From ed56ed21d9aa62f2a072c30dc558aa0a017ef4fd Mon Sep 17 00:00:00 2001 From: rasdani <73563550+rasdani@users.noreply.github.com> Date: Tue, 19 May 2026 06:32:07 +0530 Subject: [PATCH 1/2] Fix sandbox cleanup on failed rollouts --- tests/test_rollout_cleanup.py | 236 ++++++++++++++++++ verifiers/envs/environment.py | 142 ++++++++--- verifiers/envs/experimental/cli_agent_env.py | 15 +- .../experimental/composable/composable_env.py | 30 ++- .../composable/tasksets/swe/multi_swe.py | 6 +- .../composable/tasksets/swe/openswe.py | 6 +- .../composable/tasksets/swe/r2e_gym.py | 6 +- .../composable/tasksets/swe/swe_bench.py | 6 +- .../composable/tasksets/swe/swe_lego.py | 6 +- .../composable/tasksets/swe/swe_rebench_v2.py | 6 +- .../composable/tasksets/swe/swe_smith.py | 6 +- verifiers/rubrics/rubric.py | 13 +- verifiers/rubrics/rubric_group.py | 18 +- 13 files changed, 425 insertions(+), 71 deletions(-) create mode 100644 tests/test_rollout_cleanup.py diff --git a/tests/test_rollout_cleanup.py b/tests/test_rollout_cleanup.py new file mode 100644 index 000000000..9d6857cdd --- /dev/null +++ b/tests/test_rollout_cleanup.py @@ -0,0 +1,236 @@ +import asyncio +from typing import Any +from unittest.mock import AsyncMock + +import pytest +from datasets import Dataset + +import verifiers as vf +from verifiers.envs.experimental.cli_agent_env import CliAgentEnv +from verifiers.envs.experimental.composable.tasksets.swe.r2e_gym import R2ERubric +from verifiers.types import RolloutInput, SamplingArgs, State + + +def _dataset() -> Dataset: + return Dataset.from_dict( + { + "question": ["q0", "q1"], + "answer": ["a0", "a1"], + } + ) + + +def _input(example_id: int) -> RolloutInput: + return { + "prompt": [{"role": "user", "content": f"q{example_id}"}], + "answer": f"a{example_id}", + "example_id": example_id, + } + + +class RecordingRubric(vf.Rubric): + def __init__( + self, + *, + score_rollout_error: Exception | None = None, + score_group_error: Exception | None = None, + ): + super().__init__() + self.cleaned: list[int] = [] + self.score_rollout_error = score_rollout_error + self.score_group_error = score_group_error + + async def score_rollout(self, state: State): + if self.score_rollout_error is not None: + raise self.score_rollout_error + state["reward"] = 1.0 + state["metrics"] = {} + + async def score_group(self, states: list[State]): + if self.score_group_error is not None: + raise self.score_group_error + for state in states: + state["reward"] = 1.0 + state["metrics"] = {} + + async def cleanup(self, state: State): + self.cleaned.append(state["example_id"]) + + +class StaticRolloutEnv(vf.Environment): + async def rollout( + self, + input: RolloutInput, + client: vf.Client, + model: str, + sampling_args: SamplingArgs | None = None, + ) -> State: + state = await self.init_state(input, client, model, sampling_args) + state["sandbox_id"] = f"sb-{state['example_id']}" + return state + + +def _env(rubric: vf.Rubric) -> StaticRolloutEnv: + return StaticRolloutEnv(dataset=_dataset(), parser=vf.Parser(), rubric=rubric) + + +@pytest.mark.asyncio +async def test_run_rollout_state_cleans_up_when_scoring_raises(mock_client): + rubric = RecordingRubric(score_rollout_error=RuntimeError("score failed")) + env = _env(rubric) + + with pytest.raises(RuntimeError, match="score failed"): + await env._run_rollout_state(_input(0), mock_client, "test-model", {}) + + assert rubric.cleaned == [0] + + +@pytest.mark.asyncio +async def test_run_group_states_cleans_completed_states_when_gather_raises( + mock_client, +): + first_rollout_finished = asyncio.Event() + + class PartiallyFailingEnv(StaticRolloutEnv): + async def rollout( + self, + input: RolloutInput, + client: vf.Client, + model: str, + sampling_args: SamplingArgs | None = None, + ) -> State: + if input["example_id"] == 1: + await first_rollout_finished.wait() + raise RuntimeError("rollout failed") + state = await super().rollout(input, client, model, sampling_args) + first_rollout_finished.set() + return state + + rubric = RecordingRubric() + env = PartiallyFailingEnv(dataset=_dataset(), parser=vf.Parser(), rubric=rubric) + + with pytest.raises(RuntimeError, match="rollout failed"): + await env._run_group_states( + [_input(0), _input(1)], + mock_client, + "test-model", + {}, + ) + + assert rubric.cleaned == [0] + + +@pytest.mark.asyncio +async def test_run_group_states_cleans_all_states_when_group_scoring_raises( + mock_client, +): + rubric = RecordingRubric(score_group_error=RuntimeError("group score failed")) + env = _env(rubric) + + with pytest.raises(RuntimeError, match="group score failed"): + await env._run_group_states( + [_input(0), _input(1)], + mock_client, + "test-model", + {}, + ) + + assert rubric.cleaned == [0, 1] + + +@pytest.mark.asyncio +async def test_environment_cleanup_failure_does_not_skip_later_handler(mock_client): + class FailingCleanupEnv(StaticRolloutEnv): + def __init__(self, **kwargs: Any): + super().__init__(**kwargs) + self.destroyed = False + + @vf.cleanup(priority=1) + async def failing_cleanup(self, state: State): + raise RuntimeError("early cleanup failed") + + @vf.cleanup(priority=0) + async def destroy_sandbox(self, state: State): + self.destroyed = True + + env = FailingCleanupEnv(dataset=_dataset(), parser=vf.Parser(), rubric=vf.Rubric()) + state = await env.init_state(_input(0), mock_client, "test-model") + + with pytest.raises(RuntimeError, match="early cleanup failed"): + await env.cleanup(state) + + assert env.destroyed is True + + +@pytest.mark.asyncio +async def test_rubric_group_cleanup_failure_does_not_skip_later_rubric(): + class FailingRubric(vf.Rubric): + async def cleanup(self, state: State): + raise RuntimeError("rubric cleanup failed") + + class DestroyingRubric(vf.Rubric): + async def cleanup(self, state: State): + state["destroyed"] = True + + state: State = vf.State(input={}) + rubric = vf.RubricGroup([FailingRubric(), DestroyingRubric()]) + + with pytest.raises(RuntimeError, match="rubric cleanup failed"): + await rubric.cleanup(state) + + assert state["destroyed"] is True + + +@pytest.mark.asyncio +async def test_cli_agent_destroy_sandbox_deletes_when_post_rollout_fails(): + class FailingPostRolloutEnv(CliAgentEnv): + async def post_rollout(self, state: State): + raise RuntimeError("post rollout failed") + + env = FailingPostRolloutEnv( + run_command="echo done", + dataset=_dataset(), + parser=vf.Parser(), + rubric=vf.Rubric(), + keep_sandbox_for_scoring=True, + ) + env.delete_sandbox = AsyncMock() # type: ignore[method-assign] + state: State = vf.State(input={}) + state.update({"is_completed": True, "sandbox_id": "sb-post-rollout-failed"}) + + try: + with pytest.raises(RuntimeError, match="post rollout failed"): + await env.destroy_sandbox(state) + env.delete_sandbox.assert_awaited_once_with("sb-post-rollout-failed") + finally: + env.teardown_sandbox_client() + + +@pytest.mark.asyncio +async def test_swe_rubric_model_error_skips_sandbox_scoring(): + class StubTaskSet: + def __init__(self): + self.ran_tests = False + + async def _run_tests(self, *args: Any, **kwargs: Any) -> str: + self.ran_tests = True + return "PASS" + + def _calculate_reward(self, test_output: str, info: dict[str, Any]) -> float: + return 1.0 + + taskset = StubTaskSet() + rubric = R2ERubric(taskset) # type: ignore[arg-type] + state: State = vf.State(input={}) + state.update( + { + "error": vf.ModelError("No available workers"), + "sandbox_client": object(), + "sandbox_id": "sb-leaked-without-short-circuit", + } + ) + + reward = await rubric.solved(state, info={}) + + assert reward == 0.0 + assert taskset.ran_tests is False diff --git a/verifiers/envs/environment.py b/verifiers/envs/environment.py index ed379f086..d506ee994 100644 --- a/verifiers/envs/environment.py +++ b/verifiers/envs/environment.py @@ -633,14 +633,40 @@ async def cleanup( """ Finalize rollout state and clean up rollout-local resources. """ + cleanup_error: Exception | None = None for handler in self._cleanup_handlers: - await maybe_call_with_named_args( - handler, - task=task, - state=state, - env=self, - resources=resources, - ) + try: + await maybe_call_with_named_args( + handler, + task=task, + state=state, + env=self, + resources=resources, + ) + except Exception as e: + if cleanup_error is None: + cleanup_error = e + self.logger.exception( + "Cleanup handler %s failed", + getattr(handler, "__name__", repr(handler)), + ) + if cleanup_error is not None: + raise cleanup_error + + async def _cleanup_rollout_states(self, states: list[State]) -> None: + cleanup_error: Exception | None = None + for state in states: + try: + await self.rubric.cleanup(state) + except Exception as e: + if cleanup_error is None: + cleanup_error = e + self.logger.exception( + "Rubric cleanup failed for rollout example_id=%s", + state.get("example_id"), + ) + if cleanup_error is not None: + raise cleanup_error async def _teardown(self): """ @@ -691,14 +717,26 @@ async def _run_rollout_state( sampling_args, ) - state["timing"].scoring.start = time.time() - if self.score_rollouts: - await self.rubric.score_rollout(state) - else: - await self.rubric.dummy_score_rollout(state) - state["timing"].scoring.end = time.time() - - await self.rubric.cleanup(state) + primary_error: BaseException | None = None + try: + state["timing"].scoring.start = time.time() + if self.score_rollouts: + await self.rubric.score_rollout(state) + else: + await self.rubric.dummy_score_rollout(state) + state["timing"].scoring.end = time.time() + except BaseException as e: + primary_error = e + raise + finally: + try: + await self._cleanup_rollout_states([state]) + except Exception: + if primary_error is None: + raise + self.logger.exception( + "Rubric cleanup failed after rollout scoring failed" + ) return state async def _run_group_states( @@ -709,31 +747,65 @@ async def _run_group_states( sampling_args: SamplingArgs, ) -> list[State]: rollout_tasks = [ - self.rollout( - input, - client, - model, - sampling_args, + asyncio.create_task( + self.rollout( + input, + client, + model, + sampling_args, + ) ) for input in group_inputs ] - group_states = await asyncio.gather(*rollout_tasks) - - start_scoring = time.time() - for state in group_states: - state["timing"].scoring.start = start_scoring - if self.score_rollouts: - await self.rubric.score_group(group_states) - else: - await self.rubric.dummy_score_group(group_states) - end_scoring = time.time() - for state in group_states: - state["timing"].scoring.end = end_scoring + group_states: list[State] = [] + primary_error: BaseException | None = None + try: + group_states = await asyncio.gather(*rollout_tasks) - for state in group_states: - await self.rubric.cleanup(state) + start_scoring = time.time() + for state in group_states: + state["timing"].scoring.start = start_scoring + if self.score_rollouts: + await self.rubric.score_group(group_states) + else: + await self.rubric.dummy_score_group(group_states) + end_scoring = time.time() + for state in group_states: + state["timing"].scoring.end = end_scoring + + return group_states + except BaseException as e: + primary_error = e + raise + finally: + pending_tasks = [task for task in rollout_tasks if not task.done()] + if pending_tasks: + for task in pending_tasks: + task.cancel() + await asyncio.gather(*pending_tasks, return_exceptions=True) + + cleanup_states: list[State] = [] + seen_state_ids: set[int] = set() + for state in group_states: + cleanup_states.append(state) + seen_state_ids.add(id(state)) + for task in rollout_tasks: + if not task.done() or task.cancelled(): + continue + try: + state = task.result() + except BaseException: + continue + if id(state) not in seen_state_ids: + cleanup_states.append(state) + seen_state_ids.add(id(state)) - return group_states + try: + await self._cleanup_rollout_states(cleanup_states) + except Exception: + if primary_error is None: + raise + self.logger.exception("Rubric cleanup failed after group run failed") @final async def run_rollout( diff --git a/verifiers/envs/experimental/cli_agent_env.py b/verifiers/envs/experimental/cli_agent_env.py index 5b2a89324..99f24581a 100644 --- a/verifiers/envs/experimental/cli_agent_env.py +++ b/verifiers/envs/experimental/cli_agent_env.py @@ -758,14 +758,25 @@ async def destroy_sandbox(self, state: State): the sandbox is always deleted since scoring will not happen. """ completed = state.get("is_completed", False) + post_rollout_error: Exception | None = None if completed: - await self.post_rollout(state) + try: + await self.post_rollout(state) + except Exception as e: + post_rollout_error = e + self.logger.exception("Post-rollout cleanup failed") sandbox_id = state.get("sandbox_id") if sandbox_id: - if self.keep_sandbox_for_scoring and completed: + if ( + self.keep_sandbox_for_scoring + and completed + and post_rollout_error is None + ): self.deregister_sandbox(sandbox_id) else: await self.delete_sandbox(sandbox_id) + if post_rollout_error is not None: + raise post_rollout_error async def env_response( self, messages: Messages, state: State, **kwargs diff --git a/verifiers/envs/experimental/composable/composable_env.py b/verifiers/envs/experimental/composable/composable_env.py index 1a519ec21..ee92574c5 100644 --- a/verifiers/envs/experimental/composable/composable_env.py +++ b/verifiers/envs/experimental/composable/composable_env.py @@ -89,18 +89,28 @@ def _upload_tar_filter(info: tarfile.TarInfo) -> tarfile.TarInfo | None: class HarnessMetricsRubricGroup(vf.RubricGroup): async def cleanup(self, state: State) -> None: + cleanup_error: Exception | None = None for rubric in self.rubrics: - await rubric.cleanup(state) + try: + await rubric.cleanup(state) + except Exception as e: + if cleanup_error is None: + cleanup_error = e + self.logger.exception( + "Cleanup for rubric %s failed", + rubric.__class__.__name__, + ) harness_metrics = state.get("_harness_metrics") - if not isinstance(harness_metrics, dict): - return - state_metrics = state.get("metrics") - if not isinstance(state_metrics, dict): - state_metrics = {} - state["metrics"] = state_metrics - for key, value in harness_metrics.items(): - if isinstance(key, str) and isinstance(value, (int, float)): - state_metrics[key] = float(value) + if isinstance(harness_metrics, dict): + state_metrics = state.get("metrics") + if not isinstance(state_metrics, dict): + state_metrics = {} + state["metrics"] = state_metrics + for key, value in harness_metrics.items(): + if isinstance(key, str) and isinstance(value, (int, float)): + state_metrics[key] = float(value) + if cleanup_error is not None: + raise cleanup_error class ComposableEnv(CliAgentEnv): diff --git a/verifiers/envs/experimental/composable/tasksets/swe/multi_swe.py b/verifiers/envs/experimental/composable/tasksets/swe/multi_swe.py index 78224ae98..d2cafee1a 100644 --- a/verifiers/envs/experimental/composable/tasksets/swe/multi_swe.py +++ b/verifiers/envs/experimental/composable/tasksets/swe/multi_swe.py @@ -141,7 +141,7 @@ def __init__(self, taskset: "MultiSWETaskSet", **kwargs): self.add_reward_func(self.solved) async def solved(self, state, info, **kwargs) -> float: - if isinstance(state.get("error"), vf.InfraError): + if state.get("error") is not None: return 0.0 sandbox_client = state.get("sandbox_client") sandbox_id = state.get("sandbox_id") @@ -166,8 +166,8 @@ async def cleanup_sandbox(self, state: vf.State) -> None: if sandbox_client and sandbox_id: try: await sandbox_client.delete(sandbox_id) - except Exception: - pass + except Exception as e: + logger.warning(f"Failed to delete sandbox {sandbox_id}: {e}") class MultiSWETaskSet(SandboxTaskSet): diff --git a/verifiers/envs/experimental/composable/tasksets/swe/openswe.py b/verifiers/envs/experimental/composable/tasksets/swe/openswe.py index c55978bbc..d4add88f5 100644 --- a/verifiers/envs/experimental/composable/tasksets/swe/openswe.py +++ b/verifiers/envs/experimental/composable/tasksets/swe/openswe.py @@ -42,7 +42,7 @@ def __init__(self, taskset: "OpenSWETaskSet", **kwargs): self.add_reward_func(self.solved) async def solved(self, state, info, **kwargs) -> float: - if isinstance(state.get("error"), vf.InfraError): + if state.get("error") is not None: return 0.0 sandbox_client = state.get("sandbox_client") sandbox_id = state.get("sandbox_id") @@ -67,8 +67,8 @@ async def cleanup_sandbox(self, state: vf.State) -> None: if sandbox_client and sandbox_id: try: await sandbox_client.delete(sandbox_id) - except Exception: - pass + except Exception as e: + logger.warning(f"Failed to delete sandbox {sandbox_id}: {e}") class OpenSWETaskSet(SandboxTaskSet): diff --git a/verifiers/envs/experimental/composable/tasksets/swe/r2e_gym.py b/verifiers/envs/experimental/composable/tasksets/swe/r2e_gym.py index 58bb895ec..88897f122 100644 --- a/verifiers/envs/experimental/composable/tasksets/swe/r2e_gym.py +++ b/verifiers/envs/experimental/composable/tasksets/swe/r2e_gym.py @@ -140,7 +140,7 @@ def __init__(self, taskset: "R2EGymTaskSet", **kwargs): self.add_reward_func(self.solved) async def solved(self, state, info, **kwargs) -> float: - if isinstance(state.get("error"), vf.InfraError): + if state.get("error") is not None: return 0.0 sandbox_client = state.get("sandbox_client") sandbox_id = state.get("sandbox_id") @@ -165,8 +165,8 @@ async def cleanup_sandbox(self, state: vf.State) -> None: if sandbox_client and sandbox_id: try: await sandbox_client.delete(sandbox_id) - except Exception: - pass + except Exception as e: + logger.warning(f"Failed to delete sandbox {sandbox_id}: {e}") class R2EGymTaskSet(SandboxTaskSet): diff --git a/verifiers/envs/experimental/composable/tasksets/swe/swe_bench.py b/verifiers/envs/experimental/composable/tasksets/swe/swe_bench.py index fd23147bc..7f12c3a9a 100644 --- a/verifiers/envs/experimental/composable/tasksets/swe/swe_bench.py +++ b/verifiers/envs/experimental/composable/tasksets/swe/swe_bench.py @@ -313,7 +313,7 @@ def __init__(self, taskset: "SWEBenchTaskSet", **kwargs): self.add_reward_func(self.solved) async def solved(self, state, info, **kwargs) -> float: - if isinstance(state.get("error"), vf.InfraError): + if state.get("error") is not None: return 0.0 sandbox_client = state.get("sandbox_client") sandbox_id = state.get("sandbox_id") @@ -338,8 +338,8 @@ async def cleanup_sandbox(self, state: vf.State) -> None: if sandbox_client and sandbox_id: try: await sandbox_client.delete(sandbox_id) - except Exception: - pass + except Exception as e: + logger.warning(f"Failed to delete sandbox {sandbox_id}: {e}") class SWEBenchTaskSet(SandboxTaskSet): diff --git a/verifiers/envs/experimental/composable/tasksets/swe/swe_lego.py b/verifiers/envs/experimental/composable/tasksets/swe/swe_lego.py index a3d4d76d7..093208d20 100644 --- a/verifiers/envs/experimental/composable/tasksets/swe/swe_lego.py +++ b/verifiers/envs/experimental/composable/tasksets/swe/swe_lego.py @@ -134,7 +134,7 @@ def __init__(self, taskset: "SWELegoTaskSet", **kwargs): self.add_reward_func(self.solved) async def solved(self, state, info, **kwargs) -> float: - if isinstance(state.get("error"), vf.InfraError): + if state.get("error") is not None: return 0.0 sandbox_client = state.get("sandbox_client") sandbox_id = state.get("sandbox_id") @@ -158,8 +158,8 @@ async def cleanup_sandbox(self, state: vf.State) -> None: if sandbox_client and sandbox_id: try: await sandbox_client.delete(sandbox_id) - except Exception: - pass + except Exception as e: + logger.warning(f"Failed to delete sandbox {sandbox_id}: {e}") class SWELegoTaskSet(SandboxTaskSet): diff --git a/verifiers/envs/experimental/composable/tasksets/swe/swe_rebench_v2.py b/verifiers/envs/experimental/composable/tasksets/swe/swe_rebench_v2.py index b94ce47e6..870a2ed97 100644 --- a/verifiers/envs/experimental/composable/tasksets/swe/swe_rebench_v2.py +++ b/verifiers/envs/experimental/composable/tasksets/swe/swe_rebench_v2.py @@ -190,7 +190,7 @@ def __init__(self, taskset: "SWERebenchV2TaskSet", **kwargs): self.add_reward_func(self.solved) async def solved(self, state, info, **kwargs) -> float: - if isinstance(state.get("error"), vf.InfraError): + if state.get("error") is not None: return 0.0 sandbox_client = state.get("sandbox_client") sandbox_id = state.get("sandbox_id") @@ -214,8 +214,8 @@ async def cleanup_sandbox(self, state: vf.State) -> None: if sandbox_client and sandbox_id: try: await sandbox_client.delete(sandbox_id) - except Exception: - pass + except Exception as e: + logger.warning(f"Failed to delete sandbox {sandbox_id}: {e}") class SWERebenchV2TaskSet(SandboxTaskSet): diff --git a/verifiers/envs/experimental/composable/tasksets/swe/swe_smith.py b/verifiers/envs/experimental/composable/tasksets/swe/swe_smith.py index 944464756..d037835f7 100644 --- a/verifiers/envs/experimental/composable/tasksets/swe/swe_smith.py +++ b/verifiers/envs/experimental/composable/tasksets/swe/swe_smith.py @@ -145,7 +145,7 @@ def __init__(self, taskset: "SWESmithTaskSet", **kwargs): self.add_reward_func(self.solved) async def solved(self, state, info, **kwargs) -> float: - if isinstance(state.get("error"), vf.InfraError): + if state.get("error") is not None: return 0.0 sandbox_client = state.get("sandbox_client") sandbox_id = state.get("sandbox_id") @@ -169,8 +169,8 @@ async def cleanup_sandbox(self, state: vf.State) -> None: if sandbox_client and sandbox_id: try: await sandbox_client.delete(sandbox_id) - except Exception: - pass + except Exception as e: + logger.warning(f"Failed to delete sandbox {sandbox_id}: {e}") class SWESmithTaskSet(SandboxTaskSet): diff --git a/verifiers/rubrics/rubric.py b/verifiers/rubrics/rubric.py index 914ae4d69..c92526153 100644 --- a/verifiers/rubrics/rubric.py +++ b/verifiers/rubrics/rubric.py @@ -293,8 +293,19 @@ async def _call_group_reward_func( async def cleanup(self, state: State): """Run all @vf.cleanup-decorated methods on this rubric.""" + cleanup_error: Exception | None = None for handler in self._cleanup_handlers: - await self._call_cleanup_handler(handler, state) + try: + await self._call_cleanup_handler(handler, state) + except Exception as e: + if cleanup_error is None: + cleanup_error = e + self.logger.exception( + "Cleanup handler %s failed", + getattr(handler, "__name__", repr(handler)), + ) + if cleanup_error is not None: + raise cleanup_error async def _call_cleanup_handler(self, handler: Callable[..., Any], state: State): objects = self.cleanup_objects(handler, state) diff --git a/verifiers/rubrics/rubric_group.py b/verifiers/rubrics/rubric_group.py index 85807f902..08dd202a1 100644 --- a/verifiers/rubrics/rubric_group.py +++ b/verifiers/rubrics/rubric_group.py @@ -101,9 +101,23 @@ async def score_rollout(self, state: State): async def cleanup(self, state: State): """Run cleanup for all rubrics in the group.""" - await super().cleanup(state) + cleanup_error: Exception | None = None + try: + await super().cleanup(state) + except Exception as e: + cleanup_error = e for rubric in self.rubrics: - await rubric.cleanup(state) + try: + await rubric.cleanup(state) + except Exception as e: + if cleanup_error is None: + cleanup_error = e + self.logger.exception( + "Cleanup for rubric %s failed", + rubric.__class__.__name__, + ) + if cleanup_error is not None: + raise cleanup_error async def teardown(self): """Run teardown for all rubrics in the group.""" From 472fb00807b9e9e9de4616cf0dfc07ec027dd3f2 Mon Sep 17 00:00:00 2001 From: rasdani <73563550+rasdani@users.noreply.github.com> Date: Wed, 20 May 2026 19:06:44 +0200 Subject: [PATCH 2/2] Narrow SWE error guard changes --- tests/test_rollout_cleanup.py | 236 ------------------ verifiers/envs/environment.py | 142 +++-------- verifiers/envs/experimental/cli_agent_env.py | 15 +- .../experimental/composable/composable_env.py | 30 +-- .../composable/tasksets/swe/multi_swe.py | 4 +- .../composable/tasksets/swe/openswe.py | 4 +- .../composable/tasksets/swe/r2e_gym.py | 4 +- .../composable/tasksets/swe/swe_bench.py | 4 +- .../composable/tasksets/swe/swe_lego.py | 4 +- .../composable/tasksets/swe/swe_rebench_v2.py | 4 +- .../composable/tasksets/swe/swe_smith.py | 4 +- verifiers/rubrics/rubric.py | 13 +- verifiers/rubrics/rubric_group.py | 18 +- 13 files changed, 64 insertions(+), 418 deletions(-) delete mode 100644 tests/test_rollout_cleanup.py diff --git a/tests/test_rollout_cleanup.py b/tests/test_rollout_cleanup.py deleted file mode 100644 index 9d6857cdd..000000000 --- a/tests/test_rollout_cleanup.py +++ /dev/null @@ -1,236 +0,0 @@ -import asyncio -from typing import Any -from unittest.mock import AsyncMock - -import pytest -from datasets import Dataset - -import verifiers as vf -from verifiers.envs.experimental.cli_agent_env import CliAgentEnv -from verifiers.envs.experimental.composable.tasksets.swe.r2e_gym import R2ERubric -from verifiers.types import RolloutInput, SamplingArgs, State - - -def _dataset() -> Dataset: - return Dataset.from_dict( - { - "question": ["q0", "q1"], - "answer": ["a0", "a1"], - } - ) - - -def _input(example_id: int) -> RolloutInput: - return { - "prompt": [{"role": "user", "content": f"q{example_id}"}], - "answer": f"a{example_id}", - "example_id": example_id, - } - - -class RecordingRubric(vf.Rubric): - def __init__( - self, - *, - score_rollout_error: Exception | None = None, - score_group_error: Exception | None = None, - ): - super().__init__() - self.cleaned: list[int] = [] - self.score_rollout_error = score_rollout_error - self.score_group_error = score_group_error - - async def score_rollout(self, state: State): - if self.score_rollout_error is not None: - raise self.score_rollout_error - state["reward"] = 1.0 - state["metrics"] = {} - - async def score_group(self, states: list[State]): - if self.score_group_error is not None: - raise self.score_group_error - for state in states: - state["reward"] = 1.0 - state["metrics"] = {} - - async def cleanup(self, state: State): - self.cleaned.append(state["example_id"]) - - -class StaticRolloutEnv(vf.Environment): - async def rollout( - self, - input: RolloutInput, - client: vf.Client, - model: str, - sampling_args: SamplingArgs | None = None, - ) -> State: - state = await self.init_state(input, client, model, sampling_args) - state["sandbox_id"] = f"sb-{state['example_id']}" - return state - - -def _env(rubric: vf.Rubric) -> StaticRolloutEnv: - return StaticRolloutEnv(dataset=_dataset(), parser=vf.Parser(), rubric=rubric) - - -@pytest.mark.asyncio -async def test_run_rollout_state_cleans_up_when_scoring_raises(mock_client): - rubric = RecordingRubric(score_rollout_error=RuntimeError("score failed")) - env = _env(rubric) - - with pytest.raises(RuntimeError, match="score failed"): - await env._run_rollout_state(_input(0), mock_client, "test-model", {}) - - assert rubric.cleaned == [0] - - -@pytest.mark.asyncio -async def test_run_group_states_cleans_completed_states_when_gather_raises( - mock_client, -): - first_rollout_finished = asyncio.Event() - - class PartiallyFailingEnv(StaticRolloutEnv): - async def rollout( - self, - input: RolloutInput, - client: vf.Client, - model: str, - sampling_args: SamplingArgs | None = None, - ) -> State: - if input["example_id"] == 1: - await first_rollout_finished.wait() - raise RuntimeError("rollout failed") - state = await super().rollout(input, client, model, sampling_args) - first_rollout_finished.set() - return state - - rubric = RecordingRubric() - env = PartiallyFailingEnv(dataset=_dataset(), parser=vf.Parser(), rubric=rubric) - - with pytest.raises(RuntimeError, match="rollout failed"): - await env._run_group_states( - [_input(0), _input(1)], - mock_client, - "test-model", - {}, - ) - - assert rubric.cleaned == [0] - - -@pytest.mark.asyncio -async def test_run_group_states_cleans_all_states_when_group_scoring_raises( - mock_client, -): - rubric = RecordingRubric(score_group_error=RuntimeError("group score failed")) - env = _env(rubric) - - with pytest.raises(RuntimeError, match="group score failed"): - await env._run_group_states( - [_input(0), _input(1)], - mock_client, - "test-model", - {}, - ) - - assert rubric.cleaned == [0, 1] - - -@pytest.mark.asyncio -async def test_environment_cleanup_failure_does_not_skip_later_handler(mock_client): - class FailingCleanupEnv(StaticRolloutEnv): - def __init__(self, **kwargs: Any): - super().__init__(**kwargs) - self.destroyed = False - - @vf.cleanup(priority=1) - async def failing_cleanup(self, state: State): - raise RuntimeError("early cleanup failed") - - @vf.cleanup(priority=0) - async def destroy_sandbox(self, state: State): - self.destroyed = True - - env = FailingCleanupEnv(dataset=_dataset(), parser=vf.Parser(), rubric=vf.Rubric()) - state = await env.init_state(_input(0), mock_client, "test-model") - - with pytest.raises(RuntimeError, match="early cleanup failed"): - await env.cleanup(state) - - assert env.destroyed is True - - -@pytest.mark.asyncio -async def test_rubric_group_cleanup_failure_does_not_skip_later_rubric(): - class FailingRubric(vf.Rubric): - async def cleanup(self, state: State): - raise RuntimeError("rubric cleanup failed") - - class DestroyingRubric(vf.Rubric): - async def cleanup(self, state: State): - state["destroyed"] = True - - state: State = vf.State(input={}) - rubric = vf.RubricGroup([FailingRubric(), DestroyingRubric()]) - - with pytest.raises(RuntimeError, match="rubric cleanup failed"): - await rubric.cleanup(state) - - assert state["destroyed"] is True - - -@pytest.mark.asyncio -async def test_cli_agent_destroy_sandbox_deletes_when_post_rollout_fails(): - class FailingPostRolloutEnv(CliAgentEnv): - async def post_rollout(self, state: State): - raise RuntimeError("post rollout failed") - - env = FailingPostRolloutEnv( - run_command="echo done", - dataset=_dataset(), - parser=vf.Parser(), - rubric=vf.Rubric(), - keep_sandbox_for_scoring=True, - ) - env.delete_sandbox = AsyncMock() # type: ignore[method-assign] - state: State = vf.State(input={}) - state.update({"is_completed": True, "sandbox_id": "sb-post-rollout-failed"}) - - try: - with pytest.raises(RuntimeError, match="post rollout failed"): - await env.destroy_sandbox(state) - env.delete_sandbox.assert_awaited_once_with("sb-post-rollout-failed") - finally: - env.teardown_sandbox_client() - - -@pytest.mark.asyncio -async def test_swe_rubric_model_error_skips_sandbox_scoring(): - class StubTaskSet: - def __init__(self): - self.ran_tests = False - - async def _run_tests(self, *args: Any, **kwargs: Any) -> str: - self.ran_tests = True - return "PASS" - - def _calculate_reward(self, test_output: str, info: dict[str, Any]) -> float: - return 1.0 - - taskset = StubTaskSet() - rubric = R2ERubric(taskset) # type: ignore[arg-type] - state: State = vf.State(input={}) - state.update( - { - "error": vf.ModelError("No available workers"), - "sandbox_client": object(), - "sandbox_id": "sb-leaked-without-short-circuit", - } - ) - - reward = await rubric.solved(state, info={}) - - assert reward == 0.0 - assert taskset.ran_tests is False diff --git a/verifiers/envs/environment.py b/verifiers/envs/environment.py index d506ee994..ed379f086 100644 --- a/verifiers/envs/environment.py +++ b/verifiers/envs/environment.py @@ -633,40 +633,14 @@ async def cleanup( """ Finalize rollout state and clean up rollout-local resources. """ - cleanup_error: Exception | None = None for handler in self._cleanup_handlers: - try: - await maybe_call_with_named_args( - handler, - task=task, - state=state, - env=self, - resources=resources, - ) - except Exception as e: - if cleanup_error is None: - cleanup_error = e - self.logger.exception( - "Cleanup handler %s failed", - getattr(handler, "__name__", repr(handler)), - ) - if cleanup_error is not None: - raise cleanup_error - - async def _cleanup_rollout_states(self, states: list[State]) -> None: - cleanup_error: Exception | None = None - for state in states: - try: - await self.rubric.cleanup(state) - except Exception as e: - if cleanup_error is None: - cleanup_error = e - self.logger.exception( - "Rubric cleanup failed for rollout example_id=%s", - state.get("example_id"), - ) - if cleanup_error is not None: - raise cleanup_error + await maybe_call_with_named_args( + handler, + task=task, + state=state, + env=self, + resources=resources, + ) async def _teardown(self): """ @@ -717,26 +691,14 @@ async def _run_rollout_state( sampling_args, ) - primary_error: BaseException | None = None - try: - state["timing"].scoring.start = time.time() - if self.score_rollouts: - await self.rubric.score_rollout(state) - else: - await self.rubric.dummy_score_rollout(state) - state["timing"].scoring.end = time.time() - except BaseException as e: - primary_error = e - raise - finally: - try: - await self._cleanup_rollout_states([state]) - except Exception: - if primary_error is None: - raise - self.logger.exception( - "Rubric cleanup failed after rollout scoring failed" - ) + state["timing"].scoring.start = time.time() + if self.score_rollouts: + await self.rubric.score_rollout(state) + else: + await self.rubric.dummy_score_rollout(state) + state["timing"].scoring.end = time.time() + + await self.rubric.cleanup(state) return state async def _run_group_states( @@ -747,65 +709,31 @@ async def _run_group_states( sampling_args: SamplingArgs, ) -> list[State]: rollout_tasks = [ - asyncio.create_task( - self.rollout( - input, - client, - model, - sampling_args, - ) + self.rollout( + input, + client, + model, + sampling_args, ) for input in group_inputs ] - group_states: list[State] = [] - primary_error: BaseException | None = None - try: - group_states = await asyncio.gather(*rollout_tasks) + group_states = await asyncio.gather(*rollout_tasks) - start_scoring = time.time() - for state in group_states: - state["timing"].scoring.start = start_scoring - if self.score_rollouts: - await self.rubric.score_group(group_states) - else: - await self.rubric.dummy_score_group(group_states) - end_scoring = time.time() - for state in group_states: - state["timing"].scoring.end = end_scoring - - return group_states - except BaseException as e: - primary_error = e - raise - finally: - pending_tasks = [task for task in rollout_tasks if not task.done()] - if pending_tasks: - for task in pending_tasks: - task.cancel() - await asyncio.gather(*pending_tasks, return_exceptions=True) - - cleanup_states: list[State] = [] - seen_state_ids: set[int] = set() - for state in group_states: - cleanup_states.append(state) - seen_state_ids.add(id(state)) - for task in rollout_tasks: - if not task.done() or task.cancelled(): - continue - try: - state = task.result() - except BaseException: - continue - if id(state) not in seen_state_ids: - cleanup_states.append(state) - seen_state_ids.add(id(state)) + start_scoring = time.time() + for state in group_states: + state["timing"].scoring.start = start_scoring + if self.score_rollouts: + await self.rubric.score_group(group_states) + else: + await self.rubric.dummy_score_group(group_states) + end_scoring = time.time() + for state in group_states: + state["timing"].scoring.end = end_scoring - try: - await self._cleanup_rollout_states(cleanup_states) - except Exception: - if primary_error is None: - raise - self.logger.exception("Rubric cleanup failed after group run failed") + for state in group_states: + await self.rubric.cleanup(state) + + return group_states @final async def run_rollout( diff --git a/verifiers/envs/experimental/cli_agent_env.py b/verifiers/envs/experimental/cli_agent_env.py index 99f24581a..5b2a89324 100644 --- a/verifiers/envs/experimental/cli_agent_env.py +++ b/verifiers/envs/experimental/cli_agent_env.py @@ -758,25 +758,14 @@ async def destroy_sandbox(self, state: State): the sandbox is always deleted since scoring will not happen. """ completed = state.get("is_completed", False) - post_rollout_error: Exception | None = None if completed: - try: - await self.post_rollout(state) - except Exception as e: - post_rollout_error = e - self.logger.exception("Post-rollout cleanup failed") + await self.post_rollout(state) sandbox_id = state.get("sandbox_id") if sandbox_id: - if ( - self.keep_sandbox_for_scoring - and completed - and post_rollout_error is None - ): + if self.keep_sandbox_for_scoring and completed: self.deregister_sandbox(sandbox_id) else: await self.delete_sandbox(sandbox_id) - if post_rollout_error is not None: - raise post_rollout_error async def env_response( self, messages: Messages, state: State, **kwargs diff --git a/verifiers/envs/experimental/composable/composable_env.py b/verifiers/envs/experimental/composable/composable_env.py index ee92574c5..1a519ec21 100644 --- a/verifiers/envs/experimental/composable/composable_env.py +++ b/verifiers/envs/experimental/composable/composable_env.py @@ -89,28 +89,18 @@ def _upload_tar_filter(info: tarfile.TarInfo) -> tarfile.TarInfo | None: class HarnessMetricsRubricGroup(vf.RubricGroup): async def cleanup(self, state: State) -> None: - cleanup_error: Exception | None = None for rubric in self.rubrics: - try: - await rubric.cleanup(state) - except Exception as e: - if cleanup_error is None: - cleanup_error = e - self.logger.exception( - "Cleanup for rubric %s failed", - rubric.__class__.__name__, - ) + await rubric.cleanup(state) harness_metrics = state.get("_harness_metrics") - if isinstance(harness_metrics, dict): - state_metrics = state.get("metrics") - if not isinstance(state_metrics, dict): - state_metrics = {} - state["metrics"] = state_metrics - for key, value in harness_metrics.items(): - if isinstance(key, str) and isinstance(value, (int, float)): - state_metrics[key] = float(value) - if cleanup_error is not None: - raise cleanup_error + if not isinstance(harness_metrics, dict): + return + state_metrics = state.get("metrics") + if not isinstance(state_metrics, dict): + state_metrics = {} + state["metrics"] = state_metrics + for key, value in harness_metrics.items(): + if isinstance(key, str) and isinstance(value, (int, float)): + state_metrics[key] = float(value) class ComposableEnv(CliAgentEnv): diff --git a/verifiers/envs/experimental/composable/tasksets/swe/multi_swe.py b/verifiers/envs/experimental/composable/tasksets/swe/multi_swe.py index d2cafee1a..d9631099e 100644 --- a/verifiers/envs/experimental/composable/tasksets/swe/multi_swe.py +++ b/verifiers/envs/experimental/composable/tasksets/swe/multi_swe.py @@ -166,8 +166,8 @@ async def cleanup_sandbox(self, state: vf.State) -> None: if sandbox_client and sandbox_id: try: await sandbox_client.delete(sandbox_id) - except Exception as e: - logger.warning(f"Failed to delete sandbox {sandbox_id}: {e}") + except Exception: + pass class MultiSWETaskSet(SandboxTaskSet): diff --git a/verifiers/envs/experimental/composable/tasksets/swe/openswe.py b/verifiers/envs/experimental/composable/tasksets/swe/openswe.py index d4add88f5..251e4e869 100644 --- a/verifiers/envs/experimental/composable/tasksets/swe/openswe.py +++ b/verifiers/envs/experimental/composable/tasksets/swe/openswe.py @@ -67,8 +67,8 @@ async def cleanup_sandbox(self, state: vf.State) -> None: if sandbox_client and sandbox_id: try: await sandbox_client.delete(sandbox_id) - except Exception as e: - logger.warning(f"Failed to delete sandbox {sandbox_id}: {e}") + except Exception: + pass class OpenSWETaskSet(SandboxTaskSet): diff --git a/verifiers/envs/experimental/composable/tasksets/swe/r2e_gym.py b/verifiers/envs/experimental/composable/tasksets/swe/r2e_gym.py index 88897f122..310948654 100644 --- a/verifiers/envs/experimental/composable/tasksets/swe/r2e_gym.py +++ b/verifiers/envs/experimental/composable/tasksets/swe/r2e_gym.py @@ -165,8 +165,8 @@ async def cleanup_sandbox(self, state: vf.State) -> None: if sandbox_client and sandbox_id: try: await sandbox_client.delete(sandbox_id) - except Exception as e: - logger.warning(f"Failed to delete sandbox {sandbox_id}: {e}") + except Exception: + pass class R2EGymTaskSet(SandboxTaskSet): diff --git a/verifiers/envs/experimental/composable/tasksets/swe/swe_bench.py b/verifiers/envs/experimental/composable/tasksets/swe/swe_bench.py index 7f12c3a9a..e7c1d8259 100644 --- a/verifiers/envs/experimental/composable/tasksets/swe/swe_bench.py +++ b/verifiers/envs/experimental/composable/tasksets/swe/swe_bench.py @@ -338,8 +338,8 @@ async def cleanup_sandbox(self, state: vf.State) -> None: if sandbox_client and sandbox_id: try: await sandbox_client.delete(sandbox_id) - except Exception as e: - logger.warning(f"Failed to delete sandbox {sandbox_id}: {e}") + except Exception: + pass class SWEBenchTaskSet(SandboxTaskSet): diff --git a/verifiers/envs/experimental/composable/tasksets/swe/swe_lego.py b/verifiers/envs/experimental/composable/tasksets/swe/swe_lego.py index 093208d20..13cecd871 100644 --- a/verifiers/envs/experimental/composable/tasksets/swe/swe_lego.py +++ b/verifiers/envs/experimental/composable/tasksets/swe/swe_lego.py @@ -158,8 +158,8 @@ async def cleanup_sandbox(self, state: vf.State) -> None: if sandbox_client and sandbox_id: try: await sandbox_client.delete(sandbox_id) - except Exception as e: - logger.warning(f"Failed to delete sandbox {sandbox_id}: {e}") + except Exception: + pass class SWELegoTaskSet(SandboxTaskSet): diff --git a/verifiers/envs/experimental/composable/tasksets/swe/swe_rebench_v2.py b/verifiers/envs/experimental/composable/tasksets/swe/swe_rebench_v2.py index 870a2ed97..bdd61a2ce 100644 --- a/verifiers/envs/experimental/composable/tasksets/swe/swe_rebench_v2.py +++ b/verifiers/envs/experimental/composable/tasksets/swe/swe_rebench_v2.py @@ -214,8 +214,8 @@ async def cleanup_sandbox(self, state: vf.State) -> None: if sandbox_client and sandbox_id: try: await sandbox_client.delete(sandbox_id) - except Exception as e: - logger.warning(f"Failed to delete sandbox {sandbox_id}: {e}") + except Exception: + pass class SWERebenchV2TaskSet(SandboxTaskSet): diff --git a/verifiers/envs/experimental/composable/tasksets/swe/swe_smith.py b/verifiers/envs/experimental/composable/tasksets/swe/swe_smith.py index d037835f7..7c99228a6 100644 --- a/verifiers/envs/experimental/composable/tasksets/swe/swe_smith.py +++ b/verifiers/envs/experimental/composable/tasksets/swe/swe_smith.py @@ -169,8 +169,8 @@ async def cleanup_sandbox(self, state: vf.State) -> None: if sandbox_client and sandbox_id: try: await sandbox_client.delete(sandbox_id) - except Exception as e: - logger.warning(f"Failed to delete sandbox {sandbox_id}: {e}") + except Exception: + pass class SWESmithTaskSet(SandboxTaskSet): diff --git a/verifiers/rubrics/rubric.py b/verifiers/rubrics/rubric.py index c92526153..914ae4d69 100644 --- a/verifiers/rubrics/rubric.py +++ b/verifiers/rubrics/rubric.py @@ -293,19 +293,8 @@ async def _call_group_reward_func( async def cleanup(self, state: State): """Run all @vf.cleanup-decorated methods on this rubric.""" - cleanup_error: Exception | None = None for handler in self._cleanup_handlers: - try: - await self._call_cleanup_handler(handler, state) - except Exception as e: - if cleanup_error is None: - cleanup_error = e - self.logger.exception( - "Cleanup handler %s failed", - getattr(handler, "__name__", repr(handler)), - ) - if cleanup_error is not None: - raise cleanup_error + await self._call_cleanup_handler(handler, state) async def _call_cleanup_handler(self, handler: Callable[..., Any], state: State): objects = self.cleanup_objects(handler, state) diff --git a/verifiers/rubrics/rubric_group.py b/verifiers/rubrics/rubric_group.py index 08dd202a1..85807f902 100644 --- a/verifiers/rubrics/rubric_group.py +++ b/verifiers/rubrics/rubric_group.py @@ -101,23 +101,9 @@ async def score_rollout(self, state: State): async def cleanup(self, state: State): """Run cleanup for all rubrics in the group.""" - cleanup_error: Exception | None = None - try: - await super().cleanup(state) - except Exception as e: - cleanup_error = e + await super().cleanup(state) for rubric in self.rubrics: - try: - await rubric.cleanup(state) - except Exception as e: - if cleanup_error is None: - cleanup_error = e - self.logger.exception( - "Cleanup for rubric %s failed", - rubric.__class__.__name__, - ) - if cleanup_error is not None: - raise cleanup_error + await rubric.cleanup(state) async def teardown(self): """Run teardown for all rubrics in the group."""