diff --git a/.gitignore b/.gitignore index 7eaa4976e6..9f207bdaff 100644 --- a/.gitignore +++ b/.gitignore @@ -201,3 +201,6 @@ debug_I2_zero_band outputs/ third_party/ + +# Runtime flock used to serialize `uv sync` across concurrent SLURM jobs. +.uv-sync.lock diff --git a/docs/entrypoints.md b/docs/entrypoints.md index d79a82faed..574ac25b71 100644 --- a/docs/entrypoints.md +++ b/docs/entrypoints.md @@ -65,3 +65,16 @@ uv run torchrun --nproc-per-node 8 src/prime_rl/trainer/sft/train.py ... ``` For more details on multi-node deployment options, see the [deployment](deployment.md) documentation and see the [examples](examples) for concrete training configurations. To see all available configuration options, run `uv run sft --help`. + +## Sweep + +The `sweep` entrypoint materializes and launches hyperparameter studies for `rl` or `sft` target configs. It supports grid, random, and Optuna strategies, with local execution, SLURM submission, and shared-trainer LoRA sweeps for RL. + +Each trial gets a stable directory under the study output directory with generated `overrides.toml`, fully resolved `resolved.toml`, `command.txt`, and `status.json`. The launcher validates every target trial config before launching. Objective tracking reads sweep sidecar metrics when configured, and SLURM sweeps reuse the target `rl` or `sft` config's existing `[slurm]` support. + +```bash +uv run sweep @ examples/sweep/grid_local.toml +uv run sweep @ examples/sweep/grid_local.toml --dry-run +``` + +For details on strategies, schedulers, artifacts, resume, and examples, see the [sweeps](sweeps.md) documentation. To see all sweep configuration options, run `uv run sweep --help`. diff --git a/docs/index.md b/docs/index.md index aa76871f8c..53f33d692d 100644 --- a/docs/index.md +++ b/docs/index.md @@ -4,6 +4,7 @@ This directory maintains the documentation for PRIME-RL. It is organized into th - [**Entrypoints**](entrypoints.md) - Overview of the main components (orchestrator, trainer, inference) and how to run SFT, RL, and evals - [**Configs**](configs.md) - Configuration system using TOML files, CLI arguments, and environment variables +- [**Sweeps**](sweeps.md) - Hyperparameter studies with grid, random, Optuna, local, SLURM, and shared-trainer LoRA schedulers - [**Environments**](environments.md) - Installing and using verifiers environments from the Environments Hub - [**Async Training**](async.md) - Understanding asynchronous off-policy training and step semantics - [**Logging**](logging.md) - Logging with loguru, torchrun, and Weights & Biases @@ -13,4 +14,4 @@ This directory maintains the documentation for PRIME-RL. It is organized into th - [**Benchmarking**](benchmarking.md) - Performance benchmarking and throughput measurement - [**Deployment**](deployment.md) - Training deployment on single-GPU, multi-GPU, and multi-node clusters - [**Kubernetes**](kubernetes.md) - Deploying PRIME-RL on Kubernetes with Helm -- [**Troubleshooting**](troubleshooting.md) - Common issues and their solutions \ No newline at end of file +- [**Troubleshooting**](troubleshooting.md) - Common issues and their solutions diff --git a/docs/mint.json b/docs/mint.json index 216fbe4fa4..b66f5f5c39 100644 --- a/docs/mint.json +++ b/docs/mint.json @@ -7,6 +7,7 @@ "index", "entrypoints", "configs", + "sweeps", "environments", "async", "logging", @@ -19,4 +20,4 @@ ] } ] -} \ No newline at end of file +} diff --git a/docs/sweeps.md b/docs/sweeps.md new file mode 100644 index 0000000000..f254f131b8 --- /dev/null +++ b/docs/sweeps.md @@ -0,0 +1,330 @@ +# Sweeps + +The `sweep` entrypoint materializes and launches hyperparameter studies for `rl` +or `sft` configs. It supports grid search, seeded random search, Optuna ask/tell +optimization, local execution, SLURM submission, and shared-trainer LoRA sweeps +for RL. + +## Quick Start + +Run a sweep from a sweep TOML: + +```bash +uv run sweep @ examples/sweep/grid_local.toml +``` + +Validate and materialize trial artifacts without launching anything: + +```bash +uv run sweep @ examples/sweep/grid_local.toml --dry-run +``` + +Optuna support is optional. Install the HPO extra before using +`strategy.type = "optuna"`: + +```bash +uv sync --extra hpo +``` + +## Sweep Config + +A sweep config names the target entrypoint, base target configs, study output +directory, search strategy, scheduler, optional objective, and parameter space. + +```toml +name = "reverse-text-lr" +entrypoint = "rl" +base = ["examples/reverse_text/rl.toml"] +output_dir = "outputs/studies/reverse-text-lr" + +[strategy] +type = "grid" + +[scheduler] +type = "local" +max_parallel = 1 + +[objective] +metric = "reward/reverse-text/mean" +direction = "maximize" + +[parameters."trainer.optim.lr"] +values = [1e-6, 3e-6, 1e-5] +``` + +Parameter keys are dotted paths into the target `rl` or `sft` config. The sweep +controller writes each trial's generated overrides to `overrides.toml`, validates +the resolved target config, and launches the target entrypoint with the original +base files plus those overrides. + +## Strategies + +### Grid + +Grid search exhaustively evaluates every combination of `values` entries. Grid +parameters must use explicit choices: + +```toml +[strategy] +type = "grid" + +[parameters."trainer.optim.lr"] +values = [1e-6, 3e-6, 1e-5] + +[parameters."orchestrator.train.sampling.temperature"] +values = [0.7, 1.0] +``` + +### Random + +Random search draws independent samples from the declared distributions. Set a +seed when you need reproducible trial IDs and resume behavior: + +```toml +[strategy] +type = "random" +num_trials = 8 +seed = 42 + +[parameters."trainer.optim.lr"] +distribution = "log_uniform" +min = 1e-7 +max = 1e-4 +``` + +Supported parameter distributions are `choice`, `uniform`, `log_uniform`, and +`int_uniform`. + +### Optuna + +Optuna proposes one trial at a time for local sweeps, or one wave at a time for +`multi_run_lora` sweeps. Optuna requires an `[objective]`. + +```toml +[strategy] +type = "optuna" +num_trials = 12 +seed = 42 +sampler = "tpe" + +[objective] +metric = "reward/reverse-text/mean" +direction = "maximize" +``` + +Use `strategy.storage` to persist the Optuna study across resume: + +```toml +[strategy] +type = "optuna" +num_trials = 12 +storage = "sqlite:///outputs/studies/reverse-text-optuna/optuna.db" +study_name = "reverse-text-optuna" +``` + +Optuna pruners read intermediate metrics from each trial's local +`metrics.jsonl` sidecar: + +```toml +[strategy.pruner] +type = "median" +n_startup_trials = 2 +n_warmup_steps = 1 +interval_steps = 1 +``` + +Supported pruners are `none`, `median`, `asha`, and `hyperband`. + +## Schedulers + +### Local + +The local scheduler runs target commands as subprocesses on the current machine. +With `max_parallel = 1`, the controller runs trials sequentially. + +```toml +[scheduler] +type = "local" +max_parallel = 1 +``` + +Parallel local sweeps require explicit disjoint GPU groups. Each worker gets a +`CUDA_VISIBLE_DEVICES` value from one group: + +```toml +[scheduler] +type = "local" +max_parallel = 2 + +[scheduler.gpu_assignment] +mode = "static" +visible_devices = [[0, 1], [2, 3]] +``` + +The validator rejects `max_parallel > 1` without GPU assignment so parallel +workers cannot silently colocate trainer and inference stacks on the same GPUs. + +### SLURM + +The SLURM scheduler submits one target job per trial through the target +entrypoint's existing `[slurm]` support, then exits. + +```toml +[scheduler] +type = "slurm" +``` + +The base target config must include a valid `[slurm]` block, usually by composing +the normal run config with a SLURM overlay: + +```toml +base = [ + "examples/reverse_text/rl.toml", + "examples/reverse_text/slurm_rl.toml", +] +``` + +SLURM sweeps are asynchronous after submission, so early stopping and Optuna are +not supported with `scheduler.type = "slurm"`. + +### Shared-Trainer LoRA + +The `multi_run_lora` scheduler is RL-only. It launches one shared trainer and +one orchestrator per trial. Each orchestrator trains a separate LoRA adapter slot +through the trainer's `MultiRunManager`. + +```toml +entrypoint = "rl" +base = ["examples/reverse_text/rl_multi_run_lora_disagg.toml"] + +[scheduler] +type = "multi_run_lora" +max_concurrent_runs = 3 +shared = ["examples/reverse_text/rl_multi_run_lora_disagg.toml"] +``` + +Requirements: + +- The shared config must enable `trainer.model.lora`. +- `trainer.max_concurrent_runs` must be at least + `scheduler.max_concurrent_runs`. +- Only per-run orchestrator fields may vary. Trainer, model, deployment, and + inference fields are shared by the wave and cannot be swept inside one + `multi_run_lora` launch. + +Static grid/random `multi_run_lora` sweeps launch the whole wave at once. +Optuna `multi_run_lora` sweeps run in waves of `scheduler.max_concurrent_runs`; +the controller tells Optuna each wave's results before asking for the next wave. + +## Objectives and Metrics + +Set `[objective]` when the sweep should rank trials, perform early stopping, or +run Optuna: + +```toml +[objective] +metric = "reward/reverse-text/mean" +direction = "maximize" +``` + +During sweep runs, the launcher sets `PRIME_RL_SWEEP_METRICS_JSONL` for each +trial subprocess. `FileMonitor` writes step-indexed metrics to +`/metrics.jsonl`. The controller reads the latest valid-step row for +final objective attribution and polls the same file for Optuna pruning. + +If the sidecar has no usable objective value, the controller falls back to +legacy `run-*/final_summary.json` files. A clean process exit without a finite +objective is recorded as a failed trial with `failure_stage = "objective"`. + +## Early Stopping + +Early stopping applies after completed trials. It can stop future local trials +or future Optuna `multi_run_lora` waves, but it does not cancel already-running +siblings. + +```toml +[early_stopping] +type = "patience" +patience = 3 +min_trials = 5 +``` + +Threshold stopping is also available: + +```toml +[early_stopping] +type = "threshold" +threshold = 0.7 +min_trials = 2 +``` + +Early stopping requires an objective and is not supported with the SLURM +scheduler. Static `multi_run_lora` launches the whole wave at once, so use +Optuna `multi_run_lora` when you need wave-by-wave stopping. + +## Artifacts and Resume + +Each standard trial gets a stable directory under the study output directory: + +```text +outputs/studies/reverse-text-lr/ + study.toml + manifest.json + trials/ + 0000-a1b2c3d4/ + overrides.toml + resolved.toml + command.txt + status.json + run/ + metrics.jsonl +``` + +The manifest records trial metadata, commands, resolved-config checksums, base +file checksums, git metadata, and objective summaries. Trial IDs have the form +`-`, where the hash is derived from the flat parameter override +dict. + +Use `resume = true` to reuse completed grid/random trials. Resume fails closed +if the previous manifest is malformed, the target entrypoint changes, the +objective changes, parameter order changes, base file checksums drift, or a +terminal status cannot be trusted. + +Optuna resume requires persistent `strategy.storage`. On resume, the controller +reconciles leftover RUNNING Optuna trials against `status.json` so Optuna +storage, the manifest, and trial artifacts agree. + +`multi_run_lora` resume against a still-running shared trainer is not supported. + +## Failure Handling + +Two fields control failure behavior: + +```toml +continue_on_failure = true +retry_budget = 1 +``` + +`retry_budget` applies to launch failures and failed trial processes where the +scheduler can safely retry. If `continue_on_failure = false`, the controller +stops launching new work after the first failed materialization, runtime failure, +or missing objective, then exits non-zero after writing status and manifest +updates. + +For `multi_run_lora`, `rl-multi-run` writes each orchestrator's return code to +`/control/exit_code`. The controller reconciles per-trial state from +those files instead of marking every trial failed from one aggregate launcher +return code. + +## Examples + +- `examples/sweep/grid_local.toml` — grid search with local sequential execution. +- `examples/sweep/random_local.toml` — seeded random search. +- `examples/sweep/parallel_local.toml` — local parallel trials with explicit GPU groups. +- `examples/sweep/optuna_local.toml` — Optuna TPE ask/tell loop. +- `examples/sweep/optuna_pruner_median.toml` — Optuna median pruning. +- `examples/sweep/optuna_pruner_asha.toml` — Optuna ASHA pruning. +- `examples/sweep/optuna_pruner_hyperband.toml` — Optuna Hyperband pruning. +- `examples/sweep/slurm.toml` — one target SLURM job per trial. +- `examples/sweep/grid_multi_run_lora.toml` — static shared-trainer LoRA sweep. +- `examples/sweep/optuna_multi_run_lora.toml` — Optuna waves over shared-trainer LoRA runs. diff --git a/examples/sweep/early_stop_patience.toml b/examples/sweep/early_stop_patience.toml new file mode 100644 index 0000000000..c31fd9ff76 --- /dev/null +++ b/examples/sweep/early_stop_patience.toml @@ -0,0 +1,37 @@ +# Early stopping (patience): halt the sweep after N consecutive +# completed trials fail to improve on the best-so-far. +# +# Demonstrates: [early_stopping] type="patience". After patience=1 +# consecutive trials with no improvement (and min_trials trials +# completed), the sweep halts new submissions. In-flight trials are +# not killed. +# +# Order: trial 0 sets the baseline reward, trial 1 has a bad LR that +# fails to improve, so steps_without_improvement reaches patience=1 +# and the sweep halts before trial 2 launches. +# +# Run: +# uv run sweep @ examples/sweep/early_stop_patience.toml +name = "reverse-text-early-stop-patience" +entrypoint = "rl" +base = ["examples/sweep/rl.toml"] +output_dir = "outputs/studies/reverse-text-early-stop-patience" + +[strategy] +type = "grid" + +[scheduler] +type = "local" +max_parallel = 1 + +[objective] +metric = "reward/reverse-text/mean" +direction = "maximize" + +[early_stopping] +type = "patience" +patience = 1 +min_trials = 1 + +[parameters."trainer.optim.lr"] +values = [3e-6, 1e-3, 1e-5] diff --git a/examples/sweep/early_stop_threshold.toml b/examples/sweep/early_stop_threshold.toml new file mode 100644 index 0000000000..72df130121 --- /dev/null +++ b/examples/sweep/early_stop_threshold.toml @@ -0,0 +1,37 @@ +# Early stopping (threshold): halt the sweep when a trial's objective +# crosses to the wrong side of a configured floor. +# +# Demonstrates: [early_stopping] type="threshold". With direction=maximize +# and threshold=0.10, the sweep halts after any completed trial whose +# objective is below 0.10 (and at least min_trials trials have completed). +# In-flight trials are not killed; new trials are not submitted. +# +# Order matters for grid expansion: trial 0 should pass the threshold, +# trial 1 has a known-bad LR that collapses the reward, so the controller +# halts after trial 1 and trial 2 stays state="pending" in the manifest. +# +# Run: +# uv run sweep @ examples/sweep/early_stop_threshold.toml +name = "reverse-text-early-stop-threshold" +entrypoint = "rl" +base = ["examples/sweep/rl.toml"] +output_dir = "outputs/studies/reverse-text-early-stop-threshold" + +[strategy] +type = "grid" + +[scheduler] +type = "local" +max_parallel = 1 + +[objective] +metric = "reward/reverse-text/mean" +direction = "maximize" + +[early_stopping] +type = "threshold" +threshold = 0.10 +min_trials = 1 + +[parameters."trainer.optim.lr"] +values = [3e-6, 1e-3, 1e-5] diff --git a/examples/sweep/early_stop_threshold_slurm.toml b/examples/sweep/early_stop_threshold_slurm.toml new file mode 100644 index 0000000000..8d1f9b64dd --- /dev/null +++ b/examples/sweep/early_stop_threshold_slurm.toml @@ -0,0 +1,41 @@ +# Early stopping (threshold) over the synchronous SLURM scheduler. +# +# Demonstrates: pairing [early_stopping] with scheduler.type = "slurm" + +# synchronous = true. The controller observes each trial's objective via +# sbatch --wait + scontrol/metrics fallback, then checks the tracker +# between trials. When a completed trial's objective falls on the wrong +# side of the threshold, the controller halts: no further sbatch +# submissions, and the in-flight trial (if any) is allowed to finish. +# +# Order matters for grid expansion: trial 0 should pass the threshold, +# trial 1 has a known-bad LR that collapses the reward, so the controller +# halts after trial 1 and trial 2 stays state="pending" in the manifest. +# +# Requires: same SLURM prerequisites as examples/sweep/slurm.toml +# (sbatch on PATH, reachable cluster, slurm_base.toml edited). +# +# Run: +# uv run sweep @ examples/sweep/early_stop_threshold_slurm.toml +name = "reverse-text-early-stop-threshold-slurm" +entrypoint = "rl" +base = ["examples/reverse_text/rl.toml", "examples/sweep/slurm_base.toml"] +output_dir = "outputs/studies/reverse-text-early-stop-threshold-slurm" + +[strategy] +type = "grid" + +[scheduler] +type = "slurm" +synchronous = true + +[objective] +metric = "reward/reverse-text/mean" +direction = "maximize" + +[early_stopping] +type = "threshold" +threshold = 0.10 +min_trials = 1 + +[parameters."trainer.optim.lr"] +values = [3e-6, 1e-3, 1e-5] diff --git a/examples/sweep/grid_local.toml b/examples/sweep/grid_local.toml new file mode 100644 index 0000000000..2b2bc2e37f --- /dev/null +++ b/examples/sweep/grid_local.toml @@ -0,0 +1,27 @@ +# Grid search over learning rate for the reverse-text RL example. +# +# Demonstrates: basic sweep entrypoint, grid strategy, local scheduler. +# Each trial spawns its own `uv run rl @ examples/sweep/rl.toml @ +# ` subprocess and writes artifacts under +# `output_dir/trials//`. +# +# Run with: +# uv run sweep @ examples/sweep/grid_local.toml +name = "reverse-text-lr" +entrypoint = "rl" +base = ["examples/sweep/rl.toml"] +output_dir = "outputs/studies/reverse-text-lr" + +[strategy] +type = "grid" + +[scheduler] +type = "local" +max_parallel = 1 + +[objective] +metric = "reward/reverse-text/mean" +direction = "maximize" + +[parameters."trainer.optim.lr"] +values = [1e-6, 3e-6, 1e-5] diff --git a/examples/sweep/grid_local_disagg.toml b/examples/sweep/grid_local_disagg.toml new file mode 100644 index 0000000000..8eb63364b6 --- /dev/null +++ b/examples/sweep/grid_local_disagg.toml @@ -0,0 +1,33 @@ +# Grid search against a disaggregated inference server. +# +# Demonstrates: running a sweep when inference lives on a separate rig +# (no [inference] block in the base config; orchestrator points at a +# remote vLLM via client.base_url). +# +# Prereq: +# - vLLM running on the inference rig serving the same model as the +# trainer-side base (PrimeIntellect/Qwen3-0.6B-Reverse-Text-SFT). +# - examples/sweep/rl_disagg.toml points client.base_url at it. +# - Shared NAS mount visible from both rigs at the same absolute path +# so weight broadcasts are reachable. +# +# Run: +# uv run sweep @ examples/sweep/grid_local_disagg.toml +name = "reverse-text-disagg-lr" +entrypoint = "rl" +base = ["examples/sweep/rl_disagg.toml"] +output_dir = "outputs/studies/reverse-text-disagg-lr" + +[strategy] +type = "grid" + +[scheduler] +type = "local" +max_parallel = 1 + +[objective] +metric = "reward/reverse-text/mean" +direction = "maximize" + +[parameters."trainer.optim.lr"] +values = [1e-6, 3e-6] diff --git a/examples/sweep/grid_multi_run_lora.toml b/examples/sweep/grid_multi_run_lora.toml new file mode 100644 index 0000000000..568065753f --- /dev/null +++ b/examples/sweep/grid_multi_run_lora.toml @@ -0,0 +1,34 @@ +# Grid sweep using the multi_run_lora scheduler. +# +# Demonstrates: shared-trainer LoRA sweeps. rl-multi-run launches one +# trainer + one inference + N orchestrators against pre-materialized +# run dirs. The trainer's MultiRunManager rotates through each run's +# LoRA slot; per-trial overrides are restricted to the orchestrator-safe +# allowlist (orchestrator.optim.*, orchestrator.model.lora.*, sampling, +# env, batch/buffer/eval). +# +# Requires: +# - The shared base must be a LoRA-enabled RLConfig with +# trainer.max_concurrent_runs >= scheduler.max_concurrent_runs. +# - LoRA-enabled inference server (either local via [inference] in the +# shared base, or external via orchestrator.client.base_url and +# env VLLM_ALLOW_RUNTIME_LORA_UPDATING=True on the inference rig). +# +# Run: +# uv run sweep @ examples/sweep/grid_multi_run_lora.toml +name = "reverse-text-multi-run-lora" +entrypoint = "rl" +base = ["examples/sweep/rl_multi_run_lora_disagg.toml"] +output_dir = "outputs/studies/reverse-text-multi-run-lora" + +[scheduler] +type = "multi_run_lora" +max_concurrent_runs = 3 +shared = ["examples/sweep/rl_multi_run_lora_disagg.toml"] + +[objective] +metric = "reward/reverse-text/mean" +direction = "maximize" + +[parameters."orchestrator.optim.lr"] +values = [1e-6, 3e-6, 1e-5] diff --git a/examples/sweep/grid_multi_run_lora_local.toml b/examples/sweep/grid_multi_run_lora_local.toml new file mode 100644 index 0000000000..20887ea033 --- /dev/null +++ b/examples/sweep/grid_multi_run_lora_local.toml @@ -0,0 +1,25 @@ +# Multi-run-LoRA sweep, self-contained on a single multi-GPU machine. +# +# Same workflow as grid_multi_run_lora.toml but uses a local inference +# server (rl_multi_run_lora.toml) instead of pointing at a remote rig. +# Useful when you have ≥2 GPUs on one box and don't want to set up +# disagg. +# +# Run: +# VLLM_ALLOW_RUNTIME_LORA_UPDATING=True uv run sweep @ examples/sweep/grid_multi_run_lora_local.toml +name = "reverse-text-multi-run-lora-local" +entrypoint = "rl" +base = ["examples/sweep/rl_multi_run_lora.toml"] +output_dir = "outputs/studies/reverse-text-multi-run-lora-local" + +[scheduler] +type = "multi_run_lora" +max_concurrent_runs = 3 +shared = ["examples/sweep/rl_multi_run_lora.toml"] + +[objective] +metric = "reward/reverse-text/mean" +direction = "maximize" + +[parameters."orchestrator.optim.lr"] +values = [1e-6, 3e-6, 1e-5] diff --git a/examples/sweep/optuna_local.toml b/examples/sweep/optuna_local.toml new file mode 100644 index 0000000000..4630ef0e27 --- /dev/null +++ b/examples/sweep/optuna_local.toml @@ -0,0 +1,33 @@ +# Optuna (TPE) adaptive sweep, single trial at a time. +# +# Demonstrates: optuna strategy with TPE sampler over a log-uniform +# parameter. The controller asks Optuna for one trial, materializes +# it, runs it through the local scheduler, and reports the objective +# back before asking for the next trial. With seed pinned, the +# proposed LR sequence is reproducible across runs. +# +# Run: +# uv run sweep @ examples/sweep/optuna_local.toml +name = "reverse-text-optuna" +entrypoint = "rl" +base = ["examples/sweep/rl.toml"] +output_dir = "outputs/studies/reverse-text-optuna" + +[strategy] +type = "optuna" +num_trials = 6 +seed = 42 +sampler = "tpe" + +[scheduler] +type = "local" +max_parallel = 1 + +[objective] +metric = "reward/reverse-text/mean" +direction = "maximize" + +[parameters."trainer.optim.lr"] +distribution = "log_uniform" +min = 1e-7 +max = 1e-4 diff --git a/examples/sweep/optuna_multi_run_lora.toml b/examples/sweep/optuna_multi_run_lora.toml new file mode 100644 index 0000000000..5b88603b61 --- /dev/null +++ b/examples/sweep/optuna_multi_run_lora.toml @@ -0,0 +1,37 @@ +# Optuna sweep using the multi_run_lora wave driver. +# +# Demonstrates: Optuna + multi_run_lora. num_trials trials are split +# into waves of size max_concurrent_runs; each wave spawns one +# rl-multi-run invocation with that many orchestrators running against +# the same shared trainer. The controller tells Optuna each wave's +# results before proposing the next wave, so wave N+1 benefits from +# wave N's history. A non-trivial pruner (median/asha/hyperband) can +# also prune within a wave by polling each run's metrics.jsonl — +# see optuna_pruner_median.toml for that variant. +# +# Run: +# uv run sweep @ examples/sweep/optuna_multi_run_lora.toml +name = "reverse-text-optuna-multi-run-lora" +entrypoint = "rl" +base = ["examples/sweep/rl_multi_run_lora_disagg.toml"] +output_dir = "outputs/studies/reverse-text-optuna-multi-run-lora" + +[strategy] +type = "optuna" +num_trials = 6 +seed = 42 +sampler = "tpe" + +[scheduler] +type = "multi_run_lora" +max_concurrent_runs = 3 +shared = ["examples/sweep/rl_multi_run_lora_disagg.toml"] + +[objective] +metric = "reward/reverse-text/mean" +direction = "maximize" + +[parameters."orchestrator.optim.lr"] +distribution = "log_uniform" +min = 1e-7 +max = 1e-4 diff --git a/examples/sweep/optuna_multi_run_lora_local.toml b/examples/sweep/optuna_multi_run_lora_local.toml new file mode 100644 index 0000000000..ce3a80f045 --- /dev/null +++ b/examples/sweep/optuna_multi_run_lora_local.toml @@ -0,0 +1,32 @@ +# Optuna + multi_run_lora wave driver, self-contained on a single +# multi-GPU machine. +# +# Same as optuna_multi_run_lora.toml but uses a local inference server +# (rl_multi_run_lora.toml) instead of pointing at a remote rig. +# +# Run: +# VLLM_ALLOW_RUNTIME_LORA_UPDATING=True uv run sweep @ examples/sweep/optuna_multi_run_lora_local.toml +name = "reverse-text-optuna-multi-run-lora-local" +entrypoint = "rl" +base = ["examples/sweep/rl_multi_run_lora.toml"] +output_dir = "outputs/studies/reverse-text-optuna-multi-run-lora-local" + +[strategy] +type = "optuna" +num_trials = 6 +seed = 42 +sampler = "tpe" + +[scheduler] +type = "multi_run_lora" +max_concurrent_runs = 3 +shared = ["examples/sweep/rl_multi_run_lora.toml"] + +[objective] +metric = "reward/reverse-text/mean" +direction = "maximize" + +[parameters."orchestrator.optim.lr"] +distribution = "log_uniform" +min = 1e-7 +max = 1e-4 diff --git a/examples/sweep/optuna_pruner_asha.toml b/examples/sweep/optuna_pruner_asha.toml new file mode 100644 index 0000000000..4ac9098666 --- /dev/null +++ b/examples/sweep/optuna_pruner_asha.toml @@ -0,0 +1,50 @@ +# Optuna with ASHA (Asynchronous Successive Halving) pruning. +# +# Demonstrates: aggressive resource-bracketed pruning. ASHA evaluates +# trials at increasingly demanding "rungs": at each rung, only the +# top 1/reduction_factor of trials are promoted to the next rung; the +# rest are pruned. Compared to median, ASHA prunes earlier and harder, +# which is useful when most of the search space is bad and you want to +# fail-fast. +# +# Like the median variant, this requires intermediate metric reports +# from trials. For real use, trials should run long enough that +# multiple rungs can be evaluated; the smoke-test scale here will not +# produce meaningful pruning decisions. +# +# Hyperband is also available (type = "hyperband"); it runs multiple +# ASHA brackets with different min_resource values to hedge across +# resource budgets. +# +# Run: +# uv run sweep @ examples/sweep/optuna_pruner_asha.toml +name = "reverse-text-optuna-asha" +entrypoint = "rl" +base = ["examples/sweep/rl.toml"] +output_dir = "outputs/studies/reverse-text-optuna-asha" + +[strategy] +type = "optuna" +num_trials = 12 +seed = 42 +sampler = "tpe" +poll_interval_seconds = 5.0 + +[strategy.pruner] +type = "asha" +min_resource = 1 +reduction_factor = 3 +min_early_stopping_rate = 0 + +[scheduler] +type = "local" +max_parallel = 1 + +[objective] +metric = "reward/reverse-text/mean" +direction = "maximize" + +[parameters."trainer.optim.lr"] +distribution = "log_uniform" +min = 1e-7 +max = 1e-4 diff --git a/examples/sweep/optuna_pruner_asha_slurm.toml b/examples/sweep/optuna_pruner_asha_slurm.toml new file mode 100644 index 0000000000..454ce5fc4e --- /dev/null +++ b/examples/sweep/optuna_pruner_asha_slurm.toml @@ -0,0 +1,54 @@ +# Optuna + ASHA pruner over the synchronous SLURM scheduler. +# +# ASHA (Asynchronous Successive Halving) prunes more aggressively than the +# median pruner: trials are evaluated at increasingly demanding "rungs", and +# at each rung only the top 1/reduction_factor are promoted while the rest +# are pruned. Useful when most of the search space is bad and you want to +# fail-fast on cluster time. +# +# How it works on SLURM-sync: each trial is submitted via `sbatch --parsable` +# (captures the job id), the controller polls shared-FS metrics.jsonl while +# the job runs, forwards (step, value) reports to `optuna_trial.report`, +# and on `should_prune()` issues `scancel `. Trials are serialized at +# the controller, so ASHA's *adaptive* behavior is preserved even though the +# wall-clock concurrency is one. +# +# Trials need enough steps to cross multiple rungs for ASHA to do useful +# work. Tune `min_resource` (first rung) and `reduction_factor` to the +# trial length you actually run. +# +# Requires a shared filesystem between the controller and compute nodes +# so the controller can read metrics.jsonl as the trial writes it. +# +# Run: +# uv run sweep @ examples/sweep/optuna_pruner_asha_slurm.toml +name = "reverse-text-optuna-asha-slurm" +entrypoint = "rl" +base = ["examples/reverse_text/rl.toml", "examples/sweep/slurm_base.toml"] +output_dir = "outputs/studies/reverse-text-optuna-asha-slurm" + +[strategy] +type = "optuna" +num_trials = 12 +seed = 42 +sampler = "tpe" +poll_interval_seconds = 10.0 + +[strategy.pruner] +type = "asha" +min_resource = 1 +reduction_factor = 3 +min_early_stopping_rate = 0 + +[scheduler] +type = "slurm" +synchronous = true + +[objective] +metric = "reward/reverse-text/mean" +direction = "maximize" + +[parameters."trainer.optim.lr"] +distribution = "log_uniform" +min = 1e-7 +max = 1e-4 diff --git a/examples/sweep/optuna_pruner_hyperband.toml b/examples/sweep/optuna_pruner_hyperband.toml new file mode 100644 index 0000000000..1f6c89e63b --- /dev/null +++ b/examples/sweep/optuna_pruner_hyperband.toml @@ -0,0 +1,47 @@ +# Optuna with Hyperband pruning. +# +# Demonstrates: Hyperband runs multiple ASHA brackets in parallel, +# each with a different min_resource. Bracket 0 starts pruning at +# min_resource; bracket 1 at min_resource * reduction_factor; etc. +# This hedges across resource budgets — bracket 0 fail-fasts the +# obviously bad trials, deeper brackets give borderline trials a +# longer look. Compared to plain ASHA, Hyperband prunes less +# aggressively because some trials get a free pass in the deeper +# brackets. +# +# Requires intermediate metric reports from the trial (the trial +# subprocess writes metrics.jsonl via FileMonitor when the sweep +# launcher sets PRIME_RL_SWEEP_METRICS_JSONL). +# +# Run: +# uv run sweep @ examples/sweep/optuna_pruner_hyperband.toml +name = "reverse-text-optuna-hyperband" +entrypoint = "rl" +base = ["examples/sweep/rl.toml"] +output_dir = "outputs/studies/reverse-text-optuna-hyperband" + +[strategy] +type = "optuna" +num_trials = 12 +seed = 42 +sampler = "tpe" +poll_interval_seconds = 5.0 + +[strategy.pruner] +type = "hyperband" +min_resource = 1 +max_resource = "auto" +reduction_factor = 3 + +[scheduler] +type = "local" +max_parallel = 1 + +[objective] +metric = "reward/reverse-text/mean" +direction = "maximize" + +[parameters."trainer.optim.lr"] +distribution = "log_uniform" +min = 1e-7 +max = 1e-4 diff --git a/examples/sweep/optuna_pruner_hyperband_slurm.toml b/examples/sweep/optuna_pruner_hyperband_slurm.toml new file mode 100644 index 0000000000..5e4cbab3e1 --- /dev/null +++ b/examples/sweep/optuna_pruner_hyperband_slurm.toml @@ -0,0 +1,55 @@ +# Optuna + Hyperband pruner over the synchronous SLURM scheduler. +# +# Hyperband runs multiple ASHA brackets in parallel, each with a different +# min_resource value, to hedge across resource budgets — instead of picking +# one min_resource and one reduction_factor like ASHA, Hyperband sweeps +# across several so neither extreme of "prune too early" nor "prune too +# late" dominates. +# +# How it works on SLURM-sync: each trial is submitted via `sbatch --parsable` +# (captures the job id), the controller polls shared-FS metrics.jsonl while +# the job runs, forwards (step, value) reports to `optuna_trial.report`, +# and on `should_prune()` issues `scancel `. Trials are serialized at +# the controller, so Hyperband's adaptive behavior is preserved even though +# the wall-clock concurrency is one. +# +# Trials need enough steps for Hyperband's outer ASHA brackets to fire. +# Tune min_resource (smallest bracket), max_resource (largest bracket), +# and reduction_factor to your trial step budget. Setting max_resource to +# the trial's max_steps is a common choice. +# +# Requires a shared filesystem between the controller and compute nodes +# so the controller can read metrics.jsonl as the trial writes it. +# +# Run: +# uv run sweep @ examples/sweep/optuna_pruner_hyperband_slurm.toml +name = "reverse-text-optuna-hyperband-slurm" +entrypoint = "rl" +base = ["examples/reverse_text/rl.toml", "examples/sweep/slurm_base.toml"] +output_dir = "outputs/studies/reverse-text-optuna-hyperband-slurm" + +[strategy] +type = "optuna" +num_trials = 16 +seed = 42 +sampler = "tpe" +poll_interval_seconds = 10.0 + +[strategy.pruner] +type = "hyperband" +min_resource = 1 +max_resource = 30 +reduction_factor = 3 + +[scheduler] +type = "slurm" +synchronous = true + +[objective] +metric = "reward/reverse-text/mean" +direction = "maximize" + +[parameters."trainer.optim.lr"] +distribution = "log_uniform" +min = 1e-7 +max = 1e-4 diff --git a/examples/sweep/optuna_pruner_median.toml b/examples/sweep/optuna_pruner_median.toml new file mode 100644 index 0000000000..618e725274 --- /dev/null +++ b/examples/sweep/optuna_pruner_median.toml @@ -0,0 +1,51 @@ +# Optuna with intermediate-metric pruning (median). +# +# Demonstrates: Optuna pruners. While each trial runs, the controller +# polls the trial subprocess's metrics.jsonl sidecar every +# poll_interval_seconds and reports each new (step, value) to Optuna +# via trial.report. If trial.should_prune fires, the controller +# SIGTERMs the trial's process group and marks status="pruned". +# +# MedianPruner prunes a trial whose intermediate value falls below +# the running median of completed trials at the same step. With +# n_startup_trials=2 below, pruning is enabled after 2 trials have +# completed in full; n_warmup_steps=1 skips pruning during the very +# first step of each trial. +# +# Trade-off: pruning requires intermediate metric reports, so trials +# must run long enough (typically 50+ steps) for pruning decisions +# to be meaningful. The 30 steps below is a smoke-test scale; bump +# in real use. +# +# Run: +# uv run sweep @ examples/sweep/optuna_pruner_median.toml +name = "reverse-text-optuna-median" +entrypoint = "rl" +base = ["examples/sweep/rl.toml"] +output_dir = "outputs/studies/reverse-text-optuna-median" + +[strategy] +type = "optuna" +num_trials = 8 +seed = 42 +sampler = "tpe" +poll_interval_seconds = 5.0 + +[strategy.pruner] +type = "median" +n_startup_trials = 2 +n_warmup_steps = 1 +interval_steps = 1 + +[scheduler] +type = "local" +max_parallel = 1 + +[objective] +metric = "reward/reverse-text/mean" +direction = "maximize" + +[parameters."trainer.optim.lr"] +distribution = "log_uniform" +min = 1e-7 +max = 1e-4 diff --git a/examples/sweep/optuna_pruner_slurm.toml b/examples/sweep/optuna_pruner_slurm.toml new file mode 100644 index 0000000000..f0c4dc65fd --- /dev/null +++ b/examples/sweep/optuna_pruner_slurm.toml @@ -0,0 +1,44 @@ +# Optuna + median pruner over the synchronous SLURM scheduler. +# +# How it works: the controller submits each trial with `sbatch --parsable` +# (captures the job id), polls the shared-FS metrics.jsonl while the job +# runs, forwards each new (step, value) to `optuna_trial.report`, and on +# `should_prune()` issues `scancel ` to terminate the SLURM job. +# Trials are serialized at the controller (one in-flight at a time); +# pruning short-circuits before the trial finishes its full step budget. +# +# Requires a shared filesystem between the controller and compute nodes +# so the controller can read metrics.jsonl as the trial writes it. +# +# Run: +# uv run sweep @ examples/sweep/optuna_pruner_slurm.toml +name = "reverse-text-optuna-pruner-slurm" +entrypoint = "rl" +base = ["examples/reverse_text/rl.toml", "examples/sweep/slurm_base.toml"] +output_dir = "outputs/studies/reverse-text-optuna-pruner-slurm" + +[strategy] +type = "optuna" +num_trials = 8 +seed = 42 +sampler = "tpe" +poll_interval_seconds = 10.0 + +[strategy.pruner] +type = "median" +n_startup_trials = 2 +n_warmup_steps = 5 +interval_steps = 1 + +[scheduler] +type = "slurm" +synchronous = true + +[objective] +metric = "reward/reverse-text/mean" +direction = "maximize" + +[parameters."trainer.optim.lr"] +distribution = "log_uniform" +min = 1e-7 +max = 1e-4 diff --git a/examples/sweep/optuna_pruner_slurm_parallel.toml b/examples/sweep/optuna_pruner_slurm_parallel.toml new file mode 100644 index 0000000000..b33935e932 --- /dev/null +++ b/examples/sweep/optuna_pruner_slurm_parallel.toml @@ -0,0 +1,60 @@ +# Optuna + median pruner over synchronous SLURM with parallel trials. +# +# Combines two features: +# - scheduler.max_parallel = N → up to N concurrent SLURM jobs from a +# single Optuna study, driven by a ThreadPoolExecutor. +# - strategy.pruner = "median" → each in-flight worker polls its +# metrics.jsonl, forwards (step, value) to optuna_trial.report, and +# scancel's the SLURM job if should_prune fires. +# +# Concurrency model: each worker thread holds its own optuna_trial (each +# from a distinct study.ask()). report/should_prune calls go through +# Optuna's storage layer, which serializes them — the same contract that +# makes Optuna's stock study.optimize(n_jobs > 1) work. TPE auto-enables +# constant_liar so concurrent asks diversify across the search space. +# +# Use this when each individual trial is short enough that you want +# adaptive pruning, AND large enough that you want the cluster filling +# multiple trials at once. For very short trials, the no-pruner parallel +# example is simpler; for very long single-node trials, the serial +# pruner+slurm-sync example uses one trial in flight. +# +# Requires: +# - sbatch on PATH. +# - A shared filesystem between the controller and compute nodes so the +# controller can read metrics.jsonl as the trial writes it. +# - Edit examples/sweep/slurm_base.toml to match your cluster. +# +# Run: +# uv run sweep @ examples/sweep/optuna_pruner_slurm_parallel.toml +name = "reverse-text-optuna-pruner-slurm-parallel" +entrypoint = "rl" +base = ["examples/reverse_text/rl.toml", "examples/sweep/slurm_base.toml"] +output_dir = "outputs/studies/reverse-text-optuna-pruner-slurm-parallel" + +[strategy] +type = "optuna" +num_trials = 12 +seed = 42 +sampler = "tpe" +poll_interval_seconds = 10.0 + +[strategy.pruner] +type = "median" +n_startup_trials = 2 +n_warmup_steps = 5 +interval_steps = 1 + +[scheduler] +type = "slurm" +synchronous = true +max_parallel = 3 + +[objective] +metric = "reward/reverse-text/mean" +direction = "maximize" + +[parameters."trainer.optim.lr"] +distribution = "log_uniform" +min = 1e-7 +max = 1e-4 diff --git a/examples/sweep/optuna_slurm.toml b/examples/sweep/optuna_slurm.toml new file mode 100644 index 0000000000..f609a6ea36 --- /dev/null +++ b/examples/sweep/optuna_slurm.toml @@ -0,0 +1,42 @@ +# Optuna over the synchronous SLURM scheduler. +# +# Demonstrates: scheduler.synchronous = true. The controller submits each +# trial via `sbatch --wait`, blocking until the job exits, so Optuna can +# observe each trial's objective before proposing the next. Trials still +# run on a SLURM-managed compute node; the controller process holds one +# slot on the login/controller node for the duration of the sweep. +# +# Tradeoff: trials are serialized at the controller (one in-flight at a +# time), so this is most useful when each individual trial is too large +# to run anywhere except the cluster, but you still want adaptive search. +# For embarrassingly parallel grid/random sweeps with a stable queue, +# stick with the default (synchronous = false). +# +# Pruners are supported on this scheduler — see +# examples/sweep/optuna_pruner_slurm.toml for a pruning variant. +# +# Run: +# uv run sweep @ examples/sweep/optuna_slurm.toml +name = "reverse-text-optuna-slurm" +entrypoint = "rl" +base = ["examples/reverse_text/rl.toml", "examples/sweep/slurm_base.toml"] +output_dir = "outputs/studies/reverse-text-optuna-slurm" + +[strategy] +type = "optuna" +num_trials = 6 +seed = 42 +sampler = "tpe" + +[scheduler] +type = "slurm" +synchronous = true + +[objective] +metric = "reward/reverse-text/mean" +direction = "maximize" + +[parameters."trainer.optim.lr"] +distribution = "log_uniform" +min = 1e-7 +max = 1e-4 diff --git a/examples/sweep/optuna_slurm_parallel.toml b/examples/sweep/optuna_slurm_parallel.toml new file mode 100644 index 0000000000..85a5514f37 --- /dev/null +++ b/examples/sweep/optuna_slurm_parallel.toml @@ -0,0 +1,47 @@ +# Optuna (TPE) over the synchronous SLURM scheduler with parallel trials. +# +# How it works: the controller maintains up to ``max_parallel`` concurrent +# in-flight SLURM jobs. As each job exits, the controller tells Optuna the +# result and immediately asks for one more trial to refill the slot. With +# TPE, ``constant_liar=True`` is enabled automatically so concurrent asks +# don't collide on the same region of the search space — a placeholder +# objective is assigned to running trials, biasing the next ask to +# diversify rather than re-propose nearby params. +# +# Cluster fit: each SLURM job uses whatever resources its `[slurm]` +# section requests (nodes, GPUs per node). If your trials each need 1 +# node and your cluster has 3 nodes, set ``max_parallel = 3``. Trials +# wider than 1 node interact normally with the SLURM queue; ``max_parallel`` +# caps the controller-managed in-flight count, not the queue itself. +# +# Pruners are intentionally NOT supported in this mode (the pruning loop +# owns the optuna_trial object for the trial's lifetime and Optuna trial +# objects are not safe to share across polling threads). For pruned +# sweeps, fall back to scheduler.synchronous=true with max_parallel=1. +# +# Run: +# uv run sweep @ examples/sweep/optuna_slurm_parallel.toml +name = "reverse-text-optuna-slurm-parallel" +entrypoint = "rl" +base = ["examples/reverse_text/rl.toml", "examples/sweep/slurm_base.toml"] +output_dir = "outputs/studies/reverse-text-optuna-slurm-parallel" + +[strategy] +type = "optuna" +num_trials = 12 +seed = 42 +sampler = "tpe" + +[scheduler] +type = "slurm" +synchronous = true +max_parallel = 3 + +[objective] +metric = "reward/reverse-text/mean" +direction = "maximize" + +[parameters."trainer.optim.lr"] +distribution = "log_uniform" +min = 1e-7 +max = 1e-4 diff --git a/examples/sweep/optuna_slurm_resume.toml b/examples/sweep/optuna_slurm_resume.toml new file mode 100644 index 0000000000..f343631998 --- /dev/null +++ b/examples/sweep/optuna_slurm_resume.toml @@ -0,0 +1,68 @@ +# Optuna sweep with persistent SQLite storage — survives controller crashes. +# +# Demonstrates: `strategy.storage` set to a SQLite URL, so the Optuna +# study itself persists across controller invocations. Combined with +# `resume = true`, a sweep that was killed mid-run can be restarted and +# the controller will: +# +# 1. Reload the study from disk (load_if_exists = config.resume). +# 2. Validate the existing manifest matches the storage trials +# (resolve_checksum, base_checksums, ids, parameters). +# 3. Reconcile any RUNNING trials left over from the previous run: +# - status.json says "completed" with an objective → tell Optuna +# the value so future asks see it. +# - status.json says "pruned" → tell TrialState.PRUNED. +# - otherwise → tell TrialState.FAIL and mark status failed. +# 4. Resume sampling from trial number == len(study.trials). +# +# How to test resume: +# +# # First run (resume = false). Let it produce a few completed trials, +# # then Ctrl-C the controller. +# uv run sweep @ examples/sweep/optuna_slurm_resume.toml +# +# # Edit this file: set `resume = true`. Re-run. The controller should +# # print "Reconciled N RUNNING Optuna trial(s) from interrupted resume" +# # if the crash happened between ask() and tell(), then continue from +# # the next trial index until num_trials are completed. +# uv run sweep @ examples/sweep/optuna_slurm_resume.toml +# +# In-memory studies (no `storage` field) cannot resume — Optuna keeps the +# trial history in process memory, so a controller restart starts from +# scratch. Persistent storage moves that history to SQLite (or any URL +# Optuna supports: postgres, mysql, redis-backed, etc.) so the next +# invocation reads it back. +# +# Path nuance: the SQLite URL is interpreted relative to the controller's +# CWD, not output_dir. Storing the .db under output_dir keeps everything +# for one study in one folder; the URL is `sqlite:///` (note the +# triple slash for an absolute or relative-to-cwd path). +# +# Run: +# uv run sweep @ examples/sweep/optuna_slurm_resume.toml +name = "reverse-text-optuna-resume" +entrypoint = "rl" +base = ["examples/reverse_text/rl.toml", "examples/sweep/slurm_base.toml"] +output_dir = "outputs/studies/reverse-text-optuna-resume" +resume = false + +[strategy] +type = "optuna" +num_trials = 6 +seed = 42 +sampler = "tpe" +study_name = "reverse-text-optuna-resume" +storage = "sqlite:///outputs/studies/reverse-text-optuna-resume/optuna.db" + +[scheduler] +type = "slurm" +synchronous = true + +[objective] +metric = "reward/reverse-text/mean" +direction = "maximize" + +[parameters."trainer.optim.lr"] +distribution = "log_uniform" +min = 1e-7 +max = 1e-4 diff --git a/examples/sweep/parallel_local.toml b/examples/sweep/parallel_local.toml new file mode 100644 index 0000000000..da6377c8e7 --- /dev/null +++ b/examples/sweep/parallel_local.toml @@ -0,0 +1,39 @@ +# Parallel local sweep with explicit GPU assignment. +# +# Demonstrates: max_parallel > 1 with [scheduler.gpu_assignment] pinning +# each parallel worker to a disjoint device group. Each trial subprocess +# inherits a CUDA_VISIBLE_DEVICES override matching its assigned group. +# +# Requires: at least max_parallel disjoint visible_devices groups; the +# validator rejects max_parallel > 1 without gpu_assignment so two +# workers can never silently colocate trainer/inference stacks on the +# same GPUs. +# +# The example below pins two workers to GPUs [0,1] and [2,3] respectively, +# matching a 4-GPU node with each trial running its own 1-train + 1-infer +# split. Adjust visible_devices to match your hardware. +# +# Run: +# uv run sweep @ examples/sweep/parallel_local.toml +name = "reverse-text-parallel" +entrypoint = "rl" +base = ["examples/sweep/rl.toml"] +output_dir = "outputs/studies/reverse-text-parallel" + +[strategy] +type = "grid" + +[scheduler] +type = "local" +max_parallel = 2 + +[scheduler.gpu_assignment] +mode = "static" +visible_devices = [[0, 1], [2, 3]] + +[objective] +metric = "reward/reverse-text/mean" +direction = "maximize" + +[parameters."trainer.optim.lr"] +values = [1e-6, 3e-6, 1e-5, 3e-5] diff --git a/examples/sweep/random_local.toml b/examples/sweep/random_local.toml new file mode 100644 index 0000000000..b465c15ea3 --- /dev/null +++ b/examples/sweep/random_local.toml @@ -0,0 +1,35 @@ +# Random search over distributions (log_uniform LR, uniform temperature). +# +# Demonstrates: random strategy with typed distributions. Each trial +# samples a fresh parameter set; results are reproducible across runs +# because seed is pinned. +# +# Run: +# uv run sweep @ examples/sweep/random_local.toml +name = "reverse-text-random" +entrypoint = "rl" +base = ["examples/sweep/rl.toml"] +output_dir = "outputs/studies/reverse-text-random" + +[strategy] +type = "random" +num_trials = 8 +seed = 7 + +[scheduler] +type = "local" +max_parallel = 1 + +[objective] +metric = "reward/reverse-text/mean" +direction = "maximize" + +[parameters."trainer.optim.lr"] +distribution = "log_uniform" +min = 1e-7 +max = 1e-4 + +[parameters."orchestrator.train.sampling.temperature"] +distribution = "uniform" +min = 0.6 +max = 1.2 diff --git a/examples/sweep/rl.toml b/examples/sweep/rl.toml new file mode 100644 index 0000000000..d19337b86c --- /dev/null +++ b/examples/sweep/rl.toml @@ -0,0 +1,26 @@ +max_steps = 20 +seq_len = 2048 + +[model] +name = "PrimeIntellect/Qwen3-0.6B-Reverse-Text-SFT" + +[wandb] +project = "reverse-text" +name = "reverse-text" + +[orchestrator] +batch_size = 128 +rollouts_per_example = 16 + +[orchestrator.train.sampling] +max_completion_tokens = 128 + +[[orchestrator.train.env]] +id = "reverse-text" + +[trainer.optim] +lr = 3e-6 + +[ckpt] # Checkpoint at the end of training + +[inference] \ No newline at end of file diff --git a/examples/sweep/rl_disagg.toml b/examples/sweep/rl_disagg.toml new file mode 100644 index 0000000000..0079a135d3 --- /dev/null +++ b/examples/sweep/rl_disagg.toml @@ -0,0 +1,32 @@ +# Disaggregated reverse-text RL: trainer + orchestrator on rig A, vLLM on rig B. +# No [inference] block here — rig B owns inference; the orchestrator hits it +# over the LAN. +max_steps = 5 +seq_len = 2048 +clean_output_dir = true + +[deployment] +type = "single_node" +num_train_gpus = 1 +num_infer_gpus = 0 + +[model] +name = "PrimeIntellect/Qwen3-0.6B-Reverse-Text-SFT" + +[orchestrator] +batch_size = 128 +rollouts_per_example = 16 + +[orchestrator.client] +base_url = ["http://10.1.0.69:8000/v1"] + +[orchestrator.train.sampling] +max_completion_tokens = 128 + +[[orchestrator.train.env]] +id = "reverse-text" + +[trainer.optim] +lr = 3e-6 + +[ckpt] diff --git a/examples/sweep/rl_multi_run_lora.toml b/examples/sweep/rl_multi_run_lora.toml new file mode 100644 index 0000000000..24a62753d8 --- /dev/null +++ b/examples/sweep/rl_multi_run_lora.toml @@ -0,0 +1,52 @@ +# Shared base for a multi_run_lora sweep on a single multi-GPU machine +# (no disagg — inference runs locally on a separate GPU). +# +# rl-multi-run launches one trainer + one inference + N orchestrators +# against pre-materialized run dirs. Each orchestrator gets a per-trial +# LoRA name from the sweep; the trainer pushes adapter weights via the +# /load_lora_adapter HTTP endpoint, so vLLM must be LoRA-enabled. +# +# Requires: +# - ≥2 GPUs on this machine (num_train_gpus + num_infer_gpus). +max_steps = 20 +max_async_level = 20 +seq_len = 2048 +clean_output_dir = true + +[deployment] +type = "single_node" +num_train_gpus = 1 +num_infer_gpus = 1 + +[model] +name = "PrimeIntellect/Qwen3-0.6B-Reverse-Text-SFT" + +[trainer] +max_concurrent_runs = 3 + +[trainer.model.lora] +rank = 8 +alpha = 16 + +[trainer.optim] +lr = 3e-6 + +[orchestrator] +batch_size = 64 +rollouts_per_example = 8 + +[orchestrator.model.lora] +rank = 8 + +[orchestrator.train.sampling] +max_completion_tokens = 128 + +[[orchestrator.train.env]] +id = "reverse-text" + +[inference] +enable_lora = true +max_lora_rank = 8 +max_loras = 4 + +[ckpt] diff --git a/examples/sweep/rl_multi_run_lora_disagg.toml b/examples/sweep/rl_multi_run_lora_disagg.toml new file mode 100644 index 0000000000..533e965252 --- /dev/null +++ b/examples/sweep/rl_multi_run_lora_disagg.toml @@ -0,0 +1,42 @@ +# Shared base for a multi_run_lora sweep with disagg inference. +# rl-multi-run launches one trainer + N orchestrators; inference is on rig B. +max_steps = 15 +max_async_level = 20 +seq_len = 2048 +clean_output_dir = true + +[deployment] +type = "single_node" +num_train_gpus = 1 +num_infer_gpus = 0 + +[model] +name = "PrimeIntellect/Qwen3-0.6B-Reverse-Text-SFT" + +[trainer] +max_concurrent_runs = 3 + +[trainer.model.lora] +rank = 8 +alpha = 16 + +[trainer.optim] +lr = 3e-6 + +[orchestrator] +batch_size = 64 +rollouts_per_example = 8 + +[orchestrator.client] +base_url = ["http://10.1.0.69:8000/v1"] + +[orchestrator.model.lora] +rank = 8 + +[orchestrator.train.sampling] +max_completion_tokens = 128 + +[[orchestrator.train.env]] +id = "reverse-text" + +[ckpt] diff --git a/examples/sweep/sft.toml b/examples/sweep/sft.toml new file mode 100644 index 0000000000..cadfb455ef --- /dev/null +++ b/examples/sweep/sft.toml @@ -0,0 +1,14 @@ +max_steps = 100 + +[ckpt] # Checkpoint at the end of training + +[model] +name = "PrimeIntellect/Qwen3-0.6B" + +[data] +name = "willcb/R1-reverse-wikipedia-paragraphs-v1-1000" +seq_len = 4096 +batch_size = 32 + +[optim] +lr = 2e-5 diff --git a/examples/sweep/sft_slurm.toml b/examples/sweep/sft_slurm.toml new file mode 100644 index 0000000000..602e8e8541 --- /dev/null +++ b/examples/sweep/sft_slurm.toml @@ -0,0 +1,43 @@ +# SFT entrypoint over the synchronous SLURM scheduler. +# +# Demonstrates: a non-RL entrypoint going through the same controller +# loop. The SFT sbatch template runs torchrun and exits naturally when +# training completes — no self-scancel is needed because the SFT trainer +# isn't a long-lived orchestrator. sbatch --wait carries the trainer's +# exit code back to the controller directly, so SLURM accounting +# (sacct/scontrol) is not on the critical path for the completed/failed +# decision. +# +# Requires: +# - sbatch on PATH. +# - A reachable SLURM cluster. +# - Edit examples/sweep/sft_slurm_base.toml to match your cluster +# ([slurm] partition/account/time + [deployment] gpus_per_node / +# num_nodes). +# +# Run: +# uv run sweep @ examples/sweep/sft_slurm.toml +name = "reverse-text-sft-slurm" +entrypoint = "sft" +base = ["examples/reverse_text/sft.toml", "examples/sweep/sft_slurm_base.toml"] +output_dir = "outputs/studies/reverse-text-sft-slurm" + +[strategy] +type = "grid" + +[scheduler] +type = "slurm" +synchronous = true + +[objective] +metric = "loss/mean" +direction = "minimize" + +# Disabled so a fresh cluster can run this example without configuring +# wandb on every compute node. Flip enabled=true once `wandb login` has +# been run on each node the sweep might land on. +[wandb] +enabled = false + +[parameters."optim.lr"] +values = [1e-5, 2e-5, 5e-5] diff --git a/examples/sweep/sft_slurm_base.toml b/examples/sweep/sft_slurm_base.toml new file mode 100644 index 0000000000..dd1ca945e1 --- /dev/null +++ b/examples/sweep/sft_slurm_base.toml @@ -0,0 +1,20 @@ +# SLURM submission settings for SFT sweeps, chained on top of +# examples/sweep/sft.toml when running examples/sweep/sft_slurm.toml. +# Edit to match your cluster. +# +# SFT uses a different [deployment] shape than RL — num_nodes instead +# of num_train_nodes + num_infer_nodes — because SFT is trainer-only +# with no separate inference allocation. +[slurm] +job_name = "reverse-text-sft-sweep" +partition = "gpu" +# account = "your-account" +# time = "01:00:00" +# nodelist = "node-a" +# template_path = "/path/to/custom_sbatch_template.sh" +# pre_run_command = "module load cuda/12.8" + +[deployment] +type = "multi_node" +gpus_per_node = 1 +num_nodes = 1 diff --git a/examples/sweep/sft_slurm_parallel.toml b/examples/sweep/sft_slurm_parallel.toml new file mode 100644 index 0000000000..c028d5235b --- /dev/null +++ b/examples/sweep/sft_slurm_parallel.toml @@ -0,0 +1,63 @@ +# SFT entrypoint over synchronous SLURM with parallel trials. +# +# Pairs the `sft` entrypoint with scheduler.max_parallel > 1. Each SFT +# trial is 1 node (no inference disaggregation), so on an N-node cluster +# you can fit N concurrent trials by setting max_parallel = N. +# +# How it works: the controller maintains up to max_parallel concurrent +# in-flight SLURM jobs via a ThreadPoolExecutor. As each `sbatch --wait` +# returns, the main thread tells Optuna the outcome and asks for a fresh +# trial to refill the slot. TPE auto-enables `constant_liar=True` so +# concurrent asks diversify; pruners are not supported (Optuna trial +# objects aren't thread-safe to share across polling threads). +# +# Why a separate SFT example: the existing optuna_slurm_parallel.toml +# uses the RL base, which needs 2 nodes per trial (trainer + inference). +# On a 3-node cluster, that means only 1 RL trial fits at a time even +# with max_parallel = 3 — no actual parallelism. SFT trials are 1-node +# so the controller's max_parallel knob produces real concurrency. +# +# Requires: +# - sbatch on PATH. +# - examples/sweep/sft_slurm_base.toml edited for your cluster. +# - HF_HOME pointing at a shared cache (see README), or per-node +# `wandb login` if you re-enable wandb. +# +# Run: +# uv run sweep @ examples/sweep/sft_slurm_parallel.toml +name = "reverse-text-sft-slurm-parallel" +entrypoint = "sft" +base = ["examples/reverse_text/sft.toml", "examples/sweep/sft_slurm_base.toml"] +output_dir = "outputs/studies/reverse-text-sft-slurm-parallel" +# Wipe stale trial dirs from prior runs so the fresh sweep starts clean. +clean_output_dir = true + +[strategy] +type = "optuna" +num_trials = 9 +seed = 42 +sampler = "tpe" +# Persist the Optuna study to SQLite so the sweep can be resumed if the +# controller is interrupted (set `resume = true` at the top to reload). +study_name = "reverse-text-sft-slurm-parallel" +storage = "sqlite:///outputs/studies/reverse-text-sft-slurm-parallel/optuna.db" + +[scheduler] +type = "slurm" +synchronous = true +max_parallel = 3 + +[objective] +metric = "loss/mean" +direction = "minimize" + +# Disabled so a fresh cluster can run this example without configuring +# wandb on every compute node. Flip enabled=true once `wandb login` has +# been run on each node the sweep might land on. +[wandb] +enabled = false + +[parameters."optim.lr"] +distribution = "log_uniform" +min = 1e-6 +max = 1e-4 diff --git a/examples/sweep/slurm.toml b/examples/sweep/slurm.toml new file mode 100644 index 0000000000..44051945da --- /dev/null +++ b/examples/sweep/slurm.toml @@ -0,0 +1,36 @@ +# SLURM scheduler: each trial is submitted as an independent SLURM job. +# +# Demonstrates: scheduler.type = "slurm". The sweep controller submits +# one job per trial via the target entrypoint's existing [slurm] support +# (rendering and sbatch-ing per-trial scripts) and then exits. Jobs run +# asynchronously after submission, so the controller cannot tell when +# they finish — early stopping and Optuna are not supported with this +# scheduler (validator enforces both rules). +# +# Requires: +# - sbatch on PATH. +# - A reachable SLURM cluster. +# - Edit examples/sweep/slurm_base.toml to set your cluster's +# [slurm] partition/account/time/template_path/etc. and the +# [deployment] block (defaults to 1 train node + 1 infer node, +# 1 GPU each — bump to match your cluster). +# +# Run: +# uv run sweep @ examples/sweep/slurm.toml +name = "reverse-text-slurm" +entrypoint = "rl" +base = ["examples/sweep/rl.toml", "examples/sweep/slurm_base.toml"] +output_dir = "outputs/studies/reverse-text-slurm" + +[strategy] +type = "grid" + +[scheduler] +type = "slurm" + +[objective] +metric = "reward/reverse-text/mean" +direction = "maximize" + +[parameters."trainer.optim.lr"] +values = [1e-6, 3e-6, 1e-5] diff --git a/examples/sweep/slurm_base.toml b/examples/sweep/slurm_base.toml new file mode 100644 index 0000000000..cde3c5458d --- /dev/null +++ b/examples/sweep/slurm_base.toml @@ -0,0 +1,22 @@ +# SLURM submission settings, chained on top of examples/sweep/rl.toml +# when running examples/sweep/slurm.toml. Edit to match your cluster. +# +# The [deployment] block forces a multi-node layout with one training +# node and one inference node, each with 1 GPU. This fits a 2 x 1 GPU +# homelab cluster out of the box. Bump gpus_per_node, num_train_nodes, +# and num_infer_nodes to match larger clusters. +[slurm] +job_name = "reverse-text-sweep" +partition = "gpu" +# account = "your-account" +# time = "01:00:00" +# nodelist = "node-a,node-b" +# exclude = "node-bad" +# template_path = "/path/to/custom_sbatch_template.sh" +# pre_run_command = "module load cuda/12.8" + +[deployment] +type = "multi_node" +gpus_per_node = 1 +num_train_nodes = 1 +num_infer_nodes = 1 diff --git a/packages/prime-rl-configs/src/prime_rl/configs/orchestrator.py b/packages/prime-rl-configs/src/prime_rl/configs/orchestrator.py index 5d04d3369f..5cd873fef3 100644 --- a/packages/prime-rl-configs/src/prime_rl/configs/orchestrator.py +++ b/packages/prime-rl-configs/src/prime_rl/configs/orchestrator.py @@ -967,8 +967,13 @@ class OrchestratorConfig(BaseConfig): # Data buffer configuration buffer: BufferConfig = BufferConfig() - # The advantage configuration - advantage: AdvantageConfig | None = DefaultAdvantageConfig() + # Inlined rather than `AdvantageConfig | None`: tyro's expand_union_types + # mishandles `Annotated[X | Y, Field(discriminator)] | None` with a + # non-None default and emits a duplicate-subcommand warning. + advantage: Annotated[ + DefaultAdvantageConfig | CustomAdvantageConfig | None, + Field(discriminator="type"), + ] = DefaultAdvantageConfig() # Rollout filters (monitor by default, enforce optionally) filters: list[FilterConfig] = [GibberishFilterConfig(), RepetitionFilterConfig(), ZeroAdvantageFilterConfig()] diff --git a/packages/prime-rl-configs/src/prime_rl/configs/sweep.py b/packages/prime-rl-configs/src/prime_rl/configs/sweep.py new file mode 100644 index 0000000000..a54557ff83 --- /dev/null +++ b/packages/prime-rl-configs/src/prime_rl/configs/sweep.py @@ -0,0 +1,1054 @@ +import math +import warnings +from pathlib import Path +from typing import Annotated, Any, Literal, TypeAlias + +from pydantic import Discriminator, Field, Tag, model_validator + +from prime_rl.utils.config import BaseConfig + +OPTUNA_SEED_MAX = 2**32 - 1 + + +class ChoiceParameterConfig(BaseConfig): + """Choice-valued parameter sampled from an explicit list.""" + + distribution: Literal["choice"] = "choice" + values: Annotated[list[Any], Field(description="Explicit values to sweep over.")] + + @model_validator(mode="after") + def validate_values(self): + if not self.values: + raise ValueError("Sweep parameter values must be non-empty") + for idx, value in enumerate(self.values): + _validate_choice_value(value, f"values[{idx}]") + return self + + +class UniformParameterConfig(BaseConfig): + """Continuous parameter sampled uniformly on [min, max].""" + + distribution: Literal["uniform"] + min: float + max: float + + @model_validator(mode="before") + @classmethod + def reject_bool_bounds(cls, data: Any) -> Any: + _reject_bool_fields(data, ("min", "max"), "Uniform parameter") + return data + + @model_validator(mode="after") + def validate_range(self): + if not math.isfinite(self.min) or not math.isfinite(self.max): + raise ValueError("Uniform parameter min and max must be finite") + if self.min >= self.max: + raise ValueError("Uniform parameter requires min < max") + if not math.isfinite(self.max - self.min): + raise ValueError("Uniform parameter range must be finite") + return self + + +class LogUniformParameterConfig(BaseConfig): + """Continuous parameter sampled uniformly in log-space on [min, max].""" + + distribution: Literal["log_uniform"] + min: float + max: float + + @model_validator(mode="before") + @classmethod + def reject_bool_bounds(cls, data: Any) -> Any: + _reject_bool_fields(data, ("min", "max"), "Log-uniform parameter") + return data + + @model_validator(mode="after") + def validate_range(self): + if not math.isfinite(self.min) or not math.isfinite(self.max): + raise ValueError("Log-uniform parameter min and max must be finite") + if self.min <= 0 or self.max <= 0: + raise ValueError("Log-uniform parameter requires positive min and max") + if self.min >= self.max: + raise ValueError("Log-uniform parameter requires min < max") + return self + + +class IntUniformParameterConfig(BaseConfig): + """Integer parameter sampled uniformly from {min, min+step, ..., max}.""" + + distribution: Literal["int_uniform"] + min: int + max: int + step: Annotated[int, Field(ge=1)] = 1 + + @model_validator(mode="before") + @classmethod + def reject_bool_bounds(cls, data: Any) -> Any: + _reject_bool_fields(data, ("min", "max", "step"), "Int-uniform parameter") + return data + + @model_validator(mode="after") + def validate_range(self): + if self.min >= self.max: + raise ValueError("Int-uniform parameter requires min < max") + if (self.max - self.min) % self.step != 0: + raise ValueError( + f"Int-uniform range [{self.min}, {self.max}] is not divisible by step {self.step}; " + "non-divisible ranges silently truncate the search space (the inclusive max is never sampled). " + "Pick a step that divides (max - min) evenly." + ) + return self + + +def _parameter_discriminator(value: Any) -> str: + """Default to ``choice`` so the bare ``{"values": [...]}`` form keeps working.""" + if isinstance(value, dict): + return value.get("distribution", "choice") + return getattr(value, "distribution", "choice") + + +SweepParameterConfig: TypeAlias = Annotated[ + Annotated[ChoiceParameterConfig, Tag("choice")] + | Annotated[UniformParameterConfig, Tag("uniform")] + | Annotated[LogUniformParameterConfig, Tag("log_uniform")] + | Annotated[IntUniformParameterConfig, Tag("int_uniform")], + Discriminator(_parameter_discriminator), +] + + +def _validate_choice_value(value: Any, path: str) -> None: + """Validate that a choice value can be written to generated TOML.""" + if value is None: + raise ValueError( + f"Sweep choice parameter {path} cannot be None; use the string 'None' for nullable target fields." + ) + if isinstance(value, bool): + return + if isinstance(value, int): + return + if isinstance(value, float): + if not math.isfinite(value): + raise ValueError(f"Sweep choice parameter {path} must be finite") + return + if isinstance(value, str): + return + if isinstance(value, dict): + non_string_keys = [key for key in value if not isinstance(key, str)] + if non_string_keys: + raise ValueError( + f"Sweep choice parameter {path} has non-string TOML table key(s): {non_string_keys}" + ) + for key, child in value.items(): + _validate_choice_value(child, f"{path}.{key}") + return + if isinstance(value, (list, tuple)): + for idx, child in enumerate(value): + _validate_choice_value(child, f"{path}[{idx}]") + return + raise ValueError( + f"Sweep choice parameter {path} has value of type {type(value).__name__}, " + "which cannot be written to generated TOML." + ) + + +def _is_optuna_storage_safe_choice(value: Any) -> bool: + if value is None: + return False + if isinstance(value, (bool, int, float, str)): + return not isinstance(value, float) or math.isfinite(value) + return False + + +def _optuna_choices_are_equal(left: Any, right: Any) -> bool: + return left == right + + +def _reject_bool_fields(data: Any, fields: tuple[str, ...], label: str) -> None: + if not isinstance(data, dict): + return + bool_fields = [field for field in fields if isinstance(data.get(field), bool)] + if bool_fields: + raise ValueError(f"{label} numeric field(s) cannot be boolean: {bool_fields}") + + +def _choice_value_leaf_paths(parent_path: str, value: Any) -> list[str]: + if isinstance(value, dict): + if not value: + return [parent_path] + paths: list[str] = [] + for key, child in value.items(): + child_path = f"{parent_path}.{key}" + paths.extend(_choice_value_leaf_paths(child_path, child)) + return paths + if isinstance(value, (list, tuple)): + paths: list[str] = [] + for child in value: + if isinstance(child, (dict, list, tuple)): + paths.extend(_choice_value_leaf_paths(parent_path, child)) + return paths or [parent_path] + return [parent_path] + + +def _choice_value_leaf_items(parent_path: str, value: Any) -> list[tuple[str, Any]]: + if isinstance(value, dict): + if not value: + return [(parent_path, value)] + items: list[tuple[str, Any]] = [] + for key, child in value.items(): + child_path = f"{parent_path}.{key}" + items.extend(_choice_value_leaf_items(child_path, child)) + return items + if isinstance(value, (list, tuple)): + items: list[tuple[str, Any]] = [] + for child in value: + if isinstance(child, (dict, list, tuple)): + items.extend(_choice_value_leaf_items(parent_path, child)) + return items or [(parent_path, value)] + return [(parent_path, value)] + + +def _effective_parameter_paths(parameters: dict[str, "SweepParameterConfig"]) -> tuple[str, ...]: + paths: list[str] = [] + for path, parameter in parameters.items(): + paths.append(path) + if isinstance(parameter, ChoiceParameterConfig): + for value in parameter.values: + paths.extend(_choice_value_leaf_paths(path, value)) + return tuple(dict.fromkeys(paths)) + + +class GridStrategyConfig(BaseConfig): + """Exhaustive grid over choice-valued parameters.""" + + type: Literal["grid"] = "grid" + + +class RandomStrategyConfig(BaseConfig): + """Independent random samples from the declared parameter distributions.""" + + type: Literal["random"] = "random" + num_trials: Annotated[int, Field(ge=1, description="Number of trials to draw.")] + seed: Annotated[int | None, Field(description="Optional seed for reproducibility.")] = None + + @model_validator(mode="before") + @classmethod + def reject_bool_numeric_fields(cls, data: Any) -> Any: + _reject_bool_fields(data, ("num_trials", "seed"), "Random strategy") + return data + + +class NoPrunerConfig(BaseConfig): + """Disable pruning. Trials run to completion regardless of intermediate values.""" + + type: Literal["none"] = "none" + + +class MedianPrunerConfig(BaseConfig): + """Optuna's MedianPruner: prune trials whose intermediate value falls below + the running median of completed trials at the same step.""" + + type: Literal["median"] = "median" + n_startup_trials: Annotated[ + int, + Field(ge=0, description="Trials that must complete before pruning is enabled."), + ] = 5 + n_warmup_steps: Annotated[ + int, + Field(ge=0, description="Steps within a trial that are exempt from pruning."), + ] = 0 + interval_steps: Annotated[ + int, + Field(ge=1, description="Pruning is only checked every Nth reported step."), + ] = 1 + + @model_validator(mode="before") + @classmethod + def reject_bool_numeric_fields(cls, data: Any) -> Any: + _reject_bool_fields( + data, + ("n_startup_trials", "n_warmup_steps", "interval_steps"), + "Median pruner", + ) + return data + + +class AshaPrunerConfig(BaseConfig): + """Optuna's SuccessiveHalvingPruner (ASHA). Promotes trials whose intermediate + value is in the top ``1/reduction_factor`` at each rung.""" + + type: Literal["asha"] = "asha" + min_resource: Annotated[ + int | Literal["auto"], + Field(description="Minimum resource (steps) before a trial can be pruned."), + ] = "auto" + reduction_factor: Annotated[ + int, + Field(ge=2, description="At each rung, keep the top 1/reduction_factor of trials."), + ] = 4 + min_early_stopping_rate: Annotated[ + int, + Field(ge=0, description="Bracket index offset; 0 enables the most aggressive bracket."), + ] = 0 + + @model_validator(mode="before") + @classmethod + def reject_bool_numeric_fields(cls, data: Any) -> Any: + _reject_bool_fields( + data, + ("min_resource", "reduction_factor", "min_early_stopping_rate"), + "ASHA pruner", + ) + return data + + @model_validator(mode="after") + def validate_min_resource(self): + if isinstance(self.min_resource, int) and self.min_resource < 1: + raise ValueError("ASHA pruner min_resource must be >= 1 or 'auto'") + return self + + +class HyperbandPrunerConfig(BaseConfig): + """Optuna's HyperbandPruner: runs successive-halving across multiple brackets.""" + + type: Literal["hyperband"] = "hyperband" + min_resource: Annotated[ + int, + Field(ge=1, description="Smallest resource budget evaluated in any bracket."), + ] = 1 + max_resource: Annotated[ + int | Literal["auto"], + Field(description="Largest resource budget; ``auto`` infers from reported steps."), + ] = "auto" + reduction_factor: Annotated[ + int, + Field(ge=2, description="At each rung, keep the top 1/reduction_factor of trials."), + ] = 3 + + @model_validator(mode="before") + @classmethod + def reject_bool_numeric_fields(cls, data: Any) -> Any: + _reject_bool_fields( + data, + ("min_resource", "max_resource", "reduction_factor"), + "Hyperband pruner", + ) + return data + + @model_validator(mode="after") + def validate_resources(self): + if isinstance(self.max_resource, int) and self.max_resource < self.min_resource: + raise ValueError("Hyperband pruner max_resource must be >= min_resource or 'auto'") + return self + + +PrunerConfig: TypeAlias = Annotated[ + NoPrunerConfig | MedianPrunerConfig | AshaPrunerConfig | HyperbandPrunerConfig, + Field(discriminator="type"), +] + + +class OptunaStrategyConfig(BaseConfig): + """Adaptive sampling backed by Optuna. + + Samplers: ``tpe`` (default) and ``random``. Pruners: ``none`` (default), + ``median``, ``asha`` (successive-halving), and ``hyperband``. Pruners need + intermediate metric reporting from the trial; the controller polls a + sidecar metrics stream while the trial runs and calls + ``optuna_trial.report``/``should_prune`` between samples. + + Storage defaults to in-memory; pass a SQLAlchemy URL (e.g. + ``"sqlite:///optuna.db"``) to persist the study across resume. + """ + + type: Literal["optuna"] = "optuna" + num_trials: Annotated[int, Field(ge=1, description="Number of trials to evaluate.")] + seed: int | None = None + sampler: Literal["tpe", "random"] = "tpe" + pruner: PrunerConfig = NoPrunerConfig() + storage: Annotated[ + str | None, + Field(description="SQLAlchemy storage URL for study persistence; in-memory if unset."), + ] = None + study_name: Annotated[ + str | None, + Field(description="Optuna study_name; defaults to the sweep name."), + ] = None + poll_interval_seconds: Annotated[ + float, + Field( + gt=0, + description=( + "How often the controller polls the trial's intermediate metrics " + "while pruning is enabled. Ignored when pruner.type == 'none'." + ), + ), + ] = 5.0 + + @model_validator(mode="before") + @classmethod + def reject_bool_numeric_fields(cls, data: Any) -> Any: + _reject_bool_fields(data, ("num_trials", "seed", "poll_interval_seconds"), "Optuna strategy") + return data + + @model_validator(mode="after") + def validate_optuna_fields(self): + if not math.isfinite(self.poll_interval_seconds): + raise ValueError("Optuna poll_interval_seconds must be finite") + if self.seed is not None and not 0 <= self.seed <= OPTUNA_SEED_MAX: + raise ValueError(f"Optuna seed must be between 0 and {OPTUNA_SEED_MAX}") + if self.storage is not None and not self.storage.strip(): + raise ValueError("Optuna storage must be a non-empty SQLAlchemy URL when set") + return self + + +SearchStrategyConfig: TypeAlias = Annotated[ + GridStrategyConfig | RandomStrategyConfig | OptunaStrategyConfig, + Field(discriminator="type"), +] + + +class LocalGpuAssignmentConfig(BaseConfig): + """Static round-robin assignment of CUDA_VISIBLE_DEVICES to local workers. + + Each entry in ``visible_devices`` is one device group that pins one trial + subprocess. Groups are disjoint by construction so two parallel workers + never share a GPU. ``mode`` is currently fixed to ``"static"``; future + modes (``"exclusive"`` for live GPU discovery, ``"none"`` to leave + ``CUDA_VISIBLE_DEVICES`` untouched) will land in later phases. + """ + + mode: Literal["static"] = "static" + visible_devices: Annotated[ + list[list[int]], + Field(min_length=1, description="Disjoint device groups assigned to parallel workers."), + ] + + @model_validator(mode="before") + @classmethod + def reject_bool_devices(cls, data: Any) -> Any: + if not isinstance(data, dict): + return data + groups = data.get("visible_devices") + if not isinstance(groups, list): + return data + for group_idx, group in enumerate(groups): + if not isinstance(group, list): + continue + for device_idx, device in enumerate(group): + if isinstance(device, bool): + raise ValueError( + "Local GPU assignment visible_devices entries cannot be boolean: " + f"visible_devices[{group_idx}][{device_idx}]" + ) + return data + + @model_validator(mode="after") + def validate_groups(self): + if any(not group for group in self.visible_devices): + raise ValueError("Each visible_devices group must contain at least one device index") + flat = [device for group in self.visible_devices for device in group] + if any(device < 0 for device in flat): + raise ValueError("visible_devices indices must be non-negative") + if len(flat) != len(set(flat)): + raise ValueError("Each device may only appear in one visible_devices group") + return self + + +class LocalSweepSchedulerConfig(BaseConfig): + """Run generated trials as local subprocesses.""" + + type: Literal["local"] = "local" + + max_parallel: Annotated[int, Field(ge=1, description="Maximum local trials to run concurrently.")] = 1 + gpu_assignment: Annotated[ + LocalGpuAssignmentConfig | None, + Field(description="Required for max_parallel > 1; pins each worker to a disjoint device group."), + ] = None + + @model_validator(mode="before") + @classmethod + def reject_bool_numeric_fields(cls, data: Any) -> Any: + _reject_bool_fields(data, ("max_parallel",), "Local scheduler") + return data + + @model_validator(mode="after") + def validate_parallel(self): + if self.max_parallel > 1: + if self.gpu_assignment is None: + raise ValueError( + "max_parallel > 1 requires explicit gpu_assignment so parallel workers do not " + "silently colocate trainer/inference stacks on the same GPUs." + ) + available = len(self.gpu_assignment.visible_devices) + if available < self.max_parallel: + raise ValueError( + f"max_parallel={self.max_parallel} requires at least {self.max_parallel} " + f"visible_devices groups, got {available}." + ) + return self + + +class SlurmSweepSchedulerConfig(BaseConfig): + """Submit generated trials through the target entrypoint's SLURM support. + + Throughput is governed by the SLURM cluster. When ``synchronous = true`` + the controller blocks on each trial; with ``max_parallel > 1`` the + controller drives up to N concurrent trials through ``sbatch + --parsable`` and shared-FS polling so Optuna's ask/tell loop can fill + the cluster instead of serializing one trial at a time. + + When ``synchronous = true``, the controller submits each trial via + ``sbatch --wait`` (or ``sbatch --parsable`` + polling in pruning / + parallel mode) and observes per-trial completion. This lets Optuna and + trial-level early stopping work over SLURM (the controller learns each + trial's objective before proposing the next), at the cost of trials + being scheduled at the controller's pace rather than the cluster's + queue cadence. + """ + + type: Literal["slurm"] = "slurm" + synchronous: Annotated[ + bool, + Field( + description=( + "Block on each sbatch submission via 'sbatch --wait' so the " + "controller observes per-trial completion. Required to pair " + "Optuna or early stopping with the SLURM scheduler." + ), + ), + ] = False + max_parallel: Annotated[ + int, + Field( + ge=1, + description=( + "Maximum concurrent in-flight SLURM jobs the controller will " + "manage. Only meaningful with synchronous=true; the controller " + "submits up to this many trials, polls each via shared-FS " + "metrics.jsonl + squeue, and replaces them with fresh Optuna " + "asks as they complete. With TPE this enables constant_liar " + "sampling so concurrent asks don't collide on the same region." + ), + ), + ] = 1 + + +# Parameter paths a multi_run_lora sweep is allowed to vary. Must stay in +# sync with both the OrchestratorConfig schema (paths must actually resolve) +# and what the trainer's MultiRunManager treats as per-run-safe; see +# src/prime_rl/trainer/runs.py for the runtime validation hook. Anything +# under trainer.*, model.*, deployment.*, or inference.* is shared across +# runs and would silently mismatch between trials, so it is rejected at +# config-load time. +# +# Prefixes match arbitrary dict subtrees (paths must continue under them); +# fields match concrete schema leaves exactly. Splitting them avoids letting +# bogus paths like ``orchestrator.batch_size_extra`` slip through a +# startswith check. +# +# Targets that resolve to a list (e.g. ``orchestrator.train.env``, +# ``orchestrator.eval.env``) cannot be allowlisted: the sweep materializer's +# ``set_dotted_path`` only walks dict tables, so a path like +# ``orchestrator.train.env.id`` would produce a dict-shaped override that +# RLConfig rejects with "Input should be a valid list". +# +# Fields coupled to the shared trainer (e.g. ``orchestrator.max_steps`` and +# ``orchestrator.max_async_level``) are intentionally not allowlisted. The +# shared trainer owns the actual loop length and weight-broadcast retention +# window for all runs in the wave. +MULTI_RUN_LORA_PARAMETER_PREFIXES: tuple[str, ...] = ( + "orchestrator.train.sampling.extra_body.", + "orchestrator.eval.sampling.extra_body.", +) +MULTI_RUN_LORA_PARAMETER_FIELDS: frozenset[str] = frozenset( + { + "orchestrator.optim.lr", + "orchestrator.model.lora.name", + "orchestrator.model.lora.rank", + "orchestrator.model.lora.alpha", + "orchestrator.batch_size", + "orchestrator.token_batch_size", + "orchestrator.oversampling_factor", + "orchestrator.max_inflight_rollouts", + "orchestrator.rollouts_per_example", + "orchestrator.max_off_policy_steps", + "orchestrator.strict_async_level", + "orchestrator.seed", + "orchestrator.tasks_per_minute", + "orchestrator.train.sampling", + "orchestrator.train.sampling.temperature", + "orchestrator.train.sampling.repetition_penalty", + "orchestrator.train.sampling.max_completion_tokens", + "orchestrator.train.sampling.max_tokens", + "orchestrator.train.sampling.min_tokens", + "orchestrator.train.sampling.seed", + "orchestrator.train.sampling.extra_body", + "orchestrator.train.num_workers", + "orchestrator.train.max_retries", + "orchestrator.eval.sampling", + "orchestrator.eval.sampling.temperature", + "orchestrator.eval.sampling.repetition_penalty", + "orchestrator.eval.sampling.top_p", + "orchestrator.eval.sampling.top_k", + "orchestrator.eval.sampling.min_p", + "orchestrator.eval.sampling.max_completion_tokens", + "orchestrator.eval.sampling.max_tokens", + "orchestrator.eval.sampling.min_tokens", + "orchestrator.eval.sampling.reasoning_effort", + "orchestrator.eval.sampling.seed", + "orchestrator.eval.sampling.extra_body", + "orchestrator.eval.num_examples", + "orchestrator.eval.rollouts_per_example", + "orchestrator.eval.num_workers", + "orchestrator.eval.max_retries", + "orchestrator.eval.interval", + "orchestrator.eval.eval_base_model", + "orchestrator.eval.skip_eval_on_resume", + "orchestrator.eval.cancel_inflight_rollouts_on_eval", + # BufferConfig: scalar leaves, plus hash_keys as a whole-list swap. + # Sub-paths under hash_keys (it's a list[str]) are unreachable via + # set_dotted_path, so only the exact field is allowlisted. + "orchestrator.buffer.seed", + "orchestrator.buffer.easy_threshold", + "orchestrator.buffer.hard_threshold", + "orchestrator.buffer.easy_fraction", + "orchestrator.buffer.hard_fraction", + "orchestrator.buffer.online_difficulty_filtering", + "orchestrator.buffer.hash_keys", + } +) +MULTI_RUN_LORA_LIST_PARAMETER_FIELDS: frozenset[str] = frozenset({"orchestrator.buffer.hash_keys"}) +MULTI_RUN_LORA_DICT_PARAMETER_FIELDS: frozenset[str] = frozenset( + { + "orchestrator.train.sampling", + "orchestrator.train.sampling.extra_body", + "orchestrator.eval.sampling", + "orchestrator.eval.sampling.extra_body", + } +) +MULTI_RUN_LORA_SCALAR_PARAMETER_FIELDS: frozenset[str] = ( + MULTI_RUN_LORA_PARAMETER_FIELDS + - MULTI_RUN_LORA_LIST_PARAMETER_FIELDS + - MULTI_RUN_LORA_DICT_PARAMETER_FIELDS +) + + +def _multi_run_lora_exact_field_shape_errors(parameters: dict[str, SweepParameterConfig]) -> list[str]: + errors: list[str] = [] + for path, parameter in parameters.items(): + if path.startswith(MULTI_RUN_LORA_PARAMETER_PREFIXES): + continue + if path not in MULTI_RUN_LORA_PARAMETER_FIELDS: + continue + if not isinstance(parameter, ChoiceParameterConfig): + if path in MULTI_RUN_LORA_LIST_PARAMETER_FIELDS or path in MULTI_RUN_LORA_DICT_PARAMETER_FIELDS: + errors.append(f"{path}: list/table fields must use explicit choice values") + continue + leaf_values: dict[str, list[Any]] = {} + for value in parameter.values: + leaf_items = _choice_value_leaf_items(path, value) + if all(leaf_path != path for leaf_path, _ in leaf_items): + leaf_values.setdefault(path, []).append(value) + for leaf_path, leaf_value in leaf_items: + leaf_values.setdefault(leaf_path, []).append(leaf_value) + for leaf_path, values in leaf_values.items(): + if leaf_path.startswith(MULTI_RUN_LORA_PARAMETER_PREFIXES): + continue + if leaf_path not in MULTI_RUN_LORA_PARAMETER_FIELDS: + continue + if leaf_path in MULTI_RUN_LORA_SCALAR_PARAMETER_FIELDS: + bad_values = [value for value in values if isinstance(value, (dict, list, tuple))] + if bad_values: + errors.append(f"{leaf_path}: scalar field cannot use structured choice value(s) {bad_values!r}") + elif leaf_path in MULTI_RUN_LORA_LIST_PARAMETER_FIELDS: + bad_values = [ + value + for value in values + if not isinstance(value, (list, tuple)) + or not value + or any(not isinstance(item, str) for item in value) + ] + if bad_values: + errors.append(f"{leaf_path}: must use non-empty list[str] choice value(s), got {bad_values!r}") + elif leaf_path in MULTI_RUN_LORA_DICT_PARAMETER_FIELDS: + bad_values = [value for value in values if not isinstance(value, dict)] + if bad_values: + errors.append(f"{leaf_path}: must use table/dict choice value(s), got {bad_values!r}") + return errors + + +class MultiRunLoRASchedulerConfig(BaseConfig): + """Run all trials concurrently against one shared trainer + inference. + + Static sweeps launch a single ``rl-multi-run`` invocation that brings up + one trainer (with ``trainer.max_concurrent_runs >= num_trials``), one + inference server, and ``num_trials`` orchestrators — one per trial. + Optuna sweeps run in waves so in-flight trials can be pruned between + intermediate metric reports. Resume against a still-running trainer is + intentionally deferred to a later phase. + """ + + type: Literal["multi_run_lora"] = "multi_run_lora" + max_concurrent_runs: Annotated[ + int, + Field( + ge=1, + description=( + "Number of concurrent orchestrator runs against the shared trainer. " + "Must match (or be <=) trainer.max_concurrent_runs in the shared base config." + ), + ), + ] + shared: Annotated[ + list[Path], + Field( + min_length=1, + description=( + "RLConfig base TOML(s) describing the shared trainer + inference. " + "Trial overrides apply to the orchestrator block only." + ), + ), + ] + + @model_validator(mode="before") + @classmethod + def reject_bool_numeric_fields(cls, data: Any) -> Any: + _reject_bool_fields(data, ("max_concurrent_runs",), "multi_run_lora scheduler") + return data + + +SweepSchedulerConfig: TypeAlias = Annotated[ + LocalSweepSchedulerConfig | SlurmSweepSchedulerConfig | MultiRunLoRASchedulerConfig, + Field(discriminator="type"), +] + + +class SweepWandbConfig(BaseConfig): + """W&B metadata injected into generated trials.""" + + enabled: bool = True + group: str | None = None + tags: list[str] = ["sweep"] + + +class ObjectiveConfig(BaseConfig): + """Names the metric the sweep optimizes and where to read it from.""" + + metric: Annotated[ + str, + Field(description="Metric key inside final_summary.json (forward-slash-separated)."), + ] + direction: Literal["maximize", "minimize"] + source: Literal["final_summary"] = "final_summary" + + @model_validator(mode="after") + def validate_metric(self): + if not self.metric.strip(): + raise ValueError("objective.metric must be non-empty") + return self + + +class ThresholdStoppingConfig(BaseConfig): + """Halt the study after a trial whose objective is on the wrong side of a threshold.""" + + type: Literal["threshold"] = "threshold" + threshold: float + min_trials: Annotated[ + int, + Field(ge=1, description="Minimum completed trials before threshold can fire."), + ] = 1 + + @model_validator(mode="before") + @classmethod + def reject_bool_numeric_fields(cls, data: Any) -> Any: + _reject_bool_fields(data, ("threshold", "min_trials"), "Early-stopping threshold") + return data + + @model_validator(mode="after") + def validate_threshold(self): + if not math.isfinite(self.threshold): + raise ValueError("Early-stopping threshold must be finite") + return self + + +class PatienceStoppingConfig(BaseConfig): + """Halt the study after N consecutive completed trials with no improvement.""" + + type: Literal["patience"] = "patience" + patience: Annotated[int, Field(ge=1, description="Consecutive non-improving trials required to halt.")] + min_trials: Annotated[ + int, + Field(ge=1, description="Minimum completed trials before patience can fire."), + ] = 1 + + @model_validator(mode="before") + @classmethod + def reject_bool_numeric_fields(cls, data: Any) -> Any: + _reject_bool_fields(data, ("patience", "min_trials"), "Patience early stopping") + return data + + +EarlyStoppingConfig: TypeAlias = Annotated[ + ThresholdStoppingConfig | PatienceStoppingConfig, + Field(discriminator="type"), +] + + +class SweepConfig(BaseConfig): + """Configures a hyperparameter sweep study.""" + + name: str | None = None + entrypoint: Literal["rl", "sft"] = "rl" + base: list[Path] + output_dir: Path + strategy: SearchStrategyConfig = GridStrategyConfig() + scheduler: SweepSchedulerConfig = LocalSweepSchedulerConfig() + parameters: dict[str, SweepParameterConfig] + wandb: SweepWandbConfig | None = SweepWandbConfig() + objective: ObjectiveConfig | None = None + early_stopping: EarlyStoppingConfig | None = None + continue_on_failure: Annotated[ + bool, + Field(description="Schedule remaining trials when one fails. Set false to halt-on-first-fail."), + ] = True + retry_budget: Annotated[ + int, + Field(ge=0, description="Retry a failed trial up to this many times before marking it failed."), + ] = 1 + resume: Annotated[ + bool, + Field(description="Reattach to an existing study output dir; preserve completed trial state."), + ] = False + dry_run: bool = False + clean_output_dir: bool = False + + @model_validator(mode="before") + @classmethod + def reject_bool_numeric_fields(cls, data: Any) -> Any: + _reject_bool_fields(data, ("retry_budget",), "Sweep") + return data + + @model_validator(mode="after") + def validate_sweep(self): + if not self.base: + raise ValueError("Sweep base must include at least one target config file") + if not self.parameters: + raise ValueError("Sweep parameters must include at least one parameter") + paths = tuple(self.parameters) + effective_paths = _effective_parameter_paths(self.parameters) + effective_path_set = set(effective_paths) + invalid_paths = [ + path for path in effective_paths if not path or any(not part for part in path.split(".")) + ] + if invalid_paths: + raise ValueError( + "Sweep parameter paths must be non-empty dot-separated config segments. " + f"Invalid path(s): {invalid_paths}" + ) + output_dir_paths = [path for path in effective_paths if "output_dir" in path.split(".")] + if output_dir_paths: + raise ValueError( + "Sweep parameters cannot set output_dir fields; " + "the sweep materializer owns each trial's output directories. " + f"Invalid path(s): {output_dir_paths}" + ) + if self.wandb is not None and self.wandb.enabled: + managed_wandb_paths = ("wandb.group", "wandb.name", "wandb.tags") + managed_component_wandb_paths = tuple( + f"{component}.wandb.{field}" + for component in ("trainer", "orchestrator") + for field in ("project", "entity", "name", "group", "tags", "offline") + ) + managed_component_wandb_tables = ("trainer.wandb", "orchestrator.wandb") + wandb_paths = [ + path + for path in effective_paths + if path == "wandb" + or any(path == managed or path.startswith(f"{managed}.") for managed in managed_wandb_paths) + or path in managed_component_wandb_tables + or any( + path == managed or path.startswith(f"{managed}.") + for managed in managed_component_wandb_paths + ) + ] + if wandb_paths: + raise ValueError( + "Sweep parameters cannot set sweep-managed W&B identity/shared fields while sweep wandb " + f"injection is enabled. Disable [wandb] injection or remove path(s): {wandb_paths}" + ) + path_conflicts = [ + (parent, child) + for parent in paths + for child in paths + if parent != child and child.startswith(f"{parent}.") + ] + if path_conflicts: + raise ValueError( + "Sweep parameters cannot include both a parent path and one of its sub-paths: " + f"{path_conflicts}. Split these into separate sweeps or choose one override shape." + ) + if self.resume and self.clean_output_dir: + warnings.warn( + "resume=true takes precedence over clean_output_dir=true; " + "ignoring clean_output_dir so existing trial state is preserved.", + stacklevel=2, + ) + self.clean_output_dir = False + if isinstance(self.strategy, GridStrategyConfig): + non_choice = [ + path for path, parameter in self.parameters.items() if not isinstance(parameter, ChoiceParameterConfig) + ] + if non_choice: + raise ValueError( + "Grid strategy only supports choice (values=...) parameters, " + f"but these declare distributions instead: {non_choice}" + ) + if self.resume and isinstance(self.strategy, RandomStrategyConfig) and self.strategy.seed is None: + raise ValueError( + "resume requires a deterministic trial set, but the random strategy has no seed. " + "Set strategy.seed so trial IDs match the previous study, or drop resume." + ) + if self.early_stopping is not None and self.objective is None: + raise ValueError( + "early_stopping requires an objective so the controller knows which metric to compare." + ) + if ( + self.early_stopping is not None + and isinstance(self.scheduler, SlurmSweepSchedulerConfig) + and not self.scheduler.synchronous + ): + raise ValueError( + "early_stopping is not supported with the SLURM scheduler unless " + "scheduler.synchronous=true: the asynchronous SLURM scheduler submits " + "jobs and exits, so it never observes trial completion to decide when to halt." + ) + if ( + isinstance(self.scheduler, SlurmSweepSchedulerConfig) + and self.scheduler.max_parallel > 1 + and not self.scheduler.synchronous + ): + raise ValueError( + "scheduler.max_parallel > 1 requires scheduler.synchronous=true: the " + "controller cannot manage concurrent in-flight jobs without observing each " + "one's terminal state, which is what the synchronous mode provides." + ) + if ( + self.early_stopping is not None + and isinstance(self.scheduler, MultiRunLoRASchedulerConfig) + and not isinstance(self.strategy, OptunaStrategyConfig) + ): + raise ValueError( + "early_stopping is not supported with static multi_run_lora sweeps: the controller " + "launches the whole grid/random wave at once, so it cannot stop future trials. " + "Use the Optuna strategy for wave-by-wave multi_run_lora early stopping." + ) + if isinstance(self.strategy, OptunaStrategyConfig): + if self.objective is None: + raise ValueError("Optuna strategy requires an objective to optimize.") + if ( + isinstance(self.scheduler, SlurmSweepSchedulerConfig) + and not self.scheduler.synchronous + ): + raise ValueError( + "Optuna strategy is not supported with the asynchronous SLURM scheduler: " + "the controller must observe each trial's objective before proposing the " + "next one. Set scheduler.synchronous=true to submit each trial with " + "'sbatch --wait' so the controller blocks per trial." + ) + if isinstance(self.scheduler, LocalSweepSchedulerConfig) and self.scheduler.max_parallel > 1: + raise ValueError( + "Optuna strategy on the local scheduler runs sequentially (ask/tell needs " + "each trial's objective before proposing the next), so scheduler.max_parallel " + "must be 1. Use scheduler.type='slurm' with synchronous=true to drive " + "max_parallel > 1 over SLURM." + ) + if self.resume and self.strategy.storage is None: + raise ValueError( + "Resume with the Optuna strategy requires strategy.storage so the study " + "can be reloaded; in-memory studies vanish when the controller exits." + ) + storage_unsafe_choice_paths = [ + path + for path, parameter in self.parameters.items() + if isinstance(parameter, ChoiceParameterConfig) + and not all(_is_optuna_storage_safe_choice(value) for value in parameter.values) + ] + if storage_unsafe_choice_paths: + raise ValueError( + "Optuna categorical parameters only support storage-safe primitive choices " + f"(bool, int, finite float, or str). Invalid parameter path(s): {storage_unsafe_choice_paths}" + ) + ambiguous_choice_paths = [] + for path, parameter in self.parameters.items(): + if not isinstance(parameter, ChoiceParameterConfig): + continue + if any( + _optuna_choices_are_equal(left, right) + for idx, left in enumerate(parameter.values) + for right in parameter.values[idx + 1 :] + ): + ambiguous_choice_paths.append(path) + if ambiguous_choice_paths: + raise ValueError( + "Optuna categorical parameters cannot include duplicate or equality-colliding " + f"choices because Optuna storage cannot distinguish them. Invalid parameter path(s): " + f"{ambiguous_choice_paths}" + ) + if isinstance(self.scheduler, MultiRunLoRASchedulerConfig): + if self.entrypoint != "rl": + raise ValueError( + "multi_run_lora scheduler is RL-only; the shared-trainer architecture " + "depends on the trainer's MultiRunManager which only the rl entrypoint runs." + ) + if self.resume: + raise ValueError( + "Resume is not supported with the multi_run_lora scheduler in Phase 7b; " + "re-attaching to a still-running shared trainer needs reconciliation work " + "that lands in Phase 7c." + ) + offending = [ + path + for path in effective_paths + if path not in MULTI_RUN_LORA_PARAMETER_FIELDS + and not any(path.startswith(prefix) for prefix in MULTI_RUN_LORA_PARAMETER_PREFIXES) + ] + if offending: + allowed = ", ".join( + (*MULTI_RUN_LORA_PARAMETER_PREFIXES, *sorted(MULTI_RUN_LORA_PARAMETER_FIELDS)) + ) + raise ValueError( + "multi_run_lora sweeps may only vary per-run orchestrator fields. " + f"These parameter paths are not in the allowlist ({allowed}): {offending}. " + "Trainer/model/deployment/inference settings cannot vary inside one shared trainer." + ) + shape_errors = _multi_run_lora_exact_field_shape_errors(self.parameters) + if shape_errors: + raise ValueError( + "multi_run_lora exact-field sweep parameters must match the allowlisted field shape. " + f"Invalid value(s): {shape_errors}" + ) + if ( + "orchestrator.batch_size" in self.parameters + and "orchestrator.token_batch_size" in self.parameters + ): + raise ValueError( + "multi_run_lora sweeps must set either orchestrator.batch_size or " + "orchestrator.token_batch_size, not both." + ) + if ( + "orchestrator.token_batch_size" in self.parameters + and "orchestrator.oversampling_factor" in self.parameters + ): + raise ValueError( + "multi_run_lora sweeps cannot set orchestrator.oversampling_factor with " + "orchestrator.token_batch_size; oversampling only applies to rollout batching." + ) + for prefix in ("orchestrator.train.sampling", "orchestrator.eval.sampling"): + has_canonical = f"{prefix}.max_completion_tokens" in effective_path_set + has_alias = f"{prefix}.max_tokens" in effective_path_set + if has_canonical and has_alias: + raise ValueError( + f"multi_run_lora sweeps must set either {prefix}.max_completion_tokens " + f"or {prefix}.max_tokens, not both." + ) + return self diff --git a/pyproject.toml b/pyproject.toml index d9b5468fa0..4a6b6d8210 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -40,7 +40,9 @@ dependencies = [ [project.scripts] rl = "prime_rl.entrypoints.rl:main" +rl-multi-run = "prime_rl.entrypoints.rl_multi_run:main" sft = "prime_rl.entrypoints.sft:main" +sweep = "prime_rl.entrypoints.sweep:main" inference = "prime_rl.entrypoints.inference:main" trainer = "prime_rl.trainer.rl.train:main" orchestrator = "prime_rl.orchestrator.orchestrator:main" @@ -58,6 +60,9 @@ flash-attn-3 = ["flash_attn_3"] flash-attn-cute = [ "flash-attn-4", ] +hpo = [ + "optuna>=4", +] envs = [ "reverse-text", "alphabet-sort", diff --git a/skills/config/SKILL.md b/skills/config/SKILL.md index e8dc13216c..4b83dae0c1 100644 --- a/skills/config/SKILL.md +++ b/skills/config/SKILL.md @@ -63,6 +63,12 @@ uv run rl @ examples/reverse_text/rl.toml --dry-run --output-dir /tmp/test # Writes resolved TOML to /tmp/test/configs ``` +Sweep configs also support `--dry-run`. A sweep dry-run validates the sweep and target trial configs, writes study/trial artifacts, and does not launch target runs: + +```bash +uv run sweep @ path/to/sweep.toml --dry-run +``` + ## Naming CLI uses kebab-case (`--model.max-model-len`), TOML uses snake_case (`max_model_len`). Both refer to the same field. @@ -132,6 +138,69 @@ On the CLI, pass as a JSON string: uv run inference --vllm-extra '{"key1": "value1", "key2": 123}' ``` +### Sweep parameter paths + +Sweep configs name a target entrypoint, one or more base config files, an output directory, a strategy, a scheduler, optional objective/stopping rules, and dotted target config paths under `[parameters]`. The sweep launcher converts parameters to generated override TOML files and then validates the target `rl` or `sft` config normally: + +```toml +name = "reverse-text-lr" +entrypoint = "rl" +base = ["examples/reverse_text/rl.toml"] +output_dir = "outputs/studies/reverse-text-lr" + +[strategy] +type = "grid" + +[scheduler] +type = "local" +max_parallel = 1 + +[objective] +metric = "reward/reverse-text/mean" +direction = "maximize" +``` + +Use `strategy.type = "grid"` for exhaustive choice combinations, `random` for seeded independent samples, and `optuna` for adaptive ask/tell studies. Optuna requires the `hpo` extra (`uv sync --extra hpo`), an `[objective]`, and local, synchronous SLURM, or `multi_run_lora` scheduling. Persistent Optuna resume requires `strategy.storage`, usually a SQLite URL such as `sqlite:///outputs/studies/name/optuna.db`. + +Schedulers are `local`, `slurm`, and `multi_run_lora`. Local parallel sweeps require explicit disjoint GPU groups: + +```toml +[scheduler] +type = "local" +max_parallel = 2 + +[scheduler.gpu_assignment] +mode = "static" +visible_devices = [[0, 1], [2, 3]] +``` + +SLURM sweeps submit through each target config's existing `[slurm]` support. The default asynchronous mode exits after submission and cannot use early stopping or Optuna; set `scheduler.synchronous = true` to submit each trial with blocking SLURM behavior so early stopping, Optuna, and Optuna pruners can observe per-trial outcomes. Combine `synchronous = true` with `max_parallel > 1` to drive up to N concurrent in-flight SLURM jobs from a single Optuna study — TPE automatically opts into `constant_liar` so concurrent asks diversify; pruners are not supported under parallel SLURM. `multi_run_lora` is RL-only and uses one shared trainer plus one orchestrator per trial; its `shared` config list describes the shared RL stack, while `base` is still the target config list used for study materialization. + +```toml +[parameters."trainer.optim.lr"] +values = [1e-5, 3e-5] + +[parameters."orchestrator.train.sampling.temperature"] +values = [0.7, 1.0] +``` + +Do not include both a parent path and one of its sub-paths in the same sweep, such as `optim` and `optim.lr`; the generated override TOML can only use one shape for a path. +Every path segment must be non-empty, including keys inside structured parent-table choices: avoid leading dots, trailing dots, empty nested keys, and doubled dots such as `optim..lr`. +Do not sweep `output_dir` or nested `*.output_dir` fields, including inside structured choice values such as `trainer = { ckpt = { output_dir = ... } }`; the sweep materializer owns per-trial output directories so metrics, status, and resume artifacts stay isolated. +When sweep W&B injection is enabled (the default), do not sweep `wandb`, `wandb.group`, `wandb.name`, `wandb.tags`, or nested shared W&B fields such as `trainer.wandb.name` and `orchestrator.wandb.project`, including when those fields are hidden inside a parent structured choice value; the sweep materializer owns per-trial identity and the RL shared W&B auto-setup propagates those fields into trainer/orchestrator configs. Set sweep `wandb = None` first if you need to manage target W&B fields yourself. Top-level non-identity fields such as `wandb.project`, and non-shared nested extras such as `orchestrator.wandb.log_extras.interval`, remain sweepable. +Sweep values that affect search space, scheduling, retry counts, or stopping logic must be real finite numbers, not booleans. Choice values must be TOML-serializable and are checked recursively for `nan`/`inf`, `None`, unsupported Python objects, and non-string dict keys; use the string `"None"` for nullable target fields. Boolean choice values are only valid when the resolved target field remains boolean; materialization rejects cases where Pydantic would coerce `true`/`false` into a numeric or string target, including booleans nested inside structured parent-table choices. Uniform/log-uniform bounds must be finite, and uniform ranges must also be finite after subtraction so random sampling cannot produce `inf`; int-uniform bounds/step reject booleans. Scheduler concurrency, GPU indices, retry budget, Optuna numeric fields, pruner numeric fields, and early-stopping numeric fields reject booleans. Optuna `poll_interval_seconds` and threshold early stopping also reject `nan`/`inf`; Optuna `seed` must be in NumPy's accepted range (`0 <= seed <= 2**32 - 1`); Optuna `storage`, when set, must be a non-blank SQLAlchemy URL. ASHA `min_resource` must be `auto` or at least 1, and Hyperband integer `max_resource` must be at least `min_resource`. Optuna choice values are stricter than grid/random choices because Optuna storage must round-trip them: use only `bool`, `int`, finite `float`, or `str`, and do not mix duplicate or equality-colliding values such as `True` with `1` or `1` with `1.0`. +Sweep objectives must name a non-blank `objective.metric`; blank names are rejected at config validation instead of producing every clean trial as a missing-objective failure. + +For `multi_run_lora`, parameters are limited to per-run orchestrator fields that the materializer can represent as TOML tables or exact scalar/list replacements. Do not use sub-paths under list fields such as `orchestrator.train.env.*`, `orchestrator.eval.env.*`, or `orchestrator.buffer.hash_keys.*`, including when those sub-paths are hidden inside structured choice values; swap the whole list only when the allowlist exposes that exact field. Exact allowlisted fields must also use values with the right shape: scalar fields cannot use structured values, `orchestrator.buffer.hash_keys` must use non-empty `list[str]` values, whole `*.sampling` or `*.sampling.extra_body` replacements must use dict/table values, and structured whole-`*.sampling` choices must still keep nested exact fields at their scalar/list/table shape. Do not set both `*.sampling.max_completion_tokens` and the deprecated `*.sampling.max_tokens` alias, including inside a structured whole-`*.sampling` choice. Fields coupled to the shared trainer, such as `orchestrator.max_steps` and `orchestrator.max_async_level`, are not per-run sweep fields. The shared `RLConfig` must enable `trainer.model.lora`, and `scheduler.max_concurrent_runs` must be no greater than `trainer.max_concurrent_runs`. Per-run LoRA rank/alpha overrides are validated against the shared trainer LoRA config after the shared `RLConfig` resolves. `orchestrator.batch_size` and `orchestrator.token_batch_size` are mutually exclusive sweep parameters; varying rollout `batch_size` or `oversampling_factor` clears inherited `token_batch_size`, and varying `token_batch_size` clears inherited rollout-only `batch_size` and `oversampling_factor`. `orchestrator.oversampling_factor` cannot be combined with `orchestrator.token_batch_size` because oversampling only applies to rollout batching. When `orchestrator.batch_size`, `orchestrator.token_batch_size`, or `orchestrator.oversampling_factor` varies and the shared TOML did not explicitly set `orchestrator.max_inflight_rollouts`, the materializer drops the auto-resolved shared value; rollout batching recomputes it, while token batching raises the normal requirement to set max-inflight explicitly. Group-level train/eval defaults such as `*.sampling`, `*.sampling.*`, `*.num_workers`, `*.max_retries`, eval `num_examples`, eval `rollouts_per_example`, and eval `interval` must also be allowed to re-propagate into env entries unless that env explicitly set the corresponding field in the shared TOML; deprecated shared aliases like `[[orchestrator.env]]`, `[orchestrator.sampling]`, and `max_tokens` still count as explicit shared settings while they are supported by the config loader. + +When resuming Optuna sweeps with persistent storage, the controller reconciles leftover RUNNING Optuna trials from the previous process. Completed/pruned status files are replayed to Optuna; stale running or missing-objective status files are marked failed so `status.json`, the manifest, and Optuna storage agree. Replayed pruned statuses are normalized to `objective = None`. Terminal Optuna storage trials must also agree with their terminal `status.json` state: completed trials need the same finite objective value, failed/pruned trials must not carry a finite objective, and each `status.json` id must match its manifest variant id. Newly pruned trials should still carry terminal bookkeeping (`finished_at` and the subprocess/per-run return code) while keeping `objective = None`. Grid/random, static `multi_run_lora`, and Optuna trials that fail target-config materialization are written as failed manifest/status artifacts, excluded from scheduler launches, and must clear stale generated resolved configs (`resolved.toml`, plus `control/orch.toml` for `multi_run_lora`) from any reused trial directory. With `continue_on_failure=false`, grid/random and static `multi_run_lora` materialization stops after the first failed trial and does not launch scheduler work after that preflight failure. In Optuna `multi_run_lora`, if `continue_on_failure=false` aborts a wave after one trial fails materialization, already-materialized siblings that were asked but never launched are marked failed with `failure_stage = "scheduler"` and an explanatory error before being told failed to Optuna. A clean process exit without a finite objective is recorded as `state = "failed"` with `failure_stage = "objective"` in local grid/random, static `multi_run_lora`, single-trial Optuna, and `multi_run_lora` Optuna flows. Resume requires the previous manifest to be valid JSON, be a JSON object whose `variants` are JSON object entries, the strategy (except increasing `strategy.num_trials`), search parameters, parameter order, and objective to match the previous manifest, exactly one manifest variant entry with an existing `status.json` and `resolved_checksum` for every existing Optuna storage trial, no manifest variants missing from storage, manifest `overrides` and id hashes that match Optuna storage parameters, and matching base-config checksums; previous `TrialState.FAIL` entries are carried into the resumed failure count. +When resuming grid/random sweeps, terminal `status.json` files are only skipped if the previous manifest is valid JSON, is a JSON object with the same entrypoint and scheduler type, compatible strategy (except increasing `strategy.num_trials`), same search parameters/objective and parameter order, unique well-formed variant IDs, no manifest variants outside the regenerated trial set, matching `status.json` ids, valid object-shaped `status.json` files, and still carries resolved/base checksums for the same base file list. Parameter order matters because grid/random generation consumes parameters in order, and the manifest stores `parameter_order` explicitly because JSON key sorting cannot preserve it. If the manifest is missing or incomplete, status files are malformed, or a status file is not a JSON object, resume fails closed instead of silently trusting stale completed/submitted statuses. A rejected resume drift must not overwrite the previous trial's `overrides.toml`, `resolved.toml`, or `command.txt`. Legacy completed trials with missing objectives are still counted as objective failures on resume; if an Optuna resume finds a leftover RUNNING storage trial whose status already recorded clean completion without a finite objective, keep `returncode = 0` and mark `failure_stage = "objective"` rather than turning it into a launcher failure. +For local grid/random and Optuna sweeps, `continue_on_failure = false` stops launching new trials after the first runtime failure or missing objective, but the controller should still write per-trial `status.json` updates and the manifest summary for any completed objective-bearing trials before exiting non-zero. Static `multi_run_lora` launches the whole wave at once, so it cannot stop already-running siblings and does not support early stopping; Optuna `multi_run_lora` can early-stop between waves. Static `multi_run_lora` should still reconcile all per-run exit codes and write the objective summary before exiting non-zero. +Before launching a fresh attempt, the sweep runtime clears attempt-scoped objective artifacts (`metrics.jsonl` plus legacy `run-*/final_summary.json`). `multi_run_lora` also clears stale `control/exit_code` and `control/evicted.txt` before invoking `rl-multi-run`, so old pruning or exit signals cannot affect a fresh wave. +Because the shared trainer scans every `shared/run_*` directory, `multi_run_lora` launches also mark inactive run directories evicted before starting `rl-multi-run`; only the current static wave or Optuna wave should be discoverable by the trainer. +Optuna pruning reports only finite intermediate metrics with non-negative integer `step` values; malformed, boolean, missing, or negative steps are ignored instead of being sent to `optuna_trial.report`. Final objectives read from `metrics.jsonl` use the same step rule and fall back to legacy `final_summary.json` only when no valid-step sidecar row supplies the metric. `FileMonitor` treats the explicit `monitor.log(..., step=...)` argument as the canonical sweep step even if the metrics payload also contains a `step` key, rejects non-integer, boolean, and negative explicit steps before mutating history or writing the sidecar, and replaces non-finite floats with `null` inside dict/list/tuple metric payloads so `metrics.jsonl` stays standard JSON. When `FileMonitor` is combined with W&B or Prime in `MultiMonitor`, its errors must propagate because the sweep sidecar is required for objective attribution; optional remote-monitor errors can remain warnings. A pruning decision should only be applied while the trial is still running; re-check subprocess liveness after reporting an intermediate metric and immediately before `should_prune()`. In `multi_run_lora`, `rl-multi-run` writes each `control/exit_code` as soon as that orchestrator exits, and the Optuna poller must re-read that file immediately before pruning so a completed run is not reclassified as pruned while sibling runs keep the wave alive. If the launcher has already exited, final objective reconciliation wins. +If the target launcher itself cannot be spawned (`OSError`, such as a missing command), local, pruner-enabled Optuna, SLURM, and `multi_run_lora` static/Optuna schedulers retry according to `retry_budget`; only after the final failed launch attempt should the runtime mark the affected trial(s) `failed` with `returncode = -1`, `failure_stage = "launch"`, and the error message in `status.json` before continuing or exiting according to `continue_on_failure`. + ### Discriminated unions Some config fields use discriminated unions (e.g. loss type, data type). Set the `type` field to select the variant: @@ -174,7 +243,7 @@ In TOML, an empty section header does the same: ## Key files -- `src/prime_rl/utils/config.py` — re-exports `BaseConfig` and `cli` from pydantic_config -- `src/prime_rl/configs/` — all domain-specific config classes +- `packages/prime-rl-configs/src/prime_rl/utils/config.py` — re-exports `BaseConfig` and `cli` from pydantic_config +- `packages/prime-rl-configs/src/prime_rl/configs/` — all domain-specific config classes - `configs/debug/` — minimal debug configs for testing - `examples/` — full example configs for various tasks diff --git a/skills/entrypoints/SKILL.md b/skills/entrypoints/SKILL.md index f6589aca97..4e75908ffb 100644 --- a/skills/entrypoints/SKILL.md +++ b/skills/entrypoints/SKILL.md @@ -17,7 +17,7 @@ uv run rl @ examples/reverse_text/rl.toml @ examples/reverse_text/slurm_rl.toml uv run rl @ examples/reverse_text/rl.toml --dry-run # generate scripts without running ``` -- **Config:** `RLConfig` (`src/prime_rl/configs/rl.py`) +- **Config:** `RLConfig` (`packages/prime-rl-configs/src/prime_rl/configs/rl.py`) - **Entrypoint:** `src/prime_rl/entrypoints/rl.py` - **SLURM:** yes — single-node and multi-node @@ -33,7 +33,7 @@ uv run sft @ examples/reverse_text/sft.toml --dry-run # generate scripts without The entrypoint launches torchrun internally — no need to call torchrun directly. -- **Config:** `SFTConfig` (`src/prime_rl/configs/sft.py`) +- **Config:** `SFTConfig` (`packages/prime-rl-configs/src/prime_rl/configs/sft.py`) - **Entrypoint:** `src/prime_rl/entrypoints/sft.py` - **SLURM:** yes — single-node and multi-node @@ -71,10 +71,34 @@ curl http://localhost:8000/v1/chat/completions \ -d '{"model": "Qwen/Qwen3-0.6B", "messages": [{"role": "user", "content": "Hi"}], "max_tokens": 50}' ``` -- **Config:** `InferenceConfig` (`src/prime_rl/configs/inference.py`) +- **Config:** `InferenceConfig` (`packages/prime-rl-configs/src/prime_rl/configs/inference.py`) - **Entrypoint:** `src/prime_rl/entrypoints/inference.py` - **SLURM:** yes — single-node, multi-node, and disaggregated deployments +## `sweep` — Hyperparameter sweeps + +Materializes and launches hyperparameter sweep trials for `rl` or `sft` target configs. + +```bash +uv run sweep @ examples/sweep/grid_local.toml +uv run sweep @ examples/sweep/grid_local.toml --dry-run +uv run sweep @ examples/sweep/optuna_local.toml # requires uv sync --extra hpo +``` + +Sweeps support grid/random/Optuna strategies over dotted target config paths. Standard trials write `overrides.toml`, `resolved.toml`, `command.txt`, and `status.json` under the study output directory, then launch the target entrypoint with the configured base files plus the generated `overrides.toml` unless `--dry-run` is set. + +Local `max_parallel > 1` requires explicit disjoint GPU groups under `[scheduler.gpu_assignment]`. SLURM sweeps default to asynchronous submission and exit after queueing jobs; set `scheduler.synchronous = true` when early stopping, Optuna, or Optuna pruners need per-trial completion feedback. Set `scheduler.max_parallel > 1` (synchronous mode only) to drive multiple concurrent SLURM jobs from a single Optuna study — TPE auto-enables `constant_liar` so concurrent asks don't collide on the same region; pruners are rejected under parallel SLURM since Optuna trial objects are not thread-safe. + +`multi_run_lora` sweeps materialize each trial under `shared/run_/`, write the per-run orchestrator config to `control/orch.toml`, and keep `orchestrator.output_dir` pinned to that run directory. The materializer validates the shared `RLConfig` first, requires `trainer.model.lora`, checks `scheduler.max_concurrent_runs <= trainer.max_concurrent_runs`, then layers allowlisted per-run orchestrator overrides and validates the final `OrchestratorConfig` so multi-run LoRA ranks can be below the trainer's max rank. Their replay command uses the shared trainer configs, the generated `shared/_output_override.toml`, and `--runs-dir ` so it goes through the same parser as `rl-multi-run`. +Before each `multi_run_lora` launch, the sweep controller writes `control/evicted.txt` in inactive `shared/run_*` directories so the trainer's directory scan ignores stale runs from previous waves or older studies. The Optuna wave driver also wipes trainer-owned subdirs (`checkpoints/`, `weights/`, `broadcasts/`, `rollouts/`, `run_default/`) inside `shared_dir` between waves via `_reset_trainer_state_for_wave`. If the shared `RLConfig` sets `ckpt.output_dir`, that directory's `checkpoints/` is reset too, resolved once at the top of the wave loop. Without this, a fresh wave-N+1 trainer would silently resume from the previous wave's checkpoints or collide on `step_*` writes; per-trial `run_` directories are intentionally preserved so the sweep controller's eviction logic stays in charge of them. +In shared W&B mode, the `rl-multi-run` launcher process itself is the shared-run primary: it calls `init_wandb_shared_primary` before spawning subprocesses (label `launcher`, `x_primary=True`, `x_update_finish_state=True`) and `finish()`es the run in a `try/finally`. Trainer and per-trial orchestrators stay non-primary (always `WANDB_SHARED_PRIMARY=0`, with explicit per-trial labels `orchestrator-`); they attach to the run the launcher created. This binds run finish state to the supervisor's lifetime, so a trainer exit at `max_steps` or a pruned orchestrator can no longer mark the shared run finished while sibling orchestrators are still emitting final-eval logs. Single-run still relies on the `WANDB_SHARED_LABEL=="orchestrator"` fallback in `WandbMonitor`. +`rl-multi-run --runs-dir` takes colon-separated run directories. Empty entries (`a::b`, leading/trailing `:`) and duplicate resolved directories are rejected before config parsing because each run needs a distinct `control/orch.toml` and `control/exit_code`. The launcher writes each run's `control/exit_code` as soon as that orchestrator process exits, then rewrites all exit codes during final cleanup; Optuna wave pruning relies on those early files to avoid pruning an already-completed run while sibling orchestrators are still active. A non-zero orchestrator exit must also leave `control/evicted.txt` in place, without overwriting an existing pruning reason, so the shared trainer stops discovering a crashed run and the rest of the wave can drain. Once every orchestrator has stopped and at least one exited non-zero, the launcher should tear down the trainer instead of waiting forever for batches that can no longer arrive. +If the sweep controller cannot spawn `rl-multi-run` at all (`OSError`), it retries the static wave or Optuna wave according to `retry_budget`; only the final failed launch attempt marks the affected run statuses with `failure_stage = "launch"`. Runtime non-zero exits are still reconciled from per-run `control/exit_code` instead of retrying the wave. + +- **Config:** `SweepConfig` (`packages/prime-rl-configs/src/prime_rl/configs/sweep.py`) +- **Entrypoint:** `src/prime_rl/entrypoints/sweep.py` +- **SLURM:** yes, through the target `rl`/`sft` config's existing `[slurm]` support + ## Summary | Command | Purpose | SLURM | Typical use | @@ -82,10 +106,12 @@ curl http://localhost:8000/v1/chat/completions \ | `rl` | Full RL pipeline | yes | Production RL training | | `sft` | Supervised fine-tuning | yes | SFT training | | `inference` | vLLM server | yes | Standalone inference or debugging | +| `sweep` | Hyperparameter sweeps | yes | Launching multiple RL/SFT variants | +| `rl-multi-run` | Shared-trainer LoRA wave launcher | no | Internal target for `multi_run_lora` sweeps | ## Key directories -- `src/prime_rl/entrypoints/` — top-level entrypoints (`rl`, `sft`, `inference`) -- `src/prime_rl/configs/` — all config classes +- `src/prime_rl/entrypoints/` — top-level entrypoints (`rl`, `sft`, `inference`, `sweep`) +- `packages/prime-rl-configs/src/prime_rl/configs/` — config classes - `configs/debug/` — minimal configs for quick testing - `examples/` — full example configs for various tasks diff --git a/src/prime_rl/entrypoints/launch.py b/src/prime_rl/entrypoints/launch.py new file mode 100644 index 0000000000..845401dee1 --- /dev/null +++ b/src/prime_rl/entrypoints/launch.py @@ -0,0 +1,320 @@ +"""Shared subprocess-launch primitives for RL entrypoints. + +``rl_local`` and the multi-run shared-trainer entrypoint both spin up an +inference server, optional teacher inference server, a trainer torchrun +process, and one or more orchestrators. The startup, monitor-thread, and +supervision loop are identical across them; this module factors that +boilerplate out so the two entrypoints differ only in the orchestrator +plumbing. + +All helpers operate on a ``LaunchSupervisor`` that bundles the shared +process / monitor / stop-event lists. Callers create one supervisor and +pass it to each ``start_*`` helper, then call ``wait_for_completion`` with +the labels whose termination signals end-of-run. +""" + +from __future__ import annotations + +import json +import os +import re +import sys +import time +import uuid +from dataclasses import dataclass, field +from pathlib import Path +from subprocess import PIPE, Popen +from threading import Event, Thread +from typing import TYPE_CHECKING, Any + +from prime_rl.utils.process import cleanup_processes, cleanup_threads, monitor_process +from prime_rl.utils.utils import get_free_port + +if TYPE_CHECKING: # pragma: no cover + from prime_rl.configs.rl import RLConfig + + +@dataclass +class GpuMapping: + """Resolved local→physical GPU assignments for an RL deployment.""" + + infer: list[int] + trainer: list[int] + teacher: list[int] + physical: dict[int, int] + + +@dataclass +class LaunchSupervisor: + """Shared bookkeeping for spawned subprocesses + their monitor threads.""" + + logger: Any + log_dir: Path + processes: list[Popen] = field(default_factory=list) + monitor_threads: list[Thread] = field(default_factory=list) + stop_events: dict[str, Event] = field(default_factory=dict) + error_queue: list[Exception] = field(default_factory=list) + + +def compute_gpu_mapping(config: RLConfig, get_physical_gpu_ids: Any) -> GpuMapping: + """Resolve the launcher-local GPU layout into physical GPU IDs. + + ``get_physical_gpu_ids`` is injected so this helper does not depend on + pynvml directly and stays trivially testable. + """ + gpu_offset = 0 + num_infer_gpus = config.deployment.num_infer_gpus if config.inference is not None else 0 + infer_local_gpu_ids = list(range(gpu_offset, gpu_offset + num_infer_gpus)) + gpu_offset += num_infer_gpus + trainer_local_gpu_ids = list(range(gpu_offset, gpu_offset + config.deployment.num_train_gpus)) + gpu_offset += config.deployment.num_train_gpus + num_teacher_gpus = config.deployment.num_teacher_gpus or 0 + teacher_local_gpu_ids = ( + list(range(gpu_offset, gpu_offset + num_teacher_gpus)) if num_teacher_gpus > 0 else [] + ) + + total_requested_gpus = num_infer_gpus + config.deployment.num_train_gpus + num_teacher_gpus + physical_gpu_ids = get_physical_gpu_ids() + if total_requested_gpus > len(physical_gpu_ids): + raise ValueError( + f"Requested {total_requested_gpus} GPUs via deployment settings, but only " + f"{len(physical_gpu_ids)} physical GPU(s) are available: {physical_gpu_ids}" + ) + physical = {local_id: physical_gpu_ids[local_id] for local_id in range(total_requested_gpus)} + return GpuMapping( + infer=[physical[i] for i in infer_local_gpu_ids], + trainer=[physical[i] for i in trainer_local_gpu_ids], + teacher=[physical[i] for i in teacher_local_gpu_ids], + physical=physical, + ) + + +def build_wandb_shared_env(config: RLConfig) -> dict[str, str]: + """Compose the WANDB_SHARED_* env that subprocesses inherit when shared mode is on.""" + env: dict[str, str] = {} + if config.wandb and config.wandb.shared: + env["WANDB_SHARED_MODE"] = "1" + env["WANDB_SHARED_RUN_ID"] = os.environ.get("WANDB_SHARED_RUN_ID", uuid.uuid4().hex) + return env + + +def init_wandb_shared_primary( + config: RLConfig, + wandb_shared_env: dict[str, str], + logger: Any | None = None, +) -> Any | None: + """Open the shared W&B run from the launcher process as primary. + + The launcher outlives every subprocess (trainer + orchestrators), so + binding ``x_update_finish_state`` to its lifetime is the only way to + guarantee the shared run finishes after every late metric has flushed. + A primary trainer would mark the run finished at ``max_steps`` while + orchestrators are still emitting final-eval logs; a primary orchestrator + would do the same if pruned. Returns the wandb ``Run`` so the caller + can ``finish()`` it; returns ``None`` when shared mode is off. + """ + if wandb_shared_env.get("WANDB_SHARED_MODE") != "1": + return None + if config.wandb is None: + return None + + import wandb + from wandb.errors import CommError + + run_id = wandb_shared_env["WANDB_SHARED_RUN_ID"] + settings = wandb.Settings( + mode="shared", + x_label="launcher", + x_primary=True, + x_update_finish_state=True, + ) + for attempt in range(5): + try: + return wandb.init( + id=run_id, + project=config.wandb.project, + entity=config.wandb.entity, + name=config.wandb.name, + group=config.wandb.group, + tags=config.wandb.tags, + settings=settings, + ) + except CommError as e: + if attempt == 4: + raise + if logger is not None: + logger.info(f"Transient W&B init error ({e}) - retrying in 10s ({attempt + 1}/5)") + time.sleep(10) + + raise RuntimeError("unreachable") + + +def _start_supervised( + label: str, + cmd: list[str], + env: dict[str, str], + log_path: Path, + supervisor: LaunchSupervisor, +) -> Popen: + """Start ``cmd`` with stdout/stderr to ``log_path`` and a monitor thread. + + Records the process, monitor thread, and stop event under ``label`` on + the supervisor so ``wait_for_completion`` can decide when to return. + """ + log_path.parent.mkdir(parents=True, exist_ok=True) + log_file = open(log_path, "w") + try: + process = Popen(cmd, env=env, stdout=log_file, stderr=log_file) + finally: + log_file.close() + supervisor.processes.append(process) + stop_event = Event() + supervisor.stop_events[label] = stop_event + monitor_thread = Thread( + target=monitor_process, + args=(process, stop_event, supervisor.error_queue, label), + daemon=True, + ) + monitor_thread.start() + supervisor.monitor_threads.append(monitor_thread) + return process + + +def start_inference( + *, + cmd: list[str], + gpu_ids: list[int], + label: str, + log_path: Path, + supervisor: LaunchSupervisor, +) -> Popen: + """Spawn an inference (or teacher inference) server pinned to ``gpu_ids``.""" + env = {**os.environ, "CUDA_VISIBLE_DEVICES": ",".join(map(str, gpu_ids))} + supervisor.logger.info(f"Starting {label} on GPU(s) {' '.join(map(str, gpu_ids))}") + supervisor.logger.debug(f"{label} start command: {' '.join(cmd)}") + return _start_supervised(label, cmd, env, log_path, supervisor) + + +def start_orchestrator( + *, + config_path: Path, + label: str, + log_path: Path, + start_command: list[str], + wandb_shared_env: dict[str, str], + wandb_program: str, + supervisor: LaunchSupervisor, + extra_env: dict[str, str] | None = None, +) -> Popen: + """Spawn an orchestrator subprocess pointed at ``config_path``. + + Single-run mode passes ``label="orchestrator"``. Multi-run mode passes a + per-run label like ``"orchestrator-0000-abc"`` so the supervisor's + ``stop_events`` and the W&B label can be told apart. ``extra_env`` is + layered on top of the default env (after WANDB_*) so callers can scope + per-orchestrator env vars like ``PRIME_RL_SWEEP_METRICS_JSONL`` without + leaking them into sibling orchestrators. + """ + cmd = ["orchestrator", "@", config_path.as_posix()] + env = { + **os.environ, + **wandb_shared_env, + "WANDB_SHARED_LABEL": label, + "WANDB_SHARED_PRIMARY": "0", + "LOGURU_FORCE_COLORS": "1", + "WANDB_PROGRAM": wandb_program, + "WANDB_ARGS": json.dumps(start_command), + } + if extra_env: + env.update(extra_env) + supervisor.logger.info(f"Starting {label} process") + supervisor.logger.debug(f"{label} start command: {' '.join(cmd)}") + return _start_supervised(label, cmd, env, log_path, supervisor) + + +def start_trainer( + *, + config_path: Path, + gpu_ids: list[int], + ranks_filter: list[int], + log_path: Path, + torchrun_log_dir: Path, + start_command: list[str], + wandb_shared_env: dict[str, str], + wandb_program: str, + supervisor: LaunchSupervisor, +) -> Popen: + """Spawn the torchrun-driven trainer subprocess.""" + cmd = [ + "torchrun", + "--role=trainer", + f"--rdzv-endpoint=localhost:{get_free_port()}", + f"--rdzv-id={uuid.uuid4().hex}", + f"--log-dir={torchrun_log_dir}", + f"--local-ranks-filter={','.join(map(str, ranks_filter))}", + "--redirect=3", + "--tee=3", + f"--nproc-per-node={len(gpu_ids)}", + "-m", + "prime_rl.trainer.rl.train", + "@", + config_path.as_posix(), + ] + env = { + **os.environ, + **wandb_shared_env, + "WANDB_SHARED_LABEL": "trainer", + "WANDB_SHARED_PRIMARY": "0", + "CUDA_VISIBLE_DEVICES": ",".join(map(str, gpu_ids)), + "PYTHONUNBUFFERED": "1", + "PYTORCH_CUDA_ALLOC_CONF": "expandable_segments:True", + "LOGURU_FORCE_COLORS": "1", + "WANDB_PROGRAM": wandb_program, + "WANDB_ARGS": json.dumps(start_command), + } + supervisor.logger.info(f"Starting trainer on GPU(s) {' '.join(map(str, gpu_ids))}") + supervisor.logger.debug(f"Training start command: {' '.join(cmd)}") + return _start_supervised("trainer", cmd, env, log_path, supervisor) + + +def tail_trainer_log(supervisor: LaunchSupervisor, trainer_log: Path) -> Popen: + """Mirror the trainer log to stdout so the user sees live training output.""" + tail = Popen( + ["tail", "-F", trainer_log.as_posix()], + stdout=PIPE, + stderr=sys.stderr, + text=True, + bufsize=1, + ) + + def print_trainer_lines() -> None: + assert tail.stdout is not None + rank_prefix = re.compile(r"^\[[a-zA-Z]*[0-9]*\]:") + for line in tail.stdout: + print(rank_prefix.sub("", line), end="", flush=True) + + Thread(target=print_trainer_lines, daemon=True).start() + supervisor.processes.append(tail) + return tail + + +def wait_for_completion( + primary_labels: list[str], + supervisor: LaunchSupervisor, +) -> None: + """Block until every primary label's stop event fires. + + A failure surfaced through ``supervisor.error_queue`` (typically a + crashed monitor thread) tears down all processes and exits 1. Successful + completion returns; the caller is responsible for inspecting individual + return codes and final cleanup. + """ + while not all(supervisor.stop_events[label].is_set() for label in primary_labels): + if supervisor.error_queue: + error = supervisor.error_queue[0] + supervisor.logger.error(f"Error: {error}") + supervisor.logger.error("Terminating all processes...") + cleanup_threads(supervisor.monitor_threads) + cleanup_processes(supervisor.processes) + sys.exit(1) + time.sleep(1) diff --git a/src/prime_rl/entrypoints/rl.py b/src/prime_rl/entrypoints/rl.py index b740f58ba2..0fc872c128 100644 --- a/src/prime_rl/entrypoints/rl.py +++ b/src/prime_rl/entrypoints/rl.py @@ -1,19 +1,24 @@ -import json import os import signal import subprocess import sys -import time -import uuid from pathlib import Path -from subprocess import Popen -from threading import Event, Thread import pynvml import tomli_w import prime_rl._compat # noqa: F401 — patch ring_flash_attn compat before transitive import from prime_rl.configs.rl import RLConfig +from prime_rl.entrypoints.launch import ( + LaunchSupervisor, + build_wandb_shared_env, + compute_gpu_mapping, + start_inference, + start_orchestrator, + start_trainer, + tail_trainer_log, + wait_for_completion, +) from prime_rl.trainer.model import pre_download_model from prime_rl.utils.config import cli from prime_rl.utils.logger import get_logger, setup_logger @@ -24,11 +29,8 @@ resolve_latest_ckpt_step, validate_output_dir, ) -from prime_rl.utils.process import cleanup_processes, cleanup_threads, monitor_process, set_proc_title -from prime_rl.utils.utils import ( - get_free_port, - get_log_dir, -) +from prime_rl.utils.process import cleanup_processes, cleanup_threads, set_proc_title +from prime_rl.utils.utils import get_log_dir RL_TOML = "rl.toml" RL_SBATCH = "rl.sbatch" @@ -114,45 +116,18 @@ def rl_local(config: RLConfig): logger.success("Dry run complete. To start an RL run locally, remove --dry-run from your command.") return - # Derive launcher-local GPU IDs from deployment config - gpu_offset = 0 - num_infer_gpus = config.deployment.num_infer_gpus if config.inference is not None else 0 - infer_local_gpu_ids = list(range(gpu_offset, gpu_offset + num_infer_gpus)) - gpu_offset += num_infer_gpus - trainer_local_gpu_ids = list(range(gpu_offset, gpu_offset + config.deployment.num_train_gpus)) - gpu_offset += config.deployment.num_train_gpus - num_teacher_gpus = config.deployment.num_teacher_gpus or 0 - teacher_local_gpu_ids = list(range(gpu_offset, gpu_offset + num_teacher_gpus)) if num_teacher_gpus > 0 else [] - - total_requested_gpus = num_infer_gpus + config.deployment.num_train_gpus + num_teacher_gpus - physical_gpu_ids = get_physical_gpu_ids() - if total_requested_gpus > len(physical_gpu_ids): - raise ValueError( - f"Requested {total_requested_gpus} GPUs via deployment settings, but only " - f"{len(physical_gpu_ids)} physical GPU(s) are available: {physical_gpu_ids}" - ) - physical_gpu_mapping = {local_id: physical_gpu_ids[local_id] for local_id in range(total_requested_gpus)} - logger.info(f"Using local->physical GPU mapping: {physical_gpu_mapping}") - - infer_gpu_ids = [physical_gpu_mapping[local_gpu_id] for local_gpu_id in infer_local_gpu_ids] - trainer_gpu_ids = [physical_gpu_mapping[local_gpu_id] for local_gpu_id in trainer_local_gpu_ids] - teacher_gpu_ids = [physical_gpu_mapping[local_gpu_id] for local_gpu_id in teacher_local_gpu_ids] + mapping = compute_gpu_mapping(config, get_physical_gpu_ids) + logger.info(f"Using local->physical GPU mapping: {mapping.physical}") start_command = sys.argv logger.info("Starting RL run") logger.debug(f"RL start command: {' '.join(start_command)}") - # Build shared W&B env vars for subprocesses - wandb_shared_env: dict[str, str] = {} - if config.wandb and config.wandb.shared: - wandb_shared_env["WANDB_SHARED_MODE"] = "1" - wandb_shared_env["WANDB_SHARED_RUN_ID"] = os.environ.get("WANDB_SHARED_RUN_ID", uuid.uuid4().hex) + wandb_shared_env = build_wandb_shared_env(config) - # Check for existing processes on GPUs - all_gpu_ids = list(set(infer_gpu_ids + trainer_gpu_ids + teacher_gpu_ids)) + all_gpu_ids = list(set(mapping.infer + mapping.trainer + mapping.teacher)) check_gpus_available(all_gpu_ids) - # Validate client port matches inference server port if config.inference is not None and not config.orchestrator.client.is_elastic: from urllib.parse import urlparse @@ -167,97 +142,51 @@ def rl_local(config: RLConfig): f"Update the base_url to use port {expected_port} to match the inference server." ) - # Prepare paths to communicate with the trainer log_dir = get_log_dir(config.output_dir) log_dir.mkdir(parents=True, exist_ok=True) - # Start processes - processes: list[Popen] = [] - monitor_threads: list[Thread] = [] - error_queue: list[Exception] = [] - stop_events: dict[str, Event] = {} + supervisor = LaunchSupervisor(logger=logger, log_dir=log_dir) def sigterm_handler(signum, frame): logger.warning("Received SIGTERM, terminating all processes...") - cleanup_threads(monitor_threads) - cleanup_processes(processes) + cleanup_threads(supervisor.monitor_threads) + cleanup_processes(supervisor.processes) sys.exit(1) signal.signal(signal.SIGTERM, sigterm_handler) try: - # Optionally, start inference process if config.inference: - inference_cmd = ["inference", "@", (config_dir / INFERENCE_TOML).as_posix()] - logger.info(f"Starting inference on GPU(s) {' '.join(map(str, infer_gpu_ids))}") - logger.debug(f"Inference start command: {' '.join(inference_cmd)}") - # If we don't log stdout, the server hangs - with open(log_dir / "inference.log", "w") as log_file: - inference_process = Popen( - inference_cmd, - env={ - **os.environ, - "CUDA_VISIBLE_DEVICES": ",".join(map(str, infer_gpu_ids)), - }, - stdout=log_file, - stderr=log_file, - ) - processes.append(inference_process) - - # Start monitoring thread - stop_event = Event() - stop_events["inference"] = stop_event - monitor_thread = Thread( - target=monitor_process, - args=(inference_process, stop_event, error_queue, "inference"), - daemon=True, + start_inference( + cmd=["inference", "@", (config_dir / INFERENCE_TOML).as_posix()], + gpu_ids=mapping.infer, + label="inference", + log_path=log_dir / "inference.log", + supervisor=supervisor, + ) + elif config.orchestrator.teacher_rollout_model is None: + logger.warning( + "No inference config specified, skipping starting inference server. Make sure your inference server is running." ) - monitor_thread.start() - monitor_threads.append(monitor_thread) else: - if config.orchestrator.teacher_rollout_model is None: - logger.warning( - "No inference config specified, skipping starting inference server. Make sure your inference server is running." - ) - else: - logger.info( - "No inference config specified, using orchestrator.teacher_rollout_model for rollout generation." - ) + logger.info( + "No inference config specified, using orchestrator.teacher_rollout_model for rollout generation." + ) - # Optionally, start teacher inference process if config.teacher_inference: - if not teacher_gpu_ids: + if not mapping.teacher: raise ValueError( "teacher_inference is configured but deployment.num_teacher_gpus is not set. " "Either set deployment.num_teacher_gpus to start a teacher inference server, " "or omit teacher_inference and configure orchestrator.teacher_model to use an existing server." ) - - teacher_inference_cmd = ["inference", "@", (config_dir / TEACHER_INFERENCE_TOML).as_posix()] - logger.info(f"Starting teacher inference process on GPU(s) {' '.join(map(str, teacher_gpu_ids))}") - logger.debug(f"Teacher inference start command: {' '.join(teacher_inference_cmd)}") - with open(log_dir / "teacher_inference.log", "w") as log_file: - teacher_inference_process = Popen( - teacher_inference_cmd, - env={ - **os.environ, - "CUDA_VISIBLE_DEVICES": ",".join(map(str, teacher_gpu_ids)), - }, - stdout=log_file, - stderr=log_file, - ) - processes.append(teacher_inference_process) - - # Start monitoring thread - stop_event = Event() - stop_events["teacher_inference"] = stop_event - monitor_thread = Thread( - target=monitor_process, - args=(teacher_inference_process, stop_event, error_queue, "teacher_inference"), - daemon=True, + start_inference( + cmd=["inference", "@", (config_dir / TEACHER_INFERENCE_TOML).as_posix()], + gpu_ids=mapping.teacher, + label="teacher_inference", + log_path=log_dir / "teacher_inference.log", + supervisor=supervisor, ) - monitor_thread.start() - monitor_threads.append(monitor_thread) elif ( config.trainer.loss.type == "default" and config.trainer.loss.teacher_tau > 0 ) or config.orchestrator.teacher_model: @@ -266,138 +195,58 @@ def sigterm_handler(signum, frame): "Is your teacher inference server running? Make sure orchestrator.teacher_model is configured." ) - # Start orchestrator process - orchestrator_cmd = [ - "orchestrator", - "@", - (config_dir / ORCHESTRATOR_TOML).as_posix(), - ] - logger.info("Starting orchestrator process") - logger.debug(f"Orchestrator start command: {' '.join(orchestrator_cmd)}") - with open(log_dir / "orchestrator.log", "w") as log_file: - orchestrator_process = Popen( - orchestrator_cmd, - stdout=log_file, - stderr=log_file, - env={ - **os.environ, - **wandb_shared_env, - "WANDB_SHARED_LABEL": "orchestrator", - "LOGURU_FORCE_COLORS": "1", - "WANDB_PROGRAM": "uv run rl", - "WANDB_ARGS": json.dumps(start_command), - }, - ) - processes.append(orchestrator_process) - - # Start monitoring thread - stop_event = Event() - stop_events["orchestrator"] = stop_event - monitor_thread = Thread( - target=monitor_process, - args=(orchestrator_process, stop_event, error_queue, "orchestrator"), - daemon=True, + orchestrator_process = start_orchestrator( + config_path=config_dir / ORCHESTRATOR_TOML, + label="orchestrator", + log_path=log_dir / "orchestrator.log", + start_command=start_command, + wandb_shared_env=wandb_shared_env, + wandb_program="uv run rl", + supervisor=supervisor, ) - monitor_thread.start() - monitor_threads.append(monitor_thread) - - # Start training process - trainer_cmd = [ - "torchrun", - "--role=trainer", - f"--rdzv-endpoint=localhost:{get_free_port()}", - f"--rdzv-id={uuid.uuid4().hex}", - # Pipe all logs to file, and only master rank logs to stdout - f"--log-dir={log_dir / 'trainer' / 'torchrun'}", - f"--local-ranks-filter={','.join(map(str, config.trainer.log.ranks_filter))}", - "--redirect=3", - "--tee=3", - f"--nproc-per-node={len(trainer_gpu_ids)}", - "-m", - "prime_rl.trainer.rl.train", - "@", - (config_dir / TRAINER_TOML).as_posix(), - ] - logger.info(f"Starting trainer on GPU(s) {' '.join(map(str, trainer_gpu_ids))}") - logger.debug(f"Training start command: {' '.join(trainer_cmd)}") - with open(log_dir / "trainer.log", "w") as log_file: - trainer_process = Popen( - trainer_cmd, - env={ - **os.environ, - **wandb_shared_env, - "WANDB_SHARED_LABEL": "trainer", - "CUDA_VISIBLE_DEVICES": ",".join(map(str, trainer_gpu_ids)), - "PYTHONUNBUFFERED": "1", - "PYTORCH_CUDA_ALLOC_CONF": "expandable_segments:True", - "LOGURU_FORCE_COLORS": "1", - "WANDB_PROGRAM": "uv run rl", - "WANDB_ARGS": json.dumps(start_command), - }, - stdout=log_file, - stderr=log_file, - ) - processes.append(trainer_process) - # Start monitoring thread - stop_event = Event() - stop_events["trainer"] = stop_event - monitor_thread = Thread( - target=monitor_process, args=(trainer_process, stop_event, error_queue, "trainer"), daemon=True + trainer_process = start_trainer( + config_path=config_dir / TRAINER_TOML, + gpu_ids=mapping.trainer, + ranks_filter=config.trainer.log.ranks_filter, + log_path=log_dir / "trainer.log", + torchrun_log_dir=log_dir / "trainer" / "torchrun", + start_command=start_command, + wandb_shared_env=wandb_shared_env, + wandb_program="uv run rl", + supervisor=supervisor, ) - monitor_thread.start() - monitor_threads.append(monitor_thread) - # Monitor all processes for failures logger.success("Startup complete. Showing trainer logs...") + tail_trainer_log(supervisor, log_dir / "trainer.log") + + wait_for_completion(["orchestrator", "trainer"], supervisor) - tail_process = Popen( - f"tail -F '{log_dir / 'trainer.log'}' | sed -u 's/^\\[[a-zA-Z]*[0-9]*\\]://'", - shell=True, - ) - processes.append(tail_process) - - # Check for errors from monitor threads - while not (stop_events["orchestrator"].is_set() and stop_events["trainer"].is_set()): - if error_queue: - error = error_queue[0] - logger.error(f"Error: {error}") - logger.error("Terminating all processes...") - cleanup_threads(monitor_threads) - cleanup_processes(processes) - sys.exit(1) - - # Small delay to avoid busy waiting - time.sleep(1) - - # Check if any critical process failed if orchestrator_process.returncode != 0: logger.error(f"Orchestrator failed with exit code {orchestrator_process.returncode}") - cleanup_threads(monitor_threads) - cleanup_processes(processes) + cleanup_threads(supervisor.monitor_threads) + cleanup_processes(supervisor.processes) sys.exit(1) if trainer_process.returncode != 0: logger.error(f"Trainer failed with exit code {trainer_process.returncode}") - cleanup_threads(monitor_threads) - cleanup_processes(processes) + cleanup_threads(supervisor.monitor_threads) + cleanup_processes(supervisor.processes) sys.exit(1) logger.success("RL training finished!") - - # Cleanup threads and processes - cleanup_threads(monitor_threads) - cleanup_processes(processes) + cleanup_threads(supervisor.monitor_threads) + cleanup_processes(supervisor.processes) except KeyboardInterrupt: logger.warning("Received interrupt signal, terminating all processes...") - cleanup_threads(monitor_threads) - cleanup_processes(processes) + cleanup_threads(supervisor.monitor_threads) + cleanup_processes(supervisor.processes) sys.exit(1) except Exception as e: logger.error(f"Error occurred: {e}") - cleanup_threads(monitor_threads) - cleanup_processes(processes) + cleanup_threads(supervisor.monitor_threads) + cleanup_processes(supervisor.processes) raise diff --git a/src/prime_rl/entrypoints/rl_multi_run.py b/src/prime_rl/entrypoints/rl_multi_run.py new file mode 100644 index 0000000000..dc82678111 --- /dev/null +++ b/src/prime_rl/entrypoints/rl_multi_run.py @@ -0,0 +1,320 @@ +"""Shared-trainer LoRA launcher. + +Boots one trainer + one inference server + N orchestrators against a shared +output directory, one orchestrator per pre-materialized ``run_*`` directory. +The trainer's ``MultiRunManager`` discovers the run dirs, allocates per-run +LoRA adapter slots, and routes each orchestrator's training samples to the +right adapter. + +The sweep controller is the expected caller: it materializes +``/run_/control/orch.toml`` for every trial up front +and then invokes:: + + rl-multi-run @ shared.toml --runs-dir run_a:run_b:run_c + +``shared.toml`` is an ordinary ``RLConfig`` whose trainer block has +``max_concurrent_runs = N`` and whose orchestrator block carries the shared +defaults trial overrides inherit from. The orchestrator block in the shared +file is *not* used to launch a process here; only the per-run ``orch.toml`` +files referenced via ``--runs-dir`` produce orchestrators. +""" + +import os +import signal +import sys +import time +from pathlib import Path + +from prime_rl.configs.rl import RLConfig +from prime_rl.entrypoints.launch import ( + LaunchSupervisor, + build_wandb_shared_env, + compute_gpu_mapping, + init_wandb_shared_primary, + start_inference, + start_orchestrator, + start_trainer, + tail_trainer_log, +) +from prime_rl.entrypoints.rl import ( + INFERENCE_TOML, + TEACHER_INFERENCE_TOML, + TRAINER_TOML, + check_gpus_available, + get_physical_gpu_ids, + write_subconfigs, +) +from prime_rl.entrypoints.rl_multi_run_args import RUNS_DIR_FLAG, parse_runs_dirs +from prime_rl.sweep.run_control import ( + finished_orchestrator_failures, + record_finished_orchestrator_exit_codes, + record_orchestrator_exit_codes, +) +from prime_rl.utils.config import cli +from prime_rl.utils.logger import setup_logger +from prime_rl.utils.monitor import SWEEP_METRICS_JSONL_ENV +from prime_rl.utils.process import cleanup_processes, cleanup_threads, set_proc_title +from prime_rl.utils.utils import get_log_dir + + +def _validate_run_layout(run_dirs: list[Path]) -> None: + """Each run dir must already contain ``control/orch.toml``. + + The trainer's ``MultiRunManager`` will reject a run whose config file is + missing; failing loudly here gives a much better error than waiting for + the trainer to silently skip the run. + """ + missing = [d for d in run_dirs if not (d / "control" / "orch.toml").exists()] + if missing: + raise SystemExit( + f"{RUNS_DIR_FLAG} entries missing control/orch.toml: {[d.as_posix() for d in missing]}. " + "The sweep launcher must pre-materialize each run before invoking rl-multi-run." + ) + + +def _validate_concurrency(config: RLConfig, run_dirs: list[Path]) -> None: + """Trainer must be sized for at least len(run_dirs) concurrent runs.""" + max_runs = getattr(config.trainer, "max_concurrent_runs", None) + if max_runs is None or max_runs < 1: + raise SystemExit( + "rl-multi-run requires trainer.max_concurrent_runs >= 1 in the shared config " + "(use multi-run-LoRA training)." + ) + if max_runs < len(run_dirs): + raise SystemExit( + f"trainer.max_concurrent_runs={max_runs} but {len(run_dirs)} run dirs were passed; " + "set max_concurrent_runs to at least the number of concurrent trials." + ) + + +def rl_multi_run(config: RLConfig, run_dirs: list[Path]) -> None: + assert config.deployment.type == "single_node", "rl-multi-run is single-node only" + _validate_concurrency(config, run_dirs) + _validate_run_layout(run_dirs) + + logger = setup_logger( + config.log.level or os.environ.get("PRIME_LOG_LEVEL", "info"), + json_logging=config.log.json_logging, + ) + + config_dir = config.output_dir / "configs" + write_subconfigs(config, config_dir) + logger.info(f"Wrote subconfigs to {config_dir}") + + if config.dry_run: + logger.success( + "Dry run complete. To start a multi-run RL launch, remove --dry-run from your command." + ) + return + + mapping = compute_gpu_mapping(config, get_physical_gpu_ids) + logger.info(f"Using local->physical GPU mapping: {mapping.physical}") + + start_command = sys.argv + logger.info(f"Starting multi-run RL launch with {len(run_dirs)} orchestrator(s)") + logger.debug(f"Multi-run RL start command: {' '.join(start_command)}") + + wandb_shared_env = build_wandb_shared_env(config) + + all_gpu_ids = list(set(mapping.infer + mapping.trainer + mapping.teacher)) + check_gpus_available(all_gpu_ids) + + log_dir = get_log_dir(config.output_dir) + log_dir.mkdir(parents=True, exist_ok=True) + + supervisor = LaunchSupervisor(logger=logger, log_dir=log_dir) + + orchestrator_labels: list[str] = [] + orchestrator_processes = [] + recorded_exit_code_dirs: set[Path] = set() + + def sigterm_handler(signum, frame): + logger.warning("Received SIGTERM, terminating all processes...") + cleanup_threads(supervisor.monitor_threads) + cleanup_processes(supervisor.processes) + record_orchestrator_exit_codes(orchestrator_processes, run_dirs) + sys.exit(1) + + signal.signal(signal.SIGTERM, sigterm_handler) + + launcher_wandb_run = None + try: + # The launcher process owns the shared W&B run so its lifetime matches + # the supervisor that waits on every subprocess. Trainer and + # orchestrators stay non-primary so neither a trainer that exits at + # max_steps nor a pruned orchestrator can mark the run finished while + # siblings still have unflushed metrics (e.g. final-eval logs). + launcher_wandb_run = init_wandb_shared_primary(config, wandb_shared_env, logger) + + if config.inference: + start_inference( + cmd=["inference", "@", (config_dir / INFERENCE_TOML).as_posix()], + gpu_ids=mapping.infer, + label="inference", + log_path=log_dir / "inference.log", + supervisor=supervisor, + ) + else: + logger.warning( + "No inference config specified, skipping starting inference server. " + "Make sure your inference server is running." + ) + + if config.teacher_inference: + if not mapping.teacher: + raise ValueError( + "teacher_inference is configured but deployment.num_teacher_gpus is not set." + ) + start_inference( + cmd=["inference", "@", (config_dir / TEACHER_INFERENCE_TOML).as_posix()], + gpu_ids=mapping.teacher, + label="teacher_inference", + log_path=log_dir / "teacher_inference.log", + supervisor=supervisor, + ) + + for run_dir in run_dirs: + run_id = run_dir.name + label = f"orchestrator-{run_id}" + orchestrator_labels.append(label) + orchestrator_processes.append( + start_orchestrator( + config_path=run_dir / "control" / "orch.toml", + label=label, + log_path=log_dir / f"orchestrator-{run_id}.log", + start_command=start_command, + wandb_shared_env=wandb_shared_env, + wandb_program="uv run rl-multi-run", + supervisor=supervisor, + # Per-run sweep sidecar metrics, so the controller can + # read this trial's objective from /metrics.jsonl + # without colliding with sibling orchestrators. + extra_env={SWEEP_METRICS_JSONL_ENV: (run_dir / "metrics.jsonl").as_posix()}, + ) + ) + + trainer_process = start_trainer( + config_path=config_dir / TRAINER_TOML, + gpu_ids=mapping.trainer, + ranks_filter=config.trainer.log.ranks_filter, + log_path=log_dir / "trainer.log", + torchrun_log_dir=log_dir / "trainer" / "torchrun", + start_command=start_command, + wandb_shared_env=wandb_shared_env, + wandb_program="uv run rl-multi-run", + supervisor=supervisor, + ) + + logger.success("Startup complete. Showing trainer logs...") + tail_trainer_log(supervisor, log_dir / "trainer.log") + + # Wait for every orchestrator + the trainer. Unlike single-run, we + # cannot fail-fast on the first non-zero exit: a pruned or failed + # orchestrator (Optuna writes evicted.txt, the orchestrator raises) + # is expected, and the trainer's MultiRunManager keeps driving + # surviving runs. The sweep controller attributes failure per-run + # via each control/exit_code file, so we just wait for everything + # to finish and let the post-wait code report status. + # + # Trainer / inference / teacher_inference crashes are still + # terminal — survivors would block forever (trainer waiting for + # batches, orchestrators waiting for weights), so we tear the wave + # down and exit non-zero, but only after recording per-run exit + # codes so the controller can attribute failures correctly. + primary_label_set = {*orchestrator_labels, "trainer"} + infra_labels = [label for label in supervisor.stop_events if label not in primary_label_set] + all_labels = orchestrator_labels + ["trainer"] + + def _terminate_with_exit_codes(reason: str) -> None: + logger.error(f"{reason}; tearing down remaining processes") + record_orchestrator_exit_codes(orchestrator_processes, run_dirs) + cleanup_threads(supervisor.monitor_threads) + cleanup_processes(supervisor.processes) + sys.exit(1) + + while not all(supervisor.stop_events[label].is_set() for label in all_labels): + record_finished_orchestrator_exit_codes( + orchestrator_processes, + run_dirs, + recorded_exit_code_dirs, + ) + failed_finished_orchestrators = finished_orchestrator_failures( + orchestrator_labels, + orchestrator_processes, + supervisor.stop_events, + ) + if failed_finished_orchestrators: + _terminate_with_exit_codes( + "All orchestrators have exited and at least one failed or was pruned: " + f"{failed_finished_orchestrators}" + ) + # An infra subprocess (inference / teacher_inference) is meant + # to run for the whole wave; if its stop_event fires here, it + # exited unexpectedly and the trainer/orchestrators will hang. + failed_infra = [label for label in infra_labels if supervisor.stop_events[label].is_set()] + if failed_infra: + _terminate_with_exit_codes( + f"Infrastructure process(es) exited unexpectedly: {failed_infra}" + ) + if supervisor.stop_events["trainer"].is_set() and trainer_process.returncode != 0: + _terminate_with_exit_codes( + f"Trainer failed with exit code {trainer_process.returncode}" + ) + time.sleep(1) + + # Per-orchestrator exit_code is the sweep controller's source of + # truth for failure attribution; write it as soon as we've waited + # for every orchestrator, before any cleanup that might mask codes. + record_orchestrator_exit_codes(orchestrator_processes, run_dirs) + + failed_orchestrators = [ + (label, proc.returncode) + for label, proc in zip(orchestrator_labels, orchestrator_processes) + if proc.returncode != 0 + ] + if failed_orchestrators: + for label, code in failed_orchestrators: + logger.error(f"{label} failed with exit code {code}") + cleanup_threads(supervisor.monitor_threads) + cleanup_processes(supervisor.processes) + sys.exit(1) + + if trainer_process.returncode != 0: + logger.error(f"Trainer failed with exit code {trainer_process.returncode}") + cleanup_threads(supervisor.monitor_threads) + cleanup_processes(supervisor.processes) + sys.exit(1) + + logger.success("Multi-run RL training finished!") + cleanup_threads(supervisor.monitor_threads) + cleanup_processes(supervisor.processes) + + except KeyboardInterrupt: + logger.warning("Received interrupt signal, terminating all processes...") + cleanup_threads(supervisor.monitor_threads) + cleanup_processes(supervisor.processes) + record_orchestrator_exit_codes(orchestrator_processes, run_dirs) + sys.exit(1) + except Exception as e: + logger.error(f"Error occurred: {e}") + cleanup_threads(supervisor.monitor_threads) + cleanup_processes(supervisor.processes) + record_orchestrator_exit_codes(orchestrator_processes, run_dirs) + raise + finally: + # Run finish on every exit path (success, sys.exit, raise, SIGTERM) + # so x_update_finish_state fires only after the supervisor has + # waited on every subprocess. + if launcher_wandb_run is not None: + launcher_wandb_run.finish() + + +def main(): + set_proc_title("MultiRunLauncher") + run_dirs, remaining = parse_runs_dirs(sys.argv[1:]) + config = cli(RLConfig, args=remaining) + rl_multi_run(config, run_dirs) + + +if __name__ == "__main__": + main() diff --git a/src/prime_rl/entrypoints/rl_multi_run_args.py b/src/prime_rl/entrypoints/rl_multi_run_args.py new file mode 100644 index 0000000000..20d3a7752d --- /dev/null +++ b/src/prime_rl/entrypoints/rl_multi_run_args.py @@ -0,0 +1,27 @@ +import argparse +from pathlib import Path + +RUNS_DIR_FLAG = "--runs-dir" + + +def parse_runs_dirs(argv: list[str]) -> tuple[list[Path], list[str]]: + """Peel ``--runs-dir `` off argv before pydantic_config. + + Returns ``(run_dirs, remaining_argv)``. The remaining argv is passed to + ``cli(RLConfig)`` so the standard ``@ shared.toml`` syntax keeps working. + """ + parser = argparse.ArgumentParser(add_help=False) + parser.add_argument(RUNS_DIR_FLAG, required=True) + namespace, remaining = parser.parse_known_args(argv) + raw = namespace.runs_dir + if not raw: + raise SystemExit(f"{RUNS_DIR_FLAG} must list at least one run directory") + pieces = raw.split(":") + if any(not piece for piece in pieces): + raise SystemExit(f"{RUNS_DIR_FLAG} contains an empty run directory entry: {raw!r}") + run_dirs = [Path(piece).resolve() for piece in pieces] + if not run_dirs: + raise SystemExit(f"{RUNS_DIR_FLAG} parsed to no run directories: {raw!r}") + if len(run_dirs) != len(set(run_dirs)): + raise SystemExit(f"{RUNS_DIR_FLAG} contains duplicate run directories: {raw!r}") + return run_dirs, remaining diff --git a/src/prime_rl/entrypoints/sweep.py b/src/prime_rl/entrypoints/sweep.py new file mode 100644 index 0000000000..364238fa6b --- /dev/null +++ b/src/prime_rl/entrypoints/sweep.py @@ -0,0 +1,13 @@ +from prime_rl.configs.sweep import SweepConfig +from prime_rl.sweep.controller import run_sweep +from prime_rl.utils.config import cli +from prime_rl.utils.process import set_proc_title + + +def main(): + set_proc_title("Sweep") + run_sweep(cli(SweepConfig)) + + +if __name__ == "__main__": + main() diff --git a/src/prime_rl/sweep/__init__.py b/src/prime_rl/sweep/__init__.py new file mode 100644 index 0000000000..fd7f099392 --- /dev/null +++ b/src/prime_rl/sweep/__init__.py @@ -0,0 +1 @@ +"""Hyperparameter sweep support.""" diff --git a/src/prime_rl/sweep/controller.py b/src/prime_rl/sweep/controller.py new file mode 100644 index 0000000000..f52c2db002 --- /dev/null +++ b/src/prime_rl/sweep/controller.py @@ -0,0 +1,756 @@ +import json +import shlex +import shutil +from dataclasses import asdict +from pathlib import Path +from typing import Any + +import tomli_w + +from prime_rl.configs.sweep import ( + GridStrategyConfig, + LocalSweepSchedulerConfig, + MultiRunLoRASchedulerConfig, + OptunaStrategyConfig, + RandomStrategyConfig, + SlurmSweepSchedulerConfig, + SweepConfig, +) +from prime_rl.sweep.early_stopping import TrialOutcome, TrialOutcomeTracker +from prime_rl.sweep.materialize import ( + SweepDriftError, + SweepStatusError, + Trial, + TrialArtifacts, + materialize_multi_run_trial, + materialize_trial, + multi_run_shared_dir, + read_status_json, + record_multi_run_materialization_failure, + record_trial_materialization_failure, + record_trial_missing_objective, + record_trial_objective, +) +from prime_rl.sweep.metrics import coerce_finite_float, read_final_summary +from prime_rl.sweep.multi_run import run_multi_run_optuna_sweep +from prime_rl.sweep.optuna_loop import run_optuna_sweep, run_optuna_sweep_parallel_slurm +from prime_rl.sweep.reproducibility import git_metadata +from prime_rl.sweep.schedulers import ( + run_trials_locally, + submit_trials_to_multi_run_lora, + submit_trials_to_slurm, + utc_now, +) +from prime_rl.sweep.search import expand_grid, sample_random + + +def _write_toml(path: Path, data: dict[str, Any]) -> None: + path.parent.mkdir(parents=True, exist_ok=True) + with open(path, "wb") as f: + tomli_w.dump(data, f) + + +def build_variant(artifact: TrialArtifacts) -> dict[str, Any]: + # Mirror live trial state into the manifest so jq queries against + # variants[*].state / variants[*].objective work without separately + # reading each status.json. status.json is written at materialization + # (state="pending") and updated as trials run, so we always have at + # least the pending values to record here. + try: + status = read_status_json(artifact.status_path) if artifact.status_path.exists() else {} + except SweepStatusError: + status = {} + return { + "id": artifact.trial.id, + "label": artifact.trial.label, + "output_dir": artifact.run_dir.as_posix(), + "overrides": artifact.trial.parameters, + "command": artifact.command, + "status_path": artifact.status_path.as_posix(), + "resolved_checksum": artifact.resolved_checksum, + "base_checksums": artifact.base_checksums, + "state": status.get("state"), + "objective": status.get("objective"), + } + + +def write_manifest_with_variants(config: SweepConfig, variants: list[dict[str, Any]]) -> None: + manifest = { + "name": config.name, + "entrypoint": config.entrypoint, + "strategy": config.strategy.model_dump(mode="json"), + "scheduler": config.scheduler.model_dump(mode="json"), + "parameters": _manifest_parameters(config), + "parameter_order": _manifest_parameter_order(config), + "objective": config.objective.model_dump(mode="json") if config.objective else None, + "early_stopping": config.early_stopping.model_dump(mode="json") if config.early_stopping else None, + "git": git_metadata(), + "variants": variants, + } + (config.output_dir / "manifest.json").write_text(json.dumps(manifest, indent=2, sort_keys=True) + "\n") + + +def _write_manifest(config: SweepConfig, artifacts: list[TrialArtifacts]) -> None: + write_manifest_with_variants(config, [build_variant(a) for a in artifacts]) + + +def _update_manifest_summary(config: SweepConfig, summary: dict[str, Any] | None) -> None: + manifest_path = config.output_dir / "manifest.json" + manifest = json.loads(manifest_path.read_text()) + manifest["summary"] = summary + manifest_path.write_text(json.dumps(manifest, indent=2, sort_keys=True) + "\n") + + +def _manifest_variant_count(config: SweepConfig, fallback: int) -> int: + manifest_path = config.output_dir / "manifest.json" + if not manifest_path.is_file(): + return fallback + variants = json.loads(manifest_path.read_text()).get("variants") + return len(variants) if isinstance(variants, list) else fallback + + +def _expand_trials(config: SweepConfig) -> list[Trial]: + if isinstance(config.strategy, GridStrategyConfig): + return expand_grid(config.parameters) + if isinstance(config.strategy, RandomStrategyConfig): + return sample_random( + config.parameters, + num_trials=config.strategy.num_trials, + seed=config.strategy.seed, + ) + raise ValueError(f"Unsupported sweep strategy: {config.strategy!r}") + + +def _previous_manifest(config: SweepConfig) -> dict[str, Any] | None: + manifest_path = config.output_dir / "manifest.json" + if not manifest_path.exists(): + return None + try: + manifest = json.loads(manifest_path.read_text()) + except json.JSONDecodeError as exc: + raise RuntimeError( + "Resume cannot reuse existing trial results because the previous manifest " + "is not valid JSON. Restore the manifest or start a fresh study/output_dir." + ) from exc + if not isinstance(manifest, dict): + raise RuntimeError( + "Resume cannot reuse existing trial results because the previous manifest " + "is not a JSON object. Restore the manifest or start a fresh study/output_dir." + ) + return manifest + + +def _manifest_objective(config: SweepConfig) -> dict[str, Any] | None: + return config.objective.model_dump(mode="json") if config.objective else None + + +def _manifest_parameters(config: SweepConfig) -> dict[str, Any]: + return {path: parameter.model_dump(mode="json") for path, parameter in config.parameters.items()} + + +def _manifest_parameter_order(config: SweepConfig) -> list[str]: + return list(config.parameters) + + +def _validate_resume_manifest_objective(config: SweepConfig, manifest: dict[str, Any] | None) -> None: + if manifest is None or not manifest.get("variants"): + return + + previous_objective = manifest.get("objective") + current_objective = _manifest_objective(config) + if previous_objective == current_objective: + return + + raise RuntimeError( + "Resume cannot reuse existing trial objectives because the sweep objective changed " + f"(previous={previous_objective}, current={current_objective}). " + "Use the original objective or start a fresh study/output_dir." + ) + + +def _validate_resume_manifest_entrypoint(config: SweepConfig, manifest: dict[str, Any] | None) -> None: + if manifest is None or not manifest.get("variants"): + return + + previous_entrypoint = manifest.get("entrypoint") + if previous_entrypoint == config.entrypoint: + return + + raise RuntimeError( + "Resume cannot reuse existing trial results because the sweep entrypoint changed " + f"(previous={previous_entrypoint}, current={config.entrypoint}). " + "Use the original entrypoint or start a fresh study/output_dir." + ) + + +def _validate_resume_manifest_parameters(config: SweepConfig, manifest: dict[str, Any] | None) -> None: + if manifest is None or not manifest.get("variants"): + return + + previous_parameters = manifest.get("parameters") + current_parameters = _manifest_parameters(config) + if previous_parameters != current_parameters: + raise RuntimeError( + "Resume cannot reuse existing trial results because the sweep parameters changed " + f"(previous={previous_parameters}, current={current_parameters}). " + "Use the original search space or start a fresh study/output_dir." + ) + + previous_order = manifest.get("parameter_order") + current_order = _manifest_parameter_order(config) + if previous_order == current_order: + return + + raise RuntimeError( + "Resume cannot reuse existing trial results because the sweep parameter order changed " + f"(previous={previous_order}, current={current_order}). Trial generation is order-sensitive; " + "use the original parameter order or start a fresh study/output_dir." + ) + + +def _validate_resume_manifest_strategy(config: SweepConfig, manifest: dict[str, Any] | None) -> None: + if manifest is None or not manifest.get("variants"): + return + + previous_strategy = manifest.get("strategy") + current_strategy = config.strategy.model_dump(mode="json") + if not isinstance(previous_strategy, dict): + raise RuntimeError( + "Resume cannot reuse existing trial results because the previous manifest " + "does not record the sweep strategy. Restore the manifest or start a fresh study/output_dir." + ) + + if previous_strategy.get("type") != current_strategy.get("type"): + raise RuntimeError( + "Resume cannot reuse existing trial results because the sweep strategy changed " + f"(previous={previous_strategy}, current={current_strategy}). " + "Use the original strategy or start a fresh study/output_dir." + ) + + if isinstance(config.strategy, RandomStrategyConfig | OptunaStrategyConfig): + previous_without_count = dict(previous_strategy) + current_without_count = dict(current_strategy) + previous_num_trials = previous_without_count.pop("num_trials", None) + current_num_trials = current_without_count.pop("num_trials", None) + + compatible = ( + previous_without_count == current_without_count + and type(previous_num_trials) is int + and type(current_num_trials) is int + and current_num_trials >= previous_num_trials + ) + if compatible: + return + elif previous_strategy == current_strategy: + return + + raise RuntimeError( + "Resume cannot reuse existing trial results because the sweep strategy changed " + f"(previous={previous_strategy}, current={current_strategy}). " + "Only increasing strategy.num_trials is supported on resume; other strategy changes " + "need a fresh study/output_dir." + ) + + +def _validate_resume_manifest_scheduler(config: SweepConfig, manifest: dict[str, Any] | None) -> None: + if manifest is None or not manifest.get("variants"): + return + + previous_scheduler = manifest.get("scheduler") + current_scheduler = config.scheduler.model_dump(mode="json") + if not isinstance(previous_scheduler, dict): + raise RuntimeError( + "Resume cannot reuse existing trial results because the previous manifest " + "does not record the sweep scheduler. Restore the manifest or start a fresh study/output_dir." + ) + + if previous_scheduler.get("type") == current_scheduler.get("type"): + return + + raise RuntimeError( + "Resume cannot reuse existing trial results because the sweep scheduler type changed " + f"(previous={previous_scheduler}, current={current_scheduler}). " + "Use the original scheduler type or start a fresh study/output_dir." + ) + + +def _checksums_from_manifest(manifest: dict[str, Any] | None) -> dict[str, dict[str, Any]]: + """Map trial_id -> {resolved_checksum, base_checksums} from the prior manifest.""" + if manifest is None: + return {} + + variants = manifest.get("variants", []) + if not isinstance(variants, list): + raise RuntimeError( + "Resume cannot reuse existing trial results because the previous manifest " + "does not record variants as a list. Restore the manifest or start a fresh study/output_dir." + ) + + checksums: dict[str, dict[str, Any]] = {} + invalid_ids: list[Any] = [] + duplicate_ids: list[str] = [] + for variant in variants: + if not isinstance(variant, dict): + invalid_ids.append(variant) + continue + trial_id = variant.get("id") + if not isinstance(trial_id, str) or not trial_id: + invalid_ids.append(trial_id) + continue + if trial_id in checksums: + duplicate_ids.append(trial_id) + continue + checksums[trial_id] = { + "resolved_checksum": variant.get("resolved_checksum"), + "base_checksums": variant.get("base_checksums") or {}, + } + + if invalid_ids or duplicate_ids: + raise RuntimeError( + "Resume cannot reuse existing trial results because the previous manifest " + "has malformed or duplicate variant id(s). " + f"Invalid id(s): {invalid_ids}; duplicate id(s): {duplicate_ids}. " + "Repair the manifest or start a fresh study/output_dir." + ) + + return checksums + + +def _materialize_study(config: SweepConfig) -> list[TrialArtifacts]: + if config.output_dir.exists() and config.clean_output_dir: + shutil.rmtree(config.output_dir) + config.output_dir.mkdir(parents=True, exist_ok=True) + + previous_manifest = _previous_manifest(config) if config.resume else None + if config.resume: + _validate_resume_manifest_entrypoint(config, previous_manifest) + _validate_resume_manifest_objective(config, previous_manifest) + _validate_resume_manifest_parameters(config, previous_manifest) + _validate_resume_manifest_strategy(config, previous_manifest) + _validate_resume_manifest_scheduler(config, previous_manifest) + expected = _checksums_from_manifest(previous_manifest) + + _write_toml(config.output_dir / "study.toml", config.model_dump(exclude_none=True, mode="json")) + + trials = _expand_trials(config) + if config.resume: + trial_ids = {trial.id for trial in trials} + extra_manifest_ids = sorted(set(expected) - trial_ids) + if extra_manifest_ids: + raise RuntimeError( + "Resume cannot reuse existing trial results because the previous manifest " + "has variant id(s) that are not in the regenerated trial set: " + f"{extra_manifest_ids}. Restore the original sweep definition, repair the manifest, " + "or start a fresh study/output_dir." + ) + artifacts: list[TrialArtifacts] = [] + for trial in trials: + try: + artifact = materialize_trial( + config, + trial, + resume=config.resume, + expected_checksums=expected.get(trial.id), + ) + except Exception as exc: + if _should_propagate_materialization_error(exc): + raise + artifact = record_trial_materialization_failure(config, trial, exc, finished_at=utc_now()) + artifacts.append(artifact) + if not config.continue_on_failure: + break + continue + artifacts.append(artifact) + _write_manifest(config, artifacts) + return artifacts + + +def _materialize_multi_run_study(config: SweepConfig) -> list[TrialArtifacts]: + """Materialize all trials as ``run_*`` subdirs under a shared trainer dir. + + Multi-run sweeps invoke ``rl-multi-run`` exactly once with all trials + laid out up front. Resume against a still-running shared trainer is a + Phase 7b concern, so this path always writes from scratch. + """ + if config.output_dir.exists() and config.clean_output_dir: + shutil.rmtree(config.output_dir) + config.output_dir.mkdir(parents=True, exist_ok=True) + multi_run_shared_dir(config).mkdir(parents=True, exist_ok=True) + + _write_toml(config.output_dir / "study.toml", config.model_dump(exclude_none=True, mode="json")) + + trials = _expand_trials(config) + assert isinstance(config.scheduler, MultiRunLoRASchedulerConfig) + if len(trials) > config.scheduler.max_concurrent_runs: + raise SystemExit( + f"multi_run_lora scheduler.max_concurrent_runs={config.scheduler.max_concurrent_runs} " + f"but the search expanded to {len(trials)} trials. Increase max_concurrent_runs " + "or shrink the search space; wave-based execution requires Optuna (Phase 7b)." + ) + artifacts: list[TrialArtifacts] = [] + for trial in trials: + try: + artifact = materialize_multi_run_trial(config, trial, config.scheduler) + except Exception as exc: + artifact = record_multi_run_materialization_failure( + config, trial, config.scheduler, exc, finished_at=utc_now() + ) + artifacts.append(artifact) + if not config.continue_on_failure: + break + continue + artifacts.append(artifact) + _write_manifest(config, artifacts) + return artifacts + + +def _build_trial_callback( + config: SweepConfig, + tracker: TrialOutcomeTracker | None, + *, + halt_on_missing_objective: bool = False, +): + if config.objective is None or tracker is None: + return None + + metric = config.objective.metric + + def on_trial_complete(artifact: TrialArtifacts, returncode: int) -> bool: + objective = read_final_summary(artifact.run_dir, metric) if returncode == 0 else None + record_trial_objective(artifact.status_path, objective) + missing_objective = returncode == 0 and objective is None + if missing_objective: + record_trial_missing_objective(artifact.status_path, metric) + outcome = TrialOutcome(trial_id=artifact.trial.id, label=artifact.trial.label, objective=objective) + return tracker.observe(outcome) or (halt_on_missing_objective and missing_objective) + + return on_trial_complete + + +def _count_objective_failures(artifacts: list[TrialArtifacts]) -> int: + missing = 0 + for artifact in artifacts: + status = read_status_json(artifact.status_path) + if status.get("failure_stage") == "objective": + missing += 1 + elif status.get("state") == "completed" and coerce_finite_float(status.get("objective")) is None: + missing += 1 + return missing + + +def _is_materialization_failure(artifact: TrialArtifacts) -> bool: + return read_status_json(artifact.status_path).get("failure_stage") == "materialization" + + +def _should_propagate_materialization_error(exc: Exception) -> bool: + return isinstance(exc, (SweepDriftError, SweepStatusError)) + + +def _seed_tracker_from_resume(tracker: TrialOutcomeTracker, artifacts: list[TrialArtifacts]) -> None: + """Replay each completed trial's recorded objective into the tracker. + + Without this the resumed scheduler skips already-completed trials, so the + tracker never sees them — the manifest summary would forget earlier work + and patience/threshold decisions would not account for completed trials. + """ + for artifact in artifacts: + status = read_status_json(artifact.status_path) + if status.get("state") != "completed": + continue + outcome = TrialOutcome( + trial_id=artifact.trial.id, + label=artifact.trial.label, + objective=status.get("objective"), + ) + tracker.observe(outcome) + + +def _run_optuna(config: SweepConfig) -> None: + if config.output_dir.exists() and config.clean_output_dir: + shutil.rmtree(config.output_dir) + config.output_dir.mkdir(parents=True, exist_ok=True) + if config.resume: + previous_manifest = _previous_manifest(config) + _validate_resume_manifest_entrypoint(config, previous_manifest) + _validate_resume_manifest_objective(config, previous_manifest) + _validate_resume_manifest_parameters(config, previous_manifest) + _validate_resume_manifest_strategy(config, previous_manifest) + _validate_resume_manifest_scheduler(config, previous_manifest) + _write_toml(config.output_dir / "study.toml", config.model_dump(exclude_none=True, mode="json")) + + if config.dry_run: + print( + "Dry run for Optuna strategy is a no-op: trials are proposed sequentially based on " + "prior objectives, so they cannot be materialized up front." + ) + return + + from prime_rl.configs.sweep import SlurmSweepSchedulerConfig as _SlurmCfg + + if ( + isinstance(config.scheduler, _SlurmCfg) + and config.scheduler.synchronous + and config.scheduler.max_parallel > 1 + ): + failures, tracker, artifacts = run_optuna_sweep_parallel_slurm( + config, + write_manifest_with_variants=write_manifest_with_variants, + build_variant=build_variant, + ) + else: + failures, tracker, artifacts = run_optuna_sweep( + config, + write_manifest_with_variants=write_manifest_with_variants, + build_variant=build_variant, + ) + + if tracker is not None: + summary = asdict(tracker.summary()) + _update_manifest_summary(config, summary) + if summary["best_trial_id"] is not None: + label = tracker.best_label or summary["best_trial_id"] + print(f"Best trial: {label} ({summary['best_value']})") + if summary["halted_by_early_stopping"]: + print(f"Sweep halted by early stopping ({summary['halt_reason']}).") + + if failures > 0: + total = _manifest_variant_count(config, len(artifacts)) + print(f"Sweep finished with {failures} failed trial(s) out of {total}.") + raise SystemExit(1) + + +def _run_multi_run_static(config: SweepConfig) -> None: + """Drive a static (grid/random) shared-trainer LoRA sweep through ``rl-multi-run``.""" + assert isinstance(config.scheduler, MultiRunLoRASchedulerConfig) + artifacts = _materialize_multi_run_study(config) + materialization_failures = sum(1 for artifact in artifacts if _is_materialization_failure(artifact)) + launchable_artifacts = [ + artifact for artifact in artifacts if not _is_materialization_failure(artifact) + ] + + if config.dry_run: + print( + f"Dry run complete. Materialized {len(artifacts)} run dir(s) under " + f"{multi_run_shared_dir(config)}." + ) + for artifact in artifacts: + print(f" {artifact.run_dir}") + if materialization_failures: + print(f"Dry run found {materialization_failures} failed trial materialization(s).") + raise SystemExit(1) + return + + if materialization_failures > 0 and not config.continue_on_failure: + print("Skipping multi_run_lora launch: materialization failed and continue_on_failure=false.") + failures = 0 + elif launchable_artifacts: + failures = submit_trials_to_multi_run_lora( + launchable_artifacts, + shared_paths=config.scheduler.shared, + shared_dir=multi_run_shared_dir(config), + continue_on_failure=config.continue_on_failure, + retry_budget=config.retry_budget, + ) + else: + failures = 0 + + tracker = TrialOutcomeTracker(config.objective, config.early_stopping) if config.objective else None + missing_objective_failures = 0 + if config.objective is not None and tracker is not None: + for artifact in artifacts: + # reconcile_multi_run_artifact already chose between completed / + # failed / pruned; only completed trials have an objective worth + # rereading from metrics.jsonl, but record_trial_objective(None) + # for the others keeps status.json's shape consistent. + status = read_status_json(artifact.status_path) + if status.get("state") == "completed": + objective = read_final_summary(artifact.run_dir, config.objective.metric) + if objective is None: + record_trial_missing_objective(artifact.status_path, config.objective.metric) + missing_objective_failures += 1 + else: + objective = None + if objective is not None or status.get("state") != "completed": + record_trial_objective(artifact.status_path, objective) + tracker.observe( + TrialOutcome(trial_id=artifact.trial.id, label=artifact.trial.label, objective=objective) + ) + summary = asdict(tracker.summary()) + + # Same fix as run_sweep: refresh variants from each trial's final + # status.json so state/objective reflect the post-wave reality, not + # the pending values from materialization. Runs outside the + # objective-tracker branch so state is also refreshed when no + # objective is configured. + _write_manifest(config, artifacts) + if tracker is not None: + _update_manifest_summary(config, summary) + if summary["best_trial_id"] is not None: + label = tracker.best_label or summary["best_trial_id"] + print(f"Best trial: {label} ({summary['best_value']})") + + failures += materialization_failures + failures += missing_objective_failures + if failures > 0: + total = _manifest_variant_count(config, len(artifacts)) + print(f"Sweep finished with {failures} failed trial(s) out of {total}.") + raise SystemExit(1) + + +def _run_multi_run_optuna(config: SweepConfig) -> None: + """Drive an Optuna study against ``rl-multi-run`` in waves of size ``max_concurrent_runs``.""" + assert isinstance(config.scheduler, MultiRunLoRASchedulerConfig) + assert isinstance(config.strategy, OptunaStrategyConfig) + if config.output_dir.exists() and config.clean_output_dir: + shutil.rmtree(config.output_dir) + config.output_dir.mkdir(parents=True, exist_ok=True) + multi_run_shared_dir(config).mkdir(parents=True, exist_ok=True) + _write_toml(config.output_dir / "study.toml", config.model_dump(exclude_none=True, mode="json")) + + if config.dry_run: + print( + "Dry run for Optuna + multi_run_lora is a no-op: trials are proposed wave by wave " + "based on prior objectives, so they cannot be materialized up front." + ) + return + + failures, tracker, artifacts = run_multi_run_optuna_sweep( + config, + write_manifest_with_variants=write_manifest_with_variants, + build_variant=build_variant, + ) + + if tracker is not None: + summary = asdict(tracker.summary()) + _update_manifest_summary(config, summary) + if summary["best_trial_id"] is not None: + label = tracker.best_label or summary["best_trial_id"] + print(f"Best trial: {label} ({summary['best_value']})") + if summary["halted_by_early_stopping"]: + print(f"Sweep halted by early stopping ({summary['halt_reason']}).") + + if failures > 0: + total = _manifest_variant_count(config, len(artifacts)) + print(f"Sweep finished with {failures} failed trial(s) out of {total}.") + raise SystemExit(1) + + +def _run_multi_run(config: SweepConfig) -> None: + """Dispatch a shared-trainer LoRA sweep based on the search strategy.""" + assert isinstance(config.scheduler, MultiRunLoRASchedulerConfig) + if isinstance(config.strategy, OptunaStrategyConfig): + _run_multi_run_optuna(config) + return + _run_multi_run_static(config) + + +def run_sweep(config: SweepConfig) -> None: + # multi_run_lora dispatches first because the Optuna + multi_run_lora + # combination has its own wave driver — falling through to _run_optuna + # would launch single-trial mode against the wrong scheduler. + if isinstance(config.scheduler, MultiRunLoRASchedulerConfig): + _run_multi_run(config) + return + + if isinstance(config.strategy, OptunaStrategyConfig): + _run_optuna(config) + return + + artifacts = _materialize_study(config) + materialization_failures = sum(1 for artifact in artifacts if _is_materialization_failure(artifact)) + launchable_artifacts = [ + artifact for artifact in artifacts if not _is_materialization_failure(artifact) + ] + + if config.dry_run: + print(f"Dry run complete. Materialized {len(artifacts)} trial(s) under {config.output_dir}.") + for artifact in artifacts: + print(shlex.join(artifact.command)) + if materialization_failures: + print(f"Dry run found {materialization_failures} failed trial materialization(s).") + raise SystemExit(1) + return + + slurm_sync = ( + isinstance(config.scheduler, SlurmSweepSchedulerConfig) and config.scheduler.synchronous + ) + track_objectives = config.objective is not None and ( + isinstance(config.scheduler, LocalSweepSchedulerConfig) or slurm_sync + ) + if config.objective is not None and not track_objectives: + print( + "Note: objective tracking is only computed for the local scheduler or " + "synchronous SLURM scheduler; async SLURM submission produces its own " + "status.json without controller-side reconciliation." + ) + tracker = TrialOutcomeTracker(config.objective, config.early_stopping) if track_objectives else None + on_trial_complete = _build_trial_callback( + config, + tracker, + halt_on_missing_objective=not config.continue_on_failure, + ) + + if tracker is not None and config.resume: + _seed_tracker_from_resume(tracker, artifacts) + + failures = 0 + halt_after_materialization_failure = materialization_failures > 0 and not config.continue_on_failure + counted_completed_missing_objectives = False + resume_missing_objectives = ( + _count_objective_failures(artifacts) + if track_objectives and config.resume + else 0 + ) + if halt_after_materialization_failure: + print("Skipping trial launch: materialization failed and continue_on_failure=false.") + elif resume_missing_objectives > 0 and not config.continue_on_failure: + failures += resume_missing_objectives + counted_completed_missing_objectives = True + print("Skipping new trials: resume found completed trial(s) without recorded objectives.") + elif tracker is not None and tracker.halted: + print("Skipping new trials: early stopping already triggered by completed trials in the study.") + elif isinstance(config.scheduler, LocalSweepSchedulerConfig): + gpu_groups = ( + config.scheduler.gpu_assignment.visible_devices if config.scheduler.gpu_assignment is not None else None + ) + failures = run_trials_locally( + launchable_artifacts, + max_parallel=config.scheduler.max_parallel, + gpu_groups=gpu_groups, + continue_on_failure=config.continue_on_failure, + retry_budget=config.retry_budget, + on_trial_complete=on_trial_complete, + ) + elif isinstance(config.scheduler, SlurmSweepSchedulerConfig): + failures = submit_trials_to_slurm( + launchable_artifacts, + continue_on_failure=config.continue_on_failure, + retry_budget=config.retry_budget, + synchronous=config.scheduler.synchronous, + on_trial_complete=on_trial_complete if config.scheduler.synchronous else None, + ) + else: + raise ValueError(f"Unsupported sweep scheduler: {config.scheduler}") + + if track_objectives and not counted_completed_missing_objectives: + failures += _count_objective_failures(artifacts) + failures += materialization_failures + + # Refresh manifest variants from each trial's final status.json so + # state/objective reflect the post-run reality, not the pending + # values written during materialization. Optuna and multi_run_lora + # paths already rewrite the manifest at the end of their drivers. + _write_manifest(config, artifacts) + + if tracker is not None: + summary = asdict(tracker.summary()) + _update_manifest_summary(config, summary) + if summary["best_trial_id"] is not None: + label = tracker.best_label or summary["best_trial_id"] + print(f"Best trial: {label} ({summary['best_value']})") + if summary["halted_by_early_stopping"]: + print(f"Sweep halted by early stopping ({summary['halt_reason']}).") + + if failures > 0: + print(f"Sweep finished with {failures} failed trial(s) out of {len(artifacts)}.") + raise SystemExit(1) diff --git a/src/prime_rl/sweep/early_stopping.py b/src/prime_rl/sweep/early_stopping.py new file mode 100644 index 0000000000..e7c7f2ee26 --- /dev/null +++ b/src/prime_rl/sweep/early_stopping.py @@ -0,0 +1,130 @@ +"""Trial-level early stopping for sweep studies. + +The tracker observes each completed trial's objective value and decides +whether the controller should halt remaining work. All decisions are made +between trials (not in-flight): a worse-than-threshold value or a long enough +run of non-improving trials triggers a halt that prevents new submissions +while in-flight trials finish naturally. +""" + +import threading +from dataclasses import dataclass +from typing import Literal + +from prime_rl.configs.sweep import ( + EarlyStoppingConfig, + ObjectiveConfig, + PatienceStoppingConfig, + ThresholdStoppingConfig, +) +from prime_rl.sweep.metrics import coerce_finite_float + + +@dataclass(frozen=True) +class TrialOutcome: + trial_id: str + label: str + objective: float | None + + +@dataclass +class TrialOutcomeSummary: + completed: int + best_trial_id: str | None + best_value: float | None + halted_by_early_stopping: bool + halt_reason: Literal["threshold", "patience"] | None + + +class TrialOutcomeTracker: + """Thread-safe tracker over completed trials' objectives.""" + + def __init__( + self, + objective: ObjectiveConfig | None, + early_stopping: EarlyStoppingConfig | None, + ): + self._objective = objective + self._early_stopping = early_stopping + self._lock = threading.Lock() + self._completed = 0 + self._best_value: float | None = None + self._best_trial_id: str | None = None + self._best_label: str | None = None + self._steps_without_improvement = 0 + self._halted = False + self._halt_reason: Literal["threshold", "patience"] | None = None + self._outcomes: list[TrialOutcome] = [] + + def observe(self, outcome: TrialOutcome) -> bool: + """Record an outcome and return whether the study should halt.""" + with self._lock: + self._outcomes.append(outcome) + value = coerce_finite_float(outcome.objective) + if value is None: + # Missing metrics do not advance early-stopping decisions. + return self._halted + + self._completed += 1 + if self._is_improvement(value): + self._best_value = value + self._best_trial_id = outcome.trial_id + self._best_label = outcome.label + self._steps_without_improvement = 0 + else: + self._steps_without_improvement += 1 + + if not self._halted and self._should_halt(value): + self._halted = True + + return self._halted + + def _is_improvement(self, value: float) -> bool: + if self._objective is None or self._best_value is None: + return self._best_value is None + if self._objective.direction == "maximize": + return value > self._best_value + return value < self._best_value + + def _should_halt(self, value: float) -> bool: + config = self._early_stopping + if config is None or self._objective is None: + return False + if self._completed < config.min_trials: + return False + if isinstance(config, ThresholdStoppingConfig): + worse = ( + value < config.threshold + if self._objective.direction == "maximize" + else value > config.threshold + ) + if worse: + self._halt_reason = "threshold" + return True + return False + if isinstance(config, PatienceStoppingConfig): + if self._steps_without_improvement >= config.patience: + self._halt_reason = "patience" + return True + return False + return False + + @property + def halted(self) -> bool: + with self._lock: + return self._halted + + def summary(self) -> TrialOutcomeSummary: + with self._lock: + return TrialOutcomeSummary( + completed=self._completed, + best_trial_id=self._best_trial_id, + best_value=self._best_value, + halted_by_early_stopping=self._halted, + halt_reason=self._halt_reason, + ) + + @property + def best_label(self) -> str | None: + with self._lock: + return self._best_label diff --git a/src/prime_rl/sweep/materialize.py b/src/prime_rl/sweep/materialize.py new file mode 100644 index 0000000000..884012edf2 --- /dev/null +++ b/src/prime_rl/sweep/materialize.py @@ -0,0 +1,1061 @@ +import hashlib +import json +import shlex +import tempfile +import tomllib +from dataclasses import dataclass +from pathlib import Path +from typing import Any, Literal + +import tomli_w + +from prime_rl.configs.orchestrator import OrchestratorConfig +from prime_rl.configs.rl import RLConfig +from prime_rl.configs.sft import SFTConfig +from prime_rl.configs.sweep import SweepConfig +from prime_rl.sweep.metrics import coerce_finite_float +from prime_rl.sweep.reproducibility import file_checksum +from prime_rl.utils.config import BaseConfig, cli + + +@dataclass(frozen=True) +class Trial: + id: str + label: str + parameters: dict[str, Any] + + +@dataclass(frozen=True) +class TrialArtifacts: + trial: Trial + trial_dir: Path + run_dir: Path + overrides_path: Path + resolved_path: Path + command_path: Path + status_path: Path + command: list[str] + resolved_checksum: str + base_checksums: dict[str, str] + + +def set_dotted_path(data: dict[str, Any], path: str, value: Any) -> None: + if not path: + raise ValueError("Sweep parameter path cannot be empty") + + parts = path.split(".") + current = data + for part in parts[:-1]: + child = current.setdefault(part, {}) + if not isinstance(child, dict): + raise ValueError(f"Cannot set {path}: {part} is already set to a non-table value") + current = child + current[parts[-1]] = value + + +def build_nested_overrides(flat_overrides: dict[str, Any]) -> dict[str, Any]: + nested: dict[str, Any] = {} + for path, value in flat_overrides.items(): + set_dotted_path(nested, path, value) + return nested + + +def merge_nested_overrides(target: dict[str, Any], overrides: dict[str, Any]) -> None: + for key, value in overrides.items(): + current = target.get(key) + if isinstance(current, dict) and isinstance(value, dict): + merge_nested_overrides(current, value) + else: + target[key] = value + + +def _load_nested_toml(paths: list[Path]) -> dict[str, Any]: + merged: dict[str, Any] = {} + for path in paths: + with open(path, "rb") as f: + data = tomllib.load(f) + merge_nested_overrides(merged, data) + return merged + + +def sanitize_label_part(value: Any) -> str: + text = str(value) + for char in ("/", "\\", " ", ":", ",", "[", "]", "{", "}", "'", '"'): + text = text.replace(char, "_") + return text + + +def trial_label(parameters: dict[str, Any], max_len: int = 96) -> str: + parts = [] + for path, value in parameters.items(): + name = path.split(".")[-1].replace("_", "-") + parts.append(f"{name}_{sanitize_label_part(value)}") + label = "-".join(parts) + return label if len(label) <= max_len else "" + + +def command_for_trial( + entrypoint: Literal["rl", "sft"], + base_paths: list[Path], + overrides_path: Path, +) -> list[str]: + """Compose the launcher command from base files plus the generated overrides. + + This matches the form a user would type by hand and keeps per-trial diffs + small. The frozen ``resolved.toml`` is written separately as a reproducible + artifact but is not used as the launch input. + """ + cmd = ["uv", "run", entrypoint] + for base in base_paths: + cmd.extend(["@", base.as_posix()]) + cmd.extend(["@", overrides_path.as_posix()]) + return cmd + + +def write_toml(path: Path, data: dict[str, Any]) -> None: + path.parent.mkdir(parents=True, exist_ok=True) + with open(path, "wb") as f: + tomli_w.dump(data, f) + + +def toml_checksum(data: dict[str, Any]) -> str: + """SHA-256 hex digest of the TOML bytes this module writes for ``data``.""" + return hashlib.sha256(tomli_w.dumps(data).encode("utf-8")).hexdigest() + + +def write_json(path: Path, data: dict[str, Any]) -> None: + path.parent.mkdir(parents=True, exist_ok=True) + path.write_text(json.dumps(data, indent=2, sort_keys=True) + "\n") + + +class SweepStatusError(RuntimeError): + """Raised when a sweep status artifact cannot be trusted.""" + + +def read_status_json(status_path: Path) -> dict[str, Any]: + try: + status = json.loads(status_path.read_text()) + except json.JSONDecodeError as exc: + raise SweepStatusError( + f"Sweep status file at {status_path} is not valid JSON. " + "Restore the trial artifacts or remove the trial directory." + ) from exc + if not isinstance(status, dict): + raise SweepStatusError( + f"Sweep status file at {status_path} must be a JSON object, " + f"not {type(status).__name__}. Restore the trial artifacts or remove the trial directory." + ) + return status + + +def _materialization_error(exc: BaseException) -> str: + message = str(exc) + return f"{type(exc).__name__}: {message}" if message else type(exc).__name__ + + +def record_trial_objective(status_path: Path, value: float | None) -> None: + """Persist an objective value into a trial's status.json, preserving other fields.""" + status = read_status_json(status_path) + status["objective"] = coerce_finite_float(value) + write_json(status_path, status) + + +def record_trial_missing_objective(status_path: Path, metric: str) -> None: + """Mark a clean process exit as failed when it did not produce the sweep metric.""" + status = read_status_json(status_path) + status["state"] = "failed" + status["objective"] = None + status["failure_stage"] = "objective" + status["error"] = f"Trial exited successfully but did not record a finite objective for {metric!r}." + if status.get("returncode") is None: + status["returncode"] = 0 + write_json(status_path, status) + + +def record_trial_pruned( + status_path: Path, + step: int, + value: float, + *, + returncode: int | None = None, + finished_at: str | None = None, +) -> None: + """Mark a trial as pruned by intermediate-metric reporting. + + The Optuna pruning loop terminates the trial subprocess when the sampler + decides the trajectory is unpromising. We record state="pruned" plus the + step/value the prune fired on so the manifest can tell pruned trials + apart from completed and failed runs without re-deriving the cause. + """ + status = read_status_json(status_path) + status["state"] = "pruned" + status["pruned_at_step"] = int(step) + status["pruned_value"] = float(value) + status["objective"] = None + if returncode is not None: + status["returncode"] = int(returncode) + if finished_at is not None: + status["finished_at"] = finished_at + write_json(status_path, status) + + +def validate_target_config(entrypoint: Literal["rl", "sft"], args: list[str]) -> BaseConfig: + config_cls = RLConfig if entrypoint == "rl" else SFTConfig + return cli(config_cls, args=args) + + +def _validate_target_with_overrides( + entrypoint: Literal["rl", "sft"], + base_paths: list[Path], + overrides_path: Path, + overrides: dict[str, Any], + *, + avoid_overwriting: bool, +) -> BaseConfig: + args: list[str] = [] + for base_path in base_paths: + args.extend(["@", base_path.as_posix()]) + if not avoid_overwriting: + write_toml(overrides_path, overrides) + args.extend(["@", overrides_path.as_posix()]) + return validate_target_config(entrypoint, args) + + overrides_path.parent.mkdir(parents=True, exist_ok=True) + with tempfile.NamedTemporaryFile( + dir=overrides_path.parent, + prefix=f".{overrides_path.name}.", + suffix=".tmp.toml", + delete=False, + ) as temp: + temp_path = Path(temp.name) + try: + write_toml(temp_path, overrides) + args.extend(["@", temp_path.as_posix()]) + return validate_target_config(entrypoint, args) + finally: + temp_path.unlink(missing_ok=True) + + +def _target_config_to_toml(config: BaseConfig) -> dict[str, Any]: + return config.model_dump(exclude_none=True, mode="json") + + +def _get_nested_value(data: dict[str, Any], segments: tuple[str | int, ...]) -> tuple[bool, Any]: + current: Any = data + for segment in segments: + if isinstance(segment, int): + if not isinstance(current, list) or segment >= len(current): + return False, None + current = current[segment] + else: + if not isinstance(current, dict) or segment not in current: + return False, None + current = current[segment] + return True, current + + +def _canonical_bool_lookup_segments( + segments: tuple[str | int, ...], +) -> tuple[tuple[str | int, ...], ...]: + aliases = { + "max_tokens": "max_completion_tokens", + "skip_eval_on_restart": "skip_eval_on_resume", + "timeout_seconds": "timeout", + } + leaf = segments[-1] if segments else None + canonical_leaf = aliases.get(leaf) if isinstance(leaf, str) else None + if canonical_leaf is None: + return (segments,) + return (segments, (*segments[:-1], canonical_leaf)) + + +def _bool_parameter_targets( + display_path: str, + lookup_segments: tuple[str | int, ...], + value: Any, +) -> list[tuple[str, tuple[str | int, ...], bool]]: + if isinstance(value, bool): + return [(display_path, lookup_segments, value)] + if isinstance(value, dict): + targets: list[tuple[str, tuple[str | int, ...], bool]] = [] + for key, child in value.items(): + child_path = f"{display_path}.{key}" + targets.extend(_bool_parameter_targets(child_path, (*lookup_segments, key), child)) + return targets + if isinstance(value, (list, tuple)): + targets: list[tuple[str, tuple[str | int, ...], bool]] = [] + for idx, child in enumerate(value): + child_path = f"{display_path}[{idx}]" + targets.extend(_bool_parameter_targets(child_path, (*lookup_segments, idx), child)) + return targets + return [] + + +def _reject_bool_target_coercions( + parameters: dict[str, Any], + resolved_data: dict[str, Any], + *, + path_prefix: str = "", +) -> None: + """Reject bool choice values when the target field resolved as non-bool.""" + for path, value in parameters.items(): + lookup_path = path + if path_prefix and lookup_path.startswith(path_prefix): + lookup_path = lookup_path.removeprefix(path_prefix) + lookup_segments: tuple[str | int, ...] = tuple(lookup_path.split(".")) + for display_path, segments, bool_value in _bool_parameter_targets(path, lookup_segments, value): + for candidate in _canonical_bool_lookup_segments(segments): + found, resolved_value = _get_nested_value(resolved_data, candidate) + if not found: + continue + if not isinstance(resolved_value, bool): + raise ValueError( + f"Sweep parameter {display_path!r} uses boolean value {bool_value!r}, " + f"but the resolved target field is {type(resolved_value).__name__}. " + "Boolean choice values are only valid for boolean target fields." + ) + break + + +def _merge_wandb_overrides(config: SweepConfig, flat_overrides: dict[str, Any], trial: Trial) -> None: + if config.wandb is None or not config.wandb.enabled: + return + + group = config.wandb.group or config.name + if group is not None: + flat_overrides["wandb.group"] = group + + flat_overrides["wandb.name"] = trial.label or trial.id + + tags = list(dict.fromkeys([*config.wandb.tags, "sweep", f"trial:{trial.id}"])) + if config.name is not None: + tags.append(f"study:{config.name}") + flat_overrides["wandb.tags"] = list(dict.fromkeys(tags)) + + +TERMINAL_RESUME_STATES = frozenset({"completed", "submitted"}) + + +class SweepDriftError(RuntimeError): + """Raised when --resume would skip a trial whose effective config has changed.""" + + +def _existing_terminal_status(status_path: Path) -> dict[str, Any] | None: + """Return parsed status.json if its state should be preserved on resume.""" + if not status_path.exists(): + return None + status = read_status_json(status_path) + if status.get("state") in TERMINAL_RESUME_STATES: + return status + return None + + +def _check_resume_drift( + trial: Trial, + preserved_status: dict[str, Any], + expected: dict[str, Any] | None, + new_resolved_checksum: str, + new_base_checksums: dict[str, str], +) -> None: + """Refuse to skip a terminal trial whose recorded config differs from the live one. + + Trial IDs hash sweep parameters only, so a base TOML edit between runs + leaves the ID stable while the resolved config changes underneath us. + Without this check ``--resume`` would silently honor the old ``status.json`` + and skip work that no longer reflects the current configuration. + """ + if expected is None: + raise SweepDriftError( + f"Refusing to skip {preserved_status['state']} trial {trial.id} on resume because " + "the previous manifest has no checksum entry for it. Drop --resume to start fresh, " + "restore the manifest, or remove the trial directory." + ) + + expected_resolved = expected.get("resolved_checksum") + expected_bases = expected.get("base_checksums") or {} + if expected_resolved is None or not isinstance(expected_bases, dict): + raise SweepDriftError( + f"Refusing to skip {preserved_status['state']} trial {trial.id} on resume because " + "its previous manifest entry is missing resolved/base checksums. Drop --resume to " + "start fresh, restore the manifest, or remove the trial directory." + ) + + missing_bases = [base for base in new_base_checksums if base not in expected_bases] + if missing_bases: + raise SweepDriftError( + f"Refusing to skip {preserved_status['state']} trial {trial.id} on resume because " + f"its previous manifest entry is missing checksum(s) for base file(s): {missing_bases}. " + "Drop --resume to start fresh, restore the manifest, or remove the trial directory." + ) + + extra_bases = [base for base in expected_bases if base not in new_base_checksums] + if extra_bases: + raise SweepDriftError( + f"Refusing to skip {preserved_status['state']} trial {trial.id} on resume because " + f"its previous manifest entry has extra base file(s) no longer present: {extra_bases}. " + "Drop --resume to start fresh, restore the original base config list, or remove the trial directory." + ) + + changed_bases = [ + base for base, checksum in new_base_checksums.items() if expected_bases[base] != checksum + ] + resolved_drift = expected_resolved != new_resolved_checksum + + if not (changed_bases or resolved_drift): + return + + detail = f"changed base files: {changed_bases}" if changed_bases else "the resolved config changed" + raise SweepDriftError( + f"Refusing to skip {preserved_status['state']} trial {trial.id} on resume because " + f"{detail}. Drop --resume to start fresh, revert the change, or remove the trial directory." + ) + + +def materialize_trial( + config: SweepConfig, + trial: Trial, + resume: bool = False, + expected_checksums: dict[str, Any] | None = None, +) -> TrialArtifacts: + trial_dir = config.output_dir / "trials" / trial.id + run_dir = trial_dir / "run" + overrides_path = trial_dir / "overrides.toml" + resolved_path = trial_dir / "resolved.toml" + command_path = trial_dir / "command.txt" + status_path = trial_dir / "status.json" + + flat_overrides = dict(trial.parameters) + flat_overrides["output_dir"] = run_dir.as_posix() + _merge_wandb_overrides(config, flat_overrides, trial) + + overrides = build_nested_overrides(flat_overrides) + + preserved_status = _existing_terminal_status(status_path) if resume else None + + resolved_config = _validate_target_with_overrides( + config.entrypoint, + config.base, + overrides_path, + overrides, + avoid_overwriting=preserved_status is not None, + ) + resolved_data = _target_config_to_toml(resolved_config) + _reject_bool_target_coercions(trial.parameters, resolved_data) + + command = command_for_trial(config.entrypoint, config.base, overrides_path) + + resolved_checksum = toml_checksum(resolved_data) + base_checksums = {base.as_posix(): file_checksum(base) for base in config.base} + + if preserved_status is not None: + if preserved_status.get("id") != trial.id: + raise SweepDriftError( + f"Refusing to skip {preserved_status['state']} trial {trial.id} on resume because " + f"status.json belongs to {preserved_status.get('id')!r}. Restore the trial artifacts " + "or remove the mismatched trial directory." + ) + _check_resume_drift(trial, preserved_status, expected_checksums, resolved_checksum, base_checksums) + write_pending_status = preserved_status is None + + write_toml(overrides_path, overrides) + write_toml(resolved_path, resolved_data) + command_path.write_text(shlex.join(command) + "\n") + + if write_pending_status: + write_json( + status_path, + { + "id": trial.id, + "label": trial.label, + "state": "pending", + "pid": None, + "slurm_job_id": None, + "gpu_group": None, + "returncode": None, + "objective": None, + }, + ) + + return TrialArtifacts( + trial=trial, + trial_dir=trial_dir, + run_dir=run_dir, + overrides_path=overrides_path, + resolved_path=resolved_path, + command_path=command_path, + status_path=status_path, + command=command, + resolved_checksum=resolved_checksum, + base_checksums=base_checksums, + ) + + +def record_trial_materialization_failure( + config: SweepConfig, + trial: Trial, + exc: BaseException, + *, + finished_at: str | None = None, +) -> TrialArtifacts: + """Write a manifestable artifact for an Optuna trial that failed config validation.""" + trial_dir = config.output_dir / "trials" / trial.id + run_dir = trial_dir / "run" + overrides_path = trial_dir / "overrides.toml" + resolved_path = trial_dir / "resolved.toml" + command_path = trial_dir / "command.txt" + status_path = trial_dir / "status.json" + + flat_overrides = dict(trial.parameters) + flat_overrides["output_dir"] = run_dir.as_posix() + _merge_wandb_overrides(config, flat_overrides, trial) + write_toml(overrides_path, build_nested_overrides(flat_overrides)) + + command = command_for_trial(config.entrypoint, config.base, overrides_path) + command_path.write_text(shlex.join(command) + "\n") + resolved_path.unlink(missing_ok=True) + + status = { + "id": trial.id, + "label": trial.label, + "state": "failed", + "pid": None, + "slurm_job_id": None, + "gpu_group": None, + "returncode": -1, + "objective": None, + "failure_stage": "materialization", + "error": _materialization_error(exc), + } + if finished_at is not None: + status["finished_at"] = finished_at + write_json(status_path, status) + + return TrialArtifacts( + trial=trial, + trial_dir=trial_dir, + run_dir=run_dir, + overrides_path=overrides_path, + resolved_path=resolved_path, + command_path=command_path, + status_path=status_path, + command=command, + resolved_checksum="", + base_checksums={base.as_posix(): file_checksum(base) for base in config.base}, + ) + + +def _merge_multi_run_wandb_overrides( + config: SweepConfig, flat_overrides: dict[str, Any], trial: Trial +) -> None: + """Tag the per-run orchestrator's W&B run with sweep + trial metadata.""" + if config.wandb is None or not config.wandb.enabled: + return + + group = config.wandb.group or config.name + if group is not None: + flat_overrides["orchestrator.wandb.group"] = group + + flat_overrides["orchestrator.wandb.name"] = trial.label or trial.id + + tags = list(dict.fromkeys([*config.wandb.tags, "sweep", f"trial:{trial.id}"])) + if config.name is not None: + tags.append(f"study:{config.name}") + flat_overrides["orchestrator.wandb.tags"] = list(dict.fromkeys(tags)) + + +def multi_run_shared_dir(config: SweepConfig) -> Path: + """Directory that hosts the shared trainer's output and per-run subdirs. + + The trainer's ``MultiRunManager`` scans ``/run_*`` so every trial + directory must sit directly under this path with a ``run_`` prefix. + """ + return config.output_dir / "shared" + + +def multi_run_trial_dir(config: SweepConfig, trial: Trial) -> Path: + """Per-trial directory the trainer will discover as a ``run_*`` slot.""" + return multi_run_shared_dir(config) / f"run_{trial.id}" + + +def write_multi_run_output_override(shared_dir: Path) -> Path: + """Write the trainer ``output_dir`` pin used by ``rl-multi-run`` launches. + + Pinning ``output_dir = `` keeps the trainer's + ``MultiRunManager`` scanning the right ``run_*`` slots regardless of what + the base TOML carries. Materialization and the launcher both call this so + the override file is a known, replayable artifact rather than an + implementation detail of ``build_multi_run_command``. + """ + shared_dir.mkdir(parents=True, exist_ok=True) + path = shared_dir / "_output_override.toml" + write_toml(path, {"output_dir": shared_dir.as_posix()}) + return path + + +def _finalize_multi_run_lora_config( + orchestrator_dict: dict[str, Any], + resolved_rl_config: RLConfig, + trial_parameters: dict[str, Any], +) -> None: + """Apply multi-run LoRA relaxations after shared RLConfig validation.""" + trainer = getattr(resolved_rl_config, "trainer", None) + trainer_model = getattr(trainer, "model", None) + trainer_lora = getattr(trainer_model, "lora", None) + model = orchestrator_dict.get("model") + if trainer_lora is None or not isinstance(model, dict): + return + lora = model.get("lora") + if not isinstance(lora, dict): + return + + if lora.get("rank") is None: + lora["rank"] = trainer_lora.rank + if lora.get("alpha") is None: + lora["alpha"] = trainer_lora.alpha + rank = lora["rank"] + if isinstance(rank, int) and not isinstance(rank, bool) and rank > trainer_lora.rank: + raise ValueError( + f"orchestrator.model.lora.rank ({rank}) exceeds " + f"trainer.model.lora.rank ({trainer_lora.rank})" + ) + + lora_shape_changed = ( + "orchestrator.model.lora.rank" in trial_parameters + or "orchestrator.model.lora.alpha" in trial_parameters + ) + if lora_shape_changed and "orchestrator.model.lora.name" not in trial_parameters: + lora["name"] = f"r{lora['rank']}-a{lora['alpha']}" + + +def _validate_multi_run_shared_config(resolved_rl_config: RLConfig, scheduler: Any) -> None: + if not isinstance(resolved_rl_config, RLConfig): + return + + trainer_lora = resolved_rl_config.trainer.model.lora + if trainer_lora is None: + raise ValueError("multi_run_lora requires trainer.model.lora in the shared RLConfig.") + + trainer_max_runs = resolved_rl_config.trainer.max_concurrent_runs + scheduler_max_runs = scheduler.max_concurrent_runs + if trainer_max_runs < scheduler_max_runs: + raise ValueError( + "multi_run_lora scheduler.max_concurrent_runs must be <= " + f"trainer.max_concurrent_runs in the shared RLConfig " + f"(got scheduler={scheduler_max_runs}, trainer={trainer_max_runs})." + ) + + +def _finalize_multi_run_batching_config( + orchestrator_dict: dict[str, Any], + trial_parameters: dict[str, Any], + explicit_shared_orchestrator_fields: set[str], +) -> None: + sets_batch_size = "orchestrator.batch_size" in trial_parameters + sets_token_batch_size = "orchestrator.token_batch_size" in trial_parameters + sets_oversampling_factor = "orchestrator.oversampling_factor" in trial_parameters + if sets_batch_size and sets_token_batch_size: + raise ValueError("Set either orchestrator.batch_size or orchestrator.token_batch_size, not both.") + if sets_batch_size or sets_oversampling_factor: + orchestrator_dict.pop("token_batch_size", None) + if sets_token_batch_size: + orchestrator_dict.pop("batch_size", None) + orchestrator_dict.pop("oversampling_factor", None) + if ( + (sets_batch_size or sets_token_batch_size or sets_oversampling_factor) + and "orchestrator.max_inflight_rollouts" not in trial_parameters + and "max_inflight_rollouts" not in explicit_shared_orchestrator_fields + ): + orchestrator_dict.pop("max_inflight_rollouts", None) + + +def _canonicalize_multi_run_sampling_aliases( + orchestrator_dict: dict[str, Any], + trial_parameters: dict[str, Any], +) -> None: + for section in ("train", "eval"): + path = f"orchestrator.{section}.sampling.max_tokens" + if path not in trial_parameters: + continue + section_config = orchestrator_dict.get(section) + if not isinstance(section_config, dict): + continue + sampling = section_config.get("sampling") + if not isinstance(sampling, dict): + continue + if "max_tokens" in sampling: + sampling["max_completion_tokens"] = sampling.pop("max_tokens") + + +def _shared_section(shared_orchestrator: dict[str, Any], section: str) -> dict[str, Any]: + section_data: dict[str, Any] = {} + if section == "train": + if "env" in shared_orchestrator: + section_data["env"] = shared_orchestrator["env"] + if "sampling" in shared_orchestrator: + section_data["sampling"] = shared_orchestrator["sampling"] + raw = shared_orchestrator.get(section) + if isinstance(raw, dict): + merge_nested_overrides(section_data, raw) + return section_data + + +def _raw_envs(shared_section: dict[str, Any]) -> list[Any]: + envs = shared_section.get("env", []) + return envs if isinstance(envs, list) else [] + + +def _raw_env(raw_envs: list[Any], idx: int) -> dict[str, Any]: + if idx >= len(raw_envs): + return {} + raw = raw_envs[idx] + return raw if isinstance(raw, dict) else {} + + +def _clear_inherited_env_default( + env: dict[str, Any], + raw_env: dict[str, Any], + field: str, +) -> None: + if field not in raw_env: + env.pop(field, None) + + +def _clear_inherited_sampling_default( + env: dict[str, Any], + raw_env: dict[str, Any], + field: str, +) -> None: + raw_sampling = raw_env.get("sampling") + raw_sampling = raw_sampling if isinstance(raw_sampling, dict) else {} + aliases = (field, "max_tokens") if field == "max_completion_tokens" else (field,) + if any(alias in raw_sampling for alias in aliases): + return + sampling = env.get("sampling") + if isinstance(sampling, dict): + sampling.pop(field, None) + + +def _clear_inherited_extra_body_default( + env: dict[str, Any], + raw_env: dict[str, Any], + key: str, + group_extra_body: dict[str, Any], +) -> None: + raw_sampling = raw_env.get("sampling") + raw_sampling = raw_sampling if isinstance(raw_sampling, dict) else {} + raw_extra_body = raw_sampling.get("extra_body") + raw_extra_body = raw_extra_body if isinstance(raw_extra_body, dict) else {} + if key in raw_extra_body: + return + sampling = env.get("sampling") + sampling = sampling if isinstance(sampling, dict) else {} + extra_body = sampling.get("extra_body") + if isinstance(extra_body, dict): + if key in group_extra_body: + extra_body[key] = group_extra_body[key] + else: + extra_body.pop(key, None) + + +def _swept_sampling_defaults( + trial_parameters: dict[str, Any], + section: str, +) -> tuple[set[str], bool, set[str]]: + prefix = f"orchestrator.{section}.sampling." + fields: set[str] = set() + replaces_extra_body = False + extra_body_keys: set[str] = set() + for path, value in trial_parameters.items(): + if path == prefix.removesuffix("."): + if isinstance(value, dict): + for key in value: + if key == "max_tokens": + fields.add("max_completion_tokens") + elif key == "extra_body": + replaces_extra_body = True + else: + fields.add(key) + continue + if not path.startswith(prefix): + continue + suffix = path.removeprefix(prefix) + if suffix == "max_tokens": + fields.add("max_completion_tokens") + elif suffix == "extra_body": + replaces_extra_body = True + elif suffix.startswith("extra_body."): + extra_body_keys.add(suffix.removeprefix("extra_body.").split(".", 1)[0]) + else: + fields.add(suffix.split(".", 1)[0]) + return fields, replaces_extra_body, extra_body_keys + + +def _clear_inherited_env_defaults_for_section( + orchestrator_dict: dict[str, Any], + shared_orchestrator: dict[str, Any], + trial_parameters: dict[str, Any], + section: str, +) -> None: + section_config = orchestrator_dict.get(section) + if not isinstance(section_config, dict): + return + envs = section_config.get("env") + if not isinstance(envs, list): + return + + raw_section = _shared_section(shared_orchestrator, section) + raw_envs = _raw_envs(raw_section) + + env_default_fields: set[str] = set() + if f"orchestrator.{section}.num_workers" in trial_parameters: + env_default_fields.add("num_workers") + if f"orchestrator.{section}.max_retries" in trial_parameters: + env_default_fields.add("max_retries") + if section == "train": + if any( + path in trial_parameters + for path in ( + "orchestrator.batch_size", + "orchestrator.oversampling_factor", + "orchestrator.max_inflight_rollouts", + ) + ): + env_default_fields.add("num_workers") + else: + for field in ("num_examples", "rollouts_per_example", "interval"): + if f"orchestrator.{section}.{field}" in trial_parameters: + env_default_fields.add(field) + if any( + path in trial_parameters + for path in ( + "orchestrator.eval.num_examples", + "orchestrator.eval.rollouts_per_example", + ) + ): + env_default_fields.add("num_workers") + + sampling_fields, replaces_extra_body, extra_body_keys = _swept_sampling_defaults( + trial_parameters, section + ) + group_sampling = section_config.get("sampling") + group_sampling = group_sampling if isinstance(group_sampling, dict) else {} + group_extra_body = group_sampling.get("extra_body") + group_extra_body = group_extra_body if isinstance(group_extra_body, dict) else {} + + for idx, env in enumerate(envs): + if not isinstance(env, dict): + continue + raw_env = _raw_env(raw_envs, idx) + for field in env_default_fields: + _clear_inherited_env_default(env, raw_env, field) + sampling = env.get("sampling") + if not isinstance(sampling, dict): + continue + for field in sampling_fields: + _clear_inherited_sampling_default(env, raw_env, field) + if replaces_extra_body: + _clear_inherited_sampling_default(env, raw_env, "extra_body") + for key in extra_body_keys: + _clear_inherited_extra_body_default(env, raw_env, key, group_extra_body) + + +def _clear_inherited_env_defaults( + orchestrator_dict: dict[str, Any], + shared_orchestrator: dict[str, Any], + trial_parameters: dict[str, Any], +) -> None: + _clear_inherited_env_defaults_for_section( + orchestrator_dict, shared_orchestrator, trial_parameters, "train" + ) + _clear_inherited_env_defaults_for_section( + orchestrator_dict, shared_orchestrator, trial_parameters, "eval" + ) + + +def materialize_multi_run_trial( + config: SweepConfig, + trial: Trial, + scheduler: Any, # MultiRunLoRASchedulerConfig — typed as Any to avoid an import cycle +) -> TrialArtifacts: + """Write a per-trial ``run_/control/orch.toml`` for a shared-trainer sweep. + + The shared base TOMLs in ``scheduler.shared`` resolve to a full RLConfig. + Per-trial parameter overrides (already prefixed with ``orchestrator.``) + are layered on top, the orchestrator block is extracted, and its TOML is + written where the trainer's ``MultiRunManager`` will find it. Returned + ``TrialArtifacts.run_dir`` points at the per-trial directory so the + sweep's existing metrics readers (``read_final_summary`` / + ``read_intermediate_metric``) keep working unchanged once the + orchestrator's ``FileMonitor`` writes ``metrics.jsonl`` there. + """ + run_dir = multi_run_trial_dir(config, trial) + control_dir = run_dir / "control" + overrides_path = run_dir / "overrides.toml" + resolved_path = run_dir / "resolved.toml" + command_path = run_dir / "command.txt" + status_path = run_dir / "status.json" + orch_config_path = control_dir / "orch.toml" + + # Trial overrides already use orchestrator.* paths thanks to the + # validator's allowlist. Add the per-run output_dir + W&B identity. + flat_overrides: dict[str, Any] = dict(trial.parameters) + flat_overrides["orchestrator.output_dir"] = run_dir.as_posix() + _merge_multi_run_wandb_overrides(config, flat_overrides, trial) + + overrides = build_nested_overrides(flat_overrides) + write_toml(overrides_path, overrides) + + # Resolve shared RLConfig without per-run orchestrator overrides. The + # single-run RLConfig cross-checks require orchestrator LoRA rank/alpha to + # equal trainer rank/alpha, but multi-run permits per-run rank <= trainer + # rank. Trial overrides are validated against OrchestratorConfig below. + args: list[str] = [] + for base_path in scheduler.shared: + args.extend(["@", base_path.as_posix()]) + + resolved_rl_config = validate_target_config("rl", args) + _validate_multi_run_shared_config(resolved_rl_config, scheduler) + orchestrator_dict = resolved_rl_config.orchestrator.model_dump(exclude_none=True, mode="json") + shared_overrides = _load_nested_toml(scheduler.shared) + shared_orchestrator = shared_overrides.get("orchestrator", {}) + explicit_shared_orchestrator_fields = set(shared_orchestrator) if isinstance(shared_orchestrator, dict) else set() + merge_nested_overrides(orchestrator_dict, overrides.get("orchestrator", {})) + _canonicalize_multi_run_sampling_aliases(orchestrator_dict, trial.parameters) + _finalize_multi_run_batching_config( + orchestrator_dict, + trial.parameters, + explicit_shared_orchestrator_fields, + ) + _clear_inherited_env_defaults( + orchestrator_dict, + shared_orchestrator if isinstance(shared_orchestrator, dict) else {}, + trial.parameters, + ) + _finalize_multi_run_lora_config(orchestrator_dict, resolved_rl_config, trial.parameters) + + # RLConfig.auto_setup_output_dir resets orchestrator.output_dir to + # ``/run_default`` during validation, which would + # collapse every trial onto the same orchestrator directory. Restore the + # per-trial path so each orch.toml targets its own run_ slot. + orchestrator_dict["output_dir"] = run_dir.as_posix() + orchestrator_config = OrchestratorConfig(**orchestrator_dict) + orchestrator_config.output_dir = run_dir + orchestrator_dict = orchestrator_config.model_dump(exclude_none=True, mode="json") + _reject_bool_target_coercions( + trial.parameters, + orchestrator_dict, + path_prefix="orchestrator.", + ) + + write_toml(resolved_path, orchestrator_dict) + write_toml(orch_config_path, orchestrator_dict) + + # Mirror what build_multi_run_command issues so this command.txt is + # actually replayable: rl-multi-run requires --runs-dir with + # colon-separated paths, plus the trainer output_dir pin. + output_override_path = write_multi_run_output_override(multi_run_shared_dir(config)) + command = [ + "rl-multi-run", + *sum((["@", p.as_posix()] for p in scheduler.shared), []), + "@", + output_override_path.as_posix(), + "--runs-dir", + run_dir.as_posix(), + ] + command_path.write_text(shlex.join(command) + "\n") + + resolved_checksum = file_checksum(resolved_path) + base_checksums = {base.as_posix(): file_checksum(base) for base in scheduler.shared} + + write_json( + status_path, + { + "id": trial.id, + "label": trial.label, + "state": "pending", + "pid": None, + "slurm_job_id": None, + "gpu_group": None, + "returncode": None, + "objective": None, + }, + ) + + return TrialArtifacts( + trial=trial, + trial_dir=run_dir, + run_dir=run_dir, + overrides_path=overrides_path, + resolved_path=resolved_path, + command_path=command_path, + status_path=status_path, + command=command, + resolved_checksum=resolved_checksum, + base_checksums=base_checksums, + ) + + +def record_multi_run_materialization_failure( + config: SweepConfig, + trial: Trial, + scheduler: Any, + exc: BaseException, + *, + finished_at: str | None = None, +) -> TrialArtifacts: + """Write a manifestable artifact for a failed per-run orchestrator config.""" + run_dir = multi_run_trial_dir(config, trial) + control_dir = run_dir / "control" + overrides_path = run_dir / "overrides.toml" + resolved_path = run_dir / "resolved.toml" + command_path = run_dir / "command.txt" + status_path = run_dir / "status.json" + orch_config_path = control_dir / "orch.toml" + + flat_overrides: dict[str, Any] = dict(trial.parameters) + flat_overrides["orchestrator.output_dir"] = run_dir.as_posix() + _merge_multi_run_wandb_overrides(config, flat_overrides, trial) + write_toml(overrides_path, build_nested_overrides(flat_overrides)) + + output_override_path = write_multi_run_output_override(multi_run_shared_dir(config)) + command = [ + "rl-multi-run", + *sum((["@", p.as_posix()] for p in scheduler.shared), []), + "@", + output_override_path.as_posix(), + "--runs-dir", + run_dir.as_posix(), + ] + command_path.write_text(shlex.join(command) + "\n") + resolved_path.unlink(missing_ok=True) + orch_config_path.unlink(missing_ok=True) + + status = { + "id": trial.id, + "label": trial.label, + "state": "failed", + "pid": None, + "slurm_job_id": None, + "gpu_group": None, + "returncode": -1, + "objective": None, + "failure_stage": "materialization", + "error": _materialization_error(exc), + } + if finished_at is not None: + status["finished_at"] = finished_at + write_json(status_path, status) + + control_dir.mkdir(parents=True, exist_ok=True) + return TrialArtifacts( + trial=trial, + trial_dir=run_dir, + run_dir=run_dir, + overrides_path=overrides_path, + resolved_path=resolved_path, + command_path=command_path, + status_path=status_path, + command=command, + resolved_checksum="", + base_checksums={base.as_posix(): file_checksum(base) for base in scheduler.shared}, + ) diff --git a/src/prime_rl/sweep/metrics.py b/src/prime_rl/sweep/metrics.py new file mode 100644 index 0000000000..d8de6a8bbf --- /dev/null +++ b/src/prime_rl/sweep/metrics.py @@ -0,0 +1,155 @@ +"""Metric readers for sweep trials. + +The canonical source for sweep objectives is ``/metrics.jsonl``, +written by ``FileMonitor`` whenever the sweep launcher sets +``PRIME_RL_SWEEP_METRICS_JSONL`` on the trial subprocess. One JSON line per +``monitor.log()`` call carries ``{step, ...metrics}``. + +Two reader shapes: + +- ``read_final_metric`` returns the value at the largest reported step. Used + by Phase 4 once the trial completes; the last row of metrics.jsonl IS the + final summary for sweep purposes. +- ``read_intermediate_metric`` returns the latest ``(step, value)`` pair while + the trial is still running. Used by the Phase 5b Optuna pruning loop. + +If metrics.jsonl is absent (e.g. an SFT run without FileMonitor for some +reason, or a crashed launch before the monitor was set up), the reader falls +back to the legacy ``final_summary.json`` so existing behavior is preserved. +""" + +import json +import math +from pathlib import Path +from typing import Any + + +def _final_summary_paths(run_dir: Path) -> list[Path]: + """Return any ``run-*/final_summary.json`` files under ``run_dir``.""" + if not run_dir.exists(): + return [] + return sorted(run_dir.glob("run-*/final_summary.json")) + + +def coerce_finite_float(value: Any) -> float | None: + """Return ``value`` as a finite float, or ``None`` for anything else. + + NaN / +Inf / -Inf are rejected because they break later improvement and + threshold comparisons (NaN compares False with everything, +/-Inf would + pin best forever) and ``json.dumps`` writes them as non-standard + ``NaN`` / ``Infinity`` tokens that other readers cannot parse. + """ + if isinstance(value, bool): + return None + if isinstance(value, (int, float)): + try: + scalar = float(value) + except OverflowError: + return None + return scalar if math.isfinite(scalar) else None + return None + + +def _metrics_jsonl_path(run_dir: Path) -> Path: + return run_dir / "metrics.jsonl" + + +def _valid_step(row: dict[str, Any]) -> bool: + step = row.get("step") + return isinstance(step, int) and not isinstance(step, bool) and step >= 0 + + +def _step_sort_key(row: dict[str, Any]) -> int: + return row["step"] if _valid_step(row) else -1 + + +def _latest_metric_row( + rows: list[dict[str, Any]], + metric: str, + *, + require_valid_step: bool, +) -> dict[str, Any] | None: + rows_with_metric = [(idx, row) for idx, row in enumerate(rows) if metric in row] + if require_valid_step: + rows_with_metric = [(idx, row) for idx, row in rows_with_metric if _valid_step(row)] + if not rows_with_metric: + return None + _, latest = max(rows_with_metric, key=lambda item: (_step_sort_key(item[1]), item[0])) + return latest + + +def _iter_metrics_rows(run_dir: Path) -> list[dict[str, Any]]: + """Return all JSON objects in metrics.jsonl, skipping malformed lines. + + A partially-flushed final line (e.g. controller polled mid-write) is + silently skipped. This is preferable to raising because it lets the + polling loop keep tolerating intermediate snapshots. + """ + path = _metrics_jsonl_path(run_dir) + if not path.exists(): + return [] + rows: list[dict[str, Any]] = [] + for raw in path.read_text(encoding="utf-8").splitlines(): + line = raw.strip() + if not line: + continue + try: + obj = json.loads(line) + except json.JSONDecodeError: + continue + if isinstance(obj, dict): + rows.append(obj) + return rows + + +def read_final_summary(run_dir: Path, metric: str) -> float | None: + """Read ``metric`` from the canonical sidecar, falling back to legacy file. + + Sweep trials always write ``metrics.jsonl``; the final value is the + latest reported step's reading of ``metric``. If the sidecar is missing + (legacy artifact, crashed launch), fall back to the most recently + modified ``final_summary.json`` under the run directory. + + Returns ``None`` if no source supplies a finite scalar for ``metric``. + Tolerating absence keeps the sweep alive when a trial legitimately ran + without a summary (W&B disabled, run crashed, etc.) so the controller + can record ``objective=None`` rather than abort. + """ + rows = _iter_metrics_rows(run_dir) + if rows: + latest = _latest_metric_row(rows, metric, require_valid_step=True) + if latest is not None: + return coerce_finite_float(latest.get(metric)) + + paths = _final_summary_paths(run_dir) + if not paths: + return None + paths.sort(key=lambda p: p.stat().st_mtime, reverse=True) + try: + summary = json.loads(paths[0].read_text()) + except (OSError, json.JSONDecodeError): + return None + if not isinstance(summary, dict): + return None + return coerce_finite_float(summary.get(metric)) + + +def read_intermediate_metric(run_dir: Path, metric: str) -> tuple[int, float] | None: + """Return the latest ``(step, value)`` pair for ``metric``, or ``None``. + + Reads metrics.jsonl, which the trial subprocess streams in real time. + Returns ``None`` when the file is absent, the metric has not been logged + yet, or the latest value is not a finite scalar. The Optuna pruning loop + treats ``None`` as "no new data, do not report this round". + """ + rows = _iter_metrics_rows(run_dir) + if not rows: + return None + latest = _latest_metric_row(rows, metric, require_valid_step=True) + if latest is None: + return None + value = coerce_finite_float(latest.get(metric)) + if value is None: + return None + raw_step = latest.get("step") + return raw_step, value diff --git a/src/prime_rl/sweep/multi_run.py b/src/prime_rl/sweep/multi_run.py new file mode 100644 index 0000000000..b3f04ea57a --- /dev/null +++ b/src/prime_rl/sweep/multi_run.py @@ -0,0 +1,472 @@ +"""Sweep-side runtime helpers for shared-trainer LoRA multi-run sweeps. + +The trainer (``prime_rl/trainer/runs.py:MultiRunManager``) writes +``/control/evicted.txt`` to evict a run when it's about to lose its +LoRA slot. The orchestrator (``prime_rl/orchestrator/orchestrator.py``) polls +that same file at the top of each training loop iteration and exits. + +Phase 7b adds a third writer: the sweep controller itself, when an Optuna +sampler decides one of the in-flight trials should be pruned. This module is +the bridge plus the wave driver that runs Optuna against ``rl-multi-run``. +""" + +from __future__ import annotations + +import json +import shutil +import subprocess +import time +from dataclasses import asdict +from pathlib import Path +from typing import TYPE_CHECKING, Any + +from prime_rl.configs.sweep import ( + MultiRunLoRASchedulerConfig, + OptunaStrategyConfig, + SweepConfig, +) +from prime_rl.sweep import materialize as _materialize +from prime_rl.sweep.early_stopping import TrialOutcome, TrialOutcomeTracker +from prime_rl.sweep.materialize import ( + TrialArtifacts, + materialize_multi_run_trial, + multi_run_shared_dir, + read_status_json, + record_multi_run_materialization_failure, + record_trial_missing_objective, + record_trial_objective, +) +from prime_rl.sweep.metrics import read_final_summary, read_intermediate_metric +from prime_rl.sweep.optuna_loop import ( + _create_study, + _import_optuna, + _make_trial, + _suggest_parameters, +) +from prime_rl.sweep.schedulers import ( + _mark_inactive_multi_run_dirs_evicted, + _read_orchestrator_exit_code, + _read_status, + _reset_multi_run_artifact_runtime, + _write_launch_failure_status, + _write_status, + build_multi_run_command, + reconcile_multi_run_artifact, + utc_now, +) + +if TYPE_CHECKING: # pragma: no cover + import optuna + +EVICTED_FILENAME = "evicted.txt" + +# Subdirs the trainer writes under its ``output_dir`` between runs. Each +# Optuna wave starts a fresh trainer pinned to ``shared_dir``, so leftover +# state from a previous wave (checkpoints, weights, broadcasts, rollouts) +# would be picked up by the new trainer — silently resuming with stale +# checkpoints, or colliding on ``step_*`` writes. Per-trial ``run_`` +# directories are *not* listed here; the controller manages those via +# ``_mark_inactive_multi_run_dirs_evicted``. +_TRAINER_OWNED_SUBDIRS = ("weights", "broadcasts", "rollouts", "run_default") + + +def _resolve_shared_ckpt_dir(shared_paths: list[Path], shared_dir: Path) -> Path: + """Find where the shared trainer's checkpoints land for this study. + + Falls back to ``/checkpoints`` when ``ckpt.output_dir`` is + not set in the shared RLConfig (the auto_setup_ckpt default). + """ + args: list[str] = [] + for base in shared_paths: + args.extend(["@", base.as_posix()]) + resolved = _materialize.validate_target_config("rl", args) + ckpt = getattr(resolved, "ckpt", None) + override = getattr(ckpt, "output_dir", None) if ckpt is not None else None + if override is not None: + return Path(override) / "checkpoints" + return shared_dir / "checkpoints" + + +def _reset_trainer_state_for_wave(shared_dir: Path, ckpt_dir: Path) -> None: + """Wipe trainer-owned artifacts before launching the next Optuna wave. + + Each wave runs a fresh ``rl-multi-run`` whose trainer pins + ``output_dir = shared_dir``. Without this reset, the trainer for wave + N+1 starts on top of wave N's checkpoints, weights, and step files; + checkpoint-enabled configs would silently resume from stale state, and + broadcast/rollout step directories from the prior wave would collide + with the new run's step 0 writes. Per-trial ``run_`` directories + are intentionally preserved — the sweep controller materializes them + before this function runs. + """ + if ckpt_dir.exists(): + shutil.rmtree(ckpt_dir) + for subdir in _TRAINER_OWNED_SUBDIRS: + path = shared_dir / subdir + if path.exists(): + shutil.rmtree(path) + + +def prune_run( + run_dir: Path, + reason: str, + *, + step: int | None = None, + value: float | None = None, +) -> None: + """Pre-mark the trial pruned in ``status.json`` then write ``evicted.txt``. + + Order matters: the orchestrator's eviction handler raises ``RuntimeError`` + and the orchestrator exits non-zero. If we wrote ``evicted.txt`` first and + crashed before updating ``status.json``, the launcher's exit-code + reconciliation would misclassify the deliberately-pruned trial as + ``failed`` (it has no way to know the eviction was a sampler decision + rather than a slot-pressure eviction from the trainer). + + ``step`` and ``value`` are recorded on the status when the caller has them + — Optuna prunes know both — but they're optional so callers without that + context (manual prune, future heuristics) can still mark the trial pruned. + """ + status_path = run_dir / "status.json" + status = read_status_json(status_path) + status["state"] = "pruned" + status["pruned_reason"] = reason + if step is not None: + status["pruned_at_step"] = int(step) + if value is not None: + status["pruned_value"] = float(value) + # Surfacing pruned trials with a None objective keeps the manifest + # summary's best-value computation symmetric with single-trial pruning + # (see materialize.record_trial_pruned). + status["objective"] = None + status_path.write_text(json.dumps(status, indent=2, sort_keys=True) + "\n") + + control_dir = run_dir / "control" + control_dir.mkdir(parents=True, exist_ok=True) + (control_dir / EVICTED_FILENAME).write_text(reason + "\n") + + +def _poll_wave_for_pruning( + optuna: Any, + proc: subprocess.Popen[bytes], + artifacts: list[TrialArtifacts], + optuna_trials: list[optuna.Trial], + metric: str, + poll_interval: float, +) -> None: + """While ``rl-multi-run`` runs, drive Optuna's report/should_prune for each artifact. + + On each tick: + + 1. For every artifact whose status is not already ``pruned`` and whose + orchestrator has not already recorded ``control/exit_code``, read the + latest ``(step, value)`` from its ``metrics.jsonl`` sidecar. + 2. If we've never reported this step (or any step at all) for this trial, + call ``optuna_trial.report(value, step)`` and check ``should_prune``. + 3. On a prune signal, write ``status.json`` + ``evicted.txt`` so the + orchestrator winds down. The trainer's MultiRunManager picks the same + file up on its next ``discover_runs()`` cycle and frees the LoRA slot. + + The trainer-side eviction handles process termination — we never SIGTERM + the orchestrator ourselves. Survivors keep running until the wave's + ``rl-multi-run`` exits naturally. + """ + last_step: dict[Path, int | None] = {a.run_dir: None for a in artifacts} + + while proc.poll() is None: + for artifact, optuna_trial in zip(artifacts, optuna_trials): + try: + status = _read_status(artifact) + except FileNotFoundError: + continue + if status.get("state") == "pruned": + continue + if _read_orchestrator_exit_code(artifact) is not None: + continue + + sample = read_intermediate_metric(artifact.run_dir, metric) + if sample is None: + continue + step, value = sample + prev = last_step[artifact.run_dir] + if prev is not None and step <= prev: + continue + + optuna_trial.report(value, step) + last_step[artifact.run_dir] = step + + if _read_orchestrator_exit_code(artifact) is not None: + continue + if proc.poll() is not None: + break + if optuna_trial.should_prune(): + # The wave may finish between the loop guard and this + # decision. Once rl-multi-run has exited, final objective + # reconciliation wins and pruning must not rewrite a + # completed run as pruned. + if proc.poll() is not None: + break + prune_run( + artifact.run_dir, + reason=f"optuna prune at step {step}", + step=step, + value=value, + ) + time.sleep(poll_interval) + + +def _tell_wave_results( + optuna: Any, + study: optuna.Study, + artifacts: list[TrialArtifacts], + optuna_trials: list[optuna.Trial], + metric: str, + aggregate_returncode: int, +) -> tuple[int, list[float | None]]: + """Reconcile per-trial state and tell Optuna each trial's result. + + Returns ``(failures, objectives)`` where ``objectives[i]`` is the recorded + objective for ``artifacts[i]`` (or ``None`` for pruned/failed trials). + The caller folds those into the ``TrialOutcomeTracker`` for the manifest + summary. + """ + finished_at = utc_now() + failures = 0 + objectives: list[float | None] = [] + + for artifact, optuna_trial in zip(artifacts, optuna_trials): + state = reconcile_multi_run_artifact( + artifact, aggregate_returncode=aggregate_returncode, finished_at=finished_at + ) + + objective: float | None = None + if state == "completed": + objective = read_final_summary(artifact.run_dir, metric) + record_trial_objective(artifact.status_path, objective) + if objective is None: + # Clean exit but the metric never showed up — Optuna learned + # nothing from this slot. Tell FAIL and count it as a sweep + # failure (mirrors the single-trial Optuna driver). + record_trial_missing_objective(artifact.status_path, metric) + study.tell(optuna_trial, state=optuna.trial.TrialState.FAIL) + failures += 1 + else: + study.tell(optuna_trial, objective) + elif state == "pruned": + study.tell(optuna_trial, state=optuna.trial.TrialState.PRUNED) + else: # failed + record_trial_objective(artifact.status_path, None) + study.tell(optuna_trial, state=optuna.trial.TrialState.FAIL) + failures += 1 + + objectives.append(objective) + + return failures, objectives + + +def run_multi_run_optuna_sweep( + config: SweepConfig, + *, + write_manifest_with_variants: Any, + build_variant: Any, +) -> tuple[int, TrialOutcomeTracker | None, list[TrialArtifacts]]: + """Drive an Optuna study against ``rl-multi-run`` in waves. + + Each wave asks Optuna for ``min(max_concurrent_runs, remaining)`` trials, + materializes them as ``run_*`` dirs under the shared trainer dir, spawns + one ``rl-multi-run`` invocation, polls each run's ``metrics.jsonl`` for + Optuna ``report``/``should_prune`` decisions, and finally tells Optuna + each trial's result. + + Slot replacement is intentionally not supported: a slot freed mid-wave + by pruning sits idle until the wave finishes. True slot replacement + needs ``rl-multi-run`` to accept new run dirs over the wire (Phase 7c). + """ + optuna = _import_optuna() + strategy = config.strategy + scheduler = config.scheduler + assert isinstance(strategy, OptunaStrategyConfig) + assert isinstance(scheduler, MultiRunLoRASchedulerConfig) + assert config.objective is not None # validated upstream + + study = _create_study(optuna, config) + metric = config.objective.metric + wave_size = scheduler.max_concurrent_runs + total = strategy.num_trials + poll_interval = strategy.poll_interval_seconds + + tracker = TrialOutcomeTracker(config.objective, config.early_stopping) + shared_dir = multi_run_shared_dir(config) + shared_dir.mkdir(parents=True, exist_ok=True) + ckpt_dir = _resolve_shared_ckpt_dir(scheduler.shared, shared_dir) + + all_artifacts: list[TrialArtifacts] = [] + failures = 0 + submitted = 0 + wave_index = 0 + + while submitted < total: + if tracker.halted: + break + this_wave = min(wave_size, total - submitted) + stop_after_wave = False + + # Wipe trainer-owned artifacts from the prior wave before + # materializing or launching this one. Wave 0 starts on a clean + # ``shared_dir`` (just created above), so this is a no-op then; + # subsequent waves inherit checkpoints/weights/broadcasts/rollouts + # from the previous trainer and would otherwise resume or collide. + if wave_index > 0: + _reset_trainer_state_for_wave(shared_dir, ckpt_dir) + wave_index += 1 + + # 1. Ask Optuna for `this_wave` trials and materialize each. + # Failed materializations are excluded from the launch wave but kept + # in the manifest; survivors stay paired so the poll loop and + # reconcile step have aligned (optuna_trial, artifact) lists. + wave_pairs: list[tuple[optuna.Trial, TrialArtifacts]] = [] + wave_artifacts_for_manifest: list[TrialArtifacts] = [] + for offset in range(this_wave): + optuna_trial = study.ask() + params = _suggest_parameters(optuna_trial, config.parameters) + sweep_trial = _make_trial(submitted + offset, params) + try: + artifact = materialize_multi_run_trial(config, sweep_trial, scheduler) + except Exception as exc: + artifact = record_multi_run_materialization_failure( + config, sweep_trial, scheduler, exc, finished_at=utc_now() + ) + wave_artifacts_for_manifest.append(artifact) + study.tell(optuna_trial, state=optuna.trial.TrialState.FAIL) + failures += 1 + print(f"Optuna trial {sweep_trial.id} failed materialization: {exc}") + if not config.continue_on_failure: + stop_after_wave = True + break + continue + wave_pairs.append((optuna_trial, artifact)) + wave_artifacts_for_manifest.append(artifact) + + if stop_after_wave: + if wave_pairs: + finished_at = utc_now() + for optuna_trial, artifact in wave_pairs: + _write_status( + artifact, + state="failed", + finished_at=finished_at, + returncode=-1, + objective=None, + failure_stage="scheduler", + error=( + "Trial was not launched because another trial in the same Optuna " + "multi_run_lora wave failed materialization and continue_on_failure=false." + ), + ) + study.tell(optuna_trial, state=optuna.trial.TrialState.FAIL) + failures += 1 + if wave_artifacts_for_manifest: + all_artifacts.extend(wave_artifacts_for_manifest) + write_manifest_with_variants(config, [build_variant(a) for a in all_artifacts]) + break + + # If every trial in the wave failed materialization there's nothing + # to launch; advance the counter and try the next wave. + if not wave_pairs: + if wave_artifacts_for_manifest: + all_artifacts.extend(wave_artifacts_for_manifest) + write_manifest_with_variants(config, [build_variant(a) for a in all_artifacts]) + submitted += this_wave + continue + + wave_optuna_trials = [pair[0] for pair in wave_pairs] + wave_artifacts = [pair[1] for pair in wave_pairs] + all_artifacts.extend(wave_artifacts_for_manifest) + write_manifest_with_variants(config, [build_variant(a) for a in all_artifacts]) + + # 2. Spawn rl-multi-run for this wave. + command = build_multi_run_command(wave_artifacts, scheduler.shared, shared_dir) + proc = None + attempts = 0 + while proc is None: + attempts += 1 + started = utc_now() + for artifact in wave_artifacts: + _reset_multi_run_artifact_runtime(artifact) + _write_status(artifact, state="running", started_at=started, attempts=attempts, gpu_group=None) + _mark_inactive_multi_run_dirs_evicted( + shared_dir, + [artifact.run_dir for artifact in wave_artifacts], + reason="Inactive run directory is not part of the current Optuna wave.", + ) + try: + proc = subprocess.Popen(command) + except OSError as exc: + if attempts <= config.retry_budget: + continue + finished_at = utc_now() + wave_failures = 0 + for optuna_trial, artifact in zip(wave_optuna_trials, wave_artifacts): + _write_launch_failure_status(artifact, exc, finished_at=finished_at) + study.tell(optuna_trial, state=optuna.trial.TrialState.FAIL) + wave_failures += 1 + failures += wave_failures + + for artifact in wave_artifacts: + tracker.observe( + TrialOutcome( + trial_id=artifact.trial.id, + label=artifact.trial.label, + objective=None, + ) + ) + + submitted += this_wave + if not config.continue_on_failure: + stop_after_wave = True + break + + if proc is None: + if stop_after_wave: + break + continue + + try: + _poll_wave_for_pruning( + optuna, proc, wave_artifacts, wave_optuna_trials, metric, poll_interval + ) + finally: + proc.wait() + + # 4. Reconcile per-trial state and tell Optuna. + wave_failures, objectives = _tell_wave_results( + optuna, study, wave_artifacts, wave_optuna_trials, metric, proc.returncode + ) + failures += wave_failures + + if wave_failures > 0 and not config.continue_on_failure: + stop_after_wave = True + + # 5. Fold objectives into the tracker for early stopping + summary. + for artifact, objective in zip(wave_artifacts, objectives): + tracker.observe( + TrialOutcome( + trial_id=artifact.trial.id, + label=artifact.trial.label, + objective=objective, + ) + ) + if tracker.halted: + break + + submitted += this_wave + if stop_after_wave: + break + + write_manifest_with_variants(config, [build_variant(a) for a in all_artifacts]) + return failures, tracker, all_artifacts + + +# Re-exported for the controller to update the manifest summary. +def tracker_summary(tracker: TrialOutcomeTracker) -> dict[str, Any]: + return asdict(tracker.summary()) diff --git a/src/prime_rl/sweep/optuna_loop.py b/src/prime_rl/sweep/optuna_loop.py new file mode 100644 index 0000000000..be6e1012fe --- /dev/null +++ b/src/prime_rl/sweep/optuna_loop.py @@ -0,0 +1,1312 @@ +"""Optuna ask/tell driver for sweep trials. + +Phase 5a supports TPE / Random samplers without pruning: trials run to +completion, the controller reads the final objective, and tells Optuna the +result before asking for the next parameter set. + +Phase 5b adds pruning. When ``strategy.pruner`` is non-trivial the controller +spawns the trial as a child process group, polls ``metrics.jsonl`` while the +trial runs, calls ``optuna_trial.report(value, step)`` and +``optuna_trial.should_prune()`` between samples, and on a prune signal sends +SIGTERM (escalating to SIGKILL) to the trial's process group. Pruned trials +are recorded with ``state="pruned"`` in ``status.json`` and reported to Optuna +as ``TrialState.PRUNED`` so adaptive sampling can distinguish them from +completed runs and outright failures. +""" + +from __future__ import annotations + +import json +import os +import signal +import subprocess +from collections import Counter +from dataclasses import asdict, dataclass +from pathlib import Path +from typing import TYPE_CHECKING, Any, Literal + +from prime_rl.configs.sweep import ( + AshaPrunerConfig, + ChoiceParameterConfig, + HyperbandPrunerConfig, + IntUniformParameterConfig, + LocalSweepSchedulerConfig, + LogUniformParameterConfig, + MedianPrunerConfig, + NoPrunerConfig, + OptunaStrategyConfig, + PrunerConfig, + SlurmSweepSchedulerConfig, + SweepConfig, + SweepParameterConfig, + UniformParameterConfig, +) +from prime_rl.sweep.early_stopping import TrialOutcome, TrialOutcomeTracker +from prime_rl.sweep.materialize import ( + Trial, + TrialArtifacts, + materialize_trial, + record_trial_materialization_failure, + record_trial_missing_objective, + record_trial_objective, + record_trial_pruned, + write_json, +) +from prime_rl.sweep.metrics import coerce_finite_float, read_final_summary, read_intermediate_metric +from prime_rl.sweep.reproducibility import file_checksum +from prime_rl.sweep.schedulers import ( + _build_env, + _reset_metrics_jsonl, + _run_trial_with_pruning_slurm_sync, + _run_with_retries, + _run_with_retries_slurm_sync, + _write_launch_failure_status, + _write_status, + utc_now, +) +from prime_rl.sweep.search import parameters_hash, trial_label + +if TYPE_CHECKING: # pragma: no cover + import optuna + + +def _import_optuna() -> Any: + try: + import optuna # type: ignore[import-not-found] + except ImportError as exc: + raise SystemExit( + "Optuna strategy requires the [hpo] extra. Install with: uv sync --extra hpo" + ) from exc + return optuna + + +def _suggest_parameters( + optuna_trial: optuna.Trial, + parameters: dict[str, SweepParameterConfig], +) -> dict[str, Any]: + suggested: dict[str, Any] = {} + for path, config in parameters.items(): + if isinstance(config, ChoiceParameterConfig): + suggested[path] = optuna_trial.suggest_categorical(path, config.values) + elif isinstance(config, UniformParameterConfig): + suggested[path] = optuna_trial.suggest_float(path, config.min, config.max) + elif isinstance(config, LogUniformParameterConfig): + suggested[path] = optuna_trial.suggest_float(path, config.min, config.max, log=True) + elif isinstance(config, IntUniformParameterConfig): + suggested[path] = optuna_trial.suggest_int(path, config.min, config.max, step=config.step) + else: + raise ValueError(f"Unsupported parameter type for Optuna: {type(config)!r}") + return suggested + + +def _build_sampler( + optuna: Any, + strategy: OptunaStrategyConfig, + *, + concurrent_trials: int = 1, +) -> Any: + """Construct the Optuna sampler, opting into ``constant_liar`` for TPE + when more than one trial will be in flight concurrently. + + TPE estimates its density from completed trials only. Without + ``constant_liar``, concurrent asks see the same set of completed + trials and can collide on the same region of the search space; with + it, Optuna assigns a placeholder objective to running trials so the + next ask is forced to diversify. ``RandomSampler`` doesn't need this. + """ + if strategy.sampler == "tpe": + kwargs: dict[str, Any] = {"seed": strategy.seed} + if concurrent_trials > 1: + kwargs["constant_liar"] = True + return optuna.samplers.TPESampler(**kwargs) + if strategy.sampler == "random": + return optuna.samplers.RandomSampler(seed=strategy.seed) + raise ValueError(f"Unsupported Optuna sampler: {strategy.sampler}") + + +def _build_pruner(optuna: Any, pruner: PrunerConfig) -> Any: + """Map our discriminated PrunerConfig union to an Optuna pruner instance. + + ``NopPruner`` is Optuna's no-op default; using it explicitly keeps the + study creation symmetric across pruner choices. + """ + if isinstance(pruner, NoPrunerConfig): + return optuna.pruners.NopPruner() + if isinstance(pruner, MedianPrunerConfig): + return optuna.pruners.MedianPruner( + n_startup_trials=pruner.n_startup_trials, + n_warmup_steps=pruner.n_warmup_steps, + interval_steps=pruner.interval_steps, + ) + if isinstance(pruner, AshaPrunerConfig): + return optuna.pruners.SuccessiveHalvingPruner( + min_resource=pruner.min_resource, + reduction_factor=pruner.reduction_factor, + min_early_stopping_rate=pruner.min_early_stopping_rate, + ) + if isinstance(pruner, HyperbandPrunerConfig): + return optuna.pruners.HyperbandPruner( + min_resource=pruner.min_resource, + max_resource=pruner.max_resource, + reduction_factor=pruner.reduction_factor, + ) + raise ValueError(f"Unsupported Optuna pruner: {pruner!r}") + + +def _create_study(optuna: Any, config: SweepConfig) -> Any: + """Create or reload the Optuna study. + + ``load_if_exists`` is gated on ``config.resume``: a fresh sweep must start + from an empty optimization history, otherwise old trials would bias the + sampler and the storage would silently accumulate trials across runs that + the user thought were independent. With persistent storage and no + ``resume`` flag, optuna raises ``DuplicatedStudyError`` to surface the + collision instead of attaching silently. + """ + strategy = config.strategy + assert isinstance(strategy, OptunaStrategyConfig) + assert config.objective is not None # validated upstream + direction = "maximize" if config.objective.direction == "maximize" else "minimize" + concurrent_trials = ( + config.scheduler.max_parallel + if isinstance(config.scheduler, SlurmSweepSchedulerConfig) + else 1 + ) + return optuna.create_study( + study_name=strategy.study_name or config.name or "sweep", + storage=strategy.storage, + sampler=_build_sampler(optuna, strategy, concurrent_trials=concurrent_trials), + pruner=_build_pruner(optuna, strategy.pruner), + direction=direction, + load_if_exists=config.resume, + ) + + +def _make_trial(index: int, parameters: dict[str, Any]) -> Trial: + trial_id = f"{index:04d}-{parameters_hash(parameters)}" + label = trial_label(parameters) or trial_id + return Trial(id=trial_id, label=label, parameters=parameters) + + +@dataclass +class _PollingOutcome: + """Result of running a trial with intermediate-metric polling. + + ``unsafe_to_continue`` flags failures where the controller cannot + confirm the underlying job has stopped — e.g. SLURM ``squeue`` is + persistently unreachable, or a prune-triggered ``scancel`` did not + confirm the job left the queue. The outer loop must halt the sweep + in that case regardless of ``continue_on_failure``: launching the + next Optuna trial would race the still-active allocation, which is + worse than a noisy stop. + """ + + state: Literal["completed", "pruned", "failed"] + returncode: int + objective: float | None + pruned_at_step: int | None = None + pruned_value: float | None = None + reports_sent: int = 0 + launch_error: bool = False + launch_exception: OSError | None = None + unsafe_to_continue: bool = False + + +def _terminate_process_group(process: subprocess.Popen[bytes], grace_seconds: float = 10.0) -> None: + """Stop the trial subprocess and any descendants it spawned. + + The sweep launches children with ``start_new_session=True`` so the trial + and everything it spawns share a process group. We send SIGTERM to the + whole group, wait briefly for graceful exit, then escalate to SIGKILL. + Killing only the parent is not enough: the trainer/orchestrator/inference + children would be reparented to init and keep running, holding GPUs and + skewing the next trial's measurements. + """ + if process.poll() is not None: + return + try: + pgid = os.getpgid(process.pid) + except ProcessLookupError: + return + try: + os.killpg(pgid, signal.SIGTERM) + except ProcessLookupError: + return + try: + process.wait(timeout=grace_seconds) + return + except subprocess.TimeoutExpired: + pass + try: + os.killpg(pgid, signal.SIGKILL) + except ProcessLookupError: + return + try: + process.wait(timeout=grace_seconds) + except subprocess.TimeoutExpired: + pass + + +def _run_trial_with_pruning( + artifact: TrialArtifacts, + gpu_group: list[int] | None, + optuna_trial: optuna.Trial, + metric: str, + poll_interval: float, + attempt: int = 1, +) -> _PollingOutcome: + """Spawn a trial and poll its metrics.jsonl for Optuna pruning decisions. + + Each new ``(step, value)`` pair the trial reports is forwarded to + ``optuna_trial.report``. After every report we ask + ``optuna_trial.should_prune()``; on True we terminate the trial's process + group and return a ``pruned`` outcome. On natural exit we read the final + objective from the same sidecar so the sampler sees the same value the + rest of the sweep records. + + No retry loop here: the polling driver is meant to be the caller's + single attempt, with retries handled by the outer loop only when the + trial actually fails (returncode != 0 and no prune signal). + """ + env = _build_env(artifact, gpu_group) + _reset_metrics_jsonl(artifact) + _write_status( + artifact, + state="running", + started_at=utc_now(), + attempts=attempt, + gpu_group=list(gpu_group) if gpu_group is not None else None, + ) + try: + process = subprocess.Popen(artifact.command, env=env, start_new_session=True) + except OSError as exc: + return _PollingOutcome( + state="failed", + returncode=-1, + objective=None, + launch_error=True, + launch_exception=exc, + ) + last_reported_step: int | None = None + reports_sent = 0 + + try: + while True: + try: + returncode = process.wait(timeout=poll_interval) + except subprocess.TimeoutExpired: + returncode = None + + sample = read_intermediate_metric(artifact.run_dir, metric) + if sample is not None: + step, value = sample + if last_reported_step is None or step > last_reported_step: + optuna_trial.report(value, step) + last_reported_step = step + reports_sent += 1 + # Only consider pruning while the trial is still running. + # If the subprocess already exited, the run produced its + # final objective and pruning would discard a valid value. + current_returncode = process.poll() + if current_returncode is not None: + returncode = current_returncode + elif optuna_trial.should_prune(): + _terminate_process_group(process) + pruned_returncode = process.returncode if process.returncode is not None else -1 + record_trial_pruned( + artifact.status_path, + step, + value, + returncode=pruned_returncode, + finished_at=utc_now(), + ) + return _PollingOutcome( + state="pruned", + returncode=pruned_returncode, + objective=None, + pruned_at_step=step, + pruned_value=value, + reports_sent=reports_sent, + ) + + if returncode is not None: + break + finally: + # Belt-and-suspenders: if we exit through an unexpected path the + # trial process must not be left running. + _terminate_process_group(process) + + if returncode == 0: + objective = read_final_summary(artifact.run_dir, metric) + _write_status(artifact, state="completed", finished_at=utc_now(), returncode=0) + return _PollingOutcome( + state="completed", returncode=0, objective=objective, reports_sent=reports_sent + ) + + _write_status(artifact, state="failed", finished_at=utc_now(), returncode=returncode) + return _PollingOutcome( + state="failed", returncode=returncode, objective=None, reports_sent=reports_sent + ) + + +def _run_trial_with_pruning_slurm_sync_and_retries( + artifact: TrialArtifacts, + optuna_trial: optuna.Trial, + metric: str, + poll_interval: float, + retry_budget: int, +) -> _PollingOutcome: + """SLURM-sync analog of ``_run_trial_with_pruning_and_retries``. + + Same retry contract: ``completed`` and ``pruned`` return immediately, only + ``failed`` outcomes are retried, and any sent reports disable retry to + avoid biasing the pruner across attempts. + """ + attempts = 0 + while True: + attempts += 1 + outcome = _run_trial_with_pruning_slurm_sync( + artifact, + optuna_trial, + metric, + poll_interval, + attempt=attempts, + ) + if outcome.state in ("completed", "pruned"): + return outcome + if outcome.unsafe_to_continue: + # The underlying SLURM job may still be alive — retrying would + # try to submit a second job for the same trial. Bail out so + # the outer loop can halt the sweep. + return outcome + if outcome.launch_error: + if attempts > retry_budget: + if outcome.launch_exception is not None: + _write_launch_failure_status(artifact, outcome.launch_exception) + return outcome + continue + if outcome.reports_sent > 0: + return outcome + if attempts > retry_budget: + return outcome + + +def _run_trial_with_pruning_and_retries( + artifact: TrialArtifacts, + gpu_group: list[int] | None, + optuna_trial: optuna.Trial, + metric: str, + poll_interval: float, + retry_budget: int, +) -> _PollingOutcome: + """Wrap the polling driver in the project's retry-on-failure semantics. + + Pruned and completed outcomes return immediately. Only ``failed`` + outcomes (subprocess returncode != 0 with no prune signal) are retried, + so a deliberately stopped trial is never resurrected. + + A subtle constraint: ``optuna_trial.report`` calls accumulate on the + Optuna trial object across retries. If the failed attempt already + reported intermediate values, those values stay on the trial and will + bias the pruner's decisions during the retry (and Optuna may also + silently drop duplicate-step reports). We therefore refuse to retry + once any reports have been sent — the trial fails outright instead. + """ + attempts = 0 + while True: + attempts += 1 + outcome = _run_trial_with_pruning( + artifact, + gpu_group, + optuna_trial, + metric, + poll_interval, + attempt=attempts, + ) + if outcome.state in ("completed", "pruned"): + return outcome + if outcome.launch_error: + if attempts > retry_budget: + if outcome.launch_exception is not None: + _write_launch_failure_status(artifact, outcome.launch_exception) + return outcome + continue + if outcome.reports_sent > 0: + # Stale intermediate reports would bias the retry; surface the + # failure and let the caller record TrialState.FAIL. + return outcome + if attempts > retry_budget: + return outcome + + +def _load_previous_variants(config: SweepConfig) -> list[dict[str, Any]]: + manifest_path = config.output_dir / "manifest.json" + if not manifest_path.exists(): + return [] + try: + manifest = json.loads(manifest_path.read_text()) + except json.JSONDecodeError as exc: + raise RuntimeError( + "Optuna resume cannot reuse existing trials because the previous manifest " + "is not valid JSON. Restore the manifest or start a fresh study/output_dir." + ) from exc + if not isinstance(manifest, dict): + raise RuntimeError( + "Optuna resume cannot reuse existing trials because the previous manifest " + "is not a JSON object. Restore the manifest or start a fresh study/output_dir." + ) + variants = manifest.get("variants", []) + if not isinstance(variants, list): + raise RuntimeError( + "Optuna resume requires manifest variants to be recorded as a list. " + "Restore the manifest or start a fresh study/output_dir." + ) + non_object_entries = [idx for idx, variant in enumerate(variants) if not isinstance(variant, dict)] + if non_object_entries: + raise RuntimeError( + "Optuna resume requires every manifest variant entry to be a JSON object. " + f"Invalid variant index(es): {non_object_entries}. Repair the manifest or start a fresh study/output_dir." + ) + return variants + + +def _variants_for_trial_number(previous_variants: list[dict[str, Any]], trial_number: int) -> list[dict[str, Any]]: + prefix = f"{trial_number:04d}-" + return [variant for variant in previous_variants if str(variant.get("id", "")).startswith(prefix)] + + +def _variant_trial_number(variant: dict[str, Any]) -> int | None: + raw_id = variant.get("id") + if not isinstance(raw_id, str): + return None + prefix, separator, _ = raw_id.partition("-") + if separator != "-" or not prefix.isdigit(): + return None + return int(prefix) + + +def _read_manifest_status(status_path: Path, trial_number: int) -> dict[str, Any]: + try: + status = json.loads(status_path.read_text()) + except json.JSONDecodeError as exc: + raise RuntimeError( + "Optuna resume requires status.json files to be valid JSON objects. " + f"Invalid status for trial number {trial_number} at {status_path}. " + "Restore the trial artifacts or start fresh." + ) from exc + if not isinstance(status, dict): + raise RuntimeError( + "Optuna resume requires status.json files to be valid JSON objects. " + f"Status for trial number {trial_number} at {status_path} is {type(status).__name__}. " + "Restore the trial artifacts or start fresh." + ) + return status + + +def _validate_resume_manifest_coverage(study: Any, previous_variants: list[dict[str, Any]]) -> None: + """Fail closed when Optuna storage has trials the manifest cannot describe.""" + storage_numbers = {trial.number for trial in study.trials} + manifest_numbers = [_variant_trial_number(variant) for variant in previous_variants] + invalid_manifest_ids = [ + variant.get("id") for variant, number in zip(previous_variants, manifest_numbers) if number is None + ] + manifest_counts = Counter(number for number in manifest_numbers if number is not None) + manifest_without_storage = sorted(number for number in manifest_counts if number not in storage_numbers) + missing: list[int] = [] + duplicate = sorted(number for number, count in manifest_counts.items() if count > 1) + missing_status: list[int] = [] + missing_resolved_checksum: list[int] = [] + status_id_mismatches: list[str] = [] + for trial in study.trials: + variants = _variants_for_trial_number(previous_variants, trial.number) + if not variants: + missing.append(trial.number) + continue + variant_id = variants[0].get("id") + raw_status_path = variants[0].get("status_path") + if not isinstance(raw_status_path, str) or not raw_status_path: + missing_status.append(trial.number) + continue + status_path = Path(raw_status_path) + if not status_path.is_file(): + missing_status.append(trial.number) + continue + status = _read_manifest_status(status_path, trial.number) + status_id = status.get("id") + if status_id != variant_id: + status_id_mismatches.append( + f"trial {trial.number}: manifest id={variant_id!r}, status id={status_id!r}" + ) + resolved_checksum = variants[0].get("resolved_checksum") + if not isinstance(resolved_checksum, str): + missing_resolved_checksum.append(trial.number) + if missing: + raise RuntimeError( + "Optuna resume requires manifest variant entries for all existing storage trials. " + f"Missing trial number(s): {missing}. Restore the manifest or start a fresh study/output_dir." + ) + if duplicate: + raise RuntimeError( + "Optuna resume requires exactly one manifest variant entry per existing storage trial. " + f"Duplicate trial number(s): {duplicate}. Repair the manifest or start a fresh study/output_dir." + ) + if invalid_manifest_ids or manifest_without_storage: + raise RuntimeError( + "Optuna resume requires manifest variant entries and storage trials to agree. " + f"Invalid manifest id(s): {invalid_manifest_ids}; " + f"manifest trial number(s) missing from storage: {manifest_without_storage}. " + "Restore the Optuna storage or start a fresh study/output_dir." + ) + if missing_status: + raise RuntimeError( + "Optuna resume requires status.json files for all existing storage trials. " + f"Missing status for trial number(s): {missing_status}. Restore the trial artifacts or start fresh." + ) + if missing_resolved_checksum: + raise RuntimeError( + "Optuna resume requires manifest resolved_checksum entries for all existing storage trials. " + f"Missing resolved_checksum for trial number(s): {missing_resolved_checksum}. " + "Restore the manifest or start a fresh study/output_dir." + ) + if status_id_mismatches: + raise RuntimeError( + "Optuna resume requires each status.json id to match its manifest variant id. " + f"Mismatch(es): {status_id_mismatches}. Restore the trial artifacts or start fresh." + ) + + +def _validate_resume_manifest_trial_parameters( + study: Any, previous_variants: list[dict[str, Any]] +) -> None: + """Fail closed when manifest variant ids/overrides drift from Optuna storage.""" + mismatches: list[str] = [] + for trial in study.trials: + variants = _variants_for_trial_number(previous_variants, trial.number) + if len(variants) != 1: + # _validate_resume_manifest_coverage reports missing / duplicate + # variants; keep this check focused on parameter identity. + continue + variant = variants[0] + variant_id = variant.get("id") + overrides = variant.get("overrides") + if not isinstance(overrides, dict): + mismatches.append(f"trial {trial.number}: manifest id={variant_id!r} has no overrides object") + continue + + storage_params = dict(trial.params) + if overrides != storage_params: + mismatches.append( + f"trial {trial.number}: storage params={storage_params!r}, " + f"manifest overrides={overrides!r}" + ) + + expected_id = f"{trial.number:04d}-{parameters_hash(overrides)}" + if variant_id != expected_id: + mismatches.append( + f"trial {trial.number}: manifest id={variant_id!r}, expected id={expected_id!r}" + ) + + if mismatches: + raise RuntimeError( + "Optuna resume requires manifest variant ids and overrides to match Optuna storage " + f"parameters. Mismatch(es): {mismatches}. Restore the manifest/storage or start fresh." + ) + + +def _validate_resume_base_checksums( + config: SweepConfig, study: Any, previous_variants: list[dict[str, Any]] +) -> None: + """Fail closed when resumed Optuna trials came from different base TOML(s).""" + current_base_checksums = {base.as_posix(): file_checksum(base) for base in config.base} + missing_checksum_trials: list[int] = [] + missing_bases: dict[int, list[str]] = {} + extra_bases: dict[int, list[str]] = {} + changed_bases: dict[int, list[str]] = {} + + for trial in study.trials: + variants = _variants_for_trial_number(previous_variants, trial.number) + if len(variants) != 1: + # _validate_resume_manifest_coverage reports missing / duplicate + # variants; keep this check focused on checksum drift. + continue + expected_bases = variants[0].get("base_checksums") + if not isinstance(expected_bases, dict) or not expected_bases: + missing_checksum_trials.append(trial.number) + continue + + expected_base_checksums = {str(path): checksum for path, checksum in expected_bases.items()} + current_paths = set(current_base_checksums) + expected_paths = set(expected_base_checksums) + + missing = sorted(current_paths - expected_paths) + extra = sorted(expected_paths - current_paths) + changed = sorted( + path + for path, checksum in current_base_checksums.items() + if path in expected_base_checksums and expected_base_checksums[path] != checksum + ) + if missing: + missing_bases[trial.number] = missing + if extra: + extra_bases[trial.number] = extra + if changed: + changed_bases[trial.number] = changed + + if not (missing_checksum_trials or missing_bases or extra_bases or changed_bases): + return + + raise RuntimeError( + "Optuna resume requires manifest base config checksums for every existing storage trial, " + "and they must match the current base files. " + f"Missing checksum trial(s): {missing_checksum_trials}; " + f"missing base(s): {missing_bases}; extra base(s): {extra_bases}; " + f"changed base(s): {changed_bases}. " + "Restore the original base config(s) or start a fresh study/output_dir." + ) + + +def _validate_resume_status_consistency( + optuna: Any, study: Any, previous_variants: list[dict[str, Any]] +) -> None: + """Fail closed when terminal Optuna storage and sweep status files disagree.""" + mismatches: list[str] = [] + for trial in study.trials: + if trial.state == optuna.trial.TrialState.RUNNING: + continue + + _status_path, status = _variant_status_for_trial_number(previous_variants, trial.number) + if status is None: + # _validate_resume_manifest_coverage reports missing status files. + continue + + recorded_state = status.get("state") + if trial.state == optuna.trial.TrialState.COMPLETE: + status_objective = coerce_finite_float(status.get("objective")) + storage_objective = coerce_finite_float(trial.value) + if ( + recorded_state != "completed" + or status_objective is None + or storage_objective is None + or status_objective != storage_objective + ): + mismatches.append( + f"trial {trial.number}: storage=COMPLETE({storage_objective!r}), " + f"status={recorded_state!r}({status.get('objective')!r})" + ) + elif trial.state == optuna.trial.TrialState.PRUNED: + status_objective = coerce_finite_float(status.get("objective")) + if recorded_state != "pruned" or status_objective is not None: + mismatches.append( + f"trial {trial.number}: storage=PRUNED, " + f"status={recorded_state!r}({status.get('objective')!r})" + ) + elif trial.state == optuna.trial.TrialState.FAIL: + status_objective = coerce_finite_float(status.get("objective")) + if recorded_state != "failed" or status_objective is not None: + mismatches.append( + f"trial {trial.number}: storage=FAIL, " + f"status={recorded_state!r}({status.get('objective')!r})" + ) + else: + mismatches.append( + f"trial {trial.number}: unsupported storage state {trial.state!r} " + f"with status={recorded_state!r}" + ) + + if mismatches: + raise RuntimeError( + "Optuna resume requires terminal status.json files to match Optuna storage state. " + f"Mismatch(es): {mismatches}. Restore the trial artifacts/storage or start fresh." + ) + + +def _count_optuna_failures(optuna: Any, study: Any) -> int: + return sum(1 for trial in study.trials if trial.state == optuna.trial.TrialState.FAIL) + + +def _seed_tracker_from_previous(tracker: TrialOutcomeTracker, previous_variants: list[dict[str, Any]]) -> None: + for variant in previous_variants: + raw_status_path = variant.get("status_path") + if not isinstance(raw_status_path, str) or not raw_status_path: + continue + status_path = Path(raw_status_path) + if not status_path.is_file(): + continue + trial_number = _variant_trial_number(variant) + status = _read_manifest_status(status_path, -1 if trial_number is None else trial_number) + if status.get("state") != "completed": + continue + tracker.observe( + TrialOutcome( + trial_id=variant.get("id", ""), + label=variant.get("label", "") or variant.get("id", ""), + objective=status.get("objective"), + ) + ) + + +def _variant_status_for_trial_number( + previous_variants: list[dict[str, Any]], + trial_number: int, +) -> tuple[Path | None, dict[str, Any] | None]: + """Match an Optuna trial number to its sweep status via the ``NNNN-...`` id prefix.""" + prefix = f"{trial_number:04d}-" + for variant in previous_variants: + if not variant.get("id", "").startswith(prefix): + continue + raw_status_path = variant.get("status_path") + if not isinstance(raw_status_path, str) or not raw_status_path: + return None, None + status_path = Path(raw_status_path) + if not status_path.is_file(): + return status_path, None + return status_path, _read_manifest_status(status_path, trial_number) + return None, None + + +def _reconcile_running_trials( + optuna: Any, study: Any, previous_variants: list[dict[str, Any]] +) -> tuple[int, int]: + """Tell Optuna about any RUNNING trials left over from an interrupted run. + + A controller crash between ``study.ask()`` and ``study.tell()`` leaves a + trial RUNNING in persistent storage forever. On resume we walk those + trials and: + + - if the matching sweep status.json shows ``completed`` with a finite + objective, tell Optuna the value so adaptive sampling can use it; + - if status.json shows ``pruned``, tell ``TrialState.PRUNED`` so the + sampler treats the slot as a deliberate stop (a crash between + ``record_trial_pruned`` and ``study.tell(PRUNED)`` would otherwise + misclassify it as a failure); + - otherwise tell ``TrialState.FAIL`` and mark any matching stale + status file failed so Optuna storage and the sweep manifest agree. + + Returns ``(reconciled, failures)``. ``failures`` counts RUNNING trials + reconciled to ``TrialState.FAIL`` so the sweep process exits non-zero in + the same way it would have if the original controller had observed the + failure before crashing. + """ + reconciled = 0 + failures = 0 + for trial in study.trials: + if trial.state != optuna.trial.TrialState.RUNNING: + continue + status_path, status = _variant_status_for_trial_number(previous_variants, trial.number) + recorded_state = status.get("state") if status is not None else None + objective: float | None = None + if recorded_state == "completed": + value = status.get("objective") if status is not None else None + objective = coerce_finite_float(value) + # study.tell() accepts a trial number or a Trial; FrozenTrial is not + # accepted, so pass trial.number. + if recorded_state == "pruned": + study.tell(trial.number, state=optuna.trial.TrialState.PRUNED) + if status_path is not None and status_path.is_file(): + pruned_status = status or {} + pruned_status["objective"] = None + write_json(status_path, pruned_status) + elif objective is not None: + study.tell(trial.number, objective) + else: + study.tell(trial.number, state=optuna.trial.TrialState.FAIL) + failures += 1 + if status_path is not None and status_path.is_file(): + failed_status = status or {} + raw_returncode = failed_status.get("returncode") + returncode = raw_returncode if type(raw_returncode) is int else -1 + failed_status.update( + { + "state": "failed", + "finished_at": utc_now(), + "returncode": returncode, + "objective": None, + } + ) + if returncode == 0: + failed_status["failure_stage"] = "objective" + failed_status["error"] = ( + "Trial exited successfully but did not record a finite objective before resume." + ) + write_json(status_path, failed_status) + reconciled += 1 + return reconciled, failures + + +def run_optuna_sweep( + config: SweepConfig, + write_manifest_with_variants: Any, + build_variant: Any, +) -> tuple[int, TrialOutcomeTracker | None, list[TrialArtifacts]]: + """Drive an Optuna study end-to-end. + + Returns ``(failures, tracker, artifacts)`` so the caller can write the + final manifest summary and exit code in the same shape as the static + flow. Resume honors persistent storage: previously consumed slots in + ``study.trials`` are not re-asked, the manifest preserves earlier + variants, and the tracker is seeded from prior outcomes. + """ + optuna = _import_optuna() + strategy = config.strategy + assert isinstance(strategy, OptunaStrategyConfig) + # Local and synchronous-SLURM schedulers are both supported. Asynchronous + # SLURM and multi_run_lora are rejected upstream by the SweepConfig validator. + assert isinstance(config.scheduler, (LocalSweepSchedulerConfig, SlurmSweepSchedulerConfig)) + + study = _create_study(optuna, config) + + if isinstance(config.scheduler, LocalSweepSchedulerConfig): + gpu_groups = ( + config.scheduler.gpu_assignment.visible_devices + if config.scheduler.gpu_assignment is not None + else None + ) + gpu_group = gpu_groups[0] if gpu_groups else None + else: + gpu_group = None + + tracker = TrialOutcomeTracker(config.objective, config.early_stopping) if config.objective else None + + previous_variants = _load_previous_variants(config) if config.resume else [] + artifacts: list[TrialArtifacts] = [] + failures = 0 + if config.resume: + _validate_resume_manifest_coverage(study, previous_variants) + _validate_resume_manifest_trial_parameters(study, previous_variants) + _validate_resume_base_checksums(config, study, previous_variants) + _validate_resume_status_consistency(optuna, study, previous_variants) + reconciled, _ = _reconcile_running_trials(optuna, study, previous_variants) + if reconciled: + print(f"Reconciled {reconciled} RUNNING Optuna trial(s) from interrupted resume.") + failures = _count_optuna_failures(optuna, study) + if tracker is not None: + _seed_tracker_from_previous(tracker, previous_variants) + if failures > 0 and not config.continue_on_failure: + return failures, tracker, artifacts + + already_consumed = len(study.trials) if config.resume else 0 + + for index in range(already_consumed, strategy.num_trials): + if tracker is not None and tracker.halted: + break + + optuna_trial = study.ask() + params = _suggest_parameters(optuna_trial, config.parameters) + trial = _make_trial(index, params) + + try: + artifact = materialize_trial(config, trial) + except Exception as exc: + # Sampled parameters failed target-config validation. Mark the + # asked trial failed in Optuna so persistent storage doesn't + # leak a RUNNING slot, and write a manifest/status artifact so a + # later resume can account for the terminal storage trial. + artifact = record_trial_materialization_failure( + config, trial, exc, finished_at=utc_now() + ) + artifacts.append(artifact) + write_manifest_with_variants( + config, previous_variants + [build_variant(a) for a in artifacts] + ) + study.tell(optuna_trial, state=optuna.trial.TrialState.FAIL) + failures += 1 + print(f"Optuna trial {index:04d} failed materialization: {exc}") + if not config.continue_on_failure: + break + continue + + artifacts.append(artifact) + write_manifest_with_variants(config, previous_variants + [build_variant(a) for a in artifacts]) + stop_after_trial = False + + slurm_sync = ( + isinstance(config.scheduler, SlurmSweepSchedulerConfig) + and config.scheduler.synchronous + ) + if isinstance(strategy.pruner, NoPrunerConfig): + if slurm_sync: + returncode = _run_with_retries_slurm_sync(artifact, config.retry_budget) + else: + returncode = _run_with_retries(artifact, gpu_group, config.retry_budget) + objective_value = ( + read_final_summary(artifact.run_dir, config.objective.metric) + if returncode == 0 and config.objective is not None + else None + ) + record_trial_objective(artifact.status_path, objective_value) + + if objective_value is None: + if returncode == 0: + record_trial_missing_objective(artifact.status_path, config.objective.metric) + study.tell(optuna_trial, state=optuna.trial.TrialState.FAIL) + else: + study.tell(optuna_trial, objective_value) + + # A clean exit without a logged objective (returncode==0 but + # objective_value is None) is also a sweep-level failure: Optuna + # learned nothing from it, the sampler recorded TrialState.FAIL, + # and the user almost certainly wants to be alerted rather than + # let the sweep finish 'successfully' with no usable results. + if returncode != 0 or objective_value is None: + failures += 1 + if not config.continue_on_failure: + stop_after_trial = True + else: + if slurm_sync: + outcome = _run_trial_with_pruning_slurm_sync_and_retries( + artifact, + optuna_trial, + config.objective.metric, + strategy.poll_interval_seconds, + config.retry_budget, + ) + else: + outcome = _run_trial_with_pruning_and_retries( + artifact, + gpu_group, + optuna_trial, + config.objective.metric, + strategy.poll_interval_seconds, + config.retry_budget, + ) + objective_value = outcome.objective + if outcome.state == "completed": + record_trial_objective(artifact.status_path, objective_value) + if objective_value is None: + # Completed without a recorded objective (e.g. metric + # never logged): treat as a sweep-level failure too, + # not just an Optuna FAIL — the sweep produced no + # usable result for this trial. + record_trial_missing_objective(artifact.status_path, config.objective.metric) + study.tell(optuna_trial, state=optuna.trial.TrialState.FAIL) + failures += 1 + if not config.continue_on_failure: + stop_after_trial = True + else: + study.tell(optuna_trial, objective_value) + elif outcome.state == "pruned": + # record_trial_pruned already set status.json fields. + study.tell(optuna_trial, state=optuna.trial.TrialState.PRUNED) + else: # failed + record_trial_objective(artifact.status_path, None) + study.tell(optuna_trial, state=optuna.trial.TrialState.FAIL) + failures += 1 + if not config.continue_on_failure or outcome.unsafe_to_continue: + # ``unsafe_to_continue`` forces a halt even when the user + # set ``continue_on_failure=True``: the underlying SLURM + # job may still be running (persistent squeue failure or + # unconfirmed scancel), and submitting the next trial + # would race the still-active allocation. + stop_after_trial = True + + if tracker is not None: + tracker_outcome = TrialOutcome( + trial_id=trial.id, + label=trial.label, + objective=objective_value, + ) + if tracker.observe(tracker_outcome): + break + if stop_after_trial: + break + + write_manifest_with_variants(config, previous_variants + [build_variant(a) for a in artifacts]) + + return failures, tracker, artifacts + + +@dataclass +class _SlurmSyncWorkerResult: + """Outcome of one worker thread running a SLURM-sync trial.""" + + returncode: int + objective: float | None + + +def _run_one_slurm_sync_no_pruner( + artifact: TrialArtifacts, + metric: str, + retry_budget: int, +) -> _SlurmSyncWorkerResult: + """Worker function for parallel SLURM-sync sweeps without a pruner. + + Synchronous SLURM with no pruner: submit with ``sbatch --wait`` (via + ``_run_with_retries_slurm_sync``) and read the final objective from + metrics.jsonl on success. + """ + returncode = _run_with_retries_slurm_sync(artifact, retry_budget) + objective = read_final_summary(artifact.run_dir, metric) if returncode == 0 else None + return _SlurmSyncWorkerResult(returncode=returncode, objective=objective) + + +def _run_one_slurm_sync_with_pruner( + artifact: TrialArtifacts, + optuna_trial: "optuna.Trial", # noqa: F821 — forward ref for optional dep + metric: str, + poll_interval: float, + retry_budget: int, +) -> "_PollingOutcome": + """Worker function for parallel SLURM-sync sweeps with a pruner. + + Delegates to ``_run_trial_with_pruning_slurm_sync_and_retries`` (the + same code path used by the serial pruner+SLURM-sync runner). Optuna's + storage backend serializes concurrent ``report``/``should_prune`` + calls across worker threads — Optuna's own ``study.optimize(n_jobs>1)`` + documents this contract — so each worker holding its own + ``optuna_trial`` is safe even when N polling threads call + ``should_prune`` simultaneously against the same study. + """ + return _run_trial_with_pruning_slurm_sync_and_retries( + artifact, + optuna_trial, + metric, + poll_interval, + retry_budget, + ) + + +def run_optuna_sweep_parallel_slurm( + config: SweepConfig, + write_manifest_with_variants: Any, + build_variant: Any, +) -> tuple[int, TrialOutcomeTracker | None, list[TrialArtifacts]]: + """Drive an Optuna study with up to ``max_parallel`` concurrent SLURM jobs. + + Architecture: the main thread owns the Optuna study and all ask/tell + interactions; a ``ThreadPoolExecutor`` with ``max_workers=max_parallel`` + runs the per-trial ``sbatch --wait`` calls so up to N trials can be + in flight at once. As each future completes the main thread tells + Optuna the outcome and immediately asks for one more trial to refill + the slot, until ``num_trials`` is reached or a halt condition fires. + + Pruners are supported under parallel SLURM-sync as of this PR. Each + worker holds its own ``optuna_trial`` (each from a distinct + ``study.ask()``) and runs the same polling loop the serial pruner + runner uses. Optuna's storage backend serializes the concurrent + ``report``/``should_prune`` calls across threads, the same contract + that makes ``study.optimize(n_jobs>1)`` work in stock Optuna. + + Halt conditions: + - ``tracker.observe`` returns True → stop submitting new trials, wait + for in-flight to finish (early stopping). + - ``unsafe_to_continue`` from a worker (persistent squeue failure or + unconfirmed scancel during a prune) → halt new submissions, drain. + Only fires when a pruner is active (the no-pruner runner doesn't + emit it). + - ``not config.continue_on_failure`` after any failure → halt new + submissions, drain. + """ + from concurrent.futures import FIRST_COMPLETED, ThreadPoolExecutor, wait + + optuna = _import_optuna() + strategy = config.strategy + assert isinstance(strategy, OptunaStrategyConfig) + assert isinstance(config.scheduler, SlurmSweepSchedulerConfig) + assert config.scheduler.synchronous, "parallel SLURM requires synchronous=true" + max_parallel = config.scheduler.max_parallel + assert max_parallel > 1 + use_pruner = not isinstance(strategy.pruner, NoPrunerConfig) + + study = _create_study(optuna, config) + tracker = TrialOutcomeTracker(config.objective, config.early_stopping) if config.objective else None + + previous_variants = _load_previous_variants(config) if config.resume else [] + artifacts: list[TrialArtifacts] = [] + failures = 0 + if config.resume: + _validate_resume_manifest_coverage(study, previous_variants) + _validate_resume_manifest_trial_parameters(study, previous_variants) + _validate_resume_base_checksums(config, study, previous_variants) + _validate_resume_status_consistency(optuna, study, previous_variants) + reconciled, _ = _reconcile_running_trials(optuna, study, previous_variants) + if reconciled: + print(f"Reconciled {reconciled} RUNNING Optuna trial(s) from interrupted resume.") + failures = _count_optuna_failures(optuna, study) + if tracker is not None: + _seed_tracker_from_previous(tracker, previous_variants) + if failures > 0 and not config.continue_on_failure: + return failures, tracker, artifacts + + already_consumed = len(study.trials) if config.resume else 0 + next_index = already_consumed + halted = False + # Manifest writes happen on the main thread (single producer), so no + # lock is needed — but Optuna's in-memory study is read by the + # asker; ask/tell calls are all serialized by the main thread loop. + + def ask_and_materialize() -> tuple[Any, TrialArtifacts | None] | None: + """Get the next trial from Optuna and materialize it. + + Returns ``(optuna_trial, artifact)`` on success, ``(optuna_trial, None)`` + when materialization fails (caller marks Optuna FAIL and skips), or + ``None`` when there are no more trials to ask for. + """ + nonlocal next_index, failures + if halted or next_index >= strategy.num_trials: + return None + if tracker is not None and tracker.halted: + return None + index = next_index + next_index += 1 + optuna_trial = study.ask() + params = _suggest_parameters(optuna_trial, config.parameters) + trial = _make_trial(index, params) + try: + artifact = materialize_trial(config, trial) + except Exception as exc: + failed_artifact = record_trial_materialization_failure( + config, trial, exc, finished_at=utc_now() + ) + artifacts.append(failed_artifact) + write_manifest_with_variants( + config, previous_variants + [build_variant(a) for a in artifacts] + ) + study.tell(optuna_trial, state=optuna.trial.TrialState.FAIL) + failures += 1 + print(f"Optuna trial {index:04d} failed materialization: {exc}") + return (optuna_trial, None) + artifacts.append(artifact) + write_manifest_with_variants( + config, previous_variants + [build_variant(a) for a in artifacts] + ) + return (optuna_trial, artifact) + + with ThreadPoolExecutor( + max_workers=max_parallel, thread_name_prefix="slurm-sync-trial" + ) as executor: + in_flight: dict = {} + + def submit_one() -> bool: + """Ask Optuna for one trial and submit it to the executor. + + Returns True when a trial is in flight, False when there are + no more trials to launch (either reached ``num_trials`` or a + halt condition is active). + """ + while True: + asked = ask_and_materialize() + if asked is None: + return False + optuna_trial, artifact = asked + if artifact is None: + # materialization failure already recorded; try next + # index to keep the slot full. + if not config.continue_on_failure: + return False + continue + if use_pruner: + future = executor.submit( + _run_one_slurm_sync_with_pruner, + artifact, + optuna_trial, + config.objective.metric, + strategy.poll_interval_seconds, + config.retry_budget, + ) + else: + future = executor.submit( + _run_one_slurm_sync_no_pruner, + artifact, + config.objective.metric, + config.retry_budget, + ) + in_flight[future] = (optuna_trial, artifact) + return True + + for _ in range(max_parallel): + if not submit_one(): + break + + while in_flight: + done, _pending = wait(in_flight.keys(), return_when=FIRST_COMPLETED) + for future in done: + optuna_trial, artifact = in_flight.pop(future) + trial_id = artifact.trial.id + trial_label_ = artifact.trial.label + try: + result = future.result() + except Exception as exc: + # Worker threw — surface as failure; the only paths + # that raise here are programming errors since + # _run_with_retries_slurm_sync catches OSError. + record_trial_objective(artifact.status_path, None) + study.tell(optuna_trial, state=optuna.trial.TrialState.FAIL) + failures += 1 + print(f"Optuna trial {trial_id} worker raised: {exc}") + if not config.continue_on_failure: + halted = True + continue + + objective_value = result.objective + if isinstance(result, _PollingOutcome): + # Pruner path: result carries a tristate (completed/ + # pruned/failed) plus the unsafe_to_continue flag for + # the SLURM-specific failure modes. + if result.state == "completed": + record_trial_objective(artifact.status_path, objective_value) + if objective_value is None: + record_trial_missing_objective( + artifact.status_path, config.objective.metric + ) + study.tell(optuna_trial, state=optuna.trial.TrialState.FAIL) + failures += 1 + if not config.continue_on_failure: + halted = True + else: + study.tell(optuna_trial, objective_value) + elif result.state == "pruned": + # record_trial_pruned already set status.json fields + # in the worker; just tell Optuna so the sampler + # treats this trial as a deliberate stop. + study.tell(optuna_trial, state=optuna.trial.TrialState.PRUNED) + else: # failed + record_trial_objective(artifact.status_path, None) + study.tell(optuna_trial, state=optuna.trial.TrialState.FAIL) + failures += 1 + if not config.continue_on_failure or result.unsafe_to_continue: + # unsafe_to_continue forces halt regardless of + # continue_on_failure: persistent squeue + # failure or unconfirmed scancel means the + # SLURM job may still be alive and submitting + # the next trial would race it. + halted = True + else: + # No-pruner path: result is _SlurmSyncWorkerResult. + returncode = result.returncode + record_trial_objective(artifact.status_path, objective_value) + + if objective_value is None: + if returncode == 0: + record_trial_missing_objective( + artifact.status_path, config.objective.metric + ) + study.tell(optuna_trial, state=optuna.trial.TrialState.FAIL) + else: + study.tell(optuna_trial, objective_value) + + if returncode != 0 or objective_value is None: + failures += 1 + if not config.continue_on_failure: + halted = True + + if tracker is not None: + if tracker.observe( + TrialOutcome( + trial_id=trial_id, + label=trial_label_, + objective=objective_value, + ) + ): + halted = True + + # Refill the freed slot unless we are halted. + if not halted: + submit_one() + + write_manifest_with_variants(config, previous_variants + [build_variant(a) for a in artifacts]) + return failures, tracker, artifacts + + +# Re-exported for the controller to update the manifest summary. +def tracker_summary(tracker: TrialOutcomeTracker) -> dict[str, Any]: + return asdict(tracker.summary()) diff --git a/src/prime_rl/sweep/reproducibility.py b/src/prime_rl/sweep/reproducibility.py new file mode 100644 index 0000000000..50b97aab4b --- /dev/null +++ b/src/prime_rl/sweep/reproducibility.py @@ -0,0 +1,43 @@ +"""Checksum and git metadata helpers for sweep manifests.""" + +import hashlib +import subprocess +from pathlib import Path +from typing import TypedDict + + +class GitMetadata(TypedDict): + sha: str | None + dirty: bool | None + + +def file_checksum(path: Path) -> str: + """SHA-256 hex digest of a file's contents.""" + return hashlib.sha256(path.read_bytes()).hexdigest() + + +def git_metadata(cwd: Path | None = None) -> GitMetadata: + """Capture git commit SHA and dirty flag at study creation time. + + Returns ``{"sha": None, "dirty": None}`` when not running inside a git + work tree or when git is unavailable. The absence is recorded explicitly + in the manifest rather than silently dropped. + """ + try: + sha = subprocess.run( + ["git", "rev-parse", "HEAD"], + cwd=cwd, + check=True, + capture_output=True, + text=True, + ).stdout.strip() + status = subprocess.run( + ["git", "status", "--porcelain"], + cwd=cwd, + check=True, + capture_output=True, + text=True, + ).stdout + return {"sha": sha, "dirty": status.strip() != ""} + except (FileNotFoundError, subprocess.CalledProcessError): + return {"sha": None, "dirty": None} diff --git a/src/prime_rl/sweep/run_control.py b/src/prime_rl/sweep/run_control.py new file mode 100644 index 0000000000..13ac5d7d11 --- /dev/null +++ b/src/prime_rl/sweep/run_control.py @@ -0,0 +1,83 @@ +"""Runtime control-file helpers for multi-run LoRA sweeps.""" + +from pathlib import Path +from typing import Any + +EXIT_CODE_FILENAME = "exit_code" +EVICTED_FILENAME = "evicted.txt" + + +def _mark_failed_orchestrator_evicted(run_dir: Path, code: int) -> None: + """Hide a failed run from the shared trainer's future directory scans.""" + if code == 0: + return + + control_dir = run_dir / "control" + control_dir.mkdir(parents=True, exist_ok=True) + evicted_path = control_dir / EVICTED_FILENAME + if not evicted_path.exists(): + evicted_path.write_text(f"orchestrator exited with code {code}\n") + + +def write_orchestrator_exit_code(run_dir: Path, returncode: int | None) -> None: + """Write a per-orchestrator returncode for the sweep controller to reconcile. + + The sweep controller reads each ``/control/exit_code`` after the + multi-run invocation exits, so it can attribute failures to the actual + orchestrator that crashed instead of marking every trial in the wave + failed. ``None`` means "the launcher tore down the orchestrator before it + produced an exit code"; we record ``-1`` so the controller treats it as an + infrastructure failure. Non-zero exits also write ``evicted.txt`` when the + controller/trainer did not already create one, so the shared trainer + forgets crashed runs instead of waiting on them for the rest of the wave. + """ + control_dir = run_dir / "control" + control_dir.mkdir(parents=True, exist_ok=True) + code = -1 if returncode is None else int(returncode) + (control_dir / EXIT_CODE_FILENAME).write_text(f"{code}\n") + _mark_failed_orchestrator_evicted(run_dir, code) + + +def record_orchestrator_exit_codes(orchestrator_processes: list[Any], run_dirs: list[Path]) -> None: + """Best-effort: write exit_code for every run dir, swallowing per-run write errors. + + A failure to write one exit_code must not prevent the others from being + recorded; the controller falls back to "infrastructure failure" when the + file is missing, which is at least diagnosable. + """ + for proc, run_dir in zip(orchestrator_processes, run_dirs): + try: + write_orchestrator_exit_code(run_dir, proc.returncode) + except OSError: + continue + + +def record_finished_orchestrator_exit_codes( + orchestrator_processes: list[Any], + run_dirs: list[Path], + recorded_run_dirs: set[Path], +) -> None: + """Publish per-run exit codes as orchestrators finish during a wave.""" + for proc, run_dir in zip(orchestrator_processes, run_dirs): + if run_dir in recorded_run_dirs or proc.poll() is None: + continue + try: + write_orchestrator_exit_code(run_dir, proc.returncode) + except OSError: + continue + recorded_run_dirs.add(run_dir) + + +def finished_orchestrator_failures( + orchestrator_labels: list[str], + orchestrator_processes: list[Any], + stop_events: dict[str, Any], +) -> list[tuple[str, int | None]]: + """Return failed orchestrators only once every orchestrator has stopped.""" + if not all(stop_events[label].is_set() for label in orchestrator_labels): + return [] + return [ + (label, proc.returncode) + for label, proc in zip(orchestrator_labels, orchestrator_processes) + if proc.returncode != 0 + ] diff --git a/src/prime_rl/sweep/schedulers.py b/src/prime_rl/sweep/schedulers.py new file mode 100644 index 0000000000..30769339b9 --- /dev/null +++ b/src/prime_rl/sweep/schedulers.py @@ -0,0 +1,1218 @@ +import os +import queue +import subprocess +import threading +import time +from collections.abc import Callable +from concurrent.futures import ThreadPoolExecutor +from datetime import datetime, timezone +from pathlib import Path +from typing import Literal + +from prime_rl.sweep.materialize import ( + TrialArtifacts, + read_status_json, + record_trial_pruned, + write_json, + write_multi_run_output_override, +) +from prime_rl.sweep.metrics import read_final_summary, read_intermediate_metric +from prime_rl.utils.monitor import SWEEP_METRICS_JSONL_ENV + +TrialCompleteCallback = Callable[[TrialArtifacts, int], bool] + + +def utc_now() -> str: + return datetime.now(timezone.utc).isoformat() + + +def _read_status(artifacts: TrialArtifacts) -> dict: + return read_status_json(artifacts.status_path) + + +def _write_status(artifacts: TrialArtifacts, **updates) -> None: + status = _read_status(artifacts) + status.update(updates) + write_json(artifacts.status_path, status) + + +def _launch_error(exc: OSError) -> str: + message = str(exc) + return f"{type(exc).__name__}: {message}" if message else type(exc).__name__ + + +def _write_launch_failure_status( + artifact: TrialArtifacts, + exc: OSError, + *, + finished_at: str | None = None, +) -> None: + _write_status( + artifact, + state="failed", + finished_at=finished_at or utc_now(), + returncode=-1, + objective=None, + failure_stage="launch", + error=_launch_error(exc), + ) + + +def _metrics_jsonl_path(artifact: TrialArtifacts) -> str: + return (artifact.run_dir / "metrics.jsonl").as_posix() + + +TRAINING_COMPLETE_SENTINEL = ".training_complete" +"""Marker file written by ``multi_node_rl.sbatch.j2`` immediately before the +trainer scancels its own job to release the inference allocation. Its presence +proves a CANCELLED terminal state was the expected self-teardown after a clean +``max_steps`` exit rather than an external cancel.""" + + +def _reset_metrics_jsonl(artifact: TrialArtifacts) -> None: + """Truncate the sidecar metrics file and clear stale per-attempt markers. + + FileMonitor opens in append mode, so without truncation a failed + attempt's later steps would survive into the retry. read_final_summary + selects the largest reported step, which would then return the failed + attempt's value instead of the successful retry's value. The pruning + loop has the same hazard: a stale row from a previous attempt can fire + should_prune() before the new attempt has reported anything. + + Legacy ``final_summary.json`` fallback files are attempt-scoped too. If + the new attempt never writes metrics, stale summaries from an older run + must not be mistaken for a fresh objective. + + The ``.training_complete`` sentinel (written by multi_node_rl.sbatch.j2 + right before trainer-rank-0 scancels its own job) is also cleared so a + stale marker from a previous attempt does not turn an actual failure + into a false "completed". + """ + path = Path(_metrics_jsonl_path(artifact)) + path.parent.mkdir(parents=True, exist_ok=True) + path.write_text("") + for summary_path in artifact.run_dir.glob("run-*/final_summary.json"): + summary_path.unlink() + sentinel = artifact.run_dir / TRAINING_COMPLETE_SENTINEL + if sentinel.exists(): + sentinel.unlink() + + +def _build_env(artifact: TrialArtifacts, gpu_group: list[int] | None) -> dict[str, str]: + """Inherit the parent env, pin CUDA_VISIBLE_DEVICES, and route the trial's + step-indexed metrics to the canonical sidecar file the sweep controller + reads (final objective + intermediate pruning). + """ + env = os.environ.copy() + if gpu_group is not None: + env["CUDA_VISIBLE_DEVICES"] = ",".join(str(d) for d in gpu_group) + env[SWEEP_METRICS_JSONL_ENV] = _metrics_jsonl_path(artifact) + return env + + +def _run_with_retries(artifact: TrialArtifacts, gpu_group: list[int] | None, retry_budget: int) -> int: + """Run a single trial, retrying transient failures up to ``retry_budget`` times. + + Returns the final returncode. Each attempt is recorded as a fresh + ``running`` transition with the cumulative attempt count and the assigned + device group in status.json. + """ + env = _build_env(artifact, gpu_group) + attempts = 0 + while True: + attempts += 1 + _reset_metrics_jsonl(artifact) + _write_status( + artifact, + state="running", + started_at=utc_now(), + attempts=attempts, + gpu_group=list(gpu_group) if gpu_group is not None else None, + ) + try: + result = subprocess.run(artifact.command, env=env) + except OSError as exc: + if attempts > retry_budget: + _write_launch_failure_status(artifact, exc) + return -1 + continue + if result.returncode == 0: + _write_status(artifact, state="completed", finished_at=utc_now(), returncode=0) + return 0 + if attempts > retry_budget: + _write_status(artifact, state="failed", finished_at=utc_now(), returncode=result.returncode) + return result.returncode + + +def _is_completed(artifact: TrialArtifacts) -> bool: + return _read_status(artifact).get("state") == "completed" + + +def _is_submitted_or_completed(artifact: TrialArtifacts) -> bool: + return _read_status(artifact).get("state") in {"completed", "submitted"} + + +def _run_sequential( + artifacts: list[TrialArtifacts], + gpu_group: list[int] | None, + continue_on_failure: bool, + retry_budget: int, + on_trial_complete: TrialCompleteCallback | None, +) -> int: + failures = 0 + for artifact in artifacts: + returncode = _run_with_retries(artifact, gpu_group, retry_budget) + stop_for_failure = returncode != 0 and not continue_on_failure + if returncode != 0: + failures += 1 + if on_trial_complete is not None and on_trial_complete(artifact, returncode): + break + if stop_for_failure: + break + return failures + + +def _run_parallel( + artifacts: list[TrialArtifacts], + max_parallel: int, + gpu_groups: list[list[int]], + continue_on_failure: bool, + retry_budget: int, + on_trial_complete: TrialCompleteCallback | None, +) -> int: + """Run trials concurrently, pinning each to a disjoint GPU group. + + The pool of GPU groups acts as a semaphore: a worker pulls a group before + launching its subprocess and returns it on completion. This guarantees no + two parallel trials share a device. + """ + group_pool: queue.Queue[list[int]] = queue.Queue() + for group in gpu_groups: + group_pool.put(group) + + halt = threading.Event() + failure_lock = threading.Lock() + failure_count = 0 + + def task(artifact: TrialArtifacts) -> None: + nonlocal failure_count + if halt.is_set(): + return + group = group_pool.get() + try: + returncode = _run_with_retries(artifact, group, retry_budget) + finally: + group_pool.put(group) + if returncode != 0: + with failure_lock: + failure_count += 1 + if not continue_on_failure: + halt.set() + if on_trial_complete is not None and on_trial_complete(artifact, returncode): + halt.set() + + with ThreadPoolExecutor(max_workers=max_parallel) as executor: + list(executor.map(task, artifacts)) + + return failure_count + + +def run_trials_locally( + artifacts: list[TrialArtifacts], + max_parallel: int = 1, + gpu_groups: list[list[int]] | None = None, + continue_on_failure: bool = True, + retry_budget: int = 1, + on_trial_complete: TrialCompleteCallback | None = None, +) -> int: + """Run trials sequentially or in parallel. Returns the failed-trial count. + + Trials whose status.json already records ``state == "completed"`` are + skipped so ``--resume`` only re-runs unfinished work. For parallel runs + the caller must pass ``gpu_groups`` with at least ``max_parallel`` disjoint + device groups; this is validated upstream by ``LocalSweepSchedulerConfig``. + The optional ``on_trial_complete`` callback runs after each completed + trial; returning True from it halts new submissions while in-flight + trials finish. + """ + pending = [artifact for artifact in artifacts if not _is_completed(artifact)] + + if max_parallel == 1: + single_group = gpu_groups[0] if gpu_groups else None + return _run_sequential(pending, single_group, continue_on_failure, retry_budget, on_trial_complete) + + if gpu_groups is None or len(gpu_groups) < max_parallel: + raise ValueError( + f"Parallel local scheduler requires gpu_groups with at least max_parallel={max_parallel} " + f"entries (got {0 if gpu_groups is None else len(gpu_groups)})." + ) + + return _run_parallel( + pending, + max_parallel, + gpu_groups[:max_parallel], + continue_on_failure, + retry_budget, + on_trial_complete, + ) + + +def submit_trials_to_slurm( + artifacts: list[TrialArtifacts], + continue_on_failure: bool = True, + retry_budget: int = 1, + synchronous: bool = False, + on_trial_complete: TrialCompleteCallback | None = None, +) -> int: + """Submit trials through the target entrypoint's SLURM support. + + The target entrypoint owns SLURM rendering/submission. Throughput is + governed by the cluster's own scheduling, not this controller, so there + is no in-flight cap here. Submission failures (not job failures) are + retried up to ``retry_budget``. + + When ``synchronous=True``, each trial is submitted via ``sbatch --wait`` + and the controller blocks until that job exits. The trial state moves + pending -> running -> completed/failed, matching the local scheduler's + contract, so Optuna and early stopping can observe per-trial outcomes. + The ``on_trial_complete`` callback fires after each trial finishes and + can halt new submissions (used for trial-level early stopping). + """ + if synchronous: + return _submit_trials_to_slurm_sync( + artifacts, + continue_on_failure=continue_on_failure, + retry_budget=retry_budget, + on_trial_complete=on_trial_complete, + ) + + failures = 0 + for artifact in artifacts: + if _is_submitted_or_completed(artifact): + continue + attempts = 0 + while True: + attempts += 1 + _write_status(artifact, state="submitting", started_at=utc_now(), attempts=attempts) + try: + result = subprocess.run(artifact.command) + except OSError as exc: + if attempts > retry_budget: + _write_launch_failure_status(artifact, exc) + failures += 1 + if not continue_on_failure: + raise SystemExit(1) from exc + break + continue + if result.returncode == 0: + _write_status(artifact, state="submitted", finished_at=utc_now(), returncode=0) + break + if attempts > retry_budget: + _write_status(artifact, state="failed", finished_at=utc_now(), returncode=result.returncode) + failures += 1 + if not continue_on_failure: + raise SystemExit(result.returncode) + break + return failures + + +def _slurm_script_path(artifact: TrialArtifacts) -> Path: + """Return the sbatch script the entrypoint's ``--dry-run`` materializes. + + Each entrypoint writes ``/.sbatch`` (``rl.sbatch``, + ``sft.sbatch``), not a fixed filename — so the synchronous SLURM path + must derive the script name from the trial's command rather than + hard-coding ``rl.sbatch``. ``artifact.command`` is shaped as + ``["uv", "run", "", ...]``. + """ + entrypoint = artifact.command[2] + return artifact.run_dir / f"{entrypoint}.sbatch" + + +def _run_with_retries_slurm_sync(artifact: TrialArtifacts, retry_budget: int) -> int: + """Run one trial under SLURM, blocking until the job exits. + + Two-step: (1) ``uv run rl ... --dry-run`` renders the sbatch script, + (2) ``sbatch --wait