Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ override-dependencies = [

[tool.uv.sources]
torch = { index = "pytorch-cu128" }
verifiers = { git = "https://github.com/PrimeIntellect-ai/verifiers.git", rev = "1960e77" }
verifiers = { git = "https://github.com/nph4rd/verifiers.git", branch = "multiagent-no-opponent-conditioning" }
torchtitan = { git = "https://github.com/pytorch/torchtitan", rev = "a1fdd7e" }
dion = { git = "https://github.com/samsja/dion.git", rev = "d891eeb" }
transformers = { git = "https://github.com/huggingface/transformers.git", rev = "5c1c72b" }
Expand Down
11 changes: 11 additions & 0 deletions src/prime_rl/configs/orchestrator.py
Original file line number Diff line number Diff line change
Expand Up @@ -726,6 +726,17 @@ class OrchestratorConfig(BaseConfig):

weight_broadcast: WeightBroadcastConfig = FileSystemWeightBroadcastConfig()

multi_agent_lora: Annotated[
bool,
Field(
description=(
"Enable per-agent LoRA training for multi-agent environments. "
"Each agent trains its own LoRA adapter while sharing the base model. "
"Requires the environment to be a MultiAgentEnv with registered agents."
),
),
] = False

rollout_transport: TransportConfig = FileSystemTransportConfig()

output_dir: Annotated[
Expand Down
32 changes: 31 additions & 1 deletion src/prime_rl/configs/rl.py
Original file line number Diff line number Diff line change
Expand Up @@ -369,7 +369,12 @@ def validate_teacher_model(self):
def auto_setup_output_dir(self):
"""Auto-setup shared output directory for trainer and orchestrator."""
self.trainer.output_dir = self.output_dir
self.orchestrator.output_dir = self.output_dir / "run_default"
if self.orchestrator.multi_agent_lora:
# Multi-agent LoRA: orchestrator writes to the trainer output_dir root,
# and per-agent run dirs (run_<agent_id>/) are created alongside it.
self.orchestrator.output_dir = self.output_dir / "orchestrator"
else:
self.orchestrator.output_dir = self.output_dir / "run_default"

validate_shared_output_dir(self.trainer, self.orchestrator)

Expand Down Expand Up @@ -615,6 +620,31 @@ def auto_setup_lora(self):

return self

@model_validator(mode="after")
def auto_setup_multi_agent_lora(self):
"""Auto-configure multi-agent LoRA: ensure LoRA and inference settings are consistent."""
if not self.orchestrator.multi_agent_lora:
return self

if self.trainer.model.lora is None:
raise ValueError(
"multi_agent_lora requires trainer.model.lora to be configured. "
"Each agent trains its own LoRA adapter."
)

# Ensure packer waits for all runs to have data before advancing a step
self.trainer.pack_full_step = True

# Need at least 2 LoRA slots for multi-agent training
if self.trainer.max_concurrent_runs < 2:
self.trainer.max_concurrent_runs = 2

if self.inference is not None:
self.inference.enable_lora = True
self.inference.max_lora_rank = self.trainer.model.lora.rank

return self

@model_validator(mode="after")
def auto_setup_router_replay(self):
if self.trainer.enable_router_replay:
Expand Down
8 changes: 8 additions & 0 deletions src/prime_rl/configs/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -705,6 +705,14 @@ class TrainerConfig(BaseConfig):
),
] = 1

pack_full_step: Annotated[
bool,
Field(
description="When True, the packer waits for ALL active runs to have data before packing a step. "
"Required for multi-agent LoRA training where each agent must contribute data every step.",
),
] = False

@model_validator(mode="after")
def auto_setup_bench(self):
if self.bench is not None:
Expand Down
52 changes: 52 additions & 0 deletions src/prime_rl/orchestrator/advantage.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from collections import defaultdict
from dataclasses import dataclass
from typing import Callable

