diff --git a/connector_arrow/src/duckdb/append.rs b/connector_arrow/src/duckdb/append.rs index a7e5597..3960619 100644 --- a/connector_arrow/src/duckdb/append.rs +++ b/connector_arrow/src/duckdb/append.rs @@ -5,11 +5,16 @@ use duckdb::{types::Value, Appender}; use crate::{api::Append, ConnectorError}; -impl<'conn> Append<'conn> for Appender<'conn> { +pub struct DuckDBAppender<'conn> { + pub(super) inner: Appender<'conn>, +} + +impl<'conn> Append<'conn> for DuckDBAppender<'conn> { fn append(&mut self, batch: RecordBatch) -> Result<(), ConnectorError> { for row_index in 0..batch.num_rows() { let row = convert_row(&batch, row_index); - self.append_row(duckdb::appender_params_from_iter(row))?; + self.inner + .append_row(duckdb::appender_params_from_iter(row))?; } Ok(()) diff --git a/connector_arrow/src/duckdb/mod.rs b/connector_arrow/src/duckdb/mod.rs index 4110835..fd110e4 100644 --- a/connector_arrow/src/duckdb/mod.rs +++ b/connector_arrow/src/duckdb/mod.rs @@ -3,29 +3,43 @@ mod append; mod schema; +#[doc(hidden)] +pub use append::DuckDBAppender; + use arrow::datatypes::DataType; use arrow::record_batch::RecordBatch; -use duckdb::{Appender, Arrow}; use std::sync::Arc; use crate::api::{Connection, ResultReader, Statement}; use crate::errors::ConnectorError; -impl Connection for duckdb::Connection { +pub struct DuckDBConnection { + inner: duckdb::Connection, +} + +impl DuckDBConnection { + pub fn new(inner: duckdb::Connection) -> Self { + Self { inner } + } +} + +impl Connection for DuckDBConnection { type Stmt<'conn> = DuckDBStatement<'conn> where Self: 'conn; - type Append<'conn> = Appender<'conn> where Self: 'conn; + type Append<'conn> = DuckDBAppender<'conn> where Self: 'conn; fn query<'a>(&'a mut self, query: &str) -> Result, ConnectorError> { - let stmt = duckdb::Connection::prepare(self, query)?; + let stmt = self.inner.prepare(query)?; Ok(DuckDBStatement { stmt }) } fn append<'a>(&'a mut self, table_name: &str) -> Result, ConnectorError> { - Ok(self.appender(table_name)?) + Ok(DuckDBAppender { + inner: self.inner.appender(table_name)?, + }) } fn coerce_type(ty: &DataType) -> Option { @@ -57,7 +71,7 @@ impl<'conn> Statement<'conn> for DuckDBStatement<'conn> { #[doc(hidden)] pub struct DuckDBReader<'stmt> { - arrow: Arrow<'stmt>, + arrow: duckdb::Arrow<'stmt>, } impl<'stmt> ResultReader<'stmt> for DuckDBReader<'stmt> { diff --git a/connector_arrow/src/duckdb/schema.rs b/connector_arrow/src/duckdb/schema.rs index 567759e..1d147ff 100644 --- a/connector_arrow/src/duckdb/schema.rs +++ b/connector_arrow/src/duckdb/schema.rs @@ -4,10 +4,12 @@ use itertools::Itertools; use crate::api::{SchemaEdit, SchemaGet}; use crate::{ConnectorError, TableCreateError, TableDropError}; -impl SchemaGet for duckdb::Connection { +use super::DuckDBConnection; + +impl SchemaGet for DuckDBConnection { fn table_list(&mut self) -> Result, ConnectorError> { let query_tables = "SHOW TABLES;"; - let mut statement = self.prepare(query_tables)?; + let mut statement = self.inner.prepare(query_tables)?; let mut tables_res = statement.query([])?; let mut table_names = Vec::new(); @@ -20,14 +22,14 @@ impl SchemaGet for duckdb::Connection { fn table_get(&mut self, name: &str) -> Result { let query_schema = format!("SELECT * FROM \"{name}\" WHERE FALSE;"); - let mut statement = self.prepare(&query_schema)?; + let mut statement = self.inner.prepare(&query_schema)?; let results = statement.query_arrow([])?; Ok(results.get_schema()) } } -impl SchemaEdit for duckdb::Connection { +impl SchemaEdit for DuckDBConnection { fn table_create(&mut self, name: &str, schema: SchemaRef) -> Result<(), TableCreateError> { let column_defs = schema .fields() @@ -45,7 +47,7 @@ impl SchemaEdit for duckdb::Connection { let ddl = format!("CREATE TABLE \"{name}\" ({column_defs});"); - let res = self.execute(&ddl, []); + let res = self.inner.execute(&ddl, []); match res { Ok(_) => Ok(()), Err(e) @@ -62,7 +64,7 @@ impl SchemaEdit for duckdb::Connection { // TODO: properly escape let ddl = format!("DROP TABLE \"{name}\";"); - let res = self.execute(&ddl, []); + let res = self.inner.execute(&ddl, []); match res { Ok(_) => Ok(()), diff --git a/connector_arrow/src/lib.rs b/connector_arrow/src/lib.rs index 371fec8..03738f5 100644 --- a/connector_arrow/src/lib.rs +++ b/connector_arrow/src/lib.rs @@ -14,18 +14,23 @@ //! //! Example for SQLite: //! ``` -//! use connector_arrow::api::{Connection, Statement}; -//! use connector_arrow::arrow::record_batch::RecordBatch; +//! # use connector_arrow::api::{Connection, Statement, ResultReader}; +//! # use connector_arrow::arrow::record_batch::RecordBatch; +//! # use connector_arrow::arrow::datatypes::SchemaRef; +//! # use connector_arrow::sqlite::SQLiteConnection; //! //! # fn main() -> Result<(), connector_arrow::ConnectorError> { //! // a regular rusqlite connection -//! let mut conn = rusqlite::Connection::open_in_memory()?; +//! let conn = rusqlite::Connection::open_in_memory()?; +//! +//! // wrap into connector_arrow connection +//! let mut conn = SQLiteConnection::new(conn); //! -//! // provided by connector_arrow::api::Connection //! let mut stmt = conn.query("SELECT 1 as a")?; //! -//! // provided by connector_arrow::api::Statement -//! let reader = stmt.start(())?; +//! let mut reader = stmt.start(())?; +//! +//! let schema: SchemaRef = reader.get_schema()?; //! //! // reader implements Iterator> //! let batches: Vec = reader.collect::>()?; diff --git a/connector_arrow/src/postgres/mod.rs b/connector_arrow/src/postgres/mod.rs index 064848e..1db0a87 100644 --- a/connector_arrow/src/postgres/mod.rs +++ b/connector_arrow/src/postgres/mod.rs @@ -5,11 +5,10 @@ //! use connector_arrow::postgres::{PostgresConnection, ProtocolExtended}; //! use connector_arrow::api::Connection; //! -//! let mut client = Client::connect("postgres://localhost:5432/my_db", NoTls).unwrap(); +//! let client = Client::connect("postgres://localhost:5432/my_db", NoTls).unwrap(); //! -//! let mut conn = PostgresConnection::::new(&mut client); +//! let mut conn = PostgresConnection::::new(client); //! -//! // provided by api::Connection //! let stmt = conn.query("SELECT * FROM my_table").unwrap(); //! ```` @@ -32,18 +31,22 @@ use crate::errors::ConnectorError; /// Requires generic argument `Protocol`, which can be one of the following types: /// - [ProtocolExtended] /// - [ProtocolSimple] -pub struct PostgresConnection<'a, Protocol> { - client: &'a mut Client, +pub struct PostgresConnection { + client: Client, _protocol: PhantomData, } -impl<'a, Protocol> PostgresConnection<'a, Protocol> { - pub fn new(client: &'a mut Client) -> Self { +impl PostgresConnection { + pub fn new(client: Client) -> Self { PostgresConnection { client, _protocol: PhantomData, } } + + pub fn unwrap(self) -> Client { + self.client + } } /// Extended PostgreSQL wire protocol. @@ -78,7 +81,7 @@ pub enum PostgresError { IO(#[from] std::io::Error), } -impl<'c, P> Connection for PostgresConnection<'c, P> +impl

Connection for PostgresConnection

where for<'conn> PostgresStatement<'conn, P>: Statement<'conn>, { @@ -92,7 +95,7 @@ where .prepare(query) .map_err(PostgresError::Postgres)?; Ok(PostgresStatement { - client: self.client, + client: &mut self.client, query: query.to_string(), stmt, _protocol: &PhantomData, @@ -100,7 +103,7 @@ where } fn append<'a>(&'a mut self, table_name: &str) -> Result, ConnectorError> { - append::PostgresAppender::new(self.client, table_name) + append::PostgresAppender::new(&mut self.client, table_name) } fn coerce_type(ty: &DataType) -> Option { diff --git a/connector_arrow/src/postgres/schema.rs b/connector_arrow/src/postgres/schema.rs index 89fd724..cd66b1c 100644 --- a/connector_arrow/src/postgres/schema.rs +++ b/connector_arrow/src/postgres/schema.rs @@ -10,7 +10,7 @@ use crate::{ConnectorError, TableCreateError, TableDropError}; use super::PostgresError; -impl<'a, P> SchemaGet for super::PostgresConnection<'a, P> { +impl

SchemaGet for super::PostgresConnection

{ fn table_list(&mut self) -> Result, ConnectorError> { let query = " SELECT relname @@ -62,7 +62,7 @@ impl<'a, P> SchemaGet for super::PostgresConnection<'a, P> { } } -impl<'a, P> SchemaEdit for super::PostgresConnection<'a, P> { +impl

SchemaEdit for super::PostgresConnection

{ fn table_create(&mut self, name: &str, schema: SchemaRef) -> Result<(), TableCreateError> { let column_defs = schema .fields() diff --git a/connector_arrow/src/sqlite/mod.rs b/connector_arrow/src/sqlite/mod.rs index 8a68ce7..1969001 100644 --- a/connector_arrow/src/sqlite/mod.rs +++ b/connector_arrow/src/sqlite/mod.rs @@ -5,27 +5,37 @@ mod query; mod schema; mod types; -use crate::api::Connection; -use crate::errors::ConnectorError; - #[doc(hidden)] pub use append::SQLiteAppender; -use arrow::datatypes::DataType; #[doc(hidden)] pub use query::SQLiteStatement; -impl Connection for rusqlite::Connection { +use crate::api::Connection; +use crate::errors::ConnectorError; +use arrow::datatypes::DataType; + +pub struct SQLiteConnection { + inner: rusqlite::Connection, +} + +impl SQLiteConnection { + pub fn new(inner: rusqlite::Connection) -> Self { + Self { inner } + } +} + +impl Connection for SQLiteConnection { type Stmt<'conn> = SQLiteStatement<'conn> where Self: 'conn; type Append<'conn> = SQLiteAppender<'conn> where Self: 'conn; fn query(&mut self, query: &str) -> Result { - let stmt = rusqlite::Connection::prepare(self, query)?; + let stmt = self.inner.prepare(query)?; Ok(SQLiteStatement { stmt }) } fn append<'a>(&'a mut self, table: &str) -> Result, ConnectorError> { - let transaction = self.transaction()?; + let transaction = self.inner.transaction()?; SQLiteAppender::new(table.to_string(), transaction) } diff --git a/connector_arrow/src/sqlite/schema.rs b/connector_arrow/src/sqlite/schema.rs index e569bf0..c343070 100644 --- a/connector_arrow/src/sqlite/schema.rs +++ b/connector_arrow/src/sqlite/schema.rs @@ -6,11 +6,12 @@ use crate::api::{SchemaEdit, SchemaGet}; use crate::errors::{ConnectorError, TableCreateError, TableDropError}; use super::types::{self, ty_from_arrow}; +use super::SQLiteConnection; -impl SchemaGet for rusqlite::Connection { +impl SchemaGet for SQLiteConnection { fn table_list(&mut self) -> Result, ConnectorError> { let query_tables = "SELECT name FROM sqlite_master WHERE type = 'table';"; - let mut statement = self.prepare(query_tables)?; + let mut statement = self.inner.prepare(query_tables)?; let mut tables_res = statement.query(())?; let mut table_names = Vec::new(); @@ -26,7 +27,7 @@ impl SchemaGet for rusqlite::Connection { table_name: &str, ) -> Result { let query_columns = format!("PRAGMA table_info(\"{}\");", table_name); - let mut statement = self.prepare(&query_columns)?; + let mut statement = self.inner.prepare(&query_columns)?; let mut columns_res = statement.query(())?; // contains columns: cid, name, type, notnull, dflt_value, pk @@ -44,7 +45,7 @@ impl SchemaGet for rusqlite::Connection { } } -impl SchemaEdit for rusqlite::Connection { +impl SchemaEdit for SQLiteConnection { fn table_create(&mut self, name: &str, schema: SchemaRef) -> Result<(), TableCreateError> { table_create(self, name, schema) } @@ -55,7 +56,7 @@ impl SchemaEdit for rusqlite::Connection { } pub(crate) fn table_create( - conn: &mut rusqlite::Connection, + conn: &mut SQLiteConnection, name: &str, schema: SchemaRef, ) -> Result<(), TableCreateError> { @@ -73,7 +74,7 @@ pub(crate) fn table_create( let ddl = format!("CREATE TABLE \"{name}\" ({column_defs});"); - let res = conn.execute(&ddl, ()); + let res = conn.inner.execute(&ddl, ()); match res { Ok(_) => Ok(()), Err(e) if e.to_string().ends_with("already exists") => Err(TableCreateError::TableExists), @@ -81,13 +82,10 @@ pub(crate) fn table_create( } } -pub(crate) fn table_drop( - conn: &mut rusqlite::Connection, - name: &str, -) -> Result<(), TableDropError> { +pub(crate) fn table_drop(conn: &mut SQLiteConnection, name: &str) -> Result<(), TableDropError> { let ddl = format!("DROP TABLE \"{name}\";"); - let res = conn.execute(&ddl, ()); + let res = conn.inner.execute(&ddl, ()); match res { Ok(_) => Ok(()), Err(e) if e.to_string().starts_with("no such table") => { diff --git a/connector_arrow/tests/it/test_duckdb.rs b/connector_arrow/tests/it/test_duckdb.rs index 53a15af..fbb8e2a 100644 --- a/connector_arrow/tests/it/test_duckdb.rs +++ b/connector_arrow/tests/it/test_duckdb.rs @@ -1,10 +1,11 @@ use super::spec; use rstest::*; -fn init() -> duckdb::Connection { +fn init() -> connector_arrow::duckdb::DuckDBConnection { let _ = env_logger::builder().is_test(true).try_init(); - duckdb::Connection::open_in_memory().unwrap() + let conn = duckdb::Connection::open_in_memory().unwrap(); + connector_arrow::duckdb::DuckDBConnection::new(conn) } #[test] diff --git a/connector_arrow/tests/it/test_postgres_extended.rs b/connector_arrow/tests/it/test_postgres_extended.rs index dae5a48..ffc0c68 100644 --- a/connector_arrow/tests/it/test_postgres_extended.rs +++ b/connector_arrow/tests/it/test_postgres_extended.rs @@ -3,21 +3,17 @@ use rstest::*; use super::spec; -fn init() -> postgres::Client { +fn init() -> PostgresConnection { let _ = env_logger::builder().is_test(true).try_init(); let dburl = std::env::var("POSTGRES_URL").unwrap(); - postgres::Client::connect(&dburl, postgres::NoTls).unwrap() -} - -fn wrap_conn(client: &mut postgres::Client) -> PostgresConnection { + let client = postgres::Client::connect(&dburl, postgres::NoTls).unwrap(); PostgresConnection::new(client) } #[test] fn query_01() { - let mut client = init(); - let mut conn = wrap_conn(&mut client); + let mut conn = init(); super::tests::query_01(&mut conn); } @@ -26,8 +22,7 @@ fn query_01() { #[case::null_bool("extended::roundtrip::null_bool", spec::null_bool())] #[case::numeric("extended::roundtrip::numeric", spec::numeric())] fn roundtrip(#[case] table_name: &str, #[case] spec: spec::ArrowGenSpec) { - let mut client = init(); - let mut conn = wrap_conn(&mut client); + let mut conn = init(); super::tests::roundtrip(&mut conn, table_name, spec); } @@ -35,8 +30,7 @@ fn roundtrip(#[case] table_name: &str, #[case] spec: spec::ArrowGenSpec) { fn schema_get() { let table_name = "extended::schema_get"; - let mut client = init(); - let mut conn = wrap_conn(&mut client); + let mut conn = init(); let column_spec = super::spec::all_types(); super::tests::schema_get(&mut conn, table_name, column_spec); } @@ -45,8 +39,7 @@ fn schema_get() { fn schema_edit() { let table_name = "extended::schema_edit"; - let mut client = init(); - let mut conn = wrap_conn(&mut client); + let mut conn = init(); let column_spec = super::spec::all_types(); super::tests::schema_edit(&mut conn, table_name, column_spec); } diff --git a/connector_arrow/tests/it/test_postgres_simple.rs b/connector_arrow/tests/it/test_postgres_simple.rs index b065063..5818b9e 100644 --- a/connector_arrow/tests/it/test_postgres_simple.rs +++ b/connector_arrow/tests/it/test_postgres_simple.rs @@ -3,21 +3,17 @@ use rstest::*; use super::spec; -fn init() -> postgres::Client { +fn init() -> PostgresConnection { let _ = env_logger::builder().is_test(true).try_init(); let dburl = std::env::var("POSTGRES_URL").unwrap(); - postgres::Client::connect(&dburl, postgres::NoTls).unwrap() -} - -fn wrap_conn(client: &mut postgres::Client) -> PostgresConnection { + let client = postgres::Client::connect(&dburl, postgres::NoTls).unwrap(); PostgresConnection::new(client) } #[test] fn query_01() { - let mut client = init(); - let mut conn = wrap_conn(&mut client); + let mut conn = init(); super::tests::query_01(&mut conn); } @@ -26,8 +22,7 @@ fn query_01() { #[case::null_bool("simple::roundtrip::null_bool", spec::null_bool())] #[case::numeric("simple::roundtrip::numeric", spec::numeric())] fn roundtrip(#[case] table_name: &str, #[case] spec: spec::ArrowGenSpec) { - let mut client = init(); - let mut conn = wrap_conn(&mut client); + let mut conn = init(); super::tests::roundtrip(&mut conn, table_name, spec); } @@ -35,8 +30,7 @@ fn roundtrip(#[case] table_name: &str, #[case] spec: spec::ArrowGenSpec) { fn schema_get() { let table_name = "simple::schema_get"; - let mut client = init(); - let mut conn = wrap_conn(&mut client); + let mut conn = init(); let column_spec = super::spec::all_types(); super::tests::schema_get(&mut conn, table_name, column_spec); } @@ -45,8 +39,7 @@ fn schema_get() { fn schema_edit() { let table_name = "simple::schema_edit"; - let mut client = init(); - let mut conn = wrap_conn(&mut client); + let mut conn = init(); let column_spec = super::spec::all_types(); super::tests::schema_edit(&mut conn, table_name, column_spec); } diff --git a/connector_arrow/tests/it/test_sqlite.rs b/connector_arrow/tests/it/test_sqlite.rs index 0040aab..751f701 100644 --- a/connector_arrow/tests/it/test_sqlite.rs +++ b/connector_arrow/tests/it/test_sqlite.rs @@ -1,10 +1,11 @@ use super::spec; use rstest::*; -fn init() -> rusqlite::Connection { +fn init() -> connector_arrow::sqlite::SQLiteConnection { let _ = env_logger::builder().is_test(true).try_init(); - rusqlite::Connection::open_in_memory().unwrap() + let conn = rusqlite::Connection::open_in_memory().unwrap(); + connector_arrow::sqlite::SQLiteConnection::new(conn) } #[test]