Skip to content

Commit

Permalink
Add support for passing time.Duration
Browse files Browse the repository at this point in the history
  • Loading branch information
joris-bright committed Sep 17, 2024
1 parent 806af86 commit 6a40670
Show file tree
Hide file tree
Showing 5 changed files with 258 additions and 3 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -262,11 +262,11 @@ types:
passed to Trino as a time with a time zone
* the result of `trino.Timestamp(year, month, day, hour, minute, second,
nanosecond)` - passed to Trino as a timestamp without a time zone
* `time.Duration` - passed to Trino as an interval day to second (user must round to the desired precision otherwise an error will be returned)

It's not yet possible to pass:
* `float32` or `float64`
* `byte`
* `time.Duration`
* `json.RawMessage`
* maps

Expand Down
125 changes: 125 additions & 0 deletions trino/integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ import (
"fmt"
"io"
"log"
"math"
"math/big"
"net/http"
"os"
Expand Down Expand Up @@ -987,3 +988,127 @@ func contextSleep(ctx context.Context, d time.Duration) error {
return ctx.Err()
}
}

func TestIntegrationDayToHourIntervalMilliPrecision(t *testing.T) {
db := integrationOpen(t)
defer db.Close()
tests := []struct {
name string
arg time.Duration
wantErr bool
}{
{
name: "valid 1234567891s",
arg: time.Duration(1234567891) * time.Second,
wantErr: false,
},
{
name: "valid 123456789.1s",
arg: time.Duration(123456789100) * time.Millisecond,
wantErr: false,
},
{
name: "valid 12345678.91s",
arg: time.Duration(12345678910) * time.Millisecond,
wantErr: false,
},
{
name: "valid 1234567.891s",
arg: time.Duration(1234567891) * time.Millisecond,
wantErr: false,
},
{
name: "valid -1234567891s",
arg: time.Duration(-1234567891) * time.Second,
wantErr: false,
},
{
name: "valid -123456789.1s",
arg: time.Duration(-123456789100) * time.Millisecond,
wantErr: false,
},
{
name: "valid -12345678.91s",
arg: time.Duration(-12345678910) * time.Millisecond,
wantErr: false,
},
{
name: "valid -1234567.891s",
arg: time.Duration(-1234567891) * time.Millisecond,
wantErr: false,
},
{
name: "invalid 1234567891.2s",
arg: time.Duration(1234567891200) * time.Millisecond,
wantErr: true,
},
{
name: "invalid 123456789.12s",
arg: time.Duration(123456789120) * time.Millisecond,
wantErr: true,
},
{
name: "invalid 12345678.912s",
arg: time.Duration(12345678912) * time.Millisecond,
wantErr: true,
},
{
name: "invalid -1234567891.2s",
arg: time.Duration(-1234567891200) * time.Millisecond,
wantErr: true,
},
{
name: "invalid -123456789.12s",
arg: time.Duration(-123456789120) * time.Millisecond,
wantErr: true,
},
{
name: "invalid -12345678.912s",
arg: time.Duration(-12345678912) * time.Millisecond,
wantErr: true,
},
{
name: "invalid max seconds (9223372036)",
arg: time.Duration(math.MaxInt64) / time.Second * time.Second,
wantErr: true,
},
{
name: "invalid min seconds (-9223372036)",
arg: time.Duration(math.MinInt64) / time.Second * time.Second,
wantErr: true,
},
{
name: "valid max seconds (2147483647)",
arg: math.MaxInt32 * time.Second,
},
{
name: "valid min seconds (-2147483647)",
arg: -math.MaxInt32 * time.Second,
},
{
name: "valid max minutes (153722867)",
arg: time.Duration(math.MaxInt64) / time.Minute * time.Minute,
},
{
name: "valid min minutes (-153722867)",
arg: time.Duration(math.MinInt64) / time.Minute * time.Minute,
},
{
name: "valid max hours (2562047)",
arg: time.Duration(math.MaxInt64) / time.Hour * time.Hour,
},
{
name: "valid min hours (-2562047)",
arg: time.Duration(math.MinInt64) / time.Hour * time.Hour,
},
}
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
_, err := db.Exec("SELECT ?", test.arg)
if (err != nil) != test.wantErr {
t.Errorf("Exec() error = %v, wantErr %v", err, test.wantErr)
return
}
})
}
}
51 changes: 50 additions & 1 deletion trino/serial.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ package trino
import (
"encoding/json"
"fmt"
"math"
"reflect"
"strconv"
"strings"
Expand Down Expand Up @@ -163,7 +164,7 @@ func Serial(v interface{}) (string, error) {
return "TIMESTAMP " + time.Time(x).Format("'2006-01-02 15:04:05.999999999 Z07:00'"), nil

case time.Duration:
return "", UnsupportedArgError{"time.Duration"}
return serialDuration(x)

// TODO - json.RawMesssage should probably be matched to 'JSON' in Trino
case json.RawMessage:
Expand Down Expand Up @@ -208,3 +209,51 @@ func serialSlice(v []interface{}) (string, error) {

return "ARRAY[" + strings.Join(ss, ", ") + "]", nil
}

const (
// For seconds with milliseconds there is a maximum length of 10 digits
// or 11 characters with the dot and 12 characters with the minus sign and dot
maxIntervalStrLenWithDot = 11 // 123456789.1 and 12345678.91 are valid
)

func serialDuration(dur time.Duration) (string, error) {
switch {
case dur%time.Hour == 0:
return serialHoursInterval(dur), nil
case dur%time.Minute == 0:
return serialMinutesInterval(dur), nil
case dur%time.Second == 0:
return serialSecondsInterval(dur)
case dur%time.Millisecond == 0:
return serialMillisecondsInterval(dur)
default:
return "", fmt.Errorf("trino: duration %v is not a multiple of hours, minutes, seconds or milliseconds", dur)
}
}

func serialHoursInterval(dur time.Duration) string {
return "INTERVAL '" + strconv.Itoa(int(dur/time.Hour)) + "' HOUR"
}

func serialMinutesInterval(dur time.Duration) string {
return "INTERVAL '" + strconv.Itoa(int(dur/time.Minute)) + "' MINUTE"
}

func serialSecondsInterval(dur time.Duration) (string, error) {
seconds := int64(dur / time.Second)
if seconds <= math.MinInt32 || seconds > math.MaxInt32 {
return "", fmt.Errorf("trino: duration %v is out of range for interval of seconds type", dur)
}
return "INTERVAL '" + strconv.FormatInt(seconds, 10) + "' SECOND", nil
}

func serialMillisecondsInterval(dur time.Duration) (string, error) {
seconds := int64(dur / time.Second)
millisInSecond := dur.Abs().Milliseconds() % 1000
intervalNr := strings.TrimRight(fmt.Sprintf("%d.%03d", seconds, millisInSecond), "0")
if seconds > 0 && len(intervalNr) > maxIntervalStrLenWithDot ||
seconds < 0 && len(intervalNr) > maxIntervalStrLenWithDot+1 { // +1 for the minus sign
return "", fmt.Errorf("trino: duration %v is out of range for interval of seconds with millis type", dur)
}
return "INTERVAL '" + intervalNr + "' SECOND", nil
}
81 changes: 81 additions & 0 deletions trino/serial_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
package trino

import (
"math"
"testing"
"time"

Expand Down Expand Up @@ -160,6 +161,86 @@ func TestSerial(t *testing.T) {
value: time.Date(2017, 7, 10, 11, 34, 25, 123456, time.UTC),
expectedSerial: "TIMESTAMP '2017-07-10 11:34:25.000123456 Z'",
},
{
name: "duration",
value: 10*time.Second + 5*time.Millisecond,
expectedSerial: "INTERVAL '10.005' SECOND",
},
{
name: "duration with negative value",
value: -(10*time.Second + 5*time.Millisecond),
expectedSerial: "INTERVAL '-10.005' SECOND",
},
{
name: "minute duration",
value: 10 * time.Minute,
expectedSerial: "INTERVAL '10' MINUTE",
},
{
name: "hour duration",
value: 23 * time.Hour,
expectedSerial: "INTERVAL '23' HOUR",
},
{
name: "max hour duration",
value: (math.MaxInt64 / time.Hour) * time.Hour,
expectedSerial: "INTERVAL '2562047' HOUR",
},
{
name: "min hour duration",
value: (math.MinInt64 / time.Hour) * time.Hour,
expectedSerial: "INTERVAL '-2562047' HOUR",
},
{
name: "max minute duration",
value: (math.MaxInt64 / time.Minute) * time.Minute,
expectedSerial: "INTERVAL '153722867' MINUTE",
},
{
name: "min minute duration",
value: (math.MinInt64 / time.Minute) * time.Minute,
expectedSerial: "INTERVAL '-153722867' MINUTE",
},
{
name: "too big second duration",
value: (math.MaxInt64 / time.Second) * time.Second,
expectedError: true,
},
{
name: "too small second duration",
value: (math.MinInt64 / time.Second) * time.Second,
expectedError: true,
},
{
name: "too big millisecond duration",
value: time.Millisecond*912 + time.Second*12345678,
expectedError: true,
},
{
name: "too small millisecond duration",
value: -(time.Millisecond*910 + time.Second*123456789),
expectedError: true,
},
{
name: "max allowed second duration",
value: math.MaxInt32 * time.Second,
expectedSerial: "INTERVAL '2147483647' SECOND",
},
{
name: "min allowed second duration",
value: -math.MaxInt32 * time.Second,
expectedSerial: "INTERVAL '-2147483647' SECOND",
},
{
name: "max allowed second with milliseconds duration",
value: 999999999*time.Second + 900*time.Millisecond,
expectedSerial: "INTERVAL '999999999.9' SECOND",
},
{
name: "min allowed second with milliseconds duration",
value: -999999999*time.Second - 900*time.Millisecond,
expectedSerial: "INTERVAL '-999999999.9' SECOND",
},
{
name: "nil",
value: nil,
Expand Down
2 changes: 1 addition & 1 deletion trino/trino.go
Original file line number Diff line number Diff line change
Expand Up @@ -682,7 +682,7 @@ func (st *driverStmt) CheckNamedValue(arg *driver.NamedValue) error {
switch arg.Value.(type) {
case nil:
return nil
case Numeric, trinoDate, trinoTime, trinoTimeTz, trinoTimestamp:
case Numeric, trinoDate, trinoTime, trinoTimeTz, trinoTimestamp, time.Duration:
return nil
default:
{
Expand Down

0 comments on commit 6a40670

Please sign in to comment.