diff --git a/CHANGELOG.md b/CHANGELOG.md index 9e0de85b98..867050082f 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -2,6 +2,7 @@ Documenting **breaking** configuration changes — renamed, removed, or moved fields that require users to update existing configs. +- **`AdvantageInputs` / `AdvantageOutputs` are now per-group, and `AdvantageOutputs.advantages` is a plain `list[float]`** (second breaking change to this API in three weeks). `AdvantageInputs.rollouts` is now `list[vf.RolloutOutput]` (a single group) instead of `list[list[vf.RolloutOutput]]`, and `AdvantageOutputs.advantages` is now `list[float]` instead of a 2D `Float[Tensor, "num_examples rollouts_per_example"]`. `compute_advantages` calls `advantage_fn` once per group, which lets partial-group training (groups smaller than `rollouts_per_example` after rollout errors) round-trip without the previous bucket-by-size workaround. Custom advantage functions must drop the outer list dimension and return a list of floats — e.g. `AdvantageOutputs(advantages=(rewards - rewards.mean(dim=1, keepdim=True)).tolist())` becomes `AdvantageOutputs(advantages=[r - mean for r in rewards])` (or `.tolist()` if you keep torch internally). (2026-05-22) - **`orchestrator.advantage.length_penalty` → discriminated sub-config**: The scalar `length_penalty: Literal["tokens","turns"] | None` is replaced by a `LengthPenaltyConfig | None` discriminated on `type`. Token shaping now takes weighted completion + tool-response token costs. Migration: `length_penalty = "tokens"` becomes `[orchestrator.advantage.length_penalty]\ntype = "tokens"` (default weights `completion_weight = 1.0`, `tool_response_weight = 1.0` — total context). `length_penalty = "turns"` becomes `[orchestrator.advantage.length_penalty]\ntype = "turns"`. (2026-05-06) - **`orchestrator.advantage.length_shaping` → `orchestrator.advantage.length_penalty`**: The boolean `length_shaping` flag has been replaced by `length_penalty: Literal["tokens", "turns"] | None` (default: `None`). `length_shaping = true` becomes `length_penalty = "tokens"`; `length_shaping = false` becomes `length_penalty = None`. The new `"turns"` option applies the same correctness-gated efficiency shaping using trajectory turn count instead of completion-token count. (2026-05-01) - **`AdvantageInputs` API**: Replaced the `rewards`/`completion_lengths`/`num_turns` tensor fields with a single `rollouts: list[list[vf.RolloutOutput]]` (grouped by problem). Custom advantage functions can now access any rollout metadata. Existing custom advantages must update their signatures and extract per-rollout fields directly (e.g. `torch.tensor([[r["reward"] for r in g] for g in inputs.rollouts])`). (2026-05-01) diff --git a/docs/bring-your-own-algorithms.md b/docs/bring-your-own-algorithms.md index 816d9c94c4..fba1b1072f 100644 --- a/docs/bring-your-own-algorithms.md +++ b/docs/bring-your-own-algorithms.md @@ -70,7 +70,7 @@ kwargs = { clip_eps = 0.2 } ## 2. Custom Advantage Functions -Advantages are computed **per-example** (grouped by `rollouts_per_example`). You provide a function that computes advantages for a batch of examples. +Advantages are computed **per-group** (one example × N rollouts). You provide a function that computes advantages for a single group; the framework calls it once per group and stitches the results back together. Groups may have fewer than `rollouts_per_example` rollouts when some rollouts in the group errored (partial-group training). ### Interface @@ -86,8 +86,8 @@ def my_custom_advantage(inputs: AdvantageInputs, **kwargs) -> AdvantageOutputs: ```python @dataclass class AdvantageInputs: - # Rollouts grouped by problem: rollouts[i][j] is the j-th rollout for problem i. - rollouts: list[list[vf.RolloutOutput]] + # All rollouts for a single example (one group). + rollouts: list[vf.RolloutOutput] ``` Each `vf.RolloutOutput` carries the full rollout (`reward`, `trajectory`, etc.), so custom advantages can read any metadata they need (e.g. completion-token counts, turn counts, tool calls). @@ -97,22 +97,21 @@ Each `vf.RolloutOutput` carries the full rollout (`reward`, `trajectory`, etc.), ```python @dataclass class AdvantageOutputs: - advantages: Float[Tensor, "num_examples rollouts_per_example"] + advantages: list[float] # one entry per rollout in the input group ``` ### Example: Normalized Advantage ```python -import torch +import statistics from prime_rl.orchestrator.advantage import AdvantageInputs, AdvantageOutputs def normalized_advantage(inputs: AdvantageInputs, eps: float = 1e-8) -> AdvantageOutputs: - """Normalize advantages to zero mean and unit variance per example.""" - rewards = torch.tensor([[r["reward"] for r in group] for group in inputs.rollouts]) - mean = rewards.mean(dim=1, keepdim=True) - std = rewards.std(dim=1, keepdim=True) - advantages = (rewards - mean) / (std + eps) - return AdvantageOutputs(advantages=advantages) + """Normalize advantages to zero mean and unit variance within the group.""" + rewards = [r["reward"] for r in inputs.rollouts] + mean = statistics.fmean(rewards) + std = statistics.pstdev(rewards) if len(rewards) > 1 else 0.0 + return AdvantageOutputs(advantages=[(r - mean) / (std + eps) for r in rewards]) ``` ### Configuration diff --git a/packages/prime-rl-configs/src/prime_rl/configs/orchestrator.py b/packages/prime-rl-configs/src/prime_rl/configs/orchestrator.py index 8feba60128..051d8847d5 100644 --- a/packages/prime-rl-configs/src/prime_rl/configs/orchestrator.py +++ b/packages/prime-rl-configs/src/prime_rl/configs/orchestrator.py @@ -337,6 +337,9 @@ class EnvConfig(BaseConfig): ), ] = None + state_columns: list[str] = [] + """Extra ``State`` fields to persist into the saved rollout records (in addition to the always-saved ``trajectory`` and ``sampling_args``). Values must be JSON-serializable.""" + @property def stripped_id(self) -> str: """Environment ID without the @version suffix.""" diff --git a/packages/prime-rl-configs/src/prime_rl/configs/trainer.py b/packages/prime-rl-configs/src/prime_rl/configs/trainer.py index a076d1e29c..fa506b23f2 100644 --- a/packages/prime-rl-configs/src/prime_rl/configs/trainer.py +++ b/packages/prime-rl-configs/src/prime_rl/configs/trainer.py @@ -767,9 +767,16 @@ class NCCLWeightBroadcastConfig(BaseWeightBroadcastConfig): ] +class TokenExportConfig(BaseConfig): + """Configures per-token rollout exports from the RL trainer.""" + + class TrainerExperimentalConfig(BaseConfig): """Experimental features for the trainer.""" + token_export: TokenExportConfig | None = None + """Opt-in per-token JSONL export for rollout debugging. When enabled, writes token ids and aligned trainer metrics after each forward pass.""" + class TrainerConfig(BaseConfig): """Configures the RL trainer""" diff --git a/skills/config/SKILL.md b/skills/config/SKILL.md index e8dc13216c..f6ffe75b91 100644 --- a/skills/config/SKILL.md +++ b/skills/config/SKILL.md @@ -157,6 +157,16 @@ If you wish to configure values of the default variant, you don't need to set th For hosted multi-tenant runs where the trainer image's `trainer.loss.type` is fixed, the orchestrator exposes a per-run override that forces SFT loss on every micro-batch without rebuilding the trainer. Set `orchestrator.use_sft_loss = true` alongside `orchestrator.teacher_rollout_model`; both must be configured together (the orchestrator validator enforces this). The orchestrator stamps each `TrainingSample.sft_loss = True`, which the trainer's `compute_loss` honors by dispatching to `sft_loss_fn` per batch — independent of the trainer's configured default loss. +### RL trainer token exports + +For rollout debugging, enable trainer-side token export under `trainer.experimental.token_export` (or `experimental.token_export` when running the trainer entrypoint directly). It writes one JSONL record per exported sequence under `output_dir/token_exports/step_/rank_.jsonl`. Each record stores aligned per-token arrays for token ids, loss mask, advantage, reward, entropy, mismatch KL, inference/trainer logprobs, importance ratios, probability deltas, and masking diagnostics. It does not decode token text in the trainer. + +```toml +[trainer.experimental.token_export] +``` + +Leave it unset for normal training. When enabled, it exports every sequence from each exporting rank. + ### Model fields For `BaseModel | None` fields (like `[ckpt]`, `[wandb]`, `[compile]`), a bare flag enables them with defaults: diff --git a/skills/monitor-run/SKILL.md b/skills/monitor-run/SKILL.md index 2a8404dffb..1a7bbb8006 100644 --- a/skills/monitor-run/SKILL.md +++ b/skills/monitor-run/SKILL.md @@ -143,8 +143,10 @@ These tell you whether training is healthy or diverging. | Metric | Source | Description | |--------|--------|-------------| -| `mismatch_kl/mean` | trainer | KL divergence between trainer and (old) inference policy | -| `entropy/mean` | trainer | policy entropy | +| `mismatch_kl/{all,env}/{mean,std,max}` | trainer | KL divergence between trainer and (old) inference policy over trainable tokens (W&B) | +| `entropy/{all,env}/{mean,std,max}` | trainer | policy entropy over trainable tokens (W&B) | +| `masked_advantage_positive/mean` | trainer | fraction of DPPO-masked trainable tokens with positive advantage (W&B) | +| `masked_advantage_negative/mean` | trainer | fraction of DPPO-masked trainable tokens with negative advantage (W&B) | | `optim/grad_norm` | trainer | gradient norm — spikes may precede divergence | #### Performance @@ -251,4 +253,3 @@ PRIME-RL::Launcher ``` For multi-node runs, trainer and inference processes are distributed across separate nodes. Use `srun` or `ssh` to inspect processes on other nodes directly. - diff --git a/src/prime_rl/orchestrator/advantage.py b/src/prime_rl/orchestrator/advantage.py index 63b1d50325..c60a99aa13 100644 --- a/src/prime_rl/orchestrator/advantage.py +++ b/src/prime_rl/orchestrator/advantage.py @@ -1,3 +1,4 @@ +from collections import defaultdict from dataclasses import dataclass from typing import Callable @@ -19,20 +20,16 @@ @dataclass class AdvantageInputs: - """Inputs for advantage computation. + """Inputs for advantage computation of a single group (one example × N rollouts).""" - `rollouts` is grouped by problem: `rollouts[i][j]` is the j-th rollout for problem i, - so `len(rollouts) == num_problems` and `len(rollouts[0]) == rollouts_per_example`. - """ - - rollouts: list[list[vf.RolloutOutput]] + rollouts: list[vf.RolloutOutput] @dataclass class AdvantageOutputs: - """Outputs from advantage computation.""" + """Outputs from advantage computation of a single group.""" - advantages: Float[Tensor, "num_problems rollouts_per_example"] + advantages: list[float] AdvantageFn = Callable[..., AdvantageOutputs] @@ -41,6 +38,9 @@ class AdvantageOutputs: Expected signature: def my_advantage(inputs: AdvantageInputs, **kwargs) -> AdvantageOutputs: ... + +The function receives a single group and returns a list of advantages with one +entry per rollout. `compute_advantages` calls it once per group. """ @@ -48,73 +48,63 @@ def default_advantage_fn( inputs: AdvantageInputs, length_penalty: LengthPenaltyConfig | None = None, ) -> AdvantageOutputs: - """Default GRPO advantage: reward minus per-problem baseline. + """Default GRPO advantage for a single group: reward minus per-group baseline. `length_penalty` enables correctness-gated efficiency shaping over a per-rollout cost: tokens (weighted completion + tool-response) or trajectory turn count. """ - rewards = torch.tensor([[r["reward"] for r in group] for group in inputs.rollouts], dtype=torch.float32) + rewards = torch.tensor([r["reward"] for r in inputs.rollouts], dtype=torch.float32) if isinstance(length_penalty, TokensLengthPenaltyConfig): w_c = length_penalty.completion_weight w_t = length_penalty.tool_response_weight costs = torch.tensor( - [ - [w_c * get_model_completion_len(r) + w_t * get_tool_response_len(r) for r in group] - for group in inputs.rollouts - ], + [w_c * get_model_completion_len(r) + w_t * get_tool_response_len(r) for r in inputs.rollouts], dtype=rewards.dtype, ) - return AdvantageOutputs(advantages=_efficiency_shaping(rewards, costs)) + return AdvantageOutputs(advantages=_efficiency_shaping(rewards, costs).tolist()) if isinstance(length_penalty, TurnsLengthPenaltyConfig): - costs = torch.tensor( - [[get_num_turns(r) for r in group] for group in inputs.rollouts], - dtype=rewards.dtype, - ) - return AdvantageOutputs(advantages=_efficiency_shaping(rewards, costs)) + costs = torch.tensor([get_num_turns(r) for r in inputs.rollouts], dtype=rewards.dtype) + return AdvantageOutputs(advantages=_efficiency_shaping(rewards, costs).tolist()) - baseline = rewards.mean(dim=1, keepdim=True) - return AdvantageOutputs(advantages=rewards - baseline) + return AdvantageOutputs(advantages=(rewards - rewards.mean()).tolist()) def _efficiency_shaping( - rewards: Float[Tensor, "num_problems rollouts_per_example"], - costs: Float[Tensor, "num_problems rollouts_per_example"], -) -> Float[Tensor, "num_problems rollouts_per_example"]: + rewards: Float[Tensor, "group_size"], + costs: Float[Tensor, "group_size"], +) -> Float[Tensor, "group_size"]: """Correctness-gated efficiency shaping with bounded advantages. Shapes rewards with a bounded efficiency bonus before standard GRPO subtraction, - preserving zero-mean advantages per group. `costs` is a per-rollout cost (e.g., - completion length in tokens or number of turns). + preserving zero-mean advantages within the group. `costs` is a per-rollout cost + (e.g., completion length in tokens or number of turns). Correct rollouts get reward amplified by up to 2x based on relative efficiency. Incorrect rollouts are untouched. Lower-cost correct rollouts get higher advantage. """ - max_reward = rewards.max(dim=1, keepdim=True).values + max_reward = rewards.max() correct_mask = rewards >= max_reward - num_correct = correct_mask.sum(dim=1, keepdim=True) + num_correct = correct_mask.sum() # No shaping when max reward is 0 — no correct rollouts to differentiate - has_correct = max_reward > 0 + if max_reward <= 0: + return rewards - rewards.mean() - # Mean cost of correct rollouts per problem - correct_costs = costs * correct_mask - mean_correct_cost = correct_costs.sum(dim=1, keepdim=True) / num_correct.clamp(min=1) + # Mean cost of correct rollouts + mean_correct_cost = (costs * correct_mask).sum() / num_correct.clamp(min=1) # Bounded efficiency bonus: [0, 1], positive for below-average cost, zero for above. # When mean_correct_cost is 0 (e.g. tool-only shaping with no harness metric, or # all-zero turn counts), no rollouts can be differentiated — fall back to no bonus. - has_cost = mean_correct_cost > 0 - safe_mean = torch.where(has_cost, mean_correct_cost, torch.ones_like(mean_correct_cost)) - bonus = (1 - costs / safe_mean).clamp(0, 1) * has_cost + if mean_correct_cost <= 0: + return rewards - rewards.mean() + + bonus = (1 - costs / mean_correct_cost).clamp(0, 1) # Shape rewards: correct rollouts amplified by up to 2x, incorrect untouched shaped_rewards = rewards * (1 + bonus * correct_mask) - baseline = shaped_rewards.mean(dim=1, keepdim=True) - - shaped = shaped_rewards - baseline - unshaped = rewards - rewards.mean(dim=1, keepdim=True) - return torch.where(has_correct, shaped, unshaped) + return shaped_rewards - shaped_rewards.mean() def setup_advantage_fn(config: AdvantageConfig) -> AdvantageFn: @@ -136,31 +126,26 @@ def advantage_fn(inputs: AdvantageInputs) -> AdvantageOutputs: def compute_advantages( rollouts: list[vf.RolloutOutput], - samples_per_problem: int, advantage_config: AdvantageConfig | None, ) -> None: - """ - Computes advantages from rollouts, grouped by problem. - Stores advantages in-place on the rollouts. + """Computes advantages from rollouts, grouped by (env_name, example_id), and + stores them in-place on the rollouts. - Args: - rollouts: List of rollouts to store advantages on - samples_per_problem: Number of samples (and thus, rewards) per problem - advantage_config: Configuration for advantage computation (DefaultAdvantageConfig or CustomAdvantageConfig) + `advantage_fn` is called once per group, so groups may have varying sizes + (partial-group training drops failed rollouts rather than rescheduling them). """ - rewards = [r["reward"] for r in rollouts] - if not advantage_config: - for rollout, reward in zip(rollouts, rewards): - rollout["advantage"] = reward + for rollout in rollouts: + rollout["advantage"] = rollout["reward"] return advantage_fn = setup_advantage_fn(advantage_config) - grouped = [rollouts[i : i + samples_per_problem] for i in range(0, len(rollouts), samples_per_problem)] - inputs = AdvantageInputs(rollouts=grouped) - result = advantage_fn(inputs) - advantages = result.advantages.flatten().tolist() + groups_by_example: dict[tuple[str, int], list[vf.RolloutOutput]] = defaultdict(list) + for rollout in rollouts: + groups_by_example[(rollout["env_name"], rollout["example_id"])].append(rollout) - for rollout, advantage in zip(rollouts, advantages): - rollout["advantage"] = advantage + for group in groups_by_example.values(): + result = advantage_fn(AdvantageInputs(rollouts=group)) + for rollout, advantage in zip(group, result.advantages): + rollout["advantage"] = advantage diff --git a/src/prime_rl/orchestrator/envs.py b/src/prime_rl/orchestrator/envs.py index 1de52994a9..c7ac150aa6 100644 --- a/src/prime_rl/orchestrator/envs.py +++ b/src/prime_rl/orchestrator/envs.py @@ -110,6 +110,15 @@ def _sampling_args_with_salt(self, cache_salt: str) -> dict: sampling_args["extra_body"] = extra_body return sampling_args + @property + def state_columns(self) -> list[str]: + """Required columns plus any extras configured on the env, deduped (required first).""" + merged: list[str] = [] + for col in (*REQUIRED_STATE_COLUMNS, *self.config.state_columns): + if col not in merged: + merged.append(col) + return merged + async def run_rollout( self, client: vf.ClientConfig, @@ -124,7 +133,7 @@ async def run_rollout( model=model_name, sampling_args=self._sampling_args_with_salt(cache_salt), max_retries=self.config.max_retries, - state_columns=REQUIRED_STATE_COLUMNS, + state_columns=self.state_columns, env_client=self.env_client, ) @@ -143,7 +152,7 @@ async def run_group( model=model_name, sampling_args=self._sampling_args_with_salt(cache_salt), max_retries=self.config.max_retries, - state_columns=REQUIRED_STATE_COLUMNS, + state_columns=self.state_columns, env_client=self.env_client, ) diff --git a/src/prime_rl/orchestrator/orchestrator.py b/src/prime_rl/orchestrator/orchestrator.py index e72a040c33..acf3714f3c 100644 --- a/src/prime_rl/orchestrator/orchestrator.py +++ b/src/prime_rl/orchestrator/orchestrator.py @@ -429,7 +429,7 @@ async def orchestrate(config: OrchestratorConfig): # Compute advantages (in-place) num_rollouts = len(train_rollouts) num_unique_examples = len({(r["env_name"], r["example_id"]) for r in train_rollouts}) - compute_advantages(train_rollouts, config.rollouts_per_example, config.advantage) + compute_advantages(train_rollouts, config.advantage) # Apply rollout filters — sets rollout["filters"] and rollout["is_filtered"] apply_filters(rollout_filters, train_rollouts) @@ -553,6 +553,7 @@ def process_rollout(rollout: vf.RolloutOutput, rollout_idx: int) -> list[Trainin for sample in samples: sample.advantage = rollout["advantage"] sample.reward = rollout["reward"] + sample.env_name = rollout["env_name"] if config.use_sft_loss: sample.sft_loss = True sample_decode_tokens = sum(sample.completion_mask) diff --git a/src/prime_rl/orchestrator/scheduler.py b/src/prime_rl/orchestrator/scheduler.py index c266757c1c..66e01329bb 100644 --- a/src/prime_rl/orchestrator/scheduler.py +++ b/src/prime_rl/orchestrator/scheduler.py @@ -32,8 +32,6 @@ class InflightRequest: env_name: str group_id: int | None = None rollout_count: int = 1 - # Dispatch round the request belongs to (see GroupState.current_round). - round_id: int = 0 @dataclass @@ -44,19 +42,7 @@ class GroupState: rollouts_to_schedule: int completed_rollouts: list[vf.RolloutOutput] = field(default_factory=list) pinned_client: vf.ClientConfig | None = None - # Number of dispatch rounds in which at least one rollout returned errored - # or empty trajectories. Compared against - # config.max_error_reschedule_attempts to decide when to drop a - # permanently-stuck group. Counts rounds, not rollouts: a failed round in - # an individual-scoring env that happens to dispatch N rollouts at once - # still only counts as 1. - failed_attempts: int = 0 - # Round id assigned to newly-dispatched rollouts. Advances after a failure - # is counted so the resulting reschedule starts a new round. - current_round: int = 0 - # Highest round already counted as failed; used to dedupe failures from - # multiple rollouts in the same round. - last_failed_round: int = -1 + failed_rollouts: int = 0 class Scheduler: @@ -128,6 +114,7 @@ def __init__( self.cancelled_rollouts_count = 0 self.empty_rollouts_by_env: dict[str, int] = defaultdict(int) self.errored_rollouts_by_env: dict[str, int] = defaultdict(int) + self.errors_by_type: dict[str, int] = defaultdict(int) self.total_rollouts_by_env: dict[str, int] = defaultdict(int) self.dropped_groups_by_env: dict[str, int] = defaultdict(int) self.last_batch_generation_time = 0.0 @@ -245,7 +232,6 @@ async def schedule_rollout(self, group_id: int): env_name=env_name, group_id=group_id, rollout_count=rollout_count, - round_id=group.current_round, ) @property @@ -448,69 +434,65 @@ async def generate_batch(self, step: int) -> list[vf.RolloutOutput]: rollouts: list[vf.RolloutOutput] = result if isinstance(result, list) else [result] self.total_rollouts_by_env[env_name] += len(rollouts) - # Check for empty/errored rollouts and reschedule - valid_rollouts = [] - has_failures = False - last_failure_reason: str | None = None + # Partition rollouts into valid vs failed and tally per-rollout + # error metrics. Tally every failure (group-scoring envs return + # N rollouts per task) so error-rate metrics aren't deflated. + valid_rollouts: list[vf.RolloutOutput] = [] for rollout in rollouts: if rollout["error"] is not None: self.errored_rollouts_by_env[env_name] += 1 - has_failures = True - last_failure_reason = rollout["error"]["error_chain_repr"] + self.errors_by_type[rollout["error"]["error"]] += 1 self.logger.warning( - f"Rollout error in group {group_id} ({env_name}), re-scheduling " - f"({len(group.completed_rollouts)}/{self.rollouts_per_example} complete): " - f"{last_failure_reason}" + f"Rollout failed in group {group_id} ({env_name}) - " + f"{rollout['error']['error_chain_repr']}" ) elif len(rollout["trajectory"]) == 0: self.empty_rollouts_by_env[env_name] += 1 - has_failures = True - last_failure_reason = "empty trajectory" - self.logger.warning( - f"Empty trajectory in group {group_id} ({env_name}), re-scheduling " - f"({len(group.completed_rollouts)}/{self.rollouts_per_example} complete)" - ) + self.logger.warning(f"Empty trajectory in group {group_id} ({env_name})") else: rollout["env_name"] = env_name valid_rollouts.append(rollout) - if has_failures: - # Dedupe failures within the same dispatch round: an - # individual-scoring env dispatches N rollouts at once, - # so a single failed round can produce up to N failed - # tasks. We only count the round once. - if rollout_info.round_id > group.last_failed_round: - group.failed_attempts += 1 - group.last_failed_round = rollout_info.round_id - group.current_round = rollout_info.round_id + 1 - max_attempts = self.config.max_error_reschedule_attempts - if max_attempts is not None and group.failed_attempts >= max_attempts: - # Permanently-stuck group: drop it from this step and let the - # rest of the batch proceed. Avoids a single bad example (e.g. - # an agent rollout whose sandbox poll keeps timing out) - # blocking step progress forever. - self.dropped_groups_by_env[env_name] += 1 - self.logger.warning( - f"Dropping group {group_id} ({env_name}) after {group.failed_attempts} " - f"failed dispatch rounds ({len(group.completed_rollouts)}/{self.rollouts_per_example} " - f"complete). Last failure: {last_failure_reason}. Set " - f"orchestrator.max_error_reschedule_attempts higher (or to None) " - f"to retry more aggressively." - ) - await self.drop_group(group_id) - continue - - if has_failures and env.requires_group_scoring: - # Group scoring requires all rollouts — discard partial results, reschedule full group - group.completed_rollouts.clear() - group.rollouts_to_schedule = self.rollouts_per_example + num_failed = len(rollouts) - len(valid_rollouts) + group.failed_rollouts += num_failed + + # Group-scoring envs compute scores over all N rollouts + # together; the surviving rollouts carry scores computed against + # the (now-missing) failed ones, so partial salvage is unsafe. + # Drop the whole group on any failure. + if num_failed > 0 and env.requires_group_scoring: + self.dropped_groups_by_env[env_name] += 1 + self.logger.warning( + f"Dropping group-scored group {group_id} ({env_name}) after rollout failure" + ) + await self.drop_group(group_id) continue - # For individual scoring, reschedule only the failed ones - group.rollouts_to_schedule += len(rollouts) - len(valid_rollouts) group.completed_rollouts.extend(valid_rollouts) - if len(group.completed_rollouts) < self.rollouts_per_example: + + # Wait until every dispatched rollout has come back (succeeded + # or failed) before finalizing. The group may finalize as a + # partial group (< rollouts_per_example) when some rollouts + # errored - downstream advantage computation groups by + # (env_name, example_id), so variable-size groups are fine. + if len(group.completed_rollouts) + group.failed_rollouts < self.rollouts_per_example: + continue + + if not group.completed_rollouts: + self.dropped_groups_by_env[env_name] += 1 + self.logger.warning( + f"Dropping group {group_id} ({env_name}) - all {self.rollouts_per_example} rollouts failed" + ) + self.groups.pop(group_id, None) continue + + if group.failed_rollouts > 0: + self.logger.warning( + f"Partial group {group_id} ({env_name}) - " + f"{len(group.completed_rollouts)}/{self.rollouts_per_example} valid " + f"({group.failed_rollouts} failed)" + ) + completed_rollouts = self.groups.pop(group_id).completed_rollouts except asyncio.CancelledError: @@ -524,7 +506,7 @@ async def generate_batch(self, step: int) -> list[vf.RolloutOutput]: continue self.buffer.update(completed_rollouts) - accepted_rollouts = self.buffer.sample_rollouts(n=self.rollouts_per_example) + accepted_rollouts = self.buffer.sample_rollouts(n=len(completed_rollouts)) batch_rollouts.extend(accepted_rollouts) progress_increment = self.get_batch_progress_increment(accepted_rollouts) @@ -586,6 +568,8 @@ def get_metrics(self) -> dict[str, float]: metrics[f"errored_rollouts/{env_name}"] = self.errored_rollouts_by_env.get(env_name, 0) / env_total for env_name, count in self.dropped_groups_by_env.items(): metrics[f"dropped_groups/{env_name}"] = count + for error_type, count in self.errors_by_type.items(): + metrics[f"error/{error_type}/count"] = count by_env: dict[str, list[int]] = {} for info in self.inflight_requests.values(): by_env.setdefault(info.env_name, []).append(info.off_policy_steps) @@ -595,6 +579,7 @@ def get_metrics(self) -> dict[str, float]: self.cancelled_rollouts_count = 0 self.empty_rollouts_by_env.clear() self.errored_rollouts_by_env.clear() + self.errors_by_type.clear() self.total_rollouts_by_env.clear() self.dropped_groups_by_env.clear() diff --git a/src/prime_rl/orchestrator/trajectories.py b/src/prime_rl/orchestrator/trajectories.py index 4cd6f5643c..7590ca04cd 100644 --- a/src/prime_rl/orchestrator/trajectories.py +++ b/src/prime_rl/orchestrator/trajectories.py @@ -369,6 +369,7 @@ def make_sample(tokens: dict[str, Any]) -> TrainingSample: completion_temperatures=[temperature] * len(completion_ids), teacher_logprobs=None, advantage=None, + env_name=output["env_name"], routed_experts=_pack_routed_experts(routed_experts), mm_token_type_ids=None, ) diff --git a/src/prime_rl/trainer/batch.py b/src/prime_rl/trainer/batch.py index ca248a43d4..71e54612a4 100644 --- a/src/prime_rl/trainer/batch.py +++ b/src/prime_rl/trainer/batch.py @@ -58,8 +58,12 @@ def prepare_sample(training_example: TrainingSample, seq_len: int) -> MicroBatch loss_mask = training_example.prompt_mask + training_example.completion_mask inference_logprobs = [0.0] * len(training_example.prompt_ids) + training_example.completion_logprobs advantages = [training_example.advantage] * len(input_ids) + reward = training_example.reward if training_example.reward is not None else float("nan") + rewards = [reward] * len(input_ids) position_ids = list(range(len(input_ids))) mm_token_type_ids = training_example.mm_token_type_ids + assert training_example.env_name != "all", "env_name='all' is reserved for aggregate metric keys" + env_names = [training_example.env_name] * len(input_ids) # Per-token temperatures: prompt tokens use first completion temp (masked out anyway) # Default to 1.0 if completion is empty (e.g., model generated only tool calls with no text) @@ -79,6 +83,7 @@ def prepare_sample(training_example: TrainingSample, seq_len: int) -> MicroBatch inference_logprobs = inference_logprobs[:seq_len] position_ids = position_ids[:seq_len] advantages = advantages[:seq_len] + rewards = rewards[:seq_len] temperatures = temperatures[:seq_len] if teacher_logprobs is not None: teacher_logprobs = teacher_logprobs[:seq_len] @@ -86,6 +91,7 @@ def prepare_sample(training_example: TrainingSample, seq_len: int) -> MicroBatch routed_experts = _slice_routed_experts(routed_experts, seq_len) if mm_token_type_ids is not None: mm_token_type_ids = mm_token_type_ids[:seq_len] + env_names = env_names[:seq_len] assert ( len(input_ids) @@ -93,9 +99,10 @@ def prepare_sample(training_example: TrainingSample, seq_len: int) -> MicroBatch == len(loss_mask) == len(position_ids) == len(inference_logprobs) + == len(rewards) == len(temperatures) ), ( - f"input_ids: {len(input_ids)}, advantages: {len(advantages)}, loss_mask: {len(loss_mask)}, position_ids: {len(position_ids)}, inference_logprobs: {len(inference_logprobs)}, temperatures: {len(temperatures)}" + f"input_ids: {len(input_ids)}, advantages: {len(advantages)}, loss_mask: {len(loss_mask)}, position_ids: {len(position_ids)}, inference_logprobs: {len(inference_logprobs)}, rewards: {len(rewards)}, temperatures: {len(temperatures)}" ) if teacher_logprobs is not None: assert len(teacher_logprobs) == len(input_ids), f"teacher_logprobs: {len(teacher_logprobs)}" @@ -110,6 +117,7 @@ def prepare_sample(training_example: TrainingSample, seq_len: int) -> MicroBatch assert len(mm_token_type_ids) == len(input_ids), ( f"mm_token_type_ids: {len(mm_token_type_ids)}, input_ids: {len(input_ids)}" ) + assert len(env_names) == len(input_ids), f"env_names: {len(env_names)}, input_ids: {len(input_ids)}" return MicroBatch( input_ids=input_ids, @@ -119,8 +127,10 @@ def prepare_sample(training_example: TrainingSample, seq_len: int) -> MicroBatch inference_logprobs=inference_logprobs, teacher_logprobs=teacher_logprobs, temperatures=temperatures, + rewards=rewards, routed_experts=routed_experts, mm_token_type_ids=mm_token_type_ids, + env_names=env_names, # Multimodal fields (Qwen3-VL) - passed through without modification pixel_values=training_example.pixel_values, pixel_values_shape=training_example.pixel_values_shape, @@ -169,9 +179,16 @@ def packed_samples_into_micro_bs( len(bin_content.input_ids) + len(sample.input_ids) <= max_seq_len and bin_content.sft_loss == sample.sft_loss ): + existing_len = len(bin_content.input_ids) bin_content.input_ids.extend(sample.input_ids) bin_content.loss_mask.extend(sample.loss_mask) bin_content.advantages.extend(sample.advantages) + if sample.rewards is not None: + if bin_content.rewards is None: + bin_content.rewards = [float("nan")] * existing_len + bin_content.rewards.extend(sample.rewards) + elif bin_content.rewards is not None: + bin_content.rewards.extend([float("nan")] * len(sample.input_ids)) bin_content.inference_logprobs.extend(sample.inference_logprobs) bin_content.temperatures.extend(sample.temperatures) if sample.teacher_logprobs is not None: @@ -185,6 +202,7 @@ def packed_samples_into_micro_bs( if bin_content.mm_token_type_ids is None: bin_content.mm_token_type_ids = [] bin_content.mm_token_type_ids.extend(sample.mm_token_type_ids) + bin_content.env_names.extend(sample.env_names) bin_content.position_ids.extend(sample.position_ids) bin_content.lora_num_tokens[idx] += len(sample.input_ids) break @@ -209,11 +227,19 @@ def pad_micro_batch(micro_batch: MicroBatch, pad_to_multiple_of: int) -> MicroBa padding_size = (pad_to_multiple_of - (len(micro_batch.input_ids) % pad_to_multiple_of)) % pad_to_multiple_of + if len(micro_batch.env_names) != len(micro_batch.input_ids): + raise ValueError( + f"MicroBatch.env_names must match input_ids length before padding: " + f"env_names={len(micro_batch.env_names)}, input_ids={len(micro_batch.input_ids)}" + ) + if not (pad_to_multiple_of > 1 and padding_size > 0): return micro_batch micro_batch.input_ids.extend([1] * padding_size) micro_batch.advantages.extend([0.0] * padding_size) + if micro_batch.rewards is not None: + micro_batch.rewards.extend([float("nan")] * padding_size) micro_batch.loss_mask.extend([False] * padding_size) micro_batch.position_ids.extend(list(range(padding_size))) micro_batch.inference_logprobs.extend([0.0] * padding_size) @@ -228,6 +254,7 @@ def pad_micro_batch(micro_batch: MicroBatch, pad_to_multiple_of: int) -> MicroBa micro_batch.mm_token_type_ids.extend([0] * padding_size) if micro_batch.routed_experts is not None: _pad_routed_experts(micro_batch, padding_size) + micro_batch.env_names.extend([""] * padding_size) return micro_batch diff --git a/src/prime_rl/trainer/rl/data.py b/src/prime_rl/trainer/rl/data.py index cabd126f59..ad1b65d3b7 100644 --- a/src/prime_rl/trainer/rl/data.py +++ b/src/prime_rl/trainer/rl/data.py @@ -26,10 +26,12 @@ class TensorMicroBatch(TypedDict): input_ids: Int[Tensor, "batch seq"] position_ids: Int[Tensor, "batch seq"] advantages: Float[Tensor, "batch seq"] + rewards: Float[Tensor, "batch seq"] | None inference_logprobs: Float[Tensor, "batch seq"] teacher_logprobs: Float[Tensor, "batch seq"] | None loss_mask: Bool[Tensor, "batch seq"] temperatures: Float[Tensor, "batch seq"] # Per-token temperatures + env_names: list[str] # Batch level lora_num_tokens: Int[Tensor, "n_loras"] @@ -111,9 +113,11 @@ def _get_sample_micro_batch(self, generator: torch.Generator) -> TensorMicroBatc "input_ids": input_ids.unsqueeze(0), "position_ids": position_ids.unsqueeze(0), "advantages": advantages.unsqueeze(0), + "rewards": None, "inference_logprobs": inference_logprobs.unsqueeze(0), "teacher_logprobs": None, "temperatures": torch.ones(input_ids.shape[0]).unsqueeze(0), + "env_names": ["fake"] * input_ids.shape[0], "loss_mask": loss_mask.unsqueeze(0), "lora_num_tokens": lora_num_tokens, "routed_experts": None, @@ -138,9 +142,11 @@ def _get_micro_batch(self, generator: torch.Generator) -> TensorMicroBatch: ), "position_ids": torch.cat([torch.arange(self.seq_len)]).unsqueeze(0), "advantages": torch.randn(self.seq_len, generator=generator).unsqueeze(0), + "rewards": None, "inference_logprobs": torch.randn(self.seq_len, generator=generator).unsqueeze(0), "teacher_logprobs": None, "temperatures": torch.ones(self.seq_len).unsqueeze(0), + "env_names": ["fake"] * self.seq_len, "loss_mask": torch.ones(self.seq_len, dtype=torch.bool).unsqueeze(0), "lora_num_tokens": lora_num_tokens, "routed_experts": None, @@ -217,12 +223,16 @@ def _micro_batch_to_tensor(self, micro_batch: MicroBatch) -> TensorMicroBatch: input_ids=torch.tensor(micro_batch.input_ids, dtype=torch.long).unsqueeze(0), position_ids=torch.tensor(micro_batch.position_ids, dtype=torch.long).unsqueeze(0), advantages=torch.tensor(micro_batch.advantages, dtype=torch.float).unsqueeze(0), + rewards=torch.tensor(micro_batch.rewards, dtype=torch.float).unsqueeze(0) + if micro_batch.rewards is not None + else None, inference_logprobs=torch.tensor(micro_batch.inference_logprobs, dtype=torch.float).unsqueeze(0), teacher_logprobs=torch.tensor(micro_batch.teacher_logprobs, dtype=torch.float).unsqueeze(0) if micro_batch.teacher_logprobs is not None else None, loss_mask=torch.tensor(micro_batch.loss_mask, dtype=torch.bool).unsqueeze(0), temperatures=torch.tensor(micro_batch.temperatures, dtype=torch.float).unsqueeze(0), + env_names=micro_batch.env_names, lora_num_tokens=torch.tensor(micro_batch.lora_num_tokens, dtype=torch.int32), # Multimodal fields - no batch dimension for these as they are variable-sized pixel_values=torch.frombuffer(bytearray(micro_batch.pixel_values), dtype=torch.float32).reshape( diff --git a/src/prime_rl/trainer/rl/loss.py b/src/prime_rl/trainer/rl/loss.py index d67c055a68..dfb13dc9b1 100644 --- a/src/prime_rl/trainer/rl/loss.py +++ b/src/prime_rl/trainer/rl/loss.py @@ -104,6 +104,15 @@ def _safe_mean(values: Tensor, mask: Tensor) -> Tensor: return values[mask].sum() / denom +def compute_importance_ratio_and_mismatch_kl( + trainer_logprobs: Tensor, inference_logprobs: Tensor +) -> tuple[Tensor, Tensor, Tensor]: + log_importance_ratio = trainer_logprobs - inference_logprobs + importance_ratio = torch.exp(log_importance_ratio) + mismatch_kl = importance_ratio - log_importance_ratio - 1 + return log_importance_ratio, importance_ratio, mismatch_kl + + def default_loss_fn(inputs: LossInputs, loss_config: DefaultLossConfig) -> LossOutputs: """ DPPO+KL loss, combining: @@ -122,22 +131,23 @@ def default_loss_fn(inputs: LossInputs, loss_config: DefaultLossConfig) -> LossO advantages = inputs.advantages loss_mask = inputs.loss_mask - trainer_probs = torch.exp(trainer_logprobs) - inference_probs = torch.exp(inference_logprobs) - probs_diff = trainer_probs - inference_probs + log_importance_ratio, importance_ratio, mismatch_kl = compute_importance_ratio_and_mismatch_kl( + trainer_logprobs, inference_logprobs + ) + + probs_diff = torch.exp(trainer_logprobs) - torch.exp(inference_logprobs) dppo_invalid_mask_high = probs_diff > loss_config.dppo_mask_high dppo_invalid_mask_low = probs_diff < -loss_config.dppo_mask_low - dppo_invalid_mask = torch.where(advantages > 0, dppo_invalid_mask_high, dppo_invalid_mask_low) + positive_advantages = advantages > 0 + negative_advantages = advantages < 0 + dppo_invalid_mask = torch.where(positive_advantages, dppo_invalid_mask_high, dppo_invalid_mask_low) is_masked = dppo_invalid_mask - is_masked_high = (advantages > 0) & dppo_invalid_mask_high - is_masked_low = (advantages < 0) & dppo_invalid_mask_low + is_masked_high = positive_advantages & dppo_invalid_mask_high + is_masked_low = negative_advantages & dppo_invalid_mask_low + drop_mask = loss_mask & is_masked keep_mask = loss_mask & ~is_masked - log_importance_ratio = trainer_logprobs - inference_logprobs - importance_ratio = torch.exp(log_importance_ratio) - mismatch_kl = importance_ratio - log_importance_ratio - 1 - advantages = loss_config.adv_tau * advantages if teacher_logprobs is not None: teacher_kl = teacher_logprobs - trainer_logprobs @@ -150,12 +160,13 @@ def default_loss_fn(inputs: LossInputs, loss_config: DefaultLossConfig) -> LossO loss = (-pg_loss + loss_config.kl_tau * kl_loss).sum() metrics = { - "mismatch_kl": _safe_mean(mismatch_kl, loss_mask), # all trainable tokens "masked_mismatch_kl": _safe_mean(mismatch_kl, loss_mask & is_masked), # all trainable, masked tokens "unmasked_mismatch_kl": _safe_mean(mismatch_kl, keep_mask), # all trainable, unmasked tokens "is_masked": _safe_mean(is_masked, loss_mask), "is_masked_low": _safe_mean(is_masked_low, loss_mask), "is_masked_high": _safe_mean(is_masked_high, loss_mask), + "masked_advantage_positive": _safe_mean(positive_advantages, drop_mask), + "masked_advantage_negative": _safe_mean(negative_advantages, drop_mask), } if teacher_kl is not None: metrics["teacher_kl"] = _safe_mean(teacher_kl, loss_mask) @@ -230,7 +241,11 @@ def compute_loss( teacher_logprobs = [None] * len(trainer_logprobs) for t_logp, i_logp, teach_logp, adv, mask in zip( - trainer_logprobs, inference_logprobs, teacher_logprobs, advantages, loss_mask + trainer_logprobs, + inference_logprobs, + teacher_logprobs, + advantages, + loss_mask, ): inputs = LossInputs( trainer_logprobs=t_logp, diff --git a/src/prime_rl/trainer/rl/token_export.py b/src/prime_rl/trainer/rl/token_export.py new file mode 100644 index 0000000000..6a1e7a3239 --- /dev/null +++ b/src/prime_rl/trainer/rl/token_export.py @@ -0,0 +1,243 @@ +import atexit +import json +import math +from collections.abc import Mapping, Sequence +from pathlib import Path +from typing import Any + +import torch +from torch import Tensor + +from prime_rl.configs.trainer import DefaultLossConfig, TrainerConfig +from prime_rl.trainer.rl.loss import compute_importance_ratio_and_mismatch_kl + +SCHEMA_VERSION = 1 + + +class DisabledTokenExporter: + def export(self, *args: Any, **kwargs: Any) -> None: + return + + def close(self) -> None: + return + + +class TokenExporter: + def __init__( + self, + output_dir: Path, + rank: int, + ) -> None: + self.rank = rank + self.output_dir = output_dir / "token_exports" + self._file: Any | None = None + self._closed = False + self._current_step: int | None = None + self._sequences_this_step = 0 + atexit.register(self.close) + + def export( + self, + step: int, + micro_step: int, + micro_batch: Mapping[str, Any], + model_output: Mapping[str, Tensor], + response_lengths: list[int], + loss_config: Any, + ) -> None: + if self._current_step != step: + self._start_step(step) + + columns = _export_columns(micro_batch, model_output, loss_config) + _check_lengths(columns) + + start = 0 + for micro_sequence_idx, length in enumerate(response_lengths): + raw_end = start + length + end = _trim_padding(columns, start, raw_end) + if end > start and any(columns["loss_mask"][start:end]): + self._write( + { + "schema_version": SCHEMA_VERSION, + "step": step, + "rank": self.rank, + "micro_step": micro_step, + "micro_sequence_idx": micro_sequence_idx, + "export_sequence_idx": self._sequences_this_step, + "env_name": _first_non_empty(columns["env_names"][start:end]), + "training_mode": _training_mode(micro_batch), + **_slice_columns(columns, start, end), + } + ) + self._sequences_this_step += 1 + start = raw_end + + def close(self) -> None: + if self._closed: + return + self._closed = True + if self._file is not None: + self._file.close() + self._file = None + + def _start_step(self, step: int) -> None: + if self._closed: + raise RuntimeError(f"Token exporter is closed for {self.output_dir}") + if self._file is not None: + self._file.close() + self._current_step = step + self._sequences_this_step = 0 + step_dir = self.output_dir / f"step_{step}" + step_dir.mkdir(parents=True, exist_ok=True) + self._file = (step_dir / f"rank_{self.rank}.jsonl").open("w", encoding="utf-8") + + def _write(self, record: dict[str, Any]) -> None: + if self._closed: + raise RuntimeError(f"Token exporter is closed for {self.output_dir}") + if self._file is None: + raise RuntimeError("Token exporter has no active step file") + self._file.write(json.dumps(record, separators=(",", ":"), allow_nan=False) + "\n") + + +def setup_token_exporter( + config: TrainerConfig, parallel_dims: Any, world: Any, logger: Any +) -> TokenExporter | DisabledTokenExporter: + token_export_config = config.experimental.token_export + if token_export_config is None: + return DisabledTokenExporter() + if parallel_dims.cp_enabled and parallel_dims.world_mesh["cp"].get_local_rank() != 0: + return DisabledTokenExporter() + + exporter = TokenExporter(config.output_dir, world.rank) + logger.info(f"Writing token exports under {exporter.output_dir}") + return exporter + + +def _export_columns( + micro_batch: Mapping[str, Any], model_output: Mapping[str, Tensor], loss_config: Any +) -> dict[str, list[Any]]: + token_ids = _tensor_to_ints(micro_batch["input_ids"]) + seq_len = len(token_ids) + trainer_logprobs = model_output["logprobs"] + export_tensors = _compute_export_tensors(micro_batch, trainer_logprobs, loss_config) + + return { + "token_ids": token_ids, + "position_ids": _tensor_to_ints(micro_batch["position_ids"]), + "loss_mask": _tensor_to_bools(micro_batch["loss_mask"]), + "advantages": _tensor_to_floats(micro_batch["advantages"]), + "rewards": _optional_tensor_to_floats(micro_batch.get("rewards"), seq_len), + "inference_logprobs": _tensor_to_floats(micro_batch["inference_logprobs"]), + "trainer_logprobs": _tensor_to_floats(trainer_logprobs), + "entropy": _tensor_to_floats(model_output["entropy"]), + "mismatch_kl": _optional_tensor_to_floats(export_tensors["mismatch_kl"], seq_len), + "log_importance_ratio": _optional_tensor_to_floats(export_tensors["log_importance_ratio"], seq_len), + "importance_ratio": _optional_tensor_to_floats(export_tensors["importance_ratio"], seq_len), + "prob_delta": _optional_tensor_to_floats(export_tensors["prob_delta"], seq_len), + "is_masked": _optional_tensor_to_bools(export_tensors["is_masked"], seq_len), + "is_masked_high": _optional_tensor_to_bools(export_tensors["is_masked_high"], seq_len), + "is_masked_low": _optional_tensor_to_bools(export_tensors["is_masked_low"], seq_len), + "env_names": list(micro_batch["env_names"]), + } + + +def _compute_export_tensors( + micro_batch: Mapping[str, Any], trainer_logprobs: Tensor, loss_config: Any +) -> dict[str, Tensor | None]: + fields: dict[str, Tensor | None] = { + "log_importance_ratio": None, + "importance_ratio": None, + "mismatch_kl": None, + "prob_delta": None, + "is_masked": None, + "is_masked_high": None, + "is_masked_low": None, + } + if _training_mode(micro_batch) == "sft": + return fields + + inference_logprobs = micro_batch["inference_logprobs"].to(trainer_logprobs.device) + loss_mask = micro_batch["loss_mask"].to(trainer_logprobs.device) + advantages = micro_batch["advantages"].to(trainer_logprobs.device) + with torch.no_grad(): + log_ratio, ratio, mismatch_kl = compute_importance_ratio_and_mismatch_kl(trainer_logprobs, inference_logprobs) + prob_delta = torch.exp(trainer_logprobs) - torch.exp(inference_logprobs) + fields["log_importance_ratio"] = log_ratio + fields["importance_ratio"] = ratio + fields["mismatch_kl"] = mismatch_kl + fields["prob_delta"] = prob_delta + if isinstance(loss_config, DefaultLossConfig): + invalid_high = prob_delta > loss_config.dppo_mask_high + invalid_low = prob_delta < -loss_config.dppo_mask_low + positive_advantages = advantages > 0 + negative_advantages = advantages < 0 + invalid = torch.where(positive_advantages, invalid_high, invalid_low) + fields["is_masked"] = loss_mask & invalid + fields["is_masked_high"] = loss_mask & positive_advantages & invalid_high + fields["is_masked_low"] = loss_mask & negative_advantages & invalid_low + return fields + + +def _training_mode(micro_batch: Mapping[str, Any]) -> str: + if micro_batch.get("sft_loss", False): + return "sft" + return str(micro_batch.get("training_mode", "rl")) + + +def _tensor_to_ints(tensor: Tensor) -> list[int]: + return [int(value) for value in tensor.detach().cpu().reshape(-1).tolist()] + + +def _tensor_to_bools(tensor: Tensor) -> list[bool]: + return [bool(value) for value in tensor.detach().cpu().reshape(-1).tolist()] + + +def _tensor_to_floats(tensor: Tensor) -> list[float | None]: + values = tensor.detach().to(dtype=torch.float32, device="cpu").reshape(-1).tolist() + return [_json_float(value) for value in values] + + +def _optional_tensor_to_floats(tensor: Tensor | None, seq_len: int) -> list[float | None]: + if tensor is None: + return [None] * seq_len + return _tensor_to_floats(tensor) + + +def _optional_tensor_to_bools(tensor: Tensor | None, seq_len: int) -> list[bool | None]: + if tensor is None: + return [None] * seq_len + return _tensor_to_bools(tensor) + + +def _check_lengths(columns: Mapping[str, Sequence[Any]]) -> None: + lengths = {key: len(values) for key, values in columns.items()} + if len(set(lengths.values())) != 1: + raise ValueError(f"Token export fields must have aligned lengths, got {lengths}") + + +def _slice_columns(columns: Mapping[str, Sequence[Any]], start: int, end: int) -> dict[str, list[Any]]: + return {key: list(values[start:end]) for key, values in columns.items() if key != "env_names"} + + +def _trim_padding(columns: Mapping[str, Sequence[Any]], start: int, end: int) -> int: + env_names = columns["env_names"] + loss_mask = columns["loss_mask"] + while end > start and env_names[end - 1] == "" and not loss_mask[end - 1]: + end -= 1 + return end + + +def _json_float(value: float | None) -> float | None: + if value is None: + return None + value = float(value) + if not math.isfinite(value): + return None + return value + + +def _first_non_empty(values: Sequence[str]) -> str | None: + for value in values: + if value: + return value + return None diff --git a/src/prime_rl/trainer/rl/train.py b/src/prime_rl/trainer/rl/train.py index 4b2b932297..29b0e32659 100644 --- a/src/prime_rl/trainer/rl/train.py +++ b/src/prime_rl/trainer/rl/train.py @@ -29,11 +29,13 @@ from prime_rl.trainer.rl.loss import ( compute_entropy, compute_loss, + compute_importance_ratio_and_mismatch_kl, selective_log_softmax, setup_loss_fn, shift_tensor_left, shift_tensor_right, ) +from prime_rl.trainer.rl.token_export import setup_token_exporter from prime_rl.trainer.model import ( forward, setup_tokenizer, @@ -238,6 +240,8 @@ def load_run_checkpoint(_optimizer, idx: int) -> None: config.rollout_transport, ) + token_exporter = setup_token_exporter(config, parallel_dims, world, logger) + gc_handler = GarbageCollection(config.gc.interval) if config.gc else None logger.info(f"Starting training loop (max_steps={config.max_steps or 'infinite'})") @@ -337,9 +341,15 @@ def load_run_checkpoint(_optimizer, idx: int) -> None: forward_backward_start_time = time.perf_counter() seq_len = micro_batches[0]["input_ids"].shape[1] - # Normalize by the local number of unmasked tokens in the batch (per-batch length normalization) - loss_scale = sum(micro_batch["loss_mask"].sum().item() for micro_batch in micro_batches) - loss_scale = max(loss_scale, 1) + # Normalize by the global (dp_cp) number of unmasked tokens in the batch, so every rank + # divides by the same denominator. With a per-rank denominator, ranks with fewer loss + # tokens implicitly upweight their per-token gradient contribution after FSDP averaging. + # FSDP's per-rank divide is undone after the microbatch loop via fsdp_gradient_divide_factor. + local_loss_scale = sum(micro_batch["loss_mask"].sum().item() for micro_batch in micro_batches) + global_loss_scale = torch.tensor(local_loss_scale, dtype=torch.int64, device="cuda") + dp_cp_group = parallel_dims.get_mesh("dp_cp").get_group() + dist.all_reduce(global_loss_scale, op=dist.ReduceOp.SUM, group=dp_cp_group) + loss_scale = max(global_loss_scale.item(), 1) logger.debug(f"Starting forward and backward pass ({batch_size=})") tensors = Tensors() # Used to accumulate tensor statistics across micro-batches and ranks for logging @@ -474,9 +484,29 @@ def load_run_checkpoint(_optimizer, idx: int) -> None: loss.backward() # Add relevant tensors to tensor dict for logging purposes - tensors["entropy"].append(out["entropy"][loss_mask].detach().to("cpu")) + entropy = out["entropy"][loss_mask].detach().to("cpu") + tensors["entropy/all"].append(entropy) tensors["loss"].append(loss.detach().to("cpu").unsqueeze(0)) + env_names = micro_batch["env_names"] + masked_env_names = [env_name for env_name, keep in zip(env_names, loss_mask.flatten().tolist()) if keep] + env_to_indices: dict[str, list[int]] = {} + for idx, env_name in enumerate(masked_env_names): + env_to_indices.setdefault(env_name, []).append(idx) + + for env_name, indices in env_to_indices.items(): + tensors[f"entropy/{env_name}"].append(entropy[indices]) + + if not micro_batch["sft_loss"]: + with torch.no_grad(): + _, _, mismatch_kl = compute_importance_ratio_and_mismatch_kl(out["logprobs"], inference_logprobs) + mismatch_kl = mismatch_kl[loss_mask].detach().to("cpu") + tensors["mismatch_kl/all"].append(mismatch_kl) + for env_name, indices in env_to_indices.items(): + tensors[f"mismatch_kl/{env_name}"].append(mismatch_kl[indices]) + + token_exporter.export(progress.step, micro_step, micro_batch, out, response_lengths, config.loss) + if is_tt_moe_model(model): load_balance_stats = get_load_balance_stats(model) for k, v in load_balance_stats.items(): @@ -485,17 +515,22 @@ def load_run_checkpoint(_optimizer, idx: int) -> None: # Add loss tensors to tensor dict for logging purposes for key, loss_tensor in loss_tensors.items(): - loss_tensor = loss_tensor.detach().to("cpu") - tensors[key].append(loss_tensor) + tensors[key].append(loss_tensor.detach().to("cpu")) # Debug log with *local, micro step* stats - micro_step_message = f"Micro Step {micro_step}/{len(micro_batches)} | Loss: {tensors['loss'][-1].mean().item():.4f} | Entropy: {tensors['entropy'][-1].mean().item():.4f}" - if "mismatch_kl" in tensors: - micro_step_message += f" | Mismatch KL: {tensors['mismatch_kl'][-1].mean().item():.4f}" + micro_step_message = f"Micro Step {micro_step}/{len(micro_batches)} | Loss: {tensors['loss'][-1].mean().item():.4f} | Entropy: {tensors['entropy/all'][-1].mean().item():.4f}" + if not micro_batch["sft_loss"]: + micro_step_message += f" | Mismatch KL: {tensors['mismatch_kl/all'][-1].mean().item():.4f}" if "max_vio" in tensors: micro_step_message += f" | Max Vio: {tensors['max_vio'][-1].mean().item():.4f}" logger.debug(micro_step_message) + # compute_loss already divided by the global token count. Undo FSDP's per-rank averaging + # across dp_cp so the final gradient is the true per-token mean over the global batch. + for param in model.parameters(): + if param.grad is not None: + param.grad.mul_(parallel_dims.fsdp_gradient_divide_factor) + # Optionally, clip the gradients grad_norm: torch.Tensor | None = None if config.optim.max_norm is not None: @@ -539,9 +574,9 @@ def load_run_checkpoint(_optimizer, idx: int) -> None: # Log step metrics step_time = time.perf_counter() - step_start_time - step_message = f"Step {progress.step} | Time: {step_time:.2f}s | Loss: {tensor_stats['loss/mean']:.4f} | Entropy: {tensor_stats['entropy/mean']:.4f}" - if "mismatch_kl/mean" in tensor_stats: - step_message += f" | Mismatch KL: {tensor_stats['mismatch_kl/mean']:.4f}" + step_message = f"Step {progress.step} | Time: {step_time:.2f}s | Loss: {tensor_stats['loss/mean']:.4f} | Entropy: {tensor_stats['entropy/all/mean']:.4f}" + if "mismatch_kl/all/mean" in tensor_stats: + step_message += f" | Mismatch KL: {tensor_stats['mismatch_kl/all/mean']:.4f}" if grad_norm is not None: step_message += f" | Grad. Norm: {grad_norm:.4f}" step_message += f" | LR: {current_lr:.2e} | Throughput: {throughput:.0f} tokens/s | MFU: {mfu:.1f}% | Peak Mem.: {peak_memory:.1f} GiB" @@ -570,8 +605,8 @@ def load_run_checkpoint(_optimizer, idx: int) -> None: monitor.log(optim_metrics, step=progress.step) # Compute derived metrics - entropy_mean = tensor_stats.get("entropy/mean", 0.0) - mismatch_kl_mean = tensor_stats.get("mismatch_kl/mean") + entropy_mean = tensor_stats.get("entropy/all/mean", 0.0) + mismatch_kl_mean = tensor_stats.get("mismatch_kl/all/mean") if mismatch_kl_mean is not None and entropy_mean > 0: tensor_stats["kl_ent_ratio/mean"] = mismatch_kl_mean / entropy_mean @@ -605,8 +640,8 @@ def load_run_checkpoint(_optimizer, idx: int) -> None: peak_memory_gib=peak_memory, learning_rate=current_lr, mfu=mfu, - entropy=tensor_stats.get("entropy/mean", 0.0), - mismatch_kl=tensor_stats.get("mismatch_kl/mean", 0.0), + entropy=tensor_stats.get("entropy/all/mean", 0.0), + mismatch_kl=tensor_stats.get("mismatch_kl/all/mean", 0.0), zero_grad_ratio=zero_grad_ratio, ) # Update run/LoRA metrics @@ -650,6 +685,8 @@ def load_run_checkpoint(_optimizer, idx: int) -> None: prof.export_chrome_trace(trace_file) logger.info(f"Saved trace to {trace_file}") + token_exporter.close() + # Write final checkpoint (only for single-run mode; multi-run checkpoints are managed by MultiCheckpointManager) if config.max_concurrent_runs == 1 and ckpt_manager is not None: if not (config.ckpt and config.ckpt.weights_only): diff --git a/src/prime_rl/trainer/utils.py b/src/prime_rl/trainer/utils.py index 48b4df643c..8f437f2e07 100644 --- a/src/prime_rl/trainer/utils.py +++ b/src/prime_rl/trainer/utils.py @@ -363,10 +363,16 @@ def __init__(self): def compute_stats(self) -> dict[str, float | int]: """Synchronize the tensor statistic across all ranks for each key and compute relevant statistics.""" + local_keys = list(self.keys()) + gathered_keys: list[list[str] | None] = [None] * dist.get_world_size() + dist.all_gather_object(gathered_keys, local_keys) + keys = sorted({key for rank_keys in gathered_keys if rank_keys is not None for key in rank_keys}) + metrics = {} - for key in list(self.keys()): + for key in keys: # All-gather tensors across steps and ranks (get global distribution) - tensors = torch.cat(self.pop(key), dim=0).to("cuda") + values = self.pop(key, []) + tensors = torch.cat(values, dim=0).to("cuda") if values else torch.empty(0, device="cuda") assert tensors.ndim == 1, "Can only aggregate 1D tensors" tensors = flexible_all_gather(tensors) assert tensors.ndim == 1, "Can only aggregate 1D tensors" @@ -393,6 +399,11 @@ def compute_stats(self) -> dict[str, float | int]: return metrics +def _is_env_tensor_stat(key: str, allowed_stats: set[str]) -> bool: + parts = key.split("/") + return len(parts) >= 3 and parts[-1] in allowed_stats + + def filter_rl_trainer_tensor_stats_for_wandb(metrics: dict[str, float | int]) -> dict[str, float | int]: """Drop noisy per-token distribution keys before sending RL trainer stats to W&B.""" skip_prefixes = ("trainer_probs/", "inference_probs/") @@ -400,6 +411,8 @@ def filter_rl_trainer_tensor_stats_for_wandb(metrics: dict[str, float | int]) -> "is_masked/", "is_masked_low/", "is_masked_high/", + "masked_advantage_positive/", + "masked_advantage_negative/", "mismatch_kl/", "masked_mismatch_kl/", "unmasked_mismatch_kl/", @@ -411,9 +424,12 @@ def filter_rl_trainer_tensor_stats_for_wandb(metrics: dict[str, float | int]) -> continue if any(k.startswith(p) for p in skip_prefixes): continue - if k.startswith("entropy/") and k != "entropy/mean": + if k.startswith("entropy/") and not _is_env_tensor_stat(k, {"mean", "std", "max"}): continue if any(k.startswith(p) for p in mean_max_only_prefixes): + if _is_env_tensor_stat(k, {"mean", "std", "max"}): + out[k] = v + continue if not (k.endswith("/mean") or k.endswith("/max")): continue out[k] = v diff --git a/src/prime_rl/transport/types.py b/src/prime_rl/transport/types.py index cc943e9b76..aad577efeb 100644 --- a/src/prime_rl/transport/types.py +++ b/src/prime_rl/transport/types.py @@ -19,6 +19,7 @@ class TrainingSample(msgspec.Struct, array_like=True, gc=False, omit_defaults=Tr completion_mask: list[bool] completion_logprobs: list[float] completion_temperatures: list[float] # Per-token temperatures used during generation + env_name: str teacher_logprobs: list[float] | None = None advantage: float | None = None reward: float | None = None @@ -55,6 +56,7 @@ class MicroBatch(msgspec.Struct, array_like=True, gc=False, omit_defaults=True): inference_logprobs: list[float] position_ids: list[int] temperatures: list[float] # Per-token temperatures used during generation + env_names: list[str] teacher_logprobs: list[float] | None = None lora_num_tokens: list[int] | None = None routed_experts: RoutedExperts | None = None @@ -68,3 +70,4 @@ class MicroBatch(msgspec.Struct, array_like=True, gc=False, omit_defaults=True): mm_token_type_ids: list[int] | None = None sft_loss: bool = False # When True, trainer uses SFT loss instead of RL loss for this batch + rewards: list[float] | None = None diff --git a/tests/unit/orchestrator/test_advantage.py b/tests/unit/orchestrator/test_advantage.py index 28051ec97d..93c2b7ea4e 100644 --- a/tests/unit/orchestrator/test_advantage.py +++ b/tests/unit/orchestrator/test_advantage.py @@ -1,4 +1,6 @@ -import torch +import math + +import pytest from prime_rl.configs.orchestrator import ( CustomAdvantageConfig, @@ -15,7 +17,13 @@ ) -def _make_rollout(reward: float, completion_len: int = 0, num_turns: int = 1) -> dict: +def _make_rollout( + reward: float, + completion_len: int = 0, + num_turns: int = 1, + env_name: str = "test", + example_id: int = 0, +) -> dict: """Create a minimal rollout dict for advantage testing. `completion_len` tokens are split across `num_turns` trajectory steps. @@ -25,21 +33,21 @@ def _make_rollout(reward: float, completion_len: int = 0, num_turns: int = 1) -> {"tokens": {"prompt_ids": [0], "completion_ids": list(range(per_turn + (rem if i == 0 else 0)))}} for i in range(num_turns) ] - return {"reward": reward, "trajectory": trajectory} + return { + "reward": reward, + "trajectory": trajectory, + "env_name": env_name, + "example_id": example_id, + } -def _make_inputs(rewards, completion_lengths=None, num_turns=None) -> AdvantageInputs: - """Build AdvantageInputs from 2D arrays of rewards/lengths/turns.""" - rewards_t = torch.as_tensor(rewards, dtype=torch.float32) - num_problems, rollouts_per_example = rewards_t.shape +def _make_group(rewards, completion_lengths=None, num_turns=None) -> AdvantageInputs: + """Build single-group AdvantageInputs from 1D arrays of rewards/lengths/turns.""" rollouts = [] - for i in range(num_problems): - group = [] - for j in range(rollouts_per_example): - cl = int(completion_lengths[i][j]) if completion_lengths is not None else 0 - nt = int(num_turns[i][j]) if num_turns is not None else 1 - group.append(_make_rollout(float(rewards_t[i, j]), cl, nt)) - rollouts.append(group) + for i, reward in enumerate(rewards): + cl = int(completion_lengths[i]) if completion_lengths is not None else 0 + nt = int(num_turns[i]) if num_turns is not None else 1 + rollouts.append(_make_rollout(float(reward), cl, nt)) return AdvantageInputs(rollouts=rollouts) @@ -49,23 +57,16 @@ def _make_inputs(rewards, completion_lengths=None, num_turns=None) -> AdvantageI def test_default_advantage_fn_simple_mean(): - inputs = _make_inputs( - rewards=[[1.0, 0.5, 0.8], [0.2, 0.9, 0.1]], - completion_lengths=[[10, 12, 8], [15, 11, 9]], - ) + inputs = _make_group(rewards=[1.0, 0.5, 0.8], completion_lengths=[10, 12, 8]) result = default_advantage_fn(inputs) - assert result.advantages.shape == (2, 3) - # Check that mean is subtracted per row - assert torch.allclose(result.advantages.mean(dim=1), torch.zeros(2), atol=1e-6) + assert len(result.advantages) == 3 + assert sum(result.advantages) == pytest.approx(0.0, abs=1e-6) def test_efficiency_mixed_group(): """Mixed group: reward shaping preserves zero-mean, shorter correct gets higher advantage.""" - inputs = _make_inputs( - rewards=[[1.0, 1.0, 0.0, 1.0]], - completion_lengths=[[10, 30, 20, 20]], - ) + inputs = _make_group(rewards=[1.0, 1.0, 0.0, 1.0], completion_lengths=[10, 30, 20, 20]) result = default_advantage_fn(inputs, length_penalty=_TOKENS_COMPLETION) # mean_correct_len = (10+30+20)/3 = 20 @@ -73,247 +74,190 @@ def test_efficiency_mixed_group(): # shaped_rewards = R * (1 + bonus * correct_mask) = [1.5, 1, 0, 1] # baseline = mean(shaped_rewards) = 0.875 # A = shaped_rewards - baseline = [0.625, 0.125, -0.875, 0.125] - expected = torch.tensor([[0.625, 0.125, -0.875, 0.125]]) - assert torch.allclose(result.advantages, expected, atol=1e-6) + assert result.advantages == pytest.approx([0.625, 0.125, -0.875, 0.125], abs=1e-6) # Zero-mean per group - assert torch.allclose(result.advantages.mean(dim=1), torch.zeros(1), atol=1e-6) + assert sum(result.advantages) == pytest.approx(0.0, abs=1e-6) # All correct rollouts have positive advantage - rewards = torch.tensor([r["reward"] for r in inputs.rollouts[0]]) - correct_mask = rewards >= 1.0 - assert (result.advantages[0][correct_mask] > 0).all() + for rollout, adv in zip(inputs.rollouts, result.advantages): + if rollout["reward"] >= 1.0: + assert adv > 0 def test_efficiency_all_correct_group(): """All-correct group: zero-mean, shorter gets higher advantage.""" - inputs = _make_inputs( - rewards=[[1.0, 1.0, 1.0]], - completion_lengths=[[10, 20, 40]], - ) + inputs = _make_group(rewards=[1.0, 1.0, 1.0], completion_lengths=[10, 20, 40]) result = default_advantage_fn(inputs, length_penalty=_TOKENS_COMPLETION) # mean_len = 70/3 ≈ 23.33 # bonus = clamp(1 - [10, 20, 40] / (70/3), 0, 1) = [4/7, 1/7, 0] # shaped_rewards = [1+4/7, 1+1/7, 1] = [11/7, 8/7, 1] - # baseline = mean = (11/7 + 8/7 + 1) / 3 = (11+8+7)/(7*3) = 26/21 - # A = shaped - baseline - shaped = torch.tensor([[11.0 / 7, 8.0 / 7, 1.0]]) - baseline = shaped.mean(dim=1, keepdim=True) - expected = shaped - baseline - assert torch.allclose(result.advantages, expected, atol=1e-6) + shaped = [11.0 / 7, 8.0 / 7, 1.0] + mean_shaped = sum(shaped) / len(shaped) + expected = [s - mean_shaped for s in shaped] + assert result.advantages == pytest.approx(expected, abs=1e-6) # Zero-mean - assert torch.allclose(result.advantages.mean(dim=1), torch.zeros(1), atol=1e-6) + assert sum(result.advantages) == pytest.approx(0.0, abs=1e-6) # Shortest has highest advantage - assert result.advantages[0, 0] > result.advantages[0, 1] > result.advantages[0, 2] + assert result.advantages[0] > result.advantages[1] > result.advantages[2] def test_efficiency_all_zero_rewards(): """When all rewards are 0, no length shaping — falls back to standard GRPO.""" - inputs = _make_inputs( - rewards=[[0.0, 0.0, 0.0]], - completion_lengths=[[10, 20, 15]], - ) + inputs = _make_group(rewards=[0.0, 0.0, 0.0], completion_lengths=[10, 20, 15]) result_with = default_advantage_fn(inputs, length_penalty=_TOKENS_COMPLETION) result_without = default_advantage_fn(inputs) - assert torch.allclose(result_with.advantages, result_without.advantages, atol=1e-6) + assert result_with.advantages == pytest.approx(result_without.advantages, abs=1e-6) def test_efficiency_single_correct(): """Single correct rollout: bonus=0 (at its own mean), same as standard GRPO.""" - inputs = _make_inputs( - rewards=[[1.0, 0.0, 0.0, 0.0]], - completion_lengths=[[100, 50, 200, 150]], - ) + inputs = _make_group(rewards=[1.0, 0.0, 0.0, 0.0], completion_lengths=[100, 50, 200, 150]) result = default_advantage_fn(inputs, length_penalty=_TOKENS_COMPLETION) - expected = torch.tensor([[0.75, -0.25, -0.25, -0.25]]) - assert torch.allclose(result.advantages, expected, atol=1e-6) + assert result.advantages == pytest.approx([0.75, -0.25, -0.25, -0.25], abs=1e-6) def test_efficiency_shorter_correct_higher_advantage(): """Among correct rollouts in a mixed group, shorter always gets higher advantage.""" - inputs = _make_inputs( - rewards=[[1.0, 1.0, 1.0, 0.0, 0.0]], - completion_lengths=[[50, 100, 200, 80, 120]], - ) + inputs = _make_group(rewards=[1.0, 1.0, 1.0, 0.0, 0.0], completion_lengths=[50, 100, 200, 80, 120]) result = default_advantage_fn(inputs, length_penalty=_TOKENS_COMPLETION) - advs = result.advantages[0] + advs = result.advantages assert advs[0] > advs[1] > advs[2] - assert (advs[:3] > 0).all() - assert (advs[3:] < 0).all() + assert all(a > 0 for a in advs[:3]) + assert all(a < 0 for a in advs[3:]) def test_efficiency_zero_mean_per_group(): - """Reward shaping preserves zero-mean advantages per group.""" - inputs = _make_inputs( - rewards=[ - [1.0, 1.0, 0.0, 1.0], # mixed - [1.0, 1.0, 1.0, 1.0], # all correct - ], - completion_lengths=[ - [10, 30, 20, 20], - [10, 20, 40, 80], - ], + """Reward shaping preserves zero-mean advantages within each group.""" + mixed = default_advantage_fn( + _make_group(rewards=[1.0, 1.0, 0.0, 1.0], completion_lengths=[10, 30, 20, 20]), + length_penalty=_TOKENS_COMPLETION, + ) + all_correct = default_advantage_fn( + _make_group(rewards=[1.0, 1.0, 1.0, 1.0], completion_lengths=[10, 20, 40, 80]), + length_penalty=_TOKENS_COMPLETION, ) - result = default_advantage_fn(inputs, length_penalty=_TOKENS_COMPLETION) - assert torch.allclose(result.advantages.mean(dim=1), torch.zeros(2), atol=1e-6) + assert sum(mixed.advantages) == pytest.approx(0.0, abs=1e-6) + assert sum(all_correct.advantages) == pytest.approx(0.0, abs=1e-6) def test_efficiency_amplification_bounded(): """Even with extreme length outliers, reward amplification is capped at 2x.""" - inputs = _make_inputs( - rewards=[[1.0, 1.0, 0.0]], - completion_lengths=[[1, 10000, 5000]], - ) + inputs = _make_group(rewards=[1.0, 1.0, 0.0], completion_lengths=[1, 10000, 5000]) result = default_advantage_fn(inputs, length_penalty=_TOKENS_COMPLETION) # Shortest correct gets bonus ≈ 1, so shaped_reward ≈ 2 # Standard reward = 1, so amplification ≈ 2x # shaped_rewards ≈ [2, 1, 0], baseline ≈ 1, max advantage ≈ 1 - assert result.advantages[0, 0] < 1.0 + 1e-3 - - -def test_efficiency_multiple_problems(): - """Handles multiple problems independently.""" - inputs = _make_inputs( - rewards=[ - [1.0, 1.0, 0.0], # mixed - [1.0, 1.0, 1.0], # all correct - ], - completion_lengths=[ - [10, 20, 15], - [10, 20, 40], - ], - ) - result = default_advantage_fn(inputs, length_penalty=_TOKENS_COMPLETION) - - # Row 0: mixed group — shorter correct > longer correct - assert result.advantages[0, 0] > result.advantages[0, 1] - assert (result.advantages[0, :2] > 0).all() - assert result.advantages[0, 2] < 0 - - # Row 1: all-correct group — shorter gets higher advantage - assert result.advantages[1, 0] > result.advantages[1, 1] > result.advantages[1, 2] - - # Both rows have zero-mean - assert torch.allclose(result.advantages.mean(dim=1), torch.zeros(2), atol=1e-6) + assert result.advantages[0] < 1.0 + 1e-3 def test_efficiency_tokens_with_tool_response_weight(): """`tool_response_weight` shifts shaping onto tool-response tokens read from rollout metrics.""" rollouts = [ - [ - { - "reward": 1.0, - "trajectory": [{"tokens": {"prompt_ids": [0], "completion_ids": list(range(10))}}], - "metrics": {"rlm_total_tool_response_tokens": 200}, - }, - { - "reward": 1.0, - "trajectory": [{"tokens": {"prompt_ids": [0], "completion_ids": list(range(10))}}], - "metrics": {"rlm_total_tool_response_tokens": 0}, - }, - { - "reward": 1.0, - "trajectory": [{"tokens": {"prompt_ids": [0], "completion_ids": list(range(10))}}], - "metrics": {"rlm_total_tool_response_tokens": 100}, - }, - ] + { + "reward": 1.0, + "trajectory": [{"tokens": {"prompt_ids": [0], "completion_ids": list(range(10))}}], + "metrics": {"rlm_total_tool_response_tokens": 200}, + }, + { + "reward": 1.0, + "trajectory": [{"tokens": {"prompt_ids": [0], "completion_ids": list(range(10))}}], + "metrics": {"rlm_total_tool_response_tokens": 0}, + }, + { + "reward": 1.0, + "trajectory": [{"tokens": {"prompt_ids": [0], "completion_ids": list(range(10))}}], + "metrics": {"rlm_total_tool_response_tokens": 100}, + }, ] inputs = AdvantageInputs(rollouts=rollouts) # completion tokens identical (10 each) → completion-only shaping is a no-op result_completion_only = default_advantage_fn(inputs, length_penalty=_TOKENS_COMPLETION) - assert torch.allclose(result_completion_only.advantages, torch.zeros(1, 3), atol=1e-6) + assert result_completion_only.advantages == pytest.approx([0.0, 0.0, 0.0], abs=1e-6) # tool-response only: costs are [200, 0, 100], mean=100, bonus is one-sided # so only the below-mean rollout (idx 1) gets amplified; the at/above-mean tie. result_tool_only = default_advantage_fn(inputs, length_penalty=_TOKENS_TOOL_ONLY) - advs = result_tool_only.advantages[0] + advs = result_tool_only.advantages assert advs[1] > advs[0] assert advs[1] > advs[2] - assert torch.allclose(advs[0], advs[2], atol=1e-6) - assert torch.allclose(result_tool_only.advantages.mean(dim=1), torch.zeros(1), atol=1e-6) + assert advs[0] == pytest.approx(advs[2], abs=1e-6) + assert sum(advs) == pytest.approx(0.0, abs=1e-6) def test_efficiency_fractional_weight_with_int_rewards(): """Fractional weights must not truncate when rollout rewards are emitted as ints.""" rollouts_int = [ - [ - {"reward": 1, "trajectory": [{"tokens": {"prompt_ids": [0], "completion_ids": list(range(7))}}]}, - {"reward": 1, "trajectory": [{"tokens": {"prompt_ids": [0], "completion_ids": list(range(11))}}]}, - {"reward": 0, "trajectory": [{"tokens": {"prompt_ids": [0], "completion_ids": list(range(13))}}]}, - ] + {"reward": 1, "trajectory": [{"tokens": {"prompt_ids": [0], "completion_ids": list(range(7))}}]}, + {"reward": 1, "trajectory": [{"tokens": {"prompt_ids": [0], "completion_ids": list(range(11))}}]}, + {"reward": 0, "trajectory": [{"tokens": {"prompt_ids": [0], "completion_ids": list(range(13))}}]}, ] - rollouts_float = [[{**r, "reward": float(r["reward"])} for r in g] for g in rollouts_int] + rollouts_float = [{**r, "reward": float(r["reward"])} for r in rollouts_int] fractional = TokensLengthPenaltyConfig(completion_weight=0.3, tool_response_weight=0.0) int_result = default_advantage_fn(AdvantageInputs(rollouts=rollouts_int), length_penalty=fractional) float_result = default_advantage_fn(AdvantageInputs(rollouts=rollouts_float), length_penalty=fractional) - assert torch.allclose(int_result.advantages, float_result.advantages, atol=1e-6) + assert int_result.advantages == pytest.approx(float_result.advantages, abs=1e-6) def test_efficiency_zero_costs_falls_back_to_plain_grpo(): """When all effective costs are zero, shaping is a no-op (no NaNs from div-by-zero).""" # tool-only weights but no harness metric → all costs == 0 rollouts = [ - [ - {"reward": 1.0, "trajectory": [{"tokens": {"prompt_ids": [0], "completion_ids": list(range(10))}}]}, - {"reward": 1.0, "trajectory": [{"tokens": {"prompt_ids": [0], "completion_ids": list(range(10))}}]}, - {"reward": 0.0, "trajectory": [{"tokens": {"prompt_ids": [0], "completion_ids": list(range(10))}}]}, - ] + {"reward": 1.0, "trajectory": [{"tokens": {"prompt_ids": [0], "completion_ids": list(range(10))}}]}, + {"reward": 1.0, "trajectory": [{"tokens": {"prompt_ids": [0], "completion_ids": list(range(10))}}]}, + {"reward": 0.0, "trajectory": [{"tokens": {"prompt_ids": [0], "completion_ids": list(range(10))}}]}, ] inputs = AdvantageInputs(rollouts=rollouts) result = default_advantage_fn(inputs, length_penalty=_TOKENS_TOOL_ONLY) expected = default_advantage_fn(inputs) # plain GRPO - assert not torch.isnan(result.advantages).any() - assert torch.allclose(result.advantages, expected.advantages, atol=1e-6) + assert not any(math.isnan(a) for a in result.advantages) + assert result.advantages == pytest.approx(expected.advantages, abs=1e-6) def test_efficiency_tokens_default_weights_match_completion_when_no_metric(): """Default TokensLengthPenaltyConfig (1,1) reduces to completion-only when rollouts lack the metric.""" - inputs = _make_inputs( - rewards=[[1.0, 1.0, 0.0, 1.0]], - completion_lengths=[[10, 30, 20, 20]], - ) + inputs = _make_group(rewards=[1.0, 1.0, 0.0, 1.0], completion_lengths=[10, 30, 20, 20]) result_default = default_advantage_fn(inputs, length_penalty=TokensLengthPenaltyConfig()) result_completion = default_advantage_fn(inputs, length_penalty=_TOKENS_COMPLETION) - assert torch.allclose(result_default.advantages, result_completion.advantages, atol=1e-6) + assert result_default.advantages == pytest.approx(result_completion.advantages, abs=1e-6) def test_efficiency_turns_penalty(): """`TurnsLengthPenaltyConfig` shapes by trajectory turn count rather than token count.""" - inputs = _make_inputs( - rewards=[[1.0, 1.0, 0.0, 1.0]], + inputs = _make_group( + rewards=[1.0, 1.0, 0.0, 1.0], # token counts identical, but turns differ — turns penalty should still differentiate - completion_lengths=[[100, 100, 100, 100]], - num_turns=[[1, 3, 2, 2]], + completion_lengths=[100, 100, 100, 100], + num_turns=[1, 3, 2, 2], ) result = default_advantage_fn(inputs, length_penalty=TurnsLengthPenaltyConfig()) # mean_correct_turns = (1+3+2)/3 = 2 # bonus = clamp(1 - [1,3,2,2]/2, 0, 1) = [0.5, 0, 0, 0] - expected = torch.tensor([[0.625, 0.125, -0.875, 0.125]]) - assert torch.allclose(result.advantages, expected, atol=1e-6) + assert result.advantages == pytest.approx([0.625, 0.125, -0.875, 0.125], abs=1e-6) def test_compute_advantages_with_config(): rewards = [1.0, 0.5, 0.8, 0.2, 0.9, 0.1] lengths = [10, 12, 8, 15, 11, 9] - rollouts = [_make_rollout(r, l) for r, l in zip(rewards, lengths)] + rollouts = [_make_rollout(r, l, example_id=i // 3) for i, (r, l) in enumerate(zip(rewards, lengths))] - compute_advantages(rollouts, samples_per_problem=3, advantage_config=DefaultAdvantageConfig()) + compute_advantages(rollouts, advantage_config=DefaultAdvantageConfig()) advantages = [r["advantage"] for r in rollouts] assert len(advantages) == 6 - assert abs(sum(advantages[:3])) < 1e-5 - assert abs(sum(advantages[3:])) < 1e-5 + assert sum(advantages[:3]) == pytest.approx(0.0, abs=1e-5) + assert sum(advantages[3:]) == pytest.approx(0.0, abs=1e-5) def test_compute_advantages_no_cross_group_leakage(): @@ -321,18 +265,15 @@ def test_compute_advantages_no_cross_group_leakage(): Two problems with very different reward scales — cross-group leakage would pull the small-scale group's advantages toward the large-scale group's mean (and vice versa). - Distinct positional values also catch slicing/transpose bugs in the flat→grouped→flat - round-trip. + Distinct positional values also catch ordering bugs in the group→flat round-trip. """ rewards = [10.0, 20.0, 30.0, 0.0, 0.1, 0.2] - rollouts = [_make_rollout(r) for r in rewards] + rollouts = [_make_rollout(r, example_id=i // 3) for i, r in enumerate(rewards)] - compute_advantages(rollouts, samples_per_problem=3, advantage_config=DefaultAdvantageConfig()) + compute_advantages(rollouts, advantage_config=DefaultAdvantageConfig()) advantages = [r["advantage"] for r in rollouts] - expected = [-10.0, 0.0, 10.0, -0.1, 0.0, 0.1] - for got, want in zip(advantages, expected): - assert abs(got - want) < 1e-5, (advantages, expected) + assert advantages == pytest.approx([-10.0, 0.0, 10.0, -0.1, 0.0, 0.1], abs=1e-5) def test_compute_advantages_without_config(): @@ -340,12 +281,72 @@ def test_compute_advantages_without_config(): lengths = [10, 12, 8] rollouts = [_make_rollout(r, l) for r, l in zip(rewards, lengths)] - compute_advantages(rollouts, samples_per_problem=3, advantage_config=None) + compute_advantages(rollouts, advantage_config=None) advantages = [r["advantage"] for r in rollouts] assert advantages == rewards +def test_compute_advantages_partial_groups(): + """Partial groups (size < rollouts_per_example) are advantaged against their own mean. + + Two groups of different sizes must round-trip cleanly: each group's advantages + must sum to zero and not leak into the other. + """ + # Group A (example_id=0): 4 rollouts. Group B (example_id=1): 2 rollouts. + rollouts = [ + _make_rollout(1.0, example_id=0), + _make_rollout(0.0, example_id=0), + _make_rollout(1.0, example_id=0), + _make_rollout(0.0, example_id=0), + _make_rollout(0.3, example_id=1), + _make_rollout(0.7, example_id=1), + ] + + compute_advantages(rollouts, advantage_config=DefaultAdvantageConfig()) + + advantages = [r["advantage"] for r in rollouts] + # Group A: mean=0.5, advantages=[0.5, -0.5, 0.5, -0.5] + assert advantages[:4] == pytest.approx([0.5, -0.5, 0.5, -0.5], abs=1e-5) + # Group B: mean=0.5, advantages=[-0.2, 0.2] + assert advantages[4:] == pytest.approx([-0.2, 0.2], abs=1e-5) + + +def test_compute_advantages_singleton_group_gets_zero_advantage(): + """A group of size 1 has reward == mean, so its advantage is 0 (filterable downstream).""" + rollouts = [ + _make_rollout(0.5, example_id=0), + _make_rollout(0.8, example_id=0), + _make_rollout(0.3, example_id=1), # singleton group + ] + + compute_advantages(rollouts, advantage_config=DefaultAdvantageConfig()) + + advantages = [r["advantage"] for r in rollouts] + # Group 0: mean=0.65, advantages=[-0.15, 0.15] + assert advantages[:2] == pytest.approx([-0.15, 0.15], abs=1e-5) + # Group 1 (singleton): advantage=0 + assert advantages[2] == pytest.approx(0.0, abs=1e-5) + + +def test_compute_advantages_disambiguates_example_id_across_envs(): + """example_id=0 in env A and example_id=0 in env B must not be grouped together.""" + rollouts = [ + _make_rollout(1.0, env_name="env_a", example_id=0), + _make_rollout(0.0, env_name="env_a", example_id=0), + _make_rollout(100.0, env_name="env_b", example_id=0), + _make_rollout(200.0, env_name="env_b", example_id=0), + ] + + compute_advantages(rollouts, advantage_config=DefaultAdvantageConfig()) + + advantages = [r["advantage"] for r in rollouts] + # env_a group: mean=0.5, advantages=[0.5, -0.5] + assert advantages[:2] == pytest.approx([0.5, -0.5], abs=1e-5) + # env_b group: mean=150, advantages=[-50, 50] + assert advantages[2:] == pytest.approx([-50.0, 50.0], abs=1e-5) + + def test_setup_advantage_fn_with_custom_config(): config = CustomAdvantageConfig( import_path="tests.unit.orchestrator.test_advantage._dummy_custom_advantage", @@ -353,17 +354,13 @@ def test_setup_advantage_fn_with_custom_config(): ) advantage_fn = setup_advantage_fn(config) - inputs = _make_inputs( - rewards=[[1.0, 0.5, 0.8]], - completion_lengths=[[10, 12, 8]], - ) + inputs = _make_group(rewards=[1.0, 0.5, 0.8], completion_lengths=[10, 12, 8]) result = advantage_fn(inputs) assert isinstance(result, AdvantageOutputs) - assert torch.allclose(result.advantages, torch.tensor([[2.0, 1.0, 1.6]])) + assert result.advantages == pytest.approx([2.0, 1.0, 1.6], abs=1e-6) def _dummy_custom_advantage(inputs: AdvantageInputs, scale: float = 1.0) -> AdvantageOutputs: """A simple custom advantage for testing.""" - rewards = torch.tensor([[r["reward"] for r in group] for group in inputs.rollouts]) - return AdvantageOutputs(advantages=rewards * scale) + return AdvantageOutputs(advantages=[r["reward"] * scale for r in inputs.rollouts]) diff --git a/tests/unit/orchestrator/test_batch.py b/tests/unit/orchestrator/test_batch.py index fc95de4e2f..eab84ebb2e 100644 --- a/tests/unit/orchestrator/test_batch.py +++ b/tests/unit/orchestrator/test_batch.py @@ -16,7 +16,11 @@ def _routed_experts(data, dtype=np.uint8): @pytest.fixture def make_training_example(): - def _make_training_example(temperature: float = 1.0, sft_loss: bool = False) -> TrainingSample: + def _make_training_example( + temperature: float = 1.0, + sft_loss: bool = False, + env_name: str = "test-env", + ) -> TrainingSample: return TrainingSample( prompt_ids=[1, 2], prompt_mask=[False, False], @@ -26,12 +30,26 @@ def _make_training_example(temperature: float = 1.0, sft_loss: bool = False) -> completion_temperatures=[temperature, temperature], # Per-token temperatures teacher_logprobs=[0.0, 0.0, 0.0, 0.0], advantage=1.0, + env_name=env_name, sft_loss=sft_loss, ) return _make_training_example +def test_training_sample_requires_env_name(): + with pytest.raises(TypeError, match="env_name"): + TrainingSample( + prompt_ids=[1, 2], + prompt_mask=[False, False], + completion_ids=[3, 4], + completion_mask=[True, True], + completion_logprobs=[-0.1, -0.2], + completion_temperatures=[1.0, 1.0], + advantage=1.0, + ) + + @pytest.mark.parametrize( ("rollout_count", "num_train_workers", "expected_batches_per_worker"), [(4, 2, 2), (5, 2, 3), (7, 1, 7), (11, 4, 3)] ) @@ -68,8 +86,8 @@ def test_prepare_batch_balances_micro_batches_across_workers( def test_prepare_batch_packs_different_temperatures(make_training_example): """With per-token temperatures, samples can be packed together regardless of their temperature values.""" - example1 = make_training_example(temperature=0.7) - example2 = make_training_example(temperature=1.1) + example1 = make_training_example(temperature=0.7, env_name="env-a") + example2 = make_training_example(temperature=1.1, env_name="env-b") batches_per_gpu = prepare_batch( rollouts=[example1, example2], @@ -88,6 +106,7 @@ def test_prepare_batch_packs_different_temperatures(make_training_example): assert flat_batches[0].temperatures[:4] == [0.7, 0.7, 0.7, 0.7] # Second sample (4 tokens): all get temp 1.1 assert flat_batches[0].temperatures[4:8] == [1.1, 1.1, 1.1, 1.1] + assert flat_batches[0].env_names == ["env-a"] * 4 + ["env-b"] * 4 def test_prepare_sample_propagates_sft_loss(make_training_example): @@ -128,6 +147,7 @@ def test_prepare_sample_with_routed_experts(): completion_logprobs=[-0.1, -0.2], completion_temperatures=[1.0, 1.0], advantage=1.0, + env_name="test-env", routed_experts=routed, ) @@ -136,6 +156,7 @@ def test_prepare_sample_with_routed_experts(): assert micro_batch.routed_experts.data == routed.data assert micro_batch.routed_experts.shape == routed.shape assert micro_batch.routed_experts.dtype == routed.dtype + assert micro_batch.env_names == ["test-env"] * 4 def test_prepare_sample_truncates_routed_experts(): @@ -151,6 +172,7 @@ def test_prepare_sample_truncates_routed_experts(): completion_logprobs=[-0.1, -0.2], completion_temperatures=[1.0, 1.0], advantage=1.0, + env_name="test-env", routed_experts=routed, ) @@ -159,6 +181,7 @@ def test_prepare_sample_truncates_routed_experts(): assert micro_batch.routed_experts.data == expected.data assert micro_batch.routed_experts.shape == expected.shape assert micro_batch.routed_experts.dtype == expected.dtype + assert micro_batch.env_names == ["test-env"] * 3 def test_prepare_sample_none_routed_experts(): @@ -171,6 +194,7 @@ def test_prepare_sample_none_routed_experts(): completion_logprobs=[-0.1, -0.2], completion_temperatures=[1.0, 1.0], advantage=1.0, + env_name="test-env", ) micro_batch = prepare_sample(sample, seq_len=8) diff --git a/tests/unit/orchestrator/test_sft_trajectories.py b/tests/unit/orchestrator/test_sft_trajectories.py index be3b30e093..a65456b674 100644 --- a/tests/unit/orchestrator/test_sft_trajectories.py +++ b/tests/unit/orchestrator/test_sft_trajectories.py @@ -60,6 +60,7 @@ def test_pretokenize_rollout_trajectory_for_sft(): tokenizer = SimpleChatTokenizer() output = vf.RolloutOutput( example_id=42, + env_name="test-env", trajectory=[ vf.TrajectoryStep( prompt=[{"role": "user", "content": "U1"}], diff --git a/tests/unit/orchestrator/test_teacher_logprobs.py b/tests/unit/orchestrator/test_teacher_logprobs.py index fb5a04c3b0..d63fdce792 100644 --- a/tests/unit/orchestrator/test_teacher_logprobs.py +++ b/tests/unit/orchestrator/test_teacher_logprobs.py @@ -50,6 +50,7 @@ async def _run(): completion_mask=[True, True], completion_logprobs=[-0.1, -0.2], completion_temperatures=[1.0, 1.0], + env_name="test-env", ) result = await orchestrator_utils.compute_teacher_logprobs( diff --git a/tests/unit/orchestrator/test_trajectories.py b/tests/unit/orchestrator/test_trajectories.py index 303a02fd11..d9f6a700f4 100644 --- a/tests/unit/orchestrator/test_trajectories.py +++ b/tests/unit/orchestrator/test_trajectories.py @@ -19,6 +19,13 @@ interleave_rollout, ) +_interleave_rollout = interleave_rollout + + +def interleave_rollout(output, *args, **kwargs): + output.setdefault("env_name", "test-env") + return _interleave_rollout(output, *args, **kwargs) + def _pixels(data: list[list[float]]) -> tuple[bytes, list[int]]: """Convert pixel values list to (bytes, shape) for test cache data.""" @@ -388,6 +395,7 @@ def test_branching_equivalent_multi_step_trajectory_with_tool_calls( def test_interleave_rollout_single_step_trajectory(single_step_trajectory_output): + single_step_trajectory_output["env_name"] = "test-env" rollouts = interleave_rollout(single_step_trajectory_output) assert rollouts is not None assert len(rollouts) == 1 @@ -399,6 +407,7 @@ def test_interleave_rollout_single_step_trajectory(single_step_trajectory_output assert rollout.completion_mask == [True, True] assert rollout.completion_logprobs == [-0.1, -0.2] assert rollout.completion_temperatures == [1.0, 1.0] + assert rollout.env_name == "test-env" def test_interleave_rollout_multi_step_trajectory(multi_step_trajectory_output): diff --git a/tests/unit/train/models/test_nemotron_h_kl.py b/tests/unit/train/models/test_nemotron_h_kl.py index 7e590cadf9..09c0403359 100644 --- a/tests/unit/train/models/test_nemotron_h_kl.py +++ b/tests/unit/train/models/test_nemotron_h_kl.py @@ -105,8 +105,8 @@ def test_kl_zero_when_identical(): ) result = default_loss_fn(inputs, DefaultLossConfig()) - assert result.metrics["mismatch_kl"].item() == pytest.approx(0.0, abs=1e-6), ( - f"Expected zero KL for identical models, got {result.metrics['mismatch_kl'].item()}" + assert result.metrics["unmasked_mismatch_kl"].item() == pytest.approx(0.0, abs=1e-6), ( + f"Expected zero KL for identical models, got {result.metrics['unmasked_mismatch_kl'].item()}" ) @@ -138,7 +138,7 @@ def test_kl_positive_after_perturbation(): loss_mask=loss_mask, ) result = default_loss_fn(inputs, DefaultLossConfig()) - kl = result.metrics["mismatch_kl"].item() + kl = result.metrics["unmasked_mismatch_kl"].item() assert kl > 0, f"Expected positive KL after perturbation, got {kl}" assert kl < 100, f"KL unexpectedly large: {kl}" diff --git a/tests/unit/train/rl/test_packer.py b/tests/unit/train/rl/test_packer.py index c661ec0df5..68c77fbe4a 100644 --- a/tests/unit/train/rl/test_packer.py +++ b/tests/unit/train/rl/test_packer.py @@ -46,6 +46,7 @@ def make_training_sample() -> TrainingSample: completion_mask=[True], completion_logprobs=[-0.1], completion_temperatures=[1.0], + env_name="test-env", )