Skip to content

Commit

Permalink
Support EXECUTE IMMEDIATE
Browse files Browse the repository at this point in the history
Use EXECUTE IMMEDIATE sent in the HTTP request body, instead of putting
the query text in HTTP headers. This should allow sending large query
text. It can be enabled by setting the `explicitPrepare` option to
false in the connection string.
  • Loading branch information
nineinchnick committed Oct 9, 2024
1 parent d71f0cb commit 467db65
Show file tree
Hide file tree
Showing 4 changed files with 50 additions and 6 deletions.
2 changes: 2 additions & 0 deletions trino/etc/config.properties
Original file line number Diff line number Diff line change
Expand Up @@ -12,3 +12,5 @@ http-server.https.port=8443
http-server.authentication.allow-insecure-over-http=true
http-server.https.keystore.path=/etc/trino/secrets/certificate_with_key.pem
internal-communication.shared-secret=gotrino

query.max-length=5000043
2 changes: 1 addition & 1 deletion trino/etc/jvm.config
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
-Xmx1G
-Xmx4G
-XX:+UseG1GC
-XX:G1HeapRegionSize=32M
-XX:+UseGCOverheadLimit
Expand Down
33 changes: 30 additions & 3 deletions trino/integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ import (
"math/big"
"net/http"
"os"
"strconv"
"strings"
"testing"
"time"
Expand Down Expand Up @@ -75,6 +76,9 @@ func TestMain(m *testing.M) {
flag.Parse()
DefaultQueryTimeout = *integrationServerQueryTimeout
DefaultCancelQueryTimeout = *integrationServerQueryTimeout
if *trinoImageTagFlag == "" {
*trinoImageTagFlag = "latest"
}

var err error
if *integrationServerFlag == "" && !testing.Short() {
Expand All @@ -97,9 +101,6 @@ func TestMain(m *testing.M) {
if err != nil {
log.Fatalf("Could not generate TLS certificates: %s", err)
}
if *trinoImageTagFlag == "" {
*trinoImageTagFlag = "latest"
}
resource, err = pool.RunWithOptions(&dt.RunOptions{
Name: name,
Repository: "trinodb/trino",
Expand Down Expand Up @@ -1112,3 +1113,29 @@ func TestIntegrationDayToHourIntervalMilliPrecision(t *testing.T) {
})
}
}

func TestIntegrationLargeQuery(t *testing.T) {
version, err := strconv.Atoi(*trinoImageTagFlag)
if (err != nil && *trinoImageTagFlag != "latest") || (err == nil && version < 418) {
t.Skip("Skipping test when not using Trino 418 or later.")
}
dsn := *integrationServerFlag
dsn += "?explicitPrepare=false"
db := integrationOpen(t, dsn)
defer db.Close()
rows, err := db.Query("SELECT ?, '"+strings.Repeat("a", 5000000)+"'", 42)
if err != nil {
t.Fatal(err)
}
defer rows.Close()
count := 0
for rows.Next() {
count++
}
if rows.Err() != nil {
t.Fatal(err)
}
if count != 1 {
t.Fatal("not enough rows returned:", count)
}
}
19 changes: 17 additions & 2 deletions trino/trino.go
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,7 @@ const (
sslCertPathConfig = "SSLCertPath"
sslCertConfig = "SSLCert"
accessTokenConfig = "accessToken"
explicitPrepareConfig = "explicitPrepare"
)

var (
Expand Down Expand Up @@ -282,6 +283,7 @@ type Conn struct {
kerberosRemoteServiceName string
progressUpdater ProgressUpdater
progressUpdaterPeriod queryProgressCallbackPeriod
useExplicitPrepare bool
}

var (
Expand All @@ -298,6 +300,10 @@ func newConn(dsn string) (*Conn, error) {
query := serverURL.Query()

kerberosEnabled, _ := strconv.ParseBool(query.Get(kerberosEnabledConfig))
useExplicitPrepare := true
if query.Get(explicitPrepareConfig) != "" {
useExplicitPrepare, _ = strconv.ParseBool(query.Get(explicitPrepareConfig))
}

var kerberosClient *client.Client

Expand Down Expand Up @@ -356,6 +362,7 @@ func newConn(dsn string) (*Conn, error) {
kerberosClient: kerberosClient,
kerberosEnabled: kerberosEnabled,
kerberosRemoteServiceName: query.Get(kerberosRemoteServiceNameConfig),
useExplicitPrepare: useExplicitPrepare,
}

var user string
Expand Down Expand Up @@ -867,7 +874,7 @@ func (st *driverStmt) exec(ctx context.Context, args []driver.NamedValue) (*stmt

hs.Add(arg.Name, headerValue)
} else {
if hs.Get(preparedStatementHeader) == "" {
if st.conn.useExplicitPrepare && hs.Get(preparedStatementHeader) == "" {
for _, v := range st.conn.httpHeaders.Values(preparedStatementHeader) {
hs.Add(preparedStatementHeader, v)
}
Expand All @@ -880,7 +887,11 @@ func (st *driverStmt) exec(ctx context.Context, args []driver.NamedValue) (*stmt
return nil, ErrInvalidProgressCallbackHeader
}
if len(ss) > 0 {
query = "EXECUTE " + preparedStatementName + " USING " + strings.Join(ss, ", ")
if st.conn.useExplicitPrepare {
query = "EXECUTE " + preparedStatementName + " USING " + strings.Join(ss, ", ")
} else {
query = "EXECUTE IMMEDIATE " + formatStringLiteral(st.query) + " USING " + strings.Join(ss, ", ")
}
}
}

Expand Down Expand Up @@ -1028,6 +1039,10 @@ func (st *driverStmt) exec(ctx context.Context, args []driver.NamedValue) (*stmt
return &sr, handleResponseError(resp.StatusCode, sr.Error)
}

func formatStringLiteral(query string) string {
return "'" + strings.ReplaceAll(query, "'", "''") + "'"
}

type driverRows struct {
ctx context.Context
stmt *driverStmt
Expand Down

0 comments on commit 467db65

Please sign in to comment.