Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
16 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -598,6 +598,9 @@ class OrchestratorConfig(BaseConfig):
weight_broadcast: WeightBroadcastConfig = FileSystemWeightBroadcastConfig()
"""Transport used to receive updated weights from the trainer."""

multi_agent_lora: bool = False
"""Enable per-agent LoRA training for multi-agent environments."""

rollout_transport: TransportConfig = FileSystemTransportConfig()
"""Transport used to ship rollouts from orchestrator to trainer."""

Expand Down
22 changes: 22 additions & 0 deletions packages/prime-rl-configs/src/prime_rl/configs/rl.py
Original file line number Diff line number Diff line change
Expand Up @@ -467,6 +467,28 @@ def auto_setup_lora(self):

return self

@model_validator(mode="after")
def validate_multi_agent_lora(self):
if self.orchestrator.multi_agent_lora and self.trainer.model.lora is None:
raise ValueError("orchestrator.multi_agent_lora requires trainer.model.lora to be configured.")
if self.orchestrator.multi_agent_lora:
self.trainer.pack_full_step = True
if self.trainer.max_concurrent_runs < 2:
self.trainer.max_concurrent_runs = 2
# The orchestrator's output_dir defaults to ``outputs/run_default``,
# which collides with the trainer's ``output_dir/run_*`` glob (see
# MultiRunManager.discover_runs): the trainer would treat the
# orchestrator's home as a third "actor" run alongside
# run_proposer / run_responder and deadlock waiting for batches
# there. Re-leaf to ``outputs/orchestrator`` so it stays outside
# the discovery pattern. Mirrors the equivalent rewrite in
# validation.py for the shared-output_dir propagation path; only
# applies when the user hasn't overridden orchestrator.output_dir.
default_leaf = Path("outputs/run_default")
if self.orchestrator.output_dir == default_leaf:
self.orchestrator.output_dir = Path("outputs/orchestrator")
return self

@model_validator(mode="after")
def auto_setup_router_replay(self):
if self.trainer.enable_router_replay:
Expand Down
3 changes: 3 additions & 0 deletions packages/prime-rl-configs/src/prime_rl/configs/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -549,6 +549,9 @@ class TrainerConfig(BaseConfig):
max_concurrent_runs: int = Field(1, ge=1)
"""Maximum number of concurrent runs to allow. If 1, only one run may run at a time."""

pack_full_step: bool = False
"""When True, wait for all active runs to have data before packing a training step."""

experimental: TrainerExperimentalConfig = TrainerExperimentalConfig()

@model_validator(mode="after")
Expand Down
3 changes: 2 additions & 1 deletion packages/prime-rl-configs/src/prime_rl/utils/validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,8 @@ def propagate(shared_path: str, *targets: str, aliases: tuple[str, ...] = ()) ->
if get(sub) is not None:
conflicts.append(("output_dir", sub))
fill("trainer.output_dir", output_dir)
fill("orchestrator.output_dir", f"{output_dir}/run_default")
orchestrator_leaf = "orchestrator" if get("orchestrator.multi_agent_lora") is True else "run_default"
fill("orchestrator.output_dir", f"{output_dir}/{orchestrator_leaf}")

