Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 9 additions & 1 deletion packages/prime-rl-configs/src/prime_rl/configs/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -479,8 +479,16 @@ class NCCLWeightBroadcastConfig(BaseWeightBroadcastConfig):
]


class TokenExportConfig(BaseConfig):
"""Configures per-token rollout exports from the RL trainer."""

path: Path | None = None
"""JSONL output file. If unset, writes to ``<output_dir>/token_exports/rank_<rank>.jsonl``. Relative paths are resolved under output_dir."""


class TrainerExperimentalConfig(BaseConfig):
pass
token_export: TokenExportConfig | None = None
"""Opt-in per-token JSONL export for rollout debugging. When enabled, writes token ids and aligned trainer metrics after each forward pass."""


class TrainerConfig(BaseConfig):
Expand Down
12 changes: 12 additions & 0 deletions skills/configs/SKILL.md
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,18 @@ CLI: `--env.0.id reverse-text --env.1.id math-env`.

In TOML, an empty section header (`[ckpt]`) does the same.

## RL trainer token exports

For rollout debugging, enable trainer-side token export under `trainer.experimental.token_export` (or `experimental.token_export` when running the trainer entrypoint directly). It writes one JSONL record per exported sequence. Each record stores aligned per-token arrays for token ids, loss mask, advantage, reward, entropy, mismatch KL, inference/trainer logprobs, importance ratios, probability deltas, and masking diagnostics. It does not decode token text in the trainer.

```toml
[trainer.experimental.token_export]
# Optional. Relative paths resolve under trainer.output_dir.
path = "token_exports/sample.jsonl"
```

Leave it unset for normal training. When enabled, it exports every sequence from each exporting rank.

## Key files

- `packages/prime-rl-configs/src/prime_rl/` — config classes under `configs/`; `utils/config.py` re-exports `BaseConfig` and `cli`
Expand Down
16 changes: 15 additions & 1 deletion src/prime_rl/trainer/batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@ def prepare_sample(training_example: TrainingSample, seq_len: int) -> MicroBatch
loss_mask = training_example.prompt_mask + training_example.completion_mask
inference_logprobs = [0.0] * len(training_example.prompt_ids) + training_example.completion_logprobs
advantages = [training_example.advantage] * len(input_ids)
reward = training_example.reward if training_example.reward is not None else float("nan")
rewards = [reward] * len(input_ids)
position_ids = list(range(len(input_ids)))
mm_token_type_ids = training_example.mm_token_type_ids
assert training_example.env_name != "all", "env_name='all' is reserved for aggregate metric keys"
Expand All @@ -33,6 +35,7 @@ def prepare_sample(training_example: TrainingSample, seq_len: int) -> MicroBatch
inference_logprobs = inference_logprobs[:seq_len]
position_ids = position_ids[:seq_len]
advantages = advantages[:seq_len]
rewards = rewards[:seq_len]
temperatures = temperatures[:seq_len]
if teacher_logprobs is not None:
teacher_logprobs = teacher_logprobs[:seq_len]
Expand All @@ -48,9 +51,10 @@ def prepare_sample(training_example: TrainingSample, seq_len: int) -> MicroBatch
== len(loss_mask)
== len(position_ids)
== len(inference_logprobs)
== len(rewards)
== len(temperatures)
), (
f"input_ids: {len(input_ids)}, advantages: {len(advantages)}, loss_mask: {len(loss_mask)}, position_ids: {len(position_ids)}, inference_logprobs: {len(inference_logprobs)}, temperatures: {len(temperatures)}"
f"input_ids: {len(input_ids)}, advantages: {len(advantages)}, loss_mask: {len(loss_mask)}, position_ids: {len(position_ids)}, inference_logprobs: {len(inference_logprobs)}, rewards: {len(rewards)}, temperatures: {len(temperatures)}"
)
if teacher_logprobs is not None:
assert len(teacher_logprobs) == len(input_ids), f"teacher_logprobs: {len(teacher_logprobs)}"
Expand All @@ -74,6 +78,7 @@ def prepare_sample(training_example: TrainingSample, seq_len: int) -> MicroBatch
inference_logprobs=inference_logprobs,
teacher_logprobs=teacher_logprobs,
temperatures=temperatures,
rewards=rewards,
routed_experts=routed_experts,
mm_token_type_ids=mm_token_type_ids,
env_names=env_names,
Expand Down Expand Up @@ -122,9 +127,16 @@ def packed_samples_into_micro_bs(
len(bin_content.input_ids) + len(sample.input_ids) <= max_seq_len
and bin_content.training_mode == sample.training_mode
):
existing_len = len(bin_content.input_ids)
bin_content.input_ids.extend(sample.input_ids)
bin_content.loss_mask.extend(sample.loss_mask)
bin_content.advantages.extend(sample.advantages)
if sample.rewards is not None:
if bin_content.rewards is None:
bin_content.rewards = [float("nan")] * existing_len
bin_content.rewards.extend(sample.rewards)
elif bin_content.rewards is not None:
bin_content.rewards.extend([float("nan")] * len(sample.input_ids))
bin_content.inference_logprobs.extend(sample.inference_logprobs)
bin_content.temperatures.extend(sample.temperatures)
if sample.teacher_logprobs is not None:
Expand Down Expand Up @@ -175,6 +187,8 @@ def pad_micro_batch(micro_batch: MicroBatch, pad_to_multiple_of: int) -> MicroBa

