Skip to content
This repository has been archived by the owner on Sep 30, 2024. It is now read-only.

feat/enterpriseportal: all subscriptions APIs use enterprise portal DB #63959

2 changes: 2 additions & 0 deletions cmd/enterprise-portal/internal/codyaccessservice/v1_store.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,8 @@ type StoreV1 interface {

// GetCodyGatewayUsage retrieves recent Cody Gateway usage data.
// The subscriptionID should not be prefixed.
//
// Returns errStoreUnimplemented if the data source not configured.
GetCodyGatewayUsage(ctx context.Context, subscriptionID string) (*codyaccessv1.CodyGatewayUsage, error)

// GetCodyGatewayAccessBySubscription retrieves Cody Gateway access by
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -245,7 +245,7 @@ func (i *Importer) importSubscription(ctx context.Context, dotcomSub *dotcomdb.S
}
return pointers.Ptr(utctime.FromTime(*dotcomSub.ArchivedAt))
}(),
SalesforceSubscriptionID: activeLicense.SalesforceSubscriptionID,
SalesforceSubscriptionID: database.NewNullStringPtr(activeLicense.SalesforceSubscriptionID),
},
conditions...,
); err != nil {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ VALUES (
)`, pgx.NamedArgs{
"licenseID": licenseID,
// Convert to string representation of EnterpriseSubscriptionLicenseCondition
"status": subscriptionsv1.EnterpriseSubscriptionLicenseCondition_Status_name[int32(opts.Status)],
"status": opts.Status.String(),
"message": pointers.NilIfZero(opts.Message),
"transitionTime": opts.TransitionTime,
})
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -362,7 +362,7 @@ VALUES (
`, pgx.NamedArgs{
"licenseID": licenseID,
"subscriptionID": subscriptionID,
"licenseType": subscriptionsv1.EnterpriseSubscriptionLicenseType_name[int32(licenseType)],
"licenseType": licenseType.String(),
"licenseData": licenseData,
"createdAt": opts.Time,
"expireAt": opts.ExpireTime,
Expand Down Expand Up @@ -422,6 +422,9 @@ WHERE id = @licenseID
"revokedAt": opts.Time,
"licenseID": licenseID,
}); err != nil {
if errors.Is(err, pgx.ErrNoRows) {
return nil, ErrSubscriptionLicenseNotFound
}
return nil, errors.Wrap(err, "revoke license")
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,7 @@ func (opts ListEnterpriseSubscriptionsOptions) toQueryConditions() (where, limit
if *opts.IsArchived {
whereConds = append(whereConds, "archived_at IS NOT NULL")
} else {
whereConds = append(whereConds, "archived IS NUlL")
whereConds = append(whereConds, "archived_at IS NUlL")
}
}
if len(opts.DisplayNameSubstring) > 0 {
Expand Down Expand Up @@ -221,7 +221,7 @@ type UpsertSubscriptionOptions struct {
CreatedAt utctime.Time
ArchivedAt *utctime.Time

SalesforceSubscriptionID *string
SalesforceSubscriptionID *sql.NullString

// ForceUpdate indicates whether to force update all fields of the subscription
// record.
Expand Down Expand Up @@ -249,9 +249,14 @@ func (opts UpsertSubscriptionOptions) apply(ctx context.Context, db upsert.Exece
return b.Exec(ctx, db)
}

var ErrInvalidArgument = errors.New("invalid argument")

// Upsert upserts a subscription record based on the given options. If the
// operation has additional application meaning, conditions can be provided
// for insert as well.
//
// Constraint errors are returned as a human-friendly error that wraps
// ErrInvalidArgument.
func (s *Store) Upsert(
ctx context.Context,
subscriptionID string,
Expand Down Expand Up @@ -281,7 +286,12 @@ func (s *Store) Upsert(
if err := opts.apply(ctx, tx, subscriptionID); err != nil {
if pgxerrors.IsContraintError(err, "idx_enterprise_portal_subscriptions_display_name") {
return nil, errors.WithSafeDetails(
errors.Newf("display_name %q is already in use", opts.DisplayName.String),
errors.Wrapf(ErrInvalidArgument, "display_name %q is already in use", opts.DisplayName.String),
"%+v", err)
}
if pgxerrors.IsContraintError(err, "idx_enterprise_portal_subscriptions_instance_domain") {
return nil, errors.WithSafeDetails(
errors.Wrapf(ErrInvalidArgument, "instance_domain %q is assigned to another subscription", opts.DisplayName.String),
"%+v", err)
}
return nil, errors.Wrap(err, "upsert")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ VALUES (
)`, pgx.NamedArgs{
"subscriptionID": subscriptionID,
// Convert to string representation of EnterpriseSubscriptionCondition
"status": subscriptionsv1.EnterpriseSubscriptionCondition_Status_name[int32(opts.Status)],
"status": opts.Status.String(),
"message": pointers.NilIfZero(opts.Message),
"transitionTime": opts.TransitionTime,
})
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ func SubscriptionsStoreList(t *testing.T, ctx context.Context, s *subscriptions.
subscriptions.UpsertSubscriptionOptions{
DisplayName: database.NewNullString("Subscription 1"),
InstanceDomain: database.NewNullString("s1.sourcegraph.com"),
SalesforceSubscriptionID: pointers.Ptr("sf_sub_id"),
SalesforceSubscriptionID: database.NewNullString("sf_sub_id"),
},
)
require.NoError(t, err)
Expand Down Expand Up @@ -199,6 +199,22 @@ func SubscriptionsStoreList(t *testing.T, ctx context.Context, s *subscriptions.
assert.Equal(t, s1.ID, ss[0].ID)
})

t.Run("list by not archived", func(t *testing.T) {
t.Parallel()

ss, err := s.List(
ctx,
subscriptions.ListEnterpriseSubscriptionsOptions{
IsArchived: pointers.Ptr(false),
},
)
require.NoError(t, err)
assert.NotEmpty(t, ss)
for _, s := range ss {
assert.Nil(t, s.ArchivedAt)
}
})

t.Run("list with page size", func(t *testing.T) {
t.Parallel()

Expand Down
11 changes: 11 additions & 0 deletions cmd/enterprise-portal/internal/database/types.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,17 @@ func NewNullString(v string) *sql.NullString {
}
}

// NewNullString creates an *sql.NullString that indicates "invalid", i.e. null,
// if v is nil or an empty string. It returns a pointer because many use cases
// require a pointer - it is safe to immediately deref the return value if you
// need to, since it always returns a non-nil value.
func NewNullStringPtr(v *string) *sql.NullString {
if v == nil {
return &sql.NullString{}
}
return NewNullString(*v)
}

// NewNullInt32 is like NewNullString, but always produces a valid value.
func NewNullInt32[T int | int32 | int64 | uint64](v T) *sql.NullInt32 {
return &sql.NullInt32{
Expand Down
6 changes: 5 additions & 1 deletion cmd/enterprise-portal/internal/database/utctime/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,15 @@ load("@io_bazel_rules_go//go:def.bzl", "go_library")

go_library(
name = "utctime",
srcs = ["utctime.go"],
srcs = [
"utctime.go",
"valast.go",
],
importpath = "github.com/sourcegraph/sourcegraph/cmd/enterprise-portal/internal/database/utctime",
visibility = ["//cmd/enterprise-portal:__subpackages__"],
deps = [
"//lib/errors",
"//lib/pointers",
"@com_github_hexops_valast//:valast",
],
)
5 changes: 5 additions & 0 deletions cmd/enterprise-portal/internal/database/utctime/utctime.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,11 @@ func Now() Time { return Time(time.Now()) }
// FromTime returns a utctime.Time from a time.Time.
func FromTime(t time.Time) Time { return Time(t.UTC().Round(time.Microsecond)) }

// Date is analagous to time.Date, but only represents UTC time.
func Date(year int, month time.Month, day, hour, min, sec, nsec int) Time {
return FromTime(time.Date(year, month, day, hour, min, sec, nsec, time.UTC))
}

var _ sql.Scanner = (*Time)(nil)

func (t *Time) Scan(src any) error {
Expand Down
31 changes: 31 additions & 0 deletions cmd/enterprise-portal/internal/database/utctime/valast.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
package utctime

import (
"fmt"
"go/ast"
"go/token"

"github.com/hexops/valast"
)

// Register custom representation for autogold.
func init() {
valast.RegisterType(func(ut Time) ast.Expr {
t := ut.AsTime()
return &ast.CallExpr{
Fun: &ast.SelectorExpr{
X: &ast.Ident{Name: "utctime"},
Sel: &ast.Ident{Name: "Date"},
},
Args: []ast.Expr{
&ast.BasicLit{Kind: token.INT, Value: fmt.Sprintf("%d", t.Year())},
&ast.BasicLit{Kind: token.INT, Value: fmt.Sprintf("%d", t.Month())},
&ast.BasicLit{Kind: token.INT, Value: fmt.Sprintf("%d", t.Day())},
&ast.BasicLit{Kind: token.INT, Value: fmt.Sprintf("%d", t.Hour())},
&ast.BasicLit{Kind: token.INT, Value: fmt.Sprintf("%d", t.Minute())},
&ast.BasicLit{Kind: token.INT, Value: fmt.Sprintf("%d", t.Second())},
&ast.BasicLit{Kind: token.INT, Value: fmt.Sprintf("%d", t.Nanosecond())},
},
}
})
}
14 changes: 14 additions & 0 deletions cmd/enterprise-portal/internal/subscriptionsservice/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -16,21 +16,26 @@ go_library(
"//cmd/enterprise-portal/internal/connectutil",
"//cmd/enterprise-portal/internal/database",
"//cmd/enterprise-portal/internal/database/subscriptions",
"//cmd/enterprise-portal/internal/database/utctime",
"//cmd/enterprise-portal/internal/dotcomdb",
"//cmd/enterprise-portal/internal/samsm2m",
"//internal/collections",
"//internal/license",
"//internal/licensing",
"//internal/trace",
"//lib/enterpriseportal/subscriptions/v1:subscriptions",
"//lib/enterpriseportal/subscriptions/v1/v1connect",
"//lib/errors",
"//lib/managedservicesplatform/iam",
"//lib/pointers",
"@com_connectrpc_connect//:connect",
"@com_github_google_uuid//:uuid",
"@com_github_sourcegraph_log//:log",
"@com_github_sourcegraph_sourcegraph_accounts_sdk_go//:sourcegraph-accounts-sdk-go",
"@com_github_sourcegraph_sourcegraph_accounts_sdk_go//clients/v1:clients",
"@com_github_sourcegraph_sourcegraph_accounts_sdk_go//scopes",
"@org_golang_google_protobuf//types/known/timestamppb",
"@org_golang_x_crypto//ssh",
"@org_golang_x_exp//maps",
],
)
Expand All @@ -45,19 +50,28 @@ go_test(
embed = [":subscriptionsservice"],
deps = [
"//cmd/enterprise-portal/internal/database/subscriptions",
"//cmd/enterprise-portal/internal/database/utctime",
"//cmd/enterprise-portal/internal/samsm2m",
"//internal/license",
"//lib/enterpriseportal/subscriptions/v1:subscriptions",
"//lib/errors",
"//lib/managedservicesplatform/iam",
"//lib/pointers",
"@com_connectrpc_connect//:connect",
"@com_github_derision_test_go_mockgen_v2//testutil/require",
"@com_github_google_uuid//:uuid",
"@com_github_hexops_autogold_v2//:autogold",
"@com_github_hexops_valast//:valast",
"@com_github_sourcegraph_log//logtest",
"@com_github_sourcegraph_sourcegraph_accounts_sdk_go//:sourcegraph-accounts-sdk-go",
"@com_github_sourcegraph_sourcegraph_accounts_sdk_go//clients/v1:clients",
"@com_github_sourcegraph_sourcegraph_accounts_sdk_go//scopes",
"@com_github_stretchr_testify//assert",
"@com_github_stretchr_testify//require",
"@org_golang_google_protobuf//encoding/protojson",
"@org_golang_google_protobuf//reflect/protoreflect",
"@org_golang_google_protobuf//types/known/fieldmaskpb",
"@org_golang_google_protobuf//types/known/timestamppb",
],
)

Expand Down
78 changes: 75 additions & 3 deletions cmd/enterprise-portal/internal/subscriptionsservice/adapters.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,16 @@ package subscriptionsservice

import (
"encoding/json"
"fmt"
"strings"

"connectrpc.com/connect"
"google.golang.org/protobuf/types/known/timestamppb"

"github.com/sourcegraph/sourcegraph/cmd/enterprise-portal/internal/database/subscriptions"
"github.com/sourcegraph/sourcegraph/cmd/enterprise-portal/internal/database/utctime"
"github.com/sourcegraph/sourcegraph/internal/license"
"github.com/sourcegraph/sourcegraph/internal/licensing"
subscriptionsv1 "github.com/sourcegraph/sourcegraph/lib/enterpriseportal/subscriptions/v1"
"github.com/sourcegraph/sourcegraph/lib/errors"
"github.com/sourcegraph/sourcegraph/lib/managedservicesplatform/iam"
Expand Down Expand Up @@ -34,7 +40,7 @@ func convertLicenseToProto(license *subscriptions.LicenseWithConditions) (*subsc
case subscriptionsv1.EnterpriseSubscriptionLicenseType_ENTERPRISE_SUBSCRIPTION_LICENSE_TYPE_KEY.String():
var data subscriptions.DataLicenseKey
if err := json.Unmarshal(license.LicenseData, &data); err != nil {
return proto, errors.Wrap(err, "unmarshal license data")
return proto, errors.Wrapf(err, "unmarshal license data: %q", string(license.LicenseData))
}
proto.License = &subscriptionsv1.EnterpriseSubscriptionLicense_Key{
Key: &subscriptionsv1.EnterpriseSubscriptionLicenseKey{
Expand All @@ -46,9 +52,11 @@ func convertLicenseToProto(license *subscriptions.LicenseWithConditions) (*subsc
SalesforceSubscriptionId: pointers.DerefZero(data.Info.SalesforceSubscriptionID),
SalesforceOpportunityId: pointers.DerefZero(data.Info.SalesforceOpportunityID),
},
LicenseKey: data.SignedKey,
LicenseKey: data.SignedKey,
PlanDisplayName: licensing.ProductNameWithBrand(data.Info.Tags),
},
}

default:
return proto, errors.Newf("unknown license type %q", t)
}
Expand Down Expand Up @@ -105,8 +113,72 @@ func convertProtoToIAMTupleRelation(action subscriptionsv1.PermissionRelation) i
func convertProtoRoleToIAMTupleObject(role subscriptionsv1.Role, subscriptionID string) iam.TupleObject {
switch role {
case subscriptionsv1.Role_ROLE_SUBSCRIPTION_CUSTOMER_ADMIN:
return iam.ToTupleObject(iam.TupleTypeCustomerAdmin, subscriptionID)
return iam.ToTupleObject(iam.TupleTypeCustomerAdmin,
strings.TrimPrefix(subscriptionID, subscriptionsv1.EnterpriseSubscriptionIDPrefix))
default:
return ""
}
}

// convertLicenseKeyToLicenseKeyData converts a create-license request into an
// actual license key for creating a database entry.
//
// It may return Connect errors - all other errors should be considered internal
// errors.
func convertLicenseKeyToLicenseKeyData(
createdAt utctime.Time,
sub *subscriptions.Subscription,
key *subscriptionsv1.EnterpriseSubscriptionLicenseKey,
// StoreV1.GetRequiredEnterpriseSubscriptionLicenseKeyTags
requiredTags []string,
// StoreV1.SignEnterpriseSubscriptionLicenseKey
signKeyFn func(license.Info) (string, error),
) (*subscriptions.DataLicenseKey, error) {
if key.GetInfo().GetUserCount() == 0 {
return nil, connect.NewError(connect.CodeInvalidArgument, errors.New("user_count is invalid"))
}
expires := key.GetInfo().GetExpireTime().AsTime()
if expires.Before(createdAt.AsTime()) {
return nil, connect.NewError(connect.CodeInvalidArgument, errors.New("expiry must be in the future"))
}
tags := key.GetInfo().GetTags()
providedTagPrefixes := map[string]struct{}{}
for _, t := range tags {
providedTagPrefixes[strings.SplitN(t, ":", 2)[0]] = struct{}{}
}
if _, exists := providedTagPrefixes["customer"]; !exists && sub.DisplayName != nil {
tags = append(tags, fmt.Sprintf("customer:%s", *sub.DisplayName))
}
for _, r := range requiredTags {
if _, ok := providedTagPrefixes[r]; !ok {
return nil, connect.NewError(connect.CodeInvalidArgument,
errors.Newf("key tags [%s] are required", strings.Join(requiredTags, ", ")))
}
}

info := license.Info{
Tags: tags,
UserCount: uint(key.GetInfo().GetUserCount()),
CreatedAt: createdAt.AsTime(),
// Cast expiry to utctime and back for uniform representation
ExpiresAt: utctime.FromTime(expires).AsTime(),
// Provided at creation
SalesforceOpportunityID: pointers.NilIfZero(key.GetInfo().GetSalesforceOpportunityId()),
// Inherited from subscription
SalesforceSubscriptionID: sub.SalesforceSubscriptionID,
}
signedKey, err := signKeyFn(info)
if err != nil {
// See StoreV1.SignEnterpriseSubscriptionLicenseKey
if errors.Is(err, errStoreUnimplemented) {
return nil, connect.NewError(connect.CodeUnimplemented,
errors.Wrap(err, "key signing not available"))
}
return nil, errors.Wrap(err, "sign key")
}

return &subscriptions.DataLicenseKey{
Info: info,
SignedKey: signedKey,
}, nil
}
Loading
Loading