diff --git a/Cargo.lock b/Cargo.lock index 8046c17..c1186d2 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -511,6 +511,7 @@ dependencies = [ "postgres-protocol", "rand", "rand_chacha", + "rstest", "rusqlite", "rust_decimal", "rust_decimal_macros", @@ -827,6 +828,12 @@ version = "0.28.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "4271d37baee1b8c7e4b708028c57d816cf9d2434acb33a549475f78c181f6253" +[[package]] +name = "glob" +version = "0.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d2fabcfbdc87f4758337ca535fb41a6d701b65693ce38287d856d1674551ec9b" + [[package]] name = "half" version = "2.3.1" @@ -1475,6 +1482,12 @@ version = "0.8.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "c08c74e62047bb2de4ff487b251e4a92e24f48745648451635cec7d591162d9f" +[[package]] +name = "relative-path" +version = "1.9.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e898588f33fdd5b9420719948f9f2a32c922a246964576f71ba7f24f80610fbc" + [[package]] name = "rend" version = "0.4.1" @@ -1513,6 +1526,33 @@ dependencies = [ "syn 1.0.109", ] +[[package]] +name = "rstest" +version = "0.18.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "97eeab2f3c0a199bc4be135c36c924b6590b88c377d416494288c14f2db30199" +dependencies = [ + "rstest_macros", + "rustc_version", +] + +[[package]] +name = "rstest_macros" +version = "0.18.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d428f8247852f894ee1be110b375111b586d4fa431f6c46e64ba5a0dcccbe605" +dependencies = [ + "cfg-if", + "glob", + "proc-macro2", + "quote", + "regex", + "relative-path", + "rustc_version", + "syn 2.0.48", + "unicode-ident", +] + [[package]] name = "rusqlite" version = "0.30.0" @@ -1560,6 +1600,15 @@ version = "0.1.23" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d626bb9dae77e28219937af045c257c28bfd3f69333c512553507f5f9798cb76" +[[package]] +name = "rustc_version" +version = "0.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bfa0f585226d2e68097d4f95d113b15b83a82e819ab25717ec0590d9584ef366" +dependencies = [ + "semver", +] + [[package]] name = "rustix" version = "0.38.30" @@ -1597,6 +1646,12 @@ version = "4.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1c107b6f4780854c8b126e228ea8869f4d7b71260f962fefb57b996b8959ba6b" +[[package]] +name = "semver" +version = "1.0.22" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "92d43fe69e652f3df9bdc2b85b2854a0825b86e4fb76bc44d945137d053639ca" + [[package]] name = "serde" version = "1.0.196" diff --git a/connector_arrow/Cargo.toml b/connector_arrow/Cargo.toml index 97574cd..daafaab 100644 --- a/connector_arrow/Cargo.toml +++ b/connector_arrow/Cargo.toml @@ -62,6 +62,7 @@ similar-asserts = { version = "1.5.0" } half = "2.3.1" rand = { version = "0.8.5", default-features = false } rand_chacha = "0.3.1" +rstest = { version = "0.18.2", default-features = false } [features] diff --git a/connector_arrow/tests/it/generator.rs b/connector_arrow/tests/it/generator.rs index 0002383..83be716 100644 --- a/connector_arrow/tests/it/generator.rs +++ b/connector_arrow/tests/it/generator.rs @@ -4,8 +4,10 @@ use half::f16; use rand::Rng; use std::sync::Arc; +use super::spec::*; + pub fn generate_batch( - column_specs: Vec, + column_specs: ArrowGenSpec, rng: &mut R, ) -> (SchemaRef, Vec) { let mut arrays = Vec::new(); @@ -27,209 +29,6 @@ pub fn generate_batch( } } -pub fn spec_all_types() -> Vec { - domains_to_batch_spec( - &[ - DataType::Null, - DataType::Boolean, - // DataType::Int8, - DataType::Int16, - DataType::Int32, - DataType::Int64, - // DataType::UInt8, - // DataType::UInt16, - // DataType::UInt32, - // DataType::UInt64, - // DataType::Float16, - DataType::Float32, - DataType::Float64, - // DataType::Timestamp(TimeUnit::Nanosecond, None), - // DataType::Timestamp(TimeUnit::Microsecond, None), - // DataType::Timestamp(TimeUnit::Millisecond, None), - // DataType::Timestamp(TimeUnit::Second, None), - // DataType::Timestamp(TimeUnit::Nanosecond, Some(Arc::from("+07:30"))), - // DataType::Timestamp(TimeUnit::Microsecond, Some(Arc::from("+07:30"))), - // DataType::Timestamp(TimeUnit::Millisecond, Some(Arc::from("+07:30"))), - // DataType::Timestamp(TimeUnit::Second, Some(Arc::from("+07:30"))), - // DataType::Time32(TimeUnit::Millisecond), - // DataType::Time32(TimeUnit::Second), - // DataType::Time64(TimeUnit::Nanosecond), - // DataType::Time64(TimeUnit::Microsecond), - // DataType::Duration(TimeUnit::Nanosecond), - // DataType::Duration(TimeUnit::Microsecond), - // DataType::Duration(TimeUnit::Millisecond), - // DataType::Duration(TimeUnit::Second), - // DataType::Interval(IntervalUnit::YearMonth), - // DataType::Interval(IntervalUnit::MonthDayNano), - // DataType::Interval(IntervalUnit::DayTime), - ], - &[false, true], - &[ValueGenProcess::High], - ) -} - -pub fn spec_empty() -> Vec { - domains_to_batch_spec( - &[DataType::Null, DataType::Int64, DataType::Float64], - &[false, true], - &[], - ) -} - -pub fn spec_null_bool() -> Vec { - domains_to_batch_spec( - &[DataType::Null, DataType::Boolean], - &[false, true], - &VALUE_GEN_PROCESS_ALL, - ) -} - -pub fn spec_numeric() -> Vec { - domains_to_batch_spec( - &[ - DataType::Int8, - DataType::Int16, - DataType::Int32, - DataType::Int64, - DataType::UInt8, - DataType::UInt16, - DataType::UInt32, - DataType::UInt64, - DataType::Float16, - DataType::Float32, - DataType::Float64, - ], - &[false, true], - &VALUE_GEN_PROCESS_ALL, - ) -} - -pub fn spec_timestamp() -> Vec { - domains_to_batch_spec( - &[ - DataType::Timestamp(TimeUnit::Nanosecond, None), - DataType::Timestamp(TimeUnit::Microsecond, None), - DataType::Timestamp(TimeUnit::Millisecond, None), - DataType::Timestamp(TimeUnit::Second, None), - DataType::Timestamp(TimeUnit::Nanosecond, Some(Arc::from("+07:30"))), - DataType::Timestamp(TimeUnit::Microsecond, Some(Arc::from("+07:30"))), - DataType::Timestamp(TimeUnit::Millisecond, Some(Arc::from("+07:30"))), - DataType::Timestamp(TimeUnit::Second, Some(Arc::from("+07:30"))), - ], - &[true], - &VALUE_GEN_PROCESS_ALL, - ) -} -pub fn spec_date() -> Vec { - domains_to_batch_spec( - &[DataType::Date32, DataType::Date64], - &[true], - &VALUE_GEN_PROCESS_ALL, - ) -} -pub fn spec_time() -> Vec { - domains_to_batch_spec( - &[ - DataType::Time32(TimeUnit::Millisecond), - DataType::Time32(TimeUnit::Second), - DataType::Time64(TimeUnit::Nanosecond), - DataType::Time64(TimeUnit::Microsecond), - ], - &[true], - &VALUE_GEN_PROCESS_ALL, - ) -} -pub fn spec_duration() -> Vec { - domains_to_batch_spec( - &[ - DataType::Duration(TimeUnit::Nanosecond), - DataType::Duration(TimeUnit::Microsecond), - DataType::Duration(TimeUnit::Millisecond), - DataType::Duration(TimeUnit::Second), - ], - &[true], - &VALUE_GEN_PROCESS_ALL, - ) -} -pub fn spec_interval() -> Vec { - domains_to_batch_spec( - &[ - DataType::Interval(IntervalUnit::YearMonth), - DataType::Interval(IntervalUnit::MonthDayNano), - DataType::Interval(IntervalUnit::DayTime), - ], - &[true], - &VALUE_GEN_PROCESS_ALL, - ) -} - -pub fn domains_to_batch_spec( - data_types_domain: &[DataType], - is_nullable_domain: &[bool], - value_gen_process_domain: &[ValueGenProcess], -) -> Vec { - let mut columns = Vec::new(); - for data_type in data_types_domain { - for is_nullable in is_nullable_domain { - let is_nullable = *is_nullable; - if matches!(data_type, &DataType::Null) && !is_nullable { - continue; - } - - let mut field_name = data_type.to_string(); - if is_nullable { - field_name += "_null"; - } - let mut col = ColumnSpec { - field_name, - data_type: data_type.clone(), - is_nullable, - values: Vec::new(), - }; - - for gen_process in value_gen_process_domain { - col.values.push(ValuesSpec { - gen_process: if matches!(gen_process, ValueGenProcess::Null) && !is_nullable { - ValueGenProcess::RandomUniform - } else { - *gen_process - }, - repeat: 1, - }); - } - columns.push(col); - } - } - columns -} - -#[derive(Clone, Copy)] -pub enum ValueGenProcess { - Null, - Low, - High, - RandomUniform, -} - -const VALUE_GEN_PROCESS_ALL: [ValueGenProcess; 4] = [ - ValueGenProcess::Low, - ValueGenProcess::High, - ValueGenProcess::Null, - ValueGenProcess::RandomUniform, -]; - -struct ValuesSpec { - gen_process: ValueGenProcess, - repeat: usize, -} - -pub struct ColumnSpec { - field_name: String, - is_nullable: bool, - data_type: DataType, - values: Vec, -} - fn count_values(values: &[ValuesSpec]) -> usize { values.iter().map(|v| v.repeat).sum() } diff --git a/connector_arrow/tests/it/main.rs b/connector_arrow/tests/it/main.rs index f3a2410..e56770b 100644 --- a/connector_arrow/tests/it/main.rs +++ b/connector_arrow/tests/it/main.rs @@ -1,4 +1,5 @@ mod generator; +mod spec; mod tests; mod util; diff --git a/connector_arrow/tests/it/spec.rs b/connector_arrow/tests/it/spec.rs new file mode 100644 index 0000000..ed24421 --- /dev/null +++ b/connector_arrow/tests/it/spec.rs @@ -0,0 +1,216 @@ +use arrow::datatypes::*; +use std::sync::Arc; + +pub type ArrowGenSpec = Vec; + +#[derive(Clone, Copy)] +pub enum ValueGenProcess { + Null, + Low, + High, + RandomUniform, +} + +const VALUE_GEN_PROCESS_ALL: [ValueGenProcess; 4] = [ + ValueGenProcess::Low, + ValueGenProcess::High, + ValueGenProcess::Null, + ValueGenProcess::RandomUniform, +]; + +pub struct ValuesSpec { + pub gen_process: ValueGenProcess, + pub repeat: usize, +} + +pub struct ColumnSpec { + pub field_name: String, + pub is_nullable: bool, + pub data_type: DataType, + pub values: Vec, +} + +pub fn all_types() -> Vec { + domains_to_batch_spec( + &[ + DataType::Null, + DataType::Boolean, + // DataType::Int8, + DataType::Int16, + DataType::Int32, + DataType::Int64, + // DataType::UInt8, + // DataType::UInt16, + // DataType::UInt32, + // DataType::UInt64, + // DataType::Float16, + DataType::Float32, + DataType::Float64, + // DataType::Timestamp(TimeUnit::Nanosecond, None), + // DataType::Timestamp(TimeUnit::Microsecond, None), + // DataType::Timestamp(TimeUnit::Millisecond, None), + // DataType::Timestamp(TimeUnit::Second, None), + // DataType::Timestamp(TimeUnit::Nanosecond, Some(Arc::from("+07:30"))), + // DataType::Timestamp(TimeUnit::Microsecond, Some(Arc::from("+07:30"))), + // DataType::Timestamp(TimeUnit::Millisecond, Some(Arc::from("+07:30"))), + // DataType::Timestamp(TimeUnit::Second, Some(Arc::from("+07:30"))), + // DataType::Time32(TimeUnit::Millisecond), + // DataType::Time32(TimeUnit::Second), + // DataType::Time64(TimeUnit::Nanosecond), + // DataType::Time64(TimeUnit::Microsecond), + // DataType::Duration(TimeUnit::Nanosecond), + // DataType::Duration(TimeUnit::Microsecond), + // DataType::Duration(TimeUnit::Millisecond), + // DataType::Duration(TimeUnit::Second), + // DataType::Interval(IntervalUnit::YearMonth), + // DataType::Interval(IntervalUnit::MonthDayNano), + // DataType::Interval(IntervalUnit::DayTime), + ], + &[false, true], + &[ValueGenProcess::High], + ) +} + +pub fn empty() -> Vec { + domains_to_batch_spec( + &[DataType::Null, DataType::Int64, DataType::Float64], + &[false, true], + &[], + ) +} + +pub fn null_bool() -> Vec { + domains_to_batch_spec( + &[DataType::Null, DataType::Boolean], + &[false, true], + &VALUE_GEN_PROCESS_ALL, + ) +} + +pub fn numeric() -> Vec { + domains_to_batch_spec( + &[ + // DataType::Int8, + DataType::Int16, + DataType::Int32, + DataType::Int64, + // DataType::UInt8, + // DataType::UInt16, + // DataType::UInt32, + // DataType::UInt64, + // DataType::Float16, + DataType::Float32, + DataType::Float64, + ], + &[false, true], + &VALUE_GEN_PROCESS_ALL, + ) +} + +#[allow(dead_code)] +pub fn timestamp() -> Vec { + domains_to_batch_spec( + &[ + DataType::Timestamp(TimeUnit::Nanosecond, None), + DataType::Timestamp(TimeUnit::Microsecond, None), + DataType::Timestamp(TimeUnit::Millisecond, None), + DataType::Timestamp(TimeUnit::Second, None), + DataType::Timestamp(TimeUnit::Nanosecond, Some(Arc::from("+07:30"))), + DataType::Timestamp(TimeUnit::Microsecond, Some(Arc::from("+07:30"))), + DataType::Timestamp(TimeUnit::Millisecond, Some(Arc::from("+07:30"))), + DataType::Timestamp(TimeUnit::Second, Some(Arc::from("+07:30"))), + ], + &[true], + &VALUE_GEN_PROCESS_ALL, + ) +} + +#[allow(dead_code)] +pub fn date() -> Vec { + domains_to_batch_spec( + &[DataType::Date32, DataType::Date64], + &[true], + &VALUE_GEN_PROCESS_ALL, + ) +} + +#[allow(dead_code)] +pub fn time() -> Vec { + domains_to_batch_spec( + &[ + DataType::Time32(TimeUnit::Millisecond), + DataType::Time32(TimeUnit::Second), + DataType::Time64(TimeUnit::Nanosecond), + DataType::Time64(TimeUnit::Microsecond), + ], + &[true], + &VALUE_GEN_PROCESS_ALL, + ) +} + +#[allow(dead_code)] +pub fn duration() -> Vec { + domains_to_batch_spec( + &[ + DataType::Duration(TimeUnit::Nanosecond), + DataType::Duration(TimeUnit::Microsecond), + DataType::Duration(TimeUnit::Millisecond), + DataType::Duration(TimeUnit::Second), + ], + &[true], + &VALUE_GEN_PROCESS_ALL, + ) +} + +#[allow(dead_code)] +pub fn interval() -> Vec { + domains_to_batch_spec( + &[ + DataType::Interval(IntervalUnit::YearMonth), + DataType::Interval(IntervalUnit::MonthDayNano), + DataType::Interval(IntervalUnit::DayTime), + ], + &[true], + &VALUE_GEN_PROCESS_ALL, + ) +} + +fn domains_to_batch_spec( + data_types_domain: &[DataType], + is_nullable_domain: &[bool], + value_gen_process_domain: &[ValueGenProcess], +) -> Vec { + let mut columns = Vec::new(); + for data_type in data_types_domain { + for is_nullable in is_nullable_domain { + let is_nullable = *is_nullable; + if matches!(data_type, &DataType::Null) && !is_nullable { + continue; + } + + let mut field_name = data_type.to_string(); + if is_nullable { + field_name += "_null"; + } + let mut col = ColumnSpec { + field_name, + data_type: data_type.clone(), + is_nullable, + values: Vec::new(), + }; + + for gen_process in value_gen_process_domain { + col.values.push(ValuesSpec { + gen_process: if matches!(gen_process, ValueGenProcess::Null) && !is_nullable { + ValueGenProcess::RandomUniform + } else { + *gen_process + }, + repeat: 1, + }); + } + columns.push(col); + } + } + columns +} diff --git a/connector_arrow/tests/it/test_duckdb.rs b/connector_arrow/tests/it/test_duckdb.rs index 970b067..53a15af 100644 --- a/connector_arrow/tests/it/test_duckdb.rs +++ b/connector_arrow/tests/it/test_duckdb.rs @@ -1,3 +1,6 @@ +use super::spec; +use rstest::*; + fn init() -> duckdb::Connection { let _ = env_logger::builder().is_test(true).try_init(); @@ -10,31 +13,13 @@ fn query_01() { super::tests::query_01(&mut conn); } -#[test] -fn roundtrip_empty() { - let table_name = "roundtrip_empty"; - - let mut conn = init(); - let column_spec = super::generator::spec_empty(); - super::tests::roundtrip(&mut conn, table_name, column_spec); -} - -#[test] -fn roundtrip_null_bool() { - let table_name = "roundtrip_null_bool"; - - let mut conn = init(); - let column_spec = super::generator::spec_null_bool(); - super::tests::roundtrip(&mut conn, table_name, column_spec); -} - -#[test] -fn roundtrip_numeric() { - let table_name = "roundtrip_numeric"; - +#[rstest] +#[case::empty("roundtrip::empty", spec::empty())] +#[case::null_bool("roundtrip::null_bool", spec::null_bool())] +#[case::numeric("roundtrip::numeric", spec::numeric())] +fn roundtrip(#[case] table_name: &str, #[case] spec: spec::ArrowGenSpec) { let mut conn = init(); - let column_spec = super::generator::spec_numeric(); - super::tests::roundtrip(&mut conn, table_name, column_spec); + super::tests::roundtrip(&mut conn, table_name, spec); } #[test] @@ -42,8 +27,7 @@ fn schema_get() { let table_name = "schema_get"; let mut conn = init(); - let column_spec = super::generator::spec_all_types(); - super::tests::schema_get(&mut conn, table_name, column_spec); + super::tests::schema_get(&mut conn, table_name, spec::all_types()); } #[test] @@ -51,6 +35,5 @@ fn schema_edit() { let table_name = "schema_edit"; let mut conn = init(); - let column_spec = super::generator::spec_all_types(); - super::tests::schema_edit(&mut conn, table_name, column_spec); + super::tests::schema_edit(&mut conn, table_name, spec::all_types()); } diff --git a/connector_arrow/tests/it/test_postgres_extended.rs b/connector_arrow/tests/it/test_postgres_extended.rs index 28fb26d..dae5a48 100644 --- a/connector_arrow/tests/it/test_postgres_extended.rs +++ b/connector_arrow/tests/it/test_postgres_extended.rs @@ -1,4 +1,7 @@ use connector_arrow::postgres::{PostgresConnection, ProtocolExtended}; +use rstest::*; + +use super::spec; fn init() -> postgres::Client { let _ = env_logger::builder().is_test(true).try_init(); @@ -18,24 +21,14 @@ fn query_01() { super::tests::query_01(&mut conn); } -#[test] -fn roundtrip_empty() { - let table_name = "extended::roundtrip_empty"; - - let mut client = init(); - let mut conn = wrap_conn(&mut client); - let column_spec = super::generator::spec_empty(); - super::tests::roundtrip(&mut conn, table_name, column_spec); -} - -#[test] -fn roundtrip_null_bool() { - let table_name = "extended::roundtrip_null_bool"; - +#[rstest] +#[case::empty("extended::roundtrip::empty", spec::empty())] +#[case::null_bool("extended::roundtrip::null_bool", spec::null_bool())] +#[case::numeric("extended::roundtrip::numeric", spec::numeric())] +fn roundtrip(#[case] table_name: &str, #[case] spec: spec::ArrowGenSpec) { let mut client = init(); let mut conn = wrap_conn(&mut client); - let column_spec = super::generator::spec_null_bool(); - super::tests::roundtrip(&mut conn, table_name, column_spec); + super::tests::roundtrip(&mut conn, table_name, spec); } #[test] @@ -44,7 +37,7 @@ fn schema_get() { let mut client = init(); let mut conn = wrap_conn(&mut client); - let column_spec = super::generator::spec_all_types(); + let column_spec = super::spec::all_types(); super::tests::schema_get(&mut conn, table_name, column_spec); } @@ -54,6 +47,6 @@ fn schema_edit() { let mut client = init(); let mut conn = wrap_conn(&mut client); - let column_spec = super::generator::spec_all_types(); + let column_spec = super::spec::all_types(); super::tests::schema_edit(&mut conn, table_name, column_spec); } diff --git a/connector_arrow/tests/it/test_postgres_simple.rs b/connector_arrow/tests/it/test_postgres_simple.rs index cb1c183..b065063 100644 --- a/connector_arrow/tests/it/test_postgres_simple.rs +++ b/connector_arrow/tests/it/test_postgres_simple.rs @@ -1,4 +1,7 @@ use connector_arrow::postgres::{PostgresConnection, ProtocolSimple}; +use rstest::*; + +use super::spec; fn init() -> postgres::Client { let _ = env_logger::builder().is_test(true).try_init(); @@ -18,24 +21,14 @@ fn query_01() { super::tests::query_01(&mut conn); } -#[test] -fn roundtrip_empty() { - let table_name = "simple::roundtrip_empty"; - - let mut client = init(); - let mut conn = wrap_conn(&mut client); - let column_spec = super::generator::spec_empty(); - super::tests::roundtrip(&mut conn, table_name, column_spec); -} - -#[test] -fn roundtrip_null_bool() { - let table_name = "simple::roundtrip_null_bool"; - +#[rstest] +#[case::empty("simple::roundtrip::empty", spec::empty())] +#[case::null_bool("simple::roundtrip::null_bool", spec::null_bool())] +#[case::numeric("simple::roundtrip::numeric", spec::numeric())] +fn roundtrip(#[case] table_name: &str, #[case] spec: spec::ArrowGenSpec) { let mut client = init(); let mut conn = wrap_conn(&mut client); - let column_spec = super::generator::spec_null_bool(); - super::tests::roundtrip(&mut conn, table_name, column_spec); + super::tests::roundtrip(&mut conn, table_name, spec); } #[test] @@ -44,7 +37,7 @@ fn schema_get() { let mut client = init(); let mut conn = wrap_conn(&mut client); - let column_spec = super::generator::spec_all_types(); + let column_spec = super::spec::all_types(); super::tests::schema_get(&mut conn, table_name, column_spec); } @@ -54,6 +47,6 @@ fn schema_edit() { let mut client = init(); let mut conn = wrap_conn(&mut client); - let column_spec = super::generator::spec_all_types(); + let column_spec = super::spec::all_types(); super::tests::schema_edit(&mut conn, table_name, column_spec); } diff --git a/connector_arrow/tests/it/test_sqlite.rs b/connector_arrow/tests/it/test_sqlite.rs index afb0571..0040aab 100644 --- a/connector_arrow/tests/it/test_sqlite.rs +++ b/connector_arrow/tests/it/test_sqlite.rs @@ -1,3 +1,6 @@ +use super::spec; +use rstest::*; + fn init() -> rusqlite::Connection { let _ = env_logger::builder().is_test(true).try_init(); @@ -10,82 +13,13 @@ fn query_01() { super::tests::query_01(&mut conn); } -#[test] -#[ignore] // SQLite cannot infer schema from an empty response, as there is no rows to infer from -fn roundtrip_empty() { - let table_name = "roundtrip_empty"; - - let mut conn = init(); - let column_spec = super::generator::spec_empty(); - super::tests::roundtrip(&mut conn, table_name, column_spec); -} - -#[test] -fn roundtrip_null_bool() { - let table_name = "roundtrip_null_bool"; - - let mut conn = init(); - let column_spec = super::generator::spec_null_bool(); - super::tests::roundtrip(&mut conn, table_name, column_spec); -} - -#[test] -fn roundtrip_numeric() { - let table_name = "roundtrip_numeric"; - - let mut conn = init(); - let column_spec = super::generator::spec_numeric(); - super::tests::roundtrip(&mut conn, table_name, column_spec); -} - -#[test] -#[ignore] -fn roundtrip_timestamp() { - let table_name = "roundtrip_timestamp"; - - let mut conn = init(); - let column_spec = super::generator::spec_timestamp(); - super::tests::roundtrip(&mut conn, table_name, column_spec); -} - -#[test] -#[ignore] -fn roundtrip_date() { - let table_name = "roundtrip_date"; - - let mut conn = init(); - let column_spec = super::generator::spec_date(); - super::tests::roundtrip(&mut conn, table_name, column_spec); -} - -#[test] -#[ignore] -fn roundtrip_time() { - let table_name = "roundtrip_time"; - - let mut conn = init(); - let column_spec = super::generator::spec_time(); - super::tests::roundtrip(&mut conn, table_name, column_spec); -} - -#[test] -#[ignore] -fn roundtrip_duration() { - let table_name = "roundtrip_duration"; - - let mut conn = init(); - let column_spec = super::generator::spec_duration(); - super::tests::roundtrip(&mut conn, table_name, column_spec); -} - -#[test] -#[ignore] -fn roundtrip_interval() { - let table_name = "roundtrip_interval"; - +#[rstest] +// #[case::empty("roundtrip::empty", spec::empty())] +#[case::null_bool("roundtrip::null_bool", spec::null_bool())] +#[case::numeric("roundtrip::numeric", spec::numeric())] +fn roundtrip(#[case] table_name: &str, #[case] spec: spec::ArrowGenSpec) { let mut conn = init(); - let column_spec = super::generator::spec_interval(); - super::tests::roundtrip(&mut conn, table_name, column_spec); + super::tests::roundtrip(&mut conn, table_name, spec); } #[test] @@ -94,7 +28,7 @@ fn schema_get() { let table_name = "schema_get"; let mut conn = init(); - let column_spec = super::generator::spec_all_types(); + let column_spec = super::spec::all_types(); super::tests::schema_get(&mut conn, table_name, column_spec); } @@ -103,6 +37,6 @@ fn schema_edit() { let table_name = "schema_edit"; let mut conn = init(); - let column_spec = super::generator::spec_all_types(); + let column_spec = super::spec::all_types(); super::tests::schema_edit(&mut conn, table_name, column_spec); } diff --git a/connector_arrow/tests/it/tests.rs b/connector_arrow/tests/it/tests.rs index 8c34a01..23e1926 100644 --- a/connector_arrow/tests/it/tests.rs +++ b/connector_arrow/tests/it/tests.rs @@ -6,8 +6,8 @@ use connector_arrow::{ }; use rand::SeedableRng; -use crate::generator::{generate_batch, ColumnSpec}; use crate::util::{load_into_table, query_table}; +use crate::{generator::generate_batch, spec::ArrowGenSpec}; #[track_caller] pub fn query_01(conn: &mut C) { @@ -24,12 +24,12 @@ pub fn query_01(conn: &mut C) { ); } -pub fn roundtrip(conn: &mut C, table_name: &str, column_specs: Vec) +pub fn roundtrip(conn: &mut C, table_name: &str, spec: ArrowGenSpec) where C: Connection + SchemaEdit, { let mut rng = rand_chacha::ChaCha8Rng::from_seed([0; 32]); - let (schema, batches) = generate_batch(column_specs, &mut rng); + let (schema, batches) = generate_batch(spec, &mut rng); load_into_table(conn, schema.clone(), &batches, table_name).unwrap(); @@ -42,12 +42,12 @@ where similar_asserts::assert_eq!(batches_coerced, batches_query); } -pub fn schema_get(conn: &mut C, table_name: &str, column_specs: Vec) +pub fn schema_get(conn: &mut C, table_name: &str, spec: ArrowGenSpec) where C: Connection + SchemaEdit + SchemaGet, { let mut rng = rand_chacha::ChaCha8Rng::from_seed([0; 32]); - let (schema, batches) = generate_batch(column_specs, &mut rng); + let (schema, batches) = generate_batch(spec, &mut rng); load_into_table(conn, schema.clone(), &batches, table_name).unwrap(); let schema = coerce::coerce_schema(schema, &C::coerce_type, Some(false)); @@ -57,12 +57,12 @@ where similar_asserts::assert_eq!(schema, schema_introspection); } -pub fn schema_edit(conn: &mut C, table_name: &str, column_specs: Vec) +pub fn schema_edit(conn: &mut C, table_name: &str, spec: ArrowGenSpec) where C: Connection + SchemaEdit + SchemaGet, { let mut rng = rand_chacha::ChaCha8Rng::from_seed([0; 32]); - let (schema, _) = generate_batch(column_specs, &mut rng); + let (schema, _) = generate_batch(spec, &mut rng); let table_name2 = table_name.to_string() + "2";