diff --git a/.gitignore b/.gitignore index efbf202..cca70b8 100644 --- a/.gitignore +++ b/.gitignore @@ -31,3 +31,6 @@ coverage.xml .idea/ .vscode/ *.swp + +# agent harness state +.claude/ diff --git a/README.md b/README.md index 51e4d19..b2c3f2f 100644 --- a/README.md +++ b/README.md @@ -17,7 +17,7 @@ from transformers import AutoTokenizer from renderers import create_renderer tok = AutoTokenizer.from_pretrained("Qwen/Qwen3-8B") -r = create_renderer(tok, renderer="auto") # → Qwen3Renderer +r = create_renderer(tok) # → Qwen3Renderer (auto-resolved) prompt_ids = r.render_ids( [{"role": "user", "content": "hi"}], @@ -71,17 +71,17 @@ Each hand-coded bridge: ### Picking a renderer ```python -r = create_renderer(tok, renderer="auto") +r = create_renderer(tok) # AutoRendererConfig is the implicit default ``` -Auto-detect matches `tokenizer.name_or_path` against `MODEL_RENDERER_MAP` by **exact match**. Prefix matching is intentionally off — same architecture can ship different chat templates (base vs instruct, fine-tune renames). Fine-tunes must pass `renderer=` explicitly; unknown names fall back to `DefaultRenderer`. +Auto-detect matches `tokenizer.name_or_path` against `MODEL_RENDERER_MAP` by **exact match**. Prefix matching is intentionally off — same architecture can ship different chat templates (base vs instruct, fine-tune renames). Fine-tunes must pass an explicit typed config (e.g. `Qwen3RendererConfig()`); unknown names fall back to `DefaultRenderer`. ### Pools ```python from renderers import create_renderer_pool -pool = create_renderer_pool("Qwen/Qwen3-8B", renderer="auto", size=16) +pool = create_renderer_pool("Qwen/Qwen3-8B", size=16) with pool.checkout() as r: ids = r.render_ids(messages) ``` @@ -108,25 +108,50 @@ Empirical delta on Qwen3.5-35B-A3B + mini-swe-agent-plus, step 0: Each break fragments a rollout into multiple training samples — every fragment re-encodes its prefix, inflating compute roughly linearly with the number of breaks. -## Compaction overrides +## Typed renderer configs -`create_renderer` and `create_renderer_pool` accept two constructor-only flags: +Each renderer accepts a typed pydantic config that pins its template-control kwargs at construction. `create_renderer` and `create_renderer_pool` take one positional `config` argument: ```python -preserve_all_thinking: bool = False -preserve_thinking_between_tool_calls: bool = False +from renderers import ( + create_renderer, + AutoRendererConfig, + Qwen3RendererConfig, + GLM5RendererConfig, + DefaultRendererConfig, +) + +# Auto-resolve renderer from the tokenizer's model name. Carries the +# shared preserve_* flags; template kwargs require an explicit choice. +renderer = create_renderer(tokenizer) +renderer = create_renderer(tokenizer, AutoRendererConfig(preserve_all_thinking=True)) + +# Explicit choice — the typed config exposes exactly the fields that +# renderer's chat template honours. +renderer = create_renderer(tokenizer, Qwen3RendererConfig(enable_thinking=False)) +renderer = create_renderer(tokenizer, GLM5RendererConfig(clear_thinking=False)) + +# Default renderer (apply_chat_template fallback) — extra fields are +# captured via pydantic ``extra="allow"`` and forwarded to the Jinja +# template; tool / reasoning parsers are typed. +renderer = create_renderer( + tokenizer, + DefaultRendererConfig(tool_parser="qwen3", reasoning_parser="think"), +) ``` -Defaults preserve byte-identity with the model's chat template. Flipping a flag at construction restores `reasoning_content` the template would otherwise drop: +Discriminated union: every per-renderer config is a variant of `RendererConfig`, dispatched on the `name` field. Bogus combinations (e.g. `add_vision_id` under `name="qwen3"`) error at construction with a `pydantic.ValidationError`. Downstream pydantic configs (prime-rl orchestrator, verifiers `ClientConfig`) hold a single field typed as `RendererConfig` and inherit the same strict-per-variant validation. + +Two shared behaviour flags live on every variant via `_BaseRendererConfig`: -- `preserve_all_thinking=True` — every past assistant's reasoning is kept. -- `preserve_thinking_between_tool_calls=True` — reasoning is kept on assistants in the in-flight tool cycle (no-op for current renderers; reserved for future templates that drop it). +- `preserve_all_thinking=True` — every past assistant's `reasoning_content` is kept, even when the chat template would drop it. +- `preserve_thinking_between_tool_calls=True` — reasoning is kept on assistants in the in-flight tool cycle (post-last-user A-T-…-A block when it contains a tool response). A new user turn closes the block and drops its thinking. -The canonical use case is **compaction**. Injecting a `user` turn like *"summarize the work so far"* puts every prior assistant in a "past cycle", so template-default rules drop their `reasoning_content` before the summarizer sees it. Build the renderer with `preserve_all_thinking=True` to keep reasoning visible end-to-end on those flows. Both flags only ever *add* tokens vs the template default. +These OR-compose with template-level toggles (e.g. GLM-5 `clear_thinking`, Nemotron-3 `truncate_history_thinking`): either flag saying "keep" wins. preserve_* can only ever *extend* retention — never override a template kwarg into a "drop" decision. The canonical use case is **compaction**: injecting a `user` turn like *"summarize the work so far"* puts every prior assistant in a past cycle, and `preserve_all_thinking=True` keeps reasoning visible end-to-end. ## `DefaultRenderer` -Fallback for unsupported models. Wraps `apply_chat_template` and accepts `tool_parser` / `reasoning_parser` kwargs (vLLM convention). `bridge_to_next_turn` returns `None` because the template's close is unknown, so multi-turn rollouts fall back to full re-render. Implementing a hand-coded renderer is a few hundred lines of Python (`render_ids` + `parse_response` + `bridge_to_next_turn`) and is the only path that closes the failure modes above by construction. +Fallback for unsupported models. Wraps `apply_chat_template` and accepts `tool_parser` / `reasoning_parser` (vLLM convention) plus arbitrary Jinja kwargs via `DefaultRendererConfig`'s `extra="allow"`. `bridge_to_next_turn` returns `None` because the template's close is unknown, so multi-turn rollouts fall back to full re-render. Implementing a hand-coded renderer is a few hundred lines of Python (`render_ids` + `parse_response` + `bridge_to_next_turn`) and is the only path that closes the failure modes above by construction. ## Roadmap diff --git a/docs/renderer-config.md b/docs/renderer-config.md new file mode 100644 index 0000000..3d5fa2a --- /dev/null +++ b/docs/renderer-config.md @@ -0,0 +1,163 @@ +# Renderer config + +`renderers.RendererConfig` is the typed input to `create_renderer` and +`create_renderer_pool`. It pins the renderer choice and its template-control +kwargs at construction. + +```python +from renderers import create_renderer, Qwen35RendererConfig + +r = create_renderer(tokenizer, Qwen35RendererConfig(enable_thinking=False)) +``` + +`RendererConfig` is a pydantic discriminated union (one variant per renderer, +dispatched on the `name` field). Selecting a variant exposes exactly the +fields that renderer's chat template honours; anything else raises a +`pydantic.ValidationError` at construction. + +## Per-renderer configs + +Each hand-coded renderer has a typed config class with the template kwargs +its Jinja chat template reads. For example: + +| Renderer | Config class | Template fields | +|----------------|--------------------------|----------------------------------------------------------------| +| Qwen3 | `Qwen3RendererConfig` | `enable_thinking` | +| Qwen3.5 / 3.6 | `Qwen35RendererConfig` | `enable_thinking`, `add_vision_id` | +| Qwen3-VL | `Qwen3VLRendererConfig` | `add_vision_id` | +| GLM-5 / 5.1 | `GLM5RendererConfig` | `enable_thinking`, `clear_thinking` | +| GLM-4.5 | `GLM45RendererConfig` | `enable_thinking` | +| Nemotron-3 | `Nemotron3RendererConfig`| `enable_thinking`, `truncate_history_thinking` | +| Kimi K2.5 | `KimiK25RendererConfig` | `thinking` | +| MiniMax-M2 | `MiniMaxM2RendererConfig`| `model_identity` | +| Laguna-XS.2 | `LagunaXS2RendererConfig`| `enable_thinking`, `render_assistant_messages_raw` | +| gpt-oss | `GptOssRendererConfig` | `reasoning_effort`, `conversation_start_date` | + +Field names mirror the upstream Jinja variable names. Passing +`Qwen3RendererConfig(add_vision_id=True)` raises — Qwen3 is text-only, so +the field doesn't exist on its config. Use +`type(config).template_field_names()` to introspect the fields that mirror +chat-template kwargs (parity is verified against `apply_chat_template` in +`tests/test_renderer_config_parity.py`). + +Configs are frozen. To override a field, construct a new instance or call +`config.model_copy(update={...})`. + +## Auto-resolution + +`create_renderer(tokenizer)` (no config) resolves the renderer from +`tokenizer.name_or_path` via `MODEL_RENDERER_MAP`: + +```python +r = create_renderer(tokenizer) # AutoRendererConfig() is the default +r = create_renderer(tokenizer, AutoRendererConfig(preserve_all_thinking=True)) +``` + +`AutoRendererConfig` carries only the shared `preserve_*` flags. Template +kwargs depend on the renderer, so overriding them requires naming the +renderer explicitly: + +```python +r = create_renderer(tokenizer, GLM5RendererConfig(clear_thinking=False)) +``` + +Auto-resolution fails loudly for VLMs that miss the exact-match lookup — +`DefaultRenderer` only knows `apply_chat_template` + text tokens, so silently +falling back for a VLM would produce token streams the trainer can't +reconstruct. Text-only fine-tunes without a registered renderer fall back to +`DefaultRenderer` and log the choice at INFO. + +## `preserve_*` flags + +Every variant carries two renderer-agnostic flags on `_BaseRendererConfig`: + +- `preserve_all_thinking: bool = False` — re-emit `reasoning_content` on + every past assistant turn, even when the chat template would drop it. +- `preserve_thinking_between_tool_calls: bool = False` — re-emit + `reasoning_content` only inside the in-flight tool cycle (the contiguous + A-T-…-A block after the most recent `user` message, when it contains at + least one `tool` response). A new user turn closes the block and drops + its thinking. + +These OR-compose with template-level toggles. GLM-5's `clear_thinking` and +Nemotron-3's `truncate_history_thinking` already gate past thinking; the +`preserve_*` flags add to that: + +| `clear_thinking` | `preserve_all_thinking` | past thinking? | +|------------------|-------------------------|----------------| +| `True` (default — drop) | `False` (default) | dropped | +| `True` | `True` | kept | +| `False` (keep) | `False` | kept | +| `False` | `True` | kept | + +`preserve_*` can only extend retention, never force a drop. The canonical +use case is **compaction**: injecting a `user` turn like *"summarize the work +so far"* puts every prior assistant in a past cycle, and +`preserve_all_thinking=True` keeps reasoning visible end-to-end. + +## `DefaultRendererConfig` accepts arbitrary Jinja kwargs + +`DefaultRenderer` wraps `tokenizer.apply_chat_template` for any model that +doesn't have a hand-coded renderer. Its config sets `extra="allow"`: + +```python +from renderers import create_renderer, DefaultRendererConfig + +r = create_renderer( + tokenizer, + DefaultRendererConfig( + tool_parser="qwen3", # registered in renderers.parsers + reasoning_parser="think", + enable_thinking=False, # forwarded to apply_chat_template + custom_jinja_kwarg=True, # ditto + ), +) +``` + +`tool_parser` and `reasoning_parser` are typed because they configure +`DefaultRenderer`'s own parsing pipeline. Every other field lands in +`model_extra` and `DefaultRenderer._apply` forwards `model_extra` verbatim +to `apply_chat_template`. + +## Downstream integration + +Downstream pydantic configs (`prime-rl` orchestrator, `verifiers` +`ClientConfig`) hold a single field typed as `RendererConfig`: + +```python +from pydantic import BaseModel, Field +from renderers import AutoRendererConfig, RendererConfig + +class ClientConfig(BaseModel): + renderer: RendererConfig = Field(default_factory=AutoRendererConfig) +``` + +In TOML / YAML, the discriminator routes deserialization: + +```toml +[client.renderer] +name = "qwen3.5" +enable_thinking = false +add_vision_id = true +preserve_all_thinking = true +``` + +Pydantic dispatches on `name = "qwen3.5"` to `Qwen35RendererConfig`. Bogus +combinations (e.g. `add_vision_id` under `name = "qwen3"`) raise at +config-load with a clear message naming the offending field and the variant +that rejected it. + +To construct a config from a renderer name string (e.g. from a CLI flag): + +```python +from renderers import config_from_name + +cfg = config_from_name("glm-5") # → GLM5RendererConfig() with defaults +cfg = config_from_name("auto") # → None, the implicit "auto" form +``` + +## Renaming a renderer is a breaking change + +The discriminator key is the renderer name string. Renaming `"qwen3.5"` to +something else would break any downstream config that references it by +name. Add new renderers; don't rename existing ones. diff --git a/pyproject.toml b/pyproject.toml index 87f99d4..8016457 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -30,6 +30,11 @@ dependencies = [ # around ``from_pretrained``, so subsequent ``AutoTokenizer`` calls # outside the renderers package stay vanilla. "fastokens>=0.2.0", + # ``BaseRendererConfig`` inherits from ``pydantic_config.BaseConfig`` so + # the typed-config surface stays uniform with prime-rl / verifiers config + # bases. Transitively brings pydantic, which ``renderers.configs`` also + # imports directly. + "prime-pydantic-config>=0.3.0.dev83", ] [tool.hatch.version] @@ -73,7 +78,7 @@ exclude-newer = "7 days" # MiniMax-M2's slow→fast tokenizer conversion path. Exempting it from # the project-wide 7-day cutoff lets the lockfile pick it up immediately # while the rest of the dependency graph stays gated. -exclude-newer-package = { fastokens = false } +exclude-newer-package = { fastokens = false, "prime-pydantic-config" = false } [tool.ty.environment] python-version = "3.13" diff --git a/renderers/__init__.py b/renderers/__init__.py index 7c82510..0ae78ac 100644 --- a/renderers/__init__.py +++ b/renderers/__init__.py @@ -38,9 +38,30 @@ trim_to_turn_close, ) from renderers.client import OverlongPromptError +from renderers.configs import ( + AutoRendererConfig, + BaseRendererConfig, + config_from_name, + DefaultRendererConfig, + DeepSeekV3RendererConfig, + GLM45RendererConfig, + GLM51RendererConfig, + GLM5RendererConfig, + GptOssRendererConfig, + KimiK25RendererConfig, + KimiK2RendererConfig, + LagunaXS2RendererConfig, + MiniMaxM2RendererConfig, + Nemotron3RendererConfig, + Qwen35RendererConfig, + Qwen36RendererConfig, + Qwen3RendererConfig, + Qwen3VLRendererConfig, + RendererConfig, +) from renderers.deepseek_v3 import DeepSeekV3Renderer from renderers.default import DefaultRenderer -from renderers.glm5 import GLM5Renderer +from renderers.glm5 import GLM5Renderer, GLM51Renderer from renderers.glm45 import GLM45Renderer from renderers.gpt_oss import GptOssRenderer from renderers.kimi_k2 import KimiK2Renderer @@ -54,34 +75,53 @@ from renderers.qwen36 import Qwen36Renderer __all__ = [ + "AutoRendererConfig", + "BaseRendererConfig", "Content", "ContentPart", "DeepSeekV3Renderer", + "DeepSeekV3RendererConfig", "DefaultRenderer", + "DefaultRendererConfig", "GLM45Renderer", + "GLM45RendererConfig", + "GLM51Renderer", + "GLM51RendererConfig", "GLM5Renderer", + "GLM5RendererConfig", "GptOssRenderer", + "GptOssRendererConfig", "ImagePart", - "KimiK2Renderer", "KimiK25Renderer", + "KimiK25RendererConfig", + "KimiK2Renderer", + "KimiK2RendererConfig", "LagunaXS2Renderer", + "LagunaXS2RendererConfig", "MULTIMODAL_MODELS", "Message", "MiniMaxM2Renderer", + "MiniMaxM2RendererConfig", "MultiModalData", "MultimodalRenderer", "Nemotron3Renderer", + "Nemotron3RendererConfig", "OverlongPromptError", "ParsedResponse", "ParsedToolCall", "PlaceholderRange", - "Qwen3Renderer", - "Qwen3VLRenderer", "Qwen35Renderer", + "Qwen35RendererConfig", "Qwen36Renderer", + "Qwen36RendererConfig", + "Qwen3Renderer", + "Qwen3RendererConfig", + "Qwen3VLRenderer", + "Qwen3VLRendererConfig", "RenderedConversation", "RenderedTokens", "Renderer", + "RendererConfig", "RendererPool", "TextPart", "ThinkingPart", @@ -94,6 +134,7 @@ "attribute_text_segments", "build_training_sample", "build_trajectory_step", + "config_from_name", "create_renderer", "create_renderer_pool", "is_multimodal", diff --git a/renderers/base.py b/renderers/base.py index b861872..856ad4b 100644 --- a/renderers/base.py +++ b/renderers/base.py @@ -6,7 +6,18 @@ import threading from contextlib import contextmanager from dataclasses import dataclass, field -from typing import Any, Callable, Literal, Protocol, TypedDict, runtime_checkable +from typing import ( + TYPE_CHECKING, + Any, + Callable, + Literal, + Protocol, + TypedDict, + runtime_checkable, +) + +if TYPE_CHECKING: + from renderers.configs import AutoRendererConfig, RendererConfig logger = logging.getLogger("renderers.base") @@ -1166,27 +1177,22 @@ def _populate_registry(): def create_renderer_pool( tokenizer_name_or_path: str, - renderer: str = "auto", - size: int = 16, + config: RendererConfig | None = None, *, - tool_parser: str | None = None, - reasoning_parser: str | None = None, - preserve_all_thinking: bool = False, - preserve_thinking_between_tool_calls: bool = False, + size: int = 16, ) -> RendererPool: """Create a RendererPool with *size* independent tokenizer copies. - Each slot loads its own tokenizer so threads never share mutable state. - HuggingFace fast tokenizers release the GIL during Rust encoding, so - threads achieve real parallelism. - - ``tool_parser`` and ``reasoning_parser`` are forwarded to - ``create_renderer`` when the pool falls back to ``DefaultRenderer``. + Each slot loads its own tokenizer so threads never share mutable + state. HuggingFace fast tokenizers release the GIL during Rust + encoding, so threads achieve real parallelism. - ``preserve_all_thinking`` and ``preserve_thinking_between_tool_calls`` - are forwarded to each pooled renderer's constructor — every slot in - the pool shares one configuration. To run with a different - configuration, build a different pool. + ``config`` is the typed renderer config (one of the variants of + :data:`renderers.RendererConfig`). Defaults to + :class:`AutoRendererConfig`, which resolves to a concrete renderer + via ``MODEL_RENDERER_MAP`` at construction time using the loaded + tokenizer's name. Every slot in the pool shares the same config; to + run a different config, build a different pool. Tokenizers load via ``load_tokenizer`` — see its docstring for the ``trust_remote_code`` policy (default off; Moonshot Kimi-K2 family @@ -1195,92 +1201,76 @@ def create_renderer_pool( def factory() -> Renderer: tokenizer = load_tokenizer(tokenizer_name_or_path) - return create_renderer( - tokenizer, - renderer=renderer, - tool_parser=tool_parser, - reasoning_parser=reasoning_parser, - preserve_all_thinking=preserve_all_thinking, - preserve_thinking_between_tool_calls=preserve_thinking_between_tool_calls, - ) + return create_renderer(tokenizer, config) return RendererPool(factory, size=size) def create_renderer( tokenizer, - renderer: str = "auto", - *, - tool_parser: str | None = None, - reasoning_parser: str | None = None, - preserve_all_thinking: bool = False, - preserve_thinking_between_tool_calls: bool = False, + config: RendererConfig | None = None, ) -> Renderer: - """Create a Renderer by name, or auto-detect from the tokenizer's model name. + """Create a Renderer from a typed config. Args: tokenizer: HuggingFace tokenizer instance. - renderer: Renderer name ('qwen3', 'qwen3-vl', 'qwen3.5', 'qwen3.6', - 'glm-5', 'glm-5.1', 'glm-4.5', 'minimax-m2', 'deepseek-v3', - 'kimi-k2', 'kimi-k2.5', 'laguna-xs.2', 'nemotron-3', - 'gpt-oss', 'default') or 'auto' to detect from model name. - tool_parser: Name of a tool parser registered in ``renderers.parsers``. - Only consumed by DefaultRenderer. Model-specific renderers - have their own parsing wired in. - reasoning_parser: Name of a reasoning parser registered in - ``renderers.parsers``. Only consumed by DefaultRenderer. - preserve_all_thinking: Forwarded to the renderer's constructor. - When ``True``, the instance restores ``reasoning_content`` - the chat template would otherwise drop on historical - assistants — useful when a downstream pass (e.g. - compaction prompts the model with a fresh ``user`` turn - asking for a summary) would lose the trajectory's - reasoning. See ``Renderer.render`` and - ``should_preserve_past_thinking``. - preserve_thinking_between_tool_calls: Forwarded to the renderer's - constructor. ``True`` keeps reasoning on in-flight - tool-cycle assistants when the template would drop them. - See ``Renderer.render`` for semantics. + config: Typed renderer config — one of the variants of + :data:`renderers.RendererConfig`. ``None`` defaults to + :class:`AutoRendererConfig`, which resolves to a concrete + renderer using ``tokenizer.name_or_path`` against + ``MODEL_RENDERER_MAP``. To enable structured-output parsing + on the default renderer, pass :class:`DefaultRendererConfig` + with ``tool_parser`` / ``reasoning_parser`` set. To override + template-control kwargs (e.g. ``enable_thinking``), pass + the specific :class:`Qwen3RendererConfig`, + :class:`GLM5RendererConfig` etc. and set those fields. + + Selecting the auto-renderer for a model without a registered + renderer falls back to :class:`DefaultRenderer` for text-only models + and raises for VLMs (where ``apply_chat_template`` would silently + drop images). """ - _populate_registry() + from renderers.configs import AutoRendererConfig - default_kwargs: dict = {} - if tool_parser is not None: - default_kwargs["tool_parser"] = tool_parser - if reasoning_parser is not None: - default_kwargs["reasoning_parser"] = reasoning_parser + _populate_registry() - preserve_kwargs: dict = { - "preserve_all_thinking": preserve_all_thinking, - "preserve_thinking_between_tool_calls": preserve_thinking_between_tool_calls, - } + if config is None: + config = AutoRendererConfig() - if renderer != "auto": - cls = RENDERER_REGISTRY.get(renderer) + if not isinstance(config, AutoRendererConfig): + cls = RENDERER_REGISTRY.get(config.name) if cls is None: raise ValueError( - f"Unknown renderer {renderer!r}. Available: {', '.join(sorted(RENDERER_REGISTRY))}" + f"Unknown renderer {config.name!r}. " + f"Available: {', '.join(sorted(RENDERER_REGISTRY))}" ) - if renderer == "default": - return cls(tokenizer, **default_kwargs, **preserve_kwargs) - if default_kwargs: - logger.info( - "tool_parser / reasoning_parser are only consumed by " - "DefaultRenderer; ignoring for renderer=%r which has " - "built-in behavior.", - renderer, - ) - return cls(tokenizer, **preserve_kwargs) + return cls(tokenizer, config) + + return _resolve_auto(tokenizer, config) + + +def _resolve_auto(tokenizer, auto: AutoRendererConfig) -> Renderer: + """Map ``AutoRendererConfig`` → concrete typed config via the + tokenizer's ``name_or_path``, then instantiate the matching renderer. + + Fine-tunes and renamed checkpoints miss on purpose — their chat + template may differ from the original even when the architecture + matches, so silently mapping them would produce template-parity + bugs. Set ``config=`` explicitly for those. + """ + from renderers.configs import DefaultRendererConfig, _config_class_for - # Auto-detect from model name via exact match on the canonical HF id. - # Fine-tunes and renamed checkpoints miss on purpose — their chat - # template may differ from the original even when the architecture - # matches, so silently mapping them would produce template-parity - # bugs. Set ``renderer=`` explicitly for those. model_name = getattr(tokenizer, "name_or_path", "") renderer_name = MODEL_RENDERER_MAP.get(model_name) + + preserve_carry = { + "preserve_all_thinking": auto.preserve_all_thinking, + "preserve_thinking_between_tool_calls": auto.preserve_thinking_between_tool_calls, + } + if renderer_name is not None: - return RENDERER_REGISTRY[renderer_name](tokenizer, **preserve_kwargs) + cfg_cls = _config_class_for(renderer_name) + return RENDERER_REGISTRY[renderer_name](tokenizer, cfg_cls(**preserve_carry)) # No match. For VLMs this must be fatal: DefaultRenderer only knows # ``apply_chat_template`` + text tokens, so it would silently drop @@ -1294,20 +1284,26 @@ def create_renderer( f"No multimodal renderer registered for {model_name!r}, and " f"DefaultRenderer would silently drop images. Register a " f"renderer in MODEL_RENDERER_MAP (currently supported VLMs: " - f"{supported_vlms}), or pass ``renderer=''`` explicitly " - f"if you know what you're doing." + f"{supported_vlms}), or pass an explicit typed renderer " + f"config if you know what you're doing." ) # Text-only fall back to default (apply_chat_template). For fine-tunes - # with customized chat templates this is the *correct* choice, so we don't - # warn. Note the pick at INFO and advertise the parser knobs. + # with customized chat templates this is the *correct* choice, so we + # don't warn. Note the pick at INFO and advertise the parser knobs. + if auto.preserve_all_thinking or auto.preserve_thinking_between_tool_calls: + raise NotImplementedError( + "Auto-resolved DefaultRenderer can't selectively re-emit " + "dropped reasoning_content. Pass an explicit typed renderer " + "config (model-specific) if you need preserve_*_thinking." + ) logger.info( "No model-specific renderer matched %r. Using DefaultRenderer " - "(apply_chat_template). Pass tool_parser= or " - "reasoning_parser= to enable structured output parsing.", + "(apply_chat_template). Pass DefaultRendererConfig(tool_parser=..., " + "reasoning_parser=...) to enable structured output parsing.", model_name or "", ) - return RENDERER_REGISTRY["default"](tokenizer, **default_kwargs, **preserve_kwargs) + return RENDERER_REGISTRY["default"](tokenizer, DefaultRendererConfig()) # --------------------------------------------------------------------------- diff --git a/renderers/configs.py b/renderers/configs.py new file mode 100644 index 0000000..e0098ba --- /dev/null +++ b/renderers/configs.py @@ -0,0 +1,468 @@ +"""Typed renderer configs — one pydantic model per renderer, unified by a +discriminated union (``RendererConfig``). + +Each renderer accepts its own typed config; bad combinations (e.g. +``add_vision_id`` under ``name="qwen3"``) fail at config-load time with a +pydantic ``ValidationError`` rather than at runtime via an allowlist +check. The shared ``preserve_*`` flags live on ``BaseRendererConfig`` +and OR-compose with template-level toggles (e.g. GLM-5 +``clear_thinking``) inside each renderer — they extend retention, never +override the template into a drop. + +``AutoRendererConfig`` is a placeholder variant: ``create_renderer`` +resolves it via ``MODEL_RENDERER_MAP`` and constructs the matching +typed config with the auto config's ``preserve_*`` fields carried over. + +``DefaultRendererConfig`` uses ``extra="allow"`` to accept arbitrary +Jinja kwargs as ``model_extra`` — ``DefaultRenderer`` doesn't know which +keys its tokenizer's template will honour, so it can't enumerate them. +""" + +from __future__ import annotations + +from typing import Annotated, ClassVar, Literal, Union + +from pydantic import ConfigDict, Field +from pydantic_config import BaseConfig + + +class BaseRendererConfig(BaseConfig): + """Shared fields and config for every renderer config variant. + + Inherits from ``pydantic_config.BaseConfig`` so the typed-config + surface stays uniform with prime-rl / verifiers config bases. The + BaseConfig contract includes ``extra="forbid"`` (preserved here); + this class adds ``frozen=True`` so configs are hashable value + objects. + + ``preserve_all_thinking`` and ``preserve_thinking_between_tool_calls`` + are renderer-internal behaviour flags — they don't map to any Jinja + chat-template kwarg. They OR-compose with template-level toggles on + renderers that expose one (GLM-5 ``clear_thinking``, Nemotron-3 + ``truncate_history_thinking``): either flag saying "keep this + thinking" wins. preserve_* can only ever extend retention; setting + ``preserve_all_thinking=True`` always keeps past thinking, regardless + of the template kwarg. See ``renderers.base.should_preserve_past_thinking``. + """ + + model_config = ConfigDict(frozen=True) + + preserve_all_thinking: bool = False + """Restore ``reasoning_content`` on every past assistant turn, even + when the chat template would drop it. Strict superset of + ``preserve_thinking_between_tool_calls``.""" + + preserve_thinking_between_tool_calls: bool = False + """Restore ``reasoning_content`` only inside the in-flight tool cycle: + the contiguous A-T-...-A block after the most recent ``user`` turn, + and only if it contains at least one ``tool`` response. A new user + turn closes the block and drops its thinking (template default).""" + + # Fields that are renderer-internal — not forwarded to (or mirrored + # by) ``apply_chat_template``. Override in subclasses that hold + # non-template config (e.g. ``image_cache_max``, GptOss's + # ``use_system_prompt`` / ``knowledge_cutoff`` / ``model_identity``, + # or fields that exist as renderer conventions without a Jinja + # analogue like DeepSeek V3 / Kimi K2 ``enable_thinking``). + # + # Used by parity tests to compute the field subset that, when + # changed, must produce token streams matching + # ``apply_chat_template`` — see :meth:`template_field_names`. The + # renderer is the only end-to-end consumer of these fields, so this + # is a renderer-side bookkeeping concern rather than a public API. + _internal_fields: ClassVar[frozenset[str]] = frozenset() + + @classmethod + def template_field_names(cls) -> frozenset[str]: + """Subset of fields that mirror Jinja chat-template kwargs. + + Default: every non-base field except ``name`` and any field + listed in ``_internal_fields``. Used by the parity test matrix + (``tests/test_renderer_config_parity.py``) to discover the + cells that must agree with ``apply_chat_template``. + """ + base = frozenset(BaseRendererConfig.model_fields) + return frozenset(cls.model_fields) - base - {"name"} - cls._internal_fields + + +class AutoRendererConfig(BaseRendererConfig): + """Resolve the renderer from ``tokenizer.name_or_path`` at construction + time via ``MODEL_RENDERER_MAP``. Carries only the shared ``preserve_*`` + fields; template kwargs require an explicit renderer choice so that + template-dependent behaviour stays visible at the call site.""" + + name: Literal["auto"] = "auto" + + +class DefaultRendererConfig(BaseRendererConfig): + """Config for ``DefaultRenderer`` — the fallback wrapping + ``tokenizer.apply_chat_template``. Accepts arbitrary extra fields + via ``extra="allow"`` because the underlying Jinja template's kwargs + are unknown to us. ``DefaultRenderer`` forwards ``model_extra`` to + ``apply_chat_template`` verbatim. + """ + + model_config = ConfigDict(frozen=True, extra="allow") + + name: Literal["default"] = "default" + + tool_parser: str | None = None + """Name of a tool parser registered in ``renderers.parsers`` (e.g. + ``"qwen3"``, ``"glm"``). Consumed only by ``DefaultRenderer``.""" + + reasoning_parser: str | None = None + """Name of a reasoning parser registered in ``renderers.parsers`` + (e.g. ``"think"``). Consumed only by ``DefaultRenderer``.""" + + # tool_parser / reasoning_parser are renderer-internal — they configure + # DefaultRenderer's parsing pipeline, not the underlying Jinja + # template. Jinja kwargs live in ``model_extra`` (extra="allow"). + _internal_fields = frozenset({"tool_parser", "reasoning_parser"}) + + +class Qwen3RendererConfig(BaseRendererConfig): + """Qwen3 (text-only) renderer config.""" + + name: Literal["qwen3"] = "qwen3" + + enable_thinking: bool = True + """When ``True``, the generation prompt includes ```` so the + model continues into a thinking block. Mirrors the chat template's + ``enable_thinking`` kwarg.""" + + +class Qwen35RendererConfig(BaseRendererConfig): + """Qwen3.5 renderer config.""" + + name: Literal["qwen3.5"] = "qwen3.5" + + enable_thinking: bool | None = None + """When ``True``, the generation prompt includes ````. ``None`` + auto-detects from the tokenizer's chat-template default — Instruct + checkpoints default off, Thinking checkpoints default on. Mirrors + the chat template's ``enable_thinking`` kwarg.""" + + add_vision_id: bool = False + """When ``True``, prefix each ``<|vision_start|>`` placeholder with + ``"Picture N: "`` / ``"Video N: "`` where N is a 1-indexed counter + running across the entire conversation. Mirrors the chat template's + ``add_vision_id`` toggle.""" + + image_cache_max: int = 256 + """FIFO bound on the per-renderer image processor cache. Renderer- + internal — not a Jinja chat-template kwarg.""" + + _internal_fields = frozenset({"image_cache_max"}) + + +class Qwen36RendererConfig(BaseRendererConfig): + """Qwen3.6 renderer config. Inherits Qwen3.5's template surface.""" + + name: Literal["qwen3.6"] = "qwen3.6" + + enable_thinking: bool | None = None + """See :class:`Qwen35RendererConfig.enable_thinking`.""" + + add_vision_id: bool = False + """See :class:`Qwen35RendererConfig.add_vision_id`.""" + + image_cache_max: int = 256 + """See :class:`Qwen35RendererConfig.image_cache_max`.""" + + _internal_fields = frozenset({"image_cache_max"}) + + +class Qwen3VLRendererConfig(BaseRendererConfig): + """Qwen3-VL renderer config.""" + + name: Literal["qwen3-vl"] = "qwen3-vl" + + add_vision_id: bool = False + """See :class:`Qwen35RendererConfig.add_vision_id`.""" + + image_cache_max: int = 256 + """See :class:`Qwen35RendererConfig.image_cache_max`.""" + + _internal_fields = frozenset({"image_cache_max"}) + + +class GLM5RendererConfig(BaseRendererConfig): + """GLM-5 renderer config.""" + + name: Literal["glm-5"] = "glm-5" + + enable_thinking: bool = True + """When ``True``, the generation prompt includes ````. Mirrors + the chat template's ``enable_thinking`` kwarg.""" + + clear_thinking: bool = True + """When ``False``, the renderer keeps ``{reasoning}`` + on past-cycle assistant turns instead of dropping them. Mirrors the + chat template's ``clear_thinking`` toggle. OR-composes with + ``preserve_all_thinking`` / ``preserve_thinking_between_tool_calls`` + — see :class:`BaseRendererConfig` for the contract.""" + + +class GLM51RendererConfig(BaseRendererConfig): + """GLM-5.1 renderer config — same template surface as GLM-5, distinct + discriminator so the registry can route to ``GLM51Renderer``.""" + + name: Literal["glm-5.1"] = "glm-5.1" + + enable_thinking: bool = True + """See :class:`GLM5RendererConfig.enable_thinking`.""" + + clear_thinking: bool = True + """See :class:`GLM5RendererConfig.clear_thinking`.""" + + +class GLM45RendererConfig(BaseRendererConfig): + """GLM-4.5 Air renderer config.""" + + name: Literal["glm-4.5"] = "glm-4.5" + + enable_thinking: bool = True + """When ``True``, the generation prompt includes ````. Mirrors + the chat template's ``enable_thinking`` kwarg.""" + + +class GptOssRendererConfig(BaseRendererConfig): + """OpenAI gpt-oss (harmony) renderer config. + + Several fields here are renderer-internal: ``use_system_prompt``, + ``knowledge_cutoff``, and ``model_identity`` control how the renderer + builds the harmony ``SystemContent`` preamble and don't have direct + Jinja-kwarg analogues. They're typed config rather than Jinja kwargs + because users still want to set them — the distinction only matters + for downstream tooling that synthesises a Jinja-kwargs view (none + today, since vLLM is invoked via the token-in endpoint). + """ + + name: Literal["gpt-oss"] = "gpt-oss" + + reasoning_effort: Literal["low", "medium", "high"] = "medium" + """Harmony reasoning-effort tag. Mirrors the ``apply_chat_template`` + ``reasoning_effort`` kwarg.""" + + conversation_start_date: str | None = None + """ISO date string for the harmony preamble. ``None`` defers to + today's date at render time.""" + + use_system_prompt: bool = True + """Prepend the canonical harmony ``SystemContent`` preamble. Matches + HF's ``apply_chat_template`` behaviour.""" + + knowledge_cutoff: str | None = None + """Override the model's knowledge-cutoff string in the preamble. + ``None`` uses harmony's built-in default.""" + + model_identity: str | None = None + """Override the model-identity line in the preamble. ``None`` uses + harmony's built-in default.""" + + _internal_fields = frozenset( + {"use_system_prompt", "knowledge_cutoff", "model_identity"} + ) + + +class KimiK2RendererConfig(BaseRendererConfig): + """Kimi K2 renderer config. + + ``enable_thinking`` is renderer-internal here — Kimi K2's chat + template doesn't reference any thinking variable, so it's a no-op + against ``apply_chat_template`` parity. The field is kept for + protocol uniformity with the rest of the renderer family. + """ + + name: Literal["kimi-k2"] = "kimi-k2" + + enable_thinking: bool = True + """No-op for Kimi K2 (template doesn't gate on it). Stored for + introspection / cross-renderer uniformity.""" + + _internal_fields = frozenset({"enable_thinking"}) + + +class KimiK25RendererConfig(BaseRendererConfig): + """Kimi K2.5 renderer config.""" + + name: Literal["kimi-k2.5"] = "kimi-k2.5" + + thinking: bool = True + """When ``True``, the generation prompt prefills ````; when + ``False`` it prefills ````. The kwarg is named + ``thinking`` (not ``enable_thinking``) to match the upstream chat + template's native variable name.""" + + image_cache_max: int = 256 + """See :class:`Qwen35RendererConfig.image_cache_max`.""" + + _internal_fields = frozenset({"image_cache_max"}) + + +class LagunaXS2RendererConfig(BaseRendererConfig): + """Laguna XS.2 renderer config.""" + + name: Literal["laguna-xs.2"] = "laguna-xs.2" + + enable_thinking: bool = False + """When ``True``, the generation prompt includes ````. Mirrors + the chat template's ``enable_thinking`` kwarg. Default ``False`` + matches the upstream Jinja default for Laguna XS.2.""" + + render_assistant_messages_raw: bool = False + """When ``True``, assistant messages render as a passthrough: the + content bytes are emitted verbatim (no reasoning extraction, no + tool-call XML synthesis), and the ````/```` prefix + and ```` suffix are only added when missing. Mirrors the + chat template's ``render_assistant_messages_raw`` gate.""" + + +class MiniMaxM2RendererConfig(BaseRendererConfig): + """MiniMax M2 / M2.5 renderer config.""" + + name: Literal["minimax-m2"] = "minimax-m2" + + model_identity: str = "You are a helpful assistant. Your name is MiniMax-M2.5 and is built by MiniMax." + """Fallback persona used when no system message is supplied. Mirrors + the chat template's ``model_identity`` Jinja variable.""" + + +class Nemotron3RendererConfig(BaseRendererConfig): + """Nemotron 3 renderer config.""" + + name: Literal["nemotron-3"] = "nemotron-3" + + enable_thinking: bool = True + """When ``True``, the generation prompt includes ````. Mirrors + the chat template's ``enable_thinking`` kwarg.""" + + truncate_history_thinking: bool = True + """When ``False``, keep ``{reasoning}`` on past-cycle + assistant turns instead of dropping them. Mirrors the chat + template's ``truncate_history_thinking`` toggle. OR-composes with + ``preserve_all_thinking`` / ``preserve_thinking_between_tool_calls`` + — see :class:`BaseRendererConfig` for the contract.""" + + +class DeepSeekV3RendererConfig(BaseRendererConfig): + """DeepSeek V3 renderer config. + + ``enable_thinking`` is renderer-internal here — DeepSeek-V3's chat + template does not reference any thinking variable, so passing it to + ``apply_chat_template`` upstream is a no-op. The renderer uses it + to control the ```` prefill at the generation prompt (R1 + distill convention). + """ + + name: Literal["deepseek-v3"] = "deepseek-v3" + + enable_thinking: bool = True + """Renderer convention for the R1-distill family: when ``True``, + prefill ```` at the generation prompt. The DeepSeek-V3 Jinja + template ignores this kwarg upstream; it's not a chat-template + kwarg in the strict sense.""" + + _internal_fields = frozenset({"enable_thinking"}) + + +RendererConfig = Annotated[ + Union[ + AutoRendererConfig, + DefaultRendererConfig, + Qwen3RendererConfig, + Qwen35RendererConfig, + Qwen36RendererConfig, + Qwen3VLRendererConfig, + GLM5RendererConfig, + GLM51RendererConfig, + GLM45RendererConfig, + GptOssRendererConfig, + KimiK2RendererConfig, + KimiK25RendererConfig, + LagunaXS2RendererConfig, + MiniMaxM2RendererConfig, + Nemotron3RendererConfig, + DeepSeekV3RendererConfig, + ], + Field(discriminator="name"), +] +"""Discriminated union over every renderer config variant. + +Downstream pydantic configs (prime-rl orchestrator, verifiers +``ClientConfig``) can hold a single field typed as ``RendererConfig``; +deserialization dispatches on ``name`` and exposes strictly the kwargs +that renderer supports. Bogus combinations (e.g. ``add_vision_id`` under +``name="qwen3"``) raise ``ValidationError`` at config-load time. +""" + + +# Map discriminator → config class. Used by ``create_renderer`` when +# resolving ``AutoRendererConfig`` against ``MODEL_RENDERER_MAP``: the +# resolved renderer name picks the corresponding typed config, and the +# auto config's ``preserve_*`` fields are carried over. +_CONFIG_BY_NAME: dict[str, type[BaseRendererConfig]] = { + "auto": AutoRendererConfig, + "default": DefaultRendererConfig, + "qwen3": Qwen3RendererConfig, + "qwen3.5": Qwen35RendererConfig, + "qwen3.6": Qwen36RendererConfig, + "qwen3-vl": Qwen3VLRendererConfig, + "glm-5": GLM5RendererConfig, + "glm-5.1": GLM51RendererConfig, + "glm-4.5": GLM45RendererConfig, + "gpt-oss": GptOssRendererConfig, + "kimi-k2": KimiK2RendererConfig, + "kimi-k2.5": KimiK25RendererConfig, + "laguna-xs.2": LagunaXS2RendererConfig, + "minimax-m2": MiniMaxM2RendererConfig, + "nemotron-3": Nemotron3RendererConfig, + "deepseek-v3": DeepSeekV3RendererConfig, +} + + +def _config_class_for(name: str) -> type[BaseRendererConfig]: + cls = _CONFIG_BY_NAME.get(name) + if cls is None: + raise ValueError( + f"No renderer config registered for name={name!r}. " + f"Known: {sorted(_CONFIG_BY_NAME)}" + ) + return cls + + +def config_from_name(name: str) -> BaseRendererConfig | None: + """Construct a default-valued config for the given renderer name. + + Convenience for callers that hold a renderer name as a string and + want the matching typed config. ``"auto"`` returns ``None`` — + :func:`renderers.create_renderer` interprets that as "run auto + resolution against ``MODEL_RENDERER_MAP``", which is what callers + expect from a bare-string name. + """ + if name == "auto": + return None + return _config_class_for(name)() + + +__all__ = [ + "AutoRendererConfig", + "BaseRendererConfig", + "DefaultRendererConfig", + "DeepSeekV3RendererConfig", + "GLM45RendererConfig", + "GLM51RendererConfig", + "GLM5RendererConfig", + "GptOssRendererConfig", + "KimiK25RendererConfig", + "KimiK2RendererConfig", + "LagunaXS2RendererConfig", + "MiniMaxM2RendererConfig", + "Nemotron3RendererConfig", + "Qwen35RendererConfig", + "Qwen36RendererConfig", + "Qwen3RendererConfig", + "Qwen3VLRendererConfig", + "RendererConfig", + "config_from_name", +] diff --git a/renderers/deepseek_v3.py b/renderers/deepseek_v3.py index 507d81d..4efe3ef 100644 --- a/renderers/deepseek_v3.py +++ b/renderers/deepseek_v3.py @@ -25,6 +25,7 @@ reject_assistant_in_extension, trim_to_turn_close, ) +from renderers.configs import DeepSeekV3RendererConfig from renderers.parsing import parse_deepseek_v3 # Fullwidth vertical bar used in DeepSeek special token names. @@ -39,25 +40,25 @@ def _ds_token(name: str) -> str: class DeepSeekV3Renderer: - """Deterministic message → token renderer for DeepSeek V3 models.""" + """Deterministic message → token renderer for DeepSeek V3 models. + + DeepSeek-V3's chat template does not consult any thinking-related + variable; the ``enable_thinking`` field on the typed config controls + the renderer's ``\\n`` prefill at the generation prompt + (R1-distill convention) and is intentionally not forwarded to + ``apply_chat_template`` upstream — that would be a no-op. The + template also always emits ``{reasoning}`` when + ``reasoning_content`` is provided, so ``preserve_*`` flags are + no-ops here too; stored for protocol uniformity. + """ def __init__( self, tokenizer: PreTrainedTokenizer, - *, - enable_thinking: bool = True, - preserve_all_thinking: bool = False, - preserve_thinking_between_tool_calls: bool = False, + config: DeepSeekV3RendererConfig | None = None, ): - # DeepSeek-V3's chat template always emits ``{reasoning}`` - # when ``reasoning_content`` is provided — no drop, so the override - # flags are no-ops. Stored for introspection / Protocol parity only. self._tokenizer = tokenizer - self._enable_thinking = enable_thinking - self._preserve_all_thinking = preserve_all_thinking - self._preserve_thinking_between_tool_calls = ( - preserve_thinking_between_tool_calls - ) + self.config = config or DeepSeekV3RendererConfig() # ── BOS / EOS ──────────────────────────────────────────────── self._bos = self._get_special_token(f"begin{_US}of{_US}sentence") @@ -237,7 +238,7 @@ def emit_text_segments( emit_special( self._assistant_token, -1, is_sampled=False, is_content=False ) - if self._enable_thinking: + if self.config.enable_thinking: emit_text("\n", -1, is_sampled=False, is_content=False) return RenderedTokens( @@ -379,7 +380,7 @@ def emit_text( last_role = new_messages[-1].get("role") if new_messages else None if last_role != "tool": emit_special(self._assistant_token, -1) - if self._enable_thinking: + if self.config.enable_thinking: emit_text("\n", -1) total_len = len(previous_ids) + len(ext) diff --git a/renderers/default.py b/renderers/default.py index 4a15f05..e969421 100644 --- a/renderers/default.py +++ b/renderers/default.py @@ -19,9 +19,8 @@ RenderedTokens, ToolSpec, ) +from renderers.configs import DefaultRendererConfig from renderers.parsers import ( - ReasoningParser, - ToolParser, get_reasoning_parser, get_tool_parser, ) @@ -82,36 +81,30 @@ def _decode_tool_call_arguments(messages: list) -> list: class DefaultRenderer: """Fallback renderer using tokenizer.apply_chat_template(). - Works with any model. Pass ``tool_parser`` and/or ``reasoning_parser`` - (by name, resolved against the registries in ``renderers.parsers``) to - enable structured output extraction. + Works with any model. The config can carry ``tool_parser`` and/or + ``reasoning_parser`` (resolved against ``renderers.parsers``) to + enable structured output extraction, plus arbitrary additional Jinja + template kwargs captured as ``model_extra`` (``extra="allow"`` on + :class:`renderers.DefaultRendererConfig`). """ def __init__( self, tokenizer: PreTrainedTokenizer, - *, - tool_parser: str | ToolParser | None = None, - reasoning_parser: str | ReasoningParser | None = None, - preserve_all_thinking: bool = False, - preserve_thinking_between_tool_calls: bool = False, - **chat_template_kwargs, + config: DefaultRendererConfig | None = None, ): - if preserve_all_thinking or preserve_thinking_between_tool_calls: + cfg = config or DefaultRendererConfig() + if cfg.preserve_all_thinking or cfg.preserve_thinking_between_tool_calls: raise NotImplementedError( "DefaultRenderer falls back to apply_chat_template and can't " "selectively re-emit dropped reasoning_content. Configure a " "model-specific renderer if you need preserve_*_thinking." ) self._tokenizer = tokenizer - self._chat_template_kwargs = chat_template_kwargs - self._tool_parser = _resolve_parser(tool_parser, tokenizer, get_tool_parser) + self.config = cfg + self._tool_parser = _resolve_parser(cfg.tool_parser, tokenizer, get_tool_parser) self._reasoning_parser = _resolve_parser( - reasoning_parser, tokenizer, get_reasoning_parser - ) - self._preserve_all_thinking = preserve_all_thinking - self._preserve_thinking_between_tool_calls = ( - preserve_thinking_between_tool_calls + cfg.reasoning_parser, tokenizer, get_reasoning_parser ) @property @@ -151,7 +144,7 @@ def render( ) def _apply(self, messages, *, tools=None, add_generation_prompt=False) -> list[int]: - kwargs = dict(self._chat_template_kwargs) + kwargs = dict(self.config.model_extra or {}) kwargs["add_generation_prompt"] = add_generation_prompt kwargs["tokenize"] = True if tools is not None: diff --git a/renderers/glm45.py b/renderers/glm45.py index 206f366..ed0e0b7 100644 --- a/renderers/glm45.py +++ b/renderers/glm45.py @@ -24,6 +24,7 @@ reject_assistant_in_extension, should_preserve_past_thinking, ) +from renderers.configs import GLM45RendererConfig from renderers.parsing import parse_glm _TOOLS_HEADER = ( @@ -53,17 +54,10 @@ class GLM45Renderer: def __init__( self, tokenizer: PreTrainedTokenizer, - *, - enable_thinking: bool = True, - preserve_all_thinking: bool = False, - preserve_thinking_between_tool_calls: bool = False, + config: GLM45RendererConfig | None = None, ): self._tokenizer = tokenizer - self._enable_thinking = enable_thinking - self._preserve_all_thinking = preserve_all_thinking - self._preserve_thinking_between_tool_calls = ( - preserve_thinking_between_tool_calls - ) + self.config = config or GLM45RendererConfig() self._gmask = self._token_id("[gMASK]") self._sop = self._token_id("") @@ -204,7 +198,7 @@ def emit_text_segments( # ``/nothink`` suffix is scaffold the renderer injects # when ``enable_thinking=False``. user_segments: list[tuple[str, bool]] = [("\n", False), (content, True)] - if not self._enable_thinking and not content.endswith("/nothink"): + if not self.config.enable_thinking and not content.endswith("/nothink"): user_segments.append(("/nothink", False)) emit_text_segments(user_segments, i, is_sampled=False) @@ -212,8 +206,8 @@ def emit_text_segments( preserve_thinking = should_preserve_past_thinking( messages, i, - preserve_all_thinking=self._preserve_all_thinking, - preserve_thinking_between_tool_calls=self._preserve_thinking_between_tool_calls, + preserve_all_thinking=self.config.preserve_all_thinking, + preserve_thinking_between_tool_calls=self.config.preserve_thinking_between_tool_calls, ) self._render_assistant( msg, @@ -239,7 +233,7 @@ def emit_text_segments( # ── Generation prompt ─────────────────────────────────────── if add_generation_prompt: emit_special(self._assistant, -1, is_sampled=False, is_content=False) - if not self._enable_thinking: + if not self.config.enable_thinking: emit_text("\n", -1, is_sampled=False, is_content=False) emit_special(self._think, -1, is_sampled=False, is_content=False) emit_special(self._think_end, -1, is_sampled=False, is_content=False) @@ -378,7 +372,7 @@ def emit_text_segments( ("\n", False), (content, True), ] - if not self._enable_thinking and not content.endswith("/nothink"): + if not self.config.enable_thinking and not content.endswith("/nothink"): user_segments.append(("/nothink", False)) emit_text_segments(user_segments, i) elif role == "system": @@ -403,7 +397,7 @@ def emit_text_segments( # Generation prompt. emit_special(self._assistant, -1) - if not self._enable_thinking: + if not self.config.enable_thinking: emit_text("\n", -1) emit_special(self._think, -1) emit_special(self._think_end, -1) diff --git a/renderers/glm5.py b/renderers/glm5.py index 6de6ba3..f3e28e3 100644 --- a/renderers/glm5.py +++ b/renderers/glm5.py @@ -25,6 +25,7 @@ reject_assistant_in_extension, should_preserve_past_thinking, ) +from renderers.configs import GLM5RendererConfig, GLM51RendererConfig from renderers.parsing import parse_glm _TOOLS_HEADER = ( @@ -54,20 +55,18 @@ class GLM5Renderer: # GLM51Renderer; GLM-5 proper keeps this off. empty_think_on_last_assistant: bool = False + # GLM-5.1 uses the same template surface and binds the same kwargs. + # Subclassed in ``GLM51Renderer`` so the registry can dispatch on the + # ``glm-5.1`` discriminator while sharing this implementation. + _config_cls: type = GLM5RendererConfig + def __init__( self, tokenizer: PreTrainedTokenizer, - *, - enable_thinking: bool = True, - preserve_all_thinking: bool = False, - preserve_thinking_between_tool_calls: bool = False, + config: GLM5RendererConfig | GLM51RendererConfig | None = None, ): self._tokenizer = tokenizer - self._enable_thinking = enable_thinking - self._preserve_all_thinking = preserve_all_thinking - self._preserve_thinking_between_tool_calls = ( - preserve_thinking_between_tool_calls - ) + self.config = config or type(self)._config_cls() self._gmask = self._token_id("[gMASK]") self._sop = self._token_id("") @@ -220,8 +219,8 @@ def emit_text_segments( preserve_thinking = should_preserve_past_thinking( messages, i, - preserve_all_thinking=self._preserve_all_thinking, - preserve_thinking_between_tool_calls=self._preserve_thinking_between_tool_calls, + preserve_all_thinking=self.config.preserve_all_thinking, + preserve_thinking_between_tool_calls=self.config.preserve_thinking_between_tool_calls, ) self._render_assistant( msg, @@ -250,7 +249,7 @@ def emit_text_segments( # them. Always is_sampled=False / is_content=False. if add_generation_prompt: emit_special(self._assistant, -1, is_sampled=False, is_content=False) - if self._enable_thinking: + if self.config.enable_thinking: emit_special(self._think, -1, is_sampled=False, is_content=False) else: emit_special(self._think_end, -1, is_sampled=False, is_content=False) @@ -409,7 +408,7 @@ def emit_text_segments( # Generation prompt — match the gen-prompt branch of ``render()``. emit_special(self._assistant, -1) - if self._enable_thinking: + if self.config.enable_thinking: emit_special(self._think, -1) else: emit_special(self._think_end, -1) @@ -463,9 +462,14 @@ def _render_assistant( # ``preserve_thinking`` is the override output of # ``should_preserve_past_thinking`` — it adds historical assistants # back when the renderer was constructed with - # ``preserve_all_thinking=True``. + # ``preserve_all_thinking=True``. ``clear_thinking=False`` mirrors + # the template's per-call ``clear_thinking is defined and not + # clear_thinking`` gate: a chat_template_kwarg surface for the + # same behaviour, gated explicitly by the caller per render. include_thinking = ( - msg_idx > last_user_index or preserve_thinking + msg_idx > last_user_index + or preserve_thinking + or not self.config.clear_thinking ) and reasoning_content if include_thinking: @@ -478,13 +482,23 @@ def _render_assistant( reasoning_content.strip(), msg_idx, is_sampled=True, is_content=True ) emit_special(self._think_end, msg_idx, is_sampled=True, is_content=True) - elif self.empty_think_on_last_assistant and msg_idx > last_user_index: + elif ( + self.empty_think_on_last_assistant + and msg_idx > last_user_index + and self.config.enable_thinking + ): # GLM-5.1: wrap the last assistant with an empty # even without reasoning, matching the Jinja template. With # ``enable_thinking=True`` the gen prompt already includes # ````; the model then samples ```` to close an # empty think block. So ```` is scaffolding, # ```` is sampled. + # + # When ``enable_thinking=False`` the GLM-5.1 template skips + # the opening ```` for the most-recent assistant too + # — it emits only the lone ```` separator (and the + # gen prompt likewise switches to ````). Fall + # through to the else branch below so we match. emit_special(self._think, msg_idx, is_sampled=False, is_content=False) emit_special(self._think_end, msg_idx, is_sampled=True, is_content=True) else: @@ -587,6 +601,7 @@ class GLM51Renderer(GLM5Renderer): """ empty_think_on_last_assistant = True + _config_cls = GLM51RendererConfig @staticmethod def _format_tool_spec(tool: ToolSpec) -> str: diff --git a/renderers/gpt_oss.py b/renderers/gpt_oss.py index 9939de1..f1bb04a 100644 --- a/renderers/gpt_oss.py +++ b/renderers/gpt_oss.py @@ -60,6 +60,7 @@ should_preserve_past_thinking, trim_to_turn_close, ) +from renderers.configs import GptOssRendererConfig from renderers.parsing import parse_gpt_oss @@ -121,44 +122,27 @@ class GptOssRenderer: def __init__( self, tokenizer: PreTrainedTokenizer, - *, - use_system_prompt: bool = True, - reasoning_effort: str | None = "medium", - conversation_start_date: str | None = None, - knowledge_cutoff: str | None = None, - model_identity: str | None = None, - preserve_all_thinking: bool = False, - preserve_thinking_between_tool_calls: bool = False, + config: GptOssRendererConfig | None = None, ): """Initialise the renderer. Args: tokenizer: HuggingFace tokenizer. - use_system_prompt: When True (default), prepend the canonical - harmony SystemContent preamble. Matches HF's - apply_chat_template behaviour. - reasoning_effort: ``"low" | "medium" | "high"``. Default - ``"medium"`` (matches apply_chat_template). - conversation_start_date: Optional ISO date for the preamble. - Defaults to today's date in YYYY-MM-DD form. - knowledge_cutoff: Optional knowledge cutoff string. Harmony's - default is built into ``SystemContent.new()``. - model_identity: Optional override for the model identity line. + config: Typed renderer config (see + :class:`renderers.GptOssRendererConfig`). """ self._tokenizer = tokenizer + self.config = config or GptOssRendererConfig() self._enc: HarmonyEncoding = load_harmony_encoding( HarmonyEncodingName.HARMONY_GPT_OSS ) - self._use_system_prompt = use_system_prompt - self._reasoning_effort = _reasoning_effort(reasoning_effort) - self._conversation_start_date = ( - conversation_start_date or datetime.now().strftime("%Y-%m-%d") - ) - self._knowledge_cutoff = knowledge_cutoff - self._model_identity = model_identity - self._preserve_all_thinking = preserve_all_thinking - self._preserve_thinking_between_tool_calls = ( - preserve_thinking_between_tool_calls + # Materialised harmony-enum form of reasoning_effort. + self._reasoning_effort_enum = _reasoning_effort(self.config.reasoning_effort) + # ``conversation_start_date=None`` defers to today's date — + # materialise once at construction so renders within the same + # instance use a stable date. + self._conversation_start_date_resolved = ( + self.config.conversation_start_date or datetime.now().strftime("%Y-%m-%d") ) # Cache special-token IDs for the bridge / generation-prompt path. @@ -211,17 +195,21 @@ def _prefix_content_mask( # Build the same prefix with empty instructions. empty_prefix_msgs: list[HarmonyMessage] = [] - if self._use_system_prompt: + if self.config.use_system_prompt: sys_content = SystemContent.new().with_reasoning_effort( - self._reasoning_effort + self._reasoning_effort_enum ) sys_content = sys_content.with_conversation_start_date( - self._conversation_start_date + self._conversation_start_date_resolved ) - if self._knowledge_cutoff is not None: - sys_content = sys_content.with_knowledge_cutoff(self._knowledge_cutoff) - if self._model_identity is not None: - sys_content = sys_content.with_model_identity(self._model_identity) + if self.config.knowledge_cutoff is not None: + sys_content = sys_content.with_knowledge_cutoff( + self.config.knowledge_cutoff + ) + if self.config.model_identity is not None: + sys_content = sys_content.with_model_identity( + self.config.model_identity + ) empty_prefix_msgs.append( HarmonyMessage.from_role_and_content(Role.SYSTEM, sys_content) ) @@ -369,17 +357,21 @@ def emit_harmony_message( None, ) prefix_msgs: list[HarmonyMessage] = [] - if self._use_system_prompt: + if self.config.use_system_prompt: sys_content = SystemContent.new().with_reasoning_effort( - self._reasoning_effort + self._reasoning_effort_enum ) sys_content = sys_content.with_conversation_start_date( - self._conversation_start_date + self._conversation_start_date_resolved ) - if self._knowledge_cutoff is not None: - sys_content = sys_content.with_knowledge_cutoff(self._knowledge_cutoff) - if self._model_identity is not None: - sys_content = sys_content.with_model_identity(self._model_identity) + if self.config.knowledge_cutoff is not None: + sys_content = sys_content.with_knowledge_cutoff( + self.config.knowledge_cutoff + ) + if self.config.model_identity is not None: + sys_content = sys_content.with_model_identity( + self.config.model_identity + ) prefix_msgs.append( HarmonyMessage.from_role_and_content(Role.SYSTEM, sys_content) ) @@ -432,8 +424,8 @@ def emit_harmony_message( should_preserve_past_thinking( messages, i, - preserve_all_thinking=self._preserve_all_thinking, - preserve_thinking_between_tool_calls=self._preserve_thinking_between_tool_calls, + preserve_all_thinking=self.config.preserve_all_thinking, + preserve_thinking_between_tool_calls=self.config.preserve_thinking_between_tool_calls, ) ) for hm in self._to_harmony_messages( diff --git a/renderers/kimi_k2.py b/renderers/kimi_k2.py index 9e08141..54d6f53 100644 --- a/renderers/kimi_k2.py +++ b/renderers/kimi_k2.py @@ -26,31 +26,29 @@ reject_assistant_in_extension, trim_to_turn_close, ) +from renderers.configs import KimiK2RendererConfig from renderers.parsing import parse_kimi_k2 _DEFAULT_SYSTEM = "You are Kimi, an AI assistant created by Moonshot AI." class KimiK2Renderer: - """Deterministic message → token renderer for Kimi K2 models.""" + """Deterministic message → token renderer for Kimi K2 models. + + Kimi K2's chat template doesn't read any thinking-related variable — + ``content`` renders verbatim with no reasoning branch. The + ``enable_thinking`` / ``preserve_*`` fields on the config are stored + for protocol uniformity with the rest of the renderer family but + have no effect on the byte-level output. + """ def __init__( self, tokenizer: PreTrainedTokenizer, - *, - enable_thinking: bool = True, - preserve_all_thinking: bool = False, - preserve_thinking_between_tool_calls: bool = False, + config: KimiK2RendererConfig | None = None, ): - # Kimi-K2's chat template doesn't read ``reasoning_content`` for - # past assistant turns, so the override flags are no-ops. Stored - # for introspection / Protocol parity only. self._tokenizer = tokenizer - self._enable_thinking = enable_thinking - self._preserve_all_thinking = preserve_all_thinking - self._preserve_thinking_between_tool_calls = ( - preserve_thinking_between_tool_calls - ) + self.config = config or KimiK2RendererConfig() self._im_user = self._token_id("<|im_user|>") self._im_assistant = self._token_id("<|im_assistant|>") diff --git a/renderers/kimi_k25.py b/renderers/kimi_k25.py index b2a45e6..352a9ee 100644 --- a/renderers/kimi_k25.py +++ b/renderers/kimi_k25.py @@ -40,6 +40,7 @@ should_preserve_past_thinking, trim_to_turn_close, ) +from renderers.configs import KimiK25RendererConfig from renderers.parsing import parse_kimi_k2_section from renderers.qwen3_vl import ( _image_hash, @@ -562,8 +563,13 @@ class KimiK25Renderer: """Deterministic message → token renderer for Kimi K2.5 models. Renders to the same ``<|im_*|>`` format as Kimi K2 but adds: - - Generation prompt prefills ```` (enable_thinking=True, default) or - ```` (enable_thinking=False) to control thinking mode. + - Generation prompt prefills ```` (thinking=True, default) or + ```` (thinking=False) to control thinking mode. The + template's native kwarg name is ``thinking`` (not the more common + ``enable_thinking``); we mirror it on + :class:`renderers.KimiK25RendererConfig` so + ``KimiK25RendererConfig(thinking=False)`` produces the same tokens + as ``apply_chat_template(..., thinking=False)``. - Image content rendering via ``<|media_begin|>image<|media_content|>...<|media_end|>``. - TypeScript-style tool declarations instead of JSON. @@ -573,20 +579,13 @@ class KimiK25Renderer: def __init__( self, tokenizer: PreTrainedTokenizer, + config: KimiK25RendererConfig | None = None, *, processor: Any = None, - enable_thinking: bool = True, - preserve_all_thinking: bool = False, - preserve_thinking_between_tool_calls: bool = False, - image_cache_max: int = 256, ): self._tokenizer = tokenizer self._processor = processor - self._enable_thinking = enable_thinking - self._preserve_all_thinking = preserve_all_thinking - self._preserve_thinking_between_tool_calls = ( - preserve_thinking_between_tool_calls - ) + self.config = config or KimiK25RendererConfig() # Core structural tokens — all must be single special tokens in the vocab self._im_user = self._token_id("<|im_user|>") @@ -627,7 +626,6 @@ def __init__( # for Kimi (we emit a single placeholder regardless), but kept for # consistency / debugging. self._image_cache: dict[str, tuple[Any, int]] = {} - self._image_cache_max = image_cache_max @property def mm_token_type_id_map(self) -> dict[int, int]: @@ -679,7 +677,7 @@ def _process_image(self, part: dict[str, Any]): # Patch count via the processor's own calculator (matches the # model's per-patch attention count); kept for debugging. num_patches = int(img_proc.media_tokens_calculator(media_item)) - if len(self._image_cache) >= self._image_cache_max: + if len(self._image_cache) >= self.config.image_cache_max: self._image_cache.pop(next(iter(self._image_cache))) self._image_cache[h] = (out, num_patches) return pil, out, num_patches, h @@ -877,8 +875,8 @@ def emit_image( preserve_thinking = should_preserve_past_thinking( messages, i, - preserve_all_thinking=self._preserve_all_thinking, - preserve_thinking_between_tool_calls=self._preserve_thinking_between_tool_calls, + preserve_all_thinking=self.config.preserve_all_thinking, + preserve_thinking_between_tool_calls=self.config.preserve_thinking_between_tool_calls, ) self._render_assistant_body( msg, @@ -927,7 +925,7 @@ def emit_image( emit_special(self._im_assistant, -1, is_sampled=False, is_content=False) emit_text("assistant", -1, is_sampled=False, is_content=False) emit_special(self._im_middle, -1, is_sampled=False, is_content=False) - if self._enable_thinking: + if self.config.thinking: # Prefill open tag to trigger thinking mode emit_text("", -1, is_sampled=False, is_content=False) else: @@ -1161,7 +1159,7 @@ def emit_image( emit_special(self._im_assistant, -1) emit_text("assistant", -1) emit_special(self._im_middle, -1) - if self._enable_thinking: + if self.config.thinking: emit_text("", -1) else: emit_text("", -1) diff --git a/renderers/laguna_xs2.py b/renderers/laguna_xs2.py index ce85037..9d8c0b3 100644 --- a/renderers/laguna_xs2.py +++ b/renderers/laguna_xs2.py @@ -38,6 +38,7 @@ attribute_text_segments, reject_assistant_in_extension, ) +from renderers.configs import LagunaXS2RendererConfig from renderers.parsing import parse_laguna_xs2 _DEFAULT_SYSTEM_MESSAGE = ( @@ -79,20 +80,10 @@ class LagunaXS2Renderer: def __init__( self, tokenizer: PreTrainedTokenizer, - *, - enable_thinking: bool = False, - preserve_all_thinking: bool = False, - preserve_thinking_between_tool_calls: bool = False, + config: LagunaXS2RendererConfig | None = None, ): self._tokenizer = tokenizer - self._enable_thinking = enable_thinking - # Accepted for protocol uniformity. The chat template renders - # reasoning on every assistant message regardless, so flipping - # these flags has no effect on the byte-level output. - self._preserve_all_thinking = preserve_all_thinking - self._preserve_thinking_between_tool_calls = ( - preserve_thinking_between_tool_calls - ) + self.config = config or LagunaXS2RendererConfig() self._eos = self._token_id("〈|EOS|〉") self._think = self._token_id("") @@ -225,7 +216,7 @@ def emit_text_segments( tool_text += json.dumps(tool, ensure_ascii=False) + "\n" tool_text += ( _TOOLS_FOOTER_THINKING - if self._enable_thinking + if self.config.enable_thinking else _TOOLS_FOOTER_NO_THINKING ) emit_text(tool_text, -1, is_sampled=False, is_content=False) @@ -273,7 +264,7 @@ def emit_text_segments( if add_generation_prompt: emit_special(self._assistant, -1, is_sampled=False, is_content=False) emit_text("\n", -1, is_sampled=False, is_content=False) - if self._enable_thinking: + if self.config.enable_thinking: emit_special(self._think, -1, is_sampled=False, is_content=False) else: emit_special(self._think_end, -1, is_sampled=False, is_content=False) @@ -423,7 +414,7 @@ def emit_text_segments( emit_special(self._assistant, -1) emit_text("\n", -1) - if self._enable_thinking: + if self.config.enable_thinking: emit_special(self._think, -1) else: emit_special(self._think_end, -1) @@ -447,6 +438,15 @@ def _render_assistant( emit_text, emit_text_segments, ) -> None: + if self.config.render_assistant_messages_raw: + self._render_assistant_raw( + msg_idx, + content, + emit_special=emit_special, + emit_text=emit_text, + ) + return + reasoning_content = "" if isinstance(msg.get("reasoning_content"), str): reasoning_content = msg["reasoning_content"] @@ -518,3 +518,55 @@ def _render_assistant( # between turns and never sampled. emit_special(self._assistant_end, msg_idx, is_sampled=True, is_content=True) emit_text("\n", msg_idx, is_sampled=False, is_content=False) + + def _render_assistant_raw( + self, + msg_idx: int, + content: str, + *, + emit_special, + emit_text, + ) -> None: + """Passthrough assistant rendering matching the Jinja template's + ``render_assistant_messages_raw`` branch. + + Three pieces, each conditional on the content's own bytes: + + - Open the assistant turn (``\\n``) — always. + - Prepend the gen-prompt prefix (```` if + ``enable_thinking``, else ````) only when ``content`` + doesn't already start with it. This lets callers ship content + that already includes the prefix (e.g. raw rollouts) without + duplicating it. + - Emit ``content`` verbatim. ```` and ```` + land inside the content as added-vocab specials via the + tokenizer's default ``split_special_tokens=False`` behaviour, + matching what ``apply_chat_template`` does when it tokenises + the rendered string. + - Append ``\\n`` only when ``content`` doesn't end + with ```` (or ``\\n``), then always + emit the inter-turn ``\\n``. + + Tool calls are deliberately ignored in raw mode — the template + also ignores ``message.tool_calls`` here. Callers shipping raw + content are expected to embed any tool-call payload in the + content string themselves. + """ + emit_special(self._assistant, msg_idx, is_sampled=False, is_content=False) + emit_text("\n", msg_idx, is_sampled=False, is_content=False) + + if self.config.enable_thinking: + if not content.startswith(""): + emit_special(self._think, msg_idx, is_sampled=False, is_content=False) + else: + if not content.startswith(""): + emit_special( + self._think_end, msg_idx, is_sampled=False, is_content=False + ) + + emit_text(content, msg_idx, is_sampled=True, is_content=True) + + if not (content.endswith("\n") or content.endswith("")): + emit_text("\n", msg_idx, is_sampled=False, is_content=False) + emit_special(self._assistant_end, msg_idx, is_sampled=True, is_content=True) + emit_text("\n", msg_idx, is_sampled=False, is_content=False) diff --git a/renderers/minimax_m2.py b/renderers/minimax_m2.py index f3c26c8..39c12fa 100644 --- a/renderers/minimax_m2.py +++ b/renderers/minimax_m2.py @@ -26,12 +26,9 @@ should_preserve_past_thinking, trim_to_turn_close, ) +from renderers.configs import MiniMaxM2RendererConfig from renderers.parsing import parse_minimax -_DEFAULT_SYSTEM = ( - "You are a helpful assistant. Your name is MiniMax-M2.5 and is built by MiniMax." -) - _TOOLS_HEADER = ( "\n\n# Tools\n" "You may call one or more tools to assist with the user query.\n" @@ -59,17 +56,10 @@ class MiniMaxM2Renderer: def __init__( self, tokenizer: PreTrainedTokenizer, - *, - default_system: str = _DEFAULT_SYSTEM, - preserve_all_thinking: bool = False, - preserve_thinking_between_tool_calls: bool = False, + config: MiniMaxM2RendererConfig | None = None, ): self._tokenizer = tokenizer - self._default_system = default_system - self._preserve_all_thinking = preserve_all_thinking - self._preserve_thinking_between_tool_calls = ( - preserve_thinking_between_tool_calls - ) + self.config = config or MiniMaxM2RendererConfig() self._bos = self._token_id("]~!b[") self._role = self._token_id("]~b]") @@ -204,7 +194,7 @@ def emit_token_overlap_body( if sys_content: sys_segments.append((sys_content, True)) else: - sys_segments.append((self._default_system, False)) + sys_segments.append((self.config.model_identity, False)) if tools: sys_segments.append((_TOOLS_HEADER, False)) @@ -249,8 +239,8 @@ def emit_token_overlap_body( preserve_thinking = should_preserve_past_thinking( messages, orig_idx, - preserve_all_thinking=self._preserve_all_thinking, - preserve_thinking_between_tool_calls=self._preserve_thinking_between_tool_calls, + preserve_all_thinking=self.config.preserve_all_thinking, + preserve_thinking_between_tool_calls=self.config.preserve_thinking_between_tool_calls, ) self._render_assistant( msg, diff --git a/renderers/nemotron3.py b/renderers/nemotron3.py index e97790d..0d87f8b 100644 --- a/renderers/nemotron3.py +++ b/renderers/nemotron3.py @@ -29,6 +29,7 @@ should_preserve_past_thinking, trim_to_turn_close, ) +from renderers.configs import Nemotron3RendererConfig from renderers.parsing import parse_qwen35 # --------------------------------------------------------------------------- @@ -79,17 +80,10 @@ class Nemotron3Renderer: def __init__( self, tokenizer: PreTrainedTokenizer, - *, - enable_thinking: bool = True, - preserve_all_thinking: bool = False, - preserve_thinking_between_tool_calls: bool = False, + config: Nemotron3RendererConfig | None = None, ): self._tokenizer = tokenizer - self._enable_thinking = enable_thinking - self._preserve_all_thinking = preserve_all_thinking - self._preserve_thinking_between_tool_calls = ( - preserve_thinking_between_tool_calls - ) + self.config = config or Nemotron3RendererConfig() # Look up special token IDs from the tokenizer (not hardcoded). # <|endoftext|> is optional: Nemotron-3 Nano / Super tokenizers ship @@ -369,8 +363,8 @@ def emit_text_segments( preserve_thinking = msg_orig_idx >= 0 and should_preserve_past_thinking( original_messages, msg_orig_idx, - preserve_all_thinking=self._preserve_all_thinking, - preserve_thinking_between_tool_calls=self._preserve_thinking_between_tool_calls, + preserve_all_thinking=self.config.preserve_all_thinking, + preserve_thinking_between_tool_calls=self.config.preserve_thinking_between_tool_calls, ) self._render_assistant( msg, @@ -403,7 +397,7 @@ def emit_text_segments( if add_generation_prompt: emit_special(self._im_start, -1, is_sampled=False, is_content=False) emit_text("assistant\n", -1, is_sampled=False, is_content=False) - if self._enable_thinking: + if self.config.enable_thinking: emit_special(self._think, -1, is_sampled=False, is_content=False) emit_text("\n", -1, is_sampled=False, is_content=False) else: @@ -572,7 +566,7 @@ def emit_text_segments( # Generation prompt. emit_special(self._im_start, -1) emit_text("assistant\n", -1) - if self._enable_thinking: + if self.config.enable_thinking: emit_special(self._think, -1) emit_text("\n", -1) else: @@ -641,7 +635,11 @@ def _render_assistant( # , whether the content is empty or not. content_suffix = "\n" if tool_calls else "" - if reasoning_content and (is_last_turn or preserve_thinking): + if reasoning_content and ( + is_last_turn + or preserve_thinking + or not self.config.truncate_history_thinking + ): emit_special(self._think, msg_idx, is_sampled=True, is_content=True) emit_text( "\n" + reasoning_content + "\n", diff --git a/renderers/qwen3.py b/renderers/qwen3.py index 4562546..fe97561 100644 --- a/renderers/qwen3.py +++ b/renderers/qwen3.py @@ -23,6 +23,7 @@ should_preserve_past_thinking, trim_to_turn_close, ) +from renderers.configs import Qwen3RendererConfig from renderers.parsing import parse_qwen3 _TOOLS_HEADER = ( @@ -48,17 +49,10 @@ class Qwen3Renderer: def __init__( self, tokenizer: PreTrainedTokenizer, - *, - enable_thinking: bool = True, - preserve_all_thinking: bool = False, - preserve_thinking_between_tool_calls: bool = False, + config: Qwen3RendererConfig | None = None, ): self._tokenizer = tokenizer - self._enable_thinking = enable_thinking - self._preserve_all_thinking = preserve_all_thinking - self._preserve_thinking_between_tool_calls = ( - preserve_thinking_between_tool_calls - ) + self.config = config or Qwen3RendererConfig() self._im_start = self._token_id("<|im_start|>") self._im_end = self._token_id("<|im_end|>") @@ -213,8 +207,8 @@ def emit_text_segments( preserve_thinking = should_preserve_past_thinking( messages, i, - preserve_all_thinking=self._preserve_all_thinking, - preserve_thinking_between_tool_calls=self._preserve_thinking_between_tool_calls, + preserve_all_thinking=self.config.preserve_all_thinking, + preserve_thinking_between_tool_calls=self.config.preserve_thinking_between_tool_calls, ) self._render_assistant( msg, @@ -242,7 +236,7 @@ def emit_text_segments( if add_generation_prompt: emit_special(self._im_start, -1, is_sampled=False, is_content=False) emit_text("assistant\n", -1, is_sampled=False, is_content=False) - if not self._enable_thinking: + if not self.config.enable_thinking: emit_text( "\n\n\n\n", -1, is_sampled=False, is_content=False ) @@ -399,7 +393,7 @@ def emit_text_segments( emit_special(self._im_start, -1) emit_text("assistant\n", -1) - if not self._enable_thinking: + if not self.config.enable_thinking: emit_text("\n\n\n\n", -1) total_len = len(previous_ids) + len(ext) diff --git a/renderers/qwen35.py b/renderers/qwen35.py index 2deefcf..abcacec 100644 --- a/renderers/qwen35.py +++ b/renderers/qwen35.py @@ -31,6 +31,7 @@ should_preserve_past_thinking, trim_to_turn_close, ) +from renderers.configs import Qwen35RendererConfig from renderers.parsing import parse_qwen35 from renderers.qwen3_vl import ( _image_hash, @@ -103,25 +104,27 @@ def _detect_enable_thinking_default(tokenizer: PreTrainedTokenizer) -> bool: class Qwen35Renderer: """Deterministic message → token renderer for Qwen3.5 models.""" + _config_cls: type = Qwen35RendererConfig + def __init__( self, tokenizer: PreTrainedTokenizer, + config: Qwen35RendererConfig | None = None, *, processor: Any = None, - enable_thinking: bool | None = None, - preserve_all_thinking: bool = False, - preserve_thinking_between_tool_calls: bool = False, - image_cache_max: int = 256, ): self._tokenizer = tokenizer self._processor = processor - if enable_thinking is None: - enable_thinking = _detect_enable_thinking_default(tokenizer) - self._enable_thinking = enable_thinking - self._preserve_all_thinking = preserve_all_thinking - self._preserve_thinking_between_tool_calls = ( - preserve_thinking_between_tool_calls - ) + cfg = config or type(self)._config_cls() + # ``enable_thinking=None`` defers to the tokenizer's chat-template + # default (Instruct → off, Thinking → on). Materialise here so + # downstream reads see a concrete bool; rebind the config with + # the resolved value so introspection sees the same. + if cfg.enable_thinking is None: + cfg = cfg.model_copy( + update={"enable_thinking": _detect_enable_thinking_default(tokenizer)} + ) + self.config = cfg # Look up special token IDs from the tokenizer (not hardcoded) self._im_start = self._token_id("<|im_start|>") @@ -142,7 +145,6 @@ def __init__( # rationale (FIFO-bounded; same image seen across rollouts / # bridge re-renders). self._image_cache: dict[str, tuple[Any, int]] = {} - self._image_cache_max = image_cache_max @property def mm_token_type_id_map(self) -> dict[int, int]: @@ -187,7 +189,7 @@ def _process_image(self, part: dict[str, Any]): grid_thw = out["image_grid_thw"][0] merge_size = proc.image_processor.merge_size num_image_tokens = int(grid_thw.prod()) // (merge_size * merge_size) - if len(self._image_cache) >= self._image_cache_max: + if len(self._image_cache) >= self.config.image_cache_max: self._image_cache.pop(next(iter(self._image_cache))) self._image_cache[h] = (out, num_image_tokens) return pil, out, num_image_tokens, h @@ -296,6 +298,13 @@ def render( mm_hashes: dict[str, list[str]] = {} mm_placeholders: dict[str, list[PlaceholderRange]] = {} mm_items: dict[str, list[dict[str, Any]]] = {} + # 1-indexed counters for ``add_vision_id`` (mirrors the Jinja's + # ``image_count`` / ``video_count`` namespaces). Increment only + # in the main message loop — the template renders the system + # message with ``do_vision_count=False`` and would raise on + # vision in system content anyway, so the renderer's + # ``emit_image`` is only reached from user / tool emission paths. + vision_counts = {"image": 0, "video": 0} def emit_ids( ids: list[int], msg_idx: int, *, is_sampled: bool, is_content: bool @@ -350,6 +359,14 @@ def emit_image(part: dict[str, Any], msg_idx: int) -> None: # the surrounding ``<|vision_start|>`` / ``<|vision_end|>`` # specials are template scaffold. _, out, n, h = self._process_image(part) + vision_counts["image"] += 1 + if self.config.add_vision_id: + emit_text( + f"Picture {vision_counts['image']}: ", + msg_idx, + is_sampled=False, + is_content=False, + ) emit_special( self._vision_start, msg_idx, is_sampled=False, is_content=False ) @@ -488,8 +505,8 @@ def flush_buf() -> None: preserve_thinking = should_preserve_past_thinking( messages, i, - preserve_all_thinking=self._preserve_all_thinking, - preserve_thinking_between_tool_calls=self._preserve_thinking_between_tool_calls, + preserve_all_thinking=self.config.preserve_all_thinking, + preserve_thinking_between_tool_calls=self.config.preserve_thinking_between_tool_calls, ) self._render_assistant( msg, @@ -520,7 +537,7 @@ def flush_buf() -> None: if add_generation_prompt: emit_special(self._im_start, -1, is_sampled=False, is_content=False) emit_text("assistant\n", -1, is_sampled=False, is_content=False) - if self._enable_thinking: + if self.config.enable_thinking: emit_special(self._think, -1, is_sampled=False, is_content=False) emit_text("\n", -1, is_sampled=False, is_content=False) else: @@ -604,6 +621,23 @@ def bridge_to_next_turn( if previous_ids is None: return None + # ``add_vision_id`` numbers placeholders across the whole + # conversation. The bridge can only seed that counter from + # ``previous_multi_modal_data`` (raw prior token ids don't carry + # the image/video count back), so if the caller asks for + # ``add_vision_id=True`` while omitting prior mm-data on a + # conversation that already contains images, the bridged + # output would silently emit ``Picture 1:`` again. Refuse the + # bridge in that case — the caller falls back to a full + # re-render, which has the full message list and counts from + # scratch correctly. + if ( + self.config.add_vision_id + and previous_multi_modal_data is None + and self._vision_start in previous_ids + ): + return None + # Seed combined-token list with prior turn so placeholder offsets # are absolute in the bridged sequence (matching ``render()``). # Parallel ``indices``/``sampled`` are seeded with ``-1``/``False`` @@ -622,6 +656,17 @@ def bridge_to_next_turn( new_hashes: dict[str, list[str]] = {} new_placeholders: dict[str, list[PlaceholderRange]] = {} new_items: dict[str, list[dict[str, Any]]] = {} + # Seed the ``add_vision_id`` counters from prior-turn images / videos + # so the bridged turn's first placeholder gets ``Picture {prev+1}``. + # Bridges can't recover the count from raw token ids, so callers + # must thread ``previous_multi_modal_data`` through to keep + # ``add_vision_id`` parity across turns. + prev_image_count = 0 + prev_video_count = 0 + if previous_multi_modal_data is not None: + prev_image_count = len(previous_multi_modal_data.mm_items.get("image", [])) + prev_video_count = len(previous_multi_modal_data.mm_items.get("video", [])) + vision_counts = {"image": prev_image_count, "video": prev_video_count} def emit_special( token_id: int, @@ -664,6 +709,9 @@ def emit_text_segments( def emit_image(part: dict[str, Any], msg_idx: int = -1) -> None: _, out, n, h = self._process_image(part) + vision_counts["image"] += 1 + if self.config.add_vision_id: + emit_text(f"Picture {vision_counts['image']}: ", msg_idx) emit_special(self._vision_start, msg_idx) offset = len(tokens) for _ in range(n): @@ -755,7 +803,7 @@ def flush_buf() -> None: # Generation prompt — matches the gen-prompt branch of ``render()``. emit_special(self._im_start, -1) emit_text("assistant\n", -1) - if self._enable_thinking: + if self.config.enable_thinking: emit_special(self._think, -1) emit_text("\n", -1) else: diff --git a/renderers/qwen36.py b/renderers/qwen36.py index 5848194..6adf867 100644 --- a/renderers/qwen36.py +++ b/renderers/qwen36.py @@ -7,12 +7,12 @@ ``None`` as ``null`` (not ``None``), fixing the single-turn extension-break mode where a boolean parameter's case drifted across a re-render. -The template's other delta vs Qwen3.5 (a ``preserve_thinking`` toggle that -flips historical ```` retention on or off globally) is no longer -exposed as a constructor kwarg — its default-False behaviour matches -Qwen3.5 and is now baked in. Callers who want the toggled-on behaviour -pass ``preserve_all_thinking=True`` to ``create_renderer``, the -renderer-agnostic spelling of the same intent. +Historical-thinking retention follows Qwen3.5's default (drop past +```` blocks). The upstream template carries a ``preserve_thinking`` +Jinja toggle for the opposite polarity; on the renderer side that intent +maps to the renderer-agnostic ``preserve_all_thinking`` / +``preserve_thinking_between_tool_calls`` flags on +:class:`renderers.Qwen36RendererConfig`. Everything else — tool system prompt, tool-call XML structure, thinking markers, bridge logic, parser — is identical to Qwen3.5. @@ -23,12 +23,15 @@ import json from typing import Any +from renderers.configs import Qwen36RendererConfig from renderers.qwen35 import Qwen35Renderer class Qwen36Renderer(Qwen35Renderer): """Deterministic message → token renderer for Qwen3.6 models.""" + _config_cls = Qwen36RendererConfig + @staticmethod def _render_arg_value(arg_value: Any) -> str: if isinstance(arg_value, str): diff --git a/renderers/qwen3_vl.py b/renderers/qwen3_vl.py index 94ae13d..7287159 100644 --- a/renderers/qwen3_vl.py +++ b/renderers/qwen3_vl.py @@ -46,6 +46,7 @@ reject_assistant_in_extension, trim_to_turn_close, ) +from renderers.configs import Qwen3VLRendererConfig from renderers.parsing import parse_qwen3 _TOOLS_HEADER = ( @@ -291,35 +292,30 @@ class Qwen3VLRenderer: Constructor args: tokenizer: HF tokenizer for the model. + config: Typed renderer config (see + :class:`renderers.Qwen3VLRendererConfig`). Defaults to a + blank config with template defaults. processor: Optional ``Qwen3VLProcessor``. Required when rendering messages that contain image / video parts. If not supplied, the renderer lazy-loads it via ``AutoProcessor.from_pretrained`` keyed off ``tokenizer.name_or_path`` the first time a multimodal part is seen. - preserve_all_thinking / preserve_thinking_between_tool_calls: - No-ops on Qwen3-VL — the chat template already drops past - ```` blocks unconditionally. Stored for Protocol parity. - image_cache_max: Max entries in the per-instance image-processor - cache (FIFO eviction). Default 256 covers typical RL pools - (``rollouts_per_example`` × in-flight examples). Bump for runs - with large image sets where the working set exceeds the cap. + + ``preserve_all_thinking`` / ``preserve_thinking_between_tool_calls`` + on the config are no-ops here — the chat template drops past + ```` blocks unconditionally. Stored for Protocol parity. """ def __init__( self, tokenizer: PreTrainedTokenizer, + config: Qwen3VLRendererConfig | None = None, *, processor: Any = None, - preserve_all_thinking: bool = False, - preserve_thinking_between_tool_calls: bool = False, - image_cache_max: int = 256, ): self._tokenizer = tokenizer self._processor = processor - self._preserve_all_thinking = preserve_all_thinking - self._preserve_thinking_between_tool_calls = ( - preserve_thinking_between_tool_calls - ) + self.config = config or Qwen3VLRendererConfig() self._im_start = self._token_id("<|im_start|>") self._im_end = self._token_id("<|im_end|>") @@ -342,7 +338,6 @@ def __init__( # tuples of ``(processor_out, num_image_tokens)`` — bounded to # avoid unbounded growth on long-lived pools. self._image_cache: dict[str, tuple[Any, int]] = {} - self._image_cache_max = image_cache_max def _token_id(self, token: str) -> int: tid = self._tokenizer.convert_tokens_to_ids(token) @@ -431,7 +426,7 @@ def _process_image(self, part: dict[str, Any]): grid_thw = out["image_grid_thw"][0] merge_size = proc.image_processor.merge_size num_image_tokens = int(grid_thw.prod()) // (merge_size * merge_size) - if len(self._image_cache) >= self._image_cache_max: + if len(self._image_cache) >= self.config.image_cache_max: # FIFO eviction — Python dicts preserve insertion order, so # ``next(iter(...))`` is the oldest key. self._image_cache.pop(next(iter(self._image_cache))) @@ -452,6 +447,12 @@ def render( mm_hashes: dict[str, list[str]] = {} mm_placeholders: dict[str, list[PlaceholderRange]] = {} mm_items: dict[str, list[dict[str, Any]]] = {} + # ``add_vision_id`` mirrors the Jinja's ``image_count`` / + # ``video_count`` namespaces. Counters are 1-indexed and run + # across the entire conversation; they increment unconditionally + # on each image / video (the Qwen3-VL template increments first, + # then emits ``Picture N: `` only when ``add_vision_id`` is set). + vision_counts = {"image": 0, "video": 0} def emit_image(part: dict[str, Any]) -> None: # Image placeholders are prompt-side scaffolding the user @@ -462,6 +463,13 @@ def emit_image(part: dict[str, Any]) -> None: # the surrounding ``<|vision_start|>`` / ``<|vision_end|>`` # markers are renderer-emitted scaffold. _, out, n, h = self._process_image(part) + vision_counts["image"] += 1 + if self.config.add_vision_id: + em.text( + f"Picture {vision_counts['image']}: ", + is_sampled=False, + is_content=False, + ) em.special(self._vision_start, is_sampled=False, is_content=False) offset = em.cursor() for _ in range(n): @@ -663,6 +671,23 @@ def bridge_to_next_turn( if previous_ids is None: return None + # ``add_vision_id`` numbers placeholders across the whole + # conversation. The bridge can only seed that counter from + # ``previous_multi_modal_data`` (raw prior token ids don't carry + # the image/video count back), so if the caller asks for + # ``add_vision_id=True`` while omitting prior mm-data on a + # conversation that already contains images, the bridged + # output would silently emit ``Picture 1:`` again. Refuse the + # bridge in that case — the caller falls back to a full + # re-render, which has the full message list and counts + # correctly. + if ( + self.config.add_vision_id + and previous_multi_modal_data is None + and self._vision_start in previous_ids + ): + return None + # Bridge populates ``message_indices`` (relative to ``new_messages``) # and ``sampled_mask`` (uniformly ``False`` — every token the # bridge emits is template scaffolding for the next prompt, not @@ -685,9 +710,30 @@ def bridge_to_next_turn( new_hashes: dict[str, list[str]] = {} new_placeholders: dict[str, list[PlaceholderRange]] = {} new_items: dict[str, list[dict[str, Any]]] = {} + # Seed the vision counters from any prior-turn images / videos + # the bridge was handed via ``previous_multi_modal_data``. The + # ``add_vision_id`` template numbers placeholders across the + # whole conversation, so a new turn's first image is + # ``Picture {prev_total + 1}``. The bridge can't recover this + # count from raw token ids, so callers must thread + # ``previous_multi_modal_data`` through when they want + # ``add_vision_id`` parity across turns. + prev_image_count = 0 + prev_video_count = 0 + if previous_multi_modal_data is not None: + prev_image_count = len(previous_multi_modal_data.mm_items.get("image", [])) + prev_video_count = len(previous_multi_modal_data.mm_items.get("video", [])) + vision_counts = {"image": prev_image_count, "video": prev_video_count} def emit_image(part: dict[str, Any]) -> None: _, out, n, h = self._process_image(part) + vision_counts["image"] += 1 + if self.config.add_vision_id: + em.text( + f"Picture {vision_counts['image']}: ", + is_sampled=False, + is_content=False, + ) em.special(self._vision_start, is_sampled=False, is_content=False) offset = em.cursor() for _ in range(n): diff --git a/tests/conftest.py b/tests/conftest.py index 8eea97b..c334430 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -10,6 +10,7 @@ from renderers import create_renderer from renderers.base import load_tokenizer +from renderers.configs import config_from_name # (HuggingFace model name, renderer name or "auto") # @@ -44,7 +45,7 @@ def _load(model_name: str, renderer_name: str): key = f"{model_name}:{renderer_name}" if key not in _cache: tokenizer = load_tokenizer(model_name) - renderer = create_renderer(tokenizer, renderer=renderer_name) + renderer = create_renderer(tokenizer, config_from_name(renderer_name)) _cache[key] = (tokenizer, renderer) return _cache[key] diff --git a/tests/test_bridge.py b/tests/test_bridge.py index c5e7c24..81ff2e4 100644 --- a/tests/test_bridge.py +++ b/tests/test_bridge.py @@ -41,9 +41,10 @@ def _load(model_name: str, renderer_name: str): from renderers import create_renderer from renderers.base import load_tokenizer + from renderers.configs import config_from_name tok = load_tokenizer(model_name) - return tok, create_renderer(tok, renderer=renderer_name) + return tok, create_renderer(tok, config_from_name(renderer_name)) def pytest_generate_tests(metafunc): diff --git a/tests/test_gpt_oss_harmony_parity.py b/tests/test_gpt_oss_harmony_parity.py index f1a8f17..ea52af6 100644 --- a/tests/test_gpt_oss_harmony_parity.py +++ b/tests/test_gpt_oss_harmony_parity.py @@ -33,6 +33,7 @@ ToolDescription, load_harmony_encoding, ) +from renderers.configs import GptOssRendererConfig from renderers.gpt_oss import GptOssRenderer from transformers import AutoTokenizer @@ -49,7 +50,9 @@ def tokenizer(): def renderer(tokenizer): # Pin the date so the rendered preamble matches the harmony oracle # built with the same fixed date. - return GptOssRenderer(tokenizer, conversation_start_date=DATE_FOR_PARITY) + return GptOssRenderer( + tokenizer, GptOssRendererConfig(conversation_start_date=DATE_FOR_PARITY) + ) @pytest.fixture(scope="module") diff --git a/tests/test_message_indices.py b/tests/test_message_indices.py index 37da9ce..b66efb7 100644 --- a/tests/test_message_indices.py +++ b/tests/test_message_indices.py @@ -99,7 +99,7 @@ def test_kimi_k2_unknown_role_message_indices(): from renderers.base import load_tokenizer tok = load_tokenizer("moonshotai/Kimi-K2-Instruct") - renderer = create_renderer(tok, renderer="auto") + renderer = create_renderer(tok) msgs = [ {"role": "user", "content": "hi"}, diff --git a/tests/test_multimodal.py b/tests/test_multimodal.py index 118d401..28984e4 100644 --- a/tests/test_multimodal.py +++ b/tests/test_multimodal.py @@ -33,7 +33,16 @@ Qwen3VLRenderer, create_renderer, ) -from renderers.base import load_tokenizer +from renderers.base import MODEL_RENDERER_MAP, load_tokenizer +from renderers.configs import _config_class_for + + +def _config_with_add_vision_id(model_name: str, add_vision_id: bool): + """Build the typed config for ``model_name`` (resolved via + ``MODEL_RENDERER_MAP``) with ``add_vision_id`` set. The qwen_vl + family — Qwen3.5 and Qwen3-VL — both expose this field.""" + renderer_name = MODEL_RENDERER_MAP[model_name] + return _config_class_for(renderer_name)(add_vision_id=add_vision_id) pytest.importorskip("PIL", reason="Pillow required for multimodal tests") @@ -111,7 +120,7 @@ def _load_processor_and_renderer(model_name: str): ) else: processor = AutoProcessor.from_pretrained(model_name) - renderer = create_renderer(tokenizer, renderer="auto") + renderer = create_renderer(tokenizer) # Inject processor so the renderer doesn't try to fetch it lazily. if hasattr(renderer, "_processor") and renderer._processor is None: renderer._processor = processor @@ -636,7 +645,7 @@ def test_modality_registry_models_route_to_renderer(): if not _hf_snapshot_cached(model_name): continue tokenizer = load_tokenizer(model_name) - renderer = create_renderer(tokenizer, renderer="auto") + renderer = create_renderer(tokenizer) # We expect a hand-coded VL renderer, not the default fallback. assert not type(renderer).__name__.startswith("Default"), ( f"{model_name} routed to DefaultRenderer despite being in " @@ -682,6 +691,176 @@ def test_tool_response_image_byte_parity(mm_model_name, modality, tiny_image): ) +def _qwen_vl_processor_input_ids_with_kwargs( + processor, messages, add_gp, **template_kwargs +): + """Variant of ``_qwen_vl_processor_input_ids`` that forwards + ``template_kwargs`` to ``apply_chat_template`` so the parity oracle + can exercise the same typed-config template field the renderer was + constructed with (e.g. ``add_vision_id=True``). + """ + text = processor.apply_chat_template( + messages, + tokenize=False, + add_generation_prompt=add_gp, + **template_kwargs, + ) + images = [] + for msg in messages: + content = msg.get("content") + if not isinstance(content, list): + continue + for item in content: + if not isinstance(item, dict): + continue + if ( + item.get("type") in ("image", "image_url") + or "image" in item + or "image_url" in item + ): + if "image" in item and not isinstance(item["image"], dict): + images.append(item["image"]) + return processor(images=images, text=text, return_tensors="pt")["input_ids"][ + 0 + ].tolist() + + +# ``add_vision_id`` is exposed on the Qwen-VL family renderers +# (Qwen3.5 / Qwen3.6 / Qwen3-VL) per the chat-template audit. Kimi K2.5 +# / K2.6's template has no equivalent toggle, so it's intentionally +# absent from ``KimiK25RendererConfig`` and skipped here. +_ADD_VISION_ID_CASES = [ + (m, mo) for m, mo in _CASES if mo == "image" and _detect_family(m) == "qwen_vl" +] + + +@pytest.mark.parametrize( + "mm_model_name,modality", + _ADD_VISION_ID_CASES, + ids=[f"{m}|{mo}" for m, mo in _ADD_VISION_ID_CASES], +) +@pytest.mark.parametrize("add_vision_id", [True, False]) +def test_add_vision_id_parity_vs_processor( + mm_model_name, modality, add_vision_id, tiny_image +): + """Parity for ``add_vision_id`` across image-bearing shapes. + + When True, the renderer must prefix each image / video placeholder + with ``Picture N: `` / ``Video N: `` matching the Jinja template's + ``image_count`` / ``video_count`` namespaces. When False, the + prefix is suppressed entirely. Both branches must reproduce + ``processor.apply_chat_template(messages, add_vision_id=)`` + token-for-token after image expansion. + """ + if not _hf_snapshot_cached(mm_model_name): + pytest.skip(f"{mm_model_name}: HF snapshot not cached locally") + + kit = _modality_kit(modality, mm_model_name) + tokenizer, processor, _ = _load_processor_and_renderer(mm_model_name) + # Build a fresh renderer for the kwarg under test (the shared + # fixture has ``add_vision_id=False`` baked in). + renderer = create_renderer( + tokenizer, + _config_with_add_vision_id(mm_model_name, add_vision_id), + ) + if hasattr(renderer, "_processor") and renderer._processor is None: + renderer._processor = processor + + for case in _build_cases(kit["make_part"], tiny_image): + messages, add_gp = case.values + ours = renderer.render_ids(messages, add_generation_prompt=add_gp) + theirs = _qwen_vl_processor_input_ids_with_kwargs( + processor, messages, add_gp, add_vision_id=add_vision_id + ) + assert ours == theirs, ( + f"{mm_model_name} / add_vision_id={add_vision_id} / " + f"case={case.id}: renderer diverges from processor.\n" + f" ours[:80]={ours[:80]}\n theirs[:80]={theirs[:80]}\n" + f" len(ours)={len(ours)} len(theirs)={len(theirs)}" + ) + + +@pytest.mark.parametrize( + "mm_model_name,modality", + _ADD_VISION_ID_CASES, + ids=[f"{m}|{mo}" for m, mo in _ADD_VISION_ID_CASES], +) +def test_bridge_refuses_when_add_vision_id_loses_prior_count( + mm_model_name, modality, tiny_image +): + """When ``add_vision_id=True``, the bridge needs the prior turn's + image / video count to keep the ``Picture N:`` numbering correct. + The only source of that count for the bridged turn is + ``previous_multi_modal_data``; raw prior token ids don't carry it + back unambiguously (``<|vision_start|>`` is shared between image + and video placeholders). + + If a caller omits ``previous_multi_modal_data`` on a conversation + that already contains images, naively continuing the bridge would + emit ``Picture 1:`` again for the new turn — diverging from + ``apply_chat_template`` and a full re-render. The bridge must + refuse (return None) so the caller falls back to a full re-render. + """ + if not _hf_snapshot_cached(mm_model_name): + pytest.skip(f"{mm_model_name}: HF snapshot not cached locally") + + kit = _modality_kit(modality, mm_model_name) + tokenizer, processor, _ = _load_processor_and_renderer(mm_model_name) + renderer = create_renderer( + tokenizer, + _config_with_add_vision_id(mm_model_name, True), + ) + if hasattr(renderer, "_processor") and renderer._processor is None: + renderer._processor = processor + + initial = [ + { + "role": "user", + "content": [ + kit["make_part"](tiny_image), + {"type": "text", "text": "Turn one."}, + ], + } + ] + new_messages = [ + { + "role": "user", + "content": [ + kit["make_part"](tiny_image), + {"type": "text", "text": "Turn two."}, + ], + } + ] + + initial_rendered = renderer.render(initial, add_generation_prompt=True) + im_end_id = tokenizer.convert_tokens_to_ids("<|im_end|>") + completion_ids = tokenizer.encode("Saw it.", add_special_tokens=False) + [im_end_id] + + # No previous_multi_modal_data → bridge must refuse so the caller + # falls back to a full re-render (where the counter restarts from + # the full message list and lands on Picture 2: correctly). + bridged = renderer.bridge_to_next_turn( + previous_prompt_ids=initial_rendered.token_ids, + previous_completion_ids=completion_ids, + new_messages=new_messages, + ) + assert bridged is None, ( + f"{mm_model_name}: bridge should refuse when add_vision_id=True " + "and previous_multi_modal_data is omitted but prior contains images" + ) + + # With the prior mm_data threaded through, the bridge proceeds. + bridged_ok = renderer.bridge_to_next_turn( + previous_prompt_ids=initial_rendered.token_ids, + previous_completion_ids=completion_ids, + new_messages=new_messages, + previous_multi_modal_data=initial_rendered.multi_modal_data, + ) + assert bridged_ok is not None, ( + f"{mm_model_name}: bridge unexpectedly refused even with previous_multi_modal_data" + ) + + def test_qwen3_vl_renderer_exposes_image_modality(): """The flagship multimodal renderer is concretely Qwen3VLRenderer. @@ -692,7 +871,7 @@ def test_qwen3_vl_renderer_exposes_image_modality(): if not _hf_snapshot_cached(model): pytest.skip(f"{model}: HF snapshot not cached locally") tokenizer = load_tokenizer(model) - renderer = create_renderer(tokenizer, renderer="auto") + renderer = create_renderer(tokenizer) assert isinstance(renderer, Qwen3VLRenderer) assert "image" in MULTIMODAL_MODELS[model] diff --git a/tests/test_parse_response.py b/tests/test_parse_response.py index bc17544..9039589 100644 --- a/tests/test_parse_response.py +++ b/tests/test_parse_response.py @@ -13,7 +13,7 @@ @lru_cache def _qwen3_vl(): tokenizer = load_tokenizer("Qwen/Qwen3-VL-4B-Instruct") - renderer = create_renderer(tokenizer, renderer="auto") + renderer = create_renderer(tokenizer) return tokenizer, renderer @@ -101,7 +101,7 @@ def test_qwen3_vl_malformed_tool_call_surfaces_as_invalid_json(): @lru_cache def _kimi_k25(): tokenizer = load_tokenizer("moonshotai/Kimi-K2.5") - renderer = create_renderer(tokenizer, renderer="auto") + renderer = create_renderer(tokenizer) return tokenizer, renderer diff --git a/tests/test_parse_response_robustness.py b/tests/test_parse_response_robustness.py index 1824da0..0f99a25 100644 --- a/tests/test_parse_response_robustness.py +++ b/tests/test_parse_response_robustness.py @@ -133,8 +133,7 @@ def test_tool_calls_is_list_of_parsed_tool_call(model_name, tokenizer, renderer) Empty list = "model did not emit any tool calls". A list with non-OK entries = "model tried and the parser caught the failure"; those are - deliberately preserved so verifier / RL-loss code can see them. This - replaces the older list-or-None convention. + deliberately preserved so verifier / RL-loss code can see them. """ text = "Hello!" ids = tokenizer.encode(text, add_special_tokens=False) diff --git a/tests/test_parsers.py b/tests/test_parsers.py index 3ec5bb5..1204ec9 100644 --- a/tests/test_parsers.py +++ b/tests/test_parsers.py @@ -114,11 +114,12 @@ def test_think_reasoning_parser_no_block(): def test_default_renderer_uses_parsers(): """DefaultRenderer + parsers should extract tool calls and reasoning.""" - from renderers import create_renderer + from renderers import DefaultRendererConfig, create_renderer tok = load_tokenizer("Qwen/Qwen3-0.6B") renderer = create_renderer( - tok, renderer="default", tool_parser="qwen3", reasoning_parser="think" + tok, + DefaultRendererConfig(tool_parser="qwen3", reasoning_parser="think"), ) assert renderer.supports_tools is True @@ -134,10 +135,10 @@ def test_default_renderer_uses_parsers(): def test_default_renderer_without_parsers_is_backward_compatible(): """Without parsers, DefaultRenderer still does basic extraction.""" - from renderers import create_renderer + from renderers import DefaultRendererConfig, create_renderer tok = load_tokenizer("Qwen/Qwen3-0.6B") - renderer = create_renderer(tok, renderer="default") + renderer = create_renderer(tok, DefaultRendererConfig()) assert renderer.supports_tools is False ids = tok.encode("ra", add_special_tokens=False) diff --git a/tests/test_preserve_thinking.py b/tests/test_preserve_thinking.py index 661d577..1ef07f5 100644 --- a/tests/test_preserve_thinking.py +++ b/tests/test_preserve_thinking.py @@ -1,9 +1,10 @@ """Smoke coverage for the ``preserve_*_thinking`` override flags. -Flags are constructor-only (``create_renderer(..., preserve_all_thinking=True)``) -and stored as instance attributes — there is no call-site ``render`` / -``render_ids`` override. Each test that wants a non-default flag builds -a fresh renderer for that configuration via ``_make`` below. +Flags live on the typed renderer config (e.g. +``Qwen3RendererConfig(preserve_all_thinking=True)``) and are stored on +the renderer as ``self.config.preserve_*``. Each test that wants a +non-default flag builds a fresh renderer for that configuration via +``_make`` below. Two invariants per renderer: @@ -22,15 +23,22 @@ from __future__ import annotations import pytest +from pydantic import ValidationError from renderers import create_renderer -from renderers.base import should_preserve_past_thinking +from renderers.base import MODEL_RENDERER_MAP, should_preserve_past_thinking +from renderers.configs import _config_class_for def _make(tokenizer, renderer_name, **flags): """Build a fresh renderer with the given preserve_*_thinking flags bound at construction. Reuses the cached tokenizer fixture.""" - return create_renderer(tokenizer, renderer=renderer_name, **flags) + if renderer_name == "auto": + renderer_name = MODEL_RENDERER_MAP.get( + getattr(tokenizer, "name_or_path", ""), "default" + ) + config = _config_class_for(renderer_name)(**flags) + return create_renderer(tokenizer, config) # Renderers whose template doesn't drop past-asst thinking or has no @@ -379,17 +387,19 @@ def test_default_renderer_raises_on_flags(): """``DefaultRenderer`` falls back to apply_chat_template with no selective re-emit pathway, so constructing one with either flag set must raise — fail fast, before any render is attempted.""" + from renderers import DefaultRendererConfig from renderers.base import load_tokenizer tok = load_tokenizer("Qwen/Qwen2.5-0.5B-Instruct") # No flags → constructs cleanly. - create_renderer(tok, renderer="default") + create_renderer(tok, DefaultRendererConfig()) # Either flag set → raises at construction. with pytest.raises(NotImplementedError): - create_renderer(tok, renderer="default", preserve_all_thinking=True) + create_renderer(tok, DefaultRendererConfig(preserve_all_thinking=True)) with pytest.raises(NotImplementedError): create_renderer( - tok, renderer="default", preserve_thinking_between_tool_calls=True + tok, + DefaultRendererConfig(preserve_thinking_between_tool_calls=True), ) @@ -399,31 +409,27 @@ def test_default_renderer_raises_on_flags(): def test_create_renderer_records_flag_state(model_name, renderer_name, tokenizer): - """Each renderer exposes the bound flag state via ``_preserve_*`` - attributes — useful for downstream code (pool cache keys, logging, - test assertions) that needs to confirm what was constructed.""" + """Each renderer exposes the bound flag state via ``self.config`` — + useful for downstream code (pool cache keys, logging, test + assertions) that needs to confirm what was constructed.""" from renderers.default import DefaultRenderer - bare = create_renderer(tokenizer, renderer=renderer_name) - assert bare._preserve_all_thinking is False - assert bare._preserve_thinking_between_tool_calls is False + bare = _make(tokenizer, renderer_name) + assert bare.config.preserve_all_thinking is False + assert bare.config.preserve_thinking_between_tool_calls is False if not isinstance(bare, DefaultRenderer): # DefaultRenderer raises at construction with either flag set — # covered by ``test_default_renderer_raises_on_flags``. - all_on = create_renderer( - tokenizer, renderer=renderer_name, preserve_all_thinking=True - ) - assert all_on._preserve_all_thinking is True - assert all_on._preserve_thinking_between_tool_calls is False + all_on = _make(tokenizer, renderer_name, preserve_all_thinking=True) + assert all_on.config.preserve_all_thinking is True + assert all_on.config.preserve_thinking_between_tool_calls is False - btc_on = create_renderer( - tokenizer, - renderer=renderer_name, - preserve_thinking_between_tool_calls=True, + btc_on = _make( + tokenizer, renderer_name, preserve_thinking_between_tool_calls=True ) - assert btc_on._preserve_all_thinking is False - assert btc_on._preserve_thinking_between_tool_calls is True + assert btc_on.config.preserve_all_thinking is False + assert btc_on.config.preserve_thinking_between_tool_calls is True # --------------------------------------------------------------------------- @@ -431,30 +437,31 @@ def test_create_renderer_records_flag_state(model_name, renderer_name, tokenizer # --------------------------------------------------------------------------- -def test_glm5_constructor_rejects_clear_thinking(): - """``clear_thinking`` was a chat-template-kwarg pass-through. It is - superseded by the renderer-agnostic ``preserve_all_thinking`` override - and must no longer be accepted by the constructor — its default-True - semantics are now baked into the render gate.""" +def test_glm5_config_accepts_clear_thinking(): + """``clear_thinking`` is a chat-template field on GLM-5's typed + config. The GLM-5 / GLM-5.1 Jinja templates gate historical + reasoning on ``clear_thinking is defined and not clear_thinking``, + so passing ``clear_thinking=False`` here must reach the renderer's + historical-reasoning gate. Parity vs ``apply_chat_template`` is + asserted in ``test_renderer_config_parity``.""" + from renderers import GLM5RendererConfig from renderers.base import load_tokenizer from renderers.glm5 import GLM5Renderer tok = load_tokenizer("zai-org/GLM-5") - with pytest.raises(TypeError): - GLM5Renderer(tok, clear_thinking=True) # type: ignore[call-arg] - with pytest.raises(TypeError): - GLM5Renderer(tok, clear_thinking=False) # type: ignore[call-arg] + # Both values must be accepted without raising. + GLM5Renderer(tok, GLM5RendererConfig(clear_thinking=True)) + GLM5Renderer(tok, GLM5RendererConfig(clear_thinking=False)) -def test_qwen36_constructor_rejects_preserve_thinking(): +def test_qwen36_config_rejects_unknown_field(): """``preserve_thinking`` on Qwen3.6 was a chat-template-kwarg - pass-through. It is superseded by the renderer-agnostic - ``preserve_all_thinking`` override and must no longer be accepted by - the constructor — its default-False semantics are now inherited from - Qwen3.5's render gate.""" - from renderers.base import load_tokenizer - from renderers.qwen36 import Qwen36Renderer - - tok = load_tokenizer("Qwen/Qwen3.6-35B-A3B") - with pytest.raises(TypeError): - Qwen36Renderer(tok, preserve_thinking=True) # type: ignore[call-arg] + pass-through in an earlier revision. It is superseded by the + renderer-agnostic ``preserve_all_thinking`` override and must not + appear on the typed config — its default-False semantics are now + inherited from Qwen3.5's render gate. ``extra="forbid"`` on the + pydantic model enforces this at construction.""" + from renderers import Qwen36RendererConfig + + with pytest.raises(ValidationError, match="preserve_thinking"): + Qwen36RendererConfig(preserve_thinking=True) # type: ignore[call-arg] diff --git a/tests/test_qwen35_size_coverage.py b/tests/test_qwen35_size_coverage.py index c33c9d8..6bb1161 100644 --- a/tests/test_qwen35_size_coverage.py +++ b/tests/test_qwen35_size_coverage.py @@ -19,7 +19,7 @@ import pytest -from renderers import Qwen35Renderer, create_renderer +from renderers import Qwen35Renderer, Qwen35RendererConfig, create_renderer from renderers.base import MODEL_RENDERER_MAP, load_tokenizer @@ -78,11 +78,11 @@ def test_qwen35_enable_thinking_polarity_autodetected(qwen35_model, expected_def own default when no explicit flag is passed — so big / small sizes each match their own template at the gen-prompt boundary.""" tok = load_tokenizer(qwen35_model) - renderer = create_renderer(tok, renderer="qwen3.5") + renderer = create_renderer(tok, Qwen35RendererConfig()) assert isinstance(renderer, Qwen35Renderer) - assert renderer._enable_thinking is expected_default, ( + assert renderer.config.enable_thinking is expected_default, ( f"{qwen35_model}: expected enable_thinking default {expected_default}, " - f"got {renderer._enable_thinking}" + f"got {renderer.config.enable_thinking}" ) @@ -148,7 +148,7 @@ def test_qwen35_size_parity_with_apply_chat_template( share ``Qwen35Renderer`` across all seven sizes — the polarity flip on 0.8B / 2B is absorbed by the constructor's auto-detect.""" tok = load_tokenizer(qwen35_model) - renderer = create_renderer(tok, renderer="qwen3.5") + renderer = create_renderer(tok, Qwen35RendererConfig()) assert isinstance(renderer, Qwen35Renderer) ours = renderer.render_ids(messages, add_generation_prompt=add_gen_prompt) diff --git a/tests/test_render_ids.py b/tests/test_render_ids.py index 418f295..e2e4a50 100644 --- a/tests/test_render_ids.py +++ b/tests/test_render_ids.py @@ -338,7 +338,7 @@ def test_multi_step_tool_cycle(model_name, tokenizer, renderer): @lru_cache def _qwen3_vl(): tokenizer = load_tokenizer("Qwen/Qwen3-VL-4B-Instruct") - renderer = create_renderer(tokenizer, renderer="auto") + renderer = create_renderer(tokenizer) return tokenizer, renderer @@ -353,7 +353,7 @@ def test_qwen3_vl_auto_renderer(): @lru_cache def _kimi_k25(): tokenizer = load_tokenizer("moonshotai/Kimi-K2.5") - renderer = create_renderer(tokenizer, renderer="auto") + renderer = create_renderer(tokenizer) return tokenizer, renderer @@ -366,7 +366,7 @@ def test_kimi_k2_inline_think_tags_render_verbatim(): reasoning, producing tokens that disagreed with ``apply_chat_template``. """ tokenizer = load_tokenizer("moonshotai/Kimi-K2-Instruct") - renderer = create_renderer(tokenizer, renderer="auto") + renderer = create_renderer(tokenizer) msgs = [ {"role": "user", "content": "hi"}, {"role": "assistant", "content": "secretvisible"}, diff --git a/tests/test_renderer_config.py b/tests/test_renderer_config.py new file mode 100644 index 0000000..4bc0a31 --- /dev/null +++ b/tests/test_renderer_config.py @@ -0,0 +1,116 @@ +"""Unit tests for the typed-config surface — discriminated union, +auto-resolution, and ``extra="forbid"`` enforcement on per-renderer +configs.""" + +from types import SimpleNamespace + +import pytest +from pydantic import TypeAdapter, ValidationError + +from renderers import ( + AutoRendererConfig, + DefaultRendererConfig, + GLM5RendererConfig, + Qwen3RendererConfig, + Qwen35RendererConfig, + RendererConfig, + base, + create_renderer, +) + + +def test_per_renderer_config_rejects_unknown_fields(): + """``extra="forbid"`` on every typed variant catches bogus keys at + construction: ``add_vision_id`` doesn't exist on ``Qwen3RendererConfig`` + (Qwen3 is text-only), so passing it must raise.""" + with pytest.raises(ValidationError, match="add_vision_id"): + Qwen3RendererConfig(add_vision_id=True) + + +def test_discriminated_union_dispatches_on_name(): + """A dict shaped like ``{"name": "glm-5", ...}`` deserialises to the + matching typed config; the union ``RendererConfig`` is what + downstream consumers (prime-rl, verifiers) hold as a single field.""" + ta = TypeAdapter(RendererConfig) + parsed = ta.validate_python( + {"name": "glm-5", "enable_thinking": False, "clear_thinking": False} + ) + assert isinstance(parsed, GLM5RendererConfig) + assert parsed.enable_thinking is False + assert parsed.clear_thinking is False + + +def test_discriminated_union_rejects_wrong_renderer_kwargs(): + """``add_vision_id`` under ``name="qwen3"`` is invalid at deserialise + time — the discriminator narrows to ``Qwen3RendererConfig`` whose + schema does not include that field.""" + ta = TypeAdapter(RendererConfig) + with pytest.raises(ValidationError, match="add_vision_id"): + ta.validate_python({"name": "qwen3", "add_vision_id": True}) + + +def test_default_renderer_config_accepts_arbitrary_extras(): + """``DefaultRenderer`` wraps ``apply_chat_template`` for unknown + templates, so its config uses ``extra="allow"`` and surfaces extras + via ``model_extra``.""" + cfg = DefaultRendererConfig( + tool_parser="qwen3", enable_thinking=False, custom_jinja_kwarg=True + ) + assert cfg.tool_parser == "qwen3" + assert cfg.model_extra == { + "enable_thinking": False, + "custom_jinja_kwarg": True, + } + + +def test_create_renderer_forwards_typed_config_to_renderer(monkeypatch): + """``create_renderer`` dispatches on ``config.name`` via + ``RENDERER_REGISTRY``; the renderer stores the config it was given.""" + + class _FakeRenderer: + def __init__(self, tokenizer, config): + self.tokenizer = tokenizer + self.config = config + + monkeypatch.setitem(base.RENDERER_REGISTRY, "qwen3", _FakeRenderer) + + renderer = create_renderer( + SimpleNamespace(name_or_path="unused"), + Qwen3RendererConfig(enable_thinking=False), + ) + assert isinstance(renderer.config, Qwen3RendererConfig) + assert renderer.config.enable_thinking is False + + +def test_create_renderer_auto_resolves_via_model_map(monkeypatch): + """``AutoRendererConfig`` (or ``config=None``) routes through + ``MODEL_RENDERER_MAP`` to pick the matching renderer + typed config, + carrying the shared ``preserve_*`` flags over from the auto config.""" + + class _FakeQwen35: + def __init__(self, tokenizer, config): + self.tokenizer = tokenizer + self.config = config + + monkeypatch.setitem(base.RENDERER_REGISTRY, "qwen3.5", _FakeQwen35) + monkeypatch.setitem(base.MODEL_RENDERER_MAP, "fake/qwen35", "qwen3.5") + + renderer = create_renderer( + SimpleNamespace(name_or_path="fake/qwen35"), + AutoRendererConfig(preserve_all_thinking=True), + ) + + assert isinstance(renderer.config, Qwen35RendererConfig) + assert renderer.config.preserve_all_thinking is True + # Template-level kwargs stay at their per-renderer defaults — auto + # carries only the preserve_* flags. + assert renderer.config.add_vision_id is False + + +def test_create_renderer_default_argument_is_auto(): + """Passing no config is equivalent to passing ``AutoRendererConfig()`` + — short form for the common case.""" + tok = SimpleNamespace(name_or_path="") # no MODEL_RENDERER_MAP entry + renderer = create_renderer(tok) + # Falls through to DefaultRenderer when no match and no vision config. + assert renderer.__class__.__name__ == "DefaultRenderer" diff --git a/tests/test_renderer_config_parity.py b/tests/test_renderer_config_parity.py new file mode 100644 index 0000000..8ca2da3 --- /dev/null +++ b/tests/test_renderer_config_parity.py @@ -0,0 +1,513 @@ +"""Parity for typed-config template fields against the upstream chat +template. + +Each renderer's typed config (see ``renderers.configs``) declares the +fields that mirror chat-template kwargs via +``Config.template_field_names()``. ``test_renderer_config.py`` covers +the typed-config wiring; this file covers the only thing that matters +downstream: that flipping a template field on the typed config produces +token streams byte-identical to +``tokenizer.apply_chat_template(messages, **{field: value})``. + +Without this, the typed surface is a promise the renderer doesn't keep. + +Discovery is automatic — the parity matrix is built from each config +class's ``template_field_names()`` crossed with the per-field value list +in ``_KWARG_VALUES``. To extend coverage to a new field: declare it on +the typed config and add the values to exercise to ``_KWARG_VALUES`` +below. + +``gpt-oss`` parity is against ``openai-harmony`` (its renderer diverges +from HF Jinja by design — see ``test_gpt_oss_harmony_parity.py``); it +lives in its own test below, with the same auto-derived discovery. +""" + +from __future__ import annotations + +from datetime import datetime +from functools import lru_cache +from typing import Any + +import pytest + +from renderers import create_renderer +from renderers.base import ( + MODEL_RENDERER_MAP, + _populate_registry, + load_tokenizer, +) +from renderers.configs import _config_class_for + + +# Models exercised by the parity tests. Mirrors ``conftest.RENDERER_MODELS`` +# in spirit — one representative model per renderer family — plus the +# ``gpt-oss`` entry that conftest skips for HF parity (gpt-oss parity is +# against harmony, handled separately below). +_RENDERER_MODELS = [ + ("Qwen/Qwen3-8B", "auto"), + ("Qwen/Qwen3.5-9B", "auto"), + ("Qwen/Qwen3.6-35B-A3B", "auto"), + ("zai-org/GLM-5", "auto"), + ("zai-org/GLM-5.1", "auto"), + ("zai-org/GLM-4.7-Flash", "auto"), + ("THUDM/GLM-4.5-Air", "auto"), + ("moonshotai/Kimi-K2.5", "auto"), + ("moonshotai/Kimi-K2.6", "auto"), + ("deepseek-ai/DeepSeek-V3", "auto"), + ("nvidia/NVIDIA-Nemotron-3-Nano-30B-A3B-BF16", "auto"), + ("poolside/Laguna-XS.2", "auto"), + ("openai/gpt-oss-20b", "gpt-oss"), +] + + +# Per-kwarg value list. Each template field any renderer's typed config +# declares (via ``Config.template_field_names()``) must have an entry +# here, or the parity matrix silently skips it. The test below asserts +# coverage so a future kwarg can't slip through without an explicit +# value list. +_KWARG_VALUES: dict[str, list[Any]] = { + "enable_thinking": [True, False], + # Kimi K2.5 / K2.6 — same semantics as ``enable_thinking`` but the + # upstream template uses ``thinking`` as the variable name. The + # renderer's typed config (``KimiK25RendererConfig.thinking``) + # mirrors that name so the field maps 1:1 onto the template gate. + "thinking": [True, False], + "reasoning_effort": ["low", "medium", "high"], + # GLM-5 / GLM-5.1 — ``clear_thinking=False`` preserves the + # ``{reasoning}`` wrap on historical assistants too + # (default True collapses past-cycle reasoning to ````). + "clear_thinking": [True, False], + # Nemotron-3 — mirror of ``clear_thinking`` under a different name. + # ``truncate_history_thinking=False`` keeps reasoning on historical + # assistants instead of collapsing to ````. + "truncate_history_thinking": [True, False], + # MiniMax-M2 — fallback persona string when no system message is + # supplied. Two arbitrary values to verify the renderer threads the + # exact bytes through (whitespace included). + "model_identity": [ + "You are a helpful assistant. Your name is MiniMax-M2.5 and is built by MiniMax.", + "You are CustomBot, a research assistant.", + ], + # Laguna-XS.2 — switches assistant rendering to a verbatim + # passthrough mode. The renderer paths diverge significantly under + # this flag, so both values are exercised. + "render_assistant_messages_raw": [True, False], + # Qwen3.5 / Qwen3.6 / Qwen3-VL — when True, prefix each image / + # video placeholder with ``Picture N: `` / ``Video N: ``. + "add_vision_id": [True, False], + # gpt-oss — pin to a fixed date so the renderer's preamble matches + # the harmony oracle built with the same date. The default + # ``today's date`` is intentionally avoided here so the assertion + # doesn't flake on a UTC midnight crossing. + "conversation_start_date": ["2025-01-15"], +} + + +TOOLS = [ + { + "type": "function", + "function": { + "name": "get_weather", + "description": "Get the current weather for a city", + "parameters": { + "type": "object", + "properties": { + "city": {"type": "string", "description": "The city name"}, + }, + "required": ["city"], + }, + }, + } +] + + +# (id, messages, render_kwargs). Each shape exercises a distinct branch +# at least one renderer is known to flip on a kwarg: +# +# - ``system_user_gen``: forces the generation-prompt branch (e.g. +# Qwen3 / GLM / Nemotron emit a synthesized ```` here +# under ``enable_thinking=False``). +# - ``single_turn``: terminal assistant with plain content — the +# "render historical thinking?" branch. +# - ``with_reasoning``: assistant carries ``reasoning_content`` — flips +# whether thinking is emitted into the rendered history. +# - ``multi_turn``: two assistant turns separated by a user; the +# in-flight-vs-historical distinction matters for several renderers. +# - ``tool_cycle``: assistant tool call + tool response + final +# assistant, with ``add_generation_prompt=True`` so the second +# gen-prompt branch is hit too. +_MESSAGE_SHAPES = [ + ( + "system_user_gen", + [ + {"role": "system", "content": "You are helpful."}, + {"role": "user", "content": "Hi"}, + ], + {"add_generation_prompt": True}, + ), + ( + "single_turn", + [ + {"role": "user", "content": "What is 2+2?"}, + {"role": "assistant", "content": "4"}, + ], + {}, + ), + ( + "with_reasoning", + [ + {"role": "user", "content": "What is 2+2?"}, + { + "role": "assistant", + "reasoning_content": "Simple arithmetic.", + "content": "4", + }, + ], + {}, + ), + ( + "multi_turn", + [ + {"role": "user", "content": "A"}, + {"role": "assistant", "content": "B"}, + {"role": "user", "content": "C"}, + {"role": "assistant", "content": "D"}, + ], + {}, + ), + ( + "tool_cycle", + [ + {"role": "user", "content": "Weather in Paris?"}, + { + "role": "assistant", + "content": "", + "tool_calls": [ + { + "function": { + "name": "get_weather", + "arguments": {"city": "Paris"}, + } + } + ], + }, + {"role": "tool", "content": '{"temp": 20}'}, + {"role": "assistant", "content": "It is 20 degrees."}, + ], + {"tools": TOOLS, "add_generation_prompt": True}, + ), + # ``no_system_user_gen``: no system message — exercises the + # template fallback persona (e.g. MiniMax-M2's ``model_identity``). + ( + "no_system_user_gen", + [{"role": "user", "content": "Hi"}], + {"add_generation_prompt": True}, + ), + # ``historical_reasoning``: multi-turn with ``reasoning_content`` on + # a historical assistant. Exercises ``clear_thinking`` / + # ``truncate_history_thinking`` (which only diverge from default + # behaviour when a past-cycle assistant carries reasoning). + ( + "historical_reasoning", + [ + {"role": "user", "content": "What is 2+2?"}, + { + "role": "assistant", + "reasoning_content": "Adding small ints.", + "content": "4", + }, + {"role": "user", "content": "Now 3+3?"}, + { + "role": "assistant", + "reasoning_content": "Same idea.", + "content": "6", + }, + ], + {}, + ), +] + + +# ── Matrix discovery ─────────────────────────────────────────────────── + + +_populate_registry() + + +def _resolve_renderer_name(model: str, renderer_name: str) -> str: + """Resolve ``(model, renderer_name)`` to the concrete renderer name.""" + if renderer_name == "auto": + return MODEL_RENDERER_MAP.get(model, "default") + return renderer_name + + +def _template_fields_for(model: str, renderer_name: str) -> frozenset[str]: + """Discover the typed-config template-field set for a renderer.""" + resolved = _resolve_renderer_name(model, renderer_name) + return _config_class_for(resolved).template_field_names() + + +def _hf_parity_matrix() -> list[Any]: + """Auto-derived ``(model, renderer_name, kwarg, value)`` matrix for + every renderer with template fields, minus gpt-oss (handled + separately against harmony). + """ + out = [] + for model, name in _RENDERER_MODELS: + if name == "gpt-oss": + continue + for kwarg in sorted(_template_fields_for(model, name)): + for value in _KWARG_VALUES.get(kwarg, []): + out.append( + pytest.param( + model, name, kwarg, value, id=f"{model}-{kwarg}={value}" + ) + ) + return out + + +def _harmony_parity_matrix() -> list[Any]: + """Auto-derived ``(model, renderer_name, kwarg, value)`` matrix for + gpt-oss (parity against ``openai-harmony``). + """ + out = [] + for model, name in _RENDERER_MODELS: + if name != "gpt-oss": + continue + for kwarg in sorted(_template_fields_for(model, name)): + for value in _KWARG_VALUES.get(kwarg, []): + out.append( + pytest.param( + model, name, kwarg, value, id=f"{model}-{kwarg}={value}" + ) + ) + return out + + +def test_kwarg_values_covers_every_declared_kwarg(): + """Every template field any renderer declares must have an entry in + ``_KWARG_VALUES`` — otherwise it silently drops out of parity + coverage. + """ + declared: set[str] = set() + for model, name in _RENDERER_MODELS: + declared.update(_template_fields_for(model, name)) + missing = sorted(declared - _KWARG_VALUES.keys()) + assert not missing, ( + f"Typed-config template fields declared but not covered: {missing}. " + f"Add a value list to _KWARG_VALUES in this file." + ) + + +# ── Test caches ──────────────────────────────────────────────────────── + + +@lru_cache(maxsize=None) +def _tokenizer(model_name: str): + return load_tokenizer(model_name) + + +@lru_cache(maxsize=None) +def _renderer_with_kwarg(model_name: str, renderer_name: str, kwarg: str, value: Any): + tok = _tokenizer(model_name) + resolved = _resolve_renderer_name(model_name, renderer_name) + config = _config_class_for(resolved)(**{kwarg: value}) + return create_renderer(tok, config) + + +def _expected_hf(tokenizer, messages, *, kwarg: str, value: Any, **render_kwargs): + """Render via ``apply_chat_template`` with the kwarg spread as a + top-level argument. + + transformers v5.x silently drops ``chat_template_kwargs={...}`` — + only direct kwargs propagate into the Jinja environment. The two + invocation styles are semantically the same for the Jinja template, + so we pick the one that actually fires. (Our ``create_renderer`` + API accepts the dict form because it is the standard wire format + in OpenAI-compatible servers; we translate it to constructor kwargs + on our side.) + """ + render_kwargs.setdefault("add_generation_prompt", False) + result = tokenizer.apply_chat_template( + messages, + tokenize=True, + return_dict=False, + **{kwarg: value}, + **render_kwargs, + ) + if isinstance(result, dict): + return list(result["input_ids"]) + if isinstance(result, str): + return list(tokenizer.encode(result, add_special_tokens=False)) + return list(result) + + +# ── HF-Jinja parity (every renderer except gpt-oss) ──────────────────── + + +@pytest.mark.parametrize("model,renderer_name,kwarg,value", _hf_parity_matrix()) +@pytest.mark.parametrize( + "shape_id,messages,render_kwargs", + _MESSAGE_SHAPES, + ids=[s[0] for s in _MESSAGE_SHAPES], +) +def test_chat_template_kwarg_parity_hf( + model, + renderer_name, + kwarg, + value, + shape_id, + messages, + render_kwargs, +): + tokenizer = _tokenizer(model) + renderer = _renderer_with_kwarg(model, renderer_name, kwarg, value) + # Guard: the typed config must actually declare the kwarg as a + # template field. Pydantic ``extra="forbid"`` already enforces this + # at construction; asserting here gives a louder failure on a future + # config subclass that drops the field. + assert kwarg in type(renderer.config).template_field_names() + + try: + expected = _expected_hf( + tokenizer, messages, kwarg=kwarg, value=value, **render_kwargs + ) + except Exception as exc: + pytest.xfail( + f"{model}: apply_chat_template raised {type(exc).__name__}: " + f"{str(exc)[:160]}" + ) + + got = renderer.render_ids(messages, **render_kwargs) + assert got == expected, ( + f"{model} / shape={shape_id} / {kwarg}={value}: renderer diverged " + f"from apply_chat_template (len got={len(got)}, expected={len(expected)})" + ) + + +# ── Harmony parity (gpt-oss only) ────────────────────────────────────── + + +_DATE_FOR_PARITY = datetime.now().strftime("%Y-%m-%d") + + +@lru_cache(maxsize=None) +def _gpt_oss_renderer(kwarg: str, value: Any): + from renderers.configs import GptOssRendererConfig + from renderers.gpt_oss import GptOssRenderer + + tok = _tokenizer("openai/gpt-oss-20b") + # Pin a default conversation_start_date so the rendered preamble + # matches the harmony oracle's fixed date. Any explicit + # ``conversation_start_date`` from the kwarg-under-test overrides + # it (the per-kwarg branch replays the same value into the oracle + # below so the assertion still holds). + kwargs: dict[str, Any] = {"conversation_start_date": _DATE_FOR_PARITY} + kwargs[kwarg] = value + return GptOssRenderer(tok, GptOssRendererConfig(**kwargs)) + + +def _harmony_expected( + kwarg: str, value: Any, messages: list[dict[str, Any]] +) -> list[int]: + from openai_harmony import ( + Conversation, + HarmonyEncodingName, + Message as HarmonyMessage, + ReasoningEffort, + Role, + SystemContent, + load_harmony_encoding, + ) + + # Base preamble pins the same default date the renderer fixture + # uses so the unrelated kwargs don't drift on date semantics. + sys_content = SystemContent.new().with_conversation_start_date(_DATE_FOR_PARITY) + if kwarg == "reasoning_effort": + effort_enum = { + "low": ReasoningEffort.LOW, + "medium": ReasoningEffort.MEDIUM, + "high": ReasoningEffort.HIGH, + }[value] + sys_content = sys_content.with_reasoning_effort(effort_enum) + elif kwarg == "conversation_start_date": + # Override the pinned date with the value under test. + sys_content = sys_content.with_conversation_start_date(value) + else: + raise AssertionError( + f"Harmony oracle: unhandled gpt-oss chat_template_kwarg {kwarg!r}. " + "Add a branch here when extending GptOssRendererConfig's template fields." + ) + + harmony_msgs: list[HarmonyMessage] = [ + HarmonyMessage.from_role_and_content(Role.SYSTEM, sys_content) + ] + for m in messages: + role = m["role"] + content = m.get("content", "") or "" + if role == "user": + harmony_msgs.append( + HarmonyMessage.from_role_and_content(Role.USER, content) + ) + elif role == "assistant": + harmony_msgs.append( + HarmonyMessage.from_role_and_content( + Role.ASSISTANT, content + ).with_channel("final") + ) + else: + raise AssertionError( + f"Harmony oracle helper does not handle role={role!r}; add a " + "branch or constrain the shapes used for gpt-oss parity." + ) + encoder = load_harmony_encoding(HarmonyEncodingName.HARMONY_GPT_OSS) + return encoder.render_conversation_for_training( + Conversation.from_messages(harmony_msgs) + ) + + +# Harmony oracle is only wired for the simplest shapes (user-only and +# user+assistant content). Tool-call and reasoning_content shapes have +# a richer mapping that the dedicated ``test_gpt_oss_harmony_parity.py`` +# already covers — duplicating that here would only test the harness. +_HARMONY_SHAPES = [ + ( + "user_only_gen", + [{"role": "user", "content": "Hello!"}], + {}, + ), + ( + "user_and_assistant", + [ + {"role": "user", "content": "Hi"}, + {"role": "assistant", "content": "Hello!"}, + ], + {}, + ), +] + + +@pytest.mark.parametrize("model,renderer_name,kwarg,value", _harmony_parity_matrix()) +@pytest.mark.parametrize( + "shape_id,messages,render_kwargs", + _HARMONY_SHAPES, + ids=[s[0] for s in _HARMONY_SHAPES], +) +def test_chat_template_kwarg_parity_harmony( + model, + renderer_name, + kwarg, + value, + shape_id, + messages, + render_kwargs, +): + renderer = _gpt_oss_renderer(kwarg, value) + assert kwarg in type(renderer.config).template_field_names() + + got = renderer.render_ids(messages, **render_kwargs) + expected = _harmony_expected(kwarg, value, messages) + assert got == expected, ( + f"{model} / shape={shape_id} / {kwarg}={value}: renderer diverged " + f"from harmony oracle (len got={len(got)}, expected={len(expected)})" + ) diff --git a/tests/test_roundtrip.py b/tests/test_roundtrip.py index a4577fd..383bc14 100644 --- a/tests/test_roundtrip.py +++ b/tests/test_roundtrip.py @@ -51,11 +51,11 @@ @lru_cache(maxsize=None) def _load_renderer(model_name: str, renderer_name: str): - from renderers import create_renderer + from renderers import config_from_name, create_renderer from renderers.base import load_tokenizer tok = load_tokenizer(model_name) - return tok, create_renderer(tok, renderer=renderer_name) + return tok, create_renderer(tok, config_from_name(renderer_name)) def pytest_generate_tests(metafunc): @@ -316,7 +316,7 @@ def test_default_renderer_fallback_parser_preserves_boundary_whitespace( """ from renderers.default import DefaultRenderer - renderer = DefaultRenderer(rt_tokenizer, tool_parser=None, reasoning_parser=None) + renderer = DefaultRenderer(rt_tokenizer) # Encode `reason\nvisible` as text and run through # parse_response. We don't need the template to emit `` here diff --git a/tests/test_tool_arg_type_preservation.py b/tests/test_tool_arg_type_preservation.py index 61dbb9b..607c6d5 100644 --- a/tests/test_tool_arg_type_preservation.py +++ b/tests/test_tool_arg_type_preservation.py @@ -40,11 +40,11 @@ @lru_cache(maxsize=None) def _load(model: str, renderer_name: str): - from renderers import create_renderer + from renderers import config_from_name, create_renderer from renderers.base import load_tokenizer tok = load_tokenizer(model) - return tok, create_renderer(tok, renderer=renderer_name) + return tok, create_renderer(tok, config_from_name(renderer_name)) def pytest_generate_tests(metafunc): diff --git a/uv.lock b/uv.lock index f968951..8096df3 100644 --- a/uv.lock +++ b/uv.lock @@ -9,10 +9,11 @@ resolution-markers = [ ] [options] -exclude-newer = "0001-01-01T00:00:00Z" # This has no effect and is included for backwards compatibility when using relative exclude-newer values. +exclude-newer = "2026-05-18T21:42:54.18041997Z" exclude-newer-span = "P7D" [options.exclude-newer-package] +prime-pydantic-config = false fastokens = false [[package]] @@ -1073,6 +1074,18 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/80/6e/4b28b62ecb6aae56769c34a8ff1d661473ec1e9519e2d5f8b2c150086b26/pre_commit-4.6.0-py2.py3-none-any.whl", hash = "sha256:e2cf246f7299edcabcf15f9b0571fdce06058527f0a06535068a86d38089f29b", size = 226472, upload-time = "2026-04-21T20:31:40.092Z" }, ] +[[package]] +name = "prime-pydantic-config" +version = "0.3.0.dev83" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "pydantic" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/b3/4e/bcdb244336d3abae60c7767626aaf635ef724bdc3f2cb46e317d11b23d91/prime_pydantic_config-0.3.0.dev83.tar.gz", hash = "sha256:7446e6439ba6de2f2c332acb292ba5b53da7c1a4ad60e4a25b78393b17859fdd", size = 73322, upload-time = "2026-05-24T02:58:49.485Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/1a/6b/372380c6f5cc3a6f0b6690acb453e9aa9d848e8c1ec421e883c1c3c7a293/prime_pydantic_config-0.3.0.dev83-py3-none-any.whl", hash = "sha256:91a8d883181aff069a4a07f56b636c399c32f1814d8177b9e167a8398643e313", size = 26259, upload-time = "2026-05-24T02:58:48.432Z" }, +] + [[package]] name = "pydantic" version = "2.13.3" @@ -1373,6 +1386,7 @@ dependencies = [ { name = "numpy", version = "2.4.4", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.11'" }, { name = "openai" }, { name = "openai-harmony" }, + { name = "prime-pydantic-config" }, { name = "tiktoken" }, { name = "transformers" }, ] @@ -1396,6 +1410,7 @@ requires-dist = [ { name = "numpy" }, { name = "openai", specifier = ">=1.108.1" }, { name = "openai-harmony", specifier = ">=0.0.8" }, + { name = "prime-pydantic-config", specifier = ">=0.3.0.dev83" }, { name = "tiktoken" }, { name = "transformers", specifier = ">=4.50.0" }, ]