From d73a61011bfb60585eccca30aa9374e38f4b2ee6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Alja=C5=BE=20Mur=20Er=C5=BEen?= Date: Mon, 19 Feb 2024 13:23:27 +0100 Subject: [PATCH] feat: postgres append, schema get & edit --- Cargo.lock | 139 +---------- README.md | 6 +- connector_arrow/Cargo.toml | 14 +- connector_arrow/src/postgres/append.rs | 230 ++++++++++++++++++ connector_arrow/src/postgres/mod.rs | 15 +- .../src/postgres/protocol_extended.rs | 4 +- .../src/postgres/protocol_simple.rs | 4 +- connector_arrow/src/postgres/schema.rs | 108 ++++++++ connector_arrow/src/postgres/types.rs | 52 +++- connector_arrow/src/sqlite/query.rs | 2 +- connector_arrow/src/util/mod.rs | 2 + connector_arrow/src/util/row_reader.rs | 157 ++++++++++++ connector_arrow/src/util/transport.rs | 22 +- connector_arrow/tests/it/test_duckdb.rs | 21 +- connector_arrow/tests/it/test_postgres.rs | 77 +++++- connector_arrow/tests/it/test_sqlite.rs | 21 +- connector_arrow/tests/it/util.rs | 18 +- 17 files changed, 708 insertions(+), 184 deletions(-) create mode 100644 connector_arrow/src/postgres/append.rs create mode 100644 connector_arrow/src/postgres/schema.rs create mode 100644 connector_arrow/src/util/row_reader.rs diff --git a/Cargo.lock b/Cargo.lock index 048adfe..7905e1f 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -51,21 +51,6 @@ dependencies = [ "memchr", ] -[[package]] -name = "alloc-no-stdlib" -version = "2.0.4" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "cc7bb162ec39d46ab1ca8c77bf72e890535becd1751bb45f64c597edb4c8c6b3" - -[[package]] -name = "alloc-stdlib" -version = "0.2.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "94fb8275041c72129eb51b7d0322c29b8387a0386127718b096429201a5d6ece" -dependencies = [ - "alloc-no-stdlib", -] - [[package]] name = "allocator-api2" version = "0.2.16" @@ -152,10 +137,7 @@ dependencies = [ "arrow-array", "arrow-buffer", "arrow-cast", - "arrow-csv", "arrow-data", - "arrow-ipc", - "arrow-json", "arrow-ord", "arrow-row", "arrow-schema", @@ -224,25 +206,6 @@ dependencies = [ "num", ] -[[package]] -name = "arrow-csv" -version = "49.0.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2e09aa6246a1d6459b3f14baeaa49606cfdbca34435c46320e14054d244987ca" -dependencies = [ - "arrow-array", - "arrow-buffer", - "arrow-cast", - "arrow-data", - "arrow-schema", - "chrono", - "csv", - "csv-core", - "lazy_static", - "lexical-core", - "regex", -] - [[package]] name = "arrow-data" version = "49.0.0" @@ -269,26 +232,6 @@ dependencies = [ "flatbuffers", ] -[[package]] -name = "arrow-json" -version = "49.0.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d82565c91fd627922ebfe2810ee4e8346841b6f9361b87505a9acea38b614fee" -dependencies = [ - "arrow-array", - "arrow-buffer", - "arrow-cast", - "arrow-data", - "arrow-schema", - "chrono", - "half", - "indexmap", - "lexical-core", - "num", - "serde", - "serde_json", -] - [[package]] name = "arrow-ord" version = "49.0.0" @@ -453,27 +396,6 @@ dependencies = [ "syn_derive", ] -[[package]] -name = "brotli" -version = "3.4.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "516074a47ef4bce09577a3b379392300159ce5b1ba2e501ff1c819950066100f" -dependencies = [ - "alloc-no-stdlib", - "alloc-stdlib", - "brotli-decompressor", -] - -[[package]] -name = "brotli-decompressor" -version = "2.5.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4e2e4afe60d7dd600fdd3de8d0f08c2b7ec039712e3b6137ff98b7004e82de4f" -dependencies = [ - "alloc-no-stdlib", - "alloc-stdlib", -] - [[package]] name = "bstr" version = "0.2.17" @@ -537,7 +459,6 @@ version = "1.0.83" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f1174fb0b6ec23863f8b971027804a42614e347eafb0a95bf0b12cdae21fc4d0" dependencies = [ - "jobserver", "libc", ] @@ -589,6 +510,8 @@ name = "connector_arrow" version = "0.3.0" dependencies = [ "arrow", + "byteorder", + "bytes", "chrono", "csv", "duckdb", @@ -600,6 +523,7 @@ dependencies = [ "log", "parquet", "postgres", + "postgres-protocol", "rusqlite", "rust_decimal", "rust_decimal_macros", @@ -1059,15 +983,6 @@ version = "1.0.10" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b1a46d1a171d865aa5f83f92695765caa047a9b4cbae2cbf37dbd613a793fd4c" -[[package]] -name = "jobserver" -version = "0.1.28" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ab46a6e9526ddef3ae7f787c06f0f2600639ba80ea3eade3d8e670a2230f51d6" -dependencies = [ - "libc", -] - [[package]] name = "js-sys" version = "0.3.56" @@ -1212,15 +1127,6 @@ version = "0.4.20" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b5e6163cb8c49088c2c36f57875e58ccd8c87c7427f7fbd50ea6710b2f3f2e8f" -[[package]] -name = "lz4_flex" -version = "0.11.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "912b45c753ff5f7f5208307e8ace7d2a2e30d024e26d3509f3dce546c044ce15" -dependencies = [ - "twox-hash", -] - [[package]] name = "md-5" version = "0.10.6" @@ -1396,20 +1302,15 @@ dependencies = [ "arrow-schema", "arrow-select", "base64", - "brotli", "bytes", "chrono", - "flate2", "hashbrown 0.14.3", - "lz4_flex", "num", "num-bigint", "paste", "seq-macro", - "snap", "thrift", "twox-hash", - "zstd", ] [[package]] @@ -1895,12 +1796,6 @@ version = "1.13.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e6ecd384b10a64542d77071bd64bd7b231f4ed5940fba55e98c3de13824cf3d7" -[[package]] -name = "snap" -version = "1.1.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1b6b67fb9a61334225b5b790716f609cd58395f895b3fe8b328786812a40bc3b" - [[package]] name = "socket2" version = "0.5.5" @@ -2516,31 +2411,3 @@ dependencies = [ "quote", "syn 2.0.48", ] - -[[package]] -name = "zstd" -version = "0.13.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bffb3309596d527cfcba7dfc6ed6052f1d39dfbd7c867aa2e865e4a449c10110" -dependencies = [ - "zstd-safe", -] - -[[package]] -name = "zstd-safe" -version = "7.0.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "43747c7422e2924c11144d5229878b98180ef8b06cca4ab5af37afc8a8d8ea3e" -dependencies = [ - "zstd-sys", -] - -[[package]] -name = "zstd-sys" -version = "2.0.9+zstd.1.5.5" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9e16efa8a874a0481a574084d34cc26fdb3b99627480f785888deb6386506656" -dependencies = [ - "cc", - "pkg-config", -] diff --git a/README.md b/README.md index 70a5a25..371ec66 100644 --- a/README.md +++ b/README.md @@ -36,9 +36,9 @@ Similar to [ADBC](https://arrow.apache.org/docs/format/ADBC.html), but written i | Streaming | | | x | | Temporal types | | x | | | Container types | | x | | -| Schema get | x | x | | -| Schema edit | x | x | | -| Append | x | x | | +| Schema get | x | x | x | +| Schema edit | x | x | x | +| Append | x | x | x | None of the sources are enabled by default, use features to enable them. diff --git a/connector_arrow/Cargo.toml b/connector_arrow/Cargo.toml index 2374064..e52459f 100644 --- a/connector_arrow/Cargo.toml +++ b/connector_arrow/Cargo.toml @@ -31,6 +31,8 @@ urlencoding = { version = "2.1", optional = true } uuid = { version = "0.8", optional = true } fallible-streaming-iterator = { version = "0.1", optional = true } csv = { version = "1", optional = true } +bytes = { version = "1", optional = true } +byteorder = { version = "1", optional = true } [dependencies.postgres] version = "0.19" @@ -38,6 +40,11 @@ default-features = false optional = true features = ["with-chrono-0_4", "with-uuid-0_8", "with-serde_json-1"] +[dependencies.postgres-protocol] +version = "0.6.6" +default-features = false +optional = true + [dependencies.rusqlite] version = "0.30.0" default-features = false @@ -50,8 +57,8 @@ optional = true [dev-dependencies] env_logger = "0.11" -arrow = { version = "49", features = ["prettyprint"] } -parquet = { version = "49" } +arrow = { version = "49", features = ["prettyprint"], default-features = false } +parquet = { version = "49", features = ["arrow"], default-features = false } insta = { version = "1.34.0" } similar-asserts = { version = "1.5.0" } @@ -59,11 +66,14 @@ similar-asserts = { version = "1.5.0" } all = ["src_sqlite", "src_duckdb", "src_postgres"] src_postgres = [ "postgres", + "postgres-protocol", "csv", "hex", "uuid", "rust_decimal", "rust_decimal_macros", + "bytes", + "byteorder" ] src_sqlite = ["rusqlite", "fallible-streaming-iterator", "urlencoding"] src_duckdb = [ diff --git a/connector_arrow/src/postgres/append.rs b/connector_arrow/src/postgres/append.rs new file mode 100644 index 0000000..a6cb3cc --- /dev/null +++ b/connector_arrow/src/postgres/append.rs @@ -0,0 +1,230 @@ +use std::sync::Arc; + +use arrow::datatypes::*; +use arrow::record_batch::RecordBatch; +use bytes::BytesMut; +use itertools::{zip_eq, Itertools}; +use postgres::binary_copy::BinaryCopyInWriter; +use postgres::types::{to_sql_checked, IsNull, ToSql}; +use postgres::{Client, CopyInWriter}; + +use crate::api::Append; +use crate::types::{FixedSizeBinaryType, NullType}; +use crate::util::transport::{Consume, ConsumeTy}; +use crate::util::ArrayCellRef; +use crate::{impl_consume_unsupported, ConnectorError}; + +use super::PostgresError; + +pub struct PostgresAppender<'c> { + writer: Writer<'c>, +} + +impl<'conn> PostgresAppender<'conn> { + pub fn new(client: &'conn mut Client, table_name: &str) -> Result { + let query = format!("COPY BINARY \"{table_name}\" FROM stdin"); + let writer = client.copy_in(&query).map_err(PostgresError::Postgres)?; + let writer = Writer::Uninitialized(writer); + Ok(Self { writer }) + } +} + +enum Writer<'c> { + Uninitialized(CopyInWriter<'c>), + Invalid, + Initialized { writer: BinaryCopyInWriter<'c> }, +} + +impl<'c> Writer<'c> { + fn as_binary( + &mut self, + schema: SchemaRef, + ) -> Result<&mut BinaryCopyInWriter<'c>, ConnectorError> { + if let Writer::Uninitialized(_) = self { + // replace plain writer with a new binary one + let Writer::Uninitialized(w) = std::mem::replace(self, Writer::Invalid) else { + unreachable!(); + }; + + // types don't really matter + // they are used only for client-side checking of match between the + // declared type and passed value. + // Because our ToSql::accepts returns true + let types = vec![postgres::types::Type::VOID; schema.fields().len()]; + + *self = Writer::Initialized { + writer: BinaryCopyInWriter::new(w, &types), + } + } + + // return binary writer + let Writer::Initialized { writer } = self else { + unreachable!(); + }; + Ok(writer) + } + + fn finish(mut self) -> Result { + let schema = Arc::new(Schema::new(vec![] as Vec)); + self.as_binary(schema)?; + match self { + Writer::Initialized { writer: w, .. } => { + Ok(w.finish().map_err(PostgresError::Postgres)?) + } + Writer::Uninitialized(_) | Writer::Invalid => unreachable!(), + } + } +} + +impl<'conn> Append<'conn> for PostgresAppender<'conn> { + fn append(&mut self, batch: RecordBatch) -> Result<(), ConnectorError> { + let writer = self.writer.as_binary(batch.schema())?; + + let schema = batch.schema(); + let mut row = zip_eq(batch.columns(), schema.fields()) + .map(|(array, field)| ArrayCellRef { + array, + field, + row_number: 0, + }) + .collect_vec(); + + for row_number in 0..batch.num_rows() { + dbg!(row_number); + for cell in &mut row { + cell.row_number = row_number; + } + + writer.write_raw(&row).map_err(PostgresError::Postgres)?; + } + Ok(()) + } + + fn finish(self) -> Result<(), ConnectorError> { + self.writer.finish()?; + Ok(()) + } +} + +impl<'a> ToSql for ArrayCellRef<'a> { + fn to_sql( + &self, + _ty: &postgres::types::Type, + out: &mut BytesMut, + ) -> Result> + where + Self: Sized, + { + if self.array.is_null(self.row_number) || matches!(self.field.data_type(), DataType::Null) { + return Ok(IsNull::Yes); + } + crate::util::transport::transport(self.field, self, out)?; + Ok(IsNull::No) + } + + fn accepts(_: &postgres::types::Type) -> bool + where + Self: Sized, + { + // we don't need type validation, arrays cannot contain wrong types + true + } + + to_sql_checked!(); +} + +impl Consume for BytesMut {} + +macro_rules! impl_consume_ty { + ($ArrTy: ty, $to_sql: ident) => { + impl ConsumeTy<$ArrTy> for BytesMut { + fn consume(&mut self, value: <$ArrTy as crate::types::ArrowType>::Native) { + postgres_protocol::types::$to_sql(value, self); + } + + fn consume_null(&mut self) {} + } + }; +} + +macro_rules! impl_consume_ref_ty { + ($ArrTy: ty, $to_sql: ident) => { + impl ConsumeTy<$ArrTy> for BytesMut { + fn consume(&mut self, value: <$ArrTy as crate::types::ArrowType>::Native) { + postgres_protocol::types::$to_sql(&value, self); + } + + fn consume_null(&mut self) {} + } + }; +} + +impl ConsumeTy for BytesMut { + fn consume(&mut self, _: ()) {} + + fn consume_null(&mut self) {} +} + +impl_consume_ty!(BooleanType, bool_to_sql); +impl_consume_ty!(Int8Type, char_to_sql); +impl_consume_ty!(Int16Type, int2_to_sql); +impl_consume_ty!(Int32Type, int4_to_sql); +impl_consume_ty!(Int64Type, int8_to_sql); +// impl_consume_ty!(UInt8Type, ); +// impl_consume_ty!(UInt16Type, ); +impl_consume_ty!(UInt32Type, oid_to_sql); +// impl_consume_ty!(UInt64Type, ); +// impl_consume_ty!(Float16Type, ); +impl_consume_ty!(Float32Type, float4_to_sql); +impl_consume_ty!(Float64Type, float8_to_sql); +// impl_consume_ty!(TimestampSecondType, ); +// impl_consume_ty!(TimestampMillisecondType, ); +impl_consume_ty!(TimestampMicrosecondType, timestamp_to_sql); +// impl_consume_ty!(TimestampNanosecondType, ); +// impl_consume_ty!(Date32Type, date_to_sql); +// impl_consume_ty!(Date64Type, date_to_sql); +// impl_consume_ty!(Time32SecondType, ); +// impl_consume_ty!(Time32MillisecondType, ); +impl_consume_ty!(Time64MicrosecondType, time_to_sql); +// impl_consume_ty!(Time64NanosecondType, ); +// impl_consume_ty!(IntervalYearMonthType, ); +// impl_consume_ty!(IntervalDayTimeType, ); +// impl_consume_ty!(IntervalMonthDayNanoType, ); +// impl_consume_ty!(DurationSecondType, ); +// impl_consume_ty!(DurationMillisecondType, ); +// impl_consume_ty!(DurationMicrosecondType, ); +// impl_consume_ty!(DurationNanosecondType, ); +impl_consume_ref_ty!(BinaryType, bytea_to_sql); +impl_consume_ref_ty!(LargeBinaryType, bytea_to_sql); +impl_consume_ref_ty!(FixedSizeBinaryType, bytea_to_sql); +impl_consume_ref_ty!(Utf8Type, text_to_sql); +impl_consume_ref_ty!(LargeUtf8Type, text_to_sql); +// impl_consume_ty!(Decimal128Type, ); +// impl_consume_ty!(Decimal256Type, ); + +impl_consume_unsupported!( + BytesMut, + ( + UInt8Type, + UInt16Type, + UInt64Type, + Float16Type, + TimestampSecondType, + TimestampMillisecondType, + TimestampNanosecondType, + Date32Type, + Date64Type, + Time32SecondType, + Time32MillisecondType, + Time64NanosecondType, + IntervalYearMonthType, + IntervalDayTimeType, + IntervalMonthDayNanoType, + DurationSecondType, + DurationMillisecondType, + DurationMicrosecondType, + DurationNanosecondType, + Decimal128Type, + Decimal256Type, + ) +); diff --git a/connector_arrow/src/postgres/mod.rs b/connector_arrow/src/postgres/mod.rs index 48d8c3b..d682b9a 100644 --- a/connector_arrow/src/postgres/mod.rs +++ b/connector_arrow/src/postgres/mod.rs @@ -10,15 +10,17 @@ //! let stmt = conn.query("SELECT * FROM my_table").unwrap(); //! ```` +mod append; mod protocol_extended; mod protocol_simple; +mod schema; mod types; use postgres::Client; use std::marker::PhantomData; use thiserror::Error; -use crate::api::{unimplemented, Connection, Statement}; +use crate::api::{Connection, Statement}; use crate::errors::ConnectorError; pub struct PostgresConnection<'a, P> { @@ -68,10 +70,13 @@ where { type Stmt<'conn> = PostgresStatement<'conn, P> where Self: 'conn; - type Append<'conn> = unimplemented::Appender where Self: 'conn; + type Append<'conn> = append::PostgresAppender<'conn> where Self: 'conn; fn query<'a>(&'a mut self, query: &str) -> Result, ConnectorError> { - let stmt = self.client.prepare(query).map_err(PostgresError::from)?; + let stmt = self + .client + .prepare(query) + .map_err(PostgresError::Postgres)?; Ok(PostgresStatement { client: self.client, query: query.to_string(), @@ -80,8 +85,8 @@ where }) } - fn append<'a>(&'a mut self, _: &str) -> Result, ConnectorError> { - unimplemented!() + fn append<'a>(&'a mut self, table_name: &str) -> Result, ConnectorError> { + append::PostgresAppender::new(self.client, table_name) } } diff --git a/connector_arrow/src/postgres/protocol_extended.rs b/connector_arrow/src/postgres/protocol_extended.rs index b375995..806fa7a 100644 --- a/connector_arrow/src/postgres/protocol_extended.rs +++ b/connector_arrow/src/postgres/protocol_extended.rs @@ -19,7 +19,7 @@ impl<'conn> Statement<'conn> for PostgresStatement<'conn, ProtocolExtended> { fn start(&mut self, _params: ()) -> Result, ConnectorError> { let stmt = &self.stmt; - let schema = types::convert_schema(stmt)?; + let schema = types::pg_stmt_to_arrow(stmt)?; let rows = self .client @@ -158,7 +158,7 @@ impl_produce!( LargeUtf8Type, ); -crate::impl_produce_unused!( +crate::impl_produce_unsupported!( CellRef<'r>, ( UInt8Type, diff --git a/connector_arrow/src/postgres/protocol_simple.rs b/connector_arrow/src/postgres/protocol_simple.rs index 4fa8c99..b81173b 100644 --- a/connector_arrow/src/postgres/protocol_simple.rs +++ b/connector_arrow/src/postgres/protocol_simple.rs @@ -16,7 +16,7 @@ impl<'conn> Statement<'conn> for PostgresStatement<'conn, ProtocolSimple> { fn start(&mut self, _params: ()) -> Result, ConnectorError> { let stmt = &self.stmt; - let schema = types::convert_schema(stmt)?; + let schema = types::pg_stmt_to_arrow(stmt)?; let rows = self .client @@ -109,7 +109,7 @@ impl_simple_produce!( Decimal256Type, ); -crate::impl_produce_unused!( +crate::impl_produce_unsupported!( CellRef<'r>, ( UInt8Type, diff --git a/connector_arrow/src/postgres/schema.rs b/connector_arrow/src/postgres/schema.rs new file mode 100644 index 0000000..d3dd6a4 --- /dev/null +++ b/connector_arrow/src/postgres/schema.rs @@ -0,0 +1,108 @@ +use std::sync::Arc; + +use arrow::datatypes::{DataType, Field, Schema, SchemaRef}; +use itertools::Itertools; +use postgres::error::SqlState; +use postgres::types::Type; + +use crate::api::{SchemaEdit, SchemaGet}; +use crate::{ConnectorError, TableCreateError, TableDropError}; + +use super::PostgresError; + +impl<'a, P> SchemaGet for super::PostgresConnection<'a, P> { + fn table_list(&mut self) -> Result, ConnectorError> { + let query = " + SELECT relname + FROM pg_class + JOIN pg_namespace ON (relnamespace = pg_namespace.oid) + WHERE nspname = current_schema AND relkind = 'r' + "; + let rows = self.client.query(query, &[]).map_err(PostgresError::from)?; + + let table_names = rows.into_iter().map(|r| r.get(0)).collect_vec(); + Ok(table_names) + } + + fn table_get( + &mut self, + table_name: &str, + ) -> Result { + let query = " + SELECT attname, atttypid, attnotnull + FROM pg_attribute + JOIN pg_class ON (attrelid = pg_class.oid) + JOIN pg_namespace ON (relnamespace = pg_namespace.oid) + WHERE nspname = current_schema AND relname = $1 AND attnum > 0 AND atttypid > 0 + ORDER BY attnum; + "; + let res = self.client.query(query, &[&table_name.to_string()]); + let rows = res.map_err(PostgresError::Postgres)?; + + let fields: Vec<_> = rows + .into_iter() + .map(|row| -> Result<_, ConnectorError> { + let name: String = row.get(0); + let typid: u32 = row.get(1); + let notnull: bool = row.get(2); + + let ty = + Type::from_oid(typid).ok_or_else(|| ConnectorError::IncompatibleSchema { + table_name: table_name.to_string(), + message: format!("column `{name}` has unsupported type (oid = {typid})"), + hint: Some("Supported types are INTEGER, REAL, TEXT and BLOB".to_string()), + })?; + let ty = super::types::pg_ty_to_arrow(&ty); + + Ok(Field::new(name, ty, !notnull)) + }) + .try_collect()?; + + Ok(Arc::new(Schema::new(fields))) + } +} + +impl<'a, P> SchemaEdit for super::PostgresConnection<'a, P> { + fn table_create(&mut self, name: &str, schema: SchemaRef) -> Result<(), TableCreateError> { + let column_defs = schema + .fields() + .iter() + .map(|field| { + let ty = super::types::arrow_ty_to_pg(field.data_type()); + + let is_nullable = + field.is_nullable() || matches!(field.data_type(), DataType::Null); + let not_null = if is_nullable { "" } else { "NOT NULL" }; + + format!("{} {}{}", field.name(), ty, not_null) + }) + .join(","); + + let ddl = format!("CREATE TABLE \"{name}\" ({column_defs});"); + + let res = self.client.execute(&ddl, &[]); + match res { + Ok(_) => Ok(()), + Err(e) if matches!(e.code(), Some(&SqlState::DUPLICATE_TABLE)) => { + Err(TableCreateError::TableExists) + } + Err(e) => Err(TableCreateError::Connector(ConnectorError::Postgres( + PostgresError::Postgres(dbg!(e)), + ))), + } + } + + fn table_drop(&mut self, name: &str) -> Result<(), TableDropError> { + let res = self.client.execute(&format!("DROP TABLE \"{name}\""), &[]); + + match res { + Ok(_) => Ok(()), + Err(e) if matches!(e.code(), Some(&SqlState::UNDEFINED_TABLE)) => { + Err(TableDropError::TableNonexistent) + } + Err(err) => Err(TableDropError::Connector(ConnectorError::Postgres( + PostgresError::Postgres(dbg!(err)), + ))), + } + } +} diff --git a/connector_arrow/src/postgres/types.rs b/connector_arrow/src/postgres/types.rs index ac040ff..4e88e0a 100644 --- a/connector_arrow/src/postgres/types.rs +++ b/connector_arrow/src/postgres/types.rs @@ -1,22 +1,22 @@ use std::sync::Arc; -use arrow::datatypes::{DataType as ArrowType, Field, Schema}; +use arrow::datatypes::{DataType as ArrowType, *}; use postgres::types::Type as PgType; use crate::errors::ConnectorError; -pub fn convert_schema( +pub fn pg_stmt_to_arrow( stmt: &postgres::Statement, ) -> Result, ConnectorError> { let fields: Vec<_> = stmt .columns() .iter() - .map(|col| Field::new(col.name(), convert_type(col.type_()), true)) + .map(|col| Field::new(col.name(), pg_ty_to_arrow(col.type_()), true)) .collect(); Ok(Arc::new(Schema::new(fields))) } -pub fn convert_type(ty: &PgType) -> ArrowType { +pub fn pg_ty_to_arrow(ty: &PgType) -> ArrowType { match ty.name() { "int2" => ArrowType::Int16, "int4" => ArrowType::Int32, @@ -50,3 +50,47 @@ pub fn convert_type(ty: &PgType) -> ArrowType { _ => unimplemented!("{}", ty.name()), } } + +pub(crate) fn arrow_ty_to_pg(data_type: &ArrowType) -> PgType { + match data_type { + // there is no Null type in PostgreSQL, so we fallback to some other type that is nullable + ArrowType::Null => PgType::INT2, + + ArrowType::Boolean => PgType::BOOL, + ArrowType::Int8 => PgType::CHAR, + ArrowType::Int16 => PgType::INT2, + ArrowType::Int32 => PgType::INT4, + ArrowType::Int64 => PgType::INT8, + // ArrowType::UInt8 => PgType::, + // ArrowType::UInt16 => PgType::, + ArrowType::UInt32 => PgType::OID, + // ArrowType::UInt64 => PgType::, + // ArrowType::Float16 => PgType::, + ArrowType::Float32 => PgType::FLOAT4, + ArrowType::Float64 => PgType::FLOAT8, + ArrowType::Timestamp(_, None) => PgType::TIMESTAMP, + ArrowType::Timestamp(_, Some(_)) => PgType::TIMESTAMPTZ, + // ArrowType::Date32 => PgType::, + // ArrowType::Date64 => PgType::, + // ArrowType::Time32(_) => PgType::, + // ArrowType::Time64(_) => PgType::, + ArrowType::Duration(_) => PgType::INTERNAL, + // ArrowType::Interval(_) => PgType::, + ArrowType::Binary => PgType::BYTEA, + ArrowType::FixedSizeBinary(_) => PgType::BYTEA, + ArrowType::LargeBinary => PgType::BYTEA, + ArrowType::Utf8 => PgType::TEXT, + ArrowType::LargeUtf8 => PgType::TEXT, + // ArrowType::List(_) => PgType::, + // ArrowType::FixedSizeList(_, _) => PgType::, + // ArrowType::LargeList(_) => PgType::, + // ArrowType::Struct(_) => PgType::, + // ArrowType::Union(_, _) => PgType::, + // ArrowType::Dictionary(_, _) => PgType::, + // ArrowType::Decimal128(_, _) => PgType::, + // ArrowType::Decimal256(_, _) => PgType::, + // ArrowType::Map(_, _) => PgType::, + // ArrowType::RunEndEncoded(_, _) => PgType::, + _ => unimplemented!("data type: {data_type}"), + } +} diff --git a/connector_arrow/src/sqlite/query.rs b/connector_arrow/src/sqlite/query.rs index da9adab..fb960c0 100644 --- a/connector_arrow/src/sqlite/query.rs +++ b/connector_arrow/src/sqlite/query.rs @@ -174,7 +174,7 @@ impl<'r> ProduceTy<'r, LargeBinaryType> for Value { } } -crate::impl_produce_unused!( +crate::impl_produce_unsupported!( Value, ( BooleanType, diff --git a/connector_arrow/src/util/mod.rs b/connector_arrow/src/util/mod.rs index 979137d..d34e236 100644 --- a/connector_arrow/src/util/mod.rs +++ b/connector_arrow/src/util/mod.rs @@ -3,9 +3,11 @@ mod arrow_reader; mod row_collect; +mod row_reader; mod row_writer; pub mod transport; pub use arrow_reader::ArrowReader; pub use row_collect::{collect_rows_to_arrow, CellReader, RowsReader}; +pub use row_reader::ArrayCellRef; pub use row_writer::ArrowRowWriter; diff --git a/connector_arrow/src/util/row_reader.rs b/connector_arrow/src/util/row_reader.rs new file mode 100644 index 0000000..2992767 --- /dev/null +++ b/connector_arrow/src/util/row_reader.rs @@ -0,0 +1,157 @@ +use arrow::array::{ArrayRef, AsArray}; +use arrow::datatypes::*; + +use crate::types::{ArrowType, FixedSizeBinaryType}; +use crate::{impl_produce_unsupported, ConnectorError}; + +use super::transport::{Produce, ProduceTy}; + +#[derive(Debug)] +pub struct ArrayCellRef<'a> { + pub array: &'a ArrayRef, + pub field: &'a Field, + pub row_number: usize, +} + +impl<'r> Produce<'r> for &ArrayCellRef<'r> {} + +impl<'r> ProduceTy<'r, BooleanType> for &ArrayCellRef<'r> { + fn produce(self) -> Result<::Native, ConnectorError> { + let array = self.array.as_boolean(); + Ok(array.value(self.row_number)) + } + + fn produce_opt(self) -> Result::Native>, ConnectorError> { + Ok(if self.array.is_null(self.row_number) { + None + } else { + let array = self.array.as_boolean(); + Some(array.value(self.row_number)) + }) + } +} + +macro_rules! impl_produce_ty { + ($($t: ty,)+) => { + $( + impl<'r> ProduceTy<'r, $t> for &ArrayCellRef<'r> { + fn produce(self) -> Result<<$t as ArrowType>::Native, ConnectorError> { + let array = self.array.as_primitive::<$t>(); + Ok(array.value(self.row_number)) + } + + fn produce_opt(self) -> Result::Native>, ConnectorError> { + Ok(if self.array.is_null(self.row_number) { + None + } else { + let array = self.array.as_primitive::<$t>(); + Some(array.value(self.row_number)) + }) + } + } + )+ + }; +} + +impl_produce_ty!( + Int8Type, + Int16Type, + Int32Type, + Int64Type, + UInt8Type, + UInt16Type, + UInt32Type, + UInt64Type, + Float16Type, + Float32Type, + Float64Type, + TimestampSecondType, + TimestampMillisecondType, + TimestampMicrosecondType, + TimestampNanosecondType, + Date32Type, + Date64Type, + Time32SecondType, + Time32MillisecondType, + Time64MicrosecondType, + Time64NanosecondType, + IntervalYearMonthType, + IntervalDayTimeType, + IntervalMonthDayNanoType, + DurationSecondType, + DurationMillisecondType, + DurationMicrosecondType, + DurationNanosecondType, +); + +// TODO: implement ProduceTy for byte array types without cloning + +impl<'r> ProduceTy<'r, BinaryType> for &ArrayCellRef<'r> { + fn produce(self) -> Result, ConnectorError> { + let array = self.array.as_bytes::(); + Ok(array.value(self.row_number).to_vec()) + } + fn produce_opt(self) -> Result::Native>, ConnectorError> { + Ok(if self.array.is_null(self.row_number) { + None + } else { + Some(ProduceTy::::produce(self)?) + }) + } +} +impl<'r> ProduceTy<'r, LargeBinaryType> for &ArrayCellRef<'r> { + fn produce(self) -> Result, ConnectorError> { + let array = self.array.as_bytes::(); + Ok(array.value(self.row_number).to_vec()) + } + fn produce_opt(self) -> Result::Native>, ConnectorError> { + Ok(if self.array.is_null(self.row_number) { + None + } else { + Some(ProduceTy::::produce(self)?) + }) + } +} +impl<'r> ProduceTy<'r, FixedSizeBinaryType> for &ArrayCellRef<'r> { + fn produce(self) -> Result, ConnectorError> { + let array = self.array.as_fixed_size_binary(); + Ok(array.value(self.row_number).to_vec()) + } + fn produce_opt( + self, + ) -> Result::Native>, ConnectorError> { + Ok(if self.array.is_null(self.row_number) { + None + } else { + Some(ProduceTy::::produce(self)?) + }) + } +} +impl<'r> ProduceTy<'r, Utf8Type> for &ArrayCellRef<'r> { + fn produce(self) -> Result { + let array = self.array.as_bytes::(); + Ok(array.value(self.row_number).to_string()) + } + fn produce_opt(self) -> Result::Native>, ConnectorError> { + Ok(if self.array.is_null(self.row_number) { + None + } else { + Some(ProduceTy::::produce(self)?) + }) + } +} +impl<'r> ProduceTy<'r, LargeUtf8Type> for &ArrayCellRef<'r> { + fn produce(self) -> Result { + let array = self.array.as_bytes::(); + Ok(array.value(self.row_number).to_string()) + } + fn produce_opt(self) -> Result::Native>, ConnectorError> { + Ok(if self.array.is_null(self.row_number) { + None + } else { + Some(ProduceTy::::produce(self)?) + }) + } +} + +impl_produce_unsupported!(&ArrayCellRef<'r>, (Decimal128Type, Decimal256Type,)); diff --git a/connector_arrow/src/util/transport.rs b/connector_arrow/src/util/transport.rs index 3bb47c0..c8ef47a 100644 --- a/connector_arrow/src/util/transport.rs +++ b/connector_arrow/src/util/transport.rs @@ -207,15 +207,31 @@ pub mod print { } #[macro_export] -macro_rules! impl_produce_unused { +macro_rules! impl_produce_unsupported { ($p: ty, ($($t: ty,)+)) => { $( impl<'r> $crate::util::transport::ProduceTy<'r, $t> for $p { fn produce(self) -> Result<<$t as $crate::types::ArrowType>::Native, ConnectorError> { - unimplemented!(); + unimplemented!("unsupported"); } fn produce_opt(self) -> Result::Native>, ConnectorError> { - unimplemented!(); + unimplemented!("unsupported"); + } + } + )+ + }; +} + +#[macro_export] +macro_rules! impl_consume_unsupported { + ($c: ty, ($($t: ty,)+)) => { + $( + impl $crate::util::transport::ConsumeTy<$t> for $c { + fn consume(&mut self, _val: <$t as $crate::types::ArrowType>::Native) { + unimplemented!("unsupported"); + } + fn consume_null(&mut self) { + unimplemented!("unsupported"); } } )+ diff --git a/connector_arrow/tests/it/test_duckdb.rs b/connector_arrow/tests/it/test_duckdb.rs index 73099b2..9e10f73 100644 --- a/connector_arrow/tests/it/test_duckdb.rs +++ b/connector_arrow/tests/it/test_duckdb.rs @@ -18,35 +18,44 @@ fn coerce_ty(ty: &DataType) -> Option { #[test] fn roundtrip_basic_small() { + let table_name = "roundtrip_basic_small"; + let mut conn = init(); let path = PathBuf::from_str("tests/data/basic_small.parquet").unwrap(); - super::util::roundtrip_of_parquet(&mut conn, path.as_path(), coerce_ty); + super::util::roundtrip_of_parquet(&mut conn, path.as_path(), table_name, coerce_ty); } #[test] fn roundtrip_empty() { + let table_name = "roundtrip_empty"; + let mut conn = init(); let path = PathBuf::from_str("tests/data/empty.parquet").unwrap(); - super::util::roundtrip_of_parquet(&mut conn, path.as_path(), coerce_ty); + super::util::roundtrip_of_parquet(&mut conn, path.as_path(), table_name, coerce_ty); } #[test] fn introspection_basic_small() { + let table_name = "introspection_basic_small"; + let mut conn = init(); let path = PathBuf::from_str("tests/data/basic_small.parquet").unwrap(); - let (table, schema_file, _) = - super::util::load_parquet_if_not_exists(&mut conn, path.as_path()); + let (schema_file, _) = + super::util::load_parquet_if_not_exists(&mut conn, path.as_path(), table_name); let schema_file_coerced = super::util::cast_schema(&schema_file, &coerce_ty); - let schema_introspection = conn.table_get(&table).unwrap(); + let schema_introspection = conn.table_get(table_name).unwrap(); similar_asserts::assert_eq!(schema_file_coerced, schema_introspection); } #[test] fn schema_edit_01() { + let table_name = "schema_edit_01"; + let mut conn = init(); let path = PathBuf::from_str("tests/data/basic_small.parquet").unwrap(); - let (_, schema, _) = super::util::load_parquet_if_not_exists(&mut conn, path.as_path()); + let (schema, _) = + super::util::load_parquet_if_not_exists(&mut conn, path.as_path(), table_name); let _ignore = conn.table_drop("test_table2"); diff --git a/connector_arrow/tests/it/test_postgres.rs b/connector_arrow/tests/it/test_postgres.rs index a416d09..5658f3c 100644 --- a/connector_arrow/tests/it/test_postgres.rs +++ b/connector_arrow/tests/it/test_postgres.rs @@ -1,9 +1,12 @@ -use arrow::util::pretty::pretty_format_batches; -use connector_arrow::postgres::{PostgresConnection, ProtocolExtended, ProtocolSimple}; -use insta::assert_display_snapshot; +use arrow::{datatypes::DataType, util::pretty::pretty_format_batches}; +use insta::{assert_debug_snapshot, assert_display_snapshot}; use postgres::{Client, NoTls}; +use std::{env, path::PathBuf, str::FromStr}; -use std::env; +use connector_arrow::{ + api::{SchemaEdit, SchemaGet}, + postgres::{PostgresConnection, ProtocolExtended, ProtocolSimple}, +}; fn init() -> Client { let _ = env_logger::builder().is_test(true).try_init(); @@ -12,6 +15,72 @@ fn init() -> Client { Client::connect(&dburl, NoTls).unwrap() } +fn coerce_ty(ty: &DataType) -> Option { + match ty { + DataType::Null => Some(DataType::Int16), + DataType::Utf8 => Some(DataType::LargeUtf8), + _ => None, + } +} + +#[test] +fn roundtrip_basic_small() { + let table_name = "roundtrip_basic_small"; + let mut conn = init(); + let mut conn = PostgresConnection::::new(&mut conn); + + let path = PathBuf::from_str("tests/data/basic_small.parquet").unwrap(); + super::util::roundtrip_of_parquet(&mut conn, path.as_path(), table_name, coerce_ty); +} + +#[test] +fn roundtrip_empty() { + let table_name = "roundtrip_empty"; + let mut conn = init(); + let mut conn = PostgresConnection::::new(&mut conn); + + let path = PathBuf::from_str("tests/data/empty.parquet").unwrap(); + super::util::roundtrip_of_parquet(&mut conn, path.as_path(), table_name, coerce_ty); +} + +#[test] +fn introspection_basic_small() { + let table_name = "introspection_basic_small"; + let mut conn = init(); + let mut conn = PostgresConnection::::new(&mut conn); + + let path = PathBuf::from_str("tests/data/basic_small.parquet").unwrap(); + let (schema_file, _) = + super::util::load_parquet_if_not_exists(&mut conn, path.as_path(), table_name); + let schema_file_coerced = super::util::cast_schema(&schema_file, &coerce_ty); + + let schema_introspection = conn.table_get(table_name).unwrap(); + similar_asserts::assert_eq!(schema_file_coerced, schema_introspection); +} + +#[test] +fn schema_edit_01() { + let table_name = "schema_edit_01"; + let mut conn = init(); + let mut conn = PostgresConnection::::new(&mut conn); + + let path = PathBuf::from_str("tests/data/basic_small.parquet").unwrap(); + let (schema, _) = + super::util::load_parquet_if_not_exists(&mut conn, path.as_path(), table_name); + + let _ignore = conn.table_drop("test_table2"); + + conn.table_create("test_table2", schema.clone()).unwrap(); + assert_debug_snapshot!( + conn.table_create("test_table2", schema.clone()).unwrap_err(), @"TableExists" + ); + + conn.table_drop("test_table2").unwrap(); + assert_debug_snapshot!( + conn.table_drop("test_table2").unwrap_err(), @"TableNonexistent" + ); +} + #[test] fn test_protocol_simple() { let mut conn = init(); diff --git a/connector_arrow/tests/it/test_sqlite.rs b/connector_arrow/tests/it/test_sqlite.rs index 69e6c7d..e1fce4b 100644 --- a/connector_arrow/tests/it/test_sqlite.rs +++ b/connector_arrow/tests/it/test_sqlite.rs @@ -36,17 +36,21 @@ fn coerce_ty(ty: &DataType) -> Option { #[test] fn roundtrip_basic_small() { + let table_name = "roundtrip_basic_small"; + let mut conn = init(); let path = PathBuf::from_str("tests/data/basic_small.parquet").unwrap(); - super::util::roundtrip_of_parquet(&mut conn, path.as_path(), coerce_ty); + super::util::roundtrip_of_parquet(&mut conn, path.as_path(), table_name, coerce_ty); } #[test] #[ignore] // SQLite cannot infer schema from an empty response, as there is no rows to infer from fn roundtrip_empty() { + let table_name = "roundtrip_empty"; + let mut conn = init(); let path = PathBuf::from_str("tests/data/empty.parquet").unwrap(); - super::util::roundtrip_of_parquet(&mut conn, path.as_path(), coerce_ty); + super::util::roundtrip_of_parquet(&mut conn, path.as_path(), table_name, coerce_ty); } #[test] @@ -66,21 +70,26 @@ fn query_04() { #[test] #[ignore] // cannot introspect the Null column fn introspection_basic_small() { + let table_name = "introspection_basic_small"; + let mut conn = init(); let path = PathBuf::from_str("tests/data/basic_small.parquet").unwrap(); - let (table, schema_file, _) = - super::util::load_parquet_if_not_exists(&mut conn, path.as_path()); + let (schema_file, _) = + super::util::load_parquet_if_not_exists(&mut conn, path.as_path(), table_name); let schema_file_coerced = super::util::cast_schema(&schema_file, &coerce_ty); - let schema_introspection = conn.table_get(&table).unwrap(); + let schema_introspection = conn.table_get(table_name).unwrap(); similar_asserts::assert_eq!(schema_file_coerced, schema_introspection); } #[test] fn schema_edit_01() { + let table_name = "schema_edit_01"; + let mut conn = init(); let path = PathBuf::from_str("tests/data/basic_small.parquet").unwrap(); - let (_, schema, _) = super::util::load_parquet_if_not_exists(&mut conn, path.as_path()); + let (schema, _) = + super::util::load_parquet_if_not_exists(&mut conn, path.as_path(), table_name); let _ignore = conn.table_drop("test_table2"); diff --git a/connector_arrow/tests/it/util.rs b/connector_arrow/tests/it/util.rs index 1a602c7..0c7f1ad 100644 --- a/connector_arrow/tests/it/util.rs +++ b/connector_arrow/tests/it/util.rs @@ -14,7 +14,8 @@ use connector_arrow::ConnectorError; pub fn load_parquet_if_not_exists( conn: &mut C, file_path: &Path, -) -> (String, SchemaRef, Vec) + table_name: &str, +) -> (SchemaRef, Vec) where C: Connection + SchemaEdit, { @@ -32,12 +33,9 @@ where }; // table create - let table_name = file_path.file_name().unwrap().to_str().unwrap().to_string(); - match conn.table_create(&table_name, schema.clone()) { + match conn.table_create(table_name, schema.clone()) { Ok(_) => (), - Err(connector_arrow::TableCreateError::TableExists) => { - return (table_name, schema, arrow_file) - } + Err(connector_arrow::TableCreateError::TableExists) => return (schema, arrow_file), Err(e) => panic!("{}", e), } @@ -50,16 +48,16 @@ where appender.finish().unwrap(); } - (table_name, schema, arrow_file) + (schema, arrow_file) } -#[track_caller] -pub fn roundtrip_of_parquet(conn: &mut C, file_path: &Path, coerce_ty: F) +// #[track_caller] +pub fn roundtrip_of_parquet(conn: &mut C, file_path: &Path, table_name: &str, coerce_ty: F) where C: Connection + SchemaEdit, F: Fn(&DataType) -> Option, { - let (table_name, schema_file, arrow_file) = load_parquet_if_not_exists(conn, file_path); + let (schema_file, arrow_file) = load_parquet_if_not_exists(conn, file_path, table_name); // read from table let (schema_query, arrow_query) = {