diff --git a/pyproject.toml b/pyproject.toml index a684dcc056..babd935f95 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -102,7 +102,7 @@ override-dependencies = [ [tool.uv.sources] torch = { index = "pytorch-cu128" } -verifiers = { git = "https://github.com/PrimeIntellect-ai/verifiers.git", rev = "1960e77" } +verifiers = { git = "https://github.com/nph4rd/verifiers.git", branch = "multiagent-no-opponent-conditioning" } torchtitan = { git = "https://github.com/pytorch/torchtitan", rev = "a1fdd7e" } dion = { git = "https://github.com/samsja/dion.git", rev = "d891eeb" } transformers = { git = "https://github.com/huggingface/transformers.git", rev = "5c1c72b" } diff --git a/src/prime_rl/configs/orchestrator.py b/src/prime_rl/configs/orchestrator.py index 506eed7905..9230980d32 100644 --- a/src/prime_rl/configs/orchestrator.py +++ b/src/prime_rl/configs/orchestrator.py @@ -726,6 +726,17 @@ class OrchestratorConfig(BaseConfig): weight_broadcast: WeightBroadcastConfig = FileSystemWeightBroadcastConfig() + multi_agent_lora: Annotated[ + bool, + Field( + description=( + "Enable per-agent LoRA training for multi-agent environments. " + "Each agent trains its own LoRA adapter while sharing the base model. " + "Requires the environment to be a MultiAgentEnv with registered agents." + ), + ), + ] = False + rollout_transport: TransportConfig = FileSystemTransportConfig() output_dir: Annotated[ diff --git a/src/prime_rl/configs/rl.py b/src/prime_rl/configs/rl.py index 1b8e3d2067..9dbc84aed3 100644 --- a/src/prime_rl/configs/rl.py +++ b/src/prime_rl/configs/rl.py @@ -369,7 +369,12 @@ def validate_teacher_model(self): def auto_setup_output_dir(self): """Auto-setup shared output directory for trainer and orchestrator.""" self.trainer.output_dir = self.output_dir - self.orchestrator.output_dir = self.output_dir / "run_default" + if self.orchestrator.multi_agent_lora: + # Multi-agent LoRA: orchestrator writes to the trainer output_dir root, + # and per-agent run dirs (run_/) are created alongside it. + self.orchestrator.output_dir = self.output_dir / "orchestrator" + else: + self.orchestrator.output_dir = self.output_dir / "run_default" validate_shared_output_dir(self.trainer, self.orchestrator) @@ -615,6 +620,31 @@ def auto_setup_lora(self): return self + @model_validator(mode="after") + def auto_setup_multi_agent_lora(self): + """Auto-configure multi-agent LoRA: ensure LoRA and inference settings are consistent.""" + if not self.orchestrator.multi_agent_lora: + return self + + if self.trainer.model.lora is None: + raise ValueError( + "multi_agent_lora requires trainer.model.lora to be configured. " + "Each agent trains its own LoRA adapter." + ) + + # Ensure packer waits for all runs to have data before advancing a step + self.trainer.pack_full_step = True + + # Need at least 2 LoRA slots for multi-agent training + if self.trainer.max_concurrent_runs < 2: + self.trainer.max_concurrent_runs = 2 + + if self.inference is not None: + self.inference.enable_lora = True + self.inference.max_lora_rank = self.trainer.model.lora.rank + + return self + @model_validator(mode="after") def auto_setup_router_replay(self): if self.trainer.enable_router_replay: diff --git a/src/prime_rl/configs/trainer.py b/src/prime_rl/configs/trainer.py index c8a87a88af..08fa877ee7 100644 --- a/src/prime_rl/configs/trainer.py +++ b/src/prime_rl/configs/trainer.py @@ -705,6 +705,14 @@ class TrainerConfig(BaseConfig): ), ] = 1 + pack_full_step: Annotated[ + bool, + Field( + description="When True, the packer waits for ALL active runs to have data before packing a step. " + "Required for multi-agent LoRA training where each agent must contribute data every step.", + ), + ] = False + @model_validator(mode="after") def auto_setup_bench(self): if self.bench is not None: diff --git a/src/prime_rl/orchestrator/advantage.py b/src/prime_rl/orchestrator/advantage.py index 25f942a044..acb1a21f93 100644 --- a/src/prime_rl/orchestrator/advantage.py +++ b/src/prime_rl/orchestrator/advantage.py @@ -1,3 +1,4 @@ +from collections import defaultdict from dataclasses import dataclass from typing import Callable @@ -97,3 +98,54 @@ def compute_advantages( result = advantage_fn(inputs) return result.advantages.flatten().tolist() + + +def compute_per_agent_advantages(rollouts: list[dict]) -> None: + """Compute per-agent GRPO advantages for multi-agent rollouts. + + For multi-agent environments, each trajectory step is tagged with an agent_id + and has a per-agent reward. Standard GRPO computes advantages from the rollout- + level mean reward, which is invariant when agent payoffs sum to a constant. + + This function computes advantages per agent: for each (example, agent) pair, + the baseline is the mean of that agent's rewards across rollouts of the same + example. Advantages are written directly to trajectory steps so they flow + through interleave_rollout -> TrainingSample.advantage. + + No-ops if rollouts don't contain per-agent trajectory steps. + """ + if not rollouts: + return + + # Quick check: do rollouts have per-agent trajectory steps? + has_agents = False + for r in rollouts[:3]: + for step in r.get("trajectory", []): + if step.get("extras", {}).get("agent_id"): + has_agents = True + break + if has_agents: + break + if not has_agents: + return + + # Group rollouts by example_id + groups: dict[str, list[dict]] = defaultdict(list) + for r in rollouts: + groups[r["example_id"]].append(r) + + for group in groups.values(): + # Collect per-agent rewards: agent_id -> list of (step, reward) + agent_entries: dict[str, list[tuple[dict, float]]] = defaultdict(list) + for r in group: + for step in r.get("trajectory", []): + agent_id = step.get("extras", {}).get("agent_id") + reward = step.get("reward") + if agent_id is not None and reward is not None: + agent_entries[agent_id].append((step, reward)) + + # Compute per-agent baseline and set per-step advantages + for agent_id, entries in agent_entries.items(): + baseline = sum(reward for _, reward in entries) / len(entries) + for step, reward in entries: + step["advantage"] = reward - baseline diff --git a/src/prime_rl/orchestrator/eval_utils.py b/src/prime_rl/orchestrator/eval_utils.py index 5dee7f3aff..a613fbd70f 100644 --- a/src/prime_rl/orchestrator/eval_utils.py +++ b/src/prime_rl/orchestrator/eval_utils.py @@ -99,6 +99,7 @@ async def evaluate_env( ckpt_step: int, step: int, get_client: Callable[[], Awaitable[vf.ClientConfig]], + actor_models: dict[str, str] | None = None, ): logger = get_logger() logger.info(f"Evaluating {env_name} ({num_examples=}, {rollouts_per_example=})") @@ -111,6 +112,7 @@ async def evaluate_env( rollouts_per_example=rollouts_per_example, get_client=get_client, max_retries=max_retries, + actor_models=actor_models, ) eval_time = time.perf_counter() - eval_start_time diff --git a/src/prime_rl/orchestrator/orchestrator.py b/src/prime_rl/orchestrator/orchestrator.py index 7331ab3b0c..5dd648c21c 100644 --- a/src/prime_rl/orchestrator/orchestrator.py +++ b/src/prime_rl/orchestrator/orchestrator.py @@ -8,12 +8,12 @@ import tomli_w -from prime_rl.orchestrator.advantage import compute_advantages +from prime_rl.orchestrator.advantage import compute_advantages, compute_per_agent_advantages from prime_rl.orchestrator.eval_utils import compute_eval_ckpt_step, get_eval_sampling_args from prime_rl.orchestrator.event_loop_lag import EventLoopLagMonitor from prime_rl.orchestrator.patches import monkey_patch_chat_completion_logprobs, monkey_patch_oai_iterable_types from prime_rl.orchestrator.trajectories import build_vlm_image_cache, interleave_rollout, offload_images_to_disk -from prime_rl.transport import TrainingBatch, TrainingSample, setup_training_batch_sender +from prime_rl.transport import TrainingBatch, TrainingSample, setup_training_batch_sender, setup_multi_run_training_batch_sender from prime_rl.utils.pathing import get_log_dir # This monkey patch is necessary to avoid Pydantic validating fields using typing.Iterable (e.g. in multimodal or tool call messages) lazily which leads to tokenization errors, for more info see https://github.com/PrimeIntellect-ai/prime-rl/pull/1249 @@ -302,6 +302,28 @@ def _cleanup_env_processes(): else: checkpoint_step = config.ckpt.resume_step + # Multi-agent LoRA: set up per-agent policy routing + actor_lora_mapping: dict[str, str] | None = None + if config.multi_agent_lora: + agents = train_env_group.envs[0].agents + assert len(agents) >= 2, ( + "multi_agent_lora requires a MultiAgentEnv with at least 2 registered agents" + ) + actor_lora_mapping = {agent_id: f"run_{agent_id}" for agent_id in agents} + logger.info(f"Multi-agent LoRA enabled: {actor_lora_mapping}") + + # Create per-actor run directories with minimal orch.toml + # Only include model.lora.name and optim.lr — the trainer fills in the rest + for run_name in actor_lora_mapping.values(): + run_config_dir = config.output_dir.parent / run_name / "control" + run_config_dir.mkdir(parents=True, exist_ok=True) + actor_orch_config = { + "model": {"lora": {"name": run_name}}, + "optim": {"lr": config.optim.lr}, + } + with open(run_config_dir / "orch.toml", "wb") as f: + tomli_w.dump(actor_orch_config, f) + scheduler = Scheduler( env=train_env_group, buffer=buffer, @@ -314,6 +336,7 @@ def _cleanup_env_processes(): lora_name=config.model.lora.name if config.model.lora else None, deferred_group_scoring_tasks=train_env_deferred_group_scoring_tasks, config=config, + actor_lora_mapping=actor_lora_mapping, ) if checkpoint_step is not None and config.model.lora is not None: @@ -343,7 +366,14 @@ def _cleanup_env_processes(): # Setup training batch sender for sending training examples to trainer logger.info(f"Initializing training batch sender ({config.rollout_transport})") - training_batch_sender = setup_training_batch_sender(config.output_dir, config.rollout_transport) + if actor_lora_mapping is not None: + multi_run_sender = setup_multi_run_training_batch_sender( + config.output_dir.parent, list(actor_lora_mapping.values()), config.rollout_transport + ) + training_batch_sender = multi_run_sender + else: + multi_run_sender = None + training_batch_sender = setup_training_batch_sender(config.output_dir, config.rollout_transport) # Track last online eval checkpoint step for this process last_eval_step = -1 @@ -458,6 +488,7 @@ def _cleanup_env_processes(): max_retries=eval_env_config.max_retries, ckpt_step=ckpt_step, step=progress.step, + actor_models=scheduler.actor_model_names or None, ) for eval_env, eval_env_name, eval_env_config in zip(eval_envs, eval_env_names, config.eval.env) ] @@ -523,6 +554,11 @@ def _cleanup_env_processes(): config.advantage, ) + # For multi-agent environments, compute per-agent advantages and set them + # on trajectory steps. This overrides the rollout-level advantage for each + # agent's steps with an advantage relative to that agent's baseline reward. + compute_per_agent_advantages(train_rollouts) + # Convert rollouts to training samples parallel_preprocess_start = time.perf_counter() @@ -537,8 +573,10 @@ def _cleanup_env_processes(): vlm_cache = None # Process rollouts in parallel + split_by_agent = actor_lora_mapping is not None + def process_rollout(rollout: vf.RolloutOutput, rollout_idx: int) -> list[TrainingSample] | None: - return interleave_rollout(rollout, vlm_cache=vlm_cache, cache_key=rollout_idx) + return interleave_rollout(rollout, vlm_cache=vlm_cache, cache_key=rollout_idx, split_by_agent=split_by_agent) loop = asyncio.get_event_loop() futures = [ @@ -560,8 +598,11 @@ def process_rollout(rollout: vf.RolloutOutput, rollout_idx: int) -> list[Trainin if samples is not None: rollout_samples_per_rollout.append(len(samples)) for sample in samples: - sample.advantage = advantage - sample.reward = rollout["reward"] + # Use sample-level values if set (multi-agent), else rollout-level + if sample.advantage is None: + sample.advantage = advantage + if sample.reward is None: + sample.reward = rollout["reward"] sample_decode_tokens = sum(sample.completion_mask) sample_prefill_tokens = len(sample.prompt_ids) + len(sample.completion_mask) - sample_decode_tokens rollout_decode_tokens += sample_decode_tokens @@ -600,7 +641,39 @@ def process_rollout(rollout: vf.RolloutOutput, rollout_idx: int) -> list[Trainin step=progress.step, ) - training_batch_sender.send(training_batch) + # Multi-agent LoRA: split training samples by actor_id and send to per-agent run dirs + if multi_run_sender is not None and actor_lora_mapping is not None: + per_actor_examples: dict[str, list[TrainingSample]] = { + run_name: [] for run_name in actor_lora_mapping.values() + } + for sample in train_examples: + if sample.actor_id is not None and sample.actor_id in actor_lora_mapping: + per_actor_examples[actor_lora_mapping[sample.actor_id]].append(sample) + for run_name, examples in per_actor_examples.items(): + if examples: + multi_run_sender.send_to_run(run_name, TrainingBatch(examples=examples, step=progress.step)) + + # Retry with exponential backoff if batch is empty (e.g., inference temporarily unavailable) + if len(training_batch.examples) == 0: + empty_batch_retries += 1 + if empty_batch_retries >= max_empty_batch_retries: + raise RuntimeError( + f"Step {progress.step} failed after {max_empty_batch_retries} consecutive empty batches" + ) + backoff = min(30 * (2 ** (empty_batch_retries - 1)), 300) # 30s, 60s, 120s, 240s, 300s cap + logger.warning( + f"Step {progress.step} produced 0 training samples " + f"(attempt {empty_batch_retries}/{max_empty_batch_retries}). Retrying in {backoff}s..." + ) + # Cancel validation task to avoid accumulating background tasks + val_task.cancel() + await asyncio.sleep(backoff) + continue + + # Reset retry counter on successful batch + empty_batch_retries = 0 + if multi_run_sender is None: + training_batch_sender.send(training_batch) # Await and process val results await val_task @@ -829,6 +902,7 @@ def compute_solve_rates(df): max_retries=eval_env_config.max_retries, ckpt_step=ckpt_step, step=progress.step, + actor_models=scheduler.actor_model_names or None, ) for eval_env, eval_env_name, eval_env_config in zip(eval_envs, eval_env_names, config.eval.env) ] diff --git a/src/prime_rl/orchestrator/scheduler.py b/src/prime_rl/orchestrator/scheduler.py index 6ea82e4d6e..0417362523 100644 --- a/src/prime_rl/orchestrator/scheduler.py +++ b/src/prime_rl/orchestrator/scheduler.py @@ -68,6 +68,7 @@ def __init__( tasks_per_minute: int | None, lora_name: str | None = None, deferred_group_scoring_tasks: set[str] | None = None, + actor_lora_mapping: dict[str, str] | None = None, ): self.logger = get_logger() if tasks_per_minute is not None: @@ -99,6 +100,17 @@ def __init__( task_list = ", ".join(sorted(self.deferred_group_scoring_tasks)) self.logger.info(f"Deferred group scoring active for task(s): {task_list}") + # Multi-agent LoRA: per-actor policy routing + self.actor_lora_mapping = actor_lora_mapping + if actor_lora_mapping is not None: + self.actor_ckpt_steps: dict[str, int] = {agent_id: 0 for agent_id in actor_lora_mapping} + self.actor_model_names: dict[str, str] = {agent_id: config.model.name for agent_id in actor_lora_mapping} + self.actor_run_dirs: dict[str, str] = { + agent_id: run_name for agent_id, run_name in actor_lora_mapping.items() + } + else: + self.actor_model_names: dict[str, str] = {} + # Track in-flight requests: task -> info self.inflight_requests: dict[asyncio.Task, InflightRolloutInfo] = {} @@ -205,6 +217,7 @@ async def schedule_rollout(self, group_id: int): model_name=self.model_name, sampling_args=self.sampling_args, max_retries=self.max_retries_by_task.get(group.example["task"], 0), + actor_models=self.actor_model_names or None, ) ) self.inflight_requests[run_rollout_task] = InflightRolloutInfo( @@ -255,6 +268,9 @@ def _compute_next_ckpt_step(self) -> int: return max(async_away_ckpt_step, latest_ckpt_step) async def _apply_policy_update(self, next_ckpt_step: int) -> None: + if self.actor_lora_mapping is not None: + return await self._apply_policy_update_multi_actor(next_ckpt_step) + async_away_ckpt_step = max(self.step - self.max_async_level, 0) if next_ckpt_step == async_away_ckpt_step: self.logger.info( @@ -306,13 +322,71 @@ def _clear_inflight_policy_update(done_task: asyncio.Task) -> None: async def maybe_update_policy(self): """Updates the policy to the latest available checkpoint. Aborts rollout requests that are older than the max retention steps.""" while True: - next_ckpt_step = self._compute_next_ckpt_step() + if self.actor_lora_mapping is not None: + next_ckpt_step = self._compute_next_ckpt_step_multi_actor() + else: + next_ckpt_step = self._compute_next_ckpt_step() if next_ckpt_step <= self.ckpt_step: return task = await self._get_or_start_policy_update_task(next_ckpt_step) await asyncio.shield(task) + def _compute_next_ckpt_step_multi_actor(self) -> int: + """Compute next checkpoint step for multi-actor mode (minimum across all actors).""" + min_latest = float("inf") + for agent_id, run_name in self.actor_lora_mapping.items(): + broadcast_dir = get_broadcast_dir(self.config.output_dir.parent / run_name) + latest_step = get_latest_ckpt_step(broadcast_dir) or 0 + min_latest = min(min_latest, latest_step) + next_step = int(min_latest) if min_latest != float("inf") else 0 + async_away_ckpt_step = max(self.step - self.max_async_level, 0) + if self.strict_async_level: + return async_away_ckpt_step + return max(async_away_ckpt_step, next_step) + + async def _apply_policy_update_multi_actor(self, next_ckpt_step: int) -> None: + """Apply policy update for multi-actor LoRA mode.""" + async_away_ckpt_step = max(self.step - self.max_async_level, 0) + if next_ckpt_step == async_away_ckpt_step: + self.checkpoint_ready.clear() + self.logger.info( + f"Orchestrator paused: waiting for all actors to reach step {next_ckpt_step} " + f"(>{self.max_async_level} step(s) ahead). Training is progressing normally." + ) + wait_start = time.perf_counter() + for agent_id, run_name in self.actor_lora_mapping.items(): + broadcast_dir = get_broadcast_dir(self.config.output_dir.parent / run_name) + await wait_for_path(get_step_path(broadcast_dir, next_ckpt_step) / "STABLE") + self.wait_for_ckpt_time = time.perf_counter() - wait_start + self.logger.info( + f"Orchestrator resumed: all actors reached step {next_ckpt_step} " + f"(after {self.wait_for_ckpt_time:.2f}s)" + ) + + for agent_id, run_name in self.actor_lora_mapping.items(): + broadcast_dir = get_broadcast_dir(self.config.output_dir.parent / run_name) + latest_step = get_latest_ckpt_step(broadcast_dir) or 0 + + if latest_step > self.actor_ckpt_steps[agent_id]: + weights_path = get_step_path(broadcast_dir, latest_step) + lora_name = run_name + update_start = time.perf_counter() + await self.inference_pool.update_weights(weights_path, lora_name=lora_name, step=latest_step) + self.update_weights_time = time.perf_counter() - update_start + + self.actor_ckpt_steps[agent_id] = latest_step + self.actor_model_names[agent_id] = lora_name + self.inference_pool.update_model_name(lora_name) + self.logger.debug( + f"Updated actor {agent_id} ({lora_name}) to step {latest_step} " + f"in {self.update_weights_time:.2f}s" + ) + + self.ckpt_step = next_ckpt_step + self.checkpoint_ready.set() + await self._update_off_policy() + async def _update_off_policy(self) -> None: stale_group_ids = { info.group_id diff --git a/src/prime_rl/orchestrator/trajectories.py b/src/prime_rl/orchestrator/trajectories.py index 56c3c23f04..31d4841785 100644 --- a/src/prime_rl/orchestrator/trajectories.py +++ b/src/prime_rl/orchestrator/trajectories.py @@ -41,6 +41,7 @@ def interleave_rollout( output: vf.RolloutOutput, vlm_cache: "VLMImageCache | None" = None, cache_key: int | None = None, + split_by_agent: bool = False, ) -> list[TrainingSample] | None: """ Convert vf.RolloutOutput to trainable rollouts by interleaving trajectory steps @@ -80,11 +81,14 @@ def interleave_rollout( # this field should be guaranteed because we set temperature in get_sampling_args temperature = output["sampling_args"]["temperature"] + def _is_step_trainable(step: vf.TrajectoryStep) -> bool: + return step.get("extras", {}).get("is_trainable", True) + def make_sample(step: vf.TrajectoryStep) -> TrainingSample: """Create a new TrainingSample from a trajectory step.""" tokens = step["tokens"] assert tokens is not None - if has_error: + if has_error or not _is_step_trainable(step): completion_mask = [False] * len(tokens["completion_mask"]) else: completion_mask = [bool(i) for i in tokens["completion_mask"]] @@ -103,8 +107,10 @@ def make_sample(step: vf.TrajectoryStep) -> TrainingSample: completion_logprobs=list(tokens["completion_logprobs"]), completion_temperatures=[temperature] * len(completion_ids), teacher_logprobs=None, - advantage=None, + advantage=step.get("advantage"), + reward=step.get("reward"), routed_experts=routed_experts, + actor_id=step.get("extras", {}).get("agent_id"), ) def extend_sample(sample: TrainingSample, step: vf.TrajectoryStep, prefix_len: int) -> None: @@ -122,13 +128,22 @@ def extend_sample(sample: TrainingSample, step: vf.TrajectoryStep, prefix_len: i # Extend with new completion tokens completion_ids = tokens["completion_ids"] sample.completion_ids.extend(completion_ids) - if has_error: + if has_error or not _is_step_trainable(step): sample.completion_mask.extend([False] * len(tokens["completion_mask"])) else: sample.completion_mask.extend(bool(i) for i in tokens["completion_mask"]) sample.completion_logprobs.extend(tokens["completion_logprobs"]) sample.completion_temperatures.extend([temperature] * len(completion_ids)) + # Update reward/advantage/actor_id to use the latest merged step's values (multi-agent) + if step.get("reward") is not None: + sample.reward = step["reward"] + if step.get("advantage") is not None: + sample.advantage = step["advantage"] + agent_id = step.get("extras", {}).get("agent_id") + if agent_id is not None: + sample.actor_id = agent_id + if tokens.get("routed_experts") is not None and sample.routed_experts is not None: step_routed = tokens["routed_experts"] # The previous step's last routing entry was zero-padded by _align_routed_experts @@ -141,12 +156,23 @@ def extend_sample(sample: TrainingSample, step: vf.TrajectoryStep, prefix_len: i expected_len = len(sample.prompt_ids) + len(sample.completion_ids) sample.routed_experts = _align_routed_experts(sample.routed_experts, expected_len) - # Track [prefix_tokens, sample, last_step_idx] per active sample + # Track [prefix_tokens, sample, last_step_idx, agent_meta] per active sample + # agent_meta: dict with per-completion-token agent_ids and per-agent reward/advantage (for split_by_agent) active_samples: list[list] = [] first_tokens = trajectory[0]["tokens"] first_prefix = first_tokens["prompt_ids"] + first_tokens["completion_ids"] - active_samples.append([first_prefix, make_sample(trajectory[0]), 0]) + if split_by_agent: + first_agent_id = trajectory[0].get("extras", {}).get("agent_id") + first_step = trajectory[0] + agent_meta: dict = { + "ids": [first_agent_id] * len(first_tokens["completion_ids"]), + "rewards": {first_agent_id: first_step.get("reward")} if first_agent_id else {}, + "advantages": {first_agent_id: first_step.get("advantage")} if first_agent_id else {}, + } + else: + agent_meta = {} + active_samples.append([first_prefix, make_sample(trajectory[0]), 0, agent_meta]) for step_idx, step in enumerate(trajectory[1:], start=1): tokens = step["tokens"] @@ -154,17 +180,25 @@ def extend_sample(sample: TrainingSample, step: vf.TrajectoryStep, prefix_len: i # Check if this step extends ANY active prefix matched_idx = None - for idx, (prefix_tokens, _, _) in enumerate(active_samples): + for idx, (prefix_tokens, _, _, _) in enumerate(active_samples): if step_prompt_ids[: len(prefix_tokens)] == prefix_tokens: matched_idx = idx break if matched_idx is not None: # Extension holds - merge into matched sample - prefix_tokens, sample, _ = active_samples[matched_idx] + prefix_tokens, sample, _, meta = active_samples[matched_idx] extend_sample(sample, step, len(prefix_tokens)) active_samples[matched_idx][0] = tokens["prompt_ids"] + tokens["completion_ids"] active_samples[matched_idx][2] = step_idx + if split_by_agent: + step_agent_id = step.get("extras", {}).get("agent_id") + new_prompt_len = len(tokens["prompt_ids"]) - len(prefix_tokens) + new_completion_len = len(tokens["completion_ids"]) + meta["ids"].extend([step_agent_id] * (new_prompt_len + new_completion_len)) + if step_agent_id: + meta["rewards"][step_agent_id] = step.get("reward") + meta["advantages"][step_agent_id] = step.get("advantage") else: # No prefix matches - start a new sample logger.debug( @@ -172,18 +206,59 @@ def extend_sample(sample: TrainingSample, step: vf.TrajectoryStep, prefix_len: i f"Starting new sample (active_prefixes={len(active_samples)}, step_prompt_len={len(step_prompt_ids)})." ) new_prefix = tokens["prompt_ids"] + tokens["completion_ids"] - active_samples.append([new_prefix, make_sample(step), step_idx]) + if split_by_agent: + new_agent_id = step.get("extras", {}).get("agent_id") + new_meta: dict = { + "ids": [new_agent_id] * len(tokens["completion_ids"]), + "rewards": {new_agent_id: step.get("reward")} if new_agent_id else {}, + "advantages": {new_agent_id: step.get("advantage")} if new_agent_id else {}, + } + else: + new_meta = {} + active_samples.append([new_prefix, make_sample(step), step_idx, new_meta]) # Attach images once per sample using only the last merged step if vlm_cache is not None: key = output["example_id"] if cache_key is None else cache_key - for _, sample, last_step_idx in active_samples: + for _, sample, last_step_idx, _ in active_samples: pv, shape, grids = vlm_cache.get_for_step(key, last_step_idx) sample.pixel_values = pv sample.pixel_values_shape = shape sample.image_grid_thw = grids - return [sample for _, sample, _ in active_samples] + if not split_by_agent: + return [sample for _, sample, _, _ in active_samples] + + # Split merged samples into per-agent copies with agent-specific completion masks + result: list[TrainingSample] = [] + for _, sample, _, meta in active_samples: + agent_ids = meta["ids"] + unique_agents = set(a for a in agent_ids if a is not None) + if len(unique_agents) <= 1: + result.append(sample) + continue + for agent_id in unique_agents: + agent_mask = [ + sample.completion_mask[i] and agent_ids[i] == agent_id for i in range(len(agent_ids)) + ] + agent_sample = TrainingSample( + prompt_ids=sample.prompt_ids, + prompt_mask=sample.prompt_mask, + completion_ids=sample.completion_ids, + completion_mask=agent_mask, + completion_logprobs=sample.completion_logprobs, + completion_temperatures=sample.completion_temperatures, + teacher_logprobs=sample.teacher_logprobs, + advantage=meta["advantages"].get(agent_id, sample.advantage), + reward=meta["rewards"].get(agent_id, sample.reward), + routed_experts=sample.routed_experts, + actor_id=agent_id, + pixel_values=sample.pixel_values, + pixel_values_shape=sample.pixel_values_shape, + image_grid_thw=sample.image_grid_thw, + ) + result.append(agent_sample) + return result # ============================================================================= diff --git a/src/prime_rl/orchestrator/vf_utils.py b/src/prime_rl/orchestrator/vf_utils.py index 6acfbee23e..470da36a8d 100644 --- a/src/prime_rl/orchestrator/vf_utils.py +++ b/src/prime_rl/orchestrator/vf_utils.py @@ -85,6 +85,7 @@ async def run_rollout( sampling_args: dict, max_retries: int = DEFAULT_RETRIES, state_columns: list[str] = DEFAULT_STATE_COLUMNS, + actor_models: dict[str, str] | None = None, ) -> vf.RolloutOutput: """ Wrapper for vf.Environment.run_rollout(). @@ -100,6 +101,7 @@ async def run_rollout( sampling_args=sampling_args, max_retries=max_retries, state_columns=state_columns, + actor_models=actor_models, ) @@ -112,6 +114,7 @@ async def run_group( sampling_args: dict, max_retries: int = DEFAULT_RETRIES, state_columns: list[str] = DEFAULT_STATE_COLUMNS, + actor_models: dict[str, str] | None = None, ) -> list[vf.RolloutOutput]: """ Wrapper for vf.Environment.run_group(). @@ -127,6 +130,7 @@ async def run_group( sampling_args=sampling_args, max_retries=max_retries, state_columns=state_columns, + actor_models=actor_models, ) @@ -142,6 +146,7 @@ async def generate( max_retries: int = DEFAULT_RETRIES, state_columns: list[str] = DEFAULT_STATE_COLUMNS, pbar_description: str = "Generating rollouts", + actor_models: dict[str, str] | None = None, ) -> list[vf.RolloutOutput]: """ Wrapper for vf.Environment.generate(). @@ -174,6 +179,7 @@ async def run_group_with_progress(example): max_retries=max_retries, state_columns=state_columns, sampling_args=sampling_args, + actor_models=actor_models, ) pbar.update(rollouts_per_example) return result @@ -198,6 +204,7 @@ async def evaluate( get_client: Callable[[], Awaitable[vf.ClientConfig]] | None = None, max_retries: int = DEFAULT_RETRIES, state_columns: list[str] = DEFAULT_STATE_COLUMNS, + actor_models: dict[str, str] | None = None, ) -> list[vf.RolloutOutput]: """ Wrapper for vf.Environment.evaluate(). @@ -221,6 +228,7 @@ async def evaluate( sampling_args=sampling_args, max_retries=max_retries, state_columns=state_columns, + actor_models=actor_models, ) return outputs diff --git a/src/prime_rl/trainer/rl/data.py b/src/prime_rl/trainer/rl/data.py index 0f7a07b77c..37d39ca2b3 100644 --- a/src/prime_rl/trainer/rl/data.py +++ b/src/prime_rl/trainer/rl/data.py @@ -148,6 +148,7 @@ def __init__( pad_to_multiple_of: int, tokenizer: PreTrainedTokenizer, config: TransportConfig, + pack_full_step: bool = False, ): self.world = get_world() @@ -159,6 +160,7 @@ def __init__( transport_config=config, pad_to_multiple_of=pad_to_multiple_of, start_step=start_step, + pack_full_step=pack_full_step, ) non_dp_world_size = self.world.world_size // dp_world_size diff --git a/src/prime_rl/trainer/rl/packer.py b/src/prime_rl/trainer/rl/packer.py index 9bee96a40a..c34d55ac9b 100644 --- a/src/prime_rl/trainer/rl/packer.py +++ b/src/prime_rl/trainer/rl/packer.py @@ -96,8 +96,10 @@ def __init__( tokenizer: PreTrainedTokenizer, config: TransportConfig, start_step: int = 0, + pack_full_step: bool = False, ): super().__init__(dp_world_size, seq_len, pad_to_multiple_of, tokenizer, config, start_step) + self.pack_full_step = pack_full_step # Per-run buffer: stores (TrainingSample, step) tuples self.buffers: list[deque[tuple[TrainingSample, int]]] = [ deque() for _ in range(self.multi_run_manager.max_runs) @@ -192,8 +194,22 @@ def _count_tokens(self, threshold: int | None = None) -> int: return tokens return tokens + def _all_runs_have_step_data(self) -> bool: + """Check that every active run has buffered data for its current step.""" + if not self.multi_run_manager.used_idxs: + return False + for run_idx in self.multi_run_manager.used_idxs: + if len(self.buffers[run_idx]) == 0: + return False + _, step = self.buffers[run_idx][0] + if step > self.multi_run_manager.progress[run_idx].step: + return False + return True + def _has_enough_tokens(self) -> bool: """Check if we have enough samples in buffer to pack a step""" + if self.pack_full_step: + return self._all_runs_have_step_data() # When not using small batch granularity, require at least one full batch threshold = self.seq_len * self.dp_world_size return self._count_tokens(threshold) >= threshold @@ -258,13 +274,16 @@ def pack(self): start_time = time.time() while not self._has_enough_tokens(): - if time.time() - start_time > TIMEOUT_SECONDS and self._count_tokens() > 0: + if not self.pack_full_step and time.time() - start_time > TIMEOUT_SECONDS and self._count_tokens() > 0: self.logger.warning("Timeout waiting for enough tokens to pack") break time.sleep(1) self._get_batch() - token_budget = self.seq_len * self.dp_world_size + if self.pack_full_step: + token_budget = self._count_tokens() + else: + token_budget = self.seq_len * self.dp_world_size selected_samples = self._select_samples_round_robin(token_budget) assert selected_samples, "No samples selected" @@ -313,9 +332,12 @@ def setup_packer( tokenizer: PreTrainedTokenizer, transport_config: TransportConfig, start_step: int = 0, + pack_full_step: bool = False, ) -> BasePacker: multi_run_manager = get_multi_run_manager() if multi_run_manager.max_runs == 1: return SinglePacker(dp_world_size, seq_len, pad_to_multiple_of, tokenizer, transport_config, start_step) else: - return MultiPacker(dp_world_size, seq_len, pad_to_multiple_of, tokenizer, transport_config, start_step) + return MultiPacker( + dp_world_size, seq_len, pad_to_multiple_of, tokenizer, transport_config, start_step, pack_full_step + ) diff --git a/src/prime_rl/trainer/rl/train.py b/src/prime_rl/trainer/rl/train.py index 7975b4bd80..e267d16d95 100644 --- a/src/prime_rl/trainer/rl/train.py +++ b/src/prime_rl/trainer/rl/train.py @@ -206,6 +206,7 @@ def load_run_checkpoint(_optimizer, idx: int) -> None: config.model.cp, tokenizer, config.rollout_transport, + pack_full_step=config.pack_full_step, ) logger.info(f"Starting training loop (max_steps={config.max_steps or 'infinite'})") diff --git a/src/prime_rl/transport/__init__.py b/src/prime_rl/transport/__init__.py index e4c3153dc7..cda2c02ca9 100644 --- a/src/prime_rl/transport/__init__.py +++ b/src/prime_rl/transport/__init__.py @@ -7,6 +7,7 @@ FileSystemMicroBatchSender, FileSystemTrainingBatchReceiver, FileSystemTrainingBatchSender, + MultiRunFileSystemTrainingBatchSender, ) from prime_rl.transport.types import MicroBatch, TrainingBatch, TrainingSample from prime_rl.transport.zmq import ( @@ -26,6 +27,14 @@ def setup_training_batch_sender(output_dir: Path, transport: TransportConfig) -> raise ValueError(f"Invalid transport type: {transport.type}") +def setup_multi_run_training_batch_sender( + output_dir: Path, run_names: list[str], transport: TransportConfig +) -> MultiRunFileSystemTrainingBatchSender: + if transport.type != "filesystem": + raise ValueError(f"Multi-run sender only supports filesystem transport, got: {transport.type}") + return MultiRunFileSystemTrainingBatchSender(output_dir, run_names) + + def setup_training_batch_receiver(transport: TransportConfig) -> TrainingBatchReceiver: if transport.type == "filesystem": return FileSystemTrainingBatchReceiver() @@ -62,6 +71,7 @@ def setup_micro_batch_receiver( "FileSystemTrainingBatchReceiver", "FileSystemMicroBatchSender", "FileSystemMicroBatchReceiver", + "MultiRunFileSystemTrainingBatchSender", "MicroBatchReceiver", "MicroBatchSender", "TrainingSample", @@ -69,6 +79,7 @@ def setup_micro_batch_receiver( "MicroBatch", "setup_training_batch_sender", "setup_training_batch_receiver", + "setup_multi_run_training_batch_sender", "setup_micro_batch_sender", "setup_micro_batch_receiver", ] diff --git a/src/prime_rl/transport/filesystem.py b/src/prime_rl/transport/filesystem.py index f609f888e1..85cee29eae 100644 --- a/src/prime_rl/transport/filesystem.py +++ b/src/prime_rl/transport/filesystem.py @@ -113,6 +113,20 @@ def reset_run(self, idx: int) -> None: del self._received_steps[idx] +class MultiRunFileSystemTrainingBatchSender(TrainingBatchSender): + """Filesystem-based sender that routes batches to per-run directories.""" + + def __init__(self, output_dir: Path, run_names: list[str]): + super().__init__(output_dir) + self.senders = {name: FileSystemTrainingBatchSender(output_dir / name) for name in run_names} + + def send_to_run(self, run_name: str, batch: TrainingBatch) -> None: + self.senders[run_name].send(batch) + + def send(self, batch: TrainingBatch) -> None: + raise NotImplementedError("Use send_to_run() for multi-run sender") + + class FileSystemMicroBatchSender(MicroBatchSender): """Filesystem-based micro batch sender that writes micro batches to disk.""" diff --git a/src/prime_rl/transport/types.py b/src/prime_rl/transport/types.py index 8579112fa4..898b2c0718 100644 --- a/src/prime_rl/transport/types.py +++ b/src/prime_rl/transport/types.py @@ -23,6 +23,8 @@ class TrainingSample(msgspec.Struct, array_like=True, gc=False, omit_defaults=Tr routed_experts: list[list[list[int]]] | None = None # [seq_len, layers, topk] + actor_id: str | None = None + class TrainingBatch(msgspec.Struct, array_like=True, gc=False, omit_defaults=True): """A batch of training examples with metadata for transport."""