Skip to content

Commit

Permalink
Neon $withAuth (#3562)
Browse files Browse the repository at this point in the history
* Added `` for `neon-http` driver, updated `@neondatabase/serverless` version, fixed `.catch` on `` requiring strict return type

* Fixed package version mismatch

---------

Co-authored-by: Andrii Sherman <[email protected]>
  • Loading branch information
Sukairo-02 and AndriiSherman authored Nov 21, 2024
1 parent fcaa0a5 commit a2d734c
Show file tree
Hide file tree
Showing 15 changed files with 8,090 additions and 4,559 deletions.
4 changes: 2 additions & 2 deletions drizzle-orm/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@
"@electric-sql/pglite": ">=0.2.0",
"@libsql/client": ">=0.10.0",
"@libsql/client-wasm": ">=0.10.0",
"@neondatabase/serverless": ">=0.1",
"@neondatabase/serverless": ">=0.10.0",
"@op-engineering/op-sqlite": ">=2",
"@opentelemetry/api": "^1.4.1",
"@planetscale/database": ">=1",
Expand Down Expand Up @@ -169,7 +169,7 @@
"@libsql/client": "^0.10.0",
"@libsql/client-wasm": "^0.10.0",
"@miniflare/d1": "^2.14.4",
"@neondatabase/serverless": "^0.9.0",
"@neondatabase/serverless": "^0.10.0",
"@op-engineering/op-sqlite": "^2.0.16",
"@opentelemetry/api": "^1.4.1",
"@originjs/vite-plugin-commonjs": "^1.0.3",
Expand Down
56 changes: 56 additions & 0 deletions drizzle-orm/src/neon-http/driver.ts
Original file line number Diff line number Diff line change
Expand Up @@ -40,11 +40,67 @@ export class NeonHttpDriver {
}
}

function wrap<T extends object>(
target: T,
token: string,
cb: (target: any, p: string | symbol, res: any) => any,
deep?: boolean,
) {
return new Proxy(target, {
get(target, p) {
const element = target[p as keyof typeof p];
if (typeof element !== 'function' && (typeof element !== 'object' || element === null)) return element;

if (deep) return wrap(element, token, cb);
if (p === 'query') return wrap(element, token, cb, true);

return new Proxy(element as any, {
apply(target, thisArg, argArray) {
const res = target.call(thisArg, ...argArray);
if ('setToken' in res && typeof res.setToken === 'function') {
res.setToken(token);
}
return cb(target, p, res);
},
});
},
});
}

export class NeonHttpDatabase<
TSchema extends Record<string, unknown> = Record<string, never>,
> extends PgDatabase<NeonHttpQueryResultHKT, TSchema> {
static override readonly [entityKind]: string = 'NeonHttpDatabase';

$withAuth(
token: string,
): Omit<
this,
Exclude<
keyof this,
| '$count'
| 'delete'
| 'select'
| 'selectDistinct'
| 'selectDistinctOn'
| 'update'
| 'insert'
| 'with'
| 'query'
| 'execute'
| 'refreshMaterializedView'
>
> {
this.authToken = token;

return wrap(this, token, (target, p, res) => {
if (p === 'with') {
return wrap(res, token, (_, __, res) => res);
}
return res;
});
}

/** @internal */
declare readonly session: NeonHttpSession<TSchema, ExtractTablesWithRelations<TSchema>>;

Expand Down
64 changes: 54 additions & 10 deletions drizzle-orm/src/neon-http/session.ts
Original file line number Diff line number Diff line change
Expand Up @@ -38,18 +38,43 @@ export class NeonHttpPreparedQuery<T extends PreparedQueryConfig> extends PgPrep
super(query);
}

