Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions pydantic_ai_slim/pydantic_ai/_agent_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -407,8 +407,12 @@ async def stream(
)
yield agent_stream
# In case the user didn't manually consume the full stream, ensure it is fully consumed here,
# otherwise usage won't be properly counted:
async for _ in agent_stream:
# However, if the stream was cancelled, we should not consume further.
try:
async for _ in agent_stream:
pass
except exceptions.StreamCancelled:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need this exception at all?

Like I wrote below: Wouldn't the StreamedResponse just stop yielding once it's been cancelled and there are no more messages, meaning we shouldn't need to do anything special here?

# Stream was cancelled - don't consume further
pass

model_response = streamed_response.get()
Expand Down
9 changes: 9 additions & 0 deletions pydantic_ai_slim/pydantic_ai/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
'UsageLimitExceeded',
'ModelHTTPError',
'FallbackExceptionGroup',
'StreamCancelled',
)


Expand Down Expand Up @@ -162,6 +163,14 @@ class FallbackExceptionGroup(ExceptionGroup):
"""A group of exceptions that can be raised when all fallback models fail."""


class StreamCancelled(Exception):
"""Exception raised when a streaming response is cancelled."""

def __init__(self, message: str = 'Stream was cancelled'):
self.message = message
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No need to have a message as it's not used anywhere

super().__init__(message)


class ToolRetryError(Exception):
"""Exception used to signal a `ToolRetry` message should be returned to the LLM."""

Expand Down
8 changes: 8 additions & 0 deletions pydantic_ai_slim/pydantic_ai/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -641,6 +641,14 @@ def timestamp(self) -> datetime:
"""Get the timestamp of the response."""
raise NotImplementedError()

async def cancel(self) -> None:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should document this feature in the Streaming docs

"""Cancel the streaming response.

This should close the underlying network connection and cause any active iteration
to raise a StreamCancelled exception. The default implementation is a no-op.
"""
pass
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this should raise NotImplementedError to not silently keep the stream going when the user thought they canceled it



ALLOW_MODEL_REQUESTS = True
"""Whether to allow requests to models.
Expand Down
15 changes: 14 additions & 1 deletion pydantic_ai_slim/pydantic_ai/models/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from .._thinking_part import split_content_into_text_and_thinking
from .._utils import guard_tool_call_id as _guard_tool_call_id, now_utc as _now_utc, number_to_datetime
from ..builtin_tools import CodeExecutionTool, WebSearchTool
from ..exceptions import UserError
from ..exceptions import StreamCancelled, UserError
from ..messages import (
AudioUrl,
BinaryContent,
Expand Down Expand Up @@ -1347,9 +1347,14 @@ class OpenAIStreamedResponse(StreamedResponse):
_response: AsyncIterable[ChatCompletionChunk]
_timestamp: datetime
_provider_name: str
_cancelled: bool = field(default=False, init=False)

async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]:
async for chunk in self._response:
# Check for cancellation before processing each chunk
if self._cancelled:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shouldn't we do this after recording the usage?

raise StreamCancelled('OpenAI stream was cancelled')
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will this actually cause OpenAI to cleanly close the stream? Shouldn't we call await AsyncStream.close() or something?

They also have their own # Ensure the entire stream is consumed:

https://github.com/openai/openai-python/blob/4756247cee3d9548397b26a29109e76cc9522379/src/openai/_streaming.py#L216-L222

Note that right now, OpenAIStreamedResponse only has access to AsyncIterable[ChatCompletionChunk], but that's derived from the AsyncStream[ChatCompletionChunk]:

async def _process_streamed_response(
self, response: AsyncStream[ChatCompletionChunk], model_request_parameters: ModelRequestParameters
) -> OpenAIStreamedResponse:
"""Process a streamed response, and prepare a streaming response to return."""
peekable_response = _utils.PeekableAsyncStream(response)

So in order to access that cancel method, we may need to put it on _utils.PeekableAsyncStream as well, and then forward it to the underlying stream.


self._usage += _map_usage(chunk)

if chunk.id and self.provider_response_id is None:
Expand Down Expand Up @@ -1418,6 +1423,14 @@ def timestamp(self) -> datetime:
"""Get the timestamp of the response."""
return self._timestamp

async def cancel(self) -> None:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This doesn't need to be async if the recommended behavior is to always just set a flag and then cancel on the next iteration

"""Cancel the streaming response.

