Skip to content
93 changes: 72 additions & 21 deletions Sources/NIOCore/AsyncChannel/AsyncChannelHandler.swift
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,16 @@

import DequeModule

@usableFromInline
enum OutboundAction<OutboundOut>: Sendable where OutboundOut: Sendable {
/// Write value
case write(OutboundOut)
/// Write value and flush pipeline
case writeAndFlush(OutboundOut, EventLoopPromise<Void>)
/// flush writes to writer
case flush(EventLoopPromise<Void>)
}

/// A ``ChannelHandler`` that is used to transform the inbound portion of a NIO
/// ``Channel`` into an asynchronous sequence that supports back-pressure. It's also used
/// to write the outbound portion of a NIO ``Channel`` from Swift Concurrency with back-pressure
Expand Down Expand Up @@ -77,7 +87,7 @@ internal final class NIOAsyncChannelHandler<InboundIn: Sendable, ProducerElement

@usableFromInline
typealias Writer = NIOAsyncWriter<
OutboundOut,
OutboundAction<OutboundOut>,
NIOAsyncChannelHandlerWriterDelegate<OutboundOut>
>

Expand Down Expand Up @@ -372,7 +382,10 @@ struct NIOAsyncChannelHandlerProducerDelegate: @unchecked Sendable, NIOAsyncSequ

