From 1c2640c3aa0860fcbfadf03e8331697e09d09738 Mon Sep 17 00:00:00 2001 From: Giovanni Vella Date: Wed, 2 Oct 2024 12:56:37 +0000 Subject: [PATCH 1/6] add support to PKCE in OIDC connector Signed-off-by: johnvan7 Signed-off-by: Giovanni Vella --- connector/oidc/oidc.go | 34 +++++++++++++++++++++++++++++++++- connector/oidc/oidc_test.go | 36 ++++++++++++++++++++++++++++++++++++ 2 files changed, 69 insertions(+), 1 deletion(-) diff --git a/connector/oidc/oidc.go b/connector/oidc/oidc.go index 7d0cacb056..b0e8afcf33 100644 --- a/connector/oidc/oidc.go +++ b/connector/oidc/oidc.go @@ -21,6 +21,11 @@ import ( "github.com/dexidp/dex/pkg/httpclient" ) +const ( + codeChallengeMethodPlain = "plain" + codeChallengeMethodS256 = "S256" +) + // Config holds configuration options for OpenID Connect logins. type Config struct { Issuer string `json:"issuer"` @@ -79,6 +84,8 @@ type Config struct { // PromptType will be used for the prompt parameter (when offline_access, by default prompt=consent) PromptType *string `json:"promptType"` + PKCEChallenge string `json:"pkceChallenge"` + // OverrideClaimMapping will be used to override the options defined in claimMappings. // i.e. if there are 'email' and `preferred_email` claims available, by default Dex will always use the `email` claim independent of the ClaimMapping.EmailKey. // This setting allows you to override the default behavior of Dex and enforce the mappings defined in `claimMapping`. @@ -268,6 +275,8 @@ func (c *Config) Open(id string, logger *slog.Logger) (conn connector.Connector, } } + pkceVerifier := "" + clientID := c.ClientID return &oidcConnector{ provider: provider, @@ -300,6 +309,8 @@ func (c *Config) Open(id string, logger *slog.Logger) (conn connector.Connector, groupsKey: c.ClaimMapping.GroupsKey, newGroupFromClaims: c.ClaimMutations.NewGroupFromClaims, groupsFilter: groupsFilter, + pkceChallenge: c.PKCEChallenge, + pkceVerifier: pkceVerifier, }, nil } @@ -330,6 +341,8 @@ type oidcConnector struct { groupsKey string newGroupFromClaims []NewGroupFromClaims groupsFilter *regexp.Regexp + pkceChallenge string + pkceVerifier string } func (c *oidcConnector) Close() error { @@ -352,6 +365,20 @@ func (c *oidcConnector) LoginURL(s connector.Scopes, callbackURL, state string) if s.OfflineAccess { opts = append(opts, oauth2.AccessTypeOffline, oauth2.SetAuthURLParam("prompt", c.promptType)) } + + if c.pkceChallenge != "" { + switch c.pkceChallenge { + case codeChallengeMethodPlain: + c.pkceVerifier = oauth2.GenerateVerifier() + opts = append(opts, oauth2.VerifierOption(c.pkceVerifier)) + case codeChallengeMethodS256: + c.pkceVerifier = oauth2.GenerateVerifier() + opts = append(opts, oauth2.S256ChallengeOption(c.pkceVerifier)) + default: + c.logger.Warn("unknown PKCEChallenge method") + } + } + return c.oauth2Config.AuthCodeURL(state, opts...), nil } @@ -383,7 +410,12 @@ func (c *oidcConnector) HandleCallback(s connector.Scopes, r *http.Request) (ide ctx := context.WithValue(r.Context(), oauth2.HTTPClient, c.httpClient) - token, err := c.oauth2Config.Exchange(ctx, q.Get("code")) + var opts []oauth2.AuthCodeOption + if c.pkceVerifier != "" { + opts = append(opts, oauth2.VerifierOption(c.pkceVerifier)) + } + + token, err := c.oauth2Config.Exchange(ctx, q.Get("code"), opts...) if err != nil { return identity, fmt.Errorf("oidc: failed to get token: %v", err) } diff --git a/connector/oidc/oidc_test.go b/connector/oidc/oidc_test.go index 66b35c3fef..3ae517f8c0 100644 --- a/connector/oidc/oidc_test.go +++ b/connector/oidc/oidc_test.go @@ -66,6 +66,7 @@ func TestHandleCallback(t *testing.T) { token map[string]interface{} groupsRegex string newGroupFromClaims []NewGroupFromClaims + pkceChallenge string }{ { name: "simpleCase", @@ -382,6 +383,40 @@ func TestHandleCallback(t *testing.T) { "email_verified": true, }, }, + { + name: "S256PKCEChallenge", + userIDKey: "", // not configured + userNameKey: "", // not configured + pkceChallenge: "S256", + expectUserID: "subvalue", + expectUserName: "namevalue", + expectGroups: []string{"group1", "group2"}, + expectedEmailField: "emailvalue", + token: map[string]interface{}{ + "sub": "subvalue", + "name": "namevalue", + "groups": []string{"group1", "group2"}, + "email": "emailvalue", + "email_verified": true, + }, + }, + { + name: "plainPKCEChallenge", + userIDKey: "", // not configured + userNameKey: "", // not configured + pkceChallenge: "plain", + expectUserID: "subvalue", + expectUserName: "namevalue", + expectGroups: []string{"group1", "group2"}, + expectedEmailField: "emailvalue", + token: map[string]interface{}{ + "sub": "subvalue", + "name": "namevalue", + "groups": []string{"group1", "group2"}, + "email": "emailvalue", + "email_verified": true, + }, + }, } for _, tc := range tests { @@ -413,6 +448,7 @@ func TestHandleCallback(t *testing.T) { InsecureEnableGroups: true, BasicAuthUnsupported: &basicAuth, OverrideClaimMapping: tc.overrideClaimMapping, + PKCEChallenge: tc.pkceChallenge, } config.ClaimMapping.PreferredUsernameKey = tc.preferredUsernameKey config.ClaimMapping.EmailKey = tc.emailKey From 6ef864df1016645938abc68933244b42b586ba10 Mon Sep 17 00:00:00 2001 From: Giovanni Vella Date: Wed, 2 Oct 2024 15:04:43 +0000 Subject: [PATCH 2/6] auto-detect PKCE Challenge Signed-off-by: johnvan7 Signed-off-by: Giovanni Vella --- connector/oidc/oidc.go | 29 +++++++++++++++++++++++++++++ 1 file changed, 29 insertions(+) diff --git a/connector/oidc/oidc.go b/connector/oidc/oidc.go index b0e8afcf33..daea6b2e08 100644 --- a/connector/oidc/oidc.go +++ b/connector/oidc/oidc.go @@ -26,6 +26,15 @@ const ( codeChallengeMethodS256 = "S256" ) +func contains(arr []string, item string) bool { + for _, itemFromArray := range arr { + if itemFromArray == item { + return true + } + } + return false +} + // Config holds configuration options for OpenID Connect logins. type Config struct { Issuer string `json:"issuer"` @@ -275,6 +284,26 @@ func (c *Config) Open(id string, logger *slog.Logger) (conn connector.Connector, } } + // Obtain CodeChallengeMethodsSupported from the provider + var metadata struct { + CodeChallengeMethodsSupported []string `json:"code_challenge_methods_supported"` + } + if err := provider.Claims(&metadata); err != nil { + logger.Warn("failed to parse provider metadata") + } + // if PKCEChallenge method has not been setted in the config, auto-detect the best fit + if c.PKCEChallenge == "" { + if contains(metadata.CodeChallengeMethodsSupported, codeChallengeMethodS256) { + c.PKCEChallenge = codeChallengeMethodS256 + } else if contains(metadata.CodeChallengeMethodsSupported, codeChallengeMethodPlain) { + c.PKCEChallenge = codeChallengeMethodPlain + } + } else { + // if PKCEChallenge method has been setted in the config, check if it is supported + if !contains(metadata.CodeChallengeMethodsSupported, c.PKCEChallenge) { + logger.Warn("provided PKCEChallenge method not supported by the connector") + } + } pkceVerifier := "" clientID := c.ClientID From ff2135a52bd713c94d4577acd387782475117a67 Mon Sep 17 00:00:00 2001 From: Giovanni Vella Date: Sat, 19 Oct 2024 14:57:00 +0000 Subject: [PATCH 3/6] Handle pkce concurrent connections Signed-off-by: johnvan7 Signed-off-by: Giovanni Vella --- connector/oidc/oidc.go | 138 ++++++++++++++++++++++++++--------------- go.mod | 1 + go.sum | 2 + 3 files changed, 90 insertions(+), 51 deletions(-) diff --git a/connector/oidc/oidc.go b/connector/oidc/oidc.go index daea6b2e08..1b6077b7f6 100644 --- a/connector/oidc/oidc.go +++ b/connector/oidc/oidc.go @@ -14,6 +14,7 @@ import ( "time" "github.com/coreos/go-oidc/v3/oidc" + "github.com/hashicorp/golang-lru/v2" "golang.org/x/oauth2" "github.com/dexidp/dex/connector" @@ -22,8 +23,9 @@ import ( ) const ( - codeChallengeMethodPlain = "plain" - codeChallengeMethodS256 = "S256" + defaultPkceMaxConcurrentConnections = 256 + codeChallengeMethodPlain = "plain" + codeChallengeMethodS256 = "S256" ) func contains(arr []string, item string) bool { @@ -93,8 +95,13 @@ type Config struct { // PromptType will be used for the prompt parameter (when offline_access, by default prompt=consent) PromptType *string `json:"promptType"` + // PKCEChallenge specifies which PKCE algorithm will be used + // If not setted it will be auto-detected the best-fit for the connector. PKCEChallenge string `json:"pkceChallenge"` + // PKCEMaxConcurrentConnections specifies the maximum number of concurrent connections for the PKCE code verify. + PKCEMaxConcurrentConnections int `json:"pkceMaxConcurrentConnections"` + // OverrideClaimMapping will be used to override the options defined in claimMappings. // i.e. if there are 'email' and `preferred_email` claims available, by default Dex will always use the `email` claim independent of the ClaimMapping.EmailKey. // This setting allows you to override the default behavior of Dex and enforce the mappings defined in `claimMapping`. @@ -304,7 +311,19 @@ func (c *Config) Open(id string, logger *slog.Logger) (conn connector.Connector, logger.Warn("provided PKCEChallenge method not supported by the connector") } } - pkceVerifier := "" + + // if PKCE will be used, create a state cache for verifier + var pkceVerifierCache *lru.Cache[string, string] + if c.PKCEChallenge != "" { + pkceCacheSize := c.PKCEMaxConcurrentConnections + if pkceCacheSize == 0 { + pkceCacheSize = defaultPkceMaxConcurrentConnections + } + pkceVerifierCache, err = lru.New[string, string](pkceCacheSize) + if err != nil { + logger.Warn("Unable to create PKCE Verifier cache") + } + } clientID := c.ClientID return &oidcConnector{ @@ -321,25 +340,26 @@ func (c *Config) Open(id string, logger *slog.Logger) (conn connector.Connector, ctx, // Pass our ctx with customized http.Client &oidc.Config{ClientID: clientID}, ), - logger: logger.With(slog.Group("connector", "type", "oidc", "id", id)), - cancel: cancel, - httpClient: httpClient, - insecureSkipEmailVerified: c.InsecureSkipEmailVerified, - insecureEnableGroups: c.InsecureEnableGroups, - allowedGroups: c.AllowedGroups, - acrValues: c.AcrValues, - getUserInfo: c.GetUserInfo, - promptType: promptType, - userIDKey: c.UserIDKey, - userNameKey: c.UserNameKey, - overrideClaimMapping: c.OverrideClaimMapping, - preferredUsernameKey: c.ClaimMapping.PreferredUsernameKey, - emailKey: c.ClaimMapping.EmailKey, - groupsKey: c.ClaimMapping.GroupsKey, - newGroupFromClaims: c.ClaimMutations.NewGroupFromClaims, - groupsFilter: groupsFilter, - pkceChallenge: c.PKCEChallenge, - pkceVerifier: pkceVerifier, + logger: logger.With(slog.Group("connector", "type", "oidc", "id", id)), + cancel: cancel, + httpClient: httpClient, + insecureSkipEmailVerified: c.InsecureSkipEmailVerified, + insecureEnableGroups: c.InsecureEnableGroups, + allowedGroups: c.AllowedGroups, + acrValues: c.AcrValues, + getUserInfo: c.GetUserInfo, + promptType: promptType, + userIDKey: c.UserIDKey, + userNameKey: c.UserNameKey, + overrideClaimMapping: c.OverrideClaimMapping, + preferredUsernameKey: c.ClaimMapping.PreferredUsernameKey, + emailKey: c.ClaimMapping.EmailKey, + groupsKey: c.ClaimMapping.GroupsKey, + newGroupFromClaims: c.ClaimMutations.NewGroupFromClaims, + groupsFilter: groupsFilter, + pkceChallenge: c.PKCEChallenge, + pkceMaxConcurrentConnections: c.PKCEMaxConcurrentConnections, + pkceVerifierCache: pkceVerifierCache, }, nil } @@ -349,29 +369,30 @@ var ( ) type oidcConnector struct { - provider *oidc.Provider - redirectURI string - oauth2Config *oauth2.Config - verifier *oidc.IDTokenVerifier - cancel context.CancelFunc - logger *slog.Logger - httpClient *http.Client - insecureSkipEmailVerified bool - insecureEnableGroups bool - allowedGroups []string - acrValues []string - getUserInfo bool - promptType string - userIDKey string - userNameKey string - overrideClaimMapping bool - preferredUsernameKey string - emailKey string - groupsKey string - newGroupFromClaims []NewGroupFromClaims - groupsFilter *regexp.Regexp - pkceChallenge string - pkceVerifier string + provider *oidc.Provider + redirectURI string + oauth2Config *oauth2.Config + verifier *oidc.IDTokenVerifier + cancel context.CancelFunc + logger *slog.Logger + httpClient *http.Client + insecureSkipEmailVerified bool + insecureEnableGroups bool + allowedGroups []string + acrValues []string + getUserInfo bool + promptType string + userIDKey string + userNameKey string + overrideClaimMapping bool + preferredUsernameKey string + emailKey string + groupsKey string + newGroupFromClaims []NewGroupFromClaims + groupsFilter *regexp.Regexp + pkceChallenge string + pkceMaxConcurrentConnections int + pkceVerifierCache *lru.Cache[string, string] } func (c *oidcConnector) Close() error { @@ -398,11 +419,13 @@ func (c *oidcConnector) LoginURL(s connector.Scopes, callbackURL, state string) if c.pkceChallenge != "" { switch c.pkceChallenge { case codeChallengeMethodPlain: - c.pkceVerifier = oauth2.GenerateVerifier() - opts = append(opts, oauth2.VerifierOption(c.pkceVerifier)) + pkceVerifier := oauth2.GenerateVerifier() + c.pkceVerifierCache.Add(state, pkceVerifier) + opts = append(opts, oauth2.VerifierOption(pkceVerifier)) case codeChallengeMethodS256: - c.pkceVerifier = oauth2.GenerateVerifier() - opts = append(opts, oauth2.S256ChallengeOption(c.pkceVerifier)) + pkceVerifier := oauth2.GenerateVerifier() + c.pkceVerifierCache.Add(state, pkceVerifier) + opts = append(opts, oauth2.S256ChallengeOption(pkceVerifier)) default: c.logger.Warn("unknown PKCEChallenge method") } @@ -440,8 +463,21 @@ func (c *oidcConnector) HandleCallback(s connector.Scopes, r *http.Request) (ide ctx := context.WithValue(r.Context(), oauth2.HTTPClient, c.httpClient) var opts []oauth2.AuthCodeOption - if c.pkceVerifier != "" { - opts = append(opts, oauth2.VerifierOption(c.pkceVerifier)) + if c.pkceChallenge != "" { + state := q.Get("state") + if state == "" { + return identity, fmt.Errorf("oidc: missing state in callback") + } + pkceVerifier, found := c.pkceVerifierCache.Get(state) + if !found { + return identity, fmt.Errorf("oidc: received state not in callback cache") + } + + c.pkceVerifierCache.Remove(state) + if pkceVerifier == "" { + return identity, fmt.Errorf("oidc: invalid state in pkce verifier cache") + } + opts = append(opts, oauth2.VerifierOption(pkceVerifier)) } token, err := c.oauth2Config.Exchange(ctx, q.Get("code"), opts...) diff --git a/go.mod b/go.mod index e2100a396b..96671dcd8b 100644 --- a/go.mod +++ b/go.mod @@ -69,6 +69,7 @@ require ( github.com/google/s2a-go v0.1.8 // indirect github.com/googleapis/enterprise-certificate-proxy v0.3.3 // indirect github.com/googleapis/gax-go/v2 v2.13.0 // indirect + github.com/hashicorp/golang-lru/v2 v2.0.7 github.com/hashicorp/hcl/v2 v2.13.0 // indirect github.com/huandu/xstrings v1.5.0 // indirect github.com/inconshreveable/mousetrap v1.1.0 // indirect diff --git a/go.sum b/go.sum index cbebe635ec..5cbdb3cf41 100644 --- a/go.sum +++ b/go.sum @@ -135,6 +135,8 @@ github.com/grpc-ecosystem/go-grpc-prometheus v1.2.0/go.mod h1:8NvIoxWQoOIhqOTXgf github.com/hashicorp/go-uuid v1.0.2/go.mod h1:6SBZvOh/SIDV7/2o3Jml5SYk/TvGqwFJ/bN7x4byOro= github.com/hashicorp/go-uuid v1.0.3 h1:2gKiV6YVmrJ1i2CKKa9obLvRieoRGviZFL26PcT/Co8= github.com/hashicorp/go-uuid v1.0.3/go.mod h1:6SBZvOh/SIDV7/2o3Jml5SYk/TvGqwFJ/bN7x4byOro= +github.com/hashicorp/golang-lru/v2 v2.0.7 h1:a+bsQ5rvGLjzHuww6tVxozPZFVghXaHOwFs4luLUK2k= +github.com/hashicorp/golang-lru/v2 v2.0.7/go.mod h1:QeFd9opnmA6QUJc5vARoKUSoFhyfM2/ZepoAG6RGpeM= github.com/hashicorp/hcl/v2 v2.13.0 h1:0Apadu1w6M11dyGFxWnmhhcMjkbAiKCv7G1r/2QgCNc= github.com/hashicorp/hcl/v2 v2.13.0/go.mod h1:e4z5nxYlWNPdDSNYX+ph14EvWYMFm3eP0zIUqPc2jr0= github.com/huandu/xstrings v1.5.0 h1:2ag3IFq9ZDANvthTwTiqSSZLjDc+BedvHPAp5tJy2TI= From 39299d71959f1d619c797971c83dfcf070c822be Mon Sep 17 00:00:00 2001 From: Giovanni Vella Date: Fri, 25 Oct 2024 12:57:27 +0000 Subject: [PATCH 4/6] Add connectorData to handle state from LoginUrl to HandleCallback Signed-off-by: johnvan7 Signed-off-by: Giovanni Vella --- connector/bitbucketcloud/bitbucketcloud.go | 8 +- .../bitbucketcloud/bitbucketcloud_test.go | 2 +- connector/connector.go | 4 +- connector/gitea/gitea.go | 8 +- connector/gitea/gitea_test.go | 4 +- connector/github/github.go | 8 +- connector/github/github_test.go | 6 +- connector/gitlab/gitlab.go | 8 +- connector/gitlab/gitlab_test.go | 12 +- connector/google/google.go | 8 +- connector/google/google_test.go | 2 +- connector/linkedin/linkedin.go | 8 +- connector/microsoft/microsoft.go | 8 +- connector/microsoft/microsoft_test.go | 8 +- connector/mock/connectortest.go | 8 +- connector/oidc/oidc.go | 176 +++++++++--------- connector/oidc/oidc_test.go | 6 +- connector/openshift/openshift.go | 7 +- connector/openshift/openshift_test.go | 2 +- go.mod | 1 - go.sum | 2 - server/handlers.go | 16 +- 22 files changed, 160 insertions(+), 152 deletions(-) diff --git a/connector/bitbucketcloud/bitbucketcloud.go b/connector/bitbucketcloud/bitbucketcloud.go index 5f802e3414..d7fb64caa6 100644 --- a/connector/bitbucketcloud/bitbucketcloud.go +++ b/connector/bitbucketcloud/bitbucketcloud.go @@ -111,12 +111,12 @@ func (b *bitbucketConnector) oauth2Config(scopes connector.Scopes) *oauth2.Confi } } -func (b *bitbucketConnector) LoginURL(scopes connector.Scopes, callbackURL, state string) (string, error) { +func (b *bitbucketConnector) LoginURL(scopes connector.Scopes, callbackURL, state string) (string, []byte, error) { if b.redirectURI != callbackURL { - return "", fmt.Errorf("expected callback URL %q did not match the URL in the config %q", callbackURL, b.redirectURI) + return "", nil, fmt.Errorf("expected callback URL %q did not match the URL in the config %q", callbackURL, b.redirectURI) } - return b.oauth2Config(scopes).AuthCodeURL(state), nil + return b.oauth2Config(scopes).AuthCodeURL(state), nil, nil } type oauth2Error struct { @@ -131,7 +131,7 @@ func (e *oauth2Error) Error() string { return e.error + ": " + e.errorDescription } -func (b *bitbucketConnector) HandleCallback(s connector.Scopes, r *http.Request) (identity connector.Identity, err error) { +func (b *bitbucketConnector) HandleCallback(s connector.Scopes, connData []byte, r *http.Request) (identity connector.Identity, err error) { q := r.URL.Query() if errType := q.Get("error"); errType != "" { return identity, &oauth2Error{errType, q.Get("error_description")} diff --git a/connector/bitbucketcloud/bitbucketcloud_test.go b/connector/bitbucketcloud/bitbucketcloud_test.go index 9545ff09c5..67a74dab38 100644 --- a/connector/bitbucketcloud/bitbucketcloud_test.go +++ b/connector/bitbucketcloud/bitbucketcloud_test.go @@ -102,7 +102,7 @@ func TestUsernameIncludedInFederatedIdentity(t *testing.T) { expectNil(t, err) bitbucketConnector := bitbucketConnector{apiURL: s.URL, hostName: hostURL.Host, httpClient: newClient()} - identity, err := bitbucketConnector.HandleCallback(connector.Scopes{}, req) + identity, err := bitbucketConnector.HandleCallback(connector.Scopes{}, nil, req) expectNil(t, err) expectEquals(t, identity.Username, "some-login") diff --git a/connector/connector.go b/connector/connector.go index d812390f0c..b1e069c3fc 100644 --- a/connector/connector.go +++ b/connector/connector.go @@ -63,10 +63,10 @@ type CallbackConnector interface { // requested if one has already been issues. There's no good general answer // for these kind of restrictions, and may require this package to become more // aware of the global set of user/connector interactions. - LoginURL(s Scopes, callbackURL, state string) (string, error) + LoginURL(s Scopes, callbackURL, state string) (string, []byte, error) // Handle the callback to the server and return an identity. - HandleCallback(s Scopes, r *http.Request) (identity Identity, err error) + HandleCallback(s Scopes, connData []byte, r *http.Request) (identity Identity, err error) } // SAMLConnector represents SAML connectors which implement the HTTP POST binding. diff --git a/connector/gitea/gitea.go b/connector/gitea/gitea.go index 62523185d5..059c861705 100644 --- a/connector/gitea/gitea.go +++ b/connector/gitea/gitea.go @@ -102,11 +102,11 @@ func (c *giteaConnector) oauth2Config(_ connector.Scopes) *oauth2.Config { } } -func (c *giteaConnector) LoginURL(scopes connector.Scopes, callbackURL, state string) (string, error) { +func (c *giteaConnector) LoginURL(scopes connector.Scopes, callbackURL, state string) (string, []byte, error) { if c.redirectURI != callbackURL { - return "", fmt.Errorf("expected callback URL %q did not match the URL in the config %q", c.redirectURI, callbackURL) + return "", nil, fmt.Errorf("expected callback URL %q did not match the URL in the config %q", c.redirectURI, callbackURL) } - return c.oauth2Config(scopes).AuthCodeURL(state), nil + return c.oauth2Config(scopes).AuthCodeURL(state), nil, nil } type oauth2Error struct { @@ -121,7 +121,7 @@ func (e *oauth2Error) Error() string { return e.error + ": " + e.errorDescription } -func (c *giteaConnector) HandleCallback(s connector.Scopes, r *http.Request) (identity connector.Identity, err error) { +func (c *giteaConnector) HandleCallback(s connector.Scopes, connData []byte, r *http.Request) (identity connector.Identity, err error) { q := r.URL.Query() if errType := q.Get("error"); errType != "" { return identity, &oauth2Error{errType, q.Get("error_description")} diff --git a/connector/gitea/gitea_test.go b/connector/gitea/gitea_test.go index a71d79956e..4fe7768901 100644 --- a/connector/gitea/gitea_test.go +++ b/connector/gitea/gitea_test.go @@ -30,14 +30,14 @@ func TestUsernameIncludedInFederatedIdentity(t *testing.T) { expectNil(t, err) c := giteaConnector{baseURL: s.URL, httpClient: newClient()} - identity, err := c.HandleCallback(connector.Scopes{}, req) + identity, err := c.HandleCallback(connector.Scopes{}, nil, req) expectNil(t, err) expectEquals(t, identity.Username, "some@email.com") expectEquals(t, identity.UserID, "12345678") c = giteaConnector{baseURL: s.URL, httpClient: newClient()} - identity, err = c.HandleCallback(connector.Scopes{}, req) + identity, err = c.HandleCallback(connector.Scopes{}, nil, req) expectNil(t, err) expectEquals(t, identity.Username, "some@email.com") diff --git a/connector/github/github.go b/connector/github/github.go index 18a56628af..eb19f77805 100644 --- a/connector/github/github.go +++ b/connector/github/github.go @@ -194,12 +194,12 @@ func (c *githubConnector) oauth2Config(scopes connector.Scopes) *oauth2.Config { } } -func (c *githubConnector) LoginURL(scopes connector.Scopes, callbackURL, state string) (string, error) { +func (c *githubConnector) LoginURL(scopes connector.Scopes, callbackURL, state string) (string, []byte, error) { if c.redirectURI != callbackURL { - return "", fmt.Errorf("expected callback URL %q did not match the URL in the config %q", callbackURL, c.redirectURI) + return "", nil, fmt.Errorf("expected callback URL %q did not match the URL in the config %q", callbackURL, c.redirectURI) } - return c.oauth2Config(scopes).AuthCodeURL(state), nil + return c.oauth2Config(scopes).AuthCodeURL(state), nil, nil } type oauth2Error struct { @@ -214,7 +214,7 @@ func (e *oauth2Error) Error() string { return e.error + ": " + e.errorDescription } -func (c *githubConnector) HandleCallback(s connector.Scopes, r *http.Request) (identity connector.Identity, err error) { +func (c *githubConnector) HandleCallback(s connector.Scopes, connData []byte, r *http.Request) (identity connector.Identity, err error) { q := r.URL.Query() if errType := q.Get("error"); errType != "" { return identity, &oauth2Error{errType, q.Get("error_description")} diff --git a/connector/github/github_test.go b/connector/github/github_test.go index 088cbb238c..fb57161ce2 100644 --- a/connector/github/github_test.go +++ b/connector/github/github_test.go @@ -153,7 +153,7 @@ func TestUsernameIncludedInFederatedIdentity(t *testing.T) { expectNil(t, err) c := githubConnector{apiURL: s.URL, hostName: hostURL.Host, httpClient: newClient()} - identity, err := c.HandleCallback(connector.Scopes{Groups: true}, req) + identity, err := c.HandleCallback(connector.Scopes{Groups: true}, nil, req) expectNil(t, err) expectEquals(t, identity.Username, "some-login") @@ -161,7 +161,7 @@ func TestUsernameIncludedInFederatedIdentity(t *testing.T) { expectEquals(t, 0, len(identity.Groups)) c = githubConnector{apiURL: s.URL, hostName: hostURL.Host, httpClient: newClient(), loadAllGroups: true} - identity, err = c.HandleCallback(connector.Scopes{Groups: true}, req) + identity, err = c.HandleCallback(connector.Scopes{Groups: true}, nil, req) expectNil(t, err) expectEquals(t, identity.Username, "some-login") @@ -194,7 +194,7 @@ func TestLoginUsedAsIDWhenConfigured(t *testing.T) { expectNil(t, err) c := githubConnector{apiURL: s.URL, hostName: hostURL.Host, httpClient: newClient(), useLoginAsID: true} - identity, err := c.HandleCallback(connector.Scopes{Groups: true}, req) + identity, err := c.HandleCallback(connector.Scopes{Groups: true}, nil, req) expectNil(t, err) expectEquals(t, identity.UserID, "some-login") diff --git a/connector/gitlab/gitlab.go b/connector/gitlab/gitlab.go index fdb2c48204..30fb87fb35 100644 --- a/connector/gitlab/gitlab.go +++ b/connector/gitlab/gitlab.go @@ -100,11 +100,11 @@ func (c *gitlabConnector) oauth2Config(scopes connector.Scopes) *oauth2.Config { } } -func (c *gitlabConnector) LoginURL(scopes connector.Scopes, callbackURL, state string) (string, error) { +func (c *gitlabConnector) LoginURL(scopes connector.Scopes, callbackURL, state string) (string, []byte, error) { if c.redirectURI != callbackURL { - return "", fmt.Errorf("expected callback URL %q did not match the URL in the config %q", c.redirectURI, callbackURL) + return "", nil, fmt.Errorf("expected callback URL %q did not match the URL in the config %q", c.redirectURI, callbackURL) } - return c.oauth2Config(scopes).AuthCodeURL(state), nil + return c.oauth2Config(scopes).AuthCodeURL(state), nil, nil } type oauth2Error struct { @@ -119,7 +119,7 @@ func (e *oauth2Error) Error() string { return e.error + ": " + e.errorDescription } -func (c *gitlabConnector) HandleCallback(s connector.Scopes, r *http.Request) (identity connector.Identity, err error) { +func (c *gitlabConnector) HandleCallback(s connector.Scopes, connData []byte, r *http.Request) (identity connector.Identity, err error) { q := r.URL.Query() if errType := q.Get("error"); errType != "" { return identity, &oauth2Error{errType, q.Get("error_description")} diff --git a/connector/gitlab/gitlab_test.go b/connector/gitlab/gitlab_test.go index d828b8bd16..1a9b774ea5 100644 --- a/connector/gitlab/gitlab_test.go +++ b/connector/gitlab/gitlab_test.go @@ -84,7 +84,7 @@ func TestUsernameIncludedInFederatedIdentity(t *testing.T) { expectNil(t, err) c := gitlabConnector{baseURL: s.URL, httpClient: newClient()} - identity, err := c.HandleCallback(connector.Scopes{Groups: false}, req) + identity, err := c.HandleCallback(connector.Scopes{Groups: false}, nil, req) expectNil(t, err) expectEquals(t, identity.Username, "some@email.com") @@ -92,7 +92,7 @@ func TestUsernameIncludedInFederatedIdentity(t *testing.T) { expectEquals(t, 0, len(identity.Groups)) c = gitlabConnector{baseURL: s.URL, httpClient: newClient()} - identity, err = c.HandleCallback(connector.Scopes{Groups: true}, req) + identity, err = c.HandleCallback(connector.Scopes{Groups: true}, nil, req) expectNil(t, err) expectEquals(t, identity.Username, "some@email.com") @@ -120,7 +120,7 @@ func TestLoginUsedAsIDWhenConfigured(t *testing.T) { expectNil(t, err) c := gitlabConnector{baseURL: s.URL, httpClient: newClient(), useLoginAsID: true} - identity, err := c.HandleCallback(connector.Scopes{Groups: true}, req) + identity, err := c.HandleCallback(connector.Scopes{Groups: true}, nil, req) expectNil(t, err) expectEquals(t, identity.UserID, "joebloggs") @@ -147,7 +147,7 @@ func TestLoginWithTeamWhitelisted(t *testing.T) { expectNil(t, err) c := gitlabConnector{baseURL: s.URL, httpClient: newClient(), groups: []string{"team-1"}} - identity, err := c.HandleCallback(connector.Scopes{Groups: true}, req) + identity, err := c.HandleCallback(connector.Scopes{Groups: true}, nil, req) expectNil(t, err) expectEquals(t, identity.UserID, "12345678") @@ -174,7 +174,7 @@ func TestLoginWithTeamNonWhitelisted(t *testing.T) { expectNil(t, err) c := gitlabConnector{baseURL: s.URL, httpClient: newClient(), groups: []string{"team-2"}} - _, err = c.HandleCallback(connector.Scopes{Groups: true}, req) + _, err = c.HandleCallback(connector.Scopes{Groups: true}, nil, req) expectNotNil(t, err, "HandleCallback error") expectEquals(t, err.Error(), "gitlab: get groups: gitlab: user \"joebloggs\" is not in any of the required groups") @@ -208,7 +208,7 @@ func TestRefresh(t *testing.T) { }) expectNil(t, err) - identity, err := c.HandleCallback(connector.Scopes{OfflineAccess: true}, req) + identity, err := c.HandleCallback(connector.Scopes{OfflineAccess: true}, nil, req) expectNil(t, err) expectEquals(t, identity.Username, "some@email.com") expectEquals(t, identity.UserID, "12345678") diff --git a/connector/google/google.go b/connector/google/google.go index e17ec5bd7f..4a8599c0b1 100644 --- a/connector/google/google.go +++ b/connector/google/google.go @@ -168,9 +168,9 @@ func (c *googleConnector) Close() error { return nil } -func (c *googleConnector) LoginURL(s connector.Scopes, callbackURL, state string) (string, error) { +func (c *googleConnector) LoginURL(s connector.Scopes, callbackURL, state string) (string, []byte, error) { if c.redirectURI != callbackURL { - return "", fmt.Errorf("expected callback URL %q did not match the URL in the config %q", callbackURL, c.redirectURI) + return "", nil, fmt.Errorf("expected callback URL %q did not match the URL in the config %q", callbackURL, c.redirectURI) } var opts []oauth2.AuthCodeOption @@ -186,7 +186,7 @@ func (c *googleConnector) LoginURL(s connector.Scopes, callbackURL, state string opts = append(opts, oauth2.AccessTypeOffline, oauth2.SetAuthURLParam("prompt", c.promptType)) } - return c.oauth2Config.AuthCodeURL(state, opts...), nil + return c.oauth2Config.AuthCodeURL(state, opts...), nil, nil } type oauth2Error struct { @@ -201,7 +201,7 @@ func (e *oauth2Error) Error() string { return e.error + ": " + e.errorDescription } -func (c *googleConnector) HandleCallback(s connector.Scopes, r *http.Request) (identity connector.Identity, err error) { +func (c *googleConnector) HandleCallback(s connector.Scopes, connData []byte, r *http.Request) (identity connector.Identity, err error) { q := r.URL.Query() if errType := q.Get("error"); errType != "" { return identity, &oauth2Error{errType, q.Get("error_description")} diff --git a/connector/google/google_test.go b/connector/google/google_test.go index bafcadc8ff..68188c6559 100644 --- a/connector/google/google_test.go +++ b/connector/google/google_test.go @@ -440,7 +440,7 @@ func TestPromptTypeConfig(t *testing.T) { assert.Nil(t, err) assert.Equal(t, test.expectedPromptTypeValue, conn.promptType) - loginURL, err := conn.LoginURL(connector.Scopes{OfflineAccess: true}, ts.URL+"/callback", "state") + loginURL, _, err := conn.LoginURL(connector.Scopes{OfflineAccess: true}, ts.URL+"/callback", "state") assert.Nil(t, err) urlp, err := url.Parse(loginURL) diff --git a/connector/linkedin/linkedin.go b/connector/linkedin/linkedin.go index f17d17cca1..0c24ff4756 100644 --- a/connector/linkedin/linkedin.go +++ b/connector/linkedin/linkedin.go @@ -62,17 +62,17 @@ var ( ) // LoginURL returns an access token request URL -func (c *linkedInConnector) LoginURL(scopes connector.Scopes, callbackURL, state string) (string, error) { +func (c *linkedInConnector) LoginURL(scopes connector.Scopes, callbackURL, state string) (string, []byte, error) { if c.oauth2Config.RedirectURL != callbackURL { - return "", fmt.Errorf("expected callback URL %q did not match the URL in the config %q", + return "", nil, fmt.Errorf("expected callback URL %q did not match the URL in the config %q", callbackURL, c.oauth2Config.RedirectURL) } - return c.oauth2Config.AuthCodeURL(state), nil + return c.oauth2Config.AuthCodeURL(state), nil, nil } // HandleCallback handles HTTP redirect from LinkedIn -func (c *linkedInConnector) HandleCallback(s connector.Scopes, r *http.Request) (identity connector.Identity, err error) { +func (c *linkedInConnector) HandleCallback(s connector.Scopes, connData []byte, r *http.Request) (identity connector.Identity, err error) { q := r.URL.Query() if errType := q.Get("error"); errType != "" { return identity, &oauth2Error{errType, q.Get("error_description")} diff --git a/connector/microsoft/microsoft.go b/connector/microsoft/microsoft.go index 2fcf6a7515..1db16942b7 100644 --- a/connector/microsoft/microsoft.go +++ b/connector/microsoft/microsoft.go @@ -175,9 +175,9 @@ func (c *microsoftConnector) oauth2Config(scopes connector.Scopes) *oauth2.Confi } } -func (c *microsoftConnector) LoginURL(scopes connector.Scopes, callbackURL, state string) (string, error) { +func (c *microsoftConnector) LoginURL(scopes connector.Scopes, callbackURL, state string) (string, []byte, error) { if c.redirectURI != callbackURL { - return "", fmt.Errorf("expected callback URL %q did not match the URL in the config %q", callbackURL, c.redirectURI) + return "", nil, fmt.Errorf("expected callback URL %q did not match the URL in the config %q", callbackURL, c.redirectURI) } var options []oauth2.AuthCodeOption @@ -188,10 +188,10 @@ func (c *microsoftConnector) LoginURL(scopes connector.Scopes, callbackURL, stat options = append(options, oauth2.SetAuthURLParam("domain_hint", c.domainHint)) } - return c.oauth2Config(scopes).AuthCodeURL(state, options...), nil + return c.oauth2Config(scopes).AuthCodeURL(state, options...), nil, nil } -func (c *microsoftConnector) HandleCallback(s connector.Scopes, r *http.Request) (identity connector.Identity, err error) { +func (c *microsoftConnector) HandleCallback(s connector.Scopes, connData []byte, r *http.Request) (identity connector.Identity, err error) { q := r.URL.Query() if errType := q.Get("error"); errType != "" { return identity, &oauth2Error{errType, q.Get("error_description")} diff --git a/connector/microsoft/microsoft_test.go b/connector/microsoft/microsoft_test.go index 67be660fce..1fa2f3dcdf 100644 --- a/connector/microsoft/microsoft_test.go +++ b/connector/microsoft/microsoft_test.go @@ -39,7 +39,7 @@ func TestLoginURL(t *testing.T) { tenant: tenant, } - loginURL, _ := conn.LoginURL(connector.Scopes{}, conn.redirectURI, testState) + loginURL, _, _ := conn.LoginURL(connector.Scopes{}, conn.redirectURI, testState) parsedLoginURL, _ := url.Parse(loginURL) queryParams := parsedLoginURL.Query() @@ -70,7 +70,7 @@ func TestLoginURLWithOptions(t *testing.T) { domainHint: domainHint, } - loginURL, _ := conn.LoginURL(connector.Scopes{}, conn.redirectURI, "some-state") + loginURL, _, _ := conn.LoginURL(connector.Scopes{}, conn.redirectURI, "some-state") parsedLoginURL, _ := url.Parse(loginURL) queryParams := parsedLoginURL.Query() @@ -91,7 +91,7 @@ func TestUserIdentityFromGraphAPI(t *testing.T) { req, _ := http.NewRequest("GET", s.URL, nil) c := microsoftConnector{apiURL: s.URL, graphURL: s.URL, tenant: tenant} - identity, err := c.HandleCallback(connector.Scopes{Groups: false}, req) + identity, err := c.HandleCallback(connector.Scopes{Groups: false}, nil, req) expectNil(t, err) expectEquals(t, identity.Username, "Jane Doe") expectEquals(t, identity.UserID, "S56767889") @@ -114,7 +114,7 @@ func TestUserGroupsFromGraphAPI(t *testing.T) { req, _ := http.NewRequest("GET", s.URL, nil) c := microsoftConnector{apiURL: s.URL, graphURL: s.URL, tenant: tenant} - identity, err := c.HandleCallback(connector.Scopes{Groups: true}, req) + identity, err := c.HandleCallback(connector.Scopes{Groups: true}, nil, req) expectNil(t, err) expectEquals(t, identity.Groups, []string{"a", "b"}) } diff --git a/connector/mock/connectortest.go b/connector/mock/connectortest.go index 7e5979a992..be44bfd12a 100644 --- a/connector/mock/connectortest.go +++ b/connector/mock/connectortest.go @@ -43,21 +43,21 @@ type Callback struct { } // LoginURL returns the URL to redirect the user to login with. -func (m *Callback) LoginURL(s connector.Scopes, callbackURL, state string) (string, error) { +func (m *Callback) LoginURL(s connector.Scopes, callbackURL, state string) (string, []byte, error) { u, err := url.Parse(callbackURL) if err != nil { - return "", fmt.Errorf("failed to parse callbackURL %q: %v", callbackURL, err) + return "", nil, fmt.Errorf("failed to parse callbackURL %q: %v", callbackURL, err) } v := u.Query() v.Set("state", state) u.RawQuery = v.Encode() - return u.String(), nil + return u.String(), nil, nil } var connectorData = []byte("foobar") // HandleCallback parses the request and returns the user's identity -func (m *Callback) HandleCallback(s connector.Scopes, r *http.Request) (connector.Identity, error) { +func (m *Callback) HandleCallback(s connector.Scopes, connData []byte, r *http.Request) (connector.Identity, error) { return m.Identity, nil } diff --git a/connector/oidc/oidc.go b/connector/oidc/oidc.go index 1b6077b7f6..bdb601c5f6 100644 --- a/connector/oidc/oidc.go +++ b/connector/oidc/oidc.go @@ -14,7 +14,6 @@ import ( "time" "github.com/coreos/go-oidc/v3/oidc" - "github.com/hashicorp/golang-lru/v2" "golang.org/x/oauth2" "github.com/dexidp/dex/connector" @@ -23,9 +22,8 @@ import ( ) const ( - defaultPkceMaxConcurrentConnections = 256 - codeChallengeMethodPlain = "plain" - codeChallengeMethodS256 = "S256" + codeChallengeMethodPlain = "plain" + codeChallengeMethodS256 = "S256" ) func contains(arr []string, item string) bool { @@ -99,9 +97,6 @@ type Config struct { // If not setted it will be auto-detected the best-fit for the connector. PKCEChallenge string `json:"pkceChallenge"` - // PKCEMaxConcurrentConnections specifies the maximum number of concurrent connections for the PKCE code verify. - PKCEMaxConcurrentConnections int `json:"pkceMaxConcurrentConnections"` - // OverrideClaimMapping will be used to override the options defined in claimMappings. // i.e. if there are 'email' and `preferred_email` claims available, by default Dex will always use the `email` claim independent of the ClaimMapping.EmailKey. // This setting allows you to override the default behavior of Dex and enforce the mappings defined in `claimMapping`. @@ -235,6 +230,25 @@ func knownBrokenAuthHeaderProvider(issuerURL string) bool { return false } +// PKCEChallengeData is used to store info for PKCE Challenge method and verifier +// in the connectorData +type PKCEChallengeData struct { + CodeChallenge string `json:"codeChallenge"` + CodeChallengeMethod string `json:"codeChallengeMethod"` +} + +// Returns an AuthCodeOption according to the provided codeChallengeMethod +func getAuthCodeOptionForCodeChallenge(codeVerifier, codeChallengeMethod string) (oauth2.AuthCodeOption, error) { + switch codeChallengeMethod { + case codeChallengeMethodPlain: + return oauth2.VerifierOption(codeVerifier), nil + case codeChallengeMethodS256: + return oauth2.S256ChallengeOption(codeVerifier), nil + default: + return nil, fmt.Errorf("unknown challenge method (%v)", codeChallengeMethod) + } +} + // Open returns a connector which can be used to login users through an upstream // OpenID Connect provider. func (c *Config) Open(id string, logger *slog.Logger) (conn connector.Connector, err error) { @@ -312,19 +326,6 @@ func (c *Config) Open(id string, logger *slog.Logger) (conn connector.Connector, } } - // if PKCE will be used, create a state cache for verifier - var pkceVerifierCache *lru.Cache[string, string] - if c.PKCEChallenge != "" { - pkceCacheSize := c.PKCEMaxConcurrentConnections - if pkceCacheSize == 0 { - pkceCacheSize = defaultPkceMaxConcurrentConnections - } - pkceVerifierCache, err = lru.New[string, string](pkceCacheSize) - if err != nil { - logger.Warn("Unable to create PKCE Verifier cache") - } - } - clientID := c.ClientID return &oidcConnector{ provider: provider, @@ -340,26 +341,24 @@ func (c *Config) Open(id string, logger *slog.Logger) (conn connector.Connector, ctx, // Pass our ctx with customized http.Client &oidc.Config{ClientID: clientID}, ), - logger: logger.With(slog.Group("connector", "type", "oidc", "id", id)), - cancel: cancel, - httpClient: httpClient, - insecureSkipEmailVerified: c.InsecureSkipEmailVerified, - insecureEnableGroups: c.InsecureEnableGroups, - allowedGroups: c.AllowedGroups, - acrValues: c.AcrValues, - getUserInfo: c.GetUserInfo, - promptType: promptType, - userIDKey: c.UserIDKey, - userNameKey: c.UserNameKey, - overrideClaimMapping: c.OverrideClaimMapping, - preferredUsernameKey: c.ClaimMapping.PreferredUsernameKey, - emailKey: c.ClaimMapping.EmailKey, - groupsKey: c.ClaimMapping.GroupsKey, - newGroupFromClaims: c.ClaimMutations.NewGroupFromClaims, - groupsFilter: groupsFilter, - pkceChallenge: c.PKCEChallenge, - pkceMaxConcurrentConnections: c.PKCEMaxConcurrentConnections, - pkceVerifierCache: pkceVerifierCache, + logger: logger.With(slog.Group("connector", "type", "oidc", "id", id)), + cancel: cancel, + httpClient: httpClient, + insecureSkipEmailVerified: c.InsecureSkipEmailVerified, + insecureEnableGroups: c.InsecureEnableGroups, + allowedGroups: c.AllowedGroups, + acrValues: c.AcrValues, + getUserInfo: c.GetUserInfo, + promptType: promptType, + userIDKey: c.UserIDKey, + userNameKey: c.UserNameKey, + overrideClaimMapping: c.OverrideClaimMapping, + preferredUsernameKey: c.ClaimMapping.PreferredUsernameKey, + emailKey: c.ClaimMapping.EmailKey, + groupsKey: c.ClaimMapping.GroupsKey, + newGroupFromClaims: c.ClaimMutations.NewGroupFromClaims, + groupsFilter: groupsFilter, + pkceChallenge: c.PKCEChallenge, }, nil } @@ -369,30 +368,28 @@ var ( ) type oidcConnector struct { - provider *oidc.Provider - redirectURI string - oauth2Config *oauth2.Config - verifier *oidc.IDTokenVerifier - cancel context.CancelFunc - logger *slog.Logger - httpClient *http.Client - insecureSkipEmailVerified bool - insecureEnableGroups bool - allowedGroups []string - acrValues []string - getUserInfo bool - promptType string - userIDKey string - userNameKey string - overrideClaimMapping bool - preferredUsernameKey string - emailKey string - groupsKey string - newGroupFromClaims []NewGroupFromClaims - groupsFilter *regexp.Regexp - pkceChallenge string - pkceMaxConcurrentConnections int - pkceVerifierCache *lru.Cache[string, string] + provider *oidc.Provider + redirectURI string + oauth2Config *oauth2.Config + verifier *oidc.IDTokenVerifier + cancel context.CancelFunc + logger *slog.Logger + httpClient *http.Client + insecureSkipEmailVerified bool + insecureEnableGroups bool + allowedGroups []string + acrValues []string + getUserInfo bool + promptType string + userIDKey string + userNameKey string + overrideClaimMapping bool + preferredUsernameKey string + emailKey string + groupsKey string + newGroupFromClaims []NewGroupFromClaims + groupsFilter *regexp.Regexp + pkceChallenge string } func (c *oidcConnector) Close() error { @@ -400,12 +397,13 @@ func (c *oidcConnector) Close() error { return nil } -func (c *oidcConnector) LoginURL(s connector.Scopes, callbackURL, state string) (string, error) { +func (c *oidcConnector) LoginURL(s connector.Scopes, callbackURL, state string) (string, []byte, error) { if c.redirectURI != callbackURL { - return "", fmt.Errorf("expected callback URL %q did not match the URL in the config %q", callbackURL, c.redirectURI) + return "", nil, fmt.Errorf("expected callback URL %q did not match the URL in the config %q", callbackURL, c.redirectURI) } var opts []oauth2.AuthCodeOption + var connectorData []byte if len(c.acrValues) > 0 { acrValues := strings.Join(c.acrValues, " ") @@ -417,21 +415,23 @@ func (c *oidcConnector) LoginURL(s connector.Scopes, callbackURL, state string) } if c.pkceChallenge != "" { - switch c.pkceChallenge { - case codeChallengeMethodPlain: - pkceVerifier := oauth2.GenerateVerifier() - c.pkceVerifierCache.Add(state, pkceVerifier) - opts = append(opts, oauth2.VerifierOption(pkceVerifier)) - case codeChallengeMethodS256: - pkceVerifier := oauth2.GenerateVerifier() - c.pkceVerifierCache.Add(state, pkceVerifier) - opts = append(opts, oauth2.S256ChallengeOption(pkceVerifier)) - default: - c.logger.Warn("unknown PKCEChallenge method") + codeVerifier := oauth2.GenerateVerifier() + authCodeOption, err := getAuthCodeOptionForCodeChallenge(codeVerifier, c.pkceChallenge) + if err != nil { + return "", nil, fmt.Errorf("oidc: failed to get PKCE AuthCodeOption for CodeChallenge: %v", err) } + data := PKCEChallengeData{ + CodeChallenge: codeVerifier, + CodeChallengeMethod: c.pkceChallenge, + } + connectorData, err = json.Marshal(data) + if err != nil { + return "", nil, fmt.Errorf("oidc: failed to create PKCEChallenge data: %v", err) + } + opts = append(opts, authCodeOption) } - return c.oauth2Config.AuthCodeURL(state, opts...), nil + return c.oauth2Config.AuthCodeURL(state, opts...), connectorData, nil } type oauth2Error struct { @@ -454,7 +454,7 @@ const ( exchangeCaller ) -func (c *oidcConnector) HandleCallback(s connector.Scopes, r *http.Request) (identity connector.Identity, err error) { +func (c *oidcConnector) HandleCallback(s connector.Scopes, connData []byte, r *http.Request) (identity connector.Identity, err error) { q := r.URL.Query() if errType := q.Get("error"); errType != "" { return identity, &oauth2Error{errType, q.Get("error_description")} @@ -464,20 +464,14 @@ func (c *oidcConnector) HandleCallback(s connector.Scopes, r *http.Request) (ide var opts []oauth2.AuthCodeOption if c.pkceChallenge != "" { - state := q.Get("state") - if state == "" { - return identity, fmt.Errorf("oidc: missing state in callback") - } - pkceVerifier, found := c.pkceVerifierCache.Get(state) - if !found { - return identity, fmt.Errorf("oidc: received state not in callback cache") + var data PKCEChallengeData + if err := json.Unmarshal(connData, &data); err != nil { + return identity, fmt.Errorf("oidc: failed to parse PKCEChallenge data: %v", err) } - - c.pkceVerifierCache.Remove(state) - if pkceVerifier == "" { - return identity, fmt.Errorf("oidc: invalid state in pkce verifier cache") + if data.CodeChallenge == "" { + return identity, fmt.Errorf("oidc: invalid PKCE CodeChallenge") } - opts = append(opts, oauth2.VerifierOption(pkceVerifier)) + opts = append(opts, oauth2.VerifierOption(data.CodeChallenge)) } token, err := c.oauth2Config.Exchange(ctx, q.Get("code"), opts...) diff --git a/connector/oidc/oidc_test.go b/connector/oidc/oidc_test.go index 3ae517f8c0..8a99de74c2 100644 --- a/connector/oidc/oidc_test.go +++ b/connector/oidc/oidc_test.go @@ -466,7 +466,11 @@ func TestHandleCallback(t *testing.T) { t.Fatal("failed to create request", err) } - identity, err := conn.HandleCallback(connector.Scopes{Groups: true}, req) + connectorDataStrTemplate := `{"codeChallenge":"abcdefgh123456qwertuiop89101112uvpwizABC234","codeChallengeMethod":"%s"}` + connectorDataStr := fmt.Sprintf(connectorDataStrTemplate, config.PKCEChallenge) + connectorData := []byte(connectorDataStr) + + identity, err := conn.HandleCallback(connector.Scopes{Groups: true}, connectorData, req) if err != nil { t.Fatal("handle callback failed", err) } diff --git a/connector/openshift/openshift.go b/connector/openshift/openshift.go index 4519a85b6d..3d4408c585 100644 --- a/connector/openshift/openshift.go +++ b/connector/openshift/openshift.go @@ -138,12 +138,12 @@ func (c *openshiftConnector) Close() error { } // LoginURL returns the URL to redirect the user to login with. -func (c *openshiftConnector) LoginURL(scopes connector.Scopes, callbackURL, state string) (string, error) { +func (c *openshiftConnector) LoginURL(scopes connector.Scopes, callbackURL, state string) (string, []byte, error) { if c.redirectURI != callbackURL { - return "", fmt.Errorf("expected callback URL %q did not match the URL in the config %q", + return "", nil, fmt.Errorf("expected callback URL %q did not match the URL in the config %q", callbackURL, c.redirectURI) } - return c.oauth2Config.AuthCodeURL(state), nil + return c.oauth2Config.AuthCodeURL(state), nil, nil } type oauth2Error struct { @@ -160,6 +160,7 @@ func (e *oauth2Error) Error() string { // HandleCallback parses the request and returns the user's identity func (c *openshiftConnector) HandleCallback(s connector.Scopes, + connData []byte, r *http.Request, ) (identity connector.Identity, err error) { q := r.URL.Query() diff --git a/connector/openshift/openshift_test.go b/connector/openshift/openshift_test.go index 89ec0e25a9..d6d8603bbc 100644 --- a/connector/openshift/openshift_test.go +++ b/connector/openshift/openshift_test.go @@ -176,7 +176,7 @@ func TestCallbackIdentity(t *testing.T) { TokenURL: fmt.Sprintf("%s/oauth/token", s.URL), }, }} - identity, err := oc.HandleCallback(connector.Scopes{Groups: true}, req) + identity, err := oc.HandleCallback(connector.Scopes{Groups: true}, nil, req) expectNil(t, err) expectEquals(t, identity.UserID, "12345") diff --git a/go.mod b/go.mod index 96671dcd8b..e2100a396b 100644 --- a/go.mod +++ b/go.mod @@ -69,7 +69,6 @@ require ( github.com/google/s2a-go v0.1.8 // indirect github.com/googleapis/enterprise-certificate-proxy v0.3.3 // indirect github.com/googleapis/gax-go/v2 v2.13.0 // indirect - github.com/hashicorp/golang-lru/v2 v2.0.7 github.com/hashicorp/hcl/v2 v2.13.0 // indirect github.com/huandu/xstrings v1.5.0 // indirect github.com/inconshreveable/mousetrap v1.1.0 // indirect diff --git a/go.sum b/go.sum index 5cbdb3cf41..cbebe635ec 100644 --- a/go.sum +++ b/go.sum @@ -135,8 +135,6 @@ github.com/grpc-ecosystem/go-grpc-prometheus v1.2.0/go.mod h1:8NvIoxWQoOIhqOTXgf github.com/hashicorp/go-uuid v1.0.2/go.mod h1:6SBZvOh/SIDV7/2o3Jml5SYk/TvGqwFJ/bN7x4byOro= github.com/hashicorp/go-uuid v1.0.3 h1:2gKiV6YVmrJ1i2CKKa9obLvRieoRGviZFL26PcT/Co8= github.com/hashicorp/go-uuid v1.0.3/go.mod h1:6SBZvOh/SIDV7/2o3Jml5SYk/TvGqwFJ/bN7x4byOro= -github.com/hashicorp/golang-lru/v2 v2.0.7 h1:a+bsQ5rvGLjzHuww6tVxozPZFVghXaHOwFs4luLUK2k= -github.com/hashicorp/golang-lru/v2 v2.0.7/go.mod h1:QeFd9opnmA6QUJc5vARoKUSoFhyfM2/ZepoAG6RGpeM= github.com/hashicorp/hcl/v2 v2.13.0 h1:0Apadu1w6M11dyGFxWnmhhcMjkbAiKCv7G1r/2QgCNc= github.com/hashicorp/hcl/v2 v2.13.0/go.mod h1:e4z5nxYlWNPdDSNYX+ph14EvWYMFm3eP0zIUqPc2jr0= github.com/huandu/xstrings v1.5.0 h1:2ag3IFq9ZDANvthTwTiqSSZLjDc+BedvHPAp5tJy2TI= diff --git a/server/handlers.go b/server/handlers.go index 6521bf6a93..7a4152c8ad 100644 --- a/server/handlers.go +++ b/server/handlers.go @@ -263,12 +263,24 @@ func (s *Server) handleConnectorLogin(w http.ResponseWriter, r *http.Request) { // Use the auth request ID as the "state" token. // // TODO(ericchiang): Is this appropriate or should we also be using a nonce? - callbackURL, err := conn.LoginURL(scopes, s.absURL("/callback"), authReq.ID) + callbackURL, connData, err := conn.LoginURL(scopes, s.absURL("/callback"), authReq.ID) if err != nil { s.logger.ErrorContext(r.Context(), "connector returned error when creating callback", "connector_id", connID, "err", err) s.renderError(r, w, http.StatusInternalServerError, "Login error.") return } + if len(connData) > 0 { + updater := func(a storage.AuthRequest) (storage.AuthRequest, error) { + a.ConnectorData = connData + return a, nil + } + err := s.storage.UpdateAuthRequest(authReq.ID, updater) + if err != nil { + s.logger.ErrorContext(r.Context(), "Failed to set connector data on auth request", "connector_id", connID, "err", err) + s.renderError(r, w, http.StatusInternalServerError, "Database error.") + return + } + } http.Redirect(w, r, callbackURL, http.StatusFound) case connector.PasswordConnector: loginURL := url.URL{ @@ -463,7 +475,7 @@ func (s *Server) handleConnectorCallback(w http.ResponseWriter, r *http.Request) s.renderError(r, w, http.StatusBadRequest, "Invalid request") return } - identity, err = conn.HandleCallback(parseScopes(authReq.Scopes), r) + identity, err = conn.HandleCallback(parseScopes(authReq.Scopes), authReq.ConnectorData, r) case connector.SAMLConnector: if r.Method != http.MethodPost { s.logger.ErrorContext(r.Context(), "OAuth2 request mapped to SAML connector") From 432c432a4297453434f2403f7c8ca255cbfe6137 Mon Sep 17 00:00:00 2001 From: Giovanni Vella Date: Mon, 28 Oct 2024 09:13:57 +0000 Subject: [PATCH 5/6] trigger GitHub actions Signed-off-by: Giovanni Vella From 1ffd6b35b85d13827f3eca2b6ca50ec67def471e Mon Sep 17 00:00:00 2001 From: Giovanni Vella Date: Fri, 1 Nov 2024 11:01:42 +0000 Subject: [PATCH 6/6] interface fix Signed-off-by: Giovanni Vella --- connector/gitlab/gitlab_test.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/connector/gitlab/gitlab_test.go b/connector/gitlab/gitlab_test.go index bae97f4a33..045f3907a1 100644 --- a/connector/gitlab/gitlab_test.go +++ b/connector/gitlab/gitlab_test.go @@ -272,7 +272,7 @@ func TestGroupsWithPermission(t *testing.T) { expectNil(t, err) c := gitlabConnector{baseURL: s.URL, httpClient: newClient(), getGroupsPermission: true} - identity, err := c.HandleCallback(connector.Scopes{Groups: true}, req) + identity, err := c.HandleCallback(connector.Scopes{Groups: true}, nil, req) expectNil(t, err) expectEquals(t, identity.Groups, []string{