diff --git a/connector/connector.go b/connector/connector.go index aab994b468..409262bec5 100644 --- a/connector/connector.go +++ b/connector/connector.go @@ -98,3 +98,11 @@ type RefreshConnector interface { // changes since the token was last refreshed. Refresh(ctx context.Context, s Scopes, identity Identity) (Identity, error) } + +// CheckEndpointConnector is a connector that can test its endpoints for availability. +type CheckEndpointConnector interface { + // CheckEndpoint is called when a client wants to check the availability of + // endpoints used by the connector. If no error is returned, the connector + // can be assumed to be working. + CheckEndpoint(ctx context.Context) error +} diff --git a/connector/openshift/openshift.go b/connector/openshift/openshift.go index 05919973e3..f05de3d20c 100644 --- a/connector/openshift/openshift.go +++ b/connector/openshift/openshift.go @@ -38,8 +38,9 @@ type Config struct { } var ( - _ connector.CallbackConnector = (*openshiftConnector)(nil) - _ connector.RefreshConnector = (*openshiftConnector)(nil) + _ connector.CallbackConnector = (*openshiftConnector)(nil) + _ connector.RefreshConnector = (*openshiftConnector)(nil) + _ connector.CheckEndpointConnector = (*openshiftConnector)(nil) ) type openshiftConnector struct { @@ -78,7 +79,8 @@ func (c *Config) Open(id string, logger log.Logger) (conn connector.Connector, e // OpenWithHTTPClient returns a connector which can be used to login users through an upstream // OpenShift OAuth2 provider. It provides the ability to inject a http.Client. func (c *Config) OpenWithHTTPClient(id string, logger log.Logger, - httpClient *http.Client) (conn connector.Connector, err error) { + httpClient *http.Client, +) (conn connector.Connector, err error) { ctx, cancel := context.WithCancel(context.Background()) wellKnownURL := strings.TrimSuffix(c.Issuer, "/") + wellKnownURLPath @@ -125,6 +127,7 @@ func (c *Config) OpenWithHTTPClient(id string, logger log.Logger, Scopes: []string{"user:info"}, RedirectURL: c.RedirectURI, } + return &openshiftConnector, nil } @@ -156,7 +159,8 @@ func (e *oauth2Error) Error() string { // HandleCallback parses the request and returns the user's identity func (c *openshiftConnector) HandleCallback(s connector.Scopes, - r *http.Request) (identity connector.Identity, err error) { + r *http.Request, +) (identity connector.Identity, err error) { q := r.URL.Query() if errType := q.Get("error"); errType != "" { return identity, &oauth2Error{errType, q.Get("error_description")} @@ -176,7 +180,8 @@ func (c *openshiftConnector) HandleCallback(s connector.Scopes, } func (c *openshiftConnector) Refresh(ctx context.Context, s connector.Scopes, - oldID connector.Identity) (connector.Identity, error) { + oldID connector.Identity, +) (connector.Identity, error) { var token oauth2.Token err := json.Unmarshal(oldID.ConnectorData, &token) if err != nil { @@ -189,7 +194,8 @@ func (c *openshiftConnector) Refresh(ctx context.Context, s connector.Scopes, } func (c *openshiftConnector) identity(ctx context.Context, s connector.Scopes, - token *oauth2.Token) (identity connector.Identity, err error) { + token *oauth2.Token, +) (identity connector.Identity, err error) { client := c.oauth2Config.Client(ctx, token) user, err := c.user(ctx, client) if err != nil { @@ -253,6 +259,24 @@ func (c *openshiftConnector) user(ctx context.Context, client *http.Client) (u u return u, err } +func (c *openshiftConnector) CheckEndpoint(ctx context.Context) error { + return c.checkOAuth2EndpointsAvailability(ctx) +} + +func (c *openshiftConnector) checkOAuth2EndpointsAvailability(ctx context.Context) error { + for _, endpoint := range []string{c.oauth2Config.Endpoint.AuthURL, c.oauth2Config.Endpoint.TokenURL} { + req, err := http.NewRequest(http.MethodHead, endpoint, nil) + if err != nil { + return err + } + _, err = c.httpClient.Do(req.WithContext(ctx)) + if err != nil { + return err + } + } + return nil +} + func validateAllowedGroups(userGroups, allowedGroups []string) bool { matchingGroups := groups.Filter(userGroups, allowedGroups) diff --git a/connector/openshift/openshift_test.go b/connector/openshift/openshift_test.go index 6280b831de..b39cfa26a2 100644 --- a/connector/openshift/openshift_test.go +++ b/connector/openshift/openshift_test.go @@ -257,6 +257,30 @@ func TestRefreshIdentityFailure(t *testing.T) { expectEquals(t, connector.Identity{}, identity) } +func TestCheckEndpoint(t *testing.T) { + s := newTestServer(map[string]interface{}{ + "/oauth/authorize": nil, + "/oauth/token": nil, + }) + defer s.Close() + + h, err := newHTTPClient(true, "") + expectNil(t, err) + + oc := openshiftConnector{apiURL: s.URL, httpClient: h, oauth2Config: &oauth2.Config{ + Endpoint: oauth2.Endpoint{ + AuthURL: s.URL + "/oauth/authorize", TokenURL: s.URL + "/oauth/token", + }, + }} + err = oc.CheckEndpoint(context.Background()) + expectNil(t, err) + + s.Close() + + err = oc.CheckEndpoint(context.Background()) + expectNotNil(t, err) +} + func newTestServer(responses map[string]interface{}) *httptest.Server { var s *httptest.Server s = httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {