Skip to content

Commit

Permalink
Add connectorData to handle state from LoginUrl to HandleCallback
Browse files Browse the repository at this point in the history
Signed-off-by: johnvan7 <[email protected]>
  • Loading branch information
johnvan7 committed Oct 25, 2024
1 parent 6eee8c1 commit 09dea1c
Show file tree
Hide file tree
Showing 22 changed files with 160 additions and 152 deletions.
8 changes: 4 additions & 4 deletions connector/bitbucketcloud/bitbucketcloud.go
Original file line number Diff line number Diff line change
Expand Up @@ -111,12 +111,12 @@ func (b *bitbucketConnector) oauth2Config(scopes connector.Scopes) *oauth2.Confi
}
}

func (b *bitbucketConnector) LoginURL(scopes connector.Scopes, callbackURL, state string) (string, error) {
func (b *bitbucketConnector) LoginURL(scopes connector.Scopes, callbackURL, state string) (string, []byte, error) {
if b.redirectURI != callbackURL {
return "", fmt.Errorf("expected callback URL %q did not match the URL in the config %q", callbackURL, b.redirectURI)
return "", nil, fmt.Errorf("expected callback URL %q did not match the URL in the config %q", callbackURL, b.redirectURI)
}

return b.oauth2Config(scopes).AuthCodeURL(state), nil
return b.oauth2Config(scopes).AuthCodeURL(state), nil, nil
}

type oauth2Error struct {
Expand All @@ -131,7 +131,7 @@ func (e *oauth2Error) Error() string {
return e.error + ": " + e.errorDescription
}

func (b *bitbucketConnector) HandleCallback(s connector.Scopes, r *http.Request) (identity connector.Identity, err error) {
func (b *bitbucketConnector) HandleCallback(s connector.Scopes, connData []byte, r *http.Request) (identity connector.Identity, err error) {
q := r.URL.Query()
if errType := q.Get("error"); errType != "" {
return identity, &oauth2Error{errType, q.Get("error_description")}
Expand Down
2 changes: 1 addition & 1 deletion connector/bitbucketcloud/bitbucketcloud_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ func TestUsernameIncludedInFederatedIdentity(t *testing.T) {
expectNil(t, err)

bitbucketConnector := bitbucketConnector{apiURL: s.URL, hostName: hostURL.Host, httpClient: newClient()}
identity, err := bitbucketConnector.HandleCallback(connector.Scopes{}, req)
identity, err := bitbucketConnector.HandleCallback(connector.Scopes{}, nil, req)

expectNil(t, err)
expectEquals(t, identity.Username, "some-login")
Expand Down
4 changes: 2 additions & 2 deletions connector/connector.go
Original file line number Diff line number Diff line change
Expand Up @@ -63,10 +63,10 @@ type CallbackConnector interface {
// requested if one has already been issues. There's no good general answer
// for these kind of restrictions, and may require this package to become more
// aware of the global set of user/connector interactions.
LoginURL(s Scopes, callbackURL, state string) (string, error)
LoginURL(s Scopes, callbackURL, state string) (string, []byte, error)

// Handle the callback to the server and return an identity.
HandleCallback(s Scopes, r *http.Request) (identity Identity, err error)
HandleCallback(s Scopes, connData []byte, r *http.Request) (identity Identity, err error)
}

// SAMLConnector represents SAML connectors which implement the HTTP POST binding.
Expand Down
8 changes: 4 additions & 4 deletions connector/gitea/gitea.go
Original file line number Diff line number Diff line change
Expand Up @@ -102,11 +102,11 @@ func (c *giteaConnector) oauth2Config(_ connector.Scopes) *oauth2.Config {
}
}

func (c *giteaConnector) LoginURL(scopes connector.Scopes, callbackURL, state string) (string, error) {
func (c *giteaConnector) LoginURL(scopes connector.Scopes, callbackURL, state string) (string, []byte, error) {
if c.redirectURI != callbackURL {
return "", fmt.Errorf("expected callback URL %q did not match the URL in the config %q", c.redirectURI, callbackURL)
return "", nil, fmt.Errorf("expected callback URL %q did not match the URL in the config %q", c.redirectURI, callbackURL)
}
return c.oauth2Config(scopes).AuthCodeURL(state), nil
return c.oauth2Config(scopes).AuthCodeURL(state), nil, nil
}

type oauth2Error struct {
Expand All @@ -121,7 +121,7 @@ func (e *oauth2Error) Error() string {
return e.error + ": " + e.errorDescription
}

func (c *giteaConnector) HandleCallback(s connector.Scopes, r *http.Request) (identity connector.Identity, err error) {
func (c *giteaConnector) HandleCallback(s connector.Scopes, connData []byte, r *http.Request) (identity connector.Identity, err error) {
q := r.URL.Query()
if errType := q.Get("error"); errType != "" {
return identity, &oauth2Error{errType, q.Get("error_description")}
Expand Down
4 changes: 2 additions & 2 deletions connector/gitea/gitea_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,14 +30,14 @@ func TestUsernameIncludedInFederatedIdentity(t *testing.T) {
expectNil(t, err)

c := giteaConnector{baseURL: s.URL, httpClient: newClient()}
identity, err := c.HandleCallback(connector.Scopes{}, req)
identity, err := c.HandleCallback(connector.Scopes{}, nil, req)

expectNil(t, err)
expectEquals(t, identity.Username, "[email protected]")
expectEquals(t, identity.UserID, "12345678")

c = giteaConnector{baseURL: s.URL, httpClient: newClient()}
identity, err = c.HandleCallback(connector.Scopes{}, req)
identity, err = c.HandleCallback(connector.Scopes{}, nil, req)

expectNil(t, err)
expectEquals(t, identity.Username, "[email protected]")
Expand Down
8 changes: 4 additions & 4 deletions connector/github/github.go
Original file line number Diff line number Diff line change
Expand Up @@ -194,12 +194,12 @@ func (c *githubConnector) oauth2Config(scopes connector.Scopes) *oauth2.Config {
}
}

func (c *githubConnector) LoginURL(scopes connector.Scopes, callbackURL, state string) (string, error) {
func (c *githubConnector) LoginURL(scopes connector.Scopes, callbackURL, state string) (string, []byte, error) {
if c.redirectURI != callbackURL {
return "", fmt.Errorf("expected callback URL %q did not match the URL in the config %q", callbackURL, c.redirectURI)
return "", nil, fmt.Errorf("expected callback URL %q did not match the URL in the config %q", callbackURL, c.redirectURI)
}

return c.oauth2Config(scopes).AuthCodeURL(state), nil
return c.oauth2Config(scopes).AuthCodeURL(state), nil, nil
}

type oauth2Error struct {
Expand All @@ -214,7 +214,7 @@ func (e *oauth2Error) Error() string {
return e.error + ": " + e.errorDescription
}

func (c *githubConnector) HandleCallback(s connector.Scopes, r *http.Request) (identity connector.Identity, err error) {
func (c *githubConnector) HandleCallback(s connector.Scopes, connData []byte, r *http.Request) (identity connector.Identity, err error) {
q := r.URL.Query()
if errType := q.Get("error"); errType != "" {
return identity, &oauth2Error{errType, q.Get("error_description")}
Expand Down
6 changes: 3 additions & 3 deletions connector/github/github_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -153,15 +153,15 @@ func TestUsernameIncludedInFederatedIdentity(t *testing.T) {
expectNil(t, err)

c := githubConnector{apiURL: s.URL, hostName: hostURL.Host, httpClient: newClient()}
identity, err := c.HandleCallback(connector.Scopes{Groups: true}, req)
identity, err := c.HandleCallback(connector.Scopes{Groups: true}, nil, req)

expectNil(t, err)
expectEquals(t, identity.Username, "some-login")
expectEquals(t, identity.UserID, "12345678")
expectEquals(t, 0, len(identity.Groups))

c = githubConnector{apiURL: s.URL, hostName: hostURL.Host, httpClient: newClient(), loadAllGroups: true}
identity, err = c.HandleCallback(connector.Scopes{Groups: true}, req)
identity, err = c.HandleCallback(connector.Scopes{Groups: true}, nil, req)

expectNil(t, err)
expectEquals(t, identity.Username, "some-login")
Expand Down Expand Up @@ -194,7 +194,7 @@ func TestLoginUsedAsIDWhenConfigured(t *testing.T) {
expectNil(t, err)

c := githubConnector{apiURL: s.URL, hostName: hostURL.Host, httpClient: newClient(), useLoginAsID: true}
identity, err := c.HandleCallback(connector.Scopes{Groups: true}, req)
identity, err := c.HandleCallback(connector.Scopes{Groups: true}, nil, req)

expectNil(t, err)
expectEquals(t, identity.UserID, "some-login")
Expand Down
8 changes: 4 additions & 4 deletions connector/gitlab/gitlab.go
Original file line number Diff line number Diff line change
Expand Up @@ -100,11 +100,11 @@ func (c *gitlabConnector) oauth2Config(scopes connector.Scopes) *oauth2.Config {
}
}

func (c *gitlabConnector) LoginURL(scopes connector.Scopes, callbackURL, state string) (string, error) {
func (c *gitlabConnector) LoginURL(scopes connector.Scopes, callbackURL, state string) (string, []byte, error) {
if c.redirectURI != callbackURL {
return "", fmt.Errorf("expected callback URL %q did not match the URL in the config %q", c.redirectURI, callbackURL)
return "", nil, fmt.Errorf("expected callback URL %q did not match the URL in the config %q", c.redirectURI, callbackURL)
}
return c.oauth2Config(scopes).AuthCodeURL(state), nil
return c.oauth2Config(scopes).AuthCodeURL(state), nil, nil
}

type oauth2Error struct {
Expand All @@ -119,7 +119,7 @@ func (e *oauth2Error) Error() string {
return e.error + ": " + e.errorDescription
}

func (c *gitlabConnector) HandleCallback(s connector.Scopes, r *http.Request) (identity connector.Identity, err error) {
func (c *gitlabConnector) HandleCallback(s connector.Scopes, connData []byte, r *http.Request) (identity connector.Identity, err error) {
q := r.URL.Query()
if errType := q.Get("error"); errType != "" {
return identity, &oauth2Error{errType, q.Get("error_description")}
Expand Down
12 changes: 6 additions & 6 deletions connector/gitlab/gitlab_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -84,15 +84,15 @@ func TestUsernameIncludedInFederatedIdentity(t *testing.T) {
expectNil(t, err)

c := gitlabConnector{baseURL: s.URL, httpClient: newClient()}
identity, err := c.HandleCallback(connector.Scopes{Groups: false}, req)
identity, err := c.HandleCallback(connector.Scopes{Groups: false}, nil, req)

expectNil(t, err)
expectEquals(t, identity.Username, "[email protected]")
expectEquals(t, identity.UserID, "12345678")
expectEquals(t, 0, len(identity.Groups))

c = gitlabConnector{baseURL: s.URL, httpClient: newClient()}
identity, err = c.HandleCallback(connector.Scopes{Groups: true}, req)
identity, err = c.HandleCallback(connector.Scopes{Groups: true}, nil, req)

expectNil(t, err)
expectEquals(t, identity.Username, "[email protected]")
Expand Down Expand Up @@ -120,7 +120,7 @@ func TestLoginUsedAsIDWhenConfigured(t *testing.T) {
expectNil(t, err)

c := gitlabConnector{baseURL: s.URL, httpClient: newClient(), useLoginAsID: true}
identity, err := c.HandleCallback(connector.Scopes{Groups: true}, req)
identity, err := c.HandleCallback(connector.Scopes{Groups: true}, nil, req)

expectNil(t, err)
expectEquals(t, identity.UserID, "joebloggs")
Expand All @@ -147,7 +147,7 @@ func TestLoginWithTeamWhitelisted(t *testing.T) {
expectNil(t, err)

c := gitlabConnector{baseURL: s.URL, httpClient: newClient(), groups: []string{"team-1"}}
identity, err := c.HandleCallback(connector.Scopes{Groups: true}, req)
identity, err := c.HandleCallback(connector.Scopes{Groups: true}, nil, req)

expectNil(t, err)
expectEquals(t, identity.UserID, "12345678")
Expand All @@ -174,7 +174,7 @@ func TestLoginWithTeamNonWhitelisted(t *testing.T) {
expectNil(t, err)

c := gitlabConnector{baseURL: s.URL, httpClient: newClient(), groups: []string{"team-2"}}
_, err = c.HandleCallback(connector.Scopes{Groups: true}, req)
_, err = c.HandleCallback(connector.Scopes{Groups: true}, nil, req)

expectNotNil(t, err, "HandleCallback error")
expectEquals(t, err.Error(), "gitlab: get groups: gitlab: user \"joebloggs\" is not in any of the required groups")
Expand Down Expand Up @@ -208,7 +208,7 @@ func TestRefresh(t *testing.T) {
})
expectNil(t, err)

identity, err := c.HandleCallback(connector.Scopes{OfflineAccess: true}, req)
identity, err := c.HandleCallback(connector.Scopes{OfflineAccess: true}, nil, req)
expectNil(t, err)
expectEquals(t, identity.Username, "[email protected]")
expectEquals(t, identity.UserID, "12345678")
Expand Down
8 changes: 4 additions & 4 deletions connector/google/google.go
Original file line number Diff line number Diff line change
Expand Up @@ -168,9 +168,9 @@ func (c *googleConnector) Close() error {
return nil
}

func (c *googleConnector) LoginURL(s connector.Scopes, callbackURL, state string) (string, error) {
func (c *googleConnector) LoginURL(s connector.Scopes, callbackURL, state string) (string, []byte, error) {
if c.redirectURI != callbackURL {
return "", fmt.Errorf("expected callback URL %q did not match the URL in the config %q", callbackURL, c.redirectURI)
return "", nil, fmt.Errorf("expected callback URL %q did not match the URL in the config %q", callbackURL, c.redirectURI)
}

var opts []oauth2.AuthCodeOption
Expand All @@ -186,7 +186,7 @@ func (c *googleConnector) LoginURL(s connector.Scopes, callbackURL, state string
opts = append(opts, oauth2.AccessTypeOffline, oauth2.SetAuthURLParam("prompt", c.promptType))
}

return c.oauth2Config.AuthCodeURL(state, opts...), nil
return c.oauth2Config.AuthCodeURL(state, opts...), nil, nil
}

type oauth2Error struct {
Expand All @@ -201,7 +201,7 @@ func (e *oauth2Error) Error() string {
return e.error + ": " + e.errorDescription
}

func (c *googleConnector) HandleCallback(s connector.Scopes, r *http.Request) (identity connector.Identity, err error) {
func (c *googleConnector) HandleCallback(s connector.Scopes, connData []byte, r *http.Request) (identity connector.Identity, err error) {
q := r.URL.Query()
if errType := q.Get("error"); errType != "" {
return identity, &oauth2Error{errType, q.Get("error_description")}
Expand Down
2 changes: 1 addition & 1 deletion connector/google/google_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -440,7 +440,7 @@ func TestPromptTypeConfig(t *testing.T) {
assert.Nil(t, err)
assert.Equal(t, test.expectedPromptTypeValue, conn.promptType)

loginURL, err := conn.LoginURL(connector.Scopes{OfflineAccess: true}, ts.URL+"/callback", "state")
loginURL, _, err := conn.LoginURL(connector.Scopes{OfflineAccess: true}, ts.URL+"/callback", "state")
assert.Nil(t, err)

urlp, err := url.Parse(loginURL)
Expand Down
8 changes: 4 additions & 4 deletions connector/linkedin/linkedin.go
Original file line number Diff line number Diff line change
Expand Up @@ -62,17 +62,17 @@ var (
)

// LoginURL returns an access token request URL
func (c *linkedInConnector) LoginURL(scopes connector.Scopes, callbackURL, state string) (string, error) {
func (c *linkedInConnector) LoginURL(scopes connector.Scopes, callbackURL, state string) (string, []byte, error) {
if c.oauth2Config.RedirectURL != callbackURL {
return "", fmt.Errorf("expected callback URL %q did not match the URL in the config %q",
return "", nil, fmt.Errorf("expected callback URL %q did not match the URL in the config %q",
callbackURL, c.oauth2Config.RedirectURL)
}

return c.oauth2Config.AuthCodeURL(state), nil
return c.oauth2Config.AuthCodeURL(state), nil, nil
}

// HandleCallback handles HTTP redirect from LinkedIn
func (c *linkedInConnector) HandleCallback(s connector.Scopes, r *http.Request) (identity connector.Identity, err error) {
func (c *linkedInConnector) HandleCallback(s connector.Scopes, connData []byte, r *http.Request) (identity connector.Identity, err error) {
q := r.URL.Query()
if errType := q.Get("error"); errType != "" {
return identity, &oauth2Error{errType, q.Get("error_description")}
Expand Down
8 changes: 4 additions & 4 deletions connector/microsoft/microsoft.go
Original file line number Diff line number Diff line change
Expand Up @@ -175,9 +175,9 @@ func (c *microsoftConnector) oauth2Config(scopes connector.Scopes) *oauth2.Confi
}
}

func (c *microsoftConnector) LoginURL(scopes connector.Scopes, callbackURL, state string) (string, error) {
func (c *microsoftConnector) LoginURL(scopes connector.Scopes, callbackURL, state string) (string, []byte, error) {
if c.redirectURI != callbackURL {
return "", fmt.Errorf("expected callback URL %q did not match the URL in the config %q", callbackURL, c.redirectURI)
return "", nil, fmt.Errorf("expected callback URL %q did not match the URL in the config %q", callbackURL, c.redirectURI)
}

var options []oauth2.AuthCodeOption
Expand All @@ -188,10 +188,10 @@ func (c *microsoftConnector) LoginURL(scopes connector.Scopes, callbackURL, stat
options = append(options, oauth2.SetAuthURLParam("domain_hint", c.domainHint))
}

return c.oauth2Config(scopes).AuthCodeURL(state, options...), nil
return c.oauth2Config(scopes).AuthCodeURL(state, options...), nil, nil
}

func (c *microsoftConnector) HandleCallback(s connector.Scopes, r *http.Request) (identity connector.Identity, err error) {
func (c *microsoftConnector) HandleCallback(s connector.Scopes, connData []byte, r *http.Request) (identity connector.Identity, err error) {
q := r.URL.Query()
if errType := q.Get("error"); errType != "" {
return identity, &oauth2Error{errType, q.Get("error_description")}
Expand Down
8 changes: 4 additions & 4 deletions connector/microsoft/microsoft_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ func TestLoginURL(t *testing.T) {
tenant: tenant,
}

loginURL, _ := conn.LoginURL(connector.Scopes{}, conn.redirectURI, testState)
loginURL, _, _ := conn.LoginURL(connector.Scopes{}, conn.redirectURI, testState)

parsedLoginURL, _ := url.Parse(loginURL)
queryParams := parsedLoginURL.Query()
Expand Down Expand Up @@ -70,7 +70,7 @@ func TestLoginURLWithOptions(t *testing.T) {
domainHint: domainHint,
}

loginURL, _ := conn.LoginURL(connector.Scopes{}, conn.redirectURI, "some-state")
loginURL, _, _ := conn.LoginURL(connector.Scopes{}, conn.redirectURI, "some-state")

parsedLoginURL, _ := url.Parse(loginURL)
queryParams := parsedLoginURL.Query()
Expand All @@ -91,7 +91,7 @@ func TestUserIdentityFromGraphAPI(t *testing.T) {
req, _ := http.NewRequest("GET", s.URL, nil)

c := microsoftConnector{apiURL: s.URL, graphURL: s.URL, tenant: tenant}
identity, err := c.HandleCallback(connector.Scopes{Groups: false}, req)
identity, err := c.HandleCallback(connector.Scopes{Groups: false}, nil, req)
expectNil(t, err)
expectEquals(t, identity.Username, "Jane Doe")
expectEquals(t, identity.UserID, "S56767889")
Expand All @@ -114,7 +114,7 @@ func TestUserGroupsFromGraphAPI(t *testing.T) {
req, _ := http.NewRequest("GET", s.URL, nil)

c := microsoftConnector{apiURL: s.URL, graphURL: s.URL, tenant: tenant}
identity, err := c.HandleCallback(connector.Scopes{Groups: true}, req)
identity, err := c.HandleCallback(connector.Scopes{Groups: true}, nil, req)
expectNil(t, err)
expectEquals(t, identity.Groups, []string{"a", "b"})
}
Expand Down
8 changes: 4 additions & 4 deletions connector/mock/connectortest.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,21 +43,21 @@ type Callback struct {
}

// LoginURL returns the URL to redirect the user to login with.
func (m *Callback) LoginURL(s connector.Scopes, callbackURL, state string) (string, error) {
func (m *Callback) LoginURL(s connector.Scopes, callbackURL, state string) (string, []byte, error) {
u, err := url.Parse(callbackURL)
if err != nil {
return "", fmt.Errorf("failed to parse callbackURL %q: %v", callbackURL, err)
return "", nil, fmt.Errorf("failed to parse callbackURL %q: %v", callbackURL, err)
}
v := u.Query()
v.Set("state", state)
u.RawQuery = v.Encode()
return u.String(), nil
return u.String(), nil, nil
}

var connectorData = []byte("foobar")

// HandleCallback parses the request and returns the user's identity
func (m *Callback) HandleCallback(s connector.Scopes, r *http.Request) (connector.Identity, error) {
func (m *Callback) HandleCallback(s connector.Scopes, connData []byte, r *http.Request) (connector.Identity, error) {
return m.Identity, nil
}

Expand Down
Loading

0 comments on commit 09dea1c

Please sign in to comment.