Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
143 changes: 143 additions & 0 deletions examples/03_state_snapshotting.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,143 @@
"""Example demonstrating environment state snapshotting and restoration.

This example shows how to:
1. Create a snapshot after reset (at episode boundary)
2. Save the snapshot to disk
3. Restore an environment from a saved snapshot
4. Continue execution from the restored state

Example usage:

1. Make sure you have examples dependencies installed
`uv sync --group examples`
2. Run the example
`uv run -m examples.03_state_snapshotting`
"""

import asyncio
import pathlib
import tempfile

from ares.code_agents import mini_swe_agent
from ares.containers import docker
from ares.environments import snapshot
from ares.environments import swebench_env
from ares.llms import chat_completions_compatible


async def main():
# Create an LLM client
agent = chat_completions_compatible.ChatCompletionCompatibleLLMClient(model="openai/gpt-4o-mini")

# Load SWE-bench tasks
all_tasks = swebench_env.swebench_verified_tasks()
tasks = [all_tasks[0]]

print(f"Running on task: {tasks[0].instance_id}")
print(f"Repository: {tasks[0].repo}")
print("-" * 80)

# Create a temporary directory for snapshots
with tempfile.TemporaryDirectory() as snapshot_dir:
snapshot_path = pathlib.Path(snapshot_dir)

# === PART 1: Create and save a snapshot ===
print("\n[PART 1] Creating initial environment and snapshot...")

async with swebench_env.SweBenchEnv(
tasks=tasks,
code_agent_factory=mini_swe_agent.MiniSWECodeAgent,
container_factory=docker.DockerContainer,
) as env:
# Reset the environment to get the first timestep
ts = await env.reset()
print(f"Environment reset complete. Step count: {env._step_count}")

# Take a few steps before snapshotting
for i in range(3):
action = await agent(ts.observation)
print(f" Step {i}: Taking action...")
ts = await env.step(action)

if ts.last():
print(" Episode completed early")
break

print(f"Current step count: {env._step_count}")

# Wait for agent to finish current operation (reach episode boundary)
# In practice, you'd snapshot after step() returns with done=True
# or after reset() completes. For this example, we'll simulate
# waiting for agent to finish.
if not ts.last():
print("\n Note: For snapshotting, we need to be at episode boundary.")
print(" Cancelling agent task to reach boundary...")
if env._code_agent_task and not env._code_agent_task.done():
env._code_agent_task.cancel()
import contextlib

with contextlib.suppress(asyncio.CancelledError):
await env._code_agent_task

# Now we can export state (at episode boundary)
print("\n Exporting state snapshot...")
snap = await env.export_state(snapshot_path, snapshot_id="example-snapshot")

print(f" ✓ Snapshot created: {snap.snapshot_id}")
print(f" ✓ Snapshot saved to: {snap.snapshot_dir}")
print(f" ✓ Step count in snapshot: {snap.step_count}")
print(f" ✓ Task type: {snap.task_type}")
print(f" ✓ Container type: {snap.container_type}")

# === PART 2: Restore from snapshot ===
print("\n[PART 2] Restoring environment from snapshot...")

# Load snapshot metadata
snapshot_file = snapshot_path / "example-snapshot" / "snapshot.json"
loaded_snap = snapshot.EnvironmentSnapshot.load_from_file(snapshot_file)

print(f" ✓ Loaded snapshot: {loaded_snap.snapshot_id}")
print(f" ✓ Original step count: {loaded_snap.step_count}")

# Restore environment from snapshot
# Note: This creates a new environment instance with the saved state
restored_env = await swebench_env.SweBenchEnv.load_from_state(
loaded_snap,
container_factory=docker.DockerContainer,
code_agent_factory=mini_swe_agent.MiniSWECodeAgent,
)

print(" ✓ Environment restored")
print(f" ✓ Restored step count: {restored_env._step_count}")
print(f" ✓ Task: {restored_env._current_task.instance_id}")

# Use the restored environment in async context
async with restored_env:
print("\n[PART 3] Continuing from restored state...")

