Skip to content

Commit

Permalink
feat(c/driver/sqlite): Support binding dictionary-encoded string and …
Browse files Browse the repository at this point in the history
…binary types (#1224)

This PR adds the ability to ingest dictionary-encoded string and binary
columns.

Part of addressing #1008.

From the R bindings:

``` r
library(adbcdrivermanager)

db <- adbc_database_init(adbcsqlite::adbcsqlite(), uri = ":memory:")
con <- adbc_connection_init(db)

df <- data.frame(x = factor(letters[1:10]))
write_adbc(df, con, "tbl")

read_adbc(con, "SELECT * from tbl") |> 
  as.data.frame()  
#>    x
#> 1  a
#> 2  b
#> 3  c
#> 4  d
#> 5  e
#> 6  f
#> 7  g
#> 8  h
#> 9  i
#> 10 j
```

<sup>Created on 2023-10-25 with [reprex
v2.0.2](https://reprex.tidyverse.org)</sup>
  • Loading branch information
paleolimbot authored Oct 26, 2023
1 parent 1bd874b commit 1789870
Show file tree
Hide file tree
Showing 5 changed files with 98 additions and 17 deletions.
1 change: 1 addition & 0 deletions c/driver/postgresql/postgresql_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -812,6 +812,7 @@ class PostgresStatementTest : public ::testing::Test,
void TestSqlIngestUInt16() { GTEST_SKIP() << "Not implemented"; }
void TestSqlIngestUInt32() { GTEST_SKIP() << "Not implemented"; }
void TestSqlIngestUInt64() { GTEST_SKIP() << "Not implemented"; }
void TestSqlIngestStringDictionary() { GTEST_SKIP() << "Not implemented"; }

void TestSqlPrepareErrorParamCountMismatch() { GTEST_SKIP() << "Not yet implemented"; }
void TestSqlPrepareGetParameterSchema() { GTEST_SKIP() << "Not yet implemented"; }
Expand Down
2 changes: 1 addition & 1 deletion c/driver/sqlite/sqlite_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -246,7 +246,7 @@ class SqliteStatementTest : public ::testing::Test,

void TestSqlIngestUInt64() {
std::vector<std::optional<uint64_t>> values = {std::nullopt, 0, INT64_MAX};
return TestSqlIngestType(NANOARROW_TYPE_UINT64, values);
return TestSqlIngestType(NANOARROW_TYPE_UINT64, values, /*dictionary_encode*/ false);
}

void TestSqlIngestDuration() {
Expand Down
41 changes: 40 additions & 1 deletion c/driver/sqlite/statement_reader.c
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ AdbcStatusCode AdbcSqliteBinderSet(struct AdbcSqliteBinder* binder,
struct ArrowSchemaView view = {0};
for (int i = 0; i < binder->schema.n_children; i++) {
status = ArrowSchemaViewInit(&view, binder->schema.children[i], &arrow_error);
if (status != 0) {
if (status != NANOARROW_OK) {
SetError(error, "Failed to parse schema for column %d: %s (%d): %s", i,
strerror(status), status, arrow_error.message);
return ADBC_STATUS_INVALID_ARGUMENT;
Expand All @@ -70,6 +70,31 @@ AdbcStatusCode AdbcSqliteBinderSet(struct AdbcSqliteBinder* binder,
SetError(error, "Column %d has UNINITIALIZED type", i);
return ADBC_STATUS_INTERNAL;
}

if (view.type == NANOARROW_TYPE_DICTIONARY) {
struct ArrowSchemaView value_view = {0};
status = ArrowSchemaViewInit(&value_view, binder->schema.children[i]->dictionary,
&arrow_error);
if (status != NANOARROW_OK) {
SetError(error, "Failed to parse schema for column %d->dictionary: %s (%d): %s",
i, strerror(status), status, arrow_error.message);
return ADBC_STATUS_INVALID_ARGUMENT;
}

// We only support string/binary dictionary-encoded values
switch (value_view.type) {
case NANOARROW_TYPE_STRING:
case NANOARROW_TYPE_LARGE_STRING:
case NANOARROW_TYPE_BINARY:
case NANOARROW_TYPE_LARGE_BINARY:
break;
default:
SetError(error, "Column %d dictionary has unsupported type %s", i,
ArrowTypeString(value_view.type));
return ADBC_STATUS_NOT_IMPLEMENTED;
}
}

binder->types[i] = view.type;
}

Expand Down Expand Up @@ -353,6 +378,20 @@ AdbcStatusCode AdbcSqliteBinderBindNext(struct AdbcSqliteBinder* binder, sqlite3
SQLITE_STATIC);
break;
}
case NANOARROW_TYPE_DICTIONARY: {
int64_t value_index =
ArrowArrayViewGetIntUnsafe(binder->batch.children[col], binder->next_row);
if (ArrowArrayViewIsNull(binder->batch.children[col]->dictionary,
value_index)) {
status = sqlite3_bind_null(stmt, col + 1);
} else {
struct ArrowBufferView value = ArrowArrayViewGetBytesUnsafe(
binder->batch.children[col]->dictionary, value_index);
status = sqlite3_bind_text(stmt, col + 1, value.data.as_char,
value.size_bytes, SQLITE_STATIC);
}
break;
}
case NANOARROW_TYPE_DATE32: {
int64_t value =
ArrowArrayViewGetIntUnsafe(binder->batch.children[col], binder->next_row);
Expand Down
64 changes: 50 additions & 14 deletions c/validation/adbc_validation.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1366,7 +1366,8 @@ void StatementTest::TestRelease() {

template <typename CType>
void StatementTest::TestSqlIngestType(ArrowType type,
const std::vector<std::optional<CType>>& values) {
const std::vector<std::optional<CType>>& values,
bool dictionary_encode) {
if (!quirks()->supports_bulk_ingest(ADBC_INGEST_OPTION_MODE_CREATE)) {
GTEST_SKIP();
}
Expand All @@ -1381,6 +1382,38 @@ void StatementTest::TestSqlIngestType(ArrowType type,
ASSERT_THAT(MakeBatch<CType>(&schema.value, &array.value, &na_error, values),
IsOkErrno());

if (dictionary_encode) {
// Create a dictionary-encoded version of the target schema
Handle<struct ArrowSchema> dict_schema;
ASSERT_THAT(ArrowSchemaInitFromType(&dict_schema.value, NANOARROW_TYPE_INT32),
IsOkErrno());
ASSERT_THAT(ArrowSchemaSetName(&dict_schema.value, schema.value.children[0]->name),
IsOkErrno());
ASSERT_THAT(ArrowSchemaSetName(schema.value.children[0], nullptr), IsOkErrno());

// Swap it into the target schema
ASSERT_THAT(ArrowSchemaAllocateDictionary(&dict_schema.value), IsOkErrno());
ArrowSchemaMove(schema.value.children[0], dict_schema.value.dictionary);
ArrowSchemaMove(&dict_schema.value, schema.value.children[0]);

// Create a dictionary-encoded array with easy 0...n indices so that the
// matched values will be the same.
Handle<struct ArrowArray> dict_array;
ASSERT_THAT(ArrowArrayInitFromType(&dict_array.value, NANOARROW_TYPE_INT32),
IsOkErrno());
ASSERT_THAT(ArrowArrayStartAppending(&dict_array.value), IsOkErrno());
for (size_t i = 0; i < values.size(); i++) {
ASSERT_THAT(ArrowArrayAppendInt(&dict_array.value, static_cast<int64_t>(i)),
IsOkErrno());
}
ASSERT_THAT(ArrowArrayFinishBuildingDefault(&dict_array.value, nullptr), IsOkErrno());

// Swap it into the target batch
ASSERT_THAT(ArrowArrayAllocateDictionary(&dict_array.value), IsOkErrno());
ArrowArrayMove(array.value.children[0], dict_array.value.dictionary);
ArrowArrayMove(&dict_array.value, array.value.children[0]);
}

ASSERT_THAT(AdbcStatementNew(&connection, &statement, &error), IsOkStatus(&error));
ASSERT_THAT(AdbcStatementSetOption(&statement, ADBC_INGEST_OPTION_TARGET_TABLE,
"bulk_ingest", &error),
Expand Down Expand Up @@ -1448,7 +1481,7 @@ void StatementTest::TestSqlIngestNumericType(ArrowType type) {
values.push_back(std::numeric_limits<CType>::max());
}

return TestSqlIngestType(type, values);
return TestSqlIngestType(type, values, false);
}

void StatementTest::TestSqlIngestBool() {
Expand Down Expand Up @@ -1497,25 +1530,23 @@ void StatementTest::TestSqlIngestFloat64() {

void StatementTest::TestSqlIngestString() {
ASSERT_NO_FATAL_FAILURE(TestSqlIngestType<std::string>(
NANOARROW_TYPE_STRING, {std::nullopt, "", "", "1234", ""}));
NANOARROW_TYPE_STRING, {std::nullopt, "", "", "1234", ""}, false));
}

void StatementTest::TestSqlIngestLargeString() {
ASSERT_NO_FATAL_FAILURE(TestSqlIngestType<std::string>(
NANOARROW_TYPE_LARGE_STRING, {std::nullopt, "", "", "1234", ""}));
NANOARROW_TYPE_LARGE_STRING, {std::nullopt, "", "", "1234", ""}, false));
}

void StatementTest::TestSqlIngestBinary() {
ASSERT_NO_FATAL_FAILURE(TestSqlIngestType<std::vector<std::byte>>(
NANOARROW_TYPE_BINARY,
{
std::nullopt, std::vector<std::byte>{},
std::vector<std::byte>{std::byte{0x00}, std::byte{0x01}},
std::vector<std::byte>{
std::byte{0x01}, std::byte{0x02}, std::byte{0x03}, std::byte{0x04}
},
std::vector<std::byte>{std::byte{0xfe}, std::byte{0xff}}
}));
{std::nullopt, std::vector<std::byte>{},
std::vector<std::byte>{std::byte{0x00}, std::byte{0x01}},
std::vector<std::byte>{std::byte{0x01}, std::byte{0x02}, std::byte{0x03},
std::byte{0x04}},
std::vector<std::byte>{std::byte{0xfe}, std::byte{0xff}}},
false));
}

void StatementTest::TestSqlIngestDate32() {
Expand Down Expand Up @@ -1737,6 +1768,12 @@ void StatementTest::TestSqlIngestInterval() {
ASSERT_THAT(AdbcStatementRelease(&statement, &error), IsOkStatus(&error));
}

void StatementTest::TestSqlIngestStringDictionary() {
ASSERT_NO_FATAL_FAILURE(TestSqlIngestType<std::string>(
NANOARROW_TYPE_STRING, {std::nullopt, "", "", "1234", ""},
/*dictionary_encode*/ true));
}

void StatementTest::TestSqlIngestTableEscaping() {
std::string name = "create_table_escaping";

Expand Down Expand Up @@ -2112,8 +2149,7 @@ void StatementTest::TestSqlIngestErrors() {
{"coltwo", NANOARROW_TYPE_INT64}}),
IsOkErrno());
ASSERT_THAT(
(MakeBatch<int64_t, int64_t>(&schema.value, &array.value, &na_error,
{-42}, {-42})),
(MakeBatch<int64_t, int64_t>(&schema.value, &array.value, &na_error, {-42}, {-42})),
IsOkErrno(&na_error));

ASSERT_THAT(AdbcStatementBind(&statement, &array.value, &schema.value, &error),
Expand Down
7 changes: 6 additions & 1 deletion c/validation/adbc_validation.h
Original file line number Diff line number Diff line change
Expand Up @@ -327,6 +327,9 @@ class StatementTest {
void TestSqlIngestTimestampTz();
void TestSqlIngestInterval();

// Dictionary-encoded
void TestSqlIngestStringDictionary();

// ---- End Type-specific tests ----------------

void TestSqlIngestTableEscaping();
Expand Down Expand Up @@ -387,7 +390,8 @@ class StatementTest {
struct AdbcStatement statement;

template <typename CType>
void TestSqlIngestType(ArrowType type, const std::vector<std::optional<CType>>& values);
void TestSqlIngestType(ArrowType type, const std::vector<std::optional<CType>>& values,
bool dictionary_encode);

template <typename CType>
void TestSqlIngestNumericType(ArrowType type);
Expand Down Expand Up @@ -424,6 +428,7 @@ class StatementTest {
TEST_F(FIXTURE, SqlIngestTimestamp) { TestSqlIngestTimestamp(); } \
TEST_F(FIXTURE, SqlIngestTimestampTz) { TestSqlIngestTimestampTz(); } \
TEST_F(FIXTURE, SqlIngestInterval) { TestSqlIngestInterval(); } \
TEST_F(FIXTURE, SqlIngestStringDictionary) { TestSqlIngestStringDictionary(); } \
TEST_F(FIXTURE, SqlIngestTableEscaping) { TestSqlIngestTableEscaping(); } \
TEST_F(FIXTURE, SqlIngestColumnEscaping) { TestSqlIngestColumnEscaping(); } \
TEST_F(FIXTURE, SqlIngestAppend) { TestSqlIngestAppend(); } \
Expand Down

0 comments on commit 1789870

Please sign in to comment.