-
Notifications
You must be signed in to change notification settings - Fork 1.2k
Implement Stream Cancellation #2901
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -24,6 +24,7 @@ | |
'UsageLimitExceeded', | ||
'ModelHTTPError', | ||
'FallbackExceptionGroup', | ||
'StreamCancelled', | ||
) | ||
|
||
|
||
|
@@ -162,6 +163,14 @@ class FallbackExceptionGroup(ExceptionGroup): | |
"""A group of exceptions that can be raised when all fallback models fail.""" | ||
|
||
|
||
class StreamCancelled(Exception): | ||
"""Exception raised when a streaming response is cancelled.""" | ||
|
||
def __init__(self, message: str = 'Stream was cancelled'): | ||
self.message = message | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. No need to have a message as it's not used anywhere |
||
super().__init__(message) | ||
|
||
|
||
class ToolRetryError(Exception): | ||
"""Exception used to signal a `ToolRetry` message should be returned to the LLM.""" | ||
|
||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -641,6 +641,14 @@ def timestamp(self) -> datetime: | |
"""Get the timestamp of the response.""" | ||
raise NotImplementedError() | ||
|
||
async def cancel(self) -> None: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We should document this feature in the Streaming docs |
||
"""Cancel the streaming response. | ||
|
||
This should close the underlying network connection and cause any active iteration | ||
to raise a StreamCancelled exception. The default implementation is a no-op. | ||
""" | ||
pass | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think this should raise |
||
|
||
|
||
ALLOW_MODEL_REQUESTS = True | ||
"""Whether to allow requests to models. | ||
|
Original file line number | Diff line number | Diff line change | ||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|
|
@@ -17,7 +17,7 @@ | |||||||||||
from .._thinking_part import split_content_into_text_and_thinking | ||||||||||||
from .._utils import guard_tool_call_id as _guard_tool_call_id, now_utc as _now_utc, number_to_datetime | ||||||||||||
from ..builtin_tools import CodeExecutionTool, WebSearchTool | ||||||||||||
from ..exceptions import UserError | ||||||||||||
from ..exceptions import StreamCancelled, UserError | ||||||||||||
from ..messages import ( | ||||||||||||
AudioUrl, | ||||||||||||
BinaryContent, | ||||||||||||
|
@@ -1347,9 +1347,14 @@ class OpenAIStreamedResponse(StreamedResponse): | |||||||||||
_response: AsyncIterable[ChatCompletionChunk] | ||||||||||||
_timestamp: datetime | ||||||||||||
_provider_name: str | ||||||||||||
_cancelled: bool = field(default=False, init=False) | ||||||||||||
|
||||||||||||
async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]: | ||||||||||||
async for chunk in self._response: | ||||||||||||
# Check for cancellation before processing each chunk | ||||||||||||
if self._cancelled: | ||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Shouldn't we do this after recording the usage? |
||||||||||||
raise StreamCancelled('OpenAI stream was cancelled') | ||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Will this actually cause OpenAI to cleanly close the stream? Shouldn't we call They also have their own Note that right now, pydantic-ai/pydantic_ai_slim/pydantic_ai/models/openai.py Lines 578 to 582 in f903d5b
So in order to access that cancel method, we may need to put it on |
||||||||||||
|
||||||||||||
self._usage += _map_usage(chunk) | ||||||||||||
|
||||||||||||
if chunk.id and self.provider_response_id is None: | ||||||||||||
|
@@ -1418,6 +1423,14 @@ def timestamp(self) -> datetime: | |||||||||||
"""Get the timestamp of the response.""" | ||||||||||||
return self._timestamp | ||||||||||||
|
||||||||||||
async def cancel(self) -> None: | ||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This doesn't need to be async if the recommended behavior is to always just set a flag and then cancel on the next iteration |
||||||||||||
"""Cancel the streaming response. | ||||||||||||
|
||||||||||||
This marks the stream as cancelled, which will cause the iterator to raise | ||||||||||||
a StreamCancelled exception on the next iteration. | ||||||||||||
""" | ||||||||||||
self._cancelled = True | ||||||||||||
|
||||||||||||
|
||||||||||||
@dataclass | ||||||||||||
class OpenAIResponsesStreamedResponse(StreamedResponse): | ||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can we implement this for the OpenAI Responses API as well? I'm OK leaving Google and Anthropic off for now, although it'd be amazing if we supported those as well :) |
||||||||||||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -54,6 +54,7 @@ class AgentStream(Generic[AgentDepsT, OutputDataT]): | |
|
||
_agent_stream_iterator: AsyncIterator[ModelResponseStreamEvent] | None = field(default=None, init=False) | ||
_initial_run_ctx_usage: RunUsage = field(init=False) | ||
_cancelled: bool = field(default=False, init=False) | ||
|
||
def __post_init__(self): | ||
self._initial_run_ctx_usage = copy(self._run_ctx.usage) | ||
|
@@ -123,6 +124,19 @@ def timestamp(self) -> datetime: | |
"""Get the timestamp of the response.""" | ||
return self._raw_stream_response.timestamp | ||
|
||
async def cancel(self) -> None: | ||
"""Cancel the streaming response. | ||
|
||
This will close the underlying network connection and cause any active iteration | ||
over the stream to raise a StreamCancelled exception. | ||
|
||
Subsequent calls to cancel() are safe and will not raise additional exceptions. | ||
""" | ||
if not self._cancelled: | ||
self._cancelled = True | ||
# Cancel the underlying stream response | ||
await self._raw_stream_response.cancel() | ||
|
||
async def get_output(self) -> OutputDataT: | ||
"""Stream the whole response, validate the output and return it.""" | ||
async for _ in self: | ||
|
@@ -227,8 +241,8 @@ async def _stream_text_deltas() -> AsyncIterator[str]: | |
def __aiter__(self) -> AsyncIterator[ModelResponseStreamEvent]: | ||
"""Stream [`ModelResponseStreamEvent`][pydantic_ai.messages.ModelResponseStreamEvent]s.""" | ||
if self._agent_stream_iterator is None: | ||
self._agent_stream_iterator = _get_usage_checking_stream_response( | ||
self._raw_stream_response, self._usage_limits, self.usage | ||
self._agent_stream_iterator = _get_cancellation_aware_stream_response( | ||
self._raw_stream_response, self._usage_limits, self.usage, lambda: self._cancelled | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Related to what I wrote below, I don't understand why we need this lambda, instead of just pushing the cancelation down to the wrapped |
||
) | ||
|
||
return self._agent_stream_iterator | ||
|
@@ -450,6 +464,18 @@ async def stream_responses( | |
else: | ||
raise ValueError('No stream response or run result provided') # pragma: no cover | ||
|
||
async def cancel(self) -> None: | ||
"""Cancel the streaming response. | ||
|
||
This will close the underlying network connection and cause any active iteration | ||
over the stream to raise a StreamCancelled exception. | ||
|
||
Subsequent calls to cancel() are safe and will not raise additional exceptions. | ||
""" | ||
if self._stream_response is not None: | ||
await self._stream_response.cancel() | ||
# If there's no stream response, this is a no-op (already completed) | ||
|
||
async def get_output(self) -> OutputDataT: | ||
"""Stream the whole response, validate and return it.""" | ||
if self._run_result is not None: | ||
|
@@ -526,21 +552,27 @@ class FinalResult(Generic[OutputDataT]): | |
__repr__ = _utils.dataclasses_no_defaults_repr | ||
|
||
|
||
def _get_usage_checking_stream_response( | ||
def _get_cancellation_aware_stream_response( | ||
stream_response: models.StreamedResponse, | ||
limits: UsageLimits | None, | ||
get_usage: Callable[[], RunUsage], | ||
is_cancelled: Callable[[], bool], | ||
) -> AsyncIterator[ModelResponseStreamEvent]: | ||
if limits is not None and limits.has_token_limits(): | ||
"""Create an iterator that checks for cancellation and usage limits.""" | ||
|
||
async def _usage_checking_iterator(): | ||
async for item in stream_response: | ||
async def _cancellation_aware_iterator(): | ||
async for item in stream_response: | ||
# Check for cancellation first | ||
if is_cancelled(): | ||
raise exceptions.StreamCancelled() | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why do we have to raise it here and inside the Wouldn't the |
||
|
||
# Then check usage limits if needed | ||
if limits is not None and limits.has_token_limits(): | ||
limits.check_tokens(get_usage()) | ||
yield item | ||
|
||
return _usage_checking_iterator() | ||
else: | ||
return aiter(stream_response) | ||
yield item | ||
|
||
return _cancellation_aware_iterator() | ||
|
||
|
||
def _get_deferred_tool_requests( | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -15,6 +15,7 @@ | |
|
||
from pydantic_ai import Agent, ModelHTTPError, ModelRetry, UnexpectedModelBehavior, UserError | ||
from pydantic_ai.builtin_tools import WebSearchTool | ||
from pydantic_ai.exceptions import StreamCancelled | ||
from pydantic_ai.messages import ( | ||
AudioUrl, | ||
BinaryContent, | ||
|
@@ -2923,6 +2924,75 @@ async def test_openai_model_cerebras_provider_harmony(allow_model_requests: None | |
assert result.output == snapshot('The capital of France is **Paris**.') | ||
|
||
|
||
async def test_stream_cancellation(allow_model_requests: None): | ||
"""Test that stream cancellation works correctly with mocked responses.""" | ||
# Create a simple stream | ||
stream = [ | ||
text_chunk('Hello '), | ||
text_chunk('world'), | ||
chunk([]), | ||
] | ||
mock_client = MockOpenAI.create_mock_stream(stream) | ||
m = OpenAIChatModel('gpt-4o-mini', provider=OpenAIProvider(openai_client=mock_client)) | ||
agent = Agent(m) | ||
|
||
async with agent.run_stream('Hello') as result: | ||
# Cancel immediately and then try to iterate | ||
await result.cancel() | ||
|
||
# Now try to iterate - this should raise StreamCancelled | ||
with pytest.raises(StreamCancelled): | ||
async for content in result.stream_text(delta=True): | ||
pytest.fail(f'Should not receive content after cancellation: {content}') | ||
|
||
|
||
async def test_multiple_cancel_calls(allow_model_requests: None): | ||
"""Test that multiple cancel calls are safe.""" | ||
stream = [ | ||
text_chunk('Hello '), | ||
text_chunk('world '), | ||
text_chunk('from '), | ||
text_chunk('AI!'), | ||
chunk([]), | ||
] | ||
mock_client = MockOpenAI.create_mock_stream(stream) | ||
m = OpenAIChatModel('gpt-4o-mini', provider=OpenAIProvider(openai_client=mock_client)) | ||
agent = Agent(m) | ||
|
||
async with agent.run_stream('Hello world') as result: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can we add a test that cancels streaming in the middle of an unfinished tool call, like in the example at https://ai.pydantic.dev/agents/#streaming-events-and-final-output after a |
||
# Call cancel multiple times - should be safe | ||
await result.cancel() | ||
await result.cancel() | ||
await result.cancel() | ||
|
||
# Try to iterate - should raise StreamCancelled | ||
with pytest.raises(StreamCancelled): | ||
async for content in result.stream_text(delta=True): | ||
pytest.fail(f'Should not receive content after cancellation: {content}') | ||
|
||
|
||
async def test_stream_cancellation_immediate(allow_model_requests: None): | ||
"""Test immediate cancellation before iteration.""" | ||
stream = [ | ||
text_chunk('This should '), | ||
text_chunk('not be '), | ||
text_chunk('processed.'), | ||
chunk([]), | ||
] | ||
mock_client = MockOpenAI.create_mock_stream(stream) | ||
m = OpenAIChatModel('gpt-4o-mini', provider=OpenAIProvider(openai_client=mock_client)) | ||
agent = Agent(m) | ||
|
||
async with agent.run_stream('Tell me a story') as result: | ||
# Cancel immediately before any iteration | ||
await result.cancel() | ||
|
||
# Try to iterate - should raise StreamCancelled immediately | ||
with pytest.raises(StreamCancelled): | ||
async for content in result.stream_text(delta=True): | ||
pytest.fail(f'Should not receive any content after immediate cancellation: {content}') | ||
|
||
|
||
def test_deprecated_openai_model(openai_api_key: str): | ||
with pytest.warns(DeprecationWarning): | ||
from pydantic_ai.models.openai import OpenAIModel # type: ignore[reportDeprecated] | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we need this exception at all?
Like I wrote below: Wouldn't the
StreamedResponse
just stop yielding once it's been cancelled and there are no more messages, meaning we shouldn't need to do anything special here?