From 030b3dffe232bf94a0ec1d67786db45500ed8df3 Mon Sep 17 00:00:00 2001 From: Jason Quense Date: Mon, 17 May 2021 11:13:11 -0400 Subject: [PATCH 1/3] WIP --- package.json | 5 +- src/GraphqlSocketSubscriptionServer.ts | 110 +++++++++++++++++++++++++ yarn.lock | 17 ++++ 3 files changed, 131 insertions(+), 1 deletion(-) create mode 100644 src/GraphqlSocketSubscriptionServer.ts diff --git a/package.json b/package.json index c307a5ba..5dc60007 100644 --- a/package.json +++ b/package.json @@ -43,8 +43,11 @@ "conventionalCommits": true }, "dependencies": { + "@types/ws": "^7.4.1", "express": "^4.17.1", - "redis": "^3.1.2" + "graphql-ws": "^4.3.2", + "redis": "^3.1.2", + "ws": "^7.4.4" }, "peerDependencies": { "graphql": ">=0.12.3", diff --git a/src/GraphqlSocketSubscriptionServer.ts b/src/GraphqlSocketSubscriptionServer.ts new file mode 100644 index 00000000..e9ac8477 --- /dev/null +++ b/src/GraphqlSocketSubscriptionServer.ts @@ -0,0 +1,110 @@ +import type { IncomingMessage } from 'http'; +import { promisify } from 'util'; + +import express from 'express'; +import type { GraphQLSchema } from 'graphql'; +import { useServer } from 'graphql-ws/lib/use/ws'; +import ws from 'ws'; + +import AuthorizedSocketConnection from './AuthorizedSocketConnection'; +import type { CreateValidationRules } from './AuthorizedSocketConnection'; +import type { CredentialsManager } from './CredentialsManager'; +import type { CreateLogger, Logger } from './Logger'; +import type { Subscriber } from './Subscriber'; + +export type SubscriptionServerConfig = { + path: string; + schema: GraphQLSchema; + subscriber: Subscriber; + createCredentialsManager: (request: any) => CredentialsManager; + hasPermission: (data: any, credentials: TCredentials) => boolean; + createContext?: ( + request: any, + credentials: TCredentials | null | undefined, + ) => TContext; + maxSubscriptionsPerConnection?: number; + createValidationRules?: CreateValidationRules; + createLogger?: CreateLogger; +}; + +// eslint-disable-next-line @typescript-eslint/no-empty-function +const defaultCreateLogger = () => () => {}; + +export default class SubscriptionServer { + config: SubscriptionServerConfig; + + log: Logger; + + server: ws.Server | null = null; + + constructor(config: SubscriptionServerConfig) { + this.config = config; + + const createLogger: CreateLogger = + config.createLogger || defaultCreateLogger; + this.log = createLogger('@4c/SubscriptionServer::Server'); + } + + attach(httpServer: any) { + this.server = new ws.Server({ + server: httpServer, + path: this.config.path, + }); + + const { createContext } = this.config; + + useServer( + // from the previous step + { + schema: this.config.schema, + context: (ctx, msg, args) => { + + }, + onConnect() + + // credentialsManager: this.config.createCredentialsManager(request), + // hasPermission: this.config.hasPermission, + createContext: + createContext && + ((credentials: TCredentials | null | undefined) => + createContext(request, credentials)), + maxSubscriptionsPerConnection: this.config + .maxSubscriptionsPerConnection, + createValidationRules: this.config.createValidationRules, + createLogger: this.config.createLogger || defaultCreateLogger, + }, + wsServer, + ); + + this.server.on('connection', this.handleConnection); + } + + handleConnection = (socket: ws, req: IncomingMessage) => { + this.log('debug', 'new socket connection'); + + const request = Object.create((express as any).request); + Object.assign(request, req); + + const { createContext } = this.config; + + // eslint-disable-next-line no-new + new AuthorizedSocketConnection(socket, { + schema: this.config.schema, + subscriber: this.config.subscriber, + credentialsManager: this.config.createCredentialsManager(request), + hasPermission: this.config.hasPermission, + createContext: + createContext && + ((credentials: TCredentials | null | undefined) => + createContext(request, credentials)), + maxSubscriptionsPerConnection: this.config.maxSubscriptionsPerConnection, + createValidationRules: this.config.createValidationRules, + createLogger: this.config.createLogger || defaultCreateLogger, + }); + }; + + async close() { + // @ts-ignore + await promisify((...args) => this.io.close(...args))(); + } +} diff --git a/yarn.lock b/yarn.lock index aff4449b..704c2a35 100644 --- a/yarn.lock +++ b/yarn.lock @@ -1985,6 +1985,13 @@ dependencies: "@types/node" "*" +"@types/ws@^7.4.1": + version "7.4.1" + resolved "https://registry.yarnpkg.com/@types/ws/-/ws-7.4.1.tgz#49eacb15a0534663d53a36fbf5b4d98f5ae9a73a" + integrity sha512-ISCK1iFnR+jYv7+jLNX0wDqesZ/5RAeY3wUx6QaphmocphU61h+b+PHjS18TF4WIPTu/MMzxIq2PHr32o2TS5Q== + dependencies: + "@types/node" "*" + "@types/yargs-parser@*": version "15.0.0" resolved "https://registry.yarnpkg.com/@types/yargs-parser/-/yargs-parser-15.0.0.tgz#cb3f9f741869e20cce330ffbeb9271590483882d" @@ -5902,6 +5909,11 @@ graphql-relay@^0.8.0: resolved "https://registry.yarnpkg.com/graphql-relay/-/graphql-relay-0.8.0.tgz#35f0090f0f056192767c1acdaa402daed19ede6d" integrity sha512-NU7CkwNxPzkqpBgv76Cgycrc3wmWVA2K5Sxm9DHSSLLuQTpaSRAUsX1sf2gITf+XQpkccsv56/z0LojXTyQbUw== +graphql-ws@^4.3.2: + version "4.3.2" + resolved "https://registry.yarnpkg.com/graphql-ws/-/graphql-ws-4.3.2.tgz#c58b03acc3bd5d4a92a6e9f729d29ba5e90d46a3" + integrity sha512-jsW6eOlko7fJek1iaSGQFj97AWuhexL9A3PuxYtyke/VlMdbSFzmDR4PlPPCTBBskRg6tNRb5RTbBVSd2T60JQ== + graphql@^15.5.1: version "15.5.1" resolved "https://registry.yarnpkg.com/graphql/-/graphql-15.5.1.tgz#f2f84415d8985e7b84731e7f3536f8bb9d383aad" @@ -12751,6 +12763,11 @@ ws@^6.2.1: dependencies: async-limiter "~1.0.0" +ws@^7.4.4: + version "7.4.4" + resolved "https://registry.yarnpkg.com/ws/-/ws-7.4.4.tgz#383bc9742cb202292c9077ceab6f6047b17f2d59" + integrity sha512-Qm8k8ojNQIMx7S+Zp8u/uHOx7Qazv3Yv4q68MiWWWOJhiwG5W3x7iqmRtJo8xxrciZUY4vRxUTJCKuRnF28ZZw== + ws@^7.4.5: version "7.5.2" resolved "https://registry.yarnpkg.com/ws/-/ws-7.5.2.tgz#09cc8fea3bec1bc5ed44ef51b42f945be36900f6" From 6ef63e853bafd12d8b8bd89aa9b2835e548f5296 Mon Sep 17 00:00:00 2001 From: Jason Quense Date: Tue, 18 May 2021 15:20:32 -0400 Subject: [PATCH 2/3] feat!: add websocket server option BREAKING CHANGE: the subscription server export is now an abstract class --- README.md | 35 ++++ package.json | 5 +- src/AuthorizedSocketConnection.ts | 21 ++- src/GraphqlSocketSubscriptionServer.ts | 110 ------------- src/SocketIOSubscriptionServer.ts | 84 ++++++++++ src/SubscriptionServer.ts | 63 ++------ src/WebSocketSubscriptionServer.ts | 170 ++++++++++++++++++++ src/index.ts | 3 + src/types.ts | 28 ++++ test/data/schema.graphql | 214 +++++++++++++++++++++++++ test/data/schema.js | 2 - test/helpers.ts | 76 ++++++++- test/socket-io.test.ts | 9 +- test/tsconfig.json | 5 + test/websocket.test.ts | 124 ++++++++++++++ update-schema.js | 11 ++ yarn.lock | 36 +++-- 17 files changed, 800 insertions(+), 196 deletions(-) delete mode 100644 src/GraphqlSocketSubscriptionServer.ts create mode 100644 src/SocketIOSubscriptionServer.ts create mode 100644 src/WebSocketSubscriptionServer.ts create mode 100644 src/types.ts create mode 100644 test/data/schema.graphql create mode 100644 test/websocket.test.ts create mode 100644 update-schema.js diff --git a/README.md b/README.md index 7fe3dc73..7246ee50 100644 --- a/README.md +++ b/README.md @@ -1 +1,36 @@ # GraphQL Subscription Server + +A subscription server for GraphQL subscriptions. Supports streaming over plain web sockets +or Socket.IO, and integrates with Redis or any other Pub/Sub service. + +## Setup + +### Socket.IO + +```js +import http from 'http'; +import { + SocketIOSubscriptionServer, // or WebSocketSubscriptionServer + JwtCredentialManager, + RedisSubscriber, +} from '@4c/graphql-subscription-server'; + +const server = http.createServer(); + +const subscriptionServer = new SocketIOSubscriptionServer({ + schema, + path: '/socket.io/graphql', + subscriber: new RedisSubscriber(), + hasPermission: (message, credentials) => { + authorize(message, credentials); + }, + createCredentialsManager: (req) => new JwtCredentialManager(), + createLogger: () => console.debug, +}); + +subscriptionServer.attach(server); + +server.listen(4000, () => { + console.log('server running'); +}); +``` diff --git a/package.json b/package.json index 5dc60007..7002fa55 100644 --- a/package.json +++ b/package.json @@ -19,7 +19,8 @@ "tdd": "jest --watch", "test": "yarn lint && yarn typecheck && jest", "testonly": "jest", - "typecheck": "tsc --noEmit && tsc -p test --noEmit" + "typecheck": "tsc --noEmit && tsc -p test --noEmit", + "update-schema": "NODE_ENV=test babel-node ./update-schema.js" }, "gitHooks": { "pre-commit": "lint-staged" @@ -47,7 +48,7 @@ "express": "^4.17.1", "graphql-ws": "^4.3.2", "redis": "^3.1.2", - "ws": "^7.4.4" + "ws": "^7.4.5" }, "peerDependencies": { "graphql": ">=0.12.3", diff --git a/src/AuthorizedSocketConnection.ts b/src/AuthorizedSocketConnection.ts index 4fff856f..97d7cf15 100644 --- a/src/AuthorizedSocketConnection.ts +++ b/src/AuthorizedSocketConnection.ts @@ -9,13 +9,13 @@ import { validate, } from 'graphql'; import { ExecutionResult } from 'graphql/execution/execute'; -import io from 'socket.io'; import * as AsyncUtils from './AsyncUtils'; import { CredentialsManager } from './CredentialsManager'; import { CreateLogger, Logger } from './Logger'; import { Subscriber } from './Subscriber'; import SubscriptionContext from './SubscriptionContext'; +import { WebSocket } from './types'; export type CreateValidationRules = ({ query, @@ -62,7 +62,7 @@ const acknowledge = (cb?: () => void) => { * - Rudimentary connection constraints (max connections) */ export default class AuthorizedSocketConnection { - socket: io.Socket; + socket: WebSocket; config: AuthorizedSocketOptions; @@ -76,7 +76,7 @@ export default class AuthorizedSocketConnection { readonly clientId: string; constructor( - socket: io.Socket, + socket: WebSocket, config: AuthorizedSocketOptions, ) { this.socket = socket; @@ -85,14 +85,13 @@ export default class AuthorizedSocketConnection { this.log = config.createLogger('AuthorizedSocket'); this.subscriptionContexts = new Map(); - this.clientId = this.socket.id; + this.clientId = this.socket.id!; - this.socket - .on('authenticate', this.handleAuthenticate) - .on('subscribe', this.handleSubscribe) - .on('unsubscribe', this.handleUnsubscribe) - .on('connect', this.handleConnect) - .on('disconnect', this.handleDisconnect); + this.socket.on('authenticate', this.handleAuthenticate); + this.socket.on('subscribe', this.handleSubscribe); + this.socket.on('unsubscribe', this.handleUnsubscribe); + this.socket.on('connect', this.handleConnect); + this.socket.on('disconnect', this.handleDisconnect); } emitError(error: { code: string; data?: any }) { @@ -125,7 +124,7 @@ export default class AuthorizedSocketConnection { }); await this.config.credentialsManager.authenticate(authorization); - } catch (error) { + } catch (error: any) { this.log('error', error.message, { error, clientId: this.clientId }); this.emitError({ code: 'invalid_authorization' }); } diff --git a/src/GraphqlSocketSubscriptionServer.ts b/src/GraphqlSocketSubscriptionServer.ts deleted file mode 100644 index e9ac8477..00000000 --- a/src/GraphqlSocketSubscriptionServer.ts +++ /dev/null @@ -1,110 +0,0 @@ -import type { IncomingMessage } from 'http'; -import { promisify } from 'util'; - -import express from 'express'; -import type { GraphQLSchema } from 'graphql'; -import { useServer } from 'graphql-ws/lib/use/ws'; -import ws from 'ws'; - -import AuthorizedSocketConnection from './AuthorizedSocketConnection'; -import type { CreateValidationRules } from './AuthorizedSocketConnection'; -import type { CredentialsManager } from './CredentialsManager'; -import type { CreateLogger, Logger } from './Logger'; -import type { Subscriber } from './Subscriber'; - -export type SubscriptionServerConfig = { - path: string; - schema: GraphQLSchema; - subscriber: Subscriber; - createCredentialsManager: (request: any) => CredentialsManager; - hasPermission: (data: any, credentials: TCredentials) => boolean; - createContext?: ( - request: any, - credentials: TCredentials | null | undefined, - ) => TContext; - maxSubscriptionsPerConnection?: number; - createValidationRules?: CreateValidationRules; - createLogger?: CreateLogger; -}; - -// eslint-disable-next-line @typescript-eslint/no-empty-function -const defaultCreateLogger = () => () => {}; - -export default class SubscriptionServer { - config: SubscriptionServerConfig; - - log: Logger; - - server: ws.Server | null = null; - - constructor(config: SubscriptionServerConfig) { - this.config = config; - - const createLogger: CreateLogger = - config.createLogger || defaultCreateLogger; - this.log = createLogger('@4c/SubscriptionServer::Server'); - } - - attach(httpServer: any) { - this.server = new ws.Server({ - server: httpServer, - path: this.config.path, - }); - - const { createContext } = this.config; - - useServer( - // from the previous step - { - schema: this.config.schema, - context: (ctx, msg, args) => { - - }, - onConnect() - - // credentialsManager: this.config.createCredentialsManager(request), - // hasPermission: this.config.hasPermission, - createContext: - createContext && - ((credentials: TCredentials | null | undefined) => - createContext(request, credentials)), - maxSubscriptionsPerConnection: this.config - .maxSubscriptionsPerConnection, - createValidationRules: this.config.createValidationRules, - createLogger: this.config.createLogger || defaultCreateLogger, - }, - wsServer, - ); - - this.server.on('connection', this.handleConnection); - } - - handleConnection = (socket: ws, req: IncomingMessage) => { - this.log('debug', 'new socket connection'); - - const request = Object.create((express as any).request); - Object.assign(request, req); - - const { createContext } = this.config; - - // eslint-disable-next-line no-new - new AuthorizedSocketConnection(socket, { - schema: this.config.schema, - subscriber: this.config.subscriber, - credentialsManager: this.config.createCredentialsManager(request), - hasPermission: this.config.hasPermission, - createContext: - createContext && - ((credentials: TCredentials | null | undefined) => - createContext(request, credentials)), - maxSubscriptionsPerConnection: this.config.maxSubscriptionsPerConnection, - createValidationRules: this.config.createValidationRules, - createLogger: this.config.createLogger || defaultCreateLogger, - }); - }; - - async close() { - // @ts-ignore - await promisify((...args) => this.io.close(...args))(); - } -} diff --git a/src/SocketIOSubscriptionServer.ts b/src/SocketIOSubscriptionServer.ts new file mode 100644 index 00000000..803b8275 --- /dev/null +++ b/src/SocketIOSubscriptionServer.ts @@ -0,0 +1,84 @@ +import { promisify } from 'util'; + +import express from 'express'; +import type io from 'socket.io'; + +import SubscriptionServer, { + SubscriptionServerConfig, +} from './SubscriptionServer'; + +export interface SocketIOSubscriptionServerConfig + extends SubscriptionServerConfig { + socketIoServer?: io.Server; +} + +export default class SocketIOSubscriptionServer< + TContext, + TCredentials, +> extends SubscriptionServer { + io: io.Server; + + constructor({ + socketIoServer, + ...config + }: SocketIOSubscriptionServerConfig) { + super(config); + + this.io = socketIoServer!; + if (!this.io) { + // eslint-disable-next-line global-require, @typescript-eslint/no-var-requires + const IoServer = require('socket.io').Server; + this.io = new IoServer({ + serveClient: false, + path: this.config.path, + transports: ['websocket'], + allowEIO3: true, + }); + } + + this.io.on('connection', (socket: io.Socket) => { + const clientId = socket.id; + + const request = Object.create((express as any).request); + Object.assign(request, socket.request); + + this.log('debug', 'SubscriptionServer: new socket connection', { + clientId, + numClients: this.io.engine?.clientsCount ?? 0, + }); + + this.opened( + { + id: clientId, + protocol: 'socket-io', + on: socket.on.bind(socket), + emit(event: string, data: any) { + socket.emit(event, data); + }, + close() { + socket.disconnect(); + }, + }, + request, + ); + + // add after so the logs happen in order + socket.once('disconnect', (reason) => { + this.log('debug', 'socket disconnected', { + reason, + clientId, + numClients: (this.io.engine.clientsCount ?? 0) - 1, // number hasn't decremented at this point for this client + }); + }); + }); + } + + attach(httpServer: any) { + this.io.attach(httpServer); + } + + async close() { + // @ts-ignore + await promisify((...args) => this.io.close(...args))(); + } +} diff --git a/src/SubscriptionServer.ts b/src/SubscriptionServer.ts index 2c742873..942d6b1f 100644 --- a/src/SubscriptionServer.ts +++ b/src/SubscriptionServer.ts @@ -1,6 +1,4 @@ -import { promisify } from 'util'; - -import express from 'express'; +import { Request } from 'express'; import type { GraphQLSchema } from 'graphql'; import type { Server, Socket } from 'socket.io'; @@ -9,8 +7,9 @@ import type { CreateValidationRules } from './AuthorizedSocketConnection'; import type { CredentialsManager } from './CredentialsManager'; import { CreateLogger, Logger, noopCreateLogger } from './Logger'; import type { Subscriber } from './Subscriber'; +import { WebSocket } from './types'; -export type SubscriptionServerConfig = { +export interface SubscriptionServerConfig { path: string; schema: GraphQLSchema; subscriber: Subscriber; @@ -23,51 +22,25 @@ export type SubscriptionServerConfig = { maxSubscriptionsPerConnection?: number; createValidationRules?: CreateValidationRules; createLogger?: CreateLogger; - socketIoServer?: Server; -}; +} -export default class SubscriptionServer { +export default abstract class SubscriptionServer { config: SubscriptionServerConfig; log: Logger; - io: Server; - constructor(config: SubscriptionServerConfig) { this.config = config; - const createLogger = config.createLogger || noopCreateLogger; - this.log = createLogger('SubscriptionServer'); - - this.io = config.socketIoServer!; - if (!this.io) { - // eslint-disable-next-line global-require, @typescript-eslint/no-var-requires - const IoServer = require('socket.io').Server; - this.io = new IoServer({ - serveClient: false, - path: this.config.path, - transports: ['websocket'], - allowEIO3: true, - }); - } - - this.io.on('connection', this.handleConnection); - } + const createLogger: CreateLogger = config.createLogger || noopCreateLogger; - attach(httpServer: any) { - this.io.attach(httpServer); + this.log = createLogger('SubscriptionServer'); } - handleConnection = (socket: Socket) => { - const clientId = socket.id; - - this.log('debug', 'new socket connection', { - clientId, - numClients: this.io.engine?.clientsCount ?? 0, - }); + public abstract attach(httpServer: any): void; - const request = Object.create((express as any).request); - Object.assign(request, socket.request); + protected opened(socket: WebSocket, request: Request) { + this.log('debug', 'new socket connection'); const { createContext } = this.config; @@ -85,19 +58,7 @@ export default class SubscriptionServer { createValidationRules: this.config.createValidationRules, createLogger: this.config.createLogger || noopCreateLogger, }); - - // add after so the logs happen in order - socket.once('disconnect', (reason) => { - this.log('debug', 'socket disconnected', { - reason, - clientId, - numClients: (this.io.engine.clientsCount ?? 0) - 1, // number hasn't decremented at this point for this client - }); - }); - }; - - async close() { - // @ts-ignore - await promisify((...args) => this.io.close(...args))(); } + + abstract close(): void | Promise; } diff --git a/src/WebSocketSubscriptionServer.ts b/src/WebSocketSubscriptionServer.ts new file mode 100644 index 00000000..033c805f --- /dev/null +++ b/src/WebSocketSubscriptionServer.ts @@ -0,0 +1,170 @@ +/* eslint-disable max-classes-per-file */ +import { EventEmitter } from 'events'; +import type * as http from 'http'; +import url from 'url'; + +import ws from 'ws'; + +import SubscriptionServer, { + SubscriptionServerConfig, +} from './SubscriptionServer'; +import { MessageType } from './types'; + +interface Message { + type: MessageType; + payload: any; + ackId?: number; +} + +class GraphQLSocket extends EventEmitter { + protocol: 'graphql-transport-ws' | 'socket-io'; + + private pingHandle: NodeJS.Timeout | null; + + private pongWait: NodeJS.Timeout | null; + + constructor(private socket: ws, { keepAlive = 12 * 1000 } = {}) { + super(); + this.socket = socket; + + this.protocol = + socket.protocol === 'graphql-transport-ws' + ? socket.protocol + : 'socket-io'; + + socket.on('message', (data) => { + let msg: Message | null = null; + try { + msg = JSON.parse(data.toString()); + } catch (err) { + // this.log('err'); + } + super.emit(msg!.type, msg!.payload, this.ack(msg)); + }); + + socket.on('close', (code: number, reason: string) => { + clearTimeout(this.pongWait!); + clearInterval(this.pingHandle!); + + super.emit('close', code, reason); + }); + + // keep alive through ping-pong messages + this.pongWait = null; + + this.pingHandle = + keepAlive > 0 && Number.isFinite(keepAlive) + ? setInterval(() => { + // ping pong on open sockets only + if (this.socket.readyState === this.socket.OPEN) { + // terminate the connection after pong wait has passed because the client is idle + this.pongWait = setTimeout(() => { + this.socket.terminate(); + }, keepAlive); + + // listen for client's pong and stop socket termination + this.socket.once('pong', () => { + clearTimeout(this.pongWait!); + this.pongWait = null; + }); + + this.socket.ping(); + } + }, keepAlive) + : null; + } + + private ack(msg: { ackId?: number } | null) { + if (!msg || msg.ackId == null) return undefined; + const { ackId } = msg; + return (data: any) => { + this.socket.send( + JSON.stringify({ type: `ack:${ackId}`, payload: data }), + ); + }; + } + + emit(msg: MessageType, payload?: any) { + this.socket.send( + JSON.stringify({ + type: msg, + payload, + }), + ); + return true; + } + + close(code: number, reason: string) { + this.socket.close(code, reason); + } +} + +export default class WebSocketSubscriptionServer< + TContext, + TCredentials, +> extends SubscriptionServer { + private ws: ws.Server; + + constructor(config: SubscriptionServerConfig) { + super(config); + + this.ws = new ws.Server({ noServer: true }); + + this.ws.on('error', () => { + // catch the first thrown error and re-throw it once all clients have been notified + let firstErr: Error | null = null; + + // report server errors by erroring out all clients with the same error + for (const client of this.ws.clients) { + try { + client.close(1011, 'Internal Error'); + } catch (err: any) { + firstErr = firstErr ?? err; + } + } + + if (firstErr) throw firstErr; + }); + + this.ws.on('connection', (socket, request) => { + const gqlSocket = new GraphQLSocket(socket); + + this.opened(gqlSocket, request as any); + + // socket io clients do this behind the scenes + // so we keep it out of the server logic + if (gqlSocket.protocol === 'socket-io') { + // inform the client they are good to go + gqlSocket.emit('connect'); + } + }); + } + + attach(httpServer: http.Server) { + httpServer.on( + 'upgrade', + (req: http.IncomingMessage, socket: any, head) => { + const { pathname } = url.parse(req.url!); + if (pathname !== this.config.path) { + socket.destroy(); + return; + } + + this.ws.handleUpgrade(req, socket, head, (client) => { + this.ws.emit('connection', client, req); + }); + }, + ); + } + + async close() { + for (const client of this.ws.clients) { + client.close(1001, 'Going away'); + } + this.ws.removeAllListeners(); + + await new Promise((resolve, reject) => { + this.ws.close((err) => (err ? reject(err) : resolve())); + }); + } +} diff --git a/src/index.ts b/src/index.ts index d7257e2f..e992e3e3 100644 --- a/src/index.ts +++ b/src/index.ts @@ -2,6 +2,8 @@ export { default as EventSubscriber } from './EventSubscriber'; export { default as JwtCredentialsManager } from './JwtCredentialsManager'; export { default as RedisSubscriber } from './RedisSubscriber'; export { default as SubscriptionServer } from './SubscriptionServer'; +export { default as SocketIOSubscriptionServer } from './SocketIOSubscriptionServer'; +export { default as WebSocketSubscriptionServer } from './WebSocketSubscriptionServer'; export { AsyncQueue } from './AsyncUtils'; export type { CreateValidationRules } from './AuthorizedSocketConnection'; @@ -11,3 +13,4 @@ export type { Subscriber } from './Subscriber'; export type { SubscriptionServerConfig } from './SubscriptionServer'; export type { JwtCredentials } from './JwtCredentialsManager'; +export type { SocketIOSubscriptionServerConfig } from './SocketIOSubscriptionServer'; diff --git a/src/types.ts b/src/types.ts new file mode 100644 index 00000000..7ae7d2cc --- /dev/null +++ b/src/types.ts @@ -0,0 +1,28 @@ +export interface WebSocket { + protocol: string; + id?: string; + + close(code: number, reason: string): Promise | void; + + on( + message: string, + listener: (data: any, ack?: (data?: any) => Promise | void) => void, + ): void; + + emit(message: string, data: any): Promise | void | boolean; + // onMessage(cb: (data: string) => Promise): void; +} + +export type MessageType = + | 'authenticate' + | 'subscribe' + | 'unsubscribe' + | 'connect' + | 'disconnect' + | `ack:${number}`; + +export interface BaseMessage { + type: MessageType; + payload?: D; + ackId?: number; +} diff --git a/test/data/schema.graphql b/test/data/schema.graphql new file mode 100644 index 00000000..2e110296 --- /dev/null +++ b/test/data/schema.graphql @@ -0,0 +1,214 @@ +schema { + query: Root + mutation: Mutation + subscription: Subscription +} + +type Root { + viewer: User + + """ + Fetches an object given its ID + """ + node( + """ + The ID of an object + """ + id: ID! + ): Node +} + +type User implements Node { + """ + The ID of an object + """ + id: ID! + todos( + status: String = "any" + after: String + first: Int + before: String + last: Int + ): TodoConnection + numTodos: Int + numCompletedTodos: Int +} + +""" +An object with an ID +""" +interface Node { + """ + The id of the object. + """ + id: ID! +} + +""" +A connection to a list of items. +""" +type TodoConnection { + """ + Information to aid in pagination. + """ + pageInfo: PageInfo! + + """ + A list of edges. + """ + edges: [TodoEdge] +} + +""" +Information about pagination in a connection. +""" +type PageInfo { + """ + When paginating forwards, are there more items? + """ + hasNextPage: Boolean! + + """ + When paginating backwards, are there more items? + """ + hasPreviousPage: Boolean! + + """ + When paginating backwards, the cursor to continue. + """ + startCursor: String + + """ + When paginating forwards, the cursor to continue. + """ + endCursor: String +} + +""" +An edge in a connection. +""" +type TodoEdge { + """ + The item at the end of the edge + """ + node: Todo + + """ + A cursor for use in pagination + """ + cursor: String! +} + +type Todo implements Node { + """ + The ID of an object + """ + id: ID! + complete: Boolean + text: String +} + +type Mutation { + addTodo(input: AddTodoInput!): AddTodoPayload + changeTodoStatus(input: ChangeTodoStatusInput!): ChangeTodoStatusPayload + markAllTodos(input: MarkAllTodosInput!): MarkAllTodosPayload + removeCompletedTodos( + input: RemoveCompletedTodosInput! + ): RemoveCompletedTodosPayload + removeTodo(input: RemoveTodoInput!): RemoveTodoPayload + renameTodo(input: RenameTodoInput!): RenameTodoPayload +} + +type AddTodoPayload { + viewer: User + todoEdge: TodoEdge + clientMutationId: String +} + +input AddTodoInput { + text: String! + clientMutationId: String +} + +type ChangeTodoStatusPayload { + viewer: User + todo: Todo + clientMutationId: String +} + +input ChangeTodoStatusInput { + id: ID! + complete: Boolean! + clientMutationId: String +} + +type MarkAllTodosPayload { + viewer: User + changedTodos: [Todo] + clientMutationId: String +} + +input MarkAllTodosInput { + complete: Boolean! + clientMutationId: String +} + +type RemoveCompletedTodosPayload { + viewer: User + deletedIds: [String] + clientMutationId: String +} + +input RemoveCompletedTodosInput { + clientMutationId: String +} + +type RemoveTodoPayload { + viewer: User + deletedId: ID + clientMutationId: String +} + +input RemoveTodoInput { + id: ID! + clientMutationId: String +} + +type RenameTodoPayload { + todo: Todo + clientMutationId: String +} + +input RenameTodoInput { + id: ID! + text: String! + clientMutationId: String +} + +type Subscription { + todoUpdated( + input: TodoUpdatedSubscriptionInput! + ): TodoUpdatedSubscriptionPayload + todoCreated( + input: TodoCreatedSubscriptionInput! + ): TodoCreatedSubscriptionPayload +} + +type TodoUpdatedSubscriptionPayload { + todo: Todo + clientSubscriptionId: String +} + +input TodoUpdatedSubscriptionInput { + id: ID! + clientSubscriptionId: String +} + +type TodoCreatedSubscriptionPayload { + todo: Todo + clientSubscriptionId: String +} + +input TodoCreatedSubscriptionInput { + clientSubscriptionId: String +} diff --git a/test/data/schema.js b/test/data/schema.js index 4fc655f5..7554d207 100644 --- a/test/data/schema.js +++ b/test/data/schema.js @@ -255,8 +255,6 @@ const GraphQLMutation = new GraphQLObjectType({ }, }); -// const delay = (ms) => new Promise((resolve) => setTimeout(resolve, ms)); - const GraphQLTodoUpdatedSubscription = subscriptionWithClientId({ name: 'TodoUpdatedSubscription', diff --git a/test/helpers.ts b/test/helpers.ts index b271101d..8b8fa051 100644 --- a/test/helpers.ts +++ b/test/helpers.ts @@ -1,11 +1,26 @@ +// import { RedisClient } from 'redis'; +// import type { Socket } from 'socket.io-client'; +// import socketio from 'socket.io-client'; +import { EventEmitter } from 'events'; import http from 'http'; import socketio, { Socket } from 'socket.io-client'; +import WebSocket from 'ws'; import { CredentialsManager } from '../src/CredentialsManager'; import RedisSubscriber from '../src/RedisSubscriber'; import type SubscriptionServer from '../src/SubscriptionServer'; +function uuid() { + return 'xxxxxxxx-xxxx-4xxx-yxxx-xxxxxxxxxxxx'.replace(/[xy]/g, (c) => { + // eslint-disable-next-line no-bitwise + const r = (Math.random() * 16) | 0; + // eslint-disable-next-line no-bitwise + const v = c === 'x' ? r : (r & 0x3) | 0x8; + return v.toString(16); + }); +} + export const delay = (ms) => new Promise((resolve) => setTimeout(resolve, ms)); export function graphql(strings: any): string { @@ -55,23 +70,74 @@ export async function startServer( httpServer, subscriber, async close() { + httpServer.close(); await server.close(); }, }; } +let i = 0; +class WebSocketShim extends EventEmitter { + socket: WebSocket; + + id: string; + + constructor(path: string, { protocol = 'socket-io' } = {}) { + super(); + const socket = new WebSocket(path, { protocol }); + + this.socket = socket; + this.id = uuid(); + + socket.on('message', (data) => { + const { type, payload } = JSON.parse(data.toString()); + super.emit(type, payload); + }); + } + + private ack(cb?: (data?: any) => void) { + if (!cb) return undefined; + + const ackId = i++; + this.once(`ack:${ackId}`, (data) => { + cb(data); + }); + return ackId; + } + + emit(type: string, data: any, cb?: () => void) { + this.socket.send( + JSON.stringify({ + type, + payload: data, + ackId: this.ack(cb), + }), + ); + return true; + } + + disconnect() { + this.socket.removeAllListeners(); + this.socket.close(); + } +} + export class TestClient { - socket: Socket; + socket: Socket | WebSocketShim; constructor( public subscriber: RedisSubscriber, public query: string, public variables: Record | null = null, + { engine = 'socket.io' }: { engine?: 'socket.io' | 'ws' } = {}, ) { - this.socket = socketio('http://localhost:5000', { - path: '/graphql', - transports: ['websocket'], - }); + this.socket = + engine === 'socket.io' + ? socketio('http://localhost:5000', { + path: '/graphql', + transports: ['websocket'], + }) + : new WebSocketShim('ws://localhost:5000/graphql'); } init() { diff --git a/test/socket-io.test.ts b/test/socket-io.test.ts index 8e5c9678..990972ea 100644 --- a/test/socket-io.test.ts +++ b/test/socket-io.test.ts @@ -2,7 +2,7 @@ import socketio from 'socket.io-client'; import { CreateLogger } from '../src'; -import SubscriptionServer from '../src/SubscriptionServer'; +import SocketIOSubscriptionServer from '../src/SocketIOSubscriptionServer'; import { maskNonDeterministicValues } from '../src/Testing'; import schema from './data/schema'; import { @@ -13,10 +13,8 @@ import { startServer, } from './helpers'; -const sleep = () => new Promise((resolve) => process.nextTick(resolve)); - function createServer(subscriber, options = {}) { - return new SubscriptionServer({ + return new SocketIOSubscriptionServer({ ...options, path: '/graphql', schema, @@ -25,7 +23,6 @@ function createServer(subscriber, options = {}) { return creds !== null; }, createCredentialsManager: () => new TestCredentialsManager(), - // createLogger: () => console.debug, }); } @@ -177,7 +174,7 @@ describe('socket-io client', () => { promises.push(socket.subscribe(`s-${id}`)); promises.push(socket.unsubscribe(`s-${id}`)); - await sleep(); + await delay(0); } await Promise.all(promises); diff --git a/test/tsconfig.json b/test/tsconfig.json index 7c57ad61..df821f1a 100644 --- a/test/tsconfig.json +++ b/test/tsconfig.json @@ -1,9 +1,14 @@ { "extends": "../tsconfig.json", "compilerOptions": { + "rootDir": "../", "noImplicitAny": false, "types": ["jest", "node"], "rootDir": ".." }, +<<<<<<< HEAD "include": ["**/*.ts", "../src"] +======= + "include": ["./", "../src"] +>>>>>>> ed71098... feat!: add websocket server option } diff --git a/test/websocket.test.ts b/test/websocket.test.ts new file mode 100644 index 00000000..081783e3 --- /dev/null +++ b/test/websocket.test.ts @@ -0,0 +1,124 @@ +import WebSocketSubscriptionServer from '../src/WebSocketSubscriptionServer'; +import schema from './data/schema'; +import { + TestClient, + TestCredentialsManager, + graphql, + startServer, +} from './helpers'; + +function createServer(subscriber) { + return new WebSocketSubscriptionServer({ + path: '/graphql', + schema, + subscriber, + hasPermission: (_, creds) => { + return creds !== null; + }, + createCredentialsManager: () => new TestCredentialsManager(), + // createLogger: () => console.debug, + }); +} + +type PromiseType

= P extends Promise ? R : never; + +describe('socket-io client', () => { + let server: PromiseType>; + let client: TestClient | null = null; + + async function createClient(query: string, variables: any) { + client = new TestClient(server.subscriber, query, variables, { + engine: 'ws', + }); + + await client.init(); + return client; + } + + beforeAll(async () => { + server = await startServer(createServer); + }); + + afterEach(() => { + client?.close(); + client = null; + }); + + afterAll(async () => { + client?.close(); + await server.close(); + }); + + it('should subscribe', async () => { + const socket = await createClient( + graphql` + subscription TestTodoUpdatedSubscription( + $input: TodoUpdatedSubscriptionInput! + ) { + todoUpdated(input: $input) { + todo { + text + } + } + } + `, + { + input: { + id: '1', + }, + }, + ); + + await socket.authenticate(); + + expect( + await socket.getSubscriptionResult({ + topic: `todo:1:updated`, + data: { + id: '1', + text: 'Make work', + }, + }), + ).toMatchInlineSnapshot(` + Object { + "event": "subscription update", + "payload": Object { + "data": Object { + "todoUpdated": Object { + "todo": Object { + "text": "Buy a unicorn", + }, + }, + }, + "id": "foo", + }, + } + `); + }); + + it('should unsubscribe', async () => { + const socket = await createClient( + graphql` + subscription TestTodoUpdatedSubscription( + $input: TodoUpdatedSubscriptionInput! + ) { + todoUpdated(input: $input) { + todo { + text + } + } + } + `, + { + input: { + id: '1', + }, + }, + ); + + await socket.authenticate(); + await socket.subscribe(); + + await socket.unsubscribe(); + }); +}); diff --git a/update-schema.js b/update-schema.js new file mode 100644 index 00000000..19f2659c --- /dev/null +++ b/update-schema.js @@ -0,0 +1,11 @@ +import fs from 'fs'; +import path from 'path'; + +import { printSchema } from 'graphql/utilities'; + +import schema from './test/data/schema'; + +fs.writeFileSync( + path.join(__dirname, './test/data/schema.graphql'), + printSchema(schema), +); diff --git a/yarn.lock b/yarn.lock index 704c2a35..9ff6965a 100644 --- a/yarn.lock +++ b/yarn.lock @@ -440,7 +440,12 @@ chalk "^2.0.0" js-tokens "^4.0.0" -"@babel/parser@^7.1.0", "@babel/parser@^7.14.5", "@babel/parser@^7.15.0", "@babel/parser@^7.7.0", "@babel/parser@^7.7.2": +"@babel/parser@^7.1.0", "@babel/parser@^7.7.0": + version "7.12.10" + resolved "https://registry.yarnpkg.com/@babel/parser/-/parser-7.12.10.tgz#824600d59e96aea26a5a2af5a9d812af05c3ae81" + integrity sha512-PJdRPwyoOqFAWfLytxrWwGrAxghCgh/yTNCYciOz8QgjflA7aZhECPZAa2VUedKg2+QMWkI0L9lynh2SNmNEgA== + +"@babel/parser@^7.14.5", "@babel/parser@^7.15.0", "@babel/parser@^7.7.2": version "7.15.2" resolved "https://registry.yarnpkg.com/@babel/parser/-/parser-7.15.2.tgz#08d4ffcf90d211bf77e7cc7154c6f02d468d2b1d" integrity sha512-bMJXql1Ss8lFnvr11TZDH4ArtwlAS5NG9qBmdiFW2UHHm6MVoR+GDc5XE2b9K938cyjc9O6/+vjjcffLDtfuDg== @@ -1155,7 +1160,15 @@ pirates "^4.0.0" source-map-support "^0.5.16" -"@babel/runtime-corejs3@^7.10.2", "@babel/runtime-corejs3@^7.9.2": +"@babel/runtime-corejs3@^7.10.2": + version "7.12.5" + resolved "https://registry.yarnpkg.com/@babel/runtime-corejs3/-/runtime-corejs3-7.12.5.tgz#ffee91da0eb4c6dae080774e94ba606368e414f4" + integrity sha512-roGr54CsTmNPPzZoCP1AmDXuBoNao7tnSA83TXTwt+UK5QVyh1DIJnrgYRPWKCF2flqZQXwa7Yr8v7VmLzF0YQ== + dependencies: + core-js-pure "^3.0.0" + regenerator-runtime "^0.13.4" + +"@babel/runtime-corejs3@^7.9.2": version "7.14.0" resolved "https://registry.yarnpkg.com/@babel/runtime-corejs3/-/runtime-corejs3-7.14.0.tgz#6bf5fbc0b961f8e3202888cb2cd0fb7a0a9a3f66" integrity sha512-0R0HTZWHLk6G8jIk0FtoX+AatCtKnswS98VhXwGImFc759PJRp4Tru0PQYZofyijTFUr+gT8Mu7sgXVJLQ0ceg== @@ -1906,11 +1919,16 @@ resolved "https://registry.yarnpkg.com/@types/minimist/-/minimist-1.2.1.tgz#283f669ff76d7b8260df8ab7a4262cc83d988256" integrity sha512-fZQQafSREFyuZcdWFAExYjBiCL7AUCdgsk80iO0q4yihYYdcIiH28CcuPTGFgLOCC8RlW49GSQxdHwZP+I7CNg== -"@types/node@*", "@types/node@>= 8", "@types/node@>=10.0.0": +"@types/node@*", "@types/node@>= 8": version "15.6.1" resolved "https://registry.yarnpkg.com/@types/node/-/node-15.6.1.tgz#32d43390d5c62c5b6ec486a9bc9c59544de39a08" integrity sha512-7EIraBEyRHEe7CH+Fm1XvgqU6uwZN8Q7jppJGcqjROMT29qhAuuOxYB1uEY5UMYQKEmA5D+5tBnhdaPXSsLONA== +"@types/node@>=10.0.0": + version "15.3.0" + resolved "https://registry.yarnpkg.com/@types/node/-/node-15.3.0.tgz#d6fed7d6bc6854306da3dea1af9f874b00783e26" + integrity sha512-8/bnjSZD86ZfpBsDlCIkNXIvm+h6wi9g7IqL+kmFkQ+Wvu3JrasgLElfiPgoo8V8vVfnEi0QVS12gbl94h9YsQ== + "@types/node@^12.7.1": version "12.19.4" resolved "https://registry.yarnpkg.com/@types/node/-/node-12.19.4.tgz#cdfbb62e26c7435ed9aab9c941393cc3598e9b46" @@ -9953,7 +9971,12 @@ prettier-linter-helpers@^1.0.0: dependencies: fast-diff "^1.1.2" -prettier@^2.2.1, prettier@^2.3.2: +prettier@^2.2.1: + version "2.2.1" + resolved "https://registry.yarnpkg.com/prettier/-/prettier-2.2.1.tgz#795a1a78dd52f073da0cd42b21f9c91381923ff5" + integrity sha512-PqyhM2yCjg/oKkFPtTGUojv7gnZAoG80ttl45O6x2Ug/rMJw4wcc9k6aaf2hibP7BGVCCM33gZoGjyvt9mm16Q== + +prettier@^2.3.2: version "2.3.2" resolved "https://registry.yarnpkg.com/prettier/-/prettier-2.3.2.tgz#ef280a05ec253712e486233db5c6f23441e7342d" integrity sha512-lnJzDfJ66zkMy58OL5/NY5zp70S7Nz6KqcKkXYzn2tMVrNxvbqaBpg7H3qHaLxCJ5lNMsGuM8+ohS7cZrthdLQ== @@ -12763,11 +12786,6 @@ ws@^6.2.1: dependencies: async-limiter "~1.0.0" -ws@^7.4.4: - version "7.4.4" - resolved "https://registry.yarnpkg.com/ws/-/ws-7.4.4.tgz#383bc9742cb202292c9077ceab6f6047b17f2d59" - integrity sha512-Qm8k8ojNQIMx7S+Zp8u/uHOx7Qazv3Yv4q68MiWWWOJhiwG5W3x7iqmRtJo8xxrciZUY4vRxUTJCKuRnF28ZZw== - ws@^7.4.5: version "7.5.2" resolved "https://registry.yarnpkg.com/ws/-/ws-7.5.2.tgz#09cc8fea3bec1bc5ed44ef51b42f945be36900f6" From eb49b91bb98910489d1b4d407d4c2124ae521792 Mon Sep 17 00:00:00 2001 From: Jason Quense Date: Tue, 5 Oct 2021 14:20:57 -0400 Subject: [PATCH 3/3] update --- src/SocketIOSubscriptionServer.ts | 6 +- src/SubscriptionServer.ts | 5 +- src/WebSocketSubscriptionServer.ts | 106 +++++++++++++++++++---------- src/types.ts | 6 +- test/helpers.ts | 5 +- test/socket-io.test.ts | 1 + test/websocket.test.ts | 82 +++++++++++++++++++++- 7 files changed, 161 insertions(+), 50 deletions(-) diff --git a/src/SocketIOSubscriptionServer.ts b/src/SocketIOSubscriptionServer.ts index 803b8275..16c8da67 100644 --- a/src/SocketIOSubscriptionServer.ts +++ b/src/SocketIOSubscriptionServer.ts @@ -42,15 +42,15 @@ export default class SocketIOSubscriptionServer< const request = Object.create((express as any).request); Object.assign(request, socket.request); - this.log('debug', 'SubscriptionServer: new socket connection', { + this.log('debug', 'new socket connection', { clientId, numClients: this.io.engine?.clientsCount ?? 0, }); - this.opened( + this.initConnection( { id: clientId, - protocol: 'socket-io', + protocol: '4c-subscription-server', on: socket.on.bind(socket), emit(event: string, data: any) { socket.emit(event, data); diff --git a/src/SubscriptionServer.ts b/src/SubscriptionServer.ts index 942d6b1f..c71ca817 100644 --- a/src/SubscriptionServer.ts +++ b/src/SubscriptionServer.ts @@ -1,6 +1,5 @@ import { Request } from 'express'; import type { GraphQLSchema } from 'graphql'; -import type { Server, Socket } from 'socket.io'; import AuthorizedSocketConnection from './AuthorizedSocketConnection'; import type { CreateValidationRules } from './AuthorizedSocketConnection'; @@ -39,9 +38,7 @@ export default abstract class SubscriptionServer { public abstract attach(httpServer: any): void; - protected opened(socket: WebSocket, request: Request) { - this.log('debug', 'new socket connection'); - + protected initConnection(socket: WebSocket, request: Request) { const { createContext } = this.config; // eslint-disable-next-line no-new diff --git a/src/WebSocketSubscriptionServer.ts b/src/WebSocketSubscriptionServer.ts index 033c805f..095243b6 100644 --- a/src/WebSocketSubscriptionServer.ts +++ b/src/WebSocketSubscriptionServer.ts @@ -8,8 +8,12 @@ import ws from 'ws'; import SubscriptionServer, { SubscriptionServerConfig, } from './SubscriptionServer'; -import { MessageType } from './types'; +import { MessageType, SupportedProtocols } from './types'; +export type DisconnectReason = + | 'server disconnect' + | 'client disconnect' + | 'ping timeout'; interface Message { type: MessageType; payload: any; @@ -17,20 +21,24 @@ interface Message { } class GraphQLSocket extends EventEmitter { - protocol: 'graphql-transport-ws' | 'socket-io'; + protocol: SupportedProtocols; - private pingHandle: NodeJS.Timeout | null; + isAlive = true; - private pongWait: NodeJS.Timeout | null; - - constructor(private socket: ws, { keepAlive = 12 * 1000 } = {}) { + constructor(private socket: ws) { super(); + this.socket = socket; + this.isAlive = true; this.protocol = socket.protocol === 'graphql-transport-ws' ? socket.protocol - : 'socket-io'; + : '4c-subscription-server'; + + socket.on('pong', () => { + this.isAlive = true; + }); socket.on('message', (data) => { let msg: Message | null = null; @@ -43,35 +51,16 @@ class GraphQLSocket extends EventEmitter { }); socket.on('close', (code: number, reason: string) => { - clearTimeout(this.pongWait!); - clearInterval(this.pingHandle!); - + this.isAlive = false; + super.emit('disconnect', 'client disconnect'); super.emit('close', code, reason); }); + } - // keep alive through ping-pong messages - this.pongWait = null; - - this.pingHandle = - keepAlive > 0 && Number.isFinite(keepAlive) - ? setInterval(() => { - // ping pong on open sockets only - if (this.socket.readyState === this.socket.OPEN) { - // terminate the connection after pong wait has passed because the client is idle - this.pongWait = setTimeout(() => { - this.socket.terminate(); - }, keepAlive); - - // listen for client's pong and stop socket termination - this.socket.once('pong', () => { - clearTimeout(this.pongWait!); - this.pongWait = null; - }); - - this.socket.ping(); - } - }, keepAlive) - : null; + disconnect(reason?: DisconnectReason) { + this.emit('disconnect', reason); + super.emit('disconnect', reason); + this.socket.terminate(); } private ack(msg: { ackId?: number } | null) { @@ -97,18 +86,39 @@ class GraphQLSocket extends EventEmitter { close(code: number, reason: string) { this.socket.close(code, reason); } + + ping() { + if (this.socket.readyState === this.socket.OPEN) { + this.isAlive = false; + this.socket.ping(); + } + } } +export interface WebSocketSubscriptionServerConfig + extends SubscriptionServerConfig { + keepAlive?: number; +} export default class WebSocketSubscriptionServer< TContext, TCredentials, > extends SubscriptionServer { private ws: ws.Server; - constructor(config: SubscriptionServerConfig) { + private gqlClients = new WeakMap(); + + readonly keepAlive: number; + + private pingHandle: NodeJS.Timeout | null = null; + + constructor({ + keepAlive = 15_000, + ...config + }: WebSocketSubscriptionServerConfig) { super(config); this.ws = new ws.Server({ noServer: true }); + this.keepAlive = keepAlive; this.ws.on('error', () => { // catch the first thrown error and re-throw it once all clients have been notified @@ -126,14 +136,16 @@ export default class WebSocketSubscriptionServer< if (firstErr) throw firstErr; }); + this.scheduleLivelinessCheck(); this.ws.on('connection', (socket, request) => { const gqlSocket = new GraphQLSocket(socket); + this.gqlClients.set(socket, gqlSocket); - this.opened(gqlSocket, request as any); + this.initConnection(gqlSocket, request as any); // socket io clients do this behind the scenes // so we keep it out of the server logic - if (gqlSocket.protocol === 'socket-io') { + if (gqlSocket.protocol === '4c-subscription-server') { // inform the client they are good to go gqlSocket.emit('connect'); } @@ -158,6 +170,8 @@ export default class WebSocketSubscriptionServer< } async close() { + clearTimeout(this.pingHandle!); + for (const client of this.ws.clients) { client.close(1001, 'Going away'); } @@ -167,4 +181,24 @@ export default class WebSocketSubscriptionServer< this.ws.close((err) => (err ? reject(err) : resolve())); }); } + + private scheduleLivelinessCheck() { + clearTimeout(this.pingHandle!); + this.pingHandle = setTimeout(() => { + for (const socket of this.ws.clients) { + const gql = this.gqlClients.get(socket); + if (!gql) { + continue; + } + if (!gql.isAlive) { + gql.disconnect('ping timeout'); + return; + } + + gql.ping(); + } + + this.scheduleLivelinessCheck(); + }, this.keepAlive); + } } diff --git a/src/types.ts b/src/types.ts index 7ae7d2cc..513e93a2 100644 --- a/src/types.ts +++ b/src/types.ts @@ -1,5 +1,9 @@ +export type SupportedProtocols = + | 'graphql-transport-ws' + | '4c-subscription-server'; + export interface WebSocket { - protocol: string; + protocol: SupportedProtocols; id?: string; close(code: number, reason: string): Promise | void; diff --git a/test/helpers.ts b/test/helpers.ts index 8b8fa051..a0ae0780 100644 --- a/test/helpers.ts +++ b/test/helpers.ts @@ -1,6 +1,3 @@ -// import { RedisClient } from 'redis'; -// import type { Socket } from 'socket.io-client'; -// import socketio from 'socket.io-client'; import { EventEmitter } from 'events'; import http from 'http'; @@ -70,8 +67,8 @@ export async function startServer( httpServer, subscriber, async close() { - httpServer.close(); await server.close(); + httpServer.close(); }, }; } diff --git a/test/socket-io.test.ts b/test/socket-io.test.ts index 990972ea..212ac787 100644 --- a/test/socket-io.test.ts +++ b/test/socket-io.test.ts @@ -1,4 +1,5 @@ /* eslint-disable no-underscore-dangle */ + import socketio from 'socket.io-client'; import { CreateLogger } from '../src'; diff --git a/test/websocket.test.ts b/test/websocket.test.ts index 081783e3..03c9f1af 100644 --- a/test/websocket.test.ts +++ b/test/websocket.test.ts @@ -1,14 +1,18 @@ +/* eslint-disable no-underscore-dangle */ + import WebSocketSubscriptionServer from '../src/WebSocketSubscriptionServer'; import schema from './data/schema'; import { TestClient, TestCredentialsManager, + delay, graphql, startServer, } from './helpers'; -function createServer(subscriber) { +function createServer(subscriber, options = {}) { return new WebSocketSubscriptionServer({ + ...options, path: '/graphql', schema, subscriber, @@ -22,7 +26,7 @@ function createServer(subscriber) { type PromiseType

= P extends Promise ? R : never; -describe('socket-io client', () => { +describe('websocket server', () => { let server: PromiseType>; let client: TestClient | null = null; @@ -121,4 +125,78 @@ describe('socket-io client', () => { await socket.unsubscribe(); }); + + it('should not race unsubscribe call', async () => { + const socket = await createClient( + graphql` + subscription TestTodoUpdatedSubscription( + $input: TodoUpdatedSubscriptionInput! + ) { + todoUpdated(input: $input) { + todo { + text + } + } + } + `, + { + input: { + id: '1', + }, + }, + ); + + await socket.authenticate(); + + const range = Array.from({ length: 2 }, (_, i) => i); + const promises = [] as any[]; + for (const id of range) { + promises.push(socket.subscribe(`s-${id}`)); + promises.push(socket.unsubscribe(`s-${id}`)); + + await delay(0); + } + + await Promise.all(promises); + + expect(server.subscriber._queues.size).toEqual(0); + expect(server.subscriber._channels.size).toEqual(0); + }); + + it('should clean up on client close', async () => { + const socket = await createClient( + graphql` + subscription TestTodoUpdatedSubscription( + $input: TodoUpdatedSubscriptionInput! + ) { + todoUpdated(input: $input) { + todo { + text + } + } + } + `, + { + input: { + id: '1', + }, + }, + ); + + expect(server.subscriber._queues.size).toEqual(0); + expect(server.subscriber._channels.size).toEqual(0); + + await socket.authenticate(); + await socket.subscribe(); + + expect(server.subscriber._queues.size).toEqual(1); + expect(server.subscriber._channels.size).toEqual(1); + + socket.close(); + + await delay(50); + + expect(server.subscriber._queues.size).toEqual(0); + expect(server.subscriber._channels.size).toEqual(0); + }); });