Skip to content

Commit

Permalink
Merge pull request #1180 from JoelSpeed/refresh-tokens
Browse files Browse the repository at this point in the history
Implement refreshing with Google
  • Loading branch information
bonifaido authored Nov 19, 2019
2 parents c392236 + 3156553 commit b1e98d8
Show file tree
Hide file tree
Showing 8 changed files with 172 additions and 58 deletions.
59 changes: 50 additions & 9 deletions connector/oidc/oidc.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,14 @@ package oidc

import (
"context"
"encoding/json"
"errors"
"fmt"
"net/http"
"net/url"
"strings"
"sync"
"time"

"github.com/coreos/go-oidc"
"golang.org/x/oauth2"
Expand Down Expand Up @@ -60,6 +62,11 @@ var brokenAuthHeaderDomains = []string{
"oktapreview.com",
}

// connectorData stores information for sessions authenticated by this connector
type connectorData struct {
RefreshToken []byte
}

// Detect auth header provider issues for known providers. This lets users
// avoid having to explicitly set "basicAuthUnsupported" in their config.
//
Expand Down Expand Up @@ -167,14 +174,19 @@ func (c *oidcConnector) LoginURL(s connector.Scopes, callbackURL, state string)
return "", fmt.Errorf("expected callback URL %q did not match the URL in the config %q", callbackURL, c.redirectURI)
}

var opts []oauth2.AuthCodeOption
if len(c.hostedDomains) > 0 {
preferredDomain := c.hostedDomains[0]
if len(c.hostedDomains) > 1 {
preferredDomain = "*"
}
return c.oauth2Config.AuthCodeURL(state, oauth2.SetAuthURLParam("hd", preferredDomain)), nil
opts = append(opts, oauth2.SetAuthURLParam("hd", preferredDomain))
}

if s.OfflineAccess {
opts = append(opts, oauth2.AccessTypeOffline, oauth2.SetAuthURLParam("prompt", "consent"))
}
return c.oauth2Config.AuthCodeURL(state), nil
return c.oauth2Config.AuthCodeURL(state, opts...), nil
}

type oauth2Error struct {
Expand All @@ -199,11 +211,35 @@ func (c *oidcConnector) HandleCallback(s connector.Scopes, r *http.Request) (ide
return identity, fmt.Errorf("oidc: failed to get token: %v", err)
}

return c.createIdentity(r.Context(), identity, token)
}

// Refresh is used to refresh a session with the refresh token provided by the IdP
func (c *oidcConnector) Refresh(ctx context.Context, s connector.Scopes, identity connector.Identity) (connector.Identity, error) {
cd := connectorData{}
err := json.Unmarshal(identity.ConnectorData, &cd)
if err != nil {
return identity, fmt.Errorf("oidc: failed to unmarshal connector data: %v", err)
}

t := &oauth2.Token{
RefreshToken: string(cd.RefreshToken),
Expiry: time.Now().Add(-time.Hour),
}
token, err := c.oauth2Config.TokenSource(ctx, t).Token()
if err != nil {
return identity, fmt.Errorf("oidc: failed to get refresh token: %v", err)
}

return c.createIdentity(ctx, identity, token)
}