This marks the stream as cancelled, which will cause the iterator to raise
a StreamCancelled exception on the next iteration.
"""
self._cancelled = True


@dataclass
class OpenAIResponsesStreamedResponse(StreamedResponse):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we implement this for the OpenAI Responses API as well? I'm OK leaving Google and Anthropic off for now, although it'd be amazing if we supported those as well :)

Expand Down
52 changes: 42 additions & 10 deletions pydantic_ai_slim/pydantic_ai/result.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ class AgentStream(Generic[AgentDepsT, OutputDataT]):

_agent_stream_iterator: AsyncIterator[ModelResponseStreamEvent] | None = field(default=None, init=False)
_initial_run_ctx_usage: RunUsage = field(init=False)
_cancelled: bool = field(default=False, init=False)

def __post_init__(self):
self._initial_run_ctx_usage = copy(self._run_ctx.usage)
Expand Down Expand Up @@ -123,6 +124,19 @@ def timestamp(self) -> datetime:
"""Get the timestamp of the response."""
return self._raw_stream_response.timestamp

async def cancel(self) -> None:
"""Cancel the streaming response.

This will close the underlying network connection and cause any active iteration
over the stream to raise a StreamCancelled exception.

Subsequent calls to cancel() are safe and will not raise additional exceptions.
"""
if not self._cancelled:
self._cancelled = True
# Cancel the underlying stream response
await self._raw_stream_response.cancel()

async def get_output(self) -> OutputDataT:
"""Stream the whole response, validate the output and return it."""
async for _ in self:
Expand Down Expand Up @@ -227,8 +241,8 @@ async def _stream_text_deltas() -> AsyncIterator[str]:
def __aiter__(self) -> AsyncIterator[ModelResponseStreamEvent]:
"""Stream [`ModelResponseStreamEvent`][pydantic_ai.messages.ModelResponseStreamEvent]s."""
if self._agent_stream_iterator is None:
self._agent_stream_iterator = _get_usage_checking_stream_response(
self._raw_stream_response, self._usage_limits, self.usage
self._agent_stream_iterator = _get_cancellation_aware_stream_response(
self._raw_stream_response, self._usage_limits, self.usage, lambda: self._cancelled
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Related to what I wrote below, I don't understand why we need this lambda, instead of just pushing the cancelation down to the wrapped StreamedResponse, and then relying on that to stop yielding events

)

return self._agent_stream_iterator
Expand Down Expand Up @@ -450,6 +464,18 @@ async def stream_responses(
else:
raise ValueError('No stream response or run result provided') # pragma: no cover

async def cancel(self) -> None:
"""Cancel the streaming response.

This will close the underlying network connection and cause any active iteration
over the stream to raise a StreamCancelled exception.