micro_batch.input_ids.extend([1] * padding_size)
micro_batch.advantages.extend([0.0] * padding_size)
if micro_batch.rewards is not None:
micro_batch.rewards.extend([float("nan")] * padding_size)
micro_batch.loss_mask.extend([False] * padding_size)
micro_batch.position_ids.extend(list(range(padding_size)))
micro_batch.inference_logprobs.extend([0.0] * padding_size)
Expand Down
6 changes: 6 additions & 0 deletions src/prime_rl/trainer/rl/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ class TensorMicroBatch(TypedDict):
input_ids: Int[Tensor, "batch seq"]
position_ids: Int[Tensor, "batch seq"]
advantages: Float[Tensor, "batch seq"]
rewards: Float[Tensor, "batch seq"] | None
inference_logprobs: Float[Tensor, "batch seq"]
teacher_logprobs: Float[Tensor, "batch seq"] | None
loss_mask: Bool[Tensor, "batch seq"]
Expand Down Expand Up @@ -108,6 +109,7 @@ def _get_sample_micro_batch(self, generator: torch.Generator) -> TensorMicroBatc
"input_ids": input_ids.unsqueeze(0),
"position_ids": position_ids.unsqueeze(0),
"advantages": advantages.unsqueeze(0),
"rewards": None,
"inference_logprobs": inference_logprobs.unsqueeze(0),
"teacher_logprobs": None,
"temperatures": torch.ones(input_ids.shape[0]).unsqueeze(0),
Expand Down Expand Up @@ -135,6 +137,7 @@ def _get_micro_batch(self, generator: torch.Generator) -> TensorMicroBatch:
),
"position_ids": torch.cat([torch.arange(self.seq_len)]).unsqueeze(0),
"advantages": torch.randn(self.seq_len, generator=generator).unsqueeze(0),
"rewards": None,
"inference_logprobs": torch.randn(self.seq_len, generator=generator).unsqueeze(0),
"teacher_logprobs": None,
"temperatures": torch.ones(self.seq_len).unsqueeze(0),
Expand Down Expand Up @@ -211,6 +214,9 @@ def _micro_batch_to_tensor(self, micro_batch: MicroBatch) -> TensorMicroBatch:
input_ids=torch.tensor(micro_batch.input_ids, dtype=torch.long).unsqueeze(0),
position_ids=torch.tensor(micro_batch.position_ids, dtype=torch.long).unsqueeze(0),
advantages=torch.tensor(micro_batch.advantages, dtype=torch.float).unsqueeze(0),
rewards=torch.tensor(micro_batch.rewards, dtype=torch.float).unsqueeze(0)
if micro_batch.rewards is not None
else None,
inference_logprobs=torch.tensor(micro_batch.inference_logprobs, dtype=torch.float).unsqueeze(0),
teacher_logprobs=torch.tensor(micro_batch.teacher_logprobs, dtype=torch.float).unsqueeze(0)
if micro_batch.teacher_logprobs is not None
Expand Down
234 changes: 234 additions & 0 deletions src/prime_rl/trainer/rl/token_export.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,234 @@
import atexit
import json
import math
from collections.abc import Mapping, Sequence
from pathlib import Path
from typing import Any

import torch
from torch import Tensor

from prime_rl.configs.trainer import DefaultLossConfig, TokenExportConfig, TrainerConfig
from prime_rl.trainer.rl.loss import compute_importance_ratio_and_mismatch_kl

SCHEMA_VERSION = 1