func (c *oidcConnector) createIdentity(ctx context.Context, identity connector.Identity, token *oauth2.Token) (connector.Identity, error) {
rawIDToken, ok := token.Extra("id_token").(string)
if !ok {
return identity, errors.New("oidc: no id_token in token response")
}
idToken, err := c.verifier.Verify(r.Context(), rawIDToken)
idToken, err := c.verifier.Verify(ctx, rawIDToken)
if err != nil {
return identity, fmt.Errorf("oidc: failed to verify ID Token: %v", err)
}
Expand All @@ -215,7 +251,7 @@ func (c *oidcConnector) HandleCallback(s connector.Scopes, r *http.Request) (ide

// We immediately want to run getUserInfo if configured before we validate the claims
if c.getUserInfo {
userInfo, err := c.provider.UserInfo(r.Context(), oauth2.StaticTokenSource(token))
userInfo, err := c.provider.UserInfo(ctx, oauth2.StaticTokenSource(token))
if err != nil {
return identity, fmt.Errorf("oidc: error loading userinfo: %v", err)
}
Expand Down Expand Up @@ -260,11 +296,21 @@ func (c *oidcConnector) HandleCallback(s connector.Scopes, r *http.Request) (ide
}
}

cd := connectorData{
RefreshToken: []byte(token.RefreshToken),
}

connData, err := json.Marshal(&cd)
if err != nil {
return identity, fmt.Errorf("oidc: failed to encode connector data: %v", err)
}

identity = connector.Identity{
UserID: idToken.Subject,
Username: name,
Email: email,
EmailVerified: emailVerified,
ConnectorData: connData,
}

if c.userIDKey != "" {
Expand All @@ -277,8 +323,3 @@ func (c *oidcConnector) HandleCallback(s connector.Scopes, r *http.Request) (ide

return identity, nil
}

// Refresh is implemented for backwards compatibility, even though it's a no-op.
func (c *oidcConnector) Refresh(ctx context.Context, s connector.Scopes, identity connector.Identity) (connector.Identity, error) {
return identity, nil
}
60 changes: 57 additions & 3 deletions server/handlers.go
Original file line number Diff line number Diff line change
Expand Up @@ -505,7 +505,45 @@ func (s *Server) finalizeLogin(identity connector.Identity, authReq storage.Auth
s.logger.Infof("login successful: connector %q, username=%q, preferred_username=%q, email=%q, groups=%q",
authReq.ConnectorID, claims.Username, claims.PreferredUsername, email, claims.Groups)

return path.Join(s.issuerURL.Path, "/approval") + "?req=" + authReq.ID, nil
returnURL := path.Join(s.issuerURL.Path, "/approval") + "?req=" + authReq.ID
_, ok := conn.(connector.RefreshConnector)
if !ok {
return returnURL, nil
}

// Try to retrieve an existing OfflineSession object for the corresponding user.
if session, err := s.storage.GetOfflineSessions(identity.UserID, authReq.ConnectorID); err != nil {
if err != storage.ErrNotFound {
s.logger.Errorf("failed to get offline session: %v", err)
return "", err
}
offlineSessions := storage.OfflineSessions{
UserID: identity.UserID,
ConnID: authReq.ConnectorID,
Refresh: make(map[string]*storage.RefreshTokenRef),
ConnectorData: identity.ConnectorData,
}

// Create a new OfflineSession object for the user and add a reference object for
// the newly received refreshtoken.
if err := s.storage.CreateOfflineSessions(offlineSessions); err != nil {
s.logger.Errorf("failed to create offline session: %v", err)
return "", err
}
} else {
// Update existing OfflineSession obj with new RefreshTokenRef.
if err := s.storage.UpdateOfflineSessions(session.UserID, session.ConnID, func(old storage.OfflineSessions) (storage.OfflineSessions, error) {
if len(identity.ConnectorData) > 0 {
old.ConnectorData = identity.ConnectorData
}
return old, nil
}); err != nil {
s.logger.Errorf("failed to update offline session: %v", err)
return "", err
}
}

return returnURL, nil
}

func (s *Server) handleApproval(w http.ResponseWriter, r *http.Request) {
Expand Down Expand Up @@ -962,6 +1000,19 @@ func (s *Server) handleRefreshToken(w http.ResponseWriter, r *http.Request, clie
scopes = requestedScopes
}

var connectorData []byte
if session, err := s.storage.GetOfflineSessions(refresh.Claims.UserID, refresh.ConnectorID); err != nil {
if err != storage.ErrNotFound {
s.logger.Errorf("failed to get offline session: %v", err)
return
}
} else if len(refresh.ConnectorData) > 0 {
// Use the old connector data if it exists, should be deleted once used
connectorData = refresh.ConnectorData
} else {
connectorData = session.ConnectorData
}

conn, err := s.getConnector(refresh.ConnectorID)
if err != nil {
s.logger.Errorf("connector with ID %q not found: %v", refresh.ConnectorID, err)
Expand All @@ -975,7 +1026,7 @@ func (s *Server) handleRefreshToken(w http.ResponseWriter, r *http.Request, clie
Email: refresh.Claims.Email,
EmailVerified: refresh.Claims.EmailVerified,
Groups: refresh.Claims.Groups,
ConnectorData: refresh.ConnectorData,
ConnectorData: connectorData,
}

// Can the connector refresh the identity? If so, attempt to refresh the data
Expand Down Expand Up @@ -1041,8 +1092,10 @@ func (s *Server) handleRefreshToken(w http.ResponseWriter, r *http.Request, clie
old.Claims.Email = ident.Email
old.Claims.EmailVerified = ident.EmailVerified
old.Claims.Groups = ident.Groups
old.ConnectorData = ident.ConnectorData
old.LastUsed = lastUsed

// ConnectorData has been moved to OfflineSession
old.ConnectorData = []byte{}
return old, nil
}

Expand All @@ -1053,6 +1106,7 @@ func (s *Server) handleRefreshToken(w http.ResponseWriter, r *http.Request, clie
return old, errors.New("refresh token invalid")
}
old.Refresh[refresh.ClientID].LastUsed = lastUsed
old.ConnectorData = ident.ConnectorData
return old, nil
}); err != nil {
s.logger.Errorf("failed to update offline session: %v", err)
Expand Down
14 changes: 8 additions & 6 deletions storage/conformance/conformance.go
Original file line number Diff line number Diff line change
Expand Up @@ -515,9 +515,10 @@ func testPasswordCRUD(t *testing.T, s storage.Storage) {
func testOfflineSessionCRUD(t *testing.T, s storage.Storage) {
userID1 := storage.NewID()
session1 := storage.OfflineSessions{
UserID: userID1,
ConnID: "Conn1",
Refresh: make(map[string]*storage.RefreshTokenRef),
UserID: userID1,
ConnID: "Conn1",
Refresh: make(map[string]*storage.RefreshTokenRef),
ConnectorData: []byte(`{"some":"data"}`),
}

// Creating an OfflineSession with an empty Refresh list to ensure that
Expand All @@ -532,9 +533,10 @@ func testOfflineSessionCRUD(t *testing.T, s storage.Storage) {

userID2 := storage.NewID()
session2 := storage.OfflineSessions{
UserID: userID2,
ConnID: "Conn2",
Refresh: make(map[string]*storage.RefreshTokenRef),
UserID: userID2,
ConnID: "Conn2",
Refresh: make(map[string]*storage.RefreshTokenRef),
ConnectorData: []byte(`{"some":"data"}`),
}

if err := s.CreateOfflineSessions(session2); err != nil {
Expand Down
21 changes: 12 additions & 9 deletions storage/etcd/types.go
Original file line number Diff line number Diff line change
Expand Up @@ -188,24 +188,27 @@ type Keys struct {

// OfflineSessions is a mirrored struct from storage with JSON struct tags
type OfflineSessions struct {
UserID string `json:"user_id,omitempty"`
ConnID string `json:"conn_id,omitempty"`
Refresh map[string]*storage.RefreshTokenRef `json:"refresh,omitempty"`
UserID string `json:"user_id,omitempty"`
ConnID string `json:"conn_id,omitempty"`
Refresh map[string]*storage.RefreshTokenRef `json:"refresh,omitempty"`
ConnectorData []byte `json:"connectorData,omitempty"`
}

func fromStorageOfflineSessions(o storage.OfflineSessions) OfflineSessions {
return OfflineSessions{
UserID: o.UserID,
ConnID: o.ConnID,
Refresh: o.Refresh,
UserID: o.UserID,
ConnID: o.ConnID,
Refresh: o.Refresh,
ConnectorData: o.ConnectorData,
}
}

func toStorageOfflineSessions(o OfflineSessions) storage.OfflineSessions {
s := storage.OfflineSessions{
UserID: o.UserID,
ConnID: o.ConnID,
Refresh: o.Refresh,
UserID: o.UserID,
ConnID: o.ConnID,
Refresh: o.Refresh,
ConnectorData: o.ConnectorData,
}
if s.Refresh == nil {
// Server code assumes this will be non-nil.
Expand Down
21 changes: 12 additions & 9 deletions storage/kubernetes/types.go
Original file line number Diff line number Diff line change
Expand Up @@ -552,9 +552,10 @@ type OfflineSessions struct {
k8sapi.TypeMeta `json:",inline"`
k8sapi.ObjectMeta `json:"metadata,omitempty"`

UserID string `json:"userID,omitempty"`
ConnID string `json:"connID,omitempty"`
Refresh map[string]*storage.RefreshTokenRef `json:"refresh,omitempty"`
UserID string `json:"userID,omitempty"`
ConnID string `json:"connID,omitempty"`
Refresh map[string]*storage.RefreshTokenRef `json:"refresh,omitempty"`
ConnectorData []byte `json:"connectorData,omitempty"`
}

func (cli *client) fromStorageOfflineSessions(o storage.OfflineSessions) OfflineSessions {
Expand All @@ -567,17 +568,19 @@ func (cli *client) fromStorageOfflineSessions(o storage.OfflineSessions) Offline
Name: cli.offlineTokenName(o.UserID, o.ConnID),
Namespace: cli.namespace,
},
UserID: o.UserID,
ConnID: o.ConnID,
Refresh: o.Refresh,
UserID: o.UserID,
ConnID: o.ConnID,
Refresh: o.Refresh,
ConnectorData: o.ConnectorData,
}
}

func toStorageOfflineSessions(o OfflineSessions) storage.OfflineSessions {
s := storage.OfflineSessions{
UserID: o.UserID,
ConnID: o.ConnID,
Refresh: o.Refresh,
UserID: o.UserID,
ConnID: o.ConnID,
Refresh: o.Refresh,
ConnectorData: o.ConnectorData,
}
if s.Refresh == nil {
// Server code assumes this will be non-nil.
Expand Down
Loading

0 comments on commit b1e98d8

Please sign in to comment.