diff --git a/cmd/enterprise-portal/internal/database/BUILD.bazel b/cmd/enterprise-portal/internal/database/BUILD.bazel index c042c39f40cd6..dbbb9871bcb15 100644 --- a/cmd/enterprise-portal/internal/database/BUILD.bazel +++ b/cmd/enterprise-portal/internal/database/BUILD.bazel @@ -5,6 +5,7 @@ go_library( srcs = [ "database.go", "migrate.go", + "types.go", ], importpath = "github.com/sourcegraph/sourcegraph/cmd/enterprise-portal/internal/database", tags = [TAG_INFRA_CORESERVICES], diff --git a/cmd/enterprise-portal/internal/database/internal/utctime/utctime.go b/cmd/enterprise-portal/internal/database/internal/utctime/utctime.go index 5a6d8792168b0..25f5ed412650d 100644 --- a/cmd/enterprise-portal/internal/database/internal/utctime/utctime.go +++ b/cmd/enterprise-portal/internal/database/internal/utctime/utctime.go @@ -13,13 +13,21 @@ import ( // Time is a wrapper around time.Time that implements the database/sql.Scanner // and database/sql/driver.Valuer interfaces to serialize and deserialize time // in UTC time zone. +// +// Time ensures that time.Time values are always: +// +// - represented in UTC for consistency +// - rounded to microsecond precision +// +// We round the time because PostgreSQL times are represented in microseconds: +// https://www.postgresql.org/docs/current/datatype-datetime.html type Time time.Time // Now returns the current time in UTC. -func Now() Time { return Time(time.Now().UTC()) } +func Now() Time { return Time(time.Now()) } // FromTime returns a utctime.Time from a time.Time. -func FromTime(t time.Time) Time { return Time(t.UTC()) } +func FromTime(t time.Time) Time { return Time(t.UTC().Round(time.Microsecond)) } var _ sql.Scanner = (*Time)(nil) @@ -28,7 +36,7 @@ func (t *Time) Scan(src any) error { return nil } if v, ok := src.(time.Time); ok { - *t = Time(v.UTC()) + *t = FromTime(v) return nil } return errors.Newf("value %T is not time.Time", src) @@ -64,5 +72,5 @@ func (t *Time) Time() *time.Time { return nil } // Ensure the time is in UTC. - return pointers.Ptr((*time.Time)(t).UTC()) + return pointers.Ptr((*time.Time)(t).UTC().Round(time.Microsecond)) } diff --git a/cmd/enterprise-portal/internal/database/subscriptions/BUILD.bazel b/cmd/enterprise-portal/internal/database/subscriptions/BUILD.bazel index 2dcadb12c26cd..8cdb672afb02f 100644 --- a/cmd/enterprise-portal/internal/database/subscriptions/BUILD.bazel +++ b/cmd/enterprise-portal/internal/database/subscriptions/BUILD.bazel @@ -38,6 +38,7 @@ go_test( ], deps = [ ":subscriptions", + "//cmd/enterprise-portal/internal/database", "//cmd/enterprise-portal/internal/database/databasetest", "//cmd/enterprise-portal/internal/database/internal/tables", "//cmd/enterprise-portal/internal/database/internal/utctime", diff --git a/cmd/enterprise-portal/internal/database/subscriptions/licenses_test.go b/cmd/enterprise-portal/internal/database/subscriptions/licenses_test.go index 784250b001ec3..80e2ed87fc1fe 100644 --- a/cmd/enterprise-portal/internal/database/subscriptions/licenses_test.go +++ b/cmd/enterprise-portal/internal/database/subscriptions/licenses_test.go @@ -12,6 +12,7 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + "github.com/sourcegraph/sourcegraph/cmd/enterprise-portal/internal/database" "github.com/sourcegraph/sourcegraph/cmd/enterprise-portal/internal/database/databasetest" "github.com/sourcegraph/sourcegraph/cmd/enterprise-portal/internal/database/internal/tables" "github.com/sourcegraph/sourcegraph/cmd/enterprise-portal/internal/database/internal/utctime" @@ -31,11 +32,11 @@ func TestLicensesStore(t *testing.T) { subs := subscriptions.NewStore(db) _, err := subs.Upsert(ctx, subscriptionID1, subscriptions.UpsertSubscriptionOptions{ - DisplayName: "Acme, Inc. 1", + DisplayName: pointers.Ptr(database.NewNullString("Acme, Inc. 1")), }) require.NoError(t, err) _, err = subs.Upsert(ctx, subscriptionID2, subscriptions.UpsertSubscriptionOptions{ - DisplayName: "Acme, Inc. 2", + DisplayName: pointers.Ptr(database.NewNullString("Acme, Inc. 2")), }) require.NoError(t, err) diff --git a/cmd/enterprise-portal/internal/database/subscriptions/subscriptions.go b/cmd/enterprise-portal/internal/database/subscriptions/subscriptions.go index 7b49aa18593ba..da454bcf4eeb0 100644 --- a/cmd/enterprise-portal/internal/database/subscriptions/subscriptions.go +++ b/cmd/enterprise-portal/internal/database/subscriptions/subscriptions.go @@ -2,6 +2,7 @@ package subscriptions import ( "context" + "database/sql" "fmt" "strings" "time" @@ -49,7 +50,7 @@ type Subscription struct { // // TODO: Clean up the database post-deploy and remove the 'Unnamed subscription' // part of the constraint. - DisplayName string `gorm:"size:256;not null;uniqueIndex:,where:archived_at IS NULL AND display_name != 'Unnamed subscription' AND display_name != ''"` + DisplayName *string `gorm:"size:256;uniqueIndex:,where:archived_at IS NULL AND display_name != 'Unnamed subscription' AND display_name != ''"` // Timestamps representing the latest timestamps of key conditions related // to this subscription. @@ -177,8 +178,8 @@ WHERE %s } type UpsertSubscriptionOptions struct { - InstanceDomain *string - DisplayName string + InstanceDomain *sql.NullString + DisplayName *sql.NullString CreatedAt time.Time ArchivedAt *time.Time diff --git a/cmd/enterprise-portal/internal/database/subscriptions/subscriptions_test.go b/cmd/enterprise-portal/internal/database/subscriptions/subscriptions_test.go index 70b665356037f..567ba850ecdfd 100644 --- a/cmd/enterprise-portal/internal/database/subscriptions/subscriptions_test.go +++ b/cmd/enterprise-portal/internal/database/subscriptions/subscriptions_test.go @@ -10,6 +10,7 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + "github.com/sourcegraph/sourcegraph/cmd/enterprise-portal/internal/database" "github.com/sourcegraph/sourcegraph/cmd/enterprise-portal/internal/database/databasetest" "github.com/sourcegraph/sourcegraph/cmd/enterprise-portal/internal/database/internal/tables" "github.com/sourcegraph/sourcegraph/cmd/enterprise-portal/internal/database/subscriptions" @@ -45,19 +46,25 @@ func SubscriptionsStoreList(t *testing.T, ctx context.Context, s *subscriptions. s1, err := s.Upsert( ctx, uuid.New().String(), - subscriptions.UpsertSubscriptionOptions{InstanceDomain: pointers.Ptr("s1.sourcegraph.com")}, + subscriptions.UpsertSubscriptionOptions{ + InstanceDomain: pointers.Ptr(database.NewNullString("s1.sourcegraph.com")), + }, ) require.NoError(t, err) s2, err := s.Upsert( ctx, uuid.New().String(), - subscriptions.UpsertSubscriptionOptions{InstanceDomain: pointers.Ptr("s2.sourcegraph.com")}, + subscriptions.UpsertSubscriptionOptions{ + InstanceDomain: pointers.Ptr(database.NewNullString("s2.sourcegraph.com")), + }, ) require.NoError(t, err) _, err = s.Upsert( ctx, uuid.New().String(), - subscriptions.UpsertSubscriptionOptions{InstanceDomain: pointers.Ptr("s3.sourcegraph.com")}, + subscriptions.UpsertSubscriptionOptions{ + InstanceDomain: pointers.Ptr(database.NewNullString("s3.sourcegraph.com")), + }, ) require.NoError(t, err) @@ -115,14 +122,16 @@ func SubscriptionsStoreUpsert(t *testing.T, ctx context.Context, s *subscription currentSubscription, err := s.Upsert( ctx, uuid.New().String(), - subscriptions.UpsertSubscriptionOptions{InstanceDomain: pointers.Ptr("s1.sourcegraph.com")}, + subscriptions.UpsertSubscriptionOptions{ + InstanceDomain: pointers.Ptr(database.NewNullString("s1.sourcegraph.com")), + }, ) require.NoError(t, err) got, err := s.Get(ctx, currentSubscription.ID) require.NoError(t, err) assert.Equal(t, currentSubscription.ID, got.ID) - assert.Equal(t, currentSubscription.InstanceDomain, got.InstanceDomain) + assert.Equal(t, *currentSubscription.InstanceDomain, *got.InstanceDomain) assert.Empty(t, got.DisplayName) assert.NotZero(t, got.CreatedAt) assert.NotZero(t, got.UpdatedAt) @@ -133,17 +142,19 @@ func SubscriptionsStoreUpsert(t *testing.T, ctx context.Context, s *subscription got, err = s.Upsert(ctx, currentSubscription.ID, subscriptions.UpsertSubscriptionOptions{}) require.NoError(t, err) - assert.Equal(t, currentSubscription.InstanceDomain, got.InstanceDomain) + assert.Equal(t, + pointers.DerefZero(currentSubscription.InstanceDomain), + pointers.DerefZero(got.InstanceDomain)) }) t.Run("update only domain", func(t *testing.T) { t.Cleanup(func() { currentSubscription = got }) got, err = s.Upsert(ctx, currentSubscription.ID, subscriptions.UpsertSubscriptionOptions{ - InstanceDomain: pointers.Ptr("s1-new.sourcegraph.com"), + InstanceDomain: pointers.Ptr(database.NewNullString("s1-new.sourcegraph.com")), }) require.NoError(t, err) - assert.Equal(t, "s1-new.sourcegraph.com", got.InstanceDomain) + assert.Equal(t, "s1-new.sourcegraph.com", pointers.DerefZero(got.InstanceDomain)) assert.Equal(t, currentSubscription.DisplayName, got.DisplayName) }) @@ -151,11 +162,11 @@ func SubscriptionsStoreUpsert(t *testing.T, ctx context.Context, s *subscription t.Cleanup(func() { currentSubscription = got }) got, err = s.Upsert(ctx, currentSubscription.ID, subscriptions.UpsertSubscriptionOptions{ - DisplayName: "My New Display Name", + DisplayName: pointers.Ptr(database.NewNullString("My New Display Name")), }) require.NoError(t, err) - assert.Equal(t, currentSubscription.InstanceDomain, got.InstanceDomain) - assert.Equal(t, "My New Display Name", got.DisplayName) + assert.Equal(t, *currentSubscription.InstanceDomain, *got.InstanceDomain) + assert.Equal(t, "My New Display Name", pointers.DerefZero(got.DisplayName)) }) t.Run("update only created at", func(t *testing.T) { @@ -166,7 +177,9 @@ func SubscriptionsStoreUpsert(t *testing.T, ctx context.Context, s *subscription CreatedAt: yesterday, }) require.NoError(t, err) - assert.Equal(t, currentSubscription.InstanceDomain, got.InstanceDomain) + assert.Equal(t, + pointers.DerefZero(currentSubscription.InstanceDomain), + pointers.DerefZero(got.InstanceDomain)) assert.Equal(t, currentSubscription.DisplayName, got.DisplayName) // Round times to allow for some precision drift in CI assert.Equal(t, yesterday.Round(time.Second).UTC(), got.CreatedAt.Time().Round(time.Second)) @@ -180,8 +193,8 @@ func SubscriptionsStoreUpsert(t *testing.T, ctx context.Context, s *subscription ArchivedAt: pointers.Ptr(yesterday), }) require.NoError(t, err) - assert.Equal(t, currentSubscription.InstanceDomain, got.InstanceDomain) - assert.Equal(t, currentSubscription.DisplayName, got.DisplayName) + assert.Equal(t, *currentSubscription.InstanceDomain, *got.InstanceDomain) + assert.Equal(t, *currentSubscription.DisplayName, *got.DisplayName) assert.Equal(t, currentSubscription.CreatedAt, got.CreatedAt) // Round times to allow for some precision drift in CI assert.Equal(t, yesterday.Round(time.Second).UTC(), got.ArchivedAt.Time().Round(time.Second)) @@ -209,7 +222,9 @@ func SubscriptionsStoreGet(t *testing.T, ctx context.Context, s *subscriptions.S s1, err := s.Upsert( ctx, uuid.New().String(), - subscriptions.UpsertSubscriptionOptions{InstanceDomain: pointers.Ptr("s1.sourcegraph.com")}, + subscriptions.UpsertSubscriptionOptions{ + InstanceDomain: pointers.Ptr(database.NewNullString("s1.sourcegraph.com")), + }, ) require.NoError(t, err) diff --git a/cmd/enterprise-portal/internal/database/types.go b/cmd/enterprise-portal/internal/database/types.go new file mode 100644 index 0000000000000..366ed5e5481b6 --- /dev/null +++ b/cmd/enterprise-portal/internal/database/types.go @@ -0,0 +1,10 @@ +package database + +import "database/sql" + +func NewNullString(v string) sql.NullString { + return sql.NullString{ + String: v, + Valid: v != "", + } +} diff --git a/cmd/enterprise-portal/internal/subscriptionsservice/adapters.go b/cmd/enterprise-portal/internal/subscriptionsservice/adapters.go index 4569da1e323c6..97e2c7aba5d13 100644 --- a/cmd/enterprise-portal/internal/subscriptionsservice/adapters.go +++ b/cmd/enterprise-portal/internal/subscriptionsservice/adapters.go @@ -69,7 +69,7 @@ func convertSubscriptionToProto(subscription *subscriptions.Subscription, attrs Id: subscriptionsv1.EnterpriseSubscriptionIDPrefix + attrs.ID, Conditions: conds, InstanceDomain: pointers.DerefZero(subscription.InstanceDomain), - DisplayName: subscription.DisplayName, + DisplayName: pointers.DerefZero(subscription.DisplayName), } } diff --git a/cmd/enterprise-portal/internal/subscriptionsservice/v1.go b/cmd/enterprise-portal/internal/subscriptionsservice/v1.go index 50ffa0a3e21b1..0e76abde9538f 100644 --- a/cmd/enterprise-portal/internal/subscriptionsservice/v1.go +++ b/cmd/enterprise-portal/internal/subscriptionsservice/v1.go @@ -19,6 +19,7 @@ import ( "github.com/sourcegraph/sourcegraph/lib/pointers" "github.com/sourcegraph/sourcegraph/cmd/enterprise-portal/internal/connectutil" + "github.com/sourcegraph/sourcegraph/cmd/enterprise-portal/internal/database" "github.com/sourcegraph/sourcegraph/cmd/enterprise-portal/internal/database/subscriptions" "github.com/sourcegraph/sourcegraph/cmd/enterprise-portal/internal/dotcomdb" "github.com/sourcegraph/sourcegraph/cmd/enterprise-portal/internal/samsm2m" @@ -329,32 +330,41 @@ func (s *handlerV1) UpdateEnterpriseSubscription(ctx context.Context, req *conne // Empty field paths means update all non-empty fields. if len(fieldPaths) == 0 { if v := req.Msg.GetSubscription().GetInstanceDomain(); v != "" { - opts.InstanceDomain = &v + opts.InstanceDomain = pointers.Ptr(database.NewNullString(v)) } if v := req.Msg.GetSubscription().GetDisplayName(); v != "" { - opts.DisplayName = v + opts.DisplayName = pointers.Ptr(database.NewNullString(v)) } } else { for _, p := range fieldPaths { switch p { case "instance_domain": - opts.InstanceDomain = pointers.Ptr(req.Msg.GetSubscription().GetInstanceDomain()) + opts.InstanceDomain = pointers.Ptr( + database.NewNullString(req.Msg.GetSubscription().GetInstanceDomain()), + ) case "display_name": - opts.DisplayName = req.Msg.GetSubscription().GetDisplayName() + opts.DisplayName = pointers.Ptr( + database.NewNullString(req.Msg.GetSubscription().GetDisplayName()), + ) case "*": opts.ForceUpdate = true - opts.InstanceDomain = pointers.Ptr(req.Msg.GetSubscription().GetInstanceDomain()) + opts.InstanceDomain = pointers.Ptr( + database.NewNullString(req.Msg.GetSubscription().GetInstanceDomain()), + ) + opts.DisplayName = pointers.Ptr( + database.NewNullString(req.Msg.GetSubscription().GetDisplayName()), + ) } } } // Validate and normalize the domain - if opts.InstanceDomain != nil { - normalizedDomain, err := subscriptionsv1.NormalizeInstanceDomain(pointers.DerefZero(opts.InstanceDomain)) + if opts.InstanceDomain != nil && opts.InstanceDomain.Valid { + normalizedDomain, err := subscriptionsv1.NormalizeInstanceDomain(opts.InstanceDomain.String) if err != nil { return nil, connect.NewError(connect.CodeInvalidArgument, errors.Wrap(err, "invalid instance domain")) } - opts.InstanceDomain = &normalizedDomain + opts.InstanceDomain.String = normalizedDomain } subscription, err := s.store.UpsertEnterpriseSubscription(ctx, subscriptionID, opts)