-
Notifications
You must be signed in to change notification settings - Fork 14
Add gymnasium-compatible wrappers for ARES environments (fixes #25) #98
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,143 @@ | ||
| """Example demonstrating environment state snapshotting and restoration. | ||
|
|
||
| This example shows how to: | ||
| 1. Create a snapshot after reset (at episode boundary) | ||
| 2. Save the snapshot to disk | ||
| 3. Restore an environment from a saved snapshot | ||
| 4. Continue execution from the restored state | ||
|
|
||
| Example usage: | ||
|
|
||
| 1. Make sure you have examples dependencies installed | ||
| `uv sync --group examples` | ||
| 2. Run the example | ||
| `uv run -m examples.03_state_snapshotting` | ||
| """ | ||
|
|
||
| import asyncio | ||
| import pathlib | ||
| import tempfile | ||
|
|
||
| from ares.code_agents import mini_swe_agent | ||
| from ares.containers import docker | ||
| from ares.environments import snapshot | ||
| from ares.environments import swebench_env | ||
| from ares.llms import chat_completions_compatible | ||
|
|
||
|
|
||
| async def main(): | ||
| # Create an LLM client | ||
| agent = chat_completions_compatible.ChatCompletionCompatibleLLMClient(model="openai/gpt-4o-mini") | ||
|
|
||
| # Load SWE-bench tasks | ||
| all_tasks = swebench_env.swebench_verified_tasks() | ||
| tasks = [all_tasks[0]] | ||
|
|
||
| print(f"Running on task: {tasks[0].instance_id}") | ||
| print(f"Repository: {tasks[0].repo}") | ||
| print("-" * 80) | ||
|
|
||
| # Create a temporary directory for snapshots | ||
| with tempfile.TemporaryDirectory() as snapshot_dir: | ||
| snapshot_path = pathlib.Path(snapshot_dir) | ||
|
|
||
| # === PART 1: Create and save a snapshot === | ||
| print("\n[PART 1] Creating initial environment and snapshot...") | ||
|
|
||
| async with swebench_env.SweBenchEnv( | ||
| tasks=tasks, | ||
| code_agent_factory=mini_swe_agent.MiniSWECodeAgent, | ||
| container_factory=docker.DockerContainer, | ||
| ) as env: | ||
| # Reset the environment to get the first timestep | ||
| ts = await env.reset() | ||
| print(f"Environment reset complete. Step count: {env._step_count}") | ||
|
|
||
| # Take a few steps before snapshotting | ||
| for i in range(3): | ||
| action = await agent(ts.observation) | ||
| print(f" Step {i}: Taking action...") | ||
| ts = await env.step(action) | ||
|
|
||
| if ts.last(): | ||
| print(" Episode completed early") | ||
| break | ||
|
|
||
| print(f"Current step count: {env._step_count}") | ||
|
|
||
| # Wait for agent to finish current operation (reach episode boundary) | ||
| # In practice, you'd snapshot after step() returns with done=True | ||
| # or after reset() completes. For this example, we'll simulate | ||
| # waiting for agent to finish. | ||
| if not ts.last(): | ||
| print("\n Note: For snapshotting, we need to be at episode boundary.") | ||
| print(" Cancelling agent task to reach boundary...") | ||
| if env._code_agent_task and not env._code_agent_task.done(): | ||
| env._code_agent_task.cancel() | ||
| import contextlib | ||
|
|
||
| with contextlib.suppress(asyncio.CancelledError): | ||
| await env._code_agent_task | ||
|
|
||
| # Now we can export state (at episode boundary) | ||
| print("\n Exporting state snapshot...") | ||
| snap = await env.export_state(snapshot_path, snapshot_id="example-snapshot") | ||
|
|
||
| print(f" ✓ Snapshot created: {snap.snapshot_id}") | ||
| print(f" ✓ Snapshot saved to: {snap.snapshot_dir}") | ||
| print(f" ✓ Step count in snapshot: {snap.step_count}") | ||
| print(f" ✓ Task type: {snap.task_type}") | ||
| print(f" ✓ Container type: {snap.container_type}") | ||
|
|
||
| # === PART 2: Restore from snapshot === | ||
| print("\n[PART 2] Restoring environment from snapshot...") | ||
|
|
||
| # Load snapshot metadata | ||
| snapshot_file = snapshot_path / "example-snapshot" / "snapshot.json" | ||
| loaded_snap = snapshot.EnvironmentSnapshot.load_from_file(snapshot_file) | ||
|
|
||
| print(f" ✓ Loaded snapshot: {loaded_snap.snapshot_id}") | ||
| print(f" ✓ Original step count: {loaded_snap.step_count}") | ||
|
|
||
| # Restore environment from snapshot | ||
| # Note: This creates a new environment instance with the saved state | ||
| restored_env = await swebench_env.SweBenchEnv.load_from_state( | ||
| loaded_snap, | ||
| container_factory=docker.DockerContainer, | ||
| code_agent_factory=mini_swe_agent.MiniSWECodeAgent, | ||
| ) | ||
|
|
||
| print(" ✓ Environment restored") | ||
| print(f" ✓ Restored step count: {restored_env._step_count}") | ||
| print(f" ✓ Task: {restored_env._current_task.instance_id}") | ||
|
|
||
| # Use the restored environment in async context | ||
| async with restored_env: | ||
| print("\n[PART 3] Continuing from restored state...") | ||
|
|
||
| # The environment is now at the same state as when we snapshotted | ||
| # We can continue taking steps from here | ||
| ts = await restored_env.reset() # Reset to start a new episode | ||
| step_count = 0 | ||
|
|
||
| # Take a few more steps to demonstrate it works | ||
| while not ts.last() and step_count < 3: | ||
| action = await agent(ts.observation) | ||
| print(f" Step {step_count}: Taking action from restored env...") | ||
| ts = await restored_env.step(action) | ||
| step_count += 1 | ||
|
|
||
| print(f"\n ✓ Completed {step_count} additional steps from restored state") | ||
|
|
||
| print("\n" + "=" * 80) | ||
| print("Snapshot example completed successfully!") | ||
| print("=" * 80) | ||
| print("\nKey takeaways:") | ||
| print(" 1. Snapshots can only be taken at episode boundaries") | ||
| print(" 2. Snapshots save: task state, container filesystem, agent messages") | ||
| print(" 3. Restored environments can continue execution normally") | ||
| print(" 4. Use cases: debugging, RL replay, mechanistic interpretability") | ||
|
|
||
|
|
||
| if __name__ == "__main__": | ||
| asyncio.run(main()) | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,15 @@ | ||
| from ares.environments.base import Environment | ||
| from ares.environments.base import StepType | ||
| from ares.environments.base import TimeStep | ||
| from ares.environments.gym_wrapper import AsyncGymWrapper | ||
| from ares.environments.gym_wrapper import GymWrapper | ||
| from ares.environments.gym_wrapper import wrap_as_gym | ||
|
|
||
| __all__ = [ | ||
| "AsyncGymWrapper", | ||
| "Environment", | ||
| "GymWrapper", | ||
| "StepType", | ||
| "TimeStep", | ||
| "wrap_as_gym", | ||
| ] |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -5,13 +5,15 @@ | |
| import abc | ||
| import asyncio | ||
| import atexit | ||
| import dataclasses | ||
| import datetime | ||
| import functools | ||
| import logging | ||
| import os | ||
| import pathlib | ||
| import time | ||
| from types import TracebackType | ||
| from typing import Literal, NamedTuple, Protocol, Self | ||
| from typing import Any, Literal, NamedTuple, Protocol, Self | ||
| import uuid | ||
|
|
||
| from numpy.typing import NDArray | ||
|
|
@@ -21,6 +23,7 @@ | |
| from ares.containers import containers | ||
| from ares.containers import daytona as ares_daytona | ||
| from ares.environments import base | ||
| from ares.environments import snapshot | ||
| from ares.experiment_tracking import stat_tracker | ||
| from ares.llms import llm_clients | ||
| from ares.llms import queue_mediated_client | ||
|
|
@@ -287,9 +290,11 @@ def __init__( | |
| self._is_active = False | ||
| self._container: containers.Container | None = None | ||
| self._current_task: TaskType | None = None | ||
| self._code_agent: code_agent_base.CodeAgent | None = None | ||
| self._code_agent_task: asyncio.Task[None] | None = None | ||
| self._step_count = 0 | ||
| self._is_active = False | ||
| self._requires_reset = False | ||
| self._saved_agent_messages: list[dict] = [] | ||
|
|
||
| # Register for cleanup on exit. | ||
| _ENVIRONMENT_JANITOR.register_for_cleanup(self) | ||
|
|
@@ -430,6 +435,169 @@ def _assert_active(self) -> None: | |
| if not self._is_active: | ||
| raise RuntimeError("Environment is not active.") | ||
|
|
||
| def _require_container(self) -> containers.Container: | ||
| """Get container or raise if not available.""" | ||
| if self._container is None: | ||
| raise RuntimeError("Container is not available.") | ||
| return self._container | ||
|
Comment on lines
+438
to
+442
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Finding type: Want Baz to fix this for you? Activate Fixer Other fix methodsPrompt for AI Agents: |
||
|
|
||
| def _require_task(self) -> TaskType: | ||
| """Get current task or raise if not available.""" | ||
| if self._current_task is None: | ||
| raise RuntimeError("No current task set.") | ||
| return self._current_task | ||
|
|
||
| def _validate_snapshot_allowed(self) -> None: | ||
| """Raise error if snapshot not allowed (mid-episode).""" | ||
| if self._code_agent_task is not None and not self._code_agent_task.done(): | ||
| raise RuntimeError( | ||
| "Cannot snapshot during active episode. Call export_state() after reset() or after final step()." | ||
| ) | ||
|
Comment on lines
+450
to
+455
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Finding type: Want Baz to fix this for you? Activate Fixer Other fix methodsPrompt for AI Agents: |
||
|
|
||
| def _get_task_type(self) -> Literal["swebench", "harbor"]: | ||
| """Return 'swebench' or 'harbor'. Override in subclasses if needed.""" | ||
| # This will be overridden in subclasses if needed | ||
| raise NotImplementedError("Override _get_task_type in subclass") | ||
|
|
||
| def _get_container_type(self, container: containers.Container) -> Literal["daytona", "docker"]: | ||
| """Return 'daytona' or 'docker'.""" | ||
| from ares.containers.daytona import DaytonaContainer | ||
|
|
||
| return "daytona" if isinstance(container, DaytonaContainer) else "docker" | ||
|
|
||
| def _get_agent_messages(self) -> list[dict]: | ||
| """Get agent message history if available.""" | ||
| if self._code_agent is not None and hasattr(self._code_agent, "_messages"): | ||
| return list(self._code_agent._messages) | ||
| return [] | ||
|
|
||
| async def _restore_container(self, snap: snapshot.EnvironmentSnapshot) -> containers.Container: | ||
| """Restore container from filesystem snapshot.""" | ||
| # Create new container from original image/dockerfile | ||
| if snap.container_image: | ||
| container = self._container_factory.from_image( | ||
| image=snap.container_image, | ||
| resources=containers.Resources(**snap.container_resources) if snap.container_resources else None, | ||
| ) | ||
| elif snap.container_dockerfile_path: | ||
| container = self._container_factory.from_dockerfile( | ||
| dockerfile_path=pathlib.Path(snap.container_dockerfile_path), | ||
| resources=containers.Resources(**snap.container_resources) if snap.container_resources else None, | ||
| ) | ||
| else: | ||
| raise ValueError("Snapshot must have either image or dockerfile_path") | ||
|
|
||
| # Start container | ||
| await container.start() | ||
|
|
||
| # Restore filesystem from tarball | ||
| fs_path = snap.snapshot_dir / "container_fs.tar.gz" | ||
| if fs_path.exists(): | ||
| await container.upload_dir(fs_path, "/") | ||
|
|
||
| return container | ||
|
Comment on lines
+494
to
+498
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 🧩 Analysis chain🏁 Script executed: #!/bin/bash
# Check how upload_dir is implemented in container classes
ast-grep --pattern $'async def upload_dir($_, local_path$_, remote_path$_) {
$$$
}'Repository: withmartian/ares Length of output: 42 🏁 Script executed: # First, find the container class and how upload_dir is used
rg "upload_dir" -A 5 -B 2Repository: withmartian/ares Length of output: 3609 🏁 Script executed: # Find container class definitions
rg "class.*container" -iRepository: withmartian/ares Length of output: 1041 🏁 Script executed: # Look at the imports in base.py to understand the container type
head -50 src/ares/environments/base.pyRepository: withmartian/ares Length of output: 1567 🏁 Script executed: # Get the full upload_dir implementation
sed -n '/(async def upload_dir/,/^[[:space:]]*async def\|^[[:space:]]*def\|^class/p' src/ares/containers/containers.py | head -30Repository: withmartian/ares Length of output: 42 🏁 Script executed: # Alternative: use rg to get more context around upload_dir
rg "async def upload_dir" -A 20 src/ares/containers/containers.pyRepository: withmartian/ares Length of output: 1032 🏁 Script executed: # Also check how fs_path is used in base.py context
sed -n '490,510p' src/ares/environments/base.pyRepository: withmartian/ares Length of output: 782 🏁 Script executed: # Check how the tarball is created/saved in export_state
rg "container_fs.tar.gz" -B 5 -A 2Repository: withmartian/ares Length of output: 1664 🏁 Script executed: # Check the snapshot test to see intended behavior
sed -n '1,100p' src/ares/environments/snapshot_test.py | grep -A 20 "upload_dir"Repository: withmartian/ares Length of output: 42 🏁 Script executed: # Look for where tarball is extracted or how filesystem is meant to be restored
rg "tar.*gz" -iRepository: withmartian/ares Length of output: 337 🏁 Script executed: # Check download_dir implementation to understand how the tarball is created
rg "async def download_dir" -A 20 src/ares/containers/containers.pyRepository: withmartian/ares Length of output: 954 🏁 Script executed: # Check the full context of how export_state saves the filesystem
sed -n '507,550p' src/ares/environments/base.pyRepository: withmartian/ares Length of output: 1546 🏁 Script executed: # Check the snapshot.py to understand how snapshot stores/loads data
grep -n "snapshot_dir\|download_dir\|upload_dir" src/ares/environments/snapshot.py | head -20Repository: withmartian/ares Length of output: 387 🏁 Script executed: # Check if there's any extraction or decompression happening with the "tarball"
rg "tar\|gz\|extract\|decompress" src/ares/environments/base.pyRepository: withmartian/ares Length of output: 42 🏁 Script executed: # Check the test to see what's actually in container_fs.tar.gz
sed -n '1,200p' src/ares/environments/snapshot_test.py | tail -100Repository: withmartian/ares Length of output: 3127 🏁 Script executed: # Look at the full download_dir call to understand the behavior better
sed -n '530,545p' src/ares/environments/base.pyRepository: withmartian/ares Length of output: 610 🏁 Script executed: # Let me check what download_dir actually creates - is it a tarball or a directory?
# Look at the test mock implementation
sed -n '1,150p' src/ares/environments/snapshot_test.py | grep -A 30 "class MockContainerWithDirOps"Repository: withmartian/ares Length of output: 42 🏁 Script executed: # Check if the naming is misleading - does "container_fs.tar.gz" actually create a tarball or directory structure?
# Let me see the full mock implementation for upload_dir
sed -n '1,300p' src/ares/environments/snapshot_test.py | grep -A 10 "async def upload_dir"Repository: withmartian/ares Length of output: 42 🏁 Script executed: # Get the full snapshot_test.py to see the mock implementations
cat src/ares/environments/snapshot_test.pyRepository: withmartian/ares Length of output: 13577 🏁 Script executed: # Let me check the exact implementation of upload_dir more carefully, especially the destination_path line
sed -n '/async def upload_dir/,/async def download_file/p' src/ares/containers/containers.py | head -20Repository: withmartian/ares Length of output: 765 🏁 Script executed: # Check if remote_path is maybe converted to Path somewhere
rg "destination_path.*remote_path" -B 2 -A 2Repository: withmartian/ares Length of output: 458 🏁 Script executed: # Check if this code actually works - the remote_path is a string, but it's being used with /
# Let me trace through: remote_path is "/" (a string), relative_path is Path
# destination_path = str(remote_path / relative_path) should fail
# Let me verify if there's any conversion happening
python3 << 'EOF'
# Test if string / Path works
s = "/"
p = "file.txt"
try:
result = s / p
print("Success:", result)
except TypeError as e:
print("Error:", e)
EOFRepository: withmartian/ares Length of output: 117 🏁 Script executed: # Also check if there's maybe a different version or if pathlib.Path is being used
rg "destination_path.*=" -B 5 -A 1 src/ares/containers/Repository: withmartian/ares Length of output: 550 🏁 Script executed: # Check the actual test to see if it exercises upload_dir properly
grep -A 30 "async def upload_dir" src/ares/environments/snapshot_test.pyRepository: withmartian/ares Length of output: 1172 Critical bug: The code at line 497 calls Additionally, 🤖 Prompt for AI Agents |
||
|
|
||
| @classmethod | ||
| def _default_container_factory(cls, snap: snapshot.EnvironmentSnapshot) -> containers.ContainerFactory: | ||
| """Create default container factory from snapshot metadata.""" | ||
| del snap # Unused - could be used in future to select factory based on container_type | ||
| # Default to Daytona | ||
| return ares_daytona.DaytonaContainer | ||
|
|
||
| async def export_state( | ||
| self, | ||
| output_dir: pathlib.Path, | ||
| *, | ||
| snapshot_id: str | None = None, | ||
| ) -> snapshot.EnvironmentSnapshot: | ||
| """Export environment state to snapshot. | ||
|
|
||
| Args: | ||
| output_dir: Directory to save snapshot files (tarballs, metadata) | ||
| snapshot_id: Optional ID (defaults to UUID) | ||
|
|
||
| Returns: | ||
| EnvironmentSnapshot with metadata | ||
|
|
||
| Raises: | ||
| RuntimeError: If called during active episode (running code agent) | ||
| """ | ||
| # Validate episode boundary | ||
| self._validate_snapshot_allowed() | ||
|
|
||
| snapshot_id = snapshot_id or str(uuid.uuid4()) | ||
| snapshot_dir = output_dir / snapshot_id | ||
| snapshot_dir.mkdir(parents=True, exist_ok=True) | ||
|
|
||
| # 1. Download container filesystem | ||
| container = self._require_container() | ||
| fs_path = snapshot_dir / "container_fs.tar.gz" | ||
| await container.download_dir("/", fs_path) | ||
|
|
||
| # 2. Serialize task | ||
| task = self._require_task() | ||
| task_data = self._serialize_task(task) | ||
|
|
||
| # 3. Get agent messages (if agent exists) | ||
| agent_messages = self._get_agent_messages() | ||
|
|
||
| # 4. Create snapshot metadata | ||
| snap = snapshot.EnvironmentSnapshot( | ||
| snapshot_id=snapshot_id, | ||
| created_at=datetime.datetime.now().isoformat(), | ||
| snapshot_dir=snapshot_dir, | ||
| step_count=self._step_count, | ||
| step_limit=self._step_limit, | ||
| requires_reset=self._requires_reset, | ||
| task_type=self._get_task_type(), | ||
| task_data=task_data, | ||
| container_type=self._get_container_type(container), | ||
| container_image=getattr(container, "image", None), | ||
| container_dockerfile_path=( | ||
| str(getattr(container, "dockerfile_path", None)) if hasattr(container, "dockerfile_path") else None | ||
| ), | ||
| container_resources=dataclasses.asdict(container.resources) if container.resources else None, | ||
| agent_messages=agent_messages, | ||
| ) | ||
|
|
||
| # Save metadata JSON | ||
| snap.save_to_file(snapshot_dir / "snapshot.json") | ||
|
|
||
| return snap | ||
|
|
||
| async def _restore_from_snapshot(self, snap: snapshot.EnvironmentSnapshot) -> None: | ||
| """Internal helper to restore state from snapshot. | ||
|
|
||
| Called by subclass load_from_state implementations after creating the environment. | ||
|
|
||
| Args: | ||
| snap: The snapshot to restore from | ||
| """ | ||
| # Restore container | ||
| self._container = await self._restore_container(snap) | ||
|
|
||
| # Deserialize and set task | ||
| task = self._deserialize_task(snap.task_data, snap.task_type) | ||
| self._current_task = task | ||
|
|
||
| # Restore state | ||
| self._step_count = snap.step_count | ||
| self._requires_reset = snap.requires_reset | ||
|
|
||
| # Store messages for later restoration | ||
| self._saved_agent_messages = snap.agent_messages | ||
|
|
||
|
Comment on lines
+586
to
+589
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Should we seed Finding type: Want Baz to fix this for you? Activate Fixer Other fix methodsPrompt for AI Agents: |
||
| @abc.abstractmethod | ||
| def _serialize_task(self, task: TaskType) -> dict: | ||
| """Serialize task to dict. Override in subclasses.""" | ||
| pass | ||
|
|
||
| @classmethod | ||
| @abc.abstractmethod | ||
| def _deserialize_task(cls, task_data: dict, task_type: str) -> Any: | ||
| """Deserialize task from dict. Override in subclasses.""" | ||
| pass | ||
|
|
||
| @abc.abstractmethod | ||
| async def _reset_task(self) -> None: | ||
| """Should set `self._current_task` with a TaskType""" | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
restored_env.reset()discards the restored snapshot becausereset()always stops/clears and reinitializes, should we changeEnvironment.reset()to avoid stopping/clearing for restored envs or add a start-restored method, or else mark snapshots non-resumable and verify_restore_from_snapshot()rehydrates agent/task/message history so we can continue withoutreset()?Finding type:
Logical Bugs| Severity: 🔴 HighWant Baz to fix this for you? Activate Fixer
Other fix methods
Prompt for AI Agents: