diff --git a/packages/prime-rl-configs/src/prime_rl/configs/trainer.py b/packages/prime-rl-configs/src/prime_rl/configs/trainer.py index f4d37cd9d0..8464331bc0 100644 --- a/packages/prime-rl-configs/src/prime_rl/configs/trainer.py +++ b/packages/prime-rl-configs/src/prime_rl/configs/trainer.py @@ -479,8 +479,16 @@ class NCCLWeightBroadcastConfig(BaseWeightBroadcastConfig): ] +class TokenExportConfig(BaseConfig): + """Configures per-token rollout exports from the RL trainer.""" + + path: Path | None = None + """JSONL output file. If unset, writes to ``/token_exports/rank_.jsonl``. Relative paths are resolved under output_dir.""" + + class TrainerExperimentalConfig(BaseConfig): - pass + token_export: TokenExportConfig | None = None + """Opt-in per-token JSONL export for rollout debugging. When enabled, writes token ids and aligned trainer metrics after each forward pass.""" class TrainerConfig(BaseConfig): diff --git a/skills/configs/SKILL.md b/skills/configs/SKILL.md index 1eddac38ac..c8ca0fe463 100644 --- a/skills/configs/SKILL.md +++ b/skills/configs/SKILL.md @@ -60,6 +60,18 @@ CLI: `--env.0.id reverse-text --env.1.id math-env`. In TOML, an empty section header (`[ckpt]`) does the same. +## RL trainer token exports + +For rollout debugging, enable trainer-side token export under `trainer.experimental.token_export` (or `experimental.token_export` when running the trainer entrypoint directly). It writes one JSONL record per exported sequence. Each record stores aligned per-token arrays for token ids, loss mask, advantage, reward, entropy, mismatch KL, inference/trainer logprobs, importance ratios, probability deltas, and masking diagnostics. It does not decode token text in the trainer. + +```toml +[trainer.experimental.token_export] +# Optional. Relative paths resolve under trainer.output_dir. +path = "token_exports/sample.jsonl" +``` + +Leave it unset for normal training. When enabled, it exports every sequence from each exporting rank. + ## Key files - `packages/prime-rl-configs/src/prime_rl/` — config classes under `configs/`; `utils/config.py` re-exports `BaseConfig` and `cli` diff --git a/src/prime_rl/trainer/batch.py b/src/prime_rl/trainer/batch.py index 9f0a923b89..9db4aefd74 100644 --- a/src/prime_rl/trainer/batch.py +++ b/src/prime_rl/trainer/batch.py @@ -12,6 +12,8 @@ def prepare_sample(training_example: TrainingSample, seq_len: int) -> MicroBatch loss_mask = training_example.prompt_mask + training_example.completion_mask inference_logprobs = [0.0] * len(training_example.prompt_ids) + training_example.completion_logprobs advantages = [training_example.advantage] * len(input_ids) + reward = training_example.reward if training_example.reward is not None else float("nan") + rewards = [reward] * len(input_ids) position_ids = list(range(len(input_ids))) mm_token_type_ids = training_example.mm_token_type_ids assert training_example.env_name != "all", "env_name='all' is reserved for aggregate metric keys" @@ -33,6 +35,7 @@ def prepare_sample(training_example: TrainingSample, seq_len: int) -> MicroBatch inference_logprobs = inference_logprobs[:seq_len] position_ids = position_ids[:seq_len] advantages = advantages[:seq_len] + rewards = rewards[:seq_len] temperatures = temperatures[:seq_len] if teacher_logprobs is not None: teacher_logprobs = teacher_logprobs[:seq_len] @@ -48,9 +51,10 @@ def prepare_sample(training_example: TrainingSample, seq_len: int) -> MicroBatch == len(loss_mask) == len(position_ids) == len(inference_logprobs) + == len(rewards) == len(temperatures) ), ( - f"input_ids: {len(input_ids)}, advantages: {len(advantages)}, loss_mask: {len(loss_mask)}, position_ids: {len(position_ids)}, inference_logprobs: {len(inference_logprobs)}, temperatures: {len(temperatures)}" + f"input_ids: {len(input_ids)}, advantages: {len(advantages)}, loss_mask: {len(loss_mask)}, position_ids: {len(position_ids)}, inference_logprobs: {len(inference_logprobs)}, rewards: {len(rewards)}, temperatures: {len(temperatures)}" ) if teacher_logprobs is not None: assert len(teacher_logprobs) == len(input_ids), f"teacher_logprobs: {len(teacher_logprobs)}" @@ -74,6 +78,7 @@ def prepare_sample(training_example: TrainingSample, seq_len: int) -> MicroBatch inference_logprobs=inference_logprobs, teacher_logprobs=teacher_logprobs, temperatures=temperatures, + rewards=rewards, routed_experts=routed_experts, mm_token_type_ids=mm_token_type_ids, env_names=env_names, @@ -122,9 +127,16 @@ def packed_samples_into_micro_bs( len(bin_content.input_ids) + len(sample.input_ids) <= max_seq_len and bin_content.training_mode == sample.training_mode ): + existing_len = len(bin_content.input_ids) bin_content.input_ids.extend(sample.input_ids) bin_content.loss_mask.extend(sample.loss_mask) bin_content.advantages.extend(sample.advantages) + if sample.rewards is not None: + if bin_content.rewards is None: + bin_content.rewards = [float("nan")] * existing_len + bin_content.rewards.extend(sample.rewards) + elif bin_content.rewards is not None: + bin_content.rewards.extend([float("nan")] * len(sample.input_ids)) bin_content.inference_logprobs.extend(sample.inference_logprobs) bin_content.temperatures.extend(sample.temperatures) if sample.teacher_logprobs is not None: @@ -175,6 +187,8 @@ def pad_micro_batch(micro_batch: MicroBatch, pad_to_multiple_of: int) -> MicroBa micro_batch.input_ids.extend([1] * padding_size) micro_batch.advantages.extend([0.0] * padding_size) + if micro_batch.rewards is not None: + micro_batch.rewards.extend([float("nan")] * padding_size) micro_batch.loss_mask.extend([False] * padding_size) micro_batch.position_ids.extend(list(range(padding_size))) micro_batch.inference_logprobs.extend([0.0] * padding_size) diff --git a/src/prime_rl/trainer/rl/data.py b/src/prime_rl/trainer/rl/data.py index acc6b23384..73e35159af 100644 --- a/src/prime_rl/trainer/rl/data.py +++ b/src/prime_rl/trainer/rl/data.py @@ -20,6 +20,7 @@ class TensorMicroBatch(TypedDict): input_ids: Int[Tensor, "batch seq"] position_ids: Int[Tensor, "batch seq"] advantages: Float[Tensor, "batch seq"] + rewards: Float[Tensor, "batch seq"] | None inference_logprobs: Float[Tensor, "batch seq"] teacher_logprobs: Float[Tensor, "batch seq"] | None loss_mask: Bool[Tensor, "batch seq"] @@ -108,6 +109,7 @@ def _get_sample_micro_batch(self, generator: torch.Generator) -> TensorMicroBatc "input_ids": input_ids.unsqueeze(0), "position_ids": position_ids.unsqueeze(0), "advantages": advantages.unsqueeze(0), + "rewards": None, "inference_logprobs": inference_logprobs.unsqueeze(0), "teacher_logprobs": None, "temperatures": torch.ones(input_ids.shape[0]).unsqueeze(0), @@ -135,6 +137,7 @@ def _get_micro_batch(self, generator: torch.Generator) -> TensorMicroBatch: ), "position_ids": torch.cat([torch.arange(self.seq_len)]).unsqueeze(0), "advantages": torch.randn(self.seq_len, generator=generator).unsqueeze(0), + "rewards": None, "inference_logprobs": torch.randn(self.seq_len, generator=generator).unsqueeze(0), "teacher_logprobs": None, "temperatures": torch.ones(self.seq_len).unsqueeze(0), @@ -211,6 +214,9 @@ def _micro_batch_to_tensor(self, micro_batch: MicroBatch) -> TensorMicroBatch: input_ids=torch.tensor(micro_batch.input_ids, dtype=torch.long).unsqueeze(0), position_ids=torch.tensor(micro_batch.position_ids, dtype=torch.long).unsqueeze(0), advantages=torch.tensor(micro_batch.advantages, dtype=torch.float).unsqueeze(0), + rewards=torch.tensor(micro_batch.rewards, dtype=torch.float).unsqueeze(0) + if micro_batch.rewards is not None + else None, inference_logprobs=torch.tensor(micro_batch.inference_logprobs, dtype=torch.float).unsqueeze(0), teacher_logprobs=torch.tensor(micro_batch.teacher_logprobs, dtype=torch.float).unsqueeze(0) if micro_batch.teacher_logprobs is not None diff --git a/src/prime_rl/trainer/rl/token_export.py b/src/prime_rl/trainer/rl/token_export.py new file mode 100644 index 0000000000..a142088b82 --- /dev/null +++ b/src/prime_rl/trainer/rl/token_export.py @@ -0,0 +1,234 @@ +import atexit +import json +import math +from collections.abc import Mapping, Sequence +from pathlib import Path +from typing import Any + +import torch +from torch import Tensor + +from prime_rl.configs.trainer import DefaultLossConfig, TokenExportConfig, TrainerConfig +from prime_rl.trainer.rl.loss import compute_importance_ratio_and_mismatch_kl + +SCHEMA_VERSION = 1 + + +class DisabledTokenExporter: + def export(self, *args: Any, **kwargs: Any) -> None: + return + + def close(self) -> None: + return + + +class TokenExporter: + def __init__( + self, + config: TokenExportConfig, + output_dir: Path, + rank: int, + ) -> None: + self.config = config + self.rank = rank + self.path = self._resolve_path(config.path, output_dir, rank) + self.path.parent.mkdir(parents=True, exist_ok=True) + self._file = self.path.open("a", encoding="utf-8") + self._closed = False + self._current_step: int | None = None + self._sequences_this_step = 0 + atexit.register(self.close) + + def export( + self, + step: int, + micro_step: int, + micro_batch: Mapping[str, Any], + model_output: Mapping[str, Tensor], + response_lengths: list[int], + loss_config: Any, + ) -> None: + if self._current_step != step: + self._current_step = step + self._sequences_this_step = 0 + + columns = _export_columns(micro_batch, model_output, loss_config) + _check_lengths(columns) + + start = 0 + for micro_sequence_idx, length in enumerate(response_lengths): + raw_end = start + length + end = _trim_padding(columns, start, raw_end) + if end > start and any(columns["loss_mask"][start:end]): + self._write( + { + "schema_version": SCHEMA_VERSION, + "step": step, + "rank": self.rank, + "micro_step": micro_step, + "micro_sequence_idx": micro_sequence_idx, + "export_sequence_idx": self._sequences_this_step, + "env_name": _first_non_empty(columns["env_names"][start:end]), + "training_mode": str(micro_batch["training_mode"]), + **_slice_columns(columns, start, end), + } + ) + self._sequences_this_step += 1 + start = raw_end + + def close(self) -> None: + if self._closed: + return + self._closed = True + self._file.close() + + def _write(self, record: dict[str, Any]) -> None: + if self._closed: + raise RuntimeError(f"Token exporter is closed for {self.path}") + self._file.write(json.dumps(record, separators=(",", ":"), allow_nan=False) + "\n") + + @staticmethod + def _resolve_path(path: Path | None, output_dir: Path, rank: int) -> Path: + if path is None: + return output_dir / "token_exports" / f"rank_{rank}.jsonl" + if path.is_absolute(): + return path + return output_dir / path + + +def setup_token_exporter( + config: TrainerConfig, parallel_dims: Any, world: Any, logger: Any +) -> TokenExporter | DisabledTokenExporter: + token_export_config = config.experimental.token_export + if token_export_config is None: + return DisabledTokenExporter() + if parallel_dims.cp_enabled and parallel_dims.world_mesh["cp"].get_local_rank() != 0: + return DisabledTokenExporter() + + exporter = TokenExporter(token_export_config, config.output_dir, world.rank) + logger.info(f"Writing token exports to {exporter.path}") + return exporter + + +def _export_columns( + micro_batch: Mapping[str, Any], model_output: Mapping[str, Tensor], loss_config: Any +) -> dict[str, list[Any]]: + token_ids = _tensor_to_ints(micro_batch["input_ids"]) + seq_len = len(token_ids) + trainer_logprobs = model_output["logprobs"] + export_tensors = _compute_export_tensors(micro_batch, trainer_logprobs, loss_config) + + return { + "token_ids": token_ids, + "position_ids": _tensor_to_ints(micro_batch["position_ids"]), + "loss_mask": _tensor_to_bools(micro_batch["loss_mask"]), + "advantages": _tensor_to_floats(micro_batch["advantages"]), + "rewards": _optional_tensor_to_floats(micro_batch.get("rewards"), seq_len), + "inference_logprobs": _tensor_to_floats(micro_batch["inference_logprobs"]), + "trainer_logprobs": _tensor_to_floats(trainer_logprobs), + "entropy": _tensor_to_floats(model_output["entropy"]), + "mismatch_kl": _optional_tensor_to_floats(export_tensors["mismatch_kl"], seq_len), + "log_importance_ratio": _optional_tensor_to_floats(export_tensors["log_importance_ratio"], seq_len), + "importance_ratio": _optional_tensor_to_floats(export_tensors["importance_ratio"], seq_len), + "prob_delta": _optional_tensor_to_floats(export_tensors["prob_delta"], seq_len), + "is_masked": _optional_tensor_to_bools(export_tensors["is_masked"], seq_len), + "is_masked_high": _optional_tensor_to_bools(export_tensors["is_masked_high"], seq_len), + "is_masked_low": _optional_tensor_to_bools(export_tensors["is_masked_low"], seq_len), + "env_names": list(micro_batch["env_names"]), + } + + +def _compute_export_tensors( + micro_batch: Mapping[str, Any], trainer_logprobs: Tensor, loss_config: Any +) -> dict[str, Tensor | None]: + fields: dict[str, Tensor | None] = { + "log_importance_ratio": None, + "importance_ratio": None, + "mismatch_kl": None, + "prob_delta": None, + "is_masked": None, + "is_masked_high": None, + "is_masked_low": None, + } + if micro_batch["training_mode"] == "sft": + return fields + + inference_logprobs = micro_batch["inference_logprobs"].to(trainer_logprobs.device) + loss_mask = micro_batch["loss_mask"].to(trainer_logprobs.device) + advantages = micro_batch["advantages"].to(trainer_logprobs.device) + with torch.no_grad(): + log_ratio, ratio, mismatch_kl = compute_importance_ratio_and_mismatch_kl(trainer_logprobs, inference_logprobs) + prob_delta = torch.exp(trainer_logprobs) - torch.exp(inference_logprobs) + fields["log_importance_ratio"] = log_ratio + fields["importance_ratio"] = ratio + fields["mismatch_kl"] = mismatch_kl + fields["prob_delta"] = prob_delta + if isinstance(loss_config, DefaultLossConfig): + invalid_high = prob_delta > loss_config.dppo_mask_high + invalid_low = prob_delta < -loss_config.dppo_mask_low + positive_advantages = advantages > 0 + negative_advantages = advantages < 0 + invalid = torch.where(positive_advantages, invalid_high, invalid_low) + fields["is_masked"] = loss_mask & invalid + fields["is_masked_high"] = loss_mask & positive_advantages & invalid_high + fields["is_masked_low"] = loss_mask & negative_advantages & invalid_low + return fields + + +def _tensor_to_ints(tensor: Tensor) -> list[int]: + return [int(value) for value in tensor.detach().cpu().reshape(-1).tolist()] + + +def _tensor_to_bools(tensor: Tensor) -> list[bool]: + return [bool(value) for value in tensor.detach().cpu().reshape(-1).tolist()] + + +def _tensor_to_floats(tensor: Tensor) -> list[float | None]: + values = tensor.detach().to(dtype=torch.float32, device="cpu").reshape(-1).tolist() + return [_json_float(value) for value in values] + + +def _optional_tensor_to_floats(tensor: Tensor | None, seq_len: int) -> list[float | None]: + if tensor is None: + return [None] * seq_len + return _tensor_to_floats(tensor) + + +def _optional_tensor_to_bools(tensor: Tensor | None, seq_len: int) -> list[bool | None]: + if tensor is None: + return [None] * seq_len + return _tensor_to_bools(tensor) + + +def _check_lengths(columns: Mapping[str, Sequence[Any]]) -> None: + lengths = {key: len(values) for key, values in columns.items()} + if len(set(lengths.values())) != 1: + raise ValueError(f"Token export fields must have aligned lengths, got {lengths}") + + +def _slice_columns(columns: Mapping[str, Sequence[Any]], start: int, end: int) -> dict[str, list[Any]]: + return {key: list(values[start:end]) for key, values in columns.items() if key != "env_names"} + + +def _trim_padding(columns: Mapping[str, Sequence[Any]], start: int, end: int) -> int: + env_names = columns["env_names"] + loss_mask = columns["loss_mask"] + while end > start and env_names[end - 1] == "" and not loss_mask[end - 1]: + end -= 1 + return end + + +def _json_float(value: float | None) -> float | None: + if value is None: + return None + value = float(value) + if not math.isfinite(value): + return None + return value + + +def _first_non_empty(values: Sequence[str]) -> str | None: + for value in values: + if value: + return value + return None diff --git a/src/prime_rl/trainer/rl/train.py b/src/prime_rl/trainer/rl/train.py index e437efb9e2..d359bae5e7 100644 --- a/src/prime_rl/trainer/rl/train.py +++ b/src/prime_rl/trainer/rl/train.py @@ -35,6 +35,7 @@ shift_tensor_left, shift_tensor_right, ) +from prime_rl.trainer.rl.token_export import setup_token_exporter from prime_rl.trainer.model import ( forward, setup_tokenizer, @@ -239,6 +240,8 @@ def load_run_checkpoint(_optimizer, idx: int) -> None: config.rollout_transport, ) + token_exporter = setup_token_exporter(config, parallel_dims, world, logger) + gc_handler = GarbageCollection(config.gc.interval) if config.gc else None logger.info(f"Starting training loop (max_steps={config.max_steps or 'infinite'})") @@ -500,6 +503,8 @@ def load_run_checkpoint(_optimizer, idx: int) -> None: for env_name, indices in env_to_indices.items(): tensors[f"mismatch_kl/{env_name}"].append(mismatch_kl[indices]) + token_exporter.export(progress.step, micro_step, micro_batch, out, response_lengths, config.loss) + if is_tt_moe_model(model): load_balance_stats = get_load_balance_stats(model) for k, v in load_balance_stats.items(): @@ -678,6 +683,8 @@ def load_run_checkpoint(_optimizer, idx: int) -> None: prof.export_chrome_trace(trace_file) logger.info(f"Saved trace to {trace_file}") + token_exporter.close() + # Write final checkpoint (only for single-run mode; multi-run checkpoints are managed by MultiCheckpointManager) if config.max_concurrent_runs == 1 and ckpt_manager is not None: if not (config.ckpt and config.ckpt.weights_only): diff --git a/src/prime_rl/transport/types.py b/src/prime_rl/transport/types.py index de9b246eb2..d4c947224f 100644 --- a/src/prime_rl/transport/types.py +++ b/src/prime_rl/transport/types.py @@ -80,3 +80,4 @@ class MicroBatch(msgspec.Struct, array_like=True, gc=False, omit_defaults=True): # Loss dispatch is batch-driven (rl/opd → default loss with mode-specific taus, # sft → sft loss). All samples packed into a micro batch share the same mode. training_mode: TrainingMode = "rl" + rewards: list[float] | None = None