diff --git a/cmd/dex/config.go b/cmd/dex/config.go index 37167bb090..9cc98db035 100644 --- a/cmd/dex/config.go +++ b/cmd/dex/config.go @@ -342,9 +342,16 @@ type Logger struct { Format string `json:"format"` } +type MultipleRefreshToken struct { + Allow bool `json:"allow"` + MaximumCount int `json:"maximumCount"` + ReplacementPolicy string `json:"replacementPolicy"` +} + type RefreshToken struct { - DisableRotation bool `json:"disableRotation"` - ReuseInterval string `json:"reuseInterval"` - AbsoluteLifetime string `json:"absoluteLifetime"` - ValidIfNotUsedFor string `json:"validIfNotUsedFor"` + DisableRotation bool `json:"disableRotation"` + MultipleTokens MultipleRefreshToken `json:"multipleTokens"` + ReuseInterval string `json:"reuseInterval"` + AbsoluteLifetime string `json:"absoluteLifetime"` + ValidIfNotUsedFor string `json:"validIfNotUsedFor"` } diff --git a/cmd/dex/serve.go b/cmd/dex/serve.go index 92068934d8..7a548ac59f 100644 --- a/cmd/dex/serve.go +++ b/cmd/dex/serve.go @@ -310,7 +310,9 @@ func runServe(options serveOptions) error { c.Expiry.RefreshTokens.ValidIfNotUsedFor, c.Expiry.RefreshTokens.AbsoluteLifetime, c.Expiry.RefreshTokens.ReuseInterval, - ) + c.Expiry.RefreshTokens.MultipleTokens.Allow, + c.Expiry.RefreshTokens.MultipleTokens.MaximumCount, + c.Expiry.RefreshTokens.MultipleTokens.ReplacementPolicy) if err != nil { return fmt.Errorf("invalid refresh token expiration policy config: %v", err) } @@ -449,7 +451,7 @@ func runServe(options serveOptions) error { } grpcSrv := grpc.NewServer(grpcOptions...) - api.RegisterDexServer(grpcSrv, server.NewAPI(serverConfig.Storage, logger, version)) + api.RegisterDexServer(grpcSrv, server.NewAPI(serverConfig.Storage, logger, version, c.Expiry.RefreshTokens.MultipleTokens.Allow)) grpcMetrics.InitializeMetrics(grpcSrv) if c.GRPC.Reflection { diff --git a/server/api.go b/server/api.go index a68742b3cc..27bac5d16a 100644 --- a/server/api.go +++ b/server/api.go @@ -29,11 +29,12 @@ const ( ) // NewAPI returns a server which implements the gRPC API interface. -func NewAPI(s storage.Storage, logger log.Logger, version string) api.DexServer { +func NewAPI(s storage.Storage, logger log.Logger, version string, multiRefreshTokens bool) api.DexServer { return dexAPI{ - s: s, - logger: logger, - version: version, + s: s, + logger: logger, + version: version, + multiRefreshTokens: multiRefreshTokens, } } @@ -43,6 +44,8 @@ type dexAPI struct { s storage.Storage logger log.Logger version string + + multiRefreshTokens bool } func (d dexAPI) CreateClient(ctx context.Context, req *api.CreateClientReq) (*api.CreateClientResp, error) { @@ -283,6 +286,13 @@ func (d dexAPI) VerifyPassword(ctx context.Context, req *api.VerifyPasswordReq) } func (d dexAPI) ListRefresh(ctx context.Context, req *api.ListRefreshReq) (*api.ListRefreshResp, error) { + if d.multiRefreshTokens { + return d.listMultipleRefreshTokensMode(ctx, req) + } + return d.listRefresh(ctx, req) +} + +func (d dexAPI) listRefresh(ctx context.Context, req *api.ListRefreshReq) (*api.ListRefreshResp, error) { id := new(internal.IDTokenSubject) if err := internal.Unmarshal(req.UserId, id); err != nil { d.logger.Errorf("api: failed to unmarshal ID Token subject: %v", err) @@ -316,7 +326,84 @@ func (d dexAPI) ListRefresh(ctx context.Context, req *api.ListRefreshReq) (*api. }, nil } +func (d dexAPI) listMultipleRefreshTokensMode(ctx context.Context, req *api.ListRefreshReq) (*api.ListRefreshResp, error) { + id := new(internal.IDTokenSubject) + if err := internal.Unmarshal(req.UserId, id); err != nil { + d.logger.Errorf("api: failed to unmarshal ID Token subject: %v", err) + return nil, err + } + + var refreshTokenRefs []*api.RefreshTokenRef + + // TODO: As OfflineSession has a reference to lastUpdated RefreshToken, listing add RefreshTokens and filtering + refreshTokens, err := d.s.ListRefreshTokens() + if err != nil { + return nil, err + } + for _, t := range refreshTokens { + if t.Claims.UserID == id.UserId && t.ConnectorID == id.ConnId { + r := api.RefreshTokenRef{ + Id: t.ID, + ClientId: t.ClientID, + CreatedAt: t.CreatedAt.Unix(), + LastUsed: t.LastUsed.Unix(), + } + refreshTokenRefs = append(refreshTokenRefs, &r) + } + } + + return &api.ListRefreshResp{ + RefreshTokens: refreshTokenRefs, + }, nil +} + func (d dexAPI) RevokeRefresh(ctx context.Context, req *api.RevokeRefreshReq) (*api.RevokeRefreshResp, error) { + if d.multiRefreshTokens { + return d.revokeMultipleRefreshTokensMode(ctx, req) + } + return d.revokeRefresh(ctx, req) +} + +func (d dexAPI) revokeMultipleRefreshTokensMode(ctx context.Context, req *api.RevokeRefreshReq) (*api.RevokeRefreshResp, error) { + id := new(internal.IDTokenSubject) + if err := internal.Unmarshal(req.UserId, id); err != nil { + d.logger.Errorf("api: failed to unmarshal ID Token subject: %v", err) + return nil, err + } + + // FIXME: listing all tokens can be slow + refreshTokens, err := d.s.ListRefreshTokens() + if err != nil { + return nil, err + } + if len(refreshTokens) == 0 { + return &api.RevokeRefreshResp{NotFound: true}, nil + } + + for _, t := range refreshTokens { + if t.Claims.UserID == id.UserId && t.ConnectorID == id.ConnId && t.ClientID == req.ClientId { + if err := d.s.DeleteRefresh(t.ID); err != nil { + d.logger.Errorf("failed to delete refresh token: %v", err) + return nil, err + } + } + } + + updater := func(old storage.OfflineSessions) (storage.OfflineSessions, error) { + // Remove entry from Refresh list of the OfflineSession object. + delete(old.Refresh, req.ClientId) + return old, nil + } + + if err := d.s.UpdateOfflineSessions(id.UserId, id.ConnId, updater); err != nil { + d.logger.Errorf("api: failed to update offline session object: %v", err) + return nil, err + } + + return &api.RevokeRefreshResp{}, nil +} + +func (d dexAPI) revokeRefresh(ctx context.Context, req *api.RevokeRefreshReq) (*api.RevokeRefreshResp, error) { id := new(internal.IDTokenSubject) if err := internal.Unmarshal(req.UserId, id); err != nil { d.logger.Errorf("api: failed to unmarshal ID Token subject: %v", err) diff --git a/server/api_test.go b/server/api_test.go index 020613402c..a5c4c7ff81 100644 --- a/server/api_test.go +++ b/server/api_test.go @@ -36,7 +36,7 @@ func newAPI(s storage.Storage, logger log.Logger, t *testing.T) *apiClient { } serv := grpc.NewServer() - api.RegisterDexServer(serv, NewAPI(s, logger, "test")) + api.RegisterDexServer(serv, NewAPI(s, logger, "test", false)) go serv.Serve(l) // Dial will retry automatically if the serv.Serve() goroutine diff --git a/server/handlers.go b/server/handlers.go index 2a4f8c71de..d03c56c00f 100644 --- a/server/handlers.go +++ b/server/handlers.go @@ -961,13 +961,19 @@ func (s *Server) exchangeAuthCode(w http.ResponseWriter, authCode storage.AuthCo return nil, err } } else { - if oldTokenRef, ok := session.Refresh[tokenRef.ClientID]; ok { - // Delete old refresh token from storage. - if err := s.storage.DeleteRefresh(oldTokenRef.ID); err != nil && err != storage.ErrNotFound { - s.logger.Errorf("failed to delete refresh token: %v", err) - s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError) - deleteToken = true - return nil, err + if !s.refreshTokenPolicy.allowMultiple { + if oldTokenRef, ok := session.Refresh[tokenRef.ClientID]; ok { + // Delete old refresh token from storage. + if err := s.storage.DeleteRefresh(oldTokenRef.ID); err != nil && err != storage.ErrNotFound { + s.logger.Errorf("failed to delete refresh token: %v", err) + s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError) + deleteToken = true + return nil, err + } + } + } else { + if err := s.deleteRefreshTokens(refresh.ConnectorID, refresh.Claims.UserID); err != nil { + s.logger.Errorf("error while deleting refresh token: %v", err) } } @@ -1205,18 +1211,24 @@ func (s *Server) handlePasswordGrant(w http.ResponseWriter, r *http.Request, cli return } } else { - if oldTokenRef, ok := session.Refresh[tokenRef.ClientID]; ok { - // Delete old refresh token from storage. - if err := s.storage.DeleteRefresh(oldTokenRef.ID); err != nil { - if err == storage.ErrNotFound { - s.logger.Warnf("database inconsistent, refresh token missing: %v", oldTokenRef.ID) - } else { - s.logger.Errorf("failed to delete refresh token: %v", err) - s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError) - deleteToken = true - return + if !s.refreshTokenPolicy.allowMultiple { + if oldTokenRef, ok := session.Refresh[tokenRef.ClientID]; ok { + // Delete old refresh token from storage. + if err := s.storage.DeleteRefresh(oldTokenRef.ID); err != nil { + if err == storage.ErrNotFound { + s.logger.Warnf("database inconsistent, refresh token missing: %v", oldTokenRef.ID) + } else { + s.logger.Errorf("failed to delete refresh token: %v", err) + s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError) + deleteToken = true + return + } } } + } else { + if err := s.deleteRefreshTokens(refresh.ConnectorID, refresh.Claims.UserID); err != nil { + s.logger.Errorf("error while deleting refresh token: %v", err) + } } // Update existing OfflineSession obj with new RefreshTokenRef. @@ -1237,6 +1249,48 @@ func (s *Server) handlePasswordGrant(w http.ResponseWriter, r *http.Request, cli s.writeAccessToken(w, resp) } +func (s *Server) deleteRefreshTokens(connectorID string, userID string) error { + refreshTokens, err := s.storage.ListRefreshTokens() + if err != nil { + return err + } + + var userRefreshTokens []storage.RefreshToken + for index := range refreshTokens { + refreshToken := refreshTokens[index] + if refreshToken.ConnectorID == connectorID && refreshToken.Claims.UserID == userID { + userRefreshTokens = append(userRefreshTokens, refreshToken) + } + } + + if len(userRefreshTokens) <= s.refreshTokenPolicy.maxTokens { + return nil + } + + sort.SliceStable(userRefreshTokens, func(i, j int) bool { + if s.refreshTokenPolicy.tokenReplacementPolicy == FCFS { + return userRefreshTokens[i].CreatedAt.Before(userRefreshTokens[j].CreatedAt) + } else { + return userRefreshTokens[i].LastUsed.Before(userRefreshTokens[j].LastUsed) + } + }) + + tokensToDelete := userRefreshTokens[:len(userRefreshTokens)-s.refreshTokenPolicy.maxTokens] + var deletionError bool + for index := range tokensToDelete { + refreshToken := tokensToDelete[index] + if err := s.storage.DeleteRefresh(refreshToken.ID); err != nil { + deletionError = true + s.logger.Errorf("error while deleting refresh token: %v", err) + } + } + + if deletionError { + return fmt.Errorf("error while deleting refresh token for userID %s of connector %s", userID, connectorID) + } + return nil +} + type accessTokenResponse struct { AccessToken string `json:"access_token"` TokenType string `json:"token_type"` diff --git a/server/refreshhandlers.go b/server/refreshhandlers.go index 8ea7ea9ef1..bc1c7be2ab 100644 --- a/server/refreshhandlers.go +++ b/server/refreshhandlers.go @@ -197,8 +197,13 @@ func (s *Server) refreshWithConnector(ctx context.Context, token *internal.Refre func (s *Server) updateOfflineSession(refresh *storage.RefreshToken, ident connector.Identity, lastUsed time.Time) *refreshError { offlineSessionUpdater := func(old storage.OfflineSessions) (storage.OfflineSessions, error) { if old.Refresh[refresh.ClientID].ID != refresh.ID { + // For multiple tokens only latest one is referred in offline session + if s.refreshTokenPolicy.allowMultiple { + return old, nil + } return old, errors.New("refresh token invalid") } + old.Refresh[refresh.ClientID].LastUsed = lastUsed old.ConnectorData = ident.ConnectorData return old, nil diff --git a/server/rotation.go b/server/rotation.go index 98489767e0..ddcee5381a 100644 --- a/server/rotation.go +++ b/server/rotation.go @@ -178,6 +178,13 @@ func (k keyRotator) rotate() error { return nil } +type tokenReplacementPolicy string + +const ( + LRU tokenReplacementPolicy = "LRU" + FCFS tokenReplacementPolicy = "FCFS" +) + type RefreshTokenPolicy struct { rotateRefreshTokens bool // enable rotation @@ -185,12 +192,17 @@ type RefreshTokenPolicy struct { validIfNotUsedFor time.Duration // interval from last token update to the end of its life reuseInterval time.Duration // interval within which old refresh token is allowed to be reused + allowMultiple bool + maxTokens int + tokenReplacementPolicy tokenReplacementPolicy + now func() time.Time logger log.Logger } -func NewRefreshTokenPolicy(logger log.Logger, rotation bool, validIfNotUsedFor, absoluteLifetime, reuseInterval string) (*RefreshTokenPolicy, error) { +func NewRefreshTokenPolicy(logger log.Logger, rotation bool, validIfNotUsedFor, absoluteLifetime, reuseInterval string, + allowMultiple bool, maxTokens int, tokenReplacementPolicy string) (*RefreshTokenPolicy, error) { r := RefreshTokenPolicy{now: time.Now, logger: logger} var err error @@ -220,6 +232,33 @@ func NewRefreshTokenPolicy(logger log.Logger, rotation bool, validIfNotUsedFor, r.rotateRefreshTokens = !rotation logger.Infof("config refresh tokens rotation enabled: %v", r.rotateRefreshTokens) + + r.allowMultiple = allowMultiple + logger.Infof("config refresh tokens allow multiple: %v", allowMultiple) + + if allowMultiple { + if maxTokens < 1 { + r.maxTokens = 50 + } else { + r.maxTokens = maxTokens + } + logger.Infof("config refresh tokens max multiple tokens: %v", r.maxTokens) + + if tokenReplacementPolicy != "" { + switch tokenReplacementPolicy { + case string(LRU): + r.tokenReplacementPolicy = LRU + case string(FCFS): + r.tokenReplacementPolicy = FCFS + default: + return nil, fmt.Errorf("invalid config value %q for token replacement policy", tokenReplacementPolicy) + } + } else { + r.tokenReplacementPolicy = LRU + } + logger.Infof("config refresh tokens token replacement policy: %v", r.tokenReplacementPolicy) + } + return &r, nil } diff --git a/server/rotation_test.go b/server/rotation_test.go index e279bf543e..b6206b3735 100644 --- a/server/rotation_test.go +++ b/server/rotation_test.go @@ -110,7 +110,7 @@ func TestRefreshTokenPolicy(t *testing.T) { Level: logrus.DebugLevel, } - r, err := NewRefreshTokenPolicy(l, true, "1m", "1m", "1m") + r, err := NewRefreshTokenPolicy(l, true, "1m", "1m", "1m", false, 0, "") require.NoError(t, err) t.Run("Allowed", func(t *testing.T) { diff --git a/server/server_test.go b/server/server_test.go index 6f4bcb81aa..316ceef009 100644 --- a/server/server_test.go +++ b/server/server_test.go @@ -122,7 +122,7 @@ func newTestServer(ctx context.Context, t *testing.T, updateConfig func(c *Confi // Default rotation policy if server.refreshTokenPolicy == nil { - server.refreshTokenPolicy, err = NewRefreshTokenPolicy(logger, false, "", "", "") + server.refreshTokenPolicy, err = NewRefreshTokenPolicy(logger, false, "", "", "", false, 0, "") if err != nil { t.Fatalf("failed to prepare rotation policy: %v", err) } diff --git a/storage/kubernetes/storage.go b/storage/kubernetes/storage.go index 13549ef5ff..fd10b181cb 100644 --- a/storage/kubernetes/storage.go +++ b/storage/kubernetes/storage.go @@ -366,8 +366,17 @@ func (cli *client) ListClients() ([]storage.Client, error) { return nil, errors.New("not implemented") } -func (cli *client) ListRefreshTokens() ([]storage.RefreshToken, error) { - return nil, errors.New("not implemented") +func (cli *client) ListRefreshTokens() (refreshTokens []storage.RefreshToken, err error) { + var refreshList RefreshList + if err = cli.list(resourceRefreshToken, &refreshList); err != nil { + return refreshTokens, fmt.Errorf("failed to list refresh tokens: %v", err) + } + + for _, refreshToken := range refreshList.RefreshTokens { + refreshTokens = append(refreshTokens, toStorageRefreshToken(refreshToken)) + } + + return } func (cli *client) ListPasswords() (passwords []storage.Password, err error) {