Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -241,6 +241,7 @@ pi-mono/
notes/

trajectories*/
artifacts/

.hf_cache/
.hf_datasets_cache/
Expand Down
231 changes: 203 additions & 28 deletions scripts/collect_trajectories.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,19 +2,34 @@
"""
Collect DPO trajectories by running N episodes across W parallel workers.

Spins up W Docker containers (one per worker), then round-robins episodes
Spins up W containers (one per worker), then round-robins episodes
across them. Each episode produces:
- result.json (episode metadata + reward)
- pi_session.jsonl (full agent trajectory)
- container_logs.txt (server-side scoring logs)

Default behaviour preserves the original postgres-only flow: with no flags
the script is identical to its pre-multi-task form (20 episodes, 4 workers,
600 s WS timeout, ``frontier-swe-pg:latest`` image). Other tasks
(``notebook``, ``type-checker``, ``libexpat-to-x86asm``) are opt-in via
``--task``. Reward shapes (ratio / reward_json / reward_json_score) are
L1-rubric concerns; by the time scores reach ``frozen_scores`` they are
normalised [0, 1] floats, and the ``EpisodeRubric`` weights (plan / subtask
/ completion / tool) come from ``TaskConfig`` defaults that no task
currently overrides — so the offline reward backfill works for every task.

Usage:
# 20 episodes across 4 parallel workers (default)
# Original postgres flow (unchanged): 20 episodes, 4 workers, 45-min ep
PYTHONPATH=. uv run python scripts/collect_trajectories.py

# Custom settings
# Quick reward-shape smoke against another task
PYTHONPATH=. uv run python scripts/collect_trajectories.py \
--task type-checker --episodes 2 --workers 1

# Override the image tag (useful for local :smoke builds)
PYTHONPATH=. uv run python scripts/collect_trajectories.py \
--episodes 20 --workers 4 --output-dir trajectories/
--task type-checker \
--image frontier-swe-dependent-type-checker:smoke

# Resume from a previous run (skips existing episodes)
PYTHONPATH=. uv run python scripts/collect_trajectories.py --resume
Expand All @@ -38,6 +53,41 @@
from frontier_swe_env.client import FrontierSweEnv # noqa: E402
from frontier_swe_env.models import FrontierSweAction # noqa: E402


def _disable_openenv_ws_keepalive() -> None:
"""Disable the websockets keepalive ping for the EnvClient.

The notebook env can run a single agent turn for >10 minutes (e.g. 96
bash calls in a row) without yielding back to the WS event loop. The
`websockets` library's auto-ping defaults (interval=20s, timeout=20s)
fire during that gap, the env never pongs in time, and the connection
dies with `ConnectionClosedError: keepalive ping timeout` mid-episode.
The OpenEnv pin doesn't expose ping_interval/ping_timeout kwargs yet,
so we rebind the `ws_connect` symbol inside `openenv.core.env_client`
to a wrapper that injects `ping_interval=None`.

