diff --git a/README.md b/README.md index c4b0e27..6d8915d 100644 --- a/README.md +++ b/README.md @@ -35,7 +35,7 @@ without need for dynamic linking of C libraries. | feature | `src_sqlite` | `src_duckdb` | `src_postgres` | `src_mysql` | `src_tiberius` | | dependency | [rusqlite](https://crates.io/crates/rusqlite) | [duckdb](https://crates.io/crates/duckdb) | [postgres](https://crates.io/crates/postgres) | [mysql](https://crates.io/crates/mysql) | [tiberius](https://crates.io/crates/tiberius) | | query | x | x | x | x | x | -| query params | | | x | | | +| query params | x | x | x | | x | | schema get | x | x | x | x | x | | schema edit | x | x | x | x | x | | append | x | x | x | x | x | diff --git a/connector_arrow/src/api.rs b/connector_arrow/src/api.rs index af9756b..4f475f9 100644 --- a/connector_arrow/src/api.rs +++ b/connector_arrow/src/api.rs @@ -45,11 +45,24 @@ pub trait Statement<'conn> { where Self: 'stmt; - /// Start executing. - /// This will create a reader that can retrieve the result schema and data. - fn start<'p, I>(&mut self, params: I) -> Result, ConnectorError> + /// Execute this statement once. + /// Returns a reader that can retrieve the result schema and data. + fn start<'p, I>(&mut self, args: I) -> Result, ConnectorError> where - I: IntoIterator; + I: IntoIterator, + { + let args: Vec<_> = args.into_iter().collect(); + let batch = crate::params::vec_to_record_batch(args)?; + self.start_batch((&batch, 0)) + } + + /// Execute this statement once. + /// Query arguments are read from record batch, from the specified row. + /// Returns a reader that can retrieve the result schema and data. + fn start_batch( + &mut self, + args: (&RecordBatch, usize), + ) -> Result, ConnectorError>; } /// Reads result of the query, starting with the schema. diff --git a/connector_arrow/src/duckdb/mod.rs b/connector_arrow/src/duckdb/mod.rs index 4238018..60f835e 100644 --- a/connector_arrow/src/duckdb/mod.rs +++ b/connector_arrow/src/duckdb/mod.rs @@ -8,11 +8,13 @@ pub use append::DuckDBAppender; use arrow::datatypes::{DataType, TimeUnit}; use arrow::record_batch::RecordBatch; +use itertools::Itertools; use std::sync::Arc; -use crate::api::{ArrowValue, Connector, ResultReader, Statement}; +use crate::api::{Connector, ResultReader, Statement}; use crate::errors::ConnectorError; +use crate::util::{transport, ArrayCellRef}; pub struct DuckDBConnection { inner: duckdb::Connection, @@ -126,11 +128,20 @@ impl<'conn> Statement<'conn> for DuckDBStatement<'conn> { where Self: 'stmt; - fn start<'p, I>(&mut self, _params: I) -> Result, ConnectorError> - where - I: IntoIterator, - { - let arrow = self.stmt.query_arrow([])?; + fn start_batch<'p>( + &mut self, + args: (&RecordBatch, usize), + ) -> Result, ConnectorError> { + // args + let arg_cells = ArrayCellRef::vec_from_batch(args.0, args.1); + let mut args: Vec = Vec::with_capacity(arg_cells.len()); + for cell in arg_cells { + transport::transport(cell.field, &cell, &mut args)?; + } + let args = args.iter().map(|x| x as &dyn duckdb::ToSql).collect_vec(); + + // query + let arrow = self.stmt.query_arrow(args.as_slice())?; Ok(DuckDBReader { arrow }) } } diff --git a/connector_arrow/src/mysql/query.rs b/connector_arrow/src/mysql/query.rs index 648f1f9..5608693 100644 --- a/connector_arrow/src/mysql/query.rs +++ b/connector_arrow/src/mysql/query.rs @@ -19,12 +19,10 @@ impl<'conn, C: Queryable> Statement<'conn> for MySQLStatement<'conn, C> { where Self: 'stmt; - fn start<'p, I>(&mut self, _params: I) -> Result, ConnectorError> - where - I: IntoIterator, - { - // TODO: params - + fn start_batch<'p>( + &mut self, + _args: (&RecordBatch, usize), + ) -> Result, ConnectorError> { let query_result = self.queryable.exec_iter(&self.stmt, ())?; // PacCell is needed so we can return query_result and result_set that mutably borrows query result. diff --git a/connector_arrow/src/params.rs b/connector_arrow/src/params.rs index ffe1c5e..3c9085d 100644 --- a/connector_arrow/src/params.rs +++ b/connector_arrow/src/params.rs @@ -1,11 +1,38 @@ +use arrow::array::RecordBatch; use arrow::datatypes::*; +use itertools::{zip_eq, Itertools}; use std::any::Any; +use std::sync::Arc; use crate::api::ArrowValue; use crate::types::{FixedSizeBinaryType, NullType}; use crate::util::transport::{Produce, ProduceTy}; +use crate::util::ArrowRowWriter; use crate::{impl_produce_unsupported, ConnectorError}; +pub(crate) fn vec_to_record_batch( + args: Vec<&dyn ArrowValue>, +) -> Result { + Ok(if args.is_empty() { + let opts = arrow::array::RecordBatchOptions::new().with_row_count(Some(1)); + let schema = Arc::new(arrow::datatypes::Schema::new(vec![] as Vec)); + RecordBatch::try_new_with_options(schema, vec![], &opts).unwrap() + } else { + let schema = Arc::new(arrow::datatypes::Schema::new( + args.iter() + .map(|a| Field::new("", a.get_data_type().clone(), true)) + .collect_vec(), + )); + let mut arrow_writer = ArrowRowWriter::new(schema.clone(), 1); + arrow_writer.prepare_for_batch(1)?; + for (field, a) in zip_eq(schema.fields(), args) { + crate::util::transport::transport(field, a, &mut arrow_writer)?; + } + + arrow_writer.finish().unwrap().into_iter().next().unwrap() + }) +} + impl<'r> Produce<'r> for &'r dyn ArrowValue {} macro_rules! impl_arrow_value_plain { diff --git a/connector_arrow/src/postgres/mod.rs b/connector_arrow/src/postgres/mod.rs index 0582365..9e7a208 100644 --- a/connector_arrow/src/postgres/mod.rs +++ b/connector_arrow/src/postgres/mod.rs @@ -81,7 +81,6 @@ impl Connector for PostgresConnection { .map_err(PostgresError::Postgres)?; Ok(query::PostgresStatement { client: &mut self.client, - query: query.to_string(), stmt, }) } diff --git a/connector_arrow/src/postgres/query.rs b/connector_arrow/src/postgres/query.rs index d7c65c1..c37b210 100644 --- a/connector_arrow/src/postgres/query.rs +++ b/connector_arrow/src/postgres/query.rs @@ -1,48 +1,39 @@ use arrow::datatypes::*; use arrow::record_batch::RecordBatch; -use bytes::BytesMut; -use itertools::Itertools; + use postgres::fallible_iterator::FallibleIterator; -use postgres::types::{to_sql_checked, FromSql, IsNull, ToSql, Type}; +use postgres::types::{FromSql, Type}; use postgres::{Client, Row, RowIter}; -use crate::api::{ArrowValue, ResultReader, Statement}; +use crate::api::{ResultReader, Statement}; use crate::types::{ArrowType, FixedSizeBinaryType}; -use crate::util::transport; use crate::util::CellReader; +use crate::util::{transport, ArrayCellRef}; use crate::{errors::ConnectorError, util::RowsReader}; use super::{types, PostgresError}; pub struct PostgresStatement<'conn> { pub(super) client: &'conn mut Client, - pub(super) query: String, pub(super) stmt: postgres::Statement, } impl<'conn> Statement<'conn> for PostgresStatement<'conn> { type Reader<'stmt> = PostgresBatchStream<'stmt> where Self: 'stmt; - fn start<'p, I>(&mut self, params: I) -> Result, ConnectorError> - where - I: IntoIterator, - { + fn start_batch<'p>( + &mut self, + args: (&RecordBatch, usize), + ) -> Result, ConnectorError> { let stmt = &self.stmt; let schema = types::pg_stmt_to_arrow(stmt)?; - // prepare params - let params = params - .into_iter() - .map(|p| { - let field = Field::new("", p.get_data_type().clone(), true); - ParamCell { field, value: p } - }) - .collect_vec(); + let arg_row = ArrayCellRef::vec_from_batch(args.0, args.1); // query let rows = self .client - .query_raw::<_, ParamCell, _>(&self.query, params) + .query_raw::<_, _, _>(stmt, &arg_row) .map_err(PostgresError::from)?; // create the row reader @@ -335,34 +326,3 @@ impl Binary<'_> { Ok(self.0.to_vec()) } } - -#[derive(Debug)] -struct ParamCell<'a> { - field: Field, - value: &'a dyn ArrowValue, -} - -// this is needed for params -impl<'a> ToSql for ParamCell<'a> { - fn to_sql( - &self, - _ty: &postgres::types::Type, - out: &mut BytesMut, - ) -> Result> - where - Self: Sized, - { - crate::util::transport::transport(&self.field, self.value, out)?; - Ok(IsNull::No) - } - - fn accepts(_: &postgres::types::Type) -> bool - where - Self: Sized, - { - // we don't need type validation, arrays cannot contain wrong types - true - } - - to_sql_checked!(); -} diff --git a/connector_arrow/src/sqlite/query.rs b/connector_arrow/src/sqlite/query.rs index d9852cb..59850eb 100644 --- a/connector_arrow/src/sqlite/query.rs +++ b/connector_arrow/src/sqlite/query.rs @@ -1,14 +1,15 @@ use std::sync::Arc; +use arrow::array::RecordBatch; use arrow::datatypes::*; -use itertools::zip_eq; +use itertools::{zip_eq, Itertools}; use rusqlite::types::{Type, Value}; -use crate::api::{ArrowValue, Connector, Statement}; +use crate::api::{Connector, Statement}; use crate::types::FixedSizeBinaryType; -use crate::util::transport::{Produce, ProduceTy}; -use crate::util::ArrowReader; +use crate::util::transport::{self, Produce, ProduceTy}; use crate::util::{collect_rows_to_arrow, CellReader, RowsReader}; +use crate::util::{ArrayCellRef, ArrowReader}; use crate::ConnectorError; use super::SQLiteConnection; @@ -20,14 +21,23 @@ pub struct SQLiteStatement<'conn> { impl<'conn> Statement<'conn> for SQLiteStatement<'conn> { type Reader<'task> = ArrowReader where Self: 'task; - fn start<'p, I>(&mut self, _params: I) -> Result, ConnectorError> - where - I: IntoIterator, - { + fn start_batch<'p>( + &mut self, + args: (&RecordBatch, usize), + ) -> Result, ConnectorError> { let column_count = self.stmt.column_count(); + // args + let arg_cells = ArrayCellRef::vec_from_batch(args.0, args.1); + let mut args: Vec = Vec::with_capacity(arg_cells.len()); + for cell in arg_cells { + transport::transport(cell.field, &cell, &mut args)?; + } + let args = args.iter().map(|x| x as &dyn rusqlite::ToSql).collect_vec(); + + // query let rows: Vec> = { - let mut rows_iter = self.stmt.query([])?; + let mut rows_iter = self.stmt.query(args.as_slice())?; // read all of the rows into a buffer let mut rows = Vec::with_capacity(1024); diff --git a/connector_arrow/src/tiberius/append.rs b/connector_arrow/src/tiberius/append.rs index 1093a24..86b3df6 100644 --- a/connector_arrow/src/tiberius/append.rs +++ b/connector_arrow/src/tiberius/append.rs @@ -53,10 +53,12 @@ impl<'conn, S: AsyncRead + AsyncWrite + Unpin + Send> Append<'conn> for Tiberius for row_number in 0..batch.num_rows() { let mut tb_row = TokenRow::with_capacity(row_ref.len()); + let mut buffer = Vec::with_capacity(1); for cell_ref in &mut row_ref { cell_ref.row_number = row_number; - crate::util::transport::transport(cell_ref.field, &*cell_ref, &mut tb_row)?; + crate::util::transport::transport(cell_ref.field, &*cell_ref, &mut buffer)?; + tb_row.push(buffer.pop().unwrap()); } let f = self.bulk_load.send(tb_row); @@ -72,7 +74,7 @@ impl<'conn, S: AsyncRead + AsyncWrite + Unpin + Send> Append<'conn> for Tiberius } } -impl Consume for TokenRow<'static> {} +impl Consume for Vec> {} macro_rules! impl_consume_ty { ($ArrTy: ty, $variant: ident) => { @@ -80,7 +82,7 @@ macro_rules! impl_consume_ty { }; ($ArrTy: ty, $variant: ident, $conversion: expr) => { - impl ConsumeTy<$ArrTy> for TokenRow<'static> { + impl ConsumeTy<$ArrTy> for Vec> { fn consume( &mut self, _ty: &DataType, @@ -96,7 +98,7 @@ macro_rules! impl_consume_ty { }; } -impl ConsumeTy for TokenRow<'static> { +impl ConsumeTy for Vec> { fn consume(&mut self, _ty: &DataType, _: ()) { self.push(ColumnData::U8(None)) } @@ -126,7 +128,7 @@ impl_consume_ty!(TimestampMicrosecondType, I64); impl_consume_ty!(TimestampNanosecondType, I64); impl_consume_unsupported!( - TokenRow<'static>, + Vec>, ( Date32Type, Date64Type, @@ -151,7 +153,7 @@ fn u64_to_numeric(val: u64) -> Numeric { Numeric::new_with_scale(i128::from(val), 0) } -impl ConsumeTy for TokenRow<'static> { +impl ConsumeTy for Vec> { fn consume(&mut self, ty: &DataType, value: i128) { let DataType::Decimal128(p, s) = ty else { panic!() @@ -178,7 +180,7 @@ impl ConsumeTy for TokenRow<'static> { } } -impl ConsumeTy for TokenRow<'static> { +impl ConsumeTy for Vec> { fn consume(&mut self, ty: &DataType, value: i256) { let DataType::Decimal256(p, s) = ty else { panic!() diff --git a/connector_arrow/src/tiberius/query.rs b/connector_arrow/src/tiberius/query.rs index bcd870d..1c2bf3f 100644 --- a/connector_arrow/src/tiberius/query.rs +++ b/connector_arrow/src/tiberius/query.rs @@ -1,13 +1,15 @@ use arrow::{datatypes::*, record_batch::RecordBatch}; use futures::{AsyncRead, AsyncWrite, StreamExt}; +use itertools::Itertools; use std::sync::Arc; -use tiberius::{ColumnData, QueryStream}; +use tiberius::{ColumnData, QueryStream, ToSql}; use tokio::runtime::Runtime; use crate::api::{ResultReader, Statement}; use crate::impl_produce_unsupported; use crate::types::{ArrowType, FixedSizeBinaryType, NullType}; -use crate::util::transport::ProduceTy; +use crate::util::transport::{self, ProduceTy}; +use crate::util::ArrayCellRef; use crate::util::{self, transport::Produce}; use crate::ConnectorError; @@ -23,16 +25,24 @@ impl<'conn, S: AsyncRead + AsyncWrite + Unpin + Send> Statement<'conn> where Self: 'stmt; - fn start<'p, I>(&mut self, _params: I) -> Result, ConnectorError> - where - I: IntoIterator, - { - // TODO: params + fn start_batch<'p>( + &mut self, + args: (&RecordBatch, usize), + ) -> Result, ConnectorError> { + // args + let arg_cells = ArrayCellRef::vec_from_batch(args.0, args.1); + let mut args: Vec> = Vec::with_capacity(arg_cells.len()); + for cell in arg_cells { + transport::transport(cell.field, &cell, &mut args)?; + } + let args = args.iter().map(Value).collect_vec(); + let args = args.iter().map(|a| a as &dyn ToSql).collect_vec(); + // query let mut stream = self .conn .rt - .block_on(self.conn.client.query(&self.query, &[]))?; + .block_on(self.conn.client.query(&self.query, args.as_slice()))?; // get columns let columns = self.conn.rt.block_on(stream.columns())?; @@ -221,3 +231,11 @@ impl<'a> tiberius::FromSql<'a> for StrOrNum { } } } + +struct Value<'a>(&'a ColumnData<'a>); + +impl<'a> ToSql for Value<'a> { + fn to_sql(&self) -> ColumnData<'_> { + self.0.clone() + } +} diff --git a/connector_arrow/src/util/row_reader.rs b/connector_arrow/src/util/row_reader.rs index 1b5e165..580d827 100644 --- a/connector_arrow/src/util/row_reader.rs +++ b/connector_arrow/src/util/row_reader.rs @@ -1,5 +1,6 @@ -use arrow::array::{ArrayRef, AsArray}; +use arrow::array::{ArrayRef, AsArray, RecordBatch}; use arrow::datatypes::*; +use itertools::zip_eq; use crate::types::{ArrowType, FixedSizeBinaryType}; use crate::ConnectorError; @@ -13,6 +14,18 @@ pub struct ArrayCellRef<'a> { pub row_number: usize, } +impl<'a> ArrayCellRef<'a> { + pub fn vec_from_batch(batch: &'a RecordBatch, row_number: usize) -> Vec { + zip_eq(batch.columns(), batch.schema_ref().fields()) + .map(|(array, field)| ArrayCellRef { + array, + field, + row_number, + }) + .collect() + } +} + impl<'r> Produce<'r> for &ArrayCellRef<'r> {} impl<'r> ProduceTy<'r, BooleanType> for &ArrayCellRef<'r> { diff --git a/connector_arrow/tests/it/test_duckdb.rs b/connector_arrow/tests/it/test_duckdb.rs index f6c53c2..4eb21ba 100644 --- a/connector_arrow/tests/it/test_duckdb.rs +++ b/connector_arrow/tests/it/test_duckdb.rs @@ -14,6 +14,19 @@ fn query_01() { super::tests::query_01(&mut conn); } +#[test] +#[ignore] +fn query_02() { + let mut conn = init(); + super::tests::query_02(&mut conn); +} + +#[test] +fn query_03() { + let mut conn = init(); + super::tests::query_03(&mut conn); +} + #[rstest] #[case::empty("roundtrip::empty", spec::empty())] #[case::null_bool("roundtrip::null_bool", spec::null_bool())] diff --git a/connector_arrow/tests/it/test_sqlite.rs b/connector_arrow/tests/it/test_sqlite.rs index 5031877..0e2469f 100644 --- a/connector_arrow/tests/it/test_sqlite.rs +++ b/connector_arrow/tests/it/test_sqlite.rs @@ -14,6 +14,20 @@ fn query_01() { super::tests::query_01(&mut conn); } +#[test] +#[ignore] +fn query_02() { + let mut conn = init(); + super::tests::query_02(&mut conn); +} + +#[test] +#[ignore] +fn query_03() { + let mut conn = init(); + super::tests::query_03(&mut conn); +} + #[rstest] // #[case::empty("roundtrip::empty", spec::empty())] #[case::null_bool("roundtrip::null_bool", spec::null_bool())] diff --git a/connector_arrow/tests/it/test_tiberius.rs b/connector_arrow/tests/it/test_tiberius.rs index 3a9f6de..8f8cfb4 100644 --- a/connector_arrow/tests/it/test_tiberius.rs +++ b/connector_arrow/tests/it/test_tiberius.rs @@ -52,6 +52,13 @@ fn query_02() { super::tests::query_02(&mut conn); } +#[test] +#[ignore] +fn query_03() { + let mut conn = init(); + super::tests::query_03(&mut conn); +} + #[test] fn schema_get() { let table_name = "simple::schema_get"; diff --git a/connector_arrow/tests/it/tests.rs b/connector_arrow/tests/it/tests.rs index 3040222..eaa23c6 100644 --- a/connector_arrow/tests/it/tests.rs +++ b/connector_arrow/tests/it/tests.rs @@ -48,7 +48,7 @@ pub fn query_02(conn: &mut C) { pub fn query_03(conn: &mut C) { let query = "SELECT - CAST($1 as bool) as a_bool, CAST($2 as integer) as an_int, CAST($3 as real) as a_real, CAST($4 as text) as a_text + CAST($1 as boolean) as a_bool, CAST($2 as integer) as an_int, CAST($3 as real) as a_real, CAST($4 as text) as a_text "; let mut stmt = conn.query(query).unwrap();