Skip to content

Commit

Permalink
Merge remote-tracking branch 'upstream/main' into optimize-null-count
Browse files Browse the repository at this point in the history
  • Loading branch information
kazuyukitanimura committed Nov 4, 2024
2 parents df03371 + ac4223c commit 01fd025
Show file tree
Hide file tree
Showing 7 changed files with 70 additions and 13 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,9 @@
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import org.apache.arrow.c.*;
import org.apache.arrow.c.ArrowArray;
import org.apache.arrow.c.ArrowSchema;
import org.apache.arrow.c.CometSchemaImporter;
import org.apache.arrow.memory.BufferAllocator;
import org.apache.arrow.memory.RootAllocator;
import org.apache.arrow.vector.FieldVector;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ public void readBatch(int total) {
FieldVector fieldVector = Data.importVector(allocator, array, schema, null);
vector = new CometPlainVector(fieldVector, useDecimal128);
}

vector.setNumValues(total);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,17 @@ impl BloomFilterAgg {
assert!(matches!(data_type, DataType::Binary));
Self {
name: name.into(),
signature: Signature::exact(vec![DataType::Int64], Volatility::Immutable),
signature: Signature::uniform(
1,
vec![
DataType::Int8,
DataType::Int16,
DataType::Int32,
DataType::Int64,
DataType::Utf8,
],
Volatility::Immutable,
),
expr,
num_items: extract_i32_from_literal(num_items),
num_bits: extract_i32_from_literal(num_bits),
Expand Down Expand Up @@ -112,10 +122,25 @@ impl Accumulator for SparkBloomFilter {
(0..arr.len()).try_for_each(|index| {
let v = ScalarValue::try_from_array(arr, index)?;

if let ScalarValue::Int64(Some(value)) = v {
self.put_long(value);
} else {
unreachable!()
match v {
ScalarValue::Int8(Some(value)) => {
self.put_long(value as i64);
}
ScalarValue::Int16(Some(value)) => {
self.put_long(value as i64);
}
ScalarValue::Int32(Some(value)) => {
self.put_long(value as i64);
}
ScalarValue::Int64(Some(value)) => {
self.put_long(value);
}
ScalarValue::Utf8(Some(value)) => {
self.put_binary(value.as_bytes());
}
_ => {
unreachable!()
}
}
Ok(())
})
Expand Down
17 changes: 17 additions & 0 deletions native/core/src/execution/datafusion/util/spark_bloom_filter.rs
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,23 @@ impl SparkBloomFilter {
bit_changed
}

pub fn put_binary(&mut self, item: &[u8]) -> bool {
// Here we first hash the input long element into 2 int hash values, h1 and h2, then produce
// n hash values by `h1 + i * h2` with 1 <= i <= num_hash_functions.
let h1 = spark_compatible_murmur3_hash(item, 0);
let h2 = spark_compatible_murmur3_hash(item, h1);
let bit_size = self.bits.bit_size() as i32;
let mut bit_changed = false;
for i in 1..=self.num_hash_functions {
let mut combined_hash = (h1 as i32).add_wrapping((i as i32).mul_wrapping(h2 as i32));
if combined_hash < 0 {
combined_hash = !combined_hash;
}
bit_changed |= self.bits.set((combined_hash % bit_size) as usize)
}
bit_changed
}

pub fn might_contain_long(&self, item: i64) -> bool {
let h1 = spark_compatible_murmur3_hash(item.to_le_bytes(), 0);
let h2 = spark_compatible_murmur3_hash(item.to_le_bytes(), h1);
Expand Down
14 changes: 10 additions & 4 deletions spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala
Original file line number Diff line number Diff line change
Expand Up @@ -769,11 +769,17 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde with CometExprShim
val numBitsExpr = exprToProto(numBits, inputs, binding)
val dataType = serializeDataType(bloom_filter.dataType)

// TODO: Support more types
// https://github.com/apache/datafusion-comet/issues/1023
if (childExpr.isDefined &&
child.dataType
.isInstanceOf[LongType] &&
(child.dataType
.isInstanceOf[ByteType] ||
child.dataType
.isInstanceOf[ShortType] ||
child.dataType
.isInstanceOf[IntegerType] ||
child.dataType
.isInstanceOf[LongType] ||
child.dataType
.isInstanceOf[StringType]) &&
numItemsExpr.isDefined &&
numBitsExpr.isDefined &&
dataType.isDefined) {
Expand Down
4 changes: 3 additions & 1 deletion spark/src/main/scala/org/apache/spark/Plugins.scala
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,8 @@ import org.apache.comet.{CometConf, CometSparkSessionExtensions}
* To enable this plugin, set the config "spark.plugins" to `org.apache.spark.CometPlugin`.
*/
class CometDriverPlugin extends DriverPlugin with Logging with ShimCometDriverPlugin {
private val EXECUTOR_MEMORY_DEFAULT = "1g"

override def init(sc: SparkContext, pluginContext: PluginContext): ju.Map[String, String] = {
logInfo("CometDriverPlugin init")

Expand All @@ -53,7 +55,7 @@ class CometDriverPlugin extends DriverPlugin with Logging with ShimCometDriverPl
sc.getConf.getSizeAsMb(EXECUTOR_MEMORY_OVERHEAD.key)
} else {
// By default, executorMemory * spark.executor.memoryOverheadFactor, with minimum of 384MB
val executorMemory = sc.getConf.getSizeAsMb(EXECUTOR_MEMORY.key)
val executorMemory = sc.getConf.getSizeAsMb(EXECUTOR_MEMORY.key, EXECUTOR_MEMORY_DEFAULT)
val memoryOverheadFactor = getMemoryOverheadFactor(sc.getConf)
val memoryOverheadMinMib = getMemoryOverheadMinMib(sc.getConf)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -946,8 +946,12 @@ class CometExecSuite extends CometTestBase {
(0 until 100)
.map(_ => (Random.nextInt(), Random.nextInt() % 5)),
"tbl") {
val df = sql("SELECT bloom_filter_agg(cast(_2 as long)) FROM tbl")
checkSparkAnswerAndOperator(df)

(if (isSpark35Plus) Seq("tinyint", "short", "int", "long", "string") else Seq("long"))
.foreach { input_type =>
val df = sql(f"SELECT bloom_filter_agg(cast(_2 as $input_type)) FROM tbl")
checkSparkAnswerAndOperator(df)
}
}

spark.sessionState.functionRegistry.dropFunction(funcId_bloom_filter_agg)
Expand Down

0 comments on commit 01fd025

Please sign in to comment.