Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

Documenting **breaking** configuration changes — renamed, removed, or moved fields that require users to update existing configs.

- **`AdvantageInputs` / `AdvantageOutputs` are now per-group, and `AdvantageOutputs.advantages` is a plain `list[float]`** (second breaking change to this API in three weeks). `AdvantageInputs.rollouts` is now `list[vf.RolloutOutput]` (a single group) instead of `list[list[vf.RolloutOutput]]`, and `AdvantageOutputs.advantages` is now `list[float]` instead of a 2D `Float[Tensor, "num_examples rollouts_per_example"]`. `compute_advantages` calls `advantage_fn` once per group, which lets partial-group training (groups smaller than `rollouts_per_example` after rollout errors) round-trip without the previous bucket-by-size workaround. Custom advantage functions must drop the outer list dimension and return a list of floats — e.g. `AdvantageOutputs(advantages=(rewards - rewards.mean(dim=1, keepdim=True)).tolist())` becomes `AdvantageOutputs(advantages=[r - mean for r in rewards])` (or `.tolist()` if you keep torch internally). (2026-05-22)
- **`orchestrator.advantage.length_penalty` → discriminated sub-config**: The scalar `length_penalty: Literal["tokens","turns"] | None` is replaced by a `LengthPenaltyConfig | None` discriminated on `type`. Token shaping now takes weighted completion + tool-response token costs. Migration: `length_penalty = "tokens"` becomes `[orchestrator.advantage.length_penalty]\ntype = "tokens"` (default weights `completion_weight = 1.0`, `tool_response_weight = 1.0` — total context). `length_penalty = "turns"` becomes `[orchestrator.advantage.length_penalty]\ntype = "turns"`. (2026-05-06)
- **`orchestrator.advantage.length_shaping` → `orchestrator.advantage.length_penalty`**: The boolean `length_shaping` flag has been replaced by `length_penalty: Literal["tokens", "turns"] | None` (default: `None`). `length_shaping = true` becomes `length_penalty = "tokens"`; `length_shaping = false` becomes `length_penalty = None`. The new `"turns"` option applies the same correctness-gated efficiency shaping using trajectory turn count instead of completion-token count. (2026-05-01)
- **`AdvantageInputs` API**: Replaced the `rewards`/`completion_lengths`/`num_turns` tensor fields with a single `rollouts: list[list[vf.RolloutOutput]]` (grouped by problem). Custom advantage functions can now access any rollout metadata. Existing custom advantages must update their signatures and extract per-rollout fields directly (e.g. `torch.tensor([[r["reward"] for r in g] for g in inputs.rollouts])`). (2026-05-01)
Expand Down
21 changes: 10 additions & 11 deletions docs/bring-your-own-algorithms.md
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ kwargs = { clip_eps = 0.2 }

## 2. Custom Advantage Functions

Advantages are computed **per-example** (grouped by `rollouts_per_example`). You provide a function that computes advantages for a batch of examples.
Advantages are computed **per-group** (one example × N rollouts). You provide a function that computes advantages for a single group; the framework calls it once per group and stitches the results back together. Groups may have fewer than `rollouts_per_example` rollouts when some rollouts in the group errored (partial-group training).

### Interface

Expand All @@ -86,8 +86,8 @@ def my_custom_advantage(inputs: AdvantageInputs, **kwargs) -> AdvantageOutputs:
```python
@dataclass
class AdvantageInputs:
# Rollouts grouped by problem: rollouts[i][j] is the j-th rollout for problem i.
rollouts: list[list[vf.RolloutOutput]]
# All rollouts for a single example (one group).
rollouts: list[vf.RolloutOutput]
```

