Skip to content

Commit

Permalink
Support for multiple refresh tokens
Browse files Browse the repository at this point in the history
Signed-off-by: Vinod Patil <[email protected]>
  • Loading branch information
vinod-trilio committed Sep 23, 2021
1 parent f92a6f4 commit 5970b22
Show file tree
Hide file tree
Showing 10 changed files with 236 additions and 33 deletions.
15 changes: 11 additions & 4 deletions cmd/dex/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -342,9 +342,16 @@ type Logger struct {
Format string `json:"format"`
}

type MultipleRefreshToken struct {
Allow bool `json:"allow"`
MaximumCount int `json:"maximumCount"`
ReplacementPolicy string `json:"replacementPolicy"`
}

type RefreshToken struct {
DisableRotation bool `json:"disableRotation"`
ReuseInterval string `json:"reuseInterval"`
AbsoluteLifetime string `json:"absoluteLifetime"`
ValidIfNotUsedFor string `json:"validIfNotUsedFor"`
DisableRotation bool `json:"disableRotation"`
MultipleTokens MultipleRefreshToken `json:"multipleTokens"`
ReuseInterval string `json:"reuseInterval"`
AbsoluteLifetime string `json:"absoluteLifetime"`
ValidIfNotUsedFor string `json:"validIfNotUsedFor"`
}
6 changes: 4 additions & 2 deletions cmd/dex/serve.go
Original file line number Diff line number Diff line change
Expand Up @@ -310,7 +310,9 @@ func runServe(options serveOptions) error {
c.Expiry.RefreshTokens.ValidIfNotUsedFor,
c.Expiry.RefreshTokens.AbsoluteLifetime,
c.Expiry.RefreshTokens.ReuseInterval,
)
c.Expiry.RefreshTokens.MultipleTokens.Allow,
c.Expiry.RefreshTokens.MultipleTokens.MaximumCount,
c.Expiry.RefreshTokens.MultipleTokens.ReplacementPolicy)
if err != nil {
return fmt.Errorf("invalid refresh token expiration policy config: %v", err)
}
Expand Down Expand Up @@ -449,7 +451,7 @@ func runServe(options serveOptions) error {
}

grpcSrv := grpc.NewServer(grpcOptions...)
api.RegisterDexServer(grpcSrv, server.NewAPI(serverConfig.Storage, logger, version))
api.RegisterDexServer(grpcSrv, server.NewAPI(serverConfig.Storage, logger, version, c.Expiry.RefreshTokens.MultipleTokens.Allow))