Expand Down Expand Up @@ -97,3 +98,54 @@ def compute_advantages(

result = advantage_fn(inputs)
return result.advantages.flatten().tolist()


def compute_per_agent_advantages(rollouts: list[dict]) -> None:
"""Compute per-agent GRPO advantages for multi-agent rollouts.

For multi-agent environments, each trajectory step is tagged with an agent_id
and has a per-agent reward. Standard GRPO computes advantages from the rollout-
level mean reward, which is invariant when agent payoffs sum to a constant.

This function computes advantages per agent: for each (example, agent) pair,
the baseline is the mean of that agent's rewards across rollouts of the same
example. Advantages are written directly to trajectory steps so they flow
through interleave_rollout -> TrainingSample.advantage.

No-ops if rollouts don't contain per-agent trajectory steps.
"""
if not rollouts:
return

# Quick check: do rollouts have per-agent trajectory steps?
has_agents = False
for r in rollouts[:3]:
for step in r.get("trajectory", []):
if step.get("extras", {}).get("agent_id"):
has_agents = True
break
if has_agents:
break
if not has_agents:
return

# Group rollouts by example_id
groups: dict[str, list[dict]] = defaultdict(list)
for r in rollouts:
groups[r["example_id"]].append(r)

for group in groups.values():
# Collect per-agent rewards: agent_id -> list of (step, reward)
agent_entries: dict[str, list[tuple[dict, float]]] = defaultdict(list)
for r in group:
for step in r.get("trajectory", []):
agent_id = step.get("extras", {}).get("agent_id")
reward = step.get("reward")
if agent_id is not None and reward is not None:
agent_entries[agent_id].append((step, reward))

# Compute per-agent baseline and set per-step advantages
for agent_id, entries in agent_entries.items():
baseline = sum(reward for _, reward in entries) / len(entries)
for step, reward in entries:
step["advantage"] = reward - baseline
2 changes: 2 additions & 0 deletions src/prime_rl/orchestrator/eval_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,7 @@ async def evaluate_env(
ckpt_step: int,
step: int,
get_client: Callable[[], Awaitable[vf.ClientConfig]],
actor_models: dict[str, str] | None = None,
):
logger = get_logger()
logger.info(f"Evaluating {env_name} ({num_examples=}, {rollouts_per_example=})")
Expand All @@ -111,6 +112,7 @@ async def evaluate_env(
rollouts_per_example=rollouts_per_example,
get_client=get_client,
max_retries=max_retries,
actor_models=actor_models,
)
eval_time = time.perf_counter() - eval_start_time

Expand Down
88 changes: 81 additions & 7 deletions src/prime_rl/orchestrator/orchestrator.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,12 @@

import tomli_w

from prime_rl.orchestrator.advantage import compute_advantages
from prime_rl.orchestrator.advantage import compute_advantages, compute_per_agent_advantages
from prime_rl.orchestrator.eval_utils import compute_eval_ckpt_step, get_eval_sampling_args
from prime_rl.orchestrator.event_loop_lag import EventLoopLagMonitor
from prime_rl.orchestrator.patches import monkey_patch_chat_completion_logprobs, monkey_patch_oai_iterable_types
from prime_rl.orchestrator.trajectories import build_vlm_image_cache, interleave_rollout, offload_images_to_disk
from prime_rl.transport import TrainingBatch, TrainingSample, setup_training_batch_sender
from prime_rl.transport import TrainingBatch, TrainingSample, setup_training_batch_sender, setup_multi_run_training_batch_sender
from prime_rl.utils.pathing import get_log_dir

# This monkey patch is necessary to avoid Pydantic validating fields using typing.Iterable (e.g. in multimodal or tool call messages) lazily which leads to tokenization errors, for more info see https://github.com/PrimeIntellect-ai/prime-rl/pull/1249
Expand Down Expand Up @@ -302,6 +302,28 @@ def _cleanup_env_processes():
else:
checkpoint_step = config.ckpt.resume_step

# Multi-agent LoRA: set up per-agent policy routing
actor_lora_mapping: dict[str, str] | None = None
if config.multi_agent_lora:
agents = train_env_group.envs[0].agents
assert len(agents) >= 2, (
"multi_agent_lora requires a MultiAgentEnv with at least 2 registered agents"
)
actor_lora_mapping = {agent_id: f"run_{agent_id}" for agent_id in agents}
logger.info(f"Multi-agent LoRA enabled: {actor_lora_mapping}")

# Create per-actor run directories with minimal orch.toml
# Only include model.lora.name and optim.lr — the trainer fills in the rest
for run_name in actor_lora_mapping.values():
run_config_dir = config.output_dir.parent / run_name / "control"
run_config_dir.mkdir(parents=True, exist_ok=True)
actor_orch_config = {
"model": {"lora": {"name": run_name}},
"optim": {"lr": config.optim.lr},
}
with open(run_config_dir / "orch.toml", "wb") as f:
tomli_w.dump(actor_orch_config, f)

scheduler = Scheduler(
env=train_env_group,
buffer=buffer,
Expand All @@ -314,6 +336,7 @@ def _cleanup_env_processes():
lora_name=config.model.lora.name if config.model.lora else None,
deferred_group_scoring_tasks=train_env_deferred_group_scoring_tasks,
config=config,
actor_lora_mapping=actor_lora_mapping,
)

if checkpoint_step is not None and config.model.lora is not None:
Expand Down Expand Up @@ -343,7 +366,14 @@ def _cleanup_env_processes():

# Setup training batch sender for sending training examples to trainer
logger.info(f"Initializing training batch sender ({config.rollout_transport})")
training_batch_sender = setup_training_batch_sender(config.output_dir, config.rollout_transport)
if actor_lora_mapping is not None:
multi_run_sender = setup_multi_run_training_batch_sender(
config.output_dir.parent, list(actor_lora_mapping.values()), config.rollout_transport
)
training_batch_sender = multi_run_sender
else:
multi_run_sender = None
training_batch_sender = setup_training_batch_sender(config.output_dir, config.rollout_transport)

# Track last online eval checkpoint step for this process
last_eval_step = -1
Expand Down Expand Up @@ -458,6 +488,7 @@ def _cleanup_env_processes():
max_retries=eval_env_config.max_retries,
ckpt_step=ckpt_step,
step=progress.step,
actor_models=scheduler.actor_model_names or None,
)
for eval_env, eval_env_name, eval_env_config in zip(eval_envs, eval_env_names, config.eval.env)
]
Expand Down Expand Up @@ -523,6 +554,11 @@ def _cleanup_env_processes():
config.advantage,
)

