Skip to content

Commit

Permalink
refactor: wrap instead of extend public impls
Browse files Browse the repository at this point in the history
  • Loading branch information
aljazerzen committed Feb 20, 2024
1 parent 1e3d842 commit 187449c
Show file tree
Hide file tree
Showing 12 changed files with 105 additions and 80 deletions.
9 changes: 7 additions & 2 deletions connector_arrow/src/duckdb/append.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,16 @@ use duckdb::{types::Value, Appender};

use crate::{api::Append, ConnectorError};

impl<'conn> Append<'conn> for Appender<'conn> {
pub struct DuckDBAppender<'conn> {
pub(super) inner: Appender<'conn>,
}

impl<'conn> Append<'conn> for DuckDBAppender<'conn> {
fn append(&mut self, batch: RecordBatch) -> Result<(), ConnectorError> {
for row_index in 0..batch.num_rows() {
let row = convert_row(&batch, row_index);
self.append_row(duckdb::appender_params_from_iter(row))?;
self.inner
.append_row(duckdb::appender_params_from_iter(row))?;
}

Ok(())
Expand Down
26 changes: 20 additions & 6 deletions connector_arrow/src/duckdb/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,29 +3,43 @@
mod append;
mod schema;

#[doc(hidden)]
pub use append::DuckDBAppender;

use arrow::datatypes::DataType;
use arrow::record_batch::RecordBatch;
use duckdb::{Appender, Arrow};

use std::sync::Arc;

use crate::api::{Connection, ResultReader, Statement};
use crate::errors::ConnectorError;

impl Connection for duckdb::Connection {
pub struct DuckDBConnection {
inner: duckdb::Connection,
}

impl DuckDBConnection {
pub fn new(inner: duckdb::Connection) -> Self {
Self { inner }
}
}

impl Connection for DuckDBConnection {
type Stmt<'conn> = DuckDBStatement<'conn>
where
Self: 'conn;

type Append<'conn> = Appender<'conn> where Self: 'conn;
type Append<'conn> = DuckDBAppender<'conn> where Self: 'conn;

fn query<'a>(&'a mut self, query: &str) -> Result<Self::Stmt<'a>, ConnectorError> {
let stmt = duckdb::Connection::prepare(self, query)?;
let stmt = self.inner.prepare(query)?;

Ok(DuckDBStatement { stmt })
}
fn append<'a>(&'a mut self, table_name: &str) -> Result<Self::Append<'a>, ConnectorError> {
Ok(self.appender(table_name)?)
Ok(DuckDBAppender {
inner: self.inner.appender(table_name)?,
})
}

fn coerce_type(ty: &DataType) -> Option<DataType> {
Expand Down Expand Up @@ -57,7 +71,7 @@ impl<'conn> Statement<'conn> for DuckDBStatement<'conn> {

#[doc(hidden)]
pub struct DuckDBReader<'stmt> {
arrow: Arrow<'stmt>,
arrow: duckdb::Arrow<'stmt>,
}

impl<'stmt> ResultReader<'stmt> for DuckDBReader<'stmt> {
Expand Down
14 changes: 8 additions & 6 deletions connector_arrow/src/duckdb/schema.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,12 @@ use itertools::Itertools;
use crate::api::{SchemaEdit, SchemaGet};
use crate::{ConnectorError, TableCreateError, TableDropError};

