From 57a28f0b1c2b34bf5f0f03a29c2acbb6456bc222 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Alja=C5=BE=20Mur=20Er=C5=BEen?= Date: Tue, 20 Feb 2024 13:52:04 +0100 Subject: [PATCH] test generator --- Cargo.lock | 14 +- Cargo.toml | 2 +- connector_arrow/Cargo.toml | 6 +- connector_arrow/src/duckdb/mod.rs | 1 + connector_arrow/src/util/coerce.rs | 42 ++- connector_arrow/tests/data/numeric.parquet | Bin 7686 -> 0 bytes connector_arrow/tests/data/temporal.parquet | Bin 7107 -> 0 bytes .../tests/it/generator.rs | 307 ++++++++---------- connector_arrow/tests/it/main.rs | 1 + connector_arrow/tests/it/test_duckdb.rs | 17 +- connector_arrow/tests/it/test_sqlite.rs | 63 +++- connector_arrow/tests/it/tests.rs | 41 ++- connector_arrow/tests/it/util.rs | 26 +- test_generator/Cargo.toml | 13 - 14 files changed, 304 insertions(+), 229 deletions(-) delete mode 100644 connector_arrow/tests/data/numeric.parquet delete mode 100644 connector_arrow/tests/data/temporal.parquet rename test_generator/src/main.rs => connector_arrow/tests/it/generator.rs (68%) delete mode 100644 test_generator/Cargo.toml diff --git a/Cargo.lock b/Cargo.lock index 30503bb..5103deb 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -517,6 +517,7 @@ dependencies = [ "duckdb", "env_logger", "fallible-streaming-iterator", + "half", "hex", "insta", "itertools", @@ -524,6 +525,8 @@ dependencies = [ "parquet", "postgres", "postgres-protocol", + "rand", + "rand_chacha", "rusqlite", "rust_decimal", "rust_decimal_macros", @@ -1902,17 +1905,6 @@ dependencies = [ "xattr", ] -[[package]] -name = "test_generator" -version = "0.1.0" -dependencies = [ - "arrow", - "half", - "parquet", - "rand", - "rand_chacha", -] - [[package]] name = "thiserror" version = "1.0.56" diff --git a/Cargo.toml b/Cargo.toml index fec7155..b74c6db 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,3 +1,3 @@ [workspace] -members = ["connector_arrow", "test_generator"] +members = ["connector_arrow"] resolver = "2" diff --git a/connector_arrow/Cargo.toml b/connector_arrow/Cargo.toml index e52459f..e6930d1 100644 --- a/connector_arrow/Cargo.toml +++ b/connector_arrow/Cargo.toml @@ -61,6 +61,10 @@ arrow = { version = "49", features = ["prettyprint"], default-features = false } parquet = { version = "49", features = ["arrow"], default-features = false } insta = { version = "1.34.0" } similar-asserts = { version = "1.5.0" } +half = "2.3.1" +rand = { version = "0.8.5", default-features = false } +rand_chacha = "0.3.1" + [features] all = ["src_sqlite", "src_duckdb", "src_postgres"] @@ -73,7 +77,7 @@ src_postgres = [ "rust_decimal", "rust_decimal_macros", "bytes", - "byteorder" + "byteorder", ] src_sqlite = ["rusqlite", "fallible-streaming-iterator", "urlencoding"] src_duckdb = [ diff --git a/connector_arrow/src/duckdb/mod.rs b/connector_arrow/src/duckdb/mod.rs index e8681a0..4110835 100644 --- a/connector_arrow/src/duckdb/mod.rs +++ b/connector_arrow/src/duckdb/mod.rs @@ -31,6 +31,7 @@ impl Connection for duckdb::Connection { fn coerce_type(ty: &DataType) -> Option { match ty { DataType::Null => Some(DataType::Int64), + DataType::Float16 => Some(DataType::Float32), _ => None, } } diff --git a/connector_arrow/src/util/coerce.rs b/connector_arrow/src/util/coerce.rs index 983697d..c4690fb 100644 --- a/connector_arrow/src/util/coerce.rs +++ b/connector_arrow/src/util/coerce.rs @@ -1,7 +1,7 @@ use std::sync::Arc; -use arrow::array::{Array, ArrayRef}; -use arrow::datatypes::{DataType, Field, Schema, SchemaRef}; +use arrow::array::{Array, ArrayRef, AsArray, Float32Builder, Float64Builder}; +use arrow::datatypes::{DataType, Field, Float16Type, Schema, SchemaRef}; use arrow::record_batch::RecordBatch; use itertools::Itertools; @@ -38,7 +38,11 @@ where F: Fn(&DataType) -> Option + Copy, { match coerce_fn(array.data_type()) { - Some(new_ty) => arrow::compute::cast(&array, &new_ty), + Some(new_ty) => match (array.data_type(), &new_ty) { + (DataType::Float16, DataType::Float32) => Ok(coerce_float_16_to_32(&array)), + (DataType::Float16, DataType::Float64) => Ok(coerce_float_16_to_64(&array)), + _ => arrow::compute::cast(&array, &new_ty), + }, None => Ok(array), } } @@ -58,3 +62,35 @@ where .collect_vec(), )) } + +fn coerce_float_16_to_32(array: &dyn Array) -> ArrayRef { + // inefficient, but we don't need efficiency here + + let array = array.as_primitive::(); + let mut builder = Float32Builder::with_capacity(array.len()); + + for i in 0..array.len() { + if array.is_null(i) { + builder.append_null(); + } else { + builder.append_value(array.value(i).to_f32()); + } + } + Arc::new(builder.finish()) as ArrayRef +} + +fn coerce_float_16_to_64(array: &dyn Array) -> ArrayRef { + // inefficient, but we don't need efficiency here + + let array = array.as_primitive::(); + let mut builder = Float64Builder::with_capacity(array.len()); + + for i in 0..array.len() { + if array.is_null(i) { + builder.append_null(); + } else { + builder.append_value(array.value(i).to_f64()); + } + } + Arc::new(builder.finish()) as ArrayRef +} diff --git a/connector_arrow/tests/data/numeric.parquet b/connector_arrow/tests/data/numeric.parquet deleted file mode 100644 index 5a0c854701bf5aa5c4ce140fac4cb43d35e8bfc2..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 7686 zcmb7JeQXrR72i7>bJ*rH-t`?EV5hz zK&u)dL?c>72t^2`lpv)+5u%79TG1qG(#ZV>RVi{?gxsh?RhmWyl?u@crGWe1o1NL2 z-F2j{bnDrf_xrt%d2e=R-OjeGGYd__G+Obt#DJT*z+CXO88H``MP`v{>f!KI=C=?9=+z6XSEMXMIqH}X2 zBSZbY!vxJ(8QXYP2)T+mxHh~>!hxcoq)3pIIwY8=ake_d%9Z{=x* zJg#G8qJ=2ub$t$~Usvn%-34R|-L}_vp}1ieDfRiTB7=M{Cgs7-yI6GIeIgr_c?XAR z>esd8-kv&oDdeAbObl=y0f9=U>k!6uJR$*Nj>$6<`&H1~>`B?PgS;%^8fSG=fd{!cGxMV#uI@j3?88LBTtXFmg`2)50o`;FL2k~>EWcg(e&A#( z$XVSCUI%D46*wsN*S>+K;kB{z?>7&x7(74zk6+9z`NN~1`rH;}Fe7UpGDWX_nq^<{ zcLyTD8h@cnYkpAq@ACE)*YGv$T^R$})X<1Rrwj+blC2bWft4^OZE$y-D3egBWQ{Q+@;dKcag4n_v1SPP6V zNhCmliZ_H#rzeh3uLJ)xhA^G4hU_|*LCQ$*>i=?!2xX?zp zcy*d{qLlOYg@YX^=G$Lvhoj0)lUv2NdI<~7f-|Q>;cmyY5hLnsni5eoOv93zz>wAOiHC*lCuKmfE`eQ_#75e&TZUV(!F23~#37-q2)sLkM$5O7vf z_!1C{#~zm3FT}sUJE=o*Ff^x?-sc{ zfyea8Irf=38X3NhVHt<~#WR^Av7|lo zAuvp~f#?e)I-W!i-(27lDTSpNHZk8mK3`+w8yLY2e2>>oej5-UZnu%BCEPc05)cQ*_ADcU-Z{O421f5CJE;X59F zUS^*!gYQu=3-YjKeT3jIQW(Y8%Uy8rQSYi{AF<-Kp2MDX136BvLyrGMT@L@z?NZdf zUyc@l3E~`N-9hvOp{JI*XmzzxCi2Kq`+AIz`HA4G@0N(B26y0&n$OLP)IH!$tOheafLh0&Pu6F@tu|S@s;szxt~yv zMb-qeJWS(c{`Eema zudK3feqGjtx((0DosirU;Tm%qssEHgzpp>#QqQ(MpI&XBSslpl)|{#pG+agvCu%^0 zH4(KUb_D*E;r=UH%C>CT-1SYq6^VX2KZ)1=u|~#N2mJ8b$k-UGA7Cxm1%DaKQ9nGw zGh@y81YP1UJZy)a`YehKAWeR(!I5DHSQ}$49eENOLvna-N1FO92(v+?X&8fNXuxxt z0d!*rP~}oc)q&4BRF_93>=nE<0TE+ayrE9kntvvl%#!`+2wfd*^xV4-{6lf{11>NG z^lcmMOE;!A?cB1XZ@6=`5BnI~o@`Ghx#0xGzC1m*0DThGk;^Cwzg$yaQ{CW8`JNQZ z<+HQtKOpowl&b!=gDr~w`sO6aHUTrsk$>3A(=&Drpz+(H4mPv?_Jc5= zB(l|!%t4Uhs(%&vrDqT|?Csn-v=@EoPS>pKgSck=c$_4-^1U8BnI`ZszSSSUhI)!0 z>%;g7%YkbYe!0$Xch#hn7-W+>eDUY_ycgf2;2|vee4Ri3&E#qi#-GOnFc1X{ZDVIk z-*CtNO+$5iU+UUF%35~%@r>hN_TVZ00uLVTY_Asl-grL+nv?(_Hw{g4j`fce()iO+NnhZ=ZJp@MCGJme!OY{2-4iUJUd^Yw=C zZ3je7vQ20&uWxdiB%3MTo1tct2dD#YJaP<&ayZ8a+tBs4oql`=$M+)M%?}4ZBJggG zFqT4l^>{KK&ladT$36gxcv7e-$6uh(jr#GGorCqdUwCNw|~!?-J=!z`gf0hcVxICQ@1v? PHf6lI5We`se}Ml7)WDT- diff --git a/connector_arrow/tests/data/temporal.parquet b/connector_arrow/tests/data/temporal.parquet deleted file mode 100644 index f13b97230ae857a4b8a52da0b861fcd12e7cb250..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 7107 zcmbtZduUtN8NX5;S94sy;(L3wsF|Baw@&O>vTP@D*JNDlHAN3we#tV7Op+Bzwk)}l zABi@IH@3n;$U-RjBOM`(7QzT+^p8@yY=dl!5IojF3E2ZSvJplwg_SMYMmF#J&egs5 zNV?aL^nlMjkMDfn?|Yo@oFk8jqrGe$%dkuUt2PEIwux=(W=(7hThG=rjFDj&6O7lH z2blU6Gv6_unNJoMWAh80p;#_oOeXTVneL8IK9@XrDz}u)vJ7Y90$f0_2#nfO1*gkjlA_`Wyb;b^YkCL2x_Ut_ZdqF?O0{8XmZivI;cU{(EzCj28EM9{(ry~&TL&5*TP0RGmm;bc!Wb7~RzWvJ%=Eg<$_jecmxuFGL zrr}NUs%^l}W-~fArQz4IM)lPG|c`zsr8~wLbg)gZd%b!nOoh zBMH&(_J3vX!=JbR6nqM!R7lo$n&dfNX~{-0F-B#Rbg5V>0#X;uMFV%yAas+^tQEoO z2;Sk=tw2S^nJ%iTUAz--2&x6-pxCY$p5>`zfr;W=*RA7MHFj7X&hd5NfmPWHQerTc zlhtllY%$sAQ1jAx19#pa^rM}r#Snp1k3=2?;#&_`>$#U6{r*bxUnjLd;5c806q|rJ z(yt9dU5p^qUNs2R5w-o>lT%(T5L;P-M}PD=I)fR-@zh>;7PH8RF8uhn(Pu%lr;JDn zk7ofxrDXtvMLAXeuv`?sZo7V11DHnYhm{(q%8Ll_NI!MVQ6TA*2DZm?PCceDrAdSI zDDmkgpidZn@w*@W=I!+70Zsf2psBR;`B5|MI6IDckik0v9?gxYFEKe(d?T7(?!i|w zCOt@#=hpuI6PcK8GFzE?Gtc{Si^;-rEZZ4P#tOlFZZS>4{=mR}U=Wzsw-OBaYJj37 z#T5utF00T`fYO5sIS!us$DRMa>^<>487QX!b;ZtNfat))wR&RBtp+o1VDOJ+%$$ia zw;IiA@IX^uiPOgNw0wtvY5AO|r{yymrscy_p4OBmpE5u%H<7*`q}9;Z3;V}a zl?CB2=0W&K$x=bDM-5h5t!1KI_jF8w#)SeawJcQ;jb^?gR4Jqtu>fSqvC8fi?f`4& zVBBost{Bp(vl;&lV(!AH8(B5if<7L{b&C2L!ElTr#8#`Ghc+jpi*Eywm$*eSQ z?^Q(H+^00xI#}2h60~ctl-sP){IvC}jj=)@zjCUWNGIoGYzN!`t;Vgd?b!=$meQ$h14?$&GX=N%F zQ_5&A2DMsgwYGWrVOS%0hQ6a%8v|MNfHGdbHqhg`C%LsL@J+?K$U*Cbm5i!)v}J(- z6$HA6iM0hqm3`mD@~lxmgIf>blM+dEawiaQCbPx-v_%-VC|Bt1zcUXACi2eYZgsQm z+`5(B&j?SADyOYOnLJsg6a9)&xNF$a`;hy%o8`ER0(*iH{%^zf7vwO}7-K=w*}etLH(Yd|0TjJd{{F3KE~9#PhGB zqHpe2CGKE{D!&cfYY;waki2&JFmS!|DthhO3o6%&4Y19d6T-vv{_X>}{xU+O|8-=}d2zSxtFR8+Gb1fz9=04%iFYQd)eU$cq zNc;AFi+qdS&UJgOW6S$C_h}EybJx4r5F>nYK;f>F=ZPZHRsPy5gg0IUv;pr=kvF&2%9#WceA@9U#y5(iqr;Kcqzevz{{VV-wf}E>Z~>kw_~3IL<@cEqV}|JG(UFojf};?(h%FN(df$bfx%>ZBtYS;NBU$wPy(YOJ|C(GN+HOhiCCmJ8=4L0$0OO<#Qb1sIucrr z&xciB!0oOeorVHbqp>|JF}giz~C!EAZGkE~B1fg&%QpCb!Ml|p+%f&Eba!iT8lFNuA_vq3~Pk5j~6 zQz3ga(z~MCY$Y|}7+jnlOJ!twUuI7A`>+=NJO+(a+%AkK+2625NT?QquEdcbZXhRy zr4(O@zR`4iO2il|#=TiN#)>`<#9Cw`@5>ZdLZy5u==3cpG3_pO)!L7e@2`*#CfJsI zI>tAb?VQjDJjyX;=y|4fzm7tNSA@%Om*-cAFYnNGJ}W`z!4~p~@Qu zxmm^5Nbe=~tjZsGVpND!2YZU_ zcU47*Crl`n`c&ntqR%}5xr@rNXq2HarlMg{JSOFF@Eb`@#6FZgGm=8(+Kc{$6bWBa zlr4xo67-w!-ZnUa(column_specs: Vec, rng: &mut R) -> RecordBatch { + let mut arrays = Vec::new(); + let mut fields = Vec::new(); + for column in column_specs { + let array = generate_array(&column.data_type, &column.values, rng); + arrays.push(array); + + let field = Field::new(column.field_name, column.data_type, column.is_nullable); + fields.push(field); + } + let schema = Arc::new(Schema::new(fields)); + RecordBatch::try_new(schema, arrays).unwrap() +} + +pub fn spec_simple() -> Vec { + domains_to_batch_spec(&[DataType::Null, DataType::Boolean], &[false, true]) +} + +pub fn spec_numeric() -> Vec { + domains_to_batch_spec( + &[ + DataType::Int8, + DataType::Int16, + DataType::Int32, + DataType::Int64, + DataType::UInt8, + DataType::UInt16, + DataType::UInt32, + DataType::UInt64, + DataType::Float16, + DataType::Float32, + DataType::Float64, + ], + &[false, true], + ) +} + +pub fn spec_timestamp() -> Vec { + domains_to_batch_spec( + &[ + DataType::Timestamp(TimeUnit::Nanosecond, None), + DataType::Timestamp(TimeUnit::Microsecond, None), + DataType::Timestamp(TimeUnit::Millisecond, None), + DataType::Timestamp(TimeUnit::Second, None), + DataType::Timestamp(TimeUnit::Nanosecond, Some(Arc::from("+07:30"))), + DataType::Timestamp(TimeUnit::Microsecond, Some(Arc::from("+07:30"))), + DataType::Timestamp(TimeUnit::Millisecond, Some(Arc::from("+07:30"))), + DataType::Timestamp(TimeUnit::Second, Some(Arc::from("+07:30"))), + ], + &[true], + ) +} +pub fn spec_date() -> Vec { + domains_to_batch_spec(&[DataType::Date32, DataType::Date64], &[true]) +} +pub fn spec_time() -> Vec { + domains_to_batch_spec( + &[ + DataType::Time32(TimeUnit::Millisecond), + DataType::Time32(TimeUnit::Second), + DataType::Time64(TimeUnit::Nanosecond), + DataType::Time64(TimeUnit::Microsecond), + ], + &[true], + ) +} +pub fn spec_duration() -> Vec { + domains_to_batch_spec( + &[ + DataType::Duration(TimeUnit::Nanosecond), + DataType::Duration(TimeUnit::Microsecond), + DataType::Duration(TimeUnit::Millisecond), + DataType::Duration(TimeUnit::Second), + ], + &[true], + ) +} +pub fn spec_interval() -> Vec { + domains_to_batch_spec( + &[ + DataType::Interval(IntervalUnit::YearMonth), + DataType::Interval(IntervalUnit::MonthDayNano), + DataType::Interval(IntervalUnit::DayTime), + ], + &[true], + ) +} + +pub fn domains_to_batch_spec( + data_types_domain: &[DataType], + is_nullable_domain: &[bool], +) -> Vec { + let value_gen_process_domain = [ + ValueGenProcess::Low, + ValueGenProcess::High, + ValueGenProcess::Null, + ValueGenProcess::RandomUniform, + ]; + + let mut columns = Vec::new(); + for data_type in data_types_domain { + for is_nullable in is_nullable_domain { + let is_nullable = *is_nullable; + if matches!(data_type, &DataType::Null) && !is_nullable { + continue; + } + + let mut field_name = data_type.to_string(); + if is_nullable { + field_name += "_null"; + } + let mut col = ColumnSpec { + field_name, + data_type: data_type.clone(), + is_nullable, + values: Vec::new(), + }; + + for gen_process in value_gen_process_domain { + col.values.push(ValuesSpec { + gen_process: if matches!(gen_process, ValueGenProcess::Null) && !is_nullable { + ValueGenProcess::RandomUniform + } else { + gen_process + }, + repeat: 1, + }); + } + columns.push(col); + } + } + columns +} + #[derive(Clone, Copy)] enum ValueGenProcess { Null, @@ -14,19 +146,19 @@ enum ValueGenProcess { RandomUniform, } -struct ValuesDesc { +struct ValuesSpec { gen_process: ValueGenProcess, repeat: usize, } -struct ColumnDesc { +pub struct ColumnSpec { field_name: String, is_nullable: bool, data_type: DataType, - values: Vec, + values: Vec, } -fn count_values(values: &[ValuesDesc]) -> usize { +fn count_values(values: &[ValuesSpec]) -> usize { values.iter().map(|v| v.repeat).sum() } @@ -48,7 +180,7 @@ macro_rules! gen_array { }}; } -fn generate_array(data_type: &DataType, values: &[ValuesDesc], rng: &mut R) -> ArrayRef { +fn generate_array(data_type: &DataType, values: &[ValuesSpec], rng: &mut R) -> ArrayRef { match data_type { DataType::Null => { let mut builder = NullBuilder::with_capacity(count_values(values)); @@ -258,160 +390,3 @@ fn generate_array(data_type: &DataType, values: &[ValuesDesc], rng: &mut DataType::RunEndEncoded(_, _) => todo!(), } } - -fn generate_batch(columns_desc: Vec, rng: &mut R) -> RecordBatch { - let mut arrays = Vec::new(); - let mut fields = Vec::new(); - for column in columns_desc { - let array = generate_array(&column.data_type, &column.values, rng); - arrays.push(array); - - let field = Field::new(column.field_name, column.data_type, column.is_nullable); - fields.push(field); - } - let schema = Arc::new(Schema::new(fields)); - RecordBatch::try_new(schema, arrays).unwrap() -} - -fn numeric() -> Vec { - let data_types_domain = [ - DataType::Null, - DataType::Boolean, - DataType::Int8, - DataType::Int16, - DataType::Int32, - DataType::Int64, - DataType::UInt8, - DataType::UInt16, - DataType::UInt32, - DataType::UInt64, - // DataType::Float16, - DataType::Float32, - DataType::Float64, - ]; - let is_nullable_domain = [false, true]; - let value_gen_process_domain = [ - ValueGenProcess::Low, - ValueGenProcess::High, - ValueGenProcess::Null, - ValueGenProcess::RandomUniform, - ]; - - let mut columns = Vec::new(); - for data_type in &data_types_domain { - for is_nullable in is_nullable_domain { - if matches!(data_type, &DataType::Null) && !is_nullable { - continue; - } - - let mut field_name = data_type.to_string(); - if is_nullable { - field_name += "_null"; - } - let mut col = ColumnDesc { - field_name, - data_type: data_type.clone(), - is_nullable, - values: Vec::new(), - }; - - for gen_process in value_gen_process_domain { - col.values.push(ValuesDesc { - gen_process: if matches!(gen_process, ValueGenProcess::Null) && !is_nullable { - ValueGenProcess::RandomUniform - } else { - gen_process - }, - repeat: 1, - }); - } - columns.push(col); - } - } - columns -} - -fn temporal() -> Vec { - let data_types_domain = [ - DataType::Timestamp(TimeUnit::Nanosecond, None), - DataType::Timestamp(TimeUnit::Microsecond, None), - DataType::Timestamp(TimeUnit::Millisecond, None), - DataType::Timestamp(TimeUnit::Second, None), - DataType::Timestamp(TimeUnit::Nanosecond, Some(Arc::from("+07:30"))), - DataType::Timestamp(TimeUnit::Microsecond, Some(Arc::from("+07:30"))), - DataType::Timestamp(TimeUnit::Millisecond, Some(Arc::from("+07:30"))), - DataType::Timestamp(TimeUnit::Second, Some(Arc::from("+07:30"))), - DataType::Date32, - DataType::Date64, - DataType::Time32(TimeUnit::Millisecond), - DataType::Time32(TimeUnit::Second), - DataType::Time64(TimeUnit::Nanosecond), - DataType::Time64(TimeUnit::Microsecond), - // DataType::Duration(TimeUnit::Nanosecond), - // DataType::Duration(TimeUnit::Microsecond), - // DataType::Duration(TimeUnit::Millisecond), - // DataType::Duration(TimeUnit::Second), - DataType::Interval(IntervalUnit::YearMonth), - // DataType::Interval(IntervalUnit::MonthDayNano), - DataType::Interval(IntervalUnit::DayTime), - ]; - let is_nullable_domain = [true]; - let value_gen_process_domain = [ - ValueGenProcess::Low, - ValueGenProcess::High, - ValueGenProcess::Null, - ValueGenProcess::RandomUniform, - ]; - - let mut columns = Vec::new(); - for data_type in &data_types_domain { - for is_nullable in is_nullable_domain { - if matches!(data_type, &DataType::Null) && !is_nullable { - continue; - } - - let mut field_name = data_type.to_string(); - if is_nullable { - field_name += "_null"; - } - let mut col = ColumnDesc { - field_name, - data_type: data_type.clone(), - is_nullable, - values: Vec::new(), - }; - - for gen_process in value_gen_process_domain { - col.values.push(ValuesDesc { - gen_process: if matches!(gen_process, ValueGenProcess::Null) && !is_nullable { - ValueGenProcess::RandomUniform - } else { - gen_process - }, - repeat: 1, - }); - } - columns.push(col); - } - } - columns -} - -fn write_parquet_to_file(batch: RecordBatch, file_name: &str) { - let path = Path::new("connector_arrow/tests/data/file").with_file_name(file_name); - - let mut file = File::create(path).unwrap(); - - let schema = batch.schema(); - let mut writer = - parquet::arrow::arrow_writer::ArrowWriter::try_new(&mut file, schema, None).unwrap(); - writer.write(&batch).unwrap(); - writer.close().unwrap(); -} - -fn main() { - let mut rng = rand_chacha::ChaCha8Rng::from_seed([0; 32]); - - write_parquet_to_file(generate_batch(numeric(), &mut rng), "numeric.parquet"); - write_parquet_to_file(generate_batch(temporal(), &mut rng), "temporal.parquet"); -} diff --git a/connector_arrow/tests/it/main.rs b/connector_arrow/tests/it/main.rs index 1fbd683..f3a2410 100644 --- a/connector_arrow/tests/it/main.rs +++ b/connector_arrow/tests/it/main.rs @@ -1,3 +1,4 @@ +mod generator; mod tests; mod util; diff --git a/connector_arrow/tests/it/test_duckdb.rs b/connector_arrow/tests/it/test_duckdb.rs index 4b89a06..cabdb71 100644 --- a/connector_arrow/tests/it/test_duckdb.rs +++ b/connector_arrow/tests/it/test_duckdb.rs @@ -29,22 +29,21 @@ fn roundtrip_empty() { } #[test] -fn roundtrip_numeric() { - let table_name = "roundtrip_number"; - let file_name = "numeric.parquet"; +fn roundtrip_simple() { + let table_name = "roundtrip_simple"; let mut conn = init(); - super::tests::roundtrip_of_parquet(&mut conn, file_name, table_name); + let column_spec = super::generator::spec_simple(); + super::tests::roundtrip_of_generated(&mut conn, table_name, column_spec); } #[test] -#[ignore] -fn roundtrip_temporal() { - let table_name = "roundtrip_temporal"; - let file_name = "temporal.parquet"; +fn roundtrip_numeric() { + let table_name = "roundtrip_numeric"; let mut conn = init(); - super::tests::roundtrip_of_parquet(&mut conn, file_name, table_name); + let column_spec = super::generator::spec_numeric(); + super::tests::roundtrip_of_generated(&mut conn, table_name, column_spec); } #[test] diff --git a/connector_arrow/tests/it/test_sqlite.rs b/connector_arrow/tests/it/test_sqlite.rs index 1c4c9fd..295352a 100644 --- a/connector_arrow/tests/it/test_sqlite.rs +++ b/connector_arrow/tests/it/test_sqlite.rs @@ -29,23 +29,72 @@ fn roundtrip_empty() { super::tests::roundtrip_of_parquet(&mut conn, file_name, table_name); } +#[test] +fn roundtrip_simple() { + let table_name = "roundtrip_simple"; + + let mut conn = init(); + let column_spec = super::generator::spec_simple(); + super::tests::roundtrip_of_generated(&mut conn, table_name, column_spec); +} + #[test] fn roundtrip_numeric() { - let table_name = "roundtrip_number"; - let file_name = "numeric.parquet"; + let table_name = "roundtrip_numeric"; let mut conn = init(); - super::tests::roundtrip_of_parquet(&mut conn, file_name, table_name); + let column_spec = super::generator::spec_numeric(); + super::tests::roundtrip_of_generated(&mut conn, table_name, column_spec); } #[test] #[ignore] -fn roundtrip_temporal() { - let table_name = "roundtrip_temporal"; - let file_name = "temporal.parquet"; +fn roundtrip_timestamp() { + let table_name = "roundtrip_timestamp"; let mut conn = init(); - super::tests::roundtrip_of_parquet(&mut conn, file_name, table_name); + let column_spec = super::generator::spec_timestamp(); + super::tests::roundtrip_of_generated(&mut conn, table_name, column_spec); +} + +#[test] +#[ignore] +fn roundtrip_date() { + let table_name = "roundtrip_date"; + + let mut conn = init(); + let column_spec = super::generator::spec_date(); + super::tests::roundtrip_of_generated(&mut conn, table_name, column_spec); +} + +#[test] +#[ignore] +fn roundtrip_time() { + let table_name = "roundtrip_time"; + + let mut conn = init(); + let column_spec = super::generator::spec_time(); + super::tests::roundtrip_of_generated(&mut conn, table_name, column_spec); +} + +#[test] +#[ignore] +fn roundtrip_duration() { + let table_name = "roundtrip_duration"; + + let mut conn = init(); + let column_spec = super::generator::spec_duration(); + super::tests::roundtrip_of_generated(&mut conn, table_name, column_spec); +} + +#[test] +#[ignore] +fn roundtrip_interval() { + let table_name = "roundtrip_interval"; + + let mut conn = init(); + let column_spec = super::generator::spec_interval(); + super::tests::roundtrip_of_generated(&mut conn, table_name, column_spec); } #[test] diff --git a/connector_arrow/tests/it/tests.rs b/connector_arrow/tests/it/tests.rs index aab49d7..ee3192f 100644 --- a/connector_arrow/tests/it/tests.rs +++ b/connector_arrow/tests/it/tests.rs @@ -5,8 +5,13 @@ use connector_arrow::{ api::{Connection, ResultReader, SchemaEdit, SchemaGet, Statement}, TableCreateError, TableDropError, }; +use rand::SeedableRng; -use crate::util::{load_parquet_into_table, query_table}; +use crate::util::{load_into_table, query_table}; +use crate::{ + generator::{generate_batch, ColumnSpec}, + util::read_parquet, +}; #[track_caller] pub fn query_01(conn: &mut C) { @@ -15,11 +20,11 @@ pub fn query_01(conn: &mut C) { similar_asserts::assert_eq!( pretty_format_batches(&results).unwrap().to_string(), - "+---+---+ -| a | b | -+---+---+ -| 1 | | -+---+---+" + "+---+---+\n\ + | a | b |\n\ + +---+---+\n\ + | 1 | |\n\ + +---+---+" ); } @@ -29,21 +34,36 @@ where { let file_path = Path::new("./tests/data/a").with_file_name(file_name); + let (schema_file, batches_file) = read_parquet(&file_path).unwrap(); let (schema_file, batches_file) = - load_parquet_into_table(conn, &file_path, table_name).unwrap(); + load_into_table(conn, schema_file, batches_file, table_name).unwrap(); let (schema_query, batches_query) = query_table(conn, table_name).unwrap(); similar_asserts::assert_eq!(schema_file, schema_query); similar_asserts::assert_eq!(batches_file, batches_query); } +pub fn roundtrip_of_generated(conn: &mut C, table_name: &str, column_specs: Vec) +where + C: Connection + SchemaEdit, +{ + let mut rng = rand_chacha::ChaCha8Rng::from_seed([0; 32]); + let batch = generate_batch(column_specs, &mut rng); + + let (_, batches_file) = load_into_table(conn, batch.schema(), vec![batch], table_name).unwrap(); + + let (_, batches_query) = query_table(conn, table_name).unwrap(); + + similar_asserts::assert_eq!(batches_file, batches_query); +} + pub fn introspection(conn: &mut C, file_name: &str, table_name: &str) where C: Connection + SchemaEdit + SchemaGet, { let file_path = Path::new("./tests/data/a").with_file_name(file_name); - let (schema_loaded, _) = - super::util::load_parquet_into_table(conn, &file_path, table_name).unwrap(); + let (schema_file, batches_file) = read_parquet(&file_path).unwrap(); + let (schema_loaded, _) = load_into_table(conn, schema_file, batches_file, table_name).unwrap(); let schema_introspection = conn.table_get(table_name).unwrap(); similar_asserts::assert_eq!(schema_loaded, schema_introspection); @@ -55,7 +75,8 @@ where { let file_path = Path::new("./tests/data/a").with_file_name(file_name); - let (schema, _) = super::util::load_parquet_into_table(conn, &file_path, table_name).unwrap(); + let (schema_file, batches_file) = read_parquet(&file_path).unwrap(); + let (schema, _) = load_into_table(conn, schema_file, batches_file, table_name).unwrap(); let table_name2 = table_name.to_string() + "2"; diff --git a/connector_arrow/tests/it/util.rs b/connector_arrow/tests/it/util.rs index 8ea9840..42074c8 100644 --- a/connector_arrow/tests/it/util.rs +++ b/connector_arrow/tests/it/util.rs @@ -22,16 +22,26 @@ pub fn read_parquet(file_path: &Path) -> Result<(SchemaRef, Vec), A Ok((schema, batches)) } -pub fn load_parquet_into_table( +#[allow(dead_code)] +pub fn write_parquet(path: &Path, batch: RecordBatch) { + let mut file = File::create(path).unwrap(); + + let schema = batch.schema(); + let mut writer = + parquet::arrow::arrow_writer::ArrowWriter::try_new(&mut file, schema, None).unwrap(); + writer.write(&batch).unwrap(); + writer.close().unwrap(); +} + +pub fn load_into_table( conn: &mut C, - file_path: &Path, + schema: SchemaRef, + batches: Vec, table_name: &str, ) -> Result<(SchemaRef, Vec), ConnectorError> where C: Connection + SchemaEdit, { - let (schema_file, batches_file) = read_parquet(file_path)?; - // table drop match conn.table_drop(table_name) { Ok(_) | Err(TableDropError::TableNonexistent) => (), @@ -39,7 +49,7 @@ where } // table create - match conn.table_create(table_name, schema_file.clone()) { + match conn.table_create(table_name, schema.clone()) { Ok(_) => (), Err(TableCreateError::TableExists) => { panic!("table was just deleted, how can it exist now?") @@ -50,14 +60,14 @@ where // write into table { let mut appender = conn.append(&table_name).unwrap(); - for batch in batches_file.clone() { + for batch in batches.clone() { appender.append(batch).unwrap(); } appender.finish().unwrap(); } - let schema_coerced = coerce::coerce_schema(schema_file, &C::coerce_type); - let batches_coerced = coerce::coerce_batches(&batches_file, C::coerce_type).unwrap(); + let schema_coerced = coerce::coerce_schema(schema, &C::coerce_type); + let batches_coerced = coerce::coerce_batches(&batches, C::coerce_type).unwrap(); Ok((schema_coerced, batches_coerced)) } diff --git a/test_generator/Cargo.toml b/test_generator/Cargo.toml deleted file mode 100644 index f083cc2..0000000 --- a/test_generator/Cargo.toml +++ /dev/null @@ -1,13 +0,0 @@ -[package] -name = "test_generator" -version = "0.1.0" -edition = "2021" - -# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html - -[dependencies] -arrow = { version = "49", features = ["prettyprint"], default-features = false } -half = "2.3.1" -parquet = { version = "49", features = ["arrow"], default-features = false } -rand = {version = "0.8.5", default-features = false} -rand_chacha = "0.3.1"