diff --git a/connector/bitbucketcloud/bitbucketcloud.go b/connector/bitbucketcloud/bitbucketcloud.go index 5f802e3414..d7fb64caa6 100644 --- a/connector/bitbucketcloud/bitbucketcloud.go +++ b/connector/bitbucketcloud/bitbucketcloud.go @@ -111,12 +111,12 @@ func (b *bitbucketConnector) oauth2Config(scopes connector.Scopes) *oauth2.Confi } } -func (b *bitbucketConnector) LoginURL(scopes connector.Scopes, callbackURL, state string) (string, error) { +func (b *bitbucketConnector) LoginURL(scopes connector.Scopes, callbackURL, state string) (string, []byte, error) { if b.redirectURI != callbackURL { - return "", fmt.Errorf("expected callback URL %q did not match the URL in the config %q", callbackURL, b.redirectURI) + return "", nil, fmt.Errorf("expected callback URL %q did not match the URL in the config %q", callbackURL, b.redirectURI) } - return b.oauth2Config(scopes).AuthCodeURL(state), nil + return b.oauth2Config(scopes).AuthCodeURL(state), nil, nil } type oauth2Error struct { @@ -131,7 +131,7 @@ func (e *oauth2Error) Error() string { return e.error + ": " + e.errorDescription } -func (b *bitbucketConnector) HandleCallback(s connector.Scopes, r *http.Request) (identity connector.Identity, err error) { +func (b *bitbucketConnector) HandleCallback(s connector.Scopes, connData []byte, r *http.Request) (identity connector.Identity, err error) { q := r.URL.Query() if errType := q.Get("error"); errType != "" { return identity, &oauth2Error{errType, q.Get("error_description")} diff --git a/connector/bitbucketcloud/bitbucketcloud_test.go b/connector/bitbucketcloud/bitbucketcloud_test.go index 9545ff09c5..67a74dab38 100644 --- a/connector/bitbucketcloud/bitbucketcloud_test.go +++ b/connector/bitbucketcloud/bitbucketcloud_test.go @@ -102,7 +102,7 @@ func TestUsernameIncludedInFederatedIdentity(t *testing.T) { expectNil(t, err) bitbucketConnector := bitbucketConnector{apiURL: s.URL, hostName: hostURL.Host, httpClient: newClient()} - identity, err := bitbucketConnector.HandleCallback(connector.Scopes{}, req) + identity, err := bitbucketConnector.HandleCallback(connector.Scopes{}, nil, req) expectNil(t, err) expectEquals(t, identity.Username, "some-login") diff --git a/connector/connector.go b/connector/connector.go index d812390f0c..b1e069c3fc 100644 --- a/connector/connector.go +++ b/connector/connector.go @@ -63,10 +63,10 @@ type CallbackConnector interface { // requested if one has already been issues. There's no good general answer // for these kind of restrictions, and may require this package to become more // aware of the global set of user/connector interactions. - LoginURL(s Scopes, callbackURL, state string) (string, error) + LoginURL(s Scopes, callbackURL, state string) (string, []byte, error) // Handle the callback to the server and return an identity. - HandleCallback(s Scopes, r *http.Request) (identity Identity, err error) + HandleCallback(s Scopes, connData []byte, r *http.Request) (identity Identity, err error) } // SAMLConnector represents SAML connectors which implement the HTTP POST binding. diff --git a/connector/gitea/gitea.go b/connector/gitea/gitea.go index 62523185d5..059c861705 100644 --- a/connector/gitea/gitea.go +++ b/connector/gitea/gitea.go @@ -102,11 +102,11 @@ func (c *giteaConnector) oauth2Config(_ connector.Scopes) *oauth2.Config { } } -func (c *giteaConnector) LoginURL(scopes connector.Scopes, callbackURL, state string) (string, error) { +func (c *giteaConnector) LoginURL(scopes connector.Scopes, callbackURL, state string) (string, []byte, error) { if c.redirectURI != callbackURL { - return "", fmt.Errorf("expected callback URL %q did not match the URL in the config %q", c.redirectURI, callbackURL) + return "", nil, fmt.Errorf("expected callback URL %q did not match the URL in the config %q", c.redirectURI, callbackURL) } - return c.oauth2Config(scopes).AuthCodeURL(state), nil + return c.oauth2Config(scopes).AuthCodeURL(state), nil, nil } type oauth2Error struct { @@ -121,7 +121,7 @@ func (e *oauth2Error) Error() string { return e.error + ": " + e.errorDescription } -func (c *giteaConnector) HandleCallback(s connector.Scopes, r *http.Request) (identity connector.Identity, err error) { +func (c *giteaConnector) HandleCallback(s connector.Scopes, connData []byte, r *http.Request) (identity connector.Identity, err error) { q := r.URL.Query() if errType := q.Get("error"); errType != "" { return identity, &oauth2Error{errType, q.Get("error_description")} diff --git a/connector/gitea/gitea_test.go b/connector/gitea/gitea_test.go index a71d79956e..4fe7768901 100644 --- a/connector/gitea/gitea_test.go +++ b/connector/gitea/gitea_test.go @@ -30,14 +30,14 @@ func TestUsernameIncludedInFederatedIdentity(t *testing.T) { expectNil(t, err) c := giteaConnector{baseURL: s.URL, httpClient: newClient()} - identity, err := c.HandleCallback(connector.Scopes{}, req) + identity, err := c.HandleCallback(connector.Scopes{}, nil, req) expectNil(t, err) expectEquals(t, identity.Username, "some@email.com") expectEquals(t, identity.UserID, "12345678") c = giteaConnector{baseURL: s.URL, httpClient: newClient()} - identity, err = c.HandleCallback(connector.Scopes{}, req) + identity, err = c.HandleCallback(connector.Scopes{}, nil, req) expectNil(t, err) expectEquals(t, identity.Username, "some@email.com") diff --git a/connector/github/github.go b/connector/github/github.go index 18a56628af..eb19f77805 100644 --- a/connector/github/github.go +++ b/connector/github/github.go @@ -194,12 +194,12 @@ func (c *githubConnector) oauth2Config(scopes connector.Scopes) *oauth2.Config { } } -func (c *githubConnector) LoginURL(scopes connector.Scopes, callbackURL, state string) (string, error) { +func (c *githubConnector) LoginURL(scopes connector.Scopes, callbackURL, state string) (string, []byte, error) { if c.redirectURI != callbackURL { - return "", fmt.Errorf("expected callback URL %q did not match the URL in the config %q", callbackURL, c.redirectURI) + return "", nil, fmt.Errorf("expected callback URL %q did not match the URL in the config %q", callbackURL, c.redirectURI) } - return c.oauth2Config(scopes).AuthCodeURL(state), nil + return c.oauth2Config(scopes).AuthCodeURL(state), nil, nil } type oauth2Error struct { @@ -214,7 +214,7 @@ func (e *oauth2Error) Error() string { return e.error + ": " + e.errorDescription } -func (c *githubConnector) HandleCallback(s connector.Scopes, r *http.Request) (identity connector.Identity, err error) { +func (c *githubConnector) HandleCallback(s connector.Scopes, connData []byte, r *http.Request) (identity connector.Identity, err error) { q := r.URL.Query() if errType := q.Get("error"); errType != "" { return identity, &oauth2Error{errType, q.Get("error_description")} diff --git a/connector/github/github_test.go b/connector/github/github_test.go index 088cbb238c..fb57161ce2 100644 --- a/connector/github/github_test.go +++ b/connector/github/github_test.go @@ -153,7 +153,7 @@ func TestUsernameIncludedInFederatedIdentity(t *testing.T) { expectNil(t, err) c := githubConnector{apiURL: s.URL, hostName: hostURL.Host, httpClient: newClient()} - identity, err := c.HandleCallback(connector.Scopes{Groups: true}, req) + identity, err := c.HandleCallback(connector.Scopes{Groups: true}, nil, req) expectNil(t, err) expectEquals(t, identity.Username, "some-login") @@ -161,7 +161,7 @@ func TestUsernameIncludedInFederatedIdentity(t *testing.T) { expectEquals(t, 0, len(identity.Groups)) c = githubConnector{apiURL: s.URL, hostName: hostURL.Host, httpClient: newClient(), loadAllGroups: true} - identity, err = c.HandleCallback(connector.Scopes{Groups: true}, req) + identity, err = c.HandleCallback(connector.Scopes{Groups: true}, nil, req) expectNil(t, err) expectEquals(t, identity.Username, "some-login") @@ -194,7 +194,7 @@ func TestLoginUsedAsIDWhenConfigured(t *testing.T) { expectNil(t, err) c := githubConnector{apiURL: s.URL, hostName: hostURL.Host, httpClient: newClient(), useLoginAsID: true} - identity, err := c.HandleCallback(connector.Scopes{Groups: true}, req) + identity, err := c.HandleCallback(connector.Scopes{Groups: true}, nil, req) expectNil(t, err) expectEquals(t, identity.UserID, "some-login") diff --git a/connector/gitlab/gitlab.go b/connector/gitlab/gitlab.go index 7aa4439842..b845ddcb32 100644 --- a/connector/gitlab/gitlab.go +++ b/connector/gitlab/gitlab.go @@ -105,11 +105,11 @@ func (c *gitlabConnector) oauth2Config(scopes connector.Scopes) *oauth2.Config { } } -func (c *gitlabConnector) LoginURL(scopes connector.Scopes, callbackURL, state string) (string, error) { +func (c *gitlabConnector) LoginURL(scopes connector.Scopes, callbackURL, state string) (string, []byte, error) { if c.redirectURI != callbackURL { - return "", fmt.Errorf("expected callback URL %q did not match the URL in the config %q", c.redirectURI, callbackURL) + return "", nil, fmt.Errorf("expected callback URL %q did not match the URL in the config %q", c.redirectURI, callbackURL) } - return c.oauth2Config(scopes).AuthCodeURL(state), nil + return c.oauth2Config(scopes).AuthCodeURL(state), nil, nil } type oauth2Error struct { @@ -124,7 +124,7 @@ func (e *oauth2Error) Error() string { return e.error + ": " + e.errorDescription } -func (c *gitlabConnector) HandleCallback(s connector.Scopes, r *http.Request) (identity connector.Identity, err error) { +func (c *gitlabConnector) HandleCallback(s connector.Scopes, connData []byte, r *http.Request) (identity connector.Identity, err error) { q := r.URL.Query() if errType := q.Get("error"); errType != "" { return identity, &oauth2Error{errType, q.Get("error_description")} diff --git a/connector/gitlab/gitlab_test.go b/connector/gitlab/gitlab_test.go index b67b30c045..045f3907a1 100644 --- a/connector/gitlab/gitlab_test.go +++ b/connector/gitlab/gitlab_test.go @@ -84,7 +84,7 @@ func TestUsernameIncludedInFederatedIdentity(t *testing.T) { expectNil(t, err) c := gitlabConnector{baseURL: s.URL, httpClient: newClient()} - identity, err := c.HandleCallback(connector.Scopes{Groups: false}, req) + identity, err := c.HandleCallback(connector.Scopes{Groups: false}, nil, req) expectNil(t, err) expectEquals(t, identity.Username, "some@email.com") @@ -92,7 +92,7 @@ func TestUsernameIncludedInFederatedIdentity(t *testing.T) { expectEquals(t, 0, len(identity.Groups)) c = gitlabConnector{baseURL: s.URL, httpClient: newClient()} - identity, err = c.HandleCallback(connector.Scopes{Groups: true}, req) + identity, err = c.HandleCallback(connector.Scopes{Groups: true}, nil, req) expectNil(t, err) expectEquals(t, identity.Username, "some@email.com") @@ -120,7 +120,7 @@ func TestLoginUsedAsIDWhenConfigured(t *testing.T) { expectNil(t, err) c := gitlabConnector{baseURL: s.URL, httpClient: newClient(), useLoginAsID: true} - identity, err := c.HandleCallback(connector.Scopes{Groups: true}, req) + identity, err := c.HandleCallback(connector.Scopes{Groups: true}, nil, req) expectNil(t, err) expectEquals(t, identity.UserID, "joebloggs") @@ -147,7 +147,7 @@ func TestLoginWithTeamWhitelisted(t *testing.T) { expectNil(t, err) c := gitlabConnector{baseURL: s.URL, httpClient: newClient(), groups: []string{"team-1"}} - identity, err := c.HandleCallback(connector.Scopes{Groups: true}, req) + identity, err := c.HandleCallback(connector.Scopes{Groups: true}, nil, req) expectNil(t, err) expectEquals(t, identity.UserID, "12345678") @@ -174,7 +174,7 @@ func TestLoginWithTeamNonWhitelisted(t *testing.T) { expectNil(t, err) c := gitlabConnector{baseURL: s.URL, httpClient: newClient(), groups: []string{"team-2"}} - _, err = c.HandleCallback(connector.Scopes{Groups: true}, req) + _, err = c.HandleCallback(connector.Scopes{Groups: true}, nil, req) expectNotNil(t, err, "HandleCallback error") expectEquals(t, err.Error(), "gitlab: get groups: gitlab: user \"joebloggs\" is not in any of the required groups") @@ -208,7 +208,7 @@ func TestRefresh(t *testing.T) { }) expectNil(t, err) - identity, err := c.HandleCallback(connector.Scopes{OfflineAccess: true}, req) + identity, err := c.HandleCallback(connector.Scopes{OfflineAccess: true}, nil, req) expectNil(t, err) expectEquals(t, identity.Username, "some@email.com") expectEquals(t, identity.UserID, "12345678") @@ -272,7 +272,7 @@ func TestGroupsWithPermission(t *testing.T) { expectNil(t, err) c := gitlabConnector{baseURL: s.URL, httpClient: newClient(), getGroupsPermission: true} - identity, err := c.HandleCallback(connector.Scopes{Groups: true}, req) + identity, err := c.HandleCallback(connector.Scopes{Groups: true}, nil, req) expectNil(t, err) expectEquals(t, identity.Groups, []string{ diff --git a/connector/google/google.go b/connector/google/google.go index e17ec5bd7f..4a8599c0b1 100644 --- a/connector/google/google.go +++ b/connector/google/google.go @@ -168,9 +168,9 @@ func (c *googleConnector) Close() error { return nil } -func (c *googleConnector) LoginURL(s connector.Scopes, callbackURL, state string) (string, error) { +func (c *googleConnector) LoginURL(s connector.Scopes, callbackURL, state string) (string, []byte, error) { if c.redirectURI != callbackURL { - return "", fmt.Errorf("expected callback URL %q did not match the URL in the config %q", callbackURL, c.redirectURI) + return "", nil, fmt.Errorf("expected callback URL %q did not match the URL in the config %q", callbackURL, c.redirectURI) } var opts []oauth2.AuthCodeOption @@ -186,7 +186,7 @@ func (c *googleConnector) LoginURL(s connector.Scopes, callbackURL, state string opts = append(opts, oauth2.AccessTypeOffline, oauth2.SetAuthURLParam("prompt", c.promptType)) } - return c.oauth2Config.AuthCodeURL(state, opts...), nil + return c.oauth2Config.AuthCodeURL(state, opts...), nil, nil } type oauth2Error struct { @@ -201,7 +201,7 @@ func (e *oauth2Error) Error() string { return e.error + ": " + e.errorDescription } -func (c *googleConnector) HandleCallback(s connector.Scopes, r *http.Request) (identity connector.Identity, err error) { +func (c *googleConnector) HandleCallback(s connector.Scopes, connData []byte, r *http.Request) (identity connector.Identity, err error) { q := r.URL.Query() if errType := q.Get("error"); errType != "" { return identity, &oauth2Error{errType, q.Get("error_description")} diff --git a/connector/google/google_test.go b/connector/google/google_test.go index bafcadc8ff..68188c6559 100644 --- a/connector/google/google_test.go +++ b/connector/google/google_test.go @@ -440,7 +440,7 @@ func TestPromptTypeConfig(t *testing.T) { assert.Nil(t, err) assert.Equal(t, test.expectedPromptTypeValue, conn.promptType) - loginURL, err := conn.LoginURL(connector.Scopes{OfflineAccess: true}, ts.URL+"/callback", "state") + loginURL, _, err := conn.LoginURL(connector.Scopes{OfflineAccess: true}, ts.URL+"/callback", "state") assert.Nil(t, err) urlp, err := url.Parse(loginURL) diff --git a/connector/linkedin/linkedin.go b/connector/linkedin/linkedin.go index f17d17cca1..0c24ff4756 100644 --- a/connector/linkedin/linkedin.go +++ b/connector/linkedin/linkedin.go @@ -62,17 +62,17 @@ var ( ) // LoginURL returns an access token request URL -func (c *linkedInConnector) LoginURL(scopes connector.Scopes, callbackURL, state string) (string, error) { +func (c *linkedInConnector) LoginURL(scopes connector.Scopes, callbackURL, state string) (string, []byte, error) { if c.oauth2Config.RedirectURL != callbackURL { - return "", fmt.Errorf("expected callback URL %q did not match the URL in the config %q", + return "", nil, fmt.Errorf("expected callback URL %q did not match the URL in the config %q", callbackURL, c.oauth2Config.RedirectURL) } - return c.oauth2Config.AuthCodeURL(state), nil + return c.oauth2Config.AuthCodeURL(state), nil, nil } // HandleCallback handles HTTP redirect from LinkedIn -func (c *linkedInConnector) HandleCallback(s connector.Scopes, r *http.Request) (identity connector.Identity, err error) { +func (c *linkedInConnector) HandleCallback(s connector.Scopes, connData []byte, r *http.Request) (identity connector.Identity, err error) { q := r.URL.Query() if errType := q.Get("error"); errType != "" { return identity, &oauth2Error{errType, q.Get("error_description")} diff --git a/connector/microsoft/microsoft.go b/connector/microsoft/microsoft.go index 2fcf6a7515..1db16942b7 100644 --- a/connector/microsoft/microsoft.go +++ b/connector/microsoft/microsoft.go @@ -175,9 +175,9 @@ func (c *microsoftConnector) oauth2Config(scopes connector.Scopes) *oauth2.Confi } } -func (c *microsoftConnector) LoginURL(scopes connector.Scopes, callbackURL, state string) (string, error) { +func (c *microsoftConnector) LoginURL(scopes connector.Scopes, callbackURL, state string) (string, []byte, error) { if c.redirectURI != callbackURL { - return "", fmt.Errorf("expected callback URL %q did not match the URL in the config %q", callbackURL, c.redirectURI) + return "", nil, fmt.Errorf("expected callback URL %q did not match the URL in the config %q", callbackURL, c.redirectURI) } var options []oauth2.AuthCodeOption @@ -188,10 +188,10 @@ func (c *microsoftConnector) LoginURL(scopes connector.Scopes, callbackURL, stat options = append(options, oauth2.SetAuthURLParam("domain_hint", c.domainHint)) } - return c.oauth2Config(scopes).AuthCodeURL(state, options...), nil + return c.oauth2Config(scopes).AuthCodeURL(state, options...), nil, nil } -func (c *microsoftConnector) HandleCallback(s connector.Scopes, r *http.Request) (identity connector.Identity, err error) { +func (c *microsoftConnector) HandleCallback(s connector.Scopes, connData []byte, r *http.Request) (identity connector.Identity, err error) { q := r.URL.Query() if errType := q.Get("error"); errType != "" { return identity, &oauth2Error{errType, q.Get("error_description")} diff --git a/connector/microsoft/microsoft_test.go b/connector/microsoft/microsoft_test.go index 67be660fce..1fa2f3dcdf 100644 --- a/connector/microsoft/microsoft_test.go +++ b/connector/microsoft/microsoft_test.go @@ -39,7 +39,7 @@ func TestLoginURL(t *testing.T) { tenant: tenant, } - loginURL, _ := conn.LoginURL(connector.Scopes{}, conn.redirectURI, testState) + loginURL, _, _ := conn.LoginURL(connector.Scopes{}, conn.redirectURI, testState) parsedLoginURL, _ := url.Parse(loginURL) queryParams := parsedLoginURL.Query() @@ -70,7 +70,7 @@ func TestLoginURLWithOptions(t *testing.T) { domainHint: domainHint, } - loginURL, _ := conn.LoginURL(connector.Scopes{}, conn.redirectURI, "some-state") + loginURL, _, _ := conn.LoginURL(connector.Scopes{}, conn.redirectURI, "some-state") parsedLoginURL, _ := url.Parse(loginURL) queryParams := parsedLoginURL.Query() @@ -91,7 +91,7 @@ func TestUserIdentityFromGraphAPI(t *testing.T) { req, _ := http.NewRequest("GET", s.URL, nil) c := microsoftConnector{apiURL: s.URL, graphURL: s.URL, tenant: tenant} - identity, err := c.HandleCallback(connector.Scopes{Groups: false}, req) + identity, err := c.HandleCallback(connector.Scopes{Groups: false}, nil, req) expectNil(t, err) expectEquals(t, identity.Username, "Jane Doe") expectEquals(t, identity.UserID, "S56767889") @@ -114,7 +114,7 @@ func TestUserGroupsFromGraphAPI(t *testing.T) { req, _ := http.NewRequest("GET", s.URL, nil) c := microsoftConnector{apiURL: s.URL, graphURL: s.URL, tenant: tenant} - identity, err := c.HandleCallback(connector.Scopes{Groups: true}, req) + identity, err := c.HandleCallback(connector.Scopes{Groups: true}, nil, req) expectNil(t, err) expectEquals(t, identity.Groups, []string{"a", "b"}) } diff --git a/connector/mock/connectortest.go b/connector/mock/connectortest.go index 7e5979a992..be44bfd12a 100644 --- a/connector/mock/connectortest.go +++ b/connector/mock/connectortest.go @@ -43,21 +43,21 @@ type Callback struct { } // LoginURL returns the URL to redirect the user to login with. -func (m *Callback) LoginURL(s connector.Scopes, callbackURL, state string) (string, error) { +func (m *Callback) LoginURL(s connector.Scopes, callbackURL, state string) (string, []byte, error) { u, err := url.Parse(callbackURL) if err != nil { - return "", fmt.Errorf("failed to parse callbackURL %q: %v", callbackURL, err) + return "", nil, fmt.Errorf("failed to parse callbackURL %q: %v", callbackURL, err) } v := u.Query() v.Set("state", state) u.RawQuery = v.Encode() - return u.String(), nil + return u.String(), nil, nil } var connectorData = []byte("foobar") // HandleCallback parses the request and returns the user's identity -func (m *Callback) HandleCallback(s connector.Scopes, r *http.Request) (connector.Identity, error) { +func (m *Callback) HandleCallback(s connector.Scopes, connData []byte, r *http.Request) (connector.Identity, error) { return m.Identity, nil } diff --git a/connector/oidc/oidc.go b/connector/oidc/oidc.go index 1ea0c1fc1a..60f4c6096b 100644 --- a/connector/oidc/oidc.go +++ b/connector/oidc/oidc.go @@ -21,6 +21,20 @@ import ( "github.com/dexidp/dex/pkg/httpclient" ) +const ( + codeChallengeMethodPlain = "plain" + codeChallengeMethodS256 = "S256" +) + +func contains(arr []string, item string) bool { + for _, itemFromArray := range arr { + if itemFromArray == item { + return true + } + } + return false +} + // Config holds configuration options for OpenID Connect logins. type Config struct { Issuer string `json:"issuer"` @@ -84,6 +98,10 @@ type Config struct { // PromptType will be used for the prompt parameter (when offline_access, by default prompt=consent) PromptType *string `json:"promptType"` + // PKCEChallenge specifies which PKCE algorithm will be used + // If not setted it will be auto-detected the best-fit for the connector. + PKCEChallenge string `json:"pkceChallenge"` + // OverrideClaimMapping will be used to override the options defined in claimMappings. // i.e. if there are 'email' and `preferred_email` claims available, by default Dex will always use the `email` claim independent of the ClaimMapping.EmailKey. // This setting allows you to override the default behavior of Dex and enforce the mappings defined in `claimMapping`. @@ -217,6 +235,25 @@ func knownBrokenAuthHeaderProvider(issuerURL string) bool { return false } +// PKCEChallengeData is used to store info for PKCE Challenge method and verifier +// in the connectorData +type PKCEChallengeData struct { + CodeChallenge string `json:"codeChallenge"` + CodeChallengeMethod string `json:"codeChallengeMethod"` +} + +// Returns an AuthCodeOption according to the provided codeChallengeMethod +func getAuthCodeOptionForCodeChallenge(codeVerifier, codeChallengeMethod string) (oauth2.AuthCodeOption, error) { + switch codeChallengeMethod { + case codeChallengeMethodPlain: + return oauth2.VerifierOption(codeVerifier), nil + case codeChallengeMethodS256: + return oauth2.S256ChallengeOption(codeVerifier), nil + default: + return nil, fmt.Errorf("unknown challenge method (%v)", codeChallengeMethod) + } +} + // Open returns a connector which can be used to login users through an upstream // OpenID Connect provider. func (c *Config) Open(id string, logger *slog.Logger) (conn connector.Connector, err error) { @@ -275,6 +312,27 @@ func (c *Config) Open(id string, logger *slog.Logger) (conn connector.Connector, } } + // Obtain CodeChallengeMethodsSupported from the provider + var metadata struct { + CodeChallengeMethodsSupported []string `json:"code_challenge_methods_supported"` + } + if err := provider.Claims(&metadata); err != nil { + logger.Warn("failed to parse provider metadata") + } + // if PKCEChallenge method has not been setted in the config, auto-detect the best fit + if c.PKCEChallenge == "" { + if contains(metadata.CodeChallengeMethodsSupported, codeChallengeMethodS256) { + c.PKCEChallenge = codeChallengeMethodS256 + } else if contains(metadata.CodeChallengeMethodsSupported, codeChallengeMethodPlain) { + c.PKCEChallenge = codeChallengeMethodPlain + } + } else { + // if PKCEChallenge method has been setted in the config, check if it is supported + if !contains(metadata.CodeChallengeMethodsSupported, c.PKCEChallenge) { + logger.Warn("provided PKCEChallenge method not supported by the connector") + } + } + clientID := c.ClientID return &oidcConnector{ provider: provider, @@ -307,6 +365,7 @@ func (c *Config) Open(id string, logger *slog.Logger) (conn connector.Connector, groupsKey: c.ClaimMapping.GroupsKey, newGroupFromClaims: c.ClaimMutations.NewGroupFromClaims, groupsFilter: groupsFilter, + pkceChallenge: c.PKCEChallenge, }, nil } @@ -337,6 +396,7 @@ type oidcConnector struct { groupsKey string newGroupFromClaims []NewGroupFromClaims groupsFilter *regexp.Regexp + pkceChallenge string } func (c *oidcConnector) Close() error { @@ -344,12 +404,13 @@ func (c *oidcConnector) Close() error { return nil } -func (c *oidcConnector) LoginURL(s connector.Scopes, callbackURL, state string) (string, error) { +func (c *oidcConnector) LoginURL(s connector.Scopes, callbackURL, state string) (string, []byte, error) { if c.redirectURI != callbackURL { - return "", fmt.Errorf("expected callback URL %q did not match the URL in the config %q", callbackURL, c.redirectURI) + return "", nil, fmt.Errorf("expected callback URL %q did not match the URL in the config %q", callbackURL, c.redirectURI) } var opts []oauth2.AuthCodeOption + var connectorData []byte if len(c.acrValues) > 0 { acrValues := strings.Join(c.acrValues, " ") @@ -359,7 +420,25 @@ func (c *oidcConnector) LoginURL(s connector.Scopes, callbackURL, state string) if s.OfflineAccess { opts = append(opts, oauth2.AccessTypeOffline, oauth2.SetAuthURLParam("prompt", c.promptType)) } - return c.oauth2Config.AuthCodeURL(state, opts...), nil + + if c.pkceChallenge != "" { + codeVerifier := oauth2.GenerateVerifier() + authCodeOption, err := getAuthCodeOptionForCodeChallenge(codeVerifier, c.pkceChallenge) + if err != nil { + return "", nil, fmt.Errorf("oidc: failed to get PKCE AuthCodeOption for CodeChallenge: %v", err) + } + data := PKCEChallengeData{ + CodeChallenge: codeVerifier, + CodeChallengeMethod: c.pkceChallenge, + } + connectorData, err = json.Marshal(data) + if err != nil { + return "", nil, fmt.Errorf("oidc: failed to create PKCEChallenge data: %v", err) + } + opts = append(opts, authCodeOption) + } + + return c.oauth2Config.AuthCodeURL(state, opts...), connectorData, nil } type oauth2Error struct { @@ -382,7 +461,7 @@ const ( exchangeCaller ) -func (c *oidcConnector) HandleCallback(s connector.Scopes, r *http.Request) (identity connector.Identity, err error) { +func (c *oidcConnector) HandleCallback(s connector.Scopes, connData []byte, r *http.Request) (identity connector.Identity, err error) { q := r.URL.Query() if errType := q.Get("error"); errType != "" { return identity, &oauth2Error{errType, q.Get("error_description")} @@ -390,7 +469,19 @@ func (c *oidcConnector) HandleCallback(s connector.Scopes, r *http.Request) (ide ctx := context.WithValue(r.Context(), oauth2.HTTPClient, c.httpClient) - token, err := c.oauth2Config.Exchange(ctx, q.Get("code")) + var opts []oauth2.AuthCodeOption + if c.pkceChallenge != "" { + var data PKCEChallengeData + if err := json.Unmarshal(connData, &data); err != nil { + return identity, fmt.Errorf("oidc: failed to parse PKCEChallenge data: %v", err) + } + if data.CodeChallenge == "" { + return identity, fmt.Errorf("oidc: invalid PKCE CodeChallenge") + } + opts = append(opts, oauth2.VerifierOption(data.CodeChallenge)) + } + + token, err := c.oauth2Config.Exchange(ctx, q.Get("code"), opts...) if err != nil { return identity, fmt.Errorf("oidc: failed to get token: %v", err) } diff --git a/connector/oidc/oidc_test.go b/connector/oidc/oidc_test.go index e31d4e0b94..c3738757ae 100644 --- a/connector/oidc/oidc_test.go +++ b/connector/oidc/oidc_test.go @@ -66,6 +66,7 @@ func TestHandleCallback(t *testing.T) { token map[string]interface{} groupsRegex string newGroupFromClaims []NewGroupFromClaims + pkceChallenge string }{ { name: "simpleCase", @@ -431,6 +432,40 @@ func TestHandleCallback(t *testing.T) { "email_verified": true, }, }, + { + name: "S256PKCEChallenge", + userIDKey: "", // not configured + userNameKey: "", // not configured + pkceChallenge: "S256", + expectUserID: "subvalue", + expectUserName: "namevalue", + expectGroups: []string{"group1", "group2"}, + expectedEmailField: "emailvalue", + token: map[string]interface{}{ + "sub": "subvalue", + "name": "namevalue", + "groups": []string{"group1", "group2"}, + "email": "emailvalue", + "email_verified": true, + }, + }, + { + name: "plainPKCEChallenge", + userIDKey: "", // not configured + userNameKey: "", // not configured + pkceChallenge: "plain", + expectUserID: "subvalue", + expectUserName: "namevalue", + expectGroups: []string{"group1", "group2"}, + expectedEmailField: "emailvalue", + token: map[string]interface{}{ + "sub": "subvalue", + "name": "namevalue", + "groups": []string{"group1", "group2"}, + "email": "emailvalue", + "email_verified": true, + }, + }, } for _, tc := range tests { @@ -462,6 +497,7 @@ func TestHandleCallback(t *testing.T) { InsecureEnableGroups: true, BasicAuthUnsupported: &basicAuth, OverrideClaimMapping: tc.overrideClaimMapping, + PKCEChallenge: tc.pkceChallenge, } config.ClaimMapping.PreferredUsernameKey = tc.preferredUsernameKey config.ClaimMapping.EmailKey = tc.emailKey @@ -479,7 +515,11 @@ func TestHandleCallback(t *testing.T) { t.Fatal("failed to create request", err) } - identity, err := conn.HandleCallback(connector.Scopes{Groups: true}, req) + connectorDataStrTemplate := `{"codeChallenge":"abcdefgh123456qwertuiop89101112uvpwizABC234","codeChallengeMethod":"%s"}` + connectorDataStr := fmt.Sprintf(connectorDataStrTemplate, config.PKCEChallenge) + connectorData := []byte(connectorDataStr) + + identity, err := conn.HandleCallback(connector.Scopes{Groups: true}, connectorData, req) if err != nil { t.Fatal("handle callback failed", err) } diff --git a/connector/openshift/openshift.go b/connector/openshift/openshift.go index 4519a85b6d..3d4408c585 100644 --- a/connector/openshift/openshift.go +++ b/connector/openshift/openshift.go @@ -138,12 +138,12 @@ func (c *openshiftConnector) Close() error { } // LoginURL returns the URL to redirect the user to login with. -func (c *openshiftConnector) LoginURL(scopes connector.Scopes, callbackURL, state string) (string, error) { +func (c *openshiftConnector) LoginURL(scopes connector.Scopes, callbackURL, state string) (string, []byte, error) { if c.redirectURI != callbackURL { - return "", fmt.Errorf("expected callback URL %q did not match the URL in the config %q", + return "", nil, fmt.Errorf("expected callback URL %q did not match the URL in the config %q", callbackURL, c.redirectURI) } - return c.oauth2Config.AuthCodeURL(state), nil + return c.oauth2Config.AuthCodeURL(state), nil, nil } type oauth2Error struct { @@ -160,6 +160,7 @@ func (e *oauth2Error) Error() string { // HandleCallback parses the request and returns the user's identity func (c *openshiftConnector) HandleCallback(s connector.Scopes, + connData []byte, r *http.Request, ) (identity connector.Identity, err error) { q := r.URL.Query() diff --git a/connector/openshift/openshift_test.go b/connector/openshift/openshift_test.go index 89ec0e25a9..d6d8603bbc 100644 --- a/connector/openshift/openshift_test.go +++ b/connector/openshift/openshift_test.go @@ -176,7 +176,7 @@ func TestCallbackIdentity(t *testing.T) { TokenURL: fmt.Sprintf("%s/oauth/token", s.URL), }, }} - identity, err := oc.HandleCallback(connector.Scopes{Groups: true}, req) + identity, err := oc.HandleCallback(connector.Scopes{Groups: true}, nil, req) expectNil(t, err) expectEquals(t, identity.UserID, "12345") diff --git a/server/handlers.go b/server/handlers.go index 5954820caa..08e1aa72e0 100644 --- a/server/handlers.go +++ b/server/handlers.go @@ -263,12 +263,24 @@ func (s *Server) handleConnectorLogin(w http.ResponseWriter, r *http.Request) { // Use the auth request ID as the "state" token. // // TODO(ericchiang): Is this appropriate or should we also be using a nonce? - callbackURL, err := conn.LoginURL(scopes, s.absURL("/callback"), authReq.ID) + callbackURL, connData, err := conn.LoginURL(scopes, s.absURL("/callback"), authReq.ID) if err != nil { s.logger.ErrorContext(r.Context(), "connector returned error when creating callback", "connector_id", connID, "err", err) s.renderError(r, w, http.StatusInternalServerError, "Login error.") return } + if len(connData) > 0 { + updater := func(a storage.AuthRequest) (storage.AuthRequest, error) { + a.ConnectorData = connData + return a, nil + } + err := s.storage.UpdateAuthRequest(authReq.ID, updater) + if err != nil { + s.logger.ErrorContext(r.Context(), "Failed to set connector data on auth request", "connector_id", connID, "err", err) + s.renderError(r, w, http.StatusInternalServerError, "Database error.") + return + } + } http.Redirect(w, r, callbackURL, http.StatusFound) case connector.PasswordConnector: loginURL := url.URL{ @@ -463,7 +475,7 @@ func (s *Server) handleConnectorCallback(w http.ResponseWriter, r *http.Request) s.renderError(r, w, http.StatusBadRequest, "Invalid request") return } - identity, err = conn.HandleCallback(parseScopes(authReq.Scopes), r) + identity, err = conn.HandleCallback(parseScopes(authReq.Scopes), authReq.ConnectorData, r) case connector.SAMLConnector: if r.Method != http.MethodPost { s.logger.ErrorContext(r.Context(), "OAuth2 request mapped to SAML connector")