Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(c/driver/sqlite): Support binding dictionary-encoded string and binary types #1224

Merged
merged 5 commits into from
Oct 26, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions c/driver/postgresql/postgresql_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -812,6 +812,7 @@ class PostgresStatementTest : public ::testing::Test,
void TestSqlIngestUInt16() { GTEST_SKIP() << "Not implemented"; }
void TestSqlIngestUInt32() { GTEST_SKIP() << "Not implemented"; }
void TestSqlIngestUInt64() { GTEST_SKIP() << "Not implemented"; }
void TestSqlIngestStringDictionary() { GTEST_SKIP() << "Not implemented"; }

void TestSqlPrepareErrorParamCountMismatch() { GTEST_SKIP() << "Not yet implemented"; }
void TestSqlPrepareGetParameterSchema() { GTEST_SKIP() << "Not yet implemented"; }
Expand Down
2 changes: 1 addition & 1 deletion c/driver/sqlite/sqlite_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -246,7 +246,7 @@ class SqliteStatementTest : public ::testing::Test,

void TestSqlIngestUInt64() {
std::vector<std::optional<uint64_t>> values = {std::nullopt, 0, INT64_MAX};
return TestSqlIngestType(NANOARROW_TYPE_UINT64, values);
return TestSqlIngestType(NANOARROW_TYPE_UINT64, values, /*dictionary_encode*/ false);
}