# For multi-agent environments, compute per-agent advantages and set them
# on trajectory steps. This overrides the rollout-level advantage for each
# agent's steps with an advantage relative to that agent's baseline reward.
compute_per_agent_advantages(train_rollouts)

# Convert rollouts to training samples
parallel_preprocess_start = time.perf_counter()

Expand All @@ -537,8 +573,10 @@ def _cleanup_env_processes():
vlm_cache = None

# Process rollouts in parallel
split_by_agent = actor_lora_mapping is not None

def process_rollout(rollout: vf.RolloutOutput, rollout_idx: int) -> list[TrainingSample] | None:
return interleave_rollout(rollout, vlm_cache=vlm_cache, cache_key=rollout_idx)
return interleave_rollout(rollout, vlm_cache=vlm_cache, cache_key=rollout_idx, split_by_agent=split_by_agent)

loop = asyncio.get_event_loop()
futures = [
Expand All @@ -560,8 +598,11 @@ def process_rollout(rollout: vf.RolloutOutput, rollout_idx: int) -> list[Trainin
if samples is not None:
rollout_samples_per_rollout.append(len(samples))
for sample in samples:
sample.advantage = advantage
sample.reward = rollout["reward"]
# Use sample-level values if set (multi-agent), else rollout-level
if sample.advantage is None:
sample.advantage = advantage
if sample.reward is None:
sample.reward = rollout["reward"]
sample_decode_tokens = sum(sample.completion_mask)
sample_prefill_tokens = len(sample.prompt_ids) + len(sample.completion_mask) - sample_decode_tokens
rollout_decode_tokens += sample_decode_tokens
Expand Down Expand Up @@ -600,7 +641,39 @@ def process_rollout(rollout: vf.RolloutOutput, rollout_idx: int) -> list[Trainin
step=progress.step,
)

training_batch_sender.send(training_batch)
# Multi-agent LoRA: split training samples by actor_id and send to per-agent run dirs
if multi_run_sender is not None and actor_lora_mapping is not None:
per_actor_examples: dict[str, list[TrainingSample]] = {
run_name: [] for run_name in actor_lora_mapping.values()
}
for sample in train_examples:
if sample.actor_id is not None and sample.actor_id in actor_lora_mapping:
per_actor_examples[actor_lora_mapping[sample.actor_id]].append(sample)
for run_name, examples in per_actor_examples.items():
if examples:
multi_run_sender.send_to_run(run_name, TrainingBatch(examples=examples, step=progress.step))

# Retry with exponential backoff if batch is empty (e.g., inference temporarily unavailable)
if len(training_batch.examples) == 0:
empty_batch_retries += 1
if empty_batch_retries >= max_empty_batch_retries:
raise RuntimeError(
f"Step {progress.step} failed after {max_empty_batch_retries} consecutive empty batches"
)
backoff = min(30 * (2 ** (empty_batch_retries - 1)), 300) # 30s, 60s, 120s, 240s, 300s cap
logger.warning(
f"Step {progress.step} produced 0 training samples "
f"(attempt {empty_batch_retries}/{max_empty_batch_retries}). Retrying in {backoff}s..."
)
# Cancel validation task to avoid accumulating background tasks
val_task.cancel()
await asyncio.sleep(backoff)
continue

# Reset retry counter on successful batch
empty_batch_retries = 0
if multi_run_sender is None:
training_batch_sender.send(training_batch)

# Await and process val results
await val_task
Expand Down Expand Up @@ -829,6 +902,7 @@ def compute_solve_rates(df):
max_retries=eval_env_config.max_retries,
ckpt_step=ckpt_step,
step=progress.step,
actor_models=scheduler.actor_model_names or None,
)
for eval_env, eval_env_name, eval_env_config in zip(eval_envs, eval_env_names, config.eval.env)
]
Expand Down
Loading