diff --git a/examples/03_state_snapshotting.py b/examples/03_state_snapshotting.py new file mode 100644 index 0000000..7da708f --- /dev/null +++ b/examples/03_state_snapshotting.py @@ -0,0 +1,143 @@ +"""Example demonstrating environment state snapshotting and restoration. + +This example shows how to: +1. Create a snapshot after reset (at episode boundary) +2. Save the snapshot to disk +3. Restore an environment from a saved snapshot +4. Continue execution from the restored state + +Example usage: + + 1. Make sure you have examples dependencies installed + `uv sync --group examples` + 2. Run the example + `uv run -m examples.03_state_snapshotting` +""" + +import asyncio +import pathlib +import tempfile + +from ares.code_agents import mini_swe_agent +from ares.containers import docker +from ares.environments import snapshot +from ares.environments import swebench_env +from ares.llms import chat_completions_compatible + + +async def main(): + # Create an LLM client + agent = chat_completions_compatible.ChatCompletionCompatibleLLMClient(model="openai/gpt-4o-mini") + + # Load SWE-bench tasks + all_tasks = swebench_env.swebench_verified_tasks() + tasks = [all_tasks[0]] + + print(f"Running on task: {tasks[0].instance_id}") + print(f"Repository: {tasks[0].repo}") + print("-" * 80) + + # Create a temporary directory for snapshots + with tempfile.TemporaryDirectory() as snapshot_dir: + snapshot_path = pathlib.Path(snapshot_dir) + + # === PART 1: Create and save a snapshot === + print("\n[PART 1] Creating initial environment and snapshot...") + + async with swebench_env.SweBenchEnv( + tasks=tasks, + code_agent_factory=mini_swe_agent.MiniSWECodeAgent, + container_factory=docker.DockerContainer, + ) as env: + # Reset the environment to get the first timestep + ts = await env.reset() + print(f"Environment reset complete. Step count: {env._step_count}") + + # Take a few steps before snapshotting + for i in range(3): + action = await agent(ts.observation) + print(f" Step {i}: Taking action...") + ts = await env.step(action) + + if ts.last(): + print(" Episode completed early") + break + + print(f"Current step count: {env._step_count}") + + # Wait for agent to finish current operation (reach episode boundary) + # In practice, you'd snapshot after step() returns with done=True + # or after reset() completes. For this example, we'll simulate + # waiting for agent to finish. + if not ts.last(): + print("\n Note: For snapshotting, we need to be at episode boundary.") + print(" Cancelling agent task to reach boundary...") + if env._code_agent_task and not env._code_agent_task.done(): + env._code_agent_task.cancel() + import contextlib + + with contextlib.suppress(asyncio.CancelledError): + await env._code_agent_task + + # Now we can export state (at episode boundary) + print("\n Exporting state snapshot...") + snap = await env.export_state(snapshot_path, snapshot_id="example-snapshot") + + print(f" ✓ Snapshot created: {snap.snapshot_id}") + print(f" ✓ Snapshot saved to: {snap.snapshot_dir}") + print(f" ✓ Step count in snapshot: {snap.step_count}") + print(f" ✓ Task type: {snap.task_type}") + print(f" ✓ Container type: {snap.container_type}") + + # === PART 2: Restore from snapshot === + print("\n[PART 2] Restoring environment from snapshot...") + + # Load snapshot metadata + snapshot_file = snapshot_path / "example-snapshot" / "snapshot.json" + loaded_snap = snapshot.EnvironmentSnapshot.load_from_file(snapshot_file) + + print(f" ✓ Loaded snapshot: {loaded_snap.snapshot_id}") + print(f" ✓ Original step count: {loaded_snap.step_count}") + + # Restore environment from snapshot + # Note: This creates a new environment instance with the saved state + restored_env = await swebench_env.SweBenchEnv.load_from_state( + loaded_snap, + container_factory=docker.DockerContainer, + code_agent_factory=mini_swe_agent.MiniSWECodeAgent, + ) + + print(" ✓ Environment restored") + print(f" ✓ Restored step count: {restored_env._step_count}") + print(f" ✓ Task: {restored_env._current_task.instance_id}") + + # Use the restored environment in async context + async with restored_env: + print("\n[PART 3] Continuing from restored state...") + + # The environment is now at the same state as when we snapshotted + # We can continue taking steps from here + ts = await restored_env.reset() # Reset to start a new episode + step_count = 0 + + # Take a few more steps to demonstrate it works + while not ts.last() and step_count < 3: + action = await agent(ts.observation) + print(f" Step {step_count}: Taking action from restored env...") + ts = await restored_env.step(action) + step_count += 1 + + print(f"\n ✓ Completed {step_count} additional steps from restored state") + + print("\n" + "=" * 80) + print("Snapshot example completed successfully!") + print("=" * 80) + print("\nKey takeaways:") + print(" 1. Snapshots can only be taken at episode boundaries") + print(" 2. Snapshots save: task state, container filesystem, agent messages") + print(" 3. Restored environments can continue execution normally") + print(" 4. Use cases: debugging, RL replay, mechanistic interpretability") + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/src/ares/environments/__init__.py b/src/ares/environments/__init__.py index e69de29..3f64705 100644 --- a/src/ares/environments/__init__.py +++ b/src/ares/environments/__init__.py @@ -0,0 +1,15 @@ +from ares.environments.base import Environment +from ares.environments.base import StepType +from ares.environments.base import TimeStep +from ares.environments.gym_wrapper import AsyncGymWrapper +from ares.environments.gym_wrapper import GymWrapper +from ares.environments.gym_wrapper import wrap_as_gym + +__all__ = [ + "AsyncGymWrapper", + "Environment", + "GymWrapper", + "StepType", + "TimeStep", + "wrap_as_gym", +] diff --git a/src/ares/environments/base.py b/src/ares/environments/base.py index bd96bcc..687b24d 100644 --- a/src/ares/environments/base.py +++ b/src/ares/environments/base.py @@ -5,13 +5,15 @@ import abc import asyncio import atexit +import dataclasses +import datetime import functools import logging import os import pathlib import time from types import TracebackType -from typing import Literal, NamedTuple, Protocol, Self +from typing import Any, Literal, NamedTuple, Protocol, Self import uuid from numpy.typing import NDArray @@ -21,6 +23,7 @@ from ares.containers import containers from ares.containers import daytona as ares_daytona from ares.environments import base +from ares.environments import snapshot from ares.experiment_tracking import stat_tracker from ares.llms import llm_clients from ares.llms import queue_mediated_client @@ -287,9 +290,11 @@ def __init__( self._is_active = False self._container: containers.Container | None = None self._current_task: TaskType | None = None + self._code_agent: code_agent_base.CodeAgent | None = None self._code_agent_task: asyncio.Task[None] | None = None self._step_count = 0 - self._is_active = False + self._requires_reset = False + self._saved_agent_messages: list[dict] = [] # Register for cleanup on exit. _ENVIRONMENT_JANITOR.register_for_cleanup(self) @@ -430,6 +435,169 @@ def _assert_active(self) -> None: if not self._is_active: raise RuntimeError("Environment is not active.") + def _require_container(self) -> containers.Container: + """Get container or raise if not available.""" + if self._container is None: + raise RuntimeError("Container is not available.") + return self._container + + def _require_task(self) -> TaskType: + """Get current task or raise if not available.""" + if self._current_task is None: + raise RuntimeError("No current task set.") + return self._current_task + + def _validate_snapshot_allowed(self) -> None: + """Raise error if snapshot not allowed (mid-episode).""" + if self._code_agent_task is not None and not self._code_agent_task.done(): + raise RuntimeError( + "Cannot snapshot during active episode. Call export_state() after reset() or after final step()." + ) + + def _get_task_type(self) -> Literal["swebench", "harbor"]: + """Return 'swebench' or 'harbor'. Override in subclasses if needed.""" + # This will be overridden in subclasses if needed + raise NotImplementedError("Override _get_task_type in subclass") + + def _get_container_type(self, container: containers.Container) -> Literal["daytona", "docker"]: + """Return 'daytona' or 'docker'.""" + from ares.containers.daytona import DaytonaContainer + + return "daytona" if isinstance(container, DaytonaContainer) else "docker" + + def _get_agent_messages(self) -> list[dict]: + """Get agent message history if available.""" + if self._code_agent is not None and hasattr(self._code_agent, "_messages"): + return list(self._code_agent._messages) + return [] + + async def _restore_container(self, snap: snapshot.EnvironmentSnapshot) -> containers.Container: + """Restore container from filesystem snapshot.""" + # Create new container from original image/dockerfile + if snap.container_image: + container = self._container_factory.from_image( + image=snap.container_image, + resources=containers.Resources(**snap.container_resources) if snap.container_resources else None, + ) + elif snap.container_dockerfile_path: + container = self._container_factory.from_dockerfile( + dockerfile_path=pathlib.Path(snap.container_dockerfile_path), + resources=containers.Resources(**snap.container_resources) if snap.container_resources else None, + ) + else: + raise ValueError("Snapshot must have either image or dockerfile_path") + + # Start container + await container.start() + + # Restore filesystem from tarball + fs_path = snap.snapshot_dir / "container_fs.tar.gz" + if fs_path.exists(): + await container.upload_dir(fs_path, "/") + + return container + + @classmethod + def _default_container_factory(cls, snap: snapshot.EnvironmentSnapshot) -> containers.ContainerFactory: + """Create default container factory from snapshot metadata.""" + del snap # Unused - could be used in future to select factory based on container_type + # Default to Daytona + return ares_daytona.DaytonaContainer + + async def export_state( + self, + output_dir: pathlib.Path, + *, + snapshot_id: str | None = None, + ) -> snapshot.EnvironmentSnapshot: + """Export environment state to snapshot. + + Args: + output_dir: Directory to save snapshot files (tarballs, metadata) + snapshot_id: Optional ID (defaults to UUID) + + Returns: + EnvironmentSnapshot with metadata + + Raises: + RuntimeError: If called during active episode (running code agent) + """ + # Validate episode boundary + self._validate_snapshot_allowed() + + snapshot_id = snapshot_id or str(uuid.uuid4()) + snapshot_dir = output_dir / snapshot_id + snapshot_dir.mkdir(parents=True, exist_ok=True) + + # 1. Download container filesystem + container = self._require_container() + fs_path = snapshot_dir / "container_fs.tar.gz" + await container.download_dir("/", fs_path) + + # 2. Serialize task + task = self._require_task() + task_data = self._serialize_task(task) + + # 3. Get agent messages (if agent exists) + agent_messages = self._get_agent_messages() + + # 4. Create snapshot metadata + snap = snapshot.EnvironmentSnapshot( + snapshot_id=snapshot_id, + created_at=datetime.datetime.now().isoformat(), + snapshot_dir=snapshot_dir, + step_count=self._step_count, + step_limit=self._step_limit, + requires_reset=self._requires_reset, + task_type=self._get_task_type(), + task_data=task_data, + container_type=self._get_container_type(container), + container_image=getattr(container, "image", None), + container_dockerfile_path=( + str(getattr(container, "dockerfile_path", None)) if hasattr(container, "dockerfile_path") else None + ), + container_resources=dataclasses.asdict(container.resources) if container.resources else None, + agent_messages=agent_messages, + ) + + # Save metadata JSON + snap.save_to_file(snapshot_dir / "snapshot.json") + + return snap + + async def _restore_from_snapshot(self, snap: snapshot.EnvironmentSnapshot) -> None: + """Internal helper to restore state from snapshot. + + Called by subclass load_from_state implementations after creating the environment. + + Args: + snap: The snapshot to restore from + """ + # Restore container + self._container = await self._restore_container(snap) + + # Deserialize and set task + task = self._deserialize_task(snap.task_data, snap.task_type) + self._current_task = task + + # Restore state + self._step_count = snap.step_count + self._requires_reset = snap.requires_reset + + # Store messages for later restoration + self._saved_agent_messages = snap.agent_messages + + @abc.abstractmethod + def _serialize_task(self, task: TaskType) -> dict: + """Serialize task to dict. Override in subclasses.""" + pass + + @classmethod + @abc.abstractmethod + def _deserialize_task(cls, task_data: dict, task_type: str) -> Any: + """Deserialize task from dict. Override in subclasses.""" + pass + @abc.abstractmethod async def _reset_task(self) -> None: """Should set `self._current_task` with a TaskType""" diff --git a/src/ares/environments/gym_wrapper.py b/src/ares/environments/gym_wrapper.py new file mode 100644 index 0000000..9cb04ff --- /dev/null +++ b/src/ares/environments/gym_wrapper.py @@ -0,0 +1,209 @@ +"""Gymnasium-like wrappers for ARES environments. + +Provides :func:`wrap_as_gym` to adapt any ARES ``Environment`` to the +`gymnasium `_ API, making it accessible to +researchers and libraries already familiar with that interface. + +Usage:: + + import asyncio + from ares.environments import wrap_as_gym + + async def main(): + async with MyAresEnv(...) as ares_env: + env = wrap_as_gym(ares_env) + obs, info = env.reset() + while True: + action = policy(obs) + obs, reward, terminated, truncated, info = env.step(action) + if terminated or truncated: + break + +For users already inside an async context, use :class:`AsyncGymWrapper` +directly to avoid the overhead of spinning up a nested event loop. +""" + +from __future__ import annotations + +import asyncio +from typing import Any + +from ares.environments import base + + +def _run(coro: Any) -> Any: + """Run a coroutine, re-using the running loop if one exists.""" + try: + loop = asyncio.get_running_loop() + except RuntimeError: + loop = None + + if loop is not None and loop.is_running(): + # We are already inside an async context (e.g. Jupyter). + # Callers should use AsyncGymWrapper directly in this case, + # but we give a clear error rather than silently deadlocking. + raise RuntimeError( + "Cannot use GymWrapper from inside a running event loop. " + "Use AsyncGymWrapper instead, or call wrap_as_gym() from a " + "synchronous context." + ) + return asyncio.run(coro) + + +class AsyncGymWrapper[ActionType, ObservationType]: + """Async gymnasium-compatible wrapper for an ARES environment. + + Exposes the gymnasium ``reset`` / ``step`` / ``close`` interface as + *async* methods, making it suitable for use inside ``asyncio`` event + loops. + + Args: + env: Any ARES :class:`~ares.environments.base.Environment`. + """ + + def __init__( + self, + env: base.Environment[ActionType, ObservationType, Any, Any], + ) -> None: + self._env = env + # ARES environments operate on structured LLM objects rather than + # array-based gym spaces, so we expose None here. Libraries that + # strictly require numpy spaces should subclass and override. + self.observation_space: Any = None + self.action_space: Any = None + + async def reset( + self, *, seed: int | None = None, options: dict[str, Any] | None = None + ) -> tuple[ObservationType, dict[str, Any]]: + """Reset the environment and return the first observation. + + Args: + seed: Ignored (ARES environments manage their own randomness). + options: Ignored. + + Returns: + A tuple of ``(observation, info)`` matching the gymnasium v26+ + interface. ``info`` is an empty dict. + """ + del seed, options # Unused; kept for API compatibility. + time_step = await self._env.reset() + return time_step.observation, {} + + async def step(self, action: ActionType) -> tuple[ObservationType, float, bool, bool, dict[str, Any]]: + """Take a step in the environment. + + Args: + action: The action to apply (an :class:`~ares.llms.response.LLMResponse` + for code environments). + + Returns: + A tuple of ``(observation, reward, terminated, truncated, info)`` + matching the gymnasium v26+ interface. + + * ``terminated`` is ``True`` when the episode ended because the task + finished (the agent produced a final answer / the environment + reached a terminal state). + * ``truncated`` is always ``False``; ARES uses ``TimeoutError`` + rather than a truncation flag when a time limit is hit. + * ``info`` carries the raw :class:`~ares.environments.base.TimeStep` + under the key ``"time_step"`` so callers can inspect ``step_type`` + and ``discount`` if needed. + """ + time_step = await self._env.step(action) + reward = float(time_step.reward) if time_step.reward is not None else 0.0 + terminated = time_step.last() + info: dict[str, Any] = {"time_step": time_step} + return time_step.observation, reward, terminated, False, info + + async def close(self) -> None: + """Release resources held by the underlying ARES environment.""" + await self._env.close() + + async def __aenter__(self) -> AsyncGymWrapper[ActionType, ObservationType]: + return self + + async def __aexit__(self, *args: Any) -> None: + await self.close() + + +class GymWrapper[ActionType, ObservationType]: + """Synchronous gymnasium-compatible wrapper for an ARES environment. + + Bridges the fully-async ARES interface to a synchronous gymnasium-style + API by running each coroutine in a new ``asyncio`` event loop. + + .. warning:: + This wrapper cannot be used from within an already-running event + loop (e.g. inside ``async def`` functions or Jupyter notebooks). + Use :class:`AsyncGymWrapper` in those contexts. + + Args: + env: Any ARES :class:`~ares.environments.base.Environment`. + """ + + def __init__( + self, + env: base.Environment[ActionType, ObservationType, Any, Any], + ) -> None: + self._async = AsyncGymWrapper(env) + self.observation_space: Any = None + self.action_space: Any = None + + def reset( + self, *, seed: int | None = None, options: dict[str, Any] | None = None + ) -> tuple[ObservationType, dict[str, Any]]: + """Reset the environment (synchronous). + + See :meth:`AsyncGymWrapper.reset` for full documentation. + """ + return _run(self._async.reset(seed=seed, options=options)) + + def step(self, action: ActionType) -> tuple[ObservationType, float, bool, bool, dict[str, Any]]: + """Take a step in the environment (synchronous). + + See :meth:`AsyncGymWrapper.step` for full documentation. + """ + return _run(self._async.step(action)) + + def close(self) -> None: + """Release resources held by the underlying ARES environment.""" + _run(self._async.close()) + + def __enter__(self) -> GymWrapper[ActionType, ObservationType]: + return self + + def __exit__(self, *args: Any) -> None: + self.close() + + +def wrap_as_gym[ActionType, ObservationType]( + env: base.Environment[ActionType, ObservationType, Any, Any], + *, + async_mode: bool = False, +) -> GymWrapper[ActionType, ObservationType] | AsyncGymWrapper[ActionType, ObservationType]: + """Wrap an ARES environment in a gymnasium-compatible interface. + + Args: + env: Any ARES :class:`~ares.environments.base.Environment`. + async_mode: If ``True``, return an :class:`AsyncGymWrapper` with + *async* ``reset``/``step``/``close`` methods. If ``False`` + (default), return a synchronous :class:`GymWrapper`. + + Returns: + A :class:`GymWrapper` (sync) or :class:`AsyncGymWrapper` (async). + + Example:: + + # Synchronous usage (default) + async with MyAresEnv(...) as ares_env: + env = wrap_as_gym(ares_env) + obs, info = env.reset() + + # Async usage (inside async context) + async with MyAresEnv(...) as ares_env: + env = wrap_as_gym(ares_env, async_mode=True) + obs, info = await env.reset() + """ + if async_mode: + return AsyncGymWrapper(env) + return GymWrapper(env) diff --git a/src/ares/environments/gym_wrapper_test.py b/src/ares/environments/gym_wrapper_test.py new file mode 100644 index 0000000..21b6e47 --- /dev/null +++ b/src/ares/environments/gym_wrapper_test.py @@ -0,0 +1,202 @@ +"""Tests for gymnasium-compatible wrappers.""" + +from __future__ import annotations + +from typing import Any +from unittest.mock import AsyncMock +from unittest.mock import MagicMock + +import pytest + +from ares.environments.base import TimeStep +from ares.environments.gym_wrapper import AsyncGymWrapper +from ares.environments.gym_wrapper import GymWrapper +from ares.environments.gym_wrapper import wrap_as_gym + + +def _make_time_step(step_type: str, reward: float | None = None, obs: Any = "obs") -> TimeStep: + return TimeStep(step_type=step_type, reward=reward, discount=None, observation=obs) + + +def _make_env( + first_obs: Any = "first_obs", + mid_obs: Any = "mid_obs", + last_obs: Any = "last_obs", + final_reward: float = 1.0, +) -> MagicMock: + env = MagicMock() + env.reset = AsyncMock(return_value=_make_time_step("FIRST", None, first_obs)) + env.step = AsyncMock( + side_effect=[ + _make_time_step("MID", None, mid_obs), + _make_time_step("LAST", final_reward, last_obs), + ] + ) + env.close = AsyncMock() + return env + + +# --------------------------------------------------------------------------- +# AsyncGymWrapper +# --------------------------------------------------------------------------- + + +class TestAsyncGymWrapper: + @pytest.mark.asyncio + async def test_reset_returns_obs_and_empty_info(self) -> None: + env = _make_env(first_obs="hello") + wrapper = AsyncGymWrapper(env) + obs, info = await wrapper.reset() + assert obs == "hello" + assert info == {} + + @pytest.mark.asyncio + async def test_step_mid_not_terminated(self) -> None: + env = _make_env() + wrapper = AsyncGymWrapper(env) + await wrapper.reset() + obs, reward, terminated, truncated, _info = await wrapper.step("action") + assert obs == "mid_obs" + assert reward == 0.0 + assert terminated is False + assert truncated is False + + @pytest.mark.asyncio + async def test_step_last_terminated(self) -> None: + env = _make_env(final_reward=0.5) + wrapper = AsyncGymWrapper(env) + await wrapper.reset() + await wrapper.step("action") # MID + obs, reward, terminated, truncated, _info = await wrapper.step("action") # LAST + assert obs == "last_obs" + assert reward == pytest.approx(0.5) + assert terminated is True + assert truncated is False + + @pytest.mark.asyncio + async def test_step_info_contains_time_step(self) -> None: + env = _make_env() + wrapper = AsyncGymWrapper(env) + await wrapper.reset() + _, _, _, _, info = await wrapper.step("action") + assert "time_step" in info + assert isinstance(info["time_step"], TimeStep) + + @pytest.mark.asyncio + async def test_close_delegates_to_env(self) -> None: + env = _make_env() + wrapper = AsyncGymWrapper(env) + await wrapper.close() + env.close.assert_awaited_once() + + @pytest.mark.asyncio + async def test_context_manager_calls_close(self) -> None: + env = _make_env() + async with AsyncGymWrapper(env) as wrapper: + await wrapper.reset() + env.close.assert_awaited_once() + + @pytest.mark.asyncio + async def test_observation_and_action_space_none(self) -> None: + env = _make_env() + wrapper = AsyncGymWrapper(env) + assert wrapper.observation_space is None + assert wrapper.action_space is None + + @pytest.mark.asyncio + async def test_reward_none_becomes_zero(self) -> None: + env = MagicMock() + env.reset = AsyncMock(return_value=_make_time_step("FIRST", None)) + env.step = AsyncMock(return_value=_make_time_step("MID", None)) + env.close = AsyncMock() + wrapper = AsyncGymWrapper(env) + await wrapper.reset() + _, reward, _, _, _ = await wrapper.step("action") + assert reward == 0.0 + + @pytest.mark.asyncio + async def test_reset_seed_and_options_ignored(self) -> None: + env = _make_env() + wrapper = AsyncGymWrapper(env) + obs, _info = await wrapper.reset(seed=42, options={"foo": "bar"}) + assert obs == "first_obs" + + +# --------------------------------------------------------------------------- +# GymWrapper (sync bridge) +# --------------------------------------------------------------------------- + + +class TestGymWrapper: + def test_reset_returns_obs_and_empty_info(self) -> None: + env = _make_env(first_obs="hello") + wrapper = GymWrapper(env) + obs, info = wrapper.reset() + assert obs == "hello" + assert info == {} + + def test_step_mid_not_terminated(self) -> None: + env = _make_env() + wrapper = GymWrapper(env) + wrapper.reset() + obs, _reward, terminated, truncated, _info = wrapper.step("action") + assert obs == "mid_obs" + assert terminated is False + assert truncated is False + + def test_step_last_terminated(self) -> None: + env = _make_env(final_reward=1.0) + wrapper = GymWrapper(env) + wrapper.reset() + wrapper.step("action") # MID + _obs, reward, terminated, _truncated, _info = wrapper.step("action") # LAST + assert terminated is True + assert reward == pytest.approx(1.0) + + def test_close_delegates_to_env(self) -> None: + env = _make_env() + wrapper = GymWrapper(env) + wrapper.close() + env.close.assert_awaited_once() + + def test_context_manager_calls_close(self) -> None: + env = _make_env() + with GymWrapper(env) as wrapper: + wrapper.reset() + env.close.assert_awaited_once() + + def test_observation_and_action_space_none(self) -> None: + env = _make_env() + wrapper = GymWrapper(env) + assert wrapper.observation_space is None + assert wrapper.action_space is None + + +# --------------------------------------------------------------------------- +# wrap_as_gym factory +# --------------------------------------------------------------------------- + + +class TestWrapAsGym: + def test_default_returns_sync_wrapper(self) -> None: + env = _make_env() + result = wrap_as_gym(env) + assert isinstance(result, GymWrapper) + + def test_async_mode_returns_async_wrapper(self) -> None: + env = _make_env() + result = wrap_as_gym(env, async_mode=True) + assert isinstance(result, AsyncGymWrapper) + + def test_sync_wrapper_is_usable(self) -> None: + env = _make_env(first_obs="x") + gym_env = wrap_as_gym(env) + obs, _info = gym_env.reset() + assert obs == "x" + + @pytest.mark.asyncio + async def test_async_wrapper_is_usable(self) -> None: + env = _make_env(first_obs="y") + gym_env = wrap_as_gym(env, async_mode=True) + obs, _info = await gym_env.reset() + assert obs == "y" diff --git a/src/ares/environments/harbor_env.py b/src/ares/environments/harbor_env.py index a986356..e25f9a5 100644 --- a/src/ares/environments/harbor_env.py +++ b/src/ares/environments/harbor_env.py @@ -11,6 +11,7 @@ import logging import pathlib import random +from typing import Literal from harbor.models.task import task as harbor_task from harbor.models.trial import paths as harbor_paths @@ -164,3 +165,65 @@ async def _parse_reward_file(self, remote_path: pathlib.Path | str) -> float | N else: raise ValueError(f"Unsupported reward file type: {remote_path}") + + def _get_task_type(self) -> Literal["swebench", "harbor"]: + """Return task type for snapshotting.""" + return "harbor" + + def _serialize_task(self, task: harbor_task.Task) -> dict: + """Serialize Harbor task (just save task directory path).""" + return {"task_dir": str(task.task_dir)} + + @classmethod + def _deserialize_task(cls, task_data: dict, task_type: str) -> harbor_task.Task: + """Deserialize Harbor task (reload from directory).""" + del task_type # Unused - validated by caller + return harbor_task.Task(task_dir=pathlib.Path(task_data["task_dir"])) + + @classmethod + async def load_from_state( + cls, + snapshot_path: "base.snapshot.EnvironmentSnapshot | pathlib.Path", + *, + container_factory: containers.ContainerFactory | None = None, + code_agent_factory: code_agent_base.CodeAgentFactory | None = None, + tracker: stat_tracker.StatTracker | None = None, + ) -> "HarborEnv": + """Restore HarborEnv from snapshot. + + Args: + snapshot_path: EnvironmentSnapshot or path to snapshot.json + container_factory: Override factory (uses snapshot metadata if None) + code_agent_factory: Override factory (uses default if None) + tracker: Optional stat tracker + + Returns: + Restored environment (NOT active, use async with) + """ + from ares.environments import snapshot as snapshot_module + + # Load snapshot if path provided + if isinstance(snapshot_path, pathlib.Path): + snap = snapshot_module.EnvironmentSnapshot.load_from_file(snapshot_path) + else: + snap = snapshot_path + + # Deserialize task + task = cls._deserialize_task(snap.task_data, snap.task_type) + + # Create environment instance with tasks argument + container_factory = container_factory or base.ares_daytona.DaytonaContainer + code_agent_factory = code_agent_factory or mini_swe_agent.MiniSWECodeAgent + + env = cls( + tasks=[task], + container_factory=container_factory, + code_agent_factory=code_agent_factory, + step_limit=snap.step_limit, + tracker=tracker, + ) + + # Restore state using base helper + await env._restore_from_snapshot(snap) + + return env diff --git a/src/ares/environments/snapshot.py b/src/ares/environments/snapshot.py new file mode 100644 index 0000000..f87f180 --- /dev/null +++ b/src/ares/environments/snapshot.py @@ -0,0 +1,145 @@ +"""Environment state snapshotting for RL research and mechanistic interpretability. + +This module provides functionality to snapshot and restore environment state, +enabling use cases like: +- RL trajectory replay and analysis +- Debugging failed episodes +- Mechanistic interpretability of agent behavior +- Checkpointing long-running experiments + +## Usage Example + +```python +import pathlib +from ares.environments import swebench_env, snapshot + +# Create and run environment +async with swebench_env.SweBenchEnv(tasks=[task]) as env: + ts = await env.reset() + + # Take some steps + for _ in range(10): + action = await agent(ts.observation) + ts = await env.step(action) + if ts.last(): + break + + # Wait for episode to complete (required for snapshotting) + if env._code_agent_task and not env._code_agent_task.done(): + env._code_agent_task.cancel() + await env._code_agent_task + + # Export state at episode boundary + snap = await env.export_state(pathlib.Path("./snapshots")) + +# Later: restore from snapshot +loaded_snap = snapshot.EnvironmentSnapshot.load_from_file( + pathlib.Path("./snapshots/abc-123/snapshot.json") +) + +restored_env = await swebench_env.SweBenchEnv.load_from_state(loaded_snap) +async with restored_env: + # Continue from saved state + ts = await restored_env.reset() + ... +``` + +## Limitations + +- **Episode boundaries only**: Snapshots can only be created when no code agent + task is running (after reset() or after final step() with done=True) +- **No mid-execution state**: Agent message history is saved, but not mid-execution + state like running asyncio tasks or futures +- **Large snapshots**: Container filesystems are saved as tarballs (100MB-2GB typical) +- **Container restoration**: Containers are recreated from original images and + filesystem is restored from tarball, not from running container state + +## What Gets Snapshotted + +Serializable state: +- Step count and step limit +- Task metadata (serialized Pydantic models or paths) +- Container metadata (image, dockerfile, resources) +- Agent message history + +Non-serializable state (cannot snapshot): +- Running asyncio tasks and futures +- Active LLM request queues +- Live container connections +""" + +import dataclasses +import json +import pathlib +from typing import Literal + + +@dataclasses.dataclass(frozen=True) +class EnvironmentSnapshot: + """Complete environment state snapshot. + + Can only be created at episode boundaries: + - After env.reset() completes (FIRST timestep) + - After env.step() returns LAST timestep (done=True) + - When no code agent task is running + + Limitations: + - Snapshots only at episode boundaries (after reset or final step) + - Cannot snapshot mid-episode (running async tasks/futures) + - Agent message history preserved, but not mid-execution state + - Large filesystem snapshots (100MB-2GB tarballs) + """ + + # Unique identifier and metadata + snapshot_id: str + created_at: str # ISO timestamp + snapshot_dir: pathlib.Path + + # Episode state + step_count: int + step_limit: int + requires_reset: bool + + # Task metadata (for reconstruction) + task_type: Literal["swebench", "harbor"] + task_data: dict # Serialized task (Pydantic model_dump or Harbor path) + + # Container metadata + container_type: Literal["daytona", "docker"] + container_image: str | None + container_dockerfile_path: str | None + container_resources: dict | None # Serialized Resources + + # Code agent state + agent_messages: list[dict] # Chat history from MiniSWECodeAgent._messages + + def save_to_file(self, path: pathlib.Path) -> None: + """Save snapshot metadata to JSON file. + + Args: + path: Path to save JSON file (typically snapshot_dir/snapshot.json) + """ + # Convert pathlib.Path to string for JSON serialization + data = dataclasses.asdict(self) + data["snapshot_dir"] = str(data["snapshot_dir"]) + + with open(path, "w") as f: + json.dump(data, f, indent=2) + + @classmethod + def load_from_file(cls, path: pathlib.Path) -> "EnvironmentSnapshot": + """Load snapshot metadata from JSON file. + + Args: + path: Path to JSON file (typically snapshot_dir/snapshot.json) + + Returns: + EnvironmentSnapshot instance + """ + with open(path) as f: + data = json.load(f) + + # Convert string back to pathlib.Path + data["snapshot_dir"] = pathlib.Path(data["snapshot_dir"]) + + return cls(**data) diff --git a/src/ares/environments/snapshot_test.py b/src/ares/environments/snapshot_test.py new file mode 100644 index 0000000..0669325 --- /dev/null +++ b/src/ares/environments/snapshot_test.py @@ -0,0 +1,398 @@ +"""Tests for environment state snapshotting.""" + +import pathlib +import tempfile + +import pytest + +from ares.environments import snapshot +from ares.environments import swebench_env +from ares.testing import mock_container + +# Mock task for testing +_MOCK_SWEBENCH_TASK = swebench_env.SwebenchTask( + repo="test/repo", + instance_id="test-instance-1", + base_commit="abc123", + patch="diff --git a/file.py", + test_patch="diff --git a/test_file.py", + problem_statement="Fix the bug", + hints_text="", + created_at="2024-01-01", + version="1.0", + FAIL_TO_PASS='["test_case_1"]', + PASS_TO_PASS='["test_case_2"]', + environment_setup_commit="def456", +) + + +@pytest.mark.asyncio +async def test_snapshot_dataclass_save_and_load(tmp_path: pathlib.Path): + """Test EnvironmentSnapshot can be saved to and loaded from disk.""" + snap = snapshot.EnvironmentSnapshot( + snapshot_id="test-123", + created_at="2024-01-01T00:00:00", + snapshot_dir=tmp_path / "snapshots" / "test-123", + step_count=5, + step_limit=100, + requires_reset=False, + task_type="swebench", + task_data={"repo": "test/repo", "instance_id": "test-1"}, + container_type="docker", + container_image="python:3.12", + container_dockerfile_path=None, + container_resources={"cpu": 2, "memory": 4096}, + agent_messages=[{"role": "user", "content": "Hello"}], + ) + + snap.snapshot_dir.mkdir(parents=True, exist_ok=True) + snapshot_file = snap.snapshot_dir / "snapshot.json" + + # Save + snap.save_to_file(snapshot_file) + assert snapshot_file.exists() + + # Load + loaded_snap = snapshot.EnvironmentSnapshot.load_from_file(snapshot_file) + assert loaded_snap.snapshot_id == snap.snapshot_id + assert loaded_snap.step_count == snap.step_count + assert loaded_snap.task_data == snap.task_data + assert loaded_snap.snapshot_dir == snap.snapshot_dir + + +def test_swebench_task_serialization(): + """Test SwebenchTask can be serialized and deserialized via SweBenchEnv methods.""" + env = swebench_env.SweBenchEnv( + tasks=[_MOCK_SWEBENCH_TASK], + container_factory=mock_container.MockContainerFactory, + ) + + # Serialize via env method (handles JSON string conversion) + task_data = env._serialize_task(_MOCK_SWEBENCH_TASK) + + # Verify serialization produces a dict with JSON strings + assert isinstance(task_data, dict) + assert task_data["instance_id"] == "test-instance-1" + + # Verify deserialization recreates the task + restored_task = swebench_env.SweBenchEnv._deserialize_task(task_data, "swebench") + assert restored_task.instance_id == _MOCK_SWEBENCH_TASK.instance_id + assert restored_task.repo == _MOCK_SWEBENCH_TASK.repo + + +@pytest.mark.asyncio +async def test_validate_snapshot_allowed_raises_during_active_episode(): + """Test that _validate_snapshot_allowed raises when agent task is running.""" + env = swebench_env.SweBenchEnv( + tasks=[_MOCK_SWEBENCH_TASK], + container_factory=mock_container.MockContainerFactory, + step_limit=100, + ) + + # Create a mock running task (not done) + import asyncio + + async def mock_running_task(): + await asyncio.sleep(100) # Never completes + + env._code_agent_task = asyncio.create_task(mock_running_task()) + + # Should raise because task is running + with pytest.raises(RuntimeError, match="Cannot snapshot during active episode"): + env._validate_snapshot_allowed() + + # Cleanup + env._code_agent_task.cancel() + import contextlib + + with contextlib.suppress(asyncio.CancelledError): + await env._code_agent_task + + +@pytest.mark.asyncio +async def test_validate_snapshot_allowed_succeeds_when_task_done(): + """Test that _validate_snapshot_allowed succeeds when agent task is done.""" + env = swebench_env.SweBenchEnv( + tasks=[_MOCK_SWEBENCH_TASK], + container_factory=mock_container.MockContainerFactory, + step_limit=100, + ) + + # Create a completed task + import asyncio + + async def mock_completed_task(): + return None + + env._code_agent_task = asyncio.create_task(mock_completed_task()) + await env._code_agent_task # Wait for completion + + # Should not raise + env._validate_snapshot_allowed() + + +@pytest.mark.asyncio +async def test_validate_snapshot_allowed_succeeds_when_no_task(): + """Test that _validate_snapshot_allowed succeeds when no agent task exists.""" + env = swebench_env.SweBenchEnv( + tasks=[_MOCK_SWEBENCH_TASK], + container_factory=mock_container.MockContainerFactory, + step_limit=100, + ) + + # No agent task set + env._code_agent_task = None + + # Should not raise + env._validate_snapshot_allowed() + + +def test_swebench_env_serialize_task(): + """Test SweBenchEnv._serialize_task produces correct dict.""" + env = swebench_env.SweBenchEnv( + tasks=[_MOCK_SWEBENCH_TASK], + container_factory=mock_container.MockContainerFactory, + ) + + serialized = env._serialize_task(_MOCK_SWEBENCH_TASK) + + assert isinstance(serialized, dict) + assert serialized["instance_id"] == "test-instance-1" + assert serialized["repo"] == "test/repo" + + +def test_swebench_env_deserialize_task(): + """Test SweBenchEnv._deserialize_task recreates task from dict.""" + env = swebench_env.SweBenchEnv( + tasks=[_MOCK_SWEBENCH_TASK], + container_factory=mock_container.MockContainerFactory, + ) + + # Serialize first (this will convert lists to JSON strings) + task_data = env._serialize_task(_MOCK_SWEBENCH_TASK) + + # Then deserialize + restored_task = swebench_env.SweBenchEnv._deserialize_task(task_data, "swebench") + + assert isinstance(restored_task, swebench_env.SwebenchTask) + assert restored_task.instance_id == "test-instance-1" + assert restored_task.repo == "test/repo" + + +def test_swebench_env_get_task_type(): + """Test SweBenchEnv._get_task_type returns 'swebench'.""" + env = swebench_env.SweBenchEnv( + tasks=[_MOCK_SWEBENCH_TASK], + container_factory=mock_container.MockContainerFactory, + ) + + assert env._get_task_type() == "swebench" + + +def test_get_container_type_daytona(): + """Test _get_container_type identifies Daytona containers.""" + from ares.containers import daytona + + env = swebench_env.SweBenchEnv( + tasks=[_MOCK_SWEBENCH_TASK], + container_factory=mock_container.MockContainerFactory, + ) + + # Mock a Daytona container + mock_daytona = type("MockDaytona", (daytona.DaytonaContainer,), {})() + + container_type = env._get_container_type(mock_daytona) + assert container_type == "daytona" + + +def test_get_container_type_docker(): + """Test _get_container_type identifies Docker containers.""" + env = swebench_env.SweBenchEnv( + tasks=[_MOCK_SWEBENCH_TASK], + container_factory=mock_container.MockContainerFactory, + ) + + # Mock a Docker container (not Daytona) + mock_docker = mock_container.MockContainer() + + container_type = env._get_container_type(mock_docker) + assert container_type == "docker" + + +@pytest.mark.asyncio +async def test_export_state_basic_metadata(tmp_path: pathlib.Path): + """Test export_state creates snapshot with correct metadata.""" + + # Create a mock container with download_dir support + class MockContainerWithDownload(mock_container.MockContainer): + def __init__(self): + super().__init__() + self.resources = None # Add resources attribute + + async def download_dir(self, remote_path: str, local_path: pathlib.Path): + """Mock download_dir that creates an empty tarball.""" + del remote_path # Unused in mock + local_path.parent.mkdir(parents=True, exist_ok=True) + local_path.write_text("mock tarball content") + + env = swebench_env.SweBenchEnv( + tasks=[_MOCK_SWEBENCH_TASK], + container_factory=mock_container.MockContainerFactory, + step_limit=50, + ) + + # Set up minimal state + container = MockContainerWithDownload() + await container.start() + env._container = container + env._current_task = _MOCK_SWEBENCH_TASK + env._step_count = 10 + env._requires_reset = False + env._code_agent_task = None # No running task + + # Export state + snap = await env.export_state(tmp_path, snapshot_id="test-export-123") + + # Verify metadata + assert snap.snapshot_id == "test-export-123" + assert snap.step_count == 10 + assert snap.step_limit == 50 + assert snap.requires_reset is False + assert snap.task_type == "swebench" + assert snap.task_data["instance_id"] == "test-instance-1" + assert snap.container_type == "docker" + + # Verify files were created + assert (snap.snapshot_dir / "snapshot.json").exists() + assert (snap.snapshot_dir / "container_fs.tar.gz").exists() + + +@pytest.mark.asyncio +async def test_export_state_auto_generates_snapshot_id(tmp_path: pathlib.Path): + """Test export_state auto-generates UUID when snapshot_id not provided.""" + + # Create a mock container with download_dir support + class MockContainerWithDownload(mock_container.MockContainer): + def __init__(self): + super().__init__() + self.resources = None # Add resources attribute + + async def download_dir(self, remote_path: str, local_path: pathlib.Path): + del remote_path # Unused in mock + local_path.parent.mkdir(parents=True, exist_ok=True) + local_path.write_text("mock tarball") + + env = swebench_env.SweBenchEnv( + tasks=[_MOCK_SWEBENCH_TASK], + container_factory=mock_container.MockContainerFactory, + ) + + # Set up minimal state + container = MockContainerWithDownload() + await container.start() + env._container = container + env._current_task = _MOCK_SWEBENCH_TASK + env._code_agent_task = None + + # Export without snapshot_id + snap = await env.export_state(tmp_path) + + # Should have a UUID-like snapshot_id + assert snap.snapshot_id is not None + assert len(snap.snapshot_id) > 0 + + +@pytest.mark.asyncio +async def test_export_state_raises_if_no_container(): + """Test export_state raises if container not available.""" + env = swebench_env.SweBenchEnv( + tasks=[_MOCK_SWEBENCH_TASK], + container_factory=mock_container.MockContainerFactory, + ) + + env._container = None + env._current_task = _MOCK_SWEBENCH_TASK + env._code_agent_task = None + + with tempfile.TemporaryDirectory() as tmp_dir, pytest.raises(RuntimeError, match="Container is not available"): + await env.export_state(pathlib.Path(tmp_dir)) + + +@pytest.mark.asyncio +async def test_export_state_raises_if_no_task(): + """Test export_state raises if current task not available.""" + env = swebench_env.SweBenchEnv( + tasks=[_MOCK_SWEBENCH_TASK], + container_factory=mock_container.MockContainerFactory, + ) + + container = mock_container.MockContainer() + await container.start() + env._container = container + env._current_task = None + env._code_agent_task = None + + with tempfile.TemporaryDirectory() as tmp_dir, pytest.raises(RuntimeError, match="No current task set"): + await env.export_state(pathlib.Path(tmp_dir)) + + +@pytest.mark.asyncio +async def test_load_from_state_creates_valid_env(tmp_path: pathlib.Path): + """Test load_from_state creates a properly initialized environment.""" + + # Create a mock container with download_dir and upload_dir support + class MockContainerWithDirOps(mock_container.MockContainer): + def __init__(self): + super().__init__() + self.resources = None + self.image = "python:3.12" # Add image attribute for snapshot + + async def download_dir(self, remote_path: str, local_path: pathlib.Path): + del remote_path # Unused in mock + local_path.parent.mkdir(parents=True, exist_ok=True) + local_path.write_text("mock tarball") + + async def upload_dir(self, local_path: pathlib.Path, remote_path: str): + """Mock upload_dir for container restoration.""" + del local_path, remote_path # Unused in mock + + # Create and export state + env = swebench_env.SweBenchEnv( + tasks=[_MOCK_SWEBENCH_TASK], + container_factory=mock_container.MockContainerFactory, + step_limit=42, + ) + + container = MockContainerWithDirOps() + await container.start() + env._container = container + env._current_task = _MOCK_SWEBENCH_TASK + env._step_count = 7 + env._requires_reset = False + env._code_agent_task = None + + snap = await env.export_state(tmp_path, snapshot_id="test-load") + + # Load from snapshot + class MockContainerFactory: + @classmethod + def from_image(cls, *, image: str, name: str | None = None, resources=None): + del image, name, resources + return MockContainerWithDirOps() + + @classmethod + def from_dockerfile(cls, *, dockerfile_path, name: str | None = None, resources=None): + del dockerfile_path, name, resources + return MockContainerWithDirOps() + + restored_env = await swebench_env.SweBenchEnv.load_from_state(snap, container_factory=MockContainerFactory) + + # Verify restoration + assert restored_env._step_count == 7 + assert restored_env._step_limit == 42 + assert restored_env._requires_reset is False + assert restored_env._current_task.instance_id == _MOCK_SWEBENCH_TASK.instance_id + assert restored_env._container is not None + + # Cleanup + await restored_env.close() diff --git a/src/ares/environments/swebench_env.py b/src/ares/environments/swebench_env.py index 62feaa6..939093e 100644 --- a/src/ares/environments/swebench_env.py +++ b/src/ares/environments/swebench_env.py @@ -10,9 +10,10 @@ import functools import json import logging +import pathlib import random import time -from typing import Any, cast +from typing import Any, Literal, cast import datasets import pydantic @@ -247,3 +248,72 @@ async def _compute_reward(self) -> float: test_result = await _run_tests_and_evaluate(self._container, self._current_task, self._test_spec) return 1.0 if test_result["resolved"] else 0.0 + + def _get_task_type(self) -> Literal["swebench", "harbor"]: + """Return task type for snapshotting.""" + return "swebench" + + def _serialize_task(self, task: SwebenchTask) -> dict: + """Serialize SwebenchTask using Pydantic.""" + data = task.model_dump() + # Convert lists back to JSON strings for field validators + data["FAIL_TO_PASS"] = json.dumps(data["FAIL_TO_PASS"]) + data["PASS_TO_PASS"] = json.dumps(data["PASS_TO_PASS"]) + return data + + @classmethod + def _deserialize_task(cls, task_data: dict, task_type: str) -> SwebenchTask: + """Deserialize SwebenchTask using Pydantic.""" + del task_type # Unused - validated by caller + # The field validators will convert JSON strings to lists + return SwebenchTask.model_validate(task_data) + + @classmethod + async def load_from_state( + cls, + snapshot_path: "base.snapshot.EnvironmentSnapshot | pathlib.Path", + *, + container_factory: containers.ContainerFactory | None = None, + code_agent_factory: code_agent_base.CodeAgentFactory | None = None, + tracker: stat_tracker.StatTracker | None = None, + ) -> "SweBenchEnv": + """Restore SweBenchEnv from snapshot. + + Args: + snapshot_path: EnvironmentSnapshot or path to snapshot.json + container_factory: Override factory (uses snapshot metadata if None) + code_agent_factory: Override factory (uses default if None) + tracker: Optional stat tracker + + Returns: + Restored environment (NOT active, use async with) + """ + import pathlib as pathlib_module + + from ares.environments import snapshot as snapshot_module + + # Load snapshot if path provided + if isinstance(snapshot_path, pathlib_module.Path): + snap = snapshot_module.EnvironmentSnapshot.load_from_file(snapshot_path) + else: + snap = snapshot_path + + # Deserialize task + task = cls._deserialize_task(snap.task_data, snap.task_type) + + # Create environment instance with tasks argument + container_factory = container_factory or base.ares_daytona.DaytonaContainer + code_agent_factory = code_agent_factory or mini_swe_agent.MiniSWECodeAgent + + env = cls( + tasks=[task], + container_factory=container_factory, + code_agent_factory=code_agent_factory, + step_limit=snap.step_limit, + tracker=tracker, + ) + + # Restore state using base helper + await env._restore_from_snapshot(snap) + + return env