Subsequent calls to cancel() are safe and will not raise additional exceptions.
"""
if self._stream_response is not None:
await self._stream_response.cancel()
# If there's no stream response, this is a no-op (already completed)

async def get_output(self) -> OutputDataT:
"""Stream the whole response, validate and return it."""
if self._run_result is not None:
Expand Down Expand Up @@ -526,21 +552,27 @@ class FinalResult(Generic[OutputDataT]):
__repr__ = _utils.dataclasses_no_defaults_repr


def _get_usage_checking_stream_response(
def _get_cancellation_aware_stream_response(
stream_response: models.StreamedResponse,
limits: UsageLimits | None,
get_usage: Callable[[], RunUsage],
is_cancelled: Callable[[], bool],
) -> AsyncIterator[ModelResponseStreamEvent]:
if limits is not None and limits.has_token_limits():
"""Create an iterator that checks for cancellation and usage limits."""

async def _usage_checking_iterator():
async for item in stream_response:
async def _cancellation_aware_iterator():
async for item in stream_response:
# Check for cancellation first
if is_cancelled():
raise exceptions.StreamCancelled()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we have to raise it here and inside the StreamedResponse?

Wouldn't the stream_response just stop yielding once it's cancelled and there are no more messages, meaning we shouldn't need to do anythings special here?


# Then check usage limits if needed
if limits is not None and limits.has_token_limits():
limits.check_tokens(get_usage())
yield item

return _usage_checking_iterator()
else:
return aiter(stream_response)
yield item

return _cancellation_aware_iterator()


def _get_deferred_tool_requests(
Expand Down
70 changes: 70 additions & 0 deletions tests/models/test_openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

from pydantic_ai import Agent, ModelHTTPError, ModelRetry, UnexpectedModelBehavior, UserError
from pydantic_ai.builtin_tools import WebSearchTool
from pydantic_ai.exceptions import StreamCancelled
from pydantic_ai.messages import (
AudioUrl,
BinaryContent,
Expand Down Expand Up @@ -2923,6 +2924,75 @@ async def test_openai_model_cerebras_provider_harmony(allow_model_requests: None
assert result.output == snapshot('The capital of France is **Paris**.')


async def test_stream_cancellation(allow_model_requests: None):
"""Test that stream cancellation works correctly with mocked responses."""
# Create a simple stream
stream = [
text_chunk('Hello '),
text_chunk('world'),
chunk([]),
]
mock_client = MockOpenAI.create_mock_stream(stream)
m = OpenAIChatModel('gpt-4o-mini', provider=OpenAIProvider(openai_client=mock_client))
agent = Agent(m)

async with agent.run_stream('Hello') as result:
# Cancel immediately and then try to iterate
await result.cancel()

# Now try to iterate - this should raise StreamCancelled
with pytest.raises(StreamCancelled):
async for content in result.stream_text(delta=True):
pytest.fail(f'Should not receive content after cancellation: {content}')


async def test_multiple_cancel_calls(allow_model_requests: None):
"""Test that multiple cancel calls are safe."""
stream = [
text_chunk('Hello '),
text_chunk('world '),
text_chunk('from '),
text_chunk('AI!'),
chunk([]),
]
mock_client = MockOpenAI.create_mock_stream(stream)
m = OpenAIChatModel('gpt-4o-mini', provider=OpenAIProvider(openai_client=mock_client))
agent = Agent(m)

async with agent.run_stream('Hello world') as result:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we add a test that cancels streaming in the middle of an unfinished tool call, like in the example at https://ai.pydantic.dev/agents/#streaming-events-and-final-output after a ToolCallPartDelta, and then see what the final ModelResponse in result.all_messages() looks like? I imagine it would have an incomplete ToolCallPart. I wonder if we should indicate on the ModelResponse somehow that it's incomplete because it's been canceled, and cannot be used as message_history, for example.

# Call cancel multiple times - should be safe
await result.cancel()
await result.cancel()
await result.cancel()

# Try to iterate - should raise StreamCancelled
with pytest.raises(StreamCancelled):
async for content in result.stream_text(delta=True):
pytest.fail(f'Should not receive content after cancellation: {content}')


async def test_stream_cancellation_immediate(allow_model_requests: None):
"""Test immediate cancellation before iteration."""
stream = [
text_chunk('This should '),
text_chunk('not be '),
text_chunk('processed.'),
chunk([]),
]
mock_client = MockOpenAI.create_mock_stream(stream)
m = OpenAIChatModel('gpt-4o-mini', provider=OpenAIProvider(openai_client=mock_client))
agent = Agent(m)

async with agent.run_stream('Tell me a story') as result:
# Cancel immediately before any iteration
await result.cancel()

# Try to iterate - should raise StreamCancelled immediately
with pytest.raises(StreamCancelled):
async for content in result.stream_text(delta=True):
pytest.fail(f'Should not receive any content after immediate cancellation: {content}')


def test_deprecated_openai_model(openai_api_key: str):
with pytest.warns(DeprecationWarning):
from pydantic_ai.models.openai import OpenAIModel # type: ignore[reportDeprecated]
Expand Down
Loading