diff --git a/verifiers/clients/openai_chat_completions_client.py b/verifiers/clients/openai_chat_completions_client.py index d932d32d3..6f9feaebe 100644 --- a/verifiers/clients/openai_chat_completions_client.py +++ b/verifiers/clients/openai_chat_completions_client.py @@ -60,6 +60,7 @@ post_chat_completion_with_routed_experts_sidecar, setup_openai_client, ) +from verifiers.utils.response_utils import parse_routed_experts def handle_openai_overlong_prompt(func): @@ -486,13 +487,14 @@ def parse_tokens(response: OpenAIChatResponse) -> ResponseTokens | None: completion_logprobs = [token["logprob"] for token in logprobs_content] choice_extra = choice.model_extra or {} + routed_experts = parse_routed_experts(choice_extra.get("routed_experts")) return ResponseTokens( prompt_ids=prompt_ids, prompt_mask=prompt_mask, completion_ids=completion_ids, completion_mask=completion_mask, completion_logprobs=completion_logprobs, - routed_experts=choice_extra.get("routed_experts"), + routed_experts=routed_experts, ) def parse_reasoning_content_from_response( diff --git a/verifiers/clients/openai_completions_client.py b/verifiers/clients/openai_completions_client.py index 137e7a49a..a170a5872 100644 --- a/verifiers/clients/openai_completions_client.py +++ b/verifiers/clients/openai_completions_client.py @@ -25,6 +25,7 @@ Usage, ) from verifiers.utils.client_utils import setup_openai_client +from verifiers.utils.response_utils import parse_routed_experts OpenAITextMessages = str OpenAITextResponse = Completion @@ -169,12 +170,15 @@ def parse_tokens(response: OpenAITextResponse) -> ResponseTokens | None: ) if completion_logprobs is None: return None + choice_extra = response.choices[0].model_extra or {} + routed_experts = parse_routed_experts(choice_extra.get("routed_experts")) return ResponseTokens( prompt_ids=prompt_ids, prompt_mask=prompt_mask, completion_ids=completion_ids, completion_mask=completion_mask, completion_logprobs=completion_logprobs, + routed_experts=routed_experts, ) return Response( diff --git a/verifiers/clients/renderer_client.py b/verifiers/clients/renderer_client.py index e6a95558a..dfe217b2c 100644 --- a/verifiers/clients/renderer_client.py +++ b/verifiers/clients/renderer_client.py @@ -9,6 +9,7 @@ """ import asyncio +import inspect import json import threading from collections.abc import Mapping @@ -16,28 +17,22 @@ from openai import AsyncOpenAI -from renderers import Message as RendererMessage -from renderers import OverlongPromptError as RendererOverlongPromptError from renderers import ( - MultimodalRenderer, - ParsedToolCall, RenderedTokens, Renderer, RendererPool, - ToolCallParseStatus, ToolSpec, create_renderer_pool, - is_multimodal, ) +from renderers import Message as RendererMessage from renderers import ToolCall as RendererToolCall from renderers import ToolCallFunction -from renderers.client import generate from verifiers.clients.client import Client from verifiers.clients.openai_chat_completions_client import ( handle_openai_overlong_prompt, ) -from verifiers.errors import EmptyModelResponseError, OverlongPromptError +from verifiers.errors import EmptyModelResponseError from verifiers.types import ( AssistantMessage, ClientConfig, @@ -57,6 +52,7 @@ UserMessage, ) from verifiers.utils.client_utils import setup_openai_client +from verifiers.utils.response_utils import parse_routed_experts # Module-level bridge counters. Incremented by every RendererClient instance # that tries to stitch a multi-turn prompt; callers (e.g. prime-rl's @@ -65,6 +61,18 @@ _bridge_metrics_lock = threading.Lock() _bridge_metrics: dict[str, int] = {"attempts": 0, "successes": 0, "failures": 0} +try: + from renderers import MultimodalRenderer, is_multimodal +except ImportError: + MultimodalRenderer = Any + + def is_multimodal(renderer: Renderer) -> bool: + try: + signature = inspect.signature(renderer.bridge_to_next_turn) + except (TypeError, ValueError): + return False + return "previous_multi_modal_data" in signature.parameters + def get_bridge_metrics() -> dict[str, int]: """Snapshot the in-memory bridge counters (attempts/successes/failures).""" @@ -98,15 +106,20 @@ def _record_bridge(success: bool) -> None: # ── Helpers ───────────────────────────────────────────────────────── -async def _maybe_offload(renderer: Renderer | RendererPool, fn): - """Run sync renderer work on a thread iff ``renderer`` is a pool. +async def _run_renderer(renderer: Renderer | RendererPool, fn): + """Run sync renderer work, checking out from a pool when needed. - Pool methods can block on the internal queue/lock; we offload to keep - the event loop responsive. A bare ``Renderer`` runs inline. + RendererPool exposes checkout(), not the renderer protocol methods. Pool + checkout can block on its queue, so keep that branch off the event loop. """ if isinstance(renderer, RendererPool): - return await asyncio.to_thread(fn) - return fn() + + def _work(): + with renderer.checkout() as checked_out: + return fn(checked_out) + + return await asyncio.to_thread(_work) + return fn(renderer) def _get_value(obj: Any, key: str, default: Any = None) -> Any: @@ -329,6 +342,14 @@ def _step_rendered_messages(step: Any) -> list[RendererMessage]: ) +def _coerce_rendered_tokens(value: Any) -> RenderedTokens | None: + if value is None: + return None + if isinstance(value, list): + return RenderedTokens(token_ids=value) + return cast(RenderedTokens, value) + + async def _get_incremental_prompt_ids( *, renderer: Renderer | RendererPool, @@ -383,25 +404,26 @@ async def _get_incremental_prompt_ids( # Text-only renderers' bridge signature doesn't include that # kwarg. ``is_multimodal`` is type-cached so this dispatch is a # dict lookup, not a runtime_checkable Protocol walk. - if is_multimodal(renderer): - mm_renderer = cast(MultimodalRenderer, renderer) - bridge = lambda: mm_renderer.bridge_to_next_turn( # noqa: E731 - previous_prompt_ids, - previous_completion_ids, - tail, - tools=tools, - previous_multi_modal_data=previous_mm_data, - ) - else: - bridge = lambda: renderer.bridge_to_next_turn( # noqa: E731 + def bridge(checked_out: Renderer): + if is_multimodal(checked_out): + mm_renderer = cast(MultimodalRenderer, checked_out) + return mm_renderer.bridge_to_next_turn( + previous_prompt_ids, + previous_completion_ids, + tail, + tools=tools, + previous_multi_modal_data=previous_mm_data, + ) + return checked_out.bridge_to_next_turn( previous_prompt_ids, previous_completion_ids, tail, tools=tools, ) - bridged = await _maybe_offload(renderer, bridge) + + bridged = await _run_renderer(renderer, bridge) _record_bridge(success=bridged is not None) - return bridged + return _coerce_rendered_tokens(bridged) return None @@ -418,6 +440,101 @@ def _parse_finish_reason(raw: str | None) -> FinishReason: return None +async def _generate_with_renderer( + *, + client: AsyncOpenAI, + renderer: Renderer | RendererPool, + messages: list[RendererMessage], + model: str, + prompt_ids: list[int] | None = None, + tools: list[ToolSpec] | None = None, + sampling_params: dict[str, Any] | None = None, + multi_modal_data: Any = None, + cache_salt: str | None = None, + priority: int | None = None, + extra_headers: dict[str, str] | None = None, +) -> dict[str, Any]: + """Call PrimeRL's generate endpoint without decoding routed_experts. + + Older renderers.client.generate decodes ``routed_experts.data`` as base85 + int32s. PrimeRL now returns a compact base64 uint8 sidecar; keep that dict + intact so the orchestrator can decode it with the matching codec. + """ + if tools and not getattr(renderer, "supports_tools", True): + raise ValueError( + f"{type(renderer).__name__} does not support tools. " + "Choose a model-specific renderer instead of the default fallback." + ) + + def prepare(checked_out: Renderer): + ids = ( + list(prompt_ids) + if prompt_ids is not None + else checked_out.render_ids( + messages, tools=tools, add_generation_prompt=True + ) + ) + return ids, checked_out.get_stop_token_ids() + + prompt_ids, stop_token_ids = await _run_renderer(renderer, prepare) + + sp: dict[str, Any] = dict(sampling_params or {}) + sp["stop_token_ids"] = stop_token_ids + sp["logprobs"] = 1 + sp.setdefault("skip_special_tokens", False) + + body: dict[str, Any] = { + "model": model, + "token_ids": prompt_ids, + "sampling_params": sp, + } + if cache_salt is not None: + body["cache_salt"] = cache_salt + if priority is not None: + body["priority"] = priority + if multi_modal_data is not None: + body["multi_modal_data"] = multi_modal_data + + base = str(client.base_url).rstrip("/").removesuffix("/v1") + endpoint = f"{base}/inference/v1/generate" + post_kwargs: dict[str, Any] = { + "cast_to": cast(Any, dict[str, Any]), + "body": body, + } + if extra_headers: + post_kwargs["options"] = cast(Any, {"headers": extra_headers}) + data = await client.post(endpoint, **post_kwargs) + + choice = (data.get("choices") or [{}])[0] + completion_ids = choice.get("token_ids") or [] + + parsed = await _run_renderer( + renderer, + lambda checked_out: checked_out.parse_response(completion_ids), + ) + + raw_logprobs = choice.get("logprobs") or {} + content_lp = raw_logprobs.get("content") if isinstance(raw_logprobs, dict) else None + completion_logprobs = [float(c.get("logprob") or 0.0) for c in content_lp or []] + + finish_reason = choice.get("finish_reason") + if parsed.tool_calls and finish_reason == "stop": + finish_reason = "tool_calls" + + return { + "request_id": data.get("request_id") or "", + "prompt_ids": list(prompt_ids), + "completion_ids": list(completion_ids), + "completion_logprobs": completion_logprobs, + "content": parsed.content, + "reasoning_content": parsed.reasoning_content, + "tool_calls": parsed.tool_calls, + "finish_reason": finish_reason, + "routed_experts": choice.get("routed_experts"), + "multi_modal_data": choice.get("multi_modal_data"), + } + + class RendererClient( Client[AsyncOpenAI, list[RendererMessage], dict[str, Any], ToolSpec] ): @@ -544,6 +661,10 @@ async def get_native_response( renderer = self._get_renderer_or_pool(model) args = dict(sampling_args) + extra_headers = { + **dict(args.pop("extra_headers", None) or {}), + **dict(kwargs.pop("extra_headers", None) or {}), + } sampling_params: dict[str, Any] = dict(args.pop("extra_body", None) or {}) for key in ( "temperature", @@ -574,46 +695,31 @@ async def get_native_response( # /inference/v1/generate without re-rendering the whole turn. if bridged is not None: prompt_ids = bridged.token_ids - multi_modal_data = bridged.multi_modal_data + multi_modal_data = getattr(bridged, "multi_modal_data", None) else: prompt_ids = None multi_modal_data = None - # ``renderers.client.generate`` discovers the engine's context-length - # cap on its own (via ``GET /v1/models``, cached) and raises - # ``renderers.OverlongPromptError`` on pre-flight overflow. Rebadge - # that into the verifiers-native ``OverlongPromptError`` so the - # ``MultiTurnEnv.prompt_too_long`` stop condition picks it up via - # the ``vf.Error`` hierarchy. The ``@handle_openai_overlong_prompt`` - # decorator still handles the fallback case (cap unknown → engine - # 4xx → vf.OverlongPromptError) for engines whose ``/v1/models`` - # doesn't expose ``max_model_len``. - try: - return await generate( - client=self.client, - renderer=renderer, - messages=prompt, - model=model, - prompt_ids=prompt_ids, - multi_modal_data=multi_modal_data, - tools=tools, - sampling_params=sampling_params, - cache_salt=args.get("cache_salt") - or sampling_params.pop("cache_salt", None), - priority=args.get("priority") or sampling_params.pop("priority", None), - extra_headers=args.get("extra_headers"), - ) - except RendererOverlongPromptError as exc: - raise OverlongPromptError(str(exc)) from exc + return await _generate_with_renderer( + client=self.client, + renderer=renderer, + messages=prompt, + model=model, + prompt_ids=prompt_ids, + tools=tools, + sampling_params=sampling_params, + cache_salt=args.get("cache_salt") + or sampling_params.pop("cache_salt", None), + priority=args.get("priority") or sampling_params.pop("priority", None), + extra_headers=extra_headers or None, + multi_modal_data=multi_modal_data, + ) async def raise_from_native_response(self, response: dict[str, Any]) -> None: if response is None: raise EmptyModelResponseError("Model returned no response") has_content = bool(response.get("content")) - # ``tool_calls`` is now ``list[ParsedToolCall]`` (renderers >=0.1.8.dev1) - # — a non-empty list with only malformed attempts still counts as the - # model having tried to call a tool, so we don't filter by status here. has_tool_calls = bool(response.get("tool_calls")) has_reasoning = bool(response.get("reasoning_content")) if not (has_content or has_tool_calls or has_reasoning): @@ -627,31 +733,20 @@ async def from_native_response(self, response: dict[str, Any]) -> Response: reasoning_content = response.get("reasoning_content") finish_reason = _parse_finish_reason(response.get("finish_reason")) - # renderers >=0.1.8.dev1 emits ParsedToolCall dataclasses (with .name, - # .arguments, .status, .id). Skip non-OK attempts — they're surfaced - # on the parsed response so trainers can inspect, but verifiers' - # tool-loop only acts on well-formed calls. tool_calls = None - raw_tcs = response.get("tool_calls") or [] - ok_tcs = [ - tc - for tc in raw_tcs - if isinstance(tc, ParsedToolCall) - and tc.status == ToolCallParseStatus.OK - and tc.name - ] - if ok_tcs: + raw_tcs = response.get("tool_calls") + if raw_tcs: tool_calls = [ ToolCall( - id=tc.id or f"call_{i}", - name=tc.name or "", + id=f"call_{i}", + name=tc["function"]["name"], arguments=( - tc.arguments - if isinstance(tc.arguments, str) - else json.dumps(tc.arguments or {}) + tc["function"]["arguments"] + if isinstance(tc["function"]["arguments"], str) + else json.dumps(tc["function"]["arguments"]) ), ) - for i, tc in enumerate(ok_tcs) + for i, tc in enumerate(raw_tcs) ] prompt_ids = response.get("prompt_ids", []) @@ -664,7 +759,7 @@ async def from_native_response(self, response: dict[str, Any]) -> Response: completion_ids=completion_ids, completion_mask=[1] * len(completion_ids), completion_logprobs=completion_logprobs, - routed_experts=response.get("routed_experts"), + routed_experts=parse_routed_experts(response.get("routed_experts")), multi_modal_data=response.get("multi_modal_data"), ) diff --git a/verifiers/types.py b/verifiers/types.py index 0e4f63c19..a54d39747 100644 --- a/verifiers/types.py +++ b/verifiers/types.py @@ -63,7 +63,7 @@ class CustomBaseModel(BaseModel): """Allow extras and dict-like attribute access.""" - model_config = ConfigDict(extra="allow") + model_config = ConfigDict(extra="allow", arbitrary_types_allowed=True) def __getitem__(self, key): return getattr(self, key) @@ -183,9 +183,11 @@ class Usage(CustomBaseModel): total_tokens: int +RoutedExpertsData: TypeAlias = str | bytes | bytearray | memoryview + + class RoutedExpertsPayload(TypedDict): - # Keep the raw response sidecar opaque so Pydantic does not validate memoryview. - data: Any + data: RoutedExpertsData shape: list[int] diff --git a/verifiers/utils/client_utils.py b/verifiers/utils/client_utils.py index 7fce85c61..c97e5e990 100644 --- a/verifiers/utils/client_utils.py +++ b/verifiers/utils/client_utils.py @@ -2,8 +2,8 @@ import logging import os from collections.abc import Mapping -from typing import Any from pathlib import Path +from typing import Any import httpx from anthropic import AsyncAnthropic diff --git a/verifiers/utils/response_utils.py b/verifiers/utils/response_utils.py index 0ec585423..ff4ef929e 100644 --- a/verifiers/utils/response_utils.py +++ b/verifiers/utils/response_utils.py @@ -1,13 +1,27 @@ +from typing import Any + from verifiers.types import ( AssistantMessage, Messages, Response, + RoutedExpertsPayload, TrajectoryStepTokens, ) ROUTED_EXPERTS_DATA_PREFIX = b'"routed_experts":{"data":"' +def parse_routed_experts(raw: Any) -> RoutedExpertsPayload | None: + if raw is None: + return None + assert isinstance(raw, dict) + data = raw["data"] + shape = raw["shape"] + assert isinstance(data, (str, bytes, bytearray, memoryview)) + assert isinstance(shape, list) + return {"data": data, "shape": [int(dim) for dim in shape]} + + def strip_routed_experts_data(raw: bytes) -> tuple[bytes, memoryview | None]: data_start = raw.find(ROUTED_EXPERTS_DATA_PREFIX) if data_start < 0: