Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

source-snowflake: discover vector columns #2157

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
147 changes: 147 additions & 0 deletions source-snowflake/.snapshots/TestVectorDatatypes-discover
Original file line number Diff line number Diff line change
@@ -0,0 +1,147 @@
Binding 0:
{
"recommended_name": "test_vectordatatypes_56892992",
"resource_config_json": {
"schema": "PUBLIC",
"table": "test_VectorDatatypes_56892992"
},
"document_schema_json": {
"$defs": {
"Test_VectorDatatypes_56892992": {
"type": "object",
"required": [
"ID"
],
"$anchor": "Test_VectorDatatypes_56892992",
"properties": {
"A": {
"items": {
"type": "number"
},
"maxItems": 5,
"minItems": 5,
"type": [
"array",
"null"
]
},
"B": {
"items": {
"type": "integer"
},
"maxItems": 3,
"minItems": 3,
"type": [
"array",
"null"
]
},
"ID": {
"type": "integer"
}
}
}
},
"allOf": [
{
"if": {
"properties": {
"_meta": {
"properties": {
"op": {
"const": "d"
}
}
}
}
},
"then": {
"reduce": {
"delete": true,
"strategy": "merge"
}
},
"else": {
"reduce": {
"strategy": "merge"
}
},
"required": [
"_meta"
],
"properties": {
"_meta": {
"type": "object",
"required": [
"op",
"source"
],
"properties": {
"before": {
"$ref": "#Test_VectorDatatypes_56892992",
"description": "Record state immediately before this change was applied.",
"reduce": {
"strategy": "firstWriteWins"
}
},
"op": {
"enum": [
"c",
"d",
"u"
],
"description": "Change operation type: 'c' Create/Insert, 'u' Update, 'd' Delete."
},
"source": {
"$id": "https://github.com/estuary/connectors/source-snowflake/snowflake-source-metadata",
"properties": {
"ts_ms": {
"type": "integer",
"description": "Unix timestamp (in millis) at which this event was recorded by the database."
},
"schema": {
"type": "string",
"description": "Database schema (namespace) of the event."
},
"snapshot": {
"type": "boolean",
"description": "Snapshot is true if the record was produced from an initial table backfill and unset if produced from the replication log."
},
"table": {
"type": "string",
"description": "Database table of the event."
},
"seq": {
"type": "integer",
"description": "The sequence number of the staging table from which this document was read"
},
"off": {
"type": "integer",
"description": "The offset within that staging table at which this document occurred"
}
},
"type": "object",
"required": [
"schema",
"table",
"seq",
"off"
]
}
},
"reduce": {
"strategy": "merge"
}
}
}
},
{
"$ref": "#Test_VectorDatatypes_56892992"
}
]
},
"key": [
"/ID"
]
}

10 changes: 10 additions & 0 deletions source-snowflake/.snapshots/TestVectorDatatypes-init
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
# ================================
# Collection "acmeCo/test/test_vectordatatypes_56892992": 2 Documents
# ================================
{"A":"[1.100000,2.200000,3.000000,4.400000,5.500000]","B":null,"ID":1,"_meta":{"op":"c","source":{"schema":"PUBLIC","snapshot":true,"table":"test_VectorDatatypes_56892992","seq":0,"off":0}}}
{"A":null,"B":"[1,2,3]","ID":2,"_meta":{"op":"c","source":{"schema":"PUBLIC","snapshot":true,"table":"test_VectorDatatypes_56892992","seq":0,"off":1}}}
# ================================
# Final State Checkpoint
# ================================
{"streams":{"PUBLIC%2Ftest_VectorDatatypes_56892992":{"off":0,"seq":2,"uid":"FFFFFFFFFFFFFFFF_PUBLIC_TEST_VECTORDATATYPES_56892992"}}}

10 changes: 10 additions & 0 deletions source-snowflake/.snapshots/TestVectorDatatypes-main
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
# ================================
# Collection "acmeCo/test/test_vectordatatypes_56892992": 2 Documents
# ================================
{"A":"[1.100000,2.200000,3.000000,4.400000,5.500000]","B":null,"ID":3,"_meta":{"op":"c","source":{"schema":"PUBLIC","table":"test_VectorDatatypes_56892992","seq":2,"off":0}}}
{"A":null,"B":"[1,2,3]","ID":4,"_meta":{"op":"c","source":{"schema":"PUBLIC","table":"test_VectorDatatypes_56892992","seq":2,"off":1}}}
# ================================
# Final State Checkpoint
# ================================
{"streams":{"PUBLIC%2Ftest_VectorDatatypes_56892992":{"off":0,"seq":3,"uid":"FFFFFFFFFFFFFFFF_PUBLIC_TEST_VECTORDATATYPES_56892992"}}}

