Skip to content
This repository has been archived by the owner on Sep 30, 2024. It is now read-only.

Commit

Permalink
feat/enterpriseportal: db layer for cody gateway access (#63737)
Browse files Browse the repository at this point in the history
Implements Cody Gateway access in Enterprise Portal DB, such that it
replicates the behaviour it has today, retrieving:

- Quota overrides
- Subscription display name
- Active license info
- Non-revoked, non-expired license keys as hashes
- Revocation + non-expiry replaces the existing mechanism of flagging
licenses as `access_token_enabled`. Since we ended up doing zero-config
for Cody Gateway, the only license hashes that are valid for Cody
Gateway are non-expired licenses - once your license expires you should
be switching to a new license key anyway.

It's fairly similar to the `dotcomdb` shim we built before, but for our
new tables. See https://github.com/sourcegraph/sourcegraph/pull/63792
for the licenses tables.

None of this is going live yet.

Part of https://linear.app/sourcegraph/issue/CORE-100
Part of https://linear.app/sourcegraph/issue/CORE-160

## Test plan

DB integration tests

`sg run enterprise-portal` does the migrations without a hitch
  • Loading branch information
bobheadxi authored Jul 19, 2024
1 parent 7a16ccf commit df228a7
Show file tree
Hide file tree
Showing 10 changed files with 679 additions and 17 deletions.
34 changes: 33 additions & 1 deletion cmd/enterprise-portal/internal/database/codyaccess/BUILD.bazel
Original file line number Diff line number Diff line change
@@ -1,9 +1,41 @@
load("//dev:go_defs.bzl", "go_test")
load("@io_bazel_rules_go//go:def.bzl", "go_library")

go_library(
name = "codyaccess",
srcs = ["codygateway.go"],
importpath = "github.com/sourcegraph/sourcegraph/cmd/enterprise-portal/internal/database/codyaccess",
visibility = ["//cmd/enterprise-portal:__subpackages__"],
deps = ["//cmd/enterprise-portal/internal/database/subscriptions"],
deps = [
"//cmd/enterprise-portal/internal/database/internal/pgxerrors",
"//cmd/enterprise-portal/internal/database/internal/upsert",
"//cmd/enterprise-portal/internal/database/subscriptions",
"//internal/license",
"//lib/errors",
"@com_github_jackc_pgx_v5//:pgx",
"@com_github_jackc_pgx_v5//pgxpool",
"@io_gorm_gorm//:gorm",
],
)

go_test(
name = "codyaccess_test",
srcs = ["codygateway_test.go"],
tags = [
TAG_INFRA_CORESERVICES,
"requires-network",
],
deps = [
":codyaccess",
"//cmd/enterprise-portal/internal/database",
"//cmd/enterprise-portal/internal/database/databasetest",
"//cmd/enterprise-portal/internal/database/internal/tables",
"//cmd/enterprise-portal/internal/database/internal/utctime",
"//cmd/enterprise-portal/internal/database/subscriptions",
"//internal/license",
"//lib/pointers",
"@com_github_google_uuid//:uuid",
"@com_github_stretchr_testify//assert",
"@com_github_stretchr_testify//require",
],
)
295 changes: 283 additions & 12 deletions cmd/enterprise-portal/internal/database/codyaccess/codygateway.go
Original file line number Diff line number Diff line change
@@ -1,30 +1,301 @@
package codyaccess

import "github.com/sourcegraph/sourcegraph/cmd/enterprise-portal/internal/database/subscriptions"
import (
"context"
"database/sql"
"fmt"
"strings"

type CodyGatewayAccess struct {
// ⚠️ DO NOT USE: This field is only used for creating foreign key constraint.
"github.com/jackc/pgx/v5"
"github.com/jackc/pgx/v5/pgxpool"
"gorm.io/gorm"

"github.com/sourcegraph/sourcegraph/cmd/enterprise-portal/internal/database/internal/pgxerrors"
"github.com/sourcegraph/sourcegraph/cmd/enterprise-portal/internal/database/internal/upsert"
"github.com/sourcegraph/sourcegraph/cmd/enterprise-portal/internal/database/subscriptions"
"github.com/sourcegraph/sourcegraph/internal/license"
"github.com/sourcegraph/sourcegraph/lib/errors"
)

// ⚠️ DO NOT USE: This type is only used for creating foreign key constraints
// and initializing tables with gorm.
type TableCodyGatewayAccess struct {
Subscription *subscriptions.TableSubscription `gorm:"foreignKey:SubscriptionID"`

CodyGatewayAccess
}

func (*TableCodyGatewayAccess) TableName() string {
return "enterprise_portal_cody_gateway_access"
}

func (t *TableCodyGatewayAccess) RunCustomMigrations(migrator gorm.Migrator) error {
// gorm seems to refuse to drop the 'not null' constriant on a column
// unless we forcibly run AlterColumn.
columns := []string{
"chat_completions_rate_limit",
"chat_completions_rate_limit_interval_seconds",
"code_completions_rate_limit",
"code_completions_rate_limit_interval_seconds",
"embeddings_rate_limit",
"embeddings_rate_limit_interval_seconds",
}
for _, column := range columns {
if err := migrator.AlterColumn(t, column); err != nil {
return err
}
}
return nil
}

type CodyGatewayAccess struct {
// SubscriptionID is the internal unprefixed UUID of the related subscription.
SubscriptionID string `gorm:"type:uuid;not null;unique"`

// Whether or not a subscription has Cody Gateway access enabled.
Enabled bool `gorm:"not null"`
Enabled bool `gorm:"not null;default:false"`

// chat_completions_rate_limit
ChatCompletionsRateLimit int64 `gorm:"type:bigint;not null"`
ChatCompletionsRateLimitIntervalSeconds int `gorm:"not null"`
ChatCompletionsRateLimit sql.NullInt64
ChatCompletionsRateLimitIntervalSeconds sql.NullInt32

// code_completions_rate_limit
CodeCompletionsRateLimit int64 `gorm:"type:bigint;not null"`
CodeCompletionsRateLimitIntervalSeconds int `gorm:"not null"`
CodeCompletionsRateLimit sql.NullInt64
CodeCompletionsRateLimitIntervalSeconds sql.NullInt32

// embeddings_rate_limit
EmbeddingsRateLimit int64 `gorm:"type:bigint;not null"`
EmbeddingsRateLimitIntervalSeconds int `gorm:"not null"`
EmbeddingsRateLimit sql.NullInt64
EmbeddingsRateLimitIntervalSeconds sql.NullInt32
}

func (s *CodyGatewayAccess) TableName() string {
return "enterprise_portal_cody_gateway_access"
// codyGatewayAccessTableColumns must match scanCodyGatewayAccess() values.
// Requires 'codyGatewayAccessJoinClauses'.
func codyGatewayAccessTableColumns() []string {
return []string{
"subscription.id",
"enabled",
"chat_completions_rate_limit",
"chat_completions_rate_limit_interval_seconds",
"code_completions_rate_limit",
"code_completions_rate_limit_interval_seconds",
"embeddings_rate_limit",
"embeddings_rate_limit_interval_seconds",
// Subscriptions
"subscription.display_name",
// Licenses - depends on license key info
"active_license.license_data->'Info' as active_license_info",
"tokens.license_key_hashes as license_key_hashes",
}
}

// scanCodyGatewayAccess matches s.columns() values.
func scanCodyGatewayAccess(row pgx.Row) (*CodyGatewayAccessWithSubscriptionDetails, error) {
var a CodyGatewayAccessWithSubscriptionDetails
// RIGHT JOIN may surface null in enterprise_portal_cody_gateway_access if
// an active subscription exists, but explicit access is not configured. In
// this case we still need to return a valid CodyGatewayAccessWithSubscriptionDetails,
// just with empty fields.
var maybeEnabled *bool
err := row.Scan(
&a.SubscriptionID,
&maybeEnabled,
&a.ChatCompletionsRateLimit,
&a.ChatCompletionsRateLimitIntervalSeconds,
&a.CodeCompletionsRateLimit,
&a.CodeCompletionsRateLimitIntervalSeconds,
&a.EmbeddingsRateLimit,
&a.EmbeddingsRateLimitIntervalSeconds,
// Subscriptions fields
&a.DisplayName,
// License fields
&a.ActiveLicenseInfo,
&a.LicenseKeyHashes,
)
if err != nil {
return nil, err
}
if maybeEnabled != nil {
a.Enabled = *maybeEnabled
}
return &a, nil
}

const codyGatewayAccessJoinClauses = `
-- We want Cody Gateway access records for every subscription, even if an
-- an explicit one doesn't exist yet.
RIGHT JOIN
enterprise_portal_subscriptions AS subscription
ON access.subscription_id = subscription.id
-- Join against the "active license" of a subscription, which is currently used
-- as the source for default subscription access properties.
-- We may want to move user counts, product tags, etc. to the subscription table
-- in the future instead.
LEFT JOIN
enterprise_portal_subscription_licenses AS active_license
ON active_license.id = (
SELECT id
FROM enterprise_portal_subscription_licenses
WHERE
enterprise_portal_subscription_licenses.license_type = 'ENTERPRISE_SUBSCRIPTION_LICENSE_TYPE_KEY'
AND access.subscription_id = enterprise_portal_subscription_licenses.subscription_id
-- Get most recently created license key as the "active license"
ORDER BY enterprise_portal_subscription_licenses.created_at DESC
LIMIT 1
)
-- Join against collected license key hashes of each subscription, which we use
-- as 'access tokens' to Cody Gateway
LEFT JOIN (
SELECT
licenses.subscription_id,
ARRAY_AGG(digest(licenses.license_data->>'SignedKey','sha256')) AS license_key_hashes
FROM
enterprise_portal_subscription_licenses AS licenses
WHERE
licenses.license_type = 'ENTERPRISE_SUBSCRIPTION_LICENSE_TYPE_KEY'
AND licenses.expire_at > NOW() -- expires in future
AND licenses.revoked_at IS NULL -- is not revoked
GROUP BY
licenses.subscription_id
) tokens ON tokens.subscription_id = subscription.id
`

// Store is the storage layer for Cody Gateway access. It aims to mirror the
// existing behaviour as close as possible, and as such has extensive
// dependencies on licensing.
type CodyGatewayStore struct {
db *pgxpool.Pool
}

func NewCodyGatewayStore(db *pgxpool.Pool) *CodyGatewayStore {
return &CodyGatewayStore{db: db}
}

// CodyGatewayAccessWithSubscriptionDetails extends CodyGatewayAccess with metadata from
// other tables used in the codyaccess API.
type CodyGatewayAccessWithSubscriptionDetails struct {
CodyGatewayAccess

// DisplayName is the display name of the related subscription.
DisplayName string

ActiveLicenseInfo *license.Info
LicenseKeyHashes [][]byte
}

var ErrSubscriptionDoesNotExist = errors.New("subscription does not exist")

// Get returns the Cody Gateway access for the given subscription.
func (s *CodyGatewayStore) Get(ctx context.Context, subscriptionID string) (*CodyGatewayAccessWithSubscriptionDetails, error) {
query := fmt.Sprintf(`SELECT
%s
FROM
enterprise_portal_cody_gateway_access AS access
%s
WHERE
subscription.id = @subscriptionID
AND subscription.archived_at IS NULL`,
strings.Join(codyGatewayAccessTableColumns(), ", "),
codyGatewayAccessJoinClauses)

sub, err := scanCodyGatewayAccess(s.db.QueryRow(ctx, query, pgx.NamedArgs{
"subscriptionID": subscriptionID,
}))
if err != nil {
if errors.Is(err, pgx.ErrNoRows) {
// RIGHT JOIN in query ensures that if we find no result, it's
// because the subscription does not exist or is archived.
return nil, errors.WithSafeDetails(
errors.WithStack(ErrSubscriptionDoesNotExist),
err.Error())
}
return nil, err
}
return sub, nil
}

func (s *CodyGatewayStore) List(ctx context.Context) ([]*CodyGatewayAccessWithSubscriptionDetails, error) {
query := fmt.Sprintf(`SELECT
%s
FROM
enterprise_portal_cody_gateway_access AS access
%s
WHERE
subscription.archived_at IS NULL`,
strings.Join(codyGatewayAccessTableColumns(), ", "),
codyGatewayAccessJoinClauses)

rows, err := s.db.Query(ctx, query)
if err != nil {
return nil, err
}
defer rows.Close()
var accs []*CodyGatewayAccessWithSubscriptionDetails
for rows.Next() {
sub, err := scanCodyGatewayAccess(rows)
if err != nil {
return nil, err
}
accs = append(accs, sub)
}
if err := rows.Err(); err != nil {
return nil, err
}
return accs, nil
}

type UpsertCodyGatewayAccessOptions struct {
// Whether or not a subscription has Cody Gateway access enabled.
Enabled *bool

// chat_completions_rate_limit
ChatCompletionsRateLimit *int64
ChatCompletionsRateLimitIntervalSeconds *int

// code_completions_rate_limit
CodeCompletionsRateLimit *int64
CodeCompletionsRateLimitIntervalSeconds *int

// embeddings_rate_limit
EmbeddingsRateLimit *int64
EmbeddingsRateLimitIntervalSeconds *int

// ForceUpdate indicates whether to force update all fields of the subscription
// record.
ForceUpdate bool
}

// toQuery returns the query based on the options. It returns an empty query if
// nothing to update.
func (opts UpsertCodyGatewayAccessOptions) Exec(ctx context.Context, db *pgxpool.Pool, subscriptionID string) error {
b := upsert.New("enterprise_portal_cody_gateway_access", "subscription_id", opts.ForceUpdate)
upsert.Field(b, "subscription_id", subscriptionID)
upsert.Field(b, "enabled", opts.Enabled,
upsert.WithColumnDefault(),
upsert.WithValueOnForceUpdate(false))
upsert.Field(b, "chat_completions_rate_limit", opts.ChatCompletionsRateLimit)
upsert.Field(b, "chat_completions_rate_limit_interval_seconds", opts.ChatCompletionsRateLimitIntervalSeconds)
upsert.Field(b, "code_completions_rate_limit", opts.CodeCompletionsRateLimit)
upsert.Field(b, "code_completions_rate_limit_interval_seconds", opts.CodeCompletionsRateLimitIntervalSeconds)
upsert.Field(b, "embeddings_rate_limit", opts.EmbeddingsRateLimit)
upsert.Field(b, "embeddings_rate_limit_interval_seconds", opts.EmbeddingsRateLimitIntervalSeconds)
return b.Exec(ctx, db)
}

// Upsert upserts a Cody Gatweway access record based on the given options.
// The caller should check that the subscription is not archived.
//
// If the subscription does not exist, then ErrSubscriptionDoesNotExist is
// returned.
func (s *CodyGatewayStore) Upsert(ctx context.Context, subscriptionID string, opts UpsertCodyGatewayAccessOptions) (*CodyGatewayAccessWithSubscriptionDetails, error) {
if err := opts.Exec(ctx, s.db, subscriptionID); err != nil {
if pgxerrors.IsContraintError(err, "fk_enterprise_portal_cody_gateway_access_subscription") {
return nil, errors.WithSafeDetails(
errors.WithStack(ErrSubscriptionDoesNotExist),
err.Error())
}
return nil, err
}
return s.Get(ctx, subscriptionID)
}
Loading

0 comments on commit df228a7

Please sign in to comment.