diff --git a/cmd/dex/config.go b/cmd/dex/config.go index dd6d2e2ab9..6b046e6361 100644 --- a/cmd/dex/config.go +++ b/cmd/dex/config.go @@ -51,6 +51,9 @@ type Config struct { // querying the storage. Cannot be specified without enabling a passwords // database. StaticPasswords []password `json:"staticPasswords"` + + // TOTP represents the configuration for two-factor authentication. + TOTP TOTP `json:"twoFactorAuthn"` } // Validate the configuration @@ -422,3 +425,10 @@ type RefreshToken struct { AbsoluteLifetime string `json:"absoluteLifetime"` ValidIfNotUsedFor string `json:"validIfNotUsedFor"` } + +type TOTP struct { + // Issuer is the name of the service (will be shown in the authenticator app). + Issuer string `json:"issuer"` + // Connectors is a list of connectors that will use TOTP. + Connectors []string `json:"connectors"` +} diff --git a/cmd/dex/config_test.go b/cmd/dex/config_test.go index c6d37cb03e..515e2c3ab1 100644 --- a/cmd/dex/config_test.go +++ b/cmd/dex/config_test.go @@ -343,6 +343,11 @@ expiry: idTokens: "25h" authRequests: "25h" +twoFactorAuthn: + issuer: dex + connectors: + - mock + logger: level: "debug" format: "json" @@ -432,6 +437,10 @@ logger: IDTokens: "25h", AuthRequests: "25h", }, + TOTP: TOTP{ + Issuer: "dex", + Connectors: []string{"mock"}, + }, Logger: Logger{ Level: slog.LevelDebug, Format: "json", diff --git a/cmd/dex/serve.go b/cmd/dex/serve.go index 6fcca04da3..0562b0210c 100644 --- a/cmd/dex/serve.go +++ b/cmd/dex/serve.go @@ -302,6 +302,8 @@ func runServe(options serveOptions) error { Now: now, PrometheusRegistry: prometheusRegistry, HealthChecker: healthChecker, + TOTPIssuer: c.TOTP.Issuer, + TOTPConnectors: c.TOTP.Connectors, } if c.Expiry.SigningKeys != "" { signingKeys, err := time.ParseDuration(c.Expiry.SigningKeys) diff --git a/examples/config-dev.yaml b/examples/config-dev.yaml index 147597a265..2cc7e3b634 100644 --- a/examples/config-dev.yaml +++ b/examples/config-dev.yaml @@ -75,6 +75,12 @@ telemetry: http: 0.0.0.0:5558 # enableProfiling: true +# Configuration for the two-factor authentication +# twoFactorAuthn: +# issuer: "dex" +# connectors: +# - mock + # Uncomment this block to enable the gRPC API. This values MUST be different # from the HTTP endpoints. # grpc: diff --git a/go.mod b/go.mod index 890cc8dfe5..bf3260a252 100644 --- a/go.mod +++ b/go.mod @@ -27,6 +27,7 @@ require ( github.com/mattn/go-sqlite3 v1.14.22 github.com/oklog/run v1.1.0 github.com/pkg/errors v0.9.1 + github.com/pquerna/otp v1.4.0 github.com/prometheus/client_golang v1.19.1 github.com/russellhaering/goxmldsig v1.4.0 github.com/spf13/cobra v1.8.1 @@ -53,6 +54,7 @@ require ( github.com/agext/levenshtein v1.2.1 // indirect github.com/apparentlymart/go-textseg/v13 v13.0.0 // indirect github.com/beorn7/perks v1.0.1 // indirect + github.com/boombuler/barcode v1.0.1-0.20190219062509-6c824513bacc // indirect github.com/cespare/xxhash/v2 v2.3.0 // indirect github.com/coreos/go-semver v0.3.0 // indirect github.com/coreos/go-systemd/v22 v22.3.2 // indirect diff --git a/go.sum b/go.sum index da52911df8..16ec785b4e 100644 --- a/go.sum +++ b/go.sum @@ -37,6 +37,8 @@ github.com/beevik/etree v1.4.0 h1:oz1UedHRepuY3p4N5OjE0nK1WLCqtzHf25bxplKOHLs= github.com/beevik/etree v1.4.0/go.mod h1:cyWiXwGoasx60gHvtnEh5x8+uIjUVnjWqBvEnhnqKDA= github.com/beorn7/perks v1.0.1 h1:VlbKKnNfV8bJzeqoa4cOKqO6bYr3WgKZxO8Z16+hsOM= github.com/beorn7/perks v1.0.1/go.mod h1:G2ZrVWU2WbWT9wwq4/hrbKbnv/1ERSJQ0ibhJ6rlkpw= +github.com/boombuler/barcode v1.0.1-0.20190219062509-6c824513bacc h1:biVzkmvwrH8WK8raXaxBx6fRVTlJILwEwQGL1I/ByEI= +github.com/boombuler/barcode v1.0.1-0.20190219062509-6c824513bacc/go.mod h1:paBWMcWSl3LHKBqUq+rly7CNSldXjb2rDl3JlRe0mD8= github.com/census-instrumentation/opencensus-proto v0.2.1/go.mod h1:f6KPmirojxKA12rnyqOA5BBL4O983OfeGPqjHWSTneU= github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UFvs= github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= @@ -187,6 +189,8 @@ github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4= github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/pquerna/otp v1.4.0 h1:wZvl1TIVxKRThZIBiwOOHOGP/1+nZyWBil9Y2XNEDzg= +github.com/pquerna/otp v1.4.0/go.mod h1:dkJfzwRKNiegxyNb54X/3fLwhCynbMspSyWKnvi1AEg= github.com/prometheus/client_golang v1.19.1 h1:wZWJDwK+NameRJuPGDhlnFgx8e8HN3XHQeLaYJFJBOE= github.com/prometheus/client_golang v1.19.1/go.mod h1:mP78NwGzrVks5S2H6ab8+ZZGJLZUq1hoULYBAYBw1Ho= github.com/prometheus/client_model v0.0.0-20190812154241-14fe0d1b01d4/go.mod h1:xMI15A0UPsDsEKsMN9yxemIoYk6Tm2C1GtYGdfGttqA= diff --git a/server/handlers.go b/server/handlers.go index 63cb612295..5c05cb0376 100644 --- a/server/handlers.go +++ b/server/handlers.go @@ -11,7 +11,6 @@ import ( "html/template" "net/http" "net/url" - "path" "sort" "strconv" "strings" @@ -514,6 +513,11 @@ func (s *Server) finalizeLogin(ctx context.Context, identity connector.Identity, a.LoggedIn = true a.Claims = claims a.ConnectorData = identity.ConnectorData + + if !s.totp.enabledForConnector(a.ConnectorID) { + a.TOTPValidated = true + } + return a, nil } if err := s.storage.UpdateAuthRequest(authReq.ID, updater); err != nil { @@ -529,36 +533,11 @@ func (s *Server) finalizeLogin(ctx context.Context, identity connector.Identity, "connector_id", authReq.ConnectorID, "username", claims.Username, "preferred_username", claims.PreferredUsername, "email", email, "groups", claims.Groups) - // we can skip the redirect to /approval and go ahead and send code if it's not required - if s.skipApproval && !authReq.ForceApprovalPrompt { - return "", true, nil - } - - // an HMAC is used here to ensure that the request ID is unpredictable, ensuring that an attacker who intercepted the original - // flow would be unable to poll for the result at the /approval endpoint - h := hmac.New(sha256.New, authReq.HMACKey) - h.Write([]byte(authReq.ID)) - mac := h.Sum(nil) - - returnURL := path.Join(s.issuerURL.Path, "/approval") + "?req=" + authReq.ID + "&hmac=" + base64.RawURLEncoding.EncodeToString(mac) - _, ok := conn.(connector.RefreshConnector) - if !ok { - return returnURL, false, nil - } - - offlineAccessRequested := false - for _, scope := range authReq.Scopes { - if scope == scopeOfflineAccess { - offlineAccessRequested = true - break - } - } - if !offlineAccessRequested { - return returnURL, false, nil - } - // Try to retrieve an existing OfflineSession object for the corresponding user. - session, err := s.storage.GetOfflineSessions(identity.UserID, authReq.ConnectorID) + // TODO(nabokihms): We create an offline session even if the offline access is not requested. + // In the future it will be possible to migrate to sessions. + // Sessions may contain attributes like approval status, etc. + _, err := s.storage.GetOfflineSessions(identity.UserID, authReq.ConnectorID) if err != nil { if err != storage.ErrNotFound { s.logger.ErrorContext(ctx, "failed to get offline session", "err", err) @@ -571,18 +550,25 @@ func (s *Server) finalizeLogin(ctx context.Context, identity connector.Identity, ConnectorData: identity.ConnectorData, } + if s.totp.enabledForConnector(authReq.ConnectorID) { + generated, err := s.totp.generate(authReq.ConnectorID, identity.Email) + if err != nil { + s.logger.ErrorContext(ctx, "failed to generate totp for offline session", "err", err) + return "", false, err + } + offlineSessions.TOTP = generated.String() + } + // Create a new OfflineSession object for the user and add a reference object for // the newly received refreshtoken. if err := s.storage.CreateOfflineSessions(ctx, offlineSessions); err != nil { s.logger.ErrorContext(ctx, "failed to create offline session", "err", err) return "", false, err } - - return returnURL, false, nil } // Update existing OfflineSession obj with new RefreshTokenRef. - if err := s.storage.UpdateOfflineSessions(session.UserID, session.ConnID, func(old storage.OfflineSessions) (storage.OfflineSessions, error) { + if err := s.storage.UpdateOfflineSessions(identity.UserID, authReq.ConnectorID, func(old storage.OfflineSessions) (storage.OfflineSessions, error) { if len(identity.ConnectorData) > 0 { old.ConnectorData = identity.ConnectorData } @@ -592,7 +578,32 @@ func (s *Server) finalizeLogin(ctx context.Context, identity connector.Identity, return "", false, err } - return returnURL, false, nil + // we can skip the redirect to /approval and /totp and go ahead and send code if it's not required + if s.skipApproval && !authReq.ForceApprovalPrompt && !s.totp.enabledForConnector(authReq.ConnectorID) { + return "", true, nil + } + + // an HMAC is used here to ensure that the request ID is unpredictable, ensuring that an attacker who intercepted the original + // flow would be unable to poll for the result at the /approval endpoint + h := hmac.New(sha256.New, authReq.HMACKey) + h.Write([]byte(authReq.ID)) + mac := h.Sum(nil) + + // Deep copy issuer URL to avoid modifying the global one. + returnURL, _ := url.Parse(s.issuerURL.String()) + values := returnURL.Query() + values.Set("req", authReq.ID) + values.Set("hmac", base64.RawURLEncoding.EncodeToString(mac)) + + if s.totp.enabledForConnector(authReq.ConnectorID) { + values.Set("state", identity.UserID) + returnURL = returnURL.JoinPath("totp") + } else { + returnURL = returnURL.JoinPath("approval") + } + + returnURL.RawQuery = values.Encode() + return returnURL.String(), false, nil } func (s *Server) handleApproval(w http.ResponseWriter, r *http.Request) { @@ -613,7 +624,7 @@ func (s *Server) handleApproval(w http.ResponseWriter, r *http.Request) { s.renderError(r, w, http.StatusInternalServerError, "Database error.") return } - if !authReq.LoggedIn { + if !authReq.LoggedIn || !authReq.TOTPValidated { s.logger.ErrorContext(r.Context(), "auth request does not have an identity for approval") s.renderError(r, w, http.StatusInternalServerError, "Login process not yet finalized.") return diff --git a/server/server.go b/server/server.go index 1cf71c5038..03f7798a0c 100644 --- a/server/server.go +++ b/server/server.go @@ -120,6 +120,9 @@ type Config struct { PrometheusRegistry *prometheus.Registry HealthChecker gosundheit.Health + + TOTPIssuer string + TOTPConnectors []string } // WebConfig holds the server's frontend templates and asset configuration. @@ -197,6 +200,8 @@ type Server struct { refreshTokenPolicy *RefreshTokenPolicy logger *slog.Logger + + totp *secondFactorAuthenticator } // NewServer constructs a server from the provided config. @@ -312,6 +317,7 @@ func newServer(ctx context.Context, c Config, rotationStrategy rotationStrategy) now: now, templates: tmpls, passwordConnector: c.PasswordConnector, + totp: newSecondFactorAuthenticator(c.TOTPIssuer, c.TOTPConnectors), logger: c.Logger, } @@ -463,6 +469,7 @@ func newServer(ctx context.Context, c Config, rotationStrategy rotationStrategy) // "authproxy" connector. handleFunc("/callback/{connector}", s.handleConnectorCallback) handleFunc("/approval", s.handleApproval) + handleFunc("/totp", s.handleTOTPVerify) handle("/healthz", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { if !c.HealthChecker.IsHealthy() { s.renderError(r, w, http.StatusInternalServerError, "Health check failed.") diff --git a/server/templates.go b/server/templates.go index b77663e1f5..46bca92ec6 100644 --- a/server/templates.go +++ b/server/templates.go @@ -22,6 +22,7 @@ const ( tmplError = "error.html" tmplDevice = "device.html" tmplDeviceSuccess = "device_success.html" + tmplTOTPVerify = "totp_verify.html" ) var requiredTmpls = []string{ @@ -32,6 +33,7 @@ var requiredTmpls = []string{ tmplError, tmplDevice, tmplDeviceSuccess, + tmplTOTPVerify, } type templates struct { @@ -42,6 +44,7 @@ type templates struct { errorTmpl *template.Template deviceTmpl *template.Template deviceSuccessTmpl *template.Template + tmplTOTPVerify *template.Template } type webConfig struct { @@ -169,6 +172,7 @@ func loadTemplates(c webConfig, templatesDir string) (*templates, error) { errorTmpl: tmpls.Lookup(tmplError), deviceTmpl: tmpls.Lookup(tmplDevice), deviceSuccessTmpl: tmpls.Lookup(tmplDeviceSuccess), + tmplTOTPVerify: tmpls.Lookup(tmplTOTPVerify), }, nil } @@ -282,6 +286,21 @@ func (t *templates) deviceSuccess(r *http.Request, w http.ResponseWriter, client return renderTemplate(w, t.deviceSuccessTmpl, data) } +func (t *templates) totpVerify(r *http.Request, w http.ResponseWriter, postURL, issuer, connector, qrCode string, lastWasInvalid bool) error { + if lastWasInvalid { + w.WriteHeader(http.StatusUnauthorized) + } + data := struct { + PostURL string + Invalid bool + Issuer string + Connector string + QRCode string + ReqPath string + }{postURL, lastWasInvalid, issuer, connector, qrCode, r.URL.Path} + return renderTemplate(w, t.tmplTOTPVerify, data) +} + func (t *templates) login(r *http.Request, w http.ResponseWriter, connectors []connectorInfo) error { sort.Sort(byName(connectors)) data := struct { diff --git a/server/totphandler.go b/server/totphandler.go new file mode 100644 index 0000000000..2d907afd60 --- /dev/null +++ b/server/totphandler.go @@ -0,0 +1,187 @@ +package server + +import ( + "bytes" + "crypto/hmac" + "crypto/sha256" + "encoding/base64" + "fmt" + "image/png" + "net/http" + "strings" + + "github.com/pquerna/otp" + "github.com/pquerna/otp/totp" + + "github.com/dexidp/dex/storage" +) + +func (s *Server) handleTOTPVerify(w http.ResponseWriter, r *http.Request) { + macEncoded := r.FormValue("hmac") + if macEncoded == "" { + s.renderError(r, w, http.StatusUnauthorized, "Unauthorized request.") + return + } + mac, err := base64.RawURLEncoding.DecodeString(macEncoded) + if err != nil { + s.renderError(r, w, http.StatusUnauthorized, "Unauthorized request.") + return + } + + authReq, err := s.storage.GetAuthRequest(r.FormValue("req")) + if err != nil { + s.logger.ErrorContext(r.Context(), "failed to get auth request", "err", err) + s.renderError(r, w, http.StatusInternalServerError, "Database error.") + return + } + if !authReq.LoggedIn { + s.logger.ErrorContext(r.Context(), "auth request does not have an identity for TOTP verification") + s.renderError(r, w, http.StatusInternalServerError, "Login process not yet finalized.") + return + } + + // build expected hmac with secret key + h := hmac.New(sha256.New, authReq.HMACKey) + h.Write([]byte(authReq.ID)) + expectedMAC := h.Sum(nil) + // constant time comparison + if !hmac.Equal(mac, expectedMAC) { + s.renderError(r, w, http.StatusUnauthorized, "Unauthorized request.") + return + } + + offlineSession, err := s.storage.GetOfflineSessions(authReq.Claims.UserID, authReq.ConnectorID) + if err != nil { + s.logger.ErrorContext(r.Context(), "failed to get offline session", "err", err, "connector_id", authReq.ConnectorID, "user_id", authReq.Claims.UserID) + s.renderError(r, w, http.StatusInternalServerError, "Database error.") + return + } + + // TODO(nabokihms): compose the redirect URL the right way + returnURL := strings.ReplaceAll(r.URL.String(), "/totp", "/approval") + if offlineSession.TOTP == "" || authReq.TOTPValidated { + http.Redirect(w, r, returnURL, http.StatusSeeOther) + return + } + + switch r.Method { + case http.MethodGet: + s.renderTOTPValidatePage(offlineSession, false, w, r) + return + case http.MethodPost: + password := r.FormValue("totp") + + generated, err := otp.NewKeyFromURL(offlineSession.TOTP) + if err != nil { + s.logger.ErrorContext(r.Context(), "failed to load TOTP QR code", "err", err, "connector_id", offlineSession.ConnID, "user_id", offlineSession.ConnID) + s.renderError(r, w, http.StatusInternalServerError, "Internal server error.") + return + } + + ok := totp.Validate(password, generated.Secret()) + if !ok { + s.renderTOTPValidatePage(offlineSession, true, w, r) + s.logger.ErrorContext(r.Context(), "failed TOTP attempt: Invalid credentials.", "user", "????") + return + } + + // If the TOTP is valid, update the offline session and auth request to reflect that. + if err := s.storage.UpdateOfflineSessions(offlineSession.UserID, offlineSession.ConnID, func(old storage.OfflineSessions) (storage.OfflineSessions, error) { + old.TOTPConfirmed = true + return old, nil + }); err != nil { + s.logger.ErrorContext(r.Context(), "failed to update offline session", "err", err, "connector_id", offlineSession.ConnID, "user_id", offlineSession.ConnID) + s.renderError(r, w, http.StatusInternalServerError, "Internal server error.") + return + } + if err := s.storage.UpdateAuthRequest(authReq.ID, func(old storage.AuthRequest) (storage.AuthRequest, error) { + old.TOTPValidated = true + return old, nil + }); err != nil { + s.logger.ErrorContext(r.Context(), "failed to update auth request", "err", err, "auth_request_id", authReq.ID) + s.renderError(r, w, http.StatusInternalServerError, "Internal server error.") + return + } + + // we can skip the redirect to /approval and go ahead and send code if it's not required + if s.skipApproval && !authReq.ForceApprovalPrompt { + authReq, err = s.storage.GetAuthRequest(authReq.ID) + if err != nil { + s.logger.ErrorContext(r.Context(), "failed to get finalized auth request", "err", err) + s.renderError(r, w, http.StatusInternalServerError, "Login error.") + return + } + s.sendCodeResponse(w, r, authReq) + return + } + http.Redirect(w, r, returnURL, http.StatusSeeOther) + default: + s.renderError(r, w, http.StatusBadRequest, "Unsupported request method.") + } +} + +// generateQRCode generates a QR code image for the OTP key. +// Returned value is a base64 encoded PNG image. +func generateQRCode(o storage.OfflineSessions) (string, error) { + generated, err := otp.NewKeyFromURL(o.TOTP) + if err != nil { + return "", fmt.Errorf("failed to load TOTP QR code: %w", err) + } + + qrCodeImage, err := generated.Image(300, 300) + if err != nil { + return "", fmt.Errorf("failed to generate TOTP QR code: %w", err) + } + + var buf bytes.Buffer + err = png.Encode(&buf, qrCodeImage) + if err != nil { + return "", fmt.Errorf("failed to encode TOTP QR code: %w", err) + } + + return base64.StdEncoding.EncodeToString(buf.Bytes()), nil +} + +func (s *Server) renderTOTPValidatePage(o storage.OfflineSessions, lastFail bool, w http.ResponseWriter, r *http.Request) { + qrCode := "" + var err error + + // Show QR code only once when the offline session is registered + if o.TOTP != "" && !o.TOTPConfirmed { + qrCode, err = generateQRCode(o) + if err != nil { + s.logger.ErrorContext(r.Context(), "failed to generate QR code", "err", err, "connector_id", o.ConnID, "user_id", o.ConnID) + s.renderError(r, w, http.StatusInternalServerError, "Internal server error.") + return + } + } + if err := s.templates.totpVerify(r, w, r.URL.String(), s.totp.issuer, o.ConnID, qrCode, lastFail); err != nil { + s.logger.ErrorContext(r.Context(), "server template error", "err", err) + } +} + +type secondFactorAuthenticator struct { + issuer string + // To check that TOTP is enabled for the connector. + connectors map[string]struct{} +} + +func newSecondFactorAuthenticator(issuer string, connectors []string) *secondFactorAuthenticator { + c := make(map[string]struct{}) + for _, conn := range connectors { + c[conn] = struct{}{} + } + return &secondFactorAuthenticator{issuer: issuer, connectors: c} +} + +func (s *secondFactorAuthenticator) generate(connID, email string) (*otp.Key, error) { + return totp.Generate(totp.GenerateOpts{ + Issuer: s.issuer, + AccountName: fmt.Sprintf("(%s) %s", connID, email), + }) +} + +func (s *secondFactorAuthenticator) enabledForConnector(connID string) bool { + _, ok := s.connectors[connID] + return ok +} diff --git a/server/totphandler_test.go b/server/totphandler_test.go new file mode 100644 index 0000000000..712400c317 --- /dev/null +++ b/server/totphandler_test.go @@ -0,0 +1,308 @@ +package server + +import ( + "context" + "crypto/hmac" + "crypto/sha256" + "encoding/base64" + "io" + "net/http" + "net/http/httptest" + "net/url" + "path" + "testing" + "time" + + "github.com/pquerna/otp" + "github.com/pquerna/otp/totp" + + "github.com/dexidp/dex/storage" +) + +const testTOTPKey = "otpauth://totp/Example:user3?secret=JBSWY3DPEHPK3PXP&issuer=Example" + +func testNewHMAC(id, key string) string { + h := hmac.New(sha256.New, []byte(key)) + h.Write([]byte(id)) + return base64.RawURLEncoding.EncodeToString(h.Sum(nil)) +} + +func testGenerateTOTPCode() string { + key, _ := otp.NewKeyFromURL(testTOTPKey) + code, _ := totp.GenerateCode(key.Secret(), time.Now()) + return code +} + +func TestHandleTOTPVerify(t *testing.T) { + tests := []struct { + testName string + authRequest storage.AuthRequest + offlineSession storage.OfflineSessions + values url.Values + expectedResponseCode int + expectedServerResponse string + }{ + { + testName: "Missing HMAC", + authRequest: storage.AuthRequest{ + ID: "authReq1", + LoggedIn: true, + HMACKey: []byte("secret"), + Claims: storage.Claims{UserID: "user1"}, + ConnectorID: "conn1", + }, + offlineSession: storage.OfflineSessions{ + UserID: "user1", + ConnID: "conn1", + TOTP: "otpauth://totp/Example:user1?secret=JBSWY3DPEHPK3PXP&issuer=Example", + }, + values: url.Values(map[string][]string{ + "req": {"authReq1"}, + }), + expectedResponseCode: http.StatusUnauthorized, + }, + { + testName: "Already validated", + authRequest: storage.AuthRequest{ + ID: "authReq3", + LoggedIn: true, + HMACKey: []byte("secret"), + Claims: storage.Claims{UserID: "user3"}, + ConnectorID: "conn3", + TOTPValidated: true, + }, + offlineSession: storage.OfflineSessions{ + UserID: "user3", + ConnID: "conn3", + TOTP: testTOTPKey, + }, + values: url.Values(map[string][]string{ + "req": {"authReq3"}, + "hmac": {testNewHMAC("authReq3", "secret")}, + }), + expectedResponseCode: http.StatusSeeOther, + }, + { + testName: "Not logged user", + authRequest: storage.AuthRequest{ + ID: "authReq100", + LoggedIn: false, + HMACKey: []byte("secret"), + Claims: storage.Claims{UserID: "user1"}, + ConnectorID: "conn1", + }, + offlineSession: storage.OfflineSessions{ + UserID: "user1", + ConnID: "conn1", + TOTP: "otpauth://totp/Example:user1?secret=JBSWY3DPEHPK3PXP&issuer=Example", + }, + values: url.Values(map[string][]string{ + "req": {"authReq100"}, + }), + expectedResponseCode: http.StatusUnauthorized, + }, + { + testName: "Invalid HMAC", + authRequest: storage.AuthRequest{ + ID: "authReq2", + LoggedIn: true, + HMACKey: []byte("secret"), + Claims: storage.Claims{UserID: "user2"}, + ConnectorID: "conn2", + }, + offlineSession: storage.OfflineSessions{ + UserID: "user2", + ConnID: "conn2", + TOTP: "otpauth://totp/Example:user2?secret=JBSWY3DPEHPK3PXP&issuer=Example", + }, + values: url.Values(map[string][]string{ + "req": {"authReq2"}, + "hmac": {base64.RawURLEncoding.EncodeToString([]byte("invalidvalidhmac"))}, + }), + expectedResponseCode: http.StatusUnauthorized, + }, + { + testName: "Redirect if no TOTP", + authRequest: storage.AuthRequest{ + ID: "authReq3", + LoggedIn: true, + HMACKey: []byte("secret"), + Claims: storage.Claims{UserID: "user3"}, + ConnectorID: "conn3", + }, + offlineSession: storage.OfflineSessions{ + UserID: "user3", + ConnID: "conn3", + }, + values: url.Values(map[string][]string{ + "req": {"authReq3"}, + "hmac": {testNewHMAC("authReq3", "secret")}, + }), + expectedResponseCode: http.StatusSeeOther, + }, + { + testName: "Successful TOTP Verification page", + authRequest: storage.AuthRequest{ + ID: "authReq3", + LoggedIn: true, + HMACKey: []byte("secret"), + Claims: storage.Claims{UserID: "user3"}, + ConnectorID: "conn3", + }, + offlineSession: storage.OfflineSessions{ + UserID: "user3", + ConnID: "conn3", + TOTP: testTOTPKey, + }, + values: url.Values(map[string][]string{ + "req": {"authReq3"}, + "hmac": {testNewHMAC("authReq3", "secret")}, + }), + expectedResponseCode: http.StatusOK, + }, + } + + for _, tc := range tests { + t.Run(tc.testName, func(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + // Setup a dex server. + httpServer, s := newTestServer(ctx, t, func(c *Config) { + c.Now = time.Now + }) + defer httpServer.Close() + + if err := s.storage.CreateAuthRequest(context.TODO(), tc.authRequest); err != nil { + t.Fatalf("failed to create auth request: %v", err) + } + + if err := s.storage.CreateOfflineSessions(context.TODO(), tc.offlineSession); err != nil { + t.Fatalf("failed to create offline session: %v", err) + } + + u, err := url.Parse(s.issuerURL.String()) + if err != nil { + t.Fatalf("Could not parse issuer URL %v", err) + } + u.Path = path.Join(u.Path, "totp") + u.RawQuery = tc.values.Encode() + req, _ := http.NewRequest("GET", u.String(), nil) + req.Header.Set("Content-Type", "application/x-www-form-urlencoded; param=value") + + rr := httptest.NewRecorder() + s.ServeHTTP(rr, req) + if rr.Code != tc.expectedResponseCode { + t.Errorf("%s: Unexpected Response Type. Expected %v got %v: %s", tc.testName, tc.expectedResponseCode, rr.Code, rr.Body.String()) + } + + if len(tc.expectedServerResponse) > 0 { + result, _ := io.ReadAll(rr.Body) + if string(result) != tc.expectedServerResponse { + t.Errorf("%s: Unexpected Response. Expected %q got %q", tc.testName, tc.expectedServerResponse, result) + } + } + }) + } +} + +func TestHandleTOTPForm(t *testing.T) { + tests := []struct { + testName string + authRequest storage.AuthRequest + offlineSession storage.OfflineSessions + values url.Values + expectedResponseCode int + expectedServerResponse string + }{ + { + testName: "Successful TOTP Verification", + authRequest: storage.AuthRequest{ + ID: "authReq3", + LoggedIn: true, + HMACKey: []byte("secret"), + Claims: storage.Claims{UserID: "user3"}, + ConnectorID: "conn3", + Expiry: time.Now().Add(time.Hour), + }, + offlineSession: storage.OfflineSessions{ + UserID: "user3", + ConnID: "conn3", + TOTP: testTOTPKey, + TOTPConfirmed: true, + }, + values: url.Values(map[string][]string{ + "req": {"authReq3"}, + "hmac": {testNewHMAC("authReq3", "secret")}, + "totp": {testGenerateTOTPCode()}, + }), + expectedResponseCode: http.StatusSeeOther, + }, + { + testName: "Unsuccessful TOTP Verification", + authRequest: storage.AuthRequest{ + ID: "authReq3", + LoggedIn: true, + HMACKey: []byte("secret"), + Claims: storage.Claims{UserID: "user3"}, + ConnectorID: "conn3", + Expiry: time.Now().Add(time.Hour), + }, + offlineSession: storage.OfflineSessions{ + UserID: "user3", + ConnID: "conn3", + TOTP: testTOTPKey, + TOTPConfirmed: true, + }, + values: url.Values(map[string][]string{ + "req": {"authReq3"}, + "hmac": {testNewHMAC("authReq3", "secret")}, + "totp": {"invalidpassword"}, + }), + expectedResponseCode: http.StatusUnauthorized, + }, + } + + for _, tc := range tests { + t.Run(tc.testName, func(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + // Setup a dex server. + httpServer, s := newTestServer(ctx, t, func(c *Config) { + c.Now = time.Now + }) + defer httpServer.Close() + + if err := s.storage.CreateAuthRequest(context.TODO(), tc.authRequest); err != nil { + t.Fatalf("failed to create auth request: %v", err) + } + + if err := s.storage.CreateOfflineSessions(context.TODO(), tc.offlineSession); err != nil { + t.Fatalf("failed to create offline session: %v", err) + } + + u, err := url.Parse(s.issuerURL.String()) + if err != nil { + t.Fatalf("Could not parse issuer URL %v", err) + } + u.Path = path.Join(u.Path, "totp") + u.RawQuery = tc.values.Encode() + req, _ := http.NewRequest("POST", u.String(), nil) + req.Header.Set("Content-Type", "application/x-www-form-urlencoded; param=value") + + rr := httptest.NewRecorder() + s.ServeHTTP(rr, req) + if rr.Code != tc.expectedResponseCode { + t.Errorf("%s: Unexpected Response Type. Expected %v got %v: %s %s", tc.testName, tc.expectedResponseCode, rr.Code, rr.Result().Header.Get("Location"), rr.Body.String()) + } + + if len(tc.expectedServerResponse) > 0 { + result, _ := io.ReadAll(rr.Body) + if string(result) != tc.expectedServerResponse { + t.Errorf("%s: Unexpected Response. Expected %q got %q", tc.testName, tc.expectedServerResponse, result) + } + } + }) + } +} diff --git a/storage/ent/client/authrequest.go b/storage/ent/client/authrequest.go index 42db702d68..5be6413787 100644 --- a/storage/ent/client/authrequest.go +++ b/storage/ent/client/authrequest.go @@ -32,6 +32,7 @@ func (d *Database) CreateAuthRequest(ctx context.Context, authRequest storage.Au SetConnectorID(authRequest.ConnectorID). SetConnectorData(authRequest.ConnectorData). SetHmacKey(authRequest.HMACKey). + SetTotpValidated(authRequest.TOTPValidated). Save(ctx) if err != nil { return convertDBError("create auth request: %w", err) @@ -96,6 +97,7 @@ func (d *Database) UpdateAuthRequest(id string, updater func(old storage.AuthReq SetConnectorID(newAuthRequest.ConnectorID). SetConnectorData(newAuthRequest.ConnectorData). SetHmacKey(newAuthRequest.HMACKey). + SetTotpValidated(newAuthRequest.TOTPValidated). Save(context.TODO()) if err != nil { return rollback(tx, "update auth request uploading: %w", err) diff --git a/storage/ent/client/offlinesession.go b/storage/ent/client/offlinesession.go index 22469eced9..44cdd4f61a 100644 --- a/storage/ent/client/offlinesession.go +++ b/storage/ent/client/offlinesession.go @@ -22,6 +22,8 @@ func (d *Database) CreateOfflineSessions(ctx context.Context, session storage.Of SetConnID(session.ConnID). SetConnectorData(session.ConnectorData). SetRefresh(encodedRefresh). + SetTotp(session.TOTP). + SetTotpConfirmed(session.TOTPConfirmed). Save(ctx) if err != nil { return convertDBError("create offline session: %w", err) @@ -80,6 +82,8 @@ func (d *Database) UpdateOfflineSessions(userID string, connID string, updater f SetConnID(newOfflineSession.ConnID). SetConnectorData(newOfflineSession.ConnectorData). SetRefresh(encodedRefresh). + SetTotp(newOfflineSession.TOTP). + SetTotpConfirmed(newOfflineSession.TOTPConfirmed). Save(context.TODO()) if err != nil { return rollback(tx, "update offline session uploading: %w", err) diff --git a/storage/ent/client/types.go b/storage/ent/client/types.go index 397d4d30a2..81bca0ae33 100644 --- a/storage/ent/client/types.go +++ b/storage/ent/client/types.go @@ -45,7 +45,8 @@ func toStorageAuthRequest(a *db.AuthRequest) storage.AuthRequest { CodeChallenge: a.CodeChallenge, CodeChallengeMethod: a.CodeChallengeMethod, }, - HMACKey: a.HmacKey, + HMACKey: a.HmacKey, + TOTPValidated: a.TotpValidated, } } @@ -100,6 +101,8 @@ func toStorageOfflineSession(o *db.OfflineSession) storage.OfflineSessions { UserID: o.UserID, ConnID: o.ConnID, ConnectorData: *o.ConnectorData, + TOTP: o.Totp, + TOTPConfirmed: o.TotpConfirmed, } if o.Refresh != nil { diff --git a/storage/ent/db/authrequest.go b/storage/ent/db/authrequest.go index b95592e58c..a0f2545379 100644 --- a/storage/ent/db/authrequest.go +++ b/storage/ent/db/authrequest.go @@ -57,8 +57,10 @@ type AuthRequest struct { // CodeChallengeMethod holds the value of the "code_challenge_method" field. CodeChallengeMethod string `json:"code_challenge_method,omitempty"` // HmacKey holds the value of the "hmac_key" field. - HmacKey []byte `json:"hmac_key,omitempty"` - selectValues sql.SelectValues + HmacKey []byte `json:"hmac_key,omitempty"` + // TotpValidated holds the value of the "totp_validated" field. + TotpValidated bool `json:"totp_validated,omitempty"` + selectValues sql.SelectValues } // scanValues returns the types for scanning values from sql.Rows. @@ -68,7 +70,7 @@ func (*AuthRequest) scanValues(columns []string) ([]any, error) { switch columns[i] { case authrequest.FieldScopes, authrequest.FieldResponseTypes, authrequest.FieldClaimsGroups, authrequest.FieldConnectorData, authrequest.FieldHmacKey: values[i] = new([]byte) - case authrequest.FieldForceApprovalPrompt, authrequest.FieldLoggedIn, authrequest.FieldClaimsEmailVerified: + case authrequest.FieldForceApprovalPrompt, authrequest.FieldLoggedIn, authrequest.FieldClaimsEmailVerified, authrequest.FieldTotpValidated: values[i] = new(sql.NullBool) case authrequest.FieldID, authrequest.FieldClientID, authrequest.FieldRedirectURI, authrequest.FieldNonce, authrequest.FieldState, authrequest.FieldClaimsUserID, authrequest.FieldClaimsUsername, authrequest.FieldClaimsEmail, authrequest.FieldClaimsPreferredUsername, authrequest.FieldConnectorID, authrequest.FieldCodeChallenge, authrequest.FieldCodeChallengeMethod: values[i] = new(sql.NullString) @@ -221,6 +223,12 @@ func (ar *AuthRequest) assignValues(columns []string, values []any) error { } else if value != nil { ar.HmacKey = *value } + case authrequest.FieldTotpValidated: + if value, ok := values[i].(*sql.NullBool); !ok { + return fmt.Errorf("unexpected type %T for field totp_validated", values[i]) + } else if value.Valid { + ar.TotpValidated = value.Bool + } default: ar.selectValues.Set(columns[i], values[i]) } @@ -318,6 +326,9 @@ func (ar *AuthRequest) String() string { builder.WriteString(", ") builder.WriteString("hmac_key=") builder.WriteString(fmt.Sprintf("%v", ar.HmacKey)) + builder.WriteString(", ") + builder.WriteString("totp_validated=") + builder.WriteString(fmt.Sprintf("%v", ar.TotpValidated)) builder.WriteByte(')') return builder.String() } diff --git a/storage/ent/db/authrequest/authrequest.go b/storage/ent/db/authrequest/authrequest.go index 0998c79932..5ed7e4e93c 100644 --- a/storage/ent/db/authrequest/authrequest.go +++ b/storage/ent/db/authrequest/authrequest.go @@ -51,6 +51,8 @@ const ( FieldCodeChallengeMethod = "code_challenge_method" // FieldHmacKey holds the string denoting the hmac_key field in the database. FieldHmacKey = "hmac_key" + // FieldTotpValidated holds the string denoting the totp_validated field in the database. + FieldTotpValidated = "totp_validated" // Table holds the table name of the authrequest in the database. Table = "auth_requests" ) @@ -78,6 +80,7 @@ var Columns = []string{ FieldCodeChallenge, FieldCodeChallengeMethod, FieldHmacKey, + FieldTotpValidated, } // ValidColumn reports if the column name is valid (part of the table columns). @@ -97,6 +100,8 @@ var ( DefaultCodeChallenge string // DefaultCodeChallengeMethod holds the default value on creation for the "code_challenge_method" field. DefaultCodeChallengeMethod string + // DefaultTotpValidated holds the default value on creation for the "totp_validated" field. + DefaultTotpValidated bool // IDValidator is a validator for the "id" field. It is called by the builders before save. IDValidator func(string) error ) @@ -183,3 +188,8 @@ func ByCodeChallenge(opts ...sql.OrderTermOption) OrderOption { func ByCodeChallengeMethod(opts ...sql.OrderTermOption) OrderOption { return sql.OrderByField(FieldCodeChallengeMethod, opts...).ToFunc() } + +// ByTotpValidated orders the results by the totp_validated field. +func ByTotpValidated(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldTotpValidated, opts...).ToFunc() +} diff --git a/storage/ent/db/authrequest/where.go b/storage/ent/db/authrequest/where.go index 4d3a39bec5..1da8d249ee 100644 --- a/storage/ent/db/authrequest/where.go +++ b/storage/ent/db/authrequest/where.go @@ -149,6 +149,11 @@ func HmacKey(v []byte) predicate.AuthRequest { return predicate.AuthRequest(sql.FieldEQ(FieldHmacKey, v)) } +// TotpValidated applies equality check predicate on the "totp_validated" field. It's identical to TotpValidatedEQ. +func TotpValidated(v bool) predicate.AuthRequest { + return predicate.AuthRequest(sql.FieldEQ(FieldTotpValidated, v)) +} + // ClientIDEQ applies the EQ predicate on the "client_id" field. func ClientIDEQ(v string) predicate.AuthRequest { return predicate.AuthRequest(sql.FieldEQ(FieldClientID, v)) @@ -1054,6 +1059,16 @@ func HmacKeyLTE(v []byte) predicate.AuthRequest { return predicate.AuthRequest(sql.FieldLTE(FieldHmacKey, v)) } +// TotpValidatedEQ applies the EQ predicate on the "totp_validated" field. +func TotpValidatedEQ(v bool) predicate.AuthRequest { + return predicate.AuthRequest(sql.FieldEQ(FieldTotpValidated, v)) +} + +// TotpValidatedNEQ applies the NEQ predicate on the "totp_validated" field. +func TotpValidatedNEQ(v bool) predicate.AuthRequest { + return predicate.AuthRequest(sql.FieldNEQ(FieldTotpValidated, v)) +} + // And groups predicates with the AND operator between them. func And(predicates ...predicate.AuthRequest) predicate.AuthRequest { return predicate.AuthRequest(sql.AndPredicates(predicates...)) diff --git a/storage/ent/db/authrequest_create.go b/storage/ent/db/authrequest_create.go index 3fe0c2b1f7..62085c0ba9 100644 --- a/storage/ent/db/authrequest_create.go +++ b/storage/ent/db/authrequest_create.go @@ -164,6 +164,20 @@ func (arc *AuthRequestCreate) SetHmacKey(b []byte) *AuthRequestCreate { return arc } +// SetTotpValidated sets the "totp_validated" field. +func (arc *AuthRequestCreate) SetTotpValidated(b bool) *AuthRequestCreate { + arc.mutation.SetTotpValidated(b) + return arc +} + +// SetNillableTotpValidated sets the "totp_validated" field if the given value is not nil. +func (arc *AuthRequestCreate) SetNillableTotpValidated(b *bool) *AuthRequestCreate { + if b != nil { + arc.SetTotpValidated(*b) + } + return arc +} + // SetID sets the "id" field. func (arc *AuthRequestCreate) SetID(s string) *AuthRequestCreate { arc.mutation.SetID(s) @@ -217,6 +231,10 @@ func (arc *AuthRequestCreate) defaults() { v := authrequest.DefaultCodeChallengeMethod arc.mutation.SetCodeChallengeMethod(v) } + if _, ok := arc.mutation.TotpValidated(); !ok { + v := authrequest.DefaultTotpValidated + arc.mutation.SetTotpValidated(v) + } } // check runs all checks and user-defined validators on the builder. @@ -269,6 +287,9 @@ func (arc *AuthRequestCreate) check() error { if _, ok := arc.mutation.HmacKey(); !ok { return &ValidationError{Name: "hmac_key", err: errors.New(`db: missing required field "AuthRequest.hmac_key"`)} } + if _, ok := arc.mutation.TotpValidated(); !ok { + return &ValidationError{Name: "totp_validated", err: errors.New(`db: missing required field "AuthRequest.totp_validated"`)} + } if v, ok := arc.mutation.ID(); ok { if err := authrequest.IDValidator(v); err != nil { return &ValidationError{Name: "id", err: fmt.Errorf(`db: validator failed for field "AuthRequest.id": %w`, err)} @@ -389,6 +410,10 @@ func (arc *AuthRequestCreate) createSpec() (*AuthRequest, *sqlgraph.CreateSpec) _spec.SetField(authrequest.FieldHmacKey, field.TypeBytes, value) _node.HmacKey = value } + if value, ok := arc.mutation.TotpValidated(); ok { + _spec.SetField(authrequest.FieldTotpValidated, field.TypeBool, value) + _node.TotpValidated = value + } return _node, _spec } diff --git a/storage/ent/db/authrequest_update.go b/storage/ent/db/authrequest_update.go index 0f314a4f51..ab40d5c499 100644 --- a/storage/ent/db/authrequest_update.go +++ b/storage/ent/db/authrequest_update.go @@ -311,6 +311,20 @@ func (aru *AuthRequestUpdate) SetHmacKey(b []byte) *AuthRequestUpdate { return aru } +// SetTotpValidated sets the "totp_validated" field. +func (aru *AuthRequestUpdate) SetTotpValidated(b bool) *AuthRequestUpdate { + aru.mutation.SetTotpValidated(b) + return aru +} + +// SetNillableTotpValidated sets the "totp_validated" field if the given value is not nil. +func (aru *AuthRequestUpdate) SetNillableTotpValidated(b *bool) *AuthRequestUpdate { + if b != nil { + aru.SetTotpValidated(*b) + } + return aru +} + // Mutation returns the AuthRequestMutation object of the builder. func (aru *AuthRequestUpdate) Mutation() *AuthRequestMutation { return aru.mutation @@ -439,6 +453,9 @@ func (aru *AuthRequestUpdate) sqlSave(ctx context.Context) (n int, err error) { if value, ok := aru.mutation.HmacKey(); ok { _spec.SetField(authrequest.FieldHmacKey, field.TypeBytes, value) } + if value, ok := aru.mutation.TotpValidated(); ok { + _spec.SetField(authrequest.FieldTotpValidated, field.TypeBool, value) + } if n, err = sqlgraph.UpdateNodes(ctx, aru.driver, _spec); err != nil { if _, ok := err.(*sqlgraph.NotFoundError); ok { err = &NotFoundError{authrequest.Label} @@ -741,6 +758,20 @@ func (aruo *AuthRequestUpdateOne) SetHmacKey(b []byte) *AuthRequestUpdateOne { return aruo } +// SetTotpValidated sets the "totp_validated" field. +func (aruo *AuthRequestUpdateOne) SetTotpValidated(b bool) *AuthRequestUpdateOne { + aruo.mutation.SetTotpValidated(b) + return aruo +} + +// SetNillableTotpValidated sets the "totp_validated" field if the given value is not nil. +func (aruo *AuthRequestUpdateOne) SetNillableTotpValidated(b *bool) *AuthRequestUpdateOne { + if b != nil { + aruo.SetTotpValidated(*b) + } + return aruo +} + // Mutation returns the AuthRequestMutation object of the builder. func (aruo *AuthRequestUpdateOne) Mutation() *AuthRequestMutation { return aruo.mutation @@ -899,6 +930,9 @@ func (aruo *AuthRequestUpdateOne) sqlSave(ctx context.Context) (_node *AuthReque if value, ok := aruo.mutation.HmacKey(); ok { _spec.SetField(authrequest.FieldHmacKey, field.TypeBytes, value) } + if value, ok := aruo.mutation.TotpValidated(); ok { + _spec.SetField(authrequest.FieldTotpValidated, field.TypeBool, value) + } _node = &AuthRequest{config: aruo.config} _spec.Assign = _node.assignValues _spec.ScanValues = _node.scanValues diff --git a/storage/ent/db/migrate/schema.go b/storage/ent/db/migrate/schema.go index d3295a0c79..ac9a80f437 100644 --- a/storage/ent/db/migrate/schema.go +++ b/storage/ent/db/migrate/schema.go @@ -56,6 +56,7 @@ var ( {Name: "code_challenge", Type: field.TypeString, Size: 2147483647, Default: "", SchemaType: map[string]string{"mysql": "varchar(384)", "postgres": "text", "sqlite3": "text"}}, {Name: "code_challenge_method", Type: field.TypeString, Size: 2147483647, Default: "", SchemaType: map[string]string{"mysql": "varchar(384)", "postgres": "text", "sqlite3": "text"}}, {Name: "hmac_key", Type: field.TypeBytes}, + {Name: "totp_validated", Type: field.TypeBool, Default: false}, } // AuthRequestsTable holds the schema information for the "auth_requests" table. AuthRequestsTable = &schema.Table{ @@ -148,6 +149,8 @@ var ( {Name: "conn_id", Type: field.TypeString, Size: 2147483647, SchemaType: map[string]string{"mysql": "varchar(384)", "postgres": "text", "sqlite3": "text"}}, {Name: "refresh", Type: field.TypeBytes}, {Name: "connector_data", Type: field.TypeBytes, Nullable: true}, + {Name: "totp", Type: field.TypeString, Nullable: true, Size: 2147483647}, + {Name: "totp_confirmed", Type: field.TypeBool, Nullable: true, Default: false}, } // OfflineSessionsTable holds the schema information for the "offline_sessions" table. OfflineSessionsTable = &schema.Table{ diff --git a/storage/ent/db/mutation.go b/storage/ent/db/mutation.go index 71203574e6..549c94d3ba 100644 --- a/storage/ent/db/mutation.go +++ b/storage/ent/db/mutation.go @@ -1258,6 +1258,7 @@ type AuthRequestMutation struct { code_challenge *string code_challenge_method *string hmac_key *[]byte + totp_validated *bool clearedFields map[string]struct{} done bool oldValue func(context.Context) (*AuthRequest, error) @@ -2188,6 +2189,42 @@ func (m *AuthRequestMutation) ResetHmacKey() { m.hmac_key = nil } +// SetTotpValidated sets the "totp_validated" field. +func (m *AuthRequestMutation) SetTotpValidated(b bool) { + m.totp_validated = &b +} + +// TotpValidated returns the value of the "totp_validated" field in the mutation. +func (m *AuthRequestMutation) TotpValidated() (r bool, exists bool) { + v := m.totp_validated + if v == nil { + return + } + return *v, true +} + +// OldTotpValidated returns the old "totp_validated" field's value of the AuthRequest entity. +// If the AuthRequest object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *AuthRequestMutation) OldTotpValidated(ctx context.Context) (v bool, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldTotpValidated is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldTotpValidated requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldTotpValidated: %w", err) + } + return oldValue.TotpValidated, nil +} + +// ResetTotpValidated resets all changes to the "totp_validated" field. +func (m *AuthRequestMutation) ResetTotpValidated() { + m.totp_validated = nil +} + // Where appends a list predicates to the AuthRequestMutation builder. func (m *AuthRequestMutation) Where(ps ...predicate.AuthRequest) { m.predicates = append(m.predicates, ps...) @@ -2222,7 +2259,7 @@ func (m *AuthRequestMutation) Type() string { // order to get all numeric fields that were incremented/decremented, call // AddedFields(). func (m *AuthRequestMutation) Fields() []string { - fields := make([]string, 0, 20) + fields := make([]string, 0, 21) if m.client_id != nil { fields = append(fields, authrequest.FieldClientID) } @@ -2283,6 +2320,9 @@ func (m *AuthRequestMutation) Fields() []string { if m.hmac_key != nil { fields = append(fields, authrequest.FieldHmacKey) } + if m.totp_validated != nil { + fields = append(fields, authrequest.FieldTotpValidated) + } return fields } @@ -2331,6 +2371,8 @@ func (m *AuthRequestMutation) Field(name string) (ent.Value, bool) { return m.CodeChallengeMethod() case authrequest.FieldHmacKey: return m.HmacKey() + case authrequest.FieldTotpValidated: + return m.TotpValidated() } return nil, false } @@ -2380,6 +2422,8 @@ func (m *AuthRequestMutation) OldField(ctx context.Context, name string) (ent.Va return m.OldCodeChallengeMethod(ctx) case authrequest.FieldHmacKey: return m.OldHmacKey(ctx) + case authrequest.FieldTotpValidated: + return m.OldTotpValidated(ctx) } return nil, fmt.Errorf("unknown AuthRequest field %s", name) } @@ -2529,6 +2573,13 @@ func (m *AuthRequestMutation) SetField(name string, value ent.Value) error { } m.SetHmacKey(v) return nil + case authrequest.FieldTotpValidated: + v, ok := value.(bool) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetTotpValidated(v) + return nil } return fmt.Errorf("unknown AuthRequest field %s", name) } @@ -2665,6 +2716,9 @@ func (m *AuthRequestMutation) ResetField(name string) error { case authrequest.FieldHmacKey: m.ResetHmacKey() return nil + case authrequest.FieldTotpValidated: + m.ResetTotpValidated() + return nil } return fmt.Errorf("unknown AuthRequest field %s", name) } @@ -5805,6 +5859,8 @@ type OfflineSessionMutation struct { conn_id *string refresh *[]byte connector_data *[]byte + totp *string + totp_confirmed *bool clearedFields map[string]struct{} done bool oldValue func(context.Context) (*OfflineSession, error) @@ -6072,6 +6128,104 @@ func (m *OfflineSessionMutation) ResetConnectorData() { delete(m.clearedFields, offlinesession.FieldConnectorData) } +// SetTotp sets the "totp" field. +func (m *OfflineSessionMutation) SetTotp(s string) { + m.totp = &s +} + +// Totp returns the value of the "totp" field in the mutation. +func (m *OfflineSessionMutation) Totp() (r string, exists bool) { + v := m.totp + if v == nil { + return + } + return *v, true +} + +// OldTotp returns the old "totp" field's value of the OfflineSession entity. +// If the OfflineSession object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *OfflineSessionMutation) OldTotp(ctx context.Context) (v string, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldTotp is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldTotp requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldTotp: %w", err) + } + return oldValue.Totp, nil +} + +// ClearTotp clears the value of the "totp" field. +func (m *OfflineSessionMutation) ClearTotp() { + m.totp = nil + m.clearedFields[offlinesession.FieldTotp] = struct{}{} +} + +// TotpCleared returns if the "totp" field was cleared in this mutation. +func (m *OfflineSessionMutation) TotpCleared() bool { + _, ok := m.clearedFields[offlinesession.FieldTotp] + return ok +} + +// ResetTotp resets all changes to the "totp" field. +func (m *OfflineSessionMutation) ResetTotp() { + m.totp = nil + delete(m.clearedFields, offlinesession.FieldTotp) +} + +// SetTotpConfirmed sets the "totp_confirmed" field. +func (m *OfflineSessionMutation) SetTotpConfirmed(b bool) { + m.totp_confirmed = &b +} + +// TotpConfirmed returns the value of the "totp_confirmed" field in the mutation. +func (m *OfflineSessionMutation) TotpConfirmed() (r bool, exists bool) { + v := m.totp_confirmed + if v == nil { + return + } + return *v, true +} + +// OldTotpConfirmed returns the old "totp_confirmed" field's value of the OfflineSession entity. +// If the OfflineSession object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *OfflineSessionMutation) OldTotpConfirmed(ctx context.Context) (v bool, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldTotpConfirmed is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldTotpConfirmed requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldTotpConfirmed: %w", err) + } + return oldValue.TotpConfirmed, nil +} + +// ClearTotpConfirmed clears the value of the "totp_confirmed" field. +func (m *OfflineSessionMutation) ClearTotpConfirmed() { + m.totp_confirmed = nil + m.clearedFields[offlinesession.FieldTotpConfirmed] = struct{}{} +} + +// TotpConfirmedCleared returns if the "totp_confirmed" field was cleared in this mutation. +func (m *OfflineSessionMutation) TotpConfirmedCleared() bool { + _, ok := m.clearedFields[offlinesession.FieldTotpConfirmed] + return ok +} + +// ResetTotpConfirmed resets all changes to the "totp_confirmed" field. +func (m *OfflineSessionMutation) ResetTotpConfirmed() { + m.totp_confirmed = nil + delete(m.clearedFields, offlinesession.FieldTotpConfirmed) +} + // Where appends a list predicates to the OfflineSessionMutation builder. func (m *OfflineSessionMutation) Where(ps ...predicate.OfflineSession) { m.predicates = append(m.predicates, ps...) @@ -6106,7 +6260,7 @@ func (m *OfflineSessionMutation) Type() string { // order to get all numeric fields that were incremented/decremented, call // AddedFields(). func (m *OfflineSessionMutation) Fields() []string { - fields := make([]string, 0, 4) + fields := make([]string, 0, 6) if m.user_id != nil { fields = append(fields, offlinesession.FieldUserID) } @@ -6119,6 +6273,12 @@ func (m *OfflineSessionMutation) Fields() []string { if m.connector_data != nil { fields = append(fields, offlinesession.FieldConnectorData) } + if m.totp != nil { + fields = append(fields, offlinesession.FieldTotp) + } + if m.totp_confirmed != nil { + fields = append(fields, offlinesession.FieldTotpConfirmed) + } return fields } @@ -6135,6 +6295,10 @@ func (m *OfflineSessionMutation) Field(name string) (ent.Value, bool) { return m.Refresh() case offlinesession.FieldConnectorData: return m.ConnectorData() + case offlinesession.FieldTotp: + return m.Totp() + case offlinesession.FieldTotpConfirmed: + return m.TotpConfirmed() } return nil, false } @@ -6152,6 +6316,10 @@ func (m *OfflineSessionMutation) OldField(ctx context.Context, name string) (ent return m.OldRefresh(ctx) case offlinesession.FieldConnectorData: return m.OldConnectorData(ctx) + case offlinesession.FieldTotp: + return m.OldTotp(ctx) + case offlinesession.FieldTotpConfirmed: + return m.OldTotpConfirmed(ctx) } return nil, fmt.Errorf("unknown OfflineSession field %s", name) } @@ -6189,6 +6357,20 @@ func (m *OfflineSessionMutation) SetField(name string, value ent.Value) error { } m.SetConnectorData(v) return nil + case offlinesession.FieldTotp: + v, ok := value.(string) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetTotp(v) + return nil + case offlinesession.FieldTotpConfirmed: + v, ok := value.(bool) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetTotpConfirmed(v) + return nil } return fmt.Errorf("unknown OfflineSession field %s", name) } @@ -6222,6 +6404,12 @@ func (m *OfflineSessionMutation) ClearedFields() []string { if m.FieldCleared(offlinesession.FieldConnectorData) { fields = append(fields, offlinesession.FieldConnectorData) } + if m.FieldCleared(offlinesession.FieldTotp) { + fields = append(fields, offlinesession.FieldTotp) + } + if m.FieldCleared(offlinesession.FieldTotpConfirmed) { + fields = append(fields, offlinesession.FieldTotpConfirmed) + } return fields } @@ -6239,6 +6427,12 @@ func (m *OfflineSessionMutation) ClearField(name string) error { case offlinesession.FieldConnectorData: m.ClearConnectorData() return nil + case offlinesession.FieldTotp: + m.ClearTotp() + return nil + case offlinesession.FieldTotpConfirmed: + m.ClearTotpConfirmed() + return nil } return fmt.Errorf("unknown OfflineSession nullable field %s", name) } @@ -6259,6 +6453,12 @@ func (m *OfflineSessionMutation) ResetField(name string) error { case offlinesession.FieldConnectorData: m.ResetConnectorData() return nil + case offlinesession.FieldTotp: + m.ResetTotp() + return nil + case offlinesession.FieldTotpConfirmed: + m.ResetTotpConfirmed() + return nil } return fmt.Errorf("unknown OfflineSession field %s", name) } diff --git a/storage/ent/db/offlinesession.go b/storage/ent/db/offlinesession.go index 7adc3afca3..611636cf83 100644 --- a/storage/ent/db/offlinesession.go +++ b/storage/ent/db/offlinesession.go @@ -24,6 +24,10 @@ type OfflineSession struct { Refresh []byte `json:"refresh,omitempty"` // ConnectorData holds the value of the "connector_data" field. ConnectorData *[]byte `json:"connector_data,omitempty"` + // Totp holds the value of the "totp" field. + Totp string `json:"totp,omitempty"` + // TotpConfirmed holds the value of the "totp_confirmed" field. + TotpConfirmed bool `json:"totp_confirmed,omitempty"` selectValues sql.SelectValues } @@ -34,7 +38,9 @@ func (*OfflineSession) scanValues(columns []string) ([]any, error) { switch columns[i] { case offlinesession.FieldRefresh, offlinesession.FieldConnectorData: values[i] = new([]byte) - case offlinesession.FieldID, offlinesession.FieldUserID, offlinesession.FieldConnID: + case offlinesession.FieldTotpConfirmed: + values[i] = new(sql.NullBool) + case offlinesession.FieldID, offlinesession.FieldUserID, offlinesession.FieldConnID, offlinesession.FieldTotp: values[i] = new(sql.NullString) default: values[i] = new(sql.UnknownType) @@ -81,6 +87,18 @@ func (os *OfflineSession) assignValues(columns []string, values []any) error { } else if value != nil { os.ConnectorData = value } + case offlinesession.FieldTotp: + if value, ok := values[i].(*sql.NullString); !ok { + return fmt.Errorf("unexpected type %T for field totp", values[i]) + } else if value.Valid { + os.Totp = value.String + } + case offlinesession.FieldTotpConfirmed: + if value, ok := values[i].(*sql.NullBool); !ok { + return fmt.Errorf("unexpected type %T for field totp_confirmed", values[i]) + } else if value.Valid { + os.TotpConfirmed = value.Bool + } default: os.selectValues.Set(columns[i], values[i]) } @@ -130,6 +148,12 @@ func (os *OfflineSession) String() string { builder.WriteString("connector_data=") builder.WriteString(fmt.Sprintf("%v", *v)) } + builder.WriteString(", ") + builder.WriteString("totp=") + builder.WriteString(os.Totp) + builder.WriteString(", ") + builder.WriteString("totp_confirmed=") + builder.WriteString(fmt.Sprintf("%v", os.TotpConfirmed)) builder.WriteByte(')') return builder.String() } diff --git a/storage/ent/db/offlinesession/offlinesession.go b/storage/ent/db/offlinesession/offlinesession.go index e7dbc446b7..fecf7ccad6 100644 --- a/storage/ent/db/offlinesession/offlinesession.go +++ b/storage/ent/db/offlinesession/offlinesession.go @@ -19,6 +19,10 @@ const ( FieldRefresh = "refresh" // FieldConnectorData holds the string denoting the connector_data field in the database. FieldConnectorData = "connector_data" + // FieldTotp holds the string denoting the totp field in the database. + FieldTotp = "totp" + // FieldTotpConfirmed holds the string denoting the totp_confirmed field in the database. + FieldTotpConfirmed = "totp_confirmed" // Table holds the table name of the offlinesession in the database. Table = "offline_sessions" ) @@ -30,6 +34,8 @@ var Columns = []string{ FieldConnID, FieldRefresh, FieldConnectorData, + FieldTotp, + FieldTotpConfirmed, } // ValidColumn reports if the column name is valid (part of the table columns). @@ -47,6 +53,8 @@ var ( UserIDValidator func(string) error // ConnIDValidator is a validator for the "conn_id" field. It is called by the builders before save. ConnIDValidator func(string) error + // DefaultTotpConfirmed holds the default value on creation for the "totp_confirmed" field. + DefaultTotpConfirmed bool // IDValidator is a validator for the "id" field. It is called by the builders before save. IDValidator func(string) error ) @@ -68,3 +76,13 @@ func ByUserID(opts ...sql.OrderTermOption) OrderOption { func ByConnID(opts ...sql.OrderTermOption) OrderOption { return sql.OrderByField(FieldConnID, opts...).ToFunc() } + +// ByTotp orders the results by the totp field. +func ByTotp(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldTotp, opts...).ToFunc() +} + +// ByTotpConfirmed orders the results by the totp_confirmed field. +func ByTotpConfirmed(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldTotpConfirmed, opts...).ToFunc() +} diff --git a/storage/ent/db/offlinesession/where.go b/storage/ent/db/offlinesession/where.go index e0f19ab2ce..d30f0bbbf1 100644 --- a/storage/ent/db/offlinesession/where.go +++ b/storage/ent/db/offlinesession/where.go @@ -82,6 +82,16 @@ func ConnectorData(v []byte) predicate.OfflineSession { return predicate.OfflineSession(sql.FieldEQ(FieldConnectorData, v)) } +// Totp applies equality check predicate on the "totp" field. It's identical to TotpEQ. +func Totp(v string) predicate.OfflineSession { + return predicate.OfflineSession(sql.FieldEQ(FieldTotp, v)) +} + +// TotpConfirmed applies equality check predicate on the "totp_confirmed" field. It's identical to TotpConfirmedEQ. +func TotpConfirmed(v bool) predicate.OfflineSession { + return predicate.OfflineSession(sql.FieldEQ(FieldTotpConfirmed, v)) +} + // UserIDEQ applies the EQ predicate on the "user_id" field. func UserIDEQ(v string) predicate.OfflineSession { return predicate.OfflineSession(sql.FieldEQ(FieldUserID, v)) @@ -302,6 +312,101 @@ func ConnectorDataNotNil() predicate.OfflineSession { return predicate.OfflineSession(sql.FieldNotNull(FieldConnectorData)) } +// TotpEQ applies the EQ predicate on the "totp" field. +func TotpEQ(v string) predicate.OfflineSession { + return predicate.OfflineSession(sql.FieldEQ(FieldTotp, v)) +} + +// TotpNEQ applies the NEQ predicate on the "totp" field. +func TotpNEQ(v string) predicate.OfflineSession { + return predicate.OfflineSession(sql.FieldNEQ(FieldTotp, v)) +} + +// TotpIn applies the In predicate on the "totp" field. +func TotpIn(vs ...string) predicate.OfflineSession { + return predicate.OfflineSession(sql.FieldIn(FieldTotp, vs...)) +} + +// TotpNotIn applies the NotIn predicate on the "totp" field. +func TotpNotIn(vs ...string) predicate.OfflineSession { + return predicate.OfflineSession(sql.FieldNotIn(FieldTotp, vs...)) +} + +// TotpGT applies the GT predicate on the "totp" field. +func TotpGT(v string) predicate.OfflineSession { + return predicate.OfflineSession(sql.FieldGT(FieldTotp, v)) +} + +// TotpGTE applies the GTE predicate on the "totp" field. +func TotpGTE(v string) predicate.OfflineSession { + return predicate.OfflineSession(sql.FieldGTE(FieldTotp, v)) +} + +// TotpLT applies the LT predicate on the "totp" field. +func TotpLT(v string) predicate.OfflineSession { + return predicate.OfflineSession(sql.FieldLT(FieldTotp, v)) +} + +// TotpLTE applies the LTE predicate on the "totp" field. +func TotpLTE(v string) predicate.OfflineSession { + return predicate.OfflineSession(sql.FieldLTE(FieldTotp, v)) +} + +// TotpContains applies the Contains predicate on the "totp" field. +func TotpContains(v string) predicate.OfflineSession { + return predicate.OfflineSession(sql.FieldContains(FieldTotp, v)) +} + +// TotpHasPrefix applies the HasPrefix predicate on the "totp" field. +func TotpHasPrefix(v string) predicate.OfflineSession { + return predicate.OfflineSession(sql.FieldHasPrefix(FieldTotp, v)) +} + +// TotpHasSuffix applies the HasSuffix predicate on the "totp" field. +func TotpHasSuffix(v string) predicate.OfflineSession { + return predicate.OfflineSession(sql.FieldHasSuffix(FieldTotp, v)) +} + +// TotpIsNil applies the IsNil predicate on the "totp" field. +func TotpIsNil() predicate.OfflineSession { + return predicate.OfflineSession(sql.FieldIsNull(FieldTotp)) +} + +// TotpNotNil applies the NotNil predicate on the "totp" field. +func TotpNotNil() predicate.OfflineSession { + return predicate.OfflineSession(sql.FieldNotNull(FieldTotp)) +} + +// TotpEqualFold applies the EqualFold predicate on the "totp" field. +func TotpEqualFold(v string) predicate.OfflineSession { + return predicate.OfflineSession(sql.FieldEqualFold(FieldTotp, v)) +} + +// TotpContainsFold applies the ContainsFold predicate on the "totp" field. +func TotpContainsFold(v string) predicate.OfflineSession { + return predicate.OfflineSession(sql.FieldContainsFold(FieldTotp, v)) +} + +// TotpConfirmedEQ applies the EQ predicate on the "totp_confirmed" field. +func TotpConfirmedEQ(v bool) predicate.OfflineSession { + return predicate.OfflineSession(sql.FieldEQ(FieldTotpConfirmed, v)) +} + +// TotpConfirmedNEQ applies the NEQ predicate on the "totp_confirmed" field. +func TotpConfirmedNEQ(v bool) predicate.OfflineSession { + return predicate.OfflineSession(sql.FieldNEQ(FieldTotpConfirmed, v)) +} + +// TotpConfirmedIsNil applies the IsNil predicate on the "totp_confirmed" field. +func TotpConfirmedIsNil() predicate.OfflineSession { + return predicate.OfflineSession(sql.FieldIsNull(FieldTotpConfirmed)) +} + +// TotpConfirmedNotNil applies the NotNil predicate on the "totp_confirmed" field. +func TotpConfirmedNotNil() predicate.OfflineSession { + return predicate.OfflineSession(sql.FieldNotNull(FieldTotpConfirmed)) +} + // And groups predicates with the AND operator between them. func And(predicates ...predicate.OfflineSession) predicate.OfflineSession { return predicate.OfflineSession(sql.AndPredicates(predicates...)) diff --git a/storage/ent/db/offlinesession_create.go b/storage/ent/db/offlinesession_create.go index b8250aac8d..2fa25d317d 100644 --- a/storage/ent/db/offlinesession_create.go +++ b/storage/ent/db/offlinesession_create.go @@ -43,6 +43,34 @@ func (osc *OfflineSessionCreate) SetConnectorData(b []byte) *OfflineSessionCreat return osc } +// SetTotp sets the "totp" field. +func (osc *OfflineSessionCreate) SetTotp(s string) *OfflineSessionCreate { + osc.mutation.SetTotp(s) + return osc +} + +// SetNillableTotp sets the "totp" field if the given value is not nil. +func (osc *OfflineSessionCreate) SetNillableTotp(s *string) *OfflineSessionCreate { + if s != nil { + osc.SetTotp(*s) + } + return osc +} + +// SetTotpConfirmed sets the "totp_confirmed" field. +func (osc *OfflineSessionCreate) SetTotpConfirmed(b bool) *OfflineSessionCreate { + osc.mutation.SetTotpConfirmed(b) + return osc +} + +// SetNillableTotpConfirmed sets the "totp_confirmed" field if the given value is not nil. +func (osc *OfflineSessionCreate) SetNillableTotpConfirmed(b *bool) *OfflineSessionCreate { + if b != nil { + osc.SetTotpConfirmed(*b) + } + return osc +} + // SetID sets the "id" field. func (osc *OfflineSessionCreate) SetID(s string) *OfflineSessionCreate { osc.mutation.SetID(s) @@ -56,6 +84,7 @@ func (osc *OfflineSessionCreate) Mutation() *OfflineSessionMutation { // Save creates the OfflineSession in the database. func (osc *OfflineSessionCreate) Save(ctx context.Context) (*OfflineSession, error) { + osc.defaults() return withHooks(ctx, osc.sqlSave, osc.mutation, osc.hooks) } @@ -81,6 +110,14 @@ func (osc *OfflineSessionCreate) ExecX(ctx context.Context) { } } +// defaults sets the default values of the builder before save. +func (osc *OfflineSessionCreate) defaults() { + if _, ok := osc.mutation.TotpConfirmed(); !ok { + v := offlinesession.DefaultTotpConfirmed + osc.mutation.SetTotpConfirmed(v) + } +} + // check runs all checks and user-defined validators on the builder. func (osc *OfflineSessionCreate) check() error { if _, ok := osc.mutation.UserID(); !ok { @@ -158,6 +195,14 @@ func (osc *OfflineSessionCreate) createSpec() (*OfflineSession, *sqlgraph.Create _spec.SetField(offlinesession.FieldConnectorData, field.TypeBytes, value) _node.ConnectorData = &value } + if value, ok := osc.mutation.Totp(); ok { + _spec.SetField(offlinesession.FieldTotp, field.TypeString, value) + _node.Totp = value + } + if value, ok := osc.mutation.TotpConfirmed(); ok { + _spec.SetField(offlinesession.FieldTotpConfirmed, field.TypeBool, value) + _node.TotpConfirmed = value + } return _node, _spec } @@ -179,6 +224,7 @@ func (oscb *OfflineSessionCreateBulk) Save(ctx context.Context) ([]*OfflineSessi for i := range oscb.builders { func(i int, root context.Context) { builder := oscb.builders[i] + builder.defaults() var mut Mutator = MutateFunc(func(ctx context.Context, m Mutation) (Value, error) { mutation, ok := m.(*OfflineSessionMutation) if !ok { diff --git a/storage/ent/db/offlinesession_update.go b/storage/ent/db/offlinesession_update.go index d912acf1a9..76f77c603d 100644 --- a/storage/ent/db/offlinesession_update.go +++ b/storage/ent/db/offlinesession_update.go @@ -73,6 +73,46 @@ func (osu *OfflineSessionUpdate) ClearConnectorData() *OfflineSessionUpdate { return osu } +// SetTotp sets the "totp" field. +func (osu *OfflineSessionUpdate) SetTotp(s string) *OfflineSessionUpdate { + osu.mutation.SetTotp(s) + return osu +} + +// SetNillableTotp sets the "totp" field if the given value is not nil. +func (osu *OfflineSessionUpdate) SetNillableTotp(s *string) *OfflineSessionUpdate { + if s != nil { + osu.SetTotp(*s) + } + return osu +} + +// ClearTotp clears the value of the "totp" field. +func (osu *OfflineSessionUpdate) ClearTotp() *OfflineSessionUpdate { + osu.mutation.ClearTotp() + return osu +} + +// SetTotpConfirmed sets the "totp_confirmed" field. +func (osu *OfflineSessionUpdate) SetTotpConfirmed(b bool) *OfflineSessionUpdate { + osu.mutation.SetTotpConfirmed(b) + return osu +} + +// SetNillableTotpConfirmed sets the "totp_confirmed" field if the given value is not nil. +func (osu *OfflineSessionUpdate) SetNillableTotpConfirmed(b *bool) *OfflineSessionUpdate { + if b != nil { + osu.SetTotpConfirmed(*b) + } + return osu +} + +// ClearTotpConfirmed clears the value of the "totp_confirmed" field. +func (osu *OfflineSessionUpdate) ClearTotpConfirmed() *OfflineSessionUpdate { + osu.mutation.ClearTotpConfirmed() + return osu +} + // Mutation returns the OfflineSessionMutation object of the builder. func (osu *OfflineSessionUpdate) Mutation() *OfflineSessionMutation { return osu.mutation @@ -147,6 +187,18 @@ func (osu *OfflineSessionUpdate) sqlSave(ctx context.Context) (n int, err error) if osu.mutation.ConnectorDataCleared() { _spec.ClearField(offlinesession.FieldConnectorData, field.TypeBytes) } + if value, ok := osu.mutation.Totp(); ok { + _spec.SetField(offlinesession.FieldTotp, field.TypeString, value) + } + if osu.mutation.TotpCleared() { + _spec.ClearField(offlinesession.FieldTotp, field.TypeString) + } + if value, ok := osu.mutation.TotpConfirmed(); ok { + _spec.SetField(offlinesession.FieldTotpConfirmed, field.TypeBool, value) + } + if osu.mutation.TotpConfirmedCleared() { + _spec.ClearField(offlinesession.FieldTotpConfirmed, field.TypeBool) + } if n, err = sqlgraph.UpdateNodes(ctx, osu.driver, _spec); err != nil { if _, ok := err.(*sqlgraph.NotFoundError); ok { err = &NotFoundError{offlinesession.Label} @@ -213,6 +265,46 @@ func (osuo *OfflineSessionUpdateOne) ClearConnectorData() *OfflineSessionUpdateO return osuo } +// SetTotp sets the "totp" field. +func (osuo *OfflineSessionUpdateOne) SetTotp(s string) *OfflineSessionUpdateOne { + osuo.mutation.SetTotp(s) + return osuo +} + +// SetNillableTotp sets the "totp" field if the given value is not nil. +func (osuo *OfflineSessionUpdateOne) SetNillableTotp(s *string) *OfflineSessionUpdateOne { + if s != nil { + osuo.SetTotp(*s) + } + return osuo +} + +// ClearTotp clears the value of the "totp" field. +func (osuo *OfflineSessionUpdateOne) ClearTotp() *OfflineSessionUpdateOne { + osuo.mutation.ClearTotp() + return osuo +} + +// SetTotpConfirmed sets the "totp_confirmed" field. +func (osuo *OfflineSessionUpdateOne) SetTotpConfirmed(b bool) *OfflineSessionUpdateOne { + osuo.mutation.SetTotpConfirmed(b) + return osuo +} + +// SetNillableTotpConfirmed sets the "totp_confirmed" field if the given value is not nil. +func (osuo *OfflineSessionUpdateOne) SetNillableTotpConfirmed(b *bool) *OfflineSessionUpdateOne { + if b != nil { + osuo.SetTotpConfirmed(*b) + } + return osuo +} + +// ClearTotpConfirmed clears the value of the "totp_confirmed" field. +func (osuo *OfflineSessionUpdateOne) ClearTotpConfirmed() *OfflineSessionUpdateOne { + osuo.mutation.ClearTotpConfirmed() + return osuo +} + // Mutation returns the OfflineSessionMutation object of the builder. func (osuo *OfflineSessionUpdateOne) Mutation() *OfflineSessionMutation { return osuo.mutation @@ -317,6 +409,18 @@ func (osuo *OfflineSessionUpdateOne) sqlSave(ctx context.Context) (_node *Offlin if osuo.mutation.ConnectorDataCleared() { _spec.ClearField(offlinesession.FieldConnectorData, field.TypeBytes) } + if value, ok := osuo.mutation.Totp(); ok { + _spec.SetField(offlinesession.FieldTotp, field.TypeString, value) + } + if osuo.mutation.TotpCleared() { + _spec.ClearField(offlinesession.FieldTotp, field.TypeString) + } + if value, ok := osuo.mutation.TotpConfirmed(); ok { + _spec.SetField(offlinesession.FieldTotpConfirmed, field.TypeBool, value) + } + if osuo.mutation.TotpConfirmedCleared() { + _spec.ClearField(offlinesession.FieldTotpConfirmed, field.TypeBool) + } _node = &OfflineSession{config: osuo.config} _spec.Assign = _node.assignValues _spec.ScanValues = _node.scanValues diff --git a/storage/ent/db/runtime.go b/storage/ent/db/runtime.go index 797c97613b..d27201527b 100644 --- a/storage/ent/db/runtime.go +++ b/storage/ent/db/runtime.go @@ -82,6 +82,10 @@ func init() { authrequestDescCodeChallengeMethod := authrequestFields[19].Descriptor() // authrequest.DefaultCodeChallengeMethod holds the default value on creation for the code_challenge_method field. authrequest.DefaultCodeChallengeMethod = authrequestDescCodeChallengeMethod.Default.(string) + // authrequestDescTotpValidated is the schema descriptor for totp_validated field. + authrequestDescTotpValidated := authrequestFields[21].Descriptor() + // authrequest.DefaultTotpValidated holds the default value on creation for the totp_validated field. + authrequest.DefaultTotpValidated = authrequestDescTotpValidated.Default.(bool) // authrequestDescID is the schema descriptor for id field. authrequestDescID := authrequestFields[0].Descriptor() // authrequest.IDValidator is a validator for the "id" field. It is called by the builders before save. @@ -198,6 +202,10 @@ func init() { offlinesessionDescConnID := offlinesessionFields[2].Descriptor() // offlinesession.ConnIDValidator is a validator for the "conn_id" field. It is called by the builders before save. offlinesession.ConnIDValidator = offlinesessionDescConnID.Validators[0].(func(string) error) + // offlinesessionDescTotpConfirmed is the schema descriptor for totp_confirmed field. + offlinesessionDescTotpConfirmed := offlinesessionFields[6].Descriptor() + // offlinesession.DefaultTotpConfirmed holds the default value on creation for the totp_confirmed field. + offlinesession.DefaultTotpConfirmed = offlinesessionDescTotpConfirmed.Default.(bool) // offlinesessionDescID is the schema descriptor for id field. offlinesessionDescID := offlinesessionFields[0].Descriptor() // offlinesession.IDValidator is a validator for the "id" field. It is called by the builders before save. diff --git a/storage/ent/schema/authrequest.go b/storage/ent/schema/authrequest.go index 2b75927b6f..7cdf36773f 100644 --- a/storage/ent/schema/authrequest.go +++ b/storage/ent/schema/authrequest.go @@ -88,6 +88,7 @@ func (AuthRequest) Fields() []ent.Field { SchemaType(textSchema). Default(""), field.Bytes("hmac_key"), + field.Bool("totp_validated").Default(false), } } diff --git a/storage/ent/schema/offlinesession.go b/storage/ent/schema/offlinesession.go index e9a166c344..433a7f0169 100644 --- a/storage/ent/schema/offlinesession.go +++ b/storage/ent/schema/offlinesession.go @@ -37,6 +37,8 @@ func (OfflineSession) Fields() []ent.Field { NotEmpty(), field.Bytes("refresh"), field.Bytes("connector_data").Nillable().Optional(), + field.Text("totp").Optional(), + field.Bool("totp_confirmed").Default(false).Optional(), } } diff --git a/storage/etcd/types.go b/storage/etcd/types.go index b3756604dd..acf1d6a7fd 100644 --- a/storage/etcd/types.go +++ b/storage/etcd/types.go @@ -86,6 +86,8 @@ type AuthRequest struct { CodeChallengeMethod string `json:"code_challenge_method,omitempty"` HMACKey []byte `json:"hmac_key"` + + TOTPValidated bool `json:"totp_validated,omitempty"` } func fromStorageAuthRequest(a storage.AuthRequest) AuthRequest { @@ -106,6 +108,7 @@ func fromStorageAuthRequest(a storage.AuthRequest) AuthRequest { CodeChallenge: a.PKCE.CodeChallenge, CodeChallengeMethod: a.PKCE.CodeChallengeMethod, HMACKey: a.HMACKey, + TOTPValidated: a.TOTPValidated, } } @@ -128,7 +131,8 @@ func toStorageAuthRequest(a AuthRequest) storage.AuthRequest { CodeChallenge: a.CodeChallenge, CodeChallengeMethod: a.CodeChallengeMethod, }, - HMACKey: a.HMACKey, + HMACKey: a.HMACKey, + TOTPValidated: a.TOTPValidated, } } @@ -231,6 +235,8 @@ type OfflineSessions struct { ConnID string `json:"conn_id,omitempty"` Refresh map[string]*storage.RefreshTokenRef `json:"refresh,omitempty"` ConnectorData []byte `json:"connectorData,omitempty"` + TOTP string `json:"totp,omitempty"` + TOTPConfirmed bool `json:"totp_confirmed,omitempty"` } func fromStorageOfflineSessions(o storage.OfflineSessions) OfflineSessions { @@ -239,6 +245,8 @@ func fromStorageOfflineSessions(o storage.OfflineSessions) OfflineSessions { ConnID: o.ConnID, Refresh: o.Refresh, ConnectorData: o.ConnectorData, + TOTP: o.TOTP, + TOTPConfirmed: o.TOTPConfirmed, } } @@ -248,6 +256,8 @@ func toStorageOfflineSessions(o OfflineSessions) storage.OfflineSessions { ConnID: o.ConnID, Refresh: o.Refresh, ConnectorData: o.ConnectorData, + TOTP: o.TOTP, + TOTPConfirmed: o.TOTPConfirmed, } if s.Refresh == nil { // Server code assumes this will be non-nil. diff --git a/storage/kubernetes/types.go b/storage/kubernetes/types.go index c126ddc087..796591266a 100644 --- a/storage/kubernetes/types.go +++ b/storage/kubernetes/types.go @@ -358,6 +358,8 @@ type AuthRequest struct { CodeChallengeMethod string `json:"code_challenge_method,omitempty"` HMACKey []byte `json:"hmac_key"` + + TOTPValidated bool `json:"totp_validated,omitempty"` } // AuthRequestList is a list of AuthRequests. @@ -386,7 +388,8 @@ func toStorageAuthRequest(req AuthRequest) storage.AuthRequest { CodeChallenge: req.CodeChallenge, CodeChallengeMethod: req.CodeChallengeMethod, }, - HMACKey: req.HMACKey, + HMACKey: req.HMACKey, + TOTPValidated: req.TOTPValidated, } return a } @@ -416,6 +419,7 @@ func (cli *client) fromStorageAuthRequest(a storage.AuthRequest) AuthRequest { CodeChallenge: a.PKCE.CodeChallenge, CodeChallengeMethod: a.PKCE.CodeChallengeMethod, HMACKey: a.HMACKey, + TOTPValidated: a.TOTPValidated, } return req } @@ -665,6 +669,8 @@ type OfflineSessions struct { ConnID string `json:"connID,omitempty"` Refresh map[string]*storage.RefreshTokenRef `json:"refresh,omitempty"` ConnectorData []byte `json:"connectorData,omitempty"` + TOTP string `json:"totp,omitempty"` + TOTPConfirmed bool `json:"totpConfirmed,omitempty"` } func (cli *client) fromStorageOfflineSessions(o storage.OfflineSessions) OfflineSessions { @@ -681,6 +687,8 @@ func (cli *client) fromStorageOfflineSessions(o storage.OfflineSessions) Offline ConnID: o.ConnID, Refresh: o.Refresh, ConnectorData: o.ConnectorData, + TOTP: o.TOTP, + TOTPConfirmed: o.TOTPConfirmed, } } @@ -690,6 +698,8 @@ func toStorageOfflineSessions(o OfflineSessions) storage.OfflineSessions { ConnID: o.ConnID, Refresh: o.Refresh, ConnectorData: o.ConnectorData, + TOTP: o.TOTP, + TOTPConfirmed: o.TOTPConfirmed, } if s.Refresh == nil { // Server code assumes this will be non-nil. diff --git a/storage/sql/crud.go b/storage/sql/crud.go index 1249243ced..7d8b62ca30 100644 --- a/storage/sql/crud.go +++ b/storage/sql/crud.go @@ -134,10 +134,10 @@ func (c *conn) CreateAuthRequest(ctx context.Context, a storage.AuthRequest) err connector_id, connector_data, expiry, code_challenge, code_challenge_method, - hmac_key + hmac_key, totp_validated ) values ( - $1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15, $16, $17, $18, $19, $20, $21 + $1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15, $16, $17, $18, $19, $20, $21, $22 ); `, a.ID, a.ClientID, encoder(a.ResponseTypes), encoder(a.Scopes), a.RedirectURI, a.Nonce, a.State, @@ -147,7 +147,7 @@ func (c *conn) CreateAuthRequest(ctx context.Context, a storage.AuthRequest) err a.ConnectorID, a.ConnectorData, a.Expiry, a.PKCE.CodeChallenge, a.PKCE.CodeChallengeMethod, - a.HMACKey, + a.HMACKey, a.TOTPValidated, ) if err != nil { if c.alreadyExistsCheck(err) { @@ -180,8 +180,9 @@ func (c *conn) UpdateAuthRequest(id string, updater func(a storage.AuthRequest) connector_id = $15, connector_data = $16, expiry = $17, code_challenge = $18, code_challenge_method = $19, - hmac_key = $20 - where id = $21; + hmac_key = $20, + totp_validated = $21 + where id = $22; `, a.ClientID, encoder(a.ResponseTypes), encoder(a.Scopes), a.RedirectURI, a.Nonce, a.State, a.ForceApprovalPrompt, a.LoggedIn, @@ -190,7 +191,7 @@ func (c *conn) UpdateAuthRequest(id string, updater func(a storage.AuthRequest) encoder(a.Claims.Groups), a.ConnectorID, a.ConnectorData, a.Expiry, - a.PKCE.CodeChallenge, a.PKCE.CodeChallengeMethod, a.HMACKey, + a.PKCE.CodeChallenge, a.PKCE.CodeChallengeMethod, a.HMACKey, a.TOTPValidated, r.ID, ) if err != nil { @@ -212,7 +213,7 @@ func getAuthRequest(q querier, id string) (a storage.AuthRequest, err error) { claims_user_id, claims_username, claims_preferred_username, claims_email, claims_email_verified, claims_groups, connector_id, connector_data, expiry, - code_challenge, code_challenge_method, hmac_key + code_challenge, code_challenge_method, hmac_key, totp_validated from auth_request where id = $1; `, id).Scan( &a.ID, &a.ClientID, decoder(&a.ResponseTypes), decoder(&a.Scopes), &a.RedirectURI, &a.Nonce, &a.State, @@ -221,7 +222,7 @@ func getAuthRequest(q querier, id string) (a storage.AuthRequest, err error) { &a.Claims.Email, &a.Claims.EmailVerified, decoder(&a.Claims.Groups), &a.ConnectorID, &a.ConnectorData, &a.Expiry, - &a.PKCE.CodeChallenge, &a.PKCE.CodeChallengeMethod, &a.HMACKey, + &a.PKCE.CodeChallenge, &a.PKCE.CodeChallengeMethod, &a.HMACKey, &a.TOTPValidated, ) if err != nil { if err == sql.ErrNoRows { @@ -694,13 +695,13 @@ func scanPassword(s scanner) (p storage.Password, err error) { func (c *conn) CreateOfflineSessions(ctx context.Context, s storage.OfflineSessions) error { _, err := c.Exec(` insert into offline_session ( - user_id, conn_id, refresh, connector_data + user_id, conn_id, refresh, connector_data, totp, totp_confirmed ) values ( - $1, $2, $3, $4 + $1, $2, $3, $4, $5, $6 ); `, - s.UserID, s.ConnID, encoder(s.Refresh), s.ConnectorData, + s.UserID, s.ConnID, encoder(s.Refresh), s.ConnectorData, s.TOTP, s.TOTPConfirmed, ) if err != nil { if c.alreadyExistsCheck(err) { @@ -726,10 +727,12 @@ func (c *conn) UpdateOfflineSessions(userID string, connID string, updater func( update offline_session set refresh = $1, - connector_data = $2 - where user_id = $3 AND conn_id = $4; + connector_data = $2, + totp = $3, + totp_confirmed = $4 + where user_id = $5 AND conn_id = $6; `, - encoder(newSession.Refresh), newSession.ConnectorData, s.UserID, s.ConnID, + encoder(newSession.Refresh), newSession.ConnectorData, newSession.TOTP, newSession.TOTPConfirmed, s.UserID, s.ConnID, ) if err != nil { return fmt.Errorf("update offline session: %v", err) @@ -745,7 +748,7 @@ func (c *conn) GetOfflineSessions(userID string, connID string) (storage.Offline func getOfflineSessions(q querier, userID string, connID string) (storage.OfflineSessions, error) { return scanOfflineSessions(q.QueryRow(` select - user_id, conn_id, refresh, connector_data + user_id, conn_id, refresh, connector_data, totp, totp_confirmed from offline_session where user_id = $1 AND conn_id = $2; `, userID, connID)) @@ -753,7 +756,7 @@ func getOfflineSessions(q querier, userID string, connID string) (storage.Offlin func scanOfflineSessions(s scanner) (o storage.OfflineSessions, err error) { err = s.Scan( - &o.UserID, &o.ConnID, decoder(&o.Refresh), &o.ConnectorData, + &o.UserID, &o.ConnID, decoder(&o.Refresh), &o.ConnectorData, &o.TOTP, &o.TOTPConfirmed, ) if err != nil { if err == sql.ErrNoRows { diff --git a/storage/sql/migrate.go b/storage/sql/migrate.go index 83e9c20d94..621fcbb76f 100644 --- a/storage/sql/migrate.go +++ b/storage/sql/migrate.go @@ -298,4 +298,17 @@ var migrations = []migration{ add column hmac_key bytea;`, }, }, + { + stmts: []string{ + ` + alter table offline_session + add column totp text;`, + ` + alter table offline_session + add column totp_confirmed boolean default false;`, + ` + alter table auth_request + add column totp_validated boolean not null default false;`, + }, + }, } diff --git a/storage/storage.go b/storage/storage.go index 03883ef5aa..f1fa20cf4a 100644 --- a/storage/storage.go +++ b/storage/storage.go @@ -233,6 +233,9 @@ type AuthRequest struct { // HMACKey is used when generating an AuthRequest-specific HMAC HMACKey []byte + + // TOTPValidated is set to true if the user has validated their second authentication factor. + TOTPValidated bool } // AuthCode represents a code which can be exchanged for an OAuth2 token response. @@ -330,6 +333,11 @@ type OfflineSessions struct { // Authentication data provided by an upstream source. ConnectorData []byte + + // TOTP is the otp key used to generate TOTP codes for the user. + // The second factor is ignored if the field is empty. + TOTP string + TOTPConfirmed bool } // Password is an email to password mapping managed by the storage. diff --git a/web/templates/totp_verify.html b/web/templates/totp_verify.html new file mode 100644 index 0000000000..2a77b8d35e --- /dev/null +++ b/web/templates/totp_verify.html @@ -0,0 +1,28 @@ +{{ template "header.html" . }} + +
Scan the QR code below using your preferred authenticator app for future authentications
+ + {{ else }} +Open your authenticator app and enter the code
for {{ .Issuer }}: ({{ .Connector }})