From 73424a2485bec59cbdb86272b8a2eca36ea84b43 Mon Sep 17 00:00:00 2001 From: Robert Lin Date: Thu, 11 Jul 2024 09:57:28 -0700 Subject: [PATCH 1/6] feat/enterprise-portal: db layer for licenses --- .../internal/database/BUILD.bazel | 1 + .../database/codyaccess/codygateway.go | 2 +- .../database/databasetest/BUILD.bazel | 2 + .../database/databasetest/databasetest.go | 55 ++- .../tables/custommigrator/BUILD.bazel | 9 + .../tables/custommigrator/custommigrator.go | 9 + .../database/internal/tables/tables.go | 4 +- .../database/internal/utctime/BUILD.bazel | 12 + .../database/internal/utctime/utctime.go | 68 ++++ .../internal/database/migrate.go | 19 +- .../database/subscriptions/BUILD.bazel | 15 +- .../subscriptions/license_conditions.go | 82 +++- .../database/subscriptions/licenses.go | 350 +++++++++++++++++- .../database/subscriptions/licenses_test.go | 206 +++++++++++ .../database/subscriptions/subscriptions.go | 43 ++- .../subscriptions/subscriptions_conditions.go | 9 +- .../subscriptions/subscriptions_test.go | 20 +- .../internal/subscriptionsservice/adapters.go | 2 +- .../internal/subscriptionsservice/v1.go | 12 +- 19 files changed, 854 insertions(+), 66 deletions(-) create mode 100644 cmd/enterprise-portal/internal/database/internal/tables/custommigrator/BUILD.bazel create mode 100644 cmd/enterprise-portal/internal/database/internal/tables/custommigrator/custommigrator.go create mode 100644 cmd/enterprise-portal/internal/database/internal/utctime/BUILD.bazel create mode 100644 cmd/enterprise-portal/internal/database/internal/utctime/utctime.go create mode 100644 cmd/enterprise-portal/internal/database/subscriptions/licenses_test.go diff --git a/cmd/enterprise-portal/internal/database/BUILD.bazel b/cmd/enterprise-portal/internal/database/BUILD.bazel index d0daee3def48b..c042c39f40cd6 100644 --- a/cmd/enterprise-portal/internal/database/BUILD.bazel +++ b/cmd/enterprise-portal/internal/database/BUILD.bazel @@ -11,6 +11,7 @@ go_library( visibility = ["//cmd/enterprise-portal:__subpackages__"], deps = [ "//cmd/enterprise-portal/internal/database/internal/tables", + "//cmd/enterprise-portal/internal/database/internal/tables/custommigrator", "//cmd/enterprise-portal/internal/database/subscriptions", "//lib/errors", "//lib/managedservicesplatform/runtime", diff --git a/cmd/enterprise-portal/internal/database/codyaccess/codygateway.go b/cmd/enterprise-portal/internal/database/codyaccess/codygateway.go index 792dd8b5c029a..89d2b738bba6c 100644 --- a/cmd/enterprise-portal/internal/database/codyaccess/codygateway.go +++ b/cmd/enterprise-portal/internal/database/codyaccess/codygateway.go @@ -4,7 +4,7 @@ import "github.com/sourcegraph/sourcegraph/cmd/enterprise-portal/internal/databa type CodyGatewayAccess struct { // ⚠️ DO NOT USE: This field is only used for creating foreign key constraint. - Subscription *subscriptions.Subscription `gorm:"foreignKey:SubscriptionID"` + Subscription *subscriptions.TableSubscription `gorm:"foreignKey:SubscriptionID"` // SubscriptionID is the internal unprefixed UUID of the related subscription. SubscriptionID string `gorm:"type:uuid;not null;unique"` diff --git a/cmd/enterprise-portal/internal/database/databasetest/BUILD.bazel b/cmd/enterprise-portal/internal/database/databasetest/BUILD.bazel index 045d275fc4c06..d2a6fe3656df4 100644 --- a/cmd/enterprise-portal/internal/database/databasetest/BUILD.bazel +++ b/cmd/enterprise-portal/internal/database/databasetest/BUILD.bazel @@ -7,7 +7,9 @@ go_library( tags = [TAG_INFRA_CORESERVICES], visibility = ["//cmd/enterprise-portal:__subpackages__"], deps = [ + "//cmd/enterprise-portal/internal/database/internal/tables/custommigrator", "//internal/database/dbtest", + "@com_github_jackc_pgx_v5//:pgx", "@com_github_jackc_pgx_v5//pgxpool", "@com_github_stretchr_testify//require", "@io_gorm_driver_postgres//:postgres", diff --git a/cmd/enterprise-portal/internal/database/databasetest/databasetest.go b/cmd/enterprise-portal/internal/database/databasetest/databasetest.go index 7ecce9364988f..2671aa6fc5285 100644 --- a/cmd/enterprise-portal/internal/database/databasetest/databasetest.go +++ b/cmd/enterprise-portal/internal/database/databasetest/databasetest.go @@ -3,17 +3,20 @@ package databasetest import ( "context" "database/sql" + "encoding/json" "fmt" "strings" "testing" "time" + "github.com/jackc/pgx/v5" "github.com/jackc/pgx/v5/pgxpool" "github.com/stretchr/testify/require" "gorm.io/driver/postgres" "gorm.io/gorm" "gorm.io/gorm/schema" + "github.com/sourcegraph/sourcegraph/cmd/enterprise-portal/internal/database/internal/tables/custommigrator" "github.com/sourcegraph/sourcegraph/internal/database/dbtest" ) @@ -57,6 +60,11 @@ func NewTestDB(t testing.TB, system, suite string, tables ...schema.Tabler) *pgx for _, table := range tables { err = db.AutoMigrate(table) require.NoError(t, err) + if m, ok := table.(custommigrator.CustomTableMigrator); ok { + if err := m.RunCustomMigrations(db.Migrator()); err != nil { + require.NoError(t, err) + } + } } // Close the connection used to auto-migrate the database. @@ -66,7 +74,12 @@ func NewTestDB(t testing.TB, system, suite string, tables ...schema.Tabler) *pgx require.NoError(t, err) // Open a new connection to the test suite database. - testDB, err := pgxpool.New(context.Background(), dsn.String()) + dbConfig, err := pgxpool.ParseConfig(dsn.String()) + require.NoError(t, err) + if testing.Verbose() { + dbConfig.ConnConfig.Tracer = pgxTestTracer{TB: t} + } + testDB, err := pgxpool.NewWithConfig(context.Background(), dbConfig) require.NoError(t, err) t.Cleanup(func() { @@ -110,3 +123,43 @@ func ClearTablesAfterTest(t *testing.T, db *pgxpool.Pool, tables ...schema.Table } }) } + +// pgxTestTracer implements various pgx tracing hooks for dumping diagnostics +// in testing. +type pgxTestTracer struct{ testing.TB } + +// Select tracing hooks we want to implement. +var ( + _ pgx.QueryTracer = pgxTestTracer{} +) + +func (t pgxTestTracer) TraceQueryStart(ctx context.Context, conn *pgx.Conn, data pgx.TraceQueryStartData) context.Context { + var args []string + if len(data.Args) > 0 { + // Divider for readability + args = append(args, "\n---") + } + for _, arg := range data.Args { + data, err := json.MarshalIndent(arg, "", " ") + if err != nil { + args = append(args, fmt.Sprintf("marshal %T: %+v", arg, err)) + } + args = append(args, string(data)) + } + + t.Logf(`pgx.QueryStart db=%q +%s%s`, + conn.Config().Database, + strings.TrimSpace(data.SQL), + strings.Join(args, "\n")) + return ctx +} + +func (t pgxTestTracer) TraceQueryEnd(ctx context.Context, conn *pgx.Conn, data pgx.TraceQueryEndData) { + if data.Err != nil { + t.Logf(`pgx.QueryEnd db=%q tag=%q error=%q`, + conn.Config().Database, + data.CommandTag.String(), + data.Err) + } +} diff --git a/cmd/enterprise-portal/internal/database/internal/tables/custommigrator/BUILD.bazel b/cmd/enterprise-portal/internal/database/internal/tables/custommigrator/BUILD.bazel new file mode 100644 index 0000000000000..d832660e58aa5 --- /dev/null +++ b/cmd/enterprise-portal/internal/database/internal/tables/custommigrator/BUILD.bazel @@ -0,0 +1,9 @@ +load("@io_bazel_rules_go//go:def.bzl", "go_library") + +go_library( + name = "custommigrator", + srcs = ["custommigrator.go"], + importpath = "github.com/sourcegraph/sourcegraph/cmd/enterprise-portal/internal/database/internal/tables/custommigrator", + visibility = ["//cmd/enterprise-portal:__subpackages__"], + deps = ["@io_gorm_gorm//:gorm"], +) diff --git a/cmd/enterprise-portal/internal/database/internal/tables/custommigrator/custommigrator.go b/cmd/enterprise-portal/internal/database/internal/tables/custommigrator/custommigrator.go new file mode 100644 index 0000000000000..2686eb8ab7e64 --- /dev/null +++ b/cmd/enterprise-portal/internal/database/internal/tables/custommigrator/custommigrator.go @@ -0,0 +1,9 @@ +package custommigrator + +import "gorm.io/gorm" + +type CustomTableMigrator interface { + // RunCustomMigrations is called after all other migrations have been run. + // It can implement custom migrations. + RunCustomMigrations(migrator gorm.Migrator) error +} diff --git a/cmd/enterprise-portal/internal/database/internal/tables/tables.go b/cmd/enterprise-portal/internal/database/internal/tables/tables.go index 3a08337fb7053..7d84ec06e4316 100644 --- a/cmd/enterprise-portal/internal/database/internal/tables/tables.go +++ b/cmd/enterprise-portal/internal/database/internal/tables/tables.go @@ -12,9 +12,9 @@ import ( // ⚠️ WARNING: This list is meant to be read-only. func All() []schema.Tabler { return []schema.Tabler{ - &subscriptions.Subscription{}, + &subscriptions.TableSubscription{}, &subscriptions.SubscriptionCondition{}, - &subscriptions.SubscriptionLicense{}, + &subscriptions.TableSubscriptionLicense{}, &subscriptions.SubscriptionLicenseCondition{}, &codyaccess.CodyGatewayAccess{}, diff --git a/cmd/enterprise-portal/internal/database/internal/utctime/BUILD.bazel b/cmd/enterprise-portal/internal/database/internal/utctime/BUILD.bazel new file mode 100644 index 0000000000000..8e3be7b56e31e --- /dev/null +++ b/cmd/enterprise-portal/internal/database/internal/utctime/BUILD.bazel @@ -0,0 +1,12 @@ +load("@io_bazel_rules_go//go:def.bzl", "go_library") + +go_library( + name = "utctime", + srcs = ["utctime.go"], + importpath = "github.com/sourcegraph/sourcegraph/cmd/enterprise-portal/internal/database/internal/utctime", + visibility = ["//cmd/enterprise-portal:__subpackages__"], + deps = [ + "//lib/errors", + "//lib/pointers", + ], +) diff --git a/cmd/enterprise-portal/internal/database/internal/utctime/utctime.go b/cmd/enterprise-portal/internal/database/internal/utctime/utctime.go new file mode 100644 index 0000000000000..5a6d8792168b0 --- /dev/null +++ b/cmd/enterprise-portal/internal/database/internal/utctime/utctime.go @@ -0,0 +1,68 @@ +package utctime + +import ( + "database/sql" + "database/sql/driver" + "encoding/json" + "time" + + "github.com/sourcegraph/sourcegraph/lib/errors" + "github.com/sourcegraph/sourcegraph/lib/pointers" +) + +// Time is a wrapper around time.Time that implements the database/sql.Scanner +// and database/sql/driver.Valuer interfaces to serialize and deserialize time +// in UTC time zone. +type Time time.Time + +// Now returns the current time in UTC. +func Now() Time { return Time(time.Now().UTC()) } + +// FromTime returns a utctime.Time from a time.Time. +func FromTime(t time.Time) Time { return Time(t.UTC()) } + +var _ sql.Scanner = (*Time)(nil) + +func (t *Time) Scan(src any) error { + if src == nil { + return nil + } + if v, ok := src.(time.Time); ok { + *t = Time(v.UTC()) + return nil + } + return errors.Newf("value %T is not time.Time", src) +} + +var _ driver.Valuer = (*Time)(nil) + +// Value must be called with a non-nil Time. driver.Valuer callers will first +// check that the value is non-nil, so this is safe. +func (t Time) Value() (driver.Value, error) { + stdTime := t.Time() + return *stdTime, nil +} + +var _ json.Marshaler = (*Time)(nil) + +func (t Time) MarshalJSON() ([]byte, error) { return json.Marshal(t.Time()) } + +var _ json.Unmarshaler = (*Time)(nil) + +func (t *Time) UnmarshalJSON(data []byte) error { + var stdTime time.Time + if err := json.Unmarshal(data, &stdTime); err != nil { + return err + } + *t = FromTime(stdTime) + return nil +} + +// Time returns the underlying time.Time value, or nil if it is nil. +func (t *Time) Time() *time.Time { + if t == nil { + return nil + } + // Ensure the time is in UTC. + return pointers.Ptr((*time.Time)(t).UTC()) +} diff --git a/cmd/enterprise-portal/internal/database/migrate.go b/cmd/enterprise-portal/internal/database/migrate.go index 5a03474975be2..380d9125966de 100644 --- a/cmd/enterprise-portal/internal/database/migrate.go +++ b/cmd/enterprise-portal/internal/database/migrate.go @@ -21,6 +21,7 @@ import ( "github.com/sourcegraph/sourcegraph/lib/redislock" "github.com/sourcegraph/sourcegraph/cmd/enterprise-portal/internal/database/internal/tables" + "github.com/sourcegraph/sourcegraph/cmd/enterprise-portal/internal/database/internal/tables/custommigrator" ) // maybeMigrate runs the auto-migration for the database when needed based on @@ -42,6 +43,12 @@ func maybeMigrate(ctx context.Context, logger log.Logger, contract runtime.Contr } span.End() }() + logger = logger. + WithTrace(log.TraceContext{ + TraceID: span.SpanContext().TraceID().String(), + SpanID: span.SpanContext().SpanID().String(), + }). + With(log.String("database", dbName)) sqlDB, err := contract.PostgreSQL.OpenDatabase(ctx, dbName) if err != nil { @@ -83,17 +90,20 @@ func maybeMigrate(ctx context.Context, logger log.Logger, contract runtime.Contr span.AddEvent("lock.acquired") versionKey := fmt.Sprintf("%s:db_version", dbName) + liveVersion := redisClient.Get(ctx, versionKey).Val() if shouldSkipMigration( - redisClient.Get(ctx, versionKey).Val(), + liveVersion, currentVersion, ) { logger.Info("skipped auto-migration", - log.String("database", dbName), log.String("currentVersion", currentVersion), ) span.SetAttributes(attribute.Bool("skipped", true)) return nil } + logger.Info("executing auto-migration", + log.String("liveVersion", liveVersion), + log.String("currentVersion", currentVersion)) span.SetAttributes(attribute.Bool("skipped", false)) // Create a session that ignore debug logging. @@ -108,6 +118,11 @@ func maybeMigrate(ctx context.Context, logger log.Logger, contract runtime.Contr if err != nil { return errors.Wrapf(err, "auto migrating table for %s", errors.Safe(fmt.Sprintf("%T", table))) } + if m, ok := table.(custommigrator.CustomTableMigrator); ok { + if err := m.RunCustomMigrations(sess.Migrator()); err != nil { + return errors.Wrapf(err, "running custom migrations for %s", errors.Safe(fmt.Sprintf("%T", table))) + } + } } return redisClient.Set(ctx, versionKey, currentVersion, 0).Err() diff --git a/cmd/enterprise-portal/internal/database/subscriptions/BUILD.bazel b/cmd/enterprise-portal/internal/database/subscriptions/BUILD.bazel index d73ea1939daee..2dcadb12c26cd 100644 --- a/cmd/enterprise-portal/internal/database/subscriptions/BUILD.bazel +++ b/cmd/enterprise-portal/internal/database/subscriptions/BUILD.bazel @@ -14,17 +14,24 @@ go_library( visibility = ["//cmd/enterprise-portal:__subpackages__"], deps = [ "//cmd/enterprise-portal/internal/database/internal/upsert", + "//cmd/enterprise-portal/internal/database/internal/utctime", + "//internal/license", + "//lib/enterpriseportal/subscriptions/v1:subscriptions", "//lib/errors", "//lib/pointers", - "@com_github_jackc_pgtype//:pgtype", + "@com_github_google_uuid//:uuid", "@com_github_jackc_pgx_v5//:pgx", "@com_github_jackc_pgx_v5//pgxpool", + "@io_gorm_gorm//:gorm", ], ) go_test( name = "subscriptions_test", - srcs = ["subscriptions_test.go"], + srcs = [ + "licenses_test.go", + "subscriptions_test.go", + ], tags = [ TAG_INFRA_CORESERVICES, "requires-network", @@ -33,8 +40,12 @@ go_test( ":subscriptions", "//cmd/enterprise-portal/internal/database/databasetest", "//cmd/enterprise-portal/internal/database/internal/tables", + "//cmd/enterprise-portal/internal/database/internal/utctime", + "//internal/license", "//lib/pointers", "@com_github_google_uuid//:uuid", + "@com_github_hexops_autogold_v2//:autogold", + "@com_github_hexops_valast//:valast", "@com_github_jackc_pgx_v5//:pgx", "@com_github_stretchr_testify//assert", "@com_github_stretchr_testify//require", diff --git a/cmd/enterprise-portal/internal/database/subscriptions/license_conditions.go b/cmd/enterprise-portal/internal/database/subscriptions/license_conditions.go index 80cb11211b596..b047dfd0dadb5 100644 --- a/cmd/enterprise-portal/internal/database/subscriptions/license_conditions.go +++ b/cmd/enterprise-portal/internal/database/subscriptions/license_conditions.go @@ -1,11 +1,17 @@ package subscriptions -import "time" +import ( + "context" -type SubscriptionLicenseCondition struct { - // ⚠️ DO NOT USE: This field is only used for creating foreign key constraint. - License *SubscriptionLicense `gorm:"foreignKey:LicenseID"` + "github.com/jackc/pgx/v5" + + "github.com/sourcegraph/sourcegraph/cmd/enterprise-portal/internal/database/internal/utctime" + subscriptionsv1 "github.com/sourcegraph/sourcegraph/lib/enterpriseportal/subscriptions/v1" + "github.com/sourcegraph/sourcegraph/lib/errors" + "github.com/sourcegraph/sourcegraph/lib/pointers" +) +type SubscriptionLicenseCondition struct { // SubscriptionID is the internal unprefixed UUID of the related license. LicenseID string `gorm:"type:uuid;not null"` // Status is the type of status corresponding to this condition, corresponding @@ -15,9 +21,73 @@ type SubscriptionLicenseCondition struct { Message *string `gorm:"size:256"` // TransitionTime is the time at which the condition was created, i.e. when // the license transitioned into this status. - TransitionTime time.Time `gorm:"not null;default:current_timestamp"` + TransitionTime utctime.Time `gorm:"not null;default:current_timestamp"` } -func (s *SubscriptionLicenseCondition) TableName() string { +func (*SubscriptionLicenseCondition) TableName() string { return "enterprise_portal_subscription_license_conditions" } + +// subscriptionLicenseConditionJSONBAgg must be used with: +// +// JOIN +// enterprise_portal_subscription_license_conditions license_condition +// ON license_condition.license_id = id +// GROUP BY +// id +// +// The conditions are aggregated in JSON to 'conditions', which can be directly +// unmarshaled into the 'SubscriptionLicenseCondition' type using 'pgx'. +func subscriptionLicenseConditionJSONBAgg() string { + return ` +jsonb_agg( + jsonb_build_object( + 'Status', license_condition.status, + 'Message', license_condition.message, + 'TransitionTime', license_condition.transition_time + ) + ORDER BY license_condition.transition_time DESC +) AS conditions` +} + +type licenseConditionsStore struct{ tx pgx.Tx } + +// newLicenseConditionsStore is meant to be used exclusively in the context of +// a transaction, where the parent license is being updated at the same time. +// +// The caller owns the transaction lifecycle. +func newLicenseConditionsStore(tx pgx.Tx) *licenseConditionsStore { + return &licenseConditionsStore{tx: tx} +} + +type createLicenseConditionOpts struct { + Status subscriptionsv1.EnterpriseSubscriptionLicenseCondition_Status + Message string + TransitionTime utctime.Time +} + +func (s *licenseConditionsStore) createLicenseCondition(ctx context.Context, licenseID string, opts createLicenseConditionOpts) error { + if opts.TransitionTime.Time().IsZero() { + return errors.New("transition time is required") + } + _, err := s.tx.Exec(ctx, ` +INSERT INTO enterprise_portal_subscription_license_conditions ( + license_id, + status, + message, + transition_time +) +VALUES ( + @licenseID, + @status, + @message, + @transitionTime +)`, pgx.NamedArgs{ + "licenseID": licenseID, + // Convert to string representation of EnterpriseSubscriptionLicenseCondition + "status": subscriptionsv1.EnterpriseSubscriptionLicenseCondition_Status_name[int32(opts.Status)], + "message": pointers.NilIfZero(opts.Message), + "transitionTime": opts.TransitionTime, + }) + return err +} diff --git a/cmd/enterprise-portal/internal/database/subscriptions/licenses.go b/cmd/enterprise-portal/internal/database/subscriptions/licenses.go index 2b64bbea108e0..c96865900faed 100644 --- a/cmd/enterprise-portal/internal/database/subscriptions/licenses.go +++ b/cmd/enterprise-portal/internal/database/subscriptions/licenses.go @@ -1,15 +1,49 @@ package subscriptions import ( + "context" + "encoding/json" + "fmt" + "strings" "time" - "github.com/jackc/pgtype" + "github.com/google/uuid" + "github.com/jackc/pgx/v5" + "github.com/jackc/pgx/v5/pgxpool" + "gorm.io/gorm" + + "github.com/sourcegraph/sourcegraph/cmd/enterprise-portal/internal/database/internal/utctime" + internallicense "github.com/sourcegraph/sourcegraph/internal/license" + subscriptionsv1 "github.com/sourcegraph/sourcegraph/lib/enterpriseportal/subscriptions/v1" + "github.com/sourcegraph/sourcegraph/lib/pointers" + + "github.com/sourcegraph/sourcegraph/lib/errors" ) -type SubscriptionLicense struct { +// ⚠️ DO NOT USE: This type is only used for creating foreign key constraints +// and initializing tables with gorm. +type TableSubscriptionLicense struct { // ⚠️ DO NOT USE: This field is only used for creating foreign key constraint. - Subscription *Subscription `gorm:"foreignKey:SubscriptionID"` + Conditions *[]SubscriptionLicenseCondition `gorm:"foreignKey:LicenseID"` + + SubscriptionLicense +} + +func (*TableSubscriptionLicense) TableName() string { + return "enterprise_portal_subscription_licenses" +} +// Implement tables.CustomMigrator +func (s *TableSubscriptionLicense) RunCustomMigrations(migrator gorm.Migrator) error { + if migrator.HasColumn(s, "license_kind") { + if err := migrator.DropColumn(s, "license_kind"); err != nil { + return err + } + } + return nil +} + +type SubscriptionLicense struct { // SubscriptionID is the internal unprefixed UUID of the related subscription. SubscriptionID string `gorm:"type:uuid;not null"` // ID is the internal unprefixed UUID of this license. @@ -19,21 +53,317 @@ type SubscriptionLicense struct { // to this subscription. // // Condition transition details are tracked in 'enterprise_portal_subscription_license_conditions'. - CreatedAt time.Time `gorm:"not null;default:current_timestamp"` - RevokedAt *time.Time // Null indicates the licnese is not revoked. + CreatedAt utctime.Time `gorm:"not null;default:current_timestamp"` + RevokedAt *utctime.Time // Null indicates the license is not revoked. - // LicenseKind is the kind of license stored in LicenseData, corresponding + // LicenseType is the kind of license stored in LicenseData, corresponding // to the API 'EnterpriseSubscriptionLicenseType'. - LicenseKind string `gorm:"not null"` + LicenseType string `gorm:"not null"` // LicenseData is the license data stored in JSON format. It is read-only // and generally never queried in conditions - properties that are should // be stored at the subscription or license level. // // Value shapes correspond to API types appropriate for each // 'EnterpriseSubscriptionLicenseType'. - LicenseData pgtype.JSONB `gorm:"type:jsonb"` + LicenseData json.RawMessage `gorm:"type:jsonb"` } -func (s *SubscriptionLicense) TableName() string { - return "enterprise_portal_subscription_licenses" +// subscriptionLicenseWithConditionsColumns must match scanSubscriptionLicense() +// values. +func subscriptionLicenseWithConditionsColumns() []string { + return []string{ + "subscription_id", + "id", + + "created_at", + "revoked_at", + + "license_type", + "license_data", + + subscriptionLicenseConditionJSONBAgg(), + } +} + +type LicenseWithConditions struct { + SubscriptionLicense + Conditions []SubscriptionLicenseCondition +} + +// scanSubscription matches subscriptionTableColumns() values. +func scanSubscriptionLicenseWithConditions(row pgx.Row) (*LicenseWithConditions, error) { + var l LicenseWithConditions + err := row.Scan( + &l.SubscriptionID, + &l.ID, + &l.CreatedAt, + &l.RevokedAt, + &l.LicenseType, + &l.LicenseData, + &l.Conditions, // see subscriptionLicenseConditionJSONBAgg docstring + ) + return &l, err +} + +// LicensesStore manages licenses belonging to Enterprise subscriptions. +// +// Licenses can only be created and revoked - they can never be updated. +type LicensesStore struct { + db *pgxpool.Pool +} + +func NewLicensesStore(db *pgxpool.Pool) *LicensesStore { + return &LicensesStore{ + db: db, + } +} + +type ListLicensesOpts struct { + SubscriptionID string + // PageSize is the maximum number of licenses to return. + PageSize int +} + +func (opts ListLicensesOpts) toQueryConditions() (where, limitClause string, _ pgx.NamedArgs) { + whereConds := []string{"TRUE"} + namedArgs := pgx.NamedArgs{} + if opts.SubscriptionID != "" { + whereConds = append(whereConds, "subscription_id = @subscriptionID") + namedArgs["subscriptionID"] = opts.SubscriptionID + } + where = strings.Join(whereConds, " AND ") + + if opts.PageSize > 0 { + limitClause = "LIMIT @pageSize" + namedArgs["pageSize"] = opts.PageSize + } + return where, limitClause, namedArgs +} + +func (s *LicensesStore) List(ctx context.Context, opts ListLicensesOpts) ([]*LicenseWithConditions, error) { + where, limitClause, namedArgs := opts.toQueryConditions() + query := fmt.Sprintf(` +SELECT + %s +FROM + enterprise_portal_subscription_licenses +JOIN + enterprise_portal_subscription_license_conditions license_condition + ON license_condition.license_id = id +WHERE + %s +GROUP BY + id +ORDER BY + created_at DESC +%s`, + strings.Join(subscriptionLicenseWithConditionsColumns(), ", "), + where, limitClause) + + rows, err := s.db.Query(ctx, query, namedArgs) + if err != nil { + return nil, errors.Wrap(err, "query rows") + } + defer rows.Close() + + var licenses []*LicenseWithConditions + for rows.Next() { + license, err := scanSubscriptionLicenseWithConditions(rows) + if err != nil { + return nil, errors.Wrap(err, "scan row") + } + licenses = append(licenses, license) + } + return licenses, rows.Err() +} + +func (s *LicensesStore) Get(ctx context.Context, licenseID string) (*LicenseWithConditions, error) { + query := fmt.Sprintf(` +SELECT + %s +FROM + enterprise_portal_subscription_licenses +JOIN + enterprise_portal_subscription_license_conditions license_condition + ON license_condition.license_id = id +WHERE + id = @licenseID +GROUP BY + id`, + strings.Join(subscriptionLicenseWithConditionsColumns(), ", ")) + + license, err := scanSubscriptionLicenseWithConditions( + s.db.QueryRow(ctx, query, pgx.NamedArgs{ + "licenseID": licenseID, + }), + ) + if err != nil { + return nil, errors.Wrap(err, "query rows") + } + return license, nil +} + +type CreateLicenseOpts struct { + Message string + // If nil, the creation time will be set to the current time. + Time *utctime.Time +} + +// LicenseKey corresponds to *subscriptionsv1.EnterpriseSubscriptionLicenseKey +// and the 'ENTERPRISE_SUBSCRIPTION_LICENSE_TYPE_KEY' license type. +type LicenseKey struct { + Info internallicense.Info + // Signed license key with the license information in Info. + SignedKey string +} + +// CreateLicense creates a new classic offline license for the given subscription. +func (s *LicensesStore) CreateLicenseKey( + ctx context.Context, + subscriptionID string, + license *LicenseKey, + opts CreateLicenseOpts, +) (_ *LicenseWithConditions, err error) { + // Special behaviour: the license key embeds the creation time, and it must + // match the time provided in the options. + if opts.Time == nil { + return nil, errors.New("creation time must be specified for licensekeys") + } else if !opts.Time.Time().Equal(license.Info.CreatedAt) { + return nil, errors.New("creation time must match the license key information") + } + + return s.create( + ctx, + subscriptionID, + subscriptionsv1.EnterpriseSubscriptionLicenseType_ENTERPRISE_SUBSCRIPTION_LICENSE_TYPE_KEY, + license, + opts, + ) +} + +func (s *LicensesStore) create( + ctx context.Context, + subscriptionID string, + licenseType subscriptionsv1.EnterpriseSubscriptionLicenseType, + license any, + opts CreateLicenseOpts, +) (_ *LicenseWithConditions, err error) { + if subscriptionID == "" { + return nil, errors.New("subscription ID must be specified") + } + if opts.Time == nil { + opts.Time = pointers.Ptr(utctime.Now()) + } else if opts.Time.Time().After(time.Now()) { + return nil, errors.New("creation time cannot be in the future") + } + if licenseType == subscriptionsv1.EnterpriseSubscriptionLicenseType_ENTERPRISE_SUBSCRIPTION_LICENSE_TYPE_UNSPECIFIED { + return nil, errors.New("license type must be specified") + } + + licenseID, err := uuid.NewV7() + if err != nil { + return nil, errors.Wrap(err, "generate uuid") + } + licenseData, err := json.Marshal(license) + if err != nil { + return nil, errors.Wrap(err, "marshal license data") + } + tx, err := s.db.Begin(ctx) + if err != nil { + return nil, errors.Wrap(err, "begin transaction") + } + defer func() { + if rollbackErr := tx.Rollback(context.Background()); rollbackErr != nil { + err = errors.Append(err, errors.Wrap(err, "rollback")) + } + }() + + if _, err := tx.Exec(ctx, ` +INSERT INTO enterprise_portal_subscription_licenses ( + id, + subscription_id, + license_type, + license_data, + created_at +) +VALUES ( + @licenseID, + @subscriptionID, + @licenseType, + @licenseData, + @createdAt +) +`, pgx.NamedArgs{ + "licenseID": licenseID.String(), + "subscriptionID": subscriptionID, + "licenseType": subscriptionsv1.EnterpriseSubscriptionLicenseType_name[int32(licenseType)], + "licenseData": licenseData, + "createdAt": opts.Time, + }); err != nil { + return nil, errors.Wrap(err, "create license") + } + + if err := newLicenseConditionsStore(tx).createLicenseCondition(ctx, licenseID.String(), createLicenseConditionOpts{ + Status: subscriptionsv1.EnterpriseSubscriptionLicenseCondition_STATUS_CREATED, + Message: opts.Message, + TransitionTime: *opts.Time, + }); err != nil { + return nil, errors.Wrap(err, "create license condition") + } + + if err := tx.Commit(ctx); err != nil { + return nil, errors.Wrap(err, "commit transaction") + } + + return s.Get(ctx, licenseID.String()) +} + +type RevokeLicenseOpts struct { + Message string + // If nil, the revocation time will be set to the current time. + Time *utctime.Time +} + +// Revoke marks the given license as revoked. +func (s *LicensesStore) Revoke(ctx context.Context, licenseID string, opts RevokeLicenseOpts) (*LicenseWithConditions, error) { + if opts.Time == nil { + opts.Time = pointers.Ptr(utctime.Now()) + } else if opts.Time.Time().After(time.Now()) { + return nil, errors.New("revocation time cannot be in the future") + } + + tx, err := s.db.Begin(ctx) + if err != nil { + return nil, errors.Wrap(err, "begin transaction") + } + defer func() { + if rollbackErr := tx.Rollback(context.Background()); rollbackErr != nil { + err = errors.Append(err, rollbackErr) + } + }() + + if _, err := tx.Exec(ctx, ` +UPDATE enterprise_portal_subscription_licenses +SET revoked_at = COALESCE(revoked_at, @revokedAt) -- use existing revoke time if already revoked +WHERE id = @licenseID +`, pgx.NamedArgs{ + "revokedAt": opts.Time, + "licenseID": licenseID, + }); err != nil { + return nil, errors.Wrap(err, "revoke license") + } + + if err := newLicenseConditionsStore(tx).createLicenseCondition(ctx, licenseID, createLicenseConditionOpts{ + Status: subscriptionsv1.EnterpriseSubscriptionLicenseCondition_STATUS_REVOKED, + Message: opts.Message, + TransitionTime: *opts.Time, + }); err != nil { + return nil, errors.Wrap(err, "create license condition") + } + + if err := tx.Commit(ctx); err != nil { + return nil, errors.Wrap(err, "commit transaction") + } + + return s.Get(ctx, licenseID) } diff --git a/cmd/enterprise-portal/internal/database/subscriptions/licenses_test.go b/cmd/enterprise-portal/internal/database/subscriptions/licenses_test.go new file mode 100644 index 0000000000000..784250b001ec3 --- /dev/null +++ b/cmd/enterprise-portal/internal/database/subscriptions/licenses_test.go @@ -0,0 +1,206 @@ +package subscriptions_test + +import ( + "context" + "fmt" + "testing" + "time" + + "github.com/google/uuid" + "github.com/hexops/autogold/v2" + "github.com/hexops/valast" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/sourcegraph/sourcegraph/cmd/enterprise-portal/internal/database/databasetest" + "github.com/sourcegraph/sourcegraph/cmd/enterprise-portal/internal/database/internal/tables" + "github.com/sourcegraph/sourcegraph/cmd/enterprise-portal/internal/database/internal/utctime" + "github.com/sourcegraph/sourcegraph/cmd/enterprise-portal/internal/database/subscriptions" + "github.com/sourcegraph/sourcegraph/internal/license" + "github.com/sourcegraph/sourcegraph/lib/pointers" +) + +func TestLicensesStore(t *testing.T) { + t.Parallel() + + ctx := context.Background() + db := databasetest.NewTestDB(t, "enterprise-portal", t.Name(), tables.All()...) + + subscriptionID1 := uuid.NewString() + subscriptionID2 := uuid.NewString() + + subs := subscriptions.NewStore(db) + _, err := subs.Upsert(ctx, subscriptionID1, subscriptions.UpsertSubscriptionOptions{ + DisplayName: "Acme, Inc. 1", + }) + require.NoError(t, err) + _, err = subs.Upsert(ctx, subscriptionID2, subscriptions.UpsertSubscriptionOptions{ + DisplayName: "Acme, Inc. 2", + }) + require.NoError(t, err) + + licenses := subscriptions.NewLicensesStore(db) + + var createdLicenses []*subscriptions.LicenseWithConditions + getCreatedByLicenseID := func(t *testing.T, licenseID string) *subscriptions.LicenseWithConditions { + for _, l := range createdLicenses { + if l.ID == licenseID { + return l + } + } + t.Errorf("license %q not found", licenseID) + t.FailNow() + return nil + } + t.Run("CreateLicenseKey", func(t *testing.T) { + testLicense := func( + got *subscriptions.LicenseWithConditions, + wantMessage autogold.Value, + wantLicenseData autogold.Value, + ) { + assert.NotEmpty(t, got.ID) + assert.NotZero(t, got.CreatedAt) + assert.Equal(t, "ENTERPRISE_SUBSCRIPTION_LICENSE_TYPE_KEY", got.LicenseType) + wantLicenseData.Equal(t, string(got.LicenseData)) + + assert.Len(t, got.Conditions, 1) + wantMessage.Equal(t, got.Conditions[0].Message) + assert.Equal(t, "STATUS_CREATED", got.Conditions[0].Status) + assert.Equal(t, got.CreatedAt, got.Conditions[0].TransitionTime) + } + + got, err := licenses.CreateLicenseKey(ctx, subscriptionID1, + &subscriptions.LicenseKey{ + Info: license.Info{ + Tags: []string{"foo"}, + CreatedAt: time.Time{}.Add(1 * time.Hour), + }, + SignedKey: "asdfasdf", + }, + subscriptions.CreateLicenseOpts{ + Message: t.Name() + " 1 old", + Time: pointers.Ptr(utctime.FromTime(time.Time{}.Add(1 * time.Hour))), + }) + require.NoError(t, err) + testLicense( + got, + autogold.Expect(valast.Ptr("TestLicensesStore/CreateLicenseKey 1 old")), + autogold.Expect(`{"Info": {"c": "0001-01-01T01:00:00Z", "e": "0001-01-01T00:00:00Z", "t": ["foo"], "u": 0}, "SignedKey": "asdfasdf"}`), + ) + createdLicenses = append(createdLicenses, got) + + got, err = licenses.CreateLicenseKey(ctx, subscriptionID1, + &subscriptions.LicenseKey{ + Info: license.Info{ + Tags: []string{"baz"}, + CreatedAt: time.Time{}.Add(24 * time.Hour), + }, + SignedKey: "barasdf", + }, + subscriptions.CreateLicenseOpts{ + Message: t.Name() + " 1", + Time: pointers.Ptr(utctime.FromTime(time.Time{}.Add(24 * time.Hour))), + }) + require.NoError(t, err) + testLicense( + got, + autogold.Expect(valast.Ptr("TestLicensesStore/CreateLicenseKey 1")), + autogold.Expect(`{"Info": {"c": "0001-01-02T00:00:00Z", "e": "0001-01-01T00:00:00Z", "t": ["baz"], "u": 0}, "SignedKey": "barasdf"}`), + ) + createdLicenses = append(createdLicenses, got) + + got, err = licenses.CreateLicenseKey(ctx, subscriptionID2, + &subscriptions.LicenseKey{ + Info: license.Info{ + Tags: []string{"tag"}, + CreatedAt: time.Time{}.Add(24 * time.Hour), + }, + SignedKey: "asdffdsadf", + }, + subscriptions.CreateLicenseOpts{ + Message: t.Name() + " 2", + Time: pointers.Ptr(utctime.FromTime(time.Time{}.Add(24 * time.Hour))), + }) + require.NoError(t, err) + testLicense( + got, + autogold.Expect(valast.Ptr("TestLicensesStore/CreateLicenseKey 2")), + autogold.Expect(`{"Info": {"c": "0001-01-02T00:00:00Z", "e": "0001-01-01T00:00:00Z", "t": ["tag"], "u": 0}, "SignedKey": "asdffdsadf"}`), + ) + createdLicenses = append(createdLicenses, got) + + t.Run("timestamps do not match", func(t *testing.T) { + _, err = licenses.CreateLicenseKey(ctx, subscriptionID2, + &subscriptions.LicenseKey{ + Info: license.Info{ + Tags: []string{"tag"}, + CreatedAt: time.Time{}.Add(24 * time.Hour), + }, + SignedKey: "asdffdsadf", + }, + subscriptions.CreateLicenseOpts{ + Message: t.Name(), + Time: pointers.Ptr(utctime.Now()), + }) + require.Error(t, err) + autogold.Expect("creation time must match the license key information").Equal(t, err.Error()) + }) + }) + + t.Run("List", func(t *testing.T) { + listedLicenses, err := licenses.List(ctx, subscriptions.ListLicensesOpts{}) + require.NoError(t, err) + assert.Len(t, listedLicenses, len(createdLicenses)) + for _, l := range listedLicenses { + created := getCreatedByLicenseID(t, l.ID) + assert.Equal(t, *created, *l) + } + + t.Run("List by subscription", func(t *testing.T) { + listedLicenses, err := licenses.List(ctx, subscriptions.ListLicensesOpts{ + SubscriptionID: subscriptionID1, + }) + require.NoError(t, err) + assert.Len(t, listedLicenses, 2) + for _, l := range listedLicenses { + assert.Equal(t, subscriptionID1, l.SubscriptionID) + assert.Equal(t, *getCreatedByLicenseID(t, l.ID), *l) + } + + listedLicenses, err = licenses.List(ctx, subscriptions.ListLicensesOpts{ + SubscriptionID: subscriptionID2, + }) + require.NoError(t, err) + assert.Len(t, listedLicenses, 1) + for _, l := range listedLicenses { + assert.Equal(t, subscriptionID2, l.SubscriptionID) + assert.Equal(t, *getCreatedByLicenseID(t, l.ID), *l) + } + }) + }) + + t.Run("Get", func(t *testing.T) { + for _, license := range createdLicenses { + got, err := licenses.Get(ctx, license.ID) + require.NoError(t, err) + assert.Equal(t, *license, *got) + } + }) + + t.Run("Revoke", func(t *testing.T) { + for idx, license := range createdLicenses { + revokeTime := time.Now().Add(-time.Second) + got, err := licenses.Revoke(ctx, license.ID, subscriptions.RevokeLicenseOpts{ + Message: fmt.Sprintf("%s %d", t.Name(), idx), + Time: pointers.Ptr(utctime.FromTime(revokeTime)), + }) + require.NoError(t, err) + assert.Equal(t, revokeTime.UTC(), *got.RevokedAt.Time()) + require.Len(t, got.Conditions, 2) + // Most recent condition is sorted first, and should be the revocation + assert.Equal(t, "STATUS_REVOKED", got.Conditions[0].Status) + assert.Equal(t, revokeTime.UTC(), *got.Conditions[0].TransitionTime.Time()) + assert.Equal(t, "STATUS_CREATED", got.Conditions[1].Status) + } + }) +} diff --git a/cmd/enterprise-portal/internal/database/subscriptions/subscriptions.go b/cmd/enterprise-portal/internal/database/subscriptions/subscriptions.go index 94e8d8e0297ec..7b49aa18593ba 100644 --- a/cmd/enterprise-portal/internal/database/subscriptions/subscriptions.go +++ b/cmd/enterprise-portal/internal/database/subscriptions/subscriptions.go @@ -10,10 +10,26 @@ import ( "github.com/jackc/pgx/v5/pgxpool" "github.com/sourcegraph/sourcegraph/cmd/enterprise-portal/internal/database/internal/upsert" + "github.com/sourcegraph/sourcegraph/cmd/enterprise-portal/internal/database/internal/utctime" "github.com/sourcegraph/sourcegraph/lib/errors" - "github.com/sourcegraph/sourcegraph/lib/pointers" ) +// ⚠️ DO NOT USE: This type is only used for creating foreign key constraints +// and initializing tables with gorm. +type TableSubscription struct { + // Each Subscription has many Licenses. + Licenses []*TableSubscriptionLicense `gorm:"foreignKey:SubscriptionID"` + + // Each Subscription has many Conditions. + Conditions *[]SubscriptionCondition `gorm:"foreignKey:SubscriptionID"` + + Subscription +} + +func (*TableSubscription) TableName() string { + return "enterprise_portal_subscriptions" +} + // Subscription is an Enterprise subscription record. type Subscription struct { // ID is the internal (unprefixed) UUID-format identifier for the subscription. @@ -22,7 +38,7 @@ type Subscription struct { // "acme.sourcegraphcloud.com". This is set explicitly. // // It must be unique across all currently un-archived subscriptions. - InstanceDomain string `gorm:"uniqueIndex:,where:archived_at IS NULL"` + InstanceDomain *string `gorm:"uniqueIndex:,where:archived_at IS NULL"` // WARNING: The below fields are not yet used in production. @@ -39,9 +55,9 @@ type Subscription struct { // to this subscription. // // Condition transition details are tracked in 'enterprise_portal_subscription_conditions'. - CreatedAt time.Time `gorm:"not null;default:current_timestamp"` - UpdatedAt time.Time `gorm:"not null;default:current_timestamp"` - ArchivedAt *time.Time // Null indicates the subscription is not archived. + CreatedAt utctime.Time `gorm:"not null;default:current_timestamp"` + UpdatedAt utctime.Time `gorm:"not null;default:current_timestamp"` + ArchivedAt *utctime.Time // Null indicates the subscription is not archived. // SalesforceSubscriptionID associated with this Enterprise subscription. SalesforceSubscriptionID *string @@ -49,11 +65,7 @@ type Subscription struct { SalesforceOpportunityID *string } -func (s Subscription) TableName() string { - return "enterprise_portal_subscriptions" -} - -// subscriptionTableColumns must match s.scan() values. +// subscriptionTableColumns must match scanSubscription() values. func subscriptionTableColumns() []string { return []string{ "id", @@ -67,7 +79,7 @@ func subscriptionTableColumns() []string { } } -// scanSubscription matches s.columns() values. +// scanSubscription matches subscriptionTableColumns() values. func scanSubscription(row pgx.Row) (*Subscription, error) { var s Subscription err := row.Scan( @@ -83,13 +95,6 @@ func scanSubscription(row pgx.Row) (*Subscription, error) { if err != nil { return nil, err } - - s.CreatedAt = s.CreatedAt.UTC() - s.UpdatedAt = s.UpdatedAt.UTC() - if s.ArchivedAt != nil { - s.ArchivedAt = pointers.Ptr(s.ArchivedAt.UTC()) - } - return &s, nil } @@ -172,7 +177,7 @@ WHERE %s } type UpsertSubscriptionOptions struct { - InstanceDomain string + InstanceDomain *string DisplayName string CreatedAt time.Time diff --git a/cmd/enterprise-portal/internal/database/subscriptions/subscriptions_conditions.go b/cmd/enterprise-portal/internal/database/subscriptions/subscriptions_conditions.go index f30419769a3f4..9de9f3ab82149 100644 --- a/cmd/enterprise-portal/internal/database/subscriptions/subscriptions_conditions.go +++ b/cmd/enterprise-portal/internal/database/subscriptions/subscriptions_conditions.go @@ -1,14 +1,9 @@ package subscriptions -import ( - "time" -) +import "github.com/sourcegraph/sourcegraph/cmd/enterprise-portal/internal/database/internal/utctime" // Subscription is an Enterprise subscription condition record. type SubscriptionCondition struct { - // ⚠️ DO NOT USE: This field is only used for creating foreign key constraint. - Subscription *Subscription `gorm:"foreignKey:SubscriptionID"` - // SubscriptionID is the internal unprefixed UUID of the related subscription. SubscriptionID string `gorm:"type:uuid;not null"` // Status is the type of status corresponding to this condition, corresponding @@ -18,7 +13,7 @@ type SubscriptionCondition struct { Message *string `gorm:"size:256"` // TransitionTime is the time at which the condition was created, i.e. when // the subscription transitioned into this status. - TransitionTime time.Time `gorm:"not null;default:current_timestamp"` + TransitionTime utctime.Time `gorm:"not null;default:current_timestamp"` } func (s *SubscriptionCondition) TableName() string { diff --git a/cmd/enterprise-portal/internal/database/subscriptions/subscriptions_test.go b/cmd/enterprise-portal/internal/database/subscriptions/subscriptions_test.go index 4aacc18c881d8..70b665356037f 100644 --- a/cmd/enterprise-portal/internal/database/subscriptions/subscriptions_test.go +++ b/cmd/enterprise-portal/internal/database/subscriptions/subscriptions_test.go @@ -20,7 +20,7 @@ func TestSubscriptionsStore(t *testing.T) { t.Parallel() ctx := context.Background() - db := databasetest.NewTestDB(t, "enterprise-portal", "SubscriptionsStore", tables.All()...) + db := databasetest.NewTestDB(t, "enterprise-portal", t.Name(), tables.All()...) for _, tc := range []struct { name string @@ -45,19 +45,19 @@ func SubscriptionsStoreList(t *testing.T, ctx context.Context, s *subscriptions. s1, err := s.Upsert( ctx, uuid.New().String(), - subscriptions.UpsertSubscriptionOptions{InstanceDomain: "s1.sourcegraph.com"}, + subscriptions.UpsertSubscriptionOptions{InstanceDomain: pointers.Ptr("s1.sourcegraph.com")}, ) require.NoError(t, err) s2, err := s.Upsert( ctx, uuid.New().String(), - subscriptions.UpsertSubscriptionOptions{InstanceDomain: "s2.sourcegraph.com"}, + subscriptions.UpsertSubscriptionOptions{InstanceDomain: pointers.Ptr("s2.sourcegraph.com")}, ) require.NoError(t, err) _, err = s.Upsert( ctx, uuid.New().String(), - subscriptions.UpsertSubscriptionOptions{InstanceDomain: "s3.sourcegraph.com"}, + subscriptions.UpsertSubscriptionOptions{InstanceDomain: pointers.Ptr("s3.sourcegraph.com")}, ) require.NoError(t, err) @@ -79,7 +79,7 @@ func SubscriptionsStoreList(t *testing.T, ctx context.Context, s *subscriptions. t.Run("list by instance domains", func(t *testing.T) { ss, err := s.List(ctx, subscriptions.ListEnterpriseSubscriptionsOptions{ - InstanceDomains: []string{s1.InstanceDomain, s2.InstanceDomain}}, + InstanceDomains: []string{*s1.InstanceDomain, *s2.InstanceDomain}}, ) require.NoError(t, err) require.Len(t, ss, 2) @@ -115,7 +115,7 @@ func SubscriptionsStoreUpsert(t *testing.T, ctx context.Context, s *subscription currentSubscription, err := s.Upsert( ctx, uuid.New().String(), - subscriptions.UpsertSubscriptionOptions{InstanceDomain: "s1.sourcegraph.com"}, + subscriptions.UpsertSubscriptionOptions{InstanceDomain: pointers.Ptr("s1.sourcegraph.com")}, ) require.NoError(t, err) @@ -140,7 +140,7 @@ func SubscriptionsStoreUpsert(t *testing.T, ctx context.Context, s *subscription t.Cleanup(func() { currentSubscription = got }) got, err = s.Upsert(ctx, currentSubscription.ID, subscriptions.UpsertSubscriptionOptions{ - InstanceDomain: "s1-new.sourcegraph.com", + InstanceDomain: pointers.Ptr("s1-new.sourcegraph.com"), }) require.NoError(t, err) assert.Equal(t, "s1-new.sourcegraph.com", got.InstanceDomain) @@ -169,7 +169,7 @@ func SubscriptionsStoreUpsert(t *testing.T, ctx context.Context, s *subscription assert.Equal(t, currentSubscription.InstanceDomain, got.InstanceDomain) assert.Equal(t, currentSubscription.DisplayName, got.DisplayName) // Round times to allow for some precision drift in CI - assert.Equal(t, yesterday.Round(time.Second).UTC(), got.CreatedAt.Round(time.Second)) + assert.Equal(t, yesterday.Round(time.Second).UTC(), got.CreatedAt.Time().Round(time.Second)) }) t.Run("update only archived at", func(t *testing.T) { @@ -184,7 +184,7 @@ func SubscriptionsStoreUpsert(t *testing.T, ctx context.Context, s *subscription assert.Equal(t, currentSubscription.DisplayName, got.DisplayName) assert.Equal(t, currentSubscription.CreatedAt, got.CreatedAt) // Round times to allow for some precision drift in CI - assert.Equal(t, yesterday.Round(time.Second).UTC(), got.ArchivedAt.Round(time.Second)) + assert.Equal(t, yesterday.Round(time.Second).UTC(), got.ArchivedAt.Time().Round(time.Second)) }) t.Run("force update to zero values", func(t *testing.T) { @@ -209,7 +209,7 @@ func SubscriptionsStoreGet(t *testing.T, ctx context.Context, s *subscriptions.S s1, err := s.Upsert( ctx, uuid.New().String(), - subscriptions.UpsertSubscriptionOptions{InstanceDomain: "s1.sourcegraph.com"}, + subscriptions.UpsertSubscriptionOptions{InstanceDomain: pointers.Ptr("s1.sourcegraph.com")}, ) require.NoError(t, err) diff --git a/cmd/enterprise-portal/internal/subscriptionsservice/adapters.go b/cmd/enterprise-portal/internal/subscriptionsservice/adapters.go index 52ed657b067d3..4569da1e323c6 100644 --- a/cmd/enterprise-portal/internal/subscriptionsservice/adapters.go +++ b/cmd/enterprise-portal/internal/subscriptionsservice/adapters.go @@ -68,7 +68,7 @@ func convertSubscriptionToProto(subscription *subscriptions.Subscription, attrs return &subscriptionsv1.EnterpriseSubscription{ Id: subscriptionsv1.EnterpriseSubscriptionIDPrefix + attrs.ID, Conditions: conds, - InstanceDomain: subscription.InstanceDomain, + InstanceDomain: pointers.DerefZero(subscription.InstanceDomain), DisplayName: subscription.DisplayName, } } diff --git a/cmd/enterprise-portal/internal/subscriptionsservice/v1.go b/cmd/enterprise-portal/internal/subscriptionsservice/v1.go index 7e4a16cbc959d..50ffa0a3e21b1 100644 --- a/cmd/enterprise-portal/internal/subscriptionsservice/v1.go +++ b/cmd/enterprise-portal/internal/subscriptionsservice/v1.go @@ -16,6 +16,7 @@ import ( subscriptionsv1connect "github.com/sourcegraph/sourcegraph/lib/enterpriseportal/subscriptions/v1/v1connect" "github.com/sourcegraph/sourcegraph/lib/errors" "github.com/sourcegraph/sourcegraph/lib/managedservicesplatform/iam" + "github.com/sourcegraph/sourcegraph/lib/pointers" "github.com/sourcegraph/sourcegraph/cmd/enterprise-portal/internal/connectutil" "github.com/sourcegraph/sourcegraph/cmd/enterprise-portal/internal/database/subscriptions" @@ -328,7 +329,7 @@ func (s *handlerV1) UpdateEnterpriseSubscription(ctx context.Context, req *conne // Empty field paths means update all non-empty fields. if len(fieldPaths) == 0 { if v := req.Msg.GetSubscription().GetInstanceDomain(); v != "" { - opts.InstanceDomain = v + opts.InstanceDomain = &v } if v := req.Msg.GetSubscription().GetDisplayName(); v != "" { opts.DisplayName = v @@ -337,22 +338,23 @@ func (s *handlerV1) UpdateEnterpriseSubscription(ctx context.Context, req *conne for _, p := range fieldPaths { switch p { case "instance_domain": - opts.InstanceDomain = req.Msg.GetSubscription().GetInstanceDomain() + opts.InstanceDomain = pointers.Ptr(req.Msg.GetSubscription().GetInstanceDomain()) case "display_name": opts.DisplayName = req.Msg.GetSubscription().GetDisplayName() case "*": opts.ForceUpdate = true - opts.InstanceDomain = req.Msg.GetSubscription().GetInstanceDomain() + opts.InstanceDomain = pointers.Ptr(req.Msg.GetSubscription().GetInstanceDomain()) } } } // Validate and normalize the domain - if opts.InstanceDomain != "" { - opts.InstanceDomain, err = subscriptionsv1.NormalizeInstanceDomain(opts.InstanceDomain) + if opts.InstanceDomain != nil { + normalizedDomain, err := subscriptionsv1.NormalizeInstanceDomain(pointers.DerefZero(opts.InstanceDomain)) if err != nil { return nil, connect.NewError(connect.CodeInvalidArgument, errors.Wrap(err, "invalid instance domain")) } + opts.InstanceDomain = &normalizedDomain } subscription, err := s.store.UpsertEnterpriseSubscription(ctx, subscriptionID, opts) From 71b3cfe0e45872a06428925fde6fd803fbffa555 Mon Sep 17 00:00:00 2001 From: Robert Lin Date: Fri, 12 Jul 2024 11:48:45 -0700 Subject: [PATCH 2/6] fix nullable updates --- .../internal/database/BUILD.bazel | 1 + .../database/internal/utctime/utctime.go | 16 +++++-- .../database/subscriptions/BUILD.bazel | 1 + .../database/subscriptions/licenses_test.go | 5 ++- .../database/subscriptions/subscriptions.go | 7 +-- .../subscriptions/subscriptions_test.go | 45 ++++++++++++------- .../internal/database/types.go | 10 +++++ .../internal/subscriptionsservice/adapters.go | 2 +- .../internal/subscriptionsservice/v1.go | 26 +++++++---- 9 files changed, 80 insertions(+), 33 deletions(-) create mode 100644 cmd/enterprise-portal/internal/database/types.go diff --git a/cmd/enterprise-portal/internal/database/BUILD.bazel b/cmd/enterprise-portal/internal/database/BUILD.bazel index c042c39f40cd6..dbbb9871bcb15 100644 --- a/cmd/enterprise-portal/internal/database/BUILD.bazel +++ b/cmd/enterprise-portal/internal/database/BUILD.bazel @@ -5,6 +5,7 @@ go_library( srcs = [ "database.go", "migrate.go", + "types.go", ], importpath = "github.com/sourcegraph/sourcegraph/cmd/enterprise-portal/internal/database", tags = [TAG_INFRA_CORESERVICES], diff --git a/cmd/enterprise-portal/internal/database/internal/utctime/utctime.go b/cmd/enterprise-portal/internal/database/internal/utctime/utctime.go index 5a6d8792168b0..25f5ed412650d 100644 --- a/cmd/enterprise-portal/internal/database/internal/utctime/utctime.go +++ b/cmd/enterprise-portal/internal/database/internal/utctime/utctime.go @@ -13,13 +13,21 @@ import ( // Time is a wrapper around time.Time that implements the database/sql.Scanner // and database/sql/driver.Valuer interfaces to serialize and deserialize time // in UTC time zone. +// +// Time ensures that time.Time values are always: +// +// - represented in UTC for consistency +// - rounded to microsecond precision +// +// We round the time because PostgreSQL times are represented in microseconds: +// https://www.postgresql.org/docs/current/datatype-datetime.html type Time time.Time // Now returns the current time in UTC. -func Now() Time { return Time(time.Now().UTC()) } +func Now() Time { return Time(time.Now()) } // FromTime returns a utctime.Time from a time.Time. -func FromTime(t time.Time) Time { return Time(t.UTC()) } +func FromTime(t time.Time) Time { return Time(t.UTC().Round(time.Microsecond)) } var _ sql.Scanner = (*Time)(nil) @@ -28,7 +36,7 @@ func (t *Time) Scan(src any) error { return nil } if v, ok := src.(time.Time); ok { - *t = Time(v.UTC()) + *t = FromTime(v) return nil } return errors.Newf("value %T is not time.Time", src) @@ -64,5 +72,5 @@ func (t *Time) Time() *time.Time { return nil } // Ensure the time is in UTC. - return pointers.Ptr((*time.Time)(t).UTC()) + return pointers.Ptr((*time.Time)(t).UTC().Round(time.Microsecond)) } diff --git a/cmd/enterprise-portal/internal/database/subscriptions/BUILD.bazel b/cmd/enterprise-portal/internal/database/subscriptions/BUILD.bazel index 2dcadb12c26cd..8cdb672afb02f 100644 --- a/cmd/enterprise-portal/internal/database/subscriptions/BUILD.bazel +++ b/cmd/enterprise-portal/internal/database/subscriptions/BUILD.bazel @@ -38,6 +38,7 @@ go_test( ], deps = [ ":subscriptions", + "//cmd/enterprise-portal/internal/database", "//cmd/enterprise-portal/internal/database/databasetest", "//cmd/enterprise-portal/internal/database/internal/tables", "//cmd/enterprise-portal/internal/database/internal/utctime", diff --git a/cmd/enterprise-portal/internal/database/subscriptions/licenses_test.go b/cmd/enterprise-portal/internal/database/subscriptions/licenses_test.go index 784250b001ec3..80e2ed87fc1fe 100644 --- a/cmd/enterprise-portal/internal/database/subscriptions/licenses_test.go +++ b/cmd/enterprise-portal/internal/database/subscriptions/licenses_test.go @@ -12,6 +12,7 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + "github.com/sourcegraph/sourcegraph/cmd/enterprise-portal/internal/database" "github.com/sourcegraph/sourcegraph/cmd/enterprise-portal/internal/database/databasetest" "github.com/sourcegraph/sourcegraph/cmd/enterprise-portal/internal/database/internal/tables" "github.com/sourcegraph/sourcegraph/cmd/enterprise-portal/internal/database/internal/utctime" @@ -31,11 +32,11 @@ func TestLicensesStore(t *testing.T) { subs := subscriptions.NewStore(db) _, err := subs.Upsert(ctx, subscriptionID1, subscriptions.UpsertSubscriptionOptions{ - DisplayName: "Acme, Inc. 1", + DisplayName: pointers.Ptr(database.NewNullString("Acme, Inc. 1")), }) require.NoError(t, err) _, err = subs.Upsert(ctx, subscriptionID2, subscriptions.UpsertSubscriptionOptions{ - DisplayName: "Acme, Inc. 2", + DisplayName: pointers.Ptr(database.NewNullString("Acme, Inc. 2")), }) require.NoError(t, err) diff --git a/cmd/enterprise-portal/internal/database/subscriptions/subscriptions.go b/cmd/enterprise-portal/internal/database/subscriptions/subscriptions.go index 7b49aa18593ba..da454bcf4eeb0 100644 --- a/cmd/enterprise-portal/internal/database/subscriptions/subscriptions.go +++ b/cmd/enterprise-portal/internal/database/subscriptions/subscriptions.go @@ -2,6 +2,7 @@ package subscriptions import ( "context" + "database/sql" "fmt" "strings" "time" @@ -49,7 +50,7 @@ type Subscription struct { // // TODO: Clean up the database post-deploy and remove the 'Unnamed subscription' // part of the constraint. - DisplayName string `gorm:"size:256;not null;uniqueIndex:,where:archived_at IS NULL AND display_name != 'Unnamed subscription' AND display_name != ''"` + DisplayName *string `gorm:"size:256;uniqueIndex:,where:archived_at IS NULL AND display_name != 'Unnamed subscription' AND display_name != ''"` // Timestamps representing the latest timestamps of key conditions related // to this subscription. @@ -177,8 +178,8 @@ WHERE %s } type UpsertSubscriptionOptions struct { - InstanceDomain *string - DisplayName string + InstanceDomain *sql.NullString + DisplayName *sql.NullString CreatedAt time.Time ArchivedAt *time.Time diff --git a/cmd/enterprise-portal/internal/database/subscriptions/subscriptions_test.go b/cmd/enterprise-portal/internal/database/subscriptions/subscriptions_test.go index 70b665356037f..567ba850ecdfd 100644 --- a/cmd/enterprise-portal/internal/database/subscriptions/subscriptions_test.go +++ b/cmd/enterprise-portal/internal/database/subscriptions/subscriptions_test.go @@ -10,6 +10,7 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + "github.com/sourcegraph/sourcegraph/cmd/enterprise-portal/internal/database" "github.com/sourcegraph/sourcegraph/cmd/enterprise-portal/internal/database/databasetest" "github.com/sourcegraph/sourcegraph/cmd/enterprise-portal/internal/database/internal/tables" "github.com/sourcegraph/sourcegraph/cmd/enterprise-portal/internal/database/subscriptions" @@ -45,19 +46,25 @@ func SubscriptionsStoreList(t *testing.T, ctx context.Context, s *subscriptions. s1, err := s.Upsert( ctx, uuid.New().String(), - subscriptions.UpsertSubscriptionOptions{InstanceDomain: pointers.Ptr("s1.sourcegraph.com")}, + subscriptions.UpsertSubscriptionOptions{ + InstanceDomain: pointers.Ptr(database.NewNullString("s1.sourcegraph.com")), + }, ) require.NoError(t, err) s2, err := s.Upsert( ctx, uuid.New().String(), - subscriptions.UpsertSubscriptionOptions{InstanceDomain: pointers.Ptr("s2.sourcegraph.com")}, + subscriptions.UpsertSubscriptionOptions{ + InstanceDomain: pointers.Ptr(database.NewNullString("s2.sourcegraph.com")), + }, ) require.NoError(t, err) _, err = s.Upsert( ctx, uuid.New().String(), - subscriptions.UpsertSubscriptionOptions{InstanceDomain: pointers.Ptr("s3.sourcegraph.com")}, + subscriptions.UpsertSubscriptionOptions{ + InstanceDomain: pointers.Ptr(database.NewNullString("s3.sourcegraph.com")), + }, ) require.NoError(t, err) @@ -115,14 +122,16 @@ func SubscriptionsStoreUpsert(t *testing.T, ctx context.Context, s *subscription currentSubscription, err := s.Upsert( ctx, uuid.New().String(), - subscriptions.UpsertSubscriptionOptions{InstanceDomain: pointers.Ptr("s1.sourcegraph.com")}, + subscriptions.UpsertSubscriptionOptions{ + InstanceDomain: pointers.Ptr(database.NewNullString("s1.sourcegraph.com")), + }, ) require.NoError(t, err) got, err := s.Get(ctx, currentSubscription.ID) require.NoError(t, err) assert.Equal(t, currentSubscription.ID, got.ID) - assert.Equal(t, currentSubscription.InstanceDomain, got.InstanceDomain) + assert.Equal(t, *currentSubscription.InstanceDomain, *got.InstanceDomain) assert.Empty(t, got.DisplayName) assert.NotZero(t, got.CreatedAt) assert.NotZero(t, got.UpdatedAt) @@ -133,17 +142,19 @@ func SubscriptionsStoreUpsert(t *testing.T, ctx context.Context, s *subscription got, err = s.Upsert(ctx, currentSubscription.ID, subscriptions.UpsertSubscriptionOptions{}) require.NoError(t, err) - assert.Equal(t, currentSubscription.InstanceDomain, got.InstanceDomain) + assert.Equal(t, + pointers.DerefZero(currentSubscription.InstanceDomain), + pointers.DerefZero(got.InstanceDomain)) }) t.Run("update only domain", func(t *testing.T) { t.Cleanup(func() { currentSubscription = got }) got, err = s.Upsert(ctx, currentSubscription.ID, subscriptions.UpsertSubscriptionOptions{ - InstanceDomain: pointers.Ptr("s1-new.sourcegraph.com"), + InstanceDomain: pointers.Ptr(database.NewNullString("s1-new.sourcegraph.com")), }) require.NoError(t, err) - assert.Equal(t, "s1-new.sourcegraph.com", got.InstanceDomain) + assert.Equal(t, "s1-new.sourcegraph.com", pointers.DerefZero(got.InstanceDomain)) assert.Equal(t, currentSubscription.DisplayName, got.DisplayName) }) @@ -151,11 +162,11 @@ func SubscriptionsStoreUpsert(t *testing.T, ctx context.Context, s *subscription t.Cleanup(func() { currentSubscription = got }) got, err = s.Upsert(ctx, currentSubscription.ID, subscriptions.UpsertSubscriptionOptions{ - DisplayName: "My New Display Name", + DisplayName: pointers.Ptr(database.NewNullString("My New Display Name")), }) require.NoError(t, err) - assert.Equal(t, currentSubscription.InstanceDomain, got.InstanceDomain) - assert.Equal(t, "My New Display Name", got.DisplayName) + assert.Equal(t, *currentSubscription.InstanceDomain, *got.InstanceDomain) + assert.Equal(t, "My New Display Name", pointers.DerefZero(got.DisplayName)) }) t.Run("update only created at", func(t *testing.T) { @@ -166,7 +177,9 @@ func SubscriptionsStoreUpsert(t *testing.T, ctx context.Context, s *subscription CreatedAt: yesterday, }) require.NoError(t, err) - assert.Equal(t, currentSubscription.InstanceDomain, got.InstanceDomain) + assert.Equal(t, + pointers.DerefZero(currentSubscription.InstanceDomain), + pointers.DerefZero(got.InstanceDomain)) assert.Equal(t, currentSubscription.DisplayName, got.DisplayName) // Round times to allow for some precision drift in CI assert.Equal(t, yesterday.Round(time.Second).UTC(), got.CreatedAt.Time().Round(time.Second)) @@ -180,8 +193,8 @@ func SubscriptionsStoreUpsert(t *testing.T, ctx context.Context, s *subscription ArchivedAt: pointers.Ptr(yesterday), }) require.NoError(t, err) - assert.Equal(t, currentSubscription.InstanceDomain, got.InstanceDomain) - assert.Equal(t, currentSubscription.DisplayName, got.DisplayName) + assert.Equal(t, *currentSubscription.InstanceDomain, *got.InstanceDomain) + assert.Equal(t, *currentSubscription.DisplayName, *got.DisplayName) assert.Equal(t, currentSubscription.CreatedAt, got.CreatedAt) // Round times to allow for some precision drift in CI assert.Equal(t, yesterday.Round(time.Second).UTC(), got.ArchivedAt.Time().Round(time.Second)) @@ -209,7 +222,9 @@ func SubscriptionsStoreGet(t *testing.T, ctx context.Context, s *subscriptions.S s1, err := s.Upsert( ctx, uuid.New().String(), - subscriptions.UpsertSubscriptionOptions{InstanceDomain: pointers.Ptr("s1.sourcegraph.com")}, + subscriptions.UpsertSubscriptionOptions{ + InstanceDomain: pointers.Ptr(database.NewNullString("s1.sourcegraph.com")), + }, ) require.NoError(t, err) diff --git a/cmd/enterprise-portal/internal/database/types.go b/cmd/enterprise-portal/internal/database/types.go new file mode 100644 index 0000000000000..366ed5e5481b6 --- /dev/null +++ b/cmd/enterprise-portal/internal/database/types.go @@ -0,0 +1,10 @@ +package database + +import "database/sql" + +func NewNullString(v string) sql.NullString { + return sql.NullString{ + String: v, + Valid: v != "", + } +} diff --git a/cmd/enterprise-portal/internal/subscriptionsservice/adapters.go b/cmd/enterprise-portal/internal/subscriptionsservice/adapters.go index 4569da1e323c6..97e2c7aba5d13 100644 --- a/cmd/enterprise-portal/internal/subscriptionsservice/adapters.go +++ b/cmd/enterprise-portal/internal/subscriptionsservice/adapters.go @@ -69,7 +69,7 @@ func convertSubscriptionToProto(subscription *subscriptions.Subscription, attrs Id: subscriptionsv1.EnterpriseSubscriptionIDPrefix + attrs.ID, Conditions: conds, InstanceDomain: pointers.DerefZero(subscription.InstanceDomain), - DisplayName: subscription.DisplayName, + DisplayName: pointers.DerefZero(subscription.DisplayName), } } diff --git a/cmd/enterprise-portal/internal/subscriptionsservice/v1.go b/cmd/enterprise-portal/internal/subscriptionsservice/v1.go index 50ffa0a3e21b1..0e76abde9538f 100644 --- a/cmd/enterprise-portal/internal/subscriptionsservice/v1.go +++ b/cmd/enterprise-portal/internal/subscriptionsservice/v1.go @@ -19,6 +19,7 @@ import ( "github.com/sourcegraph/sourcegraph/lib/pointers" "github.com/sourcegraph/sourcegraph/cmd/enterprise-portal/internal/connectutil" + "github.com/sourcegraph/sourcegraph/cmd/enterprise-portal/internal/database" "github.com/sourcegraph/sourcegraph/cmd/enterprise-portal/internal/database/subscriptions" "github.com/sourcegraph/sourcegraph/cmd/enterprise-portal/internal/dotcomdb" "github.com/sourcegraph/sourcegraph/cmd/enterprise-portal/internal/samsm2m" @@ -329,32 +330,41 @@ func (s *handlerV1) UpdateEnterpriseSubscription(ctx context.Context, req *conne // Empty field paths means update all non-empty fields. if len(fieldPaths) == 0 { if v := req.Msg.GetSubscription().GetInstanceDomain(); v != "" { - opts.InstanceDomain = &v + opts.InstanceDomain = pointers.Ptr(database.NewNullString(v)) } if v := req.Msg.GetSubscription().GetDisplayName(); v != "" { - opts.DisplayName = v + opts.DisplayName = pointers.Ptr(database.NewNullString(v)) } } else { for _, p := range fieldPaths { switch p { case "instance_domain": - opts.InstanceDomain = pointers.Ptr(req.Msg.GetSubscription().GetInstanceDomain()) + opts.InstanceDomain = pointers.Ptr( + database.NewNullString(req.Msg.GetSubscription().GetInstanceDomain()), + ) case "display_name": - opts.DisplayName = req.Msg.GetSubscription().GetDisplayName() + opts.DisplayName = pointers.Ptr( + database.NewNullString(req.Msg.GetSubscription().GetDisplayName()), + ) case "*": opts.ForceUpdate = true - opts.InstanceDomain = pointers.Ptr(req.Msg.GetSubscription().GetInstanceDomain()) + opts.InstanceDomain = pointers.Ptr( + database.NewNullString(req.Msg.GetSubscription().GetInstanceDomain()), + ) + opts.DisplayName = pointers.Ptr( + database.NewNullString(req.Msg.GetSubscription().GetDisplayName()), + ) } } } // Validate and normalize the domain - if opts.InstanceDomain != nil { - normalizedDomain, err := subscriptionsv1.NormalizeInstanceDomain(pointers.DerefZero(opts.InstanceDomain)) + if opts.InstanceDomain != nil && opts.InstanceDomain.Valid { + normalizedDomain, err := subscriptionsv1.NormalizeInstanceDomain(opts.InstanceDomain.String) if err != nil { return nil, connect.NewError(connect.CodeInvalidArgument, errors.Wrap(err, "invalid instance domain")) } - opts.InstanceDomain = &normalizedDomain + opts.InstanceDomain.String = normalizedDomain } subscription, err := s.store.UpsertEnterpriseSubscription(ctx, subscriptionID, opts) From 0a54a1ba14a59fce180d0ebc4f42fb82c4117590 Mon Sep 17 00:00:00 2001 From: Robert Lin Date: Fri, 12 Jul 2024 13:38:47 -0700 Subject: [PATCH 3/6] lift expire_time --- .../database/internal/utctime/utctime.go | 16 +++--- .../subscriptions/license_conditions.go | 2 +- .../database/subscriptions/licenses.go | 24 +++++++-- .../database/subscriptions/licenses_test.go | 49 ++++++++++++++----- .../subscriptions/subscriptions_test.go | 4 +- 5 files changed, 69 insertions(+), 26 deletions(-) diff --git a/cmd/enterprise-portal/internal/database/internal/utctime/utctime.go b/cmd/enterprise-portal/internal/database/internal/utctime/utctime.go index 25f5ed412650d..0e90287687ea5 100644 --- a/cmd/enterprise-portal/internal/database/internal/utctime/utctime.go +++ b/cmd/enterprise-portal/internal/database/internal/utctime/utctime.go @@ -47,13 +47,13 @@ var _ driver.Valuer = (*Time)(nil) // Value must be called with a non-nil Time. driver.Valuer callers will first // check that the value is non-nil, so this is safe. func (t Time) Value() (driver.Value, error) { - stdTime := t.Time() + stdTime := t.GetTime() return *stdTime, nil } var _ json.Marshaler = (*Time)(nil) -func (t Time) MarshalJSON() ([]byte, error) { return json.Marshal(t.Time()) } +func (t Time) MarshalJSON() ([]byte, error) { return json.Marshal(t.GetTime()) } var _ json.Unmarshaler = (*Time)(nil) @@ -66,11 +66,15 @@ func (t *Time) UnmarshalJSON(data []byte) error { return nil } -// Time returns the underlying time.Time value, or nil if it is nil. -func (t *Time) Time() *time.Time { +// GetTime returns the underlying time.GetTime value, or nil if it is nil. +func (t *Time) GetTime() *time.Time { if t == nil { return nil } - // Ensure the time is in UTC. - return pointers.Ptr((*time.Time)(t).UTC().Round(time.Microsecond)) + return pointers.Ptr(t.AsTime()) +} + +// Time casts the Time as a standard time.Time value. +func (t Time) AsTime() time.Time { + return time.Time(t).UTC().Round(time.Microsecond) } diff --git a/cmd/enterprise-portal/internal/database/subscriptions/license_conditions.go b/cmd/enterprise-portal/internal/database/subscriptions/license_conditions.go index b047dfd0dadb5..ac19bceeca04d 100644 --- a/cmd/enterprise-portal/internal/database/subscriptions/license_conditions.go +++ b/cmd/enterprise-portal/internal/database/subscriptions/license_conditions.go @@ -67,7 +67,7 @@ type createLicenseConditionOpts struct { } func (s *licenseConditionsStore) createLicenseCondition(ctx context.Context, licenseID string, opts createLicenseConditionOpts) error { - if opts.TransitionTime.Time().IsZero() { + if opts.TransitionTime.GetTime().IsZero() { return errors.New("transition time is required") } _, err := s.tx.Exec(ctx, ` diff --git a/cmd/enterprise-portal/internal/database/subscriptions/licenses.go b/cmd/enterprise-portal/internal/database/subscriptions/licenses.go index c96865900faed..11289b20135f5 100644 --- a/cmd/enterprise-portal/internal/database/subscriptions/licenses.go +++ b/cmd/enterprise-portal/internal/database/subscriptions/licenses.go @@ -56,6 +56,10 @@ type SubscriptionLicense struct { CreatedAt utctime.Time `gorm:"not null;default:current_timestamp"` RevokedAt *utctime.Time // Null indicates the license is not revoked. + // ExpireAt is the time at which the license should expire. Expiration does + // NOT get a corresponding condition entry in 'enterprise_portal_subscription_license_conditions'. + ExpireAt utctime.Time `gorm:"not null"` + // LicenseType is the kind of license stored in LicenseData, corresponding // to the API 'EnterpriseSubscriptionLicenseType'. LicenseType string `gorm:"not null"` @@ -77,6 +81,7 @@ func subscriptionLicenseWithConditionsColumns() []string { "created_at", "revoked_at", + "expire_at", "license_type", "license_data", @@ -98,6 +103,7 @@ func scanSubscriptionLicenseWithConditions(row pgx.Row) (*LicenseWithConditions, &l.ID, &l.CreatedAt, &l.RevokedAt, + &l.ExpireAt, &l.LicenseType, &l.LicenseData, &l.Conditions, // see subscriptionLicenseConditionJSONBAgg docstring @@ -207,6 +213,8 @@ type CreateLicenseOpts struct { Message string // If nil, the creation time will be set to the current time. Time *utctime.Time + // Expiration time of the license. + ExpireTime utctime.Time } // LicenseKey corresponds to *subscriptionsv1.EnterpriseSubscriptionLicenseKey @@ -228,9 +236,12 @@ func (s *LicensesStore) CreateLicenseKey( // match the time provided in the options. if opts.Time == nil { return nil, errors.New("creation time must be specified for licensekeys") - } else if !opts.Time.Time().Equal(license.Info.CreatedAt) { + } else if !opts.Time.GetTime().Equal(utctime.FromTime(license.Info.CreatedAt).AsTime()) { return nil, errors.New("creation time must match the license key information") } + if !opts.ExpireTime.GetTime().Equal(utctime.FromTime(license.Info.ExpiresAt).AsTime()) { + return nil, errors.New("expiration time must match the license key information") + } return s.create( ctx, @@ -253,7 +264,7 @@ func (s *LicensesStore) create( } if opts.Time == nil { opts.Time = pointers.Ptr(utctime.Now()) - } else if opts.Time.Time().After(time.Now()) { + } else if opts.Time.GetTime().After(time.Now()) { return nil, errors.New("creation time cannot be in the future") } if licenseType == subscriptionsv1.EnterpriseSubscriptionLicenseType_ENTERPRISE_SUBSCRIPTION_LICENSE_TYPE_UNSPECIFIED { @@ -284,14 +295,16 @@ INSERT INTO enterprise_portal_subscription_licenses ( subscription_id, license_type, license_data, - created_at + created_at, + expire_at ) VALUES ( @licenseID, @subscriptionID, @licenseType, @licenseData, - @createdAt + @createdAt, + @expireAt ) `, pgx.NamedArgs{ "licenseID": licenseID.String(), @@ -299,6 +312,7 @@ VALUES ( "licenseType": subscriptionsv1.EnterpriseSubscriptionLicenseType_name[int32(licenseType)], "licenseData": licenseData, "createdAt": opts.Time, + "expireAt": opts.ExpireTime, }); err != nil { return nil, errors.Wrap(err, "create license") } @@ -328,7 +342,7 @@ type RevokeLicenseOpts struct { func (s *LicensesStore) Revoke(ctx context.Context, licenseID string, opts RevokeLicenseOpts) (*LicenseWithConditions, error) { if opts.Time == nil { opts.Time = pointers.Ptr(utctime.Now()) - } else if opts.Time.Time().After(time.Now()) { + } else if opts.Time.GetTime().After(time.Now()) { return nil, errors.New("revocation time cannot be in the future") } diff --git a/cmd/enterprise-portal/internal/database/subscriptions/licenses_test.go b/cmd/enterprise-portal/internal/database/subscriptions/licenses_test.go index 80e2ed87fc1fe..1cbd5fdd6700e 100644 --- a/cmd/enterprise-portal/internal/database/subscriptions/licenses_test.go +++ b/cmd/enterprise-portal/internal/database/subscriptions/licenses_test.go @@ -61,6 +61,7 @@ func TestLicensesStore(t *testing.T) { ) { assert.NotEmpty(t, got.ID) assert.NotZero(t, got.CreatedAt) + assert.NotZero(t, got.ExpireAt) assert.Equal(t, "ENTERPRISE_SUBSCRIPTION_LICENSE_TYPE_KEY", got.LicenseType) wantLicenseData.Equal(t, string(got.LicenseData)) @@ -75,18 +76,20 @@ func TestLicensesStore(t *testing.T) { Info: license.Info{ Tags: []string{"foo"}, CreatedAt: time.Time{}.Add(1 * time.Hour), + ExpiresAt: time.Time{}.Add(48 * time.Hour), }, SignedKey: "asdfasdf", }, subscriptions.CreateLicenseOpts{ - Message: t.Name() + " 1 old", - Time: pointers.Ptr(utctime.FromTime(time.Time{}.Add(1 * time.Hour))), + Message: t.Name() + " 1 old", + Time: pointers.Ptr(utctime.FromTime(time.Time{}.Add(1 * time.Hour))), + ExpireTime: utctime.FromTime(time.Time{}.Add(48 * time.Hour)), }) require.NoError(t, err) testLicense( got, autogold.Expect(valast.Ptr("TestLicensesStore/CreateLicenseKey 1 old")), - autogold.Expect(`{"Info": {"c": "0001-01-01T01:00:00Z", "e": "0001-01-01T00:00:00Z", "t": ["foo"], "u": 0}, "SignedKey": "asdfasdf"}`), + autogold.Expect(`{"Info": {"c": "0001-01-01T01:00:00Z", "e": "0001-01-03T00:00:00Z", "t": ["foo"], "u": 0}, "SignedKey": "asdfasdf"}`), ) createdLicenses = append(createdLicenses, got) @@ -95,18 +98,20 @@ func TestLicensesStore(t *testing.T) { Info: license.Info{ Tags: []string{"baz"}, CreatedAt: time.Time{}.Add(24 * time.Hour), + ExpiresAt: time.Time{}.Add(48 * time.Hour), }, SignedKey: "barasdf", }, subscriptions.CreateLicenseOpts{ - Message: t.Name() + " 1", - Time: pointers.Ptr(utctime.FromTime(time.Time{}.Add(24 * time.Hour))), + Message: t.Name() + " 1", + Time: pointers.Ptr(utctime.FromTime(time.Time{}.Add(24 * time.Hour))), + ExpireTime: utctime.FromTime(time.Time{}.Add(48 * time.Hour)), }) require.NoError(t, err) testLicense( got, autogold.Expect(valast.Ptr("TestLicensesStore/CreateLicenseKey 1")), - autogold.Expect(`{"Info": {"c": "0001-01-02T00:00:00Z", "e": "0001-01-01T00:00:00Z", "t": ["baz"], "u": 0}, "SignedKey": "barasdf"}`), + autogold.Expect(`{"Info": {"c": "0001-01-02T00:00:00Z", "e": "0001-01-03T00:00:00Z", "t": ["baz"], "u": 0}, "SignedKey": "barasdf"}`), ) createdLicenses = append(createdLicenses, got) @@ -115,22 +120,24 @@ func TestLicensesStore(t *testing.T) { Info: license.Info{ Tags: []string{"tag"}, CreatedAt: time.Time{}.Add(24 * time.Hour), + ExpiresAt: time.Time{}.Add(48 * time.Hour), }, SignedKey: "asdffdsadf", }, subscriptions.CreateLicenseOpts{ - Message: t.Name() + " 2", - Time: pointers.Ptr(utctime.FromTime(time.Time{}.Add(24 * time.Hour))), + Message: t.Name() + " 2", + Time: pointers.Ptr(utctime.FromTime(time.Time{}.Add(24 * time.Hour))), + ExpireTime: utctime.FromTime(time.Time{}.Add(48 * time.Hour)), }) require.NoError(t, err) testLicense( got, autogold.Expect(valast.Ptr("TestLicensesStore/CreateLicenseKey 2")), - autogold.Expect(`{"Info": {"c": "0001-01-02T00:00:00Z", "e": "0001-01-01T00:00:00Z", "t": ["tag"], "u": 0}, "SignedKey": "asdffdsadf"}`), + autogold.Expect(`{"Info": {"c": "0001-01-02T00:00:00Z", "e": "0001-01-03T00:00:00Z", "t": ["tag"], "u": 0}, "SignedKey": "asdffdsadf"}`), ) createdLicenses = append(createdLicenses, got) - t.Run("timestamps do not match", func(t *testing.T) { + t.Run("createdAt does not match", func(t *testing.T) { _, err = licenses.CreateLicenseKey(ctx, subscriptionID2, &subscriptions.LicenseKey{ Info: license.Info{ @@ -146,6 +153,24 @@ func TestLicensesStore(t *testing.T) { require.Error(t, err) autogold.Expect("creation time must match the license key information").Equal(t, err.Error()) }) + t.Run("expiresAt does not match", func(t *testing.T) { + _, err = licenses.CreateLicenseKey(ctx, subscriptionID2, + &subscriptions.LicenseKey{ + Info: license.Info{ + Tags: []string{"tag"}, + CreatedAt: time.Time{}, + ExpiresAt: time.Time{}.Add(48 * time.Hour), + }, + SignedKey: "asdffdsadf", + }, + subscriptions.CreateLicenseOpts{ + Message: t.Name(), + Time: pointers.Ptr(utctime.FromTime(time.Time{})), + ExpireTime: utctime.Now(), + }) + require.Error(t, err) + autogold.Expect("expiration time must match the license key information").Equal(t, err.Error()) + }) }) t.Run("List", func(t *testing.T) { @@ -196,11 +221,11 @@ func TestLicensesStore(t *testing.T) { Time: pointers.Ptr(utctime.FromTime(revokeTime)), }) require.NoError(t, err) - assert.Equal(t, revokeTime.UTC(), *got.RevokedAt.Time()) + assert.Equal(t, revokeTime.UTC(), got.RevokedAt.AsTime()) require.Len(t, got.Conditions, 2) // Most recent condition is sorted first, and should be the revocation assert.Equal(t, "STATUS_REVOKED", got.Conditions[0].Status) - assert.Equal(t, revokeTime.UTC(), *got.Conditions[0].TransitionTime.Time()) + assert.Equal(t, revokeTime.UTC(), *got.Conditions[0].TransitionTime.GetTime()) assert.Equal(t, "STATUS_CREATED", got.Conditions[1].Status) } }) diff --git a/cmd/enterprise-portal/internal/database/subscriptions/subscriptions_test.go b/cmd/enterprise-portal/internal/database/subscriptions/subscriptions_test.go index 567ba850ecdfd..2cbef7b07dfd0 100644 --- a/cmd/enterprise-portal/internal/database/subscriptions/subscriptions_test.go +++ b/cmd/enterprise-portal/internal/database/subscriptions/subscriptions_test.go @@ -182,7 +182,7 @@ func SubscriptionsStoreUpsert(t *testing.T, ctx context.Context, s *subscription pointers.DerefZero(got.InstanceDomain)) assert.Equal(t, currentSubscription.DisplayName, got.DisplayName) // Round times to allow for some precision drift in CI - assert.Equal(t, yesterday.Round(time.Second).UTC(), got.CreatedAt.Time().Round(time.Second)) + assert.Equal(t, yesterday.Round(time.Second).UTC(), got.CreatedAt.GetTime().Round(time.Second)) }) t.Run("update only archived at", func(t *testing.T) { @@ -197,7 +197,7 @@ func SubscriptionsStoreUpsert(t *testing.T, ctx context.Context, s *subscription assert.Equal(t, *currentSubscription.DisplayName, *got.DisplayName) assert.Equal(t, currentSubscription.CreatedAt, got.CreatedAt) // Round times to allow for some precision drift in CI - assert.Equal(t, yesterday.Round(time.Second).UTC(), got.ArchivedAt.Time().Round(time.Second)) + assert.Equal(t, yesterday.Round(time.Second).UTC(), got.ArchivedAt.GetTime().Round(time.Second)) }) t.Run("force update to zero values", func(t *testing.T) { From ba8a30dc83992b6d97024658b6d2f2b920a3250b Mon Sep 17 00:00:00 2001 From: Robert Lin Date: Fri, 12 Jul 2024 16:35:20 -0700 Subject: [PATCH 4/6] fixup test --- .../internal/database/subscriptions/licenses_test.go | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/cmd/enterprise-portal/internal/database/subscriptions/licenses_test.go b/cmd/enterprise-portal/internal/database/subscriptions/licenses_test.go index 1cbd5fdd6700e..a96fc92bbf00a 100644 --- a/cmd/enterprise-portal/internal/database/subscriptions/licenses_test.go +++ b/cmd/enterprise-portal/internal/database/subscriptions/licenses_test.go @@ -215,17 +215,17 @@ func TestLicensesStore(t *testing.T) { t.Run("Revoke", func(t *testing.T) { for idx, license := range createdLicenses { - revokeTime := time.Now().Add(-time.Second) + revokeTime := utctime.FromTime(time.Now().Add(-time.Second)) got, err := licenses.Revoke(ctx, license.ID, subscriptions.RevokeLicenseOpts{ Message: fmt.Sprintf("%s %d", t.Name(), idx), - Time: pointers.Ptr(utctime.FromTime(revokeTime)), + Time: pointers.Ptr(revokeTime), }) require.NoError(t, err) - assert.Equal(t, revokeTime.UTC(), got.RevokedAt.AsTime()) + assert.Equal(t, revokeTime.AsTime(), got.RevokedAt.AsTime()) require.Len(t, got.Conditions, 2) // Most recent condition is sorted first, and should be the revocation assert.Equal(t, "STATUS_REVOKED", got.Conditions[0].Status) - assert.Equal(t, revokeTime.UTC(), *got.Conditions[0].TransitionTime.GetTime()) + assert.Equal(t, revokeTime.AsTime(), got.Conditions[0].TransitionTime.AsTime()) assert.Equal(t, "STATUS_CREATED", got.Conditions[1].Status) } }) From 0589692ed23e811a2346db0f28f94042c4508020 Mon Sep 17 00:00:00 2001 From: Robert Lin Date: Fri, 12 Jul 2024 16:36:24 -0700 Subject: [PATCH 5/6] fast-fail --- .../internal/database/subscriptions/licenses_test.go | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/cmd/enterprise-portal/internal/database/subscriptions/licenses_test.go b/cmd/enterprise-portal/internal/database/subscriptions/licenses_test.go index a96fc92bbf00a..c79657b7045ee 100644 --- a/cmd/enterprise-portal/internal/database/subscriptions/licenses_test.go +++ b/cmd/enterprise-portal/internal/database/subscriptions/licenses_test.go @@ -173,6 +173,12 @@ func TestLicensesStore(t *testing.T) { }) }) + // No point continuing if test licenses did not create, all tests after this + // will fail + if t.Failed() { + t.FailNow() + } + t.Run("List", func(t *testing.T) { listedLicenses, err := licenses.List(ctx, subscriptions.ListLicensesOpts{}) require.NoError(t, err) From 9110453d15044f3ef6b2ba257c58da6d5589eb4d Mon Sep 17 00:00:00 2001 From: Robert Lin Date: Fri, 12 Jul 2024 17:50:47 -0700 Subject: [PATCH 6/6] left join instead --- .../internal/database/subscriptions/license_conditions.go | 2 +- .../internal/database/subscriptions/licenses.go | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/cmd/enterprise-portal/internal/database/subscriptions/license_conditions.go b/cmd/enterprise-portal/internal/database/subscriptions/license_conditions.go index ac19bceeca04d..e3cee4179dc37 100644 --- a/cmd/enterprise-portal/internal/database/subscriptions/license_conditions.go +++ b/cmd/enterprise-portal/internal/database/subscriptions/license_conditions.go @@ -30,7 +30,7 @@ func (*SubscriptionLicenseCondition) TableName() string { // subscriptionLicenseConditionJSONBAgg must be used with: // -// JOIN +// LEFT JOIN // enterprise_portal_subscription_license_conditions license_condition // ON license_condition.license_id = id // GROUP BY diff --git a/cmd/enterprise-portal/internal/database/subscriptions/licenses.go b/cmd/enterprise-portal/internal/database/subscriptions/licenses.go index 11289b20135f5..386130a55492c 100644 --- a/cmd/enterprise-portal/internal/database/subscriptions/licenses.go +++ b/cmd/enterprise-portal/internal/database/subscriptions/licenses.go @@ -153,7 +153,7 @@ SELECT %s FROM enterprise_portal_subscription_licenses -JOIN +LEFT JOIN enterprise_portal_subscription_license_conditions license_condition ON license_condition.license_id = id WHERE @@ -189,7 +189,7 @@ SELECT %s FROM enterprise_portal_subscription_licenses -JOIN +LEFT JOIN enterprise_portal_subscription_license_conditions license_condition ON license_condition.license_id = id WHERE