-
Notifications
You must be signed in to change notification settings - Fork 1.2k
Implement Stream Cancellation #2901
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
e7dab8d
to
8fda0ab
Compare
8fda0ab
to
fc05434
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@willgdjones Thanks Will. I get where Cursor is going with this, but I'm not sure if it's doing too little (shouldn't we be calling openai.AgentStream.close()
at some point?) and/or too much (do we need the multiple canceled booleans and exception, if the wrapped stream could just stop yielding events). Would be good to get your (human, not AI!) take :)
"""Exception raised when a streaming response is cancelled.""" | ||
|
||
def __init__(self, message: str = 'Stream was cancelled'): | ||
self.message = message |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No need to have a message as it's not used anywhere
This should close the underlying network connection and cause any active iteration | ||
to raise a StreamCancelled exception. The default implementation is a no-op. | ||
""" | ||
pass |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this should raise NotImplementedError
to not silently keep the stream going when the user thought they canceled it
|
||
async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]: | ||
async for chunk in self._response: | ||
# Check for cancellation before processing each chunk | ||
if self._cancelled: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Shouldn't we do this after recording the usage?
@@ -1418,6 +1423,14 @@ def timestamp(self) -> datetime: | |||
"""Get the timestamp of the response.""" | |||
return self._timestamp | |||
|
|||
async def cancel(self) -> None: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This doesn't need to be async if the recommended behavior is to always just set a flag and then cancel on the next iteration
|
||
async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]: | ||
async for chunk in self._response: | ||
# Check for cancellation before processing each chunk | ||
if self._cancelled: | ||
raise StreamCancelled('OpenAI stream was cancelled') |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Will this actually cause OpenAI to cleanly close the stream? Shouldn't we call await AsyncStream.close()
or something?
They also have their own # Ensure the entire stream is consumed
:
Note that right now, OpenAIStreamedResponse
only has access to AsyncIterable[ChatCompletionChunk]
, but that's derived from the AsyncStream[ChatCompletionChunk]
:
pydantic-ai/pydantic_ai_slim/pydantic_ai/models/openai.py
Lines 578 to 582 in f903d5b
async def _process_streamed_response( | |
self, response: AsyncStream[ChatCompletionChunk], model_request_parameters: ModelRequestParameters | |
) -> OpenAIStreamedResponse: | |
"""Process a streamed response, and prepare a streaming response to return.""" | |
peekable_response = _utils.PeekableAsyncStream(response) |
So in order to access that cancel method, we may need to put it on _utils.PeekableAsyncStream
as well, and then forward it to the underlying stream.
async for item in stream_response: | ||
# Check for cancellation first | ||
if is_cancelled(): | ||
raise exceptions.StreamCancelled() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why do we have to raise it here and inside the StreamedResponse
?
Wouldn't the stream_response
just stop yielding once it's cancelled and there are no more messages, meaning we shouldn't need to do anythings special here?
self._agent_stream_iterator = _get_usage_checking_stream_response( | ||
self._raw_stream_response, self._usage_limits, self.usage | ||
self._agent_stream_iterator = _get_cancellation_aware_stream_response( | ||
self._raw_stream_response, self._usage_limits, self.usage, lambda: self._cancelled |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Related to what I wrote below, I don't understand why we need this lambda, instead of just pushing the cancelation down to the wrapped StreamedResponse
, and then relying on that to stop yielding events
try: | ||
async for _ in agent_stream: | ||
pass | ||
except exceptions.StreamCancelled: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we need this exception at all?
Like I wrote below: Wouldn't the StreamedResponse
just stop yielding once it's been cancelled and there are no more messages, meaning we shouldn't need to do anything special here?
m = OpenAIChatModel('gpt-4o-mini', provider=OpenAIProvider(openai_client=mock_client)) | ||
agent = Agent(m) | ||
|
||
async with agent.run_stream('Hello world') as result: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we add a test that cancels streaming in the middle of an unfinished tool call, like in the example at https://ai.pydantic.dev/agents/#streaming-events-and-final-output after a ToolCallPartDelta
, and then see what the final ModelResponse
in result.all_messages()
looks like? I imagine it would have an incomplete ToolCallPart
. I wonder if we should indicate on the ModelResponse
somehow that it's incomplete because it's been canceled, and cannot be used as message_history
, for example.
@@ -641,6 +641,14 @@ def timestamp(self) -> datetime: | |||
"""Get the timestamp of the response.""" | |||
raise NotImplementedError() | |||
|
|||
async def cancel(self) -> None: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should document this feature in the Streaming docs
Thank you for this detailed critique! I will address these points when I get the chance to. |
Fixes #1524
Pydantic AI Stream Cancellation
This implementation adds stream cancellation functionality, allowing users to cancel streaming responses when clients disconnect or explicitly request cancellation.
🎯 Problem Solved
Previously, when users broke early from a streaming loop, Pydantic AI would continue consuming the entire response in the background to ensure proper usage tracking. This led to:
✨ Features
await stream.cancel()
to stop streamingBasic Usage
📚 API Reference
AgentStream.cancel()
StreamCancelled Exception
🏗️ Implementation Details
Architecture
The implementation consists of several components:
StreamCancelled Exception (
exceptions.py
)AgentStream.cancel() (
result.py
)StreamedResponse.cancel() (
models/__init__.py
)OpenAIStreamedResponse.cancel() (
models/openai.py
)Cancellation-Aware Iterator (
result.py
)Agent Graph Updates (
_agent_graph.py
)Usage Tracking
Error Handling
cancel()
calls are safe