Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
143 changes: 143 additions & 0 deletions examples/03_state_snapshotting.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,143 @@
"""Example demonstrating environment state snapshotting and restoration.

This example shows how to:
1. Create a snapshot after reset (at episode boundary)
2. Save the snapshot to disk
3. Restore an environment from a saved snapshot
4. Continue execution from the restored state

Example usage:

1. Make sure you have examples dependencies installed
`uv sync --group examples`
2. Run the example
`uv run -m examples.03_state_snapshotting`
"""

import asyncio
import pathlib
import tempfile

from ares.code_agents import mini_swe_agent
from ares.containers import docker
from ares.environments import snapshot
from ares.environments import swebench_env
from ares.llms import chat_completions_compatible


async def main():
# Create an LLM client
agent = chat_completions_compatible.ChatCompletionCompatibleLLMClient(model="openai/gpt-4o-mini")

# Load SWE-bench tasks
all_tasks = swebench_env.swebench_verified_tasks()
tasks = [all_tasks[0]]

print(f"Running on task: {tasks[0].instance_id}")
print(f"Repository: {tasks[0].repo}")
print("-" * 80)

# Create a temporary directory for snapshots
with tempfile.TemporaryDirectory() as snapshot_dir:
snapshot_path = pathlib.Path(snapshot_dir)

# === PART 1: Create and save a snapshot ===
print("\n[PART 1] Creating initial environment and snapshot...")

async with swebench_env.SweBenchEnv(
tasks=tasks,
code_agent_factory=mini_swe_agent.MiniSWECodeAgent,
container_factory=docker.DockerContainer,
) as env:
# Reset the environment to get the first timestep
ts = await env.reset()
print(f"Environment reset complete. Step count: {env._step_count}")

# Take a few steps before snapshotting
for i in range(3):
action = await agent(ts.observation)
print(f" Step {i}: Taking action...")
ts = await env.step(action)

if ts.last():
print(" Episode completed early")
break

print(f"Current step count: {env._step_count}")

# Wait for agent to finish current operation (reach episode boundary)
# In practice, you'd snapshot after step() returns with done=True
# or after reset() completes. For this example, we'll simulate
# waiting for agent to finish.
if not ts.last():
print("\n Note: For snapshotting, we need to be at episode boundary.")
print(" Cancelling agent task to reach boundary...")
if env._code_agent_task and not env._code_agent_task.done():
env._code_agent_task.cancel()
import contextlib

with contextlib.suppress(asyncio.CancelledError):
await env._code_agent_task

# Now we can export state (at episode boundary)
print("\n Exporting state snapshot...")
snap = await env.export_state(snapshot_path, snapshot_id="example-snapshot")

print(f" ✓ Snapshot created: {snap.snapshot_id}")
print(f" ✓ Snapshot saved to: {snap.snapshot_dir}")
print(f" ✓ Step count in snapshot: {snap.step_count}")
print(f" ✓ Task type: {snap.task_type}")
print(f" ✓ Container type: {snap.container_type}")

# === PART 2: Restore from snapshot ===
print("\n[PART 2] Restoring environment from snapshot...")

# Load snapshot metadata
snapshot_file = snapshot_path / "example-snapshot" / "snapshot.json"
loaded_snap = snapshot.EnvironmentSnapshot.load_from_file(snapshot_file)

print(f" ✓ Loaded snapshot: {loaded_snap.snapshot_id}")
print(f" ✓ Original step count: {loaded_snap.step_count}")

# Restore environment from snapshot
# Note: This creates a new environment instance with the saved state
restored_env = await swebench_env.SweBenchEnv.load_from_state(
loaded_snap,
container_factory=docker.DockerContainer,
code_agent_factory=mini_swe_agent.MiniSWECodeAgent,
)

print(" ✓ Environment restored")
print(f" ✓ Restored step count: {restored_env._step_count}")
print(f" ✓ Task: {restored_env._current_task.instance_id}")

# Use the restored environment in async context
async with restored_env:
print("\n[PART 3] Continuing from restored state...")

# The environment is now at the same state as when we snapshotted
# We can continue taking steps from here
ts = await restored_env.reset() # Reset to start a new episode
step_count = 0

# Take a few more steps to demonstrate it works
while not ts.last() and step_count < 3:
action = await agent(ts.observation)
print(f" Step {step_count}: Taking action from restored env...")
ts = await restored_env.step(action)
step_count += 1

