-
Notifications
You must be signed in to change notification settings - Fork 1.3k
feat/enterpriseportal: db layer for subscription licenses #63792
Changes from all commits
73424a2
71b3cfe
0a54a1b
ba8a30d
0589692
9110453
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,9 @@ | ||
load("@io_bazel_rules_go//go:def.bzl", "go_library") | ||
|
||
go_library( | ||
name = "custommigrator", | ||
srcs = ["custommigrator.go"], | ||
importpath = "github.com/sourcegraph/sourcegraph/cmd/enterprise-portal/internal/database/internal/tables/custommigrator", | ||
visibility = ["//cmd/enterprise-portal:__subpackages__"], | ||
deps = ["@io_gorm_gorm//:gorm"], | ||
) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,9 @@ | ||
package custommigrator | ||
|
||
import "gorm.io/gorm" | ||
|
||
type CustomTableMigrator interface { | ||
// RunCustomMigrations is called after all other migrations have been run. | ||
// It can implement custom migrations. | ||
RunCustomMigrations(migrator gorm.Migrator) error | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,12 @@ | ||
load("@io_bazel_rules_go//go:def.bzl", "go_library") | ||
|
||
go_library( | ||
name = "utctime", | ||
srcs = ["utctime.go"], | ||
importpath = "github.com/sourcegraph/sourcegraph/cmd/enterprise-portal/internal/database/internal/utctime", | ||
visibility = ["//cmd/enterprise-portal:__subpackages__"], | ||
deps = [ | ||
"//lib/errors", | ||
"//lib/pointers", | ||
], | ||
) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,80 @@ | ||
package utctime | ||
|
||
import ( | ||
"database/sql" | ||
"database/sql/driver" | ||
"encoding/json" | ||
"time" | ||
|
||
"github.com/sourcegraph/sourcegraph/lib/errors" | ||
"github.com/sourcegraph/sourcegraph/lib/pointers" | ||
) | ||
|
||
// Time is a wrapper around time.Time that implements the database/sql.Scanner | ||
// and database/sql/driver.Valuer interfaces to serialize and deserialize time | ||
// in UTC time zone. | ||
// | ||
// Time ensures that time.Time values are always: | ||
// | ||
// - represented in UTC for consistency | ||
// - rounded to microsecond precision | ||
// | ||
// We round the time because PostgreSQL times are represented in microseconds: | ||
// https://www.postgresql.org/docs/current/datatype-datetime.html | ||
type Time time.Time | ||
|
||
// Now returns the current time in UTC. | ||
func Now() Time { return Time(time.Now()) } | ||
|
||
// FromTime returns a utctime.Time from a time.Time. | ||
func FromTime(t time.Time) Time { return Time(t.UTC().Round(time.Microsecond)) } | ||
|
||
var _ sql.Scanner = (*Time)(nil) | ||
|
||
func (t *Time) Scan(src any) error { | ||
if src == nil { | ||
return nil | ||
} | ||
if v, ok := src.(time.Time); ok { | ||
*t = FromTime(v) | ||
return nil | ||
} | ||
return errors.Newf("value %T is not time.Time", src) | ||
} | ||
|
||
var _ driver.Valuer = (*Time)(nil) | ||
|
||
// Value must be called with a non-nil Time. driver.Valuer callers will first | ||
// check that the value is non-nil, so this is safe. | ||
func (t Time) Value() (driver.Value, error) { | ||
stdTime := t.GetTime() | ||
return *stdTime, nil | ||
} | ||
|
||
var _ json.Marshaler = (*Time)(nil) | ||
|
||
func (t Time) MarshalJSON() ([]byte, error) { return json.Marshal(t.GetTime()) } | ||
|
||
var _ json.Unmarshaler = (*Time)(nil) | ||
|
||
func (t *Time) UnmarshalJSON(data []byte) error { | ||
var stdTime time.Time | ||
if err := json.Unmarshal(data, &stdTime); err != nil { | ||
return err | ||
} | ||
*t = FromTime(stdTime) | ||
return nil | ||
} | ||
|
||
// GetTime returns the underlying time.GetTime value, or nil if it is nil. | ||
func (t *Time) GetTime() *time.Time { | ||
if t == nil { | ||
return nil | ||
} | ||
return pointers.Ptr(t.AsTime()) | ||
} | ||
|
||
// Time casts the Time as a standard time.Time value. | ||
func (t Time) AsTime() time.Time { | ||
return time.Time(t).UTC().Round(time.Microsecond) | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -21,6 +21,7 @@ import ( | |
"github.com/sourcegraph/sourcegraph/lib/redislock" | ||
|
||
"github.com/sourcegraph/sourcegraph/cmd/enterprise-portal/internal/database/internal/tables" | ||
"github.com/sourcegraph/sourcegraph/cmd/enterprise-portal/internal/database/internal/tables/custommigrator" | ||
) | ||
|
||
// maybeMigrate runs the auto-migration for the database when needed based on | ||
|
@@ -42,6 +43,12 @@ func maybeMigrate(ctx context.Context, logger log.Logger, contract runtime.Contr | |
} | ||
span.End() | ||
}() | ||
logger = logger. | ||
WithTrace(log.TraceContext{ | ||
TraceID: span.SpanContext().TraceID().String(), | ||
SpanID: span.SpanContext().SpanID().String(), | ||
}). | ||
With(log.String("database", dbName)) | ||
|
||
sqlDB, err := contract.PostgreSQL.OpenDatabase(ctx, dbName) | ||
if err != nil { | ||
|
@@ -83,17 +90,20 @@ func maybeMigrate(ctx context.Context, logger log.Logger, contract runtime.Contr | |
span.AddEvent("lock.acquired") | ||
|
||
versionKey := fmt.Sprintf("%s:db_version", dbName) | ||
liveVersion := redisClient.Get(ctx, versionKey).Val() | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. it seems a bit odd to me that redis would keep the state of the SQL DB :D There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Hm, we should probably store it in DB huh There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Start storing this info in DB somewhat increases the seriousness of the value, that requires us to keep it legit and deal with potential compatibility issues. The current use case of it (and why it's in Redis) is merely a quick optimization about "avoid (if possible) running DB migrations if the version is exactly the same, and it's fine if ran twice, does no harm", it doesn't really care if an upgrade or downgrade has happened. |
||
if shouldSkipMigration( | ||
redisClient.Get(ctx, versionKey).Val(), | ||
liveVersion, | ||
currentVersion, | ||
) { | ||
logger.Info("skipped auto-migration", | ||
log.String("database", dbName), | ||
log.String("currentVersion", currentVersion), | ||
) | ||
span.SetAttributes(attribute.Bool("skipped", true)) | ||
return nil | ||
} | ||
logger.Info("executing auto-migration", | ||
log.String("liveVersion", liveVersion), | ||
log.String("currentVersion", currentVersion)) | ||
span.SetAttributes(attribute.Bool("skipped", false)) | ||
|
||
// Create a session that ignore debug logging. | ||
|
@@ -108,6 +118,11 @@ func maybeMigrate(ctx context.Context, logger log.Logger, contract runtime.Contr | |
if err != nil { | ||
return errors.Wrapf(err, "auto migrating table for %s", errors.Safe(fmt.Sprintf("%T", table))) | ||
} | ||
if m, ok := table.(custommigrator.CustomTableMigrator); ok { | ||
if err := m.RunCustomMigrations(sess.Migrator()); err != nil { | ||
return errors.Wrapf(err, "running custom migrations for %s", errors.Safe(fmt.Sprintf("%T", table))) | ||
} | ||
} | ||
} | ||
|
||
return redisClient.Set(ctx, versionKey, currentVersion, 0).Err() | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Handy for dumping DB queries in tests