diff --git a/subscription.go b/subscription.go index 1d692bc..2fa49d7 100644 --- a/subscription.go +++ b/subscription.go @@ -107,16 +107,17 @@ type SubscriptionProtocol interface { // SubscriptionContext represents a shared context for protocol implementations with the websocket connection inside type SubscriptionContext struct { context.Context - websocketConn WebsocketConn - OnConnected func() - onDisconnected func() - cancel context.CancelFunc - subscriptions map[string]Subscription - disabledLogTypes []OperationMessageType - log func(args ...interface{}) - acknowledged int32 - exitStatusCodes []int - mutex sync.Mutex + websocketConn WebsocketConn + OnConnected func() + onDisconnected func() + onConnectionAlive func() + cancel context.CancelFunc + subscriptions map[string]Subscription + disabledLogTypes []OperationMessageType + log func(args ...interface{}) + acknowledged int32 + exitStatusCodes []int + mutex sync.Mutex } // Log prints condition logging with message type filters @@ -419,6 +420,12 @@ func (sc *SubscriptionClient) OnDisconnected(fn func()) *SubscriptionClient { return sc } +// OnConnectionAlive event is triggered when the websocket receive a connection alive message (differs per protocol) +func (sc *SubscriptionClient) OnConnectionAlive(fn func()) *SubscriptionClient { + sc.context.onConnectionAlive = fn + return sc +} + // get internal client status func (sc *SubscriptionClient) getClientStatus() int32 { return atomic.LoadInt32(&sc.clientStatus) diff --git a/subscription_graphql_ws.go b/subscription_graphql_ws.go index e7bf3a7..33e798f 100644 --- a/subscription_graphql_ws.go +++ b/subscription_graphql_ws.go @@ -132,6 +132,9 @@ func (gws *graphqlWS) OnMessage(ctx *SubscriptionContext, subscription Subscript _ = gws.Unsubscribe(ctx, message.ID) case GQLPing: ctx.Log(message, "server", GQLPing) + if ctx.onConnectionAlive != nil { + ctx.onConnectionAlive() + } // send pong response message back to the server msg := OperationMessage{ Type: GQLPong, diff --git a/subscriptions_transport_ws.go b/subscriptions_transport_ws.go index a681cf2..d76b601 100644 --- a/subscriptions_transport_ws.go +++ b/subscriptions_transport_ws.go @@ -164,6 +164,9 @@ func (stw *subscriptionsTransportWS) OnMessage(ctx *SubscriptionContext, subscri _ = stw.Unsubscribe(ctx, message.ID) case GQLConnectionKeepAlive: ctx.Log(message, "server", GQLConnectionKeepAlive) + if ctx.onConnectionAlive != nil { + ctx.onConnectionAlive() + } case GQLConnectionAck: // Expected response to the ConnectionInit message from the client acknowledging a successful connection with the server. // The client is now ready to request subscription operations.