# The environment is now at the same state as when we snapshotted
# We can continue taking steps from here
ts = await restored_env.reset() # Reset to start a new episode
step_count = 0
Comment on lines +115 to +121
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

restored_env.reset() discards the restored snapshot because reset() always stops/clears and reinitializes, should we change Environment.reset() to avoid stopping/clearing for restored envs or add a start-restored method, or else mark snapshots non-resumable and verify _restore_from_snapshot() rehydrates agent/task/message history so we can continue without reset()?

Finding type: Logical Bugs | Severity: 🔴 High


Want Baz to fix this for you? Activate Fixer

Other fix methods

Fix in Cursor

Prompt for AI Agents:

Before applying, verify this suggestion against the current code. In
examples/03_state_snapshotting.py around lines 115-121, the example calls
restored_env.reset() to “continue” the restored environment; this is wrong because
reset() destroys and recreates the container/agent and therefore discards the restored
snapshot. Change the example to NOT call reset() after load; instead either (A) call a
non-destructive startup method (implement and call Environment.start_after_restore() or
resume() in src/ares/environments/base.py that brings a restored env into a running
state without stopping/clearing the container), or (B) if you don’t want a new API,
use the restored_env's existing in-memory state (use the restored timestep/task fields
populated by _restore_from_snapshot) and proceed to step() from that state. Also update
or add a test and a short comment explaining that snapshots must be resumed via this
non-destructive path, and inspect/ensure base._restore_from_snapshot rehydrates agent
message history and _code_agent/_code_agent_task appropriately so continuing execution
works without calling reset().


# Take a few more steps to demonstrate it works
while not ts.last() and step_count < 3:
action = await agent(ts.observation)
print(f" Step {step_count}: Taking action from restored env...")
ts = await restored_env.step(action)
step_count += 1

print(f"\n ✓ Completed {step_count} additional steps from restored state")

print("\n" + "=" * 80)
print("Snapshot example completed successfully!")
print("=" * 80)
print("\nKey takeaways:")
print(" 1. Snapshots can only be taken at episode boundaries")
print(" 2. Snapshots save: task state, container filesystem, agent messages")
print(" 3. Restored environments can continue execution normally")
print(" 4. Use cases: debugging, RL replay, mechanistic interpretability")


if __name__ == "__main__":
asyncio.run(main())
15 changes: 15 additions & 0 deletions src/ares/environments/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
from ares.environments.base import Environment
from ares.environments.base import StepType
from ares.environments.base import TimeStep
from ares.environments.gym_wrapper import AsyncGymWrapper
from ares.environments.gym_wrapper import GymWrapper
from ares.environments.gym_wrapper import wrap_as_gym

__all__ = [
"AsyncGymWrapper",
"Environment",
"GymWrapper",
"StepType",
"TimeStep",
"wrap_as_gym",
]
172 changes: 170 additions & 2 deletions src/ares/environments/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,15 @@
import abc
import asyncio
import atexit
import dataclasses
import datetime
import functools
import logging
import os
import pathlib
import time
from types import TracebackType
from typing import Literal, NamedTuple, Protocol, Self
from typing import Any, Literal, NamedTuple, Protocol, Self
import uuid

