diff --git a/pydantic_ai_slim/pydantic_ai/_agent_graph.py b/pydantic_ai_slim/pydantic_ai/_agent_graph.py index e121ec475..223117443 100644 --- a/pydantic_ai_slim/pydantic_ai/_agent_graph.py +++ b/pydantic_ai_slim/pydantic_ai/_agent_graph.py @@ -407,8 +407,12 @@ async def stream( ) yield agent_stream # In case the user didn't manually consume the full stream, ensure it is fully consumed here, - # otherwise usage won't be properly counted: - async for _ in agent_stream: + # However, if the stream was cancelled, we should not consume further. + try: + async for _ in agent_stream: + pass + except exceptions.StreamCancelled: + # Stream was cancelled - don't consume further pass model_response = streamed_response.get() diff --git a/pydantic_ai_slim/pydantic_ai/exceptions.py b/pydantic_ai_slim/pydantic_ai/exceptions.py index 58a7686e0..79b7041e1 100644 --- a/pydantic_ai_slim/pydantic_ai/exceptions.py +++ b/pydantic_ai_slim/pydantic_ai/exceptions.py @@ -24,6 +24,7 @@ 'UsageLimitExceeded', 'ModelHTTPError', 'FallbackExceptionGroup', + 'StreamCancelled', ) @@ -162,6 +163,14 @@ class FallbackExceptionGroup(ExceptionGroup): """A group of exceptions that can be raised when all fallback models fail.""" +class StreamCancelled(Exception): + """Exception raised when a streaming response is cancelled.""" + + def __init__(self, message: str = 'Stream was cancelled'): + self.message = message + super().__init__(message) + + class ToolRetryError(Exception): """Exception used to signal a `ToolRetry` message should be returned to the LLM.""" diff --git a/pydantic_ai_slim/pydantic_ai/models/__init__.py b/pydantic_ai_slim/pydantic_ai/models/__init__.py index f76ec2efd..c4a2ab168 100644 --- a/pydantic_ai_slim/pydantic_ai/models/__init__.py +++ b/pydantic_ai_slim/pydantic_ai/models/__init__.py @@ -641,6 +641,14 @@ def timestamp(self) -> datetime: """Get the timestamp of the response.""" raise NotImplementedError() + async def cancel(self) -> None: + """Cancel the streaming response. + + This should close the underlying network connection and cause any active iteration + to raise a StreamCancelled exception. The default implementation is a no-op. + """ + pass + ALLOW_MODEL_REQUESTS = True """Whether to allow requests to models. diff --git a/pydantic_ai_slim/pydantic_ai/models/openai.py b/pydantic_ai_slim/pydantic_ai/models/openai.py index 2b83fcad3..5ade040c1 100644 --- a/pydantic_ai_slim/pydantic_ai/models/openai.py +++ b/pydantic_ai_slim/pydantic_ai/models/openai.py @@ -17,7 +17,7 @@ from .._thinking_part import split_content_into_text_and_thinking from .._utils import guard_tool_call_id as _guard_tool_call_id, now_utc as _now_utc, number_to_datetime from ..builtin_tools import CodeExecutionTool, WebSearchTool -from ..exceptions import UserError +from ..exceptions import StreamCancelled, UserError from ..messages import ( AudioUrl, BinaryContent, @@ -1347,9 +1347,14 @@ class OpenAIStreamedResponse(StreamedResponse): _response: AsyncIterable[ChatCompletionChunk] _timestamp: datetime _provider_name: str + _cancelled: bool = field(default=False, init=False) async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]: async for chunk in self._response: + # Check for cancellation before processing each chunk + if self._cancelled: + raise StreamCancelled('OpenAI stream was cancelled') + self._usage += _map_usage(chunk) if chunk.id and self.provider_response_id is None: @@ -1418,6 +1423,14 @@ def timestamp(self) -> datetime: """Get the timestamp of the response.""" return self._timestamp + async def cancel(self) -> None: + """Cancel the streaming response. + + This marks the stream as cancelled, which will cause the iterator to raise + a StreamCancelled exception on the next iteration. + """ + self._cancelled = True + @dataclass class OpenAIResponsesStreamedResponse(StreamedResponse): diff --git a/pydantic_ai_slim/pydantic_ai/result.py b/pydantic_ai_slim/pydantic_ai/result.py index 845d64561..93ff90ada 100644 --- a/pydantic_ai_slim/pydantic_ai/result.py +++ b/pydantic_ai_slim/pydantic_ai/result.py @@ -54,6 +54,7 @@ class AgentStream(Generic[AgentDepsT, OutputDataT]): _agent_stream_iterator: AsyncIterator[ModelResponseStreamEvent] | None = field(default=None, init=False) _initial_run_ctx_usage: RunUsage = field(init=False) + _cancelled: bool = field(default=False, init=False) def __post_init__(self): self._initial_run_ctx_usage = copy(self._run_ctx.usage) @@ -123,6 +124,19 @@ def timestamp(self) -> datetime: """Get the timestamp of the response.""" return self._raw_stream_response.timestamp + async def cancel(self) -> None: + """Cancel the streaming response. + + This will close the underlying network connection and cause any active iteration + over the stream to raise a StreamCancelled exception. + + Subsequent calls to cancel() are safe and will not raise additional exceptions. + """ + if not self._cancelled: + self._cancelled = True + # Cancel the underlying stream response + await self._raw_stream_response.cancel() + async def get_output(self) -> OutputDataT: """Stream the whole response, validate the output and return it.""" async for _ in self: @@ -227,8 +241,8 @@ async def _stream_text_deltas() -> AsyncIterator[str]: def __aiter__(self) -> AsyncIterator[ModelResponseStreamEvent]: """Stream [`ModelResponseStreamEvent`][pydantic_ai.messages.ModelResponseStreamEvent]s.""" if self._agent_stream_iterator is None: - self._agent_stream_iterator = _get_usage_checking_stream_response( - self._raw_stream_response, self._usage_limits, self.usage + self._agent_stream_iterator = _get_cancellation_aware_stream_response( + self._raw_stream_response, self._usage_limits, self.usage, lambda: self._cancelled ) return self._agent_stream_iterator @@ -450,6 +464,18 @@ async def stream_responses( else: raise ValueError('No stream response or run result provided') # pragma: no cover + async def cancel(self) -> None: + """Cancel the streaming response. + + This will close the underlying network connection and cause any active iteration + over the stream to raise a StreamCancelled exception. + + Subsequent calls to cancel() are safe and will not raise additional exceptions. + """ + if self._stream_response is not None: + await self._stream_response.cancel() + # If there's no stream response, this is a no-op (already completed) + async def get_output(self) -> OutputDataT: """Stream the whole response, validate and return it.""" if self._run_result is not None: @@ -526,21 +552,27 @@ class FinalResult(Generic[OutputDataT]): __repr__ = _utils.dataclasses_no_defaults_repr -def _get_usage_checking_stream_response( +def _get_cancellation_aware_stream_response( stream_response: models.StreamedResponse, limits: UsageLimits | None, get_usage: Callable[[], RunUsage], + is_cancelled: Callable[[], bool], ) -> AsyncIterator[ModelResponseStreamEvent]: - if limits is not None and limits.has_token_limits(): + """Create an iterator that checks for cancellation and usage limits.""" - async def _usage_checking_iterator(): - async for item in stream_response: + async def _cancellation_aware_iterator(): + async for item in stream_response: + # Check for cancellation first + if is_cancelled(): + raise exceptions.StreamCancelled() + + # Then check usage limits if needed + if limits is not None and limits.has_token_limits(): limits.check_tokens(get_usage()) - yield item - return _usage_checking_iterator() - else: - return aiter(stream_response) + yield item + + return _cancellation_aware_iterator() def _get_deferred_tool_requests( diff --git a/tests/models/test_openai.py b/tests/models/test_openai.py index 3ae08f689..79c8f86ff 100644 --- a/tests/models/test_openai.py +++ b/tests/models/test_openai.py @@ -15,6 +15,7 @@ from pydantic_ai import Agent, ModelHTTPError, ModelRetry, UnexpectedModelBehavior, UserError from pydantic_ai.builtin_tools import WebSearchTool +from pydantic_ai.exceptions import StreamCancelled from pydantic_ai.messages import ( AudioUrl, BinaryContent, @@ -2923,6 +2924,75 @@ async def test_openai_model_cerebras_provider_harmony(allow_model_requests: None assert result.output == snapshot('The capital of France is **Paris**.') +async def test_stream_cancellation(allow_model_requests: None): + """Test that stream cancellation works correctly with mocked responses.""" + # Create a simple stream + stream = [ + text_chunk('Hello '), + text_chunk('world'), + chunk([]), + ] + mock_client = MockOpenAI.create_mock_stream(stream) + m = OpenAIChatModel('gpt-4o-mini', provider=OpenAIProvider(openai_client=mock_client)) + agent = Agent(m) + + async with agent.run_stream('Hello') as result: + # Cancel immediately and then try to iterate + await result.cancel() + + # Now try to iterate - this should raise StreamCancelled + with pytest.raises(StreamCancelled): + async for content in result.stream_text(delta=True): + pytest.fail(f'Should not receive content after cancellation: {content}') + + +async def test_multiple_cancel_calls(allow_model_requests: None): + """Test that multiple cancel calls are safe.""" + stream = [ + text_chunk('Hello '), + text_chunk('world '), + text_chunk('from '), + text_chunk('AI!'), + chunk([]), + ] + mock_client = MockOpenAI.create_mock_stream(stream) + m = OpenAIChatModel('gpt-4o-mini', provider=OpenAIProvider(openai_client=mock_client)) + agent = Agent(m) + + async with agent.run_stream('Hello world') as result: + # Call cancel multiple times - should be safe + await result.cancel() + await result.cancel() + await result.cancel() + + # Try to iterate - should raise StreamCancelled + with pytest.raises(StreamCancelled): + async for content in result.stream_text(delta=True): + pytest.fail(f'Should not receive content after cancellation: {content}') + + +async def test_stream_cancellation_immediate(allow_model_requests: None): + """Test immediate cancellation before iteration.""" + stream = [ + text_chunk('This should '), + text_chunk('not be '), + text_chunk('processed.'), + chunk([]), + ] + mock_client = MockOpenAI.create_mock_stream(stream) + m = OpenAIChatModel('gpt-4o-mini', provider=OpenAIProvider(openai_client=mock_client)) + agent = Agent(m) + + async with agent.run_stream('Tell me a story') as result: + # Cancel immediately before any iteration + await result.cancel() + + # Try to iterate - should raise StreamCancelled immediately + with pytest.raises(StreamCancelled): + async for content in result.stream_text(delta=True): + pytest.fail(f'Should not receive any content after immediate cancellation: {content}') + + def test_deprecated_openai_model(openai_api_key: str): with pytest.warns(DeprecationWarning): from pydantic_ai.models.openai import OpenAIModel # type: ignore[reportDeprecated]