diff --git a/amplifier_core/cancellation.py b/amplifier_core/cancellation.py index 5bd9476..5b44f2b 100644 --- a/amplifier_core/cancellation.py +++ b/amplifier_core/cancellation.py @@ -5,6 +5,8 @@ The app layer provides the POLICY (when to cancel). """ +import asyncio +import logging from dataclasses import dataclass, field from enum import Enum from typing import TYPE_CHECKING, Awaitable, Callable, Set @@ -161,8 +163,22 @@ def on_cancel(self, callback: Callable[[], Awaitable[None]]) -> None: async def trigger_callbacks(self) -> None: """Trigger all registered cancellation callbacks.""" + _logger = logging.getLogger(__name__) + first_fatal = None for callback in self._on_cancel_callbacks: try: await callback() + except asyncio.CancelledError: + # CancelledError is a BaseException (Python 3.9+). Log and continue + # so all cancellation callbacks run. + _logger.warning("CancelledError in cancellation callback") except Exception: pass # Don't let callback errors prevent cancellation + except BaseException as e: + # Track fatal exceptions (KeyboardInterrupt, SystemExit) for re-raise + # after all callbacks complete. + _logger.warning(f"Fatal exception in cancellation callback: {e}") + if first_fatal is None: + first_fatal = e + if first_fatal is not None: + raise first_fatal diff --git a/amplifier_core/coordinator.py b/amplifier_core/coordinator.py index cf0f62a..e972d0b 100644 --- a/amplifier_core/coordinator.py +++ b/amplifier_core/coordinator.py @@ -12,6 +12,7 @@ identifiers and basic state necessary to make module boundaries work. """ +import asyncio import inspect import logging from collections.abc import Awaitable @@ -335,6 +336,14 @@ async def collect_contributions(self, channel: str) -> list[Any]: if result is not None: contributions.append(result) + except asyncio.CancelledError: + # CancelledError is a BaseException (Python 3.9+) - catch specifically. + # Stop collecting (honor cancellation signal) and return what we have. + logger.warning( + f"Collection cancelled during contributor " + f"'{contributor['name']}' on channel '{channel}'" + ) + break except Exception as e: logger.warning( f"Contributor '{contributor['name']}' on channel '{channel}' failed: {e}" @@ -344,6 +353,7 @@ async def collect_contributions(self, channel: str) -> list[Any]: async def cleanup(self): """Call all registered cleanup functions.""" + first_fatal = None for cleanup_fn in reversed(self._cleanup_functions): try: if callable(cleanup_fn): @@ -353,8 +363,16 @@ async def cleanup(self): result = cleanup_fn() if inspect.iscoroutine(result): await result - except Exception as e: + except BaseException as e: + # Catch BaseException to survive asyncio.CancelledError (a BaseException + # subclass since Python 3.9) so remaining cleanup functions still run. + # Track fatal exceptions (KeyboardInterrupt, SystemExit) for re-raise + # after all cleanup completes. logger.error(f"Error during cleanup: {e}") + if first_fatal is None and not isinstance(e, Exception): + first_fatal = e + if first_fatal is not None: + raise first_fatal def reset_turn(self): """Reset per-turn tracking. Call at turn boundaries.""" diff --git a/amplifier_core/hooks.py b/amplifier_core/hooks.py index 9b473e3..a8abbf0 100644 --- a/amplifier_core/hooks.py +++ b/amplifier_core/hooks.py @@ -70,18 +70,24 @@ def register( Returns: Unregister function """ - hook_handler = HookHandler(handler=handler, priority=priority, name=name or handler.__name__) + hook_handler = HookHandler( + handler=handler, priority=priority, name=name or handler.__name__ + ) self._handlers[event].append(hook_handler) self._handlers[event].sort() # Keep sorted by priority - logger.debug(f"Registered hook '{hook_handler.name}' for event '{event}' with priority {priority}") + logger.debug( + f"Registered hook '{hook_handler.name}' for event '{event}' with priority {priority}" + ) def unregister(): """Remove this handler from the registry.""" if hook_handler in self._handlers[event]: self._handlers[event].remove(hook_handler) - logger.debug(f"Unregistered hook '{hook_handler.name}' from event '{event}'") + logger.debug( + f"Unregistered hook '{hook_handler.name}' from event '{event}'" + ) return unregister @@ -140,11 +146,15 @@ async def emit(self, event: str, data: dict[str, Any]) -> HookResult: result = await hook_handler.handler(event, current_data) if not isinstance(result, HookResult): - logger.warning(f"Handler '{hook_handler.name}' returned invalid result type") + logger.warning( + f"Handler '{hook_handler.name}' returned invalid result type" + ) continue if result.action == "deny": - logger.info(f"Event '{event}' denied by handler '{hook_handler.name}': {result.reason}") + logger.info( + f"Event '{event}' denied by handler '{hook_handler.name}': {result.reason}" + ) return result if result.action == "modify" and result.data is not None: @@ -154,15 +164,27 @@ async def emit(self, event: str, data: dict[str, Any]) -> HookResult: # Collect inject_context actions for merging if result.action == "inject_context" and result.context_injection: inject_context_results.append(result) - logger.debug(f"Handler '{hook_handler.name}' returned inject_context") + logger.debug( + f"Handler '{hook_handler.name}' returned inject_context" + ) # Preserve ask_user (only first one, can't merge approvals) if result.action == "ask_user" and special_result is None: special_result = result logger.debug(f"Handler '{hook_handler.name}' returned ask_user") + except asyncio.CancelledError: + # CancelledError is a BaseException (Python 3.9+). Log and continue + # so all handlers observe the event (important for cleanup events + # like session:end that flow through emit). + logger.error( + f"CancelledError in hook handler '{hook_handler.name}' " + f"for event '{event}'" + ) except Exception as e: - logger.error(f"Error in hook handler '{hook_handler.name}' for event '{event}': {e}") + logger.error( + f"Error in hook handler '{hook_handler.name}' for event '{event}': {e}" + ) # Continue with other handlers even if one fails # If multiple inject_context results, merge them. @@ -173,7 +195,9 @@ async def emit(self, event: str, data: dict[str, Any]) -> HookResult: merged_inject = self._merge_inject_context_results(inject_context_results) if special_result is None: special_result = merged_inject - logger.debug(f"Merged {len(inject_context_results)} inject_context results") + logger.debug( + f"Merged {len(inject_context_results)} inject_context results" + ) else: # ask_user already captured - don't overwrite it logger.debug( @@ -208,7 +232,9 @@ def _merge_inject_context_results(self, results: list[HookResult]) -> HookResult return results[0] # Combine all injections - combined_content = "\n\n".join(result.context_injection for result in results if result.context_injection) + combined_content = "\n\n".join( + result.context_injection for result in results if result.context_injection + ) # Use settings from first result (role, ephemeral, suppress_output) first = results[0] @@ -221,7 +247,9 @@ def _merge_inject_context_results(self, results: list[HookResult]) -> HookResult suppress_output=first.suppress_output, ) - async def emit_and_collect(self, event: str, data: dict[str, Any], timeout: float = 1.0) -> list[Any]: + async def emit_and_collect( + self, event: str, data: dict[str, Any], timeout: float = 1.0 + ) -> list[Any]: """ Emit event and collect data from all handler responses. @@ -247,27 +275,46 @@ async def emit_and_collect(self, event: str, data: dict[str, Any], timeout: floa logger.debug(f"No handlers for event '{event}'") return [] - logger.debug(f"Collecting responses for event '{event}' from {len(handlers)} handlers") + logger.debug( + f"Collecting responses for event '{event}' from {len(handlers)} handlers" + ) responses = [] for hook_handler in handlers: try: # Call handler with timeout - result = await asyncio.wait_for(hook_handler.handler(event, data), timeout=timeout) + result = await asyncio.wait_for( + hook_handler.handler(event, data), timeout=timeout + ) if not isinstance(result, HookResult): - logger.warning(f"Handler '{hook_handler.name}' returned invalid result type") + logger.warning( + f"Handler '{hook_handler.name}' returned invalid result type" + ) continue # Collect response data if present if result.data is not None: responses.append(result.data) - logger.debug(f"Collected response from handler '{hook_handler.name}'") + logger.debug( + f"Collected response from handler '{hook_handler.name}'" + ) except TimeoutError: - logger.warning(f"Handler '{hook_handler.name}' timed out after {timeout}s") + logger.warning( + f"Handler '{hook_handler.name}' timed out after {timeout}s" + ) + except asyncio.CancelledError: + # CancelledError is a BaseException (Python 3.9+). Log and continue + # so all handlers get a chance to respond. + logger.error( + f"CancelledError in hook handler '{hook_handler.name}' " + f"for event '{event}'" + ) except Exception as e: - logger.error(f"Error in hook handler '{hook_handler.name}' for event '{event}': {e}") + logger.error( + f"Error in hook handler '{hook_handler.name}' for event '{event}': {e}" + ) # Continue with other handlers logger.debug(f"Collected {len(responses)} responses for event '{event}'") @@ -286,4 +333,7 @@ def list_handlers(self, event: str | None = None) -> dict[str, list[str]]: if event: handlers = self._handlers.get(event, []) return {event: [h.name for h in handlers if h.name is not None]} - return {evt: [h.name for h in handlers if h.name is not None] for evt, handlers in self._handlers.items()} + return { + evt: [h.name for h in handlers if h.name is not None] + for evt, handlers in self._handlers.items() + } diff --git a/amplifier_core/session.py b/amplifier_core/session.py index 44f7a32..d641d7a 100644 --- a/amplifier_core/session.py +++ b/amplifier_core/session.py @@ -20,7 +20,7 @@ logger = logging.getLogger(__name__) -def _safe_exception_str(e: Exception) -> str: +def _safe_exception_str(e: BaseException) -> str: """ CRITICAL: Explicitly handle exception string conversion for Windows cp1252 compatibility. Default encoding can fail on non-cp1252 characters, causing a crash during error handling. @@ -432,8 +432,9 @@ async def execute(self, prompt: str) -> str: self.status.status = "completed" return result - except Exception as e: - # Check if this was a cancellation-related exception + except BaseException as e: + # Catch BaseException to handle asyncio.CancelledError (a BaseException + # subclass since Python 3.9). All paths re-raise after status tracking. if self.coordinator.cancellation.is_cancelled: self.status.status = "cancelled" from .events import CANCEL_COMPLETED @@ -455,10 +456,13 @@ async def execute(self, prompt: str) -> str: async def cleanup(self: "AmplifierSession") -> None: """Clean up session resources.""" - await self.coordinator.cleanup() - # Clean up sys.path modifications - if self.loader: - self.loader.cleanup() + try: + await self.coordinator.cleanup() + finally: + # Clean up sys.path modifications - must always run even if + # coordinator cleanup raises (e.g., asyncio.CancelledError) + if self.loader: + self.loader.cleanup() async def __aenter__(self: "AmplifierSession"): """Async context manager entry.""" diff --git a/tests/test_cancellation_resilience.py b/tests/test_cancellation_resilience.py new file mode 100644 index 0000000..e9aeae4 --- /dev/null +++ b/tests/test_cancellation_resilience.py @@ -0,0 +1,308 @@ +"""Tests for asyncio.CancelledError resilience across kernel exception-handling sites. + +Python 3.9+ made CancelledError a BaseException subclass, so bare `except Exception` +misses it. These tests verify each fixed site handles CancelledError correctly. +""" + +import asyncio +from unittest.mock import AsyncMock, MagicMock + +import pytest + +from amplifier_core.cancellation import CancellationToken +from amplifier_core.coordinator import ModuleCoordinator +from amplifier_core.hooks import HookRegistry +from amplifier_core.models import HookResult +from amplifier_core.session import AmplifierSession + + +# --------------------------------------------------------------------------- +# Fixtures +# --------------------------------------------------------------------------- + + +@pytest.fixture +def coordinator(): + """Create a minimal coordinator for testing.""" + + class MockSession: + session_id = "test-session" + + mock_session = MockSession() + return ModuleCoordinator(session=mock_session) # type: ignore[arg-type] + + +# --------------------------------------------------------------------------- +# 1. collect_contributions — CancelledError breaks, Exception continues +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_collect_contributions_breaks_on_cancelled_error(coordinator): + """CancelledError in a contributor stops collection and returns partial results.""" + called = [] + + coordinator.register_contributor("ch", "mod1", lambda: "result-1") + + async def cancelling_contributor(): + called.append("mod2") + raise asyncio.CancelledError() + + coordinator.register_contributor("ch", "mod2", cancelling_contributor) + + def mod3_contributor(): + called.append("mod3") + return "result-3" + + coordinator.register_contributor("ch", "mod3", mod3_contributor) + + contributions = await coordinator.collect_contributions("ch") + + # Only mod1 returned before cancellation; mod3 never ran + assert contributions == ["result-1"] + assert "mod2" in called + assert "mod3" not in called + + +@pytest.mark.asyncio +async def test_collect_contributions_exception_continues(coordinator): + """A regular Exception in a contributor doesn't stop collection.""" + coordinator.register_contributor("ch", "mod1", lambda: "result-1") + + def failing_contributor(): + raise RuntimeError("boom") + + coordinator.register_contributor("ch", "mod2", failing_contributor) + coordinator.register_contributor("ch", "mod3", lambda: "result-3") + + contributions = await coordinator.collect_contributions("ch") + + assert "result-1" in contributions + assert "result-3" in contributions + assert len(contributions) == 2 + + +# --------------------------------------------------------------------------- +# 2. coordinator.cleanup — survives CancelledError, re-raises fatal exceptions +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_cleanup_continues_after_cancelled_error(coordinator): + """All cleanup functions run even if one raises CancelledError. + + CancelledError is a BaseException (not Exception), so it is tracked as + fatal and re-raised after all cleanup functions complete. + """ + called = [] + + async def cleanup1(): + called.append(1) + raise asyncio.CancelledError() + + def cleanup2(): + called.append(2) + + def cleanup3(): + called.append(3) + + coordinator.register_cleanup(cleanup1) + coordinator.register_cleanup(cleanup2) + coordinator.register_cleanup(cleanup3) + + # cleanup runs in reverse order: 3, 2, 1 + # CancelledError is re-raised after all cleanup completes + with pytest.raises(asyncio.CancelledError): + await coordinator.cleanup() + + assert sorted(called) == [1, 2, 3] + + +@pytest.mark.asyncio +async def test_cleanup_reraises_keyboard_interrupt_after_completing(coordinator): + """KeyboardInterrupt is re-raised after all cleanup functions have run.""" + called = [] + + def cleanup1(): + called.append(1) + raise KeyboardInterrupt() + + def cleanup2(): + called.append(2) + + coordinator.register_cleanup(cleanup1) + coordinator.register_cleanup(cleanup2) + + # Reverse order: cleanup2 runs first, then cleanup1 raises + with pytest.raises(KeyboardInterrupt): + await coordinator.cleanup() + + assert sorted(called) == [1, 2] + + +@pytest.mark.asyncio +async def test_cleanup_reraises_system_exit_after_completing(coordinator): + """SystemExit is re-raised after all cleanup functions have run.""" + called = [] + + def cleanup1(): + called.append(1) + raise SystemExit(1) + + def cleanup2(): + called.append(2) + + coordinator.register_cleanup(cleanup1) + coordinator.register_cleanup(cleanup2) + + with pytest.raises(SystemExit): + await coordinator.cleanup() + + assert sorted(called) == [1, 2] + + +# --------------------------------------------------------------------------- +# 3. session.cleanup — try/finally ensures loader.cleanup always runs +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_session_cleanup_runs_loader_even_if_coordinator_fails(): + """loader.cleanup() runs even when coordinator.cleanup() raises CancelledError.""" + minimal_config = { + "session": { + "orchestrator": "test-orch", + "context": "test-ctx", + }, + } + session = AmplifierSession(config=minimal_config) + + # Mock coordinator.cleanup to raise CancelledError + session.coordinator.cleanup = AsyncMock(side_effect=asyncio.CancelledError()) + # Mock loader.cleanup to track that it was called + session.loader.cleanup = MagicMock() + + with pytest.raises(asyncio.CancelledError): + await session.cleanup() + + session.loader.cleanup.assert_called_once() + + +# --------------------------------------------------------------------------- +# 4. hooks.emit — CancelledError in a handler doesn't skip remaining handlers +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_emit_continues_after_cancelled_error(): + """All hook handlers run even if one raises CancelledError.""" + registry = HookRegistry() + called = [] + + async def handler1(event, data): + called.append("h1") + raise asyncio.CancelledError() + + async def handler2(event, data): + called.append("h2") + return HookResult(action="continue") + + async def handler3(event, data): + called.append("h3") + return HookResult(action="continue") + + registry.register("test:event", handler1, priority=0, name="h1") + registry.register("test:event", handler2, priority=1, name="h2") + registry.register("test:event", handler3, priority=2, name="h3") + + result = await registry.emit("test:event", {"key": "value"}) + + assert called == ["h1", "h2", "h3"] + assert result.action == "continue" + + +# --------------------------------------------------------------------------- +# 5. hooks.emit_and_collect — same pattern for the collection variant +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_emit_and_collect_continues_after_cancelled_error(): + """emit_and_collect runs all handlers even if one raises CancelledError.""" + registry = HookRegistry() + called = [] + + async def handler1(event, data): + called.append("h1") + raise asyncio.CancelledError() + + async def handler2(event, data): + called.append("h2") + return HookResult(action="continue", data={"from": "h2"}) + + async def handler3(event, data): + called.append("h3") + return HookResult(action="continue", data={"from": "h3"}) + + registry.register("test:event", handler1, priority=0, name="h1") + registry.register("test:event", handler2, priority=1, name="h2") + registry.register("test:event", handler3, priority=2, name="h3") + + responses = await registry.emit_and_collect("test:event", {"key": "value"}) + + assert called == ["h1", "h2", "h3"] + assert len(responses) == 2 + assert {"from": "h2"} in responses + assert {"from": "h3"} in responses + + +# --------------------------------------------------------------------------- +# 6. cancellation.trigger_callbacks — survives CancelledError, re-raises fatal +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_trigger_callbacks_continues_after_cancelled_error(): + """All cancellation callbacks run even if one raises CancelledError.""" + token = CancellationToken() + called = [] + + async def cb1(): + called.append(1) + raise asyncio.CancelledError() + + async def cb2(): + called.append(2) + + async def cb3(): + called.append(3) + + token.on_cancel(cb1) + token.on_cancel(cb2) + token.on_cancel(cb3) + + await token.trigger_callbacks() + + assert called == [1, 2, 3] + + +@pytest.mark.asyncio +async def test_trigger_callbacks_reraises_keyboard_interrupt_after_completing(): + """KeyboardInterrupt is re-raised after all cancellation callbacks run.""" + token = CancellationToken() + called = [] + + async def cb1(): + called.append(1) + raise KeyboardInterrupt() + + async def cb2(): + called.append(2) + + token.on_cancel(cb1) + token.on_cancel(cb2) + + with pytest.raises(KeyboardInterrupt): + await token.trigger_callbacks() + + assert called == [1, 2]