from numpy.typing import NDArray
Expand All @@ -21,6 +23,7 @@
from ares.containers import containers
from ares.containers import daytona as ares_daytona
from ares.environments import base
from ares.environments import snapshot
from ares.experiment_tracking import stat_tracker
from ares.llms import llm_clients
from ares.llms import queue_mediated_client
Expand Down Expand Up @@ -287,9 +290,11 @@ def __init__(
self._is_active = False
self._container: containers.Container | None = None
self._current_task: TaskType | None = None
self._code_agent: code_agent_base.CodeAgent | None = None
self._code_agent_task: asyncio.Task[None] | None = None
self._step_count = 0
self._is_active = False
self._requires_reset = False
self._saved_agent_messages: list[dict] = []

# Register for cleanup on exit.
_ENVIRONMENT_JANITOR.register_for_cleanup(self)
Expand Down Expand Up @@ -430,6 +435,169 @@ def _assert_active(self) -> None:
if not self._is_active:
raise RuntimeError("Environment is not active.")

def _require_container(self) -> containers.Container:
"""Get container or raise if not available."""
if self._container is None:
raise RuntimeError("Container is not available.")
return self._container
Comment on lines +438 to +442
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

_require_container and other snapshot helpers raise bare RuntimeError/ValueError, preventing callers/tests from distinguishing snapshot/episode-boundary failures — should we introduce environment-specific exceptions like SnapshotError/SnapshotTimingError and raise those instead?

Finding type: AI Coding Guidelines | Severity: 🟠 Medium


Want Baz to fix this for you? Activate Fixer

Other fix methods

Fix in Cursor

Prompt for AI Agents:

Before applying, verify this suggestion against the current code. In
src/ares/environments/base.py around lines 438 to 442, the helper methods
(_require_container, _require_task, _validate_snapshot_allowed and related snapshot
logic) currently raise bare RuntimeError/ValueError. Add environment-specific exception
classes (e.g. SnapshotError as base, SnapshotTimingError for mid-episode snapshot
attempts, and SnapshotMetadataError for invalid snapshot metadata) near the top of the
module or in src/ares/environments/__init__.py. Replace the RuntimeError in
_require_container/_require_task with SnapshotError or a more specific subclass, replace
the RuntimeError raised in _validate_snapshot_allowed with SnapshotTimingError, and
replace the ValueError in _restore_container (invalid snapshot metadata) with
SnapshotMetadataError. Ensure export_state/docstrings reference the new exceptions and
update any tests or callers that expect RuntimeError to catch the new exception types.


def _require_task(self) -> TaskType:
"""Get current task or raise if not available."""
if self._current_task is None:
raise RuntimeError("No current task set.")
return self._current_task

def _validate_snapshot_allowed(self) -> None:
"""Raise error if snapshot not allowed (mid-episode)."""
if self._code_agent_task is not None and not self._code_agent_task.done():
raise RuntimeError(
"Cannot snapshot during active episode. Call export_state() after reset() or after final step()."
)
Comment on lines +450 to +455
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

_validate_snapshot_allowed() blocks snapshots while _code_agent_task is running, so export_state() fails after reset(); should we relax the guard to allow snapshots at the episode boundary by checking _requires_reset?

Finding type: Logical Bugs | Severity: 🔴 High


Want Baz to fix this for you? Activate Fixer

Other fix methods

Fix in Cursor

Prompt for AI Agents:

Before applying, verify this suggestion against the current code. In
src/ares/environments/base.py around lines 450 to 455, the _validate_snapshot_allowed
method forbids snapshots whenever self._code_agent_task is running, but reset() (around
lines ~302-322) starts the code agent and returns while that task is still pending,
which prevents valid 'snapshot after reset' behavior. Modify _validate_snapshot_allowed
so it only raises if the agent task is running AND self._requires_reset is False (i.e.,
allow snapshots immediately after reset when _requires_reset is True). Additionally, in
the reset() method (lines ~302-322) set self._requires_reset = True at the end of reset
to mark the post-reset boundary, and in the first step() invocation clear
self._requires_reset = False (or where the agent begins actual execution) so subsequent
snapshots are again blocked while the episode is active. Ensure these changes include
small comments explaining the intent.


def _get_task_type(self) -> Literal["swebench", "harbor"]:
"""Return 'swebench' or 'harbor'. Override in subclasses if needed."""
# This will be overridden in subclasses if needed
raise NotImplementedError("Override _get_task_type in subclass")

def _get_container_type(self, container: containers.Container) -> Literal["daytona", "docker"]:
"""Return 'daytona' or 'docker'."""
from ares.containers.daytona import DaytonaContainer

return "daytona" if isinstance(container, DaytonaContainer) else "docker"

def _get_agent_messages(self) -> list[dict]:
"""Get agent message history if available."""
if self._code_agent is not None and hasattr(self._code_agent, "_messages"):
return list(self._code_agent._messages)
return []

async def _restore_container(self, snap: snapshot.EnvironmentSnapshot) -> containers.Container:
"""Restore container from filesystem snapshot."""
# Create new container from original image/dockerfile
if snap.container_image:
container = self._container_factory.from_image(
image=snap.container_image,
resources=containers.Resources(**snap.container_resources) if snap.container_resources else None,
)
elif snap.container_dockerfile_path:
container = self._container_factory.from_dockerfile(
dockerfile_path=pathlib.Path(snap.container_dockerfile_path),
resources=containers.Resources(**snap.container_resources) if snap.container_resources else None,
)
else:
raise ValueError("Snapshot must have either image or dockerfile_path")

# Start container
await container.start()

# Restore filesystem from tarball
fs_path = snap.snapshot_dir / "container_fs.tar.gz"
if fs_path.exists():
await container.upload_dir(fs_path, "/")

return container
Comment on lines +494 to +498
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

🧩 Analysis chain

🏁 Script executed:

#!/bin/bash
# Check how upload_dir is implemented in container classes
ast-grep --pattern $'async def upload_dir($_, local_path$_, remote_path$_) {
  $$$
}'