async execute(placeholderValues: Record<string, unknown> | undefined = {}): Promise<T['execute']> {
async execute(placeholderValues: Record<string, unknown> | undefined): Promise<T['execute']>;
/** @internal */
async execute(placeholderValues: Record<string, unknown> | undefined, token?: string): Promise<T['execute']>;
/** @internal */
async execute(
placeholderValues: Record<string, unknown> | undefined = {},
token: string | undefined = this.authToken,
): Promise<T['execute']> {
const params = fillPlaceholders(this.query.params, placeholderValues);

this.logger.logQuery(this.query.sql, params);

const { fields, client, query, customResultMapper } = this;

if (!fields && !customResultMapper) {
return client(query.sql, params, rawQueryConfig);
return client(
query.sql,
params,
token === undefined
? rawQueryConfig
: {
...rawQueryConfig,
authToken: token,
},
);
}

const result = await client(query.sql, params, queryConfig);
const result = await client(
query.sql,
params,
token === undefined
? queryConfig
: {
...queryConfig,
authToken: token,
},
);

return this.mapResult(result);
}
Expand All @@ -71,13 +96,26 @@ export class NeonHttpPreparedQuery<T extends PreparedQueryConfig> extends PgPrep
all(placeholderValues: Record<string, unknown> | undefined = {}): Promise<T['all']> {
const params = fillPlaceholders(this.query.params, placeholderValues);
this.logger.logQuery(this.query.sql, params);
return this.client(this.query.sql, params, rawQueryConfig).then((result) => result.rows);
return this.client(
this.query.sql,
params,
this.authToken === undefined ? rawQueryConfig : {
...rawQueryConfig,
authToken: this.authToken,
},
).then((result) => result.rows);
}

values(placeholderValues: Record<string, unknown> | undefined = {}): Promise<T['values']> {
values(placeholderValues: Record<string, unknown> | undefined): Promise<T['values']>;
/** @internal */
values(placeholderValues: Record<string, unknown> | undefined, token?: string): Promise<T['values']>;
/** @internal */
values(placeholderValues: Record<string, unknown> | undefined = {}, token?: string): Promise<T['values']> {
const params = fillPlaceholders(this.query.params, placeholderValues);
this.logger.logQuery(this.query.sql, params);
return this.client(this.query.sql, params, { arrayMode: true, fullResults: true }).then((result) => result.rows);
return this.client(this.query.sql, params, { arrayMode: true, fullResults: true, authToken: token }).then((
result,
) => result.rows);
}

/** @internal */
Expand Down Expand Up @@ -125,7 +163,9 @@ export class NeonHttpSession<
);
}

async batch<U extends BatchItem<'pg'>, T extends Readonly<[U, ...U[]]>>(queries: T) {
async batch<U extends BatchItem<'pg'>, T extends Readonly<[U, ...U[]]>>(
queries: T,
) {
const preparedQueries: PreparedQuery[] = [];
const builtQueries: NeonQueryPromise<any, true>[] = [];

Expand All @@ -143,7 +183,7 @@ export class NeonHttpSession<

const batchResults = await this.client.transaction(builtQueries, queryConfig);

return batchResults.map((result, i) => preparedQueries[i]!.mapResult(result, true));
return batchResults.map((result, i) => preparedQueries[i]!.mapResult(result, true)) as any;
}

// change return type to QueryRows<true>
Expand All @@ -161,8 +201,12 @@ export class NeonHttpSession<
return this.client(query, params, { arrayMode: false, fullResults: true });
}

override async count(sql: SQL): Promise<number> {
const res = await this.execute<{ rows: [{ count: string }] }>(sql);
override async count(sql: SQL): Promise<number>;
/** @internal */
override async count(sql: SQL, token?: string): Promise<number>;
/** @internal */
override async count(sql: SQL, token?: string): Promise<number> {
const res = await this.execute<{ rows: [{ count: string }] }>(sql, token);

return Number(
res['rows'][0]['count'],
Expand Down
4 changes: 3 additions & 1 deletion drizzle-orm/src/pg-core/db.ts
Original file line number Diff line number Diff line change
Expand Up @@ -597,6 +597,8 @@ export class PgDatabase<
return new PgRefreshMaterializedView(view, this.session, this.dialect);
}

protected authToken?: string;

execute<TRow extends Record<string, unknown> = Record<string, unknown>>(
query: SQLWrapper | string,
): PgRaw<PgQueryResultKind<TQueryResult, TRow>> {
Expand All @@ -611,7 +613,7 @@ export class PgDatabase<
false,
);
return new PgRaw(
() => prepared.execute(),
() => prepared.execute(undefined, this.authToken),
sequel,
builtQuery,
(result) => prepared.mapResult(result, true),
Expand Down
10 changes: 8 additions & 2 deletions drizzle-orm/src/pg-core/query-builders/count.ts
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ export class PgCountBuilder<
TSession extends PgSession<any, any, any>,
> extends SQL<number> implements Promise<number>, SQLWrapper {
private sql: SQL<number>;
private token?: string;

static override readonly [entityKind] = 'PgCountBuilder';
[Symbol.toStringTag] = 'PgCountBuilder';
Expand Down Expand Up @@ -46,19 +47,24 @@ export class PgCountBuilder<
);
}

/** @intrnal */
setToken(token: string) {
this.token = token;
}

then<TResult1 = number, TResult2 = never>(
onfulfilled?: ((value: number) => TResult1 | PromiseLike<TResult1>) | null | undefined,
onrejected?: ((reason: any) => TResult2 | PromiseLike<TResult2>) | null | undefined,
): Promise<TResult1 | TResult2> {
return Promise.resolve(this.session.count(this.sql))
return Promise.resolve(this.session.count(this.sql, this.token))
.then(
onfulfilled,
onrejected,
);
}

catch(
onRejected?: ((reason: any) => never | PromiseLike<never>) | null | undefined,
onRejected?: ((reason: any) => any) | null | undefined,
): Promise<number> {
return this.then(undefined, onRejected);
}
Expand Down
9 changes: 8 additions & 1 deletion drizzle-orm/src/pg-core/query-builders/delete.ts
Original file line number Diff line number Diff line change
Expand Up @@ -232,9 +232,16 @@ export class PgDeleteBase<
return this._prepare(name);
}

private authToken?: string;
/** @internal */
setToken(token: string) {
this.authToken = token;
return this;
}

override execute: ReturnType<this['prepare']>['execute'] = (placeholderValues) => {
return tracer.startActiveSpan('drizzle.operation', () => {
return this._prepare().execute(placeholderValues);
return this._prepare().execute(placeholderValues, this.authToken);
});
};

Expand Down
53 changes: 35 additions & 18 deletions drizzle-orm/src/pg-core/query-builders/insert.ts
Original file line number Diff line number Diff line change
Expand Up @@ -63,16 +63,16 @@ export class PgInsertBuilder<
private overridingSystemValue_?: boolean,
) {}

