Skip to content
This repository has been archived by the owner on Sep 30, 2024. It is now read-only.

feat/enterpriseportal: db layer for subscription licenses #63792

Merged
merged 6 commits into from
Jul 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions cmd/enterprise-portal/internal/database/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,14 @@ go_library(
srcs = [
"database.go",
"migrate.go",
"types.go",
],
importpath = "github.com/sourcegraph/sourcegraph/cmd/enterprise-portal/internal/database",
tags = [TAG_INFRA_CORESERVICES],
visibility = ["//cmd/enterprise-portal:__subpackages__"],
deps = [
"//cmd/enterprise-portal/internal/database/internal/tables",
"//cmd/enterprise-portal/internal/database/internal/tables/custommigrator",
"//cmd/enterprise-portal/internal/database/subscriptions",
"//lib/errors",
"//lib/managedservicesplatform/runtime",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ import "github.com/sourcegraph/sourcegraph/cmd/enterprise-portal/internal/databa

type CodyGatewayAccess struct {
// ⚠️ DO NOT USE: This field is only used for creating foreign key constraint.
Subscription *subscriptions.Subscription `gorm:"foreignKey:SubscriptionID"`
Subscription *subscriptions.TableSubscription `gorm:"foreignKey:SubscriptionID"`

// SubscriptionID is the internal unprefixed UUID of the related subscription.
SubscriptionID string `gorm:"type:uuid;not null;unique"`
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,9 @@ go_library(
tags = [TAG_INFRA_CORESERVICES],
visibility = ["//cmd/enterprise-portal:__subpackages__"],
deps = [
"//cmd/enterprise-portal/internal/database/internal/tables/custommigrator",
"//internal/database/dbtest",
"@com_github_jackc_pgx_v5//:pgx",
"@com_github_jackc_pgx_v5//pgxpool",
"@com_github_stretchr_testify//require",
"@io_gorm_driver_postgres//:postgres",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,17 +3,20 @@ package databasetest
import (
"context"
"database/sql"
"encoding/json"
"fmt"
"strings"
"testing"
"time"

"github.com/jackc/pgx/v5"
"github.com/jackc/pgx/v5/pgxpool"
"github.com/stretchr/testify/require"
"gorm.io/driver/postgres"
"gorm.io/gorm"
"gorm.io/gorm/schema"

"github.com/sourcegraph/sourcegraph/cmd/enterprise-portal/internal/database/internal/tables/custommigrator"
"github.com/sourcegraph/sourcegraph/internal/database/dbtest"
)

Expand Down Expand Up @@ -57,6 +60,11 @@ func NewTestDB(t testing.TB, system, suite string, tables ...schema.Tabler) *pgx
for _, table := range tables {
err = db.AutoMigrate(table)
require.NoError(t, err)
if m, ok := table.(custommigrator.CustomTableMigrator); ok {
if err := m.RunCustomMigrations(db.Migrator()); err != nil {
require.NoError(t, err)
}
}
}

// Close the connection used to auto-migrate the database.
Expand All @@ -66,7 +74,12 @@ func NewTestDB(t testing.TB, system, suite string, tables ...schema.Tabler) *pgx
require.NoError(t, err)

// Open a new connection to the test suite database.
testDB, err := pgxpool.New(context.Background(), dsn.String())
dbConfig, err := pgxpool.ParseConfig(dsn.String())
require.NoError(t, err)
if testing.Verbose() {
dbConfig.ConnConfig.Tracer = pgxTestTracer{TB: t}
}
testDB, err := pgxpool.NewWithConfig(context.Background(), dbConfig)
require.NoError(t, err)

t.Cleanup(func() {
Expand Down Expand Up @@ -110,3 +123,43 @@ func ClearTablesAfterTest(t *testing.T, db *pgxpool.Pool, tables ...schema.Table
}
})
}

// pgxTestTracer implements various pgx tracing hooks for dumping diagnostics
// in testing.
type pgxTestTracer struct{ testing.TB }

// Select tracing hooks we want to implement.
var (
_ pgx.QueryTracer = pgxTestTracer{}
)

func (t pgxTestTracer) TraceQueryStart(ctx context.Context, conn *pgx.Conn, data pgx.TraceQueryStartData) context.Context {
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Handy for dumping DB queries in tests

var args []string
if len(data.Args) > 0 {
// Divider for readability
args = append(args, "\n---")
}
for _, arg := range data.Args {
data, err := json.MarshalIndent(arg, "", " ")
if err != nil {
args = append(args, fmt.Sprintf("marshal %T: %+v", arg, err))
}
args = append(args, string(data))
}

t.Logf(`pgx.QueryStart db=%q
%s%s`,
conn.Config().Database,
strings.TrimSpace(data.SQL),
strings.Join(args, "\n"))
return ctx
}

func (t pgxTestTracer) TraceQueryEnd(ctx context.Context, conn *pgx.Conn, data pgx.TraceQueryEndData) {
if data.Err != nil {
t.Logf(`pgx.QueryEnd db=%q tag=%q error=%q`,
conn.Config().Database,
data.CommandTag.String(),
data.Err)
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
load("@io_bazel_rules_go//go:def.bzl", "go_library")

go_library(
name = "custommigrator",
srcs = ["custommigrator.go"],
importpath = "github.com/sourcegraph/sourcegraph/cmd/enterprise-portal/internal/database/internal/tables/custommigrator",
visibility = ["//cmd/enterprise-portal:__subpackages__"],
deps = ["@io_gorm_gorm//:gorm"],
)
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
package custommigrator

import "gorm.io/gorm"

type CustomTableMigrator interface {
// RunCustomMigrations is called after all other migrations have been run.
// It can implement custom migrations.
RunCustomMigrations(migrator gorm.Migrator) error
}
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,9 @@ import (
// ⚠️ WARNING: This list is meant to be read-only.
func All() []schema.Tabler {
return []schema.Tabler{
&subscriptions.Subscription{},
&subscriptions.TableSubscription{},
&subscriptions.SubscriptionCondition{},
&subscriptions.SubscriptionLicense{},
&subscriptions.TableSubscriptionLicense{},
&subscriptions.SubscriptionLicenseCondition{},

&codyaccess.CodyGatewayAccess{},
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
load("@io_bazel_rules_go//go:def.bzl", "go_library")

go_library(
name = "utctime",
srcs = ["utctime.go"],
importpath = "github.com/sourcegraph/sourcegraph/cmd/enterprise-portal/internal/database/internal/utctime",
visibility = ["//cmd/enterprise-portal:__subpackages__"],
deps = [
"//lib/errors",
"//lib/pointers",
],
)
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
package utctime

import (
"database/sql"
"database/sql/driver"
"encoding/json"
"time"

"github.com/sourcegraph/sourcegraph/lib/errors"
"github.com/sourcegraph/sourcegraph/lib/pointers"
)

// Time is a wrapper around time.Time that implements the database/sql.Scanner
// and database/sql/driver.Valuer interfaces to serialize and deserialize time
// in UTC time zone.
//
// Time ensures that time.Time values are always:
//
// - represented in UTC for consistency
// - rounded to microsecond precision
//
// We round the time because PostgreSQL times are represented in microseconds:
// https://www.postgresql.org/docs/current/datatype-datetime.html
type Time time.Time

// Now returns the current time in UTC.
func Now() Time { return Time(time.Now()) }

// FromTime returns a utctime.Time from a time.Time.
func FromTime(t time.Time) Time { return Time(t.UTC().Round(time.Microsecond)) }

var _ sql.Scanner = (*Time)(nil)

func (t *Time) Scan(src any) error {
if src == nil {
return nil
}
if v, ok := src.(time.Time); ok {
*t = FromTime(v)
return nil
}
return errors.Newf("value %T is not time.Time", src)
}

var _ driver.Valuer = (*Time)(nil)

// Value must be called with a non-nil Time. driver.Valuer callers will first
// check that the value is non-nil, so this is safe.
func (t Time) Value() (driver.Value, error) {
stdTime := t.GetTime()
return *stdTime, nil
}

var _ json.Marshaler = (*Time)(nil)

func (t Time) MarshalJSON() ([]byte, error) { return json.Marshal(t.GetTime()) }

var _ json.Unmarshaler = (*Time)(nil)

func (t *Time) UnmarshalJSON(data []byte) error {
var stdTime time.Time
if err := json.Unmarshal(data, &stdTime); err != nil {
return err
}
*t = FromTime(stdTime)
return nil
}

// GetTime returns the underlying time.GetTime value, or nil if it is nil.
func (t *Time) GetTime() *time.Time {
if t == nil {
return nil
}
return pointers.Ptr(t.AsTime())
}

// Time casts the Time as a standard time.Time value.
func (t Time) AsTime() time.Time {
return time.Time(t).UTC().Round(time.Microsecond)
}
19 changes: 17 additions & 2 deletions cmd/enterprise-portal/internal/database/migrate.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import (
"github.com/sourcegraph/sourcegraph/lib/redislock"

"github.com/sourcegraph/sourcegraph/cmd/enterprise-portal/internal/database/internal/tables"
"github.com/sourcegraph/sourcegraph/cmd/enterprise-portal/internal/database/internal/tables/custommigrator"
)

// maybeMigrate runs the auto-migration for the database when needed based on
Expand All @@ -42,6 +43,12 @@ func maybeMigrate(ctx context.Context, logger log.Logger, contract runtime.Contr
}
span.End()
}()
logger = logger.
WithTrace(log.TraceContext{
TraceID: span.SpanContext().TraceID().String(),
SpanID: span.SpanContext().SpanID().String(),
}).
With(log.String("database", dbName))

sqlDB, err := contract.PostgreSQL.OpenDatabase(ctx, dbName)
if err != nil {
Expand Down Expand Up @@ -83,17 +90,20 @@ func maybeMigrate(ctx context.Context, logger log.Logger, contract runtime.Contr
span.AddEvent("lock.acquired")

versionKey := fmt.Sprintf("%s:db_version", dbName)
liveVersion := redisClient.Get(ctx, versionKey).Val()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it seems a bit odd to me that redis would keep the state of the SQL DB :D

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hm, we should probably store it in DB huh

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Member

@unknwon unknwon Aug 14, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Start storing this info in DB somewhat increases the seriousness of the value, that requires us to keep it legit and deal with potential compatibility issues. The current use case of it (and why it's in Redis) is merely a quick optimization about "avoid (if possible) running DB migrations if the version is exactly the same, and it's fine if ran twice, does no harm", it doesn't really care if an upgrade or downgrade has happened.

if shouldSkipMigration(
redisClient.Get(ctx, versionKey).Val(),
liveVersion,
currentVersion,
) {
logger.Info("skipped auto-migration",
log.String("database", dbName),
log.String("currentVersion", currentVersion),
)
span.SetAttributes(attribute.Bool("skipped", true))
return nil
}
logger.Info("executing auto-migration",
log.String("liveVersion", liveVersion),
log.String("currentVersion", currentVersion))
span.SetAttributes(attribute.Bool("skipped", false))

// Create a session that ignore debug logging.
Expand All @@ -108,6 +118,11 @@ func maybeMigrate(ctx context.Context, logger log.Logger, contract runtime.Contr
if err != nil {
return errors.Wrapf(err, "auto migrating table for %s", errors.Safe(fmt.Sprintf("%T", table)))
}
if m, ok := table.(custommigrator.CustomTableMigrator); ok {
if err := m.RunCustomMigrations(sess.Migrator()); err != nil {
return errors.Wrapf(err, "running custom migrations for %s", errors.Safe(fmt.Sprintf("%T", table)))
}
}
}

return redisClient.Set(ctx, versionKey, currentVersion, 0).Err()
Expand Down
16 changes: 14 additions & 2 deletions cmd/enterprise-portal/internal/database/subscriptions/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -14,27 +14,39 @@ go_library(
visibility = ["//cmd/enterprise-portal:__subpackages__"],
deps = [
"//cmd/enterprise-portal/internal/database/internal/upsert",
"//cmd/enterprise-portal/internal/database/internal/utctime",
"//internal/license",
"//lib/enterpriseportal/subscriptions/v1:subscriptions",
"//lib/errors",
"//lib/pointers",
"@com_github_jackc_pgtype//:pgtype",
"@com_github_google_uuid//:uuid",
"@com_github_jackc_pgx_v5//:pgx",
"@com_github_jackc_pgx_v5//pgxpool",
"@io_gorm_gorm//:gorm",
],
)

go_test(
name = "subscriptions_test",
srcs = ["subscriptions_test.go"],
srcs = [
"licenses_test.go",
"subscriptions_test.go",
],
tags = [
TAG_INFRA_CORESERVICES,
"requires-network",
],
deps = [
":subscriptions",
"//cmd/enterprise-portal/internal/database",
"//cmd/enterprise-portal/internal/database/databasetest",
"//cmd/enterprise-portal/internal/database/internal/tables",
"//cmd/enterprise-portal/internal/database/internal/utctime",
"//internal/license",
"//lib/pointers",
"@com_github_google_uuid//:uuid",
"@com_github_hexops_autogold_v2//:autogold",
"@com_github_hexops_valast//:valast",
"@com_github_jackc_pgx_v5//:pgx",
"@com_github_stretchr_testify//assert",
"@com_github_stretchr_testify//require",
Expand Down
Loading
Loading