From c1d175755fc4eaba84908d769fb83b6f94934691 Mon Sep 17 00:00:00 2001 From: Adam Fowler Date: Wed, 29 Jan 2025 15:01:06 +0000 Subject: [PATCH] Combine RequestBody delegate, source state to fix hang (#666) * Combine RequestBody delegate produceMore and waiting continuations into single state Fixes issue with hang where yield was told to stop producing and but then was told it could continue producing before the next call to yield. * Attempt to remove withCheckedContinuation for produceMore state * Avoid allocating deque if only one continuation is created * Update codecov token * Resume all continuations on produceMore --- .../HummingbirdCore/Request/RequestBody.swift | 112 ++++++++++++++---- Tests/HummingbirdTests/ApplicationTests.swift | 46 +++++++ 2 files changed, 138 insertions(+), 20 deletions(-) diff --git a/Sources/HummingbirdCore/Request/RequestBody.swift b/Sources/HummingbirdCore/Request/RequestBody.swift index 082d3a4f..45a90904 100644 --- a/Sources/HummingbirdCore/Request/RequestBody.swift +++ b/Sources/HummingbirdCore/Request/RequestBody.swift @@ -132,38 +132,116 @@ extension RequestBody { > /// Delegate for NIOThrowingAsyncSequenceProducer + /// + /// This can be a struct as the state is stored inside a NIOLockedValueBox which + /// turns it into a reference value @usableFromInline - final class Delegate: NIOAsyncSequenceProducerDelegate, Sendable { - let checkedContinuations: NIOLockedValueBox>> + struct Delegate: NIOAsyncSequenceProducerDelegate, Sendable { + enum State { + case produceMore + case waitingForProduceMore(CheckedContinuation?) + case multipleWaitingForProduceMore(Deque>) + case terminated + } + let state: NIOLockedValueBox @usableFromInline init() { - self.checkedContinuations = .init([]) + self.state = .init(.produceMore) } @usableFromInline func produceMore() { - self.checkedContinuations.withLockedValue { - if let cont = $0.popFirst() { - cont.resume() + self.state.withLockedValue { state in + switch state { + case .produceMore: + break + case .waitingForProduceMore(let continuation): + if let continuation { + continuation.resume() + } + state = .produceMore + + case .multipleWaitingForProduceMore(var continuations): + // this isnt exactly correct as the number of continuations + // resumed can overflow the back pressure + while let cont = continuations.popFirst() { + cont.resume() + } + state = .produceMore + + case .terminated: + preconditionFailure("Unexpected state") } } } @usableFromInline func didTerminate() { - self.checkedContinuations.withLockedValue { - while let cont = $0.popFirst() { - cont.resume() + self.state.withLockedValue { state in + switch state { + case .produceMore: + break + case .waitingForProduceMore(let continuation): + if let continuation { + continuation.resume() + } + state = .terminated + case .multipleWaitingForProduceMore(var continuations): + while let cont = continuations.popFirst() { + cont.resume() + } + state = .terminated + case .terminated: + preconditionFailure("Unexpected state") } } } @usableFromInline func waitForProduceMore() async { - await withCheckedContinuation { (cont: CheckedContinuation) in - self.checkedContinuations.withLockedValue { - $0.append(cont) + switch self.state.withLockedValue({ $0 }) { + case .produceMore, .terminated: + break + case .waitingForProduceMore, .multipleWaitingForProduceMore: + await withCheckedContinuation { (newContinuation: CheckedContinuation) in + self.state.withLockedValue { state in + switch state { + case .produceMore: + newContinuation.resume() + case .waitingForProduceMore(let firstContinuation): + if let firstContinuation { + var continuations = Deque>() + continuations.reserveCapacity(2) + continuations.append(firstContinuation) + continuations.append(newContinuation) + state = .multipleWaitingForProduceMore(continuations) + } else { + state = .waitingForProduceMore(newContinuation) + } + case .multipleWaitingForProduceMore(var continuations): + continuations.append(newContinuation) + state = .multipleWaitingForProduceMore(continuations) + case .terminated: + newContinuation.resume() + } + } + } + } + } + + @usableFromInline + func stopProducing() { + self.state.withLockedValue { state in + switch state { + case .produceMore: + state = .waitingForProduceMore(nil) + case .waitingForProduceMore: + break + case .multipleWaitingForProduceMore: + break + case .terminated: + break } } } @@ -175,14 +253,11 @@ extension RequestBody { let source: Producer.Source @usableFromInline let delegate: Delegate - @usableFromInline - let waitForProduceMore: NIOLockedValueBox @usableFromInline init(source: Producer.Source, delegate: Delegate) { self.source = source self.delegate = delegate - self.waitForProduceMore = .init(false) } /// Yields the element to the inbound stream. @@ -195,13 +270,10 @@ extension RequestBody { public func yield(_ element: ByteBuffer) async throws { // if previous call indicated we should stop producing wait until the delegate // says we can start producing again - if self.waitForProduceMore.withLockedValue({ $0 }) { - await self.delegate.waitForProduceMore() - self.waitForProduceMore.withLockedValue { $0 = false } - } + await self.delegate.waitForProduceMore() let result = self.source.yield(element) if result == .stopProducing { - self.waitForProduceMore.withLockedValue { $0 = true } + self.delegate.stopProducing() } } diff --git a/Tests/HummingbirdTests/ApplicationTests.swift b/Tests/HummingbirdTests/ApplicationTests.swift index fc107018..e090d0be 100644 --- a/Tests/HummingbirdTests/ApplicationTests.swift +++ b/Tests/HummingbirdTests/ApplicationTests.swift @@ -867,6 +867,52 @@ final class ApplicationTests: XCTestCase { } } + /// Test AsyncSequence returned by RequestBody.makeStream() and feeding it data from multiple processes + func testMakeStreamMultipleSources() async throws { + let router = Router() + router.get("numbers") { request, context -> Response in + let body = try await withThrowingTaskGroup(of: Void.self) { group in + let (requestBody, source) = RequestBody.makeStream() + group.addTask { + // Add three tasks feeding the source + await withThrowingTaskGroup(of: Void.self) { group in + group.addTask { + for value in 0..<100 { + try await source.yield(ByteBuffer(string: String(describing: value))) + } + } + group.addTask { + for value in 0..<100 { + try await source.yield(ByteBuffer(string: String(describing: value))) + } + } + group.addTask { + for value in 0..<100 { + try await source.yield(ByteBuffer(string: String(describing: value))) + } + } + } + source.finish() + } + var body = ByteBuffer() + for try await buffer in requestBody { + var buffer = buffer + body.writeBuffer(&buffer) + try await Task.sleep(for: .milliseconds(1)) + } + return body + } + return Response(status: .ok, body: .init(byteBuffer: body)) + } + let app = Application(responder: router.buildResponder()) + + try await app.test(.router) { client in + try await client.execute(uri: "/numbers", method: .get) { response in + XCTAssertEqual(response.status, .ok) + } + } + } + #if compiler(>=6.0) /// Test consumeWithInboundCloseHandler func testConsumeWithInboundHandler() async throws {