Skip to content
This repository has been archived by the owner on Sep 30, 2024. It is now read-only.

Commit

Permalink
wip
Browse files Browse the repository at this point in the history
  • Loading branch information
bobheadxi committed Jul 23, 2024
1 parent 9afee6e commit c6041c4
Show file tree
Hide file tree
Showing 2 changed files with 192 additions and 26 deletions.
36 changes: 21 additions & 15 deletions cmd/enterprise-portal/internal/subscriptionsservice/v1.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,13 @@ func RegisterV1(
&handlerV1{
logger: logger.Scoped("subscriptions.v1"),
store: store,
subscriptionIDGenerator: func() (string, error) {
id, err := uuid.NewRandom()
if err != nil {
return "", errors.Wrap(err, "uuid")
}
return id.String(), nil
},
},
opts...,
),
Expand All @@ -51,6 +58,8 @@ func RegisterV1(
type handlerV1 struct {
logger log.Logger
store StoreV1

subscriptionIDGenerator func() (string, error)
}

var _ subscriptionsv1connect.SubscriptionsServiceHandler = (*handlerV1)(nil)
Expand Down Expand Up @@ -327,31 +336,28 @@ func (s *handlerV1) CreateEnterpriseSubscription(ctx context.Context, req *conne
return nil, connect.NewError(connect.CodeInvalidArgument, errors.New("subscription details are required"))
}

// Validate required arguments.
if strings.TrimSpace(sub.GetDisplayName()) == "" {
return nil, connect.NewError(connect.CodeInvalidArgument, errors.New("display_name is required"))
}

// Generate a new ID for the subscription.
if sub.Id == "" {
subscriptionID, err := uuid.NewRandom()
if err != nil {
return nil, connectutil.InternalError(ctx, s.logger, err, "failed to generate new subscription ID")
}
sub.Id = subscriptionID.String()
} else {
_, err := uuid.Parse(strings.TrimPrefix(sub.Id, subscriptionsv1.EnterpriseSubscriptionIDPrefix))
if err != nil {
return nil, connect.NewError(connect.CodeInvalidArgument, errors.Wrap(err, "custom subscription.id must be a UUID"))
}
if sub.Id != "" {
return nil, connect.NewError(connect.CodeInvalidArgument, errors.New("subscription_id can not be set"))
}
sub.Id, err = s.subscriptionIDGenerator()
if err != nil {
return nil, connectutil.InternalError(ctx, s.logger, err, "failed to generate new subscription ID")
}

// Check for an existing subscription, just in case.
if _, err := s.store.GetEnterpriseSubscription(ctx, sub.Id); err == nil {
return nil, connect.NewError(connect.CodeAlreadyExists, err)
} else if !errors.Is(err, subscriptions.ErrSubscriptionNotFound) {

Check notice

Code scanning / Semgrep OSS

Semgrep Finding: security-semgrep-rules.semgrep-rules.generic.comment-tagging-rule Note

Code that highlight SECURITY in comment has changed. Please review the code for changes. The changes might be sensitive.
return nil, connectutil.InternalError(ctx, logger, err,
"failed to check for existing subscription")
}

if strings.TrimSpace(sub.GetDisplayName()) == "" {
return nil, connect.NewError(connect.CodeInvalidArgument, errors.New("display_name is required required"))
}

createdAt := utctime.Now()
createdSub, err := s.store.UpsertEnterpriseSubscription(ctx, sub.Id,
subscriptions.UpsertSubscriptionOptions{
Expand Down
182 changes: 171 additions & 11 deletions cmd/enterprise-portal/internal/subscriptionsservice/v1_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"database/sql"
"fmt"
"slices"
"sync/atomic"
"testing"

"connectrpc.com/connect"
Expand All @@ -19,6 +20,7 @@ import (
"github.com/sourcegraph/sourcegraph-accounts-sdk-go/scopes"

"github.com/sourcegraph/sourcegraph/cmd/enterprise-portal/internal/database/subscriptions"
"github.com/sourcegraph/sourcegraph/cmd/enterprise-portal/internal/database/utctime"
"github.com/sourcegraph/sourcegraph/cmd/enterprise-portal/internal/samsm2m"
subscriptionsv1 "github.com/sourcegraph/sourcegraph/lib/enterpriseportal/subscriptions/v1"
"github.com/sourcegraph/sourcegraph/lib/managedservicesplatform/iam"
Expand All @@ -29,23 +31,26 @@ type testHandlerV1 struct {
mockStore *MockStoreV1
}

func newTestHandlerV1() *testHandlerV1 {
func newTestHandlerV1(tokenScopes ...scopes.Scope) *testHandlerV1 {
mockStore := NewMockStoreV1()
mockStore.IntrospectSAMSTokenFunc.SetDefaultReturn(
&sams.IntrospectTokenResponse{
Active: true,
Scopes: scopes.Scopes{
samsm2m.EnterprisePortalScope("subscription", scopes.ActionRead),
samsm2m.EnterprisePortalScope("subscription", scopes.ActionWrite),
samsm2m.EnterprisePortalScope("permission.subscription", scopes.ActionWrite),
},
Scopes: tokenScopes,
},
nil,
)

var uuidSeq atomic.Int32
predictableUUID := func() (string, error) {
return fmt.Sprintf("uuid-%d", uuidSeq.Add(1)), nil
}

return &testHandlerV1{
handlerV1: &handlerV1{
logger: logtest.NoOp(nil),
store: mockStore,
logger: logtest.NoOp(nil),
store: mockStore,
subscriptionIDGenerator: predictableUUID,
},
mockStore: mockStore,
}
Expand Down Expand Up @@ -169,7 +174,12 @@ func TestHandlerV1_ListEnterpriseSubscriptions(t *testing.T) {
req := connect.NewRequest(tc.list)
req.Header().Add("Authorization", "Bearer foolmeifyoucan")

h := newTestHandlerV1()
h := newTestHandlerV1(
samsm2m.EnterprisePortalScope(
scopes.PermissionEnterprisePortalSubscription,
scopes.ActionRead,
),
)
h.mockStore.IAMListObjectsFunc.SetDefaultHook(func(_ context.Context, opts iam.ListObjectsOptions) ([]string, error) {
return tc.iamObjectsHook(opts)
})
Expand Down Expand Up @@ -198,6 +208,142 @@ func TestHandlerV1_ListEnterpriseSubscriptions(t *testing.T) {
}
}

func TestHandlerV1_CreateEnterpriseSubscription(t *testing.T) {
ctx := context.Background()

for _, tc := range []struct {
name string
tokenScopes scopes.Scopes
create *subscriptionsv1.CreateEnterpriseSubscriptionRequest
wantError autogold.Value
wantUpsertOpts autogold.Value
}{
{
name: "no parameters",
create: &subscriptionsv1.CreateEnterpriseSubscriptionRequest{
Subscription: &subscriptionsv1.EnterpriseSubscription{},
},
wantError: autogold.Expect("invalid_argument: display_name is required"),
},
{
name: "custom subscription ID",
create: &subscriptionsv1.CreateEnterpriseSubscriptionRequest{
Subscription: &subscriptionsv1.EnterpriseSubscription{
Id: "not-allowed",
DisplayName: t.Name(),
},
},
wantError: autogold.Expect("invalid_argument: subscription_id can not be set"),
},
{
name: "insufficient scopes",
tokenScopes: scopes.Scopes{
samsm2m.EnterprisePortalScope(
scopes.PermissionEnterprisePortalSubscription,
scopes.ActionRead,
),
},
create: &subscriptionsv1.CreateEnterpriseSubscriptionRequest{
Subscription: &subscriptionsv1.EnterpriseSubscription{
Id: "not-allowed",
DisplayName: t.Name(),
},
},
wantError: autogold.Expect("permission_denied: insufficient scope"),
},
{
name: "with required params only",
create: &subscriptionsv1.CreateEnterpriseSubscriptionRequest{
Subscription: &subscriptionsv1.EnterpriseSubscription{
DisplayName: t.Name(),
},
},
wantUpsertOpts: autogold.Expect(subscriptions.UpsertSubscriptionOptions{
InstanceDomain: &sql.NullString{},
DisplayName: &sql.NullString{
String: "TestHandlerV1_CreateEnterpriseSubscription",
Valid: true,
},
SalesforceSubscriptionID: &sql.NullString{},
SalesforceOpportunityID: &sql.NullString{},
}),
},
{
name: "with message and optional fields",
create: &subscriptionsv1.CreateEnterpriseSubscriptionRequest{
Subscription: &subscriptionsv1.EnterpriseSubscription{
DisplayName: t.Name(),
Salesforce: &subscriptionsv1.EnterpriseSubscription_SalesforceMetadata{
SubscriptionId: "sf_sub",
OpportunityId: "sf_opp",
},
},
Message: "hello world",
},
wantUpsertOpts: autogold.Expect(subscriptions.UpsertSubscriptionOptions{
InstanceDomain: &sql.NullString{},
DisplayName: &sql.NullString{
String: "TestHandlerV1_CreateEnterpriseSubscription",
Valid: true,
},
SalesforceSubscriptionID: &sql.NullString{
String: "sf_sub",
Valid: true,
},
SalesforceOpportunityID: &sql.NullString{
String: "sf_opp",
Valid: true,
},
}),
},
} {
req := connect.NewRequest(tc.create)
req.Header().Add("Authorization", "Bearer foolmeifyoucan")

if tc.tokenScopes == nil {
tc.tokenScopes = scopes.Scopes{
samsm2m.EnterprisePortalScope(
scopes.PermissionEnterprisePortalSubscription,
scopes.ActionWrite,
),
}
}
h := newTestHandlerV1(tc.tokenScopes...)
h.mockStore.GetEnterpriseSubscriptionFunc.SetDefaultHook(func(_ context.Context, id string) (*subscriptions.SubscriptionWithConditions, error) {
return nil, subscriptions.ErrSubscriptionNotFound
})
h.mockStore.UpsertEnterpriseSubscriptionFunc.SetDefaultHook(func(_ context.Context, _ string, opts subscriptions.UpsertSubscriptionOptions, conds ...subscriptions.CreateSubscriptionConditionOptions) (*subscriptions.SubscriptionWithConditions, error) {
require.Len(t, conds, 1) // create must have condition

// Condition must match upsert
assert.Equal(t, tc.create.GetMessage(), conds[0].Message)
assert.Equal(t, opts.CreatedAt, conds[0].TransitionTime)
assert.Equal(t, subscriptionsv1.EnterpriseSubscriptionCondition_STATUS_CREATED,
conds[0].Status)

// Set to zero time for convenience with autogold
assert.NotZero(t, opts.CreatedAt)
opts.CreatedAt = utctime.Time{}
tc.wantUpsertOpts.Equal(t, opts)

return &subscriptions.SubscriptionWithConditions{}, nil
})
_, err := h.CreateEnterpriseSubscription(ctx, req)
if tc.wantError != nil {
require.Error(t, err)
tc.wantError.Equal(t, err.Error())
} else {
require.NoError(t, err)
}
if tc.wantUpsertOpts != nil {
mockrequire.CalledOnce(t, h.mockStore.UpsertEnterpriseSubscriptionFunc)
mockrequire.CalledOnce(t, h.mockStore.GetEnterpriseSubscriptionFunc)
} else {
mockrequire.NotCalled(t, h.mockStore.UpsertEnterpriseSubscriptionFunc)
}
}
}

func TestHandlerV1_UpdateEnterpriseSubscription(t *testing.T) {
ctx := context.Background()
const mockSubscriptionID = "es_80ca12e2-54b4-448c-a61a-390b1a9c1224"
Expand Down Expand Up @@ -281,7 +427,12 @@ func TestHandlerV1_UpdateEnterpriseSubscription(t *testing.T) {
req := connect.NewRequest(tc.update)
req.Header().Add("Authorization", "Bearer foolmeifyoucan")

h := newTestHandlerV1()
h := newTestHandlerV1(
samsm2m.EnterprisePortalScope(
scopes.PermissionEnterprisePortalSubscription,
scopes.ActionWrite,
),
)
h.mockStore.ListEnterpriseSubscriptionsFunc.SetDefaultReturn(
[]*subscriptions.SubscriptionWithConditions{
{Subscription: subscriptions.Subscription{
Expand All @@ -304,6 +455,10 @@ func TestHandlerV1_UpdateEnterpriseSubscription(t *testing.T) {
}
}

func TestHandlerV1_ArchiveEnterpriseSubscription(t *testing.T) {

}

func TestHandlerV1_UpdateEnterpriseSubscriptionMembership(t *testing.T) {
const (
subscriptionID = "80ca12e2-54b4-448c-a61a-390b1a9c1224"
Expand Down Expand Up @@ -426,7 +581,12 @@ func TestHandlerV1_UpdateEnterpriseSubscriptionMembership(t *testing.T) {
req := connect.NewRequest(tc.req)
req.Header().Add("Authorization", "Bearer foolmeifyoucan")

h := newTestHandlerV1()
h := newTestHandlerV1(
samsm2m.EnterprisePortalScope(
scopes.PermissionEnterprisePortalSubscriptionPermission,
scopes.ActionWrite,
),
)
h.mockStore.ListEnterpriseSubscriptionsFunc.SetDefaultHook(
func(_ context.Context, opts subscriptions.ListEnterpriseSubscriptionsOptions) ([]*subscriptions.SubscriptionWithConditions, error) {
if slices.Contains(opts.IDs, subscriptionID) ||
Expand Down

0 comments on commit c6041c4

Please sign in to comment.