Repository: withmartian/ares

Length of output: 42


🏁 Script executed:

# First, find the container class and how upload_dir is used
rg "upload_dir" -A 5 -B 2

Repository: withmartian/ares

Length of output: 3609


🏁 Script executed:

# Find container class definitions
rg "class.*container" -i

Repository: withmartian/ares

Length of output: 1041


🏁 Script executed:

# Look at the imports in base.py to understand the container type
head -50 src/ares/environments/base.py

Repository: withmartian/ares

Length of output: 1567


🏁 Script executed:

# Get the full upload_dir implementation
sed -n '/(async def upload_dir/,/^[[:space:]]*async def\|^[[:space:]]*def\|^class/p' src/ares/containers/containers.py | head -30

Repository: withmartian/ares

Length of output: 42


🏁 Script executed:

# Alternative: use rg to get more context around upload_dir
rg "async def upload_dir" -A 20 src/ares/containers/containers.py

Repository: withmartian/ares

Length of output: 1032


🏁 Script executed:

# Also check how fs_path is used in base.py context
sed -n '490,510p' src/ares/environments/base.py

Repository: withmartian/ares

Length of output: 782


🏁 Script executed:

# Check how the tarball is created/saved in export_state
rg "container_fs.tar.gz" -B 5 -A 2

Repository: withmartian/ares

Length of output: 1664


🏁 Script executed:

# Check the snapshot test to see intended behavior
sed -n '1,100p' src/ares/environments/snapshot_test.py | grep -A 20 "upload_dir"

Repository: withmartian/ares

Length of output: 42


🏁 Script executed:

# Look for where tarball is extracted or how filesystem is meant to be restored
rg "tar.*gz" -i

Repository: withmartian/ares

Length of output: 337


🏁 Script executed:

# Check download_dir implementation to understand how the tarball is created
rg "async def download_dir" -A 20 src/ares/containers/containers.py

Repository: withmartian/ares

Length of output: 954


🏁 Script executed:

# Check the full context of how export_state saves the filesystem
sed -n '507,550p' src/ares/environments/base.py

Repository: withmartian/ares

Length of output: 1546


🏁 Script executed:

# Check the snapshot.py to understand how snapshot stores/loads data
grep -n "snapshot_dir\|download_dir\|upload_dir" src/ares/environments/snapshot.py | head -20

Repository: withmartian/ares

Length of output: 387


🏁 Script executed:

# Check if there's any extraction or decompression happening with the "tarball"
rg "tar\|gz\|extract\|decompress" src/ares/environments/base.py

Repository: withmartian/ares

Length of output: 42


🏁 Script executed:

# Check the test to see what's actually in container_fs.tar.gz
sed -n '1,200p' src/ares/environments/snapshot_test.py | tail -100

Repository: withmartian/ares

