Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for passing time.Duration #124

Merged
merged 1 commit into from
Sep 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -262,11 +262,11 @@ types:
passed to Trino as a time with a time zone
* the result of `trino.Timestamp(year, month, day, hour, minute, second,
nanosecond)` - passed to Trino as a timestamp without a time zone
* `time.Duration` - passed to Trino as an interval day to second. Because Trino does not support nanosecond precision for intervals, if the nanosecond part of the value is not zero, an error will be returned.

It's not yet possible to pass:
* `float32` or `float64`
* `byte`
* `time.Duration`
* `json.RawMessage`
* maps

Expand Down
125 changes: 125 additions & 0 deletions trino/integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ import (
"fmt"
"io"
"log"
"math"
"math/big"
"net/http"
"os"
Expand Down Expand Up @@ -987,3 +988,127 @@ func contextSleep(ctx context.Context, d time.Duration) error {
return ctx.Err()
}
}

func TestIntegrationDayToHourIntervalMilliPrecision(t *testing.T) {
joris-bright marked this conversation as resolved.
Show resolved Hide resolved
db := integrationOpen(t)
defer db.Close()
tests := []struct {
name string
arg time.Duration
wantErr bool
}{
{
name: "valid 1234567891s",
arg: time.Duration(1234567891) * time.Second,
wantErr: false,
},
{
name: "valid 123456789.1s",
arg: time.Duration(123456789100) * time.Millisecond,
wantErr: false,
},
{
name: "valid 12345678.91s",
arg: time.Duration(12345678910) * time.Millisecond,
wantErr: false,
},
{
name: "valid 1234567.891s",
arg: time.Duration(1234567891) * time.Millisecond,
wantErr: false,
},
{
name: "valid -1234567891s",
arg: time.Duration(-1234567891) * time.Second,
wantErr: false,
},
{
name: "valid -123456789.1s",
arg: time.Duration(-123456789100) * time.Millisecond,
wantErr: false,
},
{
name: "valid -12345678.91s",
arg: time.Duration(-12345678910) * time.Millisecond,
wantErr: false,
},
{
name: "valid -1234567.891s",
arg: time.Duration(-1234567891) * time.Millisecond,
wantErr: false,
},
{
name: "invalid 1234567891.2s",
arg: time.Duration(1234567891200) * time.Millisecond,
wantErr: true,
},
{
name: "invalid 123456789.12s",
arg: time.Duration(123456789120) * time.Millisecond,
wantErr: true,
},
{
name: "invalid 12345678.912s",
arg: time.Duration(12345678912) * time.Millisecond,
wantErr: true,
},
{
name: "invalid -1234567891.2s",
arg: time.Duration(-1234567891200) * time.Millisecond,
wantErr: true,
},
{
name: "invalid -123456789.12s",
arg: time.Duration(-123456789120) * time.Millisecond,
wantErr: true,
},
{
name: "invalid -12345678.912s",
arg: time.Duration(-12345678912) * time.Millisecond,
wantErr: true,
},
{
name: "invalid max seconds (9223372036)",
arg: time.Duration(math.MaxInt64) / time.Second * time.Second,
wantErr: true,
},
{
name: "invalid min seconds (-9223372036)",
arg: time.Duration(math.MinInt64) / time.Second * time.Second,
wantErr: true,
},
{
name: "valid max seconds (2147483647)",
arg: math.MaxInt32 * time.Second,
},
{
name: "valid min seconds (-2147483647)",
arg: -math.MaxInt32 * time.Second,
},
{
name: "valid max minutes (153722867)",
arg: time.Duration(math.MaxInt64) / time.Minute * time.Minute,
},
{
name: "valid min minutes (-153722867)",
arg: time.Duration(math.MinInt64) / time.Minute * time.Minute,
},
{
name: "valid max hours (2562047)",
arg: time.Duration(math.MaxInt64) / time.Hour * time.Hour,
},
{
name: "valid min hours (-2562047)",
arg: time.Duration(math.MinInt64) / time.Hour * time.Hour,
},
}
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
_, err := db.Exec("SELECT ?", test.arg)
if (err != nil) != test.wantErr {
t.Errorf("Exec() error = %v, wantErr %v", err, test.wantErr)
return
}
})
}
}
51 changes: 50 additions & 1 deletion trino/serial.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ package trino
import (
"encoding/json"
"fmt"
"math"
"reflect"
"strconv"
"strings"
Expand Down Expand Up @@ -163,7 +164,7 @@ func Serial(v interface{}) (string, error) {
return "TIMESTAMP " + time.Time(x).Format("'2006-01-02 15:04:05.999999999 Z07:00'"), nil

case time.Duration:
return "", UnsupportedArgError{"time.Duration"}
return serialDuration(x)

// TODO - json.RawMesssage should probably be matched to 'JSON' in Trino
case json.RawMessage:
Expand Down Expand Up @@ -208,3 +209,51 @@ func serialSlice(v []interface{}) (string, error) {

return "ARRAY[" + strings.Join(ss, ", ") + "]", nil
}

const (
// For seconds with milliseconds there is a maximum length of 10 digits
// or 11 characters with the dot and 12 characters with the minus sign and dot
maxIntervalStrLenWithDot = 11 // 123456789.1 and 12345678.91 are valid
)

func serialDuration(dur time.Duration) (string, error) {
switch {
case dur%time.Hour == 0:
return serialHoursInterval(dur), nil
case dur%time.Minute == 0:
return serialMinutesInterval(dur), nil
case dur%time.Second == 0:
return serialSecondsInterval(dur)
case dur%time.Millisecond == 0:
return serialMillisecondsInterval(dur)
default:
return "", fmt.Errorf("trino: duration %v is not a multiple of hours, minutes, seconds or milliseconds", dur)
}
}

func serialHoursInterval(dur time.Duration) string {
return "INTERVAL '" + strconv.Itoa(int(dur/time.Hour)) + "' HOUR"
}

func serialMinutesInterval(dur time.Duration) string {
return "INTERVAL '" + strconv.Itoa(int(dur/time.Minute)) + "' MINUTE"
}

func serialSecondsInterval(dur time.Duration) (string, error) {
seconds := int64(dur / time.Second)
if seconds <= math.MinInt32 || seconds > math.MaxInt32 {
return "", fmt.Errorf("trino: duration %v is out of range for interval of seconds type", dur)
}
return "INTERVAL '" + strconv.FormatInt(seconds, 10) + "' SECOND", nil
}

func serialMillisecondsInterval(dur time.Duration) (string, error) {
seconds := int64(dur / time.Second)
millisInSecond := dur.Abs().Milliseconds() % 1000
intervalNr := strings.TrimRight(fmt.Sprintf("%d.%03d", seconds, millisInSecond), "0")
if seconds > 0 && len(intervalNr) > maxIntervalStrLenWithDot ||
seconds < 0 && len(intervalNr) > maxIntervalStrLenWithDot+1 { // +1 for the minus sign
return "", fmt.Errorf("trino: duration %v is out of range for interval of seconds with millis type", dur)
}
return "INTERVAL '" + intervalNr + "' SECOND", nil
}
81 changes: 81 additions & 0 deletions trino/serial_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
package trino

import (
"math"
"testing"
"time"

Expand Down Expand Up @@ -160,6 +161,86 @@ func TestSerial(t *testing.T) {
value: time.Date(2017, 7, 10, 11, 34, 25, 123456, time.UTC),
expectedSerial: "TIMESTAMP '2017-07-10 11:34:25.000123456 Z'",
},
{
name: "duration",
value: 10*time.Second + 5*time.Millisecond,
expectedSerial: "INTERVAL '10.005' SECOND",
},
{
name: "duration with negative value",
value: -(10*time.Second + 5*time.Millisecond),
expectedSerial: "INTERVAL '-10.005' SECOND",
},
{
name: "minute duration",
value: 10 * time.Minute,
expectedSerial: "INTERVAL '10' MINUTE",
},
{
name: "hour duration",
value: 23 * time.Hour,
expectedSerial: "INTERVAL '23' HOUR",
},
{
name: "max hour duration",
value: (math.MaxInt64 / time.Hour) * time.Hour,
expectedSerial: "INTERVAL '2562047' HOUR",
},
{
name: "min hour duration",
value: (math.MinInt64 / time.Hour) * time.Hour,
expectedSerial: "INTERVAL '-2562047' HOUR",
},
{
name: "max minute duration",
value: (math.MaxInt64 / time.Minute) * time.Minute,
expectedSerial: "INTERVAL '153722867' MINUTE",
},
{
name: "min minute duration",
value: (math.MinInt64 / time.Minute) * time.Minute,
expectedSerial: "INTERVAL '-153722867' MINUTE",
},
{
name: "too big second duration",
value: (math.MaxInt64 / time.Second) * time.Second,
expectedError: true,
},
{
name: "too small second duration",
value: (math.MinInt64 / time.Second) * time.Second,
expectedError: true,
},
{
name: "too big millisecond duration",
value: time.Millisecond*912 + time.Second*12345678,
expectedError: true,
},
{
name: "too small millisecond duration",
value: -(time.Millisecond*910 + time.Second*123456789),
expectedError: true,
},
{
name: "max allowed second duration",
value: math.MaxInt32 * time.Second,
expectedSerial: "INTERVAL '2147483647' SECOND",
},
{
name: "min allowed second duration",
value: -math.MaxInt32 * time.Second,
expectedSerial: "INTERVAL '-2147483647' SECOND",
},
{
name: "max allowed second with milliseconds duration",
value: 999999999*time.Second + 900*time.Millisecond,
expectedSerial: "INTERVAL '999999999.9' SECOND",
},
{
name: "min allowed second with milliseconds duration",
value: -999999999*time.Second - 900*time.Millisecond,
expectedSerial: "INTERVAL '-999999999.9' SECOND",
},
{
name: "nil",
value: nil,
Expand Down
2 changes: 1 addition & 1 deletion trino/trino.go
Original file line number Diff line number Diff line change
Expand Up @@ -682,7 +682,7 @@ func (st *driverStmt) CheckNamedValue(arg *driver.NamedValue) error {
switch arg.Value.(type) {
case nil:
return nil
case Numeric, trinoDate, trinoTime, trinoTimeTz, trinoTimestamp:
case Numeric, trinoDate, trinoTime, trinoTimeTz, trinoTimestamp, time.Duration:
return nil
default:
{
Expand Down