print(f"\n ✓ Completed {step_count} additional steps from restored state")

print("\n" + "=" * 80)
print("Snapshot example completed successfully!")
print("=" * 80)
print("\nKey takeaways:")
print(" 1. Snapshots can only be taken at episode boundaries")
print(" 2. Snapshots save: task state, container filesystem, agent messages")
print(" 3. Restored environments can continue execution normally")
print(" 4. Use cases: debugging, RL replay, mechanistic interpretability")


if __name__ == "__main__":
asyncio.run(main())
172 changes: 170 additions & 2 deletions src/ares/environments/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,15 @@
import abc
import asyncio
import atexit
import dataclasses
import datetime
import functools
import logging
import os
import pathlib
import time
from types import TracebackType
from typing import Literal, NamedTuple, Protocol, Self
from typing import Any, Literal, NamedTuple, Protocol, Self
import uuid

from numpy.typing import NDArray
Expand All @@ -21,6 +23,7 @@
from ares.containers import containers
from ares.containers import daytona as ares_daytona
from ares.environments import base
from ares.environments import snapshot
from ares.experiment_tracking import stat_tracker
from ares.llms import llm_clients
from ares.llms import queue_mediated_client
Expand Down Expand Up @@ -287,9 +290,11 @@ def __init__(
self._is_active = False
self._container: containers.Container | None = None
self._current_task: TaskType | None = None
self._code_agent: code_agent_base.CodeAgent | None = None
self._code_agent_task: asyncio.Task[None] | None = None
self._step_count = 0
self._is_active = False
self._requires_reset = False
self._saved_agent_messages: list[dict] = []

# Register for cleanup on exit.
_ENVIRONMENT_JANITOR.register_for_cleanup(self)
Expand Down Expand Up @@ -430,6 +435,169 @@ def _assert_active(self) -> None:
if not self._is_active:
raise RuntimeError("Environment is not active.")

def _require_container(self) -> containers.Container:
"""Get container or raise if not available."""
if self._container is None:
raise RuntimeError("Container is not available.")
return self._container

def _require_task(self) -> TaskType:
"""Get current task or raise if not available."""
if self._current_task is None:
raise RuntimeError("No current task set.")
return self._current_task

def _validate_snapshot_allowed(self) -> None:
"""Raise error if snapshot not allowed (mid-episode)."""
if self._code_agent_task is not None and not self._code_agent_task.done():
raise RuntimeError(
"Cannot snapshot during active episode. Call export_state() after reset() or after final step()."
)

def _get_task_type(self) -> Literal["swebench", "harbor"]:
"""Return 'swebench' or 'harbor'. Override in subclasses if needed."""
# This will be overridden in subclasses if needed
raise NotImplementedError("Override _get_task_type in subclass")

def _get_container_type(self, container: containers.Container) -> Literal["daytona", "docker"]:
"""Return 'daytona' or 'docker'."""
from ares.containers.daytona import DaytonaContainer

return "daytona" if isinstance(container, DaytonaContainer) else "docker"

def _get_agent_messages(self) -> list[dict]:
"""Get agent message history if available."""
if self._code_agent is not None and hasattr(self._code_agent, "_messages"):
return list(self._code_agent._messages)
return []

async def _restore_container(self, snap: snapshot.EnvironmentSnapshot) -> containers.Container:
"""Restore container from filesystem snapshot."""
# Create new container from original image/dockerfile
if snap.container_image:
container = self._container_factory.from_image(
image=snap.container_image,
resources=containers.Resources(**snap.container_resources) if snap.container_resources else None,
)
elif snap.container_dockerfile_path:
container = self._container_factory.from_dockerfile(
dockerfile_path=pathlib.Path(snap.container_dockerfile_path),
resources=containers.Resources(**snap.container_resources) if snap.container_resources else None,
)
else:
raise ValueError("Snapshot must have either image or dockerfile_path")

# Start container
await container.start()

# Restore filesystem from tarball
fs_path = snap.snapshot_dir / "container_fs.tar.gz"
if fs_path.exists():
await container.upload_dir(fs_path, "/")
Comment on lines +494 to +496
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Critical

[Logic] This will fail at runtime. download_dir("/", fs_path) writes a tarball to fs_path, but both our Docker and Daytona container implementations expect upload_dir(local_path, remote_path) to be called with local_path pointing to a directory so they can walk it and stream a new tar archive (see the existing usage in HarborEnv._compute_reward, where we upload an actual directory). When you hand them a .tar.gz file here they hit os.walk/tar.add on a file and raise NotADirectoryError, so restoration aborts before the filesystem is restored. Please unpack the archive to a temporary directory (or stream it directly via the container API) and pass that directory to upload_dir instead of the tarball path.

Context for Agents
This will fail at runtime. `download_dir("/", fs_path)` writes a tarball to `fs_path`, but both our Docker and Daytona container implementations expect `upload_dir(local_path, remote_path)` to be called with `local_path` pointing to a directory so they can walk it and stream a new tar archive (see the existing usage in `HarborEnv._compute_reward`, where we upload an actual directory). When you hand them a `.tar.gz` file here they hit `os.walk`/`tar.add` on a file and raise `NotADirectoryError`, so restoration aborts before the filesystem is restored. Please unpack the archive to a temporary directory (or stream it directly via the container API) and pass that directory to `upload_dir` instead of the tarball path.

File: src/ares/environments/base.py
Line: 496


return container

@classmethod
def _default_container_factory(cls, snap: snapshot.EnvironmentSnapshot) -> containers.ContainerFactory:
"""Create default container factory from snapshot metadata."""
del snap # Unused - could be used in future to select factory based on container_type
# Default to Daytona
return ares_daytona.DaytonaContainer

async def export_state(
self,
output_dir: pathlib.Path,
*,
snapshot_id: str | None = None,
) -> snapshot.EnvironmentSnapshot:
"""Export environment state to snapshot.

Args:
output_dir: Directory to save snapshot files (tarballs, metadata)
snapshot_id: Optional ID (defaults to UUID)

Returns:
EnvironmentSnapshot with metadata

Raises:
RuntimeError: If called during active episode (running code agent)
"""
# Validate episode boundary
self._validate_snapshot_allowed()

snapshot_id = snapshot_id or str(uuid.uuid4())
snapshot_dir = output_dir / snapshot_id
snapshot_dir.mkdir(parents=True, exist_ok=True)

# 1. Download container filesystem
container = self._require_container()
fs_path = snapshot_dir / "container_fs.tar.gz"
await container.download_dir("/", fs_path)

# 2. Serialize task
task = self._require_task()
task_data = self._serialize_task(task)

# 3. Get agent messages (if agent exists)
agent_messages = self._get_agent_messages()

# 4. Create snapshot metadata
snap = snapshot.EnvironmentSnapshot(
snapshot_id=snapshot_id,
created_at=datetime.datetime.now().isoformat(),
snapshot_dir=snapshot_dir,
step_count=self._step_count,
step_limit=self._step_limit,
requires_reset=self._requires_reset,
task_type=self._get_task_type(),
task_data=task_data,
container_type=self._get_container_type(container),
container_image=getattr(container, "image", None),
container_dockerfile_path=(
str(getattr(container, "dockerfile_path", None)) if hasattr(container, "dockerfile_path") else None
),
container_resources=dataclasses.asdict(container.resources) if container.resources else None,
agent_messages=agent_messages,
)

# Save metadata JSON
snap.save_to_file(snapshot_dir / "snapshot.json")

return snap

async def _restore_from_snapshot(self, snap: snapshot.EnvironmentSnapshot) -> None:
"""Internal helper to restore state from snapshot.

Called by subclass load_from_state implementations after creating the environment.

Args:
snap: The snapshot to restore from
"""
# Restore container
self._container = await self._restore_container(snap)

# Deserialize and set task
task = self._deserialize_task(snap.task_data, snap.task_type)
self._current_task = task

# Restore state
self._step_count = snap.step_count
self._requires_reset = snap.requires_reset

# Store messages for later restoration
self._saved_agent_messages = snap.agent_messages

@abc.abstractmethod
def _serialize_task(self, task: TaskType) -> dict:
"""Serialize task to dict. Override in subclasses."""
pass

@classmethod
@abc.abstractmethod
def _deserialize_task(cls, task_data: dict, task_type: str) -> Any:
"""Deserialize task from dict. Override in subclasses."""
pass

@abc.abstractmethod
async def _reset_task(self) -> None:
"""Should set `self._current_task` with a TaskType"""
Expand Down
Loading
Loading