Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 16 additions & 5 deletions src/ale_bench_eval/safe_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@
from pydantic_ai.run import AgentRunResult
from pydantic_ai.settings import ModelSettings

from ale_bench_eval.shared_async_loop import shared_async_loop

OPENAI_COMPATIBLE_PROVIDERS = {
"azure",
"deepseek",
Expand Down Expand Up @@ -124,7 +126,9 @@ def safe_generation(
)

try:
result = agent.run_sync(user_prompt=user_prompt, message_history=message_history)
result = shared_async_loop().run(
agent.run(user_prompt=user_prompt, message_history=message_history),
)
model_response = result.all_messages()[-1]
if isinstance(model_response, ModelResponse):
if model_response.finish_reason == "length":
Expand All @@ -140,11 +144,18 @@ def safe_generation(
raise RuntimeError(f"Model API returned an HTTP error: {e}") from e
# NOTE: If too long string is input, sometime returned `exceeded your current quota`
except ModelHTTPError as e:
body = e.body or {}
msg = ""
if isinstance(body, dict):
msg = body.get("message") or body.get("error", {}).get("message") or str(body)
else:
msg = str(body)
if any(
[
"string too long" in e.body["message"], # type: ignore
"exceeds the context window" in e.body["message"], # type: ignore
"maximum context length" in e.body["message"], # type: ignore
s in msg.lower()
for s in [
"string too long",
"exceeds the context window",
"maximum context length",
]
):
raise MaxTokenError("Input exceeds the model's maximum token limit.") from e
Expand Down
118 changes: 118 additions & 0 deletions src/ale_bench_eval/shared_async_loop.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,118 @@
import asyncio
import atexit
import logging
import threading
from concurrent.futures import Future, TimeoutError as FutureTimeoutError
from typing import Coroutine, TypeVar

T = TypeVar("T")


class SharedAsyncLoop:
"""Background event loop shared across threads for async-only providers.

This class is intended for use cases where synchronous code needs to execute asynchronous coroutines,
such as when interacting with async-only providers (e.g., Google GenAI) from synchronous contexts.
"""

SHUTDOWN_TIMEOUT = 5

def __init__(self) -> None:
self._loop = asyncio.new_event_loop()
self._thread = threading.Thread(target=self._run_loop, daemon=True)
self._atexit_cb = self.shutdown
self._thread.start()
try:
atexit.register(self._atexit_cb)
except Exception:
# Ensure the background thread is cleaned up if registration fails.
self.shutdown()
raise

def _run_loop(self) -> None:
asyncio.set_event_loop(self._loop)
self._loop.run_forever()

async def _drain_pending(self) -> None:
tasks = [t for t in asyncio.all_tasks(self._loop) if t is not asyncio.current_task(self._loop)]
for task in tasks:
task.cancel()
if tasks:
await asyncio.gather(*tasks, return_exceptions=True)

def run(self, coroutine: Coroutine[object, object, T], timeout: float | None = None) -> T:
"""Execute a coroutine on the shared event loop from any thread and return its result.

Args:
coroutine (Coroutine): The coroutine to execute.
timeout (float, optional): Maximum time in seconds to wait for the result. If None, wait indefinitely.
Returns:
T: The result returned by the coroutine.
Raises:
asyncio.TimeoutError: If the coroutine does not complete within the specified timeout.
Exception: Any exception raised by the coroutine will be propagated.
Note:
On exception, this method requests cancellation of the underlying coroutine via the returned Future.
If the coroutine ignores cancellation, it may continue running briefly.
"""
future: Future[T] = asyncio.run_coroutine_threadsafe(coroutine, self._loop)
try:
return future.result(timeout=timeout)
except FutureTimeoutError as exc:
future.cancel()
raise asyncio.TimeoutError(f"Timed out waiting for coroutine result after {timeout}s") from exc
except Exception:
future.cancel()
raise

def shutdown(self) -> None:
global SHARED_ASYNC_LOOP
with SHARED_ASYNC_LOOP_LOCK:
if self._loop.is_closed():
if SHARED_ASYNC_LOOP is self:
SHARED_ASYNC_LOOP = None
return
if self._loop.is_running():
drain_future = asyncio.run_coroutine_threadsafe(self._drain_pending(), self._loop)
try:
drain_future.result(timeout=self.SHUTDOWN_TIMEOUT)
except FutureTimeoutError:
logging.getLogger(__name__).warning("Timed out cancelling pending tasks on shared async loop")
self._loop.call_soon_threadsafe(self._loop.stop)
if threading.current_thread() is not self._thread and self._thread.is_alive():
self._thread.join(timeout=self.SHUTDOWN_TIMEOUT)
if self._thread.is_alive():
logging.getLogger(__name__).warning(
f"Shared async loop thread did not stop within {self.SHUTDOWN_TIMEOUT}s"
)
if not self._loop.is_closed():
self._loop.close()
try:
atexit.unregister(self._atexit_cb)
except Exception:
# During interpreter shutdown unregister may fail; ignore.
pass
if SHARED_ASYNC_LOOP is self:
SHARED_ASYNC_LOOP = None

def is_closed(self) -> bool:
return self._loop.is_closed()


SHARED_ASYNC_LOOP: SharedAsyncLoop | None = None
SHARED_ASYNC_LOOP_LOCK = threading.Lock()


def shared_async_loop() -> SharedAsyncLoop:
"""Returns a singleton instance of SharedAsyncLoop, creating a new instance if None or the previous one is closed.

This function is thread-safe and ensures only one SharedAsyncLoop instance is active at a time.

Returns:
SharedAsyncLoop: The shared async event loop instance.
"""
global SHARED_ASYNC_LOOP
with SHARED_ASYNC_LOOP_LOCK:
if SHARED_ASYNC_LOOP is None or SHARED_ASYNC_LOOP.is_closed():
SHARED_ASYNC_LOOP = SharedAsyncLoop()
return SHARED_ASYNC_LOOP