Length of output: 3127


🏁 Script executed:

# Look at the full download_dir call to understand the behavior better
sed -n '530,545p' src/ares/environments/base.py

Repository: withmartian/ares

Length of output: 610


🏁 Script executed:

# Let me check what download_dir actually creates - is it a tarball or a directory?
# Look at the test mock implementation
sed -n '1,150p' src/ares/environments/snapshot_test.py | grep -A 30 "class MockContainerWithDirOps"

Repository: withmartian/ares

Length of output: 42


🏁 Script executed:

# Check if the naming is misleading - does "container_fs.tar.gz" actually create a tarball or directory structure?
# Let me see the full mock implementation for upload_dir
sed -n '1,300p' src/ares/environments/snapshot_test.py | grep -A 10 "async def upload_dir"

Repository: withmartian/ares

Length of output: 42


🏁 Script executed:

# Get the full snapshot_test.py to see the mock implementations
cat src/ares/environments/snapshot_test.py

Repository: withmartian/ares

Length of output: 13577


🏁 Script executed:

# Let me check the exact implementation of upload_dir more carefully, especially the destination_path line
sed -n '/async def upload_dir/,/async def download_file/p' src/ares/containers/containers.py | head -20

Repository: withmartian/ares

Length of output: 765


🏁 Script executed:

# Check if remote_path is maybe converted to Path somewhere
rg "destination_path.*remote_path" -B 2 -A 2

Repository: withmartian/ares

Length of output: 458


🏁 Script executed:

# Check if this code actually works - the remote_path is a string, but it's being used with /
# Let me trace through: remote_path is "/" (a string), relative_path is Path
# destination_path = str(remote_path / relative_path) should fail
# Let me verify if there's any conversion happening
python3 << 'EOF'
# Test if string / Path works
s = "/"
p = "file.txt"
try:
    result = s / p
    print("Success:", result)
except TypeError as e:
    print("Error:", e)
EOF

Repository: withmartian/ares

Length of output: 117


🏁 Script executed:

# Also check if there's maybe a different version or if pathlib.Path is being used
rg "destination_path.*=" -B 5 -A 1 src/ares/containers/

Repository: withmartian/ares

Length of output: 550


🏁 Script executed:

# Check the actual test to see if it exercises upload_dir properly
grep -A 30 "async def upload_dir" src/ares/environments/snapshot_test.py

Repository: withmartian/ares

Length of output: 1172


Critical bug: upload_dir call will fail due to type error in container implementation.

The code at line 497 calls container.upload_dir(fs_path, "/") where fs_path is a .tar.gz file. However, the upload_dir implementation in src/ares/containers/containers.py has a type error on line that computes destination_path = str(remote_path / relative_path). Since remote_path is a string (not a pathlib.Path), the / operator fails with TypeError: unsupported operand type(s) for /: 'str' and 'str'.

