Skip to content

Commit

Permalink
Don't fail when decoding empty strings as floats in server responses
Browse files Browse the repository at this point in the history
  • Loading branch information
nineinchnick committed Aug 24, 2024
1 parent d64a6b6 commit 806af86
Show file tree
Hide file tree
Showing 2 changed files with 80 additions and 23 deletions.
64 changes: 43 additions & 21 deletions trino/trino.go
Original file line number Diff line number Diff line change
Expand Up @@ -713,27 +713,27 @@ type stmtResponse struct {
}

type stmtStats struct {
State string `json:"state"`
Scheduled bool `json:"scheduled"`
Nodes int `json:"nodes"`
TotalSplits int `json:"totalSplits"`
QueuesSplits int `json:"queuedSplits"`
RunningSplits int `json:"runningSplits"`
CompletedSplits int `json:"completedSplits"`
UserTimeMillis int `json:"userTimeMillis"`
CPUTimeMillis int64 `json:"cpuTimeMillis"`
WallTimeMillis int64 `json:"wallTimeMillis"`
QueuedTimeMillis int64 `json:"queuedTimeMillis"`
ElapsedTimeMillis int64 `json:"elapsedTimeMillis"`
ProcessedRows int64 `json:"processedRows"`
ProcessedBytes int64 `json:"processedBytes"`
PhysicalInputBytes int64 `json:"physicalInputBytes"`
PhysicalWrittenBytes int64 `json:"physicalWrittenBytes"`
PeakMemoryBytes int64 `json:"peakMemoryBytes"`
SpilledBytes int64 `json:"spilledBytes"`
RootStage stmtStage `json:"rootStage"`
ProgressPercentage float32 `json:"progressPercentage"`
RunningPercentage float32 `json:"runningPercentage"`
State string `json:"state"`
Scheduled bool `json:"scheduled"`
Nodes int `json:"nodes"`
TotalSplits int `json:"totalSplits"`
QueuesSplits int `json:"queuedSplits"`
RunningSplits int `json:"runningSplits"`
CompletedSplits int `json:"completedSplits"`
UserTimeMillis int `json:"userTimeMillis"`
CPUTimeMillis int64 `json:"cpuTimeMillis"`
WallTimeMillis int64 `json:"wallTimeMillis"`
QueuedTimeMillis int64 `json:"queuedTimeMillis"`
ElapsedTimeMillis int64 `json:"elapsedTimeMillis"`
ProcessedRows int64 `json:"processedRows"`
ProcessedBytes int64 `json:"processedBytes"`
PhysicalInputBytes int64 `json:"physicalInputBytes"`
PhysicalWrittenBytes int64 `json:"physicalWrittenBytes"`
PeakMemoryBytes int64 `json:"peakMemoryBytes"`
SpilledBytes int64 `json:"spilledBytes"`
RootStage stmtStage `json:"rootStage"`
ProgressPercentage jsonFloat64 `json:"progressPercentage"`
RunningPercentage jsonFloat64 `json:"runningPercentage"`
}

type ErrTrino struct {
Expand Down Expand Up @@ -792,6 +792,28 @@ type stmtStage struct {
SubStages []stmtStage `json:"subStages"`
}

type jsonFloat64 float64

func (f *jsonFloat64) UnmarshalJSON(data []byte) error {
var v float64
err := json.Unmarshal(data, &v)
if err != nil {
var jsonErr *json.UnmarshalTypeError
if errors.As(err, &jsonErr) {
if f != nil {
*f = 0
}
return nil
}
return err
}
p := (*float64)(f)
*p = v
return nil
}

var _ json.Unmarshaler = new(jsonFloat64)

func (st *driverStmt) Query(args []driver.Value) (driver.Rows, error) {
return nil, driver.ErrSkip
}
Expand Down
39 changes: 37 additions & 2 deletions trino/trino_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -250,6 +250,34 @@ func TestRoundTripRetryQueryError(t *testing.T) {
assert.IsTypef(t, new(ErrQueryFailed), err, "unexpected error: %w", err)
}

func TestRoundTripBogusData(t *testing.T) {
count := 0
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if count == 0 {
count++
w.WriteHeader(http.StatusServiceUnavailable)
return
}
w.WriteHeader(http.StatusOK)
// some invalid JSON
w.Write([]byte(`{"stats": {"progressPercentage": ""}}`))
}))

t.Cleanup(ts.Close)

db, err := sql.Open("trino", ts.URL)
require.NoError(t, err)

t.Cleanup(func() {
assert.NoError(t, db.Close())
})

rows, err := db.Query("SELECT 1")
require.NoError(t, err)
assert.False(t, rows.Next())
require.NoError(t, rows.Err())
}

func TestRoundTripCancellation(t *testing.T) {
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusServiceUnavailable)
Expand Down Expand Up @@ -336,10 +364,12 @@ func TestQueryForUsername(t *testing.T) {
}

type TestQueryProgressCallback struct {
statusMap map[time.Time]string
progressMap map[time.Time]float64
statusMap map[time.Time]string
}

func (qpc *TestQueryProgressCallback) Update(qpi QueryProgressInfo) {
qpc.progressMap[time.Now()] = float64(qpi.QueryStats.ProgressPercentage)
qpc.statusMap[time.Now()] = qpi.QueryStats.State
}

Expand Down Expand Up @@ -387,9 +417,11 @@ func TestQueryProgressWithCallbackPeriod(t *testing.T) {
assert.NoError(t, db.Close())
})

progressMap := make(map[time.Time]float64)
statusMap := make(map[time.Time]string)
progressUpdater := &TestQueryProgressCallback{
statusMap: statusMap,
progressMap: progressMap,
statusMap: statusMap,
}
progressUpdaterPeriod, err := time.ParseDuration("1ms")
require.NoError(t, err)
Expand All @@ -416,6 +448,8 @@ func TestQueryProgressWithCallbackPeriod(t *testing.T) {
}

// sort time in order to calculate interval
assert.NotEmpty(t, progressMap)
assert.NotEmpty(t, statusMap)
var keys []time.Time
for k := range statusMap {
keys = append(keys, k)
Expand All @@ -428,6 +462,7 @@ func TestQueryProgressWithCallbackPeriod(t *testing.T) {
if i > 0 {
assert.GreaterOrEqual(t, k.Sub(keys[i-1]), progressUpdaterPeriod)
}
assert.GreaterOrEqual(t, progressMap[k], 0.0)
}
}

Expand Down

0 comments on commit 806af86

Please sign in to comment.