From f6483faed75128cca19d295d7d92579e7df32b30 Mon Sep 17 00:00:00 2001 From: nph4rd Date: Tue, 24 Feb 2026 20:50:55 -0600 Subject: [PATCH 01/16] support per-agent rewards from multi-agent environments --- src/prime_rl/orchestrator/orchestrator.py | 6 ++++-- src/prime_rl/orchestrator/trajectories.py | 9 ++++++++- 2 files changed, 12 insertions(+), 3 deletions(-) diff --git a/src/prime_rl/orchestrator/orchestrator.py b/src/prime_rl/orchestrator/orchestrator.py index 1b0cb4b3ee..8bfb6c9458 100644 --- a/src/prime_rl/orchestrator/orchestrator.py +++ b/src/prime_rl/orchestrator/orchestrator.py @@ -548,8 +548,10 @@ def process_rollout(rollout: vf.RolloutOutput, rollout_idx: int) -> list[Trainin samples = [] rollout_samples_per_rollout.append(len(samples)) for sample in samples: - sample.advantage = rollout["advantage"] - sample.reward = rollout["reward"] + if sample.advantage is None: + sample.advantage = rollout["advantage"] + if sample.reward is None: + sample.reward = rollout["reward"] sample.env_name = rollout["env_name"] sample.training_mode = config.training_mode sample_decode_tokens = sum(sample.completion_mask) diff --git a/src/prime_rl/orchestrator/trajectories.py b/src/prime_rl/orchestrator/trajectories.py index 94412eb8bf..11403effce 100644 --- a/src/prime_rl/orchestrator/trajectories.py +++ b/src/prime_rl/orchestrator/trajectories.py @@ -336,7 +336,8 @@ def make_sample(tokens: dict[str, Any]) -> TrainingSample: completion_logprobs=list(tokens["completion_logprobs"]), completion_temperatures=[temperature] * len(completion_ids), teacher_logprobs=None, - advantage=None, + advantage=step.get("advantage"), + reward=step.get("reward"), env_name=output["env_name"], routed_experts=routed_experts, mm_token_type_ids=None, @@ -363,6 +364,12 @@ def extend_sample(sample: TrainingSample, prefix_len: int, step_idx: int) -> Non sample.completion_logprobs.extend(tokens["completion_logprobs"]) sample.completion_temperatures.extend([temperature] * len(completion_ids)) + # Update reward/advantage to use the latest merged step's values (multi-agent) + if step.get("reward") is not None: + sample.reward = step["reward"] + if step.get("advantage") is not None: + sample.advantage = step["advantage"] + if tokens.get("routed_experts") is not None and sample.routed_experts is not None: step_routed = tokens["routed_experts"] # The previous step's last routing entry was zero-padded by _align_routed_experts From a471966ef913e4ee9e304b6b89d32c9f87474914 Mon Sep 17 00:00:00 2001 From: nph4rd Date: Tue, 24 Feb 2026 21:08:03 -0600 Subject: [PATCH 02/16] point verifiers to multiagent-heterogeneous-rewards branch --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index b8c1500971..99e8e5756e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -193,7 +193,7 @@ nixl-cu12 = false [tool.uv.sources] prime-rl-configs = { workspace = true } -verifiers = { workspace = true } +verifiers = { git = "https://github.com/nph4rd/verifiers.git", branch = "multiagent-no-opponent-conditioning" } renderers = { workspace = true } aime2024 = { workspace = true } aime2025 = { workspace = true } From 0a1bc6c62ed1a0f112426f477daf15e1834e97e9 Mon Sep 17 00:00:00 2001 From: nph4rd Date: Fri, 6 Mar 2026 18:32:30 -0600 Subject: [PATCH 03/16] log per-agent rewards to wandb for multi-agent environments --- src/prime_rl/orchestrator/envs.py | 2 +- src/prime_rl/orchestrator/orchestrator.py | 13 +++++++++++++ 2 files changed, 14 insertions(+), 1 deletion(-) diff --git a/src/prime_rl/orchestrator/envs.py b/src/prime_rl/orchestrator/envs.py index 1de52994a9..88a4463903 100644 --- a/src/prime_rl/orchestrator/envs.py +++ b/src/prime_rl/orchestrator/envs.py @@ -21,7 +21,7 @@ from prime_rl.utils.monitor import get_monitor from prime_rl.utils.utils import capitalize -REQUIRED_STATE_COLUMNS = ["trajectory", "sampling_args"] +REQUIRED_STATE_COLUMNS = ["trajectory", "sampling_args", "agent_rewards"] class Env: diff --git a/src/prime_rl/orchestrator/orchestrator.py b/src/prime_rl/orchestrator/orchestrator.py index 8bfb6c9458..4ac02e36b2 100644 --- a/src/prime_rl/orchestrator/orchestrator.py +++ b/src/prime_rl/orchestrator/orchestrator.py @@ -721,6 +721,19 @@ def compute_solve_rates(df): "step": progress.step, } + # Per-agent reward metrics (for multi-agent environments) + agent_reward_sums: dict[str, float] = {} + agent_reward_counts: dict[str, int] = {} + for rollout in train_rollouts: + agent_rewards = rollout.get("agent_rewards") + if not agent_rewards: + continue + for agent_id, reward in agent_rewards.items(): + agent_reward_sums[agent_id] = agent_reward_sums.get(agent_id, 0.0) + reward + agent_reward_counts[agent_id] = agent_reward_counts.get(agent_id, 0) + 1 + for agent_id in agent_reward_sums: + to_log[f"agent_reward/{agent_id}"] = agent_reward_sums[agent_id] / agent_reward_counts[agent_id] + # Per-env metrics per_env_columns = [ "seq_len", From 4710803c28340ab09c830c750b4102250dc9b398 Mon Sep 17 00:00:00 2001 From: nph4rd Date: Fri, 6 Mar 2026 18:43:05 -0600 Subject: [PATCH 04/16] revert agent_rewards logging, now handled via metrics --- src/prime_rl/orchestrator/envs.py | 2 +- src/prime_rl/orchestrator/orchestrator.py | 13 ------------- 2 files changed, 1 insertion(+), 14 deletions(-) diff --git a/src/prime_rl/orchestrator/envs.py b/src/prime_rl/orchestrator/envs.py index 88a4463903..1de52994a9 100644 --- a/src/prime_rl/orchestrator/envs.py +++ b/src/prime_rl/orchestrator/envs.py @@ -21,7 +21,7 @@ from prime_rl.utils.monitor import get_monitor from prime_rl.utils.utils import capitalize -REQUIRED_STATE_COLUMNS = ["trajectory", "sampling_args", "agent_rewards"] +REQUIRED_STATE_COLUMNS = ["trajectory", "sampling_args"] class Env: diff --git a/src/prime_rl/orchestrator/orchestrator.py b/src/prime_rl/orchestrator/orchestrator.py index 4ac02e36b2..8bfb6c9458 100644 --- a/src/prime_rl/orchestrator/orchestrator.py +++ b/src/prime_rl/orchestrator/orchestrator.py @@ -721,19 +721,6 @@ def compute_solve_rates(df): "step": progress.step, } - # Per-agent reward metrics (for multi-agent environments) - agent_reward_sums: dict[str, float] = {} - agent_reward_counts: dict[str, int] = {} - for rollout in train_rollouts: - agent_rewards = rollout.get("agent_rewards") - if not agent_rewards: - continue - for agent_id, reward in agent_rewards.items(): - agent_reward_sums[agent_id] = agent_reward_sums.get(agent_id, 0.0) + reward - agent_reward_counts[agent_id] = agent_reward_counts.get(agent_id, 0) + 1 - for agent_id in agent_reward_sums: - to_log[f"agent_reward/{agent_id}"] = agent_reward_sums[agent_id] / agent_reward_counts[agent_id] - # Per-env metrics per_env_columns = [ "seq_len", From 792a461a853448a9618e2a953b1bd36af17db093 Mon Sep 17 00:00:00 2001 From: nph4rd Date: Fri, 6 Mar 2026 19:29:11 -0600 Subject: [PATCH 05/16] compute per-agent grpo advantages for multi-agent environments --- src/prime_rl/orchestrator/advantage.py | 52 +++++++++++++++++++++++ src/prime_rl/orchestrator/orchestrator.py | 3 +- 2 files changed, 54 insertions(+), 1 deletion(-) diff --git a/src/prime_rl/orchestrator/advantage.py b/src/prime_rl/orchestrator/advantage.py index 63b1d50325..50749c3b9d 100644 --- a/src/prime_rl/orchestrator/advantage.py +++ b/src/prime_rl/orchestrator/advantage.py @@ -1,3 +1,4 @@ +from collections import defaultdict from dataclasses import dataclass from typing import Callable @@ -164,3 +165,54 @@ def compute_advantages( for rollout, advantage in zip(rollouts, advantages): rollout["advantage"] = advantage + + +def compute_per_agent_advantages(rollouts: list[dict]) -> None: + """Compute per-agent GRPO advantages for multi-agent rollouts. + + For multi-agent environments, each trajectory step is tagged with an agent_id + and has a per-agent reward. Standard GRPO computes advantages from the rollout- + level mean reward, which is invariant when agent payoffs sum to a constant. + + This function computes advantages per agent: for each (example, agent) pair, + the baseline is the mean of that agent's rewards across rollouts of the same + example. Advantages are written directly to trajectory steps so they flow + through interleave_rollout -> TrainingSample.advantage. + + No-ops if rollouts don't contain per-agent trajectory steps. + """ + if not rollouts: + return + + # Quick check: do rollouts have per-agent trajectory steps? + has_agents = False + for r in rollouts[:3]: + for step in r.get("trajectory", []): + if step.get("extras", {}).get("agent_id"): + has_agents = True + break + if has_agents: + break + if not has_agents: + return + + # Group rollouts by example_id + groups: dict[str, list[dict]] = defaultdict(list) + for r in rollouts: + groups[r["example_id"]].append(r) + + for group in groups.values(): + # Collect per-agent rewards: agent_id -> list of (step, reward) + agent_entries: dict[str, list[tuple[dict, float]]] = defaultdict(list) + for r in group: + for step in r.get("trajectory", []): + agent_id = step.get("extras", {}).get("agent_id") + reward = step.get("reward") + if agent_id is not None and reward is not None: + agent_entries[agent_id].append((step, reward)) + + # Compute per-agent baseline and set per-step advantages + for agent_id, entries in agent_entries.items(): + baseline = sum(reward for _, reward in entries) / len(entries) + for step, reward in entries: + step["advantage"] = reward - baseline diff --git a/src/prime_rl/orchestrator/orchestrator.py b/src/prime_rl/orchestrator/orchestrator.py index 8bfb6c9458..4996dd64e8 100644 --- a/src/prime_rl/orchestrator/orchestrator.py +++ b/src/prime_rl/orchestrator/orchestrator.py @@ -7,7 +7,7 @@ import tomli_w import prime_rl._compat # noqa: F401 — patch ring_flash_attn compat before transitive import -from prime_rl.orchestrator.advantage import compute_advantages +from prime_rl.orchestrator.advantage import compute_advantages, compute_per_agent_advantages from prime_rl.orchestrator.eval_utils import compute_eval_ckpt_step from prime_rl.orchestrator.event_loop_lag import EventLoopLagMonitor from prime_rl.orchestrator.inference_metrics import InferenceMetricsCollector @@ -427,6 +427,7 @@ async def orchestrate(config: OrchestratorConfig): num_rollouts = len(train_rollouts) num_unique_examples = len({(r["env_name"], r["example_id"]) for r in train_rollouts}) compute_advantages(train_rollouts, config.rollouts_per_example, config.advantage) + compute_per_agent_advantages(train_rollouts) # Apply rollout filters — sets rollout["filters"] and rollout["is_filtered"] apply_filters(rollout_filters, train_rollouts) From bef068f2d387fad936b5416beba759c76804792e Mon Sep 17 00:00:00 2001 From: nph4rd Date: Fri, 6 Mar 2026 20:30:58 -0600 Subject: [PATCH 06/16] respect per-step is_trainable flag in interleave_rollout --- src/prime_rl/orchestrator/trajectories.py | 16 +++++++++++----- 1 file changed, 11 insertions(+), 5 deletions(-) diff --git a/src/prime_rl/orchestrator/trajectories.py b/src/prime_rl/orchestrator/trajectories.py index 11403effce..a254cf4266 100644 --- a/src/prime_rl/orchestrator/trajectories.py +++ b/src/prime_rl/orchestrator/trajectories.py @@ -315,9 +315,14 @@ def prepare_step_tokens(step: vf.TrajectoryStep, step_idx: int) -> dict[str, Any return None prepared_steps.append(prepared) - def make_sample(tokens: dict[str, Any]) -> TrainingSample: + def _is_step_trainable(step: vf.TrajectoryStep) -> bool: + return step.get("extras", {}).get("is_trainable", True) + + def make_sample(step_idx: int) -> TrainingSample: """Create a new TrainingSample from a trajectory step.""" - if has_error: + step = trajectory[step_idx] + tokens = prepared_steps[step_idx] + if has_error or not _is_step_trainable(step): completion_mask = [False] * len(tokens["completion_mask"]) else: completion_mask = [bool(i) for i in tokens["completion_mask"]] @@ -345,6 +350,7 @@ def make_sample(tokens: dict[str, Any]) -> TrainingSample: def extend_sample(sample: TrainingSample, prefix_len: int, step_idx: int) -> None: """Extend an existing sample with a new trajectory step (extension property holds).""" + step = trajectory[step_idx] tokens = prepared_steps[step_idx] # Extend with new prompt tokens (mask=False, no gradient) @@ -357,7 +363,7 @@ def extend_sample(sample: TrainingSample, prefix_len: int, step_idx: int) -> Non # Extend with new completion tokens completion_ids = tokens["completion_ids"] sample.completion_ids.extend(completion_ids) - if has_error: + if has_error or not _is_step_trainable(step): sample.completion_mask.extend([False] * len(tokens["completion_mask"])) else: sample.completion_mask.extend(bool(i) for i in tokens["completion_mask"]) @@ -387,7 +393,7 @@ def extend_sample(sample: TrainingSample, prefix_len: int, step_idx: int) -> Non first_tokens = prepared_steps[0] first_prefix = first_tokens["prompt_ids"] + first_tokens["completion_ids"] - active_samples.append((first_prefix, make_sample(first_tokens), 0)) + active_samples.append((first_prefix, make_sample(0), 0)) for step_idx, _step in enumerate(trajectory[1:], start=1): tokens = prepared_steps[step_idx] @@ -412,7 +418,7 @@ def extend_sample(sample: TrainingSample, prefix_len: int, step_idx: int) -> Non f"Starting new sample (active_prefixes={len(active_samples)}, step_prompt_len={len(step_prompt_ids)})." ) new_prefix = tokens["prompt_ids"] + tokens["completion_ids"] - active_samples.append((new_prefix, make_sample(tokens), step_idx)) + active_samples.append((new_prefix, make_sample(step_idx), step_idx)) # Attach images once per sample using only the last merged step. Prompt # tokens already contain fully expanded <|image_pad|> placeholders because From 33b827be4b43eb7121d8e0da4edbb006ba08faeb Mon Sep 17 00:00:00 2001 From: nph4rd Date: Sat, 7 Mar 2026 23:15:16 -0600 Subject: [PATCH 07/16] add multi-agent lora support for per-agent policy training --- .../src/prime_rl/configs/orchestrator.py | 3 ++ .../src/prime_rl/configs/rl.py | 6 +++ .../src/prime_rl/utils/validation.py | 3 +- src/prime_rl/orchestrator/envs.py | 12 ++++- src/prime_rl/orchestrator/orchestrator.py | 53 +++++++++++++++++-- src/prime_rl/orchestrator/scheduler.py | 49 +++++++++++++++++ src/prime_rl/orchestrator/trajectories.py | 6 ++- src/prime_rl/transport/__init__.py | 11 ++++ src/prime_rl/transport/filesystem.py | 14 +++++ src/prime_rl/transport/types.py | 2 + 10 files changed, 153 insertions(+), 6 deletions(-) diff --git a/packages/prime-rl-configs/src/prime_rl/configs/orchestrator.py b/packages/prime-rl-configs/src/prime_rl/configs/orchestrator.py index 82ed1134c7..15c91ddb09 100644 --- a/packages/prime-rl-configs/src/prime_rl/configs/orchestrator.py +++ b/packages/prime-rl-configs/src/prime_rl/configs/orchestrator.py @@ -598,6 +598,9 @@ class OrchestratorConfig(BaseConfig): weight_broadcast: WeightBroadcastConfig = FileSystemWeightBroadcastConfig() """Transport used to receive updated weights from the trainer.""" + multi_agent_lora: bool = False + """Enable per-agent LoRA training for multi-agent environments.""" + rollout_transport: TransportConfig = FileSystemTransportConfig() """Transport used to ship rollouts from orchestrator to trainer.""" diff --git a/packages/prime-rl-configs/src/prime_rl/configs/rl.py b/packages/prime-rl-configs/src/prime_rl/configs/rl.py index 7a5809c1c6..e234da86f7 100644 --- a/packages/prime-rl-configs/src/prime_rl/configs/rl.py +++ b/packages/prime-rl-configs/src/prime_rl/configs/rl.py @@ -467,6 +467,12 @@ def auto_setup_lora(self): return self + @model_validator(mode="after") + def validate_multi_agent_lora(self): + if self.orchestrator.multi_agent_lora and self.trainer.model.lora is None: + raise ValueError("orchestrator.multi_agent_lora requires trainer.model.lora to be configured.") + return self + @model_validator(mode="after") def auto_setup_router_replay(self): if self.trainer.enable_router_replay: diff --git a/packages/prime-rl-configs/src/prime_rl/utils/validation.py b/packages/prime-rl-configs/src/prime_rl/utils/validation.py index 944917c252..f0baad637c 100644 --- a/packages/prime-rl-configs/src/prime_rl/utils/validation.py +++ b/packages/prime-rl-configs/src/prime_rl/utils/validation.py @@ -139,7 +139,8 @@ def propagate(shared_path: str, *targets: str, aliases: tuple[str, ...] = ()) -> if get(sub) is not None: conflicts.append(("output_dir", sub)) fill("trainer.output_dir", output_dir) - fill("orchestrator.output_dir", f"{output_dir}/run_default") + orchestrator_leaf = "orchestrator" if get("orchestrator.multi_agent_lora") is True else "run_default" + fill("orchestrator.output_dir", f"{output_dir}/{orchestrator_leaf}") # Cascade trainer.tokenizer.chat_template → inference.model.chat_template # (vLLM ``--chat-template``). Read trainer's value *after* the shared diff --git a/src/prime_rl/orchestrator/envs.py b/src/prime_rl/orchestrator/envs.py index 1de52994a9..4c70a1fdcb 100644 --- a/src/prime_rl/orchestrator/envs.py +++ b/src/prime_rl/orchestrator/envs.py @@ -116,6 +116,7 @@ async def run_rollout( example: dict, model_name: str, cache_salt: str, + actor_models: dict[str, str] | None = None, ) -> vf.RolloutOutput: """Run a single rollout for an example.""" return await self.env.run_rollout( @@ -126,6 +127,7 @@ async def run_rollout( max_retries=self.config.max_retries, state_columns=REQUIRED_STATE_COLUMNS, env_client=self.env_client, + actor_models=actor_models, ) async def run_group( @@ -135,6 +137,7 @@ async def run_group( model_name: str, rollouts_per_example: int, cache_salt: str, + actor_models: dict[str, str] | None = None, ) -> list[vf.RolloutOutput]: """Run a group of rollouts for an example. Required for group-scoring envs.""" return await self.env.run_group( @@ -145,6 +148,7 @@ async def run_group( max_retries=self.config.max_retries, state_columns=REQUIRED_STATE_COLUMNS, env_client=self.env_client, + actor_models=actor_models, ) def shutdown(self) -> None: @@ -180,6 +184,7 @@ async def evaluate( ckpt_step: int, step: int, cache_salt: str, + actor_models: dict[str, str] | None = None, ) -> list[vf.RolloutOutput]: num_examples = len(self.examples) rollouts_per_example = self.config.rollouts_per_example @@ -200,6 +205,7 @@ async def run_with_progress(example: dict) -> list[vf.RolloutOutput] | None: model_name=model_name, rollouts_per_example=rollouts_per_example, cache_salt=cache_salt, + actor_models=actor_models, ) pbar.update(rollouts_per_example) return outputs @@ -217,7 +223,11 @@ async def run_with_progress(example: dict) -> list[vf.RolloutOutput] | None: try: client = await get_client() output = await self.run_rollout( - client=client, example=example, model_name=model_name, cache_salt=cache_salt + client=client, + example=example, + model_name=model_name, + cache_salt=cache_salt, + actor_models=actor_models, ) pbar.update(1) return [output] diff --git a/src/prime_rl/orchestrator/orchestrator.py b/src/prime_rl/orchestrator/orchestrator.py index 4996dd64e8..f9aa5a299b 100644 --- a/src/prime_rl/orchestrator/orchestrator.py +++ b/src/prime_rl/orchestrator/orchestrator.py @@ -18,7 +18,12 @@ offload_images_to_disk, pretokenize_rollout_trajectory, ) -from prime_rl.transport import TrainingBatch, TrainingSample, setup_training_batch_sender +from prime_rl.transport import ( + TrainingBatch, + TrainingSample, + setup_multi_run_training_batch_sender, + setup_training_batch_sender, +) from prime_rl.utils.pathing import get_log_dir, get_rollout_dir, get_step_path from prime_rl.utils.usage_reporter import UsageReporter @@ -237,6 +242,27 @@ async def orchestrate(config: OrchestratorConfig): else: checkpoint_step = config.ckpt.resume_step + # Multi-agent LoRA: set up per-agent policy routing + actor_lora_mapping: dict[str, str] | None = None + if config.multi_agent_lora: + agent_sets = [tuple(getattr(env.env, "agents", ())) for env in train_envs] + agent_sets = [agents for agents in agent_sets if agents] + assert agent_sets, "multi_agent_lora requires at least one MultiAgentEnv with registered agents" + agents = agent_sets[0] + assert len(agents) >= 2, "multi_agent_lora requires a MultiAgentEnv with at least 2 registered agents" + assert all(agent_set == agents for agent_set in agent_sets), ( + "multi_agent_lora requires all MultiAgentEnv train envs to use the same agents" + ) + actor_lora_mapping = {agent_id: f"run_{agent_id}" for agent_id in agents} + logger.info(f"Multi-agent LoRA enabled: {actor_lora_mapping}") + + # Create per-actor run directories with orch.toml + for run_name in actor_lora_mapping.values(): + run_config_dir = config.output_dir.parent / run_name / "control" + run_config_dir.mkdir(parents=True, exist_ok=True) + with open(run_config_dir / "orch.toml", "wb") as f: + tomli_w.dump(config.model_dump(exclude_none=True, mode="json"), f) + scheduler = Scheduler( train_envs=train_envs, buffer=buffer, @@ -249,6 +275,7 @@ async def orchestrate(config: OrchestratorConfig): tasks_per_minute=config.tasks_per_minute, lora_name=config.student.model.lora.name if config.student.model.lora else None, config=config, + actor_lora_mapping=actor_lora_mapping, ) # Wait for pools to be ready @@ -281,7 +308,14 @@ async def orchestrate(config: OrchestratorConfig): # Setup training batch sender for sending training examples to trainer logger.info(f"Initializing training batch sender ({config.rollout_transport})") - training_batch_sender = setup_training_batch_sender(config.output_dir, config.rollout_transport) + if actor_lora_mapping is not None: + multi_run_sender = setup_multi_run_training_batch_sender( + config.output_dir.parent, list(actor_lora_mapping.values()), config.rollout_transport + ) + training_batch_sender = multi_run_sender + else: + multi_run_sender = None + training_batch_sender = setup_training_batch_sender(config.output_dir, config.rollout_transport) # Track last online eval checkpoint step per eval env last_eval_steps: dict[str, int] = {env.name: -1 for env in eval_envs} if eval_envs else {} @@ -393,6 +427,7 @@ async def orchestrate(config: OrchestratorConfig): ckpt_step=ckpt_step, step=progress.step, cache_salt=str(ckpt_step), + actor_models=scheduler.actor_model_names or None, ) for eval_env in envs_to_eval ] @@ -593,7 +628,18 @@ def process_rollout(rollout: vf.RolloutOutput, rollout_idx: int) -> list[Trainin step=progress.step, ) - training_batch_sender.send(training_batch) + if multi_run_sender is not None and actor_lora_mapping is not None: + per_actor_examples: dict[str, list[TrainingSample]] = { + run_name: [] for run_name in actor_lora_mapping.values() + } + for sample in train_examples: + if sample.actor_id is not None and sample.actor_id in actor_lora_mapping: + per_actor_examples[actor_lora_mapping[sample.actor_id]].append(sample) + for run_name, examples in per_actor_examples.items(): + if examples: + multi_run_sender.send_to_run(run_name, TrainingBatch(examples=examples, step=progress.step)) + else: + training_batch_sender.send(training_batch) step_time = time.perf_counter() - step_start_time @@ -823,6 +869,7 @@ def compute_solve_rates(df): ckpt_step=ckpt_step, step=progress.step, cache_salt=str(ckpt_step), + actor_models=scheduler.actor_model_names or None, ) for eval_env in eval_envs ] diff --git a/src/prime_rl/orchestrator/scheduler.py b/src/prime_rl/orchestrator/scheduler.py index 02840b6443..38e0045ad2 100644 --- a/src/prime_rl/orchestrator/scheduler.py +++ b/src/prime_rl/orchestrator/scheduler.py @@ -83,6 +83,7 @@ def __init__( strict_async_level: bool, tasks_per_minute: int | None, lora_name: str | None = None, + actor_lora_mapping: dict[str, str] | None = None, ): self.logger = get_logger() if tasks_per_minute is not None: @@ -119,6 +120,16 @@ def __init__( if group_scoring_envs: self.logger.info(f"Group rollout scoring active for env(s): {', '.join(group_scoring_envs)}") + # Multi-agent LoRA: per-actor policy routing + self.actor_lora_mapping = actor_lora_mapping + if actor_lora_mapping is not None: + self.actor_ckpt_steps: dict[str, int] = {agent_id: 0 for agent_id in actor_lora_mapping} + self.actor_model_names: dict[str, str] = { + agent_id: config.student.model.name for agent_id in actor_lora_mapping + } + else: + self.actor_model_names: dict[str, str] = {} + # Track in-flight requests: task -> info self.inflight_requests: dict[asyncio.Task, InflightRequest] = {} @@ -224,6 +235,7 @@ async def schedule_rollout(self, group_id: int): env = self.train_envs.get(env_name) cache_salt = str(self.ckpt_step) + actor_models = self.actor_model_names or None if env.requires_group_scoring: rollout_count = group.rollouts_to_schedule group.rollouts_to_schedule = 0 @@ -234,6 +246,7 @@ async def schedule_rollout(self, group_id: int): model_name=self.model_name, rollouts_per_example=rollout_count, cache_salt=cache_salt, + actor_models=actor_models, ) ) else: @@ -245,6 +258,7 @@ async def schedule_rollout(self, group_id: int): example=group.example, model_name=self.model_name, cache_salt=cache_salt, + actor_models=actor_models, ) ) self.inflight_requests[task] = InflightRequest( @@ -362,6 +376,9 @@ def _clear_inflight_policy_update(done_task: asyncio.Task) -> None: async def maybe_update_policy(self): """Updates the policy to the latest available checkpoint. Aborts rollout requests that are older than the max retention steps.""" + if self.actor_lora_mapping is not None: + return await self._update_policy_multi_actor() + while True: next_ckpt_step = self._compute_next_ckpt_step() if next_ckpt_step <= self.ckpt_step: @@ -370,6 +387,38 @@ async def maybe_update_policy(self): task = await self._get_or_start_policy_update_task(next_ckpt_step) await asyncio.shield(task) + async def _update_policy_multi_actor(self): + """Update per-actor LoRA adapters from their respective broadcast directories.""" + assert self.actor_lora_mapping is not None + any_updated = False + min_actor_step = float("inf") + + for agent_id, run_name in self.actor_lora_mapping.items(): + broadcast_dir = get_broadcast_dir(self.config.output_dir.parent / run_name) + latest_step = get_latest_ckpt_step(broadcast_dir) or 0 + min_actor_step = min(min_actor_step, latest_step) + + if latest_step > self.actor_ckpt_steps[agent_id]: + weights_path = get_step_path(broadcast_dir, latest_step) + lora_name = run_name + update_start = time.perf_counter() + await self.student_inference.update_weights(weights_path, lora_name=lora_name, step=latest_step) + self.update_weights_time = time.perf_counter() - update_start + + self.actor_ckpt_steps[agent_id] = latest_step + self.actor_model_names[agent_id] = lora_name + any_updated = True + self.logger.debug( + f"Updated actor {agent_id} ({lora_name}) to step {latest_step} " + f"in {self.update_weights_time:.2f}s" + ) + + if any_updated: + new_ckpt_step = int(min_actor_step) if min_actor_step != float("inf") else 0 + if new_ckpt_step > self.ckpt_step: + self.ckpt_step = new_ckpt_step + await self._update_off_policy() + async def _update_off_policy(self) -> None: stale_group_ids = { info.group_id diff --git a/src/prime_rl/orchestrator/trajectories.py b/src/prime_rl/orchestrator/trajectories.py index a254cf4266..ce1a9de702 100644 --- a/src/prime_rl/orchestrator/trajectories.py +++ b/src/prime_rl/orchestrator/trajectories.py @@ -346,6 +346,7 @@ def make_sample(step_idx: int) -> TrainingSample: env_name=output["env_name"], routed_experts=routed_experts, mm_token_type_ids=None, + actor_id=step.get("extras", {}).get("agent_id"), ) def extend_sample(sample: TrainingSample, prefix_len: int, step_idx: int) -> None: @@ -370,11 +371,14 @@ def extend_sample(sample: TrainingSample, prefix_len: int, step_idx: int) -> Non sample.completion_logprobs.extend(tokens["completion_logprobs"]) sample.completion_temperatures.extend([temperature] * len(completion_ids)) - # Update reward/advantage to use the latest merged step's values (multi-agent) + # Update reward/advantage/actor_id to use the latest merged step's values (multi-agent) if step.get("reward") is not None: sample.reward = step["reward"] if step.get("advantage") is not None: sample.advantage = step["advantage"] + agent_id = step.get("extras", {}).get("agent_id") + if agent_id is not None: + sample.actor_id = agent_id if tokens.get("routed_experts") is not None and sample.routed_experts is not None: step_routed = tokens["routed_experts"] diff --git a/src/prime_rl/transport/__init__.py b/src/prime_rl/transport/__init__.py index e4c3153dc7..cda2c02ca9 100644 --- a/src/prime_rl/transport/__init__.py +++ b/src/prime_rl/transport/__init__.py @@ -7,6 +7,7 @@ FileSystemMicroBatchSender, FileSystemTrainingBatchReceiver, FileSystemTrainingBatchSender, + MultiRunFileSystemTrainingBatchSender, ) from prime_rl.transport.types import MicroBatch, TrainingBatch, TrainingSample from prime_rl.transport.zmq import ( @@ -26,6 +27,14 @@ def setup_training_batch_sender(output_dir: Path, transport: TransportConfig) -> raise ValueError(f"Invalid transport type: {transport.type}") +def setup_multi_run_training_batch_sender( + output_dir: Path, run_names: list[str], transport: TransportConfig +) -> MultiRunFileSystemTrainingBatchSender: + if transport.type != "filesystem": + raise ValueError(f"Multi-run sender only supports filesystem transport, got: {transport.type}") + return MultiRunFileSystemTrainingBatchSender(output_dir, run_names) + + def setup_training_batch_receiver(transport: TransportConfig) -> TrainingBatchReceiver: if transport.type == "filesystem": return FileSystemTrainingBatchReceiver() @@ -62,6 +71,7 @@ def setup_micro_batch_receiver( "FileSystemTrainingBatchReceiver", "FileSystemMicroBatchSender", "FileSystemMicroBatchReceiver", + "MultiRunFileSystemTrainingBatchSender", "MicroBatchReceiver", "MicroBatchSender", "TrainingSample", @@ -69,6 +79,7 @@ def setup_micro_batch_receiver( "MicroBatch", "setup_training_batch_sender", "setup_training_batch_receiver", + "setup_multi_run_training_batch_sender", "setup_micro_batch_sender", "setup_micro_batch_receiver", ] diff --git a/src/prime_rl/transport/filesystem.py b/src/prime_rl/transport/filesystem.py index fb9ec68999..455b109848 100644 --- a/src/prime_rl/transport/filesystem.py +++ b/src/prime_rl/transport/filesystem.py @@ -113,6 +113,20 @@ def reset_run(self, idx: int) -> None: del self._received_steps[idx] +class MultiRunFileSystemTrainingBatchSender(TrainingBatchSender): + """Filesystem-based sender that routes batches to per-run directories.""" + + def __init__(self, output_dir: Path, run_names: list[str]): + super().__init__(output_dir) + self.senders = {name: FileSystemTrainingBatchSender(output_dir / name) for name in run_names} + + def send_to_run(self, run_name: str, batch: TrainingBatch) -> None: + self.senders[run_name].send(batch) + + def send(self, batch: TrainingBatch) -> None: + raise NotImplementedError("Use send_to_run() for multi-run sender") + + class FileSystemMicroBatchSender(MicroBatchSender): """Filesystem-based micro batch sender that writes micro batches to disk.""" diff --git a/src/prime_rl/transport/types.py b/src/prime_rl/transport/types.py index 332d6dc7a3..33cc36ff95 100644 --- a/src/prime_rl/transport/types.py +++ b/src/prime_rl/transport/types.py @@ -28,6 +28,8 @@ class TrainingSample(msgspec.Struct, array_like=True, gc=False, omit_defaults=Tr routed_experts: list[list[list[int]]] | None = None # [seq_len, layers, topk] + actor_id: str | None = None + # mm_token_type_ids: token type ids per token [batch seq], int64 (0=text, 1=image, 2=video) mm_token_type_ids: list[int] | None = None From e821d1b00ea74ef86800153fcc949ec90dedd844 Mon Sep 17 00:00:00 2001 From: nph4rd Date: Sat, 7 Mar 2026 23:36:14 -0600 Subject: [PATCH 08/16] split merged multi-agent samples by agent for per-agent lora training --- src/prime_rl/orchestrator/orchestrator.py | 3 + src/prime_rl/orchestrator/trajectories.py | 77 ++++++++++++++++++++--- 2 files changed, 70 insertions(+), 10 deletions(-) diff --git a/src/prime_rl/orchestrator/orchestrator.py b/src/prime_rl/orchestrator/orchestrator.py index f9aa5a299b..0529188c46 100644 --- a/src/prime_rl/orchestrator/orchestrator.py +++ b/src/prime_rl/orchestrator/orchestrator.py @@ -557,12 +557,15 @@ async def _pretokenize_all() -> None: mm_token_type_ids_mapping = None # Process rollouts in parallel + split_by_agent = actor_lora_mapping is not None + def process_rollout(rollout: vf.RolloutOutput, rollout_idx: int) -> list[TrainingSample] | None: return interleave_rollout( rollout, vlm_cache=vlm_cache, cache_key=rollout_idx, mm_token_type_ids_mapping=mm_token_type_ids_mapping, + split_by_agent=split_by_agent, ) results = await asyncio.gather( diff --git a/src/prime_rl/orchestrator/trajectories.py b/src/prime_rl/orchestrator/trajectories.py index ce1a9de702..3fb4f7125e 100644 --- a/src/prime_rl/orchestrator/trajectories.py +++ b/src/prime_rl/orchestrator/trajectories.py @@ -254,6 +254,7 @@ def interleave_rollout( vlm_cache: "VLMImageCache | None" = None, cache_key: int | None = None, mm_token_type_ids_mapping: dict[int, int] | None = None, + split_by_agent: bool = False, ) -> list[TrainingSample] | None: """ Convert vf.RolloutOutput to trainable rollouts by interleaving trajectory steps @@ -392,29 +393,50 @@ def extend_sample(sample: TrainingSample, prefix_len: int, step_idx: int) -> Non expected_len = len(sample.prompt_ids) + len(sample.completion_ids) sample.routed_experts = _align_routed_experts(sample.routed_experts, expected_len) - # Track [prefix_tokens, sample, last_step_idx] per active sample - active_samples: list[tuple[list[int], TrainingSample, int]] = [] + # Track [prefix_tokens, sample, last_step_idx, agent_meta] per active sample + # agent_meta: dict with per-completion-token agent_ids and per-agent reward/advantage (for split_by_agent) + active_samples: list[tuple[list[int], TrainingSample, int, dict[str, Any]]] = [] + + def make_agent_meta(step_idx: int) -> dict[str, Any]: + if not split_by_agent: + return {} + step = trajectory[step_idx] + tokens = prepared_steps[step_idx] + agent_id = step.get("extras", {}).get("agent_id") + return { + "ids": [agent_id] * len(tokens["completion_ids"]), + "rewards": {agent_id: step.get("reward")} if agent_id else {}, + "advantages": {agent_id: step.get("advantage")} if agent_id else {}, + } first_tokens = prepared_steps[0] first_prefix = first_tokens["prompt_ids"] + first_tokens["completion_ids"] - active_samples.append((first_prefix, make_sample(0), 0)) + active_samples.append((first_prefix, make_sample(0), 0, make_agent_meta(0))) - for step_idx, _step in enumerate(trajectory[1:], start=1): + for step_idx, step in enumerate(trajectory[1:], start=1): tokens = prepared_steps[step_idx] step_prompt_ids = tokens["prompt_ids"] # Check if this step extends ANY active prefix matched_idx = None - for idx, (prefix_tokens, _, _) in enumerate(active_samples): + for idx, (prefix_tokens, _, _, _) in enumerate(active_samples): if step_prompt_ids[: len(prefix_tokens)] == prefix_tokens: matched_idx = idx break if matched_idx is not None: # Extension holds - merge into matched sample - prefix_tokens, sample, _ = active_samples[matched_idx] + prefix_tokens, sample, _, meta = active_samples[matched_idx] extend_sample(sample, len(prefix_tokens), step_idx=step_idx) - active_samples[matched_idx] = (tokens["prompt_ids"] + tokens["completion_ids"], sample, step_idx) + if split_by_agent: + step_agent_id = step.get("extras", {}).get("agent_id") + new_prompt_len = len(tokens["prompt_ids"]) - len(prefix_tokens) + new_completion_len = len(tokens["completion_ids"]) + meta["ids"].extend([step_agent_id] * (new_prompt_len + new_completion_len)) + if step_agent_id: + meta["rewards"][step_agent_id] = step.get("reward") + meta["advantages"][step_agent_id] = step.get("advantage") + active_samples[matched_idx] = (tokens["prompt_ids"] + tokens["completion_ids"], sample, step_idx, meta) else: # No prefix matches - start a new sample logger.debug( @@ -422,7 +444,7 @@ def extend_sample(sample: TrainingSample, prefix_len: int, step_idx: int) -> Non f"Starting new sample (active_prefixes={len(active_samples)}, step_prompt_len={len(step_prompt_ids)})." ) new_prefix = tokens["prompt_ids"] + tokens["completion_ids"] - active_samples.append((new_prefix, make_sample(step_idx), step_idx)) + active_samples.append((new_prefix, make_sample(step_idx), step_idx, make_agent_meta(step_idx))) # Attach images once per sample using only the last merged step. Prompt # tokens already contain fully expanded <|image_pad|> placeholders because @@ -431,7 +453,7 @@ def extend_sample(sample: TrainingSample, prefix_len: int, step_idx: int) -> Non # fallback path so features and tokens stay 1:1. if vlm_cache is not None: key = output["example_id"] if cache_key is None else cache_key - for _, sample, last_step_idx in active_samples: + for _, sample, last_step_idx, _ in active_samples: pv, shape, grids = vlm_cache.get_for_step(key, last_step_idx) sample.pixel_values = pv sample.pixel_values_shape = shape @@ -441,7 +463,42 @@ def extend_sample(sample: TrainingSample, prefix_len: int, step_idx: int) -> Non mm_token_type_ids_mapping.get(token_id, 0) for token_id in sample.prompt_ids + sample.completion_ids ] - return [sample for _, sample, _ in active_samples] + if not split_by_agent: + return [sample for _, sample, _, _ in active_samples] + + # Split merged samples into per-agent copies with agent-specific completion masks + result: list[TrainingSample] = [] + for _, sample, _, meta in active_samples: + agent_ids = meta["ids"] + unique_agents = set(a for a in agent_ids if a is not None) + if len(unique_agents) <= 1: + result.append(sample) + continue + for agent_id in unique_agents: + agent_mask = [ + sample.completion_mask[i] and agent_ids[i] == agent_id for i in range(len(agent_ids)) + ] + agent_sample = TrainingSample( + prompt_ids=sample.prompt_ids, + prompt_mask=sample.prompt_mask, + completion_ids=sample.completion_ids, + completion_mask=agent_mask, + completion_logprobs=sample.completion_logprobs, + completion_temperatures=sample.completion_temperatures, + env_name=sample.env_name, + teacher_logprobs=sample.teacher_logprobs, + advantage=meta["advantages"].get(agent_id, sample.advantage), + reward=meta["rewards"].get(agent_id, sample.reward), + routed_experts=sample.routed_experts, + actor_id=agent_id, + mm_token_type_ids=sample.mm_token_type_ids, + pixel_values=sample.pixel_values, + pixel_values_shape=sample.pixel_values_shape, + image_grid_thw=sample.image_grid_thw, + training_mode=sample.training_mode, + ) + result.append(agent_sample) + return result # ============================================================================= From 7177db34eef6c90df9971a9a8a6daa035d9747df Mon Sep 17 00:00:00 2001 From: nph4rd Date: Sat, 7 Mar 2026 23:48:19 -0600 Subject: [PATCH 09/16] fix multi-agent lora orch.toml and add pack_full_step --- .../src/prime_rl/configs/rl.py | 2 ++ .../src/prime_rl/configs/trainer.py | 3 +++ src/prime_rl/orchestrator/orchestrator.py | 9 +++++-- src/prime_rl/trainer/rl/data.py | 2 ++ src/prime_rl/trainer/rl/packer.py | 26 +++++++++++++++++-- src/prime_rl/trainer/rl/train.py | 1 + 6 files changed, 39 insertions(+), 4 deletions(-) diff --git a/packages/prime-rl-configs/src/prime_rl/configs/rl.py b/packages/prime-rl-configs/src/prime_rl/configs/rl.py index e234da86f7..432bb255dd 100644 --- a/packages/prime-rl-configs/src/prime_rl/configs/rl.py +++ b/packages/prime-rl-configs/src/prime_rl/configs/rl.py @@ -471,6 +471,8 @@ def auto_setup_lora(self): def validate_multi_agent_lora(self): if self.orchestrator.multi_agent_lora and self.trainer.model.lora is None: raise ValueError("orchestrator.multi_agent_lora requires trainer.model.lora to be configured.") + if self.orchestrator.multi_agent_lora: + self.trainer.pack_full_step = True return self @model_validator(mode="after") diff --git a/packages/prime-rl-configs/src/prime_rl/configs/trainer.py b/packages/prime-rl-configs/src/prime_rl/configs/trainer.py index f4d37cd9d0..491d483ec3 100644 --- a/packages/prime-rl-configs/src/prime_rl/configs/trainer.py +++ b/packages/prime-rl-configs/src/prime_rl/configs/trainer.py @@ -549,6 +549,9 @@ class TrainerConfig(BaseConfig): max_concurrent_runs: int = Field(1, ge=1) """Maximum number of concurrent runs to allow. If 1, only one run may run at a time.""" + pack_full_step: bool = False + """When True, wait for all active runs to have data before packing a training step.""" + experimental: TrainerExperimentalConfig = TrainerExperimentalConfig() @model_validator(mode="after") diff --git a/src/prime_rl/orchestrator/orchestrator.py b/src/prime_rl/orchestrator/orchestrator.py index 0529188c46..6d4b6c7d27 100644 --- a/src/prime_rl/orchestrator/orchestrator.py +++ b/src/prime_rl/orchestrator/orchestrator.py @@ -256,12 +256,17 @@ async def orchestrate(config: OrchestratorConfig): actor_lora_mapping = {agent_id: f"run_{agent_id}" for agent_id in agents} logger.info(f"Multi-agent LoRA enabled: {actor_lora_mapping}") - # Create per-actor run directories with orch.toml + # Create per-actor run directories with minimal orch.toml + # Only include model.lora.name and optim.lr — the trainer fills in the rest for run_name in actor_lora_mapping.values(): run_config_dir = config.output_dir.parent / run_name / "control" run_config_dir.mkdir(parents=True, exist_ok=True) + actor_orch_config = { + "model": {"lora": {"name": run_name}}, + "optim": {"lr": config.optim.lr}, + } with open(run_config_dir / "orch.toml", "wb") as f: - tomli_w.dump(config.model_dump(exclude_none=True, mode="json"), f) + tomli_w.dump(actor_orch_config, f) scheduler = Scheduler( train_envs=train_envs, diff --git a/src/prime_rl/trainer/rl/data.py b/src/prime_rl/trainer/rl/data.py index e732db3e55..1d2e8322c0 100644 --- a/src/prime_rl/trainer/rl/data.py +++ b/src/prime_rl/trainer/rl/data.py @@ -161,6 +161,7 @@ def __init__( pad_to_multiple_of: int, tokenizer: PreTrainedTokenizer, config: TransportConfig, + pack_full_step: bool = False, ): self.world = get_world() @@ -172,6 +173,7 @@ def __init__( transport_config=config, pad_to_multiple_of=pad_to_multiple_of, start_step=start_step, + pack_full_step=pack_full_step, ) non_dp_world_size = self.world.world_size // dp_world_size diff --git a/src/prime_rl/trainer/rl/packer.py b/src/prime_rl/trainer/rl/packer.py index cf9dcfa02e..d614ff478b 100644 --- a/src/prime_rl/trainer/rl/packer.py +++ b/src/prime_rl/trainer/rl/packer.py @@ -124,8 +124,10 @@ def __init__( tokenizer: PreTrainedTokenizer, config: TransportConfig, start_step: int = 0, + pack_full_step: bool = False, ): super().__init__(dp_world_size, seq_len, pad_to_multiple_of, tokenizer, config, start_step) + self.pack_full_step = pack_full_step # Per-run buffer: stores (TrainingSample, step) tuples self.buffers: list[deque[tuple[TrainingSample, int]]] = [ deque() for _ in range(self.multi_run_manager.max_runs) @@ -221,8 +223,22 @@ def _count_tokens(self, threshold: int | None = None) -> int: return tokens return tokens + def _all_runs_have_step_data(self) -> bool: + """Check that every active run has buffered data for its current step.""" + if not self.multi_run_manager.used_idxs: + return False + for run_idx in self.multi_run_manager.used_idxs: + if len(self.buffers[run_idx]) == 0: + return False + _, step = self.buffers[run_idx][0] + if step > self.multi_run_manager.progress[run_idx].step: + return False + return True + def _has_enough_tokens(self) -> bool: """Check if we have enough samples in buffer to pack a step""" + if self.pack_full_step: + return self._all_runs_have_step_data() # When not using small batch granularity, require at least one full batch threshold = self.seq_len * self.dp_world_size return self._count_tokens(threshold) >= threshold @@ -293,7 +309,10 @@ def pack(self): time.sleep(1) self._get_batch() - token_budget = self.seq_len * self.dp_world_size + if self.pack_full_step: + token_budget = self._count_tokens() + else: + token_budget = self.seq_len * self.dp_world_size selected_samples = self._select_samples_round_robin(token_budget) assert selected_samples, "No samples selected" @@ -342,9 +361,12 @@ def setup_packer( tokenizer: PreTrainedTokenizer, transport_config: TransportConfig, start_step: int = 0, + pack_full_step: bool = False, ) -> BasePacker: multi_run_manager = get_multi_run_manager() if multi_run_manager.max_runs == 1: return SinglePacker(dp_world_size, seq_len, pad_to_multiple_of, tokenizer, transport_config, start_step) else: - return MultiPacker(dp_world_size, seq_len, pad_to_multiple_of, tokenizer, transport_config, start_step) + return MultiPacker( + dp_world_size, seq_len, pad_to_multiple_of, tokenizer, transport_config, start_step, pack_full_step + ) diff --git a/src/prime_rl/trainer/rl/train.py b/src/prime_rl/trainer/rl/train.py index 4e75111907..017dc6c649 100644 --- a/src/prime_rl/trainer/rl/train.py +++ b/src/prime_rl/trainer/rl/train.py @@ -237,6 +237,7 @@ def load_run_checkpoint(_optimizer, idx: int) -> None: config.model.cp, tokenizer, config.rollout_transport, + pack_full_step=config.pack_full_step, ) gc_handler = GarbageCollection(config.gc.interval) if config.gc else None From b32110856a0e0157adc4906251d2d9f4b1d1c3b7 Mon Sep 17 00:00:00 2001 From: nph4rd Date: Sat, 7 Mar 2026 23:53:45 -0600 Subject: [PATCH 10/16] enforce async level in multi-actor policy updates --- src/prime_rl/orchestrator/scheduler.py | 28 +++++++++++++++++++++----- 1 file changed, 23 insertions(+), 5 deletions(-) diff --git a/src/prime_rl/orchestrator/scheduler.py b/src/prime_rl/orchestrator/scheduler.py index 38e0045ad2..3da6014767 100644 --- a/src/prime_rl/orchestrator/scheduler.py +++ b/src/prime_rl/orchestrator/scheduler.py @@ -390,6 +390,23 @@ async def maybe_update_policy(self): async def _update_policy_multi_actor(self): """Update per-actor LoRA adapters from their respective broadcast directories.""" assert self.actor_lora_mapping is not None + async_away_ckpt_step = max(self.step - self.max_async_level, 0) + if async_away_ckpt_step > self.ckpt_step: + self.checkpoint_ready.clear() + self.logger.info( + f"Orchestrator paused: waiting for all actors to reach step {async_away_ckpt_step} " + f"(>{self.max_async_level} step(s) ahead). Training is progressing normally." + ) + wait_start = time.perf_counter() + for agent_id, run_name in self.actor_lora_mapping.items(): + broadcast_dir = get_broadcast_dir(self.config.output_dir.parent / run_name) + await wait_for_path(get_step_path(broadcast_dir, async_away_ckpt_step) / "STABLE") + self.wait_for_ckpt_time = time.perf_counter() - wait_start + self.logger.info( + f"Orchestrator resumed: all actors reached step {async_away_ckpt_step} " + f"(after {self.wait_for_ckpt_time:.2f}s)" + ) + any_updated = False min_actor_step = float("inf") @@ -413,11 +430,12 @@ async def _update_policy_multi_actor(self): f"in {self.update_weights_time:.2f}s" ) - if any_updated: - new_ckpt_step = int(min_actor_step) if min_actor_step != float("inf") else 0 - if new_ckpt_step > self.ckpt_step: - self.ckpt_step = new_ckpt_step - await self._update_off_policy() + new_ckpt_step = int(min_actor_step) if min_actor_step != float("inf") else 0 + if new_ckpt_step > self.ckpt_step: + self.ckpt_step = new_ckpt_step + await self._update_off_policy() + + self.checkpoint_ready.set() async def _update_off_policy(self) -> None: stale_group_ids = { From e3a6cc07e84e6c2a145a4141dc0faef619e3042f Mon Sep 17 00:00:00 2001 From: nph4rd Date: Sun, 8 Mar 2026 00:04:37 -0600 Subject: [PATCH 11/16] auto-set max_concurrent_runs and fix packer timeout for multi-agent lora --- packages/prime-rl-configs/src/prime_rl/configs/rl.py | 2 ++ src/prime_rl/trainer/rl/packer.py | 2 +- 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/packages/prime-rl-configs/src/prime_rl/configs/rl.py b/packages/prime-rl-configs/src/prime_rl/configs/rl.py index 432bb255dd..a8af1a1c19 100644 --- a/packages/prime-rl-configs/src/prime_rl/configs/rl.py +++ b/packages/prime-rl-configs/src/prime_rl/configs/rl.py @@ -473,6 +473,8 @@ def validate_multi_agent_lora(self): raise ValueError("orchestrator.multi_agent_lora requires trainer.model.lora to be configured.") if self.orchestrator.multi_agent_lora: self.trainer.pack_full_step = True + if self.trainer.max_concurrent_runs < 2: + self.trainer.max_concurrent_runs = 2 return self @model_validator(mode="after") diff --git a/src/prime_rl/trainer/rl/packer.py b/src/prime_rl/trainer/rl/packer.py index d614ff478b..b190279bd4 100644 --- a/src/prime_rl/trainer/rl/packer.py +++ b/src/prime_rl/trainer/rl/packer.py @@ -303,7 +303,7 @@ def pack(self): start_time = time.time() while not self._has_enough_tokens(): - if time.time() - start_time > TIMEOUT_SECONDS and self._count_tokens() > 0: + if not self.pack_full_step and time.time() - start_time > TIMEOUT_SECONDS and self._count_tokens() > 0: self.logger.warning("Timeout waiting for enough tokens to pack") break time.sleep(1) From 78ff0ac33b1d3a01af110dd181735e2409dbfaeb Mon Sep 17 00:00:00 2001 From: nph4rd Date: Sat, 21 Mar 2026 18:19:27 -0600 Subject: [PATCH 12/16] use upstream dedup pattern for multi-actor policy updates --- src/prime_rl/orchestrator/scheduler.py | 49 +++++++++++++++----------- 1 file changed, 29 insertions(+), 20 deletions(-) diff --git a/src/prime_rl/orchestrator/scheduler.py b/src/prime_rl/orchestrator/scheduler.py index 3da6014767..6649b34ea2 100644 --- a/src/prime_rl/orchestrator/scheduler.py +++ b/src/prime_rl/orchestrator/scheduler.py @@ -322,6 +322,9 @@ def _compute_next_ckpt_step(self) -> int: return max(async_away_ckpt_step, latest_ckpt_step) async def _apply_policy_update(self, next_ckpt_step: int) -> None: + if self.actor_lora_mapping is not None: + return await self._apply_policy_update_multi_actor(next_ckpt_step) + async_away_ckpt_step = max(self.step - self.max_async_level, 0) if next_ckpt_step == async_away_ckpt_step: self.logger.info( @@ -376,44 +379,54 @@ def _clear_inflight_policy_update(done_task: asyncio.Task) -> None: async def maybe_update_policy(self): """Updates the policy to the latest available checkpoint. Aborts rollout requests that are older than the max retention steps.""" - if self.actor_lora_mapping is not None: - return await self._update_policy_multi_actor() - while True: - next_ckpt_step = self._compute_next_ckpt_step() + if self.actor_lora_mapping is not None: + next_ckpt_step = self._compute_next_ckpt_step_multi_actor() + else: + next_ckpt_step = self._compute_next_ckpt_step() if next_ckpt_step <= self.ckpt_step: return task = await self._get_or_start_policy_update_task(next_ckpt_step) await asyncio.shield(task) - async def _update_policy_multi_actor(self): - """Update per-actor LoRA adapters from their respective broadcast directories.""" + def _compute_next_ckpt_step_multi_actor(self) -> int: + """Compute next checkpoint step for multi-actor mode (minimum across all actors).""" assert self.actor_lora_mapping is not None + min_latest = float("inf") + for agent_id, run_name in self.actor_lora_mapping.items(): + broadcast_dir = get_broadcast_dir(self.config.output_dir.parent / run_name) + latest_step = get_latest_ckpt_step(broadcast_dir) or 0 + min_latest = min(min_latest, latest_step) + next_step = int(min_latest) if min_latest != float("inf") else 0 async_away_ckpt_step = max(self.step - self.max_async_level, 0) - if async_away_ckpt_step > self.ckpt_step: + if self.strict_async_level: + return async_away_ckpt_step + return max(async_away_ckpt_step, next_step) + + async def _apply_policy_update_multi_actor(self, next_ckpt_step: int) -> None: + """Apply policy update for multi-actor LoRA mode.""" + assert self.actor_lora_mapping is not None + async_away_ckpt_step = max(self.step - self.max_async_level, 0) + if next_ckpt_step == async_away_ckpt_step: self.checkpoint_ready.clear() self.logger.info( - f"Orchestrator paused: waiting for all actors to reach step {async_away_ckpt_step} " + f"Orchestrator paused: waiting for all actors to reach step {next_ckpt_step} " f"(>{self.max_async_level} step(s) ahead). Training is progressing normally." ) wait_start = time.perf_counter() for agent_id, run_name in self.actor_lora_mapping.items(): broadcast_dir = get_broadcast_dir(self.config.output_dir.parent / run_name) - await wait_for_path(get_step_path(broadcast_dir, async_away_ckpt_step) / "STABLE") + await wait_for_path(get_step_path(broadcast_dir, next_ckpt_step) / "STABLE") self.wait_for_ckpt_time = time.perf_counter() - wait_start self.logger.info( - f"Orchestrator resumed: all actors reached step {async_away_ckpt_step} " + f"Orchestrator resumed: all actors reached step {next_ckpt_step} " f"(after {self.wait_for_ckpt_time:.2f}s)" ) - any_updated = False - min_actor_step = float("inf") - for agent_id, run_name in self.actor_lora_mapping.items(): broadcast_dir = get_broadcast_dir(self.config.output_dir.parent / run_name) latest_step = get_latest_ckpt_step(broadcast_dir) or 0 - min_actor_step = min(min_actor_step, latest_step) if latest_step > self.actor_ckpt_steps[agent_id]: weights_path = get_step_path(broadcast_dir, latest_step) @@ -424,18 +437,14 @@ async def _update_policy_multi_actor(self): self.actor_ckpt_steps[agent_id] = latest_step self.actor_model_names[agent_id] = lora_name - any_updated = True self.logger.debug( f"Updated actor {agent_id} ({lora_name}) to step {latest_step} " f"in {self.update_weights_time:.2f}s" ) - new_ckpt_step = int(min_actor_step) if min_actor_step != float("inf") else 0 - if new_ckpt_step > self.ckpt_step: - self.ckpt_step = new_ckpt_step - await self._update_off_policy() - + self.ckpt_step = next_ckpt_step self.checkpoint_ready.set() + await self._update_off_policy() async def _update_off_policy(self) -> None: stale_group_ids = { From 1ffd1eff381a422be9c9db176aa5452b85089bdd Mon Sep 17 00:00:00 2001 From: nph4rd Date: Wed, 20 May 2026 13:28:38 -0600 Subject: [PATCH 13/16] honor per-step advantages in ZeroAdvantageFilter for multi-agent rollouts --- src/prime_rl/orchestrator/filters.py | 12 +++++- tests/unit/orchestrator/test_filters.py | 52 ++++++++++++++++++++++++- 2 files changed, 61 insertions(+), 3 deletions(-) diff --git a/src/prime_rl/orchestrator/filters.py b/src/prime_rl/orchestrator/filters.py index b2921d22b1..761ac9856d 100644 --- a/src/prime_rl/orchestrator/filters.py +++ b/src/prime_rl/orchestrator/filters.py @@ -98,14 +98,22 @@ def check(self, rollout: vf.RolloutOutput) -> FilterResult: class ZeroAdvantageFilter: """Flags rollouts with zero advantage. - This filter is applied after advantages are computed and checks if the - rollout's advantage field is zero. + This filter is applied after advantages are computed. Multi-agent rollouts + may have per-step advantages even when the rollout-level advantage is zero. """ name: str enforce: bool = True def check(self, rollout: vf.RolloutOutput) -> FilterResult: + step_advantages = [ + step.get("advantage") + for step in rollout.get("trajectory", []) + if step.get("advantage") is not None + ] + if step_advantages: + return FilterResult(detected=all(advantage == 0.0 for advantage in step_advantages)) + advantage = rollout.get("advantage") if advantage is not None and advantage == 0.0: return FilterResult(detected=True) diff --git a/tests/unit/orchestrator/test_filters.py b/tests/unit/orchestrator/test_filters.py index e77f51f61f..60c297004d 100644 --- a/tests/unit/orchestrator/test_filters.py +++ b/tests/unit/orchestrator/test_filters.py @@ -1,9 +1,10 @@ import math -from prime_rl.configs.orchestrator import GibberishFilterConfig, RepetitionFilterConfig +from prime_rl.configs.orchestrator import GibberishFilterConfig, RepetitionFilterConfig, ZeroAdvantageFilterConfig from prime_rl.orchestrator.filters import ( GibberishFilter, RepetitionFilter, + ZeroAdvantageFilter, apply_filters, setup_filter, setup_filters, @@ -61,6 +62,10 @@ def _make_repetition_filter(window=5, prob_threshold=0.99, enforce=False): ) +def _make_zero_advantage_filter(enforce=True): + return ZeroAdvantageFilter(name="zero_advantage", enforce=enforce) + + # --- GibberishFilter tests --- @@ -169,6 +174,43 @@ def test_repetition_varied_probs_no_trigger(): assert result.detected is False +# --- ZeroAdvantageFilter tests --- + + +def test_zero_advantage_detects_rollout_level_zero(): + zero_advantage_filter = _make_zero_advantage_filter() + rollout = _make_rollout(completion_ids=[1], completion_logprobs=[-1.0]) + rollout["advantage"] = 0.0 + + result = zero_advantage_filter.check(rollout) + + assert result.detected is True + + +def test_zero_advantage_uses_step_advantages_when_present(): + zero_advantage_filter = _make_zero_advantage_filter() + rollout = _make_rollout(completion_ids=[1, 2], completion_logprobs=[-1.0, -1.0], multi_step=True) + rollout["advantage"] = 0.0 + rollout["trajectory"][0]["advantage"] = 1.0 + rollout["trajectory"][1]["advantage"] = -1.0 + + result = zero_advantage_filter.check(rollout) + + assert result.detected is False + + +def test_zero_advantage_detects_all_step_advantages_zero(): + zero_advantage_filter = _make_zero_advantage_filter() + rollout = _make_rollout(completion_ids=[1, 2], completion_logprobs=[-1.0, -1.0], multi_step=True) + rollout["advantage"] = 1.0 + rollout["trajectory"][0]["advantage"] = 0.0 + rollout["trajectory"][1]["advantage"] = 0.0 + + result = zero_advantage_filter.check(rollout) + + assert result.detected is True + + # --- setup_filter / setup_filters tests --- @@ -204,6 +246,14 @@ def test_setup_filter_repetition_enforce(): assert repetition_filter.enforce is True +def test_setup_filter_zero_advantage(): + config = ZeroAdvantageFilterConfig() + zero_advantage_filter = setup_filter(config, vocab_size=128_000) + assert isinstance(zero_advantage_filter, ZeroAdvantageFilter) + assert zero_advantage_filter.name == "zero_advantage" + assert zero_advantage_filter.enforce is True + + def test_setup_filters_multiple(): configs = [ GibberishFilterConfig(), From 4937e102cf7f83ade8782e0b15346480eae614d4 Mon Sep 17 00:00:00 2001 From: nph4rd Date: Wed, 20 May 2026 13:32:21 -0600 Subject: [PATCH 14/16] emit per-agent trainer metrics alongside per-env breakdowns --- src/prime_rl/trainer/batch.py | 13 +++++++++++++ src/prime_rl/trainer/rl/data.py | 4 ++++ src/prime_rl/trainer/rl/train.py | 17 ++++++++++++++++- src/prime_rl/transport/types.py | 3 +++ 4 files changed, 36 insertions(+), 1 deletion(-) diff --git a/src/prime_rl/trainer/batch.py b/src/prime_rl/trainer/batch.py index e5a16ba05b..d65ab08b21 100644 --- a/src/prime_rl/trainer/batch.py +++ b/src/prime_rl/trainer/batch.py @@ -16,6 +16,9 @@ def prepare_sample(training_example: TrainingSample, seq_len: int) -> MicroBatch mm_token_type_ids = training_example.mm_token_type_ids assert training_example.env_name != "all", "env_name='all' is reserved for aggregate metric keys" env_names = [training_example.env_name] * len(input_ids) + actor_id = training_example.actor_id or "" + assert actor_id != "all", "actor_id='all' is reserved for aggregate metric keys" + actor_ids = [actor_id] * len(input_ids) # Per-token temperatures: prompt tokens use first completion temp (masked out anyway) # Default to 1.0 if completion is empty (e.g., model generated only tool calls with no text) @@ -41,6 +44,7 @@ def prepare_sample(training_example: TrainingSample, seq_len: int) -> MicroBatch if mm_token_type_ids is not None: mm_token_type_ids = mm_token_type_ids[:seq_len] env_names = env_names[:seq_len] + actor_ids = actor_ids[:seq_len] assert ( len(input_ids) @@ -65,6 +69,7 @@ def prepare_sample(training_example: TrainingSample, seq_len: int) -> MicroBatch f"mm_token_type_ids: {len(mm_token_type_ids)}, input_ids: {len(input_ids)}" ) assert len(env_names) == len(input_ids), f"env_names: {len(env_names)}, input_ids: {len(input_ids)}" + assert len(actor_ids) == len(input_ids), f"actor_ids: {len(actor_ids)}, input_ids: {len(input_ids)}" return MicroBatch( input_ids=input_ids, @@ -77,6 +82,7 @@ def prepare_sample(training_example: TrainingSample, seq_len: int) -> MicroBatch routed_experts=routed_experts, mm_token_type_ids=mm_token_type_ids, env_names=env_names, + actor_ids=actor_ids, # Multimodal fields (Qwen3-VL) - passed through without modification pixel_values=training_example.pixel_values, pixel_values_shape=training_example.pixel_values_shape, @@ -143,6 +149,7 @@ def packed_samples_into_micro_bs( bin_content.mm_token_type_ids = [] bin_content.mm_token_type_ids.extend(sample.mm_token_type_ids) bin_content.env_names.extend(sample.env_names) + bin_content.actor_ids.extend(sample.actor_ids) bin_content.position_ids.extend(sample.position_ids) bin_content.lora_num_tokens[idx] += len(sample.input_ids) break @@ -172,6 +179,11 @@ def pad_micro_batch(micro_batch: MicroBatch, pad_to_multiple_of: int) -> MicroBa f"MicroBatch.env_names must match input_ids length before padding: " f"env_names={len(micro_batch.env_names)}, input_ids={len(micro_batch.input_ids)}" ) + if len(micro_batch.actor_ids) != len(micro_batch.input_ids): + raise ValueError( + f"MicroBatch.actor_ids must match input_ids length before padding: " + f"actor_ids={len(micro_batch.actor_ids)}, input_ids={len(micro_batch.input_ids)}" + ) if not (pad_to_multiple_of > 1 and padding_size > 0): return micro_batch @@ -191,6 +203,7 @@ def pad_micro_batch(micro_batch: MicroBatch, pad_to_multiple_of: int) -> MicroBa if micro_batch.mm_token_type_ids is not None: micro_batch.mm_token_type_ids.extend([0] * padding_size) micro_batch.env_names.extend([""] * padding_size) + micro_batch.actor_ids.extend([""] * padding_size) return micro_batch diff --git a/src/prime_rl/trainer/rl/data.py b/src/prime_rl/trainer/rl/data.py index 1d2e8322c0..ceed2489f2 100644 --- a/src/prime_rl/trainer/rl/data.py +++ b/src/prime_rl/trainer/rl/data.py @@ -25,6 +25,7 @@ class TensorMicroBatch(TypedDict): loss_mask: Bool[Tensor, "batch seq"] temperatures: Float[Tensor, "batch seq"] # Per-token temperatures env_names: list[str] + actor_ids: list[str] # Batch level lora_num_tokens: Int[Tensor, "n_loras"] @@ -111,6 +112,7 @@ def _get_sample_micro_batch(self, generator: torch.Generator) -> TensorMicroBatc "teacher_logprobs": None, "temperatures": torch.ones(input_ids.shape[0]).unsqueeze(0), "env_names": ["fake"] * input_ids.shape[0], + "actor_ids": [""] * input_ids.shape[0], "loss_mask": loss_mask.unsqueeze(0), "lora_num_tokens": lora_num_tokens, "routed_experts": None, @@ -139,6 +141,7 @@ def _get_micro_batch(self, generator: torch.Generator) -> TensorMicroBatch: "teacher_logprobs": None, "temperatures": torch.ones(self.seq_len).unsqueeze(0), "env_names": ["fake"] * self.seq_len, + "actor_ids": [""] * self.seq_len, "loss_mask": torch.ones(self.seq_len, dtype=torch.bool).unsqueeze(0), "lora_num_tokens": lora_num_tokens, "routed_experts": None, @@ -212,6 +215,7 @@ def _micro_batch_to_tensor(self, micro_batch: MicroBatch) -> TensorMicroBatch: loss_mask=torch.tensor(micro_batch.loss_mask, dtype=torch.bool).unsqueeze(0), temperatures=torch.tensor(micro_batch.temperatures, dtype=torch.float).unsqueeze(0), env_names=micro_batch.env_names, + actor_ids=micro_batch.actor_ids, lora_num_tokens=torch.tensor(micro_batch.lora_num_tokens, dtype=torch.int32), # Multimodal fields - no batch dimension for these as they are variable-sized pixel_values=torch.frombuffer(bytearray(micro_batch.pixel_values), dtype=torch.float32).reshape( diff --git a/src/prime_rl/trainer/rl/train.py b/src/prime_rl/trainer/rl/train.py index 017dc6c649..b7da346a51 100644 --- a/src/prime_rl/trainer/rl/train.py +++ b/src/prime_rl/trainer/rl/train.py @@ -480,8 +480,9 @@ def load_run_checkpoint(_optimizer, idx: int) -> None: tensors["entropy/all"].append(entropy) tensors["loss"].append(loss.detach().to("cpu").unsqueeze(0)) + keep_flags = loss_mask.flatten().tolist() env_names = micro_batch["env_names"] - masked_env_names = [env_name for env_name, keep in zip(env_names, loss_mask.flatten().tolist()) if keep] + masked_env_names = [env_name for env_name, keep in zip(env_names, keep_flags) if keep] env_to_indices: dict[str, list[int]] = {} for idx, env_name in enumerate(masked_env_names): env_to_indices.setdefault(env_name, []).append(idx) @@ -489,6 +490,18 @@ def load_run_checkpoint(_optimizer, idx: int) -> None: for env_name, indices in env_to_indices.items(): tensors[f"entropy/{env_name}"].append(entropy[indices]) + # Per-agent breakdown for multi-agent training. Skips agentless ("") tokens + # so single-agent runs don't emit a noisy `entropy/` key. + actor_ids = micro_batch["actor_ids"] + masked_actor_ids = [actor_id for actor_id, keep in zip(actor_ids, keep_flags) if keep] + actor_to_indices: dict[str, list[int]] = {} + for idx, actor_id in enumerate(masked_actor_ids): + if actor_id: + actor_to_indices.setdefault(actor_id, []).append(idx) + + for actor_id, indices in actor_to_indices.items(): + tensors[f"entropy/actor/{actor_id}"].append(entropy[indices]) + if micro_batch["training_mode"] != "sft": with torch.no_grad(): _, _, mismatch_kl = compute_importance_ratio_and_mismatch_kl(out["logprobs"], inference_logprobs) @@ -496,6 +509,8 @@ def load_run_checkpoint(_optimizer, idx: int) -> None: tensors["mismatch_kl/all"].append(mismatch_kl) for env_name, indices in env_to_indices.items(): tensors[f"mismatch_kl/{env_name}"].append(mismatch_kl[indices]) + for actor_id, indices in actor_to_indices.items(): + tensors[f"mismatch_kl/actor/{actor_id}"].append(mismatch_kl[indices]) if is_tt_moe_model(model): load_balance_stats = get_load_balance_stats(model) diff --git a/src/prime_rl/transport/types.py b/src/prime_rl/transport/types.py index 33cc36ff95..624d3bc3b5 100644 --- a/src/prime_rl/transport/types.py +++ b/src/prime_rl/transport/types.py @@ -57,6 +57,9 @@ class MicroBatch(msgspec.Struct, array_like=True, gc=False, omit_defaults=True): position_ids: list[int] temperatures: list[float] # Per-token temperatures used during generation env_names: list[str] + # Per-token actor_id ("" for non-multi-agent or padding). Mirrors env_names + # so train.py can emit per-agent metric breakdowns. + actor_ids: list[str] teacher_logprobs: list[float] | None = None lora_num_tokens: list[int] | None = None routed_experts: list[list[list[int]]] | None = None From aa649af69e63f689c9be668918d91446549c0a4e Mon Sep 17 00:00:00 2001 From: nph4rd Date: Wed, 20 May 2026 15:19:59 -0600 Subject: [PATCH 15/16] drop deps/verifiers from workspace members; conflicts with git source override --- pyproject.toml | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 99e8e5756e..0a06aeda9a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -125,7 +125,10 @@ dev = [ members = [ "packages/prime-rl-configs", "deps/pydantic-config", - "deps/verifiers", + # `deps/verifiers` itself is intentionally NOT a workspace member: the + # `verifiers` package is sourced from our fork via [tool.uv.sources] below. + # The environment sub-packages under deps/verifiers/environments/* remain + # workspace members so they can be imported by their workspace names. "deps/renderers", "deps/verifiers/environments/alphabet_sort", "deps/verifiers/environments/math_python", From 42796a76ffa3a00e78af49fdf1af64009756927f Mon Sep 17 00:00:00 2001 From: nph4rd Date: Wed, 20 May 2026 16:07:34 -0600 Subject: [PATCH 16/16] re-leaf orchestrator output_dir to outputs/orchestrator under multi_agent_lora --- packages/prime-rl-configs/src/prime_rl/configs/rl.py | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/packages/prime-rl-configs/src/prime_rl/configs/rl.py b/packages/prime-rl-configs/src/prime_rl/configs/rl.py index a8af1a1c19..0b65384432 100644 --- a/packages/prime-rl-configs/src/prime_rl/configs/rl.py +++ b/packages/prime-rl-configs/src/prime_rl/configs/rl.py @@ -475,6 +475,18 @@ def validate_multi_agent_lora(self): self.trainer.pack_full_step = True if self.trainer.max_concurrent_runs < 2: self.trainer.max_concurrent_runs = 2 + # The orchestrator's output_dir defaults to ``outputs/run_default``, + # which collides with the trainer's ``output_dir/run_*`` glob (see + # MultiRunManager.discover_runs): the trainer would treat the + # orchestrator's home as a third "actor" run alongside + # run_proposer / run_responder and deadlock waiting for batches + # there. Re-leaf to ``outputs/orchestrator`` so it stays outside + # the discovery pattern. Mirrors the equivalent rewrite in + # validation.py for the shared-output_dir propagation path; only + # applies when the user hasn't overridden orchestrator.output_dir. + default_leaf = Path("outputs/run_default") + if self.orchestrator.output_dir == default_leaf: + self.orchestrator.output_dir = Path("outputs/orchestrator") return self @model_validator(mode="after")