Skip to content

Commit

Permalink
addressed issues in code review
Browse files Browse the repository at this point in the history
  • Loading branch information
cocoa-xu committed Jul 3, 2024
1 parent fee71e8 commit 39be2ff
Showing 1 changed file with 66 additions and 72 deletions.
138 changes: 66 additions & 72 deletions go/adbc/driver/bigquery/record_reader.go
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,44 @@ func runPlainQuery(ctx context.Context, query *bigquery.Query, alloc memory.Allo
return bigqueryRdr, totalRows, nil
}

func queryRecordWithSchemaCallback(ctx context.Context, group *errgroup.Group, query *bigquery.Query, rec arrow.Record, ch chan arrow.Record, parameterMode string, alloc memory.Allocator, rdrSchema func(schema *arrow.Schema)) (int64, error) {
totalRows := int64(-1)
for i := 0; i < int(rec.NumRows()); i++ {
parameters, err := getQueryParameter(rec, i, parameterMode)
if err != nil {
return -1, err
}
if parameters != nil {
query.QueryConfig.Parameters = parameters
}

arrowIterator, rows, err := runQuery(ctx, query, false)
if err != nil {
return -1, err
}
totalRows = rows
rdr, err := ipcReaderFromArrowIterator(arrowIterator, alloc)
if err != nil {
return -1, err
}
rdrSchema(rdr.Schema())
group.Go(func() error {
defer rdr.Release()
for rdr.Next() && ctx.Err() == nil {
rec := rdr.Record()
rec.Retain()
ch <- rec
}
return checkContext(ctx, rdr.Err())
})
}
return totalRows, nil
}

func queryRecord(ctx context.Context, group *errgroup.Group, query *bigquery.Query, rec arrow.Record, ch chan arrow.Record, parameterMode string, alloc memory.Allocator) (int64, error) {
return queryRecordWithSchemaCallback(ctx, group, query, rec, ch, parameterMode, alloc, func(schema *arrow.Schema) {})
}

// kicks off a goroutine for each endpoint and returns a reader which
// gathers all of the records as they come in.
func newRecordReader(ctx context.Context, query *bigquery.Query, boundParameters array.RecordReader, parameterMode string, alloc memory.Allocator, resultRecordBufferSize, prefetchConcurrency int) (bigqueryRdr *reader, totalRows int64, err error) {
Expand Down Expand Up @@ -180,34 +218,12 @@ func newRecordReader(ctx context.Context, query *bigquery.Query, boundParameters
}

rec := recs[0]
for i := 0; i < int(rec.NumRows()); i++ {
parameters, err := getQueryParameter(rec, i, parameterMode)
if err != nil {
return nil, -1, err
}
if parameters != nil {
query.QueryConfig.Parameters = parameters
}

arrowIterator, rows, err := runQuery(ctx, query, false)
if err != nil {
return nil, -1, err
}
totalRows = rows
rdr, err := ipcReaderFromArrowIterator(arrowIterator, alloc)
if err != nil {
return nil, -1, err
}
bigqueryRdr.schema = rdr.Schema()
group.Go(func() error {
defer rdr.Release()
for rdr.Next() && ctx.Err() == nil {
rec := rdr.Record()
rec.Retain()
ch <- rec
}
return checkContext(ctx, rdr.Err())
})
totalRows, err = queryRecordWithSchemaCallback(ctx, group, query, rec, ch, parameterMode, alloc, func(schema *arrow.Schema) {
// only need to assign once
bigqueryRdr.schema = schema
})
if err != nil {
return nil, -1, err
}

lastChannelIndex := len(chs) - 1
Expand All @@ -216,52 +232,30 @@ func newRecordReader(ctx context.Context, query *bigquery.Query, boundParameters
batchIndex := index + 1
record := values
chs[batchIndex] = make(chan arrow.Record, resultRecordBufferSize)
group.Go(func() error {
if batchIndex != lastChannelIndex {
defer close(chs[batchIndex])
}
for i := 0; i < int(record.NumRows()); i++ {
parameters, err := getQueryParameter(record, i, parameterMode)
if err != nil {
return err
}
if parameters != nil {
query.QueryConfig.Parameters = parameters
}

arrowIterator, rows, err := runQuery(ctx, query, false)
if err != nil {
return err
}
totalRows = rows
rdr, err := ipcReaderFromArrowIterator(arrowIterator, alloc)
if err != nil {
return err
}
defer rdr.Release()

for rdr.Next() && ctx.Err() == nil {
rec := rdr.Record()
rec.Retain()
chs[batchIndex] <- rec
}
err = checkContext(ctx, rdr.Err())
if err != nil {
return err
}
}
return nil
})
if batchIndex != lastChannelIndex {
defer close(chs[batchIndex])
}
totalRows, err = queryRecord(ctx, group, query, record, chs[batchIndex], parameterMode, alloc)
// if queryRecord returns an error
// then we never enter the `group.Go(func() error {})`
// it's safe to break here and assign the error message to bigqueryRdr.err
if err != nil {
bigqueryRdr.err = err
}
}

// place this here so that we always clean up, but they can't be in a
// separate goroutine. Otherwise we'll have a race condition between
// the call to wait and the calls to group.Go to kick off the jobs
// to perform the pre-fetching (GH-1283).
bigqueryRdr.err = group.Wait()
// don't close the last channel until after the group is finished,
// so that Next() can only return after reader.err may have been set
close(chs[lastChannelIndex])
// if queryRecord never returns an error
// we can just wait the last one
if err == nil {
// place this here so that we always clean up, but they can't be in a
// separate goroutine. Otherwise we'll have a race condition between
// the call to wait and the calls to group.Go to kick off the jobs
// to perform the pre-fetching (GH-1283).
bigqueryRdr.err = group.Wait()
// don't close the last channel until after the group is finished,
// so that Next() can only return after reader.err may have been set
close(chs[lastChannelIndex])
}
}()

return bigqueryRdr, totalRows, nil
Expand Down

0 comments on commit 39be2ff

Please sign in to comment.