Each `vf.RolloutOutput` carries the full rollout (`reward`, `trajectory`, etc.), so custom advantages can read any metadata they need (e.g. completion-token counts, turn counts, tool calls).
Expand All @@ -97,22 +97,21 @@ Each `vf.RolloutOutput` carries the full rollout (`reward`, `trajectory`, etc.),
```python
@dataclass
class AdvantageOutputs:
advantages: Float[Tensor, "num_examples rollouts_per_example"]
advantages: list[float] # one entry per rollout in the input group
```

### Example: Normalized Advantage

```python
import torch
import statistics
from prime_rl.orchestrator.advantage import AdvantageInputs, AdvantageOutputs

def normalized_advantage(inputs: AdvantageInputs, eps: float = 1e-8) -> AdvantageOutputs:
"""Normalize advantages to zero mean and unit variance per example."""
rewards = torch.tensor([[r["reward"] for r in group] for group in inputs.rollouts])
mean = rewards.mean(dim=1, keepdim=True)
std = rewards.std(dim=1, keepdim=True)
advantages = (rewards - mean) / (std + eps)
return AdvantageOutputs(advantages=advantages)
"""Normalize advantages to zero mean and unit variance within the group."""
rewards = [r["reward"] for r in inputs.rollouts]
mean = statistics.fmean(rewards)
std = statistics.pstdev(rewards) if len(rewards) > 1 else 0.0
return AdvantageOutputs(advantages=[(r - mean) / (std + eps) for r in rewards])
```

### Configuration
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -337,6 +337,9 @@ class EnvConfig(BaseConfig):
),
] = None

state_columns: list[str] = []
"""Extra ``State`` fields to persist into the saved rollout records (in addition to the always-saved ``trajectory`` and ``sampling_args``). Values must be JSON-serializable."""

@property
def stripped_id(self) -> str:
"""Environment ID without the @version suffix."""
Expand Down
7 changes: 7 additions & 0 deletions packages/prime-rl-configs/src/prime_rl/configs/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -767,9 +767,16 @@ class NCCLWeightBroadcastConfig(BaseWeightBroadcastConfig):
]


class TokenExportConfig(BaseConfig):
"""Configures per-token rollout exports from the RL trainer."""


class TrainerExperimentalConfig(BaseConfig):
"""Experimental features for the trainer."""

token_export: TokenExportConfig | None = None
"""Opt-in per-token JSONL export for rollout debugging. When enabled, writes token ids and aligned trainer metrics after each forward pass."""


class TrainerConfig(BaseConfig):
"""Configures the RL trainer"""
Expand Down
10 changes: 10 additions & 0 deletions skills/config/SKILL.md
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,16 @@ If you wish to configure values of the default variant, you don't need to set th

For hosted multi-tenant runs where the trainer image's `trainer.loss.type` is fixed, the orchestrator exposes a per-run override that forces SFT loss on every micro-batch without rebuilding the trainer. Set `orchestrator.use_sft_loss = true` alongside `orchestrator.teacher_rollout_model`; both must be configured together (the orchestrator validator enforces this). The orchestrator stamps each `TrainingSample.sft_loss = True`, which the trainer's `compute_loss` honors by dispatching to `sft_loss_fn` per batch — independent of the trainer's configured default loss.

### RL trainer token exports

For rollout debugging, enable trainer-side token export under `trainer.experimental.token_export` (or `experimental.token_export` when running the trainer entrypoint directly). It writes one JSONL record per exported sequence under `output_dir/token_exports/step_<step>/rank_<rank>.jsonl`. Each record stores aligned per-token arrays for token ids, loss mask, advantage, reward, entropy, mismatch KL, inference/trainer logprobs, importance ratios, probability deltas, and masking diagnostics. It does not decode token text in the trainer.

```toml
[trainer.experimental.token_export]
```

Leave it unset for normal training. When enabled, it exports every sequence from each exporting rank.

### Model fields

For `BaseModel | None` fields (like `[ckpt]`, `[wandb]`, `[compile]`), a bare flag enables them with defaults:
Expand Down
7 changes: 4 additions & 3 deletions skills/monitor-run/SKILL.md
Original file line number Diff line number Diff line change
Expand Up @@ -143,8 +143,10 @@ These tell you whether training is healthy or diverging.

