Skip to content
This repository has been archived by the owner on Sep 30, 2024. It is now read-only.

Commit

Permalink
lift expire_time
Browse files Browse the repository at this point in the history
  • Loading branch information
bobheadxi committed Jul 12, 2024
1 parent 6a9d351 commit 785cf80
Show file tree
Hide file tree
Showing 5 changed files with 69 additions and 26 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -47,13 +47,13 @@ var _ driver.Valuer = (*Time)(nil)
// Value must be called with a non-nil Time. driver.Valuer callers will first
// check that the value is non-nil, so this is safe.
func (t Time) Value() (driver.Value, error) {
stdTime := t.Time()
stdTime := t.GetTime()
return *stdTime, nil
}

var _ json.Marshaler = (*Time)(nil)

func (t Time) MarshalJSON() ([]byte, error) { return json.Marshal(t.Time()) }
func (t Time) MarshalJSON() ([]byte, error) { return json.Marshal(t.GetTime()) }

var _ json.Unmarshaler = (*Time)(nil)

Expand All @@ -66,11 +66,15 @@ func (t *Time) UnmarshalJSON(data []byte) error {
return nil
}

// Time returns the underlying time.Time value, or nil if it is nil.
func (t *Time) Time() *time.Time {
// GetTime returns the underlying time.GetTime value, or nil if it is nil.
func (t *Time) GetTime() *time.Time {
if t == nil {
return nil
}
// Ensure the time is in UTC.
return pointers.Ptr((*time.Time)(t).UTC().Round(time.Microsecond))
return pointers.Ptr(t.AsTime())
}

// Time casts the Time as a standard time.Time value.
func (t Time) AsTime() time.Time {
return time.Time(t).UTC().Round(time.Microsecond)
}
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ type createLicenseConditionOpts struct {
}

func (s *licenseConditionsStore) createLicenseCondition(ctx context.Context, licenseID string, opts createLicenseConditionOpts) error {
if opts.TransitionTime.Time().IsZero() {
if opts.TransitionTime.GetTime().IsZero() {
return errors.New("transition time is required")
}
_, err := s.tx.Exec(ctx, `
Expand Down
24 changes: 19 additions & 5 deletions cmd/enterprise-portal/internal/database/subscriptions/licenses.go
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,10 @@ type SubscriptionLicense struct {
CreatedAt utctime.Time `gorm:"not null;default:current_timestamp"`
RevokedAt *utctime.Time // Null indicates the license is not revoked.

// ExpireAt is the time at which the license should expire. Expiration does
// NOT get a corresponding condition entry in 'enterprise_portal_subscription_license_conditions'.
ExpireAt utctime.Time `gorm:"not null"`

// LicenseType is the kind of license stored in LicenseData, corresponding
// to the API 'EnterpriseSubscriptionLicenseType'.
LicenseType string `gorm:"not null"`
Expand All @@ -77,6 +81,7 @@ func subscriptionLicenseWithConditionsColumns() []string {

"created_at",
"revoked_at",
"expire_at",

"license_type",
"license_data",
Expand All @@ -98,6 +103,7 @@ func scanSubscriptionLicenseWithConditions(row pgx.Row) (*LicenseWithConditions,
&l.ID,
&l.CreatedAt,
&l.RevokedAt,
&l.ExpireAt,
&l.LicenseType,
&l.LicenseData,
&l.Conditions, // see subscriptionLicenseConditionJSONBAgg docstring
Expand Down Expand Up @@ -207,6 +213,8 @@ type CreateLicenseOpts struct {
Message string
// If nil, the creation time will be set to the current time.
Time *utctime.Time
// Expiration time of the license.
ExpireTime utctime.Time
}

// LicenseKey corresponds to *subscriptionsv1.EnterpriseSubscriptionLicenseKey
Expand All @@ -228,9 +236,12 @@ func (s *LicensesStore) CreateLicenseKey(
// match the time provided in the options.
if opts.Time == nil {
return nil, errors.New("creation time must be specified for licensekeys")
} else if !opts.Time.Time().Equal(license.Info.CreatedAt) {
} else if !opts.Time.GetTime().Equal(utctime.FromTime(license.Info.CreatedAt).AsTime()) {
return nil, errors.New("creation time must match the license key information")
}
if !opts.ExpireTime.GetTime().Equal(utctime.FromTime(license.Info.ExpiresAt).AsTime()) {
return nil, errors.New("expiration time must match the license key information")
}

return s.create(
ctx,
Expand All @@ -253,7 +264,7 @@ func (s *LicensesStore) create(
}
if opts.Time == nil {
opts.Time = pointers.Ptr(utctime.Now())
} else if opts.Time.Time().After(time.Now()) {
} else if opts.Time.GetTime().After(time.Now()) {
return nil, errors.New("creation time cannot be in the future")
}
if licenseType == subscriptionsv1.EnterpriseSubscriptionLicenseType_ENTERPRISE_SUBSCRIPTION_LICENSE_TYPE_UNSPECIFIED {
Expand Down Expand Up @@ -284,21 +295,24 @@ INSERT INTO enterprise_portal_subscription_licenses (
subscription_id,
license_type,
license_data,
created_at
created_at,
expire_at
)
VALUES (
@licenseID,
@subscriptionID,
@licenseType,
@licenseData,
@createdAt
@createdAt,
@expireAt
)
`, pgx.NamedArgs{
"licenseID": licenseID.String(),
"subscriptionID": subscriptionID,
"licenseType": subscriptionsv1.EnterpriseSubscriptionLicenseType_name[int32(licenseType)],
"licenseData": licenseData,
"createdAt": opts.Time,
"expireAt": opts.ExpireTime,
}); err != nil {
return nil, errors.Wrap(err, "create license")
}
Expand Down Expand Up @@ -328,7 +342,7 @@ type RevokeLicenseOpts struct {
func (s *LicensesStore) Revoke(ctx context.Context, licenseID string, opts RevokeLicenseOpts) (*LicenseWithConditions, error) {
if opts.Time == nil {
opts.Time = pointers.Ptr(utctime.Now())
} else if opts.Time.Time().After(time.Now()) {
} else if opts.Time.GetTime().After(time.Now()) {
return nil, errors.New("revocation time cannot be in the future")
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ func TestLicensesStore(t *testing.T) {
) {
assert.NotEmpty(t, got.ID)
assert.NotZero(t, got.CreatedAt)
assert.NotZero(t, got.ExpireAt)
assert.Equal(t, "ENTERPRISE_SUBSCRIPTION_LICENSE_TYPE_KEY", got.LicenseType)
wantLicenseData.Equal(t, string(got.LicenseData))

Expand All @@ -75,18 +76,20 @@ func TestLicensesStore(t *testing.T) {
Info: license.Info{
Tags: []string{"foo"},
CreatedAt: time.Time{}.Add(1 * time.Hour),
ExpiresAt: time.Time{}.Add(48 * time.Hour),
},
SignedKey: "asdfasdf",
},
subscriptions.CreateLicenseOpts{
Message: t.Name() + " 1 old",
Time: pointers.Ptr(utctime.FromTime(time.Time{}.Add(1 * time.Hour))),
Message: t.Name() + " 1 old",
Time: pointers.Ptr(utctime.FromTime(time.Time{}.Add(1 * time.Hour))),
ExpireTime: utctime.FromTime(time.Time{}.Add(48 * time.Hour)),
})
require.NoError(t, err)
testLicense(
got,
autogold.Expect(valast.Ptr("TestLicensesStore/CreateLicenseKey 1 old")),
autogold.Expect(`{"Info": {"c": "0001-01-01T01:00:00Z", "e": "0001-01-01T00:00:00Z", "t": ["foo"], "u": 0}, "SignedKey": "asdfasdf"}`),
autogold.Expect(`{"Info": {"c": "0001-01-01T01:00:00Z", "e": "0001-01-03T00:00:00Z", "t": ["foo"], "u": 0}, "SignedKey": "asdfasdf"}`),
)
createdLicenses = append(createdLicenses, got)

Expand All @@ -95,18 +98,20 @@ func TestLicensesStore(t *testing.T) {
Info: license.Info{
Tags: []string{"baz"},
CreatedAt: time.Time{}.Add(24 * time.Hour),
ExpiresAt: time.Time{}.Add(48 * time.Hour),
},
SignedKey: "barasdf",
},
subscriptions.CreateLicenseOpts{
Message: t.Name() + " 1",
Time: pointers.Ptr(utctime.FromTime(time.Time{}.Add(24 * time.Hour))),
Message: t.Name() + " 1",
Time: pointers.Ptr(utctime.FromTime(time.Time{}.Add(24 * time.Hour))),
ExpireTime: utctime.FromTime(time.Time{}.Add(48 * time.Hour)),
})
require.NoError(t, err)
testLicense(
got,
autogold.Expect(valast.Ptr("TestLicensesStore/CreateLicenseKey 1")),
autogold.Expect(`{"Info": {"c": "0001-01-02T00:00:00Z", "e": "0001-01-01T00:00:00Z", "t": ["baz"], "u": 0}, "SignedKey": "barasdf"}`),
autogold.Expect(`{"Info": {"c": "0001-01-02T00:00:00Z", "e": "0001-01-03T00:00:00Z", "t": ["baz"], "u": 0}, "SignedKey": "barasdf"}`),
)
createdLicenses = append(createdLicenses, got)

Expand All @@ -115,22 +120,24 @@ func TestLicensesStore(t *testing.T) {
Info: license.Info{
Tags: []string{"tag"},
CreatedAt: time.Time{}.Add(24 * time.Hour),
ExpiresAt: time.Time{}.Add(48 * time.Hour),
},
SignedKey: "asdffdsadf",
},
subscriptions.CreateLicenseOpts{
Message: t.Name() + " 2",
Time: pointers.Ptr(utctime.FromTime(time.Time{}.Add(24 * time.Hour))),
Message: t.Name() + " 2",
Time: pointers.Ptr(utctime.FromTime(time.Time{}.Add(24 * time.Hour))),
ExpireTime: utctime.FromTime(time.Time{}.Add(48 * time.Hour)),
})
require.NoError(t, err)
testLicense(
got,
autogold.Expect(valast.Ptr("TestLicensesStore/CreateLicenseKey 2")),
autogold.Expect(`{"Info": {"c": "0001-01-02T00:00:00Z", "e": "0001-01-01T00:00:00Z", "t": ["tag"], "u": 0}, "SignedKey": "asdffdsadf"}`),
autogold.Expect(`{"Info": {"c": "0001-01-02T00:00:00Z", "e": "0001-01-03T00:00:00Z", "t": ["tag"], "u": 0}, "SignedKey": "asdffdsadf"}`),
)
createdLicenses = append(createdLicenses, got)

t.Run("timestamps do not match", func(t *testing.T) {
t.Run("createdAt does not match", func(t *testing.T) {
_, err = licenses.CreateLicenseKey(ctx, subscriptionID2,
&subscriptions.LicenseKey{
Info: license.Info{
Expand All @@ -146,6 +153,24 @@ func TestLicensesStore(t *testing.T) {
require.Error(t, err)
autogold.Expect("creation time must match the license key information").Equal(t, err.Error())
})
t.Run("expiresAt does not match", func(t *testing.T) {
_, err = licenses.CreateLicenseKey(ctx, subscriptionID2,
&subscriptions.LicenseKey{
Info: license.Info{
Tags: []string{"tag"},
CreatedAt: time.Time{},
ExpiresAt: time.Time{}.Add(48 * time.Hour),
},
SignedKey: "asdffdsadf",
},
subscriptions.CreateLicenseOpts{
Message: t.Name(),
Time: pointers.Ptr(utctime.FromTime(time.Time{})),
ExpireTime: utctime.Now(),
})
require.Error(t, err)
autogold.Expect("expiration time must match the license key information").Equal(t, err.Error())
})
})

t.Run("List", func(t *testing.T) {
Expand Down Expand Up @@ -196,11 +221,11 @@ func TestLicensesStore(t *testing.T) {
Time: pointers.Ptr(utctime.FromTime(revokeTime)),
})
require.NoError(t, err)
assert.Equal(t, revokeTime.UTC(), *got.RevokedAt.Time())
assert.Equal(t, revokeTime.UTC(), got.RevokedAt.AsTime())
require.Len(t, got.Conditions, 2)
// Most recent condition is sorted first, and should be the revocation
assert.Equal(t, "STATUS_REVOKED", got.Conditions[0].Status)
assert.Equal(t, revokeTime.UTC(), *got.Conditions[0].TransitionTime.Time())
assert.Equal(t, revokeTime.UTC(), *got.Conditions[0].TransitionTime.GetTime())
assert.Equal(t, "STATUS_CREATED", got.Conditions[1].Status)
}
})
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -182,7 +182,7 @@ func SubscriptionsStoreUpsert(t *testing.T, ctx context.Context, s *subscription
pointers.DerefZero(got.InstanceDomain))
assert.Equal(t, currentSubscription.DisplayName, got.DisplayName)
// Round times to allow for some precision drift in CI
assert.Equal(t, yesterday.Round(time.Second).UTC(), got.CreatedAt.Time().Round(time.Second))
assert.Equal(t, yesterday.Round(time.Second).UTC(), got.CreatedAt.GetTime().Round(time.Second))
})

t.Run("update only archived at", func(t *testing.T) {
Expand All @@ -197,7 +197,7 @@ func SubscriptionsStoreUpsert(t *testing.T, ctx context.Context, s *subscription
assert.Equal(t, *currentSubscription.DisplayName, *got.DisplayName)
assert.Equal(t, currentSubscription.CreatedAt, got.CreatedAt)
// Round times to allow for some precision drift in CI
assert.Equal(t, yesterday.Round(time.Second).UTC(), got.ArchivedAt.Time().Round(time.Second))
assert.Equal(t, yesterday.Round(time.Second).UTC(), got.ArchivedAt.GetTime().Round(time.Second))
})

t.Run("force update to zero values", func(t *testing.T) {
Expand Down

0 comments on commit 785cf80

Please sign in to comment.