diff --git a/trino/trino.go b/trino/trino.go index 4c46158..a8e6e1f 100644 --- a/trino/trino.go +++ b/trino/trino.go @@ -129,13 +129,14 @@ const ( trinoAddedPrepareHeader = trinoHeaderPrefix + `Added-Prepare` trinoDeallocatedPrepareHeader = trinoHeaderPrefix + `Deallocated-Prepare` - KerberosEnabledConfig = "KerberosEnabled" - kerberosKeytabPathConfig = "KerberosKeytabPath" - kerberosPrincipalConfig = "KerberosPrincipal" - kerberosRealmConfig = "KerberosRealm" - kerberosConfigPathConfig = "KerberosConfigPath" - SSLCertPathConfig = "SSLCertPath" - SSLCertConfig = "SSLCert" + KerberosEnabledConfig = "KerberosEnabled" + kerberosKeytabPathConfig = "KerberosKeytabPath" + kerberosPrincipalConfig = "KerberosPrincipal" + kerberosRealmConfig = "KerberosRealm" + kerberosConfigPathConfig = "KerberosConfigPath" + kerberosRemoteServiceNameConfig = "KerberosRemoteServiceName" + SSLCertPathConfig = "SSLCertPath" + SSLCertConfig = "SSLCert" ) var ( @@ -159,20 +160,21 @@ var _ driver.Driver = &Driver{} // Config is a configuration that can be encoded to a DSN string. type Config struct { - ServerURI string // URI of the Trino server, e.g. http://user@localhost:8080 - Source string // Source of the connection (optional) - Catalog string // Catalog (optional) - Schema string // Schema (optional) - SessionProperties map[string]string // Session properties (optional) - ExtraCredentials map[string]string // Extra credentials (optional) - CustomClientName string // Custom client name (optional) - KerberosEnabled string // KerberosEnabled (optional, default is false) - KerberosKeytabPath string // Kerberos Keytab Path (optional) - KerberosPrincipal string // Kerberos Principal used to authenticate to KDC (optional) - KerberosRealm string // The Kerberos Realm (optional) - KerberosConfigPath string // The krb5 config path (optional) - SSLCertPath string // The SSL cert path for TLS verification (optional) - SSLCert string // The SSL cert for TLS verification (optional) + ServerURI string // URI of the Trino server, e.g. http://user@localhost:8080 + Source string // Source of the connection (optional) + Catalog string // Catalog (optional) + Schema string // Schema (optional) + SessionProperties map[string]string // Session properties (optional) + ExtraCredentials map[string]string // Extra credentials (optional) + CustomClientName string // Custom client name (optional) + KerberosEnabled string // KerberosEnabled (optional, default is false) + KerberosKeytabPath string // Kerberos Keytab Path (optional) + KerberosPrincipal string // Kerberos Principal used to authenticate to KDC (optional) + KerberosRemoteServiceName string // Trino coordinator Kerberos service name (optional) + KerberosRealm string // The Kerberos Realm (optional) + KerberosConfigPath string // The krb5 config path (optional) + SSLCertPath string // The SSL cert path for TLS verification (optional) + SSLCert string // The SSL cert for TLS verification (optional) } // FormatDSN returns a DSN string from the configuration. @@ -229,14 +231,19 @@ func (c *Config) FormatDSN() (string, error) { } if KerberosEnabled { + if !isSSL { + return "", fmt.Errorf("trino: client configuration error, SSL must be enabled for secure env") + } query.Add(KerberosEnabledConfig, "true") query.Add(kerberosKeytabPathConfig, c.KerberosKeytabPath) query.Add(kerberosPrincipalConfig, c.KerberosPrincipal) query.Add(kerberosRealmConfig, c.KerberosRealm) query.Add(kerberosConfigPathConfig, c.KerberosConfigPath) - if !isSSL { - return "", fmt.Errorf("trino: client configuration error, SSL must be enabled for secure env") + remoteServiceName := c.KerberosRemoteServiceName + if remoteServiceName == "" { + remoteServiceName = "trino" } + query.Add(kerberosRemoteServiceNameConfig, remoteServiceName) } // ensure consistent order of items @@ -260,14 +267,15 @@ func (c *Config) FormatDSN() (string, error) { // Conn is a Trino connection. type Conn struct { - baseURL string - auth *url.Userinfo - httpClient http.Client - httpHeaders http.Header - kerberosClient client.Client - kerberosEnabled bool - progressUpdater ProgressUpdater - progressUpdaterPeriod queryProgressCallbackPeriod + baseURL string + auth *url.Userinfo + httpClient http.Client + httpHeaders http.Header + kerberosClient client.Client + kerberosEnabled bool + kerberosRemoteServiceName string + progressUpdater ProgressUpdater + progressUpdaterPeriod queryProgressCallbackPeriod } var ( @@ -339,11 +347,12 @@ func newConn(dsn string) (*Conn, error) { } c := &Conn{ - baseURL: serverURL.Scheme + "://" + serverURL.Host, - httpClient: *httpClient, - httpHeaders: make(http.Header), - kerberosClient: kerberosClient, - kerberosEnabled: kerberosEnabled, + baseURL: serverURL.Scheme + "://" + serverURL.Host, + httpClient: *httpClient, + httpHeaders: make(http.Header), + kerberosClient: kerberosClient, + kerberosEnabled: kerberosEnabled, + kerberosRemoteServiceName: query.Get(kerberosRemoteServiceNameConfig), } var user string @@ -455,7 +464,11 @@ func (c *Conn) newRequest(ctx context.Context, method, url string, body io.Reade } if c.kerberosEnabled { - err = c.kerberosClient.SetSPNEGOHeader(req, "trino/"+req.URL.Hostname()) + remoteServiceName := "trino" + if c.kerberosRemoteServiceName != "" { + remoteServiceName = c.kerberosRemoteServiceName + } + err = c.kerberosClient.SetSPNEGOHeader(req, remoteServiceName+"/"+req.URL.Hostname()) if err != nil { return nil, fmt.Errorf("error setting client SPNEGO header: %w", err) } diff --git a/trino/trino_test.go b/trino/trino_test.go index 6c29884..8d265bd 100644 --- a/trino/trino_test.go +++ b/trino/trino_test.go @@ -139,20 +139,21 @@ func TestConfigWithoutSSLCertPath(t *testing.T) { func TestKerberosConfig(t *testing.T) { c := &Config{ - ServerURI: "https://foobar@localhost:8090", - SessionProperties: map[string]string{"query_priority": "1"}, - KerberosEnabled: "true", - KerberosKeytabPath: "/opt/test.keytab", - KerberosPrincipal: "trino/testhost", - KerberosRealm: "example.com", - KerberosConfigPath: "/etc/krb5.conf", - SSLCertPath: "/tmp/test.cert", + ServerURI: "https://foobar@localhost:8090", + SessionProperties: map[string]string{"query_priority": "1"}, + KerberosEnabled: "true", + KerberosKeytabPath: "/opt/test.keytab", + KerberosPrincipal: "trino/testhost", + KerberosRealm: "example.com", + KerberosConfigPath: "/etc/krb5.conf", + KerberosRemoteServiceName: "service", + SSLCertPath: "/tmp/test.cert", } dsn, err := c.FormatDSN() require.NoError(t, err) - want := "https://foobar@localhost:8090?KerberosConfigPath=%2Fetc%2Fkrb5.conf&KerberosEnabled=true&KerberosKeytabPath=%2Fopt%2Ftest.keytab&KerberosPrincipal=trino%2Ftesthost&KerberosRealm=example.com&SSLCertPath=%2Ftmp%2Ftest.cert&session_properties=query_priority%3D1&source=trino-go-client" + want := "https://foobar@localhost:8090?KerberosConfigPath=%2Fetc%2Fkrb5.conf&KerberosEnabled=true&KerberosKeytabPath=%2Fopt%2Ftest.keytab&KerberosPrincipal=trino%2Ftesthost&KerberosRealm=example.com&KerberosRemoteServiceName=service&SSLCertPath=%2Ftmp%2Ftest.cert&session_properties=query_priority%3D1&source=trino-go-client" assert.Equal(t, want, dsn) }