grpcMetrics.InitializeMetrics(grpcSrv)
if c.GRPC.Reflection {
Expand Down
95 changes: 91 additions & 4 deletions server/api.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,11 +29,12 @@ const (
)

// NewAPI returns a server which implements the gRPC API interface.
func NewAPI(s storage.Storage, logger log.Logger, version string) api.DexServer {
func NewAPI(s storage.Storage, logger log.Logger, version string, multiRefreshTokens bool) api.DexServer {
return dexAPI{
s: s,
logger: logger,
version: version,
s: s,
logger: logger,
version: version,
multiRefreshTokens: multiRefreshTokens,
}
}

Expand All @@ -43,6 +44,8 @@ type dexAPI struct {
s storage.Storage
logger log.Logger
version string

multiRefreshTokens bool
}

func (d dexAPI) CreateClient(ctx context.Context, req *api.CreateClientReq) (*api.CreateClientResp, error) {
Expand Down Expand Up @@ -283,6 +286,13 @@ func (d dexAPI) VerifyPassword(ctx context.Context, req *api.VerifyPasswordReq)
}

func (d dexAPI) ListRefresh(ctx context.Context, req *api.ListRefreshReq) (*api.ListRefreshResp, error) {
if d.multiRefreshTokens {
return d.listMultipleRefreshTokensMode(ctx, req)
}
return d.listRefresh(ctx, req)
}

func (d dexAPI) listRefresh(ctx context.Context, req *api.ListRefreshReq) (*api.ListRefreshResp, error) {
id := new(internal.IDTokenSubject)
if err := internal.Unmarshal(req.UserId, id); err != nil {
d.logger.Errorf("api: failed to unmarshal ID Token subject: %v", err)
Expand Down Expand Up @@ -316,7 +326,84 @@ func (d dexAPI) ListRefresh(ctx context.Context, req *api.ListRefreshReq) (*api.
}, nil
}

func (d dexAPI) listMultipleRefreshTokensMode(ctx context.Context, req *api.ListRefreshReq) (*api.ListRefreshResp, error) {
id := new(internal.IDTokenSubject)
if err := internal.Unmarshal(req.UserId, id); err != nil {
d.logger.Errorf("api: failed to unmarshal ID Token subject: %v", err)
return nil, err
}

var refreshTokenRefs []*api.RefreshTokenRef

// TODO: As OfflineSession has a reference to lastUpdated RefreshToken, listing add RefreshTokens and filtering
refreshTokens, err := d.s.ListRefreshTokens()
if err != nil {
return nil, err
}
for _, t := range refreshTokens {
if t.Claims.UserID == id.UserId && t.ConnectorID == id.ConnId {
r := api.RefreshTokenRef{
Id: t.ID,
ClientId: t.ClientID,
CreatedAt: t.CreatedAt.Unix(),
LastUsed: t.LastUsed.Unix(),
}
refreshTokenRefs = append(refreshTokenRefs, &r)
}
}

return &api.ListRefreshResp{
RefreshTokens: refreshTokenRefs,
}, nil
}

func (d dexAPI) RevokeRefresh(ctx context.Context, req *api.RevokeRefreshReq) (*api.RevokeRefreshResp, error) {
if d.multiRefreshTokens {
return d.revokeMultipleRefreshTokensMode(ctx, req)
}
return d.revokeRefresh(ctx, req)
}

func (d dexAPI) revokeMultipleRefreshTokensMode(ctx context.Context, req *api.RevokeRefreshReq) (*api.RevokeRefreshResp, error) {
id := new(internal.IDTokenSubject)
if err := internal.Unmarshal(req.UserId, id); err != nil {
d.logger.Errorf("api: failed to unmarshal ID Token subject: %v", err)
return nil, err
}

// FIXME: listing all tokens can be slow
refreshTokens, err := d.s.ListRefreshTokens()
if err != nil {
return nil, err
}
if len(refreshTokens) == 0 {
return &api.RevokeRefreshResp{NotFound: true}, nil
}

for _, t := range refreshTokens {
if t.Claims.UserID == id.UserId && t.ConnectorID == id.ConnId && t.ClientID == req.ClientId {
if err := d.s.DeleteRefresh(t.ID); err != nil {
d.logger.Errorf("failed to delete refresh token: %v", err)
return nil, err
}
}
}

updater := func(old storage.OfflineSessions) (storage.OfflineSessions, error) {
// Remove entry from Refresh list of the OfflineSession object.
delete(old.Refresh, req.ClientId)
return old, nil
}

if err := d.s.UpdateOfflineSessions(id.UserId, id.ConnId, updater); err != nil {
d.logger.Errorf("api: failed to update offline session object: %v", err)
return nil, err
}

return &api.RevokeRefreshResp{}, nil
}

func (d dexAPI) revokeRefresh(ctx context.Context, req *api.RevokeRefreshReq) (*api.RevokeRefreshResp, error) {
id := new(internal.IDTokenSubject)
if err := internal.Unmarshal(req.UserId, id); err != nil {
d.logger.Errorf("api: failed to unmarshal ID Token subject: %v", err)
Expand Down
2 changes: 1 addition & 1 deletion server/api_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ func newAPI(s storage.Storage, logger log.Logger, t *testing.T) *apiClient {
}

serv := grpc.NewServer()
api.RegisterDexServer(serv, NewAPI(s, logger, "test"))
api.RegisterDexServer(serv, NewAPI(s, logger, "test", false))
go serv.Serve(l)

// Dial will retry automatically if the serv.Serve() goroutine
Expand Down
88 changes: 71 additions & 17 deletions server/handlers.go
Original file line number Diff line number Diff line change
Expand Up @@ -961,13 +961,19 @@ func (s *Server) exchangeAuthCode(w http.ResponseWriter, authCode storage.AuthCo
return nil, err
}
} else {
if oldTokenRef, ok := session.Refresh[tokenRef.ClientID]; ok {
// Delete old refresh token from storage.
if err := s.storage.DeleteRefresh(oldTokenRef.ID); err != nil && err != storage.ErrNotFound {
s.logger.Errorf("failed to delete refresh token: %v", err)
s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError)
deleteToken = true
return nil, err
if !s.refreshTokenPolicy.allowMultiple {
if oldTokenRef, ok := session.Refresh[tokenRef.ClientID]; ok {
// Delete old refresh token from storage.
if err := s.storage.DeleteRefresh(oldTokenRef.ID); err != nil && err != storage.ErrNotFound {
s.logger.Errorf("failed to delete refresh token: %v", err)
s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError)
deleteToken = true
return nil, err
}
}
} else {
if err := s.deleteRefreshTokens(refresh.ConnectorID, refresh.Claims.UserID); err != nil {
s.logger.Errorf("error while deleting refresh token: %v", err)
}
}

Expand Down Expand Up @@ -1205,18 +1211,24 @@ func (s *Server) handlePasswordGrant(w http.ResponseWriter, r *http.Request, cli
return
}
} else {
if oldTokenRef, ok := session.Refresh[tokenRef.ClientID]; ok {
// Delete old refresh token from storage.
if err := s.storage.DeleteRefresh(oldTokenRef.ID); err != nil {
if err == storage.ErrNotFound {
s.logger.Warnf("database inconsistent, refresh token missing: %v", oldTokenRef.ID)
} else {
s.logger.Errorf("failed to delete refresh token: %v", err)
s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError)
deleteToken = true
return
if !s.refreshTokenPolicy.allowMultiple {
if oldTokenRef, ok := session.Refresh[tokenRef.ClientID]; ok {
// Delete old refresh token from storage.
if err := s.storage.DeleteRefresh(oldTokenRef.ID); err != nil {
if err == storage.ErrNotFound {
s.logger.Warnf("database inconsistent, refresh token missing: %v", oldTokenRef.ID)
} else {
s.logger.Errorf("failed to delete refresh token: %v", err)
s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError)
deleteToken = true
return
}
}
}
} else {
if err := s.deleteRefreshTokens(refresh.ConnectorID, refresh.Claims.UserID); err != nil {
s.logger.Errorf("error while deleting refresh token: %v", err)
}
}

// Update existing OfflineSession obj with new RefreshTokenRef.
Expand All @@ -1237,6 +1249,48 @@ func (s *Server) handlePasswordGrant(w http.ResponseWriter, r *http.Request, cli
s.writeAccessToken(w, resp)
}

func (s *Server) deleteRefreshTokens(connectorID string, userID string) error {
refreshTokens, err := s.storage.ListRefreshTokens()
if err != nil {
return err
}

var userRefreshTokens []storage.RefreshToken
for index := range refreshTokens {
refreshToken := refreshTokens[index]
if refreshToken.ConnectorID == connectorID && refreshToken.Claims.UserID == userID {
userRefreshTokens = append(userRefreshTokens, refreshToken)
}
}

if len(userRefreshTokens) <= s.refreshTokenPolicy.maxTokens {
return nil
}

sort.SliceStable(userRefreshTokens, func(i, j int) bool {
if s.refreshTokenPolicy.tokenReplacementPolicy == FCFS {
return userRefreshTokens[i].CreatedAt.Before(userRefreshTokens[j].CreatedAt)
} else {
return userRefreshTokens[i].LastUsed.Before(userRefreshTokens[j].LastUsed)
}
})

tokensToDelete := userRefreshTokens[:len(userRefreshTokens)-s.refreshTokenPolicy.maxTokens]
var deletionError bool
for index := range tokensToDelete {
refreshToken := tokensToDelete[index]
if err := s.storage.DeleteRefresh(refreshToken.ID); err != nil {
deletionError = true
s.logger.Errorf("error while deleting refresh token: %v", err)
}
}

if deletionError {
return fmt.Errorf("error while deleting refresh token for userID %s of connector %s", userID, connectorID)
}
return nil
}

type accessTokenResponse struct {
AccessToken string `json:"access_token"`
TokenType string `json:"token_type"`
Expand Down
5 changes: 5 additions & 0 deletions server/refreshhandlers.go
Original file line number Diff line number Diff line change
Expand Up @@ -197,8 +197,13 @@ func (s *Server) refreshWithConnector(ctx context.Context, token *internal.Refre
func (s *Server) updateOfflineSession(refresh *storage.RefreshToken, ident connector.Identity, lastUsed time.Time) *refreshError {
offlineSessionUpdater := func(old storage.OfflineSessions) (storage.OfflineSessions, error) {
if old.Refresh[refresh.ClientID].ID != refresh.ID {
// For multiple tokens only latest one is referred in offline session
if s.refreshTokenPolicy.allowMultiple {
return old, nil
}
return old, errors.New("refresh token invalid")
}

old.Refresh[refresh.ClientID].LastUsed = lastUsed
old.ConnectorData = ident.ConnectorData
return old, nil
Expand Down
41 changes: 40 additions & 1 deletion server/rotation.go
Original file line number Diff line number Diff line change
Expand Up @@ -178,19 +178,31 @@ func (k keyRotator) rotate() error {
return nil
}

type tokenReplacementPolicy string

const (
LRU tokenReplacementPolicy = "LRU"
FCFS tokenReplacementPolicy = "FCFS"
)

type RefreshTokenPolicy struct {
rotateRefreshTokens bool // enable rotation

absoluteLifetime time.Duration // interval from token creation to the end of its life
validIfNotUsedFor time.Duration // interval from last token update to the end of its life
reuseInterval time.Duration // interval within which old refresh token is allowed to be reused

allowMultiple bool
maxTokens int
tokenReplacementPolicy tokenReplacementPolicy

now func() time.Time

logger log.Logger
}

func NewRefreshTokenPolicy(logger log.Logger, rotation bool, validIfNotUsedFor, absoluteLifetime, reuseInterval string) (*RefreshTokenPolicy, error) {
func NewRefreshTokenPolicy(logger log.Logger, rotation bool, validIfNotUsedFor, absoluteLifetime, reuseInterval string,
allowMultiple bool, maxTokens int, tokenReplacementPolicy string) (*RefreshTokenPolicy, error) {
r := RefreshTokenPolicy{now: time.Now, logger: logger}
var err error

Expand Down Expand Up @@ -220,6 +232,33 @@ func NewRefreshTokenPolicy(logger log.Logger, rotation bool, validIfNotUsedFor,

r.rotateRefreshTokens = !rotation
logger.Infof("config refresh tokens rotation enabled: %v", r.rotateRefreshTokens)

r.allowMultiple = allowMultiple
logger.Infof("config refresh tokens allow multiple: %v", allowMultiple)

if allowMultiple {
if maxTokens < 1 {
r.maxTokens = 50
} else {
r.maxTokens = maxTokens
}
logger.Infof("config refresh tokens max multiple tokens: %v", r.maxTokens)

if tokenReplacementPolicy != "" {
switch tokenReplacementPolicy {
case string(LRU):
r.tokenReplacementPolicy = LRU
case string(FCFS):
r.tokenReplacementPolicy = FCFS
default:
return nil, fmt.Errorf("invalid config value %q for token replacement policy", tokenReplacementPolicy)
}
} else {
r.tokenReplacementPolicy = LRU
}
logger.Infof("config refresh tokens token replacement policy: %v", r.tokenReplacementPolicy)
}

return &r, nil
}

Expand Down
2 changes: 1 addition & 1 deletion server/rotation_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ func TestRefreshTokenPolicy(t *testing.T) {
Level: logrus.DebugLevel,
}

r, err := NewRefreshTokenPolicy(l, true, "1m", "1m", "1m")
r, err := NewRefreshTokenPolicy(l, true, "1m", "1m", "1m", false, 0, "")
require.NoError(t, err)

t.Run("Allowed", func(t *testing.T) {
Expand Down
2 changes: 1 addition & 1 deletion server/server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ func newTestServer(ctx context.Context, t *testing.T, updateConfig func(c *Confi

// Default rotation policy
if server.refreshTokenPolicy == nil {
server.refreshTokenPolicy, err = NewRefreshTokenPolicy(logger, false, "", "", "")
server.refreshTokenPolicy, err = NewRefreshTokenPolicy(logger, false, "", "", "", false, 0, "")
if err != nil {
t.Fatalf("failed to prepare rotation policy: %v", err)
}
Expand Down
Loading

0 comments on commit 5970b22

Please sign in to comment.