impl SchemaGet for duckdb::Connection {
use super::DuckDBConnection;

impl SchemaGet for DuckDBConnection {
fn table_list(&mut self) -> Result<Vec<String>, ConnectorError> {
let query_tables = "SHOW TABLES;";
let mut statement = self.prepare(query_tables)?;
let mut statement = self.inner.prepare(query_tables)?;
let mut tables_res = statement.query([])?;

let mut table_names = Vec::new();
Expand All @@ -20,14 +22,14 @@ impl SchemaGet for duckdb::Connection {

fn table_get(&mut self, name: &str) -> Result<arrow::datatypes::SchemaRef, ConnectorError> {
let query_schema = format!("SELECT * FROM \"{name}\" WHERE FALSE;");
let mut statement = self.prepare(&query_schema)?;
let mut statement = self.inner.prepare(&query_schema)?;
let results = statement.query_arrow([])?;

Ok(results.get_schema())
}
}

impl SchemaEdit for duckdb::Connection {
impl SchemaEdit for DuckDBConnection {
fn table_create(&mut self, name: &str, schema: SchemaRef) -> Result<(), TableCreateError> {
let column_defs = schema
.fields()
Expand All @@ -45,7 +47,7 @@ impl SchemaEdit for duckdb::Connection {

let ddl = format!("CREATE TABLE \"{name}\" ({column_defs});");

let res = self.execute(&ddl, []);
let res = self.inner.execute(&ddl, []);
match res {
Ok(_) => Ok(()),
Err(e)
Expand All @@ -62,7 +64,7 @@ impl SchemaEdit for duckdb::Connection {
// TODO: properly escape
let ddl = format!("DROP TABLE \"{name}\";");

let res = self.execute(&ddl, []);
let res = self.inner.execute(&ddl, []);

match res {
Ok(_) => Ok(()),
Expand Down
17 changes: 11 additions & 6 deletions connector_arrow/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,18 +14,23 @@
//!
//! Example for SQLite:
//! ```
//! use connector_arrow::api::{Connection, Statement};
//! use connector_arrow::arrow::record_batch::RecordBatch;
//! # use connector_arrow::api::{Connection, Statement, ResultReader};
//! # use connector_arrow::arrow::record_batch::RecordBatch;
//! # use connector_arrow::arrow::datatypes::SchemaRef;
//! # use connector_arrow::sqlite::SQLiteConnection;
//!
//! # fn main() -> Result<(), connector_arrow::ConnectorError> {
//! // a regular rusqlite connection
//! let mut conn = rusqlite::Connection::open_in_memory()?;
//! let conn = rusqlite::Connection::open_in_memory()?;
//!
//! // wrap into connector_arrow connection
//! let mut conn = SQLiteConnection::new(conn);
//!
//! // provided by connector_arrow::api::Connection
//! let mut stmt = conn.query("SELECT 1 as a")?;
//!
//! // provided by connector_arrow::api::Statement
//! let reader = stmt.start(())?;
//! let mut reader = stmt.start(())?;
//!
//! let schema: SchemaRef = reader.get_schema()?;
//!
//! // reader implements Iterator<Item = Result<RecordBatch, _>>
//! let batches: Vec<RecordBatch> = reader.collect::<Result<_, _>>()?;
Expand Down
23 changes: 13 additions & 10 deletions connector_arrow/src/postgres/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,10 @@
//! use connector_arrow::postgres::{PostgresConnection, ProtocolExtended};
//! use connector_arrow::api::Connection;
//!
//! let mut client = Client::connect("postgres://localhost:5432/my_db", NoTls).unwrap();
//! let client = Client::connect("postgres://localhost:5432/my_db", NoTls).unwrap();
//!
//! let mut conn = PostgresConnection::<ProtocolExtended>::new(&mut client);
//! let mut conn = PostgresConnection::<ProtocolExtended>::new(client);
//!
//! // provided by api::Connection
//! let stmt = conn.query("SELECT * FROM my_table").unwrap();
//! ````
Expand All @@ -32,18 +31,22 @@ use crate::errors::ConnectorError;
/// Requires generic argument `Protocol`, which can be one of the following types:
/// - [ProtocolExtended]
/// - [ProtocolSimple]
pub struct PostgresConnection<'a, Protocol> {
client: &'a mut Client,
pub struct PostgresConnection<Protocol> {
client: Client,
_protocol: PhantomData<Protocol>,
}

impl<'a, Protocol> PostgresConnection<'a, Protocol> {
pub fn new(client: &'a mut Client) -> Self {
impl<Protocol> PostgresConnection<Protocol> {
pub fn new(client: Client) -> Self {
PostgresConnection {
client,
_protocol: PhantomData,
}
}

pub fn unwrap(self) -> Client {
self.client
}
}

/// Extended PostgreSQL wire protocol.
Expand Down Expand Up @@ -78,7 +81,7 @@ pub enum PostgresError {
IO(#[from] std::io::Error),
}

impl<'c, P> Connection for PostgresConnection<'c, P>
impl<P> Connection for PostgresConnection<P>
where
for<'conn> PostgresStatement<'conn, P>: Statement<'conn>,
{
Expand All @@ -92,15 +95,15 @@ where
.prepare(query)
.map_err(PostgresError::Postgres)?;
Ok(PostgresStatement {
client: self.client,
client: &mut self.client,
query: query.to_string(),
stmt,
_protocol: &PhantomData,
})
}

fn append<'a>(&'a mut self, table_name: &str) -> Result<Self::Append<'a>, ConnectorError> {
append::PostgresAppender::new(self.client, table_name)
append::PostgresAppender::new(&mut self.client, table_name)
}

fn coerce_type(ty: &DataType) -> Option<DataType> {
Expand Down
4 changes: 2 additions & 2 deletions connector_arrow/src/postgres/schema.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ use crate::{ConnectorError, TableCreateError, TableDropError};

use super::PostgresError;

impl<'a, P> SchemaGet for super::PostgresConnection<'a, P> {
impl<P> SchemaGet for super::PostgresConnection<P> {
fn table_list(&mut self) -> Result<Vec<String>, ConnectorError> {
let query = "
SELECT relname
Expand Down Expand Up @@ -62,7 +62,7 @@ impl<'a, P> SchemaGet for super::PostgresConnection<'a, P> {
}
}

impl<'a, P> SchemaEdit for super::PostgresConnection<'a, P> {
impl<P> SchemaEdit for super::PostgresConnection<P> {
fn table_create(&mut self, name: &str, schema: SchemaRef) -> Result<(), TableCreateError> {
let column_defs = schema
.fields()
Expand Down
24 changes: 17 additions & 7 deletions connector_arrow/src/sqlite/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,27 +5,37 @@ mod query;
mod schema;
mod types;

use crate::api::Connection;
use crate::errors::ConnectorError;

#[doc(hidden)]
pub use append::SQLiteAppender;
use arrow::datatypes::DataType;
#[doc(hidden)]
pub use query::SQLiteStatement;

impl Connection for rusqlite::Connection {
use crate::api::Connection;
use crate::errors::ConnectorError;
use arrow::datatypes::DataType;

pub struct SQLiteConnection {
inner: rusqlite::Connection,
}

impl SQLiteConnection {
pub fn new(inner: rusqlite::Connection) -> Self {
Self { inner }
}
}

impl Connection for SQLiteConnection {
type Stmt<'conn> = SQLiteStatement<'conn> where Self: 'conn;

type Append<'conn> = SQLiteAppender<'conn> where Self: 'conn;

fn query(&mut self, query: &str) -> Result<SQLiteStatement, ConnectorError> {
let stmt = rusqlite::Connection::prepare(self, query)?;
let stmt = self.inner.prepare(query)?;
Ok(SQLiteStatement { stmt })
}

fn append<'a>(&'a mut self, table: &str) -> Result<Self::Append<'a>, ConnectorError> {
let transaction = self.transaction()?;
let transaction = self.inner.transaction()?;

SQLiteAppender::new(table.to_string(), transaction)
}
Expand Down
20 changes: 9 additions & 11 deletions connector_arrow/src/sqlite/schema.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,12 @@ use crate::api::{SchemaEdit, SchemaGet};
use crate::errors::{ConnectorError, TableCreateError, TableDropError};

use super::types::{self, ty_from_arrow};
use super::SQLiteConnection;

impl SchemaGet for rusqlite::Connection {
impl SchemaGet for SQLiteConnection {
fn table_list(&mut self) -> Result<Vec<String>, ConnectorError> {
let query_tables = "SELECT name FROM sqlite_master WHERE type = 'table';";
let mut statement = self.prepare(query_tables)?;
let mut statement = self.inner.prepare(query_tables)?;
let mut tables_res = statement.query(())?;

let mut table_names = Vec::new();
Expand All @@ -26,7 +27,7 @@ impl SchemaGet for rusqlite::Connection {
table_name: &str,
) -> Result<arrow::datatypes::SchemaRef, ConnectorError> {
let query_columns = format!("PRAGMA table_info(\"{}\");", table_name);
let mut statement = self.prepare(&query_columns)?;
let mut statement = self.inner.prepare(&query_columns)?;
let mut columns_res = statement.query(())?;
// contains columns: cid, name, type, notnull, dflt_value, pk

Expand All @@ -44,7 +45,7 @@ impl SchemaGet for rusqlite::Connection {
}
}

impl SchemaEdit for rusqlite::Connection {
impl SchemaEdit for SQLiteConnection {
fn table_create(&mut self, name: &str, schema: SchemaRef) -> Result<(), TableCreateError> {
table_create(self, name, schema)
}
Expand All @@ -55,7 +56,7 @@ impl SchemaEdit for rusqlite::Connection {
}

pub(crate) fn table_create(
conn: &mut rusqlite::Connection,
conn: &mut SQLiteConnection,
name: &str,
schema: SchemaRef,
) -> Result<(), TableCreateError> {
Expand All @@ -73,21 +74,18 @@ pub(crate) fn table_create(

let ddl = format!("CREATE TABLE \"{name}\" ({column_defs});");

let res = conn.execute(&ddl, ());
let res = conn.inner.execute(&ddl, ());
match res {
Ok(_) => Ok(()),
Err(e) if e.to_string().ends_with("already exists") => Err(TableCreateError::TableExists),
Err(e) => Err(TableCreateError::Connector(ConnectorError::SQLite(e))),
}
}

pub(crate) fn table_drop(
conn: &mut rusqlite::Connection,
name: &str,
) -> Result<(), TableDropError> {
pub(crate) fn table_drop(conn: &mut SQLiteConnection, name: &str) -> Result<(), TableDropError> {
let ddl = format!("DROP TABLE \"{name}\";");

let res = conn.execute(&ddl, ());
let res = conn.inner.execute(&ddl, ());
match res {
Ok(_) => Ok(()),
Err(e) if e.to_string().starts_with("no such table") => {
Expand Down
5 changes: 3 additions & 2 deletions connector_arrow/tests/it/test_duckdb.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
use super::spec;
use rstest::*;

fn init() -> duckdb::Connection {
fn init() -> connector_arrow::duckdb::DuckDBConnection {
let _ = env_logger::builder().is_test(true).try_init();

duckdb::Connection::open_in_memory().unwrap()
let conn = duckdb::Connection::open_in_memory().unwrap();
connector_arrow::duckdb::DuckDBConnection::new(conn)
}

#[test]
Expand Down
Loading

0 comments on commit 187449c

Please sign in to comment.