Skip to content

Commit

Permalink
Add support for forwarding OAuth2 authorization header
Browse files Browse the repository at this point in the history
  • Loading branch information
kalil-pelissier authored and nineinchnick committed Oct 30, 2024
1 parent 3d1f94d commit 70bd4d7
Show file tree
Hide file tree
Showing 3 changed files with 98 additions and 43 deletions.
6 changes: 6 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,12 @@ Please refer to the [Coordinator JWT
Authentication](https://trino.io/docs/current/security/jwt.html) for
server-side configuration.

#### Authorization header forwarding
This driver supports forwarding authorization headers by adding a [NamedArg](https://godoc.org/database/sql#NamedArg) with the name `accessToken` (e.g., `accessToken=<your_access_token>`) and setting the `ForwardAuthorizationHeader` field in the [Config](https://godoc.org/github.com/trinodb/trino-go-client/trino#Config) struct to `true`.

When enabled, this configuration will override the `AccessToken` set in the `Config` struct.


#### System access control and per-query user information

It's possible to pass user information to Trino, different from the principal
Expand Down
103 changes: 60 additions & 43 deletions trino/trino.go
Original file line number Diff line number Diff line change
Expand Up @@ -132,16 +132,17 @@ const (

authorizationHeader = "Authorization"

kerberosEnabledConfig = "KerberosEnabled"
kerberosKeytabPathConfig = "KerberosKeytabPath"
kerberosPrincipalConfig = "KerberosPrincipal"
kerberosRealmConfig = "KerberosRealm"
kerberosConfigPathConfig = "KerberosConfigPath"
kerberosRemoteServiceNameConfig = "KerberosRemoteServiceName"
sslCertPathConfig = "SSLCertPath"
sslCertConfig = "SSLCert"
accessTokenConfig = "accessToken"
explicitPrepareConfig = "explicitPrepare"
kerberosEnabledConfig = "KerberosEnabled"
kerberosKeytabPathConfig = "KerberosKeytabPath"
kerberosPrincipalConfig = "KerberosPrincipal"
kerberosRealmConfig = "KerberosRealm"
kerberosConfigPathConfig = "KerberosConfigPath"
kerberosRemoteServiceNameConfig = "KerberosRemoteServiceName"
sslCertPathConfig = "SSLCertPath"
sslCertConfig = "SSLCert"
accessTokenConfig = "accessToken"
explicitPrepareConfig = "explicitPrepare"
forwardAuthorizationHeaderConfig = "forwardAuthorizationHeader"

mapKeySeparator = ":"
mapEntrySeparator = ";"
Expand All @@ -168,22 +169,23 @@ var _ driver.Driver = &Driver{}

// Config is a configuration that can be encoded to a DSN string.
type Config struct {
ServerURI string // URI of the Trino server, e.g. http://user@localhost:8080
Source string // Source of the connection (optional)
Catalog string // Catalog (optional)
Schema string // Schema (optional)
SessionProperties map[string]string // Session properties (optional)
ExtraCredentials map[string]string // Extra credentials (optional)
CustomClientName string // Custom client name (optional)
KerberosEnabled string // KerberosEnabled (optional, default is false)
KerberosKeytabPath string // Kerberos Keytab Path (optional)
KerberosPrincipal string // Kerberos Principal used to authenticate to KDC (optional)
KerberosRemoteServiceName string // Trino coordinator Kerberos service name (optional)
KerberosRealm string // The Kerberos Realm (optional)
KerberosConfigPath string // The krb5 config path (optional)
SSLCertPath string // The SSL cert path for TLS verification (optional)
SSLCert string // The SSL cert for TLS verification (optional)
AccessToken string // An access token (JWT) for authentication (optional)
ServerURI string // URI of the Trino server, e.g. http://user@localhost:8080
Source string // Source of the connection (optional)
Catalog string // Catalog (optional)
Schema string // Schema (optional)
SessionProperties map[string]string // Session properties (optional)
ExtraCredentials map[string]string // Extra credentials (optional)
CustomClientName string // Custom client name (optional)
KerberosEnabled string // KerberosEnabled (optional, default is false)
KerberosKeytabPath string // Kerberos Keytab Path (optional)
KerberosPrincipal string // Kerberos Principal used to authenticate to KDC (optional)
KerberosRemoteServiceName string // Trino coordinator Kerberos service name (optional)
KerberosRealm string // The Kerberos Realm (optional)
KerberosConfigPath string // The krb5 config path (optional)
SSLCertPath string // The SSL cert path for TLS verification (optional)
SSLCert string // The SSL cert for TLS verification (optional)
AccessToken string // An access token (JWT) for authentication (optional)
ForwardAuthorizationHeader bool // Allow forwarding the `accessToken` named query parameter in the authorization header, overwriting the `AccessToken` option, if set (optional)
}

// FormatDSN returns a DSN string from the configuration.
Expand Down Expand Up @@ -211,6 +213,10 @@ func (c *Config) FormatDSN() (string, error) {
query := make(url.Values)
query.Add("source", source)

if c.ForwardAuthorizationHeader {
query.Add(forwardAuthorizationHeaderConfig, "true")
}

KerberosEnabled, _ := strconv.ParseBool(c.KerberosEnabled)
isSSL := serverURL.Scheme == "https"

Expand Down Expand Up @@ -277,16 +283,17 @@ func (c *Config) FormatDSN() (string, error) {

// Conn is a Trino connection.
type Conn struct {
baseURL string
auth *url.Userinfo
httpClient http.Client
httpHeaders http.Header
kerberosClient *client.Client
kerberosEnabled bool
kerberosRemoteServiceName string
progressUpdater ProgressUpdater
progressUpdaterPeriod queryProgressCallbackPeriod
useExplicitPrepare bool
baseURL string
auth *url.Userinfo
httpClient http.Client
httpHeaders http.Header
kerberosEnabled bool
kerberosClient *client.Client
kerberosRemoteServiceName string
progressUpdater ProgressUpdater
progressUpdaterPeriod queryProgressCallbackPeriod
useExplicitPrepare bool
forwardAuthorizationHeader bool
}

var (
Expand All @@ -303,6 +310,9 @@ func newConn(dsn string) (*Conn, error) {
query := serverURL.Query()

kerberosEnabled, _ := strconv.ParseBool(query.Get(kerberosEnabledConfig))

forwardAuthorizationHeader, _ := strconv.ParseBool(query.Get(forwardAuthorizationHeaderConfig))

useExplicitPrepare := true
if query.Get(explicitPrepareConfig) != "" {
useExplicitPrepare, _ = strconv.ParseBool(query.Get(explicitPrepareConfig))
Expand Down Expand Up @@ -359,13 +369,14 @@ func newConn(dsn string) (*Conn, error) {
}

c := &Conn{
baseURL: serverURL.Scheme + "://" + serverURL.Host,
httpClient: *httpClient,
httpHeaders: make(http.Header),
kerberosClient: kerberosClient,
kerberosEnabled: kerberosEnabled,
kerberosRemoteServiceName: query.Get(kerberosRemoteServiceNameConfig),
useExplicitPrepare: useExplicitPrepare,
baseURL: serverURL.Scheme + "://" + serverURL.Host,
httpClient: *httpClient,
httpHeaders: make(http.Header),
kerberosClient: kerberosClient,
kerberosEnabled: kerberosEnabled,
kerberosRemoteServiceName: query.Get(kerberosRemoteServiceNameConfig),
useExplicitPrepare: useExplicitPrepare,
forwardAuthorizationHeader: forwardAuthorizationHeader,
}

var user string
Expand Down Expand Up @@ -909,6 +920,12 @@ func (st *driverStmt) exec(ctx context.Context, args []driver.NamedValue) (*stmt
continue
}

if st.conn.forwardAuthorizationHeader && arg.Name == accessTokenConfig {
token := arg.Value.(string)
hs.Add(authorizationHeader, getAuthorization(token))
continue
}

s, err := Serial(arg.Value)
if err != nil {
return nil, err
Expand Down
32 changes: 32 additions & 0 deletions trino/trino_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1911,3 +1911,35 @@ func TestExec(t *testing.T) {
_, err = db.Exec("DROP TABLE memory.default.test")
require.NoError(t, err, "Failed executing DROP TABLE query")
}

func TestForwardAuthorizationHeaderConfig(t *testing.T) {
c := &Config{
ServerURI: "https://foobar@localhost:8090",
ForwardAuthorizationHeader: true,
}

dsn, err := c.FormatDSN()
require.NoError(t, err)

want := "https://foobar@localhost:8090?forwardAuthorizationHeader=true&source=trino-go-client"

assert.Equal(t, want, dsn)
}

func TestForwardAuthorizationHeader(t *testing.T) {
var captureAuthHeader string
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Capture the Authorization header for later inspection
captureAuthHeader = r.Header.Get("Authorization")
}))

t.Cleanup(ts.Close)

db, err := sql.Open("trino", ts.URL+"?forwardAuthorizationHeader=true")
require.NoError(t, err)

_, _ = db.Query("SELECT 1", sql.Named("accessToken", string("token"))) // Ingore response to focus on header capture
require.Equal(t, "Bearer token", captureAuthHeader, "Authorization header is incorrect")

assert.NoError(t, db.Close())
}

0 comments on commit 70bd4d7

Please sign in to comment.