diff --git a/envs/terminus_env/README.md b/envs/terminus_env/README.md index 246ba709f..b8491fa10 100644 --- a/envs/terminus_env/README.md +++ b/envs/terminus_env/README.md @@ -10,17 +10,17 @@ base_path: /web tags: - openenv - terminus - - e2b + - hf-sandbox - coding -short_description: Single-tool E2B-backed coding environment +short_description: Single-tool coding environment --- # Terminus Environment -`terminus_env` is a single-tool coding environment backed by E2B Code -Interpreter. Each OpenEnv episode creates a fresh E2B sandbox, runs optional -setup commands, keeps shell state and files isolated for that episode, and runs -optional verify commands when the agent submits a final answer. +`terminus_env` is a single-tool coding environment. Each OpenEnv episode +creates a fresh sandbox, runs optional setup commands, keeps shell state and +files isolated for that episode, and runs optional verify commands when the +agent submits a final answer. The tool shape follows the Terminus-style "one tool" idea: agents do their work through a single terminal entrypoint rather than a notebook/toolbox surface. @@ -48,7 +48,7 @@ with TerminusEnv(base_url="http://localhost:8000").sync() as env: ```bash cd envs/terminus_env -E2B_API_KEY=e2b_... uv run --project . server +TERMINUS_SANDBOX_BACKEND=local uv run --project . server ``` The API and custom terminal web UI are served on port 8000. The UI is mounted @@ -59,14 +59,26 @@ at `/web`. ```bash cd envs/terminus_env openenv build -t terminus-env -docker run -p 8000:8000 -e E2B_API_KEY=e2b_... terminus-env +docker run -p 8000:8000 -e HF_TOKEN=hf_... terminus-env ``` ## Configuration -- `E2B_API_KEY`: required when resetting an episode. +- `TERMINUS_SANDBOX_BACKEND`: `local` for the lightweight cluster smoke + backend, or `hf` for `hf-sandbox`. Defaults to `hf`. +- `HF_TOKEN`: required by the optional `hf-sandbox` backend to launch + Hugging Face Jobs. +- `HF_SANDBOX_IMAGE`: sandbox image. Defaults to `python:3.12`. +- `HF_SANDBOX_FLAVOR`: Hugging Face Jobs flavor. Defaults to `cpu-basic`. +- `HF_SANDBOX_TIMEOUT`: Hugging Face Jobs timeout. Defaults to `1h`. +- `HF_SANDBOX_FORWARD_HF_TOKEN`: forward `HF_TOKEN` into the sandbox. Defaults + to `false`. - `MAX_CONCURRENT_ENVS`: maximum concurrent WebSocket sessions. Defaults to `4`. +The local backend requires `bwrap` on the server node and is intended for simple +cluster smoke tasks. Install the `hf` extra and use the `hf` backend for +stronger remote sandboxing. + ## Setup and Verify Commands `reset()` accepts either `setup` / `verify` or `setup_scripts` / diff --git a/envs/terminus_env/__init__.py b/envs/terminus_env/__init__.py index 9c888811a..22d39d554 100644 --- a/envs/terminus_env/__init__.py +++ b/envs/terminus_env/__init__.py @@ -9,12 +9,16 @@ from openenv.core.env_server.mcp_types import CallToolAction, ListToolsAction from .client import TerminusEnv +from .harness import TerminusSessionFactory, build_terminal_tool_call, terminus_reward from .models import CommandResult, TerminusState __all__ = [ "TerminusEnv", + "TerminusSessionFactory", "TerminusState", "CommandResult", "CallToolAction", "ListToolsAction", + "build_terminal_tool_call", + "terminus_reward", ] diff --git a/envs/terminus_env/client.py b/envs/terminus_env/client.py index f0ad00798..3d7bc35ec 100644 --- a/envs/terminus_env/client.py +++ b/envs/terminus_env/client.py @@ -6,10 +6,38 @@ """Client for the Terminus environment.""" +from typing import Any + from openenv.core.mcp_client import MCPToolClient +from .models import CommandResult, TerminusState + class TerminusEnv(MCPToolClient): """MCP client for calling the Terminus single-rollout tool.""" - pass + def _parse_state(self, payload: dict[str, Any]) -> TerminusState: + """Convert server state payloads to the Terminus state model.""" + + def command_results(name: str) -> list[CommandResult]: + values = payload.get(name, []) + if not isinstance(values, list): + return [] + return [ + value if isinstance(value, CommandResult) else CommandResult(**value) + for value in values + if isinstance(value, dict) or isinstance(value, CommandResult) + ] + + return TerminusState( + episode_id=payload.get("episode_id"), + step_count=payload.get("step_count", 0), + sandbox_id=payload.get("sandbox_id"), + setup_results=command_results("setup_results"), + verify_commands=list(payload.get("verify_commands", []) or []), + verify_results=command_results("verify_results"), + commands=command_results("commands"), + submitted_answer=payload.get("submitted_answer"), + last_reward=payload.get("last_reward"), + last_error=payload.get("last_error"), + ) diff --git a/envs/terminus_env/harness.py b/envs/terminus_env/harness.py new file mode 100644 index 000000000..0cd3b97c1 --- /dev/null +++ b/envs/terminus_env/harness.py @@ -0,0 +1,463 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +"""Harness-oriented Terminus session adapter.""" + +from __future__ import annotations + +import ast +import json +import re +from typing import Any, Callable + +from openenv.core.env_server.mcp_types import CallToolAction, Tool +from openenv.core.harness import ( + ResourceSessionFactory, + StepEnvSessionAdapter, + ToolResult, + VerifyResult, +) + +from .client import TerminusEnv + +REWARD_RE = re.compile(r"reward=([+-]?(?:\d+(?:\.\d*)?|\.\d+))") + +_TERMINUS_TOOLS: list[Tool] = [ + Tool( + name="terminal", + description=( + "Run a shell command in the Terminus sandbox, or submit final_answer " + "to trigger verification." + ), + input_schema={ + "type": "object", + "properties": { + "command": { + "type": "string", + "description": "Shell command to run in the sandbox.", + }, + "final_answer": { + "type": "string", + "description": "Final answer to submit when the task is complete.", + }, + }, + "additionalProperties": False, + }, + ) +] + + +def _task_field(task: Any, *names: str, default: Any = None) -> Any: + if not isinstance(task, dict): + return default + for name in names: + value = task.get(name) + if value is not None: + return value + return default + + +def _coerce_commands(value: Any) -> list[str]: + if value is None: + return [] + if isinstance(value, str): + return [value] if value.strip() else [] + return [str(item) for item in value if str(item).strip()] + + +def _format_initial_prompt(result: Any, task: Any) -> str: + if isinstance(task, str): + instruction = task + setup_commands: list[str] = [] + verify_commands: list[str] = [] + elif isinstance(task, list): + user_messages = [ + item.get("content") + for item in task + if isinstance(item, dict) and item.get("role") == "user" + ] + instruction = str(user_messages[-1] if user_messages else task) + setup_commands = [] + verify_commands = [] + elif isinstance(task, dict): + instruction = str( + _task_field(task, "instruction", "prompt", "question", "task", default="") + ) + setup_commands = _coerce_commands(_task_field(task, "setup", "setup_scripts")) + verify_commands = _coerce_commands(_task_field(task, "verify", "verify_scripts")) + else: + instruction = str(task or "") + setup_commands = [] + verify_commands = [] + + metadata = getattr(result.observation, "metadata", {}) or {} + verify_commands = _coerce_commands( + metadata.get("verify_commands") or verify_commands + ) + + parts = [] + if instruction: + parts.append(f"Task:\n{instruction}") + else: + parts.append("Task:\nUse the terminal tool to solve the current task.") + + reset_message = metadata.get("message") + if reset_message: + parts.append(f"Environment:\n{reset_message}") + + if setup_commands: + parts.append( + "Setup commands have already run:\n" + + "\n".join(f"- {command}" for command in setup_commands) + ) + if verify_commands: + parts.append( + "Verification commands will run after final_answer:\n" + + "\n".join(f"- {command}" for command in verify_commands) + ) + + parts.append( + 'Use {"command": "..."} to inspect and modify the sandbox. ' + 'When finished, use {"final_answer": "..."} exactly once so ' + "verification runs and emits the environment reward." + ) + return "\n\n".join(parts) + + +def _extract_tool_output(observation: Any) -> Any: + result = getattr(observation, "result", None) + if result is None: + return None + if hasattr(result, "data"): + return result.data + if isinstance(result, dict): + if "data" in result: + return result["data"] + content = result.get("content") + if isinstance(content, list): + texts = [ + str(item.get("text")) + for item in content + if isinstance(item, dict) and item.get("text") is not None + ] + if texts: + return "\n".join(texts) + return result + content = getattr(result, "content", None) + if isinstance(content, list): + texts = [ + getattr(item, "text", None) + for item in content + if getattr(item, "text", None) is not None + ] + if texts: + return "\n".join(texts) + return result + + +def _tool_error_message(observation: Any) -> str | None: + error = getattr(observation, "error", None) + if error is None: + return None + message = getattr(error, "message", None) + if message is not None: + return str(message) + if isinstance(error, dict): + return str(error.get("message") or error) + return str(error) + + +def _state_to_data(state: Any) -> Any: + if state is None: + return None + if hasattr(state, "model_dump"): + return state.model_dump() + return state + + +def _build_tool_result( + tool_name: str, + arguments: dict[str, Any], + result: Any, + state: Any, +) -> ToolResult: + output = _extract_tool_output(result.observation) + error = _tool_error_message(result.observation) + data = { + "tool_name": tool_name, + "arguments": dict(arguments), + "output": output, + "reward": result.reward, + "done": result.done, + } + if error: + data["error"] = error + + return ToolResult( + data=data, + done=bool(result.done), + error=error, + metadata={ + "reward": result.reward, + "state": _state_to_data(state), + }, + ) + + +def _build_verify( + transcript: list[dict[str, Any]], + final_state: Any | None, + last_result: Any | None, + state: Any, +) -> VerifyResult: + reward = None if last_result is None else last_result.reward + done = False if last_result is None else bool(last_result.done) + state_data = _state_to_data(state) + metrics = { + "done": done, + "step_count": getattr(state, "step_count", 0), + "commands": len(getattr(state, "commands", []) or []), + "verify_commands": len(getattr(state, "verify_commands", []) or []), + "setup_commands": len(getattr(state, "setup_results", []) or []), + "submitted_answer": getattr(state, "submitted_answer", None) is not None, + "sandbox_id": getattr(state, "sandbox_id", None), + } + if state is None and last_result is not None: + metrics["step_count"] = len(transcript) + return VerifyResult( + env_reward=reward, + done=done, + metrics=metrics, + artifacts={ + "final_state": state_data, + "final_rollout": final_state, + "transcript_length": len(transcript), + }, + ) + + +def _build_reset_kwargs( + task: Any, + default_setup: list[str], + default_verify: list[str], + default_sandbox: dict[str, Any], +) -> dict[str, Any]: + reset_kwargs: dict[str, Any] = dict(default_sandbox) + setup = list(default_setup) + verify = list(default_verify) + if isinstance(task, dict): + setup = _coerce_commands(_task_field(task, "setup", "setup_scripts", default=setup)) + verify = _coerce_commands( + _task_field(task, "verify", "verify_scripts", default=verify) + ) + for key in ( + "sandbox_image", + "sandbox_flavor", + "sandbox_timeout", + "hf_sandbox_image", + "hf_sandbox_flavor", + "hf_sandbox_timeout", + "forward_hf_token", + "sandbox_backend", + "sandbox_root", + ): + if key in task: + reset_kwargs[key] = task[key] + + if setup: + reset_kwargs["setup"] = setup + if verify: + reset_kwargs["verify"] = verify + return reset_kwargs + + +class TerminusSessionFactory(ResourceSessionFactory): + """Create Terminus-backed resource sessions for harness rollouts.""" + + def __init__( + self, + client_factory: Callable[[], TerminusEnv], + *, + default_setup: list[str] | None = None, + default_verify: list[str] | None = None, + sandbox: dict[str, Any] | None = None, + ): + self._client_factory = client_factory + self._default_setup = list(default_setup or []) + self._default_verify = list(default_verify or []) + self._sandbox = dict(sandbox or {}) + + def create( + self, + task: Any = None, + seed: int | None = None, + episode_id: str | None = None, + ) -> StepEnvSessionAdapter: + reset_kwargs = _build_reset_kwargs( + task, + self._default_setup, + self._default_verify, + self._sandbox, + ) + + return StepEnvSessionAdapter( + client=self._client_factory(), + task=task, + seed=seed, + episode_id=episode_id, + tool_specs=list(_TERMINUS_TOOLS), + action_builder=lambda name, arguments: CallToolAction( + tool_name=name, + arguments=dict(arguments), + ), + initial_messages_builder=lambda result, current_task: [ + { + "role": "user", + "content": _format_initial_prompt(result, current_task), + } + ], + tool_result_builder=_build_tool_result, + verify_builder=_build_verify, + reset_kwargs=reset_kwargs, + ) + + +def terminus_reward(completions=None, **kwargs) -> list[float]: + """Extract Terminus rewards from TRL tool messages.""" + + del kwargs + rewards = [] + for completion in completions or []: + reward = 0.0 + for message in completion if isinstance(completion, list) else []: + if not isinstance(message, dict) or message.get("role") != "tool": + continue + parsed = _parse_tool_reward(str(message.get("content", ""))) + if parsed is not None: + reward = parsed + rewards.append(reward) + return rewards + + +def _parse_tool_reward(content: str) -> float | None: + try: + payload = json.loads(content) + except json.JSONDecodeError: + payload = None + if isinstance(payload, dict) and payload.get("reward") is not None: + try: + return float(payload["reward"]) + except (TypeError, ValueError): + return None + match = REWARD_RE.search(content) + if match is None: + return None + return float(match.group(1)) + + +_TERMINAL_CALL_RE = re.compile(r"terminal\s*\((?P.*?)\)", re.DOTALL) + + +def build_terminal_tool_call(response_text: str, *, call_id: str = "terminal-0"): + """Parse a terminal call from model text. + + The preferred format is one JSON object containing ``command`` or + ``final_answer``. The parser also accepts Pi-style ``terminal(...)`` text + because small policy models often imitate that syntax before they learn + structured tool calls. Invalid text falls back to a shell command so the + environment, not this parser, decides whether a rollout earns reward. + """ + + from openenv.core.llm_client import ToolCall + + text = _strip_code_fence(response_text.strip()) + payload = _parse_terminal_json(text) + if payload is None: + payload = _parse_terminal_expression(text) + if payload is None: + payload = {"command": response_text} + + arguments = { + key: str(payload[key]) + for key in ("command", "final_answer") + if payload.get(key) is not None + } + if not arguments: + arguments = {"command": ""} + if arguments.get("command") and arguments.get("final_answer"): + arguments = {"command": arguments["command"]} + return ToolCall(id=call_id, name="terminal", args=arguments) + + +def _strip_code_fence(text: str) -> str: + if not text.startswith("```"): + return text + stripped = text.strip("`").strip() + if stripped.startswith("json"): + return stripped[4:].strip() + return stripped + + +def _parse_terminal_json(text: str) -> dict[str, Any] | None: + decoder = json.JSONDecoder() + for start, character in enumerate(text): + if character != "{": + continue + try: + payload, _ = decoder.raw_decode(text[start:]) + except json.JSONDecodeError: + continue + normalized = _normalize_terminal_payload(payload) + if normalized is not None: + return normalized + return None + + +def _normalize_terminal_payload(payload: Any) -> dict[str, Any] | None: + if not isinstance(payload, dict): + return None + if "arguments" in payload: + arguments = payload["arguments"] + if isinstance(arguments, str): + try: + arguments = json.loads(arguments) + except json.JSONDecodeError: + return None + return _normalize_terminal_payload(arguments) + if any(payload.get(key) is not None for key in ("command", "final_answer")): + return payload + return None + + +def _parse_terminal_expression(text: str) -> dict[str, Any] | None: + match = _TERMINAL_CALL_RE.search(text) + if not match: + return None + try: + expression = ast.parse(f"terminal({match.group('body')})", mode="eval") + except SyntaxError: + return None + if not isinstance(expression.body, ast.Call): + return None + payload: dict[str, Any] = {} + for keyword in expression.body.keywords: + if keyword.arg not in {"command", "final_answer"}: + continue + value = keyword.value + try: + payload[keyword.arg] = ast.literal_eval(value) + except (ValueError, SyntaxError): + if isinstance(value, ast.Name): + payload[keyword.arg] = value.id + return payload or None + + +__all__ = [ + "REWARD_RE", + "TerminusSessionFactory", + "build_terminal_tool_call", + "terminus_reward", +] diff --git a/envs/terminus_env/models.py b/envs/terminus_env/models.py index 3742f8a02..bbd2eedbe 100644 --- a/envs/terminus_env/models.py +++ b/envs/terminus_env/models.py @@ -13,7 +13,7 @@ class CommandResult(BaseModel): - """Outcome of one shell command run inside the E2B sandbox.""" + """Outcome of one shell command run inside the HF sandbox.""" command: str output: str = "" diff --git a/envs/terminus_env/pyproject.toml b/envs/terminus_env/pyproject.toml index 35ddee709..87e3a844c 100644 --- a/envs/terminus_env/pyproject.toml +++ b/envs/terminus_env/pyproject.toml @@ -11,11 +11,10 @@ build-backend = "setuptools.build_meta" [project] name = "openenv-terminus-env" version = "0.1.0" -description = "Single-tool E2B-backed coding environment for OpenEnv" +description = "Single-tool coding environment for OpenEnv" requires-python = ">=3.10" dependencies = [ "openenv-core[core]>=0.2.2", - "e2b-code-interpreter>=1.0.0", "fastapi>=0.115.0", "fastmcp>=3.0.0", "gradio>=4.0.0", @@ -25,6 +24,9 @@ dependencies = [ ] [project.optional-dependencies] +hf = [ + "hf-sandbox>=0.1.1", +] dev = [ "pytest>=8.0.0", "pytest-cov>=4.0.0", diff --git a/envs/terminus_env/server/e2b_sandbox.py b/envs/terminus_env/server/e2b_sandbox.py deleted file mode 100644 index 28046973c..000000000 --- a/envs/terminus_env/server/e2b_sandbox.py +++ /dev/null @@ -1,82 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -"""Small E2B Code Interpreter wrapper for terminal-style environments.""" - -from __future__ import annotations - -from dataclasses import dataclass -from typing import Any - -_E2B_IMPORT_ERROR: ImportError | None = None - -try: - from e2b_code_interpreter import Sandbox -except ImportError as _e2b_import_error: # pragma: no cover - _E2B_IMPORT_ERROR = _e2b_import_error - Sandbox = None # type: ignore[assignment] - - -@dataclass -class ShellResult: - """Normalized result from a command executed in E2B.""" - - stdout: str - stderr: str - error: str | None - success: bool - - -class E2BSandbox: - """Manages one E2B sandbox for one OpenEnv episode.""" - - def __init__(self, api_key: str): - if Sandbox is None: - raise ImportError( - "e2b-code-interpreter is not installed. Install the " - "terminus_env package dependencies to use E2BSandbox. " - f"Original import error: {_E2B_IMPORT_ERROR}" - ) - self._sbx = Sandbox.create(api_key=api_key) - self.sandbox_id: str = self._sbx.sandbox_id - - def run_shell(self, command: str, timeout_s: int = 120) -> ShellResult: - shell_code = ( - "import subprocess, sys\n" - f"_result = subprocess.run({command!r}, shell=True, capture_output=True, text=True, timeout={timeout_s})\n" - "print(_result.stdout, end='')\n" - "if _result.stderr: print(_result.stderr, end='', file=sys.stderr)\n" - "if _result.returncode != 0:\n" - " raise SystemExit(_result.returncode)\n" - ) - execution = self._sbx.run_code(shell_code) - return _normalize(execution) - - def kill(self) -> None: - try: - self._sbx.kill() - except Exception: - try: - self._sbx.close() - except Exception: - pass - - -def _normalize(execution: Any) -> ShellResult: - stdout = "\n".join(execution.logs.stdout) if execution.logs.stdout else "" - stderr = "\n".join(execution.logs.stderr) if execution.logs.stderr else "" - error = None - if execution.error: - error = ( - f"{execution.error.name}: {execution.error.value}\n" - f"{execution.error.traceback}" - ) - return ShellResult( - stdout=stdout, - stderr=stderr, - error=error, - success=execution.error is None, - ) diff --git a/envs/terminus_env/server/gradio_ui.py b/envs/terminus_env/server/gradio_ui.py index e7d5d74e7..1493dc3ae 100644 --- a/envs/terminus_env/server/gradio_ui.py +++ b/envs/terminus_env/server/gradio_ui.py @@ -87,7 +87,7 @@ _EMPTY_TERMINAL = ( "
" - "Reset the environment to create an E2B sandbox, then run commands." + "Reset the environment to create an HF sandbox, then run commands." "
" ) @@ -202,7 +202,7 @@ def current_state() -> dict[str, Any]: with gr.Blocks(title=f"{title} - Terminal") as demo: gr.Markdown(f"# {title}") gr.Markdown( - "Single-tool terminal environment backed by an E2B sandbox. " + "Single-tool terminal environment backed by an HF sandbox. " "Reset creates a fresh session; commands run through `terminal(command=...)`." ) @@ -236,7 +236,7 @@ def current_state() -> dict[str, Any]: label="Terminal command", value="pwd && ls -la", lines=6, - placeholder="Run shell commands in the E2B sandbox", + placeholder="Run shell commands in the HF sandbox", ) run_btn = gr.Button("Run command", variant="primary") @@ -280,7 +280,7 @@ async def on_close(): return ( _closed_terminal_html(), {}, - "Session closed. The E2B sandbox was released.", + "Session closed. The HF sandbox was released.", ) async def on_run(command: str): diff --git a/envs/terminus_env/server/hf_sandbox.py b/envs/terminus_env/server/hf_sandbox.py new file mode 100644 index 000000000..d09cdc00b --- /dev/null +++ b/envs/terminus_env/server/hf_sandbox.py @@ -0,0 +1,93 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +"""Small hf-sandbox wrapper for terminal-style environments.""" + +from __future__ import annotations + +import os +from dataclasses import dataclass + +_HF_SANDBOX_IMPORT_ERROR: ImportError | None = None + +try: + from hf_sandbox import Sandbox +except ImportError as _hf_sandbox_import_error: # pragma: no cover + _HF_SANDBOX_IMPORT_ERROR = _hf_sandbox_import_error + Sandbox = None # type: ignore[assignment] + + +DEFAULT_IMAGE = "python:3.12" +DEFAULT_FLAVOR = "cpu-basic" +DEFAULT_TIMEOUT = "1h" + + +@dataclass +class ShellResult: + """Normalized result from a command executed in an HF sandbox.""" + + stdout: str + stderr: str + error: str | None + success: bool + + +class HFSandbox: + """Manages one hf-sandbox job for one OpenEnv episode.""" + + def __init__( + self, + *, + image: str | None = None, + flavor: str | None = None, + timeout: str | None = None, + forward_hf_token: bool | None = None, + ): + if Sandbox is None: + raise ImportError( + "hf-sandbox is not installed. Install the terminus_env package " + "dependencies to use HFSandbox. Original import error: " + f"{_HF_SANDBOX_IMPORT_ERROR}" + ) + + resolved_forward = _coerce_bool( + os.getenv("HF_SANDBOX_FORWARD_HF_TOKEN", "false") + ) + if forward_hf_token is not None: + resolved_forward = bool(forward_hf_token) + + self._sandbox = Sandbox.create( + image=image or os.getenv("HF_SANDBOX_IMAGE", DEFAULT_IMAGE), + flavor=flavor or os.getenv("HF_SANDBOX_FLAVOR", DEFAULT_FLAVOR), + timeout=timeout or os.getenv("HF_SANDBOX_TIMEOUT", DEFAULT_TIMEOUT), + forward_hf_token=resolved_forward, + ) + self.sandbox_id: str = self._sandbox.job_id + + def run_shell(self, command: str, timeout_s: int = 120) -> ShellResult: + process = self._sandbox.exec( + "bash", + "-lc", + command, + timeout=timeout_s, + ) + success = process.returncode == 0 + return ShellResult( + stdout=process.stdout or "", + stderr=process.stderr or "", + error=None if success else f"exit code {process.returncode}", + success=success, + ) + + def kill(self) -> None: + try: + self._sandbox.terminate() + except Exception: + pass + + +def _coerce_bool(value: str) -> bool: + return value.strip().lower() in {"1", "true", "yes", "on"} diff --git a/envs/terminus_env/server/local_sandbox.py b/envs/terminus_env/server/local_sandbox.py new file mode 100644 index 000000000..dd4a9b679 --- /dev/null +++ b/envs/terminus_env/server/local_sandbox.py @@ -0,0 +1,107 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +"""Small local sandbox backend for cluster smoke training.""" + +from __future__ import annotations + +import os +import shutil +import subprocess +import tempfile +from dataclasses import dataclass +from pathlib import Path + + +@dataclass +class ShellResult: + """Normalized result from a command executed in a local sandbox.""" + + stdout: str + stderr: str + error: str | None + success: bool + + +class LocalSandbox: + """Runs shell commands in a persistent bubblewrap-backed home directory.""" + + def __init__(self, *, root: str | None = None, **_: object): + if shutil.which("bwrap") is None: + raise RuntimeError( + "local sandbox backend requires `bwrap` on the sandbox node" + ) + self._tmp = tempfile.TemporaryDirectory(prefix="terminus-sandbox-", dir=root) + self._home = Path(self._tmp.name) / "home" / "user" + self._tmp_dir = Path(self._tmp.name) / "tmp" + self._home.mkdir(parents=True, exist_ok=True) + self._tmp_dir.mkdir(parents=True, exist_ok=True) + self.sandbox_id = Path(self._tmp.name).name + + def run_shell(self, command: str, timeout_s: int = 120) -> ShellResult: + process = subprocess.run( + self._bwrap_command(command), + text=True, + capture_output=True, + timeout=timeout_s, + check=False, + ) + success = process.returncode == 0 + return ShellResult( + stdout=process.stdout or "", + stderr=process.stderr or "", + error=None if success else f"exit code {process.returncode}", + success=success, + ) + + def kill(self) -> None: + self._tmp.cleanup() + + def _bwrap_command(self, command: str) -> list[str]: + return [ + "bwrap", + "--die-with-parent", + "--dev", + "/dev", + "--proc", + "/proc", + "--tmpfs", + "/run", + "--bind", + str(self._tmp_dir), + "/tmp", + "--ro-bind", + "/usr", + "/usr", + "--ro-bind", + "/bin", + "/bin", + "--ro-bind", + "/lib", + "/lib", + "--ro-bind", + "/lib64", + "/lib64", + "--ro-bind", + "/etc", + "/etc", + "--dir", + "/home", + "--bind", + str(self._home), + "/home/user", + "--chdir", + "/home/user", + "--setenv", + "HOME", + "/home/user", + "--setenv", + "PATH", + os.environ.get("PATH", "/usr/local/bin:/usr/bin:/bin"), + "/bin/bash", + "-lc", + command, + ] diff --git a/envs/terminus_env/server/requirements.txt b/envs/terminus_env/server/requirements.txt index 78040bdc3..045c53f9a 100644 --- a/envs/terminus_env/server/requirements.txt +++ b/envs/terminus_env/server/requirements.txt @@ -1,9 +1,8 @@ openenv-core[core]>=0.2.2 -e2b-code-interpreter>=1.0.0 +hf-sandbox>=0.1.1 fastapi>=0.115.0 fastmcp>=3.0.0 gradio>=4.0.0 pydantic>=2.0.0 requests>=2.31.0 uvicorn>=0.24.0 - diff --git a/envs/terminus_env/server/terminus_env_environment.py b/envs/terminus_env/server/terminus_env_environment.py index c6f9e1c02..99a114dbe 100644 --- a/envs/terminus_env/server/terminus_env_environment.py +++ b/envs/terminus_env/server/terminus_env_environment.py @@ -4,7 +4,7 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -"""E2B-backed single-tool coding environment inspired by Terminus.""" +"""HF Sandbox-backed single-tool coding environment inspired by Terminus.""" from __future__ import annotations @@ -17,23 +17,25 @@ from openenv.core.env_server.types import Action, Observation try: - from .e2b_sandbox import E2BSandbox + from .hf_sandbox import HFSandbox + from .local_sandbox import LocalSandbox from ..models import CommandResult, TerminusState except ImportError: # pragma: no cover from models import CommandResult, TerminusState - from server.e2b_sandbox import E2BSandbox + from server.hf_sandbox import HFSandbox + from server.local_sandbox import LocalSandbox REWARD_FILE = "/home/user/logs/verifier/reward.txt" class TerminusEnvironment(MCPEnvironment): - """Single-tool terminal environment with one E2B sandbox per episode.""" + """Single-tool terminal environment with one sandbox per episode.""" SUPPORTS_CONCURRENT_SESSIONS = True def __init__(self): - self._sandbox: Optional[E2BSandbox] = None + self._sandbox: Optional[Any] = None self._state = TerminusState(episode_id=str(uuid4()), step_count=0) mcp = FastMCP("terminus_env") @@ -43,7 +45,7 @@ def terminal(command: str = "", final_answer: str = "") -> str: """Run a shell command or submit a final answer inside the sandbox. Args: - command: Shell command to execute in the episode's E2B sandbox. + command: Shell command to execute in the episode sandbox. final_answer: Optional answer string. When provided, stored as the final answer and any reset-time verify commands run. @@ -76,38 +78,36 @@ def reset( episode_id: Optional[str] = None, **kwargs: Any, ) -> Observation: - """Create a fresh E2B sandbox and run optional setup commands.""" + """Create a fresh sandbox and run optional setup commands.""" if self._sandbox: self._sandbox.kill() self._sandbox = None - api_key = os.environ.get("E2B_API_KEY") self._state = TerminusState( episode_id=episode_id or str(uuid4()), step_count=0, ) - if not api_key: - return Observation( - done=True, - reward=None, - metadata={ - "status": "error", - "error": ( - "E2B_API_KEY is not set. Configure it before resetting " - "terminus_env." - ), - }, - ) - + backend = str( + kwargs.get("sandbox_backend") + or os.getenv("TERMINUS_SANDBOX_BACKEND", "hf") + ).lower() + sandbox_label = ( + "HF sandbox" + if backend in {"hf", "hf-sandbox", "huggingface"} + else f"{backend} sandbox" + ) try: - self._sandbox = E2BSandbox(api_key=api_key) + self._sandbox = _create_sandbox(kwargs) except Exception as exc: # noqa: BLE001 return Observation( done=True, reward=None, metadata={ "status": "error", - "error": f"failed to create E2B sandbox: {type(exc).__name__}: {exc}", + "error": ( + f"failed to create {sandbox_label}: " + f"{type(exc).__name__}: {exc}" + ), }, ) @@ -186,6 +186,8 @@ def step( if self._state.submitted_answer is not None and self._state.last_reward is not None: obs.done = True obs.reward = self._state.last_reward + elif obs.reward is None: + obs.reward = 0.0 return obs async def step_async( @@ -199,6 +201,8 @@ async def step_async( if self._state.submitted_answer is not None and self._state.last_reward is not None: obs.done = True obs.reward = self._state.last_reward + elif obs.reward is None: + obs.reward = 0.0 return obs @property @@ -247,6 +251,23 @@ def _coerce_commands(value: Any) -> list[str]: return [str(item) for item in value if str(item).strip()] +def _create_sandbox(kwargs: dict[str, Any]) -> Any: + backend = str( + kwargs.get("sandbox_backend") + or os.getenv("TERMINUS_SANDBOX_BACKEND", "hf") + ).lower() + if backend in {"local", "bwrap", "process"}: + return LocalSandbox(root=kwargs.get("sandbox_root")) + if backend not in {"hf", "hf-sandbox", "huggingface"}: + raise ValueError(f"unknown sandbox backend: {backend}") + return HFSandbox( + image=kwargs.get("sandbox_image") or kwargs.get("hf_sandbox_image"), + flavor=kwargs.get("sandbox_flavor") or kwargs.get("hf_sandbox_flavor"), + timeout=kwargs.get("sandbox_timeout") or kwargs.get("hf_sandbox_timeout"), + forward_hf_token=kwargs.get("forward_hf_token"), + ) + + def _format_for_llm(result) -> str: parts = [] if result.stdout: @@ -258,7 +279,7 @@ def _format_for_llm(result) -> str: return "\n".join(parts) if parts else "(no output)" -def _read_reward_override(sandbox: E2BSandbox) -> Optional[float]: +def _read_reward_override(sandbox: Any) -> Optional[float]: result = sandbox.run_shell(f"cat {REWARD_FILE} 2>/dev/null || true") raw = (result.stdout or "").strip() if not raw: diff --git a/envs/terminus_env/uv.lock b/envs/terminus_env/uv.lock index 4dd01dc19..55f642402 100644 --- a/envs/terminus_env/uv.lock +++ b/envs/terminus_env/uv.lock @@ -154,15 +154,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/71/cc/18245721fa7747065ab478316c7fea7c74777d07f37ae60db2e84f8172e8/beartype-0.22.9-py3-none-any.whl", hash = "sha256:d16c9bbc61ea14637596c5f6fbff2ee99cbe3573e46a716401734ef50c3060c2", size = 1333658, upload-time = "2025-12-13T06:50:28.266Z" }, ] -[[package]] -name = "bracex" -version = "2.6" -source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/63/9a/fec38644694abfaaeca2798b58e276a8e61de49e2e37494ace423395febc/bracex-2.6.tar.gz", hash = "sha256:98f1347cd77e22ee8d967a30ad4e310b233f7754dbf31ff3fceb76145ba47dc7", size = 26642, upload-time = "2025-06-22T19:12:31.254Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/9d/2a/9186535ce58db529927f6cf5990a849aa9e052eea3e2cfefe20b9e1802da/bracex-2.6-py3-none-any.whl", hash = "sha256:0b0049264e7340b3ec782b5cb99beb325f36c3782a32e36e876452fd49a09952", size = 11508, upload-time = "2025-06-22T19:12:29.781Z" }, -] - [[package]] name = "brotli" version = "1.2.0" @@ -689,15 +680,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/ba/5a/18ad964b0086c6e62e2e7500f7edc89e3faa45033c71c1893d34eed2b2de/dnspython-2.8.0-py3-none-any.whl", hash = "sha256:01d9bbc4a2d76bf0db7c1f729812ded6d912bd318d3b1cf81d30c0f845dbf3af", size = 331094, upload-time = "2025-09-07T18:57:58.071Z" }, ] -[[package]] -name = "dockerfile-parse" -version = "2.0.1" -source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/92/df/929ee0b5d2c8bd8d713c45e71b94ab57c7e11e322130724d54f469b2cd48/dockerfile-parse-2.0.1.tar.gz", hash = "sha256:3184ccdc513221983e503ac00e1aa504a2aa8f84e5de673c46b0b6eee99ec7bc", size = 24556, upload-time = "2023-07-18T13:36:07.897Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/7a/6c/79cd5bc1b880d8c1a9a5550aa8dacd57353fa3bb2457227e1fb47383eb49/dockerfile_parse-2.0.1-py2.py3-none-any.whl", hash = "sha256:bdffd126d2eb26acf1066acb54cb2e336682e1d72b974a40894fac76a4df17f6", size = 14845, upload-time = "2023-07-18T13:36:06.052Z" }, -] - [[package]] name = "docstring-parser" version = "0.18.0" @@ -716,41 +698,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/02/10/5da547df7a391dcde17f59520a231527b8571e6f46fc8efb02ccb370ab12/docutils-0.22.4-py3-none-any.whl", hash = "sha256:d0013f540772d1420576855455d050a2180186c91c15779301ac2ccb3eeb68de", size = 633196, upload-time = "2025-12-18T19:00:18.077Z" }, ] -[[package]] -name = "e2b" -version = "2.20.2" -source = { registry = "https://pypi.org/simple" } -dependencies = [ - { name = "attrs" }, - { name = "dockerfile-parse" }, - { name = "httpcore" }, - { name = "httpx" }, - { name = "packaging" }, - { name = "protobuf" }, - { name = "python-dateutil" }, - { name = "rich" }, - { name = "typing-extensions" }, - { name = "wcmatch" }, -] -sdist = { url = "https://files.pythonhosted.org/packages/1a/c4/c99011b3a6dcde4b0ed6f8b70052d67645c7fabb2a673efda18a2c2d1e37/e2b-2.20.2.tar.gz", hash = "sha256:ce69f65e0b07c1002ac4e386d109e4e658575efb2a774aefced92393f2cb2388", size = 157130, upload-time = "2026-04-27T21:27:00.808Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/b6/43/9571d20355555f2eba6c17b9a43d2819070cfd92759511e7198535d83dc0/e2b-2.20.2-py3-none-any.whl", hash = "sha256:8ef964a4d1204a9fd61f4499662175a7a98ad173a81e7e848961799f77750276", size = 297073, upload-time = "2026-04-27T21:26:59.081Z" }, -] - -[[package]] -name = "e2b-code-interpreter" -version = "2.6.1" -source = { registry = "https://pypi.org/simple" } -dependencies = [ - { name = "attrs" }, - { name = "e2b" }, - { name = "httpx" }, -] -sdist = { url = "https://files.pythonhosted.org/packages/07/ff/624d4aecaff4876abca3e507b6866e2dd421890669a99527b1aea7d5b8d1/e2b_code_interpreter-2.6.1.tar.gz", hash = "sha256:2aa7de4241b9394f61ba131c5385c81f9687a727dc40c77e14506bea21863d2c", size = 10643, upload-time = "2026-04-28T02:13:10.984Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/2d/58/5b2438a4059d8333e37e54fb32a2f9fabb02b4257b8f14e426aef625a0bf/e2b_code_interpreter-2.6.1-py3-none-any.whl", hash = "sha256:2265fb7e0fc7a35f66a997e4bf7823291ad46d17b5a77bb1ce08343f5d3e9d1c", size = 13715, upload-time = "2026-04-28T02:13:09.993Z" }, -] - [[package]] name = "email-validator" version = "2.3.0" @@ -940,6 +887,20 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/30/2d/afff2ee87e75d8eb85c92bb8cf0e15b05c23c2ebd8fd8dec781d8601ed7f/hf_gradio-0.4.1-py3-none-any.whl", hash = "sha256:76b8cb8be6abe62d74c1ad2d35b42f0629db89aa9e1a8d033cecfe7c856eeab3", size = 4482, upload-time = "2026-04-17T19:53:31.827Z" }, ] +[[package]] +name = "hf-sandbox" +version = "0.1.1" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "dnspython" }, + { name = "httpx" }, + { name = "huggingface-hub" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/7e/b2/8975aafebf0deb9e137b59c141698dc0e911c68543ca84fd78fc3ffd2edf/hf_sandbox-0.1.1.tar.gz", hash = "sha256:4f4af45a9e0fef33e1d537204b21ce6c09f9f0347d65526be794be81b6987513", size = 5654, upload-time = "2026-05-10T15:48:38.276Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/de/aa/292aa188bbd16b68a13cbd5229e17ba07ca487ef65290995f337c6935a99/hf_sandbox-0.1.1-py3-none-any.whl", hash = "sha256:3200c89521854486945494b765f5106107cca0a4f3247b01f520e8d657e17446", size = 4906, upload-time = "2026-05-10T15:48:36.592Z" }, +] + [[package]] name = "hf-xet" version = "1.4.3" @@ -1031,11 +992,11 @@ wheels = [ [[package]] name = "idna" -version = "3.15" +version = "3.13" source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/82/77/7b3966d0b9d1d31a36ddf1746926a11dface89a83409bf1483f0237aa758/idna-3.15.tar.gz", hash = "sha256:ca962446ea538f7092a95e057da437618e886f4d349216d2b1e294abfdb65fdc", size = 199245, upload-time = "2026-05-12T22:45:57.011Z" } +sdist = { url = "https://files.pythonhosted.org/packages/ce/cc/762dfb036166873f0059f3b7de4565e1b5bc3d6f28a414c13da27e442f99/idna-3.13.tar.gz", hash = "sha256:585ea8fe5d69b9181ec1afba340451fba6ba764af97026f92a91d4eef164a242", size = 194210, upload-time = "2026-04-22T16:42:42.314Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/d2/23/408243171aa9aaba178d3e2559159c24c1171a641aa83b67bdd3394ead8e/idna-3.15-py3-none-any.whl", hash = "sha256:048adeaf8c2d788c40fee287673ccaa74c24ffd8dcf09ffa555a2fbb59f10ac8", size = 72340, upload-time = "2026-05-12T22:45:55.733Z" }, + { url = "https://files.pythonhosted.org/packages/5d/13/ad7d7ca3808a898b4612b6fe93cde56b53f3034dcde235acb1f0e1df24c6/idna-3.13-py3-none-any.whl", hash = "sha256:892ea0cde124a99ce773decba204c5552b69c3c67ffd5f232eb7696135bc8bb3", size = 68629, upload-time = "2026-04-22T16:42:40.909Z" }, ] [[package]] @@ -1665,10 +1626,10 @@ name = "openenv-terminus-env" version = "0.1.0" source = { editable = "." } dependencies = [ - { name = "e2b-code-interpreter" }, { name = "fastapi" }, { name = "fastmcp" }, { name = "gradio" }, + { name = "hf-sandbox" }, { name = "openenv-core", extra = ["core"] }, { name = "pydantic" }, { name = "requests" }, @@ -1683,10 +1644,10 @@ dev = [ [package.metadata] requires-dist = [ - { name = "e2b-code-interpreter", specifier = ">=1.0.0" }, { name = "fastapi", specifier = ">=0.115.0" }, { name = "fastmcp", specifier = ">=3.0.0" }, { name = "gradio", specifier = ">=4.0.0" }, + { name = "hf-sandbox", specifier = ">=0.1.1" }, { name = "openenv-core", extras = ["core"], specifier = ">=0.2.2" }, { name = "pydantic", specifier = ">=2.0.0" }, { name = "pytest", marker = "extra == 'dev'", specifier = ">=8.0.0" }, @@ -2059,21 +2020,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/54/20/4d324d65cc6d9205fabedc306948156824eb9f0ee1633355a8f7ec5c66bf/pluggy-1.6.0-py3-none-any.whl", hash = "sha256:e920276dd6813095e9377c0bc5566d94c932c33b27a3e3945d8389c374dd4746", size = 20538, upload-time = "2025-05-15T12:30:06.134Z" }, ] -[[package]] -name = "protobuf" -version = "7.34.1" -source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/6b/6b/a0e95cad1ad7cc3f2c6821fcab91671bd5b78bd42afb357bb4765f29bc41/protobuf-7.34.1.tar.gz", hash = "sha256:9ce42245e704cc5027be797c1db1eb93184d44d1cdd71811fb2d9b25ad541280", size = 454708, upload-time = "2026-03-20T17:34:47.036Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/ec/11/3325d41e6ee15bf1125654301211247b042563bcc898784351252549a8ad/protobuf-7.34.1-cp310-abi3-macosx_10_9_universal2.whl", hash = "sha256:d8b2cc79c4d8f62b293ad9b11ec3aebce9af481fa73e64556969f7345ebf9fc7", size = 429247, upload-time = "2026-03-20T17:34:37.024Z" }, - { url = "https://files.pythonhosted.org/packages/eb/9d/aa69df2724ff63efa6f72307b483ce0827f4347cc6d6df24b59e26659fef/protobuf-7.34.1-cp310-abi3-manylinux2014_aarch64.whl", hash = "sha256:5185e0e948d07abe94bb76ec9b8416b604cfe5da6f871d67aad30cbf24c3110b", size = 325753, upload-time = "2026-03-20T17:34:38.751Z" }, - { url = "https://files.pythonhosted.org/packages/92/e8/d174c91fd48e50101943f042b09af9029064810b734e4160bbe282fa1caa/protobuf-7.34.1-cp310-abi3-manylinux2014_s390x.whl", hash = "sha256:403b093a6e28a960372b44e5eb081775c9b056e816a8029c61231743d63f881a", size = 340198, upload-time = "2026-03-20T17:34:39.871Z" }, - { url = "https://files.pythonhosted.org/packages/53/1b/3b431694a4dc6d37b9f653f0c64b0a0d9ec074ee810710c0c3da21d67ba7/protobuf-7.34.1-cp310-abi3-manylinux2014_x86_64.whl", hash = "sha256:8ff40ce8cd688f7265326b38d5a1bed9bfdf5e6723d49961432f83e21d5713e4", size = 324267, upload-time = "2026-03-20T17:34:41.1Z" }, - { url = "https://files.pythonhosted.org/packages/85/29/64de04a0ac142fb685fd09999bc3d337943fb386f3a0ec57f92fd8203f97/protobuf-7.34.1-cp310-abi3-win32.whl", hash = "sha256:34b84ce27680df7cca9f231043ada0daa55d0c44a2ddfaa58ec1d0d89d8bf60a", size = 426628, upload-time = "2026-03-20T17:34:42.536Z" }, - { url = "https://files.pythonhosted.org/packages/4d/87/cb5e585192a22b8bd457df5a2c16a75ea0db9674c3a0a39fc9347d84e075/protobuf-7.34.1-cp310-abi3-win_amd64.whl", hash = "sha256:e97b55646e6ce5cbb0954a8c28cd39a5869b59090dfaa7df4598a7fba869468c", size = 437901, upload-time = "2026-03-20T17:34:44.112Z" }, - { url = "https://files.pythonhosted.org/packages/88/95/608f665226bca68b736b79e457fded9a2a38c4f4379a4a7614303d9db3bc/protobuf-7.34.1-py3-none-any.whl", hash = "sha256:bb3812cd53aefea2b028ef42bd780f5b96407247f20c6ef7c679807e9d188f11", size = 170715, upload-time = "2026-03-20T17:34:45.384Z" }, -] - [[package]] name = "py-key-value-aio" version = "0.4.4" @@ -2872,11 +2818,11 @@ wheels = [ [[package]] name = "urllib3" -version = "2.7.0" +version = "2.6.3" source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/53/0c/06f8b233b8fd13b9e5ee11424ef85419ba0d8ba0b3138bf360be2ff56953/urllib3-2.7.0.tar.gz", hash = "sha256:231e0ec3b63ceb14667c67be60f2f2c40a518cb38b03af60abc813da26505f4c", size = 433602, upload-time = "2026-05-07T16:13:18.596Z" } +sdist = { url = "https://files.pythonhosted.org/packages/c7/24/5f1b3bdffd70275f6661c76461e25f024d5a38a46f04aaca912426a2b1d3/urllib3-2.6.3.tar.gz", hash = "sha256:1b62b6884944a57dbe321509ab94fd4d3b307075e0c2eae991ac71ee15ad38ed", size = 435556, upload-time = "2026-01-07T16:24:43.925Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/7f/3e/5db95bcf282c52709639744ca2a8b149baccf648e39c8cc87553df9eae0c/urllib3-2.7.0-py3-none-any.whl", hash = "sha256:9fb4c81ebbb1ce9531cce37674bbc6f1360472bc18ca9a553ede278ef7276897", size = 131087, upload-time = "2026-05-07T16:13:17.151Z" }, + { url = "https://files.pythonhosted.org/packages/39/08/aaaad47bc4e9dc8c725e68f9d04865dbcb2052843ff09c97b08904852d84/urllib3-2.6.3-py3-none-any.whl", hash = "sha256:bf272323e553dfb2e87d9bfd225ca7b0f467b919d7bbd355436d3fd37cb0acd4", size = 131584, upload-time = "2026-01-07T16:24:42.685Z" }, ] [[package]] @@ -2996,18 +2942,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/6e/d4/ed38dd3b1767193de971e694aa544356e63353c33a85d948166b5ff58b9e/watchfiles-1.1.1-pp311-pypy311_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:3e6f39af2eab0118338902798b5aa6664f46ff66bc0280de76fca67a7f262a49", size = 457546, upload-time = "2025-10-14T15:06:13.372Z" }, ] -[[package]] -name = "wcmatch" -version = "10.1" -source = { registry = "https://pypi.org/simple" } -dependencies = [ - { name = "bracex" }, -] -sdist = { url = "https://files.pythonhosted.org/packages/79/3e/c0bdc27cf06f4e47680bd5803a07cb3dfd17de84cde92dd217dcb9e05253/wcmatch-10.1.tar.gz", hash = "sha256:f11f94208c8c8484a16f4f48638a85d771d9513f4ab3f37595978801cb9465af", size = 117421, upload-time = "2025-06-22T19:14:02.49Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/eb/d8/0d1d2e9d3fabcf5d6840362adcf05f8cf3cd06a73358140c3a97189238ae/wcmatch-10.1-py3-none-any.whl", hash = "sha256:5848ace7dbb0476e5e55ab63c6bbd529745089343427caa5537f230cc01beb8a", size = 39854, upload-time = "2025-06-22T19:14:00.978Z" }, -] - [[package]] name = "websockets" version = "15.0.1" diff --git a/tests/envs/test_terminus_environment.py b/tests/envs/test_terminus_environment.py index 0d2752db1..f9e8571d7 100644 --- a/tests/envs/test_terminus_environment.py +++ b/tests/envs/test_terminus_environment.py @@ -8,7 +8,7 @@ from openenv.core.env_server.mcp_types import CallToolAction, ListToolsAction from terminus_env.models import TerminusState -from terminus_env.server.e2b_sandbox import ShellResult +from terminus_env.server.hf_sandbox import ShellResult from terminus_env.server.terminus_env_environment import TerminusEnvironment @@ -57,23 +57,29 @@ def test_lists_single_terminal_tool_without_reset(): assert [tool.name for tool in obs.tools] == ["terminal"] -def test_reset_without_e2b_key_fails_cleanly(monkeypatch): - monkeypatch.delenv("E2B_API_KEY", raising=False) +def test_reset_when_hf_sandbox_creation_fails_cleanly(monkeypatch): + def fail_create(**kwargs): + raise RuntimeError("missing token") + + monkeypatch.setattr( + "terminus_env.server.terminus_env_environment.HFSandbox", + fail_create, + ) env = TerminusEnvironment() obs = env.reset() assert obs.done is True assert obs.metadata["status"] == "error" - assert "E2B_API_KEY" in obs.metadata["error"] + assert "HF sandbox" in obs.metadata["error"] + assert "missing token" in obs.metadata["error"] def test_reset_runs_setup_and_stores_verify_commands(monkeypatch): - monkeypatch.setenv("E2B_API_KEY", "fake-key") fake_sandbox = FakeSandbox() monkeypatch.setattr( - "terminus_env.server.terminus_env_environment.E2BSandbox", - lambda api_key: fake_sandbox, + "terminus_env.server.terminus_env_environment.HFSandbox", + lambda **kwargs: fake_sandbox, ) env = TerminusEnvironment() @@ -91,10 +97,9 @@ def test_reset_runs_setup_and_stores_verify_commands(monkeypatch): def test_reset_fails_when_setup_command_fails(monkeypatch): - monkeypatch.setenv("E2B_API_KEY", "fake-key") monkeypatch.setattr( - "terminus_env.server.terminus_env_environment.E2BSandbox", - lambda api_key: FakeSandbox(), + "terminus_env.server.terminus_env_environment.HFSandbox", + lambda **kwargs: FakeSandbox(), ) env = TerminusEnvironment() @@ -119,6 +124,7 @@ def test_terminal_command_runs_inside_existing_sandbox(): ) assert obs.error is None + assert obs.reward == 0.0 assert "shell: pwd" in _extract_text(obs.result) assert env.state.step_count == 1 assert env.state.commands[0].command == "pwd" diff --git a/tests/envs/test_terminus_harness.py b/tests/envs/test_terminus_harness.py new file mode 100644 index 000000000..d0fa4494a --- /dev/null +++ b/tests/envs/test_terminus_harness.py @@ -0,0 +1,229 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +"""Tests for Terminus harness-oriented session adapter.""" + +from __future__ import annotations + +from typing import Any + +from openenv.core.client_types import StepResult +from openenv.core.env_server.mcp_types import CallToolAction, CallToolObservation +from openenv.core.env_server.types import Observation +from openenv.core.harness import ( + HarnessRunLimits, + MCPHarnessAdapter, + ModelStepResult, + ResourceSessionFactory, + build_harness_rollout_func, +) +from openenv.core.llm_client import LLMResponse, ToolCall +from terminus_env.harness import ( + TerminusSessionFactory, + build_terminal_tool_call, + terminus_reward, +) +from terminus_env.models import CommandResult, TerminusState + + +class FakeTerminusClient: + """Small Terminus-like client used for harness tests.""" + + def __init__(self): + self.closed = False + self.reset_kwargs: dict[str, Any] = {} + self.step_actions: list[CallToolAction] = [] + self._state = TerminusState( + episode_id="terminus-episode", + sandbox_id="fake-sandbox", + ) + + def reset(self, **kwargs: Any) -> StepResult[Observation]: + self.reset_kwargs = dict(kwargs) + self._state.verify_commands = list(kwargs.get("verify", []) or []) + self._state.setup_results = [ + CommandResult(command=command, output="setup", success=True) + for command in kwargs.get("setup", []) or [] + ] + return StepResult( + observation=Observation( + done=False, + reward=None, + metadata={ + "message": "Terminus environment ready.", + "verify_commands": list(self._state.verify_commands), + }, + ), + reward=None, + done=False, + ) + + def step( + self, + action: CallToolAction, + ) -> StepResult[CallToolObservation]: + self.step_actions.append(action) + self._state.step_count += 1 + arguments = action.arguments + if arguments.get("final_answer"): + self._state.submitted_answer = str(arguments["final_answer"]) + self._state.last_reward = 1.0 + output = "Verification: 1/1 passed; reward=1.0" + return StepResult( + observation=CallToolObservation( + tool_name="terminal", + result={"content": [{"type": "text", "text": output}]}, + done=True, + reward=1.0, + ), + reward=1.0, + done=True, + ) + + command = str(arguments.get("command", "")) + self._state.commands.append( + CommandResult(command=command, output=f"shell: {command}", success=True) + ) + return StepResult( + observation=CallToolObservation( + tool_name="terminal", + result={"content": [{"type": "text", "text": f"shell: {command}"}]}, + done=False, + reward=0.0, + ), + reward=0.0, + done=False, + ) + + def state(self) -> TerminusState: + return self._state + + def close(self) -> None: + self.closed = True + + +def test_terminus_session_factory_exposes_terminal_tool(): + client = FakeTerminusClient() + factory = TerminusSessionFactory( + client_factory=lambda: client, + default_setup=["echo setup"], + default_verify=["test -f answer.txt"], + ) + assert isinstance(factory, ResourceSessionFactory) + + session = factory.create(task="Write answer.txt") + + assert [tool.name for tool in session.list_tools()] == ["terminal"] + assert client.reset_kwargs["setup"] == ["echo setup"] + assert client.reset_kwargs["verify"] == ["test -f answer.txt"] + messages = session.initial_messages() + assert "Write answer.txt" in messages[0]["content"] + assert "Verification commands will run after final_answer" in messages[0]["content"] + + session.close() + assert client.closed is True + + +def test_terminus_tool_calls_forward_environment_rewards(): + client = FakeTerminusClient() + factory = TerminusSessionFactory(client_factory=lambda: client) + session = factory.create( + task={ + "instruction": "Create answer.txt", + "verify": ["test -f answer.txt"], + } + ) + + command_result = session.call_tool("terminal", {"command": "pwd"}) + final_result = session.call_tool("terminal", {"final_answer": "done"}) + verify_result = session.verify(transcript=[{"role": "assistant", "content": ""}]) + + assert command_result.done is False + assert command_result.metadata["reward"] == 0.0 + assert final_result.done is True + assert final_result.metadata["reward"] == 1.0 + assert verify_result.env_reward == 1.0 + assert verify_result.done is True + assert client.step_actions[0].tool_name == "terminal" + assert client.step_actions[0].arguments == {"command": "pwd"} + session.close() + + +def test_terminus_terminal_json_parser(): + command = build_terminal_tool_call('{"command": "pytest -q"}') + final_answer = build_terminal_tool_call('```json\n{"final_answer": "done"}\n```') + pi_command = build_terminal_tool_call('terminal(command="echo terminus > answer.txt")') + pi_final_answer = build_terminal_tool_call("terminal(final_answer='done')") + tool_call = build_terminal_tool_call( + '{"name": "terminal", "arguments": "{\\"final_answer\\": ' + '\\"done\\"}"}' + ) + mixed = build_terminal_tool_call( + '{"command": "printf terminus > answer.txt", "final_answer": "done"}' + ) + + assert command.name == "terminal" + assert command.args == {"command": "pytest -q"} + assert final_answer.name == "terminal" + assert final_answer.args == {"final_answer": "done"} + assert pi_command.name == "terminal" + assert pi_command.args == {"command": "echo terminus > answer.txt"} + assert pi_final_answer.name == "terminal" + assert pi_final_answer.args == {"final_answer": "done"} + assert tool_call.name == "terminal" + assert tool_call.args == {"final_answer": "done"} + assert mixed.name == "terminal" + assert mixed.args == {"command": "printf terminus > answer.txt"} + + +def test_terminus_reward_extracts_last_tool_reward(): + rewards = terminus_reward( + completions=[ + [ + {"role": "tool", "content": '{"reward": 0.25}'}, + {"role": "tool", "content": '{"done": true, "reward": 1.0}'}, + ], + [{"role": "tool", "content": "Verification: 1/2 passed; reward=0.5"}], + [{"role": "assistant", "content": "no reward"}], + ] + ) + + assert rewards == [1.0, 0.5, 0.0] + + +def test_terminus_session_factory_works_with_generic_rollout_helper(): + factory = TerminusSessionFactory( + client_factory=FakeTerminusClient, + default_verify=["test -f answer.txt"], + ) + adapter = MCPHarnessAdapter() + + def model_step_builder(trainer, session): + tool_call = ToolCall( + id="terminal-1", + name="terminal", + args={"final_answer": "done"}, + ) + return lambda messages, tools, sampling: ModelStepResult( + response=LLMResponse(content="done", tool_calls=[tool_call]), + prompt_ids=[3, 4], + completion_ids=[5, 6], + logprobs=[-0.3, -0.4], + ) + + rollout_func = build_harness_rollout_func( + session_factory=factory, + harness_adapter=adapter, + model_step_builder=model_step_builder, + limits=HarnessRunLimits(max_turns=3), + ) + + result = rollout_func(["Create answer.txt"], trainer=object()) + + assert result["prompt_ids"] == [[3, 4]] + assert result["completion_ids"] == [[5, 6]] + assert result["logprobs"] == [[-0.3, -0.4]] + assert result["env_reward"] == [1.0]