Skip to content

Commit

Permalink
support the implicit flow
Browse files Browse the repository at this point in the history
  • Loading branch information
colemickens committed Jun 26, 2018
1 parent fd56ebd commit 445da8d
Show file tree
Hide file tree
Showing 11 changed files with 148 additions and 55 deletions.
4 changes: 2 additions & 2 deletions connector/connector.go
Original file line number Diff line number Diff line change
Expand Up @@ -62,10 +62,10 @@ type CallbackConnector interface {
// requested if one has already been issues. There's no good general answer
// for these kind of restrictions, and may require this package to become more
// aware of the global set of user/connector interactions.
LoginURL(s Scopes, callbackURL, state string) (string, error)
LoginURL(s Scopes, callbackURL, state string) (string, []byte, error)

// Handle the callback to the server and return an identity.
HandleCallback(s Scopes, r *http.Request) (identity Identity, err error)
HandleCallback(s Scopes, r *http.Request, data []byte) (identity Identity, err error)
}

// SAMLConnector represents SAML connectors which implement the HTTP POST binding.
Expand Down
8 changes: 4 additions & 4 deletions connector/github/github.go
Original file line number Diff line number Diff line change
Expand Up @@ -168,12 +168,12 @@ func (c *githubConnector) oauth2Config(scopes connector.Scopes) *oauth2.Config {
}
}