overridingSystemValue(): Omit<PgInsertBuilder<TTable, TQueryResult, true>, 'overridingSystemValue'> {
this.overridingSystemValue_ = true;
return this as any;
private authToken?: string;
/** @internal */
setToken(token: string) {
this.authToken = token;
return this;
}

values(value: PgInsertValue<TTable, OverrideT>): PgInsertBase<TTable, TQueryResult>;
values(values: PgInsertValue<TTable, OverrideT>[]): PgInsertBase<TTable, TQueryResult>;
values(
values: PgInsertValue<TTable, OverrideT> | PgInsertValue<TTable, OverrideT>[],
): PgInsertBase<TTable, TQueryResult> {
values(value: PgInsertValue<TTable>): PgInsertBase<TTable, TQueryResult>;
values(values: PgInsertValue<TTable>[]): PgInsertBase<TTable, TQueryResult>;
values(values: PgInsertValue<TTable> | PgInsertValue<TTable>[]): PgInsertBase<TTable, TQueryResult> {
values = Array.isArray(values) ? values : [values];
if (values.length === 0) {
throw new Error('values() must be called with at least one value');
Expand All @@ -87,15 +87,25 @@ export class PgInsertBuilder<
return result;
});

return new PgInsertBase(
this.table,
mappedValues,
this.session,
this.dialect,
this.withList,
false,
this.overridingSystemValue_,
);
return this.authToken === undefined
? new PgInsertBase(
this.table,
mappedValues,
this.session,
this.dialect,
this.withList,
false,
this.overridingSystemValue_,
)
: new PgInsertBase(
this.table,
mappedValues,
this.session,
this.dialect,
this.withList,
false,
this.overridingSystemValue_,
).setToken(this.authToken) as any;
}

select(selectQuery: (qb: QueryBuilder) => PgInsertSelectQueryBuilder<TTable>): PgInsertBase<TTable, TQueryResult>;
Expand Down Expand Up @@ -385,9 +395,16 @@ export class PgInsertBase<
return this._prepare(name);
}

private authToken?: string;
/** @internal */
setToken(token: string) {
this.authToken = token;
return this;
}

override execute: ReturnType<this['prepare']>['execute'] = (placeholderValues) => {
return tracer.startActiveSpan('drizzle.operation', () => {
return this._prepare().execute(placeholderValues);
return this._prepare().execute(placeholderValues, this.authToken);
});
};

Expand Down
9 changes: 8 additions & 1 deletion drizzle-orm/src/pg-core/query-builders/query.ts
Original file line number Diff line number Diff line change
Expand Up @@ -142,9 +142,16 @@ export class PgRelationalQuery<TResult> extends QueryPromise<TResult>
return this._toSQL().builtQuery;
}

private authToken?: string;
/** @internal */
setToken(token: string) {
this.authToken = token;
return this;
}

override execute(): Promise<TResult> {
return tracer.startActiveSpan('drizzle.operation', () => {
return this._prepare().execute();
return this._prepare().execute(undefined, this.authToken);
});
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -92,9 +92,16 @@ export class PgRefreshMaterializedView<TQueryResult extends PgQueryResultHKT>
return this._prepare(name);
}

private authToken?: string;
/** @internal */
setToken(token: string) {
this.authToken = token;
return this;
}

execute: ReturnType<this['prepare']>['execute'] = (placeholderValues) => {
return tracer.startActiveSpan('drizzle.operation', () => {
return this._prepare().execute(placeholderValues);
return this._prepare().execute(placeholderValues, this.authToken);
});
};
}
Loading

0 comments on commit a2d734c

Please sign in to comment.