Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 29 additions & 0 deletions packages/prime-rl-configs/src/prime_rl/configs/orchestrator.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import math
import warnings
from collections.abc import Mapping
from pathlib import Path
from typing import Annotated, Any, Literal, TypeAlias

Expand Down Expand Up @@ -101,6 +102,15 @@ def _deprecate_max_tokens(cls, data: Any) -> Any:
return data


def _chat_template_kwargs_from_extra_body(extra_body: Mapping[str, Any], label: str) -> dict[str, Any]:
raw = extra_body.get("chat_template_kwargs", {})
if raw is None:
return {}
if not isinstance(raw, Mapping):
raise ValueError(f"{label}.extra_body.chat_template_kwargs must be a table.")
return dict(raw)


class EvalSamplingConfig(BaseConfig):
temperature: float | None = Field(None, ge=0)
"""Sampling temperature. None defers to the inference server default."""
Expand Down Expand Up @@ -958,3 +968,22 @@ def resolve_env_config(self):
env.sampling.extra_body.setdefault("min_p", 0.0)
env.sampling.extra_body.setdefault("return_token_ids", True)
return self

@model_validator(mode="after")
def validate_renderer_chat_template_kwargs(self):
if not self.use_renderer:
return self

shared_kwargs = _chat_template_kwargs_from_extra_body(self.train.sampling.extra_body, "train.sampling")
for idx, env in enumerate(self.train.env):
env_kwargs = _chat_template_kwargs_from_extra_body(
env.sampling.extra_body,
f"train.env[{idx}].sampling",
)
if env_kwargs != shared_kwargs:
raise ValueError(
"Renderer chat_template_kwargs must be shared across train envs. "
"Set orchestrator.train.sampling.extra_body.chat_template_kwargs "
"instead of per-env overrides."
)
return self
13 changes: 13 additions & 0 deletions skills/configs/SKILL.md
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,19 @@ CLI: `--env.0.id reverse-text --env.1.id math-env`.

In TOML, an empty section header (`[ckpt]`) does the same.

## Renderer chat template kwargs

For renderer-backed RL, configure instance-wide chat-template toggles under shared train sampling:

```toml
[orchestrator.train.sampling.extra_body.chat_template_kwargs]
enable_thinking = false
```

This is the route the renderer rollout client consumes. Do not set different `chat_template_kwargs` per train env when `orchestrator.use_renderer=true`; one local renderer is shared for token reconstruction, so values must be shared across train envs.

For SFT with `use_renderer=true`, per-example `chat_template_kwargs` are ignored by renderer-backed tokenization.

## RL trainer token exports