17 changes: 17 additions & 0 deletions source-snowflake/capture_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,23 @@ func TestVariantDatatypes(t *testing.T) {
t.Run("capture", func(t *testing.T) { verifiedCapture(ctx, t, cs) })
}

func TestVectorDatatypes(t *testing.T) {
var ctx, tb = context.Background(), snowflakeTestBackend(t)
var uniqueID = "56892992"
var tableName = tb.CreateTable(ctx, t, uniqueID, "(id INTEGER PRIMARY KEY NOT NULL, a VECTOR(FLOAT, 5), b VECTOR(INT, 3))")
var cs = tb.CaptureSpec(ctx, t, regexp.MustCompile(uniqueID))

t.Run("discover", func(t *testing.T) { cs.VerifyDiscover(ctx, t, regexp.MustCompile(uniqueID)) })

tb.Query(ctx, t, fmt.Sprintf(`INSERT INTO %s SELECT 1, [1.1,2.2,3,4.4,5.5]::VECTOR(FLOAT,5), NULL;`, tableName))
tb.Query(ctx, t, fmt.Sprintf(`INSERT INTO %s SELECT 2, NULL, [1,2,3]::VECTOR(INT,3);`, tableName))
t.Run("init", func(t *testing.T) { verifiedCapture(ctx, t, cs) })

tb.Query(ctx, t, fmt.Sprintf(`INSERT INTO %s SELECT 3, [1.1,2.2,3,4.4,5.5]::VECTOR(FLOAT,5), NULL;`, tableName))
tb.Query(ctx, t, fmt.Sprintf(`INSERT INTO %s SELECT 4, NULL, [1,2,3]::VECTOR(INT,3);`, tableName))
t.Run("main", func(t *testing.T) { verifiedCapture(ctx, t, cs) })
}

func TestLargeCapture(t *testing.T) {
var ctx, tb = context.Background(), snowflakeTestBackend(t)
var uniqueID = "63855638"
Expand Down
106 changes: 104 additions & 2 deletions source-snowflake/discovery.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,15 +6,18 @@ import (
"database/sql"
"encoding/json"
"fmt"
"regexp"
"slices"
"strconv"
"strings"
"sync"

pc "github.com/estuary/flow/go/protocols/capture"
pf "github.com/estuary/flow/go/protocols/flow"
"github.com/invopop/jsonschema"
"github.com/jmoiron/sqlx"
"github.com/sirupsen/logrus"
log "github.com/sirupsen/logrus"
"golang.org/x/sync/errgroup"
)

func (snowflakeDriver) Discover(ctx context.Context, req *pc.Request_Discover) (*pc.Response_Discovered, error) {
Expand Down Expand Up @@ -71,6 +74,11 @@ type snowflakeDiscoveryColumn struct {
IsNullable string `db:"IS_NULLABLE"`
NumericScale *int `db:"NUMERIC_SCALE"`
NumericPrecision *int `db:"NUMERIC_PRECISION"`

// Non-standard values that can only be populated with a DESCRIBE TABLE
// query.
vectorType string // INT or FLOAT
vectorDimension int
}

type snowflakeDiscoveryPrimaryKey struct {
Expand Down Expand Up @@ -110,6 +118,8 @@ func performSnowflakeDiscovery(ctx context.Context, cfg *config, db *sql.DB) (*s
var columns []*snowflakeDiscoveryColumn
if err := xdb.Select(&columns, "SELECT * FROM information_schema.columns;"); err != nil {
return nil, fmt.Errorf("error listing columns: %w", err)
} else if columns, err = resolveVectorColumns(ctx, xdb, columns); err != nil {
return nil, fmt.Errorf("error resolving VECTOR columns: %w", err)
}
var primaryKeysQuery = fmt.Sprintf("SHOW PRIMARY KEYS IN DATABASE %s;", quoteSnowflakeIdentifier(cfg.Database))
var primaryKeys []*snowflakeDiscoveryPrimaryKey
Expand Down Expand Up @@ -266,7 +276,7 @@ func schemaFromDiscovery(info *snowflakeDiscoveryInfo) (json.RawMessage, error)
for _, column := range info.Columns {
var jsonType, err = translateDBToJSONType(column)
if err != nil {
logrus.WithFields(logrus.Fields{
log.WithFields(log.Fields{
"error": err,
"type": column.DataType,
}).Warn("error translating column type to JSON schema")
Expand Down Expand Up @@ -380,6 +390,23 @@ func translateDBToJSONType(column *snowflakeDiscoveryColumn) (*jsonschema.Schema
} else {
schema = columnSchema{jsonType: "number"}
}
case "VECTOR":
schema = columnSchema{
jsonType: "array",
extras: map[string]any{
// There will always be exactly the number of items in the array
// as the vector dimension.
"minItems": column.vectorDimension,
"maxItems": column.vectorDimension,
},
}
if column.vectorType == "INT" {
schema.extras["items"] = &jsonschema.Schema{Type: "integer"}
} else if column.vectorType == "FLOAT" {
schema.extras["items"] = &jsonschema.Schema{Type: "number"}
} else {
return nil, fmt.Errorf("internal error: unknown vector type %q (found on column %q of table %q)", column.vectorType, column.Name, column.Table)
}
default:
if s, ok := snowflakeTypeToJSON[column.DataType]; ok {
schema = s
Expand Down Expand Up @@ -445,3 +472,78 @@ var snowflakeTypeToJSON = map[string]columnSchema{
"OBJECT": {jsonType: "object"},
"ARRAY": {jsonType: "array"},
}

var vectorRegexp = regexp.MustCompile(`(?i)^VECTOR\((FLOAT|INT),\s*(\d+)\)$`)

type describeTableColumn struct {
Name string `db:"name"`
Type_ string `db:"type"`
}

// resolveVectorColumns populates all VECTOR columns' value type (either FLOAT
// or INT), and the dimension of the vector. This information is only available
// from parsing the response of a DESCRIBE TABLE query, since it isn't in the
// standard INFORMATION_SCHEMA.COLUMNS view.
func resolveVectorColumns(ctx context.Context, xdb *sqlx.DB, discoveredColumns []*snowflakeDiscoveryColumn) ([]*snowflakeDiscoveryColumn, error) {
var tablesWithVectors = make(map[string]bool)
for _, column := range discoveredColumns {
if column.DataType == "VECTOR" {
tablesWithVectors[column.Table] = true
}
}

if len(tablesWithVectors) == 0 {
return discoveredColumns, nil
}
log.WithField("count", len(tablesWithVectors)).Debug("describing tables to discover VECTOR columns")

var mu sync.Mutex
group, groupCtx := errgroup.WithContext(ctx)

// I'm not aware of any real rate limits for DESCRIBE TABLE, and there
// probably won't be many tables with VECTOR columns, so we might as well
// run all these concurrently.
for table := range tablesWithVectors {
group.Go(func() (err error) {
var describeTableQuery = fmt.Sprintf("DESCRIBE TABLE %s;", quoteSnowflakeIdentifier(table))
var describedColumns []*describeTableColumn
if err = xdb.SelectContext(groupCtx, &describedColumns, describeTableQuery); err != nil {
return err
}

mu.Lock()
defer mu.Unlock()

for _, describedColumn := range describedColumns {
if matches := vectorRegexp.FindStringSubmatch(describedColumn.Type_); matches != nil {
// This looks like a VECTOR(something, n) column. Find the
// appropriate discovered column and add its information
// there.
for _, col := range discoveredColumns {
if col.Table == table && col.Name == describedColumn.Name {
col.vectorType = matches[1]
if col.vectorDimension, err = strconv.Atoi(matches[2]); err != nil {
return fmt.Errorf(
"parsing VECTOR column %q (%q) for table %q: %w",
col.Name,
describedColumn.Type_,
table,
err,
)
}
break
}
}
}
}

return nil
})
}

if err := group.Wait(); err != nil {
return nil, err
}

return discoveredColumns, nil
}
Loading