Skip to content

Commit

Permalink
Support IS NULL and IS NOT NULL on Unions (apache#11321)
Browse files Browse the repository at this point in the history
* Demonstrate unions can't be null

* add scalar test cases

* support "IS NULL" and "IS NOT NULL" on unions

* formatting

* fix comments from @alamb

* fix docstring
  • Loading branch information
samuelcolvin authored and xinlifoobar committed Jul 18, 2024
1 parent b77de98 commit 2d53ac9
Show file tree
Hide file tree
Showing 4 changed files with 373 additions and 9 deletions.
34 changes: 33 additions & 1 deletion datafusion/common/src/scalar/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1459,7 +1459,10 @@ impl ScalarValue {
ScalarValue::DurationMillisecond(v) => v.is_none(),
ScalarValue::DurationMicrosecond(v) => v.is_none(),
ScalarValue::DurationNanosecond(v) => v.is_none(),
ScalarValue::Union(v, _, _) => v.is_none(),
ScalarValue::Union(v, _, _) => match v {
Some((_, s)) => s.is_null(),
None => true,
},
ScalarValue::Dictionary(_, v) => v.is_null(),
}
}
Expand Down Expand Up @@ -6514,4 +6517,33 @@ mod tests {
}
intervals
}

fn union_fields() -> UnionFields {
[
(0, Arc::new(Field::new("A", DataType::Int32, true))),
(1, Arc::new(Field::new("B", DataType::Float64, true))),
]
.into_iter()
.collect()
}

#[test]
fn sparse_scalar_union_is_null() {
let sparse_scalar = ScalarValue::Union(
Some((0_i8, Box::new(ScalarValue::Int32(None)))),
union_fields(),
UnionMode::Sparse,
);
assert!(sparse_scalar.is_null());
}

#[test]
fn dense_scalar_union_is_null() {
let dense_scalar = ScalarValue::Union(
Some((0_i8, Box::new(ScalarValue::Int32(None)))),
union_fields(),
UnionMode::Dense,
);
assert!(dense_scalar.is_null());
}
}
165 changes: 163 additions & 2 deletions datafusion/core/tests/dataframe/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,9 @@ use arrow::{
},
record_batch::RecordBatch,
};
use arrow_array::Float32Array;
use arrow_schema::ArrowError;
use arrow_array::{Array, Float32Array, Float64Array, UnionArray};
use arrow_buffer::ScalarBuffer;
use arrow_schema::{ArrowError, UnionFields, UnionMode};
use datafusion_functions_aggregate::count::count_udaf;
use object_store::local::LocalFileSystem;
use std::fs;
Expand Down Expand Up @@ -2195,3 +2196,163 @@ async fn write_parquet_results() -> Result<()> {

Ok(())
}

fn union_fields() -> UnionFields {
[
(0, Arc::new(Field::new("A", DataType::Int32, true))),
(1, Arc::new(Field::new("B", DataType::Float64, true))),
(2, Arc::new(Field::new("C", DataType::Utf8, true))),
]
.into_iter()
.collect()
}

#[tokio::test]
async fn sparse_union_is_null() {
// union of [{A=1}, {A=}, {B=3.2}, {B=}, {C="a"}, {C=}]
let int_array = Int32Array::from(vec![Some(1), None, None, None, None, None]);
let float_array = Float64Array::from(vec![None, None, Some(3.2), None, None, None]);
let str_array = StringArray::from(vec![None, None, None, None, Some("a"), None]);
let type_ids = [0, 0, 1, 1, 2, 2].into_iter().collect::<ScalarBuffer<i8>>();

let children = vec![
Arc::new(int_array) as Arc<dyn Array>,
Arc::new(float_array),
Arc::new(str_array),
];

let array = UnionArray::try_new(union_fields(), type_ids, None, children).unwrap();

let field = Field::new(
"my_union",
DataType::Union(union_fields(), UnionMode::Sparse),
true,
);
let schema = Arc::new(Schema::new(vec![field]));

let batch = RecordBatch::try_new(schema, vec![Arc::new(array)]).unwrap();

let ctx = SessionContext::new();

ctx.register_batch("union_batch", batch).unwrap();

let df = ctx.table("union_batch").await.unwrap();

// view_all
let expected = [
"+----------+",
"| my_union |",
"+----------+",
"| {A=1} |",
"| {A=} |",
"| {B=3.2} |",
"| {B=} |",
"| {C=a} |",
"| {C=} |",
"+----------+",
];
assert_batches_sorted_eq!(expected, &df.clone().collect().await.unwrap());

// filter where is null
let result_df = df.clone().filter(col("my_union").is_null()).unwrap();
let expected = [
"+----------+",
"| my_union |",
"+----------+",
"| {A=} |",
"| {B=} |",
"| {C=} |",
"+----------+",
];
assert_batches_sorted_eq!(expected, &result_df.collect().await.unwrap());

// filter where is not null
let result_df = df.filter(col("my_union").is_not_null()).unwrap();
let expected = [
"+----------+",
"| my_union |",
"+----------+",
"| {A=1} |",
"| {B=3.2} |",
"| {C=a} |",
"+----------+",
];
assert_batches_sorted_eq!(expected, &result_df.collect().await.unwrap());
}

