Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

chore: Stop exporting schema for every batch in CometBatchIterator #1116

Closed
wants to merge 8 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 41 additions & 11 deletions common/src/main/scala/org/apache/comet/vector/NativeUtil.scala
Original file line number Diff line number Diff line change
Expand Up @@ -88,10 +88,44 @@ class NativeUtil {
* an exported batches object containing an array containing number of rows + pairs of memory
* addresses in the format of (address of Arrow array, address of Arrow schema)
*/
def exportBatch(
arrayAddrs: Array[Long],
schemaAddrs: Array[Long],
batch: ColumnarBatch): Int = {
def exportSchema(schemaAddrs: Array[Long], batch: ColumnarBatch): Int = {

(0 until batch.numCols()).foreach { index =>
batch.column(index) match {
case a: CometVector =>
val valueVector = a.getValueVector

val provider = if (valueVector.getField.getDictionary != null) {
a.getDictionaryProvider
} else {
null
}

// The array and schema structures are allocated by native side.
// Don't need to deallocate them here.
val arrowSchema = ArrowSchema.wrap(schemaAddrs(index))
val export = getFieldVector(valueVector, "export")
Data.exportField(allocator, export.getField, provider, arrowSchema)
case c =>
throw new SparkException(
"Comet execution only takes Arrow Arrays, but got " +
s"${c.getClass}")
}
}
0
}

/**
* Exports a Comet `ColumnarBatch` into a list of memory addresses that can be consumed by the
* native execution.
*
* @param batch
* the input Comet columnar batch
* @return
* an exported batches object containing an array containing number of rows + pairs of memory
* addresses in the format of (address of Arrow array, address of Arrow schema)
*/
def exportBatch(arrayAddrs: Array[Long], batch: ColumnarBatch): Int = {
val numRows = mutable.ArrayBuffer.empty[Int]

(0 until batch.numCols()).foreach { index =>
Expand All @@ -109,14 +143,10 @@ class NativeUtil {

// The array and schema structures are allocated by native side.
// Don't need to deallocate them here.
val arrowSchema = ArrowSchema.wrap(schemaAddrs(index))
val arrowArray = ArrowArray.wrap(arrayAddrs(index))
Data.exportVector(
allocator,
getFieldVector(valueVector, "export"),
provider,
arrowArray,
arrowSchema)
val export = getFieldVector(valueVector, "export")
// export array
Data.exportVector(allocator, export, provider, arrowArray)
case c =>
throw new SparkException(
"Comet execution only takes Arrow Arrays, but got " +
Expand Down
102 changes: 76 additions & 26 deletions native/core/src/execution/operators/scan.rs
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ pub struct ScanExec {
pub exec_context_id: i64,
/// The input source of scan node. It is a global reference of JVM `CometBatchIterator` object.
pub input_source: Option<Arc<GlobalRef>>,
schema_addrs: Vec<i64>,
/// A description of the input source for informational purposes
pub input_source_description: String,
/// The data types of columns of the input batch. Converted from Spark schema.
Expand Down Expand Up @@ -95,14 +96,26 @@ impl ScanExec {
// ScanExec will cast arrays from all future batches to the type determined here, so we
// may end up either unpacking dictionary arrays or dictionary-encoding arrays.
// Dictionary-encoded primitive arrays are always unpacked.
let first_batch = if let Some(input_source) = input_source.as_ref() {
let (first_batch, schema_addrs) = if let Some(input_source) = input_source.as_ref() {
let mut timer = baseline_metrics.elapsed_compute().timer();
let batch =
ScanExec::get_next(exec_context_id, input_source.as_obj(), data_types.len())?;
timer.stop();
batch
if let Some(schema_addrs) =
ScanExec::get_schema(exec_context_id, input_source.as_obj(), data_types.len())?
{
let batch = ScanExec::get_next(
exec_context_id,
input_source.as_obj(),
schema_addrs.clone(),
data_types.len(),
)?;
timer.stop();
(batch, schema_addrs)


} else {
(InputBatch::EOF, vec![])
}
} else {
InputBatch::EOF
(InputBatch::EOF, vec![])
};

let schema = scan_schema(&first_batch, &data_types);
Expand All @@ -119,6 +132,7 @@ impl ScanExec {
exec_context_id,
input_source,
input_source_description: input_source_description.to_string(),
schema_addrs,
data_types,
batch: Arc::new(Mutex::new(Some(first_batch))),
cache,
Expand Down Expand Up @@ -167,9 +181,11 @@ impl ScanExec {

let mut current_batch = self.batch.try_lock().unwrap();
if current_batch.is_none() {
let iter = self.input_source.as_ref().unwrap().as_obj();
let next_batch = ScanExec::get_next(
self.exec_context_id,
self.input_source.as_ref().unwrap().as_obj(),
iter,
self.schema_addrs.clone(),
self.data_types.len(),
)?;
*current_batch = Some(next_batch);
Expand All @@ -180,10 +196,57 @@ impl ScanExec {
Ok(())
}

/// Invokes JNI call to get the schema.
fn get_schema(
exec_context_id: i64,
iter: &JObject,
num_cols: usize,
) -> Result<Option<Vec<i64>>, CometError> {
if exec_context_id == TEST_EXEC_CONTEXT_ID {
// This is a unit test. We don't need to call JNI.
return Ok(None);
}

if iter.is_null() {
return Err(CometError::from(ExecutionError::GeneralError(format!(
"Null batch iterator object. Plan id: {}",
exec_context_id
))));
}

let mut env = JVMClasses::get_env()?;

let mut schema_addrs = Vec::with_capacity(num_cols);

for _ in 0..num_cols {
let arrow_schema = Rc::new(FFI_ArrowSchema::empty());
let schema_ptr = Rc::into_raw(arrow_schema) as i64;
schema_addrs.push(schema_ptr);
}

// export schema
let long_schema_addrs = env.new_long_array(num_cols as jsize)?;
env.set_long_array_region(&long_schema_addrs, 0, &schema_addrs)?;
let schema_obj = JObject::from(long_schema_addrs);
let schema_obj = JValueGen::Object(schema_obj.as_ref());

let num_rows: i32 = unsafe {
jni_call!(&mut env,
comet_batch_iterator(iter).export_schema(schema_obj) -> i32)?
};

if num_rows == -1 {
return Ok(None);
}

Ok(Some(schema_addrs))
}

/// Invokes JNI call to get next batch.
fn get_next(
exec_context_id: i64,
iter: &JObject,
schema_addrs: Vec<i64>,
num_cols: usize,
) -> Result<InputBatch, CometError> {
if exec_context_id == TEST_EXEC_CONTEXT_ID {
Expand All @@ -201,36 +264,21 @@ impl ScanExec {
let mut env = JVMClasses::get_env()?;

let mut array_addrs = Vec::with_capacity(num_cols);
let mut schema_addrs = Vec::with_capacity(num_cols);

for _ in 0..num_cols {
let arrow_array = Rc::new(FFI_ArrowArray::empty());
let arrow_schema = Rc::new(FFI_ArrowSchema::empty());
let (array_ptr, schema_ptr) = (
Rc::into_raw(arrow_array) as i64,
Rc::into_raw(arrow_schema) as i64,
);

let array_ptr = Rc::into_raw(arrow_array) as i64;
array_addrs.push(array_ptr);
schema_addrs.push(schema_ptr);
}

// Prepare the java array parameters
// export data
let long_array_addrs = env.new_long_array(num_cols as jsize)?;
let long_schema_addrs = env.new_long_array(num_cols as jsize)?;

env.set_long_array_region(&long_array_addrs, 0, &array_addrs)?;
env.set_long_array_region(&long_schema_addrs, 0, &schema_addrs)?;

let array_obj = JObject::from(long_array_addrs);
let schema_obj = JObject::from(long_schema_addrs);

let array_obj = JValueGen::Object(array_obj.as_ref());
let schema_obj = JValueGen::Object(schema_obj.as_ref());

let num_rows: i32 = unsafe {
jni_call!(&mut env,
comet_batch_iterator(iter).next(array_obj, schema_obj) -> i32)?
comet_batch_iterator(iter).next(array_obj) -> i32)?
};

if num_rows == -1 {
Expand All @@ -251,7 +299,9 @@ impl ScanExec {
// Drop the Arcs to avoid memory leak
unsafe {
Rc::from_raw(array_ptr as *const FFI_ArrowArray);
Rc::from_raw(schema_ptr as *const FFI_ArrowSchema);

// TODO we need to drop them eventually
//Rc::from_raw(schema_ptr as *const FFI_ArrowSchema);
}
}

Expand Down
3 changes: 1 addition & 2 deletions native/core/src/execution/utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -77,8 +77,7 @@ impl SparkArrowConvert for ArrayData {
// about memory leak.
let mut ffi_array = unsafe {
let array_data = std::ptr::replace(array_ptr, FFI_ArrowArray::empty());
let schema_data = std::ptr::replace(schema_ptr, FFI_ArrowSchema::empty());

let schema_data: &FFI_ArrowSchema = &*schema_ptr;
from_ffi(array_data, &schema_data)?
};

Expand Down
6 changes: 5 additions & 1 deletion native/core/src/jvm_bridge/batch_iterator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@ use jni::{
/// A struct that holds all the JNI methods and fields for JVM `CometBatchIterator` class.
pub struct CometBatchIterator<'a> {
pub class: JClass<'a>,
pub method_export_schema: JMethodID,
pub method_export_schema_ret: ReturnType,
pub method_next: JMethodID,
pub method_next_ret: ReturnType,
}
Expand All @@ -38,7 +40,9 @@ impl<'a> CometBatchIterator<'a> {

Ok(CometBatchIterator {
class,
method_next: env.get_method_id(Self::JVM_CLASS, "next", "([J[J)I")?,
method_export_schema: env.get_method_id(Self::JVM_CLASS, "exportSchema", "([J)I")?,
method_export_schema_ret: ReturnType::Primitive(Primitive::Int),
method_next: env.get_method_id(Self::JVM_CLASS, "next", "([J)I")?,
method_next_ret: ReturnType::Primitive(Primitive::Int),
})
}
Expand Down
36 changes: 28 additions & 8 deletions spark/src/main/java/org/apache/comet/CometBatchIterator.java
Original file line number Diff line number Diff line change
Expand Up @@ -33,26 +33,46 @@
public class CometBatchIterator {
final Iterator<ColumnarBatch> input;
final NativeUtil nativeUtil;
private ColumnarBatch currentBatch = null;

CometBatchIterator(Iterator<ColumnarBatch> input, NativeUtil nativeUtil) {
this.input = input;
this.nativeUtil = nativeUtil;
}

/**
* Get the schema of the Arrow arrays.
*
* @param schemaAddrs The addresses of the ArrowSchema structures.
*/
public int exportSchema(long[] schemaAddrs) {
if (currentBatch == null) {
if (input.hasNext()) {
currentBatch = input.next();
} else {
return -1;
}
}
nativeUtil.exportSchema(schemaAddrs, currentBatch);
return 0;
}

/**
* Get the next batches of Arrow arrays.
*
* @param arrayAddrs The addresses of the ArrowArray structures.
* @param schemaAddrs The addresses of the ArrowSchema structures.
* @return the number of rows of the current batch. -1 if there is no more batch.
*/
public int next(long[] arrayAddrs, long[] schemaAddrs) {
boolean hasBatch = input.hasNext();

if (!hasBatch) {
return -1;
public int next(long[] arrayAddrs) {
if (currentBatch == null) {
if (input.hasNext()) {
currentBatch = input.next();
} else {
return -1;
}
}

return nativeUtil.exportBatch(arrayAddrs, schemaAddrs, input.next());
int rows = nativeUtil.exportBatch(arrayAddrs, currentBatch);
currentBatch = null;
return rows;
}
}
Loading