Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

support for id_token_hint #1125

Open
wants to merge 5 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 12 additions & 3 deletions cmd/example-app/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -278,12 +278,21 @@ func (a *app) handleCallback(w http.ResponseWriter, r *http.Request) {
}
token, err = oauth2Config.Exchange(ctx, code)
case "POST":
// Form request from frontend to refresh a token.
// Form request from frontend to refresh a token; or login again with hint
refresh := r.FormValue("refresh_token")
if refresh == "" {
http.Error(w, fmt.Sprintf("no refresh_token in request: %q", r.Form), http.StatusBadRequest)
idTokenHint := r.FormValue("id_token_hint")
if refresh == "" && idTokenHint == "" {
http.Error(w, fmt.Sprintf("no refresh_token or id_token_hint in request: %q", r.Form), http.StatusBadRequest)
return
}
if idTokenHint != "" {
// redirect to auth URL with the hint, using default scopes
scopes := []string{"openid", "profile", "email"}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems odd that these aren't taken from the first page.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah I was a bit lazy here. I'm going to try to fix this.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's actually not that simple -- we don't know much of the granted scopes when retrieving the id_token 🤔 Since this is the example app, I haven't spent too much time on implementing this.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

np. it's not critical.

authURL := a.oauth2Config(scopes).AuthCodeURL(exampleAppState)
http.Redirect(w, r, authURL+"&id_token_hint="+idTokenHint, http.StatusSeeOther)
return
}
// reaching this means refresh_token handling
t := &oauth2.Token{
RefreshToken: refresh,
Expiry: time.Now().Add(-time.Hour),
Expand Down
14 changes: 9 additions & 5 deletions cmd/example-app/templates.go
Original file line number Diff line number Diff line change
Expand Up @@ -50,13 +50,17 @@ pre {
<body>
<p> Token: <pre><code>{{ .IDToken }}</code></pre></p>
<p> Claims: <pre><code>{{ .Claims }}</code></pre></p>
{{ if .RefreshToken }}
{{ if .RefreshToken }}
<p> Refresh Token: <pre><code>{{ .RefreshToken }}</code></pre></p>
<form action="{{ .RedirectURL }}" method="post">
<input type="hidden" name="refresh_token" value="{{ .RefreshToken }}">
<input type="submit" value="Redeem refresh token">
<form action="{{ .RedirectURL }}" method="post">
<input type="hidden" name="refresh_token" value="{{ .RefreshToken }}">
<input type="submit" value="Redeem refresh token">
</form>
{{ end }}
<form action="{{ .RedirectURL }}" method="post">
<input type="hidden" name="id_token_hint" value="{{ .IDToken }}">
<input type="submit" value="Login again using ID token hint">
</form>
{{ end }}
</body>
</html>
`))
Expand Down
13 changes: 12 additions & 1 deletion server/handlers.go
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,12 @@ func (s *Server) handleAuthorization(w http.ResponseWriter, r *http.Request) {
return
}

// if we already know which connector the user should use -- go there directly
if authReq.ConnectorID != "" {
http.Redirect(w, r, s.absPath("/auth", authReq.ConnectorID)+"?backlink=none&req="+authReq.ID, http.StatusFound)
return
}

connectors, e := s.storage.ListConnectors()
if e != nil {
s.logger.Errorf("Failed to get list of connectors: %v", err)
Expand Down Expand Up @@ -223,7 +229,12 @@ func (s *Server) handleConnectorLogin(w http.ResponseWriter, r *http.Request) {
return
}
scopes := parseScopes(authReq.Scopes)
showBacklink := len(s.connectors) > 1

// allows for overriding the backlink display -- this is set when the initial
// request provided an `id_token_hint`, and we've forwarded the client based
// on this token's connector id.
backlinkOverride := r.FormValue("backlink")
showBacklink := len(s.connectors) > 1 && backlinkOverride != "none"

switch r.Method {
case "GET":
Expand Down
10 changes: 10 additions & 0 deletions server/internal/codec.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package internal

import (
"encoding/base64"
"encoding/json"

"github.com/golang/protobuf/proto"
)
Expand All @@ -23,3 +24,12 @@ func Unmarshal(s string, message proto.Message) error {
}
return proto.Unmarshal(data, message)
}

// UnmarshalJSON unmarshals the subject claim's internal format
func (s *IDTokenSubject) UnmarshalJSON(src []byte) error {
var sub string
if err := json.Unmarshal(src, &sub); err != nil {
return err
}
return Unmarshal(sub, s)
}
34 changes: 34 additions & 0 deletions server/oauth2.go
Original file line number Diff line number Diff line change
Expand Up @@ -376,6 +376,15 @@ func (s *Server) parseAuthorizationRequest(r *http.Request) (req storage.AuthReq
scopes := strings.Fields(q.Get("scope"))
responseTypes := strings.Fields(q.Get("response_type"))

var connectorID string
if hint := q.Get("id_token_hint"); hint != "" {
connectorID, err = connectorIDFromIDTokenHint(hint)
if err != nil {
s.logger.Errorf("failed to process id_token_hint: %s", err)
return req, &authErr{"", "", errInvalidRequest, "Invalid id_token_hint."}
}
}

client, err := s.storage.GetClient(clientID)
if err != nil {
if err == storage.ErrNotFound {
Expand Down Expand Up @@ -484,6 +493,7 @@ func (s *Server) parseAuthorizationRequest(r *http.Request) (req storage.AuthReq
return storage.AuthRequest{
ID: storage.NewID(),
ClientID: client.ID,
ConnectorID: connectorID,
State: state,
Nonce: nonce,
ForceApprovalPrompt: q.Get("approval_prompt") == "force",
Expand Down Expand Up @@ -548,3 +558,27 @@ func validateRedirectURI(client storage.Client, redirectURI string) bool {
host, _, err := net.SplitHostPort(u.Host)
return err == nil && host == "localhost"
}

// connectorIDFromIDTokenHint tries to extract the connector used when the
// passed ID token was issued. It does NOT check the JWT signature -- we might
// get presented an ID token that has long since expired and we still want to
// provide the correct connector pre-selection. Note, however, that the token
// MAY BE VALID and should this not be part of any error output.
func connectorIDFromIDTokenHint(hint string) (string, error) {
parts := strings.SplitN(hint, ".", 3)
if len(parts) != 3 {
return "", errors.New("wrong format")
}
payload, err := base64.RawURLEncoding.DecodeString(parts[1])
if err != nil {
return "", fmt.Errorf("failed to decode payload: %v", err)
}

cl := struct {
Sub internal.IDTokenSubject `json:"sub"`
}{}
if err := json.Unmarshal([]byte(payload), &cl); err != nil {
return "", fmt.Errorf("failed to unmarshal payload: %v", err)
}
return cl.Sub.ConnId, nil
}
75 changes: 70 additions & 5 deletions server/oauth2_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@ func TestParseAuthorizationRequest(t *testing.T) {
queryParams map[string]string

wantErr bool

authReqCheck func(*testing.T, *storage.AuthRequest)
}{
{
name: "normal request",
Expand All @@ -41,6 +43,66 @@ func TestParseAuthorizationRequest(t *testing.T) {
"scope": "openid email profile",
},
},
{
name: "request with valid id_token_hint",
clients: []storage.Client{
{
ID: "foo",
RedirectURIs: []string{"https://example.com/foo"},
},
},
supportedResponseTypes: []string{"code"},
queryParams: map[string]string{
// sub: "Cg0wLTM4NS0yODA4OS0wEgRtb2Nr" = {0-385-28089-0, "mock"}
"id_token_hint": "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiJDZzB3TFRNNE5TMHlPREE0T1Mwd0VnUnRiMk5yIn0.M1mYqRIYeAaLeo3B7DWj_Nxm589tbworSGffCIgBz04",
"client_id": "foo",
"redirect_uri": "https://example.com/foo",
"response_type": "code",
"scope": "openid email profile",
},
authReqCheck: func(t *testing.T, ar *storage.AuthRequest) {
if ar.ConnectorID != "mock" {
t.Errorf("expected connectorID \"mock\", got %v", ar.ConnectorID)
}
},
},
{
name: "request with non-jwt id_token_hint",
clients: []storage.Client{
{
ID: "foo",
RedirectURIs: []string{"https://example.com/foo"},
},
},
supportedResponseTypes: []string{"code"},
queryParams: map[string]string{
"id_token_hint": "notevenajwt",
"client_id": "foo",
"redirect_uri": "https://example.com/foo",
"response_type": "code",
"scope": "openid email profile",
},
wantErr: true,
},
{
name: "request with invalid id_token_hint (bad sub)",
clients: []storage.Client{
{
ID: "foo",
RedirectURIs: []string{"https://example.com/foo"},
},
},
supportedResponseTypes: []string{"code"},
queryParams: map[string]string{
// sub: "ject"
"id_token_hint": "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiJqZWN0In0.GwGd9UBSu2XrfeULp8u3KI0jZPt1ccUIGk1TaCCtLqE",
"client_id": "foo",
"redirect_uri": "https://example.com/foo",
"response_type": "code",
"scope": "openid email profile",
},
wantErr: true,
},
{
name: "POST request",
clients: []storage.Client{
Expand Down Expand Up @@ -145,7 +207,7 @@ func TestParseAuthorizationRequest(t *testing.T) {
}

for _, tc := range tests {
func() {
t.Run(tc.name, func(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()

Expand All @@ -168,14 +230,17 @@ func TestParseAuthorizationRequest(t *testing.T) {
} else {
req = httptest.NewRequest("GET", httpServer.URL+"/auth?"+params.Encode(), nil)
}
_, err := server.parseAuthorizationRequest(req)
resp, err := server.parseAuthorizationRequest(req)
if err != nil && !tc.wantErr {
t.Errorf("%s: %v", tc.name, err)
t.Fatal(err)
}
if err == nil && tc.wantErr {
t.Errorf("%s: expected error", tc.name)
t.Error("expected error")
}
if tc.authReqCheck != nil {
tc.authReqCheck(t, &resp)
}
}()
})
}
}

Expand Down