class DisabledTokenExporter:
def export(self, *args: Any, **kwargs: Any) -> None:
return

def close(self) -> None:
return


class TokenExporter:
def __init__(
self,
config: TokenExportConfig,
output_dir: Path,
rank: int,
) -> None:
self.config = config
self.rank = rank
self.path = self._resolve_path(config.path, output_dir, rank)
self.path.parent.mkdir(parents=True, exist_ok=True)
self._file = self.path.open("a", encoding="utf-8")
self._closed = False
self._current_step: int | None = None
self._sequences_this_step = 0
atexit.register(self.close)

def export(
self,
step: int,
micro_step: int,
micro_batch: Mapping[str, Any],
model_output: Mapping[str, Tensor],
response_lengths: list[int],
loss_config: Any,
) -> None:
if self._current_step != step:
self._current_step = step
self._sequences_this_step = 0

columns = _export_columns(micro_batch, model_output, loss_config)
_check_lengths(columns)

start = 0
for micro_sequence_idx, length in enumerate(response_lengths):
raw_end = start + length
end = _trim_padding(columns, start, raw_end)
if end > start and any(columns["loss_mask"][start:end]):
self._write(
{
"schema_version": SCHEMA_VERSION,
"step": step,
"rank": self.rank,
"micro_step": micro_step,
"micro_sequence_idx": micro_sequence_idx,
"export_sequence_idx": self._sequences_this_step,
"env_name": _first_non_empty(columns["env_names"][start:end]),
"training_mode": str(micro_batch["training_mode"]),
**_slice_columns(columns, start, end),
}
)
self._sequences_this_step += 1
start = raw_end

def close(self) -> None:
if self._closed:
return
self._closed = True
self._file.close()

def _write(self, record: dict[str, Any]) -> None:
if self._closed:
raise RuntimeError(f"Token exporter is closed for {self.path}")
self._file.write(json.dumps(record, separators=(",", ":"), allow_nan=False) + "\n")

@staticmethod
def _resolve_path(path: Path | None, output_dir: Path, rank: int) -> Path:
if path is None:
return output_dir / "token_exports" / f"rank_{rank}.jsonl"
if path.is_absolute():
return path
return output_dir / path
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shared export path corrupts JSONL

Medium Severity

When token_export.path is explicitly set, _resolve_path doesn't include the trainer rank in the filename. This causes multiple TokenExporter instances to concurrently write to the same file, resulting in corrupted JSONL output.

Fix in Cursor Fix in Web

Reviewed by Cursor Bugbot for commit c05d797. Configure here.



def setup_token_exporter(
config: TrainerConfig, parallel_dims: Any, world: Any, logger: Any
) -> TokenExporter | DisabledTokenExporter:
token_export_config = config.experimental.token_export
if token_export_config is None:
return DisabledTokenExporter()
if parallel_dims.cp_enabled and parallel_dims.world_mesh["cp"].get_local_rank() != 0:
return DisabledTokenExporter()

exporter = TokenExporter(token_export_config, config.output_dir, world.rank)
logger.info(f"Writing token exports to {exporter.path}")
return exporter


def _export_columns(
micro_batch: Mapping[str, Any], model_output: Mapping[str, Tensor], loss_config: Any
) -> dict[str, list[Any]]:
token_ids = _tensor_to_ints(micro_batch["input_ids"])
seq_len = len(token_ids)
trainer_logprobs = model_output["logprobs"]
export_tensors = _compute_export_tensors(micro_batch, trainer_logprobs, loss_config)

return {
"token_ids": token_ids,
"position_ids": _tensor_to_ints(micro_batch["position_ids"]),
"loss_mask": _tensor_to_bools(micro_batch["loss_mask"]),
"advantages": _tensor_to_floats(micro_batch["advantages"]),
"rewards": _optional_tensor_to_floats(micro_batch.get("rewards"), seq_len),
"inference_logprobs": _tensor_to_floats(micro_batch["inference_logprobs"]),
"trainer_logprobs": _tensor_to_floats(trainer_logprobs),
"entropy": _tensor_to_floats(model_output["entropy"]),
"mismatch_kl": _optional_tensor_to_floats(export_tensors["mismatch_kl"], seq_len),
"log_importance_ratio": _optional_tensor_to_floats(export_tensors["log_importance_ratio"], seq_len),
"importance_ratio": _optional_tensor_to_floats(export_tensors["importance_ratio"], seq_len),
"prob_delta": _optional_tensor_to_floats(export_tensors["prob_delta"], seq_len),
"is_masked": _optional_tensor_to_bools(export_tensors["is_masked"], seq_len),
"is_masked_high": _optional_tensor_to_bools(export_tensors["is_masked_high"], seq_len),
"is_masked_low": _optional_tensor_to_bools(export_tensors["is_masked_low"], seq_len),
"env_names": list(micro_batch["env_names"]),
}


