Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support multiple refresh tokens per user. #1829

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions cmd/dex/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,8 @@ type Config struct {
// querying the storage. Cannot be specified without enabling a passwords
// database.
StaticPasswords []password `json:"staticPasswords"`

EnableMultiRefreshTokens bool `json:"enableMultiRefreshTokens"`
}

//Validate the configuration
Expand Down
25 changes: 13 additions & 12 deletions cmd/dex/serve.go
Original file line number Diff line number Diff line change
Expand Up @@ -233,17 +233,18 @@ func serve(cmd *cobra.Command, args []string) error {
now := func() time.Time { return time.Now().UTC() }

serverConfig := server.Config{
SupportedResponseTypes: c.OAuth2.ResponseTypes,
SkipApprovalScreen: c.OAuth2.SkipApprovalScreen,
AlwaysShowLoginScreen: c.OAuth2.AlwaysShowLoginScreen,
PasswordConnector: c.OAuth2.PasswordConnector,
AllowedOrigins: c.Web.AllowedOrigins,
Issuer: c.Issuer,
Storage: s,
Web: c.Frontend,
Logger: logger,
Now: now,
PrometheusRegistry: prometheusRegistry,
SupportedResponseTypes: c.OAuth2.ResponseTypes,
SkipApprovalScreen: c.OAuth2.SkipApprovalScreen,
AlwaysShowLoginScreen: c.OAuth2.AlwaysShowLoginScreen,
PasswordConnector: c.OAuth2.PasswordConnector,
AllowedOrigins: c.Web.AllowedOrigins,
Issuer: c.Issuer,
Storage: s,
Web: c.Frontend,
Logger: logger,
Now: now,
PrometheusRegistry: prometheusRegistry,
EnableMultiRefreshTokens: c.EnableMultiRefreshTokens,
}
if c.Expiry.SigningKeys != "" {
signingKeys, err := time.ParseDuration(c.Expiry.SigningKeys)
Expand Down Expand Up @@ -326,7 +327,7 @@ func serve(cmd *cobra.Command, args []string) error {
return fmt.Errorf("listening on %s failed: %v", c.GRPC.Addr, err)
}
s := grpc.NewServer(grpcOptions...)
api.RegisterDexServer(s, server.NewAPI(serverConfig.Storage, logger))
api.RegisterDexServer(s, server.NewAPI(serverConfig.Storage, logger, c.EnableMultiRefreshTokens))
grpcMetrics.InitializeMetrics(s)
if c.GRPC.Reflection {
logger.Info("enabling reflection in grpc service")
Expand Down
96 changes: 91 additions & 5 deletions server/api.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,16 +30,18 @@ const (
)

// NewAPI returns a server which implements the gRPC API interface.
func NewAPI(s storage.Storage, logger log.Logger) api.DexServer {
func NewAPI(s storage.Storage, logger log.Logger, enableMultiRefreshTokens bool) api.DexServer {
return dexAPI{
s: s,
logger: logger,
s: s,
logger: logger,
enableMultiRefreshTokens: enableMultiRefreshTokens,
}
}

type dexAPI struct {
s storage.Storage
logger log.Logger
s storage.Storage
logger log.Logger
enableMultiRefreshTokens bool
}

func (d dexAPI) CreateClient(ctx context.Context, req *api.CreateClientReq) (*api.CreateClientResp, error) {
Expand Down Expand Up @@ -281,6 +283,13 @@ func (d dexAPI) VerifyPassword(ctx context.Context, req *api.VerifyPasswordReq)
}

func (d dexAPI) ListRefresh(ctx context.Context, req *api.ListRefreshReq) (*api.ListRefreshResp, error) {
if d.enableMultiRefreshTokens {
return d.listRefreshMultiRefreshMode(ctx, req)
}
return d.listRefresh(ctx, req)
}

func (d dexAPI) listRefresh(ctx context.Context, req *api.ListRefreshReq) (*api.ListRefreshResp, error) {
id := new(internal.IDTokenSubject)
if err := internal.Unmarshal(req.UserId, id); err != nil {
d.logger.Errorf("api: failed to unmarshal ID Token subject: %v", err)
Expand Down Expand Up @@ -316,7 +325,45 @@ func (d dexAPI) ListRefresh(ctx context.Context, req *api.ListRefreshReq) (*api.
}, nil
}

func (d dexAPI) listRefreshMultiRefreshMode(ctx context.Context, req *api.ListRefreshReq) (*api.ListRefreshResp, error) {
id := new(internal.IDTokenSubject)
if err := internal.Unmarshal(req.UserId, id); err != nil {
d.logger.Errorf("api: failed to unmarshal ID Token subject: %v", err)
return nil, err
}

var refreshTokenRefs []*api.RefreshTokenRef

// FIXME: listing all tokens can be slow
refreshTokens, err := d.s.ListRefreshTokens()
if err != nil {
return nil, err
}
for _, t := range refreshTokens {
if t.Claims.UserID == id.UserId && t.ConnectorID == id.ConnId {
r := api.RefreshTokenRef{
Id: t.ID,
ClientId: t.ClientID,
CreatedAt: t.CreatedAt.Unix(),
LastUsed: t.LastUsed.Unix(),
}
refreshTokenRefs = append(refreshTokenRefs, &r)
}
}

return &api.ListRefreshResp{
RefreshTokens: refreshTokenRefs,
}, nil
}

func (d dexAPI) RevokeRefresh(ctx context.Context, req *api.RevokeRefreshReq) (*api.RevokeRefreshResp, error) {
if d.enableMultiRefreshTokens {
return d.revokeRefreshMultiRefreshMode(ctx, req)
}
return d.revokeRefresh(ctx, req)
}

func (d dexAPI) revokeRefresh(ctx context.Context, req *api.RevokeRefreshReq) (*api.RevokeRefreshResp, error) {
id := new(internal.IDTokenSubject)
if err := internal.Unmarshal(req.UserId, id); err != nil {
d.logger.Errorf("api: failed to unmarshal ID Token subject: %v", err)
Expand Down Expand Up @@ -366,3 +413,42 @@ func (d dexAPI) RevokeRefresh(ctx context.Context, req *api.RevokeRefreshReq) (*

return &api.RevokeRefreshResp{}, nil
}

func (d dexAPI) revokeRefreshMultiRefreshMode(ctx context.Context, req *api.RevokeRefreshReq) (*api.RevokeRefreshResp, error) {
id := new(internal.IDTokenSubject)
if err := internal.Unmarshal(req.UserId, id); err != nil {
d.logger.Errorf("api: failed to unmarshal ID Token subject: %v", err)
return nil, err
}

// FIXME: listing all tokens can be slow
refreshTokens, err := d.s.ListRefreshTokens()
if err != nil {
return nil, err
}
if len(refreshTokens) == 0 {
return &api.RevokeRefreshResp{NotFound: true}, nil
}

for _, t := range refreshTokens {
if t.Claims.UserID == id.UserId && t.ConnectorID == id.ConnId && t.ClientID == req.ClientId {
if err := d.s.DeleteRefresh(t.ID); err != nil {
d.logger.Errorf("failed to delete refresh token: %v", err)
return nil, err
}
}
}

updater := func(old storage.OfflineSessions) (storage.OfflineSessions, error) {
// Remove entry from Refresh list of the OfflineSession object.
delete(old.Refresh, req.ClientId)
return old, nil
}

if err := d.s.UpdateOfflineSessions(id.UserId, id.ConnId, updater); err != nil {
d.logger.Errorf("api: failed to update offline session object: %v", err)
return nil, err
}

return &api.RevokeRefreshResp{}, nil
}
2 changes: 1 addition & 1 deletion server/api_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ func newAPI(s storage.Storage, logger log.Logger, t *testing.T) *apiClient {
}

serv := grpc.NewServer()
api.RegisterDexServer(serv, NewAPI(s, logger))
api.RegisterDexServer(serv, NewAPI(s, logger, false))
go serv.Serve(l)

// Dial will retry automatically if the serv.Serve() goroutine
Expand Down
44 changes: 26 additions & 18 deletions server/handlers.go
Original file line number Diff line number Diff line change
Expand Up @@ -919,13 +919,15 @@ func (s *Server) exchangeAuthCode(w http.ResponseWriter, authCode storage.AuthCo
return nil, err
}
} else {
if oldTokenRef, ok := session.Refresh[tokenRef.ClientID]; ok {
// Delete old refresh token from storage.
if err := s.storage.DeleteRefresh(oldTokenRef.ID); err != nil && err != storage.ErrNotFound {
s.logger.Errorf("failed to delete refresh token: %v", err)
s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError)
deleteToken = true
return nil, err
if !s.enableMultiRefreshTokens {
if oldTokenRef, ok := session.Refresh[tokenRef.ClientID]; ok {
// Delete old refresh token from storage.
if err := s.storage.DeleteRefresh(oldTokenRef.ID); err != nil && err != storage.ErrNotFound {
s.logger.Errorf("failed to delete refresh token: %v", err)
s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError)
deleteToken = true
return nil, err
}
}
}

Expand Down Expand Up @@ -1119,7 +1121,11 @@ func (s *Server) handleRefreshToken(w http.ResponseWriter, r *http.Request, clie
// in offline session for the user.
if err := s.storage.UpdateOfflineSessions(refresh.Claims.UserID, refresh.ConnectorID, func(old storage.OfflineSessions) (storage.OfflineSessions, error) {
if old.Refresh[refresh.ClientID].ID != refresh.ID {
return old, errors.New("refresh token invalid")
if s.enableMultiRefreshTokens {
return old, nil
} else {
return old, errors.New("refresh token invalid")
}
}
old.Refresh[refresh.ClientID].LastUsed = lastUsed
old.ConnectorData = ident.ConnectorData
Expand Down Expand Up @@ -1358,16 +1364,18 @@ func (s *Server) handlePasswordGrant(w http.ResponseWriter, r *http.Request, cli
return
}
} else {
if oldTokenRef, ok := session.Refresh[tokenRef.ClientID]; ok {
// Delete old refresh token from storage.
if err := s.storage.DeleteRefresh(oldTokenRef.ID); err != nil {
if err == storage.ErrNotFound {
s.logger.Warnf("database inconsistent, refresh token missing: %v", oldTokenRef.ID)
} else {
s.logger.Errorf("failed to delete refresh token: %v", err)
s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError)
deleteToken = true
return
if !s.enableMultiRefreshTokens {
if oldTokenRef, ok := session.Refresh[tokenRef.ClientID]; ok {
// Delete old refresh token from storage.
if err := s.storage.DeleteRefresh(oldTokenRef.ID); err != nil {
if err == storage.ErrNotFound {
s.logger.Warnf("database inconsistent, refresh token missing: %v", oldTokenRef.ID)
} else {
s.logger.Errorf("failed to delete refresh token: %v", err)
s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError)
deleteToken = true
return
}
}
}
}
Expand Down
31 changes: 18 additions & 13 deletions server/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,8 @@ type Config struct {
Logger log.Logger

PrometheusRegistry *prometheus.Registry

EnableMultiRefreshTokens bool
}

// WebConfig holds the server's frontend templates and asset configuration.
Expand Down Expand Up @@ -163,6 +165,8 @@ type Server struct {
deviceRequestsValidFor time.Duration

logger log.Logger

enableMultiRefreshTokens bool
}

// NewServer constructs a server from the provided config.
Expand Down Expand Up @@ -223,19 +227,20 @@ func newServer(ctx context.Context, c Config, rotationStrategy rotationStrategy)
}

s := &Server{
issuerURL: *issuerURL,
connectors: make(map[string]Connector),
storage: newKeyCacher(c.Storage, now),
supportedResponseTypes: supported,
idTokensValidFor: value(c.IDTokensValidFor, 24*time.Hour),
authRequestsValidFor: value(c.AuthRequestsValidFor, 24*time.Hour),
deviceRequestsValidFor: value(c.DeviceRequestsValidFor, 5*time.Minute),
skipApproval: c.SkipApprovalScreen,
alwaysShowLogin: c.AlwaysShowLoginScreen,
now: now,
templates: tmpls,
passwordConnector: c.PasswordConnector,
logger: c.Logger,
issuerURL: *issuerURL,
connectors: make(map[string]Connector),
storage: newKeyCacher(c.Storage, now),
supportedResponseTypes: supported,
idTokensValidFor: value(c.IDTokensValidFor, 24*time.Hour),
authRequestsValidFor: value(c.AuthRequestsValidFor, 24*time.Hour),
deviceRequestsValidFor: value(c.DeviceRequestsValidFor, 5*time.Minute),
skipApproval: c.SkipApprovalScreen,
alwaysShowLogin: c.AlwaysShowLoginScreen,
now: now,
templates: tmpls,
passwordConnector: c.PasswordConnector,
logger: c.Logger,
enableMultiRefreshTokens: c.EnableMultiRefreshTokens,
}

// Retrieves connector objects in backend storage. This list includes the static connectors
Expand Down