diff --git a/pyproject.toml b/pyproject.toml index 21ae65e4d..73442e494 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -134,6 +134,9 @@ prime-pydantic-config = false renderers = false openenv-core = false +[tool.uv.sources] +textarena = { git = "https://github.com/nph4rd/TextArena.git", branch = "fix/kuhn-poker-phantom-ante" } + [tool.uv.extra-build-dependencies] flash-attn = [{ requirement = "torch", match-runtime = true }] diff --git a/tests/test_rubric.py b/tests/test_rubric.py index 8b2293364..dcfafe4e2 100644 --- a/tests/test_rubric.py +++ b/tests/test_rubric.py @@ -1,5 +1,7 @@ """Tests for the Rubric class.""" +from __future__ import annotations + from typing import cast import pytest @@ -8,6 +10,24 @@ from verifiers.types import RewardFunc, RolloutInput, RolloutTiming, State +# Regression for `_is_multiagent_func` mis-classifying functions defined under +# ``from __future__ import annotations`` (PEP 563): the return annotation is +# the string ``"dict[str, float]"`` rather than the resolved generic alias, so +# ``get_origin(annotation)`` returns ``None`` and the function was routed +# through the individual-reward path — where its dict return value got coerced +# to 0 by ``float(dict)`` failing. ``_is_multiagent_func`` now uses +# ``typing.get_type_hints`` to resolve the string. +async def _multiagent_under_future_annotations( + state: State, **_kwargs +) -> dict[str, float]: + return {"agent_a": 0.5, "agent_b": 1.0} + + +def test_is_multiagent_func_handles_future_annotations(): + rubric = Rubric() + assert rubric._is_multiagent_func(_multiagent_under_future_annotations) is True + + class TestRubric: """Test cases for the Rubric class.""" diff --git a/tests/test_spawning_protocol.py b/tests/test_spawning_protocol.py new file mode 100644 index 000000000..a9137f32d --- /dev/null +++ b/tests/test_spawning_protocol.py @@ -0,0 +1,225 @@ +"""End-to-end test for SpawningProtocol via a toy proposer/solver env. + +The proposer picks an integer N. The protocol then spawns k child solver +rollouts whose job is to double N. We assert that: + + - the parent's trajectory contains both proposer and child solver steps + - child steps are tagged with agent_id="solver" and is_trainable + - state["extras"]["spawns"] carries SpawnResult(s) with one State per child + - each child's reward was computed by its own rubric (score was 1.0 only + when the solver's completion equaled 2*N) +""" + +from __future__ import annotations + +import re + +import pytest +from datasets import Dataset + +import verifiers as vf +from verifiers import ( + Agent, + MultiAgentEnv, + Rubric, + SingleTurnEnv, + SpawningProtocol, + SpawnSpec, +) +from verifiers.types import Messages, State + + +PROPOSER_NUMBER = 5 # what the proposer always picks in this test +NUM_CHILDREN = 3 + + +# --------------------------------------------------------------------------- # +# Child env: simple single-turn doubling env. +# --------------------------------------------------------------------------- # + + +def _doubling_correct(completion, answer, **_) -> float: + text = completion if isinstance(completion, str) else completion[-1]["content"] + match = re.search(r"-?\d+", text) + if match is None: + return 0.0 + return 1.0 if int(match.group(0)) == int(answer) else 0.0 + + +@pytest.fixture +def child_solver_env(mock_client): + """SingleTurnEnv whose prompt asks for 2*N and rubric scores correctness.""" + dataset = Dataset.from_dict( + { + "prompt": [[{"role": "user", "content": f"Double {PROPOSER_NUMBER}."}]], + "answer": [str(2 * PROPOSER_NUMBER)], + "example_id": [0], + } + ) + rubric = Rubric(funcs=[_doubling_correct]) + + env = SingleTurnEnv( + client=mock_client, + model="test-model", + dataset=dataset, + parser=vf.Parser(), + rubric=rubric, + ) + + # Pre-stage half of the children's responses to be correct, half wrong. + # The mocked client returns the same default response unless overridden; + # since the proposer ALSO uses this client we'll just set a single default + # and use add_response for both turn types. + mock_client.add_response( + [{"role": "user", "content": f"Double {PROPOSER_NUMBER}."}], + str(2 * PROPOSER_NUMBER), + ) + return env + + +# --------------------------------------------------------------------------- # +# Parent env: proposer that emits a number, then spawns NUM_CHILDREN solvers. +# --------------------------------------------------------------------------- # + + +class _OneShotSpawnProtocol(SpawningProtocol): + """Spawns child solvers exactly once, immediately after the proposer's turn.""" + + def __init__(self, child_env, agent_id: str, num_children: int): + self._child_env = child_env + self._agent_id = agent_id + self._num_children = num_children + + def get_initial_agent(self, state: State) -> str: + return "proposer" + + def get_next_agent(self, state: State) -> str: + # Single-turn protocol: never returns a "next" agent because + # on_turn_complete sets state["is_completed"]=True. + return "proposer" + + def should_spawn(self, state: State) -> bool: + # Spawn once: only if the proposer just acted and we haven't spawned yet. + already_spawned = bool(state["extras"].get("spawns")) + last_step = state["trajectory"][-1] if state["trajectory"] else None + last_agent = (last_step or {}).get("extras", {}).get("agent_id") + return last_agent == "proposer" and not already_spawned + + def get_spawn_specs(self, state: State) -> list[SpawnSpec]: + # The proposer's "answer" is the last word of its completion. + text = state["trajectory"][-1]["completion"][-1]["content"] + n = int(re.search(r"-?\d+", text).group(0)) + prompt = [{"role": "user", "content": f"Double {n}."}] + inputs = [ + {"prompt": prompt, "answer": str(2 * n), "example_id": i} + for i in range(self._num_children) + ] + return [ + SpawnSpec( + agent_id=self._agent_id, + child_env=self._child_env, + inputs=inputs, + is_trainable=True, + ) + ] + + +class _ProposerEnv(MultiAgentEnv): + """One-turn proposer that picks a number, registered as a trainable agent.""" + + async def build_agent_prompt(self, agent_id: str, state: State) -> Messages: + return [ + { + "role": "user", + "content": "Pick an integer and the solver will try to double it.", + } + ] + + @vf.stop + async def proposer_done(self, state: State, **kwargs) -> bool: + # End the rollout once the proposer's turn has been spawned out; + # the spawn block in MultiAgentEnv.rollout() runs before the next + # iteration's is_completed check, so children finish first. + return bool(state.get("extras", {}).get("spawns")) + + +@pytest.fixture +def proposer_env(mock_client, child_solver_env): + protocol = _OneShotSpawnProtocol( + child_env=child_solver_env, agent_id="solver", num_children=NUM_CHILDREN + ) + rubric = Rubric() + env = _ProposerEnv( + protocol=protocol, + client=mock_client, + model="test-model", + dataset=Dataset.from_dict({"prompt": [[{"role": "user", "content": "go"}]], "example_id": [0]}), + parser=vf.Parser(), + rubric=rubric, + max_turns=8, + ) + env.register_agent(Agent(id="proposer", system_prompt="", is_trainable=True)) + env.register_agent(Agent(id="solver", system_prompt="", is_trainable=True)) + + mock_client.add_response( + [ + { + "role": "user", + "content": "Pick an integer and the solver will try to double it.", + } + ], + str(PROPOSER_NUMBER), + ) + return env + + +# --------------------------------------------------------------------------- # +# Tests +# --------------------------------------------------------------------------- # + + +@pytest.mark.asyncio +async def test_spawning_protocol_runs_children_and_records_spawns( + proposer_env, mock_client +): + """One proposer turn → NUM_CHILDREN solver children → all embedded + recorded.""" + state = await proposer_env.rollout( + {"prompt": [{"role": "user", "content": "go"}], "example_id": 0}, + client=mock_client, + model="test-model", + sampling_args={"temperature": 1.0}, + ) + + # 1. The parent trajectory contains the proposer's step plus one per child. + agent_ids = [s["extras"].get("agent_id") for s in state["trajectory"]] + assert agent_ids.count("proposer") == 1 + assert agent_ids.count("solver") == NUM_CHILDREN + + # 2. Spawns recorded. + spawns = state["extras"]["spawns"] + assert len(spawns) == 1 + spawn = spawns[0] + assert spawn.spec.agent_id == "solver" + assert len(spawn.states) == NUM_CHILDREN + + # 3. Children were scored by the child env's own rubric — the mocked + # solver always returns 2*N so each child's reward is 1.0. + for child_state in spawn.states: + assert child_state["reward"] == 1.0 + + +@pytest.mark.asyncio +async def test_child_trajectory_steps_carry_is_trainable_tag( + proposer_env, mock_client +): + state = await proposer_env.rollout( + {"prompt": [{"role": "user", "content": "go"}], "example_id": 0}, + client=mock_client, + model="test-model", + sampling_args={"temperature": 1.0}, + ) + child_steps = [s for s in state["trajectory"] if s["extras"].get("agent_id") == "solver"] + assert child_steps, "expected child steps in parent trajectory" + for step in child_steps: + # is_trainable was set on the SpawnSpec; it must flow through to steps. + assert step["extras"].get("is_trainable") is True diff --git a/uv.lock b/uv.lock index b1c47f140..85c9a01f2 100644 --- a/uv.lock +++ b/uv.lock @@ -19,8 +19,6 @@ conflicts = [[ ]] [options] -exclude-newer = "0001-01-01T00:00:00Z" # This has no effect and is included for backwards compatibility when using relative exclude-newer values. -exclude-newer-span = "P7D" [options.exclude-newer-package] prime-tunnel = false @@ -6195,8 +6193,8 @@ wheels = [ [[package]] name = "textarena" -version = "0.7.4" -source = { registry = "https://pypi.org/simple" } +version = "0.7.3" +source = { git = "https://github.com/nph4rd/TextArena.git?branch=fix%2Fkuhn-poker-phantom-ante#e3ab9a98faace0ac90fb80e5f6e886e269cf4d4d" } dependencies = [ { name = "chess" }, { name = "nltk" }, @@ -6206,10 +6204,6 @@ dependencies = [ { name = "rich" }, { name = "websockets" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/ba/04/4a3ca42093d0be2a9c377ae3335a6c6baac1d278ae932562ec69f339d172/textarena-0.7.4.tar.gz", hash = "sha256:28bb9170d7718f2ae05e4515bea82262422731e563fc7318a9e7983de0cadd4f", size = 954969, upload-time = "2025-10-16T14:41:55.981Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/26/b4/9a9ba65154aff853c75b3d7324319d168ad9c69c6097f4aa3c16da7d9ef3/textarena-0.7.4-py3-none-any.whl", hash = "sha256:684784e78278e518066f67557ee93b47c238d16cbbd15d3abdaa3147562d3024", size = 1073570, upload-time = "2025-10-16T14:41:53.965Z" }, -] [[package]] name = "textual" @@ -6895,7 +6889,7 @@ requires-dist = [ { name = "setproctitle", specifier = ">=1.3.0" }, { name = "stagehand", marker = "extra == 'browser'", specifier = ">=3.0.0" }, { name = "tenacity", specifier = ">=8.5.0" }, - { name = "textarena", marker = "extra == 'ta'" }, + { name = "textarena", marker = "extra == 'ta'", git = "https://github.com/nph4rd/TextArena.git?branch=fix%2Fkuhn-poker-phantom-ante" }, { name = "textual" }, { name = "tomli", marker = "python_full_version < '3.11'" }, { name = "torch", marker = "extra == 'rl'", specifier = ">=2.8.0,<2.9.0" }, @@ -6922,7 +6916,7 @@ dev = [ { name = "renderers", specifier = ">=0.1.8.dev4" }, { name = "ruff" }, { name = "stagehand", specifier = ">=3.0.0" }, - { name = "textarena" }, + { name = "textarena", git = "https://github.com/nph4rd/TextArena.git?branch=fix%2Fkuhn-poker-phantom-ante" }, { name = "ty", specifier = ">=0.0.1a29,<0.0.22" }, ] policy = [{ name = "semgrep", specifier = ">=1.150.0" }] diff --git a/verifiers/__init__.py b/verifiers/__init__.py index b39b0e3f7..b84763942 100644 --- a/verifiers/__init__.py +++ b/verifiers/__init__.py @@ -99,8 +99,15 @@ "Terminus2", "Terminus2Config", "SignalConfig", + "Agent", + "Protocol", + "RoundRobinProtocol", + "SpawningProtocol", + "SpawnSpec", + "SpawnResult", "Environment", "MultiTurnEnv", + "MultiAgentEnv", "SingleTurnEnv", "PythonEnv", "SandboxEnv", @@ -169,6 +176,13 @@ "SingleTurnEnv": "verifiers.envs.singleturn_env:SingleTurnEnv", "StatefulToolEnv": "verifiers.envs.stateful_tool_env:StatefulToolEnv", "ToolEnv": "verifiers.envs.tool_env:ToolEnv", + "Agent": "verifiers.envs.agent:Agent", + "Protocol": "verifiers.envs.protocol:Protocol", + "RoundRobinProtocol": "verifiers.envs.protocol:RoundRobinProtocol", + "SpawningProtocol": "verifiers.envs.protocol:SpawningProtocol", + "SpawnSpec": "verifiers.envs.protocol:SpawnSpec", + "SpawnResult": "verifiers.envs.protocol:SpawnResult", + "MultiAgentEnv": "verifiers.envs.multiagent_env:MultiAgentEnv", "EnvGroup": "verifiers.envs.env_group:EnvGroup", "JudgeRubric": "verifiers.rubrics.judge_rubric:JudgeRubric", "load_environment": "verifiers.utils.env_utils:load_environment", @@ -281,8 +295,17 @@ def __getattr__(name: str): from .clients.openai_completions_client import OpenAICompletionsClient # noqa: F401 from .clients.openai_responses_client import OpenAIResponsesClient # noqa: F401 from .clients.renderer_client import RendererClient # noqa: F401 + from .envs.agent import Agent # noqa: F401 + from .envs.protocol import ( # noqa: F401 + Protocol, + RoundRobinProtocol, + SpawningProtocol, + SpawnResult, + SpawnSpec, + ) from .envs.env_group import EnvGroup # noqa: F401 from .envs.environment import Environment # noqa: F401 + from .envs.multiagent_env import MultiAgentEnv # noqa: F401 from .envs.experimental.cli_agent_env import CliAgentEnv # noqa: F401 from .envs.experimental.gym_env import GymEnv # noqa: F401 from .envs.experimental.harbor_env import HarborEnv # noqa: F401 diff --git a/verifiers/envs/agent.py b/verifiers/envs/agent.py new file mode 100644 index 000000000..3f52d67ba --- /dev/null +++ b/verifiers/envs/agent.py @@ -0,0 +1,35 @@ +""" +Agent: A participant in multi-agent environments. + +Contains agent metadata (id, system prompt, trainability). +""" + +from dataclasses import dataclass + + +@dataclass +class Agent: + """ + An agent in a multi-agent environment. + + Fields: + id: Unique identifier for this agent (e.g., "player_0", "guesser") + system_prompt: The agent's specific instructions + is_trainable: Whether to compute gradients for this agent's actions + """ + + id: str + system_prompt: str = "" + is_trainable: bool = True + + def __hash__(self) -> int: + return hash(self.id) + + def __eq__(self, other: object) -> bool: + if isinstance(other, Agent): + return self.id == other.id + return False + + def __repr__(self) -> str: + trainable_str = "trainable" if self.is_trainable else "frozen" + return f"Agent(id={self.id!r}, {trainable_str})" diff --git a/verifiers/envs/env_group.py b/verifiers/envs/env_group.py index f2104db98..1d17e1841 100644 --- a/verifiers/envs/env_group.py +++ b/verifiers/envs/env_group.py @@ -317,6 +317,7 @@ async def run_rollout( # type: ignore[override] max_retries: int = 0, state_columns: list[str] | None = None, env_client: EnvClient | None = None, + actor_models: dict[str, str] | None = None, ) -> vf.RolloutOutput: target_env_client = env_client or self.env_client if target_env_client is not None: @@ -331,6 +332,7 @@ async def run_rollout( # type: ignore[override] sampling_args, max_retries, state_columns, + actor_models=actor_models, ) env_name, child_input, route = self._route_child_input(input) @@ -343,6 +345,7 @@ async def run_rollout( # type: ignore[override] max_retries, state_columns, env.env_client, + actor_models=actor_models, ) return _set_info_route(output, route) # type: ignore[return-value] @@ -356,6 +359,7 @@ async def run_group( # type: ignore[override] max_retries: int = 0, state_columns: list[str] | None = None, env_client: EnvClient | None = None, + actor_models: dict[str, str] | None = None, ) -> list[vf.RolloutOutput]: target_env_client = env_client or self.env_client if target_env_client is not None: @@ -370,6 +374,7 @@ async def run_group( # type: ignore[override] sampling_args, max_retries, state_columns, + actor_models=actor_models, ) env_name, first_child_input, route = self._route_child_input(group_inputs[0]) @@ -396,6 +401,7 @@ async def run_group( # type: ignore[override] max_retries, state_columns, env.env_client, + actor_models=actor_models, ) return [_set_info_route(output, route) for output in outputs] # type: ignore[return-value] diff --git a/verifiers/envs/environment.py b/verifiers/envs/environment.py index ed379f086..1f972ee49 100644 --- a/verifiers/envs/environment.py +++ b/verifiers/envs/environment.py @@ -745,6 +745,7 @@ async def run_rollout( max_retries: int = 0, state_columns: list[str] | None = None, env_client: EnvClient | None = None, + actor_models: dict[str, str] | None = None, ) -> RolloutOutput: """Generate and, optionally, score a rollout.""" @@ -765,6 +766,7 @@ async def run_rollout( sampling_args, max_retries, state_columns, + actor_models=actor_models, ) resolved_client = resolve_client(client) @@ -791,6 +793,7 @@ async def run_group( max_retries: int = 0, state_columns: list[str] | None = None, env_client: EnvClient | None = None, + actor_models: dict[str, str] | None = None, **kwargs, ) -> list[RolloutOutput]: """Generate and, optionally, score one group.""" @@ -812,6 +815,7 @@ async def run_group( sampling_args, max_retries, state_columns, + actor_models=actor_models, ) resolved_client = resolve_client(client) diff --git a/verifiers/envs/multiagent_env.py b/verifiers/envs/multiagent_env.py new file mode 100644 index 000000000..872ca8a3a --- /dev/null +++ b/verifiers/envs/multiagent_env.py @@ -0,0 +1,337 @@ +""" +Multi-agent environment for turn-based games. + +This module provides the base class for multi-agent RL environments, extending +MultiTurnEnv with support for: +- Multiple agents with distinct system prompts +- Turn order management via Protocol +- Per-agent trajectory tagging for credit assignment + +Key concepts: +- Agent: A participant with its own identity/prompt (defined in agent.py) +- Protocol: Defines turn order and interaction patterns (defined in protocol.py) + +Environment Implementation: +- Subclasses must provide a Protocol for turn management +- Subclasses implement these main hooks: + - build_agent_prompt(agent_id, state): Build fresh prompt for this agent + - on_turn_complete(state): Update game state after each turn +""" + +import asyncio +from abc import abstractmethod + +import verifiers as vf +from verifiers.envs.agent import Agent +from verifiers.envs.multiturn_env import MultiTurnEnv +from verifiers.envs.protocol import Protocol, SpawningProtocol, SpawnResult, SpawnSpec +from verifiers.types import Messages, State, TrajectoryStep +from verifiers.utils.message_utils import normalize_messages + + +class MultiAgentEnv(MultiTurnEnv): + """ + Base class for multi-agent environments. + + Turn order is managed by a Protocol, which must be provided at init. + This keeps turn logic reusable and separate from environment logic. + + Subclasses must implement: + - build_agent_prompt(): Build prompt for current agent + + Subclasses may optionally override: + - on_turn_complete(): Game logic after each turn + """ + + # List of agent IDs this environment uses (e.g., ["player_0", "player_1"]) + # Subclasses should override this or set in __init__ + agents: list[str] = [] + + # Per-agent model routing for multi-policy training (e.g. per-agent LoRA adapters). + # Maps agent_id -> model name. When set, each agent's turns use its own model. + # When None, all agents share the single model passed to rollout(). + actor_models: dict[str, str] | None = None + + def __init__(self, protocol: Protocol, **kwargs): + """ + Initialize multi-agent environment. + + Args: + protocol: Protocol for turn order management. + **kwargs: Passed to MultiTurnEnv + """ + super().__init__(**kwargs) + self._protocol = protocol + self._agent_registry: dict[str, Agent] = {} + + def register_agent(self, agent: Agent) -> None: + """Register an Agent for lookup by get_agent().""" + self._agent_registry[agent.id] = agent + if agent.id not in self.agents: + self.agents.append(agent.id) + + def get_agent(self, agent_id: str) -> Agent: + """Get an agent by ID.""" + if agent_id not in self._agent_registry: + raise KeyError( + f"Agent '{agent_id}' not found. Did you call register_agent()?" + ) + return self._agent_registry[agent_id] + + # ------------------------------------------------------------------------- + # Turn Management (delegated to Protocol) + # ------------------------------------------------------------------------- + + def get_initial_agent(self, state: State) -> str: + """Return the agent ID that starts the rollout.""" + return self._protocol.get_initial_agent(state) + + def get_next_agent(self, state: State) -> str: + """Return the agent ID for the next turn.""" + return self._protocol.get_next_agent(state) + + # ------------------------------------------------------------------------- + # Agent Prompt Building (Subclasses Implement This) + # ------------------------------------------------------------------------- + + @abstractmethod + async def build_agent_prompt(self, agent_id: str, state: State) -> Messages: + """ + Build the prompt for the given agent's turn. + + This is called BEFORE the model generates a response. + Build a fresh prompt with whatever context this agent needs. + + Args: + agent_id: The agent who will respond (e.g., "player_0") + state: Current game state with trajectory and extras + + Returns: + Messages list with system prompt and user content + """ + pass + + # ------------------------------------------------------------------------- + # Game Logic Hook + # ------------------------------------------------------------------------- + + async def on_turn_complete(self, state: State) -> None: + """ + Update game state after a turn completes. + + This is called AFTER the model response is stored in trajectory. + Use this for game logic: + - Update scores, counters, flags + - Check win conditions + - Parse and validate actions + + The last turn's info is in state["trajectory"][-1]: + - ["completion"][-1]["content"]: The model's response text + - ["extras"]["agent_id"]: Which agent just responded + + Args: + state: Current game state (mutate extras as needed) + """ + pass + + # ------------------------------------------------------------------------- + # State Setup + # ------------------------------------------------------------------------- + + async def setup_state(self, state: State) -> State: + """Initialize multi-agent state fields.""" + state = await super().setup_state(state) + state["extras"] = state.get("extras", {}) + state["extras"]["current_agent_id"] = None + return state + + # ------------------------------------------------------------------------- + # Parent Class Requirement (env_response) + # ------------------------------------------------------------------------- + + async def env_response( + self, messages: Messages, state: State, **kwargs + ) -> Messages: + """ + Satisfy MultiTurnEnv's abstract requirement. + + MultiAgentEnv uses on_turn_complete() instead, which is called + explicitly in our rollout() after storing the response. + """ + return [] + + # ------------------------------------------------------------------------- + # Trajectory Management + # ------------------------------------------------------------------------- + + async def add_trajectory_step( + self, state: State, trajectory_step: TrajectoryStep + ) -> None: + """Tag trajectory step with agent_id.""" + current_agent_id = state["extras"].get("current_agent_id") + if current_agent_id: + trajectory_step["extras"]["agent_id"] = current_agent_id + # Copy trainability from Agent to step + agent = self.get_agent(current_agent_id) + trajectory_step["extras"]["is_trainable"] = agent.is_trainable + await super().add_trajectory_step(state, trajectory_step) + + # ------------------------------------------------------------------------- + # Main Rollout Loop + # ------------------------------------------------------------------------- + + async def rollout( + self, + input, + client, + model, + sampling_args=None, + ) -> State: + """ + Run a multi-agent episode. + + Flow: + 1. Setup state + 2. Loop until game ends: + a. Determine current agent + b. Build prompt via build_agent_prompt() + c. Get model response + d. Store in trajectory + e. Process via on_turn_complete() + f. If protocol is a SpawningProtocol and should_spawn(state): + run child sub-rollouts in parallel, embed their trajectory + steps into the parent's trajectory tagged with the spec's + agent_id, and store SpawnResults in state["extras"]["spawns"] + for the rubric to consume. + 3. Return final state + """ + state = await self.init_state(input, client, model, sampling_args) + try: + state = await self.setup_state(state) + except vf.Error as e: + state["error"] = e + return state + + # Determine first agent + state["extras"]["current_agent_id"] = self.get_initial_agent(state) + + while not await self.is_completed(state): + agent_id = state["extras"]["current_agent_id"] + + try: + # 1. Build prompt for this agent + prompt_messages = await self.build_agent_prompt(agent_id, state) + prompt_messages = normalize_messages( + prompt_messages, field_name="agent_prompt" + ) + + # 2. Get model response (route to per-agent model if configured) + agent_model = ( + self.actor_models.get(agent_id) if self.actor_models else None + ) + response = await self.get_model_response( + state, prompt_messages, model=agent_model + ) + + # 3. Store in trajectory (tags with agent_id) + await self.add_model_response(state, prompt_messages, response) + + # 4. Process turn (game logic) + await self.on_turn_complete(state) + + # 5. Spawn sub-rollouts if the protocol requests it. + if isinstance(self._protocol, SpawningProtocol) and ( + self._protocol.should_spawn(state) + ): + specs = self._protocol.get_spawn_specs(state) + await self._run_spawns(specs, state, client, model, sampling_args) + + # 6. Determine next agent (if game continues) + if not await self.is_completed(state): + state["extras"]["current_agent_id"] = self.get_next_agent(state) + + except vf.OverlongPromptError: + state["prompt_too_long"] = True + state["is_truncated"] = True + break + except vf.Error as e: + state["error"] = e + break + + await self.render_completion(state) + return state + + # ------------------------------------------------------------------------- + # Spawning support (SpawningProtocol) + # ------------------------------------------------------------------------- + + async def _run_spawns( + self, + specs: list[SpawnSpec], + state: State, + client, + model, + sampling_args, + ) -> None: + """Execute SpawnSpecs in parallel, score each child, embed their + trajectory steps into the parent's trajectory, and record SpawnResults + in ``state["extras"]["spawns"]``. + + Each child's trajectory steps are tagged with the spec's ``agent_id`` + and ``is_trainable`` (setdefault, so a child env that already tagged + its own steps wins). This keeps existing per-agent advantage and + completion-masking machinery intact — + ``interleave_rollout(split_by_agent=True)`` and + ``MicroBatch.actor_ids`` work without modification. + """ + state["extras"].setdefault("spawns", []) + # Outer loop is sequential across specs but inner inputs run in parallel. + # Most protocols emit a single spec per turn so this is essentially + # one gather() per spawning turn. + for spec in specs: + child_model = ( + self.actor_models.get(spec.agent_id) if self.actor_models else model + ) + child_states = await asyncio.gather( + *( + self._run_one_child(spec, child_input, client, child_model, sampling_args) + for child_input in spec.inputs + ) + ) + for child_state in child_states: + for step in child_state.get("trajectory", []): + step.setdefault("extras", {}) + step["extras"].setdefault("agent_id", spec.agent_id) + step["extras"].setdefault("is_trainable", spec.is_trainable) + state["trajectory"].append(step) + state["extras"]["spawns"].append( + SpawnResult(spec=spec, states=list(child_states)) + ) + + async def _run_one_child( + self, + spec: SpawnSpec, + child_input, + client, + child_model, + sampling_args, + ) -> State: + """Roll out a single child and score it so ``state["reward"]`` is + populated before the parent's reward funcs read it. Score failures + are logged and treated as zero — they should not bring down the + parent rollout.""" + child_state = await spec.child_env.rollout( + child_input, client, child_model, sampling_args + ) + rubric = getattr(spec.child_env, "rubric", None) + if rubric is not None: + try: + await rubric.score_rollout(child_state) + except Exception as e: # pragma: no cover — defensive + self.logger.warning( + "child rubric.score_rollout failed (agent_id=%s): %s: %s", + spec.agent_id, + type(e).__name__, + e, + ) + return child_state diff --git a/verifiers/envs/protocol.py b/verifiers/envs/protocol.py new file mode 100644 index 000000000..2a0775910 --- /dev/null +++ b/verifiers/envs/protocol.py @@ -0,0 +1,177 @@ +""" +Protocol: Defines how multiple agents interact in a multi-agent environment. + +A Protocol specifies turn order and agent interaction patterns, separate from +the task/environment logic. This allows the same protocol (e.g., round-robin) +to be reused across different tasks. + +Example protocols: +- RoundRobinProtocol: Agents take turns in order (0 → 1 → ...) +- SpawningProtocol: One agent can spawn sub-rollouts of another (e.g., + proposer-solver where the proposer creates problems that k solvers attempt) +""" + +from abc import ABC, abstractmethod +from dataclasses import dataclass, field +from typing import TYPE_CHECKING, Any + +from verifiers.types import State + +if TYPE_CHECKING: + from verifiers.envs.environment import Environment + + +class Protocol(ABC): + """ + Abstract base class for multi-agent interaction protocols. + + A Protocol defines: + - Turn order (who goes first, who goes next) + - Agent interaction patterns + + Protocols are independent of: + - Task/environment logic (game rules, rewards) + - Model/harness details (how agents generate responses) + """ + + @abstractmethod + def get_initial_agent(self, state: State) -> str: + """ + Return the agent ID that starts the rollout. + + Args: + state: Initial game state + + Returns: + Agent ID (e.g., "player_0") + """ + pass + + @abstractmethod + def get_next_agent(self, state: State) -> str: + """ + Return the agent ID for the next turn. + + Args: + state: Current game state (use to determine whose turn it is) + + Returns: + Agent ID for the next turn + """ + pass + + +class RoundRobinProtocol(Protocol): + """ + Simple round-robin turn order: agents take turns in sequence. + + Example with 3 agents: + 0 → 1 → 2 → 0 → ... + """ + + def __init__(self, agent_ids: list[str]): + """ + Initialize round-robin protocol. + + Args: + agent_ids: List of agent IDs in turn order + """ + if not agent_ids: + raise ValueError("agent_ids must not be empty") + self.agent_ids = agent_ids + + def get_initial_agent(self, state: State) -> str: + """First agent in the list starts.""" + return self.agent_ids[0] + + def get_next_agent(self, state: State) -> str: + """Cycle through agents in order.""" + current = state["extras"].get("current_agent_id", self.agent_ids[0]) + try: + current_idx = self.agent_ids.index(current) + except ValueError: + current_idx = -1 + next_idx = (current_idx + 1) % len(self.agent_ids) + return self.agent_ids[next_idx] + + +@dataclass +class SpawnSpec: + """A spawn request: one or more sub-rollouts of ``child_env`` for ``agent_id``. + + Fields: + agent_id: Role tag for the spawned children. Their trajectory steps + are embedded in the parent's trajectory with this id so existing + per-agent advantage/credit-assignment machinery + (Rubric.score_group, interleave_rollout(split_by_agent=True), + MicroBatch.actor_ids) picks them up without modification. + child_env: The Environment to roll out for each child. Any + ``Environment`` with a ``rollout()`` method works — does not + have to be a MultiAgentEnv. + inputs: One dataset row per child rollout. Length determines how many + children are spawned. + is_trainable: When False, the spawned tokens are still generated and + included in the parent's context but their completion masks are + zeroed out at training time (mirrors ``Agent.is_trainable``). + """ + + agent_id: str + child_env: "Environment" + inputs: list[Any] + is_trainable: bool = True + + +@dataclass +class SpawnResult: + """A SpawnSpec paired with the resulting child states. + + Stored in ``state["extras"]["spawns"]`` so MultiAgentRewardFunc + implementations can compute parent rewards from child outcomes + (e.g. goldilocks scoring on per-child verify scores). + """ + + spec: SpawnSpec + states: list[State] = field(default_factory=list) + + +class SpawningProtocol(Protocol): + """Protocol extension for hierarchical / one-to-many agent interactions. + + The base ``Protocol`` only sequences turns of a fixed agent set. A + ``SpawningProtocol`` additionally lets the current agent's turn fan out + into N sub-rollouts of another agent (or env). The canonical example + is proposer-solver: the proposer designs a problem, then k solver + children attempt it, and the proposer's reward depends on how the + solvers fared. + + Implementers provide: + - ``should_spawn(state)``: was the previous turn a spawn trigger? + - ``get_spawn_specs(state)``: what to spawn (one or more SpawnSpecs) + + MultiAgentEnv.rollout() handles the actual sub-rollout execution and + weaves children's trajectory steps into the parent's trajectory, tagged + with the child agent_id from the spec. Reward functions read the + completed children from ``state["extras"]["spawns"]`` to compute + per-agent rewards. + """ + + @abstractmethod + def should_spawn(self, state: State) -> bool: + """Return True if the most recent turn should trigger sub-rollouts. + + Called after ``MultiAgentEnv.on_turn_complete`` and before + ``get_next_agent``. Subclasses typically inspect the last trajectory + step's content (e.g. did the proposer emit a valid problem?). + """ + ... + + @abstractmethod + def get_spawn_specs(self, state: State) -> list[SpawnSpec]: + """Return the spawn requests for the current state. + + Called only when ``should_spawn`` is True. Each ``SpawnSpec`` may + request multiple child rollouts via its ``inputs`` list, and + multiple specs may be returned (e.g. spawn two solver groups for + two different problems the proposer emitted in one turn). + """ + ... diff --git a/verifiers/rubrics/rubric.py b/verifiers/rubrics/rubric.py index 914ae4d69..d2f2f35f1 100644 --- a/verifiers/rubrics/rubric.py +++ b/verifiers/rubrics/rubric.py @@ -2,12 +2,13 @@ import inspect import logging from collections.abc import Callable, Mapping -from typing import Any, cast, get_origin +from typing import Any, cast, get_origin, get_type_hints import verifiers as vf from verifiers.decorators import discover_decorated from verifiers.types import ( GroupRewardFunc, + MultiAgentRewardFunc, RewardFunc, RolloutScore, State, @@ -103,6 +104,8 @@ def _get_reward_weights(self) -> list[float]: def _is_group_func(self, func: RewardFunc) -> bool: """Check if a function is a GroupRewardFunc by inspecting its signature.""" + if self._is_multiagent_func(func): + return False sig = inspect.signature(func) # GroupRewardFunc has plural parameters: states, prompts, completions, etc. param_names = set(sig.parameters.keys()) @@ -178,22 +181,47 @@ def task_for_state(self, state: State, resources: object | None) -> object: return to_task(state["input"]) return None + def _is_multiagent_func(self, func: RewardFunc) -> bool: + """Check if a function is a MultiAgentRewardFunc by inspecting its return annotation. + + Uses ``typing.get_type_hints`` so a function defined under + ``from __future__ import annotations`` (where annotations become + strings — PEP 563) is still classified correctly. ``inspect.signature`` + alone returns the raw string form, and ``get_origin('dict[str, float]')`` + is ``None`` — so the dict check would silently miss and the function + would be routed through the individual-reward path, where its dict + return value is coerced to 0.0 by ``float(dict_result)`` failing. + """ + try: + hints = get_type_hints(func) + except Exception: + # Forward refs that fail to resolve (rare, e.g. when the function's + # globals don't contain a referenced name). Fall back to the raw + # annotation so we still classify when get_type_hints raises. + hints = {} + return_annotation = hints.get("return", inspect.signature(func).return_annotation) + return return_annotation is dict or get_origin(return_annotation) is dict + # individual-level reward helpers def _get_individual_reward_func_names(self) -> list[str]: return [ getattr(func, "__name__", repr(func)) for func in self.funcs - if not self._is_group_func(func) + if not self._is_group_func(func) and not self._is_multiagent_func(func) ] def _get_individual_reward_funcs(self) -> list[RewardFunc]: - return [func for func in self.funcs if not self._is_group_func(func)] + return [ + func + for func in self.funcs + if not self._is_group_func(func) and not self._is_multiagent_func(func) + ] def _get_individual_reward_weights(self) -> list[float]: return [ weight for func, weight in zip(self.funcs, self.weights) - if not self._is_group_func(func) + if not self._is_group_func(func) and not self._is_multiagent_func(func) ] async def _call_individual_reward_func( @@ -291,6 +319,40 @@ async def _call_group_reward_func( ans = [0.0] * len(states) return ans + # multi-agent reward helpers + def _get_multiagent_reward_funcs(self) -> list[MultiAgentRewardFunc]: + return cast( + list[MultiAgentRewardFunc], + [func for func in self.funcs if self._is_multiagent_func(func)], + ) + + async def _call_multiagent_reward_func( + self, + func: MultiAgentRewardFunc, + state: State, + ) -> dict[str, float]: + """Invoke a multi-agent reward function that returns per-agent rewards.""" + sig = inspect.signature(func) + merged = self.score_objects(state) + if any(p.kind == p.VAR_KEYWORD for p in sig.parameters.values()): + try: + ans = await maybe_await(func, **merged) + except Exception as e: + self.logger.error( + f"Error calling multi-agent reward function {func.__name__}: {e}" + ) + ans = {} + else: + allowed = {k: v for k, v in merged.items() if k in sig.parameters} + try: + ans = await maybe_await(func, **allowed) + except Exception as e: + self.logger.error( + f"Error calling multi-agent reward function {func.__name__}: {e}" + ) + ans = {} + return ans + async def cleanup(self, state: State): """Run all @vf.cleanup-decorated methods on this rubric.""" for handler in self._cleanup_handlers: @@ -348,37 +410,73 @@ async def dummy_score_rollout(self, state: State): async def score_rollout(self, state: State): """ Evaluate all reward functions for a single rollout. + Supports individual and multi-agent reward functions (but not group functions). """ reward_funcs = self._get_individual_reward_funcs() + multiagent_funcs = self._get_multiagent_reward_funcs() group_reward_funcs = self._get_group_reward_funcs() - assert len(reward_funcs) > 0 and len(group_reward_funcs) == 0, ( - "Rubric.score_rollout requires at least one individual-level reward function and no group-level reward functions" + + has_reward_funcs = len(reward_funcs) > 0 or len(multiagent_funcs) > 0 + assert has_reward_funcs and len(group_reward_funcs) == 0, ( + "Rubric.score_rollout requires at least one individual-level or multi-agent " + "reward function and no group-level reward functions" ) - reward_scores = [] - for func in reward_funcs: - reward_scores.append( - await self._call_individual_reward_func( - func=func, - state=state, - ) + + aggregated_reward = 0.0 + aggregated_metrics: dict[str, float] = {} + aggregated_agent_rewards: dict[str, float] = {} + + agent_ids = { + t.get("extras", {}).get("agent_id") + for t in state["trajectory"] + if t.get("extras", {}).get("agent_id") + } + + for func, weight in zip(reward_funcs, self._get_individual_reward_weights()): + score = await self._call_individual_reward_func(func=func, state=state) + aggregated_reward += score * weight + aggregated_metrics[func.__name__] = score + for agent_id in agent_ids: + if agent_id not in aggregated_agent_rewards: + aggregated_agent_rewards[agent_id] = 0.0 + aggregated_agent_rewards[agent_id] += score * weight + + multiagent_weights = [ + weight + for func, weight in zip(self.funcs, self.weights) + if self._is_multiagent_func(func) + ] + for func, weight in zip(multiagent_funcs, multiagent_weights): + agent_scores = await self._call_multiagent_reward_func( + func=func, state=state ) + for agent_id, score_value in agent_scores.items(): + if agent_id not in aggregated_agent_rewards: + aggregated_agent_rewards[agent_id] = 0.0 + aggregated_agent_rewards[agent_id] += score_value * weight + if agent_scores: + mean_score = sum(agent_scores.values()) / len(agent_scores) + aggregated_reward += mean_score * weight + aggregated_metrics[func.__name__] = mean_score + for agent_id, score_value in agent_scores.items(): + aggregated_metrics[f"{func.__name__}/{agent_id}"] = score_value + rewards = RolloutScore( - metrics={ - func.__name__: reward - for func, reward in zip(reward_funcs, reward_scores) - }, - reward=sum( - [ - reward * weight - for reward, weight in zip( - reward_scores, self._get_individual_reward_weights() - ) - ] - ), + metrics=aggregated_metrics, + reward=aggregated_reward, ) state["reward"] = rewards["reward"] state["metrics"] = rewards["metrics"] + if aggregated_agent_rewards: + state["agent_rewards"] = aggregated_agent_rewards + for t in state["trajectory"]: + if t["reward"] is None: + agent_id = t.get("extras", {}).get("agent_id") + t["reward"] = aggregated_agent_rewards.get( + agent_id, state["reward"] + ) + async def dummy_score_group(self, states: list[State]): """Score a group of rollouts together with dummy rewards.""" for state in states: @@ -389,18 +487,48 @@ async def score_group(self, states: list[State]): Score a group of rollouts together. All reward functions are executed in order, parallelizing across states. + Supports multi-agent reward functions that return per-agent rewards. """ num_states = len(states) if num_states == 0: self.logger.warning("No states to score") return aggregated_rewards = [0.0] * num_states + # Per-agent rewards for multi-agent envs: list of dict[agent_id, reward] + aggregated_agent_rewards: list[dict[str, float]] = [ + {} for _ in range(num_states) + ] aggregated_metrics: dict[str, list[float]] = {} # process functions in order for func, weight in zip(self.funcs, self.weights): is_group = self._is_group_func(func) - if is_group: + is_multiagent = self._is_multiagent_func(func) + + if is_multiagent: + # MultiAgentRewardFunc: returns dict[str, float] per state + multiagent_func = cast(MultiAgentRewardFunc, func) + score_tasks = [ + self._call_multiagent_reward_func(multiagent_func, state) + for state in states + ] + agent_scores_list = await asyncio.gather(*score_tasks) + + func_name = func.__name__ + for i, agent_scores in enumerate(agent_scores_list): + # Aggregate per-agent rewards + for agent_id, score_value in agent_scores.items(): + if agent_id not in aggregated_agent_rewards[i]: + aggregated_agent_rewards[i][agent_id] = 0.0 + aggregated_agent_rewards[i][agent_id] += score_value * weight + # Also compute a rollout-level reward (mean of agent rewards) + if agent_scores: + mean_score = sum(agent_scores.values()) / len(agent_scores) + aggregated_rewards[i] += mean_score * weight + if func_name not in aggregated_metrics: + aggregated_metrics[func_name] = [0.0] * num_states + aggregated_metrics[func_name][i] = mean_score + elif is_group: # GroupRewardFunc: score all states together group_func = cast(GroupRewardFunc, func) scores = await self._call_group_reward_func(group_func, states) @@ -411,6 +539,16 @@ async def score_group(self, states: list[State]): score_value = scores[i] aggregated_rewards[i] += score_value * weight aggregated_metrics[func_name][i] = score_value + # Also add to each agent's rewards (for multi-agent compatibility) + agent_ids = set( + t.get("extras", {}).get("agent_id") + for t in states[i]["trajectory"] + if t.get("extras", {}).get("agent_id") + ) + for agent_id in agent_ids: + if agent_id not in aggregated_agent_rewards[i]: + aggregated_agent_rewards[i][agent_id] = 0.0 + aggregated_agent_rewards[i][agent_id] += score_value * weight else: reward_func = cast(RewardFunc, func) score_tasks = [ @@ -426,16 +564,64 @@ async def score_group(self, states: list[State]): score_value = scores[i] aggregated_rewards[i] += score_value * weight aggregated_metrics[func_name][i] = score_value + # Also add to each agent's rewards (for multi-agent compatibility) + agent_ids = set( + t.get("extras", {}).get("agent_id") + for t in states[i]["trajectory"] + if t.get("extras", {}).get("agent_id") + ) + for agent_id in agent_ids: + if agent_id not in aggregated_agent_rewards[i]: + aggregated_agent_rewards[i][agent_id] = 0.0 + aggregated_agent_rewards[i][agent_id] += score_value * weight avg_reward = sum(aggregated_rewards) / num_states + + # Compute per-agent baselines for multi-agent environments + # Each agent's advantage is relative to that agent's mean reward + agent_baselines: dict[str, float] = {} + agent_reward_sums: dict[str, float] = {} + agent_reward_counts: dict[str, int] = {} + + for i in range(num_states): + for agent_id, reward in aggregated_agent_rewards[i].items(): + if agent_id not in agent_reward_sums: + agent_reward_sums[agent_id] = 0.0 + agent_reward_counts[agent_id] = 0 + agent_reward_sums[agent_id] += reward + agent_reward_counts[agent_id] += 1 + + for agent_id in agent_reward_sums: + agent_baselines[agent_id] = ( + agent_reward_sums[agent_id] / agent_reward_counts[agent_id] + ) + for i, state in enumerate(states): state["reward"] = aggregated_rewards[i] state["advantage"] = aggregated_rewards[i] - avg_reward + + # Store per-agent rewards if any multi-agent funcs were used + agent_rewards = aggregated_agent_rewards[i] + if agent_rewards: + state["agent_rewards"] = agent_rewards + + # Assign per-step rewards and advantages based on agent_id (for multi-agent) for t in state["trajectory"]: - if t["advantage"] is None: - t["advantage"] = state["advantage"] if t["reward"] is None: - t["reward"] = state["reward"] + if agent_rewards: + agent_id = t.get("extras", {}).get("agent_id") + t["reward"] = agent_rewards.get(agent_id, state["reward"]) + else: + t["reward"] = state["reward"] + + # Compute per-agent advantage using per-agent baseline + if t["advantage"] is None: + agent_id = t.get("extras", {}).get("agent_id") + if agent_id and agent_id in agent_baselines: + t["advantage"] = t["reward"] - agent_baselines[agent_id] + else: + t["advantage"] = t["reward"] - avg_reward + state["metrics"] = { func_name: values[i] for func_name, values in aggregated_metrics.items() } diff --git a/verifiers/serve/client/env_client.py b/verifiers/serve/client/env_client.py index 8649fb246..f6e8f2fd0 100644 --- a/verifiers/serve/client/env_client.py +++ b/verifiers/serve/client/env_client.py @@ -50,6 +50,7 @@ async def run_rollout( sampling_args: SamplingArgs, max_retries: int = 0, state_columns: list[str] | None = None, + actor_models: dict[str, str] | None = None, ) -> RolloutOutput: resolved_client_config = resolve_client_config(client_config) request = RunRolloutRequest( @@ -59,6 +60,7 @@ async def run_rollout( sampling_args=sampling_args, max_retries=max_retries, state_columns=state_columns, + actor_models=actor_models, ) response = await self.handle_run_rollout_request(request, timeout=None) assert response.output is not None @@ -72,6 +74,7 @@ async def run_group( sampling_args: SamplingArgs, max_retries: int = 0, state_columns: list[str] | None = None, + actor_models: dict[str, str] | None = None, ) -> list[RolloutOutput]: resolved_client_config = resolve_client_config(client_config) request = RunGroupRequest( @@ -81,6 +84,7 @@ async def run_group( sampling_args=sampling_args, max_retries=max_retries, state_columns=state_columns, + actor_models=actor_models, ) response = await self.handle_run_group_request(request, timeout=None) assert response.outputs is not None diff --git a/verifiers/serve/server/env_worker.py b/verifiers/serve/server/env_worker.py index 70f77cf30..3174e71f6 100644 --- a/verifiers/serve/server/env_worker.py +++ b/verifiers/serve/server/env_worker.py @@ -135,9 +135,14 @@ async def resolve_client(self, client_config: ClientConfig) -> Client: self.clients[key] = resolve_client(resolved) return self.clients[key] + def _set_actor_models(self, actor_models: dict[str, str] | None) -> None: + if actor_models is not None and hasattr(self.env, "actor_models"): + self.env.actor_models = actor_models + async def handle_run_rollout( self, request: RunRolloutRequest ) -> RunRolloutResponse: + self._set_actor_models(request.actor_models) client = await self.resolve_client(request.client_config) output = await self.env.run_rollout( input=request.input, @@ -146,10 +151,12 @@ async def handle_run_rollout( sampling_args=request.sampling_args, max_retries=request.max_retries, state_columns=request.state_columns, + actor_models=request.actor_models, ) return RunRolloutResponse(output=output) async def handle_run_group(self, request: RunGroupRequest) -> RunGroupResponse: + self._set_actor_models(request.actor_models) client = await self.resolve_client(request.client_config) outputs = await self.env.run_group( group_inputs=request.group_inputs, @@ -158,6 +165,7 @@ async def handle_run_group(self, request: RunGroupRequest) -> RunGroupResponse: sampling_args=request.sampling_args, max_retries=request.max_retries, state_columns=request.state_columns, + actor_models=request.actor_models, ) return RunGroupResponse(outputs=outputs) diff --git a/verifiers/serve/types.py b/verifiers/serve/types.py index 3c252c230..d7ed22d14 100644 --- a/verifiers/serve/types.py +++ b/verifiers/serve/types.py @@ -55,6 +55,7 @@ class RunRolloutRequest(BaseRequest): sampling_args: SamplingArgs max_retries: int state_columns: list[str] | None + actor_models: dict[str, str] | None = None class RunRolloutResponse(BaseResponse): @@ -71,6 +72,7 @@ class RunGroupRequest(BaseRequest): sampling_args: SamplingArgs max_retries: int state_columns: list[str] | None + actor_models: dict[str, str] | None = None class RunGroupResponse(BaseResponse): diff --git a/verifiers/types.py b/verifiers/types.py index 0e4f63c19..f817db387 100644 --- a/verifiers/types.py +++ b/verifiers/types.py @@ -226,7 +226,8 @@ class Response(CustomBaseModel): SamplingArgs = dict[str, Any] IndividualRewardFunc = Callable[..., float | Awaitable[float]] GroupRewardFunc = Callable[..., list[float] | Awaitable[list[float]]] -RewardFunc = IndividualRewardFunc | GroupRewardFunc +MultiAgentRewardFunc = Callable[..., dict[str, float] | Awaitable[dict[str, float]]] +RewardFunc = IndividualRewardFunc | GroupRewardFunc | MultiAgentRewardFunc DatasetBuilder: TypeAlias = "Callable[[], Dataset]"