From 785cf8061d296820136bd995087bd6be073367ac Mon Sep 17 00:00:00 2001 From: Robert Lin Date: Fri, 12 Jul 2024 13:38:47 -0700 Subject: [PATCH] lift expire_time --- .../database/internal/utctime/utctime.go | 16 +++--- .../subscriptions/license_conditions.go | 2 +- .../database/subscriptions/licenses.go | 24 +++++++-- .../database/subscriptions/licenses_test.go | 49 ++++++++++++++----- .../subscriptions/subscriptions_test.go | 4 +- 5 files changed, 69 insertions(+), 26 deletions(-) diff --git a/cmd/enterprise-portal/internal/database/internal/utctime/utctime.go b/cmd/enterprise-portal/internal/database/internal/utctime/utctime.go index 25f5ed412650d..0e90287687ea5 100644 --- a/cmd/enterprise-portal/internal/database/internal/utctime/utctime.go +++ b/cmd/enterprise-portal/internal/database/internal/utctime/utctime.go @@ -47,13 +47,13 @@ var _ driver.Valuer = (*Time)(nil) // Value must be called with a non-nil Time. driver.Valuer callers will first // check that the value is non-nil, so this is safe. func (t Time) Value() (driver.Value, error) { - stdTime := t.Time() + stdTime := t.GetTime() return *stdTime, nil } var _ json.Marshaler = (*Time)(nil) -func (t Time) MarshalJSON() ([]byte, error) { return json.Marshal(t.Time()) } +func (t Time) MarshalJSON() ([]byte, error) { return json.Marshal(t.GetTime()) } var _ json.Unmarshaler = (*Time)(nil) @@ -66,11 +66,15 @@ func (t *Time) UnmarshalJSON(data []byte) error { return nil } -// Time returns the underlying time.Time value, or nil if it is nil. -func (t *Time) Time() *time.Time { +// GetTime returns the underlying time.GetTime value, or nil if it is nil. +func (t *Time) GetTime() *time.Time { if t == nil { return nil } - // Ensure the time is in UTC. - return pointers.Ptr((*time.Time)(t).UTC().Round(time.Microsecond)) + return pointers.Ptr(t.AsTime()) +} + +// Time casts the Time as a standard time.Time value. +func (t Time) AsTime() time.Time { + return time.Time(t).UTC().Round(time.Microsecond) } diff --git a/cmd/enterprise-portal/internal/database/subscriptions/license_conditions.go b/cmd/enterprise-portal/internal/database/subscriptions/license_conditions.go index b047dfd0dadb5..ac19bceeca04d 100644 --- a/cmd/enterprise-portal/internal/database/subscriptions/license_conditions.go +++ b/cmd/enterprise-portal/internal/database/subscriptions/license_conditions.go @@ -67,7 +67,7 @@ type createLicenseConditionOpts struct { } func (s *licenseConditionsStore) createLicenseCondition(ctx context.Context, licenseID string, opts createLicenseConditionOpts) error { - if opts.TransitionTime.Time().IsZero() { + if opts.TransitionTime.GetTime().IsZero() { return errors.New("transition time is required") } _, err := s.tx.Exec(ctx, ` diff --git a/cmd/enterprise-portal/internal/database/subscriptions/licenses.go b/cmd/enterprise-portal/internal/database/subscriptions/licenses.go index c96865900faed..11289b20135f5 100644 --- a/cmd/enterprise-portal/internal/database/subscriptions/licenses.go +++ b/cmd/enterprise-portal/internal/database/subscriptions/licenses.go @@ -56,6 +56,10 @@ type SubscriptionLicense struct { CreatedAt utctime.Time `gorm:"not null;default:current_timestamp"` RevokedAt *utctime.Time // Null indicates the license is not revoked. + // ExpireAt is the time at which the license should expire. Expiration does + // NOT get a corresponding condition entry in 'enterprise_portal_subscription_license_conditions'. + ExpireAt utctime.Time `gorm:"not null"` + // LicenseType is the kind of license stored in LicenseData, corresponding // to the API 'EnterpriseSubscriptionLicenseType'. LicenseType string `gorm:"not null"` @@ -77,6 +81,7 @@ func subscriptionLicenseWithConditionsColumns() []string { "created_at", "revoked_at", + "expire_at", "license_type", "license_data", @@ -98,6 +103,7 @@ func scanSubscriptionLicenseWithConditions(row pgx.Row) (*LicenseWithConditions, &l.ID, &l.CreatedAt, &l.RevokedAt, + &l.ExpireAt, &l.LicenseType, &l.LicenseData, &l.Conditions, // see subscriptionLicenseConditionJSONBAgg docstring @@ -207,6 +213,8 @@ type CreateLicenseOpts struct { Message string // If nil, the creation time will be set to the current time. Time *utctime.Time + // Expiration time of the license. + ExpireTime utctime.Time } // LicenseKey corresponds to *subscriptionsv1.EnterpriseSubscriptionLicenseKey @@ -228,9 +236,12 @@ func (s *LicensesStore) CreateLicenseKey( // match the time provided in the options. if opts.Time == nil { return nil, errors.New("creation time must be specified for licensekeys") - } else if !opts.Time.Time().Equal(license.Info.CreatedAt) { + } else if !opts.Time.GetTime().Equal(utctime.FromTime(license.Info.CreatedAt).AsTime()) { return nil, errors.New("creation time must match the license key information") } + if !opts.ExpireTime.GetTime().Equal(utctime.FromTime(license.Info.ExpiresAt).AsTime()) { + return nil, errors.New("expiration time must match the license key information") + } return s.create( ctx, @@ -253,7 +264,7 @@ func (s *LicensesStore) create( } if opts.Time == nil { opts.Time = pointers.Ptr(utctime.Now()) - } else if opts.Time.Time().After(time.Now()) { + } else if opts.Time.GetTime().After(time.Now()) { return nil, errors.New("creation time cannot be in the future") } if licenseType == subscriptionsv1.EnterpriseSubscriptionLicenseType_ENTERPRISE_SUBSCRIPTION_LICENSE_TYPE_UNSPECIFIED { @@ -284,14 +295,16 @@ INSERT INTO enterprise_portal_subscription_licenses ( subscription_id, license_type, license_data, - created_at + created_at, + expire_at ) VALUES ( @licenseID, @subscriptionID, @licenseType, @licenseData, - @createdAt + @createdAt, + @expireAt ) `, pgx.NamedArgs{ "licenseID": licenseID.String(), @@ -299,6 +312,7 @@ VALUES ( "licenseType": subscriptionsv1.EnterpriseSubscriptionLicenseType_name[int32(licenseType)], "licenseData": licenseData, "createdAt": opts.Time, + "expireAt": opts.ExpireTime, }); err != nil { return nil, errors.Wrap(err, "create license") } @@ -328,7 +342,7 @@ type RevokeLicenseOpts struct { func (s *LicensesStore) Revoke(ctx context.Context, licenseID string, opts RevokeLicenseOpts) (*LicenseWithConditions, error) { if opts.Time == nil { opts.Time = pointers.Ptr(utctime.Now()) - } else if opts.Time.Time().After(time.Now()) { + } else if opts.Time.GetTime().After(time.Now()) { return nil, errors.New("revocation time cannot be in the future") } diff --git a/cmd/enterprise-portal/internal/database/subscriptions/licenses_test.go b/cmd/enterprise-portal/internal/database/subscriptions/licenses_test.go index 80e2ed87fc1fe..1cbd5fdd6700e 100644 --- a/cmd/enterprise-portal/internal/database/subscriptions/licenses_test.go +++ b/cmd/enterprise-portal/internal/database/subscriptions/licenses_test.go @@ -61,6 +61,7 @@ func TestLicensesStore(t *testing.T) { ) { assert.NotEmpty(t, got.ID) assert.NotZero(t, got.CreatedAt) + assert.NotZero(t, got.ExpireAt) assert.Equal(t, "ENTERPRISE_SUBSCRIPTION_LICENSE_TYPE_KEY", got.LicenseType) wantLicenseData.Equal(t, string(got.LicenseData)) @@ -75,18 +76,20 @@ func TestLicensesStore(t *testing.T) { Info: license.Info{ Tags: []string{"foo"}, CreatedAt: time.Time{}.Add(1 * time.Hour), + ExpiresAt: time.Time{}.Add(48 * time.Hour), }, SignedKey: "asdfasdf", }, subscriptions.CreateLicenseOpts{ - Message: t.Name() + " 1 old", - Time: pointers.Ptr(utctime.FromTime(time.Time{}.Add(1 * time.Hour))), + Message: t.Name() + " 1 old", + Time: pointers.Ptr(utctime.FromTime(time.Time{}.Add(1 * time.Hour))), + ExpireTime: utctime.FromTime(time.Time{}.Add(48 * time.Hour)), }) require.NoError(t, err) testLicense( got, autogold.Expect(valast.Ptr("TestLicensesStore/CreateLicenseKey 1 old")), - autogold.Expect(`{"Info": {"c": "0001-01-01T01:00:00Z", "e": "0001-01-01T00:00:00Z", "t": ["foo"], "u": 0}, "SignedKey": "asdfasdf"}`), + autogold.Expect(`{"Info": {"c": "0001-01-01T01:00:00Z", "e": "0001-01-03T00:00:00Z", "t": ["foo"], "u": 0}, "SignedKey": "asdfasdf"}`), ) createdLicenses = append(createdLicenses, got) @@ -95,18 +98,20 @@ func TestLicensesStore(t *testing.T) { Info: license.Info{ Tags: []string{"baz"}, CreatedAt: time.Time{}.Add(24 * time.Hour), + ExpiresAt: time.Time{}.Add(48 * time.Hour), }, SignedKey: "barasdf", }, subscriptions.CreateLicenseOpts{ - Message: t.Name() + " 1", - Time: pointers.Ptr(utctime.FromTime(time.Time{}.Add(24 * time.Hour))), + Message: t.Name() + " 1", + Time: pointers.Ptr(utctime.FromTime(time.Time{}.Add(24 * time.Hour))), + ExpireTime: utctime.FromTime(time.Time{}.Add(48 * time.Hour)), }) require.NoError(t, err) testLicense( got, autogold.Expect(valast.Ptr("TestLicensesStore/CreateLicenseKey 1")), - autogold.Expect(`{"Info": {"c": "0001-01-02T00:00:00Z", "e": "0001-01-01T00:00:00Z", "t": ["baz"], "u": 0}, "SignedKey": "barasdf"}`), + autogold.Expect(`{"Info": {"c": "0001-01-02T00:00:00Z", "e": "0001-01-03T00:00:00Z", "t": ["baz"], "u": 0}, "SignedKey": "barasdf"}`), ) createdLicenses = append(createdLicenses, got) @@ -115,22 +120,24 @@ func TestLicensesStore(t *testing.T) { Info: license.Info{ Tags: []string{"tag"}, CreatedAt: time.Time{}.Add(24 * time.Hour), + ExpiresAt: time.Time{}.Add(48 * time.Hour), }, SignedKey: "asdffdsadf", }, subscriptions.CreateLicenseOpts{ - Message: t.Name() + " 2", - Time: pointers.Ptr(utctime.FromTime(time.Time{}.Add(24 * time.Hour))), + Message: t.Name() + " 2", + Time: pointers.Ptr(utctime.FromTime(time.Time{}.Add(24 * time.Hour))), + ExpireTime: utctime.FromTime(time.Time{}.Add(48 * time.Hour)), }) require.NoError(t, err) testLicense( got, autogold.Expect(valast.Ptr("TestLicensesStore/CreateLicenseKey 2")), - autogold.Expect(`{"Info": {"c": "0001-01-02T00:00:00Z", "e": "0001-01-01T00:00:00Z", "t": ["tag"], "u": 0}, "SignedKey": "asdffdsadf"}`), + autogold.Expect(`{"Info": {"c": "0001-01-02T00:00:00Z", "e": "0001-01-03T00:00:00Z", "t": ["tag"], "u": 0}, "SignedKey": "asdffdsadf"}`), ) createdLicenses = append(createdLicenses, got) - t.Run("timestamps do not match", func(t *testing.T) { + t.Run("createdAt does not match", func(t *testing.T) { _, err = licenses.CreateLicenseKey(ctx, subscriptionID2, &subscriptions.LicenseKey{ Info: license.Info{ @@ -146,6 +153,24 @@ func TestLicensesStore(t *testing.T) { require.Error(t, err) autogold.Expect("creation time must match the license key information").Equal(t, err.Error()) }) + t.Run("expiresAt does not match", func(t *testing.T) { + _, err = licenses.CreateLicenseKey(ctx, subscriptionID2, + &subscriptions.LicenseKey{ + Info: license.Info{ + Tags: []string{"tag"}, + CreatedAt: time.Time{}, + ExpiresAt: time.Time{}.Add(48 * time.Hour), + }, + SignedKey: "asdffdsadf", + }, + subscriptions.CreateLicenseOpts{ + Message: t.Name(), + Time: pointers.Ptr(utctime.FromTime(time.Time{})), + ExpireTime: utctime.Now(), + }) + require.Error(t, err) + autogold.Expect("expiration time must match the license key information").Equal(t, err.Error()) + }) }) t.Run("List", func(t *testing.T) { @@ -196,11 +221,11 @@ func TestLicensesStore(t *testing.T) { Time: pointers.Ptr(utctime.FromTime(revokeTime)), }) require.NoError(t, err) - assert.Equal(t, revokeTime.UTC(), *got.RevokedAt.Time()) + assert.Equal(t, revokeTime.UTC(), got.RevokedAt.AsTime()) require.Len(t, got.Conditions, 2) // Most recent condition is sorted first, and should be the revocation assert.Equal(t, "STATUS_REVOKED", got.Conditions[0].Status) - assert.Equal(t, revokeTime.UTC(), *got.Conditions[0].TransitionTime.Time()) + assert.Equal(t, revokeTime.UTC(), *got.Conditions[0].TransitionTime.GetTime()) assert.Equal(t, "STATUS_CREATED", got.Conditions[1].Status) } }) diff --git a/cmd/enterprise-portal/internal/database/subscriptions/subscriptions_test.go b/cmd/enterprise-portal/internal/database/subscriptions/subscriptions_test.go index 567ba850ecdfd..2cbef7b07dfd0 100644 --- a/cmd/enterprise-portal/internal/database/subscriptions/subscriptions_test.go +++ b/cmd/enterprise-portal/internal/database/subscriptions/subscriptions_test.go @@ -182,7 +182,7 @@ func SubscriptionsStoreUpsert(t *testing.T, ctx context.Context, s *subscription pointers.DerefZero(got.InstanceDomain)) assert.Equal(t, currentSubscription.DisplayName, got.DisplayName) // Round times to allow for some precision drift in CI - assert.Equal(t, yesterday.Round(time.Second).UTC(), got.CreatedAt.Time().Round(time.Second)) + assert.Equal(t, yesterday.Round(time.Second).UTC(), got.CreatedAt.GetTime().Round(time.Second)) }) t.Run("update only archived at", func(t *testing.T) { @@ -197,7 +197,7 @@ func SubscriptionsStoreUpsert(t *testing.T, ctx context.Context, s *subscription assert.Equal(t, *currentSubscription.DisplayName, *got.DisplayName) assert.Equal(t, currentSubscription.CreatedAt, got.CreatedAt) // Round times to allow for some precision drift in CI - assert.Equal(t, yesterday.Round(time.Second).UTC(), got.ArchivedAt.Time().Round(time.Second)) + assert.Equal(t, yesterday.Round(time.Second).UTC(), got.ArchivedAt.GetTime().Round(time.Second)) }) t.Run("force update to zero values", func(t *testing.T) {