diff --git a/.gitignore b/.gitignore index fed309b07..f5c2a93f9 100644 --- a/.gitignore +++ b/.gitignore @@ -109,3 +109,6 @@ outputs/ .uv/ *.backup*/ + +# logs +*.log diff --git a/PR_README.md b/PR_README.md new file mode 100644 index 000000000..a98fef8bb --- /dev/null +++ b/PR_README.md @@ -0,0 +1,36 @@ +### PR: Fleet environments (OpenEnv) + +This PR documents and refines the **Fleet** runtime integration for OpenEnv. + +#### What this enables +- Run OpenEnv environments on **Fleet (remote)** with **no local Docker**. +- Keep a strict split between: + - **Orchestration (HTTP)**: `reset / step / state` + - **Agent actions (MCP)**: `tools/list + tools/call` + +#### What this is *not* +- This is **not** the local “Dockerized env server + env container” setup. +- There is **no container/provider abstraction** here; Fleet hosts the runtime remotely (HTTP env server + MCP service). The client only connects. + +#### Main abstractions +- **`FleetEnvClient` (HTTP)**: orchestrator handle for reset/step/state. +- **`FleetMCPTools` (MCP)**: agent handle for listing/calling tools. + - Unions tools across Fleet’s MCP endpoints (today often `api/v1/mcp` and `mcp`) + - Returns tools in **OpenAI “tools” dict format** (via `convert_tool_format`) + - Routes tool calls to the owning endpoint (cached after discovery) + +#### Quickstart +- Install: `pip install "openenv-core[fleet]"` +- Set: `export FLEET_API_KEY="..."` +- Run: `python examples/fleet_env_example.py ` + +#### References +- RFC 001: `rfcs/001-abstractions.md` +- RFC 003: `rfcs/003-mcp-support.md` + +#### TODOs / known sharp edges +- Endpoint discovery (avoid hardcoding `api/v1/mcp` vs `mcp`) +- Reset inconsistencies across some env keys (better errors + compatibility notes) +- Tool-name collision policy across endpoints +- Retries/backoff and clearer “endpoint down” failure modes + diff --git a/README.md b/README.md index 0a0e31d7e..79b8878f8 100644 --- a/README.md +++ b/README.md @@ -294,6 +294,25 @@ Supporters include: Meta-PyTorch, Hugging Face, [Patronus AI](https://patronus.a And we'd also like to acknowledge the team at Farama Foundation as the OpenEnv API was heavily inspired by the work you all have done on Gymnasium. Cheers! +## Fleet Telemetry + +`FleetTaskEnv` emits Logfire events to track rollout lifecycle. Every `fleet_rollout_started` gets a matching `fleet_rollout_completed` with a `failure_reason`: + +``` +started = completed + init_err + tools_err + no_computer + max_steps + abandoned +``` + +| `failure_reason` | When | +|---|---| +| *(null)* | Rollout completed normally (verifier ran) | +| `init_error` | Fleet provisioning failed | +| `tools_error` | `list_tools()` MCP call failed | +| `computer_tool_missing` | CUA modality but no `computer` tool | +| `max_steps` | Caller hit turn limit without running verifier | +| `abandoned` | Caller stopped early (context overflow, job cancelled, crash) | + +Set `LOGFIRE_TOKEN` to enable. Events include `step_count`, `reward`, `verifier_success`, and task context (env_key, version, modality). + ## License BSD 3-Clause License (see [LICENSE](./LICENSE) file) diff --git a/examples/fleet_env_example.py b/examples/fleet_env_example.py new file mode 100644 index 000000000..e4324c617 --- /dev/null +++ b/examples/fleet_env_example.py @@ -0,0 +1,153 @@ +""" +Example: Orchestrator + Agent loop using OpenEnv on Fleet. + +Demonstrates the split architecture: +1. Orchestrator: Provisions environment, resets episodes (HTTP). +2. Agent: Lists tools, calls tools (MCP). + +Prerequisites: + pip install "openenv-core[fleet]" + export FLEET_API_KEY="..." + export FLEET_ENV_KEY="..." # e.g. "browser-env" or your custom env +""" + +import asyncio +import os +import random +import sys +from typing import Any, Dict, List, Sequence + +# Ensure we can import from src/ if running from repo root +sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "src"))) + +try: + # `openenv` installs top-level packages like `envs`, `core`, etc. + # This example also prepends `src/` above so it works from a repo checkout. + from envs.fleet_env import FleetEnvClient +except ImportError as e: + raise ImportError( + "Could not import `envs.fleet_env`. " + "Run from the repo root, or install OpenEnv in editable mode: " + "`python -m pip install -e '.[fleet]'`." + ) from e + +def get_openai_tool_param_enum(tool_def: Dict[str, Any], param_name: str) -> List[str]: + """Extract an enum list for a parameter from an OpenAI 'tools' dict.""" + schema = tool_def.get("function", {}).get("parameters", {}) + if not isinstance(schema, dict): + return [] + props = schema.get("properties", {}) + if not isinstance(props, dict): + return [] + param_spec = props.get(param_name, {}) + if not isinstance(param_spec, dict): + return [] + enum = param_spec.get("enum", []) + return enum if isinstance(enum, list) else [] + +SAFE_COMPUTER_ACTION_PREFERENCE: Sequence[str] = ("screenshot", "wait", "cursor_position") + + +def pick_safe_computer_action(tool_def: Dict[str, Any]) -> str: + """Pick a non-destructive default action for the Fleet 'computer' tool. + + Prefer safe actions like screenshot/wait, falling back to first enum. + """ + actions = get_openai_tool_param_enum(tool_def, "action") + if not actions: + raise ValueError("Tool 'computer' has no available actions in schema.") + + action_set = set(actions) + safe_available = [a for a in SAFE_COMPUTER_ACTION_PREFERENCE if a in action_set] + if safe_available: + return random.choice(safe_available) + return actions[0] + +def main(): + api_key = os.environ.get("FLEET_API_KEY") + + # 1. Get env_key from args or env var + env_key = sys.argv[1] if len(sys.argv) > 1 else os.environ.get("FLEET_ENV_KEY") + + if not api_key or not env_key: + print("Usage: python fleet_env_example.py ") + print(" or: export FLEET_ENV_KEY=... && python fleet_env_example.py") + raise ValueError("Please set FLEET_API_KEY and provide an env_key.") + + print(f"Provisioning Fleet environment: {env_key}...") + + # 1. Provision & Split Handles (Synchronous) + # This must be run outside of an async loop because it manages its own loop. + try: + orch, tools = FleetEnvClient.from_fleet( + api_key=api_key, + env_key=env_key, + ttl_seconds=600, # 10 min TTL + ) + except Exception as e: + raise ValueError(f"Failed to provision environment: {e}") + + + try: + # Run the async agent loop + asyncio.run(agent_loop(orch, tools)) + except BaseException as e: + print(f"\n❌ Agent loop failed: {e}") + finally: + # 5. Cleanup (Synchronous) + print("\nOrchestrator: Closing environment...") + orch.close() + print("Done.") + + +async def agent_loop(orch, tools): + # 2. Orchestration: Start Episode (HTTP calls, sync method but we wrap or call directly) + # orch.reset() is sync (requests), so it blocks the loop briefly. That's fine for this example. + print("Orchestrator: Resetting environment...") + obs = orch.reset() + print(f"Reset complete. Initial observation keys: {list(obs.observation.metadata.keys())}") + + # 3. Agent: Discover Tools (Async) + print("\nAgent: Discovering tools...") + listed = await tools.list_tools() + tool_defs = listed.tools + print(f"Available tools ({len(tool_defs)}): {[t['function']['name'] for t in tool_defs]}") + # Print the derived schema payloads (mirrors MCP Tool.inputSchema content, but OpenAI-shaped) + print([t["function"]["parameters"] for t in tool_defs]) + + if not tool_defs: + print("No MCP tools available (all MCP endpoints may be down).") + return + + # 4. Agent: Call a Tool + target_tool_name = "computer" + target_def = next((t for t in tool_defs if t["function"]["name"] == target_tool_name), None) + + if not target_def: + print(f"Tool '{target_tool_name}' not found, picking first available.") + target_def = tool_defs[0] + target_tool_name = target_def["function"]["name"] + + print(f"\nTarget Tool: {target_tool_name}") + # Inspect schema to construct params (in a real agent, the LLM does this) + # schema = target_def["function"]["parameters"] + # print(f"Schema: {json.dumps(schema, indent=2)}") + + params = {} + if target_tool_name == "computer": + # Choose a supported action from the schema (safe default). + params = {"action": pick_safe_computer_action(target_def)} + + print(f"\nAgent: Calling tool '{target_tool_name}' with {params}...") + result = await tools.call_tool(target_tool_name, params) + + + # Result is typically a list of MCP content objects (TextContent/ImageContent) + # We'll just print a summary. + print("Agent: Tool execution result received.") + print(f"{result=}") + + +if __name__ == "__main__": + main() + diff --git a/pyproject.toml b/pyproject.toml index 37d7400a2..b62de692f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -26,6 +26,13 @@ dependencies = [ "tomli-w>=1.2.0" ] +[project.optional-dependencies] +fleet = [ + "mcp>=1.0.0", + "fleet-python>=0.2.79", + "logfire>=3.0.0", +] + [project.scripts] openenv = "openenv_cli.__main__:main" @@ -39,6 +46,9 @@ include-package-data = true [tool.setuptools.packages.find] where = ["src"] +[tool.pytest.ini_options] +pythonpath = ["src"] + [tool.coverage.run] omit = [ "openenv_cli/templates/**", diff --git a/src/envs/__init__.py b/src/envs/__init__.py new file mode 100644 index 000000000..ca0b6c7d6 --- /dev/null +++ b/src/envs/__init__.py @@ -0,0 +1 @@ +# OpenEnv environments package diff --git a/src/envs/fleet_env/README.md b/src/envs/fleet_env/README.md new file mode 100644 index 000000000..298c9597c --- /dev/null +++ b/src/envs/fleet_env/README.md @@ -0,0 +1,284 @@ +### Fleet environments + +This integration lets you run Fleet environments through OpenEnv, simplifying the interaction and adhering to OpenEnv standards; keeping **orchestration** and **agent actions** separate. + +- **Orchestration (HTTP)**: reset / step / state (episode + lifecycle control) +- **Agent actions (MCP)**: tools/list + tools/call (what the agent can do) + +That boundary matches **RFC 001** (split planes) and lines up with **RFC 003**'s "tool-call actions". +If you want the longer-form design background, see: + +- **RFC 001**: [`rfcs/001-abstractions.md`](../../../rfcs/001-abstractions.md) +- **RFC 003**: [`rfcs/003-mcp-support.md`](../../../rfcs/003-mcp-support.md) + +### What this is *not* (container/provider abstraction) + +This Fleet integration is intentionally **not yet** a "container runtime" abstraction (no Docker provider, no local container lifecycle). +In particular, there is **no local Dockerized setup** where you spin up an "env server" container alongside an "env" container; Fleet hosts the runtime remotely (HTTP env server + MCP service), and the client connects to it. + +Fleet provisions and runs the environment remotely; on the client side we just hold two handles: + +- `FleetEnvClient` for the HTTP orchestration plane +- `FleetMCPTools` for the MCP agent plane + +### Architecture (one picture) + +```mermaid +flowchart TB + subgraph Client["OpenEnv client (local)"] + Agent["Agent / Policy"] + Orch["FleetEnvClient (HTTP)"] + Tools["FleetMCPTools (MCP)"] + end + + subgraph Runtime["Fleet runtime (remote)"] + HTTP["Instance Manager HTTP API"] + MCP3003["Per-env MCP server (port 3003)"] + MCP8081["MCP Aggregator (port 8081)"] + end + + Orch -- reset/step/state --> HTTP + Agent -- list_tools/call_tool --> Tools + Tools -- "tool_use: /mcp" --> MCP3003 + Tools -- "computer_use: /api/v1/mcp" --> MCP8081 +``` + +### MCP Endpoint Routing by Modality + +Fleet exposes two MCP endpoints per instance, on different ports: + +| Modality | Endpoint | Port | What it serves | +|----------|----------|------|----------------| +| `tool_use` | `{root}/mcp` | 3003 | Per-env API tools only | +| `computer_use` | `{root}/api/v1/mcp` | 8081 | `computer` tool + aggregated API tools | + +`FleetEnvClient.from_fleet()` / `from_fleet_async()` selects the correct endpoint based on `image_type`: +- `image_type="mcp"` (computer_use) → `/api/v1/mcp` +- `image_type="standard"` (tool_use) → `/mcp` + +This eliminates partial failure ambiguity — each modality talks to exactly one endpoint. + +### Sequence: SkyRL → OpenEnv (training rollout) + +``` +SkyRL Generator SkyRL FleetTaskEnv (env.py) OpenEnv FleetTaskEnv (task_env.py) FleetEnvClient (client.py) FleetMCPTools (mcp_tools.py) Fleet Runtime + | | | | | | + |-- _env_init(env, prompt) --------->| | | | | + | |-- init_async(prompt) ------------->| | | | + | | |-- fleet_rollout_started | | | + | | | | | | + | | |-- _ensure_provisioned() ---------->| | | + | | | image_type = "mcp" | "standard" |-- from_fleet_async() ------------->| | + | | | | sdk_image_type = "mcp" | None | | + | | | |-- async_fleet.make() --------------------------------------------->| provision instance + | | | |<-- env handle + urls -----------------------------------------------| + | | | | | | + | | | | if mcp: url = /api/v1/mcp | | + | | | | else: url = /mcp | | + | | | |-- FleetMCPTools(url) ------------->| | + | | |<-- (orch, tools) ------------------| | | + | | | | | | + | | |-- reset() (swallowed on failure) | | | + | | | | | | + | | |-- tools.list_tools() -------------------------------------------->|-- list_tools() ---------------------->| MCP endpoint + | | | FATAL if fails or empty | |<-- tools[] --------------------------| + | | | | | | + | | | filter by modality: | | | + | | | computer_use → keep "computer" | | | + | | | tool_use → exclude "computer" | | | + | | | FATAL if no tools after filter | | | + | | | | | | + | | | (computer_use) screenshot ------------------------------------------------>| call_tool("computer", screenshot)-->| + | | | | | | + | |<-- obs {prompt, tools, screenshot} | | | | + | | | | | | + | | self.tools = obs["tools"] | | | | + | | FATAL if empty | | | | + | | build system prompt + tools_json | | | | + |<-- (prompt, info) -----------------| | | | | + | | | | | | + |== AGENT LOOP (per turn) ===========|====================================|====================================|====================================|====================================| + | | | | | | + |-- step_async(action) ------------->| | | | | + | |-- step_async(action) ------------->| | | | + | | |-- tools.call_tool(name, args) ------------------------------------------->| call_tool(name, args) ------------->| + | | |<-- result -----------------------------------------------------------------|<-- result --------------------------| + | | | | | | + | | | if done: _compute_reward() | | | + | | | fleet_rollout_completed | | | + | |<-- (obs, reward, done, info) ------| | | | + |<-- (obs, reward, done, info) ------| | | | | +``` + +**Failure handling:** +- `_ensure_provisioned()` fails → `fleet_rollout_completed(failure_reason="init_error")` → raise +- `list_tools()` fails or empty → `fleet_rollout_completed(failure_reason="tools_error")` → raise +- No `computer` tool for computer_use → `fleet_rollout_completed(failure_reason="computer_tool_missing")` → raise +- `reset()` fails → warning only, continues with empty observation (non-fatal) +- `screenshot` fails → warning only, continues without screenshot (non-fatal) + +### Pseudocode + +```python +class FleetEnvClient(HTTPEnvClient): + @classmethod + def from_fleet(cls, api_key, env_key, data_key, data_version, image_type, **kwargs): + # 1) Provision a remote instance via Fleet SDK + sdk_image_type = image_type if image_type == "mcp" else None + env = Fleet(api_key=api_key).make( + env_key=env_key, image_type=sdk_image_type, data_key=f"{data_key}:{data_version}", **kwargs + ) + + # 2) Orchestrator handle talks to the Instance Manager (HTTP) + orch = cls(base_url=env.urls.manager.api, ...) + + # 3) Pick MCP endpoint based on modality + if image_type == "mcp": + mcp_urls = (f"{env.urls.root}api/v1/mcp",) # aggregator (port 8081) + else: + mcp_urls = (f"{env.urls.root}mcp",) # per-env server (port 3003) + tools = FleetMCPTools(api_key=api_key, mcp_urls=mcp_urls) + + return orch, tools +``` + +### Quickstart + +- Install: `pip install "openenv-core[fleet]"` +- Set: `export FLEET_API_KEY="..."` +- Run: `python examples/fleet_env_example.py ` + +### Walkthrough (what the example is doing) + +See `examples/fleet_env_example.py`. + +1. **Provision** a remote env on Fleet: + - `orch, tools = FleetEnvClient.from_fleet(...)` +2. **Reset** the episode via HTTP: + - `obs = orch.reset()` +3. **Discover tools** via MCP: + - `listed = await tools.list_tools()` + - `tool_defs = listed.tools` + - Each entry in `tool_defs` has `{"type": "function", "function": {"name": ..., "parameters": ...}}` +4. **Call a tool** (the example picks a "safe" action from the schema and calls `computer`) + +Here's a real run (trimmed) so you know what "healthy" looks like: + +```text +Provisioning Fleet environment: amazon... +Orchestrator: Resetting environment... +Reset complete. Initial observation keys: [] + +Agent: Discovering tools... +Available tools (1): ['computer'] +[{'type': 'object', 'properties': {'action': {'enum': ['screenshot', ..., 'cursor_position'], 'type': 'string'}, ...}, 'required': ['action']}] + +Target Tool: computer +Agent: Calling tool 'computer' with {'action': 'cursor_position'}... +Agent: Tool execution result received. +result=CallToolResult(... structuredContent={'result': {'output': 'X=683,Y=384', ...}}) +``` + +### Telemetry + +Structured error tracking via [Logfire](https://logfire.pydantic.dev/). Covers init failures, tool call failures, MCP timeouts, and verifier errors across all fleet task executions. + +**Setup:** + +```python +from envs.fleet_env import configure_fleet_telemetry + +# Default environment is "training_rollouts" - shows up in Logfire env dropdown +configure_fleet_telemetry(token="your-logfire-token") + +# Or specify a custom environment +configure_fleet_telemetry(token="your-logfire-token", environment="production") +``` + +If you never call `configure_fleet_telemetry()`, logfire silently drops all events (no noise, no crashes). + +**Consistent Schema:** + +All events include these base attributes (set automatically via task context): + +| Attribute | Description | Example | +|-----------|-------------|---------| +| `env_key` | Environment key | `github`, `amazon` | +| `env_version` | Environment version | `v0.0.12` | +| `task_key` | Task identifier | `github-create-issue-001` | +| `modality` | Task modality | `tool_use`, `computer_use` | + +**What gets tracked:** + +| Event | Level | Description | +|-------|-------|-------------| +| `fleet_rollout_started` | info | Rollout attempt started (emitted before provisioning, counts init failures too) | +| `fleet_rollout_completed` | info | Rollout terminated: includes `reward`, `step_count`, `failure_reason` | +| `fleet_provisioning_completed` | info | Instance provisioned: includes `provisioning_time_s` (queue delay + create time) | +| `fleet_make_retry` | warning | Transient `Fleet.make()` failure, retrying | +| `fleet_make_failed` | error | `Fleet.make()` permanently failed | +| `fleet_env_reset_failed` | warning | Env reset threw (non-fatal, continues with empty observation) | +| `fleet_screenshot_failed` | exception | Initial screenshot threw | +| `fleet_tool_call_failed` | exception | Agent tool call threw (Python exception after retries exhausted) | +| `fleet_mcp_tool_error` | warning | MCP server returned error in tool result (tool ran but failed) | +| `fleet_verifier_failed` | exception | Verifier **code** threw an exception (not model failure — model getting wrong answer = reward 0.0 without verifier_error) | +| `fleet_list_tools_retry` | warning | list_tools retrying | +| `fleet_list_tools_exhausted` | error | list_tools retries exhausted | +| `fleet_call_tool_retry` | warning | call_tool retrying | +| `fleet_call_tool_exhausted` | error | call_tool retries exhausted | + +**Failure reasons in `fleet_rollout_completed`:** + +| `failure_reason` | Meaning | +|------------------|---------| +| `init_error` | Provisioning failed (`_ensure_provisioned()`) | +| `tools_error` | `list_tools()` MCP call failed or returned no tools | +| `computer_tool_missing` | Tools listed but no `computer` tool for computer_use modality (MCP image config issue) | + +**Example Logfire SQL Query:** + +```sql +-- Rollout summary by env/version +SELECT + attributes->>'env_key' as env, + attributes->>'env_version' as version, + attributes->>'modality' as modality, + COUNT(*) FILTER (WHERE message = 'fleet_rollout_started') as total_rollouts, + COUNT(*) FILTER (WHERE message = 'fleet_rollout_completed') as completed, + COUNT(*) FILTER (WHERE message = 'fleet_rollout_completed' + AND attributes->>'failure_reason' = 'init_error') as init_errors, + COUNT(*) FILTER (WHERE message = 'fleet_rollout_completed' + AND attributes->>'failure_reason' = 'tools_error') as tools_errors, + COUNT(*) FILTER (WHERE message = 'fleet_rollout_completed' + AND attributes->>'failure_reason' = 'computer_tool_missing') as computer_missing, + COALESCE(SUM(CAST(attributes->>'step_count' AS INT)) + FILTER (WHERE message = 'fleet_rollout_completed'), 0) as total_steps, + COUNT(*) FILTER (WHERE message IN ( + 'fleet_tool_call_failed', 'fleet_mcp_tool_error')) as tool_errors, + COUNT(*) FILTER (WHERE message = 'fleet_verifier_failed') as verifier_errors +FROM records +WHERE service_name = 'openenv-fleet' +GROUP BY 1, 2, 3 +ORDER BY total_rollouts DESC; +``` + +```sql +-- Provisioning latency by env (detects Fleet queue serialization) +SELECT + attributes->>'env_key' as env, + COUNT(*) as instances, + ROUND(AVG(CAST(attributes->>'provisioning_time_s' AS FLOAT)), 1) as avg_provision_s, + MAX(CAST(attributes->>'provisioning_time_s' AS FLOAT)) as max_provision_s, + MIN(CAST(attributes->>'provisioning_time_s' AS FLOAT)) as min_provision_s +FROM records +WHERE service_name = 'openenv-fleet' + AND message = 'fleet_provisioning_completed' +GROUP BY 1 +ORDER BY avg_provision_s DESC; +``` + +### TODOs + +- **Reset inconsistencies**: some env keys don't behave consistently on `/reset` (needs better error reporting + a compatibility note per env type). +- **Support for all OpenEnv environments**: Starting with OpenEnv, we want to support any backend to run environments at scale. +- **GA access**: GA the Fleet platform. diff --git a/src/envs/fleet_env/__init__.py b/src/envs/fleet_env/__init__.py new file mode 100644 index 000000000..1ab3b4929 --- /dev/null +++ b/src/envs/fleet_env/__init__.py @@ -0,0 +1,35 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +"""Fleet Environment - client-side adapter for Fleet-hosted MCP environments.""" + +from .client import FleetEnvClient +from .context_manager import CONTEXT_TOOLS, CONTEXT_TOOL_NAMES, ContextManager +from .mcp_tools import FleetMCPTools +from .models import CallToolAction, ListToolsAction +from .task_env import FleetTaskEnv, make_fleet_task_env +from .telemetry import configure_fleet_telemetry, set_task_context, clear_task_context +from .trace import create_trace_job, upload_trace +from .task_evaluator import TaskEvaluator, evaluate_task + +__all__ = [ + "FleetEnvClient", + "FleetMCPTools", + "ListToolsAction", + "CallToolAction", + "FleetTaskEnv", + "make_fleet_task_env", + "TaskEvaluator", + "evaluate_task", + "ContextManager", + "CONTEXT_TOOLS", + "CONTEXT_TOOL_NAMES", + "configure_fleet_telemetry", + "set_task_context", + "clear_task_context", + "create_trace_job", + "upload_trace", +] diff --git a/src/envs/fleet_env/client.py b/src/envs/fleet_env/client.py new file mode 100644 index 000000000..c174c24e2 --- /dev/null +++ b/src/envs/fleet_env/client.py @@ -0,0 +1,418 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +"""Fleet Environment client (HTTP orchestration only).""" + +import asyncio +import dataclasses +import logging +from typing import Any, Dict, List, Optional, Tuple, Type + +try: + # In-repo imports + from core.env_server.types import Action, Observation, State + from core.http_env_client import HTTPEnvClient + from core.client_types import StepResult +except ImportError: + # Standalone imports + from openenv_core.env_server.types import Action, Observation, State + from openenv_core.http_env_client import HTTPEnvClient + from openenv_core.client_types import StepResult + +from .mcp_tools import FleetMCPTools +from .models import CallToolAction, ListToolsAction +from .telemetry import fleet_error, fleet_warning, fleet_info + + +class FleetEnvClient(HTTPEnvClient[Action, Observation]): + """Orchestrator-facing client for Fleet-hosted environments (HTTP only).""" + + def __init__( + self, + base_url: str, + fleet_env_handle: Any, + api_key: str, + mcp_urls: Tuple[str, ...], + **kwargs: Any, + ): + super().__init__( + base_url=base_url, + default_headers={"Authorization": f"Bearer {api_key}"}, + **kwargs, + ) + self._fleet_env = fleet_env_handle + self._api_key = api_key + self._mcp_urls = mcp_urls + + @classmethod + def from_fleet( + cls: Type["FleetEnvClient"], + api_key: str, + env_key: str, + data_key: str, + data_version: str, + image_type: str, + region: Optional[str] = None, + ttl_seconds: Optional[int] = 3600, + env_variables: Optional[Dict[str, Any]] = None, + **kwargs: Any, + ) -> Tuple["FleetEnvClient", FleetMCPTools]: + try: + from fleet import Fleet + except ImportError as e: + raise ImportError( + "Fleet support requires the optional dependency set. " + "Install with `pip install openenv[fleet]`." + ) from e + + # Use synchronous Fleet client for the orchestrator handle. + # This ensures .close() and other lifecycle methods are synchronous. + fleet = Fleet(api_key=api_key) + + # Fleet SDK expects data_key in "key:version" format + data_key_spec = None + if data_key: + if data_version: + data_key_spec = f"{data_key}:{data_version}" + else: + data_key_spec = data_key + + import time + import logging + + _logger = logging.getLogger(__name__) + + _logger.info(f"Creating Fleet instance: env_key={env_key}, ttl={ttl_seconds}s") + start = time.time() + + # Retry logic for transient Fleet API failures (e.g., health check failures) + max_retries = 3 + retry_base_delay = 2.0 # seconds + env = None + + for attempt in range(max_retries): + try: + # Fleet SDK expects image_type=None for standard images + sdk_image_type = image_type if image_type == "mcp" else None + env = fleet.make( + env_key=env_key, + region=region, + ttl_seconds=ttl_seconds, + env_variables=env_variables, + image_type=sdk_image_type, + data_key=data_key_spec, + ) + break # Success + except Exception as e: + error_msg = str(e) + # Retry on transient errors (health check failures, timeouts, etc.) + is_transient = any( + x in error_msg.lower() + for x in ["health check", "timeout", "connection", "temporarily"] + ) + if attempt < max_retries - 1 and is_transient: + delay = retry_base_delay * (2**attempt) + _logger.warning( + f"[env={env_key}] Fleet.make() failed (attempt {attempt + 1}/{max_retries}): {e}. " + f"Retrying in {delay:.1f}s..." + ) + fleet_warning( + "fleet_make_retry", + attempt=attempt + 1, + max_retries=max_retries, + error_type=type(e).__name__, + error_message=str(e), + retry_delay_s=delay, + ) + time.sleep(delay) + else: + _logger.error( + f"[env={env_key}] Fleet.make() failed after {attempt + 1} attempt(s): {e}" + ) + fleet_error( + "fleet_make_failed", + attempt=attempt + 1, + max_retries=max_retries, + error_type=type(e).__name__, + error_message=str(e), + ) + raise + + elapsed = time.time() - start + instance_id = getattr(env, "instance_id", "unknown") + _logger.info(f"Fleet instance ready in {elapsed:.1f}s: {instance_id}") + + root = env.urls.root + # Pick MCP endpoint based on modality: + # - computer_use: aggregator on port 8081 (has computer tool + API tools) + # - tool_use: per-env MCP server on port 3003 (API tools only) + if image_type == "mcp": + mcp_urls = (f"{root}api/v1/mcp",) + else: + mcp_urls = (f"{root}mcp",) + + orch = cls( + base_url=env.urls.manager.api, + fleet_env_handle=env, + api_key=api_key, + mcp_urls=mcp_urls, + **kwargs, + ) + tools = FleetMCPTools(api_key=api_key, mcp_urls=mcp_urls) + return orch, tools + + @classmethod + async def from_fleet_async( + cls: Type["FleetEnvClient"], + api_key: str, + env_key: str, + data_key: str, + data_version: str, + image_type: str, + region: Optional[str] = None, + ttl_seconds: Optional[int] = 3600, + env_variables: Optional[Dict[str, Any]] = None, + **kwargs: Any, + ) -> Tuple["FleetEnvClient", FleetMCPTools]: + """Async version of from_fleet() — does not block the event loop. + + Uses AsyncFleet.make() for provisioning and asyncio.sleep() for retries, + allowing other async trajectories to progress while waiting. + """ + try: + from fleet._async import AsyncFleet + except ImportError as e: + raise ImportError( + "Fleet support requires the optional dependency set. " + "Install with `pip install openenv[fleet]`." + ) from e + + async_fleet = AsyncFleet(api_key=api_key) + + # Fleet SDK expects data_key in "key:version" format + data_key_spec = None + if data_key: + if data_version: + data_key_spec = f"{data_key}:{data_version}" + else: + data_key_spec = data_key + + import time + import logging + + _logger = logging.getLogger(__name__) + + _logger.info( + f"Creating Fleet instance (async): env_key={env_key}, ttl={ttl_seconds}s" + ) + start = time.time() + + # Retry logic with async sleep (non-blocking) + max_retries = 3 + retry_base_delay = 2.0 # seconds + env = None + + # Fleet SDK expects image_type=None for standard images + sdk_image_type = image_type if image_type == "mcp" else None + + for attempt in range(max_retries): + try: + env = await async_fleet.make( + env_key=env_key, + region=region, + ttl_seconds=ttl_seconds, + env_variables=env_variables, + image_type=sdk_image_type, + data_key=data_key_spec, + ) + break # Success + except Exception as e: + error_msg = str(e) + # Retry on transient errors (health check failures, timeouts, etc.) + is_transient = any( + x in error_msg.lower() + for x in ["health check", "timeout", "connection", "temporarily"] + ) + if attempt < max_retries - 1 and is_transient: + delay = retry_base_delay * (2**attempt) + _logger.warning( + f"[env={env_key}] AsyncFleet.make() failed (attempt {attempt + 1}/{max_retries}): {e}. " + f"Retrying in {delay:.1f}s..." + ) + fleet_warning( + "fleet_make_retry", + attempt=attempt + 1, + max_retries=max_retries, + error_type=type(e).__name__, + error_message=str(e), + retry_delay_s=delay, + ) + await asyncio.sleep(delay) + else: + _logger.error( + f"[env={env_key}] AsyncFleet.make() failed after {attempt + 1} attempt(s): {e}" + ) + fleet_error( + "fleet_make_failed", + attempt=attempt + 1, + max_retries=max_retries, + error_type=type(e).__name__, + error_message=str(e), + ) + raise + + elapsed = time.time() - start + instance_id = getattr(env, "instance_id", "unknown") + _logger.info(f"Fleet instance ready (async) in {elapsed:.1f}s: {instance_id}") + fleet_info( + "fleet_provisioning_completed", + provisioning_time_s=round(elapsed, 1), + instance_id=instance_id, + ) + + root = env.urls.root + # Pick MCP endpoint based on modality: + # - computer_use (image_type="mcp"): aggregator on port 8081 (has computer tool + API tools) + # - tool_use: per-env MCP server on port 3003 (API tools only) + if image_type == "mcp": + mcp_urls = (f"{root}api/v1/mcp",) + else: + mcp_urls = (f"{root}mcp",) + + orch = cls( + base_url=env.urls.manager.api, + fleet_env_handle=env, + api_key=api_key, + mcp_urls=mcp_urls, + **kwargs, + ) + tools = FleetMCPTools(api_key=api_key, mcp_urls=mcp_urls) + return orch, tools + + # ------------------------------------------------------------------ + # Database query methods (delegate to Fleet SDK's SQLiteResource) + # ------------------------------------------------------------------ + + def describe_db(self, db_name: str = "seed") -> Dict[str, Any]: + """Describe the schema of a database on the provisioned Fleet instance. + + Args: + db_name: Database name — "seed" (initial state) or "current" (live). + + Returns: + Dict with keys: success, resource_name, tables, message. + Each table has: name, sql, columns (list of {name, type, notnull, primary_key}). + """ + resp = self._fleet_env.db(db_name).describe() + return resp.model_dump() if hasattr(resp, "model_dump") else resp.dict() + + def query_db( + self, + sql: str, + args: Optional[List[Any]] = None, + db_name: str = "seed", + ) -> Dict[str, Any]: + """Execute a read-only SQL query against a database on the Fleet instance. + + Args: + sql: SQL SELECT statement. + args: Optional bind parameters. + db_name: Database name — "seed" (initial state) or "current" (live). + + Returns: + Dict with keys: success, columns, rows, message. + """ + resp = self._fleet_env.db(db_name).query(sql, args) + return resp.model_dump() if hasattr(resp, "model_dump") else resp.dict() + + async def describe_db_async(self, db_name: str = "seed") -> Dict[str, Any]: + """Async version of describe_db. + + Works with both sync (Fleet) and async (AsyncFleet) env handles. + """ + resource = self._fleet_env.db(db_name) + # AsyncFleet returns AsyncSQLiteResource with async describe() + if asyncio.iscoroutinefunction(getattr(resource, "describe", None)): + resp = await resource.describe() + else: + resp = await asyncio.to_thread(resource.describe) + return resp.model_dump() if hasattr(resp, "model_dump") else resp.dict() + + async def query_db_async( + self, + sql: str, + args: Optional[List[Any]] = None, + db_name: str = "seed", + ) -> Dict[str, Any]: + """Async version of query_db. + + Works with both sync (Fleet) and async (AsyncFleet) env handles. + """ + resource = self._fleet_env.db(db_name) + # AsyncFleet returns AsyncSQLiteResource with async query() + if asyncio.iscoroutinefunction(getattr(resource, "query", None)): + resp = await resource.query(sql, args) + else: + resp = await asyncio.to_thread(resource.query, sql, args) + return resp.model_dump() if hasattr(resp, "model_dump") else resp.dict() + + def _step_payload(self, action: Action) -> dict: + """Serialize action for HTTP /step.""" + if dataclasses.is_dataclass(action): + return dataclasses.asdict(action) + if isinstance(action, dict): + return action + raise TypeError(f"Action must be a dataclass or dict, got {type(action)}") + + def _parse_result(self, payload: dict) -> StepResult[Observation]: + """Parse standard OpenEnv step response.""" + obs_payload = payload.get("observation", {}) + # Ensure obs_payload is a dict before accessing .get() + if not isinstance(obs_payload, dict): + # If observation is a primitive (e.g. string), wrap it + obs_payload = {"content": obs_payload} + + return StepResult( + observation=Observation( + metadata=obs_payload, + reward=payload.get("reward"), + done=payload.get("done", False), + ), + reward=payload.get("reward"), + done=payload.get("done", False), + ) + + def _parse_state(self, payload: Any) -> Any: + if isinstance(payload, dict): + try: + return State(**payload) + except TypeError: + pass + return payload + + def step(self, action: Action) -> StepResult[Observation]: + # Enforce separation: agent actions are MCP-only (use FleetMCPTools). + if isinstance(action, (ListToolsAction, CallToolAction)): + raise TypeError( + "Agent tool actions are MCP-only. Use FleetMCPTools.list_tools()/call_tool()." + ) + return super().step(action) + + def close(self) -> None: + """Terminate the remote Fleet instance (resource cleanup), not an episode reset.""" + if self._fleet_env: + self._fleet_env.close() + super().close() + + async def close_async(self) -> None: + """Async close — runs sync Fleet close in a thread to avoid blocking the event loop.""" + if self._fleet_env: + await asyncio.to_thread(self._fleet_env.close) + super().close() + + async def reset_async(self) -> "StepResult": + """Async reset — runs sync HTTP reset in a thread to avoid blocking the event loop.""" + return await asyncio.to_thread(self.reset) diff --git a/src/envs/fleet_env/context_manager.py b/src/envs/fleet_env/context_manager.py new file mode 100644 index 000000000..d78b27792 --- /dev/null +++ b/src/envs/fleet_env/context_manager.py @@ -0,0 +1,329 @@ +""" +Context Management for Fleet Task Environments. + +This module provides tools for managing conversation context during long trajectories, +inspired by Toolathlon's context management approach. It allows models to: +1. Check how much context they've used +2. Drop old turns to free up context space +3. Search through dropped history +4. Navigate truncated tool outputs + +These tools are designed for step-wise RL training where each turn is a separate +training sample. When context is dropped, the training framework re-tokenizes +the modified chat_history, so the model learns from the reduced context. +""" + +import json +from typing import Any, Dict, List, Optional, Tuple + +# Context management tool definitions (OpenAI function calling format) +CONTEXT_TOOLS = [ + # --- Context/History Management --- + { + "type": "function", + "function": { + "name": "check_context", + "description": "Check current context: visible/total turn counts", + "parameters": {"type": "object", "properties": {}}, + }, + }, + { + "type": "function", + "function": { + "name": "manage_context", + "description": "Drop old turns to free up context space", + "parameters": { + "type": "object", + "properties": { + "keep_recent_turns": { + "type": "integer", + "description": "Number of recent turns to keep (drops older ones)", + } + }, + "required": ["keep_recent_turns"], + }, + }, + }, + { + "type": "function", + "function": { + "name": "search_history", + "description": "Search all history (including dropped) by pattern", + "parameters": { + "type": "object", + "properties": { + "pattern": { + "type": "string", + "description": "Text pattern to search", + } + }, + "required": ["pattern"], + }, + }, + }, + # --- Overlong Tool Output Handling --- + { + "type": "function", + "function": { + "name": "search_tool_output", + "description": "Search the last truncated tool output by pattern", + "parameters": { + "type": "object", + "properties": { + "pattern": { + "type": "string", + "description": "Text pattern to search", + } + }, + "required": ["pattern"], + }, + }, + }, + { + "type": "function", + "function": { + "name": "view_tool_output", + "description": "View a page of the last truncated tool output", + "parameters": { + "type": "object", + "properties": { + "page": { + "type": "integer", + "description": "Page number (1-indexed)", + }, + "page_size": { + "type": "integer", + "description": "Lines per page (default 50)", + }, + }, + "required": ["page"], + }, + }, + }, +] + +CONTEXT_TOOL_NAMES = {t["function"]["name"] for t in CONTEXT_TOOLS} + + +class ContextManager: + """Manages conversation context for long-running agent trajectories. + + This class provides utilities for: + 1. Tracking full conversation history (never dropped) + 2. Managing visible context (can be trimmed) + 3. Handling truncated tool outputs + 4. Executing context management tool calls + + Designed to work with any training framework that maintains chat_history. + The framework passes its chat_history to execute_tool(), which may modify it. + + Example: + >>> ctx = ContextManager(max_output_chars=10000) + >>> # Get tools to add to the model's available tools + >>> tools = ctx.get_tools() + >>> # Track messages as they're added + >>> ctx.track_message({"role": "assistant", "content": "..."}) + >>> # Check if a tool call is a context tool + >>> if ctx.is_context_tool("manage_context"): + ... result, chat_history = ctx.execute_tool("manage_context", {"keep_recent_turns": 5}, chat_history) + >>> # Truncate long outputs + >>> output = ctx.truncate_output(long_tool_result) + """ + + def __init__(self, max_output_chars: int = 10000): + """Initialize the context manager. + + Args: + max_output_chars: Maximum characters for tool output before truncation. + Truncated outputs can be accessed via search_tool_output/view_tool_output. + """ + self.max_output_chars = max_output_chars + self.full_history: List[Dict[str, Any]] = [] + self.last_full_output: Optional[str] = None + + def reset(self): + """Reset state for a new episode.""" + self.full_history = [] + self.last_full_output = None + + def get_tools(self) -> List[Dict[str, Any]]: + """Get the context management tool definitions. + + Returns: + List of tool definitions in OpenAI function calling format. + """ + return CONTEXT_TOOLS.copy() + + def is_context_tool(self, tool_name: str) -> bool: + """Check if a tool name is a context management tool. + + Args: + tool_name: Name of the tool to check. + + Returns: + True if it's a context tool that should be handled locally. + """ + return tool_name in CONTEXT_TOOL_NAMES + + def track_message(self, message: Dict[str, Any]): + """Track a message in the full history. + + Call this for every message added to chat_history. The full_history + is never trimmed, allowing search_history to find dropped messages. + + Args: + message: Message dict with "role" and "content" keys. + """ + self.full_history.append(message.copy()) + + def truncate_output(self, output: str) -> str: + """Truncate a tool output if it exceeds max_output_chars. + + If truncated, the full output is stored and can be accessed via + search_tool_output or view_tool_output tools. + + Args: + output: The tool output string. + + Returns: + Original output if within limit, truncated version with notice otherwise. + """ + if not isinstance(output, str): + return output + + if len(output) > self.max_output_chars: + self.last_full_output = output + return ( + output[: self.max_output_chars] + + f"\n\n[TRUNCATED - {len(output)} chars total. " + + "Use search_tool_output or view_tool_output to access full content.]" + ) + else: + self.last_full_output = None + return output + + def execute_tool( + self, tool_name: str, args: Dict[str, Any], chat_history: List[Dict[str, Any]] + ) -> Tuple[str, List[Dict[str, Any]]]: + """Execute a context management tool. + + Args: + tool_name: Name of the context tool to execute. + args: Tool arguments. + chat_history: Current visible chat history (may be modified). + + Returns: + Tuple of (result_string, modified_chat_history). + The chat_history is modified in-place for manage_context. + """ + if tool_name == "check_context": + return self._check_context(chat_history), chat_history + + elif tool_name == "manage_context": + return self._manage_context(args, chat_history) + + elif tool_name == "search_history": + return self._search_history(args), chat_history + + elif tool_name == "search_tool_output": + return self._search_tool_output(args), chat_history + + elif tool_name == "view_tool_output": + return self._view_tool_output(args), chat_history + + else: + return ( + json.dumps({"error": f"Unknown context tool: {tool_name}"}), + chat_history, + ) + + def _check_context(self, chat_history: List[Dict[str, Any]]) -> str: + """Check current context: visible vs total turns.""" + visible_turns = len([m for m in chat_history if m.get("role") == "assistant"]) + total_turns = len( + [m for m in self.full_history if m.get("role") == "assistant"] + ) + return json.dumps( + { + "visible_turns": visible_turns, + "total_turns": total_turns, + "dropped_turns": total_turns - visible_turns, + } + ) + + def _manage_context( + self, args: Dict[str, Any], chat_history: List[Dict[str, Any]] + ) -> Tuple[str, List[Dict[str, Any]]]: + """Drop old turns to free up context space.""" + n = args.get("keep_recent_turns", 5) + + # Keep system message + last n turns (each turn = assistant + user message) + system = [m for m in chat_history if m.get("role") == "system"] + non_system = [m for m in chat_history if m.get("role") != "system"] + keep_count = n * 2 # n turns = n assistant + n user messages + + if len(non_system) > keep_count: + dropped = len(non_system) - keep_count + new_history = system + non_system[-keep_count:] + return ( + f"Dropped {dropped} messages. {len(new_history)} remaining.", + new_history, + ) + else: + return f"Nothing to drop. {len(chat_history)} messages.", chat_history + + def _search_history(self, args: Dict[str, Any]) -> str: + """Search all history (including dropped) by pattern.""" + pattern = args.get("pattern", "").lower() + if not pattern: + return json.dumps({"error": "pattern is required"}) + + matches = [] + for i, msg in enumerate(self.full_history): + content = msg.get("content", "") + if isinstance(content, str) and pattern in content.lower(): + matches.append( + { + "index": i, + "role": msg.get("role"), + "snippet": content[:200], + } + ) + return json.dumps({"matches": matches[:10]}) + + def _search_tool_output(self, args: Dict[str, Any]) -> str: + """Search the last truncated tool output by pattern.""" + if not self.last_full_output: + return "No truncated output available." + + pattern = args.get("pattern", "").lower() + if not pattern: + return json.dumps({"error": "pattern is required"}) + + lines = self.last_full_output.split("\n") + matches = [] + for i, line in enumerate(lines): + if pattern in line.lower(): + matches.append({"line": i + 1, "content": line[:200]}) + return json.dumps({"matches": matches[:20]}) + + def _view_tool_output(self, args: Dict[str, Any]) -> str: + """View a page of the last truncated tool output.""" + if not self.last_full_output: + return "No truncated output available." + + page = args.get("page", 1) + page_size = args.get("page_size", 50) + lines = self.last_full_output.split("\n") + total_pages = (len(lines) + page_size - 1) // page_size + start = (page - 1) * page_size + end = start + page_size + page_lines = lines[start:end] + return json.dumps( + { + "page": page, + "total_pages": total_pages, + "total_lines": len(lines), + "content": "\n".join(page_lines), + } + ) diff --git a/src/envs/fleet_env/fleet_mcp_client.py b/src/envs/fleet_env/fleet_mcp_client.py new file mode 100644 index 000000000..79eb5a0a1 --- /dev/null +++ b/src/envs/fleet_env/fleet_mcp_client.py @@ -0,0 +1,172 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +"""Fleet-compatible MCP client wrapper (Streamable HTTP + initialize). + +Design note: +- We intentionally avoid exposing an async context-manager interface here. + Some MCP/AnyIO failure modes during connection setup can produce noisy + ExceptionGroup/cancel-scope traces if a partially-entered context leaks. +- Instead, this wrapper provides *one-shot* operations that open + close the + streamable HTTP transport within a single call. +""" + +from datetime import timedelta +from typing import Any, Dict, List, Optional + +try: + from mcp import ClientSession + from mcp.client.streamable_http import streamablehttp_client + from mcp.types import Tool +except ImportError as e: # pragma: no cover + raise ImportError( + "Fleet MCP support requires the optional dependency set. " + "Install with `pip install openenv-core[fleet]`." + ) from e + + +class FleetMCPClient: + # Hard timeout for entire MCP operation (connection + request) + OPERATION_TIMEOUT_S = 60 + + def __init__(self, url: str, api_key: str): + self.url = url + self.api_key = api_key + + async def _list_tools_impl(self) -> List[Tool]: + """Internal implementation without timeout wrapper.""" + async with streamablehttp_client( + url=self.url, + headers={"Authorization": f"Bearer {self.api_key}"}, + timeout=timedelta(seconds=30), + sse_read_timeout=timedelta(seconds=60), + ) as streams: + async with ClientSession( + read_stream=streams[0], write_stream=streams[1] + ) as session: + await session.initialize() + return (await session.list_tools()).tools + + async def list_tools(self) -> List[Tool]: + """List tools with hard timeout to prevent hanging.""" + import asyncio + + try: + return await asyncio.wait_for( + self._list_tools_impl(), timeout=self.OPERATION_TIMEOUT_S + ) + except asyncio.TimeoutError: + raise TimeoutError( + f"list_tools timed out after {self.OPERATION_TIMEOUT_S}s for {self.url}" + ) + + async def _call_tool_impl(self, name: str, arguments: Dict[str, Any]) -> Any: + """Internal implementation without timeout wrapper.""" + async with streamablehttp_client( + url=self.url, + headers={"Authorization": f"Bearer {self.api_key}"}, + timeout=timedelta(seconds=30), + sse_read_timeout=timedelta(seconds=60), + ) as streams: + async with ClientSession( + read_stream=streams[0], write_stream=streams[1] + ) as session: + await session.initialize() + result = await session.call_tool(name, arguments) + return self._extract_tool_result(result) + + async def call_tool(self, name: str, arguments: Dict[str, Any]) -> Any: + """Call tool with hard timeout to prevent hanging.""" + import asyncio + + try: + return await asyncio.wait_for( + self._call_tool_impl(name, arguments), timeout=self.OPERATION_TIMEOUT_S + ) + except asyncio.TimeoutError: + raise TimeoutError( + f"call_tool({name}) timed out after {self.OPERATION_TIMEOUT_S}s for {self.url}" + ) + + def _extract_tool_result(self, result: Any) -> Any: + """Extract readable content from CallToolResult. + + MCP's call_tool returns a CallToolResult with content list. + This extracts text and image content for use in agent observations. + + For VL (vision-language) models, ImageContent is converted to OpenAI-compatible + format: {"type": "image_url", "image_url": {"url": "data:image/png;base64,..."}} + + Returns: + - str: For single text result + - dict: For JSON-parseable text or error + - list: For multiple text items OR any content with images (multimodal) + """ + import json + + # Handle error case + if hasattr(result, "isError") and result.isError: + if hasattr(result, "content") and result.content: + for content in result.content: + if hasattr(content, "text"): + return {"error": content.text} + return {"error": "Tool execution failed"} + + # Extract content from CallToolResult + if hasattr(result, "content") and result.content: + texts = [] + images = [] + + for content in result.content: + # Handle TextContent + if hasattr(content, "text"): + texts.append(content.text) + # Handle ImageContent (MCP format: data, mimeType) + elif hasattr(content, "data") and hasattr(content, "mimeType"): + # Convert to OpenAI-compatible image_url format + mime_type = content.mimeType or "image/png" + base64_data = content.data + data_url = f"data:{mime_type};base64,{base64_data}" + images.append({"type": "image_url", "image_url": {"url": data_url}}) + + # If there are images, return multimodal format (for VL models) + if images: + contents = [] + for text in texts: + contents.append({"type": "text", "text": text}) + contents.extend(images) + return contents + + # Text-only: preserve backward compatibility + if len(texts) == 1: + # Single text result - try to parse as JSON + try: + parsed = json.loads(texts[0]) + # Handle Fleet MCP's base64_image format - convert to OpenAI format + if isinstance(parsed, dict) and "base64_image" in parsed: + data_url = parsed["base64_image"] + if data_url is not None: + return [{"type": "image_url", "image_url": {"url": data_url}}] + # base64_image was null — screenshot capture failed, return as text + return "Screenshot capture failed (null image)" + return parsed + except json.JSONDecodeError: + return texts[0] + elif texts: + # Multiple text results - return as list + return texts + + # Fallback to structured content if available + if hasattr(result, "structuredContent") and result.structuredContent: + return result.structuredContent + + # Last resort - return string representation + return str(result) + + def has_tool(self, name: str, tools_list: Optional[List[Tool]] = None) -> bool: + if not tools_list: + return False + return any(t.name == name for t in tools_list) diff --git a/src/envs/fleet_env/mcp_tools.py b/src/envs/fleet_env/mcp_tools.py new file mode 100644 index 000000000..044c11611 --- /dev/null +++ b/src/envs/fleet_env/mcp_tools.py @@ -0,0 +1,220 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +"""MCP-only handle for agents (no reset/step/state).""" + +from __future__ import annotations + +import asyncio +import logging +from dataclasses import dataclass, field +from typing import Any, Dict, List, Optional, Sequence + +from .fleet_mcp_client import FleetMCPClient +from .models import ListToolsAction, convert_tool_format +from .telemetry import fleet_error, fleet_warning + +logger = logging.getLogger(__name__) + + +def _unwrap_exception(e: Exception) -> str: + """Extract meaningful error message from ExceptionGroup or nested exceptions.""" + # Handle ExceptionGroup (from asyncio.TaskGroup) + if hasattr(e, 'exceptions'): + msgs = [_unwrap_exception(sub) for sub in e.exceptions] + return "; ".join(msgs) + # Handle chained exceptions + if e.__cause__: + return f"{type(e).__name__}: {e} <- {_unwrap_exception(e.__cause__)}" + return f"{type(e).__name__}: {e}" + + +@dataclass +class FleetMCPTools: + """Agent-facing tools client (MCP only).""" + + api_key: str + mcp_urls: Sequence[str] + max_retries: int = 8 + initial_wait: float = 8.0 + max_backoff: float = 5.0 + _clients: Optional[List[FleetMCPClient]] = field(default=None, repr=False) + _tool_owner: Optional[Dict[str, FleetMCPClient]] = field(default=None, repr=False) + + def _get_clients(self) -> List[FleetMCPClient]: + if self._clients is None: + self._clients = [FleetMCPClient(url, self.api_key) for url in self.mcp_urls] + return self._clients + + def _get_owner_cache(self) -> Dict[str, FleetMCPClient]: + if self._tool_owner is None: + self._tool_owner = {} + return self._tool_owner + + async def _list_tools_single_attempt(self) -> List[Any]: + """Single attempt to list tools from all clients.""" + owner_cache = self._get_owner_cache() + tools: list[Any] = [] + seen: set[str] = set() + errors: list[str] = [] + + for client in self._get_clients(): + try: + found = await client.list_tools() + for t in found: + # Deduplicate by tool name across endpoints, but cache first-seen owner. + if t.name not in owner_cache: + owner_cache[t.name] = client + if t.name in seen: + continue + seen.add(t.name) + tools.append(convert_tool_format(t)) + except BaseException as e: + errors.append(f"{client.url}: {e}") + continue + + if errors and not tools: + # All clients failed - log and raise + raise RuntimeError(f"All MCP clients failed to list tools: {errors}") + + if errors: + # Some clients failed but we got tools from others + logger.warning(f"Some MCP clients failed to list tools: {errors}") + fleet_warning( + "fleet_list_tools_partial", + error_message="; ".join(errors), + ) + + return tools + + async def list_tools(self) -> ListToolsAction: + """List available tools (union across endpoints) as a ListToolsAction. + + The returned `.tools` payload is in OpenAI "tools" dict format + (see `convert_tool_format`), derived from MCP `Tool.inputSchema`. + + Matches the orchestrator harness: 8s initial wait for MCP services to + start, then 8 retries with exponential backoff capped at 5s. + """ + # Wait for MCP services to initialize (matches harness initial_wait=8) + if self.initial_wait > 0: + logger.info(f"Waiting {self.initial_wait:.0f}s for MCP services to initialize...") + await asyncio.sleep(self.initial_wait) + + last_error = None + + for attempt in range(self.max_retries): + try: + tools = await self._list_tools_single_attempt() + if tools: + return ListToolsAction(tools=tools) + # Got empty tools - treat as failure and retry + raise RuntimeError("No tools found from any MCP endpoint") + except Exception as e: + last_error = e + error_msg = _unwrap_exception(e) + if attempt < self.max_retries - 1: + delay = min(2 ** attempt, self.max_backoff) + logger.warning( + f"list_tools attempt {attempt + 1}/{self.max_retries} failed: {error_msg}. " + f"Retrying in {delay:.1f}s..." + ) + fleet_warning( + "fleet_list_tools_retry", + attempt=attempt + 1, + max_retries=self.max_retries, + error_message=error_msg, + ) + await asyncio.sleep(delay) + + logger.error(f"list_tools failed after {self.max_retries} attempts: {_unwrap_exception(last_error)}") + fleet_error( + "fleet_list_tools_exhausted", + attempt=self.max_retries, + max_retries=self.max_retries, + error_message=_unwrap_exception(last_error), + ) + raise RuntimeError( + f"list_tools failed after {self.max_retries} attempts" + ) from last_error + + async def _call_tool_single_attempt( + self, tool_name: str, arguments: Dict[str, Any] + ) -> Any: + """Single attempt to call a tool.""" + owner_cache = self._get_owner_cache() + clients = self._get_clients() + + if tool_name in owner_cache: + client = owner_cache[tool_name] + logger.debug(f"call_tool({tool_name}) using cached client: {client.url}") + return await client.call_tool(tool_name, arguments) + + errors: list[str] = [] + for client in clients: + try: + tools = await client.list_tools() + if client.has_tool(tool_name, tools): + owner_cache[tool_name] = client + return await client.call_tool(tool_name, arguments) + except BaseException as e: + errors.append(f"{client.url}: {e}") + continue + + if errors: + raise RuntimeError(f"Tool call failed: {errors}") + + raise ValueError(f"Tool '{tool_name}' not found on any active MCP endpoint.") + + async def call_tool(self, tool_name: str, arguments: Dict[str, Any]) -> Any: + """Call a tool with retry logic for connection failures. + + Retries with exponential backoff on connection errors. + """ + last_error = None + + for attempt in range(self.max_retries): + try: + result = await self._call_tool_single_attempt(tool_name, arguments) + if attempt > 0: + logger.info(f"call_tool({tool_name}) succeeded on attempt {attempt + 1}") + return result + except ValueError: + # Tool not found - don't retry + raise + except Exception as e: + last_error = e + error_msg = _unwrap_exception(e) + if attempt < self.max_retries - 1: + delay = min(2 ** attempt, self.max_backoff) + logger.warning( + f"call_tool({tool_name}) attempt {attempt + 1}/{self.max_retries} failed: {error_msg}. " + f"Retrying in {delay:.1f}s..." + ) + fleet_warning( + "fleet_call_tool_retry", + tool_name=tool_name, + attempt=attempt + 1, + max_retries=self.max_retries, + error_message=error_msg, + ) + await asyncio.sleep(delay) + + logger.error( + f"call_tool({tool_name}) failed after {self.max_retries} attempts: {_unwrap_exception(last_error)}" + ) + fleet_error( + "fleet_call_tool_exhausted", + tool_name=tool_name, + attempt=self.max_retries, + max_retries=self.max_retries, + error_message=_unwrap_exception(last_error), + ) + raise RuntimeError( + f"call_tool({tool_name}) failed after {self.max_retries} attempts" + ) from last_error + + diff --git a/src/envs/fleet_env/models.py b/src/envs/fleet_env/models.py new file mode 100644 index 000000000..27c303fae --- /dev/null +++ b/src/envs/fleet_env/models.py @@ -0,0 +1,100 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +"""Data models for FleetEnvClient (RFC 003 tool-call actions).""" + +from dataclasses import dataclass, field +from typing import Any, Dict, TYPE_CHECKING + +# Avoid importing OpenAI typing aliases at runtime. +# The `openai` package changes exported type names across major versions, and +# Fleet integration should work even if OpenAI isn't installed. +if TYPE_CHECKING: # pragma: no cover + try: + from openai import ChatCompletionToolUnionParam as OpenAIToolParam # type: ignore + except Exception: # noqa: BLE001 + OpenAIToolParam = Dict[str, Any] # type: ignore[misc,assignment] +else: + OpenAIToolParam = Dict[str, Any] # type: ignore[misc,assignment] + +from mcp.types import Tool + + +# Support both in-repo and standalone imports +try: + from core.env_server.types import Action +except ImportError: + from openenv_core.env_server.types import Action + +def normalize_schema(schema: Dict[str, Any]) -> Dict[str, Any]: + if not isinstance(schema, dict): + return schema + + result = {} + + if "anyOf" in schema: + non_null_schemas = [s for s in schema["anyOf"] if s.get("type") != "null"] + if non_null_schemas: + schema = {**schema, **non_null_schemas[0]} + del schema["anyOf"] + + for key, value in schema.items(): + if key in ["title", "default", "anyOf"]: + continue + + if key == "prefixItems": + result["items"] = ( + normalize_schema(value[0]) if value else {"type": "string"} + ) + continue + + if key == "properties" and isinstance(value, dict): + result[key] = {k: normalize_schema(v) for k, v in value.items()} + elif key == "items" and isinstance(value, dict): + result[key] = normalize_schema(value) + else: + result[key] = value + + return result + + +def convert_tool_format(tool: Tool) -> OpenAIToolParam: + normalized_properties = { + key: normalize_schema(value) + for key, value in tool.inputSchema.get("properties", {}).items() + } + + # OpenAI "tools" format: {"type": "function", "function": {...}} + openai_tool: OpenAIToolParam = { + "type": "function", + "function": { + "name": tool.name, + "description": tool.description, + "parameters": { + "type": "object", + "properties": normalized_properties, + "required": tool.inputSchema.get("required", []), + }, + }, + } + return openai_tool + + +@dataclass(kw_only=True) +class ListToolsAction(Action): + """Request list of available MCP tools from the Fleet environment.""" + + tools: list[OpenAIToolParam] = field(default_factory=list) + + +@dataclass(kw_only=True) +class CallToolAction(Action): + """Call a specific MCP tool exposed by the Fleet environment.""" + + tool_name: str + parameters: Dict[str, Any] = field(default_factory=dict) + + diff --git a/src/envs/fleet_env/task_env.py b/src/envs/fleet_env/task_env.py new file mode 100644 index 000000000..e23130d1c --- /dev/null +++ b/src/envs/fleet_env/task_env.py @@ -0,0 +1,901 @@ +""" +Fleet Task Environment - Gymnasium-compatible environment for Fleet tasks. + +This module provides a task-oriented wrapper around FleetEnvClient that: +1. Accepts task configs (from export_training_tasks.py) +2. Creates versioned environments on reset +3. Injects task prompt into observations +4. Executes verifier for reward on episode completion +""" + +import ast +import asyncio +import logging +import os +import re +from typing import Any, Dict, List, Optional, Tuple + +logger = logging.getLogger(__name__) + +from .client import FleetEnvClient +from .mcp_tools import FleetMCPTools +from .telemetry import ( + fleet_exception, + fleet_warning, + fleet_info, + set_task_context, + clear_task_context, +) + +# Synthetic tool injected by the harness (not from MCP). +# Mirrors orchestrator/temporal/workflows/constants.py → ANSWER_SUBMISSION_TOOL. +SUBMIT_FINAL_ANSWER_TOOL = { + "type": "function", + "function": { + "name": "submit_final_answer", + "description": ( + "Submit your final answer to complete the task. Use this when you " + "have finished the task and want to provide your answer for " + "verification. If the requested answer asks for json, then write " + "your response in the answer field using json brackets." + ), + "parameters": { + "type": "object", + "properties": { + "answer": { + "type": "string", + "description": "Your final answer", + } + }, + "required": ["answer"], + }, + }, +} + + +def _is_tool_error(result: Any) -> Tuple[bool, Optional[str]]: + """Check if a tool result indicates an error. + + MCP server errors come back as: + - {"error": "..."} from isError=True responses + - {"status": "failed", ...} from some tools + - {"isError": true, ...} in some formats + + Returns: + (is_error, error_message) tuple + """ + if not isinstance(result, dict): + return False, None + + # Direct error field (from FleetMCPClient._extract_tool_result) + # Check for truthy value to avoid false positives on {"error": null} + if result.get("error"): + return True, str(result["error"]) + + # Status field pattern + if result.get("status") == "failed": + return True, result.get("message") or result.get("error") or "status=failed" + + # isError field pattern + if result.get("isError"): + return True, result.get("message") or result.get("error") or "isError=true" + + return False, None + + +class FleetTaskEnv: + """Gymnasium-compatible environment for Fleet tasks. + + This class wraps FleetEnvClient to provide a task-oriented interface + suitable for RL training with SkyRL. + + Args: + task_config: Task configuration dict with keys: + - task_key: Unique task identifier + - prompt: Task instruction for the agent + - env_key: Environment key (e.g., "booking-com") + - env_version: Environment version (e.g., "v1.2.3") + - data_key: Optional data key + - data_version: Optional data version + - verifier_code: Python code for verification + - task_modality: "tool_use" or "computer_use" + api_key: Fleet API key (defaults to FLEET_API_KEY env var) + ttl_seconds: Instance TTL in seconds. If None, auto-selects based on + modality: 1800s (30 min) for computer_use, 900s (15 min) for tool_use. + max_steps: Maximum steps per episode (default: 50) + request_timeout_s: HTTP request timeout in seconds (default: 60.0) + partial_reward: If True, compute partial scores from verifier + error/success accumulators instead of binary 0/1 (default: False) + + Example: + >>> task_config = { + ... "task_key": "search-flights-001", + ... "prompt": "Search for flights from NYC to LA", + ... "env_key": "booking-com", + ... "env_version": "v1.2.3", + ... "verifier_code": "async def verify(env): ...", + ... "task_modality": "tool_use", + ... } + >>> env = FleetTaskEnv(task_config) + >>> obs = env.reset() + >>> obs, reward, done, info = env.step({"tool": "search", "params": {...}}) + """ + + def __init__( + self, + task_config: Dict[str, Any], + api_key: Optional[str] = None, + ttl_seconds: Optional[int] = None, + max_steps: int = 50, + request_timeout_s: float = 60.0, + reset_timeout_s: float = 10.0, + partial_reward: bool = False, + ): + self.task = task_config + self.api_key = api_key or os.environ.get("FLEET_API_KEY") + self.partial_reward = partial_reward + # Auto-select TTL based on modality if not explicitly provided + if ttl_seconds is not None: + self.ttl_seconds = ttl_seconds + elif self.modality == "computer_use": + self.ttl_seconds = ( + 1800 # 30 min — CUA rollouts are slow (browser + inference) + ) + else: + self.ttl_seconds = ( + 900 # 15 min — tool_use rollouts need headroom for retries + ) + self.max_steps = max_steps + self.request_timeout_s = request_timeout_s + self.reset_timeout_s = reset_timeout_s + + if not self.api_key: + raise ValueError( + "Fleet API key required (pass api_key or set FLEET_API_KEY)" + ) + + self._step_count = 0 + self._done = False + self._rollout_completed_emitted = False + self._rollout_started = False + self._tools_cache: Optional[List[Dict]] = None + self._reward_computed = False + self.final_reward: Optional[float] = None + self._submitted_answer: Optional[str] = None + + # Feedback for hint generation (accumulated during rollout) + self._tool_errors: List[str] = [] + self._verifier_stdout: Optional[str] = None + self._verifier_error: Optional[str] = None + + # Set telemetry context so init failures are tracked with full context + set_task_context( + env_key=self.env_key, + env_version=self.env_version, + task_key=self.task_key, + modality=self.modality, + ) + + # Provisioning is deferred to _ensure_provisioned() (called from reset_async) + # to avoid blocking the event loop with sync Fleet.make() calls. + self._orch = None + self._tools = None + + @property + def task_key(self) -> str: + """Get the task key.""" + return self.task.get("task_key", "unknown") + + @property + def prompt(self) -> str: + """Get the task prompt.""" + return self.task.get("prompt", "") + + @property + def modality(self) -> str: + """Get the task modality.""" + return self.task.get("task_modality", "tool_use") + + @property + def env_key(self) -> str: + """Get the environment key (e.g., 'github', 'amazon').""" + return self.task.get("env_key", "unknown") + + @property + def env_version(self) -> str: + """Get the environment version (e.g., 'v0.0.12').""" + return self.task.get("env_version", "unknown") + + def _build_env_spec(self) -> str: + """Build env_key:version spec for Fleet.make().""" + env_key = self.task.get("env_key") + env_version = self.task.get("env_version") + + if not env_key: + raise ValueError("Task config missing env_key") + + if env_version: + return f"{env_key}:{env_version}" + return env_key + + def _get_data_key(self) -> Optional[str]: + """Get data_key from task config.""" + return self.task.get("data_key") + + def _get_data_version(self) -> Optional[str]: + """Get data_version from task config.""" + return self.task.get("data_version") + + def _get_env_variables(self) -> Optional[Dict[str, Any]]: + """Get env_variables from task config. + + These variables parameterize the environment with task-specific values + like names, dates, scenario configurations, etc. + """ + return self.task.get("env_variables") + + def reset(self, seed: Optional[int] = None) -> Dict[str, Any]: + """Reset the environment and return initial observation (sync wrapper). + + This is a sync wrapper around reset_async(). For async code, use reset_async() directly. + + Args: + seed: Optional random seed (passed to env reset) + + Returns: + Observation dict with keys: + - prompt: The task instruction + - observation: Raw observation from env reset + - tools: List of available tools (if tool_use modality) + - step: Current step number (0) + """ + import asyncio + + return asyncio.run(self.reset_async(seed=seed)) + + async def _ensure_provisioned(self): + """Provision the Fleet environment instance if not already done. + + Uses AsyncFleet.make() to avoid blocking the event loop. This allows + other async trajectories to progress while waiting for provisioning. + """ + if self._orch is not None: + return + + env_spec = self._build_env_spec() + # computer_use: MCP-enabled container with browser infra (port 8081 aggregator) + # tool_use: standard container with per-env MCP server (port 3003) + image_type = "mcp" if self.modality == "computer_use" else "standard" + self._orch, self._tools = await FleetEnvClient.from_fleet_async( + api_key=self.api_key, + env_key=env_spec, + data_key=self._get_data_key(), + data_version=self._get_data_version(), + env_variables=self._get_env_variables(), + image_type=image_type, + ttl_seconds=self.ttl_seconds, + request_timeout_s=self.request_timeout_s, + ) + + async def reset_async(self, seed: Optional[int] = None) -> Dict[str, Any]: + """Reset episode state and return initial observation. + + Provisions the Fleet environment on first call (async, non-blocking), + then resets episode state and returns the observation with tools. + + Args: + seed: Optional random seed (currently unused) + + Returns: + Observation dict with keys: + - prompt: The task instruction + - observation: Observation from env reset (or empty if reset fails) + - tools: List of available tools (if tool_use modality) + - step: Current step number (0) + """ + import logging + + logger = logging.getLogger(__name__) + + # Count this rollout attempt immediately — even if provisioning fails, + # it's still a rollout attempt (e.g., fostgres health check failures). + fleet_info("fleet_rollout_started") + self._rollout_started = True + self._rollout_completed_emitted = False + + # Provision Fleet env (async, non-blocking) on first call + try: + await self._ensure_provisioned() + except Exception: + # Emit rollout_completed so init failures are tracked in dashboards + fleet_info( + "fleet_rollout_completed", + step_count=0, + reward=0.0, + verifier_success=False, + failure_reason="init_error", + ) + self._rollout_completed_emitted = True + raise + + # Reset episode state + self._step_count = 0 + self._done = False + self._reward_computed = False + self.final_reward = None + self._submitted_answer = None + self._tool_errors = [] + self._verifier_stdout = None + self._verifier_error = None + + # Reset the environment (use short timeout to avoid blocking on broken manager APIs) + # reset() failure is non-fatal — env is up, just the manager API timed out + reset_metadata = {} + if self._orch: + try: + saved_timeout = self._orch._timeout + self._orch._timeout = self.reset_timeout_s + try: + reset_result = await self._orch.reset_async() + reset_metadata = ( + reset_result.observation.metadata if reset_result else {} + ) + finally: + self._orch._timeout = saved_timeout + except Exception as e: + logger.warning( + f"[env={self.env_key}] Fleet env reset failed (timeout={self.reset_timeout_s}s), continuing with empty observation: {e}" + ) + fleet_warning( + "fleet_env_reset_failed", + step_count=self._step_count, + timeout_s=self.reset_timeout_s, + error_type=type(e).__name__, + error_message=str(e)[:200], + ) + + # Fetch tools — fatal if MCP call fails (no tools = dead rollout) + try: + if self._tools: + tools_result = await self._tools.list_tools() + self._tools_cache = tools_result.tools + if not self._tools_cache: + raise RuntimeError("list_tools returned no tools") + except Exception as e: + fleet_info( + "fleet_rollout_completed", + step_count=0, + reward=0.0, + verifier_success=False, + failure_reason="tools_error", + error_message=str(e)[:200], + ) + self._rollout_completed_emitted = True + raise + + # Filter tools based on modality: + # - computer_use: keep ONLY the 'computer' tool + # - tool_use: EXCLUDE the 'computer' tool (should only use API tools) + if self.modality == "tool_use": + self._tools_cache = [ + t + for t in self._tools_cache + if t.get("name") != "computer" + and t.get("function", {}).get("name") != "computer" + ] + + # For computer_use, filter to only the 'computer' tool + if self.modality == "computer_use": + computer_tools = [ + t + for t in self._tools_cache + if t.get("name") == "computer" + or t.get("function", {}).get("name") == "computer" + ] + if not computer_tools: + available = [ + t.get("name") or t.get("function", {}).get("name") + for t in self._tools_cache + ] + fleet_info( + "fleet_rollout_completed", + step_count=0, + reward=0.0, + verifier_success=False, + failure_reason="computer_tool_missing", + available_tools=available, + ) + self._rollout_completed_emitted = True + raise RuntimeError( + f"computer_use modality but no 'computer' tool found. " + f"Available tools: {available}. Check MCP image configuration." + ) + self._tools_cache = computer_tools + + if not self._tools_cache: + fleet_info( + "fleet_rollout_completed", + step_count=0, + reward=0.0, + verifier_success=False, + failure_reason="tools_error", + error_message="No tools available after modality filtering", + ) + self._rollout_completed_emitted = True + raise RuntimeError("No tools available after filtering") + + # Inject submit_final_answer synthetic tool for tool_use tasks whose + # prompt references it. This mirrors the harness's ANSWER_SUBMISSION_TOOL + # so that models can submit answers during SkyRL training exactly as + # they would in a Fleet harness session. + if self.modality == "tool_use" and "submit_final_answer" in self.prompt: + self._tools_cache.append(SUBMIT_FINAL_ANSWER_TOOL) + + # Build observation with cached tools + obs = { + "prompt": self.prompt, + "observation": reset_metadata, + "step": 0, + "task_key": self.task_key, + "modality": self.modality, + "tools": self._tools_cache, + } + + # For computer_use, take initial screenshot so VL model can see the screen + # This is critical for VL models - without visual input they're blind + if self.modality == "computer_use" and self._tools: + try: + screenshot_result = await self._tools.call_tool( + "computer", {"action": "screenshot"} + ) + obs["initial_screenshot"] = screenshot_result + logger.info(f"Task {self.task_key}: captured initial screenshot") + except Exception as e: + logger.warning( + f"Task {self.task_key}: failed to capture initial screenshot: {e}" + ) + fleet_exception( + "fleet_screenshot_failed", + step_count=self._step_count, + ) + + return obs + + def step(self, action: Dict[str, Any]) -> Tuple[Dict[str, Any], float, bool, Dict]: + """Execute a step in the environment (sync wrapper). + + For async tool calls, use step_async() instead. + + Args: + action: Action dict. For tool_use modality: + - tool: Tool name to call + - params: Tool parameters + - done: Optional flag to signal episode completion + + Returns: + Tuple of (observation, reward, done, info) + """ + import asyncio + + return asyncio.run(self.step_async(action)) + + async def step_async( + self, action: Dict[str, Any] + ) -> Tuple[Dict[str, Any], float, bool, Dict]: + """Execute a step in the environment. + + Args: + action: Action dict. For tool_use modality: + - tool: Tool name to call + - params: Tool parameters + - done: Optional flag to signal episode completion + + Returns: + Tuple of (observation, reward, done, info) + """ + if self._done: + raise RuntimeError("Episode is done. Call reset() to start a new episode.") + + if not self._tools: + raise RuntimeError("Environment not initialized. Call reset() first.") + + self._step_count += 1 + info = {"step": self._step_count} + + # Check if agent signals completion + agent_done = action.get("done", False) + + # Check max steps + max_steps_reached = self._step_count >= self.max_steps + + # Execute tool call + tool_name = action.get("tool") + tool_params = action.get("params", {}) + tool_result = None + + if tool_name == "submit_final_answer": + # Synthetic tool — handled locally, not routed to MCP. + self._submitted_answer = tool_params.get("answer", "") + tool_result = { + "status": "submitted", + "message": "Answer recorded. Ending session.", + } + info["tool_result"] = tool_result + info["submitted_answer"] = self._submitted_answer + agent_done = True # Force episode end, same as harness behaviour + elif tool_name: + try: + tool_result = await self._tools.call_tool(tool_name, tool_params) + info["tool_result"] = tool_result + + # Check for MCP server errors (not Python exceptions) + is_error, error_msg = _is_tool_error(tool_result) + if is_error: + info["tool_error"] = error_msg + self._tool_errors.append( + f"{tool_name}(): {error_msg[:500] if error_msg else 'unknown'}" + ) + logger.warning( + f"[env={self.env_key}:{self.env_version}] step {self._step_count}/{self.max_steps} " + f"tool_error: {tool_name}() -> {error_msg[:200] if error_msg else 'unknown'}" + ) + fleet_warning( + "fleet_mcp_tool_error", + step_count=self._step_count, + max_steps=self.max_steps, + tool_name=tool_name, + error_message=error_msg[:500] if error_msg else None, + ) + except Exception as e: + info["tool_error"] = str(e) + tool_result = {"error": str(e)} + self._tool_errors.append(f"{tool_name}(): {str(e)[:500]}") + logger.warning( + f"[env={self.env_key}:{self.env_version}] step {self._step_count}/{self.max_steps} " + f"tool_call_failed: {tool_name}() -> {type(e).__name__}: {str(e)[:200]}" + ) + fleet_exception( + "fleet_tool_call_failed", + step_count=self._step_count, + max_steps=self.max_steps, + tool_name=tool_name, + ) + + # Determine if done + self._done = agent_done or max_steps_reached + info["done_reason"] = ( + "agent_done" if agent_done else "max_steps" if max_steps_reached else None + ) + + # Calculate reward (only on episode completion) + reward = 0.0 + if self._done: + reward = await self._compute_reward() + self._reward_computed = True + info["reward_computed"] = True + + # Build observation + obs = { + "prompt": self.prompt, + "observation": tool_result or {}, + "step": self._step_count, + "task_key": self.task_key, + "modality": self.modality, + } + + if self._tools_cache: + obs["tools"] = self._tools_cache + + return obs, reward, self._done, info + + @staticmethod + def _parse_partial_reward(stdout: str) -> Optional[float]: + """Parse partial reward from verifier accumulator output. + + Verifiers print error/success accumulators to stdout. This parses + them to compute a fractional score (n_success / total_checks). + + Returns: + Partial score in [0, 1], or None if accumulators not found. + """ + err_match = re.search( + r">>> ERROR_ACCUMULATOR >>>\n(.+?)\n<<< ERROR_ACCUMULATOR <<<", + stdout, + re.DOTALL, + ) + suc_match = re.search( + r">>> SUCCESS_ACCUMULATOR >>>\n(.+?)\n<<< SUCCESS_ACCUMULATOR <<<", + stdout, + re.DOTALL, + ) + if not err_match and not suc_match: + return None + try: + n_errors = len(ast.literal_eval(err_match.group(1))) if err_match else 0 + n_success = len(ast.literal_eval(suc_match.group(1))) if suc_match else 0 + total = n_errors + n_success + return n_success / total if total > 0 else None + except Exception: + return None + + @property + def verifier_stdout(self) -> Optional[str]: + """Raw verifier stdout (contains ERROR/SUCCESS_ACCUMULATOR blocks).""" + return self._verifier_stdout + + @property + def verifier_error(self) -> Optional[str]: + """Verifier error message, if verifier failed.""" + return self._verifier_error + + @property + def tool_errors_list(self) -> List[str]: + """Accumulated tool error messages from this rollout.""" + return self._tool_errors.copy() + + async def _compute_reward(self) -> float: + """Compute reward by executing the verifier using Fleet SDK. + + Uses Fleet SDK's Task.verify_detailed() which properly sets up the + verifier namespace with Environment type, helper functions, etc. + + Returns: + 1.0 if verifier passes, 0.0 otherwise (or partial if enabled) + """ + # Support both field names: verifier_code (OpenEnv) and verifier_func (Fleet SDK) + verifier_code = self.task.get("verifier_code") or self.task.get("verifier_func") + score = 0.0 + verifier_success = False + failure_reason = None + + if not verifier_code: + # No verifier - return neutral reward + logger.debug(f"Task {self.task_key}: no verifier_code, returning 0.0") + failure_reason = "no_verifier" + elif not self._orch: + logger.warning(f"Task {self.task_key}: no orchestrator, returning 0.0") + failure_reason = "no_orchestrator" + else: + # Get the Fleet env handle from the orchestrator + fleet_env = getattr(self._orch, "_fleet_env", None) + if not fleet_env: + logger.warning( + f"Task {self.task_key}: no Fleet env handle, returning 0.0" + ) + failure_reason = "no_fleet_env" + else: + try: + # Use Fleet SDK's Task.verify_detailed() for proper verifier execution + from fleet.tasks import Task as FleetTask + + # Create a Fleet SDK Task object with the verifier + fleet_task = FleetTask( + key=self.task_key, + prompt=self.prompt, + env_id=self.task.get("env_key", "unknown"), + verifier_func=verifier_code, + ) + + # Execute verifier in a thread to avoid blocking the event loop. + # verify_detailed() does sync HTTP calls internally. + # Pass final_answer when model used submit_final_answer, + # mirroring how the harness routes the answer to the verifier. + verify_kwargs = {} + if self._submitted_answer is not None: + verify_kwargs["final_answer"] = self._submitted_answer + response = await asyncio.to_thread( + fleet_task.verify_detailed, fleet_env, **verify_kwargs + ) + + # Extract result from response + # response.success is bool, response.result is the verifier's return value (0.0 or 1.0) + if response.success and response.result is not None: + score = float(response.result) + elif response.success: + # Verifier succeeded but returned None - treat as success + score = 1.0 + else: + # Verifier failed (exception or explicit failure) + score = 0.0 + + verifier_success = response.success + + # Capture verifier feedback for hint generation + if hasattr(response, "stdout") and response.stdout: + self._verifier_stdout = response.stdout + if not response.success: + self._verifier_error = ( + f"Verifier failed: result={response.result}" + ) + + # Partial reward: use accumulator counts instead of binary 0/1 + partial_score = None + if ( + self.partial_reward + and score == 0.0 + and hasattr(response, "stdout") + and response.stdout + ): + partial_score = self._parse_partial_reward(response.stdout) + if partial_score is not None: + score = partial_score + + logger.info( + f"Task {self.task_key}: verifier returned success={response.success}, " + f"result={response.result}, score={score}" + + ( + f", partial={partial_score:.3f}" + if partial_score is not None + else "" + ) + ) + + except ImportError as e: + logger.error(f"Fleet SDK not available for verifier execution: {e}") + failure_reason = "import_error" + self._verifier_error = f"ImportError: {e}" + except Exception as e: + logger.error( + f"Verifier execution failed for task {self.task_key}: {e}\n" + f"Verifier code:\n{verifier_code}" + ) + fleet_exception( + "fleet_verifier_failed", + step_count=self._step_count, + verifier_code_snippet=( + verifier_code[:200] if verifier_code else "" + ), + ) + failure_reason = "verifier_exception" + self._verifier_error = f"Verifier exception: {e}" + + # Always emit rollout completed event + fleet_info( + "fleet_rollout_completed", + step_count=self._step_count, + max_steps=self.max_steps, + reward=score, + verifier_success=verifier_success, + failure_reason=failure_reason, + ) + self._rollout_completed_emitted = True + return score + + def close(self): + """Close the environment and cleanup resources. + + Runs the verifier for orphaned rollouts — trajectories where SkyRL + stopped early (context overflow, its own max_turns) without OpenEnv + computing the reward. This ensures the actual reward is available + via self.final_reward instead of defaulting to 0.0. + """ + try: + # Run verifier for orphaned rollouts (started but never completed). + # _compute_reward() handles telemetry (fleet_rollout_completed). + if self._rollout_started and not self._rollout_completed_emitted: + try: + self.final_reward = asyncio.run(self._compute_reward()) + self._reward_computed = True + except RuntimeError: + # Already inside a running event loop — caller should use close_async() + # Fall back to emitting telemetry without verifier + stop_reason = ( + "max_steps" + if self._step_count >= self.max_steps + else "abandoned" + ) + fleet_info( + "fleet_rollout_completed", + step_count=self._step_count, + max_steps=self.max_steps, + reward=0.0, + verifier_success=False, + failure_reason=stop_reason, + ) + self._rollout_completed_emitted = True + + if self._orch: + try: + self._orch.close() + except Exception: + pass # Expected when instance TTL expired + finally: + # Always cleanup state, even if telemetry fails + self._orch = None + self._tools = None + self._tools_cache = None + self._done = True + self._rollout_started = False + clear_task_context() + + async def close_async(self): + """Async close — runs verifier for orphaned rollouts and terminates instance. + + If SkyRL ends the trajectory early (context overflow, its own max_turns), + the verifier never ran in step_async(). This runs it at close time so + the real reward is available via self.final_reward. + """ + try: + # Run verifier for orphaned rollouts (started but never completed). + # _compute_reward() handles telemetry (fleet_rollout_completed). + if self._rollout_started and not self._rollout_completed_emitted: + self.final_reward = await self._compute_reward() + self._reward_computed = True + + if self._orch: + try: + await self._orch.close_async() + except Exception: + pass # Expected when instance TTL expired + finally: + self._orch = None + self._tools = None + self._tools_cache = None + self._done = True + self._rollout_started = False + clear_task_context() + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + self.close() + + @classmethod + def from_json_file(cls, json_path: str, task_key: str, **kwargs) -> "FleetTaskEnv": + """Create FleetTaskEnv from exported JSON file. + + Args: + json_path: Path to JSON file from export_training_tasks.py + task_key: Task key to load + **kwargs: Additional arguments passed to FleetTaskEnv + + Returns: + FleetTaskEnv instance for the specified task + """ + import json + + with open(json_path) as f: + data = json.load(f) + + tasks = data.get("tasks", []) + task_config = next((t for t in tasks if t["task_key"] == task_key), None) + + if not task_config: + raise ValueError(f"Task '{task_key}' not found in {json_path}") + + return cls(task_config, **kwargs) + + @classmethod + def from_json_file_all(cls, json_path: str, **kwargs) -> List["FleetTaskEnv"]: + """Create FleetTaskEnv instances for all tasks in JSON file. + + Args: + json_path: Path to JSON file from export_training_tasks.py + **kwargs: Additional arguments passed to FleetTaskEnv + + Returns: + List of FleetTaskEnv instances + """ + import json + + with open(json_path) as f: + data = json.load(f) + + tasks = data.get("tasks", []) + return [cls(task, **kwargs) for task in tasks] + + +def make_fleet_task_env(task_config: Dict[str, Any], **kwargs) -> FleetTaskEnv: + """Factory function for creating FleetTaskEnv. + + This is the recommended entry point for SkyRL integration. + + Args: + task_config: Task configuration dict + **kwargs: Additional arguments passed to FleetTaskEnv + + Returns: + FleetTaskEnv instance + """ + return FleetTaskEnv(task_config, **kwargs) diff --git a/src/envs/fleet_env/task_evaluator.py b/src/envs/fleet_env/task_evaluator.py new file mode 100644 index 000000000..3ed9b041a --- /dev/null +++ b/src/envs/fleet_env/task_evaluator.py @@ -0,0 +1,299 @@ +""" +Task Evaluator for generated tasks. + +Given a generated (prompt, verifier_code) and environment config, runs k rollouts +across m models via the Fleet harness (POST /v1/jobs) and returns structured +results for reward computation. + +This is the inner loop of the task generation RL pipeline: + 1. Task generator outputs (prompt, verifier) for an environment + 2. TaskEvaluator imports the task to Fleet, creates a harness job + 3. Harness runs k × m rollouts (env provisioning, model calls, verification) + 4. Results feed into reward computation (variance + separation) +""" + +import asyncio +import logging +import os +import time +import uuid +from dataclasses import dataclass, field +from typing import Any, Dict, List, Optional + +logger = logging.getLogger(__name__) + +# Default models for evaluation (must match Fleet models table IDs) +DEFAULT_MODELS = ["anthropic/claude-sonnet-4.5"] + + +@dataclass +class EvaluationResult: + """Aggregated results from k × m rollout evaluation.""" + + results_per_model: Dict[str, List[float]] = field(default_factory=dict) + total_duration_s: float = 0.0 + num_sessions: int = 0 + num_errors: int = 0 + job_id: Optional[str] = None + + def to_dict(self) -> Dict[str, Any]: + return { + "results_per_model": self.results_per_model, + "total_duration_s": self.total_duration_s, + "num_rollouts": self.num_sessions, + "num_errors": self.num_errors, + "job_id": self.job_id, + } + + +class TaskEvaluator: + """Evaluates generated tasks by submitting jobs to the Fleet harness. + + For each generated task: + 1. Imports the task to Fleet via fleet.import_task() + 2. Creates a harness job via fleet.create_job() with specified models and pass_k + 3. Polls for job completion + 4. Extracts per-session verifier scores + 5. Returns results_per_model for reward computation + + Args: + api_key: Fleet API key + k_rollouts: Number of rollouts per model (pass_k in Fleet terms) + models: List of Fleet model IDs (e.g., ["claude-sonnet-4.5", "claude-opus-4.5"]) + max_steps: Maximum agent steps per session + poll_interval_s: Seconds between job status polls (default: 10) + max_poll_time_s: Maximum time to wait for job completion (default: 1800 = 30 min) + """ + + def __init__( + self, + api_key: Optional[str] = None, + k_rollouts: int = 4, + models: Optional[List[str]] = None, + max_steps: int = 30, + poll_interval_s: int = 10, + max_poll_time_s: int = 1800, + **kwargs, + ): + self.api_key = api_key or os.environ.get("FLEET_API_KEY") + if not self.api_key: + raise ValueError("Fleet API key required") + + self.k_rollouts = k_rollouts + self.models = list(models or DEFAULT_MODELS) + self.max_steps = max_steps + self.poll_interval_s = poll_interval_s + self.max_poll_time_s = max_poll_time_s + + # Initialize Fleet SDK client + self._fleet_client = None + + def _match_model_id(self, session_model_id: str) -> Optional[str]: + """Match a session model ID to one of our configured model IDs. + + Fleet may return model IDs without provider prefix (e.g., 'claude-sonnet-4.5') + while we configure them with prefix (e.g., 'anthropic/claude-sonnet-4.5'), + or vice versa. + """ + if session_model_id in self.models: + return session_model_id + + # Strip provider prefix and compare bare model names + session_bare = session_model_id.split("/", 1)[-1] if "/" in session_model_id else session_model_id + for configured_id in self.models: + configured_bare = configured_id.split("/", 1)[-1] if "/" in configured_id else configured_id + if configured_bare == session_bare: + return configured_id + + return None + + def _get_fleet_client(self): + """Lazy-init Fleet SDK client.""" + if self._fleet_client is None: + from fleet import Fleet + + self._fleet_client = Fleet(api_key=self.api_key) + return self._fleet_client + + async def evaluate( + self, + prompt: str, + verifier_code: str, + env_key: str, + env_version: str = "", + env_variables: Optional[Dict[str, Any]] = None, + data_key: Optional[str] = None, + data_version: Optional[str] = None, + ) -> Dict[str, Any]: + """Run k × m rollouts via Fleet harness and return structured results. + + Flow: + 1. Create a Fleet Task object with the generated prompt + verifier + 2. Import it to Fleet via POST /v1/tasks + 3. Create a harness job via POST /v1/jobs + 4. Poll until job completes + 5. Extract per-model, per-session verifier scores + + Args: + prompt: The generated task prompt + verifier_code: The generated verifier code + env_key: Fleet environment key + env_version: Fleet environment version + env_variables: Optional environment variables + data_key: Optional data key + data_version: Optional data version + + Returns: + Dict with 'results_per_model' mapping model_id -> list[float] + """ + start_time = time.time() + result = EvaluationResult() + for model_id in self.models: + result.results_per_model[model_id] = [] + + fleet = self._get_fleet_client() + task_key = f"taskgen_{uuid.uuid4().hex[:12]}" + + try: + # 1. Create Fleet Task object + from fleet.tasks import Task + + task = Task( + key=task_key, + prompt=prompt, + env_id=env_key, + version=env_version or None, + verifier_func=verifier_code, + data_id=data_key, + data_version=data_version, + env_variables=env_variables or {}, + ) + + # 2. Import task to Fleet + import_response = fleet.import_single_task(task) + if import_response is None: + logger.error(f"[{task_key}] Failed to import task to Fleet") + result.num_errors = 1 + result.total_duration_s = time.time() - start_time + return result.to_dict() + + logger.info(f"[{task_key}] Task imported to Fleet") + + # 3. Create harness job + job_response = fleet.create_job( + models=self.models, + task_keys=[task_key], + pass_k=self.k_rollouts, + max_steps=self.max_steps, + mode="tool-use", + name=f"taskgen-eval-{task_key}", + ) + job_id = job_response.job_id + result.job_id = job_id + logger.info( + f"[{task_key}] Harness job created: {job_id} " + f"(models={self.models}, pass_k={self.k_rollouts})" + ) + + # 4. Poll for job completion + job_status = await self._poll_job(fleet, job_id) + if job_status not in ("completed",): + logger.warning( + f"[{task_key}] Job {job_id} ended with status: {job_status}" + ) + result.num_errors = 1 + result.total_duration_s = time.time() - start_time + return result.to_dict() + + # 5. Extract per-session scores + sessions_response = fleet.list_job_sessions(job_id) + for task_group in sessions_response.tasks: + for session in task_group.sessions: + # Normalize model ID: Fleet may return "claude-sonnet-4.5" + # while we configured "anthropic/claude-sonnet-4.5" + matched_id = self._match_model_id(session.model) or session.model + score = 0.0 + if session.verifier_execution and session.verifier_execution.score is not None: + score = float(session.verifier_execution.score) + elif session.verifier_execution and session.verifier_execution.success: + score = 1.0 + + if matched_id in result.results_per_model: + result.results_per_model[matched_id].append(score) + else: + result.results_per_model[matched_id] = [score] + + result.num_sessions += 1 + + logger.info( + f"[{task_key}] Evaluation complete: " + f"{result.num_sessions} sessions across {len(self.models)} models. " + f"Results: {{{', '.join(f'{m}: {scores}' for m, scores in result.results_per_model.items())}}}" + ) + + except Exception as e: + logger.error(f"[{task_key}] Evaluation failed: {e}") + result.num_errors += 1 + + result.total_duration_s = time.time() - start_time + return result.to_dict() + + async def _poll_job(self, fleet, job_id: str) -> str: + """Poll Fleet job until completion or timeout. + + Uses asyncio.sleep to avoid blocking the event loop, allowing + trajectory timeouts to properly cancel evaluations. + + Returns: + Final job status string. + """ + start = time.time() + while time.time() - start < self.max_poll_time_s: + try: + job = fleet.get_job(job_id) + status = job.status + if status in ("completed", "cancelled", "errored"): + return status + except Exception as e: + logger.warning(f"Error polling job {job_id}: {e}") + + await asyncio.sleep(self.poll_interval_s) + + logger.error(f"Job {job_id} timed out after {self.max_poll_time_s}s") + return "timeout" + + +async def evaluate_task( + prompt: str, + verifier_code: str, + env_key: str, + env_version: str = "", + api_key: Optional[str] = None, + k_rollouts: int = 4, + models: Optional[List[str]] = None, +) -> Dict[str, Any]: + """Convenience function for one-off task evaluation. + + Args: + prompt: Task prompt to evaluate + verifier_code: Verifier code for the task + env_key: Fleet environment key + env_version: Fleet environment version + api_key: Fleet API key + k_rollouts: Number of rollouts per model + models: List of Fleet model IDs + + Returns: + Evaluation results dict + """ + evaluator = TaskEvaluator( + api_key=api_key, + k_rollouts=k_rollouts, + models=models, + ) + return await evaluator.evaluate( + prompt=prompt, + verifier_code=verifier_code, + env_key=env_key, + env_version=env_version, + ) diff --git a/src/envs/fleet_env/telemetry.py b/src/envs/fleet_env/telemetry.py new file mode 100644 index 000000000..f824572e5 --- /dev/null +++ b/src/envs/fleet_env/telemetry.py @@ -0,0 +1,95 @@ +"""Thin Logfire wrapper for Fleet environment telemetry. + +Provides structured error/event tracking for fleet task executions. +If configure_fleet_telemetry() is never called, logfire silently drops events. + +All events include a consistent base schema: +- env_key: Environment key (e.g., "github", "amazon") +- env_version: Environment version (e.g., "v0.0.12") +- task_key: Task identifier +- modality: "tool_use" or "computer_use" +""" + +import logfire +from contextvars import ContextVar +from typing import Optional + +# Session context - set once per rollout/task execution +_session_context: ContextVar[dict] = ContextVar("fleet_session_context", default={}) + + +def configure_fleet_telemetry( + token: Optional[str] = None, + environment: str = "training_rollouts", + service_name: str = "openenv-fleet", + **kwargs, +): + """Configure Logfire for Fleet telemetry. + + Args: + token: Logfire API token (or set LOGFIRE_TOKEN env var). + environment: Environment name (default: "training_rollouts"). + service_name: Service name for Logfire (default: "openenv-fleet"). + **kwargs: Additional arguments passed to logfire.configure(). + """ + logfire.configure( + token=token, + service_name=service_name, + environment=environment, + **kwargs, + ) + + +def set_task_context( + *, + env_key: Optional[str] = None, + env_version: Optional[str] = None, + task_key: Optional[str] = None, + modality: Optional[str] = None, +): + """Set the task context for all subsequent telemetry events. + + Call this at the start of each rollout/task execution. + """ + ctx = {} + if env_key: + ctx["env_key"] = env_key + if env_version: + ctx["env_version"] = env_version + if task_key: + ctx["task_key"] = task_key + if modality: + ctx["modality"] = modality + _session_context.set(ctx) + + +def clear_task_context(): + """Clear the task context.""" + _session_context.set({}) + + +def _with_context(**attrs) -> dict: + """Merge session context with event-specific attributes.""" + ctx = _session_context.get().copy() + ctx.update(attrs) + return ctx + + +def fleet_info(msg: str, **attrs): + """Log a structured info event.""" + logfire.info(msg, **_with_context(**attrs)) + + +def fleet_warning(msg: str, **attrs): + """Log a structured warning event.""" + logfire.warn(msg, **_with_context(**attrs)) + + +def fleet_error(msg: str, **attrs): + """Log a structured error event.""" + logfire.error(msg, **_with_context(**attrs)) + + +def fleet_exception(msg: str, **attrs): + """Log a structured error with exception info (use inside except blocks).""" + logfire.error(msg, _exc_info=True, **_with_context(**attrs)) diff --git a/src/envs/fleet_env/trace.py b/src/envs/fleet_env/trace.py new file mode 100644 index 000000000..54296954d --- /dev/null +++ b/src/envs/fleet_env/trace.py @@ -0,0 +1,120 @@ +"""Fleet trace upload utilities for eval rollouts. + +Provides functions to create trace jobs and upload conversation traces +to the Fleet API for viewing in the Fleet UI (including screenshots). +""" + +import logging +from typing import Any, Dict, List, Optional + +logger = logging.getLogger(__name__) + + +def _convert_image_block(block: Dict[str, Any]) -> Dict[str, Any]: + """Convert an OpenAI image_url block to Fleet ingest image format. + + Fleet ingest API expects: {"type": "image", "mime_type": "image/png", "data": ""} + It then uploads base64 to S3 and replaces with URL for the UI to render. + """ + url = block.get("image_url", {}).get("url", "") + if url.startswith("data:"): + # data:image/png;base64,ABC... -> extract mime_type and base64 data + header, base64_data = url.split(",", 1) + mime_type = header.split(":")[1].split(";")[0] if ":" in header else "image/png" + return {"type": "image", "mime_type": mime_type, "data": base64_data} + else: + # HTTPS URL - pass as text since ingest API expects base64 for images + return {"type": "text", "text": url} + + +def _convert_content(content: Any) -> Any: + """Convert OpenAI-format content blocks to Anthropic format for Fleet UI.""" + if not isinstance(content, list): + return content + converted = [] + for block in content: + if isinstance(block, dict) and block.get("type") == "image_url": + converted.append(_convert_image_block(block)) + else: + converted.append(block) + return converted + + +async def create_trace_job(api_key: str, name: str) -> str: + """Create a Fleet trace job for grouping eval traces. + + Args: + api_key: Fleet API key. + name: Name for the trace job (e.g. "run_name_step_100"). + + Returns: + The job_id string. + """ + from fleet._async import AsyncFleet + + fleet = AsyncFleet(api_key=api_key) + return await fleet.trace_job(name=name) + + +async def upload_trace( + api_key: str, + job_id: str, + task_key: str, + model: str, + chat_history: List[Dict[str, Any]], + reward: float, + instance_id: Optional[str] = None, + metadata: Optional[Dict[str, Any]] = None, +) -> Optional[str]: + """Upload a conversation trace to the Fleet API. + + Converts chat_history (OpenAI message format) to Fleet SessionIngestMessage + format and ingests it as a trace session. + + Args: + api_key: Fleet API key. + job_id: Trace job ID from create_trace_job(). + task_key: Fleet task key. + model: Model identifier (e.g. model path or name). + chat_history: List of messages in OpenAI format (system/user/assistant). + May contain multimodal content with image_url entries. + reward: Episode reward (>0 = completed, else failed). + instance_id: Optional Fleet environment instance ID. + metadata: Optional additional metadata dict. + + Returns: + The session_id string, or None if upload failed. + """ + try: + import httpx + + # Convert chat_history to ingest message format. + # Fleet ingest API expects image blocks as: {"type": "image", "mime_type": ..., "data": ...} + messages = [ + {"role": msg["role"], "content": _convert_content(msg.get("content"))} + for msg in chat_history + ] + + payload: Dict[str, Any] = { + "messages": messages, + "job_id": job_id, + "task_key": task_key, + "model": model, + "score": reward, + } + if instance_id: + payload["instance_id"] = instance_id + if metadata: + payload["metadata"] = metadata + + async with httpx.AsyncClient(timeout=60) as client: + response = await client.post( + "https://orchestrator.fleetai.com/v1/sessions/ingest", + json=payload, + headers={"Authorization": f"Bearer {api_key}"}, + ) + response.raise_for_status() + return response.json().get("session_id") + except Exception as e: + logger.warning(f"Failed to upload trace for {task_key}: {e}") + return None diff --git a/src/pyproject.toml b/src/pyproject.toml index 067237115..7cb404917 100644 --- a/src/pyproject.toml +++ b/src/pyproject.toml @@ -35,6 +35,13 @@ dev = [ "mypy>=1.0.0", ] +# Fleet runtime integration (optional) +fleet = [ + "mcp>=1.0.0", + "fleet-python>=0.2.79", + "openai>=2.11.0", +] + [project.scripts] openenv = "openenv_cli.__main__:main" diff --git a/tests/envs/test_fleet_env.py b/tests/envs/test_fleet_env.py new file mode 100644 index 000000000..2b9d87e03 --- /dev/null +++ b/tests/envs/test_fleet_env.py @@ -0,0 +1,727 @@ +import sys +import types + +import pytest + + +class _FakeResp: + def __init__(self, payload): + self._payload = payload + self.status_code = 200 + + def raise_for_status(self): + return None + + def json(self): + return self._payload + + +class _FakeSession: + def __init__(self): + self.calls = [] + + def post(self, url, json=None, headers=None, timeout=None): + self.calls.append(("POST", url, json)) + return _FakeResp( + {"observation": {"metadata": {}}, "reward": 0.0, "done": False} + ) + + def get(self, url, headers=None, timeout=None): + self.calls.append(("GET", url, None)) + return _FakeResp({"episode_id": "e1", "step_count": 0}) + + +@pytest.fixture +def anyio_backend(): + # Avoid running the anyio test against trio (not installed in this repo env). + return "asyncio" + + +@pytest.fixture +def fake_requests_session(monkeypatch): + # Avoid importing real `requests` in this sandboxed environment (it may fail + # while loading system CA bundles). core.http_env_client only needs Session. + fake_requests = types.SimpleNamespace(Session=_FakeSession) + monkeypatch.setitem(sys.modules, "requests", fake_requests) + + +@pytest.fixture +def fake_fleet_module(monkeypatch): + # Create a fake `fleet` module with Fleet.make returning an env with urls. + class _Urls: + def __init__(self): + self.root = "https://example/" + + class _Mgr: + api = "https://example/api/v1/env" + + self.manager = _Mgr() + + class _Env: + def __init__(self): + self.urls = _Urls() + self.closed = False + + def close(self): + self.closed = True + + class _Fleet: + def __init__(self, api_key=None): + self.api_key = api_key + + def make(self, **kwargs): + return _Env() + + mod = types.SimpleNamespace(Fleet=_Fleet) + monkeypatch.setitem(sys.modules, "fleet", mod) + + +@pytest.mark.usefixtures("fake_requests_session", "fake_fleet_module") +def test_fleet_env_from_fleet_returns_orchestrator_and_tools(): + from envs.fleet_env import FleetEnvClient, FleetMCPTools + + orch, tools = FleetEnvClient.from_fleet(api_key="k", env_key="e") + assert isinstance(orch, FleetEnvClient) + assert isinstance(tools, FleetMCPTools) + + +@pytest.mark.usefixtures("fake_requests_session", "fake_fleet_module") +def test_fleet_env_reset_uses_http_manager_base_url(): + from envs.fleet_env import FleetEnvClient + + orch, _tools = FleetEnvClient.from_fleet(api_key="k", env_key="e") + # reset() should hit {base}/reset + _ = orch.reset() + # access underlying fake session calls + calls = orch._http.calls # pylint: disable=protected-access + assert calls[-1][0] == "POST" + assert calls[-1][1].endswith("/reset") + + +@pytest.mark.usefixtures("fake_requests_session", "fake_fleet_module") +def test_fleet_env_step_rejects_tool_actions(): + from envs.fleet_env import FleetEnvClient, CallToolAction + + orch, _tools = FleetEnvClient.from_fleet(api_key="k", env_key="e") + with pytest.raises(TypeError): + orch.step( + CallToolAction(tool_name="computer", parameters={"action": "screenshot"}) + ) + + +@pytest.mark.anyio +async def test_agent_tools_list_and_call_routes(monkeypatch): + from envs.fleet_env.mcp_tools import FleetMCPTools + + class _Tool: + def __init__(self, name): + self.name = name + self.description = "" + self.inputSchema = {"type": "object", "properties": {}, "required": []} + + class _FakeMCPClient: + def __init__(self, url, api_key): + self.url = url + self.api_key = api_key + self.list_calls = 0 + + async def list_tools(self): + self.list_calls += 1 + if self.url.endswith("api/v1/mcp"): + return [_Tool("computer")] + return [_Tool("search_issues")] + + async def call_tool(self, name, args): + return {"url": self.url, "name": name, "args": args} + + def has_tool(self, name, tools_list=None): + return any(t.name == name for t in (tools_list or [])) + + monkeypatch.setattr("envs.fleet_env.mcp_tools.FleetMCPClient", _FakeMCPClient) + + tools = FleetMCPTools( + api_key="k", mcp_urls=("https://x/api/v1/mcp", "https://x/mcp") + ) + listed = await tools.list_tools() + assert sorted([t["function"]["name"] for t in listed.tools]) == [ + "computer", + "search_issues", + ] + + res = await tools.call_tool("computer", {"action": "screenshot"}) + assert res["url"].endswith("api/v1/mcp") + + +class TestFleetMCPClientExtractToolResult: + """Tests for FleetMCPClient._extract_tool_result().""" + + def test_extract_single_text_content(self): + """Should extract text from single TextContent.""" + from envs.fleet_env.fleet_mcp_client import FleetMCPClient + + client = FleetMCPClient(url="http://test", api_key="test") + + # Mock CallToolResult with TextContent + class _TextContent: + type = "text" + text = "file1.txt\nfile2.txt" + + class _Result: + content = [_TextContent()] + isError = False + structuredContent = None + + result = client._extract_tool_result(_Result()) + assert result == "file1.txt\nfile2.txt" + + def test_extract_json_text_content(self): + """Should parse JSON from text content.""" + from envs.fleet_env.fleet_mcp_client import FleetMCPClient + + client = FleetMCPClient(url="http://test", api_key="test") + + class _TextContent: + type = "text" + text = '{"status": "success", "count": 42}' + + class _Result: + content = [_TextContent()] + isError = False + structuredContent = None + + result = client._extract_tool_result(_Result()) + assert result == {"status": "success", "count": 42} + + def test_extract_multiple_text_contents(self): + """Should return list when multiple text contents.""" + from envs.fleet_env.fleet_mcp_client import FleetMCPClient + + client = FleetMCPClient(url="http://test", api_key="test") + + class _TextContent1: + type = "text" + text = "first" + + class _TextContent2: + type = "text" + text = "second" + + class _Result: + content = [_TextContent1(), _TextContent2()] + isError = False + structuredContent = None + + result = client._extract_tool_result(_Result()) + assert result == ["first", "second"] + + def test_extract_error_result(self): + """Should return error dict when isError=True.""" + from envs.fleet_env.fleet_mcp_client import FleetMCPClient + + client = FleetMCPClient(url="http://test", api_key="test") + + class _TextContent: + type = "text" + text = "Tool failed: permission denied" + + class _Result: + content = [_TextContent()] + isError = True + structuredContent = None + + result = client._extract_tool_result(_Result()) + assert result == {"error": "Tool failed: permission denied"} + + def test_extract_structured_content_fallback(self): + """Should use structuredContent when no text content.""" + from envs.fleet_env.fleet_mcp_client import FleetMCPClient + + client = FleetMCPClient(url="http://test", api_key="test") + + class _Result: + content = [] + isError = False + structuredContent = {"data": [1, 2, 3]} + + result = client._extract_tool_result(_Result()) + assert result == {"data": [1, 2, 3]} + + def test_extract_empty_result(self): + """Should return string repr for empty result.""" + from envs.fleet_env.fleet_mcp_client import FleetMCPClient + + client = FleetMCPClient(url="http://test", api_key="test") + + class _Result: + content = [] + isError = False + structuredContent = None + + def __str__(self): + return "EmptyResult()" + + result = client._extract_tool_result(_Result()) + assert result == "EmptyResult()" + + def test_extract_image_content(self): + """Should extract ImageContent as OpenAI-compatible format.""" + from envs.fleet_env.fleet_mcp_client import FleetMCPClient + + client = FleetMCPClient(url="http://test", api_key="test") + + # Mock MCP ImageContent + class _ImageContent: + type = "image" + data = "iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mNk+M9QDwADhgGAWjR9awAAAABJRU5ErkJggg==" + mimeType = "image/png" + + class _Result: + content = [_ImageContent()] + isError = False + structuredContent = None + + result = client._extract_tool_result(_Result()) + + # Should return list with single image_url item + assert isinstance(result, list) + assert len(result) == 1 + assert result[0]["type"] == "image_url" + assert "image_url" in result[0] + assert result[0]["image_url"]["url"].startswith("data:image/png;base64,") + + def test_extract_mixed_text_and_image_content(self): + """Should extract mixed text and image content.""" + from envs.fleet_env.fleet_mcp_client import FleetMCPClient + + client = FleetMCPClient(url="http://test", api_key="test") + + class _TextContent: + type = "text" + text = "Screenshot captured" + + class _ImageContent: + type = "image" + data = "base64imagedata" + mimeType = "image/jpeg" + + class _Result: + content = [_TextContent(), _ImageContent()] + isError = False + structuredContent = None + + result = client._extract_tool_result(_Result()) + + # Should return list with both items + assert isinstance(result, list) + assert len(result) == 2 + assert result[0]["type"] == "text" + assert result[0]["text"] == "Screenshot captured" + assert result[1]["type"] == "image_url" + assert result[1]["image_url"]["url"] == "data:image/jpeg;base64,base64imagedata" + + def test_extract_image_default_mimetype(self): + """Should default to image/png when mimeType is missing.""" + from envs.fleet_env.fleet_mcp_client import FleetMCPClient + + client = FleetMCPClient(url="http://test", api_key="test") + + class _ImageContent: + type = "image" + data = "somebase64data" + mimeType = None # Missing mimeType + + class _Result: + content = [_ImageContent()] + isError = False + structuredContent = None + + result = client._extract_tool_result(_Result()) + + assert isinstance(result, list) + assert result[0]["image_url"]["url"].startswith("data:image/png;base64,") + + def test_extract_base64_image_json_format(self): + """Should convert Fleet MCP's base64_image JSON format to OpenAI format.""" + from envs.fleet_env.fleet_mcp_client import FleetMCPClient + + client = FleetMCPClient(url="http://test", api_key="test") + + # Fleet MCP returns screenshot as JSON text with base64_image key + class _TextContent: + type = "text" + text = '{"base64_image": "data:image/jpeg;base64,/9j/4AAQSkZJRg..."}' + + class _Result: + content = [_TextContent()] + isError = False + structuredContent = None + + result = client._extract_tool_result(_Result()) + + # Should be converted to OpenAI-compatible format + assert isinstance(result, list) + assert len(result) == 1 + assert result[0]["type"] == "image_url" + assert ( + result[0]["image_url"]["url"] == "data:image/jpeg;base64,/9j/4AAQSkZJRg..." + ) + + def test_extract_base64_image_preserves_other_json(self): + """Should preserve normal JSON responses that don't have base64_image.""" + from envs.fleet_env.fleet_mcp_client import FleetMCPClient + + client = FleetMCPClient(url="http://test", api_key="test") + + # Normal JSON response without base64_image + class _TextContent: + type = "text" + text = '{"status": "success", "data": [1, 2, 3]}' + + class _Result: + content = [_TextContent()] + isError = False + structuredContent = None + + result = client._extract_tool_result(_Result()) + + # Should return parsed dict as-is + assert isinstance(result, dict) + assert result["status"] == "success" + assert result["data"] == [1, 2, 3] + + +@pytest.fixture +def fake_fleet_module_with_db(monkeypatch): + """Fake fleet module whose env handle supports .db() for query/describe.""" + + class _Urls: + def __init__(self): + self.root = "https://example/" + + class _Mgr: + api = "https://example/api/v1/env" + + self.manager = _Mgr() + + class _DescribeResponse: + def model_dump(self): + return { + "success": True, + "resource_name": "seed", + "tables": [ + { + "name": "events", + "sql": "CREATE TABLE events (id INTEGER, title TEXT)", + "columns": [ + { + "name": "id", + "type": "INTEGER", + "notnull": True, + "primary_key": True, + }, + { + "name": "title", + "type": "TEXT", + "notnull": False, + "primary_key": False, + }, + ], + } + ], + "message": "Schema retrieved", + } + + class _QueryResponse: + def __init__(self, sql): + self._sql = sql + + def model_dump(self): + return { + "success": True, + "columns": ["id", "title"], + "rows": [[1, "Concert A"], [2, "Concert B"]], + "rows_affected": None, + "last_insert_id": None, + "error": None, + "message": "Query executed successfully", + } + + class _SQLiteResource: + def __init__(self, name): + self.name = name + + def describe(self): + return _DescribeResponse() + + def query(self, sql, args=None): + return _QueryResponse(sql) + + class _Env: + def __init__(self): + self.urls = _Urls() + self.closed = False + + def db(self, name="current"): + return _SQLiteResource(name) + + def close(self): + self.closed = True + + class _Fleet: + def __init__(self, api_key=None): + self.api_key = api_key + + def make(self, **kwargs): + return _Env() + + mod = types.SimpleNamespace(Fleet=_Fleet) + monkeypatch.setitem(sys.modules, "fleet", mod) + + +@pytest.mark.usefixtures("fake_requests_session", "fake_fleet_module_with_db") +class TestFleetEnvClientDbQuery: + """Tests for FleetEnvClient.describe_db / query_db.""" + + def test_describe_db_returns_schema(self): + from envs.fleet_env import FleetEnvClient + + orch, _ = FleetEnvClient.from_fleet( + api_key="k", + env_key="e", + data_key="d", + data_version="v1", + image_type="standard", + ) + result = orch.describe_db("seed") + assert result["success"] is True + assert len(result["tables"]) == 1 + assert result["tables"][0]["name"] == "events" + assert len(result["tables"][0]["columns"]) == 2 + + def test_describe_db_defaults_to_seed(self): + from envs.fleet_env import FleetEnvClient + + orch, _ = FleetEnvClient.from_fleet( + api_key="k", + env_key="e", + data_key="d", + data_version="v1", + image_type="standard", + ) + result = orch.describe_db() + assert result["resource_name"] == "seed" + + def test_query_db_returns_rows(self): + from envs.fleet_env import FleetEnvClient + + orch, _ = FleetEnvClient.from_fleet( + api_key="k", + env_key="e", + data_key="d", + data_version="v1", + image_type="standard", + ) + result = orch.query_db("SELECT * FROM events LIMIT 2") + assert result["success"] is True + assert result["columns"] == ["id", "title"] + assert len(result["rows"]) == 2 + assert result["rows"][0] == [1, "Concert A"] + + def test_query_db_defaults_to_seed(self): + from envs.fleet_env import FleetEnvClient + + orch, _ = FleetEnvClient.from_fleet( + api_key="k", + env_key="e", + data_key="d", + data_version="v1", + image_type="standard", + ) + result = orch.query_db("SELECT 1") + assert result["success"] is True + + +@pytest.mark.usefixtures("fake_requests_session", "fake_fleet_module_with_db") +class TestFleetEnvClientDbQueryAsync: + """Tests for async describe_db_async / query_db_async.""" + + @pytest.mark.anyio + async def test_describe_db_async(self): + from envs.fleet_env import FleetEnvClient + + orch, _ = FleetEnvClient.from_fleet( + api_key="k", + env_key="e", + data_key="d", + data_version="v1", + image_type="standard", + ) + result = await orch.describe_db_async("seed") + assert result["success"] is True + assert result["tables"][0]["name"] == "events" + + @pytest.mark.anyio + async def test_query_db_async(self): + from envs.fleet_env import FleetEnvClient + + orch, _ = FleetEnvClient.from_fleet( + api_key="k", + env_key="e", + data_key="d", + data_version="v1", + image_type="standard", + ) + result = await orch.query_db_async("SELECT * FROM events") + assert result["success"] is True + assert len(result["rows"]) == 2 + + +class TestFleetTaskEnvInitFetchesTools: + """Tests for FleetTaskEnv provisioning and fetching tools during reset().""" + + def test_init_defers_provisioning(self, monkeypatch): + """__init__ should NOT provision — provisioning is deferred to reset_async().""" + from envs.fleet_env.task_env import FleetTaskEnv + + task_config = { + "task_key": "test-task", + "prompt": "Test prompt", + "env_key": "test-env", + "task_modality": "tool_use", + } + + # __init__ should not call Fleet.make() — just store config + env = FleetTaskEnv(task_config, api_key="test-key") + + # Not provisioned yet + assert env._orch is None + assert env._tools is None + assert env._tools_cache is None + + def test_reset_provisions_and_returns_tools(self, monkeypatch): + """reset() should provision asynchronously and return tools.""" + from unittest.mock import MagicMock + + mock_orch = MagicMock() + mock_tools = MagicMock() + + # Create a proper coroutine for list_tools + async def mock_list_tools(): + return MagicMock( + tools=[{"type": "function", "function": {"name": "search"}}] + ) + + mock_tools.list_tools = mock_list_tools + + # Mock from_fleet_async (async classmethod) + async def mock_from_fleet_async(**kwargs): + return (mock_orch, mock_tools) + + monkeypatch.setattr( + "envs.fleet_env.task_env.FleetEnvClient.from_fleet_async", + mock_from_fleet_async, + ) + + from envs.fleet_env.task_env import FleetTaskEnv + + task_config = { + "task_key": "test-task", + "prompt": "Test prompt", + "env_key": "test-env", + "task_modality": "tool_use", + } + + env = FleetTaskEnv(task_config, api_key="test-key") + + # reset triggers provisioning + tool fetching + obs = env.reset() + + assert "tools" in obs + assert len(obs["tools"]) == 1 + assert obs["tools"][0]["function"]["name"] == "search" + + def test_reset_sync_returns_cached_tools(self, monkeypatch): + """Sync reset() should provision and return tools.""" + from unittest.mock import MagicMock + + mock_orch = MagicMock() + mock_tools = MagicMock() + + # Create a proper coroutine for list_tools + async def mock_list_tools(): + return MagicMock(tools=[{"type": "function", "function": {"name": "bash"}}]) + + mock_tools.list_tools = mock_list_tools + + # Mock from_fleet_async (async classmethod) + async def mock_from_fleet_async(**kwargs): + return (mock_orch, mock_tools) + + monkeypatch.setattr( + "envs.fleet_env.task_env.FleetEnvClient.from_fleet_async", + mock_from_fleet_async, + ) + + from envs.fleet_env.task_env import FleetTaskEnv + + task_config = { + "task_key": "test-task", + "prompt": "Test prompt", + "env_key": "test-env", + "task_modality": "tool_use", + } + + env = FleetTaskEnv(task_config, api_key="test-key") + + # Sync reset should provision and return tools + obs = env.reset() + + assert "tools" in obs + assert len(obs["tools"]) == 1 + assert obs["tools"][0]["function"]["name"] == "bash" + + def test_init_failure_emits_rollout_completed(self, monkeypatch): + """Init failure should emit fleet_rollout_started AND fleet_rollout_completed.""" + from unittest.mock import patch + + # Mock from_fleet_async to raise (simulates health check failure) + async def mock_from_fleet_async(**kwargs): + raise RuntimeError("health check failed") + + monkeypatch.setattr( + "envs.fleet_env.task_env.FleetEnvClient.from_fleet_async", + mock_from_fleet_async, + ) + + from envs.fleet_env.task_env import FleetTaskEnv + + task_config = { + "task_key": "test-task", + "prompt": "Test prompt", + "env_key": "fostgres", + "task_modality": "tool_use", + } + + env = FleetTaskEnv(task_config, api_key="test-key") + + telemetry_events = [] + + def capture_info(msg, **attrs): + telemetry_events.append((msg, attrs)) + + with patch("envs.fleet_env.task_env.fleet_info", capture_info): + with pytest.raises(RuntimeError, match="health check"): + env.reset() + + # Should have emitted both started and completed + event_names = [e[0] for e in telemetry_events] + assert "fleet_rollout_started" in event_names + assert "fleet_rollout_completed" in event_names + + # fleet_rollout_completed should have failure_reason="init_error" + completed = next( + e for e in telemetry_events if e[0] == "fleet_rollout_completed" + ) + assert completed[1]["failure_reason"] == "init_error" + assert completed[1]["reward"] == 0.0 + assert completed[1]["step_count"] == 0 diff --git a/tests/envs/test_fleet_task_env.py b/tests/envs/test_fleet_task_env.py new file mode 100644 index 000000000..859bf7e14 --- /dev/null +++ b/tests/envs/test_fleet_task_env.py @@ -0,0 +1,586 @@ +"""Unit tests for FleetTaskEnv.""" + +import pytest +from unittest.mock import AsyncMock, MagicMock, patch + + +@pytest.fixture +def anyio_backend(): + return "asyncio" + + +@pytest.fixture +def sample_task_config(): + """Sample task configuration for testing.""" + return { + "task_key": "test-task-001", + "prompt": "Search for flights from NYC to LA on January 15", + "env_key": "booking-com", + "env_version": "v1.2.3", + "data_key": "consumer", + "data_version": "v0.0.12", + "verifier_code": "async def verify(env): return True", + "task_modality": "tool_use", + } + + +@pytest.fixture +def sample_task_config_no_version(): + """Task config without version info.""" + return { + "task_key": "test-task-002", + "prompt": "Test prompt", + "env_key": "test-env", + "task_modality": "tool_use", + } + + +@pytest.fixture +def mock_fleet_env_client(): + """Create a mock FleetEnvClient.from_fleet that returns mocks. + + Returns tools=None to avoid triggering asyncio.run() in __init__ + which conflicts with pytest-asyncio's event loop. + """ + mock_orch = MagicMock() + mock_orch._fleet_env = MagicMock() # Fleet env handle for verifier + + with patch("envs.fleet_env.task_env.FleetEnvClient") as MockClient: + # Return tools=None to skip the asyncio.run(list_tools()) call in __init__ + MockClient.from_fleet.return_value = (mock_orch, None) + yield mock_orch, None + + +class TestFleetTaskEnvInit: + """Tests for FleetTaskEnv initialization.""" + + def test_init_with_api_key(self, sample_task_config, mock_fleet_env_client): + """Should initialize with explicit API key.""" + from envs.fleet_env.task_env import FleetTaskEnv + + env = FleetTaskEnv(sample_task_config, api_key="test-api-key") + assert env.api_key == "test-api-key" + assert env.task_key == "test-task-001" + assert env.prompt == "Search for flights from NYC to LA on January 15" + assert env.modality == "tool_use" + + def test_init_from_env_var( + self, sample_task_config, mock_fleet_env_client, monkeypatch + ): + """Should use FLEET_API_KEY env var if no api_key provided.""" + from envs.fleet_env.task_env import FleetTaskEnv + + monkeypatch.setenv("FLEET_API_KEY", "env-api-key") + env = FleetTaskEnv(sample_task_config) + assert env.api_key == "env-api-key" + + def test_init_raises_without_api_key(self, sample_task_config, monkeypatch): + """Should raise if no API key available.""" + from envs.fleet_env.task_env import FleetTaskEnv + + monkeypatch.delenv("FLEET_API_KEY", raising=False) + with pytest.raises(ValueError, match="Fleet API key required"): + FleetTaskEnv(sample_task_config) + + +class TestFleetTaskEnvSpecs: + """Tests for env/data spec building.""" + + def test_build_env_spec_with_version( + self, sample_task_config, mock_fleet_env_client + ): + """Should build env_key:version spec.""" + from envs.fleet_env.task_env import FleetTaskEnv + + env = FleetTaskEnv(sample_task_config, api_key="test") + spec = env._build_env_spec() + assert spec == "booking-com:v1.2.3" + + def test_build_env_spec_without_version( + self, sample_task_config_no_version, mock_fleet_env_client + ): + """Should return just env_key when no version.""" + from envs.fleet_env.task_env import FleetTaskEnv + + env = FleetTaskEnv(sample_task_config_no_version, api_key="test") + spec = env._build_env_spec() + assert spec == "test-env" + + def test_get_data_key_with_data(self, sample_task_config, mock_fleet_env_client): + """Should return data_key from config.""" + from envs.fleet_env.task_env import FleetTaskEnv + + env = FleetTaskEnv(sample_task_config, api_key="test") + assert env._get_data_key() == "consumer" + assert env._get_data_version() == "v0.0.12" + + def test_get_data_key_without_data( + self, sample_task_config_no_version, mock_fleet_env_client + ): + """Should return None when no data_key.""" + from envs.fleet_env.task_env import FleetTaskEnv + + env = FleetTaskEnv(sample_task_config_no_version, api_key="test") + assert env._get_data_key() is None + assert env._get_data_version() is None + + def test_build_env_spec_raises_without_env_key(self, mock_fleet_env_client): + """Should raise when env_key is missing during init.""" + from envs.fleet_env.task_env import FleetTaskEnv + + task = {"task_key": "test", "prompt": "test"} + # The error is raised during __init__ when _build_env_spec is called + with pytest.raises(ValueError, match="missing env_key"): + FleetTaskEnv(task, api_key="test") + + +class TestFleetTaskEnvVerifier: + """Tests for verifier execution using Fleet SDK.""" + + @pytest.mark.anyio + async def test_compute_reward_returns_score_on_success( + self, sample_task_config, mock_fleet_env_client + ): + """Should return verifier result score when Fleet SDK verifier succeeds.""" + from envs.fleet_env.task_env import FleetTaskEnv + + mock_orch, _ = mock_fleet_env_client + env = FleetTaskEnv(sample_task_config, api_key="test") + + # Mock Fleet SDK Task.verify_detailed + mock_response = MagicMock() + mock_response.success = True + mock_response.result = 1.0 + + with patch("fleet.tasks.Task") as MockTask: + mock_task = MagicMock() + mock_task.verify_detailed.return_value = mock_response + MockTask.return_value = mock_task + + result = await env._compute_reward() + assert result == 1.0 + mock_task.verify_detailed.assert_called_once_with(mock_orch._fleet_env) + + @pytest.mark.anyio + async def test_compute_reward_returns_zero_on_failure( + self, sample_task_config, mock_fleet_env_client + ): + """Should return 0.0 when Fleet SDK verifier fails.""" + from envs.fleet_env.task_env import FleetTaskEnv + + env = FleetTaskEnv(sample_task_config, api_key="test") + + # Mock Fleet SDK Task.verify_detailed with failure + mock_response = MagicMock() + mock_response.success = False + mock_response.result = None + + with patch("fleet.tasks.Task") as MockTask: + mock_task = MagicMock() + mock_task.verify_detailed.return_value = mock_response + MockTask.return_value = mock_task + + result = await env._compute_reward() + assert result == 0.0 + + @pytest.mark.anyio + async def test_compute_reward_returns_zero_when_no_verifier( + self, sample_task_config_no_version, mock_fleet_env_client + ): + """Should return 0.0 when no verifier code is present.""" + from envs.fleet_env.task_env import FleetTaskEnv + + env = FleetTaskEnv(sample_task_config_no_version, api_key="test") + + result = await env._compute_reward() + assert result == 0.0 + + @pytest.mark.anyio + async def test_compute_reward_returns_zero_when_no_orch( + self, sample_task_config, mock_fleet_env_client + ): + """Should return 0.0 when no orchestrator is available.""" + from envs.fleet_env.task_env import FleetTaskEnv + + env = FleetTaskEnv(sample_task_config, api_key="test") + env._orch = None + + result = await env._compute_reward() + assert result == 0.0 + + @pytest.mark.anyio + async def test_compute_reward_returns_zero_when_no_fleet_env( + self, sample_task_config, mock_fleet_env_client + ): + """Should return 0.0 when no Fleet env handle is available.""" + from envs.fleet_env.task_env import FleetTaskEnv + + env = FleetTaskEnv(sample_task_config, api_key="test") + env._orch._fleet_env = None # No Fleet env handle + + result = await env._compute_reward() + assert result == 0.0 + + @pytest.mark.anyio + async def test_compute_reward_handles_verifier_exception( + self, sample_task_config, mock_fleet_env_client + ): + """Should return 0.0 when verifier raises an exception.""" + from envs.fleet_env.task_env import FleetTaskEnv + + env = FleetTaskEnv(sample_task_config, api_key="test") + + with patch("fleet.tasks.Task") as MockTask: + mock_task = MagicMock() + mock_task.verify_detailed.side_effect = Exception("Verifier error") + MockTask.return_value = mock_task + + result = await env._compute_reward() + assert result == 0.0 + + @pytest.mark.anyio + async def test_compute_reward_handles_success_with_none_result( + self, sample_task_config, mock_fleet_env_client + ): + """Should return 1.0 when verifier succeeds but returns None.""" + from envs.fleet_env.task_env import FleetTaskEnv + + env = FleetTaskEnv(sample_task_config, api_key="test") + + mock_response = MagicMock() + mock_response.success = True + mock_response.result = None + + with patch("fleet.tasks.Task") as MockTask: + mock_task = MagicMock() + mock_task.verify_detailed.return_value = mock_response + MockTask.return_value = mock_task + + result = await env._compute_reward() + assert result == 1.0 + + @pytest.mark.anyio + async def test_compute_reward_supports_verifier_func_field( + self, mock_fleet_env_client + ): + """Should support 'verifier_func' field name (Fleet SDK format).""" + from envs.fleet_env.task_env import FleetTaskEnv + + # Task config using 'verifier_func' instead of 'verifier_code' + task_config = { + "task_key": "test-task-003", + "prompt": "Test prompt", + "env_key": "test-env", + "verifier_func": "def verify(env): return 1.0", # Fleet SDK field name + "task_modality": "tool_use", + } + + env = FleetTaskEnv(task_config, api_key="test") + + mock_response = MagicMock() + mock_response.success = True + mock_response.result = 1.0 + + with patch("fleet.tasks.Task") as MockTask: + mock_task = MagicMock() + mock_task.verify_detailed.return_value = mock_response + MockTask.return_value = mock_task + + result = await env._compute_reward() + assert result == 1.0 + + +class TestFleetTaskEnvFactories: + """Tests for factory methods.""" + + def test_make_fleet_task_env(self, sample_task_config, mock_fleet_env_client): + """Should create FleetTaskEnv via factory function.""" + from envs.fleet_env.task_env import make_fleet_task_env + + env = make_fleet_task_env(sample_task_config, api_key="test") + assert env.task_key == "test-task-001" + + +class TestFleetTaskEnvContextManager: + """Tests for context manager protocol.""" + + def test_context_manager_closes_on_exit( + self, sample_task_config, mock_fleet_env_client + ): + """Should close environment on context exit.""" + from envs.fleet_env.task_env import FleetTaskEnv + + env = FleetTaskEnv(sample_task_config, api_key="test") + + with env: + pass # Context enters and exits + + # Environment should be closed + assert env._orch is None + assert env._tools is None + assert env._done is True + + +class TestFleetTaskEnvProperties: + """Tests for property accessors.""" + + def test_task_key_property(self, sample_task_config, mock_fleet_env_client): + """Should return task_key from config.""" + from envs.fleet_env.task_env import FleetTaskEnv + + env = FleetTaskEnv(sample_task_config, api_key="test") + assert env.task_key == "test-task-001" + + def test_task_key_default(self, mock_fleet_env_client): + """Should return 'unknown' when task_key missing.""" + from envs.fleet_env.task_env import FleetTaskEnv + + task = {"prompt": "test", "env_key": "test-env"} + env = FleetTaskEnv(task, api_key="test") + assert env.task_key == "unknown" + + def test_prompt_property(self, sample_task_config, mock_fleet_env_client): + """Should return prompt from config.""" + from envs.fleet_env.task_env import FleetTaskEnv + + env = FleetTaskEnv(sample_task_config, api_key="test") + assert env.prompt == "Search for flights from NYC to LA on January 15" + + def test_modality_property(self, sample_task_config, mock_fleet_env_client): + """Should return task_modality from config.""" + from envs.fleet_env.task_env import FleetTaskEnv + + env = FleetTaskEnv(sample_task_config, api_key="test") + assert env.modality == "tool_use" + + def test_modality_default(self, mock_fleet_env_client): + """Should default to 'tool_use' when modality missing.""" + from envs.fleet_env.task_env import FleetTaskEnv + + task = {"task_key": "test", "prompt": "test", "env_key": "test-env"} + env = FleetTaskEnv(task, api_key="test") + assert env.modality == "tool_use" + + +class TestFleetTaskEnvComputerUseFiltering: + """Tests for computer_use modality tool filtering.""" + + @pytest.fixture + def mock_fleet_env_with_tools(self): + """Create mock FleetEnvClient that returns tools.""" + mock_orch = MagicMock() + mock_tools = MagicMock() + + with patch("envs.fleet_env.task_env.FleetEnvClient") as MockClient: + MockClient.from_fleet.return_value = (mock_orch, mock_tools) + yield mock_orch, mock_tools + + @pytest.mark.anyio + async def test_computer_use_filters_to_computer_tool( + self, mock_fleet_env_with_tools + ): + """Should filter to only 'computer' tool for computer_use modality.""" + from envs.fleet_env.task_env import FleetTaskEnv + + mock_orch, mock_tools = mock_fleet_env_with_tools + + # Mock list_tools returning mixed tools (computer + API tools) + async def mock_list_tools(): + return MagicMock( + tools=[ + {"name": "computer", "description": "Mouse/keyboard control"}, + {"name": "search_issues", "description": "Search issues"}, + {"name": "create_ticket", "description": "Create ticket"}, + ] + ) + + mock_tools.list_tools = mock_list_tools + + task_config = { + "task_key": "test-task", + "prompt": "Click on button", + "env_key": "test-env", + "task_modality": "computer_use", + } + + env = FleetTaskEnv(task_config, api_key="test") + obs = await env.reset_async() + + # Should only have computer tool + assert len(env._tools_cache) == 1 + assert env._tools_cache[0]["name"] == "computer" + + @pytest.mark.anyio + async def test_computer_use_clears_tools_when_no_computer_tool( + self, mock_fleet_env_with_tools, caplog + ): + """Should clear tools and warn when no 'computer' tool for computer_use modality.""" + from envs.fleet_env.task_env import FleetTaskEnv + import logging + + mock_orch, mock_tools = mock_fleet_env_with_tools + + # Mock list_tools returning only API tools (no computer tool) + async def mock_list_tools(): + return MagicMock( + tools=[ + {"name": "search_issues", "description": "Search issues"}, + {"name": "create_ticket", "description": "Create ticket"}, + ] + ) + + mock_tools.list_tools = mock_list_tools + + task_config = { + "task_key": "sentry-task", + "prompt": "Click on button", + "env_key": "sentry", + "task_modality": "computer_use", + } + + env = FleetTaskEnv(task_config, api_key="test") + + with caplog.at_level(logging.WARNING): + obs = await env.reset_async() + + # Should have empty tools (filtered out API tools) + assert env._tools_cache == [] + + # Should have logged warning + assert "computer_use modality but no 'computer' tool found" in caplog.text + + @pytest.mark.anyio + async def test_tool_use_does_not_filter(self, mock_fleet_env_with_tools): + """Should NOT filter tools for tool_use modality.""" + from envs.fleet_env.task_env import FleetTaskEnv + + mock_orch, mock_tools = mock_fleet_env_with_tools + + # Mock list_tools returning mixed tools + async def mock_list_tools(): + return MagicMock( + tools=[ + {"name": "computer", "description": "Mouse/keyboard control"}, + {"name": "search_issues", "description": "Search issues"}, + {"name": "create_ticket", "description": "Create ticket"}, + ] + ) + + mock_tools.list_tools = mock_list_tools + + task_config = { + "task_key": "test-task", + "prompt": "Search for issues", + "env_key": "test-env", + "task_modality": "tool_use", # tool_use, not computer_use + } + + env = FleetTaskEnv(task_config, api_key="test") + obs = await env.reset_async() + + # Should have all 3 tools + assert len(env._tools_cache) == 3 + + @pytest.mark.anyio + async def test_computer_use_filters_function_format( + self, mock_fleet_env_with_tools + ): + """Should filter 'computer' tool from function format.""" + from envs.fleet_env.task_env import FleetTaskEnv + + mock_orch, mock_tools = mock_fleet_env_with_tools + + # Mock list_tools returning tools in OpenAI function format + async def mock_list_tools(): + return MagicMock( + tools=[ + { + "type": "function", + "function": {"name": "computer", "description": "Control"}, + }, + { + "type": "function", + "function": {"name": "api_call", "description": "API"}, + }, + ] + ) + + mock_tools.list_tools = mock_list_tools + + task_config = { + "task_key": "test-task", + "prompt": "Click button", + "env_key": "test-env", + "task_modality": "computer_use", + } + + env = FleetTaskEnv(task_config, api_key="test") + obs = await env.reset_async() + + # Should only have computer tool + assert len(env._tools_cache) == 1 + assert env._tools_cache[0]["function"]["name"] == "computer" + + +class TestSubmitFinalAnswer: + """Tests for synthetic submit_final_answer tool injection.""" + + def test_submit_final_answer_tool_definition(self, mock_fleet_env_client): + """SUBMIT_FINAL_ANSWER_TOOL has correct schema.""" + from envs.fleet_env.task_env import SUBMIT_FINAL_ANSWER_TOOL + + func = SUBMIT_FINAL_ANSWER_TOOL["function"] + assert func["name"] == "submit_final_answer" + assert "answer" in func["parameters"]["properties"] + assert func["parameters"]["required"] == ["answer"] + + def test_submitted_answer_init(self, sample_task_config, mock_fleet_env_client): + """_submitted_answer should be None on init.""" + from envs.fleet_env.task_env import FleetTaskEnv + + env = FleetTaskEnv(sample_task_config, api_key="test") + assert env._submitted_answer is None + + @pytest.mark.anyio + async def test_step_submit_final_answer_stores_answer( + self, sample_task_config, mock_fleet_env_client + ): + """Calling submit_final_answer should store the answer and mark done.""" + from envs.fleet_env.task_env import FleetTaskEnv + + mock_orch, _ = mock_fleet_env_client + env = FleetTaskEnv(sample_task_config, api_key="test") + env._orch = mock_orch + env._tools = MagicMock() + env._tools_cache = [{"type": "function", "function": {"name": "bash"}}] + env._done = False + env._rollout_started = True + + action = {"tool": "submit_final_answer", "params": {"answer": '["row1", "row2"]'}} + obs, reward, done, info = await env.step_async(action) + + assert env._submitted_answer == '["row1", "row2"]' + assert done is True + assert info["submitted_answer"] == '["row1", "row2"]' + assert info["tool_result"]["status"] == "submitted" + + @pytest.mark.anyio + async def test_step_submit_final_answer_not_routed_to_mcp( + self, sample_task_config, mock_fleet_env_client + ): + """submit_final_answer should NOT call MCP tools.call_tool.""" + from envs.fleet_env.task_env import FleetTaskEnv + + mock_orch, _ = mock_fleet_env_client + mock_tools = AsyncMock() + env = FleetTaskEnv(sample_task_config, api_key="test") + env._orch = mock_orch + env._tools = mock_tools + env._tools_cache = [{"type": "function", "function": {"name": "bash"}}] + env._done = False + env._rollout_started = True + + action = {"tool": "submit_final_answer", "params": {"answer": "42"}} + await env.step_async(action) + + mock_tools.call_tool.assert_not_called()