Skip to content

Commit

Permalink
Handle pkce concurrent connections
Browse files Browse the repository at this point in the history
Signed-off-by: johnvan7 <[email protected]>
  • Loading branch information
johnvan7 committed Oct 19, 2024
1 parent 3abdf95 commit 6eee8c1
Show file tree
Hide file tree
Showing 3 changed files with 90 additions and 51 deletions.
138 changes: 87 additions & 51 deletions connector/oidc/oidc.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ import (
"time"

"github.com/coreos/go-oidc/v3/oidc"
"github.com/hashicorp/golang-lru/v2"
"golang.org/x/oauth2"

"github.com/dexidp/dex/connector"
Expand All @@ -22,8 +23,9 @@ import (
)

const (
codeChallengeMethodPlain = "plain"
codeChallengeMethodS256 = "S256"
defaultPkceMaxConcurrentConnections = 256
codeChallengeMethodPlain = "plain"
codeChallengeMethodS256 = "S256"
)

func contains(arr []string, item string) bool {
Expand Down Expand Up @@ -93,8 +95,13 @@ type Config struct {
// PromptType will be used for the prompt parameter (when offline_access, by default prompt=consent)
PromptType *string `json:"promptType"`

// PKCEChallenge specifies which PKCE algorithm will be used
// If not setted it will be auto-detected the best-fit for the connector.
PKCEChallenge string `json:"pkceChallenge"`

// PKCEMaxConcurrentConnections specifies the maximum number of concurrent connections for the PKCE code verify.
PKCEMaxConcurrentConnections int `json:"pkceMaxConcurrentConnections"`

// OverrideClaimMapping will be used to override the options defined in claimMappings.
// i.e. if there are 'email' and `preferred_email` claims available, by default Dex will always use the `email` claim independent of the ClaimMapping.EmailKey.
// This setting allows you to override the default behavior of Dex and enforce the mappings defined in `claimMapping`.
Expand Down Expand Up @@ -304,7 +311,19 @@ func (c *Config) Open(id string, logger *slog.Logger) (conn connector.Connector,
logger.Warn("provided PKCEChallenge method not supported by the connector")
}
}
pkceVerifier := ""

// if PKCE will be used, create a state cache for verifier
var pkceVerifierCache *lru.Cache[string, string]
if c.PKCEChallenge != "" {
pkceCacheSize := c.PKCEMaxConcurrentConnections
if pkceCacheSize == 0 {
pkceCacheSize = defaultPkceMaxConcurrentConnections
}
pkceVerifierCache, err = lru.New[string, string](pkceCacheSize)
if err != nil {
logger.Warn("Unable to create PKCE Verifier cache")
}
}

clientID := c.ClientID
return &oidcConnector{
Expand All @@ -321,25 +340,26 @@ func (c *Config) Open(id string, logger *slog.Logger) (conn connector.Connector,
ctx, // Pass our ctx with customized http.Client
&oidc.Config{ClientID: clientID},
),
logger: logger.With(slog.Group("connector", "type", "oidc", "id", id)),
cancel: cancel,
httpClient: httpClient,
insecureSkipEmailVerified: c.InsecureSkipEmailVerified,
insecureEnableGroups: c.InsecureEnableGroups,
allowedGroups: c.AllowedGroups,
acrValues: c.AcrValues,
getUserInfo: c.GetUserInfo,
promptType: promptType,
userIDKey: c.UserIDKey,
userNameKey: c.UserNameKey,
overrideClaimMapping: c.OverrideClaimMapping,
preferredUsernameKey: c.ClaimMapping.PreferredUsernameKey,
emailKey: c.ClaimMapping.EmailKey,
groupsKey: c.ClaimMapping.GroupsKey,
newGroupFromClaims: c.ClaimMutations.NewGroupFromClaims,
groupsFilter: groupsFilter,
pkceChallenge: c.PKCEChallenge,
pkceVerifier: pkceVerifier,
logger: logger.With(slog.Group("connector", "type", "oidc", "id", id)),
cancel: cancel,
httpClient: httpClient,
insecureSkipEmailVerified: c.InsecureSkipEmailVerified,
insecureEnableGroups: c.InsecureEnableGroups,
allowedGroups: c.AllowedGroups,
acrValues: c.AcrValues,
getUserInfo: c.GetUserInfo,
promptType: promptType,
userIDKey: c.UserIDKey,
userNameKey: c.UserNameKey,
overrideClaimMapping: c.OverrideClaimMapping,
preferredUsernameKey: c.ClaimMapping.PreferredUsernameKey,
emailKey: c.ClaimMapping.EmailKey,
groupsKey: c.ClaimMapping.GroupsKey,
newGroupFromClaims: c.ClaimMutations.NewGroupFromClaims,
groupsFilter: groupsFilter,
pkceChallenge: c.PKCEChallenge,
pkceMaxConcurrentConnections: c.PKCEMaxConcurrentConnections,
pkceVerifierCache: pkceVerifierCache,
}, nil
}

Expand All @@ -349,29 +369,30 @@ var (
)

type oidcConnector struct {
provider *oidc.Provider
redirectURI string
oauth2Config *oauth2.Config
verifier *oidc.IDTokenVerifier
cancel context.CancelFunc
logger *slog.Logger
httpClient *http.Client
insecureSkipEmailVerified bool
insecureEnableGroups bool
allowedGroups []string
acrValues []string
getUserInfo bool
promptType string
userIDKey string
userNameKey string
overrideClaimMapping bool
preferredUsernameKey string
emailKey string
groupsKey string
newGroupFromClaims []NewGroupFromClaims
groupsFilter *regexp.Regexp
pkceChallenge string
pkceVerifier string
provider *oidc.Provider
redirectURI string
oauth2Config *oauth2.Config
verifier *oidc.IDTokenVerifier
cancel context.CancelFunc
logger *slog.Logger
httpClient *http.Client
insecureSkipEmailVerified bool
insecureEnableGroups bool
allowedGroups []string
acrValues []string
getUserInfo bool
promptType string
userIDKey string
userNameKey string
overrideClaimMapping bool
preferredUsernameKey string
emailKey string
groupsKey string
newGroupFromClaims []NewGroupFromClaims
groupsFilter *regexp.Regexp
pkceChallenge string
pkceMaxConcurrentConnections int
pkceVerifierCache *lru.Cache[string, string]
}

func (c *oidcConnector) Close() error {
Expand All @@ -398,11 +419,13 @@ func (c *oidcConnector) LoginURL(s connector.Scopes, callbackURL, state string)
if c.pkceChallenge != "" {
switch c.pkceChallenge {
case codeChallengeMethodPlain:
c.pkceVerifier = oauth2.GenerateVerifier()
opts = append(opts, oauth2.VerifierOption(c.pkceVerifier))
pkceVerifier := oauth2.GenerateVerifier()
c.pkceVerifierCache.Add(state, pkceVerifier)
opts = append(opts, oauth2.VerifierOption(pkceVerifier))
case codeChallengeMethodS256:
c.pkceVerifier = oauth2.GenerateVerifier()
opts = append(opts, oauth2.S256ChallengeOption(c.pkceVerifier))
pkceVerifier := oauth2.GenerateVerifier()
c.pkceVerifierCache.Add(state, pkceVerifier)
opts = append(opts, oauth2.S256ChallengeOption(pkceVerifier))
default:
c.logger.Warn("unknown PKCEChallenge method")
}
Expand Down Expand Up @@ -440,8 +463,21 @@ func (c *oidcConnector) HandleCallback(s connector.Scopes, r *http.Request) (ide
ctx := context.WithValue(r.Context(), oauth2.HTTPClient, c.httpClient)

var opts []oauth2.AuthCodeOption
if c.pkceVerifier != "" {
opts = append(opts, oauth2.VerifierOption(c.pkceVerifier))
if c.pkceChallenge != "" {
state := q.Get("state")
if state == "" {
return identity, fmt.Errorf("oidc: missing state in callback")
}
pkceVerifier, found := c.pkceVerifierCache.Get(state)
if !found {
return identity, fmt.Errorf("oidc: received state not in callback cache")
}

c.pkceVerifierCache.Remove(state)
if pkceVerifier == "" {
return identity, fmt.Errorf("oidc: invalid state in pkce verifier cache")
}
opts = append(opts, oauth2.VerifierOption(pkceVerifier))
}

token, err := c.oauth2Config.Exchange(ctx, q.Get("code"), opts...)
Expand Down
1 change: 1 addition & 0 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ require (
github.com/google/s2a-go v0.1.8 // indirect
github.com/googleapis/enterprise-certificate-proxy v0.3.3 // indirect
github.com/googleapis/gax-go/v2 v2.13.0 // indirect
github.com/hashicorp/golang-lru/v2 v2.0.7
github.com/hashicorp/hcl/v2 v2.13.0 // indirect
github.com/huandu/xstrings v1.5.0 // indirect
github.com/inconshreveable/mousetrap v1.1.0 // indirect
Expand Down
2 changes: 2 additions & 0 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,8 @@ github.com/grpc-ecosystem/go-grpc-prometheus v1.2.0/go.mod h1:8NvIoxWQoOIhqOTXgf
github.com/hashicorp/go-uuid v1.0.2/go.mod h1:6SBZvOh/SIDV7/2o3Jml5SYk/TvGqwFJ/bN7x4byOro=
github.com/hashicorp/go-uuid v1.0.3 h1:2gKiV6YVmrJ1i2CKKa9obLvRieoRGviZFL26PcT/Co8=
github.com/hashicorp/go-uuid v1.0.3/go.mod h1:6SBZvOh/SIDV7/2o3Jml5SYk/TvGqwFJ/bN7x4byOro=
github.com/hashicorp/golang-lru/v2 v2.0.7 h1:a+bsQ5rvGLjzHuww6tVxozPZFVghXaHOwFs4luLUK2k=
github.com/hashicorp/golang-lru/v2 v2.0.7/go.mod h1:QeFd9opnmA6QUJc5vARoKUSoFhyfM2/ZepoAG6RGpeM=
github.com/hashicorp/hcl/v2 v2.13.0 h1:0Apadu1w6M11dyGFxWnmhhcMjkbAiKCv7G1r/2QgCNc=
github.com/hashicorp/hcl/v2 v2.13.0/go.mod h1:e4z5nxYlWNPdDSNYX+ph14EvWYMFm3eP0zIUqPc2jr0=
github.com/huandu/xstrings v1.5.0 h1:2ag3IFq9ZDANvthTwTiqSSZLjDc+BedvHPAp5tJy2TI=
Expand Down

0 comments on commit 6eee8c1

Please sign in to comment.