Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

RecordBatch normalization (flattening) #6758

Open
wants to merge 11 commits into
base: main
Choose a base branch
from
220 changes: 219 additions & 1 deletion arrow-array/src/record_batch.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,10 @@
//! A two-dimensional batch of column-oriented data with a defined
//! [schema](arrow_schema::Schema).

use crate::cast::AsArray;
use crate::{new_empty_array, Array, ArrayRef, StructArray};
use arrow_schema::{ArrowError, DataType, Field, Schema, SchemaBuilder, SchemaRef};
use arrow_schema::{ArrowError, DataType, Field, FieldRef, Schema, SchemaBuilder, SchemaRef};
use std::collections::VecDeque;
use std::ops::Index;
use std::sync::Arc;

Expand Down Expand Up @@ -394,6 +396,56 @@ impl RecordBatch {
)
}

/// Normalize a semi-structured [`RecordBatch`] into a flat table.
///
/// If max_level is 0, normalizes all levels.
pub fn normalize(&self, separator: &str, mut max_level: usize) -> Result<Self, ArrowError> {
if max_level == 0 {
max_level = usize::MAX;
}
if self.num_rows() == 0 {
// No data, only need to normalize the schema
return Ok(Self::new_empty(Arc::new(
self.schema.normalize(separator, max_level)?,
)));
}
let mut queue: VecDeque<(usize, (ArrayRef, FieldRef))> = VecDeque::new();

for (c, f) in self.columns.iter().zip(self.schema.fields()) {
queue.push_back((0, ((*c).clone(), (*f).clone())));
}

let mut columns: Vec<ArrayRef> = Vec::new();
let mut fields: Vec<FieldRef> = Vec::new();

while let Some((depth, (c, f))) = queue.pop_front() {
if depth < max_level {
match f.data_type() {
DataType::Struct(ff) => {
// Need to zip these in reverse to maintain original order
for (cff, fff) in c.as_struct().columns().iter().zip(ff.into_iter()).rev() {
let new_key = format!("{}{}{}", f.name(), separator, fff.name());
let updated_field = Field::new(
new_key.as_str(),
fff.data_type().clone(),
fff.is_nullable(),
);
queue.push_front((depth + 1, (cff.clone(), Arc::new(updated_field))))
}
}
_ => {
columns.push(c);
fields.push(f);
}
}
} else {
columns.push(c);
fields.push(f);
}
}
RecordBatch::try_new(Arc::new(Schema::new(fields)), columns)
}

/// Returns the number of columns in the record batch.
///
/// # Example
Expand Down Expand Up @@ -1197,6 +1249,172 @@ mod tests {
assert_ne!(batch1, batch2);
}

#[test]
fn normalize() {
let animals: ArrayRef = Arc::new(StringArray::from(vec!["Parrot", ""]));
let n_legs: ArrayRef = Arc::new(Int64Array::from(vec![Some(2), Some(4)]));
let year: ArrayRef = Arc::new(Int64Array::from(vec![None, Some(2022)]));

let animals_field = Arc::new(Field::new("animals", DataType::Utf8, true));
let n_legs_field = Arc::new(Field::new("n_legs", DataType::Int64, true));
let year_field = Arc::new(Field::new("year", DataType::Int64, true));

let a = Arc::new(StructArray::from(vec![
(animals_field.clone(), Arc::new(animals.clone()) as ArrayRef),
(n_legs_field.clone(), Arc::new(n_legs.clone()) as ArrayRef),
(year_field.clone(), Arc::new(year.clone()) as ArrayRef),
]));

let month = Arc::new(Int64Array::from(vec![Some(4), Some(6)]));

let schema = Schema::new(vec![
Field::new(
"a",
DataType::Struct(Fields::from(vec![animals_field, n_legs_field, year_field])),
false,
),
Field::new("month", DataType::Int64, true),
]);

let normalized = RecordBatch::try_new(Arc::new(schema), vec![a, month.clone()])
.expect("valid conversion")
.normalize(".", 0)
.expect("valid normalization");

let expected = RecordBatch::try_from_iter_with_nullable(vec![
("a.animals", animals.clone(), true),
("a.n_legs", n_legs.clone(), true),
("a.year", year.clone(), true),
("month", month.clone(), true),
])
.expect("valid conversion");

assert_eq!(expected, normalized);
}

