Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
125 changes: 125 additions & 0 deletions tests/envs/test_tool_env_void_mask.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
import json

import pytest

import verifiers as vf
from verifiers.envs.tool_env import ToolEnv


def constant_reward(**kwargs) -> float:
return 2.0


def good_tool() -> str:
return "ok"


def bad_tool() -> str:
raise RuntimeError("boom")


def make_env(*, mask: bool) -> ToolEnv:
return ToolEnv(
tools=[good_tool, bad_tool],
mask_all_failed_tool_calls=mask,
rubric=vf.Rubric(funcs=[constant_reward]),
)


def make_tool_call(name: str, index: int) -> vf.ToolCall:
return vf.ToolCall(id=f"call_{index}", name=name, arguments=json.dumps({}))


async def run_tool_calls(env: ToolEnv, tool_names: list[str]) -> vf.State:
tool_calls = [make_tool_call(name, index) for index, name in enumerate(tool_names)]
assistant = vf.AssistantMessage(content=None, tool_calls=tool_calls)
state = vf.State(
prompt=[vf.UserMessage(content="use tools")],
completion=[],
trajectory=[],
)
tool_messages = await env.env_response([assistant], state)
state["completion"] = [assistant, *tool_messages]
env._apply_tool_call_mask(state)
return state


def run_assistant_only_no_tools(env: ToolEnv) -> vf.State:
assistant = vf.AssistantMessage(content="done")
state = vf.State(
prompt=[vf.UserMessage(content="answer")],
completion=[assistant],
trajectory=[],
)
env._apply_tool_call_mask(state)
return state


@pytest.mark.asyncio
async def test_flag_off_records_outcomes_without_masking():
env = make_env(mask=False)

state = await run_tool_calls(env, ["bad_tool", "bad_tool"])
await env.rubric.score_rollout(state)

assert state.get("masked") in (None, False)
assert state["tool_call_outcomes"] == ["error", "error"]
assert state["reward"] == 2.0


@pytest.mark.asyncio
async def test_flag_on_all_errors_masked_and_scores_zero():
env = make_env(mask=True)

state = await run_tool_calls(env, ["bad_tool", "bad_tool"])
await env.rubric.score_rollout(state)

assert state["masked"] is True
assert state["tool_call_outcomes"] == ["error", "error"]
assert state["reward"] == 0.0
assert state["metrics"]["constant_reward"] == 0.0
assert state["metrics"]["void_turn_rollouts"] == 1.0


@pytest.mark.asyncio
async def test_flag_on_mixed_outcomes_unmasked():
env = make_env(mask=True)

state = await run_tool_calls(env, ["good_tool", "bad_tool"])

assert state["masked"] is False
assert state["tool_call_outcomes"] == ["ok", "error"]


def test_flag_on_no_tool_calls_unmasked():
env = make_env(mask=True)

state = run_assistant_only_no_tools(env)

assert state["masked"] is False
assert state.get("tool_call_outcomes") in (None, [])


@pytest.mark.asyncio
async def test_stateful_tool_env_tracks_outcomes_for_masking():
class ExampleStatefulToolEnv(vf.StatefulToolEnv):
def update_tool_args(
self,
tool_name: str,
tool_args: dict,
messages: vf.Messages,
state: vf.State,
**kwargs,
) -> dict:
return tool_args

env = ExampleStatefulToolEnv(
tools=[good_tool, bad_tool],
mask_all_failed_tool_calls=True,
rubric=vf.Rubric(funcs=[constant_reward]),
)

state = await run_tool_calls(env, ["bad_tool"])

assert state["masked"] is True
assert state["tool_call_outcomes"] == ["error"]
9 changes: 9 additions & 0 deletions verifiers/envs/AGENTS.md
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,15 @@ Base classes for building environments:
- `SandboxEnv` - Sandboxed container execution using `prime` sandboxes. All sandbox setup logic should be included in the start command and queued via `setup_state`, but not awaited—await resources only when first needed to overlap provisioning with rollout. See `python_env.py` for an example.
- `PythonEnv` - Persistent Python REPL in sandbox

## Optional Flags

`ToolEnv(mask_all_failed_tool_calls=True)` marks a rollout as
`state["masked"] = True` when it made at least one tool call and every recorded
tool call failed. Tool outcomes are recorded in `state["tool_call_outcomes"]` as
`"ok"` or `"error"`, and masked rollouts receive zero reward while exposing the
`void_turn_rollouts` metric. This follows the SimpleTIR void-turn masking pattern
for skipping all-tool-failed rollouts during reward computation.

## Integrations

Third-party library wrappers that require additional dependencies:
Expand Down
3 changes: 3 additions & 0 deletions verifiers/envs/stateful_tool_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,7 @@ async def env_response(
except Exception as e:
if self._should_stop_for_error(e):
raise vf.ToolParseError from e
self._record_tool_call_outcome(state, "error")
tool_messages.append(
ToolMessage(
role="tool",
Expand All @@ -162,10 +163,12 @@ async def env_response(
)
try:
tool_message = await self.call_tool(tool_name, tool_args, tool_call_id)
self._record_tool_call_outcome(state, "ok")
tool_messages.append(tool_message)
except Exception as e:
if self._should_stop_for_error(e):
raise vf.ToolCallError from e
self._record_tool_call_outcome(state, "error")
tool_messages.append(
ToolMessage(
role="tool",
Expand Down
39 changes: 38 additions & 1 deletion verifiers/envs/tool_env.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import json
from typing import Callable, cast
from typing import Callable, Literal, cast

import verifiers as vf
from verifiers.types import AssistantMessage, Messages, ToolCall, ToolMessage
Expand All @@ -9,6 +9,8 @@
is_valid_tool_content_parts,
)

ToolCallOutcome = Literal["ok", "error"]


class ToolMonitorRubric(vf.Rubric):
def __init__(self, tool_names: list[str] | None = None, **kwargs):
Expand Down Expand Up @@ -78,12 +80,14 @@ def __init__(
max_turns: int = 10,
error_formatter: Callable[[Exception], str] = lambda e: f"{e}",
stop_errors: list[type[Exception]] | None = None,
mask_all_failed_tool_calls: bool = False,
**kwargs,
):
self.tools = tools or []
self.max_turns = max_turns
self.error_formatter = error_formatter
self.stop_errors: list[type[Exception]] = stop_errors or []
self.mask_all_failed_tool_calls = mask_all_failed_tool_calls
self.tool_defs = [convert_func_to_tool_def(tool) for tool in self.tools]
self.tool_map = {
getattr(tool, "__name__", tool.__class__.__name__): tool
Expand All @@ -94,12 +98,42 @@ def __init__(
self.tool_monitor_rubric = ToolMonitorRubric(
tool_names=list(self.tool_map.keys())
)
if self.mask_all_failed_tool_calls:
self.tool_monitor_rubric.add_metric(self.void_turn_rollouts)
self.add_rubric(self.tool_monitor_rubric)

def _should_stop_for_error(self, err: Exception) -> bool:
"""Check if error is in stop_errors."""
return any(isinstance(err, err_type) for err_type in self.stop_errors)

def _tool_call_outcomes(self, state: vf.State) -> list[ToolCallOutcome]:
outcomes = state.setdefault("tool_call_outcomes", [])
return cast(list[ToolCallOutcome], outcomes)

def _record_tool_call_outcome(
self, state: vf.State, outcome: ToolCallOutcome
) -> None:
self._tool_call_outcomes(state).append(outcome)

def _should_mask(self, state: vf.State) -> bool:
outcomes = state.get("tool_call_outcomes") or []
return (
self.mask_all_failed_tool_calls
and len(outcomes) > 0
and all(outcome == "error" for outcome in outcomes)
)

def _apply_tool_call_mask(self, state: vf.State) -> None:
if self.mask_all_failed_tool_calls:
state["masked"] = self._should_mask(state)

async def void_turn_rollouts(self, state: vf.State) -> float:
return 1.0 if state.get("masked") else 0.0

async def _finalize_rollout(self, state: vf.State) -> None:
self._apply_tool_call_mask(state)
await super()._finalize_rollout(state)

def add_tool(self, tool: Callable):
self.tools.append(tool)
if self.tool_defs is None:
Expand Down Expand Up @@ -156,6 +190,7 @@ async def env_response(
except Exception as e:
if self._should_stop_for_error(e):
raise vf.ToolParseError from e
self._record_tool_call_outcome(state, "error")
tool_messages.append(
ToolMessage(
role="tool",
Expand All @@ -167,10 +202,12 @@ async def env_response(

try:
tool_message = await self.call_tool(tool_name, tool_args, tool_call_id)
self._record_tool_call_outcome(state, "ok")
tool_messages.append(tool_message)
except Exception as e:
if self._should_stop_for_error(e):
raise vf.ToolCallError from e
self._record_tool_call_outcome(state, "error")
tool_messages.append(
ToolMessage(
role="tool",
Expand Down
26 changes: 26 additions & 0 deletions verifiers/rubrics/rubric.py
Original file line number Diff line number Diff line change
Expand Up @@ -351,6 +351,26 @@ async def score_rollout(self, state: State):
"""
reward_funcs = self._get_individual_reward_funcs()
group_reward_funcs = self._get_group_reward_funcs()
if state.get("masked"):
reward_scores = []
for func, weight in zip(
reward_funcs, self._get_individual_reward_weights()
):
if weight == 0.0:
reward_scores.append(
await self._call_individual_reward_func(
func=func,
state=state,
)
)
else:
reward_scores.append(0.0)
state["reward"] = 0.0
state["metrics"] = {
func.__name__: reward
for func, reward in zip(reward_funcs, reward_scores)
}
return
assert len(reward_funcs) > 0 and len(group_reward_funcs) == 0, (
"Rubric.score_rollout requires at least one individual-level reward function and no group-level reward functions"
)
Expand Down Expand Up @@ -427,6 +447,12 @@ async def score_group(self, states: list[State]):
aggregated_rewards[i] += score_value * weight
aggregated_metrics[func_name][i] = score_value

for i, state in enumerate(states):
if state.get("masked"):
aggregated_rewards[i] = 0.0
for func, weight in zip(self.funcs, self.weights):
if weight != 0.0 and func.__name__ in aggregated_metrics:
aggregated_metrics[func.__name__][i] = 0.0
avg_reward = sum(aggregated_rewards) / num_states
for i, state in enumerate(states):
state["reward"] = aggregated_rewards[i]
Expand Down