diff --git a/arrow-array/src/record_batch.rs b/arrow-array/src/record_batch.rs index 8958ca6fae6..ec2b974fdf9 100644 --- a/arrow-array/src/record_batch.rs +++ b/arrow-array/src/record_batch.rs @@ -18,8 +18,10 @@ //! A two-dimensional batch of column-oriented data with a defined //! [schema](arrow_schema::Schema). +use crate::cast::AsArray; use crate::{new_empty_array, Array, ArrayRef, StructArray}; -use arrow_schema::{ArrowError, DataType, Field, Schema, SchemaBuilder, SchemaRef}; +use arrow_schema::{ArrowError, DataType, Field, FieldRef, Schema, SchemaBuilder, SchemaRef}; +use std::collections::VecDeque; use std::ops::Index; use std::sync::Arc; @@ -394,6 +396,56 @@ impl RecordBatch { ) } + /// Normalize a semi-structured [`RecordBatch`] into a flat table. + /// + /// If max_level is 0, normalizes all levels. + pub fn normalize(&self, separator: &str, mut max_level: usize) -> Result { + if max_level == 0 { + max_level = usize::MAX; + } + if self.num_rows() == 0 { + // No data, only need to normalize the schema + return Ok(Self::new_empty(Arc::new( + self.schema.normalize(separator, max_level)?, + ))); + } + let mut queue: VecDeque<(usize, (ArrayRef, FieldRef))> = VecDeque::new(); + + for (c, f) in self.columns.iter().zip(self.schema.fields()) { + queue.push_back((0, ((*c).clone(), (*f).clone()))); + } + + let mut columns: Vec = Vec::new(); + let mut fields: Vec = Vec::new(); + + while let Some((depth, (c, f))) = queue.pop_front() { + if depth < max_level { + match f.data_type() { + DataType::Struct(ff) => { + // Need to zip these in reverse to maintain original order + for (cff, fff) in c.as_struct().columns().iter().zip(ff.into_iter()).rev() { + let new_key = format!("{}{}{}", f.name(), separator, fff.name()); + let updated_field = Field::new( + new_key.as_str(), + fff.data_type().clone(), + fff.is_nullable(), + ); + queue.push_front((depth + 1, (cff.clone(), Arc::new(updated_field)))) + } + } + _ => { + columns.push(c); + fields.push(f); + } + } + } else { + columns.push(c); + fields.push(f); + } + } + RecordBatch::try_new(Arc::new(Schema::new(fields)), columns) + } + /// Returns the number of columns in the record batch. /// /// # Example @@ -1197,6 +1249,172 @@ mod tests { assert_ne!(batch1, batch2); } + #[test] + fn normalize() { + let animals: ArrayRef = Arc::new(StringArray::from(vec!["Parrot", ""])); + let n_legs: ArrayRef = Arc::new(Int64Array::from(vec![Some(2), Some(4)])); + let year: ArrayRef = Arc::new(Int64Array::from(vec![None, Some(2022)])); + + let animals_field = Arc::new(Field::new("animals", DataType::Utf8, true)); + let n_legs_field = Arc::new(Field::new("n_legs", DataType::Int64, true)); + let year_field = Arc::new(Field::new("year", DataType::Int64, true)); + + let a = Arc::new(StructArray::from(vec![ + (animals_field.clone(), Arc::new(animals.clone()) as ArrayRef), + (n_legs_field.clone(), Arc::new(n_legs.clone()) as ArrayRef), + (year_field.clone(), Arc::new(year.clone()) as ArrayRef), + ])); + + let month = Arc::new(Int64Array::from(vec![Some(4), Some(6)])); + + let schema = Schema::new(vec![ + Field::new( + "a", + DataType::Struct(Fields::from(vec![animals_field, n_legs_field, year_field])), + false, + ), + Field::new("month", DataType::Int64, true), + ]); + + let normalized = RecordBatch::try_new(Arc::new(schema), vec![a, month.clone()]) + .expect("valid conversion") + .normalize(".", 0) + .expect("valid normalization"); + + let expected = RecordBatch::try_from_iter_with_nullable(vec![ + ("a.animals", animals.clone(), true), + ("a.n_legs", n_legs.clone(), true), + ("a.year", year.clone(), true), + ("month", month.clone(), true), + ]) + .expect("valid conversion"); + + assert_eq!(expected, normalized); + } + + #[test] + fn normalize_nested() { + // Initialize schema + let a = Arc::new(Field::new("a", DataType::Int64, true)); + let b = Arc::new(Field::new("b", DataType::Int64, false)); + let c = Arc::new(Field::new("c", DataType::Int64, true)); + + let one = Arc::new(Field::new( + "1", + DataType::Struct(Fields::from(vec![a.clone(), b.clone(), c.clone()])), + false, + )); + let two = Arc::new(Field::new( + "2", + DataType::Struct(Fields::from(vec![a.clone(), b.clone(), c.clone()])), + true, + )); + + let exclamation = Arc::new(Field::new( + "!", + DataType::Struct(Fields::from(vec![one.clone(), two.clone()])), + false, + )); + + let schema = Schema::new(vec![exclamation.clone()]); + + // Initialize fields + let a_field = Int64Array::from(vec![Some(0), Some(1)]); + let b_field = Int64Array::from(vec![Some(2), Some(3)]); + let c_field = Int64Array::from(vec![None, Some(4)]); + + let one_field = StructArray::from(vec![ + (a.clone(), Arc::new(a_field.clone()) as ArrayRef), + (b.clone(), Arc::new(b_field.clone()) as ArrayRef), + (c.clone(), Arc::new(c_field.clone()) as ArrayRef), + ]); + let two_field = StructArray::from(vec![ + (a.clone(), Arc::new(a_field.clone()) as ArrayRef), + (b.clone(), Arc::new(b_field.clone()) as ArrayRef), + (c.clone(), Arc::new(c_field.clone()) as ArrayRef), + ]); + + let exclamation_field = Arc::new(StructArray::from(vec![ + (one.clone(), Arc::new(one_field) as ArrayRef), + (two.clone(), Arc::new(two_field) as ArrayRef), + ])); + + // Normalize top level + let normalized = + RecordBatch::try_new(Arc::new(schema.clone()), vec![exclamation_field.clone()]) + .expect("valid conversion") + .normalize(".", 1) + .expect("valid normalization"); + + let expected = RecordBatch::try_from_iter_with_nullable(vec![ + ( + "!.1", + Arc::new(StructArray::from(vec![ + (a.clone(), Arc::new(a_field.clone()) as ArrayRef), + (b.clone(), Arc::new(b_field.clone()) as ArrayRef), + (c.clone(), Arc::new(c_field.clone()) as ArrayRef), + ])) as ArrayRef, + false, + ), + ( + "!.2", + Arc::new(StructArray::from(vec![ + (a.clone(), Arc::new(a_field.clone()) as ArrayRef), + (b.clone(), Arc::new(b_field.clone()) as ArrayRef), + (c.clone(), Arc::new(c_field.clone()) as ArrayRef), + ])) as ArrayRef, + true, + ), + ]) + .expect("valid conversion"); + + assert_eq!(expected, normalized); + + // Normalize all levels + let normalized = RecordBatch::try_new(Arc::new(schema), vec![exclamation_field]) + .expect("valid conversion") + .normalize(".", 0) + .expect("valid normalization"); + + let expected = RecordBatch::try_from_iter_with_nullable(vec![ + ("!.1.a", Arc::new(a_field.clone()) as ArrayRef, true), + ("!.1.b", Arc::new(b_field.clone()) as ArrayRef, false), + ("!.1.c", Arc::new(c_field.clone()) as ArrayRef, true), + ("!.2.a", Arc::new(a_field.clone()) as ArrayRef, true), + ("!.2.b", Arc::new(b_field.clone()) as ArrayRef, false), + ("!.2.c", Arc::new(c_field.clone()) as ArrayRef, true), + ]) + .expect("valid conversion"); + + assert_eq!(expected, normalized); + } + + #[test] + fn normalize_empty() { + let animals_field = Arc::new(Field::new("animals", DataType::Utf8, true)); + let n_legs_field = Arc::new(Field::new("n_legs", DataType::Int64, true)); + let year_field = Arc::new(Field::new("year", DataType::Int64, true)); + + let schema = Schema::new(vec![ + Field::new( + "a", + DataType::Struct(Fields::from(vec![animals_field, n_legs_field, year_field])), + false, + ), + Field::new("month", DataType::Int64, true), + ]); + + let normalized = RecordBatch::new_empty(Arc::new(schema.clone())) + .normalize(".", 0) + .expect("valid normalization"); + + let expected = RecordBatch::new_empty(Arc::new( + schema.normalize(".", 0).expect("valid normalization"), + )); + + assert_eq!(expected, normalized); + } + #[test] fn project() { let a: ArrayRef = Arc::new(Int32Array::from(vec![Some(1), None, Some(3)])); diff --git a/arrow-schema/src/schema.rs b/arrow-schema/src/schema.rs index c5c22b52713..ca0532cecdc 100644 --- a/arrow-schema/src/schema.rs +++ b/arrow-schema/src/schema.rs @@ -22,7 +22,7 @@ use std::sync::Arc; use crate::error::ArrowError; use crate::field::Field; -use crate::{FieldRef, Fields}; +use crate::{DataType, FieldRef, Fields}; /// A builder to facilitate building a [`Schema`] from iteratively from [`FieldRef`] #[derive(Debug, Default)] @@ -413,6 +413,81 @@ impl Schema { &self.metadata } + /// Returns a new schema, normalized based on the max_level + /// This carries metadata from the parent schema over as well + pub fn normalize(&self, separator: &str, mut max_level: usize) -> Result { + if max_level == 0 { + max_level = usize::MAX; + } + let mut new_fields: Vec = vec![]; + for field in self.fields() { + match field.data_type() { + DataType::Struct(nested_fields) => { + let field_name = field.name().as_str(); + new_fields = [ + new_fields, + Self::normalizer( + nested_fields.to_vec(), + field_name, + separator, + max_level - 1, + ), + ] + .concat(); + } + _ => new_fields.push(Arc::new(Field::new( + field.name(), + field.data_type().clone(), + field.is_nullable(), + ))), + }; + } + Ok(Self::new_with_metadata(new_fields, self.metadata.clone())) + } + + fn normalizer( + fields: Vec, + key_string: &str, + separator: &str, + max_level: usize, + ) -> Vec { + let mut new_fields: Vec = vec![]; + if max_level > 0 { + for field in fields { + match field.data_type() { + DataType::Struct(nested_fields) => { + let field_name = field.name().as_str(); + let new_key = format!("{key_string}{separator}{field_name}"); + new_fields = [ + new_fields, + Self::normalizer( + nested_fields.to_vec(), + new_key.as_str(), + separator, + max_level - 1, + ), + ] + .concat(); + } + _ => new_fields.push(Arc::new(Field::new( + format!("{key_string}{separator}{}", field.name()), + field.data_type().clone(), + field.is_nullable(), + ))), + }; + } + } else { + for field in fields { + new_fields.push(Arc::new(Field::new( + format!("{key_string}{separator}{}", field.name()), + field.data_type().clone(), + field.is_nullable(), + ))); + } + } + new_fields + } + /// Look up a column by name and return a immutable reference to the column along with /// its index. pub fn column_with_name(&self, name: &str) -> Option<(usize, &Field)> { @@ -697,6 +772,87 @@ mod tests { schema.index_of("nickname").unwrap(); } + #[test] + fn normalize() { + let schema = Schema::new(vec![ + Field::new( + "a", + DataType::Struct(Fields::from(vec![ + Arc::new(Field::new("animals", DataType::Utf8, true)), + Arc::new(Field::new("n_legs", DataType::Int64, true)), + Arc::new(Field::new("year", DataType::Int64, true)), + ])), + false, + ), + Field::new("month", DataType::Int64, true), + ]) + .normalize(".", 0) + .expect("valid normalization"); + + let expected = Schema::new(vec![ + Field::new("a.animals", DataType::Utf8, true), + Field::new("a.n_legs", DataType::Int64, true), + Field::new("a.year", DataType::Int64, true), + Field::new("month", DataType::Int64, true), + ]); + + assert_eq!(schema, expected); + } + + #[test] + fn normalize_nested() { + let a = Arc::new(Field::new("a", DataType::Utf8, true)); + let b = Arc::new(Field::new("b", DataType::Int64, false)); + let c = Arc::new(Field::new("c", DataType::Int64, true)); + + let d = Arc::new(Field::new("d", DataType::Utf8, true)); + let e = Arc::new(Field::new("e", DataType::Int64, false)); + let f = Arc::new(Field::new("f", DataType::Int64, true)); + + let one = Arc::new(Field::new( + "1", + DataType::Struct(Fields::from(vec![a.clone(), b.clone(), c.clone()])), + false, + )); + let two = Arc::new(Field::new( + "2", + DataType::Struct(Fields::from(vec![d.clone(), e.clone(), f.clone()])), + true, + )); + + let exclamation = Arc::new(Field::new( + "!", + DataType::Struct(Fields::from(vec![one, two])), + false, + )); + + let normalize_all = Schema::new(vec![exclamation.clone()]) + .normalize(".", 0) + .expect("valid normalization"); + + let expected = Schema::new(vec![ + Field::new("!.1.a", DataType::Utf8, true), + Field::new("!.1.b", DataType::Int64, false), + Field::new("!.1.c", DataType::Int64, true), + Field::new("!.2.d", DataType::Utf8, true), + Field::new("!.2.e", DataType::Int64, false), + Field::new("!.2.f", DataType::Int64, true), + ]); + + assert_eq!(normalize_all, expected); + + let normalize_depth_one = Schema::new(vec![exclamation]) + .normalize(".", 1) + .expect("valid normalization"); + + let expected = Schema::new(vec![ + Field::new("!.1", DataType::Struct(Fields::from(vec![a, b, c])), false), + Field::new("!.2", DataType::Struct(Fields::from(vec![d, e, f])), true), + ]); + + assert_eq!(normalize_depth_one, expected); + } + #[test] #[should_panic( expected = "Unable to get field named \\\"nickname\\\". Valid fields: [\\\"first_name\\\", \\\"last_name\\\", \\\"address\\\", \\\"interests\\\"]"