diff --git a/datafusion/core/tests/dataframe/mod.rs b/datafusion/core/tests/dataframe/mod.rs index e46a92e92818..66896a9cd771 100644 --- a/datafusion/core/tests/dataframe/mod.rs +++ b/datafusion/core/tests/dataframe/mod.rs @@ -29,8 +29,9 @@ use arrow::{ }, record_batch::RecordBatch, }; -use arrow_array::Float32Array; -use arrow_schema::ArrowError; +use arrow_array::{Array, Float32Array, Float64Array, UnionArray}; +use arrow_buffer::ScalarBuffer; +use arrow_schema::{ArrowError, UnionFields, UnionMode}; use datafusion_functions_aggregate::count::count_udaf; use object_store::local::LocalFileSystem; use std::fs; @@ -2195,3 +2196,135 @@ async fn write_parquet_results() -> Result<()> { Ok(()) } + +fn union_fields() -> UnionFields { + [ + (0, Arc::new(Field::new("A", DataType::Int32, false))), + (1, Arc::new(Field::new("B", DataType::Float64, false))), + ] + .into_iter() + .collect() +} + +#[tokio::test] +async fn sparse_union_is_null() { + // union of [{A=1}, null, {B=3.2}, {A=34}] + let int_array = Int32Array::from(vec![Some(1), None, None, Some(34)]); + let float_array = Float64Array::from(vec![None, None, Some(3.2), None]); + let type_ids = [0, 0, 1, 0].into_iter().collect::>(); + + let children = vec![Arc::new(int_array) as Arc, Arc::new(float_array)]; + + let array = UnionArray::try_new(union_fields(), type_ids, None, children).unwrap(); + + let schema = Arc::new(Schema::new(vec![Field::new( + "my_union", + DataType::Union(union_fields(), UnionMode::Sparse), + false, + )])); + + let batch = RecordBatch::try_new(schema, vec![Arc::new(array)]).unwrap(); + + let ctx = SessionContext::new(); + + ctx.register_batch("union_batch", batch).unwrap(); + + let df = ctx.table("union_batch").await.unwrap(); + + // view_all + let expected = [ + "+----------+", + "| my_union |", + "+----------+", + "| {A=1} |", + "| {A=34} |", + "| {A=} |", + "| {B=3.2} |", + "+----------+", + ]; + assert_batches_sorted_eq!(expected, &df.clone().collect().await.unwrap()); + + // filter where is null + let result_df = df.filter(col("my_union").is_null()).unwrap(); + let expected = [ + "+----------+", + "| my_union |", + "+----------+", + "| {A=} |", + "+----------+", + ]; + assert_batches_sorted_eq!(expected, &result_df.collect().await.unwrap()); +} + +#[tokio::test] +async fn dense_union_is_null() { + // union of [{A=1}, null, {B=3.2}, {A=34}] + let int_array = Int32Array::from(vec![Some(1), None, Some(34)]); + let float_array = Float64Array::from(vec![3.2]); + let type_ids = [0, 0, 1, 0].into_iter().collect::>(); + let offsets = [0, 1, 0, 2].into_iter().collect::>(); + + let children = vec![Arc::new(int_array) as Arc, Arc::new(float_array)]; + + let array = + UnionArray::try_new(union_fields(), type_ids, Some(offsets), children).unwrap(); + + let schema = Arc::new(Schema::new(vec![Field::new( + "my_union", + DataType::Union(union_fields(), UnionMode::Dense), + false, + )])); + + let batch = RecordBatch::try_new(schema, vec![Arc::new(array)]).unwrap(); + + let ctx = SessionContext::new(); + + ctx.register_batch("union_batch", batch).unwrap(); + + let df = ctx.table("union_batch").await.unwrap(); + + // view_all + let expected = [ + "+----------+", + "| my_union |", + "+----------+", + "| {A=1} |", + "| {A=34} |", + "| {A=} |", + "| {B=3.2} |", + "+----------+", + ]; + assert_batches_sorted_eq!(expected, &df.clone().collect().await.unwrap()); + + // filter where is null + let result_df = df.filter(col("my_union").is_null()).unwrap(); + let expected = [ + "+----------+", + "| my_union |", + "+----------+", + "| {A=} |", + "+----------+", + ]; + assert_batches_sorted_eq!(expected, &result_df.collect().await.unwrap()); +} + +/// these should definitely be moved somewhere else, but I'm just adding it here for simplicity now +#[tokio::test] +async fn sparse_union_is_null_scalar() { + let sparse_scalar = ScalarValue::Union( + Some((0_i8, Box::new(ScalarValue::Int32(None)))), + union_fields(), + UnionMode::Sparse, + ); + assert!(sparse_scalar.is_null()); +} + +#[tokio::test] +async fn dense_union_is_null_scalar() { + let dense_scalar = ScalarValue::Union( + Some((0_i8, Box::new(ScalarValue::Int32(None)))), + union_fields(), + UnionMode::Dense, + ); + assert!(dense_scalar.is_null()); +}