diff --git a/tests/envs/test_tool_env_void_mask.py b/tests/envs/test_tool_env_void_mask.py new file mode 100644 index 000000000..b5bf9edb8 --- /dev/null +++ b/tests/envs/test_tool_env_void_mask.py @@ -0,0 +1,125 @@ +import json + +import pytest + +import verifiers as vf +from verifiers.envs.tool_env import ToolEnv + + +def constant_reward(**kwargs) -> float: + return 2.0 + + +def good_tool() -> str: + return "ok" + + +def bad_tool() -> str: + raise RuntimeError("boom") + + +def make_env(*, mask: bool) -> ToolEnv: + return ToolEnv( + tools=[good_tool, bad_tool], + mask_all_failed_tool_calls=mask, + rubric=vf.Rubric(funcs=[constant_reward]), + ) + + +def make_tool_call(name: str, index: int) -> vf.ToolCall: + return vf.ToolCall(id=f"call_{index}", name=name, arguments=json.dumps({})) + + +async def run_tool_calls(env: ToolEnv, tool_names: list[str]) -> vf.State: + tool_calls = [make_tool_call(name, index) for index, name in enumerate(tool_names)] + assistant = vf.AssistantMessage(content=None, tool_calls=tool_calls) + state = vf.State( + prompt=[vf.UserMessage(content="use tools")], + completion=[], + trajectory=[], + ) + tool_messages = await env.env_response([assistant], state) + state["completion"] = [assistant, *tool_messages] + env._apply_tool_call_mask(state) + return state + + +def run_assistant_only_no_tools(env: ToolEnv) -> vf.State: + assistant = vf.AssistantMessage(content="done") + state = vf.State( + prompt=[vf.UserMessage(content="answer")], + completion=[assistant], + trajectory=[], + ) + env._apply_tool_call_mask(state) + return state + + +@pytest.mark.asyncio +async def test_flag_off_records_outcomes_without_masking(): + env = make_env(mask=False) + + state = await run_tool_calls(env, ["bad_tool", "bad_tool"]) + await env.rubric.score_rollout(state) + + assert state.get("masked") in (None, False) + assert state["tool_call_outcomes"] == ["error", "error"] + assert state["reward"] == 2.0 + + +@pytest.mark.asyncio +async def test_flag_on_all_errors_masked_and_scores_zero(): + env = make_env(mask=True) + + state = await run_tool_calls(env, ["bad_tool", "bad_tool"]) + await env.rubric.score_rollout(state) + + assert state["masked"] is True + assert state["tool_call_outcomes"] == ["error", "error"] + assert state["reward"] == 0.0 + assert state["metrics"]["constant_reward"] == 0.0 + assert state["metrics"]["void_turn_rollouts"] == 1.0 + + +@pytest.mark.asyncio +async def test_flag_on_mixed_outcomes_unmasked(): + env = make_env(mask=True) + + state = await run_tool_calls(env, ["good_tool", "bad_tool"]) + + assert state["masked"] is False + assert state["tool_call_outcomes"] == ["ok", "error"] + + +def test_flag_on_no_tool_calls_unmasked(): + env = make_env(mask=True) + + state = run_assistant_only_no_tools(env) + + assert state["masked"] is False + assert state.get("tool_call_outcomes") in (None, []) + + +@pytest.mark.asyncio +async def test_stateful_tool_env_tracks_outcomes_for_masking(): + class ExampleStatefulToolEnv(vf.StatefulToolEnv): + def update_tool_args( + self, + tool_name: str, + tool_args: dict, + messages: vf.Messages, + state: vf.State, + **kwargs, + ) -> dict: + return tool_args + + env = ExampleStatefulToolEnv( + tools=[good_tool, bad_tool], + mask_all_failed_tool_calls=True, + rubric=vf.Rubric(funcs=[constant_reward]), + ) + + state = await run_tool_calls(env, ["bad_tool"]) + + assert state["masked"] is True + assert state["tool_call_outcomes"] == ["error"] diff --git a/verifiers/envs/AGENTS.md b/verifiers/envs/AGENTS.md index ecd6f9f50..8c51240e1 100644 --- a/verifiers/envs/AGENTS.md +++ b/verifiers/envs/AGENTS.md @@ -25,6 +25,15 @@ Base classes for building environments: - `SandboxEnv` - Sandboxed container execution using `prime` sandboxes. All sandbox setup logic should be included in the start command and queued via `setup_state`, but not awaited—await resources only when first needed to overlap provisioning with rollout. See `python_env.py` for an example. - `PythonEnv` - Persistent Python REPL in sandbox +## Optional Flags + +`ToolEnv(mask_all_failed_tool_calls=True)` marks a rollout as +`state["masked"] = True` when it made at least one tool call and every recorded +tool call failed. Tool outcomes are recorded in `state["tool_call_outcomes"]` as +`"ok"` or `"error"`, and masked rollouts receive zero reward while exposing the +`void_turn_rollouts` metric. This follows the SimpleTIR void-turn masking pattern +for skipping all-tool-failed rollouts during reward computation. + ## Integrations Third-party library wrappers that require additional dependencies: diff --git a/verifiers/envs/stateful_tool_env.py b/verifiers/envs/stateful_tool_env.py index cb44555a4..fb5e9d8c4 100644 --- a/verifiers/envs/stateful_tool_env.py +++ b/verifiers/envs/stateful_tool_env.py @@ -148,6 +148,7 @@ async def env_response( except Exception as e: if self._should_stop_for_error(e): raise vf.ToolParseError from e + self._record_tool_call_outcome(state, "error") tool_messages.append( ToolMessage( role="tool", @@ -162,10 +163,12 @@ async def env_response( ) try: tool_message = await self.call_tool(tool_name, tool_args, tool_call_id) + self._record_tool_call_outcome(state, "ok") tool_messages.append(tool_message) except Exception as e: if self._should_stop_for_error(e): raise vf.ToolCallError from e + self._record_tool_call_outcome(state, "error") tool_messages.append( ToolMessage( role="tool", diff --git a/verifiers/envs/tool_env.py b/verifiers/envs/tool_env.py index ee1ed07e7..dde3888cc 100644 --- a/verifiers/envs/tool_env.py +++ b/verifiers/envs/tool_env.py @@ -1,5 +1,5 @@ import json -from typing import Callable, cast +from typing import Callable, Literal, cast import verifiers as vf from verifiers.types import AssistantMessage, Messages, ToolCall, ToolMessage @@ -9,6 +9,8 @@ is_valid_tool_content_parts, ) +ToolCallOutcome = Literal["ok", "error"] + class ToolMonitorRubric(vf.Rubric): def __init__(self, tool_names: list[str] | None = None, **kwargs): @@ -78,12 +80,14 @@ def __init__( max_turns: int = 10, error_formatter: Callable[[Exception], str] = lambda e: f"{e}", stop_errors: list[type[Exception]] | None = None, + mask_all_failed_tool_calls: bool = False, **kwargs, ): self.tools = tools or [] self.max_turns = max_turns self.error_formatter = error_formatter self.stop_errors: list[type[Exception]] = stop_errors or [] + self.mask_all_failed_tool_calls = mask_all_failed_tool_calls self.tool_defs = [convert_func_to_tool_def(tool) for tool in self.tools] self.tool_map = { getattr(tool, "__name__", tool.__class__.__name__): tool @@ -94,12 +98,42 @@ def __init__( self.tool_monitor_rubric = ToolMonitorRubric( tool_names=list(self.tool_map.keys()) ) + if self.mask_all_failed_tool_calls: + self.tool_monitor_rubric.add_metric(self.void_turn_rollouts) self.add_rubric(self.tool_monitor_rubric) def _should_stop_for_error(self, err: Exception) -> bool: """Check if error is in stop_errors.""" return any(isinstance(err, err_type) for err_type in self.stop_errors) + def _tool_call_outcomes(self, state: vf.State) -> list[ToolCallOutcome]: + outcomes = state.setdefault("tool_call_outcomes", []) + return cast(list[ToolCallOutcome], outcomes) + + def _record_tool_call_outcome( + self, state: vf.State, outcome: ToolCallOutcome + ) -> None: + self._tool_call_outcomes(state).append(outcome) + + def _should_mask(self, state: vf.State) -> bool: + outcomes = state.get("tool_call_outcomes") or [] + return ( + self.mask_all_failed_tool_calls + and len(outcomes) > 0 + and all(outcome == "error" for outcome in outcomes) + ) + + def _apply_tool_call_mask(self, state: vf.State) -> None: + if self.mask_all_failed_tool_calls: + state["masked"] = self._should_mask(state) + + async def void_turn_rollouts(self, state: vf.State) -> float: + return 1.0 if state.get("masked") else 0.0 + + async def _finalize_rollout(self, state: vf.State) -> None: + self._apply_tool_call_mask(state) + await super()._finalize_rollout(state) + def add_tool(self, tool: Callable): self.tools.append(tool) if self.tool_defs is None: @@ -156,6 +190,7 @@ async def env_response( except Exception as e: if self._should_stop_for_error(e): raise vf.ToolParseError from e + self._record_tool_call_outcome(state, "error") tool_messages.append( ToolMessage( role="tool", @@ -167,10 +202,12 @@ async def env_response( try: tool_message = await self.call_tool(tool_name, tool_args, tool_call_id) + self._record_tool_call_outcome(state, "ok") tool_messages.append(tool_message) except Exception as e: if self._should_stop_for_error(e): raise vf.ToolCallError from e + self._record_tool_call_outcome(state, "error") tool_messages.append( ToolMessage( role="tool", diff --git a/verifiers/rubrics/rubric.py b/verifiers/rubrics/rubric.py index 914ae4d69..51dabe53d 100644 --- a/verifiers/rubrics/rubric.py +++ b/verifiers/rubrics/rubric.py @@ -351,6 +351,26 @@ async def score_rollout(self, state: State): """ reward_funcs = self._get_individual_reward_funcs() group_reward_funcs = self._get_group_reward_funcs() + if state.get("masked"): + reward_scores = [] + for func, weight in zip( + reward_funcs, self._get_individual_reward_weights() + ): + if weight == 0.0: + reward_scores.append( + await self._call_individual_reward_func( + func=func, + state=state, + ) + ) + else: + reward_scores.append(0.0) + state["reward"] = 0.0 + state["metrics"] = { + func.__name__: reward + for func, reward in zip(reward_funcs, reward_scores) + } + return assert len(reward_funcs) > 0 and len(group_reward_funcs) == 0, ( "Rubric.score_rollout requires at least one individual-level reward function and no group-level reward functions" ) @@ -427,6 +447,12 @@ async def score_group(self, states: list[State]): aggregated_rewards[i] += score_value * weight aggregated_metrics[func_name][i] = score_value + for i, state in enumerate(states): + if state.get("masked"): + aggregated_rewards[i] = 0.0 + for func, weight in zip(self.funcs, self.weights): + if weight != 0.0 and func.__name__ in aggregated_metrics: + aggregated_metrics[func.__name__][i] = 0.0 avg_reward = sum(aggregated_rewards) / num_states for i, state in enumerate(states): state["reward"] = aggregated_rewards[i]