Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions amplifier_core/cancellation.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
The app layer provides the POLICY (when to cancel).
"""

import asyncio
import logging
from dataclasses import dataclass, field
from enum import Enum
from typing import TYPE_CHECKING, Awaitable, Callable, Set
Expand Down Expand Up @@ -161,8 +163,22 @@ def on_cancel(self, callback: Callable[[], Awaitable[None]]) -> None:

async def trigger_callbacks(self) -> None:
"""Trigger all registered cancellation callbacks."""
_logger = logging.getLogger(__name__)
first_fatal = None
for callback in self._on_cancel_callbacks:
try:
await callback()
except asyncio.CancelledError:
# CancelledError is a BaseException (Python 3.9+). Log and continue
# so all cancellation callbacks run.
_logger.warning("CancelledError in cancellation callback")
except Exception:
pass # Don't let callback errors prevent cancellation
except BaseException as e:
# Track fatal exceptions (KeyboardInterrupt, SystemExit) for re-raise
# after all callbacks complete.
_logger.warning(f"Fatal exception in cancellation callback: {e}")
if first_fatal is None:
first_fatal = e
if first_fatal is not None:
raise first_fatal
20 changes: 19 additions & 1 deletion amplifier_core/coordinator.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
identifiers and basic state necessary to make module boundaries work.
"""

import asyncio
import inspect
import logging
from collections.abc import Awaitable
Expand Down Expand Up @@ -335,6 +336,14 @@ async def collect_contributions(self, channel: str) -> list[Any]:

