From b44261dbc03dc7da74cb6cf4b20ac84d30a4be57 Mon Sep 17 00:00:00 2001 From: SourasishBasu <22051636@kiit.ac.in> Date: Sun, 26 Apr 2026 18:15:25 +0530 Subject: [PATCH] feat(scripts): generalize collect_trajectories.py to all task envs MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Adds task profiles (postgres / notebook / type-checker / libexpat) so the trajectory collector can be pointed at any env via --task. Postgres defaults are preserved bit-for-bit (image, episode/message timeouts, container prefix, base port) so the existing PG flow is unaffected. New CLI flags: --task, --runtime (auto-detects docker/podman), --image, --base-port, --container-prefix. Also patches openenv-core's WebSocket client at runtime to disable the keepalive ping. The agent can run a single turn for >10 minutes without yielding back to the WS event loop, which starves pong replies and kills the connection mid-episode (1011 keepalive ping timeout). Patch is guarded by an assertion so it fails loudly if openenv-core's layout changes — TODO to remove once ping_interval/ping_timeout kwargs are exposed upstream. Verified end-to-end across all four reward shapes; trajectories at artifacts/trajectories-2026-04-26.zip locally. Co-Authored-By: Claude Opus 4.7 (1M context) --- .gitignore | 1 + scripts/collect_trajectories.py | 231 ++++++++++++++++++++++++++++---- 2 files changed, 204 insertions(+), 28 deletions(-) diff --git a/.gitignore b/.gitignore index 030af7b..3665e25 100644 --- a/.gitignore +++ b/.gitignore @@ -241,6 +241,7 @@ pi-mono/ notes/ trajectories*/ +artifacts/ .hf_cache/ .hf_datasets_cache/ diff --git a/scripts/collect_trajectories.py b/scripts/collect_trajectories.py index c1dd523..7182a43 100644 --- a/scripts/collect_trajectories.py +++ b/scripts/collect_trajectories.py @@ -2,19 +2,34 @@ """ Collect DPO trajectories by running N episodes across W parallel workers. -Spins up W Docker containers (one per worker), then round-robins episodes +Spins up W containers (one per worker), then round-robins episodes across them. Each episode produces: - result.json (episode metadata + reward) - pi_session.jsonl (full agent trajectory) - container_logs.txt (server-side scoring logs) +Default behaviour preserves the original postgres-only flow: with no flags +the script is identical to its pre-multi-task form (20 episodes, 4 workers, +600 s WS timeout, ``frontier-swe-pg:latest`` image). Other tasks +(``notebook``, ``type-checker``, ``libexpat-to-x86asm``) are opt-in via +``--task``. Reward shapes (ratio / reward_json / reward_json_score) are +L1-rubric concerns; by the time scores reach ``frozen_scores`` they are +normalised [0, 1] floats, and the ``EpisodeRubric`` weights (plan / subtask +/ completion / tool) come from ``TaskConfig`` defaults that no task +currently overrides — so the offline reward backfill works for every task. + Usage: - # 20 episodes across 4 parallel workers (default) + # Original postgres flow (unchanged): 20 episodes, 4 workers, 45-min ep PYTHONPATH=. uv run python scripts/collect_trajectories.py - # Custom settings + # Quick reward-shape smoke against another task + PYTHONPATH=. uv run python scripts/collect_trajectories.py \ + --task type-checker --episodes 2 --workers 1 + + # Override the image tag (useful for local :smoke builds) PYTHONPATH=. uv run python scripts/collect_trajectories.py \ - --episodes 20 --workers 4 --output-dir trajectories/ + --task type-checker \ + --image frontier-swe-dependent-type-checker:smoke # Resume from a previous run (skips existing episodes) PYTHONPATH=. uv run python scripts/collect_trajectories.py --resume @@ -38,6 +53,41 @@ from frontier_swe_env.client import FrontierSweEnv # noqa: E402 from frontier_swe_env.models import FrontierSweAction # noqa: E402 + +def _disable_openenv_ws_keepalive() -> None: + """Disable the websockets keepalive ping for the EnvClient. + + The notebook env can run a single agent turn for >10 minutes (e.g. 96 + bash calls in a row) without yielding back to the WS event loop. The + `websockets` library's auto-ping defaults (interval=20s, timeout=20s) + fire during that gap, the env never pongs in time, and the connection + dies with `ConnectionClosedError: keepalive ping timeout` mid-episode. + The OpenEnv pin doesn't expose ping_interval/ping_timeout kwargs yet, + so we rebind the `ws_connect` symbol inside `openenv.core.env_client` + to a wrapper that injects `ping_interval=None`. + + TODO: remove once openenv-core exposes ping kwargs upstream and we bump + the pin in pyproject.toml. + """ + from openenv.core import env_client + + assert hasattr(env_client, "ws_connect"), ( + "openenv-core layout changed: env_client no longer exposes " + "ws_connect. Remove this monkey-patch or update it." + ) + + _orig_connect = env_client.ws_connect + + def _connect_no_keepalive(*args, **kwargs): + kwargs.setdefault("ping_interval", None) + return _orig_connect(*args, **kwargs) + + env_client.ws_connect = _connect_no_keepalive + + +_disable_openenv_ws_keepalive() + + logging.basicConfig( level=logging.INFO, format="%(asctime)s [%(levelname)s] %(name)s: %(message)s", @@ -52,17 +102,62 @@ # Constants -DOCKER_IMAGE = "frontier-swe-pg:latest" +# Per-task profile. Postgres values match the original hardcoded constants +# verbatim (2700 s episode, 600 s WS message timeout, frontier-swe-pg:latest +# image) so the default flow with no flags reproduces the pre-multi-task +# behaviour bit-for-bit. Other tasks need a longer WS timeout because pi +# explores deeper before submitting (observed up to 900 s on type-checker). +TASK_PROFILES: dict[str, dict] = { + "postgres": { + "image": "frontier-swe-pg:latest", + "episode_timeout_s": 2700, + "message_timeout_s": 600.0, + "dockerfile": "docker/Dockerfile.pg", + }, + "type-checker": { + "image": "frontier-swe-dependent-type-checker:latest", + "episode_timeout_s": 1800, + "message_timeout_s": 900.0, + "dockerfile": "docker/Dockerfile.dependent-type-checker", + }, + "notebook": { + "image": "frontier-swe-notebook:latest", + "episode_timeout_s": 1800, + "message_timeout_s": 900.0, + "dockerfile": "docker/Dockerfile.notebook", + }, + "libexpat-to-x86asm": { + "image": "frontier-swe-libexpat-to-x86asm:latest", + "episode_timeout_s": 1800, + "message_timeout_s": 900.0, + "dockerfile": "docker/Dockerfile.libexpat-to-x86asm", + }, +} + +# Mutated by main() once CLI args are parsed. Defaults below match the +# original postgres-hardcoded values so an import-side smoke that doesn't +# call main() still sees PG semantics. +DOCKER_IMAGE: str = TASK_PROFILES["postgres"]["image"] +DOCKERFILE: str = TASK_PROFILES["postgres"]["dockerfile"] +EPISODE_TIMEOUT_S: int = TASK_PROFILES["postgres"]["episode_timeout_s"] +MESSAGE_TIMEOUT_S: float = TASK_PROFILES["postgres"]["message_timeout_s"] + +# Mutated by main() — see --container-prefix / --base-port. Defaults are +# the original PG-flow values so existing call sites stay byte-identical. CONTAINER_PREFIX = "fswe-worker" BASE_PORT = 8100 # workers use ports 8100, 8101, 8102, ... + ENV_FILE = ".env" MAX_TURNS = 20 -MESSAGE_TIMEOUT_S = 600.0 -EPISODE_TIMEOUT_S = 2700 # 45 min (must match task_config) -CONTAINER_STARTUP_WAIT = 10 # seconds to wait after docker run +CONTAINER_STARTUP_WAIT = 10 # seconds to wait after container run HEALTH_CHECK_RETRIES = 30 HEALTH_CHECK_INTERVAL = 2 +# Container runtime — `docker` or `podman`. Set via --runtime / $CONTAINER_RUNTIME +# or auto-detected at startup. The two CLIs share the verbs we use +# (run, rm, exec, logs, cp, image inspect), so the same code path works. +CONTAINER_RUNTIME: str = "docker" + # Offline reward computation @@ -80,7 +175,9 @@ def _compute_reward_offline(result: dict) -> float: plan_count = max(len(plan), 1) if plan else 1 - # Weights (match EpisodeRubric / pg_training_config) + # Weights match EpisodeRubric defaults from TaskConfig. No task in the + # registry currently overrides these, so the same formula applies to + # postgres, notebook, type-checker and libexpat-to-x86asm. plan_weight = 0.25 subtask_weight = 0.60 completion_weight = 0.10 @@ -119,13 +216,13 @@ def start_container(worker_id: int) -> bool: # Remove any existing container with this name subprocess.run( - ["docker", "rm", "-f", name], + [CONTAINER_RUNTIME, "rm", "-f", name], capture_output=True, timeout=10, ) cmd = [ - "docker", + CONTAINER_RUNTIME, "run", "-d", "--name", @@ -175,7 +272,7 @@ def wait_for_healthy(worker_id: int) -> bool: def stop_container(worker_id: int) -> None: """Stop and remove a worker container.""" name = container_name(worker_id) - subprocess.run(["docker", "rm", "-f", name], capture_output=True, timeout=15) + subprocess.run([CONTAINER_RUNTIME, "rm", "-f", name], capture_output=True, timeout=15) logger.info("Stopped container %s", name) @@ -189,7 +286,7 @@ def reset_container(worker_id: int) -> bool: name = container_name(worker_id) # Remove old container - subprocess.run(["docker", "rm", "-f", name], capture_output=True, timeout=15) + subprocess.run([CONTAINER_RUNTIME, "rm", "-f", name], capture_output=True, timeout=15) time.sleep(1) # Start fresh @@ -209,7 +306,7 @@ def extract_artifacts(worker_id: int, episode_dir: Path) -> dict: # Container logs try: result = subprocess.run( - ["docker", "logs", name], + [CONTAINER_RUNTIME, "logs", name], capture_output=True, text=True, timeout=15, @@ -225,7 +322,7 @@ def extract_artifacts(worker_id: int, episode_dir: Path) -> dict: try: result = subprocess.run( [ - "docker", + CONTAINER_RUNTIME, "exec", name, "bash", @@ -241,7 +338,7 @@ def extract_artifacts(worker_id: int, episode_dir: Path) -> dict: if not session_file: result = subprocess.run( [ - "docker", + CONTAINER_RUNTIME, "exec", name, "bash", @@ -257,7 +354,7 @@ def extract_artifacts(worker_id: int, episode_dir: Path) -> dict: if session_file: dest = episode_dir / "pi_session.jsonl" result = subprocess.run( - ["docker", "cp", f"{name}:{session_file}", str(dest)], + [CONTAINER_RUNTIME, "cp", f"{name}:{session_file}", str(dest)], capture_output=True, timeout=30, ) @@ -429,7 +526,7 @@ async def run_single_episode( episode_result = { "episode_id": episode_id, "worker_id": worker_id, - "error": str(e), + "error": str(e) or type(e).__name__, "elapsed_s": round(elapsed, 1), "turns": turn, } @@ -564,24 +661,26 @@ async def collect( logger.info("=" * 70) logger.info("Episodes: %d (%d remaining)", num_episodes, remaining) logger.info("Workers: %d", num_workers) + logger.info("Image: %s", DOCKER_IMAGE) logger.info("Output: %s/", out) - logger.info("Per episode: ~45 min (2700s episode + overhead)") logger.info( - "Estimated: ~%.0f min total", remaining / num_workers * 50 - ) # 45 min + 5 min overhead + "Per-episode soft ceiling: %d s (env timer is authoritative)", + EPISODE_TIMEOUT_S, + ) logger.info("=" * 70) # Verify Docker image exists result = subprocess.run( - ["docker", "image", "inspect", DOCKER_IMAGE], + [CONTAINER_RUNTIME, "image", "inspect", DOCKER_IMAGE], capture_output=True, timeout=10, ) if result.returncode != 0: logger.error( "Docker image %s not found. Build it first:\n" - " docker build -f docker/Dockerfile.pg -t %s .", + " docker build -f %s -t %s .", DOCKER_IMAGE, + DOCKERFILE, DOCKER_IMAGE, ) sys.exit(1) @@ -720,6 +819,32 @@ def main(): parser = argparse.ArgumentParser( description="Collect DPO trajectories across parallel workers", ) + parser.add_argument( + "--task", + choices=sorted(TASK_PROFILES.keys()), + default="postgres", + help=( + "Task slug — selects the default image and timeouts. " + "Default is postgres (preserves the pre-multi-task hardcoded flow)." + ), + ) + parser.add_argument( + "--runtime", + choices=["docker", "podman"], + default=None, + help=( + "Container runtime CLI to use. Default: $CONTAINER_RUNTIME, " + "else auto-detect (prefer docker, fall back to podman)." + ), + ) + parser.add_argument( + "--image", + default=None, + help=( + "Override the Docker image tag (useful for local :smoke builds). " + "Defaults to the image registered for --task." + ), + ) parser.add_argument( "--episodes", type=int, @@ -730,7 +855,26 @@ def main(): "--workers", type=int, default=4, - help="Number of parallel Docker containers (default: 4)", + help="Number of parallel containers (default: 4)", + ) + parser.add_argument( + "--base-port", + type=int, + default=8100, + help=( + "Host port for worker_id=0 (subsequent workers use base+1, " + "base+2, ...). Bump this to run multiple invocations of the " + "script in parallel without colliding (default: 8100)." + ), + ) + parser.add_argument( + "--container-prefix", + default="fswe-worker", + help=( + "Prefix used to name worker containers (full name is " + "-). Use a unique value per parallel " + "invocation so runs don't fight over container names." + ), ) parser.add_argument( "--output-dir", @@ -752,16 +896,47 @@ def main(): "--episode-timeout", type=int, default=None, - help="Override episode timeout in seconds (default: 2700 = 45 min)", + help=( + "Override episode timeout in seconds. Default is the value " + "registered for --task in TASK_PROFILES." + ), ) args = parser.parse_args() + profile = TASK_PROFILES[args.task] + global DOCKER_IMAGE, DOCKERFILE, EPISODE_TIMEOUT_S, MESSAGE_TIMEOUT_S + global CONTAINER_RUNTIME, BASE_PORT, CONTAINER_PREFIX + DOCKER_IMAGE = args.image or profile["image"] + DOCKERFILE = profile["dockerfile"] + EPISODE_TIMEOUT_S = ( + args.episode_timeout + if args.episode_timeout is not None + else profile["episode_timeout_s"] + ) + MESSAGE_TIMEOUT_S = profile["message_timeout_s"] + BASE_PORT = args.base_port + CONTAINER_PREFIX = args.container_prefix + + import os + import shutil + + runtime = args.runtime or os.environ.get("CONTAINER_RUNTIME") + if not runtime: + runtime = "docker" if shutil.which("docker") else ( + "podman" if shutil.which("podman") else "docker" + ) + if not shutil.which(runtime): + logger.error( + "Container runtime %r not on PATH. Install it or pass --runtime.", + runtime, + ) + sys.exit(1) + CONTAINER_RUNTIME = runtime + logger.info("Container runtime: %s", CONTAINER_RUNTIME) + if args.max_turns is not None: global MAX_TURNS MAX_TURNS = args.max_turns - if args.episode_timeout is not None: - global EPISODE_TIMEOUT_S - EPISODE_TIMEOUT_S = args.episode_timeout asyncio.run( collect(