diff --git a/connector/bitbucketcloud/bitbucketcloud.go b/connector/bitbucketcloud/bitbucketcloud.go index 5f802e3414..d7fb64caa6 100644 --- a/connector/bitbucketcloud/bitbucketcloud.go +++ b/connector/bitbucketcloud/bitbucketcloud.go @@ -111,12 +111,12 @@ func (b *bitbucketConnector) oauth2Config(scopes connector.Scopes) *oauth2.Confi } } -func (b *bitbucketConnector) LoginURL(scopes connector.Scopes, callbackURL, state string) (string, error) { +func (b *bitbucketConnector) LoginURL(scopes connector.Scopes, callbackURL, state string) (string, []byte, error) { if b.redirectURI != callbackURL { - return "", fmt.Errorf("expected callback URL %q did not match the URL in the config %q", callbackURL, b.redirectURI) + return "", nil, fmt.Errorf("expected callback URL %q did not match the URL in the config %q", callbackURL, b.redirectURI) } - return b.oauth2Config(scopes).AuthCodeURL(state), nil + return b.oauth2Config(scopes).AuthCodeURL(state), nil, nil } type oauth2Error struct { @@ -131,7 +131,7 @@ func (e *oauth2Error) Error() string { return e.error + ": " + e.errorDescription } -func (b *bitbucketConnector) HandleCallback(s connector.Scopes, r *http.Request) (identity connector.Identity, err error) { +func (b *bitbucketConnector) HandleCallback(s connector.Scopes, connData []byte, r *http.Request) (identity connector.Identity, err error) { q := r.URL.Query() if errType := q.Get("error"); errType != "" { return identity, &oauth2Error{errType, q.Get("error_description")} diff --git a/connector/bitbucketcloud/bitbucketcloud_test.go b/connector/bitbucketcloud/bitbucketcloud_test.go index 9545ff09c5..67a74dab38 100644 --- a/connector/bitbucketcloud/bitbucketcloud_test.go +++ b/connector/bitbucketcloud/bitbucketcloud_test.go @@ -102,7 +102,7 @@ func TestUsernameIncludedInFederatedIdentity(t *testing.T) { expectNil(t, err) bitbucketConnector := bitbucketConnector{apiURL: s.URL, hostName: hostURL.Host, httpClient: newClient()} - identity, err := bitbucketConnector.HandleCallback(connector.Scopes{}, req) + identity, err := bitbucketConnector.HandleCallback(connector.Scopes{}, nil, req) expectNil(t, err) expectEquals(t, identity.Username, "some-login") diff --git a/connector/connector.go b/connector/connector.go index d812390f0c..b1e069c3fc 100644 --- a/connector/connector.go +++ b/connector/connector.go @@ -63,10 +63,10 @@ type CallbackConnector interface { // requested if one has already been issues. There's no good general answer // for these kind of restrictions, and may require this package to become more // aware of the global set of user/connector interactions. - LoginURL(s Scopes, callbackURL, state string) (string, error) + LoginURL(s Scopes, callbackURL, state string) (string, []byte, error) // Handle the callback to the server and return an identity. - HandleCallback(s Scopes, r *http.Request) (identity Identity, err error) + HandleCallback(s Scopes, connData []byte, r *http.Request) (identity Identity, err error) } // SAMLConnector represents SAML connectors which implement the HTTP POST binding. diff --git a/connector/gitea/gitea.go b/connector/gitea/gitea.go index 62523185d5..059c861705 100644 --- a/connector/gitea/gitea.go +++ b/connector/gitea/gitea.go @@ -102,11 +102,11 @@ func (c *giteaConnector) oauth2Config(_ connector.Scopes) *oauth2.Config { } } -func (c *giteaConnector) LoginURL(scopes connector.Scopes, callbackURL, state string) (string, error) { +func (c *giteaConnector) LoginURL(scopes connector.Scopes, callbackURL, state string) (string, []byte, error) { if c.redirectURI != callbackURL { - return "", fmt.Errorf("expected callback URL %q did not match the URL in the config %q", c.redirectURI, callbackURL) + return "", nil, fmt.Errorf("expected callback URL %q did not match the URL in the config %q", c.redirectURI, callbackURL) } - return c.oauth2Config(scopes).AuthCodeURL(state), nil + return c.oauth2Config(scopes).AuthCodeURL(state), nil, nil } type oauth2Error struct { @@ -121,7 +121,7 @@ func (e *oauth2Error) Error() string { return e.error + ": " + e.errorDescription } -func (c *giteaConnector) HandleCallback(s connector.Scopes, r *http.Request) (identity connector.Identity, err error) { +func (c *giteaConnector) HandleCallback(s connector.Scopes, connData []byte, r *http.Request) (identity connector.Identity, err error) { q := r.URL.Query() if errType := q.Get("error"); errType != "" { return identity, &oauth2Error{errType, q.Get("error_description")} diff --git a/connector/gitea/gitea_test.go b/connector/gitea/gitea_test.go index a71d79956e..4fe7768901 100644 --- a/connector/gitea/gitea_test.go +++ b/connector/gitea/gitea_test.go @@ -30,14 +30,14 @@ func TestUsernameIncludedInFederatedIdentity(t *testing.T) { expectNil(t, err) c := giteaConnector{baseURL: s.URL, httpClient: newClient()} - identity, err := c.HandleCallback(connector.Scopes{}, req) + identity, err := c.HandleCallback(connector.Scopes{}, nil, req) expectNil(t, err) expectEquals(t, identity.Username, "some@email.com") expectEquals(t, identity.UserID, "12345678") c = giteaConnector{baseURL: s.URL, httpClient: newClient()} - identity, err = c.HandleCallback(connector.Scopes{}, req) + identity, err = c.HandleCallback(connector.Scopes{}, nil, req) expectNil(t, err) expectEquals(t, identity.Username, "some@email.com") diff --git a/connector/github/github.go b/connector/github/github.go index 18a56628af..eb19f77805 100644 --- a/connector/github/github.go +++ b/connector/github/github.go @@ -194,12 +194,12 @@ func (c *githubConnector) oauth2Config(scopes connector.Scopes) *oauth2.Config { } } -func (c *githubConnector) LoginURL(scopes connector.Scopes, callbackURL, state string) (string, error) { +func (c *githubConnector) LoginURL(scopes connector.Scopes, callbackURL, state string) (string, []byte, error) { if c.redirectURI != callbackURL { - return "", fmt.Errorf("expected callback URL %q did not match the URL in the config %q", callbackURL, c.redirectURI) + return "", nil, fmt.Errorf("expected callback URL %q did not match the URL in the config %q", callbackURL, c.redirectURI) } - return c.oauth2Config(scopes).AuthCodeURL(state), nil + return c.oauth2Config(scopes).AuthCodeURL(state), nil, nil } type oauth2Error struct { @@ -214,7 +214,7 @@ func (e *oauth2Error) Error() string { return e.error + ": " + e.errorDescription } -func (c *githubConnector) HandleCallback(s connector.Scopes, r *http.Request) (identity connector.Identity, err error) { +func (c *githubConnector) HandleCallback(s connector.Scopes, connData []byte, r *http.Request) (identity connector.Identity, err error) { q := r.URL.Query() if errType := q.Get("error"); errType != "" { return identity, &oauth2Error{errType, q.Get("error_description")} diff --git a/connector/github/github_test.go b/connector/github/github_test.go index 088cbb238c..fb57161ce2 100644 --- a/connector/github/github_test.go +++ b/connector/github/github_test.go @@ -153,7 +153,7 @@ func TestUsernameIncludedInFederatedIdentity(t *testing.T) { expectNil(t, err) c := githubConnector{apiURL: s.URL, hostName: hostURL.Host, httpClient: newClient()} - identity, err := c.HandleCallback(connector.Scopes{Groups: true}, req) + identity, err := c.HandleCallback(connector.Scopes{Groups: true}, nil, req) expectNil(t, err) expectEquals(t, identity.Username, "some-login") @@ -161,7 +161,7 @@ func TestUsernameIncludedInFederatedIdentity(t *testing.T) { expectEquals(t, 0, len(identity.Groups)) c = githubConnector{apiURL: s.URL, hostName: hostURL.Host, httpClient: newClient(), loadAllGroups: true} - identity, err = c.HandleCallback(connector.Scopes{Groups: true}, req) + identity, err = c.HandleCallback(connector.Scopes{Groups: true}, nil, req) expectNil(t, err) expectEquals(t, identity.Username, "some-login") @@ -194,7 +194,7 @@ func TestLoginUsedAsIDWhenConfigured(t *testing.T) { expectNil(t, err) c := githubConnector{apiURL: s.URL, hostName: hostURL.Host, httpClient: newClient(), useLoginAsID: true} - identity, err := c.HandleCallback(connector.Scopes{Groups: true}, req) + identity, err := c.HandleCallback(connector.Scopes{Groups: true}, nil, req) expectNil(t, err) expectEquals(t, identity.UserID, "some-login") diff --git a/connector/gitlab/gitlab.go b/connector/gitlab/gitlab.go index fdb2c48204..30fb87fb35 100644 --- a/connector/gitlab/gitlab.go +++ b/connector/gitlab/gitlab.go @@ -100,11 +100,11 @@ func (c *gitlabConnector) oauth2Config(scopes connector.Scopes) *oauth2.Config { } } -func (c *gitlabConnector) LoginURL(scopes connector.Scopes, callbackURL, state string) (string, error) { +func (c *gitlabConnector) LoginURL(scopes connector.Scopes, callbackURL, state string) (string, []byte, error) { if c.redirectURI != callbackURL { - return "", fmt.Errorf("expected callback URL %q did not match the URL in the config %q", c.redirectURI, callbackURL) + return "", nil, fmt.Errorf("expected callback URL %q did not match the URL in the config %q", c.redirectURI, callbackURL) } - return c.oauth2Config(scopes).AuthCodeURL(state), nil + return c.oauth2Config(scopes).AuthCodeURL(state), nil, nil } type oauth2Error struct { @@ -119,7 +119,7 @@ func (e *oauth2Error) Error() string { return e.error + ": " + e.errorDescription } -func (c *gitlabConnector) HandleCallback(s connector.Scopes, r *http.Request) (identity connector.Identity, err error) { +func (c *gitlabConnector) HandleCallback(s connector.Scopes, connData []byte, r *http.Request) (identity connector.Identity, err error) { q := r.URL.Query() if errType := q.Get("error"); errType != "" { return identity, &oauth2Error{errType, q.Get("error_description")} diff --git a/connector/gitlab/gitlab_test.go b/connector/gitlab/gitlab_test.go index d828b8bd16..1a9b774ea5 100644 --- a/connector/gitlab/gitlab_test.go +++ b/connector/gitlab/gitlab_test.go @@ -84,7 +84,7 @@ func TestUsernameIncludedInFederatedIdentity(t *testing.T) { expectNil(t, err) c := gitlabConnector{baseURL: s.URL, httpClient: newClient()} - identity, err := c.HandleCallback(connector.Scopes{Groups: false}, req) + identity, err := c.HandleCallback(connector.Scopes{Groups: false}, nil, req) expectNil(t, err) expectEquals(t, identity.Username, "some@email.com") @@ -92,7 +92,7 @@ func TestUsernameIncludedInFederatedIdentity(t *testing.T) { expectEquals(t, 0, len(identity.Groups)) c = gitlabConnector{baseURL: s.URL, httpClient: newClient()} - identity, err = c.HandleCallback(connector.Scopes{Groups: true}, req) + identity, err = c.HandleCallback(connector.Scopes{Groups: true}, nil, req) expectNil(t, err) expectEquals(t, identity.Username, "some@email.com") @@ -120,7 +120,7 @@ func TestLoginUsedAsIDWhenConfigured(t *testing.T) { expectNil(t, err) c := gitlabConnector{baseURL: s.URL, httpClient: newClient(), useLoginAsID: true} - identity, err := c.HandleCallback(connector.Scopes{Groups: true}, req) + identity, err := c.HandleCallback(connector.Scopes{Groups: true}, nil, req) expectNil(t, err) expectEquals(t, identity.UserID, "joebloggs") @@ -147,7 +147,7 @@ func TestLoginWithTeamWhitelisted(t *testing.T) { expectNil(t, err) c := gitlabConnector{baseURL: s.URL, httpClient: newClient(), groups: []string{"team-1"}} - identity, err := c.HandleCallback(connector.Scopes{Groups: true}, req) + identity, err := c.HandleCallback(connector.Scopes{Groups: true}, nil, req) expectNil(t, err) expectEquals(t, identity.UserID, "12345678") @@ -174,7 +174,7 @@ func TestLoginWithTeamNonWhitelisted(t *testing.T) { expectNil(t, err) c := gitlabConnector{baseURL: s.URL, httpClient: newClient(), groups: []string{"team-2"}} - _, err = c.HandleCallback(connector.Scopes{Groups: true}, req) + _, err = c.HandleCallback(connector.Scopes{Groups: true}, nil, req) expectNotNil(t, err, "HandleCallback error") expectEquals(t, err.Error(), "gitlab: get groups: gitlab: user \"joebloggs\" is not in any of the required groups") @@ -208,7 +208,7 @@ func TestRefresh(t *testing.T) { }) expectNil(t, err) - identity, err := c.HandleCallback(connector.Scopes{OfflineAccess: true}, req) + identity, err := c.HandleCallback(connector.Scopes{OfflineAccess: true}, nil, req) expectNil(t, err) expectEquals(t, identity.Username, "some@email.com") expectEquals(t, identity.UserID, "12345678") diff --git a/connector/google/google.go b/connector/google/google.go index e17ec5bd7f..4a8599c0b1 100644 --- a/connector/google/google.go +++ b/connector/google/google.go @@ -168,9 +168,9 @@ func (c *googleConnector) Close() error { return nil } -func (c *googleConnector) LoginURL(s connector.Scopes, callbackURL, state string) (string, error) { +func (c *googleConnector) LoginURL(s connector.Scopes, callbackURL, state string) (string, []byte, error) { if c.redirectURI != callbackURL { - return "", fmt.Errorf("expected callback URL %q did not match the URL in the config %q", callbackURL, c.redirectURI) + return "", nil, fmt.Errorf("expected callback URL %q did not match the URL in the config %q", callbackURL, c.redirectURI) } var opts []oauth2.AuthCodeOption @@ -186,7 +186,7 @@ func (c *googleConnector) LoginURL(s connector.Scopes, callbackURL, state string opts = append(opts, oauth2.AccessTypeOffline, oauth2.SetAuthURLParam("prompt", c.promptType)) } - return c.oauth2Config.AuthCodeURL(state, opts...), nil + return c.oauth2Config.AuthCodeURL(state, opts...), nil, nil } type oauth2Error struct { @@ -201,7 +201,7 @@ func (e *oauth2Error) Error() string { return e.error + ": " + e.errorDescription } -func (c *googleConnector) HandleCallback(s connector.Scopes, r *http.Request) (identity connector.Identity, err error) { +func (c *googleConnector) HandleCallback(s connector.Scopes, connData []byte, r *http.Request) (identity connector.Identity, err error) { q := r.URL.Query() if errType := q.Get("error"); errType != "" { return identity, &oauth2Error{errType, q.Get("error_description")} diff --git a/connector/google/google_test.go b/connector/google/google_test.go index bafcadc8ff..68188c6559 100644 --- a/connector/google/google_test.go +++ b/connector/google/google_test.go @@ -440,7 +440,7 @@ func TestPromptTypeConfig(t *testing.T) { assert.Nil(t, err) assert.Equal(t, test.expectedPromptTypeValue, conn.promptType) - loginURL, err := conn.LoginURL(connector.Scopes{OfflineAccess: true}, ts.URL+"/callback", "state") + loginURL, _, err := conn.LoginURL(connector.Scopes{OfflineAccess: true}, ts.URL+"/callback", "state") assert.Nil(t, err) urlp, err := url.Parse(loginURL) diff --git a/connector/linkedin/linkedin.go b/connector/linkedin/linkedin.go index f17d17cca1..0c24ff4756 100644 --- a/connector/linkedin/linkedin.go +++ b/connector/linkedin/linkedin.go @@ -62,17 +62,17 @@ var ( ) // LoginURL returns an access token request URL -func (c *linkedInConnector) LoginURL(scopes connector.Scopes, callbackURL, state string) (string, error) { +func (c *linkedInConnector) LoginURL(scopes connector.Scopes, callbackURL, state string) (string, []byte, error) { if c.oauth2Config.RedirectURL != callbackURL { - return "", fmt.Errorf("expected callback URL %q did not match the URL in the config %q", + return "", nil, fmt.Errorf("expected callback URL %q did not match the URL in the config %q", callbackURL, c.oauth2Config.RedirectURL) } - return c.oauth2Config.AuthCodeURL(state), nil + return c.oauth2Config.AuthCodeURL(state), nil, nil } // HandleCallback handles HTTP redirect from LinkedIn -func (c *linkedInConnector) HandleCallback(s connector.Scopes, r *http.Request) (identity connector.Identity, err error) { +func (c *linkedInConnector) HandleCallback(s connector.Scopes, connData []byte, r *http.Request) (identity connector.Identity, err error) { q := r.URL.Query() if errType := q.Get("error"); errType != "" { return identity, &oauth2Error{errType, q.Get("error_description")} diff --git a/connector/microsoft/microsoft.go b/connector/microsoft/microsoft.go index 2fcf6a7515..1db16942b7 100644 --- a/connector/microsoft/microsoft.go +++ b/connector/microsoft/microsoft.go @@ -175,9 +175,9 @@ func (c *microsoftConnector) oauth2Config(scopes connector.Scopes) *oauth2.Confi } } -func (c *microsoftConnector) LoginURL(scopes connector.Scopes, callbackURL, state string) (string, error) { +func (c *microsoftConnector) LoginURL(scopes connector.Scopes, callbackURL, state string) (string, []byte, error) { if c.redirectURI != callbackURL { - return "", fmt.Errorf("expected callback URL %q did not match the URL in the config %q", callbackURL, c.redirectURI) + return "", nil, fmt.Errorf("expected callback URL %q did not match the URL in the config %q", callbackURL, c.redirectURI) } var options []oauth2.AuthCodeOption @@ -188,10 +188,10 @@ func (c *microsoftConnector) LoginURL(scopes connector.Scopes, callbackURL, stat options = append(options, oauth2.SetAuthURLParam("domain_hint", c.domainHint)) } - return c.oauth2Config(scopes).AuthCodeURL(state, options...), nil + return c.oauth2Config(scopes).AuthCodeURL(state, options...), nil, nil } -func (c *microsoftConnector) HandleCallback(s connector.Scopes, r *http.Request) (identity connector.Identity, err error) { +func (c *microsoftConnector) HandleCallback(s connector.Scopes, connData []byte, r *http.Request) (identity connector.Identity, err error) { q := r.URL.Query() if errType := q.Get("error"); errType != "" { return identity, &oauth2Error{errType, q.Get("error_description")} diff --git a/connector/microsoft/microsoft_test.go b/connector/microsoft/microsoft_test.go index 67be660fce..1fa2f3dcdf 100644 --- a/connector/microsoft/microsoft_test.go +++ b/connector/microsoft/microsoft_test.go @@ -39,7 +39,7 @@ func TestLoginURL(t *testing.T) { tenant: tenant, } - loginURL, _ := conn.LoginURL(connector.Scopes{}, conn.redirectURI, testState) + loginURL, _, _ := conn.LoginURL(connector.Scopes{}, conn.redirectURI, testState) parsedLoginURL, _ := url.Parse(loginURL) queryParams := parsedLoginURL.Query() @@ -70,7 +70,7 @@ func TestLoginURLWithOptions(t *testing.T) { domainHint: domainHint, } - loginURL, _ := conn.LoginURL(connector.Scopes{}, conn.redirectURI, "some-state") + loginURL, _, _ := conn.LoginURL(connector.Scopes{}, conn.redirectURI, "some-state") parsedLoginURL, _ := url.Parse(loginURL) queryParams := parsedLoginURL.Query() @@ -91,7 +91,7 @@ func TestUserIdentityFromGraphAPI(t *testing.T) { req, _ := http.NewRequest("GET", s.URL, nil) c := microsoftConnector{apiURL: s.URL, graphURL: s.URL, tenant: tenant} - identity, err := c.HandleCallback(connector.Scopes{Groups: false}, req) + identity, err := c.HandleCallback(connector.Scopes{Groups: false}, nil, req) expectNil(t, err) expectEquals(t, identity.Username, "Jane Doe") expectEquals(t, identity.UserID, "S56767889") @@ -114,7 +114,7 @@ func TestUserGroupsFromGraphAPI(t *testing.T) { req, _ := http.NewRequest("GET", s.URL, nil) c := microsoftConnector{apiURL: s.URL, graphURL: s.URL, tenant: tenant} - identity, err := c.HandleCallback(connector.Scopes{Groups: true}, req) + identity, err := c.HandleCallback(connector.Scopes{Groups: true}, nil, req) expectNil(t, err) expectEquals(t, identity.Groups, []string{"a", "b"}) } diff --git a/connector/mock/connectortest.go b/connector/mock/connectortest.go index 7e5979a992..be44bfd12a 100644 --- a/connector/mock/connectortest.go +++ b/connector/mock/connectortest.go @@ -43,21 +43,21 @@ type Callback struct { } // LoginURL returns the URL to redirect the user to login with. -func (m *Callback) LoginURL(s connector.Scopes, callbackURL, state string) (string, error) { +func (m *Callback) LoginURL(s connector.Scopes, callbackURL, state string) (string, []byte, error) { u, err := url.Parse(callbackURL) if err != nil { - return "", fmt.Errorf("failed to parse callbackURL %q: %v", callbackURL, err) + return "", nil, fmt.Errorf("failed to parse callbackURL %q: %v", callbackURL, err) } v := u.Query() v.Set("state", state) u.RawQuery = v.Encode() - return u.String(), nil + return u.String(), nil, nil } var connectorData = []byte("foobar") // HandleCallback parses the request and returns the user's identity -func (m *Callback) HandleCallback(s connector.Scopes, r *http.Request) (connector.Identity, error) { +func (m *Callback) HandleCallback(s connector.Scopes, connData []byte, r *http.Request) (connector.Identity, error) { return m.Identity, nil } diff --git a/connector/oidc/oidc.go b/connector/oidc/oidc.go index 1b6077b7f6..bdb601c5f6 100644 --- a/connector/oidc/oidc.go +++ b/connector/oidc/oidc.go @@ -14,7 +14,6 @@ import ( "time" "github.com/coreos/go-oidc/v3/oidc" - "github.com/hashicorp/golang-lru/v2" "golang.org/x/oauth2" "github.com/dexidp/dex/connector" @@ -23,9 +22,8 @@ import ( ) const ( - defaultPkceMaxConcurrentConnections = 256 - codeChallengeMethodPlain = "plain" - codeChallengeMethodS256 = "S256" + codeChallengeMethodPlain = "plain" + codeChallengeMethodS256 = "S256" ) func contains(arr []string, item string) bool { @@ -99,9 +97,6 @@ type Config struct { // If not setted it will be auto-detected the best-fit for the connector. PKCEChallenge string `json:"pkceChallenge"` - // PKCEMaxConcurrentConnections specifies the maximum number of concurrent connections for the PKCE code verify. - PKCEMaxConcurrentConnections int `json:"pkceMaxConcurrentConnections"` - // OverrideClaimMapping will be used to override the options defined in claimMappings. // i.e. if there are 'email' and `preferred_email` claims available, by default Dex will always use the `email` claim independent of the ClaimMapping.EmailKey. // This setting allows you to override the default behavior of Dex and enforce the mappings defined in `claimMapping`. @@ -235,6 +230,25 @@ func knownBrokenAuthHeaderProvider(issuerURL string) bool { return false } +// PKCEChallengeData is used to store info for PKCE Challenge method and verifier +// in the connectorData +type PKCEChallengeData struct { + CodeChallenge string `json:"codeChallenge"` + CodeChallengeMethod string `json:"codeChallengeMethod"` +} + +// Returns an AuthCodeOption according to the provided codeChallengeMethod +func getAuthCodeOptionForCodeChallenge(codeVerifier, codeChallengeMethod string) (oauth2.AuthCodeOption, error) { + switch codeChallengeMethod { + case codeChallengeMethodPlain: + return oauth2.VerifierOption(codeVerifier), nil + case codeChallengeMethodS256: + return oauth2.S256ChallengeOption(codeVerifier), nil + default: + return nil, fmt.Errorf("unknown challenge method (%v)", codeChallengeMethod) + } +} + // Open returns a connector which can be used to login users through an upstream // OpenID Connect provider. func (c *Config) Open(id string, logger *slog.Logger) (conn connector.Connector, err error) { @@ -312,19 +326,6 @@ func (c *Config) Open(id string, logger *slog.Logger) (conn connector.Connector, } } - // if PKCE will be used, create a state cache for verifier - var pkceVerifierCache *lru.Cache[string, string] - if c.PKCEChallenge != "" { - pkceCacheSize := c.PKCEMaxConcurrentConnections - if pkceCacheSize == 0 { - pkceCacheSize = defaultPkceMaxConcurrentConnections - } - pkceVerifierCache, err = lru.New[string, string](pkceCacheSize) - if err != nil { - logger.Warn("Unable to create PKCE Verifier cache") - } - } - clientID := c.ClientID return &oidcConnector{ provider: provider, @@ -340,26 +341,24 @@ func (c *Config) Open(id string, logger *slog.Logger) (conn connector.Connector, ctx, // Pass our ctx with customized http.Client &oidc.Config{ClientID: clientID}, ), - logger: logger.With(slog.Group("connector", "type", "oidc", "id", id)), - cancel: cancel, - httpClient: httpClient, - insecureSkipEmailVerified: c.InsecureSkipEmailVerified, - insecureEnableGroups: c.InsecureEnableGroups, - allowedGroups: c.AllowedGroups, - acrValues: c.AcrValues, - getUserInfo: c.GetUserInfo, - promptType: promptType, - userIDKey: c.UserIDKey, - userNameKey: c.UserNameKey, - overrideClaimMapping: c.OverrideClaimMapping, - preferredUsernameKey: c.ClaimMapping.PreferredUsernameKey, - emailKey: c.ClaimMapping.EmailKey, - groupsKey: c.ClaimMapping.GroupsKey, - newGroupFromClaims: c.ClaimMutations.NewGroupFromClaims, - groupsFilter: groupsFilter, - pkceChallenge: c.PKCEChallenge, - pkceMaxConcurrentConnections: c.PKCEMaxConcurrentConnections, - pkceVerifierCache: pkceVerifierCache, + logger: logger.With(slog.Group("connector", "type", "oidc", "id", id)), + cancel: cancel, + httpClient: httpClient, + insecureSkipEmailVerified: c.InsecureSkipEmailVerified, + insecureEnableGroups: c.InsecureEnableGroups, + allowedGroups: c.AllowedGroups, + acrValues: c.AcrValues, + getUserInfo: c.GetUserInfo, + promptType: promptType, + userIDKey: c.UserIDKey, + userNameKey: c.UserNameKey, + overrideClaimMapping: c.OverrideClaimMapping, + preferredUsernameKey: c.ClaimMapping.PreferredUsernameKey, + emailKey: c.ClaimMapping.EmailKey, + groupsKey: c.ClaimMapping.GroupsKey, + newGroupFromClaims: c.ClaimMutations.NewGroupFromClaims, + groupsFilter: groupsFilter, + pkceChallenge: c.PKCEChallenge, }, nil } @@ -369,30 +368,28 @@ var ( ) type oidcConnector struct { - provider *oidc.Provider - redirectURI string - oauth2Config *oauth2.Config - verifier *oidc.IDTokenVerifier - cancel context.CancelFunc - logger *slog.Logger - httpClient *http.Client - insecureSkipEmailVerified bool - insecureEnableGroups bool - allowedGroups []string - acrValues []string - getUserInfo bool - promptType string - userIDKey string - userNameKey string - overrideClaimMapping bool - preferredUsernameKey string - emailKey string - groupsKey string - newGroupFromClaims []NewGroupFromClaims - groupsFilter *regexp.Regexp - pkceChallenge string - pkceMaxConcurrentConnections int - pkceVerifierCache *lru.Cache[string, string] + provider *oidc.Provider + redirectURI string + oauth2Config *oauth2.Config + verifier *oidc.IDTokenVerifier + cancel context.CancelFunc + logger *slog.Logger + httpClient *http.Client + insecureSkipEmailVerified bool + insecureEnableGroups bool + allowedGroups []string + acrValues []string + getUserInfo bool + promptType string + userIDKey string + userNameKey string + overrideClaimMapping bool + preferredUsernameKey string + emailKey string + groupsKey string + newGroupFromClaims []NewGroupFromClaims + groupsFilter *regexp.Regexp + pkceChallenge string } func (c *oidcConnector) Close() error { @@ -400,12 +397,13 @@ func (c *oidcConnector) Close() error { return nil } -func (c *oidcConnector) LoginURL(s connector.Scopes, callbackURL, state string) (string, error) { +func (c *oidcConnector) LoginURL(s connector.Scopes, callbackURL, state string) (string, []byte, error) { if c.redirectURI != callbackURL { - return "", fmt.Errorf("expected callback URL %q did not match the URL in the config %q", callbackURL, c.redirectURI) + return "", nil, fmt.Errorf("expected callback URL %q did not match the URL in the config %q", callbackURL, c.redirectURI) } var opts []oauth2.AuthCodeOption + var connectorData []byte if len(c.acrValues) > 0 { acrValues := strings.Join(c.acrValues, " ") @@ -417,21 +415,23 @@ func (c *oidcConnector) LoginURL(s connector.Scopes, callbackURL, state string) } if c.pkceChallenge != "" { - switch c.pkceChallenge { - case codeChallengeMethodPlain: - pkceVerifier := oauth2.GenerateVerifier() - c.pkceVerifierCache.Add(state, pkceVerifier) - opts = append(opts, oauth2.VerifierOption(pkceVerifier)) - case codeChallengeMethodS256: - pkceVerifier := oauth2.GenerateVerifier() - c.pkceVerifierCache.Add(state, pkceVerifier) - opts = append(opts, oauth2.S256ChallengeOption(pkceVerifier)) - default: - c.logger.Warn("unknown PKCEChallenge method") + codeVerifier := oauth2.GenerateVerifier() + authCodeOption, err := getAuthCodeOptionForCodeChallenge(codeVerifier, c.pkceChallenge) + if err != nil { + return "", nil, fmt.Errorf("oidc: failed to get PKCE AuthCodeOption for CodeChallenge: %v", err) } + data := PKCEChallengeData{ + CodeChallenge: codeVerifier, + CodeChallengeMethod: c.pkceChallenge, + } + connectorData, err = json.Marshal(data) + if err != nil { + return "", nil, fmt.Errorf("oidc: failed to create PKCEChallenge data: %v", err) + } + opts = append(opts, authCodeOption) } - return c.oauth2Config.AuthCodeURL(state, opts...), nil + return c.oauth2Config.AuthCodeURL(state, opts...), connectorData, nil } type oauth2Error struct { @@ -454,7 +454,7 @@ const ( exchangeCaller ) -func (c *oidcConnector) HandleCallback(s connector.Scopes, r *http.Request) (identity connector.Identity, err error) { +func (c *oidcConnector) HandleCallback(s connector.Scopes, connData []byte, r *http.Request) (identity connector.Identity, err error) { q := r.URL.Query() if errType := q.Get("error"); errType != "" { return identity, &oauth2Error{errType, q.Get("error_description")} @@ -464,20 +464,14 @@ func (c *oidcConnector) HandleCallback(s connector.Scopes, r *http.Request) (ide var opts []oauth2.AuthCodeOption if c.pkceChallenge != "" { - state := q.Get("state") - if state == "" { - return identity, fmt.Errorf("oidc: missing state in callback") - } - pkceVerifier, found := c.pkceVerifierCache.Get(state) - if !found { - return identity, fmt.Errorf("oidc: received state not in callback cache") + var data PKCEChallengeData + if err := json.Unmarshal(connData, &data); err != nil { + return identity, fmt.Errorf("oidc: failed to parse PKCEChallenge data: %v", err) } - - c.pkceVerifierCache.Remove(state) - if pkceVerifier == "" { - return identity, fmt.Errorf("oidc: invalid state in pkce verifier cache") + if data.CodeChallenge == "" { + return identity, fmt.Errorf("oidc: invalid PKCE CodeChallenge") } - opts = append(opts, oauth2.VerifierOption(pkceVerifier)) + opts = append(opts, oauth2.VerifierOption(data.CodeChallenge)) } token, err := c.oauth2Config.Exchange(ctx, q.Get("code"), opts...) diff --git a/connector/oidc/oidc_test.go b/connector/oidc/oidc_test.go index 3ae517f8c0..8a99de74c2 100644 --- a/connector/oidc/oidc_test.go +++ b/connector/oidc/oidc_test.go @@ -466,7 +466,11 @@ func TestHandleCallback(t *testing.T) { t.Fatal("failed to create request", err) } - identity, err := conn.HandleCallback(connector.Scopes{Groups: true}, req) + connectorDataStrTemplate := `{"codeChallenge":"abcdefgh123456qwertuiop89101112uvpwizABC234","codeChallengeMethod":"%s"}` + connectorDataStr := fmt.Sprintf(connectorDataStrTemplate, config.PKCEChallenge) + connectorData := []byte(connectorDataStr) + + identity, err := conn.HandleCallback(connector.Scopes{Groups: true}, connectorData, req) if err != nil { t.Fatal("handle callback failed", err) } diff --git a/connector/openshift/openshift.go b/connector/openshift/openshift.go index 4519a85b6d..3d4408c585 100644 --- a/connector/openshift/openshift.go +++ b/connector/openshift/openshift.go @@ -138,12 +138,12 @@ func (c *openshiftConnector) Close() error { } // LoginURL returns the URL to redirect the user to login with. -func (c *openshiftConnector) LoginURL(scopes connector.Scopes, callbackURL, state string) (string, error) { +func (c *openshiftConnector) LoginURL(scopes connector.Scopes, callbackURL, state string) (string, []byte, error) { if c.redirectURI != callbackURL { - return "", fmt.Errorf("expected callback URL %q did not match the URL in the config %q", + return "", nil, fmt.Errorf("expected callback URL %q did not match the URL in the config %q", callbackURL, c.redirectURI) } - return c.oauth2Config.AuthCodeURL(state), nil + return c.oauth2Config.AuthCodeURL(state), nil, nil } type oauth2Error struct { @@ -160,6 +160,7 @@ func (e *oauth2Error) Error() string { // HandleCallback parses the request and returns the user's identity func (c *openshiftConnector) HandleCallback(s connector.Scopes, + connData []byte, r *http.Request, ) (identity connector.Identity, err error) { q := r.URL.Query() diff --git a/connector/openshift/openshift_test.go b/connector/openshift/openshift_test.go index 89ec0e25a9..d6d8603bbc 100644 --- a/connector/openshift/openshift_test.go +++ b/connector/openshift/openshift_test.go @@ -176,7 +176,7 @@ func TestCallbackIdentity(t *testing.T) { TokenURL: fmt.Sprintf("%s/oauth/token", s.URL), }, }} - identity, err := oc.HandleCallback(connector.Scopes{Groups: true}, req) + identity, err := oc.HandleCallback(connector.Scopes{Groups: true}, nil, req) expectNil(t, err) expectEquals(t, identity.UserID, "12345") diff --git a/go.mod b/go.mod index 96671dcd8b..e2100a396b 100644 --- a/go.mod +++ b/go.mod @@ -69,7 +69,6 @@ require ( github.com/google/s2a-go v0.1.8 // indirect github.com/googleapis/enterprise-certificate-proxy v0.3.3 // indirect github.com/googleapis/gax-go/v2 v2.13.0 // indirect - github.com/hashicorp/golang-lru/v2 v2.0.7 github.com/hashicorp/hcl/v2 v2.13.0 // indirect github.com/huandu/xstrings v1.5.0 // indirect github.com/inconshreveable/mousetrap v1.1.0 // indirect diff --git a/go.sum b/go.sum index 5cbdb3cf41..cbebe635ec 100644 --- a/go.sum +++ b/go.sum @@ -135,8 +135,6 @@ github.com/grpc-ecosystem/go-grpc-prometheus v1.2.0/go.mod h1:8NvIoxWQoOIhqOTXgf github.com/hashicorp/go-uuid v1.0.2/go.mod h1:6SBZvOh/SIDV7/2o3Jml5SYk/TvGqwFJ/bN7x4byOro= github.com/hashicorp/go-uuid v1.0.3 h1:2gKiV6YVmrJ1i2CKKa9obLvRieoRGviZFL26PcT/Co8= github.com/hashicorp/go-uuid v1.0.3/go.mod h1:6SBZvOh/SIDV7/2o3Jml5SYk/TvGqwFJ/bN7x4byOro= -github.com/hashicorp/golang-lru/v2 v2.0.7 h1:a+bsQ5rvGLjzHuww6tVxozPZFVghXaHOwFs4luLUK2k= -github.com/hashicorp/golang-lru/v2 v2.0.7/go.mod h1:QeFd9opnmA6QUJc5vARoKUSoFhyfM2/ZepoAG6RGpeM= github.com/hashicorp/hcl/v2 v2.13.0 h1:0Apadu1w6M11dyGFxWnmhhcMjkbAiKCv7G1r/2QgCNc= github.com/hashicorp/hcl/v2 v2.13.0/go.mod h1:e4z5nxYlWNPdDSNYX+ph14EvWYMFm3eP0zIUqPc2jr0= github.com/huandu/xstrings v1.5.0 h1:2ag3IFq9ZDANvthTwTiqSSZLjDc+BedvHPAp5tJy2TI= diff --git a/server/handlers.go b/server/handlers.go index 6521bf6a93..7a4152c8ad 100644 --- a/server/handlers.go +++ b/server/handlers.go @@ -263,12 +263,24 @@ func (s *Server) handleConnectorLogin(w http.ResponseWriter, r *http.Request) { // Use the auth request ID as the "state" token. // // TODO(ericchiang): Is this appropriate or should we also be using a nonce? - callbackURL, err := conn.LoginURL(scopes, s.absURL("/callback"), authReq.ID) + callbackURL, connData, err := conn.LoginURL(scopes, s.absURL("/callback"), authReq.ID) if err != nil { s.logger.ErrorContext(r.Context(), "connector returned error when creating callback", "connector_id", connID, "err", err) s.renderError(r, w, http.StatusInternalServerError, "Login error.") return } + if len(connData) > 0 { + updater := func(a storage.AuthRequest) (storage.AuthRequest, error) { + a.ConnectorData = connData + return a, nil + } + err := s.storage.UpdateAuthRequest(authReq.ID, updater) + if err != nil { + s.logger.ErrorContext(r.Context(), "Failed to set connector data on auth request", "connector_id", connID, "err", err) + s.renderError(r, w, http.StatusInternalServerError, "Database error.") + return + } + } http.Redirect(w, r, callbackURL, http.StatusFound) case connector.PasswordConnector: loginURL := url.URL{ @@ -463,7 +475,7 @@ func (s *Server) handleConnectorCallback(w http.ResponseWriter, r *http.Request) s.renderError(r, w, http.StatusBadRequest, "Invalid request") return } - identity, err = conn.HandleCallback(parseScopes(authReq.Scopes), r) + identity, err = conn.HandleCallback(parseScopes(authReq.Scopes), authReq.ConnectorData, r) case connector.SAMLConnector: if r.Method != http.MethodPost { s.logger.ErrorContext(r.Context(), "OAuth2 request mapped to SAML connector")