From 0f7810eefb6929c61756c235a4710ff44580ba65 Mon Sep 17 00:00:00 2001 From: nph4rd Date: Sat, 7 Feb 2026 16:46:58 -0600 Subject: [PATCH 01/21] add MultiAgentEnv for turn-based multi-agent environments --- verifiers/__init__.py | 6 + verifiers/envs/actor.py | 33 +++++ verifiers/envs/multiagent_env.py | 233 +++++++++++++++++++++++++++++++ 3 files changed, 272 insertions(+) create mode 100644 verifiers/envs/actor.py create mode 100644 verifiers/envs/multiagent_env.py diff --git a/verifiers/__init__.py b/verifiers/__init__.py index b39b0e3f7..a657e7cfb 100644 --- a/verifiers/__init__.py +++ b/verifiers/__init__.py @@ -99,8 +99,10 @@ "Terminus2", "Terminus2Config", "SignalConfig", + "Actor", "Environment", "MultiTurnEnv", + "MultiAgentEnv", "SingleTurnEnv", "PythonEnv", "SandboxEnv", @@ -169,6 +171,8 @@ "SingleTurnEnv": "verifiers.envs.singleturn_env:SingleTurnEnv", "StatefulToolEnv": "verifiers.envs.stateful_tool_env:StatefulToolEnv", "ToolEnv": "verifiers.envs.tool_env:ToolEnv", + "Actor": "verifiers.envs.actor:Actor", + "MultiAgentEnv": "verifiers.envs.multiagent_env:MultiAgentEnv", "EnvGroup": "verifiers.envs.env_group:EnvGroup", "JudgeRubric": "verifiers.rubrics.judge_rubric:JudgeRubric", "load_environment": "verifiers.utils.env_utils:load_environment", @@ -281,8 +285,10 @@ def __getattr__(name: str): from .clients.openai_completions_client import OpenAICompletionsClient # noqa: F401 from .clients.openai_responses_client import OpenAIResponsesClient # noqa: F401 from .clients.renderer_client import RendererClient # noqa: F401 + from .envs.actor import Actor # noqa: F401 from .envs.env_group import EnvGroup # noqa: F401 from .envs.environment import Environment # noqa: F401 + from .envs.multiagent_env import MultiAgentEnv # noqa: F401 from .envs.experimental.cli_agent_env import CliAgentEnv # noqa: F401 from .envs.experimental.gym_env import GymEnv # noqa: F401 from .envs.experimental.harbor_env import HarborEnv # noqa: F401 diff --git a/verifiers/envs/actor.py b/verifiers/envs/actor.py new file mode 100644 index 000000000..9ae9f7888 --- /dev/null +++ b/verifiers/envs/actor.py @@ -0,0 +1,33 @@ +""" +Actor: A trainable entity with distinct identity (system prompt) in multi-agent environments. +""" + +from dataclasses import dataclass + + +@dataclass +class Actor: + """ + A trainable actor with distinct system prompt. + + Fields: + id: Unique identifier for this actor (e.g., "player1", "guesser") + system_prompt: The actor's persona/instructions (used in build_actor_prompt) + is_trainable: Whether to compute GRPO advantages (False for frozen actors) + """ + + id: str + system_prompt: str = "" + is_trainable: bool = True + + def __hash__(self) -> int: + return hash(self.id) + + def __eq__(self, other: object) -> bool: + if isinstance(other, Actor): + return self.id == other.id + return False + + def __repr__(self) -> str: + trainable_str = "trainable" if self.is_trainable else "frozen" + return f"Actor(id={self.id!r}, {trainable_str})" diff --git a/verifiers/envs/multiagent_env.py b/verifiers/envs/multiagent_env.py new file mode 100644 index 000000000..26c7c66a1 --- /dev/null +++ b/verifiers/envs/multiagent_env.py @@ -0,0 +1,233 @@ +""" +Multi-agent environment for turn-based games. + +This module provides the base class for multi-agent RL environments, extending +MultiTurnEnv with support for: +- Multiple actors with distinct system prompts +- Turn order management via get_initial_actor() / get_next_actor() +- Per-actor trajectory tagging for credit assignment + +Key concepts: +- Actor: A trainable entity with its own system prompt (defined in actor.py) + +Game Implementation: +- Subclasses implement these main hooks: + - get_initial_actor(state): Who goes first + - get_next_actor(state): Who goes next + - build_actor_prompt(actor_id, state): Build fresh prompt for this actor + - on_turn_complete(state): Update game state after each turn +""" + +from abc import abstractmethod + +import verifiers as vf +from verifiers.envs.actor import Actor +from verifiers.envs.multiturn_env import MultiTurnEnv +from verifiers.types import Messages, State, TrajectoryStep + + +class MultiAgentEnv(MultiTurnEnv): + """ + Base class for multi-agent environments. + + Subclasses must implement: + - get_initial_actor(): Who goes first + - get_next_actor(): Who goes next (for alternating turns) + - build_actor_prompt(): Build prompt for current actor + + Subclasses may optionally override: + - on_turn_complete(): Game logic after each turn + """ + + # List of actor IDs this environment uses (e.g., ["player_0", "player_1"]) + # Subclasses should override this or set in __init__ + actors: list[str] = [] + + def __init__(self, **kwargs): + """Initialize multi-agent environment.""" + super().__init__(**kwargs) + # Internal storage for Actor objects (when not using Protocol) + self._actor_registry: dict[str, Actor] = {} + + def register_actor(self, actor: Actor) -> None: + """Register an Actor object for lookup by get_actor().""" + self._actor_registry[actor.id] = actor + if actor.id not in self.actors: + self.actors.append(actor.id) + + def get_actor(self, actor_id: str) -> Actor: + """Get an actor by ID.""" + if actor_id not in self._actor_registry: + raise KeyError( + f"Actor '{actor_id}' not found. Did you call register_actor()?" + ) + return self._actor_registry[actor_id] + + # ------------------------------------------------------------------------- + # Turn Management (Subclasses Implement These) + # ------------------------------------------------------------------------- + + @abstractmethod + def get_initial_actor(self, state: State) -> str: + """ + Return the actor ID that starts the rollout. + + Example: return "player_0" + """ + pass + + @abstractmethod + def get_next_actor(self, state: State) -> str: + """ + Return the actor ID for the next turn. + + Example: Round-robin through players + """ + pass + + # ------------------------------------------------------------------------- + # Game Hooks (Subclasses Implement These) + # ------------------------------------------------------------------------- + + @abstractmethod + async def build_actor_prompt(self, actor_id: str, state: State) -> Messages: + """ + Build the prompt for the given actor's turn. + + This is called BEFORE the model generates a response. + Build a fresh prompt with whatever context this actor needs. + + Args: + actor_id: The actor who will respond (e.g., "player_0") + state: Current game state with trajectory and extras + + Returns: + Messages list with system prompt and user content + """ + pass + + async def on_turn_complete(self, state: State) -> None: + """ + Update game state after a turn completes. + + This is called AFTER the model response is stored in trajectory. + Use this for game logic: + - Update scores, counters, flags + - Check win conditions + - Parse and validate actions + + The last turn's info is in state["trajectory"][-1]: + - ["completion"][-1]["content"]: The model's response text + - ["extras"]["actor_id"]: Which actor just responded + + Args: + state: Current game state (mutate extras as needed) + """ + pass + + # ------------------------------------------------------------------------- + # State Setup + # ------------------------------------------------------------------------- + + async def setup_state(self, state: State) -> State: + """Initialize multi-agent state fields.""" + state = await super().setup_state(state) + state["extras"] = state.get("extras", {}) + state["extras"]["current_actor_id"] = None + return state + + # ------------------------------------------------------------------------- + # Parent Class Requirement (env_response) + # ------------------------------------------------------------------------- + + async def env_response( + self, messages: Messages, state: State, **kwargs + ) -> Messages: + """ + Satisfy MultiTurnEnv's abstract requirement. + + MultiAgentEnv uses on_turn_complete() instead, which is called + explicitly in our rollout() after storing the response. + """ + return [] + + # ------------------------------------------------------------------------- + # Trajectory Management + # ------------------------------------------------------------------------- + + async def add_trajectory_step( + self, state: State, trajectory_step: TrajectoryStep + ) -> None: + """Tag trajectory step with actor_id.""" + current_actor_id = state["extras"].get("current_actor_id") + if current_actor_id: + trajectory_step["extras"]["actor_id"] = current_actor_id + # Copy trainability from Actor to step + actor = self.get_actor(current_actor_id) + trajectory_step["extras"]["is_trainable"] = actor.is_trainable + await super().add_trajectory_step(state, trajectory_step) + + # ------------------------------------------------------------------------- + # Main Rollout Loop + # ------------------------------------------------------------------------- + + async def rollout( + self, + input, + client, + model, + sampling_args=None, + ) -> State: + """ + Run a multi-agent episode. + + Flow: + 1. Setup state + 2. Loop until game ends: + a. Determine current actor + b. Build prompt via build_actor_prompt() + c. Get model response + d. Store in trajectory + e. Process via on_turn_complete() + 3. Return final state + """ + state = await self.init_state(input, client, model, sampling_args) + try: + state = await self.setup_state(state) + except vf.Error as e: + state["error"] = e + return state + + # Determine first actor + state["extras"]["current_actor_id"] = self.get_initial_actor(state) + + while not await self.is_completed(state): + actor_id = state["extras"]["current_actor_id"] + + try: + # 1. Build prompt for this actor + prompt_messages = await self.build_actor_prompt(actor_id, state) + + # 2. Get model response + response = await self.get_model_response(state, prompt_messages) + + # 3. Store in trajectory (tags with actor_id) + await self.add_model_response(state, prompt_messages, response) + + # 4. Process turn (game logic) + await self.on_turn_complete(state) + + # 5. Determine next actor (if game continues) + if not await self.is_completed(state): + state["extras"]["current_actor_id"] = self.get_next_actor(state) + + except vf.OverlongPromptError: + state["prompt_too_long"] = True + state["is_truncated"] = True + break + except vf.Error as e: + state["error"] = e + break + + await self.render_completion(state) + return state From a4fea39922906e5993d714c141a96802bf0e83d0 Mon Sep 17 00:00:00 2001 From: nph4rd Date: Sun, 8 Feb 2026 10:31:35 -0600 Subject: [PATCH 02/21] rename Actor to Agent, add Protocol abstraction --- verifiers/__init__.py | 11 ++- verifiers/envs/actor.py | 33 ------- verifiers/envs/agent.py | 41 +++++++++ verifiers/envs/multiagent_env.py | 150 +++++++++++++++++-------------- verifiers/envs/protocol.py | 90 +++++++++++++++++++ 5 files changed, 224 insertions(+), 101 deletions(-) delete mode 100644 verifiers/envs/actor.py create mode 100644 verifiers/envs/agent.py create mode 100644 verifiers/envs/protocol.py diff --git a/verifiers/__init__.py b/verifiers/__init__.py index a657e7cfb..5e6780e85 100644 --- a/verifiers/__init__.py +++ b/verifiers/__init__.py @@ -99,7 +99,9 @@ "Terminus2", "Terminus2Config", "SignalConfig", - "Actor", + "Agent", + "Protocol", + "RoundRobinProtocol", "Environment", "MultiTurnEnv", "MultiAgentEnv", @@ -171,7 +173,9 @@ "SingleTurnEnv": "verifiers.envs.singleturn_env:SingleTurnEnv", "StatefulToolEnv": "verifiers.envs.stateful_tool_env:StatefulToolEnv", "ToolEnv": "verifiers.envs.tool_env:ToolEnv", - "Actor": "verifiers.envs.actor:Actor", + "Agent": "verifiers.envs.agent:Agent", + "Protocol": "verifiers.envs.protocol:Protocol", + "RoundRobinProtocol": "verifiers.envs.protocol:RoundRobinProtocol", "MultiAgentEnv": "verifiers.envs.multiagent_env:MultiAgentEnv", "EnvGroup": "verifiers.envs.env_group:EnvGroup", "JudgeRubric": "verifiers.rubrics.judge_rubric:JudgeRubric", @@ -285,7 +289,8 @@ def __getattr__(name: str): from .clients.openai_completions_client import OpenAICompletionsClient # noqa: F401 from .clients.openai_responses_client import OpenAIResponsesClient # noqa: F401 from .clients.renderer_client import RendererClient # noqa: F401 - from .envs.actor import Actor # noqa: F401 + from .envs.agent import Agent # noqa: F401 + from .envs.protocol import Protocol, RoundRobinProtocol # noqa: F401 from .envs.env_group import EnvGroup # noqa: F401 from .envs.environment import Environment # noqa: F401 from .envs.multiagent_env import MultiAgentEnv # noqa: F401 diff --git a/verifiers/envs/actor.py b/verifiers/envs/actor.py deleted file mode 100644 index 9ae9f7888..000000000 --- a/verifiers/envs/actor.py +++ /dev/null @@ -1,33 +0,0 @@ -""" -Actor: A trainable entity with distinct identity (system prompt) in multi-agent environments. -""" - -from dataclasses import dataclass - - -@dataclass -class Actor: - """ - A trainable actor with distinct system prompt. - - Fields: - id: Unique identifier for this actor (e.g., "player1", "guesser") - system_prompt: The actor's persona/instructions (used in build_actor_prompt) - is_trainable: Whether to compute GRPO advantages (False for frozen actors) - """ - - id: str - system_prompt: str = "" - is_trainable: bool = True - - def __hash__(self) -> int: - return hash(self.id) - - def __eq__(self, other: object) -> bool: - if isinstance(other, Actor): - return self.id == other.id - return False - - def __repr__(self) -> str: - trainable_str = "trainable" if self.is_trainable else "frozen" - return f"Actor(id={self.id!r}, {trainable_str})" diff --git a/verifiers/envs/agent.py b/verifiers/envs/agent.py new file mode 100644 index 000000000..abdee5cef --- /dev/null +++ b/verifiers/envs/agent.py @@ -0,0 +1,41 @@ +""" +Agent: A participant in multi-agent environments. + +Currently contains agent metadata (id, system prompt, trainability). +In the future, when Harness is introduced, Agent will be extended to +compose with Harness and Model: Agent = Harness + Model. +""" + +from dataclasses import dataclass + + +@dataclass +class Agent: + """ + An agent in a multi-agent environment. + + Fields: + id: Unique identifier for this agent (e.g., "player_0", "guesser") + system_prompt: The agent's persona/instructions + is_trainable: Whether to compute gradients for this agent's actions + + Future: + When Harness is introduced, Agent will be extended to include + rollout logic and model binding: Agent = Harness + Model. + """ + + id: str + system_prompt: str = "" + is_trainable: bool = True + + def __hash__(self) -> int: + return hash(self.id) + + def __eq__(self, other: object) -> bool: + if isinstance(other, Agent): + return self.id == other.id + return False + + def __repr__(self) -> str: + trainable_str = "trainable" if self.is_trainable else "frozen" + return f"Agent(id={self.id!r}, {trainable_str})" diff --git a/verifiers/envs/multiagent_env.py b/verifiers/envs/multiagent_env.py index 26c7c66a1..62cd11d16 100644 --- a/verifiers/envs/multiagent_env.py +++ b/verifiers/envs/multiagent_env.py @@ -3,26 +3,28 @@ This module provides the base class for multi-agent RL environments, extending MultiTurnEnv with support for: -- Multiple actors with distinct system prompts -- Turn order management via get_initial_actor() / get_next_actor() -- Per-actor trajectory tagging for credit assignment +- Multiple agents with distinct system prompts +- Turn order management via Protocol or get_initial_agent() / get_next_agent() +- Per-agent trajectory tagging for credit assignment Key concepts: -- Actor: A trainable entity with its own system prompt (defined in actor.py) +- Agent: A participant with its own identity/prompt (defined in agent.py) +- Protocol: Defines turn order and interaction patterns (defined in protocol.py) -Game Implementation: +Environment Implementation: - Subclasses implement these main hooks: - - get_initial_actor(state): Who goes first - - get_next_actor(state): Who goes next - - build_actor_prompt(actor_id, state): Build fresh prompt for this actor + - get_initial_agent(state): Who goes first (or use a Protocol) + - get_next_agent(state): Who goes next (or use a Protocol) + - build_agent_prompt(agent_id, state): Build fresh prompt for this agent - on_turn_complete(state): Update game state after each turn """ from abc import abstractmethod import verifiers as vf -from verifiers.envs.actor import Actor +from verifiers.envs.agent import Agent from verifiers.envs.multiturn_env import MultiTurnEnv +from verifiers.envs.protocol import Protocol from verifiers.types import Messages, State, TrajectoryStep @@ -30,75 +32,89 @@ class MultiAgentEnv(MultiTurnEnv): """ Base class for multi-agent environments. + Turn order can be specified either by: + 1. Passing a Protocol to __init__ (reusable turn logic) + 2. Implementing get_initial_agent() and get_next_agent() in subclass + Subclasses must implement: - - get_initial_actor(): Who goes first - - get_next_actor(): Who goes next (for alternating turns) - - build_actor_prompt(): Build prompt for current actor + - build_agent_prompt(): Build prompt for current agent Subclasses may optionally override: - on_turn_complete(): Game logic after each turn + - get_initial_agent() / get_next_agent(): If not using a Protocol """ - # List of actor IDs this environment uses (e.g., ["player_0", "player_1"]) + # List of agent IDs this environment uses (e.g., ["player_0", "player_1"]) # Subclasses should override this or set in __init__ - actors: list[str] = [] + agents: list[str] = [] + + def __init__(self, protocol: Protocol | None = None, **kwargs): + """ + Initialize multi-agent environment. - def __init__(self, **kwargs): - """Initialize multi-agent environment.""" + Args: + protocol: Optional Protocol for turn order. If not provided, + subclass must implement get_initial_agent/get_next_agent. + **kwargs: Passed to MultiTurnEnv + """ super().__init__(**kwargs) - # Internal storage for Actor objects (when not using Protocol) - self._actor_registry: dict[str, Actor] = {} - - def register_actor(self, actor: Actor) -> None: - """Register an Actor object for lookup by get_actor().""" - self._actor_registry[actor.id] = actor - if actor.id not in self.actors: - self.actors.append(actor.id) - - def get_actor(self, actor_id: str) -> Actor: - """Get an actor by ID.""" - if actor_id not in self._actor_registry: + self._protocol = protocol + self._agent_registry: dict[str, Agent] = {} + + def register_agent(self, agent: Agent) -> None: + """Register an Agent for lookup by get_agent().""" + self._agent_registry[agent.id] = agent + if agent.id not in self.agents: + self.agents.append(agent.id) + + def get_agent(self, agent_id: str) -> Agent: + """Get an agent by ID.""" + if agent_id not in self._agent_registry: raise KeyError( - f"Actor '{actor_id}' not found. Did you call register_actor()?" + f"Agent '{agent_id}' not found. Did you call register_agent()?" ) - return self._actor_registry[actor_id] + return self._agent_registry[agent_id] # ------------------------------------------------------------------------- - # Turn Management (Subclasses Implement These) + # Turn Management # ------------------------------------------------------------------------- - @abstractmethod - def get_initial_actor(self, state: State) -> str: + def get_initial_agent(self, state: State) -> str: """ - Return the actor ID that starts the rollout. + Return the agent ID that starts the rollout. - Example: return "player_0" + Default: delegates to Protocol if provided. + Override in subclass if not using a Protocol. """ - pass + if self._protocol: + return self._protocol.get_initial_agent(state) + raise NotImplementedError("Provide a Protocol or override get_initial_agent()") - @abstractmethod - def get_next_actor(self, state: State) -> str: + def get_next_agent(self, state: State) -> str: """ - Return the actor ID for the next turn. + Return the agent ID for the next turn. - Example: Round-robin through players + Default: delegates to Protocol if provided. + Override in subclass if not using a Protocol. """ - pass + if self._protocol: + return self._protocol.get_next_agent(state) + raise NotImplementedError("Provide a Protocol or override get_next_agent()") # ------------------------------------------------------------------------- - # Game Hooks (Subclasses Implement These) + # Agent Prompt Building (Subclasses Implement This) # ------------------------------------------------------------------------- @abstractmethod - async def build_actor_prompt(self, actor_id: str, state: State) -> Messages: + async def build_agent_prompt(self, agent_id: str, state: State) -> Messages: """ - Build the prompt for the given actor's turn. + Build the prompt for the given agent's turn. This is called BEFORE the model generates a response. - Build a fresh prompt with whatever context this actor needs. + Build a fresh prompt with whatever context this agent needs. Args: - actor_id: The actor who will respond (e.g., "player_0") + agent_id: The agent who will respond (e.g., "player_0") state: Current game state with trajectory and extras Returns: @@ -106,6 +122,10 @@ async def build_actor_prompt(self, actor_id: str, state: State) -> Messages: """ pass + # ------------------------------------------------------------------------- + # Game Logic Hook + # ------------------------------------------------------------------------- + async def on_turn_complete(self, state: State) -> None: """ Update game state after a turn completes. @@ -118,7 +138,7 @@ async def on_turn_complete(self, state: State) -> None: The last turn's info is in state["trajectory"][-1]: - ["completion"][-1]["content"]: The model's response text - - ["extras"]["actor_id"]: Which actor just responded + - ["extras"]["agent_id"]: Which agent just responded Args: state: Current game state (mutate extras as needed) @@ -133,7 +153,7 @@ async def setup_state(self, state: State) -> State: """Initialize multi-agent state fields.""" state = await super().setup_state(state) state["extras"] = state.get("extras", {}) - state["extras"]["current_actor_id"] = None + state["extras"]["current_agent_id"] = None return state # ------------------------------------------------------------------------- @@ -158,13 +178,13 @@ async def env_response( async def add_trajectory_step( self, state: State, trajectory_step: TrajectoryStep ) -> None: - """Tag trajectory step with actor_id.""" - current_actor_id = state["extras"].get("current_actor_id") - if current_actor_id: - trajectory_step["extras"]["actor_id"] = current_actor_id - # Copy trainability from Actor to step - actor = self.get_actor(current_actor_id) - trajectory_step["extras"]["is_trainable"] = actor.is_trainable + """Tag trajectory step with agent_id.""" + current_agent_id = state["extras"].get("current_agent_id") + if current_agent_id: + trajectory_step["extras"]["agent_id"] = current_agent_id + # Copy trainability from Agent to step + agent = self.get_agent(current_agent_id) + trajectory_step["extras"]["is_trainable"] = agent.is_trainable await super().add_trajectory_step(state, trajectory_step) # ------------------------------------------------------------------------- @@ -184,8 +204,8 @@ async def rollout( Flow: 1. Setup state 2. Loop until game ends: - a. Determine current actor - b. Build prompt via build_actor_prompt() + a. Determine current agent + b. Build prompt via build_agent_prompt() c. Get model response d. Store in trajectory e. Process via on_turn_complete() @@ -198,28 +218,28 @@ async def rollout( state["error"] = e return state - # Determine first actor - state["extras"]["current_actor_id"] = self.get_initial_actor(state) + # Determine first agent + state["extras"]["current_agent_id"] = self.get_initial_agent(state) while not await self.is_completed(state): - actor_id = state["extras"]["current_actor_id"] + agent_id = state["extras"]["current_agent_id"] try: - # 1. Build prompt for this actor - prompt_messages = await self.build_actor_prompt(actor_id, state) + # 1. Build prompt for this agent + prompt_messages = await self.build_agent_prompt(agent_id, state) # 2. Get model response response = await self.get_model_response(state, prompt_messages) - # 3. Store in trajectory (tags with actor_id) + # 3. Store in trajectory (tags with agent_id) await self.add_model_response(state, prompt_messages, response) # 4. Process turn (game logic) await self.on_turn_complete(state) - # 5. Determine next actor (if game continues) + # 5. Determine next agent (if game continues) if not await self.is_completed(state): - state["extras"]["current_actor_id"] = self.get_next_actor(state) + state["extras"]["current_agent_id"] = self.get_next_agent(state) except vf.OverlongPromptError: state["prompt_too_long"] = True diff --git a/verifiers/envs/protocol.py b/verifiers/envs/protocol.py new file mode 100644 index 000000000..00d640013 --- /dev/null +++ b/verifiers/envs/protocol.py @@ -0,0 +1,90 @@ +""" +Protocol: Defines how multiple agents interact in a multi-agent environment. + +A Protocol specifies turn order and agent interaction patterns, separate from +the task/environment logic. This allows the same protocol (e.g., round-robin) +to be reused across different tasks. + +Example protocols: +- RoundRobinProtocol: Agents take turns in order (player_0 → player_1 → ...) +- SimultaneousProtocol: All agents act at once (future) +- HierarchicalProtocol: Some agents coordinate others (future) +""" + +from abc import ABC, abstractmethod + +from verifiers.types import State + + +class Protocol(ABC): + """ + Abstract base class for multi-agent interaction protocols. + + A Protocol defines: + - Turn order (who goes first, who goes next) + - Agent interaction patterns + + Protocols are independent of: + - Task/environment logic (game rules, rewards) + - Model/harness details (how agents generate responses) + """ + + @abstractmethod + def get_initial_agent(self, state: State) -> str: + """ + Return the agent ID that starts the rollout. + + Args: + state: Initial game state + + Returns: + Agent ID (e.g., "player_0") + """ + pass + + @abstractmethod + def get_next_agent(self, state: State) -> str: + """ + Return the agent ID for the next turn. + + Args: + state: Current game state (use to determine whose turn it is) + + Returns: + Agent ID for the next turn + """ + pass + + +class RoundRobinProtocol(Protocol): + """ + Simple round-robin turn order: agents take turns in sequence. + + Example with 3 agents: + player_0 → player_1 → player_2 → player_0 → ... + """ + + def __init__(self, agent_ids: list[str]): + """ + Initialize round-robin protocol. + + Args: + agent_ids: List of agent IDs in turn order + """ + if not agent_ids: + raise ValueError("agent_ids must not be empty") + self.agent_ids = agent_ids + + def get_initial_agent(self, state: State) -> str: + """First agent in the list starts.""" + return self.agent_ids[0] + + def get_next_agent(self, state: State) -> str: + """Cycle through agents in order.""" + current = state["extras"].get("current_agent_id", self.agent_ids[0]) + try: + current_idx = self.agent_ids.index(current) + except ValueError: + current_idx = -1 + next_idx = (current_idx + 1) % len(self.agent_ids) + return self.agent_ids[next_idx] From 8dab76d038058503affc07dd6b9ebf526e7bec1e Mon Sep 17 00:00:00 2001 From: nph4rd Date: Sun, 8 Feb 2026 12:35:28 -0600 Subject: [PATCH 03/21] require Protocol in MultiAgentEnv, simplify docstrings --- verifiers/envs/agent.py | 10 ++------ verifiers/envs/multiagent_env.py | 40 +++++++++----------------------- 2 files changed, 13 insertions(+), 37 deletions(-) diff --git a/verifiers/envs/agent.py b/verifiers/envs/agent.py index abdee5cef..3f52d67ba 100644 --- a/verifiers/envs/agent.py +++ b/verifiers/envs/agent.py @@ -1,9 +1,7 @@ """ Agent: A participant in multi-agent environments. -Currently contains agent metadata (id, system prompt, trainability). -In the future, when Harness is introduced, Agent will be extended to -compose with Harness and Model: Agent = Harness + Model. +Contains agent metadata (id, system prompt, trainability). """ from dataclasses import dataclass @@ -16,12 +14,8 @@ class Agent: Fields: id: Unique identifier for this agent (e.g., "player_0", "guesser") - system_prompt: The agent's persona/instructions + system_prompt: The agent's specific instructions is_trainable: Whether to compute gradients for this agent's actions - - Future: - When Harness is introduced, Agent will be extended to include - rollout logic and model binding: Agent = Harness + Model. """ id: str diff --git a/verifiers/envs/multiagent_env.py b/verifiers/envs/multiagent_env.py index 62cd11d16..87dbf6de6 100644 --- a/verifiers/envs/multiagent_env.py +++ b/verifiers/envs/multiagent_env.py @@ -4,7 +4,7 @@ This module provides the base class for multi-agent RL environments, extending MultiTurnEnv with support for: - Multiple agents with distinct system prompts -- Turn order management via Protocol or get_initial_agent() / get_next_agent() +- Turn order management via Protocol - Per-agent trajectory tagging for credit assignment Key concepts: @@ -12,9 +12,8 @@ - Protocol: Defines turn order and interaction patterns (defined in protocol.py) Environment Implementation: +- Subclasses must provide a Protocol for turn management - Subclasses implement these main hooks: - - get_initial_agent(state): Who goes first (or use a Protocol) - - get_next_agent(state): Who goes next (or use a Protocol) - build_agent_prompt(agent_id, state): Build fresh prompt for this agent - on_turn_complete(state): Update game state after each turn """ @@ -32,29 +31,26 @@ class MultiAgentEnv(MultiTurnEnv): """ Base class for multi-agent environments. - Turn order can be specified either by: - 1. Passing a Protocol to __init__ (reusable turn logic) - 2. Implementing get_initial_agent() and get_next_agent() in subclass + Turn order is managed by a Protocol, which must be provided at init. + This keeps turn logic reusable and separate from environment logic. Subclasses must implement: - build_agent_prompt(): Build prompt for current agent Subclasses may optionally override: - on_turn_complete(): Game logic after each turn - - get_initial_agent() / get_next_agent(): If not using a Protocol """ # List of agent IDs this environment uses (e.g., ["player_0", "player_1"]) # Subclasses should override this or set in __init__ agents: list[str] = [] - def __init__(self, protocol: Protocol | None = None, **kwargs): + def __init__(self, protocol: Protocol, **kwargs): """ Initialize multi-agent environment. Args: - protocol: Optional Protocol for turn order. If not provided, - subclass must implement get_initial_agent/get_next_agent. + protocol: Protocol for turn order management. **kwargs: Passed to MultiTurnEnv """ super().__init__(**kwargs) @@ -76,30 +72,16 @@ def get_agent(self, agent_id: str) -> Agent: return self._agent_registry[agent_id] # ------------------------------------------------------------------------- - # Turn Management + # Turn Management (delegated to Protocol) # ------------------------------------------------------------------------- def get_initial_agent(self, state: State) -> str: - """ - Return the agent ID that starts the rollout. - - Default: delegates to Protocol if provided. - Override in subclass if not using a Protocol. - """ - if self._protocol: - return self._protocol.get_initial_agent(state) - raise NotImplementedError("Provide a Protocol or override get_initial_agent()") + """Return the agent ID that starts the rollout.""" + return self._protocol.get_initial_agent(state) def get_next_agent(self, state: State) -> str: - """ - Return the agent ID for the next turn. - - Default: delegates to Protocol if provided. - Override in subclass if not using a Protocol. - """ - if self._protocol: - return self._protocol.get_next_agent(state) - raise NotImplementedError("Provide a Protocol or override get_next_agent()") + """Return the agent ID for the next turn.""" + return self._protocol.get_next_agent(state) # ------------------------------------------------------------------------- # Agent Prompt Building (Subclasses Implement This) From 27081693329bc01d626e65abc7a45d83b9cabe4f Mon Sep 17 00:00:00 2001 From: nph4rd Date: Sun, 8 Feb 2026 12:50:24 -0600 Subject: [PATCH 04/21] update docstrings --- verifiers/envs/protocol.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/verifiers/envs/protocol.py b/verifiers/envs/protocol.py index 00d640013..b43e71e0d 100644 --- a/verifiers/envs/protocol.py +++ b/verifiers/envs/protocol.py @@ -6,7 +6,7 @@ to be reused across different tasks. Example protocols: -- RoundRobinProtocol: Agents take turns in order (player_0 → player_1 → ...) +- RoundRobinProtocol: Agents take turns in order (0 → 1 → ...) - SimultaneousProtocol: All agents act at once (future) - HierarchicalProtocol: Some agents coordinate others (future) """ @@ -61,7 +61,7 @@ class RoundRobinProtocol(Protocol): Simple round-robin turn order: agents take turns in sequence. Example with 3 agents: - player_0 → player_1 → player_2 → player_0 → ... + 0 → 1 → 2 → 0 → ... """ def __init__(self, agent_ids: list[str]): From e8c04dc276d52a3fa7f67f60a4fac3b4429d086f Mon Sep 17 00:00:00 2001 From: nph4rd Date: Tue, 24 Feb 2026 19:38:42 -0600 Subject: [PATCH 05/21] add multi-agent reward functions for heterogeneous rewards --- verifiers/rubrics/rubric.py | 99 +++++++++++++++++++++++++++++++++++-- verifiers/types.py | 3 +- 2 files changed, 96 insertions(+), 6 deletions(-) diff --git a/verifiers/rubrics/rubric.py b/verifiers/rubrics/rubric.py index 914ae4d69..53607372b 100644 --- a/verifiers/rubrics/rubric.py +++ b/verifiers/rubrics/rubric.py @@ -8,6 +8,7 @@ from verifiers.decorators import discover_decorated from verifiers.types import ( GroupRewardFunc, + MultiAgentRewardFunc, RewardFunc, RolloutScore, State, @@ -103,6 +104,8 @@ def _get_reward_weights(self) -> list[float]: def _is_group_func(self, func: RewardFunc) -> bool: """Check if a function is a GroupRewardFunc by inspecting its signature.""" + if self._is_multiagent_func(func): + return False sig = inspect.signature(func) # GroupRewardFunc has plural parameters: states, prompts, completions, etc. param_names = set(sig.parameters.keys()) @@ -178,22 +181,32 @@ def task_for_state(self, state: State, resources: object | None) -> object: return to_task(state["input"]) return None + def _is_multiagent_func(self, func: RewardFunc) -> bool: + """Check if a function is a MultiAgentRewardFunc by inspecting its return annotation.""" + sig = inspect.signature(func) + return_annotation = sig.return_annotation + return return_annotation is dict or get_origin(return_annotation) is dict + # individual-level reward helpers def _get_individual_reward_func_names(self) -> list[str]: return [ getattr(func, "__name__", repr(func)) for func in self.funcs - if not self._is_group_func(func) + if not self._is_group_func(func) and not self._is_multiagent_func(func) ] def _get_individual_reward_funcs(self) -> list[RewardFunc]: - return [func for func in self.funcs if not self._is_group_func(func)] + return [ + func + for func in self.funcs + if not self._is_group_func(func) and not self._is_multiagent_func(func) + ] def _get_individual_reward_weights(self) -> list[float]: return [ weight for func, weight in zip(self.funcs, self.weights) - if not self._is_group_func(func) + if not self._is_group_func(func) and not self._is_multiagent_func(func) ] async def _call_individual_reward_func( @@ -291,6 +304,40 @@ async def _call_group_reward_func( ans = [0.0] * len(states) return ans + # multi-agent reward helpers + def _get_multiagent_reward_funcs(self) -> list[MultiAgentRewardFunc]: + return cast( + list[MultiAgentRewardFunc], + [func for func in self.funcs if self._is_multiagent_func(func)], + ) + + async def _call_multiagent_reward_func( + self, + func: MultiAgentRewardFunc, + state: State, + ) -> dict[str, float]: + """Invoke a multi-agent reward function that returns per-agent rewards.""" + sig = inspect.signature(func) + merged = self.score_objects(state) + if any(p.kind == p.VAR_KEYWORD for p in sig.parameters.values()): + try: + ans = await maybe_await(func, **merged) + except Exception as e: + self.logger.error( + f"Error calling multi-agent reward function {func.__name__}: {e}" + ) + ans = {} + else: + allowed = {k: v for k, v in merged.items() if k in sig.parameters} + try: + ans = await maybe_await(func, **allowed) + except Exception as e: + self.logger.error( + f"Error calling multi-agent reward function {func.__name__}: {e}" + ) + ans = {} + return ans + async def cleanup(self, state: State): """Run all @vf.cleanup-decorated methods on this rubric.""" for handler in self._cleanup_handlers: @@ -389,18 +436,48 @@ async def score_group(self, states: list[State]): Score a group of rollouts together. All reward functions are executed in order, parallelizing across states. + Supports multi-agent reward functions that return per-agent rewards. """ num_states = len(states) if num_states == 0: self.logger.warning("No states to score") return aggregated_rewards = [0.0] * num_states + # Per-agent rewards for multi-agent envs: list of dict[agent_id, reward] + aggregated_agent_rewards: list[dict[str, float]] = [ + {} for _ in range(num_states) + ] aggregated_metrics: dict[str, list[float]] = {} # process functions in order for func, weight in zip(self.funcs, self.weights): is_group = self._is_group_func(func) - if is_group: + is_multiagent = self._is_multiagent_func(func) + + if is_multiagent: + # MultiAgentRewardFunc: returns dict[str, float] per state + multiagent_func = cast(MultiAgentRewardFunc, func) + score_tasks = [ + self._call_multiagent_reward_func(multiagent_func, state) + for state in states + ] + agent_scores_list = await asyncio.gather(*score_tasks) + + func_name = func.__name__ + for i, agent_scores in enumerate(agent_scores_list): + # Aggregate per-agent rewards + for agent_id, score_value in agent_scores.items(): + if agent_id not in aggregated_agent_rewards[i]: + aggregated_agent_rewards[i][agent_id] = 0.0 + aggregated_agent_rewards[i][agent_id] += score_value * weight + # Also compute a rollout-level reward (mean of agent rewards) + if agent_scores: + mean_score = sum(agent_scores.values()) / len(agent_scores) + aggregated_rewards[i] += mean_score * weight + if func_name not in aggregated_metrics: + aggregated_metrics[func_name] = [0.0] * num_states + aggregated_metrics[func_name][i] = mean_score + elif is_group: # GroupRewardFunc: score all states together group_func = cast(GroupRewardFunc, func) scores = await self._call_group_reward_func(group_func, states) @@ -431,11 +508,23 @@ async def score_group(self, states: list[State]): for i, state in enumerate(states): state["reward"] = aggregated_rewards[i] state["advantage"] = aggregated_rewards[i] - avg_reward + + # Store per-agent rewards if any multi-agent funcs were used + agent_rewards = aggregated_agent_rewards[i] + if agent_rewards: + state["agent_rewards"] = agent_rewards + + # Assign per-step rewards based on agent_id (for multi-agent) for t in state["trajectory"]: if t["advantage"] is None: t["advantage"] = state["advantage"] if t["reward"] is None: - t["reward"] = state["reward"] + if agent_rewards: + agent_id = t.get("extras", {}).get("agent_id") + t["reward"] = agent_rewards.get(agent_id, state["reward"]) + else: + t["reward"] = state["reward"] + state["metrics"] = { func_name: values[i] for func_name, values in aggregated_metrics.items() } diff --git a/verifiers/types.py b/verifiers/types.py index 0e4f63c19..f817db387 100644 --- a/verifiers/types.py +++ b/verifiers/types.py @@ -226,7 +226,8 @@ class Response(CustomBaseModel): SamplingArgs = dict[str, Any] IndividualRewardFunc = Callable[..., float | Awaitable[float]] GroupRewardFunc = Callable[..., list[float] | Awaitable[list[float]]] -RewardFunc = IndividualRewardFunc | GroupRewardFunc +MultiAgentRewardFunc = Callable[..., dict[str, float] | Awaitable[dict[str, float]]] +RewardFunc = IndividualRewardFunc | GroupRewardFunc | MultiAgentRewardFunc DatasetBuilder: TypeAlias = "Callable[[], Dataset]" From 65c28539f526d78c7c2e848278346354a8b259cb Mon Sep 17 00:00:00 2001 From: nph4rd Date: Tue, 24 Feb 2026 22:02:03 -0600 Subject: [PATCH 06/21] compute per-agent advantages for multi-agent rewards --- verifiers/rubrics/rubric.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/verifiers/rubrics/rubric.py b/verifiers/rubrics/rubric.py index 53607372b..0f50bc3c3 100644 --- a/verifiers/rubrics/rubric.py +++ b/verifiers/rubrics/rubric.py @@ -514,16 +514,17 @@ async def score_group(self, states: list[State]): if agent_rewards: state["agent_rewards"] = agent_rewards - # Assign per-step rewards based on agent_id (for multi-agent) + # Assign per-step rewards and advantages based on agent_id (for multi-agent) for t in state["trajectory"]: - if t["advantage"] is None: - t["advantage"] = state["advantage"] if t["reward"] is None: if agent_rewards: agent_id = t.get("extras", {}).get("agent_id") t["reward"] = agent_rewards.get(agent_id, state["reward"]) else: t["reward"] = state["reward"] + # Compute per-agent advantage: agent's reward - shared baseline + if t["advantage"] is None: + t["advantage"] = t["reward"] - avg_reward state["metrics"] = { func_name: values[i] for func_name, values in aggregated_metrics.items() From 2ca7c720a1dccee107d77310b2db94a76d02ec52 Mon Sep 17 00:00:00 2001 From: nph4rd Date: Wed, 25 Feb 2026 15:35:10 -0600 Subject: [PATCH 07/21] include all rewards in per-agent rewards for multi-agent training --- verifiers/rubrics/rubric.py | 20 ++++++++++++++++++++ 1 file changed, 20 insertions(+) diff --git a/verifiers/rubrics/rubric.py b/verifiers/rubrics/rubric.py index 0f50bc3c3..ec92a8f36 100644 --- a/verifiers/rubrics/rubric.py +++ b/verifiers/rubrics/rubric.py @@ -488,6 +488,16 @@ async def score_group(self, states: list[State]): score_value = scores[i] aggregated_rewards[i] += score_value * weight aggregated_metrics[func_name][i] = score_value + # Also add to each agent's rewards (for multi-agent compatibility) + agent_ids = set( + t.get("extras", {}).get("agent_id") + for t in states[i]["trajectory"] + if t.get("extras", {}).get("agent_id") + ) + for agent_id in agent_ids: + if agent_id not in aggregated_agent_rewards[i]: + aggregated_agent_rewards[i][agent_id] = 0.0 + aggregated_agent_rewards[i][agent_id] += score_value * weight else: reward_func = cast(RewardFunc, func) score_tasks = [ @@ -503,6 +513,16 @@ async def score_group(self, states: list[State]): score_value = scores[i] aggregated_rewards[i] += score_value * weight aggregated_metrics[func_name][i] = score_value + # Also add to each agent's rewards (for multi-agent compatibility) + agent_ids = set( + t.get("extras", {}).get("agent_id") + for t in states[i]["trajectory"] + if t.get("extras", {}).get("agent_id") + ) + for agent_id in agent_ids: + if agent_id not in aggregated_agent_rewards[i]: + aggregated_agent_rewards[i][agent_id] = 0.0 + aggregated_agent_rewards[i][agent_id] += score_value * weight avg_reward = sum(aggregated_rewards) / num_states for i, state in enumerate(states): From 0c8ce5e9dddb3d5afa7493bbbb171c5a71942305 Mon Sep 17 00:00:00 2001 From: nph4rd Date: Wed, 25 Feb 2026 21:49:57 -0600 Subject: [PATCH 08/21] add opponent-conditioned baselines for multi-agent advantage estimation --- verifiers/rubrics/rubric.py | 65 +++++++++++++++++++++++++++++++++++-- 1 file changed, 63 insertions(+), 2 deletions(-) diff --git a/verifiers/rubrics/rubric.py b/verifiers/rubrics/rubric.py index ec92a8f36..eea7fcfc7 100644 --- a/verifiers/rubrics/rubric.py +++ b/verifiers/rubrics/rubric.py @@ -525,6 +525,51 @@ async def score_group(self, states: list[State]): aggregated_agent_rewards[i][agent_id] += score_value * weight avg_reward = sum(aggregated_rewards) / num_states + + # For multi-agent: compute opponent-conditioned baselines + # Group states by opponent behavior to isolate each agent's learning signal + has_multiagent = any(aggregated_agent_rewards[i] for i in range(num_states)) + opponent_baselines: dict[ + str, dict[str, float] + ] = {} # {agent_id: {opponent_sig: baseline}} + + if has_multiagent: + # Build opponent-conditioned baselines for each agent + # For each agent, group rollouts by what the opponent(s) did + agent_ids_in_group: set[str] = set() + for agent_rewards in aggregated_agent_rewards: + agent_ids_in_group.update(agent_rewards.keys()) + + for agent_id in agent_ids_in_group: + # For each state, extract opponent's actions (actions by agents != agent_id) + opponent_groups: dict[ + str, list[tuple[int, float]] + ] = {} # {opponent_signature: [(state_idx, agent_reward)]} + for i, state in enumerate(states): + # Get opponent's action signature from trajectory + opponent_actions = [] + for t in state["trajectory"]: + step_agent_id = t.get("extras", {}).get("agent_id") + if step_agent_id and step_agent_id != agent_id: + # Use completion content as opponent action signature + opponent_actions.append(str(t.get("completion", ""))) + opponent_sig = "|".join(opponent_actions) + + # Get this agent's reward for this state + agent_reward = aggregated_agent_rewards[i].get( + agent_id, aggregated_rewards[i] + ) + if opponent_sig not in opponent_groups: + opponent_groups[opponent_sig] = [] + opponent_groups[opponent_sig].append((i, agent_reward)) + + # Compute baseline for each opponent behavior group + opponent_baselines[agent_id] = {} + for opponent_sig, rewards_list in opponent_groups.items(): + if rewards_list: + baseline = sum(r for _, r in rewards_list) / len(rewards_list) + opponent_baselines[agent_id][opponent_sig] = baseline + for i, state in enumerate(states): state["reward"] = aggregated_rewards[i] state["advantage"] = aggregated_rewards[i] - avg_reward @@ -542,9 +587,25 @@ async def score_group(self, states: list[State]): t["reward"] = agent_rewards.get(agent_id, state["reward"]) else: t["reward"] = state["reward"] - # Compute per-agent advantage: agent's reward - shared baseline + + # Compute per-agent advantage with opponent-conditioned baseline if t["advantage"] is None: - t["advantage"] = t["reward"] - avg_reward + agent_id = t.get("extras", {}).get("agent_id") + if agent_id and agent_id in opponent_baselines: + # Get opponent's action signature for this state + opponent_actions = [] + for t2 in state["trajectory"]: + step_agent_id = t2.get("extras", {}).get("agent_id") + if step_agent_id and step_agent_id != agent_id: + opponent_actions.append(str(t2.get("completion", ""))) + opponent_sig = "|".join(opponent_actions) + # Use opponent-conditioned baseline + baseline = opponent_baselines[agent_id].get( + opponent_sig, avg_reward + ) + t["advantage"] = t["reward"] - baseline + else: + t["advantage"] = t["reward"] - avg_reward state["metrics"] = { func_name: values[i] for func_name, values in aggregated_metrics.items() From b426660ce655a3aaf697bc2d599447ea9170e9ec Mon Sep 17 00:00:00 2001 From: nph4rd Date: Wed, 25 Feb 2026 22:23:13 -0600 Subject: [PATCH 09/21] add debug logging for opponent-conditioned baselines --- verifiers/rubrics/rubric.py | 21 +++++++++++++++++++++ 1 file changed, 21 insertions(+) diff --git a/verifiers/rubrics/rubric.py b/verifiers/rubrics/rubric.py index eea7fcfc7..54090f66b 100644 --- a/verifiers/rubrics/rubric.py +++ b/verifiers/rubrics/rubric.py @@ -533,6 +533,10 @@ async def score_group(self, states: list[State]): str, dict[str, float] ] = {} # {agent_id: {opponent_sig: baseline}} + # DEBUG + print(f"[DEBUG] has_multiagent={has_multiagent}, num_states={num_states}") + print(f"[DEBUG] aggregated_agent_rewards={aggregated_agent_rewards[:3]}") + if has_multiagent: # Build opponent-conditioned baselines for each agent # For each agent, group rollouts by what the opponent(s) did @@ -570,6 +574,14 @@ async def score_group(self, states: list[State]): baseline = sum(r for _, r in rewards_list) / len(rewards_list) opponent_baselines[agent_id][opponent_sig] = baseline + # DEBUG + print( + f"[DEBUG] agent_id={agent_id}, opponent_groups keys={list(opponent_groups.keys())[:3]}" + ) + print( + f"[DEBUG] opponent_baselines[{agent_id}]={opponent_baselines[agent_id]}" + ) + for i, state in enumerate(states): state["reward"] = aggregated_rewards[i] state["advantage"] = aggregated_rewards[i] - avg_reward @@ -604,8 +616,17 @@ async def score_group(self, states: list[State]): opponent_sig, avg_reward ) t["advantage"] = t["reward"] - baseline + # DEBUG (only first few) + if i < 2: + print( + f"[DEBUG] i={i} agent={agent_id} reward={t['reward']} baseline={baseline} adv={t['advantage']}" + ) else: t["advantage"] = t["reward"] - avg_reward + if i < 2: + print( + f"[DEBUG] i={i} agent={agent_id} NOT in opponent_baselines, using avg_reward" + ) state["metrics"] = { func_name: values[i] for func_name, values in aggregated_metrics.items() From 5bc1468ab776aacf68d28e621309960f623a6713 Mon Sep 17 00:00:00 2001 From: nph4rd Date: Wed, 25 Feb 2026 22:33:39 -0600 Subject: [PATCH 10/21] add trajectory structure debug --- verifiers/rubrics/rubric.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/verifiers/rubrics/rubric.py b/verifiers/rubrics/rubric.py index 54090f66b..588231057 100644 --- a/verifiers/rubrics/rubric.py +++ b/verifiers/rubrics/rubric.py @@ -536,6 +536,14 @@ async def score_group(self, states: list[State]): # DEBUG print(f"[DEBUG] has_multiagent={has_multiagent}, num_states={num_states}") print(f"[DEBUG] aggregated_agent_rewards={aggregated_agent_rewards[:3]}") + # DEBUG: inspect trajectory structure + if states and states[0].get("trajectory"): + traj = states[0]["trajectory"] + print(f"[DEBUG] trajectory len={len(traj)}") + for idx, t in enumerate(traj): + print( + f"[DEBUG] step {idx}: agent_id={t.get('extras', {}).get('agent_id')}, completion={t.get('completion')}" + ) if has_multiagent: # Build opponent-conditioned baselines for each agent From 2b37080233940c63fcb5ff4f8703ed84100faf8e Mon Sep 17 00:00:00 2001 From: nph4rd Date: Wed, 25 Feb 2026 22:37:51 -0600 Subject: [PATCH 11/21] debug extras and state keys --- verifiers/rubrics/rubric.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/verifiers/rubrics/rubric.py b/verifiers/rubrics/rubric.py index 588231057..924e134ae 100644 --- a/verifiers/rubrics/rubric.py +++ b/verifiers/rubrics/rubric.py @@ -536,13 +536,16 @@ async def score_group(self, states: list[State]): # DEBUG print(f"[DEBUG] has_multiagent={has_multiagent}, num_states={num_states}") print(f"[DEBUG] aggregated_agent_rewards={aggregated_agent_rewards[:3]}") + # DEBUG: check state keys + if states: + print(f"[DEBUG] state keys={list(states[0].keys())}") # DEBUG: inspect trajectory structure if states and states[0].get("trajectory"): traj = states[0]["trajectory"] print(f"[DEBUG] trajectory len={len(traj)}") for idx, t in enumerate(traj): print( - f"[DEBUG] step {idx}: agent_id={t.get('extras', {}).get('agent_id')}, completion={t.get('completion')}" + f"[DEBUG] step {idx}: agent_id={t.get('extras', {}).get('agent_id')}, completion={t.get('completion')}, extras={t.get('extras')}" ) if has_multiagent: From dce56cd6fea9c1fc29f48aa8b3499cf4f48f7405 Mon Sep 17 00:00:00 2001 From: nph4rd Date: Thu, 26 Feb 2026 22:38:31 -0600 Subject: [PATCH 12/21] remove opponent-conditioned baselines for comparison test --- verifiers/rubrics/rubric.py | 95 +------------------------------------ 1 file changed, 2 insertions(+), 93 deletions(-) diff --git a/verifiers/rubrics/rubric.py b/verifiers/rubrics/rubric.py index 924e134ae..c4f8f7f38 100644 --- a/verifiers/rubrics/rubric.py +++ b/verifiers/rubrics/rubric.py @@ -526,73 +526,6 @@ async def score_group(self, states: list[State]): avg_reward = sum(aggregated_rewards) / num_states - # For multi-agent: compute opponent-conditioned baselines - # Group states by opponent behavior to isolate each agent's learning signal - has_multiagent = any(aggregated_agent_rewards[i] for i in range(num_states)) - opponent_baselines: dict[ - str, dict[str, float] - ] = {} # {agent_id: {opponent_sig: baseline}} - - # DEBUG - print(f"[DEBUG] has_multiagent={has_multiagent}, num_states={num_states}") - print(f"[DEBUG] aggregated_agent_rewards={aggregated_agent_rewards[:3]}") - # DEBUG: check state keys - if states: - print(f"[DEBUG] state keys={list(states[0].keys())}") - # DEBUG: inspect trajectory structure - if states and states[0].get("trajectory"): - traj = states[0]["trajectory"] - print(f"[DEBUG] trajectory len={len(traj)}") - for idx, t in enumerate(traj): - print( - f"[DEBUG] step {idx}: agent_id={t.get('extras', {}).get('agent_id')}, completion={t.get('completion')}, extras={t.get('extras')}" - ) - - if has_multiagent: - # Build opponent-conditioned baselines for each agent - # For each agent, group rollouts by what the opponent(s) did - agent_ids_in_group: set[str] = set() - for agent_rewards in aggregated_agent_rewards: - agent_ids_in_group.update(agent_rewards.keys()) - - for agent_id in agent_ids_in_group: - # For each state, extract opponent's actions (actions by agents != agent_id) - opponent_groups: dict[ - str, list[tuple[int, float]] - ] = {} # {opponent_signature: [(state_idx, agent_reward)]} - for i, state in enumerate(states): - # Get opponent's action signature from trajectory - opponent_actions = [] - for t in state["trajectory"]: - step_agent_id = t.get("extras", {}).get("agent_id") - if step_agent_id and step_agent_id != agent_id: - # Use completion content as opponent action signature - opponent_actions.append(str(t.get("completion", ""))) - opponent_sig = "|".join(opponent_actions) - - # Get this agent's reward for this state - agent_reward = aggregated_agent_rewards[i].get( - agent_id, aggregated_rewards[i] - ) - if opponent_sig not in opponent_groups: - opponent_groups[opponent_sig] = [] - opponent_groups[opponent_sig].append((i, agent_reward)) - - # Compute baseline for each opponent behavior group - opponent_baselines[agent_id] = {} - for opponent_sig, rewards_list in opponent_groups.items(): - if rewards_list: - baseline = sum(r for _, r in rewards_list) / len(rewards_list) - opponent_baselines[agent_id][opponent_sig] = baseline - - # DEBUG - print( - f"[DEBUG] agent_id={agent_id}, opponent_groups keys={list(opponent_groups.keys())[:3]}" - ) - print( - f"[DEBUG] opponent_baselines[{agent_id}]={opponent_baselines[agent_id]}" - ) - for i, state in enumerate(states): state["reward"] = aggregated_rewards[i] state["advantage"] = aggregated_rewards[i] - avg_reward @@ -611,33 +544,9 @@ async def score_group(self, states: list[State]): else: t["reward"] = state["reward"] - # Compute per-agent advantage with opponent-conditioned baseline + # Compute per-agent advantage using global baseline if t["advantage"] is None: - agent_id = t.get("extras", {}).get("agent_id") - if agent_id and agent_id in opponent_baselines: - # Get opponent's action signature for this state - opponent_actions = [] - for t2 in state["trajectory"]: - step_agent_id = t2.get("extras", {}).get("agent_id") - if step_agent_id and step_agent_id != agent_id: - opponent_actions.append(str(t2.get("completion", ""))) - opponent_sig = "|".join(opponent_actions) - # Use opponent-conditioned baseline - baseline = opponent_baselines[agent_id].get( - opponent_sig, avg_reward - ) - t["advantage"] = t["reward"] - baseline - # DEBUG (only first few) - if i < 2: - print( - f"[DEBUG] i={i} agent={agent_id} reward={t['reward']} baseline={baseline} adv={t['advantage']}" - ) - else: - t["advantage"] = t["reward"] - avg_reward - if i < 2: - print( - f"[DEBUG] i={i} agent={agent_id} NOT in opponent_baselines, using avg_reward" - ) + t["advantage"] = t["reward"] - avg_reward state["metrics"] = { func_name: values[i] for func_name, values in aggregated_metrics.items() From 5425ebb009600573423e2d6e86bb71d484d4f2f9 Mon Sep 17 00:00:00 2001 From: nph4rd Date: Tue, 3 Mar 2026 23:47:31 -0600 Subject: [PATCH 13/21] add per-agent baselines for multi-agent advantage computation --- verifiers/rubrics/rubric.py | 27 +++++++++++++++++++++++++-- 1 file changed, 25 insertions(+), 2 deletions(-) diff --git a/verifiers/rubrics/rubric.py b/verifiers/rubrics/rubric.py index c4f8f7f38..8df57cd6d 100644 --- a/verifiers/rubrics/rubric.py +++ b/verifiers/rubrics/rubric.py @@ -526,6 +526,25 @@ async def score_group(self, states: list[State]): avg_reward = sum(aggregated_rewards) / num_states + # Compute per-agent baselines for multi-agent environments + # Each agent's advantage is relative to that agent's mean reward + agent_baselines: dict[str, float] = {} + agent_reward_sums: dict[str, float] = {} + agent_reward_counts: dict[str, int] = {} + + for i in range(num_states): + for agent_id, reward in aggregated_agent_rewards[i].items(): + if agent_id not in agent_reward_sums: + agent_reward_sums[agent_id] = 0.0 + agent_reward_counts[agent_id] = 0 + agent_reward_sums[agent_id] += reward + agent_reward_counts[agent_id] += 1 + + for agent_id in agent_reward_sums: + agent_baselines[agent_id] = ( + agent_reward_sums[agent_id] / agent_reward_counts[agent_id] + ) + for i, state in enumerate(states): state["reward"] = aggregated_rewards[i] state["advantage"] = aggregated_rewards[i] - avg_reward @@ -544,9 +563,13 @@ async def score_group(self, states: list[State]): else: t["reward"] = state["reward"] - # Compute per-agent advantage using global baseline + # Compute per-agent advantage using per-agent baseline if t["advantage"] is None: - t["advantage"] = t["reward"] - avg_reward + agent_id = t.get("extras", {}).get("agent_id") + if agent_id and agent_id in agent_baselines: + t["advantage"] = t["reward"] - agent_baselines[agent_id] + else: + t["advantage"] = t["reward"] - avg_reward state["metrics"] = { func_name: values[i] for func_name, values in aggregated_metrics.items() From 003433325115d7543772d82f7aa9d03d7018f100 Mon Sep 17 00:00:00 2001 From: nph4rd Date: Fri, 6 Mar 2026 12:53:08 -0600 Subject: [PATCH 14/21] fix score_rollout to support multi-agent reward functions --- verifiers/rubrics/rubric.py | 77 ++++++++++++++++++++++++++----------- 1 file changed, 55 insertions(+), 22 deletions(-) diff --git a/verifiers/rubrics/rubric.py b/verifiers/rubrics/rubric.py index 8df57cd6d..2676a5ff4 100644 --- a/verifiers/rubrics/rubric.py +++ b/verifiers/rubrics/rubric.py @@ -10,7 +10,6 @@ GroupRewardFunc, MultiAgentRewardFunc, RewardFunc, - RolloutScore, State, TASK_INPUT_FIELDS, ) @@ -395,37 +394,71 @@ async def dummy_score_rollout(self, state: State): async def score_rollout(self, state: State): """ Evaluate all reward functions for a single rollout. + Supports individual and multi-agent reward functions (but not group functions). """ reward_funcs = self._get_individual_reward_funcs() + multiagent_funcs = self._get_multiagent_reward_funcs() group_reward_funcs = self._get_group_reward_funcs() - assert len(reward_funcs) > 0 and len(group_reward_funcs) == 0, ( - "Rubric.score_rollout requires at least one individual-level reward function and no group-level reward functions" + + has_reward_funcs = len(reward_funcs) > 0 or len(multiagent_funcs) > 0 + assert has_reward_funcs and len(group_reward_funcs) == 0, ( + "Rubric.score_rollout requires at least one individual-level or multi-agent " + "reward function and no group-level reward functions" ) - reward_scores = [] - for func in reward_funcs: - reward_scores.append( - await self._call_individual_reward_func( - func=func, - state=state, - ) + + aggregated_reward = 0.0 + aggregated_metrics: dict[str, float] = {} + aggregated_agent_rewards: dict[str, float] = {} + + agent_ids = { + t.get("extras", {}).get("agent_id") + for t in state["trajectory"] + if t.get("extras", {}).get("agent_id") + } + + for func, weight in zip(reward_funcs, self._get_individual_reward_weights()): + score = await self._call_individual_reward_func(func=func, state=state) + aggregated_reward += score * weight + aggregated_metrics[func.__name__] = score + for agent_id in agent_ids: + if agent_id not in aggregated_agent_rewards: + aggregated_agent_rewards[agent_id] = 0.0 + aggregated_agent_rewards[agent_id] += score * weight + + multiagent_weights = [ + weight + for func, weight in zip(self.funcs, self.weights) + if self._is_multiagent_func(func) + ] + for func, weight in zip(multiagent_funcs, multiagent_weights): + agent_scores = await self._call_multiagent_reward_func( + func=func, state=state ) + for agent_id, score_value in agent_scores.items(): + if agent_id not in aggregated_agent_rewards: + aggregated_agent_rewards[agent_id] = 0.0 + aggregated_agent_rewards[agent_id] += score_value * weight + if agent_scores: + mean_score = sum(agent_scores.values()) / len(agent_scores) + aggregated_reward += mean_score * weight + aggregated_metrics[func.__name__] = mean_score + rewards = RolloutScore( - metrics={ - func.__name__: reward - for func, reward in zip(reward_funcs, reward_scores) - }, - reward=sum( - [ - reward * weight - for reward, weight in zip( - reward_scores, self._get_individual_reward_weights() - ) - ] - ), + metrics=aggregated_metrics, + reward=aggregated_reward, ) state["reward"] = rewards["reward"] state["metrics"] = rewards["metrics"] + if aggregated_agent_rewards: + state["agent_rewards"] = aggregated_agent_rewards + for t in state["trajectory"]: + if t["reward"] is None: + agent_id = t.get("extras", {}).get("agent_id") + t["reward"] = aggregated_agent_rewards.get( + agent_id, state["reward"] + ) + async def dummy_score_group(self, states: list[State]): """Score a group of rollouts together with dummy rewards.""" for state in states: From 902e3f7edae30208724b0289fdec61a403296a3e Mon Sep 17 00:00:00 2001 From: nph4rd Date: Fri, 6 Mar 2026 15:01:48 -0600 Subject: [PATCH 15/21] normalize messages from build_agent_prompt before storing in trajectory --- verifiers/envs/multiagent_env.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/verifiers/envs/multiagent_env.py b/verifiers/envs/multiagent_env.py index 87dbf6de6..89dd46cbd 100644 --- a/verifiers/envs/multiagent_env.py +++ b/verifiers/envs/multiagent_env.py @@ -25,6 +25,7 @@ from verifiers.envs.multiturn_env import MultiTurnEnv from verifiers.envs.protocol import Protocol from verifiers.types import Messages, State, TrajectoryStep +from verifiers.utils.message_utils import normalize_messages class MultiAgentEnv(MultiTurnEnv): @@ -209,6 +210,9 @@ async def rollout( try: # 1. Build prompt for this agent prompt_messages = await self.build_agent_prompt(agent_id, state) + prompt_messages = normalize_messages( + prompt_messages, field_name="agent_prompt" + ) # 2. Get model response response = await self.get_model_response(state, prompt_messages) From c8d3715ab8e25dd01516546a100bc3921fb8f7c6 Mon Sep 17 00:00:00 2001 From: nph4rd Date: Fri, 6 Mar 2026 18:42:55 -0600 Subject: [PATCH 16/21] add per-agent reward metrics for multi-agent environments --- verifiers/rubrics/rubric.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/verifiers/rubrics/rubric.py b/verifiers/rubrics/rubric.py index 2676a5ff4..6ec472e37 100644 --- a/verifiers/rubrics/rubric.py +++ b/verifiers/rubrics/rubric.py @@ -442,6 +442,8 @@ async def score_rollout(self, state: State): mean_score = sum(agent_scores.values()) / len(agent_scores) aggregated_reward += mean_score * weight aggregated_metrics[func.__name__] = mean_score + for agent_id, score_value in agent_scores.items(): + aggregated_metrics[f"{func.__name__}/{agent_id}"] = score_value rewards = RolloutScore( metrics=aggregated_metrics, From e80aab628fdbf2c34a556374bf7283091f8a80d4 Mon Sep 17 00:00:00 2001 From: nph4rd Date: Sat, 7 Mar 2026 23:15:11 -0600 Subject: [PATCH 17/21] add per-agent model routing for multi-policy lora training --- verifiers/envs/env_group.py | 6 ++++++ verifiers/envs/environment.py | 4 ++++ verifiers/envs/multiagent_env.py | 14 ++++++++++++-- verifiers/serve/client/env_client.py | 4 ++++ verifiers/serve/server/env_worker.py | 8 ++++++++ verifiers/serve/types.py | 2 ++ 6 files changed, 36 insertions(+), 2 deletions(-) diff --git a/verifiers/envs/env_group.py b/verifiers/envs/env_group.py index f2104db98..1d17e1841 100644 --- a/verifiers/envs/env_group.py +++ b/verifiers/envs/env_group.py @@ -317,6 +317,7 @@ async def run_rollout( # type: ignore[override] max_retries: int = 0, state_columns: list[str] | None = None, env_client: EnvClient | None = None, + actor_models: dict[str, str] | None = None, ) -> vf.RolloutOutput: target_env_client = env_client or self.env_client if target_env_client is not None: @@ -331,6 +332,7 @@ async def run_rollout( # type: ignore[override] sampling_args, max_retries, state_columns, + actor_models=actor_models, ) env_name, child_input, route = self._route_child_input(input) @@ -343,6 +345,7 @@ async def run_rollout( # type: ignore[override] max_retries, state_columns, env.env_client, + actor_models=actor_models, ) return _set_info_route(output, route) # type: ignore[return-value] @@ -356,6 +359,7 @@ async def run_group( # type: ignore[override] max_retries: int = 0, state_columns: list[str] | None = None, env_client: EnvClient | None = None, + actor_models: dict[str, str] | None = None, ) -> list[vf.RolloutOutput]: target_env_client = env_client or self.env_client if target_env_client is not None: @@ -370,6 +374,7 @@ async def run_group( # type: ignore[override] sampling_args, max_retries, state_columns, + actor_models=actor_models, ) env_name, first_child_input, route = self._route_child_input(group_inputs[0]) @@ -396,6 +401,7 @@ async def run_group( # type: ignore[override] max_retries, state_columns, env.env_client, + actor_models=actor_models, ) return [_set_info_route(output, route) for output in outputs] # type: ignore[return-value] diff --git a/verifiers/envs/environment.py b/verifiers/envs/environment.py index ed379f086..1f972ee49 100644 --- a/verifiers/envs/environment.py +++ b/verifiers/envs/environment.py @@ -745,6 +745,7 @@ async def run_rollout( max_retries: int = 0, state_columns: list[str] | None = None, env_client: EnvClient | None = None, + actor_models: dict[str, str] | None = None, ) -> RolloutOutput: """Generate and, optionally, score a rollout.""" @@ -765,6 +766,7 @@ async def run_rollout( sampling_args, max_retries, state_columns, + actor_models=actor_models, ) resolved_client = resolve_client(client) @@ -791,6 +793,7 @@ async def run_group( max_retries: int = 0, state_columns: list[str] | None = None, env_client: EnvClient | None = None, + actor_models: dict[str, str] | None = None, **kwargs, ) -> list[RolloutOutput]: """Generate and, optionally, score one group.""" @@ -812,6 +815,7 @@ async def run_group( sampling_args, max_retries, state_columns, + actor_models=actor_models, ) resolved_client = resolve_client(client) diff --git a/verifiers/envs/multiagent_env.py b/verifiers/envs/multiagent_env.py index 89dd46cbd..4f14812ce 100644 --- a/verifiers/envs/multiagent_env.py +++ b/verifiers/envs/multiagent_env.py @@ -46,6 +46,11 @@ class MultiAgentEnv(MultiTurnEnv): # Subclasses should override this or set in __init__ agents: list[str] = [] + # Per-agent model routing for multi-policy training (e.g. per-agent LoRA adapters). + # Maps agent_id -> model name. When set, each agent's turns use its own model. + # When None, all agents share the single model passed to rollout(). + actor_models: dict[str, str] | None = None + def __init__(self, protocol: Protocol, **kwargs): """ Initialize multi-agent environment. @@ -214,8 +219,13 @@ async def rollout( prompt_messages, field_name="agent_prompt" ) - # 2. Get model response - response = await self.get_model_response(state, prompt_messages) + # 2. Get model response (route to per-agent model if configured) + agent_model = ( + self.actor_models.get(agent_id) if self.actor_models else None + ) + response = await self.get_model_response( + state, prompt_messages, model=agent_model + ) # 3. Store in trajectory (tags with agent_id) await self.add_model_response(state, prompt_messages, response) diff --git a/verifiers/serve/client/env_client.py b/verifiers/serve/client/env_client.py index 8649fb246..f6e8f2fd0 100644 --- a/verifiers/serve/client/env_client.py +++ b/verifiers/serve/client/env_client.py @@ -50,6 +50,7 @@ async def run_rollout( sampling_args: SamplingArgs, max_retries: int = 0, state_columns: list[str] | None = None, + actor_models: dict[str, str] | None = None, ) -> RolloutOutput: resolved_client_config = resolve_client_config(client_config) request = RunRolloutRequest( @@ -59,6 +60,7 @@ async def run_rollout( sampling_args=sampling_args, max_retries=max_retries, state_columns=state_columns, + actor_models=actor_models, ) response = await self.handle_run_rollout_request(request, timeout=None) assert response.output is not None @@ -72,6 +74,7 @@ async def run_group( sampling_args: SamplingArgs, max_retries: int = 0, state_columns: list[str] | None = None, + actor_models: dict[str, str] | None = None, ) -> list[RolloutOutput]: resolved_client_config = resolve_client_config(client_config) request = RunGroupRequest( @@ -81,6 +84,7 @@ async def run_group( sampling_args=sampling_args, max_retries=max_retries, state_columns=state_columns, + actor_models=actor_models, ) response = await self.handle_run_group_request(request, timeout=None) assert response.outputs is not None diff --git a/verifiers/serve/server/env_worker.py b/verifiers/serve/server/env_worker.py index 70f77cf30..3174e71f6 100644 --- a/verifiers/serve/server/env_worker.py +++ b/verifiers/serve/server/env_worker.py @@ -135,9 +135,14 @@ async def resolve_client(self, client_config: ClientConfig) -> Client: self.clients[key] = resolve_client(resolved) return self.clients[key] + def _set_actor_models(self, actor_models: dict[str, str] | None) -> None: + if actor_models is not None and hasattr(self.env, "actor_models"): + self.env.actor_models = actor_models + async def handle_run_rollout( self, request: RunRolloutRequest ) -> RunRolloutResponse: + self._set_actor_models(request.actor_models) client = await self.resolve_client(request.client_config) output = await self.env.run_rollout( input=request.input, @@ -146,10 +151,12 @@ async def handle_run_rollout( sampling_args=request.sampling_args, max_retries=request.max_retries, state_columns=request.state_columns, + actor_models=request.actor_models, ) return RunRolloutResponse(output=output) async def handle_run_group(self, request: RunGroupRequest) -> RunGroupResponse: + self._set_actor_models(request.actor_models) client = await self.resolve_client(request.client_config) outputs = await self.env.run_group( group_inputs=request.group_inputs, @@ -158,6 +165,7 @@ async def handle_run_group(self, request: RunGroupRequest) -> RunGroupResponse: sampling_args=request.sampling_args, max_retries=request.max_retries, state_columns=request.state_columns, + actor_models=request.actor_models, ) return RunGroupResponse(outputs=outputs) diff --git a/verifiers/serve/types.py b/verifiers/serve/types.py index 3c252c230..d7ed22d14 100644 --- a/verifiers/serve/types.py +++ b/verifiers/serve/types.py @@ -55,6 +55,7 @@ class RunRolloutRequest(BaseRequest): sampling_args: SamplingArgs max_retries: int state_columns: list[str] | None + actor_models: dict[str, str] | None = None class RunRolloutResponse(BaseResponse): @@ -71,6 +72,7 @@ class RunGroupRequest(BaseRequest): sampling_args: SamplingArgs max_retries: int state_columns: list[str] | None + actor_models: dict[str, str] | None = None class RunGroupResponse(BaseResponse): From 616bed6fba01e79c03deb1bf815e4ad3af2ab5a9 Mon Sep 17 00:00:00 2001 From: nph4rd Date: Sun, 8 Mar 2026 21:56:32 -0600 Subject: [PATCH 18/21] point textarena to fork with kuhn poker fixes --- pyproject.toml | 3 +++ 1 file changed, 3 insertions(+) diff --git a/pyproject.toml b/pyproject.toml index 21ae65e4d..73442e494 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -134,6 +134,9 @@ prime-pydantic-config = false renderers = false openenv-core = false +[tool.uv.sources] +textarena = { git = "https://github.com/nph4rd/TextArena.git", branch = "fix/kuhn-poker-phantom-ante" } + [tool.uv.extra-build-dependencies] flash-attn = [{ requirement = "torch", match-runtime = true }] From f430b7f6c226fd5e2156286a559281d6e807d646 Mon Sep 17 00:00:00 2001 From: nph4rd Date: Wed, 20 May 2026 12:27:50 -0600 Subject: [PATCH 19/21] fix rubric rollout score import after rebase --- verifiers/rubrics/rubric.py | 1 + 1 file changed, 1 insertion(+) diff --git a/verifiers/rubrics/rubric.py b/verifiers/rubrics/rubric.py index 6ec472e37..b1728b2ef 100644 --- a/verifiers/rubrics/rubric.py +++ b/verifiers/rubrics/rubric.py @@ -10,6 +10,7 @@ GroupRewardFunc, MultiAgentRewardFunc, RewardFunc, + RolloutScore, State, TASK_INPUT_FIELDS, ) From 7d91582faee2a540ef1d4604ae688d5f86d5635f Mon Sep 17 00:00:00 2001 From: nph4rd Date: Wed, 20 May 2026 20:08:26 -0600 Subject: [PATCH 20/21] add SpawningProtocol for hierarchical multi-agent envs MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Extends Protocol with should_spawn/get_spawn_specs. MultiAgentEnv.rollout() runs each spawned child env concurrently, scores it via its own rubric, and embeds the child trajectory steps into the parent's trajectory tagged with the spec's agent_id and is_trainable. Existing per-agent advantage computation (Rubric.score_group), trajectory splitting (interleave_rollout(split_by_agent=True)), and per-actor trainer metrics (MicroBatch.actor_ids) work without further changes. The use case is proposer-solver style envs where one agent's turn fans out into N rollouts of another agent — e.g. PrimeIntellect's general-agent synth-solver loop where the synthesizer creates tasks and the solver attempts them, with the synthesizer's reward depending on the solver's pass rate. RoundRobinProtocol is unchanged — the spawn branch is gated on isinstance(protocol, SpawningProtocol). --- tests/test_spawning_protocol.py | 225 +++++++++++++++++++++++++++++++ uv.lock | 14 +- verifiers/__init__.py | 14 +- verifiers/envs/multiagent_env.py | 92 ++++++++++++- verifiers/envs/protocol.py | 91 ++++++++++++- 5 files changed, 421 insertions(+), 15 deletions(-) create mode 100644 tests/test_spawning_protocol.py diff --git a/tests/test_spawning_protocol.py b/tests/test_spawning_protocol.py new file mode 100644 index 000000000..a9137f32d --- /dev/null +++ b/tests/test_spawning_protocol.py @@ -0,0 +1,225 @@ +"""End-to-end test for SpawningProtocol via a toy proposer/solver env. + +The proposer picks an integer N. The protocol then spawns k child solver +rollouts whose job is to double N. We assert that: + + - the parent's trajectory contains both proposer and child solver steps + - child steps are tagged with agent_id="solver" and is_trainable + - state["extras"]["spawns"] carries SpawnResult(s) with one State per child + - each child's reward was computed by its own rubric (score was 1.0 only + when the solver's completion equaled 2*N) +""" + +from __future__ import annotations + +import re + +import pytest +from datasets import Dataset + +import verifiers as vf +from verifiers import ( + Agent, + MultiAgentEnv, + Rubric, + SingleTurnEnv, + SpawningProtocol, + SpawnSpec, +) +from verifiers.types import Messages, State + + +PROPOSER_NUMBER = 5 # what the proposer always picks in this test +NUM_CHILDREN = 3 + + +# --------------------------------------------------------------------------- # +# Child env: simple single-turn doubling env. +# --------------------------------------------------------------------------- # + + +def _doubling_correct(completion, answer, **_) -> float: + text = completion if isinstance(completion, str) else completion[-1]["content"] + match = re.search(r"-?\d+", text) + if match is None: + return 0.0 + return 1.0 if int(match.group(0)) == int(answer) else 0.0 + + +@pytest.fixture +def child_solver_env(mock_client): + """SingleTurnEnv whose prompt asks for 2*N and rubric scores correctness.""" + dataset = Dataset.from_dict( + { + "prompt": [[{"role": "user", "content": f"Double {PROPOSER_NUMBER}."}]], + "answer": [str(2 * PROPOSER_NUMBER)], + "example_id": [0], + } + ) + rubric = Rubric(funcs=[_doubling_correct]) + + env = SingleTurnEnv( + client=mock_client, + model="test-model", + dataset=dataset, + parser=vf.Parser(), + rubric=rubric, + ) + + # Pre-stage half of the children's responses to be correct, half wrong. + # The mocked client returns the same default response unless overridden; + # since the proposer ALSO uses this client we'll just set a single default + # and use add_response for both turn types. + mock_client.add_response( + [{"role": "user", "content": f"Double {PROPOSER_NUMBER}."}], + str(2 * PROPOSER_NUMBER), + ) + return env + + +# --------------------------------------------------------------------------- # +# Parent env: proposer that emits a number, then spawns NUM_CHILDREN solvers. +# --------------------------------------------------------------------------- # + + +class _OneShotSpawnProtocol(SpawningProtocol): + """Spawns child solvers exactly once, immediately after the proposer's turn.""" + + def __init__(self, child_env, agent_id: str, num_children: int): + self._child_env = child_env + self._agent_id = agent_id + self._num_children = num_children + + def get_initial_agent(self, state: State) -> str: + return "proposer" + + def get_next_agent(self, state: State) -> str: + # Single-turn protocol: never returns a "next" agent because + # on_turn_complete sets state["is_completed"]=True. + return "proposer" + + def should_spawn(self, state: State) -> bool: + # Spawn once: only if the proposer just acted and we haven't spawned yet. + already_spawned = bool(state["extras"].get("spawns")) + last_step = state["trajectory"][-1] if state["trajectory"] else None + last_agent = (last_step or {}).get("extras", {}).get("agent_id") + return last_agent == "proposer" and not already_spawned + + def get_spawn_specs(self, state: State) -> list[SpawnSpec]: + # The proposer's "answer" is the last word of its completion. + text = state["trajectory"][-1]["completion"][-1]["content"] + n = int(re.search(r"-?\d+", text).group(0)) + prompt = [{"role": "user", "content": f"Double {n}."}] + inputs = [ + {"prompt": prompt, "answer": str(2 * n), "example_id": i} + for i in range(self._num_children) + ] + return [ + SpawnSpec( + agent_id=self._agent_id, + child_env=self._child_env, + inputs=inputs, + is_trainable=True, + ) + ] + + +class _ProposerEnv(MultiAgentEnv): + """One-turn proposer that picks a number, registered as a trainable agent.""" + + async def build_agent_prompt(self, agent_id: str, state: State) -> Messages: + return [ + { + "role": "user", + "content": "Pick an integer and the solver will try to double it.", + } + ] + + @vf.stop + async def proposer_done(self, state: State, **kwargs) -> bool: + # End the rollout once the proposer's turn has been spawned out; + # the spawn block in MultiAgentEnv.rollout() runs before the next + # iteration's is_completed check, so children finish first. + return bool(state.get("extras", {}).get("spawns")) + + +@pytest.fixture +def proposer_env(mock_client, child_solver_env): + protocol = _OneShotSpawnProtocol( + child_env=child_solver_env, agent_id="solver", num_children=NUM_CHILDREN + ) + rubric = Rubric() + env = _ProposerEnv( + protocol=protocol, + client=mock_client, + model="test-model", + dataset=Dataset.from_dict({"prompt": [[{"role": "user", "content": "go"}]], "example_id": [0]}), + parser=vf.Parser(), + rubric=rubric, + max_turns=8, + ) + env.register_agent(Agent(id="proposer", system_prompt="", is_trainable=True)) + env.register_agent(Agent(id="solver", system_prompt="", is_trainable=True)) + + mock_client.add_response( + [ + { + "role": "user", + "content": "Pick an integer and the solver will try to double it.", + } + ], + str(PROPOSER_NUMBER), + ) + return env + + +# --------------------------------------------------------------------------- # +# Tests +# --------------------------------------------------------------------------- # + + +@pytest.mark.asyncio +async def test_spawning_protocol_runs_children_and_records_spawns( + proposer_env, mock_client +): + """One proposer turn → NUM_CHILDREN solver children → all embedded + recorded.""" + state = await proposer_env.rollout( + {"prompt": [{"role": "user", "content": "go"}], "example_id": 0}, + client=mock_client, + model="test-model", + sampling_args={"temperature": 1.0}, + ) + + # 1. The parent trajectory contains the proposer's step plus one per child. + agent_ids = [s["extras"].get("agent_id") for s in state["trajectory"]] + assert agent_ids.count("proposer") == 1 + assert agent_ids.count("solver") == NUM_CHILDREN + + # 2. Spawns recorded. + spawns = state["extras"]["spawns"] + assert len(spawns) == 1 + spawn = spawns[0] + assert spawn.spec.agent_id == "solver" + assert len(spawn.states) == NUM_CHILDREN + + # 3. Children were scored by the child env's own rubric — the mocked + # solver always returns 2*N so each child's reward is 1.0. + for child_state in spawn.states: + assert child_state["reward"] == 1.0 + + +@pytest.mark.asyncio +async def test_child_trajectory_steps_carry_is_trainable_tag( + proposer_env, mock_client +): + state = await proposer_env.rollout( + {"prompt": [{"role": "user", "content": "go"}], "example_id": 0}, + client=mock_client, + model="test-model", + sampling_args={"temperature": 1.0}, + ) + child_steps = [s for s in state["trajectory"] if s["extras"].get("agent_id") == "solver"] + assert child_steps, "expected child steps in parent trajectory" + for step in child_steps: + # is_trainable was set on the SpawnSpec; it must flow through to steps. + assert step["extras"].get("is_trainable") is True diff --git a/uv.lock b/uv.lock index b1c47f140..85c9a01f2 100644 --- a/uv.lock +++ b/uv.lock @@ -19,8 +19,6 @@ conflicts = [[ ]] [options] -exclude-newer = "0001-01-01T00:00:00Z" # This has no effect and is included for backwards compatibility when using relative exclude-newer values. -exclude-newer-span = "P7D" [options.exclude-newer-package] prime-tunnel = false @@ -6195,8 +6193,8 @@ wheels = [ [[package]] name = "textarena" -version = "0.7.4" -source = { registry = "https://pypi.org/simple" } +version = "0.7.3" +source = { git = "https://github.com/nph4rd/TextArena.git?branch=fix%2Fkuhn-poker-phantom-ante#e3ab9a98faace0ac90fb80e5f6e886e269cf4d4d" } dependencies = [ { name = "chess" }, { name = "nltk" }, @@ -6206,10 +6204,6 @@ dependencies = [ { name = "rich" }, { name = "websockets" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/ba/04/4a3ca42093d0be2a9c377ae3335a6c6baac1d278ae932562ec69f339d172/textarena-0.7.4.tar.gz", hash = "sha256:28bb9170d7718f2ae05e4515bea82262422731e563fc7318a9e7983de0cadd4f", size = 954969, upload-time = "2025-10-16T14:41:55.981Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/26/b4/9a9ba65154aff853c75b3d7324319d168ad9c69c6097f4aa3c16da7d9ef3/textarena-0.7.4-py3-none-any.whl", hash = "sha256:684784e78278e518066f67557ee93b47c238d16cbbd15d3abdaa3147562d3024", size = 1073570, upload-time = "2025-10-16T14:41:53.965Z" }, -] [[package]] name = "textual" @@ -6895,7 +6889,7 @@ requires-dist = [ { name = "setproctitle", specifier = ">=1.3.0" }, { name = "stagehand", marker = "extra == 'browser'", specifier = ">=3.0.0" }, { name = "tenacity", specifier = ">=8.5.0" }, - { name = "textarena", marker = "extra == 'ta'" }, + { name = "textarena", marker = "extra == 'ta'", git = "https://github.com/nph4rd/TextArena.git?branch=fix%2Fkuhn-poker-phantom-ante" }, { name = "textual" }, { name = "tomli", marker = "python_full_version < '3.11'" }, { name = "torch", marker = "extra == 'rl'", specifier = ">=2.8.0,<2.9.0" }, @@ -6922,7 +6916,7 @@ dev = [ { name = "renderers", specifier = ">=0.1.8.dev4" }, { name = "ruff" }, { name = "stagehand", specifier = ">=3.0.0" }, - { name = "textarena" }, + { name = "textarena", git = "https://github.com/nph4rd/TextArena.git?branch=fix%2Fkuhn-poker-phantom-ante" }, { name = "ty", specifier = ">=0.0.1a29,<0.0.22" }, ] policy = [{ name = "semgrep", specifier = ">=1.150.0" }] diff --git a/verifiers/__init__.py b/verifiers/__init__.py index 5e6780e85..b84763942 100644 --- a/verifiers/__init__.py +++ b/verifiers/__init__.py @@ -102,6 +102,9 @@ "Agent", "Protocol", "RoundRobinProtocol", + "SpawningProtocol", + "SpawnSpec", + "SpawnResult", "Environment", "MultiTurnEnv", "MultiAgentEnv", @@ -176,6 +179,9 @@ "Agent": "verifiers.envs.agent:Agent", "Protocol": "verifiers.envs.protocol:Protocol", "RoundRobinProtocol": "verifiers.envs.protocol:RoundRobinProtocol", + "SpawningProtocol": "verifiers.envs.protocol:SpawningProtocol", + "SpawnSpec": "verifiers.envs.protocol:SpawnSpec", + "SpawnResult": "verifiers.envs.protocol:SpawnResult", "MultiAgentEnv": "verifiers.envs.multiagent_env:MultiAgentEnv", "EnvGroup": "verifiers.envs.env_group:EnvGroup", "JudgeRubric": "verifiers.rubrics.judge_rubric:JudgeRubric", @@ -290,7 +296,13 @@ def __getattr__(name: str): from .clients.openai_responses_client import OpenAIResponsesClient # noqa: F401 from .clients.renderer_client import RendererClient # noqa: F401 from .envs.agent import Agent # noqa: F401 - from .envs.protocol import Protocol, RoundRobinProtocol # noqa: F401 + from .envs.protocol import ( # noqa: F401 + Protocol, + RoundRobinProtocol, + SpawningProtocol, + SpawnResult, + SpawnSpec, + ) from .envs.env_group import EnvGroup # noqa: F401 from .envs.environment import Environment # noqa: F401 from .envs.multiagent_env import MultiAgentEnv # noqa: F401 diff --git a/verifiers/envs/multiagent_env.py b/verifiers/envs/multiagent_env.py index 4f14812ce..872ca8a3a 100644 --- a/verifiers/envs/multiagent_env.py +++ b/verifiers/envs/multiagent_env.py @@ -18,12 +18,13 @@ - on_turn_complete(state): Update game state after each turn """ +import asyncio from abc import abstractmethod import verifiers as vf from verifiers.envs.agent import Agent from verifiers.envs.multiturn_env import MultiTurnEnv -from verifiers.envs.protocol import Protocol +from verifiers.envs.protocol import Protocol, SpawningProtocol, SpawnResult, SpawnSpec from verifiers.types import Messages, State, TrajectoryStep from verifiers.utils.message_utils import normalize_messages @@ -197,6 +198,11 @@ async def rollout( c. Get model response d. Store in trajectory e. Process via on_turn_complete() + f. If protocol is a SpawningProtocol and should_spawn(state): + run child sub-rollouts in parallel, embed their trajectory + steps into the parent's trajectory tagged with the spec's + agent_id, and store SpawnResults in state["extras"]["spawns"] + for the rubric to consume. 3. Return final state """ state = await self.init_state(input, client, model, sampling_args) @@ -233,7 +239,14 @@ async def rollout( # 4. Process turn (game logic) await self.on_turn_complete(state) - # 5. Determine next agent (if game continues) + # 5. Spawn sub-rollouts if the protocol requests it. + if isinstance(self._protocol, SpawningProtocol) and ( + self._protocol.should_spawn(state) + ): + specs = self._protocol.get_spawn_specs(state) + await self._run_spawns(specs, state, client, model, sampling_args) + + # 6. Determine next agent (if game continues) if not await self.is_completed(state): state["extras"]["current_agent_id"] = self.get_next_agent(state) @@ -247,3 +260,78 @@ async def rollout( await self.render_completion(state) return state + + # ------------------------------------------------------------------------- + # Spawning support (SpawningProtocol) + # ------------------------------------------------------------------------- + + async def _run_spawns( + self, + specs: list[SpawnSpec], + state: State, + client, + model, + sampling_args, + ) -> None: + """Execute SpawnSpecs in parallel, score each child, embed their + trajectory steps into the parent's trajectory, and record SpawnResults + in ``state["extras"]["spawns"]``. + + Each child's trajectory steps are tagged with the spec's ``agent_id`` + and ``is_trainable`` (setdefault, so a child env that already tagged + its own steps wins). This keeps existing per-agent advantage and + completion-masking machinery intact — + ``interleave_rollout(split_by_agent=True)`` and + ``MicroBatch.actor_ids`` work without modification. + """ + state["extras"].setdefault("spawns", []) + # Outer loop is sequential across specs but inner inputs run in parallel. + # Most protocols emit a single spec per turn so this is essentially + # one gather() per spawning turn. + for spec in specs: + child_model = ( + self.actor_models.get(spec.agent_id) if self.actor_models else model + ) + child_states = await asyncio.gather( + *( + self._run_one_child(spec, child_input, client, child_model, sampling_args) + for child_input in spec.inputs + ) + ) + for child_state in child_states: + for step in child_state.get("trajectory", []): + step.setdefault("extras", {}) + step["extras"].setdefault("agent_id", spec.agent_id) + step["extras"].setdefault("is_trainable", spec.is_trainable) + state["trajectory"].append(step) + state["extras"]["spawns"].append( + SpawnResult(spec=spec, states=list(child_states)) + ) + + async def _run_one_child( + self, + spec: SpawnSpec, + child_input, + client, + child_model, + sampling_args, + ) -> State: + """Roll out a single child and score it so ``state["reward"]`` is + populated before the parent's reward funcs read it. Score failures + are logged and treated as zero — they should not bring down the + parent rollout.""" + child_state = await spec.child_env.rollout( + child_input, client, child_model, sampling_args + ) + rubric = getattr(spec.child_env, "rubric", None) + if rubric is not None: + try: + await rubric.score_rollout(child_state) + except Exception as e: # pragma: no cover — defensive + self.logger.warning( + "child rubric.score_rollout failed (agent_id=%s): %s: %s", + spec.agent_id, + type(e).__name__, + e, + ) + return child_state diff --git a/verifiers/envs/protocol.py b/verifiers/envs/protocol.py index b43e71e0d..2a0775910 100644 --- a/verifiers/envs/protocol.py +++ b/verifiers/envs/protocol.py @@ -7,14 +7,19 @@ Example protocols: - RoundRobinProtocol: Agents take turns in order (0 → 1 → ...) -- SimultaneousProtocol: All agents act at once (future) -- HierarchicalProtocol: Some agents coordinate others (future) +- SpawningProtocol: One agent can spawn sub-rollouts of another (e.g., + proposer-solver where the proposer creates problems that k solvers attempt) """ from abc import ABC, abstractmethod +from dataclasses import dataclass, field +from typing import TYPE_CHECKING, Any from verifiers.types import State +if TYPE_CHECKING: + from verifiers.envs.environment import Environment + class Protocol(ABC): """ @@ -88,3 +93,85 @@ def get_next_agent(self, state: State) -> str: current_idx = -1 next_idx = (current_idx + 1) % len(self.agent_ids) return self.agent_ids[next_idx] + + +@dataclass +class SpawnSpec: + """A spawn request: one or more sub-rollouts of ``child_env`` for ``agent_id``. + + Fields: + agent_id: Role tag for the spawned children. Their trajectory steps + are embedded in the parent's trajectory with this id so existing + per-agent advantage/credit-assignment machinery + (Rubric.score_group, interleave_rollout(split_by_agent=True), + MicroBatch.actor_ids) picks them up without modification. + child_env: The Environment to roll out for each child. Any + ``Environment`` with a ``rollout()`` method works — does not + have to be a MultiAgentEnv. + inputs: One dataset row per child rollout. Length determines how many + children are spawned. + is_trainable: When False, the spawned tokens are still generated and + included in the parent's context but their completion masks are + zeroed out at training time (mirrors ``Agent.is_trainable``). + """ + + agent_id: str + child_env: "Environment" + inputs: list[Any] + is_trainable: bool = True + + +@dataclass +class SpawnResult: + """A SpawnSpec paired with the resulting child states. + + Stored in ``state["extras"]["spawns"]`` so MultiAgentRewardFunc + implementations can compute parent rewards from child outcomes + (e.g. goldilocks scoring on per-child verify scores). + """ + + spec: SpawnSpec + states: list[State] = field(default_factory=list) + + +class SpawningProtocol(Protocol): + """Protocol extension for hierarchical / one-to-many agent interactions. + + The base ``Protocol`` only sequences turns of a fixed agent set. A + ``SpawningProtocol`` additionally lets the current agent's turn fan out + into N sub-rollouts of another agent (or env). The canonical example + is proposer-solver: the proposer designs a problem, then k solver + children attempt it, and the proposer's reward depends on how the + solvers fared. + + Implementers provide: + - ``should_spawn(state)``: was the previous turn a spawn trigger? + - ``get_spawn_specs(state)``: what to spawn (one or more SpawnSpecs) + + MultiAgentEnv.rollout() handles the actual sub-rollout execution and + weaves children's trajectory steps into the parent's trajectory, tagged + with the child agent_id from the spec. Reward functions read the + completed children from ``state["extras"]["spawns"]`` to compute + per-agent rewards. + """ + + @abstractmethod + def should_spawn(self, state: State) -> bool: + """Return True if the most recent turn should trigger sub-rollouts. + + Called after ``MultiAgentEnv.on_turn_complete`` and before + ``get_next_agent``. Subclasses typically inspect the last trajectory + step's content (e.g. did the proposer emit a valid problem?). + """ + ... + + @abstractmethod + def get_spawn_specs(self, state: State) -> list[SpawnSpec]: + """Return the spawn requests for the current state. + + Called only when ``should_spawn`` is True. Each ``SpawnSpec`` may + request multiple child rollouts via its ``inputs`` list, and + multiple specs may be returned (e.g. spawn two solver groups for + two different problems the proposer emitted in one turn). + """ + ... From b0d8526333437d7d62a2d7d0be0cbcc988d1ce25 Mon Sep 17 00:00:00 2001 From: nph4rd Date: Thu, 21 May 2026 13:29:44 -0600 Subject: [PATCH 21/21] resolve PEP 563 string annotations in _is_multiagent_func MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Functions defined under `from __future__ import annotations` have their return annotations stored as strings (e.g. 'dict[str, float]'), so `inspect.signature(func).return_annotation` returns the literal string and `get_origin(...)` returns None. The previous check classified such functions as INDIVIDUAL reward funcs, so calls were routed through `_call_individual_reward_func` which coerces the dict return to 0.0 via a failing `float()` conversion. Fix: use `typing.get_type_hints(func)` to resolve string annotations to their actual types before the dict-origin check. Falls back to the raw annotation if get_type_hints raises (rare — unresolved forward refs). Caught while running general-agent-coevolve on a real model: every solver rollout was correctly producing reward=1.0 in its own rubric, but the parent's aggregator returned 0 because both `solver_verify_reward` and `synth_goldilocks_reward` were declared under `from __future__ import annotations` in the env package. --- tests/test_rubric.py | 20 ++++++++++++++++++++ verifiers/rubrics/rubric.py | 23 +++++++++++++++++++---- 2 files changed, 39 insertions(+), 4 deletions(-) diff --git a/tests/test_rubric.py b/tests/test_rubric.py index 8b2293364..dcfafe4e2 100644 --- a/tests/test_rubric.py +++ b/tests/test_rubric.py @@ -1,5 +1,7 @@ """Tests for the Rubric class.""" +from __future__ import annotations + from typing import cast import pytest @@ -8,6 +10,24 @@ from verifiers.types import RewardFunc, RolloutInput, RolloutTiming, State +# Regression for `_is_multiagent_func` mis-classifying functions defined under +# ``from __future__ import annotations`` (PEP 563): the return annotation is +# the string ``"dict[str, float]"`` rather than the resolved generic alias, so +# ``get_origin(annotation)`` returns ``None`` and the function was routed +# through the individual-reward path — where its dict return value got coerced +# to 0 by ``float(dict)`` failing. ``_is_multiagent_func`` now uses +# ``typing.get_type_hints`` to resolve the string. +async def _multiagent_under_future_annotations( + state: State, **_kwargs +) -> dict[str, float]: + return {"agent_a": 0.5, "agent_b": 1.0} + + +def test_is_multiagent_func_handles_future_annotations(): + rubric = Rubric() + assert rubric._is_multiagent_func(_multiagent_under_future_annotations) is True + + class TestRubric: """Test cases for the Rubric class.""" diff --git a/verifiers/rubrics/rubric.py b/verifiers/rubrics/rubric.py index b1728b2ef..d2f2f35f1 100644 --- a/verifiers/rubrics/rubric.py +++ b/verifiers/rubrics/rubric.py @@ -2,7 +2,7 @@ import inspect import logging from collections.abc import Callable, Mapping -from typing import Any, cast, get_origin +from typing import Any, cast, get_origin, get_type_hints import verifiers as vf from verifiers.decorators import discover_decorated @@ -182,9 +182,24 @@ def task_for_state(self, state: State, resources: object | None) -> object: return None def _is_multiagent_func(self, func: RewardFunc) -> bool: - """Check if a function is a MultiAgentRewardFunc by inspecting its return annotation.""" - sig = inspect.signature(func) - return_annotation = sig.return_annotation + """Check if a function is a MultiAgentRewardFunc by inspecting its return annotation. + + Uses ``typing.get_type_hints`` so a function defined under + ``from __future__ import annotations`` (where annotations become + strings — PEP 563) is still classified correctly. ``inspect.signature`` + alone returns the raw string form, and ``get_origin('dict[str, float]')`` + is ``None`` — so the dict check would silently miss and the function + would be routed through the individual-reward path, where its dict + return value is coerced to 0.0 by ``float(dict_result)`` failing. + """ + try: + hints = get_type_hints(func) + except Exception: + # Forward refs that fail to resolve (rare, e.g. when the function's + # globals don't contain a referenced name). Fall back to the raw + # annotation so we still classify when get_type_hints raises. + hints = {} + return_annotation = hints.get("return", inspect.signature(func).return_annotation) return return_annotation is dict or get_origin(return_annotation) is dict # individual-level reward helpers