For rollout debugging, enable trainer-side token export under `trainer.experimental.token_export` (or `experimental.token_export` when running the trainer entrypoint directly). It writes one JSONL record per exported sequence under `output_dir/token_exports/step_<step>/rank_<rank>.jsonl`. Each record stores aligned per-token arrays for token ids, loss mask, advantage, reward, entropy, mismatch KL, inference/trainer logprobs, importance ratios, probability deltas, and masking diagnostics. It does not decode token text in the trainer.
Expand Down
2 changes: 2 additions & 0 deletions src/prime_rl/orchestrator/orchestrator.py
Original file line number Diff line number Diff line change
Expand Up @@ -887,11 +887,13 @@ async def setup_student_inference_pool(
model_name = config.student.model.name

if config.use_renderer:
chat_template_kwargs = dict(config.train.sampling.extra_body.get("chat_template_kwargs") or {})
renderer = create_renderer(
tokenizer,
renderer=config.renderer.name,
tool_parser=config.renderer.tool_parser,
reasoning_parser=config.renderer.reasoning_parser,
chat_template_kwargs=chat_template_kwargs,
preserve_all_thinking=config.renderer.preserve_all_thinking,
preserve_thinking_between_tool_calls=config.renderer.preserve_thinking_between_tool_calls,
)
Expand Down
5 changes: 2 additions & 3 deletions src/prime_rl/trainer/sft/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,9 +237,8 @@ def should_mask(message: dict) -> bool:
if example.get("chat_template_kwargs") and not self._warned_chat_template_kwargs:
self.logger.warning(
"Example carries chat_template_kwargs but use_renderer=True; "
"renderers don't forward chat_template_kwargs (model-specific "
"renderers bake their template behavior in). These kwargs will "
"be ignored. Further warnings suppressed for this dataset."
"per-example kwargs are ignored by renderer-backed tokenization. "
"Further warnings suppressed for this dataset."
)
self._warned_chat_template_kwargs = True

Expand Down
5 changes: 2 additions & 3 deletions src/prime_rl/utils/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,9 +185,8 @@ def setup_clients(
) -> list[vf.ClientConfig]:
clients = []
client_idx = 0
# Only forward preserve flags when the client actually uses a renderer —
# MITO/TITO clients ignore them and the verifiers ClientConfig may reject
# unknown extras on older versions.
# Only forward preserve flags when the client actually uses a renderer.
# MITO clients ignore them.
renderer_extra: dict = {}
if client_type == "renderer":
renderer_extra = {
Expand Down
10 changes: 7 additions & 3 deletions tests/unit/orchestrator/test_orchestrator_setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,16 @@ async def run() -> None:
model=SimpleNamespace(name="student-model"),
),
renderer=SimpleNamespace(
name="qwen3_vl",
name="qwen3",
tool_parser=None,
reasoning_parser=None,
pool_size=None,
preserve_all_thinking=False,
preserve_thinking_between_tool_calls=False,
),
train=SimpleNamespace(
sampling=SimpleNamespace(extra_body={"chat_template_kwargs": {"enable_thinking": False}})
),
)
logger = MagicMock()
renderer = object()
Expand All @@ -45,9 +48,10 @@ async def run() -> None:
assert returned_pool is inference_pool
create_renderer_mock.assert_called_once_with(
tokenizer,
renderer="qwen3_vl",
renderer="qwen3",
tool_parser=None,
reasoning_parser=None,
chat_template_kwargs={"enable_thinking": False},
preserve_all_thinking=False,
preserve_thinking_between_tool_calls=False,
)
Expand All @@ -56,7 +60,7 @@ async def run() -> None:
model_name="student-model",
train_client_type="renderer",
eval_client_type="openai_chat_completions",
renderer_name="qwen3_vl",
renderer_name="qwen3",
tool_parser=None,
reasoning_parser=None,
renderer_pool_size=None,
Expand Down
22 changes: 22 additions & 0 deletions tests/unit/test_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -469,6 +469,28 @@ def test_orchestrator_explicit_renderer_skips_unmapped_check():
assert config.renderer.name == "qwen3"


def test_orchestrator_accepts_sampling_chat_template_kwargs():
config = OrchestratorConfig.model_validate(
{
"model": {"name": "Qwen/Qwen3-0.6B"},
"train": {"sampling": {"extra_body": {"chat_template_kwargs": {"enable_thinking": False}}}},
}
)

assert config.train.sampling.extra_body["chat_template_kwargs"] == {"enable_thinking": False}
assert config.train.env[0].sampling.extra_body["chat_template_kwargs"] == {"enable_thinking": False}


def test_orchestrator_rejects_per_env_chat_template_kwargs_override():
with pytest.raises(ValidationError, match="must be shared across train envs"):
OrchestratorConfig.model_validate(
{
"model": {"name": "Qwen/Qwen3-0.6B"},
"train": {"env": [{"sampling": {"extra_body": {"chat_template_kwargs": {"enable_thinking": False}}}}]},
}
)


def test_orchestrator_use_renderer_false_skips_unmapped_check():
"""use_renderer=False means the renderer client isn't used, so MODEL_RENDERER_MAP doesn't apply."""
config = OrchestratorConfig.model_validate(
Expand Down
Loading