if result is not None:
contributions.append(result)
except asyncio.CancelledError:
# CancelledError is a BaseException (Python 3.9+) - catch specifically.
# Stop collecting (honor cancellation signal) and return what we have.
logger.warning(
f"Collection cancelled during contributor "
f"'{contributor['name']}' on channel '{channel}'"
)
break
except Exception as e:
logger.warning(
f"Contributor '{contributor['name']}' on channel '{channel}' failed: {e}"
Expand All @@ -344,6 +353,7 @@ async def collect_contributions(self, channel: str) -> list[Any]:

async def cleanup(self):
"""Call all registered cleanup functions."""
first_fatal = None
for cleanup_fn in reversed(self._cleanup_functions):
try:
if callable(cleanup_fn):
Expand All @@ -353,8 +363,16 @@ async def cleanup(self):
result = cleanup_fn()
if inspect.iscoroutine(result):
await result
except Exception as e:
except BaseException as e:
# Catch BaseException to survive asyncio.CancelledError (a BaseException
# subclass since Python 3.9) so remaining cleanup functions still run.
# Track fatal exceptions (KeyboardInterrupt, SystemExit) for re-raise
# after all cleanup completes.
logger.error(f"Error during cleanup: {e}")
if first_fatal is None and not isinstance(e, Exception):
first_fatal = e
if first_fatal is not None:
raise first_fatal

def reset_turn(self):
"""Reset per-turn tracking. Call at turn boundaries."""
Expand Down
84 changes: 67 additions & 17 deletions amplifier_core/hooks.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,18 +70,24 @@ def register(
Returns:
Unregister function
"""
hook_handler = HookHandler(handler=handler, priority=priority, name=name or handler.__name__)
hook_handler = HookHandler(
handler=handler, priority=priority, name=name or handler.__name__
)

self._handlers[event].append(hook_handler)
self._handlers[event].sort() # Keep sorted by priority

logger.debug(f"Registered hook '{hook_handler.name}' for event '{event}' with priority {priority}")
logger.debug(
f"Registered hook '{hook_handler.name}' for event '{event}' with priority {priority}"
)

def unregister():
"""Remove this handler from the registry."""
if hook_handler in self._handlers[event]:
self._handlers[event].remove(hook_handler)
logger.debug(f"Unregistered hook '{hook_handler.name}' from event '{event}'")
logger.debug(
f"Unregistered hook '{hook_handler.name}' from event '{event}'"
)

return unregister

Expand Down Expand Up @@ -140,11 +146,15 @@ async def emit(self, event: str, data: dict[str, Any]) -> HookResult:
result = await hook_handler.handler(event, current_data)

if not isinstance(result, HookResult):
logger.warning(f"Handler '{hook_handler.name}' returned invalid result type")
logger.warning(
f"Handler '{hook_handler.name}' returned invalid result type"
)
continue

if result.action == "deny":
logger.info(f"Event '{event}' denied by handler '{hook_handler.name}': {result.reason}")
logger.info(
f"Event '{event}' denied by handler '{hook_handler.name}': {result.reason}"
)
return result

if result.action == "modify" and result.data is not None:
Expand All @@ -154,15 +164,27 @@ async def emit(self, event: str, data: dict[str, Any]) -> HookResult:
# Collect inject_context actions for merging
if result.action == "inject_context" and result.context_injection:
inject_context_results.append(result)
logger.debug(f"Handler '{hook_handler.name}' returned inject_context")
logger.debug(
f"Handler '{hook_handler.name}' returned inject_context"
)

# Preserve ask_user (only first one, can't merge approvals)
if result.action == "ask_user" and special_result is None:
special_result = result
logger.debug(f"Handler '{hook_handler.name}' returned ask_user")

except asyncio.CancelledError:
# CancelledError is a BaseException (Python 3.9+). Log and continue
# so all handlers observe the event (important for cleanup events
# like session:end that flow through emit).
logger.error(
f"CancelledError in hook handler '{hook_handler.name}' "
f"for event '{event}'"
)
except Exception as e:
logger.error(f"Error in hook handler '{hook_handler.name}' for event '{event}': {e}")
logger.error(
f"Error in hook handler '{hook_handler.name}' for event '{event}': {e}"
)
# Continue with other handlers even if one fails

# If multiple inject_context results, merge them.
Expand All @@ -173,7 +195,9 @@ async def emit(self, event: str, data: dict[str, Any]) -> HookResult:
merged_inject = self._merge_inject_context_results(inject_context_results)
if special_result is None:
special_result = merged_inject
logger.debug(f"Merged {len(inject_context_results)} inject_context results")
logger.debug(
f"Merged {len(inject_context_results)} inject_context results"
)
else:
# ask_user already captured - don't overwrite it
logger.debug(
Expand Down Expand Up @@ -208,7 +232,9 @@ def _merge_inject_context_results(self, results: list[HookResult]) -> HookResult
return results[0]

# Combine all injections
combined_content = "\n\n".join(result.context_injection for result in results if result.context_injection)
combined_content = "\n\n".join(
result.context_injection for result in results if result.context_injection
)

# Use settings from first result (role, ephemeral, suppress_output)
first = results[0]
Expand All @@ -221,7 +247,9 @@ def _merge_inject_context_results(self, results: list[HookResult]) -> HookResult
suppress_output=first.suppress_output,
)

async def emit_and_collect(self, event: str, data: dict[str, Any], timeout: float = 1.0) -> list[Any]:
async def emit_and_collect(
self, event: str, data: dict[str, Any], timeout: float = 1.0
) -> list[Any]:
"""
Emit event and collect data from all handler responses.

Expand All @@ -247,27 +275,46 @@ async def emit_and_collect(self, event: str, data: dict[str, Any], timeout: floa
logger.debug(f"No handlers for event '{event}'")
return []

logger.debug(f"Collecting responses for event '{event}' from {len(handlers)} handlers")
logger.debug(
f"Collecting responses for event '{event}' from {len(handlers)} handlers"
)

responses = []
for hook_handler in handlers:
try:
# Call handler with timeout
result = await asyncio.wait_for(hook_handler.handler(event, data), timeout=timeout)
result = await asyncio.wait_for(
hook_handler.handler(event, data), timeout=timeout
)

if not isinstance(result, HookResult):
logger.warning(f"Handler '{hook_handler.name}' returned invalid result type")
logger.warning(
f"Handler '{hook_handler.name}' returned invalid result type"
)
continue

# Collect response data if present
if result.data is not None:
responses.append(result.data)
logger.debug(f"Collected response from handler '{hook_handler.name}'")
logger.debug(
f"Collected response from handler '{hook_handler.name}'"
)

except TimeoutError:
logger.warning(f"Handler '{hook_handler.name}' timed out after {timeout}s")
logger.warning(
f"Handler '{hook_handler.name}' timed out after {timeout}s"
)
except asyncio.CancelledError:
# CancelledError is a BaseException (Python 3.9+). Log and continue
# so all handlers get a chance to respond.
logger.error(
f"CancelledError in hook handler '{hook_handler.name}' "
f"for event '{event}'"
)
except Exception as e:
logger.error(f"Error in hook handler '{hook_handler.name}' for event '{event}': {e}")
logger.error(
f"Error in hook handler '{hook_handler.name}' for event '{event}': {e}"
)
# Continue with other handlers

logger.debug(f"Collected {len(responses)} responses for event '{event}'")
Expand All @@ -286,4 +333,7 @@ def list_handlers(self, event: str | None = None) -> dict[str, list[str]]:
if event:
handlers = self._handlers.get(event, [])
return {event: [h.name for h in handlers if h.name is not None]}
return {evt: [h.name for h in handlers if h.name is not None] for evt, handlers in self._handlers.items()}
return {
evt: [h.name for h in handlers if h.name is not None]
for evt, handlers in self._handlers.items()
}
18 changes: 11 additions & 7 deletions amplifier_core/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
logger = logging.getLogger(__name__)


def _safe_exception_str(e: Exception) -> str:
def _safe_exception_str(e: BaseException) -> str:
"""
CRITICAL: Explicitly handle exception string conversion for Windows cp1252 compatibility.
Default encoding can fail on non-cp1252 characters, causing a crash during error handling.
Expand Down Expand Up @@ -432,8 +432,9 @@ async def execute(self, prompt: str) -> str:
self.status.status = "completed"
return result

except Exception as e:
# Check if this was a cancellation-related exception
except BaseException as e:
# Catch BaseException to handle asyncio.CancelledError (a BaseException
# subclass since Python 3.9). All paths re-raise after status tracking.
if self.coordinator.cancellation.is_cancelled:
self.status.status = "cancelled"
from .events import CANCEL_COMPLETED
Expand All @@ -455,10 +456,13 @@ async def execute(self, prompt: str) -> str:

async def cleanup(self: "AmplifierSession") -> None:
"""Clean up session resources."""
await self.coordinator.cleanup()
# Clean up sys.path modifications
if self.loader:
self.loader.cleanup()
try:
await self.coordinator.cleanup()
finally:
# Clean up sys.path modifications - must always run even if
# coordinator cleanup raises (e.g., asyncio.CancelledError)
if self.loader:
self.loader.cleanup()

async def __aenter__(self: "AmplifierSession"):
"""Async context manager entry."""
Expand Down
Loading