Skip to content
This repository has been archived by the owner on Sep 30, 2024. It is now read-only.

Commit

Permalink
fix nullable updates
Browse files Browse the repository at this point in the history
  • Loading branch information
bobheadxi committed Jul 12, 2024
1 parent 0ee398b commit 6a9d351
Show file tree
Hide file tree
Showing 9 changed files with 80 additions and 33 deletions.
1 change: 1 addition & 0 deletions cmd/enterprise-portal/internal/database/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ go_library(
srcs = [
"database.go",
"migrate.go",
"types.go",
],
importpath = "github.com/sourcegraph/sourcegraph/cmd/enterprise-portal/internal/database",
tags = [TAG_INFRA_CORESERVICES],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,21 @@ import (
// Time is a wrapper around time.Time that implements the database/sql.Scanner
// and database/sql/driver.Valuer interfaces to serialize and deserialize time
// in UTC time zone.
//
// Time ensures that time.Time values are always:
//
// - represented in UTC for consistency
// - rounded to microsecond precision
//
// We round the time because PostgreSQL times are represented in microseconds:
// https://www.postgresql.org/docs/current/datatype-datetime.html
type Time time.Time

// Now returns the current time in UTC.
func Now() Time { return Time(time.Now().UTC()) }
func Now() Time { return Time(time.Now()) }

// FromTime returns a utctime.Time from a time.Time.
func FromTime(t time.Time) Time { return Time(t.UTC()) }
func FromTime(t time.Time) Time { return Time(t.UTC().Round(time.Microsecond)) }

var _ sql.Scanner = (*Time)(nil)

Expand All @@ -28,7 +36,7 @@ func (t *Time) Scan(src any) error {
return nil
}
if v, ok := src.(time.Time); ok {
*t = Time(v.UTC())
*t = FromTime(v)
return nil
}
return errors.Newf("value %T is not time.Time", src)
Expand Down Expand Up @@ -64,5 +72,5 @@ func (t *Time) Time() *time.Time {
return nil
}
// Ensure the time is in UTC.
return pointers.Ptr((*time.Time)(t).UTC())
return pointers.Ptr((*time.Time)(t).UTC().Round(time.Microsecond))
}
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ go_test(
],
deps = [
":subscriptions",
"//cmd/enterprise-portal/internal/database",
"//cmd/enterprise-portal/internal/database/databasetest",
"//cmd/enterprise-portal/internal/database/internal/tables",
"//cmd/enterprise-portal/internal/database/internal/utctime",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ import (
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"

"github.com/sourcegraph/sourcegraph/cmd/enterprise-portal/internal/database"
"github.com/sourcegraph/sourcegraph/cmd/enterprise-portal/internal/database/databasetest"
"github.com/sourcegraph/sourcegraph/cmd/enterprise-portal/internal/database/internal/tables"
"github.com/sourcegraph/sourcegraph/cmd/enterprise-portal/internal/database/internal/utctime"
Expand All @@ -31,11 +32,11 @@ func TestLicensesStore(t *testing.T) {

subs := subscriptions.NewStore(db)
_, err := subs.Upsert(ctx, subscriptionID1, subscriptions.UpsertSubscriptionOptions{
DisplayName: "Acme, Inc. 1",
DisplayName: pointers.Ptr(database.NewNullString("Acme, Inc. 1")),
})
require.NoError(t, err)
_, err = subs.Upsert(ctx, subscriptionID2, subscriptions.UpsertSubscriptionOptions{
DisplayName: "Acme, Inc. 2",
DisplayName: pointers.Ptr(database.NewNullString("Acme, Inc. 2")),
})
require.NoError(t, err)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package subscriptions

import (
"context"
"database/sql"
"fmt"
"strings"
"time"
Expand Down Expand Up @@ -49,7 +50,7 @@ type Subscription struct {
//
// TODO: Clean up the database post-deploy and remove the 'Unnamed subscription'
// part of the constraint.
DisplayName string `gorm:"size:256;not null;uniqueIndex:,where:archived_at IS NULL AND display_name != 'Unnamed subscription' AND display_name != ''"`
DisplayName *string `gorm:"size:256;uniqueIndex:,where:archived_at IS NULL AND display_name != 'Unnamed subscription' AND display_name != ''"`

// Timestamps representing the latest timestamps of key conditions related
// to this subscription.
Expand Down Expand Up @@ -177,8 +178,8 @@ WHERE %s
}

type UpsertSubscriptionOptions struct {
InstanceDomain *string
DisplayName string
InstanceDomain *sql.NullString
DisplayName *sql.NullString

CreatedAt time.Time
ArchivedAt *time.Time
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import (
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"

"github.com/sourcegraph/sourcegraph/cmd/enterprise-portal/internal/database"
"github.com/sourcegraph/sourcegraph/cmd/enterprise-portal/internal/database/databasetest"
"github.com/sourcegraph/sourcegraph/cmd/enterprise-portal/internal/database/internal/tables"
"github.com/sourcegraph/sourcegraph/cmd/enterprise-portal/internal/database/subscriptions"
Expand Down Expand Up @@ -45,19 +46,25 @@ func SubscriptionsStoreList(t *testing.T, ctx context.Context, s *subscriptions.
s1, err := s.Upsert(
ctx,
uuid.New().String(),
subscriptions.UpsertSubscriptionOptions{InstanceDomain: pointers.Ptr("s1.sourcegraph.com")},
subscriptions.UpsertSubscriptionOptions{
InstanceDomain: pointers.Ptr(database.NewNullString("s1.sourcegraph.com")),
},
)
require.NoError(t, err)
s2, err := s.Upsert(
ctx,
uuid.New().String(),
subscriptions.UpsertSubscriptionOptions{InstanceDomain: pointers.Ptr("s2.sourcegraph.com")},
subscriptions.UpsertSubscriptionOptions{
InstanceDomain: pointers.Ptr(database.NewNullString("s2.sourcegraph.com")),
},
)
require.NoError(t, err)
_, err = s.Upsert(
ctx,
uuid.New().String(),
subscriptions.UpsertSubscriptionOptions{InstanceDomain: pointers.Ptr("s3.sourcegraph.com")},
subscriptions.UpsertSubscriptionOptions{
InstanceDomain: pointers.Ptr(database.NewNullString("s3.sourcegraph.com")),
},
)
require.NoError(t, err)

Expand Down Expand Up @@ -115,14 +122,16 @@ func SubscriptionsStoreUpsert(t *testing.T, ctx context.Context, s *subscription
currentSubscription, err := s.Upsert(
ctx,
uuid.New().String(),
subscriptions.UpsertSubscriptionOptions{InstanceDomain: pointers.Ptr("s1.sourcegraph.com")},
subscriptions.UpsertSubscriptionOptions{
InstanceDomain: pointers.Ptr(database.NewNullString("s1.sourcegraph.com")),
},
)
require.NoError(t, err)

got, err := s.Get(ctx, currentSubscription.ID)
require.NoError(t, err)
assert.Equal(t, currentSubscription.ID, got.ID)
assert.Equal(t, currentSubscription.InstanceDomain, got.InstanceDomain)
assert.Equal(t, *currentSubscription.InstanceDomain, *got.InstanceDomain)
assert.Empty(t, got.DisplayName)
assert.NotZero(t, got.CreatedAt)
assert.NotZero(t, got.UpdatedAt)
Expand All @@ -133,29 +142,31 @@ func SubscriptionsStoreUpsert(t *testing.T, ctx context.Context, s *subscription

got, err = s.Upsert(ctx, currentSubscription.ID, subscriptions.UpsertSubscriptionOptions{})
require.NoError(t, err)
assert.Equal(t, currentSubscription.InstanceDomain, got.InstanceDomain)
assert.Equal(t,
pointers.DerefZero(currentSubscription.InstanceDomain),
pointers.DerefZero(got.InstanceDomain))
})

t.Run("update only domain", func(t *testing.T) {
t.Cleanup(func() { currentSubscription = got })

got, err = s.Upsert(ctx, currentSubscription.ID, subscriptions.UpsertSubscriptionOptions{
InstanceDomain: pointers.Ptr("s1-new.sourcegraph.com"),
InstanceDomain: pointers.Ptr(database.NewNullString("s1-new.sourcegraph.com")),
})
require.NoError(t, err)
assert.Equal(t, "s1-new.sourcegraph.com", got.InstanceDomain)
assert.Equal(t, "s1-new.sourcegraph.com", pointers.DerefZero(got.InstanceDomain))
assert.Equal(t, currentSubscription.DisplayName, got.DisplayName)
})

t.Run("update only display name", func(t *testing.T) {
t.Cleanup(func() { currentSubscription = got })

got, err = s.Upsert(ctx, currentSubscription.ID, subscriptions.UpsertSubscriptionOptions{
DisplayName: "My New Display Name",
DisplayName: pointers.Ptr(database.NewNullString("My New Display Name")),
})
require.NoError(t, err)
assert.Equal(t, currentSubscription.InstanceDomain, got.InstanceDomain)
assert.Equal(t, "My New Display Name", got.DisplayName)
assert.Equal(t, *currentSubscription.InstanceDomain, *got.InstanceDomain)
assert.Equal(t, "My New Display Name", pointers.DerefZero(got.DisplayName))
})

t.Run("update only created at", func(t *testing.T) {
Expand All @@ -166,7 +177,9 @@ func SubscriptionsStoreUpsert(t *testing.T, ctx context.Context, s *subscription
CreatedAt: yesterday,
})
require.NoError(t, err)
assert.Equal(t, currentSubscription.InstanceDomain, got.InstanceDomain)
assert.Equal(t,
pointers.DerefZero(currentSubscription.InstanceDomain),
pointers.DerefZero(got.InstanceDomain))
assert.Equal(t, currentSubscription.DisplayName, got.DisplayName)
// Round times to allow for some precision drift in CI
assert.Equal(t, yesterday.Round(time.Second).UTC(), got.CreatedAt.Time().Round(time.Second))
Expand All @@ -180,8 +193,8 @@ func SubscriptionsStoreUpsert(t *testing.T, ctx context.Context, s *subscription
ArchivedAt: pointers.Ptr(yesterday),
})
require.NoError(t, err)
assert.Equal(t, currentSubscription.InstanceDomain, got.InstanceDomain)
assert.Equal(t, currentSubscription.DisplayName, got.DisplayName)
assert.Equal(t, *currentSubscription.InstanceDomain, *got.InstanceDomain)
assert.Equal(t, *currentSubscription.DisplayName, *got.DisplayName)
assert.Equal(t, currentSubscription.CreatedAt, got.CreatedAt)
// Round times to allow for some precision drift in CI
assert.Equal(t, yesterday.Round(time.Second).UTC(), got.ArchivedAt.Time().Round(time.Second))
Expand Down Expand Up @@ -209,7 +222,9 @@ func SubscriptionsStoreGet(t *testing.T, ctx context.Context, s *subscriptions.S
s1, err := s.Upsert(
ctx,
uuid.New().String(),
subscriptions.UpsertSubscriptionOptions{InstanceDomain: pointers.Ptr("s1.sourcegraph.com")},
subscriptions.UpsertSubscriptionOptions{
InstanceDomain: pointers.Ptr(database.NewNullString("s1.sourcegraph.com")),
},
)
require.NoError(t, err)

Expand Down
10 changes: 10 additions & 0 deletions cmd/enterprise-portal/internal/database/types.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
package database

import "database/sql"

func NewNullString(v string) sql.NullString {
return sql.NullString{
String: v,
Valid: v != "",
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ func convertSubscriptionToProto(subscription *subscriptions.Subscription, attrs
Id: subscriptionsv1.EnterpriseSubscriptionIDPrefix + attrs.ID,
Conditions: conds,
InstanceDomain: pointers.DerefZero(subscription.InstanceDomain),
DisplayName: subscription.DisplayName,
DisplayName: pointers.DerefZero(subscription.DisplayName),
}
}

Expand Down
26 changes: 18 additions & 8 deletions cmd/enterprise-portal/internal/subscriptionsservice/v1.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ import (
"github.com/sourcegraph/sourcegraph/lib/pointers"

"github.com/sourcegraph/sourcegraph/cmd/enterprise-portal/internal/connectutil"
"github.com/sourcegraph/sourcegraph/cmd/enterprise-portal/internal/database"
"github.com/sourcegraph/sourcegraph/cmd/enterprise-portal/internal/database/subscriptions"
"github.com/sourcegraph/sourcegraph/cmd/enterprise-portal/internal/dotcomdb"
"github.com/sourcegraph/sourcegraph/cmd/enterprise-portal/internal/samsm2m"
Expand Down Expand Up @@ -329,32 +330,41 @@ func (s *handlerV1) UpdateEnterpriseSubscription(ctx context.Context, req *conne
// Empty field paths means update all non-empty fields.
if len(fieldPaths) == 0 {
if v := req.Msg.GetSubscription().GetInstanceDomain(); v != "" {
opts.InstanceDomain = &v
opts.InstanceDomain = pointers.Ptr(database.NewNullString(v))
}
if v := req.Msg.GetSubscription().GetDisplayName(); v != "" {
opts.DisplayName = v
opts.DisplayName = pointers.Ptr(database.NewNullString(v))
}
} else {
for _, p := range fieldPaths {
switch p {
case "instance_domain":
opts.InstanceDomain = pointers.Ptr(req.Msg.GetSubscription().GetInstanceDomain())
opts.InstanceDomain = pointers.Ptr(
database.NewNullString(req.Msg.GetSubscription().GetInstanceDomain()),
)
case "display_name":
opts.DisplayName = req.Msg.GetSubscription().GetDisplayName()
opts.DisplayName = pointers.Ptr(
database.NewNullString(req.Msg.GetSubscription().GetDisplayName()),
)
case "*":
opts.ForceUpdate = true
opts.InstanceDomain = pointers.Ptr(req.Msg.GetSubscription().GetInstanceDomain())
opts.InstanceDomain = pointers.Ptr(
database.NewNullString(req.Msg.GetSubscription().GetInstanceDomain()),
)
opts.DisplayName = pointers.Ptr(
database.NewNullString(req.Msg.GetSubscription().GetDisplayName()),
)
}
}
}

// Validate and normalize the domain
if opts.InstanceDomain != nil {
normalizedDomain, err := subscriptionsv1.NormalizeInstanceDomain(pointers.DerefZero(opts.InstanceDomain))
if opts.InstanceDomain != nil && opts.InstanceDomain.Valid {
normalizedDomain, err := subscriptionsv1.NormalizeInstanceDomain(opts.InstanceDomain.String)
if err != nil {
return nil, connect.NewError(connect.CodeInvalidArgument, errors.Wrap(err, "invalid instance domain"))
}
opts.InstanceDomain = &normalizedDomain
opts.InstanceDomain.String = normalizedDomain
}

subscription, err := s.store.UpsertEnterpriseSubscription(ctx, subscriptionID, opts)
Expand Down

0 comments on commit 6a9d351

Please sign in to comment.