This repository has been archived by the owner on Sep 30, 2024. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 1.3k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat/enterpriseportal: db layer for cody gateway access (#63737)
Implements Cody Gateway access in Enterprise Portal DB, such that it replicates the behaviour it has today, retrieving: - Quota overrides - Subscription display name - Active license info - Non-revoked, non-expired license keys as hashes - Revocation + non-expiry replaces the existing mechanism of flagging licenses as `access_token_enabled`. Since we ended up doing zero-config for Cody Gateway, the only license hashes that are valid for Cody Gateway are non-expired licenses - once your license expires you should be switching to a new license key anyway. It's fairly similar to the `dotcomdb` shim we built before, but for our new tables. See https://github.com/sourcegraph/sourcegraph/pull/63792 for the licenses tables. None of this is going live yet. Part of https://linear.app/sourcegraph/issue/CORE-100 Part of https://linear.app/sourcegraph/issue/CORE-160 ## Test plan DB integration tests `sg run enterprise-portal` does the migrations without a hitch
- Loading branch information
Showing
10 changed files
with
679 additions
and
17 deletions.
There are no files selected for viewing
34 changes: 33 additions & 1 deletion
34
cmd/enterprise-portal/internal/database/codyaccess/BUILD.bazel
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,9 +1,41 @@ | ||
load("//dev:go_defs.bzl", "go_test") | ||
load("@io_bazel_rules_go//go:def.bzl", "go_library") | ||
|
||
go_library( | ||
name = "codyaccess", | ||
srcs = ["codygateway.go"], | ||
importpath = "github.com/sourcegraph/sourcegraph/cmd/enterprise-portal/internal/database/codyaccess", | ||
visibility = ["//cmd/enterprise-portal:__subpackages__"], | ||
deps = ["//cmd/enterprise-portal/internal/database/subscriptions"], | ||
deps = [ | ||
"//cmd/enterprise-portal/internal/database/internal/pgxerrors", | ||
"//cmd/enterprise-portal/internal/database/internal/upsert", | ||
"//cmd/enterprise-portal/internal/database/subscriptions", | ||
"//internal/license", | ||
"//lib/errors", | ||
"@com_github_jackc_pgx_v5//:pgx", | ||
"@com_github_jackc_pgx_v5//pgxpool", | ||
"@io_gorm_gorm//:gorm", | ||
], | ||
) | ||
|
||
go_test( | ||
name = "codyaccess_test", | ||
srcs = ["codygateway_test.go"], | ||
tags = [ | ||
TAG_INFRA_CORESERVICES, | ||
"requires-network", | ||
], | ||
deps = [ | ||
":codyaccess", | ||
"//cmd/enterprise-portal/internal/database", | ||
"//cmd/enterprise-portal/internal/database/databasetest", | ||
"//cmd/enterprise-portal/internal/database/internal/tables", | ||
"//cmd/enterprise-portal/internal/database/internal/utctime", | ||
"//cmd/enterprise-portal/internal/database/subscriptions", | ||
"//internal/license", | ||
"//lib/pointers", | ||
"@com_github_google_uuid//:uuid", | ||
"@com_github_stretchr_testify//assert", | ||
"@com_github_stretchr_testify//require", | ||
], | ||
) |
295 changes: 283 additions & 12 deletions
295
cmd/enterprise-portal/internal/database/codyaccess/codygateway.go
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,30 +1,301 @@ | ||
package codyaccess | ||
|
||
import "github.com/sourcegraph/sourcegraph/cmd/enterprise-portal/internal/database/subscriptions" | ||
import ( | ||
"context" | ||
"database/sql" | ||
"fmt" | ||
"strings" | ||
|
||
type CodyGatewayAccess struct { | ||
// ⚠️ DO NOT USE: This field is only used for creating foreign key constraint. | ||
"github.com/jackc/pgx/v5" | ||
"github.com/jackc/pgx/v5/pgxpool" | ||
"gorm.io/gorm" | ||
|
||
"github.com/sourcegraph/sourcegraph/cmd/enterprise-portal/internal/database/internal/pgxerrors" | ||
"github.com/sourcegraph/sourcegraph/cmd/enterprise-portal/internal/database/internal/upsert" | ||
"github.com/sourcegraph/sourcegraph/cmd/enterprise-portal/internal/database/subscriptions" | ||
"github.com/sourcegraph/sourcegraph/internal/license" | ||
"github.com/sourcegraph/sourcegraph/lib/errors" | ||
) | ||
|
||
// ⚠️ DO NOT USE: This type is only used for creating foreign key constraints | ||
// and initializing tables with gorm. | ||
type TableCodyGatewayAccess struct { | ||
Subscription *subscriptions.TableSubscription `gorm:"foreignKey:SubscriptionID"` | ||
|
||
CodyGatewayAccess | ||
} | ||
|
||
func (*TableCodyGatewayAccess) TableName() string { | ||
return "enterprise_portal_cody_gateway_access" | ||
} | ||
|
||
func (t *TableCodyGatewayAccess) RunCustomMigrations(migrator gorm.Migrator) error { | ||
// gorm seems to refuse to drop the 'not null' constriant on a column | ||
// unless we forcibly run AlterColumn. | ||
columns := []string{ | ||
"chat_completions_rate_limit", | ||
"chat_completions_rate_limit_interval_seconds", | ||
"code_completions_rate_limit", | ||
"code_completions_rate_limit_interval_seconds", | ||
"embeddings_rate_limit", | ||
"embeddings_rate_limit_interval_seconds", | ||
} | ||
for _, column := range columns { | ||
if err := migrator.AlterColumn(t, column); err != nil { | ||
return err | ||
} | ||
} | ||
return nil | ||
} | ||
|
||
type CodyGatewayAccess struct { | ||
// SubscriptionID is the internal unprefixed UUID of the related subscription. | ||
SubscriptionID string `gorm:"type:uuid;not null;unique"` | ||
|
||
// Whether or not a subscription has Cody Gateway access enabled. | ||
Enabled bool `gorm:"not null"` | ||
Enabled bool `gorm:"not null;default:false"` | ||
|
||
// chat_completions_rate_limit | ||
ChatCompletionsRateLimit int64 `gorm:"type:bigint;not null"` | ||
ChatCompletionsRateLimitIntervalSeconds int `gorm:"not null"` | ||
ChatCompletionsRateLimit sql.NullInt64 | ||
ChatCompletionsRateLimitIntervalSeconds sql.NullInt32 | ||
|
||
// code_completions_rate_limit | ||
CodeCompletionsRateLimit int64 `gorm:"type:bigint;not null"` | ||
CodeCompletionsRateLimitIntervalSeconds int `gorm:"not null"` | ||
CodeCompletionsRateLimit sql.NullInt64 | ||
CodeCompletionsRateLimitIntervalSeconds sql.NullInt32 | ||
|
||
// embeddings_rate_limit | ||
EmbeddingsRateLimit int64 `gorm:"type:bigint;not null"` | ||
EmbeddingsRateLimitIntervalSeconds int `gorm:"not null"` | ||
EmbeddingsRateLimit sql.NullInt64 | ||
EmbeddingsRateLimitIntervalSeconds sql.NullInt32 | ||
} | ||
|
||
func (s *CodyGatewayAccess) TableName() string { | ||
return "enterprise_portal_cody_gateway_access" | ||
// codyGatewayAccessTableColumns must match scanCodyGatewayAccess() values. | ||
// Requires 'codyGatewayAccessJoinClauses'. | ||
func codyGatewayAccessTableColumns() []string { | ||
return []string{ | ||
"subscription.id", | ||
"enabled", | ||
"chat_completions_rate_limit", | ||
"chat_completions_rate_limit_interval_seconds", | ||
"code_completions_rate_limit", | ||
"code_completions_rate_limit_interval_seconds", | ||
"embeddings_rate_limit", | ||
"embeddings_rate_limit_interval_seconds", | ||
// Subscriptions | ||
"subscription.display_name", | ||
// Licenses - depends on license key info | ||
"active_license.license_data->'Info' as active_license_info", | ||
"tokens.license_key_hashes as license_key_hashes", | ||
} | ||
} | ||
|
||
// scanCodyGatewayAccess matches s.columns() values. | ||
func scanCodyGatewayAccess(row pgx.Row) (*CodyGatewayAccessWithSubscriptionDetails, error) { | ||
var a CodyGatewayAccessWithSubscriptionDetails | ||
// RIGHT JOIN may surface null in enterprise_portal_cody_gateway_access if | ||
// an active subscription exists, but explicit access is not configured. In | ||
// this case we still need to return a valid CodyGatewayAccessWithSubscriptionDetails, | ||
// just with empty fields. | ||
var maybeEnabled *bool | ||
err := row.Scan( | ||
&a.SubscriptionID, | ||
&maybeEnabled, | ||
&a.ChatCompletionsRateLimit, | ||
&a.ChatCompletionsRateLimitIntervalSeconds, | ||
&a.CodeCompletionsRateLimit, | ||
&a.CodeCompletionsRateLimitIntervalSeconds, | ||
&a.EmbeddingsRateLimit, | ||
&a.EmbeddingsRateLimitIntervalSeconds, | ||
// Subscriptions fields | ||
&a.DisplayName, | ||
// License fields | ||
&a.ActiveLicenseInfo, | ||
&a.LicenseKeyHashes, | ||
) | ||
if err != nil { | ||
return nil, err | ||
} | ||
if maybeEnabled != nil { | ||
a.Enabled = *maybeEnabled | ||
} | ||
return &a, nil | ||
} | ||
|
||
const codyGatewayAccessJoinClauses = ` | ||
-- We want Cody Gateway access records for every subscription, even if an | ||
-- an explicit one doesn't exist yet. | ||
RIGHT JOIN | ||
enterprise_portal_subscriptions AS subscription | ||
ON access.subscription_id = subscription.id | ||
-- Join against the "active license" of a subscription, which is currently used | ||
-- as the source for default subscription access properties. | ||
-- We may want to move user counts, product tags, etc. to the subscription table | ||
-- in the future instead. | ||
LEFT JOIN | ||
enterprise_portal_subscription_licenses AS active_license | ||
ON active_license.id = ( | ||
SELECT id | ||
FROM enterprise_portal_subscription_licenses | ||
WHERE | ||
enterprise_portal_subscription_licenses.license_type = 'ENTERPRISE_SUBSCRIPTION_LICENSE_TYPE_KEY' | ||
AND access.subscription_id = enterprise_portal_subscription_licenses.subscription_id | ||
-- Get most recently created license key as the "active license" | ||
ORDER BY enterprise_portal_subscription_licenses.created_at DESC | ||
LIMIT 1 | ||
) | ||
-- Join against collected license key hashes of each subscription, which we use | ||
-- as 'access tokens' to Cody Gateway | ||
LEFT JOIN ( | ||
SELECT | ||
licenses.subscription_id, | ||
ARRAY_AGG(digest(licenses.license_data->>'SignedKey','sha256')) AS license_key_hashes | ||
FROM | ||
enterprise_portal_subscription_licenses AS licenses | ||
WHERE | ||
licenses.license_type = 'ENTERPRISE_SUBSCRIPTION_LICENSE_TYPE_KEY' | ||
AND licenses.expire_at > NOW() -- expires in future | ||
AND licenses.revoked_at IS NULL -- is not revoked | ||
GROUP BY | ||
licenses.subscription_id | ||
) tokens ON tokens.subscription_id = subscription.id | ||
` | ||
|
||
// Store is the storage layer for Cody Gateway access. It aims to mirror the | ||
// existing behaviour as close as possible, and as such has extensive | ||
// dependencies on licensing. | ||
type CodyGatewayStore struct { | ||
db *pgxpool.Pool | ||
} | ||
|
||
func NewCodyGatewayStore(db *pgxpool.Pool) *CodyGatewayStore { | ||
return &CodyGatewayStore{db: db} | ||
} | ||
|
||
// CodyGatewayAccessWithSubscriptionDetails extends CodyGatewayAccess with metadata from | ||
// other tables used in the codyaccess API. | ||
type CodyGatewayAccessWithSubscriptionDetails struct { | ||
CodyGatewayAccess | ||
|
||
// DisplayName is the display name of the related subscription. | ||
DisplayName string | ||
|
||
ActiveLicenseInfo *license.Info | ||
LicenseKeyHashes [][]byte | ||
} | ||
|
||
var ErrSubscriptionDoesNotExist = errors.New("subscription does not exist") | ||
|
||
// Get returns the Cody Gateway access for the given subscription. | ||
func (s *CodyGatewayStore) Get(ctx context.Context, subscriptionID string) (*CodyGatewayAccessWithSubscriptionDetails, error) { | ||
query := fmt.Sprintf(`SELECT | ||
%s | ||
FROM | ||
enterprise_portal_cody_gateway_access AS access | ||
%s | ||
WHERE | ||
subscription.id = @subscriptionID | ||
AND subscription.archived_at IS NULL`, | ||
strings.Join(codyGatewayAccessTableColumns(), ", "), | ||
codyGatewayAccessJoinClauses) | ||
|
||
sub, err := scanCodyGatewayAccess(s.db.QueryRow(ctx, query, pgx.NamedArgs{ | ||
"subscriptionID": subscriptionID, | ||
})) | ||
if err != nil { | ||
if errors.Is(err, pgx.ErrNoRows) { | ||
// RIGHT JOIN in query ensures that if we find no result, it's | ||
// because the subscription does not exist or is archived. | ||
return nil, errors.WithSafeDetails( | ||
errors.WithStack(ErrSubscriptionDoesNotExist), | ||
err.Error()) | ||
} | ||
return nil, err | ||
} | ||
return sub, nil | ||
} | ||
|
||
func (s *CodyGatewayStore) List(ctx context.Context) ([]*CodyGatewayAccessWithSubscriptionDetails, error) { | ||
query := fmt.Sprintf(`SELECT | ||
%s | ||
FROM | ||
enterprise_portal_cody_gateway_access AS access | ||
%s | ||
WHERE | ||
subscription.archived_at IS NULL`, | ||
strings.Join(codyGatewayAccessTableColumns(), ", "), | ||
codyGatewayAccessJoinClauses) | ||
|
||
rows, err := s.db.Query(ctx, query) | ||
if err != nil { | ||
return nil, err | ||
} | ||
defer rows.Close() | ||
var accs []*CodyGatewayAccessWithSubscriptionDetails | ||
for rows.Next() { | ||
sub, err := scanCodyGatewayAccess(rows) | ||
if err != nil { | ||
return nil, err | ||
} | ||
accs = append(accs, sub) | ||
} | ||
if err := rows.Err(); err != nil { | ||
return nil, err | ||
} | ||
return accs, nil | ||
} | ||
|
||
type UpsertCodyGatewayAccessOptions struct { | ||
// Whether or not a subscription has Cody Gateway access enabled. | ||
Enabled *bool | ||
|
||
// chat_completions_rate_limit | ||
ChatCompletionsRateLimit *int64 | ||
ChatCompletionsRateLimitIntervalSeconds *int | ||
|
||
// code_completions_rate_limit | ||
CodeCompletionsRateLimit *int64 | ||
CodeCompletionsRateLimitIntervalSeconds *int | ||
|
||
// embeddings_rate_limit | ||
EmbeddingsRateLimit *int64 | ||
EmbeddingsRateLimitIntervalSeconds *int | ||
|
||
// ForceUpdate indicates whether to force update all fields of the subscription | ||
// record. | ||
ForceUpdate bool | ||
} | ||
|
||
// toQuery returns the query based on the options. It returns an empty query if | ||
// nothing to update. | ||
func (opts UpsertCodyGatewayAccessOptions) Exec(ctx context.Context, db *pgxpool.Pool, subscriptionID string) error { | ||
b := upsert.New("enterprise_portal_cody_gateway_access", "subscription_id", opts.ForceUpdate) | ||
upsert.Field(b, "subscription_id", subscriptionID) | ||
upsert.Field(b, "enabled", opts.Enabled, | ||
upsert.WithColumnDefault(), | ||
upsert.WithValueOnForceUpdate(false)) | ||
upsert.Field(b, "chat_completions_rate_limit", opts.ChatCompletionsRateLimit) | ||
upsert.Field(b, "chat_completions_rate_limit_interval_seconds", opts.ChatCompletionsRateLimitIntervalSeconds) | ||
upsert.Field(b, "code_completions_rate_limit", opts.CodeCompletionsRateLimit) | ||
upsert.Field(b, "code_completions_rate_limit_interval_seconds", opts.CodeCompletionsRateLimitIntervalSeconds) | ||
upsert.Field(b, "embeddings_rate_limit", opts.EmbeddingsRateLimit) | ||
upsert.Field(b, "embeddings_rate_limit_interval_seconds", opts.EmbeddingsRateLimitIntervalSeconds) | ||
return b.Exec(ctx, db) | ||
} | ||
|
||
// Upsert upserts a Cody Gatweway access record based on the given options. | ||
// The caller should check that the subscription is not archived. | ||
// | ||
// If the subscription does not exist, then ErrSubscriptionDoesNotExist is | ||
// returned. | ||
func (s *CodyGatewayStore) Upsert(ctx context.Context, subscriptionID string, opts UpsertCodyGatewayAccessOptions) (*CodyGatewayAccessWithSubscriptionDetails, error) { | ||
if err := opts.Exec(ctx, s.db, subscriptionID); err != nil { | ||
if pgxerrors.IsContraintError(err, "fk_enterprise_portal_cody_gateway_access_subscription") { | ||
return nil, errors.WithSafeDetails( | ||
errors.WithStack(ErrSubscriptionDoesNotExist), | ||
err.Error()) | ||
} | ||
return nil, err | ||
} | ||
return s.Get(ctx, subscriptionID) | ||
} |
Oops, something went wrong.