From 3a05f4a21eb448e120ac2b7c1cff52e6b0d661fd Mon Sep 17 00:00:00 2001 From: Moustafa Amer Date: Wed, 20 Dec 2023 15:23:07 -0500 Subject: [PATCH] feat: rebase saml idp flow feature This feature rebases from the current master branch to include PR #1514. --- go.sum | 6 ++ server/handlers.go | 2 + storage/ent/client/client.go | 7 ++ storage/ent/client/types.go | 4 ++ storage/ent/db/migrate/schema.go | 1 + storage/ent/db/mutation.go | 76 ++++++++++++++++++++- storage/ent/db/oauth2client.go | 20 +++++- storage/ent/db/oauth2client/oauth2client.go | 3 + storage/ent/db/oauth2client/where.go | 10 +++ storage/ent/db/oauth2client_create.go | 19 ++++++ storage/ent/db/oauth2client_update.go | 53 ++++++++++++++ storage/ent/schema/client.go | 7 ++ 12 files changed, 204 insertions(+), 4 deletions(-) diff --git a/go.sum b/go.sum index 5026e7c0f1..c8f02b1024 100644 --- a/go.sum +++ b/go.sum @@ -149,6 +149,8 @@ github.com/lib/pq v1.10.9 h1:YXG7RB+JIjhP29X+OtkiDnYaXQwpS4JEWq7dtCCRUEw= github.com/lib/pq v1.10.9/go.mod h1:AlVN5x4E4T544tWzH6hKfbfQvm3HdbOxrmggDNAPY9o= github.com/mattermost/xml-roundtrip-validator v0.1.0 h1:RXbVD2UAl7A7nOTR4u7E3ILa4IbtvKBHw64LDsmu9hU= github.com/mattermost/xml-roundtrip-validator v0.1.0/go.mod h1:qccnGMcpgwcNaBnxqpJpWWUiPNr5H3O8eDgGV9gT5To= +github.com/mattn/go-runewidth v0.0.9 h1:Lm995f3rfxdpd6TSmuVCHVb/QhupuXlYr8sCI/QdE+0= +github.com/mattn/go-runewidth v0.0.9/go.mod h1:H031xJmbD/WCDINGzjvQ9THkh0rPKHF+m2gUSrubnMI= github.com/mattn/go-sqlite3 v1.14.17 h1:mCRHCLDUBXgpKAqIKsaAaAsrAlbkeomtRFKXh2L6YIM= github.com/mattn/go-sqlite3 v1.14.17/go.mod h1:2eHXhiwb8IkHr+BDWZGa96P6+rkvnG63S2DGjv9HUNg= github.com/matttproud/golang_protobuf_extensions v1.0.4 h1:mmDVorXM7PCGKw94cs5zkfA9PSy5pEvNWRP0ET0TIVo= @@ -161,6 +163,8 @@ github.com/mitchellh/reflectwalk v1.0.0 h1:9D+8oIskB4VJBN5SFlmc27fSlIBZaov1Wpk/I github.com/mitchellh/reflectwalk v1.0.0/go.mod h1:mSTlrgnPZtwu0c4WaC2kGObEpuNDbx0jmZXqmk4esnw= github.com/oklog/run v1.1.0 h1:GEenZ1cK0+q0+wsJew9qUg/DyD8k3JzYsZAi5gYi2mA= github.com/oklog/run v1.1.0/go.mod h1:sVPdnTZT1zYwAJeCMu2Th4T21pA3FPOQRfWjQlk7DVU= +github.com/olekukonko/tablewriter v0.0.5 h1:P2Ga83D34wi1o9J6Wh1mRuqd4mF/x/lgBS7N7AbDhec= +github.com/olekukonko/tablewriter v0.0.5/go.mod h1:hPp6KlRPjbx+hW8ykQs1w3UBbZlj6HuIJcUGPhkA7kY= github.com/pkg/diff v0.0.0-20210226163009-20ebb0f2a09e/go.mod h1:pJLUxLENpZxwdsKMEsNbx1VGcRFpLqf3715MtcvvzbA= github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4= @@ -326,6 +330,8 @@ golang.org/x/tools v0.0.0-20200619180055-7c47624df98f/go.mod h1:EkVYQZoAsY45+roY golang.org/x/tools v0.0.0-20210106214847-113979e3529a/go.mod h1:emZCQorbCU4vsT4fOWvOPXz4eW1wZW4PmDk9uLelYpA= golang.org/x/tools v0.1.12/go.mod h1:hNGJHUnrk76NpqgfD5Aqm5Crs+Hm0VOH/i9J2+nxYbc= golang.org/x/tools v0.6.0/go.mod h1:Xwgl3UAJ/d3gWutnCtw505GrjyAbvKui8lOU390QaIU= +golang.org/x/tools v0.8.1-0.20230428195545-5283a0178901 h1:0wxTF6pSjIIhNt7mo9GvjDfzyCOiWhmICgtO/Ah948s= +golang.org/x/tools v0.8.1-0.20230428195545-5283a0178901/go.mod h1:JxBZ99ISMI5ViVkT1tr6tdNmXeTrcpVSD3vZ1RsRdN4= golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= diff --git a/server/handlers.go b/server/handlers.go index 307556fe07..7fd8151ca1 100644 --- a/server/handlers.go +++ b/server/handlers.go @@ -414,6 +414,8 @@ func (s *Server) handleConnectorCallback(w http.ResponseWriter, r *http.Request) s.renderError(r, w, code, fmt.Sprintf("Error processing SAML callback: %s.", err)) return } + // remove before PR + s.logger.Infof("SAML callback processed successfully") default: s.renderError(r, w, http.StatusBadRequest, "Method not supported") return diff --git a/storage/ent/client/client.go b/storage/ent/client/client.go index 07434bd60b..48e0f4e1cf 100644 --- a/storage/ent/client/client.go +++ b/storage/ent/client/client.go @@ -4,10 +4,16 @@ import ( "context" "github.com/dexidp/dex/storage" + "github.com/dexidp/dex/storage/ent/schema" ) // CreateClient saves provided oauth2 client settings into the database. func (d *Database) CreateClient(client storage.Client) error { + + samlDbSchema := schema.SAMLInitiated{ + Scopes: client.SAMLInitiated.Scopes, + RedirectURI: client.SAMLInitiated.RedirectURI, + } _, err := d.client.OAuth2Client.Create(). SetID(client.ID). SetName(client.Name). @@ -16,6 +22,7 @@ func (d *Database) CreateClient(client storage.Client) error { SetLogoURL(client.LogoURL). SetRedirectUris(client.RedirectURIs). SetTrustedPeers(client.TrustedPeers). + SetSamlInitiated(samlDbSchema). Save(context.TODO()) if err != nil { return convertDBError("create oauth2 client: %w", err) diff --git a/storage/ent/client/types.go b/storage/ent/client/types.go index 397d4d30a2..8364d628ae 100644 --- a/storage/ent/client/types.go +++ b/storage/ent/client/types.go @@ -83,6 +83,10 @@ func toStorageClient(c *db.OAuth2Client) storage.Client { Public: c.Public, Name: c.Name, LogoURL: c.LogoURL, + SAMLInitiated: storage.SAMLInitiatedConfig{ + Scopes: c.SamlInitiated.Scopes, + RedirectURI: c.SamlInitiated.RedirectURI, + }, } } diff --git a/storage/ent/db/migrate/schema.go b/storage/ent/db/migrate/schema.go index d3295a0c79..2176096161 100644 --- a/storage/ent/db/migrate/schema.go +++ b/storage/ent/db/migrate/schema.go @@ -134,6 +134,7 @@ var ( {Name: "public", Type: field.TypeBool}, {Name: "name", Type: field.TypeString, Size: 2147483647, SchemaType: map[string]string{"mysql": "varchar(384)", "postgres": "text", "sqlite3": "text"}}, {Name: "logo_url", Type: field.TypeString, Size: 2147483647, SchemaType: map[string]string{"mysql": "varchar(384)", "postgres": "text", "sqlite3": "text"}}, + {Name: "saml_initiated", Type: field.TypeJSON, Nullable: true}, } // Oauth2clientsTable holds the schema information for the "oauth2clients" table. Oauth2clientsTable = &schema.Table{ diff --git a/storage/ent/db/mutation.go b/storage/ent/db/mutation.go index aec11425c5..ce4de32da6 100644 --- a/storage/ent/db/mutation.go +++ b/storage/ent/db/mutation.go @@ -23,6 +23,7 @@ import ( "github.com/dexidp/dex/storage/ent/db/password" "github.com/dexidp/dex/storage/ent/db/predicate" "github.com/dexidp/dex/storage/ent/db/refreshtoken" + "github.com/dexidp/dex/storage/ent/schema" jose "gopkg.in/square/go-jose.v2" ) @@ -5132,6 +5133,7 @@ type OAuth2ClientMutation struct { public *bool name *string logo_url *string + samlInitiated *schema.SAMLInitiated clearedFields map[string]struct{} done bool oldValue func(context.Context) (*OAuth2Client, error) @@ -5516,6 +5518,55 @@ func (m *OAuth2ClientMutation) ResetLogoURL() { m.logo_url = nil } +// SetSamlInitiated sets the "samlInitiated" field. +func (m *OAuth2ClientMutation) SetSamlInitiated(si schema.SAMLInitiated) { + m.samlInitiated = &si +} + +// SamlInitiated returns the value of the "samlInitiated" field in the mutation. +func (m *OAuth2ClientMutation) SamlInitiated() (r schema.SAMLInitiated, exists bool) { + v := m.samlInitiated + if v == nil { + return + } + return *v, true +} + +// OldSamlInitiated returns the old "samlInitiated" field's value of the OAuth2Client entity. +// If the OAuth2Client object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *OAuth2ClientMutation) OldSamlInitiated(ctx context.Context) (v schema.SAMLInitiated, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldSamlInitiated is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldSamlInitiated requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldSamlInitiated: %w", err) + } + return oldValue.SamlInitiated, nil +} + +// ClearSamlInitiated clears the value of the "samlInitiated" field. +func (m *OAuth2ClientMutation) ClearSamlInitiated() { + m.samlInitiated = nil + m.clearedFields[oauth2client.FieldSamlInitiated] = struct{}{} +} + +// SamlInitiatedCleared returns if the "samlInitiated" field was cleared in this mutation. +func (m *OAuth2ClientMutation) SamlInitiatedCleared() bool { + _, ok := m.clearedFields[oauth2client.FieldSamlInitiated] + return ok +} + +// ResetSamlInitiated resets all changes to the "samlInitiated" field. +func (m *OAuth2ClientMutation) ResetSamlInitiated() { + m.samlInitiated = nil + delete(m.clearedFields, oauth2client.FieldSamlInitiated) +} + // Where appends a list predicates to the OAuth2ClientMutation builder. func (m *OAuth2ClientMutation) Where(ps ...predicate.OAuth2Client) { m.predicates = append(m.predicates, ps...) @@ -5550,7 +5601,7 @@ func (m *OAuth2ClientMutation) Type() string { // order to get all numeric fields that were incremented/decremented, call // AddedFields(). func (m *OAuth2ClientMutation) Fields() []string { - fields := make([]string, 0, 6) + fields := make([]string, 0, 7) if m.secret != nil { fields = append(fields, oauth2client.FieldSecret) } @@ -5569,6 +5620,9 @@ func (m *OAuth2ClientMutation) Fields() []string { if m.logo_url != nil { fields = append(fields, oauth2client.FieldLogoURL) } + if m.samlInitiated != nil { + fields = append(fields, oauth2client.FieldSamlInitiated) + } return fields } @@ -5589,6 +5643,8 @@ func (m *OAuth2ClientMutation) Field(name string) (ent.Value, bool) { return m.Name() case oauth2client.FieldLogoURL: return m.LogoURL() + case oauth2client.FieldSamlInitiated: + return m.SamlInitiated() } return nil, false } @@ -5610,6 +5666,8 @@ func (m *OAuth2ClientMutation) OldField(ctx context.Context, name string) (ent.V return m.OldName(ctx) case oauth2client.FieldLogoURL: return m.OldLogoURL(ctx) + case oauth2client.FieldSamlInitiated: + return m.OldSamlInitiated(ctx) } return nil, fmt.Errorf("unknown OAuth2Client field %s", name) } @@ -5661,6 +5719,13 @@ func (m *OAuth2ClientMutation) SetField(name string, value ent.Value) error { } m.SetLogoURL(v) return nil + case oauth2client.FieldSamlInitiated: + v, ok := value.(schema.SAMLInitiated) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetSamlInitiated(v) + return nil } return fmt.Errorf("unknown OAuth2Client field %s", name) } @@ -5697,6 +5762,9 @@ func (m *OAuth2ClientMutation) ClearedFields() []string { if m.FieldCleared(oauth2client.FieldTrustedPeers) { fields = append(fields, oauth2client.FieldTrustedPeers) } + if m.FieldCleared(oauth2client.FieldSamlInitiated) { + fields = append(fields, oauth2client.FieldSamlInitiated) + } return fields } @@ -5717,6 +5785,9 @@ func (m *OAuth2ClientMutation) ClearField(name string) error { case oauth2client.FieldTrustedPeers: m.ClearTrustedPeers() return nil + case oauth2client.FieldSamlInitiated: + m.ClearSamlInitiated() + return nil } return fmt.Errorf("unknown OAuth2Client nullable field %s", name) } @@ -5743,6 +5814,9 @@ func (m *OAuth2ClientMutation) ResetField(name string) error { case oauth2client.FieldLogoURL: m.ResetLogoURL() return nil + case oauth2client.FieldSamlInitiated: + m.ResetSamlInitiated() + return nil } return fmt.Errorf("unknown OAuth2Client field %s", name) } diff --git a/storage/ent/db/oauth2client.go b/storage/ent/db/oauth2client.go index 39a4cf82ab..08f44ebb9a 100644 --- a/storage/ent/db/oauth2client.go +++ b/storage/ent/db/oauth2client.go @@ -10,6 +10,7 @@ import ( "entgo.io/ent" "entgo.io/ent/dialect/sql" "github.com/dexidp/dex/storage/ent/db/oauth2client" + "github.com/dexidp/dex/storage/ent/schema" ) // OAuth2Client is the model entity for the OAuth2Client schema. @@ -28,8 +29,10 @@ type OAuth2Client struct { // Name holds the value of the "name" field. Name string `json:"name,omitempty"` // LogoURL holds the value of the "logo_url" field. - LogoURL string `json:"logo_url,omitempty"` - selectValues sql.SelectValues + LogoURL string `json:"logo_url,omitempty"` + // SamlInitiated holds the value of the "samlInitiated" field. + SamlInitiated schema.SAMLInitiated `json:"samlInitiated,omitempty"` + selectValues sql.SelectValues } // scanValues returns the types for scanning values from sql.Rows. @@ -37,7 +40,7 @@ func (*OAuth2Client) scanValues(columns []string) ([]any, error) { values := make([]any, len(columns)) for i := range columns { switch columns[i] { - case oauth2client.FieldRedirectUris, oauth2client.FieldTrustedPeers: + case oauth2client.FieldRedirectUris, oauth2client.FieldTrustedPeers, oauth2client.FieldSamlInitiated: values[i] = new([]byte) case oauth2client.FieldPublic: values[i] = new(sql.NullBool) @@ -104,6 +107,14 @@ func (o *OAuth2Client) assignValues(columns []string, values []any) error { } else if value.Valid { o.LogoURL = value.String } + case oauth2client.FieldSamlInitiated: + if value, ok := values[i].(*[]byte); !ok { + return fmt.Errorf("unexpected type %T for field samlInitiated", values[i]) + } else if value != nil && len(*value) > 0 { + if err := json.Unmarshal(*value, &o.SamlInitiated); err != nil { + return fmt.Errorf("unmarshal field samlInitiated: %w", err) + } + } default: o.selectValues.Set(columns[i], values[i]) } @@ -157,6 +168,9 @@ func (o *OAuth2Client) String() string { builder.WriteString(", ") builder.WriteString("logo_url=") builder.WriteString(o.LogoURL) + builder.WriteString(", ") + builder.WriteString("samlInitiated=") + builder.WriteString(fmt.Sprintf("%v", o.SamlInitiated)) builder.WriteByte(')') return builder.String() } diff --git a/storage/ent/db/oauth2client/oauth2client.go b/storage/ent/db/oauth2client/oauth2client.go index 08df76be9c..c676bea9c7 100644 --- a/storage/ent/db/oauth2client/oauth2client.go +++ b/storage/ent/db/oauth2client/oauth2client.go @@ -23,6 +23,8 @@ const ( FieldName = "name" // FieldLogoURL holds the string denoting the logo_url field in the database. FieldLogoURL = "logo_url" + // FieldSamlInitiated holds the string denoting the samlinitiated field in the database. + FieldSamlInitiated = "saml_initiated" // Table holds the table name of the oauth2client in the database. Table = "oauth2clients" ) @@ -36,6 +38,7 @@ var Columns = []string{ FieldPublic, FieldName, FieldLogoURL, + FieldSamlInitiated, } // ValidColumn reports if the column name is valid (part of the table columns). diff --git a/storage/ent/db/oauth2client/where.go b/storage/ent/db/oauth2client/where.go index 55aee79b1a..6e69068d05 100644 --- a/storage/ent/db/oauth2client/where.go +++ b/storage/ent/db/oauth2client/where.go @@ -307,6 +307,16 @@ func LogoURLContainsFold(v string) predicate.OAuth2Client { return predicate.OAuth2Client(sql.FieldContainsFold(FieldLogoURL, v)) } +// SamlInitiatedIsNil applies the IsNil predicate on the "samlInitiated" field. +func SamlInitiatedIsNil() predicate.OAuth2Client { + return predicate.OAuth2Client(sql.FieldIsNull(FieldSamlInitiated)) +} + +// SamlInitiatedNotNil applies the NotNil predicate on the "samlInitiated" field. +func SamlInitiatedNotNil() predicate.OAuth2Client { + return predicate.OAuth2Client(sql.FieldNotNull(FieldSamlInitiated)) +} + // And groups predicates with the AND operator between them. func And(predicates ...predicate.OAuth2Client) predicate.OAuth2Client { return predicate.OAuth2Client(sql.AndPredicates(predicates...)) diff --git a/storage/ent/db/oauth2client_create.go b/storage/ent/db/oauth2client_create.go index 5b472cd36d..5ecd10980d 100644 --- a/storage/ent/db/oauth2client_create.go +++ b/storage/ent/db/oauth2client_create.go @@ -10,6 +10,7 @@ import ( "entgo.io/ent/dialect/sql/sqlgraph" "entgo.io/ent/schema/field" "github.com/dexidp/dex/storage/ent/db/oauth2client" + "github.com/dexidp/dex/storage/ent/schema" ) // OAuth2ClientCreate is the builder for creating a OAuth2Client entity. @@ -55,6 +56,20 @@ func (oc *OAuth2ClientCreate) SetLogoURL(s string) *OAuth2ClientCreate { return oc } +// SetSamlInitiated sets the "samlInitiated" field. +func (oc *OAuth2ClientCreate) SetSamlInitiated(si schema.SAMLInitiated) *OAuth2ClientCreate { + oc.mutation.SetSamlInitiated(si) + return oc +} + +// SetNillableSamlInitiated sets the "samlInitiated" field if the given value is not nil. +func (oc *OAuth2ClientCreate) SetNillableSamlInitiated(si *schema.SAMLInitiated) *OAuth2ClientCreate { + if si != nil { + oc.SetSamlInitiated(*si) + } + return oc +} + // SetID sets the "id" field. func (oc *OAuth2ClientCreate) SetID(s string) *OAuth2ClientCreate { oc.mutation.SetID(s) @@ -186,6 +201,10 @@ func (oc *OAuth2ClientCreate) createSpec() (*OAuth2Client, *sqlgraph.CreateSpec) _spec.SetField(oauth2client.FieldLogoURL, field.TypeString, value) _node.LogoURL = value } + if value, ok := oc.mutation.SamlInitiated(); ok { + _spec.SetField(oauth2client.FieldSamlInitiated, field.TypeJSON, value) + _node.SamlInitiated = value + } return _node, _spec } diff --git a/storage/ent/db/oauth2client_update.go b/storage/ent/db/oauth2client_update.go index b70feacc40..9fabac0cfd 100644 --- a/storage/ent/db/oauth2client_update.go +++ b/storage/ent/db/oauth2client_update.go @@ -13,6 +13,7 @@ import ( "entgo.io/ent/schema/field" "github.com/dexidp/dex/storage/ent/db/oauth2client" "github.com/dexidp/dex/storage/ent/db/predicate" + "github.com/dexidp/dex/storage/ent/schema" ) // OAuth2ClientUpdate is the builder for updating OAuth2Client entities. @@ -88,6 +89,26 @@ func (ou *OAuth2ClientUpdate) SetLogoURL(s string) *OAuth2ClientUpdate { return ou } +// SetSamlInitiated sets the "samlInitiated" field. +func (ou *OAuth2ClientUpdate) SetSamlInitiated(si schema.SAMLInitiated) *OAuth2ClientUpdate { + ou.mutation.SetSamlInitiated(si) + return ou +} + +// SetNillableSamlInitiated sets the "samlInitiated" field if the given value is not nil. +func (ou *OAuth2ClientUpdate) SetNillableSamlInitiated(si *schema.SAMLInitiated) *OAuth2ClientUpdate { + if si != nil { + ou.SetSamlInitiated(*si) + } + return ou +} + +// ClearSamlInitiated clears the value of the "samlInitiated" field. +func (ou *OAuth2ClientUpdate) ClearSamlInitiated() *OAuth2ClientUpdate { + ou.mutation.ClearSamlInitiated() + return ou +} + // Mutation returns the OAuth2ClientMutation object of the builder. func (ou *OAuth2ClientUpdate) Mutation() *OAuth2ClientMutation { return ou.mutation @@ -186,6 +207,12 @@ func (ou *OAuth2ClientUpdate) sqlSave(ctx context.Context) (n int, err error) { if value, ok := ou.mutation.LogoURL(); ok { _spec.SetField(oauth2client.FieldLogoURL, field.TypeString, value) } + if value, ok := ou.mutation.SamlInitiated(); ok { + _spec.SetField(oauth2client.FieldSamlInitiated, field.TypeJSON, value) + } + if ou.mutation.SamlInitiatedCleared() { + _spec.ClearField(oauth2client.FieldSamlInitiated, field.TypeJSON) + } if n, err = sqlgraph.UpdateNodes(ctx, ou.driver, _spec); err != nil { if _, ok := err.(*sqlgraph.NotFoundError); ok { err = &NotFoundError{oauth2client.Label} @@ -266,6 +293,26 @@ func (ouo *OAuth2ClientUpdateOne) SetLogoURL(s string) *OAuth2ClientUpdateOne { return ouo } +// SetSamlInitiated sets the "samlInitiated" field. +func (ouo *OAuth2ClientUpdateOne) SetSamlInitiated(si schema.SAMLInitiated) *OAuth2ClientUpdateOne { + ouo.mutation.SetSamlInitiated(si) + return ouo +} + +// SetNillableSamlInitiated sets the "samlInitiated" field if the given value is not nil. +func (ouo *OAuth2ClientUpdateOne) SetNillableSamlInitiated(si *schema.SAMLInitiated) *OAuth2ClientUpdateOne { + if si != nil { + ouo.SetSamlInitiated(*si) + } + return ouo +} + +// ClearSamlInitiated clears the value of the "samlInitiated" field. +func (ouo *OAuth2ClientUpdateOne) ClearSamlInitiated() *OAuth2ClientUpdateOne { + ouo.mutation.ClearSamlInitiated() + return ouo +} + // Mutation returns the OAuth2ClientMutation object of the builder. func (ouo *OAuth2ClientUpdateOne) Mutation() *OAuth2ClientMutation { return ouo.mutation @@ -394,6 +441,12 @@ func (ouo *OAuth2ClientUpdateOne) sqlSave(ctx context.Context) (_node *OAuth2Cli if value, ok := ouo.mutation.LogoURL(); ok { _spec.SetField(oauth2client.FieldLogoURL, field.TypeString, value) } + if value, ok := ouo.mutation.SamlInitiated(); ok { + _spec.SetField(oauth2client.FieldSamlInitiated, field.TypeJSON, value) + } + if ouo.mutation.SamlInitiatedCleared() { + _spec.ClearField(oauth2client.FieldSamlInitiated, field.TypeJSON) + } _node = &OAuth2Client{config: ouo.config} _spec.Assign = _node.assignValues _spec.ScanValues = _node.scanValues diff --git a/storage/ent/schema/client.go b/storage/ent/schema/client.go index b897c52a2e..36229d7906 100644 --- a/storage/ent/schema/client.go +++ b/storage/ent/schema/client.go @@ -18,6 +18,11 @@ create table client ); */ +type SAMLInitiated struct { + Scopes []string `json:"scopes,omitempty"` + RedirectURI string `json:"redirect_uri,omitempty"` +} + // OAuth2Client holds the schema definition for the Client entity. type OAuth2Client struct { ent.Schema @@ -45,6 +50,8 @@ func (OAuth2Client) Fields() []ent.Field { field.Text("logo_url"). SchemaType(textSchema). NotEmpty(), + field.JSON("samlInitiated", SAMLInitiated{}). + Optional(), } }