diff --git a/c/driver/postgresql/postgresql_test.cc b/c/driver/postgresql/postgresql_test.cc index 932e685e4e..f76920e035 100644 --- a/c/driver/postgresql/postgresql_test.cc +++ b/c/driver/postgresql/postgresql_test.cc @@ -812,6 +812,7 @@ class PostgresStatementTest : public ::testing::Test, void TestSqlIngestUInt16() { GTEST_SKIP() << "Not implemented"; } void TestSqlIngestUInt32() { GTEST_SKIP() << "Not implemented"; } void TestSqlIngestUInt64() { GTEST_SKIP() << "Not implemented"; } + void TestSqlIngestStringDictionary() { GTEST_SKIP() << "Not implemented"; } void TestSqlPrepareErrorParamCountMismatch() { GTEST_SKIP() << "Not yet implemented"; } void TestSqlPrepareGetParameterSchema() { GTEST_SKIP() << "Not yet implemented"; } diff --git a/c/driver/sqlite/sqlite_test.cc b/c/driver/sqlite/sqlite_test.cc index e5566df260..c07aaeebd0 100644 --- a/c/driver/sqlite/sqlite_test.cc +++ b/c/driver/sqlite/sqlite_test.cc @@ -246,7 +246,7 @@ class SqliteStatementTest : public ::testing::Test, void TestSqlIngestUInt64() { std::vector> values = {std::nullopt, 0, INT64_MAX}; - return TestSqlIngestType(NANOARROW_TYPE_UINT64, values); + return TestSqlIngestType(NANOARROW_TYPE_UINT64, values, /*dictionary_encode*/ false); } void TestSqlIngestBinary() { GTEST_SKIP() << "Cannot ingest BINARY (not implemented)"; } diff --git a/c/driver/sqlite/statement_reader.c b/c/driver/sqlite/statement_reader.c index c609e1e416..3654a2e3de 100644 --- a/c/driver/sqlite/statement_reader.c +++ b/c/driver/sqlite/statement_reader.c @@ -60,7 +60,7 @@ AdbcStatusCode AdbcSqliteBinderSet(struct AdbcSqliteBinder* binder, struct ArrowSchemaView view = {0}; for (int i = 0; i < binder->schema.n_children; i++) { status = ArrowSchemaViewInit(&view, binder->schema.children[i], &arrow_error); - if (status != 0) { + if (status != NANOARROW_OK) { SetError(error, "Failed to parse schema for column %d: %s (%d): %s", i, strerror(status), status, arrow_error.message); return ADBC_STATUS_INVALID_ARGUMENT; @@ -70,6 +70,31 @@ AdbcStatusCode AdbcSqliteBinderSet(struct AdbcSqliteBinder* binder, SetError(error, "Column %d has UNINITIALIZED type", i); return ADBC_STATUS_INTERNAL; } + + if (view.type == NANOARROW_TYPE_DICTIONARY) { + struct ArrowSchemaView value_view = {0}; + status = ArrowSchemaViewInit(&value_view, binder->schema.children[i]->dictionary, + &arrow_error); + if (status != NANOARROW_OK) { + SetError(error, "Failed to parse schema for column %d->dictionary: %s (%d): %s", + i, strerror(status), status, arrow_error.message); + return ADBC_STATUS_INVALID_ARGUMENT; + } + + // We only support string/binary dictionary-encoded values + switch (value_view.type) { + case NANOARROW_TYPE_STRING: + case NANOARROW_TYPE_LARGE_STRING: + case NANOARROW_TYPE_BINARY: + case NANOARROW_TYPE_LARGE_BINARY: + break; + default: + SetError(error, "Column %d dictionary has unsupported type %s", i, + ArrowTypeString(value_view.type)); + return ADBC_STATUS_NOT_IMPLEMENTED; + } + } + binder->types[i] = view.type; } @@ -353,6 +378,20 @@ AdbcStatusCode AdbcSqliteBinderBindNext(struct AdbcSqliteBinder* binder, sqlite3 SQLITE_STATIC); break; } + case NANOARROW_TYPE_DICTIONARY: { + int64_t value_index = + ArrowArrayViewGetIntUnsafe(binder->batch.children[col], binder->next_row); + if (ArrowArrayViewIsNull(binder->batch.children[col]->dictionary, + value_index)) { + status = sqlite3_bind_null(stmt, col + 1); + } else { + struct ArrowBufferView value = ArrowArrayViewGetBytesUnsafe( + binder->batch.children[col]->dictionary, value_index); + status = sqlite3_bind_text(stmt, col + 1, value.data.as_char, + value.size_bytes, SQLITE_STATIC); + } + break; + } case NANOARROW_TYPE_DATE32: { int64_t value = ArrowArrayViewGetIntUnsafe(binder->batch.children[col], binder->next_row); diff --git a/c/validation/adbc_validation.cc b/c/validation/adbc_validation.cc index c2f32ff502..0f9023e890 100644 --- a/c/validation/adbc_validation.cc +++ b/c/validation/adbc_validation.cc @@ -1366,7 +1366,8 @@ void StatementTest::TestRelease() { template void StatementTest::TestSqlIngestType(ArrowType type, - const std::vector>& values) { + const std::vector>& values, + bool dictionary_encode) { if (!quirks()->supports_bulk_ingest(ADBC_INGEST_OPTION_MODE_CREATE)) { GTEST_SKIP(); } @@ -1381,6 +1382,38 @@ void StatementTest::TestSqlIngestType(ArrowType type, ASSERT_THAT(MakeBatch(&schema.value, &array.value, &na_error, values), IsOkErrno()); + if (dictionary_encode) { + // Create a dictionary-encoded version of the target schema + Handle dict_schema; + ASSERT_THAT(ArrowSchemaInitFromType(&dict_schema.value, NANOARROW_TYPE_INT32), + IsOkErrno()); + ASSERT_THAT(ArrowSchemaSetName(&dict_schema.value, schema.value.children[0]->name), + IsOkErrno()); + ASSERT_THAT(ArrowSchemaSetName(schema.value.children[0], nullptr), IsOkErrno()); + + // Swap it into the target schema + ASSERT_THAT(ArrowSchemaAllocateDictionary(&dict_schema.value), IsOkErrno()); + ArrowSchemaMove(schema.value.children[0], dict_schema.value.dictionary); + ArrowSchemaMove(&dict_schema.value, schema.value.children[0]); + + // Create a dictionary-encoded array with easy 0...n indices so that the + // matched values will be the same. + Handle dict_array; + ASSERT_THAT(ArrowArrayInitFromType(&dict_array.value, NANOARROW_TYPE_INT32), + IsOkErrno()); + ASSERT_THAT(ArrowArrayStartAppending(&dict_array.value), IsOkErrno()); + for (size_t i = 0; i < values.size(); i++) { + ASSERT_THAT(ArrowArrayAppendInt(&dict_array.value, static_cast(i)), + IsOkErrno()); + } + ASSERT_THAT(ArrowArrayFinishBuildingDefault(&dict_array.value, nullptr), IsOkErrno()); + + // Swap it into the target batch + ASSERT_THAT(ArrowArrayAllocateDictionary(&dict_array.value), IsOkErrno()); + ArrowArrayMove(array.value.children[0], dict_array.value.dictionary); + ArrowArrayMove(&dict_array.value, array.value.children[0]); + } + ASSERT_THAT(AdbcStatementNew(&connection, &statement, &error), IsOkStatus(&error)); ASSERT_THAT(AdbcStatementSetOption(&statement, ADBC_INGEST_OPTION_TARGET_TABLE, "bulk_ingest", &error), @@ -1448,7 +1481,7 @@ void StatementTest::TestSqlIngestNumericType(ArrowType type) { values.push_back(std::numeric_limits::max()); } - return TestSqlIngestType(type, values); + return TestSqlIngestType(type, values, false); } void StatementTest::TestSqlIngestBool() { @@ -1497,25 +1530,23 @@ void StatementTest::TestSqlIngestFloat64() { void StatementTest::TestSqlIngestString() { ASSERT_NO_FATAL_FAILURE(TestSqlIngestType( - NANOARROW_TYPE_STRING, {std::nullopt, "", "", "1234", "例"})); + NANOARROW_TYPE_STRING, {std::nullopt, "", "", "1234", "例"}, false)); } void StatementTest::TestSqlIngestLargeString() { ASSERT_NO_FATAL_FAILURE(TestSqlIngestType( - NANOARROW_TYPE_LARGE_STRING, {std::nullopt, "", "", "1234", "例"})); + NANOARROW_TYPE_LARGE_STRING, {std::nullopt, "", "", "1234", "例"}, false)); } void StatementTest::TestSqlIngestBinary() { ASSERT_NO_FATAL_FAILURE(TestSqlIngestType>( NANOARROW_TYPE_BINARY, - { - std::nullopt, std::vector{}, - std::vector{std::byte{0x00}, std::byte{0x01}}, - std::vector{ - std::byte{0x01}, std::byte{0x02}, std::byte{0x03}, std::byte{0x04} - }, - std::vector{std::byte{0xfe}, std::byte{0xff}} - })); + {std::nullopt, std::vector{}, + std::vector{std::byte{0x00}, std::byte{0x01}}, + std::vector{std::byte{0x01}, std::byte{0x02}, std::byte{0x03}, + std::byte{0x04}}, + std::vector{std::byte{0xfe}, std::byte{0xff}}}, + false)); } void StatementTest::TestSqlIngestDate32() { @@ -1737,6 +1768,12 @@ void StatementTest::TestSqlIngestInterval() { ASSERT_THAT(AdbcStatementRelease(&statement, &error), IsOkStatus(&error)); } +void StatementTest::TestSqlIngestStringDictionary() { + ASSERT_NO_FATAL_FAILURE(TestSqlIngestType( + NANOARROW_TYPE_STRING, {std::nullopt, "", "", "1234", "例"}, + /*dictionary_encode*/ true)); +} + void StatementTest::TestSqlIngestTableEscaping() { std::string name = "create_table_escaping"; @@ -2112,8 +2149,7 @@ void StatementTest::TestSqlIngestErrors() { {"coltwo", NANOARROW_TYPE_INT64}}), IsOkErrno()); ASSERT_THAT( - (MakeBatch(&schema.value, &array.value, &na_error, - {-42}, {-42})), + (MakeBatch(&schema.value, &array.value, &na_error, {-42}, {-42})), IsOkErrno(&na_error)); ASSERT_THAT(AdbcStatementBind(&statement, &array.value, &schema.value, &error), diff --git a/c/validation/adbc_validation.h b/c/validation/adbc_validation.h index d125d1be10..47fa4df86b 100644 --- a/c/validation/adbc_validation.h +++ b/c/validation/adbc_validation.h @@ -327,6 +327,9 @@ class StatementTest { void TestSqlIngestTimestampTz(); void TestSqlIngestInterval(); + // Dictionary-encoded + void TestSqlIngestStringDictionary(); + // ---- End Type-specific tests ---------------- void TestSqlIngestTableEscaping(); @@ -385,7 +388,8 @@ class StatementTest { struct AdbcStatement statement; template - void TestSqlIngestType(ArrowType type, const std::vector>& values); + void TestSqlIngestType(ArrowType type, const std::vector>& values, + bool dictionary_encode); template void TestSqlIngestNumericType(ArrowType type); @@ -422,6 +426,7 @@ class StatementTest { TEST_F(FIXTURE, SqlIngestTimestamp) { TestSqlIngestTimestamp(); } \ TEST_F(FIXTURE, SqlIngestTimestampTz) { TestSqlIngestTimestampTz(); } \ TEST_F(FIXTURE, SqlIngestInterval) { TestSqlIngestInterval(); } \ + TEST_F(FIXTURE, SqlIngestStringDictionary) { TestSqlIngestStringDictionary(); } \ TEST_F(FIXTURE, SqlIngestTableEscaping) { TestSqlIngestTableEscaping(); } \ TEST_F(FIXTURE, SqlIngestColumnEscaping) { TestSqlIngestColumnEscaping(); } \ TEST_F(FIXTURE, SqlIngestAppend) { TestSqlIngestAppend(); } \