Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
190 changes: 126 additions & 64 deletions amplifier_module_loop_streaming/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,15 @@ def __init__(self, config: dict[str, Any]):
# Store ephemeral injections from tool:post hooks for next iteration
self._pending_ephemeral_injections: list[dict[str, Any]] = []

# Retry configuration for transient provider errors
self.retry_max_attempts = int(config.get("retry_max_attempts", 3))
self.retry_base_delay_seconds = float(
config.get("retry_base_delay_seconds", 1.0)
)
self.retry_max_delay_seconds = float(
config.get("retry_max_delay_seconds", 30.0)
)

async def _apply_rate_limit_delay(
self, hooks: HookRegistry, iteration: int
) -> None:
Expand Down Expand Up @@ -106,6 +115,105 @@ async def _apply_rate_limit_delay(
)
await asyncio.sleep(remaining_ms / 1000)

async def _call_provider_with_retry(
self,
call_fn,
hooks: HookRegistry,
provider_name: str | None,
) -> Any:
"""Call a provider function with exponential backoff retry on retryable errors.

Wraps both async (provider.complete()) and sync (provider.stream()) calls.
Retries only when the error has retryable=True (e.g., RateLimitError,
ProviderUnavailableError, LLMTimeoutError). Honors retry_after from the
error when available, otherwise uses exponential backoff.

Args:
call_fn: Callable that makes the provider call. May return a coroutine
(for complete()) or a value (for stream()).
hooks: Hook registry for event emission.
provider_name: Name of the provider for event data.

Returns:
The result of call_fn().

Raises:
LLMError: If the error is not retryable or all retries are exhausted.
Exception: If a non-LLM error occurs (never retried).
"""
max_retries = self.retry_max_attempts

for attempt in range(max_retries + 1):
try:
result = call_fn()
if asyncio.iscoroutine(result):
result = await result
return result
except LLMError as e:
is_last_attempt = attempt >= max_retries
if e.retryable and not is_last_attempt:
# Calculate delay: prefer server-provided retry_after
retry_after = getattr(e, "retry_after", None)
if retry_after is not None:
delay = float(retry_after)
else:
delay = min(
self.retry_base_delay_seconds * (2**attempt),
self.retry_max_delay_seconds,
)

await hooks.emit(
"provider:retry",
{
"provider": provider_name,
"error": {
"type": type(e).__name__,
"msg": str(e),
},
"attempt": attempt + 1,
"max_retries": max_retries,
"delay_seconds": delay,
"retryable": e.retryable,
"status_code": e.status_code,
},
)
logger.warning(
"Retryable provider error (attempt %d/%d): %s. "
"Retrying in %.1fs...",
attempt + 1,
max_retries,
e,
delay,
)
await asyncio.sleep(delay)
continue

# Not retryable or final attempt — emit error and raise
await hooks.emit(
PROVIDER_ERROR,
{
"provider": provider_name,
"error": {"type": type(e).__name__, "msg": str(e)},
"retryable": e.retryable,
"status_code": e.status_code,
},
)
raise
except Exception as e:
# Non-LLM errors are never retried
await hooks.emit(
PROVIDER_ERROR,
{
"provider": provider_name,
"error": {"type": type(e).__name__, "msg": str(e)},
},
)
raise

# Unreachable, but satisfies type checkers
msg = "Retry loop exited unexpectedly"
raise RuntimeError(msg)

async def execute(
self,
prompt: str,
Expand Down Expand Up @@ -390,28 +498,11 @@ async def _execute_stream(
kwargs = {}
if self.extended_thinking:
kwargs["extended_thinking"] = True
try:
response = await provider.complete(chat_request, **kwargs)
except LLMError as e:
await hooks.emit(
PROVIDER_ERROR,
{
"provider": provider_name,
"error": {"type": type(e).__name__, "msg": str(e)},
"retryable": e.retryable,
"status_code": e.status_code,
},
)
raise
except Exception as e:
await hooks.emit(
PROVIDER_ERROR,
{
"provider": provider_name,
"error": {"type": type(e).__name__, "msg": str(e)},
},
)
raise
response = await self._call_provider_with_retry(
lambda: provider.complete(chat_request, **kwargs),
hooks,
provider_name,
)

# Update rate limit timestamp after non-streaming response
self._last_provider_call_end = time.monotonic()
Expand Down Expand Up @@ -706,7 +797,11 @@ async def _execute_stream(
if self.extended_thinking:
kwargs["extended_thinking"] = True

response = await provider.complete(max_iter_chat_request, **kwargs)
response = await self._call_provider_with_retry(
lambda: provider.complete(max_iter_chat_request, **kwargs),
hooks,
provider_name,
)
content = (
response.content if hasattr(response, "content") else str(response)
)
Expand All @@ -719,25 +814,9 @@ async def _execute_stream(
# Add to context
await context.add_message({"role": "assistant", "content": content})

except LLMError as e:
await hooks.emit(
PROVIDER_ERROR,
{
"provider": provider_name,
"error": {"type": type(e).__name__, "msg": str(e)},
"retryable": e.retryable,
"status_code": e.status_code,
},
)
logger.error(f"Error getting final response after max iterations: {e}")
except Exception as e:
await hooks.emit(
PROVIDER_ERROR,
{
"provider": provider_name,
"error": {"type": type(e).__name__, "msg": str(e)},
},
)
except (LLMError, Exception) as e:
# _call_provider_with_retry already emitted provider:error and
# exhausted retries — log and continue gracefully.
logger.error(f"Error getting final response after max iterations: {e}")

# Emit execution end
Expand Down Expand Up @@ -771,28 +850,11 @@ async def _stream_from_provider(

# Convert tools dict to list for provider
tools_list = list(tools.values()) if tools else []
try:
stream_iter = provider.stream(chat_request, tools=tools_list)
except LLMError as e:
await hooks.emit(
PROVIDER_ERROR,
{
"provider": provider_name,
"error": {"type": type(e).__name__, "msg": str(e)},
"retryable": e.retryable,
"status_code": e.status_code,
},
)
raise
except Exception as e:
await hooks.emit(
PROVIDER_ERROR,
{
"provider": provider_name,
"error": {"type": type(e).__name__, "msg": str(e)},
},
)
raise
stream_iter = await self._call_provider_with_retry(
lambda: provider.stream(chat_request, tools=tools_list),
hooks,
provider_name,
)

async for chunk in stream_iter:
# Check for immediate cancellation between chunks
Expand Down
Loading