Skip to content

Commit

Permalink
Introduce CheckEndpointConnector, implement for OpenShift connector.
Browse files Browse the repository at this point in the history
  • Loading branch information
dhaus67 committed May 9, 2022
1 parent 453504c commit 29b843c
Show file tree
Hide file tree
Showing 3 changed files with 62 additions and 6 deletions.
8 changes: 8 additions & 0 deletions connector/connector.go
Original file line number Diff line number Diff line change
Expand Up @@ -98,3 +98,11 @@ type RefreshConnector interface {
// changes since the token was last refreshed.
Refresh(ctx context.Context, s Scopes, identity Identity) (Identity, error)
}

// CheckEndpointConnector is a connector that can test its endpoints for availability.
type CheckEndpointConnector interface {
// CheckEndpoint is called when a client wants to check the availability of
// endpoints used by the connector. If no error is returned, the connector
// can be assumed to be working.
CheckEndpoint(ctx context.Context) error
}
36 changes: 30 additions & 6 deletions connector/openshift/openshift.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,9 @@ type Config struct {
}

var (
_ connector.CallbackConnector = (*openshiftConnector)(nil)
_ connector.RefreshConnector = (*openshiftConnector)(nil)
_ connector.CallbackConnector = (*openshiftConnector)(nil)
_ connector.RefreshConnector = (*openshiftConnector)(nil)
_ connector.CheckEndpointConnector = (*openshiftConnector)(nil)
)

type openshiftConnector struct {
Expand Down Expand Up @@ -78,7 +79,8 @@ func (c *Config) Open(id string, logger log.Logger) (conn connector.Connector, e
// OpenWithHTTPClient returns a connector which can be used to login users through an upstream
// OpenShift OAuth2 provider. It provides the ability to inject a http.Client.
func (c *Config) OpenWithHTTPClient(id string, logger log.Logger,
httpClient *http.Client) (conn connector.Connector, err error) {
httpClient *http.Client,
) (conn connector.Connector, err error) {
ctx, cancel := context.WithCancel(context.Background())

wellKnownURL := strings.TrimSuffix(c.Issuer, "/") + wellKnownURLPath
Expand Down Expand Up @@ -125,6 +127,7 @@ func (c *Config) OpenWithHTTPClient(id string, logger log.Logger,
Scopes: []string{"user:info"},
RedirectURL: c.RedirectURI,
}

return &openshiftConnector, nil
}

Expand Down Expand Up @@ -156,7 +159,8 @@ func (e *oauth2Error) Error() string {

// HandleCallback parses the request and returns the user's identity
func (c *openshiftConnector) HandleCallback(s connector.Scopes,
r *http.Request) (identity connector.Identity, err error) {
r *http.Request,
) (identity connector.Identity, err error) {
q := r.URL.Query()
if errType := q.Get("error"); errType != "" {
return identity, &oauth2Error{errType, q.Get("error_description")}
Expand All @@ -176,7 +180,8 @@ func (c *openshiftConnector) HandleCallback(s connector.Scopes,
}

func (c *openshiftConnector) Refresh(ctx context.Context, s connector.Scopes,
oldID connector.Identity) (connector.Identity, error) {
oldID connector.Identity,
) (connector.Identity, error) {
var token oauth2.Token
err := json.Unmarshal(oldID.ConnectorData, &token)
if err != nil {
Expand All @@ -189,7 +194,8 @@ func (c *openshiftConnector) Refresh(ctx context.Context, s connector.Scopes,
}

func (c *openshiftConnector) identity(ctx context.Context, s connector.Scopes,
token *oauth2.Token) (identity connector.Identity, err error) {
token *oauth2.Token,
) (identity connector.Identity, err error) {
client := c.oauth2Config.Client(ctx, token)
user, err := c.user(ctx, client)
if err != nil {
Expand Down Expand Up @@ -253,6 +259,24 @@ func (c *openshiftConnector) user(ctx context.Context, client *http.Client) (u u
return u, err
}

func (c *openshiftConnector) CheckEndpoint(ctx context.Context) error {
return c.checkOAuth2EndpointsAvailability(ctx)
}

func (c *openshiftConnector) checkOAuth2EndpointsAvailability(ctx context.Context) error {
for _, endpoint := range []string{c.oauth2Config.Endpoint.AuthURL, c.oauth2Config.Endpoint.TokenURL} {
req, err := http.NewRequest(http.MethodHead, endpoint, nil)
if err != nil {
return err
}
_, err = c.httpClient.Do(req.WithContext(ctx))
if err != nil {
return err
}
}
return nil
}

func validateAllowedGroups(userGroups, allowedGroups []string) bool {
matchingGroups := groups.Filter(userGroups, allowedGroups)

Expand Down
24 changes: 24 additions & 0 deletions connector/openshift/openshift_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -257,6 +257,30 @@ func TestRefreshIdentityFailure(t *testing.T) {
expectEquals(t, connector.Identity{}, identity)
}

func TestCheckEndpoint(t *testing.T) {
s := newTestServer(map[string]interface{}{
"/oauth/authorize": nil,
"/oauth/token": nil,
})
defer s.Close()

h, err := newHTTPClient(true, "")
expectNil(t, err)

oc := openshiftConnector{apiURL: s.URL, httpClient: h, oauth2Config: &oauth2.Config{
Endpoint: oauth2.Endpoint{
AuthURL: s.URL + "/oauth/authorize", TokenURL: s.URL + "/oauth/token",
},
}}
err = oc.CheckEndpoint(context.Background())
expectNil(t, err)

s.Close()

err = oc.CheckEndpoint(context.Background())
expectNotNil(t, err)
}

func newTestServer(responses map[string]interface{}) *httptest.Server {
var s *httptest.Server
s = httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
Expand Down

0 comments on commit 29b843c

Please sign in to comment.