diff --git a/cmd/dex/config.go b/cmd/dex/config.go index 3d07f2ff62..4b4ddd1020 100644 --- a/cmd/dex/config.go +++ b/cmd/dex/config.go @@ -47,6 +47,9 @@ type Config struct { // querying the storage. Cannot be specified without enabling a passwords // database. StaticPasswords []password `json:"staticPasswords"` + + // URL base to use for public-facing links and redirects. Defaults to Issuer. + PublicURL string `json:"publicURL"` } //Validate the configuration diff --git a/cmd/dex/serve.go b/cmd/dex/serve.go index 5c8732aaa3..a7b256afd4 100644 --- a/cmd/dex/serve.go +++ b/cmd/dex/serve.go @@ -244,6 +244,7 @@ func serve(cmd *cobra.Command, args []string) error { Logger: logger, Now: now, PrometheusRegistry: prometheusRegistry, + PublicURL: c.PublicURL, } if c.Expiry.SigningKeys != "" { signingKeys, err := time.ParseDuration(c.Expiry.SigningKeys) diff --git a/server/handlers.go b/server/handlers.go index 5512d87fb7..ace02294ff 100644 --- a/server/handlers.go +++ b/server/handlers.go @@ -164,10 +164,10 @@ type discovery struct { func (s *Server) discoveryHandler() (http.HandlerFunc, error) { d := discovery{ Issuer: s.issuerURL.String(), - Auth: s.absURL("/auth"), - Token: s.absURL("/token"), - Keys: s.absURL("/keys"), - UserInfo: s.absURL("/userinfo"), + Auth: s.absURL(s.publicURL, "/auth"), + Token: s.absURL(s.issuerURL, "/token"), + Keys: s.absURL(s.issuerURL, "/keys"), + UserInfo: s.absURL(s.issuerURL, "/userinfo"), Subjects: []string{"public"}, IDTokenAlgs: []string{string(jose.RS256)}, Scopes: []string{"openid", "email", "groups", "profile", "offline_access"}, @@ -241,7 +241,7 @@ func (s *Server) handleAuthorization(w http.ResponseWriter, r *http.Request) { if authReq.ConnectorID != "" { for _, c := range connectors { if c.ID == authReq.ConnectorID { - http.Redirect(w, r, s.absPath("/auth", c.ID)+"?req="+authReq.ID, http.StatusFound) + http.Redirect(w, r, s.absPath(s.publicURL, "/auth", c.ID)+"?req="+authReq.ID, http.StatusFound) return } } @@ -253,7 +253,7 @@ func (s *Server) handleAuthorization(w http.ResponseWriter, r *http.Request) { for _, c := range connectors { // TODO(ericchiang): Make this pass on r.URL.RawQuery and let something latter // on create the auth request. - http.Redirect(w, r, s.absPath("/auth", c.ID)+"?req="+authReq.ID, http.StatusFound) + http.Redirect(w, r, s.absPath(s.publicURL, "/auth", c.ID)+"?req="+authReq.ID, http.StatusFound) return } } @@ -266,7 +266,7 @@ func (s *Server) handleAuthorization(w http.ResponseWriter, r *http.Request) { Type: conn.Type, // TODO(ericchiang): Make this pass on r.URL.RawQuery and let something latter // on create the auth request. - URL: s.absPath("/auth", conn.ID) + "?req=" + authReq.ID, + URL: s.absPath(s.publicURL, "/auth", conn.ID) + "?req=" + authReq.ID, } } @@ -320,7 +320,7 @@ func (s *Server) handleConnectorLogin(w http.ResponseWriter, r *http.Request) { // Use the auth request ID as the "state" token. // // TODO(ericchiang): Is this appropriate or should we also be using a nonce? - callbackURL, err := conn.LoginURL(scopes, s.absURL("/callback"), authReqID) + callbackURL, err := conn.LoginURL(scopes, s.absURL(s.publicURL, "/callback"), authReqID) if err != nil { s.logger.Errorf("Connector %q returned error when creating callback: %v", connID, err) s.renderError(r, w, http.StatusInternalServerError, "Login error.") diff --git a/server/handlers_test.go b/server/handlers_test.go index b30076dd41..82eee82b1f 100644 --- a/server/handlers_test.go +++ b/server/handlers_test.go @@ -119,3 +119,28 @@ func TestHandleInvalidSAMLCallbacks(t *testing.T) { } } } + +func TestHandleDiscoveryPublic(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + httpServer, server := newTestServer(ctx, t, func(c *Config) { + c.PublicURL = "https://dex.example.com/" + }) + defer httpServer.Close() + + rr := httptest.NewRecorder() + server.ServeHTTP(rr, httptest.NewRequest("GET", "/.well-known/openid-configuration", nil)) + if rr.Code != http.StatusOK { + t.Errorf("expected 200 got %d", rr.Code) + } + config := map[string]interface{}{} + err := json.Unmarshal(rr.Body.Bytes(), &config) + if err != nil { + t.Fatal(err.Error()) + } + authURL := config["authorization_endpoint"].(string) + if authURL != "https://dex.example.com/auth" { + t.Errorf("expected https://dex.example.com/auth got %s", authURL) + } +} diff --git a/server/server.go b/server/server.go index 09292b1672..1c8476ef32 100644 --- a/server/server.go +++ b/server/server.go @@ -90,6 +90,8 @@ type Config struct { Logger log.Logger PrometheusRegistry *prometheus.Registry + + PublicURL string } // WebConfig holds the server's frontend templates and asset configuration. @@ -130,6 +132,7 @@ func value(val, defaultValue time.Duration) time.Duration { // Server is the top level object. type Server struct { issuerURL url.URL + publicURL url.URL // mutex for the connectors map. mu sync.Mutex @@ -175,6 +178,16 @@ func newServer(ctx context.Context, c Config, rotationStrategy rotationStrategy) return nil, fmt.Errorf("server: can't parse issuer URL") } + var publicURL *url.URL + if c.PublicURL == "" { + publicURL = issuerURL + } else { + publicURL, err = url.Parse(c.PublicURL) + if err != nil { + return nil, fmt.Errorf("server: can't parse public URL") + } + } + if c.Storage == nil { return nil, errors.New("server: storage cannot be nil") } @@ -224,6 +237,7 @@ func newServer(ctx context.Context, c Config, rotationStrategy rotationStrategy) templates: tmpls, passwordConnector: c.PasswordConnector, logger: c.Logger, + publicURL: *publicURL, } // Retrieves connector objects in backend storage. This list includes the static connectors @@ -330,17 +344,16 @@ func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) { s.mux.ServeHTTP(w, r) } -func (s *Server) absPath(pathItems ...string) string { +func (s *Server) absPath(base url.URL, pathItems ...string) string { paths := make([]string, len(pathItems)+1) - paths[0] = s.issuerURL.Path + paths[0] = base.Path copy(paths[1:], pathItems) return path.Join(paths...) } -func (s *Server) absURL(pathItems ...string) string { - u := s.issuerURL - u.Path = s.absPath(pathItems...) - return u.String() +func (s *Server) absURL(base url.URL, pathItems ...string) string { + base.Path = s.absPath(base, pathItems...) + return base.String() } func newPasswordDB(s storage.Storage) interface {