def _compute_export_tensors(
micro_batch: Mapping[str, Any], trainer_logprobs: Tensor, loss_config: Any
) -> dict[str, Tensor | None]:
fields: dict[str, Tensor | None] = {
"log_importance_ratio": None,
"importance_ratio": None,
"mismatch_kl": None,
"prob_delta": None,
"is_masked": None,
"is_masked_high": None,
"is_masked_low": None,
}
if micro_batch["training_mode"] == "sft":
return fields

inference_logprobs = micro_batch["inference_logprobs"].to(trainer_logprobs.device)
loss_mask = micro_batch["loss_mask"].to(trainer_logprobs.device)
advantages = micro_batch["advantages"].to(trainer_logprobs.device)
with torch.no_grad():
log_ratio, ratio, mismatch_kl = compute_importance_ratio_and_mismatch_kl(trainer_logprobs, inference_logprobs)
prob_delta = torch.exp(trainer_logprobs) - torch.exp(inference_logprobs)
fields["log_importance_ratio"] = log_ratio
fields["importance_ratio"] = ratio
fields["mismatch_kl"] = mismatch_kl
fields["prob_delta"] = prob_delta
if isinstance(loss_config, DefaultLossConfig):
invalid_high = prob_delta > loss_config.dppo_mask_high
invalid_low = prob_delta < -loss_config.dppo_mask_low
positive_advantages = advantages > 0
negative_advantages = advantages < 0
invalid = torch.where(positive_advantages, invalid_high, invalid_low)
fields["is_masked"] = loss_mask & invalid
fields["is_masked_high"] = loss_mask & positive_advantages & invalid_high
fields["is_masked_low"] = loss_mask & negative_advantages & invalid_low
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OPD export masks mismatch loss

Medium Severity

The token_export module's is_masked* fields for opd training use DefaultLossConfig's DPPO thresholds. This differs from opd_loss_fn's fixed 0.2 probability-delta thresholds, causing exported masking diagnostics to misrepresent actual token masking during opd training.

Fix in Cursor Fix in Web

Reviewed by Cursor Bugbot for commit c05d797. Configure here.

return fields


def _tensor_to_ints(tensor: Tensor) -> list[int]:
return [int(value) for value in tensor.detach().cpu().reshape(-1).tolist()]


def _tensor_to_bools(tensor: Tensor) -> list[bool]:
return [bool(value) for value in tensor.detach().cpu().reshape(-1).tolist()]


def _tensor_to_floats(tensor: Tensor) -> list[float | None]:
values = tensor.detach().to(dtype=torch.float32, device="cpu").reshape(-1).tolist()
return [_json_float(value) for value in values]


def _optional_tensor_to_floats(tensor: Tensor | None, seq_len: int) -> list[float | None]:
if tensor is None:
return [None] * seq_len
return _tensor_to_floats(tensor)


def _optional_tensor_to_bools(tensor: Tensor | None, seq_len: int) -> list[bool | None]:
if tensor is None:
return [None] * seq_len
return _tensor_to_bools(tensor)


def _check_lengths(columns: Mapping[str, Sequence[Any]]) -> None:
lengths = {key: len(values) for key, values in columns.items()}
if len(set(lengths.values())) != 1:
raise ValueError(f"Token export fields must have aligned lengths, got {lengths}")


def _slice_columns(columns: Mapping[str, Sequence[Any]], start: int, end: int) -> dict[str, list[Any]]:
return {key: list(values[start:end]) for key, values in columns.items() if key != "env_names"}


def _trim_padding(columns: Mapping[str, Sequence[Any]], start: int, end: int) -> int:
env_names = columns["env_names"]
loss_mask = columns["loss_mask"]
while end > start and env_names[end - 1] == "" and not loss_mask[end - 1]:
end -= 1
return end


def _json_float(value: float | None) -> float | None:
if value is None:
return None
value = float(value)
if not math.isfinite(value):
return None
return value


def _first_non_empty(values: Sequence[str]) -> str | None:
for value in values:
if value:
return value
return None
Loading
Loading