@available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *)
@usableFromInline
struct NIOAsyncChannelHandlerWriterDelegate<Element: Sendable>: NIOAsyncWriterSinkDelegate, @unchecked Sendable {
struct NIOAsyncChannelHandlerWriterDelegate<OutboundOut: Sendable>: NIOAsyncWriterSinkDelegate, @unchecked Sendable {
@usableFromInline
typealias Element = OutboundAction<OutboundOut>

@usableFromInline
let eventLoop: EventLoop

Expand All @@ -386,7 +399,7 @@ struct NIOAsyncChannelHandlerWriterDelegate<Element: Sendable>: NIOAsyncWriterSi
let _didTerminate: ((any Error)?) -> Void

@inlinable
init<InboundIn, ProducerElement>(handler: NIOAsyncChannelHandler<InboundIn, ProducerElement, Element>) {
init<InboundIn, ProducerElement>(handler: NIOAsyncChannelHandler<InboundIn, ProducerElement, OutboundOut>) {
self.eventLoop = handler.eventLoop
self._didYieldContentsOf = handler._didYield(sequence:)
self._didYield = handler._didYield(element:)
Expand Down Expand Up @@ -430,35 +443,27 @@ struct NIOAsyncChannelHandlerWriterDelegate<Element: Sendable>: NIOAsyncWriterSi
@available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *)
extension NIOAsyncChannelHandler {
@inlinable
func _didYield(sequence: Deque<OutboundOut>) {
func _didYield(sequence: Deque<OutboundAction<OutboundOut>>) {
// This is always called from an async context, so we must loop-hop.
// Because we always loop-hop, we're always at the top of a stack frame. As this
// is the only source of writes for us, and as this channel handler doesn't implement
// func write(), we cannot possibly re-entrantly write. That means we can skip many of the
// awkward re-entrancy protections NIO usually requires, and can safely just do an iterative
// write.
self.eventLoop.preconditionInEventLoop()
guard let context = self.context else {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This test has been moved into _doOutboundWrites as we need to complete promises even if the channel handler is no longer there.

// Already removed from the channel by now, we can stop.
return
}

self._doOutboundWrites(context: context, writes: sequence)
}

@inlinable
func _didYield(element: OutboundOut) {
func _didYield(element: OutboundAction<OutboundOut>) {
// This is always called from an async context, so we must loop-hop.
// Because we always loop-hop, we're always at the top of a stack frame. As this
// is the only source of writes for us, and as this channel handler doesn't implement
// func write(), we cannot possibly re-entrantly write. That means we can skip many of the
// awkward re-entrancy protections NIO usually requires, and can safely just do an iterative
// write.
self.eventLoop.preconditionInEventLoop()
guard let context = self.context else {
// Already removed from the channel by now, we can stop.
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This test has been moved into _doOutboundWrites as we need to complete promises even if the channel handler is no longer there.

return
}

self._doOutboundWrite(context: context, write: element)
}
Expand All @@ -475,18 +480,64 @@ extension NIOAsyncChannelHandler {
}

@inlinable
func _doOutboundWrites(context: ChannelHandlerContext, writes: Deque<OutboundOut>) {
for write in writes {
context.write(Self.wrapOutboundOut(write), promise: nil)
func _doOutboundWrites(context: ChannelHandlerContext?, writes: Deque<OutboundAction<OutboundOut>>) {
// write everything but the last item
for write in writes.dropLast() {
switch write {
case .write(let value), .writeAndFlush(let value, _):
guard let context = self.context else {
// Already removed from the channel by now, we can stop.
return
}
context.write(Self.wrapOutboundOut(value), promise: nil)
context.flush()
case .flush(let promise):
promise.succeed()
}
}
// write last item
switch writes.last {
case .write(let value):
guard let context = self.context else {
// Already removed from the channel by now, we can stop.
return
}
context.write(Self.wrapOutboundOut(value), promise: nil)
context.flush()
case .flush(let promise):
promise.succeed()
case .writeAndFlush(let value, let promise):
guard let context = self.context else {
// Already removed from the channel by now, we can stop.
promise.succeed()
return
}
context.writeAndFlush(Self.wrapOutboundOut(value), promise: promise)
case .none:
break
}

context.flush()
}

@inlinable
func _doOutboundWrite(context: ChannelHandlerContext, write: OutboundOut) {
context.write(Self.wrapOutboundOut(write), promise: nil)
context.flush()
func _doOutboundWrite(context: ChannelHandlerContext?, write: OutboundAction<OutboundOut>) {
switch write {
case .write(let value):
guard let context = self.context else {
// Already removed from the channel by now, we can stop.
return
}
context.write(Self.wrapOutboundOut(value), promise: nil)
context.flush()
case .flush(let promise):
promise.succeed()
case .writeAndFlush(let value, let promise):
guard let context = self.context else {
// Already removed from the channel by now, we can stop.
promise.succeed()
return
}
context.writeAndFlush(Self.wrapOutboundOut(value), promise: promise)
}
}
}

Expand Down
75 changes: 66 additions & 9 deletions Sources/NIOCore/AsyncChannel/AsyncChannelOutboundWriter.swift
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
public struct NIOAsyncChannelOutboundWriter<OutboundOut: Sendable>: Sendable {
@usableFromInline
typealias _Writer = NIOAsyncWriter<
OutboundOut,
OutboundAction<OutboundOut>,
NIOAsyncChannelHandlerWriterDelegate<OutboundOut>
>

Expand Down Expand Up @@ -66,7 +66,7 @@ public struct NIOAsyncChannelOutboundWriter<OutboundOut: Sendable>: Sendable {
@usableFromInline
enum Backing: Sendable {
case asyncStream(AsyncStream<OutboundOut>.Continuation)
case writer(_Writer)
case writer(_Writer, EventLoop)
}

@usableFromInline
Expand All @@ -93,7 +93,7 @@ public struct NIOAsyncChannelOutboundWriter<OutboundOut: Sendable>: Sendable {
) throws {
eventLoop.preconditionInEventLoop()
let writer = _Writer.makeWriter(
elementType: OutboundOut.self,
elementType: OutboundAction<OutboundOut>.self,
isWritable: true,
finishOnDeinit: closeOnDeinit,
delegate: .init(handler: handler)
Expand All @@ -102,7 +102,7 @@ public struct NIOAsyncChannelOutboundWriter<OutboundOut: Sendable>: Sendable {
handler.sink = writer.sink
handler.writer = writer.writer

self._backing = .writer(writer.writer)
self._backing = .writer(writer.writer, eventLoop)
}

@inlinable
Expand All @@ -118,8 +118,23 @@ public struct NIOAsyncChannelOutboundWriter<OutboundOut: Sendable>: Sendable {
switch self._backing {
case .asyncStream(let continuation):
continuation.yield(data)
case .writer(let writer):
try await writer.yield(data)
case .writer(let writer, _):
try await writer.yield(.write(data))
}
}

/// Send a write into the ``ChannelPipeline`` and flush it right away.
///
/// This method suspends until the write has been written and flushed.
@inlinable
public func writeAndFlush(_ data: OutboundOut) async throws {
switch self._backing {
case .asyncStream(let continuation):
continuation.yield(data)
case .writer(let writer, let eventLoop):
try await self.withPromise(eventLoop: eventLoop) { promise in
try await writer.yield(.writeAndFlush(data, promise))
}
}
}

Expand All @@ -133,8 +148,26 @@ public struct NIOAsyncChannelOutboundWriter<OutboundOut: Sendable>: Sendable {
for data in sequence {
continuation.yield(data)
}
case .writer(let writer):
try await writer.yield(contentsOf: sequence)
case .writer(let writer, _):
try await writer.yield(contentsOf: sequence.map { .write($0) })
}
}

/// Send a sequence of writes into the ``ChannelPipeline`` and flush them right away.
///
/// This method suspends if the underlying channel is not writable and will resume once the it becomes writable again.
@inlinable
public func writeAndFlush<Writes: Sequence>(contentsOf sequence: Writes) async throws
where Writes.Element == OutboundOut {
switch self._backing {
case .asyncStream(let continuation):
for data in sequence {
continuation.yield(data)
}
case .writer(let writer, let eventLoop):
try await withPromise(eventLoop: eventLoop) { promise in
try await writer.yield(contentsOf: sequence.map { .writeAndFlush($0, promise) })
}
}
}

Expand All @@ -151,17 +184,41 @@ public struct NIOAsyncChannelOutboundWriter<OutboundOut: Sendable>: Sendable {
}
}

/// Ensure all writes to the writer have been read
@inlinable
public func flush() async throws {
if case .writer(let writer, let eventLoop) = self._backing {
try await self.withPromise(eventLoop: eventLoop) { promise in
try await writer.yield(.flush(promise))
}
}
}

/// Finishes the writer.
///
/// This might trigger a half closure if the ``NIOAsyncChannel`` was configured to support it.
public func finish() {
switch self._backing {
case .asyncStream(let continuation):
continuation.finish()
case .writer(let writer):
case .writer(let writer, _):
writer.finish()
}
}

@usableFromInline
func withPromise(
eventLoop: EventLoop,
_ process: (EventLoopPromise<Void>) async throws -> Void
) async throws {
let promise = eventLoop.makePromise(of: Void.self)
do {
try await process(promise)
try await promise.futureResult.get()
} catch {
promise.fail(error)
}
}
}

@available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *)
Expand Down
84 changes: 84 additions & 0 deletions Tests/NIOCoreTests/AsyncChannel/AsyncChannelTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,63 @@ final class AsyncChannelTests: XCTestCase {
}
}

func testAllWritesAreWritten() async throws {
let channel = NIOAsyncTestingChannel()
let promise = channel.testingEventLoop.makePromise(of: Void.self)
let wrapped = try await channel.testingEventLoop.executeInContext {
try channel.pipeline.syncOperations.addHandler(DelayingChannelHandler(promise: promise))
return try NIOAsyncChannel<Never, String>(wrappingChannelSynchronously: channel)
}
try await withThrowingTaskGroup(of: Void.self) { group in
group.addTask {
try await wrapped.executeThenClose { inbound, outbound in
try await outbound.write("hello")
try await outbound.writeAndFlush("world")
}
}
group.addTask {
let firstRead = try await channel.waitForOutboundWrite(as: String.self)
let secondRead = try await channel.waitForOutboundWrite(as: String.self)

XCTAssertEqual(firstRead, "hello")
XCTAssertEqual(secondRead, "world")
}

// wait 50 milliseconds to ensure we are inside write and flush then
// trigger pipeline flush by succeeding promise in DelayingChannelHandler
try await Task.sleep(for: .milliseconds(50))
promise.succeed()
}
}

func testAllWritesInSequenceAreWritten() async throws {
let channel = NIOAsyncTestingChannel()
let promise = channel.testingEventLoop.makePromise(of: Void.self)
let wrapped = try await channel.testingEventLoop.executeInContext {
try channel.pipeline.syncOperations.addHandler(DelayingChannelHandler(promise: promise))
return try NIOAsyncChannel<Never, String>(wrappingChannelSynchronously: channel)
}
try await withThrowingTaskGroup(of: Void.self) { group in
group.addTask {
try await wrapped.executeThenClose { inbound, outbound in
try await outbound.writeAndFlush(contentsOf: ["hello", "world"])
}
}
group.addTask {
let firstRead = try await channel.waitForOutboundWrite(as: String.self)
let secondRead = try await channel.waitForOutboundWrite(as: String.self)

XCTAssertEqual(firstRead, "hello")
XCTAssertEqual(secondRead, "world")
}

// wait 50 milliseconds to ensure we are inside write and flush then
// trigger pipeline flush by succeeding promise in DelayingChannelHandler
try await Task.sleep(for: .milliseconds(50))
promise.succeed()
}
}

func testErrorsArePropagatedButAfterReads() async throws {
let channel = NIOAsyncTestingChannel()
let wrapped = try await channel.testingEventLoop.executeInContext {
Expand Down Expand Up @@ -429,6 +486,17 @@ private final class CloseRecorder: ChannelOutboundHandler, @unchecked Sendable {
}
}

struct UnsafeContext: @unchecked Sendable {
private let _context: ChannelHandlerContext
var context: ChannelHandlerContext {
self._context.eventLoop.preconditionInEventLoop()
return _context
}
init(_ context: ChannelHandlerContext) {
self._context = context
}
}

private final class CloseSuppressor: ChannelOutboundHandler, RemovableChannelHandler, Sendable {
typealias OutboundIn = Any

Expand All @@ -438,6 +506,22 @@ private final class CloseSuppressor: ChannelOutboundHandler, RemovableChannelHan
}
}

private final class DelayingChannelHandler: ChannelOutboundHandler, RemovableChannelHandler, Sendable {
typealias OutboundIn = Any
typealias OutboundOut = Any
let waitPromise: EventLoopPromise<Void>

init(promise: EventLoopPromise<Void>) {
self.waitPromise = promise
}
func write(context: ChannelHandlerContext, data: NIOAny, promise: EventLoopPromise<Void>?) {
let unsafeTransfer = UnsafeTransfer((context: context, data: data))
self.waitPromise.futureResult.whenComplete { _ in
unsafeTransfer.wrappedValue.context.writeAndFlush(unsafeTransfer.wrappedValue.data, promise: promise)
}
}
}

@available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *)
extension NIOAsyncTestingChannel {
fileprivate func closeIgnoringSuppression() async throws {
Expand Down