diff --git a/packages/prime-rl-configs/src/prime_rl/configs/orchestrator.py b/packages/prime-rl-configs/src/prime_rl/configs/orchestrator.py index 1507f96079..867f19dcef 100644 --- a/packages/prime-rl-configs/src/prime_rl/configs/orchestrator.py +++ b/packages/prime-rl-configs/src/prime_rl/configs/orchestrator.py @@ -1,5 +1,6 @@ import math import warnings +from collections.abc import Mapping from pathlib import Path from typing import Annotated, Any, Literal, TypeAlias @@ -101,6 +102,15 @@ def _deprecate_max_tokens(cls, data: Any) -> Any: return data +def _chat_template_kwargs_from_extra_body(extra_body: Mapping[str, Any], label: str) -> dict[str, Any]: + raw = extra_body.get("chat_template_kwargs", {}) + if raw is None: + return {} + if not isinstance(raw, Mapping): + raise ValueError(f"{label}.extra_body.chat_template_kwargs must be a table.") + return dict(raw) + + class EvalSamplingConfig(BaseConfig): temperature: float | None = Field(None, ge=0) """Sampling temperature. None defers to the inference server default.""" @@ -958,3 +968,22 @@ def resolve_env_config(self): env.sampling.extra_body.setdefault("min_p", 0.0) env.sampling.extra_body.setdefault("return_token_ids", True) return self + + @model_validator(mode="after") + def validate_renderer_chat_template_kwargs(self): + if not self.use_renderer: + return self + + shared_kwargs = _chat_template_kwargs_from_extra_body(self.train.sampling.extra_body, "train.sampling") + for idx, env in enumerate(self.train.env): + env_kwargs = _chat_template_kwargs_from_extra_body( + env.sampling.extra_body, + f"train.env[{idx}].sampling", + ) + if env_kwargs != shared_kwargs: + raise ValueError( + "Renderer chat_template_kwargs must be shared across train envs. " + "Set orchestrator.train.sampling.extra_body.chat_template_kwargs " + "instead of per-env overrides." + ) + return self diff --git a/skills/configs/SKILL.md b/skills/configs/SKILL.md index 83f7dd8d47..46c0381ec1 100644 --- a/skills/configs/SKILL.md +++ b/skills/configs/SKILL.md @@ -60,6 +60,19 @@ CLI: `--env.0.id reverse-text --env.1.id math-env`. In TOML, an empty section header (`[ckpt]`) does the same. +## Renderer chat template kwargs + +For renderer-backed RL, configure instance-wide chat-template toggles under shared train sampling: + +```toml +[orchestrator.train.sampling.extra_body.chat_template_kwargs] +enable_thinking = false +``` + +This is the route the renderer rollout client consumes. Do not set different `chat_template_kwargs` per train env when `orchestrator.use_renderer=true`; one local renderer is shared for token reconstruction, so values must be shared across train envs. + +For SFT with `use_renderer=true`, per-example `chat_template_kwargs` are ignored by renderer-backed tokenization. + ## RL trainer token exports For rollout debugging, enable trainer-side token export under `trainer.experimental.token_export` (or `experimental.token_export` when running the trainer entrypoint directly). It writes one JSONL record per exported sequence under `output_dir/token_exports/step_/rank_.jsonl`. Each record stores aligned per-token arrays for token ids, loss mask, advantage, reward, entropy, mismatch KL, inference/trainer logprobs, importance ratios, probability deltas, and masking diagnostics. It does not decode token text in the trainer. diff --git a/src/prime_rl/orchestrator/orchestrator.py b/src/prime_rl/orchestrator/orchestrator.py index 5e5932ef58..e342aa37c1 100644 --- a/src/prime_rl/orchestrator/orchestrator.py +++ b/src/prime_rl/orchestrator/orchestrator.py @@ -887,11 +887,13 @@ async def setup_student_inference_pool( model_name = config.student.model.name if config.use_renderer: + chat_template_kwargs = dict(config.train.sampling.extra_body.get("chat_template_kwargs") or {}) renderer = create_renderer( tokenizer, renderer=config.renderer.name, tool_parser=config.renderer.tool_parser, reasoning_parser=config.renderer.reasoning_parser, + chat_template_kwargs=chat_template_kwargs, preserve_all_thinking=config.renderer.preserve_all_thinking, preserve_thinking_between_tool_calls=config.renderer.preserve_thinking_between_tool_calls, ) diff --git a/src/prime_rl/trainer/sft/data.py b/src/prime_rl/trainer/sft/data.py index 253acadaaa..2d4323ed7d 100644 --- a/src/prime_rl/trainer/sft/data.py +++ b/src/prime_rl/trainer/sft/data.py @@ -237,9 +237,8 @@ def should_mask(message: dict) -> bool: if example.get("chat_template_kwargs") and not self._warned_chat_template_kwargs: self.logger.warning( "Example carries chat_template_kwargs but use_renderer=True; " - "renderers don't forward chat_template_kwargs (model-specific " - "renderers bake their template behavior in). These kwargs will " - "be ignored. Further warnings suppressed for this dataset." + "per-example kwargs are ignored by renderer-backed tokenization. " + "Further warnings suppressed for this dataset." ) self._warned_chat_template_kwargs = True diff --git a/src/prime_rl/utils/client.py b/src/prime_rl/utils/client.py index beb41e8ab6..0d8ca52327 100644 --- a/src/prime_rl/utils/client.py +++ b/src/prime_rl/utils/client.py @@ -185,9 +185,8 @@ def setup_clients( ) -> list[vf.ClientConfig]: clients = [] client_idx = 0 - # Only forward preserve flags when the client actually uses a renderer — - # MITO/TITO clients ignore them and the verifiers ClientConfig may reject - # unknown extras on older versions. + # Only forward preserve flags when the client actually uses a renderer. + # MITO clients ignore them. renderer_extra: dict = {} if client_type == "renderer": renderer_extra = { diff --git a/tests/unit/orchestrator/test_orchestrator_setup.py b/tests/unit/orchestrator/test_orchestrator_setup.py index 80856549b0..aee891a542 100644 --- a/tests/unit/orchestrator/test_orchestrator_setup.py +++ b/tests/unit/orchestrator/test_orchestrator_setup.py @@ -16,13 +16,16 @@ async def run() -> None: model=SimpleNamespace(name="student-model"), ), renderer=SimpleNamespace( - name="qwen3_vl", + name="qwen3", tool_parser=None, reasoning_parser=None, pool_size=None, preserve_all_thinking=False, preserve_thinking_between_tool_calls=False, ), + train=SimpleNamespace( + sampling=SimpleNamespace(extra_body={"chat_template_kwargs": {"enable_thinking": False}}) + ), ) logger = MagicMock() renderer = object() @@ -45,9 +48,10 @@ async def run() -> None: assert returned_pool is inference_pool create_renderer_mock.assert_called_once_with( tokenizer, - renderer="qwen3_vl", + renderer="qwen3", tool_parser=None, reasoning_parser=None, + chat_template_kwargs={"enable_thinking": False}, preserve_all_thinking=False, preserve_thinking_between_tool_calls=False, ) @@ -56,7 +60,7 @@ async def run() -> None: model_name="student-model", train_client_type="renderer", eval_client_type="openai_chat_completions", - renderer_name="qwen3_vl", + renderer_name="qwen3", tool_parser=None, reasoning_parser=None, renderer_pool_size=None, diff --git a/tests/unit/test_configs.py b/tests/unit/test_configs.py index 66ce195bc6..f7411e027e 100644 --- a/tests/unit/test_configs.py +++ b/tests/unit/test_configs.py @@ -469,6 +469,28 @@ def test_orchestrator_explicit_renderer_skips_unmapped_check(): assert config.renderer.name == "qwen3" +def test_orchestrator_accepts_sampling_chat_template_kwargs(): + config = OrchestratorConfig.model_validate( + { + "model": {"name": "Qwen/Qwen3-0.6B"}, + "train": {"sampling": {"extra_body": {"chat_template_kwargs": {"enable_thinking": False}}}}, + } + ) + + assert config.train.sampling.extra_body["chat_template_kwargs"] == {"enable_thinking": False} + assert config.train.env[0].sampling.extra_body["chat_template_kwargs"] == {"enable_thinking": False} + + +def test_orchestrator_rejects_per_env_chat_template_kwargs_override(): + with pytest.raises(ValidationError, match="must be shared across train envs"): + OrchestratorConfig.model_validate( + { + "model": {"name": "Qwen/Qwen3-0.6B"}, + "train": {"env": [{"sampling": {"extra_body": {"chat_template_kwargs": {"enable_thinking": False}}}}]}, + } + ) + + def test_orchestrator_use_renderer_false_skips_unmapped_check(): """use_renderer=False means the renderer client isn't used, so MODEL_RENDERER_MAP doesn't apply.""" config = OrchestratorConfig.model_validate(