#[tokio::test]
async fn dense_union_is_null() {
// union of [{A=1}, null, {B=3.2}, {A=34}]
let int_array = Int32Array::from(vec![Some(1), None]);
let float_array = Float64Array::from(vec![Some(3.2), None]);
let str_array = StringArray::from(vec![Some("a"), None]);
let type_ids = [0, 0, 1, 1, 2, 2].into_iter().collect::<ScalarBuffer<i8>>();
let offsets = [0, 1, 0, 1, 0, 1]
.into_iter()
.collect::<ScalarBuffer<i32>>();

let children = vec![
Arc::new(int_array) as Arc<dyn Array>,
Arc::new(float_array),
Arc::new(str_array),
];

let array =
UnionArray::try_new(union_fields(), type_ids, Some(offsets), children).unwrap();

let field = Field::new(
"my_union",
DataType::Union(union_fields(), UnionMode::Dense),
true,
);
let schema = Arc::new(Schema::new(vec![field]));

let batch = RecordBatch::try_new(schema, vec![Arc::new(array)]).unwrap();

let ctx = SessionContext::new();

ctx.register_batch("union_batch", batch).unwrap();

let df = ctx.table("union_batch").await.unwrap();

// view_all
let expected = [
"+----------+",
"| my_union |",
"+----------+",
"| {A=1} |",
"| {A=} |",
"| {B=3.2} |",
"| {B=} |",
"| {C=a} |",
"| {C=} |",
"+----------+",
];
assert_batches_sorted_eq!(expected, &df.clone().collect().await.unwrap());

// filter where is null
let result_df = df.clone().filter(col("my_union").is_null()).unwrap();
let expected = [
"+----------+",
"| my_union |",
"+----------+",
"| {A=} |",
"| {B=} |",
"| {C=} |",
"+----------+",
];
assert_batches_sorted_eq!(expected, &result_df.collect().await.unwrap());

// filter where is not null
let result_df = df.filter(col("my_union").is_not_null()).unwrap();
let expected = [
"+----------+",
"| my_union |",
"+----------+",
"| {A=1} |",
"| {B=3.2} |",
"| {C=a} |",
"+----------+",
];
assert_batches_sorted_eq!(expected, &result_df.collect().await.unwrap());
}
54 changes: 51 additions & 3 deletions datafusion/physical-expr/src/expressions/is_not_null.rs
Original file line number Diff line number Diff line change
Expand Up @@ -73,9 +73,11 @@ impl PhysicalExpr for IsNotNullExpr {
fn evaluate(&self, batch: &RecordBatch) -> Result<ColumnarValue> {
let arg = self.arg.evaluate(batch)?;
match arg {
ColumnarValue::Array(array) => Ok(ColumnarValue::Array(Arc::new(
compute::is_not_null(array.as_ref())?,
))),
ColumnarValue::Array(array) => {
let is_null = super::is_null::compute_is_null(array)?;
let is_not_null = compute::not(&is_null)?;
Ok(ColumnarValue::Array(Arc::new(is_not_null)))
}
ColumnarValue::Scalar(scalar) => Ok(ColumnarValue::Scalar(
ScalarValue::Boolean(Some(!scalar.is_null())),
)),
Expand Down Expand Up @@ -120,6 +122,8 @@ mod tests {
array::{BooleanArray, StringArray},
datatypes::*,
};
use arrow_array::{Array, Float64Array, Int32Array, UnionArray};
use arrow_buffer::ScalarBuffer;
use datafusion_common::cast::as_boolean_array;

#[test]
Expand All @@ -143,4 +147,48 @@ mod tests {

Ok(())
}

#[test]
fn union_is_not_null_op() {
// union of [{A=1}, {A=}, {B=1.1}, {B=1.2}, {B=}]
let int_array = Int32Array::from(vec![Some(1), None, None, None, None]);
let float_array =
Float64Array::from(vec![None, None, Some(1.1), Some(1.2), None]);
let type_ids = [0, 0, 1, 1, 1].into_iter().collect::<ScalarBuffer<i8>>();

let children = vec![Arc::new(int_array) as Arc<dyn Array>, Arc::new(float_array)];

let union_fields: UnionFields = [
(0, Arc::new(Field::new("A", DataType::Int32, true))),
(1, Arc::new(Field::new("B", DataType::Float64, true))),
]
.into_iter()
.collect();

let array =
UnionArray::try_new(union_fields.clone(), type_ids, None, children).unwrap();

let field = Field::new(
"my_union",
DataType::Union(union_fields, UnionMode::Sparse),
true,
);

let schema = Schema::new(vec![field]);
let expr = is_not_null(col("my_union", &schema).unwrap()).unwrap();
let batch =
RecordBatch::try_new(Arc::new(schema), vec![Arc::new(array)]).unwrap();

// expression: "a is not null"
let actual = expr
.evaluate(&batch)
.unwrap()
.into_array(batch.num_rows())
.expect("Failed to convert to array");
let actual = as_boolean_array(&actual).unwrap();

let expected = &BooleanArray::from(vec![true, false, true, true, false]);

assert_eq!(expected, actual);
}
}
Loading

0 comments on commit 2d53ac9

Please sign in to comment.