TODO: remove once openenv-core exposes ping kwargs upstream and we bump
the pin in pyproject.toml.
"""
from openenv.core import env_client

assert hasattr(env_client, "ws_connect"), (
"openenv-core layout changed: env_client no longer exposes "
"ws_connect. Remove this monkey-patch or update it."
)

_orig_connect = env_client.ws_connect

def _connect_no_keepalive(*args, **kwargs):
kwargs.setdefault("ping_interval", None)
return _orig_connect(*args, **kwargs)

env_client.ws_connect = _connect_no_keepalive


_disable_openenv_ws_keepalive()


logging.basicConfig(
level=logging.INFO,
format="%(asctime)s [%(levelname)s] %(name)s: %(message)s",
Expand All @@ -52,17 +102,62 @@

# Constants

DOCKER_IMAGE = "frontier-swe-pg:latest"
# Per-task profile. Postgres values match the original hardcoded constants
# verbatim (2700 s episode, 600 s WS message timeout, frontier-swe-pg:latest
# image) so the default flow with no flags reproduces the pre-multi-task
# behaviour bit-for-bit. Other tasks need a longer WS timeout because pi
# explores deeper before submitting (observed up to 900 s on type-checker).
TASK_PROFILES: dict[str, dict] = {
"postgres": {
"image": "frontier-swe-pg:latest",
"episode_timeout_s": 2700,
"message_timeout_s": 600.0,
"dockerfile": "docker/Dockerfile.pg",
},
"type-checker": {
"image": "frontier-swe-dependent-type-checker:latest",
"episode_timeout_s": 1800,
"message_timeout_s": 900.0,
"dockerfile": "docker/Dockerfile.dependent-type-checker",
},
"notebook": {
"image": "frontier-swe-notebook:latest",
"episode_timeout_s": 1800,
"message_timeout_s": 900.0,
"dockerfile": "docker/Dockerfile.notebook",
},
"libexpat-to-x86asm": {
"image": "frontier-swe-libexpat-to-x86asm:latest",
"episode_timeout_s": 1800,
"message_timeout_s": 900.0,
"dockerfile": "docker/Dockerfile.libexpat-to-x86asm",
},
}

# Mutated by main() once CLI args are parsed. Defaults below match the
# original postgres-hardcoded values so an import-side smoke that doesn't
# call main() still sees PG semantics.
DOCKER_IMAGE: str = TASK_PROFILES["postgres"]["image"]
DOCKERFILE: str = TASK_PROFILES["postgres"]["dockerfile"]
EPISODE_TIMEOUT_S: int = TASK_PROFILES["postgres"]["episode_timeout_s"]
MESSAGE_TIMEOUT_S: float = TASK_PROFILES["postgres"]["message_timeout_s"]

# Mutated by main() — see --container-prefix / --base-port. Defaults are
# the original PG-flow values so existing call sites stay byte-identical.
CONTAINER_PREFIX = "fswe-worker"
BASE_PORT = 8100 # workers use ports 8100, 8101, 8102, ...

ENV_FILE = ".env"
MAX_TURNS = 20
MESSAGE_TIMEOUT_S = 600.0
EPISODE_TIMEOUT_S = 2700 # 45 min (must match task_config)
CONTAINER_STARTUP_WAIT = 10 # seconds to wait after docker run
CONTAINER_STARTUP_WAIT = 10 # seconds to wait after container run
HEALTH_CHECK_RETRIES = 30
HEALTH_CHECK_INTERVAL = 2

# Container runtime — `docker` or `podman`. Set via --runtime / $CONTAINER_RUNTIME
# or auto-detected at startup. The two CLIs share the verbs we use
# (run, rm, exec, logs, cp, image inspect), so the same code path works.
CONTAINER_RUNTIME: str = "docker"


# Offline reward computation

Expand All @@ -80,7 +175,9 @@ def _compute_reward_offline(result: dict) -> float:

plan_count = max(len(plan), 1) if plan else 1

# Weights (match EpisodeRubric / pg_training_config)
# Weights match EpisodeRubric defaults from TaskConfig. No task in the
# registry currently overrides these, so the same formula applies to
# postgres, notebook, type-checker and libexpat-to-x86asm.
plan_weight = 0.25
subtask_weight = 0.60
completion_weight = 0.10
Expand Down Expand Up @@ -119,13 +216,13 @@ def start_container(worker_id: int) -> bool:

# Remove any existing container with this name
subprocess.run(
["docker", "rm", "-f", name],
[CONTAINER_RUNTIME, "rm", "-f", name],
capture_output=True,
timeout=10,
)

cmd = [
"docker",
CONTAINER_RUNTIME,
"run",
"-d",
"--name",
Expand Down Expand Up @@ -175,7 +272,7 @@ def wait_for_healthy(worker_id: int) -> bool:
def stop_container(worker_id: int) -> None:
"""Stop and remove a worker container."""
name = container_name(worker_id)
subprocess.run(["docker", "rm", "-f", name], capture_output=True, timeout=15)
subprocess.run([CONTAINER_RUNTIME, "rm", "-f", name], capture_output=True, timeout=15)
logger.info("Stopped container %s", name)


Expand All @@ -189,7 +286,7 @@ def reset_container(worker_id: int) -> bool:
name = container_name(worker_id)

# Remove old container
subprocess.run(["docker", "rm", "-f", name], capture_output=True, timeout=15)
subprocess.run([CONTAINER_RUNTIME, "rm", "-f", name], capture_output=True, timeout=15)
time.sleep(1)

# Start fresh
Expand All @@ -209,7 +306,7 @@ def extract_artifacts(worker_id: int, episode_dir: Path) -> dict:
# Container logs
try:
result = subprocess.run(
["docker", "logs", name],
[CONTAINER_RUNTIME, "logs", name],
capture_output=True,
text=True,
timeout=15,
Expand All @@ -225,7 +322,7 @@ def extract_artifacts(worker_id: int, episode_dir: Path) -> dict:
try:
result = subprocess.run(
[
"docker",
CONTAINER_RUNTIME,
"exec",
name,
"bash",
Expand All @@ -241,7 +338,7 @@ def extract_artifacts(worker_id: int, episode_dir: Path) -> dict:
if not session_file:
result = subprocess.run(
[
"docker",
CONTAINER_RUNTIME,
"exec",
name,
"bash",
Expand All @@ -257,7 +354,7 @@ def extract_artifacts(worker_id: int, episode_dir: Path) -> dict:
if session_file:
dest = episode_dir / "pi_session.jsonl"
result = subprocess.run(
["docker", "cp", f"{name}:{session_file}", str(dest)],
[CONTAINER_RUNTIME, "cp", f"{name}:{session_file}", str(dest)],
capture_output=True,
timeout=30,
)
Expand Down Expand Up @@ -429,7 +526,7 @@ async def run_single_episode(
episode_result = {
"episode_id": episode_id,
"worker_id": worker_id,
"error": str(e),
"error": str(e) or type(e).__name__,
"elapsed_s": round(elapsed, 1),
"turns": turn,
}
Expand Down Expand Up @@ -564,24 +661,26 @@ async def collect(
logger.info("=" * 70)
logger.info("Episodes: %d (%d remaining)", num_episodes, remaining)
logger.info("Workers: %d", num_workers)
logger.info("Image: %s", DOCKER_IMAGE)
logger.info("Output: %s/", out)
logger.info("Per episode: ~45 min (2700s episode + overhead)")
logger.info(
"Estimated: ~%.0f min total", remaining / num_workers * 50
) # 45 min + 5 min overhead
"Per-episode soft ceiling: %d s (env timer is authoritative)",
EPISODE_TIMEOUT_S,
)
logger.info("=" * 70)

# Verify Docker image exists
result = subprocess.run(
["docker", "image", "inspect", DOCKER_IMAGE],
[CONTAINER_RUNTIME, "image", "inspect", DOCKER_IMAGE],
capture_output=True,
timeout=10,
)
if result.returncode != 0:
logger.error(
"Docker image %s not found. Build it first:\n"
" docker build -f docker/Dockerfile.pg -t %s .",
" docker build -f %s -t %s .",
DOCKER_IMAGE,
DOCKERFILE,
DOCKER_IMAGE,
)
sys.exit(1)
Expand Down Expand Up @@ -720,6 +819,32 @@ def main():
parser = argparse.ArgumentParser(
description="Collect DPO trajectories across parallel workers",
)
parser.add_argument(
"--task",
choices=sorted(TASK_PROFILES.keys()),
default="postgres",
help=(
"Task slug — selects the default image and timeouts. "
"Default is postgres (preserves the pre-multi-task hardcoded flow)."
),
)
parser.add_argument(
"--runtime",
choices=["docker", "podman"],
default=None,
help=(
"Container runtime CLI to use. Default: $CONTAINER_RUNTIME, "
"else auto-detect (prefer docker, fall back to podman)."
),
)
parser.add_argument(
"--image",
default=None,
help=(
"Override the Docker image tag (useful for local :smoke builds). "
"Defaults to the image registered for --task."
),
)
parser.add_argument(
"--episodes",
type=int,
Expand All @@ -730,7 +855,26 @@ def main():
"--workers",
type=int,
default=4,
help="Number of parallel Docker containers (default: 4)",
help="Number of parallel containers (default: 4)",
)
parser.add_argument(
"--base-port",
type=int,
default=8100,
help=(
"Host port for worker_id=0 (subsequent workers use base+1, "
"base+2, ...). Bump this to run multiple invocations of the "
"script in parallel without colliding (default: 8100)."
),
)
parser.add_argument(
"--container-prefix",
default="fswe-worker",
help=(
"Prefix used to name worker containers (full name is "
"<prefix>-<worker_id>). Use a unique value per parallel "
"invocation so runs don't fight over container names."
),
)
parser.add_argument(
"--output-dir",
Expand All @@ -752,16 +896,47 @@ def main():
"--episode-timeout",
type=int,
default=None,
help="Override episode timeout in seconds (default: 2700 = 45 min)",
help=(
"Override episode timeout in seconds. Default is the value "
"registered for --task in TASK_PROFILES."
),
)
args = parser.parse_args()

profile = TASK_PROFILES[args.task]
global DOCKER_IMAGE, DOCKERFILE, EPISODE_TIMEOUT_S, MESSAGE_TIMEOUT_S
global CONTAINER_RUNTIME, BASE_PORT, CONTAINER_PREFIX
DOCKER_IMAGE = args.image or profile["image"]
DOCKERFILE = profile["dockerfile"]
EPISODE_TIMEOUT_S = (
args.episode_timeout
if args.episode_timeout is not None
else profile["episode_timeout_s"]
)
MESSAGE_TIMEOUT_S = profile["message_timeout_s"]
BASE_PORT = args.base_port
CONTAINER_PREFIX = args.container_prefix

import os
import shutil

runtime = args.runtime or os.environ.get("CONTAINER_RUNTIME")
if not runtime:
runtime = "docker" if shutil.which("docker") else (
"podman" if shutil.which("podman") else "docker"
)
if not shutil.which(runtime):
logger.error(
"Container runtime %r not on PATH. Install it or pass --runtime.",
runtime,
)
sys.exit(1)
CONTAINER_RUNTIME = runtime
logger.info("Container runtime: %s", CONTAINER_RUNTIME)

if args.max_turns is not None:
global MAX_TURNS
MAX_TURNS = args.max_turns
if args.episode_timeout is not None:
global EPISODE_TIMEOUT_S
EPISODE_TIMEOUT_S = args.episode_timeout

asyncio.run(
collect(
Expand Down
Loading