| Metric | Source | Description |
|--------|--------|-------------|
| `mismatch_kl/mean` | trainer | KL divergence between trainer and (old) inference policy |
| `entropy/mean` | trainer | policy entropy |
| `mismatch_kl/{all,env}/{mean,std,max}` | trainer | KL divergence between trainer and (old) inference policy over trainable tokens (W&B) |
| `entropy/{all,env}/{mean,std,max}` | trainer | policy entropy over trainable tokens (W&B) |
| `masked_advantage_positive/mean` | trainer | fraction of DPPO-masked trainable tokens with positive advantage (W&B) |
| `masked_advantage_negative/mean` | trainer | fraction of DPPO-masked trainable tokens with negative advantage (W&B) |
| `optim/grad_norm` | trainer | gradient norm — spikes may precede divergence |

#### Performance
Expand Down Expand Up @@ -251,4 +253,3 @@ PRIME-RL::Launcher
```

For multi-node runs, trainer and inference processes are distributed across separate nodes. Use `srun` or `ssh` to inspect processes on other nodes directly.

103 changes: 44 additions & 59 deletions src/prime_rl/orchestrator/advantage.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from collections import defaultdict
from dataclasses import dataclass
from typing import Callable

Expand All @@ -19,20 +20,16 @@

@dataclass
class AdvantageInputs:
"""Inputs for advantage computation.
"""Inputs for advantage computation of a single group (one example × N rollouts)."""

`rollouts` is grouped by problem: `rollouts[i][j]` is the j-th rollout for problem i,
so `len(rollouts) == num_problems` and `len(rollouts[0]) == rollouts_per_example`.
"""

rollouts: list[list[vf.RolloutOutput]]
rollouts: list[vf.RolloutOutput]


@dataclass
class AdvantageOutputs:
"""Outputs from advantage computation."""
"""Outputs from advantage computation of a single group."""

advantages: Float[Tensor, "num_problems rollouts_per_example"]
advantages: list[float]


AdvantageFn = Callable[..., AdvantageOutputs]
Expand All @@ -41,80 +38,73 @@ class AdvantageOutputs:
Expected signature:
def my_advantage(inputs: AdvantageInputs, **kwargs) -> AdvantageOutputs:
...

The function receives a single group and returns a list of advantages with one
entry per rollout. `compute_advantages` calls it once per group.
"""


def default_advantage_fn(
inputs: AdvantageInputs,
length_penalty: LengthPenaltyConfig | None = None,
) -> AdvantageOutputs:
"""Default GRPO advantage: reward minus per-problem baseline.
"""Default GRPO advantage for a single group: reward minus per-group baseline.

`length_penalty` enables correctness-gated efficiency shaping over a per-rollout
cost: tokens (weighted completion + tool-response) or trajectory turn count.
"""
rewards = torch.tensor([[r["reward"] for r in group] for group in inputs.rollouts], dtype=torch.float32)
rewards = torch.tensor([r["reward"] for r in inputs.rollouts], dtype=torch.float32)

if isinstance(length_penalty, TokensLengthPenaltyConfig):
w_c = length_penalty.completion_weight
w_t = length_penalty.tool_response_weight
costs = torch.tensor(
[
[w_c * get_model_completion_len(r) + w_t * get_tool_response_len(r) for r in group]
for group in inputs.rollouts
],
[w_c * get_model_completion_len(r) + w_t * get_tool_response_len(r) for r in inputs.rollouts],
dtype=rewards.dtype,
)
return AdvantageOutputs(advantages=_efficiency_shaping(rewards, costs))
return AdvantageOutputs(advantages=_efficiency_shaping(rewards, costs).tolist())
if isinstance(length_penalty, TurnsLengthPenaltyConfig):
costs = torch.tensor(
[[get_num_turns(r) for r in group] for group in inputs.rollouts],
dtype=rewards.dtype,
)
return AdvantageOutputs(advantages=_efficiency_shaping(rewards, costs))
costs = torch.tensor([get_num_turns(r) for r in inputs.rollouts], dtype=rewards.dtype)
return AdvantageOutputs(advantages=_efficiency_shaping(rewards, costs).tolist())

baseline = rewards.mean(dim=1, keepdim=True)
return AdvantageOutputs(advantages=rewards - baseline)
return AdvantageOutputs(advantages=(rewards - rewards.mean()).tolist())


def _efficiency_shaping(
rewards: Float[Tensor, "num_problems rollouts_per_example"],
costs: Float[Tensor, "num_problems rollouts_per_example"],
) -> Float[Tensor, "num_problems rollouts_per_example"]:
rewards: Float[Tensor, "group_size"],
costs: Float[Tensor, "group_size"],
) -> Float[Tensor, "group_size"]:
"""Correctness-gated efficiency shaping with bounded advantages.

Shapes rewards with a bounded efficiency bonus before standard GRPO subtraction,
preserving zero-mean advantages per group. `costs` is a per-rollout cost (e.g.,
completion length in tokens or number of turns).
preserving zero-mean advantages within the group. `costs` is a per-rollout cost
(e.g., completion length in tokens or number of turns).

Correct rollouts get reward amplified by up to 2x based on relative efficiency.
Incorrect rollouts are untouched. Lower-cost correct rollouts get higher advantage.
"""
max_reward = rewards.max(dim=1, keepdim=True).values
max_reward = rewards.max()
correct_mask = rewards >= max_reward
num_correct = correct_mask.sum(dim=1, keepdim=True)
num_correct = correct_mask.sum()

# No shaping when max reward is 0 — no correct rollouts to differentiate
has_correct = max_reward > 0
if max_reward <= 0:
return rewards - rewards.mean()

# Mean cost of correct rollouts per problem
correct_costs = costs * correct_mask
mean_correct_cost = correct_costs.sum(dim=1, keepdim=True) / num_correct.clamp(min=1)
# Mean cost of correct rollouts
mean_correct_cost = (costs * correct_mask).sum() / num_correct.clamp(min=1)

# Bounded efficiency bonus: [0, 1], positive for below-average cost, zero for above.
# When mean_correct_cost is 0 (e.g. tool-only shaping with no harness metric, or
# all-zero turn counts), no rollouts can be differentiated — fall back to no bonus.
has_cost = mean_correct_cost > 0
safe_mean = torch.where(has_cost, mean_correct_cost, torch.ones_like(mean_correct_cost))
bonus = (1 - costs / safe_mean).clamp(0, 1) * has_cost
if mean_correct_cost <= 0:
return rewards - rewards.mean()

bonus = (1 - costs / mean_correct_cost).clamp(0, 1)

# Shape rewards: correct rollouts amplified by up to 2x, incorrect untouched
shaped_rewards = rewards * (1 + bonus * correct_mask)
baseline = shaped_rewards.mean(dim=1, keepdim=True)

shaped = shaped_rewards - baseline
unshaped = rewards - rewards.mean(dim=1, keepdim=True)
return torch.where(has_correct, shaped, unshaped)
return shaped_rewards - shaped_rewards.mean()


def setup_advantage_fn(config: AdvantageConfig) -> AdvantageFn:
Expand All @@ -136,31 +126,26 @@ def advantage_fn(inputs: AdvantageInputs) -> AdvantageOutputs:

def compute_advantages(
rollouts: list[vf.RolloutOutput],
samples_per_problem: int,
advantage_config: AdvantageConfig | None,
) -> None:
"""
Computes advantages from rollouts, grouped by problem.
Stores advantages in-place on the rollouts.
"""Computes advantages from rollouts, grouped by (env_name, example_id), and
stores them in-place on the rollouts.

Args:
rollouts: List of rollouts to store advantages on
samples_per_problem: Number of samples (and thus, rewards) per problem
advantage_config: Configuration for advantage computation (DefaultAdvantageConfig or CustomAdvantageConfig)
`advantage_fn` is called once per group, so groups may have varying sizes
(partial-group training drops failed rollouts rather than rescheduling them).
"""
rewards = [r["reward"] for r in rollouts]

if not advantage_config:
for rollout, reward in zip(rollouts, rewards):
rollout["advantage"] = reward
for rollout in rollouts:
rollout["advantage"] = rollout["reward"]
return

advantage_fn = setup_advantage_fn(advantage_config)
grouped = [rollouts[i : i + samples_per_problem] for i in range(0, len(rollouts), samples_per_problem)]
inputs = AdvantageInputs(rollouts=grouped)

result = advantage_fn(inputs)
advantages = result.advantages.flatten().tolist()
groups_by_example: dict[tuple[str, int], list[vf.RolloutOutput]] = defaultdict(list)
for rollout in rollouts:
groups_by_example[(rollout["env_name"], rollout["example_id"])].append(rollout)

for rollout, advantage in zip(rollouts, advantages):
rollout["advantage"] = advantage
for group in groups_by_example.values():
result = advantage_fn(AdvantageInputs(rollouts=group))
for rollout, advantage in zip(group, result.advantages):
rollout["advantage"] = advantage
13 changes: 11 additions & 2 deletions src/prime_rl/orchestrator/envs.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,15 @@ def _sampling_args_with_salt(self, cache_salt: str) -> dict:
sampling_args["extra_body"] = extra_body
return sampling_args

@property
def state_columns(self) -> list[str]:
"""Required columns plus any extras configured on the env, deduped (required first)."""
merged: list[str] = []
for col in (*REQUIRED_STATE_COLUMNS, *self.config.state_columns):
if col not in merged:
merged.append(col)
return merged

async def run_rollout(
self,
client: vf.ClientConfig,
Expand All @@ -124,7 +133,7 @@ async def run_rollout(
model=model_name,
sampling_args=self._sampling_args_with_salt(cache_salt),
max_retries=self.config.max_retries,
state_columns=REQUIRED_STATE_COLUMNS,
state_columns=self.state_columns,
env_client=self.env_client,
)

Expand All @@ -143,7 +152,7 @@ async def run_group(
model=model_name,
sampling_args=self._sampling_args_with_salt(cache_salt),
max_retries=self.config.max_retries,
state_columns=REQUIRED_STATE_COLUMNS,
state_columns=self.state_columns,
env_client=self.env_client,
)

Expand Down
3 changes: 2 additions & 1 deletion src/prime_rl/orchestrator/orchestrator.py
Original file line number Diff line number Diff line change
Expand Up @@ -429,7 +429,7 @@ async def orchestrate(config: OrchestratorConfig):
# Compute advantages (in-place)
num_rollouts = len(train_rollouts)
num_unique_examples = len({(r["env_name"], r["example_id"]) for r in train_rollouts})
compute_advantages(train_rollouts, config.rollouts_per_example, config.advantage)
compute_advantages(train_rollouts, config.advantage)

# Apply rollout filters — sets rollout["filters"] and rollout["is_filtered"]
apply_filters(rollout_filters, train_rollouts)
Expand Down Expand Up @@ -553,6 +553,7 @@ def process_rollout(rollout: vf.RolloutOutput, rollout_idx: int) -> list[Trainin
for sample in samples:
sample.advantage = rollout["advantage"]
sample.reward = rollout["reward"]
sample.env_name = rollout["env_name"]
if config.use_sft_loss:
sample.sft_loss = True
sample_decode_tokens = sum(sample.completion_mask)
Expand Down
Loading
Loading