Skip to content

Commit

Permalink
Combine RequestBody delegate, source state to fix hang (#666)
Browse files Browse the repository at this point in the history
* Combine RequestBody delegate produceMore and waiting continuations into single state

Fixes issue with hang where yield was told to stop producing and but then was told it could continue producing before the next call to yield.

* Attempt to remove withCheckedContinuation for produceMore state

* Avoid allocating deque if only one continuation is created

* Update codecov token

* Resume all continuations on produceMore
  • Loading branch information
adam-fowler authored Jan 29, 2025
1 parent b1bc64f commit c1d1757
Show file tree
Hide file tree
Showing 2 changed files with 138 additions and 20 deletions.
112 changes: 92 additions & 20 deletions Sources/HummingbirdCore/Request/RequestBody.swift
Original file line number Diff line number Diff line change
Expand Up @@ -132,38 +132,116 @@ extension RequestBody {
>

/// Delegate for NIOThrowingAsyncSequenceProducer
///
/// This can be a struct as the state is stored inside a NIOLockedValueBox which
/// turns it into a reference value
@usableFromInline
final class Delegate: NIOAsyncSequenceProducerDelegate, Sendable {
let checkedContinuations: NIOLockedValueBox<Deque<CheckedContinuation<Void, Never>>>
struct Delegate: NIOAsyncSequenceProducerDelegate, Sendable {
enum State {
case produceMore
case waitingForProduceMore(CheckedContinuation<Void, Never>?)
case multipleWaitingForProduceMore(Deque<CheckedContinuation<Void, Never>>)
case terminated
}
let state: NIOLockedValueBox<State>

@usableFromInline
init() {
self.checkedContinuations = .init([])
self.state = .init(.produceMore)
}

@usableFromInline
func produceMore() {
self.checkedContinuations.withLockedValue {
if let cont = $0.popFirst() {
cont.resume()
self.state.withLockedValue { state in
switch state {
case .produceMore:
break
case .waitingForProduceMore(let continuation):
if let continuation {
continuation.resume()
}
state = .produceMore

case .multipleWaitingForProduceMore(var continuations):
// this isnt exactly correct as the number of continuations
// resumed can overflow the back pressure
while let cont = continuations.popFirst() {
cont.resume()
}
state = .produceMore

case .terminated:
preconditionFailure("Unexpected state")
}
}
}

@usableFromInline
func didTerminate() {
self.checkedContinuations.withLockedValue {
while let cont = $0.popFirst() {
cont.resume()
self.state.withLockedValue { state in
switch state {
case .produceMore:
break
case .waitingForProduceMore(let continuation):
if let continuation {
continuation.resume()
}
state = .terminated
case .multipleWaitingForProduceMore(var continuations):
while let cont = continuations.popFirst() {
cont.resume()
}
state = .terminated
case .terminated:
preconditionFailure("Unexpected state")
}
}
}

@usableFromInline
func waitForProduceMore() async {
await withCheckedContinuation { (cont: CheckedContinuation<Void, Never>) in
self.checkedContinuations.withLockedValue {
$0.append(cont)
switch self.state.withLockedValue({ $0 }) {
case .produceMore, .terminated:
break
case .waitingForProduceMore, .multipleWaitingForProduceMore:
await withCheckedContinuation { (newContinuation: CheckedContinuation<Void, Never>) in
self.state.withLockedValue { state in
switch state {
case .produceMore:
newContinuation.resume()
case .waitingForProduceMore(let firstContinuation):
if let firstContinuation {
var continuations = Deque<CheckedContinuation<Void, Never>>()
continuations.reserveCapacity(2)
continuations.append(firstContinuation)
continuations.append(newContinuation)
state = .multipleWaitingForProduceMore(continuations)
} else {
state = .waitingForProduceMore(newContinuation)
}
case .multipleWaitingForProduceMore(var continuations):
continuations.append(newContinuation)
state = .multipleWaitingForProduceMore(continuations)
case .terminated:
newContinuation.resume()
}
}
}
}
}

@usableFromInline
func stopProducing() {
self.state.withLockedValue { state in
switch state {
case .produceMore:
state = .waitingForProduceMore(nil)
case .waitingForProduceMore:
break
case .multipleWaitingForProduceMore:
break
case .terminated:
break
}
}
}
Expand All @@ -175,14 +253,11 @@ extension RequestBody {
let source: Producer.Source
@usableFromInline
let delegate: Delegate
@usableFromInline
let waitForProduceMore: NIOLockedValueBox<Bool>

@usableFromInline
init(source: Producer.Source, delegate: Delegate) {
self.source = source
self.delegate = delegate
self.waitForProduceMore = .init(false)
}

/// Yields the element to the inbound stream.
Expand All @@ -195,13 +270,10 @@ extension RequestBody {
public func yield(_ element: ByteBuffer) async throws {
// if previous call indicated we should stop producing wait until the delegate
// says we can start producing again
if self.waitForProduceMore.withLockedValue({ $0 }) {
await self.delegate.waitForProduceMore()
self.waitForProduceMore.withLockedValue { $0 = false }
}
await self.delegate.waitForProduceMore()
let result = self.source.yield(element)
if result == .stopProducing {
self.waitForProduceMore.withLockedValue { $0 = true }
self.delegate.stopProducing()
}
}

Expand Down
46 changes: 46 additions & 0 deletions Tests/HummingbirdTests/ApplicationTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -867,6 +867,52 @@ final class ApplicationTests: XCTestCase {
}
}

/// Test AsyncSequence returned by RequestBody.makeStream() and feeding it data from multiple processes
func testMakeStreamMultipleSources() async throws {
let router = Router()
router.get("numbers") { request, context -> Response in
let body = try await withThrowingTaskGroup(of: Void.self) { group in
let (requestBody, source) = RequestBody.makeStream()
group.addTask {
// Add three tasks feeding the source
await withThrowingTaskGroup(of: Void.self) { group in
group.addTask {
for value in 0..<100 {
try await source.yield(ByteBuffer(string: String(describing: value)))
}
}
group.addTask {
for value in 0..<100 {
try await source.yield(ByteBuffer(string: String(describing: value)))
}
}
group.addTask {
for value in 0..<100 {
try await source.yield(ByteBuffer(string: String(describing: value)))
}
}
}
source.finish()
}
var body = ByteBuffer()
for try await buffer in requestBody {
var buffer = buffer
body.writeBuffer(&buffer)
try await Task.sleep(for: .milliseconds(1))
}
return body
}
return Response(status: .ok, body: .init(byteBuffer: body))
}
let app = Application(responder: router.buildResponder())

try await app.test(.router) { client in
try await client.execute(uri: "/numbers", method: .get) { response in
XCTAssertEqual(response.status, .ok)
}
}
}

#if compiler(>=6.0)
/// Test consumeWithInboundCloseHandler
func testConsumeWithInboundHandler() async throws {
Expand Down

0 comments on commit c1d1757

Please sign in to comment.