Skip to content

Commit

Permalink
Handle Sendability of RemovableChannelHandler (#2953)
Browse files Browse the repository at this point in the history
Motivation:

RemovableChannelHandlers have a large API surface in NIOCore. That API
surface is a bit awkward with regard to strict concurrency, and needs
some cleanup.

Modifications:

This patch adds some new API that is necessary to safely work with
RemovableChannelHandlers, deprecates some API that cannot plausibly be
used, and cleans up some other parts of the API.

Result:

Easier to work with RemovableChannelHandlers
  • Loading branch information
Lukasa authored Oct 29, 2024
1 parent 02906a6 commit 411c2c5
Show file tree
Hide file tree
Showing 13 changed files with 167 additions and 59 deletions.
5 changes: 5 additions & 0 deletions Sources/NIOCore/AsyncAwaitSupport.swift
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,11 @@ extension ChannelPipeline {
}

@available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *)
@available(
*,
deprecated,
message: "Use .syncOperations.removeHandler(context:) instead, this method is not Sendable-safe."
)
public func removeHandler(context: ChannelHandlerContext) async throws {
try await self.removeHandler(context: context).get()
}
Expand Down
161 changes: 129 additions & 32 deletions Sources/NIOCore/ChannelPipeline.swift
Original file line number Diff line number Diff line change
Expand Up @@ -167,8 +167,9 @@ public final class ChannelPipeline: ChannelInvoker {
/// - handler: the `ChannelHandler` to add
/// - position: The position in the `ChannelPipeline` to add `handler`. Defaults to `.last`.
/// - returns: the `EventLoopFuture` which will be notified once the `ChannelHandler` was added.
@preconcurrency
public func addHandler(
_ handler: ChannelHandler,
_ handler: ChannelHandler & Sendable,
name: String? = nil,
position: ChannelPipeline.Position = .last
) -> EventLoopFuture<Void> {
Expand Down Expand Up @@ -349,7 +350,8 @@ public final class ChannelPipeline: ChannelInvoker {
/// - parameters:
/// - handler: the `ChannelHandler` to remove.
/// - returns: the `EventLoopFuture` which will be notified once the `ChannelHandler` was removed.
public func removeHandler(_ handler: RemovableChannelHandler) -> EventLoopFuture<Void> {
@preconcurrency
public func removeHandler(_ handler: RemovableChannelHandler & Sendable) -> EventLoopFuture<Void> {
let promise = self.eventLoop.makePromise(of: Void.self)
self.removeHandler(handler, promise: promise)
return promise.futureResult
Expand All @@ -371,6 +373,11 @@ public final class ChannelPipeline: ChannelInvoker {
/// - parameters:
/// - context: the `ChannelHandlerContext` that belongs to `ChannelHandler` that should be removed.
/// - returns: the `EventLoopFuture` which will be notified once the `ChannelHandler` was removed.
@available(
*,
deprecated,
message: "Use .syncOperations.removeHandler(context:) instead, this method is not Sendable-safe."
)
public func removeHandler(context: ChannelHandlerContext) -> EventLoopFuture<Void> {
let promise = self.eventLoop.makePromise(of: Void.self)
self.removeHandler(context: context, promise: promise)
Expand All @@ -382,14 +389,11 @@ public final class ChannelPipeline: ChannelInvoker {
/// - parameters:
/// - handler: the `ChannelHandler` to remove.
/// - promise: An `EventLoopPromise` that will complete when the `ChannelHandler` is removed.
public func removeHandler(_ handler: RemovableChannelHandler, promise: EventLoopPromise<Void>?) {
@preconcurrency
public func removeHandler(_ handler: RemovableChannelHandler & Sendable, promise: EventLoopPromise<Void>?) {
@Sendable
func removeHandler0() {
switch self.contextSync(handler: handler) {
case .success(let context):
self.removeHandler(context: context, promise: promise)
case .failure(let error):
promise?.fail(error)
}
self.syncOperations.removeHandler(handler, promise: promise)
}

if self.eventLoop.inEventLoop {
Expand All @@ -407,13 +411,9 @@ public final class ChannelPipeline: ChannelInvoker {
/// - name: the name that was used to add the `ChannelHandler` to the `ChannelPipeline` before.
/// - promise: An `EventLoopPromise` that will complete when the `ChannelHandler` is removed.
public func removeHandler(name: String, promise: EventLoopPromise<Void>?) {
@Sendable
func removeHandler0() {
switch self.contextSync(name: name) {
case .success(let context):
self.removeHandler(context: context, promise: promise)
case .failure(let error):
promise?.fail(error)
}
self.syncOperations.removeHandler(name: name, promise: promise)
}

if self.eventLoop.inEventLoop {
Expand All @@ -430,13 +430,22 @@ public final class ChannelPipeline: ChannelInvoker {
/// - parameters:
/// - context: the `ChannelHandlerContext` that belongs to `ChannelHandler` that should be removed.
/// - promise: An `EventLoopPromise` that will complete when the `ChannelHandler` is removed.
@available(
*,
deprecated,
message: "Use .syncOperations.removeHandler(context:) instead, this method is not Sendable-safe."
)
public func removeHandler(context: ChannelHandlerContext, promise: EventLoopPromise<Void>?) {
guard context.handler is RemovableChannelHandler else {
let sendableView = context.sendableView

guard sendableView.channelHandlerIsRemovable else {
promise?.fail(ChannelError._unremovableHandler)
return
}

@Sendable
func removeHandler0() {
context.startUserTriggeredRemoval(promise: promise)
sendableView.wrappedValue.startUserTriggeredRemoval(promise: promise)
}

if self.eventLoop.inEventLoop {
Expand All @@ -453,7 +462,13 @@ public final class ChannelPipeline: ChannelInvoker {
/// - parameters:
/// - handler: the `ChannelHandler` for which the `ChannelHandlerContext` should be returned
/// - returns: the `EventLoopFuture` which will be notified once the the operation completes.
public func context(handler: ChannelHandler) -> EventLoopFuture<ChannelHandlerContext> {
@available(
*,
deprecated,
message: "This method is not strict concurrency safe. Prefer .syncOperations.context(handler:)"
)
@preconcurrency
public func context(handler: ChannelHandler & Sendable) -> EventLoopFuture<ChannelHandlerContext> {
let promise = self.eventLoop.makePromise(of: ChannelHandlerContext.self)

if self.eventLoop.inEventLoop {
Expand Down Expand Up @@ -1005,8 +1020,9 @@ extension ChannelPipeline {
/// - position: The position in the `ChannelPipeline` to add `handlers`. Defaults to `.last`.
///
/// - returns: A future that will be completed when all of the supplied `ChannelHandler`s were added.
@preconcurrency
public func addHandlers(
_ handlers: [ChannelHandler],
_ handlers: [ChannelHandler & Sendable],
position: ChannelPipeline.Position = .last
) -> EventLoopFuture<Void> {
let future: EventLoopFuture<Void>
Expand All @@ -1030,8 +1046,9 @@ extension ChannelPipeline {
/// - position: The position in the `ChannelPipeline` to add `handlers`. Defaults to `.last`.
///
/// - returns: A future that will be completed when all of the supplied `ChannelHandler`s were added.
@preconcurrency
public func addHandlers(
_ handlers: ChannelHandler...,
_ handlers: (ChannelHandler & Sendable)...,
position: ChannelPipeline.Position = .last
) -> EventLoopFuture<Void> {
self.addHandlers(handlers, position: position)
Expand Down Expand Up @@ -1149,29 +1166,75 @@ extension ChannelPipeline {
/// - parameters:
/// - handler: the `ChannelHandler` to remove.
/// - returns: the `EventLoopFuture` which will be notified once the `ChannelHandler` was removed.
@preconcurrency
public func removeHandler(_ handler: RemovableChannelHandler) -> EventLoopFuture<Void> {
let promise = self.eventLoop.makePromise(of: Void.self)
self.removeHandler(handler, promise: promise)
return promise.futureResult
}

/// Remove a ``ChannelHandler`` from the ``ChannelPipeline``.
///
/// - parameters:
/// - handler: the ``ChannelHandler`` to remove.
/// - promise: an ``EventLoopPromise`` to notify when the ``ChannelHandler`` was removed.
public func removeHandler(_ handler: RemovableChannelHandler, promise: EventLoopPromise<Void>?) {
switch self._pipeline.contextSync(handler: handler) {
case .success(let context):
self._pipeline.removeHandler(context: context, promise: promise)
self.removeHandler(context: context, promise: promise)
case .failure(let error):
promise.fail(error)
promise?.fail(error)
}
}

/// Remove a `ChannelHandler` from the `ChannelPipeline`.
///
/// - parameters:
/// - name: the name that was used to add the `ChannelHandler` to the `ChannelPipeline` before.
/// - returns: the `EventLoopFuture` which will be notified once the `ChannelHandler` was removed.
public func removeHandler(name: String) -> EventLoopFuture<Void> {
let promise = self.eventLoop.makePromise(of: Void.self)
self.removeHandler(name: name, promise: promise)
return promise.futureResult
}

/// Remove a ``ChannelHandler`` from the ``ChannelPipeline``.
///
/// - parameters:
/// - name: the name that was used to add the `ChannelHandler` to the `ChannelPipeline` before.
/// - promise: an ``EventLoopPromise`` to notify when the ``ChannelHandler`` was removed.
public func removeHandler(name: String, promise: EventLoopPromise<Void>?) {
switch self._pipeline.contextSync(name: name) {
case .success(let context):
self.removeHandler(context: context, promise: promise)
case .failure(let error):
promise?.fail(error)
}
}

/// Remove a `ChannelHandler` from the `ChannelPipeline`.
///
/// - parameters:
/// - context: the `ChannelHandlerContext` that belongs to `ChannelHandler` that should be removed.
/// - returns: the `EventLoopFuture` which will be notified once the `ChannelHandler` was removed.
public func removeHandler(context: ChannelHandlerContext) -> EventLoopFuture<Void> {
let promise = self.eventLoop.makePromise(of: Void.self)
self._pipeline.removeHandler(context: context, promise: promise)
self.removeHandler(context: context, promise: promise)
return promise.futureResult
}

/// Remove a `ChannelHandler` from the `ChannelPipeline`.
///
/// - parameters:
/// - context: the `ChannelHandlerContext` that belongs to `ChannelHandler` that should be removed.
/// - promise: an ``EventLoopPromise`` to notify when the ``ChannelHandler`` was removed.
public func removeHandler(context: ChannelHandlerContext, promise: EventLoopPromise<Void>?) {
if context.handler is RemovableChannelHandler {
context.startUserTriggeredRemoval(promise: promise)
} else {
promise?.fail(ChannelError.unremovableHandler)
}
}

/// Returns the `ChannelHandlerContext` for the given handler instance if it is in
/// the `ChannelPipeline`, if it exists.
///
Expand Down Expand Up @@ -1367,26 +1430,24 @@ extension ChannelPipeline.SynchronousOperations: Sendable {}

extension ChannelPipeline {
/// A `Position` within the `ChannelPipeline` used to insert handlers into the `ChannelPipeline`.
public enum Position {
@preconcurrency
public enum Position: Sendable {
/// The first `ChannelHandler` -- the front of the `ChannelPipeline`.
case first

/// The last `ChannelHandler` -- the back of the `ChannelPipeline`.
case last

/// Before the given `ChannelHandler`.
case before(ChannelHandler)
case before(ChannelHandler & Sendable)

/// After the given `ChannelHandler`.
case after(ChannelHandler)
case after(ChannelHandler & Sendable)
}
}

@available(*, unavailable)
extension ChannelPipeline.Position: Sendable {}

/// Special `ChannelHandler` that forwards all events to the `Channel.Unsafe` implementation.
final class HeadChannelHandler: _ChannelOutboundHandler {
final class HeadChannelHandler: _ChannelOutboundHandler, Sendable {

static let name = "head"
static let sharedInstance = HeadChannelHandler()
Expand Down Expand Up @@ -1442,7 +1503,7 @@ extension CloseMode {
}

/// Special `ChannelInboundHandler` which will consume all inbound events.
final class TailChannelHandler: _ChannelInboundHandler {
final class TailChannelHandler: _ChannelInboundHandler, Sendable {

static let name = "tail"
static let sharedInstance = TailChannelHandler()
Expand Down Expand Up @@ -1977,6 +2038,42 @@ extension ChannelHandlerContext {
}
}

extension ChannelHandlerContext {
var sendableView: SendableView {
SendableView(wrapping: self)
}

/// A wrapper over ``ChannelHandlerContext`` that allows access to the thread-safe API
/// surface on the type.
///
/// Very little of ``ChannelHandlerContext`` is thread-safe, but in a rare few places
/// there are things we can access. This type makes those available.
struct SendableView: @unchecked Sendable {
private let context: ChannelHandlerContext

fileprivate init(wrapping context: ChannelHandlerContext) {
self.context = context
}

/// Whether the ``ChannelHandler`` associated with this context conforms to
/// ``RemovableChannelHandler``.
var channelHandlerIsRemovable: Bool {
// `context.handler` is not mutable, and set at construction, so this access is
// acceptable. The protocol conformance check is also safe.
self.context.handler is RemovableChannelHandler
}

/// Grabs the underlying ``ChannelHandlerContext``. May only be called on the
/// event loop.
var wrappedValue: ChannelHandlerContext {
// The event loop lookup here is also thread-safe, so we can grab the value out
// and use it.
self.context.eventLoop.preconditionInEventLoop()
return self.context
}
}
}

extension ChannelPipeline: CustomDebugStringConvertible {
public var debugDescription: String {
// This method forms output in the following format:
Expand Down
4 changes: 2 additions & 2 deletions Sources/NIOHTTP1/HTTPServerUpgradeHandler.swift
Original file line number Diff line number Diff line change
Expand Up @@ -290,7 +290,7 @@ public final class HTTPServerUpgradeHandler: ChannelInboundHandler, RemovableCha
)
self.upgradeState = .upgradeComplete
// When we remove ourselves we'll be delivering any buffered data.
context.pipeline.removeHandler(context: context, promise: nil)
context.pipeline.syncOperations.removeHandler(context: context, promise: nil)

case .failure(let error):
// Remain in the '.upgrading' state.
Expand Down Expand Up @@ -357,7 +357,7 @@ public final class HTTPServerUpgradeHandler: ChannelInboundHandler, RemovableCha
context.fireChannelReadComplete()

// Ok, we've delivered all the parts. We can now remove ourselves, which should happen synchronously.
context.pipeline.removeHandler(context: context, promise: nil)
context.pipeline.syncOperations.removeHandler(context: context, promise: nil)
}

/// Builds the initial mandatory HTTP headers for HTTP upgrade responses.
Expand Down
4 changes: 2 additions & 2 deletions Sources/NIOHTTP1/NIOHTTPClientUpgradeHandler.swift
Original file line number Diff line number Diff line change
Expand Up @@ -356,7 +356,7 @@ public final class NIOHTTPClientUpgradeHandler: ChannelDuplexHandler, RemovableC
self.upgradeState = .upgradeComplete
}
.whenComplete { _ in
context.pipeline.removeHandler(context: context, promise: nil)
context.pipeline.syncOperations.removeHandler(context: context, promise: nil)
}
}
}
Expand Down Expand Up @@ -397,7 +397,7 @@ public final class NIOHTTPClientUpgradeHandler: ChannelDuplexHandler, RemovableC
context.fireChannelRead(Self.wrapInboundOut(data))

// We've delivered the data. We can now remove ourselves, which should happen synchronously.
context.pipeline.removeHandler(context: context, promise: nil)
context.pipeline.syncOperations.removeHandler(context: context, promise: nil)
}
}

Expand Down
10 changes: 7 additions & 3 deletions Sources/NIOTLS/ApplicationProtocolNegotiationHandler.swift
Original file line number Diff line number Diff line change
Expand Up @@ -128,10 +128,12 @@ public final class ApplicationProtocolNegotiationHandler: ChannelInboundHandler,
}

private func userFutureCompleted(context: ChannelHandlerContext, result: Result<Void, Error>) {
context.eventLoop.assertInEventLoop()

switch self.stateMachine.userFutureCompleted(with: result) {
case .fireErrorCaughtAndRemoveHandler(let error):
context.fireErrorCaught(error)
context.pipeline.removeHandler(self, promise: nil)
context.pipeline.syncOperations.removeHandler(self, promise: nil)

case .fireErrorCaughtAndStartUnbuffering(let error):
context.fireErrorCaught(error)
Expand All @@ -141,22 +143,24 @@ public final class ApplicationProtocolNegotiationHandler: ChannelInboundHandler,
self.unbuffer(context: context)

case .removeHandler:
context.pipeline.removeHandler(self, promise: nil)
context.pipeline.syncOperations.removeHandler(self, promise: nil)

case .none:
break
}
}

private func unbuffer(context: ChannelHandlerContext) {
context.eventLoop.assertInEventLoop()

while true {
switch self.stateMachine.unbuffer() {
case .fireChannelRead(let data):
context.fireChannelRead(data)

case .fireChannelReadCompleteAndRemoveHandler:
context.fireChannelReadComplete()
context.pipeline.removeHandler(self, promise: nil)
context.pipeline.syncOperations.removeHandler(self, promise: nil)
return
}
}
Expand Down
Loading

0 comments on commit 411c2c5

Please sign in to comment.