-
Notifications
You must be signed in to change notification settings - Fork 14
Implement Trajectory collection #86
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,13 @@ | ||
| """Environment implementations for ARES.""" | ||
|
|
||
| from ares.environments.trajectory import EpisodeTrajectory | ||
| from ares.environments.trajectory import JsonTrajectoryCollector | ||
| from ares.environments.trajectory import StepRecord | ||
| from ares.environments.trajectory import TrajectoryCollector | ||
|
|
||
| __all__ = [ | ||
| "EpisodeTrajectory", | ||
| "JsonTrajectoryCollector", | ||
| "StepRecord", | ||
| "TrajectoryCollector", | ||
| ] | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -27,6 +27,7 @@ | |
| from ares.containers import containers | ||
| from ares.containers import daytona as ares_daytona | ||
| from ares.environments import base | ||
| from ares.environments import trajectory as trajectory_lib | ||
| from ares.experiment_tracking import stat_tracker | ||
| from ares.llms import queue_mediated_client | ||
| from ares.llms import request | ||
|
|
@@ -67,13 +68,15 @@ def __init__( | |
| step_limit: int = 250, # Same as mini-swe-agent default. | ||
| prefix: str = "harbor_env", | ||
| tracker: stat_tracker.StatTracker | None = None, | ||
| trajectory_collector: trajectory_lib.TrajectoryCollector | None = None, | ||
| ): | ||
| self._tasks = tasks | ||
| self._container_factory = container_factory | ||
| self._code_agent_factory = code_agent_factory | ||
| self._step_limit = step_limit | ||
| self._prefix = prefix | ||
| self._tracker = tracker if tracker is not None else stat_tracker.NullStatTracker() | ||
| self._trajectory_collector = trajectory_collector if trajectory_collector is not None else trajectory_lib.NullTrajectoryCollector() | ||
|
|
||
| # We set the LLM client to a queue mediated client so that | ||
| # we can return LLM requests in the reset and step methods. | ||
|
|
@@ -122,6 +125,22 @@ async def reset(self) -> base.TimeStep[request.LLMRequest, float, float]: | |
| assert ts.observation is not None | ||
| result = base.TimeStep(step_type="FIRST", reward=ts.reward, discount=ts.discount, observation=ts.observation) | ||
|
|
||
| # Record the FIRST step in the trajectory. | ||
| # FIRST steps have only observation; action/reward/discount are None per dm_env semantics. | ||
| assert self._current_task is not None | ||
| self._trajectory_collector.begin_episode(task_name=self._current_task.name) | ||
| self._trajectory_collector.record_step( | ||
| trajectory_lib.StepRecord( | ||
| step_index=0, | ||
| step_type="FIRST", | ||
| observation=trajectory_lib.serialize_llm_request(result.observation), | ||
| action=None, | ||
| reward=None, | ||
| discount=None, | ||
|
Comment on lines
+128
to
+139
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Finding type: Want Baz to fix this for you? Activate Fixer Other fix methodsPrompt for AI Agents: |
||
| timestamp=time.time(), | ||
| ) | ||
| ) | ||
|
|
||
| reset_end_time = time.time() | ||
| self._tracker.scalar(f"{self._prefix}/reset", reset_end_time - reset_start_time) | ||
| return result | ||
|
|
@@ -145,16 +164,37 @@ async def step(self, action: response.LLMResponse) -> base.TimeStep[request.LLMR | |
| with self._tracker.timeit(f"{self._prefix}/get_time_step"): | ||
| ts = await self._get_time_step() | ||
|
|
||
| truncated = False | ||
| if self._step_count >= self._step_limit: | ||
| _LOGGER.debug("[%d] Step limit reached. Returning LAST timestep.", id(self)) | ||
| assert self._code_agent_task is not None | ||
| self._code_agent_task.cancel() | ||
| # Truncation: step_type="LAST", discount=1.0, unless we're _also_ already in a terminal state. | ||
| truncated = ts.step_type != "LAST" | ||
| ts = base.TimeStep(step_type="LAST", reward=ts.reward, discount=ts.discount, observation=ts.observation) | ||
|
|
||
| if ts.step_type == "LAST": | ||
| self._requires_reset = True | ||
|
|
||
| # Record the step in the trajectory. | ||
| self._trajectory_collector.record_step( | ||
| trajectory_lib.StepRecord( | ||
| step_index=self._step_count, | ||
| step_type=ts.step_type, | ||
| observation=( | ||
| trajectory_lib.serialize_llm_request(ts.observation) | ||
| if ts.observation is not None | ||
| else None | ||
| ), | ||
| action=trajectory_lib.serialize_llm_response(action), | ||
| reward=ts.reward, | ||
| discount=ts.discount, | ||
| timestamp=time.time(), | ||
| ) | ||
| ) | ||
| if ts.step_type == "LAST": | ||
| self._trajectory_collector.end_episode(truncated=truncated) | ||
|
|
||
|
Comment on lines
164
to
+197
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Finalize the trajectory when
🤖 Prompt for AI Agents |
||
| step_end_time = time.time() | ||
| self._tracker.scalar(f"{self._prefix}/step", step_end_time - step_start_time) | ||
|
|
||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This seems reasonable, but we might want to consider where is the right place to export these.