diff --git a/src/agents/result.py b/src/agents/result.py index 438d53af2..f5955904c 100644 --- a/src/agents/result.py +++ b/src/agents/result.py @@ -306,6 +306,7 @@ async def stream_events(self) -> AsyncIterator[StreamEvent]: - A MaxTurnsExceeded exception if the agent exceeds the max_turns limit. - A GuardrailTripwireTriggered exception if a guardrail is tripped. """ + cancelled = False try: while True: self._check_errors() @@ -320,7 +321,9 @@ async def stream_events(self) -> AsyncIterator[StreamEvent]: try: item = await self._event_queue.get() except asyncio.CancelledError: - break + cancelled = True + self.cancel() + raise if isinstance(item, QueueCompleteSentinel): # Await input guardrails if they are still running, so late @@ -337,11 +340,16 @@ async def stream_events(self) -> AsyncIterator[StreamEvent]: yield item self._event_queue.task_done() finally: - # Ensure main execution completes before cleanup to avoid race conditions - # with session operations - await self._await_task_safely(self._run_impl_task) - # Safely terminate all background tasks after main execution has finished - self._cleanup_tasks() + if cancelled: + # Cancellation should return promptly, so avoid waiting on long-running tasks. + # Tasks have already been cancelled above. + self._cleanup_tasks() + else: + # Ensure main execution completes before cleanup to avoid race conditions + # with session operations + await self._await_task_safely(self._run_impl_task) + # Safely terminate all background tasks after main execution has finished + self._cleanup_tasks() if self._stored_exception: raise self._stored_exception diff --git a/tests/test_cancel_streaming.py b/tests/test_cancel_streaming.py index ddf603f9f..76b78cb58 100644 --- a/tests/test_cancel_streaming.py +++ b/tests/test_cancel_streaming.py @@ -1,13 +1,31 @@ +import asyncio import json +import time import pytest +from openai.types.responses import ResponseCompletedEvent from agents import Agent, Runner +from agents.stream_events import RawResponsesStreamEvent from .fake_model import FakeModel from .test_responses import get_function_tool, get_function_tool_call, get_text_message +class SlowCompleteFakeModel(FakeModel): + """A FakeModel that delays before emitting the completed event in streaming.""" + + def __init__(self, delay_seconds: float): + super().__init__() + self._delay_seconds = delay_seconds + + async def stream_response(self, *args, **kwargs): + async for ev in super().stream_response(*args, **kwargs): + if isinstance(ev, ResponseCompletedEvent) and self._delay_seconds > 0: + await asyncio.sleep(self._delay_seconds) + yield ev + + @pytest.mark.asyncio async def test_simple_streaming_with_cancel(): model = FakeModel() @@ -131,3 +149,30 @@ async def test_cancel_immediate_mode_explicit(): assert result.is_complete assert result._event_queue.empty() assert result._cancel_mode == "immediate" + + +@pytest.mark.asyncio +async def test_stream_events_respects_asyncio_timeout_cancellation(): + model = SlowCompleteFakeModel(delay_seconds=0.5) + model.set_next_output([get_text_message("Final response")]) + agent = Agent(name="TimeoutTester", model=model) + + result = Runner.run_streamed(agent, input="Please tell me 5 jokes.") + event_iter = result.stream_events().__aiter__() + + # Consume events until the output item is done so the next event is delayed. + while True: + event = await asyncio.wait_for(event_iter.__anext__(), timeout=1.0) + if ( + isinstance(event, RawResponsesStreamEvent) + and event.data.type == "response.output_item.done" + ): + break + + start = time.perf_counter() + with pytest.raises(asyncio.TimeoutError): + await asyncio.wait_for(event_iter.__anext__(), timeout=0.1) + elapsed = time.perf_counter() - start + + assert elapsed < 0.3, "Cancellation should propagate promptly when waiting for events." + result.cancel()