#[test]
fn normalize_nested() {
// Initialize schema
let a = Arc::new(Field::new("a", DataType::Int64, true));
let b = Arc::new(Field::new("b", DataType::Int64, false));
let c = Arc::new(Field::new("c", DataType::Int64, true));

let one = Arc::new(Field::new(
"1",
DataType::Struct(Fields::from(vec![a.clone(), b.clone(), c.clone()])),
false,
));
let two = Arc::new(Field::new(
"2",
DataType::Struct(Fields::from(vec![a.clone(), b.clone(), c.clone()])),
true,
));

let exclamation = Arc::new(Field::new(
"!",
DataType::Struct(Fields::from(vec![one.clone(), two.clone()])),
false,
));

let schema = Schema::new(vec![exclamation.clone()]);

// Initialize fields
let a_field = Int64Array::from(vec![Some(0), Some(1)]);
let b_field = Int64Array::from(vec![Some(2), Some(3)]);
let c_field = Int64Array::from(vec![None, Some(4)]);

let one_field = StructArray::from(vec![
(a.clone(), Arc::new(a_field.clone()) as ArrayRef),
(b.clone(), Arc::new(b_field.clone()) as ArrayRef),
(c.clone(), Arc::new(c_field.clone()) as ArrayRef),
]);
let two_field = StructArray::from(vec![
(a.clone(), Arc::new(a_field.clone()) as ArrayRef),
(b.clone(), Arc::new(b_field.clone()) as ArrayRef),
(c.clone(), Arc::new(c_field.clone()) as ArrayRef),
]);

let exclamation_field = Arc::new(StructArray::from(vec![
(one.clone(), Arc::new(one_field) as ArrayRef),
(two.clone(), Arc::new(two_field) as ArrayRef),
]));

// Normalize top level
let normalized =
RecordBatch::try_new(Arc::new(schema.clone()), vec![exclamation_field.clone()])
.expect("valid conversion")
.normalize(".", 1)
.expect("valid normalization");

let expected = RecordBatch::try_from_iter_with_nullable(vec![
(
"!.1",
Arc::new(StructArray::from(vec![
(a.clone(), Arc::new(a_field.clone()) as ArrayRef),
(b.clone(), Arc::new(b_field.clone()) as ArrayRef),
(c.clone(), Arc::new(c_field.clone()) as ArrayRef),
])) as ArrayRef,
false,
),
(
"!.2",
Arc::new(StructArray::from(vec![
(a.clone(), Arc::new(a_field.clone()) as ArrayRef),
(b.clone(), Arc::new(b_field.clone()) as ArrayRef),
(c.clone(), Arc::new(c_field.clone()) as ArrayRef),
])) as ArrayRef,
true,
),
])
.expect("valid conversion");

assert_eq!(expected, normalized);

// Normalize all levels
let normalized = RecordBatch::try_new(Arc::new(schema), vec![exclamation_field])
.expect("valid conversion")
.normalize(".", 0)
.expect("valid normalization");

let expected = RecordBatch::try_from_iter_with_nullable(vec![
("!.1.a", Arc::new(a_field.clone()) as ArrayRef, true),
("!.1.b", Arc::new(b_field.clone()) as ArrayRef, false),
("!.1.c", Arc::new(c_field.clone()) as ArrayRef, true),
("!.2.a", Arc::new(a_field.clone()) as ArrayRef, true),
("!.2.b", Arc::new(b_field.clone()) as ArrayRef, false),
("!.2.c", Arc::new(c_field.clone()) as ArrayRef, true),
])
.expect("valid conversion");

assert_eq!(expected, normalized);
}

#[test]
fn normalize_empty() {
let animals_field = Arc::new(Field::new("animals", DataType::Utf8, true));
let n_legs_field = Arc::new(Field::new("n_legs", DataType::Int64, true));
let year_field = Arc::new(Field::new("year", DataType::Int64, true));

let schema = Schema::new(vec![
Field::new(
"a",
DataType::Struct(Fields::from(vec![animals_field, n_legs_field, year_field])),
false,
),
Field::new("month", DataType::Int64, true),
]);

let normalized = RecordBatch::new_empty(Arc::new(schema.clone()))
.normalize(".", 0)
.expect("valid normalization");

let expected = RecordBatch::new_empty(Arc::new(
schema.normalize(".", 0).expect("valid normalization"),
));

assert_eq!(expected, normalized);
}

#[test]
fn project() {
let a: ArrayRef = Arc::new(Int32Array::from(vec![Some(1), None, Some(3)]));
Expand Down
Loading
Loading