From d7c160be04f5829594a39db5179da910c06242e1 Mon Sep 17 00:00:00 2001 From: Sebastian Date: Thu, 28 May 2026 18:10:05 +0200 Subject: [PATCH] feat(base): add `message_tool_names` field for per-message tool attribution MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Adds ``RenderedTokens.message_tool_names: list[str | None]`` — a sidecar parallel to ``message_roles`` that carries the tool function name for each tool-role message in the rendered slice. For each tool message the name is taken from ``msg["name"]`` (caller- provided) or recovered by joining ``msg["tool_call_id"]`` against any prior assistant's ``tool_calls[i].function.name`` in the same list. Tool messages whose issuing assistant lives outside the slice (e.g. on a ``bridge_to_next_turn`` call where ``new_messages`` covers only the new turn) resolve to ``None``. Pure metadata: ``extract_message_tool_names`` runs independently of the render path, never mutates the caller's messages, and has no effect on the rendered token stream — HF chat-template byte parity is preserved on every renderer. Callers that want the function name to appear in the rendered scaffold (e.g. GPT-OSS Harmony's ``functions.{name}`` prefix) continue to attach ``name`` themselves before calling ``render`` — that responsibility stays with the caller (verifiers does this in ``_attach_tool_call_names``). Trainers (prime-rl) join this list with ``message_indices`` to recover per-token tool attribution — the canonical use case is SFT on tool response bodies of a specific tool while RL acts on assistant tokens. Wired into every concrete renderer's ``RenderedTokens(...)`` construction site (render + ``bridge_to_next_turn``). ``extract_message_tool_names`` is exported at package level. Tests: 5 unit tests covering the case matrix (empty, caller- provided wins, resolves from prior assistant, orphan tool message, non-mutation invariant) + 1 integration test that runs across every renderer in the conftest matrix to catch missed wire-up at any of the ~25 ``RenderedTokens(...)`` sites. Co-Authored-By: Claude Opus 4.7 (1M context) --- renderers/__init__.py | 2 + renderers/base.py | 90 +++++++++++++++++++++++++++++ renderers/deepseek_v3.py | 3 + renderers/default.py | 2 + renderers/glm45.py | 3 + renderers/glm5.py | 3 + renderers/gpt_oss.py | 3 + renderers/kimi_k2.py | 3 + renderers/kimi_k25.py | 5 ++ renderers/laguna_xs2.py | 3 + renderers/minimax_m2.py | 3 + renderers/nemotron3.py | 3 + renderers/qwen3.py | 3 + renderers/qwen35.py | 5 ++ renderers/qwen3_vl.py | 3 + tests/test_message_tool_names.py | 97 ++++++++++++++++++++++++++++++++ 16 files changed, 231 insertions(+) create mode 100644 tests/test_message_tool_names.py diff --git a/renderers/__init__.py b/renderers/__init__.py index c95719b..e7cd1c4 100644 --- a/renderers/__init__.py +++ b/renderers/__init__.py @@ -33,6 +33,7 @@ build_trajectory_step, create_renderer, create_renderer_pool, + extract_message_tool_names, is_multimodal, reject_assistant_in_extension, trim_to_turn_close, @@ -168,6 +169,7 @@ def __dir__() -> list[str]: "config_from_name", "create_renderer", "create_renderer_pool", + "extract_message_tool_names", "is_multimodal", "reject_assistant_in_extension", "trim_to_turn_close", diff --git a/renderers/base.py b/renderers/base.py index 5bed116..3d1b696 100644 --- a/renderers/base.py +++ b/renderers/base.py @@ -6,6 +6,7 @@ import logging import queue import threading +from collections.abc import Mapping from contextlib import contextmanager from dataclasses import dataclass, field from typing import ( @@ -117,6 +118,68 @@ class Message(TypedDict, total=False): reasoning_content: str +def extract_message_tool_names(messages: list[Message]) -> list[str | None]: + """Per-message tool function names parallel to ``message_roles``. + + Returns one entry per message: the function name for ``role="tool"`` + messages, ``None`` for every other message. Length matches the + input list. + + For tool messages the name is taken from ``msg["name"]`` when set + (caller-provided), otherwise recovered by joining + ``msg["tool_call_id"]`` against any prior assistant's + ``tool_calls[i].function.name`` in the same list. Tool messages + whose issuing assistant lives outside the provided list (e.g. on + a :meth:`Renderer.bridge_to_next_turn` call where ``new_messages`` + covers only the new turn) resolve to ``None``. + + Pure metadata: this never mutates the caller's messages and has + no effect on the rendered token stream. It runs independently of + the render path so the renderer can populate the field on + :class:`RenderedTokens` without breaking HF byte parity for tool + messages that carry no ``name``. Callers who *also* want the + function name to appear in the rendered scaffold (e.g. GPT-OSS + Harmony's ``functions.{name}`` prefix) must attach ``name`` to + their tool messages before calling :meth:`Renderer.render` + themselves — renderers don't synthesize ``name`` into the input, + only into this metadata field. + + Trainers join this list with :attr:`RenderedTokens.message_indices` + to recover per-token tool attribution — the canonical use case is + SFT on tool response bodies while RL acts only on assistant tokens + (tool body tokens get a constant positive advantage so the model + learns to anticipate tool outputs without learning to emit + ``<|tool_response>`` itself). + + Per-message rather than per-token because the data is naturally + per-message — storing it per-token would duplicate the same + string across every body token of the same tool message. + """ + lookup: dict[str, str] = {} + for m in messages: + if not isinstance(m, Mapping) or m.get("role") != "assistant": + continue + for tc in m.get("tool_calls") or []: + if not isinstance(tc, Mapping): + continue + tc_id = tc.get("id") + fn = tc.get("function") + tc_name = fn.get("name") if isinstance(fn, Mapping) else None + if isinstance(tc_id, str) and isinstance(tc_name, str): + lookup[tc_id] = tc_name + out: list[str | None] = [] + for m in messages: + if not isinstance(m, Mapping) or m.get("role") != "tool": + out.append(None) + continue + name = m.get("name") + if not (isinstance(name, str) and name): + tc_id = m.get("tool_call_id") + name = lookup.get(tc_id) if isinstance(tc_id, str) else None + out.append(name if isinstance(name, str) and name else None) + return out + + # --------------------------------------------------------------------------- # Renderer data types # --------------------------------------------------------------------------- @@ -208,6 +271,32 @@ class RenderedTokens: renderer doesn't provide the signal. ``DefaultRenderer`` leaves it empty for the same reason. + ``message_tool_names`` is the per-message tool function name list, + parallel to ``message_roles`` (same length). For tool-role + messages it carries the function name — either taken from + ``msg["name"]`` (caller-provided) or recovered by joining + ``msg["tool_call_id"]`` against a prior assistant's + ``tool_calls[i].function.name`` in the rendered slice. Every + other message is ``None``, as are tool messages whose issuing + assistant lives outside the rendered slice (e.g. on a + :meth:`Renderer.bridge_to_next_turn` call where ``new_messages`` + covers only the new turn). + + This is pure metadata, computed by :func:`extract_message_tool_names` + independently of the render path: populating it never touches the + rendered token stream, so HF chat-template byte parity is + preserved for tool messages carrying no ``name``. Callers who + *also* want the function name to appear in the rendered scaffold + (e.g. GPT-OSS Harmony's ``functions.{name}`` prefix) must attach + ``name`` to their tool messages before calling + :meth:`Renderer.render` themselves. + + Trainers join this with ``message_indices`` to build per-tool + selective loss masks (SFT on tool response bodies of a specific + tool while RL acts on assistant tokens). Empty + ``message_tool_names`` (``[]``) means the renderer doesn't + provide the signal. + ``multi_modal_data`` is populated by multimodal renderers (e.g. ``Qwen3VLRenderer``) when image / video content parts are present; text-only renderers leave it as ``None``. @@ -218,6 +307,7 @@ class RenderedTokens: sampled_mask: list[bool] = field(default_factory=list) is_content: list[bool] = field(default_factory=list) message_roles: list[str] = field(default_factory=list) + message_tool_names: list[str | None] = field(default_factory=list) multi_modal_data: "MultiModalData | None" = None def tokens_per_message( diff --git a/renderers/deepseek_v3.py b/renderers/deepseek_v3.py index 4efe3ef..7bec3de 100644 --- a/renderers/deepseek_v3.py +++ b/renderers/deepseek_v3.py @@ -22,6 +22,7 @@ RenderedTokens, ToolSpec, attribute_text_segments, + extract_message_tool_names, reject_assistant_in_extension, trim_to_turn_close, ) @@ -247,6 +248,7 @@ def emit_text_segments( sampled_mask=sampled, is_content=content_mask, message_roles=[m.get("role") or "" for m in messages], + message_tool_names=extract_message_tool_names(messages), ) def render_ids( @@ -390,6 +392,7 @@ def emit_text( sampled_mask=[False] * total_len, is_content=[False] * len(previous_ids) + ext_content, message_roles=[m.get("role") or "" for m in new_messages], + message_tool_names=extract_message_tool_names(new_messages), ) # ------------------------------------------------------------------ diff --git a/renderers/default.py b/renderers/default.py index e969421..a662097 100644 --- a/renderers/default.py +++ b/renderers/default.py @@ -18,6 +18,7 @@ ParsedResponse, RenderedTokens, ToolSpec, + extract_message_tool_names, ) from renderers.configs import DefaultRendererConfig from renderers.parsers import ( @@ -141,6 +142,7 @@ def render( token_ids=token_ids, message_indices=message_indices, message_roles=message_roles, + message_tool_names=extract_message_tool_names(messages), ) def _apply(self, messages, *, tools=None, add_generation_prompt=False) -> list[int]: diff --git a/renderers/glm45.py b/renderers/glm45.py index efea47b..7af9259 100644 --- a/renderers/glm45.py +++ b/renderers/glm45.py @@ -21,6 +21,7 @@ RenderedTokens, ToolSpec, attribute_text_segments, + extract_message_tool_names, reject_assistant_in_extension, should_preserve_past_thinking, ) @@ -265,6 +266,7 @@ def emit_text_segments( sampled_mask=sampled, is_content=content_mask, message_roles=[m.get("role") or "" for m in messages], + message_tool_names=extract_message_tool_names(messages), ) def render_ids( @@ -445,6 +447,7 @@ def emit_text_segments( sampled_mask=[False] * total_len, is_content=[False] * len(previous_ids) + ext_content, message_roles=[m.get("role") or "" for m in new_messages], + message_tool_names=extract_message_tool_names(new_messages), ) def _render_assistant( diff --git a/renderers/glm5.py b/renderers/glm5.py index a42a0af..924d754 100644 --- a/renderers/glm5.py +++ b/renderers/glm5.py @@ -22,6 +22,7 @@ RenderedTokens, ToolSpec, attribute_text_segments, + extract_message_tool_names, reject_assistant_in_extension, should_preserve_past_thinking, ) @@ -281,6 +282,7 @@ def emit_text_segments( sampled_mask=sampled, is_content=content_mask, message_roles=[m.get("role") or "" for m in messages], + message_tool_names=extract_message_tool_names(messages), ) def render_ids( @@ -456,6 +458,7 @@ def emit_text_segments( sampled_mask=[False] * total_len, is_content=[False] * len(previous_ids) + ext_content, message_roles=[m.get("role") or "" for m in new_messages], + message_tool_names=extract_message_tool_names(new_messages), ) def _render_assistant( diff --git a/renderers/gpt_oss.py b/renderers/gpt_oss.py index f1bb04a..2a9c5ca 100644 --- a/renderers/gpt_oss.py +++ b/renderers/gpt_oss.py @@ -56,6 +56,7 @@ ParsedResponse, RenderedTokens, ToolSpec, + extract_message_tool_names, reject_assistant_in_extension, should_preserve_past_thinking, trim_to_turn_close, @@ -465,6 +466,7 @@ def emit_harmony_message( sampled_mask=sampled, is_content=content_mask, message_roles=[m.get("role") or "" for m in messages], + message_tool_names=extract_message_tool_names(messages), ) def render_ids( @@ -594,6 +596,7 @@ def bridge_to_next_turn( sampled_mask=[False] * total_len, is_content=[False] * len(previous_ids) + ext_content, message_roles=[m.get("role") or "" for m in new_messages], + message_tool_names=extract_message_tool_names(new_messages), ) # ── message conversion ─────────────────────────────────────────────────── diff --git a/renderers/kimi_k2.py b/renderers/kimi_k2.py index 54d6f53..e99dfa7 100644 --- a/renderers/kimi_k2.py +++ b/renderers/kimi_k2.py @@ -23,6 +23,7 @@ ParsedResponse, RenderedTokens, ToolSpec, + extract_message_tool_names, reject_assistant_in_extension, trim_to_turn_close, ) @@ -305,6 +306,7 @@ def emit_text( sampled_mask=sampled, is_content=content_mask, message_roles=[m.get("role") or "" for m in caller_messages], + message_tool_names=extract_message_tool_names(caller_messages), ) def render_ids( @@ -454,6 +456,7 @@ def emit_text( sampled_mask=[False] * total_len, is_content=[False] * len(previous_ids) + ext_content, message_roles=[m.get("role") or "" for m in new_messages], + message_tool_names=extract_message_tool_names(new_messages), ) def _render_assistant( diff --git a/renderers/kimi_k25.py b/renderers/kimi_k25.py index 352a9ee..ba3ca6e 100644 --- a/renderers/kimi_k25.py +++ b/renderers/kimi_k25.py @@ -36,6 +36,7 @@ RenderedTokens, ToolCallParseStatus, ToolSpec, + extract_message_tool_names, reject_assistant_in_extension, should_preserve_past_thinking, trim_to_turn_close, @@ -946,6 +947,7 @@ def emit_image( sampled_mask=sampled, is_content=content_mask, message_roles=[m.get("role") or "" for m in messages], + message_tool_names=extract_message_tool_names(messages), multi_modal_data=mm_data, ) @@ -1188,6 +1190,7 @@ def emit_image( merged_items.setdefault(modality, []).extend(vals) bridge_roles = [m.get("role") or "" for m in new_messages] + bridge_tool_names = extract_message_tool_names(new_messages) if not (merged_hashes or merged_placeholders or merged_items): return RenderedTokens( token_ids=tokens, @@ -1195,6 +1198,7 @@ def emit_image( sampled_mask=sampled, is_content=content_mask, message_roles=bridge_roles, + message_tool_names=bridge_tool_names, ) mm_data = MultiModalData( @@ -1208,6 +1212,7 @@ def emit_image( sampled_mask=sampled, is_content=content_mask, message_roles=bridge_roles, + message_tool_names=bridge_tool_names, multi_modal_data=mm_data, ) diff --git a/renderers/laguna_xs2.py b/renderers/laguna_xs2.py index 9d8c0b3..bd6b64f 100644 --- a/renderers/laguna_xs2.py +++ b/renderers/laguna_xs2.py @@ -36,6 +36,7 @@ RenderedTokens, ToolSpec, attribute_text_segments, + extract_message_tool_names, reject_assistant_in_extension, ) from renderers.configs import LagunaXS2RendererConfig @@ -275,6 +276,7 @@ def emit_text_segments( sampled_mask=sampled, is_content=content_mask, message_roles=[m.get("role") or "" for m in messages], + message_tool_names=extract_message_tool_names(messages), ) def render_ids( @@ -426,6 +428,7 @@ def emit_text_segments( sampled_mask=[False] * total_len, is_content=[False] * len(previous_ids) + ext_content, message_roles=[m.get("role") or "" for m in new_messages], + message_tool_names=extract_message_tool_names(new_messages), ) def _render_assistant( diff --git a/renderers/minimax_m2.py b/renderers/minimax_m2.py index 39c12fa..f990274 100644 --- a/renderers/minimax_m2.py +++ b/renderers/minimax_m2.py @@ -22,6 +22,7 @@ RenderedTokens, ToolSpec, attribute_text_segments, + extract_message_tool_names, reject_assistant_in_extension, should_preserve_past_thinking, trim_to_turn_close, @@ -278,6 +279,7 @@ def emit_token_overlap_body( sampled_mask=sampled, is_content=content_mask, message_roles=[m.get("role") or "" for m in messages], + message_tool_names=extract_message_tool_names(messages), ) def render_ids( @@ -459,6 +461,7 @@ def emit_token_overlap_body( sampled_mask=[False] * total_len, is_content=[False] * len(previous_ids) + ext_content, message_roles=[m.get("role") or "" for m in new_messages], + message_tool_names=extract_message_tool_names(new_messages), ) def _render_assistant( diff --git a/renderers/nemotron3.py b/renderers/nemotron3.py index 06d9d4d..e6398b5 100644 --- a/renderers/nemotron3.py +++ b/renderers/nemotron3.py @@ -25,6 +25,7 @@ RenderedTokens, ToolSpec, attribute_text_segments, + extract_message_tool_names, reject_assistant_in_extension, should_preserve_past_thinking, trim_to_turn_close, @@ -411,6 +412,7 @@ def emit_text_segments( sampled_mask=sampled, is_content=content_mask, message_roles=[m.get("role") or "" for m in original_messages], + message_tool_names=extract_message_tool_names(original_messages), ) def render_ids( @@ -581,6 +583,7 @@ def emit_text_segments( sampled_mask=[False] * total_len, is_content=[False] * len(previous_ids) + ext_content, message_roles=[m.get("role") or "" for m in new_messages], + message_tool_names=extract_message_tool_names(new_messages), ) # ------------------------------------------------------------------ diff --git a/renderers/qwen3.py b/renderers/qwen3.py index fe97561..f744b8c 100644 --- a/renderers/qwen3.py +++ b/renderers/qwen3.py @@ -19,6 +19,7 @@ RenderedTokens, ToolSpec, attribute_text_segments, + extract_message_tool_names, reject_assistant_in_extension, should_preserve_past_thinking, trim_to_turn_close, @@ -247,6 +248,7 @@ def emit_text_segments( sampled_mask=sampled, is_content=content_mask, message_roles=[m.get("role") or "" for m in messages], + message_tool_names=extract_message_tool_names(messages), ) def render_ids( @@ -403,6 +405,7 @@ def emit_text_segments( sampled_mask=[False] * total_len, is_content=[False] * len(previous_ids) + ext_content, message_roles=[m.get("role") or "" for m in new_messages], + message_tool_names=extract_message_tool_names(new_messages), ) def _render_assistant( diff --git a/renderers/qwen35.py b/renderers/qwen35.py index b3c6af7..cdb8ee1 100644 --- a/renderers/qwen35.py +++ b/renderers/qwen35.py @@ -27,6 +27,7 @@ RenderedTokens, ToolSpec, attribute_text_segments, + extract_message_tool_names, reject_assistant_in_extension, should_preserve_past_thinking, trim_to_turn_close, @@ -565,6 +566,7 @@ def flush_buf() -> None: sampled_mask=sampled, is_content=content_mask, message_roles=[m.get("role") or "" for m in messages], + message_tool_names=extract_message_tool_names(messages), multi_modal_data=mm_data, ) @@ -841,6 +843,7 @@ def flush_buf() -> None: merged_items.setdefault(modality, []).extend(vals) bridge_roles = [m.get("role") or "" for m in new_messages] + bridge_tool_names = extract_message_tool_names(new_messages) if not (merged_hashes or merged_placeholders or merged_items): return RenderedTokens( token_ids=tokens, @@ -848,6 +851,7 @@ def flush_buf() -> None: sampled_mask=sampled, is_content=content_mask, message_roles=bridge_roles, + message_tool_names=bridge_tool_names, ) mm_data = MultiModalData( @@ -861,6 +865,7 @@ def flush_buf() -> None: sampled_mask=sampled, is_content=content_mask, message_roles=bridge_roles, + message_tool_names=bridge_tool_names, multi_modal_data=mm_data, ) diff --git a/renderers/qwen3_vl.py b/renderers/qwen3_vl.py index 7287159..9a4ffde 100644 --- a/renderers/qwen3_vl.py +++ b/renderers/qwen3_vl.py @@ -43,6 +43,7 @@ RenderedTokens, ToolSpec, attribute_text_segments, + extract_message_tool_names, reject_assistant_in_extension, trim_to_turn_close, ) @@ -604,6 +605,7 @@ def render_media_content(content: Any) -> None: sampled_mask=em.sampled, is_content=em.is_content, message_roles=[m.get("role") or "" for m in messages], + message_tool_names=extract_message_tool_names(messages), multi_modal_data=mm_data, ) @@ -839,6 +841,7 @@ def render_media_content(content: Any) -> None: sampled_mask=em.sampled, is_content=em.is_content, message_roles=[m.get("role") or "" for m in new_messages], + message_tool_names=extract_message_tool_names(new_messages), multi_modal_data=mm_data, ) diff --git a/tests/test_message_tool_names.py b/tests/test_message_tool_names.py new file mode 100644 index 0000000..f88440f --- /dev/null +++ b/tests/test_message_tool_names.py @@ -0,0 +1,97 @@ +"""Tests for ``RenderedTokens.message_tool_names`` and its populating helper. + +``message_tool_names`` is a per-message sidecar parallel to +``message_roles``: for each tool-role message in the rendered slice +it carries the tool function name, ``None`` everywhere else. The name +comes from ``msg["name"]`` when set, otherwise from a +``tool_call_id`` join against any prior assistant's ``tool_calls`` in +the same slice. Pure metadata — does not affect the rendered token +stream, does not mutate the caller's messages. + +Unit tests below cover the join's case matrix without a tokenizer. +The single integration test runs every renderer in the conftest +matrix to catch any of the ~25 ``RenderedTokens(...)`` construction +sites that might fail to wire the field through. +""" + +from __future__ import annotations + +from renderers.base import extract_message_tool_names + + +def test_extract_empty(): + assert extract_message_tool_names([]) == [] + + +def test_extract_caller_provided_name_wins(): + """``msg['name']`` set by the caller is used verbatim — no join attempted.""" + messages = [ + {"role": "tool", "tool_call_id": "c1", "name": "caller_set", "content": "x"}, + ] + assert extract_message_tool_names(messages) == ["caller_set"] + + +def test_extract_resolves_from_prior_assistant(): + """Tool message without ``name``: recovered via tool_call_id → assistant.tool_calls.""" + messages = [ + {"role": "user", "content": "go"}, + { + "role": "assistant", + "tool_calls": [{"id": "c1", "function": {"name": "screenshot"}}], + }, + {"role": "tool", "tool_call_id": "c1", "content": "ok"}, + ] + assert extract_message_tool_names(messages) == [None, None, "screenshot"] + + +def test_extract_orphan_tool_message_is_none(): + """``tool_call_id`` matching no in-slice assistant resolves to ``None`` + (bridge case: the issuing assistant lives in the prior portion that + ``new_messages`` doesn't cover). + """ + messages = [{"role": "tool", "tool_call_id": "orphan", "content": "x"}] + assert extract_message_tool_names(messages) == [None] + + +def test_extract_does_not_mutate_caller(): + """Caller's tool message must not gain a ``name`` field after extraction — + the helper produces a sidecar list, not a mutated view of the input. + """ + messages = [ + { + "role": "assistant", + "tool_calls": [{"id": "c1", "function": {"name": "f"}}], + }, + {"role": "tool", "tool_call_id": "c1", "content": "x"}, + ] + extract_message_tool_names(messages) + assert "name" not in messages[1] + + +def test_renderer_populates_message_tool_names(model_name, renderer): + """Every renderer wires ``message_tool_names`` through ``RenderedTokens``. + + Catches missed wire-up at any of the ~25 ``RenderedTokens(...)`` + construction sites across concrete renderers. The input is + spec-conformant (tool message carries ``tool_call_id`` but no + ``name``) so the resolution path exercises the internal join. + """ + messages = [ + {"role": "user", "content": "go"}, + { + "role": "assistant", + "content": "", + "tool_calls": [ + { + "id": "c1", + "type": "function", + "function": {"name": "screenshot", "arguments": {}}, + } + ], + }, + {"role": "tool", "tool_call_id": "c1", "content": "ok"}, + ] + rt = renderer.render(messages) + assert rt.message_tool_names == [None, None, "screenshot"], ( + f"{model_name}: got {rt.message_tool_names}" + )