diff --git a/connector/bitbucketcloud/bitbucketcloud.go b/connector/bitbucketcloud/bitbucketcloud.go index 27a63c4684..32b98a367a 100644 --- a/connector/bitbucketcloud/bitbucketcloud.go +++ b/connector/bitbucketcloud/bitbucketcloud.go @@ -105,12 +105,12 @@ func (b *bitbucketConnector) oauth2Config(scopes connector.Scopes) *oauth2.Confi } } -func (b *bitbucketConnector) LoginURL(scopes connector.Scopes, callbackURL, state string) (string, error) { +func (b *bitbucketConnector) LoginURL(scopes connector.Scopes, callbackURL, state string) (string, []byte, error) { if b.redirectURI != callbackURL { - return "", fmt.Errorf("expected callback URL %q did not match the URL in the config %q", callbackURL, b.redirectURI) + return "", nil, fmt.Errorf("expected callback URL %q did not match the URL in the config %q", callbackURL, b.redirectURI) } - return b.oauth2Config(scopes).AuthCodeURL(state), nil + return b.oauth2Config(scopes).AuthCodeURL(state), nil, nil } type oauth2Error struct { @@ -125,7 +125,7 @@ func (e *oauth2Error) Error() string { return e.error + ": " + e.errorDescription } -func (b *bitbucketConnector) HandleCallback(s connector.Scopes, r *http.Request) (identity connector.Identity, err error) { +func (b *bitbucketConnector) HandleCallback(s connector.Scopes, r *http.Request, connDataBytes []byte) (identity connector.Identity, err error) { q := r.URL.Query() if errType := q.Get("error"); errType != "" { return identity, &oauth2Error{errType, q.Get("error_description")} diff --git a/connector/bitbucketcloud/bitbucketcloud_test.go b/connector/bitbucketcloud/bitbucketcloud_test.go index b9f4ba08f8..9dd3bd2a06 100644 --- a/connector/bitbucketcloud/bitbucketcloud_test.go +++ b/connector/bitbucketcloud/bitbucketcloud_test.go @@ -89,7 +89,7 @@ func TestUsernameIncludedInFederatedIdentity(t *testing.T) { expectNil(t, err) bitbucketConnector := bitbucketConnector{apiURL: s.URL, hostName: hostURL.Host, httpClient: newClient()} - identity, err := bitbucketConnector.HandleCallback(connector.Scopes{}, req) + identity, err := bitbucketConnector.HandleCallback(connector.Scopes{}, req, nil) expectNil(t, err) expectEquals(t, identity.Username, "some-login") diff --git a/connector/connector.go b/connector/connector.go index c442c54af2..5bd2e140cd 100644 --- a/connector/connector.go +++ b/connector/connector.go @@ -62,10 +62,10 @@ type CallbackConnector interface { // requested if one has already been issues. There's no good general answer // for these kind of restrictions, and may require this package to become more // aware of the global set of user/connector interactions. - LoginURL(s Scopes, callbackURL, state string) (string, error) + LoginURL(s Scopes, callbackURL, state string) (string, []byte, error) // Handle the callback to the server and return an identity. - HandleCallback(s Scopes, r *http.Request) (identity Identity, err error) + HandleCallback(s Scopes, r *http.Request, data []byte) (identity Identity, err error) } // SAMLConnector represents SAML connectors which implement the HTTP POST binding. diff --git a/connector/github/github.go b/connector/github/github.go index 9f7c2782a9..afb213b0c4 100644 --- a/connector/github/github.go +++ b/connector/github/github.go @@ -186,12 +186,12 @@ func (c *githubConnector) oauth2Config(scopes connector.Scopes) *oauth2.Config { } } -func (c *githubConnector) LoginURL(scopes connector.Scopes, callbackURL, state string) (string, error) { +func (c *githubConnector) LoginURL(scopes connector.Scopes, callbackURL, state string) (string, []byte, error) { if c.redirectURI != callbackURL { - return "", fmt.Errorf("expected callback URL %q did not match the URL in the config %q", callbackURL, c.redirectURI) + return "", nil, fmt.Errorf("expected callback URL %q did not match the URL in the config %q", callbackURL, c.redirectURI) } - return c.oauth2Config(scopes).AuthCodeURL(state), nil + return c.oauth2Config(scopes).AuthCodeURL(state), nil, nil } type oauth2Error struct { @@ -234,7 +234,7 @@ func newHTTPClient(rootCA string) (*http.Client, error) { }, nil } -func (c *githubConnector) HandleCallback(s connector.Scopes, r *http.Request) (identity connector.Identity, err error) { +func (c *githubConnector) HandleCallback(s connector.Scopes, r *http.Request, connDataBytes []byte) (identity connector.Identity, err error) { q := r.URL.Query() if errType := q.Get("error"); errType != "" { return identity, &oauth2Error{errType, q.Get("error_description")} diff --git a/connector/github/github_test.go b/connector/github/github_test.go index 539a2e69c1..73265b796f 100644 --- a/connector/github/github_test.go +++ b/connector/github/github_test.go @@ -151,7 +151,7 @@ func TestUsernameIncludedInFederatedIdentity(t *testing.T) { expectNil(t, err) c := githubConnector{apiURL: s.URL, hostName: hostURL.Host, httpClient: newClient()} - identity, err := c.HandleCallback(connector.Scopes{Groups: true}, req) + identity, err := c.HandleCallback(connector.Scopes{Groups: true}, req, nil) expectNil(t, err) expectEquals(t, identity.Username, "some-login") @@ -159,7 +159,7 @@ func TestUsernameIncludedInFederatedIdentity(t *testing.T) { expectEquals(t, 0, len(identity.Groups)) c = githubConnector{apiURL: s.URL, hostName: hostURL.Host, httpClient: newClient(), loadAllGroups: true} - identity, err = c.HandleCallback(connector.Scopes{Groups: true}, req) + identity, err = c.HandleCallback(connector.Scopes{Groups: true}, req, nil) expectNil(t, err) expectEquals(t, identity.Username, "some-login") @@ -193,7 +193,7 @@ func TestLoginUsedAsIDWhenConfigured(t *testing.T) { expectNil(t, err) c := githubConnector{apiURL: s.URL, hostName: hostURL.Host, httpClient: newClient(), useLoginAsID: true} - identity, err := c.HandleCallback(connector.Scopes{Groups: true}, req) + identity, err := c.HandleCallback(connector.Scopes{Groups: true}, req, nil) expectNil(t, err) expectEquals(t, identity.UserID, "some-login") diff --git a/connector/gitlab/gitlab.go b/connector/gitlab/gitlab.go index 41b0beb261..11a3fa18a6 100644 --- a/connector/gitlab/gitlab.go +++ b/connector/gitlab/gitlab.go @@ -90,11 +90,11 @@ func (c *gitlabConnector) oauth2Config(scopes connector.Scopes) *oauth2.Config { } } -func (c *gitlabConnector) LoginURL(scopes connector.Scopes, callbackURL, state string) (string, error) { +func (c *gitlabConnector) LoginURL(scopes connector.Scopes, callbackURL, state string) (string, []byte, error) { if c.redirectURI != callbackURL { - return "", fmt.Errorf("expected callback URL %q did not match the URL in the config %q", c.redirectURI, callbackURL) + return "", nil, fmt.Errorf("expected callback URL %q did not match the URL in the config %q", c.redirectURI, callbackURL) } - return c.oauth2Config(scopes).AuthCodeURL(state), nil + return c.oauth2Config(scopes).AuthCodeURL(state), nil, nil } type oauth2Error struct { @@ -109,7 +109,7 @@ func (e *oauth2Error) Error() string { return e.error + ": " + e.errorDescription } -func (c *gitlabConnector) HandleCallback(s connector.Scopes, r *http.Request) (identity connector.Identity, err error) { +func (c *gitlabConnector) HandleCallback(s connector.Scopes, r *http.Request, connDataBytes []byte) (identity connector.Identity, err error) { q := r.URL.Query() if errType := q.Get("error"); errType != "" { return identity, &oauth2Error{errType, q.Get("error_description")} diff --git a/connector/linkedin/linkedin.go b/connector/linkedin/linkedin.go index 9ab67e57c8..e8c0d3bb4b 100644 --- a/connector/linkedin/linkedin.go +++ b/connector/linkedin/linkedin.go @@ -63,17 +63,17 @@ var ( ) // LoginURL returns an access token request URL -func (c *linkedInConnector) LoginURL(scopes connector.Scopes, callbackURL, state string) (string, error) { +func (c *linkedInConnector) LoginURL(scopes connector.Scopes, callbackURL, state string) (string, []byte, error) { if c.oauth2Config.RedirectURL != callbackURL { - return "", fmt.Errorf("expected callback URL %q did not match the URL in the config %q", + return "", nil, fmt.Errorf("expected callback URL %q did not match the URL in the config %q", callbackURL, c.oauth2Config.RedirectURL) } - return c.oauth2Config.AuthCodeURL(state), nil + return c.oauth2Config.AuthCodeURL(state), nil, nil } // HandleCallback handles HTTP redirect from LinkedIn -func (c *linkedInConnector) HandleCallback(s connector.Scopes, r *http.Request) (identity connector.Identity, err error) { +func (c *linkedInConnector) HandleCallback(s connector.Scopes, r *http.Request, connDataBytes []byte) (identity connector.Identity, err error) { q := r.URL.Query() if errType := q.Get("error"); errType != "" { return identity, &oauth2Error{errType, q.Get("error_description")} diff --git a/connector/microsoft/microsoft.go b/connector/microsoft/microsoft.go index ad6b3e7304..927154e7ff 100644 --- a/connector/microsoft/microsoft.go +++ b/connector/microsoft/microsoft.go @@ -105,15 +105,15 @@ func (c *microsoftConnector) oauth2Config(scopes connector.Scopes) *oauth2.Confi } } -func (c *microsoftConnector) LoginURL(scopes connector.Scopes, callbackURL, state string) (string, error) { +func (c *microsoftConnector) LoginURL(scopes connector.Scopes, callbackURL, state string) (string, []byte, error) { if c.redirectURI != callbackURL { - return "", fmt.Errorf("expected callback URL %q did not match the URL in the config %q", callbackURL, c.redirectURI) + return "", nil, fmt.Errorf("expected callback URL %q did not match the URL in the config %q", callbackURL, c.redirectURI) } - return c.oauth2Config(scopes).AuthCodeURL(state), nil + return c.oauth2Config(scopes).AuthCodeURL(state), nil, nil } -func (c *microsoftConnector) HandleCallback(s connector.Scopes, r *http.Request) (identity connector.Identity, err error) { +func (c *microsoftConnector) HandleCallback(s connector.Scopes, r *http.Request, connDataBytes []byte) (identity connector.Identity, err error) { q := r.URL.Query() if errType := q.Get("error"); errType != "" { return identity, &oauth2Error{errType, q.Get("error_description")} diff --git a/connector/mock/connectortest.go b/connector/mock/connectortest.go index ef8afd4608..0fbee6fb3c 100644 --- a/connector/mock/connectortest.go +++ b/connector/mock/connectortest.go @@ -44,21 +44,21 @@ type Callback struct { } // LoginURL returns the URL to redirect the user to login with. -func (m *Callback) LoginURL(s connector.Scopes, callbackURL, state string) (string, error) { +func (m *Callback) LoginURL(s connector.Scopes, callbackURL, state string) (string, []byte, error) { u, err := url.Parse(callbackURL) if err != nil { - return "", fmt.Errorf("failed to parse callbackURL %q: %v", callbackURL, err) + return "", nil, fmt.Errorf("failed to parse callbackURL %q: %v", callbackURL, err) } v := u.Query() v.Set("state", state) u.RawQuery = v.Encode() - return u.String(), nil + return u.String(), nil, nil } var connectorData = []byte("foobar") // HandleCallback parses the request and returns the user's identity -func (m *Callback) HandleCallback(s connector.Scopes, r *http.Request) (connector.Identity, error) { +func (m *Callback) HandleCallback(s connector.Scopes, r *http.Request, connDataBytes []byte) (connector.Identity, error) { return m.Identity, nil } diff --git a/connector/oidc/oidc.go b/connector/oidc/oidc.go index bf6a1d2876..47096dde1a 100644 --- a/connector/oidc/oidc.go +++ b/connector/oidc/oidc.go @@ -3,6 +3,7 @@ package oidc import ( "context" + "encoding/json" "errors" "fmt" "net/http" @@ -16,8 +17,12 @@ import ( "golang.org/x/oauth2" "github.com/dexidp/dex/connector" + "github.com/dexidp/dex/storage" ) +// NoRefreshTokenDummy is a placeholder for a lack of refresh token +const NoRefreshTokenDummy = "no-refresh-token" + // Config holds configuration options for OpenID Connect logins. type Config struct { Issuer string `json:"issuer"` @@ -34,11 +39,19 @@ type Config struct { Scopes []string `json:"scopes"` // defaults to "profile" and "email" + ResponseType string `json:"responseType"` // Default to "code" + // Optional list of whitelisted domains when using Google // If this field is nonempty, only users from a listed domain will be allowed to log in HostedDomains []string `json:"hostedDomains"` } +// connectorData holds state that is needed between starting +// the auth flow, and validating the response +type connectorData struct { + Nonce string +} + // Domains that don't support basic auth. golang.org/x/oauth2 has an internal // list, but it only matches specific URLs, not top level domains. var brokenAuthHeaderDomains = []string{ @@ -94,15 +107,25 @@ func (c *Config) Open(id string, logger logrus.FieldLogger) (conn connector.Conn registerBrokenAuthHeaderProvider(provider.Endpoint().TokenURL) } - scopes := []string{oidc.ScopeOpenID} - if len(c.Scopes) > 0 { - scopes = append(scopes, c.Scopes...) - } else { - scopes = append(scopes, "profile", "email") + // if the user specifies scope, respect it + scopes := c.Scopes + if len(c.Scopes) == 0 { + scopes = []string{oidc.ScopeOpenID, "profile", "email"} + fmt.Println(c.Scopes) + } + + if c.ResponseType == "" { + c.ResponseType = "code" + } + + if c.ResponseType != "id_token" && c.ResponseType != "code" { + err := fmt.Errorf("failed to create %s provider, unsupported response_type '%s'", id, c.ResponseType) + cancel() + return nil, err } clientID := c.ClientID - return &oidcConnector{ + connector := &oidcConnector{ redirectURI: c.RedirectURI, oauth2Config: &oauth2.Config{ ClientID: clientID, @@ -114,10 +137,13 @@ func (c *Config) Open(id string, logger logrus.FieldLogger) (conn connector.Conn verifier: provider.Verifier( &oidc.Config{ClientID: clientID}, ), + responseType: c.ResponseType, logger: logger, + ctx: ctx, cancel: cancel, hostedDomains: c.HostedDomains, - }, nil + } + return connector, nil } var ( @@ -133,6 +159,7 @@ type oidcConnector struct { cancel context.CancelFunc logger logrus.FieldLogger hostedDomains []string + responseType string } func (c *oidcConnector) Close() error { @@ -140,9 +167,10 @@ func (c *oidcConnector) Close() error { return nil } -func (c *oidcConnector) LoginURL(s connector.Scopes, callbackURL, state string) (string, error) { +func (c *oidcConnector) LoginURL(s connector.Scopes, callbackURL, state string) (string, []byte, error) { + connData := connectorData{} if c.redirectURI != callbackURL { - return "", fmt.Errorf("expected callback URL %q did not match the URL in the config %q", callbackURL, c.redirectURI) + return "", nil, fmt.Errorf("expected callback URL %q did not match the URL in the config %q", callbackURL, c.redirectURI) } var opts []oauth2.AuthCodeOption @@ -157,7 +185,19 @@ func (c *oidcConnector) LoginURL(s connector.Scopes, callbackURL, state string) if s.OfflineAccess { opts = append(opts, oauth2.AccessTypeOffline, oauth2.SetAuthURLParam("prompt", "consent")) } - return c.oauth2Config.AuthCodeURL(state, opts...), nil + if c.responseType == "id_token" { + connData.Nonce = storage.NewID() + opts = append(opts, oauth2.SetAuthURLParam("response_type", c.responseType)) + opts = append(opts, oauth2.SetAuthURLParam("response_mode", "form_post")) + opts = append(opts, oauth2.SetAuthURLParam("nonce", connData.Nonce)) + } + authCodeURL := c.oauth2Config.AuthCodeURL(state, opts...) + + connDataBytes, err := json.Marshal(connData) + if err != nil { + return "", nil, fmt.Errorf("failed to encode connector data: %v", err) + } + return authCodeURL, connDataBytes, nil } type oauth2Error struct { @@ -172,21 +212,49 @@ func (e *oauth2Error) Error() string { return e.error + ": " + e.errorDescription } -func (c *oidcConnector) HandleCallback(s connector.Scopes, r *http.Request) (identity connector.Identity, err error) { +func (c *oidcConnector) HandleCallback(s connector.Scopes, r *http.Request, connDataBytes []byte) (identity connector.Identity, err error) { + if len(connDataBytes) == 0 { + return identity, fmt.Errorf("connector data was unexpectedly empty") + } + + var connData connectorData + err = json.Unmarshal(connDataBytes, &connData) + if err != nil { + return identity, fmt.Errorf("failed to parse connector data: %v", err) + } + q := r.URL.Query() if errType := q.Get("error"); errType != "" { return identity, &oauth2Error{errType, q.Get("error_description")} } + + if c.responseType == "id_token" { + rawIDToken := r.FormValue("id_token") + if rawIDToken == "" { + return identity, fmt.Errorf("authorization response lacked id_token despite using the implicit flow") + } + return c.createIdentity(r.Context(), rawIDToken, connData, NoRefreshTokenDummy) + } + token, err := c.oauth2Config.Exchange(r.Context(), q.Get("code")) if err != nil { return identity, fmt.Errorf("oidc: failed to get token: %v", err) } - return c.createIdentity(r.Context(), identity, token) + rawIDToken, ok := token.Extra("id_token").(string) + if !ok { + return identity, errors.New("oidc: no id_token in token response") + } + + return c.createIdentity(r.Context(), rawIDToken, connData, token.RefreshToken) } // Refresh is implemented for backwards compatibility, even though it's a no-op. func (c *oidcConnector) Refresh(ctx context.Context, s connector.Scopes, identity connector.Identity) (connector.Identity, error) { + if c.responseType == "id_token" { + return identity, fmt.Errorf("oidc: there is no refresh_token with implict flow") + } + t := &oauth2.Token{ RefreshToken: string(identity.ConnectorData), Expiry: time.Now().Add(-time.Hour), @@ -196,26 +264,40 @@ func (c *oidcConnector) Refresh(ctx context.Context, s connector.Scopes, identit return identity, fmt.Errorf("oidc: failed to get token: %v", err) } - return c.createIdentity(ctx, identity, token) -} - -func (c *oidcConnector) createIdentity(ctx context.Context, identity connector.Identity, token *oauth2.Token) (connector.Identity, error) { rawIDToken, ok := token.Extra("id_token").(string) if !ok { return identity, errors.New("oidc: no id_token in token response") } + + return c.createIdentity(ctx, rawIDToken, connectorData{}, token.RefreshToken) +} + +func (c *oidcConnector) createIdentity(ctx context.Context, rawIDToken string, connData connectorData, refreshToken string) (identity connector.Identity, err error) { idToken, err := c.verifier.Verify(ctx, rawIDToken) if err != nil { return identity, fmt.Errorf("oidc: failed to verify ID Token: %v", err) } + if c.responseType == "id_token" { + // validate the nonce, we're in the implicit flow + var nonceClaim struct { + Nonce string `json:"nonce"` + } + if err := idToken.Claims(&nonceClaim); err != nil { + return identity, fmt.Errorf("oidc: failed to decode claims: %v", err) + } + if nonceClaim.Nonce != connData.Nonce { + return identity, fmt.Errorf("oidc: invalid nonce from provider") + } + } + var claims struct { Username string `json:"name"` Email string `json:"email"` EmailVerified bool `json:"email_verified"` HostedDomain string `json:"hd"` } - if err := idToken.Claims(&claims); err != nil { + if err = idToken.Claims(&claims); err != nil { return identity, fmt.Errorf("oidc: failed to decode claims: %v", err) } @@ -238,7 +320,7 @@ func (c *oidcConnector) createIdentity(ctx context.Context, identity connector.I Username: claims.Username, Email: claims.Email, EmailVerified: claims.EmailVerified, - ConnectorData: []byte(token.RefreshToken), + ConnectorData: []byte(refreshToken), } return identity, nil } diff --git a/server/handlers.go b/server/handlers.go index 4d90dd83fc..38e8110da0 100644 --- a/server/handlers.go +++ b/server/handlers.go @@ -302,12 +302,22 @@ func (s *Server) handleConnectorLogin(w http.ResponseWriter, r *http.Request) { // Use the auth request ID as the "state" token. // // TODO(ericchiang): Is this appropriate or should we also be using a nonce? - callbackURL, err := conn.LoginURL(scopes, s.absURL("/callback"), authReqID) + callbackURL, connectorData, err := conn.LoginURL(scopes, s.absURL("/callback"), authReqID) if err != nil { s.logger.Errorf("Connector %q returned error when creating callback: %v", connID, err) s.renderError(w, http.StatusInternalServerError, "Login error.") return } + + updater := func(a storage.AuthRequest) (storage.AuthRequest, error) { + a.ConnectorData = connectorData + return a, nil + } + if err := s.storage.UpdateAuthRequest(authReq.ID, updater); err != nil { + s.logger.Errorf("Failed to set connector Data on auth request: %v", err) + s.renderError(w, http.StatusInternalServerError, "Database error.") + } + http.Redirect(w, r, callbackURL, http.StatusFound) case connector.PasswordConnector: if err := s.templates.password(w, r.URL.String(), "", usernamePrompt(conn), false, showBacklink); err != nil { @@ -378,6 +388,7 @@ func (s *Server) handleConnectorLogin(w http.ResponseWriter, r *http.Request) { func (s *Server) handleConnectorCallback(w http.ResponseWriter, r *http.Request) { var authID string + var isPostFormResponse bool switch r.Method { case http.MethodGet: // OAuth2 callback if authID = r.URL.Query().Get("state"); authID == "" { @@ -386,8 +397,11 @@ func (s *Server) handleConnectorCallback(w http.ResponseWriter, r *http.Request) } case http.MethodPost: // SAML POST binding if authID = r.PostFormValue("RelayState"); authID == "" { - s.renderError(w, http.StatusBadRequest, "User session error.") - return + isPostFormResponse = true + if authID = r.PostFormValue("state"); authID == "" { + s.renderError(w, http.StatusBadRequest, "User session error.") + return + } } default: s.renderError(w, http.StatusBadRequest, "Method not supported") @@ -422,12 +436,12 @@ func (s *Server) handleConnectorCallback(w http.ResponseWriter, r *http.Request) var identity connector.Identity switch conn := conn.Connector.(type) { case connector.CallbackConnector: - if r.Method != http.MethodGet { + if r.Method != http.MethodGet && !isPostFormResponse { s.logger.Errorf("SAML request mapped to OAuth2 connector") s.renderError(w, http.StatusBadRequest, "Invalid request") return } - identity, err = conn.HandleCallback(parseScopes(authReq.Scopes), r) + identity, err = conn.HandleCallback(parseScopes(authReq.Scopes), r, authReq.ConnectorData) case connector.SAMLConnector: if r.Method != http.MethodPost { s.logger.Errorf("OAuth2 request mapped to SAML connector") diff --git a/server/server_test.go b/server/server_test.go index 536387c40d..3272261876 100644 --- a/server/server_test.go +++ b/server/server_test.go @@ -623,11 +623,8 @@ func TestOAuth2ImplicitFlow(t *testing.T) { t.Fatalf("failed to create client: %v", err) } - src := &nonceSource{nonce: nonce} - idTokenVerifier := p.Verifier(&oidc.Config{ - ClientID: client.ID, - ClaimNonce: src.ClaimNonce, + ClientID: client.ID, }) oauth2Config = &oauth2.Config{ @@ -638,7 +635,7 @@ func TestOAuth2ImplicitFlow(t *testing.T) { RedirectURL: redirectURL, } - checkIDToken := func(u *url.URL) error { + checkIDToken := func(u *url.URL, nonce string) error { if u.Fragment == "" { return fmt.Errorf("url has no fragment: %s", u) } @@ -650,9 +647,15 @@ func TestOAuth2ImplicitFlow(t *testing.T) { if idToken == "" { return errors.New("no id_token in fragment") } - if _, err := idTokenVerifier.Verify(ctx, idToken); err != nil { + parsedToken, err := idTokenVerifier.Verify(ctx, idToken) + if err != nil { return fmt.Errorf("failed to verify id_token: %v", err) } + // check nonce since Verifier no longer does: + // retrieve from the connectordata? + if parsedToken.Nonce != nonce { + return fmt.Errorf("the id_token nonce was incorrect") + } return nil } @@ -669,7 +672,7 @@ func TestOAuth2ImplicitFlow(t *testing.T) { // for an ID Token. u := req.URL.String() if strings.HasPrefix(u, oauth2Server.URL) { - if err := checkIDToken(req.URL); err == nil { + if err := checkIDToken(req.URL, nonce); err == nil { gotIDToken = true } else { t.Error(err) diff --git a/storage/kubernetes/types.go b/storage/kubernetes/types.go index 87b3557918..e0934aabf7 100644 --- a/storage/kubernetes/types.go +++ b/storage/kubernetes/types.go @@ -337,6 +337,8 @@ type AuthRequest struct { Claims Claims `json:"claims,omitempty"` // The connector used to login the user. Set when the user authenticates. ConnectorID string `json:"connectorID,omitempty"` + // Connector Data for auth requests in flight + ConnectorData []byte `json:"connectorData,omitempty"` Expiry time.Time `json:"expiry"` } @@ -385,6 +387,7 @@ func (cli *client) fromStorageAuthRequest(a storage.AuthRequest) AuthRequest { LoggedIn: a.LoggedIn, ForceApprovalPrompt: a.ForceApprovalPrompt, ConnectorID: a.ConnectorID, + ConnectorData: a.ConnectorData, Expiry: a.Expiry, Claims: fromStorageClaims(a.Claims), } @@ -634,8 +637,8 @@ func (cli *client) fromStorageOfflineSessions(o storage.OfflineSessions) Offline }, UserID: o.UserID, ConnID: o.ConnID, - Refresh: o.Refresh, ConnectorData: o.ConnectorData, + Refresh: o.Refresh, } } diff --git a/storage/sql/crud.go b/storage/sql/crud.go index a5a8832760..1724c51160 100644 --- a/storage/sql/crud.go +++ b/storage/sql/crud.go @@ -110,18 +110,18 @@ func (c *conn) CreateAuthRequest(a storage.AuthRequest) error { force_approval_prompt, logged_in, claims_user_id, claims_username, claims_email, claims_email_verified, claims_groups, - connector_id, + connector_id, connector_data, expiry ) values ( - $1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15, $16 + $1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15, $16, $17 ); `, a.ID, a.ClientID, encoder(a.ResponseTypes), encoder(a.Scopes), a.RedirectURI, a.Nonce, a.State, a.ForceApprovalPrompt, a.LoggedIn, a.Claims.UserID, a.Claims.Username, a.Claims.Email, a.Claims.EmailVerified, encoder(a.Claims.Groups), - a.ConnectorID, + a.ConnectorID, a.ConnectorData, a.Expiry, ) if err != nil { @@ -152,15 +152,15 @@ func (c *conn) UpdateAuthRequest(id string, updater func(a storage.AuthRequest) claims_user_id = $9, claims_username = $10, claims_email = $11, claims_email_verified = $12, claims_groups = $13, - connector_id = $14, - expiry = $15 - where id = $16; + connector_id = $14, connector_data = $15, + expiry = $16 + where id = $17; `, a.ClientID, encoder(a.ResponseTypes), encoder(a.Scopes), a.RedirectURI, a.Nonce, a.State, a.ForceApprovalPrompt, a.LoggedIn, a.Claims.UserID, a.Claims.Username, a.Claims.Email, a.Claims.EmailVerified, encoder(a.Claims.Groups), - a.ConnectorID, + a.ConnectorID, a.ConnectorData, a.Expiry, r.ID, ) if err != nil { @@ -182,14 +182,14 @@ func getAuthRequest(q querier, id string) (a storage.AuthRequest, err error) { force_approval_prompt, logged_in, claims_user_id, claims_username, claims_email, claims_email_verified, claims_groups, - connector_id, expiry + connector_id, connector_data, expiry from auth_request where id = $1; `, id).Scan( &a.ID, &a.ClientID, decoder(&a.ResponseTypes), decoder(&a.Scopes), &a.RedirectURI, &a.Nonce, &a.State, &a.ForceApprovalPrompt, &a.LoggedIn, &a.Claims.UserID, &a.Claims.Username, &a.Claims.Email, &a.Claims.EmailVerified, decoder(&a.Claims.Groups), - &a.ConnectorID, &a.Expiry, + &a.ConnectorID, &a.ConnectorData, &a.Expiry, ) if err != nil { if err == sql.ErrNoRows { diff --git a/storage/storage.go b/storage/storage.go index 2c1f3d00d3..3ea24ad9ae 100644 --- a/storage/storage.go +++ b/storage/storage.go @@ -181,6 +181,10 @@ type AuthRequest struct { // The connector used to login the user and any data the connector wishes to persists. // Set when the user authenticates. ConnectorID string + // Set when the user starts authenticating upstream + // Used for OIDC flows that require state saved between starting + // the authorization flow and getting the callback + ConnectorData []byte } // AuthCode represents a code which can be exchanged for an OAuth2 token response.