func (c *githubConnector) LoginURL(scopes connector.Scopes, callbackURL, state string) (string, error) {
func (c *githubConnector) LoginURL(scopes connector.Scopes, callbackURL, state string) (string, []byte, error) {
if c.redirectURI != callbackURL {
return "", fmt.Errorf("expected callback URL %q did not match the URL in the config %q", callbackURL, c.redirectURI)
return "", nil, fmt.Errorf("expected callback URL %q did not match the URL in the config %q", callbackURL, c.redirectURI)
}

return c.oauth2Config(scopes).AuthCodeURL(state), nil
return c.oauth2Config(scopes).AuthCodeURL(state), nil, nil
}

type oauth2Error struct {
Expand Down Expand Up @@ -216,7 +216,7 @@ func newHTTPClient(rootCA string) (*http.Client, error) {
}, nil
}

func (c *githubConnector) HandleCallback(s connector.Scopes, r *http.Request) (identity connector.Identity, err error) {
func (c *githubConnector) HandleCallback(s connector.Scopes, r *http.Request, connDataBytes []byte) (identity connector.Identity, err error) {
q := r.URL.Query()
if errType := q.Get("error"); errType != "" {
return identity, &oauth2Error{errType, q.Get("error_description")}
Expand Down
8 changes: 4 additions & 4 deletions connector/gitlab/gitlab.go
Original file line number Diff line number Diff line change
Expand Up @@ -99,11 +99,11 @@ func (c *gitlabConnector) oauth2Config(scopes connector.Scopes) *oauth2.Config {
}
}

func (c *gitlabConnector) LoginURL(scopes connector.Scopes, callbackURL, state string) (string, error) {
func (c *gitlabConnector) LoginURL(scopes connector.Scopes, callbackURL, state string) (string, []byte, error) {
if c.redirectURI != callbackURL {
return "", fmt.Errorf("expected callback URL %q did not match the URL in the config %q", c.redirectURI, callbackURL)
return "", nil, fmt.Errorf("expected callback URL %q did not match the URL in the config %q", c.redirectURI, callbackURL)
}
return c.oauth2Config(scopes).AuthCodeURL(state), nil
return c.oauth2Config(scopes).AuthCodeURL(state), nil, nil
}

type oauth2Error struct {
Expand All @@ -118,7 +118,7 @@ func (e *oauth2Error) Error() string {
return e.error + ": " + e.errorDescription
}

func (c *gitlabConnector) HandleCallback(s connector.Scopes, r *http.Request) (identity connector.Identity, err error) {
func (c *gitlabConnector) HandleCallback(s connector.Scopes, r *http.Request, connDataBytes []byte) (identity connector.Identity, err error) {
q := r.URL.Query()
if errType := q.Get("error"); errType != "" {
return identity, &oauth2Error{errType, q.Get("error_description")}
Expand Down
8 changes: 4 additions & 4 deletions connector/linkedin/linkedin.go
Original file line number Diff line number Diff line change
Expand Up @@ -62,17 +62,17 @@ var (
)

// LoginURL returns an access token request URL
func (c *linkedInConnector) LoginURL(scopes connector.Scopes, callbackURL, state string) (string, error) {
func (c *linkedInConnector) LoginURL(scopes connector.Scopes, callbackURL, state string) (string, []byte, error) {
if c.oauth2Config.RedirectURL != callbackURL {
return "", fmt.Errorf("expected callback URL %q did not match the URL in the config %q",
return "", nil, fmt.Errorf("expected callback URL %q did not match the URL in the config %q",
callbackURL, c.oauth2Config.RedirectURL)
}

return c.oauth2Config.AuthCodeURL(state), nil
return c.oauth2Config.AuthCodeURL(state), nil, nil
}

// HandleCallback handles HTTP redirect from LinkedIn
func (c *linkedInConnector) HandleCallback(s connector.Scopes, r *http.Request) (identity connector.Identity, err error) {
func (c *linkedInConnector) HandleCallback(s connector.Scopes, r *http.Request, connDataBytes []byte) (identity connector.Identity, err error) {
q := r.URL.Query()
if errType := q.Get("error"); errType != "" {
return identity, &oauth2Error{errType, q.Get("error_description")}
Expand Down
8 changes: 4 additions & 4 deletions connector/microsoft/microsoft.go
Original file line number Diff line number Diff line change
Expand Up @@ -104,15 +104,15 @@ func (c *microsoftConnector) oauth2Config(scopes connector.Scopes) *oauth2.Confi
}
}

func (c *microsoftConnector) LoginURL(scopes connector.Scopes, callbackURL, state string) (string, error) {
func (c *microsoftConnector) LoginURL(scopes connector.Scopes, callbackURL, state string) (string, []byte, error) {
if c.redirectURI != callbackURL {
return "", fmt.Errorf("expected callback URL %q did not match the URL in the config %q", callbackURL, c.redirectURI)
return "", nil, fmt.Errorf("expected callback URL %q did not match the URL in the config %q", callbackURL, c.redirectURI)
}

return c.oauth2Config(scopes).AuthCodeURL(state), nil
return c.oauth2Config(scopes).AuthCodeURL(state), nil, nil
}

func (c *microsoftConnector) HandleCallback(s connector.Scopes, r *http.Request) (identity connector.Identity, err error) {
func (c *microsoftConnector) HandleCallback(s connector.Scopes, r *http.Request, connDataBytes []byte) (identity connector.Identity, err error) {
q := r.URL.Query()
if errType := q.Get("error"); errType != "" {
return identity, &oauth2Error{errType, q.Get("error_description")}
Expand Down
8 changes: 4 additions & 4 deletions connector/mock/connectortest.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,21 +42,21 @@ type Callback struct {
}

// LoginURL returns the URL to redirect the user to login with.
func (m *Callback) LoginURL(s connector.Scopes, callbackURL, state string) (string, error) {
func (m *Callback) LoginURL(s connector.Scopes, callbackURL, state string) (string, []byte, error) {
u, err := url.Parse(callbackURL)
if err != nil {
return "", fmt.Errorf("failed to parse callbackURL %q: %v", callbackURL, err)
return "", nil, fmt.Errorf("failed to parse callbackURL %q: %v", callbackURL, err)
}
v := u.Query()
v.Set("state", state)
u.RawQuery = v.Encode()
return u.String(), nil
return u.String(), nil, nil
}

var connectorData = []byte("foobar")

// HandleCallback parses the request and returns the user's identity
func (m *Callback) HandleCallback(s connector.Scopes, r *http.Request) (connector.Identity, error) {
func (m *Callback) HandleCallback(s connector.Scopes, r *http.Request, connDataBytes []byte) (connector.Identity, error) {
return m.Identity, nil
}

Expand Down
110 changes: 91 additions & 19 deletions connector/oidc/oidc.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package oidc

import (
"context"
"encoding/json"
"errors"
"fmt"
"net/http"
Expand All @@ -16,8 +17,11 @@ import (
"golang.org/x/oauth2"

"github.com/coreos/dex/connector"
"github.com/coreos/dex/storage"
)

const NoRefreshTokenDummy = "no-refresh-token"

// Config holds configuration options for OpenID Connect logins.
type Config struct {
Issuer string `json:"issuer"`
Expand All @@ -34,11 +38,19 @@ type Config struct {

Scopes []string `json:"scopes"` // defaults to "profile" and "email"

ResponseType string `json:"responseType"` // Default to "code"

// Optional list of whitelisted domains when using Google
// If this field is nonempty, only users from a listed domain will be allowed to log in
HostedDomains []string `json:"hostedDomains"`
}

// connectorData holds state that is needed between starting
// the auth flow, and validating the response
type connectorData struct {
Nonce string
}

// Domains that don't support basic auth. golang.org/x/oauth2 has an internal
// list, but it only matches specific URLs, not top level domains.
var brokenAuthHeaderDomains = []string{
Expand Down Expand Up @@ -84,7 +96,7 @@ func (c *Config) Open(id string, logger logrus.FieldLogger) (conn connector.Conn
cancel()
return nil, fmt.Errorf("failed to get provider: %v", err)
}

if c.BasicAuthUnsupported != nil {
// Setting "basicAuthUnsupported" always overrides our detection.
if *c.BasicAuthUnsupported {
Expand All @@ -94,15 +106,20 @@ func (c *Config) Open(id string, logger logrus.FieldLogger) (conn connector.Conn
registerBrokenAuthHeaderProvider(provider.Endpoint().TokenURL)
}

scopes := []string{oidc.ScopeOpenID}
if len(c.Scopes) > 0 {
scopes = append(scopes, c.Scopes...)
} else {
scopes = append(scopes, "profile", "email")
// if the user specifies scope, respect it
scopes := c.Scopes
if len(c.Scopes) == 0 {
scopes = []string{ oidc.ScopeOpenID, "profile", "email" }
}

if c.ResponseType != "id_token" && c.ResponseType != "code" {
err := fmt.Errorf("failed to create %s provider, unsupported response_type '%s'", id, c.ResponseType)
cancel()
return nil, err
}

clientID := c.ClientID
return &oidcConnector{
connector := &oidcConnector{
redirectURI: c.RedirectURI,
oauth2Config: &oauth2.Config{
ClientID: clientID,
Expand All @@ -114,10 +131,13 @@ func (c *Config) Open(id string, logger logrus.FieldLogger) (conn connector.Conn
verifier: provider.Verifier(
&oidc.Config{ClientID: clientID},
),
responseType: c.ResponseType,
logger: logger,
ctx: ctx,
cancel: cancel,
hostedDomains: c.HostedDomains,
}, nil
}
return connector, nil
}

var (
Expand All @@ -133,16 +153,18 @@ type oidcConnector struct {
cancel context.CancelFunc
logger logrus.FieldLogger
hostedDomains []string
responseType string
}

func (c *oidcConnector) Close() error {
c.cancel()
return nil
}

func (c *oidcConnector) LoginURL(s connector.Scopes, callbackURL, state string) (string, error) {
func (c *oidcConnector) LoginURL(s connector.Scopes, callbackURL, state string) (string, []byte, error) {
connData := connectorData{}
if c.redirectURI != callbackURL {
return "", fmt.Errorf("expected callback URL %q did not match the URL in the config %q", callbackURL, c.redirectURI)
return "", nil, fmt.Errorf("expected callback URL %q did not match the URL in the config %q", callbackURL, c.redirectURI)
}

var opts []oauth2.AuthCodeOption
Expand All @@ -157,7 +179,19 @@ func (c *oidcConnector) LoginURL(s connector.Scopes, callbackURL, state string)
if s.OfflineAccess {
opts = append(opts, oauth2.AccessTypeOffline, oauth2.SetAuthURLParam("prompt", "consent"))
}
return c.oauth2Config.AuthCodeURL(state, opts...), nil
if c.responseType == "id_token" {
connData.Nonce = storage.NewID()
opts = append(opts, oauth2.SetAuthURLParam("response_type", c.responseType))
opts = append(opts, oauth2.SetAuthURLParam("response_mode", "form_post"))
opts = append(opts, oauth2.SetAuthURLParam("nonce", connData.Nonce))
}
authCodeURL := c.oauth2Config.AuthCodeURL(state, opts...)

connDataBytes, err := json.Marshal(connData)
if err != nil {
return "", nil, fmt.Errorf("failed to encode connector data: %v", err)
}
return authCodeURL, connDataBytes, nil
}

type oauth2Error struct {
Expand All @@ -172,21 +206,45 @@ func (e *oauth2Error) Error() string {
return e.error + ": " + e.errorDescription
}

func (c *oidcConnector) HandleCallback(s connector.Scopes, r *http.Request) (identity connector.Identity, err error) {
func (c *oidcConnector) HandleCallback(s connector.Scopes, r *http.Request, connDataBytes []byte) (identity connector.Identity, err error) {
var connData connectorData
err = json.Unmarshal(connDataBytes, &connData)
if err != nil {
return identity, fmt.Errorf("failed to parse connector data: %v", err)
}

q := r.URL.Query()
if errType := q.Get("error"); errType != "" {
return identity, &oauth2Error{errType, q.Get("error_description")}
}

if c.responseType == "id_token" {
rawIDToken := r.FormValue("id_token")
if rawIDToken == "" {
return identity, fmt.Errorf("authorization response lacked id_token despite using the implicit flow")
}
return c.createIdentity(r.Context(), rawIDToken, connData, NoRefreshTokenDummy)
}

token, err := c.oauth2Config.Exchange(r.Context(), q.Get("code"))
if err != nil {
return identity, fmt.Errorf("oidc: failed to get token: %v", err)
}

return c.createIdentity(r.Context(), identity, token)
rawIDToken, ok := token.Extra("id_token").(string)
if !ok {
return identity, errors.New("oidc: no id_token in token response")
}

return c.createIdentity(r.Context(), rawIDToken, connData, token.RefreshToken)
}

// Refresh is implemented for backwards compatibility, even though it's a no-op.
func (c *oidcConnector) Refresh(ctx context.Context, s connector.Scopes, identity connector.Identity) (connector.Identity, error) {
if c.responseType == "id_token" {
return identity, fmt.Errorf("oidc: there is no refresh_token with implict flow")
}

t := &oauth2.Token{
RefreshToken: string(identity.ConnectorData),
Expiry: time.Now().Add(-time.Hour),
Expand All @@ -196,26 +254,40 @@ func (c *oidcConnector) Refresh(ctx context.Context, s connector.Scopes, identit
return identity, fmt.Errorf("oidc: failed to get token: %v", err)
}

return c.createIdentity(ctx, identity, token)
}

func (c *oidcConnector) createIdentity(ctx context.Context, identity connector.Identity, token *oauth2.Token) (connector.Identity, error) {
rawIDToken, ok := token.Extra("id_token").(string)
if !ok {
return identity, errors.New("oidc: no id_token in token response")
}

return c.createIdentity(ctx, rawIDToken, connectorData{}, token.RefreshToken)
}

func (c *oidcConnector) createIdentity(ctx context.Context, rawIDToken string, connData connectorData, refreshToken string) (identity connector.Identity, err error) {
idToken, err := c.verifier.Verify(ctx, rawIDToken)
if err != nil {
return identity, fmt.Errorf("oidc: failed to verify ID Token: %v", err)
}

if c.responseType == "id_token" {
// validate the nonce, we're in the implicit flow
var nonceClaim struct {
Nonce string `json:"nonce"`
}
if err := idToken.Claims(&nonceClaim); err != nil {
return identity, fmt.Errorf("oidc: failed to decode claims: %v", err)
}
if nonceClaim.Nonce != connData.Nonce {
return identity, fmt.Errorf("oidc: invalid nonce from provider.")
}
}

var claims struct {
Username string `json:"name"`
Email string `json:"email"`
EmailVerified bool `json:"email_verified"`
HostedDomain string `json:"hd"`
}
if err := idToken.Claims(&claims); err != nil {
if err = idToken.Claims(&claims); err != nil {
return identity, fmt.Errorf("oidc: failed to decode claims: %v", err)
}

Expand All @@ -238,7 +310,7 @@ func (c *oidcConnector) createIdentity(ctx context.Context, identity connector.I
Username: claims.Username,
Email: claims.Email,
EmailVerified: claims.EmailVerified,
ConnectorData: []byte(token.RefreshToken),
ConnectorData: []byte(refreshToken),
}
return identity, nil
}
Loading

0 comments on commit 445da8d

Please sign in to comment.