Skip to content
This repository has been archived by the owner on Sep 30, 2024. It is now read-only.

Commit

Permalink
implement license filters
Browse files Browse the repository at this point in the history
  • Loading branch information
bobheadxi committed Aug 8, 2024
1 parent f8c9093 commit 289be88
Show file tree
Hide file tree
Showing 3 changed files with 72 additions and 8 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,9 @@ func NewLicensesStore(db *pgxpool.Pool) *LicensesStore {
}

type ListLicensesOpts struct {
SubscriptionID string
SubscriptionID string
LicenseType subscriptionsv1.EnterpriseSubscriptionLicenseType
LicenseKeySubstring string
// PageSize is the maximum number of licenses to return.
PageSize int
}
Expand All @@ -144,6 +146,16 @@ func (opts ListLicensesOpts) toQueryConditions() (where, limitClause string, _ p
whereConds = append(whereConds, "subscription_id = @subscriptionID")
namedArgs["subscriptionID"] = opts.SubscriptionID
}
if opts.LicenseType > 0 {
whereConds = append(whereConds,
"license_type = @licenseType")
namedArgs["licenseType"] = opts.LicenseType.String()
}
if opts.LicenseKeySubstring != "" {
whereConds = append(whereConds,
"license_data->>'SignedKey' LIKE '%' || @licenseKeySubstring || '%'")
namedArgs["licenseKeySubstring"] = opts.LicenseKeySubstring
}
where = strings.Join(whereConds, " AND ")

if opts.PageSize > 0 {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ import (
"github.com/sourcegraph/sourcegraph/cmd/enterprise-portal/internal/database/subscriptions"
"github.com/sourcegraph/sourcegraph/cmd/enterprise-portal/internal/database/utctime"
"github.com/sourcegraph/sourcegraph/internal/license"
subscriptionsv1 "github.com/sourcegraph/sourcegraph/lib/enterpriseportal/subscriptions/v1"
"github.com/sourcegraph/sourcegraph/lib/pointers"
)

Expand All @@ -42,6 +43,8 @@ func TestLicensesStore(t *testing.T) {

licenses := subscriptions.NewLicensesStore(db)

const signedKeyExample = "<signed-key-example>"

var createdLicenses []*subscriptions.LicenseWithConditions
getCreatedByLicenseID := func(t *testing.T, licenseID string) *subscriptions.LicenseWithConditions {
for _, l := range createdLicenses {
Expand Down Expand Up @@ -100,7 +103,7 @@ func TestLicensesStore(t *testing.T) {
CreatedAt: time.Time{}.Add(24 * time.Hour),
ExpiresAt: time.Time{}.Add(48 * time.Hour),
},
SignedKey: "barasdf",
SignedKey: signedKeyExample,
},
subscriptions.CreateLicenseOpts{
Message: t.Name() + " 1",
Expand All @@ -111,7 +114,7 @@ func TestLicensesStore(t *testing.T) {
testLicense(
got,
autogold.Expect(valast.Ptr("TestLicensesStore/CreateLicenseKey 1")),
autogold.Expect(`{"Info": {"c": "0001-01-02T00:00:00Z", "e": "0001-01-03T00:00:00Z", "t": ["baz"], "u": 0}, "SignedKey": "barasdf"}`),
autogold.Expect(`{"Info": {"c": "0001-01-02T00:00:00Z", "e": "0001-01-03T00:00:00Z", "t": ["baz"], "u": 0}, "SignedKey": "<signed-key-example>"}`),
)
createdLicenses = append(createdLicenses, got)

Expand Down Expand Up @@ -209,6 +212,24 @@ func TestLicensesStore(t *testing.T) {
assert.Equal(t, *getCreatedByLicenseID(t, l.ID), *l)
}
})

t.Run("List by license key substring", func(t *testing.T) {
listedLicenses, err := licenses.List(ctx, subscriptions.ListLicensesOpts{
LicenseType: subscriptionsv1.EnterpriseSubscriptionLicenseType_ENTERPRISE_SUBSCRIPTION_LICENSE_TYPE_KEY,
LicenseKeySubstring: signedKeyExample,
})
require.NoError(t, err)
require.Len(t, listedLicenses, 1)
assert.Equal(t, subscriptionID1, listedLicenses[0].SubscriptionID)

listedLicenses, err = licenses.List(ctx, subscriptions.ListLicensesOpts{
LicenseType: subscriptionsv1.EnterpriseSubscriptionLicenseType_ENTERPRISE_SUBSCRIPTION_LICENSE_TYPE_KEY,
LicenseKeySubstring: signedKeyExample[2:5],
})
require.NoError(t, err)
require.Len(t, listedLicenses, 1)
assert.Equal(t, subscriptionID1, listedLicenses[0].SubscriptionID)
})
})

t.Run("Get", func(t *testing.T) {
Expand Down
41 changes: 36 additions & 5 deletions cmd/enterprise-portal/internal/subscriptionsservice/v1.go
Original file line number Diff line number Diff line change
Expand Up @@ -219,14 +219,37 @@ func (s *handlerV1) ListEnterpriseSubscriptionLicenses(ctx context.Context, req
// Validate filters
filters := req.Msg.GetFilters()
for _, filter := range filters {
// TODO: Implement additional filtering as needed
switch f := filter.GetFilter().(type) {
case *subscriptionsv1.ListEnterpriseSubscriptionLicensesFilter_Type:
return nil, connect.NewError(connect.CodeUnimplemented,
errors.New("filtering by type is not implemented"))
if f.Type == 0 {
return nil, connect.NewError(
connect.CodeInvalidArgument,
errors.New(`invalid filter: "type" is not valid`),
)
}
if opts.LicenseType != 0 {
return nil, connect.NewError(
connect.CodeInvalidArgument,
errors.New(`invalid filter: "type" provided more than once`),
)
}
opts.LicenseType = f.Type

case *subscriptionsv1.ListEnterpriseSubscriptionLicensesFilter_LicenseKeySubstring:
return nil, connect.NewError(connect.CodeUnimplemented,
errors.New("filtering by license key substring is not implemented"))
if f.LicenseKeySubstring == "" {
return nil, connect.NewError(
connect.CodeInvalidArgument,
errors.New(`invalid filter: "license_key_substring" is provided but is empty`),
)
}
if opts.LicenseKeySubstring != "" {
return nil, connect.NewError(
connect.CodeInvalidArgument,
errors.New(`invalid filter: "license_key_substring"" provided multiple times`),
)
}
opts.LicenseKeySubstring = f.LicenseKeySubstring

case *subscriptionsv1.ListEnterpriseSubscriptionLicensesFilter_SubscriptionId:
if f.SubscriptionId == "" {
return nil, connect.NewError(
Expand All @@ -244,6 +267,14 @@ func (s *handlerV1) ListEnterpriseSubscriptionLicenses(ctx context.Context, req
}
}

if opts.LicenseType != subscriptionsv1.EnterpriseSubscriptionLicenseType_ENTERPRISE_SUBSCRIPTION_LICENSE_TYPE_KEY &&
opts.LicenseKeySubstring != "" {
return nil, connect.NewError(
connect.CodeInvalidArgument,
errors.New(`invalid filters: "license_type" must be 'ENTERPRISE_SUBSCRIPTION_LICENSE_TYPE_KEY' to use the "license_key_substring" filter`),
)
}

licenses, err := s.store.ListEnterpriseSubscriptionLicenses(ctx, opts)
if err != nil {
if errors.Is(err, dotcomdb.ErrCodyGatewayAccessNotFound) {
Expand Down

0 comments on commit 289be88

Please sign in to comment.