# Cascade trainer.tokenizer.chat_template → inference.model.chat_template
# (vLLM ``--chat-template``). Read trainer's value *after* the shared
Expand Down
7 changes: 5 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,10 @@ dev = [
members = [
"packages/prime-rl-configs",
"deps/pydantic-config",
"deps/verifiers",
# `deps/verifiers` itself is intentionally NOT a workspace member: the
# `verifiers` package is sourced from our fork via [tool.uv.sources] below.
# The environment sub-packages under deps/verifiers/environments/* remain
# workspace members so they can be imported by their workspace names.
"deps/renderers",
"deps/verifiers/environments/alphabet_sort",
"deps/verifiers/environments/math_python",
Expand Down Expand Up @@ -193,7 +196,7 @@ nixl-cu12 = false

[tool.uv.sources]
prime-rl-configs = { workspace = true }
verifiers = { workspace = true }
verifiers = { git = "https://github.com/nph4rd/verifiers.git", branch = "multiagent-no-opponent-conditioning" }
renderers = { workspace = true }
aime2024 = { workspace = true }
aime2025 = { workspace = true }
Expand Down
52 changes: 52 additions & 0 deletions src/prime_rl/orchestrator/advantage.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from collections import defaultdict
from dataclasses import dataclass
from typing import Callable

Expand Down Expand Up @@ -164,3 +165,54 @@ def compute_advantages(

for rollout, advantage in zip(rollouts, advantages):
rollout["advantage"] = advantage


def compute_per_agent_advantages(rollouts: list[dict]) -> None:
"""Compute per-agent GRPO advantages for multi-agent rollouts.

For multi-agent environments, each trajectory step is tagged with an agent_id
and has a per-agent reward. Standard GRPO computes advantages from the rollout-
level mean reward, which is invariant when agent payoffs sum to a constant.

This function computes advantages per agent: for each (example, agent) pair,
the baseline is the mean of that agent's rewards across rollouts of the same
example. Advantages are written directly to trajectory steps so they flow
through interleave_rollout -> TrainingSample.advantage.

No-ops if rollouts don't contain per-agent trajectory steps.
"""
if not rollouts:
return

# Quick check: do rollouts have per-agent trajectory steps?
has_agents = False
for r in rollouts[:3]:
for step in r.get("trajectory", []):
if step.get("extras", {}).get("agent_id"):
has_agents = True
break
if has_agents:
break
if not has_agents:
return

# Group rollouts by example_id
groups: dict[str, list[dict]] = defaultdict(list)
for r in rollouts:
groups[r["example_id"]].append(r)

for group in groups.values():
# Collect per-agent rewards: agent_id -> list of (step, reward)
agent_entries: dict[str, list[tuple[dict, float]]] = defaultdict(list)
for r in group:
for step in r.get("trajectory", []):
agent_id = step.get("extras", {}).get("agent_id")
reward = step.get("reward")
if agent_id is not None and reward is not None:
agent_entries[agent_id].append((step, reward))

# Compute per-agent baseline and set per-step advantages
for agent_id, entries in agent_entries.items():
baseline = sum(reward for _, reward in entries) / len(entries)
for step, reward in entries:
step["advantage"] = reward - baseline
12 changes: 11 additions & 1 deletion src/prime_rl/orchestrator/envs.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,7 @@ async def run_rollout(
example: dict,
model_name: str,
cache_salt: str,
actor_models: dict[str, str] | None = None,
) -> vf.RolloutOutput:
"""Run a single rollout for an example."""
return await self.env.run_rollout(
Expand All @@ -126,6 +127,7 @@ async def run_rollout(
max_retries=self.config.max_retries,
state_columns=REQUIRED_STATE_COLUMNS,
env_client=self.env_client,
actor_models=actor_models,
)

async def run_group(
Expand All @@ -135,6 +137,7 @@ async def run_group(
model_name: str,
rollouts_per_example: int,
cache_salt: str,
actor_models: dict[str, str] | None = None,
) -> list[vf.RolloutOutput]:
"""Run a group of rollouts for an example. Required for group-scoring envs."""
return await self.env.run_group(
Expand All @@ -145,6 +148,7 @@ async def run_group(
max_retries=self.config.max_retries,
state_columns=REQUIRED_STATE_COLUMNS,
env_client=self.env_client,
actor_models=actor_models,
)

def shutdown(self) -> None:
Expand Down Expand Up @@ -180,6 +184,7 @@ async def evaluate(
ckpt_step: int,
step: int,
cache_salt: str,
actor_models: dict[str, str] | None = None,
) -> list[vf.RolloutOutput]:
num_examples = len(self.examples)
rollouts_per_example = self.config.rollouts_per_example
Expand All @@ -200,6 +205,7 @@ async def run_with_progress(example: dict) -> list[vf.RolloutOutput] | None:
model_name=model_name,
rollouts_per_example=rollouts_per_example,
cache_salt=cache_salt,
actor_models=actor_models,
)
pbar.update(rollouts_per_example)
return outputs
Expand All @@ -217,7 +223,11 @@ async def run_with_progress(example: dict) -> list[vf.RolloutOutput] | None:
try:
client = await get_client()
output = await self.run_rollout(
client=client, example=example, model_name=model_name, cache_salt=cache_salt
client=client,
example=example,
model_name=model_name,
cache_salt=cache_salt,
actor_models=actor_models,
)
pbar.update(1)
return [output]
Expand Down
12 changes: 10 additions & 2 deletions src/prime_rl/orchestrator/filters.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,14 +98,22 @@ def check(self, rollout: vf.RolloutOutput) -> FilterResult:
class ZeroAdvantageFilter:
"""Flags rollouts with zero advantage.

This filter is applied after advantages are computed and checks if the
rollout's advantage field is zero.
This filter is applied after advantages are computed. Multi-agent rollouts
may have per-step advantages even when the rollout-level advantage is zero.
"""

name: str
enforce: bool = True

def check(self, rollout: vf.RolloutOutput) -> FilterResult:
step_advantages = [
step.get("advantage")
for step in rollout.get("trajectory", [])
if step.get("advantage") is not None
]
if step_advantages:
return FilterResult(detected=all(advantage == 0.0 for advantage in step_advantages))

advantage = rollout.get("advantage")
if advantage is not None and advantage == 0.0:
return FilterResult(detected=True)
Expand Down
70 changes: 64 additions & 6 deletions src/prime_rl/orchestrator/orchestrator.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import tomli_w

import prime_rl._compat # noqa: F401 — patch ring_flash_attn compat before transitive import
from prime_rl.orchestrator.advantage import compute_advantages
from prime_rl.orchestrator.advantage import compute_advantages, compute_per_agent_advantages
from prime_rl.orchestrator.eval_utils import compute_eval_ckpt_step
from prime_rl.orchestrator.event_loop_lag import EventLoopLagMonitor
from prime_rl.orchestrator.inference_metrics import InferenceMetricsCollector
Expand All @@ -18,7 +18,12 @@
offload_images_to_disk,
pretokenize_rollout_trajectory,
)
from prime_rl.transport import TrainingBatch, TrainingSample, setup_training_batch_sender
from prime_rl.transport import (
TrainingBatch,
TrainingSample,
setup_multi_run_training_batch_sender,
setup_training_batch_sender,
)
from prime_rl.utils.pathing import get_log_dir, get_rollout_dir, get_step_path
from prime_rl.utils.usage_reporter import UsageReporter

Expand Down Expand Up @@ -237,6 +242,32 @@ async def orchestrate(config: OrchestratorConfig):
else:
checkpoint_step = config.ckpt.resume_step

# Multi-agent LoRA: set up per-agent policy routing
actor_lora_mapping: dict[str, str] | None = None
if config.multi_agent_lora:
agent_sets = [tuple(getattr(env.env, "agents", ())) for env in train_envs]
agent_sets = [agents for agents in agent_sets if agents]
assert agent_sets, "multi_agent_lora requires at least one MultiAgentEnv with registered agents"
agents = agent_sets[0]
assert len(agents) >= 2, "multi_agent_lora requires a MultiAgentEnv with at least 2 registered agents"
assert all(agent_set == agents for agent_set in agent_sets), (
"multi_agent_lora requires all MultiAgentEnv train envs to use the same agents"
)
actor_lora_mapping = {agent_id: f"run_{agent_id}" for agent_id in agents}
logger.info(f"Multi-agent LoRA enabled: {actor_lora_mapping}")

# Create per-actor run directories with minimal orch.toml
# Only include model.lora.name and optim.lr — the trainer fills in the rest
for run_name in actor_lora_mapping.values():
run_config_dir = config.output_dir.parent / run_name / "control"
run_config_dir.mkdir(parents=True, exist_ok=True)
actor_orch_config = {
"model": {"lora": {"name": run_name}},
"optim": {"lr": config.optim.lr},
}
with open(run_config_dir / "orch.toml", "wb") as f:
tomli_w.dump(actor_orch_config, f)

scheduler = Scheduler(
train_envs=train_envs,
buffer=buffer,
Expand All @@ -249,6 +280,7 @@ async def orchestrate(config: OrchestratorConfig):
tasks_per_minute=config.tasks_per_minute,
lora_name=config.student.model.lora.name if config.student.model.lora else None,
config=config,
actor_lora_mapping=actor_lora_mapping,
)

# Wait for pools to be ready
Expand Down Expand Up @@ -281,7 +313,14 @@ async def orchestrate(config: OrchestratorConfig):

# Setup training batch sender for sending training examples to trainer
logger.info(f"Initializing training batch sender ({config.rollout_transport})")
training_batch_sender = setup_training_batch_sender(config.output_dir, config.rollout_transport)
if actor_lora_mapping is not None:
multi_run_sender = setup_multi_run_training_batch_sender(
config.output_dir.parent, list(actor_lora_mapping.values()), config.rollout_transport
)
training_batch_sender = multi_run_sender
else:
multi_run_sender = None
training_batch_sender = setup_training_batch_sender(config.output_dir, config.rollout_transport)

# Track last online eval checkpoint step per eval env
last_eval_steps: dict[str, int] = {env.name: -1 for env in eval_envs} if eval_envs else {}
Expand Down Expand Up @@ -393,6 +432,7 @@ async def orchestrate(config: OrchestratorConfig):
ckpt_step=ckpt_step,
step=progress.step,
cache_salt=str(ckpt_step),
actor_models=scheduler.actor_model_names or None,
)
for eval_env in envs_to_eval
]
Expand Down Expand Up @@ -427,6 +467,7 @@ async def orchestrate(config: OrchestratorConfig):
num_rollouts = len(train_rollouts)
num_unique_examples = len({(r["env_name"], r["example_id"]) for r in train_rollouts})
compute_advantages(train_rollouts, config.rollouts_per_example, config.advantage)
compute_per_agent_advantages(train_rollouts)

# Apply rollout filters — sets rollout["filters"] and rollout["is_filtered"]
apply_filters(rollout_filters, train_rollouts)
Expand Down Expand Up @@ -521,12 +562,15 @@ async def _pretokenize_all() -> None:
mm_token_type_ids_mapping = None

# Process rollouts in parallel
split_by_agent = actor_lora_mapping is not None

def process_rollout(rollout: vf.RolloutOutput, rollout_idx: int) -> list[TrainingSample] | None:
return interleave_rollout(
rollout,
vlm_cache=vlm_cache,
cache_key=rollout_idx,
mm_token_type_ids_mapping=mm_token_type_ids_mapping,
split_by_agent=split_by_agent,
)

results = await asyncio.gather(
Expand All @@ -548,8 +592,10 @@ def process_rollout(rollout: vf.RolloutOutput, rollout_idx: int) -> list[Trainin
samples = []
rollout_samples_per_rollout.append(len(samples))
for sample in samples:
sample.advantage = rollout["advantage"]
sample.reward = rollout["reward"]
if sample.advantage is None:
sample.advantage = rollout["advantage"]
if sample.reward is None:
sample.reward = rollout["reward"]
sample.env_name = rollout["env_name"]
sample.training_mode = config.training_mode
sample_decode_tokens = sum(sample.completion_mask)
Expand Down Expand Up @@ -590,7 +636,18 @@ def process_rollout(rollout: vf.RolloutOutput, rollout_idx: int) -> list[Trainin
step=progress.step,
)

training_batch_sender.send(training_batch)
if multi_run_sender is not None and actor_lora_mapping is not None:
per_actor_examples: dict[str, list[TrainingSample]] = {
run_name: [] for run_name in actor_lora_mapping.values()
}
for sample in train_examples:
if sample.actor_id is not None and sample.actor_id in actor_lora_mapping:
per_actor_examples[actor_lora_mapping[sample.actor_id]].append(sample)
for run_name, examples in per_actor_examples.items():
if examples:
multi_run_sender.send_to_run(run_name, TrainingBatch(examples=examples, step=progress.step))
else:
training_batch_sender.send(training_batch)

step_time = time.perf_counter() - step_start_time

Expand Down Expand Up @@ -820,6 +877,7 @@ def compute_solve_rates(df):
ckpt_step=ckpt_step,
step=progress.step,
cache_salt=str(ckpt_step),
actor_models=scheduler.actor_model_names or None,
)
for eval_env in eval_envs
]
Expand Down
Loading