Additionally, upload_dir expects a directory and iterates through it with rglob("*"). Passing a .tar.gz file will not yield any files to upload, breaking snapshot restoration. The tarball should either be uploaded as a single file using upload_file or extracted before passing to upload_dir.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@src/ares/environments/base.py` around lines 494 - 498, The call
container.upload_dir(fs_path, "/") is wrong because upload_dir expects a
directory and its implementation uses remote_path / relative_path which fails
when remote_path is a str; either upload the tarball as a single file with
container.upload_file(fs_path, "/container_fs.tar.gz") or extract fs_path to a
temporary directory and call container.upload_dir(extracted_dir, "/");
additionally fix the upload_dir implementation in
src/ares/containers/containers.py by ensuring remote_path is converted to a
pathlib.Path (e.g., remote_base = Path(remote_path)) before using the / operator
and keep iteration over Path.rglob("*") to handle directories correctly.


@classmethod
def _default_container_factory(cls, snap: snapshot.EnvironmentSnapshot) -> containers.ContainerFactory:
"""Create default container factory from snapshot metadata."""
del snap # Unused - could be used in future to select factory based on container_type
# Default to Daytona
return ares_daytona.DaytonaContainer

async def export_state(
self,
output_dir: pathlib.Path,
*,
snapshot_id: str | None = None,
) -> snapshot.EnvironmentSnapshot:
"""Export environment state to snapshot.

Args:
output_dir: Directory to save snapshot files (tarballs, metadata)
snapshot_id: Optional ID (defaults to UUID)

Returns:
EnvironmentSnapshot with metadata

Raises:
RuntimeError: If called during active episode (running code agent)
"""
# Validate episode boundary
self._validate_snapshot_allowed()

snapshot_id = snapshot_id or str(uuid.uuid4())
snapshot_dir = output_dir / snapshot_id
snapshot_dir.mkdir(parents=True, exist_ok=True)

# 1. Download container filesystem
container = self._require_container()
fs_path = snapshot_dir / "container_fs.tar.gz"
await container.download_dir("/", fs_path)

# 2. Serialize task
task = self._require_task()
task_data = self._serialize_task(task)

# 3. Get agent messages (if agent exists)
agent_messages = self._get_agent_messages()

# 4. Create snapshot metadata
snap = snapshot.EnvironmentSnapshot(
snapshot_id=snapshot_id,
created_at=datetime.datetime.now().isoformat(),
snapshot_dir=snapshot_dir,
step_count=self._step_count,
step_limit=self._step_limit,
requires_reset=self._requires_reset,
task_type=self._get_task_type(),
task_data=task_data,
container_type=self._get_container_type(container),
container_image=getattr(container, "image", None),
container_dockerfile_path=(
str(getattr(container, "dockerfile_path", None)) if hasattr(container, "dockerfile_path") else None
),
container_resources=dataclasses.asdict(container.resources) if container.resources else None,
agent_messages=agent_messages,
)

# Save metadata JSON
snap.save_to_file(snapshot_dir / "snapshot.json")

return snap

async def _restore_from_snapshot(self, snap: snapshot.EnvironmentSnapshot) -> None:
"""Internal helper to restore state from snapshot.

Called by subclass load_from_state implementations after creating the environment.

Args:
snap: The snapshot to restore from
"""
# Restore container
self._container = await self._restore_container(snap)

# Deserialize and set task
task = self._deserialize_task(snap.task_data, snap.task_type)
self._current_task = task

# Restore state
self._step_count = snap.step_count
self._requires_reset = snap.requires_reset

# Store messages for later restoration
self._saved_agent_messages = snap.agent_messages

Comment on lines +586 to +589
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we seed _code_agent._messages from _saved_agent_messages before run so restored snapshots actually resume the prior conversation?

Finding type: Logical Bugs | Severity: 🔴 High


Want Baz to fix this for you? Activate Fixer

Other fix methods

Fix in Cursor

Prompt for AI Agents:

Before applying, verify this suggestion against the current code. In
src/ares/environments/base.py around lines 586 to 589, the _restore_from_snapshot method
stores prior agent messages in self._saved_agent_messages but nothing injects those
messages into the newly created code agent, causing conversation history loss. Modify
the restore/start flow so that when subclasses instantiate the code agent (e.g., their
_start_code_agent or equivalent that currently creates MiniSWECodeAgent), they seed the
agent's internal message history with self._saved_agent_messages before calling
run/start: either add an optional messages parameter to the code agent
factory/constructor and pass self._saved_agent_messages, or after construction set
code_agent._messages = list(self._saved_agent_messages) (defensive copy) before starting
the agent. Also clear self._saved_agent_messages after seeding to avoid double-restore.

@abc.abstractmethod
def _serialize_task(self, task: TaskType) -> dict:
"""Serialize task to dict. Override in subclasses."""
pass

@classmethod
@abc.abstractmethod
def _deserialize_task(cls, task_data: dict, task_type: str) -> Any:
"""Deserialize task from dict. Override in subclasses."""
pass

@abc.abstractmethod
async def _reset_task(self) -> None:
"""Should set `self._current_task` with a TaskType"""
Expand Down
Loading