diff --git a/src/ares/__init__.py b/src/ares/__init__.py index 9d40675..b5d97b2 100644 --- a/src/ares/__init__.py +++ b/src/ares/__init__.py @@ -16,6 +16,12 @@ >>> tracker = stat_tracker.LoggingStatTracker() >>> env = ares.make("sbv-mswea", container_factory=daytona.DaytonaContainer, tracker=tracker) +Collect episode trajectories: + + >>> from ares.environments.trajectory import JsonTrajectoryCollector + >>> collector = JsonTrajectoryCollector(output_dir="./trajectories") + >>> env = ares.make("sbv-mswea", trajectory_collector=collector) + To see available presets: >>> all_presets = ares.info() # Get list of all presets @@ -35,7 +41,7 @@ >>> ares.registry.register_preset("my-env", MyEnvSpec()) All other functionality is available via submodules: -- ares.environments: Environment implementations +- ares.environments: Environment implementations and trajectory collection - ares.code_agents: Code agent implementations - ares.containers: Container management - ares.llms: LLM client implementations @@ -47,6 +53,12 @@ from ares import presets # noqa: F401 from ares.environments.base import Environment from ares.environments.base import TimeStep + +# Trajectory collection +from ares.environments.trajectory import EpisodeTrajectory +from ares.environments.trajectory import JsonTrajectoryCollector +from ares.environments.trajectory import StepRecord +from ares.environments.trajectory import TrajectoryCollector from ares.registry import EnvironmentInfo # Import registry functions to expose at top level @@ -58,7 +70,11 @@ __all__ = [ "Environment", "EnvironmentInfo", + "EpisodeTrajectory", + "JsonTrajectoryCollector", + "StepRecord", "TimeStep", + "TrajectoryCollector", "info", "list_presets", "make", diff --git a/src/ares/environments/__init__.py b/src/ares/environments/__init__.py index e69de29..635851c 100644 --- a/src/ares/environments/__init__.py +++ b/src/ares/environments/__init__.py @@ -0,0 +1,13 @@ +"""Environment implementations for ARES.""" + +from ares.environments.trajectory import EpisodeTrajectory +from ares.environments.trajectory import JsonTrajectoryCollector +from ares.environments.trajectory import StepRecord +from ares.environments.trajectory import TrajectoryCollector + +__all__ = [ + "EpisodeTrajectory", + "JsonTrajectoryCollector", + "StepRecord", + "TrajectoryCollector", +] diff --git a/src/ares/environments/code_env.py b/src/ares/environments/code_env.py index 5c8b30a..67f356f 100644 --- a/src/ares/environments/code_env.py +++ b/src/ares/environments/code_env.py @@ -27,6 +27,7 @@ from ares.containers import containers from ares.containers import daytona as ares_daytona from ares.environments import base +from ares.environments import trajectory as trajectory_lib from ares.experiment_tracking import stat_tracker from ares.llms import queue_mediated_client from ares.llms import request @@ -67,6 +68,7 @@ def __init__( step_limit: int = 250, # Same as mini-swe-agent default. prefix: str = "harbor_env", tracker: stat_tracker.StatTracker | None = None, + trajectory_collector: trajectory_lib.TrajectoryCollector | None = None, ): self._tasks = tasks self._container_factory = container_factory @@ -74,6 +76,7 @@ def __init__( self._step_limit = step_limit self._prefix = prefix self._tracker = tracker if tracker is not None else stat_tracker.NullStatTracker() + self._trajectory_collector = trajectory_collector if trajectory_collector is not None else trajectory_lib.NullTrajectoryCollector() # We set the LLM client to a queue mediated client so that # we can return LLM requests in the reset and step methods. @@ -122,6 +125,22 @@ async def reset(self) -> base.TimeStep[request.LLMRequest, float, float]: assert ts.observation is not None result = base.TimeStep(step_type="FIRST", reward=ts.reward, discount=ts.discount, observation=ts.observation) + # Record the FIRST step in the trajectory. + # FIRST steps have only observation; action/reward/discount are None per dm_env semantics. + assert self._current_task is not None + self._trajectory_collector.begin_episode(task_name=self._current_task.name) + self._trajectory_collector.record_step( + trajectory_lib.StepRecord( + step_index=0, + step_type="FIRST", + observation=trajectory_lib.serialize_llm_request(result.observation), + action=None, + reward=None, + discount=None, + timestamp=time.time(), + ) + ) + reset_end_time = time.time() self._tracker.scalar(f"{self._prefix}/reset", reset_end_time - reset_start_time) return result @@ -145,16 +164,37 @@ async def step(self, action: response.LLMResponse) -> base.TimeStep[request.LLMR with self._tracker.timeit(f"{self._prefix}/get_time_step"): ts = await self._get_time_step() + truncated = False if self._step_count >= self._step_limit: _LOGGER.debug("[%d] Step limit reached. Returning LAST timestep.", id(self)) assert self._code_agent_task is not None self._code_agent_task.cancel() # Truncation: step_type="LAST", discount=1.0, unless we're _also_ already in a terminal state. + truncated = ts.step_type != "LAST" ts = base.TimeStep(step_type="LAST", reward=ts.reward, discount=ts.discount, observation=ts.observation) if ts.step_type == "LAST": self._requires_reset = True + # Record the step in the trajectory. + self._trajectory_collector.record_step( + trajectory_lib.StepRecord( + step_index=self._step_count, + step_type=ts.step_type, + observation=( + trajectory_lib.serialize_llm_request(ts.observation) + if ts.observation is not None + else None + ), + action=trajectory_lib.serialize_llm_response(action), + reward=ts.reward, + discount=ts.discount, + timestamp=time.time(), + ) + ) + if ts.step_type == "LAST": + self._trajectory_collector.end_episode(truncated=truncated) + step_end_time = time.time() self._tracker.scalar(f"{self._prefix}/step", step_end_time - step_start_time) diff --git a/src/ares/environments/trajectory.py b/src/ares/environments/trajectory.py new file mode 100644 index 0000000..6dabc2a --- /dev/null +++ b/src/ares/environments/trajectory.py @@ -0,0 +1,294 @@ +"""Episode trajectory collection for ARES environments. + +Provides data models and collectors for recording episode trajectories. +Trajectories capture the full sequence of (observation, action, reward, discount) +tuples that flow through the environment loop, enabling: + +- Behavior cloning (learning from recorded expert episodes) +- Batch / offline RL (training on collected experience) +- Debugging and analysis (replaying what happened in an episode) + +Usage: + + >>> from ares.environments.trajectory import JsonTrajectoryCollector + >>> collector = JsonTrajectoryCollector(output_dir="./trajectories") + >>> env = CodeEnvironment(tasks=..., trajectory_collector=collector) + +Trajectories are stored as one JSON file per episode. Each file contains episode +metadata (task name, timing, reward, truncation status) and an ordered list of +step records. + +Step record semantics follow the dm_env convention: + + - FIRST step (from reset): observation is set; action/reward/discount are None. + - MID steps (from step): action is the LLMResponse provided, observation is + the next LLMRequest, reward is the intermediate reward (usually 0.0). + - LAST step (from step): action is the final LLMResponse, observation may be + None (terminal), reward is the episode reward. +""" + +import dataclasses +import json +import logging +import pathlib +import time +from typing import Any, Protocol, runtime_checkable +import uuid + +from ares.environments.base import StepType +from ares.llms import request +from ares.llms import response + +_LOGGER = logging.getLogger(__name__) + + +# --------------------------------------------------------------------------- +# Serialization helpers +# --------------------------------------------------------------------------- + + +def serialize_llm_request(req: request.LLMRequest) -> dict[str, Any]: + """Serialize an LLMRequest to a JSON-compatible dict.""" + return dataclasses.asdict(req) + + +def serialize_llm_response(resp: response.LLMResponse) -> dict[str, Any]: + """Serialize an LLMResponse to a JSON-compatible dict.""" + return dataclasses.asdict(resp) + + +# --------------------------------------------------------------------------- +# Data models +# --------------------------------------------------------------------------- + + +@dataclasses.dataclass(frozen=True) +class StepRecord: + """Records the data from a single environment step. + + For the FIRST step (from ``reset()``): + - ``observation`` is the initial ``LLMRequest`` (serialized). + - ``action``, ``reward``, and ``discount`` are ``None``. + + For MID steps (from ``step()``): + - ``action`` is the ``LLMResponse`` that was provided to ``step()``. + - ``observation`` is the resulting ``LLMRequest``. + - ``reward`` and ``discount`` are from the returned ``TimeStep``. + + For the LAST step (from ``step()``): + - ``action`` is the final ``LLMResponse`` provided to ``step()``. + - ``observation`` may be ``None`` (terminal state). + - ``reward`` contains the episode reward. + - ``discount`` is ``0.0`` (terminal) or ``1.0`` (truncated). + """ + + step_index: int + step_type: StepType + observation: dict[str, Any] | None + action: dict[str, Any] | None + reward: float | None + discount: float | None + timestamp: float + + def to_dict(self) -> dict[str, Any]: + """Convert to a plain dict for JSON serialization.""" + return dataclasses.asdict(self) + + @classmethod + def from_dict(cls, data: dict[str, Any]) -> "StepRecord": + """Reconstruct a StepRecord from a plain dict.""" + return cls(**data) + + +@dataclasses.dataclass +class EpisodeTrajectory: + """A complete episode trajectory with metadata and step records. + + Attributes: + episode_id: Unique identifier for this episode. + task_name: The name of the task that was run. + steps: Ordered list of step records. + start_time: Wall-clock time when the episode started (``time.time()``). + end_time: Wall-clock time when the episode ended. + total_reward: The reward from the final step. + num_steps: Total number of steps recorded. + truncated: Whether the episode was truncated (step limit reached). + """ + + episode_id: str + task_name: str + steps: list[StepRecord] + start_time: float + end_time: float | None = None + total_reward: float | None = None + num_steps: int = 0 + truncated: bool = False + + def to_dict(self) -> dict[str, Any]: + """Convert to a plain dict for JSON serialization.""" + return { + "episode_id": self.episode_id, + "task_name": self.task_name, + "start_time": self.start_time, + "end_time": self.end_time, + "total_reward": self.total_reward, + "num_steps": self.num_steps, + "truncated": self.truncated, + "steps": [step.to_dict() for step in self.steps], + } + + @classmethod + def from_dict(cls, data: dict[str, Any]) -> "EpisodeTrajectory": + """Reconstruct an EpisodeTrajectory from a plain dict.""" + steps = [StepRecord.from_dict(s) for s in data.get("steps", [])] + return cls( + episode_id=data["episode_id"], + task_name=data["task_name"], + steps=steps, + start_time=data["start_time"], + end_time=data.get("end_time"), + total_reward=data.get("total_reward"), + num_steps=data.get("num_steps", len(steps)), + truncated=data.get("truncated", False), + ) + + @classmethod + def load(cls, path: str | pathlib.Path) -> "EpisodeTrajectory": + """Load an EpisodeTrajectory from a JSON file.""" + with open(path) as f: + return cls.from_dict(json.load(f)) + + +# --------------------------------------------------------------------------- +# Collector protocol and implementations +# --------------------------------------------------------------------------- + + +@runtime_checkable +class TrajectoryCollector(Protocol): + """Protocol for collecting episode trajectories. + + Implementations receive step-by-step data from the environment loop + and persist or aggregate it however they choose. + + Lifecycle:: + + collector.begin_episode(task_name="my-task") + collector.record_step(first_step_record) + collector.record_step(mid_step_record) + ... + trajectory = collector.end_episode(truncated=False) + """ + + def begin_episode(self, *, task_name: str) -> None: + """Signal the start of a new episode.""" + ... + + def record_step(self, record: StepRecord) -> None: + """Record a single step within the current episode.""" + ... + + def end_episode(self, *, truncated: bool = False) -> EpisodeTrajectory: + """Finalize the current episode and return the completed trajectory.""" + ... + + +class NullTrajectoryCollector: + """No-op trajectory collector that discards all data.""" + + def begin_episode(self, *, task_name: str) -> None: + del task_name + + def record_step(self, record: StepRecord) -> None: + del record + + def end_episode(self, *, truncated: bool = False) -> EpisodeTrajectory: + del truncated + return EpisodeTrajectory( + episode_id="", task_name="", steps=[], start_time=0.0 + ) + + +class JsonTrajectoryCollector: + """Collects trajectories and persists each episode as a JSON file. + + Files are named ``{episode_id}.json`` and written to *output_dir*. + + Args: + output_dir: Directory where episode JSON files will be saved. + Created automatically if it does not exist. + """ + + def __init__(self, output_dir: str | pathlib.Path): + self._output_dir = pathlib.Path(output_dir) + self._output_dir.mkdir(parents=True, exist_ok=True) + self._current_episode: EpisodeTrajectory | None = None + + @property + def output_dir(self) -> pathlib.Path: + """The directory where trajectory files are written.""" + return self._output_dir + + def begin_episode(self, *, task_name: str) -> None: + """Start recording a new episode. + + If a previous episode was not ended, it is discarded with a warning. + """ + if self._current_episode is not None: + _LOGGER.warning( + "Previous episode %s was not ended — discarding %d steps.", + self._current_episode.episode_id, + len(self._current_episode.steps), + ) + + episode_id = str(uuid.uuid4()) + self._current_episode = EpisodeTrajectory( + episode_id=episode_id, + task_name=task_name, + steps=[], + start_time=time.time(), + ) + _LOGGER.debug( + "Started trajectory collection for episode %s (task: %s).", + episode_id, + task_name, + ) + + def record_step(self, record: StepRecord) -> None: + """Append a step record to the current episode.""" + if self._current_episode is None: + raise RuntimeError("No episode in progress. Call begin_episode() first.") + self._current_episode.steps.append(record) + + def end_episode(self, *, truncated: bool = False) -> EpisodeTrajectory: + """Finalize the current episode, write it to disk, and return it.""" + if self._current_episode is None: + raise RuntimeError("No episode in progress. Call begin_episode() first.") + + episode = self._current_episode + episode.end_time = time.time() + episode.truncated = truncated + episode.num_steps = len(episode.steps) + + # Extract total reward from the last step. + if episode.steps: + last_step = episode.steps[-1] + episode.total_reward = last_step.reward + + # Persist to disk. + filename = f"{episode.episode_id}.json" + filepath = self._output_dir / filename + with open(filepath, "w") as f: + json.dump(episode.to_dict(), f, indent=2) + + _LOGGER.info( + "Saved trajectory for episode %s: %d steps, reward=%s, truncated=%s → %s", + episode.episode_id, + episode.num_steps, + episode.total_reward, + truncated, + filepath, + ) + + self._current_episode = None + return episode diff --git a/src/ares/environments/trajectory_test.py b/src/ares/environments/trajectory_test.py new file mode 100644 index 0000000..5cda90b --- /dev/null +++ b/src/ares/environments/trajectory_test.py @@ -0,0 +1,394 @@ +"""Tests for episode trajectory collection.""" + +import json +import pathlib +import time + +import pytest + +from ares.environments import trajectory +from ares.llms import request +from ares.llms import response + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def _make_llm_request(content: str = "Hello") -> request.LLMRequest: + """Create a simple LLMRequest for testing.""" + return request.LLMRequest(messages=[{"role": "user", "content": content}]) + + +def _make_llm_response(content: str = "Reply") -> response.LLMResponse: + """Create a simple LLMResponse for testing.""" + return response.LLMResponse( + data=[response.TextData(content=content)], + cost=0.01, + usage=response.Usage(prompt_tokens=50, generated_tokens=25), + ) + + +def _make_step_record( + step_index: int = 0, + step_type: str = "MID", + with_observation: bool = True, + with_action: bool = True, + reward: float | None = 0.0, + discount: float | None = 1.0, +) -> trajectory.StepRecord: + """Create a StepRecord for testing.""" + return trajectory.StepRecord( + step_index=step_index, + step_type=step_type, + observation=trajectory.serialize_llm_request(_make_llm_request()) if with_observation else None, + action=trajectory.serialize_llm_response(_make_llm_response()) if with_action else None, + reward=reward, + discount=discount, + timestamp=time.time(), + ) + + +# --------------------------------------------------------------------------- +# Serialization helpers +# --------------------------------------------------------------------------- + + +class TestSerializeLLMRequest: + def test_basic_request(self): + req = _make_llm_request("Test message") + result = trajectory.serialize_llm_request(req) + + assert isinstance(result, dict) + assert result["messages"] == [{"role": "user", "content": "Test message"}] + + def test_request_with_all_fields(self): + req = request.LLMRequest( + messages=[{"role": "user", "content": "Hello"}], + max_output_tokens=100, + temperature=0.7, + top_p=0.9, + system_prompt="You are a helpful assistant.", + ) + result = trajectory.serialize_llm_request(req) + + assert result["max_output_tokens"] == 100 + assert result["temperature"] == 0.7 + assert result["top_p"] == 0.9 + assert result["system_prompt"] == "You are a helpful assistant." + + def test_result_is_json_serializable(self): + req = _make_llm_request("Test") + result = trajectory.serialize_llm_request(req) + # Should not raise + json.dumps(result) + + +class TestSerializeLLMResponse: + def test_basic_response(self): + resp = _make_llm_response("Test reply") + result = trajectory.serialize_llm_response(resp) + + assert isinstance(result, dict) + assert result["data"] == [{"content": "Test reply"}] + assert result["cost"] == 0.01 + assert result["usage"]["prompt_tokens"] == 50 + assert result["usage"]["generated_tokens"] == 25 + + def test_result_is_json_serializable(self): + resp = _make_llm_response("Test") + result = trajectory.serialize_llm_response(resp) + # Should not raise + json.dumps(result) + + +# --------------------------------------------------------------------------- +# StepRecord +# --------------------------------------------------------------------------- + + +class TestStepRecord: + def test_create_first_step(self): + record = _make_step_record(step_index=0, step_type="FIRST", with_action=False, reward=None, discount=None) + assert record.step_index == 0 + assert record.step_type == "FIRST" + assert record.observation is not None + assert record.action is None + assert record.reward is None + assert record.discount is None + + def test_create_mid_step(self): + record = _make_step_record(step_index=3, step_type="MID", reward=0.0, discount=1.0) + assert record.step_index == 3 + assert record.step_type == "MID" + assert record.observation is not None + assert record.action is not None + assert record.reward == 0.0 + assert record.discount == 1.0 + + def test_create_last_step(self): + record = _make_step_record( + step_index=10, step_type="LAST", with_observation=False, reward=1.0, discount=0.0 + ) + assert record.step_index == 10 + assert record.step_type == "LAST" + assert record.observation is None + assert record.reward == 1.0 + assert record.discount == 0.0 + + def test_to_dict(self): + record = _make_step_record(step_index=5, step_type="MID") + d = record.to_dict() + + assert isinstance(d, dict) + assert d["step_index"] == 5 + assert d["step_type"] == "MID" + assert "observation" in d + assert "action" in d + assert "reward" in d + assert "discount" in d + assert "timestamp" in d + + def test_to_dict_json_serializable(self): + record = _make_step_record() + d = record.to_dict() + json.dumps(d) # Should not raise. + + def test_from_dict_roundtrip(self): + original = _make_step_record(step_index=7, step_type="MID", reward=0.5, discount=0.99) + d = original.to_dict() + restored = trajectory.StepRecord.from_dict(d) + + assert restored.step_index == original.step_index + assert restored.step_type == original.step_type + assert restored.observation == original.observation + assert restored.action == original.action + assert restored.reward == original.reward + assert restored.discount == original.discount + assert restored.timestamp == original.timestamp + + def test_frozen(self): + record = _make_step_record() + with pytest.raises(AttributeError): + record.step_index = 999 # type: ignore[misc] + + +# --------------------------------------------------------------------------- +# EpisodeTrajectory +# --------------------------------------------------------------------------- + + +class TestEpisodeTrajectory: + def test_create(self): + traj = trajectory.EpisodeTrajectory( + episode_id="test-123", + task_name="my-task", + steps=[], + start_time=time.time(), + ) + assert traj.episode_id == "test-123" + assert traj.task_name == "my-task" + assert traj.steps == [] + assert traj.end_time is None + assert traj.total_reward is None + assert traj.num_steps == 0 + assert traj.truncated is False + + def test_to_dict(self): + steps = [ + _make_step_record(step_index=0, step_type="FIRST", with_action=False, reward=None, discount=None), + _make_step_record(step_index=1, step_type="MID"), + _make_step_record(step_index=2, step_type="LAST", with_observation=False, reward=1.0, discount=0.0), + ] + traj = trajectory.EpisodeTrajectory( + episode_id="ep-abc", + task_name="swe-bench-123", + steps=steps, + start_time=1000.0, + end_time=1060.0, + total_reward=1.0, + num_steps=3, + truncated=False, + ) + d = traj.to_dict() + + assert d["episode_id"] == "ep-abc" + assert d["task_name"] == "swe-bench-123" + assert d["start_time"] == 1000.0 + assert d["end_time"] == 1060.0 + assert d["total_reward"] == 1.0 + assert d["num_steps"] == 3 + assert d["truncated"] is False + assert len(d["steps"]) == 3 + + def test_to_dict_json_serializable(self): + steps = [_make_step_record(step_index=0, step_type="MID")] + traj = trajectory.EpisodeTrajectory( + episode_id="ep-1", + task_name="task-1", + steps=steps, + start_time=time.time(), + ) + json.dumps(traj.to_dict()) # Should not raise. + + def test_from_dict_roundtrip(self): + steps = [ + _make_step_record(step_index=0, step_type="FIRST", with_action=False, reward=None, discount=None), + _make_step_record(step_index=1, step_type="LAST", reward=0.75, discount=0.0), + ] + original = trajectory.EpisodeTrajectory( + episode_id="ep-roundtrip", + task_name="test-task", + steps=steps, + start_time=1000.0, + end_time=1010.0, + total_reward=0.75, + num_steps=2, + truncated=True, + ) + d = original.to_dict() + restored = trajectory.EpisodeTrajectory.from_dict(d) + + assert restored.episode_id == original.episode_id + assert restored.task_name == original.task_name + assert restored.start_time == original.start_time + assert restored.end_time == original.end_time + assert restored.total_reward == original.total_reward + assert restored.num_steps == original.num_steps + assert restored.truncated == original.truncated + assert len(restored.steps) == len(original.steps) + + def test_load(self, tmp_path: pathlib.Path): + steps = [_make_step_record(step_index=0, step_type="MID", reward=0.5)] + traj = trajectory.EpisodeTrajectory( + episode_id="ep-load", + task_name="task-load", + steps=steps, + start_time=100.0, + end_time=200.0, + total_reward=0.5, + num_steps=1, + ) + filepath = tmp_path / "test_episode.json" + with open(filepath, "w") as f: + json.dump(traj.to_dict(), f) + + loaded = trajectory.EpisodeTrajectory.load(filepath) + assert loaded.episode_id == "ep-load" + assert loaded.task_name == "task-load" + assert loaded.total_reward == 0.5 + assert len(loaded.steps) == 1 + + +# --------------------------------------------------------------------------- +# JsonTrajectoryCollector +# --------------------------------------------------------------------------- + + +class TestJsonTrajectoryCollector: + def test_creates_output_dir(self, tmp_path: pathlib.Path): + output_dir = tmp_path / "trajectories" / "nested" + collector = trajectory.JsonTrajectoryCollector(output_dir=output_dir) + assert output_dir.exists() + assert collector.output_dir == output_dir + + def test_full_episode_saves_file(self, tmp_path: pathlib.Path): + collector = trajectory.JsonTrajectoryCollector(output_dir=tmp_path) + + collector.begin_episode(task_name="json-task") + collector.record_step( + _make_step_record(step_index=0, step_type="FIRST", with_action=False, reward=None, discount=None) + ) + collector.record_step(_make_step_record(step_index=1, step_type="MID")) + collector.record_step( + _make_step_record(step_index=2, step_type="LAST", with_observation=False, reward=0.8, discount=0.0) + ) + ep = collector.end_episode(truncated=False) + + # Check file was created. + filepath = tmp_path / f"{ep.episode_id}.json" + assert filepath.exists() + + # Check file contents. + with open(filepath) as f: + data = json.load(f) + + assert data["episode_id"] == ep.episode_id + assert data["task_name"] == "json-task" + assert data["num_steps"] == 3 + assert data["total_reward"] == 0.8 + assert data["truncated"] is False + assert len(data["steps"]) == 3 + + def test_load_saved_episode(self, tmp_path: pathlib.Path): + """Test that saved episodes can be loaded back.""" + collector = trajectory.JsonTrajectoryCollector(output_dir=tmp_path) + + collector.begin_episode(task_name="roundtrip-task") + collector.record_step( + _make_step_record(step_index=0, step_type="FIRST", with_action=False, reward=None, discount=None) + ) + collector.record_step( + _make_step_record(step_index=1, step_type="LAST", reward=0.5, discount=0.0) + ) + ep = collector.end_episode() + + filepath = tmp_path / f"{ep.episode_id}.json" + loaded = trajectory.EpisodeTrajectory.load(filepath) + + assert loaded.episode_id == ep.episode_id + assert loaded.task_name == ep.task_name + assert loaded.total_reward == ep.total_reward + assert loaded.num_steps == ep.num_steps + assert len(loaded.steps) == len(ep.steps) + + def test_multiple_episodes_create_separate_files(self, tmp_path: pathlib.Path): + collector = trajectory.JsonTrajectoryCollector(output_dir=tmp_path) + + episode_ids = [] + for i in range(3): + collector.begin_episode(task_name=f"task-{i}") + collector.record_step( + _make_step_record(step_index=0, step_type="FIRST", with_action=False, reward=None, discount=None) + ) + collector.record_step( + _make_step_record(step_index=1, step_type="LAST", reward=float(i), discount=0.0) + ) + ep = collector.end_episode() + episode_ids.append(ep.episode_id) + + json_files = list(tmp_path.glob("*.json")) + assert len(json_files) == 3 + + # All episode IDs should be unique. + assert len(set(episode_ids)) == 3 + + def test_record_step_without_begin_raises(self, tmp_path: pathlib.Path): + collector = trajectory.JsonTrajectoryCollector(output_dir=tmp_path) + with pytest.raises(RuntimeError, match="No episode in progress"): + collector.record_step(_make_step_record()) + + def test_end_episode_without_begin_raises(self, tmp_path: pathlib.Path): + collector = trajectory.JsonTrajectoryCollector(output_dir=tmp_path) + with pytest.raises(RuntimeError, match="No episode in progress"): + collector.end_episode() + + def test_implements_protocol(self, tmp_path: pathlib.Path): + collector = trajectory.JsonTrajectoryCollector(output_dir=tmp_path) + assert isinstance(collector, trajectory.TrajectoryCollector) + + def test_truncated_flag_saved(self, tmp_path: pathlib.Path): + collector = trajectory.JsonTrajectoryCollector(output_dir=tmp_path) + + collector.begin_episode(task_name="trunc-task") + collector.record_step( + _make_step_record(step_index=0, step_type="FIRST", with_action=False, reward=None, discount=None) + ) + collector.record_step(_make_step_record(step_index=1, step_type="LAST", reward=0.0, discount=1.0)) + ep = collector.end_episode(truncated=True) + + filepath = tmp_path / f"{ep.episode_id}.json" + with open(filepath) as f: + data = json.load(f) + + assert data["truncated"] is True diff --git a/src/ares/presets.py b/src/ares/presets.py index 71ff0e3..431a3c3 100644 --- a/src/ares/presets.py +++ b/src/ares/presets.py @@ -22,6 +22,7 @@ from ares.containers import containers from ares.environments import base from ares.environments import code_env +from ares.environments import trajectory from ares.experiment_tracking import stat_tracker _LOGGER = logging.getLogger(__name__) @@ -64,6 +65,7 @@ def get_env( selector: registry.TaskSelector, container_factory: containers.ContainerFactory, tracker: stat_tracker.StatTracker | None = None, + trajectory_collector: trajectory.TrajectoryCollector | None = None, ) -> base.Environment: """Create Harbor Verified environment with mini-swe-agent.""" all_tasks = self.ds @@ -78,6 +80,7 @@ def get_env( code_agent_factory=self.code_agent_factory, step_limit=250, # Same as mini-swe-agent default. tracker=tracker, + trajectory_collector=trajectory_collector, ) diff --git a/src/ares/registry.py b/src/ares/registry.py index 959d1ce..f686e64 100644 --- a/src/ares/registry.py +++ b/src/ares/registry.py @@ -20,6 +20,7 @@ from ares.containers import containers from ares.containers import docker from ares.environments import base +from ares.environments import trajectory from ares.experiment_tracking import stat_tracker _LOGGER = logging.getLogger(__name__) @@ -261,6 +262,7 @@ def get_env( selector: TaskSelector, container_factory: containers.ContainerFactory, tracker: stat_tracker.StatTracker | None = None, + trajectory_collector: trajectory.TrajectoryCollector | None = None, ) -> base.Environment: """Create and return an environment instance. @@ -268,6 +270,7 @@ def get_env( selector: Task selector to filter which tasks to include. container_factory: Factory for creating containers. Required. tracker: Statistics tracker for monitoring. Optional. + trajectory_collector: Trajectory collector for recording episodes. Optional. Returns: A configured environment instance ready for use in the RL loop. @@ -407,12 +410,14 @@ def get_env( selector: TaskSelector, container_factory: containers.ContainerFactory, tracker: stat_tracker.StatTracker | None = None, + trajectory_collector: trajectory.TrajectoryCollector | None = None, ) -> base.Environment: """Delegate to the decorated function.""" return func( selector=selector, container_factory=container_factory, tracker=tracker, + trajectory_collector=trajectory_collector, ) # Register the auto-generated spec @@ -529,6 +534,7 @@ def make( *, container_factory: containers.ContainerFactory = docker.DockerContainer, tracker: stat_tracker.StatTracker | None = None, + trajectory_collector: trajectory.TrajectoryCollector | None = None, ) -> base.Environment: """Create an environment instance from a registered preset. @@ -546,6 +552,9 @@ def make( - "sbv-mswea@2/8" - Shard 2 out of 8 total shards container_factory: Factory for creating containers. Defaults to DockerContainer. tracker: Statistics tracker for monitoring. Optional. + trajectory_collector: Trajectory collector for recording episode data. Optional. + When provided, every episode will have its (observation, action, reward, discount) + sequence recorded via the collector. See :mod:`ares.environments.trajectory`. Returns: An environment instance configured according to the preset. @@ -582,6 +591,12 @@ def make( >>> from ares.experiment_tracking import stat_tracker >>> tracker = stat_tracker.LoggingStatTracker() >>> env = make("sbv-mswea", tracker=tracker) + + Collect episode trajectories: + + >>> from ares.environments.trajectory import JsonTrajectoryCollector + >>> collector = JsonTrajectoryCollector(output_dir="./trajectories") + >>> env = make("sbv-mswea", trajectory_collector=collector) """ # Parse the selector syntax preset_id_clean, selector = _parse_selector(preset_id) @@ -592,14 +607,21 @@ def make( spec = _REGISTRY[preset_id_clean] _LOGGER.debug( - "Creating environment from preset '%s' with selector=%s, container_factory=%s, tracker=%s", + "Creating environment from preset '%s' with selector=%s, container_factory=%s, " + "tracker=%s, trajectory_collector=%s", preset_id_clean, selector, container_factory, tracker, + trajectory_collector, ) - env = spec.get_env(selector=selector, container_factory=container_factory, tracker=tracker) + env = spec.get_env( + selector=selector, + container_factory=container_factory, + tracker=tracker, + trajectory_collector=trajectory_collector, + ) _LOGGER.debug("Successfully created environment from preset '%s'", preset_id_clean) return env