Skip to content
This repository has been archived by the owner on Sep 30, 2024. It is now read-only.

Commit

Permalink
feat/enterpriseportal: all subscriptions APIs use enterprise portal DB (
Browse files Browse the repository at this point in the history
#63959)

This change follows
https://github.com/sourcegraph/sourcegraph/pull/63858 by making the
_all_ subscriptions APIs read and write to the Enterprise Portal
database, instead of dotcomdb, using the data that we sync from dotcomdb
into Enterprise Portal.

With this PR, all initially proposed subscriptions APIs are at least
partially implemented.

Uses hexops/valast#27 for custom `autogold`
rendering of `utctime.Time`

Closes https://linear.app/sourcegraph/issue/CORE-156
Part of https://linear.app/sourcegraph/issue/CORE-158

## Test plan

- [x] Unit tests on API level
- [x] Adapters unit testing
- [x] Simple E2E test:
https://github.com/sourcegraph/sourcegraph/pull/64057
  • Loading branch information
bobheadxi authored Aug 10, 2024
1 parent 8296e98 commit e2c646a
Show file tree
Hide file tree
Showing 29 changed files with 2,920 additions and 316 deletions.
2 changes: 2 additions & 0 deletions cmd/enterprise-portal/internal/codyaccessservice/v1_store.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,8 @@ type StoreV1 interface {

// GetCodyGatewayUsage retrieves recent Cody Gateway usage data.
// The subscriptionID should not be prefixed.
//
// Returns errStoreUnimplemented if the data source not configured.
GetCodyGatewayUsage(ctx context.Context, subscriptionID string) (*codyaccessv1.CodyGatewayUsage, error)

// GetCodyGatewayAccessBySubscription retrieves Cody Gateway access by
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -245,7 +245,7 @@ func (i *Importer) importSubscription(ctx context.Context, dotcomSub *dotcomdb.S
}
return pointers.Ptr(utctime.FromTime(*dotcomSub.ArchivedAt))
}(),
SalesforceSubscriptionID: activeLicense.SalesforceSubscriptionID,
SalesforceSubscriptionID: database.NewNullStringPtr(activeLicense.SalesforceSubscriptionID),
},
conditions...,
); err != nil {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ VALUES (
)`, pgx.NamedArgs{
"licenseID": licenseID,
// Convert to string representation of EnterpriseSubscriptionLicenseCondition
"status": subscriptionsv1.EnterpriseSubscriptionLicenseCondition_Status_name[int32(opts.Status)],
"status": opts.Status.String(),
"message": pointers.NilIfZero(opts.Message),
"transitionTime": opts.TransitionTime,
})
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -362,7 +362,7 @@ VALUES (
`, pgx.NamedArgs{
"licenseID": licenseID,
"subscriptionID": subscriptionID,
"licenseType": subscriptionsv1.EnterpriseSubscriptionLicenseType_name[int32(licenseType)],
"licenseType": licenseType.String(),
"licenseData": licenseData,
"createdAt": opts.Time,
"expireAt": opts.ExpireTime,
Expand Down Expand Up @@ -422,6 +422,9 @@ WHERE id = @licenseID
"revokedAt": opts.Time,
"licenseID": licenseID,
}); err != nil {
if errors.Is(err, pgx.ErrNoRows) {
return nil, ErrSubscriptionLicenseNotFound
}
return nil, errors.Wrap(err, "revoke license")
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,7 @@ func (opts ListEnterpriseSubscriptionsOptions) toQueryConditions() (where, limit
if *opts.IsArchived {
whereConds = append(whereConds, "archived_at IS NOT NULL")
} else {
whereConds = append(whereConds, "archived IS NUlL")
whereConds = append(whereConds, "archived_at IS NUlL")
}
}
if len(opts.DisplayNameSubstring) > 0 {
Expand Down Expand Up @@ -221,7 +221,7 @@ type UpsertSubscriptionOptions struct {
CreatedAt utctime.Time
ArchivedAt *utctime.Time

SalesforceSubscriptionID *string
SalesforceSubscriptionID *sql.NullString

// ForceUpdate indicates whether to force update all fields of the subscription
// record.
Expand Down Expand Up @@ -249,9 +249,14 @@ func (opts UpsertSubscriptionOptions) apply(ctx context.Context, db upsert.Exece
return b.Exec(ctx, db)
}

var ErrInvalidArgument = errors.New("invalid argument")

// Upsert upserts a subscription record based on the given options. If the
// operation has additional application meaning, conditions can be provided
// for insert as well.
//
// Constraint errors are returned as a human-friendly error that wraps
// ErrInvalidArgument.
func (s *Store) Upsert(
ctx context.Context,
subscriptionID string,
Expand Down Expand Up @@ -281,7 +286,12 @@ func (s *Store) Upsert(
if err := opts.apply(ctx, tx, subscriptionID); err != nil {
if pgxerrors.IsContraintError(err, "idx_enterprise_portal_subscriptions_display_name") {
return nil, errors.WithSafeDetails(
errors.Newf("display_name %q is already in use", opts.DisplayName.String),
errors.Wrapf(ErrInvalidArgument, "display_name %q is already in use", opts.DisplayName.String),
"%+v", err)
}
if pgxerrors.IsContraintError(err, "idx_enterprise_portal_subscriptions_instance_domain") {
return nil, errors.WithSafeDetails(
errors.Wrapf(ErrInvalidArgument, "instance_domain %q is assigned to another subscription", opts.DisplayName.String),
"%+v", err)
}
return nil, errors.Wrap(err, "upsert")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ VALUES (
)`, pgx.NamedArgs{
"subscriptionID": subscriptionID,
// Convert to string representation of EnterpriseSubscriptionCondition
"status": subscriptionsv1.EnterpriseSubscriptionCondition_Status_name[int32(opts.Status)],
"status": opts.Status.String(),
"message": pointers.NilIfZero(opts.Message),
"transitionTime": opts.TransitionTime,
})
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ func SubscriptionsStoreList(t *testing.T, ctx context.Context, s *subscriptions.
subscriptions.UpsertSubscriptionOptions{
DisplayName: database.NewNullString("Subscription 1"),
InstanceDomain: database.NewNullString("s1.sourcegraph.com"),
SalesforceSubscriptionID: pointers.Ptr("sf_sub_id"),
SalesforceSubscriptionID: database.NewNullString("sf_sub_id"),
},
)
require.NoError(t, err)
Expand Down Expand Up @@ -199,6 +199,22 @@ func SubscriptionsStoreList(t *testing.T, ctx context.Context, s *subscriptions.
assert.Equal(t, s1.ID, ss[0].ID)
})

t.Run("list by not archived", func(t *testing.T) {
t.Parallel()

ss, err := s.List(
ctx,
subscriptions.ListEnterpriseSubscriptionsOptions{
IsArchived: pointers.Ptr(false),
},
)
require.NoError(t, err)
assert.NotEmpty(t, ss)
for _, s := range ss {
assert.Nil(t, s.ArchivedAt)
}
})

t.Run("list with page size", func(t *testing.T) {
t.Parallel()

Expand Down
11 changes: 11 additions & 0 deletions cmd/enterprise-portal/internal/database/types.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,17 @@ func NewNullString(v string) *sql.NullString {
}
}

// NewNullString creates an *sql.NullString that indicates "invalid", i.e. null,
// if v is nil or an empty string. It returns a pointer because many use cases
// require a pointer - it is safe to immediately deref the return value if you
// need to, since it always returns a non-nil value.
func NewNullStringPtr(v *string) *sql.NullString {
if v == nil {
return &sql.NullString{}
}
return NewNullString(*v)
}

// NewNullInt32 is like NewNullString, but always produces a valid value.
func NewNullInt32[T int | int32 | int64 | uint64](v T) *sql.NullInt32 {
return &sql.NullInt32{
Expand Down
6 changes: 5 additions & 1 deletion cmd/enterprise-portal/internal/database/utctime/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,15 @@ load("@io_bazel_rules_go//go:def.bzl", "go_library")

go_library(
name = "utctime",
srcs = ["utctime.go"],
srcs = [
"utctime.go",
"valast.go",
],
importpath = "github.com/sourcegraph/sourcegraph/cmd/enterprise-portal/internal/database/utctime",
visibility = ["//cmd/enterprise-portal:__subpackages__"],
deps = [
"//lib/errors",
"//lib/pointers",
"@com_github_hexops_valast//:valast",
],
)
5 changes: 5 additions & 0 deletions cmd/enterprise-portal/internal/database/utctime/utctime.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,11 @@ func Now() Time { return Time(time.Now()) }
// FromTime returns a utctime.Time from a time.Time.
func FromTime(t time.Time) Time { return Time(t.UTC().Round(time.Microsecond)) }

// Date is analagous to time.Date, but only represents UTC time.
func Date(year int, month time.Month, day, hour, min, sec, nsec int) Time {
return FromTime(time.Date(year, month, day, hour, min, sec, nsec, time.UTC))
}

var _ sql.Scanner = (*Time)(nil)

func (t *Time) Scan(src any) error {
Expand Down
31 changes: 31 additions & 0 deletions cmd/enterprise-portal/internal/database/utctime/valast.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
package utctime

import (
"fmt"
"go/ast"
"go/token"

"github.com/hexops/valast"
)

// Register custom representation for autogold.
func init() {
valast.RegisterType(func(ut Time) ast.Expr {
t := ut.AsTime()
return &ast.CallExpr{
Fun: &ast.SelectorExpr{
X: &ast.Ident{Name: "utctime"},
Sel: &ast.Ident{Name: "Date"},
},
Args: []ast.Expr{
&ast.BasicLit{Kind: token.INT, Value: fmt.Sprintf("%d", t.Year())},
&ast.BasicLit{Kind: token.INT, Value: fmt.Sprintf("%d", t.Month())},
&ast.BasicLit{Kind: token.INT, Value: fmt.Sprintf("%d", t.Day())},
&ast.BasicLit{Kind: token.INT, Value: fmt.Sprintf("%d", t.Hour())},
&ast.BasicLit{Kind: token.INT, Value: fmt.Sprintf("%d", t.Minute())},
&ast.BasicLit{Kind: token.INT, Value: fmt.Sprintf("%d", t.Second())},
&ast.BasicLit{Kind: token.INT, Value: fmt.Sprintf("%d", t.Nanosecond())},
},
}
})
}
14 changes: 14 additions & 0 deletions cmd/enterprise-portal/internal/subscriptionsservice/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -16,21 +16,26 @@ go_library(
"//cmd/enterprise-portal/internal/connectutil",
"//cmd/enterprise-portal/internal/database",
"//cmd/enterprise-portal/internal/database/subscriptions",
"//cmd/enterprise-portal/internal/database/utctime",
"//cmd/enterprise-portal/internal/dotcomdb",
"//cmd/enterprise-portal/internal/samsm2m",
"//internal/collections",
"//internal/license",
"//internal/licensing",
"//internal/trace",
"//lib/enterpriseportal/subscriptions/v1:subscriptions",
"//lib/enterpriseportal/subscriptions/v1/v1connect",
"//lib/errors",
"//lib/managedservicesplatform/iam",
"//lib/pointers",
"@com_connectrpc_connect//:connect",
"@com_github_google_uuid//:uuid",
"@com_github_sourcegraph_log//:log",
"@com_github_sourcegraph_sourcegraph_accounts_sdk_go//:sourcegraph-accounts-sdk-go",
"@com_github_sourcegraph_sourcegraph_accounts_sdk_go//clients/v1:clients",
"@com_github_sourcegraph_sourcegraph_accounts_sdk_go//scopes",
"@org_golang_google_protobuf//types/known/timestamppb",
"@org_golang_x_crypto//ssh",
"@org_golang_x_exp//maps",
],
)
Expand All @@ -45,19 +50,28 @@ go_test(
embed = [":subscriptionsservice"],
deps = [
"//cmd/enterprise-portal/internal/database/subscriptions",
"//cmd/enterprise-portal/internal/database/utctime",
"//cmd/enterprise-portal/internal/samsm2m",
"//internal/license",
"//lib/enterpriseportal/subscriptions/v1:subscriptions",
"//lib/errors",
"//lib/managedservicesplatform/iam",
"//lib/pointers",
"@com_connectrpc_connect//:connect",
"@com_github_derision_test_go_mockgen_v2//testutil/require",
"@com_github_google_uuid//:uuid",
"@com_github_hexops_autogold_v2//:autogold",
"@com_github_hexops_valast//:valast",
"@com_github_sourcegraph_log//logtest",
"@com_github_sourcegraph_sourcegraph_accounts_sdk_go//:sourcegraph-accounts-sdk-go",
"@com_github_sourcegraph_sourcegraph_accounts_sdk_go//clients/v1:clients",
"@com_github_sourcegraph_sourcegraph_accounts_sdk_go//scopes",
"@com_github_stretchr_testify//assert",
"@com_github_stretchr_testify//require",
"@org_golang_google_protobuf//encoding/protojson",
"@org_golang_google_protobuf//reflect/protoreflect",
"@org_golang_google_protobuf//types/known/fieldmaskpb",
"@org_golang_google_protobuf//types/known/timestamppb",
],
)

Expand Down
78 changes: 75 additions & 3 deletions cmd/enterprise-portal/internal/subscriptionsservice/adapters.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,16 @@ package subscriptionsservice

import (
"encoding/json"
"fmt"
"strings"

"connectrpc.com/connect"
"google.golang.org/protobuf/types/known/timestamppb"

"github.com/sourcegraph/sourcegraph/cmd/enterprise-portal/internal/database/subscriptions"
"github.com/sourcegraph/sourcegraph/cmd/enterprise-portal/internal/database/utctime"
"github.com/sourcegraph/sourcegraph/internal/license"
"github.com/sourcegraph/sourcegraph/internal/licensing"
subscriptionsv1 "github.com/sourcegraph/sourcegraph/lib/enterpriseportal/subscriptions/v1"
"github.com/sourcegraph/sourcegraph/lib/errors"
"github.com/sourcegraph/sourcegraph/lib/managedservicesplatform/iam"
Expand Down Expand Up @@ -34,7 +40,7 @@ func convertLicenseToProto(license *subscriptions.LicenseWithConditions) (*subsc
case subscriptionsv1.EnterpriseSubscriptionLicenseType_ENTERPRISE_SUBSCRIPTION_LICENSE_TYPE_KEY.String():
var data subscriptions.DataLicenseKey
if err := json.Unmarshal(license.LicenseData, &data); err != nil {
return proto, errors.Wrap(err, "unmarshal license data")
return proto, errors.Wrapf(err, "unmarshal license data: %q", string(license.LicenseData))
}
proto.License = &subscriptionsv1.EnterpriseSubscriptionLicense_Key{
Key: &subscriptionsv1.EnterpriseSubscriptionLicenseKey{
Expand All @@ -46,9 +52,11 @@ func convertLicenseToProto(license *subscriptions.LicenseWithConditions) (*subsc
SalesforceSubscriptionId: pointers.DerefZero(data.Info.SalesforceSubscriptionID),
SalesforceOpportunityId: pointers.DerefZero(data.Info.SalesforceOpportunityID),
},
LicenseKey: data.SignedKey,
LicenseKey: data.SignedKey,
PlanDisplayName: licensing.ProductNameWithBrand(data.Info.Tags),
},
}

default:
return proto, errors.Newf("unknown license type %q", t)
}
Expand Down Expand Up @@ -105,8 +113,72 @@ func convertProtoToIAMTupleRelation(action subscriptionsv1.PermissionRelation) i
func convertProtoRoleToIAMTupleObject(role subscriptionsv1.Role, subscriptionID string) iam.TupleObject {
switch role {
case subscriptionsv1.Role_ROLE_SUBSCRIPTION_CUSTOMER_ADMIN:
return iam.ToTupleObject(iam.TupleTypeCustomerAdmin, subscriptionID)
return iam.ToTupleObject(iam.TupleTypeCustomerAdmin,
strings.TrimPrefix(subscriptionID, subscriptionsv1.EnterpriseSubscriptionIDPrefix))
default:
return ""
}
}

// convertLicenseKeyToLicenseKeyData converts a create-license request into an
// actual license key for creating a database entry.
//
// It may return Connect errors - all other errors should be considered internal
// errors.
func convertLicenseKeyToLicenseKeyData(
createdAt utctime.Time,
sub *subscriptions.Subscription,
key *subscriptionsv1.EnterpriseSubscriptionLicenseKey,
// StoreV1.GetRequiredEnterpriseSubscriptionLicenseKeyTags
requiredTags []string,
// StoreV1.SignEnterpriseSubscriptionLicenseKey
signKeyFn func(license.Info) (string, error),
) (*subscriptions.DataLicenseKey, error) {
if key.GetInfo().GetUserCount() == 0 {
return nil, connect.NewError(connect.CodeInvalidArgument, errors.New("user_count is invalid"))
}
expires := key.GetInfo().GetExpireTime().AsTime()
if expires.Before(createdAt.AsTime()) {
return nil, connect.NewError(connect.CodeInvalidArgument, errors.New("expiry must be in the future"))
}
tags := key.GetInfo().GetTags()
providedTagPrefixes := map[string]struct{}{}
for _, t := range tags {
providedTagPrefixes[strings.SplitN(t, ":", 2)[0]] = struct{}{}
}
if _, exists := providedTagPrefixes["customer"]; !exists && sub.DisplayName != nil {
tags = append(tags, fmt.Sprintf("customer:%s", *sub.DisplayName))
}
for _, r := range requiredTags {
if _, ok := providedTagPrefixes[r]; !ok {
return nil, connect.NewError(connect.CodeInvalidArgument,
errors.Newf("key tags [%s] are required", strings.Join(requiredTags, ", ")))
}
}

info := license.Info{
Tags: tags,
UserCount: uint(key.GetInfo().GetUserCount()),
CreatedAt: createdAt.AsTime(),
// Cast expiry to utctime and back for uniform representation
ExpiresAt: utctime.FromTime(expires).AsTime(),
// Provided at creation
SalesforceOpportunityID: pointers.NilIfZero(key.GetInfo().GetSalesforceOpportunityId()),
// Inherited from subscription
SalesforceSubscriptionID: sub.SalesforceSubscriptionID,
}
signedKey, err := signKeyFn(info)
if err != nil {
// See StoreV1.SignEnterpriseSubscriptionLicenseKey
if errors.Is(err, errStoreUnimplemented) {
return nil, connect.NewError(connect.CodeUnimplemented,
errors.Wrap(err, "key signing not available"))
}
return nil, errors.Wrap(err, "sign key")
}

return &subscriptions.DataLicenseKey{
Info: info,
SignedKey: signedKey,
}, nil
}
Loading

0 comments on commit e2c646a

Please sign in to comment.