diff --git a/packages/prime-rl-configs/src/prime_rl/configs/trainer.py b/packages/prime-rl-configs/src/prime_rl/configs/trainer.py index f4d37cd9d0..f80ac2782f 100644 --- a/packages/prime-rl-configs/src/prime_rl/configs/trainer.py +++ b/packages/prime-rl-configs/src/prime_rl/configs/trainer.py @@ -479,8 +479,30 @@ class NCCLWeightBroadcastConfig(BaseWeightBroadcastConfig): ] +class TokenExportConfig(BaseConfig): + """Configures per-token rollout exports from the RL trainer.""" + + path: Annotated[ + Path | None, + Field( + description=( + "JSONL output file. If unset, writes to " + "`/token_exports/rank_.jsonl`. Relative paths are resolved under output_dir." + ), + ), + ] = None + + class TrainerExperimentalConfig(BaseConfig): - pass + token_export: Annotated[ + TokenExportConfig | None, + Field( + description=( + "Opt-in per-token JSONL export for rollout visualization. " + "When enabled, writes token ids and aligned trainer metrics after each forward pass." + ), + ), + ] = None class TrainerConfig(BaseConfig): diff --git a/scripts/token_export_visualizer.py b/scripts/token_export_visualizer.py new file mode 100644 index 0000000000..5a85453f25 --- /dev/null +++ b/scripts/token_export_visualizer.py @@ -0,0 +1,730 @@ +#!/usr/bin/env python3 +import argparse +import json +import math +from pathlib import Path +from typing import Any + + +def main() -> None: + parser = argparse.ArgumentParser(description="Render prime-rl token export JSONL as HTML.") + parser.add_argument("input", type=Path, help="Path to a token export JSONL file or directory of JSONL exports.") + parser.add_argument("--output", "-o", type=Path, help="Path to write the HTML file.") + parser.add_argument( + "--record-index", + type=int, + default=0, + help="Record index to render after filters are applied, or initial record in --all-records mode.", + ) + parser.add_argument( + "--all-records", + action="store_true", + help="Embed every matched record in one navigable HTML page. Implied for directory inputs.", + ) + parser.add_argument("--step", type=int, help="Only consider records from this trainer step.") + parser.add_argument("--rank", type=int, help="Only consider records from this trainer rank.") + parser.add_argument("--env-name", help="Only consider records for this env name.") + parser.add_argument( + "--tokenizer", + help="Tokenizer model name or local path used to decode token ids into text fragments.", + ) + parser.add_argument( + "--trust-remote-code", + action="store_true", + help="Pass trust_remote_code=True when loading --tokenizer.", + ) + parser.add_argument( + "--max-mismatch", + type=float, + help="Mismatch KL value mapped to the deepest red. Defaults to the p95 finite trainable mismatch.", + ) + args = parser.parse_args() + + records = _filter_records(_load_records(args.input), step=args.step, rank=args.rank, env_name=args.env_name) + if not records: + raise ValueError("No token export records matched the requested filters.") + + render_all = args.all_records or args.input.is_dir() + if render_all: + if args.record_index < 0 or args.record_index >= len(records): + raise IndexError(f"record-index {args.record_index} is out of range for {len(records)} matched records.") + selected_records = records + initial_index = args.record_index + else: + selected_records = [_select_record(records, args.record_index)] + initial_index = 0 + + if args.tokenizer: + tokenizer = _load_tokenizer(args.tokenizer, trust_remote_code=args.trust_remote_code) + for record in selected_records: + _decode_token_texts(record, tokenizer) + for record in selected_records: + _add_derived_token_fields(record) + + output = args.output or _default_output_path(args.input) + output.parent.mkdir(parents=True, exist_ok=True) + output.write_text( + _render_html(selected_records, max_mismatch=args.max_mismatch, initial_index=initial_index), encoding="utf-8" + ) + print(output) + + +def _default_output_path(path: Path) -> Path: + if path.is_dir(): + return path / "index.html" + return path.with_suffix(".html") + + +def _load_records(path: Path) -> list[dict[str, Any]]: + if path.is_dir(): + files = sorted(file for file in path.rglob("*.jsonl") if file.is_file()) + if not files: + raise FileNotFoundError(f"No JSONL files found under {path}") + root = path + else: + files = [path] + root = path.parent + + records = [] + for file in files: + source_file = _relative_source(file, root) + with file.open("r", encoding="utf-8") as f: + for line_number, line in enumerate(f, start=1): + if not line.strip(): + continue + record = json.loads(line) + record["_source_file"] = source_file + record["_source_line"] = line_number + record["_source_record_index"] = len(records) + records.append(record) + return records + + +def _relative_source(file: Path, root: Path) -> str: + try: + return str(file.relative_to(root)) + except ValueError: + return str(file) + + +def _filter_records( + records: list[dict[str, Any]], + *, + step: int | None, + rank: int | None, + env_name: str | None, +) -> list[dict[str, Any]]: + return [ + record + for record in records + if (step is None or record.get("step") == step) + and (rank is None or record.get("rank") == rank) + and (env_name is None or record.get("env_name") == env_name) + ] + + +def _select_record(records: list[dict[str, Any]], record_index: int) -> dict[str, Any]: + if record_index < 0 or record_index >= len(records): + raise IndexError(f"record-index {record_index} is out of range for {len(records)} matched records.") + return records[record_index] + + +def _load_tokenizer(tokenizer_name: str, trust_remote_code: bool) -> Any: + from transformers import AutoTokenizer + + return AutoTokenizer.from_pretrained(tokenizer_name, trust_remote_code=trust_remote_code) + + +def _decode_token_texts(record: dict[str, Any], tokenizer: Any) -> None: + tokens = record.get("tokens", []) + token_ids = [int(token["id"]) for token in tokens] + token_texts = tokenizer.batch_decode( + [[token_id] for token_id in token_ids], + skip_special_tokens=False, + clean_up_tokenization_spaces=False, + ) + for token, text in zip(tokens, token_texts): + token["text"] = text + + +def _add_derived_token_fields(record: dict[str, Any]) -> None: + for token in record.get("tokens", []): + log_ratio = _finite_float(token.get("log_importance_ratio")) + if log_ratio is None: + continue + token.setdefault("sample_kl_trainer_to_inference", log_ratio) + token.setdefault("sample_kl_inference_to_trainer", -log_ratio) + + +def _render_html(records: list[dict[str, Any]], max_mismatch: float | None, initial_index: int) -> str: + scale = _mismatch_scale([token for record in records for token in record.get("tokens", [])], max_mismatch) + prepared_records = _prepare_records(records) + document = """ + + + + + prime-rl token export + + + +
+ +
+ +
+ + + + +""" + return ( + document.replace("__RECORDS_JSON__", _json_script(prepared_records)) + .replace("__MISMATCH_SCALE__", json.dumps(scale)) + .replace("__INITIAL_RECORD_INDEX__", json.dumps(initial_index)) + ) + + +def _prepare_records(records: list[dict[str, Any]]) -> list[dict[str, Any]]: + return [ + { + "meta": _record_meta(record, ordinal), + "segments": _chat_segments(record.get("tokens", [])), + } + for ordinal, record in enumerate(records) + ] + + +def _record_meta(record: dict[str, Any], ordinal: int) -> dict[str, Any]: + tokens = record.get("tokens", []) + token_count = len(tokens) + trainable_count = sum(1 for token in tokens if token.get("loss_mask")) + label_parts = [f"#{ordinal}"] + for key, label in ( + ("step", "step"), + ("rank", "rank"), + ("env_name", "env"), + ("export_sequence_idx", "seq"), + ): + value = record.get(key) + if value is not None: + label_parts.append(f"{label} {value}") + + subtitle_parts = [ + f"{trainable_count}/{token_count} trainable", + str(record.get("_source_file", "")), + f"line {record.get('_source_line')}", + ] + meta = { + "ordinal": ordinal, + "label": " | ".join(label_parts), + "subtitle": " | ".join(part for part in subtitle_parts if part), + "step": record.get("step"), + "rank": record.get("rank"), + "micro_step": record.get("micro_step"), + "micro_sequence_idx": record.get("micro_sequence_idx"), + "export_sequence_idx": record.get("export_sequence_idx"), + "env_name": record.get("env_name"), + "source_file": record.get("_source_file"), + "source_line": record.get("_source_line"), + "token_count": token_count, + "trainable_token_count": trainable_count, + } + meta["search_text"] = " ".join(str(value) for value in meta.values() if value is not None).lower() + return meta + + +def _json_script(value: Any) -> str: + payload = json.dumps(_json_safe(value), ensure_ascii=False, allow_nan=False, separators=(",", ":")) + return payload.replace("&", "\\u0026").replace("<", "\\u003c").replace(">", "\\u003e") + + +def _json_safe(value: Any) -> Any: + if isinstance(value, float): + return value if math.isfinite(value) else None + if isinstance(value, dict): + return {str(key): _json_safe(item) for key, item in value.items()} + if isinstance(value, list | tuple): + return [_json_safe(item) for item in value] + return value + + +def _chat_segments(tokens: list[dict[str, Any]]) -> list[dict[str, Any]]: + if not tokens or tokens[0].get("text") is None: + return [{"role": "tokens", "tokens": tokens}] + + segments = [] + idx = 0 + pending = [] + while idx < len(tokens): + if tokens[idx].get("text") != "<|im_start|>": + pending.append(tokens[idx]) + idx += 1 + continue + + if pending: + segments.append({"role": "tokens", "tokens": pending}) + pending = [] + + role = "message" + idx += 1 + if idx < len(tokens): + role = str(tokens[idx].get("text", role)).strip() or role + idx += 1 + if idx < len(tokens) and tokens[idx].get("text") == "\n": + idx += 1 + + message_tokens = [] + while idx < len(tokens) and tokens[idx].get("text") != "<|im_end|>": + message_tokens.append(tokens[idx]) + idx += 1 + if idx < len(tokens) and tokens[idx].get("text") == "<|im_end|>": + idx += 1 + if idx < len(tokens) and tokens[idx].get("text") == "\n": + idx += 1 + segments.append({"role": role, "tokens": message_tokens}) + + if pending: + segments.append({"role": "tokens", "tokens": pending}) + + return [segment for segment in segments if segment["tokens"]] + + +def _mismatch_scale(tokens: list[dict[str, Any]], override: float | None) -> float: + if override is not None: + if override <= 0: + raise ValueError("--max-mismatch must be positive.") + return override + values = [_finite_float(token.get("mismatch_kl")) for token in tokens if token.get("loss_mask")] + values = [value for value in values if value is not None and value > 0] + if not values: + return 1.0 + return max(_percentile(values, 0.95), 0.05) + + +def _percentile(values: list[float], q: float) -> float: + ordered = sorted(values) + idx = min(max(round((len(ordered) - 1) * q), 0), len(ordered) - 1) + return ordered[idx] + + +def _finite_float(value: Any) -> float | None: + if value is None: + return None + value = float(value) + if not math.isfinite(value): + return None + return value + + +if __name__ == "__main__": + main() diff --git a/skills/configs/SKILL.md b/skills/configs/SKILL.md index 1eddac38ac..9877be8ccb 100644 --- a/skills/configs/SKILL.md +++ b/skills/configs/SKILL.md @@ -60,6 +60,25 @@ CLI: `--env.0.id reverse-text --env.1.id math-env`. In TOML, an empty section header (`[ckpt]`) does the same. +## RL trainer token exports + +For rollout visualization/debugging, enable trainer-side token export under `trainer.experimental.token_export` (or `experimental.token_export` when running the trainer entrypoint directly). It writes JSONL records with per-token ids, loss mask, advantage, reward, entropy, mismatch KL, inference logprob/prob, trainer logprob/prob, and masking diagnostics. It does not decode token text in the trainer. + +```toml +[trainer.experimental.token_export] +# Optional. Relative paths resolve under trainer.output_dir. +path = "token_exports/sample.jsonl" +``` + +Leave it unset for normal training. When enabled, it exports every sequence from each exporting rank. + +```bash +uv run scripts/token_export_visualizer.py outputs/token_exports/rank_0.jsonl --tokenizer Qwen/Qwen3-0.6B -o /tmp/token_export.html +uv run scripts/token_export_visualizer.py outputs/token_exports --tokenizer Qwen/Qwen3-0.6B -o outputs/token_exports/index.html +``` + +Directory inputs render every exported sequence as one navigable HTML page. For a single JSONL file, pass `--all-records` to embed every matching record instead of just one `--record-index`. + ## Key files - `packages/prime-rl-configs/src/prime_rl/` — config classes under `configs/`; `utils/config.py` re-exports `BaseConfig` and `cli` diff --git a/src/prime_rl/trainer/batch.py b/src/prime_rl/trainer/batch.py index 9f0a923b89..9db4aefd74 100644 --- a/src/prime_rl/trainer/batch.py +++ b/src/prime_rl/trainer/batch.py @@ -12,6 +12,8 @@ def prepare_sample(training_example: TrainingSample, seq_len: int) -> MicroBatch loss_mask = training_example.prompt_mask + training_example.completion_mask inference_logprobs = [0.0] * len(training_example.prompt_ids) + training_example.completion_logprobs advantages = [training_example.advantage] * len(input_ids) + reward = training_example.reward if training_example.reward is not None else float("nan") + rewards = [reward] * len(input_ids) position_ids = list(range(len(input_ids))) mm_token_type_ids = training_example.mm_token_type_ids assert training_example.env_name != "all", "env_name='all' is reserved for aggregate metric keys" @@ -33,6 +35,7 @@ def prepare_sample(training_example: TrainingSample, seq_len: int) -> MicroBatch inference_logprobs = inference_logprobs[:seq_len] position_ids = position_ids[:seq_len] advantages = advantages[:seq_len] + rewards = rewards[:seq_len] temperatures = temperatures[:seq_len] if teacher_logprobs is not None: teacher_logprobs = teacher_logprobs[:seq_len] @@ -48,9 +51,10 @@ def prepare_sample(training_example: TrainingSample, seq_len: int) -> MicroBatch == len(loss_mask) == len(position_ids) == len(inference_logprobs) + == len(rewards) == len(temperatures) ), ( - f"input_ids: {len(input_ids)}, advantages: {len(advantages)}, loss_mask: {len(loss_mask)}, position_ids: {len(position_ids)}, inference_logprobs: {len(inference_logprobs)}, temperatures: {len(temperatures)}" + f"input_ids: {len(input_ids)}, advantages: {len(advantages)}, loss_mask: {len(loss_mask)}, position_ids: {len(position_ids)}, inference_logprobs: {len(inference_logprobs)}, rewards: {len(rewards)}, temperatures: {len(temperatures)}" ) if teacher_logprobs is not None: assert len(teacher_logprobs) == len(input_ids), f"teacher_logprobs: {len(teacher_logprobs)}" @@ -74,6 +78,7 @@ def prepare_sample(training_example: TrainingSample, seq_len: int) -> MicroBatch inference_logprobs=inference_logprobs, teacher_logprobs=teacher_logprobs, temperatures=temperatures, + rewards=rewards, routed_experts=routed_experts, mm_token_type_ids=mm_token_type_ids, env_names=env_names, @@ -122,9 +127,16 @@ def packed_samples_into_micro_bs( len(bin_content.input_ids) + len(sample.input_ids) <= max_seq_len and bin_content.training_mode == sample.training_mode ): + existing_len = len(bin_content.input_ids) bin_content.input_ids.extend(sample.input_ids) bin_content.loss_mask.extend(sample.loss_mask) bin_content.advantages.extend(sample.advantages) + if sample.rewards is not None: + if bin_content.rewards is None: + bin_content.rewards = [float("nan")] * existing_len + bin_content.rewards.extend(sample.rewards) + elif bin_content.rewards is not None: + bin_content.rewards.extend([float("nan")] * len(sample.input_ids)) bin_content.inference_logprobs.extend(sample.inference_logprobs) bin_content.temperatures.extend(sample.temperatures) if sample.teacher_logprobs is not None: @@ -175,6 +187,8 @@ def pad_micro_batch(micro_batch: MicroBatch, pad_to_multiple_of: int) -> MicroBa micro_batch.input_ids.extend([1] * padding_size) micro_batch.advantages.extend([0.0] * padding_size) + if micro_batch.rewards is not None: + micro_batch.rewards.extend([float("nan")] * padding_size) micro_batch.loss_mask.extend([False] * padding_size) micro_batch.position_ids.extend(list(range(padding_size))) micro_batch.inference_logprobs.extend([0.0] * padding_size) diff --git a/src/prime_rl/trainer/rl/data.py b/src/prime_rl/trainer/rl/data.py index acc6b23384..73e35159af 100644 --- a/src/prime_rl/trainer/rl/data.py +++ b/src/prime_rl/trainer/rl/data.py @@ -20,6 +20,7 @@ class TensorMicroBatch(TypedDict): input_ids: Int[Tensor, "batch seq"] position_ids: Int[Tensor, "batch seq"] advantages: Float[Tensor, "batch seq"] + rewards: Float[Tensor, "batch seq"] | None inference_logprobs: Float[Tensor, "batch seq"] teacher_logprobs: Float[Tensor, "batch seq"] | None loss_mask: Bool[Tensor, "batch seq"] @@ -108,6 +109,7 @@ def _get_sample_micro_batch(self, generator: torch.Generator) -> TensorMicroBatc "input_ids": input_ids.unsqueeze(0), "position_ids": position_ids.unsqueeze(0), "advantages": advantages.unsqueeze(0), + "rewards": None, "inference_logprobs": inference_logprobs.unsqueeze(0), "teacher_logprobs": None, "temperatures": torch.ones(input_ids.shape[0]).unsqueeze(0), @@ -135,6 +137,7 @@ def _get_micro_batch(self, generator: torch.Generator) -> TensorMicroBatch: ), "position_ids": torch.cat([torch.arange(self.seq_len)]).unsqueeze(0), "advantages": torch.randn(self.seq_len, generator=generator).unsqueeze(0), + "rewards": None, "inference_logprobs": torch.randn(self.seq_len, generator=generator).unsqueeze(0), "teacher_logprobs": None, "temperatures": torch.ones(self.seq_len).unsqueeze(0), @@ -211,6 +214,9 @@ def _micro_batch_to_tensor(self, micro_batch: MicroBatch) -> TensorMicroBatch: input_ids=torch.tensor(micro_batch.input_ids, dtype=torch.long).unsqueeze(0), position_ids=torch.tensor(micro_batch.position_ids, dtype=torch.long).unsqueeze(0), advantages=torch.tensor(micro_batch.advantages, dtype=torch.float).unsqueeze(0), + rewards=torch.tensor(micro_batch.rewards, dtype=torch.float).unsqueeze(0) + if micro_batch.rewards is not None + else None, inference_logprobs=torch.tensor(micro_batch.inference_logprobs, dtype=torch.float).unsqueeze(0), teacher_logprobs=torch.tensor(micro_batch.teacher_logprobs, dtype=torch.float).unsqueeze(0) if micro_batch.teacher_logprobs is not None diff --git a/src/prime_rl/trainer/rl/token_export.py b/src/prime_rl/trainer/rl/token_export.py new file mode 100644 index 0000000000..480d6229e5 --- /dev/null +++ b/src/prime_rl/trainer/rl/token_export.py @@ -0,0 +1,298 @@ +import atexit +import json +import math +from collections.abc import Mapping +from pathlib import Path +from typing import Any + +import torch +from torch import Tensor + +from prime_rl.configs.trainer import DefaultLossConfig, TokenExportConfig, TrainerConfig +from prime_rl.trainer.rl.loss import compute_importance_ratio_and_mismatch_kl + +SCHEMA_VERSION = 1 + + +class DisabledTokenExporter: + def export(self, *args: Any, **kwargs: Any) -> None: + return + + def close(self) -> None: + return + + +class TokenExporter: + def __init__( + self, + config: TokenExportConfig, + output_dir: Path, + rank: int, + ) -> None: + self.config = config + self.rank = rank + self.path = self._resolve_path(config.path, output_dir, rank) + self.path.parent.mkdir(parents=True, exist_ok=True) + self._file = self.path.open("a", encoding="utf-8") + self._closed = False + self._current_step: int | None = None + self._sequences_this_step = 0 + atexit.register(self.close) + + def export( + self, + step: int, + micro_step: int, + micro_batch: Mapping[str, Any], + model_output: Mapping[str, Tensor], + response_lengths: list[int], + loss_config: Any, + ) -> None: + if self._current_step != step: + self._current_step = step + self._sequences_this_step = 0 + + flat = self._flatten_micro_batch( + micro_batch, + model_output["logprobs"], + model_output["entropy"], + _compute_export_tensors(micro_batch, model_output["logprobs"], loss_config), + ) + start = 0 + for micro_sequence_idx, length in enumerate(response_lengths): + end = start + length + record = self._build_record( + step=step, + micro_step=micro_step, + micro_sequence_idx=micro_sequence_idx, + flat=flat, + start=start, + end=end, + training_mode=str(micro_batch["training_mode"]), + ) + start = end + if record is None: + continue + self._write(record) + self._sequences_this_step += 1 + + def close(self) -> None: + if self._closed: + return + self._closed = True + self._file.close() + + def _write(self, record: dict[str, Any]) -> None: + if self._closed: + raise RuntimeError(f"Token exporter is closed for {self.path}") + self._file.write(json.dumps(record, separators=(",", ":"), allow_nan=False) + "\n") + + @staticmethod + def _resolve_path(path: Path | None, output_dir: Path, rank: int) -> Path: + if path is None: + return output_dir / "token_exports" / f"rank_{rank}.jsonl" + if path.is_absolute(): + return path + return output_dir / path + + def _flatten_micro_batch( + self, + micro_batch: Mapping[str, Any], + trainer_logprobs: Tensor, + entropy: Tensor, + export_tensors: Mapping[str, Tensor | None], + ) -> dict[str, list[Any]]: + input_ids = _tensor_to_ints(micro_batch["input_ids"]) + seq_len = len(input_ids) + rewards_tensor = micro_batch.get("rewards") + + flat = { + "input_ids": input_ids, + "position_ids": _tensor_to_ints(micro_batch["position_ids"]), + "loss_mask": _tensor_to_bools(micro_batch["loss_mask"]), + "advantages": _tensor_to_floats(micro_batch["advantages"]), + "rewards": _optional_tensor_to_floats(rewards_tensor, seq_len), + "inference_logprobs": _tensor_to_floats(micro_batch["inference_logprobs"]), + "trainer_logprobs": _tensor_to_floats(trainer_logprobs), + "entropy": _tensor_to_floats(entropy), + "mismatch_kl": _optional_tensor_to_floats(export_tensors["mismatch_kl"], seq_len), + "log_importance_ratio": _optional_tensor_to_floats(export_tensors["log_importance_ratio"], seq_len), + "importance_ratio": _optional_tensor_to_floats(export_tensors["importance_ratio"], seq_len), + "prob_delta": _optional_tensor_to_floats(export_tensors["prob_delta"], seq_len), + "is_masked": _optional_tensor_to_bools(export_tensors["is_masked"], seq_len), + "is_masked_high": _optional_tensor_to_bools(export_tensors["is_masked_high"], seq_len), + "is_masked_low": _optional_tensor_to_bools(export_tensors["is_masked_low"], seq_len), + "env_names": list(micro_batch["env_names"]), + } + lengths = {key: len(values) for key, values in flat.items()} + if len(set(lengths.values())) != 1: + raise ValueError(f"Token export fields must have aligned lengths, got {lengths}") + return flat + + def _build_record( + self, + *, + step: int, + micro_step: int, + micro_sequence_idx: int, + flat: dict[str, list[Any]], + start: int, + end: int, + training_mode: str, + ) -> dict[str, Any] | None: + end = _trim_padding(flat, start, end) + if start >= end: + return None + + loss_mask = flat["loss_mask"][start:end] + if not any(loss_mask): + return None + + token_ids = flat["input_ids"][start:end] + tokens = [] + for local_idx, absolute_idx in enumerate(range(start, end)): + log_importance_ratio = _json_float(flat["log_importance_ratio"][absolute_idx]) + token = { + "index": local_idx, + "id": token_ids[local_idx], + "position": flat["position_ids"][absolute_idx], + "loss_mask": loss_mask[local_idx], + "advantage": _json_float(flat["advantages"][absolute_idx]), + "reward": _json_float(flat["rewards"][absolute_idx]), + "entropy": _json_float(flat["entropy"][absolute_idx]), + "mismatch_kl": _json_float(flat["mismatch_kl"][absolute_idx]), + "log_importance_ratio": log_importance_ratio, + "sample_kl_trainer_to_inference": log_importance_ratio, + "sample_kl_inference_to_trainer": -log_importance_ratio if log_importance_ratio is not None else None, + "importance_ratio": _json_float(flat["importance_ratio"][absolute_idx]), + "prob_delta": _json_float(flat["prob_delta"][absolute_idx]), + "inference_logprob": _json_float(flat["inference_logprobs"][absolute_idx]), + "trainer_logprob": _json_float(flat["trainer_logprobs"][absolute_idx]), + "is_masked": flat["is_masked"][absolute_idx], + "is_masked_high": flat["is_masked_high"][absolute_idx], + "is_masked_low": flat["is_masked_low"][absolute_idx], + } + token["inference_prob"] = _prob_from_logprob(token["inference_logprob"]) + token["trainer_prob"] = _prob_from_logprob(token["trainer_logprob"]) + tokens.append(token) + + env_name = _first_non_empty(flat["env_names"][start:end]) + return { + "schema_version": SCHEMA_VERSION, + "step": step, + "rank": self.rank, + "micro_step": micro_step, + "micro_sequence_idx": micro_sequence_idx, + "export_sequence_idx": self._sequences_this_step, + "env_name": env_name, + "training_mode": training_mode, + "tokens": tokens, + } + + +def setup_token_exporter( + config: TrainerConfig, parallel_dims: Any, world: Any, logger: Any +) -> TokenExporter | DisabledTokenExporter: + token_export_config = config.experimental.token_export + if token_export_config is None: + return DisabledTokenExporter() + if parallel_dims.cp_enabled and parallel_dims.world_mesh["cp"].get_local_rank() != 0: + return DisabledTokenExporter() + + exporter = TokenExporter(token_export_config, config.output_dir, world.rank) + logger.info(f"Writing token exports to {exporter.path}") + return exporter + + +def _compute_export_tensors( + micro_batch: Mapping[str, Any], trainer_logprobs: Tensor, loss_config: Any +) -> dict[str, Tensor | None]: + fields: dict[str, Tensor | None] = { + "log_importance_ratio": None, + "importance_ratio": None, + "mismatch_kl": None, + "prob_delta": None, + "is_masked": None, + "is_masked_high": None, + "is_masked_low": None, + } + if micro_batch["training_mode"] == "sft": + return fields + + inference_logprobs = micro_batch["inference_logprobs"].to(trainer_logprobs.device) + loss_mask = micro_batch["loss_mask"].to(trainer_logprobs.device) + advantages = micro_batch["advantages"].to(trainer_logprobs.device) + with torch.no_grad(): + log_ratio, ratio, mismatch_kl = compute_importance_ratio_and_mismatch_kl(trainer_logprobs, inference_logprobs) + prob_delta = torch.exp(trainer_logprobs) - torch.exp(inference_logprobs) + fields["log_importance_ratio"] = log_ratio + fields["importance_ratio"] = ratio + fields["mismatch_kl"] = mismatch_kl + fields["prob_delta"] = prob_delta + if isinstance(loss_config, DefaultLossConfig): + invalid_high = prob_delta > loss_config.dppo_mask_high + invalid_low = prob_delta < -loss_config.dppo_mask_low + positive_advantages = advantages > 0 + negative_advantages = advantages < 0 + invalid = torch.where(positive_advantages, invalid_high, invalid_low) + fields["is_masked"] = loss_mask & invalid + fields["is_masked_high"] = loss_mask & positive_advantages & invalid_high + fields["is_masked_low"] = loss_mask & negative_advantages & invalid_low + return fields + + +def _tensor_to_ints(tensor: Tensor) -> list[int]: + return [int(value) for value in tensor.detach().cpu().reshape(-1).tolist()] + + +def _tensor_to_bools(tensor: Tensor) -> list[bool]: + return [bool(value) for value in tensor.detach().cpu().reshape(-1).tolist()] + + +def _tensor_to_floats(tensor: Tensor) -> list[float]: + values = tensor.detach().to(dtype=torch.float32, device="cpu").reshape(-1).tolist() + return [float(value) for value in values] + + +def _optional_tensor_to_floats(tensor: Tensor | None, seq_len: int) -> list[float | None]: + if tensor is None: + return [None] * seq_len + return _tensor_to_floats(tensor) + + +def _optional_tensor_to_bools(tensor: Tensor | None, seq_len: int) -> list[bool | None]: + if tensor is None: + return [None] * seq_len + return _tensor_to_bools(tensor) + + +def _trim_padding(flat: dict[str, list[Any]], start: int, end: int) -> int: + env_names = flat["env_names"] + loss_mask = flat["loss_mask"] + while end > start and env_names[end - 1] == "" and not loss_mask[end - 1]: + end -= 1 + return end + + +def _json_float(value: float | None) -> float | None: + if value is None: + return None + value = float(value) + if not math.isfinite(value): + return None + return value + + +def _prob_from_logprob(logprob: float | None) -> float | None: + if logprob is None: + return None + if logprob > 709: + return None + return math.exp(logprob) + + +def _first_non_empty(values: list[str]) -> str | None: + for value in values: + if value: + return value + return None diff --git a/src/prime_rl/trainer/rl/train.py b/src/prime_rl/trainer/rl/train.py index e437efb9e2..d359bae5e7 100644 --- a/src/prime_rl/trainer/rl/train.py +++ b/src/prime_rl/trainer/rl/train.py @@ -35,6 +35,7 @@ shift_tensor_left, shift_tensor_right, ) +from prime_rl.trainer.rl.token_export import setup_token_exporter from prime_rl.trainer.model import ( forward, setup_tokenizer, @@ -239,6 +240,8 @@ def load_run_checkpoint(_optimizer, idx: int) -> None: config.rollout_transport, ) + token_exporter = setup_token_exporter(config, parallel_dims, world, logger) + gc_handler = GarbageCollection(config.gc.interval) if config.gc else None logger.info(f"Starting training loop (max_steps={config.max_steps or 'infinite'})") @@ -500,6 +503,8 @@ def load_run_checkpoint(_optimizer, idx: int) -> None: for env_name, indices in env_to_indices.items(): tensors[f"mismatch_kl/{env_name}"].append(mismatch_kl[indices]) + token_exporter.export(progress.step, micro_step, micro_batch, out, response_lengths, config.loss) + if is_tt_moe_model(model): load_balance_stats = get_load_balance_stats(model) for k, v in load_balance_stats.items(): @@ -678,6 +683,8 @@ def load_run_checkpoint(_optimizer, idx: int) -> None: prof.export_chrome_trace(trace_file) logger.info(f"Saved trace to {trace_file}") + token_exporter.close() + # Write final checkpoint (only for single-run mode; multi-run checkpoints are managed by MultiCheckpointManager) if config.max_concurrent_runs == 1 and ckpt_manager is not None: if not (config.ckpt and config.ckpt.weights_only): diff --git a/src/prime_rl/transport/types.py b/src/prime_rl/transport/types.py index de9b246eb2..d4c947224f 100644 --- a/src/prime_rl/transport/types.py +++ b/src/prime_rl/transport/types.py @@ -80,3 +80,4 @@ class MicroBatch(msgspec.Struct, array_like=True, gc=False, omit_defaults=True): # Loss dispatch is batch-driven (rl/opd → default loss with mode-specific taus, # sft → sft loss). All samples packed into a micro batch share the same mode. training_mode: TrainingMode = "rl" + rewards: list[float] | None = None