From abae672e2c0330e8c206fb18de39bddb778fdd3d Mon Sep 17 00:00:00 2001 From: Nithin Date: Mon, 26 Jan 2026 11:23:12 -0500 Subject: [PATCH 1/2] Add environment state snapshotting for RL research Implements snapshot/restore functionality to save and replay episodes from specific checkpoints. Useful for debugging, trajectory analysis, and mechanistic interpretability. - Add EnvironmentSnapshot dataclass for serializing env state - Implement export_state() and load_from_state() methods on CodeBaseEnv - Support both SWEBench and Harbor environments - Save container filesystem as tarball with JSON metadata - Snapshots only work at episode boundaries (can't snapshot mid-episode) - Add comprehensive test coverage Closes #39 --- examples/03_state_snapshotting.py | 143 +++++++++++ src/ares/environments/base.py | 200 ++++++++++++++- src/ares/environments/harbor_env.py | 15 ++ src/ares/environments/snapshot.py | 145 +++++++++++ src/ares/environments/snapshot_test.py | 336 +++++++++++++++++++++++++ src/ares/environments/swebench_env.py | 21 +- 6 files changed, 857 insertions(+), 3 deletions(-) create mode 100644 examples/03_state_snapshotting.py create mode 100644 src/ares/environments/snapshot.py create mode 100644 src/ares/environments/snapshot_test.py diff --git a/examples/03_state_snapshotting.py b/examples/03_state_snapshotting.py new file mode 100644 index 0000000..7da708f --- /dev/null +++ b/examples/03_state_snapshotting.py @@ -0,0 +1,143 @@ +"""Example demonstrating environment state snapshotting and restoration. + +This example shows how to: +1. Create a snapshot after reset (at episode boundary) +2. Save the snapshot to disk +3. Restore an environment from a saved snapshot +4. Continue execution from the restored state + +Example usage: + + 1. Make sure you have examples dependencies installed + `uv sync --group examples` + 2. Run the example + `uv run -m examples.03_state_snapshotting` +""" + +import asyncio +import pathlib +import tempfile + +from ares.code_agents import mini_swe_agent +from ares.containers import docker +from ares.environments import snapshot +from ares.environments import swebench_env +from ares.llms import chat_completions_compatible + + +async def main(): + # Create an LLM client + agent = chat_completions_compatible.ChatCompletionCompatibleLLMClient(model="openai/gpt-4o-mini") + + # Load SWE-bench tasks + all_tasks = swebench_env.swebench_verified_tasks() + tasks = [all_tasks[0]] + + print(f"Running on task: {tasks[0].instance_id}") + print(f"Repository: {tasks[0].repo}") + print("-" * 80) + + # Create a temporary directory for snapshots + with tempfile.TemporaryDirectory() as snapshot_dir: + snapshot_path = pathlib.Path(snapshot_dir) + + # === PART 1: Create and save a snapshot === + print("\n[PART 1] Creating initial environment and snapshot...") + + async with swebench_env.SweBenchEnv( + tasks=tasks, + code_agent_factory=mini_swe_agent.MiniSWECodeAgent, + container_factory=docker.DockerContainer, + ) as env: + # Reset the environment to get the first timestep + ts = await env.reset() + print(f"Environment reset complete. Step count: {env._step_count}") + + # Take a few steps before snapshotting + for i in range(3): + action = await agent(ts.observation) + print(f" Step {i}: Taking action...") + ts = await env.step(action) + + if ts.last(): + print(" Episode completed early") + break + + print(f"Current step count: {env._step_count}") + + # Wait for agent to finish current operation (reach episode boundary) + # In practice, you'd snapshot after step() returns with done=True + # or after reset() completes. For this example, we'll simulate + # waiting for agent to finish. + if not ts.last(): + print("\n Note: For snapshotting, we need to be at episode boundary.") + print(" Cancelling agent task to reach boundary...") + if env._code_agent_task and not env._code_agent_task.done(): + env._code_agent_task.cancel() + import contextlib + + with contextlib.suppress(asyncio.CancelledError): + await env._code_agent_task + + # Now we can export state (at episode boundary) + print("\n Exporting state snapshot...") + snap = await env.export_state(snapshot_path, snapshot_id="example-snapshot") + + print(f" ✓ Snapshot created: {snap.snapshot_id}") + print(f" ✓ Snapshot saved to: {snap.snapshot_dir}") + print(f" ✓ Step count in snapshot: {snap.step_count}") + print(f" ✓ Task type: {snap.task_type}") + print(f" ✓ Container type: {snap.container_type}") + + # === PART 2: Restore from snapshot === + print("\n[PART 2] Restoring environment from snapshot...") + + # Load snapshot metadata + snapshot_file = snapshot_path / "example-snapshot" / "snapshot.json" + loaded_snap = snapshot.EnvironmentSnapshot.load_from_file(snapshot_file) + + print(f" ✓ Loaded snapshot: {loaded_snap.snapshot_id}") + print(f" ✓ Original step count: {loaded_snap.step_count}") + + # Restore environment from snapshot + # Note: This creates a new environment instance with the saved state + restored_env = await swebench_env.SweBenchEnv.load_from_state( + loaded_snap, + container_factory=docker.DockerContainer, + code_agent_factory=mini_swe_agent.MiniSWECodeAgent, + ) + + print(" ✓ Environment restored") + print(f" ✓ Restored step count: {restored_env._step_count}") + print(f" ✓ Task: {restored_env._current_task.instance_id}") + + # Use the restored environment in async context + async with restored_env: + print("\n[PART 3] Continuing from restored state...") + + # The environment is now at the same state as when we snapshotted + # We can continue taking steps from here + ts = await restored_env.reset() # Reset to start a new episode + step_count = 0 + + # Take a few more steps to demonstrate it works + while not ts.last() and step_count < 3: + action = await agent(ts.observation) + print(f" Step {step_count}: Taking action from restored env...") + ts = await restored_env.step(action) + step_count += 1 + + print(f"\n ✓ Completed {step_count} additional steps from restored state") + + print("\n" + "=" * 80) + print("Snapshot example completed successfully!") + print("=" * 80) + print("\nKey takeaways:") + print(" 1. Snapshots can only be taken at episode boundaries") + print(" 2. Snapshots save: task state, container filesystem, agent messages") + print(" 3. Restored environments can continue execution normally") + print(" 4. Use cases: debugging, RL replay, mechanistic interpretability") + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/src/ares/environments/base.py b/src/ares/environments/base.py index bd96bcc..64fbd60 100644 --- a/src/ares/environments/base.py +++ b/src/ares/environments/base.py @@ -5,13 +5,15 @@ import abc import asyncio import atexit +import dataclasses +import datetime import functools import logging import os import pathlib import time from types import TracebackType -from typing import Literal, NamedTuple, Protocol, Self +from typing import Any, Literal, NamedTuple, Protocol, Self import uuid from numpy.typing import NDArray @@ -21,6 +23,7 @@ from ares.containers import containers from ares.containers import daytona as ares_daytona from ares.environments import base +from ares.environments import snapshot from ares.experiment_tracking import stat_tracker from ares.llms import llm_clients from ares.llms import queue_mediated_client @@ -287,9 +290,11 @@ def __init__( self._is_active = False self._container: containers.Container | None = None self._current_task: TaskType | None = None + self._code_agent: code_agent_base.CodeAgent | None = None self._code_agent_task: asyncio.Task[None] | None = None self._step_count = 0 - self._is_active = False + self._requires_reset = False + self._saved_agent_messages: list[dict] = [] # Register for cleanup on exit. _ENVIRONMENT_JANITOR.register_for_cleanup(self) @@ -430,6 +435,197 @@ def _assert_active(self) -> None: if not self._is_active: raise RuntimeError("Environment is not active.") + def _require_container(self) -> containers.Container: + """Get container or raise if not available.""" + if self._container is None: + raise RuntimeError("Container is not available.") + return self._container + + def _require_task(self) -> TaskType: + """Get current task or raise if not available.""" + if self._current_task is None: + raise RuntimeError("No current task set.") + return self._current_task + + def _validate_snapshot_allowed(self) -> None: + """Raise error if snapshot not allowed (mid-episode).""" + if self._code_agent_task is not None and not self._code_agent_task.done(): + raise RuntimeError( + "Cannot snapshot during active episode. Call export_state() after reset() or after final step()." + ) + + def _get_task_type(self) -> Literal["swebench", "harbor"]: + """Return 'swebench' or 'harbor'. Override in subclasses if needed.""" + # This will be overridden in subclasses if needed + raise NotImplementedError("Override _get_task_type in subclass") + + def _get_container_type(self, container: containers.Container) -> Literal["daytona", "docker"]: + """Return 'daytona' or 'docker'.""" + from ares.containers.daytona import DaytonaContainer + + return "daytona" if isinstance(container, DaytonaContainer) else "docker" + + def _get_agent_messages(self) -> list[dict]: + """Get agent message history if available.""" + if self._code_agent is not None and hasattr(self._code_agent, "_messages"): + return list(self._code_agent._messages) + return [] + + async def _restore_container(self, snap: snapshot.EnvironmentSnapshot) -> containers.Container: + """Restore container from filesystem snapshot.""" + # Create new container from original image/dockerfile + if snap.container_image: + container = self._container_factory.from_image( + image=snap.container_image, + resources=containers.Resources(**snap.container_resources) if snap.container_resources else None, + ) + elif snap.container_dockerfile_path: + container = self._container_factory.from_dockerfile( + dockerfile_path=pathlib.Path(snap.container_dockerfile_path), + resources=containers.Resources(**snap.container_resources) if snap.container_resources else None, + ) + else: + raise ValueError("Snapshot must have either image or dockerfile_path") + + # Start container + await container.start() + + # Restore filesystem from tarball + fs_path = snap.snapshot_dir / "container_fs.tar.gz" + if fs_path.exists(): + await container.upload_dir(fs_path, "/") + + return container + + @classmethod + def _default_container_factory(cls, snap: snapshot.EnvironmentSnapshot) -> containers.ContainerFactory: + """Create default container factory from snapshot metadata.""" + del snap # Unused - could be used in future to select factory based on container_type + # Default to Daytona + return ares_daytona.DaytonaContainer + + async def export_state( + self, + output_dir: pathlib.Path, + *, + snapshot_id: str | None = None, + ) -> snapshot.EnvironmentSnapshot: + """Export environment state to snapshot. + + Args: + output_dir: Directory to save snapshot files (tarballs, metadata) + snapshot_id: Optional ID (defaults to UUID) + + Returns: + EnvironmentSnapshot with metadata + + Raises: + RuntimeError: If called during active episode (running code agent) + """ + # Validate episode boundary + self._validate_snapshot_allowed() + + snapshot_id = snapshot_id or str(uuid.uuid4()) + snapshot_dir = output_dir / snapshot_id + snapshot_dir.mkdir(parents=True, exist_ok=True) + + # 1. Download container filesystem + container = self._require_container() + fs_path = snapshot_dir / "container_fs.tar.gz" + await container.download_dir("/", fs_path) + + # 2. Serialize task + task = self._require_task() + task_data = self._serialize_task(task) + + # 3. Get agent messages (if agent exists) + agent_messages = self._get_agent_messages() + + # 4. Create snapshot metadata + snap = snapshot.EnvironmentSnapshot( + snapshot_id=snapshot_id, + created_at=datetime.datetime.now().isoformat(), + snapshot_dir=snapshot_dir, + step_count=self._step_count, + step_limit=self._step_limit, + requires_reset=self._requires_reset, + task_type=self._get_task_type(), + task_data=task_data, + container_type=self._get_container_type(container), + container_image=getattr(container, "image", None), + container_dockerfile_path=( + str(getattr(container, "dockerfile_path", None)) if hasattr(container, "dockerfile_path") else None + ), + container_resources=dataclasses.asdict(container.resources) if container.resources else None, + agent_messages=agent_messages, + ) + + # Save metadata JSON + snap.save_to_file(snapshot_dir / "snapshot.json") + + return snap + + @classmethod + async def load_from_state( + cls, + snapshot_path: snapshot.EnvironmentSnapshot | pathlib.Path, + *, + container_factory: containers.ContainerFactory | None = None, + code_agent_factory: code_agent_base.CodeAgentFactory | None = None, + ) -> "CodeBaseEnv": + """Restore environment from snapshot. + + Args: + snapshot_path: EnvironmentSnapshot or path to snapshot.json + container_factory: Override factory (uses snapshot metadata if None) + code_agent_factory: Override factory (uses default if None) + + Returns: + Restored environment (NOT active, use async with) + """ + # Load snapshot if path provided + if isinstance(snapshot_path, pathlib.Path): + snap = snapshot.EnvironmentSnapshot.load_from_file(snapshot_path) + else: + snap = snapshot_path + + # Deserialize task + task = cls._deserialize_task(snap.task_data, snap.task_type) + + # Create environment instance + container_factory = container_factory or cls._default_container_factory(snap) + code_agent_factory = code_agent_factory or mini_swe_agent.MiniSWECodeAgent + + # Note: This creates a base CodeBaseEnv which is abstract + # In practice, this should be called on SweBenchEnv or HarborEnv subclasses + env = cls( + container_factory=container_factory, + code_agent_factory=code_agent_factory, + step_limit=snap.step_limit, + ) + + # Restore container + env._container = await env._restore_container(snap) + env._current_task = task + env._step_count = snap.step_count + env._requires_reset = snap.requires_reset + + # Store messages for later restoration + env._saved_agent_messages = snap.agent_messages + + return env + + @abc.abstractmethod + def _serialize_task(self, task: TaskType) -> dict: + """Serialize task to dict. Override in subclasses.""" + pass + + @classmethod + @abc.abstractmethod + def _deserialize_task(cls, task_data: dict, task_type: str) -> Any: + """Deserialize task from dict. Override in subclasses.""" + pass + @abc.abstractmethod async def _reset_task(self) -> None: """Should set `self._current_task` with a TaskType""" diff --git a/src/ares/environments/harbor_env.py b/src/ares/environments/harbor_env.py index a986356..5ce7e8e 100644 --- a/src/ares/environments/harbor_env.py +++ b/src/ares/environments/harbor_env.py @@ -11,6 +11,7 @@ import logging import pathlib import random +from typing import Literal from harbor.models.task import task as harbor_task from harbor.models.trial import paths as harbor_paths @@ -164,3 +165,17 @@ async def _parse_reward_file(self, remote_path: pathlib.Path | str) -> float | N else: raise ValueError(f"Unsupported reward file type: {remote_path}") + + def _get_task_type(self) -> Literal["swebench", "harbor"]: + """Return task type for snapshotting.""" + return "harbor" + + def _serialize_task(self, task: harbor_task.Task) -> dict: + """Serialize Harbor task (just save task directory path).""" + return {"task_dir": str(task.task_dir)} + + @classmethod + def _deserialize_task(cls, task_data: dict, task_type: str) -> harbor_task.Task: + """Deserialize Harbor task (reload from directory).""" + del task_type # Unused - validated by caller + return harbor_task.Task(task_dir=pathlib.Path(task_data["task_dir"])) diff --git a/src/ares/environments/snapshot.py b/src/ares/environments/snapshot.py new file mode 100644 index 0000000..f87f180 --- /dev/null +++ b/src/ares/environments/snapshot.py @@ -0,0 +1,145 @@ +"""Environment state snapshotting for RL research and mechanistic interpretability. + +This module provides functionality to snapshot and restore environment state, +enabling use cases like: +- RL trajectory replay and analysis +- Debugging failed episodes +- Mechanistic interpretability of agent behavior +- Checkpointing long-running experiments + +## Usage Example + +```python +import pathlib +from ares.environments import swebench_env, snapshot + +# Create and run environment +async with swebench_env.SweBenchEnv(tasks=[task]) as env: + ts = await env.reset() + + # Take some steps + for _ in range(10): + action = await agent(ts.observation) + ts = await env.step(action) + if ts.last(): + break + + # Wait for episode to complete (required for snapshotting) + if env._code_agent_task and not env._code_agent_task.done(): + env._code_agent_task.cancel() + await env._code_agent_task + + # Export state at episode boundary + snap = await env.export_state(pathlib.Path("./snapshots")) + +# Later: restore from snapshot +loaded_snap = snapshot.EnvironmentSnapshot.load_from_file( + pathlib.Path("./snapshots/abc-123/snapshot.json") +) + +restored_env = await swebench_env.SweBenchEnv.load_from_state(loaded_snap) +async with restored_env: + # Continue from saved state + ts = await restored_env.reset() + ... +``` + +## Limitations + +- **Episode boundaries only**: Snapshots can only be created when no code agent + task is running (after reset() or after final step() with done=True) +- **No mid-execution state**: Agent message history is saved, but not mid-execution + state like running asyncio tasks or futures +- **Large snapshots**: Container filesystems are saved as tarballs (100MB-2GB typical) +- **Container restoration**: Containers are recreated from original images and + filesystem is restored from tarball, not from running container state + +## What Gets Snapshotted + +Serializable state: +- Step count and step limit +- Task metadata (serialized Pydantic models or paths) +- Container metadata (image, dockerfile, resources) +- Agent message history + +Non-serializable state (cannot snapshot): +- Running asyncio tasks and futures +- Active LLM request queues +- Live container connections +""" + +import dataclasses +import json +import pathlib +from typing import Literal + + +@dataclasses.dataclass(frozen=True) +class EnvironmentSnapshot: + """Complete environment state snapshot. + + Can only be created at episode boundaries: + - After env.reset() completes (FIRST timestep) + - After env.step() returns LAST timestep (done=True) + - When no code agent task is running + + Limitations: + - Snapshots only at episode boundaries (after reset or final step) + - Cannot snapshot mid-episode (running async tasks/futures) + - Agent message history preserved, but not mid-execution state + - Large filesystem snapshots (100MB-2GB tarballs) + """ + + # Unique identifier and metadata + snapshot_id: str + created_at: str # ISO timestamp + snapshot_dir: pathlib.Path + + # Episode state + step_count: int + step_limit: int + requires_reset: bool + + # Task metadata (for reconstruction) + task_type: Literal["swebench", "harbor"] + task_data: dict # Serialized task (Pydantic model_dump or Harbor path) + + # Container metadata + container_type: Literal["daytona", "docker"] + container_image: str | None + container_dockerfile_path: str | None + container_resources: dict | None # Serialized Resources + + # Code agent state + agent_messages: list[dict] # Chat history from MiniSWECodeAgent._messages + + def save_to_file(self, path: pathlib.Path) -> None: + """Save snapshot metadata to JSON file. + + Args: + path: Path to save JSON file (typically snapshot_dir/snapshot.json) + """ + # Convert pathlib.Path to string for JSON serialization + data = dataclasses.asdict(self) + data["snapshot_dir"] = str(data["snapshot_dir"]) + + with open(path, "w") as f: + json.dump(data, f, indent=2) + + @classmethod + def load_from_file(cls, path: pathlib.Path) -> "EnvironmentSnapshot": + """Load snapshot metadata from JSON file. + + Args: + path: Path to JSON file (typically snapshot_dir/snapshot.json) + + Returns: + EnvironmentSnapshot instance + """ + with open(path) as f: + data = json.load(f) + + # Convert string back to pathlib.Path + data["snapshot_dir"] = pathlib.Path(data["snapshot_dir"]) + + return cls(**data) diff --git a/src/ares/environments/snapshot_test.py b/src/ares/environments/snapshot_test.py new file mode 100644 index 0000000..534c081 --- /dev/null +++ b/src/ares/environments/snapshot_test.py @@ -0,0 +1,336 @@ +"""Tests for environment state snapshotting.""" + +import pathlib +import tempfile + +import pytest + +from ares.environments import snapshot +from ares.environments import swebench_env +from ares.testing import mock_container + +# Mock task for testing +_MOCK_SWEBENCH_TASK = swebench_env.SwebenchTask( + repo="test/repo", + instance_id="test-instance-1", + base_commit="abc123", + patch="diff --git a/file.py", + test_patch="diff --git a/test_file.py", + problem_statement="Fix the bug", + hints_text="", + created_at="2024-01-01", + version="1.0", + FAIL_TO_PASS='["test_case_1"]', + PASS_TO_PASS='["test_case_2"]', + environment_setup_commit="def456", +) + + +@pytest.mark.asyncio +async def test_snapshot_dataclass_save_and_load(tmp_path: pathlib.Path): + """Test EnvironmentSnapshot can be saved to and loaded from disk.""" + snap = snapshot.EnvironmentSnapshot( + snapshot_id="test-123", + created_at="2024-01-01T00:00:00", + snapshot_dir=tmp_path / "snapshots" / "test-123", + step_count=5, + step_limit=100, + requires_reset=False, + task_type="swebench", + task_data={"repo": "test/repo", "instance_id": "test-1"}, + container_type="docker", + container_image="python:3.12", + container_dockerfile_path=None, + container_resources={"cpu": 2, "memory": 4096}, + agent_messages=[{"role": "user", "content": "Hello"}], + ) + + snap.snapshot_dir.mkdir(parents=True, exist_ok=True) + snapshot_file = snap.snapshot_dir / "snapshot.json" + + # Save + snap.save_to_file(snapshot_file) + assert snapshot_file.exists() + + # Load + loaded_snap = snapshot.EnvironmentSnapshot.load_from_file(snapshot_file) + assert loaded_snap.snapshot_id == snap.snapshot_id + assert loaded_snap.step_count == snap.step_count + assert loaded_snap.task_data == snap.task_data + assert loaded_snap.snapshot_dir == snap.snapshot_dir + + +def test_swebench_task_serialization(): + """Test SwebenchTask can be serialized and deserialized via SweBenchEnv methods.""" + env = swebench_env.SweBenchEnv( + tasks=[_MOCK_SWEBENCH_TASK], + container_factory=mock_container.MockContainerFactory, + ) + + # Serialize via env method (handles JSON string conversion) + task_data = env._serialize_task(_MOCK_SWEBENCH_TASK) + + # Verify serialization produces a dict with JSON strings + assert isinstance(task_data, dict) + assert task_data["instance_id"] == "test-instance-1" + + # Verify deserialization recreates the task + restored_task = swebench_env.SweBenchEnv._deserialize_task(task_data, "swebench") + assert restored_task.instance_id == _MOCK_SWEBENCH_TASK.instance_id + assert restored_task.repo == _MOCK_SWEBENCH_TASK.repo + + +@pytest.mark.asyncio +async def test_validate_snapshot_allowed_raises_during_active_episode(): + """Test that _validate_snapshot_allowed raises when agent task is running.""" + env = swebench_env.SweBenchEnv( + tasks=[_MOCK_SWEBENCH_TASK], + container_factory=mock_container.MockContainerFactory, + step_limit=100, + ) + + # Create a mock running task (not done) + import asyncio + + async def mock_running_task(): + await asyncio.sleep(100) # Never completes + + env._code_agent_task = asyncio.create_task(mock_running_task()) + + # Should raise because task is running + with pytest.raises(RuntimeError, match="Cannot snapshot during active episode"): + env._validate_snapshot_allowed() + + # Cleanup + env._code_agent_task.cancel() + import contextlib + + with contextlib.suppress(asyncio.CancelledError): + await env._code_agent_task + + +@pytest.mark.asyncio +async def test_validate_snapshot_allowed_succeeds_when_task_done(): + """Test that _validate_snapshot_allowed succeeds when agent task is done.""" + env = swebench_env.SweBenchEnv( + tasks=[_MOCK_SWEBENCH_TASK], + container_factory=mock_container.MockContainerFactory, + step_limit=100, + ) + + # Create a completed task + import asyncio + + async def mock_completed_task(): + return None + + env._code_agent_task = asyncio.create_task(mock_completed_task()) + await env._code_agent_task # Wait for completion + + # Should not raise + env._validate_snapshot_allowed() + + +@pytest.mark.asyncio +async def test_validate_snapshot_allowed_succeeds_when_no_task(): + """Test that _validate_snapshot_allowed succeeds when no agent task exists.""" + env = swebench_env.SweBenchEnv( + tasks=[_MOCK_SWEBENCH_TASK], + container_factory=mock_container.MockContainerFactory, + step_limit=100, + ) + + # No agent task set + env._code_agent_task = None + + # Should not raise + env._validate_snapshot_allowed() + + +def test_swebench_env_serialize_task(): + """Test SweBenchEnv._serialize_task produces correct dict.""" + env = swebench_env.SweBenchEnv( + tasks=[_MOCK_SWEBENCH_TASK], + container_factory=mock_container.MockContainerFactory, + ) + + serialized = env._serialize_task(_MOCK_SWEBENCH_TASK) + + assert isinstance(serialized, dict) + assert serialized["instance_id"] == "test-instance-1" + assert serialized["repo"] == "test/repo" + + +def test_swebench_env_deserialize_task(): + """Test SweBenchEnv._deserialize_task recreates task from dict.""" + env = swebench_env.SweBenchEnv( + tasks=[_MOCK_SWEBENCH_TASK], + container_factory=mock_container.MockContainerFactory, + ) + + # Serialize first (this will convert lists to JSON strings) + task_data = env._serialize_task(_MOCK_SWEBENCH_TASK) + + # Then deserialize + restored_task = swebench_env.SweBenchEnv._deserialize_task(task_data, "swebench") + + assert isinstance(restored_task, swebench_env.SwebenchTask) + assert restored_task.instance_id == "test-instance-1" + assert restored_task.repo == "test/repo" + + +def test_swebench_env_get_task_type(): + """Test SweBenchEnv._get_task_type returns 'swebench'.""" + env = swebench_env.SweBenchEnv( + tasks=[_MOCK_SWEBENCH_TASK], + container_factory=mock_container.MockContainerFactory, + ) + + assert env._get_task_type() == "swebench" + + +def test_get_container_type_daytona(): + """Test _get_container_type identifies Daytona containers.""" + from ares.containers import daytona + + env = swebench_env.SweBenchEnv( + tasks=[_MOCK_SWEBENCH_TASK], + container_factory=mock_container.MockContainerFactory, + ) + + # Mock a Daytona container + mock_daytona = type("MockDaytona", (daytona.DaytonaContainer,), {})() + + container_type = env._get_container_type(mock_daytona) + assert container_type == "daytona" + + +def test_get_container_type_docker(): + """Test _get_container_type identifies Docker containers.""" + env = swebench_env.SweBenchEnv( + tasks=[_MOCK_SWEBENCH_TASK], + container_factory=mock_container.MockContainerFactory, + ) + + # Mock a Docker container (not Daytona) + mock_docker = mock_container.MockContainer() + + container_type = env._get_container_type(mock_docker) + assert container_type == "docker" + + +@pytest.mark.asyncio +async def test_export_state_basic_metadata(tmp_path: pathlib.Path): + """Test export_state creates snapshot with correct metadata.""" + + # Create a mock container with download_dir support + class MockContainerWithDownload(mock_container.MockContainer): + def __init__(self): + super().__init__() + self.resources = None # Add resources attribute + + async def download_dir(self, remote_path: str, local_path: pathlib.Path): + """Mock download_dir that creates an empty tarball.""" + del remote_path # Unused in mock + local_path.parent.mkdir(parents=True, exist_ok=True) + local_path.write_text("mock tarball content") + + env = swebench_env.SweBenchEnv( + tasks=[_MOCK_SWEBENCH_TASK], + container_factory=mock_container.MockContainerFactory, + step_limit=50, + ) + + # Set up minimal state + container = MockContainerWithDownload() + await container.start() + env._container = container + env._current_task = _MOCK_SWEBENCH_TASK + env._step_count = 10 + env._requires_reset = False + env._code_agent_task = None # No running task + + # Export state + snap = await env.export_state(tmp_path, snapshot_id="test-export-123") + + # Verify metadata + assert snap.snapshot_id == "test-export-123" + assert snap.step_count == 10 + assert snap.step_limit == 50 + assert snap.requires_reset is False + assert snap.task_type == "swebench" + assert snap.task_data["instance_id"] == "test-instance-1" + assert snap.container_type == "docker" + + # Verify files were created + assert (snap.snapshot_dir / "snapshot.json").exists() + assert (snap.snapshot_dir / "container_fs.tar.gz").exists() + + +@pytest.mark.asyncio +async def test_export_state_auto_generates_snapshot_id(tmp_path: pathlib.Path): + """Test export_state auto-generates UUID when snapshot_id not provided.""" + + # Create a mock container with download_dir support + class MockContainerWithDownload(mock_container.MockContainer): + def __init__(self): + super().__init__() + self.resources = None # Add resources attribute + + async def download_dir(self, remote_path: str, local_path: pathlib.Path): + del remote_path # Unused in mock + local_path.parent.mkdir(parents=True, exist_ok=True) + local_path.write_text("mock tarball") + + env = swebench_env.SweBenchEnv( + tasks=[_MOCK_SWEBENCH_TASK], + container_factory=mock_container.MockContainerFactory, + ) + + # Set up minimal state + container = MockContainerWithDownload() + await container.start() + env._container = container + env._current_task = _MOCK_SWEBENCH_TASK + env._code_agent_task = None + + # Export without snapshot_id + snap = await env.export_state(tmp_path) + + # Should have a UUID-like snapshot_id + assert snap.snapshot_id is not None + assert len(snap.snapshot_id) > 0 + + +@pytest.mark.asyncio +async def test_export_state_raises_if_no_container(): + """Test export_state raises if container not available.""" + env = swebench_env.SweBenchEnv( + tasks=[_MOCK_SWEBENCH_TASK], + container_factory=mock_container.MockContainerFactory, + ) + + env._container = None + env._current_task = _MOCK_SWEBENCH_TASK + env._code_agent_task = None + + with tempfile.TemporaryDirectory() as tmp_dir, pytest.raises(RuntimeError, match="Container is not available"): + await env.export_state(pathlib.Path(tmp_dir)) + + +@pytest.mark.asyncio +async def test_export_state_raises_if_no_task(): + """Test export_state raises if current task not available.""" + env = swebench_env.SweBenchEnv( + tasks=[_MOCK_SWEBENCH_TASK], + container_factory=mock_container.MockContainerFactory, + ) + + container = mock_container.MockContainer() + await container.start() + env._container = container + env._current_task = None + env._code_agent_task = None + + with tempfile.TemporaryDirectory() as tmp_dir, pytest.raises(RuntimeError, match="No current task set"): + await env.export_state(pathlib.Path(tmp_dir)) diff --git a/src/ares/environments/swebench_env.py b/src/ares/environments/swebench_env.py index 62feaa6..20d75de 100644 --- a/src/ares/environments/swebench_env.py +++ b/src/ares/environments/swebench_env.py @@ -12,7 +12,7 @@ import logging import random import time -from typing import Any, cast +from typing import Any, Literal, cast import datasets import pydantic @@ -247,3 +247,22 @@ async def _compute_reward(self) -> float: test_result = await _run_tests_and_evaluate(self._container, self._current_task, self._test_spec) return 1.0 if test_result["resolved"] else 0.0 + + def _get_task_type(self) -> Literal["swebench", "harbor"]: + """Return task type for snapshotting.""" + return "swebench" + + def _serialize_task(self, task: SwebenchTask) -> dict: + """Serialize SwebenchTask using Pydantic.""" + data = task.model_dump() + # Convert lists back to JSON strings for field validators + data["FAIL_TO_PASS"] = json.dumps(data["FAIL_TO_PASS"]) + data["PASS_TO_PASS"] = json.dumps(data["PASS_TO_PASS"]) + return data + + @classmethod + def _deserialize_task(cls, task_data: dict, task_type: str) -> SwebenchTask: + """Deserialize SwebenchTask using Pydantic.""" + del task_type # Unused - validated by caller + # The field validators will convert JSON strings to lists + return SwebenchTask.model_validate(task_data) From 17836cbef7b865bdbe7c7c1909fc8d8a664fd063 Mon Sep 17 00:00:00 2001 From: Nithin Date: Mon, 26 Jan 2026 14:00:47 -0500 Subject: [PATCH 2/2] Fix load_from_state to work with subclass constructors Moved load_from_state to SweBenchEnv and HarborEnv since they need different constructor arguments (tasks list). Base class now provides _restore_from_snapshot helper that subclasses call after init. - Add load_from_state implementation to SweBenchEnv - Add load_from_state implementation to HarborEnv - Refactor base class to use _restore_from_snapshot helper - Add test for load_from_state functionality --- src/ares/environments/base.py | 56 ++++++----------------- src/ares/environments/harbor_env.py | 48 ++++++++++++++++++++ src/ares/environments/snapshot_test.py | 62 ++++++++++++++++++++++++++ src/ares/environments/swebench_env.py | 51 +++++++++++++++++++++ 4 files changed, 175 insertions(+), 42 deletions(-) diff --git a/src/ares/environments/base.py b/src/ares/environments/base.py index 64fbd60..687b24d 100644 --- a/src/ares/environments/base.py +++ b/src/ares/environments/base.py @@ -565,55 +565,27 @@ async def export_state( return snap - @classmethod - async def load_from_state( - cls, - snapshot_path: snapshot.EnvironmentSnapshot | pathlib.Path, - *, - container_factory: containers.ContainerFactory | None = None, - code_agent_factory: code_agent_base.CodeAgentFactory | None = None, - ) -> "CodeBaseEnv": - """Restore environment from snapshot. + async def _restore_from_snapshot(self, snap: snapshot.EnvironmentSnapshot) -> None: + """Internal helper to restore state from snapshot. - Args: - snapshot_path: EnvironmentSnapshot or path to snapshot.json - container_factory: Override factory (uses snapshot metadata if None) - code_agent_factory: Override factory (uses default if None) + Called by subclass load_from_state implementations after creating the environment. - Returns: - Restored environment (NOT active, use async with) + Args: + snap: The snapshot to restore from """ - # Load snapshot if path provided - if isinstance(snapshot_path, pathlib.Path): - snap = snapshot.EnvironmentSnapshot.load_from_file(snapshot_path) - else: - snap = snapshot_path - - # Deserialize task - task = cls._deserialize_task(snap.task_data, snap.task_type) - - # Create environment instance - container_factory = container_factory or cls._default_container_factory(snap) - code_agent_factory = code_agent_factory or mini_swe_agent.MiniSWECodeAgent + # Restore container + self._container = await self._restore_container(snap) - # Note: This creates a base CodeBaseEnv which is abstract - # In practice, this should be called on SweBenchEnv or HarborEnv subclasses - env = cls( - container_factory=container_factory, - code_agent_factory=code_agent_factory, - step_limit=snap.step_limit, - ) + # Deserialize and set task + task = self._deserialize_task(snap.task_data, snap.task_type) + self._current_task = task - # Restore container - env._container = await env._restore_container(snap) - env._current_task = task - env._step_count = snap.step_count - env._requires_reset = snap.requires_reset + # Restore state + self._step_count = snap.step_count + self._requires_reset = snap.requires_reset # Store messages for later restoration - env._saved_agent_messages = snap.agent_messages - - return env + self._saved_agent_messages = snap.agent_messages @abc.abstractmethod def _serialize_task(self, task: TaskType) -> dict: diff --git a/src/ares/environments/harbor_env.py b/src/ares/environments/harbor_env.py index 5ce7e8e..e25f9a5 100644 --- a/src/ares/environments/harbor_env.py +++ b/src/ares/environments/harbor_env.py @@ -179,3 +179,51 @@ def _deserialize_task(cls, task_data: dict, task_type: str) -> harbor_task.Task: """Deserialize Harbor task (reload from directory).""" del task_type # Unused - validated by caller return harbor_task.Task(task_dir=pathlib.Path(task_data["task_dir"])) + + @classmethod + async def load_from_state( + cls, + snapshot_path: "base.snapshot.EnvironmentSnapshot | pathlib.Path", + *, + container_factory: containers.ContainerFactory | None = None, + code_agent_factory: code_agent_base.CodeAgentFactory | None = None, + tracker: stat_tracker.StatTracker | None = None, + ) -> "HarborEnv": + """Restore HarborEnv from snapshot. + + Args: + snapshot_path: EnvironmentSnapshot or path to snapshot.json + container_factory: Override factory (uses snapshot metadata if None) + code_agent_factory: Override factory (uses default if None) + tracker: Optional stat tracker + + Returns: + Restored environment (NOT active, use async with) + """ + from ares.environments import snapshot as snapshot_module + + # Load snapshot if path provided + if isinstance(snapshot_path, pathlib.Path): + snap = snapshot_module.EnvironmentSnapshot.load_from_file(snapshot_path) + else: + snap = snapshot_path + + # Deserialize task + task = cls._deserialize_task(snap.task_data, snap.task_type) + + # Create environment instance with tasks argument + container_factory = container_factory or base.ares_daytona.DaytonaContainer + code_agent_factory = code_agent_factory or mini_swe_agent.MiniSWECodeAgent + + env = cls( + tasks=[task], + container_factory=container_factory, + code_agent_factory=code_agent_factory, + step_limit=snap.step_limit, + tracker=tracker, + ) + + # Restore state using base helper + await env._restore_from_snapshot(snap) + + return env diff --git a/src/ares/environments/snapshot_test.py b/src/ares/environments/snapshot_test.py index 534c081..0669325 100644 --- a/src/ares/environments/snapshot_test.py +++ b/src/ares/environments/snapshot_test.py @@ -334,3 +334,65 @@ async def test_export_state_raises_if_no_task(): with tempfile.TemporaryDirectory() as tmp_dir, pytest.raises(RuntimeError, match="No current task set"): await env.export_state(pathlib.Path(tmp_dir)) + + +@pytest.mark.asyncio +async def test_load_from_state_creates_valid_env(tmp_path: pathlib.Path): + """Test load_from_state creates a properly initialized environment.""" + + # Create a mock container with download_dir and upload_dir support + class MockContainerWithDirOps(mock_container.MockContainer): + def __init__(self): + super().__init__() + self.resources = None + self.image = "python:3.12" # Add image attribute for snapshot + + async def download_dir(self, remote_path: str, local_path: pathlib.Path): + del remote_path # Unused in mock + local_path.parent.mkdir(parents=True, exist_ok=True) + local_path.write_text("mock tarball") + + async def upload_dir(self, local_path: pathlib.Path, remote_path: str): + """Mock upload_dir for container restoration.""" + del local_path, remote_path # Unused in mock + + # Create and export state + env = swebench_env.SweBenchEnv( + tasks=[_MOCK_SWEBENCH_TASK], + container_factory=mock_container.MockContainerFactory, + step_limit=42, + ) + + container = MockContainerWithDirOps() + await container.start() + env._container = container + env._current_task = _MOCK_SWEBENCH_TASK + env._step_count = 7 + env._requires_reset = False + env._code_agent_task = None + + snap = await env.export_state(tmp_path, snapshot_id="test-load") + + # Load from snapshot + class MockContainerFactory: + @classmethod + def from_image(cls, *, image: str, name: str | None = None, resources=None): + del image, name, resources + return MockContainerWithDirOps() + + @classmethod + def from_dockerfile(cls, *, dockerfile_path, name: str | None = None, resources=None): + del dockerfile_path, name, resources + return MockContainerWithDirOps() + + restored_env = await swebench_env.SweBenchEnv.load_from_state(snap, container_factory=MockContainerFactory) + + # Verify restoration + assert restored_env._step_count == 7 + assert restored_env._step_limit == 42 + assert restored_env._requires_reset is False + assert restored_env._current_task.instance_id == _MOCK_SWEBENCH_TASK.instance_id + assert restored_env._container is not None + + # Cleanup + await restored_env.close() diff --git a/src/ares/environments/swebench_env.py b/src/ares/environments/swebench_env.py index 20d75de..939093e 100644 --- a/src/ares/environments/swebench_env.py +++ b/src/ares/environments/swebench_env.py @@ -10,6 +10,7 @@ import functools import json import logging +import pathlib import random import time from typing import Any, Literal, cast @@ -266,3 +267,53 @@ def _deserialize_task(cls, task_data: dict, task_type: str) -> SwebenchTask: del task_type # Unused - validated by caller # The field validators will convert JSON strings to lists return SwebenchTask.model_validate(task_data) + + @classmethod + async def load_from_state( + cls, + snapshot_path: "base.snapshot.EnvironmentSnapshot | pathlib.Path", + *, + container_factory: containers.ContainerFactory | None = None, + code_agent_factory: code_agent_base.CodeAgentFactory | None = None, + tracker: stat_tracker.StatTracker | None = None, + ) -> "SweBenchEnv": + """Restore SweBenchEnv from snapshot. + + Args: + snapshot_path: EnvironmentSnapshot or path to snapshot.json + container_factory: Override factory (uses snapshot metadata if None) + code_agent_factory: Override factory (uses default if None) + tracker: Optional stat tracker + + Returns: + Restored environment (NOT active, use async with) + """ + import pathlib as pathlib_module + + from ares.environments import snapshot as snapshot_module + + # Load snapshot if path provided + if isinstance(snapshot_path, pathlib_module.Path): + snap = snapshot_module.EnvironmentSnapshot.load_from_file(snapshot_path) + else: + snap = snapshot_path + + # Deserialize task + task = cls._deserialize_task(snap.task_data, snap.task_type) + + # Create environment instance with tasks argument + container_factory = container_factory or base.ares_daytona.DaytonaContainer + code_agent_factory = code_agent_factory or mini_swe_agent.MiniSWECodeAgent + + env = cls( + tasks=[task], + container_factory=container_factory, + code_agent_factory=code_agent_factory, + step_limit=snap.step_limit, + tracker=tracker, + ) + + # Restore state using base helper + await env._restore_from_snapshot(snap) + + return env