void TestSqlIngestBinary() { GTEST_SKIP() << "Cannot ingest BINARY (not implemented)"; }
Expand Down
41 changes: 40 additions & 1 deletion c/driver/sqlite/statement_reader.c
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ AdbcStatusCode AdbcSqliteBinderSet(struct AdbcSqliteBinder* binder,
struct ArrowSchemaView view = {0};
for (int i = 0; i < binder->schema.n_children; i++) {
status = ArrowSchemaViewInit(&view, binder->schema.children[i], &arrow_error);
if (status != 0) {
if (status != NANOARROW_OK) {
SetError(error, "Failed to parse schema for column %d: %s (%d): %s", i,
strerror(status), status, arrow_error.message);
return ADBC_STATUS_INVALID_ARGUMENT;
Expand All @@ -70,6 +70,31 @@ AdbcStatusCode AdbcSqliteBinderSet(struct AdbcSqliteBinder* binder,
SetError(error, "Column %d has UNINITIALIZED type", i);
return ADBC_STATUS_INTERNAL;
}

if (view.type == NANOARROW_TYPE_DICTIONARY) {
struct ArrowSchemaView value_view = {0};
status = ArrowSchemaViewInit(&value_view, binder->schema.children[i]->dictionary,
&arrow_error);
if (status != NANOARROW_OK) {
SetError(error, "Failed to parse schema for column %d->dictionary: %s (%d): %s",
i, strerror(status), status, arrow_error.message);
return ADBC_STATUS_INVALID_ARGUMENT;
}

// We only support string/binary dictionary-encoded values
switch (value_view.type) {
case NANOARROW_TYPE_STRING:
case NANOARROW_TYPE_LARGE_STRING:
case NANOARROW_TYPE_BINARY:
case NANOARROW_TYPE_LARGE_BINARY:
break;
default:
SetError(error, "Column %d dictionary has unsupported type %s", i,
ArrowTypeString(value_view.type));
return ADBC_STATUS_NOT_IMPLEMENTED;
}
}

binder->types[i] = view.type;
}

Expand Down Expand Up @@ -353,6 +378,20 @@ AdbcStatusCode AdbcSqliteBinderBindNext(struct AdbcSqliteBinder* binder, sqlite3
SQLITE_STATIC);
break;
}
case NANOARROW_TYPE_DICTIONARY: {
int64_t value_index =
ArrowArrayViewGetIntUnsafe(binder->batch.children[col], binder->next_row);
if (ArrowArrayViewIsNull(binder->batch.children[col]->dictionary,
value_index)) {
status = sqlite3_bind_null(stmt, col + 1);
} else {
struct ArrowBufferView value = ArrowArrayViewGetBytesUnsafe(
binder->batch.children[col]->dictionary, value_index);
status = sqlite3_bind_text(stmt, col + 1, value.data.as_char,
value.size_bytes, SQLITE_STATIC);
}
break;
}
case NANOARROW_TYPE_DATE32: {
int64_t value =
ArrowArrayViewGetIntUnsafe(binder->batch.children[col], binder->next_row);
Expand Down
64 changes: 50 additions & 14 deletions c/validation/adbc_validation.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1366,7 +1366,8 @@ void StatementTest::TestRelease() {

template <typename CType>
void StatementTest::TestSqlIngestType(ArrowType type,
const std::vector<std::optional<CType>>& values) {
const std::vector<std::optional<CType>>& values,
bool dictionary_encode) {
if (!quirks()->supports_bulk_ingest(ADBC_INGEST_OPTION_MODE_CREATE)) {
GTEST_SKIP();
}
Expand All @@ -1381,6 +1382,38 @@ void StatementTest::TestSqlIngestType(ArrowType type,
ASSERT_THAT(MakeBatch<CType>(&schema.value, &array.value, &na_error, values),
IsOkErrno());

if (dictionary_encode) {
// Create a dictionary-encoded version of the target schema
Handle<struct ArrowSchema> dict_schema;
ASSERT_THAT(ArrowSchemaInitFromType(&dict_schema.value, NANOARROW_TYPE_INT32),
IsOkErrno());
ASSERT_THAT(ArrowSchemaSetName(&dict_schema.value, schema.value.children[0]->name),
IsOkErrno());
ASSERT_THAT(ArrowSchemaSetName(schema.value.children[0], nullptr), IsOkErrno());

// Swap it into the target schema
ASSERT_THAT(ArrowSchemaAllocateDictionary(&dict_schema.value), IsOkErrno());
ArrowSchemaMove(schema.value.children[0], dict_schema.value.dictionary);
ArrowSchemaMove(&dict_schema.value, schema.value.children[0]);

// Create a dictionary-encoded array with easy 0...n indices so that the
// matched values will be the same.
Handle<struct ArrowArray> dict_array;
ASSERT_THAT(ArrowArrayInitFromType(&dict_array.value, NANOARROW_TYPE_INT32),
IsOkErrno());
ASSERT_THAT(ArrowArrayStartAppending(&dict_array.value), IsOkErrno());
for (size_t i = 0; i < values.size(); i++) {
ASSERT_THAT(ArrowArrayAppendInt(&dict_array.value, static_cast<int64_t>(i)),
IsOkErrno());
}
ASSERT_THAT(ArrowArrayFinishBuildingDefault(&dict_array.value, nullptr), IsOkErrno());

// Swap it into the target batch
ASSERT_THAT(ArrowArrayAllocateDictionary(&dict_array.value), IsOkErrno());
ArrowArrayMove(array.value.children[0], dict_array.value.dictionary);
ArrowArrayMove(&dict_array.value, array.value.children[0]);
}

ASSERT_THAT(AdbcStatementNew(&connection, &statement, &error), IsOkStatus(&error));
ASSERT_THAT(AdbcStatementSetOption(&statement, ADBC_INGEST_OPTION_TARGET_TABLE,
"bulk_ingest", &error),
Expand Down Expand Up @@ -1448,7 +1481,7 @@ void StatementTest::TestSqlIngestNumericType(ArrowType type) {
values.push_back(std::numeric_limits<CType>::max());
}

return TestSqlIngestType(type, values);
return TestSqlIngestType(type, values, false);
}

void StatementTest::TestSqlIngestBool() {
Expand Down Expand Up @@ -1497,25 +1530,23 @@ void StatementTest::TestSqlIngestFloat64() {

void StatementTest::TestSqlIngestString() {
ASSERT_NO_FATAL_FAILURE(TestSqlIngestType<std::string>(
NANOARROW_TYPE_STRING, {std::nullopt, "", "", "1234", "例"}));
NANOARROW_TYPE_STRING, {std::nullopt, "", "", "1234", "例"}, false));
}

void StatementTest::TestSqlIngestLargeString() {
ASSERT_NO_FATAL_FAILURE(TestSqlIngestType<std::string>(
NANOARROW_TYPE_LARGE_STRING, {std::nullopt, "", "", "1234", "例"}));
NANOARROW_TYPE_LARGE_STRING, {std::nullopt, "", "", "1234", "例"}, false));
}

void StatementTest::TestSqlIngestBinary() {
ASSERT_NO_FATAL_FAILURE(TestSqlIngestType<std::vector<std::byte>>(
NANOARROW_TYPE_BINARY,
{
std::nullopt, std::vector<std::byte>{},
std::vector<std::byte>{std::byte{0x00}, std::byte{0x01}},
std::vector<std::byte>{
std::byte{0x01}, std::byte{0x02}, std::byte{0x03}, std::byte{0x04}
},
std::vector<std::byte>{std::byte{0xfe}, std::byte{0xff}}
}));
{std::nullopt, std::vector<std::byte>{},
std::vector<std::byte>{std::byte{0x00}, std::byte{0x01}},
std::vector<std::byte>{std::byte{0x01}, std::byte{0x02}, std::byte{0x03},
std::byte{0x04}},
std::vector<std::byte>{std::byte{0xfe}, std::byte{0xff}}},
false));
}

void StatementTest::TestSqlIngestDate32() {
Expand Down Expand Up @@ -1737,6 +1768,12 @@ void StatementTest::TestSqlIngestInterval() {
ASSERT_THAT(AdbcStatementRelease(&statement, &error), IsOkStatus(&error));
}

void StatementTest::TestSqlIngestStringDictionary() {
ASSERT_NO_FATAL_FAILURE(TestSqlIngestType<std::string>(
NANOARROW_TYPE_STRING, {std::nullopt, "", "", "1234", "例"},
/*dictionary_encode*/ true));
}

void StatementTest::TestSqlIngestTableEscaping() {
std::string name = "create_table_escaping";

Expand Down Expand Up @@ -2112,8 +2149,7 @@ void StatementTest::TestSqlIngestErrors() {
{"coltwo", NANOARROW_TYPE_INT64}}),
IsOkErrno());
ASSERT_THAT(
(MakeBatch<int64_t, int64_t>(&schema.value, &array.value, &na_error,
{-42}, {-42})),
(MakeBatch<int64_t, int64_t>(&schema.value, &array.value, &na_error, {-42}, {-42})),
IsOkErrno(&na_error));

ASSERT_THAT(AdbcStatementBind(&statement, &array.value, &schema.value, &error),
Expand Down
7 changes: 6 additions & 1 deletion c/validation/adbc_validation.h
Original file line number Diff line number Diff line change
Expand Up @@ -327,6 +327,9 @@ class StatementTest {
void TestSqlIngestTimestampTz();
void TestSqlIngestInterval();

// Dictionary-encoded
void TestSqlIngestStringDictionary();

// ---- End Type-specific tests ----------------

void TestSqlIngestTableEscaping();
Expand Down Expand Up @@ -385,7 +388,8 @@ class StatementTest {
struct AdbcStatement statement;

template <typename CType>
void TestSqlIngestType(ArrowType type, const std::vector<std::optional<CType>>& values);
void TestSqlIngestType(ArrowType type, const std::vector<std::optional<CType>>& values,
bool dictionary_encode);

template <typename CType>
void TestSqlIngestNumericType(ArrowType type);
Expand Down Expand Up @@ -422,6 +426,7 @@ class StatementTest {
TEST_F(FIXTURE, SqlIngestTimestamp) { TestSqlIngestTimestamp(); } \
TEST_F(FIXTURE, SqlIngestTimestampTz) { TestSqlIngestTimestampTz(); } \
TEST_F(FIXTURE, SqlIngestInterval) { TestSqlIngestInterval(); } \
TEST_F(FIXTURE, SqlIngestStringDictionary) { TestSqlIngestStringDictionary(); } \
TEST_F(FIXTURE, SqlIngestTableEscaping) { TestSqlIngestTableEscaping(); } \
TEST_F(FIXTURE, SqlIngestColumnEscaping) { TestSqlIngestColumnEscaping(); } \
TEST_F(FIXTURE, SqlIngestAppend) { TestSqlIngestAppend(); } \
Expand Down
Loading