From ce625bf198555335abf5b280f5bdd975e4ee310f Mon Sep 17 00:00:00 2001 From: lorenss-m Date: Tue, 9 Dec 2025 02:03:41 -0800 Subject: [PATCH 01/92] init new environments --- hud/__init__.py | 2 + hud/environment/__init__.py | 46 +++ hud/environment/connection.py | 159 ++++++++ hud/environment/connectors/__init__.py | 38 ++ hud/environment/connectors/base.py | 65 ++++ hud/environment/connectors/local.py | 147 +++++++ hud/environment/connectors/mcp_config.py | 99 +++++ hud/environment/connectors/openai.py | 107 ++++++ hud/environment/connectors/remote.py | 167 ++++++++ hud/environment/connectors/task.py | 104 +++++ hud/environment/environment.py | 447 ++++++++++++++++++++++ hud/environment/integrations/__init__.py | 36 ++ hud/environment/integrations/anthropic.py | 206 ++++++++++ hud/environment/integrations/gemini.py | 93 +++++ hud/environment/integrations/langchain.py | 114 ++++++ hud/environment/integrations/openai.py | 202 ++++++++++ hud/environment/mock.py | 306 +++++++++++++++ hud/environment/router.py | 105 +++++ hud/environment/utils/__init__.py | 25 ++ hud/environment/utils/formats.py | 213 +++++++++++ hud/environment/utils/schema.py | 97 +++++ hud/trace/__init__.py | 42 ++ hud/trace/context.py | 357 +++++++++++++++++ hud/trace/mixin.py | 382 ++++++++++++++++++ hud/trace/parallel.py | 131 +++++++ hud/types.py | 4 +- 26 files changed, 3693 insertions(+), 1 deletion(-) create mode 100644 hud/environment/__init__.py create mode 100644 hud/environment/connection.py create mode 100644 hud/environment/connectors/__init__.py create mode 100644 hud/environment/connectors/base.py create mode 100644 hud/environment/connectors/local.py create mode 100644 hud/environment/connectors/mcp_config.py create mode 100644 hud/environment/connectors/openai.py create mode 100644 hud/environment/connectors/remote.py create mode 100644 hud/environment/connectors/task.py create mode 100644 hud/environment/environment.py create mode 100644 hud/environment/integrations/__init__.py create mode 100644 hud/environment/integrations/anthropic.py create mode 100644 hud/environment/integrations/gemini.py create mode 100644 hud/environment/integrations/langchain.py create mode 100644 hud/environment/integrations/openai.py create mode 100644 hud/environment/mock.py create mode 100644 hud/environment/router.py create mode 100644 hud/environment/utils/__init__.py create mode 100644 hud/environment/utils/formats.py create mode 100644 hud/environment/utils/schema.py create mode 100644 hud/trace/__init__.py create mode 100644 hud/trace/context.py create mode 100644 hud/trace/mixin.py create mode 100644 hud/trace/parallel.py diff --git a/hud/__init__.py b/hud/__init__.py index 072dde86..2f4eef69 100644 --- a/hud/__init__.py +++ b/hud/__init__.py @@ -5,6 +5,7 @@ from __future__ import annotations +from .environment import Environment from .telemetry import ( Trace, async_job, @@ -18,6 +19,7 @@ ) __all__ = [ + "Environment", "Trace", "async_job", "async_trace", diff --git a/hud/environment/__init__.py b/hud/environment/__init__.py new file mode 100644 index 00000000..57634e51 --- /dev/null +++ b/hud/environment/__init__.py @@ -0,0 +1,46 @@ +""" +HUD Environment - A unified abstraction for MCP environments. + +The Environment class is a server that you can also use as a client. +It subclasses MCPServer to get server capabilities (@env.tool, serve()) +and composes FastMCP Client instances for remote connections. + +Usage: + from hud.environment import Environment + + # Create and connect + env = Environment("my-env").connect_hub("browser", prefix="web") + + async with env: + # Get tools in any format + openai_tools = env.as_openai_chat_tools() + claude_tools = env.as_claude_tools() + + # Call tools with any format - auto-parses and returns matching format + result = await env.call_tool("web_navigate", url="https://google.com") + + # Framework integrations (requires external deps) + agent_tools = env.as_openai_agent_tools() # needs openai-agents + lc_tools = env.as_langchain_tools() # needs langchain-core +""" + +from hud.environment.connection import ConnectionConfig, ConnectionType, Connector +from hud.environment.environment import Environment +from hud.environment.mock import MockMixin, generate_mock_value +from hud.environment.router import ConflictResolution, ToolRouter +from hud.environment.utils import ToolFormat, format_result, parse_tool_call, parse_tool_calls + +__all__ = [ + "ConflictResolution", + "ConnectionConfig", + "ConnectionType", + "Connector", + "Environment", + "MockMixin", + "ToolFormat", + "ToolRouter", + "format_result", + "generate_mock_value", + "parse_tool_call", + "parse_tool_calls", +] diff --git a/hud/environment/connection.py b/hud/environment/connection.py new file mode 100644 index 00000000..dd05ccde --- /dev/null +++ b/hud/environment/connection.py @@ -0,0 +1,159 @@ +"""Connection management for MCP servers.""" + +from __future__ import annotations + +import logging +from enum import Enum +from typing import TYPE_CHECKING, Any + +import mcp.types as mcp_types +from fastmcp.client import Client as FastMCPClient + +if TYPE_CHECKING: + from collections.abc import Callable + + from fastmcp.tools.tool import Tool + +__all__ = ["ConnectionConfig", "ConnectionType", "Connector"] + +logger = logging.getLogger(__name__) + + +class ConnectionType(str, Enum): + """Type of connection - determines parallelization capability.""" + + LOCAL = "local" # Stdio/Docker - single instance, not parallelizable + REMOTE = "remote" # HTTP/URL - can spawn multiple instances + + +class ConnectionConfig: + """Configuration for filtering/transforming tools from a remote connection.""" + + def __init__( + self, + *, + prefix: str | None = None, + include: list[str] | None = None, + exclude: list[str] | None = None, + transform: Callable[[Tool], Tool | None] | None = None, + ) -> None: + self.prefix = prefix + self.include = include + self.exclude = exclude + self.transform = transform + + +class Connector: + """Manages a connection to an MCP server with tool caching.""" + + def __init__( + self, + client: FastMCPClient[Any], + config: ConnectionConfig, + name: str, + connection_type: ConnectionType, + ) -> None: + self.client = client + self.config = config + self.name = name + self.connection_type = connection_type + self._tools_cache: list[mcp_types.Tool] | None = None + + @property + def is_local(self) -> bool: + """True if this is a local (non-parallelizable) connection.""" + return self.connection_type == ConnectionType.LOCAL + + @property + def is_remote(self) -> bool: + """True if this is a remote (parallelizable) connection.""" + return self.connection_type == ConnectionType.REMOTE + + @property + def is_connected(self) -> bool: + return self.client.is_connected() + + @property + def cached_tools(self) -> list[mcp_types.Tool]: + return self._tools_cache or [] + + async def connect(self) -> None: + """Connect using FastMCP Client's context manager.""" + if not self.is_connected: + await self.client.__aenter__() + + async def disconnect(self) -> None: + """Disconnect and clear cache.""" + if self.is_connected: + await self.client.__aexit__(None, None, None) + self._tools_cache = None + + async def list_tools(self) -> list[mcp_types.Tool]: + """Fetch tools from server, apply filters/transforms/prefix, and cache.""" + tools = await self.client.list_tools() + + result: list[mcp_types.Tool] = [] + for tool in tools: + # Apply include/exclude filter + if self.config.include is not None and tool.name not in self.config.include: + continue + if self.config.exclude is not None and tool.name in self.config.exclude: + continue + + # Apply transform + if self.config.transform is not None: + from fastmcp.tools.tool import Tool as FastMCPTool + + fastmcp_tool = FastMCPTool.model_construct( + name=tool.name, + description=tool.description or "", + parameters=tool.inputSchema, + ) + transformed = self.config.transform(fastmcp_tool) + if transformed is None: + continue + tool = mcp_types.Tool( + name=transformed.name, + description=transformed.description, + inputSchema=transformed.parameters, + ) + + # Apply prefix + name = f"{self.config.prefix}_{tool.name}" if self.config.prefix else tool.name + result.append(mcp_types.Tool( + name=name, + description=tool.description, + inputSchema=tool.inputSchema, + )) + + self._tools_cache = result + return result + + async def call_tool( + self, name: str, arguments: dict[str, Any] | None = None + ) -> mcp_types.CallToolResult: + """Call a tool, stripping prefix if needed.""" + # Strip prefix when calling remote + if self.config.prefix and name.startswith(f"{self.config.prefix}_"): + name = name[len(self.config.prefix) + 1:] + return await self.client.call_tool_mcp(name, arguments or {}) + + async def list_resources(self) -> list[mcp_types.Resource]: + return await self.client.list_resources() + + async def list_prompts(self) -> list[mcp_types.Prompt]: + return await self.client.list_prompts() + + async def read_resource( + self, uri: str + ) -> list[mcp_types.TextResourceContents | mcp_types.BlobResourceContents]: + return await self.client.read_resource(uri) + + async def get_prompt( + self, name: str, arguments: dict[str, Any] | None = None + ) -> mcp_types.GetPromptResult: + return await self.client.get_prompt(name, arguments) + + def __repr__(self) -> str: + t = self.connection_type.value + return f"Connector({self.name!r}, {t}, connected={self.is_connected})" diff --git a/hud/environment/connectors/__init__.py b/hud/environment/connectors/__init__.py new file mode 100644 index 00000000..e99850da --- /dev/null +++ b/hud/environment/connectors/__init__.py @@ -0,0 +1,38 @@ +"""Connection connectors - methods for connecting to various sources.""" + +from hud.environment.connectors.local import LocalConnectorMixin +from hud.environment.connectors.openai import OpenAIConnectorMixin +from hud.environment.connectors.remote import RemoteConnectorMixin +from hud.environment.connectors.task import TaskConnectorMixin + +__all__ = ["ConnectorsMixin"] + + +class ConnectorsMixin( + RemoteConnectorMixin, + LocalConnectorMixin, + TaskConnectorMixin, + OpenAIConnectorMixin, +): + """Combined connector mixin providing all connection methods. + + Remote connections: + connect_hub(slug) - HUD Hub environment (fetches mcp_config from API) + connect_url(url) - MCP server via URL + connect_openapi(spec) - Mount OpenAPI spec as MCP server + + Local connections (in-process): + connect_image(image) - Docker image via stdio + connect_fastapi(app) - Mount FastAPI app as MCP server + connect_server(server) - Mount MCPServer/FastMCP directly + + MCP config: + connect_mcp(config) - Single mcp_config server (auto-detects local/remote) + connect_mcp_config(mcp_config) - Multiple mcp_config servers + + Task: + connect_task(slug) - Load task from platform by slug + + Framework imports: + connect_function_tools(tools) - Import OpenAI Agents SDK FunctionTools + """ diff --git a/hud/environment/connectors/base.py b/hud/environment/connectors/base.py new file mode 100644 index 00000000..3d25e78f --- /dev/null +++ b/hud/environment/connectors/base.py @@ -0,0 +1,65 @@ +"""Base connector mixin with shared helper.""" + +from __future__ import annotations + +from typing import TYPE_CHECKING, Any + +if TYPE_CHECKING: + from collections.abc import Callable + + from fastmcp.tools.tool import Tool + + from hud.environment.connection import ConnectionType, Connector + +__all__ = ["BaseConnectorMixin"] + + +class BaseConnectorMixin: + """Base mixin providing connection helper. + + Requires: + _connections: dict[str, Connector] + """ + + _connections: dict[str, Connector] + + def _add_connection( + self, + name: str, + transport: Any, + *, + connection_type: ConnectionType, + auth: str | None = None, + prefix: str | None = None, + include: list[str] | None = None, + exclude: list[str] | None = None, + transform: Callable[[Tool], Tool | None] | None = None, + ) -> Any: + """Add a connection to the environment. + + Args: + name: Connection name/alias. + transport: FastMCP transport (URL, config dict, etc.). + connection_type: LOCAL or REMOTE - determines parallelization. + auth: Authorization header value. + prefix: Prefix for tool names. + include: Only include these tools. + exclude: Exclude these tools. + transform: Transform function for tools. + + Returns: + self for chaining. + """ + from fastmcp.client import Client as FastMCPClient + + from hud.environment.connection import ConnectionConfig, Connector + + config = ConnectionConfig( + prefix=prefix, include=include, exclude=exclude, transform=transform, + ) + client = FastMCPClient(transport=transport, auth=auth) + self._connections[name] = Connector( + client, config, name, connection_type=connection_type, + ) + return self + diff --git a/hud/environment/connectors/local.py b/hud/environment/connectors/local.py new file mode 100644 index 00000000..6ae170b8 --- /dev/null +++ b/hud/environment/connectors/local.py @@ -0,0 +1,147 @@ +"""Local connection connectors - Docker image, FastAPI, MCPServer.""" + +from __future__ import annotations + +from typing import TYPE_CHECKING, Any + +from hud.environment.connectors.mcp_config import MCPConfigConnectorMixin + +if TYPE_CHECKING: + from collections.abc import Callable + + from fastmcp.tools.tool import Tool + +__all__ = ["LocalConnectorMixin"] + + +class LocalConnectorMixin(MCPConfigConnectorMixin): + """Mixin providing local connection methods. + + Methods: + connect_image(image) - Run Docker image via stdio + connect_fastapi(app) - Mount FastAPI app as MCP server + connect_server(server) - Mount any MCPServer/FastMCP directly + + Inherits connect_mcp() from MCPConfigConnectorMixin. + """ + + def mount(self, server: Any, *, prefix: str | None = None) -> None: + """Mount method from MCPServer base class.""" + raise NotImplementedError + + def connect_image( + self, + image: str, + *, + alias: str | None = None, + docker_args: list[str] | None = None, + env_vars: dict[str, str] | None = None, + prefix: str | None = None, + include: list[str] | None = None, + exclude: list[str] | None = None, + transform: Callable[[Tool], Tool | None] | None = None, + ) -> Any: + """Connect to a Docker image via stdio. + + Creates an MCP config that runs: docker run -i --rm {image} + Environment variables from `.env` files are auto-injected. + + Example: + ```python + env = Environment("my-env") + env.connect_image("mcp/fetch") + + async with env: + result = await env.call_tool("fetch", url="https://example.com") + ``` + """ + from hud.cli.utils.docker import create_docker_run_command + + cmd = create_docker_run_command( + image=image, + docker_args=docker_args, + extra_env=env_vars, + interactive=True, + remove=True, + ) + + name = alias or image + mcp_config = { + name: { + "command": cmd[0], + "args": cmd[1:], + } + } + return self.connect_mcp( + mcp_config, + alias=name, + prefix=prefix, + include=include, + exclude=exclude, + transform=transform, + ) + + def connect_fastapi( + self, + app: Any, + *, + name: str | None = None, + prefix: str | None = None, + ) -> Any: + """Mount a FastAPI application as an MCP server. + + Uses FastMCP's from_fastapi() to convert FastAPI endpoints to MCP tools. + + Example: + ```python + from fastapi import FastAPI + + api = FastAPI() + + @api.get("/users/{user_id}", operation_id="get_user") + def get_user(user_id: int): + return {"id": user_id, "name": "Alice"} + + env = Environment("my-env") + env.connect_fastapi(api) + + async with env: + result = await env.call_tool("get_user", user_id=1) + ``` + + Tip: Use operation_id in FastAPI decorators for cleaner tool names. + """ + from fastmcp import FastMCP + + server_name = name or getattr(app, "title", None) or "fastapi" + mcp_server = FastMCP.from_fastapi(app=app, name=server_name) + self.mount(mcp_server, prefix=prefix) + return self + + def connect_server( + self, + server: Any, + *, + prefix: str | None = None, + ) -> Any: + """Mount an MCPServer or FastMCP instance directly. + + Example: + ```python + from fastmcp import FastMCP + + tools = FastMCP("tools") + + @tools.tool + def greet(name: str) -> str: + return f"Hello, {name}!" + + env = Environment("my-env") + env.connect_server(tools) + + async with env: + result = await env.call_tool("greet", name="World") + ``` + """ + self.mount(server, prefix=prefix) + return self diff --git a/hud/environment/connectors/mcp_config.py b/hud/environment/connectors/mcp_config.py new file mode 100644 index 00000000..95581974 --- /dev/null +++ b/hud/environment/connectors/mcp_config.py @@ -0,0 +1,99 @@ +"""MCP config connection connectors.""" + +from __future__ import annotations + +from typing import TYPE_CHECKING, Any + +from hud.environment.connectors.base import BaseConnectorMixin + +if TYPE_CHECKING: + from collections.abc import Callable + + from fastmcp.tools.tool import Tool + +__all__ = ["MCPConfigConnectorMixin"] + + +class MCPConfigConnectorMixin(BaseConnectorMixin): + """Mixin providing mcp_config connection methods.""" + + def connect_mcp( + self, + config: dict[str, dict[str, Any]], + *, + alias: str | None = None, + prefix: str | None = None, + include: list[str] | None = None, + exclude: list[str] | None = None, + transform: Callable[[Tool], Tool | None] | None = None, + ) -> Any: + """Connect using an mcp_config dictionary (single server). + + Auto-detects LOCAL (stdio) vs REMOTE (URL) based on config. + + Example: + ```python + env = Environment("my-env") + + # Stdio server + env.connect_mcp({ + "filesystem": { + "command": "npx", + "args": ["-y", "@modelcontextprotocol/server-filesystem", "/tmp"], + } + }) + + async with env: + await env.call_tool("read_file", path="/tmp/test.txt") + ``` + """ + from hud.environment.connection import ConnectionType + + name = alias or next(iter(config.keys()), "mcp") + server_config = next(iter(config.values()), {}) + + is_local = "command" in server_config or "args" in server_config + conn_type = ConnectionType.LOCAL if is_local else ConnectionType.REMOTE + + return self._add_connection( + name, + config, + connection_type=conn_type, + prefix=prefix, + include=include, + exclude=exclude, + transform=transform, + ) + + def connect_mcp_config( + self, + mcp_config: dict[str, dict[str, Any]], + **kwargs: Any, + ) -> Any: + """Connect multiple servers from an mcp_config dictionary. + + Example: + ```python + env = Environment("my-env") + + # Claude Desktop style config + env.connect_mcp_config({ + "filesystem": { + "command": "npx", + "args": ["-y", "@modelcontextprotocol/server-filesystem", "/tmp"], + }, + "github": { + "command": "npx", + "args": ["-y", "@modelcontextprotocol/server-github"], + "env": {"GITHUB_TOKEN": "..."}, + }, + }) + + async with env: + await env.call_tool("read_file", path="/tmp/test.txt") + await env.call_tool("search_repositories", query="mcp") + ``` + """ + for server_name, server_config in mcp_config.items(): + self.connect_mcp({server_name: server_config}, alias=server_name, **kwargs) + return self diff --git a/hud/environment/connectors/openai.py b/hud/environment/connectors/openai.py new file mode 100644 index 00000000..fdaea52a --- /dev/null +++ b/hud/environment/connectors/openai.py @@ -0,0 +1,107 @@ +"""OpenAI Agents SDK connectors - import tools from OpenAI agents.""" + +from __future__ import annotations + +import json +from typing import TYPE_CHECKING, Any + +if TYPE_CHECKING: + from collections.abc import Callable + +__all__ = ["OpenAIConnectorMixin"] + +# Lazy import check +try: + from agents import FunctionTool + _HAS_OPENAI_AGENTS = True +except ImportError: + _HAS_OPENAI_AGENTS = False + FunctionTool = None # type: ignore[misc, assignment] + + +class OpenAIConnectorMixin: + """Mixin providing OpenAI Agents SDK connector methods.""" + + # These are defined on Environment/MCPServer + _tool_manager: Any + + def connect_function_tools( + self, + tools: list[Any], + *, + prefix: str | None = None, + ) -> Any: + """Import FunctionTools from the OpenAI Agents SDK. + + Wraps each tool so calls go through HUD with telemetry. + + Example: + ```python + from agents import function_tool + + @function_tool + def search(query: str) -> str: + '''Search for information.''' + return f"Results for {query}" + + @function_tool + def calculate(expression: str) -> float: + '''Evaluate a math expression.''' + return eval(expression) + + env = Environment("my-env") + env.connect_function_tools([search, calculate]) + + async with env: + result = await env.call_tool("search", query="MCP protocol") + ``` + + Note: + Requires `openai-agents`: pip install openai-agents + """ + if not _HAS_OPENAI_AGENTS: + raise ImportError( + "openai-agents is required for connect_function_tools. " + "Install with: pip install openai-agents" + ) + + for tool in tools: + if isinstance(tool, FunctionTool): + self._add_openai_function_tool(tool, prefix) + + return self + + def _add_openai_function_tool(self, tool: Any, prefix: str | None) -> None: + """Convert OpenAI FunctionTool to local MCP tool.""" + name = f"{prefix}_{tool.name}" if prefix else tool.name + + # Get the original invoke function + original_invoke = tool.on_invoke_tool + + # Create wrapper that calls the original + async def invoke(**arguments: Any) -> Any: + # OpenAI's on_invoke_tool expects (ToolContext, str_json_args) + # We need to create a minimal context + from agents.tool_context import ToolContext + ctx = ToolContext(context=None) + result = await original_invoke(ctx, json.dumps(arguments)) + return result + + # Set function metadata for FastMCP + invoke.__name__ = name + invoke.__doc__ = tool.description + + # Register using FastMCP's tool decorator mechanism + # We access the internal _tool_manager from MCPServer + from fastmcp.tools import Tool as FastMCPTool + + fastmcp_tool = FastMCPTool.from_function( + fn=invoke, + name=name, + description=tool.description, + ) + # Override the schema with OpenAI's (more accurate) + fastmcp_tool.parameters = tool.params_json_schema + + self._tool_manager.add_tool(fastmcp_tool) + diff --git a/hud/environment/connectors/remote.py b/hud/environment/connectors/remote.py new file mode 100644 index 00000000..7a7ad4b4 --- /dev/null +++ b/hud/environment/connectors/remote.py @@ -0,0 +1,167 @@ +"""Remote connection connectors - HUD Hub, URL, OpenAPI.""" + +from __future__ import annotations + +import logging +from typing import TYPE_CHECKING, Any, cast + +from hud.environment.connectors.mcp_config import MCPConfigConnectorMixin + +if TYPE_CHECKING: + from collections.abc import Callable + + from fastmcp.tools.tool import Tool + +__all__ = ["RemoteConnectorMixin"] + +logger = logging.getLogger(__name__) + + +class RemoteConnectorMixin(MCPConfigConnectorMixin): + """Mixin providing remote connection methods.""" + + def mount(self, server: Any, *, prefix: str | None = None) -> None: + raise NotImplementedError + + def connect_hub( + self, + slug: str, + *, + alias: str | None = None, + prefix: str | None = None, + include: list[str] | None = None, + exclude: list[str] | None = None, + transform: Callable[[Tool], Tool | None] | None = None, + ) -> Any: + """Connect to a HUD Hub environment. + + Fetches mcp_config from api.hud.so immediately and creates connectors. + + Example: + ```python + env = Environment("my-env") + env.connect_hub("hud/browser") + + async with env: + await env.call_tool("navigate", url="https://google.com") + ``` + """ + import httpx + + from hud.settings import settings + + # Fetch mcp_config synchronously + logger.info("Loading hub environment: %s", slug) + + headers = {} + if settings.api_key: + headers["Authorization"] = f"Bearer {settings.api_key}" + + with httpx.Client() as client: + response = client.get( + f"{settings.hud_api_url}/environments/{slug}/mcp-config", + headers=headers, + ) + response.raise_for_status() + data = response.json() + + mcp_config: dict[str, dict[str, Any]] = data.get("mcp_config", data) + self.connect_mcp_config( + mcp_config, prefix=prefix, include=include, exclude=exclude, transform=transform + ) + logger.info("Hub connected: %s (%d servers)", slug, len(mcp_config)) + return self + + def connect_url( + self, + url: str, + *, + headers: dict[str, str] | None = None, + alias: str | None = None, + prefix: str | None = None, + include: list[str] | None = None, + exclude: list[str] | None = None, + transform: Callable[[Tool], Tool | None] | None = None, + ) -> Any: + """Connect to an MCP server via URL. + + Example: + ```python + env = Environment("my-env") + env.connect_url( + "https://mcp.example.com", + headers={"Authorization": "Bearer token"}, + ) + + async with env: + await env.call_tool("search", query="hello") + ``` + """ + from hud.environment.connection import ConnectionType + + auth = headers.get("Authorization") if headers else None + return self._add_connection( + alias or url, + url, + connection_type=ConnectionType.REMOTE, + auth=auth, + prefix=prefix, + include=include, + exclude=exclude, + transform=transform, + ) + + def connect_openapi( + self, + openapi_spec: dict[str, Any] | str, + *, + base_url: str | None = None, + headers: dict[str, str] | None = None, + name: str | None = None, + prefix: str | None = None, + timeout: float = 30.0, + ) -> Any: + """Mount an OpenAPI specification as an MCP server. + + Converts REST API endpoints to MCP tools. Base URL is auto-inferred + from the spec URL when possible. + + Example: + ```python + env = Environment("my-env") + env.connect_openapi("https://petstore.swagger.io/v2/swagger.json") + + async with env: + result = await env.call_tool("getPetById", petId=1) + ``` + """ + from urllib.parse import urlparse + + import httpx + from fastmcp import FastMCP + + if isinstance(openapi_spec, str): + if openapi_spec.startswith(("http://", "https://")): + if base_url is None: + parsed = urlparse(openapi_spec) + base_url = f"{parsed.scheme}://{parsed.netloc}" + + resp = httpx.get(openapi_spec, headers=headers) + resp.raise_for_status() + openapi_spec = resp.json() + else: + import json + with open(openapi_spec) as f: + openapi_spec = json.load(f) + + if base_url is None: + raise ValueError("base_url is required when openapi_spec is a dict or file") + + client = httpx.AsyncClient(base_url=base_url, headers=headers or {}, timeout=timeout) + mcp_server = FastMCP.from_openapi( + openapi_spec=cast("dict[str, Any]", openapi_spec), + client=client, + name=name or "openapi", + ) + self.mount(mcp_server, prefix=prefix) + return self diff --git a/hud/environment/connectors/task.py b/hud/environment/connectors/task.py new file mode 100644 index 00000000..3eae3bc4 --- /dev/null +++ b/hud/environment/connectors/task.py @@ -0,0 +1,104 @@ +"""Task connection connector.""" + +from __future__ import annotations + +import logging +from typing import TYPE_CHECKING, Any + +from hud.environment.connectors.mcp_config import MCPConfigConnectorMixin + +if TYPE_CHECKING: + from hud.types import Task + +__all__ = ["TaskConnectorMixin"] + +logger = logging.getLogger(__name__) + + +class TaskConnectorMixin(MCPConfigConnectorMixin): + """Mixin providing connect_task() method. + + Inherits from MCPConfigConnectorMixin for connect_mcp_config(). + """ + + def setup_tool(self, call: Any, /, **kwargs: Any) -> Any: + raise NotImplementedError + + def evaluate_tool(self, call: Any, /, **kwargs: Any) -> Any: + raise NotImplementedError + + def connect_task(self, slug: str) -> Any: + """Connect to a task from the HUD platform. + + Fetches the task from api.hud.so immediately and applies configuration + (mcp_config, setup_tool, evaluate_tool). + + Args: + slug: Task slug in format "evalset/task_name" or "evalset/task_name@version". + + Returns: + self for chaining. + + Example: + ```python + env = Environment("my-env").connect_task("my-org/browser-task") + + async with env: + # Task's mcp_config is connected + # Task's setup_tool runs automatically + result = await env.call_tool("navigate", url="...") + # Task's evaluate_tool runs on exit + ``` + """ + import httpx + + from hud.settings import settings + from hud.types import Task + + # Fetch task synchronously + logger.info("Loading task from platform: %s", slug) + + headers = {} + if settings.api_key: + headers["Authorization"] = f"Bearer {settings.api_key}" + + with httpx.Client() as client: + response = client.get( + f"{settings.hud_api_url}/tasks/{slug}", + headers=headers, + ) + response.raise_for_status() + data = response.json() + + task = Task(**data) + self._apply_task(task) + logger.info("Task loaded and applied: %s", slug) + return self + + def _apply_task(self, task: Task) -> None: + """Apply a Task definition to this environment. + + Sets up: + - MCP connections from task.mcp_config + - Setup tool calls from task.setup_tool + - Evaluate tool calls from task.evaluate_tool + """ + # Connect MCP servers + if task.mcp_config: + self.connect_mcp_config(task.mcp_config) + + # Configure setup tool calls + if task.setup_tool: + setup_calls = task.setup_tool + if not isinstance(setup_calls, list): + setup_calls = [setup_calls] + for call in setup_calls: + self.setup_tool(call.name, **(call.arguments or {})) + + # Configure evaluate tool calls + if task.evaluate_tool: + eval_calls = task.evaluate_tool + if not isinstance(eval_calls, list): + eval_calls = [eval_calls] + for call in eval_calls: + self.evaluate_tool(call.name, **(call.arguments or {})) diff --git a/hud/environment/environment.py b/hud/environment/environment.py new file mode 100644 index 00000000..63d0f0f0 --- /dev/null +++ b/hud/environment/environment.py @@ -0,0 +1,447 @@ +"""Environment class - unified MCP server and client.""" + +from __future__ import annotations + +import asyncio +import logging +import types +from collections.abc import Awaitable, Callable +from typing import TYPE_CHECKING, Any, Literal + +import mcp.types as mcp_types + +from hud.environment.connectors import ConnectorsMixin +from hud.environment.connection import Connector +from hud.environment.integrations import IntegrationsMixin +from hud.environment.mock import MockMixin +from hud.environment.router import ConflictResolution, ToolRouter +from hud.server.server import MCPServer +from hud.trace.mixin import TraceMixin +from hud.types import MCPToolResult + +if TYPE_CHECKING: + from hud.types import Task + +__all__ = ["Environment"] + +logger = logging.getLogger(__name__) + +# Type alias for async callables (no-arg functions that return awaitable) +AsyncCallable = Callable[[], Awaitable[Any]] + + +class Environment( + ConnectorsMixin, + IntegrationsMixin, + MockMixin, + TraceMixin, + MCPServer, +): + """Unified MCP environment that acts as both server and client. + + Features: + - Define local tools with @env.tool decorator + - Connect to HUD Hub, URLs, or mcp_config dicts + - Automatic tool routing (local vs remote) + - Format tools for any LLM provider + - Integrate with popular agent frameworks + - Mock mode for testing without real connections + + Connector methods (connect to sources): + connect_hub(name) - HUD Hub environment + connect_url(url) - MCP server via URL + connect_mcp(config) - Single mcp_config server + connect_mcp_config(mcp_config) - Multiple mcp_config servers + connect_task(slug) - Load task from platform by slug + connect_image(image) - Docker image via stdio + connect_fastapi(app) - Mount FastAPI app as MCP server + connect_openapi(spec) - Mount OpenAPI spec as MCP server + connect_server(server) - Mount MCPServer/FastMCP directly + + Mock methods (for testing): + mock() - Enable mock mode, all tools return mock values + unmock() - Disable mock mode + mock_tool(name, output) - Set specific mock output for a tool + is_mock - Check if mock mode is enabled + + OpenAI integrations: + as_openai_chat_tools() - Chat Completions format + as_openai_responses_tools() - Responses API format + as_openai_agent_tools() - Agents SDK (requires openai-agents) + + Anthropic/Claude integrations: + as_claude_tools() - Claude API format + as_claude_programmatic_tools() - Programmatic tool use + as_anthropic_runner() - Tool runner (requires anthropic) + + Google/Gemini integrations: + as_gemini_tools() - Gemini format + as_gemini_tool_config() - Tool execution config + + LangChain integrations: + as_langchain_tools() - StructuredTools (requires langchain-core) + + Example: + ```python + env = Environment("my-env") + + @env.tool + def greet(name: str) -> str: + return f"Hello, {name}!" + + env.connect_hub("browser", prefix="browser") + + async with env: + # Get tools in any format + openai_tools = env.as_openai_chat_tools() + claude_tools = env.as_claude_tools() + + # Call tools - automatically routed + result = await env.call_tool("greet", name="World") + + # Or pass provider-specific format - auto-detected + result = await env.call_tool(response.choices[0].message.tool_calls[0]) + + # Mock mode for testing + env.mock() + env.mock_tool("browser_navigate", "Navigation successful") + async with env: + result = await env.call_tool("browser_navigate", url="https://example.com") + # Returns mock value instead of actually navigating + ``` + """ + + MAX_CONCURRENT_CONNECTIONS = 10 + + def __init__( + self, + name: str = "environment", + instructions: str | None = None, + conflict_resolution: ConflictResolution = ConflictResolution.PREFIX, + **fastmcp_kwargs: Any, + ) -> None: + super().__init__(name=name, instructions=instructions, **fastmcp_kwargs) + self._connections: dict[str, Connector] = {} + self._router = ToolRouter(conflict_resolution=conflict_resolution) + self._in_context = False + + # Tool call queues - run after connections established + self._setup_calls: list[tuple[str, dict[str, Any]]] = [] + self._evaluate_calls: list[tuple[str, dict[str, Any]]] = [] + + # Track which lifecycle tools we've warned about (only warn once per tool) + self._warned_lifecycle_tools: set[str] = set() + + # Initialize mock state + self._init_mock() + + # ========================================================================= + # Core Methods + # ========================================================================= + + def as_tools(self) -> list[mcp_types.Tool]: + """Return tools in MCP format (base format).""" + return self._router.tools + + async def call_tool(self, call: Any, /, **kwargs: Any) -> Any: + """Call a tool, auto-detecting format and returning matching result format. + + Accepts any format: + - String with kwargs: call_tool("navigate", url="...") + - Tuple: call_tool(("navigate", {"url": "..."})) + - MCPToolCall: call_tool(MCPToolCall(name="navigate", ...)) + - OpenAI: call_tool(response.choices[0].message.tool_calls[0]) + - Claude: call_tool(response.content[0]) # tool_use block + - Gemini: call_tool(response.candidates[0].content.parts[0]) + + Returns: + Result formatted to match input format (OpenAI -> OpenAI tool message, etc.) + """ + from hud.environment.utils import format_result, parse_tool_call + + # Parse the tool call (kwargs merged when call is string) + parsed, fmt = parse_tool_call(call, **kwargs) + self._check_lifecycle_warning(parsed.name) + result = await self._execute_tool(parsed.name, parsed.arguments or {}) + return format_result(result, parsed, fmt) + + def _check_lifecycle_warning(self, name: str) -> None: + """Warn once if calling a setup/evaluate tool manually.""" + if name in self._warned_lifecycle_tools: + return + setup = {n for n, _ in self._setup_calls} + evaluate = {n for n, _ in self._evaluate_calls} + if name not in setup and name not in evaluate: + return + self._warned_lifecycle_tools.add(name) + phase = "setup" if name in setup else "evaluate" + logger.warning( + "Tool '%s' is a %s tool (runs automatically). Manual call may duplicate.", + name, phase, + ) + + async def call_tools(self, calls: Any) -> list[Any]: + """Call multiple tools, returning results in matching formats.""" + if calls is None: + return [] + if not isinstance(calls, list): + return [await self.call_tool(calls)] + + # Filter to tool calls only (skip text blocks, etc.) + tool_calls = [] + for call in calls: + t = call.get("type") if isinstance(call, dict) else getattr(call, "type", None) + if t is None or t in ("tool_use", "function"): + tool_calls.append(call) + + return await asyncio.gather(*[self.call_tool(c) for c in tool_calls]) + + # ========================================================================= + # Lifecycle Configuration + # ========================================================================= + + def setup_tool(self, call: Any, /, **kwargs: Any) -> Environment: + """Add a tool call to execute after connections are established.""" + from hud.environment.utils import parse_tool_call + + if isinstance(call, str) and kwargs: + self._setup_calls.append((call, kwargs)) + else: + parsed, _ = parse_tool_call(call) + self._setup_calls.append((parsed.name, parsed.arguments or {})) + return self + + def evaluate_tool(self, call: Any, /, **kwargs: Any) -> Environment: + """Add a tool call to execute before disconnecting.""" + from hud.environment.utils import parse_tool_call + + if isinstance(call, str) and kwargs: + self._evaluate_calls.append((call, kwargs)) + else: + parsed, _ = parse_tool_call(call) + self._evaluate_calls.append((parsed.name, parsed.arguments or {})) + return self + + # ========================================================================= + # Context Manager + # ========================================================================= + + async def __aenter__(self) -> Environment: + """Connect all connectors, build routing, run setup tools.""" + self._in_context = True + + # Connect to all servers (on_connect callbacks run first within connect()) + sem = asyncio.Semaphore(self.MAX_CONCURRENT_CONNECTIONS) + errors: list[tuple[str, Exception]] = [] + + async def connect_one(name: str, conn: Connector) -> None: + async with sem: + try: + await conn.connect() + await conn.list_tools() + except Exception as e: + errors.append((name, e)) + + if self._connections: + await asyncio.gather(*[ + connect_one(n, c) for n, c in self._connections.items() + ]) + if errors: + for conn in self._connections.values(): + if conn.is_connected: + await conn.disconnect() + name, err = errors[0] + raise ConnectionError(f"Failed to connect to {name}") from err + + await self._build_routing() + + # Setup tool calls (after connections) + for name, args in self._setup_calls: + await self._execute_tool(name, args) + + return self + + async def __aexit__( + self, + exc_type: type[BaseException] | None, + exc_val: BaseException | None, + exc_tb: types.TracebackType | None, + ) -> None: + """Run evaluate tools, exit queue, then disconnect.""" + # Evaluate tool calls + for name, args in self._evaluate_calls: + try: + await self._execute_tool(name, args) + except Exception as e: + logger.warning("Evaluate tool %s failed: %s", name, e) + + self._in_context = False + if self._connections: + await asyncio.gather(*[c.disconnect() for c in self._connections.values()]) + self._router.clear() + + async def _build_routing(self) -> None: + """Build tool routing from local tools and connection caches.""" + local_tools = await self._tool_manager.list_tools() + self._router.build( + local_tools=[t.to_mcp_tool() for t in local_tools], + connections=self._connections, + connection_order=list(self._connections.keys()), + ) + # Populate mock schemas for auto-generated mock values + self._populate_mock_schemas() + + # ========================================================================= + # Tool Operations + # ========================================================================= + + async def list_tools(self) -> list[mcp_types.Tool]: + """Refresh tools from all connections and rebuild routing.""" + if self._connections: + await asyncio.gather(*[c.list_tools() for c in self._connections.values()]) + await self._build_routing() + return self._router.tools + + async def _execute_tool(self, name: str, arguments: dict[str, Any]) -> MCPToolResult: + """Execute a tool by name. Routes to local or remote handler. + + If mock mode is enabled, returns a mock result instead of executing. + """ + # Check mock mode first + if self._mock_mode: + logger.debug("Mock mode: returning mock result for tool %s", name) + return self._get_mock_result(name, arguments) + + if self._router.is_local(name): + result = await self._call_tool(name, arguments) + return MCPToolResult(content=result.content, isError=False) + + connection_name = self._router.get_connection(name) + if connection_name: + conn = self._connections[connection_name] + result = await conn.call_tool(name, arguments) + return MCPToolResult(content=result.content, isError=result.isError) + + raise ValueError(f"Tool not found: {name}") + + # ========================================================================= + # Resource Operations + # ========================================================================= + + async def list_resources(self) -> list[mcp_types.Resource]: + """List all resources (local + remote).""" + local = await self._resource_manager.list_resources() + resources: list[mcp_types.Resource] = [r.to_mcp_resource() for r in local] + + if self._connections: + results = await asyncio.gather(*[ + c.list_resources() for c in self._connections.values() + ], return_exceptions=True) + for r in results: + if isinstance(r, list): + resources.extend(r) + + return resources + + async def read_resource( + self, uri: str + ) -> list[mcp_types.TextResourceContents | mcp_types.BlobResourceContents]: + """Read a resource by URI (tries local first, then remote).""" + from pydantic import AnyUrl + + try: + result = await self._resource_manager.read_resource(uri) + resource_uri = AnyUrl(uri) + if isinstance(result, str): + return [mcp_types.TextResourceContents(uri=resource_uri, text=result)] + import base64 + return [mcp_types.BlobResourceContents( + uri=resource_uri, blob=base64.b64encode(result).decode() + )] + except Exception: + pass + + for conn in self._connections.values(): + try: + return await conn.read_resource(uri) + except Exception: + continue + + raise ValueError(f"Resource not found: {uri}") + + # ========================================================================= + # Prompt Operations + # ========================================================================= + + async def list_prompts(self) -> list[mcp_types.Prompt]: + """List all prompts (local + remote).""" + local = await self._prompt_manager.list_prompts() + prompts: list[mcp_types.Prompt] = [p.to_mcp_prompt() for p in local] + + if self._connections: + results = await asyncio.gather(*[ + c.list_prompts() for c in self._connections.values() + ], return_exceptions=True) + for r in results: + if isinstance(r, list): + prompts.extend(r) + + return prompts + + async def get_prompt( + self, name: str, arguments: dict[str, Any] | None = None + ) -> mcp_types.GetPromptResult: + """Get a prompt by name (tries local first, then remote).""" + try: + return await self._prompt_manager.render_prompt(name, arguments or {}) + except Exception: + pass + + for conn in self._connections.values(): + try: + return await conn.get_prompt(name, arguments) + except Exception: + continue + + raise ValueError(f"Prompt not found: {name}") + + # ========================================================================= + # Server Methods + # ========================================================================= + + def serve( + self, + transport: Literal["stdio", "sse", "streamable-http"] = "streamable-http", + host: str = "0.0.0.0", # noqa: S104 + port: int = 8000, + **kwargs: Any, + ) -> None: + """Start serving as an MCP server.""" + self.run(transport=transport, host=host, port=port, **kwargs) + + # ========================================================================= + # Properties + # ========================================================================= + + @property + def connections(self) -> dict[str, Connector]: + return self._connections + + @property + def is_connected(self) -> bool: + return self._in_context + + @property + def is_parallelizable(self) -> bool: + """True if all connections are remote (can spawn multiple instances).""" + if not self._connections: + return True # No connections = can parallelize (local tools only) + return all(conn.is_remote for conn in self._connections.values()) + + @property + def local_connections(self) -> list[str]: + """Names of local (non-parallelizable) connections.""" + return [name for name, conn in self._connections.items() if conn.is_local] + + def __repr__(self) -> str: + return f"Environment({self.name!r}, connections={list(self._connections.keys())})" diff --git a/hud/environment/integrations/__init__.py b/hud/environment/integrations/__init__.py new file mode 100644 index 00000000..9794ec16 --- /dev/null +++ b/hud/environment/integrations/__init__.py @@ -0,0 +1,36 @@ +"""Provider integrations - format conversion and framework tools.""" + +from hud.environment.integrations.anthropic import AnthropicMixin +from hud.environment.integrations.gemini import GeminiMixin +from hud.environment.integrations.langchain import LangChainMixin +from hud.environment.integrations.openai import OpenAIMixin + +__all__ = ["IntegrationsMixin"] + + +class IntegrationsMixin( + OpenAIMixin, + AnthropicMixin, + GeminiMixin, + LangChainMixin, +): + """Combined integration mixin for all providers. + + OpenAI: + as_openai_chat_tools() - Chat Completions format + as_openai_responses_tools() - Responses API format + as_openai_agent_tools() - Agents SDK (requires openai-agents) + + Anthropic/Claude: + as_claude_tools() - Claude API format + as_claude_programmatic_tools() - Programmatic tool use + as_anthropic_runner() - Tool runner (requires anthropic) + + Google/Gemini: + as_gemini_tools() - Gemini format + as_gemini_tool_config() - Tool config + + LangChain: + as_langchain_tools() - StructuredTools (requires langchain-core) + """ + pass diff --git a/hud/environment/integrations/anthropic.py b/hud/environment/integrations/anthropic.py new file mode 100644 index 00000000..dc2d3a3c --- /dev/null +++ b/hud/environment/integrations/anthropic.py @@ -0,0 +1,206 @@ +"""Anthropic/Claude integrations - format conversion and tool runner.""" + +from __future__ import annotations + +import json +from typing import TYPE_CHECKING, Any + +# Try to import anthropic +try: + from anthropic.types.beta import BetaToolResultBlockParam + _HAS_ANTHROPIC = True +except ImportError: + _HAS_ANTHROPIC = False + BetaToolResultBlockParam = None # type: ignore[misc, assignment] + +if TYPE_CHECKING: + import mcp.types as mcp_types + +__all__ = ["AnthropicMixin"] + + +class AnthropicMixin: + """Mixin providing Anthropic/Claude format conversion and tool runner. + + Format methods (no deps): + as_claude_tools() - Claude API format + as_claude_programmatic_tools() - Programmatic tool use format + + Integration methods (requires anthropic): + as_anthropic_runner() - Tool runner for executing tool_use blocks + + Requires: as_tools() -> list[mcp_types.Tool], call_tool(name, args) + """ + + def as_tools(self) -> list[mcp_types.Tool]: + raise NotImplementedError + + async def call_tool(self, name: str, arguments: dict[str, Any]) -> Any: + raise NotImplementedError + + # ========================================================================= + # Format Conversion (no external deps) + # ========================================================================= + + def as_claude_tools(self, *, cache_control: bool = False) -> list[dict[str, Any]]: + """Convert to Claude/Anthropic tool format. + + Args: + cache_control: Add cache_control for prompt caching + + Returns: + List of tool definitions for Claude API. + + Example: + ```python + from anthropic import Anthropic + + client = Anthropic() + async with env: + response = client.messages.create( + model="claude-sonnet-4-20250514", + max_tokens=1024, + messages=[{"role": "user", "content": "Navigate to google.com"}], + tools=env.as_claude_tools(), + ) + # Execute tool calls + for block in response.content: + if block.type == "tool_use": + result = await env.call_tool(block) + ``` + """ + tools = [] + for t in self.as_tools(): + tool: dict[str, Any] = { + "name": t.name, + "description": t.description or "", + "input_schema": t.inputSchema or {"type": "object", "properties": {}}, + } + if cache_control: + tool["cache_control"] = {"type": "ephemeral"} + tools.append(tool) + return tools + + def as_claude_programmatic_tools(self, *, cache_control: bool = False) -> list[dict[str, Any]]: + """Convert to Claude programmatic tool use format. + + Programmatic tool use allows Claude to execute tools via code execution. + + Example: + ```python + from anthropic import Anthropic + + client = Anthropic() + async with env: + response = client.messages.create( + model="claude-sonnet-4-20250514", + max_tokens=1024, + messages=[{"role": "user", "content": "Analyze the data"}], + tools=env.as_claude_programmatic_tools(), + betas=["code-execution-2025-01-24"], + ) + ``` + """ + tools = [] + for t in self.as_tools(): + tool: dict[str, Any] = { + "name": t.name, + "description": t.description or "", + "input_schema": t.inputSchema or {"type": "object", "properties": {}}, + "allowed_callers": ["code_execution_20250825"], + } + if cache_control: + tool["cache_control"] = {"type": "ephemeral"} + tools.append(tool) + return tools + + # ========================================================================= + # Tool Runner Integration (requires anthropic) + # ========================================================================= + + def as_anthropic_runner(self) -> EnvToolRunner: + """Create an Anthropic tool runner for this environment. + + Requires: pip install anthropic + + Returns: + EnvToolRunner that can process tool_use blocks from Claude. + + Example: + ```python + from anthropic import Anthropic + + client = Anthropic() + async with env: + runner = env.as_anthropic_runner() + + response = client.messages.create( + model="claude-sonnet-4-20250514", + max_tokens=1024, + messages=[{"role": "user", "content": "Navigate to google.com"}], + tools=env.as_claude_tools(), + ) + + # Execute all tool_use blocks + results = [] + for block in response.content: + if block.type == "tool_use": + result = await runner.run(block) + results.append(result) + ``` + """ + if not _HAS_ANTHROPIC: + raise ImportError( + "Anthropic SDK not installed. Install with: pip install anthropic" + ) + + return EnvToolRunner(self) + + +class EnvToolRunner: + """Tool runner that executes tools against an Environment.""" + + def __init__(self, env: AnthropicMixin) -> None: + self.env = env + self._tool_names: set[str] | None = None + + @property + def tool_names(self) -> set[str]: + """Get available tool names.""" + if self._tool_names is None: + self._tool_names = {t.name for t in self.env.as_tools()} + return self._tool_names + + async def run(self, tool_use_block: Any) -> dict[str, Any]: + """Execute a tool_use block from Claude. + + Args: + tool_use_block: A ToolUseBlock from Claude's response. + + Returns: + Tool result dict (or BetaToolResultBlockParam if anthropic installed). + """ + name = tool_use_block.name + tool_use_id = tool_use_block.id + arguments = tool_use_block.input or {} + + try: + result = await self.env.call_tool(name, **arguments) + content = result if isinstance(result, str) else json.dumps(result) if result else "" + result_dict: dict[str, Any] = { + "type": "tool_result", + "tool_use_id": tool_use_id, + "content": content, + } + except Exception as e: + result_dict = { + "type": "tool_result", + "tool_use_id": tool_use_id, + "content": f"Error: {e}", + "is_error": True, + } + + # Return typed object if anthropic is available + if _HAS_ANTHROPIC and BetaToolResultBlockParam is not None: + return BetaToolResultBlockParam(**result_dict) + return result_dict diff --git a/hud/environment/integrations/gemini.py b/hud/environment/integrations/gemini.py new file mode 100644 index 00000000..e8899d99 --- /dev/null +++ b/hud/environment/integrations/gemini.py @@ -0,0 +1,93 @@ +"""Google/Gemini integrations - format conversion.""" + +from __future__ import annotations + +from typing import TYPE_CHECKING, Any + +if TYPE_CHECKING: + import mcp.types as mcp_types + +__all__ = ["GeminiMixin"] + + +class GeminiMixin: + """Mixin providing Google/Gemini format conversion. + + Format methods (no deps): + as_gemini_tools() - Gemini tool format + as_gemini_tool_config() - Tool execution config + + Requires: as_tools() -> list[mcp_types.Tool] + """ + + def as_tools(self) -> list[mcp_types.Tool]: + raise NotImplementedError + + def as_gemini_tools(self) -> list[dict[str, Any]]: + """Convert to Gemini/Google AI tool format. + + Returns: + List with function_declarations for Gemini API. + + Example: + ```python + import google.generativeai as genai + + model = genai.GenerativeModel("gemini-1.5-pro") + async with env: + response = model.generate_content( + "Navigate to google.com", + tools=env.as_gemini_tools(), + ) + # Execute tool calls + for part in response.candidates[0].content.parts: + if fn := part.function_call: + result = await env.call_tool(part) + ``` + """ + return [{ + "function_declarations": [ + { + "name": t.name, + "description": t.description or "", + "parameters": t.inputSchema or {"type": "object", "properties": {}}, + } + for t in self.as_tools() + ] + }] + + def as_gemini_tool_config( + self, + mode: str = "AUTO", + allowed_tools: list[str] | None = None, + ) -> dict[str, Any]: + """Get Gemini tool_config for controlling tool execution. + + Args: + mode: "AUTO", "ANY", or "NONE" + allowed_tools: If mode is "ANY", list of allowed tool names + + Returns: + Tool config dict for Gemini API. + + Example: + ```python + import google.generativeai as genai + + model = genai.GenerativeModel("gemini-1.5-pro") + async with env: + # Force specific tool usage + response = model.generate_content( + "Search for cats", + tools=env.as_gemini_tools(), + tool_config=env.as_gemini_tool_config( + mode="ANY", + allowed_tools=["search"] + ), + ) + ``` + """ + config: dict[str, Any] = {"function_calling_config": {"mode": mode}} + if mode == "ANY" and allowed_tools: + config["function_calling_config"]["allowed_function_names"] = allowed_tools + return config diff --git a/hud/environment/integrations/langchain.py b/hud/environment/integrations/langchain.py new file mode 100644 index 00000000..d52eaf87 --- /dev/null +++ b/hud/environment/integrations/langchain.py @@ -0,0 +1,114 @@ +"""LangChain integration.""" + +from __future__ import annotations + +import json +from typing import TYPE_CHECKING, Any + +from hud.environment.utils.schema import schema_to_pydantic + +# Try to import langchain +try: + from langchain_core.tools import StructuredTool + _HAS_LANGCHAIN = True +except ImportError: + _HAS_LANGCHAIN = False + StructuredTool = None # type: ignore[misc, assignment] + +if TYPE_CHECKING: + import mcp.types as mcp_types + +__all__ = ["LangChainMixin"] + + +class LangChainMixin: + """Mixin providing LangChain integration. + + Integration methods (requires langchain-core): + as_langchain_tools() - LangChain StructuredTool objects + + Requires: as_tools() -> list[mcp_types.Tool], call_tool(name, args) + """ + + def as_tools(self) -> list[mcp_types.Tool]: + raise NotImplementedError + + async def call_tool(self, name: str, arguments: dict[str, Any]) -> Any: + raise NotImplementedError + + def as_langchain_tools(self) -> list[Any]: + """Convert to LangChain StructuredTool objects. + + Requires: pip install langchain-core + + Returns: + List of StructuredTool objects for LangChain agents. + + Example: + ```python + from langchain_openai import ChatOpenAI + from langchain.agents import create_tool_calling_agent, AgentExecutor + from langchain_core.prompts import ChatPromptTemplate + + llm = ChatOpenAI(model="gpt-4o") + async with env: + tools = env.as_langchain_tools() + + prompt = ChatPromptTemplate.from_messages([ + ("system", "You are a helpful assistant."), + ("human", "{input}"), + ("placeholder", "{agent_scratchpad}"), + ]) + + agent = create_tool_calling_agent(llm, tools, prompt) + executor = AgentExecutor(agent=agent, tools=tools) + result = await executor.ainvoke({"input": "Navigate to google.com"}) + ``` + """ + if not _HAS_LANGCHAIN: + raise ImportError( + "LangChain not installed. Install with: pip install langchain-core" + ) + + tools = [] + for t in self.as_tools(): + tool = _create_structured_tool(self, t) + tools.append(tool) + return tools + + +def _create_structured_tool(env: LangChainMixin, tool: mcp_types.Tool) -> Any: + """Create a StructuredTool that calls back to the environment.""" + import asyncio + + schema = tool.inputSchema or {"type": "object", "properties": {}} + + def sync_invoke(**kwargs: Any) -> str: + """Synchronous wrapper for the tool.""" + loop = asyncio.get_event_loop() + if loop.is_running(): + import concurrent.futures + with concurrent.futures.ThreadPoolExecutor() as executor: + future = executor.submit(asyncio.run, env.call_tool(tool.name, **kwargs)) + result = future.result() + else: + result = loop.run_until_complete(env.call_tool(tool.name, **kwargs)) + + if isinstance(result, str): + return result + return json.dumps(result) if result else "" + + async def async_invoke(**kwargs: Any) -> str: + """Async wrapper for the tool.""" + result = await env.call_tool(tool.name, **kwargs) + if isinstance(result, str): + return result + return json.dumps(result) if result else "" + + return StructuredTool( + name=tool.name, + description=tool.description or "", + func=sync_invoke, + coroutine=async_invoke, + args_schema=schema_to_pydantic(tool.name, schema), + ) diff --git a/hud/environment/integrations/openai.py b/hud/environment/integrations/openai.py new file mode 100644 index 00000000..3a188bf5 --- /dev/null +++ b/hud/environment/integrations/openai.py @@ -0,0 +1,202 @@ +"""OpenAI integrations - format conversion and Agents SDK.""" + +from __future__ import annotations + +import json +from typing import TYPE_CHECKING, Any + +from hud.environment.utils.schema import ensure_strict_schema + +# Try to import OpenAI Agents SDK +try: + from agents import FunctionTool + _HAS_AGENTS = True +except ImportError: + _HAS_AGENTS = False + FunctionTool = None # type: ignore[misc, assignment] + +if TYPE_CHECKING: + import mcp.types as mcp_types + +__all__ = ["OpenAIMixin"] + + +class OpenAIMixin: + """Mixin providing OpenAI format conversion and Agents SDK integration. + + Format methods (no deps): + as_openai_chat_tools() - Chat Completions format + as_openai_responses_tools() - Responses API format + + Integration methods (requires openai-agents): + as_openai_agent_tools() - Agents SDK FunctionTool objects + + Note: The OpenAI Agents SDK also supports: + - HostedMCPTool - MCP tools hosted by OpenAI + - MCPServerStdio/Sse/StreamableHttp - Direct MCP server connections + + For MCP server integration, use as_mcp_server() from the mcp integration. + + Requires: as_tools() -> list[mcp_types.Tool], call_tool(name, args) + """ + + def as_tools(self) -> list[mcp_types.Tool]: + raise NotImplementedError + + async def call_tool(self, name: str, arguments: dict[str, Any]) -> Any: + raise NotImplementedError + + # ========================================================================= + # Format Conversion (no external deps) + # ========================================================================= + + def as_openai_chat_tools(self, *, strict: bool = False) -> list[dict[str, Any]]: + """Convert to OpenAI Chat Completions tool format. + + Args: + strict: Enable strict mode for structured outputs + + Returns: + List of tool definitions for OpenAI Chat Completions API. + + Example: + ```python + from openai import OpenAI + + client = OpenAI() + async with env: + response = client.chat.completions.create( + model="gpt-4o", + messages=[{"role": "user", "content": "Navigate to google.com"}], + tools=env.as_openai_chat_tools(), + ) + # Execute tool calls and get results in OpenAI format + results = await env.call_tools(response.choices[0].message.tool_calls) + # results are {"role": "tool", "tool_call_id": ..., "content": ...} + ``` + """ + tools = [] + for t in self.as_tools(): + schema = dict(t.inputSchema) if t.inputSchema else {"type": "object", "properties": {}} + + if strict: + schema = ensure_strict_schema(schema) + + tools.append({ + "type": "function", + "function": { + "name": t.name, + "description": t.description or "", + "parameters": schema, + **({"strict": True} if strict else {}), + }, + }) + return tools + + def as_openai_responses_tools(self) -> list[dict[str, Any]]: + """Convert to OpenAI Responses API tool format. + + Note: Like Chat Completions, you must execute tools yourself. + OpenAI only auto-executes their built-in tools (code_interpreter, etc). + + Returns: + List of tool definitions for OpenAI Responses API. + + Example: + ```python + from openai import OpenAI + + client = OpenAI() + async with env: + response = client.responses.create( + model="gpt-4o", + input="Navigate to google.com", + tools=env.as_openai_responses_tools(), + ) + # Check for function calls in the response + for item in response.output: + if item.type == "function_call": + result = await env.call_tool(item.name, **item.arguments) + ``` + """ + return [{ + "type": "function", + "name": t.name, + "description": t.description or "", + "parameters": t.inputSchema or {"type": "object", "properties": {}}, + } for t in self.as_tools()] + + # ========================================================================= + # Agents SDK Integration (requires openai-agents) + # ========================================================================= + + def as_openai_agent_tools(self) -> list[Any]: + """Convert to OpenAI Agents SDK FunctionTool objects. + + This creates FunctionTool objects that automatically execute against + this environment. The Agents SDK Runner handles the tool loop. + + Note: The Agents SDK also supports other tool types: + - HostedMCPTool: MCP tools hosted by OpenAI + - MCPServerStdio/Sse/StreamableHttp: Direct MCP server connections + + For direct MCP integration, consider using as_mcp_server(). + + Requires: pip install openai-agents + + Returns: + List of FunctionTool objects for OpenAI Agents SDK. + + Example: + ```python + from agents import Agent, Runner + + async with env: + agent = Agent( + name="browser-agent", + instructions="You browse the web.", + tools=env.as_openai_agent_tools(), + ) + result = await Runner.run(agent, "Go to google.com") + print(result.final_output) + ``` + """ + if not _HAS_AGENTS: + raise ImportError( + "OpenAI Agents SDK not installed. Install with: pip install openai-agents" + ) + + tools = [] + for t in self.as_tools(): + tool = _create_function_tool(self, t) + tools.append(tool) + return tools + + +def _create_function_tool(env: OpenAIMixin, tool: mcp_types.Tool) -> Any: + """Create a FunctionTool that calls back to the environment.""" + import asyncio + + schema = tool.inputSchema or {"type": "object", "properties": {}} + + def sync_wrapper(**kwargs: Any) -> str: + """Synchronous wrapper for the tool.""" + loop = asyncio.get_event_loop() + if loop.is_running(): + import concurrent.futures + with concurrent.futures.ThreadPoolExecutor() as executor: + future = executor.submit(asyncio.run, env.call_tool(tool.name, **kwargs)) + result = future.result() + else: + result = loop.run_until_complete(env.call_tool(tool.name, **kwargs)) + + if isinstance(result, str): + return result + return json.dumps(result) if result else "" + + return FunctionTool( + name=tool.name, + description=tool.description or "", + params_json_schema=schema, + on_invoke_tool=sync_wrapper, + ) diff --git a/hud/environment/mock.py b/hud/environment/mock.py new file mode 100644 index 00000000..37711a69 --- /dev/null +++ b/hud/environment/mock.py @@ -0,0 +1,306 @@ +"""Mock functionality for Environment.""" + +from __future__ import annotations + +import logging +from typing import TYPE_CHECKING, Any + +import mcp.types as mcp_types + +from hud.types import MCPToolResult + +if TYPE_CHECKING: + from hud.environment.environment import Environment + +__all__ = ["MockMixin", "generate_mock_value"] + +logger = logging.getLogger(__name__) + + +def generate_mock_value(schema: dict[str, Any], depth: int = 0) -> Any: + """Generate a reasonable mock value from a JSON schema. + + Args: + schema: JSON schema dict with 'type', 'properties', etc. + depth: Current recursion depth (to prevent infinite loops). + + Returns: + A mock value that matches the schema. + """ + if depth > 10: # Prevent infinite recursion + return None + + # Handle $ref - we don't resolve refs, just return placeholder + if "$ref" in schema: + return {} + + # Handle anyOf/oneOf/allOf - pick first option + if "anyOf" in schema: + return generate_mock_value(schema["anyOf"][0], depth + 1) + if "oneOf" in schema: + return generate_mock_value(schema["oneOf"][0], depth + 1) + if "allOf" in schema: + # Merge all schemas + merged: dict[str, Any] = {} + for sub_schema in schema["allOf"]: + result = generate_mock_value(sub_schema, depth + 1) + if isinstance(result, dict): + merged.update(result) + return merged + + # Check for const or enum first + if "const" in schema: + return schema["const"] + if "enum" in schema: + return schema["enum"][0] if schema["enum"] else None + + # Check for default value + if "default" in schema: + return schema["default"] + + # Handle by type + schema_type = schema.get("type") + + if schema_type == "string": + # Check for format hints + fmt = schema.get("format", "") + if fmt == "uri" or fmt == "url": + return "https://example.com" + if fmt == "email": + return "user@example.com" + if fmt == "date": + return "2024-01-01" + if fmt == "date-time": + return "2024-01-01T00:00:00Z" + if fmt == "uuid": + return "00000000-0000-0000-0000-000000000000" + # Use title/description hint if available + title = schema.get("title", "").lower() + if "url" in title or "link" in title: + return "https://example.com" + if "name" in title: + return "mock_name" + if "id" in title: + return "mock_id" + return "mock_string" + + if schema_type == "number" or schema_type == "integer": + # Check for bounds + minimum = schema.get("minimum", 0) + maximum = schema.get("maximum", 100) + if schema_type == "integer": + return int((minimum + maximum) / 2) if maximum != float("inf") else minimum + return float((minimum + maximum) / 2) if maximum != float("inf") else float(minimum) + + if schema_type == "boolean": + return True + + if schema_type == "null": + return None + + if schema_type == "array": + items_schema = schema.get("items", {}) + if items_schema: + # Generate one item + return [generate_mock_value(items_schema, depth + 1)] + return [] + + if schema_type == "object" or "properties" in schema: + result: dict[str, Any] = {} + properties = schema.get("properties", {}) + required = set(schema.get("required", [])) + + for prop_name, prop_schema in properties.items(): + # Only include required properties or first few optional ones + if prop_name in required or len(result) < 3: + result[prop_name] = generate_mock_value(prop_schema, depth + 1) + + return result + + # Handle list of types + if isinstance(schema_type, list): + # Pick first non-null type + for t in schema_type: + if t != "null": + return generate_mock_value({"type": t}, depth + 1) + return None + + # Fallback for unknown schema + return None + + +def generate_mock_tool_result(tool: mcp_types.Tool) -> MCPToolResult: + """Generate a mock result for a tool based on its output schema. + + Args: + tool: MCP Tool with inputSchema and optionally outputSchema. + + Returns: + MCPToolResult with mock content. + """ + # Check if tool has an output schema + output_schema = getattr(tool, "outputSchema", None) + + if output_schema: + mock_value = generate_mock_value(output_schema) + content_text = str(mock_value) if mock_value is not None else "mock_result" + else: + # Generate a sensible default based on tool name + tool_name = tool.name + if "screenshot" in tool_name.lower() or "image" in tool_name.lower(): + content_text = "[mock image data]" + elif "get" in tool_name.lower() or "list" in tool_name.lower(): + content_text = "[]" + elif "check" in tool_name.lower() or "verify" in tool_name.lower(): + content_text = "true" + elif "count" in tool_name.lower(): + content_text = "0" + else: + content_text = "mock_success" + + return MCPToolResult( + content=[mcp_types.TextContent(type="text", text=content_text)], + isError=False, + ) + + +class MockMixin: + """Mixin that adds mock functionality to Environment. + + When mock mode is enabled: + - All tool calls return mock values instead of executing + - Specific tools can have custom mock outputs via mock_tool() + - Tools are automatically mocked with reasonable defaults based on their schemas + + Usage: + env = Environment("test").connect_hub("browser") + env.mock() # Enable mock mode + + # Set specific mock outputs + env.mock_tool("navigate", "Navigation successful") + env.mock_tool("screenshot", {"image": "base64data..."}) + + async with env: + result = await env.call_tool("navigate", url="https://example.com") + # Returns: MCPToolResult with "Navigation successful" + """ + + _mock_mode: bool + _mock_outputs: dict[str, Any] + _mock_tool_schemas: dict[str, mcp_types.Tool] + + def _init_mock(self) -> None: + """Initialize mock state. Called from Environment.__init__.""" + self._mock_mode = False + self._mock_outputs = {} + self._mock_tool_schemas = {} + + def mock(self) -> "Environment": + """Enable mock mode - all tool calls will return mock values. + + Returns: + self for chaining. + + Example: + env = Environment("test").connect_hub("browser").mock() + """ + self._mock_mode = True + logger.info("Mock mode enabled for environment %s", getattr(self, "name", "unknown")) + return self # type: ignore[return-value] + + def unmock(self) -> "Environment": + """Disable mock mode - tool calls will execute normally. + + Returns: + self for chaining. + """ + self._mock_mode = False + logger.info("Mock mode disabled for environment %s", getattr(self, "name", "unknown")) + return self # type: ignore[return-value] + + @property + def is_mock(self) -> bool: + """Check if mock mode is enabled.""" + return self._mock_mode + + def mock_tool(self, name: str, output: Any) -> "Environment": + """Set a specific mock output for a tool. + + Args: + name: Tool name (with prefix if applicable). + output: The value to return when this tool is called. + Can be a string, dict, or any JSON-serializable value. + + Returns: + self for chaining. + + Example: + env.mock_tool("navigate", "Success") + env.mock_tool("screenshot", {"type": "image", "data": "..."}) + env.mock_tool("get_elements", [{"id": "1", "text": "Button"}]) + """ + self._mock_outputs[name] = output + logger.debug("Mock output set for tool %s", name) + return self # type: ignore[return-value] + + def _get_mock_result(self, name: str, arguments: dict[str, Any]) -> MCPToolResult: + """Get mock result for a tool call. + + Priority: + 1. Custom mock output set via mock_tool() + 2. Auto-generated mock based on tool's output schema + 3. Default mock value + + Args: + name: Tool name. + arguments: Tool arguments (for potential future use). + + Returns: + MCPToolResult with mock content. + """ + # Check for custom mock output + if name in self._mock_outputs: + output = self._mock_outputs[name] + # Convert to string if not already + if isinstance(output, str): + content_text = output + else: + import json + try: + content_text = json.dumps(output) + except (TypeError, ValueError): + content_text = str(output) + + return MCPToolResult( + content=[mcp_types.TextContent(type="text", text=content_text)], + isError=False, + ) + + # Try to find tool schema for auto-generation + if name in self._mock_tool_schemas: + return generate_mock_tool_result(self._mock_tool_schemas[name]) + + # Check router for tool schema + router = getattr(self, "_router", None) + if router: + for tool in router.tools: + if tool.name == name: + self._mock_tool_schemas[name] = tool + return generate_mock_tool_result(tool) + + # Default fallback + return MCPToolResult( + content=[mcp_types.TextContent(type="text", text="mock_success")], + isError=False, + ) + + def _populate_mock_schemas(self) -> None: + """Populate mock tool schemas from router after connection. + + Called after _build_routing to cache tool schemas for mock generation. + """ + router = getattr(self, "_router", None) + if router: + for tool in router.tools: + self._mock_tool_schemas[tool.name] = tool + diff --git a/hud/environment/router.py b/hud/environment/router.py new file mode 100644 index 00000000..ccdb5b83 --- /dev/null +++ b/hud/environment/router.py @@ -0,0 +1,105 @@ +"""Tool routing for Environment.""" + +from __future__ import annotations + +import logging +from dataclasses import dataclass, field +from enum import Enum +from typing import TYPE_CHECKING + +import mcp.types as mcp_types + +if TYPE_CHECKING: + from hud.environment.connection import Connector + +__all__ = ["ConflictResolution", "ToolRouter", "LOCAL_CONNECTION"] + +logger = logging.getLogger(__name__) + +LOCAL_CONNECTION = "__local__" + + +class ConflictResolution(str, Enum): + """Strategy for resolving tool name conflicts.""" + PREFIX = "prefix" # Add connection name as prefix + FIRST_WINS = "first_wins" # First connection wins + LAST_WINS = "last_wins" # Last connection wins + ERROR = "error" # Raise error on conflict + + +@dataclass +class ToolRouter: + """Routes tool calls to local or remote handlers with conflict resolution.""" + + conflict_resolution: ConflictResolution = ConflictResolution.PREFIX + _tools: list[mcp_types.Tool] = field(default_factory=list) + _routing: dict[str, str] = field(default_factory=dict) # name -> connection + _local_names: set[str] = field(default_factory=set) + + @property + def tools(self) -> list[mcp_types.Tool]: + return self._tools + + def is_local(self, name: str) -> bool: + return name in self._local_names + + def get_connection(self, name: str) -> str | None: + """Get connection name for tool, None if local or not found.""" + conn = self._routing.get(name) + return None if conn == LOCAL_CONNECTION else conn + + def clear(self) -> None: + self._tools.clear() + self._routing.clear() + self._local_names.clear() + + def build( + self, + local_tools: list[mcp_types.Tool], + connections: dict[str, Connector], + connection_order: list[str], + ) -> None: + """Build routing from local tools and connection caches. + + Local tools always have priority over remote tools. + """ + self.clear() + seen: dict[str, str] = {} + + # Local tools first (always priority) + for tool in local_tools: + seen[tool.name] = LOCAL_CONNECTION + self._routing[tool.name] = LOCAL_CONNECTION + self._local_names.add(tool.name) + self._tools.append(tool) + + # Remote connections in order + for conn_name in connection_order: + if conn_name not in connections: + continue + for tool in connections[conn_name].cached_tools: + name = tool.name + if name in seen: + existing = seen[name] + if existing == LOCAL_CONNECTION: + continue # Local always wins + if not self._handle_conflict(name, existing, conn_name): + continue + self._tools = [t for t in self._tools if t.name != name] + + seen[name] = conn_name + self._routing[name] = conn_name + self._tools.append(tool) + + logger.debug("Router: %d tools (%d local)", len(self._tools), len(self._local_names)) + + def _handle_conflict(self, name: str, existing: str, new: str) -> bool: + """Handle remote-to-remote conflict. Returns True to replace existing.""" + if self.conflict_resolution == ConflictResolution.ERROR: + raise ValueError(f"Tool conflict: '{name}' in '{existing}' and '{new}'") + if self.conflict_resolution == ConflictResolution.FIRST_WINS: + return False + if self.conflict_resolution == ConflictResolution.LAST_WINS: + return True + # PREFIX - shouldn't conflict if prefixes set correctly + return False diff --git a/hud/environment/utils/__init__.py b/hud/environment/utils/__init__.py new file mode 100644 index 00000000..81d9fc36 --- /dev/null +++ b/hud/environment/utils/__init__.py @@ -0,0 +1,25 @@ +"""Environment utilities.""" + +from hud.environment.utils.formats import ( + ToolFormat, + format_result, + parse_tool_call, + parse_tool_calls, + result_to_string, +) +from hud.environment.utils.schema import ( + ensure_strict_schema, + json_type_to_python, + schema_to_pydantic, +) + +__all__ = [ + "ToolFormat", + "ensure_strict_schema", + "format_result", + "json_type_to_python", + "parse_tool_call", + "parse_tool_calls", + "result_to_string", + "schema_to_pydantic", +] diff --git a/hud/environment/utils/formats.py b/hud/environment/utils/formats.py new file mode 100644 index 00000000..2a9e3b61 --- /dev/null +++ b/hud/environment/utils/formats.py @@ -0,0 +1,213 @@ +"""Tool format parsing and conversion for OpenAI, Claude, Gemini, and MCP.""" + +from __future__ import annotations + +import json +from enum import Enum, auto +from typing import Any + +from hud.types import MCPToolCall, MCPToolResult + +__all__ = [ + "ToolFormat", + "format_result", + "parse_tool_call", + "parse_tool_calls", + "result_to_string", +] + + +class ToolFormat(Enum): + """Detected tool call format.""" + OPENAI = auto() # function.arguments as JSON string + CLAUDE = auto() # type="tool_use", input as dict + GEMINI = auto() # functionCall with args + MCP = auto() # name + arguments + + +# ----------------------------------------------------------------------------- +# Parsing +# ----------------------------------------------------------------------------- + +def _to_dict(obj: Any) -> dict[str, Any]: + """Convert object to dict for uniform processing.""" + if isinstance(obj, dict): + return obj + if hasattr(obj, "model_dump"): + return obj.model_dump() + if hasattr(obj, "__dict__"): + return vars(obj) + raise ValueError(f"Cannot convert {type(obj).__name__} to dict") + + +def _parse_json_args(args: Any) -> dict[str, Any]: + """Parse arguments, handling JSON strings.""" + if not args: + return {} + if isinstance(args, str): + try: + return json.loads(args) + except json.JSONDecodeError: + return {} + return args + + +def parse_tool_call(call: Any, **kwargs: Any) -> tuple[MCPToolCall, ToolFormat]: + """Parse any tool call format into (MCPToolCall, ToolFormat). + + Supports: + - String (tool name only, or with kwargs) + - Tuple: (name,), (name, args), (name, args, id) + - MCPToolCall + - OpenAI: {function: {name, arguments}, id} + - Claude: {type: "tool_use", name, input, id} + - Gemini: {functionCall: {name, args}} or {name, args} + - Generic: {name, arguments} + + Args: + call: Tool call in any supported format. + **kwargs: Additional arguments (merged when call is a string). + + Returns: + Tuple of (MCPToolCall, ToolFormat) for the parsed call. + + Raises: + ValueError: If format is unrecognized. + """ + # Primitives + if isinstance(call, str): + return MCPToolCall(name=call, arguments=kwargs or {}), ToolFormat.MCP + + if isinstance(call, tuple): + tc = MCPToolCall(name=call[0], arguments=call[1] if len(call) > 1 else {}) + if len(call) > 2: + tc.id = call[2] + return tc, ToolFormat.MCP + + if isinstance(call, MCPToolCall): + return call, ToolFormat.MCP + + # Convert to dict + d = _to_dict(call) + + # OpenAI: {function: {name, arguments}, id} + if "function" in d: + f = _to_dict(d["function"]) if not isinstance(d["function"], dict) else d["function"] + tc = MCPToolCall(name=f["name"], arguments=_parse_json_args(f.get("arguments"))) + if d.get("id"): + tc.id = d["id"] + return tc, ToolFormat.OPENAI + + # Claude: {type: "tool_use", name, input, id} + if d.get("type") == "tool_use": + tc = MCPToolCall(name=d["name"], arguments=d.get("input") or {}) + if d.get("id"): + tc.id = d["id"] + return tc, ToolFormat.CLAUDE + + # Gemini: {functionCall: {name, args}} or {name, args} + if "functionCall" in d: + fc = d["functionCall"] + return MCPToolCall(name=fc["name"], arguments=fc.get("args") or {}), ToolFormat.GEMINI + + if "args" in d and "name" in d and "arguments" not in d: + return MCPToolCall(name=d["name"], arguments=d.get("args") or {}), ToolFormat.GEMINI + + # Generic: {name, arguments/input} + if "name" in d: + tc = MCPToolCall(name=d["name"], arguments=d.get("arguments") or d.get("input") or {}) + if d.get("id"): + tc.id = d["id"] + return tc, ToolFormat.MCP + + raise ValueError(f"Unrecognized tool call format: {list(d.keys())}") + + +def _is_tool_block(item: Any) -> bool: + """Check if item is a tool call (not text/other content).""" + t = item.get("type") if isinstance(item, dict) else getattr(item, "type", None) + return t is None or t in ("tool_use", "function") + + +def parse_tool_calls(calls: Any) -> list[tuple[MCPToolCall, ToolFormat]]: + """Parse multiple tool calls, filtering non-tool content (e.g. Claude TextBlock). + + Args: + calls: Single call or list of calls in any format. + + Returns: + List of (MCPToolCall, ToolFormat) tuples. + """ + if calls is None: + return [] + if not isinstance(calls, list): + try: + return [parse_tool_call(calls)] + except ValueError: + return [] + + results = [] + for item in calls: + if not _is_tool_block(item): + continue + try: + results.append(parse_tool_call(item)) + except ValueError: + continue + return results + + +# ----------------------------------------------------------------------------- +# Result Formatting +# ----------------------------------------------------------------------------- + +def result_to_string(result: MCPToolResult) -> str: + """Convert MCPToolResult content to string. + + Args: + result: MCP tool result with content blocks. + + Returns: + String representation of the result content. + """ + if not result.content: + return "" + parts = [] + for block in result.content: + if (text := getattr(block, "text", None)) is not None: + parts.append(str(text)) + elif (data := getattr(block, "data", None)) is not None: + parts.append(f"[binary: {len(data)} bytes]") + return "\n".join(parts) + + +def format_result(result: MCPToolResult, tc: MCPToolCall, fmt: ToolFormat) -> Any: + """Format MCPToolResult based on the input format. + + Args: + result: MCP tool result. + tc: Original tool call (for id/name). + fmt: Target format. + + Returns: + OpenAI: {"role": "tool", "tool_call_id": ..., "content": ...} + Claude: {"type": "tool_result", "tool_use_id": ..., "content": ..., "is_error"?: bool} + Gemini: {"functionResponse": {"name": ..., "response": {"result": ...}}} + MCP: MCPToolResult unchanged + """ + content = result_to_string(result) + + if fmt == ToolFormat.OPENAI: + return {"role": "tool", "tool_call_id": tc.id, "content": content} + + if fmt == ToolFormat.CLAUDE: + r: dict[str, Any] = {"type": "tool_result", "tool_use_id": tc.id, "content": content} + if result.isError: + r["is_error"] = True + return r + + if fmt == ToolFormat.GEMINI: + return {"functionResponse": {"name": tc.name, "response": {"result": content}}} + + return result # MCP format - return as-is + diff --git a/hud/environment/utils/schema.py b/hud/environment/utils/schema.py new file mode 100644 index 00000000..6e0c2029 --- /dev/null +++ b/hud/environment/utils/schema.py @@ -0,0 +1,97 @@ +"""Schema utilities for tool definitions.""" + +from __future__ import annotations + +from typing import Any + +__all__ = ["ensure_strict_schema", "schema_to_pydantic", "json_type_to_python"] + + +def ensure_strict_schema(schema: dict[str, Any]) -> dict[str, Any]: + """Ensure a JSON schema is compatible with OpenAI's strict mode. + + OpenAI strict mode requires: + - additionalProperties: false on all objects + - All properties must be in required + + Args: + schema: Original JSON schema. + + Returns: + Modified schema for strict mode. + """ + schema = dict(schema) + + if schema.get("type") == "object": + schema["additionalProperties"] = False + + if "properties" in schema: + # All properties must be required + schema["required"] = list(schema["properties"].keys()) + + # Recursively process nested objects + for prop_schema in schema["properties"].values(): + if isinstance(prop_schema, dict): + _ensure_strict_recursive(prop_schema) + + return schema + + +def _ensure_strict_recursive(schema: dict[str, Any]) -> None: + """Recursively apply strict mode to nested schemas.""" + if schema.get("type") == "object": + schema["additionalProperties"] = False + if "properties" in schema: + schema["required"] = list(schema["properties"].keys()) + for prop_schema in schema["properties"].values(): + if isinstance(prop_schema, dict): + _ensure_strict_recursive(prop_schema) + + elif schema.get("type") == "array" and "items" in schema: + if isinstance(schema["items"], dict): + _ensure_strict_recursive(schema["items"]) + + +def schema_to_pydantic(name: str, schema: dict[str, Any]) -> type: + """Convert JSON schema to a Pydantic model. + + Args: + name: Model name (used for class name). + schema: JSON schema with properties. + + Returns: + Dynamically created Pydantic model class. + """ + from pydantic import Field, create_model + + properties = schema.get("properties", {}) + required = set(schema.get("required", [])) + + fields = {} + for prop_name, prop_schema in properties.items(): + prop_type = json_type_to_python(prop_schema.get("type", "string")) + default = ... if prop_name in required else None + description = prop_schema.get("description", "") + fields[prop_name] = (prop_type, Field(default=default, description=description)) + + return create_model(f"{name}Input", **fields) + + +def json_type_to_python(json_type: str) -> type: + """Map JSON schema type to Python type. + + Args: + json_type: JSON schema type string. + + Returns: + Corresponding Python type. + """ + mapping = { + "string": str, + "integer": int, + "number": float, + "boolean": bool, + "array": list, + "object": dict, + } + return mapping.get(json_type, str) diff --git a/hud/trace/__init__.py b/hud/trace/__init__.py new file mode 100644 index 00000000..60afb5d4 --- /dev/null +++ b/hud/trace/__init__.py @@ -0,0 +1,42 @@ +""" +HUD Trace System - Context management for agent runs. + +The trace system provides: +- TraceContext: Core abstraction for recording agent runs +- TraceMixin: Mixin that adds trace() method to Environment +- Auto-instrumentation of httpx for inference.hud.ai +- Parallel execution with group=N + +Usage (single execution): + ```python + async with env.trace("google-search") as tc: + await tc.call_tool("navigate", {"url": "..."}) + tc.reward = 0.9 + + # tc has the results + print(tc.trace_id, tc.reward, tc.duration, tc.success) + ``` + +Usage (parallel execution): + ```python + async with env.trace("google-search", group=4) as tc: + # This body runs 4 times, each with a different tc! + await tc.call_tool("navigate", {"url": "..."}) + tc.reward = evaluate() + + # tc.results contains all parallel traces + # tc.reward is the mean reward + print(f"Mean reward: {tc.reward}") + for trace in tc.results: + print(f" {trace.trace_id}: {trace.reward}") + ``` +""" + +from hud.trace.context import TraceContext, get_current_trace_headers +from hud.trace.mixin import TraceMixin + +__all__ = [ + "TraceContext", + "TraceMixin", + "get_current_trace_headers", +] diff --git a/hud/trace/context.py b/hud/trace/context.py new file mode 100644 index 00000000..a80f5b98 --- /dev/null +++ b/hud/trace/context.py @@ -0,0 +1,357 @@ +"""TraceContext - Lightweight context for recording agent runs. + +TraceContext provides: +- Unique trace identification +- Headers for gateway integration (auto-injected to inference.hud.ai) +- Reward and status reporting to backend +- Tool call delegation + +All telemetry goes directly to the backend - nothing accumulated locally. + +Auto-instrumentation: + httpx clients are automatically instrumented when this module is imported. + Any request to inference.hud.ai will have trace headers injected. +""" + +from __future__ import annotations + +import contextvars +import logging +import uuid +from datetime import UTC, datetime +from typing import TYPE_CHECKING, Any, Self + +from hud.settings import settings +from hud.shared import make_request +from hud.telemetry.job import get_current_job + +if TYPE_CHECKING: + from types import TracebackType + + from hud.environment import Environment + from hud.types import MCPToolResult + +logger = logging.getLogger(__name__) + +# Contextvar to store current trace headers +_current_trace_headers: contextvars.ContextVar[dict[str, str] | None] = contextvars.ContextVar( + "current_trace_headers", default=None +) + + +def get_current_trace_headers() -> dict[str, str] | None: + """Get the current trace headers from context.""" + return _current_trace_headers.get() + + +# ============================================================================= +# Auto-instrumentation for httpx +# ============================================================================= + +def _httpx_request_hook(request: Any) -> None: + """httpx event hook that adds trace headers to inference.hud.ai requests.""" + headers = get_current_trace_headers() + if headers is None: + return + + url_str = str(request.url) + if "inference.hud.ai" not in url_str: + return + + for key, value in headers.items(): + request.headers[key] = value + + logger.debug("Added trace headers to request: %s", url_str) + + +async def _async_httpx_request_hook(request: Any) -> None: + """Async version of the httpx event hook.""" + _httpx_request_hook(request) + + +def _instrument_client(client: Any) -> None: + """Add trace hook to an httpx client instance.""" + is_async = hasattr(client, "aclose") + hook = _async_httpx_request_hook if is_async else _httpx_request_hook + + existing_hooks = client.event_hooks.get("request", []) + if hook not in existing_hooks: + existing_hooks.append(hook) + client.event_hooks["request"] = existing_hooks + + +def _patch_httpx() -> None: + """Monkey-patch httpx to auto-instrument all clients.""" + try: + import httpx + except ImportError: + logger.debug("httpx not installed, skipping auto-instrumentation") + return + + _original_async_init = httpx.AsyncClient.__init__ + + def _patched_async_init(self: Any, *args: Any, **kwargs: Any) -> None: + _original_async_init(self, *args, **kwargs) + _instrument_client(self) + + httpx.AsyncClient.__init__ = _patched_async_init # type: ignore[method-assign] + + _original_sync_init = httpx.Client.__init__ + + def _patched_sync_init(self: Any, *args: Any, **kwargs: Any) -> None: + _original_sync_init(self, *args, **kwargs) + _instrument_client(self) + + httpx.Client.__init__ = _patched_sync_init # type: ignore[method-assign] + + logger.debug("httpx auto-instrumentation enabled") + + +# Auto-patch httpx on module import +_patch_httpx() + + +# ============================================================================= +# TraceContext +# ============================================================================= + +class TraceContext: + """Lightweight context for a traced execution. + + Attributes: + trace_id: Unique identifier for this trace + name: Task name + job_id: Links to parent job (auto-detected from hud.job() context) + group_id: Links parallel traces together (None for single traces) + variants: Variant assignment dict (for A/B testing) + reward: Reward value (user-settable) + error: Exception if failed + results: All trace results (for parent trace) + + Computed: + headers: Gateway headers + duration: Execution time in seconds + success: True if no error + done: True if completed + + Example: + ```python + # Simple trace + async with env.trace("task") as tc: + await tc.call_tool("navigate", {"url": "..."}) + tc.reward = 0.9 + + # With variants (A/B testing) and group (multiple runs) + async with env.trace("task", + variants={"model": ["gpt-4o", "claude"]}, + group=3, + ) as tc: + model = tc.variants["model"] # Assigned for this run + response = await call_llm(model=model) + tc.reward = evaluate(response) + + # tc.results has 6 traces (2 variants x 3 runs each) + # All share the same tc.group_id + for t in tc.results: + print(f"{t.variants}: reward={t.reward}") + ``` + """ + + def __init__( + self, + env: Environment, + name: str, + *, + trace_id: str | None = None, + api_key: str | None = None, + job_id: str | None = None, + _group_id: str | None = None, + _index: int = 0, + _variants: dict[str, Any] | None = None, + ) -> None: + # Identity + self.trace_id: str = trace_id or str(uuid.uuid4()) + self.name: str = name + + # Job linkage - auto-detect from current job context if not provided + if job_id is None: + current_job = get_current_job() + self.job_id: str | None = current_job.id if current_job else None + else: + self.job_id = job_id + + self.group_id: str | None = _group_id # Links parallel traces together + self.index: int = _index # Local only, for debugging + + # Variant assignment (for A/B testing) + self.variants: dict[str, Any] = _variants or {} + + # User-settable + self.reward: float | None = None + + # Error tracking + self.error: BaseException | None = None + + # Parallel/variant results (nested) + self.results: list[TraceContext] | None = None + + # Private + self._env = env + self._api_key = api_key + self._started_at: datetime | None = None + self._completed_at: datetime | None = None + self._token: contextvars.Token[dict[str, str] | None] | None = None + + # ========================================================================= + # Computed Properties + # ========================================================================= + + @property + def headers(self) -> dict[str, str]: + """Headers for gateway integration.""" + return {"HUD-Trace-Id": self.trace_id} + + @property + def duration(self) -> float: + """Execution duration in seconds.""" + if self._started_at is None: + return 0.0 + end = self._completed_at or datetime.now(UTC) + return (end - self._started_at).total_seconds() + + @property + def success(self) -> bool: + """True if no error occurred.""" + return self.error is None + + @property + def done(self) -> bool: + """True if execution completed.""" + return self._completed_at is not None + + def _get_api_key(self) -> str | None: + return self._api_key or settings.api_key + + # ========================================================================= + # Tool Operations + # ========================================================================= + + async def call_tool( + self, + name: str, + arguments: dict[str, Any] | None = None, + ) -> MCPToolResult: + """Call a tool by name (delegates to environment).""" + return await self._env.call_tool(name, arguments) # type: ignore[attr-defined] + + # ========================================================================= + # Backend Integration + # ========================================================================= + + async def log(self, metrics: dict[str, Any]) -> None: + """Log metrics to the backend.""" + api_key = self._get_api_key() + if not settings.telemetry_enabled or not api_key: + return + + try: + await make_request( + method="POST", + url=f"{settings.hud_telemetry_url}/traces/{self.trace_id}/log", + json={"metrics": metrics, "timestamp": datetime.now(UTC).isoformat()}, + api_key=api_key, + ) + except Exception as e: + logger.warning("Failed to log metrics: %s", e) + + async def _trace_enter(self) -> None: + """Notify backend that trace has started.""" + api_key = self._get_api_key() + if not settings.telemetry_enabled or not api_key: + return + + try: + data: dict[str, Any] = { + "task_name": self.name, + "started_at": self._started_at.isoformat() if self._started_at else None, + } + if self.job_id: + data["job_id"] = self.job_id + if self.group_id: + data["group_id"] = self.group_id + if self.variants: + data["variants"] = self.variants + + await make_request( + method="POST", + url=f"{settings.hud_telemetry_url}/trace/{self.trace_id}/enter", + json=data, + api_key=api_key, + ) + except Exception as e: + logger.warning("Failed to send trace enter: %s", e) + + async def _trace_exit(self, error_message: str | None = None) -> None: + """Notify backend that trace has completed.""" + api_key = self._get_api_key() + if not settings.telemetry_enabled or not api_key: + return + + try: + data: dict[str, Any] = { + "task_name": self.name, + "completed_at": self._completed_at.isoformat() if self._completed_at else None, + "success": self.success, + } + if self.job_id: + data["job_id"] = self.job_id + if self.group_id: + data["group_id"] = self.group_id + if self.variants: + data["variants"] = self.variants + if self.reward is not None: + data["reward"] = self.reward + if error_message: + data["error_message"] = error_message + + await make_request( + method="POST", + url=f"{settings.hud_telemetry_url}/trace/{self.trace_id}/exit", + json=data, + api_key=api_key, + ) + except Exception as e: + logger.warning("Failed to send trace exit: %s", e) + + # ========================================================================= + # Context Manager + # ========================================================================= + + async def __aenter__(self) -> Self: + self._started_at = datetime.now(UTC) + self._token = _current_trace_headers.set(self.headers) + await self._trace_enter() + return self + + async def __aexit__( + self, + exc_type: type[BaseException] | None, + exc_val: BaseException | None, + exc_tb: TracebackType | None, + ) -> None: + self._completed_at = datetime.now(UTC) + + if self._token is not None: + _current_trace_headers.reset(self._token) + self._token = None + + error_msg: str | None = None + if exc_type is not None: + self.error = exc_val + error_msg = str(exc_val) if exc_val else "Unknown error" + + # Send exit with all data (reward, error, etc.) + await self._trace_exit(error_msg) + + def __repr__(self) -> str: + return f"TraceContext({self.trace_id[:8]}..., name={self.name!r}, reward={self.reward})" diff --git a/hud/trace/mixin.py b/hud/trace/mixin.py new file mode 100644 index 00000000..d8a5e6b4 --- /dev/null +++ b/hud/trace/mixin.py @@ -0,0 +1,382 @@ +"""TraceMixin - Adds trace() method to Environment. + +This mixin provides the trace() context manager that creates TraceContext +instances for recording agent runs, with optional parallel execution and +variant-based A/B testing. +""" + +from __future__ import annotations + +import inspect +import itertools +import logging +import uuid +from contextlib import asynccontextmanager +from typing import TYPE_CHECKING, Any + +from hud.trace.context import TraceContext +from hud.trace.parallel import ( + ASTExtractionError, + _get_with_block_body, + run_parallel_traces, +) + +if TYPE_CHECKING: + from collections.abc import AsyncGenerator + + from hud.types import MCPToolResult + +logger = logging.getLogger(__name__) + + +def _expand_variants( + variants: dict[str, Any] | None, +) -> list[dict[str, Any]]: + """Expand variants dict into all combinations. + + Args: + variants: Dict where values can be: + - Single value: {"model": "gpt-4o"} → fixed + - List: {"model": ["gpt-4o", "claude"]} → expand + + Returns: + List of variant assignments, one per combination. + + Examples: + >>> _expand_variants(None) + [{}] + >>> _expand_variants({"model": "gpt-4o"}) + [{"model": "gpt-4o"}] + >>> _expand_variants({"model": ["gpt-4o", "claude"]}) + [{"model": "gpt-4o"}, {"model": "claude"}] + >>> _expand_variants({"model": ["a", "b"], "temp": [0.0, 0.7]}) + [{"model": "a", "temp": 0.0}, {"model": "a", "temp": 0.7}, + {"model": "b", "temp": 0.0}, {"model": "b", "temp": 0.7}] + """ + if not variants: + return [{}] + + # Normalize: single values become single-element lists + expanded: dict[str, list[Any]] = {} + for key, value in variants.items(): + if isinstance(value, list): + expanded[key] = value + else: + expanded[key] = [value] + + # Generate all combinations + keys = list(expanded.keys()) + value_lists = [expanded[k] for k in keys] + + return [ + dict(zip(keys, combo, strict=True)) + for combo in itertools.product(*value_lists) + ] + + +class TraceMixin: + """Mixin that adds trace capabilities to Environment. + + This mixin provides: + - trace(): Create a TraceContext for recording agent runs + - Parallel execution with group=N parameter + - A/B testing with variants parameter + + Example: + ```python + class Environment(TraceMixin, MCPServer): + ... + + env = Environment("my-env") + + # Single trace + async with env.trace("task") as tc: + await tc.call_tool("navigate", {"url": "..."}) + tc.reward = 0.9 + + # Parallel traces (runs 4 times) + async with env.trace("task", group=4) as tc: + await tc.call_tool("navigate", {"url": "..."}) + tc.reward = 0.9 + + # A/B testing (2 variants x 3 runs = 6 traces) + async with env.trace("task", + variants={"model": ["gpt-4o", "claude"]}, + group=3, + ) as tc: + model = tc.variants["model"] + response = await call_llm(model=model) + tc.reward = evaluate(response) + + # Access results + for t in tc.results: + print(f"{t.variants} run {t.index}: reward={t.reward}") + ``` + """ + + # These will be provided by the Environment class + name: str + + # Store last parallel results (list of completed TraceContext objects) + _last_traces: list[TraceContext] | None = None + + async def call_tool(self, name: str, arguments: dict[str, Any] | None = None) -> MCPToolResult: + """Placeholder - implemented by Environment.""" + raise NotImplementedError + + @property + def last_traces(self) -> list[TraceContext] | None: + """Get TraceContext objects from the last parallel execution. + + Each TraceContext has: trace_id, index, reward, duration, error, success + """ + return self._last_traces + + @asynccontextmanager + async def trace( + self, + name: str, + *, + variants: dict[str, Any] | None = None, + group: int = 1, + group_ids: list[str] | None = None, + job_id: str | None = None, + trace_id: str | None = None, + api_key: str | None = None, + ) -> AsyncGenerator[TraceContext, None]: + """Create a trace context for recording an agent run. + + The trace context provides: + - Unique trace identification + - Task name linking (for training data construction) + - Headers for gateway integration (auto-injected to inference.hud.ai) + - Tool call delegation + - Reward setting + - Metrics logging + + A/B Testing: + Use `variants` to define experiment variables. Each list value + creates a variant; single values are fixed. All combinations + are expanded and run. + + Parallel Execution: + Use `group` to run multiple times per variant for statistical + significance. Total traces = len(variants combinations) x group. + + Args: + name: Task name for this trace (used for task construction) + variants: A/B test configuration. Dict where: + - List values are expanded: {"model": ["gpt-4o", "claude"]} + - Single values are fixed: {"temp": 0.7} + - All combinations are run + group: Runs per variant (default: 1) for statistical significance. + group_ids: Optional list of group IDs for each trace. + Length must match (variants x group). If not provided, + a single shared group_id is auto-generated. + job_id: Optional job ID to link this trace to. If not provided, + auto-detects from current `hud.job()` context. + trace_id: Optional trace ID (auto-generated if not provided). + For parallel execution, each trace gets a unique ID. + api_key: Optional API key for backend calls (defaults to settings.api_key) + + Yields: + TraceContext for this trace. Inside the body: + - `tc.variants` = current variant assignment (e.g., {"model": "gpt-4o"}) + - `tc.index` = local run index (for debugging) + - `tc.group_id` = links all traces in this parallel execution + + After execution (for variants/group > 1): + - `tc.results` = list of all TraceContext objects + - `tc.reward` = mean reward across all traces + + Example: + ```python + # Single execution + async with env.trace("task") as tc: + await tc.call_tool("search", {"query": "..."}) + tc.reward = 1.0 + + # A/B test: 2 variants x 3 runs = 6 traces + async with env.trace("task", + variants={"model": ["gpt-4o", "claude"]}, + group=3, + ) as tc: + model = tc.variants["model"] # Assigned per-trace + response = await call_llm(model=model) + tc.reward = evaluate(response) + + # Access results + for t in tc.results: + print(f"{t.variants} run {t.index}: reward={t.reward}") + ``` + + Limitations (for variants/group > 1): + - Requires source file (won't work in REPL/Jupyter) + - Outer variables captured at enter time, changes don't propagate back + - Modifying mutable objects causes race conditions + - Cannot use yield/generators inside body + """ + if group <= 0: + raise ValueError("group must be >= 1") + + # Expand variants into all combinations + variant_combos = _expand_variants(variants) + total_traces = len(variant_combos) * group + + # Validate parallelization - only remote connections allowed for group > 1 + if total_traces > 1 and not self.is_parallelizable: # type: ignore[attr-defined] + local_conns = self.local_connections # type: ignore[attr-defined] + raise ValueError( + f"Cannot run parallel traces (group={group}) with local connections.\n" + f" Local connections: {local_conns}\n" + f" Local connections (stdio/Docker) can only run one instance.\n" + f" Use remote connections (HTTP/URL) for parallel execution." + ) + + if total_traces == 1: + # Simple case: single trace + # TraceContext enters FIRST (sets headers in contextvar) + # Environment enters SECOND (can inject headers into connections) + tc = TraceContext( + env=self, # type: ignore[arg-type] + name=name, + trace_id=trace_id, + api_key=api_key, + job_id=job_id, + _variants=variant_combos[0], + ) + async with tc: + async with self: # type: ignore[attr-defined] + yield tc + else: + # Parallel execution: each trace gets its own environment instance + # Parent environment NOT entered - each child connects independently + completed = await self._run_parallel_trace( + name=name, + variant_combos=variant_combos, + group=group, + group_ids=group_ids, + job_id=job_id, + api_key=api_key, + ) + + # Create parent tc with results injected + tc = TraceContext( + env=self, # type: ignore[arg-type] + name=name, + trace_id=trace_id, + api_key=api_key, + job_id=job_id, + ) + tc.results = completed + self._last_traces = completed + + # Compute aggregate reward (mean of non-None rewards) + rewards = [t.reward for t in completed if t.reward is not None] + if rewards: + tc.reward = sum(rewards) / len(rewards) + + yield tc + + async def _run_parallel_trace( + self, + name: str, + variant_combos: list[dict[str, Any]], + group: int, + group_ids: list[str] | None, + job_id: str | None, + api_key: str | None, + ) -> list[TraceContext]: + """Run parallel trace execution using AST extraction. + + This method: + 1. Captures the caller's frame + 2. Extracts the with-block body via AST + 3. Creates (variants x group) TraceContext instances + 4. Runs the body in parallel + 5. Stores results in self._last_traces + + Args: + name: Task name + variant_combos: List of variant assignments (one per combination) + group: Runs per variant + group_ids: Optional list of group IDs (one per total trace) + job_id: Optional job ID (auto-detected from current job if not provided) + api_key: Optional API key + """ + # Get the caller's frame (skip this method and the trace method) + frame = inspect.currentframe() + if frame is None: + raise ASTExtractionError("Cannot get current frame") + + try: + # Go up: _run_parallel_trace -> trace -> user code + caller_frame = frame.f_back + if caller_frame is not None: + caller_frame = caller_frame.f_back + if caller_frame is None: + raise ASTExtractionError("Cannot get caller frame") + + # Extract the with-block body + body_source, captured_locals = _get_with_block_body(caller_frame) + + finally: + del frame # Avoid reference cycles + + # Calculate total traces + total_traces = len(variant_combos) * group + + # Use provided group_ids or generate one shared group_id + if group_ids: + if len(group_ids) != total_traces: + raise ValueError( + f"group_ids length ({len(group_ids)}) must match " + f"total traces ({total_traces} = {len(variant_combos)} variants x {group} runs)" + ) + resolved_group_ids = group_ids + else: + # All traces share one auto-generated group_id + shared_group_id = str(uuid.uuid4()) + resolved_group_ids = [shared_group_id] * total_traces + + # Create TraceContext for each (variant, run) combination + trace_contexts: list[TraceContext] = [] + idx = 0 + for variant in variant_combos: + for _ in range(group): + tc = TraceContext( + env=self, # type: ignore[arg-type] + name=name, + api_key=api_key, + job_id=job_id, + _group_id=resolved_group_ids[idx], + _index=idx, + _variants=variant, + ) + trace_contexts.append(tc) + idx += 1 + + # Run in parallel + total = len(trace_contexts) + logger.info( + "Running %d traces for task '%s' (%d variants x %d runs)", + total, name, len(variant_combos), group, + ) + completed = await run_parallel_traces(trace_contexts, body_source, captured_locals) + + # Store results + self._last_traces = completed + + # Calculate stats + rewards = [tc.reward for tc in completed if tc.reward is not None] + mean_reward = sum(rewards) / len(rewards) if rewards else 0.0 + success_count = sum(1 for tc in completed if tc.success) + + logger.info( + "Traces complete: %d/%d succeeded, mean_reward=%.3f", + success_count, + len(completed), + mean_reward, + ) + + return completed diff --git a/hud/trace/parallel.py b/hud/trace/parallel.py new file mode 100644 index 00000000..f20de00d --- /dev/null +++ b/hud/trace/parallel.py @@ -0,0 +1,131 @@ +"""Parallel execution support for traces. + +This module provides AST extraction and parallel execution for running +the same trace body N times concurrently. +""" + +from __future__ import annotations + +import ast +import asyncio +import linecache +import logging +import textwrap +from typing import TYPE_CHECKING, Any + +if TYPE_CHECKING: + from hud.trace.context import TraceContext + +logger = logging.getLogger(__name__) + + +class ASTExtractionError(Exception): + """Error extracting AST from source.""" + + +def _get_with_block_body(frame: Any) -> tuple[str, dict[str, Any]]: + """Extract the body of a with-block from the calling frame. + + Args: + frame: The calling frame (from inspect.currentframe()) + + Returns: + Tuple of (body_source, captured_locals) + """ + filename = frame.f_code.co_filename + lineno = frame.f_lineno + + # Check for interactive session + if filename.startswith("<") or filename in ("", ""): + raise ASTExtractionError( + "Cannot extract source from interactive session. Use a .py file." + ) + + # Read and parse source + lines = linecache.getlines(filename) + if not lines: + with open(filename, encoding="utf-8") as f: + lines = f.readlines() + + source = "".join(lines) + tree = ast.parse(source, filename=filename) + + # Find the async with containing this line + with_node = _find_async_with(tree, lineno) + if with_node is None: + raise ASTExtractionError( + f"Cannot find 'async with' statement at line {lineno}" + ) + + # Extract body source + body_source = _extract_body(lines, with_node) + + return body_source, frame.f_locals.copy() + + +def _find_async_with(tree: ast.AST, target_line: int) -> ast.AsyncWith | None: + """Find AsyncWith node containing the target line.""" + for node in ast.walk(tree): + if isinstance(node, ast.AsyncWith): + end_line = _get_end_line(node) + if node.lineno <= target_line <= end_line: + return node + return None + + +def _get_end_line(node: ast.AST) -> int: + """Get the last line number of an AST node.""" + end = getattr(node, "end_lineno", getattr(node, "lineno", 0)) + for child in ast.walk(node): + child_end = getattr(child, "end_lineno", 0) + if child_end > end: + end = child_end + return end + + +def _extract_body(lines: list[str], with_node: ast.AsyncWith) -> str: + """Extract the body source from an AsyncWith node.""" + if not with_node.body: + return "pass" + + start = with_node.body[0].lineno - 1 + end = _get_end_line(with_node.body[-1]) + + body = "".join(lines[start:end]) + return textwrap.dedent(body) + + +async def run_parallel_traces( + trace_contexts: list[TraceContext], + body_source: str, + captured_locals: dict[str, Any], +) -> list[TraceContext]: + """Run the trace body in parallel for multiple contexts. + + Returns the TraceContext objects after execution - they contain: + - trace_id + - index + - reward + - duration + - Any error is captured in the context + """ + + # Create runner function + wrapped = f"async def __runner__(tc):\n{textwrap.indent(body_source, ' ')}" + code = compile(wrapped, "", "exec") + namespace = captured_locals.copy() + exec(code, namespace) # noqa: S102 + runner = namespace["__runner__"] + + async def run_one(tc: TraceContext) -> TraceContext: + try: + async with tc: + await runner(tc) + except Exception as e: + logger.warning("Parallel trace %d failed: %s", tc.index, e) + # Store error in context for inspection + tc._error = e # type: ignore[attr-defined] + return tc + + results = await asyncio.gather(*[run_one(tc) for tc in trace_contexts]) + return list(results) diff --git a/hud/types.py b/hud/types.py index 527e9051..f4fd5c5b 100644 --- a/hud/types.py +++ b/hud/types.py @@ -236,7 +236,9 @@ def __rich__(self) -> str: class MCPToolResult(CallToolResult): - """A tool result.""" + """A tool result with optional call_id for correlation.""" + + call_id: str | None = None # For correlating with provider-specific tool call IDs def _get_content_summary(self) -> str: """Extract a summary of the content.""" From 8ca286d24691f88256a09abc99e4f6a7a7b7059d Mon Sep 17 00:00:00 2001 From: lorenss-m Date: Tue, 9 Dec 2025 05:20:12 -0800 Subject: [PATCH 02/92] add tests and various output changes --- hud/datasets/runner.py | 2 +- hud/environment/__init__.py | 3 + hud/environment/connection.py | 51 ++- hud/environment/connectors/base.py | 6 +- hud/environment/connectors/mcp_config.py | 2 +- hud/environment/connectors/remote.py | 18 + hud/environment/connectors/task.py | 5 + hud/environment/environment.py | 75 ++++- hud/environment/tests/__init__.py | 2 + hud/environment/tests/test_connection.py | 312 ++++++++++++++++++ hud/environment/tests/test_connectors.py | 269 +++++++++++++++ hud/environment/tests/test_environment.py | 192 +++++++++++ hud/environment/tests/test_integrations.py | 246 ++++++++++++++ .../tests/test_local_connectors.py | 204 ++++++++++++ hud/environment/types.py | 29 ++ hud/trace/context.py | 149 ++++++--- hud/trace/mixin.py | 54 +++ hud/trace/tests/__init__.py | 2 + hud/trace/tests/test_context.py | 288 ++++++++++++++++ hud/trace/tests/test_mixin.py | 178 ++++++++++ hud/trace/tests/test_parallel.py | 156 +++++++++ 21 files changed, 2183 insertions(+), 60 deletions(-) create mode 100644 hud/environment/tests/__init__.py create mode 100644 hud/environment/tests/test_connection.py create mode 100644 hud/environment/tests/test_connectors.py create mode 100644 hud/environment/tests/test_environment.py create mode 100644 hud/environment/tests/test_integrations.py create mode 100644 hud/environment/tests/test_local_connectors.py create mode 100644 hud/environment/types.py create mode 100644 hud/trace/tests/__init__.py create mode 100644 hud/trace/tests/test_context.py create mode 100644 hud/trace/tests/test_mixin.py create mode 100644 hud/trace/tests/test_parallel.py diff --git a/hud/datasets/runner.py b/hud/datasets/runner.py index 5b20092c..7875500a 100644 --- a/hud/datasets/runner.py +++ b/hud/datasets/runner.py @@ -11,7 +11,7 @@ from datasets import Dataset, load_dataset -from hud import async_job, async_trace +from hud.telemetry import async_job, async_trace from hud.datasets.utils import calculate_group_stats, submit_rollouts from hud.types import AgentType, Task, Trace diff --git a/hud/environment/__init__.py b/hud/environment/__init__.py index 57634e51..3bd64c39 100644 --- a/hud/environment/__init__.py +++ b/hud/environment/__init__.py @@ -28,6 +28,7 @@ from hud.environment.environment import Environment from hud.environment.mock import MockMixin, generate_mock_value from hud.environment.router import ConflictResolution, ToolRouter +from hud.environment.types import EnvConfig, HubConfig from hud.environment.utils import ToolFormat, format_result, parse_tool_call, parse_tool_calls __all__ = [ @@ -35,7 +36,9 @@ "ConnectionConfig", "ConnectionType", "Connector", + "EnvConfig", "Environment", + "HubConfig", "MockMixin", "ToolFormat", "ToolRouter", diff --git a/hud/environment/connection.py b/hud/environment/connection.py index dd05ccde..a104881b 100644 --- a/hud/environment/connection.py +++ b/hud/environment/connection.py @@ -7,11 +7,11 @@ from typing import TYPE_CHECKING, Any import mcp.types as mcp_types -from fastmcp.client import Client as FastMCPClient if TYPE_CHECKING: from collections.abc import Callable + from fastmcp.client import Client as FastMCPClient from fastmcp.tools.tool import Tool __all__ = ["ConnectionConfig", "ConnectionType", "Connector"] @@ -44,19 +44,29 @@ def __init__( class Connector: - """Manages a connection to an MCP server with tool caching.""" + """Manages a connection to an MCP server with tool caching. + + Client creation is deferred to connect() so that: + 1. Each parallel trace gets fresh client instances + 2. Connection happens inside trace context (for header injection) + """ def __init__( self, - client: FastMCPClient[Any], + transport: Any, config: ConnectionConfig, name: str, connection_type: ConnectionType, + *, + auth: str | None = None, ) -> None: - self.client = client + # Store transport config - client created in connect() + self._transport = transport + self._auth = auth self.config = config self.name = name self.connection_type = connection_type + self.client: FastMCPClient[Any] | None = None self._tools_cache: list[mcp_types.Tool] | None = None @property @@ -71,25 +81,36 @@ def is_remote(self) -> bool: @property def is_connected(self) -> bool: - return self.client.is_connected() + return self.client is not None and self.client.is_connected() @property def cached_tools(self) -> list[mcp_types.Tool]: return self._tools_cache or [] async def connect(self) -> None: - """Connect using FastMCP Client's context manager.""" - if not self.is_connected: - await self.client.__aenter__() + """Create FastMCP client and connect. + + Client is created here (not in __init__) so that: + 1. Each parallel trace gets fresh client instances + 2. httpx auto-instrumentation can inject trace headers + """ + from fastmcp.client import Client as FastMCPClient + + # Create fresh client from stored transport config + self.client = FastMCPClient(transport=self._transport, auth=self._auth) + await self.client.__aenter__() async def disconnect(self) -> None: """Disconnect and clear cache.""" - if self.is_connected: + if self.client is not None and self.is_connected: await self.client.__aexit__(None, None, None) - self._tools_cache = None + self.client = None + self._tools_cache = None async def list_tools(self) -> list[mcp_types.Tool]: """Fetch tools from server, apply filters/transforms/prefix, and cache.""" + if self.client is None: + raise RuntimeError("Not connected - call connect() first") tools = await self.client.list_tools() result: list[mcp_types.Tool] = [] @@ -133,25 +154,35 @@ async def call_tool( self, name: str, arguments: dict[str, Any] | None = None ) -> mcp_types.CallToolResult: """Call a tool, stripping prefix if needed.""" + if self.client is None: + raise RuntimeError("Not connected - call connect() first") # Strip prefix when calling remote if self.config.prefix and name.startswith(f"{self.config.prefix}_"): name = name[len(self.config.prefix) + 1:] return await self.client.call_tool_mcp(name, arguments or {}) async def list_resources(self) -> list[mcp_types.Resource]: + if self.client is None: + raise RuntimeError("Not connected - call connect() first") return await self.client.list_resources() async def list_prompts(self) -> list[mcp_types.Prompt]: + if self.client is None: + raise RuntimeError("Not connected - call connect() first") return await self.client.list_prompts() async def read_resource( self, uri: str ) -> list[mcp_types.TextResourceContents | mcp_types.BlobResourceContents]: + if self.client is None: + raise RuntimeError("Not connected - call connect() first") return await self.client.read_resource(uri) async def get_prompt( self, name: str, arguments: dict[str, Any] | None = None ) -> mcp_types.GetPromptResult: + if self.client is None: + raise RuntimeError("Not connected - call connect() first") return await self.client.get_prompt(name, arguments) def __repr__(self) -> str: diff --git a/hud/environment/connectors/base.py b/hud/environment/connectors/base.py index 3d25e78f..997dd715 100644 --- a/hud/environment/connectors/base.py +++ b/hud/environment/connectors/base.py @@ -50,16 +50,12 @@ def _add_connection( Returns: self for chaining. """ - from fastmcp.client import Client as FastMCPClient - from hud.environment.connection import ConnectionConfig, Connector config = ConnectionConfig( prefix=prefix, include=include, exclude=exclude, transform=transform, ) - client = FastMCPClient(transport=transport, auth=auth) self._connections[name] = Connector( - client, config, name, connection_type=connection_type, + transport, config, name, connection_type=connection_type, auth=auth, ) return self - diff --git a/hud/environment/connectors/mcp_config.py b/hud/environment/connectors/mcp_config.py index 95581974..e9a06cdd 100644 --- a/hud/environment/connectors/mcp_config.py +++ b/hud/environment/connectors/mcp_config.py @@ -83,7 +83,7 @@ def connect_mcp_config( "args": ["-y", "@modelcontextprotocol/server-filesystem", "/tmp"], }, "github": { - "command": "npx", + "command": "npx", "args": ["-y", "@modelcontextprotocol/server-github"], "env": {"GITHUB_TOKEN": "..."}, }, diff --git a/hud/environment/connectors/remote.py b/hud/environment/connectors/remote.py index 7a7ad4b4..5dc539b3 100644 --- a/hud/environment/connectors/remote.py +++ b/hud/environment/connectors/remote.py @@ -20,6 +20,9 @@ class RemoteConnectorMixin(MCPConfigConnectorMixin): """Mixin providing remote connection methods.""" + # Store hub configs for trace serialization + _hub_configs: list[dict[str, Any]] + def mount(self, server: Any, *, prefix: str | None = None) -> None: raise NotImplementedError @@ -50,6 +53,21 @@ def connect_hub( from hud.settings import settings + # Store hub config for trace serialization + hub_config: dict[str, Any] = {"slug": slug} + if alias: + hub_config["alias"] = alias + if prefix: + hub_config["prefix"] = prefix + if include: + hub_config["include"] = include + if exclude: + hub_config["exclude"] = exclude + + if not hasattr(self, "_hub_configs"): + self._hub_configs = [] + self._hub_configs.append(hub_config) + # Fetch mcp_config synchronously logger.info("Loading hub environment: %s", slug) diff --git a/hud/environment/connectors/task.py b/hud/environment/connectors/task.py index 3eae3bc4..4298bbcf 100644 --- a/hud/environment/connectors/task.py +++ b/hud/environment/connectors/task.py @@ -79,10 +79,15 @@ def _apply_task(self, task: Task) -> None: """Apply a Task definition to this environment. Sets up: + - Prompt from task.prompt - MCP connections from task.mcp_config - Setup tool calls from task.setup_tool - Evaluate tool calls from task.evaluate_tool """ + # Set prompt + if task.prompt: + self.prompt = task.prompt # type: ignore[attr-defined] + # Connect MCP servers if task.mcp_config: self.connect_mcp_config(task.mcp_config) diff --git a/hud/environment/environment.py b/hud/environment/environment.py index 63d0f0f0..45141f5d 100644 --- a/hud/environment/environment.py +++ b/hud/environment/environment.py @@ -129,6 +129,9 @@ def __init__( self._setup_calls: list[tuple[str, dict[str, Any]]] = [] self._evaluate_calls: list[tuple[str, dict[str, Any]]] = [] + # Task prompt - set by connect_task or manually + self.prompt: str | None = None + # Track which lifecycle tools we've warned about (only warn once per tool) self._warned_lifecycle_tools: set[str] = set() @@ -268,13 +271,22 @@ async def __aexit__( exc_tb: types.TracebackType | None, ) -> None: """Run evaluate tools, exit queue, then disconnect.""" - # Evaluate tool calls + from hud.agents.base import find_reward + + # Evaluate tool calls and collect rewards + rewards: list[float] = [] for name, args in self._evaluate_calls: try: - await self._execute_tool(name, args) + result = await self._execute_tool(name, args) + rewards.append(find_reward(result)) except Exception as e: logger.warning("Evaluate tool %s failed: %s", name, e) + # Store average reward from evaluate tools + self._evaluate_reward: float | None = None + if rewards: + self._evaluate_reward = sum(rewards) / len(rewards) + self._in_context = False if self._connections: await asyncio.gather(*[c.disconnect() for c in self._connections.values()]) @@ -443,5 +455,64 @@ def local_connections(self) -> list[str]: """Names of local (non-parallelizable) connections.""" return [name for name, conn in self._connections.items() if conn.is_local] + def _get_env_config(self) -> dict[str, Any] | None: + """Get serializable environment configuration for trace storage. + + Returns EnvConfig-compatible dict with: + - name: Environment name + - hubs: List of hub configs (connect_hub calls) + - setup_tools: Tools to run after connection (MCPToolCall format) + - evaluate_tools: Tools to run before disconnection (MCPToolCall format) + """ + hub_configs = getattr(self, "_hub_configs", []) + + # Convert setup/evaluate calls to MCPToolCall format + setup_tools = [ + {"name": name, "arguments": args} + for name, args in self._setup_calls + ] + evaluate_tools = [ + {"name": name, "arguments": args} + for name, args in self._evaluate_calls + ] + + # Only return config if there's something to store + if not hub_configs and not setup_tools and not evaluate_tools: + return None + + return { + "name": self.name, + "hubs": hub_configs, + "setup_tools": setup_tools, + "evaluate_tools": evaluate_tools, + } + + @property + def _all_hubs(self) -> bool: + """True if all tools came from connect_hub (fully reproducible). + + Returns False if there are: + - Local tools (@env.tool, connect_fastapi, connect_openapi, connect_server) + - Non-hub connections (connect_url, connect_mcp, connect_image, etc.) + """ + hub_configs = getattr(self, "_hub_configs", []) + + # Check for local tools (mounted servers, @env.tool) + # _tool_manager comes from MCPServer base class + local_tool_count = len(self._tool_manager._tools) if hasattr(self, "_tool_manager") else 0 + if local_tool_count > 0: + return False + + # No hubs and no connections = trivially all hubs (empty env) + if not hub_configs and not self._connections: + return True + + # Has connections but no hubs = not all hubs + if not hub_configs: + return False + + # Compare hub count to connection count + return len(hub_configs) >= len(self._connections) + def __repr__(self) -> str: return f"Environment({self.name!r}, connections={list(self._connections.keys())})" diff --git a/hud/environment/tests/__init__.py b/hud/environment/tests/__init__.py new file mode 100644 index 00000000..9364a7c0 --- /dev/null +++ b/hud/environment/tests/__init__.py @@ -0,0 +1,2 @@ +"""Tests for hud.environment module.""" + diff --git a/hud/environment/tests/test_connection.py b/hud/environment/tests/test_connection.py new file mode 100644 index 00000000..cc6fdba4 --- /dev/null +++ b/hud/environment/tests/test_connection.py @@ -0,0 +1,312 @@ +"""Tests for hud.environment.connection module.""" + +from __future__ import annotations + +from unittest.mock import AsyncMock, MagicMock, patch + +import mcp.types as mcp_types +import pytest + +from hud.environment.connection import ConnectionConfig, ConnectionType, Connector + + +class TestConnectionConfig: + """Tests for ConnectionConfig.""" + + def test_default_config(self) -> None: + """Config with no options set.""" + config = ConnectionConfig() + assert config.prefix is None + assert config.include is None + assert config.exclude is None + assert config.transform is None + + def test_config_with_options(self) -> None: + """Config with all options set.""" + transform_fn = lambda t: t # noqa: E731 + config = ConnectionConfig( + prefix="test", + include=["tool1", "tool2"], + exclude=["tool3"], + transform=transform_fn, + ) + assert config.prefix == "test" + assert config.include == ["tool1", "tool2"] + assert config.exclude == ["tool3"] + assert config.transform is transform_fn + + +class TestConnectionType: + """Tests for ConnectionType enum.""" + + def test_local_type(self) -> None: + """LOCAL type for stdio/Docker connections.""" + assert ConnectionType.LOCAL.value == "local" + + def test_remote_type(self) -> None: + """REMOTE type for HTTP connections.""" + assert ConnectionType.REMOTE.value == "remote" + + +class TestConnector: + """Tests for Connector class.""" + + def test_init_stores_transport_config(self) -> None: + """__init__ stores transport config, doesn't create client.""" + transport = {"server": {"url": "http://example.com"}} + config = ConnectionConfig() + + connector = Connector( + transport=transport, + config=config, + name="test", + connection_type=ConnectionType.REMOTE, + auth="test-token", + ) + + assert connector._transport == transport + assert connector._auth == "test-token" + assert connector.name == "test" + assert connector.connection_type == ConnectionType.REMOTE + assert connector.client is None # Not created yet + assert connector._tools_cache is None + + def test_is_local_property(self) -> None: + """is_local returns True for LOCAL connections.""" + connector = Connector( + transport={}, + config=ConnectionConfig(), + name="local-test", + connection_type=ConnectionType.LOCAL, + ) + assert connector.is_local is True + assert connector.is_remote is False + + def test_is_remote_property(self) -> None: + """is_remote returns True for REMOTE connections.""" + connector = Connector( + transport={}, + config=ConnectionConfig(), + name="remote-test", + connection_type=ConnectionType.REMOTE, + ) + assert connector.is_remote is True + assert connector.is_local is False + + def test_is_connected_false_when_no_client(self) -> None: + """is_connected returns False when client is None.""" + connector = Connector( + transport={}, + config=ConnectionConfig(), + name="test", + connection_type=ConnectionType.REMOTE, + ) + assert connector.is_connected is False + + def test_cached_tools_empty_initially(self) -> None: + """cached_tools returns empty list initially.""" + connector = Connector( + transport={}, + config=ConnectionConfig(), + name="test", + connection_type=ConnectionType.REMOTE, + ) + assert connector.cached_tools == [] + + @pytest.mark.asyncio + async def test_connect_creates_client(self) -> None: + """connect() creates FastMCPClient and enters context.""" + transport = {"server": {"url": "http://example.com"}} + connector = Connector( + transport=transport, + config=ConnectionConfig(), + name="test", + connection_type=ConnectionType.REMOTE, + auth="test-token", + ) + + mock_client = MagicMock() + mock_client.__aenter__ = AsyncMock(return_value=mock_client) + mock_client.is_connected = MagicMock(return_value=True) + + # Patch where it's imported from, not where it's used + with patch( + "fastmcp.client.Client", return_value=mock_client + ) as mock_cls: + await connector.connect() + + # Client was created with correct args + mock_cls.assert_called_once_with(transport=transport, auth="test-token") + # Client context was entered + mock_client.__aenter__.assert_called_once() + # Client is now set + assert connector.client is mock_client + + @pytest.mark.asyncio + async def test_disconnect_clears_client(self) -> None: + """disconnect() exits client context and clears state.""" + connector = Connector( + transport={}, + config=ConnectionConfig(), + name="test", + connection_type=ConnectionType.REMOTE, + ) + + mock_client = MagicMock() + mock_client.__aexit__ = AsyncMock(return_value=None) + mock_client.is_connected = MagicMock(return_value=True) + connector.client = mock_client + connector._tools_cache = [MagicMock()] + + await connector.disconnect() + + mock_client.__aexit__.assert_called_once_with(None, None, None) + assert connector.client is None + assert connector._tools_cache is None + + @pytest.mark.asyncio + async def test_list_tools_raises_when_not_connected(self) -> None: + """list_tools() raises RuntimeError when not connected.""" + connector = Connector( + transport={}, + config=ConnectionConfig(), + name="test", + connection_type=ConnectionType.REMOTE, + ) + + with pytest.raises(RuntimeError, match="Not connected"): + await connector.list_tools() + + @pytest.mark.asyncio + async def test_list_tools_applies_include_filter(self) -> None: + """list_tools() filters tools based on include list.""" + connector = Connector( + transport={}, + config=ConnectionConfig(include=["tool1"]), + name="test", + connection_type=ConnectionType.REMOTE, + ) + + mock_client = MagicMock() + mock_client.list_tools = AsyncMock(return_value=[ + mcp_types.Tool(name="tool1", description="Tool 1", inputSchema={}), + mcp_types.Tool(name="tool2", description="Tool 2", inputSchema={}), + ]) + connector.client = mock_client + + tools = await connector.list_tools() + + assert len(tools) == 1 + assert tools[0].name == "tool1" + + @pytest.mark.asyncio + async def test_list_tools_applies_exclude_filter(self) -> None: + """list_tools() filters out tools in exclude list.""" + connector = Connector( + transport={}, + config=ConnectionConfig(exclude=["tool2"]), + name="test", + connection_type=ConnectionType.REMOTE, + ) + + mock_client = MagicMock() + mock_client.list_tools = AsyncMock(return_value=[ + mcp_types.Tool(name="tool1", description="Tool 1", inputSchema={}), + mcp_types.Tool(name="tool2", description="Tool 2", inputSchema={}), + ]) + connector.client = mock_client + + tools = await connector.list_tools() + + assert len(tools) == 1 + assert tools[0].name == "tool1" + + @pytest.mark.asyncio + async def test_list_tools_applies_prefix(self) -> None: + """list_tools() adds prefix to tool names.""" + connector = Connector( + transport={}, + config=ConnectionConfig(prefix="myprefix"), + name="test", + connection_type=ConnectionType.REMOTE, + ) + + mock_client = MagicMock() + mock_client.list_tools = AsyncMock(return_value=[ + mcp_types.Tool(name="tool1", description="Tool 1", inputSchema={}), + ]) + connector.client = mock_client + + tools = await connector.list_tools() + + assert len(tools) == 1 + assert tools[0].name == "myprefix_tool1" + + @pytest.mark.asyncio + async def test_list_tools_caches_results(self) -> None: + """list_tools() caches results.""" + connector = Connector( + transport={}, + config=ConnectionConfig(), + name="test", + connection_type=ConnectionType.REMOTE, + ) + + mock_client = MagicMock() + mock_client.list_tools = AsyncMock(return_value=[ + mcp_types.Tool(name="tool1", description="Tool 1", inputSchema={}), + ]) + connector.client = mock_client + + tools = await connector.list_tools() + + assert connector._tools_cache == tools + assert connector.cached_tools == tools + + @pytest.mark.asyncio + async def test_call_tool_strips_prefix(self) -> None: + """call_tool() strips prefix before calling.""" + connector = Connector( + transport={}, + config=ConnectionConfig(prefix="myprefix"), + name="test", + connection_type=ConnectionType.REMOTE, + ) + + mock_result = mcp_types.CallToolResult(content=[], isError=False) + mock_client = MagicMock() + mock_client.call_tool_mcp = AsyncMock(return_value=mock_result) + connector.client = mock_client + + await connector.call_tool("myprefix_tool1", {"arg": "value"}) + + # Prefix should be stripped + mock_client.call_tool_mcp.assert_called_once_with("tool1", {"arg": "value"}) + + @pytest.mark.asyncio + async def test_call_tool_raises_when_not_connected(self) -> None: + """call_tool() raises RuntimeError when not connected.""" + connector = Connector( + transport={}, + config=ConnectionConfig(), + name="test", + connection_type=ConnectionType.REMOTE, + ) + + with pytest.raises(RuntimeError, match="Not connected"): + await connector.call_tool("tool1", {}) + + def test_repr(self) -> None: + """__repr__ shows useful info.""" + connector = Connector( + transport={}, + config=ConnectionConfig(), + name="my-server", + connection_type=ConnectionType.REMOTE, + ) + + repr_str = repr(connector) + assert "my-server" in repr_str + assert "remote" in repr_str + assert "connected=False" in repr_str + diff --git a/hud/environment/tests/test_connectors.py b/hud/environment/tests/test_connectors.py new file mode 100644 index 00000000..d0e05883 --- /dev/null +++ b/hud/environment/tests/test_connectors.py @@ -0,0 +1,269 @@ +"""Tests for hud.environment.connectors module.""" + +from __future__ import annotations + +from typing import Any +from unittest.mock import MagicMock, patch + +from hud.environment.connection import ConnectionType, Connector + + +class TestBaseConnectorMixin: + """Tests for BaseConnectorMixin._add_connection.""" + + def test_add_connection_stores_transport_config(self) -> None: + """_add_connection stores transport, doesn't create client.""" + from hud.environment.connectors.base import BaseConnectorMixin + + class TestEnv(BaseConnectorMixin): + def __init__(self) -> None: + self._connections: dict[str, Connector] = {} + + env = TestEnv() + transport = {"server": {"url": "http://example.com"}} + + env._add_connection( + "test-server", + transport, + connection_type=ConnectionType.REMOTE, + auth="test-token", + prefix="myprefix", + ) + + assert "test-server" in env._connections + conn = env._connections["test-server"] + assert conn._transport == transport + assert conn._auth == "test-token" + assert conn.config.prefix == "myprefix" + assert conn.client is None # Not created yet + + def test_add_connection_returns_self(self) -> None: + """_add_connection returns self for chaining.""" + from hud.environment.connectors.base import BaseConnectorMixin + + class TestEnv(BaseConnectorMixin): + def __init__(self) -> None: + self._connections: dict[str, Connector] = {} + + env = TestEnv() + result = env._add_connection( + "test", + {}, + connection_type=ConnectionType.REMOTE, + ) + + assert result is env + + +class TestMCPConfigConnectorMixin: + """Tests for MCPConfigConnectorMixin.""" + + def test_connect_mcp_detects_local_connection(self) -> None: + """connect_mcp detects LOCAL type from command in config.""" + from hud.environment.connectors.mcp_config import MCPConfigConnectorMixin + + class TestEnv(MCPConfigConnectorMixin): + def __init__(self) -> None: + self._connections: dict[str, Connector] = {} + + env = TestEnv() + config = { + "filesystem": { + "command": "npx", + "args": ["-y", "@modelcontextprotocol/server-filesystem"], + } + } + + env.connect_mcp(config) + + conn = env._connections["filesystem"] + assert conn.connection_type == ConnectionType.LOCAL + + def test_connect_mcp_detects_remote_connection(self) -> None: + """connect_mcp detects REMOTE type from URL in config.""" + from hud.environment.connectors.mcp_config import MCPConfigConnectorMixin + + class TestEnv(MCPConfigConnectorMixin): + def __init__(self) -> None: + self._connections: dict[str, Connector] = {} + + env = TestEnv() + config = { + "browser": { + "url": "https://mcp.hud.ai/browser", + } + } + + env.connect_mcp(config) + + conn = env._connections["browser"] + assert conn.connection_type == ConnectionType.REMOTE + + def test_connect_mcp_uses_alias(self) -> None: + """connect_mcp uses alias if provided.""" + from hud.environment.connectors.mcp_config import MCPConfigConnectorMixin + + class TestEnv(MCPConfigConnectorMixin): + def __init__(self) -> None: + self._connections: dict[str, Connector] = {} + + env = TestEnv() + config = {"server": {"url": "http://example.com"}} + + env.connect_mcp(config, alias="my-alias") + + assert "my-alias" in env._connections + assert "server" not in env._connections + + def test_connect_mcp_config_creates_multiple_connections(self) -> None: + """connect_mcp_config creates a connection for each server.""" + from hud.environment.connectors.mcp_config import MCPConfigConnectorMixin + + class TestEnv(MCPConfigConnectorMixin): + def __init__(self) -> None: + self._connections: dict[str, Connector] = {} + + env = TestEnv() + mcp_config = { + "server1": {"url": "http://example1.com"}, + "server2": {"url": "http://example2.com"}, + "server3": {"command": "npx", "args": ["server"]}, + } + + env.connect_mcp_config(mcp_config) + + assert len(env._connections) == 3 + assert "server1" in env._connections + assert "server2" in env._connections + assert "server3" in env._connections + + +class TestRemoteConnectorMixin: + """Tests for RemoteConnectorMixin.""" + + def test_connect_url_creates_remote_connection(self) -> None: + """connect_url creates REMOTE connection.""" + from hud.environment.connectors.remote import RemoteConnectorMixin + + class TestEnv(RemoteConnectorMixin): + def __init__(self) -> None: + self._connections: dict[str, Connector] = {} + + def mount(self, server: Any, *, prefix: str | None = None) -> None: + pass + + env = TestEnv() + env.connect_url("https://mcp.example.com", alias="example") + + assert "example" in env._connections + conn = env._connections["example"] + assert conn.connection_type == ConnectionType.REMOTE + + def test_connect_url_extracts_auth_from_headers(self) -> None: + """connect_url extracts Authorization from headers.""" + from hud.environment.connectors.remote import RemoteConnectorMixin + + class TestEnv(RemoteConnectorMixin): + def __init__(self) -> None: + self._connections: dict[str, Connector] = {} + + def mount(self, server: Any, *, prefix: str | None = None) -> None: + pass + + env = TestEnv() + env.connect_url( + "https://mcp.example.com", + headers={"Authorization": "Bearer my-token"}, + alias="example", + ) + + conn = env._connections["example"] + assert conn._auth == "Bearer my-token" + + @patch("httpx.Client") + def test_connect_hub_fetches_config(self, mock_httpx_cls: MagicMock) -> None: + """connect_hub fetches mcp_config from API.""" + from hud.environment.connectors.remote import RemoteConnectorMixin + + class TestEnv(RemoteConnectorMixin): + def __init__(self) -> None: + self._connections: dict[str, Connector] = {} + + def mount(self, server: Any, *, prefix: str | None = None) -> None: + pass + + # Mock httpx response + mock_response = MagicMock() + mock_response.json.return_value = { + "mcp_config": { + "browser": {"url": "https://mcp.hud.ai/browser"}, + } + } + mock_response.raise_for_status = MagicMock() + + mock_client = MagicMock() + mock_client.get.return_value = mock_response + mock_client.__enter__ = MagicMock(return_value=mock_client) + mock_client.__exit__ = MagicMock(return_value=None) + mock_httpx_cls.return_value = mock_client + + env = TestEnv() + with patch("hud.settings.settings") as mock_settings: + mock_settings.hud_api_url = "https://api.hud.so" + mock_settings.api_key = "test-key" + + env.connect_hub("hud/browser") + + assert "browser" in env._connections + + +class TestTaskConnectorMixin: + """Tests for TaskConnectorMixin.""" + + @patch("httpx.Client") + def test_connect_task_fetches_and_applies_config(self, mock_httpx_cls: MagicMock) -> None: + """connect_task fetches task and applies mcp_config.""" + from hud.environment.connectors.task import TaskConnectorMixin + + class TestEnv(TaskConnectorMixin): + def __init__(self) -> None: + self._connections: dict[str, Connector] = {} + self._setup_calls: list[tuple[str, dict[str, Any]]] = [] + self._evaluate_calls: list[tuple[str, dict[str, Any]]] = [] + + def setup_tool(self, call: Any, /, **kwargs: Any) -> Any: + self._setup_calls.append((call, kwargs)) + return self + + def evaluate_tool(self, call: Any, /, **kwargs: Any) -> Any: + self._evaluate_calls.append((call, kwargs)) + return self + + # Mock httpx response with task data + mock_response = MagicMock() + mock_response.json.return_value = { + "id": "task-123", + "prompt": "Test task prompt", + "mcp_config": { + "browser": {"url": "https://mcp.hud.ai/browser"}, + }, + "setup_tool": None, + "evaluate_tool": None, + } + mock_response.raise_for_status = MagicMock() + + mock_client = MagicMock() + mock_client.get.return_value = mock_response + mock_client.__enter__ = MagicMock(return_value=mock_client) + mock_client.__exit__ = MagicMock(return_value=None) + mock_httpx_cls.return_value = mock_client + + env = TestEnv() + with patch("hud.settings.settings") as mock_settings: + mock_settings.hud_api_url = "https://api.hud.so" + mock_settings.api_key = "test-key" + + env.connect_task("my-org/my-task") + + assert "browser" in env._connections + diff --git a/hud/environment/tests/test_environment.py b/hud/environment/tests/test_environment.py new file mode 100644 index 00000000..b23bddf8 --- /dev/null +++ b/hud/environment/tests/test_environment.py @@ -0,0 +1,192 @@ +"""Tests for Environment class - context manager, resources, prompts, prompt feature.""" + +from __future__ import annotations + +from typing import Any +from unittest.mock import AsyncMock, MagicMock, patch + +import mcp.types as mcp_types +import pytest + + +class TestEnvironmentPrompt: + """Tests for Environment.prompt feature.""" + + def test_prompt_defaults_to_none(self) -> None: + """Environment.prompt defaults to None.""" + from hud.environment import Environment + + env = Environment("test") + assert env.prompt is None + + def test_prompt_can_be_set(self) -> None: + """Environment.prompt can be set manually.""" + from hud.environment import Environment + + env = Environment("test") + env.prompt = "Navigate to google.com" + assert env.prompt == "Navigate to google.com" + + def test_prompt_set_from_task(self) -> None: + """connect_task sets prompt from task.prompt.""" + from hud.environment.connectors.task import TaskConnectorMixin + from hud.environment.connection import Connector + from hud.types import Task + + class TestEnv(TaskConnectorMixin): + def __init__(self) -> None: + self._connections: dict[str, Connector] = {} + self.prompt: str | None = None + + def setup_tool(self, call: Any, /, **kwargs: Any) -> Any: + return self + + def evaluate_tool(self, call: Any, /, **kwargs: Any) -> Any: + return self + + def connect_mcp_config(self, config: dict) -> Any: + return self + + env = TestEnv() + task = Task(prompt="Test prompt", mcp_config={}) + env._apply_task(task) + + assert env.prompt == "Test prompt" + + +class TestEnvironmentContextManager: + """Tests for Environment async context manager.""" + + @pytest.mark.asyncio + async def test_context_manager_sets_in_context_flag(self) -> None: + """Context manager sets _in_context flag.""" + from hud.environment import Environment + + env = Environment("test") + + assert env._in_context is False + + async with env: + assert env._in_context is True + + assert env._in_context is False + + @pytest.mark.asyncio + async def test_context_manager_no_connections(self) -> None: + """Context manager works with no connections.""" + from hud.environment import Environment + + env = Environment("test") + + async with env: + # Should work without connections + pass + + +class TestEnvironmentResources: + """Tests for Environment resource operations.""" + + @pytest.mark.asyncio + async def test_list_resources_empty(self) -> None: + """list_resources returns empty list when no resources.""" + from hud.environment import Environment + + env = Environment("test") + + async with env: + resources = await env.list_resources() + + assert resources == [] + + @pytest.mark.asyncio + async def test_read_resource_not_found(self) -> None: + """read_resource raises when resource not found.""" + from hud.environment import Environment + + env = Environment("test") + + async with env: + with pytest.raises(ValueError, match="Resource not found"): + await env.read_resource("file://nonexistent.txt") + + +class TestEnvironmentPrompts: + """Tests for Environment prompt operations (MCP prompts, not task prompt).""" + + @pytest.mark.asyncio + async def test_list_prompts_empty(self) -> None: + """list_prompts returns empty list when no prompts.""" + from hud.environment import Environment + + env = Environment("test") + + async with env: + prompts = await env.list_prompts() + + assert prompts == [] + + @pytest.mark.asyncio + async def test_get_prompt_not_found(self) -> None: + """get_prompt raises when prompt not found.""" + from hud.environment import Environment + + env = Environment("test") + + async with env: + with pytest.raises(ValueError, match="Prompt not found"): + await env.get_prompt("nonexistent") + + +class TestEnvironmentSetupEvaluate: + """Tests for setup_tool and evaluate_tool methods.""" + + def test_setup_tool_with_name_and_kwargs(self) -> None: + """setup_tool accepts name and kwargs.""" + from hud.environment import Environment + + env = Environment("test") + env.setup_tool("navigate", url="https://example.com") + + assert len(env._setup_calls) == 1 + assert env._setup_calls[0] == ("navigate", {"url": "https://example.com"}) + + def test_setup_tool_returns_self(self) -> None: + """setup_tool returns self for chaining.""" + from hud.environment import Environment + + env = Environment("test") + result = env.setup_tool("navigate", url="https://example.com") + + assert result is env + + def test_evaluate_tool_with_name_and_kwargs(self) -> None: + """evaluate_tool accepts name and kwargs.""" + from hud.environment import Environment + + env = Environment("test") + env.evaluate_tool("check_text", contains="success") + + assert len(env._evaluate_calls) == 1 + assert env._evaluate_calls[0] == ("check_text", {"contains": "success"}) + + def test_evaluate_tool_returns_self(self) -> None: + """evaluate_tool returns self for chaining.""" + from hud.environment import Environment + + env = Environment("test") + result = env.evaluate_tool("check_text", contains="success") + + assert result is env + + def test_chaining_multiple_setup_calls(self) -> None: + """Multiple setup_tool calls can be chained.""" + from hud.environment import Environment + + env = ( + Environment("test") + .setup_tool("navigate", url="https://example.com") + .setup_tool("wait", seconds=2) + ) + + assert len(env._setup_calls) == 2 + diff --git a/hud/environment/tests/test_integrations.py b/hud/environment/tests/test_integrations.py new file mode 100644 index 00000000..30643427 --- /dev/null +++ b/hud/environment/tests/test_integrations.py @@ -0,0 +1,246 @@ +"""Tests for format integrations - OpenAI, Anthropic, Gemini.""" + +from __future__ import annotations + +from typing import Any +from unittest.mock import MagicMock + +import mcp.types as mcp_types +import pytest + + +def create_mock_tool(name: str, description: str = "", schema: dict | None = None) -> mcp_types.Tool: + """Create a mock MCP tool for testing.""" + return mcp_types.Tool( + name=name, + description=description, + inputSchema=schema or {"type": "object", "properties": {}}, + ) + + +class TestOpenAIMixin: + """Tests for OpenAI format conversion.""" + + def test_as_openai_chat_tools_basic(self) -> None: + """as_openai_chat_tools converts MCP tools to OpenAI format.""" + from hud.environment.integrations.openai import OpenAIMixin + + class TestEnv(OpenAIMixin): + def as_tools(self) -> list[mcp_types.Tool]: + return [ + create_mock_tool("navigate", "Navigate to URL", { + "type": "object", + "properties": {"url": {"type": "string"}}, + "required": ["url"], + }), + ] + + async def call_tool(self, name: str, arguments: dict[str, Any]) -> Any: + pass + + env = TestEnv() + tools = env.as_openai_chat_tools() + + assert len(tools) == 1 + assert tools[0]["type"] == "function" + assert tools[0]["function"]["name"] == "navigate" + assert tools[0]["function"]["description"] == "Navigate to URL" + assert "url" in tools[0]["function"]["parameters"]["properties"] + + def test_as_openai_chat_tools_strict_mode(self) -> None: + """as_openai_chat_tools with strict=True adds strict flag.""" + from hud.environment.integrations.openai import OpenAIMixin + + class TestEnv(OpenAIMixin): + def as_tools(self) -> list[mcp_types.Tool]: + return [create_mock_tool("test_tool")] + + async def call_tool(self, name: str, arguments: dict[str, Any]) -> Any: + pass + + env = TestEnv() + tools = env.as_openai_chat_tools(strict=True) + + assert tools[0]["function"]["strict"] is True + + def test_as_openai_chat_tools_empty(self) -> None: + """as_openai_chat_tools returns empty list when no tools.""" + from hud.environment.integrations.openai import OpenAIMixin + + class TestEnv(OpenAIMixin): + def as_tools(self) -> list[mcp_types.Tool]: + return [] + + async def call_tool(self, name: str, arguments: dict[str, Any]) -> Any: + pass + + env = TestEnv() + tools = env.as_openai_chat_tools() + + assert tools == [] + + def test_as_openai_responses_tools(self) -> None: + """as_openai_responses_tools converts to Responses API format.""" + from hud.environment.integrations.openai import OpenAIMixin + + class TestEnv(OpenAIMixin): + def as_tools(self) -> list[mcp_types.Tool]: + return [create_mock_tool("search", "Search the web")] + + async def call_tool(self, name: str, arguments: dict[str, Any]) -> Any: + pass + + env = TestEnv() + tools = env.as_openai_responses_tools() + + assert len(tools) == 1 + assert tools[0]["type"] == "function" + assert tools[0]["name"] == "search" + assert tools[0]["description"] == "Search the web" + + +class TestAnthropicMixin: + """Tests for Anthropic/Claude format conversion.""" + + def test_as_claude_tools_basic(self) -> None: + """as_claude_tools converts MCP tools to Claude format.""" + from hud.environment.integrations.anthropic import AnthropicMixin + + class TestEnv(AnthropicMixin): + def as_tools(self) -> list[mcp_types.Tool]: + return [ + create_mock_tool("click", "Click element", { + "type": "object", + "properties": {"selector": {"type": "string"}}, + }), + ] + + async def call_tool(self, name: str, arguments: dict[str, Any]) -> Any: + pass + + env = TestEnv() + tools = env.as_claude_tools() + + assert len(tools) == 1 + assert tools[0]["name"] == "click" + assert tools[0]["description"] == "Click element" + assert "input_schema" in tools[0] + assert "cache_control" not in tools[0] + + def test_as_claude_tools_with_cache_control(self) -> None: + """as_claude_tools with cache_control=True adds cache field.""" + from hud.environment.integrations.anthropic import AnthropicMixin + + class TestEnv(AnthropicMixin): + def as_tools(self) -> list[mcp_types.Tool]: + return [create_mock_tool("test")] + + async def call_tool(self, name: str, arguments: dict[str, Any]) -> Any: + pass + + env = TestEnv() + tools = env.as_claude_tools(cache_control=True) + + assert tools[0]["cache_control"] == {"type": "ephemeral"} + + def test_as_claude_programmatic_tools(self) -> None: + """as_claude_programmatic_tools includes allowed_callers.""" + from hud.environment.integrations.anthropic import AnthropicMixin + + class TestEnv(AnthropicMixin): + def as_tools(self) -> list[mcp_types.Tool]: + return [create_mock_tool("analyze")] + + async def call_tool(self, name: str, arguments: dict[str, Any]) -> Any: + pass + + env = TestEnv() + tools = env.as_claude_programmatic_tools() + + assert tools[0]["allowed_callers"] == ["code_execution_20250825"] + + +class TestGeminiMixin: + """Tests for Google/Gemini format conversion.""" + + def test_as_gemini_tools_basic(self) -> None: + """as_gemini_tools converts MCP tools to Gemini format.""" + from hud.environment.integrations.gemini import GeminiMixin + + class TestEnv(GeminiMixin): + def as_tools(self) -> list[mcp_types.Tool]: + return [ + create_mock_tool("search", "Search query", { + "type": "object", + "properties": {"query": {"type": "string"}}, + }), + ] + + env = TestEnv() + tools = env.as_gemini_tools() + + assert len(tools) == 1 + assert "function_declarations" in tools[0] + declarations = tools[0]["function_declarations"] + assert len(declarations) == 1 + assert declarations[0]["name"] == "search" + assert declarations[0]["description"] == "Search query" + + def test_as_gemini_tools_multiple(self) -> None: + """as_gemini_tools wraps multiple tools in single declaration list.""" + from hud.environment.integrations.gemini import GeminiMixin + + class TestEnv(GeminiMixin): + def as_tools(self) -> list[mcp_types.Tool]: + return [ + create_mock_tool("tool1"), + create_mock_tool("tool2"), + create_mock_tool("tool3"), + ] + + env = TestEnv() + tools = env.as_gemini_tools() + + assert len(tools) == 1 # Single wrapper object + assert len(tools[0]["function_declarations"]) == 3 + + def test_as_gemini_tool_config_auto(self) -> None: + """as_gemini_tool_config with AUTO mode.""" + from hud.environment.integrations.gemini import GeminiMixin + + class TestEnv(GeminiMixin): + def as_tools(self) -> list[mcp_types.Tool]: + return [] + + env = TestEnv() + config = env.as_gemini_tool_config(mode="AUTO") + + assert config["function_calling_config"]["mode"] == "AUTO" + + def test_as_gemini_tool_config_any_with_allowed(self) -> None: + """as_gemini_tool_config with ANY mode and allowed tools.""" + from hud.environment.integrations.gemini import GeminiMixin + + class TestEnv(GeminiMixin): + def as_tools(self) -> list[mcp_types.Tool]: + return [] + + env = TestEnv() + config = env.as_gemini_tool_config(mode="ANY", allowed_tools=["search", "navigate"]) + + assert config["function_calling_config"]["mode"] == "ANY" + assert config["function_calling_config"]["allowed_function_names"] == ["search", "navigate"] + + def test_as_gemini_tool_config_none(self) -> None: + """as_gemini_tool_config with NONE mode disables tools.""" + from hud.environment.integrations.gemini import GeminiMixin + + class TestEnv(GeminiMixin): + def as_tools(self) -> list[mcp_types.Tool]: + return [] + + env = TestEnv() + config = env.as_gemini_tool_config(mode="NONE") + + assert config["function_calling_config"]["mode"] == "NONE" + diff --git a/hud/environment/tests/test_local_connectors.py b/hud/environment/tests/test_local_connectors.py new file mode 100644 index 00000000..488fd8fc --- /dev/null +++ b/hud/environment/tests/test_local_connectors.py @@ -0,0 +1,204 @@ +"""Tests for local connectors - connect_image, connect_server, connect_fastapi.""" + +from __future__ import annotations + +from typing import Any +from unittest.mock import MagicMock, patch + +import pytest + +from hud.environment.connection import ConnectionType, Connector + + +class TestConnectImage: + """Tests for LocalConnectorMixin.connect_image.""" + + @patch("hud.cli.utils.docker.create_docker_run_command") + def test_connect_image_creates_local_connection(self, mock_docker_cmd: MagicMock) -> None: + """connect_image creates LOCAL connection with docker command.""" + from hud.environment.connectors.local import LocalConnectorMixin + + mock_docker_cmd.return_value = ["docker", "run", "-i", "--rm", "mcp/fetch"] + + class TestEnv(LocalConnectorMixin): + def __init__(self) -> None: + self._connections: dict[str, Connector] = {} + + def mount(self, server: Any, *, prefix: str | None = None) -> None: + pass + + env = TestEnv() + env.connect_image("mcp/fetch") + + assert "mcp/fetch" in env._connections + conn = env._connections["mcp/fetch"] + assert conn.connection_type == ConnectionType.LOCAL + mock_docker_cmd.assert_called_once() + + @patch("hud.cli.utils.docker.create_docker_run_command") + def test_connect_image_with_alias(self, mock_docker_cmd: MagicMock) -> None: + """connect_image uses alias for connection name.""" + from hud.environment.connectors.local import LocalConnectorMixin + + mock_docker_cmd.return_value = ["docker", "run", "-i", "--rm", "mcp/fetch"] + + class TestEnv(LocalConnectorMixin): + def __init__(self) -> None: + self._connections: dict[str, Connector] = {} + + def mount(self, server: Any, *, prefix: str | None = None) -> None: + pass + + env = TestEnv() + env.connect_image("mcp/fetch", alias="fetcher") + + assert "fetcher" in env._connections + assert "mcp/fetch" not in env._connections + + @patch("hud.cli.utils.docker.create_docker_run_command") + def test_connect_image_with_prefix(self, mock_docker_cmd: MagicMock) -> None: + """connect_image passes prefix to config.""" + from hud.environment.connectors.local import LocalConnectorMixin + + mock_docker_cmd.return_value = ["docker", "run", "-i", "--rm", "mcp/fetch"] + + class TestEnv(LocalConnectorMixin): + def __init__(self) -> None: + self._connections: dict[str, Connector] = {} + + def mount(self, server: Any, *, prefix: str | None = None) -> None: + pass + + env = TestEnv() + env.connect_image("mcp/fetch", prefix="fetch") + + conn = env._connections["mcp/fetch"] + assert conn.config.prefix == "fetch" + + @patch("hud.cli.utils.docker.create_docker_run_command") + def test_connect_image_returns_self(self, mock_docker_cmd: MagicMock) -> None: + """connect_image returns self for chaining.""" + from hud.environment.connectors.local import LocalConnectorMixin + + mock_docker_cmd.return_value = ["docker", "run", "-i", "--rm", "mcp/fetch"] + + class TestEnv(LocalConnectorMixin): + def __init__(self) -> None: + self._connections: dict[str, Connector] = {} + + def mount(self, server: Any, *, prefix: str | None = None) -> None: + pass + + env = TestEnv() + result = env.connect_image("mcp/fetch") + + assert result is env + + +class TestConnectServer: + """Tests for LocalConnectorMixin.connect_server.""" + + def test_connect_server_calls_mount(self) -> None: + """connect_server calls mount with server and prefix.""" + from hud.environment.connectors.local import LocalConnectorMixin + + class TestEnv(LocalConnectorMixin): + def __init__(self) -> None: + self._connections: dict[str, Connector] = {} + self.mounted: list[tuple[Any, str | None]] = [] + + def mount(self, server: Any, *, prefix: str | None = None) -> None: + self.mounted.append((server, prefix)) + + env = TestEnv() + mock_server = MagicMock() + env.connect_server(mock_server, prefix="tools") + + assert len(env.mounted) == 1 + assert env.mounted[0] == (mock_server, "tools") + + def test_connect_server_returns_self(self) -> None: + """connect_server returns self for chaining.""" + from hud.environment.connectors.local import LocalConnectorMixin + + class TestEnv(LocalConnectorMixin): + def __init__(self) -> None: + self._connections: dict[str, Connector] = {} + + def mount(self, server: Any, *, prefix: str | None = None) -> None: + pass + + env = TestEnv() + result = env.connect_server(MagicMock()) + + assert result is env + + +class TestConnectFastAPI: + """Tests for LocalConnectorMixin.connect_fastapi.""" + + @patch("fastmcp.FastMCP") + def test_connect_fastapi_creates_mcp_server(self, mock_fastmcp: MagicMock) -> None: + """connect_fastapi converts FastAPI app to MCP server.""" + from hud.environment.connectors.local import LocalConnectorMixin + + mock_mcp_server = MagicMock() + mock_fastmcp.from_fastapi.return_value = mock_mcp_server + + class TestEnv(LocalConnectorMixin): + def __init__(self) -> None: + self._connections: dict[str, Connector] = {} + self.mounted: list[tuple[Any, str | None]] = [] + + def mount(self, server: Any, *, prefix: str | None = None) -> None: + self.mounted.append((server, prefix)) + + env = TestEnv() + mock_app = MagicMock() + mock_app.title = "My API" + env.connect_fastapi(mock_app) + + mock_fastmcp.from_fastapi.assert_called_once_with(app=mock_app, name="My API") + assert len(env.mounted) == 1 + assert env.mounted[0] == (mock_mcp_server, None) + + @patch("fastmcp.FastMCP") + def test_connect_fastapi_with_custom_name(self, mock_fastmcp: MagicMock) -> None: + """connect_fastapi uses custom name if provided.""" + from hud.environment.connectors.local import LocalConnectorMixin + + mock_fastmcp.from_fastapi.return_value = MagicMock() + + class TestEnv(LocalConnectorMixin): + def __init__(self) -> None: + self._connections: dict[str, Connector] = {} + + def mount(self, server: Any, *, prefix: str | None = None) -> None: + pass + + env = TestEnv() + mock_app = MagicMock() + mock_app.title = "Original" + env.connect_fastapi(mock_app, name="custom-api") + + mock_fastmcp.from_fastapi.assert_called_once_with(app=mock_app, name="custom-api") + + @patch("fastmcp.FastMCP") + def test_connect_fastapi_returns_self(self, mock_fastmcp: MagicMock) -> None: + """connect_fastapi returns self for chaining.""" + from hud.environment.connectors.local import LocalConnectorMixin + + mock_fastmcp.from_fastapi.return_value = MagicMock() + + class TestEnv(LocalConnectorMixin): + def __init__(self) -> None: + self._connections: dict[str, Connector] = {} + + def mount(self, server: Any, *, prefix: str | None = None) -> None: + pass + + env = TestEnv() + result = env.connect_fastapi(MagicMock()) + + assert result is env + diff --git a/hud/environment/types.py b/hud/environment/types.py new file mode 100644 index 00000000..e911ffe8 --- /dev/null +++ b/hud/environment/types.py @@ -0,0 +1,29 @@ +"""Environment types for configuration and tracing.""" + +from __future__ import annotations + +from pydantic import BaseModel + +from hud.types import MCPToolCall + +__all__ = ["EnvConfig", "HubConfig"] + + +class HubConfig(BaseModel): + """Configuration for a single hub connection.""" + + slug: str + alias: str | None = None + prefix: str | None = None + include: list[str] | None = None + exclude: list[str] | None = None + + +class EnvConfig(BaseModel): + """Environment configuration for trace reproducibility.""" + + name: str + hubs: list[HubConfig] = [] + setup_tools: list[MCPToolCall] = [] + evaluate_tools: list[MCPToolCall] = [] + diff --git a/hud/trace/context.py b/hud/trace/context.py index a80f5b98..668daad6 100644 --- a/hud/trace/context.py +++ b/hud/trace/context.py @@ -21,6 +21,9 @@ from datetime import UTC, datetime from typing import TYPE_CHECKING, Any, Self +from pydantic import BaseModel + +from hud.environment.types import EnvConfig from hud.settings import settings from hud.shared import make_request from hud.telemetry.job import get_current_job @@ -44,24 +47,74 @@ def get_current_trace_headers() -> dict[str, str] | None: return _current_trace_headers.get() +# ============================================================================= +# Payload Models +# ============================================================================= + + +class TracePayload(BaseModel): + """Base payload for trace enter/exit - sent to both endpoints.""" + + task_name: str + prompt: str | None = None + code_snippet: str | None = None + env_config: EnvConfig | None = None + all_hubs: bool = False # True if all connectors are from connect_hub + job_id: str | None = None + group_id: str | None = None + variants: dict[str, Any] | None = None + + +class TraceExitPayload(TracePayload): + """Exit payload - includes result fields.""" + + reward: float | None = None + success: bool = True + error_message: str | None = None + + # ============================================================================= # Auto-instrumentation for httpx # ============================================================================= +def _is_hud_url(url_str: str) -> bool: + """Check if URL is a HUD service (inference or MCP).""" + from urllib.parse import urlparse + + # Extract hostnames from settings URLs + gateway_host = urlparse(settings.hud_gateway_url).netloc + mcp_host = urlparse(settings.hud_mcp_url).netloc + + # Parse the request URL and check against known HUD hosts + parsed = urlparse(url_str) + request_host = parsed.netloc or url_str.split("/")[0] + + return request_host == gateway_host or request_host == mcp_host + + def _httpx_request_hook(request: Any) -> None: - """httpx event hook that adds trace headers to inference.hud.ai requests.""" - headers = get_current_trace_headers() - if headers is None: - return + """httpx event hook that adds trace headers and auth to HUD requests. + For inference.hud.ai and mcp.hud.ai: + - Injects trace headers (Trace-Id) if in trace context + - Injects Authorization header if API key is set and no auth present + """ url_str = str(request.url) - if "inference.hud.ai" not in url_str: + if not _is_hud_url(url_str): return - for key, value in headers.items(): - request.headers[key] = value + # Inject trace headers if in trace context + headers = get_current_trace_headers() + if headers is not None: + for key, value in headers.items(): + request.headers[key] = value + logger.debug("Added trace headers to request: %s", url_str) - logger.debug("Added trace headers to request: %s", url_str) + # Auto-inject API key if not present + has_auth = "authorization" in {k.lower() for k in request.headers} + if not has_auth and settings.api_key: + request.headers["Authorization"] = f"Bearer {settings.api_key}" + logger.debug("Added API key auth to request: %s", url_str) async def _async_httpx_request_hook(request: Any) -> None: @@ -125,6 +178,7 @@ class TraceContext: group_id: Links parallel traces together (None for single traces) variants: Variant assignment dict (for A/B testing) reward: Reward value (user-settable) + prompt: Task prompt (defaults from env.prompt, user-settable) error: Exception if failed results: All trace results (for parent trace) @@ -168,6 +222,8 @@ def __init__( _group_id: str | None = None, _index: int = 0, _variants: dict[str, Any] | None = None, + _code_snippet: str | None = None, + _env_config: dict[str, Any] | None = None, ) -> None: # Identity self.trace_id: str = trace_id or str(uuid.uuid4()) @@ -188,6 +244,7 @@ def __init__( # User-settable self.reward: float | None = None + self.prompt: str | None = getattr(env, "prompt", None) # From env, can override # Error tracking self.error: BaseException | None = None @@ -195,6 +252,10 @@ def __init__( # Parallel/variant results (nested) self.results: list[TraceContext] | None = None + # Code and config (for reproducibility) + self.code_snippet: str | None = _code_snippet + self.env_config: dict[str, Any] | None = _env_config + # Private self._env = env self._api_key = api_key @@ -209,7 +270,7 @@ def __init__( @property def headers(self) -> dict[str, str]: """Headers for gateway integration.""" - return {"HUD-Trace-Id": self.trace_id} + return {"Trace-Id": self.trace_id} @property def duration(self) -> float: @@ -232,6 +293,27 @@ def done(self) -> bool: def _get_api_key(self) -> str | None: return self._api_key or settings.api_key + def _build_base_payload(self) -> TracePayload: + """Build the base payload for enter/exit.""" + # Check if all connectors are from hubs (fully reproducible) + all_hubs = getattr(self._env, "_all_hubs", False) + + # Convert env_config dict to EnvConfig model + env_config_model: EnvConfig | None = None + if self.env_config: + env_config_model = EnvConfig(**self.env_config) + + return TracePayload( + task_name=self.name, + prompt=self.prompt, + code_snippet=self.code_snippet, + env_config=env_config_model, + all_hubs=all_hubs, + job_id=self.job_id, + group_id=self.group_id, + variants=self.variants if self.variants else None, + ) + # ========================================================================= # Tool Operations # ========================================================================= @@ -258,7 +340,7 @@ async def log(self, metrics: dict[str, Any]) -> None: await make_request( method="POST", url=f"{settings.hud_telemetry_url}/traces/{self.trace_id}/log", - json={"metrics": metrics, "timestamp": datetime.now(UTC).isoformat()}, + json={"metrics": metrics}, api_key=api_key, ) except Exception as e: @@ -271,21 +353,11 @@ async def _trace_enter(self) -> None: return try: - data: dict[str, Any] = { - "task_name": self.name, - "started_at": self._started_at.isoformat() if self._started_at else None, - } - if self.job_id: - data["job_id"] = self.job_id - if self.group_id: - data["group_id"] = self.group_id - if self.variants: - data["variants"] = self.variants - + payload = self._build_base_payload() await make_request( method="POST", - url=f"{settings.hud_telemetry_url}/trace/{self.trace_id}/enter", - json=data, + url=f"{settings.hud_api_url}/trace/{self.trace_id}/enter", + json=payload.model_dump(exclude_none=True), api_key=api_key, ) except Exception as e: @@ -297,27 +369,22 @@ async def _trace_exit(self, error_message: str | None = None) -> None: if not settings.telemetry_enabled or not api_key: return + # Use evaluate tool reward if not manually set + reward = self.reward + if reward is None: + reward = getattr(self._env, "_evaluate_reward", None) + try: - data: dict[str, Any] = { - "task_name": self.name, - "completed_at": self._completed_at.isoformat() if self._completed_at else None, - "success": self.success, - } - if self.job_id: - data["job_id"] = self.job_id - if self.group_id: - data["group_id"] = self.group_id - if self.variants: - data["variants"] = self.variants - if self.reward is not None: - data["reward"] = self.reward - if error_message: - data["error_message"] = error_message - + payload = TraceExitPayload( + **self._build_base_payload().model_dump(), + reward=reward, + success=self.success, + error_message=error_message, + ) await make_request( method="POST", - url=f"{settings.hud_telemetry_url}/trace/{self.trace_id}/exit", - json=data, + url=f"{settings.hud_api_url}/trace/{self.trace_id}/exit", + json=payload.model_dump(exclude_none=True), api_key=api_key, ) except Exception as e: diff --git a/hud/trace/mixin.py b/hud/trace/mixin.py index d8a5e6b4..977f58e7 100644 --- a/hud/trace/mixin.py +++ b/hud/trace/mixin.py @@ -124,6 +124,42 @@ async def call_tool(self, name: str, arguments: dict[str, Any] | None = None) -> """Placeholder - implemented by Environment.""" raise NotImplementedError + def _capture_code_snippet(self) -> str | None: + """Capture the code inside the trace() with-block (best effort). + + Returns None if source cannot be extracted (e.g., REPL, Jupyter). + """ + frame = inspect.currentframe() + if frame is None: + return None + + try: + # Go up: _capture_code_snippet -> trace -> user code + caller = frame.f_back + if caller is not None: + caller = caller.f_back + if caller is None: + return None + + body_source, _ = _get_with_block_body(caller) + return body_source + except ASTExtractionError: + # Can't extract from REPL/Jupyter - that's OK + return None + except Exception as e: + logger.debug("Failed to capture code snippet: %s", e) + return None + finally: + del frame + + def _get_env_config(self) -> dict[str, Any] | None: + """Get serializable environment configuration. + + Returns dict with connections and local tools. + """ + # This will be overridden by Environment with actual implementation + return None + @property def last_traces(self) -> list[TraceContext] | None: """Get TraceContext objects from the last parallel execution. @@ -223,6 +259,12 @@ async def trace( variant_combos = _expand_variants(variants) total_traces = len(variant_combos) * group + # Capture code snippet (best effort - won't work in REPL/Jupyter) + code_snippet = self._capture_code_snippet() + + # Get environment config + env_config = self._get_env_config() + # Validate parallelization - only remote connections allowed for group > 1 if total_traces > 1 and not self.is_parallelizable: # type: ignore[attr-defined] local_conns = self.local_connections # type: ignore[attr-defined] @@ -244,6 +286,8 @@ async def trace( api_key=api_key, job_id=job_id, _variants=variant_combos[0], + _code_snippet=code_snippet, + _env_config=env_config, ) async with tc: async with self: # type: ignore[attr-defined] @@ -258,6 +302,8 @@ async def trace( group_ids=group_ids, job_id=job_id, api_key=api_key, + code_snippet=code_snippet, + env_config=env_config, ) # Create parent tc with results injected @@ -267,6 +313,8 @@ async def trace( trace_id=trace_id, api_key=api_key, job_id=job_id, + _code_snippet=code_snippet, + _env_config=env_config, ) tc.results = completed self._last_traces = completed @@ -286,6 +334,8 @@ async def _run_parallel_trace( group_ids: list[str] | None, job_id: str | None, api_key: str | None, + code_snippet: str | None, + env_config: dict[str, Any] | None, ) -> list[TraceContext]: """Run parallel trace execution using AST extraction. @@ -303,6 +353,8 @@ async def _run_parallel_trace( group_ids: Optional list of group IDs (one per total trace) job_id: Optional job ID (auto-detected from current job if not provided) api_key: Optional API key + code_snippet: Captured code from the with-block + env_config: Environment configuration """ # Get the caller's frame (skip this method and the trace method) frame = inspect.currentframe() @@ -352,6 +404,8 @@ async def _run_parallel_trace( _group_id=resolved_group_ids[idx], _index=idx, _variants=variant, + _code_snippet=code_snippet, + _env_config=env_config, ) trace_contexts.append(tc) idx += 1 diff --git a/hud/trace/tests/__init__.py b/hud/trace/tests/__init__.py new file mode 100644 index 00000000..79c48157 --- /dev/null +++ b/hud/trace/tests/__init__.py @@ -0,0 +1,2 @@ +"""Tests for hud.trace module.""" + diff --git a/hud/trace/tests/test_context.py b/hud/trace/tests/test_context.py new file mode 100644 index 00000000..8ccba9e2 --- /dev/null +++ b/hud/trace/tests/test_context.py @@ -0,0 +1,288 @@ +"""Tests for hud.trace.context module.""" + +from __future__ import annotations + +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest + +from hud.trace.context import ( + TraceContext, + _httpx_request_hook, + _is_hud_url, + get_current_trace_headers, +) + + +class TestIsHudUrl: + """Tests for _is_hud_url helper.""" + + def test_inference_hud_ai_is_hud(self) -> None: + """inference.hud.ai is a HUD URL.""" + assert _is_hud_url("https://inference.hud.ai/v1/chat") is True + assert _is_hud_url("http://inference.hud.ai/v1/chat") is True + + def test_mcp_hud_ai_is_hud(self) -> None: + """mcp.hud.ai is a HUD URL.""" + assert _is_hud_url("https://mcp.hud.ai/browser") is True + assert _is_hud_url("http://mcp.hud.ai/some/path") is True + + def test_mcp_hud_so_is_hud(self) -> None: + """mcp.hud.so is a HUD URL.""" + assert _is_hud_url("https://mcp.hud.so/browser") is True + + def test_other_urls_are_not_hud(self) -> None: + """Other URLs are not HUD URLs.""" + assert _is_hud_url("https://example.com") is False + assert _is_hud_url("https://api.openai.com") is False + assert _is_hud_url("https://notinference.hud.ai.fake.com") is False + + +class TestHttpxRequestHook: + """Tests for _httpx_request_hook.""" + + def test_injects_trace_headers_for_hud_urls(self) -> None: + """Hook injects trace headers for HUD URLs when in trace context.""" + mock_request = MagicMock() + mock_request.url = "https://inference.hud.ai/v1/chat" + mock_request.headers = {} + + # Set up trace context + from hud.trace.context import _current_trace_headers + token = _current_trace_headers.set({"Trace-Id": "test-trace-123"}) + + try: + _httpx_request_hook(mock_request) + + assert mock_request.headers["Trace-Id"] == "test-trace-123" + finally: + _current_trace_headers.reset(token) + + def test_injects_api_key_for_hud_urls(self) -> None: + """Hook injects API key for HUD URLs when no auth present.""" + mock_request = MagicMock() + mock_request.url = "https://mcp.hud.ai/browser" + mock_request.headers = {} + + with patch("hud.trace.context.settings") as mock_settings: + mock_settings.api_key = "test-api-key" + + _httpx_request_hook(mock_request) + + assert mock_request.headers["Authorization"] == "Bearer test-api-key" + + def test_does_not_override_existing_auth(self) -> None: + """Hook does not override existing Authorization header.""" + mock_request = MagicMock() + mock_request.url = "https://mcp.hud.ai/browser" + mock_request.headers = {"Authorization": "Bearer existing-token"} + + with patch("hud.trace.context.settings") as mock_settings: + mock_settings.api_key = "test-api-key" + + _httpx_request_hook(mock_request) + + assert mock_request.headers["Authorization"] == "Bearer existing-token" + + def test_ignores_non_hud_urls(self) -> None: + """Hook ignores non-HUD URLs.""" + mock_request = MagicMock() + mock_request.url = "https://api.openai.com/v1/chat" + mock_request.headers = {} + + # Set up trace context + from hud.trace.context import _current_trace_headers + token = _current_trace_headers.set({"Trace-Id": "test-trace-123"}) + + try: + with patch("hud.trace.context.settings") as mock_settings: + mock_settings.api_key = "test-api-key" + + _httpx_request_hook(mock_request) + + # No headers should be added + assert "Trace-Id" not in mock_request.headers + assert "Authorization" not in mock_request.headers + finally: + _current_trace_headers.reset(token) + + +class TestTraceContext: + """Tests for TraceContext.""" + + def test_init_generates_trace_id(self) -> None: + """TraceContext generates trace_id if not provided.""" + mock_env = MagicMock() + tc = TraceContext(env=mock_env, name="test-task") + + assert tc.trace_id is not None + assert len(tc.trace_id) == 36 # UUID format + + def test_init_uses_provided_trace_id(self) -> None: + """TraceContext uses provided trace_id.""" + mock_env = MagicMock() + tc = TraceContext(env=mock_env, name="test-task", trace_id="custom-id") + + assert tc.trace_id == "custom-id" + + def test_headers_contains_trace_id(self) -> None: + """headers property returns dict with trace ID.""" + mock_env = MagicMock() + tc = TraceContext(env=mock_env, name="test-task", trace_id="test-123") + + assert tc.headers == {"Trace-Id": "test-123"} + + def test_success_true_when_no_error(self) -> None: + """success property returns True when no error.""" + mock_env = MagicMock() + tc = TraceContext(env=mock_env, name="test-task") + + assert tc.success is True + + def test_success_false_when_error(self) -> None: + """success property returns False when error is set.""" + mock_env = MagicMock() + tc = TraceContext(env=mock_env, name="test-task") + tc.error = ValueError("test error") + + assert tc.success is False + + def test_done_false_initially(self) -> None: + """done property returns False initially.""" + mock_env = MagicMock() + tc = TraceContext(env=mock_env, name="test-task") + + assert tc.done is False + + def test_variants_empty_by_default(self) -> None: + """variants is empty dict by default.""" + mock_env = MagicMock() + tc = TraceContext(env=mock_env, name="test-task") + + assert tc.variants == {} + + def test_variants_set_from_init(self) -> None: + """variants set from _variants parameter.""" + mock_env = MagicMock() + tc = TraceContext( + env=mock_env, + name="test-task", + _variants={"model": "gpt-4o", "temp": 0.7}, + ) + + assert tc.variants == {"model": "gpt-4o", "temp": 0.7} + + @pytest.mark.asyncio + async def test_context_manager_sets_headers(self) -> None: + """Context manager sets trace headers in contextvar.""" + mock_env = MagicMock() + tc = TraceContext(env=mock_env, name="test-task", trace_id="test-123") + + # Mock telemetry calls + with patch.object(tc, "_trace_enter", new_callable=AsyncMock): + with patch.object(tc, "_trace_exit", new_callable=AsyncMock): + assert get_current_trace_headers() is None + + async with tc: + headers = get_current_trace_headers() + assert headers is not None + assert headers["Trace-Id"] == "test-123" + + assert get_current_trace_headers() is None + + @pytest.mark.asyncio + async def test_context_manager_captures_error(self) -> None: + """Context manager captures exception in error field.""" + mock_env = MagicMock() + tc = TraceContext(env=mock_env, name="test-task") + + with patch.object(tc, "_trace_enter", new_callable=AsyncMock): + with patch.object(tc, "_trace_exit", new_callable=AsyncMock): + with pytest.raises(ValueError): + async with tc: + raise ValueError("test error") + + assert tc.error is not None + assert str(tc.error) == "test error" + assert tc.success is False + + @pytest.mark.asyncio + async def test_call_tool_delegates_to_env(self) -> None: + """call_tool delegates to environment.""" + mock_env = MagicMock() + mock_env.call_tool = AsyncMock(return_value="result") + + tc = TraceContext(env=mock_env, name="test-task") + result = await tc.call_tool("my_tool", {"arg": "value"}) + + mock_env.call_tool.assert_called_once_with("my_tool", {"arg": "value"}) + assert result == "result" + + def test_repr(self) -> None: + """__repr__ shows useful info.""" + mock_env = MagicMock() + tc = TraceContext(env=mock_env, name="test-task", trace_id="abc12345-6789-0000-0000-000000000000") + tc.reward = 0.95 + + repr_str = repr(tc) + assert "abc12345" in repr_str + assert "test-task" in repr_str + assert "0.95" in repr_str + + +class TestTraceContextPrompt: + """Tests for TraceContext.prompt feature.""" + + def test_prompt_defaults_from_env(self) -> None: + """TraceContext.prompt defaults from env.prompt.""" + mock_env = MagicMock() + mock_env.prompt = "Task prompt from environment" + + tc = TraceContext( + env=mock_env, + name="test-task", + trace_id="test-123", + ) + + assert tc.prompt == "Task prompt from environment" + + def test_prompt_none_when_env_has_no_prompt(self) -> None: + """TraceContext.prompt is None when env has no prompt.""" + mock_env = MagicMock(spec=[]) # No prompt attribute + + tc = TraceContext( + env=mock_env, + name="test-task", + trace_id="test-123", + ) + + assert tc.prompt is None + + def test_prompt_can_be_overridden(self) -> None: + """TraceContext.prompt can be set to override env default.""" + mock_env = MagicMock() + mock_env.prompt = "Original prompt" + + tc = TraceContext( + env=mock_env, + name="test-task", + trace_id="test-123", + ) + + tc.prompt = "Overridden prompt" + assert tc.prompt == "Overridden prompt" + + def test_prompt_included_in_payload(self) -> None: + """Prompt is included in trace payload.""" + mock_env = MagicMock() + mock_env.prompt = "Test prompt" + mock_env._all_hubs = False + + tc = TraceContext( + env=mock_env, + name="test-task", + trace_id="test-123", + ) + + payload = tc._build_base_payload() + assert payload.prompt == "Test prompt" diff --git a/hud/trace/tests/test_mixin.py b/hud/trace/tests/test_mixin.py new file mode 100644 index 00000000..eddcaa8c --- /dev/null +++ b/hud/trace/tests/test_mixin.py @@ -0,0 +1,178 @@ +"""Tests for hud.trace.mixin module.""" + +from __future__ import annotations + +from typing import Any +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest + +from hud.trace.mixin import TraceMixin, _expand_variants + + +class TestExpandVariants: + """Tests for _expand_variants helper.""" + + def test_none_returns_empty_dict(self) -> None: + """None variants returns list with empty dict.""" + result = _expand_variants(None) + assert result == [{}] + + def test_empty_dict_returns_empty_dict(self) -> None: + """Empty variants returns list with empty dict.""" + result = _expand_variants({}) + assert result == [{}] + + def test_single_value_stays_single(self) -> None: + """Single non-list value stays as single variant.""" + result = _expand_variants({"model": "gpt-4o"}) + assert result == [{"model": "gpt-4o"}] + + def test_list_expands_to_variants(self) -> None: + """List value expands to multiple variants.""" + result = _expand_variants({"model": ["gpt-4o", "claude"]}) + assert result == [{"model": "gpt-4o"}, {"model": "claude"}] + + def test_multiple_lists_create_combinations(self) -> None: + """Multiple lists create all combinations.""" + result = _expand_variants({ + "model": ["a", "b"], + "temp": [0.0, 1.0], + }) + + assert len(result) == 4 + assert {"model": "a", "temp": 0.0} in result + assert {"model": "a", "temp": 1.0} in result + assert {"model": "b", "temp": 0.0} in result + assert {"model": "b", "temp": 1.0} in result + + def test_mixed_single_and_list(self) -> None: + """Mixed single values and lists work correctly.""" + result = _expand_variants({ + "model": ["gpt-4o", "claude"], + "temp": 0.7, + }) + + assert len(result) == 2 + assert {"model": "gpt-4o", "temp": 0.7} in result + assert {"model": "claude", "temp": 0.7} in result + + +class MockEnvironment(TraceMixin): + """Mock environment for testing TraceMixin.""" + + def __init__(self) -> None: + self.name = "test-env" + self._connections: dict[str, Any] = {} + self._last_traces = None + + @property + def is_parallelizable(self) -> bool: + return all( + getattr(c, "is_remote", True) + for c in self._connections.values() + ) + + @property + def local_connections(self) -> list[str]: + return [ + name for name, c in self._connections.items() + if getattr(c, "is_local", False) + ] + + async def call_tool(self, name: str, arguments: dict[str, Any] | None = None) -> Any: + return {"name": name, "arguments": arguments} + + async def __aenter__(self) -> "MockEnvironment": + return self + + async def __aexit__(self, *args: Any) -> None: + pass + + +class TestTraceMixin: + """Tests for TraceMixin.""" + + @pytest.mark.asyncio + async def test_trace_single_creates_context(self) -> None: + """trace() with group=1 creates single TraceContext.""" + env = MockEnvironment() + + async with env.trace("test-task") as tc: + assert tc.name == "test-task" + assert tc.trace_id is not None + assert tc.variants == {} + + @pytest.mark.asyncio + async def test_trace_sets_reward(self) -> None: + """reward can be set on TraceContext.""" + env = MockEnvironment() + + async with env.trace("test-task") as tc: + tc.reward = 0.95 + + assert tc.reward == 0.95 + + @pytest.mark.asyncio + async def test_trace_with_variants_single(self) -> None: + """trace() with single variant value works.""" + env = MockEnvironment() + + async with env.trace("test-task", variants={"model": "gpt-4o"}) as tc: + assert tc.variants == {"model": "gpt-4o"} + + @pytest.mark.asyncio + async def test_trace_rejects_parallel_with_local_connections(self) -> None: + """trace() raises error for parallel with local connections.""" + env = MockEnvironment() + + # Add a local connection + mock_conn = MagicMock() + mock_conn.is_local = True + mock_conn.is_remote = False + env._connections["local-server"] = mock_conn + + with pytest.raises(ValueError, match="Cannot run parallel traces"): + async with env.trace("test-task", group=2) as tc: + pass + + @pytest.mark.asyncio + async def test_trace_allows_parallel_with_remote_connections(self) -> None: + """trace() allows parallel with only remote connections.""" + env = MockEnvironment() + + # Add a remote connection + mock_conn = MagicMock() + mock_conn.is_local = False + mock_conn.is_remote = True + env._connections["remote-server"] = mock_conn + + # This should not raise (though parallel execution is complex to test) + # Just verify it doesn't raise the local connection error + assert env.is_parallelizable is True + + @pytest.mark.asyncio + async def test_trace_rejects_zero_group(self) -> None: + """trace() raises error for group <= 0.""" + env = MockEnvironment() + + with pytest.raises(ValueError, match="group must be >= 1"): + async with env.trace("test-task", group=0) as tc: + pass + + def test_last_traces_none_initially(self) -> None: + """last_traces is None before any parallel execution.""" + env = MockEnvironment() + assert env.last_traces is None + + @pytest.mark.asyncio + async def test_trace_context_delegates_call_tool(self) -> None: + """TraceContext.call_tool delegates to environment.""" + env = MockEnvironment() + + async with env.trace("test-task") as tc: + result = await tc.call_tool("my_tool", {"arg": "value"}) + + assert result["name"] == "my_tool" + assert result["arguments"] == {"arg": "value"} + diff --git a/hud/trace/tests/test_parallel.py b/hud/trace/tests/test_parallel.py new file mode 100644 index 00000000..c0bda532 --- /dev/null +++ b/hud/trace/tests/test_parallel.py @@ -0,0 +1,156 @@ +"""Tests for hud.trace.parallel module.""" + +from __future__ import annotations + +import ast +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest + +from hud.trace.parallel import ( + ASTExtractionError, + _extract_body, + _find_async_with, + _get_end_line, + run_parallel_traces, +) + + +class TestASTHelpers: + """Tests for AST helper functions.""" + + def test_find_async_with_finds_correct_node(self) -> None: + """_find_async_with finds the async with containing target line.""" + source = ''' +async def main(): + x = 1 + async with something as ctx: + do_stuff() + more_stuff() + y = 2 +''' + tree = ast.parse(source) + + # Line 4 is inside the async with + node = _find_async_with(tree, 5) + assert node is not None + assert isinstance(node, ast.AsyncWith) + + def test_find_async_with_returns_none_when_not_found(self) -> None: + """_find_async_with returns None when line is outside async with.""" + source = ''' +async def main(): + x = 1 + async with something as ctx: + do_stuff() + y = 2 +''' + tree = ast.parse(source) + + # Line 6 is outside the async with + node = _find_async_with(tree, 7) + assert node is None + + def test_get_end_line(self) -> None: + """_get_end_line returns last line of node.""" + source = ''' +async with ctx: + line1() + line2() + line3() +''' + tree = ast.parse(source) + async_with = tree.body[0] + + end_line = _get_end_line(async_with) + assert end_line >= 4 # At least through line 4 + + def test_extract_body(self) -> None: + """_extract_body extracts the body source from async with.""" + source = '''async with ctx: + do_thing() + more_thing() +''' + lines = source.split('\n') + lines = [line + '\n' for line in lines] + + tree = ast.parse(source) + async_with = tree.body[0] + + body = _extract_body(lines, async_with) + assert "do_thing()" in body + assert "more_thing()" in body + + +class TestRunParallelTraces: + """Tests for run_parallel_traces function.""" + + @pytest.mark.asyncio + async def test_runs_body_for_each_context(self) -> None: + """run_parallel_traces runs body for each TraceContext.""" + # Create mock trace contexts + mock_tcs = [] + for i in range(3): + tc = MagicMock() + tc.index = i + tc.__aenter__ = AsyncMock(return_value=tc) + tc.__aexit__ = AsyncMock(return_value=None) + mock_tcs.append(tc) + + # Simple body that sets reward + body_source = "tc.reward = tc.index * 10" + captured_locals: dict[str, object] = {} + + results = await run_parallel_traces(mock_tcs, body_source, captured_locals) + + assert len(results) == 3 + # Each context should have had __aenter__ and __aexit__ called + for tc in mock_tcs: + tc.__aenter__.assert_called_once() + tc.__aexit__.assert_called_once() + + @pytest.mark.asyncio + async def test_captures_exceptions(self) -> None: + """run_parallel_traces captures exceptions in context.""" + tc = MagicMock() + tc.index = 0 + tc.__aenter__ = AsyncMock(return_value=tc) + tc.__aexit__ = AsyncMock(return_value=None) + + # Body that raises + body_source = "raise ValueError('test error')" + captured_locals: dict[str, object] = {} + + results = await run_parallel_traces([tc], body_source, captured_locals) + + assert len(results) == 1 + # Error should be captured, not raised + assert hasattr(tc, "_error") or tc.__aexit__.called + + @pytest.mark.asyncio + async def test_uses_captured_locals(self) -> None: + """run_parallel_traces uses captured locals in body execution.""" + tc = MagicMock() + tc.index = 0 + tc.result = None + tc.__aenter__ = AsyncMock(return_value=tc) + tc.__aexit__ = AsyncMock(return_value=None) + + # Body that uses captured local + body_source = "tc.result = my_value * 2" + captured_locals = {"my_value": 21} + + results = await run_parallel_traces([tc], body_source, captured_locals) + + assert len(results) == 1 + + +class TestASTExtractionError: + """Tests for ASTExtractionError.""" + + def test_is_exception(self) -> None: + """ASTExtractionError is an exception.""" + error = ASTExtractionError("test message") + assert isinstance(error, Exception) + assert str(error) == "test message" + From 15053aec733ea9c7dc4812be11011edfc5712c2b Mon Sep 17 00:00:00 2001 From: lorenss-m Date: Tue, 9 Dec 2025 05:21:26 -0800 Subject: [PATCH 03/92] format --- hud/cli/flows/tasks.py | 2 +- hud/datasets/runner.py | 2 +- hud/environment/connection.py | 28 ++-- hud/environment/connectors/__init__.py | 10 +- hud/environment/connectors/base.py | 19 ++- hud/environment/connectors/local.py | 44 +++--- hud/environment/connectors/mcp_config.py | 56 +++---- hud/environment/connectors/openai.py | 49 +++--- hud/environment/connectors/remote.py | 39 ++--- hud/environment/connectors/task.py | 30 ++-- hud/environment/environment.py | 139 +++++++++--------- hud/environment/integrations/__init__.py | 10 +- hud/environment/integrations/anthropic.py | 47 +++--- hud/environment/integrations/gemini.py | 45 +++--- hud/environment/integrations/langchain.py | 40 ++--- hud/environment/integrations/openai.py | 87 ++++++----- hud/environment/mock.py | 102 ++++++------- hud/environment/router.py | 11 +- hud/environment/tests/__init__.py | 1 - hud/environment/tests/test_connection.py | 93 ++++++------ hud/environment/tests/test_connectors.py | 93 ++++++------ hud/environment/tests/test_environment.py | 25 ++-- hud/environment/tests/test_integrations.py | 45 +++--- .../tests/test_local_connectors.py | 3 - hud/environment/types.py | 5 +- hud/environment/utils/formats.py | 60 ++++---- hud/environment/utils/schema.py | 32 ++-- hud/samples/browser.py | 2 +- hud/trace/__init__.py | 4 +- hud/trace/context.py | 117 ++++++++------- hud/trace/mixin.py | 139 +++++++++--------- hud/trace/parallel.py | 36 ++--- hud/trace/tests/__init__.py | 1 - hud/trace/tests/test_context.py | 80 +++++----- hud/trace/tests/test_mixin.py | 73 +++++---- hud/trace/tests/test_parallel.py | 51 ++++--- 36 files changed, 825 insertions(+), 795 deletions(-) diff --git a/hud/cli/flows/tasks.py b/hud/cli/flows/tasks.py index c4d8304d..e6374d3f 100644 --- a/hud/cli/flows/tasks.py +++ b/hud/cli/flows/tasks.py @@ -449,7 +449,7 @@ def _one(x: Any) -> dict[str, Any]: "prompt": t.prompt, "mcp_config": { "hud": { - "url": "https://mcp.hud.ai/v3/mcp", + "url": settings.hud_mcp_url, "headers": { "Authorization": "Bearer ${HUD_API_KEY}", "Mcp-Image": remote_image, diff --git a/hud/datasets/runner.py b/hud/datasets/runner.py index 7875500a..3960aeef 100644 --- a/hud/datasets/runner.py +++ b/hud/datasets/runner.py @@ -11,8 +11,8 @@ from datasets import Dataset, load_dataset -from hud.telemetry import async_job, async_trace from hud.datasets.utils import calculate_group_stats, submit_rollouts +from hud.telemetry import async_job, async_trace from hud.types import AgentType, Task, Trace if TYPE_CHECKING: diff --git a/hud/environment/connection.py b/hud/environment/connection.py index a104881b..e65869fd 100644 --- a/hud/environment/connection.py +++ b/hud/environment/connection.py @@ -21,8 +21,8 @@ class ConnectionType(str, Enum): """Type of connection - determines parallelization capability.""" - - LOCAL = "local" # Stdio/Docker - single instance, not parallelizable + + LOCAL = "local" # Stdio/Docker - single instance, not parallelizable REMOTE = "remote" # HTTP/URL - can spawn multiple instances @@ -45,7 +45,7 @@ def __init__( class Connector: """Manages a connection to an MCP server with tool caching. - + Client creation is deferred to connect() so that: 1. Each parallel trace gets fresh client instances 2. Connection happens inside trace context (for header injection) @@ -68,12 +68,12 @@ def __init__( self.connection_type = connection_type self.client: FastMCPClient[Any] | None = None self._tools_cache: list[mcp_types.Tool] | None = None - + @property def is_local(self) -> bool: """True if this is a local (non-parallelizable) connection.""" return self.connection_type == ConnectionType.LOCAL - + @property def is_remote(self) -> bool: """True if this is a remote (parallelizable) connection.""" @@ -89,13 +89,13 @@ def cached_tools(self) -> list[mcp_types.Tool]: async def connect(self) -> None: """Create FastMCP client and connect. - + Client is created here (not in __init__) so that: 1. Each parallel trace gets fresh client instances 2. httpx auto-instrumentation can inject trace headers """ from fastmcp.client import Client as FastMCPClient - + # Create fresh client from stored transport config self.client = FastMCPClient(transport=self._transport, auth=self._auth) await self.client.__aenter__() @@ -141,11 +141,13 @@ async def list_tools(self) -> list[mcp_types.Tool]: # Apply prefix name = f"{self.config.prefix}_{tool.name}" if self.config.prefix else tool.name - result.append(mcp_types.Tool( - name=name, - description=tool.description, - inputSchema=tool.inputSchema, - )) + result.append( + mcp_types.Tool( + name=name, + description=tool.description, + inputSchema=tool.inputSchema, + ) + ) self._tools_cache = result return result @@ -158,7 +160,7 @@ async def call_tool( raise RuntimeError("Not connected - call connect() first") # Strip prefix when calling remote if self.config.prefix and name.startswith(f"{self.config.prefix}_"): - name = name[len(self.config.prefix) + 1:] + name = name[len(self.config.prefix) + 1 :] return await self.client.call_tool_mcp(name, arguments or {}) async def list_resources(self) -> list[mcp_types.Resource]: diff --git a/hud/environment/connectors/__init__.py b/hud/environment/connectors/__init__.py index e99850da..7b8919ac 100644 --- a/hud/environment/connectors/__init__.py +++ b/hud/environment/connectors/__init__.py @@ -15,24 +15,24 @@ class ConnectorsMixin( OpenAIConnectorMixin, ): """Combined connector mixin providing all connection methods. - + Remote connections: connect_hub(slug) - HUD Hub environment (fetches mcp_config from API) connect_url(url) - MCP server via URL connect_openapi(spec) - Mount OpenAPI spec as MCP server - + Local connections (in-process): connect_image(image) - Docker image via stdio connect_fastapi(app) - Mount FastAPI app as MCP server connect_server(server) - Mount MCPServer/FastMCP directly - + MCP config: connect_mcp(config) - Single mcp_config server (auto-detects local/remote) connect_mcp_config(mcp_config) - Multiple mcp_config servers - + Task: connect_task(slug) - Load task from platform by slug - + Framework imports: connect_function_tools(tools) - Import OpenAI Agents SDK FunctionTools """ diff --git a/hud/environment/connectors/base.py b/hud/environment/connectors/base.py index 997dd715..94557e94 100644 --- a/hud/environment/connectors/base.py +++ b/hud/environment/connectors/base.py @@ -16,7 +16,7 @@ class BaseConnectorMixin: """Base mixin providing connection helper. - + Requires: _connections: dict[str, Connector] """ @@ -36,7 +36,7 @@ def _add_connection( transform: Callable[[Tool], Tool | None] | None = None, ) -> Any: """Add a connection to the environment. - + Args: name: Connection name/alias. transport: FastMCP transport (URL, config dict, etc.). @@ -46,16 +46,23 @@ def _add_connection( include: Only include these tools. exclude: Exclude these tools. transform: Transform function for tools. - + Returns: self for chaining. """ from hud.environment.connection import ConnectionConfig, Connector - + config = ConnectionConfig( - prefix=prefix, include=include, exclude=exclude, transform=transform, + prefix=prefix, + include=include, + exclude=exclude, + transform=transform, ) self._connections[name] = Connector( - transport, config, name, connection_type=connection_type, auth=auth, + transport, + config, + name, + connection_type=connection_type, + auth=auth, ) return self diff --git a/hud/environment/connectors/local.py b/hud/environment/connectors/local.py index 6ae170b8..66633221 100644 --- a/hud/environment/connectors/local.py +++ b/hud/environment/connectors/local.py @@ -16,12 +16,12 @@ class LocalConnectorMixin(MCPConfigConnectorMixin): """Mixin providing local connection methods. - + Methods: connect_image(image) - Run Docker image via stdio connect_fastapi(app) - Mount FastAPI app as MCP server connect_server(server) - Mount any MCPServer/FastMCP directly - + Inherits connect_mcp() from MCPConfigConnectorMixin. """ @@ -42,21 +42,21 @@ def connect_image( transform: Callable[[Tool], Tool | None] | None = None, ) -> Any: """Connect to a Docker image via stdio. - + Creates an MCP config that runs: docker run -i --rm {image} Environment variables from `.env` files are auto-injected. - + Example: ```python env = Environment("my-env") env.connect_image("mcp/fetch") - + async with env: result = await env.call_tool("fetch", url="https://example.com") ``` """ from hud.cli.utils.docker import create_docker_run_command - + cmd = create_docker_run_command( image=image, docker_args=docker_args, @@ -64,7 +64,7 @@ def connect_image( interactive=True, remove=True, ) - + name = alias or image mcp_config = { name: { @@ -89,30 +89,32 @@ def connect_fastapi( prefix: str | None = None, ) -> Any: """Mount a FastAPI application as an MCP server. - + Uses FastMCP's from_fastapi() to convert FastAPI endpoints to MCP tools. - + Example: ```python from fastapi import FastAPI - + api = FastAPI() - + + @api.get("/users/{user_id}", operation_id="get_user") def get_user(user_id: int): return {"id": user_id, "name": "Alice"} - + + env = Environment("my-env") env.connect_fastapi(api) - + async with env: result = await env.call_tool("get_user", user_id=1) ``` - + Tip: Use operation_id in FastAPI decorators for cleaner tool names. """ from fastmcp import FastMCP - + server_name = name or getattr(app, "title", None) or "fastapi" mcp_server = FastMCP.from_fastapi(app=app, name=server_name) self.mount(mcp_server, prefix=prefix) @@ -125,20 +127,22 @@ def connect_server( prefix: str | None = None, ) -> Any: """Mount an MCPServer or FastMCP instance directly. - + Example: ```python from fastmcp import FastMCP - + tools = FastMCP("tools") - + + @tools.tool def greet(name: str) -> str: return f"Hello, {name}!" - + + env = Environment("my-env") env.connect_server(tools) - + async with env: result = await env.call_tool("greet", name="World") ``` diff --git a/hud/environment/connectors/mcp_config.py b/hud/environment/connectors/mcp_config.py index e9a06cdd..ebfacee5 100644 --- a/hud/environment/connectors/mcp_config.py +++ b/hud/environment/connectors/mcp_config.py @@ -28,33 +28,35 @@ def connect_mcp( transform: Callable[[Tool], Tool | None] | None = None, ) -> Any: """Connect using an mcp_config dictionary (single server). - + Auto-detects LOCAL (stdio) vs REMOTE (URL) based on config. - + Example: ```python env = Environment("my-env") - + # Stdio server - env.connect_mcp({ - "filesystem": { - "command": "npx", - "args": ["-y", "@modelcontextprotocol/server-filesystem", "/tmp"], + env.connect_mcp( + { + "filesystem": { + "command": "npx", + "args": ["-y", "@modelcontextprotocol/server-filesystem", "/tmp"], + } } - }) - + ) + async with env: await env.call_tool("read_file", path="/tmp/test.txt") ``` """ from hud.environment.connection import ConnectionType - + name = alias or next(iter(config.keys()), "mcp") server_config = next(iter(config.values()), {}) - + is_local = "command" in server_config or "args" in server_config conn_type = ConnectionType.LOCAL if is_local else ConnectionType.REMOTE - + return self._add_connection( name, config, @@ -71,24 +73,26 @@ def connect_mcp_config( **kwargs: Any, ) -> Any: """Connect multiple servers from an mcp_config dictionary. - + Example: ```python env = Environment("my-env") - + # Claude Desktop style config - env.connect_mcp_config({ - "filesystem": { - "command": "npx", - "args": ["-y", "@modelcontextprotocol/server-filesystem", "/tmp"], - }, - "github": { - "command": "npx", - "args": ["-y", "@modelcontextprotocol/server-github"], - "env": {"GITHUB_TOKEN": "..."}, - }, - }) - + env.connect_mcp_config( + { + "filesystem": { + "command": "npx", + "args": ["-y", "@modelcontextprotocol/server-filesystem", "/tmp"], + }, + "github": { + "command": "npx", + "args": ["-y", "@modelcontextprotocol/server-github"], + "env": {"GITHUB_TOKEN": "..."}, + }, + } + ) + async with env: await env.call_tool("read_file", path="/tmp/test.txt") await env.call_tool("search_repositories", query="mcp") diff --git a/hud/environment/connectors/openai.py b/hud/environment/connectors/openai.py index fdaea52a..893e50b1 100644 --- a/hud/environment/connectors/openai.py +++ b/hud/environment/connectors/openai.py @@ -3,16 +3,14 @@ from __future__ import annotations import json -from typing import TYPE_CHECKING, Any - -if TYPE_CHECKING: - from collections.abc import Callable +from typing import Any __all__ = ["OpenAIConnectorMixin"] # Lazy import check try: from agents import FunctionTool + _HAS_OPENAI_AGENTS = True except ImportError: _HAS_OPENAI_AGENTS = False @@ -21,10 +19,10 @@ class OpenAIConnectorMixin: """Mixin providing OpenAI Agents SDK connector methods.""" - + # These are defined on Environment/MCPServer _tool_manager: Any - + def connect_function_tools( self, tools: list[Any], @@ -32,30 +30,33 @@ def connect_function_tools( prefix: str | None = None, ) -> Any: """Import FunctionTools from the OpenAI Agents SDK. - + Wraps each tool so calls go through HUD with telemetry. - + Example: ```python from agents import function_tool - + + @function_tool def search(query: str) -> str: '''Search for information.''' return f"Results for {query}" - - @function_tool + + + @function_tool def calculate(expression: str) -> float: '''Evaluate a math expression.''' return eval(expression) - + + env = Environment("my-env") env.connect_function_tools([search, calculate]) - + async with env: result = await env.call_tool("search", query="MCP protocol") ``` - + Note: Requires `openai-agents`: pip install openai-agents """ @@ -64,37 +65,38 @@ def calculate(expression: str) -> float: "openai-agents is required for connect_function_tools. " "Install with: pip install openai-agents" ) - + for tool in tools: if isinstance(tool, FunctionTool): self._add_openai_function_tool(tool, prefix) - + return self - + def _add_openai_function_tool(self, tool: Any, prefix: str | None) -> None: """Convert OpenAI FunctionTool to local MCP tool.""" name = f"{prefix}_{tool.name}" if prefix else tool.name - + # Get the original invoke function original_invoke = tool.on_invoke_tool - + # Create wrapper that calls the original async def invoke(**arguments: Any) -> Any: # OpenAI's on_invoke_tool expects (ToolContext, str_json_args) # We need to create a minimal context from agents.tool_context import ToolContext + ctx = ToolContext(context=None) result = await original_invoke(ctx, json.dumps(arguments)) return result - + # Set function metadata for FastMCP invoke.__name__ = name invoke.__doc__ = tool.description - + # Register using FastMCP's tool decorator mechanism # We access the internal _tool_manager from MCPServer from fastmcp.tools import Tool as FastMCPTool - + fastmcp_tool = FastMCPTool.from_function( fn=invoke, name=name, @@ -102,6 +104,5 @@ async def invoke(**arguments: Any) -> Any: ) # Override the schema with OpenAI's (more accurate) fastmcp_tool.parameters = tool.params_json_schema - - self._tool_manager.add_tool(fastmcp_tool) + self._tool_manager.add_tool(fastmcp_tool) diff --git a/hud/environment/connectors/remote.py b/hud/environment/connectors/remote.py index 5dc539b3..f7500d64 100644 --- a/hud/environment/connectors/remote.py +++ b/hud/environment/connectors/remote.py @@ -37,22 +37,22 @@ def connect_hub( transform: Callable[[Tool], Tool | None] | None = None, ) -> Any: """Connect to a HUD Hub environment. - + Fetches mcp_config from api.hud.so immediately and creates connectors. - + Example: ```python env = Environment("my-env") env.connect_hub("hud/browser") - + async with env: await env.call_tool("navigate", url="https://google.com") ``` """ import httpx - + from hud.settings import settings - + # Store hub config for trace serialization hub_config: dict[str, Any] = {"slug": slug} if alias: @@ -63,18 +63,18 @@ def connect_hub( hub_config["include"] = include if exclude: hub_config["exclude"] = exclude - + if not hasattr(self, "_hub_configs"): self._hub_configs = [] self._hub_configs.append(hub_config) - + # Fetch mcp_config synchronously logger.info("Loading hub environment: %s", slug) - + headers = {} if settings.api_key: headers["Authorization"] = f"Bearer {settings.api_key}" - + with httpx.Client() as client: response = client.get( f"{settings.hud_api_url}/environments/{slug}/mcp-config", @@ -82,7 +82,7 @@ def connect_hub( ) response.raise_for_status() data = response.json() - + mcp_config: dict[str, dict[str, Any]] = data.get("mcp_config", data) self.connect_mcp_config( mcp_config, prefix=prefix, include=include, exclude=exclude, transform=transform @@ -102,7 +102,7 @@ def connect_url( transform: Callable[[Tool], Tool | None] | None = None, ) -> Any: """Connect to an MCP server via URL. - + Example: ```python env = Environment("my-env") @@ -110,13 +110,13 @@ def connect_url( "https://mcp.example.com", headers={"Authorization": "Bearer token"}, ) - + async with env: await env.call_tool("search", query="hello") ``` """ from hud.environment.connection import ConnectionType - + auth = headers.get("Authorization") if headers else None return self._add_connection( alias or url, @@ -140,15 +140,15 @@ def connect_openapi( timeout: float = 30.0, ) -> Any: """Mount an OpenAPI specification as an MCP server. - + Converts REST API endpoints to MCP tools. Base URL is auto-inferred from the spec URL when possible. - + Example: ```python env = Environment("my-env") env.connect_openapi("https://petstore.swagger.io/v2/swagger.json") - + async with env: result = await env.call_tool("getPetById", petId=1) ``` @@ -163,18 +163,19 @@ def connect_openapi( if base_url is None: parsed = urlparse(openapi_spec) base_url = f"{parsed.scheme}://{parsed.netloc}" - + resp = httpx.get(openapi_spec, headers=headers) resp.raise_for_status() openapi_spec = resp.json() else: import json + with open(openapi_spec) as f: openapi_spec = json.load(f) - + if base_url is None: raise ValueError("base_url is required when openapi_spec is a dict or file") - + client = httpx.AsyncClient(base_url=base_url, headers=headers or {}, timeout=timeout) mcp_server = FastMCP.from_openapi( openapi_spec=cast("dict[str, Any]", openapi_spec), diff --git a/hud/environment/connectors/task.py b/hud/environment/connectors/task.py index 4298bbcf..1fe1033e 100644 --- a/hud/environment/connectors/task.py +++ b/hud/environment/connectors/task.py @@ -17,7 +17,7 @@ class TaskConnectorMixin(MCPConfigConnectorMixin): """Mixin providing connect_task() method. - + Inherits from MCPConfigConnectorMixin for connect_mcp_config(). """ @@ -29,20 +29,20 @@ def evaluate_tool(self, call: Any, /, **kwargs: Any) -> Any: def connect_task(self, slug: str) -> Any: """Connect to a task from the HUD platform. - + Fetches the task from api.hud.so immediately and applies configuration (mcp_config, setup_tool, evaluate_tool). - + Args: slug: Task slug in format "evalset/task_name" or "evalset/task_name@version". - + Returns: self for chaining. - + Example: ```python env = Environment("my-env").connect_task("my-org/browser-task") - + async with env: # Task's mcp_config is connected # Task's setup_tool runs automatically @@ -51,17 +51,17 @@ def connect_task(self, slug: str) -> Any: ``` """ import httpx - + from hud.settings import settings from hud.types import Task - + # Fetch task synchronously logger.info("Loading task from platform: %s", slug) - + headers = {} if settings.api_key: headers["Authorization"] = f"Bearer {settings.api_key}" - + with httpx.Client() as client: response = client.get( f"{settings.hud_api_url}/tasks/{slug}", @@ -69,7 +69,7 @@ def connect_task(self, slug: str) -> Any: ) response.raise_for_status() data = response.json() - + task = Task(**data) self._apply_task(task) logger.info("Task loaded and applied: %s", slug) @@ -77,7 +77,7 @@ def connect_task(self, slug: str) -> Any: def _apply_task(self, task: Task) -> None: """Apply a Task definition to this environment. - + Sets up: - Prompt from task.prompt - MCP connections from task.mcp_config @@ -87,11 +87,11 @@ def _apply_task(self, task: Task) -> None: # Set prompt if task.prompt: self.prompt = task.prompt # type: ignore[attr-defined] - + # Connect MCP servers if task.mcp_config: self.connect_mcp_config(task.mcp_config) - + # Configure setup tool calls if task.setup_tool: setup_calls = task.setup_tool @@ -99,7 +99,7 @@ def _apply_task(self, task: Task) -> None: setup_calls = [setup_calls] for call in setup_calls: self.setup_tool(call.name, **(call.arguments or {})) - + # Configure evaluate tool calls if task.evaluate_tool: eval_calls = task.evaluate_tool diff --git a/hud/environment/environment.py b/hud/environment/environment.py index 45141f5d..f85a662e 100644 --- a/hud/environment/environment.py +++ b/hud/environment/environment.py @@ -6,12 +6,12 @@ import logging import types from collections.abc import Awaitable, Callable -from typing import TYPE_CHECKING, Any, Literal +from typing import Any, Literal import mcp.types as mcp_types -from hud.environment.connectors import ConnectorsMixin from hud.environment.connection import Connector +from hud.environment.connectors import ConnectorsMixin from hud.environment.integrations import IntegrationsMixin from hud.environment.mock import MockMixin from hud.environment.router import ConflictResolution, ToolRouter @@ -19,9 +19,6 @@ from hud.trace.mixin import TraceMixin from hud.types import MCPToolResult -if TYPE_CHECKING: - from hud.types import Task - __all__ = ["Environment"] logger = logging.getLogger(__name__) @@ -38,7 +35,7 @@ class Environment( MCPServer, ): """Unified MCP environment that acts as both server and client. - + Features: - Define local tools with @env.tool decorator - Connect to HUD Hub, URLs, or mcp_config dicts @@ -46,7 +43,7 @@ class Environment( - Format tools for any LLM provider - Integrate with popular agent frameworks - Mock mode for testing without real connections - + Connector methods (connect to sources): connect_hub(name) - HUD Hub environment connect_url(url) - MCP server via URL @@ -57,51 +54,53 @@ class Environment( connect_fastapi(app) - Mount FastAPI app as MCP server connect_openapi(spec) - Mount OpenAPI spec as MCP server connect_server(server) - Mount MCPServer/FastMCP directly - + Mock methods (for testing): mock() - Enable mock mode, all tools return mock values unmock() - Disable mock mode mock_tool(name, output) - Set specific mock output for a tool is_mock - Check if mock mode is enabled - + OpenAI integrations: as_openai_chat_tools() - Chat Completions format as_openai_responses_tools() - Responses API format as_openai_agent_tools() - Agents SDK (requires openai-agents) - + Anthropic/Claude integrations: as_claude_tools() - Claude API format as_claude_programmatic_tools() - Programmatic tool use as_anthropic_runner() - Tool runner (requires anthropic) - + Google/Gemini integrations: as_gemini_tools() - Gemini format as_gemini_tool_config() - Tool execution config - + LangChain integrations: as_langchain_tools() - StructuredTools (requires langchain-core) - + Example: ```python env = Environment("my-env") - + + @env.tool def greet(name: str) -> str: return f"Hello, {name}!" - + + env.connect_hub("browser", prefix="browser") - + async with env: # Get tools in any format openai_tools = env.as_openai_chat_tools() claude_tools = env.as_claude_tools() - + # Call tools - automatically routed result = await env.call_tool("greet", name="World") - + # Or pass provider-specific format - auto-detected result = await env.call_tool(response.choices[0].message.tool_calls[0]) - + # Mock mode for testing env.mock() env.mock_tool("browser_navigate", "Navigation successful") @@ -124,17 +123,17 @@ def __init__( self._connections: dict[str, Connector] = {} self._router = ToolRouter(conflict_resolution=conflict_resolution) self._in_context = False - + # Tool call queues - run after connections established self._setup_calls: list[tuple[str, dict[str, Any]]] = [] self._evaluate_calls: list[tuple[str, dict[str, Any]]] = [] - + # Task prompt - set by connect_task or manually self.prompt: str | None = None - + # Track which lifecycle tools we've warned about (only warn once per tool) self._warned_lifecycle_tools: set[str] = set() - + # Initialize mock state self._init_mock() @@ -148,7 +147,7 @@ def as_tools(self) -> list[mcp_types.Tool]: async def call_tool(self, call: Any, /, **kwargs: Any) -> Any: """Call a tool, auto-detecting format and returning matching result format. - + Accepts any format: - String with kwargs: call_tool("navigate", url="...") - Tuple: call_tool(("navigate", {"url": "..."})) @@ -156,18 +155,18 @@ async def call_tool(self, call: Any, /, **kwargs: Any) -> Any: - OpenAI: call_tool(response.choices[0].message.tool_calls[0]) - Claude: call_tool(response.content[0]) # tool_use block - Gemini: call_tool(response.candidates[0].content.parts[0]) - + Returns: Result formatted to match input format (OpenAI -> OpenAI tool message, etc.) """ from hud.environment.utils import format_result, parse_tool_call - + # Parse the tool call (kwargs merged when call is string) parsed, fmt = parse_tool_call(call, **kwargs) self._check_lifecycle_warning(parsed.name) result = await self._execute_tool(parsed.name, parsed.arguments or {}) return format_result(result, parsed, fmt) - + def _check_lifecycle_warning(self, name: str) -> None: """Warn once if calling a setup/evaluate tool manually.""" if name in self._warned_lifecycle_tools: @@ -180,7 +179,8 @@ def _check_lifecycle_warning(self, name: str) -> None: phase = "setup" if name in setup else "evaluate" logger.warning( "Tool '%s' is a %s tool (runs automatically). Manual call may duplicate.", - name, phase, + name, + phase, ) async def call_tools(self, calls: Any) -> list[Any]: @@ -189,14 +189,14 @@ async def call_tools(self, calls: Any) -> list[Any]: return [] if not isinstance(calls, list): return [await self.call_tool(calls)] - + # Filter to tool calls only (skip text blocks, etc.) tool_calls = [] for call in calls: t = call.get("type") if isinstance(call, dict) else getattr(call, "type", None) if t is None or t in ("tool_use", "function"): tool_calls.append(call) - + return await asyncio.gather(*[self.call_tool(c) for c in tool_calls]) # ========================================================================= @@ -206,7 +206,7 @@ async def call_tools(self, calls: Any) -> list[Any]: def setup_tool(self, call: Any, /, **kwargs: Any) -> Environment: """Add a tool call to execute after connections are established.""" from hud.environment.utils import parse_tool_call - + if isinstance(call, str) and kwargs: self._setup_calls.append((call, kwargs)) else: @@ -217,7 +217,7 @@ def setup_tool(self, call: Any, /, **kwargs: Any) -> Environment: def evaluate_tool(self, call: Any, /, **kwargs: Any) -> Environment: """Add a tool call to execute before disconnecting.""" from hud.environment.utils import parse_tool_call - + if isinstance(call, str) and kwargs: self._evaluate_calls.append((call, kwargs)) else: @@ -232,7 +232,7 @@ def evaluate_tool(self, call: Any, /, **kwargs: Any) -> Environment: async def __aenter__(self) -> Environment: """Connect all connectors, build routing, run setup tools.""" self._in_context = True - + # Connect to all servers (on_connect callbacks run first within connect()) sem = asyncio.Semaphore(self.MAX_CONCURRENT_CONNECTIONS) errors: list[tuple[str, Exception]] = [] @@ -246,9 +246,7 @@ async def connect_one(name: str, conn: Connector) -> None: errors.append((name, e)) if self._connections: - await asyncio.gather(*[ - connect_one(n, c) for n, c in self._connections.items() - ]) + await asyncio.gather(*[connect_one(n, c) for n, c in self._connections.items()]) if errors: for conn in self._connections.values(): if conn.is_connected: @@ -257,11 +255,11 @@ async def connect_one(name: str, conn: Connector) -> None: raise ConnectionError(f"Failed to connect to {name}") from err await self._build_routing() - + # Setup tool calls (after connections) for name, args in self._setup_calls: await self._execute_tool(name, args) - + return self async def __aexit__( @@ -272,7 +270,7 @@ async def __aexit__( ) -> None: """Run evaluate tools, exit queue, then disconnect.""" from hud.agents.base import find_reward - + # Evaluate tool calls and collect rewards rewards: list[float] = [] for name, args in self._evaluate_calls: @@ -281,12 +279,12 @@ async def __aexit__( rewards.append(find_reward(result)) except Exception as e: logger.warning("Evaluate tool %s failed: %s", name, e) - + # Store average reward from evaluate tools self._evaluate_reward: float | None = None if rewards: self._evaluate_reward = sum(rewards) / len(rewards) - + self._in_context = False if self._connections: await asyncio.gather(*[c.disconnect() for c in self._connections.values()]) @@ -316,14 +314,14 @@ async def list_tools(self) -> list[mcp_types.Tool]: async def _execute_tool(self, name: str, arguments: dict[str, Any]) -> MCPToolResult: """Execute a tool by name. Routes to local or remote handler. - + If mock mode is enabled, returns a mock result instead of executing. """ # Check mock mode first if self._mock_mode: logger.debug("Mock mode: returning mock result for tool %s", name) return self._get_mock_result(name, arguments) - + if self._router.is_local(name): result = await self._call_tool(name, arguments) return MCPToolResult(content=result.content, isError=False) @@ -346,9 +344,9 @@ async def list_resources(self) -> list[mcp_types.Resource]: resources: list[mcp_types.Resource] = [r.to_mcp_resource() for r in local] if self._connections: - results = await asyncio.gather(*[ - c.list_resources() for c in self._connections.values() - ], return_exceptions=True) + results = await asyncio.gather( + *[c.list_resources() for c in self._connections.values()], return_exceptions=True + ) for r in results: if isinstance(r, list): resources.extend(r) @@ -367,9 +365,12 @@ async def read_resource( if isinstance(result, str): return [mcp_types.TextResourceContents(uri=resource_uri, text=result)] import base64 - return [mcp_types.BlobResourceContents( - uri=resource_uri, blob=base64.b64encode(result).decode() - )] + + return [ + mcp_types.BlobResourceContents( + uri=resource_uri, blob=base64.b64encode(result).decode() + ) + ] except Exception: pass @@ -391,9 +392,9 @@ async def list_prompts(self) -> list[mcp_types.Prompt]: prompts: list[mcp_types.Prompt] = [p.to_mcp_prompt() for p in local] if self._connections: - results = await asyncio.gather(*[ - c.list_prompts() for c in self._connections.values() - ], return_exceptions=True) + results = await asyncio.gather( + *[c.list_prompts() for c in self._connections.values()], return_exceptions=True + ) for r in results: if isinstance(r, list): prompts.extend(r) @@ -442,14 +443,14 @@ def connections(self) -> dict[str, Connector]: @property def is_connected(self) -> bool: return self._in_context - + @property def is_parallelizable(self) -> bool: """True if all connections are remote (can spawn multiple instances).""" if not self._connections: return True # No connections = can parallelize (local tools only) return all(conn.is_remote for conn in self._connections.values()) - + @property def local_connections(self) -> list[str]: """Names of local (non-parallelizable) connections.""" @@ -457,7 +458,7 @@ def local_connections(self) -> list[str]: def _get_env_config(self) -> dict[str, Any] | None: """Get serializable environment configuration for trace storage. - + Returns EnvConfig-compatible dict with: - name: Environment name - hubs: List of hub configs (connect_hub calls) @@ -465,52 +466,46 @@ def _get_env_config(self) -> dict[str, Any] | None: - evaluate_tools: Tools to run before disconnection (MCPToolCall format) """ hub_configs = getattr(self, "_hub_configs", []) - + # Convert setup/evaluate calls to MCPToolCall format - setup_tools = [ - {"name": name, "arguments": args} - for name, args in self._setup_calls - ] - evaluate_tools = [ - {"name": name, "arguments": args} - for name, args in self._evaluate_calls - ] - + setup_tools = [{"name": name, "arguments": args} for name, args in self._setup_calls] + evaluate_tools = [{"name": name, "arguments": args} for name, args in self._evaluate_calls] + # Only return config if there's something to store if not hub_configs and not setup_tools and not evaluate_tools: return None - + return { "name": self.name, "hubs": hub_configs, "setup_tools": setup_tools, "evaluate_tools": evaluate_tools, } - + @property def _all_hubs(self) -> bool: """True if all tools came from connect_hub (fully reproducible). - + Returns False if there are: - Local tools (@env.tool, connect_fastapi, connect_openapi, connect_server) - Non-hub connections (connect_url, connect_mcp, connect_image, etc.) """ hub_configs = getattr(self, "_hub_configs", []) - + # Check for local tools (mounted servers, @env.tool) # _tool_manager comes from MCPServer base class local_tool_count = len(self._tool_manager._tools) if hasattr(self, "_tool_manager") else 0 if local_tool_count > 0: return False - + # No hubs and no connections = trivially all hubs (empty env) if not hub_configs and not self._connections: return True - + # Has connections but no hubs = not all hubs if not hub_configs: return False - + # Compare hub count to connection count return len(hub_configs) >= len(self._connections) diff --git a/hud/environment/integrations/__init__.py b/hud/environment/integrations/__init__.py index 9794ec16..4990abad 100644 --- a/hud/environment/integrations/__init__.py +++ b/hud/environment/integrations/__init__.py @@ -15,22 +15,22 @@ class IntegrationsMixin( LangChainMixin, ): """Combined integration mixin for all providers. - + OpenAI: as_openai_chat_tools() - Chat Completions format as_openai_responses_tools() - Responses API format as_openai_agent_tools() - Agents SDK (requires openai-agents) - + Anthropic/Claude: as_claude_tools() - Claude API format as_claude_programmatic_tools() - Programmatic tool use as_anthropic_runner() - Tool runner (requires anthropic) - + Google/Gemini: as_gemini_tools() - Gemini format as_gemini_tool_config() - Tool config - + LangChain: as_langchain_tools() - StructuredTools (requires langchain-core) """ - pass + diff --git a/hud/environment/integrations/anthropic.py b/hud/environment/integrations/anthropic.py index dc2d3a3c..d1427b4e 100644 --- a/hud/environment/integrations/anthropic.py +++ b/hud/environment/integrations/anthropic.py @@ -8,6 +8,7 @@ # Try to import anthropic try: from anthropic.types.beta import BetaToolResultBlockParam + _HAS_ANTHROPIC = True except ImportError: _HAS_ANTHROPIC = False @@ -21,14 +22,14 @@ class AnthropicMixin: """Mixin providing Anthropic/Claude format conversion and tool runner. - + Format methods (no deps): as_claude_tools() - Claude API format as_claude_programmatic_tools() - Programmatic tool use format - + Integration methods (requires anthropic): as_anthropic_runner() - Tool runner for executing tool_use blocks - + Requires: as_tools() -> list[mcp_types.Tool], call_tool(name, args) """ @@ -44,17 +45,17 @@ async def call_tool(self, name: str, arguments: dict[str, Any]) -> Any: def as_claude_tools(self, *, cache_control: bool = False) -> list[dict[str, Any]]: """Convert to Claude/Anthropic tool format. - + Args: cache_control: Add cache_control for prompt caching - + Returns: List of tool definitions for Claude API. - + Example: ```python from anthropic import Anthropic - + client = Anthropic() async with env: response = client.messages.create( @@ -83,13 +84,13 @@ def as_claude_tools(self, *, cache_control: bool = False) -> list[dict[str, Any] def as_claude_programmatic_tools(self, *, cache_control: bool = False) -> list[dict[str, Any]]: """Convert to Claude programmatic tool use format. - + Programmatic tool use allows Claude to execute tools via code execution. - + Example: ```python from anthropic import Anthropic - + client = Anthropic() async with env: response = client.messages.create( @@ -120,27 +121,27 @@ def as_claude_programmatic_tools(self, *, cache_control: bool = False) -> list[d def as_anthropic_runner(self) -> EnvToolRunner: """Create an Anthropic tool runner for this environment. - + Requires: pip install anthropic - + Returns: EnvToolRunner that can process tool_use blocks from Claude. - + Example: ```python from anthropic import Anthropic - + client = Anthropic() async with env: runner = env.as_anthropic_runner() - + response = client.messages.create( model="claude-sonnet-4-20250514", max_tokens=1024, messages=[{"role": "user", "content": "Navigate to google.com"}], tools=env.as_claude_tools(), ) - + # Execute all tool_use blocks results = [] for block in response.content: @@ -150,16 +151,14 @@ def as_anthropic_runner(self) -> EnvToolRunner: ``` """ if not _HAS_ANTHROPIC: - raise ImportError( - "Anthropic SDK not installed. Install with: pip install anthropic" - ) + raise ImportError("Anthropic SDK not installed. Install with: pip install anthropic") return EnvToolRunner(self) class EnvToolRunner: """Tool runner that executes tools against an Environment.""" - + def __init__(self, env: AnthropicMixin) -> None: self.env = env self._tool_names: set[str] | None = None @@ -173,17 +172,17 @@ def tool_names(self) -> set[str]: async def run(self, tool_use_block: Any) -> dict[str, Any]: """Execute a tool_use block from Claude. - + Args: tool_use_block: A ToolUseBlock from Claude's response. - + Returns: Tool result dict (or BetaToolResultBlockParam if anthropic installed). """ name = tool_use_block.name tool_use_id = tool_use_block.id arguments = tool_use_block.input or {} - + try: result = await self.env.call_tool(name, **arguments) content = result if isinstance(result, str) else json.dumps(result) if result else "" @@ -199,7 +198,7 @@ async def run(self, tool_use_block: Any) -> dict[str, Any]: "content": f"Error: {e}", "is_error": True, } - + # Return typed object if anthropic is available if _HAS_ANTHROPIC and BetaToolResultBlockParam is not None: return BetaToolResultBlockParam(**result_dict) diff --git a/hud/environment/integrations/gemini.py b/hud/environment/integrations/gemini.py index e8899d99..4f7895b4 100644 --- a/hud/environment/integrations/gemini.py +++ b/hud/environment/integrations/gemini.py @@ -12,11 +12,11 @@ class GeminiMixin: """Mixin providing Google/Gemini format conversion. - + Format methods (no deps): as_gemini_tools() - Gemini tool format as_gemini_tool_config() - Tool execution config - + Requires: as_tools() -> list[mcp_types.Tool] """ @@ -25,14 +25,14 @@ def as_tools(self) -> list[mcp_types.Tool]: def as_gemini_tools(self) -> list[dict[str, Any]]: """Convert to Gemini/Google AI tool format. - + Returns: List with function_declarations for Gemini API. - + Example: ```python import google.generativeai as genai - + model = genai.GenerativeModel("gemini-1.5-pro") async with env: response = model.generate_content( @@ -45,16 +45,18 @@ def as_gemini_tools(self) -> list[dict[str, Any]]: result = await env.call_tool(part) ``` """ - return [{ - "function_declarations": [ - { - "name": t.name, - "description": t.description or "", - "parameters": t.inputSchema or {"type": "object", "properties": {}}, - } - for t in self.as_tools() - ] - }] + return [ + { + "function_declarations": [ + { + "name": t.name, + "description": t.description or "", + "parameters": t.inputSchema or {"type": "object", "properties": {}}, + } + for t in self.as_tools() + ] + } + ] def as_gemini_tool_config( self, @@ -62,28 +64,25 @@ def as_gemini_tool_config( allowed_tools: list[str] | None = None, ) -> dict[str, Any]: """Get Gemini tool_config for controlling tool execution. - + Args: mode: "AUTO", "ANY", or "NONE" allowed_tools: If mode is "ANY", list of allowed tool names - + Returns: Tool config dict for Gemini API. - + Example: ```python import google.generativeai as genai - + model = genai.GenerativeModel("gemini-1.5-pro") async with env: # Force specific tool usage response = model.generate_content( "Search for cats", tools=env.as_gemini_tools(), - tool_config=env.as_gemini_tool_config( - mode="ANY", - allowed_tools=["search"] - ), + tool_config=env.as_gemini_tool_config(mode="ANY", allowed_tools=["search"]), ) ``` """ diff --git a/hud/environment/integrations/langchain.py b/hud/environment/integrations/langchain.py index d52eaf87..9b505a08 100644 --- a/hud/environment/integrations/langchain.py +++ b/hud/environment/integrations/langchain.py @@ -10,6 +10,7 @@ # Try to import langchain try: from langchain_core.tools import StructuredTool + _HAS_LANGCHAIN = True except ImportError: _HAS_LANGCHAIN = False @@ -23,10 +24,10 @@ class LangChainMixin: """Mixin providing LangChain integration. - + Integration methods (requires langchain-core): as_langchain_tools() - LangChain StructuredTool objects - + Requires: as_tools() -> list[mcp_types.Tool], call_tool(name, args) """ @@ -38,37 +39,37 @@ async def call_tool(self, name: str, arguments: dict[str, Any]) -> Any: def as_langchain_tools(self) -> list[Any]: """Convert to LangChain StructuredTool objects. - + Requires: pip install langchain-core - + Returns: List of StructuredTool objects for LangChain agents. - + Example: ```python from langchain_openai import ChatOpenAI from langchain.agents import create_tool_calling_agent, AgentExecutor from langchain_core.prompts import ChatPromptTemplate - + llm = ChatOpenAI(model="gpt-4o") async with env: tools = env.as_langchain_tools() - - prompt = ChatPromptTemplate.from_messages([ - ("system", "You are a helpful assistant."), - ("human", "{input}"), - ("placeholder", "{agent_scratchpad}"), - ]) - + + prompt = ChatPromptTemplate.from_messages( + [ + ("system", "You are a helpful assistant."), + ("human", "{input}"), + ("placeholder", "{agent_scratchpad}"), + ] + ) + agent = create_tool_calling_agent(llm, tools, prompt) executor = AgentExecutor(agent=agent, tools=tools) result = await executor.ainvoke({"input": "Navigate to google.com"}) ``` """ if not _HAS_LANGCHAIN: - raise ImportError( - "LangChain not installed. Install with: pip install langchain-core" - ) + raise ImportError("LangChain not installed. Install with: pip install langchain-core") tools = [] for t in self.as_tools(): @@ -80,20 +81,21 @@ def as_langchain_tools(self) -> list[Any]: def _create_structured_tool(env: LangChainMixin, tool: mcp_types.Tool) -> Any: """Create a StructuredTool that calls back to the environment.""" import asyncio - + schema = tool.inputSchema or {"type": "object", "properties": {}} - + def sync_invoke(**kwargs: Any) -> str: """Synchronous wrapper for the tool.""" loop = asyncio.get_event_loop() if loop.is_running(): import concurrent.futures + with concurrent.futures.ThreadPoolExecutor() as executor: future = executor.submit(asyncio.run, env.call_tool(tool.name, **kwargs)) result = future.result() else: result = loop.run_until_complete(env.call_tool(tool.name, **kwargs)) - + if isinstance(result, str): return result return json.dumps(result) if result else "" diff --git a/hud/environment/integrations/openai.py b/hud/environment/integrations/openai.py index 3a188bf5..c261765d 100644 --- a/hud/environment/integrations/openai.py +++ b/hud/environment/integrations/openai.py @@ -10,6 +10,7 @@ # Try to import OpenAI Agents SDK try: from agents import FunctionTool + _HAS_AGENTS = True except ImportError: _HAS_AGENTS = False @@ -23,20 +24,20 @@ class OpenAIMixin: """Mixin providing OpenAI format conversion and Agents SDK integration. - + Format methods (no deps): as_openai_chat_tools() - Chat Completions format as_openai_responses_tools() - Responses API format - + Integration methods (requires openai-agents): as_openai_agent_tools() - Agents SDK FunctionTool objects - + Note: The OpenAI Agents SDK also supports: - HostedMCPTool - MCP tools hosted by OpenAI - MCPServerStdio/Sse/StreamableHttp - Direct MCP server connections - + For MCP server integration, use as_mcp_server() from the mcp integration. - + Requires: as_tools() -> list[mcp_types.Tool], call_tool(name, args) """ @@ -52,17 +53,17 @@ async def call_tool(self, name: str, arguments: dict[str, Any]) -> Any: def as_openai_chat_tools(self, *, strict: bool = False) -> list[dict[str, Any]]: """Convert to OpenAI Chat Completions tool format. - + Args: strict: Enable strict mode for structured outputs - + Returns: List of tool definitions for OpenAI Chat Completions API. - + Example: ```python from openai import OpenAI - + client = OpenAI() async with env: response = client.chat.completions.create( @@ -78,34 +79,36 @@ def as_openai_chat_tools(self, *, strict: bool = False) -> list[dict[str, Any]]: tools = [] for t in self.as_tools(): schema = dict(t.inputSchema) if t.inputSchema else {"type": "object", "properties": {}} - + if strict: schema = ensure_strict_schema(schema) - - tools.append({ - "type": "function", - "function": { - "name": t.name, - "description": t.description or "", - "parameters": schema, - **({"strict": True} if strict else {}), - }, - }) + + tools.append( + { + "type": "function", + "function": { + "name": t.name, + "description": t.description or "", + "parameters": schema, + **({"strict": True} if strict else {}), + }, + } + ) return tools def as_openai_responses_tools(self) -> list[dict[str, Any]]: """Convert to OpenAI Responses API tool format. - + Note: Like Chat Completions, you must execute tools yourself. OpenAI only auto-executes their built-in tools (code_interpreter, etc). - + Returns: List of tool definitions for OpenAI Responses API. - + Example: ```python from openai import OpenAI - + client = OpenAI() async with env: response = client.responses.create( @@ -119,12 +122,15 @@ def as_openai_responses_tools(self) -> list[dict[str, Any]]: result = await env.call_tool(item.name, **item.arguments) ``` """ - return [{ - "type": "function", - "name": t.name, - "description": t.description or "", - "parameters": t.inputSchema or {"type": "object", "properties": {}}, - } for t in self.as_tools()] + return [ + { + "type": "function", + "name": t.name, + "description": t.description or "", + "parameters": t.inputSchema or {"type": "object", "properties": {}}, + } + for t in self.as_tools() + ] # ========================================================================= # Agents SDK Integration (requires openai-agents) @@ -132,25 +138,25 @@ def as_openai_responses_tools(self) -> list[dict[str, Any]]: def as_openai_agent_tools(self) -> list[Any]: """Convert to OpenAI Agents SDK FunctionTool objects. - + This creates FunctionTool objects that automatically execute against this environment. The Agents SDK Runner handles the tool loop. - + Note: The Agents SDK also supports other tool types: - HostedMCPTool: MCP tools hosted by OpenAI - MCPServerStdio/Sse/StreamableHttp: Direct MCP server connections - + For direct MCP integration, consider using as_mcp_server(). - + Requires: pip install openai-agents - + Returns: List of FunctionTool objects for OpenAI Agents SDK. - + Example: ```python from agents import Agent, Runner - + async with env: agent = Agent( name="browser-agent", @@ -176,20 +182,21 @@ def as_openai_agent_tools(self) -> list[Any]: def _create_function_tool(env: OpenAIMixin, tool: mcp_types.Tool) -> Any: """Create a FunctionTool that calls back to the environment.""" import asyncio - + schema = tool.inputSchema or {"type": "object", "properties": {}} - + def sync_wrapper(**kwargs: Any) -> str: """Synchronous wrapper for the tool.""" loop = asyncio.get_event_loop() if loop.is_running(): import concurrent.futures + with concurrent.futures.ThreadPoolExecutor() as executor: future = executor.submit(asyncio.run, env.call_tool(tool.name, **kwargs)) result = future.result() else: result = loop.run_until_complete(env.call_tool(tool.name, **kwargs)) - + if isinstance(result, str): return result return json.dumps(result) if result else "" diff --git a/hud/environment/mock.py b/hud/environment/mock.py index 37711a69..f0f70541 100644 --- a/hud/environment/mock.py +++ b/hud/environment/mock.py @@ -19,21 +19,21 @@ def generate_mock_value(schema: dict[str, Any], depth: int = 0) -> Any: """Generate a reasonable mock value from a JSON schema. - + Args: schema: JSON schema dict with 'type', 'properties', etc. depth: Current recursion depth (to prevent infinite loops). - + Returns: A mock value that matches the schema. """ if depth > 10: # Prevent infinite recursion return None - + # Handle $ref - we don't resolve refs, just return placeholder if "$ref" in schema: return {} - + # Handle anyOf/oneOf/allOf - pick first option if "anyOf" in schema: return generate_mock_value(schema["anyOf"][0], depth + 1) @@ -47,20 +47,20 @@ def generate_mock_value(schema: dict[str, Any], depth: int = 0) -> Any: if isinstance(result, dict): merged.update(result) return merged - + # Check for const or enum first if "const" in schema: return schema["const"] if "enum" in schema: return schema["enum"][0] if schema["enum"] else None - + # Check for default value if "default" in schema: return schema["default"] - + # Handle by type schema_type = schema.get("type") - + if schema_type == "string": # Check for format hints fmt = schema.get("format", "") @@ -83,7 +83,7 @@ def generate_mock_value(schema: dict[str, Any], depth: int = 0) -> Any: if "id" in title: return "mock_id" return "mock_string" - + if schema_type == "number" or schema_type == "integer": # Check for bounds minimum = schema.get("minimum", 0) @@ -91,32 +91,32 @@ def generate_mock_value(schema: dict[str, Any], depth: int = 0) -> Any: if schema_type == "integer": return int((minimum + maximum) / 2) if maximum != float("inf") else minimum return float((minimum + maximum) / 2) if maximum != float("inf") else float(minimum) - + if schema_type == "boolean": return True - + if schema_type == "null": return None - + if schema_type == "array": items_schema = schema.get("items", {}) if items_schema: # Generate one item return [generate_mock_value(items_schema, depth + 1)] return [] - + if schema_type == "object" or "properties" in schema: result: dict[str, Any] = {} properties = schema.get("properties", {}) required = set(schema.get("required", [])) - + for prop_name, prop_schema in properties.items(): # Only include required properties or first few optional ones if prop_name in required or len(result) < 3: result[prop_name] = generate_mock_value(prop_schema, depth + 1) - + return result - + # Handle list of types if isinstance(schema_type, list): # Pick first non-null type @@ -124,23 +124,23 @@ def generate_mock_value(schema: dict[str, Any], depth: int = 0) -> Any: if t != "null": return generate_mock_value({"type": t}, depth + 1) return None - + # Fallback for unknown schema return None def generate_mock_tool_result(tool: mcp_types.Tool) -> MCPToolResult: """Generate a mock result for a tool based on its output schema. - + Args: tool: MCP Tool with inputSchema and optionally outputSchema. - + Returns: MCPToolResult with mock content. """ # Check if tool has an output schema output_schema = getattr(tool, "outputSchema", None) - + if output_schema: mock_value = generate_mock_value(output_schema) content_text = str(mock_value) if mock_value is not None else "mock_result" @@ -157,7 +157,7 @@ def generate_mock_tool_result(tool: mcp_types.Tool) -> MCPToolResult: content_text = "0" else: content_text = "mock_success" - + return MCPToolResult( content=[mcp_types.TextContent(type="text", text=content_text)], isError=False, @@ -166,74 +166,74 @@ def generate_mock_tool_result(tool: mcp_types.Tool) -> MCPToolResult: class MockMixin: """Mixin that adds mock functionality to Environment. - + When mock mode is enabled: - All tool calls return mock values instead of executing - Specific tools can have custom mock outputs via mock_tool() - Tools are automatically mocked with reasonable defaults based on their schemas - + Usage: env = Environment("test").connect_hub("browser") env.mock() # Enable mock mode - + # Set specific mock outputs env.mock_tool("navigate", "Navigation successful") env.mock_tool("screenshot", {"image": "base64data..."}) - + async with env: result = await env.call_tool("navigate", url="https://example.com") # Returns: MCPToolResult with "Navigation successful" """ - + _mock_mode: bool _mock_outputs: dict[str, Any] _mock_tool_schemas: dict[str, mcp_types.Tool] - + def _init_mock(self) -> None: """Initialize mock state. Called from Environment.__init__.""" self._mock_mode = False self._mock_outputs = {} self._mock_tool_schemas = {} - - def mock(self) -> "Environment": + + def mock(self) -> Environment: """Enable mock mode - all tool calls will return mock values. - + Returns: self for chaining. - + Example: env = Environment("test").connect_hub("browser").mock() """ self._mock_mode = True logger.info("Mock mode enabled for environment %s", getattr(self, "name", "unknown")) return self # type: ignore[return-value] - - def unmock(self) -> "Environment": + + def unmock(self) -> Environment: """Disable mock mode - tool calls will execute normally. - + Returns: self for chaining. """ self._mock_mode = False logger.info("Mock mode disabled for environment %s", getattr(self, "name", "unknown")) return self # type: ignore[return-value] - + @property def is_mock(self) -> bool: """Check if mock mode is enabled.""" return self._mock_mode - - def mock_tool(self, name: str, output: Any) -> "Environment": + + def mock_tool(self, name: str, output: Any) -> Environment: """Set a specific mock output for a tool. - + Args: name: Tool name (with prefix if applicable). output: The value to return when this tool is called. Can be a string, dict, or any JSON-serializable value. - + Returns: self for chaining. - + Example: env.mock_tool("navigate", "Success") env.mock_tool("screenshot", {"type": "image", "data": "..."}) @@ -242,19 +242,19 @@ def mock_tool(self, name: str, output: Any) -> "Environment": self._mock_outputs[name] = output logger.debug("Mock output set for tool %s", name) return self # type: ignore[return-value] - + def _get_mock_result(self, name: str, arguments: dict[str, Any]) -> MCPToolResult: """Get mock result for a tool call. - + Priority: 1. Custom mock output set via mock_tool() 2. Auto-generated mock based on tool's output schema 3. Default mock value - + Args: name: Tool name. arguments: Tool arguments (for potential future use). - + Returns: MCPToolResult with mock content. """ @@ -266,20 +266,21 @@ def _get_mock_result(self, name: str, arguments: dict[str, Any]) -> MCPToolResul content_text = output else: import json + try: content_text = json.dumps(output) except (TypeError, ValueError): content_text = str(output) - + return MCPToolResult( content=[mcp_types.TextContent(type="text", text=content_text)], isError=False, ) - + # Try to find tool schema for auto-generation if name in self._mock_tool_schemas: return generate_mock_tool_result(self._mock_tool_schemas[name]) - + # Check router for tool schema router = getattr(self, "_router", None) if router: @@ -287,20 +288,19 @@ def _get_mock_result(self, name: str, arguments: dict[str, Any]) -> MCPToolResul if tool.name == name: self._mock_tool_schemas[name] = tool return generate_mock_tool_result(tool) - + # Default fallback return MCPToolResult( content=[mcp_types.TextContent(type="text", text="mock_success")], isError=False, ) - + def _populate_mock_schemas(self) -> None: """Populate mock tool schemas from router after connection. - + Called after _build_routing to cache tool schemas for mock generation. """ router = getattr(self, "_router", None) if router: for tool in router.tools: self._mock_tool_schemas[tool.name] = tool - diff --git a/hud/environment/router.py b/hud/environment/router.py index ccdb5b83..2dc88a36 100644 --- a/hud/environment/router.py +++ b/hud/environment/router.py @@ -12,7 +12,7 @@ if TYPE_CHECKING: from hud.environment.connection import Connector -__all__ = ["ConflictResolution", "ToolRouter", "LOCAL_CONNECTION"] +__all__ = ["LOCAL_CONNECTION", "ConflictResolution", "ToolRouter"] logger = logging.getLogger(__name__) @@ -21,10 +21,11 @@ class ConflictResolution(str, Enum): """Strategy for resolving tool name conflicts.""" - PREFIX = "prefix" # Add connection name as prefix + + PREFIX = "prefix" # Add connection name as prefix FIRST_WINS = "first_wins" # First connection wins - LAST_WINS = "last_wins" # Last connection wins - ERROR = "error" # Raise error on conflict + LAST_WINS = "last_wins" # Last connection wins + ERROR = "error" # Raise error on conflict @dataclass @@ -60,7 +61,7 @@ def build( connection_order: list[str], ) -> None: """Build routing from local tools and connection caches. - + Local tools always have priority over remote tools. """ self.clear() diff --git a/hud/environment/tests/__init__.py b/hud/environment/tests/__init__.py index 9364a7c0..6703f70b 100644 --- a/hud/environment/tests/__init__.py +++ b/hud/environment/tests/__init__.py @@ -1,2 +1 @@ """Tests for hud.environment module.""" - diff --git a/hud/environment/tests/test_connection.py b/hud/environment/tests/test_connection.py index cc6fdba4..4ce44fa2 100644 --- a/hud/environment/tests/test_connection.py +++ b/hud/environment/tests/test_connection.py @@ -55,7 +55,7 @@ def test_init_stores_transport_config(self) -> None: """__init__ stores transport config, doesn't create client.""" transport = {"server": {"url": "http://example.com"}} config = ConnectionConfig() - + connector = Connector( transport=transport, config=config, @@ -63,7 +63,7 @@ def test_init_stores_transport_config(self) -> None: connection_type=ConnectionType.REMOTE, auth="test-token", ) - + assert connector._transport == transport assert connector._auth == "test-token" assert connector.name == "test" @@ -124,17 +124,15 @@ async def test_connect_creates_client(self) -> None: connection_type=ConnectionType.REMOTE, auth="test-token", ) - + mock_client = MagicMock() mock_client.__aenter__ = AsyncMock(return_value=mock_client) mock_client.is_connected = MagicMock(return_value=True) - + # Patch where it's imported from, not where it's used - with patch( - "fastmcp.client.Client", return_value=mock_client - ) as mock_cls: + with patch("fastmcp.client.Client", return_value=mock_client) as mock_cls: await connector.connect() - + # Client was created with correct args mock_cls.assert_called_once_with(transport=transport, auth="test-token") # Client context was entered @@ -151,15 +149,15 @@ async def test_disconnect_clears_client(self) -> None: name="test", connection_type=ConnectionType.REMOTE, ) - + mock_client = MagicMock() mock_client.__aexit__ = AsyncMock(return_value=None) mock_client.is_connected = MagicMock(return_value=True) connector.client = mock_client connector._tools_cache = [MagicMock()] - + await connector.disconnect() - + mock_client.__aexit__.assert_called_once_with(None, None, None) assert connector.client is None assert connector._tools_cache is None @@ -173,7 +171,7 @@ async def test_list_tools_raises_when_not_connected(self) -> None: name="test", connection_type=ConnectionType.REMOTE, ) - + with pytest.raises(RuntimeError, match="Not connected"): await connector.list_tools() @@ -186,16 +184,18 @@ async def test_list_tools_applies_include_filter(self) -> None: name="test", connection_type=ConnectionType.REMOTE, ) - + mock_client = MagicMock() - mock_client.list_tools = AsyncMock(return_value=[ - mcp_types.Tool(name="tool1", description="Tool 1", inputSchema={}), - mcp_types.Tool(name="tool2", description="Tool 2", inputSchema={}), - ]) + mock_client.list_tools = AsyncMock( + return_value=[ + mcp_types.Tool(name="tool1", description="Tool 1", inputSchema={}), + mcp_types.Tool(name="tool2", description="Tool 2", inputSchema={}), + ] + ) connector.client = mock_client - + tools = await connector.list_tools() - + assert len(tools) == 1 assert tools[0].name == "tool1" @@ -208,16 +208,18 @@ async def test_list_tools_applies_exclude_filter(self) -> None: name="test", connection_type=ConnectionType.REMOTE, ) - + mock_client = MagicMock() - mock_client.list_tools = AsyncMock(return_value=[ - mcp_types.Tool(name="tool1", description="Tool 1", inputSchema={}), - mcp_types.Tool(name="tool2", description="Tool 2", inputSchema={}), - ]) + mock_client.list_tools = AsyncMock( + return_value=[ + mcp_types.Tool(name="tool1", description="Tool 1", inputSchema={}), + mcp_types.Tool(name="tool2", description="Tool 2", inputSchema={}), + ] + ) connector.client = mock_client - + tools = await connector.list_tools() - + assert len(tools) == 1 assert tools[0].name == "tool1" @@ -230,15 +232,17 @@ async def test_list_tools_applies_prefix(self) -> None: name="test", connection_type=ConnectionType.REMOTE, ) - + mock_client = MagicMock() - mock_client.list_tools = AsyncMock(return_value=[ - mcp_types.Tool(name="tool1", description="Tool 1", inputSchema={}), - ]) + mock_client.list_tools = AsyncMock( + return_value=[ + mcp_types.Tool(name="tool1", description="Tool 1", inputSchema={}), + ] + ) connector.client = mock_client - + tools = await connector.list_tools() - + assert len(tools) == 1 assert tools[0].name == "myprefix_tool1" @@ -251,15 +255,17 @@ async def test_list_tools_caches_results(self) -> None: name="test", connection_type=ConnectionType.REMOTE, ) - + mock_client = MagicMock() - mock_client.list_tools = AsyncMock(return_value=[ - mcp_types.Tool(name="tool1", description="Tool 1", inputSchema={}), - ]) + mock_client.list_tools = AsyncMock( + return_value=[ + mcp_types.Tool(name="tool1", description="Tool 1", inputSchema={}), + ] + ) connector.client = mock_client - + tools = await connector.list_tools() - + assert connector._tools_cache == tools assert connector.cached_tools == tools @@ -272,14 +278,14 @@ async def test_call_tool_strips_prefix(self) -> None: name="test", connection_type=ConnectionType.REMOTE, ) - + mock_result = mcp_types.CallToolResult(content=[], isError=False) mock_client = MagicMock() mock_client.call_tool_mcp = AsyncMock(return_value=mock_result) connector.client = mock_client - + await connector.call_tool("myprefix_tool1", {"arg": "value"}) - + # Prefix should be stripped mock_client.call_tool_mcp.assert_called_once_with("tool1", {"arg": "value"}) @@ -292,7 +298,7 @@ async def test_call_tool_raises_when_not_connected(self) -> None: name="test", connection_type=ConnectionType.REMOTE, ) - + with pytest.raises(RuntimeError, match="Not connected"): await connector.call_tool("tool1", {}) @@ -304,9 +310,8 @@ def test_repr(self) -> None: name="my-server", connection_type=ConnectionType.REMOTE, ) - + repr_str = repr(connector) assert "my-server" in repr_str assert "remote" in repr_str assert "connected=False" in repr_str - diff --git a/hud/environment/tests/test_connectors.py b/hud/environment/tests/test_connectors.py index d0e05883..03c13796 100644 --- a/hud/environment/tests/test_connectors.py +++ b/hud/environment/tests/test_connectors.py @@ -14,14 +14,14 @@ class TestBaseConnectorMixin: def test_add_connection_stores_transport_config(self) -> None: """_add_connection stores transport, doesn't create client.""" from hud.environment.connectors.base import BaseConnectorMixin - + class TestEnv(BaseConnectorMixin): def __init__(self) -> None: self._connections: dict[str, Connector] = {} - + env = TestEnv() transport = {"server": {"url": "http://example.com"}} - + env._add_connection( "test-server", transport, @@ -29,7 +29,7 @@ def __init__(self) -> None: auth="test-token", prefix="myprefix", ) - + assert "test-server" in env._connections conn = env._connections["test-server"] assert conn._transport == transport @@ -40,18 +40,18 @@ def __init__(self) -> None: def test_add_connection_returns_self(self) -> None: """_add_connection returns self for chaining.""" from hud.environment.connectors.base import BaseConnectorMixin - + class TestEnv(BaseConnectorMixin): def __init__(self) -> None: self._connections: dict[str, Connector] = {} - + env = TestEnv() result = env._add_connection( "test", {}, connection_type=ConnectionType.REMOTE, ) - + assert result is env @@ -61,11 +61,11 @@ class TestMCPConfigConnectorMixin: def test_connect_mcp_detects_local_connection(self) -> None: """connect_mcp detects LOCAL type from command in config.""" from hud.environment.connectors.mcp_config import MCPConfigConnectorMixin - + class TestEnv(MCPConfigConnectorMixin): def __init__(self) -> None: self._connections: dict[str, Connector] = {} - + env = TestEnv() config = { "filesystem": { @@ -73,65 +73,65 @@ def __init__(self) -> None: "args": ["-y", "@modelcontextprotocol/server-filesystem"], } } - + env.connect_mcp(config) - + conn = env._connections["filesystem"] assert conn.connection_type == ConnectionType.LOCAL def test_connect_mcp_detects_remote_connection(self) -> None: """connect_mcp detects REMOTE type from URL in config.""" from hud.environment.connectors.mcp_config import MCPConfigConnectorMixin - + class TestEnv(MCPConfigConnectorMixin): def __init__(self) -> None: self._connections: dict[str, Connector] = {} - + env = TestEnv() config = { "browser": { "url": "https://mcp.hud.ai/browser", } } - + env.connect_mcp(config) - + conn = env._connections["browser"] assert conn.connection_type == ConnectionType.REMOTE def test_connect_mcp_uses_alias(self) -> None: """connect_mcp uses alias if provided.""" from hud.environment.connectors.mcp_config import MCPConfigConnectorMixin - + class TestEnv(MCPConfigConnectorMixin): def __init__(self) -> None: self._connections: dict[str, Connector] = {} - + env = TestEnv() config = {"server": {"url": "http://example.com"}} - + env.connect_mcp(config, alias="my-alias") - + assert "my-alias" in env._connections assert "server" not in env._connections def test_connect_mcp_config_creates_multiple_connections(self) -> None: """connect_mcp_config creates a connection for each server.""" from hud.environment.connectors.mcp_config import MCPConfigConnectorMixin - + class TestEnv(MCPConfigConnectorMixin): def __init__(self) -> None: self._connections: dict[str, Connector] = {} - + env = TestEnv() mcp_config = { "server1": {"url": "http://example1.com"}, "server2": {"url": "http://example2.com"}, "server3": {"command": "npx", "args": ["server"]}, } - + env.connect_mcp_config(mcp_config) - + assert len(env._connections) == 3 assert "server1" in env._connections assert "server2" in env._connections @@ -144,17 +144,17 @@ class TestRemoteConnectorMixin: def test_connect_url_creates_remote_connection(self) -> None: """connect_url creates REMOTE connection.""" from hud.environment.connectors.remote import RemoteConnectorMixin - + class TestEnv(RemoteConnectorMixin): def __init__(self) -> None: self._connections: dict[str, Connector] = {} - + def mount(self, server: Any, *, prefix: str | None = None) -> None: pass - + env = TestEnv() env.connect_url("https://mcp.example.com", alias="example") - + assert "example" in env._connections conn = env._connections["example"] assert conn.connection_type == ConnectionType.REMOTE @@ -162,21 +162,21 @@ def mount(self, server: Any, *, prefix: str | None = None) -> None: def test_connect_url_extracts_auth_from_headers(self) -> None: """connect_url extracts Authorization from headers.""" from hud.environment.connectors.remote import RemoteConnectorMixin - + class TestEnv(RemoteConnectorMixin): def __init__(self) -> None: self._connections: dict[str, Connector] = {} - + def mount(self, server: Any, *, prefix: str | None = None) -> None: pass - + env = TestEnv() env.connect_url( "https://mcp.example.com", headers={"Authorization": "Bearer my-token"}, alias="example", ) - + conn = env._connections["example"] assert conn._auth == "Bearer my-token" @@ -184,14 +184,14 @@ def mount(self, server: Any, *, prefix: str | None = None) -> None: def test_connect_hub_fetches_config(self, mock_httpx_cls: MagicMock) -> None: """connect_hub fetches mcp_config from API.""" from hud.environment.connectors.remote import RemoteConnectorMixin - + class TestEnv(RemoteConnectorMixin): def __init__(self) -> None: self._connections: dict[str, Connector] = {} - + def mount(self, server: Any, *, prefix: str | None = None) -> None: pass - + # Mock httpx response mock_response = MagicMock() mock_response.json.return_value = { @@ -200,20 +200,20 @@ def mount(self, server: Any, *, prefix: str | None = None) -> None: } } mock_response.raise_for_status = MagicMock() - + mock_client = MagicMock() mock_client.get.return_value = mock_response mock_client.__enter__ = MagicMock(return_value=mock_client) mock_client.__exit__ = MagicMock(return_value=None) mock_httpx_cls.return_value = mock_client - + env = TestEnv() with patch("hud.settings.settings") as mock_settings: mock_settings.hud_api_url = "https://api.hud.so" mock_settings.api_key = "test-key" - + env.connect_hub("hud/browser") - + assert "browser" in env._connections @@ -224,21 +224,21 @@ class TestTaskConnectorMixin: def test_connect_task_fetches_and_applies_config(self, mock_httpx_cls: MagicMock) -> None: """connect_task fetches task and applies mcp_config.""" from hud.environment.connectors.task import TaskConnectorMixin - + class TestEnv(TaskConnectorMixin): def __init__(self) -> None: self._connections: dict[str, Connector] = {} self._setup_calls: list[tuple[str, dict[str, Any]]] = [] self._evaluate_calls: list[tuple[str, dict[str, Any]]] = [] - + def setup_tool(self, call: Any, /, **kwargs: Any) -> Any: self._setup_calls.append((call, kwargs)) return self - + def evaluate_tool(self, call: Any, /, **kwargs: Any) -> Any: self._evaluate_calls.append((call, kwargs)) return self - + # Mock httpx response with task data mock_response = MagicMock() mock_response.json.return_value = { @@ -251,19 +251,18 @@ def evaluate_tool(self, call: Any, /, **kwargs: Any) -> Any: "evaluate_tool": None, } mock_response.raise_for_status = MagicMock() - + mock_client = MagicMock() mock_client.get.return_value = mock_response mock_client.__enter__ = MagicMock(return_value=mock_client) mock_client.__exit__ = MagicMock(return_value=None) mock_httpx_cls.return_value = mock_client - + env = TestEnv() with patch("hud.settings.settings") as mock_settings: mock_settings.hud_api_url = "https://api.hud.so" mock_settings.api_key = "test-key" - + env.connect_task("my-org/my-task") - - assert "browser" in env._connections + assert "browser" in env._connections diff --git a/hud/environment/tests/test_environment.py b/hud/environment/tests/test_environment.py index b23bddf8..f3eeff4e 100644 --- a/hud/environment/tests/test_environment.py +++ b/hud/environment/tests/test_environment.py @@ -3,9 +3,7 @@ from __future__ import annotations from typing import Any -from unittest.mock import AsyncMock, MagicMock, patch -import mcp.types as mcp_types import pytest @@ -29,8 +27,8 @@ def test_prompt_can_be_set(self) -> None: def test_prompt_set_from_task(self) -> None: """connect_task sets prompt from task.prompt.""" - from hud.environment.connectors.task import TaskConnectorMixin from hud.environment.connection import Connector + from hud.environment.connectors.task import TaskConnectorMixin from hud.types import Task class TestEnv(TaskConnectorMixin): @@ -63,12 +61,12 @@ async def test_context_manager_sets_in_context_flag(self) -> None: from hud.environment import Environment env = Environment("test") - + assert env._in_context is False - + async with env: assert env._in_context is True - + assert env._in_context is False @pytest.mark.asyncio @@ -77,7 +75,7 @@ async def test_context_manager_no_connections(self) -> None: from hud.environment import Environment env = Environment("test") - + async with env: # Should work without connections pass @@ -92,10 +90,10 @@ async def test_list_resources_empty(self) -> None: from hud.environment import Environment env = Environment("test") - + async with env: resources = await env.list_resources() - + assert resources == [] @pytest.mark.asyncio @@ -104,7 +102,7 @@ async def test_read_resource_not_found(self) -> None: from hud.environment import Environment env = Environment("test") - + async with env: with pytest.raises(ValueError, match="Resource not found"): await env.read_resource("file://nonexistent.txt") @@ -119,10 +117,10 @@ async def test_list_prompts_empty(self) -> None: from hud.environment import Environment env = Environment("test") - + async with env: prompts = await env.list_prompts() - + assert prompts == [] @pytest.mark.asyncio @@ -131,7 +129,7 @@ async def test_get_prompt_not_found(self) -> None: from hud.environment import Environment env = Environment("test") - + async with env: with pytest.raises(ValueError, match="Prompt not found"): await env.get_prompt("nonexistent") @@ -189,4 +187,3 @@ def test_chaining_multiple_setup_calls(self) -> None: ) assert len(env._setup_calls) == 2 - diff --git a/hud/environment/tests/test_integrations.py b/hud/environment/tests/test_integrations.py index 30643427..713d0568 100644 --- a/hud/environment/tests/test_integrations.py +++ b/hud/environment/tests/test_integrations.py @@ -3,13 +3,13 @@ from __future__ import annotations from typing import Any -from unittest.mock import MagicMock import mcp.types as mcp_types -import pytest -def create_mock_tool(name: str, description: str = "", schema: dict | None = None) -> mcp_types.Tool: +def create_mock_tool( + name: str, description: str = "", schema: dict | None = None +) -> mcp_types.Tool: """Create a mock MCP tool for testing.""" return mcp_types.Tool( name=name, @@ -28,11 +28,15 @@ def test_as_openai_chat_tools_basic(self) -> None: class TestEnv(OpenAIMixin): def as_tools(self) -> list[mcp_types.Tool]: return [ - create_mock_tool("navigate", "Navigate to URL", { - "type": "object", - "properties": {"url": {"type": "string"}}, - "required": ["url"], - }), + create_mock_tool( + "navigate", + "Navigate to URL", + { + "type": "object", + "properties": {"url": {"type": "string"}}, + "required": ["url"], + }, + ), ] async def call_tool(self, name: str, arguments: dict[str, Any]) -> Any: @@ -109,10 +113,14 @@ def test_as_claude_tools_basic(self) -> None: class TestEnv(AnthropicMixin): def as_tools(self) -> list[mcp_types.Tool]: return [ - create_mock_tool("click", "Click element", { - "type": "object", - "properties": {"selector": {"type": "string"}}, - }), + create_mock_tool( + "click", + "Click element", + { + "type": "object", + "properties": {"selector": {"type": "string"}}, + }, + ), ] async def call_tool(self, name: str, arguments: dict[str, Any]) -> Any: @@ -170,10 +178,14 @@ def test_as_gemini_tools_basic(self) -> None: class TestEnv(GeminiMixin): def as_tools(self) -> list[mcp_types.Tool]: return [ - create_mock_tool("search", "Search query", { - "type": "object", - "properties": {"query": {"type": "string"}}, - }), + create_mock_tool( + "search", + "Search query", + { + "type": "object", + "properties": {"query": {"type": "string"}}, + }, + ), ] env = TestEnv() @@ -243,4 +255,3 @@ def as_tools(self) -> list[mcp_types.Tool]: config = env.as_gemini_tool_config(mode="NONE") assert config["function_calling_config"]["mode"] == "NONE" - diff --git a/hud/environment/tests/test_local_connectors.py b/hud/environment/tests/test_local_connectors.py index 488fd8fc..018d68cb 100644 --- a/hud/environment/tests/test_local_connectors.py +++ b/hud/environment/tests/test_local_connectors.py @@ -5,8 +5,6 @@ from typing import Any from unittest.mock import MagicMock, patch -import pytest - from hud.environment.connection import ConnectionType, Connector @@ -201,4 +199,3 @@ def mount(self, server: Any, *, prefix: str | None = None) -> None: result = env.connect_fastapi(MagicMock()) assert result is env - diff --git a/hud/environment/types.py b/hud/environment/types.py index e911ffe8..8e8fcd97 100644 --- a/hud/environment/types.py +++ b/hud/environment/types.py @@ -11,7 +11,7 @@ class HubConfig(BaseModel): """Configuration for a single hub connection.""" - + slug: str alias: str | None = None prefix: str | None = None @@ -21,9 +21,8 @@ class HubConfig(BaseModel): class EnvConfig(BaseModel): """Environment configuration for trace reproducibility.""" - + name: str hubs: list[HubConfig] = [] setup_tools: list[MCPToolCall] = [] evaluate_tools: list[MCPToolCall] = [] - diff --git a/hud/environment/utils/formats.py b/hud/environment/utils/formats.py index 2a9e3b61..b299dadd 100644 --- a/hud/environment/utils/formats.py +++ b/hud/environment/utils/formats.py @@ -19,16 +19,18 @@ class ToolFormat(Enum): """Detected tool call format.""" - OPENAI = auto() # function.arguments as JSON string - CLAUDE = auto() # type="tool_use", input as dict - GEMINI = auto() # functionCall with args - MCP = auto() # name + arguments + + OPENAI = auto() # function.arguments as JSON string + CLAUDE = auto() # type="tool_use", input as dict + GEMINI = auto() # functionCall with args + MCP = auto() # name + arguments # ----------------------------------------------------------------------------- # Parsing # ----------------------------------------------------------------------------- + def _to_dict(obj: Any) -> dict[str, Any]: """Convert object to dict for uniform processing.""" if isinstance(obj, dict): @@ -54,7 +56,7 @@ def _parse_json_args(args: Any) -> dict[str, Any]: def parse_tool_call(call: Any, **kwargs: Any) -> tuple[MCPToolCall, ToolFormat]: """Parse any tool call format into (MCPToolCall, ToolFormat). - + Supports: - String (tool name only, or with kwargs) - Tuple: (name,), (name, args), (name, args, id) @@ -63,33 +65,33 @@ def parse_tool_call(call: Any, **kwargs: Any) -> tuple[MCPToolCall, ToolFormat]: - Claude: {type: "tool_use", name, input, id} - Gemini: {functionCall: {name, args}} or {name, args} - Generic: {name, arguments} - + Args: call: Tool call in any supported format. **kwargs: Additional arguments (merged when call is a string). - + Returns: Tuple of (MCPToolCall, ToolFormat) for the parsed call. - + Raises: ValueError: If format is unrecognized. """ # Primitives if isinstance(call, str): return MCPToolCall(name=call, arguments=kwargs or {}), ToolFormat.MCP - + if isinstance(call, tuple): tc = MCPToolCall(name=call[0], arguments=call[1] if len(call) > 1 else {}) if len(call) > 2: tc.id = call[2] return tc, ToolFormat.MCP - + if isinstance(call, MCPToolCall): return call, ToolFormat.MCP - + # Convert to dict d = _to_dict(call) - + # OpenAI: {function: {name, arguments}, id} if "function" in d: f = _to_dict(d["function"]) if not isinstance(d["function"], dict) else d["function"] @@ -97,29 +99,29 @@ def parse_tool_call(call: Any, **kwargs: Any) -> tuple[MCPToolCall, ToolFormat]: if d.get("id"): tc.id = d["id"] return tc, ToolFormat.OPENAI - + # Claude: {type: "tool_use", name, input, id} if d.get("type") == "tool_use": tc = MCPToolCall(name=d["name"], arguments=d.get("input") or {}) if d.get("id"): tc.id = d["id"] return tc, ToolFormat.CLAUDE - + # Gemini: {functionCall: {name, args}} or {name, args} if "functionCall" in d: fc = d["functionCall"] return MCPToolCall(name=fc["name"], arguments=fc.get("args") or {}), ToolFormat.GEMINI - + if "args" in d and "name" in d and "arguments" not in d: return MCPToolCall(name=d["name"], arguments=d.get("args") or {}), ToolFormat.GEMINI - + # Generic: {name, arguments/input} if "name" in d: tc = MCPToolCall(name=d["name"], arguments=d.get("arguments") or d.get("input") or {}) if d.get("id"): tc.id = d["id"] return tc, ToolFormat.MCP - + raise ValueError(f"Unrecognized tool call format: {list(d.keys())}") @@ -131,10 +133,10 @@ def _is_tool_block(item: Any) -> bool: def parse_tool_calls(calls: Any) -> list[tuple[MCPToolCall, ToolFormat]]: """Parse multiple tool calls, filtering non-tool content (e.g. Claude TextBlock). - + Args: calls: Single call or list of calls in any format. - + Returns: List of (MCPToolCall, ToolFormat) tuples. """ @@ -145,7 +147,7 @@ def parse_tool_calls(calls: Any) -> list[tuple[MCPToolCall, ToolFormat]]: return [parse_tool_call(calls)] except ValueError: return [] - + results = [] for item in calls: if not _is_tool_block(item): @@ -161,12 +163,13 @@ def parse_tool_calls(calls: Any) -> list[tuple[MCPToolCall, ToolFormat]]: # Result Formatting # ----------------------------------------------------------------------------- + def result_to_string(result: MCPToolResult) -> str: """Convert MCPToolResult content to string. - + Args: result: MCP tool result with content blocks. - + Returns: String representation of the result content. """ @@ -183,12 +186,12 @@ def result_to_string(result: MCPToolResult) -> str: def format_result(result: MCPToolResult, tc: MCPToolCall, fmt: ToolFormat) -> Any: """Format MCPToolResult based on the input format. - + Args: result: MCP tool result. tc: Original tool call (for id/name). fmt: Target format. - + Returns: OpenAI: {"role": "tool", "tool_call_id": ..., "content": ...} Claude: {"type": "tool_result", "tool_use_id": ..., "content": ..., "is_error"?: bool} @@ -196,18 +199,17 @@ def format_result(result: MCPToolResult, tc: MCPToolCall, fmt: ToolFormat) -> An MCP: MCPToolResult unchanged """ content = result_to_string(result) - + if fmt == ToolFormat.OPENAI: return {"role": "tool", "tool_call_id": tc.id, "content": content} - + if fmt == ToolFormat.CLAUDE: r: dict[str, Any] = {"type": "tool_result", "tool_use_id": tc.id, "content": content} if result.isError: r["is_error"] = True return r - + if fmt == ToolFormat.GEMINI: return {"functionResponse": {"name": tc.name, "response": {"result": content}}} - - return result # MCP format - return as-is + return result # MCP format - return as-is diff --git a/hud/environment/utils/schema.py b/hud/environment/utils/schema.py index 6e0c2029..346ff2ce 100644 --- a/hud/environment/utils/schema.py +++ b/hud/environment/utils/schema.py @@ -4,36 +4,36 @@ from typing import Any -__all__ = ["ensure_strict_schema", "schema_to_pydantic", "json_type_to_python"] +__all__ = ["ensure_strict_schema", "json_type_to_python", "schema_to_pydantic"] def ensure_strict_schema(schema: dict[str, Any]) -> dict[str, Any]: """Ensure a JSON schema is compatible with OpenAI's strict mode. - + OpenAI strict mode requires: - additionalProperties: false on all objects - All properties must be in required - + Args: schema: Original JSON schema. - + Returns: Modified schema for strict mode. """ schema = dict(schema) - + if schema.get("type") == "object": schema["additionalProperties"] = False - + if "properties" in schema: # All properties must be required schema["required"] = list(schema["properties"].keys()) - + # Recursively process nested objects for prop_schema in schema["properties"].values(): if isinstance(prop_schema, dict): _ensure_strict_recursive(prop_schema) - + return schema @@ -46,7 +46,7 @@ def _ensure_strict_recursive(schema: dict[str, Any]) -> None: for prop_schema in schema["properties"].values(): if isinstance(prop_schema, dict): _ensure_strict_recursive(prop_schema) - + elif schema.get("type") == "array" and "items" in schema: if isinstance(schema["items"], dict): _ensure_strict_recursive(schema["items"]) @@ -54,35 +54,35 @@ def _ensure_strict_recursive(schema: dict[str, Any]) -> None: def schema_to_pydantic(name: str, schema: dict[str, Any]) -> type: """Convert JSON schema to a Pydantic model. - + Args: name: Model name (used for class name). schema: JSON schema with properties. - + Returns: Dynamically created Pydantic model class. """ from pydantic import Field, create_model - + properties = schema.get("properties", {}) required = set(schema.get("required", [])) - + fields = {} for prop_name, prop_schema in properties.items(): prop_type = json_type_to_python(prop_schema.get("type", "string")) default = ... if prop_name in required else None description = prop_schema.get("description", "") fields[prop_name] = (prop_type, Field(default=default, description=description)) - + return create_model(f"{name}Input", **fields) def json_type_to_python(json_type: str) -> type: """Map JSON schema type to Python type. - + Args: json_type: JSON schema type string. - + Returns: Corresponding Python type. """ diff --git a/hud/samples/browser.py b/hud/samples/browser.py index 17de5d99..f6268dad 100644 --- a/hud/samples/browser.py +++ b/hud/samples/browser.py @@ -17,7 +17,7 @@ class BrowserTask(Task): mcp_config: dict[str, Any] = Field( default_factory=lambda: { "browser": { - "url": "https://mcp.hud.ai/v3/mcp", + "url": settings.hud_mcp_url, "headers": { "Authorization": f"Bearer {settings.api_key}", "Mcp-Image": "hudevals/hud-remote-browser:0.1.1", diff --git a/hud/trace/__init__.py b/hud/trace/__init__.py index 60afb5d4..9022c021 100644 --- a/hud/trace/__init__.py +++ b/hud/trace/__init__.py @@ -12,7 +12,7 @@ async with env.trace("google-search") as tc: await tc.call_tool("navigate", {"url": "..."}) tc.reward = 0.9 - + # tc has the results print(tc.trace_id, tc.reward, tc.duration, tc.success) ``` @@ -23,7 +23,7 @@ # This body runs 4 times, each with a different tc! await tc.call_tool("navigate", {"url": "..."}) tc.reward = evaluate() - + # tc.results contains all parallel traces # tc.reward is the mean reward print(f"Mean reward: {tc.reward}") diff --git a/hud/trace/context.py b/hud/trace/context.py index 668daad6..5ca5975a 100644 --- a/hud/trace/context.py +++ b/hud/trace/context.py @@ -54,7 +54,7 @@ def get_current_trace_headers() -> dict[str, str] | None: class TracePayload(BaseModel): """Base payload for trace enter/exit - sent to both endpoints.""" - + task_name: str prompt: str | None = None code_snippet: str | None = None @@ -67,7 +67,7 @@ class TracePayload(BaseModel): class TraceExitPayload(TracePayload): """Exit payload - includes result fields.""" - + reward: float | None = None success: bool = True error_message: str | None = None @@ -77,24 +77,25 @@ class TraceExitPayload(TracePayload): # Auto-instrumentation for httpx # ============================================================================= + def _is_hud_url(url_str: str) -> bool: """Check if URL is a HUD service (inference or MCP).""" from urllib.parse import urlparse - + # Extract hostnames from settings URLs gateway_host = urlparse(settings.hud_gateway_url).netloc mcp_host = urlparse(settings.hud_mcp_url).netloc - + # Parse the request URL and check against known HUD hosts parsed = urlparse(url_str) request_host = parsed.netloc or url_str.split("/")[0] - + return request_host == gateway_host or request_host == mcp_host def _httpx_request_hook(request: Any) -> None: """httpx event hook that adds trace headers and auth to HUD requests. - + For inference.hud.ai and mcp.hud.ai: - Injects trace headers (Trace-Id) if in trace context - Injects Authorization header if API key is set and no auth present @@ -102,14 +103,14 @@ def _httpx_request_hook(request: Any) -> None: url_str = str(request.url) if not _is_hud_url(url_str): return - + # Inject trace headers if in trace context headers = get_current_trace_headers() if headers is not None: for key, value in headers.items(): request.headers[key] = value logger.debug("Added trace headers to request: %s", url_str) - + # Auto-inject API key if not present has_auth = "authorization" in {k.lower() for k in request.headers} if not has_auth and settings.api_key: @@ -126,7 +127,7 @@ def _instrument_client(client: Any) -> None: """Add trace hook to an httpx client instance.""" is_async = hasattr(client, "aclose") hook = _async_httpx_request_hook if is_async else _httpx_request_hook - + existing_hooks = client.event_hooks.get("request", []) if hook not in existing_hooks: existing_hooks.append(hook) @@ -140,23 +141,23 @@ def _patch_httpx() -> None: except ImportError: logger.debug("httpx not installed, skipping auto-instrumentation") return - + _original_async_init = httpx.AsyncClient.__init__ - + def _patched_async_init(self: Any, *args: Any, **kwargs: Any) -> None: _original_async_init(self, *args, **kwargs) _instrument_client(self) - + httpx.AsyncClient.__init__ = _patched_async_init # type: ignore[method-assign] - + _original_sync_init = httpx.Client.__init__ - + def _patched_sync_init(self: Any, *args: Any, **kwargs: Any) -> None: _original_sync_init(self, *args, **kwargs) _instrument_client(self) - + httpx.Client.__init__ = _patched_sync_init # type: ignore[method-assign] - + logger.debug("httpx auto-instrumentation enabled") @@ -168,9 +169,10 @@ def _patched_sync_init(self: Any, *args: Any, **kwargs: Any) -> None: # TraceContext # ============================================================================= + class TraceContext: """Lightweight context for a traced execution. - + Attributes: trace_id: Unique identifier for this trace name: Task name @@ -181,36 +183,37 @@ class TraceContext: prompt: Task prompt (defaults from env.prompt, user-settable) error: Exception if failed results: All trace results (for parent trace) - + Computed: headers: Gateway headers duration: Execution time in seconds success: True if no error done: True if completed - + Example: ```python # Simple trace async with env.trace("task") as tc: await tc.call_tool("navigate", {"url": "..."}) tc.reward = 0.9 - + # With variants (A/B testing) and group (multiple runs) - async with env.trace("task", + async with env.trace( + "task", variants={"model": ["gpt-4o", "claude"]}, group=3, ) as tc: model = tc.variants["model"] # Assigned for this run response = await call_llm(model=model) tc.reward = evaluate(response) - + # tc.results has 6 traces (2 variants x 3 runs each) # All share the same tc.group_id for t in tc.results: print(f"{t.variants}: reward={t.reward}") ``` """ - + def __init__( self, env: Environment, @@ -228,50 +231,50 @@ def __init__( # Identity self.trace_id: str = trace_id or str(uuid.uuid4()) self.name: str = name - + # Job linkage - auto-detect from current job context if not provided if job_id is None: current_job = get_current_job() self.job_id: str | None = current_job.id if current_job else None else: self.job_id = job_id - + self.group_id: str | None = _group_id # Links parallel traces together self.index: int = _index # Local only, for debugging - + # Variant assignment (for A/B testing) self.variants: dict[str, Any] = _variants or {} - + # User-settable self.reward: float | None = None self.prompt: str | None = getattr(env, "prompt", None) # From env, can override - + # Error tracking self.error: BaseException | None = None - + # Parallel/variant results (nested) self.results: list[TraceContext] | None = None - + # Code and config (for reproducibility) self.code_snippet: str | None = _code_snippet self.env_config: dict[str, Any] | None = _env_config - + # Private self._env = env self._api_key = api_key self._started_at: datetime | None = None self._completed_at: datetime | None = None self._token: contextvars.Token[dict[str, str] | None] | None = None - + # ========================================================================= # Computed Properties # ========================================================================= - + @property def headers(self) -> dict[str, str]: """Headers for gateway integration.""" return {"Trace-Id": self.trace_id} - + @property def duration(self) -> float: """Execution duration in seconds.""" @@ -279,30 +282,30 @@ def duration(self) -> float: return 0.0 end = self._completed_at or datetime.now(UTC) return (end - self._started_at).total_seconds() - + @property def success(self) -> bool: """True if no error occurred.""" return self.error is None - + @property def done(self) -> bool: """True if execution completed.""" return self._completed_at is not None - + def _get_api_key(self) -> str | None: return self._api_key or settings.api_key - + def _build_base_payload(self) -> TracePayload: """Build the base payload for enter/exit.""" # Check if all connectors are from hubs (fully reproducible) all_hubs = getattr(self._env, "_all_hubs", False) - + # Convert env_config dict to EnvConfig model env_config_model: EnvConfig | None = None if self.env_config: env_config_model = EnvConfig(**self.env_config) - + return TracePayload( task_name=self.name, prompt=self.prompt, @@ -313,11 +316,11 @@ def _build_base_payload(self) -> TracePayload: group_id=self.group_id, variants=self.variants if self.variants else None, ) - + # ========================================================================= # Tool Operations # ========================================================================= - + async def call_tool( self, name: str, @@ -325,17 +328,17 @@ async def call_tool( ) -> MCPToolResult: """Call a tool by name (delegates to environment).""" return await self._env.call_tool(name, arguments) # type: ignore[attr-defined] - + # ========================================================================= # Backend Integration # ========================================================================= - + async def log(self, metrics: dict[str, Any]) -> None: """Log metrics to the backend.""" api_key = self._get_api_key() if not settings.telemetry_enabled or not api_key: return - + try: await make_request( method="POST", @@ -345,13 +348,13 @@ async def log(self, metrics: dict[str, Any]) -> None: ) except Exception as e: logger.warning("Failed to log metrics: %s", e) - + async def _trace_enter(self) -> None: """Notify backend that trace has started.""" api_key = self._get_api_key() if not settings.telemetry_enabled or not api_key: return - + try: payload = self._build_base_payload() await make_request( @@ -362,18 +365,18 @@ async def _trace_enter(self) -> None: ) except Exception as e: logger.warning("Failed to send trace enter: %s", e) - + async def _trace_exit(self, error_message: str | None = None) -> None: """Notify backend that trace has completed.""" api_key = self._get_api_key() if not settings.telemetry_enabled or not api_key: return - + # Use evaluate tool reward if not manually set reward = self.reward if reward is None: reward = getattr(self._env, "_evaluate_reward", None) - + try: payload = TraceExitPayload( **self._build_base_payload().model_dump(), @@ -389,17 +392,17 @@ async def _trace_exit(self, error_message: str | None = None) -> None: ) except Exception as e: logger.warning("Failed to send trace exit: %s", e) - + # ========================================================================= # Context Manager # ========================================================================= - + async def __aenter__(self) -> Self: self._started_at = datetime.now(UTC) self._token = _current_trace_headers.set(self.headers) await self._trace_enter() return self - + async def __aexit__( self, exc_type: type[BaseException] | None, @@ -407,18 +410,18 @@ async def __aexit__( exc_tb: TracebackType | None, ) -> None: self._completed_at = datetime.now(UTC) - + if self._token is not None: _current_trace_headers.reset(self._token) self._token = None - + error_msg: str | None = None if exc_type is not None: self.error = exc_val error_msg = str(exc_val) if exc_val else "Unknown error" - + # Send exit with all data (reward, error, etc.) await self._trace_exit(error_msg) - + def __repr__(self) -> str: return f"TraceContext({self.trace_id[:8]}..., name={self.name!r}, reward={self.reward})" diff --git a/hud/trace/mixin.py b/hud/trace/mixin.py index 977f58e7..ac514e16 100644 --- a/hud/trace/mixin.py +++ b/hud/trace/mixin.py @@ -33,15 +33,15 @@ def _expand_variants( variants: dict[str, Any] | None, ) -> list[dict[str, Any]]: """Expand variants dict into all combinations. - + Args: variants: Dict where values can be: - Single value: {"model": "gpt-4o"} → fixed - List: {"model": ["gpt-4o", "claude"]} → expand - + Returns: List of variant assignments, one per combination. - + Examples: >>> _expand_variants(None) [{}] @@ -55,7 +55,7 @@ def _expand_variants( """ if not variants: return [{}] - + # Normalize: single values become single-element lists expanded: dict[str, list[Any]] = {} for key, value in variants.items(): @@ -63,76 +63,74 @@ def _expand_variants( expanded[key] = value else: expanded[key] = [value] - + # Generate all combinations keys = list(expanded.keys()) value_lists = [expanded[k] for k in keys] - - return [ - dict(zip(keys, combo, strict=True)) - for combo in itertools.product(*value_lists) - ] + + return [dict(zip(keys, combo, strict=True)) for combo in itertools.product(*value_lists)] class TraceMixin: """Mixin that adds trace capabilities to Environment. - + This mixin provides: - trace(): Create a TraceContext for recording agent runs - Parallel execution with group=N parameter - A/B testing with variants parameter - + Example: ```python - class Environment(TraceMixin, MCPServer): - ... - + class Environment(TraceMixin, MCPServer): ... + + env = Environment("my-env") - + # Single trace async with env.trace("task") as tc: await tc.call_tool("navigate", {"url": "..."}) tc.reward = 0.9 - + # Parallel traces (runs 4 times) async with env.trace("task", group=4) as tc: await tc.call_tool("navigate", {"url": "..."}) tc.reward = 0.9 - + # A/B testing (2 variants x 3 runs = 6 traces) - async with env.trace("task", + async with env.trace( + "task", variants={"model": ["gpt-4o", "claude"]}, group=3, ) as tc: model = tc.variants["model"] response = await call_llm(model=model) tc.reward = evaluate(response) - + # Access results for t in tc.results: print(f"{t.variants} run {t.index}: reward={t.reward}") ``` """ - + # These will be provided by the Environment class name: str - + # Store last parallel results (list of completed TraceContext objects) _last_traces: list[TraceContext] | None = None - + async def call_tool(self, name: str, arguments: dict[str, Any] | None = None) -> MCPToolResult: """Placeholder - implemented by Environment.""" raise NotImplementedError - + def _capture_code_snippet(self) -> str | None: """Capture the code inside the trace() with-block (best effort). - + Returns None if source cannot be extracted (e.g., REPL, Jupyter). """ frame = inspect.currentframe() if frame is None: return None - + try: # Go up: _capture_code_snippet -> trace -> user code caller = frame.f_back @@ -140,7 +138,7 @@ def _capture_code_snippet(self) -> str | None: caller = caller.f_back if caller is None: return None - + body_source, _ = _get_with_block_body(caller) return body_source except ASTExtractionError: @@ -151,23 +149,23 @@ def _capture_code_snippet(self) -> str | None: return None finally: del frame - + def _get_env_config(self) -> dict[str, Any] | None: """Get serializable environment configuration. - + Returns dict with connections and local tools. """ # This will be overridden by Environment with actual implementation return None - + @property def last_traces(self) -> list[TraceContext] | None: """Get TraceContext objects from the last parallel execution. - + Each TraceContext has: trace_id, index, reward, duration, error, success """ return self._last_traces - + @asynccontextmanager async def trace( self, @@ -181,7 +179,7 @@ async def trace( api_key: str | None = None, ) -> AsyncGenerator[TraceContext, None]: """Create a trace context for recording an agent run. - + The trace context provides: - Unique trace identification - Task name linking (for training data construction) @@ -189,16 +187,16 @@ async def trace( - Tool call delegation - Reward setting - Metrics logging - + A/B Testing: Use `variants` to define experiment variables. Each list value creates a variant; single values are fixed. All combinations are expanded and run. - + Parallel Execution: Use `group` to run multiple times per variant for statistical significance. Total traces = len(variants combinations) x group. - + Args: name: Task name for this trace (used for task construction) variants: A/B test configuration. Dict where: @@ -214,38 +212,39 @@ async def trace( trace_id: Optional trace ID (auto-generated if not provided). For parallel execution, each trace gets a unique ID. api_key: Optional API key for backend calls (defaults to settings.api_key) - + Yields: TraceContext for this trace. Inside the body: - `tc.variants` = current variant assignment (e.g., {"model": "gpt-4o"}) - `tc.index` = local run index (for debugging) - `tc.group_id` = links all traces in this parallel execution - + After execution (for variants/group > 1): - `tc.results` = list of all TraceContext objects - `tc.reward` = mean reward across all traces - + Example: ```python # Single execution async with env.trace("task") as tc: await tc.call_tool("search", {"query": "..."}) tc.reward = 1.0 - + # A/B test: 2 variants x 3 runs = 6 traces - async with env.trace("task", + async with env.trace( + "task", variants={"model": ["gpt-4o", "claude"]}, group=3, ) as tc: model = tc.variants["model"] # Assigned per-trace response = await call_llm(model=model) tc.reward = evaluate(response) - + # Access results for t in tc.results: print(f"{t.variants} run {t.index}: reward={t.reward}") ``` - + Limitations (for variants/group > 1): - Requires source file (won't work in REPL/Jupyter) - Outer variables captured at enter time, changes don't propagate back @@ -254,17 +253,17 @@ async def trace( """ if group <= 0: raise ValueError("group must be >= 1") - + # Expand variants into all combinations variant_combos = _expand_variants(variants) total_traces = len(variant_combos) * group - + # Capture code snippet (best effort - won't work in REPL/Jupyter) code_snippet = self._capture_code_snippet() - + # Get environment config env_config = self._get_env_config() - + # Validate parallelization - only remote connections allowed for group > 1 if total_traces > 1 and not self.is_parallelizable: # type: ignore[attr-defined] local_conns = self.local_connections # type: ignore[attr-defined] @@ -274,7 +273,7 @@ async def trace( f" Local connections (stdio/Docker) can only run one instance.\n" f" Use remote connections (HTTP/URL) for parallel execution." ) - + if total_traces == 1: # Simple case: single trace # TraceContext enters FIRST (sets headers in contextvar) @@ -289,9 +288,8 @@ async def trace( _code_snippet=code_snippet, _env_config=env_config, ) - async with tc: - async with self: # type: ignore[attr-defined] - yield tc + async with tc, self: # type: ignore[attr-defined] + yield tc else: # Parallel execution: each trace gets its own environment instance # Parent environment NOT entered - each child connects independently @@ -305,7 +303,7 @@ async def trace( code_snippet=code_snippet, env_config=env_config, ) - + # Create parent tc with results injected tc = TraceContext( env=self, # type: ignore[arg-type] @@ -318,14 +316,14 @@ async def trace( ) tc.results = completed self._last_traces = completed - + # Compute aggregate reward (mean of non-None rewards) rewards = [t.reward for t in completed if t.reward is not None] if rewards: tc.reward = sum(rewards) / len(rewards) - + yield tc - + async def _run_parallel_trace( self, name: str, @@ -338,14 +336,14 @@ async def _run_parallel_trace( env_config: dict[str, Any] | None, ) -> list[TraceContext]: """Run parallel trace execution using AST extraction. - + This method: 1. Captures the caller's frame 2. Extracts the with-block body via AST 3. Creates (variants x group) TraceContext instances 4. Runs the body in parallel 5. Stores results in self._last_traces - + Args: name: Task name variant_combos: List of variant assignments (one per combination) @@ -360,7 +358,7 @@ async def _run_parallel_trace( frame = inspect.currentframe() if frame is None: raise ASTExtractionError("Cannot get current frame") - + try: # Go up: _run_parallel_trace -> trace -> user code caller_frame = frame.f_back @@ -368,16 +366,16 @@ async def _run_parallel_trace( caller_frame = caller_frame.f_back if caller_frame is None: raise ASTExtractionError("Cannot get caller frame") - + # Extract the with-block body body_source, captured_locals = _get_with_block_body(caller_frame) - + finally: del frame # Avoid reference cycles - + # Calculate total traces total_traces = len(variant_combos) * group - + # Use provided group_ids or generate one shared group_id if group_ids: if len(group_ids) != total_traces: @@ -390,7 +388,7 @@ async def _run_parallel_trace( # All traces share one auto-generated group_id shared_group_id = str(uuid.uuid4()) resolved_group_ids = [shared_group_id] * total_traces - + # Create TraceContext for each (variant, run) combination trace_contexts: list[TraceContext] = [] idx = 0 @@ -409,28 +407,31 @@ async def _run_parallel_trace( ) trace_contexts.append(tc) idx += 1 - + # Run in parallel total = len(trace_contexts) logger.info( "Running %d traces for task '%s' (%d variants x %d runs)", - total, name, len(variant_combos), group, + total, + name, + len(variant_combos), + group, ) completed = await run_parallel_traces(trace_contexts, body_source, captured_locals) - + # Store results self._last_traces = completed - + # Calculate stats rewards = [tc.reward for tc in completed if tc.reward is not None] mean_reward = sum(rewards) / len(rewards) if rewards else 0.0 success_count = sum(1 for tc in completed if tc.success) - + logger.info( "Traces complete: %d/%d succeeded, mean_reward=%.3f", success_count, len(completed), mean_reward, ) - + return completed diff --git a/hud/trace/parallel.py b/hud/trace/parallel.py index f20de00d..60b3bb8f 100644 --- a/hud/trace/parallel.py +++ b/hud/trace/parallel.py @@ -25,41 +25,37 @@ class ASTExtractionError(Exception): def _get_with_block_body(frame: Any) -> tuple[str, dict[str, Any]]: """Extract the body of a with-block from the calling frame. - + Args: frame: The calling frame (from inspect.currentframe()) - + Returns: Tuple of (body_source, captured_locals) """ filename = frame.f_code.co_filename lineno = frame.f_lineno - + # Check for interactive session if filename.startswith("<") or filename in ("", ""): - raise ASTExtractionError( - "Cannot extract source from interactive session. Use a .py file." - ) - + raise ASTExtractionError("Cannot extract source from interactive session. Use a .py file.") + # Read and parse source lines = linecache.getlines(filename) if not lines: with open(filename, encoding="utf-8") as f: lines = f.readlines() - + source = "".join(lines) tree = ast.parse(source, filename=filename) - + # Find the async with containing this line with_node = _find_async_with(tree, lineno) if with_node is None: - raise ASTExtractionError( - f"Cannot find 'async with' statement at line {lineno}" - ) - + raise ASTExtractionError(f"Cannot find 'async with' statement at line {lineno}") + # Extract body source body_source = _extract_body(lines, with_node) - + return body_source, frame.f_locals.copy() @@ -87,10 +83,10 @@ def _extract_body(lines: list[str], with_node: ast.AsyncWith) -> str: """Extract the body source from an AsyncWith node.""" if not with_node.body: return "pass" - + start = with_node.body[0].lineno - 1 end = _get_end_line(with_node.body[-1]) - + body = "".join(lines[start:end]) return textwrap.dedent(body) @@ -101,7 +97,7 @@ async def run_parallel_traces( captured_locals: dict[str, Any], ) -> list[TraceContext]: """Run the trace body in parallel for multiple contexts. - + Returns the TraceContext objects after execution - they contain: - trace_id - index @@ -109,14 +105,14 @@ async def run_parallel_traces( - duration - Any error is captured in the context """ - + # Create runner function wrapped = f"async def __runner__(tc):\n{textwrap.indent(body_source, ' ')}" code = compile(wrapped, "", "exec") namespace = captured_locals.copy() exec(code, namespace) # noqa: S102 runner = namespace["__runner__"] - + async def run_one(tc: TraceContext) -> TraceContext: try: async with tc: @@ -126,6 +122,6 @@ async def run_one(tc: TraceContext) -> TraceContext: # Store error in context for inspection tc._error = e # type: ignore[attr-defined] return tc - + results = await asyncio.gather(*[run_one(tc) for tc in trace_contexts]) return list(results) diff --git a/hud/trace/tests/__init__.py b/hud/trace/tests/__init__.py index 79c48157..93f3ee87 100644 --- a/hud/trace/tests/__init__.py +++ b/hud/trace/tests/__init__.py @@ -1,2 +1 @@ """Tests for hud.trace module.""" - diff --git a/hud/trace/tests/test_context.py b/hud/trace/tests/test_context.py index 8ccba9e2..38ccfbce 100644 --- a/hud/trace/tests/test_context.py +++ b/hud/trace/tests/test_context.py @@ -46,14 +46,15 @@ def test_injects_trace_headers_for_hud_urls(self) -> None: mock_request = MagicMock() mock_request.url = "https://inference.hud.ai/v1/chat" mock_request.headers = {} - + # Set up trace context from hud.trace.context import _current_trace_headers + token = _current_trace_headers.set({"Trace-Id": "test-trace-123"}) - + try: _httpx_request_hook(mock_request) - + assert mock_request.headers["Trace-Id"] == "test-trace-123" finally: _current_trace_headers.reset(token) @@ -63,12 +64,12 @@ def test_injects_api_key_for_hud_urls(self) -> None: mock_request = MagicMock() mock_request.url = "https://mcp.hud.ai/browser" mock_request.headers = {} - + with patch("hud.trace.context.settings") as mock_settings: mock_settings.api_key = "test-api-key" - + _httpx_request_hook(mock_request) - + assert mock_request.headers["Authorization"] == "Bearer test-api-key" def test_does_not_override_existing_auth(self) -> None: @@ -76,12 +77,12 @@ def test_does_not_override_existing_auth(self) -> None: mock_request = MagicMock() mock_request.url = "https://mcp.hud.ai/browser" mock_request.headers = {"Authorization": "Bearer existing-token"} - + with patch("hud.trace.context.settings") as mock_settings: mock_settings.api_key = "test-api-key" - + _httpx_request_hook(mock_request) - + assert mock_request.headers["Authorization"] == "Bearer existing-token" def test_ignores_non_hud_urls(self) -> None: @@ -89,17 +90,18 @@ def test_ignores_non_hud_urls(self) -> None: mock_request = MagicMock() mock_request.url = "https://api.openai.com/v1/chat" mock_request.headers = {} - + # Set up trace context from hud.trace.context import _current_trace_headers + token = _current_trace_headers.set({"Trace-Id": "test-trace-123"}) - + try: with patch("hud.trace.context.settings") as mock_settings: mock_settings.api_key = "test-api-key" - + _httpx_request_hook(mock_request) - + # No headers should be added assert "Trace-Id" not in mock_request.headers assert "Authorization" not in mock_request.headers @@ -114,7 +116,7 @@ def test_init_generates_trace_id(self) -> None: """TraceContext generates trace_id if not provided.""" mock_env = MagicMock() tc = TraceContext(env=mock_env, name="test-task") - + assert tc.trace_id is not None assert len(tc.trace_id) == 36 # UUID format @@ -122,21 +124,21 @@ def test_init_uses_provided_trace_id(self) -> None: """TraceContext uses provided trace_id.""" mock_env = MagicMock() tc = TraceContext(env=mock_env, name="test-task", trace_id="custom-id") - + assert tc.trace_id == "custom-id" def test_headers_contains_trace_id(self) -> None: """headers property returns dict with trace ID.""" mock_env = MagicMock() tc = TraceContext(env=mock_env, name="test-task", trace_id="test-123") - + assert tc.headers == {"Trace-Id": "test-123"} def test_success_true_when_no_error(self) -> None: """success property returns True when no error.""" mock_env = MagicMock() tc = TraceContext(env=mock_env, name="test-task") - + assert tc.success is True def test_success_false_when_error(self) -> None: @@ -144,21 +146,21 @@ def test_success_false_when_error(self) -> None: mock_env = MagicMock() tc = TraceContext(env=mock_env, name="test-task") tc.error = ValueError("test error") - + assert tc.success is False def test_done_false_initially(self) -> None: """done property returns False initially.""" mock_env = MagicMock() tc = TraceContext(env=mock_env, name="test-task") - + assert tc.done is False def test_variants_empty_by_default(self) -> None: """variants is empty dict by default.""" mock_env = MagicMock() tc = TraceContext(env=mock_env, name="test-task") - + assert tc.variants == {} def test_variants_set_from_init(self) -> None: @@ -169,7 +171,7 @@ def test_variants_set_from_init(self) -> None: name="test-task", _variants={"model": "gpt-4o", "temp": 0.7}, ) - + assert tc.variants == {"model": "gpt-4o", "temp": 0.7} @pytest.mark.asyncio @@ -177,17 +179,17 @@ async def test_context_manager_sets_headers(self) -> None: """Context manager sets trace headers in contextvar.""" mock_env = MagicMock() tc = TraceContext(env=mock_env, name="test-task", trace_id="test-123") - + # Mock telemetry calls with patch.object(tc, "_trace_enter", new_callable=AsyncMock): with patch.object(tc, "_trace_exit", new_callable=AsyncMock): assert get_current_trace_headers() is None - + async with tc: headers = get_current_trace_headers() assert headers is not None assert headers["Trace-Id"] == "test-123" - + assert get_current_trace_headers() is None @pytest.mark.asyncio @@ -195,13 +197,13 @@ async def test_context_manager_captures_error(self) -> None: """Context manager captures exception in error field.""" mock_env = MagicMock() tc = TraceContext(env=mock_env, name="test-task") - + with patch.object(tc, "_trace_enter", new_callable=AsyncMock): with patch.object(tc, "_trace_exit", new_callable=AsyncMock): with pytest.raises(ValueError): async with tc: raise ValueError("test error") - + assert tc.error is not None assert str(tc.error) == "test error" assert tc.success is False @@ -211,19 +213,21 @@ async def test_call_tool_delegates_to_env(self) -> None: """call_tool delegates to environment.""" mock_env = MagicMock() mock_env.call_tool = AsyncMock(return_value="result") - + tc = TraceContext(env=mock_env, name="test-task") result = await tc.call_tool("my_tool", {"arg": "value"}) - + mock_env.call_tool.assert_called_once_with("my_tool", {"arg": "value"}) assert result == "result" def test_repr(self) -> None: """__repr__ shows useful info.""" mock_env = MagicMock() - tc = TraceContext(env=mock_env, name="test-task", trace_id="abc12345-6789-0000-0000-000000000000") + tc = TraceContext( + env=mock_env, name="test-task", trace_id="abc12345-6789-0000-0000-000000000000" + ) tc.reward = 0.95 - + repr_str = repr(tc) assert "abc12345" in repr_str assert "test-task" in repr_str @@ -237,38 +241,38 @@ def test_prompt_defaults_from_env(self) -> None: """TraceContext.prompt defaults from env.prompt.""" mock_env = MagicMock() mock_env.prompt = "Task prompt from environment" - + tc = TraceContext( env=mock_env, name="test-task", trace_id="test-123", ) - + assert tc.prompt == "Task prompt from environment" def test_prompt_none_when_env_has_no_prompt(self) -> None: """TraceContext.prompt is None when env has no prompt.""" mock_env = MagicMock(spec=[]) # No prompt attribute - + tc = TraceContext( env=mock_env, name="test-task", trace_id="test-123", ) - + assert tc.prompt is None def test_prompt_can_be_overridden(self) -> None: """TraceContext.prompt can be set to override env default.""" mock_env = MagicMock() mock_env.prompt = "Original prompt" - + tc = TraceContext( env=mock_env, name="test-task", trace_id="test-123", ) - + tc.prompt = "Overridden prompt" assert tc.prompt == "Overridden prompt" @@ -277,12 +281,12 @@ def test_prompt_included_in_payload(self) -> None: mock_env = MagicMock() mock_env.prompt = "Test prompt" mock_env._all_hubs = False - + tc = TraceContext( env=mock_env, name="test-task", trace_id="test-123", ) - + payload = tc._build_base_payload() assert payload.prompt == "Test prompt" diff --git a/hud/trace/tests/test_mixin.py b/hud/trace/tests/test_mixin.py index eddcaa8c..c6b90a33 100644 --- a/hud/trace/tests/test_mixin.py +++ b/hud/trace/tests/test_mixin.py @@ -3,7 +3,7 @@ from __future__ import annotations from typing import Any -from unittest.mock import AsyncMock, MagicMock, patch +from unittest.mock import MagicMock import pytest @@ -35,11 +35,13 @@ def test_list_expands_to_variants(self) -> None: def test_multiple_lists_create_combinations(self) -> None: """Multiple lists create all combinations.""" - result = _expand_variants({ - "model": ["a", "b"], - "temp": [0.0, 1.0], - }) - + result = _expand_variants( + { + "model": ["a", "b"], + "temp": [0.0, 1.0], + } + ) + assert len(result) == 4 assert {"model": "a", "temp": 0.0} in result assert {"model": "a", "temp": 1.0} in result @@ -48,11 +50,13 @@ def test_multiple_lists_create_combinations(self) -> None: def test_mixed_single_and_list(self) -> None: """Mixed single values and lists work correctly.""" - result = _expand_variants({ - "model": ["gpt-4o", "claude"], - "temp": 0.7, - }) - + result = _expand_variants( + { + "model": ["gpt-4o", "claude"], + "temp": 0.7, + } + ) + assert len(result) == 2 assert {"model": "gpt-4o", "temp": 0.7} in result assert {"model": "claude", "temp": 0.7} in result @@ -60,32 +64,26 @@ def test_mixed_single_and_list(self) -> None: class MockEnvironment(TraceMixin): """Mock environment for testing TraceMixin.""" - + def __init__(self) -> None: self.name = "test-env" self._connections: dict[str, Any] = {} self._last_traces = None - + @property def is_parallelizable(self) -> bool: - return all( - getattr(c, "is_remote", True) - for c in self._connections.values() - ) - + return all(getattr(c, "is_remote", True) for c in self._connections.values()) + @property def local_connections(self) -> list[str]: - return [ - name for name, c in self._connections.items() - if getattr(c, "is_local", False) - ] - + return [name for name, c in self._connections.items() if getattr(c, "is_local", False)] + async def call_tool(self, name: str, arguments: dict[str, Any] | None = None) -> Any: return {"name": name, "arguments": arguments} - - async def __aenter__(self) -> "MockEnvironment": + + async def __aenter__(self) -> MockEnvironment: return self - + async def __aexit__(self, *args: Any) -> None: pass @@ -97,7 +95,7 @@ class TestTraceMixin: async def test_trace_single_creates_context(self) -> None: """trace() with group=1 creates single TraceContext.""" env = MockEnvironment() - + async with env.trace("test-task") as tc: assert tc.name == "test-task" assert tc.trace_id is not None @@ -107,17 +105,17 @@ async def test_trace_single_creates_context(self) -> None: async def test_trace_sets_reward(self) -> None: """reward can be set on TraceContext.""" env = MockEnvironment() - + async with env.trace("test-task") as tc: tc.reward = 0.95 - + assert tc.reward == 0.95 @pytest.mark.asyncio async def test_trace_with_variants_single(self) -> None: """trace() with single variant value works.""" env = MockEnvironment() - + async with env.trace("test-task", variants={"model": "gpt-4o"}) as tc: assert tc.variants == {"model": "gpt-4o"} @@ -125,13 +123,13 @@ async def test_trace_with_variants_single(self) -> None: async def test_trace_rejects_parallel_with_local_connections(self) -> None: """trace() raises error for parallel with local connections.""" env = MockEnvironment() - + # Add a local connection mock_conn = MagicMock() mock_conn.is_local = True mock_conn.is_remote = False env._connections["local-server"] = mock_conn - + with pytest.raises(ValueError, match="Cannot run parallel traces"): async with env.trace("test-task", group=2) as tc: pass @@ -140,13 +138,13 @@ async def test_trace_rejects_parallel_with_local_connections(self) -> None: async def test_trace_allows_parallel_with_remote_connections(self) -> None: """trace() allows parallel with only remote connections.""" env = MockEnvironment() - + # Add a remote connection mock_conn = MagicMock() mock_conn.is_local = False mock_conn.is_remote = True env._connections["remote-server"] = mock_conn - + # This should not raise (though parallel execution is complex to test) # Just verify it doesn't raise the local connection error assert env.is_parallelizable is True @@ -155,7 +153,7 @@ async def test_trace_allows_parallel_with_remote_connections(self) -> None: async def test_trace_rejects_zero_group(self) -> None: """trace() raises error for group <= 0.""" env = MockEnvironment() - + with pytest.raises(ValueError, match="group must be >= 1"): async with env.trace("test-task", group=0) as tc: pass @@ -169,10 +167,9 @@ def test_last_traces_none_initially(self) -> None: async def test_trace_context_delegates_call_tool(self) -> None: """TraceContext.call_tool delegates to environment.""" env = MockEnvironment() - + async with env.trace("test-task") as tc: result = await tc.call_tool("my_tool", {"arg": "value"}) - + assert result["name"] == "my_tool" assert result["arguments"] == {"arg": "value"} - diff --git a/hud/trace/tests/test_parallel.py b/hud/trace/tests/test_parallel.py index c0bda532..cf8056e2 100644 --- a/hud/trace/tests/test_parallel.py +++ b/hud/trace/tests/test_parallel.py @@ -3,7 +3,7 @@ from __future__ import annotations import ast -from unittest.mock import AsyncMock, MagicMock, patch +from unittest.mock import AsyncMock, MagicMock import pytest @@ -21,16 +21,16 @@ class TestASTHelpers: def test_find_async_with_finds_correct_node(self) -> None: """_find_async_with finds the async with containing target line.""" - source = ''' + source = """ async def main(): x = 1 async with something as ctx: do_stuff() more_stuff() y = 2 -''' +""" tree = ast.parse(source) - + # Line 4 is inside the async with node = _find_async_with(tree, 5) assert node is not None @@ -38,45 +38,45 @@ async def main(): def test_find_async_with_returns_none_when_not_found(self) -> None: """_find_async_with returns None when line is outside async with.""" - source = ''' + source = """ async def main(): x = 1 async with something as ctx: do_stuff() y = 2 -''' +""" tree = ast.parse(source) - + # Line 6 is outside the async with node = _find_async_with(tree, 7) assert node is None def test_get_end_line(self) -> None: """_get_end_line returns last line of node.""" - source = ''' + source = """ async with ctx: line1() line2() line3() -''' +""" tree = ast.parse(source) async_with = tree.body[0] - + end_line = _get_end_line(async_with) assert end_line >= 4 # At least through line 4 def test_extract_body(self) -> None: """_extract_body extracts the body source from async with.""" - source = '''async with ctx: + source = """async with ctx: do_thing() more_thing() -''' - lines = source.split('\n') - lines = [line + '\n' for line in lines] - +""" + lines = source.split("\n") + lines = [line + "\n" for line in lines] + tree = ast.parse(source) async_with = tree.body[0] - + body = _extract_body(lines, async_with) assert "do_thing()" in body assert "more_thing()" in body @@ -96,13 +96,13 @@ async def test_runs_body_for_each_context(self) -> None: tc.__aenter__ = AsyncMock(return_value=tc) tc.__aexit__ = AsyncMock(return_value=None) mock_tcs.append(tc) - + # Simple body that sets reward body_source = "tc.reward = tc.index * 10" captured_locals: dict[str, object] = {} - + results = await run_parallel_traces(mock_tcs, body_source, captured_locals) - + assert len(results) == 3 # Each context should have had __aenter__ and __aexit__ called for tc in mock_tcs: @@ -116,13 +116,13 @@ async def test_captures_exceptions(self) -> None: tc.index = 0 tc.__aenter__ = AsyncMock(return_value=tc) tc.__aexit__ = AsyncMock(return_value=None) - + # Body that raises body_source = "raise ValueError('test error')" captured_locals: dict[str, object] = {} - + results = await run_parallel_traces([tc], body_source, captured_locals) - + assert len(results) == 1 # Error should be captured, not raised assert hasattr(tc, "_error") or tc.__aexit__.called @@ -135,13 +135,13 @@ async def test_uses_captured_locals(self) -> None: tc.result = None tc.__aenter__ = AsyncMock(return_value=tc) tc.__aexit__ = AsyncMock(return_value=None) - + # Body that uses captured local body_source = "tc.result = my_value * 2" captured_locals = {"my_value": 21} - + results = await run_parallel_traces([tc], body_source, captured_locals) - + assert len(results) == 1 @@ -153,4 +153,3 @@ def test_is_exception(self) -> None: error = ASTExtractionError("test message") assert isinstance(error, Exception) assert str(error) == "test message" - From d6f9f18c5f330aaaf47e756b9ef233e28e85e305 Mon Sep 17 00:00:00 2001 From: lorenss-m Date: Tue, 9 Dec 2025 07:11:52 -0800 Subject: [PATCH 04/92] simplify import structure --- hud/__init__.py | 21 +- hud/cli/__init__.py | 63 ++++ hud/datasets/runner.py | 24 +- hud/datasets/utils.py | 19 +- hud/environment/environment.py | 3 +- hud/environment/integrations/__init__.py | 1 - hud/otel/__init__.py | 17 +- hud/otel/instrumentation.py | 14 +- hud/telemetry/__init__.py | 26 ++ hud/telemetry/instrument.py | 379 ++++++++--------------- hud/tools/jupyter.py | 71 +++-- hud/trace/context.py | 53 +++- hud/utils/tasks.py | 8 +- pyproject.toml | 54 ++-- 14 files changed, 405 insertions(+), 348 deletions(-) diff --git a/hud/__init__.py b/hud/__init__.py index 2f4eef69..bd87d8ed 100644 --- a/hud/__init__.py +++ b/hud/__init__.py @@ -6,29 +6,16 @@ from __future__ import annotations from .environment import Environment -from .telemetry import ( - Trace, - async_job, - async_trace, - clear_trace, - create_job, - get_trace, - instrument, - job, - trace, -) +from .telemetry.instrument import instrument +from .telemetry.job import Job, create_job, get_current_job, job __all__ = [ "Environment", - "Trace", - "async_job", - "async_trace", - "clear_trace", + "Job", "create_job", - "get_trace", + "get_current_job", "instrument", "job", - "trace", ] try: diff --git a/hud/cli/__init__.py b/hud/cli/__init__.py index 13719a8c..ae7f1b16 100644 --- a/hud/cli/__init__.py +++ b/hud/cli/__init__.py @@ -365,6 +365,69 @@ def version() -> None: console.print("HUD CLI version: [cyan]unknown[/cyan]") +@app.command() +def models( + json_output: bool = typer.Option(False, "--json", help="Output as JSON"), +) -> None: + """📋 List available models from HUD inference gateway. + + [not dim]Shows models available via the HUD inference gateway at inference.hud.ai. + + Examples: + hud models # List all models + hud models --json # Output as JSON[/not dim] + """ + from hud.settings import settings + + try: + response = httpx.get( + f"{settings.hud_gateway_url}/models", + headers={"Authorization": f"Bearer {settings.api_key}"} if settings.api_key else {}, + timeout=30.0, + ) + response.raise_for_status() + data = response.json() + + if json_output: + console.print_json(json.dumps(data, indent=2)) + return + + # Parse and display models + models_list = data.get("data", data) if isinstance(data, dict) else data + + if not models_list: + console.print("[yellow]No models found[/yellow]") + return + + console.print(Panel.fit("📋 [bold cyan]Available Models[/bold cyan]", border_style="cyan")) + + table = Table() + table.add_column("Name", style="cyan") + table.add_column("Model (API)", style="green") + table.add_column("Routes", style="yellow") + + for model in models_list: + if isinstance(model, dict): + name = model.get("name", "-") + api_model = model.get("model", model.get("id", "-")) + routes = model.get("routes", []) + routes_str = ", ".join(routes) if routes else "-" + table.add_row(name, api_model, routes_str) + else: + table.add_row(str(model), "-", "-") + + console.print(table) + console.print(f"\n[dim]Gateway: {settings.hud_gateway_url}[/dim]") + + except httpx.HTTPStatusError as e: + console.print(f"[red]❌ API error: {e.response.status_code}[/red]") + console.print(f"[dim]{e.response.text}[/dim]") + raise typer.Exit(1) from e + except Exception as e: + console.print(f"[red]❌ Failed to fetch models: {e}[/red]") + raise typer.Exit(1) from e + + @app.command(context_settings={"allow_extra_args": True, "ignore_unknown_options": True}) def dev( params: list[str] = typer.Argument( # type: ignore[arg-type] # noqa: B008 diff --git a/hud/datasets/runner.py b/hud/datasets/runner.py index 3960aeef..9a4103f7 100644 --- a/hud/datasets/runner.py +++ b/hud/datasets/runner.py @@ -1,4 +1,7 @@ -"""Core task runner for evaluating agents on datasets.""" +"""Core task runner for evaluating agents on datasets. + +Requires the [agents] extra: pip install hud-python[agents] +""" from __future__ import annotations @@ -9,13 +12,12 @@ import warnings from typing import TYPE_CHECKING, Any, cast -from datasets import Dataset, load_dataset - from hud.datasets.utils import calculate_group_stats, submit_rollouts -from hud.telemetry import async_job, async_trace from hud.types import AgentType, Task, Trace if TYPE_CHECKING: + from datasets import Dataset + from hud.agents import MCPAgent logger = logging.getLogger("hud.datasets") @@ -54,6 +56,8 @@ async def run_single_task( Returns: Trace result from agent execution """ + from hud.telemetry import async_trace + name = trace_name or task.prompt or task_id or "task" async with async_trace( @@ -116,6 +120,7 @@ async def run_tasks( await run_tasks(tasks, AgentType.CLAUDE, remote=True) """ import hud + from hud.telemetry import async_job from hud.utils.hud_console import HUDConsole job_metadata = metadata or {} @@ -188,6 +193,11 @@ async def run_dataset( If group_size == 1: List of results from agent.run() in dataset order. If group_size > 1: List of statistics dicts for each task group. """ + from datasets import Dataset as HFDataset + from datasets import load_dataset + + from hud.telemetry import async_job + warnings.warn( "run_dataset() is deprecated. Use run_tasks() instead for more flexibility.", DeprecationWarning, @@ -201,9 +211,9 @@ async def run_dataset( if isinstance(dataset, str): logger.info("Loading dataset %s from HuggingFace...", dataset) dataset_link = dataset - loaded = cast("Dataset", load_dataset(dataset, split=split)) + loaded = cast("HFDataset", load_dataset(dataset, split=split)) task_dicts = cast("list[dict[str, Any]]", list(loaded)) - elif isinstance(dataset, Dataset): + elif isinstance(dataset, HFDataset): task_dicts = cast("list[dict[str, Any]]", list(dataset)) # Try to extract dataset link try: @@ -241,6 +251,8 @@ async def _run_tasks( group_size: int, job_obj: Any, ) -> list[Any]: + from hud.telemetry import async_trace + sem = asyncio.Semaphore(max_concurrent) params = agent_params or {} diff --git a/hud/datasets/utils.py b/hud/datasets/utils.py index c724ce45..41d761b1 100644 --- a/hud/datasets/utils.py +++ b/hud/datasets/utils.py @@ -3,11 +3,10 @@ from __future__ import annotations import logging -from statistics import mean, stdev +from statistics import mean, pstdev from typing import Any import httpx -import numpy as np from pydantic import BaseModel, Field, field_validator, model_validator from hud.settings import settings @@ -304,7 +303,7 @@ def calculate_group_stats( ) continue - rewards = np.array([t.reward for t in task_traces]) + rewards = [t.reward for t in task_traces] errors = [t for t in task_traces if t.isError] task_stats = { @@ -312,12 +311,12 @@ def calculate_group_stats( "prompt": task.prompt or "", "group_id": group_ids[task_idx], "group_size": group_size, - "rewards": rewards.tolist(), - "mean_reward": float(np.mean(rewards)), - "std_reward": float(np.std(rewards)) if len(rewards) > 1 else 0.0, - "min_reward": float(np.min(rewards)), - "max_reward": float(np.max(rewards)), - "success_rate": float(np.sum(rewards > 0) / len(rewards)), + "rewards": rewards, + "mean_reward": mean(rewards), + "std_reward": pstdev(rewards) if len(rewards) > 1 else 0.0, + "min_reward": min(rewards), + "max_reward": max(rewards), + "success_rate": sum(1 for r in rewards if r > 0) / len(rewards), "error_rate": len(errors) / len(task_traces), "traces": task_traces, } @@ -360,7 +359,7 @@ def display_results( # Grouped evaluation stats all_means = [s["mean_reward"] for s in results] overall_mean = mean(all_means) if all_means else 0.0 - overall_std = stdev(all_means) if len(all_means) > 1 else 0.0 + overall_std = pstdev(all_means) if len(all_means) > 1 else 0.0 group_size = results[0].get("group_size", 1) total_episodes = sum(len(s.get("rewards", [])) for s in results) diff --git a/hud/environment/environment.py b/hud/environment/environment.py index f85a662e..b347d7b7 100644 --- a/hud/environment/environment.py +++ b/hud/environment/environment.py @@ -323,7 +323,8 @@ async def _execute_tool(self, name: str, arguments: dict[str, Any]) -> MCPToolRe return self._get_mock_result(name, arguments) if self._router.is_local(name): - result = await self._call_tool(name, arguments) + # Call tool manager directly to avoid FastMCP context requirement + result = await self._tool_manager.call_tool(name, arguments) return MCPToolResult(content=result.content, isError=False) connection_name = self._router.get_connection(name) diff --git a/hud/environment/integrations/__init__.py b/hud/environment/integrations/__init__.py index 4990abad..82610bf9 100644 --- a/hud/environment/integrations/__init__.py +++ b/hud/environment/integrations/__init__.py @@ -33,4 +33,3 @@ class IntegrationsMixin( LangChain: as_langchain_tools() - StructuredTools (requires langchain-core) """ - diff --git a/hud/otel/__init__.py b/hud/otel/__init__.py index 07ab487b..233e3776 100644 --- a/hud/otel/__init__.py +++ b/hud/otel/__init__.py @@ -1,7 +1,13 @@ """HUD OpenTelemetry integration. +.. deprecated:: + The `hud.otel` module is deprecated and will be removed in a future version. + Use `env.trace()` from `hud.environment.Environment` instead. + + This module requires the [agents] extra: + pip install hud-python[agents] + This package provides the internal OpenTelemetry implementation for HUD telemetry. -Users should interact with the telemetry APIs through hud.telemetry instead. Internal Components: - config: OpenTelemetry configuration and setup @@ -14,6 +20,15 @@ from __future__ import annotations +import warnings + +warnings.warn( + "The hud.otel module is deprecated. Use env.trace() instead. " + "This module requires pip install hud-python[agents].", + DeprecationWarning, + stacklevel=2, +) + from .collector import enable_trace_collection from .config import configure_telemetry, is_telemetry_configured, shutdown_telemetry from .context import ( diff --git a/hud/otel/instrumentation.py b/hud/otel/instrumentation.py index db62089e..475ac3e1 100644 --- a/hud/otel/instrumentation.py +++ b/hud/otel/instrumentation.py @@ -2,6 +2,9 @@ This module provides functions to enable MCP OpenTelemetry instrumentation for automatic tracing of MCP protocol communication. + +Note: This module requires the [agents] extra to be installed: + pip install hud-python[agents] """ from __future__ import annotations @@ -16,8 +19,17 @@ logger = logging.getLogger(__name__) +# Check if OpenTelemetry is available +_HAS_OPENTELEMETRY = False +try: + from opentelemetry import trace as _otel_trace # noqa: F401 + + _HAS_OPENTELEMETRY = True +except ImportError: + pass + -def install_mcp_instrumentation(provider: TracerProvider) -> None: +def install_mcp_instrumentation(provider: TracerProvider | None = None) -> None: """Enable community MCP OpenTelemetry instrumentation if present. Args: diff --git a/hud/telemetry/__init__.py b/hud/telemetry/__init__.py index 84e632f7..9cac60da 100644 --- a/hud/telemetry/__init__.py +++ b/hud/telemetry/__init__.py @@ -1,5 +1,22 @@ """HUD Telemetry - Tracing and job management for agent execution. +.. deprecated:: + The `hud.telemetry` module is deprecated and will be removed in a future version. + Use `env.trace()` from `hud.environment.Environment` instead. + + This module requires the [agents] extra: + pip install hud-python[agents] + + Migration: + # Old (deprecated): + async with hud.async_trace("Task"): + await agent.run(task) + + # New (recommended): + async with env.trace("Task") as tc: + await agent.run(task) + tc.reward = result.reward + Provides telemetry APIs for tracking agent execution and experiments. Async Usage (Recommended): @@ -27,6 +44,15 @@ from __future__ import annotations +import warnings + +warnings.warn( + "The hud.telemetry module is deprecated. Use env.trace() instead. " + "This module requires pip install hud-python[agents].", + DeprecationWarning, + stacklevel=2, +) + from .async_context import async_job, async_trace from .instrument import instrument from .job import Job, create_job, job diff --git a/hud/telemetry/instrument.py b/hud/telemetry/instrument.py index d50c45be..e17f4fb6 100644 --- a/hud/telemetry/instrument.py +++ b/hud/telemetry/instrument.py @@ -1,7 +1,16 @@ -"""General-purpose instrumentation decorator for HUD telemetry. +"""Simple instrumentation decorator for HUD tracing. -This module provides the instrument() decorator that users can use -to instrument any function with OpenTelemetry spans. +This module provides a lightweight @instrument decorator that records +function calls within the context of env.trace(). No OpenTelemetry required. + +Usage: + @hud.instrument + async def my_function(arg1, arg2): + ... + + # Within a trace context, calls are recorded + async with env.trace("task") as tc: + result = await my_function("a", "b") """ from __future__ import annotations @@ -11,14 +20,12 @@ import inspect import json import logging +import time +import uuid +from datetime import UTC, datetime from typing import TYPE_CHECKING, Any, TypeVar, overload import pydantic_core -from opentelemetry import trace -from opentelemetry.trace import SpanKind, Status, StatusCode - -from hud.otel import configure_telemetry, is_telemetry_configured -from hud.otel.context import get_current_task_run_id if TYPE_CHECKING: from collections.abc import Awaitable, Callable @@ -31,40 +38,19 @@ def _serialize_value(value: Any, max_items: int = 10) -> Any: - """Serialize a value for span attributes. - - Uses pydantic_core.to_json for robust serialization of complex objects. - - Args: - value: The value to serialize - max_items: Maximum number of items for collections - - Returns: - JSON-serializable version of the value - """ - # Simple types pass through + """Serialize a value for recording.""" if isinstance(value, str | int | float | bool | type(None)): return value - # For collections, we need to limit size first if isinstance(value, list | tuple): value = value[:max_items] if len(value) > max_items else value elif isinstance(value, dict) and len(value) > max_items: value = dict(list(value.items())[:max_items]) - # Use pydantic_core for serialization - it handles: - # - Pydantic models (via model_dump) - # - Dataclasses (via asdict) - # - Bytes (encodes to string) - # - Custom objects (via __dict__ or repr) - # - Complex nested structures try: - # Convert to JSON bytes then back to Python objects - # This ensures we get JSON-serializable types json_bytes = pydantic_core.to_json(value, fallback=str) return json.loads(json_bytes) except Exception: - # Fallback if pydantic_core fails somehow return f"<{type(value).__name__}>" @@ -73,11 +59,9 @@ def instrument( func: None = None, *, name: str | None = None, - span_type: str = "function", - attributes: dict[str, Any] | None = None, + category: str = "function", record_args: bool = True, record_result: bool = True, - span_kind: SpanKind = SpanKind.INTERNAL, ) -> Callable[[Callable[..., Any]], Callable[..., Any]]: ... @@ -86,11 +70,9 @@ def instrument( func: Callable[P, R], *, name: str | None = None, - span_type: str = "function", - attributes: dict[str, Any] | None = None, + category: str = "function", record_args: bool = True, record_result: bool = True, - span_kind: SpanKind = SpanKind.INTERNAL, ) -> Callable[P, R]: ... @@ -99,11 +81,9 @@ def instrument( func: Callable[P, Awaitable[R]], *, name: str | None = None, - span_type: str = "function", - attributes: dict[str, Any] | None = None, + category: str = "function", record_args: bool = True, record_result: bool = True, - span_kind: SpanKind = SpanKind.INTERNAL, ) -> Callable[P, Awaitable[R]]: ... @@ -111,269 +91,168 @@ def instrument( func: Callable[..., Any] | None = None, *, name: str | None = None, - span_type: str = "function", - attributes: dict[str, Any] | None = None, + category: str = "function", record_args: bool = True, record_result: bool = True, - span_kind: SpanKind = SpanKind.INTERNAL, ) -> Callable[..., Any]: - """Instrument a function to emit OpenTelemetry spans. + """Instrument a function to record spans within trace context. - This decorator wraps any function to automatically create spans for - observability. It works with both sync and async functions. + This decorator records function calls as spans, compatible with env.trace(). Args: - func: The function to instrument (when used without parentheses) - name: Custom span name (defaults to fully qualified function name) - span_type: The category for this span (e.g., "agent", "mcp", "database", "validation") - attributes: Additional attributes to attach to every span - record_args: Whether to record function arguments in the request field - record_result: Whether to record function result in the result field - span_kind: OpenTelemetry span kind (INTERNAL, CLIENT, SERVER, etc.) + func: The function to instrument + name: Custom span name (defaults to module.function) + category: Span category (e.g., "agent", "tool", "function") + record_args: Whether to record function arguments + record_result: Whether to record function result Returns: - The instrumented function that emits spans + The instrumented function Examples: - # Basic usage - defaults to category="function" @hud.instrument async def process_data(items: list[str]) -> dict: return {"count": len(items)} - # Custom category - @hud.instrument( - span_type="database", # This becomes category="database" - record_args=True, - record_result=True - ) - async def query_users(filter: dict) -> list[User]: - return await db.find(filter) - - # Agent instrumentation - @hud.instrument( - span_type="agent", # category="agent" gets special handling - record_args=False, # Don't record large message arrays - record_result=True - ) - async def get_model_response(self, messages: list) -> Response: - return await self.model.complete(messages) - - # Instrument third-party functions - import requests - requests.get = hud.instrument( - span_type="http", # category="http" - span_kind=SpanKind.CLIENT - )(requests.get) - - # Conditional instrumentation - if settings.enable_db_tracing: - db.query = hud.instrument(db.query) + @hud.instrument(category="agent") + async def call_model(messages: list) -> str: + return await model.generate(messages) """ - # Don't configure telemetry at decoration time - wait until first call - # This allows users to configure alternative backends before importing agents def decorator(func: Callable[..., Any]) -> Callable[..., Any]: - # Check if already instrumented if hasattr(func, "_hud_instrumented"): - logger.debug("Function %s already instrumented, skipping", func.__name__) return func - # Get function metadata func_module = getattr(func, "__module__", "unknown") func_name = getattr(func, "__name__", "unknown") func_qualname = getattr(func, "__qualname__", func_name) - - # Determine span name span_name = name or f"{func_module}.{func_qualname}" - # Get function signature for argument parsing try: sig = inspect.signature(func) except (ValueError, TypeError): sig = None - @functools.wraps(func) - async def async_wrapper(*args: Any, **kwargs: Any) -> Any: - # Ensure telemetry is configured (lazy initialization) - # Only configure with defaults if user hasn't configured it yet - if not is_telemetry_configured(): - configure_telemetry() - - tracer = trace.get_tracer("hud-sdk") - - # Build span attributes - span_attrs = { - "category": span_type, # span_type IS the category - "function.module": func_module, - "function.name": func_name, - "function.qualname": func_qualname, + def _build_span( + trace_id: str, + args: tuple[Any, ...], + kwargs: dict[str, Any], + start_time: str, + end_time: str, + duration_ms: float, + result: Any = None, + error: str | None = None, + ) -> dict[str, Any]: + """Build a span record.""" + attributes: dict[str, Any] = { + "category": category, + "function": func_qualname, + "module": func_module, + "duration_ms": duration_ms, } - # Add custom attributes - if attributes: - span_attrs.update(attributes) - - # Add current task_run_id if available - task_run_id = get_current_task_run_id() - if task_run_id: - span_attrs["hud.task_run_id"] = task_run_id - - # Record function arguments if requested + # Record arguments if record_args and sig: try: bound_args = sig.bind(*args, **kwargs) bound_args.apply_defaults() - - # Serialize arguments (with safety limits) - args_dict = {} - for param_name, value in bound_args.arguments.items(): - try: - # Skip 'self' and 'cls' parameters - if param_name in ("self", "cls"): - continue - - args_dict[param_name] = _serialize_value(value) - except Exception: - args_dict[param_name] = "" - + args_dict = { + k: _serialize_value(v) + for k, v in bound_args.arguments.items() + if k not in ("self", "cls") + } if args_dict: - args_json = json.dumps(args_dict) - span_attrs["function.arguments"] = args_json - # Always set generic request field for consistency - span_attrs["request"] = args_json + attributes["request"] = json.dumps(args_dict) except Exception as e: - logger.debug("Failed to record function arguments: %s", e) + logger.debug("Failed to serialize args: %s", e) - with tracer.start_as_current_span( - span_name, - kind=span_kind, - attributes=span_attrs, - ) as span: + # Record result + if record_result and result is not None and error is None: try: - # Execute the function - result = await func(*args, **kwargs) - - # Record result if requested - if record_result: - try: - serialized = _serialize_value(result) - result_json = json.dumps(serialized) - span.set_attribute("function.result", result_json) - # Always set generic result field for consistency - span.set_attribute("result", result_json) - - # Also set result type for complex objects - if not isinstance( - result, str | int | float | bool | type(None) | list | tuple | dict - ): - span.set_attribute("function.result_type", type(result).__name__) - except Exception as e: - logger.debug("Failed to record function result: %s", e) - - span.set_status(Status(StatusCode.OK)) - return result - + attributes["result"] = json.dumps(_serialize_value(result)) except Exception as e: - # Record exception and set error status - span.record_exception(e) - span.set_status(Status(StatusCode.ERROR, str(e))) - raise - - @functools.wraps(func) - def sync_wrapper(*args: Any, **kwargs: Any) -> Any: - # Ensure telemetry is configured (lazy initialization) - # Only configure with defaults if user hasn't configured it yet - if not is_telemetry_configured(): - configure_telemetry() - - tracer = trace.get_tracer("hud-sdk") - - # Build span attributes (same as async) - span_attrs = { - "category": span_type, # span_type IS the category - "function.module": func_module, - "function.name": func_name, - "function.qualname": func_qualname, + logger.debug("Failed to serialize result: %s", e) + + # Record error + if error: + attributes["error"] = error + + return { + "trace_id": trace_id, + "span_id": uuid.uuid4().hex[:16], + "name": span_name, + "start_time": start_time, + "end_time": end_time, + "status_code": "ERROR" if error else "OK", + "attributes": attributes, } - if attributes: - span_attrs.update(attributes) - - task_run_id = get_current_task_run_id() - if task_run_id: - span_attrs["hud.task_run_id"] = task_run_id - - # Record function arguments if requested - if record_args and sig: - try: - bound_args = sig.bind(*args, **kwargs) - bound_args.apply_defaults() - - args_dict = {} - for param_name, value in bound_args.arguments.items(): - try: - if param_name in ("self", "cls"): - continue + def _get_trace_id() -> str | None: + """Get trace_id from current trace context.""" + from hud.trace.context import get_current_trace_headers - args_dict[param_name] = _serialize_value(value) - except Exception: - args_dict[param_name] = "" - - if args_dict: - args_json = json.dumps(args_dict) - span_attrs["function.arguments"] = args_json - # Always set generic request field for consistency - span_attrs["request"] = args_json - except Exception as e: - logger.debug("Failed to record function arguments: %s", e) + headers = get_current_trace_headers() + if headers: + return headers.get("Trace-Id") + return None - with tracer.start_as_current_span( - span_name, - kind=span_kind, - attributes=span_attrs, - ) as span: - try: - # Execute the function - result = func(*args, **kwargs) - - # Record result if requested - if record_result: - try: - serialized = _serialize_value(result) - result_json = json.dumps(serialized) - span.set_attribute("function.result", result_json) - # Always set generic result field for consistency - span.set_attribute("result", result_json) - - # Also set result type for complex objects - if not isinstance( - result, str | int | float | bool | type(None) | list | tuple | dict - ): - span.set_attribute("function.result_type", type(result).__name__) - except Exception as e: - logger.debug("Failed to record function result: %s", e) - - span.set_status(Status(StatusCode.OK)) - return result + @functools.wraps(func) + async def async_wrapper(*args: Any, **kwargs: Any) -> Any: + trace_id = _get_trace_id() + start_time = datetime.now(UTC).isoformat() + start_perf = time.perf_counter() + error: str | None = None + result: Any = None + + try: + result = await func(*args, **kwargs) + return result + except Exception as e: + error = f"{type(e).__name__}: {e}" + raise + finally: + end_time = datetime.now(UTC).isoformat() + duration_ms = (time.perf_counter() - start_perf) * 1000 + + if trace_id: + _build_span( + trace_id, args, kwargs, start_time, end_time, duration_ms, result, error + ) + logger.debug("Span: %s (%.2fms)", span_name, duration_ms) - except Exception as e: - span.record_exception(e) - span.set_status(Status(StatusCode.ERROR, str(e))) - raise + @functools.wraps(func) + def sync_wrapper(*args: Any, **kwargs: Any) -> Any: + trace_id = _get_trace_id() + start_time = datetime.now(UTC).isoformat() + start_perf = time.perf_counter() + error: str | None = None + result: Any = None + + try: + result = func(*args, **kwargs) + return result + except Exception as e: + error = f"{type(e).__name__}: {e}" + raise + finally: + end_time = datetime.now(UTC).isoformat() + duration_ms = (time.perf_counter() - start_perf) * 1000 + + if trace_id: + _build_span( + trace_id, args, kwargs, start_time, end_time, duration_ms, result, error + ) + logger.debug("Span: %s (%.2fms)", span_name, duration_ms) - # Choose wrapper based on function type wrapper = async_wrapper if asyncio.iscoroutinefunction(func) else sync_wrapper - - # Mark as instrumented wrapper._hud_instrumented = True # type: ignore[attr-defined] wrapper._hud_original = func # type: ignore[attr-defined] return wrapper - # Handle usage with or without parentheses if func is None: - # Called with arguments: @instrument(name="foo") return decorator - else: - # Called without arguments: @instrument - return decorator(func) + return decorator(func) + + +__all__ = ["instrument"] diff --git a/hud/tools/jupyter.py b/hud/tools/jupyter.py index 479e647b..b525caa2 100644 --- a/hud/tools/jupyter.py +++ b/hud/tools/jupyter.py @@ -1,4 +1,7 @@ -"""Jupyter execution tool.""" +"""Jupyter execution tool. + +Requires the [agents] extra: pip install hud-python[agents] +""" from __future__ import annotations @@ -8,12 +11,6 @@ from typing import TYPE_CHECKING, Any, ClassVar from uuid import uuid4 -import tornado -from tornado.escape import json_decode, json_encode, url_escape -from tornado.httpclient import AsyncHTTPClient, HTTPRequest -from tornado.ioloop import PeriodicCallback -from tornado.websocket import websocket_connect - from hud.tools.base import BaseTool from hud.tools.types import ContentResult, ToolError @@ -80,6 +77,15 @@ def __init__( kernel_id: (Optional) If set, connect to the existed kernel with kernel_id. If empty, create new kernel """ + # Check tornado is available + try: + import tornado # noqa: F401 + except ImportError as e: + raise ImportError( + "JupyterTool requires the [agents] extra. " + "Install with: pip install hud-python[agents]" + ) from e + super().__init__( env=None, name="jupyter", @@ -94,12 +100,12 @@ def __init__( # Kernel state (reuse existing or create new) self._kernel_id = kernel_id - self._ws = None + self._ws: Any = None self._initialized = False # WebSocket heartbeat self._heartbeat_interval = 10000 # 10 seconds - self._heartbeat_callback = None + self._heartbeat_callback: Any = None async def __call__(self, code: str, execution_timeout: int = 15) -> list[ContentBlock]: """Execute Python code in the Jupyter kernel. @@ -140,6 +146,12 @@ async def _ensure_kernel(self) -> None: async def _connect(self) -> None: """Connect to Jupyter kernel via WebSocket.""" + import tornado.iostream + from tornado.escape import json_decode, json_encode, url_escape + from tornado.httpclient import AsyncHTTPClient, HTTPRequest + from tornado.ioloop import PeriodicCallback + from tornado.websocket import websocket_connect + if self._ws: self._ws.close() self._ws = None @@ -177,22 +189,22 @@ async def _connect(self) -> None: # Setup heartbeat to keep connection alive if self._heartbeat_callback: self._heartbeat_callback.stop() - self._heartbeat_callback = PeriodicCallback(self._send_heartbeat, self._heartbeat_interval) - self._heartbeat_callback.start() - async def _send_heartbeat(self) -> None: - """Send heartbeat to maintain WebSocket connection.""" - if not self._ws: - return - try: - self._ws.ping() - except tornado.iostream.StreamClosedError: + async def heartbeat() -> None: + if not self._ws: + return try: - await self._connect() - except ConnectionRefusedError: - logger.warning( - "Failed to reconnect to kernel websocket - Is the kernel still running?" - ) + self._ws.ping() + except tornado.iostream.StreamClosedError: + try: + await self._connect() + except ConnectionRefusedError: + logger.warning( + "Failed to reconnect to kernel websocket - Is the kernel still running?" + ) + + self._heartbeat_callback = PeriodicCallback(heartbeat, self._heartbeat_interval) + self._heartbeat_callback.start() async def _execute(self, code: str, execution_timeout: int = 15) -> str: """Execute code in Jupyter kernel and return output. @@ -204,11 +216,14 @@ async def _execute(self, code: str, execution_timeout: int = 15) -> str: Returns: String output from the kernel """ + from tornado.escape import json_decode, json_encode + from tornado.httpclient import AsyncHTTPClient + if not self._ws: await self._connect() msg_id = uuid4().hex - self._ws.write_message( # type: ignore + self._ws.write_message( json_encode( { "header": { @@ -233,13 +248,13 @@ async def _execute(self, code: str, execution_timeout: int = 15) -> str: ) ) - outputs = [] + outputs: list[str] = [] async def wait_for_messages() -> bool: execution_done = False while not execution_done: - msg = await self._ws.read_message() # type: ignore - msg = json_decode(msg) # type: ignore + msg = await self._ws.read_message() + msg = json_decode(msg) msg_type = msg["msg_type"] parent_msg_id = msg["parent_header"].get("msg_id", None) @@ -285,6 +300,8 @@ async def interrupt_kernel() -> None: async def shutdown(self) -> None: """Shutdown the kernel connection.""" + from tornado.httpclient import AsyncHTTPClient + if self._kernel_id: client = AsyncHTTPClient() try: diff --git a/hud/trace/context.py b/hud/trace/context.py index 5ca5975a..312851f5 100644 --- a/hud/trace/context.py +++ b/hud/trace/context.py @@ -323,11 +323,19 @@ def _build_base_payload(self) -> TracePayload: async def call_tool( self, - name: str, - arguments: dict[str, Any] | None = None, - ) -> MCPToolResult: - """Call a tool by name (delegates to environment).""" - return await self._env.call_tool(name, arguments) # type: ignore[attr-defined] + call: Any, + /, + **kwargs: Any, + ) -> Any: + """Call a tool (delegates to environment). + + Accepts any format: + - String with kwargs: call_tool("navigate", url="...") + - OpenAI tool_call: call_tool(response.choices[0].message.tool_calls[0]) + - Claude tool_use: call_tool(block) # where block.type == "tool_use" + - Gemini function_call: call_tool(part) + """ + return await self._env.call_tool(call, **kwargs) # type: ignore[attr-defined] # ========================================================================= # Backend Integration @@ -401,6 +409,7 @@ async def __aenter__(self) -> Self: self._started_at = datetime.now(UTC) self._token = _current_trace_headers.set(self.headers) await self._trace_enter() + self._print_trace_link() return self async def __aexit__( @@ -425,3 +434,37 @@ async def __aexit__( def __repr__(self) -> str: return f"TraceContext({self.trace_id[:8]}..., name={self.name!r}, reward={self.reward})" + + def _print_trace_link(self) -> None: + """Print a nicely formatted trace link to console and open in browser.""" + import contextlib + import webbrowser + + trace_url = f"https://hud.ai/trace/{self.trace_id}" + + # Try to open in browser (new tab if possible) + with contextlib.suppress(Exception): + webbrowser.open(trace_url, new=2) + + try: + from rich.console import Console + from rich.panel import Panel + from rich.align import Align + + console = Console() + + # Style: HUD colors - gold border, purple link + link_markup = f"[bold underline rgb(108,113,196)][link={trace_url}]{trace_url}[/link][/bold underline rgb(108,113,196)]" + + content = Align.center(link_markup) + + panel = Panel( + content, + title="🔗 Trace Started", + border_style="rgb(192,150,12)", # HUD gold + padding=(0, 2), + ) + console.print(panel) + except ImportError: + # Fallback if rich not available + print(f"Trace: https://hud.ai/trace/{self.trace_id}") diff --git a/hud/utils/tasks.py b/hud/utils/tasks.py index 5e92c6d8..2a5606c7 100644 --- a/hud/utils/tasks.py +++ b/hud/utils/tasks.py @@ -4,8 +4,6 @@ from pathlib import Path from typing import Any -from datasets import Dataset - from hud.types import Task from hud.utils.hud_console import HUDConsole @@ -182,5 +180,7 @@ def save_tasks( data.append(row) - dataset = Dataset.from_list(data) - dataset.push_to_hub(repo_id, **kwargs) + from datasets import Dataset + + ds = Dataset.from_list(data) + ds.push_to_hub(repo_id, **kwargs) diff --git a/pyproject.toml b/pyproject.toml index f3eee8ef..e5cb99cc 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -17,33 +17,15 @@ dependencies = [ # MCP dependencies "hud-mcp-python-sdk>=3.13.2", "hud-fastmcp-python-sdk>=0.1.2", - "hud-mcp-use-python-sdk==2.3.20", - "langchain==0.3.27", - "pathspec>=0.12.1", - "wrapt>=1.14.0", # CLI dependencies "typer>=0.9.0", "rich>=13.0.0", "toml>=0.10.2", "watchfiles>=0.21.0", "questionary==2.1.0", - "prompt-toolkit==3.0.51", + "prompt-toolkit==3.0.51", # Locked for questionary compatibility # Terminal library with mouse support for JSON viewer "blessed>=1.20.0", - # Telemetry - "opentelemetry-instrumentation-mcp==0.47.0", - "opentelemetry-api>=1.34.1", - "opentelemetry-sdk>=1.34.1", - "opentelemetry-exporter-otlp-proto-http>=1.34.1", - # Data and evaluation - "datasets>=2.14.0", - "numpy>=1.24.0", - "pillow>=11.1.0", - # AI providers - "anthropic>=0.75", - "openai>=2.8.1", - "google-genai", - "tornado>=6.5.2", ] classifiers = [ "Development Status :: 4 - Beta", @@ -123,9 +105,34 @@ packages = ["hud"] "hud/py.typed" = "hud/py.typed" [project.optional-dependencies] +# Agent implementations, AI providers, datasets, and telemetry +agents = [ + # MCP-use client (legacy) + "hud-mcp-use-python-sdk==2.3.20", + "langchain==0.3.27", # Required by mcp-use + # AI providers + "anthropic>=0.75", + "openai>=2.8.1", + "google-genai", + # Dataset loading (HuggingFace) + "datasets>=2.14.0", + # Telemetry / OpenTelemetry tracing + "opentelemetry-instrumentation-mcp==0.47.0", + "opentelemetry-api>=1.34.1", + "opentelemetry-sdk>=1.34.1", + "opentelemetry-exporter-otlp-proto-http>=1.34.1", + # Image processing for screenshots/grounding + "pillow>=11.1.0", + # Jupyter kernel support + "tornado>=6.5.2", +] + +# RL training dependencies rl = [ + "hud-python[agents]", # RL needs agent dependencies "peft>=0.17.1", "vllm==0.10.1.1", + "numpy>=1.24.0", # Required for RL training "bitsandbytes>=0.41.0 ; sys_platform == 'linux'", # For 8-bit optimizers (Linux only) "liger-kernel>=0.5.0 ; sys_platform == 'linux'", # Optimized Triton kernels for LLM training (Linux only) # Note: flash-attn is recommended but optional @@ -134,7 +141,7 @@ rl = [ # Development dependencies - includes testing, linting, and automation tools dev = [ - # Include agent dependencies + "hud-python[agents]", # Include agents for dev # Jupyter support "ipykernel", "ipython <9", @@ -151,13 +158,10 @@ dev = [ # Automation and computer control "playwright", "pyautogui>=0.9.54", - "pillow>=11.1.0", ] -# Agent dependencies extend dev -agent = ["hud-python[dev]"] - -agents = ["hud-python[agent]"] +# Alias for backwards compatibility +agent = ["hud-python[agents]"] [tool.ruff] From 05a7fbc44a90ad9bf31f699c49c3b0a100463971 Mon Sep 17 00:00:00 2001 From: lorenss-m Date: Tue, 9 Dec 2025 08:16:19 -0800 Subject: [PATCH 05/92] rename and clean up files --- hud/__init__.py | 9 +- hud/agents/base.py | 6 - hud/agents/utils.py | 50 --- hud/environment/environment.py | 4 +- hud/eval/__init__.py | 48 +++ hud/eval/context.py | 505 +++++++++++++++++++++++++ hud/eval/manager.py | 347 +++++++++++++++++ hud/eval/mixin.py | 338 +++++++++++++++++ hud/eval/parallel.py | 276 ++++++++++++++ hud/eval/tests/__init__.py | 2 + hud/eval/tests/test_context.py | 179 +++++++++ hud/eval/tests/test_mixin.py | 129 +++++++ hud/eval/tests/test_parallel.py | 234 ++++++++++++ hud/otel/__init__.py | 1 + hud/telemetry/__init__.py | 125 +++--- hud/telemetry/instrument.py | 20 +- hud/telemetry/tests/test_instrument.py | 29 +- hud/trace/__init__.py | 42 -- hud/trace/context.py | 470 ----------------------- hud/trace/mixin.py | 437 --------------------- hud/trace/parallel.py | 127 ------- hud/trace/tests/__init__.py | 1 - hud/trace/tests/test_context.py | 292 -------------- hud/trace/tests/test_mixin.py | 175 --------- hud/trace/tests/test_parallel.py | 155 -------- 25 files changed, 2157 insertions(+), 1844 deletions(-) delete mode 100644 hud/agents/utils.py create mode 100644 hud/eval/__init__.py create mode 100644 hud/eval/context.py create mode 100644 hud/eval/manager.py create mode 100644 hud/eval/mixin.py create mode 100644 hud/eval/parallel.py create mode 100644 hud/eval/tests/__init__.py create mode 100644 hud/eval/tests/test_context.py create mode 100644 hud/eval/tests/test_mixin.py create mode 100644 hud/eval/tests/test_parallel.py delete mode 100644 hud/trace/__init__.py delete mode 100644 hud/trace/context.py delete mode 100644 hud/trace/mixin.py delete mode 100644 hud/trace/parallel.py delete mode 100644 hud/trace/tests/__init__.py delete mode 100644 hud/trace/tests/test_context.py delete mode 100644 hud/trace/tests/test_mixin.py delete mode 100644 hud/trace/tests/test_parallel.py diff --git a/hud/__init__.py b/hud/__init__.py index bd87d8ed..43514b06 100644 --- a/hud/__init__.py +++ b/hud/__init__.py @@ -6,16 +6,15 @@ from __future__ import annotations from .environment import Environment +from .eval import EvalContext +from .eval import run_eval as eval from .telemetry.instrument import instrument -from .telemetry.job import Job, create_job, get_current_job, job __all__ = [ "Environment", - "Job", - "create_job", - "get_current_job", + "EvalContext", + "eval", "instrument", - "job", ] try: diff --git a/hud/agents/base.py b/hud/agents/base.py index 58d5cad0..05e12094 100644 --- a/hud/agents/base.py +++ b/hud/agents/base.py @@ -12,7 +12,6 @@ import mcp.types as types from pydantic import BaseModel, ConfigDict -from hud.agents.utils import log_agent_metadata_to_status, log_task_config_to_current_trace from hud.clients.base import AgentMCPClient from hud.types import AgentResponse, BaseAgentConfig, MCPToolCall, MCPToolResult, Trace from hud.utils.hud_console import HUDConsole @@ -210,8 +209,6 @@ async def initialize(self, task: str | Task | None = None) -> None: f"Agent initialized with {len(self.get_available_tools())} tools: {', '.join([t.name for t in self.get_available_tools()])}" # noqa: E501 ) - await log_agent_metadata_to_status(self.model_name, self.checkpoint_name) - async def run(self, prompt_or_task: str | Task | dict[str, Any], max_steps: int = 10) -> Trace: """ Run the agent with the given prompt or task. @@ -237,9 +234,6 @@ async def run(self, prompt_or_task: str | Task | dict[str, Any], max_steps: int # Handle Task objects with full lifecycle if isinstance(prompt_or_task, Task): - # Log a compact summary of task config to the current trace (async) - await log_task_config_to_current_trace(prompt_or_task) - return await self.run_task(prompt_or_task, max_steps) # Handle simple string prompts diff --git a/hud/agents/utils.py b/hud/agents/utils.py deleted file mode 100644 index 0efc83aa..00000000 --- a/hud/agents/utils.py +++ /dev/null @@ -1,50 +0,0 @@ -from __future__ import annotations - -import contextlib -from typing import TYPE_CHECKING - -from hud.otel.context import ( - _update_task_status_async, - get_current_task_run_id, -) - -if TYPE_CHECKING: - from hud.datasets import Task - - -async def log_task_config_to_current_trace(task: Task) -> None: - with contextlib.suppress(Exception): - task_run_id = get_current_task_run_id() - if not task_run_id: - return - - raw_config = task.model_dump() - - await _update_task_status_async( - task_run_id, - "running", - task_id=task.id, - extra_metadata={"task_config": raw_config}, - ) - - -async def log_agent_metadata_to_status( - model_name: str | None = None, checkpoint_name: str | None = None -) -> None: - """Attach agent metadata (model/checkpoint) to current trace status metadata.""" - with contextlib.suppress(Exception): - task_run_id = get_current_task_run_id() - if not task_run_id or (not model_name and not checkpoint_name): - return - - agent_meta = {} - if model_name is not None: - agent_meta["model_name"] = model_name - if checkpoint_name is not None: - agent_meta["checkpoint_name"] = checkpoint_name - - await _update_task_status_async( - task_run_id, - "running", - extra_metadata={"agent": agent_meta}, - ) diff --git a/hud/environment/environment.py b/hud/environment/environment.py index b347d7b7..55d3cd49 100644 --- a/hud/environment/environment.py +++ b/hud/environment/environment.py @@ -15,8 +15,8 @@ from hud.environment.integrations import IntegrationsMixin from hud.environment.mock import MockMixin from hud.environment.router import ConflictResolution, ToolRouter +from hud.eval.mixin import EvalMixin from hud.server.server import MCPServer -from hud.trace.mixin import TraceMixin from hud.types import MCPToolResult __all__ = ["Environment"] @@ -31,7 +31,7 @@ class Environment( ConnectorsMixin, IntegrationsMixin, MockMixin, - TraceMixin, + EvalMixin, MCPServer, ): """Unified MCP environment that acts as both server and client. diff --git a/hud/eval/__init__.py b/hud/eval/__init__.py new file mode 100644 index 00000000..43d947cd --- /dev/null +++ b/hud/eval/__init__.py @@ -0,0 +1,48 @@ +"""HUD Eval - Evaluation context and management. + +This module provides: +- EvalContext: Environment with evaluation tracking (trace_id, reward, etc.) +- EvalMixin: Adds env.eval() method to Environment +- eval(): Standalone context manager for task-based evaluation + +Usage: + # Method on existing environment + async with env.eval("task_name") as env: + await env.call_tool("navigate", url="...") + env.reward = 0.9 + + # Standalone with task slugs + async with hud.eval("my-org/task:1") as env: + await agent.run(env) + + # Blank eval for manual reward + async with hud.eval() as env: + env.reward = compute_reward() +""" + +from __future__ import annotations + +from typing import TYPE_CHECKING + +# EvalMixin is safe to import (uses lazy imports internally) +from hud.eval.mixin import EvalMixin + +# run_eval is safe to import (uses lazy imports internally) +from hud.eval.manager import run_eval + +if TYPE_CHECKING: + from hud.eval.context import EvalContext + +__all__ = [ + "EvalContext", + "EvalMixin", + "run_eval", +] + + +def __getattr__(name: str) -> object: + """Lazy import EvalContext to avoid circular imports.""" + if name == "EvalContext": + from hud.eval.context import EvalContext + return EvalContext + raise AttributeError(f"module 'hud.eval' has no attribute {name!r}") diff --git a/hud/eval/context.py b/hud/eval/context.py new file mode 100644 index 00000000..0dbc747e --- /dev/null +++ b/hud/eval/context.py @@ -0,0 +1,505 @@ +"""EvalContext - Environment with evaluation tracking. + +EvalContext IS an Environment, with additional evaluation tracking +capabilities (trace_id, reward, backend reporting). + +This makes `async with env.eval("task") as env` natural - you get +a full Environment that you can call tools on directly. +""" + +from __future__ import annotations + +import contextvars +import logging +import uuid +from datetime import UTC, datetime +from typing import TYPE_CHECKING, Any, Self + +from pydantic import BaseModel + +from hud.environment import Environment +from hud.environment.types import EnvConfig +from hud.settings import settings +from hud.shared import make_request +from hud.telemetry.job import get_current_job + +if TYPE_CHECKING: + from types import TracebackType + + from hud.types import Task + +logger = logging.getLogger(__name__) + +# Contextvar to store current trace headers (for httpx auto-instrumentation) +_current_trace_headers: contextvars.ContextVar[dict[str, str] | None] = contextvars.ContextVar( + "current_trace_headers", default=None +) + + +def get_current_trace_headers() -> dict[str, str] | None: + """Get the current trace headers from context.""" + return _current_trace_headers.get() + + +# ============================================================================= +# Payload Models +# ============================================================================= + + +class EvalPayload(BaseModel): + """Base payload for eval enter/exit.""" + + task_name: str + prompt: str | None = None + code_snippet: str | None = None + env_config: EnvConfig | None = None + all_hubs: bool = False + job_id: str | None = None + group_id: str | None = None + variants: dict[str, Any] | None = None + + +class EvalExitPayload(EvalPayload): + """Exit payload with result fields.""" + + reward: float | None = None + success: bool = True + error_message: str | None = None + + +# ============================================================================= +# EvalContext +# ============================================================================= + + +class EvalContext(Environment): + """Environment with evaluation tracking capabilities. + + Attributes: + trace_id: Unique identifier for this evaluation + eval_name: Task/evaluation name (separate from env name) + job_id: Links to parent job (auto-detected from hud.job() context) + group_id: Links parallel evaluations together + variants: Variant assignment dict (for A/B testing) + reward: Reward value (user-settable) + error: Exception if failed + results: All eval results (for parallel execution) + task: Task definition (if loaded from slug) + + Example: + ```python + # From existing environment + async with env.eval("task") as ctx: + await ctx.call_tool("navigate", url="...") + ctx.reward = 0.9 + + # Standalone with slug + async with hud.eval("my-org/task:1") as ctx: + await agent.run(ctx) + ctx.reward = result.reward + + # Blank eval + async with hud.eval() as ctx: + ctx.reward = compute_reward() + ``` + """ + + def __init__( + self, + name: str = "eval", + *, + trace_id: str | None = None, + api_key: str | None = None, + job_id: str | None = None, + group_id: str | None = None, + index: int = 0, + variants: dict[str, Any] | None = None, + code_snippet: str | None = None, + env_config: dict[str, Any] | None = None, + task: Task | None = None, + **env_kwargs: Any, + ) -> None: + """Initialize EvalContext. + + Args: + name: Environment/evaluation name + trace_id: Unique trace ID (auto-generated if not provided) + api_key: API key for backend calls + job_id: Job ID to link to (auto-detected if not provided) + group_id: Group ID for parallel evaluations + index: Index in parallel execution + variants: Variant assignment for A/B testing + code_snippet: Code being evaluated (for reproducibility) + env_config: Environment configuration dict + task: Task definition (if loaded from slug) + **env_kwargs: Additional kwargs passed to Environment.__init__ + """ + # Initialize Environment + super().__init__(name=name, **env_kwargs) + + # === Evaluation tracking (not in Environment) === + + # Identity + self.trace_id: str = trace_id or str(uuid.uuid4()) + self.eval_name: str = name # Separate from self.name for clarity + + # Job linkage + if job_id is None: + current_job = get_current_job() + self.job_id: str | None = current_job.id if current_job else None + else: + self.job_id = job_id + + self.group_id: str | None = group_id + self.index: int = index + + # Variant assignment + self.variants: dict[str, Any] = variants or {} + + # User-settable + self.reward: float | None = None + + # Error tracking + self.error: BaseException | None = None + + # Parallel results + self.results: list[EvalContext] | None = None + + # Code and config + self.code_snippet: str | None = code_snippet + self._eval_env_config: dict[str, Any] | None = env_config + + # Task definition (if loaded from slug) + self.task: Task | None = task + + # Apply task configuration + if task: + self._apply_task(task) + + # Private state for eval tracking + self._eval_api_key = api_key + self._started_at: datetime | None = None + self._completed_at: datetime | None = None + self._token: contextvars.Token[dict[str, str] | None] | None = None + + def _apply_task(self, task: Task) -> None: + """Apply a Task definition to this environment.""" + # Set prompt + if task.prompt: + self.prompt = task.prompt + + # Connect MCP servers + if task.mcp_config: + self.connect_mcp_config(task.mcp_config) + + # Configure setup tool calls + if task.setup_tool: + setup_calls = task.setup_tool + if not isinstance(setup_calls, list): + setup_calls = [setup_calls] + for call in setup_calls: + self.setup_tool(call.name, **(call.arguments or {})) + + # Configure evaluate tool calls + if task.evaluate_tool: + eval_calls = task.evaluate_tool + if not isinstance(eval_calls, list): + eval_calls = [eval_calls] + for call in eval_calls: + self.evaluate_tool(call.name, **(call.arguments or {})) + + @classmethod + def from_environment( + cls, + env: Environment, + name: str, + *, + trace_id: str | None = None, + api_key: str | None = None, + job_id: str | None = None, + group_id: str | None = None, + index: int = 0, + variants: dict[str, Any] | None = None, + code_snippet: str | None = None, + env_config: dict[str, Any] | None = None, + ) -> EvalContext: + """Create an EvalContext that copies configuration from an existing Environment. + + This creates a new EvalContext with the same connections as the parent. + Used by env.eval() to create evaluation contexts. + + Args: + env: Parent environment to copy from + name: Evaluation name + trace_id: Unique trace ID + api_key: API key for backend calls + job_id: Job ID to link to + group_id: Group ID for parallel evaluations + index: Index in parallel execution + variants: Variant assignment + code_snippet: Code being evaluated + env_config: Environment configuration + """ + ctx = cls( + name=name, + trace_id=trace_id, + api_key=api_key, + job_id=job_id, + group_id=group_id, + index=index, + variants=variants, + code_snippet=code_snippet, + env_config=env_config, + ) + + # Copy connections from parent + # Note: These are shared references - for parallel execution, + # only remote connections should be used + ctx._connections = env._connections.copy() + ctx._hub_configs = getattr(env, "_hub_configs", []).copy() + ctx._setup_calls = env._setup_calls.copy() + ctx._evaluate_calls = env._evaluate_calls.copy() + + # Copy prompt + if env.prompt: + ctx.prompt = env.prompt + + return ctx + + @classmethod + def from_task( + cls, + task: Task, + name: str | None = None, + *, + trace_id: str | None = None, + api_key: str | None = None, + job_id: str | None = None, + group_id: str | None = None, + index: int = 0, + variants: dict[str, Any] | None = None, + code_snippet: str | None = None, + ) -> EvalContext: + """Create an EvalContext from a Task definition. + + Used by hud.eval(slug) to create evaluation contexts from tasks. + + Args: + task: Task definition + name: Evaluation name (defaults to task.id or "eval") + trace_id: Unique trace ID + api_key: API key for backend calls + job_id: Job ID to link to + group_id: Group ID for parallel evaluations + index: Index in parallel execution + variants: Variant assignment + code_snippet: Code being evaluated + """ + eval_name = name or task.id or "eval" + + return cls( + name=eval_name, + trace_id=trace_id, + api_key=api_key, + job_id=job_id, + group_id=group_id, + index=index, + variants=variants, + code_snippet=code_snippet, + task=task, + ) + + # ========================================================================= + # Computed Properties (eval-specific) + # ========================================================================= + + @property + def headers(self) -> dict[str, str]: + """Headers for gateway integration.""" + return {"Trace-Id": self.trace_id} + + @property + def duration(self) -> float: + """Execution duration in seconds.""" + if self._started_at is None: + return 0.0 + end = self._completed_at or datetime.now(UTC) + return (end - self._started_at).total_seconds() + + @property + def success(self) -> bool: + """True if no error occurred.""" + return self.error is None + + @property + def done(self) -> bool: + """True if execution completed.""" + return self._completed_at is not None + + # ========================================================================= + # Backend Integration + # ========================================================================= + + def _get_eval_api_key(self) -> str | None: + return self._eval_api_key or settings.api_key + + def _build_base_payload(self) -> EvalPayload: + """Build the base payload for enter/exit.""" + env_config_model: EnvConfig | None = None + if self._eval_env_config: + env_config_model = EnvConfig(**self._eval_env_config) + + return EvalPayload( + task_name=self.eval_name, + prompt=self.prompt, + code_snippet=self.code_snippet, + env_config=env_config_model, + all_hubs=self._all_hubs, + job_id=self.job_id, + group_id=self.group_id, + variants=self.variants if self.variants else None, + ) + + async def log(self, metrics: dict[str, Any]) -> None: + """Log metrics to the backend.""" + api_key = self._get_eval_api_key() + if not settings.telemetry_enabled or not api_key: + return + + try: + await make_request( + method="POST", + url=f"{settings.hud_telemetry_url}/traces/{self.trace_id}/log", + json={"metrics": metrics}, + api_key=api_key, + ) + except Exception as e: + logger.warning("Failed to log metrics: %s", e) + + async def _eval_enter(self) -> None: + """Notify backend that eval has started.""" + api_key = self._get_eval_api_key() + if not settings.telemetry_enabled or not api_key: + return + + try: + payload = self._build_base_payload() + await make_request( + method="POST", + url=f"{settings.hud_api_url}/trace/{self.trace_id}/enter", + json=payload.model_dump(exclude_none=True), + api_key=api_key, + ) + except Exception as e: + logger.warning("Failed to send eval enter: %s", e) + + async def _eval_exit(self, error_message: str | None = None) -> None: + """Notify backend that eval has completed.""" + api_key = self._get_eval_api_key() + if not settings.telemetry_enabled or not api_key: + return + + # Use evaluate tool reward if not manually set + reward = self.reward + if reward is None: + reward = getattr(self, "_evaluate_reward", None) + + try: + payload = EvalExitPayload( + **self._build_base_payload().model_dump(), + reward=reward, + success=self.success, + error_message=error_message, + ) + await make_request( + method="POST", + url=f"{settings.hud_api_url}/trace/{self.trace_id}/exit", + json=payload.model_dump(exclude_none=True), + api_key=api_key, + ) + except Exception as e: + logger.warning("Failed to send eval exit: %s", e) + + # ========================================================================= + # Context Manager (override Environment) + # ========================================================================= + + async def __aenter__(self) -> Self: + """Enter eval context - start tracking and connect environment.""" + # Start eval tracking + self._started_at = datetime.now(UTC) + self._token = _current_trace_headers.set(self.headers) + + # Notify backend + await self._eval_enter() + self._print_eval_link() + + # Connect environment (parent class) + await super().__aenter__() + + return self + + async def __aexit__( + self, + exc_type: type[BaseException] | None, + exc_val: BaseException | None, + exc_tb: TracebackType | None, + ) -> None: + """Exit eval context - disconnect and report.""" + self._completed_at = datetime.now(UTC) + + # Track error + error_msg: str | None = None + if exc_type is not None: + self.error = exc_val + error_msg = str(exc_val) if exc_val else "Unknown error" + + # Disconnect environment (parent class) + await super().__aexit__(exc_type, exc_val, exc_tb) + + # Reset context var + if self._token is not None: + _current_trace_headers.reset(self._token) + self._token = None + + # Notify backend + await self._eval_exit(error_msg) + + def __repr__(self) -> str: + return f"EvalContext({self.trace_id[:8]}..., name={self.eval_name!r}, reward={self.reward})" + + def _print_eval_link(self) -> None: + """Print a nicely formatted eval link.""" + import contextlib + import webbrowser + + trace_url = f"https://hud.ai/trace/{self.trace_id}" + + with contextlib.suppress(Exception): + webbrowser.open(trace_url, new=2) + + try: + from rich.align import Align + from rich.console import Console + from rich.panel import Panel + + console = Console() + + style = "bold underline rgb(108,113,196)" + link_markup = f"[{style}][link={trace_url}]{trace_url}[/link][/{style}]" + + content = Align.center(link_markup) + + panel = Panel( + content, + title="🔗 Eval Started", + border_style="rgb(192,150,12)", + padding=(0, 2), + ) + console.print(panel) + except ImportError: + print(f"Eval: {trace_url}") # noqa: T201 + + +# Re-export for backwards compatibility with trace module +__all__ = ["EvalContext", "get_current_trace_headers"] diff --git a/hud/eval/manager.py b/hud/eval/manager.py new file mode 100644 index 00000000..49f735af --- /dev/null +++ b/hud/eval/manager.py @@ -0,0 +1,347 @@ +"""Standalone eval() context manager. + +Provides hud.eval() for task-based evaluation without needing an existing environment. +""" + +from __future__ import annotations + +import inspect +import logging +from contextlib import asynccontextmanager +from typing import TYPE_CHECKING, Any + +from hud.eval.parallel import ( + ASTExtractionError, + execute_parallel_evals, + expand_variants, + get_with_block_body, + resolve_group_ids, +) + +if TYPE_CHECKING: + from collections.abc import AsyncGenerator + + from hud.eval.context import EvalContext + from hud.types import Task + +logger = logging.getLogger(__name__) + + +def _parse_slug(slug: str) -> tuple[str, str | None]: + """Parse a task slug into (base_slug, index_or_wildcard). + + Args: + slug: Task slug like "my-org/task", "my-org/task:1", or "my-org/task:*" + + Returns: + Tuple of (base_slug, index_str or None) + - "my-org/task" -> ("my-org/task", None) + - "my-org/task:1" -> ("my-org/task", "1") + - "my-org/task:*" -> ("my-org/task", "*") + """ + if ":" in slug: + parts = slug.rsplit(":", 1) + return parts[0], parts[1] + return slug, None + + +def _load_tasks_from_slugs(slugs: str | list[str]) -> list[Task]: + """Load tasks from platform by slugs. + + Args: + slugs: Single slug or list of slugs. Slugs can be: + - "my-org/task" - single task + - "my-org/task:N" - task at index N + - "my-org/task:*" - all tasks matching pattern + + Returns: + List of Task objects + """ + import httpx + + from hud.settings import settings + from hud.types import Task + + if isinstance(slugs, str): + slugs = [slugs] + + tasks: list[Task] = [] + + headers = {} + if settings.api_key: + headers["Authorization"] = f"Bearer {settings.api_key}" + + with httpx.Client() as client: + for slug in slugs: + base_slug, index_str = _parse_slug(slug) + + if index_str == "*": + # Fetch all tasks for this evalset + logger.info("Loading all tasks for: %s", base_slug) + response = client.get( + f"{settings.hud_api_url}/tasks/{base_slug}", + headers=headers, + params={"all": "true"}, + ) + response.raise_for_status() + data = response.json() + + if isinstance(data, list): + for item in data: + tasks.append(Task(**item)) + else: + tasks.append(Task(**data)) + + elif index_str is not None: + # Fetch specific task by index + logger.info("Loading task: %s (index %s)", base_slug, index_str) + response = client.get( + f"{settings.hud_api_url}/tasks/{base_slug}", + headers=headers, + params={"index": index_str}, + ) + response.raise_for_status() + data = response.json() + tasks.append(Task(**data)) + + else: + # Fetch single task + logger.info("Loading task: %s", slug) + response = client.get( + f"{settings.hud_api_url}/tasks/{slug}", + headers=headers, + ) + response.raise_for_status() + data = response.json() + tasks.append(Task(**data)) + + return tasks + + +@asynccontextmanager +async def run_eval( + slugs: str | list[str] | None = None, + *, + variants: dict[str, Any] | None = None, + group: int = 1, + group_ids: list[str] | None = None, + job_id: str | None = None, + api_key: str | None = None, +) -> AsyncGenerator[EvalContext, None]: + """Standalone eval context manager. + + Creates an EvalContext for evaluation, optionally loading task configuration + from slugs. + + Args: + slugs: Task slug(s) to load. Can be: + - None: Create blank eval context + - "my-org/task": Single task + - "my-org/task:N": Task at index N + - "my-org/task:*": All tasks matching pattern + - List of any above: Multiple tasks + variants: A/B test configuration (dict with list values expanded) + group: Runs per variant for statistical significance + group_ids: Optional list of group IDs + job_id: Job ID to link to + api_key: API key for backend calls + + Yields: + EvalContext: Environment with evaluation tracking + + Example: + ```python + # Blank eval (for manual reward) + async with hud.eval() as ctx: + ctx.reward = compute_reward() + + # With task slug + async with hud.eval("my-org/browser-task:1") as ctx: + await agent.run(ctx) + ctx.reward = result.reward + + # Multiple tasks + async with hud.eval(["task:1", "task:2"]) as ctx: + await agent.run(ctx) + + # All tasks in evalset + async with hud.eval("my-org/evalset:*") as ctx: + await agent.run(ctx) + + # With variants and group + async with hud.eval( + "task", + variants={"model": ["gpt-4o", "claude"]}, + group=3, + ) as ctx: + model = ctx.variants["model"] + await run_agent(model) + ctx.reward = evaluate() + + # Access results after parallel run + for e in ctx.results: + print(f"{e.variants}: reward={e.reward}") + ``` + """ + if group <= 0: + raise ValueError("group must be >= 1") + + # Expand variants + variant_combos = expand_variants(variants) + + # Load tasks if slugs provided + tasks: list[Task] = [] + if slugs is not None: + tasks = _load_tasks_from_slugs(slugs) + + # Calculate total evaluations + # If we have tasks, each task gets (variants x group) runs + # If no tasks, we have a single blank eval with (variants x group) runs + if tasks: + total_evals = len(tasks) * len(variant_combos) * group + else: + total_evals = len(variant_combos) * group + + # Capture code snippet for parallel execution + code_snippet: str | None = None + if total_evals > 1: + frame = inspect.currentframe() + if frame is not None: + try: + caller = frame.f_back + if caller is not None: + code_snippet, _ = get_with_block_body(caller) + except ASTExtractionError: + pass + finally: + del frame + + # Lazy import to avoid circular dependency + from hud.eval.context import EvalContext + + if total_evals == 1: + # Simple case: single eval + if tasks: + # Single task + ctx = EvalContext.from_task( + task=tasks[0], + api_key=api_key, + job_id=job_id, + variants=variant_combos[0], + code_snippet=code_snippet, + ) + else: + # Blank eval + ctx = EvalContext( + name="eval", + api_key=api_key, + job_id=job_id, + variants=variant_combos[0], + code_snippet=code_snippet, + ) + + async with ctx: + yield ctx + + else: + # Parallel execution + completed = await _run_parallel_eval( + tasks=tasks, + variant_combos=variant_combos, + group=group, + group_ids=group_ids, + job_id=job_id, + api_key=api_key, + code_snippet=code_snippet, + ) + + # Create parent ctx with results + if tasks: + ctx = EvalContext.from_task( + task=tasks[0], + api_key=api_key, + job_id=job_id, + ) + else: + ctx = EvalContext( + name="eval", + api_key=api_key, + job_id=job_id, + ) + + ctx.results = completed + + # Compute aggregate reward + rewards = [e.reward for e in completed if e.reward is not None] + if rewards: + ctx.reward = sum(rewards) / len(rewards) + + yield ctx + + +async def _run_parallel_eval( + tasks: list[Task], + variant_combos: list[dict[str, Any]], + group: int, + group_ids: list[str] | None, + job_id: str | None, + api_key: str | None, + code_snippet: str | None, +) -> list[EvalContext]: + """Run parallel evaluation. + + Creates EvalContexts from tasks (or blank) and runs them in parallel. + """ + # Lazy import to avoid circular dependency + from hud.eval.context import EvalContext + + # Calculate total evals and resolve group IDs + if tasks: + total_evals = len(tasks) * len(variant_combos) * group + else: + total_evals = len(variant_combos) * group + + resolved_group_ids = resolve_group_ids(group_ids, total_evals) + + # Create EvalContexts + eval_contexts: list[EvalContext] = [] + idx = 0 + + if tasks: + # Create context for each (task, variant, run) combination + for task in tasks: + for variant in variant_combos: + for _ in range(group): + ctx = EvalContext.from_task( + task=task, + api_key=api_key, + job_id=job_id, + group_id=resolved_group_ids[idx], + index=idx, + variants=variant, + code_snippet=code_snippet, + ) + eval_contexts.append(ctx) + idx += 1 + else: + # Blank evals for each (variant, run) combination + for variant in variant_combos: + for _ in range(group): + ctx = EvalContext( + name="eval", + api_key=api_key, + job_id=job_id, + group_id=resolved_group_ids[idx], + index=idx, + variants=variant, + code_snippet=code_snippet, + ) + eval_contexts.append(ctx) + idx += 1 + + # Run in parallel (frame depth: _run_parallel_eval -> eval -> user code) + return await execute_parallel_evals(eval_contexts, caller_frame_depth=3) + + +__all__ = ["run_eval"] + diff --git a/hud/eval/mixin.py b/hud/eval/mixin.py new file mode 100644 index 00000000..84bb1ff3 --- /dev/null +++ b/hud/eval/mixin.py @@ -0,0 +1,338 @@ +"""EvalMixin - Adds eval() method to Environment. + +This mixin provides the eval() context manager that creates EvalContext +instances for recording agent runs, with optional parallel execution and +variant-based A/B testing. +""" + +from __future__ import annotations + +import inspect +import logging +from contextlib import asynccontextmanager +from typing import TYPE_CHECKING, Any + +from hud.eval.parallel import ( + ASTExtractionError, + execute_parallel_evals, + expand_variants, + get_with_block_body, + resolve_group_ids, +) + +if TYPE_CHECKING: + from collections.abc import AsyncGenerator + + from hud.eval.context import EvalContext + from hud.types import MCPToolResult + +logger = logging.getLogger(__name__) + + +class EvalMixin: + """Mixin that adds eval capabilities to Environment. + + This mixin provides: + - eval(): Create an EvalContext for recording agent runs + - Parallel execution with group=N parameter + - A/B testing with variants parameter + + Example: + ```python + class Environment(EvalMixin, MCPServer): ... + + + env = Environment("my-env") + + # Single eval - yields EvalContext (which has Environment capabilities) + async with env.eval("task") as ctx: + await ctx.call_tool("navigate", {"url": "..."}) + ctx.reward = 0.9 + + # Parallel evals (runs 4 times) + async with env.eval("task", group=4) as ctx: + await ctx.call_tool("navigate", {"url": "..."}) + ctx.reward = 0.9 + + # A/B testing (2 variants x 3 runs = 6 evals) + async with env.eval( + "task", + variants={"model": ["gpt-4o", "claude"]}, + group=3, + ) as ctx: + model = ctx.variants["model"] + response = await call_llm(model=model) + ctx.reward = evaluate(response) + + # Access results + for e in ctx.results: + print(f"{e.variants} run {e.index}: reward={e.reward}") + ``` + """ + + # These will be provided by the Environment class + name: str + + # Store last parallel results + _last_evals: list[EvalContext] | None = None + + async def call_tool(self, name: str, arguments: dict[str, Any] | None = None) -> MCPToolResult: + """Placeholder - implemented by Environment.""" + raise NotImplementedError + + def _capture_code_snippet(self) -> str | None: + """Capture the code inside the eval() with-block (best effort). + + Returns None if source cannot be extracted (e.g., REPL, Jupyter). + """ + frame = inspect.currentframe() + if frame is None: + return None + + try: + # Go up: _capture_code_snippet -> eval -> user code + caller = frame.f_back + if caller is not None: + caller = caller.f_back + if caller is None: + return None + + body_source, _ = get_with_block_body(caller) + return body_source + except ASTExtractionError: + # Can't extract from REPL/Jupyter - that's OK + return None + except Exception as e: + logger.debug("Failed to capture code snippet: %s", e) + return None + finally: + del frame + + def _get_env_config(self) -> dict[str, Any] | None: + """Get serializable environment configuration. + + Returns dict with connections and local tools. + """ + # This will be overridden by Environment with actual implementation + return None + + @property + def last_evals(self) -> list[EvalContext] | None: + """Get EvalContext objects from the last parallel execution. + + Each EvalContext has: trace_id, index, reward, duration, error, success + """ + return self._last_evals + + @asynccontextmanager + async def eval( + self, + name: str, + *, + variants: dict[str, Any] | None = None, + group: int = 1, + group_ids: list[str] | None = None, + job_id: str | None = None, + trace_id: str | None = None, + api_key: str | None = None, + ) -> AsyncGenerator[EvalContext, None]: + """Create an eval context for recording an agent run. + + The eval context provides: + - Unique trace identification + - Task name linking (for training data construction) + - Headers for gateway integration (auto-injected to inference.hud.ai) + - Tool call capabilities (call_tool, as_openai_chat_tools, etc.) + - Reward setting + - Metrics logging + + A/B Testing: + Use `variants` to define experiment variables. Each list value + creates a variant; single values are fixed. All combinations + are expanded and run. + + Parallel Execution: + Use `group` to run multiple times per variant for statistical + significance. Total evals = len(variants combinations) x group. + + Args: + name: Task name for this eval (used for task construction) + variants: A/B test configuration. Dict where: + - List values are expanded: {"model": ["gpt-4o", "claude"]} + - Single values are fixed: {"temp": 0.7} + - All combinations are run + group: Runs per variant (default: 1) for statistical significance. + group_ids: Optional list of group IDs for each eval. + Length must match (variants x group). If not provided, + a single shared group_id is auto-generated. + job_id: Optional job ID to link this eval to. If not provided, + auto-detects from current `hud.job()` context. + trace_id: Optional trace ID (auto-generated if not provided). + For parallel execution, each eval gets a unique ID. + api_key: Optional API key for backend calls (defaults to settings.api_key) + + Yields: + EvalContext for this evaluation. Inside the body: + - `ctx.variants` = current variant assignment (e.g., {"model": "gpt-4o"}) + - `ctx.index` = local run index (for debugging) + - `ctx.group_id` = links all evals in this parallel execution + - `ctx.call_tool(...)` = call tools on the environment + - `ctx.reward = ...` = set reward + + After execution (for variants/group > 1): + - `ctx.results` = list of all EvalContext objects + - `ctx.reward` = mean reward across all evals + + Example: + ```python + # Single execution + async with env.eval("task") as ctx: + await ctx.call_tool("search", {"query": "..."}) + ctx.reward = 1.0 + + # A/B test: 2 variants x 3 runs = 6 evals + async with env.eval( + "task", + variants={"model": ["gpt-4o", "claude"]}, + group=3, + ) as ctx: + model = ctx.variants["model"] # Assigned per-eval + response = await call_llm(model=model) + ctx.reward = evaluate(response) + + # Access results + for e in ctx.results: + print(f"{e.variants} run {e.index}: reward={e.reward}") + ``` + + Limitations (for variants/group > 1): + - Requires source file (won't work in REPL/Jupyter) + - Outer variables captured at enter time, changes don't propagate back + - Modifying mutable objects causes race conditions + - Cannot use yield/generators inside body + """ + if group <= 0: + raise ValueError("group must be >= 1") + + # Expand variants into all combinations + variant_combos = expand_variants(variants) + total_evals = len(variant_combos) * group + + # Capture code snippet (best effort - won't work in REPL/Jupyter) + code_snippet = self._capture_code_snippet() + + # Get environment config + env_config = self._get_env_config() + + # Validate parallelization - only remote connections allowed for group > 1 + if total_evals > 1 and not self.is_parallelizable: # type: ignore[attr-defined] + local_conns = self.local_connections # type: ignore[attr-defined] + raise ValueError( + f"Cannot run parallel evals (group={group}) with local connections.\n" + f" Local connections: {local_conns}\n" + f" Local connections (stdio/Docker) can only run one instance.\n" + f" Use remote connections (HTTP/URL) for parallel execution." + ) + + # Lazy import to avoid circular dependency + from hud.eval.context import EvalContext + + if total_evals == 1: + # Simple case: single eval + # Create EvalContext from parent environment + ctx = EvalContext.from_environment( + env=self, # type: ignore[arg-type] + name=name, + trace_id=trace_id, + api_key=api_key, + job_id=job_id, + variants=variant_combos[0], + code_snippet=code_snippet, + env_config=env_config, + ) + async with ctx: + yield ctx + else: + # Parallel execution: each eval gets its own environment instance + completed = await self._run_parallel_eval( + name=name, + variant_combos=variant_combos, + group=group, + group_ids=group_ids, + job_id=job_id, + api_key=api_key, + code_snippet=code_snippet, + env_config=env_config, + ) + + # Create parent ctx with results injected + ctx = EvalContext.from_environment( + env=self, # type: ignore[arg-type] + name=name, + trace_id=trace_id, + api_key=api_key, + job_id=job_id, + code_snippet=code_snippet, + env_config=env_config, + ) + ctx.results = completed + self._last_evals = completed + + # Compute aggregate reward (mean of non-None rewards) + rewards = [e.reward for e in completed if e.reward is not None] + if rewards: + ctx.reward = sum(rewards) / len(rewards) + + yield ctx + + async def _run_parallel_eval( + self, + name: str, + variant_combos: list[dict[str, Any]], + group: int, + group_ids: list[str] | None, + job_id: str | None, + api_key: str | None, + code_snippet: str | None, + env_config: dict[str, Any] | None, + ) -> list[EvalContext]: + """Run parallel eval execution. + + Creates EvalContexts from parent environment and runs them in parallel. + """ + # Lazy import to avoid circular dependency + from hud.eval.context import EvalContext + + # Calculate total evals and resolve group IDs + total_evals = len(variant_combos) * group + resolved_group_ids = resolve_group_ids(group_ids, total_evals) + + # Create EvalContext for each (variant, run) combination + eval_contexts: list[EvalContext] = [] + idx = 0 + for variant in variant_combos: + for _ in range(group): + ctx = EvalContext.from_environment( + env=self, # type: ignore[arg-type] + name=name, + api_key=api_key, + job_id=job_id, + group_id=resolved_group_ids[idx], + index=idx, + variants=variant, + code_snippet=code_snippet, + env_config=env_config, + ) + eval_contexts.append(ctx) + idx += 1 + + # Run in parallel (frame depth: _run_parallel_eval -> eval -> user code) + completed = await execute_parallel_evals(eval_contexts, caller_frame_depth=3) + + # Store results + self._last_evals = completed + return completed + + +__all__ = ["EvalMixin"] + diff --git a/hud/eval/parallel.py b/hud/eval/parallel.py new file mode 100644 index 00000000..45d2237d --- /dev/null +++ b/hud/eval/parallel.py @@ -0,0 +1,276 @@ +"""Parallel execution support for evaluations. + +This module provides AST extraction and parallel execution for running +the same eval body N times concurrently. +""" + +from __future__ import annotations + +import ast +import asyncio +import itertools +import linecache +import logging +import textwrap +import uuid +from typing import TYPE_CHECKING, Any + +if TYPE_CHECKING: + from hud.eval.context import EvalContext + +logger = logging.getLogger(__name__) + + +def expand_variants( + variants: dict[str, Any] | None, +) -> list[dict[str, Any]]: + """Expand variants dict into all combinations. + + Args: + variants: Dict where values can be: + - Single value: {"model": "gpt-4o"} → fixed + - List: {"model": ["gpt-4o", "claude"]} → expand + + Returns: + List of variant assignments, one per combination. + + Examples: + >>> expand_variants(None) + [{}] + >>> expand_variants({"model": "gpt-4o"}) + [{"model": "gpt-4o"}] + >>> expand_variants({"model": ["gpt-4o", "claude"]}) + [{"model": "gpt-4o"}, {"model": "claude"}] + """ + if not variants: + return [{}] + + expanded: dict[str, list[Any]] = {} + for key, value in variants.items(): + if isinstance(value, list): + expanded[key] = value + else: + expanded[key] = [value] + + keys = list(expanded.keys()) + value_lists = [expanded[k] for k in keys] + + return [dict(zip(keys, combo, strict=True)) for combo in itertools.product(*value_lists)] + + +def resolve_group_ids( + group_ids: list[str] | None, + total_count: int, +) -> list[str]: + """Resolve group IDs for parallel execution. + + Args: + group_ids: Optional list of group IDs (must match total_count if provided) + total_count: Total number of evals + + Returns: + List of group IDs (one per eval) + + Raises: + ValueError: If group_ids length doesn't match total_count + """ + if group_ids: + if len(group_ids) != total_count: + raise ValueError( + f"group_ids length ({len(group_ids)}) must match total evals ({total_count})" + ) + return group_ids + else: + shared_group_id = str(uuid.uuid4()) + return [shared_group_id] * total_count + + +def log_eval_stats(completed: list[EvalContext], context: str = "") -> None: + """Log statistics for completed evaluations. + + Args: + completed: List of completed EvalContext objects + context: Optional context string for the log message + """ + rewards = [ctx.reward for ctx in completed if ctx.reward is not None] + mean_reward = sum(rewards) / len(rewards) if rewards else 0.0 + success_count = sum(1 for ctx in completed if ctx.success) + + logger.info( + "Evals complete%s: %d/%d succeeded, mean_reward=%.3f", + f" ({context})" if context else "", + success_count, + len(completed), + mean_reward, + ) + + +async def execute_parallel_evals( + contexts: list[EvalContext], + caller_frame_depth: int = 2, +) -> list[EvalContext]: + """Execute evaluations in parallel using AST extraction. + + This is the shared implementation for parallel execution. It: + 1. Captures the caller's frame and extracts with-block body + 2. Runs all provided EvalContexts in parallel + 3. Logs statistics + + Args: + contexts: Pre-created EvalContext instances to run + caller_frame_depth: How many frames to go up to find user code + (default 2: execute_parallel_evals -> caller -> user) + + Returns: + List of completed EvalContext objects with results + """ + import inspect + + # Get the caller's frame + frame = inspect.currentframe() + if frame is None: + raise ASTExtractionError("Cannot get current frame") + + try: + # Go up the specified number of frames + caller_frame = frame + for _ in range(caller_frame_depth): + if caller_frame is not None: + caller_frame = caller_frame.f_back + if caller_frame is None: + raise ASTExtractionError("Cannot get caller frame") + + body_source, captured_locals = get_with_block_body(caller_frame) + + finally: + del frame + + # Run in parallel + logger.info("Running %d parallel evals", len(contexts)) + completed = await run_parallel_evals(contexts, body_source, captured_locals) + + # Log stats + log_eval_stats(completed) + + return completed + + +class ASTExtractionError(Exception): + """Error extracting AST from source.""" + + +def get_with_block_body(frame: Any) -> tuple[str, dict[str, Any]]: + """Extract the body of a with-block from the calling frame. + + Args: + frame: The calling frame (from inspect.currentframe()) + + Returns: + Tuple of (body_source, captured_locals) + """ + filename = frame.f_code.co_filename + lineno = frame.f_lineno + + # Check for interactive session + if filename.startswith("<") or filename in ("", ""): + raise ASTExtractionError("Cannot extract source from interactive session. Use a .py file.") + + # Read and parse source + lines = linecache.getlines(filename) + if not lines: + with open(filename, encoding="utf-8") as f: + lines = f.readlines() + + source = "".join(lines) + tree = ast.parse(source, filename=filename) + + # Find the async with containing this line + with_node = _find_async_with(tree, lineno) + if with_node is None: + raise ASTExtractionError(f"Cannot find 'async with' statement at line {lineno}") + + # Extract body source + body_source = _extract_body(lines, with_node) + + return body_source, frame.f_locals.copy() + + +def _find_async_with(tree: ast.AST, target_line: int) -> ast.AsyncWith | None: + """Find AsyncWith node containing the target line.""" + for node in ast.walk(tree): + if isinstance(node, ast.AsyncWith): + end_line = _get_end_line(node) + if node.lineno <= target_line <= end_line: + return node + return None + + +def _get_end_line(node: ast.AST) -> int: + """Get the last line number of an AST node.""" + end = getattr(node, "end_lineno", getattr(node, "lineno", 0)) + for child in ast.walk(node): + child_end = getattr(child, "end_lineno", 0) + if child_end > end: + end = child_end + return end + + +def _extract_body(lines: list[str], with_node: ast.AsyncWith) -> str: + """Extract the body source from an AsyncWith node.""" + if not with_node.body: + return "pass" + + start = with_node.body[0].lineno - 1 + end = _get_end_line(with_node.body[-1]) + + body = "".join(lines[start:end]) + return textwrap.dedent(body) + + +async def run_parallel_evals( + eval_contexts: list[EvalContext], + body_source: str, + captured_locals: dict[str, Any], +) -> list[EvalContext]: + """Run the eval body in parallel for multiple contexts. + + Returns the EvalContext objects after execution - they contain: + - trace_id + - index + - reward + - duration + - Any error is captured in the context + """ + + # Create runner function + # The variable name in the with statement is 'ctx' by convention + # but we use 'env' since that's what the user will see + wrapped = f"async def __runner__(env):\n{textwrap.indent(body_source, ' ')}" + code = compile(wrapped, "", "exec") + namespace = captured_locals.copy() + exec(code, namespace) # noqa: S102 + runner = namespace["__runner__"] + + async def run_one(ctx: EvalContext) -> EvalContext: + try: + async with ctx: + await runner(ctx) + except Exception as e: + logger.warning("Parallel eval %d failed: %s", ctx.index, e) + ctx.error = e + return ctx + + results = await asyncio.gather(*[run_one(ctx) for ctx in eval_contexts]) + return list(results) + + +__all__ = [ + "ASTExtractionError", + "execute_parallel_evals", + "expand_variants", + "get_with_block_body", + "log_eval_stats", + "resolve_group_ids", + "run_parallel_evals", +] + diff --git a/hud/eval/tests/__init__.py b/hud/eval/tests/__init__.py new file mode 100644 index 00000000..64147a3e --- /dev/null +++ b/hud/eval/tests/__init__.py @@ -0,0 +1,2 @@ +"""Tests for hud.eval module.""" + diff --git a/hud/eval/tests/test_context.py b/hud/eval/tests/test_context.py new file mode 100644 index 00000000..e4cbc8c7 --- /dev/null +++ b/hud/eval/tests/test_context.py @@ -0,0 +1,179 @@ +"""Tests for hud.eval.context module.""" + +from __future__ import annotations + +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest + +from hud.eval.context import ( + EvalContext, + get_current_trace_headers, +) + + +class TestEvalContext: + """Tests for EvalContext.""" + + def test_init_generates_trace_id(self) -> None: + """EvalContext generates trace_id if not provided.""" + ctx = EvalContext(name="test-task") + + assert ctx.trace_id is not None + assert len(ctx.trace_id) == 36 # UUID format + + def test_init_uses_provided_trace_id(self) -> None: + """EvalContext uses provided trace_id.""" + ctx = EvalContext(name="test-task", trace_id="custom-id") + + assert ctx.trace_id == "custom-id" + + def test_headers_contains_trace_id(self) -> None: + """headers property returns dict with trace ID.""" + ctx = EvalContext(name="test-task", trace_id="test-123") + + assert ctx.headers == {"Trace-Id": "test-123"} + + def test_success_true_when_no_error(self) -> None: + """success property returns True when no error.""" + ctx = EvalContext(name="test-task") + + assert ctx.success is True + + def test_success_false_when_error(self) -> None: + """success property returns False when error is set.""" + ctx = EvalContext(name="test-task") + ctx.error = ValueError("test error") + + assert ctx.success is False + + def test_done_false_initially(self) -> None: + """done property returns False initially.""" + ctx = EvalContext(name="test-task") + + assert ctx.done is False + + def test_variants_empty_by_default(self) -> None: + """variants is empty dict by default.""" + ctx = EvalContext(name="test-task") + + assert ctx.variants == {} + + def test_variants_set_from_init(self) -> None: + """variants set from parameter.""" + ctx = EvalContext( + name="test-task", + variants={"model": "gpt-4o", "temp": 0.7}, + ) + + assert ctx.variants == {"model": "gpt-4o", "temp": 0.7} + + @pytest.mark.asyncio + async def test_context_manager_sets_headers(self) -> None: + """Context manager sets trace headers in contextvar.""" + ctx = EvalContext(name="test-task", trace_id="test-123") + + # Mock telemetry calls + with ( + patch.object(ctx, "_eval_enter", new_callable=AsyncMock), + patch.object(ctx, "_eval_exit", new_callable=AsyncMock), + ): + # Mock parent Environment context manager + with patch.object(EvalContext, "__aenter__", return_value=ctx): + with patch.object(EvalContext, "__aexit__", return_value=None): + assert get_current_trace_headers() is None + + # Manually set token for test + from hud.eval.context import _current_trace_headers + + token = _current_trace_headers.set(ctx.headers) + try: + headers = get_current_trace_headers() + assert headers is not None + assert headers["Trace-Id"] == "test-123" + finally: + _current_trace_headers.reset(token) + + assert get_current_trace_headers() is None + + def test_repr(self) -> None: + """__repr__ shows useful info.""" + ctx = EvalContext( + name="test-task", trace_id="abc12345-6789-0000-0000-000000000000" + ) + ctx.reward = 0.95 + + repr_str = repr(ctx) + assert "abc12345" in repr_str + assert "test-task" in repr_str + assert "0.95" in repr_str + + +class TestEvalContextPrompt: + """Tests for EvalContext.prompt feature.""" + + def test_prompt_can_be_set(self) -> None: + """EvalContext.prompt can be set.""" + ctx = EvalContext(name="test-task") + ctx.prompt = "Test prompt" + + assert ctx.prompt == "Test prompt" + + def test_prompt_included_in_payload(self) -> None: + """Prompt is included in eval payload.""" + ctx = EvalContext(name="test-task") + ctx.prompt = "Test prompt" + + payload = ctx._build_base_payload() + assert payload.prompt == "Test prompt" + + +class TestEvalContextFromEnvironment: + """Tests for EvalContext.from_environment factory.""" + + def test_copies_connections(self) -> None: + """from_environment copies connections from parent.""" + from hud.environment import Environment + + parent = Environment("parent-env") + # Add a mock connection + mock_conn = MagicMock() + parent._connections["test-conn"] = mock_conn + + ctx = EvalContext.from_environment(parent, name="test-task") + + assert "test-conn" in ctx._connections + assert ctx._connections["test-conn"] is mock_conn + + def test_copies_prompt(self) -> None: + """from_environment copies prompt from parent.""" + from hud.environment import Environment + + parent = Environment("parent-env") + parent.prompt = "Parent prompt" + + ctx = EvalContext.from_environment(parent, name="test-task") + + assert ctx.prompt == "Parent prompt" + + def test_sets_eval_properties(self) -> None: + """from_environment sets eval-specific properties.""" + from hud.environment import Environment + + parent = Environment("parent-env") + + ctx = EvalContext.from_environment( + parent, + name="test-task", + trace_id="custom-trace", + variants={"model": "gpt-4o"}, + group_id="group-123", + index=5, + ) + + assert ctx.eval_name == "test-task" + assert ctx.trace_id == "custom-trace" + assert ctx.variants == {"model": "gpt-4o"} + assert ctx.group_id == "group-123" + assert ctx.index == 5 + diff --git a/hud/eval/tests/test_mixin.py b/hud/eval/tests/test_mixin.py new file mode 100644 index 00000000..45ee4b53 --- /dev/null +++ b/hud/eval/tests/test_mixin.py @@ -0,0 +1,129 @@ +"""Tests for hud.eval.mixin module.""" + +from __future__ import annotations + +from typing import Any +from unittest.mock import MagicMock + +import pytest + +from hud.eval.mixin import EvalMixin +from hud.eval.parallel import expand_variants + + +class TestExpandVariants: + """Tests for expand_variants helper.""" + + def test_none_returns_empty_dict(self) -> None: + result = expand_variants(None) + assert result == [{}] + + def test_single_value_stays_single(self) -> None: + result = expand_variants({"model": "gpt-4o"}) + assert result == [{"model": "gpt-4o"}] + + def test_list_expands_to_variants(self) -> None: + result = expand_variants({"model": ["gpt-4o", "claude"]}) + assert result == [{"model": "gpt-4o"}, {"model": "claude"}] + + def test_multiple_lists_create_combinations(self) -> None: + result = expand_variants({"model": ["a", "b"], "temp": [0.0, 1.0]}) + assert len(result) == 4 + assert {"model": "a", "temp": 0.0} in result + assert {"model": "b", "temp": 1.0} in result + + +class MockEnvironment(EvalMixin): + """Mock environment for testing EvalMixin.""" + + def __init__(self) -> None: + self.name = "test-env" + self._connections: dict[str, Any] = {} + self._last_evals = None + self._hub_configs: list[dict[str, Any]] = [] + self._setup_calls: list[tuple[str, dict[str, Any]]] = [] + self._evaluate_calls: list[tuple[str, dict[str, Any]]] = [] + self.prompt: str | None = None + + @property + def is_parallelizable(self) -> bool: + return all(getattr(c, "is_remote", True) for c in self._connections.values()) + + @property + def local_connections(self) -> list[str]: + return [name for name, c in self._connections.items() if getattr(c, "is_local", False)] + + +class TestEvalMixin: + """Tests for EvalMixin.""" + + @pytest.mark.asyncio + async def test_eval_single_creates_context(self) -> None: + """eval() with group=1 creates single EvalContext.""" + env = MockEnvironment() + + async with env.eval("test-task") as ctx: + assert ctx.eval_name == "test-task" + assert ctx.trace_id is not None + assert ctx.variants == {} + + @pytest.mark.asyncio + async def test_eval_sets_reward(self) -> None: + """reward can be set on EvalContext.""" + env = MockEnvironment() + + async with env.eval("test-task") as ctx: + ctx.reward = 0.95 + + assert ctx.reward == 0.95 + + @pytest.mark.asyncio + async def test_eval_with_variants_single(self) -> None: + """eval() with single variant value works.""" + env = MockEnvironment() + + async with env.eval("test-task", variants={"model": "gpt-4o"}) as ctx: + assert ctx.variants == {"model": "gpt-4o"} + + @pytest.mark.asyncio + async def test_eval_rejects_parallel_with_local_connections(self) -> None: + """eval() raises error for parallel with local connections.""" + env = MockEnvironment() + + # Add a local connection + mock_conn = MagicMock() + mock_conn.is_local = True + mock_conn.is_remote = False + env._connections["local-server"] = mock_conn + + with pytest.raises(ValueError, match="Cannot run parallel evals"): + async with env.eval("test-task", group=2) as _ctx: + pass + + @pytest.mark.asyncio + async def test_eval_allows_parallel_with_remote_connections(self) -> None: + """eval() allows parallel with only remote connections.""" + env = MockEnvironment() + + # Add a remote connection + mock_conn = MagicMock() + mock_conn.is_local = False + mock_conn.is_remote = True + env._connections["remote-server"] = mock_conn + + # Just verify it doesn't raise the local connection error + assert env.is_parallelizable is True + + @pytest.mark.asyncio + async def test_eval_rejects_zero_group(self) -> None: + """eval() raises error for group <= 0.""" + env = MockEnvironment() + + with pytest.raises(ValueError, match="group must be >= 1"): + async with env.eval("test-task", group=0) as _ctx: + pass + + def test_last_evals_none_initially(self) -> None: + """last_evals is None before any parallel execution.""" + env = MockEnvironment() + assert env.last_evals is None diff --git a/hud/eval/tests/test_parallel.py b/hud/eval/tests/test_parallel.py new file mode 100644 index 00000000..baff6a6b --- /dev/null +++ b/hud/eval/tests/test_parallel.py @@ -0,0 +1,234 @@ +"""Tests for hud.eval.parallel module.""" + +from __future__ import annotations + +import ast +from unittest.mock import AsyncMock, MagicMock + +import pytest + +from hud.eval.parallel import ( + ASTExtractionError, + _extract_body, + _find_async_with, + _get_end_line, + expand_variants, + resolve_group_ids, + run_parallel_evals, +) + + +class TestExpandVariants: + """Tests for expand_variants helper.""" + + def test_none_returns_empty_dict(self) -> None: + """None variants returns list with empty dict.""" + result = expand_variants(None) + assert result == [{}] + + def test_empty_dict_returns_empty_dict(self) -> None: + """Empty variants returns list with empty dict.""" + result = expand_variants({}) + assert result == [{}] + + def test_single_value_stays_single(self) -> None: + """Single non-list value stays as single variant.""" + result = expand_variants({"model": "gpt-4o"}) + assert result == [{"model": "gpt-4o"}] + + def test_list_expands_to_variants(self) -> None: + """List value expands to multiple variants.""" + result = expand_variants({"model": ["gpt-4o", "claude"]}) + assert result == [{"model": "gpt-4o"}, {"model": "claude"}] + + def test_multiple_lists_create_combinations(self) -> None: + """Multiple lists create all combinations.""" + result = expand_variants( + { + "model": ["a", "b"], + "temp": [0.0, 1.0], + } + ) + + assert len(result) == 4 + assert {"model": "a", "temp": 0.0} in result + assert {"model": "a", "temp": 1.0} in result + assert {"model": "b", "temp": 0.0} in result + assert {"model": "b", "temp": 1.0} in result + + def test_mixed_single_and_list(self) -> None: + """Mixed single values and lists work correctly.""" + result = expand_variants( + { + "model": ["gpt-4o", "claude"], + "temp": 0.7, + } + ) + + assert len(result) == 2 + assert {"model": "gpt-4o", "temp": 0.7} in result + assert {"model": "claude", "temp": 0.7} in result + + +class TestResolveGroupIds: + """Tests for resolve_group_ids helper.""" + + def test_uses_provided_group_ids(self) -> None: + """Uses provided group_ids when given.""" + result = resolve_group_ids(["a", "b", "c"], 3) + assert result == ["a", "b", "c"] + + def test_generates_shared_group_id(self) -> None: + """Generates shared group_id when not provided.""" + result = resolve_group_ids(None, 3) + assert len(result) == 3 + # All should be the same + assert result[0] == result[1] == result[2] + # Should be a valid UUID + assert len(result[0]) == 36 + + def test_raises_on_length_mismatch(self) -> None: + """Raises ValueError when group_ids length doesn't match.""" + with pytest.raises(ValueError, match="group_ids length"): + resolve_group_ids(["a", "b"], 3) + + +class TestASTHelpers: + """Tests for AST helper functions.""" + + def test_find_async_with_finds_correct_node(self) -> None: + """_find_async_with finds the async with containing target line.""" + source = """ +async def main(): + x = 1 + async with something as ctx: + do_stuff() + more_stuff() + y = 2 +""" + tree = ast.parse(source) + + # Line 5 is inside the async with + node = _find_async_with(tree, 5) + assert node is not None + assert isinstance(node, ast.AsyncWith) + + def test_find_async_with_returns_none_when_not_found(self) -> None: + """_find_async_with returns None when line is outside async with.""" + source = """ +async def main(): + x = 1 + async with something as ctx: + do_stuff() + y = 2 +""" + tree = ast.parse(source) + + # Line 7 is outside the async with + node = _find_async_with(tree, 7) + assert node is None + + def test_get_end_line(self) -> None: + """_get_end_line returns last line of node.""" + source = """ +async with ctx: + line1() + line2() + line3() +""" + tree = ast.parse(source) + async_with = tree.body[0] + + end_line = _get_end_line(async_with) + assert end_line >= 4 # At least through line 4 + + def test_extract_body(self) -> None: + """_extract_body extracts the body source from async with.""" + source = """async with ctx: + do_thing() + more_thing() +""" + lines = source.split("\n") + lines = [line + "\n" for line in lines] + + tree = ast.parse(source) + async_with = tree.body[0] + assert isinstance(async_with, ast.AsyncWith) + + body = _extract_body(lines, async_with) + assert "do_thing()" in body + assert "more_thing()" in body + + +class TestRunParallelEvals: + """Tests for run_parallel_evals function.""" + + @pytest.mark.asyncio + async def test_runs_body_for_each_context(self) -> None: + """run_parallel_evals runs body for each EvalContext.""" + # Create mock eval contexts + mock_ctxs = [] + for i in range(3): + ctx = MagicMock() + ctx.index = i + ctx.__aenter__ = AsyncMock(return_value=ctx) + ctx.__aexit__ = AsyncMock(return_value=None) + mock_ctxs.append(ctx) + + # Simple body that sets reward + body_source = "env.reward = env.index * 10" + captured_locals: dict[str, object] = {} + + results = await run_parallel_evals(mock_ctxs, body_source, captured_locals) + + assert len(results) == 3 + # Each context should have had __aenter__ and __aexit__ called + for ctx in mock_ctxs: + ctx.__aenter__.assert_called_once() + ctx.__aexit__.assert_called_once() + + @pytest.mark.asyncio + async def test_captures_exceptions(self) -> None: + """run_parallel_evals captures exceptions in context.""" + ctx = MagicMock() + ctx.index = 0 + ctx.__aenter__ = AsyncMock(return_value=ctx) + ctx.__aexit__ = AsyncMock(return_value=None) + + # Body that raises + body_source = "raise ValueError('test error')" + captured_locals: dict[str, object] = {} + + results = await run_parallel_evals([ctx], body_source, captured_locals) + + assert len(results) == 1 + # Error should be captured, not raised + assert hasattr(ctx, "error") or ctx.__aexit__.called + + @pytest.mark.asyncio + async def test_uses_captured_locals(self) -> None: + """run_parallel_evals uses captured locals in body execution.""" + ctx = MagicMock() + ctx.index = 0 + ctx.result = None + ctx.__aenter__ = AsyncMock(return_value=ctx) + ctx.__aexit__ = AsyncMock(return_value=None) + + # Body that uses captured local + body_source = "env.result = my_value * 2" + captured_locals = {"my_value": 21} + + results = await run_parallel_evals([ctx], body_source, captured_locals) + + assert len(results) == 1 + + +class TestASTExtractionError: + """Tests for ASTExtractionError.""" + + def test_is_exception(self) -> None: + """ASTExtractionError is an exception.""" + error = ASTExtractionError("test message") + assert isinstance(error, Exception) + assert str(error) == "test message" + diff --git a/hud/otel/__init__.py b/hud/otel/__init__.py index 233e3776..4efc7d50 100644 --- a/hud/otel/__init__.py +++ b/hud/otel/__init__.py @@ -22,6 +22,7 @@ import warnings +# Show deprecation warning when module is imported warnings.warn( "The hud.otel module is deprecated. Use env.trace() instead. " "This module requires pip install hud-python[agents].", diff --git a/hud/telemetry/__init__.py b/hud/telemetry/__init__.py index 9cac60da..a6c17234 100644 --- a/hud/telemetry/__init__.py +++ b/hud/telemetry/__init__.py @@ -1,73 +1,88 @@ -"""HUD Telemetry - Tracing and job management for agent execution. - -.. deprecated:: - The `hud.telemetry` module is deprecated and will be removed in a future version. - Use `env.trace()` from `hud.environment.Environment` instead. - - This module requires the [agents] extra: - pip install hud-python[agents] - - Migration: - # Old (deprecated): - async with hud.async_trace("Task"): - await agent.run(task) - - # New (recommended): - async with env.trace("Task") as tc: - await agent.run(task) - tc.reward = result.reward - -Provides telemetry APIs for tracking agent execution and experiments. - -Async Usage (Recommended): - >>> import hud - >>> async with hud.async_trace("Task"): - ... await agent.run(task) - >>> async with hud.async_job("Evaluation") as job: - ... async with hud.async_trace("Task", job_id=job.id): - ... await agent.run(task) - -Sync Usage: - >>> import hud - >>> with hud.trace("Task"): - ... do_work() - >>> with hud.job("My Job") as job: - ... with hud.trace("Task", job_id=job.id): - ... do_work() - -APIs: - - async_trace(), async_job() - Async context managers (recommended) - - trace(), job() - Sync context managers - - flush_telemetry() - Manual span flushing (rarely needed) - - instrument() - Function instrumentation decorator +"""HUD Telemetry - Instrumentation for agent execution. + +This module provides: +- instrument: Function instrumentation decorator + +All other APIs are deprecated: +- Job, job, create_job, get_current_job - Use hud.eval() instead +- async_trace(), trace() - Use env.trace() instead +- async_job() - Use hud.eval() instead + +Migration: + # Old (deprecated): + async with hud.async_trace("Task"): + await agent.run(task) + + # New (recommended): + async with env.trace("Task") as tc: + await agent.run(task) + tc.reward = result.reward """ from __future__ import annotations -import warnings +from .instrument import instrument -warnings.warn( - "The hud.telemetry module is deprecated. Use env.trace() instead. " - "This module requires pip install hud-python[agents].", - DeprecationWarning, - stacklevel=2, -) -from .async_context import async_job, async_trace -from .instrument import instrument -from .job import Job, create_job, job -from .replay import clear_trace, get_trace -from .trace import Trace, trace +def __getattr__(name: str): # noqa: ANN202 + """Lazy load deprecated APIs and show warnings.""" + import warnings + + deprecated_apis = { + # Job APIs (deprecated) + "Job", + "job", + "create_job", + "get_current_job", + # OpenTelemetry-based APIs (deprecated, require [agents]) + "async_job", + "async_trace", + "clear_trace", + "get_trace", + "Trace", + "trace", + } + + if name in deprecated_apis: + warnings.warn( + f"hud.telemetry.{name} is deprecated. Use hud.eval() or env.trace() instead.", + DeprecationWarning, + stacklevel=2, + ) + + # Import from submodules + if name in ("Job", "job", "create_job", "get_current_job"): + from .job import Job, create_job, get_current_job, job + + return {"Job": Job, "job": job, "create_job": create_job, "get_current_job": get_current_job}[name] + elif name in ("async_job", "async_trace"): + from .async_context import async_job, async_trace + + return async_job if name == "async_job" else async_trace + elif name in ("clear_trace", "get_trace"): + from .replay import clear_trace, get_trace + + return clear_trace if name == "clear_trace" else get_trace + elif name in ("Trace", "trace"): + from .trace import Trace, trace + + return Trace if name == "Trace" else trace + + raise AttributeError(f"module 'hud.telemetry' has no attribute {name!r}") + __all__ = [ + # Core (always available) + "instrument", + # Deprecated "Job", "Trace", "async_job", "async_trace", "clear_trace", "create_job", + "get_current_job", "get_trace", - "instrument", "job", "trace", ] diff --git a/hud/telemetry/instrument.py b/hud/telemetry/instrument.py index e17f4fb6..0f438b83 100644 --- a/hud/telemetry/instrument.py +++ b/hud/telemetry/instrument.py @@ -9,7 +9,7 @@ async def my_function(arg1, arg2): ... # Within a trace context, calls are recorded - async with env.trace("task") as tc: + async with env.eval("task") as tc: result = await my_function("a", "b") """ @@ -60,6 +60,7 @@ def instrument( *, name: str | None = None, category: str = "function", + span_type: str | None = None, # Alias for category record_args: bool = True, record_result: bool = True, ) -> Callable[[Callable[..., Any]], Callable[..., Any]]: ... @@ -71,6 +72,7 @@ def instrument( *, name: str | None = None, category: str = "function", + span_type: str | None = None, # Alias for category record_args: bool = True, record_result: bool = True, ) -> Callable[P, R]: ... @@ -82,6 +84,7 @@ def instrument( *, name: str | None = None, category: str = "function", + span_type: str | None = None, # Alias for category record_args: bool = True, record_result: bool = True, ) -> Callable[P, Awaitable[R]]: ... @@ -92,17 +95,19 @@ def instrument( *, name: str | None = None, category: str = "function", + span_type: str | None = None, # Alias for category record_args: bool = True, record_result: bool = True, ) -> Callable[..., Any]: - """Instrument a function to record spans within trace context. + """Instrument a function to record spans within eval context. - This decorator records function calls as spans, compatible with env.trace(). + This decorator records function calls as spans, compatible with env.eval(). Args: func: The function to instrument name: Custom span name (defaults to module.function) category: Span category (e.g., "agent", "tool", "function") + span_type: Alias for category (deprecated, use category instead) record_args: Whether to record function arguments record_result: Whether to record function result @@ -119,6 +124,9 @@ async def call_model(messages: list) -> str: return await model.generate(messages) """ + # span_type is an alias for category + effective_category = span_type if span_type is not None else category + def decorator(func: Callable[..., Any]) -> Callable[..., Any]: if hasattr(func, "_hud_instrumented"): return func @@ -145,7 +153,7 @@ def _build_span( ) -> dict[str, Any]: """Build a span record.""" attributes: dict[str, Any] = { - "category": category, + "category": effective_category, "function": func_qualname, "module": func_module, "duration_ms": duration_ms, @@ -188,8 +196,8 @@ def _build_span( } def _get_trace_id() -> str | None: - """Get trace_id from current trace context.""" - from hud.trace.context import get_current_trace_headers + """Get trace_id from current eval context.""" + from hud.eval.context import get_current_trace_headers headers = get_current_trace_headers() if headers: diff --git a/hud/telemetry/tests/test_instrument.py b/hud/telemetry/tests/test_instrument.py index 1acf950a..2ffcf2f8 100644 --- a/hud/telemetry/tests/test_instrument.py +++ b/hud/telemetry/tests/test_instrument.py @@ -3,7 +3,6 @@ from dataclasses import dataclass import pytest -from opentelemetry.trace import SpanKind from hud.telemetry.instrument import _serialize_value, instrument @@ -102,7 +101,7 @@ async def test_func(x: int, y: int) -> int: async def test_instrument_async_with_params(): """Test instrument with custom parameters.""" - @instrument(name="custom_name", span_type="custom_type") + @instrument(name="custom_name", category="custom_type") async def test_func(x: int) -> int: return x * 2 @@ -147,10 +146,10 @@ async def test_func() -> str: @pytest.mark.asyncio -async def test_instrument_async_with_attributes(): - """Test instrument with custom attributes.""" +async def test_instrument_async_with_category(): + """Test instrument with custom category.""" - @instrument(attributes={"custom_attr": "value"}) + @instrument(category="agent") async def test_func() -> int: return 42 @@ -158,18 +157,6 @@ async def test_func() -> int: assert result == 42 -@pytest.mark.asyncio -async def test_instrument_async_with_span_kind(): - """Test instrument with custom span kind.""" - - @instrument(span_kind=SpanKind.CLIENT) - async def test_func() -> int: - return 1 - - result = await test_func() - assert result == 1 - - def test_instrument_sync_basic(): """Test instrument decorator on sync function.""" @@ -184,7 +171,7 @@ def test_func(x: int, y: int) -> int: def test_instrument_sync_with_params(): """Test instrument on sync function with parameters.""" - @instrument(name="sync_custom", span_type="sync_type") + @instrument(name="sync_custom", category="sync_type") def test_func(x: int) -> int: return x * 2 @@ -225,10 +212,10 @@ def test_func() -> str: assert result == "test" -def test_instrument_sync_with_attributes(): - """Test instrument sync with custom attributes.""" +def test_instrument_sync_with_category(): + """Test instrument sync with custom category.""" - @instrument(attributes={"sync_attr": "sync_value"}) + @instrument(category="tool") def test_func() -> int: return 42 diff --git a/hud/trace/__init__.py b/hud/trace/__init__.py deleted file mode 100644 index 9022c021..00000000 --- a/hud/trace/__init__.py +++ /dev/null @@ -1,42 +0,0 @@ -""" -HUD Trace System - Context management for agent runs. - -The trace system provides: -- TraceContext: Core abstraction for recording agent runs -- TraceMixin: Mixin that adds trace() method to Environment -- Auto-instrumentation of httpx for inference.hud.ai -- Parallel execution with group=N - -Usage (single execution): - ```python - async with env.trace("google-search") as tc: - await tc.call_tool("navigate", {"url": "..."}) - tc.reward = 0.9 - - # tc has the results - print(tc.trace_id, tc.reward, tc.duration, tc.success) - ``` - -Usage (parallel execution): - ```python - async with env.trace("google-search", group=4) as tc: - # This body runs 4 times, each with a different tc! - await tc.call_tool("navigate", {"url": "..."}) - tc.reward = evaluate() - - # tc.results contains all parallel traces - # tc.reward is the mean reward - print(f"Mean reward: {tc.reward}") - for trace in tc.results: - print(f" {trace.trace_id}: {trace.reward}") - ``` -""" - -from hud.trace.context import TraceContext, get_current_trace_headers -from hud.trace.mixin import TraceMixin - -__all__ = [ - "TraceContext", - "TraceMixin", - "get_current_trace_headers", -] diff --git a/hud/trace/context.py b/hud/trace/context.py deleted file mode 100644 index 312851f5..00000000 --- a/hud/trace/context.py +++ /dev/null @@ -1,470 +0,0 @@ -"""TraceContext - Lightweight context for recording agent runs. - -TraceContext provides: -- Unique trace identification -- Headers for gateway integration (auto-injected to inference.hud.ai) -- Reward and status reporting to backend -- Tool call delegation - -All telemetry goes directly to the backend - nothing accumulated locally. - -Auto-instrumentation: - httpx clients are automatically instrumented when this module is imported. - Any request to inference.hud.ai will have trace headers injected. -""" - -from __future__ import annotations - -import contextvars -import logging -import uuid -from datetime import UTC, datetime -from typing import TYPE_CHECKING, Any, Self - -from pydantic import BaseModel - -from hud.environment.types import EnvConfig -from hud.settings import settings -from hud.shared import make_request -from hud.telemetry.job import get_current_job - -if TYPE_CHECKING: - from types import TracebackType - - from hud.environment import Environment - from hud.types import MCPToolResult - -logger = logging.getLogger(__name__) - -# Contextvar to store current trace headers -_current_trace_headers: contextvars.ContextVar[dict[str, str] | None] = contextvars.ContextVar( - "current_trace_headers", default=None -) - - -def get_current_trace_headers() -> dict[str, str] | None: - """Get the current trace headers from context.""" - return _current_trace_headers.get() - - -# ============================================================================= -# Payload Models -# ============================================================================= - - -class TracePayload(BaseModel): - """Base payload for trace enter/exit - sent to both endpoints.""" - - task_name: str - prompt: str | None = None - code_snippet: str | None = None - env_config: EnvConfig | None = None - all_hubs: bool = False # True if all connectors are from connect_hub - job_id: str | None = None - group_id: str | None = None - variants: dict[str, Any] | None = None - - -class TraceExitPayload(TracePayload): - """Exit payload - includes result fields.""" - - reward: float | None = None - success: bool = True - error_message: str | None = None - - -# ============================================================================= -# Auto-instrumentation for httpx -# ============================================================================= - - -def _is_hud_url(url_str: str) -> bool: - """Check if URL is a HUD service (inference or MCP).""" - from urllib.parse import urlparse - - # Extract hostnames from settings URLs - gateway_host = urlparse(settings.hud_gateway_url).netloc - mcp_host = urlparse(settings.hud_mcp_url).netloc - - # Parse the request URL and check against known HUD hosts - parsed = urlparse(url_str) - request_host = parsed.netloc or url_str.split("/")[0] - - return request_host == gateway_host or request_host == mcp_host - - -def _httpx_request_hook(request: Any) -> None: - """httpx event hook that adds trace headers and auth to HUD requests. - - For inference.hud.ai and mcp.hud.ai: - - Injects trace headers (Trace-Id) if in trace context - - Injects Authorization header if API key is set and no auth present - """ - url_str = str(request.url) - if not _is_hud_url(url_str): - return - - # Inject trace headers if in trace context - headers = get_current_trace_headers() - if headers is not None: - for key, value in headers.items(): - request.headers[key] = value - logger.debug("Added trace headers to request: %s", url_str) - - # Auto-inject API key if not present - has_auth = "authorization" in {k.lower() for k in request.headers} - if not has_auth and settings.api_key: - request.headers["Authorization"] = f"Bearer {settings.api_key}" - logger.debug("Added API key auth to request: %s", url_str) - - -async def _async_httpx_request_hook(request: Any) -> None: - """Async version of the httpx event hook.""" - _httpx_request_hook(request) - - -def _instrument_client(client: Any) -> None: - """Add trace hook to an httpx client instance.""" - is_async = hasattr(client, "aclose") - hook = _async_httpx_request_hook if is_async else _httpx_request_hook - - existing_hooks = client.event_hooks.get("request", []) - if hook not in existing_hooks: - existing_hooks.append(hook) - client.event_hooks["request"] = existing_hooks - - -def _patch_httpx() -> None: - """Monkey-patch httpx to auto-instrument all clients.""" - try: - import httpx - except ImportError: - logger.debug("httpx not installed, skipping auto-instrumentation") - return - - _original_async_init = httpx.AsyncClient.__init__ - - def _patched_async_init(self: Any, *args: Any, **kwargs: Any) -> None: - _original_async_init(self, *args, **kwargs) - _instrument_client(self) - - httpx.AsyncClient.__init__ = _patched_async_init # type: ignore[method-assign] - - _original_sync_init = httpx.Client.__init__ - - def _patched_sync_init(self: Any, *args: Any, **kwargs: Any) -> None: - _original_sync_init(self, *args, **kwargs) - _instrument_client(self) - - httpx.Client.__init__ = _patched_sync_init # type: ignore[method-assign] - - logger.debug("httpx auto-instrumentation enabled") - - -# Auto-patch httpx on module import -_patch_httpx() - - -# ============================================================================= -# TraceContext -# ============================================================================= - - -class TraceContext: - """Lightweight context for a traced execution. - - Attributes: - trace_id: Unique identifier for this trace - name: Task name - job_id: Links to parent job (auto-detected from hud.job() context) - group_id: Links parallel traces together (None for single traces) - variants: Variant assignment dict (for A/B testing) - reward: Reward value (user-settable) - prompt: Task prompt (defaults from env.prompt, user-settable) - error: Exception if failed - results: All trace results (for parent trace) - - Computed: - headers: Gateway headers - duration: Execution time in seconds - success: True if no error - done: True if completed - - Example: - ```python - # Simple trace - async with env.trace("task") as tc: - await tc.call_tool("navigate", {"url": "..."}) - tc.reward = 0.9 - - # With variants (A/B testing) and group (multiple runs) - async with env.trace( - "task", - variants={"model": ["gpt-4o", "claude"]}, - group=3, - ) as tc: - model = tc.variants["model"] # Assigned for this run - response = await call_llm(model=model) - tc.reward = evaluate(response) - - # tc.results has 6 traces (2 variants x 3 runs each) - # All share the same tc.group_id - for t in tc.results: - print(f"{t.variants}: reward={t.reward}") - ``` - """ - - def __init__( - self, - env: Environment, - name: str, - *, - trace_id: str | None = None, - api_key: str | None = None, - job_id: str | None = None, - _group_id: str | None = None, - _index: int = 0, - _variants: dict[str, Any] | None = None, - _code_snippet: str | None = None, - _env_config: dict[str, Any] | None = None, - ) -> None: - # Identity - self.trace_id: str = trace_id or str(uuid.uuid4()) - self.name: str = name - - # Job linkage - auto-detect from current job context if not provided - if job_id is None: - current_job = get_current_job() - self.job_id: str | None = current_job.id if current_job else None - else: - self.job_id = job_id - - self.group_id: str | None = _group_id # Links parallel traces together - self.index: int = _index # Local only, for debugging - - # Variant assignment (for A/B testing) - self.variants: dict[str, Any] = _variants or {} - - # User-settable - self.reward: float | None = None - self.prompt: str | None = getattr(env, "prompt", None) # From env, can override - - # Error tracking - self.error: BaseException | None = None - - # Parallel/variant results (nested) - self.results: list[TraceContext] | None = None - - # Code and config (for reproducibility) - self.code_snippet: str | None = _code_snippet - self.env_config: dict[str, Any] | None = _env_config - - # Private - self._env = env - self._api_key = api_key - self._started_at: datetime | None = None - self._completed_at: datetime | None = None - self._token: contextvars.Token[dict[str, str] | None] | None = None - - # ========================================================================= - # Computed Properties - # ========================================================================= - - @property - def headers(self) -> dict[str, str]: - """Headers for gateway integration.""" - return {"Trace-Id": self.trace_id} - - @property - def duration(self) -> float: - """Execution duration in seconds.""" - if self._started_at is None: - return 0.0 - end = self._completed_at or datetime.now(UTC) - return (end - self._started_at).total_seconds() - - @property - def success(self) -> bool: - """True if no error occurred.""" - return self.error is None - - @property - def done(self) -> bool: - """True if execution completed.""" - return self._completed_at is not None - - def _get_api_key(self) -> str | None: - return self._api_key or settings.api_key - - def _build_base_payload(self) -> TracePayload: - """Build the base payload for enter/exit.""" - # Check if all connectors are from hubs (fully reproducible) - all_hubs = getattr(self._env, "_all_hubs", False) - - # Convert env_config dict to EnvConfig model - env_config_model: EnvConfig | None = None - if self.env_config: - env_config_model = EnvConfig(**self.env_config) - - return TracePayload( - task_name=self.name, - prompt=self.prompt, - code_snippet=self.code_snippet, - env_config=env_config_model, - all_hubs=all_hubs, - job_id=self.job_id, - group_id=self.group_id, - variants=self.variants if self.variants else None, - ) - - # ========================================================================= - # Tool Operations - # ========================================================================= - - async def call_tool( - self, - call: Any, - /, - **kwargs: Any, - ) -> Any: - """Call a tool (delegates to environment). - - Accepts any format: - - String with kwargs: call_tool("navigate", url="...") - - OpenAI tool_call: call_tool(response.choices[0].message.tool_calls[0]) - - Claude tool_use: call_tool(block) # where block.type == "tool_use" - - Gemini function_call: call_tool(part) - """ - return await self._env.call_tool(call, **kwargs) # type: ignore[attr-defined] - - # ========================================================================= - # Backend Integration - # ========================================================================= - - async def log(self, metrics: dict[str, Any]) -> None: - """Log metrics to the backend.""" - api_key = self._get_api_key() - if not settings.telemetry_enabled or not api_key: - return - - try: - await make_request( - method="POST", - url=f"{settings.hud_telemetry_url}/traces/{self.trace_id}/log", - json={"metrics": metrics}, - api_key=api_key, - ) - except Exception as e: - logger.warning("Failed to log metrics: %s", e) - - async def _trace_enter(self) -> None: - """Notify backend that trace has started.""" - api_key = self._get_api_key() - if not settings.telemetry_enabled or not api_key: - return - - try: - payload = self._build_base_payload() - await make_request( - method="POST", - url=f"{settings.hud_api_url}/trace/{self.trace_id}/enter", - json=payload.model_dump(exclude_none=True), - api_key=api_key, - ) - except Exception as e: - logger.warning("Failed to send trace enter: %s", e) - - async def _trace_exit(self, error_message: str | None = None) -> None: - """Notify backend that trace has completed.""" - api_key = self._get_api_key() - if not settings.telemetry_enabled or not api_key: - return - - # Use evaluate tool reward if not manually set - reward = self.reward - if reward is None: - reward = getattr(self._env, "_evaluate_reward", None) - - try: - payload = TraceExitPayload( - **self._build_base_payload().model_dump(), - reward=reward, - success=self.success, - error_message=error_message, - ) - await make_request( - method="POST", - url=f"{settings.hud_api_url}/trace/{self.trace_id}/exit", - json=payload.model_dump(exclude_none=True), - api_key=api_key, - ) - except Exception as e: - logger.warning("Failed to send trace exit: %s", e) - - # ========================================================================= - # Context Manager - # ========================================================================= - - async def __aenter__(self) -> Self: - self._started_at = datetime.now(UTC) - self._token = _current_trace_headers.set(self.headers) - await self._trace_enter() - self._print_trace_link() - return self - - async def __aexit__( - self, - exc_type: type[BaseException] | None, - exc_val: BaseException | None, - exc_tb: TracebackType | None, - ) -> None: - self._completed_at = datetime.now(UTC) - - if self._token is not None: - _current_trace_headers.reset(self._token) - self._token = None - - error_msg: str | None = None - if exc_type is not None: - self.error = exc_val - error_msg = str(exc_val) if exc_val else "Unknown error" - - # Send exit with all data (reward, error, etc.) - await self._trace_exit(error_msg) - - def __repr__(self) -> str: - return f"TraceContext({self.trace_id[:8]}..., name={self.name!r}, reward={self.reward})" - - def _print_trace_link(self) -> None: - """Print a nicely formatted trace link to console and open in browser.""" - import contextlib - import webbrowser - - trace_url = f"https://hud.ai/trace/{self.trace_id}" - - # Try to open in browser (new tab if possible) - with contextlib.suppress(Exception): - webbrowser.open(trace_url, new=2) - - try: - from rich.console import Console - from rich.panel import Panel - from rich.align import Align - - console = Console() - - # Style: HUD colors - gold border, purple link - link_markup = f"[bold underline rgb(108,113,196)][link={trace_url}]{trace_url}[/link][/bold underline rgb(108,113,196)]" - - content = Align.center(link_markup) - - panel = Panel( - content, - title="🔗 Trace Started", - border_style="rgb(192,150,12)", # HUD gold - padding=(0, 2), - ) - console.print(panel) - except ImportError: - # Fallback if rich not available - print(f"Trace: https://hud.ai/trace/{self.trace_id}") diff --git a/hud/trace/mixin.py b/hud/trace/mixin.py deleted file mode 100644 index ac514e16..00000000 --- a/hud/trace/mixin.py +++ /dev/null @@ -1,437 +0,0 @@ -"""TraceMixin - Adds trace() method to Environment. - -This mixin provides the trace() context manager that creates TraceContext -instances for recording agent runs, with optional parallel execution and -variant-based A/B testing. -""" - -from __future__ import annotations - -import inspect -import itertools -import logging -import uuid -from contextlib import asynccontextmanager -from typing import TYPE_CHECKING, Any - -from hud.trace.context import TraceContext -from hud.trace.parallel import ( - ASTExtractionError, - _get_with_block_body, - run_parallel_traces, -) - -if TYPE_CHECKING: - from collections.abc import AsyncGenerator - - from hud.types import MCPToolResult - -logger = logging.getLogger(__name__) - - -def _expand_variants( - variants: dict[str, Any] | None, -) -> list[dict[str, Any]]: - """Expand variants dict into all combinations. - - Args: - variants: Dict where values can be: - - Single value: {"model": "gpt-4o"} → fixed - - List: {"model": ["gpt-4o", "claude"]} → expand - - Returns: - List of variant assignments, one per combination. - - Examples: - >>> _expand_variants(None) - [{}] - >>> _expand_variants({"model": "gpt-4o"}) - [{"model": "gpt-4o"}] - >>> _expand_variants({"model": ["gpt-4o", "claude"]}) - [{"model": "gpt-4o"}, {"model": "claude"}] - >>> _expand_variants({"model": ["a", "b"], "temp": [0.0, 0.7]}) - [{"model": "a", "temp": 0.0}, {"model": "a", "temp": 0.7}, - {"model": "b", "temp": 0.0}, {"model": "b", "temp": 0.7}] - """ - if not variants: - return [{}] - - # Normalize: single values become single-element lists - expanded: dict[str, list[Any]] = {} - for key, value in variants.items(): - if isinstance(value, list): - expanded[key] = value - else: - expanded[key] = [value] - - # Generate all combinations - keys = list(expanded.keys()) - value_lists = [expanded[k] for k in keys] - - return [dict(zip(keys, combo, strict=True)) for combo in itertools.product(*value_lists)] - - -class TraceMixin: - """Mixin that adds trace capabilities to Environment. - - This mixin provides: - - trace(): Create a TraceContext for recording agent runs - - Parallel execution with group=N parameter - - A/B testing with variants parameter - - Example: - ```python - class Environment(TraceMixin, MCPServer): ... - - - env = Environment("my-env") - - # Single trace - async with env.trace("task") as tc: - await tc.call_tool("navigate", {"url": "..."}) - tc.reward = 0.9 - - # Parallel traces (runs 4 times) - async with env.trace("task", group=4) as tc: - await tc.call_tool("navigate", {"url": "..."}) - tc.reward = 0.9 - - # A/B testing (2 variants x 3 runs = 6 traces) - async with env.trace( - "task", - variants={"model": ["gpt-4o", "claude"]}, - group=3, - ) as tc: - model = tc.variants["model"] - response = await call_llm(model=model) - tc.reward = evaluate(response) - - # Access results - for t in tc.results: - print(f"{t.variants} run {t.index}: reward={t.reward}") - ``` - """ - - # These will be provided by the Environment class - name: str - - # Store last parallel results (list of completed TraceContext objects) - _last_traces: list[TraceContext] | None = None - - async def call_tool(self, name: str, arguments: dict[str, Any] | None = None) -> MCPToolResult: - """Placeholder - implemented by Environment.""" - raise NotImplementedError - - def _capture_code_snippet(self) -> str | None: - """Capture the code inside the trace() with-block (best effort). - - Returns None if source cannot be extracted (e.g., REPL, Jupyter). - """ - frame = inspect.currentframe() - if frame is None: - return None - - try: - # Go up: _capture_code_snippet -> trace -> user code - caller = frame.f_back - if caller is not None: - caller = caller.f_back - if caller is None: - return None - - body_source, _ = _get_with_block_body(caller) - return body_source - except ASTExtractionError: - # Can't extract from REPL/Jupyter - that's OK - return None - except Exception as e: - logger.debug("Failed to capture code snippet: %s", e) - return None - finally: - del frame - - def _get_env_config(self) -> dict[str, Any] | None: - """Get serializable environment configuration. - - Returns dict with connections and local tools. - """ - # This will be overridden by Environment with actual implementation - return None - - @property - def last_traces(self) -> list[TraceContext] | None: - """Get TraceContext objects from the last parallel execution. - - Each TraceContext has: trace_id, index, reward, duration, error, success - """ - return self._last_traces - - @asynccontextmanager - async def trace( - self, - name: str, - *, - variants: dict[str, Any] | None = None, - group: int = 1, - group_ids: list[str] | None = None, - job_id: str | None = None, - trace_id: str | None = None, - api_key: str | None = None, - ) -> AsyncGenerator[TraceContext, None]: - """Create a trace context for recording an agent run. - - The trace context provides: - - Unique trace identification - - Task name linking (for training data construction) - - Headers for gateway integration (auto-injected to inference.hud.ai) - - Tool call delegation - - Reward setting - - Metrics logging - - A/B Testing: - Use `variants` to define experiment variables. Each list value - creates a variant; single values are fixed. All combinations - are expanded and run. - - Parallel Execution: - Use `group` to run multiple times per variant for statistical - significance. Total traces = len(variants combinations) x group. - - Args: - name: Task name for this trace (used for task construction) - variants: A/B test configuration. Dict where: - - List values are expanded: {"model": ["gpt-4o", "claude"]} - - Single values are fixed: {"temp": 0.7} - - All combinations are run - group: Runs per variant (default: 1) for statistical significance. - group_ids: Optional list of group IDs for each trace. - Length must match (variants x group). If not provided, - a single shared group_id is auto-generated. - job_id: Optional job ID to link this trace to. If not provided, - auto-detects from current `hud.job()` context. - trace_id: Optional trace ID (auto-generated if not provided). - For parallel execution, each trace gets a unique ID. - api_key: Optional API key for backend calls (defaults to settings.api_key) - - Yields: - TraceContext for this trace. Inside the body: - - `tc.variants` = current variant assignment (e.g., {"model": "gpt-4o"}) - - `tc.index` = local run index (for debugging) - - `tc.group_id` = links all traces in this parallel execution - - After execution (for variants/group > 1): - - `tc.results` = list of all TraceContext objects - - `tc.reward` = mean reward across all traces - - Example: - ```python - # Single execution - async with env.trace("task") as tc: - await tc.call_tool("search", {"query": "..."}) - tc.reward = 1.0 - - # A/B test: 2 variants x 3 runs = 6 traces - async with env.trace( - "task", - variants={"model": ["gpt-4o", "claude"]}, - group=3, - ) as tc: - model = tc.variants["model"] # Assigned per-trace - response = await call_llm(model=model) - tc.reward = evaluate(response) - - # Access results - for t in tc.results: - print(f"{t.variants} run {t.index}: reward={t.reward}") - ``` - - Limitations (for variants/group > 1): - - Requires source file (won't work in REPL/Jupyter) - - Outer variables captured at enter time, changes don't propagate back - - Modifying mutable objects causes race conditions - - Cannot use yield/generators inside body - """ - if group <= 0: - raise ValueError("group must be >= 1") - - # Expand variants into all combinations - variant_combos = _expand_variants(variants) - total_traces = len(variant_combos) * group - - # Capture code snippet (best effort - won't work in REPL/Jupyter) - code_snippet = self._capture_code_snippet() - - # Get environment config - env_config = self._get_env_config() - - # Validate parallelization - only remote connections allowed for group > 1 - if total_traces > 1 and not self.is_parallelizable: # type: ignore[attr-defined] - local_conns = self.local_connections # type: ignore[attr-defined] - raise ValueError( - f"Cannot run parallel traces (group={group}) with local connections.\n" - f" Local connections: {local_conns}\n" - f" Local connections (stdio/Docker) can only run one instance.\n" - f" Use remote connections (HTTP/URL) for parallel execution." - ) - - if total_traces == 1: - # Simple case: single trace - # TraceContext enters FIRST (sets headers in contextvar) - # Environment enters SECOND (can inject headers into connections) - tc = TraceContext( - env=self, # type: ignore[arg-type] - name=name, - trace_id=trace_id, - api_key=api_key, - job_id=job_id, - _variants=variant_combos[0], - _code_snippet=code_snippet, - _env_config=env_config, - ) - async with tc, self: # type: ignore[attr-defined] - yield tc - else: - # Parallel execution: each trace gets its own environment instance - # Parent environment NOT entered - each child connects independently - completed = await self._run_parallel_trace( - name=name, - variant_combos=variant_combos, - group=group, - group_ids=group_ids, - job_id=job_id, - api_key=api_key, - code_snippet=code_snippet, - env_config=env_config, - ) - - # Create parent tc with results injected - tc = TraceContext( - env=self, # type: ignore[arg-type] - name=name, - trace_id=trace_id, - api_key=api_key, - job_id=job_id, - _code_snippet=code_snippet, - _env_config=env_config, - ) - tc.results = completed - self._last_traces = completed - - # Compute aggregate reward (mean of non-None rewards) - rewards = [t.reward for t in completed if t.reward is not None] - if rewards: - tc.reward = sum(rewards) / len(rewards) - - yield tc - - async def _run_parallel_trace( - self, - name: str, - variant_combos: list[dict[str, Any]], - group: int, - group_ids: list[str] | None, - job_id: str | None, - api_key: str | None, - code_snippet: str | None, - env_config: dict[str, Any] | None, - ) -> list[TraceContext]: - """Run parallel trace execution using AST extraction. - - This method: - 1. Captures the caller's frame - 2. Extracts the with-block body via AST - 3. Creates (variants x group) TraceContext instances - 4. Runs the body in parallel - 5. Stores results in self._last_traces - - Args: - name: Task name - variant_combos: List of variant assignments (one per combination) - group: Runs per variant - group_ids: Optional list of group IDs (one per total trace) - job_id: Optional job ID (auto-detected from current job if not provided) - api_key: Optional API key - code_snippet: Captured code from the with-block - env_config: Environment configuration - """ - # Get the caller's frame (skip this method and the trace method) - frame = inspect.currentframe() - if frame is None: - raise ASTExtractionError("Cannot get current frame") - - try: - # Go up: _run_parallel_trace -> trace -> user code - caller_frame = frame.f_back - if caller_frame is not None: - caller_frame = caller_frame.f_back - if caller_frame is None: - raise ASTExtractionError("Cannot get caller frame") - - # Extract the with-block body - body_source, captured_locals = _get_with_block_body(caller_frame) - - finally: - del frame # Avoid reference cycles - - # Calculate total traces - total_traces = len(variant_combos) * group - - # Use provided group_ids or generate one shared group_id - if group_ids: - if len(group_ids) != total_traces: - raise ValueError( - f"group_ids length ({len(group_ids)}) must match " - f"total traces ({total_traces} = {len(variant_combos)} variants x {group} runs)" - ) - resolved_group_ids = group_ids - else: - # All traces share one auto-generated group_id - shared_group_id = str(uuid.uuid4()) - resolved_group_ids = [shared_group_id] * total_traces - - # Create TraceContext for each (variant, run) combination - trace_contexts: list[TraceContext] = [] - idx = 0 - for variant in variant_combos: - for _ in range(group): - tc = TraceContext( - env=self, # type: ignore[arg-type] - name=name, - api_key=api_key, - job_id=job_id, - _group_id=resolved_group_ids[idx], - _index=idx, - _variants=variant, - _code_snippet=code_snippet, - _env_config=env_config, - ) - trace_contexts.append(tc) - idx += 1 - - # Run in parallel - total = len(trace_contexts) - logger.info( - "Running %d traces for task '%s' (%d variants x %d runs)", - total, - name, - len(variant_combos), - group, - ) - completed = await run_parallel_traces(trace_contexts, body_source, captured_locals) - - # Store results - self._last_traces = completed - - # Calculate stats - rewards = [tc.reward for tc in completed if tc.reward is not None] - mean_reward = sum(rewards) / len(rewards) if rewards else 0.0 - success_count = sum(1 for tc in completed if tc.success) - - logger.info( - "Traces complete: %d/%d succeeded, mean_reward=%.3f", - success_count, - len(completed), - mean_reward, - ) - - return completed diff --git a/hud/trace/parallel.py b/hud/trace/parallel.py deleted file mode 100644 index 60b3bb8f..00000000 --- a/hud/trace/parallel.py +++ /dev/null @@ -1,127 +0,0 @@ -"""Parallel execution support for traces. - -This module provides AST extraction and parallel execution for running -the same trace body N times concurrently. -""" - -from __future__ import annotations - -import ast -import asyncio -import linecache -import logging -import textwrap -from typing import TYPE_CHECKING, Any - -if TYPE_CHECKING: - from hud.trace.context import TraceContext - -logger = logging.getLogger(__name__) - - -class ASTExtractionError(Exception): - """Error extracting AST from source.""" - - -def _get_with_block_body(frame: Any) -> tuple[str, dict[str, Any]]: - """Extract the body of a with-block from the calling frame. - - Args: - frame: The calling frame (from inspect.currentframe()) - - Returns: - Tuple of (body_source, captured_locals) - """ - filename = frame.f_code.co_filename - lineno = frame.f_lineno - - # Check for interactive session - if filename.startswith("<") or filename in ("", ""): - raise ASTExtractionError("Cannot extract source from interactive session. Use a .py file.") - - # Read and parse source - lines = linecache.getlines(filename) - if not lines: - with open(filename, encoding="utf-8") as f: - lines = f.readlines() - - source = "".join(lines) - tree = ast.parse(source, filename=filename) - - # Find the async with containing this line - with_node = _find_async_with(tree, lineno) - if with_node is None: - raise ASTExtractionError(f"Cannot find 'async with' statement at line {lineno}") - - # Extract body source - body_source = _extract_body(lines, with_node) - - return body_source, frame.f_locals.copy() - - -def _find_async_with(tree: ast.AST, target_line: int) -> ast.AsyncWith | None: - """Find AsyncWith node containing the target line.""" - for node in ast.walk(tree): - if isinstance(node, ast.AsyncWith): - end_line = _get_end_line(node) - if node.lineno <= target_line <= end_line: - return node - return None - - -def _get_end_line(node: ast.AST) -> int: - """Get the last line number of an AST node.""" - end = getattr(node, "end_lineno", getattr(node, "lineno", 0)) - for child in ast.walk(node): - child_end = getattr(child, "end_lineno", 0) - if child_end > end: - end = child_end - return end - - -def _extract_body(lines: list[str], with_node: ast.AsyncWith) -> str: - """Extract the body source from an AsyncWith node.""" - if not with_node.body: - return "pass" - - start = with_node.body[0].lineno - 1 - end = _get_end_line(with_node.body[-1]) - - body = "".join(lines[start:end]) - return textwrap.dedent(body) - - -async def run_parallel_traces( - trace_contexts: list[TraceContext], - body_source: str, - captured_locals: dict[str, Any], -) -> list[TraceContext]: - """Run the trace body in parallel for multiple contexts. - - Returns the TraceContext objects after execution - they contain: - - trace_id - - index - - reward - - duration - - Any error is captured in the context - """ - - # Create runner function - wrapped = f"async def __runner__(tc):\n{textwrap.indent(body_source, ' ')}" - code = compile(wrapped, "", "exec") - namespace = captured_locals.copy() - exec(code, namespace) # noqa: S102 - runner = namespace["__runner__"] - - async def run_one(tc: TraceContext) -> TraceContext: - try: - async with tc: - await runner(tc) - except Exception as e: - logger.warning("Parallel trace %d failed: %s", tc.index, e) - # Store error in context for inspection - tc._error = e # type: ignore[attr-defined] - return tc - - results = await asyncio.gather(*[run_one(tc) for tc in trace_contexts]) - return list(results) diff --git a/hud/trace/tests/__init__.py b/hud/trace/tests/__init__.py deleted file mode 100644 index 93f3ee87..00000000 --- a/hud/trace/tests/__init__.py +++ /dev/null @@ -1 +0,0 @@ -"""Tests for hud.trace module.""" diff --git a/hud/trace/tests/test_context.py b/hud/trace/tests/test_context.py deleted file mode 100644 index 38ccfbce..00000000 --- a/hud/trace/tests/test_context.py +++ /dev/null @@ -1,292 +0,0 @@ -"""Tests for hud.trace.context module.""" - -from __future__ import annotations - -from unittest.mock import AsyncMock, MagicMock, patch - -import pytest - -from hud.trace.context import ( - TraceContext, - _httpx_request_hook, - _is_hud_url, - get_current_trace_headers, -) - - -class TestIsHudUrl: - """Tests for _is_hud_url helper.""" - - def test_inference_hud_ai_is_hud(self) -> None: - """inference.hud.ai is a HUD URL.""" - assert _is_hud_url("https://inference.hud.ai/v1/chat") is True - assert _is_hud_url("http://inference.hud.ai/v1/chat") is True - - def test_mcp_hud_ai_is_hud(self) -> None: - """mcp.hud.ai is a HUD URL.""" - assert _is_hud_url("https://mcp.hud.ai/browser") is True - assert _is_hud_url("http://mcp.hud.ai/some/path") is True - - def test_mcp_hud_so_is_hud(self) -> None: - """mcp.hud.so is a HUD URL.""" - assert _is_hud_url("https://mcp.hud.so/browser") is True - - def test_other_urls_are_not_hud(self) -> None: - """Other URLs are not HUD URLs.""" - assert _is_hud_url("https://example.com") is False - assert _is_hud_url("https://api.openai.com") is False - assert _is_hud_url("https://notinference.hud.ai.fake.com") is False - - -class TestHttpxRequestHook: - """Tests for _httpx_request_hook.""" - - def test_injects_trace_headers_for_hud_urls(self) -> None: - """Hook injects trace headers for HUD URLs when in trace context.""" - mock_request = MagicMock() - mock_request.url = "https://inference.hud.ai/v1/chat" - mock_request.headers = {} - - # Set up trace context - from hud.trace.context import _current_trace_headers - - token = _current_trace_headers.set({"Trace-Id": "test-trace-123"}) - - try: - _httpx_request_hook(mock_request) - - assert mock_request.headers["Trace-Id"] == "test-trace-123" - finally: - _current_trace_headers.reset(token) - - def test_injects_api_key_for_hud_urls(self) -> None: - """Hook injects API key for HUD URLs when no auth present.""" - mock_request = MagicMock() - mock_request.url = "https://mcp.hud.ai/browser" - mock_request.headers = {} - - with patch("hud.trace.context.settings") as mock_settings: - mock_settings.api_key = "test-api-key" - - _httpx_request_hook(mock_request) - - assert mock_request.headers["Authorization"] == "Bearer test-api-key" - - def test_does_not_override_existing_auth(self) -> None: - """Hook does not override existing Authorization header.""" - mock_request = MagicMock() - mock_request.url = "https://mcp.hud.ai/browser" - mock_request.headers = {"Authorization": "Bearer existing-token"} - - with patch("hud.trace.context.settings") as mock_settings: - mock_settings.api_key = "test-api-key" - - _httpx_request_hook(mock_request) - - assert mock_request.headers["Authorization"] == "Bearer existing-token" - - def test_ignores_non_hud_urls(self) -> None: - """Hook ignores non-HUD URLs.""" - mock_request = MagicMock() - mock_request.url = "https://api.openai.com/v1/chat" - mock_request.headers = {} - - # Set up trace context - from hud.trace.context import _current_trace_headers - - token = _current_trace_headers.set({"Trace-Id": "test-trace-123"}) - - try: - with patch("hud.trace.context.settings") as mock_settings: - mock_settings.api_key = "test-api-key" - - _httpx_request_hook(mock_request) - - # No headers should be added - assert "Trace-Id" not in mock_request.headers - assert "Authorization" not in mock_request.headers - finally: - _current_trace_headers.reset(token) - - -class TestTraceContext: - """Tests for TraceContext.""" - - def test_init_generates_trace_id(self) -> None: - """TraceContext generates trace_id if not provided.""" - mock_env = MagicMock() - tc = TraceContext(env=mock_env, name="test-task") - - assert tc.trace_id is not None - assert len(tc.trace_id) == 36 # UUID format - - def test_init_uses_provided_trace_id(self) -> None: - """TraceContext uses provided trace_id.""" - mock_env = MagicMock() - tc = TraceContext(env=mock_env, name="test-task", trace_id="custom-id") - - assert tc.trace_id == "custom-id" - - def test_headers_contains_trace_id(self) -> None: - """headers property returns dict with trace ID.""" - mock_env = MagicMock() - tc = TraceContext(env=mock_env, name="test-task", trace_id="test-123") - - assert tc.headers == {"Trace-Id": "test-123"} - - def test_success_true_when_no_error(self) -> None: - """success property returns True when no error.""" - mock_env = MagicMock() - tc = TraceContext(env=mock_env, name="test-task") - - assert tc.success is True - - def test_success_false_when_error(self) -> None: - """success property returns False when error is set.""" - mock_env = MagicMock() - tc = TraceContext(env=mock_env, name="test-task") - tc.error = ValueError("test error") - - assert tc.success is False - - def test_done_false_initially(self) -> None: - """done property returns False initially.""" - mock_env = MagicMock() - tc = TraceContext(env=mock_env, name="test-task") - - assert tc.done is False - - def test_variants_empty_by_default(self) -> None: - """variants is empty dict by default.""" - mock_env = MagicMock() - tc = TraceContext(env=mock_env, name="test-task") - - assert tc.variants == {} - - def test_variants_set_from_init(self) -> None: - """variants set from _variants parameter.""" - mock_env = MagicMock() - tc = TraceContext( - env=mock_env, - name="test-task", - _variants={"model": "gpt-4o", "temp": 0.7}, - ) - - assert tc.variants == {"model": "gpt-4o", "temp": 0.7} - - @pytest.mark.asyncio - async def test_context_manager_sets_headers(self) -> None: - """Context manager sets trace headers in contextvar.""" - mock_env = MagicMock() - tc = TraceContext(env=mock_env, name="test-task", trace_id="test-123") - - # Mock telemetry calls - with patch.object(tc, "_trace_enter", new_callable=AsyncMock): - with patch.object(tc, "_trace_exit", new_callable=AsyncMock): - assert get_current_trace_headers() is None - - async with tc: - headers = get_current_trace_headers() - assert headers is not None - assert headers["Trace-Id"] == "test-123" - - assert get_current_trace_headers() is None - - @pytest.mark.asyncio - async def test_context_manager_captures_error(self) -> None: - """Context manager captures exception in error field.""" - mock_env = MagicMock() - tc = TraceContext(env=mock_env, name="test-task") - - with patch.object(tc, "_trace_enter", new_callable=AsyncMock): - with patch.object(tc, "_trace_exit", new_callable=AsyncMock): - with pytest.raises(ValueError): - async with tc: - raise ValueError("test error") - - assert tc.error is not None - assert str(tc.error) == "test error" - assert tc.success is False - - @pytest.mark.asyncio - async def test_call_tool_delegates_to_env(self) -> None: - """call_tool delegates to environment.""" - mock_env = MagicMock() - mock_env.call_tool = AsyncMock(return_value="result") - - tc = TraceContext(env=mock_env, name="test-task") - result = await tc.call_tool("my_tool", {"arg": "value"}) - - mock_env.call_tool.assert_called_once_with("my_tool", {"arg": "value"}) - assert result == "result" - - def test_repr(self) -> None: - """__repr__ shows useful info.""" - mock_env = MagicMock() - tc = TraceContext( - env=mock_env, name="test-task", trace_id="abc12345-6789-0000-0000-000000000000" - ) - tc.reward = 0.95 - - repr_str = repr(tc) - assert "abc12345" in repr_str - assert "test-task" in repr_str - assert "0.95" in repr_str - - -class TestTraceContextPrompt: - """Tests for TraceContext.prompt feature.""" - - def test_prompt_defaults_from_env(self) -> None: - """TraceContext.prompt defaults from env.prompt.""" - mock_env = MagicMock() - mock_env.prompt = "Task prompt from environment" - - tc = TraceContext( - env=mock_env, - name="test-task", - trace_id="test-123", - ) - - assert tc.prompt == "Task prompt from environment" - - def test_prompt_none_when_env_has_no_prompt(self) -> None: - """TraceContext.prompt is None when env has no prompt.""" - mock_env = MagicMock(spec=[]) # No prompt attribute - - tc = TraceContext( - env=mock_env, - name="test-task", - trace_id="test-123", - ) - - assert tc.prompt is None - - def test_prompt_can_be_overridden(self) -> None: - """TraceContext.prompt can be set to override env default.""" - mock_env = MagicMock() - mock_env.prompt = "Original prompt" - - tc = TraceContext( - env=mock_env, - name="test-task", - trace_id="test-123", - ) - - tc.prompt = "Overridden prompt" - assert tc.prompt == "Overridden prompt" - - def test_prompt_included_in_payload(self) -> None: - """Prompt is included in trace payload.""" - mock_env = MagicMock() - mock_env.prompt = "Test prompt" - mock_env._all_hubs = False - - tc = TraceContext( - env=mock_env, - name="test-task", - trace_id="test-123", - ) - - payload = tc._build_base_payload() - assert payload.prompt == "Test prompt" diff --git a/hud/trace/tests/test_mixin.py b/hud/trace/tests/test_mixin.py deleted file mode 100644 index c6b90a33..00000000 --- a/hud/trace/tests/test_mixin.py +++ /dev/null @@ -1,175 +0,0 @@ -"""Tests for hud.trace.mixin module.""" - -from __future__ import annotations - -from typing import Any -from unittest.mock import MagicMock - -import pytest - -from hud.trace.mixin import TraceMixin, _expand_variants - - -class TestExpandVariants: - """Tests for _expand_variants helper.""" - - def test_none_returns_empty_dict(self) -> None: - """None variants returns list with empty dict.""" - result = _expand_variants(None) - assert result == [{}] - - def test_empty_dict_returns_empty_dict(self) -> None: - """Empty variants returns list with empty dict.""" - result = _expand_variants({}) - assert result == [{}] - - def test_single_value_stays_single(self) -> None: - """Single non-list value stays as single variant.""" - result = _expand_variants({"model": "gpt-4o"}) - assert result == [{"model": "gpt-4o"}] - - def test_list_expands_to_variants(self) -> None: - """List value expands to multiple variants.""" - result = _expand_variants({"model": ["gpt-4o", "claude"]}) - assert result == [{"model": "gpt-4o"}, {"model": "claude"}] - - def test_multiple_lists_create_combinations(self) -> None: - """Multiple lists create all combinations.""" - result = _expand_variants( - { - "model": ["a", "b"], - "temp": [0.0, 1.0], - } - ) - - assert len(result) == 4 - assert {"model": "a", "temp": 0.0} in result - assert {"model": "a", "temp": 1.0} in result - assert {"model": "b", "temp": 0.0} in result - assert {"model": "b", "temp": 1.0} in result - - def test_mixed_single_and_list(self) -> None: - """Mixed single values and lists work correctly.""" - result = _expand_variants( - { - "model": ["gpt-4o", "claude"], - "temp": 0.7, - } - ) - - assert len(result) == 2 - assert {"model": "gpt-4o", "temp": 0.7} in result - assert {"model": "claude", "temp": 0.7} in result - - -class MockEnvironment(TraceMixin): - """Mock environment for testing TraceMixin.""" - - def __init__(self) -> None: - self.name = "test-env" - self._connections: dict[str, Any] = {} - self._last_traces = None - - @property - def is_parallelizable(self) -> bool: - return all(getattr(c, "is_remote", True) for c in self._connections.values()) - - @property - def local_connections(self) -> list[str]: - return [name for name, c in self._connections.items() if getattr(c, "is_local", False)] - - async def call_tool(self, name: str, arguments: dict[str, Any] | None = None) -> Any: - return {"name": name, "arguments": arguments} - - async def __aenter__(self) -> MockEnvironment: - return self - - async def __aexit__(self, *args: Any) -> None: - pass - - -class TestTraceMixin: - """Tests for TraceMixin.""" - - @pytest.mark.asyncio - async def test_trace_single_creates_context(self) -> None: - """trace() with group=1 creates single TraceContext.""" - env = MockEnvironment() - - async with env.trace("test-task") as tc: - assert tc.name == "test-task" - assert tc.trace_id is not None - assert tc.variants == {} - - @pytest.mark.asyncio - async def test_trace_sets_reward(self) -> None: - """reward can be set on TraceContext.""" - env = MockEnvironment() - - async with env.trace("test-task") as tc: - tc.reward = 0.95 - - assert tc.reward == 0.95 - - @pytest.mark.asyncio - async def test_trace_with_variants_single(self) -> None: - """trace() with single variant value works.""" - env = MockEnvironment() - - async with env.trace("test-task", variants={"model": "gpt-4o"}) as tc: - assert tc.variants == {"model": "gpt-4o"} - - @pytest.mark.asyncio - async def test_trace_rejects_parallel_with_local_connections(self) -> None: - """trace() raises error for parallel with local connections.""" - env = MockEnvironment() - - # Add a local connection - mock_conn = MagicMock() - mock_conn.is_local = True - mock_conn.is_remote = False - env._connections["local-server"] = mock_conn - - with pytest.raises(ValueError, match="Cannot run parallel traces"): - async with env.trace("test-task", group=2) as tc: - pass - - @pytest.mark.asyncio - async def test_trace_allows_parallel_with_remote_connections(self) -> None: - """trace() allows parallel with only remote connections.""" - env = MockEnvironment() - - # Add a remote connection - mock_conn = MagicMock() - mock_conn.is_local = False - mock_conn.is_remote = True - env._connections["remote-server"] = mock_conn - - # This should not raise (though parallel execution is complex to test) - # Just verify it doesn't raise the local connection error - assert env.is_parallelizable is True - - @pytest.mark.asyncio - async def test_trace_rejects_zero_group(self) -> None: - """trace() raises error for group <= 0.""" - env = MockEnvironment() - - with pytest.raises(ValueError, match="group must be >= 1"): - async with env.trace("test-task", group=0) as tc: - pass - - def test_last_traces_none_initially(self) -> None: - """last_traces is None before any parallel execution.""" - env = MockEnvironment() - assert env.last_traces is None - - @pytest.mark.asyncio - async def test_trace_context_delegates_call_tool(self) -> None: - """TraceContext.call_tool delegates to environment.""" - env = MockEnvironment() - - async with env.trace("test-task") as tc: - result = await tc.call_tool("my_tool", {"arg": "value"}) - - assert result["name"] == "my_tool" - assert result["arguments"] == {"arg": "value"} diff --git a/hud/trace/tests/test_parallel.py b/hud/trace/tests/test_parallel.py deleted file mode 100644 index cf8056e2..00000000 --- a/hud/trace/tests/test_parallel.py +++ /dev/null @@ -1,155 +0,0 @@ -"""Tests for hud.trace.parallel module.""" - -from __future__ import annotations - -import ast -from unittest.mock import AsyncMock, MagicMock - -import pytest - -from hud.trace.parallel import ( - ASTExtractionError, - _extract_body, - _find_async_with, - _get_end_line, - run_parallel_traces, -) - - -class TestASTHelpers: - """Tests for AST helper functions.""" - - def test_find_async_with_finds_correct_node(self) -> None: - """_find_async_with finds the async with containing target line.""" - source = """ -async def main(): - x = 1 - async with something as ctx: - do_stuff() - more_stuff() - y = 2 -""" - tree = ast.parse(source) - - # Line 4 is inside the async with - node = _find_async_with(tree, 5) - assert node is not None - assert isinstance(node, ast.AsyncWith) - - def test_find_async_with_returns_none_when_not_found(self) -> None: - """_find_async_with returns None when line is outside async with.""" - source = """ -async def main(): - x = 1 - async with something as ctx: - do_stuff() - y = 2 -""" - tree = ast.parse(source) - - # Line 6 is outside the async with - node = _find_async_with(tree, 7) - assert node is None - - def test_get_end_line(self) -> None: - """_get_end_line returns last line of node.""" - source = """ -async with ctx: - line1() - line2() - line3() -""" - tree = ast.parse(source) - async_with = tree.body[0] - - end_line = _get_end_line(async_with) - assert end_line >= 4 # At least through line 4 - - def test_extract_body(self) -> None: - """_extract_body extracts the body source from async with.""" - source = """async with ctx: - do_thing() - more_thing() -""" - lines = source.split("\n") - lines = [line + "\n" for line in lines] - - tree = ast.parse(source) - async_with = tree.body[0] - - body = _extract_body(lines, async_with) - assert "do_thing()" in body - assert "more_thing()" in body - - -class TestRunParallelTraces: - """Tests for run_parallel_traces function.""" - - @pytest.mark.asyncio - async def test_runs_body_for_each_context(self) -> None: - """run_parallel_traces runs body for each TraceContext.""" - # Create mock trace contexts - mock_tcs = [] - for i in range(3): - tc = MagicMock() - tc.index = i - tc.__aenter__ = AsyncMock(return_value=tc) - tc.__aexit__ = AsyncMock(return_value=None) - mock_tcs.append(tc) - - # Simple body that sets reward - body_source = "tc.reward = tc.index * 10" - captured_locals: dict[str, object] = {} - - results = await run_parallel_traces(mock_tcs, body_source, captured_locals) - - assert len(results) == 3 - # Each context should have had __aenter__ and __aexit__ called - for tc in mock_tcs: - tc.__aenter__.assert_called_once() - tc.__aexit__.assert_called_once() - - @pytest.mark.asyncio - async def test_captures_exceptions(self) -> None: - """run_parallel_traces captures exceptions in context.""" - tc = MagicMock() - tc.index = 0 - tc.__aenter__ = AsyncMock(return_value=tc) - tc.__aexit__ = AsyncMock(return_value=None) - - # Body that raises - body_source = "raise ValueError('test error')" - captured_locals: dict[str, object] = {} - - results = await run_parallel_traces([tc], body_source, captured_locals) - - assert len(results) == 1 - # Error should be captured, not raised - assert hasattr(tc, "_error") or tc.__aexit__.called - - @pytest.mark.asyncio - async def test_uses_captured_locals(self) -> None: - """run_parallel_traces uses captured locals in body execution.""" - tc = MagicMock() - tc.index = 0 - tc.result = None - tc.__aenter__ = AsyncMock(return_value=tc) - tc.__aexit__ = AsyncMock(return_value=None) - - # Body that uses captured local - body_source = "tc.result = my_value * 2" - captured_locals = {"my_value": 21} - - results = await run_parallel_traces([tc], body_source, captured_locals) - - assert len(results) == 1 - - -class TestASTExtractionError: - """Tests for ASTExtractionError.""" - - def test_is_exception(self) -> None: - """ASTExtractionError is an exception.""" - error = ASTExtractionError("test message") - assert isinstance(error, Exception) - assert str(error) == "test message" From 478245e19b65472e171b8cb61b78ea2efe771043 Mon Sep 17 00:00:00 2001 From: lorenss-m Date: Tue, 9 Dec 2025 08:49:41 -0800 Subject: [PATCH 06/92] cleanup and quality --- hud/cli/__init__.py | 94 ----- hud/cli/rft.py | 2 +- hud/cli/rft_status.py | 2 +- hud/cli/rl/__init__.py | 180 -------- hud/cli/rl/config.py | 101 ----- hud/cli/rl/display.py | 133 ------ hud/cli/rl/gpu.py | 63 --- hud/cli/rl/gpu_utils.py | 321 -------------- hud/cli/rl/local_runner.py | 607 --------------------------- hud/cli/rl/presets.py | 96 ----- hud/cli/rl/remote_runner.py | 463 --------------------- hud/cli/rl/rl_api.py | 150 ------- hud/cli/rl/vllm.py | 179 -------- hud/cli/rl/wait_utils.py | 89 ---- hud/cli/{rl => utils}/celebrate.py | 27 +- hud/cli/{rl => utils}/viewer.py | 3 +- hud/environment/connection.py | 14 + hud/eval/context.py | 23 +- hud/eval/manager.py | 128 ++++-- hud/eval/mixin.py | 102 +++-- hud/eval/parallel.py | 87 +++- hud/rl/README.md | 30 -- hud/rl/__init__.py | 1 - hud/rl/actor.py | 178 -------- hud/rl/buffer.py | 405 ------------------ hud/rl/chat_template.jinja | 101 ----- hud/rl/config.py | 193 --------- hud/rl/distributed.py | 132 ------ hud/rl/learner.py | 648 ----------------------------- hud/rl/tests/__init__.py | 1 - hud/rl/tests/test_learner.py | 186 --------- hud/rl/train.py | 394 ------------------ hud/rl/types.py | 101 ----- hud/rl/utils.py | 524 ----------------------- hud/rl/utils/start_vllm_server.sh | 30 -- hud/rl/vllm_adapter.py | 143 ------- 36 files changed, 294 insertions(+), 5637 deletions(-) delete mode 100644 hud/cli/rl/__init__.py delete mode 100644 hud/cli/rl/config.py delete mode 100644 hud/cli/rl/display.py delete mode 100644 hud/cli/rl/gpu.py delete mode 100644 hud/cli/rl/gpu_utils.py delete mode 100644 hud/cli/rl/local_runner.py delete mode 100644 hud/cli/rl/presets.py delete mode 100644 hud/cli/rl/remote_runner.py delete mode 100644 hud/cli/rl/rl_api.py delete mode 100644 hud/cli/rl/vllm.py delete mode 100644 hud/cli/rl/wait_utils.py rename hud/cli/{rl => utils}/celebrate.py (86%) rename hud/cli/{rl => utils}/viewer.py (98%) delete mode 100644 hud/rl/README.md delete mode 100644 hud/rl/__init__.py delete mode 100644 hud/rl/actor.py delete mode 100644 hud/rl/buffer.py delete mode 100644 hud/rl/chat_template.jinja delete mode 100644 hud/rl/config.py delete mode 100644 hud/rl/distributed.py delete mode 100644 hud/rl/learner.py delete mode 100644 hud/rl/tests/__init__.py delete mode 100644 hud/rl/tests/test_learner.py delete mode 100644 hud/rl/train.py delete mode 100644 hud/rl/types.py delete mode 100644 hud/rl/utils.py delete mode 100755 hud/rl/utils/start_vllm_server.sh delete mode 100644 hud/rl/vllm_adapter.py diff --git a/hud/cli/__init__.py b/hud/cli/__init__.py index ae7f1b16..96b4fce4 100644 --- a/hud/cli/__init__.py +++ b/hud/cli/__init__.py @@ -961,100 +961,6 @@ def get( ) -@app.command() -def rl( - tasks_file: str | None = typer.Argument( - None, - help=( - "Path to tasks file (JSON/JSONL) or HuggingFace dataset name. " - "If not provided, looks for tasks.json or tasks.jsonl in current directory." - ), - ), - model: str | None = typer.Argument( - None, - help="Model to train from https://hud.ai/models (default: interactive selection)", - ), - config_file: Path | None = typer.Option( # noqa: B008 - None, - "--config", - "-c", - help="Path to existing configuration file", - ), - output_dir: str = typer.Option( - "checkpoints", - "--output-dir", - "-o", - help="Output directory for checkpoints", - ), - restart: bool = typer.Option( - False, - "--restart", - help="Restart the vLLM server before training", - ), - verbose: bool = typer.Option( - False, - "--verbose", - "-v", - help="Enable verbose output", - ), - local: bool = typer.Option( - False, - "--local", - help="Run training locally instead of using remote API server", - ), - no_ddp: bool = typer.Option( - False, - "--no-ddp", - help="Disable DDP even with multiple GPUs", - ), - ddp_gpus: str | None = typer.Option( - None, - "--ddp-gpus", - help="Specific GPUs for DDP (e.g., '0,1,2,3')", - ), - yes: bool = typer.Option( - False, - "--yes", - "-y", - help="Auto-accept all prompts and use defaults (lazy mode)", - ), - vllm_gpu: int | None = typer.Option( - None, - "--vllm-gpu", - help="Specific GPU for vLLM server", - ), - vllm_gpu_count: int = typer.Option( - 1, - "--vllm-gpu-count", - help="Number of GPUs for vLLM server", - ), - skip_vllm_startup: bool = typer.Option( - False, - "--skip_vllm_startup", - help="Skip the vLLM server startup", - ), -) -> None: - """🎯 Run GRPO reinforcement learning training on tasks.""" - # Import from the rl module - from .rl import rl_command - - rl_command( - tasks_file=tasks_file, - model=model, - config_file=config_file, - output_dir=output_dir, - restart=restart, - verbose=verbose, - local=local, - no_ddp=no_ddp, - ddp_gpus=ddp_gpus, - vllm_gpu=vllm_gpu, - vllm_gpu_count=vllm_gpu_count, - yes=yes, - skip_vllm_startup=skip_vllm_startup, - ) - - @app.command() def convert( tasks_file: str = typer.Argument( diff --git a/hud/cli/rft.py b/hud/cli/rft.py index 1c910b73..43c35e94 100644 --- a/hud/cli/rft.py +++ b/hud/cli/rft.py @@ -243,7 +243,7 @@ def rft_command( hud_console.info("Skipping task preview in auto-accept mode (--yes)") else: try: - from hud.cli.rl.viewer import show_json_interactive + from hud.cli.utils.viewer import show_json_interactive hud_console.section_title("Task Preview") show_json_interactive( diff --git a/hud/cli/rft_status.py b/hud/cli/rft_status.py index e04e9b4a..55566a39 100644 --- a/hud/cli/rft_status.py +++ b/hud/cli/rft_status.py @@ -6,7 +6,7 @@ import typer from rich.console import Console -from hud.cli.rl.viewer import show_json_interactive +from hud.cli.utils.viewer import show_json_interactive from hud.settings import settings from hud.utils.hud_console import HUDConsole diff --git a/hud/cli/rl/__init__.py b/hud/cli/rl/__init__.py deleted file mode 100644 index 57b29546..00000000 --- a/hud/cli/rl/__init__.py +++ /dev/null @@ -1,180 +0,0 @@ -"""RL training command for HUD CLI.""" - -from __future__ import annotations - -import logging -import os -from typing import TYPE_CHECKING - -import typer -from rich.console import Console - -from hud.cli.utils.tasks import find_tasks_file -from hud.utils.hud_console import hud_console - -console = Console() - -if TYPE_CHECKING: - from pathlib import Path - - -def rl_command( - tasks_file: str | None = typer.Argument( - None, - help="Path to tasks file (JSON/JSONL) or HuggingFace dataset name", - ), - model: str | None = typer.Argument( - None, - help="Model to train from https://hud.ai/models (default: interactive selection)", - ), - config_file: Path | None = typer.Option( # noqa: B008 - None, - "--config", - "-c", - help="Path to existing configuration file", - ), - output_dir: str = typer.Option( - "/checkpoints", - "--output-dir", - "-o", - help="Output directory for checkpoints", - ), - restart: bool = typer.Option( - False, - "--restart", - help="Restart the vLLM server before training", - ), - verbose: bool = typer.Option( - False, - "--verbose", - "-v", - help="Enable verbose output", - ), - # DDP options - no_ddp: bool = typer.Option( - False, - "--no-ddp", - help="Disable DDP even with multiple GPUs", - ), - ddp_gpus: str | None = typer.Option( - None, - "--ddp-gpus", - help="Specific GPUs for DDP (e.g., '0,1,2,3')", - ), - vllm_gpu: int | None = typer.Option( - None, - "--vllm-gpu", - help="Specific GPU for vLLM server", - ), - # Execution mode options - local: bool = typer.Option( - False, - "--local", - help="Run training locally instead of using remote API server", - ), - yes: bool = typer.Option( - False, - "--yes", - "-y", - help="Auto-accept all prompts and use defaults (lazy mode)", - ), - vllm_gpu_count: int = typer.Option( - None, - "--vllm-gpu-count", - help="Number of GPUs for vLLM server", - ), - skip_vllm_startup: bool = typer.Option( - False, - "--skip-vllm-startup", - help="Skip local vLLM server startup (for internal use)", - ), -) -> None: - """Run GRPO reinforcement learning training on tasks.""" - # Configure logging based on verbose flag BEFORE any output - if not verbose: - os.environ["HUD_LOG_LEVEL"] = "WARNING" - logging.basicConfig(level=logging.WARNING, force=True) - root_logger = logging.getLogger() - root_logger.setLevel(logging.WARNING) - - # Suppress INFO logs from various components - for logger_name in [ - "httpx", - "hud.agents", - "hud.utils.design", - "hud", - "asyncio", - "transformers", - ]: - logging.getLogger(logger_name).setLevel(logging.WARNING) - logging.getLogger("hud.agents.base").setLevel(logging.WARNING) - else: - logging.basicConfig(level=logging.INFO) - - hud_console.header("HUD RL Training") - - # Determine execution mode - use_remote = not local - - if not tasks_file: - tasks_file = find_tasks_file(tasks_file) - if not tasks_file: - hud_console.warning("No tasks file found in current directory") - hud_console.hint( - "Download a HF dataset using `hud get ` (e.g., `hud get hud-evals/2048-basic`)" # noqa: E501 - ) - hud_console.hint("or create a tasks file manually.") - raise typer.Exit(1) - - # If user ran bare `hud rl`, guide them through remote task conversion flow - # before proceeding (remote only) - if use_remote: - try: - from hud.cli.flows.tasks import convert_tasks_to_remote - - console.print("[cyan]Preparing remote training tasks...[/cyan]") - tasks_file = convert_tasks_to_remote(tasks_file) - except typer.Exit: - raise - except Exception as e: - hud_console.warning(f"[red]Tasks file is not valid for remote training: {e!s}[/red]") - hud_console.hint("Either ensure the tasks file has remote urls") - hud_console.hint("Or rerun `hud rl` within an environment directory") - raise typer.Exit(1) from e - - try: - from .remote_runner import run_remote_training - - run_remote_training( - tasks_file=tasks_file, - model=model, - config_file=config_file, - output_dir=output_dir, - vllm_gpu_count=vllm_gpu_count, - yes=yes, - ) - return - except Exception as e: - console.print(f"[red]❌ Remote training failed: {e!s}[/red]") - raise typer.Exit(1) from e - - # Local execution flow delegated to local_runner (imports heavy deps lazily) - from .local_runner import run_local_training - - run_local_training( - tasks_file=tasks_file, - model=model, - config_file=config_file, - output_dir=output_dir, - yes=yes, - restart=restart, - verbose=verbose, - no_ddp=no_ddp, - ddp_gpus=ddp_gpus, - vllm_gpu=vllm_gpu, - skip_vllm_startup=skip_vllm_startup, - ) - - -# Export the command function -__all__ = ["rl_command"] diff --git a/hud/cli/rl/config.py b/hud/cli/rl/config.py deleted file mode 100644 index fd6721aa..00000000 --- a/hud/cli/rl/config.py +++ /dev/null @@ -1,101 +0,0 @@ -"""Configuration generation and management for RL training.""" - -from __future__ import annotations - -import json -from typing import TYPE_CHECKING, Any - -from rich.console import Console - -from hud.rl.config import Config, validate_vl_model -from hud.utils.hud_console import hud_console - -from .display import display_preset_table -from .presets import estimate_memory_usage - -if TYPE_CHECKING: - from pathlib import Path -console = Console() - - -def generate_config_interactive( - model_name: str, - presets: list[dict[str, Any]], - yes: bool = False, -) -> tuple[Config, float]: - """Generate RL training configuration interactively.""" - # Validate model is a VL model - validate_vl_model(model_name) - - # Display preset options - if not yes: - display_preset_table(presets, 80.0) # Assuming A100 80GB - - # Let user select preset - if yes: - # Use default preset (Balanced if available, otherwise first) - preset_choice = 1 if len(presets) > 1 else 0 - selected_preset = presets[preset_choice] - hud_console.info(f"Auto-selecting preset: {selected_preset['name']} (--yes mode)") - else: - preset_choice = hud_console.select( - "Select a training configuration preset:", - choices=[{"name": p["name"], "value": i} for i, p in enumerate(presets)], - default=1 if len(presets) > 1 else 0, # Default to "Balanced" if available - ) - selected_preset = presets[preset_choice] # type: ignore - - # Use preset values directly - max_steps_per_episode = selected_preset["max_steps_per_episode"] - - # Calculate memory estimate - max_pixels = 256 * 28 * 28 - estimated_memory = estimate_memory_usage( - selected_preset["mini_batch_size"], - max_steps_per_episode, - selected_preset["max_new_tokens"], - max_pixels, - ) - - config_adds = { - "actor": { - "max_new_tokens": selected_preset["max_new_tokens"], - "max_parallel_episodes": selected_preset["batch_size"], - "max_steps_per_episode": selected_preset["max_steps_per_episode"], - "force_tool_choice": True, - }, - "training": { - "mini_batch_size": selected_preset["mini_batch_size"], - "group_size": selected_preset["group_size"], - "batch_size": selected_preset["batch_size"], - "lr": selected_preset["lr"], - "epochs": selected_preset["epochs"], - }, - "verbose": True, - } - - # Create config - config = Config.from_dict(config_adds) - - return config, estimated_memory - - -def save_config(config: Config, path: Path) -> None: - """Save configuration to a JSON file.""" - config_dict = config.to_dict() - - with open(path, "w", encoding="utf-8") as f: - json.dump(config_dict, f, indent=2) - f.write("\n") # Add newline at end of file - - if not path.name.startswith("."): # Don't show message for temp files - console.print(f"[green]✅ Configuration saved to {path}[/green]") - - -def load_config(path: Path) -> Config: - """Load configuration from a JSON file.""" - with open(path, encoding="utf-8") as f: - data = json.load(f) - - # Use Config.from_dict which handles missing fields gracefully - return Config.from_dict(data) diff --git a/hud/cli/rl/display.py b/hud/cli/rl/display.py deleted file mode 100644 index 06435cf5..00000000 --- a/hud/cli/rl/display.py +++ /dev/null @@ -1,133 +0,0 @@ -"""Display utilities for RL training configuration.""" - -from __future__ import annotations - -from typing import TYPE_CHECKING, Any - -from rich.console import Console -from rich.table import Table - -if TYPE_CHECKING: - from hud.rl.config import Config - -console = Console() - - -def display_gpu_info(gpu_info: dict[str, Any]) -> None: - """Display GPU information in a table.""" - if not gpu_info["available"]: - console.print(f"[red]❌ CUDA not available: {gpu_info.get('error', 'Unknown error')}[/red]") - return - - gpu_table = Table(title="🖥️ Available GPUs", title_style="bold cyan") - gpu_table.add_column("Index", style="yellow") - gpu_table.add_column("Name", style="cyan") - gpu_table.add_column("Memory", style="green") - - for device in gpu_info["devices"]: - gpu_table.add_row(f"GPU {device['index']}", device["name"], f"{device['memory_gb']:.1f} GB") - - console.print(gpu_table) - - -def display_preset_table(presets: list[dict[str, Any]], gpu_memory_gb: float) -> None: - """Display training configuration presets in a table.""" - preset_table = Table(title="📊 Training Configuration Presets", title_style="bold cyan") - preset_table.add_column("Option", style="yellow") - preset_table.add_column("Steps", style="cyan") - preset_table.add_column("Mini-batch", style="cyan") - preset_table.add_column("Group", style="cyan") - preset_table.add_column("Episodes/batch", style="cyan") - - # Add time columns for A100 - if gpu_memory_gb >= 40: - preset_table.add_column("Tasks/hour", style="green") - preset_table.add_column("Updates/hour", style="green") - - for i, preset in enumerate(presets): - row = [ - f"{i + 1}. {preset['name']}", - str(preset["max_steps_per_episode"]), - str(preset["mini_batch_size"]), - str(preset["group_size"]), - str(preset["batch_size"]), - ] - if "tasks_per_hour" in preset: - row.extend( - [ - str(preset["tasks_per_hour"]), - str(preset["steps_per_hour"]), - ] - ) - preset_table.add_row(*row) - - console.print("\n") - console.print(preset_table) - console.print("\n") - - -def display_config_summary( - config: Config, tasks_count: int, gpu_info: dict[str, Any], estimated_memory: float -) -> None: - """Display comprehensive configuration summary for review.""" - console.print("\n[bold cyan]📋 RL Training Configuration Summary[/bold cyan]\n") - - # GPU Information - if gpu_info["available"]: - gpu_table = Table(title="🖥️ GPU Information", title_style="bold yellow") - gpu_table.add_column("Property", style="cyan") - gpu_table.add_column("Value", style="green") - - device = gpu_info["devices"][0] # Primary GPU - gpu_table.add_row("GPU 0", device["name"]) - gpu_table.add_row("Memory", f"{device['memory_gb']:.1f} GB") - gpu_table.add_row("Compute Capability", "8.0") # Assuming A100 - - console.print(gpu_table) - - # Model Configuration - model_table = Table(title="🤖 Model Configuration", title_style="bold yellow") - model_table.add_column("Parameter", style="cyan") - model_table.add_column("Value", style="green") - - model_table.add_row("Base Model", config.model.base_model) - model_table.add_row("LoRA Rank (r)", str(config.model.lora_r)) - model_table.add_row("LoRA Alpha", str(config.model.lora_alpha)) - model_table.add_row("LoRA Dropout", str(config.model.lora_dropout)) - - console.print(model_table) - - # Training Configuration - training_table = Table(title="🎯 Training Configuration", title_style="bold yellow") - training_table.add_column("Parameter", style="cyan") - training_table.add_column("Value", style="green") - - training_table.add_row("Tasks Count", str(tasks_count)) - training_table.add_row("Learning Rate", f"{config.training.lr:.1e}") - training_table.add_row("Epochs", str(config.training.epochs)) - training_table.add_row("Mini Batch Size", str(config.training.mini_batch_size)) - training_table.add_row("Batch Size", str(config.training.batch_size)) - training_table.add_row("Group Size", str(config.training.group_size)) - training_table.add_row("Training Steps", str(config.training.training_steps)) - training_table.add_row("Max Parallel Episodes", str(config.actor.max_parallel_episodes)) - - console.print(training_table) - - # Memory Estimation - memory_table = Table(title="💾 Memory Estimation", title_style="bold yellow") - memory_table.add_column("Metric", style="cyan") - memory_table.add_column("Value", style="green") - - memory_table.add_row("Estimated GPU Memory", f"{estimated_memory:.1f} GB") - if gpu_info["available"]: - available_memory = gpu_info["devices"][0]["memory_gb"] - memory_table.add_row("Available GPU Memory", f"{available_memory:.1f} GB") - - if estimated_memory > available_memory: - status = "[red]⚠️ May exceed available memory[/red]" - else: - status = "[green]✅ Within memory limits[/green]" - memory_table.add_row("Status", status) - - console.print(memory_table) - console.print("\n") diff --git a/hud/cli/rl/gpu.py b/hud/cli/rl/gpu.py deleted file mode 100644 index 56690cd7..00000000 --- a/hud/cli/rl/gpu.py +++ /dev/null @@ -1,63 +0,0 @@ -"""GPU detection and validation utilities for RL training.""" - -from __future__ import annotations - -import subprocess -from typing import Any - - -def detect_cuda_devices() -> dict[str, Any]: - """Detect available CUDA devices and their properties.""" - try: - # Check if CUDA is available - result = subprocess.run( - ["nvidia-smi", "--query-gpu=index,name,memory.total", "--format=csv,noheader,nounits"], # noqa: S607 - capture_output=True, - text=True, - check=True, - ) - - if result.returncode != 0: - return {"available": False, "error": "nvidia-smi command failed"} - - devices = [] - for line in result.stdout.strip().split("\n"): - parts = line.split(", ") - if len(parts) >= 3: - devices.append( - { - "index": int(parts[0]), - "name": parts[1], - "memory_gb": float(parts[2]) / 1024, # Convert MB to GB - } - ) - - return {"available": True, "devices": devices} - - except FileNotFoundError: - return { - "available": False, - "error": "nvidia-smi not found - CUDA drivers may not be installed", - } - except Exception as e: - return {"available": False, "error": str(e)} - - -def select_gpu_for_vllm(devices: list[dict[str, Any]]) -> int: - """Select the best GPU for vLLM server (typically GPU 1 if available).""" - if len(devices) > 1: - # Prefer GPU 1 for vLLM to leave GPU 0 for other processes - return 1 - return 0 - - -def validate_gpu_memory(gpu_memory_gb: float, model_size: str = "3B") -> bool: - """Validate if GPU has sufficient memory for the model.""" - min_memory_requirements = { - "3B": 12.0, # Minimum for Qwen 2.5 VL 3B - "7B": 24.0, - "14B": 40.0, - } - - min_required = min_memory_requirements.get(model_size, 12.0) - return gpu_memory_gb >= min_required diff --git a/hud/cli/rl/gpu_utils.py b/hud/cli/rl/gpu_utils.py deleted file mode 100644 index 8b999aa2..00000000 --- a/hud/cli/rl/gpu_utils.py +++ /dev/null @@ -1,321 +0,0 @@ -"""GPU utilities for DDP training.""" - -from __future__ import annotations - -import logging -import subprocess -import time -from typing import TYPE_CHECKING, Any - -from hud.utils.hud_console import HUDConsole - -if TYPE_CHECKING: - from hud.rl.config import Config -hud_console = HUDConsole(logging.getLogger(__name__)) - - -def get_gpu_memory_info() -> dict[int, dict[str, Any]]: - """Get memory usage information for all GPUs.""" - - gpu_memory = {} - try: - # Get memory info for all GPUs - cmd = [ - "nvidia-smi", - "--query-gpu=index,memory.used,memory.total,memory.free", - "--format=csv,noheader,nounits", - ] - result = subprocess.run(cmd, capture_output=True, text=True, check=True) # noqa: S603 - - for line in result.stdout.strip().split("\n"): - if not line: - continue - parts = line.split(", ") - if len(parts) >= 4: - gpu_idx = int(parts[0]) - memory_used = float(parts[1]) - memory_total = float(parts[2]) - memory_free = float(parts[3]) - gpu_memory[gpu_idx] = { - "used_mb": memory_used, - "total_mb": memory_total, - "free_mb": memory_free, - "used_pct": (memory_used / memory_total) * 100, - } - - # Get process information per GPU - for gpu_idx in gpu_memory: # noqa: PLC0206 - cmd = [ - "nvidia-smi", - "-i", - str(gpu_idx), - "--query-compute-apps=pid,used_memory", - "--format=csv,noheader,nounits", - ] - try: - result = subprocess.run(cmd, capture_output=True, text=True, check=True) # noqa: S603 - processes = [] - for line in result.stdout.strip().split("\n"): - if not line: - continue - parts = line.split(", ") - if len(parts) >= 2: - pid = int(parts[0]) - memory_mb = float(parts[1]) - processes.append({"pid": pid, "memory_mb": memory_mb}) - gpu_memory[gpu_idx]["processes"] = processes - except Exception as e: - hud_console.error(f"Failed to get process info for GPU {gpu_idx}: {e}") - gpu_memory[gpu_idx]["processes"] = [] - - except Exception as e: - hud_console.error(f"Failed to get GPU memory info {e}") - return {} - - return gpu_memory - - -def health_check_gpus(gpu_indices: list[int]) -> dict[str, Any]: - """Perform health check on specified GPUs including memory status. - - Returns: - Dict with: - - healthy_gpus: List of healthy GPU indices - - unhealthy_gpus: Dict of unhealthy GPU index -> error message - - all_healthy: Boolean indicating if all GPUs are healthy - - memory_issues: Boolean indicating if there are memory issues - """ - import torch - from rich.console import Console - from rich.table import Table - - console = Console() - - console.print("\n[bold cyan]🏥 GPU Health Check[/bold cyan]") - - # First get memory info - memory_info = get_gpu_memory_info() - - healthy_gpus = [] - unhealthy_gpus = {} - memory_issues = [] - - # Create a table for results - table = Table(title="GPU Health Status") - table.add_column("GPU", style="cyan") - table.add_column("Memory Usage", style="yellow") - table.add_column("Status", style="green") - table.add_column("Details", style="yellow") - - for gpu_idx in gpu_indices: - # Memory info - mem_str = "Unknown" - if gpu_idx in memory_info: - mem = memory_info[gpu_idx] - used_gb = mem["used_mb"] / 1024 - total_gb = mem["total_mb"] / 1024 - mem_str = f"{used_gb:.1f}/{total_gb:.1f} GB ({mem['used_pct']:.0f}%)" - - # Check for high memory usage - if mem["used_pct"] > 70: - memory_issues.append(gpu_idx) - proc_info = f" ({len(mem['processes'])} processes)" if mem["processes"] else "" - unhealthy_gpus[gpu_idx] = f"High memory usage{proc_info}" - table.add_row( - f"GPU {gpu_idx}", mem_str, "❌ Unhealthy", f"High memory usage{proc_info}" - ) - continue - - # If no severe memory issue, do accessibility test - try: - # Try to allocate a small tensor on the GPU - torch.cuda.set_device(gpu_idx) - device = torch.device(f"cuda:{gpu_idx}") - - # Test basic allocation - test_tensor = torch.zeros(100, 100, device=device) - - # Test computation - result = torch.matmul(test_tensor, test_tensor) - - # Force synchronization - torch.cuda.synchronize(device) - - # Clean up - del test_tensor, result - torch.cuda.empty_cache() - - healthy_gpus.append(gpu_idx) - table.add_row(f"GPU {gpu_idx}", mem_str, "✅ Healthy", "Passed all tests") - - except Exception as e: - error_msg = str(e) - if "busy or unavailable" in error_msg: - short_msg = "Device busy or unavailable" - elif "out of memory" in error_msg: - short_msg = "Insufficient memory" - else: - short_msg = error_msg[:50] + "..." if len(error_msg) > 50 else error_msg - - unhealthy_gpus[gpu_idx] = short_msg - table.add_row(f"GPU {gpu_idx}", mem_str, "❌ Unhealthy", short_msg) - - # Small delay between GPU checks - time.sleep(0.1) - - console.print(table) - - return { - "healthy_gpus": healthy_gpus, - "unhealthy_gpus": unhealthy_gpus, - "all_healthy": len(unhealthy_gpus) == 0, - "memory_issues": memory_issues, - } - - -def calculate_optimal_gpu_allocation(gpu_info: dict[str, Any], config: Config) -> dict[str, Any]: - """Calculate optimal GPU allocation for DDP GRPO training. - - Key insight: In GRPO, we want to process groups in parallel. - Optimal case: num_gpus = num_groups (each GPU processes 1 group). - """ - devices = gpu_info["devices"] - available_gpus = [device["index"] for device in devices] - - # Need at least 2 GPUs (1 for training, 1 for vLLM) - if len(available_gpus) < 2: - return {"use_ddp": False, "reason": "Need at least 2 GPUs"} - - # Reserve last GPU for vLLM - vllm_gpu = available_gpus[-1] - training_gpus = available_gpus[:-1] - - # Calculate number of groups - batch_size = config.training.batch_size - group_size = config.training.group_size - num_groups = batch_size // group_size - - if num_groups == 0: - num_groups = 1 - - # Optimal: Use exactly num_groups GPUs (each processes 1 group in parallel) - # But cap at available training GPUs - optimal_gpu_count = min(len(training_gpus), num_groups) - - # Only use DDP if we have more than 1 group and more than 1 GPU - use_ddp = optimal_gpu_count > 1 and num_groups > 1 - - if not use_ddp: - # Single GPU training - return { - "use_ddp": False, - "reason": f"Single GPU sufficient for {num_groups} group(s)", - "training_gpus": [training_gpus[0]], - "vllm_gpu": vllm_gpu, - "num_groups": num_groups, - } - - # Use optimal number of GPUs for DDP - training_gpus = training_gpus[:optimal_gpu_count] - - return { - "use_ddp": True, - "training_gpus": training_gpus, - "vllm_gpu": vllm_gpu, - "num_groups": num_groups, - "groups_per_gpu": num_groups / len(training_gpus), - "parallel_efficiency": min( - 1.0, num_groups / len(training_gpus) - ), # 1.0 = perfect load balance - } - - -def adjust_config_for_ddp(config: Config, num_gpus: int) -> Config: - """Adjust configuration for optimal DDP performance. - - Scaling rule: - - For 1 GPU: batch_size = 2 * group_size - - For N GPUs (N > 1): batch_size = N * group_size - - This ensures each GPU processes exactly 1 group in parallel for optimal performance. - """ - group_size = config.training.group_size - - # Apply scaling rule - if num_gpus == 1: - # Special case: 2 groups for single GPU - groups_per_gpu = 2 - config.training.batch_size = 2 * group_size - else: - groups_per_gpu = config.training.batch_size // group_size - # Multi-GPU: each GPU processes groups_per_gpu groups - config.training.batch_size = num_gpus * group_size * groups_per_gpu - - # Update max_parallel_episodes to match - config.actor.max_parallel_episodes = config.training.batch_size - - config.training.num_gpus = num_gpus - - # Log the adjustment - from rich.console import Console - - console = Console() - console.print( - f"\n[cyan]📊 Adjusted batch_size to {config.training.batch_size} ({config.training.batch_size // group_size} groups)[/cyan]" # noqa: E501 - ) - console.print( - f"[cyan] Each of the {num_gpus} GPU(s) will process {groups_per_gpu} group(s) in parallel[/cyan]" # noqa: E501 - ) - - return config - - -def kill_high_memory_processes(memory_threshold: float = 70.0) -> int: - """Kill all GPU processes using more than threshold% memory. - - Returns: - Number of processes killed - """ - from rich.console import Console - - console = Console() - - memory_info = get_gpu_memory_info() - killed_count = 0 - - for gpu_idx, info in memory_info.items(): - if info["used_pct"] > memory_threshold: - for proc in info.get("processes", []): - pid = proc["pid"] - try: - # Try graceful termination first - subprocess.run(["kill", "-TERM", str(pid)], check=False, capture_output=True) # noqa: S603, S607 - killed_count += 1 - console.print( - f"[yellow]Terminating PID {pid} on GPU {gpu_idx} ({proc['memory_mb'] / 1024:.1f} GB)[/yellow]" # noqa: E501 - ) - except Exception as e: - console.print(f"[red]Failed to kill PID {pid}: {e}[/red]") - - if killed_count > 0: - console.print(f"\n[yellow]Sent termination signal to {killed_count} processes...[/yellow]") - time.sleep(3) - - # Force kill any remaining - for info in memory_info.values(): - for proc in info.get("processes", []): - pid = proc["pid"] - try: - # Check if still running - subprocess.run( # noqa: S603 - ["kill", "-0", str(pid)], # noqa: S607 - check=True, - capture_output=True, - ) - # If no error, process is still running, force kill - subprocess.run(["kill", "-KILL", str(pid)], check=False) # noqa: S603, S607 - console.print(f"[red]Force killed PID {pid}[/red]") - except Exception: - hud_console.error(f"Failed to kill PID {pid}") - - return killed_count diff --git a/hud/cli/rl/local_runner.py b/hud/cli/rl/local_runner.py deleted file mode 100644 index dec63677..00000000 --- a/hud/cli/rl/local_runner.py +++ /dev/null @@ -1,607 +0,0 @@ -""" -Local runner for HUD RL training. - -This module encapsulates the local training flow and imports heavy -dependencies (torch, transformers, etc.) only when actually running -locally. The CLI entrypoint should import this module lazily to avoid -pulling heavy deps during remote-only usage. -""" - -from __future__ import annotations - -import asyncio -import os -import subprocess -import sys -from pathlib import Path - -from rich.console import Console - -from hud.rl.config import validate_vl_model -from hud.utils.hud_console import hud_console -from hud.utils.tasks import load_tasks - -console = Console() - - -def run_local_training( - *, - tasks_file: str, - model: str | None, - config_file: Path | None, - output_dir: str, - yes: bool, - restart: bool, - verbose: bool, - no_ddp: bool, - ddp_gpus: str | None, - vllm_gpu: int | None, - skip_vllm_startup: bool, -) -> None: - """Run RL training locally on the current machine. - - Heavy modules are imported inside this function to avoid import-time side effects - during remote-only runs. - """ - # Light-weight utilities - from .config import generate_config_interactive, load_config, save_config - from .display import display_config_summary, display_gpu_info - from .gpu import detect_cuda_devices, validate_gpu_memory - from .presets import get_training_presets - - # Python version compatibility warning for vLLM - python_version = sys.version_info - if python_version.major == 3 and python_version.minor >= 13: - console.print("[red]⚠️ Warning: Python 3.13+ detected![/red]") - console.print("[yellow]vLLM has compatibility issues with Python 3.13.[/yellow]") - console.print("[yellow]Recommended: Use Python 3.12 or 3.11[/yellow]") - console.print("\n[dim]To create a new environment with Python 3.12:[/dim]") - console.print("[dim] 1. Exit this shell: exit[/dim]") - console.print("[dim] 2. Remove current venv: sudo rm -rf .venv[/dim]") - console.print("[dim] 3. Create new venv: uv venv --python 3.12[/dim]") - console.print("[dim] 4. Install dependencies: uv pip install -e '.[rl]'[/dim]") - - try: - import typer - - if not yes: - if not typer.confirm("\nDo you want to continue anyway?", default=False): - raise typer.Exit(1) - else: - hud_console.warning("Auto-continuing despite Python 3.13+ (--yes mode)") - except Exception as e: - hud_console.warning(f"Failed to confirm: {e}") - return - - # Step 1: Validate CUDA devices - console.print("[yellow]Checking GPU availability...[/yellow]") - gpu_info = detect_cuda_devices() - - if not gpu_info["available"]: - console.print(f"[red]❌ {gpu_info['error']}[/red]") - console.print("[yellow]RL training requires CUDA-capable GPUs[/yellow]") - try: - import typer - - raise typer.Exit(1) - except Exception: - return - - display_gpu_info(gpu_info) - - # Perform GPU health check (imports torch lazily) - all_gpu_indices = [device["index"] for device in gpu_info["devices"]] - from .gpu_utils import health_check_gpus # heavy import (torch) - - health_results = health_check_gpus(all_gpu_indices) - - if not health_results["all_healthy"]: - console.print("\n[yellow]⚠️ Some GPUs failed health checks![/yellow]") - console.print( - f"[yellow]Unhealthy GPUs: {list(health_results['unhealthy_gpus'].keys())}[/yellow]" - ) - - if not health_results["healthy_gpus"]: - console.print("[red]❌ No healthy GPUs available for training![/red]") - try: - import typer - - raise typer.Exit(1) - except Exception: - return - - console.print( - f"\n[cyan]You have {len(health_results['healthy_gpus'])} healthy GPUs available.[/cyan]" - ) - - try: - import typer - - if yes: - continue_training = True - hud_console.info("Auto-continuing with healthy GPUs only (--yes mode)") - else: - continue_training = typer.confirm( - "\nContinue with healthy GPUs only?", default=True - ) - except Exception: - continue_training = True - - if not continue_training: - healthy_str = ",".join(map(str, health_results["healthy_gpus"])) - console.print("\n[yellow]Exiting. Please resolve GPU issues and try again.[/yellow]") - console.print("\n[cyan]💡 Tip: To use only healthy GPUs, you can run:[/cyan]") - console.print(f"[white]hud rl {tasks_file} --ddp-gpus {healthy_str} --local[/white]\n") - try: - import typer - - raise typer.Exit(0) - except Exception: - return - else: - # Continue with healthy GPUs only - gpu_info["devices"] = [ - d for d in gpu_info["devices"] if d["index"] in health_results["healthy_gpus"] - ] - console.print( - f"\n[green]✅ Continuing with {len(gpu_info['devices'])} healthy GPUs[/green]" - ) - - # Get primary GPU memory for configuration - primary_gpu = gpu_info["devices"][0] - gpu_memory_gb = primary_gpu["memory_gb"] - - # Validate GPU memory for 3B model - if not validate_gpu_memory(gpu_memory_gb, "3B"): - console.print(f"[red]❌ Insufficient GPU memory ({gpu_memory_gb:.1f} GB)[/red]") - console.print("[yellow]Qwen 2.5 VL 3B requires at least 12 GB of GPU memory[/yellow]") - try: - import typer - - raise typer.Exit(1) - except Exception: - return - - # Step 2: Load and validate tasks - if tasks_file: - console.print(f"\n[cyan]Loading tasks from: {tasks_file}[/cyan]") - else: - possible_files = ["tasks.json", "tasks.jsonl", "browser_2048_tasks.jsonl"] - for f in possible_files: - if Path(f).exists(): - tasks_file = f - console.print(f"[green]Auto-detected tasks file: {f}[/green]") - break - - if not tasks_file: - console.print("[red]❌ No tasks file specified or auto-detected[/red]") - console.print( - "[yellow]Please provide a tasks file or create one of: tasks.json, tasks.jsonl[/yellow]" # noqa: E501 - ) - try: - import typer - - raise typer.Exit(1) - except Exception: - return - - tasks = load_tasks(tasks_file) - console.print(f"[green]✅ Loaded {len(tasks)} tasks[/green]") - - invalid_tasks: list[str] = [] - for i, task in enumerate(tasks): - if not hasattr(task, "prompt") or not task.prompt: # type: ignore - invalid_tasks.append(f"Task {i}: missing 'prompt' field") - if not hasattr(task, "mcp_config") or not task.mcp_config: # type: ignore - invalid_tasks.append(f"Task {i}: missing 'mcp_config' field") - - if invalid_tasks: - console.print("[red]❌ Invalid tasks found:[/red]") - for error in invalid_tasks[:5]: - console.print(f" - {error}") - if len(invalid_tasks) > 5: - console.print(f" ... and {len(invalid_tasks) - 5} more") - try: - import typer - - raise typer.Exit(1) - except Exception: - return - - # Step 3: Model selection (if not provided) - if model is None and not config_file: - if yes: - model = "Qwen/Qwen2.5-VL-3B-Instruct" # Default model in yes mode - hud_console.info(f"Auto-selecting model: {model} (--yes mode)") - else: - model = hud_console.select( - "Select a model for RL training:", - choices=[ - { - "name": "Qwen 2.5 VL 3B (Recommended - Vision-Language)", - "value": "Qwen/Qwen2.5-VL-3B-Instruct", - }, - {"name": "Custom model", "value": "custom"}, - ], - default=0, - ) - - if model == "custom": - console.print("Enter the model name (HuggingFace ID):") - model = input().strip() - - # try to get model from config file - if config_file: - console.print(f"\n[cyan]Loading configuration from: {config_file}[/cyan]") - config = load_config(config_file) - if hasattr(config, "model") and hasattr(config.model, "base_model"): - if model is None: - model = config.model.base_model - else: - console.print( - f"[yellow]Model already set to {model}, using that instead " - f"of {config.model.base_model}[/yellow] (override)" - ) - - if model is None: - console.print("[red]❌ No model specified either through CLI or config file[/red]") - try: - import typer - - raise typer.Exit(1) - except Exception: - return - - # Validate model is a VL model (whether provided via CLI or selected) - try: - validate_vl_model(model) - except ValueError as e: - console.print(f"\n[red]❌ {e}[/red]") - try: - import typer - - raise typer.Exit(1) - except Exception: - return - - # Step 4: Generate or load configuration - if config_file: - console.print(f"\n[cyan]Loading configuration from: {config_file}[/cyan]") - config = load_config(config_file) - - # Validate model from config - if hasattr(config, "model") and hasattr(config.model, "base_model"): - try: - validate_vl_model(config.model.base_model) - except ValueError as e: - console.print(f"\n[red]❌ {e}[/red]") - try: - import typer - - raise typer.Exit(1) - except Exception: - return - - # Estimate memory for display - from .presets import estimate_memory_usage - - estimated_memory = estimate_memory_usage( - config.training.mini_batch_size, - config.actor.max_steps_per_episode, - config.actor.max_new_tokens, - config.model.max_pixels, - ) - else: - console.print("\n[cyan]Generating training configuration...[/cyan]") - # Get number of GPUs for preset scaling - num_training_gpus = 1 # Default, will be adjusted later - if len(gpu_info["devices"]) > 2: - num_training_gpus = len(gpu_info["devices"]) - 1 # Reserve 1 for vLLM - console.print( - f"[yellow]Note: Episodes will be scaled for {num_training_gpus} training GPUs[/yellow]\n" # noqa: E501 - ) - - presets = get_training_presets(gpu_memory_gb) - config, estimated_memory = generate_config_interactive( - model_name=model, - presets=presets, - yes=yes, - ) - - # Step 5: Save temporary config and display summary - temp_config_path = Path(".rl_config_temp.json") - save_config(config, temp_config_path) - console.print(f"\n[cyan]📝 Configuration saved to: {temp_config_path.absolute()}[/cyan]") - console.print("[yellow]You can edit this file before starting training.[/yellow]") - - # Display configuration summary - display_config_summary(config, len(tasks), gpu_info, estimated_memory) - - # Step 6: Ask for confirmation (skip if config was provided or in yes mode) - if not config_file and not yes: - console.print("\n[bold yellow]Options:[/bold yellow]") - console.print(" • Type [green]'start'[/green] to begin training") - console.print(" • Type [cyan]'edit'[/cyan] to open config in your editor") - console.print(" • Type [red]'cancel'[/red] to abort") - console.print("\n[bold]Your choice:[/bold] ", end="") - - while True: - choice = input().strip().lower() - - if choice == "start": - config = load_config(temp_config_path) # Reload config in case it was edited - break - elif choice == "edit": - editor = os.environ.get("EDITOR", "nano") - - if editor == "nano": - console.print("\n[cyan]Opening config in nano editor...[/cyan]") - console.print("[yellow]Tips:[/yellow]") - console.print(" • Edit the configuration values as needed") - console.print(" • Press [bold]Ctrl+O[/bold] then [bold]Enter[/bold] to save") - console.print(" • Press [bold]Ctrl+X[/bold] to exit") - console.print(" • Press [bold]Ctrl+C[/bold] to cancel without saving\n") - input("Press Enter to continue...") - - try: - subprocess.run([editor, str(temp_config_path)], check=True) # noqa: S603 - # Reload and display updated config - config = load_config(temp_config_path) - from .presets import estimate_memory_usage as _estimate_memory - - estimated_memory = _estimate_memory( - config.training.mini_batch_size, - config.actor.max_steps_per_episode, - config.actor.max_new_tokens, - config.model.max_pixels, - ) - display_config_summary(config, len(tasks), gpu_info, estimated_memory) - console.print( - "\n[bold]Type 'start' to begin or 'cancel' to abort:[/bold] ", end="" - ) - except subprocess.CalledProcessError: - console.print( - "\n[yellow]Editor closed without saving or was cancelled.[/yellow]" - ) - console.print("[bold]Your choice:[/bold] ", end="") - except Exception as e: - console.print(f"\n[red]Failed to open editor: {e}[/red]") - console.print( - f"[yellow]Please edit {temp_config_path} manually and type 'start' when ready.[/yellow]" # noqa: E501 - ) - console.print("[bold]Your choice:[/bold] ", end="") - elif choice == "cancel": - console.print("[red]Training cancelled[/red]") - try: - import typer - - if yes: - # Always save in yes mode - config_path = Path("rl_config.json") - save_config(config, config_path) - hud_console.info("Auto-saved configuration (--yes mode)") - elif typer.confirm("Save this configuration for later?", default=True): - config_path = Path("rl_config.json") - save_config(config, config_path) - except Exception as e: - hud_console.warning(f"Failed to save config: {e}") - - try: - temp_config_path.unlink() - except Exception as e: - hud_console.warning(f"Failed to clean up temp config: {e}") - - try: - import typer - - raise typer.Exit(0) - except Exception: - return - else: - console.print( - "[red]Invalid choice. Type 'start', 'edit', or 'cancel':[/red] ", end="" - ) - elif yes: - # In yes mode, auto-start training - hud_console.info("Auto-starting training (--yes mode)") - config = load_config(temp_config_path) - else: - console.print("\n[dim]Using provided configuration file...[/dim]") - config = load_config(temp_config_path) - - # Step 7: Determine if DDP should be used (imports heavy helpers lazily) - num_gpus = len(gpu_info["devices"]) - use_ddp = False - training_gpus = [0] # Default single GPU - vllm_gpu_idx = 1 if num_gpus > 1 else 0 - - if num_gpus > 2 and not no_ddp: - console.print(f"\n[cyan]🚀 Detected {num_gpus} GPUs - checking DDP configuration...[/cyan]") - - from .gpu_utils import calculate_optimal_gpu_allocation # heavy import (torch at module) - - gpu_allocation = calculate_optimal_gpu_allocation(gpu_info, config) - - if gpu_allocation["use_ddp"]: - use_ddp = True - training_gpus = gpu_allocation["training_gpus"] - vllm_gpu_idx = gpu_allocation["vllm_gpu"] - - console.print( - f"[green]✅ Will use DDP with {len(training_gpus)} GPUs for training[/green]" - ) - console.print(f"[green]✅ GPU {vllm_gpu_idx} reserved for vLLM server[/green]") - - console.print("\n[cyan]Training Configuration:[/cyan]") - console.print(f" • Groups to process: {gpu_allocation['num_groups']}") - console.print(f" • Training GPUs: {training_gpus}") - console.print(f" • Groups per GPU: {gpu_allocation.get('groups_per_gpu', 'N/A'):.1f}") - - if gpu_allocation.get("parallel_efficiency", 1.0) < 0.8: - console.print( - f"\n[yellow]⚠️ GPU efficiency: {gpu_allocation['parallel_efficiency'] * 100:.0f}%[/yellow]" # noqa: E501 - ) - console.print( - f"[yellow]Consider adjusting batch_size to {len(training_gpus) * config.training.group_size} for optimal performance[/yellow]" # noqa: E501 - ) - else: - console.print(f"[cyan]{gpu_allocation.get('reason', 'Using single GPU')}[/cyan]") - - # Allow manual overrides - if ddp_gpus is not None: - requested_gpus = [int(x) for x in ddp_gpus.split(",")] - console.print(f"[cyan]Manual GPU selection: {requested_gpus}[/cyan]") - available_indices = [d["index"] for d in gpu_info["devices"]] - invalid_gpus = [g for g in requested_gpus if g not in available_indices] - if invalid_gpus: - console.print(f"[red]❌ Invalid/unhealthy GPU(s) requested: {invalid_gpus}[/red]") - console.print(f"[yellow]Available healthy GPUs: {available_indices}[/yellow]") - try: - import typer - - raise typer.Exit(1) - except Exception: - return - training_gpus = requested_gpus - use_ddp = len(training_gpus) > 1 - - if vllm_gpu is not None: - vllm_gpu_idx = vllm_gpu - console.print(f"[cyan]Manual vLLM GPU: {vllm_gpu_idx}[/cyan]") - available_indices = [d["index"] for d in gpu_info["devices"]] - if vllm_gpu_idx not in available_indices: - console.print(f"[red]❌ vLLM GPU {vllm_gpu_idx} is not available/healthy![/red]") - console.print(f"[yellow]Available healthy GPUs: {available_indices}[/yellow]") - try: - import typer - - raise typer.Exit(1) - except Exception: - return - - # Ensure we have at least one training GPU - if not training_gpus: - console.print("[red]❌ No available GPUs for training![/red]") - try: - import typer - - raise typer.Exit(1) - except Exception: - return - - # Always adjust batch_size based on number of training GPUs (lazy import) - from .gpu_utils import adjust_config_for_ddp # heavy import (torch at module) - - config = adjust_config_for_ddp(config, len(training_gpus)) - save_config(config, temp_config_path) - - # Step 8: Start vLLM server (unless we're using a remote one) - if not skip_vllm_startup: - console.print(f"\n[cyan]Setting up vLLM server on GPU {vllm_gpu_idx}...[/cyan]") - - from .vllm import start_vllm_server, wait_for_vllm_server - - start_vllm_server(config.model.base_model, vllm_gpu_idx, restart=restart) - server_ready = asyncio.run(wait_for_vllm_server()) - if not server_ready: - console.print("[red]❌ Failed to start vLLM server[/red]") - try: - import typer - - raise typer.Exit(1) - except Exception: - return - else: - console.print("\n[cyan]Using remote vLLM server (skipping local startup)[/cyan]") - - # Step 9: Run training (DDP or single GPU) - if use_ddp: - console.print( - f"\n[bold green]🎯 Starting DDP training on {len(training_gpus)} GPUs...[/bold green]\n" - ) - launch_ddp_training(training_gpus, tasks_file, temp_config_path, verbose) - else: - console.print("\n[bold green]🎯 Starting single-GPU training...[/bold green]\n") - try: - # Set verbose in config instead of passing as parameter - if verbose: - config.verbose = True - - # Import and run the async training function lazily - from hud.rl.train import train # heavy import - - asyncio.run(train(config, tasks)) # type: ignore - console.print("\n[green]✅ Training completed successfully![/green]") - - try: - temp_config_path.unlink() - except Exception as e: - hud_console.warning(f"Failed to clean up temp config: {e}") - - except KeyboardInterrupt: - console.print("\n[yellow]Training interrupted by user[/yellow]") - try: - import typer - - raise typer.Exit(1) - except Exception: - return - except Exception as e: - console.print(f"\n[red]❌ Training failed: {e}") - try: - import typer - - raise typer.Exit(1) - except Exception: - return - - -def launch_ddp_training( - training_gpus: list[int], tasks_file: str, config_path: Path, verbose: bool -) -> None: - """Launch DDP training with torchrun. - - Uses subprocess to run the training module, so heavy dependencies load in - the spawned processes rather than the CLI import path. - """ - import subprocess as _subprocess - import sys as _sys - - env = os.environ.copy() - env["CUDA_VISIBLE_DEVICES"] = ",".join(map(str, training_gpus)) - - if not verbose: - env["HUD_LOG_LEVEL"] = "WARNING" - - cmd = [ - _sys.executable, - "-m", - "torch.distributed.run", - f"--nproc_per_node={len(training_gpus)}", - "--master_port=29500", - "-m", - "hud.rl.train", - "--config", - str(config_path), - "--tasks", - tasks_file, - ] - - if verbose: - cmd.append("--verbose") - - try: - _subprocess.run(cmd, env=env, check=True) # noqa: S603 - except _subprocess.CalledProcessError as e: - console.print(f"\n[red]❌ DDP training failed with exit code {e.returncode}[/red]") - try: - import typer - - raise typer.Exit(1) - except Exception: - return - finally: - try: - config_path.unlink() - except Exception as e: - hud_console.warning(f"Failed to clean up temp config: {e}") diff --git a/hud/cli/rl/presets.py b/hud/cli/rl/presets.py deleted file mode 100644 index ead1e560..00000000 --- a/hud/cli/rl/presets.py +++ /dev/null @@ -1,96 +0,0 @@ -"""Training configuration presets for different GPU configurations.""" - -from __future__ import annotations - -from typing import Any - - -def get_training_presets(gpu_memory_gb: float) -> list[dict[str, Any]]: - """Get training configuration presets based on GPU memory.""" - # Time estimates based on provided benchmarks - if gpu_memory_gb >= 40: # A100 40GB or better - presets = [ - { - "name": "More Steps", - "max_steps_per_episode": 12, - "mini_batch_size": 1, - "group_size": 4, - "batch_size": 8, - "max_new_tokens": 256, - "tasks_per_hour": 847, - "steps_per_hour": 424, - "lr": 3e-5, - "epochs": 2, - }, - { - "name": "Balanced (Recommended)", - "max_steps_per_episode": 5, - "mini_batch_size": 1, - "group_size": 6, - "batch_size": 12, - "max_new_tokens": 1024, - "tasks_per_hour": 738, - "steps_per_hour": 415, - "lr": 3e-5, - "epochs": 2, - }, - { - "name": "Low Variance", - "max_steps_per_episode": 3, - "mini_batch_size": 2, - "group_size": 8, - "batch_size": 16, - "max_new_tokens": 512, - "tasks_per_hour": 900, - "steps_per_hour": 450, - "lr": 3e-5, - "epochs": 2, - }, - ] - elif gpu_memory_gb >= 24: # RTX 4090, A10, etc - presets = [ - { - "name": "Balanced (Recommended)", - "max_steps_per_episode": 4, - "mini_batch_size": 1, - "group_size": 4, - "batch_size": 16, - "lr": 1e-4, - "epochs": 2, - }, - { - "name": "Low Variance", - "max_steps_per_episode": 3, - "mini_batch_size": 2, - "group_size": 4, - "batch_size": 16, - "lr": 5e-5, - "epochs": 2, - }, - ] - else: # Smaller GPUs - presets = [ - { - "name": "Test", - "max_steps_per_episode": 5, - "mini_batch_size": 1, - "group_size": 4, - "batch_size": 8, - "lr": 1e-4, - "epochs": 1, - }, - ] - - return presets - - -def estimate_memory_usage( - mini_batch_size: int, max_steps: int, max_new_tokens: int, max_pixels: int -) -> float: - """Calculate estimated GPU memory usage using the formula from train.py.""" - INITIAL_MEMORY = 8.0 - SCALING_FACTOR = 4 / (28 * 28 * 256 * 1024) - token_estimate = mini_batch_size * max_steps * max_new_tokens - image_estimate = max_pixels - total_memory = INITIAL_MEMORY + SCALING_FACTOR * token_estimate * image_estimate - return total_memory diff --git a/hud/cli/rl/remote_runner.py b/hud/cli/rl/remote_runner.py deleted file mode 100644 index c857a973..00000000 --- a/hud/cli/rl/remote_runner.py +++ /dev/null @@ -1,463 +0,0 @@ -""" -Remote runner for HUD RL training via API server. - -This module implements the new interactive flow for RL training. -""" - -from __future__ import annotations - -import time -import uuid -from pathlib import Path - -from rich.console import Console - -from hud.cli.rl.celebrate import show_confetti_async -from hud.cli.rl.gpu_utils import adjust_config_for_ddp -from hud.cli.rl.viewer import show_json_interactive -from hud.cli.rl.wait_utils import wait_for_enter_cancel_or_change -from hud.utils.hud_console import hud_console -from hud.utils.tasks import load_tasks - -from . import rl_api -from .config import generate_config_interactive, load_config, save_config -from .presets import get_training_presets - -console = Console() - -# GPU pricing information -GPU_PRICING = { - "A100": {"price": "1", "memory": "80GB"}, - "H100": {"price": "2", "memory": "80GB"}, -} - - -def ensure_vllm_deployed( - model_name: str, gpu_type: str = "A100", gpu_count: int = 1, timeout: int = 600 -) -> None: - """Deploy vLLM for a model if needed and wait until it's ready. - - Args: - model_name: The name of the model to deploy vLLM for - gpu_type: GPU type to use for deployment (e.g., A100, H100) - timeout: Max seconds to wait for vLLM to be ready - """ - # Check current model status - info = rl_api.get_model(model_name) - if info.vllm_url: - hud_console.success("vLLM server already running") - return - - hud_console.info(f"Deploying vLLM server for {model_name}...") - rl_api.deploy_vllm(model_name, gpu_type=gpu_type, gpu_count=gpu_count) - hud_console.success("vLLM deployment started") - - hud_console.info("Waiting for vLLM server to be ready...") - start_time = time.time() - with hud_console.progress() as progress: - progress.update("Checking deployment status (see live status on https://hud.ai/models)") - while True: - if time.time() - start_time > timeout: - hud_console.error("Timeout waiting for vLLM deployment") - raise ValueError("vLLM deployment timeout") - info = rl_api.get_model(model_name) - if info.status == "ready": - hud_console.success( - f"vLLM server ready at http://rl.hud.ai/v1/models/{model_name}/vllm" - ) - break - time.sleep(5) - - -def run_remote_training( - tasks_file: str | None, - model: str | None, - config_file: Path | None, - output_dir: str, - vllm_gpu_count: int = 1, - yes: bool = False, -) -> None: - """Run RL training remotely via the API server following the new interactive flow.""" - from hud.settings import settings - - if not settings.api_key: - hud_console.error("API key not found") - console.print( - "[yellow]Set it in your environment or run: hud set HUD_API_KEY=your-key-here[/yellow]" - ) - raise ValueError("API key not found") - - # Step 1: CONFIRMATION - Load tasks - if tasks_file: - tasks: list[Task] = load_tasks(tasks_file) # type: ignore[assignment] - # Resolve tasks immediately after loading (validate + fill defaults) - from hud.types import Task - - resolved_tasks: list[dict] = [] - for t in tasks: - try: - resolved = Task(**t.model_dump()).model_dump() - except Exception: - resolved = t.model_dump() - resolved_tasks.append(resolved) - - # Preview resolved task - if resolved_tasks and not yes: - try: - show_json_interactive(resolved_tasks[0], title="Task Preview") - except Exception as e: - hud_console.warning(f"Interactive viewer failed: {e}") - else: - raise ValueError("Tasks file not found") - - # Show example task for confirmation - # hud_console.section_title("Example Task from Dataset") - - # if tasks: - # # Display task with truncated values - # try: - # task_data = resolved_tasks[0] - # except Exception: - # task_data = tasks[0].model_dump() - # truncated_data = {} - # max_value_length = 120 # Maximum characters to show per line - - # for key, value in task_data.items(): - # value_str = str(value) - # if len(value_str) > max_value_length: - # truncated_data[key] = value_str[:max_value_length] + "..." - # else: - # truncated_data[key] = value_str - - # hud_console.key_value_table(truncated_data) - - # if not hud_console.confirm("Proceed with training on this dataset?", default=True): - # hud_console.error("Training cancelled") - # return - - # Step 2: MODEL SELECTION - hud_console.section_title("Model Selection") - - # Fetch existing models - hud_console.info("Fetching your models from https://hud.ai/models") - - try: - models = rl_api.list_models() - # Filter for active/training models and sort by recency - active_models = [m for m in models if m.status in ["ready", "training"]] - active_models.sort(key=lambda m: m.created_at or "", reverse=True) - - if active_models or model is None: - # Build choices - choices = [] - for m in active_models: - status_emoji = { - "ready": "✅", - "training": "🔄", - "deploying": "🚀", - "pending": "⏳", - }.get(m.status, "❓") - - choices.append({"name": f"{status_emoji} {m.name} ({m.status})", "value": m.name}) - - choices.append({"name": "Create new model", "value": "__new__"}) - - if not model: - if yes: - # In yes mode, always create a new model to avoid conflicts - selected = "__new__" - hud_console.info("Auto-creating new model (--yes mode)") - elif choices: - selected = hud_console.select("Select a model:", choices=choices) - else: - selected = "__new__" - hud_console.hint("No existing models found. Creating new model...") - else: - # Model was provided via CLI - selected = model - - else: - selected = "__new__" - - # Handle model selection - if selected == "__new__": - # Create new model flow - hud_console.info("Creating new model...") - - # Ask for model type - if yes: - if config_file: - config = load_config(config_file) - model_type = config.model.base_model - else: - model_type = "Qwen/Qwen2.5-VL-3B-Instruct" - hud_console.info(f"Auto-selecting base model: {model_type} (--yes mode)") - else: - model_type = hud_console.select( - "Select base model type:", - choices=[ - {"name": "Qwen2.5-VL-3B-Instruct", "value": "Qwen/Qwen2.5-VL-3B-Instruct"}, - {"name": "Qwen2.5-3B-Instruct", "value": "Qwen/Qwen2.5-3B-Instruct"}, - ], - default=0, - ) - from rich.prompt import Prompt - - # Ask for model name - base_default = model_type.split("/")[-1].lower() - default_name = base_default - existing_names = {m.name for m in active_models} - suffix = 1 - while default_name in existing_names: - default_name = f"{base_default}-{suffix}" - suffix += 1 - - if yes: - model_name = default_name - hud_console.info(f"Auto-using model name: {model_name} (--yes mode)") - else: - hud_console.info(f"Enter model name (default: {default_name}):") - model_name = Prompt.ask("Model name", default=default_name) - model_name = model_name.replace("/", "-").lower() - - # Create the model with retry on name conflict - hud_console.info(f"Creating model: {model_name}") - try: - rl_api.create_model(model_name, model_type) - hud_console.success(f"Created model: {model_name}") - ensure_vllm_deployed(model_name, gpu_type="A100", gpu_count=vllm_gpu_count) - - except Exception as e: - # If the name already exists, suggest a new name and prompt once - message = str(e) - if "already exists" in message or "409" in message: - alt_name = f"{model_name}-1" - i = 1 - while True: - candidate = f"{model_name}-{str(uuid.uuid4())[:4]}" - if candidate not in existing_names: - alt_name = candidate - break - i += 1 - hud_console.warning( - f"Model '{model_name}' exists. Suggesting '{alt_name}' instead." - ) - try: - from rich.prompt import Prompt as _Prompt - - if yes: - chosen = alt_name - hud_console.info(f"Auto-using suggested name: {chosen} (--yes mode)") - else: - chosen = _Prompt.ask("Use different name", default=alt_name) - chosen = chosen.replace("/", "-").lower() - rl_api.create_model(chosen, model_type) - hud_console.success(f"Created model: {chosen}") - model_name = chosen - ensure_vllm_deployed(model_name, gpu_type="A100", gpu_count=vllm_gpu_count) - except Exception as e2: - hud_console.error(f"Failed to create model: {e2}") - raise - else: - hud_console.error(f"Failed to create model: {e}") - raise - - else: - # Existing model selected - model_name = selected - model_info = rl_api.get_model(model_name) - - # Check if model is in training - if model_info.status == "training": - if yes: - # In yes mode, skip training if model is already training - hud_console.warning(f"{model_name} is already training, skipping (--yes mode)") - return - elif hud_console.confirm( - f"{model_name} is currently training. Stop current training?", default=False - ): - hud_console.info(f"Stopping training for {model_name}...") - try: - rl_api.stop_training(model_name) - hud_console.success("Training stopped") - except Exception as e: - hud_console.error(f"Failed to stop training: {e}") - raise - else: - hud_console.error("Cannot start new training while model is already training") - return - - # Ensure vLLM is deployed - ensure_vllm_deployed(model_name, gpu_type="A100", gpu_count=vllm_gpu_count) - except KeyboardInterrupt: - hud_console.dim_info("Training cancelled", "") - return - except Exception as e: - hud_console.error(f"Error during model selection: {e}") - raise - - # Get final model info - model_info = rl_api.get_model(model_name) - - # Step 3: TRAINING CONFIG - hud_console.section_title("Training Configuration") - - if not config_file: - # Ask about number of GPUs with pricing - # hud_console.info("GPU Selection (Pricing per GPU):") - - # gpu_table = Table(show_header=True, header_style="bold magenta") - # gpu_table.add_column("GPU Type", style="cyan") - # gpu_table.add_column("Memory", style="green") - # gpu_table.add_column("Price/hr", style="yellow") - - # for gpu, info in GPU_PRICING.items(): - # gpu_table.add_row(gpu, info["memory"], "see pricing on hud.ai") - - # console.print(gpu_table) - - if yes: - gpu_choice = "A100" - hud_console.info(f"Auto-selecting GPU: {gpu_choice} 80GB (--yes mode)") - else: - gpu_choice = hud_console.select( - "Select GPU type:", - choices=[ - {"name": "A100 80GB", "value": "A100"}, - {"name": "H100 80GB", "value": "H100"}, - ], - default=0, - ) - - if yes: - num_gpus = 2 # Default to 2 GPUs in yes mode - hud_console.info(f"Auto-selecting {num_gpus} GPU(s) (--yes mode)") - else: - num_gpus = hud_console.select( - "Number of GPUs:", - choices=[ - {"name": "1 GPU", "value": 1}, - {"name": "2 GPUs", "value": 2}, - {"name": "4 GPUs", "value": 4}, - {"name": "8 GPUs", "value": 8}, - ], - default=1, - ) - - # Generate config with presets - hud_console.info("Generating training configuration...") - gpu_memory_gb = 80.0 if gpu_choice in ["A100", "H100"] else 48.0 - presets = get_training_presets(gpu_memory_gb) - - config, _ = generate_config_interactive( - model_name=model_info.base_model, - presets=presets, - yes=yes, - ) - - config = adjust_config_for_ddp(config, int(num_gpus)) - - config.training.gpu_type = gpu_choice - - # Use a short label for tasks (avoid full absolute paths) - try: - if tasks_file and Path(tasks_file).exists(): - tasks_label = Path(tasks_file).name - else: - # Fallback: last segment of a non-existent path or dataset name - tasks_label = str(tasks_file).replace("\\", "/").split("/")[-1] - except Exception: - tasks_label = str(tasks_file) - - config.job_name = f"RL {tasks_label} | {model_name}" - - # Save config so user can review/edit externally - temp_config_path = Path(f".rl_config_temp_{model_name}.json") - save_config(config, temp_config_path) - - # Interactive review loop: show preview, allow external edits, press Enter to start - hud_console.info( - f"Using training configuration from [underline cyan]{temp_config_path.absolute()}[/underline cyan]" # noqa: E501 - ) - - if yes: - # In yes mode, skip the interactive review loop - hud_console.info("Auto-accepting config (--yes mode)") - # Still show the config briefly - try: - show_json_interactive( - config.to_dict() if hasattr(config, "to_dict") else {}, - title="RL Config Preview", - prompt=False, - ) - except Exception as e: - hud_console.warning(f"Interactive viewer failed: {e}") - else: - while True: - # Reload latest config from file each cycle - try: - config = load_config(temp_config_path) - except Exception as e: - hud_console.warning(f"Failed to load config from disk, using in-memory: {e}") - - # Preview current config (no extra prompt here; main loop handles start/cancel) - try: - show_json_interactive( - config.to_dict() if hasattr(config, "to_dict") else {}, - title="RL Config Preview", - prompt=False, - ) - except Exception as e: - hud_console.warning(f"Interactive viewer failed: {e}") - - console.print( - "\n[dim]Edit the config file above if needed, then save.[/dim]\n" - "[bold]Press Enter to start training[/bold], or press 'q' to cancel." - ) - - start_training, cancelled, changed = wait_for_enter_cancel_or_change( - temp_config_path - ) - - if cancelled: - hud_console.error("Training cancelled") - return - if start_training: - break # proceed - if changed: - hud_console.info("Detected configuration changes. Reloading preview...") - - config_dict = config.to_dict() - else: - # Load provided config - hud_console.info(f"Loading configuration from: {config_file}") - config = load_config(config_file) - gpu_choice = config.training.gpu_type - num_gpus = config.training.num_gpus - - config = adjust_config_for_ddp(config, int(num_gpus)) - config_dict = config.to_dict() - - # Launch training - try: - # Little celebration before launching - try: - show_confetti_async(console) - except Exception: - hud_console.info("Launching training...") - - rl_api.launch_training( - model_name=model_name, - config=config_dict, - tasks=resolved_tasks, - gpu_type=gpu_choice, - gpu_count=int(num_gpus), - ) - - hud_console.info(f"Your model {model_name} has started training") - hud_console.hint("Launch another training run via: hud rl ") - hud_console.hint("Or evaluate the model via: hud eval ") - - except Exception as e: - hud_console.error(f"Failed to launch training: {e}") - raise diff --git a/hud/cli/rl/rl_api.py b/hud/cli/rl/rl_api.py deleted file mode 100644 index 92761b3f..00000000 --- a/hud/cli/rl/rl_api.py +++ /dev/null @@ -1,150 +0,0 @@ -""" -Direct API functions for HUD RL remote endpoints using shared requests module. - -This module provides functions for interacting with the HUD RL API server. -""" - -from __future__ import annotations - -from typing import TYPE_CHECKING, Any - -from pydantic import BaseModel - -from hud.settings import settings -from hud.shared.requests import make_request_sync - -if TYPE_CHECKING: - from collections.abc import Iterator - - -class RLModelInfo(BaseModel): - """Model information from the API.""" - - name: str - base_model: str - vllm_url: str | None = None - trainer_name: str | None = None - checkpoint_volume: str | None = None - status: str = "pending" # pending, deploying, ready, training, terminated - created_at: str | None = None - updated_at: str | None = None - terminated_at: str | None = None - - -def create_model(name: str, base_model: str) -> dict[str, Any]: - """Create a new model.""" - return make_request_sync( - method="POST", - url=f"{settings.hud_rl_url}/models", - json={"name": name, "base_model": base_model}, - api_key=settings.api_key, - ) - - -def get_model(name: str) -> RLModelInfo: - """Get model information.""" - response = make_request_sync( - method="GET", url=f"{settings.hud_rl_url}/models/{name}", api_key=settings.api_key - ) - return RLModelInfo(**response) - - -def list_models() -> list[RLModelInfo]: - """List all models.""" - response = make_request_sync( - method="GET", url=f"{settings.hud_rl_url}/models", api_key=settings.api_key - ) - if not isinstance(response, list): - response = [response] - return [ - RLModelInfo(**(model if isinstance(model, dict) else model.__dict__)) for model in response - ] - - -def deploy_vllm(model_name: str, gpu_type: str = "A100", gpu_count: int = 1) -> dict[str, Any]: - """Deploy a vLLM server for a model.""" - return make_request_sync( - method="POST", - url=f"{settings.hud_rl_url}/models/{model_name}/deploy", - json={"gpu_type": gpu_type, "gpu_count": gpu_count}, - api_key=settings.api_key, - ) - - -def stop_vllm(model_name: str) -> dict[str, Any]: - """Stop the vLLM server for a model.""" - return make_request_sync( - method="DELETE", - url=f"{settings.hud_rl_url}/models/{model_name}/deploy", - api_key=settings.api_key, - ) - - -def stop_training(model_name: str) -> dict[str, Any]: - """Stop the training for a model.""" - return make_request_sync( - method="DELETE", - url=f"{settings.hud_rl_url}/models/{model_name}/training", - api_key=settings.api_key, - ) - - -def launch_training( - model_name: str, - config: dict[str, Any], - tasks: list[dict[str, Any]], - gpu_type: str = "A100", - gpu_count: int = 1, -) -> dict[str, Any]: - """Launch a training run for a model.""" - return make_request_sync( - method="POST", - url=f"{settings.hud_rl_url}/models/{model_name}/training/launch", - json={"config": config, "tasks": tasks, "gpu_type": gpu_type, "gpu_count": gpu_count}, - api_key=settings.api_key, - ) - - -def get_training_status(model_name: str) -> dict[str, Any]: - """Get the status of a training run.""" - return make_request_sync( - method="GET", - url=f"{settings.hud_rl_url}/models/{model_name}/training/status", - api_key=settings.api_key, - ) - - -def get_training_logs(model_name: str, lines: int = 100, follow: bool = False) -> Iterator[str]: - """Get training logs for a model. - - Args: - model_name: Name of the model - lines: Number of lines to return - follow: If True, stream logs as they arrive - - Yields: - Log lines as strings - """ - # For streaming logs, we need to use httpx directly - # as the shared requests module expects JSON responses - import httpx - - params = {"lines": lines} - if follow: - params["follow"] = True - - headers = {"Authorization": f"Bearer {settings.api_key}"} - - with ( - httpx.Client(timeout=300.0) as client, - client.stream( - "GET", - f"{settings.hud_rl_url}/models/{model_name}/training/logs", - params=params, - headers=headers, - ) as response, - ): - response.raise_for_status() - for line in response.iter_lines(): - if line: - yield line diff --git a/hud/cli/rl/vllm.py b/hud/cli/rl/vllm.py deleted file mode 100644 index 969a961c..00000000 --- a/hud/cli/rl/vllm.py +++ /dev/null @@ -1,179 +0,0 @@ -"""vLLM server management utilities.""" - -from __future__ import annotations - -import asyncio -import logging -import os -import subprocess -import time -from pathlib import Path - -import httpx -from rich.console import Console - -from hud.utils.hud_console import HUDConsole - -logger = logging.getLogger(__name__) -hud_console = HUDConsole(logger) - -console = Console() - - -def get_vllm_args(model_name: str, chat_template_path: Path | None = None) -> list[str]: - """Get common vLLM server arguments for both local and remote deployments.""" - args = [ - "serve", - model_name, - "--api-key", - "token-abc123", - "--host", - "0.0.0.0", # noqa: S104 - "--port", - "8000", - "--tensor-parallel-size", - "1", - "--trust-remote-code", - "--max-model-len", - "16384", - "--enable-lora", - "--max-lora-rank", - "64", - "--max-cpu-loras", - "4", - "--enable-auto-tool-choice", - "--tool-call-parser", - "hermes", - "--disable-log-requests", - "--dtype", - "auto", - ] - - # Add chat template if provided - if chat_template_path and chat_template_path.exists(): - args.extend(["--chat-template", str(chat_template_path.absolute())]) - - return args - - -def check_vllm_server() -> bool: - """Check if vLLM server is running.""" - try: - response = httpx.get("http://localhost:8000/health", timeout=2.0) - return response.status_code == 200 - except Exception: - return False - - -def kill_vllm_server() -> None: - """Kill any running vLLM server processes.""" - try: - # Check for PID file first - pid_file = Path("/tmp/vllm_server.pid") # noqa: S108 - if pid_file.exists(): - try: - pid = int(pid_file.read_text().strip()) - subprocess.run(["kill", "-TERM", str(pid)], check=False) # noqa: S603, S607 - time.sleep(2) - # Force kill if still running - subprocess.run(["kill", "-9", str(pid)], check=False) # noqa: S603, S607 - pid_file.unlink() - except Exception as e: - hud_console.error(f"Failed to kill vLLM server: {e}") - - # Also try to kill by process name - subprocess.run(["pkill", "-f", "vllm serve"], check=False) # noqa: S607 - subprocess.run(["pkill", "-f", "vllm.entrypoints.openai.api_server"], check=False) # noqa: S607 - time.sleep(2) - - # Check for any process using port 8000 - result = subprocess.run(["lsof", "-ti:8000"], capture_output=True, text=True, check=False) # noqa: S607 - - if result.stdout.strip(): - for pid in result.stdout.strip().split("\n"): - try: - subprocess.run(["kill", "-9", pid], check=False) # noqa: S603, S607 - except Exception as e: - hud_console.error(f"Failed to kill vLLM server: {e}") - - console.print("[yellow]Killed existing vLLM server processes[/yellow]") - except Exception as e: - hud_console.error(f"Error killing vLLM server: {e}") - - -def start_vllm_server(model_name: str, gpu_index: int = 1, restart: bool = False) -> None: - """Start vLLM server in the background with dynamic GPU selection.""" - if restart: - kill_vllm_server() - time.sleep(3) - - # Check if already running - if check_vllm_server(): - console.print("[green]vLLM server is already running[/green]") - return - - console.print(f"[cyan]Starting vLLM server with {model_name} on GPU {gpu_index}...[/cyan]") - - # Set up environment variables - env = os.environ.copy() - env.update( - { - "CUDA_VISIBLE_DEVICES": str(gpu_index), - "VLLM_ALLOW_RUNTIME_LORA_UPDATING": "True", - "TOKENIZERS_PARALLELISM": "false", - "VLLM_LOGGING_LEVEL": "INFO", # Changed from DEBUG to reduce noise - "CUDA_LAUNCH_BLOCKING": "1", # Better error messages - } - ) - - # Get the path to chat template - chat_template_path = Path(__file__).parent.parent.parent / "rl" / "chat_template.jinja" - - # Build the vLLM command - vllm_args = get_vllm_args(model_name, chat_template_path) - cmd = ["uv", "run", "vllm", *vllm_args] - - # Start the server in the background - with open("/tmp/vllm_server.log", "w") as log_file: # noqa: S108, - process = subprocess.Popen( # noqa: S603 - cmd, - env=env, - stdout=log_file, - stderr=subprocess.STDOUT, - preexec_fn=os.setpgrp, # type: ignore - cwd=Path.cwd(), # Use current working directory - ) - - console.print("[yellow]vLLM server starting in background...[/yellow]") - console.print(f"[yellow]Process ID: {process.pid}[/yellow]") - console.print("[yellow]Check logs at: /tmp/vllm_server.log[/yellow]") - - # Save PID for later management - pid_file = Path("/tmp/vllm_server.pid") # noqa: S108 - pid_file.write_text(str(process.pid)) - - -async def wait_for_vllm_server(timeout: int = 360) -> bool: # noqa: ASYNC109 - """Wait for vLLM server to be ready.""" - start_time = time.time() - console.print("[yellow]Waiting for vLLM server to be ready (up to 6 minutes)...[/yellow]") - - async with httpx.AsyncClient() as client: - while time.time() - start_time < timeout: - try: - response = await client.get("http://localhost:8000/health", timeout=2.0) - if response.status_code == 200: - console.print("[green]✅ vLLM server is ready![/green]") - return True - except httpx.ConnectError: - pass - except Exception as e: - hud_console.error(f"Failed to connect to vLLM server: {e}") - - await asyncio.sleep(2) - elapsed = int(time.time() - start_time) - console.print(f"[yellow]Waiting... ({elapsed}s / {timeout}s)[/yellow]", end="\r") - - console.print("\n[red]❌ vLLM server failed to start within timeout[/red]") - console.print("[yellow]Check /tmp/vllm_server.log for details[/yellow]") - return False diff --git a/hud/cli/rl/wait_utils.py b/hud/cli/rl/wait_utils.py deleted file mode 100644 index f1a587ac..00000000 --- a/hud/cli/rl/wait_utils.py +++ /dev/null @@ -1,89 +0,0 @@ -from __future__ import annotations - -import contextlib -import os -import select -import sys -import threading -import time as _time -from typing import TYPE_CHECKING - -from watchfiles import watch - -if TYPE_CHECKING: - from pathlib import Path - - -def wait_for_enter_cancel_or_change(file_path: Path) -> tuple[bool, bool, bool]: - """Block until Enter (start), 'q' (cancel), or file change. - - Returns (start_training, cancelled, changed). - - start_training: True if Enter (or any non-'q' line on POSIX) was received - - cancelled: True if 'q' was received or Ctrl-C - - changed: True if the file changed on disk - """ - start_training = False - cancelled = False - changed = False - - stop_evt: threading.Event = threading.Event() - changed_evt: threading.Event = threading.Event() - - def _watcher() -> None: - with contextlib.suppress(Exception): - for _ in watch(file_path, stop_event=stop_evt, debounce=200): - changed_evt.set() - break - - t = threading.Thread(target=_watcher, daemon=True) - t.start() - - try: - if os.name == "nt": - import msvcrt # type: ignore[attr-defined] - - while True: - if changed_evt.is_set(): - changed = True - break - - if msvcrt.kbhit(): - ch = msvcrt.getwch() - if ch in ("\r", "\n"): - start_training = True - break - if ch.lower() == "q": - cancelled = True - break - _time.sleep(0.15) - else: - while True: - if changed_evt.is_set(): - changed = True - break - - rlist, _, _ = select.select([sys.stdin], [], [], 0.25) - if rlist: - line = sys.stdin.readline() - if line is None: - continue - stripped = line.strip().lower() - if stripped == "q": - cancelled = True - break - # Any other (including empty) => start - start_training = True - break - _time.sleep(0.05) - - except KeyboardInterrupt: - cancelled = True - finally: - stop_evt.set() - with contextlib.suppress(Exception): - t.join(timeout=1) - - return start_training, cancelled, changed - - -__all__ = ["wait_for_enter_cancel_or_change"] diff --git a/hud/cli/rl/celebrate.py b/hud/cli/utils/celebrate.py similarity index 86% rename from hud/cli/rl/celebrate.py rename to hud/cli/utils/celebrate.py index e9b48cde..8e587822 100644 --- a/hud/cli/rl/celebrate.py +++ b/hud/cli/utils/celebrate.py @@ -1,4 +1,6 @@ # ruff: noqa: S311 +"""Confetti celebration animation for CLI.""" + from __future__ import annotations import random @@ -121,20 +123,20 @@ def render_with_colors(self) -> Text: return text -def show_confetti(console: Console, seconds: float = 2.5) -> None: - """Display celebratory confetti animation inspired by confetty. +def show_confetti(console: Console, seconds: float = 2.5, message: str | None = None) -> None: + """Display celebratory confetti animation. - Shows "Starting training!" message first, then creates two bursts of + Shows a message first, then creates two bursts of falling confetti particles that fall away completely. Args: console: Rich console instance seconds: Duration to show confetti + message: Custom message to display (default: "🎉 Success!") """ # Show celebratory message first - console.print( - "[bold green]🎉 Starting training! See your model on https://hud.ai/models[/bold green]" - ) + msg = message or "[bold green]🎉 Success![/bold green]" + console.print(msg) time.sleep(0.3) # Brief pause to see the message width = min(console.size.width, 120) # Cap width for performance @@ -166,22 +168,23 @@ def show_confetti(console: Console, seconds: float = 2.5) -> None: frame += 1 -def show_confetti_async(console: Console, seconds: float = 2.5) -> None: +def show_confetti_async(console: Console, seconds: float = 2.5, message: str | None = None) -> None: """Non-blocking confetti animation that runs in a background thread. - The animation will run independently while training starts immediately. + The animation will run independently while other operations continue. """ import threading def _run_confetti() -> None: try: - show_confetti(console, seconds) + show_confetti(console, seconds, message) except Exception: - hud_console.info("Launching training...") + hud_console.info("Continuing...") thread = threading.Thread(target=_run_confetti, daemon=True) thread.start() - # Don't wait - let training start immediately while confetti plays + # Don't wait - let operations continue while confetti plays + +__all__ = ["show_confetti", "show_confetti_async", "ConfettiSystem", "Particle"] -__all__ = ["show_confetti", "show_confetti_async"] diff --git a/hud/cli/rl/viewer.py b/hud/cli/utils/viewer.py similarity index 98% rename from hud/cli/rl/viewer.py rename to hud/cli/utils/viewer.py index 4a817acc..59c93c54 100644 --- a/hud/cli/rl/viewer.py +++ b/hud/cli/utils/viewer.py @@ -1,4 +1,4 @@ -"""Inline JSON preview with expandable view for RL flow. +"""Inline JSON preview with expandable view. Uses minimal terminal interaction for inline display. """ @@ -139,3 +139,4 @@ def show_json_interactive( input() console.print() + diff --git a/hud/environment/connection.py b/hud/environment/connection.py index e65869fd..95839eb6 100644 --- a/hud/environment/connection.py +++ b/hud/environment/connection.py @@ -69,6 +69,20 @@ def __init__( self.client: FastMCPClient[Any] | None = None self._tools_cache: list[mcp_types.Tool] | None = None + def copy(self) -> Connector: + """Create a copy of this connector with fresh (unconnected) state. + + The copy shares transport config but has its own client instance, + allowing parallel execution without conflicts. + """ + return Connector( + transport=self._transport, + config=self.config, + name=self.name, + connection_type=self.connection_type, + auth=self._auth, + ) + @property def is_local(self) -> bool: """True if this is a local (non-parallelizable) connection.""" diff --git a/hud/eval/context.py b/hud/eval/context.py index 0dbc747e..0e027d61 100644 --- a/hud/eval/context.py +++ b/hud/eval/context.py @@ -181,6 +181,8 @@ def __init__( self._started_at: datetime | None = None self._completed_at: datetime | None = None self._token: contextvars.Token[dict[str, str] | None] | None = None + self._is_summary: bool = False # True for summary contexts (skip trace) + self._suppress_link: bool = False # True to suppress printing eval link def _apply_task(self, task: Task) -> None: """Apply a Task definition to this environment.""" @@ -252,10 +254,11 @@ def from_environment( env_config=env_config, ) - # Copy connections from parent - # Note: These are shared references - for parallel execution, - # only remote connections should be used - ctx._connections = env._connections.copy() + # Copy connections from parent - each connector is copied so parallel + # execution gets fresh client instances + ctx._connections = { + name: connector.copy() for name, connector in env._connections.items() + } ctx._hub_configs = getattr(env, "_hub_configs", []).copy() ctx._setup_calls = env._setup_calls.copy() ctx._evaluate_calls = env._evaluate_calls.copy() @@ -426,6 +429,10 @@ async def _eval_exit(self, error_message: str | None = None) -> None: async def __aenter__(self) -> Self: """Enter eval context - start tracking and connect environment.""" + # Summary contexts skip trace tracking (parallel results already tracked) + if self._is_summary: + return self + # Start eval tracking self._started_at = datetime.now(UTC) self._token = _current_trace_headers.set(self.headers) @@ -446,6 +453,10 @@ async def __aexit__( exc_tb: TracebackType | None, ) -> None: """Exit eval context - disconnect and report.""" + # Summary contexts skip trace tracking (parallel results already tracked) + if self._is_summary: + return + self._completed_at = datetime.now(UTC) # Track error @@ -470,6 +481,10 @@ def __repr__(self) -> str: def _print_eval_link(self) -> None: """Print a nicely formatted eval link.""" + # Skip if link printing is suppressed (e.g., parallel child traces) + if self._suppress_link: + return + import contextlib import webbrowser diff --git a/hud/eval/manager.py b/hud/eval/manager.py index 49f735af..c5b051af 100644 --- a/hud/eval/manager.py +++ b/hud/eval/manager.py @@ -7,16 +7,18 @@ import inspect import logging +import uuid from contextlib import asynccontextmanager from typing import TYPE_CHECKING, Any from hud.eval.parallel import ( ASTExtractionError, - execute_parallel_evals, expand_variants, + find_user_frame, get_with_block_body, resolve_group_ids, ) +from hud.telemetry.job import _print_job_complete_url, _print_job_url if TYPE_CHECKING: from collections.abc import AsyncGenerator @@ -45,6 +47,31 @@ def _parse_slug(slug: str) -> tuple[str, str | None]: return slug, None +def _get_eval_name(slugs: str | list[str] | None) -> str: + """Extract a nice name from slugs for job display. + + Args: + slugs: Single slug or list of slugs + + Returns: + Name like "evalset" or "eval" if no slugs + """ + if slugs is None: + return "eval" + + # Get the first slug + first_slug = slugs if isinstance(slugs, str) else slugs[0] + + # Remove index/wildcard suffix (":1" or ":*") + base_slug, _ = _parse_slug(first_slug) + + # Extract the evalset name (part after last "/") + if "/" in base_slug: + return base_slug.rsplit("/", 1)[1] + + return base_slug + + def _load_tasks_from_slugs(slugs: str | list[str]) -> list[Task]: """Load tasks from platform by slugs. @@ -210,7 +237,7 @@ async def run_eval( try: caller = frame.f_back if caller is not None: - code_snippet, _ = get_with_block_body(caller) + code_snippet, _, _ = get_with_block_body(caller) except ASTExtractionError: pass finally: @@ -244,39 +271,57 @@ async def run_eval( yield ctx else: - # Parallel execution - completed = await _run_parallel_eval( - tasks=tasks, - variant_combos=variant_combos, - group=group, - group_ids=group_ids, - job_id=job_id, - api_key=api_key, - code_snippet=code_snippet, - ) - - # Create parent ctx with results - if tasks: - ctx = EvalContext.from_task( - task=tasks[0], + # Parallel execution: create implicit job to group traces + eval_name = _get_eval_name(slugs) + implicit_job_id = job_id or str(uuid.uuid4()) + + # Print job URL (not individual trace URLs) + _print_job_url(implicit_job_id, eval_name) + + error_occurred = False + try: + # Run parallel evals with job_id + completed = await _run_parallel_eval( + tasks=tasks, + variant_combos=variant_combos, + group=group, + group_ids=group_ids, + job_id=implicit_job_id, # Propagate job_id to child traces api_key=api_key, - job_id=job_id, - ) - else: - ctx = EvalContext( - name="eval", - api_key=api_key, - job_id=job_id, + code_snippet=code_snippet, ) - ctx.results = completed + # Create summary context (no trace, just aggregates results) + if tasks: + ctx = EvalContext.from_task( + task=tasks[0], + api_key=api_key, + job_id=implicit_job_id, + ) + else: + ctx = EvalContext( + name="eval", + api_key=api_key, + job_id=implicit_job_id, + ) + + ctx._is_summary = True # Skip trace tracking + ctx.results = completed + + # Compute aggregate reward + rewards = [e.reward for e in completed if e.reward is not None] + if rewards: + ctx.reward = sum(rewards) / len(rewards) - # Compute aggregate reward - rewards = [e.reward for e in completed if e.reward is not None] - if rewards: - ctx.reward = sum(rewards) / len(rewards) + # Check if any failed + error_occurred = any(e.error is not None for e in completed) - yield ctx + yield ctx + except Exception: + error_occurred = True + raise + finally: + _print_job_complete_url(implicit_job_id, eval_name, error_occurred) async def _run_parallel_eval( @@ -294,6 +339,11 @@ async def _run_parallel_eval( """ # Lazy import to avoid circular dependency from hud.eval.context import EvalContext + from hud.eval.parallel import log_eval_stats, run_parallel_evals + + # Find user code frame and extract the with block body + caller_frame = find_user_frame() + body_source, captured_locals, context_var = get_with_block_body(caller_frame) # Calculate total evals and resolve group IDs if tasks: @@ -321,6 +371,7 @@ async def _run_parallel_eval( variants=variant, code_snippet=code_snippet, ) + ctx._suppress_link = True # Suppress individual links, job URL shown instead eval_contexts.append(ctx) idx += 1 else: @@ -336,11 +387,24 @@ async def _run_parallel_eval( variants=variant, code_snippet=code_snippet, ) + ctx._suppress_link = True # Suppress individual links, job URL shown instead eval_contexts.append(ctx) idx += 1 - # Run in parallel (frame depth: _run_parallel_eval -> eval -> user code) - return await execute_parallel_evals(eval_contexts, caller_frame_depth=3) + # Run in parallel + logger.info( + "Running %d evals (%d tasks x %d variants x %d runs)", + len(eval_contexts), + max(len(tasks), 1), + len(variant_combos), + group, + ) + completed = await run_parallel_evals(eval_contexts, body_source, captured_locals, context_var) + + # Log stats + log_eval_stats(completed) + + return completed __all__ = ["run_eval"] diff --git a/hud/eval/mixin.py b/hud/eval/mixin.py index 84bb1ff3..49061ed8 100644 --- a/hud/eval/mixin.py +++ b/hud/eval/mixin.py @@ -9,16 +9,18 @@ import inspect import logging +import uuid from contextlib import asynccontextmanager from typing import TYPE_CHECKING, Any from hud.eval.parallel import ( ASTExtractionError, - execute_parallel_evals, expand_variants, + find_user_frame, get_with_block_body, resolve_group_ids, ) +from hud.telemetry.job import _print_job_complete_url, _print_job_url if TYPE_CHECKING: from collections.abc import AsyncGenerator @@ -97,7 +99,7 @@ def _capture_code_snippet(self) -> str | None: if caller is None: return None - body_source, _ = get_with_block_body(caller) + body_source, _, _ = get_with_block_body(caller) return body_source except ASTExtractionError: # Can't extract from REPL/Jupyter - that's OK @@ -253,37 +255,54 @@ async def eval( async with ctx: yield ctx else: - # Parallel execution: each eval gets its own environment instance - completed = await self._run_parallel_eval( - name=name, - variant_combos=variant_combos, - group=group, - group_ids=group_ids, - job_id=job_id, - api_key=api_key, - code_snippet=code_snippet, - env_config=env_config, - ) + # Parallel execution: create implicit job to group traces + implicit_job_id = job_id or str(uuid.uuid4()) - # Create parent ctx with results injected - ctx = EvalContext.from_environment( - env=self, # type: ignore[arg-type] - name=name, - trace_id=trace_id, - api_key=api_key, - job_id=job_id, - code_snippet=code_snippet, - env_config=env_config, - ) - ctx.results = completed - self._last_evals = completed + # Print job URL (not individual trace URLs) + _print_job_url(implicit_job_id, name) + + error_occurred = False + try: + # Run parallel evals with job_id + completed = await self._run_parallel_eval( + name=name, + variant_combos=variant_combos, + group=group, + group_ids=group_ids, + job_id=implicit_job_id, # Propagate job_id to child traces + api_key=api_key, + code_snippet=code_snippet, + env_config=env_config, + ) - # Compute aggregate reward (mean of non-None rewards) - rewards = [e.reward for e in completed if e.reward is not None] - if rewards: - ctx.reward = sum(rewards) / len(rewards) + # Create summary context (no trace, just aggregates results) + ctx = EvalContext.from_environment( + env=self, # type: ignore[arg-type] + name=name, + trace_id=trace_id, + api_key=api_key, + job_id=implicit_job_id, + code_snippet=code_snippet, + env_config=env_config, + ) + ctx._is_summary = True # Skip trace tracking + ctx.results = completed + self._last_evals = completed + + # Compute aggregate reward (mean of non-None rewards) + rewards = [e.reward for e in completed if e.reward is not None] + if rewards: + ctx.reward = sum(rewards) / len(rewards) + + # Check if any failed + error_occurred = any(e.error is not None for e in completed) - yield ctx + yield ctx + except Exception: + error_occurred = True + raise + finally: + _print_job_complete_url(implicit_job_id, name, error_occurred) async def _run_parallel_eval( self, @@ -302,6 +321,11 @@ async def _run_parallel_eval( """ # Lazy import to avoid circular dependency from hud.eval.context import EvalContext + from hud.eval.parallel import log_eval_stats, run_parallel_evals + + # Find user code frame and extract the with block body + caller_frame = find_user_frame() + body_source, captured_locals, context_var = get_with_block_body(caller_frame) # Calculate total evals and resolve group IDs total_evals = len(variant_combos) * group @@ -323,14 +347,24 @@ async def _run_parallel_eval( code_snippet=code_snippet, env_config=env_config, ) + ctx._suppress_link = True # Suppress individual links, job URL shown instead eval_contexts.append(ctx) idx += 1 - # Run in parallel (frame depth: _run_parallel_eval -> eval -> user code) - completed = await execute_parallel_evals(eval_contexts, caller_frame_depth=3) - - # Store results + # Run in parallel + logger.info( + "Running %d evals for '%s' (%d variants x %d runs)", + len(eval_contexts), + name, + len(variant_combos), + group, + ) + completed = await run_parallel_evals(eval_contexts, body_source, captured_locals, context_var) + + # Store results and log stats self._last_evals = completed + log_eval_stats(completed, name) + return completed diff --git a/hud/eval/parallel.py b/hud/eval/parallel.py index 45d2237d..6eab25a5 100644 --- a/hud/eval/parallel.py +++ b/hud/eval/parallel.py @@ -8,11 +8,13 @@ import ast import asyncio +import inspect import itertools import linecache import logging import textwrap import uuid +from types import FrameType from typing import TYPE_CHECKING, Any if TYPE_CHECKING: @@ -20,6 +22,53 @@ logger = logging.getLogger(__name__) +# Frames to skip when walking the call stack to find user code +# These are internal implementation details that shouldn't be considered user code +_SKIP_FRAME_PATTERNS = ( + # Python stdlib + "contextlib.py", + "asyncio", + # Third-party + "site-packages", + # HUD eval internals (both Unix and Windows paths) + "hud/eval/mixin.py", + "hud/eval/manager.py", + "hud/eval/parallel.py", + "hud\\eval\\mixin.py", + "hud\\eval\\manager.py", + "hud\\eval\\parallel.py", +) + + +def find_user_frame() -> FrameType: + """Walk the call stack to find the first user code frame. + + Skips internal frames from contextlib, asyncio, site-packages, + and hud.eval internals. + + Returns: + The frame containing user code (typically the async with statement). + + Raises: + ASTExtractionError: If no user code frame can be found. + """ + frame = inspect.currentframe() + if frame is None: + raise ASTExtractionError("Cannot get current frame") + + try: + caller_frame = frame.f_back + while caller_frame is not None: + filename = caller_frame.f_code.co_filename + # Stop at first frame not matching skip patterns + if not any(pattern in filename for pattern in _SKIP_FRAME_PATTERNS): + return caller_frame + caller_frame = caller_frame.f_back + + raise ASTExtractionError("Cannot find user code frame in call stack") + finally: + del frame + def expand_variants( variants: dict[str, Any] | None, @@ -140,14 +189,14 @@ async def execute_parallel_evals( if caller_frame is None: raise ASTExtractionError("Cannot get caller frame") - body_source, captured_locals = get_with_block_body(caller_frame) + body_source, captured_locals, context_var = get_with_block_body(caller_frame) finally: del frame # Run in parallel logger.info("Running %d parallel evals", len(contexts)) - completed = await run_parallel_evals(contexts, body_source, captured_locals) + completed = await run_parallel_evals(contexts, body_source, captured_locals, context_var) # Log stats log_eval_stats(completed) @@ -159,14 +208,14 @@ class ASTExtractionError(Exception): """Error extracting AST from source.""" -def get_with_block_body(frame: Any) -> tuple[str, dict[str, Any]]: +def get_with_block_body(frame: Any) -> tuple[str, dict[str, Any], str]: """Extract the body of a with-block from the calling frame. Args: frame: The calling frame (from inspect.currentframe()) Returns: - Tuple of (body_source, captured_locals) + Tuple of (body_source, captured_locals, context_var_name) """ filename = frame.f_code.co_filename lineno = frame.f_lineno @@ -192,7 +241,22 @@ def get_with_block_body(frame: Any) -> tuple[str, dict[str, Any]]: # Extract body source body_source = _extract_body(lines, with_node) - return body_source, frame.f_locals.copy() + # Extract the context variable name from 'as' clause + context_var = _extract_context_var(with_node) + + return body_source, frame.f_locals.copy(), context_var + + +def _extract_context_var(with_node: ast.AsyncWith) -> str: + """Extract the variable name from the 'as' clause of an async with statement.""" + if not with_node.items or not with_node.items[0].optional_vars: + raise ASTExtractionError("async with statement must use 'as' clause for parallel execution") + + var_node = with_node.items[0].optional_vars + if not isinstance(var_node, ast.Name): + raise ASTExtractionError("async with 'as' clause must be a simple variable name") + + return var_node.id def _find_async_with(tree: ast.AST, target_line: int) -> ast.AsyncWith | None: @@ -231,6 +295,7 @@ async def run_parallel_evals( eval_contexts: list[EvalContext], body_source: str, captured_locals: dict[str, Any], + context_var: str, ) -> list[EvalContext]: """Run the eval body in parallel for multiple contexts. @@ -240,12 +305,16 @@ async def run_parallel_evals( - reward - duration - Any error is captured in the context + + Args: + eval_contexts: List of EvalContext instances to run + body_source: The source code of the with-block body + captured_locals: Local variables captured from the caller + context_var: The variable name used in the 'as' clause """ - # Create runner function - # The variable name in the with statement is 'ctx' by convention - # but we use 'env' since that's what the user will see - wrapped = f"async def __runner__(env):\n{textwrap.indent(body_source, ' ')}" + # Create runner function using the actual variable name from the 'as' clause + wrapped = f"async def __runner__({context_var}):\n{textwrap.indent(body_source, ' ')}" code = compile(wrapped, "", "exec") namespace = captured_locals.copy() exec(code, namespace) # noqa: S102 diff --git a/hud/rl/README.md b/hud/rl/README.md deleted file mode 100644 index af451e24..00000000 --- a/hud/rl/README.md +++ /dev/null @@ -1,30 +0,0 @@ -We suggest running hud rl (or with the --local flag) for optimal hyperparameters and native HuggingFace running. - -However, to run this independently, sping up an instance with at least 2 GPUs and run: -```bash -sudo apt-get update -y && sudo apt-get install -y cuda-toolkit-12-6 -uv pip install -e .[rl] -uv pip install ninja -uv pip install flash-attn --no-build-isolation -``` - -Launch a vllm server with: -```bash -export VLLM_ALLOW_RUNTIME_LORA_UPDATING=True -export TOKENIZERS_PARALLELISM=false -export VLLM_LOGGING_LEVEL=INFO -export CUDA_VISIBLE_DEVICES=7 # Set this to your last GPU - -uv run vllm serve Qwen/Qwen2.5-VL-3B-Instruct \ - --api-key token-abc123 --host 0.0.0.0 --port 8000 --tensor-parallel-size 1 --trust-remote-code \ - --max-model-len 16384 --enable-lora --max-lora-rank 64 --max-cpu-loras 4 --enable-auto-tool-choice \ - --tool-call-parser hermes --disable-log-requests --dtype auto -``` - -And training with (replace 2 with your spare GPUs): -```bash -hud get hud-evals/2048-basic -torchrun --nproc-per-node 2 -m hud.rl.train --tasks 2048-basic.json --verbose -``` - -Add a `--config path/to/config.json` flag to run a specific configuration (or change the defaults in config.py) diff --git a/hud/rl/__init__.py b/hud/rl/__init__.py deleted file mode 100644 index 604974ce..00000000 --- a/hud/rl/__init__.py +++ /dev/null @@ -1 +0,0 @@ -"""RL module for HUD.""" diff --git a/hud/rl/actor.py b/hud/rl/actor.py deleted file mode 100644 index 4c9a3390..00000000 --- a/hud/rl/actor.py +++ /dev/null @@ -1,178 +0,0 @@ -"""Actor for episode collection using vLLM and HUD.""" - -from __future__ import annotations - -import asyncio -import logging - -import httpx -from openai import AsyncOpenAI - -import hud -from hud.agents.openai_chat import OpenAIChatAgent -from hud.clients.utils.retry_transport import create_retry_httpx_client -from hud.types import Task, Trace -from hud.utils.hud_console import HUDConsole - -from .config import Config - -logger = logging.getLogger(__name__) -hud_console = HUDConsole(logger) - - -class Actor: - """Collects episodes using vLLM-served models via HUD agents.""" - - def __init__(self, config: Config) -> None: - self.config = config - self.actor_config = config.actor - self.current_adapter = config.model.base_model - - # Setup OpenAI client for vLLM - base_url = self.actor_config.vllm_base_url.replace("localhost", "127.0.0.1") - self.openai_client = self._create_openai_client(base_url) - - def _create_openai_client(self, base_url: str) -> AsyncOpenAI: - """Create OpenAI client with optimized settings for vLLM.""" - # Match connection limits to parallel_episodes to avoid bottlenecks - # Use shorter per-request timeout and keep retries modest to avoid long blocking - http_client = create_retry_httpx_client( - timeout=httpx.Timeout(60.0), - ) - return AsyncOpenAI( - base_url=base_url, - api_key=self.actor_config.vllm_api_key, - http_client=http_client, - max_retries=2, - ) - - def create_agent(self) -> OpenAIChatAgent: - """Create an agent with the current adapter.""" - return OpenAIChatAgent( - openai_client=self.openai_client, - model_name=self.current_adapter, - allowed_tools=self.actor_config.allowed_tools, - append_setup_output=False, - system_prompt=self.actor_config.system_prompt, - verbose=self.config.verbose, - completion_kwargs={ - "temperature": self.actor_config.temperature, - "max_tokens": self.actor_config.max_new_tokens, - "tool_choice": "required" if self.actor_config.force_tool_choice else "auto", - }, - ) - - def update_adapter(self, adapter_name: str) -> None: - """Update the current adapter being used.""" - self.current_adapter = adapter_name - hud_console.info(f"[Actor] Using adapter: {adapter_name}") - - async def run_tasks(self, tasks: list[Task], job_id: str) -> list[Trace]: - """Run tasks and collect traces.""" - traces = [] - - # Process tasks in batches respecting max_parallel_episodes limit - for batch_start in range(0, len(tasks), self.actor_config.max_parallel_episodes): - batch_end = min(batch_start + self.actor_config.max_parallel_episodes, len(tasks)) - batch = tasks[batch_start:batch_end] - - # Run batch in parallel with per-episode timeout protection - async def run_with_timeout(t: Task) -> Trace: - try: - return await asyncio.wait_for( - self._run_task(t, job_id), - timeout=self.actor_config.episode_timeout_sec, - ) - except TimeoutError: - hud_console.warning_log(f"Episode timed out for task {t.id}") - # Attach task so buffer grouping has key - return Trace(isError=True, content="Episode timeout", task=t) - - results = await asyncio.gather( - *[run_with_timeout(t) for t in batch], - return_exceptions=True, - ) - - # Normalize exceptions to error traces and ensure task is attached - for t, res in zip(batch, results, strict=False): - if isinstance(res, Exception): - hud_console.warning_log(f"Episode error: {res}") - traces.append(Trace(isError=True, content=str(res), task=t)) - else: - traces.append(res) - - return traces - - async def _run_task(self, task: Task, job_id: str) -> Trace: - """Run a single task.""" - agent = self.create_agent() - - # Run the task - try: - async with hud.async_trace(f"Training | {task.prompt}", job_id=job_id): - result = await agent.run(task, max_steps=self.actor_config.max_steps_per_episode) - - except Exception: - logger.info("GOT EXCEPTION") - # Preserve task on exception for grouping - return Trace(isError=True, task=task) - - result.info["tool_spec"] = agent.get_tool_schemas() - - return result - - -if __name__ == "__main__": - from hud.types import Task - - async def test_actor() -> None: - """Test the actor with a single 2048 task using local hud-browser image.""" - config = Config() - config.actor.max_parallel_episodes = 1 - config.actor.max_steps_per_episode = 6 - config.actor.episode_timeout_sec = 120 - config.verbose = True - - # Create test task with local hud-browser image - task_data = { - "id": "test_2048_128", - "prompt": "Play the browser-based 2048 game and try to reach the 128 tile. Start by taking a screenshot, then make strategic moves using arrow keys.", # noqa: E501 - "mcp_config": { - "local": { - "command": "sh", - "args": [ - "-c", - "docker run --rm --platform linux/amd64 -i hud-browser:latest 2>/dev/null", - ], - } - }, - "setup_tool": {"name": "launch_app", "arguments": {"app_name": "2048"}}, - "evaluate_tool": { - "name": "evaluate", - "arguments": {"name": "game_2048_max_number", "arguments": {"target": 128}}, - }, - "agent_config": { - "system_prompt": "You are an expert 2048 game player. Use arrow keys to reach the target tile. First take a screenshot, then make strategic moves.", # noqa: E501 - }, - } - - task = Task(**task_data) - actor = Actor(config) - - logger.info("Testing actor with task: %s", task.id) - logger.info("Model: %s", config.model.base_model) - logger.info("VLLM: %s", config.actor.vllm_base_url) - - traces = await actor.run_tasks([task], job_id="test_2048") - - for trace in traces: - if trace.isError: - logger.info("Error: %s", trace.content) - else: - logger.info("Success!") - logger.info("Trace info: %s", trace.info if hasattr(trace, "info") else "No info") - # Check for evaluation in the trace info - if hasattr(trace, "info") and "evaluation" in trace.info: - logger.info(" Evaluation: %s", trace.info["evaluation"]) - - asyncio.run(test_actor()) diff --git a/hud/rl/buffer.py b/hud/rl/buffer.py deleted file mode 100644 index 17cdff87..00000000 --- a/hud/rl/buffer.py +++ /dev/null @@ -1,405 +0,0 @@ -"""Replay buffer for storing and sampling episodes.""" - -from __future__ import annotations - -import logging -import random -from collections import deque -from typing import TYPE_CHECKING, Generic, TypeVar - -from hud.types import Task, Trace -from hud.utils.hud_console import HUDConsole - -logger = logging.getLogger(__name__) -hud_console = HUDConsole(logger=logger) - -T = TypeVar("T") - -if TYPE_CHECKING: - from collections.abc import Callable - - from hud.rl.config import Config - - -class Buffer(Generic[T]): - """Simple buffer for a list of tasks, traces or episodes.""" - - def __init__(self, max_size: int = 10000) -> None: - self.max_size = max_size - self.buffer: deque[T] = deque(maxlen=max_size) - - def add(self, items: list[T] | T, shuffle: bool = False) -> None: - """Add items to buffer.""" - if isinstance(items, list): - for item in items: - self.buffer.append(item) - else: - self.buffer.append(items) - if shuffle: - random.shuffle(self.buffer) - - def add_fill(self, items: list[T] | T, target_size: int, shuffle: bool = False) -> None: - """Add items to buffer until the buffer is at least the target size.""" - while len(self.buffer) < target_size: - self.add(items, shuffle) - - def get(self, n: int = 0) -> list[T]: - """Get items from the buffer.""" - if n == 0: - return list(self.buffer) - if n > len(self.buffer): - raise ValueError("Not enough items in buffer") - return list(self.buffer)[-n:] - - def consume(self, n: int = 0) -> list[T]: - """Consume items from the buffer.""" - if n == 0: - return list(self.buffer) - if n > len(self.buffer): - raise ValueError("Not enough items in buffer") - - return [self.buffer.pop() for _ in range(n)] - - def get_filtered( - self, n: int = 0, filter_fn: Callable[[T], bool] | None = None, consume: bool = False - ) -> list[T]: - """Filter the buffer by a filter function.""" - filtered = ( - [item for item in self.buffer if filter_fn(item)] if filter_fn else list(self.buffer) - ) - if n == 0: - return filtered - return self.consume(n) if consume else self.get(n) - - def sample( - self, - batch_size: int, - n: int = 0, - filter_fn: Callable[[T], bool] | None = None, - consume: bool = False, - ) -> list[T]: - """Sample a batch of items with optional filtering.""" - items = self.get_filtered(n, filter_fn, consume) - - if len(items) < batch_size: - hud_console.warning(f"Buffer has {len(items)} items, requested {batch_size}") - return items - - return random.sample(items, batch_size) - - def clear(self) -> None: - """Clear the buffer.""" - self.buffer.clear() - - def __len__(self) -> int: - """Use len() directly on Buffer instances.""" - return len(self.buffer) - - -class DatasetBuffer(Buffer[Task]): - """ - Buffer for a dataset. - Loads in individual tasks that will be trained for a specified number of training steps. - """ - - def __init__( - self, - dataset: list[Task] | Task, - config: Config, - ) -> None: - self.config = config - - self.group_size = config.training.group_size - self.batch_size = config.training.batch_size - self.training_steps = config.training.training_steps - - if self.group_size > self.batch_size: - raise ValueError( - f"Group size is greater than batch size, {self.group_size} > {self.batch_size}" - ) - - if self.batch_size % self.group_size != 0: - raise ValueError( - f"A batch cannot have irregular groups, {self.group_size} % {self.batch_size} != 0" - ) - - if self.group_size % config.training.mini_batch_size != 0: - raise ValueError( - f"Group size is not a multiple of mini batch size, {self.group_size} % {config.training.mini_batch_size} != 0" # noqa: E501 - ) - - self.groups_per_batch = self.batch_size // self.group_size - self.number_of_tasks = self.training_steps * self.groups_per_batch - - super().__init__(self.number_of_tasks) - - dataset = dataset if isinstance(dataset, list) else [dataset] - tasks = self._validate_tasks(dataset) - if config.training.shuffle_dataset: - random.shuffle(tasks) - if len(tasks) > self.number_of_tasks: - leftovers = len(tasks) - self.number_of_tasks - hud_console.warning( - f"Training steps ({self.training_steps}) will lead to {leftovers} tasks not being trained" # noqa: E501 - ) - tasks = tasks[: self.number_of_tasks] - - # Check if the dataset is imbalanced - self.dataset_size = len(tasks) - if self.training_steps % self.dataset_size != 0: - leftovers = self.number_of_tasks % self.dataset_size - hud_console.warning( - f"Dataset imbalanced ({leftovers} tasks will be trained 1 more time)" - ) - hud_console.warning( - f"This is because the number of training steps ({self.training_steps}) is not a multiple of the dataset size ({self.dataset_size})" # noqa: E501 - ) - - if config.verbose: - hud_console.info(f"Sample task: {tasks[0]}") - - self.add_fill(tasks, self.number_of_tasks, config.training.shuffle_dataset) - - def _validate_tasks(self, tasks: list[Task]) -> list[Task]: - """Validate that all tasks are proper HUD Task objects.""" - if not tasks: - raise ValueError("No tasks provided to DatasetBuffer") - - validated_tasks = [] - for i, task in enumerate(tasks): - if not isinstance(task, Task): - raise TypeError(f"Task at index {i} is not a HUD Task object, got {type(task)}") - validated_tasks.append(task) - - return validated_tasks - - @property - def info(self) -> dict[str, int | float | str]: - """Get the info of the buffer.""" - return { - "total_items": len(self), - "total_traces": self.number_of_tasks * self.group_size, - "total_batches": self.training_steps, - "task_repeats": self.number_of_tasks // self.dataset_size, - "dataset_size": self.dataset_size, - "group_size": self.group_size, - "batch_size": self.batch_size, - } - - def get_tasks(self, consume: bool = True) -> list[Task]: - """Get tasks for a batch.""" - tasks = self.consume(self.groups_per_batch) if consume else self.get(self.groups_per_batch) - # Create groups where each group contains group_size copies of the same task - result = [] - for task in tasks: - result.extend([task] * self.group_size) - return result - - -class ReplayBuffer(Buffer[Trace]): - """Buffer for traces.""" - - def __init__(self, config: Config) -> None: - self.config = config - - self.buffer_steps = config.training.buffer_steps - self.select_strategy = config.training.select_strategy - self.group_size = config.training.group_size - self.batch_size = config.training.batch_size - - buffer_size = self.buffer_steps * self.batch_size - - super().__init__(buffer_size) - - def sample_traces(self) -> list[Trace]: - """Sample traces for a batch.""" - if self.select_strategy == "recent": - return self.get(self.batch_size) - elif self.select_strategy == "random": - return self.sample(self.batch_size) - elif self.select_strategy == "variance": - return self._sample_high_variance_traces() - else: - raise ValueError(f"Invalid select strategy: {self.select_strategy}") - - def _extract_group_key(self, trace: Trace) -> tuple[str, str]: - """Return a stable grouping key for a trace. - - Preference order: - 1) task.id when present (kind='id') - 2) task.prompt exact string (kind='prompt') when id is None - 3) 'NA' for missing/errored entries (kind='NA') - """ - if getattr(trace, "isError", False): - return ("NA", "NA") - - task = getattr(trace, "task", None) - if task is None: - return ("NA", "NA") - - tid = getattr(task, "id", None) - if tid is not None: - return ("id", str(tid)) - - prompt = getattr(task, "prompt", None) - if prompt: - return ("prompt", str(prompt)) - - return ("NA", "NA") - - def _validate_and_split_groups( - self, recent_traces: list[Trace] - ) -> tuple[list[list[Trace]], list[tuple[str, str]]]: - """Validate and split recent traces into homogeneous groups by id or prompt. - - - Uses id when present; otherwise falls back to prompt equality. - - Any NA/error traces are excluded and the group is filled by duplicating - existing valid members in that group. - - Always returns len == groups_per_batch groups of size == group_size. - """ - from collections import Counter - - groups_per_batch = self.batch_size // self.group_size - - window_keys = [self._extract_group_key(t) for t in recent_traces] - window_counter = Counter(k for k in window_keys if k[0] != "NA") - - validated_groups: list[list[Trace]] = [] - selected_keys: list[tuple[str, str]] = [] - - for g_idx in range(groups_per_batch): - start = g_idx * self.group_size - end = start + self.group_size - chunk = recent_traces[start:end] - - key_counts = Counter() - per_item_keys: list[tuple[str, str]] = [] - for tr in chunk: - k = self._extract_group_key(tr) - per_item_keys.append(k) - if k[0] != "NA": - key_counts[k] += 1 - - if key_counts: - best_key = key_counts.most_common(1)[0][0] - elif window_counter: - best_key = window_counter.most_common(1)[0][0] - else: - best_key = ("NA", "NA") - - homogeneous = [tr for tr, k in zip(chunk, per_item_keys, strict=False) if k == best_key] - - while len(homogeneous) < self.group_size: - if homogeneous: - homogeneous.append(homogeneous[-1]) - else: - idx = next((i for i, wk in enumerate(window_keys) if wk[0] != "NA"), None) - if idx is not None: - homogeneous.append(recent_traces[idx]) - elif chunk: - homogeneous.append(chunk[0]) - else: - homogeneous.append(recent_traces[0]) - - validated_groups.append(homogeneous) - selected_keys.append(best_key) - - return validated_groups, selected_keys - - def _sample_high_variance_traces(self) -> list[Trace]: - from collections import Counter, defaultdict, deque - - buf_list = list(self.buffer) - if len(buf_list) < self.batch_size: - hud_console.warning( - f"[group-sampler] Buffer has only {len(buf_list)} traces, need {self.batch_size}" - ) - while len(buf_list) < self.batch_size: - take = min(len(buf_list) or 1, self.batch_size - len(buf_list)) - buf_list.extend(buf_list[:take]) - recent_traces = buf_list[-self.batch_size :] - - recent_keys = [self._extract_group_key(t) for t in recent_traces] - hud_console.info(f"[group-sampler] recent-window histogram: {Counter(recent_keys)}") - - hud_console.info( - f"[group-sampler] Building earlier traces lookup, buffer size: {len(buf_list)}" - ) - earlier_traces_by_key: dict[tuple[str, str], deque[Trace]] = defaultdict(deque) - for tr in buf_list[: -self.batch_size]: - k = self._extract_group_key(tr) - if k[0] != "NA": - earlier_traces_by_key[k].append(tr) - - groups, group_keys = self._validate_and_split_groups(recent_traces) - - final_traces: list[Trace] = [] - for g_idx, (homogeneous, target_key) in enumerate(zip(groups, group_keys, strict=False)): - - def current_mean(h: list[Trace]) -> float: - if not h: - return 0.0 - vals = [float(getattr(t, "reward", 0.0) or 0.0) for t in h] - return sum(vals) / len(vals) - - pool = earlier_traces_by_key.get(target_key, deque()) - if pool: - pool_vals = [float(getattr(tr, "reward", 0.0) or 0.0) for tr in list(pool)] - if pool_vals: - pool_mean = sum(pool_vals) / len(pool_vals) - pool_var = sum((v - pool_mean) * (v - pool_mean) for v in pool_vals) / len( - pool_vals - ) - hud_console.info( - f"[group-sampler] Group {g_idx}: earlier-pool size={len(pool_vals)} " - f"mean={pool_mean:.4f} std={(pool_var**0.5):.4f}" - ) - - replace_k = max(1, self.group_size // 4) - replace_k = min(replace_k, len(pool), self.group_size) - - if replace_k > 0: - mu = current_mean(homogeneous) - pool_list = list(pool) - pool_indices = list(range(len(pool_list))) - pool_indices.sort( - key=lambda i: abs( - (float(getattr(pool_list[i], "reward", 0.0) or 0.0)) - mu - ), - reverse=True, - ) - chosen_pool_idx = set(pool_indices[:replace_k]) - replacements = [pool_list[i] for i in pool_indices[:replace_k]] - - remaining = [tr for i, tr in enumerate(pool_list) if i not in chosen_pool_idx] - earlier_traces_by_key[target_key] = deque(remaining) - - group_indices = list(range(len(homogeneous))) - mu = current_mean(homogeneous) - group_indices.sort( - key=lambda i: abs( - (float(getattr(homogeneous[i], "reward", 0.0) or 0.0)) - mu - ) - ) - target_positions = group_indices[:replace_k] - - for pos, new_tr in zip(target_positions, replacements, strict=False): - homogeneous[pos] = new_tr - - if any(self._extract_group_key(t) != target_key for t in homogeneous): - raise RuntimeError(f"Group {g_idx} is not homogeneous after sampling") - final_traces.extend(homogeneous) - - for i in range(0, len(final_traces), self.group_size): - block = final_traces[i : i + self.group_size] - keys = {self._extract_group_key(t) for t in block} - if len(keys) != 1: - raise RuntimeError(f"Homogeneity validation failed for block starting at index {i}") - - hud_console.info( - f"[group-sampler] final histogram: " - f"{Counter(self._extract_group_key(t) for t in final_traces)}" - ) - return final_traces - - # -------------------------------------------------------------------- diff --git a/hud/rl/chat_template.jinja b/hud/rl/chat_template.jinja deleted file mode 100644 index 00fd8c18..00000000 --- a/hud/rl/chat_template.jinja +++ /dev/null @@ -1,101 +0,0 @@ -{% set image_count = namespace(value=0) %} -{% set video_count = namespace(value=0) %} -{{- '<|im_start|>system\n' }} -{%- if messages[0]['role'] == 'system' -%} - {%- if messages[0]['content'] is string -%} - {{ messages[0]['content'] }} - {%- else -%} - {%- for content in messages[0]['content'] -%} - {%- if content['type'] == 'image' or 'image' in content or 'image_url' in content -%} - {%- set image_count.value = image_count.value + 1 -%} - {%- if add_vision_id -%} - {{ 'Picture ' ~ image_count.value ~ ': ' }} - {%- endif -%} - {{ '<|vision_start|><|image_pad|><|vision_end|>' }} - {%- elif content['type'] == 'video' or 'video' in content -%} - {%- set video_count.value = video_count.value + 1 -%} - {%- if add_vision_id -%} - {{ 'Video ' ~ video_count.value ~ ': ' }} - {%- endif -%} - {{ '<|vision_start|><|video_pad|><|vision_end|>' }} - {%- elif 'text' in content -%} - {{ content['text'] }} - {%- endif -%} - {%- endfor -%} - {%- endif -%} -{%- else -%} - {{ 'You are Qwen, created by Alibaba Cloud. You are a helpful assistant.' }} -{%- endif -%} -{%- if tools -%} - {{ '\n\n# Tools\n\nYou may call one or more functions to assist with the user query.\n\nYou are provided with function signatures within XML tags:\n\n' }} - {{- tools | map('tojson') | join('\n') -}} - {{ '\n\n\nFor each function call, return a json object with function name and arguments within XML tags:\n\n{"name": , "arguments": }\n' }} -{%- endif -%} -{{ '<|im_end|>\n' }} -{%- for message in messages -%} - {# Skip the first system message as it was already rendered. #} - {%- if loop.first and message.role == 'system' %}{% continue %}{% endif -%} - - {# Render tool messages. The logic is slightly different with other messages. #} - {%- if message['role'] == 'tool' -%} - {%- if loop.first or messages[loop.index0 - 1]['role'] != 'tool' -%} - {{ '<|im_start|>user' }} - {%- endif -%} - {{ '\n\n' }} - {%- else -%} - {{ '<|im_start|>' ~ message['role'] ~ '\n' }} - {%- endif -%} - - {%- if message['content'] is string -%} - {{ message['content'] }} - {%- else -%} - {%- for content in message['content'] -%} - {%- if content['type'] == 'image' or 'image' in content or 'image_url' in content -%} - {%- set image_count.value = image_count.value + 1 -%} - {%- if add_vision_id -%} - {{ 'Picture ' ~ image_count.value ~ ': ' }} - {%- endif -%} - {{ '<|vision_start|><|image_pad|><|vision_end|>' }} - {%- elif content['type'] == 'video' or 'video' in content -%} - {%- set video_count.value = video_count.value + 1 -%} - {%- if add_vision_id -%} - {{ 'Video ' ~ video_count.value ~ ': ' }} - {%- endif -%} - {{ '<|vision_start|><|video_pad|><|vision_end|>' }} - {%- elif 'text' in content and message['role'] == 'assistant' -%} - {% generation %} {{ content['text'] }} {% endgeneration %} - {%- elif 'text' in content -%} - {{ content['text'] }} - {%- endif -%} - {%- endfor -%} - {%- endif -%} - {# Render tool_calls in AI messages. #} - {%- if message['role'] == 'assistant' and 'tool_calls' in message -%} - {# It will be cleaner if I can use some map function and join them with '\n' #} - {%- for tool_call in message['tool_calls'] -%} - {%- if tool_call['function'] is defined -%} - {%- set tool_call = tool_call['function'] -%} - {%- endif -%} - {# Handle the case where arguments is already a JSON string (OpenAI format) #} - {%- if tool_call.arguments is string -%} - {% generation %} {{ '\n{"name": "' }}{{ tool_call.name }}{{ '", "arguments": ' }}{{ tool_call.arguments }}{{ '}\n' }} {% endgeneration %} - {%- else -%} - {% generation %} {{ '\n' }}{{ tool_call | tojson }}{{ '\n' }} {% endgeneration %} - {%- endif -%} - {%- if not loop.last -%} - {% generation %} {{ '\n' }} {% endgeneration %} - {%- endif -%} - {%- endfor -%} - {%- endif -%} - {%- if message['role'] == 'tool' -%} - {{ '\n' }} - {%- if loop.last or messages[loop.index0 + 1]['role'] != 'tool' -%} - {{ '<|im_end|>\n' }} - {%- endif -%} - {%- else -%} - {{ '<|im_end|>\n' }} - {%- endif -%} -{%- endfor -%} -{%- if add_generation_prompt -%} - {{ '<|im_start|>assistant\n' }} -{%- endif -%} diff --git a/hud/rl/config.py b/hud/rl/config.py deleted file mode 100644 index f795e3ff..00000000 --- a/hud/rl/config.py +++ /dev/null @@ -1,193 +0,0 @@ -"""Configuration for RL training.""" - -from __future__ import annotations - -from dataclasses import dataclass, field -from typing import Literal - -# List of supported VL (Vision-Language) models -SUPPORTED_MODELS = [ - "Qwen/Qwen2.5-VL-3B-Instruct", - "Qwen/Qwen2.5-VL-7B-Instruct", - "Qwen/Qwen2.5-VL-14B-Instruct", - "Qwen/Qwen2.5-VL-32B-Instruct", - "Qwen/Qwen2.5-VL-72B-Instruct", - "Qwen/Qwen2.5-7B-Instruct", - "Qwen/Qwen2.5-3B-Instruct", -] - - -def validate_vl_model(model_name: str) -> None: - """Validate that the model is a supported VL model. - - Args: - model_name: The model name to validate - - Raises: - ValueError: If the model is not a supported VL model - """ - if not any(model_name.startswith(supported) for supported in SUPPORTED_MODELS): - raise ValueError( - f"Model '{model_name}' is not a supported VL model. " - f"Only VL (Vision-Language) models are supported for RL training.\n" - f"Supported models: {', '.join(SUPPORTED_MODELS)}\n" - f"Note: '{model_name}' appears to be a text-only model." - ) - - -@dataclass -class ModelConfig: - """Model and LoRA configuration.""" - - base_model: str = "Qwen/Qwen2.5-VL-3B-Instruct" - lora_r: int = 16 - lora_alpha: int = 32 - lora_dropout: float = 0.1 - target_modules: tuple[str, ...] = ( - "q_proj", - "k_proj", - "v_proj", - "o_proj", - "gate_proj", - "up_proj", - "down_proj", - ) - min_pixels: int = 256 * 28 * 28 - max_pixels: int = 512 * 28 * 28 - attn_implementation: str = "flash_attention_2" - use_liger: bool = True - gradient_checkpointing: bool = True - adapter_path: str | None = None # Path to existing LoRA adapter to load as baseline - - -@dataclass -class TrainingConfig: - """Training hyperparameters.""" - - # GPU parameters - gpu_type: str = "A100" - num_gpus: int = 2 - - # Training parameters - training_steps: int = 100 - shuffle_dataset: bool = False - save_every_batches: int = 1 - - # Batching parameters - epochs: int = 1 - batch_size: int = 16 - group_size: int = 8 - mini_batch_size: int = 1 - update_after_group: bool = True # Whether to update the policy after each task group - accumulate_over_minibatches: bool = False # Whether to accumulate over minibatches - - # Advantage calculation parameters - batch_level: Literal["group", "batch"] = "group" - no_std: bool = False - leave_one_out: bool = True - - # Replay buffer parameters - buffer_steps: int = 8 - select_strategy: Literal["recent", "variance", "random"] = "variance" - - # Aggregation parameters - ppo_mode: Literal["per_token", "per_trace"] = "per_token" - token_agg: Literal["mean", "sum"] = "mean" # noqa: S105 - - # Regularization parameters - kl_beta: float = 0.001 - entropy_beta: float = 0.001 - top_eps: float = 0.2 - bottom_eps: float = 0.1 - - # Training hyperparameters - lr: float = 3e-5 - grad_clip: float = 1.0 - - # Adam hyperparameters - use_8bit_optimizer: bool = True - adam_betas: tuple[float, float] = (0.9, 0.999) - adam_eps: float = 1e-8 - - -@dataclass -class ActorConfig: - """Actor/episode collection configuration.""" - - # Execution parameters - max_steps_per_episode: int = 5 - max_parallel_episodes: int = 48 - max_new_tokens: int = 1024 - force_tool_choice: bool = True - allowed_tools: list[str] | None = None - - # Model parameters - temperature: float = 0.7 - - # Hud agent parameters - system_prompt: str = "You are an expert agent. Complete the task efficiently." - vllm_base_url: str = "http://localhost:8000/v1" - vllm_api_key: str = "token-abc123" - - # Episode execution timeout (seconds) - episode_timeout_sec: int = 600 - - -@dataclass -class Config: - """Main configuration combining all sub-configs.""" - - model: ModelConfig = field(default_factory=ModelConfig) - training: TrainingConfig = field(default_factory=TrainingConfig) - actor: ActorConfig = field(default_factory=ActorConfig) - - # Telemetry configuration - job_name: str = "RL Training" - job_id: str | None = None # Use existing job ID if provided - stats_interval: int = 1 - verbose: bool = False - very_verbose: bool = False - - # Paths - out_dir: str = "./checkpoints" - adapter_prefix: str = "cua-grpo-step" - - # Misc - seed: int = 1234 - - @classmethod - def from_dict(cls, d: dict) -> Config: - """Create config from dictionary.""" - model = ModelConfig(**d.get("model", {})) - training = TrainingConfig(**d.get("training", {})) - actor = ActorConfig(**d.get("actor", {})) - - return cls( - model=model, - training=training, - actor=actor, - job_name=d.get("job_name", "RL Training"), - job_id=d.get("job_id"), - stats_interval=d.get("stats_interval", 1), - verbose=d.get("verbose", False), - very_verbose=d.get("very_verbose", False), - out_dir=d.get("out_dir", "./checkpoints"), - adapter_prefix=d.get("adapter_prefix", "cua-grpo-step"), - seed=d.get("seed", 1234), - ) - - def to_dict(self) -> dict: - """Convert config to dictionary.""" - return { - "model": self.model.__dict__, - "training": self.training.__dict__, - "actor": self.actor.__dict__, - "job_name": self.job_name, - "job_id": self.job_id, - "stats_interval": self.stats_interval, - "verbose": self.verbose, - "very_verbose": self.very_verbose, - "out_dir": self.out_dir, - "adapter_prefix": self.adapter_prefix, - "seed": self.seed, - } diff --git a/hud/rl/distributed.py b/hud/rl/distributed.py deleted file mode 100644 index 6bade77c..00000000 --- a/hud/rl/distributed.py +++ /dev/null @@ -1,132 +0,0 @@ -"""Distributed training utilities for GRPO.""" - -from __future__ import annotations - -import os -from datetime import timedelta -from typing import Any - -import torch -import torch.distributed as dist - - -def setup_distributed() -> None: - """Initialize distributed training environment.""" - if "RANK" in os.environ and int(os.environ["WORLD_SIZE"]) > 1: - # Set device for this process - local_rank = int(os.environ["LOCAL_RANK"]) - torch.cuda.set_device(local_rank) - - # Initialize process group - # Increase watchdog timeout to accommodate long eval/sampling phases - # and enable clearer NCCL error handling. - os.environ.setdefault("TORCH_NCCL_ASYNC_ERROR_HANDLING", "1") - dist.init_process_group("nccl", timeout=timedelta(minutes=20)) - - -def get_local_rank() -> int: - """Get local rank from environment.""" - return int(os.environ.get("LOCAL_RANK", 0)) - - -def get_global_rank() -> int: - """Get global rank from environment.""" - return int(os.environ.get("RANK", 0)) - - -def get_world_size() -> int: - """Get world size from environment.""" - return int(os.environ.get("WORLD_SIZE", 1)) - - -def cleanup_distributed() -> None: - """Clean up distributed environment.""" - if dist.is_initialized(): - dist.destroy_process_group() - - -def is_main_process() -> bool: - """Check if this is the main process (rank 0).""" - if not dist.is_initialized(): - return True - return dist.get_rank() == 0 - - -def synchronize() -> None: - """Synchronize all processes.""" - if dist.is_initialized(): - dist.barrier() - - -def all_reduce_mean(tensor: torch.Tensor) -> torch.Tensor: - """Average a tensor across all processes.""" - if not dist.is_initialized(): - return tensor - - world_size = dist.get_world_size() - dist.all_reduce(tensor, op=dist.ReduceOp.SUM) - tensor /= world_size - return tensor - - -def broadcast_object(obj: Any, src: int = 0) -> Any: - """Broadcast a Python object from src rank to all ranks. - - Args: - obj: Object to broadcast (used on src rank) - src: Source rank - device: Device for temporary tensor buffer during pickling transfer - """ - if not dist.is_initialized(): - return obj - - obj_list = [obj] if dist.get_rank() == src else [None] - dist.broadcast_object_list(obj_list, src=src) - return obj_list[0] - - -def scatter_object( - obj_list: list[Any] | None, - src: int = 0, -) -> Any: - """Scatter a list of Python objects from src so each rank receives one object. - - Usage: - - On src rank: pass the full list (length == world_size) - - On non-src ranks: pass None - - Returns: - The object intended for this rank. - """ - if not dist.is_initialized(): - # Single-process: return first element if provided, else None - if obj_list is None or len(obj_list) == 0: - return None - return obj_list[0] - - out: list[Any] = [None] - if dist.get_rank() == src: - dist.scatter_object_list(out, obj_list, src=src) - else: - dist.scatter_object_list(out, None, src=src) - return out[0] - - -def gather_tensors(tensor: torch.Tensor) -> list[torch.Tensor] | None: - """Gather tensors from all ranks to rank 0. - - Returns: - List of tensors on rank 0, None on other ranks - """ - if not dist.is_initialized(): - return [tensor] - - world_size = dist.get_world_size() - - if dist.get_rank() == 0: - gathered = [torch.zeros_like(tensor) for _ in range(world_size)] - dist.gather(tensor, gathered, dst=0) - return gathered - else: - dist.gather(tensor, None, dst=0) - return None diff --git a/hud/rl/learner.py b/hud/rl/learner.py deleted file mode 100644 index 859d7ec4..00000000 --- a/hud/rl/learner.py +++ /dev/null @@ -1,648 +0,0 @@ -"""GRPO learner for vision-language and text models.""" - -from __future__ import annotations - -import logging -import os -from typing import TYPE_CHECKING, Any - -import torch -from peft import LoraConfig, get_peft_model -from torch.nn.parallel import DistributedDataParallel as DDP -from transformers import ( - AutoModelForCausalLM, - AutoProcessor, - AutoTokenizer, - Qwen2_5_VLForConditionalGeneration, -) - -try: - from liger_kernel.transformers import apply_liger_kernel_to_qwen2_5_vl # type: ignore - - LIGER_AVAILABLE = True -except ImportError: - LIGER_AVAILABLE = False - -try: - import bitsandbytes as bnb # type: ignore - - BNB_AVAILABLE = True -except ImportError: - BNB_AVAILABLE = False - -from contextlib import nullcontext - -from hud.rl.distributed import ( - get_local_rank, - get_world_size, - is_main_process, -) -from hud.rl.utils import ( - batch_training_samples, - entropy_from_logits, - get_gpu_utilization, - get_memory_usage, - prepare_inputs, -) -from hud.utils.hud_console import HUDConsole - -from .types import TrainingMetrics, TrainingSample - -logger = logging.getLogger(__name__) -hud_console = HUDConsole(logger) - -if TYPE_CHECKING: - from .config import Config - - -class GRPOLearner: - """GRPO learning algorithm for Vision-Language Models (VLMs) and Text Models.""" - - def __init__(self, config: Config) -> None: - self.config = config - self.local_rank = get_local_rank() - self.world_size = get_world_size() - self.device = torch.device( - f"cuda:{self.local_rank}" if torch.cuda.is_available() else "cpu" - ) - - # Detect model type - self.is_vl_model = "VL" in config.model.base_model - - # Load models and processor - self.processor, self.policy, self.ref, self.optimizer = self._load_models() - self.metrics: list[TrainingMetrics] = [] - - def log(self, message: str) -> None: - hud_console.info_log(f"[{self.local_rank}] {message}") - - def _load_models(self) -> tuple[Any, Any, Any, Any]: - """Load policy, reference models and optimizer.""" - model_cfg = self.config.model - - # Detect if this is a VL model or standard text model - is_vl_model = "VL" in model_cfg.base_model - model_type = "Vision-Language" if is_vl_model else "Text" - self.log(f"Loading {model_type} model: {model_cfg.base_model}") - - # Apply Liger kernel optimizations if available and enabled - if model_cfg.use_liger and LIGER_AVAILABLE: - if is_vl_model: - self.log("Applying Liger kernel optimizations to Qwen2.5-VL") - apply_liger_kernel_to_qwen2_5_vl( - rope=True, # Optimized RoPE - rms_norm=True, # Optimized RMSNorm - swiglu=True, # Optimized SwiGLU - fused_linear_cross_entropy=True, # Fused Linear+CrossEntropy for memory - ) - elif model_cfg.use_liger and not LIGER_AVAILABLE: - self.log( - "Liger kernel requested but not installed. Install with: pip install liger-kernel" - ) - - # Load processor/tokenizer based on model type - if is_vl_model: - # Some environments require remote code for Qwen2.5-VL processors - processor = AutoProcessor.from_pretrained( - model_cfg.base_model, - min_pixels=model_cfg.min_pixels, - max_pixels=model_cfg.max_pixels, - trust_remote_code=True, - ) - else: - processor = AutoTokenizer.from_pretrained(model_cfg.base_model) - - # Load policy model with LoRA - # Use attention implementation from config - attn_implementation = model_cfg.attn_implementation - - # Choose the appropriate model class - model_class = Qwen2_5_VLForConditionalGeneration if is_vl_model else AutoModelForCausalLM - - try: - policy = model_class.from_pretrained( - model_cfg.base_model, - torch_dtype=torch.bfloat16, - attn_implementation=attn_implementation, - trust_remote_code=True, - ) - self.log(f"Using {attn_implementation} for attention") - except (ImportError, ValueError) as e: - # Only fallback if explicitly using flash_attention_2 and it's not available - if attn_implementation == "flash_attention_2": - self.log(f"Flash Attention 2 not available ({e}), using eager attention") - policy = model_class.from_pretrained( - model_cfg.base_model, - torch_dtype=torch.bfloat16, - attn_implementation="eager", - ) - else: - raise # Re-raise if it's a different error - - # Move model to device - policy = policy.to(self.device) # type: ignore - # Enable gradient checkpointing for memory efficiency - if model_cfg.gradient_checkpointing: - policy.gradient_checkpointing_enable() - self.log("Gradient checkpointing enabled for memory efficiency") - - # Add LoRA adapters or load existing adapter - policy.config.use_cache = False - - if model_cfg.adapter_path: - # Load existing adapter as baseline - self.log(f"Loading existing LoRA adapter from: {model_cfg.adapter_path}") - from peft import PeftModel - - policy = PeftModel.from_pretrained(policy, model_cfg.adapter_path) - # Enable adapter training - policy.train() - else: - # Create new LoRA adapter - lora_config = LoraConfig( - r=model_cfg.lora_r, - lora_alpha=model_cfg.lora_alpha, - lora_dropout=model_cfg.lora_dropout, - task_type="CAUSAL_LM", - bias="none", - target_modules=list(model_cfg.target_modules), - ) - policy = get_peft_model(policy, lora_config) - - # Wrap with DDP if in distributed mode - if self.world_size > 1: - policy = DDP( - policy, - device_ids=[self.local_rank], - output_device=self.local_rank, - broadcast_buffers=False, - find_unused_parameters=True, - ) - self.log("Wrapped model (find_unused_parameters=True)") - - # Create optimizer - need to access underlying model if DDP - base_model = policy.module if hasattr(policy, "module") else policy - trainable_params = [p for _, p in base_model.named_parameters() if p.requires_grad] # type: ignore - - # Use 8-bit optimizer if configured - if self.config.training.use_8bit_optimizer and BNB_AVAILABLE: - hud_console.info("Using 8-bit AdamW optimizer from bitsandbytes") - optimizer = bnb.optim.AdamW8bit( # type: ignore - trainable_params, - lr=self.config.training.lr, - betas=self.config.training.adam_betas, - eps=self.config.training.adam_eps, - ) - else: - self.log("Using standard FP32 AdamW optimizer") - optimizer = torch.optim.AdamW( - trainable_params, - lr=self.config.training.lr, - betas=self.config.training.adam_betas, - eps=self.config.training.adam_eps, - ) - - # Log optimizer info - self.log(f"Optimizer: {type(optimizer).__name__}") - num_params = sum(p.numel() for p in trainable_params) - self.log(f"Number of trainable parameters: {num_params:,}") - - return processor, policy, None, optimizer - - def prepare_groups( - self, - samples: list[TrainingSample], - ) -> list[list[TrainingSample]]: - """Prepare groups of samples for training.""" - # Prepare inputs with messages - batch = [] - for sample in samples: - inputs = prepare_inputs(sample, self.processor) - # If inputs are invalid, create dummy inputs to maintain batch size - if ( - not inputs - or "input_ids" not in inputs - or inputs.get("input_ids", torch.tensor([])).numel() == 0 - ): - hud_console.warning_log("Sample has invalid inputs, using dummy values") - # Create minimal dummy inputs to keep batch size consistent - inputs = { - "input_ids": torch.zeros(1, 2, dtype=torch.long), # Minimal sequence - "attention_mask": torch.ones(1, 2, dtype=torch.long), - "assistant_mask": torch.zeros(1, 1, dtype=torch.bool), # T-1 length - } - elif "assistant_mask" not in inputs: - hud_console.warning_log("Sample missing assistant_mask, creating zero mask") - seq_len = inputs["input_ids"].shape[-1] - inputs["assistant_mask"] = torch.zeros( - inputs["input_ids"].shape[0], seq_len - 1, dtype=torch.bool - ) - - new_sample = TrainingSample(**sample.model_dump()) - new_sample.inputs = inputs - new_sample.advantage = sample.advantage - batch.append(new_sample) - - with hud_console.progress("Processing batch of traces...") as progress, torch.no_grad(): - for i, sample in enumerate(batch): - if is_main_process(): - progress.update(f"Processing batch of traces... {i}/{len(batch)}") - if sample.inputs: - sample = sample.to_device(self.device) - sample.old_logprobs, _ = self.compute_logprobs(self.policy, sample.inputs) - # Free GPU memory for this sample immediately - sample.to_device(torch.device("cpu")) - - policy_module = self.policy.module if hasattr(self.policy, "module") else self.policy - with policy_module.disable_adapter(): - for i, sample in enumerate(batch): - if is_main_process(): - progress.update(f"Processing batch of traces... {i}/{len(batch)}") - if sample.inputs: - # Move back to GPU for reference computation, then free - sample = sample.to_device(self.device) - sample.ref_logprobs, _ = self.compute_logprobs(self.policy, sample.inputs) - sample.to_device(torch.device("cpu")) - - hud_console.info_log("Creating mini-batches...") - group_size = self.config.training.group_size - processed_batch = [] - if not self.config.training.accumulate_over_minibatches: - # Find minibatches and group them via batch_training_samples - # Minibatches control the batch size of the forward pass to the model - mb_size = self.config.training.mini_batch_size - group_size = group_size // mb_size - for i in range(0, len(batch), mb_size): - processed_batch.extend(batch_training_samples(batch[i : i + mb_size])) - else: - processed_batch = batch - - for sample in processed_batch: - sample.to_device(torch.device("cpu")) - - # Convert to grouped batches (if updating the model after each task group) - if self.config.training.update_after_group: - return [ - processed_batch[i : i + group_size] - for i in range(0, len(processed_batch), group_size) - ] - else: - return [processed_batch] - - def update(self, samples: list[TrainingSample]) -> TrainingMetrics: - """Perform a gradient update on a batch.""" - import time - - training_start_time = time.time() - - # Always create metrics for synchronization - self.metrics.append(TrainingMetrics()) - metrics = self.metrics[-1] - - # Prepare groups for GRPO training - groups = self.prepare_groups(samples) - self.log(f"Updating over {len(groups)} groups") - - # Update over mini batch size - with hud_console.progress("Gradient update...") as progress: - for epoch in range(self.config.training.epochs): # Do not accumulate across epochs - progress.update(f"Training epoch {epoch + 1}/{self.config.training.epochs}") - for group_idx, group in enumerate(groups): # Do not accumulate across "groups" - self.optimizer.zero_grad(set_to_none=True) - - debug_per_group = "" - grad_accum_steps = len(group) - # Tensor for distributed sync - global_skip = torch.zeros(1, device=self.device) - - for s_idx, sample_minibatch in enumerate(group): - # self.log(f"{group_idx} {sample_minibatch.inputs['assistant_mask'].sum()}") - # mini_updated = sample_minibatch.inputs["assistant_mask"].sum() > 0 - - # Update mini_updated globally - # self.log(f"{group_idx} Mini updated: {mini_updated}") - - # Do not sync until the last minibatch - if s_idx < len(group) - 1 and self.world_size > 1: - ddp_ctx = self.policy.no_sync() - else: - ddp_ctx = nullcontext() - - with ddp_ctx, torch.autocast(device_type="cuda", dtype=torch.bfloat16): - try: - # if mini_updated: - loss = self.compute_loss(sample_minibatch) / grad_accum_steps - debug_per_group += f"l{s_idx}:{round(loss.item(), 3)!s} " - loss.backward() - # else: # Dummy backward that touches all params, produces zero g - # dummy = sum(p.sum() for p in self.policy.parameters()) * 0.0 - # debug_per_group += f"d{s_idx}:{str(round(dummy.item(), 3))} " - # dummy.backward() - # self.log(f"{group_idx} GPU Backward: {get_gpu_utilization():.1f}% | Memory: {get_memory_usage():.2f} GB") # noqa: E501 - except torch.cuda.OutOfMemoryError: - hud_console.warning_log( - f"{group_idx} CUDA OOM for {sample_minibatch.inputs['input_ids'].numel()} tokens; skipping minibatch" # noqa: E501 - ) - # Dummy backward to keep DDP happy - dummy = torch.sum(p.sum() for p in self.policy.parameters()) * 0.0 # type: ignore - debug_per_group += f"o{s_idx}:{round(dummy.item(), 3)!s} " - dummy.backward() - # mark global skip if OOM - global_skip.fill_(1) - continue - - if torch.cuda.is_available(): - torch.cuda.empty_cache() - - # After minibatches loop, sync skip across ranks - if torch.distributed.is_initialized(): - torch.distributed.all_reduce(global_skip, op=torch.distributed.ReduceOp.MAX) - skip_any = bool(global_skip.item()) - - if skip_any: - self.log(f"G[{group_idx}] {debug_per_group} N/A (skipped)") - continue - - grad_norm = torch.nn.utils.clip_grad_norm_( - self.policy.parameters(), - self.config.training.grad_clip, - error_if_nonfinite=True, - ) - self.optimizer.step() - - debug_per_group += f"g:{round(grad_norm.item(), 3)!s}" - self.log(f"G[{group_idx}] {debug_per_group}") - - metrics.update( - { - "grad_norm": grad_norm.item() - if isinstance(grad_norm, torch.Tensor) - else float(grad_norm), - } - ) - - # Calculate training time and throughput - training_time = time.time() - training_start_time - total_samples = ( - len(groups) * self.config.training.group_size * self.config.training.mini_batch_size - ) - samples_per_second = total_samples / training_time if training_time > 0 else 0.0 - - metrics.update( - { - "training_time": training_time, - "samples_per_second": samples_per_second, - } - ) - - return metrics - - def compute_loss(self, sample: TrainingSample) -> torch.Tensor: - """Compute GRPO loss for a batch of samples.""" - training_cfg = self.config.training - metrics = self.metrics[-1] if len(self.metrics) > 0 else TrainingMetrics() - - sample.to_device(self.device) - - pol_logp, pol_entropy = self.compute_logprobs( - self.policy, - sample.inputs, - ) - - sanity_check(sample, pol_logp, sample.old_logprobs, sample.ref_logprobs) - - metrics.update( - { - "gpu_util": get_gpu_utilization(), # Track peak utilization - "gpu_memory": get_memory_usage(), # Track memory usage - } - ) - self.log(f"GPU Util: {get_gpu_utilization():.1f}% | Memory: {get_memory_usage():.2f} GB") - - old_logp = sample.old_logprobs - ref_logp = sample.ref_logprobs - - if old_logp is None or ref_logp is None or sample.advantage is None: - raise ValueError("old_logp, ref_logp, or sample.advantage is None") - - # Use assistant mask to remove non-assistant tokens - m = sample.inputs["assistant_mask"] - - # Aggregate per trace or per token - if training_cfg.ppo_mode == "per_trace": - counts = m.sum(dim=1).clamp_min(1.0) - pol_logp = (pol_logp * m.float()).sum(dim=1) / counts - pol_entropy = (pol_entropy * m.float()).sum(dim=1) / counts - old_logp = (old_logp * m.float()).sum(dim=1) / counts - ref_logp = (ref_logp * m.float()).sum(dim=1) / counts - - # Clip log probability differences - log_ratio = torch.where(m, pol_logp - old_logp, torch.zeros_like(pol_logp)) - ratio_tok = torch.exp(log_ratio.clamp(-20.0, 20.0)) - - # Ensure advantage shape matches ratio_tok for broadcasting - advantage = ( - sample.advantage.view(-1, 1) if ratio_tok.dim() == 2 else sample.advantage.squeeze(-1) - ) - - unclipped = ratio_tok * advantage - clipped = ( - torch.clamp(ratio_tok, 1 - training_cfg.top_eps, 1 + training_cfg.bottom_eps) - * advantage - ) - - policy_term = -torch.minimum(unclipped, clipped) - - # Clip log probability differences in KL - log_rho = torch.where(m, pol_logp - ref_logp, torch.zeros_like(pol_logp)) - rho_tok = torch.exp(log_rho.clamp(-20.0, 20.0)) - kl_approx = rho_tok - torch.log(rho_tok) - 1 - - total_loss = ( - policy_term + training_cfg.kl_beta * kl_approx + training_cfg.entropy_beta * pol_entropy - ) - - # Aggregate loss - if training_cfg.ppo_mode == "per_trace": - total_loss = total_loss.mean() if training_cfg.token_agg == "mean" else total_loss.sum() # noqa: S105 - else: - if training_cfg.token_agg == "mean": # noqa: S105 - total_loss = (total_loss * m).sum() / m.sum().clamp_min(1.0) - else: - total_loss = (total_loss * m).sum() - - # Compute metrics only over masked (assistant) tokens - mask_count = m.sum().clamp_min(1.0) - metrics.update( - { - "policy_ratio": (ratio_tok * m).sum().item() / mask_count.item() - if mask_count.item() > 0 - else 1.0, - "kl": (kl_approx * m).sum().item() / mask_count.item() - if mask_count.item() > 0 - else 0.0, - "entropy": (pol_entropy * m).sum().item() / mask_count.item() - if mask_count.item() > 0 - else 0.0, - "tokens": sample.inputs["input_ids"].numel(), - "loss": total_loss.item(), - } - ) - - sample.to_device(torch.device("cpu")) - - return total_loss - - def compute_logprobs(self, model: Any, inputs: Any) -> tuple[torch.Tensor, torch.Tensor]: - """Compute masked per-token log probabilities via the model. - - Returns log probabilities for the actual next tokens. - """ - try: - model_inputs = {k: v for k, v in inputs.items() if k != "assistant_mask"} - out = model(**model_inputs) - - logits = out.logits / self.config.actor.temperature - - targets = inputs["input_ids"][:, 1:] - - # Align logits to predict next token: use logits[:, :-1, :] - next_logits = logits[:, :-1, :] - - token_log_probs = _selective_log_softmax(next_logits, targets) - - # Compute entropy only for assistant tokens to save memory - assistant_mask = inputs["assistant_mask"] - entropy = torch.zeros_like(token_log_probs) - if assistant_mask.any(): - entropy[assistant_mask] = entropy_from_logits(logits[:, :-1][assistant_mask]) - - return token_log_probs, entropy - except (IndexError, RuntimeError) as e: - # Handle empty inputs or DDP errors - hud_console.warning_log(f"Error in compute_logprobs: {e}. Returning dummy values.") - # Return dummy values that match expected shapes - seq_len = inputs["input_ids"].shape[1] - 1 if "input_ids" in inputs else 0 - batch_size = inputs["input_ids"].shape[0] if "input_ids" in inputs else 1 - # Create dummy tensors that still participate in autograd so backward doesn't fail - try: - # Touch params to build a graph - param_sum = torch.sum(next(self.policy.parameters())) - base = param_sum * 0.0 - except StopIteration: - base = torch.tensor(0.0, device=self.device) - dummy_logprobs = ( - base + torch.zeros(batch_size, seq_len, device=self.device) - ).requires_grad_(True) - dummy_entropy = ( - base + torch.zeros(batch_size, seq_len, device=self.device) - ).requires_grad_(True) - return dummy_logprobs, dummy_entropy - - def save(self, path: str) -> None: - """Save the current policy checkpoint (only on rank 0).""" - if is_main_process(): - os.makedirs(path, exist_ok=True) - # Unwrap DDP model if needed - model_to_save = self.policy.module if hasattr(self.policy, "module") else self.policy - model_to_save.save_pretrained(path) - self.log(f"Saved checkpoint to {path}") - - def load(self, path: str) -> None: - """Load a policy checkpoint.""" - # Would need to reload LoRA weights - self.log(f"Loading checkpoint from {path}") - # Implementation depends on PEFT version - - -def sanity_check( - sample: TrainingSample, - pol_logp: torch.Tensor, - old_logp: torch.Tensor | None, - ref_logp: torch.Tensor | None, -) -> None: - assert "assistant_mask" in sample.inputs - m = sample.inputs["assistant_mask"] - if old_logp is None or ref_logp is None: - return - with torch.no_grad(): - B, K = pol_logp.shape - assert old_logp.shape == (B, K), "old_logp shape mismatch" - assert ref_logp.shape == (B, K), "ref_logp shape mismatch" - assert m.shape == (B, K), "assistant_mask shape mismatch" - - # Check mask is subset of attention_mask[:, 1:] - att = sample.inputs.get("attention_mask", None) - if att is not None and att.dim() == 2: - att_shift = att[:, 1:].bool() - bad = (m & ~att_shift).sum().item() - if bad > 0: - hud_console.warning_log(f"assistant_mask overlaps padding: {bad} tokens") - - # Finiteness on masked entries only - def _stats(name: str, t: torch.Tensor) -> None: - sel = t[m] - if sel.numel() == 0: - hud_console.warning_log(f"{name} empty under mask") - return - finite = torch.isfinite(sel) - if finite.sum() < sel.numel(): - hud_console.warning_log( - f"{name} non-finite: {((~finite).sum().item())}/{sel.numel()}" - ) - sel = sel[finite].float() - - _stats("pol_logp", pol_logp) - _stats("old_logp", old_logp) - _stats("ref_logp", ref_logp) - - # Log-probabilities should be <= 0 (log-softmax) - if (pol_logp[m] > 1e-6).any(): - hud_console.warning_log("pol_logp has positive values under mask") - - # Precompute masked deltas and ratios for diagnostics (before exp) - masked_log_ratio = torch.zeros_like(pol_logp) - masked_log_ratio[m] = (pol_logp - old_logp)[m] - masked_log_rho = torch.zeros_like(pol_logp) - masked_log_rho[m] = (pol_logp - ref_logp)[m] - - _stats("log_ratio(masked)", masked_log_ratio) - _stats("log_rho(masked)", masked_log_rho) - - # Ratios after clamp (diagnostic only) - ratio_diag = torch.zeros_like(pol_logp) - rho_diag = torch.zeros_like(pol_logp) - ratio_diag[m] = torch.exp(masked_log_ratio[m].clamp(-20.0, 20.0)) - rho_diag[m] = torch.exp(masked_log_rho[m].clamp(-20.0, 20.0)) - _stats("ratio_tok(masked)", ratio_diag) - _stats("rho_tok(masked)", rho_diag) - - -def _selective_log_softmax( - logits_bt_v: torch.Tensor, - index_bt: torch.Tensor, -) -> torch.Tensor: - """Gather log softmax for selected indices with reduced peak memory. - - Uses logsumexp subtraction for float32/64; falls back to per-row - log_softmax for bf16/fp16. - logits_bt_v: [B, T, V] - index_bt: [B, T] - Returns: [B, T] - """ - if logits_bt_v.dtype in (torch.float32, torch.float64): - # Compute logsumexp per [B, T] in a loop over batch to reduce - # peak from B*T*V to T*V - logsumexp_values = torch.stack([torch.logsumexp(lg, dim=-1) for lg in logits_bt_v]) - selected_logits = torch.gather(logits_bt_v, dim=-1, index=index_bt.unsqueeze(-1)).squeeze( - -1 - ) - return selected_logits - logsumexp_values - # Reduced precision: numerically stable route using per-row log_softmax - token_logprobs_rows: list[torch.Tensor] = [] - for logits_row, index_row in zip(logits_bt_v, index_bt, strict=True): - logprobs_row = logits_row.log_softmax(dim=-1) - token_logprobs_rows.append( - torch.gather(logprobs_row, dim=-1, index=index_row.unsqueeze(-1)).squeeze(-1) - ) - return torch.stack(token_logprobs_rows) diff --git a/hud/rl/tests/__init__.py b/hud/rl/tests/__init__.py deleted file mode 100644 index e9f6eb2b..00000000 --- a/hud/rl/tests/__init__.py +++ /dev/null @@ -1 +0,0 @@ -"""Tests for RL module.""" diff --git a/hud/rl/tests/test_learner.py b/hud/rl/tests/test_learner.py deleted file mode 100644 index 1055e62c..00000000 --- a/hud/rl/tests/test_learner.py +++ /dev/null @@ -1,186 +0,0 @@ -from __future__ import annotations - -import pytest -import torch - -from hud.rl.config import Config -from hud.rl.learner import GRPOLearner -from hud.rl.types import TrainingSample - - -@pytest.fixture() -def learner_stub(monkeypatch): - cfg = Config() - # Speed up: tiny settings - cfg.training.epochs = 1 - cfg.training.group_size = 1 - cfg.training.mini_batch_size = 1 - cfg.training.use_8bit_optimizer = False - - # Stub _load_models to avoid heavy model init - def _stub_load_models(self): - class DummyPolicy(torch.nn.Module): - def __init__(self): - super().__init__() - self.w = torch.nn.Parameter(torch.zeros(1)) - - dummy_policy = DummyPolicy() - dummy_opt = torch.optim.SGD(dummy_policy.parameters(), lr=0.1) - return None, dummy_policy, None, dummy_opt - - monkeypatch.setattr(GRPOLearner, "_load_models", _stub_load_models, raising=True) - return GRPOLearner(cfg) - - -def make_sample( - pol_logp_tok: torch.Tensor, - old_logp_tok: torch.Tensor, - ref_logp_tok: torch.Tensor, - advantage: float, -): - # Minimal-but-correct object for GRPOLearner.compute_loss. - # Needs assistant_mask (T-1) and attention_mask (T) for sanity_check(). - Tm1 = pol_logp_tok.size(-1) - inputs = { - "input_ids": torch.zeros(1, Tm1 + 1, dtype=torch.long), - "attention_mask": torch.ones(1, Tm1 + 1, dtype=torch.long), - "assistant_mask": torch.ones(1, Tm1, dtype=torch.bool), - } - return TrainingSample( - inputs=inputs, - old_logprobs=old_logp_tok, - ref_logprobs=ref_logp_tok, - # advantage must be 1D so .view(-1,1) works in compute_loss - advantage=torch.tensor([advantage], dtype=torch.float32), - ) - - -def patch_compute_logprobs( - monkeypatch, learner: GRPOLearner, pol_logp_tok: torch.Tensor, pol_entropy_tok: torch.Tensor -): - # Return (pol_logp, pol_entropy) as expected by compute_loss - def _stub_compute_logprobs(self, model, inputs): - return pol_logp_tok.to(inputs["input_ids"].device), pol_entropy_tok.to( - inputs["input_ids"].device - ) - - monkeypatch.setattr(GRPOLearner, "compute_logprobs", _stub_compute_logprobs, raising=True) - - -def test_per_token_mean_vs_sum(monkeypatch, learner_stub: GRPOLearner): - # Setup - _, Tm1 = 1, 4 - pol = torch.tensor([[-1.0, -1.0, -1.0, -1.0]], dtype=torch.float32) # logp - old = torch.tensor([[-1.2, -0.8, -1.0, -1.1]], dtype=torch.float32) - ref = torch.tensor([[-1.0, -1.0, -1.0, -1.0]], dtype=torch.float32) - ent = torch.zeros_like(pol) - patch_compute_logprobs(monkeypatch, learner_stub, pol, ent) - - # Common config - learner_stub.config.training.kl_beta = 0.0 - learner_stub.config.training.entropy_beta = 0.0 - learner_stub.config.training.top_eps = 0.2 - learner_stub.config.training.bottom_eps = 0.1 - - sample = make_sample(pol, old, ref, advantage=1.0) - - # token_agg=mean - learner_stub.config.training.ppo_mode = "per_token" - learner_stub.config.training.token_agg = "mean" - loss_mean = learner_stub.compute_loss(sample).item() - - # token_agg=sum - learner_stub.config.training.token_agg = "sum" - loss_sum = learner_stub.compute_loss(sample).item() - - # Expect sum ≈ mean * num_tokens - assert pytest.approx(loss_sum, rel=1e-5) == loss_mean * Tm1 - - -def test_per_trace_vs_per_token(monkeypatch, learner_stub: GRPOLearner): - # Equal per-token deltas -> per_trace matches per_token(mean) - pol = torch.tensor([[-1.0, -1.0, -1.0]], dtype=torch.float32) - old = torch.tensor([[-1.2, -1.2, -1.2]], dtype=torch.float32) - ref = torch.tensor([[-1.1, -1.1, -1.1]], dtype=torch.float32) - ent = torch.zeros_like(pol) - patch_compute_logprobs(monkeypatch, learner_stub, pol, ent) - - learner_stub.config.training.kl_beta = 0.0 - learner_stub.config.training.entropy_beta = 0.0 - learner_stub.config.training.top_eps = 0.2 - learner_stub.config.training.bottom_eps = 0.1 - - sample = make_sample(pol, old, ref, advantage=1.0) - - learner_stub.config.training.ppo_mode = "per_token" - learner_stub.config.training.token_agg = "mean" - ltok = learner_stub.compute_loss(sample).item() - - learner_stub.config.training.ppo_mode = "per_trace" - ltraj = learner_stub.compute_loss(sample).item() - - assert pytest.approx(ltraj, rel=1e-6) == ltok - - -def test_entropy_beta_effect(monkeypatch, learner_stub: GRPOLearner): - pol = torch.tensor([[-1.0, -1.1]], dtype=torch.float32) - old = torch.tensor([[-1.0, -1.1]], dtype=torch.float32) - ref = torch.tensor([[-1.0, -1.1]], dtype=torch.float32) - ent = torch.tensor([[0.5, 1.5]], dtype=torch.float32) - patch_compute_logprobs(monkeypatch, learner_stub, pol, ent) - - # No policy/kl effect, only entropy - learner_stub.config.training.ppo_mode = "per_token" - learner_stub.config.training.token_agg = "mean" - learner_stub.config.training.kl_beta = 0.0 - - sample = make_sample(pol, old, ref, advantage=0.0) - - learner_stub.config.training.entropy_beta = 0.0 - l0 = learner_stub.compute_loss(sample).item() - - learner_stub.config.training.entropy_beta = 2.0 - l1 = learner_stub.compute_loss(sample).item() - - # Mean entropy = (0.5+1.5)/2 = 1.0, scaled by beta=2.0 -> +2.0 - assert pytest.approx(l1 - l0, rel=1e-6) == 2.0 - - -def test_skip_update_when_zero_adv(monkeypatch, learner_stub: GRPOLearner): - # Patch prepare_groups to yield a single group with a minibatch-like object - class MiniBatch: - def __init__(self): - self.advantage = torch.zeros(1) - - def to_device(self, device: torch.device) -> MiniBatch: - return self - - def _stub_prepare_groups(self, samples: list[TrainingSample]) -> list[list[MiniBatch]]: - return [[MiniBatch(), MiniBatch()]] - - monkeypatch.setattr(GRPOLearner, "prepare_groups", _stub_prepare_groups, raising=True) - - # Return a zero scalar loss that *depends* on params so backward works, - # but has zero gradients (no update signal). - def _zero_loss(self, sample) -> torch.Tensor: - return sum(p.sum() for p in self.policy.parameters()) * 0.0 # type: ignore - - monkeypatch.setattr(GRPOLearner, "compute_loss", _zero_loss, raising=True) - - # Count optimizer.step calls - steps = {"n": 0} - # orig_step = learner_stub.optimizer.step - - def _count_step(): - steps["n"] += 1 - - monkeypatch.setattr(learner_stub.optimizer, "step", _count_step, raising=False) - - # Ensure dummy backward can touch a parameter - assert any(p.requires_grad for p in learner_stub.policy.parameters()) - - learner_stub.update([]) - # With the current learner implementation we still call optimizer.step() - # even if the per-minibatch "advantage" is zero (the step is a no-op - # because the gradients are zero). So we expect exactly one step here. - assert steps["n"] == 1 diff --git a/hud/rl/train.py b/hud/rl/train.py deleted file mode 100644 index 3c7d6988..00000000 --- a/hud/rl/train.py +++ /dev/null @@ -1,394 +0,0 @@ -"""Main training loop for GRPO RL.""" - -from __future__ import annotations - -import os - -# Disable tokenizer parallelism warnings -os.environ["TOKENIZERS_PARALLELISM"] = "false" -os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True" -import argparse -import asyncio -import json -import logging -from pathlib import Path -from typing import TYPE_CHECKING, cast - -import hud -from hud.rl.actor import Actor -from hud.rl.buffer import DatasetBuffer, ReplayBuffer -from hud.rl.config import Config -from hud.rl.distributed import ( - broadcast_object, - cleanup_distributed, - get_global_rank, - get_world_size, - is_main_process, - scatter_object, - setup_distributed, - synchronize, -) -from hud.rl.learner import GRPOLearner -from hud.rl.utils import ( - aggregate_metrics_across_ranks, - ensure_dir, - preprocess_advantages, - set_seed, -) -from hud.rl.vllm_adapter import VLLMAdapter -from hud.utils.hud_console import HUDConsole -from hud.utils.tasks import load_tasks - -if TYPE_CHECKING: - from hud.types import Task -hud_console = HUDConsole(logging.getLogger(__name__)) - - -async def train(config: Config, tasks: list[Task]) -> None: - """Main training loop.""" - # Setup distributed environment - setup_distributed() - - # Initialize components - set_seed(config.seed + get_global_rank()) # Different seed per rank - ensure_dir(config.out_dir) - if config.verbose: - logging.basicConfig(level=logging.INFO) - # Remove httpx logger - logging.getLogger("httpx").setLevel(logging.WARNING) - if config.very_verbose: - logging.basicConfig(level=logging.DEBUG) - # Remove httpx logger - logging.getLogger("httpx").setLevel(logging.INFO) - - if is_main_process(): - hud_console.header("Starting GRPO Training") - hud_console.section_title( - f"\n[1/3] Initializing components (world_size={get_world_size()})..." - ) - - num_gpus = get_world_size() - - # Actor is responsible for running tasks and collecting episodes - actor = Actor(config) if is_main_process() else None - - # Learner is responsible for updating the policy - learner = GRPOLearner(config) - - # Dataset buffer is responsible for storing tasks - dataset_buffer = DatasetBuffer(tasks, config) - if is_main_process(): - hud_console.key_value_table(dataset_buffer.info) - - if dataset_buffer.groups_per_batch % num_gpus != 0: - hud_console.warning( - f"Groups per batch {dataset_buffer.groups_per_batch} is not divisible by number of GPUs {num_gpus}" # noqa: E501 - ) - exit(1) - - # Replay buffer is responsible for storing episodes for training - trace_buffer = ReplayBuffer(config) - - # VLLM adapter is responsible for loading and unloading adapters (only on main process) - vllm = ( - VLLMAdapter(config.actor.vllm_base_url, config.actor.vllm_api_key) - if is_main_process() - else None - ) - - # Load initial adapter if provided - if is_main_process() and config.model.adapter_path and vllm: - hud_console.info(f"Loading baseline adapter from: {config.model.adapter_path}") - success = vllm.load_adapter(config.model.base_model, config.model.adapter_path) - if success and actor is not None: - hud_console.info("Successfully loaded baseline adapter as 'base_model'") - # Update actor to use the loaded adapter - actor.update_adapter(config.model.base_model) - else: - hud_console.error("Failed to load baseline adapter") - exit(1) - - # Training state - step = 0 - last_metrics = None # Store last successful metrics for error recovery - - if is_main_process(): - hud_console.section_title("\n[2/3] Running training loop...") - - # Create job on main process and distribute ID across GPUs - if is_main_process(): - hud_console.info(f"Creating job with config.job_id: {config.job_id}") - job_obj = hud.create_job( - job_id=config.job_id, - name=config.job_name, - metadata={"config": config.to_dict(), "agent_class": config.model.base_model}, - ) - hud_console.info(f"Created job with job_obj.id: {job_obj.id}") - job_obj.update_status_sync("running") - job_id = job_obj.id - else: - job_obj = None - job_id = None - - # Broadcast job ID to all ranks - job_id = broadcast_object(job_id, src=0) - - try: - while len(dataset_buffer) > 0: - if is_main_process(): - hud_console.section_title(f"Step {step + 1}/{dataset_buffer.training_steps}") - hud_console.info(f"{len(dataset_buffer)} tasks remaining") - # Get batch of tasks (all ranks need same tasks) - tasks = dataset_buffer.get_tasks() - - # Initialize variables on all ranks - global_reward_stats = None - global_advantage_stats = None - - # Step-state gate: ensure all ranks branch coherently - state = {"ok": False, "err": None, "num_samples": 0} - rank_samples = None - episode_time_value = None - - # Only rank 0 runs tasks and prepares distribution - if is_main_process() and actor is not None: - import time - - try: - episode_start_time = time.time() - traces = await actor.run_tasks(tasks, job_id=job_id) - episode_time = time.time() - episode_start_time - hud_console.info(f"Sampled {len(traces)} traces in {episode_time:.1f}s") - trace_buffer.add(traces) - global_reward_stats = [trace.reward for trace in traces] - - # Get all traces from buffer for distribution - all_traces = trace_buffer.sample_traces() - - # Preprocess traces to training samples - preprocessed_traces = preprocess_advantages(all_traces, config) - - # Store these for later use in metrics - global_advantage_stats = [sample.advantage for sample in preprocessed_traces] - - # Distribute preprocessed samples in groups across ranks via scatter - # Ensure list length is a multiple of num_gpus by allowing empty per-rank slices - gpu_batch_size = max(1, (len(preprocessed_traces) + num_gpus - 1) // num_gpus) - rank_samples = [ - preprocessed_traces[i : i + gpu_batch_size] - for i in range(0, len(preprocessed_traces), gpu_batch_size) - ] - # Pad rank_samples to exactly num_gpus entries - if len(rank_samples) < num_gpus: - rank_samples.extend([[] for _ in range(num_gpus - len(rank_samples))]) - - # Log distribution info - dist_msg = ( - f"Distributing {len(preprocessed_traces)} samples as {gpu_batch_size} " - f"sized batches across {num_gpus} GPUs" - ) - hud_console.info(dist_msg) - for rank in range(num_gpus): - n_samples = len(rank_samples[rank]) if rank < len(rank_samples) else 0 - hud_console.info(f" Rank {rank}: {n_samples} samples") - - hud_console.section_title(f"Training on {len(all_traces)} traces") - episode_time_value = episode_time - - state.update({"ok": True, "num_samples": len(preprocessed_traces)}) - except Exception as e: - state.update({"ok": False, "err": str(e)}) - - # Broadcast step-state to keep ranks in lockstep - state = broadcast_object(state, src=0) - if not state.get("ok", False): - hud_console.warning("Step failed on rank 0; skipping this step coherently") - synchronize() - continue - - # Scatter per-rank samples; each rank receives only its slice - my_samples = scatter_object(rank_samples if is_main_process() else None, src=0) - # Broadcast the episode time (small object) - episode_time_value = broadcast_object(episode_time_value, src=0) - - # Process only assigned samples - last_metrics = learner.update(my_samples) - - # Add episode time (same for all ranks since episodes run on rank 0) - if episode_time_value is not None: - last_metrics.update( - { - "episode_time": episode_time_value, - } - ) - - # Aggregate metrics across all GPUs for proper statistics - aggregate_metrics_across_ranks(last_metrics) - - if is_main_process() and job_obj is not None: - # Use the global statistics we collected before distribution - if global_reward_stats is not None and global_advantage_stats is not None: - last_metrics.update( - { - "advantage": global_advantage_stats, - "reward": global_reward_stats, - } - ) - else: - # Fallback: use only this rank's data - hud_console.warning("Global statistics not available, using partial data") - last_metrics.update( - { - "advantage": [sample.advantage for sample in my_samples] - if my_samples - else [], - "reward": [sample.reward for sample in my_samples] - if my_samples - else [], - } - ) - - job_obj.log_sync(last_metrics.to_dict()) - - if step % config.stats_interval == 0: - hud_console.key_value_table(last_metrics.to_dict()) - - # Increment step counter on all processes - step += 1 - - # Save checkpoint and update vLLM (only on main process) - if step % config.training.save_every_batches == 0: - if is_main_process() and vllm is not None and actor is not None: - hud_console.section_title("Saving checkpoint and updating vLLM") - checkpoint_path = Path(config.out_dir) / f"{config.adapter_prefix}-{step}" - learner.save(str(checkpoint_path)) - - # Wait for 6 seconds to ensure the checkpoint is saved - await asyncio.sleep(6) - - # If there is a previous adapter, unload it - current_adapter = vllm.get_current() - if current_adapter is not None: - vllm.unload_adapter(current_adapter) - - adapter_name = f"{config.adapter_prefix}-{step}" - if vllm.load_adapter(adapter_name, str(checkpoint_path)): - actor.update_adapter(adapter_name) - hud_console.info(f"✓ Checkpoint saved and loaded: {adapter_name}") - else: - hud_console.warning(f"Failed to hot-load adapter {adapter_name}") - - # Ensure all processes wait for checkpoint operations to complete - synchronize() - - if is_main_process(): - hud_console.section_title("\n[3/3] Training completed!") - # Update job status to completed - if job_obj: - job_obj.update_status_sync("completed") - except Exception as e: - # Log error and any available metrics before failing - hud_console.error(f"Training failed on rank {get_global_rank()}: {e}") - - if is_main_process(): - # Log final metrics if we have any - if last_metrics and job_obj: - try: - job_obj.log_sync(last_metrics.to_dict()) - except Exception: - hud_console.warning("Failed to log final metrics") - - # Update job status to failed - if job_obj: - job_obj.update_status_sync("failed") - - # Don't re-raise immediately to allow cleanup - raise - - finally: - # Try to sync one last time, but don't fail if it doesn't work - try: - synchronize() - except Exception: - hud_console.warning("Failed to synchronize during cleanup") - - # Clean up distributed environment - cleanup_distributed() - - -async def main() -> None: - parser = argparse.ArgumentParser(description="GRPO RL Training") - parser.add_argument("--config", type=str, help="Path to config JSON file") - parser.add_argument("--test", action="store_true", help="Run in test mode") - parser.add_argument("--debug", action="store_true", help="Enable debug mode") - parser.add_argument("--verbose", action="store_true", help="Enable verbose mode") - # Task input arguments - parser.add_argument( - "--tasks", type=str, help="Path to tasks JSONL file or HuggingFace dataset name" - ) - parser.add_argument("--tasks-json", type=json.loads, help="Tasks as JSON list string") - - args = parser.parse_args() - - # Load config - if args.config: - with open(args.config, encoding="utf-8") as f: # noqa: ASYNC230 - config_dict = json.load(f) - config = Config.from_dict(config_dict) - else: - config = Config() - - # Apply test mode settings - if args.test: - hud_console.info("[TEST MODE] Using minimal configuration") - eps = 6 - config.training.batch_size = eps - config.actor.max_parallel_episodes = 12 - config.training.group_size = eps - config.training.mini_batch_size = 3 - config.training.training_steps = 4 - config.actor.max_steps_per_episode = 4 - - # Calculate the memory usage - INITIAL_MEMORY = 8.0 - SCALING_FACTOR = 4 / (28 * 28 * 256 * 1024) - token_estimate = ( - config.training.mini_batch_size - * config.actor.max_steps_per_episode - * config.actor.max_new_tokens - ) - hud_console.info(f"Estimated tokens per forward pass: {token_estimate}") - image_estimate = config.model.max_pixels - total_memory = INITIAL_MEMORY + SCALING_FACTOR * token_estimate * image_estimate - hud_console.info(f"Estimated memory peak: {total_memory:.2f} GB") - if total_memory > 75.0: - hud_console.warning( - "Potential memory usage is too high, decrease either training steps or mini batch size" - ) - exit(1) - - # Load tasks - if args.tasks_json: - # Tasks provided as JSON list via command line - tasks = load_tasks(args.tasks_json) - elif args.tasks: - # Tasks provided as file path or HuggingFace dataset - tasks = load_tasks(args.tasks) - else: - # Default to browser_2048_tasks.jsonl if it exists - default_tasks_path = "browser_2048_tasks.jsonl" - if Path(default_tasks_path).exists(): - hud_console.info(f"No tasks specified, using default: {default_tasks_path}") - tasks = load_tasks(default_tasks_path) - else: - raise ValueError( - "No tasks specified. Use --tasks, --tasks-json, or specify tasks_file in config" - ) - - # Run training - tasks_typed = cast("list[Task]", tasks) - await train(config, tasks_typed) - - -if __name__ == "__main__": - asyncio.run(main()) diff --git a/hud/rl/types.py b/hud/rl/types.py deleted file mode 100644 index e0fc5006..00000000 --- a/hud/rl/types.py +++ /dev/null @@ -1,101 +0,0 @@ -"""Shared types for RL training.""" - -from __future__ import annotations - -import math -from typing import Any - -from pydantic import ConfigDict, Field -from pydantic.dataclasses import dataclass - -from hud.types import Trace - -try: - import torch -except ImportError: - raise ImportError("uv tool install hud-python[rl] to use this module") from None - - -class TrainingSample(Trace): - """A single training sample for GRPO.""" - - model_config = ConfigDict(arbitrary_types_allowed=True) - # Tokenized inputs to the model (model.forward(*inputs)) - # This includes the input tokens, logit mask, etc. - inputs: dict[str, torch.Tensor] = Field(default_factory=dict) - old_logprobs: torch.Tensor | None = Field(default=None) - ref_logprobs: torch.Tensor | None = Field(default=None) - - # Weighted advantage of group calculation - advantage: torch.Tensor | None = Field(default=None) - - def to_device(self, device: torch.device) -> TrainingSample: - """Move sample to device.""" - self.inputs = { - k: (t.to(device, non_blocking=True) if hasattr(t, "to") else t) - for k, t in self.inputs.items() - } - self.advantage = self.advantage.to(device) if self.advantage is not None else None - self.old_logprobs = self.old_logprobs.to(device) if self.old_logprobs is not None else None - self.ref_logprobs = self.ref_logprobs.to(device) if self.ref_logprobs is not None else None - return self - - -@dataclass -class Metric: - """A tuple for metrics.""" - - name: str = Field(default="") - mean: float = Field(default=0.0) - std: float = Field(default=0.0) - values: list[float] = Field(default_factory=list) - - def update( - self, value: float | torch.Tensor | list[float] | list[int] | list[torch.Tensor] - ) -> None: - """Update metric.""" - if isinstance(value, list): - self.values.extend(value.item() if isinstance(value, torch.Tensor) else value) # type: ignore - else: - self.values.append(value.item() if isinstance(value, torch.Tensor) else value) # type: ignore - mean_val = sum(self.values) / len(self.values) - self.mean = mean_val.item() if isinstance(mean_val, torch.Tensor) else float(mean_val) # type: ignore - variance = sum((x - self.mean) ** 2 for x in self.values) / len(self.values) - variance_val = variance.item() if isinstance(variance, torch.Tensor) else float(variance) # type: ignore - self.std = math.sqrt(variance_val) - - -@dataclass -class TrainingMetrics: - """Metrics for GRPO training (per training step).""" - - # Learner metrics - grad_norm: Metric = Field(default=Metric()) - loss: Metric = Field(default=Metric()) - kl: Metric = Field(default=Metric()) - reward: Metric = Field(default=Metric()) - advantage: Metric = Field(default=Metric()) - policy_ratio: Metric = Field(default=Metric()) - tokens: Metric = Field(default=Metric()) - entropy: Metric = Field(default=Metric()) - - # Computation metrics - gpu_util: Metric = Field(default=Metric()) # GPU utilization percentage - gpu_memory: Metric = Field(default=Metric()) # GPU memory usage in GB - episode_time: Metric = Field(default=Metric()) # Time to run episodes (actor) - training_time: Metric = Field(default=Metric()) # Time for gradient updates (learner) - samples_per_second: Metric = Field(default=Metric()) # Training throughput - - def update(self, metrics: dict[str, Any]) -> None: - """Update metrics.""" - for key, value in metrics.items(): - if key in self.__dataclass_fields__: - getattr(self, key).update(value) - - def to_dict(self) -> dict[str, Any]: - """Convert metrics to dictionary.""" - final_metrics = {} - for key in self.__dataclass_fields__: - final_metrics[f"{key}_mean"] = getattr(self, key).mean - final_metrics[f"{key}_std"] = getattr(self, key).std - return final_metrics diff --git a/hud/rl/utils.py b/hud/rl/utils.py deleted file mode 100644 index 29665f81..00000000 --- a/hud/rl/utils.py +++ /dev/null @@ -1,524 +0,0 @@ -"""Utility functions for RL training.""" - -from __future__ import annotations - -import base64 -import io -import logging -import os -import random -from pathlib import Path -from typing import TYPE_CHECKING, Any - -import numpy as np -import torch -from PIL import Image -from transformers.utils.chat_template_utils import render_jinja_template - -from hud.utils.hud_console import HUDConsole - -from .types import TrainingSample - -if TYPE_CHECKING: - from hud.types import Trace - - from .config import Config - -logger = logging.getLogger(__name__) -hud_console = HUDConsole(logger) - - -def set_seed(seed: int) -> None: - """Set random seeds for reproducibility.""" - random.seed(seed) - torch.manual_seed(seed) - if torch.cuda.is_available(): - torch.cuda.manual_seed_all(seed) - - -def load_chat_template(path: str) -> str: - """Load chat template from file.""" - with open(path) as f: - return f.read() - - -def ensure_dir(path: str) -> None: - """Create directory if it doesn't exist.""" - os.makedirs(path, exist_ok=True) - - -def get_memory_usage() -> float: - if torch.cuda.is_available(): - torch.cuda.synchronize() - return torch.cuda.memory_allocated() / 1024**3 - return 0.0 - - -def get_gpu_utilization() -> float: - """Get current GPU utilization percentage (0-100).""" - if not torch.cuda.is_available(): - return 0.0 - - try: - import nvidia_ml_py as nvml # type: ignore - - nvml.nvmlInit() - device_id = torch.cuda.current_device() - handle = nvml.nvmlDeviceGetHandleByIndex(device_id) - util = nvml.nvmlDeviceGetUtilizationRates(handle) - return float(util.gpu) - except Exception: - # Fallback: estimate based on memory usage - # This is less accurate but works without nvidia-ml-py - return min(100.0, (torch.cuda.memory_allocated() / torch.cuda.max_memory_allocated()) * 100) - - -def aggregate_metrics_across_ranks( - metrics: Any, metrics_to_aggregate: list[str] | None = None -) -> None: - """Aggregate metrics across all ranks for proper distributed statistics. - - Args: - metrics: TrainingMetrics object to update in-place - metrics_to_aggregate: List of metric names to aggregate. If None, aggregates all numeric metrics. - - This function: - 1. Gathers metric values from all ranks - 2. Computes proper mean/std across all GPUs - 3. Updates the metrics object in-place (only on rank 0) - """ # noqa: E501 - from hud.rl.distributed import get_local_rank, get_world_size, is_main_process - - if get_world_size() <= 1: - return # Nothing to aggregate in single GPU mode - - # Default metrics that typically vary across GPUs - if metrics_to_aggregate is None: - metrics_to_aggregate = [ - "training_time", - "samples_per_second", - "gpu_util", - "gpu_memory", - "grad_norm", - # Include core training scalars - "loss", - "kl", - "entropy", - "tokens", - "policy_ratio", - ] - - # Collect current values from this rank - local_values = {} - for metric_name in metrics_to_aggregate: - if hasattr(metrics, metric_name): - metric_obj = getattr(metrics, metric_name) - # Get the last value if available, otherwise 0 - local_values[metric_name] = metric_obj.values[-1] if metric_obj.values else 0.0 - - # Convert to tensor for distributed gathering - values_tensor = torch.tensor( - list(local_values.values()), device=f"cuda:{get_local_rank()}", dtype=torch.float32 - ) - - # Gather from all ranks using NCCL-supported all_gather - world_size = get_world_size() - gather_list = [torch.zeros_like(values_tensor) for _ in range(world_size)] - torch.distributed.all_gather(gather_list, values_tensor) - - # Update metrics on main process only - if is_main_process(): - # Reshape: [num_gpus, num_metrics] - all_values = torch.stack(gather_list).cpu().numpy() - - # Update each metric with aggregated values - for i, metric_name in enumerate(local_values.keys()): - metric_obj = getattr(metrics, metric_name) - gpu_values = all_values[:, i].tolist() - - # Replace last value with cross-rank mean for reporting - if len(metric_obj.values) == 0: - metric_obj.values.append(0.0) - metric_obj.values[-1] = float(sum(gpu_values) / len(gpu_values)) - # Recompute mean/std across history using updated last value - metric_obj.mean = float(sum(metric_obj.values) / len(metric_obj.values)) - variance = sum((x - metric_obj.mean) ** 2 for x in metric_obj.values) / len( - metric_obj.values - ) - metric_obj.std = float(variance**0.5) - - -def b64_to_pil(b64_str: str) -> Image.Image: - """Convert base64 string to PIL Image.""" - return Image.open(io.BytesIO(base64.b64decode(b64_str))).convert("RGB") - - -def build_assistant_masks( - input_ids: list[list[int]], - tokenizer: Any, -) -> list[list[int]]: - """ - Build assistant masks from token IDs by finding assistant turns. - - Args: - input_ids: List of token sequences - tokenizer: Tokenizer to decode tokens and get special token IDs - verbose: Whether to print verbose information - - Returns: - List of binary masks indicating assistant tokens - """ - id_im_start = tokenizer.convert_tokens_to_ids("<|im_start|>") - id_im_end = tokenizer.convert_tokens_to_ids("<|im_end|>") - id_assistant = tokenizer.convert_tokens_to_ids("assistant") - - assistant_masks: list[list[int]] = [] - - for seq in input_ids: - mask = [0] * len(seq) - i_tok = 0 - assistant_turn_count = 0 - - while i_tok < len(seq): - # Detect start of assistant turn - if ( - seq[i_tok] == id_im_start - and i_tok + 1 < len(seq) - and seq[i_tok + 1] == id_assistant - ): - assistant_turn_count += 1 - - # Skip '<|im_start|>', 'assistant' and possible newline token - i_tok += 2 - # Check for newline after 'assistant' - if i_tok < len(seq) and tokenizer.decode([seq[i_tok]]) == "\n": - i_tok += 1 - - # Skip leading spaces after assistant\n - while i_tok < len(seq) and tokenizer.decode([seq[i_tok]]).strip() == "": - i_tok += 1 - - assistant_content_start = i_tok - - # Mark tokens until we hit <|im_end|> - content_end = i_tok - while i_tok < len(seq) and seq[i_tok] != id_im_end: - content_end = i_tok + 1 # Track last non-<|im_end|> position - mask[i_tok] = 1 - i_tok += 1 - - # Remove trailing spaces from the mask - while content_end > assistant_content_start: - if ( - mask[content_end - 1] == 1 - and tokenizer.decode([seq[content_end - 1]]).strip() == "" - ): - mask[content_end - 1] = 0 - content_end -= 1 - else: - break - - # Skip the <|im_end|> token - i_tok += 1 - else: - i_tok += 1 - - assistant_masks.append(mask) - - return assistant_masks - - -def prepare_conversation_history( - conversation_history: list[dict[str, Any]], -) -> tuple[list[dict[str, Any]], list[Image.Image]]: - """Sanitize conversation history to avoid vLLM errors.""" - sanitized_messages = [] - images = [] - for m in conversation_history: - if "tool_calls" in m: - m = { - "role": m["role"], - "content": m.get("content", ""), - "tool_calls": [ - tc.model_dump() if not isinstance(tc, dict) else tc - for tc in m.get("tool_calls", []) - ], - } - elif m.get("role") == "user": - user_content = m.get("content", []) - for c in user_content: - if isinstance(c, dict) and c.get("type") == "image_url": - image_url = c.get("image_url", {}) - url = image_url.get("url", "") - if url.startswith("data:image"): - data = url.split(",", 1)[1] if "," in url else url - images.append(b64_to_pil(data)) - elif isinstance(data, bytes | bytearray): - images.append(Image.open(io.BytesIO(data)).convert("RGB")) - c = {"type": "image"} - m["content"] = user_content - sanitized_messages.append(m) - return sanitized_messages, images - - -def prepare_inputs(trace: Trace, processor: Any) -> dict[str, torch.Tensor]: - """ - Prepare inputs from a trace. - - Args: - trace: Trace to process - processor: Model processor - - Returns: - Inputs for the model - """ - if len(trace.messages) == 0: - return {} - - # Get images for current turn - conversation, images = prepare_conversation_history(trace.messages) - - # Get absolute path to chat template - chat_template_path = Path(__file__).parent / "chat_template.jinja" - - # For VL models, processor has a tokenizer attribute; for text models, processor IS tokenizer - tokenizer = processor.tokenizer if hasattr(processor, "tokenizer") else processor - - text_list, _ = render_jinja_template( - conversations=[conversation], - chat_template=load_chat_template(str(chat_template_path)), - tools=trace.info["tool_spec"] if trace.info["tool_spec"] else None, # mcp_tools - return_assistant_tokens_mask=True, - **tokenizer.special_tokens_map, - ) - # For text models, don't pass images parameter - if hasattr(processor, "tokenizer"): - # VL model - processor accepts images - inputs = processor( - images=images if len(images) > 0 else None, - text=text_list, - return_offsets_mapping=False, # we no longer need char offsets - ) - else: - # Text model - processor is tokenizer, doesn't accept images - inputs = processor( - text=text_list, - return_offsets_mapping=False, # we no longer need char offsets - ) - - assistant_masks = build_assistant_masks(inputs["input_ids"], tokenizer) - mask_tensor = torch.tensor(assistant_masks, dtype=torch.long) - - # Ensure mask_tensor is 2D before slicing - if mask_tensor.dim() == 1: - mask_tensor = mask_tensor.unsqueeze(0) - - # Slice to align with targets [B, T-1] - inputs["assistant_mask"] = mask_tensor[:, 1:].bool() - - # Log amount of assistant tokens, and the first 10 tokens that are non 0, decoded - # assistant_batches = render_assistant_tokens(mask_tensor, inputs['input_ids'], processor) - inputs.convert_to_tensors(tensor_type="pt") - - return inputs - - -def render_assistant_tokens( - mask_tensor: torch.Tensor, input_ids: torch.Tensor, processor: Any -) -> list[str]: - """Render assistant tokens as a list of continuous batches.""" - # Get the mask as a 1D tensor - mask_1d = mask_tensor[0] - - # Find continuous sequences of non-zero values - batches = [] - start_idx = None - - for i in range(len(mask_1d)): - if mask_1d[i] != 0 and start_idx is None: - # Start of a new batch - start_idx = i - elif mask_1d[i] == 0 and start_idx is not None: - # End of current batch - # Extract and decode the tokens in this batch - batch_token_ids = input_ids[0][start_idx:i].tolist() - decoded_batch = processor.decode(batch_token_ids) - batches.append(decoded_batch) - start_idx = None - - # Handle case where the last batch extends to the end - if start_idx is not None: - batch_token_ids = input_ids[0][start_idx:].tolist() - decoded_batch = processor.decode(batch_token_ids) - batches.append(decoded_batch) - - return batches - - -def entropy_from_logits(logits: torch.Tensor) -> torch.Tensor: - """Calculate entropy from logits in a memory-efficient way.""" - log_probs = torch.nn.functional.log_softmax(logits, dim=-1) - entropy = -torch.sum(torch.exp(log_probs) * log_probs, dim=-1) - return entropy - - -def preprocess_advantages(group: list[Trace], config: Config) -> list[TrainingSample]: - """Preprocess a group of traces.""" - group_size = config.training.group_size - if config.training.batch_level == "group": - groups = [group[i : i + group_size] for i in range(0, len(group), group_size)] - elif config.training.batch_level == "batch": - groups = [group] - else: - raise ValueError(f"Invalid batch level: {config.training.batch_level}") - - all_samples = [] - for i, group in enumerate(groups): - rewards = np.array([trace.reward for trace in group]) - mean_reward = np.mean(rewards) - std_reward = np.std(rewards) - - # Calculate advantages - samples = [TrainingSample(**trace.model_dump()) for trace in group] - for sample, reward in zip(samples, rewards, strict=True): - if sample.isError: - sample.advantage = torch.Tensor(np.array([0.0])) - continue - # No std (non-baseline GRPO) - if config.training.no_std: - advantage_value = reward - mean_reward - else: - # Avoid division by zero - if std_reward < 1e-6: - advantage_value = torch.Tensor(np.array([0.0])) - else: - advantage_value = (reward - mean_reward) / std_reward - # Leave one out RLOO/LOOP - if config.training.leave_one_out: - advantage_value = advantage_value * len(group) / (len(group) - 1) - sample.advantage = torch.Tensor(np.array([advantage_value])) - hud_console.info_log( - f"Advantages for group {i} [{mean_reward:.4f} ± {std_reward:.4f}]:" - f"{[round(sample.advantage.item(), 4) for sample in samples if sample.advantage is not None]}" # noqa: E501 - ) - - all_samples.extend(samples) - - return all_samples - - -def batch_training_samples(samples: list[TrainingSample]) -> list[TrainingSample]: - """Create batched model inputs from a list of TrainingSample. - - Pads token sequences to the maximum length in the list and zero-pads - images to the maximum H/W when present. Returns a dictionary of batched - tensors suitable for a single forward pass. Keeps assistant_masks for - masked scoring. - """ - if not samples: - hud_console.warning("No samples to batch.") - return [] - - for s in samples: - if ( - "assistant_mask" not in s.inputs - or s.inputs["assistant_mask"].sum() == 0 - or s.advantage == 0.0 - ) and len(samples) > 1: - hud_console.info("Removing sample with zero advantage.") - samples.remove(s) - - if len(samples) == 1: - return samples - - import torch.nn.functional as F - - new_samples = [TrainingSample()] - - input_keys_to_expand = ["input_ids", "attention_mask", "assistant_mask"] - input_keys_to_cat = ["pixel_values", "image_grid_thw"] - updated_inputs: dict[str, list[torch.Tensor]] = { - k: [] for k in input_keys_to_expand + input_keys_to_cat - } - - # Sanity check dimensions - for s in samples: - for k in input_keys_to_expand + input_keys_to_cat: - val = s.inputs.get(k) - if val is not None: - if k in input_keys_to_expand: - if val.dim() == 2 and val.size(0) == 1: - val = val[0] - elif val.dim() != 1: - raise ValueError(f"{k} has unexpected dimensions: {val.shape}") - updated_inputs[k].append(val) - - # Pad 1D sequences to max length - max_len = max(t.size(-1) for t in updated_inputs["input_ids"]) - - def pad_1d(x: torch.Tensor, pad_to: int, pad_value: int) -> torch.Tensor: - pad = pad_to - x.size(-1) - return F.pad(x, (0, pad), value=pad_value) if pad > 0 else x - - stacked_inputs: dict[str, torch.Tensor] = {} - # These are 1D sequences that need padding - for k in input_keys_to_expand: - if updated_inputs[k]: - # assistant_mask is T-1, others are T - if k == "assistant_mask": - stacked_inputs[k] = torch.stack( - [pad_1d(x, max_len - 1, 0) for x in updated_inputs[k]], dim=0 - ) - else: - stacked_inputs[k] = torch.stack( - [pad_1d(x, max_len, 0) for x in updated_inputs[k]], dim=0 - ) - - for k in input_keys_to_cat: - if updated_inputs[k]: - # pixel_values and image_grid_thw are concatenated across all images from all samples - # Shape of pixel_values: (sum of all patches from all images, feature_dim) - # Shape of image_grid_thw: (sum of all images, 3) - stacked_inputs[k] = torch.cat(updated_inputs[k], dim=0) - else: - stacked_inputs.pop(k) - - new_samples[0].inputs = stacked_inputs - - # Pad logprobs to max length before stacking - # old_logprobs and ref_logprobs have shape [seq_len] or [1, seq_len] after gathering - def pad_logprobs(logprobs: torch.Tensor | None, max_len: int) -> torch.Tensor: - # Always work with 1D tensor, squeeze batch dim if present - if logprobs is None: - return torch.tensor([float("-inf")], dtype=torch.float32) - if logprobs.dim() == 2 and logprobs.size(0) == 1: - logprobs = logprobs.squeeze(0) - elif logprobs.dim() != 1: - raise ValueError( - f"Expected logprobs to have 1 or 2 dimensions, got {logprobs.dim()} with shape {logprobs.shape}" # noqa: E501 - ) - - # Now logprobs is [seq_len] - seq_len = logprobs.size(0) if logprobs is not None else 0 - if seq_len < max_len: - pad_size = max_len - seq_len - # Pad with -inf (log of 0 probability) along sequence dimension - return F.pad(logprobs, (0, pad_size), value=float("-inf")) - return logprobs - - # Stack padded logprobs (these are T-1 length) - old_logprobs_list = [pad_logprobs(s.old_logprobs, max_len - 1) for s in samples] - ref_logprobs_list = [pad_logprobs(s.ref_logprobs, max_len - 1) for s in samples] - - new_samples[0].old_logprobs = torch.stack(old_logprobs_list, dim=0) - new_samples[0].ref_logprobs = torch.stack(ref_logprobs_list, dim=0) - - # Stack advantages, checking for None values - advantages = [s.advantage for s in samples] - if any(adv is None for adv in advantages): - raise ValueError( - "Some samples have None advantages. Make sure advantages are computed before batching." - ) - new_samples[0].advantage = torch.stack(advantages, dim=0) # type: ignore - - return new_samples diff --git a/hud/rl/utils/start_vllm_server.sh b/hud/rl/utils/start_vllm_server.sh deleted file mode 100755 index 38ea6739..00000000 --- a/hud/rl/utils/start_vllm_server.sh +++ /dev/null @@ -1,30 +0,0 @@ -#!/bin/bash -# Start vLLM server with OpenAI-compatible API - -echo "Starting vLLM server for Qwen2.5-VL-3B-Instruct..." - -# Enable runtime LoRA adapter loading -export VLLM_ALLOW_RUNTIME_LORA_UPDATING=True - -export TOKENIZERS_PARALLELISM=false -export VLLM_LOGGING_LEVEL=DEBUG -export CUDA_LAUNCH_BLOCKING=1 # Better error messages for CUDA errors - -# Common vLLM server command -# Using CUDA_VISIBLE_DEVICES to put vLLM on GPU 1 -CUDA_VISIBLE_DEVICES=1 uv run vllm serve \ - Qwen/Qwen2.5-VL-3B-Instruct \ - --api-key token-abc123 \ - --host 0.0.0.0 \ - --port 8000 \ - --tensor-parallel-size 1 \ - --trust-remote-code \ - --max-model-len 16384 \ - --enable-lora \ - --max-lora-rank 64 \ - --max-cpu-loras 4 \ - --enable-auto-tool-choice \ - --tool-call-parser hermes \ - --chat-template chat_template.jinja \ - --enable-log-requests \ - --uvicorn-log-level=debug 2>&1 | tee vllm_debug.log \ No newline at end of file diff --git a/hud/rl/vllm_adapter.py b/hud/rl/vllm_adapter.py deleted file mode 100644 index 2937448e..00000000 --- a/hud/rl/vllm_adapter.py +++ /dev/null @@ -1,143 +0,0 @@ -"""vLLM adapter management for LoRA hot-swapping.""" - -from __future__ import annotations - -import json -import logging - -import requests - -from hud.utils.hud_console import HUDConsole - -hud_console = HUDConsole(logging.getLogger(__name__)) - - -class VLLMAdapter: - """Manages LoRA adapter loading/unloading in vLLM.""" - - def __init__(self, base_url: str, api_key: str) -> None: - self.base_url = base_url - self.api_key = api_key - self.current_adapter = None - - def load_adapter(self, adapter_name: str, adapter_path: str, timeout: int = 30) -> bool: - """ - Hot-load a LoRA adapter to vLLM. - - Args: - adapter_name: Name to register the adapter as - adapter_path: Path to the adapter checkpoint - timeout: Request timeout in seconds - - Returns: - True if successful, False otherwise - """ - url = f"{self.base_url}/load_lora_adapter" - headers = {"Authorization": f"Bearer {self.api_key}", "Content-Type": "application/json"} - payload = {"lora_name": adapter_name, "lora_path": adapter_path} - # Implement exponential backoff for retrying the adapter load request. - max_retries = 8 - backoff_factor = 2 - delay = 1 # initial delay in seconds - - for attempt in range(1, max_retries + 1): - try: - response = requests.post( - url, headers=headers, data=json.dumps(payload), timeout=timeout - ) - response.raise_for_status() - - self.current_adapter = adapter_name - hud_console.info(f"[VLLMAdapter] Loaded adapter: {adapter_name}") - return True - - except requests.exceptions.RequestException as e: - if attempt == max_retries: - hud_console.error( - f"[VLLMAdapter] Failed to load adapter {adapter_name} after {attempt} attempts: {e}" # noqa: E501 - ) - return False - else: - hud_console.warning( - f"[VLLMAdapter] Load adapter {adapter_name} failed (attempt {attempt}/{max_retries}): {e}. Retrying in {delay} seconds...", # noqa: E501 - ) - import time - - time.sleep(delay) - delay *= backoff_factor - - return False - - def unload_adapter(self, adapter_name: str) -> bool: - """ - Unload a LoRA adapter from vLLM. - - Args: - adapter_name: Name of the adapter to unload - - Returns: - True if successful, False otherwise - """ - url = f"{self.base_url}/unload_lora_adapter" - headers = {"Authorization": f"Bearer {self.api_key}", "Content-Type": "application/json"} - payload = {"lora_name": adapter_name} - - try: - response = requests.post(url, headers=headers, data=json.dumps(payload), timeout=30) - response.raise_for_status() - - if self.current_adapter == adapter_name: - self.current_adapter = None - - hud_console.info(f"[VLLMAdapter] Unloaded adapter: {adapter_name}") - return True - - except requests.exceptions.RequestException as e: - hud_console.error(f"[VLLMAdapter] Failed to unload adapter {adapter_name}: {e}") - return False - - def list_adapters(self) -> list | None: - """ - List all loaded LoRA adapters in vLLM. - - Returns: - List of adapter names, or None if failed - """ - url = f"{self.base_url}/list_lora_adapters" - headers = {"Authorization": f"Bearer {self.api_key}"} - - try: - response = requests.get(url, headers=headers, timeout=10) - response.raise_for_status() - return response.json().get("adapters", []) - - except requests.exceptions.RequestException as e: - hud_console.error(f"[VLLMAdapter] Failed to list adapters: {e}") - return None - - def get_current(self) -> str | None: - """Get the name of the currently loaded adapter.""" - return self.current_adapter - - -# Convenience function for standalone use -def hotload_lora( - adapter_name: str, - adapter_path: str, - base_url: str = "http://localhost:8000/v1", - api_key: str = "token-abc123", -) -> bool: - """ - Quick function to hot-load a LoRA adapter. - - Args: - adapter_name: Name for the adapter - adapter_path: Path to adapter checkpoint - base_url: vLLM server URL - api_key: API key for vLLM - - Returns: - True if successful - """ - adapter = VLLMAdapter(base_url, api_key) - return adapter.load_adapter(adapter_name, adapter_path) From d75d30e8e27c545b8874a823239b1641e5920d95 Mon Sep 17 00:00:00 2001 From: lorenss-m Date: Tue, 9 Dec 2025 08:49:50 -0800 Subject: [PATCH 07/92] deps --- pyproject.toml | 16 ---------------- 1 file changed, 16 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index e5cb99cc..c14c7d2b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -127,18 +127,6 @@ agents = [ "tornado>=6.5.2", ] -# RL training dependencies -rl = [ - "hud-python[agents]", # RL needs agent dependencies - "peft>=0.17.1", - "vllm==0.10.1.1", - "numpy>=1.24.0", # Required for RL training - "bitsandbytes>=0.41.0 ; sys_platform == 'linux'", # For 8-bit optimizers (Linux only) - "liger-kernel>=0.5.0 ; sys_platform == 'linux'", # Optimized Triton kernels for LLM training (Linux only) - # Note: flash-attn is recommended but optional - # Install separately with: uv pip install flash-attn --no-build-isolation -] - # Development dependencies - includes testing, linting, and automation tools dev = [ "hud-python[agents]", # Include agents for dev @@ -229,8 +217,6 @@ source = ["hud"] omit = [ "*/tests/*", "*/examples/*", - "hud/rl/*", - "hud/cli/rl/*", "hud/misc/*", ] @@ -252,8 +238,6 @@ fail_under = 58 omit = [ "*/tests/*", "*/examples/*", - "hud/rl/*", - "hud/cli/rl/*", "hud/misc/*", ] From c9622e575b62e5314daa9db0a746115de107291b Mon Sep 17 00:00:00 2001 From: lorenss-m Date: Tue, 9 Dec 2025 09:30:57 -0800 Subject: [PATCH 08/92] format and functionality adjustments --- README.md | 1 - hud/cli/utils/celebrate.py | 3 +- hud/cli/utils/viewer.py | 1 - hud/environment/environment.py | 25 +++-- hud/environment/router.py | 10 +- hud/environment/tests/test_environment.py | 2 +- hud/eval/__init__.py | 8 +- hud/eval/context.py | 108 ++++++++++------------ hud/eval/manager.py | 41 ++++---- hud/eval/mixin.py | 30 +++--- hud/eval/parallel.py | 22 ++++- hud/eval/tests/__init__.py | 1 - hud/eval/tests/test_context.py | 32 +++---- hud/eval/tests/test_parallel.py | 1 - hud/otel/__init__.py | 16 ++-- hud/telemetry/__init__.py | 11 ++- 16 files changed, 165 insertions(+), 147 deletions(-) diff --git a/README.md b/README.md index bd4fdfd3..0e605f15 100644 --- a/README.md +++ b/README.md @@ -392,7 +392,6 @@ Key areas: - [Environment examples](environments/) - Add new MCP environments - [Agent implementations](hud/agents/) - Add support for new LLM providers - [Tool library](hud/tools/) - Extend the built-in tool collection -- [RL training](hud/rl/) - Improve reinforcement learning pipelines Thanks to all our contributors! diff --git a/hud/cli/utils/celebrate.py b/hud/cli/utils/celebrate.py index 8e587822..66eb5fc4 100644 --- a/hud/cli/utils/celebrate.py +++ b/hud/cli/utils/celebrate.py @@ -186,5 +186,4 @@ def _run_confetti() -> None: # Don't wait - let operations continue while confetti plays -__all__ = ["show_confetti", "show_confetti_async", "ConfettiSystem", "Particle"] - +__all__ = ["ConfettiSystem", "Particle", "show_confetti", "show_confetti_async"] diff --git a/hud/cli/utils/viewer.py b/hud/cli/utils/viewer.py index 59c93c54..8ea54a98 100644 --- a/hud/cli/utils/viewer.py +++ b/hud/cli/utils/viewer.py @@ -139,4 +139,3 @@ def show_json_interactive( input() console.print() - diff --git a/hud/environment/environment.py b/hud/environment/environment.py index 55d3cd49..8509b6ca 100644 --- a/hud/environment/environment.py +++ b/hud/environment/environment.py @@ -4,13 +4,11 @@ import asyncio import logging -import types from collections.abc import Awaitable, Callable -from typing import Any, Literal +from typing import TYPE_CHECKING, Any, Literal, Self import mcp.types as mcp_types -from hud.environment.connection import Connector from hud.environment.connectors import ConnectorsMixin from hud.environment.integrations import IntegrationsMixin from hud.environment.mock import MockMixin @@ -19,6 +17,11 @@ from hud.server.server import MCPServer from hud.types import MCPToolResult +if TYPE_CHECKING: + import types + + from hud.environment.connection import Connector + __all__ = ["Environment"] logger = logging.getLogger(__name__) @@ -229,7 +232,7 @@ def evaluate_tool(self, call: Any, /, **kwargs: Any) -> Environment: # Context Manager # ========================================================================= - async def __aenter__(self) -> Environment: + async def __aenter__(self) -> Self: """Connect all connectors, build routing, run setup tools.""" self._in_context = True @@ -372,13 +375,14 @@ async def read_resource( uri=resource_uri, blob=base64.b64encode(result).decode() ) ] - except Exception: - pass + except Exception as e: + logger.debug("Local resource read failed for %s: %s", uri, e) for conn in self._connections.values(): try: return await conn.read_resource(uri) - except Exception: + except Exception as e: + logger.debug("Remote resource read failed for %s: %s", uri, e) continue raise ValueError(f"Resource not found: {uri}") @@ -408,13 +412,14 @@ async def get_prompt( """Get a prompt by name (tries local first, then remote).""" try: return await self._prompt_manager.render_prompt(name, arguments or {}) - except Exception: - pass + except Exception as e: + logger.debug("Local prompt render failed for %s: %s", name, e) for conn in self._connections.values(): try: return await conn.get_prompt(name, arguments) - except Exception: + except Exception as e: + logger.debug("Remote prompt get failed for %s: %s", name, e) continue raise ValueError(f"Prompt not found: {name}") diff --git a/hud/environment/router.py b/hud/environment/router.py index 2dc88a36..962b0802 100644 --- a/hud/environment/router.py +++ b/hud/environment/router.py @@ -7,9 +7,9 @@ from enum import Enum from typing import TYPE_CHECKING -import mcp.types as mcp_types - if TYPE_CHECKING: + import mcp.types as mcp_types + from hud.environment.connection import Connector __all__ = ["LOCAL_CONNECTION", "ConflictResolution", "ToolRouter"] @@ -100,7 +100,5 @@ def _handle_conflict(self, name: str, existing: str, new: str) -> bool: raise ValueError(f"Tool conflict: '{name}' in '{existing}' and '{new}'") if self.conflict_resolution == ConflictResolution.FIRST_WINS: return False - if self.conflict_resolution == ConflictResolution.LAST_WINS: - return True - # PREFIX - shouldn't conflict if prefixes set correctly - return False + # LAST_WINS returns True, PREFIX (shouldn't conflict) returns False + return self.conflict_resolution == ConflictResolution.LAST_WINS diff --git a/hud/environment/tests/test_environment.py b/hud/environment/tests/test_environment.py index f3eeff4e..1f75ab33 100644 --- a/hud/environment/tests/test_environment.py +++ b/hud/environment/tests/test_environment.py @@ -27,7 +27,7 @@ def test_prompt_can_be_set(self) -> None: def test_prompt_set_from_task(self) -> None: """connect_task sets prompt from task.prompt.""" - from hud.environment.connection import Connector + from hud.environment.connection import Connector # noqa: TC001 from hud.environment.connectors.task import TaskConnectorMixin from hud.types import Task diff --git a/hud/eval/__init__.py b/hud/eval/__init__.py index 43d947cd..88bec509 100644 --- a/hud/eval/__init__.py +++ b/hud/eval/__init__.py @@ -24,12 +24,15 @@ from typing import TYPE_CHECKING -# EvalMixin is safe to import (uses lazy imports internally) -from hud.eval.mixin import EvalMixin +# Auto-instrument httpx on import +import hud.eval.instrument # noqa: F401 # run_eval is safe to import (uses lazy imports internally) from hud.eval.manager import run_eval +# EvalMixin is safe to import (uses lazy imports internally) +from hud.eval.mixin import EvalMixin + if TYPE_CHECKING: from hud.eval.context import EvalContext @@ -44,5 +47,6 @@ def __getattr__(name: str) -> object: """Lazy import EvalContext to avoid circular imports.""" if name == "EvalContext": from hud.eval.context import EvalContext + return EvalContext raise AttributeError(f"module 'hud.eval' has no attribute {name!r}") diff --git a/hud/eval/context.py b/hud/eval/context.py index 0e027d61..8cb5b09f 100644 --- a/hud/eval/context.py +++ b/hud/eval/context.py @@ -15,8 +15,6 @@ from datetime import UTC, datetime from typing import TYPE_CHECKING, Any, Self -from pydantic import BaseModel - from hud.environment import Environment from hud.environment.types import EnvConfig from hud.settings import settings @@ -28,6 +26,8 @@ from hud.types import Task +from hud.eval.types import EvalExitPayload, EvalPayload, ParallelEvalComplete + logger = logging.getLogger(__name__) # Contextvar to store current trace headers (for httpx auto-instrumentation) @@ -41,32 +41,6 @@ def get_current_trace_headers() -> dict[str, str] | None: return _current_trace_headers.get() -# ============================================================================= -# Payload Models -# ============================================================================= - - -class EvalPayload(BaseModel): - """Base payload for eval enter/exit.""" - - task_name: str - prompt: str | None = None - code_snippet: str | None = None - env_config: EnvConfig | None = None - all_hubs: bool = False - job_id: str | None = None - group_id: str | None = None - variants: dict[str, Any] | None = None - - -class EvalExitPayload(EvalPayload): - """Exit payload with result fields.""" - - reward: float | None = None - success: bool = True - error_message: str | None = None - - # ============================================================================= # EvalContext # ============================================================================= @@ -256,9 +230,7 @@ def from_environment( # Copy connections from parent - each connector is copied so parallel # execution gets fresh client instances - ctx._connections = { - name: connector.copy() for name, connector in env._connections.items() - } + ctx._connections = {name: connector.copy() for name, connector in env._connections.items()} ctx._hub_configs = getattr(env, "_hub_configs", []).copy() ctx._setup_calls = env._setup_calls.copy() ctx._evaluate_calls = env._evaluate_calls.copy() @@ -312,6 +284,51 @@ def from_task( task=task, ) + # ========================================================================= + # Summary Context - Attribute Access Control + # ========================================================================= + + # Attributes accessible on summary context (everything else raises) + _SUMMARY_ALLOWED = frozenset( + { + # Results and metadata + "results", + "reward", + "error", + "trace_id", + "job_id", + "group_id", + "index", + "variants", + "eval_name", + "duration", + "success", + "done" + # Private attrs + "_is_summary", + "_suppress_link", + "__class__", + "__dict__", + } + ) + + def __getattribute__(self, name: str) -> Any: + """Block most attribute access on summary contexts.""" + # Always allow private/dunder and whitelisted attrs + if name.startswith("_") or name in EvalContext._SUMMARY_ALLOWED: + return super().__getattribute__(name) + + # Check if this is a summary context + try: + is_summary = super().__getattribute__("_is_summary") + except AttributeError: + is_summary = False + + if is_summary: + raise ParallelEvalComplete + + return super().__getattribute__(name) + # ========================================================================= # Computed Properties (eval-specific) # ========================================================================= @@ -485,35 +502,10 @@ def _print_eval_link(self) -> None: if self._suppress_link: return - import contextlib - import webbrowser + from hud.eval.display import print_link trace_url = f"https://hud.ai/trace/{self.trace_id}" - - with contextlib.suppress(Exception): - webbrowser.open(trace_url, new=2) - - try: - from rich.align import Align - from rich.console import Console - from rich.panel import Panel - - console = Console() - - style = "bold underline rgb(108,113,196)" - link_markup = f"[{style}][link={trace_url}]{trace_url}[/link][/{style}]" - - content = Align.center(link_markup) - - panel = Panel( - content, - title="🔗 Eval Started", - border_style="rgb(192,150,12)", - padding=(0, 2), - ) - console.print(panel) - except ImportError: - print(f"Eval: {trace_url}") # noqa: T201 + print_link(trace_url, "🔗 Eval Started") # Re-export for backwards compatibility with trace module diff --git a/hud/eval/manager.py b/hud/eval/manager.py index c5b051af..84b6a8c2 100644 --- a/hud/eval/manager.py +++ b/hud/eval/manager.py @@ -11,6 +11,7 @@ from contextlib import asynccontextmanager from typing import TYPE_CHECKING, Any +from hud.eval.display import print_complete, print_eval_stats, print_link from hud.eval.parallel import ( ASTExtractionError, expand_variants, @@ -18,7 +19,6 @@ get_with_block_body, resolve_group_ids, ) -from hud.telemetry.job import _print_job_complete_url, _print_job_url if TYPE_CHECKING: from collections.abc import AsyncGenerator @@ -114,8 +114,7 @@ def _load_tasks_from_slugs(slugs: str | list[str]) -> list[Task]: data = response.json() if isinstance(data, list): - for item in data: - tasks.append(Task(**item)) + tasks.extend(Task(**item) for item in data) else: tasks.append(Task(**data)) @@ -154,6 +153,7 @@ async def run_eval( group_ids: list[str] | None = None, job_id: str | None = None, api_key: str | None = None, + max_concurrent: int | None = None, ) -> AsyncGenerator[EvalContext, None]: """Standalone eval context manager. @@ -172,6 +172,7 @@ async def run_eval( group_ids: Optional list of group IDs job_id: Job ID to link to api_key: API key for backend calls + max_concurrent: Maximum concurrent evals (None = unlimited) Yields: EvalContext: Environment with evaluation tracking @@ -205,6 +206,10 @@ async def run_eval( await run_agent(model) ctx.reward = evaluate() + # With concurrency limit + async with hud.eval("my-org/evalset:*", max_concurrent=10) as ctx: + await agent.run(ctx) + # Access results after parallel run for e in ctx.results: print(f"{e.variants}: reward={e.reward}") @@ -224,10 +229,7 @@ async def run_eval( # Calculate total evaluations # If we have tasks, each task gets (variants x group) runs # If no tasks, we have a single blank eval with (variants x group) runs - if tasks: - total_evals = len(tasks) * len(variant_combos) * group - else: - total_evals = len(variant_combos) * group + total_evals = len(tasks) * len(variant_combos) * group if tasks else len(variant_combos) * group # Capture code snippet for parallel execution code_snippet: str | None = None @@ -274,9 +276,10 @@ async def run_eval( # Parallel execution: create implicit job to group traces eval_name = _get_eval_name(slugs) implicit_job_id = job_id or str(uuid.uuid4()) + job_url = f"https://hud.ai/jobs/{implicit_job_id}" # Print job URL (not individual trace URLs) - _print_job_url(implicit_job_id, eval_name) + print_link(job_url, f"🚀 Job '{eval_name}'") error_occurred = False try: @@ -289,6 +292,7 @@ async def run_eval( job_id=implicit_job_id, # Propagate job_id to child traces api_key=api_key, code_snippet=code_snippet, + max_concurrent=max_concurrent, ) # Create summary context (no trace, just aggregates results) @@ -321,7 +325,7 @@ async def run_eval( error_occurred = True raise finally: - _print_job_complete_url(implicit_job_id, eval_name, error_occurred) + print_complete(job_url, eval_name, error=error_occurred) async def _run_parallel_eval( @@ -332,6 +336,7 @@ async def _run_parallel_eval( job_id: str | None, api_key: str | None, code_snippet: str | None, + max_concurrent: int | None, ) -> list[EvalContext]: """Run parallel evaluation. @@ -346,11 +351,7 @@ async def _run_parallel_eval( body_source, captured_locals, context_var = get_with_block_body(caller_frame) # Calculate total evals and resolve group IDs - if tasks: - total_evals = len(tasks) * len(variant_combos) * group - else: - total_evals = len(variant_combos) * group - + total_evals = len(tasks) * len(variant_combos) * group if tasks else len(variant_combos) * group resolved_group_ids = resolve_group_ids(group_ids, total_evals) # Create EvalContexts @@ -393,19 +394,23 @@ async def _run_parallel_eval( # Run in parallel logger.info( - "Running %d evals (%d tasks x %d variants x %d runs)", + "Running %d evals (%d tasks x %d variants x %d runs)%s", len(eval_contexts), max(len(tasks), 1), len(variant_combos), group, + f", max_concurrent={max_concurrent}" if max_concurrent else "", + ) + completed = await run_parallel_evals( + eval_contexts, body_source, captured_locals, context_var, max_concurrent ) - completed = await run_parallel_evals(eval_contexts, body_source, captured_locals, context_var) - # Log stats + # Log and print stats + eval_name = completed[0].eval_name if completed else "eval" log_eval_stats(completed) + print_eval_stats(completed, eval_name) return completed __all__ = ["run_eval"] - diff --git a/hud/eval/mixin.py b/hud/eval/mixin.py index 49061ed8..10d56df9 100644 --- a/hud/eval/mixin.py +++ b/hud/eval/mixin.py @@ -7,12 +7,14 @@ from __future__ import annotations +import contextlib import inspect import logging import uuid from contextlib import asynccontextmanager from typing import TYPE_CHECKING, Any +from hud.eval.display import print_complete, print_eval_stats, print_link from hud.eval.parallel import ( ASTExtractionError, expand_variants, @@ -20,7 +22,7 @@ get_with_block_body, resolve_group_ids, ) -from hud.telemetry.job import _print_job_complete_url, _print_job_url +from hud.eval.types import ParallelEvalComplete if TYPE_CHECKING: from collections.abc import AsyncGenerator @@ -137,6 +139,7 @@ async def eval( job_id: str | None = None, trace_id: str | None = None, api_key: str | None = None, + max_concurrent: int | None = None, ) -> AsyncGenerator[EvalContext, None]: """Create an eval context for recording an agent run. @@ -172,6 +175,7 @@ async def eval( trace_id: Optional trace ID (auto-generated if not provided). For parallel execution, each eval gets a unique ID. api_key: Optional API key for backend calls (defaults to settings.api_key) + max_concurrent: Maximum concurrent evals (None = unlimited) Yields: EvalContext for this evaluation. Inside the body: @@ -257,9 +261,10 @@ async def eval( else: # Parallel execution: create implicit job to group traces implicit_job_id = job_id or str(uuid.uuid4()) + job_url = f"https://hud.ai/jobs/{implicit_job_id}" # Print job URL (not individual trace URLs) - _print_job_url(implicit_job_id, name) + print_link(job_url, f"🚀 Job '{name}'") error_occurred = False try: @@ -273,6 +278,7 @@ async def eval( api_key=api_key, code_snippet=code_snippet, env_config=env_config, + max_concurrent=max_concurrent, ) # Create summary context (no trace, just aggregates results) @@ -297,12 +303,10 @@ async def eval( # Check if any failed error_occurred = any(e.error is not None for e in completed) - yield ctx - except Exception: - error_occurred = True - raise + with contextlib.suppress(ParallelEvalComplete): + yield ctx finally: - _print_job_complete_url(implicit_job_id, name, error_occurred) + print_complete(job_url, name, error=error_occurred) async def _run_parallel_eval( self, @@ -314,6 +318,7 @@ async def _run_parallel_eval( api_key: str | None, code_snippet: str | None, env_config: dict[str, Any] | None, + max_concurrent: int | None, ) -> list[EvalContext]: """Run parallel eval execution. @@ -353,20 +358,23 @@ async def _run_parallel_eval( # Run in parallel logger.info( - "Running %d evals for '%s' (%d variants x %d runs)", + "Running %d evals for '%s' (%d variants x %d runs)%s", len(eval_contexts), name, len(variant_combos), group, + f", max_concurrent={max_concurrent}" if max_concurrent else "", + ) + completed = await run_parallel_evals( + eval_contexts, body_source, captured_locals, context_var, max_concurrent ) - completed = await run_parallel_evals(eval_contexts, body_source, captured_locals, context_var) - # Store results and log stats + # Store results and print stats self._last_evals = completed log_eval_stats(completed, name) + print_eval_stats(completed, name) return completed __all__ = ["EvalMixin"] - diff --git a/hud/eval/parallel.py b/hud/eval/parallel.py index 6eab25a5..d5651d6d 100644 --- a/hud/eval/parallel.py +++ b/hud/eval/parallel.py @@ -14,10 +14,11 @@ import logging import textwrap import uuid -from types import FrameType from typing import TYPE_CHECKING, Any if TYPE_CHECKING: + from types import FrameType + from hud.eval.context import EvalContext logger = logging.getLogger(__name__) @@ -244,7 +245,10 @@ def get_with_block_body(frame: Any) -> tuple[str, dict[str, Any], str]: # Extract the context variable name from 'as' clause context_var = _extract_context_var(with_node) - return body_source, frame.f_locals.copy(), context_var + # Capture both globals (imports) and locals (variables in scope) + captured = {**frame.f_globals, **frame.f_locals} + + return body_source, captured, context_var def _extract_context_var(with_node: ast.AsyncWith) -> str: @@ -296,6 +300,7 @@ async def run_parallel_evals( body_source: str, captured_locals: dict[str, Any], context_var: str, + max_concurrent: int | None = None, ) -> list[EvalContext]: """Run the eval body in parallel for multiple contexts. @@ -311,6 +316,7 @@ async def run_parallel_evals( body_source: The source code of the with-block body captured_locals: Local variables captured from the caller context_var: The variable name used in the 'as' clause + max_concurrent: Maximum concurrent evals (None = unlimited) """ # Create runner function using the actual variable name from the 'as' clause @@ -320,10 +326,17 @@ async def run_parallel_evals( exec(code, namespace) # noqa: S102 runner = namespace["__runner__"] + # Create semaphore for concurrency control + sem = asyncio.Semaphore(max_concurrent) if max_concurrent else None + async def run_one(ctx: EvalContext) -> EvalContext: try: - async with ctx: - await runner(ctx) + if sem: + async with sem, ctx: + await runner(ctx) + else: + async with ctx: + await runner(ctx) except Exception as e: logger.warning("Parallel eval %d failed: %s", ctx.index, e) ctx.error = e @@ -342,4 +355,3 @@ async def run_one(ctx: EvalContext) -> EvalContext: "resolve_group_ids", "run_parallel_evals", ] - diff --git a/hud/eval/tests/__init__.py b/hud/eval/tests/__init__.py index 64147a3e..3b6c294e 100644 --- a/hud/eval/tests/__init__.py +++ b/hud/eval/tests/__init__.py @@ -1,2 +1 @@ """Tests for hud.eval module.""" - diff --git a/hud/eval/tests/test_context.py b/hud/eval/tests/test_context.py index e4cbc8c7..2f291feb 100644 --- a/hud/eval/tests/test_context.py +++ b/hud/eval/tests/test_context.py @@ -77,30 +77,27 @@ async def test_context_manager_sets_headers(self) -> None: with ( patch.object(ctx, "_eval_enter", new_callable=AsyncMock), patch.object(ctx, "_eval_exit", new_callable=AsyncMock), + patch.object(EvalContext, "__aenter__", return_value=ctx), + patch.object(EvalContext, "__aexit__", return_value=None), ): - # Mock parent Environment context manager - with patch.object(EvalContext, "__aenter__", return_value=ctx): - with patch.object(EvalContext, "__aexit__", return_value=None): - assert get_current_trace_headers() is None + assert get_current_trace_headers() is None - # Manually set token for test - from hud.eval.context import _current_trace_headers + # Manually set token for test + from hud.eval.context import _current_trace_headers - token = _current_trace_headers.set(ctx.headers) - try: - headers = get_current_trace_headers() - assert headers is not None - assert headers["Trace-Id"] == "test-123" - finally: - _current_trace_headers.reset(token) + token = _current_trace_headers.set(ctx.headers) + try: + headers = get_current_trace_headers() + assert headers is not None + assert headers["Trace-Id"] == "test-123" + finally: + _current_trace_headers.reset(token) - assert get_current_trace_headers() is None + assert get_current_trace_headers() is None def test_repr(self) -> None: """__repr__ shows useful info.""" - ctx = EvalContext( - name="test-task", trace_id="abc12345-6789-0000-0000-000000000000" - ) + ctx = EvalContext(name="test-task", trace_id="abc12345-6789-0000-0000-000000000000") ctx.reward = 0.95 repr_str = repr(ctx) @@ -176,4 +173,3 @@ def test_sets_eval_properties(self) -> None: assert ctx.variants == {"model": "gpt-4o"} assert ctx.group_id == "group-123" assert ctx.index == 5 - diff --git a/hud/eval/tests/test_parallel.py b/hud/eval/tests/test_parallel.py index baff6a6b..9fef3f98 100644 --- a/hud/eval/tests/test_parallel.py +++ b/hud/eval/tests/test_parallel.py @@ -231,4 +231,3 @@ def test_is_exception(self) -> None: error = ASTExtractionError("test message") assert isinstance(error, Exception) assert str(error) == "test message" - diff --git a/hud/otel/__init__.py b/hud/otel/__init__.py index 4efc7d50..855c7622 100644 --- a/hud/otel/__init__.py +++ b/hud/otel/__init__.py @@ -22,14 +22,6 @@ import warnings -# Show deprecation warning when module is imported -warnings.warn( - "The hud.otel module is deprecated. Use env.trace() instead. " - "This module requires pip install hud-python[agents].", - DeprecationWarning, - stacklevel=2, -) - from .collector import enable_trace_collection from .config import configure_telemetry, is_telemetry_configured, shutdown_telemetry from .context import ( @@ -39,6 +31,14 @@ trace, ) +# Show deprecation warning when module is imported +warnings.warn( + "The hud.otel module is deprecated. Use env.trace() instead. " + "This module requires pip install hud-python[agents].", + DeprecationWarning, + stacklevel=2, +) + __all__ = [ "configure_telemetry", "enable_trace_collection", diff --git a/hud/telemetry/__init__.py b/hud/telemetry/__init__.py index a6c17234..1125fba0 100644 --- a/hud/telemetry/__init__.py +++ b/hud/telemetry/__init__.py @@ -54,7 +54,12 @@ def __getattr__(name: str): # noqa: ANN202 if name in ("Job", "job", "create_job", "get_current_job"): from .job import Job, create_job, get_current_job, job - return {"Job": Job, "job": job, "create_job": create_job, "get_current_job": get_current_job}[name] + return { + "Job": Job, + "job": job, + "create_job": create_job, + "get_current_job": get_current_job, + }[name] elif name in ("async_job", "async_trace"): from .async_context import async_job, async_trace @@ -72,9 +77,6 @@ def __getattr__(name: str): # noqa: ANN202 __all__ = [ - # Core (always available) - "instrument", - # Deprecated "Job", "Trace", "async_job", @@ -83,6 +85,7 @@ def __getattr__(name: str): # noqa: ANN202 "create_job", "get_current_job", "get_trace", + "instrument", "job", "trace", ] From 274e1c7c60434f291e44b8dc11a6c2e8da4ebef2 Mon Sep 17 00:00:00 2001 From: lorenss-m Date: Tue, 9 Dec 2025 09:31:14 -0800 Subject: [PATCH 09/92] misc additions --- hud/eval/display.py | 209 +++++++++++++++++++++++++++++++++++++++++ hud/eval/instrument.py | 111 ++++++++++++++++++++++ hud/eval/types.py | 57 +++++++++++ 3 files changed, 377 insertions(+) create mode 100644 hud/eval/display.py create mode 100644 hud/eval/instrument.py create mode 100644 hud/eval/types.py diff --git a/hud/eval/display.py b/hud/eval/display.py new file mode 100644 index 00000000..3f6f5c43 --- /dev/null +++ b/hud/eval/display.py @@ -0,0 +1,209 @@ +"""Display helpers for eval links and job URLs. + +Provides consistent, beautiful display for HUD URLs using rich. +""" + +from __future__ import annotations + +import contextlib +import webbrowser +from statistics import mean, pstdev +from typing import TYPE_CHECKING, Any + +from hud.settings import settings + +if TYPE_CHECKING: + from hud.eval.context import EvalContext + + +def print_link(url: str, title: str, *, open_browser: bool = True) -> None: + """Print a nicely formatted link with optional browser opening. + + Args: + url: The URL to display + title: Title for the panel (e.g., "🔗 Eval Started", "🚀 Job Started") + open_browser: Whether to open the URL in browser + """ + # Only print if telemetry is enabled and has API key + if not (settings.telemetry_enabled and settings.api_key): + return + + # Open in browser + if open_browser: + with contextlib.suppress(Exception): + webbrowser.open(url, new=2) + + try: + from rich.align import Align + from rich.console import Console + from rich.panel import Panel + + console = Console() + + style = "bold underline rgb(108,113,196)" + link_markup = f"[{style}][link={url}]{url}[/link][/{style}]" + + content = Align.center(link_markup) + + panel = Panel( + content, + title=title, + border_style="rgb(192,150,12)", + padding=(0, 2), + ) + console.print(panel) + except ImportError: + print(f"{title}: {url}") # noqa: T201 + + +def print_complete(url: str, name: str, *, error: bool = False) -> None: + """Print a completion message with link. + + Args: + url: The URL to display + name: Name of the eval/job + error: Whether an error occurred + """ + # Only print if telemetry is enabled and has API key + if not (settings.telemetry_enabled and settings.api_key): + return + + try: + from rich.console import Console + + console = Console() + + if error: + console.print( + f"\n[red]✗ '{name}' failed![/red] [dim]View details at:[/dim] " + f"[bold link={url}]{url}[/bold link]\n" + ) + else: + console.print( + f"\n[green]✓ '{name}' complete![/green] [dim]View results at:[/dim] " + f"[bold link={url}]{url}[/bold link]\n" + ) + except ImportError: + status = "failed" if error else "complete" + print(f"\n{name} {status}: {url}\n") # noqa: T201 + + +def print_eval_stats( + completed: list[EvalContext], + name: str = "", + *, + elapsed: float | None = None, + show_details: bool = True, +) -> None: + """Print statistics for completed evaluations. + + Args: + completed: List of completed EvalContext objects + name: Optional name for the evaluation + elapsed: Optional elapsed time in seconds + show_details: Whether to show per-eval details table + """ + if not completed: + return + + try: + from rich.console import Console + from rich.table import Table + + console = Console() + except ImportError: + # Fallback to basic printing + _print_eval_stats_basic(completed, name, elapsed) + return + + # Calculate aggregate stats + rewards = [ctx.reward for ctx in completed if ctx.reward is not None] + errors = [ctx for ctx in completed if ctx.error is not None] + durations = [ctx.duration for ctx in completed if ctx.duration > 0] + + mean_reward = mean(rewards) if rewards else 0.0 + std_reward = pstdev(rewards) if len(rewards) > 1 else 0.0 + success_rate = (len(completed) - len(errors)) / len(completed) if completed else 0.0 + + # Print summary + title = f"📊 '{name}' Results" if name else "📊 Eval Results" + console.print(f"\n[bold]{title}[/bold]") + console.print(f" [dim]Evals:[/dim] {len(completed)}") + if elapsed: + rate = len(completed) / elapsed if elapsed > 0 else 0 + console.print(f" [dim]Time:[/dim] {elapsed:.1f}s ({rate:.1f} evals/s)") + if durations: + mean_duration = mean(durations) + console.print(f" [dim]Avg duration:[/dim] {mean_duration:.2f}s") + console.print(f" [dim]Mean reward:[/dim] [green]{mean_reward:.3f}[/green] ± {std_reward:.3f}") + console.print(f" [dim]Success rate:[/dim] [yellow]{success_rate * 100:.1f}%[/yellow]") + if errors: + console.print(f" [dim]Errors:[/dim] [red]{len(errors)}[/red]") + + # Show details table if requested and not too many + if show_details and len(completed) <= 50: + table = Table(title="Per-Eval Details", show_header=True, header_style="bold") + table.add_column("#", style="dim", justify="right", width=4) + table.add_column("Variants", style="cyan", max_width=30) + table.add_column("Reward", justify="right", style="green", width=8) + table.add_column("Duration", justify="right", width=10) + table.add_column("Status", justify="center", width=8) + + for ctx in completed: + idx_str = str(ctx.index) + variants_str = _format_variants(ctx.variants) if ctx.variants else "-" + reward_str = f"{ctx.reward:.3f}" if ctx.reward is not None else "-" + duration_str = f"{ctx.duration:.2f}s" if ctx.duration > 0 else "-" + + if ctx.error: + status = "[red]✗[/red]" + elif ctx.reward is not None and ctx.reward > 0.7: + status = "[green]✓[/green]" + else: + status = "[yellow]○[/yellow]" + + table.add_row(idx_str, variants_str, reward_str, duration_str, status) + + console.print(table) + + # Warn about high variance + if std_reward > 0.3: + console.print(f"\n[yellow]⚠️ High variance detected (std={std_reward:.3f})[/yellow]") + + console.print() + + +def _format_variants(variants: dict[str, Any]) -> str: + """Format variants dict for display.""" + if not variants: + return "-" + parts = [f"{k}={v}" for k, v in variants.items()] + result = ", ".join(parts) + return result[:30] + "..." if len(result) > 30 else result + + +def _print_eval_stats_basic( + completed: list[EvalContext], + name: str, + elapsed: float | None, +) -> None: + """Basic stats printing without rich.""" + rewards = [ctx.reward for ctx in completed if ctx.reward is not None] + errors = [ctx for ctx in completed if ctx.error is not None] + + mean_reward = mean(rewards) if rewards else 0.0 + success_rate = (len(completed) - len(errors)) / len(completed) if completed else 0.0 + + title = f"'{name}' Results" if name else "Eval Results" + print(f"\n{title}") # noqa: T201 + print(f" Evals: {len(completed)}") # noqa: T201 + if elapsed: + print(f" Time: {elapsed:.1f}s") # noqa: T201 + print(f" Mean reward: {mean_reward:.3f}") # noqa: T201 + print(f" Success rate: {success_rate * 100:.1f}%") # noqa: T201 + if errors: + print(f" Errors: {len(errors)}") # noqa: T201 + print() # noqa: T201 + + +__all__ = ["print_complete", "print_eval_stats", "print_link"] diff --git a/hud/eval/instrument.py b/hud/eval/instrument.py new file mode 100644 index 00000000..9db50c4e --- /dev/null +++ b/hud/eval/instrument.py @@ -0,0 +1,111 @@ +"""Auto-instrumentation for httpx to inject trace headers. + +This module patches httpx clients to automatically add: +- Trace-Id headers when inside an eval context +- Authorization headers for HUD API calls +""" + +from __future__ import annotations + +import logging +from typing import Any +from urllib.parse import urlparse + +from hud.settings import settings + +logger = logging.getLogger(__name__) + + +def _get_trace_headers() -> dict[str, str] | None: + """Lazy import to avoid circular dependency.""" + from hud.eval.context import get_current_trace_headers + + return get_current_trace_headers() + + +def _is_hud_url(url_str: str) -> bool: + """Check if URL is a HUD service (inference or MCP).""" + # Extract hostnames from settings URLs + gateway_host = urlparse(settings.hud_gateway_url).netloc + mcp_host = urlparse(settings.hud_mcp_url).netloc + + # Parse the request URL and check against known HUD hosts + parsed = urlparse(url_str) + request_host = parsed.netloc or url_str.split("/")[0] + + return request_host in (gateway_host, mcp_host) + + +def _httpx_request_hook(request: Any) -> None: + """httpx event hook that adds trace headers and auth to HUD requests. + + For inference.hud.ai and mcp.hud.ai: + - Injects trace headers (Trace-Id) if in trace context + - Injects Authorization header if API key is set and no auth present + """ + url_str = str(request.url) + if not _is_hud_url(url_str): + return + + # Inject trace headers if in trace context + headers = _get_trace_headers() + if headers is not None: + for key, value in headers.items(): + request.headers[key] = value + logger.debug("Added trace headers to request: %s", url_str) + + # Auto-inject API key if not present + has_auth = "authorization" in {k.lower() for k in request.headers} + if not has_auth and settings.api_key: + request.headers["Authorization"] = f"Bearer {settings.api_key}" + logger.debug("Added API key auth to request: %s", url_str) + + +async def _async_httpx_request_hook(request: Any) -> None: + """Async version of the httpx event hook.""" + _httpx_request_hook(request) + + +def _instrument_client(client: Any) -> None: + """Add trace hook to an httpx client instance.""" + is_async = hasattr(client, "aclose") + hook = _async_httpx_request_hook if is_async else _httpx_request_hook + + existing_hooks = client.event_hooks.get("request", []) + if hook not in existing_hooks: + existing_hooks.append(hook) + client.event_hooks["request"] = existing_hooks + + +def _patch_httpx() -> None: + """Monkey-patch httpx to auto-instrument all clients.""" + try: + import httpx + except ImportError: + logger.debug("httpx not installed, skipping auto-instrumentation") + return + + _original_async_init = httpx.AsyncClient.__init__ + + def _patched_async_init(self: Any, *args: Any, **kwargs: Any) -> None: + _original_async_init(self, *args, **kwargs) + _instrument_client(self) + + httpx.AsyncClient.__init__ = _patched_async_init # type: ignore[method-assign] + + _original_sync_init = httpx.Client.__init__ + + def _patched_sync_init(self: Any, *args: Any, **kwargs: Any) -> None: + _original_sync_init(self, *args, **kwargs) + _instrument_client(self) + + httpx.Client.__init__ = _patched_sync_init # type: ignore[method-assign] + + logger.debug("httpx auto-instrumentation enabled") + + +# Auto-patch httpx on module import +_patch_httpx() + + +__all__ = ["_patch_httpx"] diff --git a/hud/eval/types.py b/hud/eval/types.py new file mode 100644 index 00000000..86da6957 --- /dev/null +++ b/hud/eval/types.py @@ -0,0 +1,57 @@ +"""Types and exceptions for the eval module. + +Kept separate to avoid circular imports. +""" + +from __future__ import annotations + +from typing import Any + +from pydantic import BaseModel + +from hud.environment.types import EnvConfig + +# ============================================================================= +# Exceptions +# ============================================================================= + + +class ParallelEvalComplete(Exception): + """Raised by summary context to skip body re-execution after parallel eval. + + This is caught by the eval() context manager to cleanly exit. + The summary context with results is still accessible after the with block. + """ + + +# ============================================================================= +# Payload Models +# ============================================================================= + + +class EvalPayload(BaseModel): + """Base payload for eval enter/exit.""" + + task_name: str + prompt: str | None = None + code_snippet: str | None = None + env_config: EnvConfig | None = None + all_hubs: bool = False + job_id: str | None = None + group_id: str | None = None + variants: dict[str, Any] | None = None + + +class EvalExitPayload(EvalPayload): + """Exit payload with result fields.""" + + reward: float | None = None + success: bool = True + error_message: str | None = None + + +__all__ = [ + "EvalExitPayload", + "EvalPayload", + "ParallelEvalComplete", +] From 84ee0d145e78bce6ef11d0df5901d5bb33665460 Mon Sep 17 00:00:00 2001 From: lorenss-m Date: Tue, 9 Dec 2025 10:20:41 -0800 Subject: [PATCH 10/92] runner and docs --- docs/beta/index.mdx | 2 +- docs/build-environments/index.mdx | 4 - docs/build-environments/spec.mdx | 4 +- docs/docs.json | 9 +- docs/evaluate-agents/benchmarks.mdx | 30 ++- docs/index.mdx | 12 +- docs/quickstart.mdx | 14 +- docs/reference/cli/eval.mdx | 2 +- docs/reference/cli/overview.mdx | 6 +- docs/reference/cli/rft.mdx | 5 +- docs/reference/cli/rl.mdx | 87 ------ docs/reference/eval.mdx | 405 ++++++++++++++++++++++++++++ docs/train-agents/quickstart.mdx | 126 --------- docs/train-agents/tasks.mdx | 80 ------ hud/datasets/runner.py | 218 +++++++++------ hud/eval/manager.py | 96 +++++-- hud/eval/tests/test_parallel.py | 6 +- hud/misc/__init__.py | 1 - hud/misc/claude_plays_pokemon.py | 292 -------------------- 19 files changed, 664 insertions(+), 735 deletions(-) delete mode 100644 docs/reference/cli/rl.mdx create mode 100644 docs/reference/eval.mdx delete mode 100644 docs/train-agents/quickstart.mdx delete mode 100644 docs/train-agents/tasks.mdx delete mode 100644 hud/misc/__init__.py delete mode 100644 hud/misc/claude_plays_pokemon.py diff --git a/docs/beta/index.mdx b/docs/beta/index.mdx index b318cad3..6485a3fd 100644 --- a/docs/beta/index.mdx +++ b/docs/beta/index.mdx @@ -11,5 +11,5 @@ Beta features are experimental and may change in future releases. ## Available Beta Features - Fine-tune models with reinforcement learning on your HUD tasks (invite-only) + Fine-tune models on your HUD tasks (invite-only) diff --git a/docs/build-environments/index.mdx b/docs/build-environments/index.mdx index 40ec910f..d022c1ed 100644 --- a/docs/build-environments/index.mdx +++ b/docs/build-environments/index.mdx @@ -66,9 +66,6 @@ hud eval tasks.json # Deploy to registry hud push - -# Train agents on your tasks -hud rl tasks.json ``` --- @@ -83,7 +80,6 @@ hud rl tasks.json | Troubleshoot | `hud debug my-env:dev` | | Build image | `hud build` | | Push to registry | `hud push` | -| RL training | `hud rl tasks.json` | --- diff --git a/docs/build-environments/spec.mdx b/docs/build-environments/spec.mdx index a87160df..61069b21 100644 --- a/docs/build-environments/spec.mdx +++ b/docs/build-environments/spec.mdx @@ -24,7 +24,7 @@ graph TD - No non‑MCP output on stdout (all logging to stderr). - No required file layout, framework, or endpoints. -Recommended (for HUD RL/evals): provide tools named `setup` and `evaluate`. +Recommended (for HUD evals): provide tools named `setup` and `evaluate`. ## Make it runnable remotely (mcp.hud.ai) @@ -143,7 +143,7 @@ The same structure is used by `hud init`’s template and by programmatic tasks. ] ``` -Switching this file to remote is as simple as replacing the `mcp_config` with the `hud` section shown above (or using `hud rl`, which will help convert it automatically). +Switching this file to remote is as simple as replacing the `mcp_config` with the `hud` section shown above (or using `hud convert`, which will help convert it automatically). Run tasks with either the CLI or an agent: diff --git a/docs/docs.json b/docs/docs.json index dd814276..d2f9e789 100644 --- a/docs/docs.json +++ b/docs/docs.json @@ -50,6 +50,7 @@ { "group": "SDK Reference", "pages": [ + "reference/eval", "reference/tools", "reference/agents", "reference/types", @@ -64,13 +65,6 @@ "build-environments/spec" ] }, - { - "group": "Training (RL)", - "pages": [ - "train-agents/quickstart", - "train-agents/tasks" - ] - }, { "group": "HUD Gateway", "pages": [ @@ -103,7 +97,6 @@ "reference/cli/debug", "reference/cli/run", "reference/cli/eval", - "reference/cli/rl", "reference/cli/rft", "reference/cli/misc" ] diff --git a/docs/evaluate-agents/benchmarks.mdx b/docs/evaluate-agents/benchmarks.mdx index b63d9b17..09561a30 100644 --- a/docs/evaluate-agents/benchmarks.mdx +++ b/docs/evaluate-agents/benchmarks.mdx @@ -18,7 +18,30 @@ hud eval tasks.json hud eval hud-evals/SheetBench-50 claude --full ``` -- SDK +- SDK (Context Manager) + +```python +import hud + +# Single task evaluation +async with hud.eval("hud-evals/SheetBench-50:0") as ctx: + agent = MyAgent() + result = await agent.run(ctx) + ctx.reward = result.reward + +# All tasks with variants +async with hud.eval( + "hud-evals/SheetBench-50:*", + variants={"model": ["claude-sonnet", "gpt-4o"]}, + group=3, + max_concurrent=50, +) as ctx: + agent = create_agent(model=ctx.variants["model"]) + result = await agent.run(ctx) + ctx.reward = result.reward +``` + +- SDK (Batch Execution) ```python from hud.datasets import run_tasks @@ -108,8 +131,9 @@ results = await run_tasks( ## See Also -- [`hud eval`](/reference/cli/eval) -- [`hud rl`](/reference/cli/rl) +- [Evaluation API](/reference/eval) - SDK reference for `hud.eval()` +- [`hud eval`](/reference/cli/eval) - CLI reference +- [`hud rft`](/reference/cli/rft) - [Tasks](/reference/tasks) - [Agents (SDK)](/reference/agents) diff --git a/docs/index.mdx b/docs/index.mdx index fa09fb8b..ecccffeb 100644 --- a/docs/index.mdx +++ b/docs/index.mdx @@ -1,6 +1,6 @@ --- title: "Introduction" -description: "OSS RL environment + evals toolkit." +description: "OSS environment + evals toolkit for AI agents." icon: "book" --- @@ -8,7 +8,7 @@ icon: "book" **Version 0.4.73** - Latest stable release - + Test Claude, Operator, or custom agents on benchmarks like SheetBench and OSWorld @@ -16,15 +16,11 @@ icon: "book" Wrap any software in dockerized MCP for scalable and generalizable agent evaluation - - - Use reinforcement learning and GRPO on evaluations to improve agent performance - ## What is HUD? -HUD connects AI agents to software environments using the Model Context Protocol (MCP). Whether you're evaluating existing agents, building new environments, or training models with RL, HUD provides the infrastructure. +HUD connects AI agents to software environments using the Model Context Protocol (MCP). Whether you're evaluating existing agents or building new environments, HUD provides the infrastructure. ```mermaid graph LR @@ -49,7 +45,7 @@ graph LR - **⚡ HUD Gateway**: Unified inference API for all LLMs - **🚀 Production-ready**: From local Docker to cloud scale - **🎯 Built-in benchmarks**: OSWorld-Verified, SheetBench-50, and more - - **🔧 CLI tools**: Create, develop, run, and train with `hud init`, `hud dev`, `hud run`, `hud eval`, `hud rl` +- **🔧 CLI tools**: Create, develop, and run with `hud init`, `hud dev`, `hud run`, `hud eval` diff --git a/docs/quickstart.mdx b/docs/quickstart.mdx index 650f200a..6e14401c 100644 --- a/docs/quickstart.mdx +++ b/docs/quickstart.mdx @@ -55,7 +55,19 @@ Get up and running with HUD in minutes. Follow these four steps to install the C -## Environments/CLI Quick Reference +## SDK Quick Reference + +```python +import hud + +# Run evaluation with the new eval API +async with hud.eval("hud-evals/SheetBench-50:0") as ctx: + agent = MyAgent() + result = await agent.run(ctx) + ctx.reward = result.reward +``` + +## CLI Quick Reference ```bash # Create sample environment diff --git a/docs/reference/cli/eval.mdx b/docs/reference/cli/eval.mdx index bf8553e6..9658df6c 100644 --- a/docs/reference/cli/eval.mdx +++ b/docs/reference/cli/eval.mdx @@ -224,5 +224,5 @@ hud cancel --all - [Tasks Reference](/reference/tasks) - Task configuration - [Agents Reference](/reference/agents) - Agent options -- [`hud rl`](/reference/cli/rl) - RL training +- [`hud rft`](/reference/cli/rft) - Reinforcement fine-tuning - [`hud cancel`](/reference/cli/misc) - Cancel remote jobs diff --git a/docs/reference/cli/overview.mdx b/docs/reference/cli/overview.mdx index a474e3ef..49d226a1 100644 --- a/docs/reference/cli/overview.mdx +++ b/docs/reference/cli/overview.mdx @@ -21,8 +21,7 @@ The HUD CLI provides a complete toolkit for creating, developing, and running MC - `hud debug` — 5‑phase compliance test - `hud run` — Execute (Python module/command/Docker) - `hud eval` — Run agents on tasks/datasets - - `hud rl` — Train with GRPO on tasks - - `hud rft` — Fine-tune models with RL (BETA, invite-only) + - `hud rft` — Fine-tune models (BETA, invite-only) @@ -62,8 +61,7 @@ hud --version | `hud debug` | Image/dir/config | 5‑phase compliance test | `hud debug my-env:latest` | | `hud run` | Module/command/image | Execute server (local/remote) | `hud run controller --reload` | | `hud eval` | Tasks/dataset | Run agent on tasks | `hud eval tasks.json claude` | -| `hud rl` | Tasks/dataset | Train with GRPO | `hud rl tasks.json --local` | -| `hud rft` | Tasks file | Fine-tune with RL (BETA, invite-only) | `hud rft run tasks.json` | +| `hud rft` | Tasks file | Fine-tune models (BETA, invite-only) | `hud rft run tasks.json` | ### Other Commands | Command | Description | Example | diff --git a/docs/reference/cli/rft.mdx b/docs/reference/cli/rft.mdx index 8d1d3be1..771b806d 100644 --- a/docs/reference/cli/rft.mdx +++ b/docs/reference/cli/rft.mdx @@ -1,6 +1,6 @@ --- title: "hud rft" -description: "Reinforcement Fine-Tuning commands (invite-only)" +description: "Fine-Tuning commands (invite-only)" icon: "brain-circuit" --- @@ -12,7 +12,7 @@ RFT is currently in BETA. Features and APIs may change. **Access Required**: RFT is available by invite only. Contact [founders@hud.ai](mailto:founders@hud.ai) to request access. -The `hud rft` command group provides tools for fine-tuning models using reinforcement learning on HUD tasks. +The `hud rft` command group provides tools for fine-tuning models on HUD tasks. ## Subcommands @@ -133,4 +133,3 @@ hud rft status f5f050a3-99c1-4339-b819-ccb1325f79d8 --verbose ## See Also - [Beta RFT Documentation](/beta/rft) - Detailed guide and examples -- [hud rl](/reference/cli/rl) - Standard reinforcement learning training diff --git a/docs/reference/cli/rl.mdx b/docs/reference/cli/rl.mdx deleted file mode 100644 index f644770b..00000000 --- a/docs/reference/cli/rl.mdx +++ /dev/null @@ -1,87 +0,0 @@ ---- -title: "hud rl" -description: "Run GRPO reinforcement learning on tasks" -icon: "brain" ---- - -The `hud rl` command trains an agent with GRPO on tasks, locally or via the HUD remote service. - -## Usage - -```bash -hud rl [TASKS_FILE|DATASET] [MODEL] [OPTIONS] -``` - -## Arguments - - - Path to tasks JSON/JSONL file or HuggingFace dataset name. If omitted, looks for a tasks file in the current directory. - - - - Model to train (default: interactive selection) - - -## Options - - - Path to existing configuration file. Short: `-c` - - - - Output directory for checkpoints. Short: `-o` - - - - Restart the vLLM server before training - - - - Enable verbose output. Short: `-v` - - - - Disable DistributedDataParallel (even with multiple GPUs) - - - - Specific GPUs for DDP (e.g., `0,1,2,3`) - - - - Specific GPU for vLLM server - - - - Run training locally instead of the remote HUD server - - -## Behavior - -- If no tasks file is provided, an interactive picker helps locate one. -- Remote mode (default) converts tasks to remote MCP automatically (build/push as needed) and launches remote training. -- Local mode runs training on your machine (delegated to `local_runner`). - -## Examples - -```bash -# Remote (default): auto-convert tasks to remote, then train -hud rl tasks.json --model claude-rl - -# Local training with GPU selection -hud rl tasks.json llama3.1 --local --ddp-gpus 0,1 --vllm-gpu 0 - -# Use a dataset directly (remote) -hud rl hud-evals/SheetBench-50 --model claude-rl -``` - -## See Also - -- [`hud eval`](/reference/cli/eval) -- [`hud get`](/reference/cli/get) -- [`hud build`](/reference/cli/build) -- [`hud push`](/reference/cli/push) - -## Pricing & Billing - -See hosted vLLM and training GPU rates in the [Training Quickstart → Pricing](/train-agents/quickstart#pricing). Manage usage and billing at `https://hud.ai/project/billing`. \ No newline at end of file diff --git a/docs/reference/eval.mdx b/docs/reference/eval.mdx new file mode 100644 index 00000000..3ef0f7e8 --- /dev/null +++ b/docs/reference/eval.mdx @@ -0,0 +1,405 @@ +--- +title: "Evaluation API" +description: "SDK reference for running evaluations with hud.eval()" +icon: "flask-vial" +--- + +The HUD SDK provides a unified evaluation API through `hud.eval()` for tracking agent performance, running parallel evaluations, and integrating with the HUD platform. + +## Overview + +There are three ways to run evaluations: + +1. **`hud.eval()`** - Standalone context manager for any evaluation +2. **`env.eval()`** - Method on `Environment` for evaluating within an existing environment +3. **`run_tasks()`** - High-level batch execution with automatic agent creation + +## hud.eval() + +The primary evaluation context manager. Creates an `EvalContext` which is a full `Environment` with evaluation tracking. + +```python +import hud + +async with hud.eval("my-org/browser-task:1") as ctx: + # ctx is an EvalContext (extends Environment) + tools = await ctx.list_tools() + result = await ctx.call_tool("navigate", url="https://example.com") + ctx.reward = 1.0 # Set the evaluation reward +``` + +### Parameters + +| Parameter | Type | Description | Default | +|-----------|------|-------------|---------| +| `source` | `str \| list[str] \| Task \| list[Task] \| None` | Task source (slugs or Task objects) | `None` | +| `variants` | `dict[str, Any] \| None` | A/B test configuration | `None` | +| `group` | `int` | Runs per variant for statistical significance | `1` | +| `group_ids` | `list[str] \| None` | Custom group IDs for parallel runs | `None` | +| `job_id` | `str \| None` | Job ID to link traces to | `None` | +| `api_key` | `str \| None` | API key for backend calls | `None` | +| `max_concurrent` | `int \| None` | Maximum concurrent evaluations | `None` | + +### Task Sources + +The `source` parameter accepts multiple formats: + +```python +# 1. Blank evaluation (manual reward) +async with hud.eval() as ctx: + ctx.reward = compute_reward() + +# 2. Single task slug +async with hud.eval("my-org/browser-task") as ctx: + await agent.run(ctx) + +# 3. Task at specific index +async with hud.eval("my-org/evalset:0") as ctx: + await agent.run(ctx) + +# 4. All tasks in an evalset (wildcard) +async with hud.eval("my-org/evalset:*") as ctx: + await agent.run(ctx) + +# 5. Multiple slugs +async with hud.eval(["task:0", "task:1", "task:2"]) as ctx: + await agent.run(ctx) + +# 6. Task objects directly (backwards compatible) +from hud.types import Task +tasks = [Task(prompt="Navigate to docs", mcp_config={...})] +async with hud.eval(tasks) as ctx: + await agent.run(ctx) +``` + +### Variants and Groups + +Run A/B tests with multiple configurations: + +```python +# Test different models +async with hud.eval( + "my-org/evalset:*", + variants={"model": ["gpt-4o", "claude-sonnet"]}, + group=3, # 3 runs per variant for statistical significance +) as ctx: + model = ctx.variants["model"] # Current variant assignment + agent = create_agent(model=model) + result = await agent.run(ctx) + ctx.reward = result.reward + +# Access all results after completion +for result in ctx.results: + print(f"{result.variants}: reward={result.reward}") +``` + +**How it works:** +- `variants` dict with list values creates the cartesian product +- `group` multiplies each variant combination +- Total runs = `len(tasks) × len(variant_combos) × group` +- For parallel runs (total > 1), a job is automatically created + +### Concurrency Control + +Limit concurrent evaluations to manage resources: + +```python +async with hud.eval( + "my-org/large-evalset:*", + max_concurrent=10, # Max 10 parallel evaluations +) as ctx: + await agent.run(ctx) +``` + +## env.eval() + +Create evaluation contexts from an existing `Environment`: + +```python +from hud import Environment + +async with Environment() as env: + # Connect to MCP servers + await env.connect_hub("test-browser-26") + + # Run evaluation within this environment + async with env.eval("my-evaluation", group=3) as ctx: + # ctx inherits env's connections + tools = await ctx.list_tools() + await agent.run(ctx) + ctx.reward = result.reward +``` + +### Parameters + +Same as `hud.eval()`, plus: + +| Parameter | Type | Description | Default | +|-----------|------|-------------|---------| +| `name` | `str` | Evaluation name (required) | Required | +| `trace_id` | `str \| None` | Custom trace ID | `None` | + +### Connection Inheritance + +When you call `env.eval()`, the `EvalContext` copies the parent environment's connections: + +```python +async with Environment() as env: + await env.connect_hub("my-hub") + + # Parallel evaluations each get their own connection copies + async with env.eval("test", group=3) as ctx: + # Each parallel run has independent connections + await ctx.call_tool("my_tool") +``` + +## EvalContext + +`EvalContext` extends `Environment` with evaluation-specific functionality. + +### Properties + +| Property | Type | Description | +|----------|------|-------------| +| `trace_id` | `str` | Unique trace identifier | +| `eval_name` | `str` | Evaluation name | +| `job_id` | `str \| None` | Parent job ID | +| `group_id` | `str \| None` | Group ID for parallel runs | +| `index` | `int` | Index in parallel execution | +| `variants` | `dict[str, Any]` | Current variant assignment | +| `reward` | `float \| None` | Evaluation reward (settable) | +| `error` | `BaseException \| None` | Error if evaluation failed | +| `results` | `list[EvalContext] \| None` | Results from parallel runs | +| `task` | `Task \| None` | Task definition (if loaded from slug) | +| `prompt` | `str \| None` | Task prompt | +| `headers` | `dict[str, str]` | Trace headers for HTTP requests | + +### Methods + +All `Environment` methods are available, plus: + +```python +# Set reward +ctx.reward = 1.0 + +# Access task configuration +if ctx.task: + print(ctx.task.prompt) + print(ctx.task.agent_config) # Agent configuration hints + +# Get trace headers for external HTTP calls +headers = ctx.headers # {"Trace-Id": "...", "Trace-Parent": "..."} +``` + +### Creating from Task + +```python +from hud.eval.context import EvalContext +from hud.types import Task + +task = Task( + prompt="Navigate to the docs page", + mcp_config={"hud": {"url": "...", "headers": {...}}}, + setup_tool={"name": "setup", "arguments": {...}}, + evaluate_tool={"name": "evaluate", "arguments": {...}}, +) + +ctx = EvalContext.from_task(task) +async with ctx: + # MCP connections configured from task.mcp_config + # setup_tool and evaluate_tool configured + tools = await ctx.list_tools() +``` + +## run_tasks() + +High-level batch execution that creates agents automatically: + +```python +from hud.datasets import run_tasks +from hud.types import AgentType +from hud.utils.tasks import load_tasks + +# Load tasks from HuggingFace or file +tasks = load_tasks("hud-evals/SheetBench-50") + +# Run with automatic agent creation +results = await run_tasks( + tasks=tasks, + agent_type=AgentType.CLAUDE, + agent_params={"checkpoint_name": "claude-sonnet-4-5"}, + max_concurrent=30, + max_steps=10, + group_size=3, # 3 runs per task +) +``` + +### Parameters + +| Parameter | Type | Description | Default | +|-----------|------|-------------|---------| +| `tasks` | `list[Task]` | List of Task objects | Required | +| `agent_type` | `AgentType` | Agent type enum | Required | +| `agent_params` | `dict[str, Any] \| None` | Agent configuration | `None` | +| `name` | `str` | Job name | `"Evaluation"` | +| `max_concurrent` | `int` | Maximum concurrent tasks | `30` | +| `metadata` | `dict[str, Any] \| None` | Job metadata | `None` | +| `max_steps` | `int` | Maximum steps per task | `10` | +| `group_size` | `int` | Runs per task | `1` | +| `remote` | `bool` | Submit to HUD platform | `False` | + +### Returns + +- If `group_size == 1`: `list[Trace]` - Results in task order +- If `group_size > 1`: `list[dict]` - Statistics per task group + +### Remote Execution + +Submit tasks to the HUD platform for remote execution: + +```python +await run_tasks( + tasks=tasks, + agent_type=AgentType.CLAUDE, + remote=True, # Submit to platform +) +# Returns immediately, monitor at https://hud.ai/jobs/{job_id} +``` + +## Task Configuration + +Tasks define the evaluation environment and success criteria: + +```python +from hud.types import Task + +task = Task( + id="nav-001", + prompt="Navigate to the documentation page", + mcp_config={ + "hud": { + "url": "https://mcp.hud.ai/v3/mcp", + "headers": { + "Authorization": "Bearer ${HUD_API_KEY}", + "Mcp-Image": "hudpython/hud-remote-browser:latest" + } + } + }, + setup_tool={ + "name": "setup", + "arguments": {"name": "navigate", "arguments": {"url": "https://example.com"}} + }, + evaluate_tool={ + "name": "evaluate", + "arguments": {"name": "url_match", "arguments": {"pattern": ".*/docs.*"}} + }, + agent_config={ + "allowed_tools": ["playwright", "computer"], + "system_prompt": "You are a web navigation agent." + }, + metadata={"difficulty": "easy", "category": "navigation"} +) +``` + +### Task Fields + +| Field | Type | Description | +|-------|------|-------------| +| `id` | `str \| None` | Unique task identifier | +| `prompt` | `str` | Task instruction | +| `mcp_config` | `dict[str, Any]` | MCP server configuration | +| `setup_tool` | `MCPToolCall \| list[MCPToolCall] \| None` | Setup tool calls | +| `evaluate_tool` | `MCPToolCall \| list[MCPToolCall] \| None` | Evaluation tool calls | +| `agent_config` | `BaseAgentConfig \| None` | Agent configuration hints | +| `metadata` | `dict[str, Any]` | Custom metadata | + +### Environment Variable Substitution + +MCP config supports `${VAR_NAME}` substitution: + +```python +mcp_config = { + "hud": { + "url": "${HUD_MCP_URL:https://mcp.hud.ai/v3/mcp}", # With default + "headers": { + "Authorization": "Bearer ${HUD_API_KEY}" # From environment + } + } +} +``` + +## HTTP Instrumentation + +When running inside an eval context, HTTP requests to HUD services automatically include trace headers: + +```python +import httpx + +async with hud.eval("test") as ctx: + # Trace headers are automatically injected + async with httpx.AsyncClient() as client: + # Requests to inference.hud.ai, mcp.hud.ai include Trace-Id + response = await client.post( + "https://inference.hud.ai/v1/messages", + json={...} + ) +``` + +This enables automatic telemetry linking without manual header management. + +## Best Practices + +### 1. Use Variants for A/B Testing + +```python +async with hud.eval( + "evalset:*", + variants={ + "model": ["gpt-4o", "claude"], + "temperature": [0.0, 0.7], + }, + group=3, +) as ctx: + # Runs: 2 models × 2 temps × 3 groups = 12 evaluations + ... +``` + +### 2. Set Rewards Consistently + +```python +async with hud.eval("task") as ctx: + try: + result = await agent.run(ctx) + ctx.reward = result.reward + except Exception as e: + ctx.reward = 0.0 # Explicit failure reward + raise +``` + +### 3. Use Concurrency Limits for Resource-Heavy Tasks + +```python +async with hud.eval( + "browser-tasks:*", + max_concurrent=5, # Browser instances are heavy +) as ctx: + ... +``` + +### 4. Access Task Agent Config + +```python +async with hud.eval("my-org/task:0") as ctx: + if ctx.task and ctx.task.agent_config: + # Apply task's agent hints + allowed_tools = ctx.task.agent_config.allowed_tools + system_prompt = ctx.task.agent_config.system_prompt +``` + +## See Also + +- [`hud eval` CLI](/reference/cli/eval) - Command-line interface +- [Benchmarks](/evaluate-agents/benchmarks) - Creating and running benchmarks +- [Tasks](/reference/tasks) - Task configuration reference +- [Environments](/reference/environments) - Building MCP environments + diff --git a/docs/train-agents/quickstart.mdx b/docs/train-agents/quickstart.mdx deleted file mode 100644 index 32e83471..00000000 --- a/docs/train-agents/quickstart.mdx +++ /dev/null @@ -1,126 +0,0 @@ ---- -title: "RL Quickstart" -icon: "graduation-cap" ---- - -## Prerequisites - -- HUD API key: Remote training requires authentication. Set `HUD_API_KEY` before running: - -```bash -export HUD_API_KEY="sk-hud-..." # get one at https://hud.ai -# Or persist it locally: -hud set HUD_API_KEY=sk-hud-... -``` - -- Docker daemon: For local runs (using `--local`) or when training against a local Docker image, ensure Docker Desktop is installed and the Docker daemon is running. - -## Quickstart - -Install and download a taskset: - -```bash -uv tool install hud-python@latest --python 3.12 -hud get hud-evals/2048-basic -``` - -### 1) Simple: Train (remote by default) - -```bash -hud rl 2048-basic.json -``` - -This launches training remotely and automatically provisions a vLLM server and a trainer for you. You can monitor progress on https://hud.ai. The server persists between runs, so you can rerun training or evaluate against the same endpoint. - -Optional baseline first (Claude or Operator): - -```bash -hud eval 2048-basic.json -``` - -### 2) Run on your own machine/remote - -Use any provider with at least 2 GPUs (one for inference, one for training). Run locally with the flag `--local`: - -```bash -uv tool install hud-python@latest --python 3.12 -hud get hud-evals/2048-basic -hud rl 2048-basic.json --local -``` - -### Recommended setups - -- 2× A100: quick iteration, shorter runs -- 8× A100: higher throughput for larger tasksets - -Training throughput depends on task complexity and parallelism (`max_parallel_episodes`). - -### 3) Build your own environment (hud init) - -Create a new MCP environment, develop with hot-reload, and train on a production image: - -```bash -hud init my-env && cd my-env -hud dev --interactive -# When ready to run: -hud rl -``` - -Change the tasks.json to include other tasks you want to train on. - -See [hud init](/reference/cli/init) for options and details. - - -## Getting the best performance - -Often training a good model requires many iterations over the parameters of the trainer. Take the config generated by `hud rl` and modify it to various values to do a hyperparameter sweep. - -For easy launching, specify the tasks and config upfront, and add `--yes` to automatically launch vllm and training. - -```bash -hud rl taskset.json --config rl-config.json --yes -``` - -Additionally, sometimes it may be helpful to run an initial analysis on the dataset to determine which tasks would be the most informative to trian on. In that case either start with a deployed model or run `hud rl` without training, and then: - -```bash -hud eval taskset.json --full --group-size 6 --max-steps 5 -``` - -This will prompt you for the model choice, produce a table of accuracies per task. Prefer tasks which are 10%-60% accurate for training. - -Some general findings from our internal training runs: -- As many different tasks per gradient update as possible (runs with 4+ GPUs and batch size of 50+ are much more stable than single GPU runs) -- Batch size should be somewhere around 2/X where X is the accuracy of that given task on an untrained model. - -### Pricing - -Below is the pricing by GPU type. Actual prices vary — see https://hud.ai/project/billing for current rates. - -vLLM GPU Pricing (2 Hosted GPUs) - -| GPU type | Memory | Est. price/hr | -| --- | --- | --- | -| A100 80GB | 80 GB | $4.95 | -| H100 80GB | 80 GB | $7.95 | - -Training GPU Pricing - -| GPU type | Memory | Est. price/hr | -| --- | --- | --- | -| A100 80GB | 80 GB | $3.95 | -| H100 80GB | 80 GB | $5.40 | - ---- - -### Learn more - - - - Complete guide to building environments from scratch - - - - Full `hud rl` command options and usage - - \ No newline at end of file diff --git a/docs/train-agents/tasks.mdx b/docs/train-agents/tasks.mdx deleted file mode 100644 index 58131b6f..00000000 --- a/docs/train-agents/tasks.mdx +++ /dev/null @@ -1,80 +0,0 @@ ---- -title: Dataset Design -icon: table ---- - -## Tasks format - -HUD tasksets can be provided in two primary formats (both supported): - -1) A single JSON file containing a list of task objects (recommended) - -```json -[ - { - "id": "browser_2048_128", - "prompt": "Reach 128 in 2048.", - "mcp_config": { - "hud": { - "url": "https://mcp.hud.ai/v3/mcp", - "headers": { - "Authorization": "Bearer ${HUD_API_KEY}", - "Mcp-Image": "hudevals/hud-browser:0.1.3" - } - } - }, - "setup_tool": {"name": "launch_app", "arguments": {"app_name": "2048"}}, - "evaluate_tool": {"name": "evaluate", "arguments": {"name": "game_2048_max_number", "arguments": {"target": 128}}} - } -] -``` - -Save as `2048-basic.json` and run: - -```bash -hud eval 2048-basic.json -hud rl 2048-basic.json -``` - -2) JSONL file with one task object per line - -- prompt: instruction for the agent -- mcp_config: where to run the environment (local docker or remote MCP) -- setup_tool (optional): a tool call to prepare the environment -- evaluate_tool: a tool call to compute reward -- system_prompt (optional): extra guidance for the agent - -## Hosting on HuggingFace - -You can host tasksets on the Hub and fetch them with: - -```bash -hud get hud-evals/2048-basic -``` - -The command downloads the JSONL task file and places it in your project directory. - -This allows running the full dataset or training with simply: - -```bash -hud eval hud-evals/2048-basic -hud rl hud-evals/2048-basic -``` - -## Tips - -- Keep tasks self-contained; use `setup_tool` to open apps or load data -- Ensure `evaluate_tool` returns a numeric reward per episode -- Use small task counts to iterate quickly; scale up once stable - - - - Learn how to run benchmarks - - - - Deep-dive into MCP configs and tools - - - - diff --git a/hud/datasets/runner.py b/hud/datasets/runner.py index 9a4103f7..7f5a7aee 100644 --- a/hud/datasets/runner.py +++ b/hud/datasets/runner.py @@ -56,20 +56,24 @@ async def run_single_task( Returns: Trace result from agent execution """ - from hud.telemetry import async_trace + from hud.eval.context import EvalContext name = trace_name or task.prompt or task_id or "task" - async with async_trace( - name, + ctx = EvalContext.from_task( + task=task, + name=name, + trace_id=trace_id, job_id=job_id, - task_id=task_id, group_id=group_id, - trace_id=trace_id, - attrs=metadata or {}, - ): + ) + + async with ctx: agent = agent_type.cls.create(**(agent_params or {})) - return await agent.run(task, max_steps=max_steps) + result = await agent.run(task, max_steps=max_steps) + # Transfer reward to context for tracking + ctx.reward = result.reward + return result async def run_tasks( @@ -119,8 +123,7 @@ async def run_tasks( # Submit for remote execution await run_tasks(tasks, AgentType.CLAUDE, remote=True) """ - import hud - from hud.telemetry import async_job + from hud.eval.display import print_complete, print_link from hud.utils.hud_console import HUDConsole job_metadata = metadata or {} @@ -131,9 +134,11 @@ async def run_tasks( job_metadata["total_episodes"] = len(tasks) * group_size if remote: + from hud.telemetry.job import create_job + hud_console = HUDConsole() - job = hud.create_job(name, metadata=job_metadata) + job = create_job(name, metadata=job_metadata) job.update_status_sync("created") await submit_rollouts( @@ -149,13 +154,102 @@ async def run_tasks( hud_console.info(f"Monitor progress at: https://hud.ai/jobs/{job.id}") return [] - # Local execution + # Local execution using new eval system agent_class = agent_type.cls + job_id = str(uuid.uuid4()) + job_url = f"https://hud.ai/jobs/{job_id}" + + # Print job URL + print_link(job_url, f"🚀 Job '{name}'") - async with async_job(name, metadata=job_metadata) as job_obj: - return await _run_tasks( - tasks, agent_class, agent_params, max_concurrent, max_steps, group_size, job_obj + error_occurred = False + try: + results = await _run_tasks_with_eval( + tasks=tasks, + agent_class=agent_class, + agent_params=agent_params, + max_concurrent=max_concurrent, + max_steps=max_steps, + group_size=group_size, + job_id=job_id, ) + error_occurred = any(r is None or (isinstance(r, Trace) and r.isError) for r in results) + return results + except Exception: + error_occurred = True + raise + finally: + print_complete(job_url, name, error=error_occurred) + + +async def _run_tasks_with_eval( + tasks: list[Task], + agent_class: type[MCPAgent], + agent_params: dict[str, Any] | None, + max_concurrent: int, + max_steps: int, + group_size: int, + job_id: str, +) -> list[Any]: + """Run tasks using the new EvalContext system.""" + from hud.eval.context import EvalContext + + sem = asyncio.Semaphore(max_concurrent) + params = agent_params or {} + + # Generate group IDs for each task (used for telemetry grouping) + group_ids = {i: str(uuid.uuid4()) for i in range(len(tasks))} + + # Expand tasks: each task runs group_size times + expanded: list[tuple[int, int, Task]] = [] # (flat_idx, task_idx, task) + for task_idx, task in enumerate(tasks): + for _ in range(group_size): + expanded.append((len(expanded), task_idx, task)) + + traces: list[Trace | None] = [None] * len(expanded) + + async def worker(flat_idx: int, task_idx: int, run_idx: int, task: Task) -> None: + async with sem: + try: + base_task_id = str(task.id) if task.id is not None else f"task_{task_idx}" + trace_name = task.prompt or base_task_id + + # Create EvalContext for this task run + ctx = EvalContext.from_task( + task=task, + name=trace_name, + job_id=job_id, + group_id=group_ids[task_idx] if group_size > 1 else None, + ) + ctx._suppress_link = True # Don't print individual trace links + + async with ctx: + agent = agent_class.create(**params) + result = await agent.run(task, max_steps=max_steps) + ctx.reward = result.reward + traces[flat_idx] = result + + except Exception as e: + if group_size == 1: + logger.exception("Task %s failed: %s", task_idx, e) + traces[flat_idx] = None + else: + logger.warning("Episode %s failed: %s", flat_idx, e) + traces[flat_idx] = Trace(isError=True, content=str(e), reward=0.0, done=True) + + await asyncio.gather( + *[ + worker(flat_idx, task_idx, flat_idx % group_size, task) + for flat_idx, task_idx, task in expanded + ], + return_exceptions=True, + ) + + # Return format depends on group_size + if group_size == 1: + return list(traces) + else: + return calculate_group_stats(tasks, traces, group_size, group_ids) async def run_dataset( @@ -196,7 +290,7 @@ async def run_dataset( from datasets import Dataset as HFDataset from datasets import load_dataset - from hud.telemetry import async_job + from hud.eval.display import print_complete, print_link warnings.warn( "run_dataset() is deprecated. Use run_tasks() instead for more flexibility.", @@ -236,75 +330,27 @@ async def run_dataset( job_metadata["group_size"] = group_size job_metadata["total_episodes"] = len(tasks) * group_size - async with async_job(name, metadata=job_metadata) as job_obj: - return await _run_tasks( - tasks, agent_class, agent_config, max_concurrent, max_steps, group_size, job_obj - ) - - -async def _run_tasks( - tasks: list[Task], - agent_class: type[MCPAgent], - agent_params: dict[str, Any] | None, - max_concurrent: int, - max_steps: int, - group_size: int, - job_obj: Any, -) -> list[Any]: - from hud.telemetry import async_trace - - sem = asyncio.Semaphore(max_concurrent) - params = agent_params or {} - - # Generate group IDs for each task (used for telemetry grouping) - group_ids = {i: str(uuid.uuid4()) for i in range(len(tasks))} + # Use new eval system + job_id = str(uuid.uuid4()) + job_url = f"https://hud.ai/jobs/{job_id}" - # Expand tasks: each task runs group_size times - expanded: list[tuple[int, int, Task]] = [] # (flat_idx, task_idx, task) - for task_idx, task in enumerate(tasks): - for _ in range(group_size): - expanded.append((len(expanded), task_idx, task)) - - traces: list[Trace | None] = [None] * len(expanded) + print_link(job_url, f"🚀 Job '{name}'") - async def worker(flat_idx: int, task_idx: int, run_idx: int, task: Task) -> None: - async with sem: - try: - base_task_id = str(task.id) if task.id is not None else f"task_{task_idx}" - trace_name = task.prompt or base_task_id - - if group_size == 1: - async with async_trace(trace_name, job_id=job_obj.id, task_id=base_task_id): - agent = agent_class.create(**params) - traces[flat_idx] = await agent.run(task, max_steps=max_steps) - else: - task_id_with_run = f"{base_task_id}_{run_idx}" - async with async_trace( - trace_name, - job_id=job_obj.id, - task_id=task_id_with_run, - group_id=group_ids[task_idx], - ): - agent = agent_class.create(**params) - traces[flat_idx] = await agent.run(task, max_steps=max_steps) - except Exception as e: - if group_size == 1: - logger.exception("Task %s failed: %s", task_idx, e) - traces[flat_idx] = None - else: - logger.warning("Episode %s failed: %s", flat_idx, e) - traces[flat_idx] = Trace(isError=True, content=str(e), reward=0.0, done=True) - - await asyncio.gather( - *[ - worker(flat_idx, task_idx, flat_idx % group_size, task) - for flat_idx, task_idx, task in expanded - ], - return_exceptions=True, - ) - - # Return format depends on group_size - if group_size == 1: - return list(traces) - else: - return calculate_group_stats(tasks, traces, group_size, group_ids) + error_occurred = False + try: + results = await _run_tasks_with_eval( + tasks=tasks, + agent_class=agent_class, + agent_params=agent_config, + max_concurrent=max_concurrent, + max_steps=max_steps, + group_size=group_size, + job_id=job_id, + ) + error_occurred = any(r is None or (isinstance(r, Trace) and r.isError) for r in results) + return results + except Exception: + error_occurred = True + raise + finally: + print_complete(job_url, name, error=error_occurred) diff --git a/hud/eval/manager.py b/hud/eval/manager.py index 84b6a8c2..5a2a261e 100644 --- a/hud/eval/manager.py +++ b/hud/eval/manager.py @@ -29,6 +29,10 @@ logger = logging.getLogger(__name__) +# Type alias for task source: can be slug strings or Task objects +TaskSource = "str | list[str] | Task | list[Task] | None" + + def _parse_slug(slug: str) -> tuple[str, str | None]: """Parse a task slug into (base_slug, index_or_wildcard). @@ -47,29 +51,47 @@ def _parse_slug(slug: str) -> tuple[str, str | None]: return slug, None -def _get_eval_name(slugs: str | list[str] | None) -> str: - """Extract a nice name from slugs for job display. +def _get_eval_name( + source: str | list[str] | None = None, + tasks: list[Task] | None = None, +) -> str: + """Extract a nice name for job display. Args: - slugs: Single slug or list of slugs + source: Single slug or list of slugs (if string-based) + tasks: List of Task objects (if using direct tasks) Returns: - Name like "evalset" or "eval" if no slugs + Name like "evalset", task ID, or "eval" if no source """ - if slugs is None: - return "eval" + # If we have tasks with IDs, use first task ID + if tasks: + first_task = tasks[0] + if first_task.id: + # Extract name from task ID (might be "evalset/task_name") + task_id = str(first_task.id) + if "/" in task_id: + return task_id.rsplit("/", 1)[1] + return task_id + # Fall back to prompt excerpt + if first_task.prompt: + return first_task.prompt[:30].strip() - # Get the first slug - first_slug = slugs if isinstance(slugs, str) else slugs[0] + # If we have string slugs + if source is not None: + # Get the first slug + first_slug = source if isinstance(source, str) else source[0] - # Remove index/wildcard suffix (":1" or ":*") - base_slug, _ = _parse_slug(first_slug) + # Remove index/wildcard suffix (":1" or ":*") + base_slug, _ = _parse_slug(first_slug) - # Extract the evalset name (part after last "/") - if "/" in base_slug: - return base_slug.rsplit("/", 1)[1] + # Extract the evalset name (part after last "/") + if "/" in base_slug: + return base_slug.rsplit("/", 1)[1] - return base_slug + return base_slug + + return "eval" def _load_tasks_from_slugs(slugs: str | list[str]) -> list[Task]: @@ -146,7 +168,7 @@ def _load_tasks_from_slugs(slugs: str | list[str]) -> list[Task]: @asynccontextmanager async def run_eval( - slugs: str | list[str] | None = None, + source: str | list[str] | Task | list[Task] | None = None, *, variants: dict[str, Any] | None = None, group: int = 1, @@ -158,15 +180,15 @@ async def run_eval( """Standalone eval context manager. Creates an EvalContext for evaluation, optionally loading task configuration - from slugs. + from slugs or using Task objects directly. Args: - slugs: Task slug(s) to load. Can be: + source: Task source. Can be: - None: Create blank eval context - - "my-org/task": Single task - - "my-org/task:N": Task at index N - - "my-org/task:*": All tasks matching pattern - - List of any above: Multiple tasks + - str: Task slug like "my-org/task", "my-org/task:N", "my-org/task:*" + - list[str]: Multiple task slugs + - Task: Single Task object (for backwards compat with run_tasks) + - list[Task]: List of Task objects (for backwards compat with run_tasks) variants: A/B test configuration (dict with list values expanded) group: Runs per variant for statistical significance group_ids: Optional list of group IDs @@ -196,6 +218,13 @@ async def run_eval( async with hud.eval("my-org/evalset:*") as ctx: await agent.run(ctx) + # With Task objects directly + from hud.types import Task + + tasks = [Task(prompt="Do X", mcp_config={...})] + async with hud.eval(tasks) as ctx: + await agent.run(ctx) + # With variants and group async with hud.eval( "task", @@ -215,16 +244,33 @@ async def run_eval( print(f"{e.variants}: reward={e.reward}") ``` """ + from hud.types import Task + if group <= 0: raise ValueError("group must be >= 1") # Expand variants variant_combos = expand_variants(variants) - # Load tasks if slugs provided + # Parse source into tasks list tasks: list[Task] = [] - if slugs is not None: - tasks = _load_tasks_from_slugs(slugs) + slugs: str | list[str] | None = None # Track if we had string slugs (for naming) + + if source is not None: + if isinstance(source, Task): + # Single Task object + tasks = [source] + elif isinstance(source, list) and source and isinstance(source[0], Task): + # List of Task objects + tasks = source # type: ignore[assignment] + elif isinstance(source, str): + # String slug + slugs = source + tasks = _load_tasks_from_slugs(source) + elif isinstance(source, list) and source and isinstance(source[0], str): + # List of string slugs + slugs = source # type: ignore[assignment] + tasks = _load_tasks_from_slugs(source) # type: ignore[arg-type] # Calculate total evaluations # If we have tasks, each task gets (variants x group) runs @@ -274,7 +320,7 @@ async def run_eval( else: # Parallel execution: create implicit job to group traces - eval_name = _get_eval_name(slugs) + eval_name = _get_eval_name(source=slugs, tasks=tasks) implicit_job_id = job_id or str(uuid.uuid4()) job_url = f"https://hud.ai/jobs/{implicit_job_id}" diff --git a/hud/eval/tests/test_parallel.py b/hud/eval/tests/test_parallel.py index 9fef3f98..8750b447 100644 --- a/hud/eval/tests/test_parallel.py +++ b/hud/eval/tests/test_parallel.py @@ -179,7 +179,7 @@ async def test_runs_body_for_each_context(self) -> None: body_source = "env.reward = env.index * 10" captured_locals: dict[str, object] = {} - results = await run_parallel_evals(mock_ctxs, body_source, captured_locals) + results = await run_parallel_evals(mock_ctxs, body_source, captured_locals, "env") assert len(results) == 3 # Each context should have had __aenter__ and __aexit__ called @@ -199,7 +199,7 @@ async def test_captures_exceptions(self) -> None: body_source = "raise ValueError('test error')" captured_locals: dict[str, object] = {} - results = await run_parallel_evals([ctx], body_source, captured_locals) + results = await run_parallel_evals([ctx], body_source, captured_locals, "env") assert len(results) == 1 # Error should be captured, not raised @@ -218,7 +218,7 @@ async def test_uses_captured_locals(self) -> None: body_source = "env.result = my_value * 2" captured_locals = {"my_value": 21} - results = await run_parallel_evals([ctx], body_source, captured_locals) + results = await run_parallel_evals([ctx], body_source, captured_locals, "env") assert len(results) == 1 diff --git a/hud/misc/__init__.py b/hud/misc/__init__.py deleted file mode 100644 index 40fb1d81..00000000 --- a/hud/misc/__init__.py +++ /dev/null @@ -1 +0,0 @@ -"""Miscellaneous utilities for HUD SDK.""" diff --git a/hud/misc/claude_plays_pokemon.py b/hud/misc/claude_plays_pokemon.py deleted file mode 100644 index 96b78ae2..00000000 --- a/hud/misc/claude_plays_pokemon.py +++ /dev/null @@ -1,292 +0,0 @@ -# pyright: reportGeneralTypeIssues=false -from __future__ import annotations - -import json -import logging -from typing import TYPE_CHECKING, Any, cast - -from anthropic import AsyncAnthropic - -from hud.adapters import Adapter -from hud.adapters.common.types import CLA - -# Update import to current API; if this script is legacy, keep it optional -try: - from hud.agents import MCPAgent as Agent # type: ignore[assignment] -except Exception: # pragma: no cover - optional example script - from hud.agents import MCPAgent as Agent # fallback -from hud.settings import settings - -if TYPE_CHECKING: - from anthropic.types.beta import ( - BetaImageBlockParam, - BetaMessageParam, - BetaTextBlockParam, - ) - - from hud.env.environment import Observation - -logger = logging.getLogger(__name__) - -# Constants -DEFAULT_MODEL = "claude-3-7-sonnet-20250219" -DEFAULT_MAX_TOKENS = 4096 -DEFAULT_MAX_ITERATIONS = 10 -DEFAULT_TEMPERATURE = 0.7 -DEFAULT_MAX_MESSAGE_MEMORY = 20 - - -def generate_system_prompt(game_name: str) -> str: - """Generate the system prompt for the AI agent. - - Args: - game_name: Name of the game being played - - Returns: - str: The system prompt for the AI agent - """ - return """You are a specialized AI assistant designed to play Pokémon games via screenshot analysis and text instructions. Your task is to understand the current game state from visual input, determine appropriate actions, and respond with structured outputs that control the game. - -For each turn, you will receive: -1. A screenshot of the current game state -2. Contextual information about the game progress, recent events, and objectives - -Based on this information, you must analyze the situation, determine the best course of action, and provide a structured JSON response. - -## Response Format -Your response MUST follow this exact JSON format with no additional markers, tags, or block delimiters: - -{ - "analysis": "Brief analysis of the current game situation, visible UI elements, and important context (1-3 sentences)", - "current_objective": "The immediate goal based on the game state (single sentence)", - "reasoning": "Step-by-step logic explaining your chosen action sequence (2-4 sentences)", - "progress_assessment": "Evaluation of whether previous action(s) achieved their intended goal and why/why not (1-2 sentences)", - "actions": [ - { - "type": "press", - "keys": ["up"|"down"|"left"|"right"|"a"|"b"|"start"|"select"|"pause"] - }, - { - "type": "wait", - "time": milliseconds_to_wait - } - ] -} - -IMPORTANT: Do not include any conversation markers like <> or <> around your response. Provide only the clean JSON object. - -## Action Types -- Button presses: {"type": "press", "keys": ["button_name"]} - Valid buttons are: up, down, left, right, a, b, start, select, pause -- Wait for processing: {"type": "wait", "time": milliseconds} - -## Important Rules -1. Never use "wait" commands while the game is paused. The game state will not change while paused, so waiting is ineffective. -2. If you detect the game is paused, your next action should be to unpause by using {"type": "press", "keys": ["pause"]} before attempting other actions. -3. Maintain awareness of whether the game is in a paused state based on visual cues in the screenshot. - -## Game Play Guidelines -1. **Navigation**: Use directional buttons to move the character or navigate menus -2. **Interaction**: Use 'a' to confirm selections and interact with objects/NPCs, 'b' to cancel or exit menus -3. **Menu Access**: Use 'start' to access the game menu -4. **Battle Strategy**: Analyze Pokémon types, moves, and stats to make optimal battle decisions -5. **Progressive Play**: Work toward completing the current objective while being mindful of longer-term goals like leveling Pokémon, collecting badges, and advancing the story -6. **Resource Management**: Monitor and manage HP, PP, items, and Pokéballs effectively -7. **Memory**: Maintain awareness of the game history and your previous actions to avoid repetitive behaviors - -Always provide thoughtful analysis and clear reasoning for your decisions. If you're uncertain about the best course of action, prioritize safe moves that gather more information. -""" # noqa: E501 - - -def extract_action_from_response_block(block: dict[str, Any]) -> list[dict[str, Any]]: - """Extract actions from a response block. - - Args: - block: The response block containing actions - - Returns: - list[dict[str, Any]]: List of actions extracted from the block - """ - if "actions" in block: - actions = block["actions"] - if isinstance(actions, list): - return actions - return [] - - -def extract_json_from_response(response: str) -> str: - """Extract JSON from a response string. - - Args: - response: The response string containing JSON - - Returns: - str: The extracted JSON string - """ - # Try to find JSON block with markdown code block markers - start = response.find("```json") - end = response.rfind("```") - if start != -1 and end != -1: - start += len("```json") - return response[start:end].strip() - - # Try to find JSON object directly - start = response.find("{") - end = response.rfind("}") - if start != -1 and end != -1: - return response[start : end + 1].strip() - - return response.strip() - - -class ClaudePlaysPokemon(Agent[AsyncAnthropic, CLA]): - """AI agent that plays Pokémon games using Claude.""" - - def __init__( - self, - client: AsyncAnthropic | None = None, - adapter: Adapter | None = None, - model: str = DEFAULT_MODEL, - max_tokens: int = DEFAULT_MAX_TOKENS, - max_iterations: int = DEFAULT_MAX_ITERATIONS, - temperature: float = DEFAULT_TEMPERATURE, - max_message_memory: int = DEFAULT_MAX_MESSAGE_MEMORY, - ) -> None: - """Initialize the Claude Plays Pokémon agent. - - Args: - client: Anthropic API client - adapter: Game adapter - model: Claude model to use - max_tokens: Maximum tokens for response - max_iterations: Maximum number of iterations - temperature: Response temperature - max_message_memory: Maximum number of messages to remember - - Raises: - ValueError: If API key is not provided - """ - if client is None: - api_key = settings.anthropic_api_key - if not api_key: - raise ValueError("Anthropic API key is required") - client = AsyncAnthropic(api_key=api_key) - - if adapter is None: - adapter = Adapter() - - super().__init__( - client=client, - adapter=adapter, - ) - - self.model = model - self.max_tokens = max_tokens - self.max_iterations = max_iterations - self.temperature = temperature - self.max_message_memory = max_message_memory - - self.system_prompts: list[BetaMessageParam] = [ - { - "role": "assistant", - "content": generate_system_prompt("Pokemon Red"), - } - ] - - self.messages: list[BetaMessageParam] = [] - - async def fetch_response(self, observation: Observation) -> tuple[list[dict[str, Any]], bool]: - """Fetch a response from Claude based on the current observation. - - Args: - observation: The current game observation - - Returns: - tuple[list[dict[str, Any]], bool, list[LogType] | None]: List of actions, whether the game is done, and a list of strings or dictionaries of logs. - - Raises: - ValueError: If client is not initialized - """ # noqa: E501 - if not self.client: - raise ValueError("Client is not initialized") - - user_content: list[BetaTextBlockParam | BetaImageBlockParam] = [] - - if observation.text: - user_content.append( - { - "type": "text", - "text": observation.text, - } - ) - - if observation.screenshot: - logger.debug("Processing screenshot data") - user_content.append( - { - "type": "image", - "source": { - "type": "base64", - "media_type": "image/png", - "data": observation.screenshot, - }, - } - ) - - self.messages.append( - { - "role": "user", - "content": user_content, - } - ) - - logger.debug( - "Sending messages to Claude", extra={"messages": self.system_prompts + self.messages} - ) - - response = await self.client.beta.messages.create( - model=self.model, - messages=self.system_prompts + self.messages, - temperature=self.temperature, - max_tokens=self.max_tokens, - ) - - response_content = response.content - self.messages.append( - cast( - "BetaMessageParam", - { - "role": "user", - "content": response_content, - }, - ) - ) - - # Maintain message memory limit - while len(self.messages) > self.max_message_memory: - self.messages.pop(0) - - action_list: list[dict[str, Any]] = [] - - # Parse response content to extract actions - for block in response_content: - if block.type == "text": - text_json = extract_json_from_response(block.text) - try: - text = json.loads(text_json) - if not isinstance(text, dict): - logger.error("Invalid response format", extra={"text": text}) - raise ValueError("Response is not a dictionary") - - action_list.extend(extract_action_from_response_block(text)) - - except json.JSONDecodeError as e: - logger.error( - "Failed to parse response", extra={"error": str(e), "text": text_json} - ) - - else: - logger.error("Unexpected block type", extra={"type": type(block)}) - - logger.debug("Extracted actions", extra={"actions": action_list}) - - return action_list, False From 05c212ce2d5e9b969bb79a2fce5685092f3167e1 Mon Sep 17 00:00:00 2001 From: lorenss-m Date: Tue, 9 Dec 2025 10:27:48 -0800 Subject: [PATCH 11/92] typing --- hud/clients/fastmcp.py | 2 +- hud/environment/connectors/openai.py | 2 +- hud/environment/integrations/anthropic.py | 2 +- hud/environment/integrations/langchain.py | 1 + hud/environment/integrations/openai.py | 21 ++---- hud/telemetry/__init__.py | 86 ++--------------------- hud/tools/shell.py | 19 +++-- hud/utils/mcp.py | 2 +- 8 files changed, 31 insertions(+), 104 deletions(-) diff --git a/hud/clients/fastmcp.py b/hud/clients/fastmcp.py index 695d4ae0..04880ba7 100644 --- a/hud/clients/fastmcp.py +++ b/hud/clients/fastmcp.py @@ -110,7 +110,7 @@ async def _connect(self, mcp_config: dict[str, dict[str, Any]]) -> None: hasattr(self._client, "_session_state") and self._client._session_state.session is not None ): - self._client._session_state.session._validate_structured_outputs = ( + self._client._session_state.session._validate_structured_outputs = ( # type: ignore[attr-defined] self._strict_validation ) except ImportError: diff --git a/hud/environment/connectors/openai.py b/hud/environment/connectors/openai.py index 893e50b1..5ce2df3b 100644 --- a/hud/environment/connectors/openai.py +++ b/hud/environment/connectors/openai.py @@ -67,7 +67,7 @@ def calculate(expression: str) -> float: ) for tool in tools: - if isinstance(tool, FunctionTool): + if FunctionTool is not None and isinstance(tool, FunctionTool): self._add_openai_function_tool(tool, prefix) return self diff --git a/hud/environment/integrations/anthropic.py b/hud/environment/integrations/anthropic.py index d1427b4e..584dea02 100644 --- a/hud/environment/integrations/anthropic.py +++ b/hud/environment/integrations/anthropic.py @@ -170,7 +170,7 @@ def tool_names(self) -> set[str]: self._tool_names = {t.name for t in self.env.as_tools()} return self._tool_names - async def run(self, tool_use_block: Any) -> dict[str, Any]: + async def run(self, tool_use_block: Any) -> Any: """Execute a tool_use block from Claude. Args: diff --git a/hud/environment/integrations/langchain.py b/hud/environment/integrations/langchain.py index 9b505a08..52bd10d6 100644 --- a/hud/environment/integrations/langchain.py +++ b/hud/environment/integrations/langchain.py @@ -107,6 +107,7 @@ async def async_invoke(**kwargs: Any) -> str: return result return json.dumps(result) if result else "" + assert StructuredTool is not None # Checked in as_langchain_tools return StructuredTool( name=tool.name, description=tool.description or "", diff --git a/hud/environment/integrations/openai.py b/hud/environment/integrations/openai.py index c261765d..54893e1f 100644 --- a/hud/environment/integrations/openai.py +++ b/hud/environment/integrations/openai.py @@ -181,29 +181,20 @@ def as_openai_agent_tools(self) -> list[Any]: def _create_function_tool(env: OpenAIMixin, tool: mcp_types.Tool) -> Any: """Create a FunctionTool that calls back to the environment.""" - import asyncio - schema = tool.inputSchema or {"type": "object", "properties": {}} - def sync_wrapper(**kwargs: Any) -> str: - """Synchronous wrapper for the tool.""" - loop = asyncio.get_event_loop() - if loop.is_running(): - import concurrent.futures - - with concurrent.futures.ThreadPoolExecutor() as executor: - future = executor.submit(asyncio.run, env.call_tool(tool.name, **kwargs)) - result = future.result() - else: - result = loop.run_until_complete(env.call_tool(tool.name, **kwargs)) - + async def async_wrapper(ctx: Any, args_json: str) -> str: + """Async wrapper for the tool that matches FunctionTool signature.""" + kwargs = json.loads(args_json) if args_json else {} + result = await env.call_tool(tool.name, **kwargs) if isinstance(result, str): return result return json.dumps(result) if result else "" + assert FunctionTool is not None # Checked in as_openai_agent_tools return FunctionTool( name=tool.name, description=tool.description or "", params_json_schema=schema, - on_invoke_tool=sync_wrapper, + on_invoke_tool=async_wrapper, ) diff --git a/hud/telemetry/__init__.py b/hud/telemetry/__init__.py index 1125fba0..a243af80 100644 --- a/hud/telemetry/__init__.py +++ b/hud/telemetry/__init__.py @@ -3,89 +3,17 @@ This module provides: - instrument: Function instrumentation decorator -All other APIs are deprecated: -- Job, job, create_job, get_current_job - Use hud.eval() instead -- async_trace(), trace() - Use env.trace() instead -- async_job() - Use hud.eval() instead +For other APIs, import directly from submodules: +- hud.telemetry.job: Job, job, create_job, get_current_job +- hud.telemetry.trace: Trace, trace +- hud.telemetry.async_context: async_job, async_trace +- hud.telemetry.replay: clear_trace, get_trace -Migration: - # Old (deprecated): - async with hud.async_trace("Task"): - await agent.run(task) - - # New (recommended): - async with env.trace("Task") as tc: - await agent.run(task) - tc.reward = result.reward +Recommended: Use hud.eval() or env.eval() instead. """ from __future__ import annotations from .instrument import instrument - -def __getattr__(name: str): # noqa: ANN202 - """Lazy load deprecated APIs and show warnings.""" - import warnings - - deprecated_apis = { - # Job APIs (deprecated) - "Job", - "job", - "create_job", - "get_current_job", - # OpenTelemetry-based APIs (deprecated, require [agents]) - "async_job", - "async_trace", - "clear_trace", - "get_trace", - "Trace", - "trace", - } - - if name in deprecated_apis: - warnings.warn( - f"hud.telemetry.{name} is deprecated. Use hud.eval() or env.trace() instead.", - DeprecationWarning, - stacklevel=2, - ) - - # Import from submodules - if name in ("Job", "job", "create_job", "get_current_job"): - from .job import Job, create_job, get_current_job, job - - return { - "Job": Job, - "job": job, - "create_job": create_job, - "get_current_job": get_current_job, - }[name] - elif name in ("async_job", "async_trace"): - from .async_context import async_job, async_trace - - return async_job if name == "async_job" else async_trace - elif name in ("clear_trace", "get_trace"): - from .replay import clear_trace, get_trace - - return clear_trace if name == "clear_trace" else get_trace - elif name in ("Trace", "trace"): - from .trace import Trace, trace - - return Trace if name == "Trace" else trace - - raise AttributeError(f"module 'hud.telemetry' has no attribute {name!r}") - - -__all__ = [ - "Job", - "Trace", - "async_job", - "async_trace", - "clear_trace", - "create_job", - "get_current_job", - "get_trace", - "instrument", - "job", - "trace", -] +__all__ = ["instrument"] diff --git a/hud/tools/shell.py b/hud/tools/shell.py index eff208b8..fe6a7efa 100644 --- a/hud/tools/shell.py +++ b/hud/tools/shell.py @@ -11,6 +11,7 @@ import asyncio import os +import sys from dataclasses import dataclass from typing import Any, Literal @@ -81,15 +82,21 @@ async def start(self) -> None: await asyncio.sleep(0) return - def demote() -> None: - # This only runs in the child process - os.setsid() - os.setgid(1000) - os.setuid(1000) + # preexec_fn and user demotion only available on Unix + preexec_fn = None + if sys.platform != "win32": + + def demote() -> None: + # This only runs in the child process (Unix only) + os.setsid() # type: ignore[attr-defined] + os.setgid(1000) # type: ignore[attr-defined] + os.setuid(1000) # type: ignore[attr-defined] + + preexec_fn = demote self._process = await asyncio.create_subprocess_shell( # noqa: S604 self.command, - preexec_fn=demote, + preexec_fn=preexec_fn, shell=True, bufsize=0, stdin=asyncio.subprocess.PIPE, diff --git a/hud/utils/mcp.py b/hud/utils/mcp.py index fe5044c9..c42f1346 100644 --- a/hud/utils/mcp.py +++ b/hud/utils/mcp.py @@ -67,7 +67,7 @@ def setup_hud_telemetry( return None from hud.otel import get_current_task_run_id - from hud.telemetry import trace + from hud.telemetry.trace import trace run_id = get_current_task_run_id() auto_trace_cm = None From dcdf9f093b9ef0accaa3b25a809aeb79020e9d59 Mon Sep 17 00:00:00 2001 From: lorenss-m Date: Tue, 9 Dec 2025 10:53:48 -0800 Subject: [PATCH 12/92] format and test fixes --- hud/cli/tests/test_eval.py | 44 ++++++-- hud/environment/connectors/openai.py | 17 +-- hud/environment/integrations/anthropic.py | 19 +--- hud/environment/integrations/langchain.py | 22 ++-- hud/environment/integrations/openai.py | 20 +--- hud/tests/test_datasets_extended.py | 130 ++++++++-------------- hud/tests/test_init.py | 8 +- hud/tests/test_init_module.py | 11 +- 8 files changed, 113 insertions(+), 158 deletions(-) diff --git a/hud/cli/tests/test_eval.py b/hud/cli/tests/test_eval.py index 04d26475..d367c447 100644 --- a/hud/cli/tests/test_eval.py +++ b/hud/cli/tests/test_eval.py @@ -280,9 +280,16 @@ async def test_agent_config_intersection_union_via_run_dataset( "validate_api_key": False, } + # Create mock context + mock_ctx = AsyncMock() + mock_ctx.__aenter__ = AsyncMock(return_value=mock_ctx) + mock_ctx.__aexit__ = AsyncMock(return_value=None) + mock_ctx._suppress_link = False + with ( - patch("hud.job"), - patch("hud.trace"), + patch("hud.eval.context.EvalContext.from_task", return_value=mock_ctx), + patch("hud.eval.display.print_link"), + patch("hud.eval.display.print_complete"), patch.object(ClaudeAgent, "_run_context", mock_run_context), patch.object(ClaudeAgent, "call_tools", mock_call_tools), patch("hud.clients.MCPClient", return_value=mock_client_instance), @@ -349,9 +356,16 @@ async def test_no_allowed_tools_keeps_all_tools_except_disallowed( "validate_api_key": False, } + # Create mock context + mock_ctx = AsyncMock() + mock_ctx.__aenter__ = AsyncMock(return_value=mock_ctx) + mock_ctx.__aexit__ = AsyncMock(return_value=None) + mock_ctx._suppress_link = False + with ( - patch("hud.job"), - patch("hud.trace"), + patch("hud.eval.context.EvalContext.from_task", return_value=mock_ctx), + patch("hud.eval.display.print_link"), + patch("hud.eval.display.print_complete"), patch.object(ClaudeAgent, "_run_context", mock_run_context), patch.object(ClaudeAgent, "call_tools", mock_call_tools), patch("hud.clients.MCPClient", return_value=mock_client_instance), @@ -445,9 +459,16 @@ async def test_task_system_prompt_only( # Agent config with no custom system_prompt (will use default) agent_init_config = {"validate_api_key": False, "system_prompt": SYSTEM_PROMPT} + # Create mock context + mock_ctx = AsyncMock() + mock_ctx.__aenter__ = AsyncMock(return_value=mock_ctx) + mock_ctx.__aexit__ = AsyncMock(return_value=None) + mock_ctx._suppress_link = False + with ( - patch("hud.job"), - patch("hud.trace"), + patch("hud.eval.context.EvalContext.from_task", return_value=mock_ctx), + patch("hud.eval.display.print_link"), + patch("hud.eval.display.print_complete"), patch.object(ClaudeAgent, "_run_context", mock_run_context), patch.object(ClaudeAgent, "call_tools", mock_call_tools), patch("hud.clients.MCPClient", return_value=mock_mcp_client), @@ -497,9 +518,16 @@ async def test_both_agent_and_task_system_prompts( "validate_api_key": False, } + # Create mock context + mock_ctx = AsyncMock() + mock_ctx.__aenter__ = AsyncMock(return_value=mock_ctx) + mock_ctx.__aexit__ = AsyncMock(return_value=None) + mock_ctx._suppress_link = False + with ( - patch("hud.job"), - patch("hud.trace"), + patch("hud.eval.context.EvalContext.from_task", return_value=mock_ctx), + patch("hud.eval.display.print_link"), + patch("hud.eval.display.print_complete"), patch.object(ClaudeAgent, "_run_context", mock_run_context), patch.object(ClaudeAgent, "call_tools", mock_call_tools), patch("hud.clients.MCPClient", return_value=mock_mcp_client), diff --git a/hud/environment/connectors/openai.py b/hud/environment/connectors/openai.py index 5ce2df3b..4b08929a 100644 --- a/hud/environment/connectors/openai.py +++ b/hud/environment/connectors/openai.py @@ -7,15 +7,6 @@ __all__ = ["OpenAIConnectorMixin"] -# Lazy import check -try: - from agents import FunctionTool - - _HAS_OPENAI_AGENTS = True -except ImportError: - _HAS_OPENAI_AGENTS = False - FunctionTool = None # type: ignore[misc, assignment] - class OpenAIConnectorMixin: """Mixin providing OpenAI Agents SDK connector methods.""" @@ -60,14 +51,16 @@ def calculate(expression: str) -> float: Note: Requires `openai-agents`: pip install openai-agents """ - if not _HAS_OPENAI_AGENTS: + try: + from agents import FunctionTool + except ImportError as e: raise ImportError( "openai-agents is required for connect_function_tools. " "Install with: pip install openai-agents" - ) + ) from e for tool in tools: - if FunctionTool is not None and isinstance(tool, FunctionTool): + if isinstance(tool, FunctionTool): self._add_openai_function_tool(tool, prefix) return self diff --git a/hud/environment/integrations/anthropic.py b/hud/environment/integrations/anthropic.py index 584dea02..66f84b4f 100644 --- a/hud/environment/integrations/anthropic.py +++ b/hud/environment/integrations/anthropic.py @@ -5,15 +5,6 @@ import json from typing import TYPE_CHECKING, Any -# Try to import anthropic -try: - from anthropic.types.beta import BetaToolResultBlockParam - - _HAS_ANTHROPIC = True -except ImportError: - _HAS_ANTHROPIC = False - BetaToolResultBlockParam = None # type: ignore[misc, assignment] - if TYPE_CHECKING: import mcp.types as mcp_types @@ -150,9 +141,6 @@ def as_anthropic_runner(self) -> EnvToolRunner: results.append(result) ``` """ - if not _HAS_ANTHROPIC: - raise ImportError("Anthropic SDK not installed. Install with: pip install anthropic") - return EnvToolRunner(self) @@ -200,6 +188,9 @@ async def run(self, tool_use_block: Any) -> Any: } # Return typed object if anthropic is available - if _HAS_ANTHROPIC and BetaToolResultBlockParam is not None: + try: + from anthropic.types.beta import BetaToolResultBlockParam + return BetaToolResultBlockParam(**result_dict) - return result_dict + except ImportError: + return result_dict diff --git a/hud/environment/integrations/langchain.py b/hud/environment/integrations/langchain.py index 52bd10d6..f86e936f 100644 --- a/hud/environment/integrations/langchain.py +++ b/hud/environment/integrations/langchain.py @@ -7,15 +7,6 @@ from hud.environment.utils.schema import schema_to_pydantic -# Try to import langchain -try: - from langchain_core.tools import StructuredTool - - _HAS_LANGCHAIN = True -except ImportError: - _HAS_LANGCHAIN = False - StructuredTool = None # type: ignore[misc, assignment] - if TYPE_CHECKING: import mcp.types as mcp_types @@ -68,17 +59,21 @@ def as_langchain_tools(self) -> list[Any]: result = await executor.ainvoke({"input": "Navigate to google.com"}) ``` """ - if not _HAS_LANGCHAIN: - raise ImportError("LangChain not installed. Install with: pip install langchain-core") + try: + from langchain_core.tools import StructuredTool + except ImportError as e: + raise ImportError( + "LangChain not installed. Install with: pip install langchain-core" + ) from e tools = [] for t in self.as_tools(): - tool = _create_structured_tool(self, t) + tool = _create_structured_tool(self, t, StructuredTool) tools.append(tool) return tools -def _create_structured_tool(env: LangChainMixin, tool: mcp_types.Tool) -> Any: +def _create_structured_tool(env: LangChainMixin, tool: mcp_types.Tool, StructuredTool: type) -> Any: """Create a StructuredTool that calls back to the environment.""" import asyncio @@ -107,7 +102,6 @@ async def async_invoke(**kwargs: Any) -> str: return result return json.dumps(result) if result else "" - assert StructuredTool is not None # Checked in as_langchain_tools return StructuredTool( name=tool.name, description=tool.description or "", diff --git a/hud/environment/integrations/openai.py b/hud/environment/integrations/openai.py index 54893e1f..015e8ada 100644 --- a/hud/environment/integrations/openai.py +++ b/hud/environment/integrations/openai.py @@ -7,15 +7,6 @@ from hud.environment.utils.schema import ensure_strict_schema -# Try to import OpenAI Agents SDK -try: - from agents import FunctionTool - - _HAS_AGENTS = True -except ImportError: - _HAS_AGENTS = False - FunctionTool = None # type: ignore[misc, assignment] - if TYPE_CHECKING: import mcp.types as mcp_types @@ -167,19 +158,21 @@ def as_openai_agent_tools(self) -> list[Any]: print(result.final_output) ``` """ - if not _HAS_AGENTS: + try: + from agents import FunctionTool + except ImportError as e: raise ImportError( "OpenAI Agents SDK not installed. Install with: pip install openai-agents" - ) + ) from e tools = [] for t in self.as_tools(): - tool = _create_function_tool(self, t) + tool = _create_function_tool(self, t, FunctionTool) tools.append(tool) return tools -def _create_function_tool(env: OpenAIMixin, tool: mcp_types.Tool) -> Any: +def _create_function_tool(env: OpenAIMixin, tool: mcp_types.Tool, FunctionTool: type) -> Any: """Create a FunctionTool that calls back to the environment.""" schema = tool.inputSchema or {"type": "object", "properties": {}} @@ -191,7 +184,6 @@ async def async_wrapper(ctx: Any, args_json: str) -> str: return result return json.dumps(result) if result else "" - assert FunctionTool is not None # Checked in as_openai_agent_tools return FunctionTool( name=tool.name, description=tool.description or "", diff --git a/hud/tests/test_datasets_extended.py b/hud/tests/test_datasets_extended.py index 69afb59b..1bbf5d42 100644 --- a/hud/tests/test_datasets_extended.py +++ b/hud/tests/test_datasets_extended.py @@ -157,13 +157,9 @@ async def test_run_dataset_empty(self): """Test running empty dataset.""" with ( patch("hud.clients.MCPClient"), - patch("hud.async_job") as mock_job_func, - patch("hud.async_trace") as mock_trace, + patch("hud.eval.display.print_link"), + patch("hud.eval.display.print_complete"), ): - mock_job_obj = MagicMock() - mock_job_obj.id = "job-empty" - mock_job_func.return_value.__aenter__.return_value = mock_job_obj - # Create a mock agent class with proper type from hud.agents import MCPAgent @@ -176,16 +172,16 @@ async def test_run_dataset_empty(self): ) assert results == [] - mock_trace.assert_not_called() @pytest.mark.asyncio async def test_run_dataset_with_metadata(self): """Test run_dataset with custom metadata.""" from hud.agents import MCPAgent + from hud.types import Trace # Create a proper mock agent class mock_agent_instance = AsyncMock() - mock_agent_instance.run.return_value = {"status": "complete"} + mock_agent_instance.run.return_value = Trace(reward=1.0, done=True) mock_agent_class = type( "MockAgent", @@ -198,46 +194,30 @@ async def test_run_dataset_with_metadata(self): tasks = [{"prompt": "Task 1", "mcp_config": {"url": "test1"}}] - custom_metadata = { - "experiment_id": "exp-123", - "tags": ["test", "v2"], - "config": {"temperature": 0.7}, - } + # Mock EvalContext to avoid actual MCP connections + mock_ctx = AsyncMock() + mock_ctx.__aenter__ = AsyncMock(return_value=mock_ctx) + mock_ctx.__aexit__ = AsyncMock(return_value=None) with ( - patch("hud.clients.MCPClient") as MockClient, - patch("hud.datasets.runner.async_job") as mock_job_func, - patch("hud.datasets.runner.async_trace") as mock_trace, + patch("hud.clients.MCPClient"), + patch("hud.eval.context.EvalContext.from_task", return_value=mock_ctx), + patch("hud.eval.display.print_link"), + patch("hud.eval.display.print_complete"), ): - mock_job = AsyncMock() - mock_job.id = "job-meta" - mock_job_func.return_value.__aenter__.return_value = mock_job - mock_trace.return_value.__aenter__.return_value = "trace-id" - - mock_client = AsyncMock() - MockClient.return_value = mock_client - + # Should run without error await run_dataset( "metadata_run", tasks, mock_agent_class, # type: ignore {"verbose": True}, - metadata=custom_metadata, ) - # Verify job was created with merged metadata - expected_metadata = { - "experiment_id": "exp-123", - "tags": ["test", "v2"], - "config": {"temperature": 0.7}, - "agent_config": {"verbose": True}, - } - - mock_job_func.assert_called_once_with("metadata_run", metadata=expected_metadata) - @pytest.mark.asyncio async def test_run_dataset_exception_handling(self): """Test exception handling during task execution.""" + from hud.types import Trace + # Track execution by task index executed_task_indices: set[int] = set() @@ -252,7 +232,7 @@ async def mock_run(task, **run_kwargs): if task_idx == 1: # Second task (index 1) should fail raise RuntimeError("Task 2 failed") - return {"result": f"success-{task_idx + 1}"} + return Trace(reward=1.0, done=True, content=f"success-{task_idx + 1}") agent.run = mock_run return agent @@ -264,19 +244,20 @@ async def mock_run(task, **run_kwargs): tasks = [{"prompt": f"Task {i}", "mcp_config": {"url": f"test{i}"}} for i in range(3)] + # Create mock contexts for each task + def create_mock_ctx(*args, **kwargs): + ctx = AsyncMock() + ctx.__aenter__ = AsyncMock(return_value=ctx) + ctx.__aexit__ = AsyncMock(return_value=None) + ctx._suppress_link = False + return ctx + with ( - patch("hud.clients.MCPClient") as MockClient, - patch("hud.async_job") as mock_job_func, - patch("hud.async_trace") as mock_trace, + patch("hud.clients.MCPClient"), + patch("hud.eval.context.EvalContext.from_task", side_effect=create_mock_ctx), + patch("hud.eval.display.print_link"), + patch("hud.eval.display.print_complete"), ): - mock_job = MagicMock() - mock_job.id = "job-error" - mock_job_func.return_value.__aenter__.return_value = mock_job - mock_trace.return_value.__aenter__.return_value = "trace-id" - - mock_client = AsyncMock() - MockClient.return_value = mock_client - # Should complete without raising results = await run_dataset("error_run", tasks, mock_agent_class) # type: ignore @@ -284,64 +265,47 @@ async def mock_run(task, **run_kwargs): assert len(executed_task_indices) == 3 assert executed_task_indices == {0, 1, 2} - # First and third should succeed - assert results[0] == {"result": "success-1"} - assert results[2] == {"result": "success-3"} # Second result should be None due to exception assert results[1] is None @pytest.mark.asyncio async def test_run_dataset_client_cleanup(self): - """Test that MCP clients are properly cleaned up.""" + """Test that run_dataset completes successfully.""" from hud.agents import MCPAgent - - # Track client instances - client_instances = [] - - def create_client(**kwargs): - client = AsyncMock() - client_instances.append(client) - return client - - # Mock agent that creates a client - def mock_agent_init(self, client=None, **kwargs): - if client is None: - # Create client if not provided - this simulates real agent behavior - from hud.clients import MCPClient - - self.client = MCPClient() # This will use our mocked version - else: - self.client = client + from hud.types import Trace mock_agent_instance = AsyncMock() - mock_agent_instance.run.return_value = {"done": True} + mock_agent_instance.run.return_value = Trace(reward=1.0, done=True) mock_agent_class = type( "MockAgent", (MCPAgent,), { - "__init__": mock_agent_init, + "__init__": lambda self, **kwargs: None, "__new__": lambda cls, **kwargs: mock_agent_instance, }, ) tasks = [{"prompt": f"Task {i}", "mcp_config": {"url": f"test{i}"}} for i in range(3)] + # Create mock contexts + def create_mock_ctx(*args, **kwargs): + ctx = AsyncMock() + ctx.__aenter__ = AsyncMock(return_value=ctx) + ctx.__aexit__ = AsyncMock(return_value=None) + ctx._suppress_link = False + return ctx + with ( - patch("hud.clients.MCPClient", side_effect=create_client), - patch("hud.job") as mock_job_func, - patch("hud.trace") as mock_trace, + patch("hud.clients.MCPClient"), + patch("hud.eval.context.EvalContext.from_task", side_effect=create_mock_ctx), + patch("hud.eval.display.print_link"), + patch("hud.eval.display.print_complete"), ): - mock_job = MagicMock() - mock_job.id = "job-cleanup" - mock_job_func.return_value.__enter__.return_value = mock_job - mock_trace.return_value.__enter__.return_value = "trace-id" - - await run_dataset("cleanup_run", tasks, mock_agent_class) # type: ignore + results = await run_dataset("cleanup_run", tasks, mock_agent_class) # type: ignore - # Since agents might not create clients in our current implementation, - # just verify the test completes successfully - assert len(client_instances) >= 0 # Accept any number of clients created + # Verify results were returned + assert len(results) == 3 @pytest.mark.asyncio async def test_run_dataset_validation_error(self): diff --git a/hud/tests/test_init.py b/hud/tests/test_init.py index 8e7ecf4b..4c264405 100644 --- a/hud/tests/test_init.py +++ b/hud/tests/test_init.py @@ -41,12 +41,10 @@ def test_all_exports_available(self): import hud expected_exports = [ - "clear_trace", - "create_job", - "get_trace", + "Environment", + "EvalContext", + "eval", "instrument", - "job", - "trace", ] for export in expected_exports: diff --git a/hud/tests/test_init_module.py b/hud/tests/test_init_module.py index 6f76d52c..2fba8a0c 100644 --- a/hud/tests/test_init_module.py +++ b/hud/tests/test_init_module.py @@ -21,15 +21,10 @@ def test_all_exports(self): import hud expected = [ - "Trace", - "async_job", - "async_trace", - "clear_trace", - "create_job", - "get_trace", + "Environment", + "EvalContext", + "eval", "instrument", - "job", - "trace", ] assert set(hud.__all__) == set(expected) From 6d68ea2dea761a10266b83d6110347657629afd8 Mon Sep 17 00:00:00 2001 From: lorenss-m Date: Tue, 9 Dec 2025 10:57:03 -0800 Subject: [PATCH 13/92] test adjustments --- hud/eval/tests/test_context.py | 10 +++++++--- hud/tests/test_settings.py | 7 ++++--- hud/tools/tests/test_apply_patch.py | 4 +++- hud/tools/tests/test_jupyter_tool.py | 16 ++++++++-------- hud/utils/mcp.py | 12 ++++++++++-- 5 files changed, 32 insertions(+), 17 deletions(-) diff --git a/hud/eval/tests/test_context.py b/hud/eval/tests/test_context.py index 2f291feb..f749377a 100644 --- a/hud/eval/tests/test_context.py +++ b/hud/eval/tests/test_context.py @@ -129,18 +129,22 @@ class TestEvalContextFromEnvironment: """Tests for EvalContext.from_environment factory.""" def test_copies_connections(self) -> None: - """from_environment copies connections from parent.""" + """from_environment copies connections from parent (deep copy).""" from hud.environment import Environment parent = Environment("parent-env") - # Add a mock connection + # Add a mock connection with copy method mock_conn = MagicMock() + mock_conn_copy = MagicMock() + mock_conn.copy.return_value = mock_conn_copy parent._connections["test-conn"] = mock_conn ctx = EvalContext.from_environment(parent, name="test-task") + # Verify connection was copied (not same object) assert "test-conn" in ctx._connections - assert ctx._connections["test-conn"] is mock_conn + mock_conn.copy.assert_called_once() + assert ctx._connections["test-conn"] is mock_conn_copy def test_copies_prompt(self) -> None: """from_environment copies prompt from parent.""" diff --git a/hud/tests/test_settings.py b/hud/tests/test_settings.py index 538427cb..47a605ac 100644 --- a/hud/tests/test_settings.py +++ b/hud/tests/test_settings.py @@ -13,10 +13,11 @@ def test_get_settings(): def test_settings_defaults(): - """Test that settings have expected default values.""" + """Test that settings have expected default values or env overrides.""" s = get_settings() - assert s.hud_telemetry_url == "https://telemetry.hud.ai/v3/api" - assert s.hud_mcp_url == "https://mcp.hud.ai/v3/mcp" + # These URLs may be overridden by environment variables + assert s.hud_telemetry_url.endswith("/v3/api") + assert s.hud_mcp_url.endswith("/v3/mcp") # Default may be overridden in CI; just assert the field exists and is bool assert isinstance(s.telemetry_enabled, bool) assert s.hud_logging is True diff --git a/hud/tools/tests/test_apply_patch.py b/hud/tools/tests/test_apply_patch.py index d27263ea..87255f75 100644 --- a/hud/tools/tests/test_apply_patch.py +++ b/hud/tools/tests/test_apply_patch.py @@ -423,7 +423,9 @@ def test_validate_path_valid(self): with tempfile.TemporaryDirectory() as tmpdir: tool = ApplyPatchTool(base_path=tmpdir) result = tool._validate_path("subdir/file.txt") - assert result == os.path.join(tmpdir, "subdir/file.txt") + # Normalize path separators for cross-platform compatibility + expected = os.path.normpath(os.path.join(tmpdir, "subdir/file.txt")) + assert result == expected @pytest.mark.asyncio async def test_call_missing_type(self): diff --git a/hud/tools/tests/test_jupyter_tool.py b/hud/tools/tests/test_jupyter_tool.py index 02bd1c6b..3dec0025 100644 --- a/hud/tools/tests/test_jupyter_tool.py +++ b/hud/tools/tests/test_jupyter_tool.py @@ -82,9 +82,9 @@ async def test_connect_new_kernel(self): mock_client = MagicMock(fetch=AsyncMock(return_value=mock_response)) with ( - patch("hud.tools.jupyter.AsyncHTTPClient", return_value=mock_client), - patch("hud.tools.jupyter.websocket_connect", new_callable=AsyncMock), - patch("hud.tools.jupyter.PeriodicCallback"), + patch("tornado.httpclient.AsyncHTTPClient", return_value=mock_client), + patch("tornado.websocket.websocket_connect", new_callable=AsyncMock), + patch("tornado.ioloop.PeriodicCallback"), ): await tool._connect() assert tool._kernel_id == "new-kernel-123" @@ -94,9 +94,9 @@ async def test_connect_existing_kernel(self): """Test connecting to an existing kernel.""" tool = JupyterTool(kernel_id="existing-kernel-456") with ( - patch("hud.tools.jupyter.AsyncHTTPClient"), - patch("hud.tools.jupyter.websocket_connect", new_callable=AsyncMock), - patch("hud.tools.jupyter.PeriodicCallback"), + patch("tornado.httpclient.AsyncHTTPClient"), + patch("tornado.websocket.websocket_connect", new_callable=AsyncMock), + patch("tornado.ioloop.PeriodicCallback"), ): await tool._connect() assert tool._kernel_id == "existing-kernel-456" @@ -150,7 +150,7 @@ async def hang_forever(): with ( patch("hud.tools.jupyter.uuid4") as mock_uuid, - patch("hud.tools.jupyter.AsyncHTTPClient", return_value=mock_client), + patch("tornado.httpclient.AsyncHTTPClient", return_value=mock_client), ): mock_uuid.return_value.hex = "test-msg" result = await tool._execute("while True: pass", execution_timeout=1) @@ -164,7 +164,7 @@ async def test_shutdown(self): tool._ws = MagicMock() tool._heartbeat_callback = MagicMock() - with patch("hud.tools.jupyter.AsyncHTTPClient"): + with patch("tornado.httpclient.AsyncHTTPClient"): await tool.shutdown() assert tool._kernel_id == "" assert tool._ws is None diff --git a/hud/utils/mcp.py b/hud/utils/mcp.py index c42f1346..87c35fae 100644 --- a/hud/utils/mcp.py +++ b/hud/utils/mcp.py @@ -20,11 +20,19 @@ class MCPConfigPatch(BaseModel): def _is_hud_server(url: str) -> bool: """Check if a URL is a HUD MCP server. - Matches any mcp.hud.* domain (including .ai, .so, and future domains). + Matches: + - Any mcp.hud.* domain (including .ai, .so, and future domains) + - Staging servers (orcstaging.hud.so) + - Any *.hud.ai or *.hud.so domain """ if not url: return False - return "mcp.hud." in url.lower() + url_lower = url.lower() + return ( + "mcp.hud." in url_lower + or ".hud.ai" in url_lower + or ".hud.so" in url_lower + ) def patch_mcp_config(mcp_config: dict[str, dict[str, Any]], patch: MCPConfigPatch) -> None: From 8a508648bcb3fddb8a0b0c21fda69f0329b5b541 Mon Sep 17 00:00:00 2001 From: lorenss-m Date: Tue, 9 Dec 2025 10:59:49 -0800 Subject: [PATCH 14/92] mock tests --- hud/cli/tests/test_convert.py | 1 + hud/tests/test_datasets_extended.py | 2 +- hud/utils/tests/test_tasks.py | 16 ++++++++-------- 3 files changed, 10 insertions(+), 9 deletions(-) diff --git a/hud/cli/tests/test_convert.py b/hud/cli/tests/test_convert.py index 02c66481..7b8f3d04 100644 --- a/hud/cli/tests/test_convert.py +++ b/hud/cli/tests/test_convert.py @@ -70,6 +70,7 @@ def test_convert_tasks_basic( """Test basic task conversion from local to remote.""" # Setup mocks mock_settings.api_key = "test-api-key" + mock_settings.hud_mcp_url = "https://mcp.hud.ai/v3/mcp" mock_find_env.return_value = mock_env_dir # Mock the push check to return updated lock data diff --git a/hud/tests/test_datasets_extended.py b/hud/tests/test_datasets_extended.py index 1bbf5d42..264771a1 100644 --- a/hud/tests/test_datasets_extended.py +++ b/hud/tests/test_datasets_extended.py @@ -128,7 +128,7 @@ class TestDatasetOperations: def test_save_taskconfigs_empty_list(self): """Test saving empty task list.""" - with patch("hud.utils.tasks.Dataset") as MockDataset: + with patch("datasets.Dataset") as MockDataset: mock_instance = MagicMock() MockDataset.from_list.return_value = mock_instance mock_instance.push_to_hub.return_value = None diff --git a/hud/utils/tests/test_tasks.py b/hud/utils/tests/test_tasks.py index 8a038e43..9979e752 100644 --- a/hud/utils/tests/test_tasks.py +++ b/hud/utils/tests/test_tasks.py @@ -195,7 +195,7 @@ def test_save_tasks_basic(): {"id": "2", "prompt": "test2", "mcp_config": {"key2": "value2"}}, ] - with patch("hud.utils.tasks.Dataset") as mock_dataset_class: + with patch("datasets.Dataset") as mock_dataset_class: mock_dataset = MagicMock() mock_dataset_class.from_list.return_value = mock_dataset @@ -215,7 +215,7 @@ def test_save_tasks_with_specific_fields(): {"id": "1", "prompt": "test", "mcp_config": {"key": "value"}, "extra": "data"}, ] - with patch("hud.utils.tasks.Dataset") as mock_dataset_class: + with patch("datasets.Dataset") as mock_dataset_class: mock_dataset = MagicMock() mock_dataset_class.from_list.return_value = mock_dataset @@ -233,7 +233,7 @@ def test_save_tasks_with_list_field(): {"id": "1", "tags": ["tag1", "tag2"], "count": 5}, ] - with patch("hud.utils.tasks.Dataset") as mock_dataset_class: + with patch("datasets.Dataset") as mock_dataset_class: mock_dataset = MagicMock() mock_dataset_class.from_list.return_value = mock_dataset @@ -257,7 +257,7 @@ def test_save_tasks_with_primitive_types(): }, ] - with patch("hud.utils.tasks.Dataset") as mock_dataset_class: + with patch("datasets.Dataset") as mock_dataset_class: mock_dataset = MagicMock() mock_dataset_class.from_list.return_value = mock_dataset @@ -282,7 +282,7 @@ def __str__(self): {"id": "1", "custom": CustomObj()}, ] - with patch("hud.utils.tasks.Dataset") as mock_dataset_class: + with patch("datasets.Dataset") as mock_dataset_class: mock_dataset = MagicMock() mock_dataset_class.from_list.return_value = mock_dataset @@ -315,7 +315,7 @@ def test_save_tasks_with_kwargs(): """Test save_tasks passes kwargs to push_to_hub.""" tasks = [{"id": "1", "prompt": "test"}] - with patch("hud.utils.tasks.Dataset") as mock_dataset_class: + with patch("datasets.Dataset") as mock_dataset_class: mock_dataset = MagicMock() mock_dataset_class.from_list.return_value = mock_dataset @@ -332,7 +332,7 @@ def test_save_tasks_field_not_in_dict(): {"id": "1", "prompt": "test"}, ] - with patch("hud.utils.tasks.Dataset") as mock_dataset_class: + with patch("datasets.Dataset") as mock_dataset_class: mock_dataset = MagicMock() mock_dataset_class.from_list.return_value = mock_dataset @@ -346,7 +346,7 @@ def test_save_tasks_field_not_in_dict(): def test_save_tasks_empty_list(): """Test save_tasks with empty list.""" - with patch("hud.utils.tasks.Dataset") as mock_dataset_class: + with patch("datasets.Dataset") as mock_dataset_class: mock_dataset = MagicMock() mock_dataset_class.from_list.return_value = mock_dataset From e7cd40acb2d9a4f463be6f12022c5fdfdc9f3021 Mon Sep 17 00:00:00 2001 From: lorenss-m Date: Tue, 9 Dec 2025 11:00:15 -0800 Subject: [PATCH 15/92] small adjustment --- hud/cli/tests/test_convert.py | 1 + 1 file changed, 1 insertion(+) diff --git a/hud/cli/tests/test_convert.py b/hud/cli/tests/test_convert.py index 7b8f3d04..cbdb6c8b 100644 --- a/hud/cli/tests/test_convert.py +++ b/hud/cli/tests/test_convert.py @@ -190,6 +190,7 @@ def test_convert_with_env_vars( ): """Test conversion includes environment variables as headers.""" mock_settings.api_key = "test-api-key" + mock_settings.hud_mcp_url = "https://mcp.hud.ai/v3/mcp" mock_find_env.return_value = mock_env_dir mock_confirm.return_value = True # Always confirm in tests From 3f0243972e74883b0b8eea245d5407ef5031c1d2 Mon Sep 17 00:00:00 2001 From: lorenss-m Date: Tue, 9 Dec 2025 11:09:38 -0800 Subject: [PATCH 16/92] fix tests --- .github/workflows/ci.yml | 2 +- hud/utils/mcp.py | 6 +----- pyproject.toml | 2 +- 3 files changed, 3 insertions(+), 7 deletions(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 61d713d0..b070add5 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -59,4 +59,4 @@ jobs: uses: astral-sh/setup-uv@v5 - name: Run pyright - run: uv run --with=".[rl,dev]" pyright + run: uv run --with=".[dev]" pyright diff --git a/hud/utils/mcp.py b/hud/utils/mcp.py index 87c35fae..e9335d54 100644 --- a/hud/utils/mcp.py +++ b/hud/utils/mcp.py @@ -28,11 +28,7 @@ def _is_hud_server(url: str) -> bool: if not url: return False url_lower = url.lower() - return ( - "mcp.hud." in url_lower - or ".hud.ai" in url_lower - or ".hud.so" in url_lower - ) + return "mcp.hud." in url_lower or ".hud.ai" in url_lower or ".hud.so" in url_lower def patch_mcp_config(mcp_config: dict[str, dict[str, Any]], patch: MCPConfigPatch) -> None: diff --git a/pyproject.toml b/pyproject.toml index 5f788d2b..cf1742d2 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -24,7 +24,6 @@ dependencies = [ "watchfiles>=0.21.0", "questionary==2.1.0", "prompt-toolkit==3.0.51", # Locked for questionary compatibility - # Terminal library with mouse support for JSON viewer "blessed>=1.20.0", "scarf-sdk>=0.1.0", ] @@ -115,6 +114,7 @@ agents = [ "anthropic>=0.75", "openai>=2.8.1", "google-genai", + "openai-agents", # Dataset loading (HuggingFace) "datasets>=2.14.0", # Telemetry / OpenTelemetry tracing From 8316abd1fc71c7cfc972b81ed33e5c53974e2d2a Mon Sep 17 00:00:00 2001 From: lorenss-m Date: Wed, 10 Dec 2025 07:51:37 -0800 Subject: [PATCH 17/92] rewrite hud.eval --- hud/cli/__init__.py | 35 ++- hud/cli/build.py | 49 +++- hud/cli/dev.py | 7 +- hud/cli/flows/init.py | 154 +++++++++++ hud/cli/flows/templates.py | 138 ++++++++++ hud/cli/utils/docker.py | 10 +- hud/cli/utils/source_hash.py | 4 +- hud/environment/__init__.py | 2 + hud/environment/connectors/remote.py | 21 +- hud/environment/environment.py | 100 ++++++- hud/environment/scripts.py | 206 ++++++++++++++ hud/environment/tests/test_scripts.py | 162 +++++++++++ hud/eval/__init__.py | 33 ++- hud/eval/context.py | 11 +- hud/eval/eval.py | 247 +++++++++++++++++ hud/eval/manager.py | 323 ++++++++++++++++------ hud/eval/mixin.py | 380 -------------------------- hud/eval/parallel.py | 104 +------ hud/eval/tests/test_eval.py | 236 ++++++++++++++++ hud/eval/tests/test_manager.py | 133 +++++++++ hud/eval/tests/test_mixin.py | 129 --------- hud/eval/tests/test_parallel.py | 65 ----- hud/types.py | 18 ++ 23 files changed, 1759 insertions(+), 808 deletions(-) create mode 100644 hud/cli/flows/init.py create mode 100644 hud/cli/flows/templates.py create mode 100644 hud/environment/scripts.py create mode 100644 hud/environment/tests/test_scripts.py create mode 100644 hud/eval/eval.py delete mode 100644 hud/eval/mixin.py create mode 100644 hud/eval/tests/test_eval.py create mode 100644 hud/eval/tests/test_manager.py delete mode 100644 hud/eval/tests/test_mixin.py diff --git a/hud/cli/__init__.py b/hud/cli/__init__.py index 96b4fce4..8d3c0e4b 100644 --- a/hud/cli/__init__.py +++ b/hud/cli/__init__.py @@ -26,7 +26,6 @@ from .debug import debug_mcp_stdio from .dev import run_mcp_dev_server from .eval import eval_command -from .init import create_environment from .pull import pull_command from .push import push_command from .remove import remove_command @@ -889,31 +888,37 @@ def remove( @app.command() def init( - name: str = typer.Argument(None, help="Environment name (default: chosen preset name)"), + name: str = typer.Argument(None, help="Environment name (default: directory name)"), + directory: str = typer.Option(".", "--dir", "-d", help="Target directory"), + force: bool = typer.Option(False, "--force", "-f", help="Overwrite existing files"), preset: str | None = typer.Option( None, "--preset", "-p", - help="Preset to use: blank, deep-research, browser, rubrics. If omitted, you'll choose interactively.", # noqa: E501 + help="Download a preset: blank, deep-research, browser, rubrics", ), - directory: str = typer.Option(".", "--dir", "-d", help="Parent directory for the environment"), - force: bool = typer.Option(False, "--force", "-f", help="Overwrite existing files"), ) -> None: - """🚀 Initialize a new HUD environment with minimal boilerplate. + """🚀 Initialize a HUD environment. + + [not dim]• Empty directory: Choose a preset interactively + • Existing project: Add Dockerfile.hud and hud.py - [not dim]Creates a working MCP environment with: - - Dockerfile for containerization - - pyproject.toml for dependencies - - Minimal MCP server with context - - Required setup/evaluate tools + Use --preset to skip selection and download a specific template. Examples: - hud init # Choose preset interactively, create ./preset-name/ - hud init my-env # Create new directory ./my-env/ - hud init my-env --dir /tmp # Create in /tmp/my-env/[/not dim] + hud init # Auto-detect mode + hud init my-env # Initialize with custom name + hud init --preset browser # Download browser preset[/not dim] """ - create_environment(name, directory, force, preset) + if preset: + from hud.cli.init import create_environment + + create_environment(name, directory, force, preset) + else: + from hud.cli.flows.init import smart_init + + smart_init(name, directory, force) @app.command() diff --git a/hud/cli/build.py b/hud/cli/build.py index b7f169a1..e191f0bc 100644 --- a/hud/cli/build.py +++ b/hud/cli/build.py @@ -25,6 +25,30 @@ from .utils.registry import save_to_registry +def find_dockerfile(directory: Path) -> Path | None: + """Find the Dockerfile in a directory, preferring Dockerfile.hud. + + Checks for Dockerfile.hud first (HUD-specific), then falls back to Dockerfile. + + Args: + directory: Directory to search in + + Returns: + Path to the Dockerfile if found, None otherwise + """ + # Prefer Dockerfile.hud for HUD environments + hud_dockerfile = directory / "Dockerfile.hud" + if hud_dockerfile.exists(): + return hud_dockerfile + + # Fall back to standard Dockerfile + standard_dockerfile = directory / "Dockerfile" + if standard_dockerfile.exists(): + return standard_dockerfile + + return None + + def parse_version(version_str: str) -> tuple[int, int, int]: """Parse version string like '1.0.0' or '1.0' into tuple of integers.""" # Remove 'v' prefix if present @@ -530,16 +554,21 @@ def build_docker_image( hud_console = HUDConsole() build_args = build_args or {} - # Check if Dockerfile exists - dockerfile = directory / "Dockerfile" - if not dockerfile.exists(): + # Check if Dockerfile exists (prefer Dockerfile.hud) + dockerfile = find_dockerfile(directory) + if dockerfile is None: hud_console.error(f"No Dockerfile found in {directory}") + hud_console.info("Expected: Dockerfile.hud or Dockerfile") return False # Build command - use buildx when remote cache is enabled effective_platform = platform if platform is not None else "linux/amd64" cmd = ["docker", "buildx", "build"] if remote_cache else ["docker", "build"] + # Specify dockerfile explicitly if not the default name + if dockerfile.name != "Dockerfile": + cmd.extend(["-f", str(dockerfile)]) + if effective_platform: cmd.extend(["--platform", effective_platform]) cmd.extend(["-t", tag]) @@ -653,15 +682,17 @@ def build_environment( # Step 2: If no lock, check for Dockerfile if not base_name: - dockerfile_path = env_dir / "Dockerfile" - if not dockerfile_path.exists(): + dockerfile_path = find_dockerfile(env_dir) + if dockerfile_path is None: hud_console.error(f"Not a valid environment directory: {directory}") - hud_console.info("Expected: Dockerfile or hud.lock.yaml") + hud_console.info("Expected: Dockerfile.hud, Dockerfile, or hud.lock.yaml") raise typer.Exit(1) # First build - use directory name base_name = env_dir.name hud_console.info(f"First build - using base name: {base_name}") + if dockerfile_path.name == "Dockerfile.hud": + hud_console.info("Using Dockerfile.hud") # If user provides --tag, respect it; otherwise use base name only (version added later) if tag: @@ -725,7 +756,7 @@ def build_environment( hud_console.success(tool_msg) # Extract environment variables from Dockerfile - dockerfile_path = env_dir / "Dockerfile" + dockerfile_path = find_dockerfile(env_dir) or env_dir / "Dockerfile" required_env, optional_env = extract_env_vars_from_dockerfile(dockerfile_path) # Show env vars detected from .env file @@ -885,6 +916,10 @@ def build_environment( # Build command - use buildx when remote cache is enabled label_cmd = ["docker", "buildx", "build"] if remote_cache else ["docker", "build"] + # Specify dockerfile explicitly if not the default name + if dockerfile_path and dockerfile_path.name != "Dockerfile": + label_cmd.extend(["-f", str(dockerfile_path)]) + # Use same defaulting for the second build step label_platform = platform if platform is not None else "linux/amd64" if label_platform: diff --git a/hud/cli/dev.py b/hud/cli/dev.py index 38bccda1..74f39c8b 100644 --- a/hud/cli/dev.py +++ b/hud/cli/dev.py @@ -136,8 +136,11 @@ def auto_detect_module() -> tuple[str, Path | None] | tuple[None, None]: def should_use_docker_mode(cwd: Path) -> bool: - """Check if environment requires Docker mode (has Dockerfile in current dir).""" - return (cwd / "Dockerfile").exists() + """Check if environment requires Docker mode (has Dockerfile in current dir). + + Checks for Dockerfile.hud first (HUD-specific), then falls back to Dockerfile. + """ + return (cwd / "Dockerfile.hud").exists() or (cwd / "Dockerfile").exists() async def run_mcp_module( diff --git a/hud/cli/flows/init.py b/hud/cli/flows/init.py new file mode 100644 index 00000000..205dbac5 --- /dev/null +++ b/hud/cli/flows/init.py @@ -0,0 +1,154 @@ +"""Smart HUD environment initialization.""" + +from __future__ import annotations + +import subprocess +from pathlib import Path + +from hud.utils.hud_console import HUDConsole + +from .templates import DOCKERFILE_HUD, HUD_PY, PYPROJECT_TOML + +# Files that indicate this might be an existing project +PROJECT_INDICATORS = { + "pyproject.toml", + "package.json", + "requirements.txt", + "setup.py", + "Cargo.toml", + "go.mod", +} + + +def _normalize_name(name: str) -> str: + """Normalize name for Python identifiers.""" + name = name.replace("-", "_").replace(" ", "_") + return "".join(c if c.isalnum() or c == "_" else "_" for c in name) + + +def _add_hud_dependency(directory: Path) -> bool: + """Add hud-python using uv if available.""" + try: + result = subprocess.run( + ["uv", "add", "hud-python", "openai"], # noqa: S607 + capture_output=True, + text=True, + cwd=directory, + check=False, + ) + return result.returncode == 0 or "already" in result.stderr.lower() + except FileNotFoundError: + return False + + +def _is_empty_or_trivial(directory: Path) -> bool: + """Check if directory is empty or only has trivial files.""" + if not directory.exists(): + return True + files = list(directory.iterdir()) + # Empty + if not files: + return True + # Only has hidden files or common trivial files + trivial = {".git", ".gitignore", ".DS_Store", "README.md", "LICENSE"} + return all(f.name in trivial or f.name.startswith(".") for f in files) + + +def _has_project_files(directory: Path) -> bool: + """Check if directory has files indicating an existing project.""" + if not directory.exists(): + return False + return any(f.name in PROJECT_INDICATORS for f in directory.iterdir()) + + +def smart_init( + name: str | None = None, + directory: str = ".", + force: bool = False, +) -> None: + """Initialize HUD environment files in a directory. + + - If directory is empty: delegate to preset selection + - If directory has project files: add HUD files to existing project + - Otherwise: create new HUD environment + """ + hud_console = HUDConsole() + target = Path(directory).resolve() + + # If directory is empty, use preset selection + if _is_empty_or_trivial(target): + from hud.cli.init import create_environment + + hud_console.info("Empty directory - showing preset selection") + create_environment(name, directory, force, preset=None) + return + + # Directory has files - use smart init + target.mkdir(parents=True, exist_ok=True) + env_name = _normalize_name(name or target.name) + has_pyproject = (target / "pyproject.toml").exists() + + hud_console.header(f"HUD Init: {env_name}") + + if has_pyproject: + hud_console.info("Found pyproject.toml - adding HUD files") + else: + hud_console.info("Creating HUD environment in existing directory") + + created = [] + + # Create pyproject.toml if needed + if not has_pyproject: + pyproject = target / "pyproject.toml" + pyproject.write_text(PYPROJECT_TOML.format(name=env_name.replace("_", "-"))) + created.append("pyproject.toml") + + # Create Dockerfile.hud + dockerfile = target / "Dockerfile.hud" + if not dockerfile.exists() or force: + dockerfile.write_text(DOCKERFILE_HUD) + created.append("Dockerfile.hud") + else: + hud_console.warning("Dockerfile.hud exists, skipping (use --force)") + + # Create hud.py + hud_py = target / "hud.py" + if not hud_py.exists() or force: + hud_py.write_text(HUD_PY.format(env_name=env_name)) + created.append("hud.py") + else: + hud_console.warning("hud.py exists, skipping (use --force)") + + # Add dependency + if _add_hud_dependency(target): + hud_console.success("Added hud-python dependency") + else: + hud_console.info("Run manually: uv add hud-python openai") + + # Summary + if created: + hud_console.section_title("Created") + for f in created: + hud_console.status_item(f, "✓") + + hud_console.section_title("Next Steps") + hud_console.info("1. Edit hud.py:") + hud_console.info(" - Add your tools with @env.tool()") + hud_console.info(" - Connect existing servers (FastAPI, MCP, OpenAPI)") + hud_console.info("") + hud_console.info("2. Edit Dockerfile.hud:") + hud_console.info(" - Add system dependencies (apt-get install)") + hud_console.info(" - Set up data sources for production") + hud_console.info("") + hud_console.command_example("python hud.py", "Test locally") + hud_console.command_example("hud dev hud:env", "Development server") + hud_console.command_example("hud build", "Build Docker image") + hud_console.info("") + hud_console.section_title("Tips") + hud_console.info("• For production environments you want to mock locally,") + hud_console.info(" configure data sources in Dockerfile.hud before deploying") + hud_console.info("• For testing without real connections, use env.mock()") + hud_console.info("• See hud.py DEPLOYMENT section for remote deployment") + + +__all__ = ["smart_init"] diff --git a/hud/cli/flows/templates.py b/hud/cli/flows/templates.py new file mode 100644 index 00000000..9d22f146 --- /dev/null +++ b/hud/cli/flows/templates.py @@ -0,0 +1,138 @@ +"""Templates for hud init command.""" + +DOCKERFILE_HUD = """\ +FROM python:3.11-slim + +RUN apt-get update && apt-get install -y --no-install-recommends curl \\ + && rm -rf /var/lib/apt/lists/* + +WORKDIR /app +COPY pyproject.toml uv.lock* ./ +RUN pip install uv && uv sync --frozen --no-dev 2>/dev/null || uv sync --no-dev +COPY . . + +CMD ["uv", "run", "python", "-m", "hud", "dev", "hud:env", "--stdio"] +""" + +# fmt: off +HUD_PY = '''\ +"""{env_name} - HUD Environment""" + +import asyncio +import os + +from hud.environment import Environment + +env = Environment("{env_name}") + + +# ============================================================================= +# 1. ADD FUNCTIONS AS TOOLS +# ============================================================================= +# Decorate any function with @env.tool() to expose it as a tool. + +@env.tool() +def hud(query: str) -> str: + """A tool that returns the answer to any question.""" + return f"Oh, I know the answer to '{{query}}', it's 42." + + +# ============================================================================= +# 2. IMPORT FROM EXISTING SERVERS +# ============================================================================= + +# --- FastAPI app --- +# from my_app import app +# env.connect_fastapi(app) + +# --- FastMCP / MCPServer --- +# from my_server import mcp +# env.connect_server(mcp) + +# --- OpenAPI spec (URL or file path) --- +# env.connect_openapi("https://api.example.com/openapi.json") + + +# ============================================================================= +# 3. CONNECT REMOTE SERVERS +# ============================================================================= + +# --- MCP config (stdio or SSE) --- +# env.connect_mcp_config({{ +# "my-server": {{"command": "uvx", "args": ["some-mcp-server"]}} +# }}) + +# --- HUD hub (requires deployment, see below) --- +# env.connect_hub("my-org/my-env", prefix="remote") + + +# ============================================================================= +# TEST - Run with: python hud.py +# ============================================================================= + +async def test(): + from openai import AsyncOpenAI + + async with env.task("test") as ctx: + # 1. List tools + tools = await env.list_tools() + print(f"Tools: {{[t.name for t in tools]}}") + + # 2. Call the hud tool + result = await env.call_tool("hud", query="What is HUD?") + print(f"HUD result: {{result}}") + + # 3. Call inference.hud.ai + client = AsyncOpenAI( + base_url="https://inference.hud.ai/v1", + api_key=os.environ.get("HUD_API_KEY", ""), + ) + response = await client.chat.completions.create( + model="claude-sonnet-4-5", + messages=[{{"role": "user", "content": "Say hello in one word."}}], + ) + print(f"LLM: {{response.choices[0].message.content}}") + + # 4. Assign reward + ctx.reward = 1.0 if "42" in str(result) else 0.0 + print(f"Reward: {{ctx.reward}}") + + +if __name__ == "__main__": + asyncio.run(test()) + + +# ============================================================================= +# DEPLOYMENT +# ============================================================================= +# To deploy this environment on HUD: +# +# 1. Push this repo to GitHub +# 2. Go to hud.ai -> New -> Environment +# 3. Choose "From GitHub URL" and paste your repo URL +# 4. This deploys the environment for remote connection +# +# Once deployed, connect to it from other environments: +# env.connect_hub("{env_name}") +# +# Remote deployment enables: +# - Parallelized evaluations (run many agents simultaneously) +# - Training data collection at scale +# - Shared environments across team members +# +# Note: The test() function above is just for local testing. +# It's not required for the deployed environment. +''' +# fmt: on + +PYPROJECT_TOML = """\ +[project] +name = "{name}" +version = "0.1.0" +requires-python = ">=3.10" +dependencies = ["hud-python", "openai"] + +[build-system] +requires = ["hatchling"] +build-backend = "hatchling.build" +""" diff --git a/hud/cli/utils/docker.py b/hud/cli/utils/docker.py index 8ef850af..27cbd54b 100644 --- a/hud/cli/utils/docker.py +++ b/hud/cli/utils/docker.py @@ -121,7 +121,7 @@ def detect_environment_dir(start_dir: Path | None = None) -> Path | None: - Current directory containing `hud.lock.yaml` - Parent directory containing `hud.lock.yaml` - Current directory that looks like an environment if it has either a - `Dockerfile` or a `pyproject.toml` (looser than `is_environment_directory`). + `Dockerfile.hud`, `Dockerfile`, or a `pyproject.toml` (looser than `is_environment_directory`). Returns the detected directory path or None if not found. """ @@ -132,8 +132,12 @@ def detect_environment_dir(start_dir: Path | None = None) -> Path | None: if (candidate / "hud.lock.yaml").exists(): return candidate - # Fallback: treat as env if it has Dockerfile OR pyproject.toml - if (base / "Dockerfile").exists() or (base / "pyproject.toml").exists(): + # Fallback: treat as env if it has Dockerfile.hud, Dockerfile, or pyproject.toml + if ( + (base / "Dockerfile.hud").exists() + or (base / "Dockerfile").exists() + or (base / "pyproject.toml").exists() + ): return base return None diff --git a/hud/cli/utils/source_hash.py b/hud/cli/utils/source_hash.py index 71af3bfc..22123396 100644 --- a/hud/cli/utils/source_hash.py +++ b/hud/cli/utils/source_hash.py @@ -1,7 +1,7 @@ """Utilities to compute a fast, deterministic source hash for environments. This intentionally focuses on the typical HUD environment layout and aims to be fast: -- Always include: Dockerfile, pyproject.toml +- Always include: Dockerfile.hud, Dockerfile, pyproject.toml - Include directories: controller/, environment/, src/ - Exclude common build/runtime caches and lock files @@ -40,7 +40,7 @@ "hud.lock.yaml", } -INCLUDE_FILES = {"Dockerfile", "pyproject.toml"} +INCLUDE_FILES = {"Dockerfile", "Dockerfile.hud", "pyproject.toml"} INCLUDE_DIRS = {"server", "mcp", "controller", "environment"} diff --git a/hud/environment/__init__.py b/hud/environment/__init__.py index 3bd64c39..1746606d 100644 --- a/hud/environment/__init__.py +++ b/hud/environment/__init__.py @@ -28,6 +28,7 @@ from hud.environment.environment import Environment from hud.environment.mock import MockMixin, generate_mock_value from hud.environment.router import ConflictResolution, ToolRouter +from hud.environment.scripts import ScriptMixin from hud.environment.types import EnvConfig, HubConfig from hud.environment.utils import ToolFormat, format_result, parse_tool_call, parse_tool_calls @@ -40,6 +41,7 @@ "Environment", "HubConfig", "MockMixin", + "ScriptMixin", "ToolFormat", "ToolRouter", "format_result", diff --git a/hud/environment/connectors/remote.py b/hud/environment/connectors/remote.py index f7500d64..0b2e66c3 100644 --- a/hud/environment/connectors/remote.py +++ b/hud/environment/connectors/remote.py @@ -12,6 +12,8 @@ from fastmcp.tools.tool import Tool + from hud.environment.types import HubConfig + __all__ = ["RemoteConnectorMixin"] logger = logging.getLogger(__name__) @@ -21,7 +23,7 @@ class RemoteConnectorMixin(MCPConfigConnectorMixin): """Mixin providing remote connection methods.""" # Store hub configs for trace serialization - _hub_configs: list[dict[str, Any]] + _hub_configs: list[HubConfig] def mount(self, server: Any, *, prefix: str | None = None) -> None: raise NotImplementedError @@ -51,18 +53,17 @@ def connect_hub( """ import httpx + from hud.environment.types import HubConfig from hud.settings import settings # Store hub config for trace serialization - hub_config: dict[str, Any] = {"slug": slug} - if alias: - hub_config["alias"] = alias - if prefix: - hub_config["prefix"] = prefix - if include: - hub_config["include"] = include - if exclude: - hub_config["exclude"] = exclude + hub_config = HubConfig( + slug=slug, + alias=alias, + prefix=prefix, + include=include, + exclude=exclude, + ) if not hasattr(self, "_hub_configs"): self._hub_configs = [] diff --git a/hud/environment/environment.py b/hud/environment/environment.py index 8509b6ca..482034da 100644 --- a/hud/environment/environment.py +++ b/hud/environment/environment.py @@ -13,7 +13,7 @@ from hud.environment.integrations import IntegrationsMixin from hud.environment.mock import MockMixin from hud.environment.router import ConflictResolution, ToolRouter -from hud.eval.mixin import EvalMixin +from hud.environment.scripts import ScriptMixin from hud.server.server import MCPServer from hud.types import MCPToolResult @@ -21,6 +21,7 @@ import types from hud.environment.connection import Connector + from hud.eval.eval import Eval __all__ = ["Environment"] @@ -34,7 +35,7 @@ class Environment( ConnectorsMixin, IntegrationsMixin, MockMixin, - EvalMixin, + ScriptMixin, MCPServer, ): """Unified MCP environment that acts as both server and client. @@ -140,6 +141,9 @@ def __init__( # Initialize mock state self._init_mock() + # Initialize script state + self._init_scripts() + # ========================================================================= # Core Methods # ========================================================================= @@ -483,7 +487,7 @@ def _get_env_config(self) -> dict[str, Any] | None: return { "name": self.name, - "hubs": hub_configs, + "hubs": [h.model_dump() for h in hub_configs], "setup_tools": setup_tools, "evaluate_tools": evaluate_tools, } @@ -517,3 +521,93 @@ def _all_hubs(self) -> bool: def __repr__(self) -> str: return f"Environment({self.name!r}, connections={list(self._connections.keys())})" + + # ========================================================================= + # Eval Creation + # ========================================================================= + + def __call__(self, script: str | None = None, **args: Any) -> Eval: + """Create an Eval from this environment. + + Returns an Eval that can be entered as a context manager or passed + to hud.eval() for orchestration. + + Args: + script: Optional script name to run (from @env.script) + **args: Arguments for the script + + Returns: + Eval: A runnable evaluation unit + + Example: + ```python + env = Environment("my-env").connect_hub("browser") + + @env.script() + async def checkout(user_id: str): + yield "Complete checkout" + yield 1.0 + + # Simple use - Eval is context manager + async with env("checkout", user_id="alice") as ctx: + await agent.run(ctx.prompt) + + # Empty - just env + async with env() as ctx: + await ctx.call_tool("navigate", url="...") + + # Orchestrated via hud.eval + evals = [env("checkout", user_id="alice"), env("checkout", user_id="bob")] + async with hud.eval(evals, variants={"model": ["gpt-4o"]}, group=4) as ctx: + ... + ``` + """ + from hud.eval.eval import Eval + + return Eval( + env_config=self._get_env_config(), + script=script, + args=args, + ) + + @classmethod + def from_config(cls, config: dict[str, Any] | None) -> Environment: + """Create an Environment from a configuration dict. + + Args: + config: EnvConfig-compatible dict with: + - name: Environment name + - hubs: List of hub configs (HubConfig dicts) + - setup_tools: Tools to run after connection + - evaluate_tools: Tools to run before disconnection + + Returns: + Environment: Configured environment instance + """ + if config is None: + return cls("eval") + + env = cls(name=config.get("name", "eval")) + + # Connect hubs + for hub in config.get("hubs", []): + if isinstance(hub, dict): + env.connect_hub( + hub.get("slug", ""), + alias=hub.get("alias"), + prefix=hub.get("prefix"), + include=hub.get("include"), + exclude=hub.get("exclude"), + ) + + # Add setup tools + for tool in config.get("setup_tools", []): + if isinstance(tool, dict): + env.setup_tool(tool.get("name", ""), **(tool.get("arguments") or {})) + + # Add evaluate tools + for tool in config.get("evaluate_tools", []): + if isinstance(tool, dict): + env.evaluate_tool(tool.get("name", ""), **(tool.get("arguments") or {})) + + return env diff --git a/hud/environment/scripts.py b/hud/environment/scripts.py new file mode 100644 index 00000000..00ebd88e --- /dev/null +++ b/hud/environment/scripts.py @@ -0,0 +1,206 @@ +"""Script decorator for Environment - defines setup/evaluate phases.""" + +from __future__ import annotations + +import inspect +import json +import logging +import uuid +from typing import TYPE_CHECKING, Any + +if TYPE_CHECKING: + from collections.abc import AsyncGenerator, Callable + + from fastmcp.prompts import PromptManager + from fastmcp.resources import ResourceManager + +__all__ = ["ScriptMixin"] + +logger = logging.getLogger(__name__) + + +class ScriptMixin: + """Mixin providing @env.script decorator for setup/evaluate phases. + + Scripts are async generators that yield twice: + - First yield: prompt string (setup phase) + - Second yield: reward float (evaluate phase) + + The decorator registers both an MCP prompt and resource with the same + identifier (script:{name}), linked by session state. + + Example: + @env.script() + async def search_cats(url: str): + await env.call_tool("navigate", url=url) + yield "Find all cat images on the page" + result = await env.call_tool("count_cats") + yield float(result > 0) + """ + + # These come from Environment/MCPServer + name: str + _prompt_manager: PromptManager + _resource_manager: ResourceManager + + # Script state + _scripts: dict[str, Callable[..., AsyncGenerator[Any, None]]] + _script_sessions: dict[str, AsyncGenerator[Any, None]] # session_id -> generator + _script_latest: dict[str, str] # script_name -> latest session_id + + def _init_scripts(self) -> None: + """Initialize script state. Called from Environment.__init__.""" + self._scripts = {} + self._script_sessions = {} + self._script_latest = {} + + def script( + self, + name: str | None = None, + description: str | None = None, + ) -> Callable[ + [Callable[..., AsyncGenerator[Any, None]]], + Callable[..., AsyncGenerator[Any, None]], + ]: + """Decorator to register a script with setup and evaluate phases. + + Creates both a prompt and resource with identifier script:{name}. + The script function should yield twice: + - First yield: the prompt string (returned from prompt) + - Second yield: the reward float (returned from resource) + + Args: + name: Optional name for the script (defaults to function name) + description: Optional description of what the script does + + Example: + @env.script() + async def search_cats(url: str): + await env.call_tool("navigate", url=url) + yield "Find cat images" + result = await env.call_tool("count_cats") + yield float(result > 0) + + # MCP client usage: + # 1. get_prompt("{env_name}:search_cats", {url: "..."}) -> prompt messages + # 2. agent runs... + # 3. read_resource("{env_name}:search_cats") -> {"reward": 0.95} + """ + + def decorator( + fn: Callable[..., AsyncGenerator[Any, None]], + ) -> Callable[..., AsyncGenerator[Any, None]]: + script_name = name or fn.__name__ + script_id = f"{self.name}:{script_name}" + script_desc = description or fn.__doc__ or f"Script: {script_name}" + + # Store the generator function + self._scripts[script_name] = fn + + # Get function signature for prompt arguments + sig = inspect.signature(fn) + prompt_args = [ + {"name": p.name, "required": p.default is inspect.Parameter.empty} + for p in sig.parameters.values() + ] + + # Register PROMPT - runs setup, returns prompt messages + # We need a reference to self and the outer variables + script_self = self + script_fn = fn + script_name_ref = script_name + + async def prompt_handler(**handler_args: Any) -> list[dict[str, Any]]: + # Create generator instance + gen = script_fn(**handler_args) + + # Run setup phase (code before first yield) + prompt_text = await gen.__anext__() + + # Store generator with session ID + session_id = uuid.uuid4().hex[:8] + script_self._script_sessions[session_id] = gen + script_self._script_latest[script_name_ref] = session_id + + logger.debug( + "Script %s setup complete, session=%s, prompt=%s", + script_name_ref, + session_id, + prompt_text[:50] if isinstance(prompt_text, str) else prompt_text, + ) + + return [{"role": "user", "content": str(prompt_text)}] + + # Register prompt using FastMCP - create FunctionPrompt directly + # to bypass the **kwargs validation in from_function() + from fastmcp.prompts.prompt import FunctionPrompt, PromptArgument + + prompt = FunctionPrompt( + name=script_id, + description=f"[Setup] {script_desc}", + arguments=[ + PromptArgument(name=arg["name"], required=arg["required"]) + for arg in prompt_args + ], + fn=prompt_handler, + ) + self._prompt_manager.add_prompt(prompt) + + # Register RESOURCE - runs evaluate, returns reward + async def resource_handler() -> str: + # Get latest session for this script + session_id = self._script_latest.get(script_name) + if not session_id: + raise ValueError( + f"No active session for script '{script_name}'. " + "Call the prompt first to run setup." + ) + + gen = self._script_sessions.pop(session_id, None) + if gen is None: + raise ValueError( + f"Session '{session_id}' not found or already evaluated." + ) + + # Run evaluate phase (code after first yield) + try: + reward = await gen.__anext__() + except StopAsyncIteration: + # Generator ended without second yield - assume success + reward = 1.0 + + logger.debug( + "Script %s evaluate complete, session=%s, reward=%s", + script_name, + session_id, + reward, + ) + + # Clean up latest pointer if it matches + if self._script_latest.get(script_name) == session_id: + del self._script_latest[script_name] + + return json.dumps({"reward": float(reward)}) + + # Register as resource with same script: URI + from fastmcp.resources.resource import FunctionResource + + resource = FunctionResource.from_function( + fn=resource_handler, + uri=script_id, + name=script_name, + description=f"[Evaluate] {script_desc}", + mime_type="application/json", + ) + self._resource_manager.add_resource(resource) + + logger.debug( + "Registered script '%s' as prompt and resource: %s", + script_name, + script_id, + ) + + return fn + + return decorator + diff --git a/hud/environment/tests/test_scripts.py b/hud/environment/tests/test_scripts.py new file mode 100644 index 00000000..f7b3fd06 --- /dev/null +++ b/hud/environment/tests/test_scripts.py @@ -0,0 +1,162 @@ +"""Tests for Environment script decorator.""" + +from __future__ import annotations + +import pytest + +from hud.environment import Environment + + +class TestScriptDecorator: + """Tests for @env.script decorator.""" + + def test_script_registers_function(self) -> None: + """@env.script registers the function.""" + env = Environment("test-env") + + @env.script("greet") + async def greet_script(name: str): + yield f"Hello, {name}!" + yield 1.0 + + assert "greet" in env._scripts + + def test_script_creates_mcp_prompt(self) -> None: + """@env.script creates an MCP prompt.""" + env = Environment("test-env") + + @env.script("greet", description="Greeting script") + async def greet_script(name: str): + yield f"Hello, {name}!" + yield 1.0 + + # Check that prompt was registered via prompt manager + prompt_names = list(env._prompt_manager._prompts.keys()) + assert "test-env:greet" in prompt_names + + def test_script_creates_mcp_resource(self) -> None: + """@env.script creates an MCP resource.""" + env = Environment("test-env") + + @env.script("greet") + async def greet_script(name: str): + yield f"Hello, {name}!" + yield 1.0 + + # Check that resource was registered via resource manager + resource_uris = list(env._resource_manager._resources.keys()) + assert "test-env:greet" in resource_uris + + def test_script_extracts_arguments(self) -> None: + """@env.script extracts function arguments for prompt.""" + env = Environment("test-env") + + @env.script("checkout") + async def checkout_script(user_id: str, amount: int = 100): + yield f"Checkout for {user_id}: ${amount}" + yield 1.0 + + # Find the prompt + prompt = env._prompt_manager._prompts.get("test-env:checkout") + assert prompt is not None + + # Check arguments + arg_names = [arg.name for arg in prompt.arguments] + assert "user_id" in arg_names + assert "amount" in arg_names + + +class TestScriptExecution: + """Tests for script execution flow.""" + + @pytest.mark.asyncio + async def test_script_setup_phase(self) -> None: + """Script setup phase yields prompt.""" + env = Environment("test-env") + setup_ran = False + + @env.script("test") + async def test_script(): + nonlocal setup_ran + setup_ran = True + yield "Test prompt" + yield 1.0 + + # Get the prompt handler + prompt = env._prompt_manager._prompts.get("test-env:test") + assert prompt is not None + + # Run setup via prompt render (which calls fn) - no need for context + result = await prompt.render({}) + + assert setup_ran + # Result is list of PromptMessage + assert len(result) > 0 + assert "Test prompt" in str(result[0].content) + + @pytest.mark.asyncio + async def test_script_stores_session(self) -> None: + """Script stores generator in session for evaluate phase.""" + env = Environment("test-env") + + @env.script("test") + async def test_script(): + yield "Test prompt" + yield 1.0 + + # Run setup via prompt - no need for context + prompt = env._prompt_manager._prompts.get("test-env:test") + await prompt.render({}) + + # Check session was stored + assert "test" in env._script_latest + + @pytest.mark.asyncio + async def test_script_full_flow(self) -> None: + """Script runs setup and evaluate phases correctly.""" + env = Environment("test-env") + phases = [] + + @env.script("test") + async def test_script(): + phases.append("setup") + yield "Test prompt" + phases.append("evaluate") + yield 0.95 + + # Setup phase - no context needed for prompt/resource + prompt = env._prompt_manager._prompts.get("test-env:test") + await prompt.render({}) + assert "setup" in phases + assert "evaluate" not in phases + + # Evaluate phase + resource = env._resource_manager._resources.get("test-env:test") + reward_result = await resource.read() + assert "evaluate" in phases + + +class TestScriptWithArgs: + """Tests for scripts with arguments.""" + + @pytest.mark.asyncio + async def test_script_receives_args(self) -> None: + """Script receives arguments from prompt call.""" + env = Environment("test-env") + received_args = {} + + @env.script("checkout") + async def checkout_script(user_id: str, amount: int = 100): + received_args["user_id"] = user_id + received_args["amount"] = amount + yield f"Checkout {user_id}: ${amount}" + yield 1.0 + + prompt = env._prompt_manager._prompts.get("test-env:checkout") + + # No context needed for prompt render + await prompt.render({"user_id": "alice", "amount": 50}) + + assert received_args["user_id"] == "alice" + assert received_args["amount"] == 50 + diff --git a/hud/eval/__init__.py b/hud/eval/__init__.py index 88bec509..da8ce3fe 100644 --- a/hud/eval/__init__.py +++ b/hud/eval/__init__.py @@ -1,23 +1,32 @@ """HUD Eval - Evaluation context and management. This module provides: +- Eval: A runnable evaluation unit (from env()) - EvalContext: Environment with evaluation tracking (trace_id, reward, etc.) -- EvalMixin: Adds env.eval() method to Environment - eval(): Standalone context manager for task-based evaluation Usage: - # Method on existing environment - async with env.eval("task_name") as env: - await env.call_tool("navigate", url="...") - env.reward = 0.9 + # Using env() to create Eval + env = Environment("my-env").connect_hub("browser") + + async with env() as ctx: + await ctx.call_tool("navigate", url="...") + + async with env("checkout", user_id="alice") as ctx: + await agent.run(ctx.prompt) # Standalone with task slugs - async with hud.eval("my-org/task:1") as env: - await agent.run(env) + async with hud.eval("my-org/task:1") as ctx: + await agent.run(ctx) + + # Orchestrated with Eval objects + evals = [env("checkout", user_id="alice"), env("checkout", user_id="bob")] + async with hud.eval(evals, variants={"model": ["gpt-4o"]}, group=4) as ctx: + await agent.run(ctx.prompt) # Blank eval for manual reward - async with hud.eval() as env: - env.reward = compute_reward() + async with hud.eval() as ctx: + ctx.reward = compute_reward() """ from __future__ import annotations @@ -30,15 +39,15 @@ # run_eval is safe to import (uses lazy imports internally) from hud.eval.manager import run_eval -# EvalMixin is safe to import (uses lazy imports internally) -from hud.eval.mixin import EvalMixin +# Eval is safe to import +from hud.eval.eval import Eval if TYPE_CHECKING: from hud.eval.context import EvalContext __all__ = [ + "Eval", "EvalContext", - "EvalMixin", "run_eval", ] diff --git a/hud/eval/context.py b/hud/eval/context.py index 8cb5b09f..73ceab04 100644 --- a/hud/eval/context.py +++ b/hud/eval/context.py @@ -257,7 +257,8 @@ def from_task( ) -> EvalContext: """Create an EvalContext from a Task definition. - Used by hud.eval(slug) to create evaluation contexts from tasks. + .. deprecated:: 0.5.0 + Use Eval objects from env() instead of Task objects. Args: task: Task definition @@ -270,6 +271,14 @@ def from_task( variants: Variant assignment code_snippet: Code being evaluated """ + import warnings + + warnings.warn( + "EvalContext.from_task() is deprecated. Use Eval objects from env() instead.", + DeprecationWarning, + stacklevel=2, + ) + eval_name = name or task.id or "eval" return cls( diff --git a/hud/eval/eval.py b/hud/eval/eval.py new file mode 100644 index 00000000..2c9830e0 --- /dev/null +++ b/hud/eval/eval.py @@ -0,0 +1,247 @@ +"""Eval - A runnable evaluation unit (data class). + +An Eval holds the configuration needed to run an evaluation: +- Environment configuration (how to create/connect) +- Optional script name and args + +When entered as a context manager, it creates an EvalContext. + +Usage: + env = Environment("my-env").connect_hub("browser") + + # Empty - just env + async with env() as ctx: + await ctx.call_tool("navigate", url="...") + + # With script + async with env("checkout", user_id="alice") as ctx: + await agent.run(ctx.prompt) + + # Orchestrated via hud.eval + evals = [env("checkout", user_id="alice"), env("checkout", user_id="bob")] + async with hud.eval(evals, variants={"model": ["gpt-4o"]}, group=4) as ctx: + ... +""" + +from __future__ import annotations + +import logging +from dataclasses import dataclass, field +from typing import TYPE_CHECKING, Any + +if TYPE_CHECKING: + from types import TracebackType + + from hud.eval.context import EvalContext + +__all__ = ["Eval"] + +logger = logging.getLogger(__name__) + + +@dataclass +class Eval: + """A runnable evaluation unit (data class). + + Holds the configuration to create an EvalContext: + - env_config: How to create/connect the environment + - script: Optional script name to run (from @env.script) + - args: Arguments for the script + + When entered as a context manager, creates an EvalContext. + + Attributes: + env_config: Serializable environment configuration + script: Script name to run (None for env-only) + args: Script arguments + """ + + # Core config + env_config: dict[str, Any] | None = None + script: str | None = None + args: dict[str, Any] = field(default_factory=dict) + + # EvalContext creation params (set by hud.eval for parallel execution) + trace_id: str | None = field(default=None, repr=False) + api_key: str | None = field(default=None, repr=False) + job_id: str | None = field(default=None, repr=False) + group_id: str | None = field(default=None, repr=False) + index: int = field(default=0, repr=False) + variants: dict[str, Any] = field(default_factory=dict, repr=False) + code_snippet: str | None = field(default=None, repr=False) + _suppress_link: bool = field(default=False, repr=False) + + # Runtime state + _ctx: EvalContext | None = field(default=None, repr=False) + + def copy(self) -> Eval: + """Create a copy of this Eval for parallel execution.""" + return Eval( + env_config=self.env_config, + script=self.script, + args=self.args.copy(), + trace_id=None, # Each copy gets unique trace_id + api_key=self.api_key, + job_id=self.job_id, + group_id=self.group_id, + index=self.index, + variants=self.variants.copy(), + code_snippet=self.code_snippet, + _suppress_link=self._suppress_link, + ) + + def to_eval_context(self) -> EvalContext: + """Convert this Eval to an EvalContext. + + Creates an EvalContext with environment from env_config and + script info stored for setup/evaluate phases. + """ + from hud.environment import Environment + from hud.eval.context import EvalContext + + # Create environment from config + env = Environment.from_config(self.env_config) if self.env_config else Environment("eval") + + # Create EvalContext from environment + ctx = EvalContext.from_environment( + env=env, + name=self.script or "eval", + trace_id=self.trace_id, + api_key=self.api_key, + job_id=self.job_id, + group_id=self.group_id, + index=self.index, + variants=self.variants, + code_snippet=self.code_snippet, + env_config=self.env_config, + ) + ctx._suppress_link = self._suppress_link + + return ctx + + async def __aenter__(self) -> EvalContext: + """Enter eval context - create EvalContext and enter it.""" + self._ctx = self.to_eval_context() + await self._ctx.__aenter__() + + # If we have a script, run its setup phase + if self.script: + await self._run_script_setup() + + return self._ctx + + async def __aexit__( + self, + exc_type: type[BaseException] | None, + exc_val: BaseException | None, + exc_tb: TracebackType | None, + ) -> None: + """Exit eval context - run script evaluate and exit EvalContext.""" + if self._ctx is None: + return + + # If we have a script and no error, run its evaluate phase + if self.script and exc_type is None: + await self._run_script_evaluate() + + # Exit the EvalContext + await self._ctx.__aexit__(exc_type, exc_val, exc_tb) + self._ctx = None + + async def _run_script_setup(self) -> None: + """Run the script's setup phase (get prompt).""" + if self._ctx is None or self.script is None: + return + + # Check if script is registered locally + scripts = getattr(self._ctx, "_scripts", {}) + if self.script in scripts: + # Local script - run setup via generator + import uuid + + script_fn = scripts[self.script] + gen = script_fn(**self.args) + + # Run setup phase (code before first yield) + prompt = await gen.__anext__() + + # Store generator for evaluate phase + session_id = uuid.uuid4().hex[:8] + script_sessions = getattr(self._ctx, "_script_sessions", {}) + script_latest = getattr(self._ctx, "_script_latest", {}) + script_sessions[session_id] = gen + script_latest[self.script] = session_id + + # Set prompt on context + self._ctx.prompt = str(prompt) + + logger.debug( + "Script %s setup complete, session=%s", + self.script, + session_id, + ) + else: + # Remote script - call via MCP prompt + # Format: {env_name}:{script_name} + env_name = self._ctx.name if self._ctx else "eval" + prompt_id = f"{env_name}:{self.script}" + try: + result = await self._ctx.get_prompt(prompt_id, self.args) + if result.messages: + # Extract prompt from first message + first_msg = result.messages[0] + content = first_msg.content + # Handle TextContent which has .text attribute + if hasattr(content, "text") and isinstance(content.text, str): # type: ignore[union-attr] + self._ctx.prompt = content.text # type: ignore[union-attr] + elif isinstance(content, str): + self._ctx.prompt = content + except Exception as e: + logger.warning("Failed to get script prompt: %s", e) + + async def _run_script_evaluate(self) -> None: + """Run the script's evaluate phase (get reward).""" + if self._ctx is None or self.script is None: + return + + # Check if we have a stored generator (local script) + script_latest = getattr(self._ctx, "_script_latest", {}) + session_id = script_latest.get(self.script) + if session_id: + script_sessions = getattr(self._ctx, "_script_sessions", {}) + gen = script_sessions.pop(session_id, None) + if gen: + try: + reward = await gen.__anext__() + self._ctx.reward = float(reward) + logger.debug( + "Script %s evaluate complete, reward=%s", + self.script, + reward, + ) + except StopAsyncIteration: + # Generator ended without second yield - assume success + self._ctx.reward = 1.0 + + # Clean up latest pointer + if script_latest.get(self.script) == session_id: + del script_latest[self.script] + return + + # Remote script - read via MCP resource + # Format: {env_name}:{script_name} + env_name = self._ctx.name if self._ctx else "eval" + resource_id = f"{env_name}:{self.script}" + try: + import json + + contents = await self._ctx.read_resource(resource_id) + if contents: + first_content = contents[0] + # Handle TextResourceContents which has .text attribute + if hasattr(first_content, "text") and isinstance(first_content.text, str): # type: ignore[union-attr] + data = json.loads(first_content.text) # type: ignore[union-attr] + if "reward" in data: + self._ctx.reward = float(data["reward"]) + except Exception as e: + logger.warning("Failed to get script reward: %s", e) diff --git a/hud/eval/manager.py b/hud/eval/manager.py index 5a2a261e..3abe6d49 100644 --- a/hud/eval/manager.py +++ b/hud/eval/manager.py @@ -24,13 +24,14 @@ from collections.abc import AsyncGenerator from hud.eval.context import EvalContext + from hud.eval.eval import Eval from hud.types import Task logger = logging.getLogger(__name__) -# Type alias for task source: can be slug strings or Task objects -TaskSource = "str | list[str] | Task | list[Task] | None" +# Type alias for eval source: slug strings, Eval objects, or deprecated Task objects +EvalSource = "str | list[str] | Eval | list[Eval] | Task | list[Task] | None" def _parse_slug(slug: str) -> tuple[str, str | None]: @@ -53,18 +54,27 @@ def _parse_slug(slug: str) -> tuple[str, str | None]: def _get_eval_name( source: str | list[str] | None = None, - tasks: list[Task] | None = None, + evals: list[Eval] | None = None, + tasks: list[Task] | None = None, # Deprecated ) -> str: """Extract a nice name for job display. Args: source: Single slug or list of slugs (if string-based) - tasks: List of Task objects (if using direct tasks) + evals: List of Eval objects (primary path) + tasks: List of Task objects (deprecated) Returns: - Name like "evalset", task ID, or "eval" if no source + Name like "evalset", script name, or "eval" if no source """ - # If we have tasks with IDs, use first task ID + # If we have Eval objects, use first script name + if evals: + first_eval = evals[0] + if first_eval.script: + return first_eval.script + return "eval" + + # Deprecated: If we have tasks with IDs, use first task ID if tasks: first_task = tasks[0] if first_task.id: @@ -94,27 +104,27 @@ def _get_eval_name( return "eval" -def _load_tasks_from_slugs(slugs: str | list[str]) -> list[Task]: - """Load tasks from platform by slugs. +def _load_evals_from_slugs(slugs: str | list[str]) -> list[Eval]: + """Load Eval configs from platform by slugs. Args: slugs: Single slug or list of slugs. Slugs can be: - - "my-org/task" - single task - - "my-org/task:N" - task at index N - - "my-org/task:*" - all tasks matching pattern + - "my-org/eval" - single eval + - "my-org/eval:N" - eval at index N + - "my-org/eval:*" - all evals matching pattern Returns: - List of Task objects + List of Eval objects """ import httpx + from hud.eval.eval import Eval from hud.settings import settings - from hud.types import Task if isinstance(slugs, str): slugs = [slugs] - tasks: list[Task] = [] + evals: list[Eval] = [] headers = {} if settings.api_key: @@ -125,10 +135,10 @@ def _load_tasks_from_slugs(slugs: str | list[str]) -> list[Task]: base_slug, index_str = _parse_slug(slug) if index_str == "*": - # Fetch all tasks for this evalset - logger.info("Loading all tasks for: %s", base_slug) + # Fetch all evals for this evalset + logger.info("Loading all evals for: %s", base_slug) response = client.get( - f"{settings.hud_api_url}/tasks/{base_slug}", + f"{settings.hud_api_url}/evals/{base_slug}", headers=headers, params={"all": "true"}, ) @@ -136,39 +146,58 @@ def _load_tasks_from_slugs(slugs: str | list[str]) -> list[Task]: data = response.json() if isinstance(data, list): - tasks.extend(Task(**item) for item in data) + evals.extend(_eval_from_api(item) for item in data) else: - tasks.append(Task(**data)) + evals.append(_eval_from_api(data)) elif index_str is not None: - # Fetch specific task by index - logger.info("Loading task: %s (index %s)", base_slug, index_str) + # Fetch specific eval by index + logger.info("Loading eval: %s (index %s)", base_slug, index_str) response = client.get( - f"{settings.hud_api_url}/tasks/{base_slug}", + f"{settings.hud_api_url}/evals/{base_slug}", headers=headers, params={"index": index_str}, ) response.raise_for_status() data = response.json() - tasks.append(Task(**data)) + evals.append(_eval_from_api(data)) else: - # Fetch single task - logger.info("Loading task: %s", slug) + # Fetch single eval + logger.info("Loading eval: %s", slug) response = client.get( - f"{settings.hud_api_url}/tasks/{slug}", + f"{settings.hud_api_url}/evals/{slug}", headers=headers, ) response.raise_for_status() data = response.json() - tasks.append(Task(**data)) + evals.append(_eval_from_api(data)) + + return evals + - return tasks +def _eval_from_api(data: dict[str, Any]) -> Eval: + """Convert API response to Eval object. + + Expected API response format: + { + "env_config": {...}, # EnvConfig dict + "script": "script_name", # Optional + "args": {...}, # Script arguments + } + """ + from hud.eval.eval import Eval + + return Eval( + env_config=data.get("env_config"), + script=data.get("script"), + args=data.get("args", {}), + ) @asynccontextmanager async def run_eval( - source: str | list[str] | Task | list[Task] | None = None, + source: str | list[str] | Task | list[Task] | Eval | list[Eval] | None = None, *, variants: dict[str, Any] | None = None, group: int = 1, @@ -180,15 +209,17 @@ async def run_eval( """Standalone eval context manager. Creates an EvalContext for evaluation, optionally loading task configuration - from slugs or using Task objects directly. + from slugs, using Task objects, or using Eval objects directly. Args: - source: Task source. Can be: + source: Eval source. Can be: - None: Create blank eval context - str: Task slug like "my-org/task", "my-org/task:N", "my-org/task:*" - list[str]: Multiple task slugs - Task: Single Task object (for backwards compat with run_tasks) - list[Task]: List of Task objects (for backwards compat with run_tasks) + - Eval: Single Eval object (from env()) + - list[Eval]: List of Eval objects (from env()) variants: A/B test configuration (dict with list values expanded) group: Runs per variant for statistical significance group_ids: Optional list of group IDs @@ -218,12 +249,11 @@ async def run_eval( async with hud.eval("my-org/evalset:*") as ctx: await agent.run(ctx) - # With Task objects directly - from hud.types import Task - - tasks = [Task(prompt="Do X", mcp_config={...})] - async with hud.eval(tasks) as ctx: - await agent.run(ctx) + # With Eval objects (from env()) + env = Environment("my-env").connect_hub("browser") + evals = [env("checkout", user_id="alice"), env("checkout", user_id="bob")] + async with hud.eval(evals, variants={"model": ["gpt-4o"]}, group=4) as ctx: + await agent.run(ctx.prompt) # With variants and group async with hud.eval( @@ -244,6 +274,9 @@ async def run_eval( print(f"{e.variants}: reward={e.reward}") ``` """ + import warnings + + from hud.eval.eval import Eval from hud.types import Task if group <= 0: @@ -252,30 +285,51 @@ async def run_eval( # Expand variants variant_combos = expand_variants(variants) - # Parse source into tasks list - tasks: list[Task] = [] + # Parse source into evals list (or deprecated tasks list) + evals: list[Eval] = [] + tasks: list[Task] = [] # Deprecated path slugs: str | list[str] | None = None # Track if we had string slugs (for naming) if source is not None: - if isinstance(source, Task): - # Single Task object + if isinstance(source, Eval): + # Single Eval object + evals = [source] + elif isinstance(source, list) and source and isinstance(source[0], Eval): + # List of Eval objects + evals = source # type: ignore[assignment] + elif isinstance(source, Task): + # Single Task object (deprecated) + warnings.warn( + "Passing Task objects to hud.eval() is deprecated. " + "Use Eval objects from env() or string slugs instead.", + DeprecationWarning, + stacklevel=2, + ) tasks = [source] elif isinstance(source, list) and source and isinstance(source[0], Task): - # List of Task objects + # List of Task objects (deprecated) + warnings.warn( + "Passing Task objects to hud.eval() is deprecated. " + "Use Eval objects from env() or string slugs instead.", + DeprecationWarning, + stacklevel=2, + ) tasks = source # type: ignore[assignment] elif isinstance(source, str): - # String slug + # String slug - load as Eval slugs = source - tasks = _load_tasks_from_slugs(source) + evals = _load_evals_from_slugs(source) elif isinstance(source, list) and source and isinstance(source[0], str): - # List of string slugs + # List of string slugs - load as Eval slugs = source # type: ignore[assignment] - tasks = _load_tasks_from_slugs(source) # type: ignore[arg-type] + evals = _load_evals_from_slugs(source) # type: ignore[arg-type] # Calculate total evaluations + # If we have evals, each eval gets (variants x group) runs # If we have tasks, each task gets (variants x group) runs - # If no tasks, we have a single blank eval with (variants x group) runs - total_evals = len(tasks) * len(variant_combos) * group if tasks else len(variant_combos) * group + # If neither, we have a single blank eval with (variants x group) runs + base_count = len(evals) or len(tasks) or 1 + total_evals = base_count * len(variant_combos) * group # Capture code snippet for parallel execution code_snippet: str | None = None @@ -296,7 +350,16 @@ async def run_eval( if total_evals == 1: # Simple case: single eval - if tasks: + if evals: + # Single Eval object - enter it directly + single_eval = evals[0] + single_eval.api_key = api_key + single_eval.job_id = job_id + single_eval.variants = variant_combos[0] + single_eval.code_snippet = code_snippet + async with single_eval as ctx: + yield ctx + elif tasks: # Single task ctx = EvalContext.from_task( task=tasks[0], @@ -305,6 +368,8 @@ async def run_eval( variants=variant_combos[0], code_snippet=code_snippet, ) + async with ctx: + yield ctx else: # Blank eval ctx = EvalContext( @@ -314,13 +379,12 @@ async def run_eval( variants=variant_combos[0], code_snippet=code_snippet, ) - - async with ctx: - yield ctx + async with ctx: + yield ctx else: # Parallel execution: create implicit job to group traces - eval_name = _get_eval_name(source=slugs, tasks=tasks) + eval_name = _get_eval_name(source=slugs, evals=evals, tasks=tasks) implicit_job_id = job_id or str(uuid.uuid4()) job_url = f"https://hud.ai/jobs/{implicit_job_id}" @@ -331,6 +395,7 @@ async def run_eval( try: # Run parallel evals with job_id completed = await _run_parallel_eval( + evals=evals, tasks=tasks, variant_combos=variant_combos, group=group, @@ -342,7 +407,15 @@ async def run_eval( ) # Create summary context (no trace, just aggregates results) - if tasks: + if evals: + # Create summary from first eval's env_config + ctx = EvalContext( + name=evals[0].script or "eval", + api_key=api_key, + job_id=implicit_job_id, + env_config=evals[0].env_config, + ) + elif tasks: ctx = EvalContext.from_task( task=tasks[0], api_key=api_key, @@ -375,6 +448,7 @@ async def run_eval( async def _run_parallel_eval( + evals: list[Eval], tasks: list[Task], variant_combos: list[dict[str, Any]], group: int, @@ -386,77 +460,174 @@ async def _run_parallel_eval( ) -> list[EvalContext]: """Run parallel evaluation. - Creates EvalContexts from tasks (or blank) and runs them in parallel. + Creates EvalContexts from Evals, tasks (or blank) and runs them in parallel. """ + import asyncio + import textwrap + # Lazy import to avoid circular dependency from hud.eval.context import EvalContext - from hud.eval.parallel import log_eval_stats, run_parallel_evals + from hud.eval.eval import Eval + from hud.eval.parallel import log_eval_stats # Find user code frame and extract the with block body caller_frame = find_user_frame() body_source, captured_locals, context_var = get_with_block_body(caller_frame) # Calculate total evals and resolve group IDs - total_evals = len(tasks) * len(variant_combos) * group if tasks else len(variant_combos) * group + base_count = len(evals) or len(tasks) or 1 + total_evals = base_count * len(variant_combos) * group resolved_group_ids = resolve_group_ids(group_ids, total_evals) - # Create EvalContexts - eval_contexts: list[EvalContext] = [] + # Create Eval objects for parallel execution + eval_objects: list[Eval] = [] idx = 0 - if tasks: - # Create context for each (task, variant, run) combination + if evals: + # Create Eval for each (eval, variant, run) combination + for base_eval in evals: + for variant in variant_combos: + for _ in range(group): + eval_copy = base_eval.copy() + eval_copy.api_key = api_key + eval_copy.job_id = job_id + eval_copy.group_id = resolved_group_ids[idx] + eval_copy.index = idx + eval_copy.variants = variant + eval_copy.code_snippet = code_snippet + eval_copy._suppress_link = True + eval_objects.append(eval_copy) + idx += 1 + elif tasks: + # Create Eval from Task for each (task, variant, run) combination for task in tasks: for variant in variant_combos: for _ in range(group): - ctx = EvalContext.from_task( - task=task, + # Convert Task to Eval (backwards compatibility) + task_eval = Eval( + env_config=None, # Task has its own mcp_config + script=None, + args={}, api_key=api_key, job_id=job_id, group_id=resolved_group_ids[idx], index=idx, variants=variant, code_snippet=code_snippet, + _suppress_link=True, ) - ctx._suppress_link = True # Suppress individual links, job URL shown instead - eval_contexts.append(ctx) + # Store task reference for EvalContext creation + task_eval._task = task # type: ignore[attr-defined] + eval_objects.append(task_eval) idx += 1 else: # Blank evals for each (variant, run) combination for variant in variant_combos: for _ in range(group): - ctx = EvalContext( - name="eval", + blank_eval = Eval( + env_config=None, + script=None, + args={}, api_key=api_key, job_id=job_id, group_id=resolved_group_ids[idx], index=idx, variants=variant, code_snippet=code_snippet, + _suppress_link=True, ) - ctx._suppress_link = True # Suppress individual links, job URL shown instead - eval_contexts.append(ctx) + eval_objects.append(blank_eval) idx += 1 + # Create runner function using the actual variable name from the 'as' clause + wrapped = f"async def __runner__({context_var}):\n{textwrap.indent(body_source, ' ')}" + code = compile(wrapped, "", "exec") + namespace = captured_locals.copy() + exec(code, namespace) # noqa: S102 + runner = namespace["__runner__"] + + # Create semaphore for concurrency control + sem = asyncio.Semaphore(max_concurrent) if max_concurrent else None + + async def run_one(eval_obj: Eval) -> EvalContext: + """Run a single Eval and return its EvalContext.""" + # Check if this is a Task-based eval (backwards compat) + task = getattr(eval_obj, "_task", None) + + try: + if task is not None: + # Task-based: use EvalContext.from_task + ctx = EvalContext.from_task( + task=task, + api_key=eval_obj.api_key, + job_id=eval_obj.job_id, + group_id=eval_obj.group_id, + index=eval_obj.index, + variants=eval_obj.variants, + code_snippet=eval_obj.code_snippet, + ) + ctx._suppress_link = eval_obj._suppress_link + + if sem: + async with sem, ctx: + await runner(ctx) + else: + async with ctx: + await runner(ctx) + return ctx + else: + # Eval-based: enter the Eval directly + if sem: + async with sem, eval_obj as ctx: + await runner(ctx) + else: + async with eval_obj as ctx: + await runner(ctx) + return ctx + except Exception as e: + logger.warning("Parallel eval %d failed: %s", eval_obj.index, e) + # Create a failed context + if task is not None: + ctx = EvalContext.from_task( + task=task, + api_key=eval_obj.api_key, + job_id=eval_obj.job_id, + group_id=eval_obj.group_id, + index=eval_obj.index, + variants=eval_obj.variants, + code_snippet=eval_obj.code_snippet, + ) + else: + ctx = EvalContext( + name=eval_obj.script or "eval", + api_key=eval_obj.api_key, + job_id=eval_obj.job_id, + group_id=eval_obj.group_id, + index=eval_obj.index, + variants=eval_obj.variants, + code_snippet=eval_obj.code_snippet, + env_config=eval_obj.env_config, + ) + ctx.error = e + return ctx + # Run in parallel logger.info( - "Running %d evals (%d tasks x %d variants x %d runs)%s", - len(eval_contexts), - max(len(tasks), 1), + "Running %d evals (%d base x %d variants x %d runs)%s", + len(eval_objects), + base_count, len(variant_combos), group, f", max_concurrent={max_concurrent}" if max_concurrent else "", ) - completed = await run_parallel_evals( - eval_contexts, body_source, captured_locals, context_var, max_concurrent - ) + completed = await asyncio.gather(*[run_one(e) for e in eval_objects]) # Log and print stats eval_name = completed[0].eval_name if completed else "eval" log_eval_stats(completed) print_eval_stats(completed, eval_name) - return completed + return list(completed) __all__ = ["run_eval"] diff --git a/hud/eval/mixin.py b/hud/eval/mixin.py deleted file mode 100644 index 10d56df9..00000000 --- a/hud/eval/mixin.py +++ /dev/null @@ -1,380 +0,0 @@ -"""EvalMixin - Adds eval() method to Environment. - -This mixin provides the eval() context manager that creates EvalContext -instances for recording agent runs, with optional parallel execution and -variant-based A/B testing. -""" - -from __future__ import annotations - -import contextlib -import inspect -import logging -import uuid -from contextlib import asynccontextmanager -from typing import TYPE_CHECKING, Any - -from hud.eval.display import print_complete, print_eval_stats, print_link -from hud.eval.parallel import ( - ASTExtractionError, - expand_variants, - find_user_frame, - get_with_block_body, - resolve_group_ids, -) -from hud.eval.types import ParallelEvalComplete - -if TYPE_CHECKING: - from collections.abc import AsyncGenerator - - from hud.eval.context import EvalContext - from hud.types import MCPToolResult - -logger = logging.getLogger(__name__) - - -class EvalMixin: - """Mixin that adds eval capabilities to Environment. - - This mixin provides: - - eval(): Create an EvalContext for recording agent runs - - Parallel execution with group=N parameter - - A/B testing with variants parameter - - Example: - ```python - class Environment(EvalMixin, MCPServer): ... - - - env = Environment("my-env") - - # Single eval - yields EvalContext (which has Environment capabilities) - async with env.eval("task") as ctx: - await ctx.call_tool("navigate", {"url": "..."}) - ctx.reward = 0.9 - - # Parallel evals (runs 4 times) - async with env.eval("task", group=4) as ctx: - await ctx.call_tool("navigate", {"url": "..."}) - ctx.reward = 0.9 - - # A/B testing (2 variants x 3 runs = 6 evals) - async with env.eval( - "task", - variants={"model": ["gpt-4o", "claude"]}, - group=3, - ) as ctx: - model = ctx.variants["model"] - response = await call_llm(model=model) - ctx.reward = evaluate(response) - - # Access results - for e in ctx.results: - print(f"{e.variants} run {e.index}: reward={e.reward}") - ``` - """ - - # These will be provided by the Environment class - name: str - - # Store last parallel results - _last_evals: list[EvalContext] | None = None - - async def call_tool(self, name: str, arguments: dict[str, Any] | None = None) -> MCPToolResult: - """Placeholder - implemented by Environment.""" - raise NotImplementedError - - def _capture_code_snippet(self) -> str | None: - """Capture the code inside the eval() with-block (best effort). - - Returns None if source cannot be extracted (e.g., REPL, Jupyter). - """ - frame = inspect.currentframe() - if frame is None: - return None - - try: - # Go up: _capture_code_snippet -> eval -> user code - caller = frame.f_back - if caller is not None: - caller = caller.f_back - if caller is None: - return None - - body_source, _, _ = get_with_block_body(caller) - return body_source - except ASTExtractionError: - # Can't extract from REPL/Jupyter - that's OK - return None - except Exception as e: - logger.debug("Failed to capture code snippet: %s", e) - return None - finally: - del frame - - def _get_env_config(self) -> dict[str, Any] | None: - """Get serializable environment configuration. - - Returns dict with connections and local tools. - """ - # This will be overridden by Environment with actual implementation - return None - - @property - def last_evals(self) -> list[EvalContext] | None: - """Get EvalContext objects from the last parallel execution. - - Each EvalContext has: trace_id, index, reward, duration, error, success - """ - return self._last_evals - - @asynccontextmanager - async def eval( - self, - name: str, - *, - variants: dict[str, Any] | None = None, - group: int = 1, - group_ids: list[str] | None = None, - job_id: str | None = None, - trace_id: str | None = None, - api_key: str | None = None, - max_concurrent: int | None = None, - ) -> AsyncGenerator[EvalContext, None]: - """Create an eval context for recording an agent run. - - The eval context provides: - - Unique trace identification - - Task name linking (for training data construction) - - Headers for gateway integration (auto-injected to inference.hud.ai) - - Tool call capabilities (call_tool, as_openai_chat_tools, etc.) - - Reward setting - - Metrics logging - - A/B Testing: - Use `variants` to define experiment variables. Each list value - creates a variant; single values are fixed. All combinations - are expanded and run. - - Parallel Execution: - Use `group` to run multiple times per variant for statistical - significance. Total evals = len(variants combinations) x group. - - Args: - name: Task name for this eval (used for task construction) - variants: A/B test configuration. Dict where: - - List values are expanded: {"model": ["gpt-4o", "claude"]} - - Single values are fixed: {"temp": 0.7} - - All combinations are run - group: Runs per variant (default: 1) for statistical significance. - group_ids: Optional list of group IDs for each eval. - Length must match (variants x group). If not provided, - a single shared group_id is auto-generated. - job_id: Optional job ID to link this eval to. If not provided, - auto-detects from current `hud.job()` context. - trace_id: Optional trace ID (auto-generated if not provided). - For parallel execution, each eval gets a unique ID. - api_key: Optional API key for backend calls (defaults to settings.api_key) - max_concurrent: Maximum concurrent evals (None = unlimited) - - Yields: - EvalContext for this evaluation. Inside the body: - - `ctx.variants` = current variant assignment (e.g., {"model": "gpt-4o"}) - - `ctx.index` = local run index (for debugging) - - `ctx.group_id` = links all evals in this parallel execution - - `ctx.call_tool(...)` = call tools on the environment - - `ctx.reward = ...` = set reward - - After execution (for variants/group > 1): - - `ctx.results` = list of all EvalContext objects - - `ctx.reward` = mean reward across all evals - - Example: - ```python - # Single execution - async with env.eval("task") as ctx: - await ctx.call_tool("search", {"query": "..."}) - ctx.reward = 1.0 - - # A/B test: 2 variants x 3 runs = 6 evals - async with env.eval( - "task", - variants={"model": ["gpt-4o", "claude"]}, - group=3, - ) as ctx: - model = ctx.variants["model"] # Assigned per-eval - response = await call_llm(model=model) - ctx.reward = evaluate(response) - - # Access results - for e in ctx.results: - print(f"{e.variants} run {e.index}: reward={e.reward}") - ``` - - Limitations (for variants/group > 1): - - Requires source file (won't work in REPL/Jupyter) - - Outer variables captured at enter time, changes don't propagate back - - Modifying mutable objects causes race conditions - - Cannot use yield/generators inside body - """ - if group <= 0: - raise ValueError("group must be >= 1") - - # Expand variants into all combinations - variant_combos = expand_variants(variants) - total_evals = len(variant_combos) * group - - # Capture code snippet (best effort - won't work in REPL/Jupyter) - code_snippet = self._capture_code_snippet() - - # Get environment config - env_config = self._get_env_config() - - # Validate parallelization - only remote connections allowed for group > 1 - if total_evals > 1 and not self.is_parallelizable: # type: ignore[attr-defined] - local_conns = self.local_connections # type: ignore[attr-defined] - raise ValueError( - f"Cannot run parallel evals (group={group}) with local connections.\n" - f" Local connections: {local_conns}\n" - f" Local connections (stdio/Docker) can only run one instance.\n" - f" Use remote connections (HTTP/URL) for parallel execution." - ) - - # Lazy import to avoid circular dependency - from hud.eval.context import EvalContext - - if total_evals == 1: - # Simple case: single eval - # Create EvalContext from parent environment - ctx = EvalContext.from_environment( - env=self, # type: ignore[arg-type] - name=name, - trace_id=trace_id, - api_key=api_key, - job_id=job_id, - variants=variant_combos[0], - code_snippet=code_snippet, - env_config=env_config, - ) - async with ctx: - yield ctx - else: - # Parallel execution: create implicit job to group traces - implicit_job_id = job_id or str(uuid.uuid4()) - job_url = f"https://hud.ai/jobs/{implicit_job_id}" - - # Print job URL (not individual trace URLs) - print_link(job_url, f"🚀 Job '{name}'") - - error_occurred = False - try: - # Run parallel evals with job_id - completed = await self._run_parallel_eval( - name=name, - variant_combos=variant_combos, - group=group, - group_ids=group_ids, - job_id=implicit_job_id, # Propagate job_id to child traces - api_key=api_key, - code_snippet=code_snippet, - env_config=env_config, - max_concurrent=max_concurrent, - ) - - # Create summary context (no trace, just aggregates results) - ctx = EvalContext.from_environment( - env=self, # type: ignore[arg-type] - name=name, - trace_id=trace_id, - api_key=api_key, - job_id=implicit_job_id, - code_snippet=code_snippet, - env_config=env_config, - ) - ctx._is_summary = True # Skip trace tracking - ctx.results = completed - self._last_evals = completed - - # Compute aggregate reward (mean of non-None rewards) - rewards = [e.reward for e in completed if e.reward is not None] - if rewards: - ctx.reward = sum(rewards) / len(rewards) - - # Check if any failed - error_occurred = any(e.error is not None for e in completed) - - with contextlib.suppress(ParallelEvalComplete): - yield ctx - finally: - print_complete(job_url, name, error=error_occurred) - - async def _run_parallel_eval( - self, - name: str, - variant_combos: list[dict[str, Any]], - group: int, - group_ids: list[str] | None, - job_id: str | None, - api_key: str | None, - code_snippet: str | None, - env_config: dict[str, Any] | None, - max_concurrent: int | None, - ) -> list[EvalContext]: - """Run parallel eval execution. - - Creates EvalContexts from parent environment and runs them in parallel. - """ - # Lazy import to avoid circular dependency - from hud.eval.context import EvalContext - from hud.eval.parallel import log_eval_stats, run_parallel_evals - - # Find user code frame and extract the with block body - caller_frame = find_user_frame() - body_source, captured_locals, context_var = get_with_block_body(caller_frame) - - # Calculate total evals and resolve group IDs - total_evals = len(variant_combos) * group - resolved_group_ids = resolve_group_ids(group_ids, total_evals) - - # Create EvalContext for each (variant, run) combination - eval_contexts: list[EvalContext] = [] - idx = 0 - for variant in variant_combos: - for _ in range(group): - ctx = EvalContext.from_environment( - env=self, # type: ignore[arg-type] - name=name, - api_key=api_key, - job_id=job_id, - group_id=resolved_group_ids[idx], - index=idx, - variants=variant, - code_snippet=code_snippet, - env_config=env_config, - ) - ctx._suppress_link = True # Suppress individual links, job URL shown instead - eval_contexts.append(ctx) - idx += 1 - - # Run in parallel - logger.info( - "Running %d evals for '%s' (%d variants x %d runs)%s", - len(eval_contexts), - name, - len(variant_combos), - group, - f", max_concurrent={max_concurrent}" if max_concurrent else "", - ) - completed = await run_parallel_evals( - eval_contexts, body_source, captured_locals, context_var, max_concurrent - ) - - # Store results and print stats - self._last_evals = completed - log_eval_stats(completed, name) - print_eval_stats(completed, name) - - return completed - - -__all__ = ["EvalMixin"] diff --git a/hud/eval/parallel.py b/hud/eval/parallel.py index d5651d6d..59d27d91 100644 --- a/hud/eval/parallel.py +++ b/hud/eval/parallel.py @@ -155,56 +155,6 @@ def log_eval_stats(completed: list[EvalContext], context: str = "") -> None: ) -async def execute_parallel_evals( - contexts: list[EvalContext], - caller_frame_depth: int = 2, -) -> list[EvalContext]: - """Execute evaluations in parallel using AST extraction. - - This is the shared implementation for parallel execution. It: - 1. Captures the caller's frame and extracts with-block body - 2. Runs all provided EvalContexts in parallel - 3. Logs statistics - - Args: - contexts: Pre-created EvalContext instances to run - caller_frame_depth: How many frames to go up to find user code - (default 2: execute_parallel_evals -> caller -> user) - - Returns: - List of completed EvalContext objects with results - """ - import inspect - - # Get the caller's frame - frame = inspect.currentframe() - if frame is None: - raise ASTExtractionError("Cannot get current frame") - - try: - # Go up the specified number of frames - caller_frame = frame - for _ in range(caller_frame_depth): - if caller_frame is not None: - caller_frame = caller_frame.f_back - if caller_frame is None: - raise ASTExtractionError("Cannot get caller frame") - - body_source, captured_locals, context_var = get_with_block_body(caller_frame) - - finally: - del frame - - # Run in parallel - logger.info("Running %d parallel evals", len(contexts)) - completed = await run_parallel_evals(contexts, body_source, captured_locals, context_var) - - # Log stats - log_eval_stats(completed) - - return completed - - class ASTExtractionError(Exception): """Error extracting AST from source.""" @@ -295,63 +245,11 @@ def _extract_body(lines: list[str], with_node: ast.AsyncWith) -> str: return textwrap.dedent(body) -async def run_parallel_evals( - eval_contexts: list[EvalContext], - body_source: str, - captured_locals: dict[str, Any], - context_var: str, - max_concurrent: int | None = None, -) -> list[EvalContext]: - """Run the eval body in parallel for multiple contexts. - - Returns the EvalContext objects after execution - they contain: - - trace_id - - index - - reward - - duration - - Any error is captured in the context - - Args: - eval_contexts: List of EvalContext instances to run - body_source: The source code of the with-block body - captured_locals: Local variables captured from the caller - context_var: The variable name used in the 'as' clause - max_concurrent: Maximum concurrent evals (None = unlimited) - """ - - # Create runner function using the actual variable name from the 'as' clause - wrapped = f"async def __runner__({context_var}):\n{textwrap.indent(body_source, ' ')}" - code = compile(wrapped, "", "exec") - namespace = captured_locals.copy() - exec(code, namespace) # noqa: S102 - runner = namespace["__runner__"] - - # Create semaphore for concurrency control - sem = asyncio.Semaphore(max_concurrent) if max_concurrent else None - - async def run_one(ctx: EvalContext) -> EvalContext: - try: - if sem: - async with sem, ctx: - await runner(ctx) - else: - async with ctx: - await runner(ctx) - except Exception as e: - logger.warning("Parallel eval %d failed: %s", ctx.index, e) - ctx.error = e - return ctx - - results = await asyncio.gather(*[run_one(ctx) for ctx in eval_contexts]) - return list(results) - - __all__ = [ "ASTExtractionError", - "execute_parallel_evals", "expand_variants", + "find_user_frame", "get_with_block_body", "log_eval_stats", "resolve_group_ids", - "run_parallel_evals", ] diff --git a/hud/eval/tests/test_eval.py b/hud/eval/tests/test_eval.py new file mode 100644 index 00000000..26471244 --- /dev/null +++ b/hud/eval/tests/test_eval.py @@ -0,0 +1,236 @@ +"""Tests for hud.eval.eval module (Eval class).""" + +from __future__ import annotations + +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest + +from hud.eval.eval import Eval + + +class TestEvalDataclass: + """Tests for Eval as a data class.""" + + def test_init_defaults(self) -> None: + """Eval initializes with sensible defaults.""" + ev = Eval() + + assert ev.env_config is None + assert ev.script is None + assert ev.args == {} + assert ev.variants == {} + assert ev.index == 0 + + def test_init_with_config(self) -> None: + """Eval can be initialized with env_config and script.""" + config = {"name": "test-env", "hubs": []} + ev = Eval(env_config=config, script="checkout", args={"user_id": "alice"}) + + assert ev.env_config == config + assert ev.script == "checkout" + assert ev.args == {"user_id": "alice"} + + def test_copy_creates_new_instance(self) -> None: + """copy() creates a new Eval instance.""" + original = Eval( + env_config={"name": "test"}, + script="checkout", + args={"user_id": "alice"}, + variants={"model": "gpt-4o"}, + ) + copied = original.copy() + + assert copied is not original + assert copied.env_config == original.env_config + assert copied.script == original.script + assert copied.args == original.args + assert copied.args is not original.args # Deep copy + assert copied.variants == original.variants + assert copied.variants is not original.variants # Deep copy + + def test_copy_clears_trace_id(self) -> None: + """copy() clears trace_id for fresh instance.""" + original = Eval(trace_id="original-trace") + copied = original.copy() + + assert copied.trace_id is None + + +class TestEvalToEvalContext: + """Tests for Eval.to_eval_context().""" + + def test_creates_eval_context(self) -> None: + """to_eval_context() creates an EvalContext.""" + from hud.eval.context import EvalContext + + ev = Eval(script="checkout") + ctx = ev.to_eval_context() + + assert isinstance(ctx, EvalContext) + assert ctx.eval_name == "checkout" + + def test_uses_eval_as_name_when_no_script(self) -> None: + """to_eval_context() uses 'eval' as name when no script.""" + ev = Eval() + ctx = ev.to_eval_context() + + assert ctx.eval_name == "eval" + + def test_passes_through_properties(self) -> None: + """to_eval_context() passes through properties.""" + ev = Eval( + script="checkout", + trace_id="test-trace", + api_key="test-key", + job_id="test-job", + group_id="test-group", + index=5, + variants={"model": "gpt-4o"}, + ) + ctx = ev.to_eval_context() + + assert ctx.trace_id == "test-trace" + assert ctx._eval_api_key == "test-key" + assert ctx.job_id == "test-job" + assert ctx.group_id == "test-group" + assert ctx.index == 5 + assert ctx.variants == {"model": "gpt-4o"} + + +class TestEvalContextManager: + """Tests for Eval as async context manager.""" + + @pytest.mark.asyncio + async def test_aenter_returns_eval_context(self) -> None: + """__aenter__ returns an EvalContext.""" + from hud.eval.context import EvalContext + + ev = Eval() # No script to avoid script lookup + + with ( + patch.object(EvalContext, "_eval_enter", new_callable=AsyncMock), + patch.object(EvalContext, "_eval_exit", new_callable=AsyncMock), + patch.object(EvalContext, "__aexit__", new_callable=AsyncMock), + ): + ctx = await ev.__aenter__() + assert isinstance(ctx, EvalContext) + # Clean up manually since we patched __aexit__ + ev._ctx = None + + @pytest.mark.asyncio + async def test_context_clears_on_exit(self) -> None: + """__aexit__ clears internal context reference.""" + from hud.eval.context import EvalContext + + ev = Eval() + + with ( + patch.object(EvalContext, "_eval_enter", new_callable=AsyncMock), + patch.object(EvalContext, "_eval_exit", new_callable=AsyncMock), + patch.object(EvalContext, "__aexit__", new_callable=AsyncMock) as mock_exit, + ): + ctx = await ev.__aenter__() + assert ev._ctx is not None + + # Manually call __aexit__ on Eval (which will call mocked ctx.__aexit__) + await ev.__aexit__(None, None, None) + assert ev._ctx is None + + @pytest.mark.asyncio + async def test_reward_accessible_after_exit(self) -> None: + """Reward set in context is accessible after exit.""" + from hud.eval.context import EvalContext + + ev = Eval() + + with ( + patch.object(EvalContext, "_eval_enter", new_callable=AsyncMock), + patch.object(EvalContext, "_eval_exit", new_callable=AsyncMock), + patch.object(EvalContext, "__aexit__", new_callable=AsyncMock), + ): + ctx = await ev.__aenter__() + ctx.reward = 0.95 + + await ev.__aexit__(None, None, None) + # Context reference is cleared but reward was set on the actual context + + +class TestEvalFromApi: + """Tests for _eval_from_api helper.""" + + def test_creates_eval_from_api_response(self) -> None: + """_eval_from_api creates Eval from API response.""" + from hud.eval.manager import _eval_from_api + + data = { + "env_config": {"name": "test-env", "hubs": []}, + "script": "checkout", + "args": {"user_id": "alice"}, + } + + ev = _eval_from_api(data) + + assert ev.env_config == {"name": "test-env", "hubs": []} + assert ev.script == "checkout" + assert ev.args == {"user_id": "alice"} + + def test_handles_missing_optional_fields(self) -> None: + """_eval_from_api handles missing optional fields.""" + from hud.eval.manager import _eval_from_api + + data = {} # Minimal response + + ev = _eval_from_api(data) + + assert ev.env_config is None + assert ev.script is None + assert ev.args == {} + + +class TestEnvironmentCall: + """Tests for Environment.__call__ returning Eval.""" + + def test_call_returns_eval(self) -> None: + """Environment() returns an Eval object.""" + from hud.environment import Environment + + env = Environment("test-env") + ev = env() + + assert isinstance(ev, Eval) + + def test_call_with_script_sets_script(self) -> None: + """Environment(script) sets script name.""" + from hud.environment import Environment + + env = Environment("test-env") + ev = env("checkout") + + assert ev.script == "checkout" + + def test_call_with_args_sets_args(self) -> None: + """Environment(script, **args) sets args.""" + from hud.environment import Environment + + env = Environment("test-env") + ev = env("checkout", user_id="alice", amount=100) + + assert ev.args == {"user_id": "alice", "amount": 100} + + def test_call_captures_env_config_when_configured(self) -> None: + """Environment() captures env config when there's something to store.""" + from hud.environment import Environment + + # Plain env has no config (nothing to reconstruct) + env = Environment("test-env") + ev = env() + assert ev.env_config is None # Nothing to store + + # Env with setup_tool has config + env2 = Environment("test-env").setup_tool("navigate", url="https://example.com") + ev2 = env2() + assert ev2.env_config is not None + assert ev2.env_config["name"] == "test-env" + assert len(ev2.env_config["setup_tools"]) == 1 + diff --git a/hud/eval/tests/test_manager.py b/hud/eval/tests/test_manager.py new file mode 100644 index 00000000..53e1de80 --- /dev/null +++ b/hud/eval/tests/test_manager.py @@ -0,0 +1,133 @@ +"""Tests for hud.eval.manager module (hud.eval() function).""" + +from __future__ import annotations + +from unittest.mock import AsyncMock, patch + +import pytest + +from hud.eval.context import EvalContext, get_current_trace_headers +from hud.eval.manager import run_eval + + +class TestRunEvalNoArgs: + """Tests for hud.eval() with no arguments (blank eval).""" + + @pytest.mark.asyncio + async def test_blank_eval_creates_context(self) -> None: + """hud.eval() with no args creates an EvalContext.""" + with patch.object(EvalContext, "_eval_enter", new_callable=AsyncMock): + with patch.object(EvalContext, "_eval_exit", new_callable=AsyncMock): + async with run_eval() as ctx: + assert isinstance(ctx, EvalContext) + assert ctx.eval_name == "eval" + + @pytest.mark.asyncio + async def test_blank_eval_generates_trace_id(self) -> None: + """hud.eval() with no args generates a trace_id.""" + with patch.object(EvalContext, "_eval_enter", new_callable=AsyncMock): + with patch.object(EvalContext, "_eval_exit", new_callable=AsyncMock): + async with run_eval() as ctx: + assert ctx.trace_id is not None + assert len(ctx.trace_id) == 36 # UUID format + + @pytest.mark.asyncio + async def test_blank_eval_sets_trace_headers(self) -> None: + """hud.eval() sets trace headers in contextvar during context.""" + with patch.object(EvalContext, "_eval_enter", new_callable=AsyncMock): + with patch.object(EvalContext, "_eval_exit", new_callable=AsyncMock): + # Before context, no headers + assert get_current_trace_headers() is None + + async with run_eval() as ctx: + # Inside context, headers are set + headers = get_current_trace_headers() + assert headers is not None + assert headers["Trace-Id"] == ctx.trace_id + + # After context, headers are cleared + assert get_current_trace_headers() is None + + @pytest.mark.asyncio + async def test_blank_eval_reward_can_be_set(self) -> None: + """hud.eval() allows setting reward on context.""" + with patch.object(EvalContext, "_eval_enter", new_callable=AsyncMock): + with patch.object(EvalContext, "_eval_exit", new_callable=AsyncMock): + async with run_eval() as ctx: + assert ctx.reward is None + ctx.reward = 0.95 + + assert ctx.reward == 0.95 + + @pytest.mark.asyncio + async def test_blank_eval_reports_reward_on_exit(self) -> None: + """hud.eval() reports reward to backend on exit.""" + with patch.object(EvalContext, "_eval_enter", new_callable=AsyncMock): + with patch.object(EvalContext, "_eval_exit", new_callable=AsyncMock) as mock_exit: + async with run_eval() as ctx: + ctx.reward = 0.85 + + # _eval_exit should have been called (with no error) + mock_exit.assert_called_once_with(None) + + @pytest.mark.asyncio + async def test_blank_eval_empty_variants(self) -> None: + """hud.eval() with no args has empty variants dict.""" + with patch.object(EvalContext, "_eval_enter", new_callable=AsyncMock): + with patch.object(EvalContext, "_eval_exit", new_callable=AsyncMock): + async with run_eval() as ctx: + assert ctx.variants == {} + + @pytest.mark.asyncio + async def test_blank_eval_has_headers_property(self) -> None: + """hud.eval() context has headers property for gateway integration.""" + with patch.object(EvalContext, "_eval_enter", new_callable=AsyncMock): + with patch.object(EvalContext, "_eval_exit", new_callable=AsyncMock): + async with run_eval() as ctx: + headers = ctx.headers + assert "Trace-Id" in headers + assert headers["Trace-Id"] == ctx.trace_id + + +class TestRunEvalWithApiKey: + """Tests for hud.eval() with api_key parameter.""" + + @pytest.mark.asyncio + async def test_api_key_passed_to_context(self) -> None: + """hud.eval(api_key=...) passes api_key to context.""" + with patch.object(EvalContext, "_eval_enter", new_callable=AsyncMock): + with patch.object(EvalContext, "_eval_exit", new_callable=AsyncMock): + async with run_eval(api_key="test-key") as ctx: + assert ctx._eval_api_key == "test-key" + + +class TestRunEvalWithJobId: + """Tests for hud.eval() with job_id parameter.""" + + @pytest.mark.asyncio + async def test_job_id_passed_to_context(self) -> None: + """hud.eval(job_id=...) passes job_id to context.""" + with patch.object(EvalContext, "_eval_enter", new_callable=AsyncMock): + with patch.object(EvalContext, "_eval_exit", new_callable=AsyncMock): + async with run_eval(job_id="job-123") as ctx: + assert ctx.job_id == "job-123" + + +class TestRunEvalErrorHandling: + """Tests for hud.eval() error handling.""" + + @pytest.mark.asyncio + async def test_error_tracked_on_exception(self) -> None: + """hud.eval() tracks error when exception occurs.""" + with patch.object(EvalContext, "_eval_enter", new_callable=AsyncMock): + with patch.object(EvalContext, "_eval_exit", new_callable=AsyncMock) as mock_exit: + with pytest.raises(ValueError): + async with run_eval() as ctx: + raise ValueError("test error") + + # _eval_exit should have been called with error message + mock_exit.assert_called_once() + error_msg = mock_exit.call_args[0][0] + assert error_msg is not None + assert "test error" in error_msg + diff --git a/hud/eval/tests/test_mixin.py b/hud/eval/tests/test_mixin.py deleted file mode 100644 index 45ee4b53..00000000 --- a/hud/eval/tests/test_mixin.py +++ /dev/null @@ -1,129 +0,0 @@ -"""Tests for hud.eval.mixin module.""" - -from __future__ import annotations - -from typing import Any -from unittest.mock import MagicMock - -import pytest - -from hud.eval.mixin import EvalMixin -from hud.eval.parallel import expand_variants - - -class TestExpandVariants: - """Tests for expand_variants helper.""" - - def test_none_returns_empty_dict(self) -> None: - result = expand_variants(None) - assert result == [{}] - - def test_single_value_stays_single(self) -> None: - result = expand_variants({"model": "gpt-4o"}) - assert result == [{"model": "gpt-4o"}] - - def test_list_expands_to_variants(self) -> None: - result = expand_variants({"model": ["gpt-4o", "claude"]}) - assert result == [{"model": "gpt-4o"}, {"model": "claude"}] - - def test_multiple_lists_create_combinations(self) -> None: - result = expand_variants({"model": ["a", "b"], "temp": [0.0, 1.0]}) - assert len(result) == 4 - assert {"model": "a", "temp": 0.0} in result - assert {"model": "b", "temp": 1.0} in result - - -class MockEnvironment(EvalMixin): - """Mock environment for testing EvalMixin.""" - - def __init__(self) -> None: - self.name = "test-env" - self._connections: dict[str, Any] = {} - self._last_evals = None - self._hub_configs: list[dict[str, Any]] = [] - self._setup_calls: list[tuple[str, dict[str, Any]]] = [] - self._evaluate_calls: list[tuple[str, dict[str, Any]]] = [] - self.prompt: str | None = None - - @property - def is_parallelizable(self) -> bool: - return all(getattr(c, "is_remote", True) for c in self._connections.values()) - - @property - def local_connections(self) -> list[str]: - return [name for name, c in self._connections.items() if getattr(c, "is_local", False)] - - -class TestEvalMixin: - """Tests for EvalMixin.""" - - @pytest.mark.asyncio - async def test_eval_single_creates_context(self) -> None: - """eval() with group=1 creates single EvalContext.""" - env = MockEnvironment() - - async with env.eval("test-task") as ctx: - assert ctx.eval_name == "test-task" - assert ctx.trace_id is not None - assert ctx.variants == {} - - @pytest.mark.asyncio - async def test_eval_sets_reward(self) -> None: - """reward can be set on EvalContext.""" - env = MockEnvironment() - - async with env.eval("test-task") as ctx: - ctx.reward = 0.95 - - assert ctx.reward == 0.95 - - @pytest.mark.asyncio - async def test_eval_with_variants_single(self) -> None: - """eval() with single variant value works.""" - env = MockEnvironment() - - async with env.eval("test-task", variants={"model": "gpt-4o"}) as ctx: - assert ctx.variants == {"model": "gpt-4o"} - - @pytest.mark.asyncio - async def test_eval_rejects_parallel_with_local_connections(self) -> None: - """eval() raises error for parallel with local connections.""" - env = MockEnvironment() - - # Add a local connection - mock_conn = MagicMock() - mock_conn.is_local = True - mock_conn.is_remote = False - env._connections["local-server"] = mock_conn - - with pytest.raises(ValueError, match="Cannot run parallel evals"): - async with env.eval("test-task", group=2) as _ctx: - pass - - @pytest.mark.asyncio - async def test_eval_allows_parallel_with_remote_connections(self) -> None: - """eval() allows parallel with only remote connections.""" - env = MockEnvironment() - - # Add a remote connection - mock_conn = MagicMock() - mock_conn.is_local = False - mock_conn.is_remote = True - env._connections["remote-server"] = mock_conn - - # Just verify it doesn't raise the local connection error - assert env.is_parallelizable is True - - @pytest.mark.asyncio - async def test_eval_rejects_zero_group(self) -> None: - """eval() raises error for group <= 0.""" - env = MockEnvironment() - - with pytest.raises(ValueError, match="group must be >= 1"): - async with env.eval("test-task", group=0) as _ctx: - pass - - def test_last_evals_none_initially(self) -> None: - """last_evals is None before any parallel execution.""" - env = MockEnvironment() - assert env.last_evals is None diff --git a/hud/eval/tests/test_parallel.py b/hud/eval/tests/test_parallel.py index 8750b447..4e55b8fb 100644 --- a/hud/eval/tests/test_parallel.py +++ b/hud/eval/tests/test_parallel.py @@ -3,7 +3,6 @@ from __future__ import annotations import ast -from unittest.mock import AsyncMock, MagicMock import pytest @@ -14,7 +13,6 @@ _get_end_line, expand_variants, resolve_group_ids, - run_parallel_evals, ) @@ -160,69 +158,6 @@ def test_extract_body(self) -> None: assert "more_thing()" in body -class TestRunParallelEvals: - """Tests for run_parallel_evals function.""" - - @pytest.mark.asyncio - async def test_runs_body_for_each_context(self) -> None: - """run_parallel_evals runs body for each EvalContext.""" - # Create mock eval contexts - mock_ctxs = [] - for i in range(3): - ctx = MagicMock() - ctx.index = i - ctx.__aenter__ = AsyncMock(return_value=ctx) - ctx.__aexit__ = AsyncMock(return_value=None) - mock_ctxs.append(ctx) - - # Simple body that sets reward - body_source = "env.reward = env.index * 10" - captured_locals: dict[str, object] = {} - - results = await run_parallel_evals(mock_ctxs, body_source, captured_locals, "env") - - assert len(results) == 3 - # Each context should have had __aenter__ and __aexit__ called - for ctx in mock_ctxs: - ctx.__aenter__.assert_called_once() - ctx.__aexit__.assert_called_once() - - @pytest.mark.asyncio - async def test_captures_exceptions(self) -> None: - """run_parallel_evals captures exceptions in context.""" - ctx = MagicMock() - ctx.index = 0 - ctx.__aenter__ = AsyncMock(return_value=ctx) - ctx.__aexit__ = AsyncMock(return_value=None) - - # Body that raises - body_source = "raise ValueError('test error')" - captured_locals: dict[str, object] = {} - - results = await run_parallel_evals([ctx], body_source, captured_locals, "env") - - assert len(results) == 1 - # Error should be captured, not raised - assert hasattr(ctx, "error") or ctx.__aexit__.called - - @pytest.mark.asyncio - async def test_uses_captured_locals(self) -> None: - """run_parallel_evals uses captured locals in body execution.""" - ctx = MagicMock() - ctx.index = 0 - ctx.result = None - ctx.__aenter__ = AsyncMock(return_value=ctx) - ctx.__aexit__ = AsyncMock(return_value=None) - - # Body that uses captured local - body_source = "env.result = my_value * 2" - captured_locals = {"my_value": 21} - - results = await run_parallel_evals([ctx], body_source, captured_locals, "env") - - assert len(results) == 1 - - class TestASTExtractionError: """Tests for ASTExtractionError.""" diff --git a/hud/types.py b/hud/types.py index f4fd5c5b..455a9ed9 100644 --- a/hud/types.py +++ b/hud/types.py @@ -71,11 +71,17 @@ class BaseAgentConfig(BaseModel): class Task(BaseModel): """ + DEPRECATED: Use Eval from env() instead. + A task configuration that can be used to create a task. The mcp_config field supports environment variable substitution using template placeholders in the format ${VAR_NAME} or ${VAR_NAME:default_value}. + .. deprecated:: 0.5.0 + Task is deprecated. Use `env("script_name", **args)` to create Eval objects, + or use string slugs with `hud.eval("org/evalset:*")`. + Example: mcp_config: { "hud": { @@ -97,6 +103,18 @@ class Task(BaseModel): agent_config: BaseAgentConfig | None = None metadata: dict[str, Any] = Field(default_factory=dict) + def __init__(self, **data: Any) -> None: + """Initialize Task with deprecation warning.""" + import warnings + + warnings.warn( + "Task is deprecated. Use env('script_name', **args) to create Eval objects, " + "or use string slugs with hud.eval('org/evalset:*').", + DeprecationWarning, + stacklevel=2, + ) + super().__init__(**data) + @field_validator("mcp_config", "metadata", mode="before") @classmethod def parse_json_strings(cls, v: Any) -> Any: From 51ff45983b699da659e95071e8cded10aa51592d Mon Sep 17 00:00:00 2001 From: lorenss-m Date: Wed, 10 Dec 2025 11:34:12 -0800 Subject: [PATCH 18/92] docs updates --- docs/docs.json | 86 ++++++++++++- docs/gateway/{index.mdx => index-legacy.mdx} | 3 +- docs/quick-links/ab-testing.mdx | 61 +++++++++ docs/quick-links/deploy.mdx | 56 ++++++++ docs/quick-links/environments.mdx | 104 +++++++++++++++ docs/quick-links/gateway.mdx | 128 +++++++++++++++++++ 6 files changed, 436 insertions(+), 2 deletions(-) rename docs/gateway/{index.mdx => index-legacy.mdx} (99%) create mode 100644 docs/quick-links/ab-testing.mdx create mode 100644 docs/quick-links/deploy.mdx create mode 100644 docs/quick-links/environments.mdx create mode 100644 docs/quick-links/gateway.mdx diff --git a/docs/docs.json b/docs/docs.json index d2f9e789..b3629164 100644 --- a/docs/docs.json +++ b/docs/docs.json @@ -28,6 +28,90 @@ }, "navigation": { "versions": [ + { + "version": "0.5.0", + "groups": [ + { + "group": "Get Started", + "pages": [ + "index", + "quickstart", + "llm-quickstart" + ] + }, + { + "group": "Essentials", + "pages": [ + "quick-links/gateway", + "quick-links/ab-testing", + "quick-links/environments", + "quick-links/deploy" + ] + }, + { + "group": "Core Concepts", + "pages": [ + "core-concepts/architecture", + "core-concepts/mcp-protocol", + "core-concepts/task-system" + ] + }, + { + "group": "SDK Reference", + "pages": [ + "reference/eval", + "reference/tools", + "reference/agents", + "reference/types", + "reference/environments", + "reference/tasks" + ] + }, + { + "group": "Environments", + "pages": [ + "build-environments/index", + "build-environments/spec" + ] + }, + { + "group": "Beta Features", + "pages": [ + "beta/index", + "beta/rft" + ] + }, + { + "group": "Agents", + "pages": [ + "evaluate-agents/create-agents", + "evaluate-agents/benchmarks" + ] + }, + { + "group": "CLI Reference", + "pages": [ + "reference/cli/overview", + "reference/cli/init", + "reference/cli/dev", + "reference/cli/build", + "reference/cli/push", + "reference/cli/analyze", + "reference/cli/debug", + "reference/cli/run", + "reference/cli/eval", + "reference/cli/rft", + "reference/cli/misc" + ] + }, + { + "group": "Community", + "pages": [ + "contributing" + ] + } + ] + }, { "version": "0.4.73", "groups": [ @@ -68,7 +152,7 @@ { "group": "HUD Gateway", "pages": [ - "gateway/index" + "gateway/index-legacy" ] }, { diff --git a/docs/gateway/index.mdx b/docs/gateway/index-legacy.mdx similarity index 99% rename from docs/gateway/index.mdx rename to docs/gateway/index-legacy.mdx index ea235980..a60b6811 100644 --- a/docs/gateway/index.mdx +++ b/docs/gateway/index-legacy.mdx @@ -1,5 +1,5 @@ --- -title: "HUD Gateway" +title: "Gateway" description: "Unified LLM inference service with built-in auth and credit management." icon: "server" --- @@ -128,3 +128,4 @@ This example demonstrates: - Automatic token usage and latency tracking View your traces on the [HUD Dashboard](https://hud.ai/home). + diff --git a/docs/quick-links/ab-testing.mdx b/docs/quick-links/ab-testing.mdx new file mode 100644 index 00000000..0c1215f8 --- /dev/null +++ b/docs/quick-links/ab-testing.mdx @@ -0,0 +1,61 @@ +--- +title: "A/B Evals" +description: "Find out which model actually performs best for your use case." +icon: "flask-vial" +--- + +LLM outputs vary from run to run—ask the same question twice and you might get different quality answers. To find out which model actually performs best, you need to test each one multiple times and look at the spread. **Variants** let you test different models side-by-side. **Groups** repeat each test so you see the full distribution, not just one lucky or unlucky result. + +## Variants + +Pass the configurations you want to test: + +```python +import hud + +async with hud.eval(variants={"model": ["gpt-4o", "claude-sonnet-4-5"]}) as ctx: + response = await client.chat.completions.create( + model=ctx.variants["model"], + messages=[{"role": "user", "content": "What is 2+2?"}] + ) + ctx.reward = 1.0 if "4" in response.choices[0].message.content else 0.0 + +for result in ctx.results: + print(f"{result.variants}: reward={result.reward}") +``` + +## Groups + +Run each variant multiple times to get a distribution: + +```python +async with hud.eval( + variants={"model": ["gpt-4o", "claude-sonnet-4-5"]}, + group=5 # 10 runs total: 2 models × 5 each +) as ctx: + ... +``` + +The `hud.eval` manager will parallelize your evals automatically and show the distribution across all your runs on [hud.ai](https://hud.ai/home). + +## Remote Rollouts + +Once you've [deployed an environment](/quick-links/deploy#deploying-environments) and created evals, run them by name: + +```python +async with hud.eval("my-org/checkout-laptop", variants={"model": ["gpt-4o", "claude"]}) as ctx: + response = await client.chat.completions.create( + model=ctx.variants["model"], + messages=[{"role": "user", "content": ctx.prompt}] + ) +``` + +The platform loads everything—environment, prompt, evaluation logic, comparisons across models. You just provide the agent. + +Or via CLI: + +```bash +hud eval my-org/checkout-laptop --model gpt-4o --group-size 5 +``` + +Or run directly on the platform—see [Running at Scale](/quick-links/deploy#running-at-scale). diff --git a/docs/quick-links/deploy.mdx b/docs/quick-links/deploy.mdx new file mode 100644 index 00000000..4d5f137f --- /dev/null +++ b/docs/quick-links/deploy.mdx @@ -0,0 +1,56 @@ +--- +title: "Deploy" +description: "Deploy environments. Create evals. Run and train at scale." +icon: "rocket" +--- + +You've built an environment with tools and scripts. Deploy it to the platform and you can run evals at scale—hundreds of parallel runs across models, all traced, all generating training data. + +## Deploying Environments + +Start with `hud init` ([see Environments](/quick-links/environments)) to scaffold locally. When ready: + +1. Go to [hud.ai](https://hud.ai) → **New Environment** +2. Connect your GitHub repo and name your environment +3. Push changes and it rebuilds automatically, like Vercel + +Your environment—tools, scripts, everything—is now live. Connect from anywhere: + +```python +env.connect_hub("my-org/my-env") +``` + +## Running at Scale + +Once deployed, create evals on [hud.ai](https://hud.ai) from your scripts. Each eval is a frozen configuration—same prompt, same scoring, every time. + +Your script might take arguments: + +```python +@env.script("checkout") +async def checkout_flow(product_name: str, apply_coupon: bool = False): + yield f"Complete checkout for {product_name}" + (" with coupon" if apply_coupon else "") + yield 1.0 if order_confirmed() else 0.0 +``` + +On the platform, click **New Eval** → select your script → fill in the arguments. Create multiple evals from the same script: + +| Eval Name | Arguments | +|-----------|-----------| +| `checkout-laptop` | `product_name="Laptop"`, `apply_coupon=False` | +| `checkout-phone-coupon` | `product_name="Phone"`, `apply_coupon=True` | +| `checkout-headphones` | `product_name="Headphones"`, `apply_coupon=False` | + +Then run them—select an eval, choose variants and groups, launch hundreds of runs in parallel. Every run is traced. Results show scores, distributions, and side-by-side model comparisons. These become your training data. + +For A/B testing with variants and groups, see [A/B Evals](/quick-links/ab-testing). + +## What's Next? + +With your environment deployed: + +- **Scale**: Launch thousands of rollouts. Every run generates traces—prompts, tool calls, rewards. +- **Analyze**: See which evals agents struggle with. Compare models across your entire benchmark. +- **Train**: Use runs as training data. Fine-tune on successful completions. Run reinforcement learning to optimize for your specific environment. + +The loop: deploy → eval at scale → analyze → train → redeploy. Agents get better at *your* environment. diff --git a/docs/quick-links/environments.mdx b/docs/quick-links/environments.mdx new file mode 100644 index 00000000..64c9a6af --- /dev/null +++ b/docs/quick-links/environments.mdx @@ -0,0 +1,104 @@ +--- +title: "Environments" +description: "Turn your code into agent-callable tools. Define how agents are evaluated." +icon: "cube" +--- + +An environment is everything an agent can interact with—your APIs, services, databases, wrapped as tools. But it's more than that: the environment also defines how agents are *evaluated* through **scripts**. When you deploy an environment, you're creating a sandbox that agents can learn from at scale. + +## Tools + +Start with `hud init` to scaffold an environment—works with existing codebases or from scratch: + +```bash +hud init my-env +``` + +Every tool is just a function. Decorate it with `@env.tool()` and agents can call it: + +```python +from hud import Environment + +env = Environment("my-env") + +@env.tool() +async def search(query: str) -> str: + """Search the knowledge base.""" + return db.search(query) +``` + +Got a FastAPI app? One line: + +```python +env.connect_fastapi(app) +``` + +All your routes become tools. Run it: + +```python +async with env() as ctx: + tools = await ctx.list_tools() + result = await ctx.call_tool("search", query="test") +``` + +## Scripts + +To evaluate an agent, you need two things: what to tell it, and how to score what it did. Scripts capture both with two `yield` statements: + +```python +@env.script("checkout") +async def checkout_flow(product_name: str): + # Yield the prompt, receive the agent's final answer + answer = yield f"Add '{product_name}' to cart and complete checkout" + + # Score based on environment state and/or the answer + order_exists = await check_order_status(product_name) + yield 1.0 if order_exists else 0.0 +``` + +The agent runs between the yields. First yield sends the prompt and returns the agent's answer. Second yield checks environment state—database rows, files, API calls—and returns a reward. Scripts live with the environment because only the environment knows how to verify what happened. + +## Evals + +Call the environment with a script name and arguments to create an eval: + +```python +eval = env("checkout", product_name="Laptop") + +async with hud.eval(eval, group=4) as ctx: + response = await client.chat.completions.create( + model="gpt-4o", + messages=[{"role": "user", "content": ctx.prompt}], + tools=ctx.as_openai_chat_tools() + ) + # Handle tool calls, run agent loop... + + ctx.submit(response.choices[0].message.content) + +print(ctx.reward) +``` + +This creates a trace on [hud.ai](https://hud.ai/home). Add [variants](/quick-links/ab-testing) to A/B test across models. To run evals at scale, [deploy your environment](/quick-links/deploy). + +## Mock Mode + +Testing your agent loop without hitting real services? Mock mode returns fake responses based on tool schemas: + +```python +env.mock() +env.mock_tool("search", "Mock search results") + +async with hud.eval(env(), group=4) as ctx: + tools = env.as_openai_chat_tools() + + response = await client.chat.completions.create( + model="claude-sonnet-4-5", + messages=[{"role": "user", "content": "Search for X"}], + tools=tools + ) + + # Returns mock value instead of hitting real service + result = await env.call_tool(response.choices[0].message.tool_calls[0]) +``` +Your agent code stays the same—just toggle `env.mock()` for local testing. + diff --git a/docs/quick-links/gateway.mdx b/docs/quick-links/gateway.mdx new file mode 100644 index 00000000..11d5d73d --- /dev/null +++ b/docs/quick-links/gateway.mdx @@ -0,0 +1,128 @@ +--- +title: "Gateway" +description: "One endpoint for every model. One API key. Full observability." +icon: "server" +--- + +Stop juggling API keys. HUD Gateway routes to Anthropic, OpenAI, Gemini, xAI, and more through a single OpenAI-compatible endpoint—with built-in telemetry. Swap `model="gpt-4o"` for `model="claude-sonnet-4-5"` and you're [A/B testing](/quick-links/ab-testing) across providers. Continuous RL from production coming soon. + +## Quick Start + +Point any OpenAI-compatible client at `inference.hud.ai`: + + + +```python Python +from openai import AsyncOpenAI +import os + +client = AsyncOpenAI( + base_url="https://inference.hud.ai", + api_key=os.environ["HUD_API_KEY"] +) + +response = await client.chat.completions.create( + model="claude-sonnet-4-5", # or gpt-4o, gemini-2.5-pro, grok-4-1-fast... + messages=[{"role": "user", "content": "Hello!"}] +) +``` + +```bash curl +curl -X POST https://inference.hud.ai/chat/completions \ + -H "Authorization: Bearer $HUD_API_KEY" \ + -H "Content-Type: application/json" \ + -d '{ + "model": "claude-sonnet-4-5", + "messages": [{"role": "user", "content": "Hello!"}] + }' +``` + + + +## Supported Models + +Full list at [hud.ai/models](https://hud.ai/models). + + +| Model | Routes | +|-------|--------| +| `claude-sonnet-4-5` | chat, messages | +| `claude-haiku-4-5` | chat, messages | +| `claude-opus-4-5` | chat, messages | +| `claude-opus-4-1` | chat, messages | + + + +| Model | Routes | +|-------|--------| +| `gpt-5.1` | chat, responses | +| `gpt-5-mini` | chat, responses | +| `gpt-4o` | chat, responses | +| `gpt-4o-mini` | chat, responses | +| `operator` | responses | + + + +| Model | Routes | +|-------|--------| +| `gemini-3-pro-preview` | chat | +| `gemini-2.5-pro` | chat | +| `gemini-2.5-computer-use-preview` | gemini | + + + +| Model | Routes | +|-------|--------| +| `grok-4-1-fast` | chat | +| `z-ai/glm-4.5v` | chat | + + +## Telemetry + +Wrap code in a plain `hud.eval()` to group inference calls. In the trace you'll see the full conversation in sequence, not scattered API calls. + +```python +async with hud.eval(): + response = await client.chat.completions.create( + model="claude-sonnet-4-5", + messages=[{"role": "user", "content": "Hello!"}] + ) +``` + +Or inject a trace ID manually if you're not using `hud.eval()`. Generate a UUID and pass it with each request in a task: + + + +```python Python +import uuid + +trace_id = str(uuid.uuid4()) # e.g. "a1b2c3d4-e5f6-7890-abcd-ef1234567890" + +response = await client.chat.completions.create( + model="claude-sonnet-4-5", + messages=[{"role": "user", "content": "Hello!"}], + extra_headers={"Trace-Id": trace_id} +) +``` + +```bash curl +curl -X POST https://inference.hud.ai/chat/completions \ + -H "Authorization: Bearer $HUD_API_KEY" \ + -H "Content-Type: application/json" \ + -H "Trace-Id: a1b2c3d4-e5f6-7890-abcd-ef1234567890" \ + -d '{ + "model": "claude-sonnet-4-5", + "messages": [{"role": "user", "content": "Hello!"}] + }' +``` + + + +View traces at [hud.ai/home](https://hud.ai/home). + +## Routes + +- **chat** — `/chat/completions` (OpenAI-compatible) +- **messages** — `/messages` (Anthropic-compatible) +- **responses** — `/responses` (OpenAI Responses API) +- **gemini** — Google Gemini native API From b6b0372facac6e4b0df5c37b2a302273ce5b9eef Mon Sep 17 00:00:00 2001 From: lorenss-m Date: Wed, 10 Dec 2025 14:20:50 -0800 Subject: [PATCH 19/92] fixes to server and cli --- hud/cli/dev.py | 41 +++-- hud/cli/flows/init.py | 89 ++++++++--- hud/cli/flows/templates.py | 83 ++++++---- hud/environment/connectors/local.py | 48 ++++-- hud/environment/connectors/remote.py | 10 +- hud/environment/environment.py | 78 ++++++++- hud/environment/integrations/openai.py | 26 +-- hud/environment/router.py | 12 +- hud/environment/scripts.py | 213 +++++++++++++++++++++++-- hud/eval/context.py | 98 +++++++++--- hud/eval/display.py | 17 +- hud/eval/eval.py | 206 ++++++++++++------------ hud/eval/instrument.py | 18 ++- hud/eval/manager.py | 146 ++++++++--------- hud/eval/tests/test_eval.py | 2 +- hud/eval/types.py | 3 +- hud/server/server.py | 1 - 17 files changed, 752 insertions(+), 339 deletions(-) diff --git a/hud/cli/dev.py b/hud/cli/dev.py index 74f39c8b..5809f118 100644 --- a/hud/cli/dev.py +++ b/hud/cli/dev.py @@ -144,7 +144,7 @@ def should_use_docker_mode(cwd: Path) -> bool: async def run_mcp_module( - module_name: str, + module_spec: str, transport: str, port: int, verbose: bool, @@ -152,7 +152,19 @@ async def run_mcp_module( interactive: bool, new_trace: bool = False, ) -> None: - """Run an MCP module directly.""" + """Run an MCP module directly. + + Args: + module_spec: Module specification in format "module" or "module:attribute" + e.g., "server" (looks for mcp), "env:env" (looks for env) + """ + # Parse module:attribute format (like uvicorn/gunicorn) + if ":" in module_spec: + module_name, attr_name = module_spec.rsplit(":", 1) + else: + module_name = module_spec + attr_name = "mcp" # Default attribute + # Check if this is a reload (not first run) is_reload = os.environ.get("_HUD_DEV_RELOAD") == "1" @@ -165,8 +177,10 @@ async def run_mcp_module( # Suppress tracebacks in logs unless verbose logging.basicConfig(stream=sys.stderr, level=logging.INFO, format="%(message)s") - # Suppress FastMCP's verbose error logging + # Suppress FastMCP's verbose logging logging.getLogger("fastmcp.tools.tool_manager").setLevel(logging.WARNING) + logging.getLogger("fastmcp.server.server").setLevel(logging.WARNING) + logging.getLogger("fastmcp.server.openapi").setLevel(logging.WARNING) # On reload, suppress most startup logs if is_reload: @@ -211,8 +225,7 @@ async def run_mcp_module( hud_console.info(traceback.format_exc()) sys.exit(1) - # Look for 'mcp' attribute - check module __dict__ directly - # Debug: print what's in the module + # Look for the specified attribute if verbose: hud_console.info(f"Module attributes: {dir(module)}") module_dict = module.__dict__ if hasattr(module, "__dict__") else {} @@ -220,22 +233,22 @@ async def run_mcp_module( mcp_server = None - # Try different ways to access the mcp variable - if hasattr(module, "mcp"): - mcp_server = module.mcp - elif hasattr(module, "__dict__") and "mcp" in module.__dict__: - mcp_server = module.__dict__["mcp"] + # Try different ways to access the attribute + if hasattr(module, attr_name): + mcp_server = getattr(module, attr_name) + elif hasattr(module, "__dict__") and attr_name in module.__dict__: + mcp_server = module.__dict__[attr_name] if mcp_server is None: - hud_console.error(f"Module '{module_name}' does not have 'mcp' defined") + hud_console.error(f"Module '{module_name}' does not have '{attr_name}' defined") hud_console.info("") available = [k for k in dir(module) if not k.startswith("_")] hud_console.info(f"Available in module: {available}") hud_console.info("") hud_console.info("[bold cyan]Expected structure:[/bold cyan]") - hud_console.info(" from hud.server import MCPServer") - hud_console.info(" mcp = MCPServer(name='my-server')") - raise AttributeError(f"Module '{module_name}' must define 'mcp'") + hud_console.info(" from hud.environment import Environment") + hud_console.info(f" {attr_name} = Environment('my-env')") + raise AttributeError(f"Module '{module_name}' must define '{attr_name}'") # Only show full header on first run, brief message on reload if is_reload: diff --git a/hud/cli/flows/init.py b/hud/cli/flows/init.py index 205dbac5..94c4c471 100644 --- a/hud/cli/flows/init.py +++ b/hud/cli/flows/init.py @@ -7,7 +7,7 @@ from hud.utils.hud_console import HUDConsole -from .templates import DOCKERFILE_HUD, HUD_PY, PYPROJECT_TOML +from .templates import DOCKERFILE_HUD, ENV_PY, PYPROJECT_TOML # Files that indicate this might be an existing project PROJECT_INDICATORS = { @@ -26,8 +26,24 @@ def _normalize_name(name: str) -> str: return "".join(c if c.isalnum() or c == "_" else "_" for c in name) -def _add_hud_dependency(directory: Path) -> bool: - """Add hud-python using uv if available.""" +def _has_hud_dependency(directory: Path) -> bool: + """Check if hud-python is already in pyproject.toml.""" + pyproject = directory / "pyproject.toml" + if not pyproject.exists(): + return False + content = pyproject.read_text() + return "hud-python" in content or "hud_python" in content + + +def _add_hud_dependency(directory: Path) -> str: + """Add hud-python using uv if available. + + Returns: + "exists" if already present, "added" if added, "failed" if failed + """ + if _has_hud_dependency(directory): + return "exists" + try: result = subprocess.run( ["uv", "add", "hud-python", "openai"], # noqa: S607 @@ -36,9 +52,11 @@ def _add_hud_dependency(directory: Path) -> bool: cwd=directory, check=False, ) - return result.returncode == 0 or "already" in result.stderr.lower() + if result.returncode == 0 or "already" in result.stderr.lower(): + return "added" + return "failed" except FileNotFoundError: - return False + return "failed" def _is_empty_or_trivial(directory: Path) -> bool: @@ -72,7 +90,21 @@ def smart_init( - If directory has project files: add HUD files to existing project - Otherwise: create new HUD environment """ + from hud.settings import settings + hud_console = HUDConsole() + + # Check for API key first + if not settings.api_key: + hud_console.error("HUD_API_KEY not found") + hud_console.info("") + hud_console.info("Set your API key:") + hud_console.info(" hud set HUD_API_KEY=your-key-here") + hud_console.info(" Or: export HUD_API_KEY=your-key") + hud_console.info("") + hud_console.info("Get your key at: https://hud.ai/settings/api-keys") + return + target = Path(directory).resolve() # If directory is empty, use preset selection @@ -111,17 +143,20 @@ def smart_init( else: hud_console.warning("Dockerfile.hud exists, skipping (use --force)") - # Create hud.py - hud_py = target / "hud.py" - if not hud_py.exists() or force: - hud_py.write_text(HUD_PY.format(env_name=env_name)) - created.append("hud.py") + # Create env.py + env_py = target / "env.py" + if not env_py.exists() or force: + env_py.write_text(ENV_PY.format(env_name=env_name)) + created.append("env.py") else: - hud_console.warning("hud.py exists, skipping (use --force)") + hud_console.warning("env.py exists, skipping (use --force)") # Add dependency - if _add_hud_dependency(target): + dep_result = _add_hud_dependency(target) + if dep_result == "added": hud_console.success("Added hud-python dependency") + elif dep_result == "exists": + hud_console.info("hud-python already in dependencies") else: hud_console.info("Run manually: uv add hud-python openai") @@ -132,23 +167,25 @@ def smart_init( hud_console.status_item(f, "✓") hud_console.section_title("Next Steps") - hud_console.info("1. Edit hud.py:") - hud_console.info(" - Add your tools with @env.tool()") - hud_console.info(" - Connect existing servers (FastAPI, MCP, OpenAPI)") hud_console.info("") - hud_console.info("2. Edit Dockerfile.hud:") - hud_console.info(" - Add system dependencies (apt-get install)") - hud_console.info(" - Set up data sources for production") + hud_console.info("1. Define your tools in env.py") + hud_console.info(" Tools are functions the agent can call. Wrap existing code") + hud_console.info(" with @env.tool() or connect FastAPI/OpenAPI servers.") + hud_console.info("") + hud_console.info("2. Write scripts that test agent behavior") + hud_console.info(" Scripts define prompts and scoring. The agent runs between") + hud_console.info(" two yields: first sends the task, second scores the result.") + hud_console.info("") + hud_console.info("3. Run locally to iterate") + hud_console.command_example("python env.py", "Run the test script") hud_console.info("") - hud_console.command_example("python hud.py", "Test locally") - hud_console.command_example("hud dev hud:env", "Development server") - hud_console.command_example("hud build", "Build Docker image") + hud_console.info("4. Deploy for scale") + hud_console.info(" Push to GitHub, connect on hud.ai. Then run hundreds of") + hud_console.info(" evals in parallel and collect training data.") hud_console.info("") - hud_console.section_title("Tips") - hud_console.info("• For production environments you want to mock locally,") - hud_console.info(" configure data sources in Dockerfile.hud before deploying") - hud_console.info("• For testing without real connections, use env.mock()") - hud_console.info("• See hud.py DEPLOYMENT section for remote deployment") + hud_console.section_title("Files") + hud_console.info("• env.py Your tools, scripts, and test code") + hud_console.info("• Dockerfile.hud Container config for remote deployment") __all__ = ["smart_init"] diff --git a/hud/cli/flows/templates.py b/hud/cli/flows/templates.py index 9d22f146..b96c7752 100644 --- a/hud/cli/flows/templates.py +++ b/hud/cli/flows/templates.py @@ -11,34 +11,52 @@ RUN pip install uv && uv sync --frozen --no-dev 2>/dev/null || uv sync --no-dev COPY . . -CMD ["uv", "run", "python", "-m", "hud", "dev", "hud:env", "--stdio"] +# Most of the time this command should not change, except if you change your env path +# or launch some other service before running the environment +CMD ["uv", "run", "python", "-m", "hud", "dev", "env:env", "--stdio"] """ # fmt: off -HUD_PY = '''\ +ENV_PY = '''\ """{env_name} - HUD Environment""" import asyncio -import os +import hud +from hud.settings import settings +from openai import AsyncOpenAI, Omit from hud.environment import Environment env = Environment("{env_name}") # ============================================================================= -# 1. ADD FUNCTIONS AS TOOLS +# 1. TOOLS - Functions the agent can call # ============================================================================= -# Decorate any function with @env.tool() to expose it as a tool. @env.tool() -def hud(query: str) -> str: - """A tool that returns the answer to any question.""" - return f"Oh, I know the answer to '{{query}}', it's 42." +def count_letter(text: str, letter: str) -> int: + """Count occurrences of a letter in text.""" + return text.lower().count(letter.lower()) # ============================================================================= -# 2. IMPORT FROM EXISTING SERVERS +# 2. SCRIPTS - Define prompts and evaluation logic +# ============================================================================= + +@env.script("count") +async def count_script(sentence: str, letter: str, fmt: str = "integer"): + """Agent must count a letter. We check if they got it right.""" + # Yield the prompt, receive the agent's final answer + answer = yield f"How many times does '{{letter}}' appear in: '{{sentence}}'? Format: {{fmt}}." + + # Score: 1.0 if correct, 0.0 otherwise + correct = str(sentence.lower().count(letter.lower())) + yield correct in answer + + +# ============================================================================= +# 3. CONNECT EXISTING SERVERS (optional) # ============================================================================= # --- FastAPI app --- @@ -52,11 +70,6 @@ def hud(query: str) -> str: # --- OpenAPI spec (URL or file path) --- # env.connect_openapi("https://api.example.com/openapi.json") - -# ============================================================================= -# 3. CONNECT REMOTE SERVERS -# ============================================================================= - # --- MCP config (stdio or SSE) --- # env.connect_mcp_config({{ # "my-server": {{"command": "uvx", "args": ["some-mcp-server"]}} @@ -67,35 +80,35 @@ def hud(query: str) -> str: # ============================================================================= -# TEST - Run with: python hud.py +# TEST - Run with: python env.py # ============================================================================= async def test(): - from openai import AsyncOpenAI + client = AsyncOpenAI( + base_url=settings.hud_gateway_url, + api_key=settings.api_key, + ) - async with env.task("test") as ctx: - # 1. List tools - tools = await env.list_tools() - print(f"Tools: {{[t.name for t in tools]}}") + # Create an eval from the script + eval = env("count", sentence="Strawberry world", letter="r") - # 2. Call the hud tool - result = await env.call_tool("hud", query="What is HUD?") - print(f"HUD result: {{result}}") - - # 3. Call inference.hud.ai - client = AsyncOpenAI( - base_url="https://inference.hud.ai/v1", - api_key=os.environ.get("HUD_API_KEY", ""), - ) + # Test with and without tools + async with hud.eval(eval, variants={{"tools": [True, False]}}) as ctx: response = await client.chat.completions.create( - model="claude-sonnet-4-5", - messages=[{{"role": "user", "content": "Say hello in one word."}}], + model="gpt-4o-mini", + messages=[{{"role": "user", "content": ctx.prompt}}], + tools=ctx.as_openai_chat_tools() if ctx.variants["tools"] else Omit(), ) - print(f"LLM: {{response.choices[0].message.content}}") - # 4. Assign reward - ctx.reward = 1.0 if "42" in str(result) else 0.0 - print(f"Reward: {{ctx.reward}}") + # Handle tool calls if present + message = response.choices[0].message + if message.tool_calls: + result = await ctx.call_tool(message.tool_calls[0]) + answer = str(result["content"]) + else: + answer = message.content + + await ctx.submit(answer or "") if __name__ == "__main__": diff --git a/hud/environment/connectors/local.py b/hud/environment/connectors/local.py index 66633221..1deea4b7 100644 --- a/hud/environment/connectors/local.py +++ b/hud/environment/connectors/local.py @@ -23,11 +23,9 @@ class LocalConnectorMixin(MCPConfigConnectorMixin): connect_server(server) - Mount any MCPServer/FastMCP directly Inherits connect_mcp() from MCPConfigConnectorMixin. - """ - def mount(self, server: Any, *, prefix: str | None = None) -> None: - """Mount method from MCPServer base class.""" - raise NotImplementedError + Note: include_router() is inherited from MCPServer (via FastMCP). + """ def connect_image( self, @@ -87,10 +85,18 @@ def connect_fastapi( *, name: str | None = None, prefix: str | None = None, + include_hidden: bool = True, ) -> Any: - """Mount a FastAPI application as an MCP server. + """Import a FastAPI application's routes as MCP tools. - Uses FastMCP's from_fastapi() to convert FastAPI endpoints to MCP tools. + Uses FastMCP's from_fastapi() to convert FastAPI endpoints to MCP tools, + then imports them synchronously so they're available immediately. + + Args: + app: FastAPI application instance + name: Custom name for the server (defaults to app.title) + prefix: Optional prefix for tool names + include_hidden: If True (default), includes routes with include_in_schema=False Example: ```python @@ -115,9 +121,29 @@ def get_user(user_id: int): """ from fastmcp import FastMCP - server_name = name or getattr(app, "title", None) or "fastapi" - mcp_server = FastMCP.from_fastapi(app=app, name=server_name) - self.mount(mcp_server, prefix=prefix) + # Temporarily enable hidden routes for OpenAPI generation + hidden_routes: list[Any] = [] + if include_hidden: + for route in getattr(app, "routes", []): + if hasattr(route, "include_in_schema") and not route.include_in_schema: + hidden_routes.append(route) + route.include_in_schema = True + # Clear cached openapi schema so it regenerates + if hasattr(app, "openapi_schema"): + app.openapi_schema = None + + try: + server_name = name or getattr(app, "title", None) or "fastapi" + mcp_server = FastMCP.from_fastapi(app=app, name=server_name) + # Use include_router for synchronous import (tools available immediately) + self.include_router(mcp_server, prefix=prefix) # type: ignore + finally: + # Restore original states + for route in hidden_routes: + route.include_in_schema = False + if hidden_routes and hasattr(app, "openapi_schema"): + app.openapi_schema = None # Clear cache again + return self def connect_server( @@ -126,7 +152,7 @@ def connect_server( *, prefix: str | None = None, ) -> Any: - """Mount an MCPServer or FastMCP instance directly. + """Import an MCPServer or FastMCP instance's tools directly. Example: ```python @@ -147,5 +173,5 @@ def greet(name: str) -> str: result = await env.call_tool("greet", name="World") ``` """ - self.mount(server, prefix=prefix) + self.include_router(server, prefix=prefix) # type: ignore return self diff --git a/hud/environment/connectors/remote.py b/hud/environment/connectors/remote.py index 0b2e66c3..599c5af0 100644 --- a/hud/environment/connectors/remote.py +++ b/hud/environment/connectors/remote.py @@ -20,14 +20,14 @@ class RemoteConnectorMixin(MCPConfigConnectorMixin): - """Mixin providing remote connection methods.""" + """Mixin providing remote connection methods. + + Note: include_router() is inherited from MCPServer (via FastMCP). + """ # Store hub configs for trace serialization _hub_configs: list[HubConfig] - def mount(self, server: Any, *, prefix: str | None = None) -> None: - raise NotImplementedError - def connect_hub( self, slug: str, @@ -183,5 +183,5 @@ def connect_openapi( client=client, name=name or "openapi", ) - self.mount(mcp_server, prefix=prefix) + self.include_router(mcp_server, prefix=prefix) # type: ignore return self diff --git a/hud/environment/environment.py b/hud/environment/environment.py index 482034da..0bc48a52 100644 --- a/hud/environment/environment.py +++ b/hud/environment/environment.py @@ -27,6 +27,10 @@ logger = logging.getLogger(__name__) +# Suppress verbose fastmcp logging +logging.getLogger("fastmcp.server.server").setLevel(logging.WARNING) +logging.getLogger("fastmcp.server.openapi").setLevel(logging.WARNING) + # Type alias for async callables (no-arg functions that return awaitable) AsyncCallable = Callable[[], Awaitable[Any]] @@ -132,7 +136,7 @@ def __init__( self._setup_calls: list[tuple[str, dict[str, Any]]] = [] self._evaluate_calls: list[tuple[str, dict[str, Any]]] = [] - # Task prompt - set by connect_task or manually + # Default prompt - set by connect_task (EvalContext has per-run prompt) self.prompt: str | None = None # Track which lifecycle tools we've warned about (only warn once per tool) @@ -190,6 +194,58 @@ def _check_lifecycle_warning(self, name: str) -> None: phase, ) + def _connections_with_tool(self, tool_name: str) -> set[str]: + """Get connection names that have a specific tool. + + Uses cached_tools from each Connector to check availability. + """ + result = set() + for name, connector in self._connections.items(): + tool_names = {t.name for t in connector.cached_tools} + if tool_name in tool_names: + result.add(name) + return result + + async def _broadcast_tool( + self, + tool_name: str, + **kwargs: Any, + ) -> dict[str, Any]: + """Broadcast a tool call to all connections that have the tool. + + Automatically filters to only connections where the tool exists + (based on cached_tools from initial discovery). + + Args: + tool_name: Name of the tool to call + **kwargs: Arguments to pass to the tool + + Returns: + Dict mapping connection name to result (or exception) + """ + import asyncio + + # Only call connections that have this tool + targets = self._connections_with_tool(tool_name) + if not targets: + return {} + + results: dict[str, Any] = {} + + async def call_one(name: str) -> None: + connector = self._connections.get(name) + if not connector or not connector.client: + return + try: + results[name] = await connector.client.call_tool(tool_name, **kwargs) + logger.debug("Broadcast '%s' to '%s' succeeded", tool_name, name) + except Exception as e: + results[name] = e + logger.debug("Broadcast '%s' to '%s' failed: %s", tool_name, name, e) + + await asyncio.gather(*[call_one(n) for n in targets], return_exceptions=True) + return results + async def call_tools(self, calls: Any) -> list[Any]: """Call multiple tools, returning results in matching formats.""" if calls is None: @@ -299,7 +355,10 @@ async def __aexit__( async def _build_routing(self) -> None: """Build tool routing from local tools and connection caches.""" - local_tools = await self._tool_manager.list_tools() + # Use get_tools() not list_tools() - it includes mounted servers without + # requiring MCP server communication (via_server=False) + local_tools_dict = await self._tool_manager.get_tools() + local_tools = list(local_tools_dict.values()) self._router.build( local_tools=[t.to_mcp_tool() for t in local_tools], connections=self._connections, @@ -526,7 +585,14 @@ def __repr__(self) -> str: # Eval Creation # ========================================================================= - def __call__(self, script: str | None = None, **args: Any) -> Eval: + def __call__( + self, + script: str | None = None, + *, + _trace: bool = True, + _quiet: bool = False, + **args: Any, + ) -> Eval: """Create an Eval from this environment. Returns an Eval that can be entered as a context manager or passed @@ -534,6 +600,8 @@ def __call__(self, script: str | None = None, **args: Any) -> Eval: Args: script: Optional script name to run (from @env.script) + _trace: Whether to send trace data to backend (default True) + _quiet: Whether to suppress printing links (default False) **args: Arguments for the script Returns: @@ -565,9 +633,11 @@ async def checkout(user_id: str): from hud.eval.eval import Eval return Eval( - env_config=self._get_env_config(), + env=self, # Pass live environment for local tools/scripts script=script, args=args, + _trace=_trace, + _quiet=_quiet, ) @classmethod diff --git a/hud/environment/integrations/openai.py b/hud/environment/integrations/openai.py index 015e8ada..9d553ac0 100644 --- a/hud/environment/integrations/openai.py +++ b/hud/environment/integrations/openai.py @@ -3,12 +3,13 @@ from __future__ import annotations import json -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING, Any, cast from hud.environment.utils.schema import ensure_strict_schema if TYPE_CHECKING: import mcp.types as mcp_types + from openai.types.chat import ChatCompletionToolUnionParam __all__ = ["OpenAIMixin"] @@ -42,7 +43,7 @@ async def call_tool(self, name: str, arguments: dict[str, Any]) -> Any: # Format Conversion (no external deps) # ========================================================================= - def as_openai_chat_tools(self, *, strict: bool = False) -> list[dict[str, Any]]: + def as_openai_chat_tools(self, *, strict: bool = False) -> list[ChatCompletionToolUnionParam]: """Convert to OpenAI Chat Completions tool format. Args: @@ -67,7 +68,7 @@ def as_openai_chat_tools(self, *, strict: bool = False) -> list[dict[str, Any]]: # results are {"role": "tool", "tool_call_id": ..., "content": ...} ``` """ - tools = [] + tools: list[ChatCompletionToolUnionParam] = [] for t in self.as_tools(): schema = dict(t.inputSchema) if t.inputSchema else {"type": "object", "properties": {}} @@ -75,15 +76,18 @@ def as_openai_chat_tools(self, *, strict: bool = False) -> list[dict[str, Any]]: schema = ensure_strict_schema(schema) tools.append( - { - "type": "function", - "function": { - "name": t.name, - "description": t.description or "", - "parameters": schema, - **({"strict": True} if strict else {}), + cast( + "ChatCompletionToolUnionParam", + { + "type": "function", + "function": { + "name": t.name, + "description": t.description or "", + "parameters": schema, + **({"strict": True} if strict else {}), + }, }, - } + ) ) return tools diff --git a/hud/environment/router.py b/hud/environment/router.py index 962b0802..a1f423bf 100644 --- a/hud/environment/router.py +++ b/hud/environment/router.py @@ -63,16 +63,21 @@ def build( """Build routing from local tools and connection caches. Local tools always have priority over remote tools. + Tools starting with '_' are internal and hidden from listing + (but still callable directly). """ self.clear() seen: dict[str, str] = {} # Local tools first (always priority) for tool in local_tools: + # Always add to routing (so tool is callable) seen[tool.name] = LOCAL_CONNECTION self._routing[tool.name] = LOCAL_CONNECTION self._local_names.add(tool.name) - self._tools.append(tool) + # Only add to visible list if not internal (underscore prefix) + if not tool.name.startswith("_"): + self._tools.append(tool) # Remote connections in order for conn_name in connection_order: @@ -88,9 +93,12 @@ def build( continue self._tools = [t for t in self._tools if t.name != name] + # Always add to routing (so tool is callable) seen[name] = conn_name self._routing[name] = conn_name - self._tools.append(tool) + # Only add to visible list if not internal (underscore prefix) + if not name.startswith("_"): + self._tools.append(tool) logger.debug("Router: %d tools (%d local)", len(self._tools), len(self._local_names)) diff --git a/hud/environment/scripts.py b/hud/environment/scripts.py index 00ebd88e..7b077a46 100644 --- a/hud/environment/scripts.py +++ b/hud/environment/scripts.py @@ -13,6 +13,7 @@ from fastmcp.prompts import PromptManager from fastmcp.resources import ResourceManager + from fastmcp.tools import ToolManager __all__ = ["ScriptMixin"] @@ -26,33 +27,210 @@ class ScriptMixin: - First yield: prompt string (setup phase) - Second yield: reward float (evaluate phase) + The script can receive the agent's answer via yield: + answer = yield "Do the task" + yield 1.0 if "success" in answer else 0.0 + + The answer is passed via the hud_submit tool or ctx.submit(). + The decorator registers both an MCP prompt and resource with the same - identifier (script:{name}), linked by session state. + identifier ({env_name}:{script_name}), linked by session state. Example: @env.script() async def search_cats(url: str): await env.call_tool("navigate", url=url) - yield "Find all cat images on the page" + answer = yield "Find all cat images on the page" result = await env.call_tool("count_cats") - yield float(result > 0) + yield float(result > 0 or "found" in answer.lower()) """ # These come from Environment/MCPServer name: str _prompt_manager: PromptManager _resource_manager: ResourceManager + _tool_manager: ToolManager # Script state - _scripts: dict[str, Callable[..., AsyncGenerator[Any, None]]] - _script_sessions: dict[str, AsyncGenerator[Any, None]] # session_id -> generator + _scripts: dict[str, Callable[..., AsyncGenerator[Any, Any]]] + _script_sessions: dict[str, AsyncGenerator[Any, Any]] # session_id -> generator _script_latest: dict[str, str] # script_name -> latest session_id + _script_answers: dict[str, str] # script_name -> submitted answer def _init_scripts(self) -> None: """Initialize script state. Called from Environment.__init__.""" self._scripts = {} self._script_sessions = {} self._script_latest = {} + self._script_answers = {} + + # Register _hud_submit tool (underscore = hidden from agent) + self._register_hud_submit_tool() + + async def submit(self, script: str, answer: str) -> None: + """Submit the agent's answer for a script's evaluate phase. + + This stores the answer locally and broadcasts to connected hubs + that have the _hud_submit tool (auto-detected by Environment). + + Args: + script: Name of the script (without env prefix) + answer: The agent's answer/result to submit + + Example: + # Direct call with script name + await env.submit("checkout", "Order completed successfully") + + # Or via EvalContext (knows its own script) + await ctx.submit("Order completed successfully") + """ + # Store locally for our scripts + self._script_answers[script] = answer + logger.debug("Stored answer for script '%s': %s...", + script, answer[:50] if len(answer) > 50 else answer) + + # Broadcast to connections that have _hud_submit + # Environment._broadcast_tool auto-filters to connections with the tool + await self._broadcast_tool( # type: ignore[attr-defined] + "_hud_submit", + script=script, + answer=answer, + ) + + def _register_hud_submit_tool(self) -> None: + """Register the _hud_submit tool for receiving agent answers. + + Named with underscore prefix to hide from agent tool listings. + """ + from fastmcp.tools import Tool + + script_self = self + + async def _hud_submit(script: str, answer: str) -> str: + """Submit the agent's answer for a script's evaluate phase. + + Internal tool - called by Environment.submit() on connected hubs. + + Args: + script: Name of the script (without env prefix) + answer: The agent's answer/result to submit + """ + # Store locally (don't broadcast - we ARE the target) + script_self._script_answers[script] = answer + logger.debug("_hud_submit received answer for script '%s': %s...", + script, answer[:50] if len(answer) > 50 else answer) + return f"Answer submitted for script '{script}'" + + # Register the tool with underscore name + tool = Tool.from_function(_hud_submit) + self._tool_manager.add_tool(tool) + logger.debug("Registered _hud_submit tool") + + async def run_script_setup(self, script_name: str, args: dict[str, Any]) -> str | None: + """Run a script's setup phase and return the prompt. + + Handles both local scripts (registered via @env.script) and remote + scripts (via MCP prompt). + + Args: + script_name: Name of the script to run + args: Arguments to pass to the script + + Returns: + The prompt string from the script's setup phase, or None if failed + """ + # Check if script is registered locally + if script_name in self._scripts: + # Local script - run setup via generator + script_fn = self._scripts[script_name] + gen = script_fn(**args) + + # Run setup phase (code before first yield) + prompt = await gen.__anext__() + + # Store generator for evaluate phase + session_id = uuid.uuid4().hex[:8] + self._script_sessions[session_id] = gen + self._script_latest[script_name] = session_id + + logger.debug( + "Script %s setup complete, session=%s", + script_name, + session_id, + ) + return str(prompt) + else: + # Remote script - call via MCP prompt + # Format: {env_name}:{script_name} (use source env name if available) + env_name = getattr(self, "_source_env_name", None) or self.name + safe_env_name = env_name.replace("_", "-") + prompt_id = f"{safe_env_name}:{script_name}" + try: + result = await self.get_prompt(prompt_id, args) # type: ignore[attr-defined] + if result.messages: + first_msg = result.messages[0] + content = first_msg.content + if hasattr(content, "text") and isinstance(content.text, str): # type: ignore[union-attr] + return content.text # type: ignore[union-attr] + elif isinstance(content, str): + return content + except Exception as e: + logger.warning("Failed to get script prompt: %s", e) + return None + + async def run_script_evaluate(self, script_name: str) -> float | None: + """Run a script's evaluate phase and return the reward. + + Uses the submitted answer (if any) via gen.asend(). + Handles both local and remote scripts. + + Args: + script_name: Name of the script to evaluate + + Returns: + The reward from the script's evaluate phase, or None if failed + """ + # Check if we have a stored generator (local script) + session_id = self._script_latest.get(script_name) + if session_id: + gen = self._script_sessions.pop(session_id, None) + if gen: + # Get submitted answer (if any) + answer = self._script_answers.pop(script_name, None) + + try: + # Use asend to pass the answer to the script + reward = await gen.asend(answer) + logger.debug( + "Script %s evaluate complete, answer=%s, reward=%s", + script_name, + answer[:50] if answer and len(answer) > 50 else answer, + reward, + ) + return float(reward) + except StopAsyncIteration: + # Generator ended without second yield - assume success + return 1.0 + finally: + # Clean up latest pointer + if self._script_latest.get(script_name) == session_id: + del self._script_latest[script_name] + + # Remote script - read via MCP resource (use source env name if available) + env_name = getattr(self, "_source_env_name", None) or self.name + safe_env_name = env_name.replace("_", "-") + resource_id = f"{safe_env_name}:{script_name}" + try: + contents = await self.read_resource(resource_id) # type: ignore[attr-defined] + if contents: + first_content = contents[0] + if hasattr(first_content, "text") and isinstance(first_content.text, str): # type: ignore[union-attr] + data = json.loads(first_content.text) # type: ignore[union-attr] + if "reward" in data: + return float(data["reward"]) + except Exception as e: + logger.warning("Failed to get script reward: %s", e) + return None def script( self, @@ -91,7 +269,9 @@ def decorator( fn: Callable[..., AsyncGenerator[Any, None]], ) -> Callable[..., AsyncGenerator[Any, None]]: script_name = name or fn.__name__ - script_id = f"{self.name}:{script_name}" + # Sanitize env name for URI scheme (no underscores allowed) + safe_env_name = self.name.replace("_", "-") + script_id = f"{safe_env_name}:{script_name}" script_desc = description or fn.__doc__ or f"Script: {script_name}" # Store the generator function @@ -149,36 +329,41 @@ async def prompt_handler(**handler_args: Any) -> list[dict[str, Any]]: # Register RESOURCE - runs evaluate, returns reward async def resource_handler() -> str: # Get latest session for this script - session_id = self._script_latest.get(script_name) + session_id = script_self._script_latest.get(script_name_ref) if not session_id: raise ValueError( - f"No active session for script '{script_name}'. " + f"No active session for script '{script_name_ref}'. " "Call the prompt first to run setup." ) - gen = self._script_sessions.pop(session_id, None) + gen = script_self._script_sessions.pop(session_id, None) if gen is None: raise ValueError( f"Session '{session_id}' not found or already evaluated." ) + # Get submitted answer (if any) + answer = script_self._script_answers.pop(script_name_ref, None) + # Run evaluate phase (code after first yield) + # Use asend to pass the answer (or None if not submitted) try: - reward = await gen.__anext__() + reward = await gen.asend(answer) except StopAsyncIteration: # Generator ended without second yield - assume success reward = 1.0 logger.debug( - "Script %s evaluate complete, session=%s, reward=%s", - script_name, + "Script %s evaluate complete, session=%s, answer=%s, reward=%s", + script_name_ref, session_id, + answer[:50] if answer and len(answer) > 50 else answer, reward, ) # Clean up latest pointer if it matches - if self._script_latest.get(script_name) == session_id: - del self._script_latest[script_name] + if script_self._script_latest.get(script_name_ref) == session_id: + del script_self._script_latest[script_name_ref] return json.dumps({"reward": float(reward)}) diff --git a/hud/eval/context.py b/hud/eval/context.py index 73ceab04..b2ef831c 100644 --- a/hud/eval/context.py +++ b/hud/eval/context.py @@ -91,6 +91,8 @@ def __init__( code_snippet: str | None = None, env_config: dict[str, Any] | None = None, task: Task | None = None, + trace: bool = True, + quiet: bool = False, **env_kwargs: Any, ) -> None: """Initialize EvalContext. @@ -106,6 +108,8 @@ def __init__( code_snippet: Code being evaluated (for reproducibility) env_config: Environment configuration dict task: Task definition (if loaded from slug) + trace: Whether to send trace data to backend (default True) + quiet: Whether to suppress printing links (default False) **env_kwargs: Additional kwargs passed to Environment.__init__ """ # Initialize Environment @@ -130,8 +134,10 @@ def __init__( # Variant assignment self.variants: dict[str, Any] = variants or {} - # User-settable + # User-settable (per-run values, override Environment defaults) + self.prompt: str | None = None # From script setup or task self.reward: float | None = None + self.answer: str | None = None # Agent's submitted answer # Error tracking self.error: BaseException | None = None @@ -156,7 +162,9 @@ def __init__( self._completed_at: datetime | None = None self._token: contextvars.Token[dict[str, str] | None] | None = None self._is_summary: bool = False # True for summary contexts (skip trace) - self._suppress_link: bool = False # True to suppress printing eval link + self._suppress_link: bool = quiet # True to suppress printing eval link + self._trace_enabled: bool = trace # Whether to send trace data to backend + self._script_name: str | None = None # Current script name (for submit) def _apply_task(self, task: Task) -> None: """Apply a Task definition to this environment.""" @@ -198,6 +206,8 @@ def from_environment( variants: dict[str, Any] | None = None, code_snippet: str | None = None, env_config: dict[str, Any] | None = None, + trace: bool = True, + quiet: bool = False, ) -> EvalContext: """Create an EvalContext that copies configuration from an existing Environment. @@ -226,6 +236,8 @@ def from_environment( variants=variants, code_snippet=code_snippet, env_config=env_config, + trace=trace, + quiet=quiet, ) # Copy connections from parent - each connector is copied so parallel @@ -235,6 +247,23 @@ def from_environment( ctx._setup_calls = env._setup_calls.copy() ctx._evaluate_calls = env._evaluate_calls.copy() + # Copy scripts (definitions) by reference - they don't change + ctx._scripts = getattr(env, "_scripts", {}) + # Create fresh session state for this eval (parallel evals each need their own) + ctx._script_sessions = {} + ctx._script_latest = {} + ctx._script_answers = {} + + # Store source env name for remote script lookups + ctx._source_env_name = env.name + + # Copy managers by reference (they hold local tools, prompts, resources) + # This allows ctx.call_tool(), ctx.get_prompt(), ctx.read_resource() to work + # for locally defined tools/scripts + ctx._tool_manager = env._tool_manager + ctx._prompt_manager = env._prompt_manager + ctx._resource_manager = env._resource_manager + # Copy prompt if env.prompt: ctx.prompt = env.prompt @@ -254,6 +283,8 @@ def from_task( index: int = 0, variants: dict[str, Any] | None = None, code_snippet: str | None = None, + trace: bool = True, + quiet: bool = False, ) -> EvalContext: """Create an EvalContext from a Task definition. @@ -270,6 +301,8 @@ def from_task( index: Index in parallel execution variants: Variant assignment code_snippet: Code being evaluated + trace: Whether to send trace data to backend + quiet: Whether to suppress printing links """ import warnings @@ -291,28 +324,27 @@ def from_task( variants=variants, code_snippet=code_snippet, task=task, + trace=trace, + quiet=quiet, ) # ========================================================================= # Summary Context - Attribute Access Control # ========================================================================= - # Attributes accessible on summary context (everything else raises) + # Attributes accessible on summary context (everything else raises ParallelEvalComplete) _SUMMARY_ALLOWED = frozenset( { # Results and metadata "results", "reward", "error", + "success", + # IDs "trace_id", "job_id", "group_id", "index", - "variants", - "eval_name", - "duration", - "success", - "done" # Private attrs "_is_summary", "_suppress_link", @@ -379,11 +411,10 @@ def _build_base_payload(self) -> EvalPayload: env_config_model = EnvConfig(**self._eval_env_config) return EvalPayload( - task_name=self.eval_name, + job_name=self.eval_name, prompt=self.prompt, code_snippet=self.code_snippet, env_config=env_config_model, - all_hubs=self._all_hubs, job_id=self.job_id, group_id=self.group_id, variants=self.variants if self.variants else None, @@ -405,8 +436,36 @@ async def log(self, metrics: dict[str, Any]) -> None: except Exception as e: logger.warning("Failed to log metrics: %s", e) + async def submit(self, answer: str) -> None: + """Submit the agent's answer for script evaluation. + + Delegates to Environment.submit() with the current script name. + The answer will be passed to the script's evaluate phase via + `yield`, e.g.: `answer = yield "Do the task"` + + Args: + answer: The agent's final answer/result to submit + + Example: + async with env("checkout", product="laptop") as ctx: + response = await agent.run(ctx.prompt) + await ctx.submit(response) + # On exit, script's evaluate phase receives the answer + """ + if not self._script_name: + logger.warning("submit() called but no script is running") + return + + # Store answer on context for display + self.answer = answer + + # Delegate to Environment.submit() which handles storage + broadcast + await super().submit(self._script_name, answer) + async def _eval_enter(self) -> None: """Notify backend that eval has started.""" + if not self._trace_enabled: + return api_key = self._get_eval_api_key() if not settings.telemetry_enabled or not api_key: return @@ -424,6 +483,8 @@ async def _eval_enter(self) -> None: async def _eval_exit(self, error_message: str | None = None) -> None: """Notify backend that eval has completed.""" + if not self._trace_enabled: + return api_key = self._get_eval_api_key() if not settings.telemetry_enabled or not api_key: return @@ -454,20 +515,15 @@ async def _eval_exit(self, error_message: str | None = None) -> None: # ========================================================================= async def __aenter__(self) -> Self: - """Enter eval context - start tracking and connect environment.""" - # Summary contexts skip trace tracking (parallel results already tracked) + """Enter eval context - connect environment and set trace headers.""" if self._is_summary: return self - # Start eval tracking + # Start tracking self._started_at = datetime.now(UTC) self._token = _current_trace_headers.set(self.headers) - # Notify backend - await self._eval_enter() - self._print_eval_link() - - # Connect environment (parent class) + # Connect environment (MCP servers, tools) await super().__aenter__() return self @@ -477,11 +533,12 @@ async def __aexit__( exc_type: type[BaseException] | None, exc_val: BaseException | None, exc_tb: TracebackType | None, - ) -> None: + ) -> bool: """Exit eval context - disconnect and report.""" # Summary contexts skip trace tracking (parallel results already tracked) + # Suppress ParallelEvalComplete - it's expected for skipping body re-execution if self._is_summary: - return + return exc_type is ParallelEvalComplete self._completed_at = datetime.now(UTC) @@ -501,6 +558,7 @@ async def __aexit__( # Notify backend await self._eval_exit(error_msg) + return False def __repr__(self) -> str: return f"EvalContext({self.trace_id[:8]}..., name={self.eval_name!r}, reward={self.reward})" diff --git a/hud/eval/display.py b/hud/eval/display.py index 3f6f5c43..a7798504 100644 --- a/hud/eval/display.py +++ b/hud/eval/display.py @@ -144,7 +144,8 @@ def print_eval_stats( if show_details and len(completed) <= 50: table = Table(title="Per-Eval Details", show_header=True, header_style="bold") table.add_column("#", style="dim", justify="right", width=4) - table.add_column("Variants", style="cyan", max_width=30) + table.add_column("Variants", style="cyan", max_width=35) + table.add_column("Answer", style="white", max_width=25) table.add_column("Reward", justify="right", style="green", width=8) table.add_column("Duration", justify="right", width=10) table.add_column("Status", justify="center", width=8) @@ -152,6 +153,7 @@ def print_eval_stats( for ctx in completed: idx_str = str(ctx.index) variants_str = _format_variants(ctx.variants) if ctx.variants else "-" + answer_str = _truncate(ctx.answer, 30) if ctx.answer else "-" reward_str = f"{ctx.reward:.3f}" if ctx.reward is not None else "-" duration_str = f"{ctx.duration:.2f}s" if ctx.duration > 0 else "-" @@ -162,7 +164,7 @@ def print_eval_stats( else: status = "[yellow]○[/yellow]" - table.add_row(idx_str, variants_str, reward_str, duration_str, status) + table.add_row(idx_str, variants_str, answer_str, reward_str, duration_str, status) console.print(table) @@ -179,7 +181,16 @@ def _format_variants(variants: dict[str, Any]) -> str: return "-" parts = [f"{k}={v}" for k, v in variants.items()] result = ", ".join(parts) - return result[:30] + "..." if len(result) > 30 else result + return result[:35] + "..." if len(result) > 35 else result + + +def _truncate(text: str | None, max_len: int) -> str: + """Truncate text to max length.""" + if not text: + return "-" + # Replace newlines with spaces for display + text = text.replace("\n", " ").strip() + return text[:max_len] + "..." if len(text) > max_len else text def _print_eval_stats_basic( diff --git a/hud/eval/eval.py b/hud/eval/eval.py index 2c9830e0..2800c48d 100644 --- a/hud/eval/eval.py +++ b/hud/eval/eval.py @@ -34,30 +34,49 @@ from hud.eval.context import EvalContext -__all__ = ["Eval"] +__all__ = ["Eval", "build_eval_name"] logger = logging.getLogger(__name__) +def build_eval_name(script: str | None, args: dict[str, Any] | None) -> str: + """Build descriptive name: 'script with val1, val2, ...'""" + if not script: + return "eval" + if not args: + return script + + val_parts = [] + for v in list(args.values())[:3]: # Max 3 values + v_str = repr(v) if isinstance(v, str) else str(v) + if len(v_str) > 25: + v_str = v_str[:22] + "..." + val_parts.append(v_str) + + if val_parts: + return f"{script} with {', '.join(val_parts)}" + return script + + @dataclass class Eval: """A runnable evaluation unit (data class). Holds the configuration to create an EvalContext: - - env_config: How to create/connect the environment + - env: The environment (live instance or serialized config) - script: Optional script name to run (from @env.script) - args: Arguments for the script When entered as a context manager, creates an EvalContext. Attributes: - env_config: Serializable environment configuration + env: Environment instance (local) or EnvConfig dict (remote) or None (blank) script: Script name to run (None for env-only) args: Script arguments """ - # Core config - env_config: dict[str, Any] | None = None + # Core config - env can be live Environment or serialized config + env: Any = None # Environment | dict[str, Any] | None script: str | None = None args: dict[str, Any] = field(default_factory=dict) @@ -70,14 +89,28 @@ class Eval: variants: dict[str, Any] = field(default_factory=dict, repr=False) code_snippet: str | None = field(default=None, repr=False) _suppress_link: bool = field(default=False, repr=False) + _trace: bool = field(default=True, repr=False) + _quiet: bool = field(default=False, repr=False) # Runtime state _ctx: EvalContext | None = field(default=None, repr=False) + # Backwards compat alias + @property + def env_config(self) -> dict[str, Any] | None: + """Get serializable env config (for backwards compat and backend).""" + from hud.environment import Environment + + if isinstance(self.env, Environment): + return self.env._get_env_config() + elif isinstance(self.env, dict): + return self.env + return None + def copy(self) -> Eval: """Create a copy of this Eval for parallel execution.""" return Eval( - env_config=self.env_config, + env=self.env, # Share reference - from_environment handles copying script=self.script, args=self.args.copy(), trace_id=None, # Each copy gets unique trace_id @@ -88,24 +121,59 @@ def copy(self) -> Eval: variants=self.variants.copy(), code_snippet=self.code_snippet, _suppress_link=self._suppress_link, + _trace=self._trace, + _quiet=self._quiet, ) def to_eval_context(self) -> EvalContext: """Convert this Eval to an EvalContext. - Creates an EvalContext with environment from env_config and - script info stored for setup/evaluate phases. + Creates an EvalContext from the environment (live or from config). + Also handles deprecated Task objects stored in _task attribute. """ from hud.environment import Environment from hud.eval.context import EvalContext - # Create environment from config - env = Environment.from_config(self.env_config) if self.env_config else Environment("eval") + # Check for deprecated Task (backwards compat) + task = getattr(self, "_task", None) + if task is not None: + import warnings + warnings.warn( + "Task objects are deprecated. Use Eval from env() instead.", + DeprecationWarning, + stacklevel=3, + ) + ctx = EvalContext.from_task( + task=task, + api_key=self.api_key, + job_id=self.job_id, + group_id=self.group_id, + index=self.index, + variants=self.variants, + code_snippet=self.code_snippet, + trace=self._trace, + quiet=self._quiet, + ) + ctx._suppress_link = self._suppress_link + return ctx + + # Get or create environment + if isinstance(self.env, Environment): + # Local - use live environment (from_environment handles copying) + source_env = self.env + elif isinstance(self.env, dict): + # Remote/config - create fresh from config + source_env = Environment.from_config(self.env) + else: + # Blank + source_env = Environment("eval") + + eval_name = build_eval_name(self.script, self.args) # Create EvalContext from environment ctx = EvalContext.from_environment( - env=env, - name=self.script or "eval", + env=source_env, + name=eval_name, trace_id=self.trace_id, api_key=self.api_key, job_id=self.job_id, @@ -116,18 +184,31 @@ def to_eval_context(self) -> EvalContext: env_config=self.env_config, ) ctx._suppress_link = self._suppress_link + ctx._trace_enabled = self._trace return ctx async def __aenter__(self) -> EvalContext: - """Enter eval context - create EvalContext and enter it.""" + """Enter eval context. + + Order of operations: + 1. Create EvalContext from environment config + 2. Connect environment (MCP servers, etc.) + 3. Run script setup (if script) → sets ctx.prompt + 4. Notify backend (with prompt now set) + 5. Print trace link + """ self._ctx = self.to_eval_context() - await self._ctx.__aenter__() + await self._ctx.__aenter__() # Connect env, set trace headers - # If we have a script, run its setup phase + # Run script setup (sets prompt) if self.script: await self._run_script_setup() + # Notify backend with prompt included + await self._ctx._eval_enter() + self._ctx._print_eval_link() + return self._ctx async def __aexit__( @@ -153,95 +234,20 @@ async def _run_script_setup(self) -> None: if self._ctx is None or self.script is None: return - # Check if script is registered locally - scripts = getattr(self._ctx, "_scripts", {}) - if self.script in scripts: - # Local script - run setup via generator - import uuid - - script_fn = scripts[self.script] - gen = script_fn(**self.args) - - # Run setup phase (code before first yield) - prompt = await gen.__anext__() + # Store script name on context for ctx.submit() + self._ctx._script_name = self.script - # Store generator for evaluate phase - session_id = uuid.uuid4().hex[:8] - script_sessions = getattr(self._ctx, "_script_sessions", {}) - script_latest = getattr(self._ctx, "_script_latest", {}) - script_sessions[session_id] = gen - script_latest[self.script] = session_id - - # Set prompt on context - self._ctx.prompt = str(prompt) - - logger.debug( - "Script %s setup complete, session=%s", - self.script, - session_id, - ) - else: - # Remote script - call via MCP prompt - # Format: {env_name}:{script_name} - env_name = self._ctx.name if self._ctx else "eval" - prompt_id = f"{env_name}:{self.script}" - try: - result = await self._ctx.get_prompt(prompt_id, self.args) - if result.messages: - # Extract prompt from first message - first_msg = result.messages[0] - content = first_msg.content - # Handle TextContent which has .text attribute - if hasattr(content, "text") and isinstance(content.text, str): # type: ignore[union-attr] - self._ctx.prompt = content.text # type: ignore[union-attr] - elif isinstance(content, str): - self._ctx.prompt = content - except Exception as e: - logger.warning("Failed to get script prompt: %s", e) + # Delegate to ScriptMixin.run_script_setup + prompt = await self._ctx.run_script_setup(self.script, self.args) + if prompt: + self._ctx.prompt = prompt async def _run_script_evaluate(self) -> None: """Run the script's evaluate phase (get reward).""" if self._ctx is None or self.script is None: return - # Check if we have a stored generator (local script) - script_latest = getattr(self._ctx, "_script_latest", {}) - session_id = script_latest.get(self.script) - if session_id: - script_sessions = getattr(self._ctx, "_script_sessions", {}) - gen = script_sessions.pop(session_id, None) - if gen: - try: - reward = await gen.__anext__() - self._ctx.reward = float(reward) - logger.debug( - "Script %s evaluate complete, reward=%s", - self.script, - reward, - ) - except StopAsyncIteration: - # Generator ended without second yield - assume success - self._ctx.reward = 1.0 - - # Clean up latest pointer - if script_latest.get(self.script) == session_id: - del script_latest[self.script] - return - - # Remote script - read via MCP resource - # Format: {env_name}:{script_name} - env_name = self._ctx.name if self._ctx else "eval" - resource_id = f"{env_name}:{self.script}" - try: - import json - - contents = await self._ctx.read_resource(resource_id) - if contents: - first_content = contents[0] - # Handle TextResourceContents which has .text attribute - if hasattr(first_content, "text") and isinstance(first_content.text, str): # type: ignore[union-attr] - data = json.loads(first_content.text) # type: ignore[union-attr] - if "reward" in data: - self._ctx.reward = float(data["reward"]) - except Exception as e: - logger.warning("Failed to get script reward: %s", e) + # Delegate to ScriptMixin.run_script_evaluate + reward = await self._ctx.run_script_evaluate(self.script) + if reward is not None: + self._ctx.reward = reward diff --git a/hud/eval/instrument.py b/hud/eval/instrument.py index 9db50c4e..0024034e 100644 --- a/hud/eval/instrument.py +++ b/hud/eval/instrument.py @@ -25,15 +25,19 @@ def _get_trace_headers() -> dict[str, str] | None: def _is_hud_url(url_str: str) -> bool: """Check if URL is a HUD service (inference or MCP).""" - # Extract hostnames from settings URLs - gateway_host = urlparse(settings.hud_gateway_url).netloc - mcp_host = urlparse(settings.hud_mcp_url).netloc - - # Parse the request URL and check against known HUD hosts parsed = urlparse(url_str) request_host = parsed.netloc or url_str.split("/")[0] - - return request_host in (gateway_host, mcp_host) + + # Check for known HUD domains (works for any subdomain) + if request_host.endswith((".hud.ai", ".hud.so")): + return True + + # Also check settings URLs + known_hosts = { + urlparse(settings.hud_gateway_url).netloc, + urlparse(settings.hud_mcp_url).netloc, + } + return request_host in known_hosts def _httpx_request_hook(request: Any) -> None: diff --git a/hud/eval/manager.py b/hud/eval/manager.py index 3abe6d49..9eb56a9f 100644 --- a/hud/eval/manager.py +++ b/hud/eval/manager.py @@ -12,6 +12,7 @@ from typing import TYPE_CHECKING, Any from hud.eval.display import print_complete, print_eval_stats, print_link +from hud.eval.types import ParallelEvalComplete from hud.eval.parallel import ( ASTExtractionError, expand_variants, @@ -65,14 +66,13 @@ def _get_eval_name( tasks: List of Task objects (deprecated) Returns: - Name like "evalset", script name, or "eval" if no source + Name like "script with val1, val2" or "eval" if no source """ - # If we have Eval objects, use first script name - if evals: - first_eval = evals[0] - if first_eval.script: - return first_eval.script - return "eval" + from hud.eval.eval import build_eval_name + + # If we have Eval objects, derive name from first one + if evals and evals[0].script: + return build_eval_name(evals[0].script, evals[0].args) # Deprecated: If we have tasks with IDs, use first task ID if tasks: @@ -189,7 +189,7 @@ def _eval_from_api(data: dict[str, Any]) -> Eval: from hud.eval.eval import Eval return Eval( - env_config=data.get("env_config"), + env=data.get("env_config"), # Serialized config from backend script=data.get("script"), args=data.get("args", {}), ) @@ -205,6 +205,8 @@ async def run_eval( job_id: str | None = None, api_key: str | None = None, max_concurrent: int | None = None, + trace: bool = True, + quiet: bool = False, ) -> AsyncGenerator[EvalContext, None]: """Standalone eval context manager. @@ -226,6 +228,8 @@ async def run_eval( job_id: Job ID to link to api_key: API key for backend calls max_concurrent: Maximum concurrent evals (None = unlimited) + trace: Whether to send trace data to backend (default True) + quiet: Whether to suppress printing links (default False) Yields: EvalContext: Environment with evaluation tracking @@ -349,38 +353,45 @@ async def run_eval( from hud.eval.context import EvalContext if total_evals == 1: - # Simple case: single eval + # Simple case: single eval - always use Eval for consistent flow if evals: - # Single Eval object - enter it directly single_eval = evals[0] - single_eval.api_key = api_key - single_eval.job_id = job_id - single_eval.variants = variant_combos[0] - single_eval.code_snippet = code_snippet - async with single_eval as ctx: - yield ctx elif tasks: - # Single task - ctx = EvalContext.from_task( - task=tasks[0], + # Wrap deprecated Task in Eval + single_eval = Eval( + env=None, + script=None, api_key=api_key, job_id=job_id, variants=variant_combos[0], code_snippet=code_snippet, + _trace=trace, + _quiet=quiet, ) - async with ctx: - yield ctx + single_eval._task = tasks[0] # type: ignore[attr-defined] else: # Blank eval - ctx = EvalContext( - name="eval", + single_eval = Eval( + env=None, + script=None, api_key=api_key, job_id=job_id, variants=variant_combos[0], code_snippet=code_snippet, + _trace=trace, + _quiet=quiet, ) - async with ctx: - yield ctx + + # Apply common settings + single_eval.api_key = api_key + single_eval.job_id = job_id + single_eval.variants = variant_combos[0] + single_eval.code_snippet = code_snippet + single_eval._trace = trace + single_eval._quiet = quiet + + async with single_eval as ctx: + yield ctx else: # Parallel execution: create implicit job to group traces @@ -389,7 +400,8 @@ async def run_eval( job_url = f"https://hud.ai/jobs/{implicit_job_id}" # Print job URL (not individual trace URLs) - print_link(job_url, f"🚀 Job '{eval_name}'") + if not quiet: + print_link(job_url, f"🚀 {eval_name}") error_occurred = False try: @@ -404,13 +416,15 @@ async def run_eval( api_key=api_key, code_snippet=code_snippet, max_concurrent=max_concurrent, + trace=trace, + quiet=quiet, ) # Create summary context (no trace, just aggregates results) if evals: # Create summary from first eval's env_config ctx = EvalContext( - name=evals[0].script or "eval", + name=eval_name, # Use the same smart name api_key=api_key, job_id=implicit_job_id, env_config=evals[0].env_config, @@ -440,6 +454,9 @@ async def run_eval( error_occurred = any(e.error is not None for e in completed) yield ctx + except ParallelEvalComplete: + # Expected - body re-executed on summary context, skip it + pass except Exception: error_occurred = True raise @@ -457,6 +474,8 @@ async def _run_parallel_eval( api_key: str | None, code_snippet: str | None, max_concurrent: int | None, + trace: bool = True, + quiet: bool = False, ) -> list[EvalContext]: """Run parallel evaluation. @@ -495,7 +514,9 @@ async def _run_parallel_eval( eval_copy.index = idx eval_copy.variants = variant eval_copy.code_snippet = code_snippet - eval_copy._suppress_link = True + eval_copy._suppress_link = True # Individual traces don't print links + eval_copy._trace = trace + eval_copy._quiet = quiet eval_objects.append(eval_copy) idx += 1 elif tasks: @@ -505,7 +526,7 @@ async def _run_parallel_eval( for _ in range(group): # Convert Task to Eval (backwards compatibility) task_eval = Eval( - env_config=None, # Task has its own mcp_config + env=None, # Task has its own mcp_config script=None, args={}, api_key=api_key, @@ -515,6 +536,8 @@ async def _run_parallel_eval( variants=variant, code_snippet=code_snippet, _suppress_link=True, + _trace=trace, + _quiet=quiet, ) # Store task reference for EvalContext creation task_eval._task = task # type: ignore[attr-defined] @@ -525,7 +548,7 @@ async def _run_parallel_eval( for variant in variant_combos: for _ in range(group): blank_eval = Eval( - env_config=None, + env=None, script=None, args={}, api_key=api_key, @@ -535,6 +558,8 @@ async def _run_parallel_eval( variants=variant, code_snippet=code_snippet, _suppress_link=True, + _trace=trace, + _quiet=quiet, ) eval_objects.append(blank_eval) idx += 1 @@ -551,63 +576,18 @@ async def _run_parallel_eval( async def run_one(eval_obj: Eval) -> EvalContext: """Run a single Eval and return its EvalContext.""" - # Check if this is a Task-based eval (backwards compat) - task = getattr(eval_obj, "_task", None) - try: - if task is not None: - # Task-based: use EvalContext.from_task - ctx = EvalContext.from_task( - task=task, - api_key=eval_obj.api_key, - job_id=eval_obj.job_id, - group_id=eval_obj.group_id, - index=eval_obj.index, - variants=eval_obj.variants, - code_snippet=eval_obj.code_snippet, - ) - ctx._suppress_link = eval_obj._suppress_link - - if sem: - async with sem, ctx: - await runner(ctx) - else: - async with ctx: - await runner(ctx) - return ctx + if sem: + async with sem, eval_obj as ctx: + await runner(ctx) else: - # Eval-based: enter the Eval directly - if sem: - async with sem, eval_obj as ctx: - await runner(ctx) - else: - async with eval_obj as ctx: - await runner(ctx) - return ctx + async with eval_obj as ctx: + await runner(ctx) + return ctx except Exception as e: logger.warning("Parallel eval %d failed: %s", eval_obj.index, e) - # Create a failed context - if task is not None: - ctx = EvalContext.from_task( - task=task, - api_key=eval_obj.api_key, - job_id=eval_obj.job_id, - group_id=eval_obj.group_id, - index=eval_obj.index, - variants=eval_obj.variants, - code_snippet=eval_obj.code_snippet, - ) - else: - ctx = EvalContext( - name=eval_obj.script or "eval", - api_key=eval_obj.api_key, - job_id=eval_obj.job_id, - group_id=eval_obj.group_id, - index=eval_obj.index, - variants=eval_obj.variants, - code_snippet=eval_obj.code_snippet, - env_config=eval_obj.env_config, - ) + # Create a failed context from the eval + ctx = eval_obj.to_eval_context() ctx.error = e return ctx diff --git a/hud/eval/tests/test_eval.py b/hud/eval/tests/test_eval.py index 26471244..ff9e8411 100644 --- a/hud/eval/tests/test_eval.py +++ b/hud/eval/tests/test_eval.py @@ -25,7 +25,7 @@ def test_init_defaults(self) -> None: def test_init_with_config(self) -> None: """Eval can be initialized with env_config and script.""" config = {"name": "test-env", "hubs": []} - ev = Eval(env_config=config, script="checkout", args={"user_id": "alice"}) + ev = Eval(env=config, script="checkout", args={"user_id": "alice"}) assert ev.env_config == config assert ev.script == "checkout" diff --git a/hud/eval/types.py b/hud/eval/types.py index 86da6957..a6c8b376 100644 --- a/hud/eval/types.py +++ b/hud/eval/types.py @@ -32,11 +32,10 @@ class ParallelEvalComplete(Exception): class EvalPayload(BaseModel): """Base payload for eval enter/exit.""" - task_name: str prompt: str | None = None code_snippet: str | None = None env_config: EnvConfig | None = None - all_hubs: bool = False + job_name: str | None = None job_id: str | None = None group_id: str | None = None variants: dict[str, Any] | None = None diff --git a/hud/server/server.py b/hud/server/server.py index 431592ea..497c1019 100644 --- a/hud/server/server.py +++ b/hud/server/server.py @@ -486,7 +486,6 @@ def _sync_import_router( for key, prompt in router._prompt_manager._prompts.items(): new_key = f"{prefix}_{key}" if prefix else key self._prompt_manager._prompts[new_key] = prompt - # await self.import_server(hidden_router, prefix=None, **kwargs) def _get_docker_logs( self, From 8766595fb8b10156f5e2b7aea9169966ee0a07cd Mon Sep 17 00:00:00 2001 From: lorenss-m Date: Wed, 10 Dec 2025 14:21:00 -0800 Subject: [PATCH 20/92] docs update --- docs/quick-links/environments.mdx | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/docs/quick-links/environments.mdx b/docs/quick-links/environments.mdx index 64c9a6af..5c583bdc 100644 --- a/docs/quick-links/environments.mdx +++ b/docs/quick-links/environments.mdx @@ -11,7 +11,7 @@ An environment is everything an agent can interact with—your APIs, services, d Start with `hud init` to scaffold an environment—works with existing codebases or from scratch: ```bash -hud init my-env +hud init ``` Every tool is just a function. Decorate it with `@env.tool()` and agents can call it: @@ -66,13 +66,13 @@ Call the environment with a script name and arguments to create an eval: eval = env("checkout", product_name="Laptop") async with hud.eval(eval, group=4) as ctx: + # Connect your agent here. Handle tool calls, run agent loop... response = await client.chat.completions.create( model="gpt-4o", messages=[{"role": "user", "content": ctx.prompt}], tools=ctx.as_openai_chat_tools() ) - # Handle tool calls, run agent loop... - + ctx.submit(response.choices[0].message.content) print(ctx.reward) @@ -86,7 +86,7 @@ Testing your agent loop without hitting real services? Mock mode returns fake re ```python env.mock() -env.mock_tool("search", "Mock search results") +env.mock_tool("search", "Mock search results") # Manual override of mock async with hud.eval(env(), group=4) as ctx: tools = env.as_openai_chat_tools() From 62a553e84c204503692dcc3cea6f962ed594430a Mon Sep 17 00:00:00 2001 From: lorenss-m Date: Wed, 10 Dec 2025 14:57:42 -0800 Subject: [PATCH 21/92] lazy mcp use --- hud/cli/__init__.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/hud/cli/__init__.py b/hud/cli/__init__.py index 8d3c0e4b..61868bef 100644 --- a/hud/cli/__init__.py +++ b/hud/cli/__init__.py @@ -16,11 +16,6 @@ from hud.utils.hud_console import HUDConsole from . import list_func as list_module -from .analyze import ( - analyze_environment, - analyze_environment_from_config, - analyze_environment_from_mcp_config, -) from .build import build_command from .clone import clone_repository, get_clone_message, print_error, print_tutorial from .debug import debug_mcp_stdio @@ -103,6 +98,13 @@ def analyze( hud analyze --config mcp-config.json # From MCP config hud analyze --cursor text-2048-dev # From Cursor config[/not dim] """ + # Lazy import to avoid loading mcp_use on simple CLI commands + from .analyze import ( + analyze_environment, + analyze_environment_from_config, + analyze_environment_from_mcp_config, + ) + if config: # Load config from JSON file (always live for configs) asyncio.run(analyze_environment_from_config(config, output_format, verbose)) From b396b1304485808c378614af364086fa6fea4fa3 Mon Sep 17 00:00:00 2001 From: lorenss-m Date: Wed, 10 Dec 2025 15:21:49 -0800 Subject: [PATCH 22/92] docs update and deps --- docs/docs.json | 39 +-- docs/reference/environments.mdx | 570 +++++++++++--------------------- docs/reference/eval.mdx | 405 ----------------------- docs/reference/evals.mdx | 208 ++++++++++++ docs/reference/mcpserver.mdx | 510 ++++++++++++++++++++++++++++ docs/reference/types.mdx | 191 ++++++----- hud/cli/analyze.py | 7 +- hud/cli/build.py | 4 +- hud/cli/debug.py | 7 +- hud/cli/utils/interactive.py | 9 +- 10 files changed, 1032 insertions(+), 918 deletions(-) delete mode 100644 docs/reference/eval.mdx create mode 100644 docs/reference/evals.mdx create mode 100644 docs/reference/mcpserver.mdx diff --git a/docs/docs.json b/docs/docs.json index b3629164..cf3736cf 100644 --- a/docs/docs.json +++ b/docs/docs.json @@ -48,44 +48,15 @@ "quick-links/deploy" ] }, - { - "group": "Core Concepts", - "pages": [ - "core-concepts/architecture", - "core-concepts/mcp-protocol", - "core-concepts/task-system" - ] - }, { "group": "SDK Reference", "pages": [ - "reference/eval", + "reference/evals", + "reference/environments", "reference/tools", + "reference/mcpserver", "reference/agents", - "reference/types", - "reference/environments", - "reference/tasks" - ] - }, - { - "group": "Environments", - "pages": [ - "build-environments/index", - "build-environments/spec" - ] - }, - { - "group": "Beta Features", - "pages": [ - "beta/index", - "beta/rft" - ] - }, - { - "group": "Agents", - "pages": [ - "evaluate-agents/create-agents", - "evaluate-agents/benchmarks" + "reference/types" ] }, { @@ -138,7 +109,7 @@ "reference/tools", "reference/agents", "reference/types", - "reference/environments", + "reference/mcpserver", "reference/tasks" ] }, diff --git a/docs/reference/environments.mdx b/docs/reference/environments.mdx index 889477f9..e58323d4 100644 --- a/docs/reference/environments.mdx +++ b/docs/reference/environments.mdx @@ -1,490 +1,288 @@ --- title: "Environments" -description: "SDK reference for building MCP environments" +description: "SDK reference for the Environment class - tools, connectors, and integrations" icon: "cube" --- -The HUD SDK provides `MCPServer` for building MCP-compatible environments that work with any MCP client. +`Environment` is the unified class for defining tools, connecting to services, and formatting for any LLM provider. -## MCPServer +## Environment ```python -from hud.server import MCPServer +from hud import Environment + +env = Environment("my-env") ``` -Enhanced FastMCP server with Docker-friendly features for building HUD environments. +### Constructor -**Constructor Parameters:** | Parameter | Type | Description | Default | |-----------|------|-------------|---------| -| `name` | `str` | Server name for MCP handshake | Required | -| `instructions` | `str` | Server instructions/description | `None` | -| `**fastmcp_kwargs` | `Any` | Additional FastMCP parameters | - | - -**Key Features:** -1. **SIGTERM handling** - Graceful shutdown in containers via custom runner -2. **Initialize decorator** - Async setup during MCP initialize request (stdout is temporarily redirected to stderr during initialization to avoid corrupting MCP output) -3. **Shutdown decorator** - Runs only on SIGTERM (container termination), not on hot‑reload/SIGINT -4. **Enhanced add_tool()** - Automatically handles `BaseTool` instances and raw FastMCP Tool objects -5. **Tool decorator passthrough** - `@mcp.tool` returns the original function for easy composition -6. **FastMCP inheritance** - All FastMCP methods available (`mount`, `resource`, `tool`) - -### Decorators +| `name` | `str` | Environment name | `"environment"` | +| `instructions` | `str \| None` | Description/instructions | `None` | +| `conflict_resolution` | `ConflictResolution` | How to handle tool name conflicts | `PREFIX` | -#### @initialize +### Context Manager -Run async setup during MCP initialize request: +Environment must be used as an async context manager to connect: ```python -mcp = MCPServer(name="my-env") - -@mcp.initialize -async def setup_environment(ctx): - """ - Initialize environment resources. - - Args: - ctx: RequestContext with: - - ctx.meta: Client metadata dict - - ctx.session: MCP ServerSession - """ - # Access metadata from agent (if provided) - if ctx.meta: - progress_token = ctx.meta.get("progressToken") - display_width = ctx.meta.get("display_width", 1920) - display_height = ctx.meta.get("display_height", 1080) - - # Send progress notifications - if progress_token: - await ctx.session.send_progress_notification( - progress_token=progress_token, - progress=50, - total=100, - message="Initializing environment..." - ) +async with env: + tools = env.as_openai_chat_tools() + result = await env.call_tool("my_tool", arg="value") ``` -#### @shutdown +## Defining Tools -Run cleanup on SIGTERM (container termination only): +### @env.tool() + +Register functions as callable tools: ```python -@mcp.shutdown -async def cleanup(): - """Clean up resources on shutdown.""" - if browser_provider: - browser_provider.close() - logger.info("Cleanup complete") +@env.tool() +def count_letter(text: str, letter: str) -> int: + """Count occurrences of a letter in text.""" + return text.lower().count(letter.lower()) + +@env.tool() +async def fetch_data(url: str) -> dict: + """Fetch JSON data from URL.""" + async with httpx.AsyncClient() as client: + response = await client.get(url) + return response.json() ``` -### Tool Registration +Tools are automatically documented from type hints and docstrings. + +## Scripts -Three ways to register tools: +Scripts define evaluation logic with two yields: ```python -# 1. Decorator for simple functions -@mcp.tool() -async def my_tool(param: str) -> dict: - return {"result": param} - -# 2. Add BaseTool instances -from hud.tools import BashTool -bash = BashTool() -mcp.add_tool(bash) # Automatically uses bash.mcp internally - -# 3. Add non-BaseTool instances directly -from custom import PlaywrightTool -playwright = PlaywrightTool() -mcp.add_tool(playwright) # Added as-is +@env.script("checkout") +async def checkout_flow(product: str): + # First yield: send prompt, receive answer + answer = yield f"Add '{product}' to cart and checkout" + + # Second yield: return reward based on result + order_exists = await check_order(product) + yield 1.0 if order_exists else 0.0 ``` -### Hub Pattern (mount) - -Use BaseHub for organized tool namespaces: +Create Evals from scripts: ```python -from hud.tools import BaseHub +eval = env("checkout", product="laptop") -# Create hub -setup_hub = BaseHub("setup") - -# Add internal tools (hidden from agents) -@setup_hub.tool("board") -async def setup_board(size: int = 4): - game = setup_hub.env - game.reset(size=size) - return [TextContent(text=f"{size}x{size} board initialized")] +async with hud.eval(eval) as ctx: + await agent.run(ctx.prompt) + await ctx.submit(agent.response) +``` -# Mount hub on server -mcp.mount(setup_hub) +## Connectors -# Agents call via dispatcher: setup(name="board", arguments={"size": 4}) -``` +Connect to external services as tool sources. -### Resources +### connect_hub() -Expose metadata via MCP resources: +Connect to a deployed HUD environment: ```python -@mcp.resource("telemetry://live") -async def get_telemetry(): - """Expose live telemetry data.""" - return { - "provider": os.getenv("BROWSER_PROVIDER"), - "status": "running" if browser_provider else "stopped", - "live_url": browser_provider.get_live_view_url() if browser_provider else None, - "timestamp": datetime.now().isoformat() - } +env.connect_hub("my-org/browser", prefix="browser") +# Tools available as browser_navigate, browser_click, etc. ``` -### Running the Server +### connect_fastapi() + +Import FastAPI routes as tools: ```python -if __name__ == "__main__": - # Run with SIGTERM handling (stdio by default) - mcp.run() +from fastapi import FastAPI - # Or use development transports (HTTP/SSE) - mcp.run(transport="http", port=8765) - mcp.run(transport="sse", port=8080) -``` +api = FastAPI() -When using HTTP/SSE, HUD development helper endpoints are available: +@api.get("/users/{user_id}", operation_id="get_user") +def get_user(user_id: int): + return {"id": user_id, "name": "Alice"} -- `GET /hud` – overview -- `GET /hud/tools` – list tools with schemas -- `GET /hud/resources` – list resources -- `GET /hud/prompts` – list prompts +env.connect_fastapi(api) +# Tool available as get_user +``` -## Real Environment Examples +| Parameter | Type | Description | Default | +|-----------|------|-------------|---------| +| `app` | `FastAPI` | FastAPI application | Required | +| `name` | `str \| None` | Server name | `app.title` | +| `prefix` | `str \| None` | Tool name prefix | `None` | +| `include_hidden` | `bool` | Include routes with `include_in_schema=False` | `True` | -### Minimal Environment +### connect_openapi() + +Import from OpenAPI spec: ```python -# src/hud_controller/server.py -from hud.server import MCPServer -from mcp.types import TextContent - -mcp = MCPServer(name="counter-env") -counter = {"value": 0} - -@mcp.tool() -async def setup(start_value: int = 0): - """Initialize counter.""" - counter["value"] = start_value - return {"status": "ready", "counter": counter["value"]} - -@mcp.tool() -async def increment(): - """Increment counter.""" - counter["value"] += 1 - return [TextContent(text=f"Counter: {counter['value']}", type="text")] - -@mcp.tool() -async def evaluate(target: int): - """Check if target reached.""" - from hud.tools.types import EvaluationResult - return EvaluationResult( - reward=1.0 if counter["value"] >= target else 0.0, - done=counter["value"] >= target - ) - -if __name__ == "__main__": - mcp.run() +env.connect_openapi("https://api.example.com/openapi.json") ``` -### text_2048 Environment +### connect_server() -From `environments/text_2048/src/hud_controller/server.py`: +Mount an MCPServer or FastMCP directly: ```python -from hud.server import MCPServer -from .game import Game2048 -from .tools import MoveTool -from .setup import setup as setup_hub -from .evaluate import evaluate as evaluate_hub - -mcp = MCPServer(name="text-2048") -game = None - -@mcp.initialize -async def initialize_environment(ctx): - global game - - # Progress notifications - progress_token = getattr(ctx.meta, "progressToken", None) if ctx.meta else None - - async def send_progress(progress: int, message: str): - if progress_token: - await ctx.session.send_progress_notification( - progress_token=progress_token, - progress=progress, - total=100, - message=message - ) - - await send_progress(0, "Starting 2048 game environment...") - - # Create game - game = Game2048() - game.reset() - - await send_progress(50, "Setting up game board...") - - # Set game on hubs - setup_hub.env = game - evaluate_hub.env = game - - # Mount hubs - mcp.mount(setup_hub) - mcp.mount(evaluate_hub) - - await send_progress(70, "Configuring tools...") - - # Add move tool - mcp.add_tool(MoveTool(env=game)) - - await send_progress(100, "2048 environment ready") +from fastmcp import FastMCP + +tools = FastMCP("tools") + +@tools.tool +def greet(name: str) -> str: + return f"Hello, {name}!" + +env.connect_server(tools) ``` -### remote_browser Environment +### connect_mcp_config() -From `environments/remote_browser/src/hud_controller/server.py`: +Connect via MCP config dict: ```python -from hud.server import MCPServer -from hud.tools.computer import HudComputerTool, AnthropicComputerTool, OpenAIComputerTool -from .tools import PlaywrightToolWithMemory, BrowserExecutor -from .setup import setup as setup_hub -from .evaluate import evaluate as evaluate_hub -from .providers import get_provider - -mcp = MCPServer( - name="HUD Remote Browser Environment", - instructions="""Remote browser automation environment...""" -) - -# Global state -browser_provider = None -playwright_tool = None - -@mcp.resource("telemetry://live") -async def get_telemetry_resource(): - """MCP resource with live browser status.""" - return { - "provider": os.getenv("BROWSER_PROVIDER", "unknown"), - "status": "running" if browser_provider else "stopped", - "live_url": browser_provider.get_live_view_url() if browser_provider else None, - "cdp_url": browser_provider.cdp_url if browser_provider else None +env.connect_mcp_config({ + "my-server": { + "command": "uvx", + "args": ["some-mcp-server"] } - -@mcp.initialize -async def initialize_environment(ctx): - global browser_provider, playwright_tool - - # Get metadata - metadata = ctx.meta - progress_token = metadata.get("progressToken", None) - - # Initialize provider - provider_name = os.getenv("BROWSER_PROVIDER") - provider_class = get_provider(provider_name) - browser_provider = provider_class(config) - - # Launch browser - cdp_url = await browser_provider.launch() - - # Create playwright tool - playwright_tool = PlaywrightToolWithMemory(cdp_url=cdp_url) - await playwright_tool._ensure_browser() - - # Add playwright tool (not a BaseTool, added directly) - mcp.add_tool(playwright_tool) - - # Create computer tools - executor = BrowserExecutor(playwright_tool) - tool_kwargs = {"executor": executor} - - # Add display dimensions from metadata - if metadata: - width = metadata.get("display_width") - height = metadata.get("display_height") - if width and height: - tool_kwargs["width"] = width - tool_kwargs["height"] = height - - # Add computer tools (all are BaseTool subclasses) - mcp.add_tool(HudComputerTool(**tool_kwargs)) - mcp.add_tool(AnthropicComputerTool(**tool_kwargs)) - mcp.add_tool(OpenAIComputerTool(**tool_kwargs)) - - # Mount hubs - setup_hub.env = playwright_tool - evaluate_hub.env = playwright_tool - mcp.mount(setup_hub) - mcp.mount(evaluate_hub) - -@mcp.shutdown -async def shutdown_environment(): - """Cleanup browser resources.""" - global browser_provider - if browser_provider: - browser_provider.close() - browser_provider = None +}) ``` -## Standard Structure +### connect_image() -### Directory Layout +Connect to a Docker image via stdio: -``` -my-environment/ -├── Dockerfile -├── pyproject.toml -├── controller/ # MCP controller (stdio) -│ ├── __init__.py # mcp = MCPServer() -│ ├── __main__.py # python -m controller → mcp.run() -│ ├── hooks.py # @mcp.initialize / @mcp.shutdown -│ └── tools.py # @mcp.tool(...) -└── environment/ # Optional backend (HTTP/IPC) - └── server.py # e.g., FastAPI app +```python +env.connect_image("mcp/fetch") ``` -### Dockerfile +## Tool Formatting -```dockerfile -FROM python:3.11-slim +Convert tools to provider-specific formats. -WORKDIR /app +### OpenAI -# Copy and install -COPY pyproject.toml ./ -COPY controller/ ./controller/ -COPY environment/ ./environment/ -RUN pip install --no-cache-dir -e . +```python +# Chat Completions API +tools = env.as_openai_chat_tools() +response = await client.chat.completions.create( + model="gpt-4o", + messages=messages, + tools=tools, +) -ENV ENV_SERVER_PORT=8005 +# Responses API +tools = env.as_openai_responses_tools() -# Start optional backend, then MCP controller on stdio -CMD ["sh", "-c", "uvicorn environment.server:app --host 0.0.0.0 --port $ENV_SERVER_PORT --log-level warning & python -m controller"] +# Agents SDK (requires openai-agents) +tools = env.as_openai_agent_tools() ``` -### Hub Module Pattern - -Example from text_2048: +### Anthropic/Claude ```python -# src/hud_controller/setup/__init__.py -from hud.tools.base import BaseHub - -setup = BaseHub("setup") +tools = env.as_claude_tools() +response = await client.messages.create( + model="claude-sonnet-4-5", + messages=messages, + tools=tools, +) +``` -# Import all setup functions to register them -from . import board +### Gemini -__all__ = ["setup"] +```python +tools = env.as_gemini_tools() +config = env.as_gemini_tool_config() +``` -# src/hud_controller/setup/board.py -from . import setup +### LangChain -@setup.tool("board") -async def setup_board(board_size: int = 4): - """Initialize game board.""" - game = setup.env # Access environment from hub - game.reset(size=board_size) - return [TextContent(text=f"{board_size}x{board_size} game initialized")] +```python +# Requires langchain-core +tools = env.as_langchain_tools() ``` -## Key Concepts +## Calling Tools -### Environment State +### call_tool() -Three patterns for managing state: +Execute tools with auto-format detection: -1. **Global variables** (simple environments): - ```python - game = None - - @mcp.initialize - async def initialize_environment(ctx): - global game - game = Game2048() - ``` - -2. **Context class** (complex environments): - ```python - class EnvironmentContext: - def __init__(self): - self.browser = None - self.page = None - - env = EnvironmentContext() - ``` +```python +# Simple call +result = await env.call_tool("my_tool", arg="value") -3. **Hub env attribute** (for tool access): - ```python - setup_hub.env = game # Tools access via hub.env - ``` +# From OpenAI tool call +result = await env.call_tool(response.choices[0].message.tool_calls[0]) -### Tool Lifecycle +# From Claude tool use +result = await env.call_tool(response.content[0]) # tool_use block +``` -1. **Setup tools** - Hidden from agents, prepare environment state -2. **Interaction tools** - Available to agents for control -3. **Evaluate tools** - Hidden from agents, score performance +Returns result in matching format (OpenAI tool call → OpenAI tool message, etc.). -### Progress Notifications +## Mock Mode -Send [progress updates](https://modelcontextprotocol.io/specification/basic/utilities/progress) during long-running operations: +Test without real connections: ```python -async def send_progress(progress: int, message: str): - if progress_token: - await ctx.session.send_progress_notification( - progress_token=progress_token, - progress=progress, - total=100, - message=message - ) +env.mock() # Enable mock mode + +# Set specific mock outputs +env.mock_tool("navigate", "Navigation successful") +env.mock_tool("screenshot", b"fake_image_data") + +async with env: + result = await env.call_tool("navigate", url="https://example.com") + # Returns "Navigation successful" instead of actually navigating + +env.unmock() # Disable mock mode ``` - -Progress notifications follow the [MCP progress specification](https://modelcontextprotocol.io/specification/basic/utilities/progress#progress-flow). The `progressToken` comes from the client's request [metadata](https://modelcontextprotocol.io/specification/basic/index#_meta). - +| Method | Description | +|--------|-------------| +| `mock(enable=True)` | Enable/disable mock mode | +| `unmock()` | Disable mock mode | +| `mock_tool(name, output)` | Set specific mock output | +| `is_mock` | Check if mock mode is enabled | -### Metadata Access +## Properties -Agent metadata flows through initialization: +| Property | Type | Description | +|----------|------|-------------| +| `name` | `str` | Environment name | +| `prompt` | `str \| None` | Default prompt (set by connect_task) | +| `is_connected` | `bool` | True if in context | +| `connections` | `dict[str, Connector]` | Active connections | + +## Creating Evals + +Call the environment to create an Eval: ```python -@mcp.initialize -async def initialize_environment(ctx): - # From agent's metadata class variable - width = ctx.meta.get("display_width", 1920) if ctx.meta else 1920 - height = ctx.meta.get("display_height", 1080) if ctx.meta else 1080 -``` +# With script +eval = env("checkout", product="laptop") -## Testing +# Without script (just the environment) +eval = env() +``` -```bash -# CLI testing -hud debug my-env:latest -hud analyze my-env:latest +Then run with `hud.eval()`: -# Python testing -async def test(): - from hud.clients import MCPClient - - client = MCPClient({ - "env": {"command": "docker", "args": ["run", "-i", "my-env"]} - }) - - async with client: - tools = await client.list_tools() - result = await client.call_tool("setup", {"value": 0}) +```python +async with hud.eval(eval, variants={"model": ["gpt-4o"]}) as ctx: + ... ``` ## See Also -- [Build Environments](/build-environments) - Getting started guide -- [Tools](/reference/tools) - Tool implementation reference -- [Environment Spec](/build-environments/spec) - Technical specification and architecture \ No newline at end of file +- [Evals](/reference/evals) - hud.eval() reference +- [MCPServer](/reference/mcpserver) - Building MCP servers +- [Environments Guide](/quick-links/environments) - Getting started guide + diff --git a/docs/reference/eval.mdx b/docs/reference/eval.mdx deleted file mode 100644 index 3ef0f7e8..00000000 --- a/docs/reference/eval.mdx +++ /dev/null @@ -1,405 +0,0 @@ ---- -title: "Evaluation API" -description: "SDK reference for running evaluations with hud.eval()" -icon: "flask-vial" ---- - -The HUD SDK provides a unified evaluation API through `hud.eval()` for tracking agent performance, running parallel evaluations, and integrating with the HUD platform. - -## Overview - -There are three ways to run evaluations: - -1. **`hud.eval()`** - Standalone context manager for any evaluation -2. **`env.eval()`** - Method on `Environment` for evaluating within an existing environment -3. **`run_tasks()`** - High-level batch execution with automatic agent creation - -## hud.eval() - -The primary evaluation context manager. Creates an `EvalContext` which is a full `Environment` with evaluation tracking. - -```python -import hud - -async with hud.eval("my-org/browser-task:1") as ctx: - # ctx is an EvalContext (extends Environment) - tools = await ctx.list_tools() - result = await ctx.call_tool("navigate", url="https://example.com") - ctx.reward = 1.0 # Set the evaluation reward -``` - -### Parameters - -| Parameter | Type | Description | Default | -|-----------|------|-------------|---------| -| `source` | `str \| list[str] \| Task \| list[Task] \| None` | Task source (slugs or Task objects) | `None` | -| `variants` | `dict[str, Any] \| None` | A/B test configuration | `None` | -| `group` | `int` | Runs per variant for statistical significance | `1` | -| `group_ids` | `list[str] \| None` | Custom group IDs for parallel runs | `None` | -| `job_id` | `str \| None` | Job ID to link traces to | `None` | -| `api_key` | `str \| None` | API key for backend calls | `None` | -| `max_concurrent` | `int \| None` | Maximum concurrent evaluations | `None` | - -### Task Sources - -The `source` parameter accepts multiple formats: - -```python -# 1. Blank evaluation (manual reward) -async with hud.eval() as ctx: - ctx.reward = compute_reward() - -# 2. Single task slug -async with hud.eval("my-org/browser-task") as ctx: - await agent.run(ctx) - -# 3. Task at specific index -async with hud.eval("my-org/evalset:0") as ctx: - await agent.run(ctx) - -# 4. All tasks in an evalset (wildcard) -async with hud.eval("my-org/evalset:*") as ctx: - await agent.run(ctx) - -# 5. Multiple slugs -async with hud.eval(["task:0", "task:1", "task:2"]) as ctx: - await agent.run(ctx) - -# 6. Task objects directly (backwards compatible) -from hud.types import Task -tasks = [Task(prompt="Navigate to docs", mcp_config={...})] -async with hud.eval(tasks) as ctx: - await agent.run(ctx) -``` - -### Variants and Groups - -Run A/B tests with multiple configurations: - -```python -# Test different models -async with hud.eval( - "my-org/evalset:*", - variants={"model": ["gpt-4o", "claude-sonnet"]}, - group=3, # 3 runs per variant for statistical significance -) as ctx: - model = ctx.variants["model"] # Current variant assignment - agent = create_agent(model=model) - result = await agent.run(ctx) - ctx.reward = result.reward - -# Access all results after completion -for result in ctx.results: - print(f"{result.variants}: reward={result.reward}") -``` - -**How it works:** -- `variants` dict with list values creates the cartesian product -- `group` multiplies each variant combination -- Total runs = `len(tasks) × len(variant_combos) × group` -- For parallel runs (total > 1), a job is automatically created - -### Concurrency Control - -Limit concurrent evaluations to manage resources: - -```python -async with hud.eval( - "my-org/large-evalset:*", - max_concurrent=10, # Max 10 parallel evaluations -) as ctx: - await agent.run(ctx) -``` - -## env.eval() - -Create evaluation contexts from an existing `Environment`: - -```python -from hud import Environment - -async with Environment() as env: - # Connect to MCP servers - await env.connect_hub("test-browser-26") - - # Run evaluation within this environment - async with env.eval("my-evaluation", group=3) as ctx: - # ctx inherits env's connections - tools = await ctx.list_tools() - await agent.run(ctx) - ctx.reward = result.reward -``` - -### Parameters - -Same as `hud.eval()`, plus: - -| Parameter | Type | Description | Default | -|-----------|------|-------------|---------| -| `name` | `str` | Evaluation name (required) | Required | -| `trace_id` | `str \| None` | Custom trace ID | `None` | - -### Connection Inheritance - -When you call `env.eval()`, the `EvalContext` copies the parent environment's connections: - -```python -async with Environment() as env: - await env.connect_hub("my-hub") - - # Parallel evaluations each get their own connection copies - async with env.eval("test", group=3) as ctx: - # Each parallel run has independent connections - await ctx.call_tool("my_tool") -``` - -## EvalContext - -`EvalContext` extends `Environment` with evaluation-specific functionality. - -### Properties - -| Property | Type | Description | -|----------|------|-------------| -| `trace_id` | `str` | Unique trace identifier | -| `eval_name` | `str` | Evaluation name | -| `job_id` | `str \| None` | Parent job ID | -| `group_id` | `str \| None` | Group ID for parallel runs | -| `index` | `int` | Index in parallel execution | -| `variants` | `dict[str, Any]` | Current variant assignment | -| `reward` | `float \| None` | Evaluation reward (settable) | -| `error` | `BaseException \| None` | Error if evaluation failed | -| `results` | `list[EvalContext] \| None` | Results from parallel runs | -| `task` | `Task \| None` | Task definition (if loaded from slug) | -| `prompt` | `str \| None` | Task prompt | -| `headers` | `dict[str, str]` | Trace headers for HTTP requests | - -### Methods - -All `Environment` methods are available, plus: - -```python -# Set reward -ctx.reward = 1.0 - -# Access task configuration -if ctx.task: - print(ctx.task.prompt) - print(ctx.task.agent_config) # Agent configuration hints - -# Get trace headers for external HTTP calls -headers = ctx.headers # {"Trace-Id": "...", "Trace-Parent": "..."} -``` - -### Creating from Task - -```python -from hud.eval.context import EvalContext -from hud.types import Task - -task = Task( - prompt="Navigate to the docs page", - mcp_config={"hud": {"url": "...", "headers": {...}}}, - setup_tool={"name": "setup", "arguments": {...}}, - evaluate_tool={"name": "evaluate", "arguments": {...}}, -) - -ctx = EvalContext.from_task(task) -async with ctx: - # MCP connections configured from task.mcp_config - # setup_tool and evaluate_tool configured - tools = await ctx.list_tools() -``` - -## run_tasks() - -High-level batch execution that creates agents automatically: - -```python -from hud.datasets import run_tasks -from hud.types import AgentType -from hud.utils.tasks import load_tasks - -# Load tasks from HuggingFace or file -tasks = load_tasks("hud-evals/SheetBench-50") - -# Run with automatic agent creation -results = await run_tasks( - tasks=tasks, - agent_type=AgentType.CLAUDE, - agent_params={"checkpoint_name": "claude-sonnet-4-5"}, - max_concurrent=30, - max_steps=10, - group_size=3, # 3 runs per task -) -``` - -### Parameters - -| Parameter | Type | Description | Default | -|-----------|------|-------------|---------| -| `tasks` | `list[Task]` | List of Task objects | Required | -| `agent_type` | `AgentType` | Agent type enum | Required | -| `agent_params` | `dict[str, Any] \| None` | Agent configuration | `None` | -| `name` | `str` | Job name | `"Evaluation"` | -| `max_concurrent` | `int` | Maximum concurrent tasks | `30` | -| `metadata` | `dict[str, Any] \| None` | Job metadata | `None` | -| `max_steps` | `int` | Maximum steps per task | `10` | -| `group_size` | `int` | Runs per task | `1` | -| `remote` | `bool` | Submit to HUD platform | `False` | - -### Returns - -- If `group_size == 1`: `list[Trace]` - Results in task order -- If `group_size > 1`: `list[dict]` - Statistics per task group - -### Remote Execution - -Submit tasks to the HUD platform for remote execution: - -```python -await run_tasks( - tasks=tasks, - agent_type=AgentType.CLAUDE, - remote=True, # Submit to platform -) -# Returns immediately, monitor at https://hud.ai/jobs/{job_id} -``` - -## Task Configuration - -Tasks define the evaluation environment and success criteria: - -```python -from hud.types import Task - -task = Task( - id="nav-001", - prompt="Navigate to the documentation page", - mcp_config={ - "hud": { - "url": "https://mcp.hud.ai/v3/mcp", - "headers": { - "Authorization": "Bearer ${HUD_API_KEY}", - "Mcp-Image": "hudpython/hud-remote-browser:latest" - } - } - }, - setup_tool={ - "name": "setup", - "arguments": {"name": "navigate", "arguments": {"url": "https://example.com"}} - }, - evaluate_tool={ - "name": "evaluate", - "arguments": {"name": "url_match", "arguments": {"pattern": ".*/docs.*"}} - }, - agent_config={ - "allowed_tools": ["playwright", "computer"], - "system_prompt": "You are a web navigation agent." - }, - metadata={"difficulty": "easy", "category": "navigation"} -) -``` - -### Task Fields - -| Field | Type | Description | -|-------|------|-------------| -| `id` | `str \| None` | Unique task identifier | -| `prompt` | `str` | Task instruction | -| `mcp_config` | `dict[str, Any]` | MCP server configuration | -| `setup_tool` | `MCPToolCall \| list[MCPToolCall] \| None` | Setup tool calls | -| `evaluate_tool` | `MCPToolCall \| list[MCPToolCall] \| None` | Evaluation tool calls | -| `agent_config` | `BaseAgentConfig \| None` | Agent configuration hints | -| `metadata` | `dict[str, Any]` | Custom metadata | - -### Environment Variable Substitution - -MCP config supports `${VAR_NAME}` substitution: - -```python -mcp_config = { - "hud": { - "url": "${HUD_MCP_URL:https://mcp.hud.ai/v3/mcp}", # With default - "headers": { - "Authorization": "Bearer ${HUD_API_KEY}" # From environment - } - } -} -``` - -## HTTP Instrumentation - -When running inside an eval context, HTTP requests to HUD services automatically include trace headers: - -```python -import httpx - -async with hud.eval("test") as ctx: - # Trace headers are automatically injected - async with httpx.AsyncClient() as client: - # Requests to inference.hud.ai, mcp.hud.ai include Trace-Id - response = await client.post( - "https://inference.hud.ai/v1/messages", - json={...} - ) -``` - -This enables automatic telemetry linking without manual header management. - -## Best Practices - -### 1. Use Variants for A/B Testing - -```python -async with hud.eval( - "evalset:*", - variants={ - "model": ["gpt-4o", "claude"], - "temperature": [0.0, 0.7], - }, - group=3, -) as ctx: - # Runs: 2 models × 2 temps × 3 groups = 12 evaluations - ... -``` - -### 2. Set Rewards Consistently - -```python -async with hud.eval("task") as ctx: - try: - result = await agent.run(ctx) - ctx.reward = result.reward - except Exception as e: - ctx.reward = 0.0 # Explicit failure reward - raise -``` - -### 3. Use Concurrency Limits for Resource-Heavy Tasks - -```python -async with hud.eval( - "browser-tasks:*", - max_concurrent=5, # Browser instances are heavy -) as ctx: - ... -``` - -### 4. Access Task Agent Config - -```python -async with hud.eval("my-org/task:0") as ctx: - if ctx.task and ctx.task.agent_config: - # Apply task's agent hints - allowed_tools = ctx.task.agent_config.allowed_tools - system_prompt = ctx.task.agent_config.system_prompt -``` - -## See Also - -- [`hud eval` CLI](/reference/cli/eval) - Command-line interface -- [Benchmarks](/evaluate-agents/benchmarks) - Creating and running benchmarks -- [Tasks](/reference/tasks) - Task configuration reference -- [Environments](/reference/environments) - Building MCP environments - diff --git a/docs/reference/evals.mdx b/docs/reference/evals.mdx new file mode 100644 index 00000000..58eb60f5 --- /dev/null +++ b/docs/reference/evals.mdx @@ -0,0 +1,208 @@ +--- +title: "Evals" +description: "SDK reference for hud.eval() - the unified evaluation context manager" +icon: "flask-vial" +--- + +`hud.eval()` is the primary way to run evaluations. It creates an `EvalContext` with telemetry, handles parallel execution, and integrates with the HUD platform. + +## hud.eval() + +```python +import hud + +async with hud.eval() as ctx: + # ctx is an EvalContext (extends Environment) + response = await client.chat.completions.create(...) + ctx.reward = 1.0 +``` + +### Parameters + +| Parameter | Type | Description | Default | +|-----------|------|-------------|---------| +| `source` | `Eval \| list[Eval] \| str \| None` | Eval objects from `env()`, task slugs, or None | `None` | +| `variants` | `dict[str, Any] \| None` | A/B test configuration (lists expand to combinations) | `None` | +| `group` | `int` | Runs per variant for statistical significance | `1` | +| `group_ids` | `list[str] \| None` | Custom group IDs for parallel runs | `None` | +| `job_id` | `str \| None` | Job ID to link traces to | `None` | +| `api_key` | `str \| None` | API key for backend calls | `None` | +| `max_concurrent` | `int \| None` | Maximum concurrent evaluations | `None` | +| `trace` | `bool` | Send telemetry to backend | `True` | +| `quiet` | `bool` | Suppress console output | `False` | + +### Source Types + +The `source` parameter accepts: + +```python +# 1. Blank eval - manual setup and reward +async with hud.eval() as ctx: + ctx.reward = compute_reward() + +# 2. Eval from Environment (recommended) +env = Environment("my-env") +eval = env("checkout", product="laptop") # Creates Eval from script +async with hud.eval(eval) as ctx: + await agent.run(ctx.prompt) + +# 3. Task slug (loads from platform) +async with hud.eval("my-org/browser-task") as ctx: + await agent.run(ctx) + +# 4. Multiple evals +evals = [env("checkout", product="laptop"), env("checkout", product="phone")] +async with hud.eval(evals) as ctx: + await agent.run(ctx.prompt) +``` + +### Variants + +Test multiple configurations in parallel: + +```python +async with hud.eval( + eval, + variants={"model": ["gpt-4o", "claude-sonnet-4-5"]}, +) as ctx: + model = ctx.variants["model"] # Current variant + response = await client.chat.completions.create(model=model, ...) +``` + +Lists expand to all combinations: + +```python +variants = { + "model": ["gpt-4o", "claude"], + "temperature": [0.0, 0.7], +} +# Creates 4 combinations: gpt-4o+0.0, gpt-4o+0.7, claude+0.0, claude+0.7 +``` + +### Groups + +Run each variant multiple times for statistical significance: + +```python +async with hud.eval(eval, variants={"model": ["gpt-4o"]}, group=5) as ctx: + # Runs 5 times - see the distribution of results + ... +``` + +Total runs = `len(evals) × len(variant_combinations) × group` + +### Concurrency Control + +```python +async with hud.eval( + evals, + max_concurrent=10, # Max 10 parallel evaluations +) as ctx: + ... +``` + +## EvalContext + +`EvalContext` extends `Environment` with evaluation tracking. + +### Properties + +| Property | Type | Description | +|----------|------|-------------| +| `trace_id` | `str` | Unique trace identifier | +| `eval_name` | `str` | Evaluation name | +| `prompt` | `str \| None` | Task prompt (from script or task) | +| `variants` | `dict[str, Any]` | Current variant assignment | +| `reward` | `float \| None` | Evaluation reward (settable) | +| `answer` | `str \| None` | Submitted answer | +| `error` | `BaseException \| None` | Error if failed | +| `results` | `list[EvalContext]` | Results from parallel runs | +| `headers` | `dict[str, str]` | Trace headers for HTTP requests | +| `job_id` | `str \| None` | Parent job ID | +| `group_id` | `str \| None` | Group ID for parallel runs | +| `index` | `int` | Index in parallel execution | + +### Methods + +All `Environment` methods are available, plus: + +```python +# Submit answer (passes to script for evaluation) +await ctx.submit(answer) + +# Set reward directly +ctx.reward = 1.0 + +# Access tools in provider formats +tools = ctx.as_openai_chat_tools() + +# Call tools +result = await ctx.call_tool("my_tool", arg="value") +``` + +### Headers for Telemetry + +Inside an eval context, trace headers are automatically injected into HTTP requests: + +```python +async with hud.eval() as ctx: + # Requests to HUD services include Trace-Id automatically + response = await client.chat.completions.create(...) + + # Manual access + print(ctx.headers) # {"Trace-Id": "..."} +``` + +## Working with Environments + +The recommended pattern is to create Evals from an Environment: + +```python +from hud import Environment +import hud + +env = Environment("my-env") + +@env.tool() +def count_letter(text: str, letter: str) -> int: + return text.lower().count(letter.lower()) + +@env.script("count") +async def count_script(sentence: str, letter: str): + answer = yield f"How many '{letter}' in '{sentence}'?" + correct = str(sentence.lower().count(letter.lower())) + yield correct in answer + +# Create an Eval from the script +eval = env("count", sentence="Strawberry", letter="r") + +# Run with variants +async with hud.eval(eval, variants={"model": ["gpt-4o", "claude"]}) as ctx: + response = await client.chat.completions.create( + model=ctx.variants["model"], + messages=[{"role": "user", "content": ctx.prompt}], + tools=ctx.as_openai_chat_tools(), + ) + await ctx.submit(response.choices[0].message.content or "") +``` + +## Results + +After parallel runs complete, access results on the summary context: + +```python +async with hud.eval(eval, variants={"model": ["gpt-4o", "claude"]}, group=3) as ctx: + ... + +# ctx.results contains all individual EvalContexts +for result in ctx.results: + print(f"{result.variants}: reward={result.reward}, answer={result.answer}") +``` + +## See Also + +- [Environments](/reference/environments) - Environment class reference +- [A/B Evals](/quick-links/ab-testing) - Variants and groups guide +- [Deploy](/quick-links/deploy) - Running evals at scale +- [`hud eval` CLI](/reference/cli/eval) - Command-line interface + diff --git a/docs/reference/mcpserver.mdx b/docs/reference/mcpserver.mdx new file mode 100644 index 00000000..42d33e2b --- /dev/null +++ b/docs/reference/mcpserver.mdx @@ -0,0 +1,510 @@ +--- +title: "MCPServer" +description: "SDK reference for building MCP servers" +icon: "server" +--- + +`MCPServer` is the base class for building MCP-compatible servers that work with any MCP client. It extends FastMCP with Docker-friendly features. + +## Why MCP? + +Traditional agent frameworks couple agents tightly to specific environments. MCP decouples them: + + + + - Agent code hardcoded for each environment + - No standardization across tools + - Difficult to swap agents or environments + + + + - Any agent works with any environment + - Standard protocol for all interactions + - Easy to swap components + + + +MCP standardizes agent-environment communication through JSON-RPC messages. Agents call tools exposed by servers and receive structured responses. + +## MCPServer + +```python +from hud.server import MCPServer +``` + +Enhanced FastMCP server with Docker-friendly features for building HUD environments. + +**Constructor Parameters:** +| Parameter | Type | Description | Default | +|-----------|------|-------------|---------| +| `name` | `str` | Server name for MCP handshake | Required | +| `instructions` | `str` | Server instructions/description | `None` | +| `**fastmcp_kwargs` | `Any` | Additional FastMCP parameters | - | + +**Key Features:** +1. **SIGTERM handling** - Graceful shutdown in containers via custom runner +2. **Initialize decorator** - Async setup during MCP initialize request (stdout is temporarily redirected to stderr during initialization to avoid corrupting MCP output) +3. **Shutdown decorator** - Runs only on SIGTERM (container termination), not on hot‑reload/SIGINT +4. **Enhanced add_tool()** - Automatically handles `BaseTool` instances and raw FastMCP Tool objects +5. **Tool decorator passthrough** - `@mcp.tool` returns the original function for easy composition +6. **FastMCP inheritance** - All FastMCP methods available (`mount`, `resource`, `tool`) + +### Decorators + +#### @initialize + +Run async setup during MCP initialize request: + +```python +mcp = MCPServer(name="my-env") + +@mcp.initialize +async def setup_environment(ctx): + """ + Initialize environment resources. + + Args: + ctx: RequestContext with: + - ctx.meta: Client metadata dict + - ctx.session: MCP ServerSession + """ + # Access metadata from agent (if provided) + if ctx.meta: + progress_token = ctx.meta.get("progressToken") + display_width = ctx.meta.get("display_width", 1920) + display_height = ctx.meta.get("display_height", 1080) + + # Send progress notifications + if progress_token: + await ctx.session.send_progress_notification( + progress_token=progress_token, + progress=50, + total=100, + message="Initializing environment..." + ) +``` + +#### @shutdown + +Run cleanup on SIGTERM (container termination only): + +```python +@mcp.shutdown +async def cleanup(): + """Clean up resources on shutdown.""" + if browser_provider: + browser_provider.close() + logger.info("Cleanup complete") +``` + +### Tool Registration + +Three ways to register tools: + +```python +# 1. Decorator for simple functions +@mcp.tool() +async def my_tool(param: str) -> dict: + return {"result": param} + +# 2. Add BaseTool instances +from hud.tools import BashTool +bash = BashTool() +mcp.add_tool(bash) # Automatically uses bash.mcp internally + +# 3. Add non-BaseTool instances directly +from custom import PlaywrightTool +playwright = PlaywrightTool() +mcp.add_tool(playwright) # Added as-is +``` + +### Hub Pattern (mount) + +Use BaseHub for organized tool namespaces: + +```python +from hud.tools import BaseHub + +# Create hub +setup_hub = BaseHub("setup") + +# Add internal tools (hidden from agents) +@setup_hub.tool("board") +async def setup_board(size: int = 4): + game = setup_hub.env + game.reset(size=size) + return [TextContent(text=f"{size}x{size} board initialized")] + +# Mount hub on server +mcp.mount(setup_hub) + +# Agents call via dispatcher: setup(name="board", arguments={"size": 4}) +``` + +### Resources + +Expose metadata via MCP resources: + +```python +@mcp.resource("telemetry://live") +async def get_telemetry(): + """Expose live telemetry data.""" + return { + "provider": os.getenv("BROWSER_PROVIDER"), + "status": "running" if browser_provider else "stopped", + "live_url": browser_provider.get_live_view_url() if browser_provider else None, + "timestamp": datetime.now().isoformat() + } +``` + +### Running the Server + +```python +if __name__ == "__main__": + # Run with SIGTERM handling (stdio by default) + mcp.run() + + # Or use development transports (HTTP/SSE) + mcp.run(transport="http", port=8765) + mcp.run(transport="sse", port=8080) +``` + +When using HTTP/SSE, HUD development helper endpoints are available: + +- `GET /hud` – overview +- `GET /hud/tools` – list tools with schemas +- `GET /hud/resources` – list resources +- `GET /hud/prompts` – list prompts + +## Real Environment Examples + +### Minimal Environment + +```python +# src/hud_controller/server.py +from hud.server import MCPServer +from mcp.types import TextContent + +mcp = MCPServer(name="counter-env") +counter = {"value": 0} + +@mcp.tool() +async def setup(start_value: int = 0): + """Initialize counter.""" + counter["value"] = start_value + return {"status": "ready", "counter": counter["value"]} + +@mcp.tool() +async def increment(): + """Increment counter.""" + counter["value"] += 1 + return [TextContent(text=f"Counter: {counter['value']}", type="text")] + +@mcp.tool() +async def evaluate(target: int): + """Check if target reached.""" + from hud.tools.types import EvaluationResult + return EvaluationResult( + reward=1.0 if counter["value"] >= target else 0.0, + done=counter["value"] >= target + ) + +if __name__ == "__main__": + mcp.run() +``` + +### text_2048 Environment + +From `environments/text_2048/src/hud_controller/server.py`: + +```python +from hud.server import MCPServer +from .game import Game2048 +from .tools import MoveTool +from .setup import setup as setup_hub +from .evaluate import evaluate as evaluate_hub + +mcp = MCPServer(name="text-2048") +game = None + +@mcp.initialize +async def initialize_environment(ctx): + global game + + # Progress notifications + progress_token = getattr(ctx.meta, "progressToken", None) if ctx.meta else None + + async def send_progress(progress: int, message: str): + if progress_token: + await ctx.session.send_progress_notification( + progress_token=progress_token, + progress=progress, + total=100, + message=message + ) + + await send_progress(0, "Starting 2048 game environment...") + + # Create game + game = Game2048() + game.reset() + + await send_progress(50, "Setting up game board...") + + # Set game on hubs + setup_hub.env = game + evaluate_hub.env = game + + # Mount hubs + mcp.mount(setup_hub) + mcp.mount(evaluate_hub) + + await send_progress(70, "Configuring tools...") + + # Add move tool + mcp.add_tool(MoveTool(env=game)) + + await send_progress(100, "2048 environment ready") +``` + +### remote_browser Environment + +From `environments/remote_browser/src/hud_controller/server.py`: + +```python +from hud.server import MCPServer +from hud.tools.computer import HudComputerTool, AnthropicComputerTool, OpenAIComputerTool +from .tools import PlaywrightToolWithMemory, BrowserExecutor +from .setup import setup as setup_hub +from .evaluate import evaluate as evaluate_hub +from .providers import get_provider + +mcp = MCPServer( + name="HUD Remote Browser Environment", + instructions="""Remote browser automation environment...""" +) + +# Global state +browser_provider = None +playwright_tool = None + +@mcp.resource("telemetry://live") +async def get_telemetry_resource(): + """MCP resource with live browser status.""" + return { + "provider": os.getenv("BROWSER_PROVIDER", "unknown"), + "status": "running" if browser_provider else "stopped", + "live_url": browser_provider.get_live_view_url() if browser_provider else None, + "cdp_url": browser_provider.cdp_url if browser_provider else None + } + +@mcp.initialize +async def initialize_environment(ctx): + global browser_provider, playwright_tool + + # Get metadata + metadata = ctx.meta + progress_token = metadata.get("progressToken", None) + + # Initialize provider + provider_name = os.getenv("BROWSER_PROVIDER") + provider_class = get_provider(provider_name) + browser_provider = provider_class(config) + + # Launch browser + cdp_url = await browser_provider.launch() + + # Create playwright tool + playwright_tool = PlaywrightToolWithMemory(cdp_url=cdp_url) + await playwright_tool._ensure_browser() + + # Add playwright tool (not a BaseTool, added directly) + mcp.add_tool(playwright_tool) + + # Create computer tools + executor = BrowserExecutor(playwright_tool) + tool_kwargs = {"executor": executor} + + # Add display dimensions from metadata + if metadata: + width = metadata.get("display_width") + height = metadata.get("display_height") + if width and height: + tool_kwargs["width"] = width + tool_kwargs["height"] = height + + # Add computer tools (all are BaseTool subclasses) + mcp.add_tool(HudComputerTool(**tool_kwargs)) + mcp.add_tool(AnthropicComputerTool(**tool_kwargs)) + mcp.add_tool(OpenAIComputerTool(**tool_kwargs)) + + # Mount hubs + setup_hub.env = playwright_tool + evaluate_hub.env = playwright_tool + mcp.mount(setup_hub) + mcp.mount(evaluate_hub) + +@mcp.shutdown +async def shutdown_environment(): + """Cleanup browser resources.""" + global browser_provider + if browser_provider: + browser_provider.close() + browser_provider = None +``` + +## Standard Structure + +### Directory Layout + +``` +my-environment/ +├── Dockerfile +├── pyproject.toml +├── controller/ # MCP controller (stdio) +│ ├── __init__.py # mcp = MCPServer() +│ ├── __main__.py # python -m controller → mcp.run() +│ ├── hooks.py # @mcp.initialize / @mcp.shutdown +│ └── tools.py # @mcp.tool(...) +└── environment/ # Optional backend (HTTP/IPC) + └── server.py # e.g., FastAPI app +``` + +### Dockerfile + +```dockerfile +FROM python:3.11-slim + +WORKDIR /app + +# Copy and install +COPY pyproject.toml ./ +COPY controller/ ./controller/ +COPY environment/ ./environment/ +RUN pip install --no-cache-dir -e . + +ENV ENV_SERVER_PORT=8005 + +# Start optional backend, then MCP controller on stdio +CMD ["sh", "-c", "uvicorn environment.server:app --host 0.0.0.0 --port $ENV_SERVER_PORT --log-level warning & python -m controller"] +``` + +### Hub Module Pattern + +Example from text_2048: + +```python +# src/hud_controller/setup/__init__.py +from hud.tools.base import BaseHub + +setup = BaseHub("setup") + +# Import all setup functions to register them +from . import board + +__all__ = ["setup"] + +# src/hud_controller/setup/board.py +from . import setup + +@setup.tool("board") +async def setup_board(board_size: int = 4): + """Initialize game board.""" + game = setup.env # Access environment from hub + game.reset(size=board_size) + return [TextContent(text=f"{board_size}x{board_size} game initialized")] +``` + +## Key Concepts + +### Environment State + +Three patterns for managing state: + +1. **Global variables** (simple environments): + ```python + game = None + + @mcp.initialize + async def initialize_environment(ctx): + global game + game = Game2048() + ``` + +2. **Context class** (complex environments): + ```python + class EnvironmentContext: + def __init__(self): + self.browser = None + self.page = None + + env = EnvironmentContext() + ``` + +3. **Hub env attribute** (for tool access): + ```python + setup_hub.env = game # Tools access via hub.env + ``` + +### Tool Lifecycle + +1. **Setup tools** - Hidden from agents, prepare environment state +2. **Interaction tools** - Available to agents for control +3. **Evaluate tools** - Hidden from agents, score performance + +### Progress Notifications + +Send [progress updates](https://modelcontextprotocol.io/specification/basic/utilities/progress) during long-running operations: + +```python +async def send_progress(progress: int, message: str): + if progress_token: + await ctx.session.send_progress_notification( + progress_token=progress_token, + progress=progress, + total=100, + message=message + ) +``` + + +Progress notifications follow the [MCP progress specification](https://modelcontextprotocol.io/specification/basic/utilities/progress#progress-flow). The `progressToken` comes from the client's request [metadata](https://modelcontextprotocol.io/specification/basic/index#_meta). + + +### Metadata Access + +Agent metadata flows through initialization: + +```python +@mcp.initialize +async def initialize_environment(ctx): + # From agent's metadata class variable + width = ctx.meta.get("display_width", 1920) if ctx.meta else 1920 + height = ctx.meta.get("display_height", 1080) if ctx.meta else 1080 +``` + +## Testing + +```bash +# CLI testing +hud debug my-env:latest +hud analyze my-env:latest + +# Python testing +async def test(): + from hud.clients import MCPClient + + client = MCPClient({ + "env": {"command": "docker", "args": ["run", "-i", "my-env"]} + }) + + async with client: + tools = await client.list_tools() + result = await client.call_tool("setup", {"value": 0}) +``` + +## See Also + +- [Environments](/reference/environments) - Environment class (client-side) +- [Tools](/reference/tools) - Tool implementation reference +- [Evals](/reference/evals) - Running evaluations \ No newline at end of file diff --git a/docs/reference/types.mdx b/docs/reference/types.mdx index 8361353a..da7ed17b 100644 --- a/docs/reference/types.mdx +++ b/docs/reference/types.mdx @@ -6,133 +6,126 @@ icon: "code" Core types used throughout the HUD SDK. -## Trace +## Eval -Returned by `agent.run()`. Contains the result of an agent execution. +Created by calling an Environment. Holds configuration for running an evaluation. ```python -from hud.types import Trace +from hud import Environment + +env = Environment("my-env") +eval = env("script_name", arg1="value") # Returns Eval ``` | Field | Type | Description | |-------|------|-------------| -| `reward` | `float` | Evaluation score (0.0-1.0) | -| `done` | `bool` | Whether execution completed | -| `content` | `str \| None` | Final response content | -| `isError` | `bool` | Whether an error occurred | -| `info` | `dict[str, Any]` | Additional metadata | -| `task` | `Task \| None` | The executed task | -| `trace` | `list[TraceStep]` | Execution trace steps | -| `messages` | `list[Any]` | Final conversation state | +| `env` | `Environment \| dict \| None` | Source environment | +| `script` | `str \| None` | Script name to run | +| `args` | `dict[str, Any]` | Script arguments | +| `trace_id` | `str \| None` | Trace identifier | +| `job_id` | `str \| None` | Parent job ID | +| `group_id` | `str \| None` | Group ID for parallel runs | +| `index` | `int` | Index in parallel execution | +| `variants` | `dict[str, Any] \| None` | Variant assignment | -## AgentResponse +## EvalContext -Returned by agent `get_response()` methods. Represents a single model response. +Returned by `hud.eval()`. Extends Environment with evaluation tracking. ```python -from hud.types import AgentResponse +async with hud.eval(eval) as ctx: + print(ctx.prompt) # Task prompt + print(ctx.variants) # Current variant + ctx.reward = 1.0 # Set reward ``` -| Field | Type | Description | -|-------|------|-------------| -| `tool_calls` | `list[MCPToolCall]` | Tools to execute | -| `done` | `bool` | Whether agent should stop | -| `content` | `str \| None` | Response text | -| `reasoning` | `str \| None` | Model reasoning/thinking | -| `info` | `dict[str, Any]` | Provider-specific metadata | -| `isError` | `bool` | Error flag | +| Property | Type | Description | +|----------|------|-------------| +| `trace_id` | `str` | Unique trace identifier | +| `eval_name` | `str` | Evaluation name | +| `prompt` | `str \| None` | Task prompt | +| `variants` | `dict[str, Any]` | Current variant assignment | +| `reward` | `float \| None` | Evaluation reward | +| `answer` | `str \| None` | Submitted answer | +| `error` | `BaseException \| None` | Error if failed | +| `results` | `list[EvalContext]` | Results from parallel runs | +| `headers` | `dict[str, str]` | Trace headers | ## MCPToolCall -Represents a tool call to be executed. +Represents a tool call to execute. ```python from hud.types import MCPToolCall + +call = MCPToolCall( + name="navigate", + arguments={"url": "https://example.com"} +) ``` | Field | Type | Description | |-------|------|-------------| -| `id` | `str` | Unique identifier (auto-generated if not provided) | +| `id` | `str` | Unique identifier (auto-generated) | | `name` | `str` | Tool name | | `arguments` | `dict[str, Any]` | Tool arguments | -**Example:** - -```python -tool_call = MCPToolCall( - name="playwright", - arguments={"action": "click", "selector": "#submit"} -) -``` - ## MCPToolResult Result from executing a tool call. ```python from hud.types import MCPToolResult + +result = MCPToolResult( + content=[TextContent(text="Success", type="text")], + isError=False +) ``` | Field | Type | Description | |-------|------|-------------| | `content` | `list[ContentBlock]` | Result content blocks | -| `structuredContent` | `dict[str, Any] \| None` | Structured result data | -| `isError` | `bool` | Whether the tool call failed | +| `structuredContent` | `dict \| None` | Structured result data | +| `isError` | `bool` | Whether the call failed | -## Task +## Trace -Defines an agent task with prompt, environment config, and lifecycle tools. +Returned by `agent.run()`. Contains the result of an agent execution. ```python -from hud.types import Task +from hud.types import Trace + +result = await agent.run(task, max_steps=20) +print(result.reward, result.done) ``` | Field | Type | Description | |-------|------|-------------| -| `prompt` | `str` | Instruction for the agent | -| `mcp_config` | `dict` | Environment connection config | -| `id` | `str \| None` | Unique identifier (required for datasets) | -| `system_prompt` | `str \| None` | Custom system prompt | -| `setup_tool` | `dict \| list[dict] \| None` | Tool(s) to initialize state | -| `evaluate_tool` | `dict \| list[dict] \| None` | Tool(s) to score performance | -| `agent_config` | `BaseAgentConfig \| None` | Task-specific agent config | -| `metadata` | `dict \| None` | Additional task metadata | - -**Example:** - -```python -task = Task( - prompt="Navigate to example.com and click login", - mcp_config={ - "hud": { - "url": "https://mcp.hud.ai/v3/mcp", - "headers": { - "Authorization": "Bearer ${HUD_API_KEY}", - "Mcp-Image": "hudpython/hud-remote-browser:latest" - } - } - }, - setup_tool={"name": "playwright", "arguments": {"action": "navigate", "url": "https://example.com"}}, - evaluate_tool={"name": "evaluate", "arguments": {"name": "url_contains", "substring": "/login"}} -) -``` +| `reward` | `float` | Evaluation score (0.0-1.0) | +| `done` | `bool` | Whether execution completed | +| `content` | `str \| None` | Final response content | +| `isError` | `bool` | Whether an error occurred | +| `info` | `dict[str, Any]` | Additional metadata | +| `trace` | `list[TraceStep]` | Execution trace steps | +| `messages` | `list[Any]` | Final conversation state | -## BaseAgentConfig +## AgentResponse -Standard agent configuration that tasks can override. +Returned by agent `get_response()` methods. ```python -from hud.types import BaseAgentConfig +from hud.types import AgentResponse ``` -| Field | Type | Description | Default | -|-------|------|-------------|---------| -| `allowed_tools` | `list[str] \| None` | Tool patterns to expose | `None` (all) | -| `disallowed_tools` | `list[str] \| None` | Tool patterns to hide | `None` | -| `system_prompt` | `str \| None` | Custom system prompt | `None` | -| `append_setup_output` | `bool` | Include setup output in first turn | `True` | -| `initial_screenshot` | `bool` | Include screenshot in initial context | `True` | -| `response_tool_name` | `str \| None` | Lifecycle tool for responses | `None` | +| Field | Type | Description | +|-------|------|-------------| +| `tool_calls` | `list[MCPToolCall]` | Tools to execute | +| `done` | `bool` | Whether agent should stop | +| `content` | `str \| None` | Response text | +| `reasoning` | `str \| None` | Model reasoning/thinking | +| `info` | `dict[str, Any]` | Provider-specific metadata | +| `isError` | `bool` | Error flag | ## AgentType @@ -140,6 +133,9 @@ Enum of supported agent types. ```python from hud.types import AgentType + +agent_cls = AgentType.CLAUDE.cls +agent = agent_cls.create() ``` | Value | Agent Class | @@ -150,25 +146,44 @@ from hud.types import AgentType | `AgentType.GEMINI` | `GeminiAgent` | | `AgentType.OPENAI_COMPATIBLE` | `OpenAIChatAgent` | -**Example:** +## ContentBlock + +MCP content types (from `mcp.types`): ```python -from hud.types import AgentType +from mcp.types import TextContent, ImageContent -agent_cls = AgentType.CLAUDE.cls # Returns ClaudeAgent class -agent = agent_cls.create() +# Text +TextContent(text="Hello", type="text") + +# Image +ImageContent(data="base64...", mimeType="image/png", type="image") ``` -## ContentBlock +## EvaluationResult -MCP content block types (from `mcp.types`): +Returned by evaluation tools. -- `TextContent` - Text content with `text` field -- `ImageContent` - Image with `data` (base64) and `mimeType` -- `EmbeddedResource` - Embedded resource reference +```python +from hud.tools.types import EvaluationResult -## See Also +result = EvaluationResult( + reward=0.8, + done=True, + content="Task completed", + info={"score": 80} +) +``` -- [Agents Reference](/reference/agents) - Agent classes and configuration -- [Tasks Reference](/reference/tasks) - Task configuration details +| Field | Type | Description | +|-------|------|-------------| +| `reward` | `float` | Score (0.0-1.0) | +| `done` | `bool` | Task complete | +| `content` | `str \| None` | Details | +| `info` | `dict` | Metadata | + +## See Also +- [Evals](/reference/evals) - hud.eval() reference +- [Environments](/reference/environments) - Environment class +- [Agents](/reference/agents) - Agent classes diff --git a/hud/cli/analyze.py b/hud/cli/analyze.py index cd58ad23..541617d4 100644 --- a/hud/cli/analyze.py +++ b/hud/cli/analyze.py @@ -12,7 +12,6 @@ from rich.table import Table from rich.tree import Tree -from hud.clients import MCPClient from hud.utils.hud_console import HUDConsole console = Console() @@ -45,6 +44,9 @@ async def analyze_environment(docker_cmd: list[str], output_format: str, verbose ) as progress: task = progress.add_task("Initializing MCP client...", total=None) + # Lazy import to avoid loading mcp_use on simple CLI commands + from hud.clients import MCPClient + client = MCPClient(mcp_config=mcp_config, verbose=verbose, auto_trace=False) try: @@ -344,6 +346,9 @@ async def _analyze_with_config( ) as progress: task = progress.add_task("Initializing MCP client...", total=None) + # Lazy import to avoid loading mcp_use on simple CLI commands + from hud.clients import MCPClient + client = MCPClient(mcp_config=mcp_config, verbose=verbose) try: diff --git a/hud/cli/build.py b/hud/cli/build.py index e191f0bc..bf300fef 100644 --- a/hud/cli/build.py +++ b/hud/cli/build.py @@ -18,7 +18,6 @@ import yaml from hud.cli.utils.source_hash import compute_source_hash, list_source_files -from hud.clients import MCPClient from hud.utils.hud_console import HUDConsole from hud.version import __version__ as hud_version @@ -451,6 +450,9 @@ async def analyze_mcp_environment( mcp_config = parse_docker_command(docker_cmd) # Initialize client and measure timing + # Lazy import to avoid loading mcp_use on simple CLI commands + from hud.clients import MCPClient + start_time = time.time() client = MCPClient(mcp_config=mcp_config, verbose=verbose, auto_trace=False) initialized = False diff --git a/hud/cli/debug.py b/hud/cli/debug.py index bd656f59..252546e0 100644 --- a/hud/cli/debug.py +++ b/hud/cli/debug.py @@ -11,7 +11,6 @@ from rich.console import Console -from hud.clients import MCPClient from hud.utils.hud_console import HUDConsole from .utils.logging import CaptureLogger, Colors, analyze_error_for_hints @@ -246,6 +245,9 @@ def read_stderr() -> None: logger.command(command) logger.info("Creating MCP client via hud...") + # Lazy import to avoid loading mcp_use on simple CLI commands + from hud.clients import MCPClient + client = MCPClient(mcp_config=mcp_config, verbose=False, auto_trace=False) await client.initialize() @@ -350,6 +352,9 @@ def read_stderr() -> None: try: logger.info("Creating 3 concurrent MCP clients...") + # Lazy import to avoid loading mcp_use on simple CLI commands + from hud.clients import MCPClient + for i in range(3): client_config = { f"test_concurrent_{i}": { diff --git a/hud/cli/utils/interactive.py b/hud/cli/utils/interactive.py index 22a698d4..1f8d1da7 100644 --- a/hud/cli/utils/interactive.py +++ b/hud/cli/utils/interactive.py @@ -3,7 +3,7 @@ from __future__ import annotations import json -from typing import Any +from typing import TYPE_CHECKING, Any import questionary from mcp.types import ImageContent, TextContent @@ -13,9 +13,11 @@ from rich.syntax import Syntax from rich.tree import Tree -from hud.clients import MCPClient from hud.utils.hud_console import HUDConsole +if TYPE_CHECKING: + from hud.clients import MCPClient + console = Console() @@ -38,6 +40,9 @@ def __init__(self, server_url: str, verbose: bool = False) -> None: async def connect(self) -> bool: """Connect to the MCP server.""" try: + # Lazy import to avoid loading mcp_use on simple CLI commands + from hud.clients import MCPClient + # Create MCP config for HTTP transport # Note: We explicitly set auth to None to prevent OAuth discovery attempts config = {"server": {"url": self.server_url, "auth": None}} From b3fe2968f63aa55f521144d11044ce7071b1750d Mon Sep 17 00:00:00 2001 From: lorenss-m Date: Wed, 10 Dec 2025 17:12:23 -0800 Subject: [PATCH 23/92] change deps and patches --- hud/__init__.py | 3 + hud/patches/__init__.py | 14 ++++ hud/patches/mcp_patches.py | 153 +++++++++++++++++++++++++++++++++++++ pyproject.toml | 8 +- 4 files changed, 174 insertions(+), 4 deletions(-) create mode 100644 hud/patches/__init__.py create mode 100644 hud/patches/mcp_patches.py diff --git a/hud/__init__.py b/hud/__init__.py index 43514b06..edc0af89 100644 --- a/hud/__init__.py +++ b/hud/__init__.py @@ -5,6 +5,9 @@ from __future__ import annotations +# Apply patches to third-party libraries early, before other imports +from . import patches as _patches # noqa: F401 + from .environment import Environment from .eval import EvalContext from .eval import run_eval as eval diff --git a/hud/patches/__init__.py b/hud/patches/__init__.py new file mode 100644 index 00000000..6b4e0934 --- /dev/null +++ b/hud/patches/__init__.py @@ -0,0 +1,14 @@ +""" +HUD runtime patches for third-party libraries. + +This module applies monkey-patches to fix issues in dependencies +without requiring forked packages. +""" + +from hud.patches.mcp_patches import apply_all_patches, suppress_fastmcp_logging + +# Apply patches on import +apply_all_patches() + +__all__ = ["apply_all_patches", "suppress_fastmcp_logging"] + diff --git a/hud/patches/mcp_patches.py b/hud/patches/mcp_patches.py new file mode 100644 index 00000000..73797bff --- /dev/null +++ b/hud/patches/mcp_patches.py @@ -0,0 +1,153 @@ +""" +Runtime patches for the standard mcp package. + +These patches apply fixes from the HUD fork without requiring a separate package. +Import this module early (e.g., in hud/__init__.py) to apply patches. +""" + +from __future__ import annotations + +import logging +from typing import Any + +logger = logging.getLogger(__name__) + + +def patch_streamable_http_error_handling() -> None: + """ + Patch StreamableHTTPTransport.post_writer to handle request errors properly. + + The original implementation doesn't catch errors in handle_request_async, + which can cause silent failures. This patch wraps the handler to send + errors to the read stream so clients know the request failed. + """ + try: + from mcp.client.streamable_http import StreamableHTTPTransport + + async def patched_post_writer( + self: Any, + client: Any, + write_stream_reader: Any, + read_stream_writer: Any, + write_stream: Any, + start_get_stream: Any, + tg: Any, + ) -> None: + """Patched post_writer with error handling for handle_request_async.""" + from mcp.client.streamable_http import RequestContext + from mcp.shared.message import ClientMessageMetadata + from mcp.types import JSONRPCRequest + + try: + async with write_stream_reader: + async for session_message in write_stream_reader: + message = session_message.message + metadata = ( + session_message.metadata + if isinstance(session_message.metadata, ClientMessageMetadata) + else None + ) + + is_resumption = bool(metadata and metadata.resumption_token) + + logger.debug("Sending client message: %s", message) + + if self._is_initialized_notification(message): + start_get_stream() + + ctx = RequestContext( + client=client, + headers=self.request_headers, + session_id=self.session_id, + session_message=session_message, + metadata=metadata, + read_stream_writer=read_stream_writer, + sse_read_timeout=self.sse_read_timeout, + ) + + # Patched: Accept ctx and is_resumption as params, add error handling + async def handle_request_async( + ctx: RequestContext = ctx, + is_resumption: bool = is_resumption, + ) -> None: + try: + if is_resumption: + await self._handle_resumption_request(ctx) + else: + await self._handle_post_request(ctx) + except Exception as e: + # Send error to read stream so client knows request failed + logger.error("Request handler error: %s", e) + await ctx.read_stream_writer.send(e) + + if isinstance(message.root, JSONRPCRequest): + tg.start_soon(handle_request_async, ctx, is_resumption) + else: + await handle_request_async(ctx, is_resumption) + + except Exception: + logger.exception("Error in post_writer") + finally: + await read_stream_writer.aclose() + await write_stream.aclose() + + StreamableHTTPTransport.post_writer = patched_post_writer + logger.debug("Patched StreamableHTTPTransport.post_writer") + + except ImportError: + logger.debug("mcp.client.streamable_http not available, skipping patch") + except Exception as e: + logger.warning("Failed to patch streamable_http: %s", e) + + +def patch_client_session_validation() -> None: + """ + Patch ClientSession to skip structured output validation. + + The original validation is strict and raises errors for non-conforming + but usable responses. We replace it with a no-op. + """ + try: + from mcp.client.session import ClientSession + + async def noop_validate(self: Any, name: str, result: Any) -> None: + """Skip structured output validation entirely.""" + pass + + ClientSession._validate_tool_result = noop_validate + logger.debug("Patched ClientSession._validate_tool_result to skip validation") + + except ImportError: + logger.debug("mcp.client.session not available, skipping patch") + except Exception as e: + logger.warning("Failed to patch client session: %s", e) + + +def suppress_fastmcp_logging(level: int = logging.WARNING) -> None: + """ + Suppress verbose fastmcp logging. + + FastMCP logs a lot of INFO-level messages that clutter output. + This sets all fastmcp loggers to the specified level. + + Args: + level: Logging level to set (default: WARNING) + """ + loggers_to_suppress = [ + "fastmcp", + "fastmcp.server.server", + "fastmcp.server.openapi", + "fastmcp.tools.tool_manager", + ] + for logger_name in loggers_to_suppress: + logging.getLogger(logger_name).setLevel(level) + logger.debug("Suppressed fastmcp logging to level %s", level) + + +def apply_all_patches() -> None: + """Apply all MCP patches.""" + patch_streamable_http_error_handling() + patch_client_session_validation() + suppress_fastmcp_logging() + logger.debug("All MCP patches applied") + diff --git a/pyproject.toml b/pyproject.toml index cf1742d2..72744511 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -15,8 +15,8 @@ dependencies = [ "pydantic>=2.6,<3", "pydantic-settings>=2.2,<3", # MCP dependencies - "hud-mcp-python-sdk>=3.13.2", - "hud-fastmcp-python-sdk>=0.1.2", + "mcp>1.21.1,<1.23", + "fastmcp==2.13.3", # CLI dependencies "typer>=0.9.0", "rich>=13.0.0", @@ -108,8 +108,8 @@ packages = ["hud"] # Agent implementations, AI providers, datasets, and telemetry agents = [ # MCP-use client (legacy) - "hud-mcp-use-python-sdk==2.3.20", - "langchain==0.3.27", # Required by mcp-use + "mcp-use==1.5.0", + "langchain>=1.0.0", # Required by mcp-use # AI providers "anthropic>=0.75", "openai>=2.8.1", From b98b2fe35ed79eddb428ac8d1bc25fb698772961 Mon Sep 17 00:00:00 2001 From: lorenss-m Date: Wed, 10 Dec 2025 17:27:09 -0800 Subject: [PATCH 24/92] touchups --- docs/docs.json | 3 +- docs/index-legacy.mdx | 113 ++++++++++++++++ docs/index.mdx | 179 +++++++++++++++----------- hud/__init__.py | 1 - hud/agents/gemini_cua.py | 4 +- hud/cli/dev.py | 4 +- hud/cli/flows/init.py | 10 +- hud/cli/utils/config.py | 19 ++- hud/cli/utils/docker.py | 4 +- hud/cli/utils/viewer.py | 2 +- hud/environment/connectors/local.py | 2 +- hud/environment/connectors/remote.py | 4 +- hud/environment/environment.py | 20 +-- hud/environment/scripts.py | 37 +++--- hud/environment/tests/test_scripts.py | 9 +- hud/eval/__init__.py | 6 +- hud/eval/context.py | 2 +- hud/eval/eval.py | 7 +- hud/eval/instrument.py | 4 +- hud/eval/manager.py | 8 +- hud/eval/parallel.py | 1 - hud/eval/tests/test_eval.py | 11 +- hud/eval/tests/test_manager.py | 149 +++++++++++---------- hud/patches/__init__.py | 1 - hud/patches/mcp_patches.py | 10 +- hud/server/server.py | 4 +- 26 files changed, 394 insertions(+), 220 deletions(-) create mode 100644 docs/index-legacy.mdx diff --git a/docs/docs.json b/docs/docs.json index cf3736cf..c4ec7980 100644 --- a/docs/docs.json +++ b/docs/docs.json @@ -35,7 +35,6 @@ "group": "Get Started", "pages": [ "index", - "quickstart", "llm-quickstart" ] }, @@ -89,7 +88,7 @@ { "group": "Get Started", "pages": [ - "index", + "index-legacy", "quickstart", "llm-quickstart" ] diff --git a/docs/index-legacy.mdx b/docs/index-legacy.mdx new file mode 100644 index 00000000..ecccffeb --- /dev/null +++ b/docs/index-legacy.mdx @@ -0,0 +1,113 @@ +--- +title: "Introduction" +description: "OSS environment + evals toolkit for AI agents." +icon: "book" +--- + + +**Version 0.4.73** - Latest stable release + + + + + Test Claude, Operator, or custom agents on benchmarks like SheetBench and OSWorld + + + + Wrap any software in dockerized MCP for scalable and generalizable agent evaluation + + + +## What is HUD? + +HUD connects AI agents to software environments using the Model Context Protocol (MCP). Whether you're evaluating existing agents or building new environments, HUD provides the infrastructure. + +```mermaid +graph LR + Agent["🤖 Any Agent
(Claude, Operator, etc.)"] + MCP["🔌 MCP Protocol
(Tool Calls)"] + Env["📦 Any Environment
(Browser, OS, etc.)"] + + Agent -->|"call_tool()"| MCP + MCP -->|"click(x, y)"| Env + Env -->|"screenshot"| MCP + MCP -->|"get_response()"| Agent + + style Agent fill:#3b82f6,stroke:#1e40af,stroke-width:2px,color:#ffffff + style MCP fill:#f59e0b,stroke:#d97706,stroke-width:2px,color:#ffffff + style Env fill:#10b981,stroke:#047857,stroke-width:2px,color:#ffffff +``` + +## Why HUD? + +- **🔌 MCP-native**: Any agent can connect to any environment +- **📡 Live telemetry**: Debug every tool call at [hud.ai](https://hud.ai) +- **⚡ HUD Gateway**: Unified inference API for all LLMs +- **🚀 Production-ready**: From local Docker to cloud scale +- **🎯 Built-in benchmarks**: OSWorld-Verified, SheetBench-50, and more +- **🔧 CLI tools**: Create, develop, and run with `hud init`, `hud dev`, `hud run`, `hud eval` + + + + Run your first agent evaluation with zero setup + + + + Unified inference API for OpenAI, Anthropic, Gemini, and Open Source Models + + + + Give your AI assistant full knowledge of HUD docs + + + + + +## Quick Example + +```python +import asyncio, os, hud +from hud.datasets import Task +from hud.agents import ClaudeAgent + +async def main(): + # Define evaluation task with remote MCP + task = Task( + prompt="Win a game of 2048 by reaching the 128 tile", + mcp_config={ + "hud": { + "url": "https://mcp.hud.ai/v3/mcp", + "headers": { + "Authorization": f"Bearer {os.getenv('HUD_API_KEY')}", + "Mcp-Image": "hudevals/hud-text-2048:0.1.3" + } + } + }, + setup_tool={"name": "setup", "arguments": {"name": "board", "arguments": { "board_size": 4}}}, + evaluate_tool={"name": "evaluate", "arguments": {"name": "max_number", "arguments": {"target": 64}}} + ) + + # Run agent (auto-creates MCP client) + agent = ClaudeAgent.create() + result = await agent.run(task) + print(f"Score: {result.reward}") + +asyncio.run(main()) +``` + +## Community + + + + Star the repo and contribute + + + + Join our community + + + +### Are you an enterprise building agents? + +[📅 Hop on a call](https://cal.com/jay-hud) or [📧 founders@hud.ai](mailto:founders@hud.ai) + diff --git a/docs/index.mdx b/docs/index.mdx index ecccffeb..116cfc5a 100644 --- a/docs/index.mdx +++ b/docs/index.mdx @@ -1,100 +1,124 @@ --- title: "Introduction" -description: "OSS environment + evals toolkit for AI agents." +description: "Build, evaluate, and train AI agents." icon: "book" --- - -**Version 0.4.73** - Latest stable release - +HUD gives you three things: a unified API for every model, a way to turn your code into agent-callable tools, and infrastructure to run evaluations at scale. - - - Test Claude, Operator, or custom agents on benchmarks like SheetBench and OSWorld - +## Install - - Wrap any software in dockerized MCP for scalable and generalizable agent evaluation - - +```bash +# Install CLI +uv tool install hud-python --python 3.12 -## What is HUD? - -HUD connects AI agents to software environments using the Model Context Protocol (MCP). Whether you're evaluating existing agents or building new environments, HUD provides the infrastructure. - -```mermaid -graph LR - Agent["🤖 Any Agent
(Claude, Operator, etc.)"] - MCP["🔌 MCP Protocol
(Tool Calls)"] - Env["📦 Any Environment
(Browser, OS, etc.)"] - - Agent -->|"call_tool()"| MCP - MCP -->|"click(x, y)"| Env - Env -->|"screenshot"| MCP - MCP -->|"get_response()"| Agent - - style Agent fill:#3b82f6,stroke:#1e40af,stroke-width:2px,color:#ffffff - style MCP fill:#f59e0b,stroke:#d97706,stroke-width:2px,color:#ffffff - style Env fill:#10b981,stroke:#047857,stroke-width:2px,color:#ffffff +# Set your API key +hud set HUD_API_KEY=your-key-here ``` -## Why HUD? +Get your API key at [hud.ai/settings/api-keys](https://hud.ai/settings/api-keys). -- **🔌 MCP-native**: Any agent can connect to any environment -- **📡 Live telemetry**: Debug every tool call at [hud.ai](https://hud.ai) -- **⚡ HUD Gateway**: Unified inference API for all LLMs -- **🚀 Production-ready**: From local Docker to cloud scale -- **🎯 Built-in benchmarks**: OSWorld-Verified, SheetBench-50, and more -- **🔧 CLI tools**: Create, develop, and run with `hud init`, `hud dev`, `hud run`, `hud eval` +## 1. Gateway: Any Model, One API - - - Run your first agent evaluation with zero setup - +Stop juggling API keys. Point any OpenAI-compatible client at `inference.hud.ai` and use Claude, GPT, Gemini, or Grok: - - Unified inference API for OpenAI, Anthropic, Gemini, and Open Source Models - +```python +from openai import AsyncOpenAI +import os + +client = AsyncOpenAI( + base_url="https://inference.hud.ai", + api_key=os.environ["HUD_API_KEY"] +) + +response = await client.chat.completions.create( + model="claude-sonnet-4-5", # or gpt-4o, gemini-2.5-pro, grok-4-1-fast... + messages=[{"role": "user", "content": "Hello!"}] +) +``` - - Give your AI assistant full knowledge of HUD docs - - +Every call is traced. View them at [hud.ai/home](https://hud.ai/home). + +→ [More on Gateway](/quick-links/gateway) + +## 2. Environments: Your Code, Agent-Ready + +Turn your code into tools agents can call. Define scripts that evaluate what agents do: + +```python +from hud import Environment + +env = Environment("my-env") + +@env.tool() +def search(query: str) -> str: + """Search the knowledge base.""" + return db.search(query) + +@env.script("find-answer") +async def find_answer(question: str): + answer = yield f"Find the answer to: {question}" + yield 1.0 if "correct" in answer.lower() else 0.0 +``` +Scripts define the prompt (first yield) and the scoring logic (second yield). The agent runs in between. +→ [More on Environments](/quick-links/environments) -## Quick Example +## 3. Evals: Test and Improve + +Run your script with different models. Compare results: ```python -import asyncio, os, hud -from hud.datasets import Task -from hud.agents import ClaudeAgent - -async def main(): - # Define evaluation task with remote MCP - task = Task( - prompt="Win a game of 2048 by reaching the 128 tile", - mcp_config={ - "hud": { - "url": "https://mcp.hud.ai/v3/mcp", - "headers": { - "Authorization": f"Bearer {os.getenv('HUD_API_KEY')}", - "Mcp-Image": "hudevals/hud-text-2048:0.1.3" - } - } - }, - setup_tool={"name": "setup", "arguments": {"name": "board", "arguments": { "board_size": 4}}}, - evaluate_tool={"name": "evaluate", "arguments": {"name": "max_number", "arguments": {"target": 64}}} +import hud + +eval = env("find-answer", question="What is 2+2?") + +async with hud.eval(eval, variants={"model": ["gpt-4o", "claude-sonnet-4-5"]}, group=5) as ctx: + response = await client.chat.completions.create( + model=ctx.variants["model"], + messages=[{"role": "user", "content": ctx.prompt}] ) - - # Run agent (auto-creates MCP client) - agent = ClaudeAgent.create() - result = await agent.run(task) - print(f"Score: {result.reward}") + await ctx.submit(response.choices[0].message.content) +``` + +**Variants** test different configurations. **Groups** repeat each to see the distribution. Results show up on [hud.ai](https://hud.ai/home) with scores, traces, and side-by-side comparisons. -asyncio.run(main()) +→ [More on A/B Evals](/quick-links/ab-testing) + +## 4. Deploy and Scale + +Push your environment to GitHub, connect it on [hud.ai](https://hud.ai), and run thousands of evals in parallel. Every run generates training data. + +```bash +hud init # Scaffold environment +git push # Push to GitHub +# Connect on hud.ai → New Environment +hud eval my-org/my-eval --model gpt-4o --group-size 100 ``` +→ [More on Deploy](/quick-links/deploy) + +## Next Steps + + + + One endpoint for every model. Full observability. + + + + Tools, scripts, and local testing. + + + + Variants, groups, and finding what works. + + + + Run at scale. Generate training data. + + + ## Community @@ -103,11 +127,12 @@ asyncio.run(main()) - Join our community + Join the community -### Are you an enterprise building agents? +## Enterprise -[📅 Hop on a call](https://cal.com/jay-hud) or [📧 founders@hud.ai](mailto:founders@hud.ai) +Building agents at scale? We work with teams on custom environments, benchmarks, and training pipelines. +[📅 Book a call](https://cal.com/jay-hud) · [📧 founders@hud.ai](mailto:founders@hud.ai) diff --git a/hud/__init__.py b/hud/__init__.py index edc0af89..be6e8ee9 100644 --- a/hud/__init__.py +++ b/hud/__init__.py @@ -7,7 +7,6 @@ # Apply patches to third-party libraries early, before other imports from . import patches as _patches # noqa: F401 - from .environment import Environment from .eval import EvalContext from .eval import run_eval as eval diff --git a/hud/agents/gemini_cua.py b/hud/agents/gemini_cua.py index 3a67477a..8ad02b21 100644 --- a/hud/agents/gemini_cua.py +++ b/hud/agents/gemini_cua.py @@ -249,7 +249,7 @@ def _extract_tool_call(self, part: genai_types.Part) -> MCPToolCall | None: # Map common argument shapes used by Gemini Computer Use # 1) Coordinate arrays → x/y coord = raw_args.get("coordinate") or raw_args.get("coordinates") - if isinstance(coord, (list, tuple)) and len(coord) >= 2: + if isinstance(coord, list | tuple) and len(coord) >= 2: try: normalized_args["x"] = int(coord[0]) normalized_args["y"] = int(coord[1]) @@ -263,7 +263,7 @@ def _extract_tool_call(self, part: genai_types.Part) -> MCPToolCall | None: or raw_args.get("destination_coordinate") or raw_args.get("destinationCoordinate") ) - if isinstance(dest, (list, tuple)) and len(dest) >= 2: + if isinstance(dest, list | tuple) and len(dest) >= 2: try: normalized_args["destination_x"] = int(dest[0]) normalized_args["destination_y"] = int(dest[1]) diff --git a/hud/cli/dev.py b/hud/cli/dev.py index 5809f118..913cc582 100644 --- a/hud/cli/dev.py +++ b/hud/cli/dev.py @@ -153,7 +153,7 @@ async def run_mcp_module( new_trace: bool = False, ) -> None: """Run an MCP module directly. - + Args: module_spec: Module specification in format "module" or "module:attribute" e.g., "server" (looks for mcp), "env:env" (looks for env) @@ -164,7 +164,7 @@ async def run_mcp_module( else: module_name = module_spec attr_name = "mcp" # Default attribute - + # Check if this is a reload (not first run) is_reload = os.environ.get("_HUD_DEV_RELOAD") == "1" diff --git a/hud/cli/flows/init.py b/hud/cli/flows/init.py index 94c4c471..b58ef3ad 100644 --- a/hud/cli/flows/init.py +++ b/hud/cli/flows/init.py @@ -37,13 +37,13 @@ def _has_hud_dependency(directory: Path) -> bool: def _add_hud_dependency(directory: Path) -> str: """Add hud-python using uv if available. - + Returns: "exists" if already present, "added" if added, "failed" if failed """ if _has_hud_dependency(directory): return "exists" - + try: result = subprocess.run( ["uv", "add", "hud-python", "openai"], # noqa: S607 @@ -91,9 +91,9 @@ def smart_init( - Otherwise: create new HUD environment """ from hud.settings import settings - + hud_console = HUDConsole() - + # Check for API key first if not settings.api_key: hud_console.error("HUD_API_KEY not found") @@ -104,7 +104,7 @@ def smart_init( hud_console.info("") hud_console.info("Get your key at: https://hud.ai/settings/api-keys") return - + target = Path(directory).resolve() # If directory is empty, use preset selection diff --git a/hud/cli/utils/config.py b/hud/cli/utils/config.py index 439cee3d..d13e6669 100644 --- a/hud/cli/utils/config.py +++ b/hud/cli/utils/config.py @@ -27,7 +27,9 @@ def parse_env_file(contents: str) -> dict[str, str]: """Parse simple KEY=VALUE lines into a dict. - Ignores blank lines and lines starting with '#'. - - Does not perform variable substitution or quoting. + - Strips inline comments (# and everything after) from unquoted values. + - Respects single and double quoted values (comments inside quotes are preserved). + - Does not perform variable substitution. """ data: dict[str, str] = {} for raw_line in contents.splitlines(): @@ -39,6 +41,21 @@ def parse_env_file(contents: str) -> dict[str, str]: key, value = line.split("=", 1) key = key.strip() value = value.strip() + + # Handle quoted values - preserve everything inside quotes + if value and value[0] in ('"', "'"): + quote_char = value[0] + # Find the closing quote + end_quote = value.find(quote_char, 1) + # Extract value without quotes (or strip opening quote if no closing quote) + value = value[1:end_quote] if end_quote != -1 else value[1:] + else: + # Unquoted value - strip inline comments + # Find # that's not escaped and treat as comment start + comment_idx = value.find("#") + if comment_idx != -1: + value = value[:comment_idx].rstrip() + if key: data[key] = value return data diff --git a/hud/cli/utils/docker.py b/hud/cli/utils/docker.py index 27cbd54b..3fed4551 100644 --- a/hud/cli/utils/docker.py +++ b/hud/cli/utils/docker.py @@ -121,7 +121,7 @@ def detect_environment_dir(start_dir: Path | None = None) -> Path | None: - Current directory containing `hud.lock.yaml` - Parent directory containing `hud.lock.yaml` - Current directory that looks like an environment if it has either a - `Dockerfile.hud`, `Dockerfile`, or a `pyproject.toml` (looser than `is_environment_directory`). + `Dockerfile.hud`, `Dockerfile`, or a `pyproject.toml` (looser than `is_environment_directory`) Returns the detected directory path or None if not found. """ @@ -204,7 +204,7 @@ def create_docker_run_command( # Load env from `.env` in detected env directory env_dir_path: Path | None = ( - Path(env_dir).resolve() if isinstance(env_dir, (str, Path)) else detect_environment_dir() + Path(env_dir).resolve() if isinstance(env_dir, str | Path) else detect_environment_dir() ) merged_env: dict[str, str] = {} diff --git a/hud/cli/utils/viewer.py b/hud/cli/utils/viewer.py index 8ea54a98..2d6efe28 100644 --- a/hud/cli/utils/viewer.py +++ b/hud/cli/utils/viewer.py @@ -46,7 +46,7 @@ def _truncate_value(value: Any, max_len: int = 60) -> str: if len(value) > max_len: return value[:max_len] + "…" return value - elif isinstance(value, (dict, list)): + elif isinstance(value, dict | list): s = json.dumps(value, separators=(",", ":")) if len(s) > max_len: return s[:max_len] + "…" diff --git a/hud/environment/connectors/local.py b/hud/environment/connectors/local.py index 1deea4b7..a8ef946a 100644 --- a/hud/environment/connectors/local.py +++ b/hud/environment/connectors/local.py @@ -173,5 +173,5 @@ def greet(name: str) -> str: result = await env.call_tool("greet", name="World") ``` """ - self.include_router(server, prefix=prefix) # type: ignore + self.include_router(server, prefix=prefix) # type: ignore return self diff --git a/hud/environment/connectors/remote.py b/hud/environment/connectors/remote.py index 599c5af0..d9179786 100644 --- a/hud/environment/connectors/remote.py +++ b/hud/environment/connectors/remote.py @@ -21,7 +21,7 @@ class RemoteConnectorMixin(MCPConfigConnectorMixin): """Mixin providing remote connection methods. - + Note: include_router() is inherited from MCPServer (via FastMCP). """ @@ -183,5 +183,5 @@ def connect_openapi( client=client, name=name or "openapi", ) - self.include_router(mcp_server, prefix=prefix) # type: ignore + self.include_router(mcp_server, prefix=prefix) # type: ignore return self diff --git a/hud/environment/environment.py b/hud/environment/environment.py index 0bc48a52..df731ac8 100644 --- a/hud/environment/environment.py +++ b/hud/environment/environment.py @@ -196,7 +196,7 @@ def _check_lifecycle_warning(self, name: str) -> None: def _connections_with_tool(self, tool_name: str) -> set[str]: """Get connection names that have a specific tool. - + Uses cached_tools from each Connector to check availability. """ result = set() @@ -207,19 +207,19 @@ def _connections_with_tool(self, tool_name: str) -> set[str]: return result async def _broadcast_tool( - self, - tool_name: str, + self, + tool_name: str, **kwargs: Any, ) -> dict[str, Any]: """Broadcast a tool call to all connections that have the tool. - + Automatically filters to only connections where the tool exists (based on cached_tools from initial discovery). - + Args: tool_name: Name of the tool to call **kwargs: Arguments to pass to the tool - + Returns: Dict mapping connection name to result (or exception) """ @@ -229,9 +229,9 @@ async def _broadcast_tool( targets = self._connections_with_tool(tool_name) if not targets: return {} - + results: dict[str, Any] = {} - + async def call_one(name: str) -> None: connector = self._connections.get(name) if not connector or not connector.client: @@ -242,7 +242,7 @@ async def call_one(name: str) -> None: except Exception as e: results[name] = e logger.debug("Broadcast '%s' to '%s' failed: %s", tool_name, name, e) - + await asyncio.gather(*[call_one(n) for n in targets], return_exceptions=True) return results @@ -611,11 +611,13 @@ def __call__( ```python env = Environment("my-env").connect_hub("browser") + @env.script() async def checkout(user_id: str): yield "Complete checkout" yield 1.0 + # Simple use - Eval is context manager async with env("checkout", user_id="alice") as ctx: await agent.run(ctx.prompt) diff --git a/hud/environment/scripts.py b/hud/environment/scripts.py index 7b077a46..4fd06218 100644 --- a/hud/environment/scripts.py +++ b/hud/environment/scripts.py @@ -63,7 +63,7 @@ def _init_scripts(self) -> None: self._script_sessions = {} self._script_latest = {} self._script_answers = {} - + # Register _hud_submit tool (underscore = hidden from agent) self._register_hud_submit_tool() @@ -80,14 +80,17 @@ async def submit(self, script: str, answer: str) -> None: Example: # Direct call with script name await env.submit("checkout", "Order completed successfully") - + # Or via EvalContext (knows its own script) await ctx.submit("Order completed successfully") """ # Store locally for our scripts self._script_answers[script] = answer - logger.debug("Stored answer for script '%s': %s...", - script, answer[:50] if len(answer) > 50 else answer) + logger.debug( + "Stored answer for script '%s': %s...", + script, + answer[:50] if len(answer) > 50 else answer, + ) # Broadcast to connections that have _hud_submit # Environment._broadcast_tool auto-filters to connections with the tool @@ -99,7 +102,7 @@ async def submit(self, script: str, answer: str) -> None: def _register_hud_submit_tool(self) -> None: """Register the _hud_submit tool for receiving agent answers. - + Named with underscore prefix to hide from agent tool listings. """ from fastmcp.tools import Tool @@ -117,8 +120,11 @@ async def _hud_submit(script: str, answer: str) -> str: """ # Store locally (don't broadcast - we ARE the target) script_self._script_answers[script] = answer - logger.debug("_hud_submit received answer for script '%s': %s...", - script, answer[:50] if len(answer) > 50 else answer) + logger.debug( + "_hud_submit received answer for script '%s': %s...", + script, + answer[:50] if len(answer) > 50 else answer, + ) return f"Answer submitted for script '{script}'" # Register the tool with underscore name @@ -128,14 +134,14 @@ async def _hud_submit(script: str, answer: str) -> str: async def run_script_setup(self, script_name: str, args: dict[str, Any]) -> str | None: """Run a script's setup phase and return the prompt. - + Handles both local scripts (registered via @env.script) and remote scripts (via MCP prompt). - + Args: script_name: Name of the script to run args: Arguments to pass to the script - + Returns: The prompt string from the script's setup phase, or None if failed """ @@ -180,13 +186,13 @@ async def run_script_setup(self, script_name: str, args: dict[str, Any]) -> str async def run_script_evaluate(self, script_name: str) -> float | None: """Run a script's evaluate phase and return the reward. - + Uses the submitted answer (if any) via gen.asend(). Handles both local and remote scripts. - + Args: script_name: Name of the script to evaluate - + Returns: The reward from the script's evaluate phase, or None if failed """ @@ -338,9 +344,7 @@ async def resource_handler() -> str: gen = script_self._script_sessions.pop(session_id, None) if gen is None: - raise ValueError( - f"Session '{session_id}' not found or already evaluated." - ) + raise ValueError(f"Session '{session_id}' not found or already evaluated.") # Get submitted answer (if any) answer = script_self._script_answers.pop(script_name_ref, None) @@ -388,4 +392,3 @@ async def resource_handler() -> str: return fn return decorator - diff --git a/hud/environment/tests/test_scripts.py b/hud/environment/tests/test_scripts.py index f7b3fd06..cfec80b0 100644 --- a/hud/environment/tests/test_scripts.py +++ b/hud/environment/tests/test_scripts.py @@ -59,6 +59,7 @@ async def checkout_script(user_id: str, amount: int = 100): # Find the prompt prompt = env._prompt_manager._prompts.get("test-env:checkout") assert prompt is not None + assert prompt.arguments is not None # Check arguments arg_names = [arg.name for arg in prompt.arguments] @@ -106,6 +107,7 @@ async def test_script(): # Run setup via prompt - no need for context prompt = env._prompt_manager._prompts.get("test-env:test") + assert prompt is not None await prompt.render({}) # Check session was stored @@ -126,13 +128,15 @@ async def test_script(): # Setup phase - no context needed for prompt/resource prompt = env._prompt_manager._prompts.get("test-env:test") + assert prompt is not None await prompt.render({}) assert "setup" in phases assert "evaluate" not in phases # Evaluate phase resource = env._resource_manager._resources.get("test-env:test") - reward_result = await resource.read() + assert resource is not None + await resource.read() assert "evaluate" in phases @@ -153,10 +157,9 @@ async def checkout_script(user_id: str, amount: int = 100): yield 1.0 prompt = env._prompt_manager._prompts.get("test-env:checkout") - + assert prompt is not None # No context needed for prompt render await prompt.render({"user_id": "alice", "amount": 50}) assert received_args["user_id"] == "alice" assert received_args["amount"] == 50 - diff --git a/hud/eval/__init__.py b/hud/eval/__init__.py index da8ce3fe..78cab2cd 100644 --- a/hud/eval/__init__.py +++ b/hud/eval/__init__.py @@ -36,12 +36,12 @@ # Auto-instrument httpx on import import hud.eval.instrument # noqa: F401 -# run_eval is safe to import (uses lazy imports internally) -from hud.eval.manager import run_eval - # Eval is safe to import from hud.eval.eval import Eval +# run_eval is safe to import (uses lazy imports internally) +from hud.eval.manager import run_eval + if TYPE_CHECKING: from hud.eval.context import EvalContext diff --git a/hud/eval/context.py b/hud/eval/context.py index b2ef831c..4ecafef0 100644 --- a/hud/eval/context.py +++ b/hud/eval/context.py @@ -253,7 +253,7 @@ def from_environment( ctx._script_sessions = {} ctx._script_latest = {} ctx._script_answers = {} - + # Store source env name for remote script lookups ctx._source_env_name = env.name diff --git a/hud/eval/eval.py b/hud/eval/eval.py index 2800c48d..4eb681f9 100644 --- a/hud/eval/eval.py +++ b/hud/eval/eval.py @@ -45,14 +45,14 @@ def build_eval_name(script: str | None, args: dict[str, Any] | None) -> str: return "eval" if not args: return script - + val_parts = [] for v in list(args.values())[:3]: # Max 3 values v_str = repr(v) if isinstance(v, str) else str(v) if len(v_str) > 25: v_str = v_str[:22] + "..." val_parts.append(v_str) - + if val_parts: return f"{script} with {', '.join(val_parts)}" return script @@ -138,6 +138,7 @@ def to_eval_context(self) -> EvalContext: task = getattr(self, "_task", None) if task is not None: import warnings + warnings.warn( "Task objects are deprecated. Use Eval from env() instead.", DeprecationWarning, @@ -190,7 +191,7 @@ def to_eval_context(self) -> EvalContext: async def __aenter__(self) -> EvalContext: """Enter eval context. - + Order of operations: 1. Create EvalContext from environment config 2. Connect environment (MCP servers, etc.) diff --git a/hud/eval/instrument.py b/hud/eval/instrument.py index 0024034e..e950522c 100644 --- a/hud/eval/instrument.py +++ b/hud/eval/instrument.py @@ -27,11 +27,11 @@ def _is_hud_url(url_str: str) -> bool: """Check if URL is a HUD service (inference or MCP).""" parsed = urlparse(url_str) request_host = parsed.netloc or url_str.split("/")[0] - + # Check for known HUD domains (works for any subdomain) if request_host.endswith((".hud.ai", ".hud.so")): return True - + # Also check settings URLs known_hosts = { urlparse(settings.hud_gateway_url).netloc, diff --git a/hud/eval/manager.py b/hud/eval/manager.py index 9eb56a9f..789c9df4 100644 --- a/hud/eval/manager.py +++ b/hud/eval/manager.py @@ -12,7 +12,6 @@ from typing import TYPE_CHECKING, Any from hud.eval.display import print_complete, print_eval_stats, print_link -from hud.eval.types import ParallelEvalComplete from hud.eval.parallel import ( ASTExtractionError, expand_variants, @@ -20,6 +19,7 @@ get_with_block_body, resolve_group_ids, ) +from hud.eval.types import ParallelEvalComplete if TYPE_CHECKING: from collections.abc import AsyncGenerator @@ -118,7 +118,6 @@ def _load_evals_from_slugs(slugs: str | list[str]) -> list[Eval]: """ import httpx - from hud.eval.eval import Eval from hud.settings import settings if isinstance(slugs, str): @@ -381,7 +380,7 @@ async def run_eval( _trace=trace, _quiet=quiet, ) - + # Apply common settings single_eval.api_key = api_key single_eval.job_id = job_id @@ -389,7 +388,7 @@ async def run_eval( single_eval.code_snippet = code_snippet single_eval._trace = trace single_eval._quiet = quiet - + async with single_eval as ctx: yield ctx @@ -485,7 +484,6 @@ async def _run_parallel_eval( import textwrap # Lazy import to avoid circular dependency - from hud.eval.context import EvalContext from hud.eval.eval import Eval from hud.eval.parallel import log_eval_stats diff --git a/hud/eval/parallel.py b/hud/eval/parallel.py index 59d27d91..9f70f333 100644 --- a/hud/eval/parallel.py +++ b/hud/eval/parallel.py @@ -7,7 +7,6 @@ from __future__ import annotations import ast -import asyncio import inspect import itertools import linecache diff --git a/hud/eval/tests/test_eval.py b/hud/eval/tests/test_eval.py index ff9e8411..512d0150 100644 --- a/hud/eval/tests/test_eval.py +++ b/hud/eval/tests/test_eval.py @@ -2,7 +2,7 @@ from __future__ import annotations -from unittest.mock import AsyncMock, MagicMock, patch +from unittest.mock import AsyncMock, patch import pytest @@ -34,7 +34,7 @@ def test_init_with_config(self) -> None: def test_copy_creates_new_instance(self) -> None: """copy() creates a new Eval instance.""" original = Eval( - env_config={"name": "test"}, + env={"name": "test"}, script="checkout", args={"user_id": "alice"}, variants={"model": "gpt-4o"}, @@ -42,7 +42,7 @@ def test_copy_creates_new_instance(self) -> None: copied = original.copy() assert copied is not original - assert copied.env_config == original.env_config + assert copied.env == original.env assert copied.script == original.script assert copied.args == original.args assert copied.args is not original.args # Deep copy @@ -128,9 +128,9 @@ async def test_context_clears_on_exit(self) -> None: with ( patch.object(EvalContext, "_eval_enter", new_callable=AsyncMock), patch.object(EvalContext, "_eval_exit", new_callable=AsyncMock), - patch.object(EvalContext, "__aexit__", new_callable=AsyncMock) as mock_exit, + patch.object(EvalContext, "__aexit__", new_callable=AsyncMock), ): - ctx = await ev.__aenter__() + await ev.__aenter__() assert ev._ctx is not None # Manually call __aexit__ on Eval (which will call mocked ctx.__aexit__) @@ -233,4 +233,3 @@ def test_call_captures_env_config_when_configured(self) -> None: assert ev2.env_config is not None assert ev2.env_config["name"] == "test-env" assert len(ev2.env_config["setup_tools"]) == 1 - diff --git a/hud/eval/tests/test_manager.py b/hud/eval/tests/test_manager.py index 53e1de80..75aa6ad7 100644 --- a/hud/eval/tests/test_manager.py +++ b/hud/eval/tests/test_manager.py @@ -16,77 +16,91 @@ class TestRunEvalNoArgs: @pytest.mark.asyncio async def test_blank_eval_creates_context(self) -> None: """hud.eval() with no args creates an EvalContext.""" - with patch.object(EvalContext, "_eval_enter", new_callable=AsyncMock): - with patch.object(EvalContext, "_eval_exit", new_callable=AsyncMock): - async with run_eval() as ctx: - assert isinstance(ctx, EvalContext) - assert ctx.eval_name == "eval" + with ( + patch.object(EvalContext, "_eval_enter", new_callable=AsyncMock), + patch.object(EvalContext, "_eval_exit", new_callable=AsyncMock), + ): + async with run_eval() as ctx: + assert isinstance(ctx, EvalContext) + assert ctx.eval_name == "eval" @pytest.mark.asyncio async def test_blank_eval_generates_trace_id(self) -> None: """hud.eval() with no args generates a trace_id.""" - with patch.object(EvalContext, "_eval_enter", new_callable=AsyncMock): - with patch.object(EvalContext, "_eval_exit", new_callable=AsyncMock): - async with run_eval() as ctx: - assert ctx.trace_id is not None - assert len(ctx.trace_id) == 36 # UUID format + with ( + patch.object(EvalContext, "_eval_enter", new_callable=AsyncMock), + patch.object(EvalContext, "_eval_exit", new_callable=AsyncMock), + ): + async with run_eval() as ctx: + assert ctx.trace_id is not None + assert len(ctx.trace_id) == 36 # UUID format @pytest.mark.asyncio async def test_blank_eval_sets_trace_headers(self) -> None: """hud.eval() sets trace headers in contextvar during context.""" - with patch.object(EvalContext, "_eval_enter", new_callable=AsyncMock): - with patch.object(EvalContext, "_eval_exit", new_callable=AsyncMock): - # Before context, no headers - assert get_current_trace_headers() is None - - async with run_eval() as ctx: - # Inside context, headers are set - headers = get_current_trace_headers() - assert headers is not None - assert headers["Trace-Id"] == ctx.trace_id - - # After context, headers are cleared - assert get_current_trace_headers() is None + with ( + patch.object(EvalContext, "_eval_enter", new_callable=AsyncMock), + patch.object(EvalContext, "_eval_exit", new_callable=AsyncMock), + ): + # Before context, no headers + assert get_current_trace_headers() is None + + async with run_eval() as ctx: + # Inside context, headers are set + headers = get_current_trace_headers() + assert headers is not None + assert headers["Trace-Id"] == ctx.trace_id + + # After context, headers are cleared + assert get_current_trace_headers() is None @pytest.mark.asyncio async def test_blank_eval_reward_can_be_set(self) -> None: """hud.eval() allows setting reward on context.""" - with patch.object(EvalContext, "_eval_enter", new_callable=AsyncMock): - with patch.object(EvalContext, "_eval_exit", new_callable=AsyncMock): - async with run_eval() as ctx: - assert ctx.reward is None - ctx.reward = 0.95 + with ( + patch.object(EvalContext, "_eval_enter", new_callable=AsyncMock), + patch.object(EvalContext, "_eval_exit", new_callable=AsyncMock), + ): + async with run_eval() as ctx: + assert ctx.reward is None + ctx.reward = 0.95 - assert ctx.reward == 0.95 + assert ctx.reward == 0.95 @pytest.mark.asyncio async def test_blank_eval_reports_reward_on_exit(self) -> None: """hud.eval() reports reward to backend on exit.""" - with patch.object(EvalContext, "_eval_enter", new_callable=AsyncMock): - with patch.object(EvalContext, "_eval_exit", new_callable=AsyncMock) as mock_exit: - async with run_eval() as ctx: - ctx.reward = 0.85 + with ( + patch.object(EvalContext, "_eval_enter", new_callable=AsyncMock), + patch.object(EvalContext, "_eval_exit", new_callable=AsyncMock) as mock_exit, + ): + async with run_eval() as ctx: + ctx.reward = 0.85 - # _eval_exit should have been called (with no error) - mock_exit.assert_called_once_with(None) + # _eval_exit should have been called (with no error) + mock_exit.assert_called_once_with(None) @pytest.mark.asyncio async def test_blank_eval_empty_variants(self) -> None: """hud.eval() with no args has empty variants dict.""" - with patch.object(EvalContext, "_eval_enter", new_callable=AsyncMock): - with patch.object(EvalContext, "_eval_exit", new_callable=AsyncMock): - async with run_eval() as ctx: - assert ctx.variants == {} + with ( + patch.object(EvalContext, "_eval_enter", new_callable=AsyncMock), + patch.object(EvalContext, "_eval_exit", new_callable=AsyncMock), + ): + async with run_eval() as ctx: + assert ctx.variants == {} @pytest.mark.asyncio async def test_blank_eval_has_headers_property(self) -> None: """hud.eval() context has headers property for gateway integration.""" - with patch.object(EvalContext, "_eval_enter", new_callable=AsyncMock): - with patch.object(EvalContext, "_eval_exit", new_callable=AsyncMock): - async with run_eval() as ctx: - headers = ctx.headers - assert "Trace-Id" in headers - assert headers["Trace-Id"] == ctx.trace_id + with ( + patch.object(EvalContext, "_eval_enter", new_callable=AsyncMock), + patch.object(EvalContext, "_eval_exit", new_callable=AsyncMock), + ): + async with run_eval() as ctx: + headers = ctx.headers + assert "Trace-Id" in headers + assert headers["Trace-Id"] == ctx.trace_id class TestRunEvalWithApiKey: @@ -95,10 +109,12 @@ class TestRunEvalWithApiKey: @pytest.mark.asyncio async def test_api_key_passed_to_context(self) -> None: """hud.eval(api_key=...) passes api_key to context.""" - with patch.object(EvalContext, "_eval_enter", new_callable=AsyncMock): - with patch.object(EvalContext, "_eval_exit", new_callable=AsyncMock): - async with run_eval(api_key="test-key") as ctx: - assert ctx._eval_api_key == "test-key" + with ( + patch.object(EvalContext, "_eval_enter", new_callable=AsyncMock), + patch.object(EvalContext, "_eval_exit", new_callable=AsyncMock), + ): + async with run_eval(api_key="test-key") as ctx: + assert ctx._eval_api_key == "test-key" class TestRunEvalWithJobId: @@ -107,10 +123,12 @@ class TestRunEvalWithJobId: @pytest.mark.asyncio async def test_job_id_passed_to_context(self) -> None: """hud.eval(job_id=...) passes job_id to context.""" - with patch.object(EvalContext, "_eval_enter", new_callable=AsyncMock): - with patch.object(EvalContext, "_eval_exit", new_callable=AsyncMock): - async with run_eval(job_id="job-123") as ctx: - assert ctx.job_id == "job-123" + with ( + patch.object(EvalContext, "_eval_enter", new_callable=AsyncMock), + patch.object(EvalContext, "_eval_exit", new_callable=AsyncMock), + ): + async with run_eval(job_id="job-123") as ctx: + assert ctx.job_id == "job-123" class TestRunEvalErrorHandling: @@ -119,15 +137,16 @@ class TestRunEvalErrorHandling: @pytest.mark.asyncio async def test_error_tracked_on_exception(self) -> None: """hud.eval() tracks error when exception occurs.""" - with patch.object(EvalContext, "_eval_enter", new_callable=AsyncMock): - with patch.object(EvalContext, "_eval_exit", new_callable=AsyncMock) as mock_exit: - with pytest.raises(ValueError): - async with run_eval() as ctx: - raise ValueError("test error") - - # _eval_exit should have been called with error message - mock_exit.assert_called_once() - error_msg = mock_exit.call_args[0][0] - assert error_msg is not None - assert "test error" in error_msg - + with ( + patch.object(EvalContext, "_eval_enter", new_callable=AsyncMock), + patch.object(EvalContext, "_eval_exit", new_callable=AsyncMock) as mock_exit, + ): + with pytest.raises(ValueError): + async with run_eval(): + raise ValueError("test error") + + # _eval_exit should have been called with error message + mock_exit.assert_called_once() + error_msg = mock_exit.call_args[0][0] + assert error_msg is not None + assert "test error" in error_msg diff --git a/hud/patches/__init__.py b/hud/patches/__init__.py index 6b4e0934..96c3ec0e 100644 --- a/hud/patches/__init__.py +++ b/hud/patches/__init__.py @@ -11,4 +11,3 @@ apply_all_patches() __all__ = ["apply_all_patches", "suppress_fastmcp_logging"] - diff --git a/hud/patches/mcp_patches.py b/hud/patches/mcp_patches.py index 73797bff..e987152c 100644 --- a/hud/patches/mcp_patches.py +++ b/hud/patches/mcp_patches.py @@ -16,7 +16,7 @@ def patch_streamable_http_error_handling() -> None: """ Patch StreamableHTTPTransport.post_writer to handle request errors properly. - + The original implementation doesn't catch errors in handle_request_async, which can cause silent failures. This patch wraps the handler to send errors to the read stream so clients know the request failed. @@ -103,7 +103,7 @@ async def handle_request_async( def patch_client_session_validation() -> None: """ Patch ClientSession to skip structured output validation. - + The original validation is strict and raises errors for non-conforming but usable responses. We replace it with a no-op. """ @@ -112,7 +112,6 @@ def patch_client_session_validation() -> None: async def noop_validate(self: Any, name: str, result: Any) -> None: """Skip structured output validation entirely.""" - pass ClientSession._validate_tool_result = noop_validate logger.debug("Patched ClientSession._validate_tool_result to skip validation") @@ -126,10 +125,10 @@ async def noop_validate(self: Any, name: str, result: Any) -> None: def suppress_fastmcp_logging(level: int = logging.WARNING) -> None: """ Suppress verbose fastmcp logging. - + FastMCP logs a lot of INFO-level messages that clutter output. This sets all fastmcp loggers to the specified level. - + Args: level: Logging level to set (default: WARNING) """ @@ -150,4 +149,3 @@ def apply_all_patches() -> None: patch_client_session_validation() suppress_fastmcp_logging() logger.debug("All MCP patches applied") - diff --git a/hud/server/server.py b/hud/server/server.py index 497c1019..7497aa3e 100644 --- a/hud/server/server.py +++ b/hud/server/server.py @@ -593,9 +593,9 @@ async def tool_endpoint(request: Request) -> Response: # Recursively serialize MCP objects def serialize_obj(obj: Any) -> Any: """Recursively serialize MCP objects to JSON-compatible format.""" - if obj is None or isinstance(obj, (str, int, float, bool)): + if obj is None or isinstance(obj, str | int | float | bool): return obj - if isinstance(obj, (list, tuple)): + if isinstance(obj, list | tuple): return [serialize_obj(item) for item in obj] if isinstance(obj, dict): return {k: serialize_obj(v) for k, v in obj.items()} From 206120405031ad18b92f005edeb14fdf3164b6bd Mon Sep 17 00:00:00 2001 From: lorenss-m Date: Wed, 10 Dec 2025 17:35:25 -0800 Subject: [PATCH 25/92] fix typing --- hud/clients/mcp_use.py | 14 +++----------- hud/datasets/runner.py | 3 ++- hud/environment/environment.py | 4 ++-- hud/environment/tests/test_integrations.py | 8 ++++---- hud/eval/context.py | 1 + 5 files changed, 12 insertions(+), 18 deletions(-) diff --git a/hud/clients/mcp_use.py b/hud/clients/mcp_use.py index 19915165..36c1b144 100644 --- a/hud/clients/mcp_use.py +++ b/hud/clients/mcp_use.py @@ -9,9 +9,8 @@ from mcp import Implementation, types from mcp.shared.exceptions import McpError -from mcp_use.client import MCPClient as MCPUseClient -from mcp_use.session import MCPSession as MCPUseSession -from mcp_use.types.http import HttpOptions +from mcp_use.client.client import MCPClient as MCPUseClient +from mcp_use.client.session import MCPSession as MCPUseSession from pydantic import AnyUrl from hud.settings import settings @@ -20,7 +19,6 @@ from hud.version import __version__ as hud_version from .base import BaseHUDClient -from .utils.retry_transport import create_retry_httpx_client logger = logging.getLogger(__name__) hud_console = HUDConsole(logger=logger) @@ -58,12 +56,6 @@ def __init__( str, tuple[str, types.Tool, types.Tool] ] = {} # server_name, original_tool, prefixed_tool self._client: Any | None = None # Will be MCPUseClient when available - # Transport options for MCP-use (disable_sse_fallback, httpx_client_factory, etc.) - # Default to retry-enabled HTTPX client if factory not provided - self._http_options: HttpOptions = HttpOptions( - httpx_client_factory=create_retry_httpx_client, - disable_sse_fallback=True, - ) async def _connect(self, mcp_config: dict[str, dict[str, Any]]) -> None: """Create all sessions for MCP-use client.""" @@ -88,7 +80,7 @@ async def _connect(self, mcp_config: dict[str, dict[str, Any]]) -> None: config = {"mcpServers": mcp_config} if MCPUseClient is None: raise ImportError("MCPUseClient is not available") - self._client = MCPUseClient.from_dict(config, http_options=self._http_options) + self._client = MCPUseClient.from_dict(config) try: assert self._client is not None self._sessions = await self._client.create_all_sessions() diff --git a/hud/datasets/runner.py b/hud/datasets/runner.py index 7f5a7aee..5cc6dcd6 100644 --- a/hud/datasets/runner.py +++ b/hud/datasets/runner.py @@ -68,12 +68,13 @@ async def run_single_task( group_id=group_id, ) + result: Trace async with ctx: agent = agent_type.cls.create(**(agent_params or {})) result = await agent.run(task, max_steps=max_steps) # Transfer reward to context for tracking ctx.reward = result.reward - return result + return result async def run_tasks( diff --git a/hud/environment/environment.py b/hud/environment/environment.py index df731ac8..2f5f2196 100644 --- a/hud/environment/environment.py +++ b/hud/environment/environment.py @@ -407,7 +407,7 @@ async def _execute_tool(self, name: str, arguments: dict[str, Any]) -> MCPToolRe async def list_resources(self) -> list[mcp_types.Resource]: """List all resources (local + remote).""" - local = await self._resource_manager.list_resources() + local = list((await self._resource_manager.get_resources()).values()) resources: list[mcp_types.Resource] = [r.to_mcp_resource() for r in local] if self._connections: @@ -456,7 +456,7 @@ async def read_resource( async def list_prompts(self) -> list[mcp_types.Prompt]: """List all prompts (local + remote).""" - local = await self._prompt_manager.list_prompts() + local = list((await self._prompt_manager.get_prompts()).values()) prompts: list[mcp_types.Prompt] = [p.to_mcp_prompt() for p in local] if self._connections: diff --git a/hud/environment/tests/test_integrations.py b/hud/environment/tests/test_integrations.py index 713d0568..90e84931 100644 --- a/hud/environment/tests/test_integrations.py +++ b/hud/environment/tests/test_integrations.py @@ -47,9 +47,9 @@ async def call_tool(self, name: str, arguments: dict[str, Any]) -> Any: assert len(tools) == 1 assert tools[0]["type"] == "function" - assert tools[0]["function"]["name"] == "navigate" - assert tools[0]["function"]["description"] == "Navigate to URL" - assert "url" in tools[0]["function"]["parameters"]["properties"] + assert tools[0]["function"]["name"] == "navigate" # type: ignore[typeddict-item] + assert tools[0]["function"]["description"] == "Navigate to URL" # type: ignore[typeddict-item] + assert "url" in tools[0]["function"]["parameters"]["properties"] # type: ignore[typeddict-item, operator] def test_as_openai_chat_tools_strict_mode(self) -> None: """as_openai_chat_tools with strict=True adds strict flag.""" @@ -65,7 +65,7 @@ async def call_tool(self, name: str, arguments: dict[str, Any]) -> Any: env = TestEnv() tools = env.as_openai_chat_tools(strict=True) - assert tools[0]["function"]["strict"] is True + assert tools[0]["function"]["strict"] is True # type: ignore[typeddict-item] def test_as_openai_chat_tools_empty(self) -> None: """as_openai_chat_tools returns empty list when no tools.""" diff --git a/hud/eval/context.py b/hud/eval/context.py index 4ecafef0..68cd1df3 100644 --- a/hud/eval/context.py +++ b/hud/eval/context.py @@ -165,6 +165,7 @@ def __init__( self._suppress_link: bool = quiet # True to suppress printing eval link self._trace_enabled: bool = trace # Whether to send trace data to backend self._script_name: str | None = None # Current script name (for submit) + self._source_env_name: str | None = None # Source env name for remote lookups def _apply_task(self, task: Task) -> None: """Apply a Task definition to this environment.""" From dbe94a8dc796198bc45fba10acb6e1fd33e5dd64 Mon Sep 17 00:00:00 2001 From: lorenss-m Date: Wed, 10 Dec 2025 17:49:13 -0800 Subject: [PATCH 26/92] scripts --- docs/llm-quickstart.mdx | 2 +- hud/environment/scripts.py | 11 +++++++++++ 2 files changed, 12 insertions(+), 1 deletion(-) diff --git a/docs/llm-quickstart.mdx b/docs/llm-quickstart.mdx index 7bde2a04..bcd99d95 100644 --- a/docs/llm-quickstart.mdx +++ b/docs/llm-quickstart.mdx @@ -30,5 +30,5 @@ icon: "sparkles" -Try asking your assistant: "How do I create a custom agent in HUD?" or "Help me debug MCP tool calls" +Try asking: "How do I create an Environment with tools?" or "How do scripts and evals work in HUD?" \ No newline at end of file diff --git a/hud/environment/scripts.py b/hud/environment/scripts.py index 4fd06218..63fd8703 100644 --- a/hud/environment/scripts.py +++ b/hud/environment/scripts.py @@ -280,6 +280,12 @@ def decorator( script_id = f"{safe_env_name}:{script_name}" script_desc = description or fn.__doc__ or f"Script: {script_name}" + # Capture source code for reproducibility + try: + source_code = inspect.getsource(fn) + except (OSError, TypeError): + source_code = None + # Store the generator function self._scripts[script_name] = fn @@ -321,6 +327,9 @@ async def prompt_handler(**handler_args: Any) -> list[dict[str, Any]]: # to bypass the **kwargs validation in from_function() from fastmcp.prompts.prompt import FunctionPrompt, PromptArgument + # Build meta with source code + script_meta = {"code": source_code} if source_code else None + prompt = FunctionPrompt( name=script_id, description=f"[Setup] {script_desc}", @@ -329,6 +338,7 @@ async def prompt_handler(**handler_args: Any) -> list[dict[str, Any]]: for arg in prompt_args ], fn=prompt_handler, + meta=script_meta, ) self._prompt_manager.add_prompt(prompt) @@ -380,6 +390,7 @@ async def resource_handler() -> str: name=script_name, description=f"[Evaluate] {script_desc}", mime_type="application/json", + meta=script_meta, ) self._resource_manager.add_resource(resource) From 0685e688c6f9e121d2cba86d78644c66b46cd1b9 Mon Sep 17 00:00:00 2001 From: lorenss-m Date: Wed, 10 Dec 2025 18:24:38 -0800 Subject: [PATCH 27/92] tests --- .../tests/test_local_connectors.py | 30 +-- hud/environment/tests/test_scripts.py | 115 ++++++++++ hud/environment/tests/test_tools.py | 208 ++++++++++++++++++ hud/eval/tests/test_context.py | 27 ++- hud/eval/tests/test_eval.py | 3 + hud/eval/tests/test_manager.py | 20 +- 6 files changed, 367 insertions(+), 36 deletions(-) create mode 100644 hud/environment/tests/test_tools.py diff --git a/hud/environment/tests/test_local_connectors.py b/hud/environment/tests/test_local_connectors.py index 018d68cb..d8e3de0a 100644 --- a/hud/environment/tests/test_local_connectors.py +++ b/hud/environment/tests/test_local_connectors.py @@ -96,24 +96,24 @@ def mount(self, server: Any, *, prefix: str | None = None) -> None: class TestConnectServer: """Tests for LocalConnectorMixin.connect_server.""" - def test_connect_server_calls_mount(self) -> None: - """connect_server calls mount with server and prefix.""" + def test_connect_server_calls_include_router(self) -> None: + """connect_server calls include_router with server and prefix.""" from hud.environment.connectors.local import LocalConnectorMixin class TestEnv(LocalConnectorMixin): def __init__(self) -> None: self._connections: dict[str, Connector] = {} - self.mounted: list[tuple[Any, str | None]] = [] + self.routers: list[tuple[Any, str | None]] = [] - def mount(self, server: Any, *, prefix: str | None = None) -> None: - self.mounted.append((server, prefix)) + def include_router(self, server: Any, *, prefix: str | None = None) -> None: + self.routers.append((server, prefix)) env = TestEnv() mock_server = MagicMock() env.connect_server(mock_server, prefix="tools") - assert len(env.mounted) == 1 - assert env.mounted[0] == (mock_server, "tools") + assert len(env.routers) == 1 + assert env.routers[0] == (mock_server, "tools") def test_connect_server_returns_self(self) -> None: """connect_server returns self for chaining.""" @@ -123,7 +123,7 @@ class TestEnv(LocalConnectorMixin): def __init__(self) -> None: self._connections: dict[str, Connector] = {} - def mount(self, server: Any, *, prefix: str | None = None) -> None: + def include_router(self, server: Any, *, prefix: str | None = None) -> None: pass env = TestEnv() @@ -146,10 +146,10 @@ def test_connect_fastapi_creates_mcp_server(self, mock_fastmcp: MagicMock) -> No class TestEnv(LocalConnectorMixin): def __init__(self) -> None: self._connections: dict[str, Connector] = {} - self.mounted: list[tuple[Any, str | None]] = [] + self.routers: list[tuple[Any, str | None]] = [] - def mount(self, server: Any, *, prefix: str | None = None) -> None: - self.mounted.append((server, prefix)) + def include_router(self, server: Any, *, prefix: str | None = None) -> None: + self.routers.append((server, prefix)) env = TestEnv() mock_app = MagicMock() @@ -157,8 +157,8 @@ def mount(self, server: Any, *, prefix: str | None = None) -> None: env.connect_fastapi(mock_app) mock_fastmcp.from_fastapi.assert_called_once_with(app=mock_app, name="My API") - assert len(env.mounted) == 1 - assert env.mounted[0] == (mock_mcp_server, None) + assert len(env.routers) == 1 + assert env.routers[0] == (mock_mcp_server, None) @patch("fastmcp.FastMCP") def test_connect_fastapi_with_custom_name(self, mock_fastmcp: MagicMock) -> None: @@ -171,7 +171,7 @@ class TestEnv(LocalConnectorMixin): def __init__(self) -> None: self._connections: dict[str, Connector] = {} - def mount(self, server: Any, *, prefix: str | None = None) -> None: + def include_router(self, server: Any, *, prefix: str | None = None) -> None: pass env = TestEnv() @@ -192,7 +192,7 @@ class TestEnv(LocalConnectorMixin): def __init__(self) -> None: self._connections: dict[str, Connector] = {} - def mount(self, server: Any, *, prefix: str | None = None) -> None: + def include_router(self, server: Any, *, prefix: str | None = None) -> None: pass env = TestEnv() diff --git a/hud/environment/tests/test_scripts.py b/hud/environment/tests/test_scripts.py index cfec80b0..e07481f0 100644 --- a/hud/environment/tests/test_scripts.py +++ b/hud/environment/tests/test_scripts.py @@ -163,3 +163,118 @@ async def checkout_script(user_id: str, amount: int = 100): assert received_args["user_id"] == "alice" assert received_args["amount"] == 50 + + +class TestScriptSubmit: + """Tests for script submit and answer flow.""" + + @pytest.mark.asyncio + async def test_submit_stores_answer(self) -> None: + """submit() stores answer for script.""" + env = Environment("test-env") + + @env.script("test") + async def test_script(): + yield "What is 2+2?" + yield 1.0 + + # Run setup + prompt = env._prompt_manager._prompts.get("test-env:test") + assert prompt is not None + await prompt.render({}) + + # Submit answer + await env.submit("test", "4") + + assert env._script_answers.get("test") == "4" + + @pytest.mark.asyncio + async def test_script_receives_answer(self) -> None: + """Script receives submitted answer via yield.""" + env = Environment("test-env") + received_answer = None + + @env.script("qa") + async def qa_script(): + nonlocal received_answer + answer = yield "What is 2+2?" + received_answer = answer + yield 1.0 if answer == "4" else 0.0 + + # Run setup + prompt = env._prompt_manager._prompts.get("test-env:qa") + assert prompt is not None + await prompt.render({}) + + # Submit answer + env._script_answers["qa"] = "4" + + # Run evaluate + resource = env._resource_manager._resources.get("test-env:qa") + assert resource is not None + await resource.read() + + assert received_answer == "4" + + @pytest.mark.asyncio + async def test_script_evaluates_answer(self) -> None: + """Script evaluates answer and returns reward.""" + env = Environment("test-env") + + @env.script("grading") + async def grading_script(): + answer = yield "What is the capital of France?" + yield 1.0 if "paris" in answer.lower() else 0.0 + + # Run setup + prompt = env._prompt_manager._prompts.get("test-env:grading") + assert prompt is not None + await prompt.render({}) + + # Submit correct answer + env._script_answers["grading"] = "Paris" + + # Run evaluate + resource = env._resource_manager._resources.get("test-env:grading") + assert resource is not None + result = await resource.read() + + import json + + data = json.loads(result) + assert data["reward"] == 1.0 + + +class TestScriptMeta: + """Tests for script _meta containing code.""" + + def test_script_captures_source_code(self) -> None: + """@env.script captures function source in meta.""" + env = Environment("test-env") + + @env.script("example") + async def example_script(x: int): + yield f"Process {x}" + yield 1.0 + + prompt = env._prompt_manager._prompts.get("test-env:example") + assert prompt is not None + assert prompt.meta is not None + assert "code" in prompt.meta + assert "async def example_script" in prompt.meta["code"] + assert "yield" in prompt.meta["code"] + + def test_script_meta_on_resource(self) -> None: + """Resource also has source code in meta.""" + env = Environment("test-env") + + @env.script("example") + async def example_script(): + yield "Test" + yield 1.0 + + resource = env._resource_manager._resources.get("test-env:example") + assert resource is not None + assert resource.meta is not None + assert "code" in resource.meta + assert "async def example_script" in resource.meta["code"] diff --git a/hud/environment/tests/test_tools.py b/hud/environment/tests/test_tools.py new file mode 100644 index 00000000..8c99a01b --- /dev/null +++ b/hud/environment/tests/test_tools.py @@ -0,0 +1,208 @@ +"""Tests for @env.tool() decorator and tool operations.""" + +from __future__ import annotations + +import pytest + +from hud.environment import Environment + + +class TestToolDecorator: + """Tests for @env.tool() decorator.""" + + def test_tool_registers_function(self) -> None: + """@env.tool registers the function in tool manager.""" + env = Environment("test-env") + + @env.tool() + def add(a: int, b: int) -> int: + """Add two numbers.""" + return a + b + + # Check tool was registered + tool_names = list(env._tool_manager._tools.keys()) + assert "add" in tool_names + + def test_tool_with_custom_name(self) -> None: + """@env.tool(name=...) uses custom name.""" + env = Environment("test-env") + + @env.tool(name="custom_add") + def add(a: int, b: int) -> int: + return a + b + + tool_names = list(env._tool_manager._tools.keys()) + assert "custom_add" in tool_names + assert "add" not in tool_names + + def test_tool_preserves_docstring(self) -> None: + """@env.tool preserves function docstring as description.""" + env = Environment("test-env") + + @env.tool() + def greet(name: str) -> str: + """Greet someone by name.""" + return f"Hello, {name}!" + + tool = env._tool_manager._tools.get("greet") + assert tool is not None + assert "Greet someone by name" in (tool.description or "") + + def test_tool_async_function(self) -> None: + """@env.tool works with async functions.""" + env = Environment("test-env") + + @env.tool() + async def fetch_data(url: str) -> str: + """Fetch data from URL.""" + return f"Data from {url}" + + tool_names = list(env._tool_manager._tools.keys()) + assert "fetch_data" in tool_names + + def test_tool_returns_function(self) -> None: + """@env.tool returns the original function.""" + env = Environment("test-env") + + @env.tool() + def add(a: int, b: int) -> int: + return a + b + + # Should be able to call it directly + assert add(2, 3) == 5 + + +class TestListTools: + """Tests for list_tools and as_tools.""" + + @pytest.mark.asyncio + async def test_as_tools_returns_registered_tools(self) -> None: + """as_tools returns list of registered MCP tools.""" + env = Environment("test-env") + + @env.tool() + def tool1() -> str: + return "1" + + @env.tool() + def tool2() -> str: + return "2" + + async with env: + tools = env.as_tools() + tool_names = [t.name for t in tools] + assert "tool1" in tool_names + assert "tool2" in tool_names + + @pytest.mark.asyncio + async def test_as_tools_empty_when_no_tools(self) -> None: + """as_tools returns empty list when no tools registered.""" + env = Environment("test-env") + async with env: + tools = env.as_tools() + # May have built-in _hud_submit tool + user_tools = [t for t in tools if not t.name.startswith("_")] + assert len(user_tools) == 0 + + +class TestCallTool: + """Tests for call_tool method.""" + + @pytest.mark.asyncio + async def test_call_tool_executes_function(self) -> None: + """call_tool executes registered tool function.""" + env = Environment("test-env") + executed = [] + + @env.tool() + def greet(name: str) -> str: + executed.append(name) + return f"Hello, {name}!" + + async with env: + result = await env.call_tool("greet", name="Alice") + + assert executed == ["Alice"] + assert result is not None + + @pytest.mark.asyncio + async def test_call_tool_async_function(self) -> None: + """call_tool works with async tool functions.""" + env = Environment("test-env") + + @env.tool() + async def async_greet(name: str) -> str: + return f"Hello, {name}!" + + async with env: + result = await env.call_tool("async_greet", name="Bob") + + assert result is not None + + @pytest.mark.asyncio + async def test_call_tool_not_found(self) -> None: + """call_tool raises for unknown tool.""" + env = Environment("test-env") + + async with env: + with pytest.raises(ValueError, match="Tool not found"): + await env.call_tool("nonexistent") + + +class TestMockMode: + """Tests for mock mode.""" + + def test_mock_mode_default_false(self) -> None: + """Mock mode is False by default.""" + env = Environment("test-env") + assert env._mock_mode is False + assert env.is_mock is False + + def test_mock_enables_mock_mode(self) -> None: + """mock() enables mock mode.""" + env = Environment("test-env") + env.mock() + assert env._mock_mode is True + assert env.is_mock is True + + def test_unmock_disables_mock_mode(self) -> None: + """unmock() disables mock mode.""" + env = Environment("test-env") + env.mock() + env.unmock() + assert env._mock_mode is False + + def test_mock_returns_self_for_chaining(self) -> None: + """mock() returns self for chaining.""" + env = Environment("test-env") + result = env.mock() + assert result is env + + def test_mock_tool_sets_custom_output(self) -> None: + """mock_tool() sets custom output for a tool.""" + env = Environment("test-env") + env.mock_tool("navigate", "Custom result") + assert env._mock_outputs["navigate"] == "Custom result" + + @pytest.mark.asyncio + async def test_mock_mode_returns_mock_response(self) -> None: + """Mock mode returns mock response instead of executing tool.""" + env = Environment("test-env") + call_count = 0 + + @env.tool() + def real_tool() -> str: + nonlocal call_count + call_count += 1 + return "real result" + + env.mock() + env.mock_tool("real_tool", "mocked result") + + async with env: + result = await env.call_tool("real_tool") + + # Tool should not be called in mock mode + assert call_count == 0 + # Should get the mock result + assert result is not None diff --git a/hud/eval/tests/test_context.py b/hud/eval/tests/test_context.py index f749377a..275d3fef 100644 --- a/hud/eval/tests/test_context.py +++ b/hud/eval/tests/test_context.py @@ -17,45 +17,45 @@ class TestEvalContext: def test_init_generates_trace_id(self) -> None: """EvalContext generates trace_id if not provided.""" - ctx = EvalContext(name="test-task") + ctx = EvalContext(name="test-task", quiet=True) assert ctx.trace_id is not None assert len(ctx.trace_id) == 36 # UUID format def test_init_uses_provided_trace_id(self) -> None: """EvalContext uses provided trace_id.""" - ctx = EvalContext(name="test-task", trace_id="custom-id") + ctx = EvalContext(name="test-task", trace_id="custom-id", quiet=True) assert ctx.trace_id == "custom-id" def test_headers_contains_trace_id(self) -> None: """headers property returns dict with trace ID.""" - ctx = EvalContext(name="test-task", trace_id="test-123") + ctx = EvalContext(name="test-task", trace_id="test-123", quiet=True) assert ctx.headers == {"Trace-Id": "test-123"} def test_success_true_when_no_error(self) -> None: """success property returns True when no error.""" - ctx = EvalContext(name="test-task") + ctx = EvalContext(name="test-task", quiet=True) assert ctx.success is True def test_success_false_when_error(self) -> None: """success property returns False when error is set.""" - ctx = EvalContext(name="test-task") + ctx = EvalContext(name="test-task", quiet=True) ctx.error = ValueError("test error") assert ctx.success is False def test_done_false_initially(self) -> None: """done property returns False initially.""" - ctx = EvalContext(name="test-task") + ctx = EvalContext(name="test-task", quiet=True) assert ctx.done is False def test_variants_empty_by_default(self) -> None: """variants is empty dict by default.""" - ctx = EvalContext(name="test-task") + ctx = EvalContext(name="test-task", quiet=True) assert ctx.variants == {} @@ -64,6 +64,7 @@ def test_variants_set_from_init(self) -> None: ctx = EvalContext( name="test-task", variants={"model": "gpt-4o", "temp": 0.7}, + quiet=True, ) assert ctx.variants == {"model": "gpt-4o", "temp": 0.7} @@ -71,7 +72,7 @@ def test_variants_set_from_init(self) -> None: @pytest.mark.asyncio async def test_context_manager_sets_headers(self) -> None: """Context manager sets trace headers in contextvar.""" - ctx = EvalContext(name="test-task", trace_id="test-123") + ctx = EvalContext(name="test-task", trace_id="test-123", quiet=True) # Mock telemetry calls with ( @@ -97,7 +98,11 @@ async def test_context_manager_sets_headers(self) -> None: def test_repr(self) -> None: """__repr__ shows useful info.""" - ctx = EvalContext(name="test-task", trace_id="abc12345-6789-0000-0000-000000000000") + ctx = EvalContext( + name="test-task", + trace_id="abc12345-6789-0000-0000-000000000000", + quiet=True, + ) ctx.reward = 0.95 repr_str = repr(ctx) @@ -111,14 +116,14 @@ class TestEvalContextPrompt: def test_prompt_can_be_set(self) -> None: """EvalContext.prompt can be set.""" - ctx = EvalContext(name="test-task") + ctx = EvalContext(name="test-task", quiet=True) ctx.prompt = "Test prompt" assert ctx.prompt == "Test prompt" def test_prompt_included_in_payload(self) -> None: """Prompt is included in eval payload.""" - ctx = EvalContext(name="test-task") + ctx = EvalContext(name="test-task", quiet=True) ctx.prompt = "Test prompt" payload = ctx._build_base_payload() diff --git a/hud/eval/tests/test_eval.py b/hud/eval/tests/test_eval.py index 512d0150..38c11f58 100644 --- a/hud/eval/tests/test_eval.py +++ b/hud/eval/tests/test_eval.py @@ -112,6 +112,7 @@ async def test_aenter_returns_eval_context(self) -> None: patch.object(EvalContext, "_eval_enter", new_callable=AsyncMock), patch.object(EvalContext, "_eval_exit", new_callable=AsyncMock), patch.object(EvalContext, "__aexit__", new_callable=AsyncMock), + patch.object(EvalContext, "_print_eval_link"), # Suppress link printing ): ctx = await ev.__aenter__() assert isinstance(ctx, EvalContext) @@ -129,6 +130,7 @@ async def test_context_clears_on_exit(self) -> None: patch.object(EvalContext, "_eval_enter", new_callable=AsyncMock), patch.object(EvalContext, "_eval_exit", new_callable=AsyncMock), patch.object(EvalContext, "__aexit__", new_callable=AsyncMock), + patch.object(EvalContext, "_print_eval_link"), # Suppress link printing ): await ev.__aenter__() assert ev._ctx is not None @@ -148,6 +150,7 @@ async def test_reward_accessible_after_exit(self) -> None: patch.object(EvalContext, "_eval_enter", new_callable=AsyncMock), patch.object(EvalContext, "_eval_exit", new_callable=AsyncMock), patch.object(EvalContext, "__aexit__", new_callable=AsyncMock), + patch.object(EvalContext, "_print_eval_link"), # Suppress link printing ): ctx = await ev.__aenter__() ctx.reward = 0.95 diff --git a/hud/eval/tests/test_manager.py b/hud/eval/tests/test_manager.py index 75aa6ad7..9b237382 100644 --- a/hud/eval/tests/test_manager.py +++ b/hud/eval/tests/test_manager.py @@ -20,7 +20,7 @@ async def test_blank_eval_creates_context(self) -> None: patch.object(EvalContext, "_eval_enter", new_callable=AsyncMock), patch.object(EvalContext, "_eval_exit", new_callable=AsyncMock), ): - async with run_eval() as ctx: + async with run_eval(quiet=True) as ctx: assert isinstance(ctx, EvalContext) assert ctx.eval_name == "eval" @@ -31,7 +31,7 @@ async def test_blank_eval_generates_trace_id(self) -> None: patch.object(EvalContext, "_eval_enter", new_callable=AsyncMock), patch.object(EvalContext, "_eval_exit", new_callable=AsyncMock), ): - async with run_eval() as ctx: + async with run_eval(quiet=True) as ctx: assert ctx.trace_id is not None assert len(ctx.trace_id) == 36 # UUID format @@ -45,7 +45,7 @@ async def test_blank_eval_sets_trace_headers(self) -> None: # Before context, no headers assert get_current_trace_headers() is None - async with run_eval() as ctx: + async with run_eval(quiet=True) as ctx: # Inside context, headers are set headers = get_current_trace_headers() assert headers is not None @@ -61,7 +61,7 @@ async def test_blank_eval_reward_can_be_set(self) -> None: patch.object(EvalContext, "_eval_enter", new_callable=AsyncMock), patch.object(EvalContext, "_eval_exit", new_callable=AsyncMock), ): - async with run_eval() as ctx: + async with run_eval(quiet=True) as ctx: assert ctx.reward is None ctx.reward = 0.95 @@ -74,7 +74,7 @@ async def test_blank_eval_reports_reward_on_exit(self) -> None: patch.object(EvalContext, "_eval_enter", new_callable=AsyncMock), patch.object(EvalContext, "_eval_exit", new_callable=AsyncMock) as mock_exit, ): - async with run_eval() as ctx: + async with run_eval(quiet=True) as ctx: ctx.reward = 0.85 # _eval_exit should have been called (with no error) @@ -87,7 +87,7 @@ async def test_blank_eval_empty_variants(self) -> None: patch.object(EvalContext, "_eval_enter", new_callable=AsyncMock), patch.object(EvalContext, "_eval_exit", new_callable=AsyncMock), ): - async with run_eval() as ctx: + async with run_eval(quiet=True) as ctx: assert ctx.variants == {} @pytest.mark.asyncio @@ -97,7 +97,7 @@ async def test_blank_eval_has_headers_property(self) -> None: patch.object(EvalContext, "_eval_enter", new_callable=AsyncMock), patch.object(EvalContext, "_eval_exit", new_callable=AsyncMock), ): - async with run_eval() as ctx: + async with run_eval(quiet=True) as ctx: headers = ctx.headers assert "Trace-Id" in headers assert headers["Trace-Id"] == ctx.trace_id @@ -113,7 +113,7 @@ async def test_api_key_passed_to_context(self) -> None: patch.object(EvalContext, "_eval_enter", new_callable=AsyncMock), patch.object(EvalContext, "_eval_exit", new_callable=AsyncMock), ): - async with run_eval(api_key="test-key") as ctx: + async with run_eval(api_key="test-key", quiet=True) as ctx: assert ctx._eval_api_key == "test-key" @@ -127,7 +127,7 @@ async def test_job_id_passed_to_context(self) -> None: patch.object(EvalContext, "_eval_enter", new_callable=AsyncMock), patch.object(EvalContext, "_eval_exit", new_callable=AsyncMock), ): - async with run_eval(job_id="job-123") as ctx: + async with run_eval(job_id="job-123", quiet=True) as ctx: assert ctx.job_id == "job-123" @@ -142,7 +142,7 @@ async def test_error_tracked_on_exception(self) -> None: patch.object(EvalContext, "_eval_exit", new_callable=AsyncMock) as mock_exit, ): with pytest.raises(ValueError): - async with run_eval(): + async with run_eval(quiet=True): raise ValueError("test error") # _eval_exit should have been called with error message From 867e976e7e252b15de7e353274f824e4f5ba52cd Mon Sep 17 00:00:00 2001 From: lorenss-m Date: Thu, 11 Dec 2025 11:29:20 -0800 Subject: [PATCH 28/92] fix new langchain version --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 72744511..8333e6d9 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -109,7 +109,7 @@ packages = ["hud"] agents = [ # MCP-use client (legacy) "mcp-use==1.5.0", - "langchain>=1.0.0", # Required by mcp-use + "langchain>=1.1.0", # Required by mcp-use # AI providers "anthropic>=0.75", "openai>=2.8.1", From fb10f4832148f1f921013ee93f592663b354c2ae Mon Sep 17 00:00:00 2001 From: lorenss-m Date: Thu, 11 Dec 2025 11:36:29 -0800 Subject: [PATCH 29/92] analyze includes scripts --- hud/cli/analyze.py | 37 +++++- hud/cli/build.py | 2 +- hud/cli/dev.py | 4 +- hud/cli/tests/test_build.py | 4 + hud/cli/tests/test_registry.py | 2 +- hud/cli/utils/metadata.py | 69 +++++++++++ hud/cli/utils/server.py | 4 +- hud/cli/utils/tests/test_env_check.py | 2 +- hud/clients/base.py | 98 +++++++++++++++- hud/clients/fastmcp.py | 6 + hud/clients/mcp_use.py | 28 +++++ hud/clients/tests/test_analyze_scenarios.py | 122 ++++++++++++++++++++ hud/patches/__init__.py | 8 +- hud/patches/warnings.py | 56 +++++++++ 14 files changed, 431 insertions(+), 11 deletions(-) create mode 100644 hud/clients/tests/test_analyze_scenarios.py create mode 100644 hud/patches/warnings.py diff --git a/hud/cli/analyze.py b/hud/cli/analyze.py index 541617d4..6fdf7441 100644 --- a/hud/cli/analyze.py +++ b/hud/cli/analyze.py @@ -143,8 +143,8 @@ def display_interactive(analysis: dict) -> None: tool_node.add(f"[bright_black]{tool['description']}[/bright_black]") # Show input schema if verbose - if analysis.get("verbose") and tool.get("input_schema"): - schema_str = json.dumps(tool["input_schema"], indent=2) + if analysis.get("verbose") and tool.get("inputSchema"): + schema_str = json.dumps(tool["inputSchema"], indent=2) syntax = Syntax(schema_str, "json", theme="monokai", line_numbers=False) tool_node.add(syntax) @@ -170,6 +170,28 @@ def display_interactive(analysis: dict) -> None: console.print(tools_tree) + # Scenarios (Environment scripts exposed as prompt+resource) + if analysis.get("scenarios"): + hud_console.section_title("🎬 Scenarios") + scenarios_table = Table() + scenarios_table.add_column("Scenario", style="bright_white") + scenarios_table.add_column("Env", style="bright_black") + scenarios_table.add_column("Setup/Eval", style="bright_black") + + for s in analysis["scenarios"][:20]: + setup = "✓" if s.get("has_setup_prompt") else "✗" + eval_ = "✓" if s.get("has_evaluate_resource") else "✗" + scenarios_table.add_row( + str(s.get("name", "")), + str(s.get("env", "")), + f"setup {setup} / eval {eval_}", + ) + + console.print(scenarios_table) + if len(analysis["scenarios"]) > 20: + remaining = len(analysis["scenarios"]) - 20 + console.print(f"[bright_black]... and {remaining} more scenarios[/bright_black]") + # Resources if analysis["resources"]: hud_console.section_title("📚 Available Resources") @@ -285,6 +307,17 @@ def display_markdown(analysis: dict) -> None: md.extend([f"| {uri} | {name} | {mime_type} |"]) md.append("") + # Scenarios + if analysis.get("scenarios"): + md.append("## Scenarios\n") + for s in analysis["scenarios"]: + name = s.get("name", "") + env = s.get("env", "") + setup = "✓" if s.get("has_setup_prompt") else "✗" + eval_ = "✓" if s.get("has_evaluate_resource") else "✗" + md.append(f"- **{name}** ({env}) — setup {setup} / eval {eval_}") + md.append("") + # Telemetry (only for live analysis) if analysis.get("telemetry"): md.append("## Telemetry") diff --git a/hud/cli/build.py b/hud/cli/build.py index bf300fef..26963f3f 100644 --- a/hud/cli/build.py +++ b/hud/cli/build.py @@ -804,7 +804,7 @@ def build_environment( # Create lock file content with images subsection at top lock_content = { - "version": "1.2", # Lock file format version + "version": "1.3", # Lock file format version "images": { "local": f"{base_name}:{new_version}", # Local tag with version "full": None, # Will be set with digest after build diff --git a/hud/cli/dev.py b/hud/cli/dev.py index 913cc582..a7564360 100644 --- a/hud/cli/dev.py +++ b/hud/cli/dev.py @@ -189,9 +189,9 @@ async def run_mcp_module( logging.getLogger("mcp.server.streamable_http_manager").setLevel(logging.ERROR) # Suppress deprecation warnings on reload - import warnings + from hud.patches.warnings import apply_default_warning_filters - warnings.filterwarnings("ignore", category=DeprecationWarning) + apply_default_warning_filters(verbose=False) # Ensure proper directory is in sys.path based on module name cwd = Path.cwd() diff --git a/hud/cli/tests/test_build.py b/hud/cli/tests/test_build.py index 9a7bb77b..9e76977b 100644 --- a/hud/cli/tests/test_build.py +++ b/hud/cli/tests/test_build.py @@ -403,6 +403,9 @@ def test_build_environment_success( with open(lock_file) as f: lock_data = yaml.safe_load(f) + # Lock file format version + assert lock_data["version"] == "1.3" + assert lock_data["images"]["full"] == "test-env:0.1.0@sha256:abc123" assert lock_data["images"]["local"] == "test-env:0.1.0" assert lock_data["build"]["version"] == "0.1.0" @@ -472,6 +475,7 @@ def test_build_environment_internal_tools( lock_file = env_dir / "hud.lock.yaml" with open(lock_file) as f: data = yaml.safe_load(f) + assert data["version"] == "1.3" assert data["environment"]["internalToolCount"] == 2 assert data["tools"][0]["name"] == "setup" assert data["tools"][0]["internalTools"] == ["board", "seed"] diff --git a/hud/cli/tests/test_registry.py b/hud/cli/tests/test_registry.py index 6dd92b6a..5a09c283 100644 --- a/hud/cli/tests/test_registry.py +++ b/hud/cli/tests/test_registry.py @@ -189,7 +189,7 @@ def test_load_success(self, tmp_path): digest_dir = registry_dir / "abc123" digest_dir.mkdir(parents=True) - lock_data = {"image": "test:latest", "version": "1.0"} + lock_data = {"image": "test:latest", "version": "1.3"} lock_file = digest_dir / "hud.lock.yaml" lock_file.write_text(yaml.dump(lock_data)) diff --git a/hud/cli/utils/metadata.py b/hud/cli/utils/metadata.py index d19a344c..f9241752 100644 --- a/hud/cli/utils/metadata.py +++ b/hud/cli/utils/metadata.py @@ -173,6 +173,8 @@ async def analyze_from_metadata(reference: str, output_format: str, verbose: boo "tools": [], "resources": [], "prompts": [], + "scenarios": [], + "verbose": verbose, } # Add basic info @@ -206,6 +208,73 @@ async def analyze_from_metadata(reference: str, output_format: str, verbose: boo } ) + # Extract resources + if "resources" in lock_data: + for resource in lock_data["resources"]: + analysis["resources"].append( + { + "uri": resource.get("uri", ""), + "name": resource.get("name", ""), + "description": resource.get("description", ""), + "mime_type": resource.get("mimeType", resource.get("mime_type", "")), + } + ) + + # Extract prompts + if "prompts" in lock_data: + for prompt in lock_data["prompts"]: + analysis["prompts"].append( + { + "name": prompt.get("name", ""), + "description": prompt.get("description", ""), + "arguments": prompt.get("arguments", []), + } + ) + + # Derive scenarios from script prompts/resources if present + scenarios_by_id: dict[str, dict] = {} + for p in analysis["prompts"]: + desc = (p.get("description") or "").strip() + if not desc.startswith("[Setup]"): + continue + scenario_id = p.get("name") + if not scenario_id: + continue + env_name, script_name = ([*scenario_id.split(":", 1), ""])[:2] + scenarios_by_id[scenario_id] = { + "id": scenario_id, + "env": env_name, + "name": script_name or scenario_id, + "setup_description": desc, + "arguments": p.get("arguments") or [], + "has_setup_prompt": True, + "has_evaluate_resource": False, + } + for r in analysis["resources"]: + desc = (r.get("description") or "").strip() + if not desc.startswith("[Evaluate]"): + continue + scenario_id = r.get("uri") + if not scenario_id: + continue + env_name, script_name = ([*scenario_id.split(":", 1), ""])[:2] + if scenario_id not in scenarios_by_id: + scenarios_by_id[scenario_id] = { + "id": scenario_id, + "env": env_name, + "name": script_name or scenario_id, + "arguments": [], + "has_setup_prompt": False, + "has_evaluate_resource": True, + } + scenarios_by_id[scenario_id]["evaluate_description"] = desc + scenarios_by_id[scenario_id]["has_evaluate_resource"] = True + + analysis["scenarios"] = sorted( + scenarios_by_id.values(), + key=lambda s: (str(s.get("env") or ""), str(s.get("name") or "")), + ) + # Display results hud_console.info("") if source == "local": diff --git a/hud/cli/utils/server.py b/hud/cli/utils/server.py index 3f3bcc18..6d942d07 100644 --- a/hud/cli/utils/server.py +++ b/hud/cli/utils/server.py @@ -138,9 +138,9 @@ async def run_http_server( logging.getLogger("uvicorn.access").setLevel(logging.ERROR) logging.getLogger("uvicorn.error").setLevel(logging.ERROR) - import warnings + from hud.patches.warnings import apply_default_warning_filters - warnings.filterwarnings("ignore", category=DeprecationWarning) + apply_default_warning_filters(verbose=False) try: await proxy.run_async( diff --git a/hud/cli/utils/tests/test_env_check.py b/hud/cli/utils/tests/test_env_check.py index 1ec55c77..134549d0 100644 --- a/hud/cli/utils/tests/test_env_check.py +++ b/hud/cli/utils/tests/test_env_check.py @@ -50,7 +50,7 @@ def test_find_environment_dir_prefers_lock(tmp_path: Path): tasks.write_text("[]") env = tmp_path / "env" env.mkdir() - (env / "hud.lock.yaml").write_text("version: 1.0") + (env / "hud.lock.yaml").write_text("version: 1.3") # Set cwd to env so it's in the candidate list with patch("pathlib.Path.cwd", return_value=env): found = find_environment_dir(tasks) diff --git a/hud/clients/base.py b/hud/clients/base.py index f60b7b5a..b760058b 100644 --- a/hud/clients/base.py +++ b/hud/clients/base.py @@ -105,6 +105,7 @@ def __init__( self._initialized = False self._telemetry_data = {} # Initialize telemetry data self._cached_resources: list[types.Resource] = [] # Cache for resources + self._cached_prompts: list[types.Prompt] = [] # Cache for prompts if self.verbose: self._setup_verbose_logging() @@ -172,6 +173,7 @@ async def shutdown(self) -> None: await self._disconnect() self._initialized = False self._cached_resources.clear() + self._cached_prompts.clear() hud_console.info("Environment Shutdown completed") else: hud_console.debug("Client was not initialized, skipping disconnect") @@ -231,6 +233,23 @@ async def _list_resources_impl(self) -> list[types.Resource]: """Implementation-specific resource listing. Subclasses must implement this.""" raise NotImplementedError + async def list_prompts(self) -> list[types.Prompt]: + """List all available prompts. + + Uses cached prompts if available, otherwise fetches from the server. + Prompts are optional in MCP; default implementation returns an empty list. + """ + if not self._cached_prompts: + self._cached_prompts = await self._list_prompts_impl() + return self._cached_prompts + + async def _list_prompts_impl(self) -> list[types.Prompt]: + """Implementation-specific prompt listing (optional). + + Subclasses can override to support prompt discovery. + """ + return [] + @abstractmethod async def _call_tool(self, tool_call: MCPToolCall) -> MCPToolResult: """Execute a tool by name.""" @@ -347,6 +366,9 @@ async def analyze_environment(self) -> dict[str, Any]: "hub_tools": {}, "telemetry": self._telemetry_data, "resources": [], + "prompts": [], + "scenarios": [], + "verbose": self.verbose, "metadata": { "servers": list(self._mcp_config.keys()), # type: ignore "initialized": self._initialized, @@ -387,7 +409,81 @@ async def analyze_environment(self) -> dict[str, Any]: analysis["resources"].append(resource_info) except Exception as e: if self.verbose: - hud_console.debug(f"Could not list resources: {e}") + hud_console.debug("Could not list resources: " + str(e)) + + # Get all prompts (optional) + try: + prompts = await self.list_prompts() + for prompt in prompts: + raw_args = getattr(prompt, "arguments", []) or [] + args: list[dict[str, Any]] = [ + { + "name": getattr(a, "name", None), + "required": getattr(a, "required", None), + "description": getattr(a, "description", None), + } + for a in raw_args + ] + + analysis["prompts"].append( + { + "name": prompt.name, + "description": prompt.description, + "arguments": args, + } + ) + except Exception as e: + if self.verbose: + hud_console.debug("Could not list prompts: " + str(e)) + + # Derive "scenarios" from Environment.@script prompts/resources. + # A scenario is exposed as: + # - Prompt: name "{env}:{script}" with description prefix "[Setup]" + # - Resource: uri "{env}:{script}" with description prefix "[Evaluate]" + scenarios_by_id: dict[str, dict[str, Any]] = {} + + for p in analysis.get("prompts", []): + desc = (p.get("description") or "").strip() + if not desc.startswith("[Setup]"): + continue + scenario_id = p.get("name") + if not scenario_id: + continue + env_name, script_name = ([*scenario_id.split(":", 1), ""])[:2] + scenarios_by_id[scenario_id] = { + "id": scenario_id, + "env": env_name, + "name": script_name or scenario_id, + "setup_description": desc, + "arguments": p.get("arguments") or [], + "has_setup_prompt": True, + "has_evaluate_resource": False, + } + + for r in analysis.get("resources", []): + desc = (r.get("description") or "").strip() + if not desc.startswith("[Evaluate]"): + continue + scenario_id = r.get("uri") + if not scenario_id: + continue + env_name, script_name = ([*scenario_id.split(":", 1), ""])[:2] + if scenario_id not in scenarios_by_id: + scenarios_by_id[scenario_id] = { + "id": scenario_id, + "env": env_name, + "name": script_name or scenario_id, + "arguments": [], + "has_setup_prompt": False, + "has_evaluate_resource": True, + } + scenarios_by_id[scenario_id]["evaluate_description"] = desc + scenarios_by_id[scenario_id]["has_evaluate_resource"] = True + + analysis["scenarios"] = sorted( + scenarios_by_id.values(), + key=lambda s: (str(s.get("env") or ""), str(s.get("name") or "")), + ) return analysis diff --git a/hud/clients/fastmcp.py b/hud/clients/fastmcp.py index 04880ba7..fc68c68b 100644 --- a/hud/clients/fastmcp.py +++ b/hud/clients/fastmcp.py @@ -124,6 +124,12 @@ async def list_tools(self) -> list[types.Tool]: raise ValueError("Client is not connected, call initialize() first") return await self._client.list_tools() + async def _list_prompts_impl(self) -> list[types.Prompt]: + """List all available prompts (FastMCP supports this).""" + if self._client is None: + raise ValueError("Client is not connected, call initialize() first") + return await self._client.list_prompts() + async def _call_tool(self, tool_call: MCPToolCall) -> MCPToolResult: """Execute a tool by name.""" if self._client is None: diff --git a/hud/clients/mcp_use.py b/hud/clients/mcp_use.py index 36c1b144..0926328c 100644 --- a/hud/clients/mcp_use.py +++ b/hud/clients/mcp_use.py @@ -262,6 +262,34 @@ async def _list_resources_impl(self) -> list[types.Resource]: continue return [] + async def _list_prompts_impl(self) -> list[types.Prompt]: + """Implementation of prompt listing for MCP-use client (best-effort).""" + if self._client is None or not self._sessions: + raise ValueError("Client is not connected, call initialize() first") + + all_prompts: list[types.Prompt] = [] + for server_name, session in self._sessions.items(): + try: + if not hasattr(session, "connector") or not hasattr( + session.connector, "client_session" + ): + continue + if session.connector.client_session is None: + continue + + if not hasattr(session.connector.client_session, "list_prompts"): + continue + + prompts_result = await session.connector.client_session.list_prompts() + all_prompts.extend(prompts_result.prompts) + except Exception as e: + if self.verbose: + hud_console.debug( + f"Could not list prompts from server '{server_name}': {e}" + ) + continue + return all_prompts + async def read_resource(self, uri: str | AnyUrl) -> types.ReadResourceResult | None: """Read a resource by URI from any server that provides it.""" if self._client is None or not self._sessions: diff --git a/hud/clients/tests/test_analyze_scenarios.py b/hud/clients/tests/test_analyze_scenarios.py new file mode 100644 index 00000000..9f18ea7b --- /dev/null +++ b/hud/clients/tests/test_analyze_scenarios.py @@ -0,0 +1,122 @@ +"""Tests for scenario discovery via prompts/resources in analyze_environment().""" + +from __future__ import annotations + +from typing import Any + +import pytest +from mcp import types +from pydantic import AnyUrl + +from hud.clients.base import BaseHUDClient +from hud.types import MCPToolCall, MCPToolResult + + +class _MockClient(BaseHUDClient): + """Minimal BaseHUDClient for testing analyze_environment scenario derivation.""" + + def __init__( + self, + *, + prompts: list[types.Prompt], + resources: list[types.Resource], + ) -> None: + super().__init__(mcp_config={"test": {"url": "mock://test"}}, verbose=True, auto_trace=False) + self._mock_prompts = prompts + self._mock_resources = resources + # Skip initialize() (which fetches telemetry); we just need analyze_environment(). + self._initialized = True + + async def _connect(self, mcp_config: dict[str, dict[str, Any]]) -> None: # pragma: no cover + return None + + async def list_tools(self) -> list[types.Tool]: + return [] + + async def _list_resources_impl(self) -> list[types.Resource]: + return self._mock_resources + + async def _list_prompts_impl(self) -> list[types.Prompt]: + return self._mock_prompts + + async def _call_tool(self, tool_call: MCPToolCall) -> MCPToolResult: # pragma: no cover + raise NotImplementedError + + async def read_resource(self, uri: str) -> types.ReadResourceResult | None: # pragma: no cover + return None + + async def _disconnect(self) -> None: # pragma: no cover + return None + + +@pytest.mark.asyncio +async def test_analyze_environment_derives_scenarios_from_script_prompt_and_resource() -> None: + prompts = [ + types.Prompt( + name="my-env:checkout", + description="[Setup] Checkout flow", + arguments=[], + ) + ] + resources = [ + types.Resource( + uri=AnyUrl("my-env:checkout"), + name="checkout", + description="[Evaluate] Checkout flow", + ) + ] + + client = _MockClient(prompts=prompts, resources=resources) + analysis = await client.analyze_environment() + + assert "scenarios" in analysis + assert len(analysis["scenarios"]) == 1 + scenario = analysis["scenarios"][0] + assert scenario["id"] == "my-env:checkout" + assert scenario["env"] == "my-env" + assert scenario["name"] == "checkout" + assert scenario["has_setup_prompt"] is True + assert scenario["has_evaluate_resource"] is True + + +@pytest.mark.asyncio +async def test_analyze_environment_scenario_from_setup_only() -> None: + prompts = [ + types.Prompt( + name="env-x:only_setup", + description="[Setup] Setup only scenario", + arguments=[], + ) + ] + resources: list[types.Resource] = [] + + client = _MockClient(prompts=prompts, resources=resources) + analysis = await client.analyze_environment() + + assert len(analysis["scenarios"]) == 1 + scenario = analysis["scenarios"][0] + assert scenario["id"] == "env-x:only_setup" + assert scenario["has_setup_prompt"] is True + assert scenario["has_evaluate_resource"] is False + + +@pytest.mark.asyncio +async def test_analyze_environment_scenario_from_evaluate_only() -> None: + prompts: list[types.Prompt] = [] + resources = [ + types.Resource( + uri=AnyUrl("env-y:only_eval"), + name="only_eval", + description="[Evaluate] Evaluate only scenario", + ) + ] + + client = _MockClient(prompts=prompts, resources=resources) + analysis = await client.analyze_environment() + + assert len(analysis["scenarios"]) == 1 + scenario = analysis["scenarios"][0] + assert scenario["id"] == "env-y:only_eval" + assert scenario["has_setup_prompt"] is False + assert scenario["has_evaluate_resource"] is True + diff --git a/hud/patches/__init__.py b/hud/patches/__init__.py index 96c3ec0e..64397eb2 100644 --- a/hud/patches/__init__.py +++ b/hud/patches/__init__.py @@ -6,8 +6,14 @@ """ from hud.patches.mcp_patches import apply_all_patches, suppress_fastmcp_logging +from hud.patches.warnings import apply_default_warning_filters, suppress_mcp_use_import_warnings # Apply patches on import apply_all_patches() -__all__ = ["apply_all_patches", "suppress_fastmcp_logging"] +__all__ = [ + "apply_all_patches", + "apply_default_warning_filters", + "suppress_fastmcp_logging", + "suppress_mcp_use_import_warnings", +] diff --git a/hud/patches/warnings.py b/hud/patches/warnings.py new file mode 100644 index 00000000..1a7afd39 --- /dev/null +++ b/hud/patches/warnings.py @@ -0,0 +1,56 @@ +""" +Centralized warning filters for noisy third-party dependencies. + +Keep these helpers here so the rest of the codebase can stay clean and avoid +scattering warning filters across unrelated modules. +""" + +from __future__ import annotations + +import warnings +from contextlib import contextmanager +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from collections.abc import Iterator + + +def apply_default_warning_filters(*, verbose: bool) -> None: + """Apply our default warning filters for non-verbose CLI/server modes.""" + if verbose: + return + + warnings.filterwarnings("ignore", category=DeprecationWarning) + + # Pydantic v2 emits PydanticDeprecatedSince20 for v1-style config usage in deps. + try: + from pydantic.warnings import PydanticDeprecatedSince20 + except Exception: + return + + warnings.filterwarnings("ignore", category=PydanticDeprecatedSince20) + + +@contextmanager +def suppress_mcp_use_import_warnings() -> Iterator[None]: + """Suppress known noisy warnings emitted during `mcp_use` imports.""" + try: + from pydantic.warnings import PydanticDeprecatedSince20 + except Exception: # pragma: no cover + PydanticDeprecatedSince20 = None # type: ignore[assignment] + + with warnings.catch_warnings(): + # mcp_use currently emits DeprecationWarning from its package __init__.py. + warnings.filterwarnings("ignore", category=DeprecationWarning, module=r"mcp_use(\..*)?$") + + # mcp_use currently defines Pydantic v1-style `class Config` in oauth models. + if PydanticDeprecatedSince20 is not None: + warnings.filterwarnings( + "ignore", + category=PydanticDeprecatedSince20, + module=r"mcp_use\.client\.auth\.oauth$", + ) + + yield + + From b4f1ab6f5de841bb77fe890eb4d8fe7b628a9937 Mon Sep 17 00:00:00 2001 From: lorenss-m Date: Fri, 12 Dec 2025 01:06:17 -0800 Subject: [PATCH 30/92] update lowlevel server init --- hud/server/low_level.py | 3 ++- hud/server/server.py | 1 + 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/hud/server/low_level.py b/hud/server/low_level.py index 65460ee4..05758a4c 100644 --- a/hud/server/low_level.py +++ b/hud/server/low_level.py @@ -89,11 +89,12 @@ class LowLevelServerWithInit(_BaseLL): def __init__( self, + fastmcp: Any, *args: Any, init_fn: Callable[[RequestContext], Awaitable[None]] | None = None, **kwargs: Any, ) -> None: - super().__init__(*args, **kwargs) + super().__init__(fastmcp, *args, **kwargs) self._init_fn = init_fn async def run( diff --git a/hud/server/server.py b/hud/server/server.py index 7497aa3e..aa020fa5 100644 --- a/hud/server/server.py +++ b/hud/server/server.py @@ -242,6 +242,7 @@ async def _run_init(ctx: object | None = None) -> None: old_notification_handlers = self._mcp_server.notification_handlers self._mcp_server = LowLevelServerWithInit( + self, # Pass FastMCP instance as required by parent class name=self.name, version=self.version, instructions=self.instructions, From e06daa0f7ddba58713711be1001093cd5aa2ef51 Mon Sep 17 00:00:00 2001 From: lorenss-m Date: Fri, 12 Dec 2025 01:21:02 -0800 Subject: [PATCH 31/92] update docs --- docs/docs.json | 3 +- docs/index.mdx | 2 + docs/migration.mdx | 107 ++++++++++++++++++++++++++++++ docs/quick-links/environments.mdx | 4 ++ 4 files changed, 115 insertions(+), 1 deletion(-) create mode 100644 docs/migration.mdx diff --git a/docs/docs.json b/docs/docs.json index c4ec7980..c1525789 100644 --- a/docs/docs.json +++ b/docs/docs.json @@ -35,7 +35,8 @@ "group": "Get Started", "pages": [ "index", - "llm-quickstart" + "llm-quickstart", + "migration" ] }, { diff --git a/docs/index.mdx b/docs/index.mdx index 116cfc5a..8d1e7121 100644 --- a/docs/index.mdx +++ b/docs/index.mdx @@ -43,6 +43,8 @@ Every call is traced. View them at [hud.ai/home](https://hud.ai/home). ## 2. Environments: Your Code, Agent-Ready +A production API is one live instance with shared state—you can't run 1,000 parallel tests without them stepping on each other. Environments spin up fresh for every evaluation: isolated, deterministic, reproducible. Each generates training data. + Turn your code into tools agents can call. Define scripts that evaluate what agents do: ```python diff --git a/docs/migration.mdx b/docs/migration.mdx new file mode 100644 index 00000000..b4d3b0c0 --- /dev/null +++ b/docs/migration.mdx @@ -0,0 +1,107 @@ +--- +title: "Migrating from v4" +description: "Transition from Task-based environments to the unified Environment class" +icon: "arrow-right-arrow-left" +--- + +v4 separated environments (Docker containers) from evaluation logic (Task objects). v5 unifies everything in the `Environment` class—tools, setup, and scoring live together. + +## Good News: Your Code Still Works + +`Environment` inherits from `MCPServer`. Same API, same behavior. Just change the import: + +```python +# Before +from hud.server import MCPServer +mcp = MCPServer("my-env") + +@mcp.tool() +def my_tool(): ... + +mcp.run() +``` + +```python +# After +from hud import Environment +env = Environment("my-env") + +@env.tool() +def my_tool(): ... + +env.run() +``` + +That's it. Your Dockerfile, your tools, your `run()` call—all unchanged. Environment adds scripts, connectors, and integrations on top. + +## Recommended: Add Scripts + +v4 defined setup and evaluation externally in Task objects. v5 lets you define them inside the environment with `@env.script()`. This is optional but recommended—platform features like trace analysis and training work best with scripts. + +```python +@env.script("checkout") +async def checkout_flow(product: str): + # Setup: code before first yield + await env.call_tool("reset_cart") + + # Yield the prompt + answer = yield f"Add '{product}' to cart and checkout" + + # Evaluate: code after first yield, second yield returns reward + yield 1.0 if cart.contains(product) else 0.0 +``` + +Your existing `setup_tool` and `evaluate_tool` definitions still work. Scripts just keep the logic with the environment instead of scattered across task files. + +## Recommended: Use env() for Evals + +v4 created Task objects: + +```python +task = Task(prompt="...", mcp_config={...}, setup_tool={...}, evaluate_tool={...}) +``` + +v5 creates Evals by calling the environment with a script name: + +```python +eval = env("checkout", product="laptop") +``` + +Both work. But `env()` connects to scripts, which means setup/evaluate run automatically and you get structured traces. + +## Optional: Bring Your Own Agent + +v4 required using HUD's agent classes: + +```python +agent = ClaudeAgent.create() +result = await agent.run(task) +``` + +v5 gives you the `hud.eval()` context manager. Use any agent, any model, any framework: + +```python +async with hud.eval(env("checkout", product="laptop")) as ctx: + # Use OpenAI, Anthropic, your own agent—whatever you want + response = await client.chat.completions.create( + model="gpt-4o", + messages=[{"role": "user", "content": ctx.prompt}], + tools=ctx.as_openai_chat_tools() + ) + + # Handle tool calls, run your agent loop... + await ctx.submit(response.choices[0].message.content) + +print(ctx.reward) +``` + +The old `ClaudeAgent` and `OperatorAgent` still work—even with the new `hud.eval()` system. But now you're not locked into a specific agent spec. Pair with the [Gateway](/quick-links/gateway) to use any model through one API. + +## Quick Reference + +| v4 | v5 | +|----|-----| +| `MCPServer` | `Environment` (drop-in replacement) | +| `setup_tool` / `evaluate_tool` | `@env.script()` (recommended) | +| `Task(...)` | `env("script", ...)` (recommended) | +| `agent.run(task)` | `hud.eval()` + any agent (optional) | diff --git a/docs/quick-links/environments.mdx b/docs/quick-links/environments.mdx index 5c583bdc..4ccb49ce 100644 --- a/docs/quick-links/environments.mdx +++ b/docs/quick-links/environments.mdx @@ -6,6 +6,10 @@ icon: "cube" An environment is everything an agent can interact with—your APIs, services, databases, wrapped as tools. But it's more than that: the environment also defines how agents are *evaluated* through **scripts**. When you deploy an environment, you're creating a sandbox that agents can learn from at scale. +## Why Environments, Not API Servers? + +Your production API is a single live instance with shared state—you can't run 500 tests against it in parallel without causing chaos. Environments spin up fresh for every evaluation: isolated, deterministic, reproducible. Run thousands in parallel, each starting from the exact state you define, each generating training data. An API server is a live system you observe. An environment is a sandbox you control. + ## Tools Start with `hud init` to scaffold an environment—works with existing codebases or from scratch: From 186e23b9646d4f1fd084aedc78ad6e8ebfed749f Mon Sep 17 00:00:00 2001 From: lorenss-m Date: Fri, 12 Dec 2025 01:39:50 -0800 Subject: [PATCH 32/92] analyze uses fastncp --- hud/cli/analyze.py | 12 ++++++------ hud/cli/build.py | 6 +++--- hud/clients/__init__.py | 3 ++- 3 files changed, 11 insertions(+), 10 deletions(-) diff --git a/hud/cli/analyze.py b/hud/cli/analyze.py index 6fdf7441..ea44fb07 100644 --- a/hud/cli/analyze.py +++ b/hud/cli/analyze.py @@ -44,10 +44,10 @@ async def analyze_environment(docker_cmd: list[str], output_format: str, verbose ) as progress: task = progress.add_task("Initializing MCP client...", total=None) - # Lazy import to avoid loading mcp_use on simple CLI commands - from hud.clients import MCPClient + # Use FastMCP client directly - no mcp_use deprecation warnings + from hud.clients.fastmcp import FastMCPHUDClient - client = MCPClient(mcp_config=mcp_config, verbose=verbose, auto_trace=False) + client = FastMCPHUDClient(mcp_config=mcp_config, verbose=verbose, auto_trace=False) try: await client.initialize() @@ -379,10 +379,10 @@ async def _analyze_with_config( ) as progress: task = progress.add_task("Initializing MCP client...", total=None) - # Lazy import to avoid loading mcp_use on simple CLI commands - from hud.clients import MCPClient + # Use FastMCP client directly - no mcp_use deprecation warnings + from hud.clients.fastmcp import FastMCPHUDClient - client = MCPClient(mcp_config=mcp_config, verbose=verbose) + client = FastMCPHUDClient(mcp_config=mcp_config, verbose=verbose) try: await client.initialize() diff --git a/hud/cli/build.py b/hud/cli/build.py index 26963f3f..77cbebbc 100644 --- a/hud/cli/build.py +++ b/hud/cli/build.py @@ -450,11 +450,11 @@ async def analyze_mcp_environment( mcp_config = parse_docker_command(docker_cmd) # Initialize client and measure timing - # Lazy import to avoid loading mcp_use on simple CLI commands - from hud.clients import MCPClient + # Use FastMCP client directly - no mcp_use deprecation warnings + from hud.clients.fastmcp import FastMCPHUDClient start_time = time.time() - client = MCPClient(mcp_config=mcp_config, verbose=verbose, auto_trace=False) + client = FastMCPHUDClient(mcp_config=mcp_config, verbose=verbose, auto_trace=False) initialized = False try: diff --git a/hud/clients/__init__.py b/hud/clients/__init__.py index 0ffce9e4..4f93eec0 100644 --- a/hud/clients/__init__.py +++ b/hud/clients/__init__.py @@ -6,7 +6,7 @@ from .fastmcp import FastMCPHUDClient from .mcp_use import MCPUseHUDClient -# Default to MCP-use for new features +# Default to MCP-use for agents (has multi-server session support) MCPClient = MCPUseHUDClient __all__ = [ @@ -14,4 +14,5 @@ "BaseHUDClient", "FastMCPHUDClient", "MCPClient", + "MCPUseHUDClient", ] From c6d8d755057e602f1705f0b3ce5b1e8374a5e754 Mon Sep 17 00:00:00 2001 From: lorenss-m Date: Fri, 12 Dec 2025 02:36:16 -0800 Subject: [PATCH 33/92] add build analysis --- hud/cli/build.py | 29 +++++++++++++++++++++++++++-- 1 file changed, 27 insertions(+), 2 deletions(-) diff --git a/hud/cli/build.py b/hud/cli/build.py index 77cbebbc..ba00dcb7 100644 --- a/hud/cli/build.py +++ b/hud/cli/build.py @@ -519,6 +519,11 @@ async def analyze_mcp_environment( } if hub_map: result["hub_tools"] = hub_map + # Include prompts and resources from analysis + if full_analysis.get("prompts"): + result["prompts"] = full_analysis["prompts"] + if full_analysis.get("resources"): + result["resources"] = full_analysis["resources"] return result except TimeoutError: from hud.shared.exceptions import HudException @@ -753,8 +758,18 @@ def build_environment( finally: loop.close() - # Show analysis results including hub tools - tool_msg = f"Analyzed environment: {analysis['toolCount']} tools found" + # Show analysis results including hub tools, prompts, resources + tool_count = analysis["toolCount"] + prompt_count = len(analysis.get("prompts") or []) + resource_count = len(analysis.get("resources") or []) + + parts = [f"{tool_count} tools"] + if prompt_count: + parts.append(f"{prompt_count} prompts") + if resource_count: + parts.append(f"{resource_count} resources") + + tool_msg = f"Analyzed environment: {', '.join(parts)} found" hud_console.success(tool_msg) # Extract environment variables from Dockerfile @@ -885,6 +900,16 @@ def build_environment( if hub_tools: lock_content["hubTools"] = hub_tools + # Add prompts if present + prompts = analysis.get("prompts") + if prompts: + lock_content["prompts"] = prompts + + # Add resources if present + resources = analysis.get("resources") + if resources: + lock_content["resources"] = resources + # Write lock file lock_path = env_dir / "hud.lock.yaml" with open(lock_path, "w") as f: From 6e026bc62b7fd4625246d56cb1ead7a4b2c6a617 Mon Sep 17 00:00:00 2001 From: lorenss-m Date: Fri, 12 Dec 2025 02:44:26 -0800 Subject: [PATCH 34/92] docs update --- docs/docs.json | 10 +- docs/guides/agent-frameworks.mdx | 430 +++++++++++++++++++++++++++++++ docs/reference/environments.mdx | 14 + 3 files changed, 452 insertions(+), 2 deletions(-) create mode 100644 docs/guides/agent-frameworks.mdx diff --git a/docs/docs.json b/docs/docs.json index c1525789..20aafc43 100644 --- a/docs/docs.json +++ b/docs/docs.json @@ -35,8 +35,7 @@ "group": "Get Started", "pages": [ "index", - "llm-quickstart", - "migration" + "llm-quickstart" ] }, { @@ -59,6 +58,13 @@ "reference/types" ] }, + { + "group": "Guides", + "pages": [ + "guides/agent-frameworks", + "migration" + ] + }, { "group": "CLI Reference", "pages": [ diff --git a/docs/guides/agent-frameworks.mdx b/docs/guides/agent-frameworks.mdx new file mode 100644 index 00000000..7fc1e5c6 --- /dev/null +++ b/docs/guides/agent-frameworks.mdx @@ -0,0 +1,430 @@ +--- +title: "Agent Frameworks" +description: "Use any agent framework with HUD environments" +icon: "robot" +--- + +HUD environments work with any agent framework. The `Environment` class provides format converters for all major providers, and `hud.eval()` handles setup, evaluation, and tracing automatically. + +Every example on this page uses the `eval` defined below and the [Gateway](/quick-links/gateway) for inference. + +## The Example Environment + +```python +import hud + +CEOS = {"hud": "Jay Ram", "openai": "Sam Altman", "anthropic": "Dario Amodei"} + +env = hud.Environment("trivia") + +@env.tool() +def lookup_ceo(company: str) -> str: + """Look up the CEO of a company.""" + return CEOS.get(company.lower(), "Unknown") + +@env.script("initials") +async def find_initials(company: str): + answer = yield f"What are the initials of the CEO of {company}?" + ceo = CEOS.get(company.lower()) + correct = "".join(word[0] for word in ceo.split()) if ceo else None + yield 1.0 if answer and correct and correct in answer.upper() else 0.0 + +eval = env("initials", company="HUD") +``` + +--- + +## OpenAI + +The OpenAI SDK supports three APIs: Chat Completions, Responses, and the Agents SDK. + +### Chat Completions + +```python +import os +from openai import AsyncOpenAI +import hud + +client = AsyncOpenAI( + base_url="https://inference.hud.ai", + api_key=os.environ["HUD_API_KEY"] +) + +async with hud.eval(eval) as ctx: + messages = [{"role": "user", "content": ctx.prompt}] + + while True: + response = await client.chat.completions.create( + model="gpt-4o", + messages=messages, + tools=ctx.as_openai_chat_tools() + ) + + msg = response.choices[0].message + messages.append(msg) + + if not msg.tool_calls: + break + + for tool_call in msg.tool_calls: + result = await ctx.call_tool(tool_call) + messages.append(result) + + await ctx.submit(msg.content or "") +``` + +### Responses API + +```python +async with hud.eval(eval) as ctx: + response = await client.responses.create( + model="gpt-4o", + input=ctx.prompt, + tools=ctx.as_openai_responses_tools() + ) + + for item in response.output: + if item.type == "function_call": + await ctx.call_tool(item) + + await ctx.submit(response.output_text) +``` + +### Agents SDK + +```python +from agents import Agent, Runner +import hud + +async with hud.eval(eval) as ctx: + agent = Agent( + name="trivia-agent", + instructions="Answer trivia questions. Use tools to look up information.", + tools=ctx.as_openai_agent_tools() + ) + + result = await Runner.run(agent, ctx.prompt) + await ctx.submit(result.final_output) +``` + +Requires: `pip install openai-agents` + +--- + +## Anthropic + +Claude's Messages API with tool use. + +```python +import os +from anthropic import AsyncAnthropic +import hud + +client = AsyncAnthropic( + base_url="https://inference.hud.ai", + api_key=os.environ["HUD_API_KEY"] +) + +async with hud.eval(eval) as ctx: + messages = [{"role": "user", "content": ctx.prompt}] + + while True: + response = await client.messages.create( + model="claude-sonnet-4-20250514", + max_tokens=1024, + messages=messages, + tools=ctx.as_claude_tools() + ) + + tool_uses = [b for b in response.content if b.type == "tool_use"] + if not tool_uses: + break + + tool_results = [await ctx.call_tool(block) for block in tool_uses] + + messages.append({"role": "assistant", "content": response.content}) + messages.append({"role": "user", "content": tool_results}) + + text = next((b.text for b in response.content if b.type == "text"), "") + await ctx.submit(text) +``` + +Requires: `pip install anthropic` + +--- + +## Gemini + +Google's Gemini API with function calling. + +```python +import os +import google.generativeai as genai +import hud + +genai.configure(api_key=os.environ["GOOGLE_API_KEY"]) +model = genai.GenerativeModel("gemini-2.0-flash") + +async with hud.eval(eval) as ctx: + chat = model.start_chat() + + response = chat.send_message( + ctx.prompt, + tools=ctx.as_gemini_tools(), + tool_config=ctx.as_gemini_tool_config() + ) + + while True: + part = response.candidates[0].content.parts[0] + if not hasattr(part, "function_call") or not part.function_call: + break + + result = await ctx.call_tool(part) + response = chat.send_message(result) + + await ctx.submit(response.text) +``` + +Requires: `pip install google-generativeai` + +--- + +## browser-use + +Browser automation for web agents. + +```python +import os +from browser_use import Agent +from langchain_openai import ChatOpenAI +import hud + +llm = ChatOpenAI( + model="gpt-4o", + base_url="https://inference.hud.ai", + api_key=os.environ["HUD_API_KEY"] +) + +async with hud.eval(eval) as ctx: + agent = Agent(task=ctx.prompt, llm=llm) + result = await agent.run() + await ctx.submit(str(result)) +``` + +Requires: `pip install browser-use playwright && playwright install` + +--- + +## LangChain + +LangChain's agent framework with tool calling. + +```python +import os +from langchain_openai import ChatOpenAI +from langchain.agents import create_tool_calling_agent, AgentExecutor +from langchain_core.prompts import ChatPromptTemplate +import hud + +llm = ChatOpenAI( + model="gpt-4o", + base_url="https://inference.hud.ai", + api_key=os.environ["HUD_API_KEY"] +) + +async with hud.eval(eval) as ctx: + tools = ctx.as_langchain_tools() + + prompt = ChatPromptTemplate.from_messages([ + ("system", "You are a helpful assistant."), + ("human", "{input}"), + ("placeholder", "{agent_scratchpad}"), + ]) + + agent = create_tool_calling_agent(llm, tools, prompt) + executor = AgentExecutor(agent=agent, tools=tools) + + result = await executor.ainvoke({"input": ctx.prompt}) + await ctx.submit(result["output"]) +``` + +Requires: `pip install langchain langchain-openai langchain-core` + +--- + +## LlamaIndex + +LlamaIndex's ReAct agent with tool integration. + +```python +import os +from llama_index.llms.openai import OpenAI +from llama_index.core.agent import ReActAgent +import hud + +llm = OpenAI( + model="gpt-4o", + api_base="https://inference.hud.ai", + api_key=os.environ["HUD_API_KEY"] +) + +async with hud.eval(eval) as ctx: + tools = ctx.as_llamaindex_tools() + + agent = ReActAgent.from_tools(tools, llm=llm, verbose=True) + response = await agent.achat(ctx.prompt) + + await ctx.submit(str(response)) +``` + +Requires: `pip install llama-index-core llama-index-llms-openai` + +--- + +## Google ADK + +Google's Agent Development Kit for Gemini-powered agents. + +```python +import os +from google.adk.agents import Agent +from google.adk.runners import Runner +import hud + +async with hud.eval(eval) as ctx: + agent = Agent( + name="trivia-agent", + model="gemini-2.0-flash", + instruction="Answer trivia questions. Use tools to look up information.", + tools=ctx.as_adk_tools() + ) + + runner = Runner(agent=agent) + result = await runner.run(ctx.prompt) + + await ctx.submit(result.output) +``` + +Requires: `pip install google-adk` + +--- + +## CrewAI + +Multi-agent orchestration with roles and tasks. + +```python +import os +from crewai import Agent, Task, Crew +from langchain_openai import ChatOpenAI +import hud + +llm = ChatOpenAI( + model="gpt-4o", + base_url="https://inference.hud.ai", + api_key=os.environ["HUD_API_KEY"] +) + +async with hud.eval(eval) as ctx: + tools = ctx.as_langchain_tools() + + researcher = Agent( + role="Researcher", + goal="Find accurate information", + backstory="Expert at finding information", + tools=tools, + llm=llm + ) + + task = Task( + description=ctx.prompt, + expected_output="The initials of the CEO", + agent=researcher + ) + + crew = Crew(agents=[researcher], tasks=[task]) + result = crew.kickoff() + await ctx.submit(str(result)) +``` + +Requires: `pip install crewai langchain-openai` + +--- + +## AutoGen + +Microsoft's multi-agent conversation framework. + +```python +import os +from autogen import AssistantAgent, UserProxyAgent +import hud + +async with hud.eval(eval) as ctx: + config_list = [{ + "model": "gpt-4o", + "base_url": "https://inference.hud.ai", + "api_key": os.environ["HUD_API_KEY"] + }] + + assistant = AssistantAgent( + name="assistant", + llm_config={"config_list": config_list} + ) + + for tool in ctx.as_tools(): + @assistant.register_for_execution() + async def tool_fn(name=tool.name, **kwargs): + return await ctx.call_tool(name, **kwargs) + + user = UserProxyAgent( + name="user", + human_input_mode="NEVER", + code_execution_config=False + ) + + result = await user.a_initiate_chat(assistant, message=ctx.prompt) + await ctx.submit(result.summary) +``` + +Requires: `pip install pyautogen` + +--- + +## Format Reference + +| Method | Returns | Use With | +|--------|---------|----------| +| `as_openai_chat_tools()` | OpenAI Chat format | OpenAI Chat Completions | +| `as_openai_responses_tools()` | OpenAI Responses format | OpenAI Responses API | +| `as_openai_agent_tools()` | FunctionTool objects | OpenAI Agents SDK | +| `as_claude_tools()` | Anthropic format | Claude API | +| `as_gemini_tools()` | Gemini format | Google AI | +| `as_adk_tools()` | ADK FunctionTool objects | Google ADK | +| `as_langchain_tools()` | StructuredTool objects | LangChain, CrewAI | +| `as_llamaindex_tools()` | FunctionTool objects | LlamaIndex | +| `as_tools()` | MCP Tool objects | Raw MCP, AutoGen | + +All `call_tool()` calls auto-detect the input format and return matching output format. + +--- + +## Bring Your Own + +Don't see your framework? The pattern is simple: + +1. Get tools in your framework's format (or use `as_tools()` for raw MCP) +2. Run your agent loop +3. Call `ctx.call_tool()` for each tool invocation +4. Call `ctx.submit()` with the final answer + +```python +async with hud.eval(eval) as ctx: + tools = ctx.as_tools() # Raw MCP format + + result = await my_custom_agent(ctx.prompt, tools, ctx.call_tool) + + await ctx.submit(result) +``` + +The environment handles setup, evaluation, and tracing. You handle the agent logic. diff --git a/docs/reference/environments.mdx b/docs/reference/environments.mdx index e58323d4..f9632318 100644 --- a/docs/reference/environments.mdx +++ b/docs/reference/environments.mdx @@ -208,6 +208,20 @@ config = env.as_gemini_tool_config() tools = env.as_langchain_tools() ``` +### LlamaIndex + +```python +# Requires llama-index-core +tools = env.as_llamaindex_tools() +``` + +### Google ADK + +```python +# Requires google-adk +tools = env.as_adk_tools() +``` + ## Calling Tools ### call_tool() From 079f739999de8dff3f500d789f81c28fd2aae9c2 Mon Sep 17 00:00:00 2001 From: lorenss-m Date: Fri, 12 Dec 2025 04:43:00 -0800 Subject: [PATCH 35/92] update docs --- docs/docs.json | 15 +- ...{agent-frameworks.mdx => integrations.mdx} | 2 +- docs/guides/sandboxing.mdx | 199 ++++++++++++++++++ 3 files changed, 208 insertions(+), 8 deletions(-) rename docs/guides/{agent-frameworks.mdx => integrations.mdx} (99%) create mode 100644 docs/guides/sandboxing.mdx diff --git a/docs/docs.json b/docs/docs.json index 20aafc43..db2d1406 100644 --- a/docs/docs.json +++ b/docs/docs.json @@ -47,6 +47,14 @@ "quick-links/deploy" ] }, + { + "group": "Guides", + "pages": [ + "guides/integrations", + "guides/sandboxing", + "migration" + ] + }, { "group": "SDK Reference", "pages": [ @@ -58,13 +66,6 @@ "reference/types" ] }, - { - "group": "Guides", - "pages": [ - "guides/agent-frameworks", - "migration" - ] - }, { "group": "CLI Reference", "pages": [ diff --git a/docs/guides/agent-frameworks.mdx b/docs/guides/integrations.mdx similarity index 99% rename from docs/guides/agent-frameworks.mdx rename to docs/guides/integrations.mdx index 7fc1e5c6..b929ab0d 100644 --- a/docs/guides/agent-frameworks.mdx +++ b/docs/guides/integrations.mdx @@ -1,5 +1,5 @@ --- -title: "Agent Frameworks" +title: "Integrations" description: "Use any agent framework with HUD environments" icon: "robot" --- diff --git a/docs/guides/sandboxing.mdx b/docs/guides/sandboxing.mdx new file mode 100644 index 00000000..dd55b7fe --- /dev/null +++ b/docs/guides/sandboxing.mdx @@ -0,0 +1,199 @@ +--- +title: "Sandboxing" +description: "Turn your existing services into agent-testable environments" +icon: "cube" +--- + +You have a production stack. You want an agent on it. But you can't just point an agent at production—it'll make real changes, hit real APIs, affect real users. And you can't test at scale against a single live instance with shared state. + +HUD lets you mock your production environment so agents can run against it safely. Connect your services in a few lines. Write evals that tell agents what to do and grade how well they did it. HUD handles the sandboxing, the parallelization, the state extraction, the tracing. You get a reliable test bed where thousands of agents can run in parallel—each isolated, each reproducible, each generating useful data. + +## Connecting Your Stack + +HUD wraps your existing infrastructure. Your code stays where it is—you connect it: + +```python +from hud import Environment + +env = Environment("my-env") + +# Your FastAPI app → all routes become tools +env.connect_fastapi(app) + +# Your MCP servers +env.connect_server(mcp_server) + +# Any REST API with an OpenAPI spec +env.connect_openapi("https://api.example.com/openapi.json") +``` + +Docker images work with `env.connect_image("my-service:v1")`. Other HUD environments compose with `env.connect_hub("my-org/other-env")`. See the full list in the [Environment Reference](/reference/environments). + +Run `hud init` to scaffold an environment in an existing project—it adds the HUD files without touching your code. Once connected, deploy and run evals at scale. + +### Making It Safe + +HUD runs each eval in its own container—isolated, reproducible, safe. But your environment might connect to external services. Here's how to handle them: + +**Databases.** Each agent needs its own sandbox. Use in-memory SQLite (fast, resets per eval), transaction rollback, or seed fresh data at start: + +```python +@env.script("update-order") +async def update_order(order_id: str): + await db.seed_from("fixtures/orders.sql") + + answer = yield f"Update order {order_id} status to 'shipped'" + + order = await db.query("SELECT status FROM orders WHERE id = ?", order_id) + yield 1.0 if order and order["status"] == "shipped" else 0.0 +``` + +**Third-party APIs.** Use mock mode to return fake responses without hitting real services: + +```python +env.mock() # All tools return fake responses based on schemas +env.mock_tool("send_email", {"status": "sent", "id": "mock-123"}) # Override specific tools +``` + +**Credentials.** If you need a live service, use staging keys. Point evals at staging, not production. + +## Good Environments + +A good environment gives agents what they need to succeed—and gives you what you need to evaluate them. + +### Observable State + +Agents need access to the right information. If they can't see the data they need, they can't complete the task. Design tools that expose useful state: + +```python +# ❌ Bad: Agent can't see what was created +@env.tool() +def create_user(name: str) -> str: + db.insert("users", name=name) + return "User created" + +# ✅ Good: Agent gets actionable data back +@env.tool() +def create_user(name: str) -> dict: + user_id = db.insert("users", name=name) + return {"id": user_id, "name": name, "created": True} +``` + +For grading, you also need to observe what happened. If the agent creates a database row, you need to query that database. If it uploads a file, you need to read that file. Be cognizant of what you can and cannot observe—only ask agents to do things you can verify. + +### Deterministic Setup + +Each eval should seed the state it needs. HUD handles container isolation—you handle making sure your script sets up the right data before the agent runs. + +```python +# ❌ Bad: Depends on whatever state exists +@env.script("find-user") +async def find_user(name: str): + answer = yield f"Find the user named {name}" + yield 1.0 if name in answer else 0.0 + +# ✅ Good: Seeds known state before eval +@env.script("find-user") +async def find_user(name: str): + await db.clear() + await db.insert("users", name=name, email=f"{name}@example.com") + + answer = yield f"Find the user named {name}" + yield 1.0 if name in answer else 0.0 +``` + +### Isolated Execution + +HUD sandboxes each eval—containers don't share state. But if your environment connects to external services, think about stateful vs stateless. + +**Stateless services** are fine. Multiple agents can hit the same read-only API without interference. + +**Stateful services** need care. If 100 agents all hit the same database endpoint that modifies data, they'll step on each other. Use per-eval instances, transaction isolation, or target different records. + +## Good Evals + +An eval combines a prompt (the first `yield`) with grading logic (everything after). The prompt tells agents what to do—write short-to-medium length instructions that ask for an unambiguous change you can verify. + +### Be Specific + +Ambiguous prompts lead to ambiguous grading. Say exactly what you want: + +``` +❌ "Update the user settings" +✅ "Change the email for user alice@example.com to alice.new@example.com" +``` + +Real-world example: *"Add a column to the Portfolio snapshot with the 'Phase' of the engagement. C-11X should be 'Phase 2', all else are 'Phase 1'."* + +### Only Ask for Testable Things + +If you can't observe the result, you can't grade it. Don't ask an agent to "think about" something—ask it to do something you can verify. + +``` +❌ "Consider the best approach to optimize the query" +✅ "Rewrite the query to use an index on the email column" +``` + +### Create Variations + +Evals are easier to write when you have a specific failure mode in mind. If you've observed agents struggling with something, incorporate that into future evals. + +Create different versions with more or less explicit instructions—step-by-step guidance vs. high-level goals. Use [variants](/quick-links/ab-testing) to test these systematically. Variations make it easier to tune difficulty later. + +## Good Graders + +The grading logic after the first `yield` determines the grade. Fair grading means useful signal. + +### Match the Prompt + +If the prompt says "create a document with a Japanese car brand", check for any Japanese car brand—not just "Toyota". But don't accept any document either. Exactly as strict as the prompt implies. + +```python +# ❌ Bad: Too strict—only accepts one answer +@env.script("add-car") +async def add_car(): + answer = yield "Add a Japanese car brand to the document" + yield 1.0 if answer == "Toyota" else 0.0 + +# ✅ Good: Accepts any valid answer +@env.script("add-car") +async def add_car(): + answer = yield "Add a Japanese car brand to the document" + japanese_brands = ["toyota", "honda", "nissan", "mazda", "subaru"] + yield 1.0 if any(brand in answer.lower() for brand in japanese_brands) else 0.0 +``` + +### Use Partial Credit + +Partial grades help you see where agents fail. Did they add to cart but not checkout? That's useful signal. Break complex grading into sub-checks with weighted grades: + +```python +@env.script("checkout") +async def checkout(product: str): + answer = yield f"Add {product} to cart and checkout" + + score = 0.0 + if await product_in_cart(product): + score += 0.3 # Partial credit for first step + if await order_completed(product): + score += 0.7 # Most credit for completion + yield score +``` + +### Sanity Check + +At minimum, verify two cases: unchanged state → 0.0, correct completion → 1.0. For grading logic you'll reuse across many evals, write unit tests. Load a known state snapshot, verify the grade matches what you expect. + +## What's Next + +Once your environment is connected and your evals are written, you're ready to run at scale. + +**Deploy.** Push to GitHub, connect on [hud.ai](https://hud.ai), and your environment goes live. See [Deploy](/quick-links/deploy). + +**Run with any agent.** Use [Integrations](/guides/integrations) to connect OpenAI, Anthropic, LangChain, or your own agent loop. + +**Find the right difficulty.** A good eval set has range—target 20-30% average success rate. You want high variance: some runs should grade 0.0, others 1.0. If every run grades the same, there's no signal to learn from. Having both positive and negative examples on the same eval is what makes improvement possible. + +**Iterate.** Create an eval, test it manually, run it at scale, check the difficulty. If it's too easy or too hard, adjust the prompt or grading. Use your best evals as templates for more. + +**Train.** Every eval generates data—prompts, tool calls, grades. Use successful runs for fine-tuning. The loop: eval → analyze → train → eval again. From 0f98f236414607eb0d337f4b7f5eb797f63359c9 Mon Sep 17 00:00:00 2001 From: lorenss-m Date: Fri, 12 Dec 2025 04:43:14 -0800 Subject: [PATCH 36/92] adjust agent class and envs --- hud/agents/base.py | 100 ++++++++++- hud/agents/tests/test_run_eval.py | 190 +++++++++++++++++++++ hud/clients/__init__.py | 2 + hud/clients/environment.py | 51 ++++++ hud/environment/integrations/__init__.py | 10 ++ hud/environment/integrations/adk.py | 67 ++++++++ hud/environment/integrations/langchain.py | 51 ++---- hud/environment/integrations/llamaindex.py | 68 ++++++++ hud/environment/utils/__init__.py | 10 ++ hud/environment/utils/tool_wrappers.py | 111 ++++++++++++ 10 files changed, 616 insertions(+), 44 deletions(-) create mode 100644 hud/agents/tests/test_run_eval.py create mode 100644 hud/clients/environment.py create mode 100644 hud/environment/integrations/adk.py create mode 100644 hud/environment/integrations/llamaindex.py create mode 100644 hud/environment/utils/tool_wrappers.py diff --git a/hud/agents/base.py b/hud/agents/base.py index 05e12094..803d3380 100644 --- a/hud/agents/base.py +++ b/hud/agents/base.py @@ -19,6 +19,8 @@ if TYPE_CHECKING: from hud.datasets import Task + from hud.environment import Environment + from hud.eval.context import EvalContext logger = logging.getLogger(__name__) @@ -209,24 +211,66 @@ async def initialize(self, task: str | Task | None = None) -> None: f"Agent initialized with {len(self.get_available_tools())} tools: {', '.join([t.name for t in self.get_available_tools()])}" # noqa: E501 ) - async def run(self, prompt_or_task: str | Task | dict[str, Any], max_steps: int = 10) -> Trace: + async def run( + self, + prompt_or_task: str | Task | EvalContext | Environment | dict[str, Any], + max_steps: int = 10, + ) -> Trace: """ - Run the agent with the given prompt or task. + Run the agent with the given prompt, task, or environment. Args: - prompt_or_task: Either a string prompt for simple execution or a Task object + prompt_or_task: One of: + - str: Simple text prompt + - Task: Task object with mcp_config, setup_tool, evaluate_tool + - EvalContext: From hud.eval() - uses ctx.prompt, ctx.call_tool, ctx.submit + - Environment: Connected environment to use for tool calls + - dict: Task-like dict (converted to Task) max_steps: Maximum number of steps (-1 for infinite) Returns: Trace with reward, done, content, isError fields and trace steps + + Example: + # With EvalContext from hud.eval + async with hud.eval(evals) as ctx: + result = await agent.run(ctx) + # result.reward comes from script evaluate """ # Import here to avoid circular imports from hud.datasets import Task + from hud.environment import Environment + from hud.eval.context import EvalContext + + # Handle EvalContext - delegate to run_eval + if isinstance(prompt_or_task, EvalContext): + return await self.run_eval(prompt_or_task, max_steps=max_steps) + + # Handle Environment (non-eval) - wrap with EnvironmentClient + if isinstance(prompt_or_task, Environment) and not isinstance(prompt_or_task, EvalContext): + from hud.clients.environment import EnvironmentClient + + env = prompt_or_task + if not env.prompt: + raise ValueError("Environment.prompt is not set") + + client = EnvironmentClient(env) + self.mcp_client = client + + try: + await self.initialize(env.prompt) + result = await self._run_context(text_to_blocks(env.prompt), max_steps=max_steps) + return result + finally: + self.mcp_client = None if isinstance(prompt_or_task, dict): prompt_or_task = Task(**prompt_or_task) elif not isinstance(prompt_or_task, str) and not isinstance(prompt_or_task, Task): - raise TypeError(f"prompt_or_task must be str or Task, got {type(prompt_or_task)}") + raise TypeError( + f"prompt_or_task must be str, Task, EvalContext, or Environment, " + f"got {type(prompt_or_task)}" + ) try: # Establish the connection with the MCP server/Environment @@ -376,6 +420,54 @@ async def run_task(self, task: Task, max_steps: int = 10) -> Trace: return prompt_result + async def run_eval(self, ctx: EvalContext, *, max_steps: int = 10) -> Trace: + """ + Run the agent with an EvalContext from hud.eval(). + + This method integrates with the hud.eval framework: + - Uses ctx.prompt as the starting prompt + - Uses ctx for tool calls via EnvironmentClient adapter + - Calls ctx.submit(response) when the agent finishes + - Reward is available on ctx.reward after the hud.eval block exits + + Args: + ctx: EvalContext from hud.eval() - already connected and has prompt set + max_steps: Maximum number of agent steps (-1 for infinite) + + Returns: + Trace with agent output. Note: ctx.reward is set by script evaluate + phase which runs when the hud.eval block exits. + + Example: + ```python + async with hud.eval(evals) as ctx: + result = await agent.run_eval(ctx) + # ctx.reward is now set by the script's evaluate phase + print(f"Reward: {ctx.reward}") + ``` + """ + from hud.clients.environment import EnvironmentClient + from hud.eval.context import EvalContext + + if not isinstance(ctx, EvalContext): + raise TypeError(f"ctx must be EvalContext, got {type(ctx)}") + + if not ctx.prompt: + raise ValueError("EvalContext.prompt is not set - did the script setup run?") + + self.mcp_client = EnvironmentClient(ctx) + try: + await self.initialize(ctx.prompt) + result = await self._run_context(text_to_blocks(ctx.prompt), max_steps=max_steps) + if result.content: + await ctx.submit(result.content) + return result + except Exception as e: + logger.exception("Error running agent with EvalContext:") + return Trace(reward=0.0, done=True, content=str(e), isError=True) + finally: + self.mcp_client = None + async def _run_context( self, context: list[types.ContentBlock], *, max_steps: int = 10 ) -> Trace: diff --git a/hud/agents/tests/test_run_eval.py b/hud/agents/tests/test_run_eval.py new file mode 100644 index 00000000..746c80d6 --- /dev/null +++ b/hud/agents/tests/test_run_eval.py @@ -0,0 +1,190 @@ +"""Tests for run_eval and EnvironmentClient.""" + +from __future__ import annotations + +from typing import Any, ClassVar + +import pytest +from mcp import types + +from hud.agents import MCPAgent +from hud.agents.base import BaseCreateParams +from hud.clients.environment import EnvironmentClient +from hud.eval.context import EvalContext +from hud.types import AgentResponse, BaseAgentConfig, MCPToolCall, MCPToolResult + + +class MockConfig(BaseAgentConfig): + model_name: str = "MockAgent" + checkpoint_name: str = "mock-model" + + +class MockCreateParams(BaseCreateParams, MockConfig): + pass + + +class MockMCPAgent(MCPAgent): + """Mock agent for testing run_eval.""" + + metadata: ClassVar[dict[str, Any] | None] = {} + config_cls: ClassVar[type[BaseAgentConfig]] = MockConfig + + def __init__(self, **kwargs: Any) -> None: + params = MockCreateParams(**kwargs) + super().__init__(params) + self._response = AgentResponse(content="Test response", tool_calls=[], done=True) + + def set_response(self, response: AgentResponse) -> None: + self._response = response + + async def create_initial_messages( + self, prompt: str, initial_screenshot: bool = False + ) -> list[dict[str, Any]]: + return [{"role": "user", "content": prompt}] + + async def get_response(self, messages: list[dict[str, Any]]) -> AgentResponse: + return self._response + + async def format_tool_results( + self, tool_calls: list[MCPToolCall], tool_results: list[MCPToolResult] + ) -> list[dict[str, Any]]: + return [{"role": "tool", "content": str(r)} for r in tool_results] + + async def create_user_message(self, text: str) -> Any: + return {"role": "user", "content": text} + + async def get_system_messages(self) -> list[Any]: + return [] + + async def format_blocks(self, blocks: list[types.ContentBlock]) -> list[Any]: + return [{"type": "text", "text": b.text} for b in blocks if hasattr(b, "text")] + + +class MockEvalContext(EvalContext): + """Mock EvalContext for testing - inherits from real EvalContext.""" + + def __init__(self, prompt: str = "Test prompt", tools: list[types.Tool] | None = None) -> None: + # Skip parent __init__, just set what we need + self.prompt = prompt + self._tools = tools or [ + types.Tool(name="test_tool", description="Test", inputSchema={}) + ] + self._submitted: str | None = None + self.reward: float | None = None + + async def list_tools(self) -> list[types.Tool]: + return self._tools + + async def call_tool(self, name: str, **kwargs: Any) -> MCPToolResult: + return MCPToolResult( + content=[types.TextContent(type="text", text=f"Result from {name}")], + isError=False, + ) + + async def submit(self, answer: str) -> None: + self._submitted = answer + + +class TestEnvironmentClient: + """Tests for EnvironmentClient adapter.""" + + @pytest.mark.asyncio + async def test_initialize(self) -> None: + """Test client initialization.""" + ctx = MockEvalContext() + client = EnvironmentClient(ctx) + + assert not client.is_connected + await client.initialize() + assert client.is_connected + + @pytest.mark.asyncio + async def test_list_tools(self) -> None: + """Test listing tools through adapter.""" + ctx = MockEvalContext() + client = EnvironmentClient(ctx) + + tools = await client.list_tools() + assert len(tools) == 1 + assert tools[0].name == "test_tool" + + @pytest.mark.asyncio + async def test_call_tool(self) -> None: + """Test calling tools through adapter.""" + ctx = MockEvalContext() + client = EnvironmentClient(ctx) + + result = await client.call_tool(MCPToolCall(name="test_tool", arguments={})) + assert not result.isError + assert len(result.content) == 1 + + @pytest.mark.asyncio + async def test_mcp_config_empty(self) -> None: + """Test mcp_config is empty for environment clients.""" + ctx = MockEvalContext() + client = EnvironmentClient(ctx) + assert client.mcp_config == {} + + @pytest.mark.asyncio + async def test_shutdown(self) -> None: + """Test shutdown resets initialized state.""" + ctx = MockEvalContext() + client = EnvironmentClient(ctx) + + await client.initialize() + assert client.is_connected + + await client.shutdown() + assert not client.is_connected + + +class TestRunEval: + """Tests for MCPAgent.run_eval().""" + + @pytest.mark.asyncio + async def test_run_eval_basic(self) -> None: + """Test basic run_eval flow.""" + ctx = MockEvalContext(prompt="Do the task") + agent = MockMCPAgent() + + result = await agent.run_eval(ctx) + + assert result.done + assert result.content == "Test response" + assert ctx._submitted == "Test response" + + @pytest.mark.asyncio + async def test_run_eval_no_prompt_raises(self) -> None: + """Test run_eval raises when prompt is not set.""" + ctx = MockEvalContext(prompt="") + agent = MockMCPAgent() + + with pytest.raises(ValueError, match="prompt is not set"): + await agent.run_eval(ctx) + + @pytest.mark.asyncio + async def test_run_eval_wrong_type_raises(self) -> None: + """Test run_eval raises TypeError for non-EvalContext.""" + agent = MockMCPAgent() + + with pytest.raises(TypeError, match="must be EvalContext"): + await agent.run_eval("not an eval context") # type: ignore[arg-type] + + @pytest.mark.asyncio + async def test_run_eval_clears_client(self) -> None: + """Test run_eval clears mcp_client after completion.""" + ctx = MockEvalContext(prompt="Do the task") + agent = MockMCPAgent() + + await agent.run_eval(ctx) + assert agent.mcp_client is None + + @pytest.mark.asyncio + async def test_run_eval_no_submit_on_empty_content(self) -> None: + """Test run_eval doesn't submit when content is empty.""" + ctx = MockEvalContext(prompt="Do the task") + agent = MockMCPAgent() + agent.set_response(AgentResponse(content="", tool_calls=[], done=True)) + + await agent.run_eval(ctx) + assert ctx._submitted is None diff --git a/hud/clients/__init__.py b/hud/clients/__init__.py index 4f93eec0..31692021 100644 --- a/hud/clients/__init__.py +++ b/hud/clients/__init__.py @@ -3,6 +3,7 @@ from __future__ import annotations from .base import AgentMCPClient, BaseHUDClient +from .environment import EnvironmentClient from .fastmcp import FastMCPHUDClient from .mcp_use import MCPUseHUDClient @@ -12,6 +13,7 @@ __all__ = [ "AgentMCPClient", "BaseHUDClient", + "EnvironmentClient", "FastMCPHUDClient", "MCPClient", "MCPUseHUDClient", diff --git a/hud/clients/environment.py b/hud/clients/environment.py new file mode 100644 index 00000000..6f42f368 --- /dev/null +++ b/hud/clients/environment.py @@ -0,0 +1,51 @@ +"""Environment-based client adapter for agents.""" + +from __future__ import annotations + +from typing import TYPE_CHECKING, Any + +import mcp.types as types + +from hud.types import MCPToolCall, MCPToolResult + +if TYPE_CHECKING: + from hud.environment import Environment + from hud.eval.context import EvalContext + +__all__ = ["EnvironmentClient"] + + +class EnvironmentClient: + """Adapter wrapping Environment/EvalContext as AgentMCPClient.""" + + def __init__(self, env: Environment | EvalContext) -> None: + self._env = env + self._initialized = False + + @property + def mcp_config(self) -> dict[str, dict[str, Any]]: + return {} + + @property + def is_connected(self) -> bool: + return self._initialized + + async def initialize(self, mcp_config: dict[str, dict[str, Any]] | None = None) -> None: + if not self._initialized: + await self._env.list_tools() + self._initialized = True + + async def list_tools(self) -> list[types.Tool]: + return await self._env.list_tools() + + async def call_tool(self, tool_call: MCPToolCall) -> MCPToolResult: + result = await self._env.call_tool(tool_call.name, **(tool_call.arguments or {})) + if isinstance(result, MCPToolResult): + return result + return MCPToolResult( + content=[types.TextContent(type="text", text=str(result))], + isError=False, + ) + + async def shutdown(self) -> None: + self._initialized = False diff --git a/hud/environment/integrations/__init__.py b/hud/environment/integrations/__init__.py index 82610bf9..412f283f 100644 --- a/hud/environment/integrations/__init__.py +++ b/hud/environment/integrations/__init__.py @@ -1,8 +1,10 @@ """Provider integrations - format conversion and framework tools.""" +from hud.environment.integrations.adk import ADKMixin from hud.environment.integrations.anthropic import AnthropicMixin from hud.environment.integrations.gemini import GeminiMixin from hud.environment.integrations.langchain import LangChainMixin +from hud.environment.integrations.llamaindex import LlamaIndexMixin from hud.environment.integrations.openai import OpenAIMixin __all__ = ["IntegrationsMixin"] @@ -13,6 +15,8 @@ class IntegrationsMixin( AnthropicMixin, GeminiMixin, LangChainMixin, + LlamaIndexMixin, + ADKMixin, ): """Combined integration mixin for all providers. @@ -30,6 +34,12 @@ class IntegrationsMixin( as_gemini_tools() - Gemini format as_gemini_tool_config() - Tool config + Google ADK: + as_adk_tools() - ADK FunctionTool objects (requires google-adk) + LangChain: as_langchain_tools() - StructuredTools (requires langchain-core) + + LlamaIndex: + as_llamaindex_tools() - FunctionTools (requires llama-index-core) """ diff --git a/hud/environment/integrations/adk.py b/hud/environment/integrations/adk.py new file mode 100644 index 00000000..2d33887f --- /dev/null +++ b/hud/environment/integrations/adk.py @@ -0,0 +1,67 @@ +"""Google ADK integration.""" + +from __future__ import annotations + +from typing import TYPE_CHECKING, Any + +from hud.environment.utils.tool_wrappers import create_async_tool_fn + +if TYPE_CHECKING: + import mcp.types as mcp_types + +__all__ = ["ADKMixin"] + + +class ADKMixin: + """Mixin providing Google ADK (Agent Development Kit) integration. + + Integration methods (requires google-adk): + as_adk_tools() - ADK FunctionTool objects + + Requires: as_tools() -> list[mcp_types.Tool], call_tool(name, args) + """ + + def as_tools(self) -> list[mcp_types.Tool]: + raise NotImplementedError + + async def call_tool(self, name: str, arguments: dict[str, Any]) -> Any: + raise NotImplementedError + + def as_adk_tools(self) -> list[Any]: + """Convert to Google ADK FunctionTool objects. + + Requires: pip install google-adk + + Returns: + List of FunctionTool objects for Google ADK agents. + + Example: + ```python + from google.adk.agents import Agent + from google.adk.runners import Runner + + async with env: + agent = Agent( + name="assistant", + model="gemini-2.0-flash", + instruction="You are a helpful assistant.", + tools=env.as_adk_tools() + ) + runner = Runner(agent=agent) + result = await runner.run("Find information about Python") + ``` + """ + try: + from google.adk.tools import FunctionTool + except ImportError as e: + raise ImportError( + "Google ADK not installed. Install with: pip install google-adk" + ) from e + + tools = [] + for t in self.as_tools(): + # ADK only needs async function - it wraps it in FunctionTool + async_fn = create_async_tool_fn(self, t.name, t.description) + tool = FunctionTool(async_fn) + tools.append(tool) + return tools diff --git a/hud/environment/integrations/langchain.py b/hud/environment/integrations/langchain.py index f86e936f..09d0d52f 100644 --- a/hud/environment/integrations/langchain.py +++ b/hud/environment/integrations/langchain.py @@ -2,10 +2,10 @@ from __future__ import annotations -import json from typing import TYPE_CHECKING, Any from hud.environment.utils.schema import schema_to_pydantic +from hud.environment.utils.tool_wrappers import create_tool_fns if TYPE_CHECKING: import mcp.types as mcp_types @@ -68,44 +68,15 @@ def as_langchain_tools(self) -> list[Any]: tools = [] for t in self.as_tools(): - tool = _create_structured_tool(self, t, StructuredTool) + schema = t.inputSchema or {"type": "object", "properties": {}} + sync_fn, async_fn = create_tool_fns(self, t) + + tool = StructuredTool( + name=t.name, + description=t.description or "", + func=sync_fn, + coroutine=async_fn, + args_schema=schema_to_pydantic(t.name, schema), + ) tools.append(tool) return tools - - -def _create_structured_tool(env: LangChainMixin, tool: mcp_types.Tool, StructuredTool: type) -> Any: - """Create a StructuredTool that calls back to the environment.""" - import asyncio - - schema = tool.inputSchema or {"type": "object", "properties": {}} - - def sync_invoke(**kwargs: Any) -> str: - """Synchronous wrapper for the tool.""" - loop = asyncio.get_event_loop() - if loop.is_running(): - import concurrent.futures - - with concurrent.futures.ThreadPoolExecutor() as executor: - future = executor.submit(asyncio.run, env.call_tool(tool.name, **kwargs)) - result = future.result() - else: - result = loop.run_until_complete(env.call_tool(tool.name, **kwargs)) - - if isinstance(result, str): - return result - return json.dumps(result) if result else "" - - async def async_invoke(**kwargs: Any) -> str: - """Async wrapper for the tool.""" - result = await env.call_tool(tool.name, **kwargs) - if isinstance(result, str): - return result - return json.dumps(result) if result else "" - - return StructuredTool( - name=tool.name, - description=tool.description or "", - func=sync_invoke, - coroutine=async_invoke, - args_schema=schema_to_pydantic(tool.name, schema), - ) diff --git a/hud/environment/integrations/llamaindex.py b/hud/environment/integrations/llamaindex.py new file mode 100644 index 00000000..0815d05a --- /dev/null +++ b/hud/environment/integrations/llamaindex.py @@ -0,0 +1,68 @@ +"""LlamaIndex integration.""" + +from __future__ import annotations + +from typing import TYPE_CHECKING, Any + +from hud.environment.utils.tool_wrappers import create_tool_fns + +if TYPE_CHECKING: + import mcp.types as mcp_types + +__all__ = ["LlamaIndexMixin"] + + +class LlamaIndexMixin: + """Mixin providing LlamaIndex integration. + + Integration methods (requires llama-index-core): + as_llamaindex_tools() - LlamaIndex FunctionTool objects + + Requires: as_tools() -> list[mcp_types.Tool], call_tool(name, args) + """ + + def as_tools(self) -> list[mcp_types.Tool]: + raise NotImplementedError + + async def call_tool(self, name: str, arguments: dict[str, Any]) -> Any: + raise NotImplementedError + + def as_llamaindex_tools(self) -> list[Any]: + """Convert to LlamaIndex FunctionTool objects. + + Requires: pip install llama-index-core + + Returns: + List of FunctionTool objects for LlamaIndex agents. + + Example: + ```python + from llama_index.llms.openai import OpenAI + from llama_index.core.agent import ReActAgent + + llm = OpenAI(model="gpt-4o") + async with env: + tools = env.as_llamaindex_tools() + agent = ReActAgent.from_tools(tools, llm=llm, verbose=True) + response = await agent.achat("Find information about Python") + ``` + """ + try: + from llama_index.core.tools import FunctionTool + except ImportError as e: + raise ImportError( + "LlamaIndex not installed. Install with: pip install llama-index-core" + ) from e + + tools = [] + for t in self.as_tools(): + sync_fn, async_fn = create_tool_fns(self, t) + + tool = FunctionTool.from_defaults( + fn=sync_fn, + async_fn=async_fn, + name=t.name, + description=t.description or "", + ) + tools.append(tool) + return tools diff --git a/hud/environment/utils/__init__.py b/hud/environment/utils/__init__.py index 81d9fc36..1e0318bd 100644 --- a/hud/environment/utils/__init__.py +++ b/hud/environment/utils/__init__.py @@ -12,9 +12,18 @@ json_type_to_python, schema_to_pydantic, ) +from hud.environment.utils.tool_wrappers import ( + create_async_tool_fn, + create_sync_tool_fn, + create_tool_fns, + stringify_result, +) __all__ = [ "ToolFormat", + "create_async_tool_fn", + "create_sync_tool_fn", + "create_tool_fns", "ensure_strict_schema", "format_result", "json_type_to_python", @@ -22,4 +31,5 @@ "parse_tool_calls", "result_to_string", "schema_to_pydantic", + "stringify_result", ] diff --git a/hud/environment/utils/tool_wrappers.py b/hud/environment/utils/tool_wrappers.py new file mode 100644 index 00000000..876c5632 --- /dev/null +++ b/hud/environment/utils/tool_wrappers.py @@ -0,0 +1,111 @@ +"""Shared tool wrapper utilities for agent framework integrations.""" + +from __future__ import annotations + +import json +from typing import TYPE_CHECKING, Any, Callable + +if TYPE_CHECKING: + import mcp.types as mcp_types + +__all__ = [ + "create_sync_tool_fn", + "create_async_tool_fn", + "stringify_result", + "create_tool_fns", +] + + +def stringify_result(result: Any) -> str: + """Convert a tool result to string format. + + Args: + result: The tool result (str, dict, or other). + + Returns: + String representation of the result. + """ + if isinstance(result, str): + return result + return json.dumps(result) if result else "" + + +def create_async_tool_fn( + env: Any, + tool_name: str, + description: str | None = None, +) -> Callable[..., Any]: + """Create an async function that calls a tool on the environment. + + Args: + env: Environment with call_tool method. + tool_name: Name of the tool to call. + description: Optional description for the function docstring. + + Returns: + Async function that calls the tool and returns string result. + """ + + async def async_fn(**kwargs: Any) -> str: + result = await env.call_tool(tool_name, **kwargs) + return stringify_result(result) + + async_fn.__name__ = tool_name + async_fn.__doc__ = description or f"Tool: {tool_name}" + return async_fn + + +def create_sync_tool_fn( + env: Any, + tool_name: str, + description: str | None = None, +) -> Callable[..., Any]: + """Create a sync function that calls a tool on the environment. + + This handles the complexity of running async code from sync context, + including when already in an async event loop. + + Args: + env: Environment with call_tool method. + tool_name: Name of the tool to call. + description: Optional description for the function docstring. + + Returns: + Sync function that calls the tool and returns string result. + """ + import asyncio + + def sync_fn(**kwargs: Any) -> str: + loop = asyncio.get_event_loop() + if loop.is_running(): + import concurrent.futures + + with concurrent.futures.ThreadPoolExecutor() as executor: + future = executor.submit(asyncio.run, env.call_tool(tool_name, **kwargs)) + result = future.result() + else: + result = loop.run_until_complete(env.call_tool(tool_name, **kwargs)) + + return stringify_result(result) + + sync_fn.__name__ = tool_name + sync_fn.__doc__ = description or f"Tool: {tool_name}" + return sync_fn + + +def create_tool_fns( + env: Any, + tool: mcp_types.Tool, +) -> tuple[Callable[..., str], Callable[..., Any]]: + """Create both sync and async functions for a tool. + + Args: + env: Environment with call_tool method. + tool: MCP tool definition. + + Returns: + Tuple of (sync_fn, async_fn). + """ + sync_fn = create_sync_tool_fn(env, tool.name, tool.description) + async_fn = create_async_tool_fn(env, tool.name, tool.description) + return sync_fn, async_fn From 9c4269ba757a7c23f44ee957a9a45665660c63d5 Mon Sep 17 00:00:00 2001 From: lorenss-m Date: Fri, 12 Dec 2025 05:03:17 -0800 Subject: [PATCH 37/92] docs --- docs/guides/sandboxing.mdx | 2 +- docs/quick-links/environments.mdx | 2 +- docs/reference/environments.mdx | 4 ++-- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/docs/guides/sandboxing.mdx b/docs/guides/sandboxing.mdx index dd55b7fe..5a6ca742 100644 --- a/docs/guides/sandboxing.mdx +++ b/docs/guides/sandboxing.mdx @@ -1,7 +1,7 @@ --- title: "Sandboxing" description: "Turn your existing services into agent-testable environments" -icon: "cube" +icon: "shield" --- You have a production stack. You want an agent on it. But you can't just point an agent at production—it'll make real changes, hit real APIs, affect real users. And you can't test at scale against a single live instance with shared state. diff --git a/docs/quick-links/environments.mdx b/docs/quick-links/environments.mdx index 4ccb49ce..9a442f47 100644 --- a/docs/quick-links/environments.mdx +++ b/docs/quick-links/environments.mdx @@ -1,7 +1,7 @@ --- title: "Environments" description: "Turn your code into agent-callable tools. Define how agents are evaluated." -icon: "cube" +icon: "box" --- An environment is everything an agent can interact with—your APIs, services, databases, wrapped as tools. But it's more than that: the environment also defines how agents are *evaluated* through **scripts**. When you deploy an environment, you're creating a sandbox that agents can learn from at scale. diff --git a/docs/reference/environments.mdx b/docs/reference/environments.mdx index f9632318..8dfd5e81 100644 --- a/docs/reference/environments.mdx +++ b/docs/reference/environments.mdx @@ -1,7 +1,7 @@ --- -title: "Environments" +title: "Environment" description: "SDK reference for the Environment class - tools, connectors, and integrations" -icon: "cube" +icon: "desktop" --- `Environment` is the unified class for defining tools, connecting to services, and formatting for any LLM provider. From 6041aeed8368f034071fb733a38e6566fd669ba8 Mon Sep 17 00:00:00 2001 From: lorenss-m Date: Fri, 12 Dec 2025 05:12:28 -0800 Subject: [PATCH 38/92] small docs updates --- docs/quick-links/deploy.mdx | 10 ++++++++++ docs/quick-links/environments.mdx | 2 +- 2 files changed, 11 insertions(+), 1 deletion(-) diff --git a/docs/quick-links/deploy.mdx b/docs/quick-links/deploy.mdx index 4d5f137f..ba7ec3f7 100644 --- a/docs/quick-links/deploy.mdx +++ b/docs/quick-links/deploy.mdx @@ -54,3 +54,13 @@ With your environment deployed: - **Train**: Use runs as training data. Fine-tune on successful completions. Run reinforcement learning to optimize for your specific environment. The loop: deploy → eval at scale → analyze → train → redeploy. Agents get better at *your* environment. + + + + Connect OpenAI, Anthropic, LangChain, and more. + + + + Turn production services into safe test environments. + + diff --git a/docs/quick-links/environments.mdx b/docs/quick-links/environments.mdx index 9a442f47..4ccb49ce 100644 --- a/docs/quick-links/environments.mdx +++ b/docs/quick-links/environments.mdx @@ -1,7 +1,7 @@ --- title: "Environments" description: "Turn your code into agent-callable tools. Define how agents are evaluated." -icon: "box" +icon: "cube" --- An environment is everything an agent can interact with—your APIs, services, databases, wrapped as tools. But it's more than that: the environment also defines how agents are *evaluated* through **scripts**. When you deploy an environment, you're creating a sandbox that agents can learn from at scale. From 6468aa94b38c965933a00ca85a6bda6fa02141b5 Mon Sep 17 00:00:00 2001 From: lorenss-m Date: Fri, 12 Dec 2025 23:46:57 -0800 Subject: [PATCH 39/92] updates to logic all round --- examples/01_agent_lifecycle.py | 206 ++-- examples/02_claude_agent.py | 84 +- examples/03_openai_compatible_agent.py | 37 +- examples/04_grounded_agent.py | 11 +- examples/05_custom_agent.py | 12 +- examples/README.md | 17 +- examples/integration_otel.py | 26 +- examples/run_evaluation.py | 65 +- hud/agents/base.py | 502 ++------ hud/agents/claude.py | 8 +- hud/agents/gemini.py | 8 +- hud/agents/grounded_openai.py | 21 +- hud/agents/misc/integration_test_agent.py | 71 +- hud/agents/openai.py | 5 +- hud/agents/tests/conftest.py | 114 +- hud/agents/tests/test_base.py | 828 ++++--------- hud/agents/tests/test_base_runtime.py | 224 ++-- hud/agents/tests/test_claude.py | 466 ++++--- hud/agents/tests/test_gemini.py | 539 ++------ hud/agents/tests/test_openai.py | 1084 ++++------------- hud/agents/tests/test_operator.py | 114 +- hud/agents/tests/test_run_eval.py | 129 +- hud/cli/eval.py | 157 +-- hud/cli/flows/tasks.py | 8 +- hud/cli/flows/templates.py | 8 +- hud/cli/rft.py | 4 +- hud/cli/tests/test_convert.py | 10 +- hud/cli/tests/test_eval.py | 642 +++------- hud/cli/utils/metadata.py | 10 +- hud/clients/base.py | 14 +- hud/clients/tests/test_analyze_scenarios.py | 2 +- hud/datasets/__init__.py | 10 +- hud/datasets/loader.py | 177 +++ hud/datasets/runner.py | 369 +----- hud/datasets/tests/test_loader.py | 196 +++ hud/datasets/tests/test_utils.py | 22 +- hud/datasets/utils.py | 32 +- hud/environment/__init__.py | 7 +- hud/environment/connectors/__init__.py | 7 +- hud/environment/connectors/remote.py | 52 +- hud/environment/connectors/task.py | 109 -- hud/environment/environment.py | 139 +-- hud/environment/{scripts.py => scenarios.py} | 228 ++-- hud/environment/tests/test_connectors.py | 49 - hud/environment/tests/test_environment.py | 25 - .../{test_scripts.py => test_scenarios.py} | 136 +-- hud/environment/types.py | 35 +- hud/eval/__init__.py | 16 +- hud/eval/context.py | 141 +-- hud/eval/eval.py | 254 ---- hud/eval/manager.py | 426 ++----- hud/eval/task.py | 437 +++++++ hud/eval/tests/test_eval.py | 277 +++-- hud/eval/types.py | 3 - hud/samples/browser.py | 6 +- hud/server/server.py | 2 +- hud/tests/test_datasets_extended.py | 241 ++-- hud/tests/test_types.py | 38 +- hud/tools/grounding/grounded_tool.py | 31 +- hud/types.py | 41 +- hud/utils/tasks.py | 26 +- hud/utils/tests/test_tasks.py | 24 +- 62 files changed, 3234 insertions(+), 5748 deletions(-) create mode 100644 hud/datasets/loader.py create mode 100644 hud/datasets/tests/test_loader.py delete mode 100644 hud/environment/connectors/task.py rename hud/environment/{scripts.py => scenarios.py} (60%) rename hud/environment/tests/{test_scripts.py => test_scenarios.py} (64%) delete mode 100644 hud/eval/eval.py create mode 100644 hud/eval/task.py diff --git a/examples/01_agent_lifecycle.py b/examples/01_agent_lifecycle.py index c094a183..8d15e1ce 100644 --- a/examples/01_agent_lifecycle.py +++ b/examples/01_agent_lifecycle.py @@ -2,140 +2,100 @@ """ Complete Agent Lifecycle Example -This example demonstrates the full agent lifecycle: -- Task definition with setup and evaluation tools -- Agent initialization -- Setup phase -- Agent execution loop -- Tool call handling -- Evaluation phase -- Cleanup - -The entire flow is wrapped in hud.trace() to provide RUN_ID context. +This example demonstrates the full agent lifecycle using Task.from_v4(): +- Task definition with setup and evaluation tools (v4 LegacyTask format) +- Conversion to v5 Task using Task.from_v4() +- hud.eval() context for connection and tracing +- Agent initialization and execution +- Automatic setup/evaluate tool execution +- Result collection + +For simpler usage, just use `await agent.run(ctx)` which handles everything. +This example shows what happens under the hood. """ import asyncio import hud -from hud.datasets import Task -from hud.clients import MCPClient -from hud.agents.claude import ClaudeAgent -from hud.agents.base import find_reward, find_content +from hud.datasets import LegacyTask +from hud.eval.task import Task +from hud.agents import ClaudeAgent async def main(): - # Wrap everything in trace to provide RUN_ID for the task - with hud.trace("Agent Lifecycle Demo"): - # Define a complete task with setup and evaluation - task_dict = { - "prompt": "Create a new todo item with the title 'Buy groceries' and description 'Milk, eggs, bread'", - "mcp_config": { - "hud": { - "url": "https://mcp.hud.ai/v3/mcp", - "headers": { - "Authorization": "Bearer ${HUD_API_KEY}", # Automatically filled from env - "Mcp-Image": "hudevals/hud-browser:latest", - }, - } - }, - "setup_tool": {"name": "launch_app", "arguments": {"app_name": "todo"}}, - "evaluate_tool": { - "name": "evaluate", - "arguments": {"name": "todo_exists", "arguments": {"title": "Buy groceries"}}, - }, - } - task = Task(**task_dict) - - # Create MCP client with resolved config - client = MCPClient(mcp_config=task.mcp_config) - - # Create agent - agent = ClaudeAgent.create( - mcp_client=client, - checkpoint_name="claude-sonnet-4-5", - allowed_tools=["anthropic_computer"], - initial_screenshot=True, - ) - - try: - # Phase 1: Initialize agent with task context - print("🔧 Initializing agent...") - await agent.initialize(task) - - # Phase 2: Run setup tool - print("📋 Running setup...") - setup_result = await agent.call_tools(task.setup_tool) - setup_content = setup_result[0].content - print("✅ Setup complete") - - # Phase 3: Add context and first messages - print(f"\n🤖 Running task: {task.prompt}") - messages = await agent.get_system_messages() - - # Add context - context = await agent.format_message( - [ - *setup_content, - task.prompt, - ] - ) - - messages.extend(context) - print(f"Messages: {messages}") - - # Phase 4: Run agent loop - done = False - steps = 0 - max_steps = 10 - - # Use messages as the state for the agent - while not done and steps < max_steps: - # Get model response - response = await agent.get_response(messages) - print(f"\n Step {steps + 1}:") - - if response.content: - print(f" 💭 Agent: {response.content[:100]}...") - - if response.tool_calls: - # Execute tool calls - tool_results = await agent.call_tools(response.tool_calls) - - # Format results back into messages - messages.extend( - await agent.format_tool_results(response.tool_calls, tool_results) - ) - else: - # No more tool calls, we're done - done = True - - steps += 1 - - # Phase 4: Run evaluation - print("\n📊 Running evaluation...") - eval_result = await agent.call_tools(task.evaluate_tool) - - if eval_result[0].isError: - print(f"❌ Evaluation failed: {eval_result[0].content}") - else: - reward = find_reward(eval_result[0]) - eval_content = find_content(eval_result[0]) - print(f"✅ Evaluation complete - Reward: {reward}") - print(f"✅ Evaluation complete - Content: {eval_content}") - - # Summary - print("\n📈 Summary:") - print(f" Total steps: {steps}") - print(f" Task completed: {done}") + print("🚀 Agent Lifecycle Example") + print("=" * 50) - finally: - # Phase 5: Cleanup - print("\n🧹 Cleaning up...") - await client.shutdown() + # Phase 1: Define task in v4 LegacyTask format + # This format includes setup_tool and evaluate_tool + print("📋 Defining task...") + legacy_task = LegacyTask( + prompt="Create a new todo item with the title 'Buy groceries' and description 'Milk, eggs, bread'", + mcp_config={ + "hud": { + "url": "https://mcp.hud.ai/v3/mcp", + "headers": { + "Authorization": "Bearer ${HUD_API_KEY}", # Auto-resolved from env + "Mcp-Image": "hudevals/hud-browser:latest", + }, + } + }, + setup_tool={"name": "launch_app", "arguments": {"app_name": "todo"}}, + evaluate_tool={ + "name": "evaluate", + "arguments": {"name": "todo_exists", "arguments": {"title": "Buy groceries"}}, + }, + ) + + # Phase 2: Convert to v5 Task + # Task.from_v4() creates an Environment with: + # - mcp_config connection (connects on context entry) + # - setup_tool calls (run on context entry) + # - evaluate_tool calls (run on context exit) + print("🔄 Converting to v5 Task...") + task = Task.from_v4(legacy_task) + + # Phase 3: Create agent + print("🤖 Creating Claude agent...") + agent = ClaudeAgent.create( + checkpoint_name="claude-sonnet-4-5", + allowed_tools=["anthropic_computer"], + initial_screenshot=True, + ) + + # Phase 4: Enter eval context and run agent + # The context manager handles: + # - Environment connection (MCP servers start) + # - Setup tools execution (launch_app) + # - Trace creation for telemetry + print("🔧 Entering eval context...") + async with task as ctx: + print(f" ✅ Environment connected") + print(f" ✅ Setup tools executed") + print(f" 📝 Prompt: {ctx.prompt[:50]}...") + + # Phase 5: Run the agent + # agent.run() handles the agentic loop: + # - Gets system messages + # - Sends prompt to model + # - Processes tool calls + # - Continues until done or max_steps + print("\n🏃 Running agent loop...") + result = await agent.run(ctx, max_steps=10) + + print(f"\n Agent finished:") + print(f" - Done: {result.done}") + print(f" - Has error: {result.isError}") + if result.content: + print(f" - Response: {result.content[:100]}...") + + # Phase 6: After exit, evaluate_tool was automatically called + # and ctx.reward is set based on the evaluation + print("\n📊 Evaluation complete (via evaluate_tool)") + print(f" Reward: {ctx.reward}") + print(f" Success: {ctx.success}") print("\n✨ Agent lifecycle demo complete!") if __name__ == "__main__": - print("🚀 Agent Lifecycle Example") - print("=" * 50) asyncio.run(main()) diff --git a/examples/02_claude_agent.py b/examples/02_claude_agent.py index 0c5150be..9cf5b1b8 100644 --- a/examples/02_claude_agent.py +++ b/examples/02_claude_agent.py @@ -15,18 +15,36 @@ import asyncio import hud from hud.agents import ClaudeAgent -from hud.clients import MCPClient +from hud.datasets import LegacyTask +from hud.eval.task import Task from hud.settings import settings async def main(): - with hud.trace("Claude Agent Demo"): - # For any environment, you can run : - # hud debug to see the logs - # hud analyze to get a report about its capabilities (tools, resources, etc.) - # e.g. hud analyze hudpython/hud-remote-browser:latest + # For any environment, you can run : + # hud debug to see the logs + # hud analyze to get a report about its capabilities (tools, resources, etc.) + # e.g. hud analyze hudpython/hud-remote-browser:latest - mcp_config = { + initial_url = "https://httpbin.org/forms/post" + + prompt = f""" + Please help me test a web form: + 1. Navigate to {initial_url} + 2. Fill in the customer name as "Claude Test" + 3. Enter the telephone as "555-0123" + 4. Type "Testing form submission with Claude" in the comments + 5. Select a small pizza size + 6. Choose "bacon" as a topping + 7. Set delivery time to "20:30" + 8. Submit the form + 9. Verify the submission was successful + """ + + # Create LegacyTask with mcp_config and setup + legacy_task = LegacyTask( + prompt=prompt, + mcp_config={ "hud": { "url": "https://mcp.hud.ai/v3/mcp", "headers": { @@ -34,44 +52,32 @@ async def main(): "Mcp-Image": "hudpython/hud-remote-browser:latest", }, } - } - - # Create Claude-specific agent - client = MCPClient(mcp_config=mcp_config) - agent = ClaudeAgent.create( - mcp_client=client, - checkpoint_name="claude-sonnet-4-5", - allowed_tools=["anthropic_computer"], - initial_screenshot=True, - ) - - initial_url = "https://httpbin.org/forms/post" + }, + setup_tool={ + "name": "setup", + "arguments": {"name": "navigate_to_url", "arguments": {"url": initial_url}}, + }, + ) - prompt = f""" - Please help me test a web form: - 1. Navigate to {initial_url} - 2. Fill in the customer name as "Claude Test" - 3. Enter the telephone as "555-0123" - 4. Type "Testing form submission with Claude" in the comments - 5. Select a small pizza size - 6. Choose "bacon" as a topping - 7. Set delivery time to "20:30" - 8. Submit the form - 9. Verify the submission was successful - """ + # Convert to v5 Task + task = Task.from_v4(legacy_task) - print(f"📋 Task: Multi-step form interaction") - print(f"🚀 Running Claude agent...\n") + # Create Claude-specific agent + agent = ClaudeAgent.create( + checkpoint_name="claude-sonnet-4-5", + allowed_tools=["anthropic_computer"], + initial_screenshot=True, + ) - await client.call_tool( - name="setup", - arguments={"name": "navigate_to_url", "arguments": {"url": initial_url}}, - ) + print(f"📋 Task: Multi-step form interaction") + print(f"🚀 Running Claude agent...\n") - # Run the task - await agent.run(prompt, max_steps=15) + # Run with hud.eval() context + async with task as ctx: + result = await agent.run(ctx, max_steps=15) print("\n✨ Claude agent demo complete!") + print(f" Reward: {result.reward}") if __name__ == "__main__": diff --git a/examples/03_openai_compatible_agent.py b/examples/03_openai_compatible_agent.py index 5e1fbd5f..51578398 100644 --- a/examples/03_openai_compatible_agent.py +++ b/examples/03_openai_compatible_agent.py @@ -24,7 +24,7 @@ import hud from hud.agents.openai_chat import OpenAIChatAgent -from hud.datasets import Task +from hud.datasets import LegacyTask def _system_prompt(mode: Literal["text", "browser"]) -> str: @@ -66,7 +66,7 @@ def _system_prompt(mode: Literal["text", "browser"]) -> str: ) -def _task_for_mode(mode: Literal["text", "browser"], target: int) -> Task: +def _task_for_mode(mode: Literal["text", "browser"], target: int) -> LegacyTask: if mode == "browser": mcp_config = { "local": { @@ -100,7 +100,7 @@ def _task_for_mode(mode: Literal["text", "browser"], target: int) -> Task: "arguments": {"name": "max_number", "arguments": {"target": target}}, } - return Task( + return LegacyTask( prompt=prompt, mcp_config=mcp_config, setup_tool=setup_tool, # type: ignore[arg-type] @@ -136,20 +136,23 @@ async def run_example(mode: Literal["text", "browser"], target: int) -> None: ) title = "OpenAI 2048 Game (Browser)" if mode == "browser" else "OpenAI 2048 Game (Text)" - async with hud.async_job(title, metadata={"model": checkpoint, "mode": mode}) as job: - print("🎮 Starting 2048 game with OpenAI-compatible agent...") - print(f"🤖 Model: {agent.config.checkpoint_name}") - print(f"🧩 Mode: {mode}") - print("=" * 50) - - async with hud.async_trace("Game Execution", job_id=job.id): - result = await agent.run(task, max_steps=100) - - print("=" * 50) - print("✅ Game completed!") - print(f"🏆 Final Score/Max Tile: {result.reward}") - if result.info: - print(f"📊 Game Stats: {result.info}") + print("🎮 Starting 2048 game with OpenAI-compatible agent...") + print(f"🤖 Model: {agent.config.checkpoint_name}") + print(f"🧩 Mode: {mode}") + print("=" * 50) + + # Use hud.eval() with Task.from_v4() for legacy task format + from hud.eval.task import Task + + v5_task = Task.from_v4(task) + async with hud.eval(v5_task, variants={"model": checkpoint, "mode": mode}) as ctx: + result = await agent.run(ctx, max_steps=100) + + print("=" * 50) + print("✅ Game completed!") + print(f"🏆 Final Score/Max Tile: {result.reward}") + if result.info: + print(f"📊 Game Stats: {result.info}") def _parse_args() -> argparse.Namespace: diff --git a/examples/04_grounded_agent.py b/examples/04_grounded_agent.py index e2a31685..636baaf6 100644 --- a/examples/04_grounded_agent.py +++ b/examples/04_grounded_agent.py @@ -54,7 +54,7 @@ async def main(): try: # Create a task with MCP config - from hud.datasets import Task + from hud.datasets import LegacyTask form_url = "https://hb.cran.dev/forms/post" @@ -68,7 +68,7 @@ async def main(): 6. Submit the form """ - task = Task( + legacy_task = LegacyTask( prompt=form_prompt, mcp_config=mcp_config, setup_tool={ @@ -80,7 +80,12 @@ async def main(): print(f"📋 Task: Form interaction") print(f"🚀 Running grounded agent...\n") - result = await agent.run(task, max_steps=10) + # Convert LegacyTask to Task and run with hud.eval() + from hud.eval.task import Task + + task = Task.from_v4(legacy_task) + async with task as ctx: + result = await agent.run(ctx, max_steps=10) print(f"Result: {result.content}\n") except Exception as e: diff --git a/examples/05_custom_agent.py b/examples/05_custom_agent.py index 262e7e78..094cb0f3 100644 --- a/examples/05_custom_agent.py +++ b/examples/05_custom_agent.py @@ -20,7 +20,7 @@ from hud import instrument from hud.agents.base import MCPAgent -from hud.datasets import Task +from hud.datasets import LegacyTask from hud.settings import settings from hud.types import AgentResponse, MCPToolCall, MCPToolResult @@ -207,7 +207,7 @@ async def main(): ) # Define a task with HUD MCP environment - task = Task( + legacy_task = LegacyTask( prompt="Go to example.com and tell me the page title", mcp_config={ "hud": { @@ -220,9 +220,15 @@ async def main(): }, ) + # Convert to v5 Task and run with context manager + from hud.eval.task import Task + + task = Task.from_v4(legacy_task) + # Run the agent - traces are automatically captured print("Running agent with HUD Gateway inference...") - result = await agent.run(task, max_steps=5) + async with task as ctx: + result = await agent.run(ctx, max_steps=5) print("\n=== Results ===") print(f"Done: {result.done}") diff --git a/examples/README.md b/examples/README.md index c7997cb6..303ca16c 100644 --- a/examples/README.md +++ b/examples/README.md @@ -35,21 +35,24 @@ python examples/03_browser_agent_loop.py --app todo ## Core Patterns -### 02_agent_lifecycle.py +### 01_agent_lifecycle.py Demonstrates the full agent lifecycle with telemetry and state management. -- Task creation and configuration +- Task creation using LegacyTask format - Trace context for debugging -- State persistence between runs +- Setup and evaluation tool calls ### run_evaluation.py -Generic dataset evaluation runner supporting multiple agents. +Generic dataset evaluation runner using the programmatic API. ```bash -# Run single task +# Run all tasks python examples/run_evaluation.py hud-evals/SheetBench-50 -# Run full dataset -python examples/run_evaluation.py hud-evals/SheetBench-50 --full +# Run specific tasks by index +python examples/run_evaluation.py hud-evals/SheetBench-50 --task-ids 0 1 2 + +# Use different agent +python examples/run_evaluation.py hud-evals/OSWorld-Verified-Gold --agent operator ``` ## Integration Examples diff --git a/examples/integration_otel.py b/examples/integration_otel.py index 4fe4ec7e..9644b67d 100644 --- a/examples/integration_otel.py +++ b/examples/integration_otel.py @@ -39,7 +39,7 @@ import hud from hud.agents import ClaudeAgent from hud.clients import MCPClient -from hud.datasets import Task +from hud.datasets import LegacyTask async def main(): @@ -59,23 +59,27 @@ async def main(): "arguments": {"name": "max_number"}, }, } - task = Task(**task_dict) + task = LegacyTask(**task_dict) # Create client and agent - mcp_client = MCPClient(mcp_config=task.mcp_config) # Create agent - its methods are already instrumented with @hud.instrument - agent = ClaudeAgent.create( - mcp_client=mcp_client, - ) + agent = ClaudeAgent.create() - # Run with hud.trace() - this creates the root span in Jaeger + # Convert to v5 Task and run with hud.eval() + from hud.eval.task import Task + + v5_task = Task.from_v4(task) + + # Run with hud.trace() and hud.eval() - this creates spans in Jaeger with hud.trace("play_2048_game"): print(f"🎮 Starting 2048 game") - # Agent will play the game with setup and evaluate phases - # Each call to get_model_response() and execute_tools() - # will create child spans in Jaeger automatically - result = await agent.run(task, max_steps=20) + # Use Task as context manager to get EvalContext + async with v5_task as ctx: + # Agent will play the game with setup and evaluate phases + # Each call to get_model_response() and execute_tools() + # will create child spans in Jaeger automatically + result = await agent.run(ctx, max_steps=20) print(f"\n🏁 Game finished!") print(f" Final reward: {result.reward}") diff --git a/examples/run_evaluation.py b/examples/run_evaluation.py index 7e200001..1171f4f1 100644 --- a/examples/run_evaluation.py +++ b/examples/run_evaluation.py @@ -1,5 +1,5 @@ #!/usr/bin/env python3 -"""Example: Running evaluations programmatically with run_tasks. +"""Example: Running evaluations programmatically with run_dataset. For CLI usage, prefer `hud eval` which handles config files, interactive agent selection, and more. This example shows the programmatic API. @@ -14,61 +14,62 @@ import argparse import asyncio -from typing import Any, cast - -from datasets import load_dataset - -from hud.datasets import run_tasks, display_results -from hud.types import AgentType, Task async def main() -> None: parser = argparse.ArgumentParser(description="Run evaluation on a HUD dataset") - parser.add_argument("dataset", help="HuggingFace dataset ID (e.g., hud-evals/SheetBench-50)") + parser.add_argument("dataset", help="Dataset source (e.g., hud-evals/SheetBench-50)") parser.add_argument("--agent", choices=["claude", "operator"], default="claude") parser.add_argument("--model", default=None, help="Model name override") parser.add_argument("--max-concurrent", type=int, default=30, help="Max concurrent tasks") parser.add_argument("--max-steps", type=int, default=50, help="Max steps per task") parser.add_argument("--group-size", type=int, default=1, help="Runs per task (for variance)") - parser.add_argument("--task-ids", nargs="*", help="Specific task IDs to run (optional)") + parser.add_argument("--task-ids", nargs="*", help="Specific task indices to run (optional)") args = parser.parse_args() - # Load dataset and convert to Task objects + # Import here to avoid import errors if agents not installed + from hud.datasets import load_dataset, run_dataset, display_results + + # Load dataset as Task objects print(f"Loading {args.dataset}...") - raw_dataset = load_dataset(args.dataset, split="train") - tasks = [Task(**cast("dict[str, Any]", row)) for row in raw_dataset] + tasks = load_dataset(args.dataset) - # Filter by task IDs if specified + # Filter by index if specified if args.task_ids: - tasks = [t for t in tasks if t.id in args.task_ids] - print(f"Filtered to {len(tasks)} tasks: {args.task_ids}") + indices = [int(tid) for tid in args.task_ids] + tasks = [tasks[i] for i in indices if i < len(tasks)] + print(f"Filtered to {len(tasks)} tasks at indices: {args.task_ids}") - # Select agent type and params + # Create agent instance based on type if args.agent == "operator": - agent_type = AgentType.OPERATOR - agent_params = { - "checkpoint_name": args.model or "computer-use-preview", - "validate_api_key": False, - } + from hud.agents import OperatorAgent + + agent = OperatorAgent.create( + checkpoint_name=args.model or "computer-use-preview", + ) else: - agent_type = AgentType.CLAUDE - agent_params = { - "checkpoint_name": args.model or "claude-sonnet-4-5", - "validate_api_key": False, - } + from hud.agents import ClaudeAgent + + agent = ClaudeAgent.create( + checkpoint_name=args.model or "claude-sonnet-4-5", + ) # Run evaluation - results = await run_tasks( + print(f"Running {len(tasks)} tasks with {args.agent} agent...") + results = await run_dataset( tasks=tasks, - agent_type=agent_type, - agent_params=agent_params, - name=f"Eval: {args.dataset.split('/')[-1]}", - max_concurrent=args.max_concurrent, + agent=agent, max_steps=args.max_steps, + max_concurrent=args.max_concurrent, group_size=args.group_size, ) - display_results(results, tasks=tasks) + # Display results + print(f"\n{'='*50}") + print(f"Completed {len(results)} tasks") + for i, ctx in enumerate(results): + reward = ctx.reward if hasattr(ctx, "reward") else "N/A" + print(f" Task {i}: reward={reward}") if __name__ == "__main__": diff --git a/hud/agents/base.py b/hud/agents/base.py index 803d3380..8c03a156 100644 --- a/hud/agents/base.py +++ b/hud/agents/base.py @@ -3,7 +3,6 @@ from __future__ import annotations import asyncio -import fnmatch import json import logging from abc import ABC, abstractmethod @@ -12,13 +11,10 @@ import mcp.types as types from pydantic import BaseModel, ConfigDict -from hud.clients.base import AgentMCPClient from hud.types import AgentResponse, BaseAgentConfig, MCPToolCall, MCPToolResult, Trace from hud.utils.hud_console import HUDConsole -from hud.utils.mcp import MCPConfigPatch, patch_mcp_config, setup_hud_telemetry if TYPE_CHECKING: - from hud.datasets import Task from hud.environment import Environment from hud.eval.context import EvalContext @@ -31,7 +27,9 @@ class BaseCreateParams(BaseModel): model_config = ConfigDict(arbitrary_types_allowed=True) - mcp_client: AgentMCPClient | None = None + # Primary way to bind agent to execution context (v5) + ctx: Any | None = None # EvalContext or Environment - agent uses this for tool calls + auto_trace: bool = True auto_respond: bool = False verbose: bool = False @@ -41,19 +39,14 @@ class MCPAgent(ABC): """ Base class for MCP-enabled agents. - Provides common behavior for agents that interact with MCP servers, including: - - Client management: accepts an `AgentMCPClient` or auto-creates one at - runtime when `run()` is called with a `Task` that includes `mcp_config`. - - Tool lifecycle: discovery, filtering (`allowed_tools`, `disallowed_tools`), - and automatic marking of lifecycle tools (setup/evaluate) from a `Task`. - - Messaging: system prompt handling, optional inclusion of setup output on - the first turn, and control over initial screenshots. - - Telemetry & UX: standardized logging/printing via `HUDConsole` and optional - automatic tracing (`auto_trace`). + Agents interact with MCP servers through an EvalContext: + - run(ctx): Main entry point - takes EvalContext from hud.eval() + - ctx.call_tool(): Used internally for all tool execution + - ctx.submit(): Called automatically with agent's final response Subclasses implement provider-specific formatting and response fetching - by overriding these abstract methods: `get_system_messages`, `get_response`, - `format_blocks`, and `format_tool_results`. + by overriding: `get_system_messages`, `get_response`, `format_blocks`, + and `format_tool_results`. """ metadata: ClassVar[dict[str, Any] | None] = None @@ -82,7 +75,9 @@ def __init__(self, params: BaseCreateParams | None = None, **kwargs: Any) -> Non } self.config = self.config_cls(**config_kwargs) - self.mcp_client = params.mcp_client + # v5: Store execution context (EvalContext/Environment) - agent uses ctx.call_tool() + self.ctx: EvalContext | Environment | None = params.ctx + self.model_name: str = getattr(params, "model_name", "MCPAgent") self.checkpoint_name: str = getattr(params, "checkpoint_name", "unknown") self.auto_respond = params.auto_respond @@ -92,15 +87,11 @@ def __init__(self, params: BaseCreateParams | None = None, **kwargs: Any) -> Non if params.verbose: self.console.set_verbose(True) - self.allowed_tools = self.config.allowed_tools - self.disallowed_tools = self.config.disallowed_tools self.system_prompt = self.config.system_prompt - self.append_setup_output = self.config.append_setup_output - self.initial_screenshot = self.config.initial_screenshot - self.response_tool_name = self.config.response_tool_name self._available_tools: list[types.Tool] | None = None self._tool_map: dict[str, types.Tool] = {} + self._initialized: bool = False # Trace self._auto_trace = params.auto_trace @@ -118,85 +109,20 @@ def create(cls, **kwargs: Any) -> MCPAgent: ) return cls(params=CreateParams(**kwargs)) - async def initialize(self, task: str | Task | None = None) -> None: - """Initialize the agent with task-specific configuration.""" - from hud.datasets import Task - - # Create client if needed - if self.mcp_client is None and isinstance(task, Task) and task.mcp_config: - from hud.clients import MCPClient + async def _initialize_from_ctx(self, ctx: EvalContext) -> None: + """Initialize agent from EvalContext - discovers tools and sets up state. - self.mcp_client = MCPClient(mcp_config=task.mcp_config) - self.console.debug("Auto-created MCPClient from task.mcp_config") + This is the v5 initialization path. The agent uses ctx.call_tool() directly + for tool execution (no EnvironmentClient wrapper needed). + """ + from hud.eval.context import EvalContext - # Ensure we have a client - if self.mcp_client is None: - raise ValueError( - "No MCPClient. Please provide one when initializing the agent or pass a Task with mcp_config." # noqa: E501 - ) + if not isinstance(ctx, EvalContext): + raise TypeError(f"ctx must be EvalContext, got {type(ctx).__name__}") - try: - client_cfg = getattr(self.mcp_client, "mcp_config", None) - except Exception: - client_cfg = None - await self._setup_config(client_cfg) - - # Initialize client if needed - try: - await self.mcp_client.initialize() - except Exception as e: - self.console.error_log(f"Failed to initialize MCP client: {e}") - self._handle_connection_error(e) - - # If task is provided, apply agent_config and add lifecycle tools - if isinstance(task, Task) and task.agent_config: - agent_cfg = task.agent_config - if agent_cfg.system_prompt: - if self.system_prompt is None: - self.system_prompt = agent_cfg.system_prompt - else: - self.system_prompt += "\n\n" + agent_cfg.system_prompt - if "append_setup_output" in agent_cfg.model_fields_set: - self.append_setup_output = agent_cfg.append_setup_output - if "initial_screenshot" in agent_cfg.model_fields_set: - self.initial_screenshot = agent_cfg.initial_screenshot - if agent_cfg.allowed_tools is not None: - # If allowed_tools has already been set, we take the intersection of the two - # If the list had been empty, we were allowing all tools, so we overwrite this - if isinstance(self.allowed_tools, list) and len(self.allowed_tools) > 0: - # If task allows "*", keep CLI's allowed_tools unchanged - if "*" not in agent_cfg.allowed_tools: - self.allowed_tools = [ - tool for tool in self.allowed_tools if tool in agent_cfg.allowed_tools - ] - # else: task allows all tools, so CLI's allowed_tools takes precedence - else: # If allowed_tools is None, we overwrite it - self.allowed_tools = agent_cfg.allowed_tools - if agent_cfg.disallowed_tools is not None: - # If disallowed_tools has already been set, we take the union of the two - if isinstance(self.disallowed_tools, list): - self.disallowed_tools.extend(agent_cfg.disallowed_tools) - else: # If disallowed_tools is None, we overwrite it - self.disallowed_tools = agent_cfg.disallowed_tools - if agent_cfg.response_tool_name is not None: - self.response_tool_name = agent_cfg.response_tool_name - - all_tools = await self.mcp_client.list_tools() - self._available_tools = [] - - # Filter tools based on allowed and disallowed patterns - # No allowed tools and no disallowed tools -> we accept all tools - # No allowed tools and disallowed tools -> we accept all tools except the disallowed ones - for tool in all_tools: - if self.allowed_tools is not None and not any( - fnmatch.fnmatch(tool.name, pattern) for pattern in self.allowed_tools - ): - continue - if self.disallowed_tools is not None and any( - fnmatch.fnmatch(tool.name, pattern) for pattern in self.disallowed_tools - ): - continue - self._available_tools.append(tool) + # Get tools from the context (tool filtering is done at Environment/Task level) + self._available_tools = await ctx.list_tools() + self._tool_map = {t.name: t for t in self._available_tools} # Validate required tools are present available_tool_names = {t.name for t in self._available_tools} @@ -208,265 +134,88 @@ async def initialize(self, task: str | Task | None = None) -> None: ) self.console.info( - f"Agent initialized with {len(self.get_available_tools())} tools: {', '.join([t.name for t in self.get_available_tools()])}" # noqa: E501 + f"Agent initialized with {len(self._available_tools)} tools: " + f"{', '.join([t.name for t in self._available_tools])}" ) - async def run( - self, - prompt_or_task: str | Task | EvalContext | Environment | dict[str, Any], - max_steps: int = 10, - ) -> Trace: - """ - Run the agent with the given prompt, task, or environment. - - Args: - prompt_or_task: One of: - - str: Simple text prompt - - Task: Task object with mcp_config, setup_tool, evaluate_tool - - EvalContext: From hud.eval() - uses ctx.prompt, ctx.call_tool, ctx.submit - - Environment: Connected environment to use for tool calls - - dict: Task-like dict (converted to Task) - max_steps: Maximum number of steps (-1 for infinite) - - Returns: - Trace with reward, done, content, isError fields and trace steps - - Example: - # With EvalContext from hud.eval - async with hud.eval(evals) as ctx: - result = await agent.run(ctx) - # result.reward comes from script evaluate - """ - # Import here to avoid circular imports - from hud.datasets import Task - from hud.environment import Environment - from hud.eval.context import EvalContext - - # Handle EvalContext - delegate to run_eval - if isinstance(prompt_or_task, EvalContext): - return await self.run_eval(prompt_or_task, max_steps=max_steps) - - # Handle Environment (non-eval) - wrap with EnvironmentClient - if isinstance(prompt_or_task, Environment) and not isinstance(prompt_or_task, EvalContext): - from hud.clients.environment import EnvironmentClient - - env = prompt_or_task - if not env.prompt: - raise ValueError("Environment.prompt is not set") - - client = EnvironmentClient(env) - self.mcp_client = client - - try: - await self.initialize(env.prompt) - result = await self._run_context(text_to_blocks(env.prompt), max_steps=max_steps) - return result - finally: - self.mcp_client = None - - if isinstance(prompt_or_task, dict): - prompt_or_task = Task(**prompt_or_task) - elif not isinstance(prompt_or_task, str) and not isinstance(prompt_or_task, Task): - raise TypeError( - f"prompt_or_task must be str, Task, EvalContext, or Environment, " - f"got {type(prompt_or_task)}" - ) + # Call hook for subclass-specific initialization (e.g., tool format conversion) + self._on_tools_ready() - try: - # Establish the connection with the MCP server/Environment - await self.initialize(prompt_or_task) - - # Handle Task objects with full lifecycle - if isinstance(prompt_or_task, Task): - return await self.run_task(prompt_or_task, max_steps) - - # Handle simple string prompts - elif isinstance(prompt_or_task, str): - context = text_to_blocks(prompt_or_task) - return await self._run_context(context, max_steps=max_steps) - - except Exception as e: - logger.exception("Error while running agent:") - # Always return a Trace object for any exception - if self._is_connection_error(e): - # Return error trace for connection failures - return Trace( - reward=0.0, - done=True, - content=self._get_connection_error_message(e), - isError=True, - ) - else: - # Return error trace for any other exception - return Trace( - reward=0.0, - done=True, - content=f"Task failed with error: {e}", - isError=True, - info={"error": str(e)}, - ) - finally: - # Cleanup auto-created resources - await self._cleanup() + self._initialized = True - async def run_task(self, task: Task, max_steps: int = 10) -> Trace: - """ - Execute a task with setup and evaluate phases. + def _on_tools_ready(self) -> None: + """Hook called after tools are discovered and validated. - Args: - task: Task object with prompt, setup, and evaluate configs - max_steps: Maximum steps for task execution (-1 for infinite) + Subclasses can override this to perform provider-specific setup, + such as converting MCP tools to the provider's format. - Returns: - Trace with reward from evaluation + Called by _initialize_from_ctx() after _available_tools is populated. """ - try: - # Setup phase - start_context: list[types.ContentBlock] = [] - - # Extract the initial task information - if task.prompt: - start_context.extend(text_to_blocks(task.prompt)) - - # Execute the setup tool and append the initial observation to the context - if task.setup_tool is not None: - self.console.progress_log(f"Setting up tool phase: {task.setup_tool}") - results = await self.call_tools(task.setup_tool) - if any(result.isError for result in results): - for result in results: - if result.isError: - self.console.error_log(f"Error in setup tool: {result}") - - return Trace( - reward=0.0, - done=True, - content=f"Setup tool failed: {results}", - isError=True, - task=task, - ) - - if self.append_setup_output and isinstance(results[0].content, list): - start_context.extend(results[0].content) - if not self.initial_screenshot: - start_context = await self._filter_messages(start_context, include_types=["text"]) - - # Execute the task (agent loop) - this returns a empty trace object with the final response # noqa: E501 - prompt_result = await self._run_context(start_context, max_steps=max_steps) - - except Exception as e: - self.console.error_log(f"Task execution failed: {e}") - # Create an error result but don't return yet - we still want to evaluate - prompt_result = Trace(reward=0.0, done=True, content=str(e), isError=True, task=task) - prompt_result.populate_from_context() - - # Always evaluate if we have evaluate tool, regardless of errors - if task.evaluate_tool is not None: - try: - results = await self.call_tools(task.evaluate_tool) - - if any(result.isError for result in results): - self.console.warning_log(f"Evaluate tool returned error: {results}") - # Still extract what we can from the error response - if prompt_result is None: - prompt_result = Trace( - reward=0.0, - done=True, - content="Task failed before evaluation", - isError=True, - task=task, - ) - prompt_result.reward = 0.0 # Default to 0 on error - else: - # Extract reward and content from evaluation - if results: - reward = find_reward(results[0]) - self.console.info_log(f"Eval: {reward:.4f} {task.evaluate_tool}") - eval_content = find_content(results[0]) - - # Update the prompt result with evaluation reward - if prompt_result is None: - prompt_result = Trace( - reward=reward, - done=True, - content=eval_content or "", - isError=False, - task=task, - ) - else: - prompt_result.reward = reward - - # Update the prompt result with evaluation content (if available) - if eval_content: - # Prompt result may already have final response content, - # so we append to it - if prompt_result.content: - prompt_result.content += "\n\n" + eval_content - else: - prompt_result.content = eval_content + return # Default no-op - subclasses override for provider-specific setup - except Exception as e: - self.console.error_log(f"Evaluation phase failed: {e}") - # Ensure we have a result even if evaluation failed - if prompt_result is None: - prompt_result = Trace( - reward=0.0, - done=True, - content=f"Evaluation failed: {e}", - isError=True, - task=task, - ) - - prompt_result.task = task - - return prompt_result - - async def run_eval(self, ctx: EvalContext, *, max_steps: int = 10) -> Trace: + async def run( + self, + ctx: EvalContext, + *, + max_steps: int = 10, + ) -> Trace: """ - Run the agent with an EvalContext from hud.eval(). + Run the agent on the given evaluation context. - This method integrates with the hud.eval framework: - - Uses ctx.prompt as the starting prompt - - Uses ctx for tool calls via EnvironmentClient adapter - - Calls ctx.submit(response) when the agent finishes - - Reward is available on ctx.reward after the hud.eval block exits + The agent uses ctx.prompt as the task and ctx.call_tool() for tool execution. + Automatically calls ctx.submit() with the final answer. Args: - ctx: EvalContext from hud.eval() - already connected and has prompt set + ctx: EvalContext from hud.eval() - contains prompt and tools max_steps: Maximum number of agent steps (-1 for infinite) Returns: - Trace with agent output. Note: ctx.reward is set by script evaluate - phase which runs when the hud.eval block exits. + Trace with done, content, isError fields Example: ```python - async with hud.eval(evals) as ctx: - result = await agent.run_eval(ctx) - # ctx.reward is now set by the script's evaluate phase - print(f"Reward: {ctx.reward}") + async with hud.eval(task) as ctx: + agent = ClaudeAgent.create() + await agent.run(ctx) + # ctx.reward is set by the scenario's evaluate phase ``` """ - from hud.clients.environment import EnvironmentClient from hud.eval.context import EvalContext if not isinstance(ctx, EvalContext): - raise TypeError(f"ctx must be EvalContext, got {type(ctx)}") + raise TypeError(f"ctx must be EvalContext, got {type(ctx).__name__}") if not ctx.prompt: - raise ValueError("EvalContext.prompt is not set - did the script setup run?") + raise ValueError("ctx.prompt is not set - did the scenario setup run?") + + # Store context for tool calls + self.ctx = ctx + + # Initialize tools from context + if not self._initialized: + await self._initialize_from_ctx(ctx) - self.mcp_client = EnvironmentClient(ctx) try: - await self.initialize(ctx.prompt) result = await self._run_context(text_to_blocks(ctx.prompt), max_steps=max_steps) + + # Submit final answer to context if result.content: await ctx.submit(result.content) + return result + except Exception as e: - logger.exception("Error running agent with EvalContext:") - return Trace(reward=0.0, done=True, content=str(e), isError=True) + logger.exception("Error while running agent:") + return Trace( + reward=0.0, + done=True, + content=f"Agent failed with error: {e}", + isError=True, + info={"error": str(e)}, + ) finally: - self.mcp_client = None + # Cleanup auto-created resources + await self._cleanup() async def _run_context( self, context: list[types.ContentBlock], *, max_steps: int = 10 @@ -521,9 +270,6 @@ async def _run_context( except Exception as e: self.console.warning_log(f"Auto-respond failed: {e}") if decision == "STOP": - # Try to submit response through lifecycle tool - await self._maybe_submit_response(response, messages) - self.console.debug("Stopping execution") final_response = response break @@ -595,7 +341,7 @@ async def call_tools( self, tool_call: MCPToolCall | list[MCPToolCall] | None = None ) -> list[MCPToolResult]: """ - Call a tool through the MCP client. + Call tools through the bound EvalContext. Args: tool_call: MCPToolCall or list of MCPToolCall @@ -609,20 +355,17 @@ async def call_tools( if isinstance(tool_call, MCPToolCall): tool_call = [tool_call] - if self.mcp_client is None: - raise ValueError("Client is not initialized") + if self.ctx is None: + raise ValueError("Agent not bound to context - call run(ctx) first") results: list[MCPToolResult] = [] for tc in tool_call: try: self.console.debug(f"Calling tool: {tc}") - results.append(await self.mcp_client.call_tool(tc)) + result = await self.ctx.call_tool(tc) + results.append(MCPToolResult(content=result.content, isError=result.isError)) except TimeoutError as e: self.console.error_log(f"Tool execution timed out: {e}") - try: - await self.mcp_client.shutdown() - except Exception as close_err: - self.console.debug(f"Failed to close MCP client cleanly: {close_err}") raise except Exception as e: self.console.error_log(f"Tool execution failed: {e}") @@ -702,45 +445,6 @@ async def format_message( return await self.format_blocks(blocks) - async def _maybe_submit_response(self, response: AgentResponse, messages: list[Any]) -> None: - """Submit response through lifecycle tool if available. - - Args: - response: The agent's response - messages: The current message history (will be modified in-place) - """ - if self.response_tool_name: - self.console.debug(f"Calling response lifecycle tool: {self.response_tool_name}") - try: - # Call the response tool with the agent's response - response_tool_call = MCPToolCall( - name=self.response_tool_name, arguments={"response": response.content} - ) - response_results = await self.call_tools(response_tool_call) - - # Format and add the response tool results to messages - response_messages = await self.format_tool_results( - [response_tool_call], response_results - ) - messages.extend(response_messages) - - # Mark the task as done - self.console.debug("Response lifecycle tool executed, marking task as done") - except Exception as e: - self.console.error_log(f"Response lifecycle tool failed: {e}") - - async def _setup_config(self, mcp_config: dict[str, dict[str, Any]] | None) -> None: - """Inject metadata into the metadata of the initialize request.""" - if not isinstance(mcp_config, dict): - return - - if self.metadata: - patch_mcp_config( - mcp_config, - MCPConfigPatch(meta=self.metadata), - ) - self._auto_trace_cm = setup_hud_telemetry(mcp_config, auto_trace=self._auto_trace) - def get_available_tools(self) -> list[types.Tool]: """Get list of available MCP tools for LLM use (excludes lifecycle tools).""" if self._available_tools is None: @@ -793,54 +497,8 @@ async def _cleanup(self) -> None: finally: self._auto_trace_cm = None - # Always clean up the client - if self.mcp_client: - try: - await self.mcp_client.shutdown() - self.console.debug("Closed auto-created MCPClient") - except Exception as e: - self.console.warning_log(f"Failed to close auto-created client: {e}") - finally: - self.mcp_client = None - - def _is_connection_error(self, e: Exception) -> bool: - """Check if an exception is a connection error.""" - error_msg = str(e).lower() - return any( - pattern in error_msg - for pattern in [ - "connection", - "connect", - "refused", - "failed", - "could not connect", - "mcp server", - ] - ) - - def _get_connection_error_message(self, e: Exception) -> str: - """Extract a helpful connection error message.""" - import re - - url_match = re.search(r"https?://[^\s]+", str(e)) - url = url_match.group(0) if url_match else "the MCP server" - return f"Connection failed: Could not connect to {url}. Is your MCP client/server running?" - - def _handle_connection_error(self, e: Exception) -> None: - """Handle connection errors with helpful messages.""" - if self._is_connection_error(e): - msg = self._get_connection_error_message(e) - # Always show connection errors, not just when logging is enabled - self.console.error(f"❌ {msg}") - self.console.info("💡 Make sure the MCP server is started before running the agent.") - - # For localhost, provide specific instructions - error_str = str(e).lower() - if "localhost" in error_str or "127.0.0.1" in error_str: - self.console.info(" Run 'hud dev' in another terminal to start the MCP server") - - raise RuntimeError(msg) from e - raise + # Clear context reference + self.ctx = None def _format_error_result(error_message: str) -> MCPToolResult: diff --git a/hud/agents/claude.py b/hud/agents/claude.py index a9c9ad28..d8124246 100644 --- a/hud/agents/claude.py +++ b/hud/agents/claude.py @@ -26,7 +26,7 @@ import hud if TYPE_CHECKING: - from hud.datasets import Task + from hud.datasets import LegacyTask import mcp.types as types from pydantic import ConfigDict @@ -103,10 +103,8 @@ def __init__(self, params: ClaudeCreateParams | None = None, **kwargs: Any) -> N self.tool_mapping: dict[str, str] = {} self.claude_tools: list[BetaToolUnionParam] = [] - async def initialize(self, task: str | Task | None = None) -> None: - """Initialize the agent and build tool mappings.""" - await super().initialize(task) - # Build tool mappings after tools are discovered + def _on_tools_ready(self) -> None: + """Build Claude-specific tool mappings after tools are discovered.""" self._convert_tools_for_claude() async def get_system_messages(self) -> list[Any]: diff --git a/hud/agents/gemini.py b/hud/agents/gemini.py index c2377b95..91942e49 100644 --- a/hud/agents/gemini.py +++ b/hud/agents/gemini.py @@ -12,7 +12,7 @@ import hud if TYPE_CHECKING: - from hud.datasets import Task + from hud.datasets import LegacyTask import mcp.types as types @@ -89,10 +89,8 @@ def __init__(self, params: GeminiCreateParams | None = None, **kwargs: Any) -> N self._gemini_to_mcp_tool_map: dict[str, str] = {} self.gemini_tools: genai_types.ToolListUnion = [] - async def initialize(self, task: str | Task | None = None) -> None: - """Initialize the agent and build tool mappings.""" - await super().initialize(task) - # Build tool mappings after tools are discovered + def _on_tools_ready(self) -> None: + """Build Gemini-specific tool mappings after tools are discovered.""" self._convert_tools_for_gemini() async def get_system_messages(self) -> list[Any]: diff --git a/hud/agents/grounded_openai.py b/hud/agents/grounded_openai.py index 6372bea6..e427bcb6 100644 --- a/hud/agents/grounded_openai.py +++ b/hud/agents/grounded_openai.py @@ -83,15 +83,12 @@ def __init__(self, params: GroundedOpenAICreateParams | None = None, **kwargs: A self.grounder = Grounder(self.config.grounder_config) self.grounded_tool: GroundedComputerTool | None = None - async def initialize(self, task: Any = None) -> None: - """Initialize the agent and create the grounded tool with mcp_client.""" - # Call parent initialization first - await super().initialize(task) - - if self.mcp_client is None: - raise ValueError("mcp_client must be initialized before creating grounded tool") + def _on_tools_ready(self) -> None: + """Create the grounded tool after context is bound.""" + if self.ctx is None: + raise ValueError("ctx must be set before creating grounded tool") self.grounded_tool = GroundedComputerTool( - grounder=self.grounder, mcp_client=self.mcp_client, computer_tool_name="computer" + grounder=self.grounder, ctx=self.ctx, computer_tool_name="computer" ) def get_tool_schemas(self) -> list[Any]: @@ -141,10 +138,10 @@ async def get_response(self, messages: Any) -> AgentResponse: ) if not has_image: - if self.mcp_client is None: - raise ValueError("mcp_client is not initialized") - screenshot_result = await self.mcp_client.call_tool( - MCPToolCall(name="computer", arguments={"action": "screenshot"}) + if self.ctx is None: + raise ValueError("ctx is not initialized") + screenshot_result = await self.ctx.call_tool( + ("computer", {"action": "screenshot"}) ) for block in screenshot_result.content: diff --git a/hud/agents/misc/integration_test_agent.py b/hud/agents/misc/integration_test_agent.py index 254b6669..ec9d71e1 100644 --- a/hud/agents/misc/integration_test_agent.py +++ b/hud/agents/misc/integration_test_agent.py @@ -1,12 +1,21 @@ from __future__ import annotations -from typing import Any, ClassVar +from typing import TYPE_CHECKING, Any, ClassVar -from hud.agents.base import MCPAgent, find_reward -from hud.types import AgentResponse, BaseAgentConfig, Task, Trace +from hud.agents.base import MCPAgent +from hud.types import AgentResponse, BaseAgentConfig, Trace + +if TYPE_CHECKING: + from hud.eval.context import EvalContext class IntegrationTestRunner(MCPAgent): + """Special agent that runs integration tests by executing tools directly. + + Unlike regular agents, this doesn't run an LLM loop - it executes + integration_test_tool and evaluate_tool in sequence to verify tool behavior. + """ + metadata: ClassVar[dict[str, Any] | None] = {} config_cls: ClassVar[type[BaseAgentConfig]] = BaseAgentConfig @@ -14,38 +23,50 @@ def __init__(self, **kwargs: Any) -> None: kwargs["auto_trace"] = False super().__init__(**kwargs) - async def run(self, prompt_or_task: str | Task | dict[str, Any], max_steps: int = 10) -> Trace: + async def run( + self, + ctx: EvalContext, + *, + max_steps: int = 10, + ) -> Trace: + """Run integration test by executing tools directly. + + The EvalContext should have integration_test_tool and evaluate_tool + configured in its metadata or environment setup. + """ + from hud.eval.context import EvalContext + + if not isinstance(ctx, EvalContext): + raise TypeError(f"ctx must be EvalContext, got {type(ctx).__name__}") + + self.ctx = ctx + try: - # Initialize using base to set up client and telemetry correctly - if isinstance(prompt_or_task, str): - task = Task(prompt=prompt_or_task, mcp_config={}) - elif isinstance(prompt_or_task, dict): - task = Task(**prompt_or_task) - else: - task = prompt_or_task - await self.initialize(task) + # Initialize tools from context + if not self._initialized: + await self._initialize_from_ctx(ctx) self.console.info(f"Full system prompt: {self.system_prompt}") - # Validate task shape - if not getattr(task, "integration_test_tool", None): + # For integration tests, we expect the context's environment to have + # _setup_calls, _integration_test_calls, and _evaluate_calls configured + env = ctx + + # Run integration test tool (stored in environment metadata or separate list) + integration_test_calls = getattr(env, "_integration_test_calls", []) + if not integration_test_calls: raise ValueError( - "--integration-test requires task.integration_test_tool (single call)" + "--integration-test requires integration_test_tool to be configured" ) - elif not getattr(task, "evaluate_tool", None): - raise ValueError("--integration-test requires task.evaluate_tool (single call)") - - if task.setup_tool: - _ = await self.call_tools(task.setup_tool) - _ = await self.call_tools(task.integration_test_tool) - evaluate_result = await self.call_tools(task.evaluate_tool) + for name, args in integration_test_calls: + await ctx.call_tool((name, args)) - reward = float(find_reward(evaluate_result[0])) if evaluate_result else 0.0 + # The evaluate phase runs automatically when ctx exits, + # but we can also get the reward from ctx.reward after + return Trace(done=True, reward=ctx.reward or 0.0, info={}) - return Trace(done=True, reward=reward, info={}) finally: - # Ensure resources are cleaned up so the CLI can exit cleanly await self._cleanup() # Stub implementations to satisfy abstract base class; not used in --integration-test path diff --git a/hud/agents/openai.py b/hud/agents/openai.py index a803ead7..2a19c776 100644 --- a/hud/agents/openai.py +++ b/hud/agents/openai.py @@ -106,9 +106,8 @@ def __init__(self, params: OpenAICreateParams | None = None, **kwargs: Any) -> N self.last_response_id: str | None = None self._message_cursor = 0 - async def initialize(self, task: Any | None = None) -> None: - """Initialize agent and build tool metadata.""" - await super().initialize(task) + def _on_tools_ready(self) -> None: + """Build OpenAI-specific tool mappings after tools are discovered.""" self._convert_tools_for_openai() def _to_openai_tool( diff --git a/hud/agents/tests/conftest.py b/hud/agents/tests/conftest.py index 871b96b6..41f55b9b 100644 --- a/hud/agents/tests/conftest.py +++ b/hud/agents/tests/conftest.py @@ -2,77 +2,72 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Any +from typing import Any import pytest - -if TYPE_CHECKING: - from collections.abc import Callable from mcp import types +from hud.eval.context import EvalContext from hud.types import MCPToolCall, MCPToolResult -class MockMCPClient: - """Mock MCP client that satisfies AgentMCPClient protocol.""" +class MockEvalContext(EvalContext): + """Mock EvalContext for testing agents. - _initialized: bool = False + This provides a minimal EvalContext implementation that can be used + to test agent initialization and tool calling without a real environment. + """ def __init__( self, + prompt: str = "Test prompt", tools: list[types.Tool] | None = None, - call_tool_handler: Callable[[MCPToolCall], MCPToolResult] | None = None, - initialize_error: Exception | None = None, + call_tool_handler: Any = None, ) -> None: - self._mcp_config: dict[str, dict[str, Any]] = {"test": {"url": "http://test"}} + self.prompt = prompt self._tools = tools or [] + self._submitted: str | None = None + self.reward: float | None = None self._call_tool_handler = call_tool_handler - self._initialize_error = initialize_error - self.call_tool_calls: list[MCPToolCall] = [] - self.shutdown_called = False - - @property - def mcp_config(self) -> dict[str, dict[str, Any]]: - return self._mcp_config - - @property - def is_connected(self) -> bool: - return self._initialized - - async def initialize(self, mcp_config: dict[str, dict[str, Any]] | None = None) -> None: - if self._initialize_error: - raise self._initialize_error - self._initialized = True - - async def shutdown(self) -> None: - self.shutdown_called = True + self.tool_calls: list[tuple[str, dict[str, Any]]] = [] async def list_tools(self) -> list[types.Tool]: return self._tools - async def call_tool(self, tool_call: MCPToolCall) -> MCPToolResult: - self.call_tool_calls.append(tool_call) + async def call_tool(self, call: Any, /, **kwargs: Any) -> MCPToolResult: + # Parse the call + if isinstance(call, tuple): + name, args = call[0], call[1] if len(call) > 1 else {} + elif hasattr(call, "name"): + name, args = call.name, getattr(call, "arguments", {}) or {} + else: + name, args = str(call), kwargs + + self.tool_calls.append((name, args)) + if self._call_tool_handler: - return self._call_tool_handler(tool_call) - return MCPToolResult(content=[]) + tc = MCPToolCall(name=name, arguments=args) + return self._call_tool_handler(tc) - def get_available_tools(self) -> list[types.Tool]: - return self._tools + return MCPToolResult( + content=[types.TextContent(type="text", text=f"Result from {name}")], + isError=False, + ) - def get_tool_map(self) -> dict[str, types.Tool]: - return {t.name: t for t in self._tools} + async def submit(self, answer: str) -> None: + self._submitted = answer @pytest.fixture -def mock_mcp_client() -> MockMCPClient: - """Create a mock MCP client that satisfies the AgentMCPClient protocol.""" - return MockMCPClient() +def mock_eval_context() -> MockEvalContext: + """Create a basic mock EvalContext.""" + return MockEvalContext() @pytest.fixture -def mock_mcp_client_with_tools() -> MockMCPClient: - """Create a mock MCP client with a test tool.""" - return MockMCPClient( +def mock_eval_context_with_tools() -> MockEvalContext: + """Create a mock EvalContext with test tools.""" + return MockEvalContext( tools=[ types.Tool( name="test_tool", @@ -84,41 +79,26 @@ def mock_mcp_client_with_tools() -> MockMCPClient: @pytest.fixture -def mock_mcp_client_openai_computer() -> MockMCPClient: - """Create a mock MCP client with openai_computer tool for Operator tests.""" - return MockMCPClient( - tools=[ - types.Tool( - name="openai_computer", - description="OpenAI computer use tool", - inputSchema={}, - ) - ] - ) - - -@pytest.fixture -def mock_mcp_client_gemini_computer() -> MockMCPClient: - """Create a mock MCP client with gemini_computer tool for Gemini tests.""" - return MockMCPClient( +def mock_eval_context_computer() -> MockEvalContext: + """Create a mock EvalContext with computer tool.""" + return MockEvalContext( tools=[ types.Tool( - name="gemini_computer", - description="Gemini computer use tool", - inputSchema={}, + name="computer", + description="Computer use tool", + inputSchema={"type": "object"}, ) ] ) @pytest.fixture -def mock_mcp_client_browser_tools() -> MockMCPClient: - """Create a mock MCP client with browser-like tools for extended tests.""" - return MockMCPClient( +def mock_eval_context_browser_tools() -> MockEvalContext: + """Create a mock EvalContext with browser-like tools.""" + return MockEvalContext( tools=[ types.Tool(name="screenshot", description="Take screenshot", inputSchema={}), types.Tool(name="click", description="Click at coordinates", inputSchema={}), types.Tool(name="type", description="Type text", inputSchema={}), - types.Tool(name="bad_tool", description="A tool that fails", inputSchema={}), ] ) diff --git a/hud/agents/tests/test_base.py b/hud/agents/tests/test_base.py index b2b544c2..25fab1d8 100644 --- a/hud/agents/tests/test_base.py +++ b/hud/agents/tests/test_base.py @@ -1,27 +1,16 @@ -"""Tests for BaseMCPAgent using simulated actions.""" +"""Tests for MCPAgent base class with v5 EvalContext pattern.""" from __future__ import annotations from typing import Any, ClassVar -from unittest.mock import MagicMock - -# Import AsyncMock from unittest.mock if available (Python 3.8+) -try: - from unittest.mock import AsyncMock -except ImportError: - # Fallback for older Python versions - from unittest.mock import MagicMock as AsyncMock import pytest from mcp import types from hud.agents import MCPAgent from hud.agents.base import BaseCreateParams -from hud.datasets import Task -from hud.tools.executors.base import BaseExecutor -from hud.types import AgentResponse, BaseAgentConfig, MCPToolCall, MCPToolResult, Trace - -from .conftest import MockMCPClient +from hud.eval.context import EvalContext +from hud.types import AgentResponse, BaseAgentConfig, MCPToolCall, MCPToolResult class MockConfig(BaseAgentConfig): @@ -33,705 +22,302 @@ class MockCreateParams(BaseCreateParams, MockConfig): pass +class MockEvalContext(EvalContext): + """Mock EvalContext for testing.""" + + def __init__( + self, + prompt: str = "Test prompt", + tools: list[types.Tool] | None = None, + ) -> None: + self.prompt = prompt + self._tools = tools or [ + types.Tool(name="test_tool", description="A test tool", inputSchema={}), + types.Tool(name="another_tool", description="Another tool", inputSchema={}), + ] + self._submitted: str | None = None + self.reward: float | None = None + self._tool_calls: list[tuple[str, dict[str, Any]]] = [] + + async def list_tools(self) -> list[types.Tool]: + return self._tools + + async def call_tool(self, call: Any, /, **kwargs: Any) -> MCPToolResult: + # Parse the call + if isinstance(call, tuple): + name, args = call[0], call[1] if len(call) > 1 else {} + elif hasattr(call, "name"): + name, args = call.name, getattr(call, "arguments", {}) or {} + else: + name, args = str(call), kwargs + self._tool_calls.append((name, args)) + return MCPToolResult( + content=[types.TextContent(type="text", text=f"Result from {name}")], + isError=False, + ) + + async def submit(self, answer: str) -> None: + self._submitted = answer + + class MockMCPAgent(MCPAgent): - """Concrete implementation of BaseMCPAgent for testing.""" + """Concrete implementation of MCPAgent for testing.""" metadata: ClassVar[dict[str, Any] | None] = {} config_cls: ClassVar[type[BaseAgentConfig]] = MockConfig - def __init__(self, mcp_client: Any = None, **kwargs: Any) -> None: - if mcp_client is None: - mcp_client = MockMCPClient() - - kwargs.setdefault("mcp_client", mcp_client) + def __init__(self, **kwargs: Any) -> None: params = MockCreateParams(**kwargs) super().__init__(params) - self.executor = BaseExecutor() - self._messages: list[dict[str, Any]] = [] + self._response = AgentResponse(content="Mock response", tool_calls=[], done=True) - async def create_initial_messages( - self, prompt: str, initial_screenshot: bool = False - ) -> list[dict[str, Any]]: - """Mock create initial messages.""" - messages = [{"role": "user", "content": prompt}] - if initial_screenshot: - messages.append({"role": "assistant", "content": "Screenshot: mock_screenshot"}) - return messages + def set_response(self, response: AgentResponse) -> None: + self._response = response async def get_response(self, messages: list[dict[str, Any]]) -> AgentResponse: - """Mock get response.""" - return AgentResponse(content="Mock response", tool_calls=[], done=True) + return self._response async def format_tool_results( self, tool_calls: list[MCPToolCall], tool_results: list[MCPToolResult] ) -> list[dict[str, Any]]: - """Mock format tool results.""" formatted = [] - for tool_call, result in zip(tool_calls, tool_results): + for tool_call, result in zip(tool_calls, tool_results, strict=True): formatted.append({"role": "tool", "name": tool_call.name, "content": str(result)}) return formatted - async def create_user_message(self, text: str) -> Any: - """Mock create user message.""" - return {"role": "user", "content": text} - async def get_system_messages(self) -> list[Any]: - """Mock get system messages.""" return [] async def format_blocks(self, blocks: list[types.ContentBlock]) -> list[Any]: - """Mock format blocks.""" - formatted = [] - for block in blocks: - if isinstance(block, types.TextContent): - formatted.append({"type": "text", "text": block.text}) - elif isinstance(block, types.ImageContent): - formatted.append({"type": "image", "data": block.data}) - elif hasattr(block, "type"): - formatted.append({"type": getattr(block, "type", "unknown")}) - return formatted - + return [{"type": "text", "text": getattr(b, "text", "")} for b in blocks] -class TestBaseMCPAgent: - """Tests for BaseMCPAgent with simulated actions.""" - def test_init_defaults(self): - """Test initialization with default values.""" - agent = MockMCPAgent() +class TestMCPAgentInit: + """Tests for MCPAgent initialization.""" - assert agent.mcp_client is not None - assert agent.allowed_tools is None - assert agent.disallowed_tools is None - assert agent.initial_screenshot is True - - def test_init_with_params(self, mock_mcp_client): - """Test initialization with custom parameters.""" - agent = MockMCPAgent( - mcp_client=mock_mcp_client, - allowed_tools=["tool1", "tool2"], - disallowed_tools=["bad_tool"], - initial_screenshot=True, - system_prompt="Custom prompt", - ) + def test_init_defaults(self) -> None: + """Test agent initializes with default config.""" + agent = MockMCPAgent(auto_trace=False) + assert agent.ctx is None + assert agent._initialized is False + assert agent.system_prompt is None - assert agent.mcp_client == mock_mcp_client - assert agent.allowed_tools == ["tool1", "tool2"] - assert agent.disallowed_tools == ["bad_tool"] - assert agent.initial_screenshot is True + def test_init_with_system_prompt(self) -> None: + """Test agent with custom system prompt.""" + agent = MockMCPAgent(auto_trace=False, system_prompt="Custom prompt") assert agent.system_prompt == "Custom prompt" - @pytest.mark.asyncio - async def test_init_no_client_no_task(self): - """Test initialize fails without client and without task.""" - - # Create a minimal concrete implementation to test the ValueError - class TestAgentConfig(BaseAgentConfig): - model_name: str = "TestAgent" - checkpoint_name: str = "test-model" - - class TestAgentCreateParams(BaseCreateParams, TestAgentConfig): - pass - - class TestAgent(MCPAgent): - config_cls = TestAgentConfig - def __init__(self, **kwargs: Any) -> None: - params = TestAgentCreateParams(**kwargs) - super().__init__(params) +class TestMCPAgentRun: + """Tests for MCPAgent.run() with EvalContext.""" - async def create_initial_messages( - self, prompt: str, initial_screenshot: bool = False - ) -> list[dict[str, Any]]: - return [] - - async def format_tool_results( - self, tool_calls: list[MCPToolCall], tool_results: list[MCPToolResult] - ) -> list[dict[str, Any]]: - return [] - - async def get_response(self, messages: list[dict[str, Any]]) -> AgentResponse: - return AgentResponse(content="test", tool_calls=[], done=True) + @pytest.mark.asyncio + async def test_run_basic(self) -> None: + """Test basic run flow with EvalContext.""" + ctx = MockEvalContext(prompt="Do something") + agent = MockMCPAgent(auto_trace=False) - async def get_system_messages(self) -> list[Any]: - return [] + result = await agent.run(ctx) - async def format_blocks(self, blocks: list[types.ContentBlock]) -> list[Any]: - return [] + assert result.done is True + assert result.content == "Mock response" + assert ctx._submitted == "Mock response" - # Agent can be created with None client - agent = TestAgent(mcp_client=None) + @pytest.mark.asyncio + async def test_run_initializes_agent(self) -> None: + """Test run() initializes the agent with context.""" + ctx = MockEvalContext(prompt="Do something") + agent = MockMCPAgent(auto_trace=False) - # But initialize should fail without client or task - with pytest.raises(ValueError, match="No MCPClient"): - await agent.initialize() + assert not agent._initialized + await agent.run(ctx) + assert agent._initialized @pytest.mark.asyncio - async def test_initialize_with_sessions(self): - """Test initialize with existing sessions.""" - agent = MockMCPAgent() - - # Create proper async mock for session - mock_session = MagicMock() - - # Set up the connector and client_session structure - mock_session.connector = MagicMock() - mock_session.connector.client_session = MagicMock() - - # Mock list_tools on the client_session - async def mock_list_tools(): - return types.ListToolsResult( - tools=[ - types.Tool(name="tool1", description="Tool 1", inputSchema={"type": "object"}), - types.Tool(name="tool2", description="Tool 2", inputSchema={"type": "object"}), - types.Tool( - name="setup", description="Setup tool", inputSchema={"type": "object"} - ), - ] - ) - - mock_session.connector.client_session.list_tools = mock_list_tools + async def test_run_discovers_tools(self) -> None: + """Test run() discovers tools from context.""" + tools = [ + types.Tool(name="tool1", description="Tool 1", inputSchema={}), + types.Tool(name="tool2", description="Tool 2", inputSchema={}), + ] + ctx = MockEvalContext(prompt="Do something", tools=tools) + agent = MockMCPAgent(auto_trace=False) - assert agent.mcp_client is not None + # We need to check tools before cleanup + # Store a reference to check + discovered_tools = [] - # Mock the list_tools method on mcp_client to return the tools - agent.mcp_client.list_tools = AsyncMock( - return_value=[ - types.Tool(name="tool1", description="Tool 1", inputSchema={"type": "object"}), - types.Tool(name="tool2", description="Tool 2", inputSchema={"type": "object"}), - types.Tool(name="setup", description="Setup tool", inputSchema={"type": "object"}), - ] - ) + original_run = agent._run_context - await agent.initialize() + async def capture_tools(*args: Any, **kwargs: Any) -> Any: + discovered_tools.extend(agent.get_available_tools()) + return await original_run(*args, **kwargs) - # Check available tools were populated (excludes lifecycle tools) - tools = agent.get_available_tools() - assert len(tools) == 3 # All tools (setup is not in default lifecycle tools) + agent._run_context = capture_tools # type: ignore + await agent.run(ctx) - # Ensure names exist in available tools - names = {t.name for t in tools} - assert {"tool1", "tool2", "setup"} <= names + assert len(discovered_tools) == 2 + assert discovered_tools[0].name == "tool1" + assert discovered_tools[1].name == "tool2" @pytest.mark.asyncio - async def test_initialize_with_filtering(self): - """Test initialize with tool filtering.""" - agent = MockMCPAgent(allowed_tools=["tool1"], disallowed_tools=["tool3"]) - - # Create proper async mock for session - mock_session = MagicMock() - - # Set up the connector and client_session structure - mock_session.connector = MagicMock() - mock_session.connector.client_session = MagicMock() - - async def mock_list_tools(): - return types.ListToolsResult( - tools=[ - types.Tool(name="tool1", description="Tool 1", inputSchema={"type": "object"}), - types.Tool(name="tool2", description="Tool 2", inputSchema={"type": "object"}), - types.Tool(name="tool3", description="Tool 3", inputSchema={"type": "object"}), - types.Tool(name="setup", description="Setup", inputSchema={"type": "object"}), - ] - ) - - mock_session.connector.client_session.list_tools = mock_list_tools - - assert agent.mcp_client is not None + async def test_run_requires_eval_context(self) -> None: + """Test run() raises TypeError for non-EvalContext.""" + agent = MockMCPAgent(auto_trace=False) - # Mock the list_tools method on mcp_client to return the tools - agent.mcp_client.list_tools = AsyncMock( - return_value=[ - types.Tool(name="tool1", description="Tool 1", inputSchema={"type": "object"}), - types.Tool(name="tool2", description="Tool 2", inputSchema={"type": "object"}), - types.Tool(name="tool3", description="Tool 3", inputSchema={"type": "object"}), - types.Tool(name="setup", description="Setup", inputSchema={"type": "object"}), - ] - ) - - await agent.initialize() + with pytest.raises(TypeError, match="must be EvalContext"): + await agent.run("not a context") # type: ignore - # Check filtering worked - get_available_tools excludes lifecycle tools - tools = agent.get_available_tools() - tool_names = [t.name for t in tools] - assert len(tools) == 1 # Only tool1 (tool2 and tool3 are filtered out) - assert "tool1" in tool_names - assert "setup" not in tool_names # Lifecycle tool excluded from available tools - assert "tool2" not in tool_names # Not in allowed list - assert "tool3" not in tool_names # In disallowed list + @pytest.mark.asyncio + async def test_run_requires_prompt(self) -> None: + """Test run() raises ValueError when prompt is empty.""" + ctx = MockEvalContext(prompt="") + agent = MockMCPAgent(auto_trace=False) - # Make sure tool schemas are correct - schemas = agent.get_tool_schemas() - assert len(schemas) == 1 - assert schemas[0]["name"] == "tool1" - assert schemas[0]["description"] == "Tool 1" - assert schemas[0]["parameters"] == {"type": "object"} + with pytest.raises(ValueError, match="prompt is not set"): + await agent.run(ctx) @pytest.mark.asyncio - async def test_call_tool_success(self): - """Test successful tool call.""" - agent = MockMCPAgent() - - # Initialize with a tool - mock_session = MagicMock() - mock_session.connector = MagicMock() - mock_session.connector.client_session = MagicMock() - - async def mock_list_tools(): - return types.ListToolsResult( - tools=[ - types.Tool(name="test_tool", description="Test", inputSchema={"type": "object"}) - ] - ) + async def test_run_clears_context_after(self) -> None: + """Test run() clears ctx after completion.""" + ctx = MockEvalContext(prompt="Do something") + agent = MockMCPAgent(auto_trace=False) - mock_session.connector.client_session.list_tools = mock_list_tools - - # Mock the call_tool method on the client session - mock_result = types.CallToolResult( - content=[types.TextContent(type="text", text="Tool result")], isError=False - ) + await agent.run(ctx) + assert agent.ctx is None - async def mock_call_tool(name, args): - return mock_result + @pytest.mark.asyncio + async def test_run_no_submit_on_empty_content(self) -> None: + """Test run() doesn't submit when content is empty.""" + ctx = MockEvalContext(prompt="Do something") + agent = MockMCPAgent(auto_trace=False) + agent.set_response(AgentResponse(content="", tool_calls=[], done=True)) - mock_session.connector.client_session.call_tool = mock_call_tool + await agent.run(ctx) + assert ctx._submitted is None - assert agent.mcp_client is not None - # Mock the client's call_tool method directly - agent.mcp_client.call_tool = AsyncMock(return_value=mock_result) +class TestMCPAgentToolCalling: + """Tests for tool calling through context.""" - # Mock the list_tools method to return the test tool - agent.mcp_client.list_tools = AsyncMock( - return_value=[ - types.Tool(name="test_tool", description="Test", inputSchema={"type": "object"}) - ] - ) + @pytest.mark.asyncio + async def test_call_tools_uses_context(self) -> None: + """Test call_tools routes through ctx.call_tool.""" + ctx = MockEvalContext(prompt="Do something") + agent = MockMCPAgent(auto_trace=False) - await agent.initialize() + # Bind context manually + agent.ctx = ctx + await agent._initialize_from_ctx(ctx) - # Call the tool - tool_call = MCPToolCall(name="test_tool", arguments={"param": "value"}) - results = await agent.call_tools(tool_call) + # Call a tool + results = await agent.call_tools(MCPToolCall(name="test_tool", arguments={"arg": "value"})) assert len(results) == 1 - assert results[0] == mock_result assert not results[0].isError + assert ("test_tool", {"arg": "value"}) in ctx._tool_calls @pytest.mark.asyncio - async def test_call_tool_not_found(self): - """Test calling non-existent tool.""" - agent = MockMCPAgent() - - # Initialize without tools - mock_session = MagicMock() + async def test_call_tools_without_context_raises(self) -> None: + """Test call_tools raises when no context bound.""" + agent = MockMCPAgent(auto_trace=False) - async def mock_list_tools(): - return types.ListToolsResult(tools=[]) + with pytest.raises(ValueError, match="not bound to context"): + await agent.call_tools(MCPToolCall(name="test_tool", arguments={})) - mock_session.list_tools = mock_list_tools - assert agent.mcp_client is not None - await agent.initialize() - - # Try to call unknown tool - call_tools doesn't raise for unknown tools - tool_call = MCPToolCall(name="unknown_tool", arguments={}) - await agent.call_tools(tool_call) +class TestMCPAgentRequiredTools: + """Tests for required_tools validation.""" @pytest.mark.asyncio - async def test_call_tool_no_name(self): - """Test calling tool without name.""" - # MCPToolCall accepts empty names - agent = MockMCPAgent() - tool_call = MCPToolCall(name="", arguments={}) - - # call_tools doesn't validate empty names, it will return error - await agent.call_tools(tool_call) - - def test_get_tool_schemas(self): - """Test getting tool schemas.""" - agent = MockMCPAgent() - - agent._available_tools = [ - types.Tool(name="tool1", description="Tool 1", inputSchema={"type": "object"}), - types.Tool(name="setup", description="Setup", inputSchema={"type": "object"}), - ] - - schemas = agent.get_tool_schemas() - - # Should include non-lifecycle tools - assert len(schemas) == 2 - assert schemas[0]["name"] == "tool1" - - def test_get_tools_by_server(self): - """Test getting tools grouped by server.""" - agent = MockMCPAgent() - - # Set up tools from different servers - tool1 = types.Tool(name="tool1", description="Tool 1", inputSchema={"type": "object"}) - tool2 = types.Tool(name="tool2", description="Tool 2", inputSchema={"type": "object"}) - - agent._available_tools = [tool1, tool2] - tools = agent.get_available_tools() - assert {t.name for t in tools} == {"tool1", "tool2"} - - @pytest.mark.asyncio - async def test_executor_integration(self): - """Test integration with BaseExecutor for simulated actions.""" - agent = MockMCPAgent() - - # Test various executor actions - click_result = await agent.executor.click(100, 200, take_screenshot=False) - assert click_result.output is not None - assert "[SIMULATED] Click at (100, 200)" in click_result.output - - type_result = await agent.executor.write("Test input", take_screenshot=False) - assert type_result.output is not None - assert "[SIMULATED] Type 'Test input'" in type_result.output - - scroll_result = await agent.executor.scroll(x=50, y=50, scroll_y=5, take_screenshot=False) - assert scroll_result.output is not None - assert "[SIMULATED] Scroll" in scroll_result.output - - # Test screenshot - screenshot = await agent.executor.screenshot() - assert isinstance(screenshot, str) - assert screenshot.startswith("iVBORw0KGgo") # PNG header - - -class MockAgentExtended(MCPAgent): - """Mock agent for testing with predefined responses.""" - - metadata: ClassVar[dict[str, Any] | None] = {} - config_cls: ClassVar[type[BaseAgentConfig]] = MockConfig + async def test_missing_required_tools_raises(self) -> None: + """Test run() raises when required tools are missing.""" - def __init__(self, responses: list[Any] | None = None, **kwargs: Any): - if kwargs.get("mcp_client") is None: - kwargs["mcp_client"] = MockMCPClient() - params = MockCreateParams(**kwargs) - super().__init__(params) - self.responses = responses or [] - self.call_count = 0 - - async def create_initial_messages( - self, prompt: str, initial_screenshot: bool = False - ) -> list[dict[str, Any]]: - """Create initial messages.""" - messages = [{"role": "user", "content": prompt}] - if initial_screenshot: - # capture_screenshot doesn't exist, just mock it - screenshot = "mock_screenshot_data" - messages.append({"role": "assistant", "content": f"Screenshot: {screenshot}"}) - return messages - - async def get_response(self, messages: list[dict[str, Any]]) -> AgentResponse: - """Return predefined responses - must be async.""" - if self.call_count < len(self.responses): - response_dict = self.responses[self.call_count] - self.call_count += 1 - # Convert dict to AgentResponse - return AgentResponse( - content=response_dict.get("content", ""), - tool_calls=response_dict.get("tool_calls", []), - done=response_dict.get("done", not bool(response_dict.get("tool_calls"))), - ) - return AgentResponse(content="Done", tool_calls=[], done=True) - - async def format_tool_results( - self, tool_calls: list[MCPToolCall], tool_results: list[MCPToolResult] - ) -> list[dict[str, Any]]: - """Format tool results.""" - formatted = [] - for tool_call, result in zip(tool_calls, tool_results): - formatted.append({"role": "tool", "name": tool_call.name, "content": str(result)}) - return formatted + class AgentWithRequiredTools(MockMCPAgent): + required_tools: ClassVar[list[str]] = ["must_have_tool"] - async def create_user_message(self, text: str) -> Any: - """Create user message.""" - return {"role": "user", "content": text} + ctx = MockEvalContext(prompt="Do something", tools=[]) + agent = AgentWithRequiredTools(auto_trace=False) - async def get_system_messages(self) -> list[Any]: - """Mock get system messages.""" - return [] - - async def format_blocks(self, blocks: list[types.ContentBlock]) -> list[Any]: - """Mock format blocks.""" - formatted = [] - for block in blocks: - if isinstance(block, types.TextContent): - formatted.append({"type": "text", "text": block.text}) - elif isinstance(block, types.ImageContent): - formatted.append({"type": "image", "data": block.data}) - elif hasattr(block, "type"): - formatted.append({"type": getattr(block, "type", "unknown")}) - return formatted - - -class TestMCPAgentExtended: - """Extended tests for MCPAgent.""" - - @pytest.fixture - def agent_with_tools(self, mock_mcp_client_browser_tools): - """Create agent with mock tools.""" - return MockAgentExtended(mcp_client=mock_mcp_client_browser_tools) + with pytest.raises(ValueError, match="Required tools are missing"): + await agent.run(ctx) @pytest.mark.asyncio - async def test_run_with_task_object(self, agent_with_tools): - """Test running agent with Task object.""" - from hud.types import MCPToolResult - - task = Task( - id="test_task", - prompt="Click the button", - mcp_config={"test_server": {"url": "http://localhost:8080"}}, - setup_tool={"name": "navigate", "arguments": {"url": "https://example.com"}}, # type: ignore[arg-type] - evaluate_tool={"name": "check_result", "arguments": {}}, # type: ignore[arg-type] - ) + async def test_required_tools_present_succeeds(self) -> None: + """Test run() succeeds when required tools are present.""" - # Set up responses - agent_with_tools.responses = [ - { - "role": "assistant", - "content": "I'll click the button", - "tool_calls": [MCPToolCall(name="click", arguments={"x": 100, "y": 200})], - } - ] - - # Mock the evaluation to return a reward - agent_with_tools.mcp_client.call_tool = AsyncMock( - side_effect=[ - # Setup tool - MCPToolResult( - content=[types.TextContent(type="text", text="Navigated")], - isError=False, - ), - # Click tool - MCPToolResult( - content=[types.TextContent(type="text", text="Clicked")], - isError=False, - ), - # Evaluate tool with reward - MCPToolResult( - content=[types.TextContent(type="text", text="Success")], - isError=False, - structuredContent={"reward": 1.0}, - ), - ] - ) + class AgentWithRequiredTools(MockMCPAgent): + required_tools: ClassVar[list[str]] = ["required_tool"] - result = await agent_with_tools.run(task) + tools = [types.Tool(name="required_tool", description="Required", inputSchema={})] + ctx = MockEvalContext(prompt="Do something", tools=tools) + agent = AgentWithRequiredTools(auto_trace=False) - assert isinstance(result, Trace) - assert result.reward == 1.0 - assert not result.isError + result = await agent.run(ctx) assert result.done - @pytest.mark.asyncio - async def test_run_with_setup_error(self, agent_with_tools): - """Test task execution with setup phase error.""" - from hud.types import MCPToolResult - - task = Task( - id="test_task", - prompt="Do something", - mcp_config={"test_server": {"url": "http://localhost:8080"}}, - setup_tool={"name": "bad_setup", "arguments": {}}, # type: ignore[arg-type] - ) - - # Mock setup tool to fail - agent_with_tools.mcp_client.call_tool = AsyncMock( - return_value=MCPToolResult( - content=[types.TextContent(type="text", text="Setup failed")], - isError=True, - ) - ) - - result = await agent_with_tools.run(task) - assert isinstance(result, Trace) - assert result.isError - # Error content is the string representation of the MCPToolResult list - assert result.content is not None - assert "Setup failed" in result.content - assert "MCPToolResult" in result.content +class TestMCPAgentOnToolsReady: + """Tests for _on_tools_ready hook.""" @pytest.mark.asyncio - async def test_run_with_multiple_setup_tools(self, agent_with_tools): - """Test task with multiple setup tools.""" - - task = Task( - id="test_task", - prompt="Test multiple setup", - mcp_config={"test_server": {"url": "http://localhost:8080"}}, - setup_tool=[ - MCPToolCall(name="setup1", arguments={}), - MCPToolCall(name="setup2", arguments={}), - ], - ) - - agent_with_tools.responses = [{"role": "assistant", "content": "Done", "tool_calls": []}] + async def test_on_tools_ready_called(self) -> None: + """Test _on_tools_ready is called during initialization.""" + hook_called = [False] - setup_calls = [] - agent_with_tools.mcp_client.call_tool = AsyncMock( - side_effect=lambda tool_call: setup_calls.append(tool_call) - or MCPToolResult( - content=[types.TextContent(type="text", text=f"{tool_call.name} done")], - isError=False, - ) - ) + class AgentWithHook(MockMCPAgent): + def _on_tools_ready(self) -> None: + hook_called[0] = True - result = await agent_with_tools.run(task) + ctx = MockEvalContext(prompt="Do something") + agent = AgentWithHook(auto_trace=False) - # Check that the tool names match - setup_names = [call.name for call in setup_calls] - assert "setup1" in setup_names - assert "setup2" in setup_names - assert not result.isError + await agent.run(ctx) + assert hook_called[0] @pytest.mark.asyncio - async def test_allowed_tools_filtering(self): - """Test that allowed_tools filters available tools.""" - mock_client = MockMCPClient( - tools=[ - types.Tool(name="tool1", description="Tool 1", inputSchema={}), - types.Tool(name="tool2", description="Tool 2", inputSchema={}), - types.Tool(name="tool3", description="Tool 3", inputSchema={}), - ] - ) - - agent = MockAgentExtended(mcp_client=mock_client, allowed_tools=["tool1", "tool3"]) - await agent.initialize("test") + async def test_on_tools_ready_has_access_to_tools(self) -> None: + """Test _on_tools_ready can access discovered tools.""" + captured_tools: list[types.Tool] = [] - available_names = [tool.name for tool in agent.get_available_tools()] - assert "tool1" in available_names - assert "tool3" in available_names - assert "tool2" not in available_names + class AgentWithHook(MockMCPAgent): + def _on_tools_ready(self) -> None: + captured_tools.extend(self.get_available_tools()) - @pytest.mark.asyncio - async def test_disallowed_tools_filtering(self): - """Test that disallowed_tools filters available tools.""" - mock_client = MockMCPClient( - tools=[ - types.Tool(name="tool1", description="Tool 1", inputSchema={}), - types.Tool(name="tool2", description="Tool 2", inputSchema={}), - types.Tool(name="tool3", description="Tool 3", inputSchema={}), - ] - ) - - agent = MockAgentExtended(mcp_client=mock_client, disallowed_tools=["tool2"]) - await agent.initialize("test") + tools = [ + types.Tool(name="tool1", description="Tool 1", inputSchema={}), + types.Tool(name="tool2", description="Tool 2", inputSchema={}), + ] + ctx = MockEvalContext(prompt="Do something", tools=tools) + agent = AgentWithHook(auto_trace=False) - available_names = [tool.name for tool in agent.get_available_tools()] - assert "tool1" in available_names - assert "tool3" in available_names - assert "tool2" not in available_names + await agent.run(ctx) - @pytest.mark.asyncio - async def test_lifecycle_tools(self): - """Test lifecycle tools are called in run_prompt.""" - mock_client = MockMCPClient( - tools=[types.Tool(name="screenshot", description="Take screenshot", inputSchema={})] - ) + assert len(captured_tools) == 2 + assert captured_tools[0].name == "tool1" - agent = MockAgentExtended( - mcp_client=mock_client, - responses=[{"role": "assistant", "content": "Done", "tool_calls": []}], - ) - # Initialize to make tools available - await agent.initialize() - - result = await agent.run("Test lifecycle", max_steps=1) - assert not result.isError - - # This test is commented out as screenshot history management may have changed - # @pytest.mark.asyncio - # async def test_screenshot_history_management(self, agent_with_tools): - # """Test screenshot history is maintained.""" - # agent_with_tools.initial_screenshot = True - - # # Set up responses with tool calls - # agent_with_tools.responses = [ - # { - # "role": "assistant", - # "content": "Action 1", - # "tool_calls": [MCPToolCall(name="click", arguments={"x": 1, "y": 1})], - # }, - # { - # "role": "assistant", - # "content": "Action 2", - # "tool_calls": [MCPToolCall(name="click", arguments={"x": 2, "y": 2})], - # }, - # { - # "role": "assistant", - # "content": "Action 3", - # "tool_calls": [MCPToolCall(name="click", arguments={"x": 3, "y": 3})], - # }, - # ] - - # await agent_with_tools.run("Test screenshots", max_steps=3) - - # # Should have screenshots in history - # assert len(agent_with_tools.screenshot_history) > 0 +class TestMCPAgentToolSchemas: + """Tests for tool schema generation.""" @pytest.mark.asyncio - async def test_run_with_invalid_prompt_type(self, agent_with_tools): - """Test run with invalid prompt type raises TypeError.""" - with pytest.raises(TypeError, match="prompt_or_task must be str or Task"): - await agent_with_tools.run(123) # Invalid type - - @pytest.mark.asyncio - async def test_evaluate_phase_with_multiple_tools(self, agent_with_tools): - """Test evaluation phase with multiple evaluation tools.""" - from hud.types import MCPToolResult - - task = Task( - id="test_task", - prompt="Test evaluation", - mcp_config={"test_server": {"url": "http://localhost:8080"}}, - evaluate_tool=[ - MCPToolCall(name="eval1", arguments={}), - MCPToolCall(name="eval2", arguments={"reward": True}), - ], - ) - - agent_with_tools.responses = [{"role": "assistant", "content": "Done", "tool_calls": []}] - - eval_calls = [] - agent_with_tools.mcp_client.call_tool = AsyncMock( - side_effect=lambda tool_call: eval_calls.append(tool_call) - or MCPToolResult( - content=[types.TextContent(type="text", text=f"{tool_call.name} result")], - isError=False, - structuredContent={"reward": 0.5} if tool_call.name == "eval1" else {"reward": 1.0}, + async def test_get_tool_schemas(self) -> None: + """Test get_tool_schemas returns correct format.""" + tools = [ + types.Tool( + name="my_tool", + description="My tool description", + inputSchema={"type": "object", "properties": {"x": {"type": "string"}}}, ) - ) - - result = await agent_with_tools.run(task) - - # Check that the tool names match - eval_names = [call.name for call in eval_calls] - assert "eval1" in eval_names - assert "eval2" in eval_names - assert result.reward == 0.5 # From eval1 (first evaluation tool) - - @pytest.mark.asyncio - async def test_trace_population_on_error(self, agent_with_tools): - """Test that trace is populated on task execution error.""" - - task = Task( - id="test_task", - prompt="Test error", - mcp_config={"test_server": {"url": "http://localhost:8080"}}, - setup_tool={"name": "failing_setup", "arguments": {}}, # type: ignore[arg-type] - ) - - # Make setup fail with exception - agent_with_tools.mcp_client.call_tool = AsyncMock(side_effect=Exception("Setup explosion")) + ] + ctx = MockEvalContext(prompt="Do something", tools=tools) + agent = MockMCPAgent(auto_trace=False) - result = await agent_with_tools.run(task) + # Initialize agent + agent.ctx = ctx + await agent._initialize_from_ctx(ctx) - assert result.isError - # Error content is the string representation of the MCPToolResult list - assert "Setup explosion" in result.content - assert "MCPToolResult" in result.content - assert result.done + schemas = agent.get_tool_schemas() + assert len(schemas) == 1 + assert schemas[0]["name"] == "my_tool" + assert schemas[0]["description"] == "My tool description" diff --git a/hud/agents/tests/test_base_runtime.py b/hud/agents/tests/test_base_runtime.py index 2ea24756..83502e31 100644 --- a/hud/agents/tests/test_base_runtime.py +++ b/hud/agents/tests/test_base_runtime.py @@ -1,16 +1,16 @@ +"""Runtime tests for MCPAgent base class.""" + from __future__ import annotations from typing import Any -from unittest import mock import mcp.types as types import pytest from hud.agents.base import BaseCreateParams, MCPAgent, find_content, find_reward, text_to_blocks +from hud.eval.context import EvalContext from hud.types import AgentResponse, BaseAgentConfig, MCPToolCall, MCPToolResult -from .conftest import MockMCPClient - class DummyConfig(BaseAgentConfig): model_name: str = "DummyAgent" @@ -21,43 +21,69 @@ class DummyCreateParams(BaseCreateParams, DummyConfig): pass +class MockEvalContext(EvalContext): + """Mock EvalContext for testing.""" + + def __init__( + self, + prompt: str = "Test prompt", + tools: list[types.Tool] | None = None, + ) -> None: + self.prompt = prompt + self._tools = tools or [] + self._submitted: str | None = None + self.reward: float | None = None + self._call_tool_handler: Any = None + + def set_call_tool_handler(self, handler: Any) -> None: + self._call_tool_handler = handler + + async def list_tools(self) -> list[types.Tool]: + return self._tools + + async def call_tool(self, call: Any, /, **kwargs: Any) -> MCPToolResult: + if self._call_tool_handler: + # Parse the call + if isinstance(call, tuple): + tc = MCPToolCall(name=call[0], arguments=call[1] if len(call) > 1 else {}) + elif hasattr(call, "name"): + tc = call + else: + tc = MCPToolCall(name=str(call), arguments=kwargs) + return self._call_tool_handler(tc) + return MCPToolResult( + content=[types.TextContent(type="text", text="ok")], + isError=False, + ) + + async def submit(self, answer: str) -> None: + self._submitted = answer + + class DummyAgent(MCPAgent): config_cls = DummyConfig def __init__(self, **kwargs: Any) -> None: - # Only create MockMCPClient if mcp_client not specified at all - if "mcp_client" not in kwargs: - kwargs["mcp_client"] = MockMCPClient() params = DummyCreateParams(**kwargs) super().__init__(params) async def get_system_messages(self) -> list[types.ContentBlock]: return [types.TextContent(type="text", text="sys")] - async def get_response(self, messages): - # Single step: no tool calls -> done + async def get_response(self, messages: list[Any]) -> AgentResponse: return AgentResponse(content="ok", tool_calls=[], done=True) - async def format_blocks(self, blocks): - # Return as-is + async def format_blocks(self, blocks: list[Any]) -> list[Any]: return blocks - async def format_tool_results(self, tool_calls, tool_results): + async def format_tool_results( + self, tool_calls: list[MCPToolCall], tool_results: list[MCPToolResult] + ) -> list[Any]: return [types.TextContent(text="tools", type="text")] -@pytest.mark.asyncio -async def test_run_with_string_prompt_auto_client(monkeypatch): - fake_client = MockMCPClient() - - # Patch MCPClient construction inside initialize() - with mock.patch("hud.clients.MCPClient", return_value=fake_client): - agent = DummyAgent(mcp_client=fake_client, auto_trace=False) - result = await agent.run("hello", max_steps=1) - assert result.done is True and result.isError is False - - -def test_find_reward_and_content_extractors(): +def test_find_reward_and_content_extractors() -> None: + """Test reward and content extraction from tool results.""" # Structured content r = MCPToolResult( content=text_to_blocks("{}"), isError=False, structuredContent={"reward": 0.7} @@ -70,108 +96,108 @@ def test_find_reward_and_content_extractors(): assert find_content(r2) == "hi" -@pytest.mark.asyncio -async def test_call_tools_error_paths(): - call_count = [0] - ok_result = MCPToolResult(content=text_to_blocks("ok"), isError=False) - - def handler(tool_call: MCPToolCall) -> MCPToolResult: - call_count[0] += 1 - if call_count[0] == 1: - return ok_result - raise RuntimeError("boom") - - fake_client = MockMCPClient(call_tool_handler=handler) - agent = DummyAgent(mcp_client=fake_client, auto_trace=False) - results = await agent.call_tools( - [MCPToolCall(name="a", arguments={}), MCPToolCall(name="b", arguments={})] - ) - assert results[0].isError is False - assert results[1].isError is True - - -@pytest.mark.asyncio -async def test_initialize_without_client_raises_valueerror(): - agent = DummyAgent(mcp_client=None, auto_trace=False) - with pytest.raises(ValueError): - await agent.initialize(None) - - -def test_get_available_tools_before_initialize_raises(): - agent = DummyAgent(mcp_client=MockMCPClient(), auto_trace=False) +def test_get_available_tools_before_run_raises() -> None: + """Test that get_available_tools raises before initialization.""" + agent = DummyAgent(auto_trace=False) with pytest.raises(RuntimeError): agent.get_available_tools() @pytest.mark.asyncio -async def test_format_message_invalid_type_raises(): - agent = DummyAgent(mcp_client=MockMCPClient(), auto_trace=False) +async def test_format_message_invalid_type_raises() -> None: + """Test that format_message raises for invalid types.""" + agent = DummyAgent(auto_trace=False) with pytest.raises(ValueError): await agent.format_message({"oops": 1}) # type: ignore -@pytest.mark.asyncio -async def test_call_tools_timeout_error_shutdown_called(): - def handler(tool_call: MCPToolCall) -> MCPToolResult: - raise TimeoutError("timeout") - - fake_client = MockMCPClient(call_tool_handler=handler) - agent = DummyAgent(mcp_client=fake_client, auto_trace=False) - with pytest.raises(TimeoutError): - await agent.call_tools(MCPToolCall(name="x", arguments={})) - assert fake_client.shutdown_called - - -def test_text_to_blocks_shapes(): +def test_text_to_blocks_shapes() -> None: + """Test text_to_blocks returns correct structure.""" blocks = text_to_blocks("x") assert isinstance(blocks, list) and blocks and isinstance(blocks[0], types.TextContent) @pytest.mark.asyncio -async def test_run_returns_connection_error_trace(monkeypatch): - fake_client = MockMCPClient( - initialize_error=RuntimeError("Connection refused http://localhost:1234") - ) +async def test_run_with_eval_context() -> None: + """Test basic run() with EvalContext.""" + ctx = MockEvalContext(prompt="hello") + agent = DummyAgent(auto_trace=False) + result = await agent.run(ctx, max_steps=1) + assert result.done is True + assert result.isError is False + - class DummyCM: - def __exit__(self, *args, **kwargs): - return False +@pytest.mark.asyncio +async def test_run_requires_eval_context() -> None: + """Test run() raises TypeError for non-EvalContext.""" + agent = DummyAgent(auto_trace=False) + with pytest.raises(TypeError, match="must be EvalContext"): + await agent.run("hello") # type: ignore - monkeypatch.setattr("hud.utils.mcp.setup_hud_telemetry", lambda *args, **kwargs: DummyCM()) - agent = DummyAgent(mcp_client=fake_client, auto_trace=False) - result = await agent.run("p", max_steps=1) - assert result.isError is True - assert "Could not connect" in (result.content or "") +@pytest.mark.asyncio +async def test_run_requires_prompt() -> None: + """Test run() raises ValueError when prompt is empty.""" + ctx = MockEvalContext(prompt="") + agent = DummyAgent(auto_trace=False) + with pytest.raises(ValueError, match="prompt is not set"): + await agent.run(ctx) @pytest.mark.asyncio -async def test_run_calls_response_tool_when_configured(monkeypatch): - ok = MCPToolResult(content=text_to_blocks("ok"), isError=False) - fake_client = MockMCPClient(call_tool_handler=lambda _: ok) +async def test_call_tools_error_paths() -> None: + """Test call_tools handles errors correctly.""" + call_count = [0] + ok_result = MCPToolResult(content=text_to_blocks("ok"), isError=False) - class DummyCM: - def __exit__(self, *args, **kwargs): - return False + def handler(tool_call: MCPToolCall) -> MCPToolResult: + call_count[0] += 1 + if call_count[0] == 1: + return ok_result + raise RuntimeError("boom") - monkeypatch.setattr("hud.utils.mcp.setup_hud_telemetry", lambda *args, **kwargs: DummyCM()) + ctx = MockEvalContext(prompt="test") + ctx.set_call_tool_handler(handler) + agent = DummyAgent(auto_trace=False) - agent = DummyAgent(mcp_client=fake_client, auto_trace=False, response_tool_name="submit") - result = await agent.run("hello", max_steps=1) - assert result.isError is False - assert len(fake_client.call_tool_calls) > 0 + # Initialize the agent with context + agent.ctx = ctx + await agent._initialize_from_ctx(ctx) + + results = await agent.call_tools( + [MCPToolCall(name="a", arguments={}), MCPToolCall(name="b", arguments={})] + ) + assert results[0].isError is False + assert results[1].isError is True @pytest.mark.asyncio -async def test_get_available_tools_after_initialize(monkeypatch): - fake_client = MockMCPClient() +async def test_call_tools_timeout_raises() -> None: + """Test call_tools raises TimeoutError.""" + def handler(tool_call: MCPToolCall) -> MCPToolResult: + raise TimeoutError("timeout") - class DummyCM: - def __exit__(self, *args, **kwargs): - return False + ctx = MockEvalContext(prompt="test") + ctx.set_call_tool_handler(handler) + agent = DummyAgent(auto_trace=False) - monkeypatch.setattr("hud.utils.mcp.setup_hud_telemetry", lambda *args, **kwargs: DummyCM()) + agent.ctx = ctx + await agent._initialize_from_ctx(ctx) - agent = DummyAgent(mcp_client=fake_client, auto_trace=False) - await agent.initialize(None) - assert agent.get_available_tools() == [] + with pytest.raises(TimeoutError): + await agent.call_tools(MCPToolCall(name="x", arguments={})) + + +@pytest.mark.asyncio +async def test_get_available_tools_after_run() -> None: + """Test get_available_tools works after initialization.""" + tools = [types.Tool(name="test_tool", description="Test", inputSchema={})] + ctx = MockEvalContext(prompt="hello", tools=tools) + agent = DummyAgent(auto_trace=False) + + # Run initializes the agent + await agent.run(ctx, max_steps=1) + + # After cleanup, we can't access tools (ctx is cleared) + # But during run, tools were available + assert agent._initialized is True diff --git a/hud/agents/tests/test_claude.py b/hud/agents/tests/test_claude.py index d84951df..0ac8c87c 100644 --- a/hud/agents/tests/test_claude.py +++ b/hud/agents/tests/test_claude.py @@ -2,12 +2,11 @@ from __future__ import annotations -from types import SimpleNamespace -from typing import TYPE_CHECKING, cast +from typing import TYPE_CHECKING, Any from unittest.mock import AsyncMock, MagicMock, patch import pytest -from anthropic import AsyncAnthropic, BadRequestError +from anthropic import AsyncAnthropic from mcp import types from hud.agents.claude import ( @@ -16,41 +15,65 @@ text_to_content_block, tool_use_content_block, ) +from hud.eval.context import EvalContext from hud.types import MCPToolCall, MCPToolResult if TYPE_CHECKING: - from anthropic.types.beta import BetaImageBlockParam, BetaMessageParam, BetaTextBlockParam + from anthropic.types.beta import BetaImageBlockParam, BetaTextBlockParam + + +class MockEvalContext(EvalContext): + """Mock EvalContext for testing.""" + + def __init__(self, tools: list[types.Tool] | None = None) -> None: + self.prompt = "Test prompt" + self._tools = tools or [] + self._submitted: str | None = None + self.reward: float | None = None + + async def list_tools(self) -> list[types.Tool]: + return self._tools + + async def call_tool(self, call: Any, /, **kwargs: Any) -> MCPToolResult: + return MCPToolResult( + content=[types.TextContent(type="text", text="ok")], + isError=False, + ) + + async def submit(self, answer: str) -> None: + self._submitted = answer class MockStreamContextManager: """Mock for Claude's streaming context manager.""" - def __init__(self, response: MagicMock): + def __init__(self, response: MagicMock) -> None: self.response = response - async def __aenter__(self): + async def __aenter__(self) -> MockStreamContextManager: return self - async def __aexit__(self, exc_type, exc_val, exc_tb): + async def __aexit__( + self, exc_type: type | None, exc_val: Exception | None, exc_tb: Any + ) -> bool: return False - def __aiter__(self): + def __aiter__(self) -> MockStreamContextManager: return self - async def __anext__(self): - # No events to yield, end iteration immediately + async def __anext__(self) -> None: raise StopAsyncIteration - async def get_final_message(self): + async def get_final_message(self) -> MagicMock: return self.response class TestClaudeHelperFunctions: """Test helper functions for Claude message formatting.""" - def test_base64_to_content_block(self): + def test_base64_to_content_block(self) -> None: """Test base64 image conversion.""" - base64_data = "iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mNkYPhfDwAChwGA60e6kgAAAABJRU5ErkJggg==" # noqa: E501 + base64_data = "iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mNk" result = base64_to_content_block(base64_data) assert result["type"] == "image" @@ -58,7 +81,7 @@ def test_base64_to_content_block(self): assert result["source"]["media_type"] == "image/png" assert result["source"]["data"] == base64_data - def test_text_to_content_block(self): + def test_text_to_content_block(self) -> None: """Test text conversion.""" text = "Hello, world!" result = text_to_content_block(text) @@ -66,7 +89,7 @@ def test_text_to_content_block(self): assert result["type"] == "text" assert result["text"] == text - def test_tool_use_content_block(self): + def test_tool_use_content_block(self) -> None: """Test tool result content block creation.""" tool_use_id = "tool_123" content: list[BetaTextBlockParam | BetaImageBlockParam] = [ @@ -84,294 +107,235 @@ class TestClaudeAgent: """Test ClaudeAgent class.""" @pytest.fixture - def mock_anthropic(self): - """Create a stub AsyncAnthropic client and patch constructor.""" - client = AsyncAnthropic(api_key="test_key") - client.__dict__["beta"] = SimpleNamespace(messages=AsyncMock()) - with patch("hud.agents.claude.AsyncAnthropic", return_value=client): + def mock_anthropic(self) -> AsyncAnthropic: + """Create a stub Anthropic client.""" + with patch("hud.agents.claude.AsyncAnthropic") as mock_class, patch( + "hud.agents.claude.Anthropic" + ) as mock_sync: + # Mock the sync client's models.list() for validation + mock_sync.return_value.models.list.return_value = [] + + client = MagicMock(spec=AsyncAnthropic) + client.api_key = "test-key" + mock_class.return_value = client yield client @pytest.mark.asyncio - async def test_init(self, mock_mcp_client, mock_anthropic): - """Test agent initialization.""" + async def test_init_with_client(self, mock_anthropic: AsyncAnthropic) -> None: + """Test agent initialization with provided client.""" agent = ClaudeAgent.create( - mcp_client=mock_mcp_client, model_client=mock_anthropic, - checkpoint_name="claude-3-opus-20240229", - max_tokens=1000, - validate_api_key=False, # Skip validation in tests + checkpoint_name="claude-sonnet-4-20250514", + validate_api_key=False, ) assert agent.model_name == "Claude" - assert agent.max_tokens == 1000 + assert agent.config.checkpoint_name == "claude-sonnet-4-20250514" assert agent.anthropic_client == mock_anthropic @pytest.mark.asyncio - async def test_init_without_model_client(self, mock_mcp_client, mock_anthropic): - """Test agent initialization without model client.""" - with patch("hud.settings.settings.anthropic_api_key", "test_key"): - agent = ClaudeAgent.create( - mcp_client=mock_mcp_client, - checkpoint_name="claude-3-opus-20240229", - validate_api_key=False, # Skip validation in tests - ) + async def test_init_with_parameters(self, mock_anthropic: AsyncAnthropic) -> None: + """Test agent initialization with various parameters.""" + agent = ClaudeAgent.create( + model_client=mock_anthropic, + checkpoint_name="claude-sonnet-4-20250514", + max_tokens=4096, + validate_api_key=False, + ) - assert agent.model_name == "Claude" - assert agent.anthropic_client is not None + assert agent.max_tokens == 4096 @pytest.mark.asyncio - async def test_format_blocks(self, mock_mcp_client, mock_anthropic): - """Test formatting content blocks into Claude messages.""" + async def test_format_blocks_text_only(self, mock_anthropic: AsyncAnthropic) -> None: + """Test formatting text content blocks.""" agent = ClaudeAgent.create( - mcp_client=mock_mcp_client, model_client=mock_anthropic, - validate_api_key=False, # Skip validation in tests + validate_api_key=False, ) - # Test with text only - text_blocks: list[types.ContentBlock] = [ - types.TextContent(type="text", text="Hello, Claude!") + blocks: list[types.ContentBlock] = [ + types.TextContent(type="text", text="Hello, world!"), + types.TextContent(type="text", text="How are you?"), ] - messages = await agent.format_blocks(text_blocks) + + messages = await agent.format_blocks(blocks) assert len(messages) == 1 assert messages[0]["role"] == "user" - content = messages[0]["content"] - assert isinstance(content, list) - assert len(content) == 1 - assert content[0]["type"] == "text" - assert content[0]["text"] == "Hello, Claude!" + assert len(messages[0]["content"]) == 2 + assert messages[0]["content"][0]["type"] == "text" + assert messages[0]["content"][0]["text"] == "Hello, world!" - # Test with screenshot - image_blocks: list[types.ContentBlock] = [ - types.TextContent(type="text", text="Look at this"), + @pytest.mark.asyncio + async def test_format_blocks_with_image(self, mock_anthropic: AsyncAnthropic) -> None: + """Test formatting image content blocks.""" + agent = ClaudeAgent.create( + model_client=mock_anthropic, + validate_api_key=False, + ) + + blocks: list[types.ContentBlock] = [ + types.TextContent(type="text", text="Look at this:"), types.ImageContent(type="image", data="base64data", mimeType="image/png"), ] - messages = await agent.format_blocks(image_blocks) + + messages = await agent.format_blocks(blocks) assert len(messages) == 1 - assert messages[0]["role"] == "user" - content = messages[0]["content"] - assert isinstance(content, list) - assert len(content) == 2 - # Content blocks are in order - assert content[0]["type"] == "text" - assert content[0]["text"] == "Look at this" - assert content[1]["type"] == "image" - assert content[1]["source"]["data"] == "base64data" + assert len(messages[0]["content"]) == 2 + assert messages[0]["content"][1]["type"] == "image" @pytest.mark.asyncio - async def test_format_tool_results_method(self, mock_mcp_client, mock_anthropic): - """Test the agent's format_tool_results method.""" + async def test_format_tool_results_text(self, mock_anthropic: AsyncAnthropic) -> None: + """Test formatting tool results with text content.""" agent = ClaudeAgent.create( - mcp_client=mock_mcp_client, model_client=mock_anthropic, - validate_api_key=False, # Skip validation in tests + validate_api_key=False, ) - tool_calls = [ - MCPToolCall(name="test_tool", arguments={}, id="id1"), - ] - + tool_calls = [MCPToolCall(id="call_123", name="test_tool", arguments={})] tool_results = [ - MCPToolResult(content=[types.TextContent(type="text", text="Success")], isError=False), + MCPToolResult( + content=[types.TextContent(type="text", text="Tool output")], + isError=False, + ) ] messages = await agent.format_tool_results(tool_calls, tool_results) - - # format_tool_results returns a single user message with tool result content assert len(messages) == 1 assert messages[0]["role"] == "user" - # The content is wrapped in a tool result block - content = list(messages[0]["content"]) + content = messages[0]["content"] assert len(content) == 1 - assert content[0]["type"] == "tool_result" # type: ignore - assert content[0]["tool_use_id"] == "id1" # type: ignore - # The actual content is nested inside - inner_content = list(content[0]["content"]) # type: ignore - assert inner_content[0]["type"] == "text" # type: ignore - assert inner_content[0]["text"] == "Success" # type: ignore + assert content[0]["type"] == "tool_result" + assert content[0]["tool_use_id"] == "call_123" @pytest.mark.asyncio - async def test_get_response(self, mock_mcp_client, mock_anthropic): - """Test getting model response from Claude API.""" - # Disable telemetry for this test to avoid backend configuration issues - with patch("hud.settings.settings.telemetry_enabled", False): - agent = ClaudeAgent.create( - mcp_client=mock_mcp_client, - model_client=mock_anthropic, - validate_api_key=False, # Skip validation in tests + async def test_format_tool_results_with_error(self, mock_anthropic: AsyncAnthropic) -> None: + """Test formatting tool results with error.""" + agent = ClaudeAgent.create( + model_client=mock_anthropic, + validate_api_key=False, + ) + + tool_calls = [MCPToolCall(id="call_123", name="test_tool", arguments={})] + tool_results = [ + MCPToolResult( + content=[types.TextContent(type="text", text="Error message")], + isError=True, ) + ] - # Mock the API response - mock_response = MagicMock() - - # Create text block - text_block = MagicMock() - text_block.type = "text" - text_block.text = "Hello!" - - # Create tool use block - tool_block = MagicMock() - tool_block.type = "tool_use" - tool_block.id = "tool_123" - tool_block.name = "test_tool" - tool_block.input = {"param": "value"} - - mock_response.content = [text_block, tool_block] - mock_response.usage = MagicMock(input_tokens=10, output_tokens=20) - - # Mock the streaming context manager - mock_stream = MockStreamContextManager(mock_response) - mock_anthropic.beta.messages.stream = MagicMock(return_value=mock_stream) - - messages = [ - cast( - "BetaMessageParam", - {"role": "user", "content": [{"type": "text", "text": "Hi"}]}, - ) - ] - response = await agent.get_response(messages) - - assert response.content == "Hello!" - assert len(response.tool_calls) == 1 - assert response.tool_calls[0].name == "test_tool" - assert response.tool_calls[0].arguments == {"param": "value"} - # The test was checking for Claude-specific attributes that aren't part of ModelResponse - # These would need to be accessed from the original Claude response if needed - - # Verify API was called correctly - mock_anthropic.beta.messages.stream.assert_called_once() + messages = await agent.format_tool_results(tool_calls, tool_results) + assert len(messages) == 1 + content = messages[0]["content"] + # Error content should include "Error:" prefix + assert any("Error" in str(block) for block in content[0]["content"]) @pytest.mark.asyncio - async def test_get_model_response_text_only(self, mock_mcp_client, mock_anthropic): - """Test getting text-only response.""" - # Disable telemetry for this test to avoid backend configuration issues - with patch("hud.settings.settings.telemetry_enabled", False): - agent = ClaudeAgent.create( - mcp_client=mock_mcp_client, - model_client=mock_anthropic, - validate_api_key=False, # Skip validation in tests - ) + async def test_get_system_messages(self, mock_anthropic: AsyncAnthropic) -> None: + """Test that system messages return empty (Claude uses system param).""" + agent = ClaudeAgent.create( + model_client=mock_anthropic, + system_prompt="You are a helpful assistant.", + validate_api_key=False, + ) - mock_response = MagicMock() - # Create text block - text_block = MagicMock() - text_block.type = "text" - text_block.text = "Just text" - mock_response.content = [text_block] - mock_response.usage = MagicMock(input_tokens=5, output_tokens=10) - - # Mock the streaming context manager - mock_stream = MockStreamContextManager(mock_response) - mock_anthropic.beta.messages.stream = MagicMock(return_value=mock_stream) - - messages = [ - cast( - "BetaMessageParam", - {"role": "user", "content": [{"type": "text", "text": "Hi"}]}, - ) - ] - response = await agent.get_response(messages) - - assert response.content == "Just text" - assert response.tool_calls == [] + messages = await agent.get_system_messages() + # Claude doesn't use system messages in the message list + assert messages == [] @pytest.mark.asyncio - async def test_get_model_response_error(self, mock_mcp_client, mock_anthropic): - """Test handling API errors.""" - # Disable telemetry for this test to avoid backend configuration issues - with patch("hud.settings.settings.telemetry_enabled", False): - agent = ClaudeAgent.create( - mcp_client=mock_mcp_client, - model_client=mock_anthropic, - validate_api_key=False, # Skip validation in tests + async def test_convert_tools_for_claude(self, mock_anthropic: AsyncAnthropic) -> None: + """Test converting MCP tools to Claude format.""" + tools = [ + types.Tool( + name="my_tool", + description="A test tool", + inputSchema={"type": "object", "properties": {"x": {"type": "string"}}}, ) + ] + ctx = MockEvalContext(tools=tools) + agent = ClaudeAgent.create( + model_client=mock_anthropic, + validate_api_key=False, + ) + + agent.ctx = ctx + await agent._initialize_from_ctx(ctx) - # Mock API error - stream() raises when entering context - error = BadRequestError( - message="Invalid request", - response=MagicMock(status_code=400), - body={"error": {"message": "Invalid request"}}, + # Check that tools were converted + assert len(agent.claude_tools) == 1 + assert agent.claude_tools[0]["name"] == "my_tool" + + @pytest.mark.asyncio + async def test_computer_tool_detection(self, mock_anthropic: AsyncAnthropic) -> None: + """Test that computer tools are detected for beta API.""" + tools = [ + types.Tool( + name="computer", + description="Control computer", + inputSchema={"type": "object"}, ) + ] + ctx = MockEvalContext(tools=tools) + agent = ClaudeAgent.create( + model_client=mock_anthropic, + validate_api_key=False, + ) - class MockErrorStreamContextManager: - """Mock stream that raises error on enter.""" + agent.ctx = ctx + await agent._initialize_from_ctx(ctx) - async def __aenter__(self): - raise error + assert agent.has_computer_tool is True - async def __aexit__(self, exc_type, exc_val, exc_tb): - return False + @pytest.mark.asyncio + async def test_get_response_with_text(self, mock_anthropic: AsyncAnthropic) -> None: + """Test getting response with text output.""" + # Create mock response + mock_response = MagicMock() + mock_response.content = [MagicMock(type="text", text="Hello!")] - mock_anthropic.beta.messages.stream = MagicMock( - return_value=MockErrorStreamContextManager() - ) + mock_stream = MockStreamContextManager(mock_response) + mock_anthropic.beta.messages.stream = MagicMock(return_value=mock_stream) + + agent = ClaudeAgent.create( + model_client=mock_anthropic, + validate_api_key=False, + ) + agent.claude_tools = [] + agent.tool_mapping = {} + agent.has_computer_tool = False + agent._initialized = True - messages = [{"role": "user", "content": [{"type": "text", "text": "Hi"}]}] - - with pytest.raises(BadRequestError): - await agent.get_response(messages) # type: ignore - - # This test is commented out as it's testing complex integration scenarios - # that may have changed in the implementation - # @pytest.mark.asyncio - # async def test_run_with_tools(self, mock_mcp_client, mock_anthropic): - # """Test running agent with tool usage.""" - # # Disable telemetry for this test to avoid backend configuration issues - # with patch("hud.settings.settings.telemetry_enabled", False): - # agent = ClaudeAgent.create(mcp_client=mock_mcp_client, model_client=mock_anthropic) - - # # Mock tool availability - # agent._available_tools = [ - # types.Tool( - # name="calculator", description="Calculator", inputSchema={"type": "object"} - # ) - # ] - # agent._tool_map = { - # "calculator": types.Tool( - # name="calculator", description="Calculator", inputSchema={"type": "object"} - # ) - # } - - # # Mock initial response with tool use - # initial_response = MagicMock() - # # Create tool use block - # tool_block = MagicMock() - # tool_block.type = "tool_use" - # tool_block.id = "calc_123" - # tool_block.name = "calculator" - # tool_block.input = {"operation": "add", "a": 2, "b": 3} - # initial_response.content = [tool_block] - # initial_response.usage = MagicMock(input_tokens=10, output_tokens=15) - - # # Mock follow-up response - # final_response = MagicMock() - # text_block = MagicMock() - # text_block.type = "text" - # text_block.text = "2 + 3 = 5" - # final_response.content = [text_block] - # final_response.usage = MagicMock(input_tokens=20, output_tokens=10) - - # mock_anthropic.beta.messages.create = AsyncMock( - # side_effect=[initial_response, final_response] - # ) - - # # Mock tool execution - # mock_mcp_client.call_tool = AsyncMock( - # return_value=MCPToolResult( - # content=[types.TextContent(type="text", text="5")], isError=False - # ) - # ) - - # # Mock the mcp_client properties - # mock_mcp_client.mcp_config = {"test_server": {"url": "http://localhost"}} - # mock_mcp_client.list_tools = AsyncMock(return_value=agent._available_tools) - # mock_mcp_client.initialize = AsyncMock() - - # # Initialize the agent - # await agent.initialize() - - # # Use a string prompt instead of a task - # result = await agent.run("What is 2 + 3?") - - # assert result.content == "2 + 3 = 5" - # assert result.done is True + response = await agent.get_response([]) + assert response.content == "Hello!" + assert response.done is True + assert len(response.tool_calls) == 0 + + @pytest.mark.asyncio + async def test_get_response_with_tool_call(self, mock_anthropic: AsyncAnthropic) -> None: + """Test getting response with tool call.""" + mock_tool_use = MagicMock() + mock_tool_use.type = "tool_use" + mock_tool_use.id = "call_123" + mock_tool_use.name = "my_tool" + mock_tool_use.input = {"x": "value"} + + mock_response = MagicMock() + mock_response.content = [mock_tool_use] + + mock_stream = MockStreamContextManager(mock_response) + mock_anthropic.beta.messages.stream = MagicMock(return_value=mock_stream) + + agent = ClaudeAgent.create( + model_client=mock_anthropic, + validate_api_key=False, + ) + agent.claude_tools = [] + agent.tool_mapping = {"my_tool": "my_tool"} + agent.has_computer_tool = False + agent._initialized = True + + response = await agent.get_response([]) + assert response.done is False + assert len(response.tool_calls) == 1 + assert response.tool_calls[0].name == "my_tool" + assert response.tool_calls[0].arguments == {"x": "value"} diff --git a/hud/agents/tests/test_gemini.py b/hud/agents/tests/test_gemini.py index 0a98ecd9..242ce725 100644 --- a/hud/agents/tests/test_gemini.py +++ b/hud/agents/tests/test_gemini.py @@ -3,6 +3,7 @@ from __future__ import annotations import base64 +from typing import Any from unittest.mock import MagicMock, patch import pytest @@ -11,15 +12,37 @@ from mcp import types from hud.agents.gemini import GeminiAgent -from hud.agents.gemini_cua import GeminiCUAAgent +from hud.eval.context import EvalContext from hud.types import MCPToolCall, MCPToolResult +class MockEvalContext(EvalContext): + """Mock EvalContext for testing.""" + + def __init__(self, tools: list[types.Tool] | None = None) -> None: + self.prompt = "Test prompt" + self._tools = tools or [] + self._submitted: str | None = None + self.reward: float | None = None + + async def list_tools(self) -> list[types.Tool]: + return self._tools + + async def call_tool(self, call: Any, /, **kwargs: Any) -> MCPToolResult: + return MCPToolResult( + content=[types.TextContent(type="text", text="ok")], + isError=False, + ) + + async def submit(self, answer: str) -> None: + self._submitted = answer + + class TestGeminiAgent: """Test GeminiAgent base class.""" @pytest.fixture - def mock_gemini_client(self): + def mock_gemini_client(self) -> genai.Client: """Create a stub Gemini client.""" client = genai.Client(api_key="test_key") client.models.list = MagicMock(return_value=iter([])) @@ -27,13 +50,12 @@ def mock_gemini_client(self): return client @pytest.mark.asyncio - async def test_init(self, mock_mcp_client, mock_gemini_client): + async def test_init(self, mock_gemini_client: genai.Client) -> None: """Test agent initialization.""" agent = GeminiAgent.create( - mcp_client=mock_mcp_client, model_client=mock_gemini_client, checkpoint_name="gemini-2.5-flash", - validate_api_key=False, # Skip validation in tests + validate_api_key=False, ) assert agent.model_name == "Gemini" @@ -41,7 +63,7 @@ async def test_init(self, mock_mcp_client, mock_gemini_client): assert agent.gemini_client == mock_gemini_client @pytest.mark.asyncio - async def test_init_without_model_client(self, mock_mcp_client): + async def test_init_without_model_client(self) -> None: """Test agent initialization without model client.""" with ( patch("hud.settings.settings.gemini_api_key", "test_key"), @@ -54,509 +76,164 @@ async def test_init_without_model_client(self, mock_mcp_client): mock_client_class.return_value = mock_client agent = GeminiAgent.create( - mcp_client=mock_mcp_client, checkpoint_name="gemini-2.5-flash", validate_api_key=False, ) - assert agent.model_name == "Gemini" assert agent.gemini_client is not None @pytest.mark.asyncio - async def test_format_blocks(self, mock_mcp_client, mock_gemini_client): - """Test formatting content blocks into Gemini messages.""" + async def test_format_blocks_text_only(self, mock_gemini_client: genai.Client) -> None: + """Test formatting text content blocks.""" agent = GeminiAgent.create( - mcp_client=mock_mcp_client, model_client=mock_gemini_client, validate_api_key=False, ) - # Test with text only - text_blocks: list[types.ContentBlock] = [ - types.TextContent(type="text", text="Hello, Gemini!") - ] - messages = await agent.format_blocks(text_blocks) - assert len(messages) == 1 - assert messages[0].role == "user" - parts = messages[0].parts - assert parts is not None - assert len(parts) == 1 - assert parts[0].text == "Hello, Gemini!" - - # Test with screenshot - image_blocks: list[types.ContentBlock] = [ - types.TextContent(type="text", text="Look at this"), - types.ImageContent( - type="image", - data=base64.b64encode(b"fakeimage").decode("utf-8"), - mimeType="image/png", - ), + blocks: list[types.ContentBlock] = [ + types.TextContent(type="text", text="Hello, world!"), + types.TextContent(type="text", text="How are you?"), ] - messages = await agent.format_blocks(image_blocks) + + messages = await agent.format_blocks(blocks) assert len(messages) == 1 assert messages[0].role == "user" - parts = messages[0].parts - assert parts is not None - assert len(parts) == 2 - # First part is text - assert parts[0].text == "Look at this" - # Second part is image - check that it was created from bytes - assert parts[1].inline_data is not None + assert len(messages[0].parts) == 2 @pytest.mark.asyncio - async def test_format_tool_results(self, mock_mcp_client, mock_gemini_client): - """Test the agent's format_tool_results method for non-computer tools.""" + async def test_format_blocks_with_image(self, mock_gemini_client: genai.Client) -> None: + """Test formatting image content blocks.""" agent = GeminiAgent.create( - mcp_client=mock_mcp_client, model_client=mock_gemini_client, validate_api_key=False, ) - tool_calls = [ - MCPToolCall( - name="calculator", - arguments={"operation": "add", "a": 1, "b": 2}, - id="call_1", # type: ignore - gemini_name="calculator", # type: ignore - ), - ] + # Create a tiny valid base64 PNG + png_data = base64.b64encode(b"\x89PNG\r\n\x1a\n").decode() - tool_results = [ - MCPToolResult( - content=[ - types.TextContent(type="text", text="Result: 3"), - ], - isError=False, - ), + blocks: list[types.ContentBlock] = [ + types.TextContent(type="text", text="Look at this:"), + types.ImageContent(type="image", data=png_data, mimeType="image/png"), ] - messages = await agent.format_tool_results(tool_calls, tool_results) - - # format_tool_results returns a single user message with function responses + messages = await agent.format_blocks(blocks) assert len(messages) == 1 - assert messages[0].role == "user" - # The content contains function response parts - parts = messages[0].parts - assert parts is not None - assert len(parts) == 1 - function_response = parts[0].function_response - assert function_response is not None - assert function_response.name == "calculator" - response_payload = function_response.response or {} - assert response_payload.get("success") is True - assert response_payload.get("output") == "Result: 3" + assert len(messages[0].parts) == 2 @pytest.mark.asyncio - async def test_format_tool_results_with_error(self, mock_mcp_client, mock_gemini_client): - """Test formatting tool results with errors.""" + async def test_format_tool_results(self, mock_gemini_client: genai.Client) -> None: + """Test formatting tool results.""" agent = GeminiAgent.create( - mcp_client=mock_mcp_client, model_client=mock_gemini_client, validate_api_key=False, ) - tool_calls = [ - MCPToolCall( - name="calculator", - arguments={"operation": "divide", "a": 1, "b": 0}, - id="call_error", # type: ignore - gemini_name="calculator", # type: ignore - ), - ] - + tool_calls = [MCPToolCall(id="call_123", name="test_tool", arguments={})] tool_results = [ MCPToolResult( - content=[types.TextContent(type="text", text="Division by zero error")], - isError=True, - ), + content=[types.TextContent(type="text", text="Tool output")], + isError=False, + ) ] messages = await agent.format_tool_results(tool_calls, tool_results) - - # Check that error is in the response assert len(messages) == 1 assert messages[0].role == "user" - parts = messages[0].parts - assert parts is not None - function_response = parts[0].function_response - assert function_response is not None - response_payload = function_response.response or {} - assert "error" in response_payload - - @pytest.mark.asyncio - async def test_get_response_text_only(self, mock_mcp_client, mock_gemini_client): - """Test getting text-only response.""" - # Disable telemetry for this test - with patch("hud.settings.settings.telemetry_enabled", False): - agent = GeminiAgent.create( - mcp_client=mock_mcp_client, - model_client=mock_gemini_client, - validate_api_key=False, - ) - - # Mock the API response with text only - mock_response = MagicMock() - mock_candidate = MagicMock() - - text_part = MagicMock() - text_part.text = "Task completed successfully" - text_part.function_call = None - - mock_candidate.content = MagicMock() - mock_candidate.content.parts = [text_part] - - mock_response.candidates = [mock_candidate] - - mock_gemini_client.models.generate_content = MagicMock(return_value=mock_response) - - messages = [genai_types.Content(role="user", parts=[genai_types.Part(text="Status?")])] - response = await agent.get_response(messages) - - assert response.content == "Task completed successfully" - assert response.tool_calls == [] - assert response.done is True @pytest.mark.asyncio - async def test_convert_tools_for_gemini(self, mock_mcp_client, mock_gemini_client): - """Test converting MCP tools to Gemini format.""" + async def test_get_system_messages(self, mock_gemini_client: genai.Client) -> None: + """Test that system messages return empty (Gemini uses system_instruction).""" agent = GeminiAgent.create( - mcp_client=mock_mcp_client, model_client=mock_gemini_client, + system_prompt="You are a helpful assistant.", validate_api_key=False, ) - # Set up available tools (no computer tool for base agent) - agent._available_tools = [ - types.Tool( - name="calculator", - description="Calculator tool", - inputSchema={ - "type": "object", - "properties": {"operation": {"type": "string"}}, - }, - ), - types.Tool( - name="weather", - description="Weather tool", - inputSchema={ - "type": "object", - "properties": {"location": {"type": "string"}}, - }, - ), - ] - - gemini_tools = agent._convert_tools_for_gemini() - - # Should have 2 function declaration tools - assert len(gemini_tools) == 2 - - # Both should be function declarations - assert gemini_tools[0].function_declarations is not None # type: ignore[reportAttributeAccessIssue] - assert len(gemini_tools[0].function_declarations) == 1 # type: ignore[reportAttributeAccessIssue] - assert gemini_tools[0].function_declarations[0].name == "calculator" # type: ignore[reportAttributeAccessIssue] - - assert gemini_tools[1].function_declarations is not None # type: ignore[reportAttributeAccessIssue] - assert len(gemini_tools[1].function_declarations) == 1 # type: ignore[reportAttributeAccessIssue] - assert gemini_tools[1].function_declarations[0].name == "weather" # type: ignore[reportAttributeAccessIssue] + messages = await agent.get_system_messages() + # Gemini doesn't use system messages in the message list + assert messages == [] @pytest.mark.asyncio - async def test_create_user_message(self, mock_mcp_client, mock_gemini_client): - """Test creating a user message.""" + async def test_convert_tools_for_gemini(self, mock_gemini_client: genai.Client) -> None: + """Test converting MCP tools to Gemini format.""" + tools = [ + types.Tool( + name="my_tool", + description="A test tool", + inputSchema={"type": "object", "properties": {"x": {"type": "string"}}}, + ) + ] + ctx = MockEvalContext(tools=tools) agent = GeminiAgent.create( - mcp_client=mock_mcp_client, model_client=mock_gemini_client, validate_api_key=False, ) - message = await agent.create_user_message("Hello Gemini") + agent.ctx = ctx + await agent._initialize_from_ctx(ctx) - assert message.role == "user" - parts = message.parts - assert parts is not None - assert len(parts) == 1 - assert parts[0].text == "Hello Gemini" - - @pytest.mark.asyncio - async def test_handle_empty_response(self, mock_mcp_client, mock_gemini_client): - """Test handling empty response from API.""" - with patch("hud.settings.settings.telemetry_enabled", False): - agent = GeminiAgent.create( - mcp_client=mock_mcp_client, - model_client=mock_gemini_client, - validate_api_key=False, - ) - - # Mock empty response - mock_response = MagicMock() - mock_response.candidates = [] - - mock_gemini_client.models.generate_content = MagicMock(return_value=mock_response) - - messages = [genai_types.Content(role="user", parts=[genai_types.Part(text="Hi")])] - response = await agent.get_response(messages) - - assert response.content == "" - assert response.tool_calls == [] - assert response.done is True + # Check that tools were converted + assert len(agent.gemini_tools) == 1 + assert agent.gemini_tools[0]["name"] == "my_tool" -class TestGeminiCUAAgent: - """Test GeminiCUAAgent computer use agent.""" +class TestGeminiToolConversion: + """Tests for tool conversion to Gemini format.""" @pytest.fixture - def mock_gemini_client(self): + def mock_gemini_client(self) -> genai.Client: """Create a stub Gemini client.""" client = genai.Client(api_key="test_key") client.models.list = MagicMock(return_value=iter([])) - client.models.generate_content = MagicMock() return client @pytest.mark.asyncio - async def test_init(self, mock_mcp_client_gemini_computer, mock_gemini_client): - """Test agent initialization.""" - agent = GeminiCUAAgent.create( - mcp_client=mock_mcp_client_gemini_computer, - model_client=mock_gemini_client, - checkpoint_name="gemini-2.5-computer-use-preview", - validate_api_key=False, # Skip validation in tests - ) - - assert agent.model_name == "GeminiCUA" - assert agent.config.checkpoint_name == "gemini-2.5-computer-use-preview" - assert agent.gemini_client == mock_gemini_client - - @pytest.mark.asyncio - async def test_format_tool_results_with_screenshot( - self, mock_mcp_client_gemini_computer, mock_gemini_client - ): - """Test the agent's format_tool_results method with screenshots.""" - agent = GeminiCUAAgent.create( - mcp_client=mock_mcp_client_gemini_computer, - model_client=mock_gemini_client, - validate_api_key=False, - ) - - tool_calls = [ - MCPToolCall( - name="gemini_computer", - arguments={"action": "click_at", "x": 100, "y": 200}, - id="call_1", # type: ignore - gemini_name="click_at", # type: ignore - ), - ] - - tool_results = [ - MCPToolResult( - content=[ - types.TextContent(type="text", text="__URL__:https://example.com"), - types.ImageContent( - type="image", - data=base64.b64encode(b"screenshot").decode("utf-8"), - mimeType="image/png", - ), - ], - isError=False, - ), - ] - - messages = await agent.format_tool_results(tool_calls, tool_results) - - # format_tool_results returns a single user message with function responses - assert len(messages) == 1 - assert messages[0].role == "user" - # The content contains function response parts - parts = messages[0].parts - assert parts is not None - assert len(parts) == 1 - function_response = parts[0].function_response - assert function_response is not None - assert function_response.name == "click_at" - response_payload = function_response.response or {} - assert response_payload.get("success") is True - assert response_payload.get("url") == "https://example.com" - - @pytest.mark.asyncio - async def test_format_tool_results_with_error( - self, mock_mcp_client_gemini_computer, mock_gemini_client - ): - """Test formatting tool results with errors.""" - agent = GeminiCUAAgent.create( - mcp_client=mock_mcp_client_gemini_computer, - model_client=mock_gemini_client, - validate_api_key=False, - ) - - tool_calls = [ - MCPToolCall( - name="gemini_computer", - arguments={"action": "invalid"}, - id="call_error", # type: ignore - gemini_name="click_at", # type: ignore - ), - ] - - tool_results = [ - MCPToolResult( - content=[types.TextContent(type="text", text="Action failed: invalid action")], - isError=True, - ), - ] - - messages = await agent.format_tool_results(tool_calls, tool_results) - - # Check that error is in the response - assert len(messages) == 1 - assert messages[0].role == "user" - parts = messages[0].parts - assert parts is not None - function_response = parts[0].function_response - assert function_response is not None - response_payload = function_response.response or {} - assert "error" in response_payload - - @pytest.mark.asyncio - async def test_get_response(self, mock_mcp_client_gemini_computer, mock_gemini_client): - """Test getting model response from Gemini API.""" - # Disable telemetry for this test - with patch("hud.settings.settings.telemetry_enabled", False): - agent = GeminiCUAAgent.create( - mcp_client=mock_mcp_client_gemini_computer, - model_client=mock_gemini_client, - validate_api_key=False, - ) - - # Set up available tools - agent._available_tools = [ - types.Tool(name="gemini_computer", description="Computer tool", inputSchema={}) - ] - - # Mock the API response - mock_response = MagicMock() - mock_candidate = MagicMock() - - # Create text part - text_part = MagicMock() - text_part.text = "I will click at coordinates" - text_part.function_call = None - - # Create function call part - function_call_part = MagicMock() - function_call_part.text = None - function_call_part.function_call = MagicMock() - function_call_part.function_call.name = "click_at" - function_call_part.function_call.args = {"coordinate": [100, 200]} - - mock_candidate.content = MagicMock() - mock_candidate.content.parts = [text_part, function_call_part] - - mock_response.candidates = [mock_candidate] - - mock_gemini_client.models.generate_content = MagicMock(return_value=mock_response) - - messages = [genai_types.Content(role="user", parts=[genai_types.Part(text="Click")])] - response = await agent.get_response(messages) - - assert response.content == "I will click at coordinates" - assert len(response.tool_calls) == 1 - # Check normalized arguments - assert response.tool_calls[0].arguments == {"action": "click_at", "x": 100, "y": 200} - assert response.done is False - - @pytest.mark.asyncio - async def test_convert_tools_for_gemini( - self, mock_mcp_client_gemini_computer, mock_gemini_client - ): - """Test converting MCP tools to Gemini format.""" - agent = GeminiCUAAgent.create( - mcp_client=mock_mcp_client_gemini_computer, - model_client=mock_gemini_client, - validate_api_key=False, - ) - - # Set up available tools - agent._available_tools = [ + async def test_tool_with_properties(self, mock_gemini_client: genai.Client) -> None: + """Test tool with input properties.""" + tools = [ types.Tool( - name="gemini_computer", - description="Computer tool", - inputSchema={"type": "object"}, - ), - types.Tool( - name="calculator", - description="Calculator tool", + name="search", + description="Search the web", inputSchema={ "type": "object", - "properties": {"operation": {"type": "string"}}, + "properties": { + "query": {"type": "string", "description": "Search query"}, + "limit": {"type": "integer", "description": "Max results"}, + }, + "required": ["query"], }, - ), + ) ] - - gemini_tools = agent._convert_tools_for_gemini() - - # Should have 2 tools: computer_use and calculator - assert len(gemini_tools) == 2 - - # First should be computer use tool - assert gemini_tools[0].computer_use is not None # type: ignore[reportAttributeAccessIssue] - assert ( - gemini_tools[0].computer_use.environment == genai_types.Environment.ENVIRONMENT_BROWSER # type: ignore[reportAttributeAccessIssue] - ) - - # Second should be calculator as function declaration - assert gemini_tools[1].function_declarations is not None # type: ignore[reportAttributeAccessIssue] - assert len(gemini_tools[1].function_declarations) == 1 # type: ignore[reportAttributeAccessIssue] - assert gemini_tools[1].function_declarations[0].name == "calculator" # type: ignore[reportAttributeAccessIssue] - - @pytest.mark.asyncio - async def test_extract_tool_call_normalizes_coordinates( - self, mock_mcp_client_gemini_computer, mock_gemini_client - ): - """Test that _extract_tool_call normalizes coordinate arrays to x/y.""" - agent = GeminiCUAAgent.create( - mcp_client=mock_mcp_client_gemini_computer, + ctx = MockEvalContext(tools=tools) + agent = GeminiAgent.create( model_client=mock_gemini_client, validate_api_key=False, ) - # Set up tool mapping - agent._gemini_to_mcp_tool_map = {"click_at": "gemini_computer"} - - # Create a mock part with function call - part = MagicMock() - part.function_call = MagicMock() - part.function_call.name = "click_at" - part.function_call.args = {"coordinate": [150, 250]} + agent.ctx = ctx + await agent._initialize_from_ctx(ctx) - tool_call = agent._extract_tool_call(part) - - assert tool_call is not None - assert tool_call.name == "gemini_computer" - assert tool_call.arguments["action"] == "click_at" # type: ignore[reportAttributeAccessIssue] - assert tool_call.arguments["x"] == 150 # type: ignore[reportAttributeAccessIssue] - assert tool_call.arguments["y"] == 250 # type: ignore[reportAttributeAccessIssue] + assert len(agent.gemini_tools) == 1 + tool = agent.gemini_tools[0] + assert tool["name"] == "search" + assert "parameters" in tool @pytest.mark.asyncio - async def test_extract_tool_call_preserves_non_computer_args( - self, mock_mcp_client_gemini_computer, mock_gemini_client - ): - """Test that _extract_tool_call preserves arguments for non-computer tools.""" - agent = GeminiCUAAgent.create( - mcp_client=mock_mcp_client_gemini_computer, + async def test_tool_without_schema(self, mock_gemini_client: genai.Client) -> None: + """Test tool without input schema raises error.""" + tools = [ + types.Tool( + name="incomplete", + description=None, + inputSchema=None, + ) + ] + ctx = MockEvalContext(tools=tools) + agent = GeminiAgent.create( model_client=mock_gemini_client, validate_api_key=False, ) - # Set up tool mapping - agent._gemini_to_mcp_tool_map = {"calculator": "calculator"} - - # Create a mock part with function call for non-computer tool - part = MagicMock() - part.function_call = MagicMock() - part.function_call.name = "calculator" - part.function_call.args = {"operation": "add", "a": 1, "b": 2} - - tool_call = agent._extract_tool_call(part) - - assert tool_call is not None - assert tool_call.name == "calculator" - # Arguments should be passed as-is, no normalization - assert tool_call.arguments == {"operation": "add", "a": 1, "b": 2} + agent.ctx = ctx + with pytest.raises(ValueError, match="requires both a description"): + await agent._initialize_from_ctx(ctx) diff --git a/hud/agents/tests/test_openai.py b/hud/agents/tests/test_openai.py index 26e27144..bc58eb58 100644 --- a/hud/agents/tests/test_openai.py +++ b/hud/agents/tests/test_openai.py @@ -2,8 +2,8 @@ from __future__ import annotations -from typing import Any, cast -from unittest.mock import AsyncMock, MagicMock, patch +from typing import Any +from unittest.mock import AsyncMock, patch import pytest from mcp import types @@ -18,14 +18,37 @@ from pydantic import AnyUrl from hud.agents.openai import OpenAIAgent +from hud.eval.context import EvalContext from hud.types import MCPToolCall, MCPToolResult +class MockEvalContext(EvalContext): + """Mock EvalContext for testing.""" + + def __init__(self, tools: list[types.Tool] | None = None) -> None: + self.prompt = "Test prompt" + self._tools = tools or [] + self._submitted: str | None = None + self.reward: float | None = None + + async def list_tools(self) -> list[types.Tool]: + return self._tools + + async def call_tool(self, call: Any, /, **kwargs: Any) -> MCPToolResult: + return MCPToolResult( + content=[types.TextContent(type="text", text="ok")], + isError=False, + ) + + async def submit(self, answer: str) -> None: + self._submitted = answer + + class TestOpenAIAgent: """Test OpenAIAgent class.""" @pytest.fixture - def mock_openai(self): + def mock_openai(self) -> AsyncOpenAI: """Create a stub OpenAI client.""" with patch("hud.agents.openai.AsyncOpenAI") as mock_class: client = AsyncOpenAI(api_key="test", base_url="http://localhost") @@ -35,12 +58,10 @@ def mock_openai(self): yield client @pytest.mark.asyncio - async def test_init_with_client(self, mock_mcp_client): + async def test_init_with_client(self, mock_openai: AsyncOpenAI) -> None: """Test agent initialization with provided client.""" - mock_model_client = AsyncOpenAI(api_key="test", base_url="http://localhost") agent = OpenAIAgent.create( - mcp_client=mock_mcp_client, - model_client=mock_model_client, + model_client=mock_openai, checkpoint_name="gpt-4o", validate_api_key=False, ) @@ -48,17 +69,15 @@ async def test_init_with_client(self, mock_mcp_client): assert agent.model_name == "OpenAI" assert agent.config.checkpoint_name == "gpt-4o" assert agent.checkpoint_name == "gpt-4o" - assert agent.openai_client == mock_model_client + assert agent.openai_client == mock_openai assert agent.max_output_tokens is None assert agent.temperature is None @pytest.mark.asyncio - async def test_init_with_parameters(self, mock_mcp_client): + async def test_init_with_parameters(self, mock_openai: AsyncOpenAI) -> None: """Test agent initialization with various parameters.""" - mock_model_client = AsyncOpenAI(api_key="test", base_url="http://localhost") agent = OpenAIAgent.create( - mcp_client=mock_mcp_client, - model_client=mock_model_client, + model_client=mock_openai, checkpoint_name="gpt-4o", max_output_tokens=2048, temperature=0.7, @@ -75,18 +94,17 @@ async def test_init_with_parameters(self, mock_mcp_client): assert agent.parallel_tool_calls is True @pytest.mark.asyncio - async def test_init_without_client_no_api_key(self, mock_mcp_client): + async def test_init_without_client_no_api_key(self) -> None: """Test agent initialization fails without API key.""" with patch("hud.agents.openai.settings") as mock_settings: mock_settings.openai_api_key = None with pytest.raises(ValueError, match="OpenAI API key not found"): - OpenAIAgent.create(mcp_client=mock_mcp_client) + OpenAIAgent.create() @pytest.mark.asyncio - async def test_format_blocks_text_only(self, mock_mcp_client, mock_openai): + async def test_format_blocks_text_only(self, mock_openai: AsyncOpenAI) -> None: """Test formatting text content blocks.""" agent = OpenAIAgent.create( - mcp_client=mock_mcp_client, model_client=mock_openai, validate_api_key=False, ) @@ -98,986 +116,300 @@ async def test_format_blocks_text_only(self, mock_mcp_client, mock_openai): messages = await agent.format_blocks(blocks) assert len(messages) == 1 - msg = cast("dict[str, Any]", messages[0]) - assert msg["role"] == "user" - content = cast("list[dict[str, Any]]", msg["content"]) - assert len(content) == 2 - assert content[0] == {"type": "input_text", "text": "Hello, world!"} - assert content[1] == {"type": "input_text", "text": "How are you?"} + assert messages[0]["role"] == "user" + assert len(messages[0]["content"]) == 2 + assert messages[0]["content"][0]["type"] == "input_text" + assert messages[0]["content"][0]["text"] == "Hello, world!" @pytest.mark.asyncio - async def test_format_blocks_with_image(self, mock_mcp_client, mock_openai): - """Test formatting content blocks with images.""" + async def test_format_blocks_with_image(self, mock_openai: AsyncOpenAI) -> None: + """Test formatting image content blocks.""" agent = OpenAIAgent.create( - mcp_client=mock_mcp_client, model_client=mock_openai, validate_api_key=False, ) blocks: list[types.ContentBlock] = [ - types.TextContent(type="text", text="Check this out:"), - types.ImageContent(type="image", data="base64imagedata", mimeType="image/jpeg"), + types.TextContent(type="text", text="Look at this:"), + types.ImageContent(type="image", data="base64data", mimeType="image/png"), ] messages = await agent.format_blocks(blocks) assert len(messages) == 1 - msg = cast("dict[str, Any]", messages[0]) - assert msg["role"] == "user" - content = cast("list[dict[str, Any]]", msg["content"]) - assert len(content) == 2 - assert content[0] == {"type": "input_text", "text": "Check this out:"} - assert content[1] == { - "type": "input_image", - "image_url": "data:image/jpeg;base64,base64imagedata", - "detail": "auto", - } + assert len(messages[0]["content"]) == 2 + assert messages[0]["content"][1]["type"] == "input_image" + assert messages[0]["content"][1]["image_url"] == "data:image/png;base64,base64data" @pytest.mark.asyncio - async def test_format_blocks_empty(self, mock_mcp_client, mock_openai): + async def test_format_blocks_empty(self, mock_openai: AsyncOpenAI) -> None: """Test formatting empty content blocks.""" agent = OpenAIAgent.create( - mcp_client=mock_mcp_client, model_client=mock_openai, validate_api_key=False, ) - blocks: list[types.ContentBlock] = [] - - messages = await agent.format_blocks(blocks) + messages = await agent.format_blocks([]) assert len(messages) == 1 - msg = cast("dict[str, Any]", messages[0]) - assert msg["role"] == "user" - content = cast("list[dict[str, Any]]", msg["content"]) - assert len(content) == 1 - assert content[0] == {"type": "input_text", "text": ""} + assert messages[0]["content"] == [] @pytest.mark.asyncio - async def test_format_tool_results_text(self, mock_mcp_client, mock_openai): + async def test_format_tool_results_text(self, mock_openai: AsyncOpenAI) -> None: """Test formatting tool results with text content.""" agent = OpenAIAgent.create( - mcp_client=mock_mcp_client, model_client=mock_openai, validate_api_key=False, ) - tool_calls = [ - MCPToolCall(name="test_tool", arguments={"arg": "value"}, id="call_123"), # type: ignore - ] - + tool_calls = [MCPToolCall(id="call_123", name="test_tool", arguments={})] tool_results = [ MCPToolResult( - content=[types.TextContent(type="text", text="Tool executed successfully")], + content=[types.TextContent(type="text", text="Tool output")], isError=False, - ), - ] - - messages = await agent.format_tool_results(tool_calls, tool_results) - - assert len(messages) == 1 - msg = cast("dict[str, Any]", messages[0]) - assert msg["type"] == "function_call_output" - assert msg["call_id"] == "call_123" - output = cast("list[dict[str, Any]]", msg["output"]) - assert len(output) == 1 - assert output[0]["type"] == "input_text" - assert output[0]["text"] == "Tool executed successfully" - - @pytest.mark.asyncio - async def test_format_tool_results_with_image(self, mock_mcp_client, mock_openai): - """Test formatting tool results with image content.""" - agent = OpenAIAgent.create( - mcp_client=mock_mcp_client, - model_client=mock_openai, - validate_api_key=False, - ) - - tool_calls = [ - MCPToolCall(name="screenshot", arguments={}, id="call_456"), # type: ignore - ] - - tool_results = [ - MCPToolResult( - content=[ - types.ImageContent(type="image", data="screenshot_data", mimeType="image/png") - ], - isError=False, - ), + ) ] messages = await agent.format_tool_results(tool_calls, tool_results) - assert len(messages) == 1 - msg = cast("dict[str, Any]", messages[0]) - assert msg["type"] == "function_call_output" - assert msg["call_id"] == "call_456" - output = cast("list[dict[str, Any]]", msg["output"]) - assert len(output) == 1 - assert output[0]["type"] == "input_image" - assert output[0]["image_url"] == "data:image/png;base64,screenshot_data" + assert messages[0]["type"] == "function_call_output" + assert messages[0]["call_id"] == "call_123" + assert messages[0]["output"] == "Tool output" @pytest.mark.asyncio - async def test_format_tool_results_with_error(self, mock_mcp_client, mock_openai): - """Test formatting tool results with errors.""" + async def test_format_tool_results_with_error(self, mock_openai: AsyncOpenAI) -> None: + """Test formatting tool results with error.""" agent = OpenAIAgent.create( - mcp_client=mock_mcp_client, model_client=mock_openai, validate_api_key=False, ) - tool_calls = [ - MCPToolCall(name="failing_tool", arguments={}, id="call_error"), # type: ignore - ] - + tool_calls = [MCPToolCall(id="call_123", name="test_tool", arguments={})] tool_results = [ MCPToolResult( - content=[types.TextContent(type="text", text="Error: Something went wrong")], + content=[types.TextContent(type="text", text="Error message")], isError=True, - ), - ] - - messages = await agent.format_tool_results(tool_calls, tool_results) - - assert len(messages) == 1 - msg = cast("dict[str, Any]", messages[0]) - assert msg["type"] == "function_call_output" - assert msg["call_id"] == "call_error" - output = cast("list[dict[str, Any]]", msg["output"]) - assert len(output) == 2 - assert output[0]["type"] == "input_text" - assert output[0]["text"] == "[tool_error] true" - assert output[1]["type"] == "input_text" - assert output[1]["text"] == "Error: Something went wrong" - - @pytest.mark.asyncio - async def test_format_tool_results_with_structured_content(self, mock_mcp_client, mock_openai): - """Test formatting tool results with structured content.""" - agent = OpenAIAgent.create( - mcp_client=mock_mcp_client, - model_client=mock_openai, - validate_api_key=False, - ) - - tool_calls = [ - MCPToolCall(name="data_tool", arguments={}, id="call_789"), # type: ignore - ] - - tool_results = [ - MCPToolResult( - content=[], - structuredContent={"key": "value", "number": 42}, - isError=False, - ), + ) ] messages = await agent.format_tool_results(tool_calls, tool_results) - assert len(messages) == 1 - msg = cast("dict[str, Any]", messages[0]) - assert msg["type"] == "function_call_output" - assert msg["call_id"] == "call_789" - output = cast("list[dict[str, Any]]", msg["output"]) - assert len(output) == 1 - assert output[0]["type"] == "input_text" - # Structured content is JSON serialized - import json - - parsed = json.loads(output[0]["text"]) - assert parsed == {"key": "value", "number": 42} - - @pytest.mark.asyncio - async def test_format_tool_results_multiple(self, mock_mcp_client, mock_openai): - """Test formatting multiple tool results.""" - agent = OpenAIAgent.create( - mcp_client=mock_mcp_client, - model_client=mock_openai, - validate_api_key=False, - ) - - tool_calls = [ - MCPToolCall(name="tool1", arguments={}, id="call_1"), # type: ignore - MCPToolCall(name="tool2", arguments={}, id="call_2"), # type: ignore - ] - - tool_results = [ - MCPToolResult( - content=[types.TextContent(type="text", text="Result 1")], - isError=False, - ), - MCPToolResult( - content=[types.TextContent(type="text", text="Result 2")], - isError=False, - ), - ] - - messages = await agent.format_tool_results(tool_calls, tool_results) - - assert len(messages) == 2 - msg0 = cast("dict[str, Any]", messages[0]) - assert msg0["call_id"] == "call_1" - msg1 = cast("dict[str, Any]", messages[1]) - assert msg1["call_id"] == "call_2" - - @pytest.mark.asyncio - async def test_format_tool_results_missing_call_id(self, mock_mcp_client, mock_openai): - """Test formatting tool results with missing call_id.""" - agent = OpenAIAgent.create( - mcp_client=mock_mcp_client, - model_client=mock_openai, - validate_api_key=False, - ) - - tool_calls = [ - MCPToolCall(name="tool_no_id", arguments={}, id=""), # Empty string instead of None - ] - - tool_results = [ - MCPToolResult( - content=[types.TextContent(type="text", text="Some result")], - isError=False, - ), - ] - - messages = await agent.format_tool_results(tool_calls, tool_results) - - # Should skip tools without call_id (empty string is falsy) - assert len(messages) == 0 - - @pytest.mark.asyncio - async def test_get_response_with_text(self, mock_mcp_client, mock_openai): - """Test getting model response with text output.""" - # Disable telemetry for this test - with patch("hud.settings.settings.telemetry_enabled", False): - agent = OpenAIAgent.create( - mcp_client=mock_mcp_client, - model_client=mock_openai, - validate_api_key=False, - ) - - # Mock OpenAI API response - mock_response = MagicMock() - mock_response.id = "response_123" - - # Create properly typed output text with all required fields - mock_output_text = ResponseOutputText( - type="output_text", - text="This is the response text", - annotations=[], # Required field - ) - - # Create properly typed output message with all required fields - mock_output_message = ResponseOutputMessage( - type="message", - id="msg_123", # Required field - role="assistant", # Required field - status="completed", # Required field - content=[mock_output_text], - ) - - mock_response.output = [mock_output_message] - - mock_openai.responses.create = AsyncMock(return_value=mock_response) - - # Test with initial message - messages = [{"role": "user", "content": [{"type": "input_text", "text": "Hello"}]}] - response = await agent.get_response(messages) - - assert response.content == "This is the response text" - assert response.done is True - assert response.tool_calls == [] - assert agent.last_response_id == "response_123" - - @pytest.mark.asyncio - async def test_get_response_with_tool_call(self, mock_mcp_client, mock_openai): - """Test getting model response with tool call.""" - with patch("hud.settings.settings.telemetry_enabled", False): - agent = OpenAIAgent.create( - mcp_client=mock_mcp_client, - model_client=mock_openai, - validate_api_key=False, - ) - - # Set up tool name map - agent._tool_name_map = {"test_tool": "test_tool"} - - # Mock OpenAI API response with properly typed function call - mock_response = MagicMock() - mock_response.id = "response_456" - - # Create properly typed function call with correct type value - mock_function_call = ResponseFunctionToolCall( - type="function_call", # Correct type value - call_id="call_123", - name="test_tool", - arguments='{"param": "value"}', - ) - - mock_response.output = [mock_function_call] - - mock_openai.responses.create = AsyncMock(return_value=mock_response) - - messages = [ - {"role": "user", "content": [{"type": "input_text", "text": "Do something"}]} - ] - response = await agent.get_response(messages) - - assert response.done is False - assert len(response.tool_calls) == 1 - assert response.tool_calls[0].name == "test_tool" - assert response.tool_calls[0].id == "call_123" - assert response.tool_calls[0].arguments == {"param": "value"} - - @pytest.mark.asyncio - async def test_get_response_with_reasoning(self, mock_mcp_client, mock_openai): - """Test getting model response with reasoning.""" - with patch("hud.settings.settings.telemetry_enabled", False): - agent = OpenAIAgent.create( - mcp_client=mock_mcp_client, - model_client=mock_openai, - validate_api_key=False, - ) - - # Mock OpenAI API response with properly typed reasoning - mock_response = MagicMock() - mock_response.id = "response_789" - - # Create a properly typed reasoning item with all required fields - mock_summary = Summary( - type="summary_text", # Correct literal type value - text="Let me think about this...", - ) - - mock_reasoning = ResponseReasoningItem( - type="reasoning", - id="reasoning_1", # Required field - summary=[mock_summary], # Required field - status="completed", # Required field - ) - - # Create properly typed output message with all required fields - mock_output_text = ResponseOutputText( - type="output_text", - text="Final answer", - annotations=[], # Required field - ) - mock_output_message = ResponseOutputMessage( - type="message", - id="msg_789", # Required field - role="assistant", # Required field - status="completed", # Required field - content=[mock_output_text], - ) - - mock_response.output = [mock_reasoning, mock_output_message] - - mock_openai.responses.create = AsyncMock(return_value=mock_response) - - messages = [ - {"role": "user", "content": [{"type": "input_text", "text": "Hard question"}]} - ] - response = await agent.get_response(messages) - - assert "Thinking: Let me think about this..." in response.content - assert "Final answer" in response.content - - @pytest.mark.asyncio - async def test_get_response_empty_messages(self, mock_mcp_client, mock_openai): - """Test getting model response with empty messages.""" - with patch("hud.settings.settings.telemetry_enabled", False): - agent = OpenAIAgent.create( - mcp_client=mock_mcp_client, - model_client=mock_openai, - validate_api_key=False, - ) - - # Mock empty response - mock_response = MagicMock() - mock_response.id = "response_empty" - mock_response.output = [] - - mock_openai.responses.create = AsyncMock(return_value=mock_response) - - messages = [] - response = await agent.get_response(messages) - - assert response.content == "" - assert response.tool_calls == [] - - @pytest.mark.asyncio - async def test_get_response_no_new_messages_with_previous_id( - self, mock_mcp_client, mock_openai - ): - """Test getting model response when no new messages and previous response exists.""" - with patch("hud.settings.settings.telemetry_enabled", False): - agent = OpenAIAgent.create( - mcp_client=mock_mcp_client, - model_client=mock_openai, - validate_api_key=False, - ) - - agent.last_response_id = "prev_response" - agent._message_cursor = 1 - - messages = [{"role": "user", "content": [{"type": "input_text", "text": "Hello"}]}] - response = await agent.get_response(messages) - - # Should return early without calling API - assert response.content == "" - assert response.done is True - mock_openai.responses.create.assert_not_called() - - @pytest.mark.asyncio - async def test_get_response_passes_correct_payload(self, mock_mcp_client, mock_openai): - """Test that get_response passes correct parameters to OpenAI API.""" - with patch("hud.settings.settings.telemetry_enabled", False): - agent = OpenAIAgent.create( - mcp_client=mock_mcp_client, - model_client=mock_openai, - checkpoint_name="gpt-4o", - max_output_tokens=1024, - temperature=0.5, - reasoning={"effort": "high"}, - tool_choice="auto", - parallel_tool_calls=True, - validate_api_key=False, - ) - - agent._openai_tools = [cast("Any", {"type": "function", "name": "test"})] - agent.system_prompt = "You are a helpful assistant" - agent.last_response_id = "prev_123" - - # Mock the API response - mock_response = MagicMock() - mock_response.id = "response_new" - mock_response.output = [] - mock_openai.responses.create = AsyncMock(return_value=mock_response) - - messages = [{"role": "user", "content": [{"type": "input_text", "text": "Hi"}]}] - await agent.get_response(messages) - - # Verify the API was called with the correct parameters - mock_openai.responses.create.assert_called_once() - call_kwargs = mock_openai.responses.create.call_args.kwargs - - assert call_kwargs["model"] == "gpt-4o" - assert call_kwargs["input"] == messages - assert call_kwargs["instructions"] == "You are a helpful assistant" - assert call_kwargs["max_output_tokens"] == 1024 - assert call_kwargs["temperature"] == 0.5 - assert call_kwargs["reasoning"] == {"effort": "high"} - assert call_kwargs["tool_choice"] == "auto" - assert call_kwargs["parallel_tool_calls"] is True - assert call_kwargs["tools"] == [{"type": "function", "name": "test"}] - assert call_kwargs["previous_response_id"] == "prev_123" - - @pytest.mark.asyncio - async def test_get_response_passes_minimal_payload(self, mock_mcp_client, mock_openai): - """Test that get_response passes minimal parameters when not configured.""" - from openai import Omit - - with patch("hud.settings.settings.telemetry_enabled", False): - agent = OpenAIAgent.create( - mcp_client=mock_mcp_client, - model_client=mock_openai, - validate_api_key=False, - ) - - # Mock the API response - mock_response = MagicMock() - mock_response.id = "response_new" - mock_response.output = [] - mock_openai.responses.create = AsyncMock(return_value=mock_response) - - messages = [{"role": "user", "content": [{"type": "input_text", "text": "Hi"}]}] - await agent.get_response(messages) - - # Verify the API was called with minimal parameters - mock_openai.responses.create.assert_called_once() - call_kwargs = mock_openai.responses.create.call_args.kwargs - - assert call_kwargs["model"] == "gpt-5.1" # default - assert call_kwargs["input"] == messages - assert call_kwargs["max_output_tokens"] is None - assert call_kwargs["temperature"] is None - # tool_choice should be Omit() when not set - assert isinstance(call_kwargs["tool_choice"], Omit) - # tools should be Omit() when empty - assert isinstance(call_kwargs["tools"], Omit) - # previous_response_id should be Omit() when not set - assert isinstance(call_kwargs["previous_response_id"], Omit) + assert "Error message" in messages[0]["output"] @pytest.mark.asyncio - async def test_reset_response_state(self, mock_mcp_client, mock_openai): - """Test resetting response state.""" - agent = OpenAIAgent.create( - mcp_client=mock_mcp_client, - model_client=mock_openai, - validate_api_key=False, - ) - - # Set some state - agent.last_response_id = "some_id" - agent._message_cursor = 5 - - # Reset - agent._reset_response_state() - - assert agent.last_response_id is None - assert agent._message_cursor == 0 - - @pytest.mark.asyncio - async def test_get_system_messages(self, mock_mcp_client, mock_openai): + async def test_get_system_messages(self, mock_openai: AsyncOpenAI) -> None: """Test getting system messages.""" agent = OpenAIAgent.create( - mcp_client=mock_mcp_client, model_client=mock_openai, + system_prompt="You are a helpful assistant.", validate_api_key=False, ) - # OpenAI agent returns empty list (uses instructions field instead) messages = await agent.get_system_messages() - assert messages == [] + assert len(messages) == 1 + assert messages[0]["type"] == "message" + assert messages[0]["role"] == "developer" @pytest.mark.asyncio - async def test_convert_tools_for_openai(self, mock_mcp_client, mock_openai): + async def test_convert_tools_for_openai(self, mock_openai: AsyncOpenAI) -> None: """Test converting MCP tools to OpenAI format.""" - agent = OpenAIAgent.create( - mcp_client=mock_mcp_client, - model_client=mock_openai, - validate_api_key=False, - ) - - # Mock MCP tools - mock_tools = [ + tools = [ types.Tool( - name="tool1", - description="First tool", - inputSchema={ - "type": "object", - "properties": {"arg1": {"type": "string"}}, - "required": ["arg1"], - "additionalProperties": False, - }, - ), - types.Tool( - name="tool2", - description="Second tool", - inputSchema={ - "type": "object", - "properties": {}, - "additionalProperties": False, - }, - ), + name="my_tool", + description="A test tool", + inputSchema={"type": "object", "properties": {"x": {"type": "string"}}}, + ) ] - - agent._available_tools = mock_tools - agent._convert_tools_for_openai() - - assert len(agent._openai_tools) == 2 - assert agent._tool_name_map == {"tool1": "tool1", "tool2": "tool2"} - - tool1 = cast("dict[str, Any]", agent._openai_tools[0]) - assert tool1["type"] == "function" - assert tool1["name"] == "tool1" - assert tool1["description"] == "First tool" - assert tool1["strict"] is True - - @pytest.mark.asyncio - async def test_convert_tools_raises_on_incomplete(self, mock_mcp_client, mock_openai): - """Test that converting tools raises error for incomplete tool definitions.""" - agent = OpenAIAgent.create( - mcp_client=mock_mcp_client, - model_client=mock_openai, - validate_api_key=False, - ) - - # Create mock tools directly as objects that bypass pydantic validation - incomplete1 = MagicMock(spec=types.Tool) - incomplete1.name = "incomplete1" - incomplete1.description = None - incomplete1.inputSchema = {"type": "object"} - - agent._available_tools = [incomplete1] - - # Should raise ValueError for tool without description - with pytest.raises(ValueError, match="requires both a description and inputSchema"): - agent._convert_tools_for_openai() - - @pytest.mark.asyncio - async def test_convert_tools_for_openai_via_initialize(self, mock_mcp_client, mock_openai): - """Test that initialize properly converts tools.""" + ctx = MockEvalContext(tools=tools) agent = OpenAIAgent.create( - mcp_client=mock_mcp_client, model_client=mock_openai, validate_api_key=False, ) - # Mock the list_tools to return our test tools - mock_mcp_client.list_tools = AsyncMock( - return_value=[ - types.Tool( - name="complete", - description="Complete tool", - inputSchema={"type": "object", "properties": {}, "additionalProperties": False}, - ) - ] - ) - - await agent.initialize() + # Initialize with context to trigger tool conversion + agent.ctx = ctx + await agent._initialize_from_ctx(ctx) - # Should have the complete tool converted - assert len(agent._openai_tools) == 1 - tool = cast("dict[str, Any]", agent._openai_tools[0]) - assert tool["name"] == "complete" + # Check that tools were converted + assert len(agent.openai_tools) >= 1 + # Find our tool + tool = next((t for t in agent.openai_tools if t.get("name") == "my_tool"), None) + assert tool is not None + assert tool["type"] == "function" @pytest.mark.asyncio - async def test_get_response_converts_function_tool_call(self, mock_mcp_client, mock_openai): - """Test that get_response properly converts OpenAI function tool calls to MCP format.""" - with patch("hud.settings.settings.telemetry_enabled", False): - agent = OpenAIAgent.create( - mcp_client=mock_mcp_client, - model_client=mock_openai, - validate_api_key=False, - ) - - # Set up tool name map (simulating tool conversion) - agent._tool_name_map = {"openai_name": "mcp_name"} - - # Mock OpenAI API response with function call - mock_response = MagicMock() - mock_response.id = "response_123" - - mock_function_call = ResponseFunctionToolCall( - type="function_call", - call_id="call_123", - name="openai_name", - arguments='{"key": "value", "number": 42}', + async def test_convert_tools_raises_on_incomplete(self, mock_openai: AsyncOpenAI) -> None: + """Test that tools without description raise error.""" + tools = [ + types.Tool( + name="incomplete_tool", + description=None, # Missing description + inputSchema={"type": "object"}, ) - - mock_response.output = [mock_function_call] - mock_openai.responses.create = AsyncMock(return_value=mock_response) - - messages = [ - {"role": "user", "content": [{"type": "input_text", "text": "Do something"}]} - ] - response = await agent.get_response(messages) - - # Verify the tool call was converted correctly - assert len(response.tool_calls) == 1 - assert response.tool_calls[0].name == "mcp_name" - assert response.tool_calls[0].id == "call_123" - assert response.tool_calls[0].arguments == {"key": "value", "number": 42} - - @pytest.mark.asyncio - async def test_convert_function_tool_call_invalid_json(self, mock_mcp_client, mock_openai): - """Test converting function tool call with invalid JSON.""" - _agent = OpenAIAgent.create( - mcp_client=mock_mcp_client, + ] + ctx = MockEvalContext(tools=tools) + agent = OpenAIAgent.create( model_client=mock_openai, validate_api_key=False, ) - async def test_get_response_raises_on_invalid_json_arguments( - self, mock_mcp_client, mock_openai - ): - """Test that get_response raises error on invalid JSON in function call arguments. - - With strict mode being mandatory, invalid JSON arguments should never occur - in practice since schemas are validated. This test verifies that if it does - happen, we get an appropriate error rather than silently failing. - """ - import json - - with patch("hud.settings.settings.telemetry_enabled", False): - agent = OpenAIAgent.create( - mcp_client=mock_mcp_client, - model_client=mock_openai, - validate_api_key=False, - ) - - agent._tool_name_map = {"tool": "tool"} - - # Mock OpenAI API response with function call that has invalid JSON - mock_response = MagicMock() - mock_response.id = "response_456" - - mock_function_call = ResponseFunctionToolCall( - type="function_call", - call_id="call_456", - name="tool", - arguments="invalid json {{", - ) - - mock_response.output = [mock_function_call] - mock_openai.responses.create = AsyncMock(return_value=mock_response) - - messages = [ - {"role": "user", "content": [{"type": "input_text", "text": "Do something"}]} - ] - - # With strict mode mandatory, invalid JSON should raise an error - with pytest.raises(json.JSONDecodeError): - await agent.get_response(messages) + agent.ctx = ctx + with pytest.raises(ValueError, match="requires both a description"): + await agent._initialize_from_ctx(ctx) @pytest.mark.asyncio - async def test_get_response_handles_tool_name_mapping(self, mock_mcp_client, mock_openai): - """Test that get_response correctly maps tool names that aren't in the map.""" - with patch("hud.settings.settings.telemetry_enabled", False): - agent = OpenAIAgent.create( - mcp_client=mock_mcp_client, - model_client=mock_openai, - validate_api_key=False, - ) - - # Tool name is NOT in the map, should fall back to the original name - agent._tool_name_map = {} - - mock_response = MagicMock() - mock_response.id = "response_789" - - mock_function_call = ResponseFunctionToolCall( - type="function_call", - call_id="call_789", - name="unmapped_tool", - arguments="{}", + async def test_get_response_with_text(self, mock_openai: AsyncOpenAI) -> None: + """Test getting response with text output.""" + # Setup mock response + mock_response = AsyncMock() + mock_response.output = [ + ResponseOutputMessage( + id="msg_123", + type="message", + role="assistant", + status="completed", + content=[ResponseOutputText(type="output_text", text="Hello!")], ) + ] + mock_openai.responses.create = AsyncMock(return_value=mock_response) - mock_response.output = [mock_function_call] - mock_openai.responses.create = AsyncMock(return_value=mock_response) - - messages = [ - {"role": "user", "content": [{"type": "input_text", "text": "Do something"}]} - ] - response = await agent.get_response(messages) - - # Should use the original tool name when not in map - assert len(response.tool_calls) == 1 - assert response.tool_calls[0].name == "unmapped_tool" - assert response.tool_calls[0].arguments == {} - - @pytest.mark.asyncio - async def test_convert_tools_for_openai_shell_tool(self, mock_mcp_client, mock_openai): - """Test that shell tool is converted to OpenAI native shell type.""" agent = OpenAIAgent.create( - mcp_client=mock_mcp_client, model_client=mock_openai, validate_api_key=False, ) + # Set empty tools to avoid needing initialization + agent.openai_tools = [] + agent._initialized = True - # Mock a shell tool - shell_tool = types.Tool( - name="shell", - description="Execute shell commands", - inputSchema={"type": "object", "properties": {}}, - ) - - agent._available_tools = [shell_tool] - agent._convert_tools_for_openai() - - assert len(agent._openai_tools) == 1 - tool = cast("dict[str, Any]", agent._openai_tools[0]) - assert tool["type"] == "shell" + response = await agent.get_response([]) + assert response.content == "Hello!" + assert response.done is True + assert len(response.tool_calls) == 0 @pytest.mark.asyncio - async def test_convert_tools_for_openai_apply_patch_tool(self, mock_mcp_client, mock_openai): - """Test that apply_patch tool is converted to OpenAI native apply_patch type.""" - agent = OpenAIAgent.create( - mcp_client=mock_mcp_client, - model_client=mock_openai, - validate_api_key=False, - ) - - # Mock an apply_patch tool - apply_patch_tool = types.Tool( - name="apply_patch", - description="Apply patches to files", - inputSchema={"type": "object", "properties": {}}, - ) - - agent._available_tools = [apply_patch_tool] - agent._convert_tools_for_openai() - - assert len(agent._openai_tools) == 1 - tool = cast("dict[str, Any]", agent._openai_tools[0]) - assert tool["type"] == "apply_patch" + async def test_get_response_with_tool_call(self, mock_openai: AsyncOpenAI) -> None: + """Test getting response with tool call.""" + mock_response = AsyncMock() + mock_response.output = [ + ResponseOutputMessage( + id="msg_123", + type="message", + role="assistant", + status="completed", + content=[ + ResponseFunctionToolCall( + id="call_123", + type="function_call", + call_id="call_123", + name="my_tool", + arguments='{"x": "value"}', + ) + ], + ) + ] + mock_openai.responses.create = AsyncMock(return_value=mock_response) - @pytest.mark.asyncio - async def test_convert_tools_for_openai_strict_schema_failure( - self, mock_mcp_client, mock_openai - ): - """Test that tool conversion raises error when strict schema conversion fails.""" agent = OpenAIAgent.create( - mcp_client=mock_mcp_client, model_client=mock_openai, validate_api_key=False, ) + agent.openai_tools = [] + agent.tool_mapping = {"my_tool": "my_tool"} + agent._initialized = True - # Mock a tool with a schema that will fail strict conversion - # Using a schema without additionalProperties which is required for strict mode - mock_tool = types.Tool( - name="non_strict_tool", - description="A tool with non-strict schema", - inputSchema={ - "type": "object", - "properties": {"arg": {"type": "string"}}, - # Missing additionalProperties and required - will fail strict conversion - }, - ) - - agent._available_tools = [mock_tool] - - # Mock ensure_strict_json_schema to raise an exception - with patch("hud.agents.openai.ensure_strict_json_schema") as mock_strict: - mock_strict.side_effect = ValueError("Schema not strict compatible") - # Now strict compatibility is mandatory, so this should raise - with pytest.raises(ValueError, match="Schema not strict compatible"): - agent._convert_tools_for_openai() + response = await agent.get_response([]) + assert response.done is False + assert len(response.tool_calls) == 1 + assert response.tool_calls[0].name == "my_tool" + assert response.tool_calls[0].arguments == {"x": "value"} @pytest.mark.asyncio - async def test_format_tool_results_with_resource_link(self, mock_mcp_client, mock_openai): - """Test formatting tool results with ResourceLink content.""" - agent = OpenAIAgent.create( - mcp_client=mock_mcp_client, - model_client=mock_openai, - validate_api_key=False, - ) - - tool_calls = [ - MCPToolCall(name="resource_tool", arguments={}, id="call_resource"), - ] - - # Create a ResourceLink content - resource_link = types.ResourceLink( - type="resource_link", - name="test_resource", - uri=AnyUrl("file:///test/resource"), - ) - - tool_results = [ - MCPToolResult( - content=[resource_link], - isError=False, + async def test_get_response_with_reasoning(self, mock_openai: AsyncOpenAI) -> None: + """Test getting response with reasoning.""" + mock_response = AsyncMock() + mock_response.output = [ + ResponseReasoningItem( + id="reason_123", + type="reasoning", + summary=[Summary(type="summary_text", text="Thinking about it...")], + ), + ResponseOutputMessage( + id="msg_123", + type="message", + role="assistant", + status="completed", + content=[ResponseOutputText(type="output_text", text="Answer!")], ), ] + mock_openai.responses.create = AsyncMock(return_value=mock_response) - messages = await agent.format_tool_results(tool_calls, tool_results) - - assert len(messages) == 1 - msg = cast("dict[str, Any]", messages[0]) - output = cast("list[dict[str, Any]]", msg["output"]) - assert len(output) == 1 - assert output[0]["type"] == "input_file" - assert output[0]["file_url"] == "file:///test/resource" - - @pytest.mark.asyncio - async def test_format_tool_results_with_embedded_text_resource( - self, mock_mcp_client, mock_openai - ): - """Test formatting tool results with EmbeddedResource containing text.""" agent = OpenAIAgent.create( - mcp_client=mock_mcp_client, model_client=mock_openai, validate_api_key=False, ) + agent.openai_tools = [] + agent._initialized = True - tool_calls = [ - MCPToolCall(name="embed_tool", arguments={}, id="call_embed"), - ] - - # Create an EmbeddedResource with TextResourceContents - text_resource = types.TextResourceContents( - uri=AnyUrl("file:///test.txt"), - mimeType="text/plain", - text="Embedded text content", - ) - embedded = types.EmbeddedResource( - type="resource", - resource=text_resource, - ) + response = await agent.get_response([]) + assert "Thinking about it..." in (response.reasoning or "") + assert response.content == "Answer!" - tool_results = [ - MCPToolResult( - content=[embedded], - isError=False, - ), - ] - messages = await agent.format_tool_results(tool_calls, tool_results) +class TestOpenAIToolConversion: + """Tests for tool conversion to OpenAI format.""" - assert len(messages) == 1 - msg = cast("dict[str, Any]", messages[0]) - output = cast("list[dict[str, Any]]", msg["output"]) - assert len(output) == 1 - assert output[0]["type"] == "input_text" - assert output[0]["text"] == "Embedded text content" + @pytest.fixture + def mock_openai(self) -> AsyncOpenAI: + """Create a stub OpenAI client.""" + with patch("hud.agents.openai.AsyncOpenAI") as mock_class: + client = AsyncOpenAI(api_key="test", base_url="http://localhost") + client.responses.create = AsyncMock() + mock_class.return_value = client + yield client @pytest.mark.asyncio - async def test_format_tool_results_with_embedded_blob_resource( - self, mock_mcp_client, mock_openai - ): - """Test formatting tool results with EmbeddedResource containing blob.""" + async def test_shell_tool_conversion(self, mock_openai: AsyncOpenAI) -> None: + """Test that shell tool is converted to native format.""" + tools = [ + types.Tool( + name="shell", + description="Execute shell commands", + inputSchema={"type": "object"}, + ) + ] + ctx = MockEvalContext(tools=tools) agent = OpenAIAgent.create( - mcp_client=mock_mcp_client, model_client=mock_openai, validate_api_key=False, ) - tool_calls = [ - MCPToolCall(name="blob_tool", arguments={}, id="call_blob"), - ] - - # Create an EmbeddedResource with BlobResourceContents - blob_resource = types.BlobResourceContents( - uri=AnyUrl("file:///test.bin"), - mimeType="application/octet-stream", - blob="YmluYXJ5IGRhdGE=", # base64 encoded "binary data" - ) - embedded = types.EmbeddedResource( - type="resource", - resource=blob_resource, - ) - - tool_results = [ - MCPToolResult( - content=[embedded], - isError=False, - ), - ] - - messages = await agent.format_tool_results(tool_calls, tool_results) + agent.ctx = ctx + await agent._initialize_from_ctx(ctx) - assert len(messages) == 1 - msg = cast("dict[str, Any]", messages[0]) - output = cast("list[dict[str, Any]]", msg["output"]) - assert len(output) == 1 - assert output[0]["type"] == "input_file" - assert output[0]["file_data"] == "YmluYXJ5IGRhdGE=" + # Check for native shell tool + shell_tool = next((t for t in agent.openai_tools if t.get("type") == "shell"), None) + assert shell_tool is not None @pytest.mark.asyncio - async def test_format_tool_results_empty_content(self, mock_mcp_client, mock_openai): - """Test formatting tool results with completely empty content.""" + async def test_computer_tool_conversion(self, mock_openai: AsyncOpenAI) -> None: + """Test that computer tool is converted to native format.""" + tools = [ + types.Tool( + name="computer", + description="Control computer", + inputSchema={"type": "object"}, + ) + ] + ctx = MockEvalContext(tools=tools) agent = OpenAIAgent.create( - mcp_client=mock_mcp_client, model_client=mock_openai, validate_api_key=False, ) - tool_calls = [ - MCPToolCall(name="empty_tool", arguments={}, id="call_empty"), - ] - - tool_results = [ - MCPToolResult( - content=[], # Empty content - isError=False, - ), - ] - - messages = await agent.format_tool_results(tool_calls, tool_results) + agent.ctx = ctx + await agent._initialize_from_ctx(ctx) - assert len(messages) == 1 - msg = cast("dict[str, Any]", messages[0]) - output = cast("list[dict[str, Any]]", msg["output"]) - # Should have fallback empty text when no content - assert len(output) == 1 - assert output[0]["type"] == "input_text" - assert output[0]["text"] == "" + # Check for native computer tool + computer_tool = next( + (t for t in agent.openai_tools if t.get("type") == "computer_use_preview"), + None, + ) + assert computer_tool is not None diff --git a/hud/agents/tests/test_operator.py b/hud/agents/tests/test_operator.py index b9900247..94861522 100644 --- a/hud/agents/tests/test_operator.py +++ b/hud/agents/tests/test_operator.py @@ -11,43 +11,75 @@ from openai.types.responses.response_computer_tool_call import PendingSafetyCheck from hud.agents.operator import OperatorAgent +from hud.eval.context import EvalContext from hud.types import MCPToolCall, MCPToolResult +class MockEvalContext(EvalContext): + """Mock EvalContext for testing.""" + + def __init__(self, tools: list[types.Tool] | None = None) -> None: + self.prompt = "Test prompt" + self._tools = tools or [] + self._submitted: str | None = None + self.reward: float | None = None + + async def list_tools(self) -> list[types.Tool]: + return self._tools + + async def call_tool(self, call: Any, /, **kwargs: Any) -> MCPToolResult: + return MCPToolResult( + content=[types.TextContent(type="text", text="ok")], + isError=False, + ) + + async def submit(self, answer: str) -> None: + self._submitted = answer + + class TestOperatorAgent: """Test OperatorAgent class.""" @pytest.fixture - def mock_openai(self): + def mock_openai(self) -> AsyncOpenAI: """Create a mock OpenAI client.""" client = AsyncOpenAI(api_key="test", base_url="http://localhost") client.responses.create = AsyncMock() with patch("hud.agents.openai.AsyncOpenAI", return_value=client): yield client + @pytest.fixture + def mock_eval_context_computer(self) -> MockEvalContext: + """Create a mock EvalContext with computer tool.""" + return MockEvalContext( + tools=[ + types.Tool( + name="openai_computer", + description="OpenAI computer use tool", + inputSchema={}, + ) + ] + ) + @pytest.mark.asyncio - async def test_init(self, mock_mcp_client_openai_computer): + async def test_init(self, mock_openai: AsyncOpenAI) -> None: """Test agent initialization.""" - mock_model_client = AsyncOpenAI(api_key="test") agent = OperatorAgent.create( - mcp_client=mock_mcp_client_openai_computer, - model_client=mock_model_client, + model_client=mock_openai, checkpoint_name="gpt-4", - validate_api_key=False, # Skip validation in tests + validate_api_key=False, ) assert agent.model_name == "Operator" assert agent.config.checkpoint_name == "gpt-4" - assert agent.openai_client == mock_model_client + assert agent.openai_client == mock_openai @pytest.mark.asyncio - async def test_format_blocks(self, mock_mcp_client_openai_computer): + async def test_format_blocks(self, mock_openai: AsyncOpenAI) -> None: """Test formatting content blocks.""" - mock_model_client = AsyncOpenAI(api_key="test") agent = OperatorAgent.create( - mcp_client=mock_mcp_client_openai_computer, - model_client=mock_model_client, - validate_api_key=False, # Skip validation in tests + model_client=mock_openai, + validate_api_key=False, ) # Test with text blocks @@ -85,17 +117,16 @@ async def test_format_blocks(self, mock_mcp_client_openai_computer): } @pytest.mark.asyncio - async def test_format_tool_results(self, mock_mcp_client_openai_computer, mock_openai): + async def test_format_tool_results(self, mock_openai: AsyncOpenAI) -> None: """Test formatting tool results.""" agent = OperatorAgent.create( - mcp_client=mock_mcp_client_openai_computer, model_client=mock_openai, - validate_api_key=False, # Skip validation in tests + validate_api_key=False, ) tool_calls = [ - MCPToolCall(name="test_tool", arguments={}, id="call_123"), # type: ignore - MCPToolCall(name="screenshot", arguments={}, id="call_456"), # type: ignore + MCPToolCall(name="test_tool", arguments={}, id="call_123"), + MCPToolCall(name="screenshot", arguments={}, id="call_456"), ] tool_results = [ @@ -126,18 +157,15 @@ async def test_format_tool_results(self, mock_mcp_client_openai_computer, mock_o assert output1[0]["image_url"] == "data:image/png;base64,base64data" @pytest.mark.asyncio - async def test_format_tool_results_with_error( - self, mock_mcp_client_openai_computer, mock_openai - ): + async def test_format_tool_results_with_error(self, mock_openai: AsyncOpenAI) -> None: """Test formatting tool results with errors.""" agent = OperatorAgent.create( - mcp_client=mock_mcp_client_openai_computer, model_client=mock_openai, - validate_api_key=False, # Skip validation in tests + validate_api_key=False, ) tool_calls = [ - MCPToolCall(name="failing_tool", arguments={}, id="call_error"), # type: ignore + MCPToolCall(name="failing_tool", arguments={}, id="call_error"), ] tool_results = [ @@ -160,20 +188,19 @@ async def test_format_tool_results_with_error( assert output[1]["text"] == "Something went wrong" @pytest.mark.asyncio - async def test_get_model_response(self, mock_mcp_client_openai_computer, mock_openai): + async def test_get_model_response( + self, mock_openai: AsyncOpenAI, mock_eval_context_computer: MockEvalContext + ) -> None: """Test getting model response from OpenAI API.""" - # Disable telemetry for this test to avoid backend configuration issues with patch("hud.settings.settings.telemetry_enabled", False): agent = OperatorAgent.create( - mcp_client=mock_mcp_client_openai_computer, model_client=mock_openai, - validate_api_key=False, # Skip validation in tests + validate_api_key=False, ) - # Set up available tools so agent doesn't return "No computer use tools available" - agent._available_tools = [ - types.Tool(name="computer_openai", description="Computer tool", inputSchema={}) - ] + # Initialize with context + agent.ctx = mock_eval_context_computer + await agent._initialize_from_ctx(mock_eval_context_computer) # Mock OpenAI API response for a successful computer use response mock_response = MagicMock() @@ -195,24 +222,22 @@ async def test_get_model_response(self, mock_mcp_client_openai_computer, mock_op messages = [{"prompt": "What's on the screen?", "screenshot": None}] response = await agent.get_response(messages) - # The test should verify that the response is processed correctly - # Since the isinstance checks will fail, content will be empty, but done should be True assert response.done is True assert response.tool_calls == [] @pytest.mark.asyncio - async def test_handle_empty_response(self, mock_mcp_client_openai_computer, mock_openai): + async def test_handle_empty_response( + self, mock_openai: AsyncOpenAI, mock_eval_context_computer: MockEvalContext + ) -> None: """Test handling empty response from API.""" agent = OperatorAgent.create( - mcp_client=mock_mcp_client_openai_computer, model_client=mock_openai, - validate_api_key=False, # Skip validation in tests + validate_api_key=False, ) - # Set up available tools - agent._available_tools = [ - types.Tool(name="openai_computer", description="Computer tool", inputSchema={}) - ] + # Initialize with context + agent.ctx = mock_eval_context_computer + await agent._initialize_from_ctx(mock_eval_context_computer) # Mock empty response mock_response = MagicMock() @@ -229,10 +254,9 @@ async def test_handle_empty_response(self, mock_mcp_client_openai_computer, mock assert response.tool_calls == [] @pytest.mark.asyncio - async def test_pending_safety_checks_initialization(self, mock_mcp_client, mock_openai): + async def test_pending_safety_checks_initialization(self, mock_openai: AsyncOpenAI) -> None: """Test that OperatorAgent initializes pending_call_id and pending_safety_checks.""" agent = OperatorAgent.create( - mcp_client=mock_mcp_client, model_client=mock_openai, validate_api_key=False, ) @@ -253,10 +277,9 @@ async def test_pending_safety_checks_initialization(self, mock_mcp_client, mock_ assert agent.pending_safety_checks[0].id == "safety_check_id" @pytest.mark.asyncio - async def test_extract_tool_call_computer(self, mock_mcp_client, mock_openai): + async def test_extract_tool_call_computer(self, mock_openai: AsyncOpenAI) -> None: """Test that _extract_tool_call routes computer_call to openai_computer.""" agent = OperatorAgent.create( - mcp_client=mock_mcp_client, model_client=mock_openai, validate_api_key=False, ) @@ -281,10 +304,9 @@ async def test_extract_tool_call_computer(self, mock_mcp_client, mock_openai): assert agent.pending_safety_checks == mock_item.pending_safety_checks @pytest.mark.asyncio - async def test_extract_tool_call_delegates_to_super(self, mock_mcp_client, mock_openai): + async def test_extract_tool_call_delegates_to_super(self, mock_openai: AsyncOpenAI) -> None: """Test that _extract_tool_call delegates non-computer calls to parent.""" agent = OperatorAgent.create( - mcp_client=mock_mcp_client, model_client=mock_openai, validate_api_key=False, ) diff --git a/hud/agents/tests/test_run_eval.py b/hud/agents/tests/test_run_eval.py index 746c80d6..1f9a7fc1 100644 --- a/hud/agents/tests/test_run_eval.py +++ b/hud/agents/tests/test_run_eval.py @@ -1,4 +1,4 @@ -"""Tests for run_eval and EnvironmentClient.""" +"""Tests for MCPAgent.run() with EvalContext.""" from __future__ import annotations @@ -9,7 +9,6 @@ from hud.agents import MCPAgent from hud.agents.base import BaseCreateParams -from hud.clients.environment import EnvironmentClient from hud.eval.context import EvalContext from hud.types import AgentResponse, BaseAgentConfig, MCPToolCall, MCPToolResult @@ -24,7 +23,7 @@ class MockCreateParams(BaseCreateParams, MockConfig): class MockMCPAgent(MCPAgent): - """Mock agent for testing run_eval.""" + """Mock agent for testing run().""" metadata: ClassVar[dict[str, Any] | None] = {} config_cls: ClassVar[type[BaseAgentConfig]] = MockConfig @@ -37,11 +36,6 @@ def __init__(self, **kwargs: Any) -> None: def set_response(self, response: AgentResponse) -> None: self._response = response - async def create_initial_messages( - self, prompt: str, initial_screenshot: bool = False - ) -> list[dict[str, Any]]: - return [{"role": "user", "content": prompt}] - async def get_response(self, messages: list[dict[str, Any]]) -> AgentResponse: return self._response @@ -50,9 +44,6 @@ async def format_tool_results( ) -> list[dict[str, Any]]: return [{"role": "tool", "content": str(r)} for r in tool_results] - async def create_user_message(self, text: str) -> Any: - return {"role": "user", "content": text} - async def get_system_messages(self) -> list[Any]: return [] @@ -71,11 +62,19 @@ def __init__(self, prompt: str = "Test prompt", tools: list[types.Tool] | None = ] self._submitted: str | None = None self.reward: float | None = None + self._initialized = True async def list_tools(self) -> list[types.Tool]: return self._tools - async def call_tool(self, name: str, **kwargs: Any) -> MCPToolResult: + async def call_tool(self, call: Any, /, **kwargs: Any) -> MCPToolResult: + # Handle tuple format (name, args) + if isinstance(call, tuple): + name = call[0] + elif hasattr(call, "name"): + name = call.name + else: + name = str(call) return MCPToolResult( content=[types.TextContent(type="text", text=f"Result from {name}")], isError=False, @@ -85,106 +84,70 @@ async def submit(self, answer: str) -> None: self._submitted = answer -class TestEnvironmentClient: - """Tests for EnvironmentClient adapter.""" - - @pytest.mark.asyncio - async def test_initialize(self) -> None: - """Test client initialization.""" - ctx = MockEvalContext() - client = EnvironmentClient(ctx) - - assert not client.is_connected - await client.initialize() - assert client.is_connected - - @pytest.mark.asyncio - async def test_list_tools(self) -> None: - """Test listing tools through adapter.""" - ctx = MockEvalContext() - client = EnvironmentClient(ctx) - - tools = await client.list_tools() - assert len(tools) == 1 - assert tools[0].name == "test_tool" - - @pytest.mark.asyncio - async def test_call_tool(self) -> None: - """Test calling tools through adapter.""" - ctx = MockEvalContext() - client = EnvironmentClient(ctx) - - result = await client.call_tool(MCPToolCall(name="test_tool", arguments={})) - assert not result.isError - assert len(result.content) == 1 - - @pytest.mark.asyncio - async def test_mcp_config_empty(self) -> None: - """Test mcp_config is empty for environment clients.""" - ctx = MockEvalContext() - client = EnvironmentClient(ctx) - assert client.mcp_config == {} - - @pytest.mark.asyncio - async def test_shutdown(self) -> None: - """Test shutdown resets initialized state.""" - ctx = MockEvalContext() - client = EnvironmentClient(ctx) - - await client.initialize() - assert client.is_connected - - await client.shutdown() - assert not client.is_connected - - -class TestRunEval: - """Tests for MCPAgent.run_eval().""" +class TestRun: + """Tests for MCPAgent.run() with EvalContext.""" @pytest.mark.asyncio - async def test_run_eval_basic(self) -> None: - """Test basic run_eval flow.""" + async def test_run_basic(self) -> None: + """Test basic run() flow.""" ctx = MockEvalContext(prompt="Do the task") agent = MockMCPAgent() - result = await agent.run_eval(ctx) + result = await agent.run(ctx) assert result.done assert result.content == "Test response" assert ctx._submitted == "Test response" @pytest.mark.asyncio - async def test_run_eval_no_prompt_raises(self) -> None: - """Test run_eval raises when prompt is not set.""" + async def test_run_no_prompt_raises(self) -> None: + """Test run() raises when prompt is not set.""" ctx = MockEvalContext(prompt="") agent = MockMCPAgent() with pytest.raises(ValueError, match="prompt is not set"): - await agent.run_eval(ctx) + await agent.run(ctx) @pytest.mark.asyncio - async def test_run_eval_wrong_type_raises(self) -> None: - """Test run_eval raises TypeError for non-EvalContext.""" + async def test_run_wrong_type_raises(self) -> None: + """Test run() raises TypeError for non-EvalContext.""" agent = MockMCPAgent() with pytest.raises(TypeError, match="must be EvalContext"): - await agent.run_eval("not an eval context") # type: ignore[arg-type] + await agent.run("not an eval context") # type: ignore[arg-type] @pytest.mark.asyncio - async def test_run_eval_clears_client(self) -> None: - """Test run_eval clears mcp_client after completion.""" + async def test_run_clears_ctx(self) -> None: + """Test run() clears ctx after completion.""" ctx = MockEvalContext(prompt="Do the task") agent = MockMCPAgent() - await agent.run_eval(ctx) - assert agent.mcp_client is None + await agent.run(ctx) + assert agent.ctx is None @pytest.mark.asyncio - async def test_run_eval_no_submit_on_empty_content(self) -> None: - """Test run_eval doesn't submit when content is empty.""" + async def test_run_no_submit_on_empty_content(self) -> None: + """Test run() doesn't submit when content is empty.""" ctx = MockEvalContext(prompt="Do the task") agent = MockMCPAgent() agent.set_response(AgentResponse(content="", tool_calls=[], done=True)) - await agent.run_eval(ctx) + await agent.run(ctx) assert ctx._submitted is None + + @pytest.mark.asyncio + async def test_run_initializes_tools(self) -> None: + """Test run() initializes tools from context.""" + ctx = MockEvalContext( + prompt="Do the task", + tools=[ + types.Tool(name="tool1", description="Tool 1", inputSchema={}), + types.Tool(name="tool2", description="Tool 2", inputSchema={}), + ], + ) + agent = MockMCPAgent() + + await agent.run(ctx) + + assert agent._initialized + # After cleanup, ctx is None but tools were discovered diff --git a/hud/cli/eval.py b/hud/cli/eval.py index 23ee244e..2685282e 100644 --- a/hud/cli/eval.py +++ b/hud/cli/eval.py @@ -19,14 +19,12 @@ from rich import box from rich.table import Table -from hud.cli.utils.env_check import ensure_built, find_environment_dir from hud.settings import settings from hud.types import AgentType from hud.utils.hud_console import HUDConsole if TYPE_CHECKING: from hud.agents.base import MCPAgent - from hud.types import Task logger = logging.getLogger(__name__) hud_console = HUDConsole() @@ -487,93 +485,34 @@ def display(self) -> None: hud_console.console.print(table) -# ============================================================================= -# Task loading -# ============================================================================= - - -def _load_tasks_from_source(source: str) -> list[Task]: - """Load tasks from file or HuggingFace dataset.""" - from hud.utils.tasks import load_tasks - - path = Path(source) - if path.exists() and path.suffix in {".json", ".jsonl"}: - hud_console.info("📊 Loading task file…") - tasks = load_tasks(str(path)) - try: - env_dir = find_environment_dir(path) - if env_dir is not None: - ensure_built(env_dir, interactive=False) - except Exception as exc: - hud_console.debug(f"Eval preflight env check skipped: {exc}") - else: - hud_console.info(f"📊 Loading tasks from: {source}…") - tasks = load_tasks(source) - - if not tasks: - hud_console.error(f"No tasks found in: {source}") - raise typer.Exit(1) - - return tasks # type: ignore[return-value] - - -def _warn_local_mcp(tasks: list[Task], source: str) -> None: - """Warn user if tasks use local MCP configs.""" - try: - has_local = any( - isinstance(server_cfg, dict) and "command" in server_cfg and not server_cfg.get("url") - for t in tasks - for server_cfg in (getattr(t, "mcp_config", {}) or {}).values() - if isinstance(getattr(t, "mcp_config", {}), dict) - ) - - if not has_local: - return - - hud_console.warning("Detected local MCP configurations (uses 'command' instead of 'url').") - hud_console.info("When running concurrently, exposed host ports from Docker may conflict.") - - if not hud_console.confirm("Proceed with local MCP servers?", default=True): - hint_file = Path(source).name if Path(source).exists() else "" - hud_console.hint(f"Convert to remote: hud convert {hint_file}") - raise typer.Exit(1) - - hint_file = Path(source).name if Path(source).exists() else "" - hud_console.hint(f"Convert to remote to avoid port conflicts: hud convert {hint_file}") - - except typer.Exit: - raise - except Exception as e: - hud_console.debug(f"Local MCP check skipped: {e}") - - # ============================================================================= # Evaluation runner # ============================================================================= -async def _run_evaluation(cfg: EvalConfig) -> tuple[list[Any], list[Task]]: - """Run evaluation with the given config.""" - from hud.datasets import run_single_task, run_tasks +async def _run_evaluation(cfg: EvalConfig) -> tuple[list[Any], list[Any]]: + """Run evaluation with the given config using run_dataset().""" + from hud.datasets import load_dataset, run_dataset if cfg.source is None or cfg.agent_type is None: raise ValueError("source and agent_type must be set") - tasks = _load_tasks_from_source(cfg.source) + # Load tasks using unified loader (handles v4→v5 conversion automatically) + hud_console.info(f"📊 Loading tasks from: {cfg.source}…") + tasks = load_dataset(cfg.source) - if not cfg.remote and (cfg.group_size > 1 or cfg.full): - _warn_local_mcp(tasks, cfg.source) - - agent_kwargs = cfg.get_agent_kwargs() - - path = Path(cfg.source) - dataset_name = path.name if path.exists() else cfg.source.split("/")[-1] - max_steps = cfg.max_steps or (100 if cfg.full else 10) + if not tasks: + hud_console.error(f"No tasks found in: {cfg.source}") + raise typer.Exit(1) # Filter by task IDs if provided if cfg.task_ids: id_set = set(cfg.task_ids) - filtered = [t for t in tasks if str(getattr(t, "id", "")) in id_set] + # Match by task.id or index + filtered = [ + t for i, t in enumerate(tasks) + if t.id in id_set or str(i) in id_set + ] if not filtered: hud_console.error(f"No tasks found matching IDs: {', '.join(cfg.task_ids)}") raise typer.Exit(1) @@ -584,57 +523,51 @@ async def _run_evaluation(cfg: EvalConfig) -> tuple[list[Any], list[Task]]: tasks = [tasks[0]] hud_console.info("Using first task (run with --full or --task-ids for more)…") - auto_respond = cfg.auto_respond if cfg.auto_respond is not None else cfg.full + hud_console.info(f"Loaded {len(tasks)} task(s)") + # Prepare agent kwargs + agent_kwargs = cfg.get_agent_kwargs() + auto_respond = cfg.auto_respond if cfg.auto_respond is not None else cfg.full if auto_respond: agent_kwargs = {**agent_kwargs, "auto_respond": True} + max_steps = cfg.max_steps or (100 if cfg.full else 10) + + # Remote execution not yet supported in new flow if cfg.remote: - hud_console.info(f"🚀 Submitting {len(tasks)} tasks for remote execution…") - await run_tasks( - tasks=tasks, - agent_type=cfg.agent_type, - agent_params=agent_kwargs, - name=f"Evaluation {dataset_name}", - metadata={"dataset": cfg.source}, - max_steps=max_steps, - group_size=cfg.group_size, - remote=True, - ) - return [], tasks + hud_console.error("Remote execution not yet supported. Use local execution.") + raise typer.Exit(1) + + # Create agent + agent = cfg.agent_type.cls.create(**agent_kwargs) + # Single task mode - show extra info if len(tasks) == 1 and cfg.group_size == 1: - task = tasks[0] logging.getLogger("hud.agents").setLevel(logging.INFO) logging.getLogger("hud.agents.base").setLevel(logging.INFO) - - hud_console.info(task.prompt) - result = await run_single_task( - task=task, - agent_type=cfg.agent_type, - agent_params=agent_kwargs, - max_steps=max_steps, - trace_name=task.prompt, + # Get prompt from args (v4 tasks) or show scenario name + prompt = tasks[0].args.get("prompt") if tasks[0].args else tasks[0].scenario + if prompt: + hud_console.info(f"Prompt: {prompt}") + else: + hud_console.info( + f"🚀 Running evaluation (max_concurrent: {cfg.max_concurrent}, " + f"group_size: {cfg.group_size})…" ) - hud_console.success(f"Reward: {result.reward}") - return [result], tasks - # Local batch execution - hud_console.info( - f"🚀 Running evaluation (max_concurrent: {cfg.max_concurrent}, " - f"group_size: {cfg.group_size})…" - ) - - results = await run_tasks( - tasks=tasks, - agent_type=cfg.agent_type, - agent_params=agent_kwargs, - name=f"Evaluation {dataset_name}", - max_concurrent=cfg.max_concurrent, - metadata={"dataset": cfg.source}, + # Run using run_dataset + results = await run_dataset( + tasks, + agent, max_steps=max_steps, + max_concurrent=cfg.max_concurrent, group_size=cfg.group_size, ) + + # Show reward for single task + if len(tasks) == 1 and cfg.group_size == 1 and results: + hud_console.success(f"Reward: {results[0].reward}") + return results, tasks diff --git a/hud/cli/flows/tasks.py b/hud/cli/flows/tasks.py index e6374d3f..a0921766 100644 --- a/hud/cli/flows/tasks.py +++ b/hud/cli/flows/tasks.py @@ -17,7 +17,7 @@ from hud.utils.tasks import load_tasks if TYPE_CHECKING: - from hud.types import Task + from hud.types import LegacyTask logger = logging.getLogger(__name__) @@ -29,7 +29,7 @@ def _is_remote_url(url: str) -> bool: return bool(re.match(r"^(https?:\/\/)?(www\.)?[a-zA-Z0-9\-\.]+\.[a-zA-Z]{2,}(\/\S*)?$", url)) -def _validate_tasks(tasks: list[Task]) -> bool: +def _validate_tasks(tasks: list[LegacyTask]) -> bool: """Validate the tasks file: return True if tasks already reference a remote MCP URL. A task is considered remote if any "url" field anywhere inside mcp_config @@ -115,7 +115,7 @@ def _derive_remote_image(lock_data: dict[str, Any]) -> str: raise typer.Exit(1) -def _extract_existing_images(tasks: list[Task]) -> set[str]: +def _extract_existing_images(tasks: list[LegacyTask]) -> set[str]: """Extract all Mcp-Image references from tasks.""" images = set() @@ -268,7 +268,7 @@ def convert_tasks_to_remote(tasks_file: str) -> str: tasks_path = Path(tasks_file).resolve() # Load validated tasks for decision-making (may resolve env vars) - tasks: list[Task] = load_tasks(str(tasks_path)) # type: ignore[assignment] + tasks: list[LegacyTask] = load_tasks(str(tasks_path)) # type: ignore[assignment] # Load raw tasks to preserve placeholders when writing back to disk raw_tasks: list[dict[str, Any]] = load_tasks(str(tasks_path), raw=True) # type: ignore[assignment] diff --git a/hud/cli/flows/templates.py b/hud/cli/flows/templates.py index b96c7752..1dfb2b05 100644 --- a/hud/cli/flows/templates.py +++ b/hud/cli/flows/templates.py @@ -44,7 +44,7 @@ def count_letter(text: str, letter: str) -> int: # 2. SCRIPTS - Define prompts and evaluation logic # ============================================================================= -@env.script("count") +@env.scenario("count") async def count_script(sentence: str, letter: str, fmt: str = "integer"): """Agent must count a letter. We check if they got it right.""" # Yield the prompt, receive the agent's final answer @@ -89,11 +89,11 @@ async def test(): api_key=settings.api_key, ) - # Create an eval from the script - eval = env("count", sentence="Strawberry world", letter="r") + # Create a task from the scenario + task = env("count", sentence="Strawberry world", letter="r") # Test with and without tools - async with hud.eval(eval, variants={{"tools": [True, False]}}) as ctx: + async with hud.eval(task, variants={{"tools": [True, False]}}) as ctx: response = await client.chat.completions.create( model="gpt-4o-mini", messages=[{{"role": "user", "content": ctx.prompt}}], diff --git a/hud/cli/rft.py b/hud/cli/rft.py index 43c35e94..53d336ce 100644 --- a/hud/cli/rft.py +++ b/hud/cli/rft.py @@ -193,9 +193,9 @@ def rft_command( # Load and validate tasks try: # Load tasks with env vars already resolved - from hud.types import Task # noqa: TC001 + from hud.types import LegacyTask # noqa: TC001 - tasks_objects: list[Task] = load_tasks(tasks_file) # type: ignore[assignment] + tasks_objects: list[LegacyTask] = load_tasks(tasks_file) # type: ignore[assignment] # Convert to dicts for patching and serialization tasks: list[dict[str, Any]] = [t.model_dump() for t in tasks_objects] if not tasks: diff --git a/hud/cli/tests/test_convert.py b/hud/cli/tests/test_convert.py index cbdb6c8b..3dc5fb19 100644 --- a/hud/cli/tests/test_convert.py +++ b/hud/cli/tests/test_convert.py @@ -8,7 +8,7 @@ import typer from hud.cli.flows.tasks import convert_tasks_to_remote -from hud.types import Task +from hud.types import LegacyTask class TestConvertCommand: @@ -84,7 +84,7 @@ def test_convert_tasks_basic( # Mock derive remote image mock_derive_remote.return_value = "registry.hud.ai/test-org/test-env:v1.0.0" - task = Task( + task = LegacyTask( prompt="Test task", mcp_config={ "local": {"command": "docker", "args": ["run", "--rm", "-i", "test-image:latest"]} @@ -133,7 +133,7 @@ def test_convert_already_remote( mock_find_env.return_value = None # No env dir needed for remote tasks # Create task that's already remote - task = Task( + task = LegacyTask( prompt="Test task", mcp_config={ "remote": { @@ -159,7 +159,7 @@ def test_convert_no_environment( mock_settings.api_key = "test-api-key" mock_find_env.return_value = None - task = Task( + task = LegacyTask( prompt="Test task", mcp_config={ "local": {"command": "docker", "args": ["run", "--rm", "-i", "test-image:latest"]} @@ -209,7 +209,7 @@ def test_convert_with_env_vars( env_file = mock_env_dir / ".env" env_file.write_text("OPENAI_API_KEY=sk-test123\nANTHROPIC_API_KEY=sk-ant456") - task = Task( + task = LegacyTask( prompt="Test task", mcp_config={ "local": { diff --git a/hud/cli/tests/test_eval.py b/hud/cli/tests/test_eval.py index d367c447..272b6b6b 100644 --- a/hud/cli/tests/test_eval.py +++ b/hud/cli/tests/test_eval.py @@ -1,555 +1,213 @@ -"""Tests for hud.cli.eval module.""" +"""Tests for hud.cli.eval module and run_dataset function.""" from __future__ import annotations -from types import SimpleNamespace from typing import Any from unittest.mock import AsyncMock, MagicMock, patch import pytest -from anthropic import AsyncAnthropic from mcp import types -from hud.agents.tests.conftest import MockMCPClient -from hud.types import Task, Trace +from hud.eval.context import EvalContext +from hud.types import MCPToolResult -class TestToolFiltering: - """Test wildcard tool filtering via agent_config in tasks.""" +class MockEvalContext(EvalContext): + """Mock EvalContext for testing.""" - @pytest.fixture - def mock_mcp_client(self): - """Fixture for mock MCP client.""" - return MockMCPClient() - - @pytest.fixture - def mock_model_client(self): - """Fixture for a lightweight Anthropic client.""" - client = AsyncAnthropic(api_key="test_key") - client.__dict__["beta"] = SimpleNamespace(messages=AsyncMock()) - return client - - async def _run_agent_with_tools( + def __init__( self, - mock_mcp_client: MagicMock, - mock_model_client: MagicMock, - tools: list[types.Tool], - agent_config: dict[str, Any] | None = None, - ) -> list[types.Tool]: - """Helper to create agent, initialize with tools and config, return filtered tools.""" - from hud.agents import ClaudeAgent - from hud.types import BaseAgentConfig - - mock_mcp_client.list_tools = AsyncMock(return_value=tools) - - task = Task( - prompt="Test", - mcp_config={"local": {"url": "http://localhost"}}, - agent_config=BaseAgentConfig(**agent_config) if agent_config else None, + prompt: str = "Test prompt", + tools: list[types.Tool] | None = None, + ) -> None: + self.prompt = prompt + self._tools = tools or [] + self._submitted: str | None = None + self.reward: float | None = None + self.results: list[EvalContext] = [] + + async def list_tools(self) -> list[types.Tool]: + return self._tools + + async def call_tool(self, call: Any, /, **kwargs: Any) -> MCPToolResult: + return MCPToolResult( + content=[types.TextContent(type="text", text="ok")], + isError=False, ) - agent = ClaudeAgent.create( - mcp_client=mock_mcp_client, - model_client=mock_model_client, - checkpoint_name="test", - validate_api_key=False, - ) - await agent.initialize(task) - return agent.get_available_tools() + async def submit(self, answer: str) -> None: + self._submitted = answer - @pytest.mark.asyncio - async def test_no_filters_returns_all_tools(self, mock_mcp_client, mock_model_client) -> None: - """Test that no filters in agent_config returns all tools.""" - tools = [ - types.Tool( - name="tool1", - description="Tool 1", - inputSchema={"type": "object", "properties": {}}, - ), - types.Tool( - name="tool2", - description="Tool 2", - inputSchema={"type": "object", "properties": {}}, - ), - types.Tool( - name="debug_tool", - description="Debug", - inputSchema={"type": "object", "properties": {}}, - ), - ] - result = await self._run_agent_with_tools(mock_mcp_client, mock_model_client, tools) +class MockAgent: + """Mock agent for testing run_dataset.""" - assert len(result) == 3 + def __init__(self) -> None: + self.run_count = 0 - @pytest.mark.asyncio - async def test_allowed_tools_filters_correctly( - self, mock_mcp_client, mock_model_client - ) -> None: - """Test that allowed_tools in agent_config filters to matching patterns.""" - tools = [ - types.Tool( - name="screenshot_take", - description="Tool 1", - inputSchema={"type": "object", "properties": {}}, - ), - types.Tool( - name="screenshot_full", - description="Tool 2", - inputSchema={"type": "object", "properties": {}}, - ), - types.Tool( - name="click", - description="Tool 3", - inputSchema={"type": "object", "properties": {}}, - ), - ] - agent_config = {"allowed_tools": ["screenshot_*"]} + async def run(self, ctx: EvalContext, *, max_steps: int = 10) -> Any: + self.run_count += 1 + ctx.reward = 1.0 + # Return a mock Trace-like object + return MagicMock(reward=1.0, done=True, content="Done") - result = await self._run_agent_with_tools( - mock_mcp_client, mock_model_client, tools, agent_config - ) - assert len(result) == 2 - assert all("screenshot" in t.name for t in result) +class TestRunDataset: + """Test the new run_dataset function.""" @pytest.mark.asyncio - async def test_disallowed_tools_excludes_correctly( - self, mock_mcp_client, mock_model_client - ) -> None: - """Test that disallowed_tools in agent_config excludes matching patterns.""" - tools = [ - types.Tool( - name="tool1", - description="Tool 1", - inputSchema={"type": "object", "properties": {}}, - ), - types.Tool( - name="debug_tool", - description="Tool 2", - inputSchema={"type": "object", "properties": {}}, - ), - types.Tool( - name="internal_secret", - description="Tool 3", - inputSchema={"type": "object", "properties": {}}, - ), + async def test_run_dataset_with_task_list(self) -> None: + """Test run_dataset with a list of tasks.""" + from hud.eval.task import Task + + tasks = [ + Task(id="task1", scenario="test"), + Task(id="task2", scenario="test"), ] - agent_config = {"disallowed_tools": ["debug_*", "internal_*"]} + agent = MockAgent() - result = await self._run_agent_with_tools( - mock_mcp_client, mock_model_client, tools, agent_config - ) + # Mock hud.eval to return our mock context + mock_ctx = MockEvalContext() - assert len(result) == 1 - assert result[0].name == "tool1" + with patch("hud.datasets.runner.hud.eval") as mock_eval: + # Set up the async context manager + mock_eval.return_value.__aenter__ = AsyncMock(return_value=mock_ctx) + mock_eval.return_value.__aexit__ = AsyncMock(return_value=None) - @pytest.mark.asyncio - async def test_both_filters_applies_allowed_then_disallowed( - self, mock_mcp_client, mock_model_client - ) -> None: - """Test that both filters in agent_config work together (disallowed takes precedence).""" - tools = [ - types.Tool( - name="browser_click", - description="Tool 1", - inputSchema={"type": "object", "properties": {}}, - ), - types.Tool( - name="browser_debug", - description="Tool 2", - inputSchema={"type": "object", "properties": {}}, - ), - types.Tool( - name="system_click", - description="Tool 3", - inputSchema={"type": "object", "properties": {}}, - ), - ] - agent_config = {"allowed_tools": ["browser_*"], "disallowed_tools": ["*_debug"]} + from hud.datasets.runner import run_dataset - result = await self._run_agent_with_tools( - mock_mcp_client, mock_model_client, tools, agent_config - ) + await run_dataset(tasks, agent, max_steps=5) # type: ignore[arg-type] - assert len(result) == 1 - assert result[0].name == "browser_click" - - -class TestRunDatasetToolFiltering: - """Test tool filtering via run_dataset with agent_config in both init and task.""" - - @pytest.fixture - def all_tools(self): - """Fixture for a standard set of tools.""" - return [ - types.Tool( - name="browser_click", - description="Click", - inputSchema={"type": "object", "properties": {}}, - ), - types.Tool( - name="browser_type", - description="Type", - inputSchema={"type": "object", "properties": {}}, - ), - types.Tool( - name="browser_debug", - description="Debug", - inputSchema={"type": "object", "properties": {}}, - ), - types.Tool( - name="system_screenshot", - description="Screenshot", - inputSchema={"type": "object", "properties": {}}, - ), - types.Tool( - name="system_execute", - description="Execute", - inputSchema={"type": "object", "properties": {}}, - ), - ] - - @pytest.fixture - def captured_agent_fixture(self): - """Fixture that returns a dictionary to capture the agent instance.""" - return {"agent": None} + # Verify hud.eval was called with correct params + mock_eval.assert_called_once() + call_kwargs = mock_eval.call_args[1] + assert call_kwargs["group"] == 1 + assert call_kwargs["max_concurrent"] == 30 - @pytest.fixture - def mock_run_context(self, captured_agent_fixture): - """Fixture for mocking _run_context.""" + # Agent should have run + assert agent.run_count == 1 - async def _mock(self, context, max_steps=10): - captured_agent_fixture["agent"] = self - return Trace(reward=1.0, done=True, content="Done") + @pytest.mark.asyncio + async def test_run_dataset_with_string_source(self) -> None: + """Test run_dataset with a string source (loads via load_dataset).""" + from hud.eval.task import Task - return _mock + mock_tasks = [Task(id="loaded_task", scenario="loaded")] + agent = MockAgent() + mock_ctx = MockEvalContext() - @pytest.fixture - def mock_call_tools(self): - """Fixture for mocking call_tools.""" + with ( + patch("hud.datasets.runner.load_dataset", return_value=mock_tasks) as mock_load, + patch("hud.datasets.runner.hud.eval") as mock_eval, + ): + mock_eval.return_value.__aenter__ = AsyncMock(return_value=mock_ctx) + mock_eval.return_value.__aexit__ = AsyncMock(return_value=None) - async def _mock(self, tool_call=None): - return [] + from hud.datasets.runner import run_dataset - return _mock + await run_dataset("my-tasks.json", agent) # type: ignore[arg-type] - @pytest.fixture - def mock_client_instance(self, all_tools): - """Fixture for mock MCP client instance.""" - mock_client = MagicMock() - mock_client.initialize = AsyncMock() - mock_client.list_tools = AsyncMock(return_value=all_tools) - mock_client.shutdown = AsyncMock() - mock_client.mcp_config = {"local": {"url": "http://localhost:8765/mcp"}} - return mock_client + # Verify load_dataset was called + mock_load.assert_called_once_with("my-tasks.json") @pytest.mark.asyncio - async def test_agent_config_intersection_union_via_run_dataset( - self, - all_tools, - captured_agent_fixture, - mock_run_context, - mock_call_tools, - mock_client_instance, - ) -> None: - """Test that allowed_tools intersect and disallowed_tools union when set in both __init__ and task.agent_config.""" # noqa: E501 - from hud.agents import ClaudeAgent - from hud.datasets.runner import run_dataset - - # Create a task with its own agent_config - task_dict = { - "prompt": "Test task", - "mcp_config": {"local": {"url": "http://localhost:8765/mcp"}}, - "agent_config": { - "allowed_tools": [ - "browser_*", - "system_screenshot", - ], # Task wants browser_* and system_screenshot - "disallowed_tools": [ - "*_debug", - "*_execute", - ], # Task disallows *_debug and *_execute - }, - } - - # Agent config passed to __init__ via run_dataset - agent_init_config = { - "allowed_tools": ["browser_*", "system_*"], # Agent init wants browser_* and system_* - "disallowed_tools": ["browser_debug"], # Agent init disallows browser_debug - "validate_api_key": False, - } - - # Create mock context - mock_ctx = AsyncMock() - mock_ctx.__aenter__ = AsyncMock(return_value=mock_ctx) - mock_ctx.__aexit__ = AsyncMock(return_value=None) - mock_ctx._suppress_link = False + async def test_run_dataset_empty_tasks_raises(self) -> None: + """Test run_dataset raises ValueError for empty tasks.""" + agent = MockAgent() - with ( - patch("hud.eval.context.EvalContext.from_task", return_value=mock_ctx), - patch("hud.eval.display.print_link"), - patch("hud.eval.display.print_complete"), - patch.object(ClaudeAgent, "_run_context", mock_run_context), - patch.object(ClaudeAgent, "call_tools", mock_call_tools), - patch("hud.clients.MCPClient", return_value=mock_client_instance), - patch("hud.settings.settings.anthropic_api_key", "sk-test-key"), - ): - # Run the dataset - await run_dataset( - name="test_job", - dataset=[task_dict], - agent_class=ClaudeAgent, - agent_config=agent_init_config, - max_steps=10, - ) - - # Verify agent was created and ran - captured_agent = captured_agent_fixture["agent"] - assert captured_agent is not None - - # Get the filtered tools - filtered_tools = captured_agent.get_available_tools() - filtered_names = {tool.name for tool in filtered_tools} - - # Expected behavior: - # 1. allowed_tools intersection: ["browser_*", "system_*"] ∩ ["browser_*", "system_screenshot"] # noqa: E501 - # Exact string intersection: only "browser_*" is in both lists - # So only tools matching browser_* are allowed: browser_click, browser_type, browser_debug # noqa: E501 - # 2. disallowed_tools union: ["browser_debug"] U ["*_debug", "*_execute"] - # Result: ["browser_debug", "*_debug", "*_execute"] (all patterns included) - # 3. Final: {browser_click, browser_type, browser_debug} - {browser_debug} - # Result: browser_click, browser_type - - expected_tools = {"browser_click", "browser_type"} - assert filtered_names == expected_tools, ( - f"Expected {expected_tools}, got {filtered_names}" - ) + with patch("hud.datasets.runner.load_dataset", return_value=[]): + from hud.datasets.runner import run_dataset + + with pytest.raises(ValueError, match="No tasks to run"): + await run_dataset([], agent) # type: ignore[arg-type] @pytest.mark.asyncio - async def test_no_allowed_tools_keeps_all_tools_except_disallowed( - self, - all_tools, - captured_agent_fixture, - mock_run_context, - mock_call_tools, - mock_client_instance, - ) -> None: - """Test that when allowed_tools is not set, all tools are available except disallowed ones.""" # noqa: E501 - from hud.agents import ClaudeAgent - from hud.datasets.runner import run_dataset - - # Create a task with its own agent_config (no allowed_tools) - task_dict = { - "prompt": "Test task", - "mcp_config": {"local": {"url": "http://localhost:8765/mcp"}}, - "agent_config": { - # No allowed_tools set - should allow all tools - "disallowed_tools": ["*_execute"], # Task disallows *_execute - }, - } - - # Agent config passed to __init__ via run_dataset (no allowed_tools) - agent_init_config = { - # No allowed_tools set - should allow all tools - "disallowed_tools": ["browser_debug"], # Agent init disallows browser_debug - "validate_api_key": False, - } - - # Create mock context - mock_ctx = AsyncMock() - mock_ctx.__aenter__ = AsyncMock(return_value=mock_ctx) - mock_ctx.__aexit__ = AsyncMock(return_value=None) - mock_ctx._suppress_link = False + async def test_run_dataset_with_group_size(self) -> None: + """Test run_dataset passes group_size to hud.eval.""" + from hud.eval.task import Task - with ( - patch("hud.eval.context.EvalContext.from_task", return_value=mock_ctx), - patch("hud.eval.display.print_link"), - patch("hud.eval.display.print_complete"), - patch.object(ClaudeAgent, "_run_context", mock_run_context), - patch.object(ClaudeAgent, "call_tools", mock_call_tools), - patch("hud.clients.MCPClient", return_value=mock_client_instance), - patch("hud.settings.settings.anthropic_api_key", "sk-test-key"), - ): - # Run the dataset - await run_dataset( - name="test_job", - dataset=[task_dict], - agent_class=ClaudeAgent, - agent_config=agent_init_config, - max_steps=10, - ) + tasks = [Task(id="task1", scenario="test")] + agent = MockAgent() + mock_ctx = MockEvalContext() - # Verify agent was created and ran - captured_agent = captured_agent_fixture["agent"] - assert captured_agent is not None + with patch("hud.datasets.runner.hud.eval") as mock_eval: + mock_eval.return_value.__aenter__ = AsyncMock(return_value=mock_ctx) + mock_eval.return_value.__aexit__ = AsyncMock(return_value=None) - # Get the filtered tools - filtered_tools = captured_agent.get_available_tools() - filtered_names = {tool.name for tool in filtered_tools} + from hud.datasets.runner import run_dataset - # Expected behavior: - # 1. allowed_tools: None (no allowed_tools set in either init or task) - # Result: All tools are initially allowed - # 2. disallowed_tools union: ["browser_debug"] U ["*_execute"] - # Result: ["browser_debug", "*_execute"] (all patterns included) - # 3. Final: {all tools} - {browser_debug, system_execute} - # Result: browser_click, browser_type, system_screenshot + await run_dataset(tasks, agent, group_size=3) # type: ignore[arg-type] - expected_tools = {"browser_click", "browser_type", "system_screenshot"} - assert filtered_names == expected_tools, ( - f"Expected {expected_tools}, got {filtered_names}" - ) + call_kwargs = mock_eval.call_args[1] + assert call_kwargs["group"] == 3 + @pytest.mark.asyncio + async def test_run_dataset_with_max_concurrent(self) -> None: + """Test run_dataset passes max_concurrent to hud.eval.""" + from hud.eval.task import Task -SYSTEM_PROMPT = "You are an assistant that can use tools to help the user. You will be given a task and you will need to use the tools to complete the task." # noqa: E501 + tasks = [Task(id="task1", scenario="test")] + agent = MockAgent() + mock_ctx = MockEvalContext() + with patch("hud.datasets.runner.hud.eval") as mock_eval: + mock_eval.return_value.__aenter__ = AsyncMock(return_value=mock_ctx) + mock_eval.return_value.__aexit__ = AsyncMock(return_value=None) -class TestSystemPromptHandling: - """Test system prompt handling through run_dataset flow.""" + from hud.datasets.runner import run_dataset - @pytest.fixture - def mock_mcp_client(self): - """Fixture for mock MCP client.""" - return MockMCPClient() + await run_dataset(tasks, agent, max_concurrent=10) # type: ignore[arg-type] - @pytest.fixture - def captured_agent_fixture(self): - """Fixture that returns a dictionary to capture the agent instance.""" - return {"agent": None} + call_kwargs = mock_eval.call_args[1] + assert call_kwargs["max_concurrent"] == 10 - @pytest.fixture - def mock_run_context(self, captured_agent_fixture): - """Fixture for mocking _run_context to capture agent.""" + @pytest.mark.asyncio + async def test_run_dataset_returns_results(self) -> None: + """Test run_dataset returns EvalContext results.""" + from hud.eval.task import Task - async def _mock(self, context, max_steps=10): - captured_agent_fixture["agent"] = self - return Trace(reward=1.0, done=True, content="Done") + tasks = [Task(id="task1", scenario="test")] + agent = MockAgent() + mock_ctx = MockEvalContext() - return _mock + with patch("hud.datasets.runner.hud.eval") as mock_eval: + mock_eval.return_value.__aenter__ = AsyncMock(return_value=mock_ctx) + mock_eval.return_value.__aexit__ = AsyncMock(return_value=None) - @pytest.fixture - def mock_call_tools(self): - """Fixture for mocking call_tools.""" + from hud.datasets.runner import run_dataset - async def _mock(self, tool_call=None): - return [] + results = await run_dataset(tasks, agent) # type: ignore[arg-type] - return _mock + # Should return list with the context + assert len(results) == 1 + assert results[0] is mock_ctx @pytest.mark.asyncio - async def test_task_system_prompt_only( - self, captured_agent_fixture, mock_run_context, mock_call_tools, mock_mcp_client - ) -> None: - """Test that task system_prompt is appended when agent has default system prompt.""" - from hud.agents import ClaudeAgent - from hud.datasets.runner import run_dataset - - task_system_prompt = "Task prompt" - - # Create a task with its own system_prompt in agent_config - task_dict = { - "prompt": "Test task", - "mcp_config": {"local": {"url": "http://localhost:8765/mcp"}}, - "agent_config": { - "system_prompt": task_system_prompt, - }, - } - - # Agent config with no custom system_prompt (will use default) - agent_init_config = {"validate_api_key": False, "system_prompt": SYSTEM_PROMPT} - - # Create mock context - mock_ctx = AsyncMock() - mock_ctx.__aenter__ = AsyncMock(return_value=mock_ctx) - mock_ctx.__aexit__ = AsyncMock(return_value=None) - mock_ctx._suppress_link = False + async def test_run_dataset_parallel_results(self) -> None: + """Test run_dataset returns ctx.results for parallel execution.""" + from hud.eval.task import Task - with ( - patch("hud.eval.context.EvalContext.from_task", return_value=mock_ctx), - patch("hud.eval.display.print_link"), - patch("hud.eval.display.print_complete"), - patch.object(ClaudeAgent, "_run_context", mock_run_context), - patch.object(ClaudeAgent, "call_tools", mock_call_tools), - patch("hud.clients.MCPClient", return_value=mock_mcp_client), - patch("hud.settings.settings.anthropic_api_key", "sk-test-key"), - ): - # Run the dataset - await run_dataset( - name="test_job", - dataset=[task_dict], - agent_class=ClaudeAgent, - agent_config=agent_init_config, - max_steps=10, - ) - - # Verify agent was created and ran - captured_agent = captured_agent_fixture["agent"] - assert captured_agent is not None - - # Verify the task system prompt was appended - assert captured_agent.system_prompt.endswith(f"\n\n{task_system_prompt}") - # Verify it starts with the base global system prompt - assert captured_agent.system_prompt.startswith(SYSTEM_PROMPT) + tasks = [Task(id="task1", scenario="test")] + agent = MockAgent() - @pytest.mark.asyncio - async def test_both_agent_and_task_system_prompts( - self, captured_agent_fixture, mock_run_context, mock_call_tools, mock_mcp_client - ) -> None: - """Test that both agent init and task system prompts are present when both are set.""" - from hud.agents import ClaudeAgent - from hud.datasets.runner import run_dataset - - agent_custom_prompt = "Agent init prompt" - task_system_prompt = "Task prompt" - - # Create a task with its own system_prompt in agent_config - task_dict = { - "prompt": "Test task", - "mcp_config": {"local": {"url": "http://localhost:8765/mcp"}}, - "agent_config": { - "system_prompt": task_system_prompt, - }, - } - - # Agent config WITH custom system_prompt - agent_init_config = { - "system_prompt": agent_custom_prompt, - "validate_api_key": False, - } - - # Create mock context - mock_ctx = AsyncMock() - mock_ctx.__aenter__ = AsyncMock(return_value=mock_ctx) - mock_ctx.__aexit__ = AsyncMock(return_value=None) - mock_ctx._suppress_link = False + # Create mock context with results (parallel execution) + mock_result1 = MockEvalContext(prompt="result1") + mock_result1.reward = 0.8 + mock_result2 = MockEvalContext(prompt="result2") + mock_result2.reward = 0.9 - with ( - patch("hud.eval.context.EvalContext.from_task", return_value=mock_ctx), - patch("hud.eval.display.print_link"), - patch("hud.eval.display.print_complete"), - patch.object(ClaudeAgent, "_run_context", mock_run_context), - patch.object(ClaudeAgent, "call_tools", mock_call_tools), - patch("hud.clients.MCPClient", return_value=mock_mcp_client), - patch("hud.settings.settings.anthropic_api_key", "sk-test-key"), - ): - # Run the dataset - await run_dataset( - name="test_job", - dataset=[task_dict], - agent_class=ClaudeAgent, - agent_config=agent_init_config, - max_steps=10, - ) - - # Verify agent was created and ran - captured_agent = captured_agent_fixture["agent"] - assert captured_agent is not None - - # Verify the task system prompt was appended at the end - assert captured_agent.system_prompt.endswith(f"\n\n{task_system_prompt}") - # Verify it starts with the agent custom prompt - assert captured_agent.system_prompt.startswith(agent_custom_prompt) - # Verify both prompts are present - assert agent_custom_prompt in captured_agent.system_prompt - assert task_system_prompt in captured_agent.system_prompt + mock_ctx = MockEvalContext() + mock_ctx.results = [mock_result1, mock_result2] + + with patch("hud.datasets.runner.hud.eval") as mock_eval: + mock_eval.return_value.__aenter__ = AsyncMock(return_value=mock_ctx) + mock_eval.return_value.__aexit__ = AsyncMock(return_value=None) + + from hud.datasets.runner import run_dataset + + results = await run_dataset(tasks, agent) # type: ignore[arg-type] + + # Should return the parallel results + assert len(results) == 2 + assert results[0].reward == 0.8 + assert results[1].reward == 0.9 diff --git a/hud/cli/utils/metadata.py b/hud/cli/utils/metadata.py index f9241752..f26db9ea 100644 --- a/hud/cli/utils/metadata.py +++ b/hud/cli/utils/metadata.py @@ -231,7 +231,7 @@ async def analyze_from_metadata(reference: str, output_format: str, verbose: boo } ) - # Derive scenarios from script prompts/resources if present + # Derive scenarios from scenario prompts/resources if present scenarios_by_id: dict[str, dict] = {} for p in analysis["prompts"]: desc = (p.get("description") or "").strip() @@ -240,11 +240,11 @@ async def analyze_from_metadata(reference: str, output_format: str, verbose: boo scenario_id = p.get("name") if not scenario_id: continue - env_name, script_name = ([*scenario_id.split(":", 1), ""])[:2] + env_name, scenario_name = ([*scenario_id.split(":", 1), ""])[:2] scenarios_by_id[scenario_id] = { "id": scenario_id, "env": env_name, - "name": script_name or scenario_id, + "name": scenario_name or scenario_id, "setup_description": desc, "arguments": p.get("arguments") or [], "has_setup_prompt": True, @@ -257,12 +257,12 @@ async def analyze_from_metadata(reference: str, output_format: str, verbose: boo scenario_id = r.get("uri") if not scenario_id: continue - env_name, script_name = ([*scenario_id.split(":", 1), ""])[:2] + env_name, scenario_name = ([*scenario_id.split(":", 1), ""])[:2] if scenario_id not in scenarios_by_id: scenarios_by_id[scenario_id] = { "id": scenario_id, "env": env_name, - "name": script_name or scenario_id, + "name": scenario_name or scenario_id, "arguments": [], "has_setup_prompt": False, "has_evaluate_resource": True, diff --git a/hud/clients/base.py b/hud/clients/base.py index b760058b..ee5ad5a6 100644 --- a/hud/clients/base.py +++ b/hud/clients/base.py @@ -436,10 +436,10 @@ async def analyze_environment(self) -> dict[str, Any]: if self.verbose: hud_console.debug("Could not list prompts: " + str(e)) - # Derive "scenarios" from Environment.@script prompts/resources. + # Derive "scenarios" from Environment.@scenario prompts/resources. # A scenario is exposed as: - # - Prompt: name "{env}:{script}" with description prefix "[Setup]" - # - Resource: uri "{env}:{script}" with description prefix "[Evaluate]" + # - Prompt: name "{env}:{scenario}" with description prefix "[Setup]" + # - Resource: uri "{env}:{scenario}" with description prefix "[Evaluate]" scenarios_by_id: dict[str, dict[str, Any]] = {} for p in analysis.get("prompts", []): @@ -449,11 +449,11 @@ async def analyze_environment(self) -> dict[str, Any]: scenario_id = p.get("name") if not scenario_id: continue - env_name, script_name = ([*scenario_id.split(":", 1), ""])[:2] + env_name, scenario_name = ([*scenario_id.split(":", 1), ""])[:2] scenarios_by_id[scenario_id] = { "id": scenario_id, "env": env_name, - "name": script_name or scenario_id, + "name": scenario_name or scenario_id, "setup_description": desc, "arguments": p.get("arguments") or [], "has_setup_prompt": True, @@ -467,12 +467,12 @@ async def analyze_environment(self) -> dict[str, Any]: scenario_id = r.get("uri") if not scenario_id: continue - env_name, script_name = ([*scenario_id.split(":", 1), ""])[:2] + env_name, scenario_name = ([*scenario_id.split(":", 1), ""])[:2] if scenario_id not in scenarios_by_id: scenarios_by_id[scenario_id] = { "id": scenario_id, "env": env_name, - "name": script_name or scenario_id, + "name": scenario_name or scenario_id, "arguments": [], "has_setup_prompt": False, "has_evaluate_resource": True, diff --git a/hud/clients/tests/test_analyze_scenarios.py b/hud/clients/tests/test_analyze_scenarios.py index 9f18ea7b..46505507 100644 --- a/hud/clients/tests/test_analyze_scenarios.py +++ b/hud/clients/tests/test_analyze_scenarios.py @@ -50,7 +50,7 @@ async def _disconnect(self) -> None: # pragma: no cover @pytest.mark.asyncio -async def test_analyze_environment_derives_scenarios_from_script_prompt_and_resource() -> None: +async def test_analyze_environment_derives_scenarios_from_scenario_prompt_and_resource() -> None: prompts = [ types.Prompt( name="my-env:checkout", diff --git a/hud/datasets/__init__.py b/hud/datasets/__init__.py index 951a32d7..e67ac560 100644 --- a/hud/datasets/__init__.py +++ b/hud/datasets/__init__.py @@ -7,10 +7,11 @@ # Execution functions from __future__ import annotations -from hud.types import Task +from hud.types import LegacyTask from hud.utils.tasks import save_tasks -from .runner import run_dataset, run_single_task, run_tasks +from .loader import load_dataset +from .runner import run_dataset from .utils import ( BatchRequest, SingleTaskRequest, @@ -22,12 +23,11 @@ __all__ = [ "BatchRequest", "SingleTaskRequest", - "Task", + "LegacyTask", "calculate_group_stats", "display_results", + "load_dataset", "run_dataset", - "run_single_task", - "run_tasks", "save_tasks", "submit_rollouts", ] diff --git a/hud/datasets/loader.py b/hud/datasets/loader.py new file mode 100644 index 00000000..984c3437 --- /dev/null +++ b/hud/datasets/loader.py @@ -0,0 +1,177 @@ +"""Dataset loading utilities for HUD. + +Unified interface for loading evaluation datasets from: +- HUD API (v5 format) +- Local JSON/JSONL files (v4 LegacyTask format, auto-converted) +""" + +from __future__ import annotations + +import json +import logging +from pathlib import Path +from typing import TYPE_CHECKING, Any + +if TYPE_CHECKING: + from hud.eval.task import Task + +logger = logging.getLogger(__name__) + +__all__ = ["load_dataset"] + + +def _is_legacy_task_format(item: dict[str, Any]) -> bool: + """Check if a dict is in v4 LegacyTask format. + + LegacyTask has: prompt, mcp_config (required), setup_tool, evaluate_tool (optional) + v5 Task has: env, scenario, args + """ + # If it has prompt + mcp_config, it's legacy format + # If it has setup_tool or evaluate_tool, it's legacy + return ( + ("prompt" in item and "mcp_config" in item) + or "setup_tool" in item + or "evaluate_tool" in item + ) + + +def _task_from_dict(item: dict[str, Any]) -> Task: + """Convert a dict to Task, auto-detecting v4 vs v5 format.""" + from hud.eval.task import Task + from hud.types import MCPToolCall + + if _is_legacy_task_format(item): + # v4 LegacyTask format - convert via Task.from_v4() + return Task.from_v4(item) + else: + # v5 format - env is EnvConfig dict with name, include, exclude + # Convert validation dicts to MCPToolCall objects + validation = None + if item.get("validation"): + validation = [MCPToolCall(**v) for v in item["validation"]] + + return Task( + id=item.get("id"), + env=item.get("env"), # EnvConfig dict: {"name": "browser", "include": [...], ...} + scenario=item.get("scenario"), + args=item.get("args", {}), + validation=validation, + ) + + +def _load_from_file(path: Path) -> list[Task]: + """Load tasks from a local JSON or JSONL file.""" + tasks: list[Task] = [] + + if path.suffix == ".jsonl": + # JSONL: one task per line + with open(path, encoding="utf-8") as f: + for line in f: + line = line.strip() + if not line: + continue + item = json.loads(line) + # Handle case where line contains a list + if isinstance(item, list): + tasks.extend(_task_from_dict(i) for i in item) + elif isinstance(item, dict): + tasks.append(_task_from_dict(item)) + else: + raise ValueError( + f"Invalid JSONL format: expected dict or list, got {type(item)}" + ) + else: + # JSON: array of tasks + with open(path, encoding="utf-8") as f: + data = json.load(f) + + if isinstance(data, list): + tasks = [_task_from_dict(item) for item in data] + elif isinstance(data, dict): + tasks = [_task_from_dict(data)] + else: + raise ValueError(f"JSON file must contain an array or object, got {type(data)}") + + return tasks + + +def _load_from_api(dataset_name: str) -> list[Task]: + """Load tasks from HUD API.""" + import httpx + + from hud.settings import settings + + headers = {} + if settings.api_key: + headers["Authorization"] = f"Bearer {settings.api_key}" + + with httpx.Client() as client: + response = client.get( + f"{settings.hud_api_url}/evals/{dataset_name}", + headers=headers, + params={"all": "true"}, + ) + response.raise_for_status() + data = response.json() + + tasks: list[Task] = [] + if isinstance(data, list): + tasks = [_task_from_dict(item) for item in data] + else: + tasks = [_task_from_dict(data)] + + return tasks + + +def load_dataset(source: str) -> list[Task]: + """Load tasks from a dataset source. + + Supports multiple sources with auto-detection: + - Local file path (JSON or JSONL) + - HUD API dataset slug (e.g., "hud-evals/SheetBench-50") + + Automatically detects and converts v4 LegacyTask format to v5 Task. + + Args: + source: Dataset source. Can be: + - Path to a local JSON/JSONL file + - HUD API dataset slug (e.g., "hud-evals/SheetBench-50") + + Returns: + List of Task objects ready to use with hud.eval() + + Example: + ```python + import hud + from hud.datasets import load_dataset + + # Load from HUD API + tasks = load_dataset("hud-evals/SheetBench-50") + + # Load from local file (v4 format auto-converted) + tasks = load_dataset("./my-tasks.json") + + # Run evaluation + async with hud.eval(tasks, variants={"model": ["gpt-4o"]}, group=4) as ctx: + await agent.run(ctx) + ``` + + Raises: + ValueError: If dataset loading fails + """ + # Check if it's a local file + path = Path(source) + if path.exists() and path.suffix in {".json", ".jsonl"}: + logger.info("Loading tasks from file: %s", source) + tasks = _load_from_file(path) + logger.info("Loaded %d tasks from %s", len(tasks), source) + return tasks + + # Otherwise, try HUD API + logger.info("Loading dataset from HUD API: %s", source) + try: + tasks = _load_from_api(source) + logger.info("Loaded %d tasks from %s", len(tasks), source) + return tasks + except Exception as e: + raise ValueError(f"Failed to load dataset '{source}' from HUD API: {e}") from e diff --git a/hud/datasets/runner.py b/hud/datasets/runner.py index 5cc6dcd6..c1e291f9 100644 --- a/hud/datasets/runner.py +++ b/hud/datasets/runner.py @@ -5,353 +5,78 @@ from __future__ import annotations -import asyncio -import json import logging -import uuid -import warnings -from typing import TYPE_CHECKING, Any, cast +from typing import TYPE_CHECKING -from hud.datasets.utils import calculate_group_stats, submit_rollouts -from hud.types import AgentType, Task, Trace +import hud if TYPE_CHECKING: - from datasets import Dataset - from hud.agents import MCPAgent - -logger = logging.getLogger("hud.datasets") - - -async def run_single_task( - task: Task, - agent_type: AgentType, - agent_params: dict[str, Any] | None = None, - max_steps: int = 10, - job_id: str | None = None, - task_id: str | None = None, - group_id: str | None = None, - trace_id: str | None = None, - trace_name: str | None = None, - metadata: dict[str, Any] | None = None, -) -> Trace: - """Execute a single task with tracing. - - This is the core execution primitive for running a single task. - - Args: - task: Task to execute - agent_type: Agent type to use - agent_params: Parameters passed to agent.create(). Should include fields - from BaseCreateParams (auto_trace, auto_respond, verbose) plus - agent-specific config fields (e.g., use_computer_beta for ClaudeConfig). - max_steps: Maximum steps for agent execution - job_id: Job ID for telemetry grouping - task_id: Task ID for telemetry - group_id: Group ID for variance estimation runs - trace_id: Trace ID for telemetry (auto-generated if not provided) - trace_name: Name for the trace (defaults to task prompt) - metadata: Additional trace metadata - - Returns: - Trace result from agent execution - """ from hud.eval.context import EvalContext + from hud.eval.task import Task - name = trace_name or task.prompt or task_id or "task" - - ctx = EvalContext.from_task( - task=task, - name=name, - trace_id=trace_id, - job_id=job_id, - group_id=group_id, - ) - - result: Trace - async with ctx: - agent = agent_type.cls.create(**(agent_params or {})) - result = await agent.run(task, max_steps=max_steps) - # Transfer reward to context for tracking - ctx.reward = result.reward - return result +logger = logging.getLogger("hud.datasets") -async def run_tasks( - tasks: list[Task], - agent_type: AgentType, - agent_params: dict[str, Any] | None = None, +async def run_dataset( + tasks: str | list[Task], + agent: MCPAgent, *, - name: str = "Evaluation", - max_concurrent: int = 30, - metadata: dict[str, Any] | None = None, max_steps: int = 10, + max_concurrent: int = 30, group_size: int = 1, - remote: bool = False, -) -> list[Any]: - """Run a list of tasks with automatic job and telemetry tracking. +) -> list[EvalContext]: + """Run an agent on a dataset of tasks. - This is the core evaluation function. Use this when you have a list of tasks - to run, whether loaded from a dataset, filtered, or constructed programmatically. + This is the primary entry point for running evaluations programmatically. Args: - tasks: List of Task objects - agent_type: AgentType specifying which agent to use - agent_params: Parameters passed to agent.create(). Should include fields - from BaseCreateParams (auto_trace, auto_respond, verbose) plus - agent-specific config fields (e.g., checkpoint_name for ClaudeConfig). - name: Name for the job - max_concurrent: Maximum concurrent tasks - metadata: Optional job metadata - max_steps: Maximum steps per task - group_size: Number of times to run each task (for variance estimation) - remote: If True, submit tasks to HUD platform for remote execution + tasks: Either a source string (file path, API slug) or list of Task objects. + If a string, tasks are loaded via load_dataset(). + agent: The agent instance to run. + max_steps: Maximum steps per task. + max_concurrent: Maximum concurrent tasks (for parallel execution). + group_size: Number of times to run each task (for variance estimation). Returns: - If remote: Empty list (fire-and-forget submission) - If group_size == 1: List of Trace results in task order. - If group_size > 1: List of statistics dicts for each task group. + List of EvalContext results from each task execution. Access `.reward` on each. Example: - # Run specific tasks locally - all_tasks = load_tasks("hud-evals/SheetBench-50") - selected = [t for t in all_tasks if t.id in ["task_1", "task_5"]] - results = await run_tasks(selected, AgentType.CLAUDE, {"checkpoint_name": "..."}) - - # Run with variance estimation - stats = await run_tasks(tasks, AgentType.CLAUDE, group_size=3) - - # Submit for remote execution - await run_tasks(tasks, AgentType.CLAUDE, remote=True) - """ - from hud.eval.display import print_complete, print_link - from hud.utils.hud_console import HUDConsole - - job_metadata = metadata or {} - job_metadata["agent_params"] = json.dumps(agent_params or {}) - job_metadata["agent_type"] = agent_type.value - if group_size > 1: - job_metadata["group_size"] = group_size - job_metadata["total_episodes"] = len(tasks) * group_size - - if remote: - from hud.telemetry.job import create_job - - hud_console = HUDConsole() - - job = create_job(name, metadata=job_metadata) - job.update_status_sync("created") - - await submit_rollouts( - tasks=tasks, - job_id=job.id, - agent_type=agent_type, - agent_params=agent_params, - max_steps=max_steps, - group_size=group_size, - metadata=metadata, - ) - hud_console.success(f"Submitted {len(tasks) * group_size} rollouts for remote execution") - hud_console.info(f"Monitor progress at: https://hud.ai/jobs/{job.id}") - return [] - - # Local execution using new eval system - agent_class = agent_type.cls - job_id = str(uuid.uuid4()) - job_url = f"https://hud.ai/jobs/{job_id}" + ```python + from hud.agents import ClaudeAgent + from hud.datasets import load_dataset, run_dataset - # Print job URL - print_link(job_url, f"🚀 Job '{name}'") + # Load tasks + tasks = load_dataset("my-tasks.json") - error_occurred = False - try: - results = await _run_tasks_with_eval( - tasks=tasks, - agent_class=agent_class, - agent_params=agent_params, - max_concurrent=max_concurrent, - max_steps=max_steps, - group_size=group_size, - job_id=job_id, - ) - error_occurred = any(r is None or (isinstance(r, Trace) and r.isError) for r in results) - return results - except Exception: - error_occurred = True - raise - finally: - print_complete(job_url, name, error=error_occurred) + # Create agent + agent = ClaudeAgent.create(checkpoint_name="claude-sonnet-4-20250514") - -async def _run_tasks_with_eval( - tasks: list[Task], - agent_class: type[MCPAgent], - agent_params: dict[str, Any] | None, - max_concurrent: int, - max_steps: int, - group_size: int, - job_id: str, -) -> list[Any]: - """Run tasks using the new EvalContext system.""" - from hud.eval.context import EvalContext - - sem = asyncio.Semaphore(max_concurrent) - params = agent_params or {} - - # Generate group IDs for each task (used for telemetry grouping) - group_ids = {i: str(uuid.uuid4()) for i in range(len(tasks))} - - # Expand tasks: each task runs group_size times - expanded: list[tuple[int, int, Task]] = [] # (flat_idx, task_idx, task) - for task_idx, task in enumerate(tasks): - for _ in range(group_size): - expanded.append((len(expanded), task_idx, task)) - - traces: list[Trace | None] = [None] * len(expanded) - - async def worker(flat_idx: int, task_idx: int, run_idx: int, task: Task) -> None: - async with sem: - try: - base_task_id = str(task.id) if task.id is not None else f"task_{task_idx}" - trace_name = task.prompt or base_task_id - - # Create EvalContext for this task run - ctx = EvalContext.from_task( - task=task, - name=trace_name, - job_id=job_id, - group_id=group_ids[task_idx] if group_size > 1 else None, - ) - ctx._suppress_link = True # Don't print individual trace links - - async with ctx: - agent = agent_class.create(**params) - result = await agent.run(task, max_steps=max_steps) - ctx.reward = result.reward - traces[flat_idx] = result - - except Exception as e: - if group_size == 1: - logger.exception("Task %s failed: %s", task_idx, e) - traces[flat_idx] = None - else: - logger.warning("Episode %s failed: %s", flat_idx, e) - traces[flat_idx] = Trace(isError=True, content=str(e), reward=0.0, done=True) - - await asyncio.gather( - *[ - worker(flat_idx, task_idx, flat_idx % group_size, task) - for flat_idx, task_idx, task in expanded - ], - return_exceptions=True, - ) - - # Return format depends on group_size - if group_size == 1: - return list(traces) - else: - return calculate_group_stats(tasks, traces, group_size, group_ids) - - -async def run_dataset( - name: str, - dataset: str | Dataset | list[dict[str, Any]], - agent_class: type[MCPAgent], - agent_config: dict[str, Any] | None = None, - max_concurrent: int = 30, - metadata: dict[str, Any] | None = None, - max_steps: int = 10, - split: str = "train", - auto_respond: bool = False, - group_size: int = 1, -) -> list[Any]: - """Load and run all tasks from a dataset. - - .. deprecated:: - Use `run_tasks()` for new code. This function remains for backwards - compatibility but `run_tasks()` offers more flexibility (filtering, - custom task lists, etc.). - - Args: - name: Name for the job - dataset: HuggingFace dataset identifier, Dataset object, or list of dicts - agent_class: Agent class to instantiate - agent_config: Configuration kwargs for agent initialization - max_concurrent: Maximum concurrent tasks - metadata: Optional job metadata - max_steps: Maximum steps per task - split: Dataset split to use when loading from string - auto_respond: Whether to use auto-response agent - group_size: Number of times to run each task (for variance estimation) - - Returns: - If group_size == 1: List of results from agent.run() in dataset order. - If group_size > 1: List of statistics dicts for each task group. + # Run evaluation + results = await run_dataset(tasks, agent, max_steps=50) + for ctx in results: + print(f"Reward: {ctx.reward}") + ``` """ - from datasets import Dataset as HFDataset - from datasets import load_dataset - - from hud.eval.display import print_complete, print_link - - warnings.warn( - "run_dataset() is deprecated. Use run_tasks() instead for more flexibility.", - DeprecationWarning, - stacklevel=2, - ) + from hud.datasets.loader import load_dataset - # Load dataset and convert to Task objects - task_dicts: list[dict[str, Any]] - dataset_link: str | None = None + # Load tasks if string provided + task_list = load_dataset(tasks) if isinstance(tasks, str) else tasks - if isinstance(dataset, str): - logger.info("Loading dataset %s from HuggingFace...", dataset) - dataset_link = dataset - loaded = cast("HFDataset", load_dataset(dataset, split=split)) - task_dicts = cast("list[dict[str, Any]]", list(loaded)) - elif isinstance(dataset, HFDataset): - task_dicts = cast("list[dict[str, Any]]", list(dataset)) - # Try to extract dataset link - try: - general_info = next(iter(dataset.info.__dict__["download_checksums"].keys())).split("/") - dataset_link = f"{general_info[3]}/{general_info[4].split('@')[0]}" - except Exception: # noqa: S110 - pass - else: - task_dicts = dataset + if not task_list: + raise ValueError("No tasks to run") - # Convert dicts to Task objects - tasks = [Task(**d) for d in task_dicts] - - # Add dataset link to metadata - job_metadata = metadata or {} - job_metadata["agent_config"] = agent_config or {} - if dataset_link: - job_metadata["dataset_link"] = dataset_link - if group_size > 1: - job_metadata["group_size"] = group_size - job_metadata["total_episodes"] = len(tasks) * group_size - - # Use new eval system - job_id = str(uuid.uuid4()) - job_url = f"https://hud.ai/jobs/{job_id}" + # Use hud.eval() for both single and parallel execution + async with hud.eval( + task_list, + group=group_size, + max_concurrent=max_concurrent, + ) as ctx: + result = await agent.run(ctx, max_steps=max_steps) + ctx.reward = result.reward - print_link(job_url, f"🚀 Job '{name}'") + # For parallel execution, results are collected via ctx.results + if hasattr(ctx, "results") and ctx.results: + return ctx.results - error_occurred = False - try: - results = await _run_tasks_with_eval( - tasks=tasks, - agent_class=agent_class, - agent_params=agent_config, - max_concurrent=max_concurrent, - max_steps=max_steps, - group_size=group_size, - job_id=job_id, - ) - error_occurred = any(r is None or (isinstance(r, Trace) and r.isError) for r in results) - return results - except Exception: - error_occurred = True - raise - finally: - print_complete(job_url, name, error=error_occurred) + return [ctx] diff --git a/hud/datasets/tests/test_loader.py b/hud/datasets/tests/test_loader.py new file mode 100644 index 00000000..9bc617ff --- /dev/null +++ b/hud/datasets/tests/test_loader.py @@ -0,0 +1,196 @@ +"""Tests for hud.datasets.loader module.""" + +from __future__ import annotations + +from unittest.mock import MagicMock, patch + +import pytest + +from hud.datasets.loader import load_dataset + + +class TestLoadDataset: + """Tests for load_dataset() function.""" + + @patch("hud.datasets.loader.httpx.Client") + @patch("hud.datasets.loader.settings") + def test_load_dataset_success( + self, mock_settings: MagicMock, mock_client_class: MagicMock + ) -> None: + """load_dataset() successfully loads tasks from API.""" + mock_settings.hud_api_url = "https://api.hud.ai" + mock_settings.api_key = "test_key" + + mock_response = MagicMock() + mock_response.json.return_value = [ + {"env": {"name": "test"}, "scenario": "checkout", "args": {"user": "alice"}}, + {"env": {"name": "test"}, "scenario": "login", "args": {"user": "bob"}}, + ] + mock_response.raise_for_status = MagicMock() + + mock_client = MagicMock() + mock_client.get.return_value = mock_response + mock_client.__enter__.return_value = mock_client + mock_client.__exit__.return_value = None + mock_client_class.return_value = mock_client + + tasks = load_dataset("test-org/test-dataset") + + assert len(tasks) == 2 + assert tasks[0].scenario == "checkout" + assert tasks[0].args == {"user": "alice"} + assert tasks[1].scenario == "login" + mock_client.get.assert_called_once_with( + "https://api.hud.ai/evals/test-org/test-dataset", + headers={"Authorization": "Bearer test_key"}, + params={"all": "true"}, + ) + + @patch("hud.datasets.loader.httpx.Client") + @patch("hud.datasets.loader.settings") + def test_load_dataset_single_task( + self, mock_settings: MagicMock, mock_client_class: MagicMock + ) -> None: + """load_dataset() handles single task (non-list) response.""" + mock_settings.hud_api_url = "https://api.hud.ai" + mock_settings.api_key = "test_key" + + mock_response = MagicMock() + mock_response.json.return_value = { + "env": {"name": "test"}, + "scenario": "checkout", + "args": {"user": "alice"}, + } + mock_response.raise_for_status = MagicMock() + + mock_client = MagicMock() + mock_client.get.return_value = mock_response + mock_client.__enter__.return_value = mock_client + mock_client.__exit__.return_value = None + mock_client_class.return_value = mock_client + + tasks = load_dataset("test-org/test-dataset") + + assert len(tasks) == 1 + assert tasks[0].scenario == "checkout" + + @patch("hud.datasets.loader.httpx.Client") + @patch("hud.datasets.loader.settings") + def test_load_dataset_no_api_key( + self, mock_settings: MagicMock, mock_client_class: MagicMock + ) -> None: + """load_dataset() works without API key.""" + mock_settings.hud_api_url = "https://api.hud.ai" + mock_settings.api_key = None + + mock_response = MagicMock() + mock_response.json.return_value = [] + mock_response.raise_for_status = MagicMock() + + mock_client = MagicMock() + mock_client.get.return_value = mock_response + mock_client.__enter__.return_value = mock_client + mock_client.__exit__.return_value = None + mock_client_class.return_value = mock_client + + tasks = load_dataset("test-org/test-dataset") + + mock_client.get.assert_called_once_with( + "https://api.hud.ai/evals/test-org/test-dataset", + headers={}, + params={"all": "true"}, + ) + + @patch("hud.datasets.loader.httpx.Client") + @patch("hud.datasets.loader.settings") + def test_load_dataset_http_error( + self, mock_settings: MagicMock, mock_client_class: MagicMock + ) -> None: + """load_dataset() raises ValueError on HTTP error.""" + import httpx + + mock_settings.hud_api_url = "https://api.hud.ai" + mock_settings.api_key = "test_key" + + mock_client = MagicMock() + mock_client.get.side_effect = httpx.HTTPError("Network error") + mock_client.__enter__.return_value = mock_client + mock_client.__exit__.return_value = None + mock_client_class.return_value = mock_client + + with pytest.raises(ValueError, match="Failed to load dataset"): + load_dataset("test-org/test-dataset") + + @patch("hud.datasets.loader.httpx.Client") + @patch("hud.datasets.loader.settings") + def test_load_dataset_json_error( + self, mock_settings: MagicMock, mock_client_class: MagicMock + ) -> None: + """load_dataset() raises ValueError on JSON processing error.""" + mock_settings.hud_api_url = "https://api.hud.ai" + mock_settings.api_key = "test_key" + + mock_response = MagicMock() + mock_response.json.side_effect = Exception("Invalid JSON") + mock_response.raise_for_status = MagicMock() + + mock_client = MagicMock() + mock_client.get.return_value = mock_response + mock_client.__enter__.return_value = mock_client + mock_client.__exit__.return_value = None + mock_client_class.return_value = mock_client + + with pytest.raises(ValueError, match="Error processing dataset"): + load_dataset("test-org/test-dataset") + + @patch("hud.datasets.loader.httpx.Client") + @patch("hud.datasets.loader.settings") + def test_load_dataset_empty( + self, mock_settings: MagicMock, mock_client_class: MagicMock + ) -> None: + """load_dataset() handles empty dataset.""" + mock_settings.hud_api_url = "https://api.hud.ai" + mock_settings.api_key = "test_key" + + mock_response = MagicMock() + mock_response.json.return_value = [] + mock_response.raise_for_status = MagicMock() + + mock_client = MagicMock() + mock_client.get.return_value = mock_response + mock_client.__enter__.return_value = mock_client + mock_client.__exit__.return_value = None + mock_client_class.return_value = mock_client + + tasks = load_dataset("test-org/test-dataset") + + assert len(tasks) == 0 + + @patch("hud.datasets.loader.httpx.Client") + @patch("hud.datasets.loader.settings") + def test_load_dataset_missing_fields( + self, mock_settings: MagicMock, mock_client_class: MagicMock + ) -> None: + """load_dataset() handles tasks with missing optional fields.""" + mock_settings.hud_api_url = "https://api.hud.ai" + mock_settings.api_key = "test_key" + + mock_response = MagicMock() + mock_response.json.return_value = [ + {"scenario": "test"}, # Missing env and args + ] + mock_response.raise_for_status = MagicMock() + + mock_client = MagicMock() + mock_client.get.return_value = mock_response + mock_client.__enter__.return_value = mock_client + mock_client.__exit__.return_value = None + mock_client_class.return_value = mock_client + + tasks = load_dataset("test-org/test-dataset") + + assert len(tasks) == 1 + assert tasks[0].scenario == "test" + assert tasks[0].env is None + assert tasks[0].args == {} + diff --git a/hud/datasets/tests/test_utils.py b/hud/datasets/tests/test_utils.py index cda201ac..79a69544 100644 --- a/hud/datasets/tests/test_utils.py +++ b/hud/datasets/tests/test_utils.py @@ -16,7 +16,7 @@ display_results, submit_rollouts, ) -from hud.types import AgentType, Task, Trace +from hud.types import AgentType, LegacyTask, Trace class TestSingleTaskRequest: @@ -161,8 +161,8 @@ class TestCalculateGroupStats: def test_basic_stats(self): """Test basic group statistics calculation.""" tasks = [ - Task(prompt="Task 1", mcp_config={}), - Task(prompt="Task 2", mcp_config={}), + LegacyTask(prompt="Task 1", mcp_config={}), + LegacyTask(prompt="Task 2", mcp_config={}), ] traces: list[Trace | None] = [ Trace(reward=0.8, done=True), @@ -180,7 +180,7 @@ def test_basic_stats(self): def test_all_none_traces(self): """Test when all traces are None.""" - tasks = [Task(prompt="Task 1", mcp_config={})] + tasks = [LegacyTask(prompt="Task 1", mcp_config={})] traces: list[Trace | None] = [None, None] group_ids = {0: "group-0"} @@ -192,7 +192,7 @@ def test_all_none_traces(self): def test_mixed_success_failure(self): """Test with mixed success and failure traces.""" - tasks = [Task(prompt="Task 1", mcp_config={})] + tasks = [LegacyTask(prompt="Task 1", mcp_config={})] traces: list[Trace | None] = [ Trace(reward=1.0, done=True), Trace(reward=0.0, done=True, isError=True), @@ -211,8 +211,8 @@ class TestDisplayResults: def test_display_with_traces(self): """Test displaying single-run trace results.""" tasks = [ - Task(id="t1", prompt="Test task 1", mcp_config={}), - Task(id="t2", prompt="Test task 2", mcp_config={}), + LegacyTask(id="t1", prompt="Test task 1", mcp_config={}), + LegacyTask(id="t2", prompt="Test task 2", mcp_config={}), ] results = [ Trace(reward=0.9, done=True), @@ -225,7 +225,7 @@ def test_display_with_traces(self): def test_display_with_group_stats(self): """Test displaying group statistics.""" tasks = [ - Task(id="t1", prompt="Test task 1", mcp_config={}), + LegacyTask(id="t1", prompt="Test task 1", mcp_config={}), ] results = [ { @@ -246,7 +246,7 @@ def test_display_with_group_stats(self): def test_display_empty_results(self): """Test displaying when no valid results.""" - tasks = [Task(prompt="Test", mcp_config={})] + tasks = [LegacyTask(prompt="Test", mcp_config={})] results: list[Trace | None] = [None] # Should not raise @@ -259,7 +259,7 @@ class TestSubmitRollouts: @pytest.mark.asyncio async def test_submit_single_task(self): """Test submitting a single task.""" - tasks = [Task(id="task-1", prompt="Test prompt", mcp_config={})] + tasks = [LegacyTask(id="task-1", prompt="Test prompt", mcp_config={})] with patch("hud.datasets.utils.httpx.AsyncClient") as mock_client_cls: mock_response = MagicMock() @@ -288,7 +288,7 @@ async def test_submit_single_task(self): @pytest.mark.asyncio async def test_submit_with_group_size(self): """Test submitting with group_size > 1 creates multiple requests per task.""" - tasks = [Task(id="task-1", prompt="Test prompt", mcp_config={})] + tasks = [LegacyTask(id="task-1", prompt="Test prompt", mcp_config={})] with patch("hud.datasets.utils.httpx.AsyncClient") as mock_client_cls: mock_response = MagicMock() diff --git a/hud/datasets/utils.py b/hud/datasets/utils.py index 41d761b1..84e4a604 100644 --- a/hud/datasets/utils.py +++ b/hud/datasets/utils.py @@ -10,7 +10,7 @@ from pydantic import BaseModel, Field, field_validator, model_validator from hud.settings import settings -from hud.types import AgentType, Task, Trace +from hud.types import AgentType, LegacyTask, Trace from hud.utils.hud_console import HUDConsole logger = logging.getLogger(__name__) @@ -21,7 +21,7 @@ class SingleTaskRequest(BaseModel): """Request to run a single task remotely - mirrors run_single_task() args.""" task: dict[str, Any] = Field( - description="Task definition compatible with hud.types.Task.", + description="Task definition compatible with hud.types.LegacyTask.", ) agent_type: AgentType = Field(description="Agent type to execute the task.") agent_params: dict[str, Any] = Field( @@ -43,7 +43,7 @@ class SingleTaskRequest(BaseModel): @model_validator(mode="after") def _validate_task(self) -> SingleTaskRequest: try: - Task(**self.task) + LegacyTask(**self.task) except Exception as exc: raise ValueError(f"Invalid task payload: {exc}") from exc return self @@ -67,7 +67,7 @@ class BatchRequest(BaseModel): async def submit_rollouts( - tasks: list[Task], + tasks: list[LegacyTask], job_id: str, agent_type: AgentType, agent_params: dict[str, Any] | None = None, @@ -259,7 +259,7 @@ async def cancel_all_jobs() -> dict[str, Any]: def calculate_group_stats( - tasks: list[Task], + tasks: list[LegacyTask], traces: list[Trace | None], group_size: int, group_ids: dict[int, str], @@ -328,15 +328,15 @@ def calculate_group_stats( def display_results( results: list[Any], *, - tasks: list[Task], + tasks: list[Any], elapsed: float | None = None, show_details: bool = True, ) -> None: """Display evaluation results in a formatted table. Args: - results: List of Trace objects or grouped statistics dicts - tasks: List of Task objects corresponding to results + results: List of EvalContext objects or grouped statistics dicts + tasks: List of Task or LegacyTask objects corresponding to results elapsed: Optional elapsed time in seconds show_details: Whether to show per-task details table """ @@ -380,8 +380,12 @@ def display_results( for i, (stat, task) in enumerate(zip(results, tasks, strict=False)): task_id = (task.id or "")[:20] - prompt = (task.prompt or "")[:40] - if len(task.prompt or "") > 40: + # Handle both v4 (prompt attr) and v5 (prompt in args) tasks + raw_prompt = getattr(task, "prompt", None) or ( + task.args.get("prompt") if hasattr(task, "args") else None + ) or task.scenario or "" + prompt = raw_prompt[:40] + if len(raw_prompt) > 40: prompt += "..." table.add_row( str(i + 1), @@ -428,8 +432,12 @@ def display_results( for i, r in enumerate(results): task = tasks[i] task_id = (task.id or "")[:20] - prompt = (task.prompt or "")[:40] - if len(task.prompt or "") > 40: + # Handle both v4 (prompt attr) and v5 (prompt in args) tasks + raw_prompt = getattr(task, "prompt", None) or ( + task.args.get("prompt") if hasattr(task, "args") else None + ) or getattr(task, "scenario", None) or "" + prompt = raw_prompt[:40] + if len(raw_prompt) > 40: prompt += "..." if r is None: diff --git a/hud/environment/__init__.py b/hud/environment/__init__.py index 1746606d..9aad37a0 100644 --- a/hud/environment/__init__.py +++ b/hud/environment/__init__.py @@ -28,8 +28,8 @@ from hud.environment.environment import Environment from hud.environment.mock import MockMixin, generate_mock_value from hud.environment.router import ConflictResolution, ToolRouter -from hud.environment.scripts import ScriptMixin -from hud.environment.types import EnvConfig, HubConfig +from hud.environment.scenarios import ScenarioMixin +from hud.environment.types import EnvConfig from hud.environment.utils import ToolFormat, format_result, parse_tool_call, parse_tool_calls __all__ = [ @@ -39,9 +39,8 @@ "Connector", "EnvConfig", "Environment", - "HubConfig", "MockMixin", - "ScriptMixin", + "ScenarioMixin", "ToolFormat", "ToolRouter", "format_result", diff --git a/hud/environment/connectors/__init__.py b/hud/environment/connectors/__init__.py index 7b8919ac..e88778e1 100644 --- a/hud/environment/connectors/__init__.py +++ b/hud/environment/connectors/__init__.py @@ -3,7 +3,6 @@ from hud.environment.connectors.local import LocalConnectorMixin from hud.environment.connectors.openai import OpenAIConnectorMixin from hud.environment.connectors.remote import RemoteConnectorMixin -from hud.environment.connectors.task import TaskConnectorMixin __all__ = ["ConnectorsMixin"] @@ -11,13 +10,12 @@ class ConnectorsMixin( RemoteConnectorMixin, LocalConnectorMixin, - TaskConnectorMixin, OpenAIConnectorMixin, ): """Combined connector mixin providing all connection methods. Remote connections: - connect_hub(slug) - HUD Hub environment (fetches mcp_config from API) + connect_hub(slug) - HUD Hub environment connect_url(url) - MCP server via URL connect_openapi(spec) - Mount OpenAPI spec as MCP server @@ -30,9 +28,6 @@ class ConnectorsMixin( connect_mcp(config) - Single mcp_config server (auto-detects local/remote) connect_mcp_config(mcp_config) - Multiple mcp_config servers - Task: - connect_task(slug) - Load task from platform by slug - Framework imports: connect_function_tools(tools) - Import OpenAI Agents SDK FunctionTools """ diff --git a/hud/environment/connectors/remote.py b/hud/environment/connectors/remote.py index d9179786..866b13a5 100644 --- a/hud/environment/connectors/remote.py +++ b/hud/environment/connectors/remote.py @@ -12,8 +12,6 @@ from fastmcp.tools.tool import Tool - from hud.environment.types import HubConfig - __all__ = ["RemoteConnectorMixin"] logger = logging.getLogger(__name__) @@ -25,9 +23,6 @@ class RemoteConnectorMixin(MCPConfigConnectorMixin): Note: include_router() is inherited from MCPServer (via FastMCP). """ - # Store hub configs for trace serialization - _hub_configs: list[HubConfig] - def connect_hub( self, slug: str, @@ -40,55 +35,36 @@ def connect_hub( ) -> Any: """Connect to a HUD Hub environment. - Fetches mcp_config from api.hud.so immediately and creates connectors. + Creates an MCP connection to the HUD API with the hub slug in headers. Example: ```python env = Environment("my-env") - env.connect_hub("hud/browser") + env.connect_hub("browser") async with env: await env.call_tool("navigate", url="https://google.com") ``` """ - import httpx - - from hud.environment.types import HubConfig from hud.settings import settings - # Store hub config for trace serialization - hub_config = HubConfig( - slug=slug, - alias=alias, - prefix=prefix, - include=include, - exclude=exclude, - ) - - if not hasattr(self, "_hub_configs"): - self._hub_configs = [] - self._hub_configs.append(hub_config) + logger.info("Connecting to hub environment: %s", slug) - # Fetch mcp_config synchronously - logger.info("Loading hub environment: %s", slug) - - headers = {} - if settings.api_key: - headers["Authorization"] = f"Bearer {settings.api_key}" - - with httpx.Client() as client: - response = client.get( - f"{settings.hud_api_url}/environments/{slug}/mcp-config", - headers=headers, - ) - response.raise_for_status() - data = response.json() + # Create mcp_config with standard MCP URL and hub slug in headers + mcp_config = { + "hud": { + "url": settings.hud_mcp_url, + "headers": { + "Authorization": f"Bearer {settings.api_key}", + "Environment-Name": slug, + }, + } + } - mcp_config: dict[str, dict[str, Any]] = data.get("mcp_config", data) self.connect_mcp_config( mcp_config, prefix=prefix, include=include, exclude=exclude, transform=transform ) - logger.info("Hub connected: %s (%d servers)", slug, len(mcp_config)) + logger.info("Hub connected: %s", slug) return self def connect_url( diff --git a/hud/environment/connectors/task.py b/hud/environment/connectors/task.py deleted file mode 100644 index 1fe1033e..00000000 --- a/hud/environment/connectors/task.py +++ /dev/null @@ -1,109 +0,0 @@ -"""Task connection connector.""" - -from __future__ import annotations - -import logging -from typing import TYPE_CHECKING, Any - -from hud.environment.connectors.mcp_config import MCPConfigConnectorMixin - -if TYPE_CHECKING: - from hud.types import Task - -__all__ = ["TaskConnectorMixin"] - -logger = logging.getLogger(__name__) - - -class TaskConnectorMixin(MCPConfigConnectorMixin): - """Mixin providing connect_task() method. - - Inherits from MCPConfigConnectorMixin for connect_mcp_config(). - """ - - def setup_tool(self, call: Any, /, **kwargs: Any) -> Any: - raise NotImplementedError - - def evaluate_tool(self, call: Any, /, **kwargs: Any) -> Any: - raise NotImplementedError - - def connect_task(self, slug: str) -> Any: - """Connect to a task from the HUD platform. - - Fetches the task from api.hud.so immediately and applies configuration - (mcp_config, setup_tool, evaluate_tool). - - Args: - slug: Task slug in format "evalset/task_name" or "evalset/task_name@version". - - Returns: - self for chaining. - - Example: - ```python - env = Environment("my-env").connect_task("my-org/browser-task") - - async with env: - # Task's mcp_config is connected - # Task's setup_tool runs automatically - result = await env.call_tool("navigate", url="...") - # Task's evaluate_tool runs on exit - ``` - """ - import httpx - - from hud.settings import settings - from hud.types import Task - - # Fetch task synchronously - logger.info("Loading task from platform: %s", slug) - - headers = {} - if settings.api_key: - headers["Authorization"] = f"Bearer {settings.api_key}" - - with httpx.Client() as client: - response = client.get( - f"{settings.hud_api_url}/tasks/{slug}", - headers=headers, - ) - response.raise_for_status() - data = response.json() - - task = Task(**data) - self._apply_task(task) - logger.info("Task loaded and applied: %s", slug) - return self - - def _apply_task(self, task: Task) -> None: - """Apply a Task definition to this environment. - - Sets up: - - Prompt from task.prompt - - MCP connections from task.mcp_config - - Setup tool calls from task.setup_tool - - Evaluate tool calls from task.evaluate_tool - """ - # Set prompt - if task.prompt: - self.prompt = task.prompt # type: ignore[attr-defined] - - # Connect MCP servers - if task.mcp_config: - self.connect_mcp_config(task.mcp_config) - - # Configure setup tool calls - if task.setup_tool: - setup_calls = task.setup_tool - if not isinstance(setup_calls, list): - setup_calls = [setup_calls] - for call in setup_calls: - self.setup_tool(call.name, **(call.arguments or {})) - - # Configure evaluate tool calls - if task.evaluate_tool: - eval_calls = task.evaluate_tool - if not isinstance(eval_calls, list): - eval_calls = [eval_calls] - for call in eval_calls: - self.evaluate_tool(call.name, **(call.arguments or {})) diff --git a/hud/environment/environment.py b/hud/environment/environment.py index 2f5f2196..0a9d3867 100644 --- a/hud/environment/environment.py +++ b/hud/environment/environment.py @@ -13,7 +13,7 @@ from hud.environment.integrations import IntegrationsMixin from hud.environment.mock import MockMixin from hud.environment.router import ConflictResolution, ToolRouter -from hud.environment.scripts import ScriptMixin +from hud.environment.scenarios import ScenarioMixin from hud.server.server import MCPServer from hud.types import MCPToolResult @@ -21,7 +21,7 @@ import types from hud.environment.connection import Connector - from hud.eval.eval import Eval + from hud.eval.task import Task __all__ = ["Environment"] @@ -39,7 +39,7 @@ class Environment( ConnectorsMixin, IntegrationsMixin, MockMixin, - ScriptMixin, + ScenarioMixin, MCPServer, ): """Unified MCP environment that acts as both server and client. @@ -57,7 +57,6 @@ class Environment( connect_url(url) - MCP server via URL connect_mcp(config) - Single mcp_config server connect_mcp_config(mcp_config) - Multiple mcp_config servers - connect_task(slug) - Load task from platform by slug connect_image(image) - Docker image via stdio connect_fastapi(app) - Mount FastAPI app as MCP server connect_openapi(spec) - Mount OpenAPI spec as MCP server @@ -136,7 +135,7 @@ def __init__( self._setup_calls: list[tuple[str, dict[str, Any]]] = [] self._evaluate_calls: list[tuple[str, dict[str, Any]]] = [] - # Default prompt - set by connect_task (EvalContext has per-run prompt) + # Default prompt (EvalContext has per-run prompt) self.prompt: str | None = None # Track which lifecycle tools we've warned about (only warn once per tool) @@ -145,8 +144,8 @@ def __init__( # Initialize mock state self._init_mock() - # Initialize script state - self._init_scripts() + # Initialize scenario state + self._init_scenarios() # ========================================================================= # Core Methods @@ -525,100 +524,47 @@ def local_connections(self) -> list[str]: """Names of local (non-parallelizable) connections.""" return [name for name, conn in self._connections.items() if conn.is_local] - def _get_env_config(self) -> dict[str, Any] | None: - """Get serializable environment configuration for trace storage. - - Returns EnvConfig-compatible dict with: - - name: Environment name - - hubs: List of hub configs (connect_hub calls) - - setup_tools: Tools to run after connection (MCPToolCall format) - - evaluate_tools: Tools to run before disconnection (MCPToolCall format) - """ - hub_configs = getattr(self, "_hub_configs", []) - - # Convert setup/evaluate calls to MCPToolCall format - setup_tools = [{"name": name, "arguments": args} for name, args in self._setup_calls] - evaluate_tools = [{"name": name, "arguments": args} for name, args in self._evaluate_calls] - - # Only return config if there's something to store - if not hub_configs and not setup_tools and not evaluate_tools: - return None - - return { - "name": self.name, - "hubs": [h.model_dump() for h in hub_configs], - "setup_tools": setup_tools, - "evaluate_tools": evaluate_tools, - } - - @property - def _all_hubs(self) -> bool: - """True if all tools came from connect_hub (fully reproducible). - - Returns False if there are: - - Local tools (@env.tool, connect_fastapi, connect_openapi, connect_server) - - Non-hub connections (connect_url, connect_mcp, connect_image, etc.) - """ - hub_configs = getattr(self, "_hub_configs", []) - - # Check for local tools (mounted servers, @env.tool) - # _tool_manager comes from MCPServer base class - local_tool_count = len(self._tool_manager._tools) if hasattr(self, "_tool_manager") else 0 - if local_tool_count > 0: - return False - - # No hubs and no connections = trivially all hubs (empty env) - if not hub_configs and not self._connections: - return True - - # Has connections but no hubs = not all hubs - if not hub_configs: - return False - - # Compare hub count to connection count - return len(hub_configs) >= len(self._connections) - def __repr__(self) -> str: return f"Environment({self.name!r}, connections={list(self._connections.keys())})" # ========================================================================= - # Eval Creation + # Task Creation # ========================================================================= def __call__( self, - script: str | None = None, + scenario: str | None = None, *, _trace: bool = True, _quiet: bool = False, **args: Any, - ) -> Eval: - """Create an Eval from this environment. + ) -> Task: + """Create a Task from this environment. - Returns an Eval that can be entered as a context manager or passed + Returns a Task that can be entered as a context manager or passed to hud.eval() for orchestration. Args: - script: Optional script name to run (from @env.script) + scenario: Optional scenario name to run (from @env.scenario) _trace: Whether to send trace data to backend (default True) _quiet: Whether to suppress printing links (default False) - **args: Arguments for the script + **args: Arguments for the scenario Returns: - Eval: A runnable evaluation unit + Task: A runnable evaluation unit Example: ```python env = Environment("my-env").connect_hub("browser") - @env.script() + @env.scenario() async def checkout(user_id: str): yield "Complete checkout" yield 1.0 - # Simple use - Eval is context manager + # Simple use - Task is context manager async with env("checkout", user_id="alice") as ctx: await agent.run(ctx.prompt) @@ -627,59 +573,18 @@ async def checkout(user_id: str): await ctx.call_tool("navigate", url="...") # Orchestrated via hud.eval - evals = [env("checkout", user_id="alice"), env("checkout", user_id="bob")] - async with hud.eval(evals, variants={"model": ["gpt-4o"]}, group=4) as ctx: + tasks = [env("checkout", user_id="alice"), env("checkout", user_id="bob")] + async with hud.eval(tasks, variants={"model": ["gpt-4o"]}, group=4) as ctx: ... ``` """ - from hud.eval.eval import Eval + from hud.eval.task import Task - return Eval( - env=self, # Pass live environment for local tools/scripts - script=script, + return Task( + env=self, # Pass live environment for local tools/scenarios + scenario=scenario, args=args, _trace=_trace, _quiet=_quiet, ) - @classmethod - def from_config(cls, config: dict[str, Any] | None) -> Environment: - """Create an Environment from a configuration dict. - - Args: - config: EnvConfig-compatible dict with: - - name: Environment name - - hubs: List of hub configs (HubConfig dicts) - - setup_tools: Tools to run after connection - - evaluate_tools: Tools to run before disconnection - - Returns: - Environment: Configured environment instance - """ - if config is None: - return cls("eval") - - env = cls(name=config.get("name", "eval")) - - # Connect hubs - for hub in config.get("hubs", []): - if isinstance(hub, dict): - env.connect_hub( - hub.get("slug", ""), - alias=hub.get("alias"), - prefix=hub.get("prefix"), - include=hub.get("include"), - exclude=hub.get("exclude"), - ) - - # Add setup tools - for tool in config.get("setup_tools", []): - if isinstance(tool, dict): - env.setup_tool(tool.get("name", ""), **(tool.get("arguments") or {})) - - # Add evaluate tools - for tool in config.get("evaluate_tools", []): - if isinstance(tool, dict): - env.evaluate_tool(tool.get("name", ""), **(tool.get("arguments") or {})) - - return env diff --git a/hud/environment/scripts.py b/hud/environment/scenarios.py similarity index 60% rename from hud/environment/scripts.py rename to hud/environment/scenarios.py index 63fd8703..5efc6499 100644 --- a/hud/environment/scripts.py +++ b/hud/environment/scenarios.py @@ -1,4 +1,4 @@ -"""Script decorator for Environment - defines setup/evaluate phases.""" +"""Scenario decorator for Environment - defines setup/evaluate phases.""" from __future__ import annotations @@ -15,29 +15,29 @@ from fastmcp.resources import ResourceManager from fastmcp.tools import ToolManager -__all__ = ["ScriptMixin"] +__all__ = ["ScenarioMixin"] logger = logging.getLogger(__name__) -class ScriptMixin: - """Mixin providing @env.script decorator for setup/evaluate phases. +class ScenarioMixin: + """Mixin providing @env.scenario decorator for setup/evaluate phases. - Scripts are async generators that yield twice: + Scenarios are async generators that yield twice: - First yield: prompt string (setup phase) - Second yield: reward float (evaluate phase) - The script can receive the agent's answer via yield: + The scenario can receive the agent's answer via yield: answer = yield "Do the task" yield 1.0 if "success" in answer else 0.0 The answer is passed via the hud_submit tool or ctx.submit(). The decorator registers both an MCP prompt and resource with the same - identifier ({env_name}:{script_name}), linked by session state. + identifier ({env_name}:{scenario_name}), linked by session state. Example: - @env.script() + @env.scenario() async def search_cats(url: str): await env.call_tool("navigate", url=url) answer = yield "Find all cat images on the page" @@ -51,44 +51,44 @@ async def search_cats(url: str): _resource_manager: ResourceManager _tool_manager: ToolManager - # Script state - _scripts: dict[str, Callable[..., AsyncGenerator[Any, Any]]] - _script_sessions: dict[str, AsyncGenerator[Any, Any]] # session_id -> generator - _script_latest: dict[str, str] # script_name -> latest session_id - _script_answers: dict[str, str] # script_name -> submitted answer + # Scenario state + _scenarios: dict[str, Callable[..., AsyncGenerator[Any, Any]]] + _scenario_sessions: dict[str, AsyncGenerator[Any, Any]] # session_id -> generator + _scenario_latest: dict[str, str] # scenario_name -> latest session_id + _scenario_answers: dict[str, str] # scenario_name -> submitted answer - def _init_scripts(self) -> None: - """Initialize script state. Called from Environment.__init__.""" - self._scripts = {} - self._script_sessions = {} - self._script_latest = {} - self._script_answers = {} + def _init_scenarios(self) -> None: + """Initialize scenario state. Called from Environment.__init__.""" + self._scenarios = {} + self._scenario_sessions = {} + self._scenario_latest = {} + self._scenario_answers = {} # Register _hud_submit tool (underscore = hidden from agent) self._register_hud_submit_tool() - async def submit(self, script: str, answer: str) -> None: - """Submit the agent's answer for a script's evaluate phase. + async def submit(self, scenario: str, answer: str) -> None: + """Submit the agent's answer for a scenario's evaluate phase. This stores the answer locally and broadcasts to connected hubs that have the _hud_submit tool (auto-detected by Environment). Args: - script: Name of the script (without env prefix) + scenario: Name of the scenario (without env prefix) answer: The agent's answer/result to submit Example: - # Direct call with script name + # Direct call with scenario name await env.submit("checkout", "Order completed successfully") - # Or via EvalContext (knows its own script) + # Or via EvalContext (knows its own scenario) await ctx.submit("Order completed successfully") """ - # Store locally for our scripts - self._script_answers[script] = answer + # Store locally for our scenarios + self._scenario_answers[scenario] = answer logger.debug( - "Stored answer for script '%s': %s...", - script, + "Stored answer for scenario '%s': %s...", + scenario, answer[:50] if len(answer) > 50 else answer, ) @@ -96,7 +96,7 @@ async def submit(self, script: str, answer: str) -> None: # Environment._broadcast_tool auto-filters to connections with the tool await self._broadcast_tool( # type: ignore[attr-defined] "_hud_submit", - script=script, + scenario=scenario, answer=answer, ) @@ -107,70 +107,70 @@ def _register_hud_submit_tool(self) -> None: """ from fastmcp.tools import Tool - script_self = self + scenario_self = self - async def _hud_submit(script: str, answer: str) -> str: - """Submit the agent's answer for a script's evaluate phase. + async def _hud_submit(scenario: str, answer: str) -> str: + """Submit the agent's answer for a scenario's evaluate phase. Internal tool - called by Environment.submit() on connected hubs. Args: - script: Name of the script (without env prefix) + scenario: Name of the scenario (without env prefix) answer: The agent's answer/result to submit """ # Store locally (don't broadcast - we ARE the target) - script_self._script_answers[script] = answer + scenario_self._scenario_answers[scenario] = answer logger.debug( - "_hud_submit received answer for script '%s': %s...", - script, + "_hud_submit received answer for scenario '%s': %s...", + scenario, answer[:50] if len(answer) > 50 else answer, ) - return f"Answer submitted for script '{script}'" + return f"Answer submitted for scenario '{scenario}'" # Register the tool with underscore name tool = Tool.from_function(_hud_submit) self._tool_manager.add_tool(tool) logger.debug("Registered _hud_submit tool") - async def run_script_setup(self, script_name: str, args: dict[str, Any]) -> str | None: - """Run a script's setup phase and return the prompt. + async def run_scenario_setup(self, scenario_name: str, args: dict[str, Any]) -> str | None: + """Run a scenario's setup phase and return the prompt. - Handles both local scripts (registered via @env.script) and remote - scripts (via MCP prompt). + Handles both local scenarios (registered via @env.scenario) and remote + scenarios (via MCP prompt). Args: - script_name: Name of the script to run - args: Arguments to pass to the script + scenario_name: Name of the scenario to run + args: Arguments to pass to the scenario Returns: - The prompt string from the script's setup phase, or None if failed + The prompt string from the scenario's setup phase, or None if failed """ - # Check if script is registered locally - if script_name in self._scripts: - # Local script - run setup via generator - script_fn = self._scripts[script_name] - gen = script_fn(**args) + # Check if scenario is registered locally + if scenario_name in self._scenarios: + # Local scenario - run setup via generator + scenario_fn = self._scenarios[scenario_name] + gen = scenario_fn(**args) # Run setup phase (code before first yield) prompt = await gen.__anext__() # Store generator for evaluate phase session_id = uuid.uuid4().hex[:8] - self._script_sessions[session_id] = gen - self._script_latest[script_name] = session_id + self._scenario_sessions[session_id] = gen + self._scenario_latest[scenario_name] = session_id logger.debug( - "Script %s setup complete, session=%s", - script_name, + "Scenario %s setup complete, session=%s", + scenario_name, session_id, ) return str(prompt) else: - # Remote script - call via MCP prompt - # Format: {env_name}:{script_name} (use source env name if available) + # Remote scenario - call via MCP prompt + # Format: {env_name}:{scenario_name} (use source env name if available) env_name = getattr(self, "_source_env_name", None) or self.name safe_env_name = env_name.replace("_", "-") - prompt_id = f"{safe_env_name}:{script_name}" + prompt_id = f"{safe_env_name}:{scenario_name}" try: result = await self.get_prompt(prompt_id, args) # type: ignore[attr-defined] if result.messages: @@ -181,35 +181,35 @@ async def run_script_setup(self, script_name: str, args: dict[str, Any]) -> str elif isinstance(content, str): return content except Exception as e: - logger.warning("Failed to get script prompt: %s", e) + logger.warning("Failed to get scenario prompt: %s", e) return None - async def run_script_evaluate(self, script_name: str) -> float | None: - """Run a script's evaluate phase and return the reward. + async def run_scenario_evaluate(self, scenario_name: str) -> float | None: + """Run a scenario's evaluate phase and return the reward. Uses the submitted answer (if any) via gen.asend(). - Handles both local and remote scripts. + Handles both local and remote scenarios. Args: - script_name: Name of the script to evaluate + scenario_name: Name of the scenario to evaluate Returns: - The reward from the script's evaluate phase, or None if failed + The reward from the scenario's evaluate phase, or None if failed """ - # Check if we have a stored generator (local script) - session_id = self._script_latest.get(script_name) + # Check if we have a stored generator (local scenario) + session_id = self._scenario_latest.get(scenario_name) if session_id: - gen = self._script_sessions.pop(session_id, None) + gen = self._scenario_sessions.pop(session_id, None) if gen: # Get submitted answer (if any) - answer = self._script_answers.pop(script_name, None) + answer = self._scenario_answers.pop(scenario_name, None) try: - # Use asend to pass the answer to the script + # Use asend to pass the answer to the scenario reward = await gen.asend(answer) logger.debug( - "Script %s evaluate complete, answer=%s, reward=%s", - script_name, + "Scenario %s evaluate complete, answer=%s, reward=%s", + scenario_name, answer[:50] if answer and len(answer) > 50 else answer, reward, ) @@ -219,13 +219,13 @@ async def run_script_evaluate(self, script_name: str) -> float | None: return 1.0 finally: # Clean up latest pointer - if self._script_latest.get(script_name) == session_id: - del self._script_latest[script_name] + if self._scenario_latest.get(scenario_name) == session_id: + del self._scenario_latest[scenario_name] - # Remote script - read via MCP resource (use source env name if available) + # Remote scenario - read via MCP resource (use source env name if available) env_name = getattr(self, "_source_env_name", None) or self.name safe_env_name = env_name.replace("_", "-") - resource_id = f"{safe_env_name}:{script_name}" + resource_id = f"{safe_env_name}:{scenario_name}" try: contents = await self.read_resource(resource_id) # type: ignore[attr-defined] if contents: @@ -235,10 +235,10 @@ async def run_script_evaluate(self, script_name: str) -> float | None: if "reward" in data: return float(data["reward"]) except Exception as e: - logger.warning("Failed to get script reward: %s", e) + logger.warning("Failed to get scenario reward: %s", e) return None - def script( + def scenario( self, name: str | None = None, description: str | None = None, @@ -246,19 +246,19 @@ def script( [Callable[..., AsyncGenerator[Any, None]]], Callable[..., AsyncGenerator[Any, None]], ]: - """Decorator to register a script with setup and evaluate phases. + """Decorator to register a scenario with setup and evaluate phases. - Creates both a prompt and resource with identifier script:{name}. - The script function should yield twice: + Creates both a prompt and resource with identifier scenario:{name}. + The scenario function should yield twice: - First yield: the prompt string (returned from prompt) - Second yield: the reward float (returned from resource) Args: - name: Optional name for the script (defaults to function name) - description: Optional description of what the script does + name: Optional name for the scenario (defaults to function name) + description: Optional description of what the scenario does Example: - @env.script() + @env.scenario() async def search_cats(url: str): await env.call_tool("navigate", url=url) yield "Find cat images" @@ -274,11 +274,11 @@ async def search_cats(url: str): def decorator( fn: Callable[..., AsyncGenerator[Any, None]], ) -> Callable[..., AsyncGenerator[Any, None]]: - script_name = name or fn.__name__ + scenario_name = name or fn.__name__ # Sanitize env name for URI scheme (no underscores allowed) safe_env_name = self.name.replace("_", "-") - script_id = f"{safe_env_name}:{script_name}" - script_desc = description or fn.__doc__ or f"Script: {script_name}" + scenario_id = f"{safe_env_name}:{scenario_name}" + scenario_desc = description or fn.__doc__ or f"Scenario: {scenario_name}" # Capture source code for reproducibility try: @@ -287,7 +287,7 @@ def decorator( source_code = None # Store the generator function - self._scripts[script_name] = fn + self._scenarios[scenario_name] = fn # Get function signature for prompt arguments sig = inspect.signature(fn) @@ -298,25 +298,25 @@ def decorator( # Register PROMPT - runs setup, returns prompt messages # We need a reference to self and the outer variables - script_self = self - script_fn = fn - script_name_ref = script_name + scenario_self = self + scenario_fn = fn + scenario_name_ref = scenario_name async def prompt_handler(**handler_args: Any) -> list[dict[str, Any]]: # Create generator instance - gen = script_fn(**handler_args) + gen = scenario_fn(**handler_args) # Run setup phase (code before first yield) prompt_text = await gen.__anext__() # Store generator with session ID session_id = uuid.uuid4().hex[:8] - script_self._script_sessions[session_id] = gen - script_self._script_latest[script_name_ref] = session_id + scenario_self._scenario_sessions[session_id] = gen + scenario_self._scenario_latest[scenario_name_ref] = session_id logger.debug( - "Script %s setup complete, session=%s, prompt=%s", - script_name_ref, + "Scenario %s setup complete, session=%s, prompt=%s", + scenario_name_ref, session_id, prompt_text[:50] if isinstance(prompt_text, str) else prompt_text, ) @@ -328,36 +328,36 @@ async def prompt_handler(**handler_args: Any) -> list[dict[str, Any]]: from fastmcp.prompts.prompt import FunctionPrompt, PromptArgument # Build meta with source code - script_meta = {"code": source_code} if source_code else None + scenario_meta = {"code": source_code} if source_code else None prompt = FunctionPrompt( - name=script_id, - description=f"[Setup] {script_desc}", + name=scenario_id, + description=f"[Setup] {scenario_desc}", arguments=[ PromptArgument(name=arg["name"], required=arg["required"]) for arg in prompt_args ], fn=prompt_handler, - meta=script_meta, + meta=scenario_meta, ) self._prompt_manager.add_prompt(prompt) # Register RESOURCE - runs evaluate, returns reward async def resource_handler() -> str: - # Get latest session for this script - session_id = script_self._script_latest.get(script_name_ref) + # Get latest session for this scenario + session_id = scenario_self._scenario_latest.get(scenario_name_ref) if not session_id: raise ValueError( - f"No active session for script '{script_name_ref}'. " + f"No active session for scenario '{scenario_name_ref}'. " "Call the prompt first to run setup." ) - gen = script_self._script_sessions.pop(session_id, None) + gen = scenario_self._scenario_sessions.pop(session_id, None) if gen is None: raise ValueError(f"Session '{session_id}' not found or already evaluated.") # Get submitted answer (if any) - answer = script_self._script_answers.pop(script_name_ref, None) + answer = scenario_self._scenario_answers.pop(scenario_name_ref, None) # Run evaluate phase (code after first yield) # Use asend to pass the answer (or None if not submitted) @@ -368,36 +368,36 @@ async def resource_handler() -> str: reward = 1.0 logger.debug( - "Script %s evaluate complete, session=%s, answer=%s, reward=%s", - script_name_ref, + "Scenario %s evaluate complete, session=%s, answer=%s, reward=%s", + scenario_name_ref, session_id, answer[:50] if answer and len(answer) > 50 else answer, reward, ) # Clean up latest pointer if it matches - if script_self._script_latest.get(script_name_ref) == session_id: - del script_self._script_latest[script_name_ref] + if scenario_self._scenario_latest.get(scenario_name_ref) == session_id: + del scenario_self._scenario_latest[scenario_name_ref] return json.dumps({"reward": float(reward)}) - # Register as resource with same script: URI + # Register as resource with same scenario: URI from fastmcp.resources.resource import FunctionResource resource = FunctionResource.from_function( fn=resource_handler, - uri=script_id, - name=script_name, - description=f"[Evaluate] {script_desc}", + uri=scenario_id, + name=scenario_name, + description=f"[Evaluate] {scenario_desc}", mime_type="application/json", - meta=script_meta, + meta=scenario_meta, ) self._resource_manager.add_resource(resource) logger.debug( - "Registered script '%s' as prompt and resource: %s", - script_name, - script_id, + "Registered scenario '%s' as prompt and resource: %s", + scenario_name, + scenario_id, ) return fn diff --git a/hud/environment/tests/test_connectors.py b/hud/environment/tests/test_connectors.py index 03c13796..f6047a23 100644 --- a/hud/environment/tests/test_connectors.py +++ b/hud/environment/tests/test_connectors.py @@ -217,52 +217,3 @@ def mount(self, server: Any, *, prefix: str | None = None) -> None: assert "browser" in env._connections -class TestTaskConnectorMixin: - """Tests for TaskConnectorMixin.""" - - @patch("httpx.Client") - def test_connect_task_fetches_and_applies_config(self, mock_httpx_cls: MagicMock) -> None: - """connect_task fetches task and applies mcp_config.""" - from hud.environment.connectors.task import TaskConnectorMixin - - class TestEnv(TaskConnectorMixin): - def __init__(self) -> None: - self._connections: dict[str, Connector] = {} - self._setup_calls: list[tuple[str, dict[str, Any]]] = [] - self._evaluate_calls: list[tuple[str, dict[str, Any]]] = [] - - def setup_tool(self, call: Any, /, **kwargs: Any) -> Any: - self._setup_calls.append((call, kwargs)) - return self - - def evaluate_tool(self, call: Any, /, **kwargs: Any) -> Any: - self._evaluate_calls.append((call, kwargs)) - return self - - # Mock httpx response with task data - mock_response = MagicMock() - mock_response.json.return_value = { - "id": "task-123", - "prompt": "Test task prompt", - "mcp_config": { - "browser": {"url": "https://mcp.hud.ai/browser"}, - }, - "setup_tool": None, - "evaluate_tool": None, - } - mock_response.raise_for_status = MagicMock() - - mock_client = MagicMock() - mock_client.get.return_value = mock_response - mock_client.__enter__ = MagicMock(return_value=mock_client) - mock_client.__exit__ = MagicMock(return_value=None) - mock_httpx_cls.return_value = mock_client - - env = TestEnv() - with patch("hud.settings.settings") as mock_settings: - mock_settings.hud_api_url = "https://api.hud.so" - mock_settings.api_key = "test-key" - - env.connect_task("my-org/my-task") - - assert "browser" in env._connections diff --git a/hud/environment/tests/test_environment.py b/hud/environment/tests/test_environment.py index 1f75ab33..39f85d9e 100644 --- a/hud/environment/tests/test_environment.py +++ b/hud/environment/tests/test_environment.py @@ -25,31 +25,6 @@ def test_prompt_can_be_set(self) -> None: env.prompt = "Navigate to google.com" assert env.prompt == "Navigate to google.com" - def test_prompt_set_from_task(self) -> None: - """connect_task sets prompt from task.prompt.""" - from hud.environment.connection import Connector # noqa: TC001 - from hud.environment.connectors.task import TaskConnectorMixin - from hud.types import Task - - class TestEnv(TaskConnectorMixin): - def __init__(self) -> None: - self._connections: dict[str, Connector] = {} - self.prompt: str | None = None - - def setup_tool(self, call: Any, /, **kwargs: Any) -> Any: - return self - - def evaluate_tool(self, call: Any, /, **kwargs: Any) -> Any: - return self - - def connect_mcp_config(self, config: dict) -> Any: - return self - - env = TestEnv() - task = Task(prompt="Test prompt", mcp_config={}) - env._apply_task(task) - - assert env.prompt == "Test prompt" class TestEnvironmentContextManager: diff --git a/hud/environment/tests/test_scripts.py b/hud/environment/tests/test_scenarios.py similarity index 64% rename from hud/environment/tests/test_scripts.py rename to hud/environment/tests/test_scenarios.py index e07481f0..875ac846 100644 --- a/hud/environment/tests/test_scripts.py +++ b/hud/environment/tests/test_scenarios.py @@ -1,4 +1,4 @@ -"""Tests for Environment script decorator.""" +"""Tests for Environment scenario decorator.""" from __future__ import annotations @@ -7,26 +7,26 @@ from hud.environment import Environment -class TestScriptDecorator: - """Tests for @env.script decorator.""" +class TestScenarioDecorator: + """Tests for @env.scenario decorator.""" - def test_script_registers_function(self) -> None: - """@env.script registers the function.""" + def test_scenario_registers_function(self) -> None: + """@env.scenario registers the function.""" env = Environment("test-env") - @env.script("greet") - async def greet_script(name: str): + @env.scenario("greet") + async def greet_scenario(name: str): yield f"Hello, {name}!" yield 1.0 - assert "greet" in env._scripts + assert "greet" in env._scenarios - def test_script_creates_mcp_prompt(self) -> None: - """@env.script creates an MCP prompt.""" + def test_scenario_creates_mcp_prompt(self) -> None: + """@env.scenario creates an MCP prompt.""" env = Environment("test-env") - @env.script("greet", description="Greeting script") - async def greet_script(name: str): + @env.scenario("greet", description="Greeting scenario") + async def greet_scenario(name: str): yield f"Hello, {name}!" yield 1.0 @@ -34,12 +34,12 @@ async def greet_script(name: str): prompt_names = list(env._prompt_manager._prompts.keys()) assert "test-env:greet" in prompt_names - def test_script_creates_mcp_resource(self) -> None: - """@env.script creates an MCP resource.""" + def test_scenario_creates_mcp_resource(self) -> None: + """@env.scenario creates an MCP resource.""" env = Environment("test-env") - @env.script("greet") - async def greet_script(name: str): + @env.scenario("greet") + async def greet_scenario(name: str): yield f"Hello, {name}!" yield 1.0 @@ -47,12 +47,12 @@ async def greet_script(name: str): resource_uris = list(env._resource_manager._resources.keys()) assert "test-env:greet" in resource_uris - def test_script_extracts_arguments(self) -> None: - """@env.script extracts function arguments for prompt.""" + def test_scenario_extracts_arguments(self) -> None: + """@env.scenario extracts function arguments for prompt.""" env = Environment("test-env") - @env.script("checkout") - async def checkout_script(user_id: str, amount: int = 100): + @env.scenario("checkout") + async def checkout_scenario(user_id: str, amount: int = 100): yield f"Checkout for {user_id}: ${amount}" yield 1.0 @@ -67,17 +67,17 @@ async def checkout_script(user_id: str, amount: int = 100): assert "amount" in arg_names -class TestScriptExecution: - """Tests for script execution flow.""" +class TestScenarioExecution: + """Tests for scenario execution flow.""" @pytest.mark.asyncio - async def test_script_setup_phase(self) -> None: - """Script setup phase yields prompt.""" + async def test_scenario_setup_phase(self) -> None: + """Scenario setup phase yields prompt.""" env = Environment("test-env") setup_ran = False - @env.script("test") - async def test_script(): + @env.scenario("test") + async def test_scenario(): nonlocal setup_ran setup_ran = True yield "Test prompt" @@ -96,12 +96,12 @@ async def test_script(): assert "Test prompt" in str(result[0].content) @pytest.mark.asyncio - async def test_script_stores_session(self) -> None: - """Script stores generator in session for evaluate phase.""" + async def test_scenario_stores_session(self) -> None: + """Scenario stores generator in session for evaluate phase.""" env = Environment("test-env") - @env.script("test") - async def test_script(): + @env.scenario("test") + async def test_scenario(): yield "Test prompt" yield 1.0 @@ -111,16 +111,16 @@ async def test_script(): await prompt.render({}) # Check session was stored - assert "test" in env._script_latest + assert "test" in env._scenario_latest @pytest.mark.asyncio - async def test_script_full_flow(self) -> None: - """Script runs setup and evaluate phases correctly.""" + async def test_scenario_full_flow(self) -> None: + """Scenario runs setup and evaluate phases correctly.""" env = Environment("test-env") phases = [] - @env.script("test") - async def test_script(): + @env.scenario("test") + async def test_scenario(): phases.append("setup") yield "Test prompt" phases.append("evaluate") @@ -140,17 +140,17 @@ async def test_script(): assert "evaluate" in phases -class TestScriptWithArgs: - """Tests for scripts with arguments.""" +class TestScenarioWithArgs: + """Tests for scenarios with arguments.""" @pytest.mark.asyncio - async def test_script_receives_args(self) -> None: - """Script receives arguments from prompt call.""" + async def test_scenario_receives_args(self) -> None: + """Scenario receives arguments from prompt call.""" env = Environment("test-env") received_args = {} - @env.script("checkout") - async def checkout_script(user_id: str, amount: int = 100): + @env.scenario("checkout") + async def checkout_scenario(user_id: str, amount: int = 100): received_args["user_id"] = user_id received_args["amount"] = amount yield f"Checkout {user_id}: ${amount}" @@ -165,16 +165,16 @@ async def checkout_script(user_id: str, amount: int = 100): assert received_args["amount"] == 50 -class TestScriptSubmit: - """Tests for script submit and answer flow.""" +class TestScenarioSubmit: + """Tests for scenario submit and answer flow.""" @pytest.mark.asyncio async def test_submit_stores_answer(self) -> None: - """submit() stores answer for script.""" + """submit() stores answer for scenario.""" env = Environment("test-env") - @env.script("test") - async def test_script(): + @env.scenario("test") + async def test_scenario(): yield "What is 2+2?" yield 1.0 @@ -186,16 +186,16 @@ async def test_script(): # Submit answer await env.submit("test", "4") - assert env._script_answers.get("test") == "4" + assert env._scenario_answers.get("test") == "4" @pytest.mark.asyncio - async def test_script_receives_answer(self) -> None: - """Script receives submitted answer via yield.""" + async def test_scenario_receives_answer(self) -> None: + """Scenario receives submitted answer via yield.""" env = Environment("test-env") received_answer = None - @env.script("qa") - async def qa_script(): + @env.scenario("qa") + async def qa_scenario(): nonlocal received_answer answer = yield "What is 2+2?" received_answer = answer @@ -207,7 +207,7 @@ async def qa_script(): await prompt.render({}) # Submit answer - env._script_answers["qa"] = "4" + env._scenario_answers["qa"] = "4" # Run evaluate resource = env._resource_manager._resources.get("test-env:qa") @@ -217,12 +217,12 @@ async def qa_script(): assert received_answer == "4" @pytest.mark.asyncio - async def test_script_evaluates_answer(self) -> None: - """Script evaluates answer and returns reward.""" + async def test_scenario_evaluates_answer(self) -> None: + """Scenario evaluates answer and returns reward.""" env = Environment("test-env") - @env.script("grading") - async def grading_script(): + @env.scenario("grading") + async def grading_scenario(): answer = yield "What is the capital of France?" yield 1.0 if "paris" in answer.lower() else 0.0 @@ -232,7 +232,7 @@ async def grading_script(): await prompt.render({}) # Submit correct answer - env._script_answers["grading"] = "Paris" + env._scenario_answers["grading"] = "Paris" # Run evaluate resource = env._resource_manager._resources.get("test-env:grading") @@ -245,15 +245,15 @@ async def grading_script(): assert data["reward"] == 1.0 -class TestScriptMeta: - """Tests for script _meta containing code.""" +class TestScenarioMeta: + """Tests for scenario _meta containing code.""" - def test_script_captures_source_code(self) -> None: - """@env.script captures function source in meta.""" + def test_scenario_captures_source_code(self) -> None: + """@env.scenario captures function source in meta.""" env = Environment("test-env") - @env.script("example") - async def example_script(x: int): + @env.scenario("example") + async def example_scenario(x: int): yield f"Process {x}" yield 1.0 @@ -261,15 +261,15 @@ async def example_script(x: int): assert prompt is not None assert prompt.meta is not None assert "code" in prompt.meta - assert "async def example_script" in prompt.meta["code"] + assert "async def example_scenario" in prompt.meta["code"] assert "yield" in prompt.meta["code"] - def test_script_meta_on_resource(self) -> None: + def test_scenario_meta_on_resource(self) -> None: """Resource also has source code in meta.""" env = Environment("test-env") - @env.script("example") - async def example_script(): + @env.scenario("example") + async def example_scenario(): yield "Test" yield 1.0 @@ -277,4 +277,4 @@ async def example_script(): assert resource is not None assert resource.meta is not None assert "code" in resource.meta - assert "async def example_script" in resource.meta["code"] + assert "async def example_scenario" in resource.meta["code"] diff --git a/hud/environment/types.py b/hud/environment/types.py index 8e8fcd97..dfa76abd 100644 --- a/hud/environment/types.py +++ b/hud/environment/types.py @@ -2,27 +2,22 @@ from __future__ import annotations -from pydantic import BaseModel +from pydantic import BaseModel, Field -from hud.types import MCPToolCall - -__all__ = ["EnvConfig", "HubConfig"] - - -class HubConfig(BaseModel): - """Configuration for a single hub connection.""" - - slug: str - alias: str | None = None - prefix: str | None = None - include: list[str] | None = None - exclude: list[str] | None = None +__all__ = ["EnvConfig"] class EnvConfig(BaseModel): - """Environment configuration for trace reproducibility.""" - - name: str - hubs: list[HubConfig] = [] - setup_tools: list[MCPToolCall] = [] - evaluate_tools: list[MCPToolCall] = [] + """Environment configuration for Tasks. + + Specifies which hub to connect to and optional tool filtering. + + Attributes: + name: Hub name to connect via connect_hub() (e.g., "browser", "sheets") + include: Optional whitelist of tool names to include + exclude: Optional blacklist of tool names to exclude + """ + + name: str = Field(description="Hub name to connect to") + include: list[str] | None = Field(default=None, description="Whitelist of tool names") + exclude: list[str] | None = Field(default=None, description="Blacklist of tool names") diff --git a/hud/eval/__init__.py b/hud/eval/__init__.py index 78cab2cd..93c5f699 100644 --- a/hud/eval/__init__.py +++ b/hud/eval/__init__.py @@ -1,12 +1,12 @@ """HUD Eval - Evaluation context and management. This module provides: -- Eval: A runnable evaluation unit (from env()) +- Task: A runnable evaluation unit (from env()) - EvalContext: Environment with evaluation tracking (trace_id, reward, etc.) - eval(): Standalone context manager for task-based evaluation Usage: - # Using env() to create Eval + # Using env() to create Task env = Environment("my-env").connect_hub("browser") async with env() as ctx: @@ -19,9 +19,9 @@ async with hud.eval("my-org/task:1") as ctx: await agent.run(ctx) - # Orchestrated with Eval objects - evals = [env("checkout", user_id="alice"), env("checkout", user_id="bob")] - async with hud.eval(evals, variants={"model": ["gpt-4o"]}, group=4) as ctx: + # Orchestrated with Task objects + tasks = [env("checkout", user_id="alice"), env("checkout", user_id="bob")] + async with hud.eval(tasks, variants={"model": ["gpt-4o"]}, group=4) as ctx: await agent.run(ctx.prompt) # Blank eval for manual reward @@ -36,8 +36,8 @@ # Auto-instrument httpx on import import hud.eval.instrument # noqa: F401 -# Eval is safe to import -from hud.eval.eval import Eval +# Task is safe to import +from hud.eval.task import Task # run_eval is safe to import (uses lazy imports internally) from hud.eval.manager import run_eval @@ -46,7 +46,7 @@ from hud.eval.context import EvalContext __all__ = [ - "Eval", + "Task", "EvalContext", "run_eval", ] diff --git a/hud/eval/context.py b/hud/eval/context.py index 68cd1df3..2d35199a 100644 --- a/hud/eval/context.py +++ b/hud/eval/context.py @@ -16,7 +16,6 @@ from typing import TYPE_CHECKING, Any, Self from hud.environment import Environment -from hud.environment.types import EnvConfig from hud.settings import settings from hud.shared import make_request from hud.telemetry.job import get_current_job @@ -24,7 +23,7 @@ if TYPE_CHECKING: from types import TracebackType - from hud.types import Task + from hud.types import LegacyTask from hud.eval.types import EvalExitPayload, EvalPayload, ParallelEvalComplete @@ -89,8 +88,6 @@ def __init__( index: int = 0, variants: dict[str, Any] | None = None, code_snippet: str | None = None, - env_config: dict[str, Any] | None = None, - task: Task | None = None, trace: bool = True, quiet: bool = False, **env_kwargs: Any, @@ -106,8 +103,6 @@ def __init__( index: Index in parallel execution variants: Variant assignment for A/B testing code_snippet: Code being evaluated (for reproducibility) - env_config: Environment configuration dict - task: Task definition (if loaded from slug) trace: Whether to send trace data to backend (default True) quiet: Whether to suppress printing links (default False) **env_kwargs: Additional kwargs passed to Environment.__init__ @@ -135,7 +130,7 @@ def __init__( self.variants: dict[str, Any] = variants or {} # User-settable (per-run values, override Environment defaults) - self.prompt: str | None = None # From script setup or task + self.prompt: str | None = None # From scenario setup or task self.reward: float | None = None self.answer: str | None = None # Agent's submitted answer @@ -145,16 +140,8 @@ def __init__( # Parallel results self.results: list[EvalContext] | None = None - # Code and config + # Code snippet for reproducibility self.code_snippet: str | None = code_snippet - self._eval_env_config: dict[str, Any] | None = env_config - - # Task definition (if loaded from slug) - self.task: Task | None = task - - # Apply task configuration - if task: - self._apply_task(task) # Private state for eval tracking self._eval_api_key = api_key @@ -164,34 +151,9 @@ def __init__( self._is_summary: bool = False # True for summary contexts (skip trace) self._suppress_link: bool = quiet # True to suppress printing eval link self._trace_enabled: bool = trace # Whether to send trace data to backend - self._script_name: str | None = None # Current script name (for submit) + self._scenario_name: str | None = None # Current scenario name (for submit) self._source_env_name: str | None = None # Source env name for remote lookups - def _apply_task(self, task: Task) -> None: - """Apply a Task definition to this environment.""" - # Set prompt - if task.prompt: - self.prompt = task.prompt - - # Connect MCP servers - if task.mcp_config: - self.connect_mcp_config(task.mcp_config) - - # Configure setup tool calls - if task.setup_tool: - setup_calls = task.setup_tool - if not isinstance(setup_calls, list): - setup_calls = [setup_calls] - for call in setup_calls: - self.setup_tool(call.name, **(call.arguments or {})) - - # Configure evaluate tool calls - if task.evaluate_tool: - eval_calls = task.evaluate_tool - if not isinstance(eval_calls, list): - eval_calls = [eval_calls] - for call in eval_calls: - self.evaluate_tool(call.name, **(call.arguments or {})) @classmethod def from_environment( @@ -206,7 +168,6 @@ def from_environment( index: int = 0, variants: dict[str, Any] | None = None, code_snippet: str | None = None, - env_config: dict[str, Any] | None = None, trace: bool = True, quiet: bool = False, ) -> EvalContext: @@ -225,7 +186,6 @@ def from_environment( index: Index in parallel execution variants: Variant assignment code_snippet: Code being evaluated - env_config: Environment configuration """ ctx = cls( name=name, @@ -236,7 +196,6 @@ def from_environment( index=index, variants=variants, code_snippet=code_snippet, - env_config=env_config, trace=trace, quiet=quiet, ) @@ -244,23 +203,22 @@ def from_environment( # Copy connections from parent - each connector is copied so parallel # execution gets fresh client instances ctx._connections = {name: connector.copy() for name, connector in env._connections.items()} - ctx._hub_configs = getattr(env, "_hub_configs", []).copy() ctx._setup_calls = env._setup_calls.copy() ctx._evaluate_calls = env._evaluate_calls.copy() - # Copy scripts (definitions) by reference - they don't change - ctx._scripts = getattr(env, "_scripts", {}) + # Copy scenarios (definitions) by reference - they don't change + ctx._scenarios = getattr(env, "_scenarios", {}) # Create fresh session state for this eval (parallel evals each need their own) - ctx._script_sessions = {} - ctx._script_latest = {} - ctx._script_answers = {} + ctx._scenario_sessions = {} + ctx._scenario_latest = {} + ctx._scenario_answers = {} - # Store source env name for remote script lookups + # Store source env name for remote scenario lookups ctx._source_env_name = env.name # Copy managers by reference (they hold local tools, prompts, resources) # This allows ctx.call_tool(), ctx.get_prompt(), ctx.read_resource() to work - # for locally defined tools/scripts + # for locally defined tools/scenarios ctx._tool_manager = env._tool_manager ctx._prompt_manager = env._prompt_manager ctx._resource_manager = env._resource_manager @@ -271,64 +229,6 @@ def from_environment( return ctx - @classmethod - def from_task( - cls, - task: Task, - name: str | None = None, - *, - trace_id: str | None = None, - api_key: str | None = None, - job_id: str | None = None, - group_id: str | None = None, - index: int = 0, - variants: dict[str, Any] | None = None, - code_snippet: str | None = None, - trace: bool = True, - quiet: bool = False, - ) -> EvalContext: - """Create an EvalContext from a Task definition. - - .. deprecated:: 0.5.0 - Use Eval objects from env() instead of Task objects. - - Args: - task: Task definition - name: Evaluation name (defaults to task.id or "eval") - trace_id: Unique trace ID - api_key: API key for backend calls - job_id: Job ID to link to - group_id: Group ID for parallel evaluations - index: Index in parallel execution - variants: Variant assignment - code_snippet: Code being evaluated - trace: Whether to send trace data to backend - quiet: Whether to suppress printing links - """ - import warnings - - warnings.warn( - "EvalContext.from_task() is deprecated. Use Eval objects from env() instead.", - DeprecationWarning, - stacklevel=2, - ) - - eval_name = name or task.id or "eval" - - return cls( - name=eval_name, - trace_id=trace_id, - api_key=api_key, - job_id=job_id, - group_id=group_id, - index=index, - variants=variants, - code_snippet=code_snippet, - task=task, - trace=trace, - quiet=quiet, - ) - # ========================================================================= # Summary Context - Attribute Access Control # ========================================================================= @@ -407,15 +307,10 @@ def _get_eval_api_key(self) -> str | None: def _build_base_payload(self) -> EvalPayload: """Build the base payload for enter/exit.""" - env_config_model: EnvConfig | None = None - if self._eval_env_config: - env_config_model = EnvConfig(**self._eval_env_config) - return EvalPayload( job_name=self.eval_name, prompt=self.prompt, code_snippet=self.code_snippet, - env_config=env_config_model, job_id=self.job_id, group_id=self.group_id, variants=self.variants if self.variants else None, @@ -438,10 +333,10 @@ async def log(self, metrics: dict[str, Any]) -> None: logger.warning("Failed to log metrics: %s", e) async def submit(self, answer: str) -> None: - """Submit the agent's answer for script evaluation. + """Submit the agent's answer for scenario evaluation. - Delegates to Environment.submit() with the current script name. - The answer will be passed to the script's evaluate phase via + Delegates to Environment.submit() with the current scenario name. + The answer will be passed to the scenario's evaluate phase via `yield`, e.g.: `answer = yield "Do the task"` Args: @@ -451,17 +346,17 @@ async def submit(self, answer: str) -> None: async with env("checkout", product="laptop") as ctx: response = await agent.run(ctx.prompt) await ctx.submit(response) - # On exit, script's evaluate phase receives the answer + # On exit, scenario's evaluate phase receives the answer """ - if not self._script_name: - logger.warning("submit() called but no script is running") + if not self._scenario_name: + logger.warning("submit() called but no scenario is running") return # Store answer on context for display self.answer = answer # Delegate to Environment.submit() which handles storage + broadcast - await super().submit(self._script_name, answer) + await super().submit(self._scenario_name, answer) async def _eval_enter(self) -> None: """Notify backend that eval has started.""" diff --git a/hud/eval/eval.py b/hud/eval/eval.py deleted file mode 100644 index 4eb681f9..00000000 --- a/hud/eval/eval.py +++ /dev/null @@ -1,254 +0,0 @@ -"""Eval - A runnable evaluation unit (data class). - -An Eval holds the configuration needed to run an evaluation: -- Environment configuration (how to create/connect) -- Optional script name and args - -When entered as a context manager, it creates an EvalContext. - -Usage: - env = Environment("my-env").connect_hub("browser") - - # Empty - just env - async with env() as ctx: - await ctx.call_tool("navigate", url="...") - - # With script - async with env("checkout", user_id="alice") as ctx: - await agent.run(ctx.prompt) - - # Orchestrated via hud.eval - evals = [env("checkout", user_id="alice"), env("checkout", user_id="bob")] - async with hud.eval(evals, variants={"model": ["gpt-4o"]}, group=4) as ctx: - ... -""" - -from __future__ import annotations - -import logging -from dataclasses import dataclass, field -from typing import TYPE_CHECKING, Any - -if TYPE_CHECKING: - from types import TracebackType - - from hud.eval.context import EvalContext - -__all__ = ["Eval", "build_eval_name"] - -logger = logging.getLogger(__name__) - - -def build_eval_name(script: str | None, args: dict[str, Any] | None) -> str: - """Build descriptive name: 'script with val1, val2, ...'""" - if not script: - return "eval" - if not args: - return script - - val_parts = [] - for v in list(args.values())[:3]: # Max 3 values - v_str = repr(v) if isinstance(v, str) else str(v) - if len(v_str) > 25: - v_str = v_str[:22] + "..." - val_parts.append(v_str) - - if val_parts: - return f"{script} with {', '.join(val_parts)}" - return script - - -@dataclass -class Eval: - """A runnable evaluation unit (data class). - - Holds the configuration to create an EvalContext: - - env: The environment (live instance or serialized config) - - script: Optional script name to run (from @env.script) - - args: Arguments for the script - - When entered as a context manager, creates an EvalContext. - - Attributes: - env: Environment instance (local) or EnvConfig dict (remote) or None (blank) - script: Script name to run (None for env-only) - args: Script arguments - """ - - # Core config - env can be live Environment or serialized config - env: Any = None # Environment | dict[str, Any] | None - script: str | None = None - args: dict[str, Any] = field(default_factory=dict) - - # EvalContext creation params (set by hud.eval for parallel execution) - trace_id: str | None = field(default=None, repr=False) - api_key: str | None = field(default=None, repr=False) - job_id: str | None = field(default=None, repr=False) - group_id: str | None = field(default=None, repr=False) - index: int = field(default=0, repr=False) - variants: dict[str, Any] = field(default_factory=dict, repr=False) - code_snippet: str | None = field(default=None, repr=False) - _suppress_link: bool = field(default=False, repr=False) - _trace: bool = field(default=True, repr=False) - _quiet: bool = field(default=False, repr=False) - - # Runtime state - _ctx: EvalContext | None = field(default=None, repr=False) - - # Backwards compat alias - @property - def env_config(self) -> dict[str, Any] | None: - """Get serializable env config (for backwards compat and backend).""" - from hud.environment import Environment - - if isinstance(self.env, Environment): - return self.env._get_env_config() - elif isinstance(self.env, dict): - return self.env - return None - - def copy(self) -> Eval: - """Create a copy of this Eval for parallel execution.""" - return Eval( - env=self.env, # Share reference - from_environment handles copying - script=self.script, - args=self.args.copy(), - trace_id=None, # Each copy gets unique trace_id - api_key=self.api_key, - job_id=self.job_id, - group_id=self.group_id, - index=self.index, - variants=self.variants.copy(), - code_snippet=self.code_snippet, - _suppress_link=self._suppress_link, - _trace=self._trace, - _quiet=self._quiet, - ) - - def to_eval_context(self) -> EvalContext: - """Convert this Eval to an EvalContext. - - Creates an EvalContext from the environment (live or from config). - Also handles deprecated Task objects stored in _task attribute. - """ - from hud.environment import Environment - from hud.eval.context import EvalContext - - # Check for deprecated Task (backwards compat) - task = getattr(self, "_task", None) - if task is not None: - import warnings - - warnings.warn( - "Task objects are deprecated. Use Eval from env() instead.", - DeprecationWarning, - stacklevel=3, - ) - ctx = EvalContext.from_task( - task=task, - api_key=self.api_key, - job_id=self.job_id, - group_id=self.group_id, - index=self.index, - variants=self.variants, - code_snippet=self.code_snippet, - trace=self._trace, - quiet=self._quiet, - ) - ctx._suppress_link = self._suppress_link - return ctx - - # Get or create environment - if isinstance(self.env, Environment): - # Local - use live environment (from_environment handles copying) - source_env = self.env - elif isinstance(self.env, dict): - # Remote/config - create fresh from config - source_env = Environment.from_config(self.env) - else: - # Blank - source_env = Environment("eval") - - eval_name = build_eval_name(self.script, self.args) - - # Create EvalContext from environment - ctx = EvalContext.from_environment( - env=source_env, - name=eval_name, - trace_id=self.trace_id, - api_key=self.api_key, - job_id=self.job_id, - group_id=self.group_id, - index=self.index, - variants=self.variants, - code_snippet=self.code_snippet, - env_config=self.env_config, - ) - ctx._suppress_link = self._suppress_link - ctx._trace_enabled = self._trace - - return ctx - - async def __aenter__(self) -> EvalContext: - """Enter eval context. - - Order of operations: - 1. Create EvalContext from environment config - 2. Connect environment (MCP servers, etc.) - 3. Run script setup (if script) → sets ctx.prompt - 4. Notify backend (with prompt now set) - 5. Print trace link - """ - self._ctx = self.to_eval_context() - await self._ctx.__aenter__() # Connect env, set trace headers - - # Run script setup (sets prompt) - if self.script: - await self._run_script_setup() - - # Notify backend with prompt included - await self._ctx._eval_enter() - self._ctx._print_eval_link() - - return self._ctx - - async def __aexit__( - self, - exc_type: type[BaseException] | None, - exc_val: BaseException | None, - exc_tb: TracebackType | None, - ) -> None: - """Exit eval context - run script evaluate and exit EvalContext.""" - if self._ctx is None: - return - - # If we have a script and no error, run its evaluate phase - if self.script and exc_type is None: - await self._run_script_evaluate() - - # Exit the EvalContext - await self._ctx.__aexit__(exc_type, exc_val, exc_tb) - self._ctx = None - - async def _run_script_setup(self) -> None: - """Run the script's setup phase (get prompt).""" - if self._ctx is None or self.script is None: - return - - # Store script name on context for ctx.submit() - self._ctx._script_name = self.script - - # Delegate to ScriptMixin.run_script_setup - prompt = await self._ctx.run_script_setup(self.script, self.args) - if prompt: - self._ctx.prompt = prompt - - async def _run_script_evaluate(self) -> None: - """Run the script's evaluate phase (get reward).""" - if self._ctx is None or self.script is None: - return - - # Delegate to ScriptMixin.run_script_evaluate - reward = await self._ctx.run_script_evaluate(self.script) - if reward is not None: - self._ctx.reward = reward diff --git a/hud/eval/manager.py b/hud/eval/manager.py index 789c9df4..b9f0a065 100644 --- a/hud/eval/manager.py +++ b/hud/eval/manager.py @@ -25,178 +25,40 @@ from collections.abc import AsyncGenerator from hud.eval.context import EvalContext - from hud.eval.eval import Eval - from hud.types import Task + from hud.eval.task import Task logger = logging.getLogger(__name__) -# Type alias for eval source: slug strings, Eval objects, or deprecated Task objects -EvalSource = "str | list[str] | Eval | list[Eval] | Task | list[Task] | None" - - -def _parse_slug(slug: str) -> tuple[str, str | None]: - """Parse a task slug into (base_slug, index_or_wildcard). - - Args: - slug: Task slug like "my-org/task", "my-org/task:1", or "my-org/task:*" - - Returns: - Tuple of (base_slug, index_str or None) - - "my-org/task" -> ("my-org/task", None) - - "my-org/task:1" -> ("my-org/task", "1") - - "my-org/task:*" -> ("my-org/task", "*") - """ - if ":" in slug: - parts = slug.rsplit(":", 1) - return parts[0], parts[1] - return slug, None - - -def _get_eval_name( - source: str | list[str] | None = None, - evals: list[Eval] | None = None, - tasks: list[Task] | None = None, # Deprecated -) -> str: +def _get_eval_name(tasks: list[Task] | None = None) -> str: """Extract a nice name for job display. Args: - source: Single slug or list of slugs (if string-based) - evals: List of Eval objects (primary path) - tasks: List of Task objects (deprecated) + tasks: List of Task objects Returns: - Name like "script with val1, val2" or "eval" if no source + Name like "scenario with val1, val2" or "eval" if no tasks """ - from hud.eval.eval import build_eval_name + from hud.eval.task import build_eval_name - # If we have Eval objects, derive name from first one - if evals and evals[0].script: - return build_eval_name(evals[0].script, evals[0].args) - - # Deprecated: If we have tasks with IDs, use first task ID + # If we have Task objects, derive name from first one if tasks: - first_task = tasks[0] - if first_task.id: - # Extract name from task ID (might be "evalset/task_name") - task_id = str(first_task.id) - if "/" in task_id: - return task_id.rsplit("/", 1)[1] - return task_id - # Fall back to prompt excerpt - if first_task.prompt: - return first_task.prompt[:30].strip() - - # If we have string slugs - if source is not None: - # Get the first slug - first_slug = source if isinstance(source, str) else source[0] - - # Remove index/wildcard suffix (":1" or ":*") - base_slug, _ = _parse_slug(first_slug) - - # Extract the evalset name (part after last "/") - if "/" in base_slug: - return base_slug.rsplit("/", 1)[1] - - return base_slug + if tasks[0].scenario: + return build_eval_name(tasks[0].scenario, tasks[0].args) + # Fall back to env name or prompt + if tasks[0].env and hasattr(tasks[0].env, "name"): + return tasks[0].env.name + if tasks[0].env and hasattr(tasks[0].env, "prompt") and tasks[0].env.prompt: + return tasks[0].env.prompt[:30].strip() + if tasks[0].id: + return tasks[0].id return "eval" -def _load_evals_from_slugs(slugs: str | list[str]) -> list[Eval]: - """Load Eval configs from platform by slugs. - - Args: - slugs: Single slug or list of slugs. Slugs can be: - - "my-org/eval" - single eval - - "my-org/eval:N" - eval at index N - - "my-org/eval:*" - all evals matching pattern - - Returns: - List of Eval objects - """ - import httpx - - from hud.settings import settings - - if isinstance(slugs, str): - slugs = [slugs] - - evals: list[Eval] = [] - - headers = {} - if settings.api_key: - headers["Authorization"] = f"Bearer {settings.api_key}" - - with httpx.Client() as client: - for slug in slugs: - base_slug, index_str = _parse_slug(slug) - - if index_str == "*": - # Fetch all evals for this evalset - logger.info("Loading all evals for: %s", base_slug) - response = client.get( - f"{settings.hud_api_url}/evals/{base_slug}", - headers=headers, - params={"all": "true"}, - ) - response.raise_for_status() - data = response.json() - - if isinstance(data, list): - evals.extend(_eval_from_api(item) for item in data) - else: - evals.append(_eval_from_api(data)) - - elif index_str is not None: - # Fetch specific eval by index - logger.info("Loading eval: %s (index %s)", base_slug, index_str) - response = client.get( - f"{settings.hud_api_url}/evals/{base_slug}", - headers=headers, - params={"index": index_str}, - ) - response.raise_for_status() - data = response.json() - evals.append(_eval_from_api(data)) - - else: - # Fetch single eval - logger.info("Loading eval: %s", slug) - response = client.get( - f"{settings.hud_api_url}/evals/{slug}", - headers=headers, - ) - response.raise_for_status() - data = response.json() - evals.append(_eval_from_api(data)) - - return evals - - -def _eval_from_api(data: dict[str, Any]) -> Eval: - """Convert API response to Eval object. - - Expected API response format: - { - "env_config": {...}, # EnvConfig dict - "script": "script_name", # Optional - "args": {...}, # Script arguments - } - """ - from hud.eval.eval import Eval - - return Eval( - env=data.get("env_config"), # Serialized config from backend - script=data.get("script"), - args=data.get("args", {}), - ) - - @asynccontextmanager async def run_eval( - source: str | list[str] | Task | list[Task] | Eval | list[Eval] | None = None, + source: Task | list[Task] | None = None, *, variants: dict[str, Any] | None = None, group: int = 1, @@ -209,18 +71,16 @@ async def run_eval( ) -> AsyncGenerator[EvalContext, None]: """Standalone eval context manager. - Creates an EvalContext for evaluation, optionally loading task configuration - from slugs, using Task objects, or using Eval objects directly. + Creates an EvalContext for evaluation using Task objects (or deprecated LegacyTask). + For loading tasks from datasets, use load_dataset() first. Args: - source: Eval source. Can be: + source: Task source. Can be: - None: Create blank eval context - - str: Task slug like "my-org/task", "my-org/task:N", "my-org/task:*" - - list[str]: Multiple task slugs - - Task: Single Task object (for backwards compat with run_tasks) - - list[Task]: List of Task objects (for backwards compat with run_tasks) - - Eval: Single Eval object (from env()) - - list[Eval]: List of Eval objects (from env()) + - Task: Single Task object (from env() or load_dataset()) + - list[Task]: List of Task objects + - LegacyTask: Single LegacyTask object (deprecated, use Task.from_v4()) + - list[LegacyTask]: List of LegacyTask objects (deprecated) variants: A/B test configuration (dict with list values expanded) group: Runs per variant for statistical significance group_ids: Optional list of group IDs @@ -235,32 +95,26 @@ async def run_eval( Example: ```python + from hud.datasets import load_dataset + # Blank eval (for manual reward) async with hud.eval() as ctx: ctx.reward = compute_reward() - # With task slug - async with hud.eval("my-org/browser-task:1") as ctx: - await agent.run(ctx) - ctx.reward = result.reward - - # Multiple tasks - async with hud.eval(["task:1", "task:2"]) as ctx: - await agent.run(ctx) - - # All tasks in evalset - async with hud.eval("my-org/evalset:*") as ctx: - await agent.run(ctx) - - # With Eval objects (from env()) + # With Task objects (from env()) env = Environment("my-env").connect_hub("browser") - evals = [env("checkout", user_id="alice"), env("checkout", user_id="bob")] - async with hud.eval(evals, variants={"model": ["gpt-4o"]}, group=4) as ctx: + tasks = [env("checkout", user_id="alice"), env("checkout", user_id="bob")] + async with hud.eval(tasks, variants={"model": ["gpt-4o"]}, group=4) as ctx: await agent.run(ctx.prompt) + # Load tasks from dataset first + tasks = load_dataset("hud-evals/SheetBench-50") + async with hud.eval(tasks) as ctx: + await agent.run(ctx) + # With variants and group async with hud.eval( - "task", + tasks, variants={"model": ["gpt-4o", "claude"]}, group=3, ) as ctx: @@ -269,7 +123,7 @@ async def run_eval( ctx.reward = evaluate() # With concurrency limit - async with hud.eval("my-org/evalset:*", max_concurrent=10) as ctx: + async with hud.eval(tasks, max_concurrent=10) as ctx: await agent.run(ctx) # Access results after parallel run @@ -277,10 +131,8 @@ async def run_eval( print(f"{e.variants}: reward={e.reward}") ``` """ - import warnings - - from hud.eval.eval import Eval - from hud.types import Task + from hud.eval.task import Task + from hud.types import LegacyTask if group <= 0: raise ValueError("group must be >= 1") @@ -288,50 +140,40 @@ async def run_eval( # Expand variants variant_combos = expand_variants(variants) - # Parse source into evals list (or deprecated tasks list) - evals: list[Eval] = [] - tasks: list[Task] = [] # Deprecated path - slugs: str | list[str] | None = None # Track if we had string slugs (for naming) + # Parse source into tasks list - only Task objects accepted + tasks: list[Task] = [] if source is not None: - if isinstance(source, Eval): - # Single Eval object - evals = [source] - elif isinstance(source, list) and source and isinstance(source[0], Eval): - # List of Eval objects - evals = source # type: ignore[assignment] - elif isinstance(source, Task): - # Single Task object (deprecated) - warnings.warn( - "Passing Task objects to hud.eval() is deprecated. " - "Use Eval objects from env() or string slugs instead.", - DeprecationWarning, - stacklevel=2, - ) + if isinstance(source, Task): + # Single Task object tasks = [source] elif isinstance(source, list) and source and isinstance(source[0], Task): - # List of Task objects (deprecated) - warnings.warn( - "Passing Task objects to hud.eval() is deprecated. " - "Use Eval objects from env() or string slugs instead.", - DeprecationWarning, - stacklevel=2, - ) + # List of Task objects tasks = source # type: ignore[assignment] + elif isinstance(source, LegacyTask) or ( + isinstance(source, list) and source and isinstance(source[0], LegacyTask) + ): + # LegacyTask no longer accepted - user must convert first + raise TypeError( + "LegacyTask is no longer accepted by hud.eval(). " + "Convert first with Task.from_v4(legacy_task), or use load_dataset()." + ) elif isinstance(source, str): - # String slug - load as Eval - slugs = source - evals = _load_evals_from_slugs(source) + # String slugs no longer supported - use load_dataset() + raise TypeError( + f"String slugs are no longer supported in hud.eval(). " + f"Use load_dataset('{source}') first, then pass the tasks list." + ) elif isinstance(source, list) and source and isinstance(source[0], str): - # List of string slugs - load as Eval - slugs = source # type: ignore[assignment] - evals = _load_evals_from_slugs(source) # type: ignore[arg-type] + # List of string slugs no longer supported + raise TypeError( + "String slugs are no longer supported in hud.eval(). " + "Use load_dataset() first, then pass the tasks list." + ) # Calculate total evaluations - # If we have evals, each eval gets (variants x group) runs - # If we have tasks, each task gets (variants x group) runs - # If neither, we have a single blank eval with (variants x group) runs - base_count = len(evals) or len(tasks) or 1 + # Each task gets (variants x group) runs; no tasks = single blank eval + base_count = len(tasks) or 1 total_evals = base_count * len(variant_combos) * group # Capture code snippet for parallel execution @@ -352,27 +194,14 @@ async def run_eval( from hud.eval.context import EvalContext if total_evals == 1: - # Simple case: single eval - always use Eval for consistent flow - if evals: - single_eval = evals[0] - elif tasks: - # Wrap deprecated Task in Eval - single_eval = Eval( - env=None, - script=None, - api_key=api_key, - job_id=job_id, - variants=variant_combos[0], - code_snippet=code_snippet, - _trace=trace, - _quiet=quiet, - ) - single_eval._task = tasks[0] # type: ignore[attr-defined] + # Simple case: single eval - always use Task for consistent flow + if tasks: + single_task = tasks[0] else: # Blank eval - single_eval = Eval( + single_task = Task( env=None, - script=None, + scenario=None, api_key=api_key, job_id=job_id, variants=variant_combos[0], @@ -382,19 +211,19 @@ async def run_eval( ) # Apply common settings - single_eval.api_key = api_key - single_eval.job_id = job_id - single_eval.variants = variant_combos[0] - single_eval.code_snippet = code_snippet - single_eval._trace = trace - single_eval._quiet = quiet - - async with single_eval as ctx: + single_task.api_key = api_key + single_task.job_id = job_id + single_task.variants = variant_combos[0] + single_task.code_snippet = code_snippet + single_task._trace = trace + single_task._quiet = quiet + + async with single_task as ctx: yield ctx else: # Parallel execution: create implicit job to group traces - eval_name = _get_eval_name(source=slugs, evals=evals, tasks=tasks) + eval_name = _get_eval_name(tasks=tasks) implicit_job_id = job_id or str(uuid.uuid4()) job_url = f"https://hud.ai/jobs/{implicit_job_id}" @@ -406,7 +235,6 @@ async def run_eval( try: # Run parallel evals with job_id completed = await _run_parallel_eval( - evals=evals, tasks=tasks, variant_combos=variant_combos, group=group, @@ -420,19 +248,12 @@ async def run_eval( ) # Create summary context (no trace, just aggregates results) - if evals: - # Create summary from first eval's env_config + if tasks: + # Create summary from first task ctx = EvalContext( name=eval_name, # Use the same smart name api_key=api_key, job_id=implicit_job_id, - env_config=evals[0].env_config, - ) - elif tasks: - ctx = EvalContext.from_task( - task=tasks[0], - api_key=api_key, - job_id=implicit_job_id, ) else: ctx = EvalContext( @@ -464,7 +285,6 @@ async def run_eval( async def _run_parallel_eval( - evals: list[Eval], tasks: list[Task], variant_combos: list[dict[str, Any]], group: int, @@ -478,13 +298,13 @@ async def _run_parallel_eval( ) -> list[EvalContext]: """Run parallel evaluation. - Creates EvalContexts from Evals, tasks (or blank) and runs them in parallel. + Creates EvalContexts from Tasks (or blank) and runs them in parallel. """ import asyncio import textwrap # Lazy import to avoid circular dependency - from hud.eval.eval import Eval + from hud.eval.task import Task from hud.eval.parallel import log_eval_stats # Find user code frame and extract the with block body @@ -492,62 +312,38 @@ async def _run_parallel_eval( body_source, captured_locals, context_var = get_with_block_body(caller_frame) # Calculate total evals and resolve group IDs - base_count = len(evals) or len(tasks) or 1 + base_count = len(tasks) or 1 total_evals = base_count * len(variant_combos) * group resolved_group_ids = resolve_group_ids(group_ids, total_evals) - # Create Eval objects for parallel execution - eval_objects: list[Eval] = [] + # Create Task objects for parallel execution + task_objects: list[Task] = [] idx = 0 - if evals: - # Create Eval for each (eval, variant, run) combination - for base_eval in evals: - for variant in variant_combos: - for _ in range(group): - eval_copy = base_eval.copy() - eval_copy.api_key = api_key - eval_copy.job_id = job_id - eval_copy.group_id = resolved_group_ids[idx] - eval_copy.index = idx - eval_copy.variants = variant - eval_copy.code_snippet = code_snippet - eval_copy._suppress_link = True # Individual traces don't print links - eval_copy._trace = trace - eval_copy._quiet = quiet - eval_objects.append(eval_copy) - idx += 1 - elif tasks: - # Create Eval from Task for each (task, variant, run) combination - for task in tasks: + if tasks: + # Create Task for each (task, variant, run) combination + for base_task in tasks: for variant in variant_combos: for _ in range(group): - # Convert Task to Eval (backwards compatibility) - task_eval = Eval( - env=None, # Task has its own mcp_config - script=None, - args={}, - api_key=api_key, - job_id=job_id, - group_id=resolved_group_ids[idx], - index=idx, - variants=variant, - code_snippet=code_snippet, - _suppress_link=True, - _trace=trace, - _quiet=quiet, - ) - # Store task reference for EvalContext creation - task_eval._task = task # type: ignore[attr-defined] - eval_objects.append(task_eval) + task_copy = base_task.copy() + task_copy.api_key = api_key + task_copy.job_id = job_id + task_copy.group_id = resolved_group_ids[idx] + task_copy.index = idx + task_copy.variants = variant + task_copy.code_snippet = code_snippet + task_copy._suppress_link = True # Individual traces don't print links + task_copy._trace = trace + task_copy._quiet = quiet + task_objects.append(task_copy) idx += 1 else: - # Blank evals for each (variant, run) combination + # Blank tasks for each (variant, run) combination for variant in variant_combos: for _ in range(group): - blank_eval = Eval( + blank_task = Task( env=None, - script=None, + scenario=None, args={}, api_key=api_key, job_id=job_id, @@ -559,7 +355,7 @@ async def _run_parallel_eval( _trace=trace, _quiet=quiet, ) - eval_objects.append(blank_eval) + task_objects.append(blank_task) idx += 1 # Create runner function using the actual variable name from the 'as' clause @@ -572,33 +368,33 @@ async def _run_parallel_eval( # Create semaphore for concurrency control sem = asyncio.Semaphore(max_concurrent) if max_concurrent else None - async def run_one(eval_obj: Eval) -> EvalContext: - """Run a single Eval and return its EvalContext.""" + async def run_one(task_obj: Task) -> EvalContext: + """Run a single Task and return its EvalContext.""" try: if sem: - async with sem, eval_obj as ctx: + async with sem, task_obj as ctx: await runner(ctx) else: - async with eval_obj as ctx: + async with task_obj as ctx: await runner(ctx) return ctx except Exception as e: - logger.warning("Parallel eval %d failed: %s", eval_obj.index, e) - # Create a failed context from the eval - ctx = eval_obj.to_eval_context() + logger.warning("Parallel eval %d failed: %s", task_obj.index, e) + # Create a failed context from the task + ctx = task_obj.to_eval_context() ctx.error = e return ctx # Run in parallel logger.info( - "Running %d evals (%d base x %d variants x %d runs)%s", - len(eval_objects), + "Running %d tasks (%d base x %d variants x %d runs)%s", + len(task_objects), base_count, len(variant_combos), group, f", max_concurrent={max_concurrent}" if max_concurrent else "", ) - completed = await asyncio.gather(*[run_one(e) for e in eval_objects]) + completed = await asyncio.gather(*[run_one(t) for t in task_objects]) # Log and print stats eval_name = completed[0].eval_name if completed else "eval" diff --git a/hud/eval/task.py b/hud/eval/task.py new file mode 100644 index 00000000..c2a145a5 --- /dev/null +++ b/hud/eval/task.py @@ -0,0 +1,437 @@ +"""Task - A runnable evaluation unit (data class). + +A Task holds the configuration needed to run an evaluation: +- Environment configuration (how to create/connect) +- Optional scenario name and args + +When entered as a context manager, it creates an EvalContext. + +Usage: + env = Environment("my-env").connect_hub("browser") + + # Empty - just env + async with env() as ctx: + await ctx.call_tool("navigate", url="...") + + # With scenario + async with env("checkout", user_id="alice") as ctx: + await agent.run(ctx.prompt) + + # Orchestrated via hud.eval + tasks = [env("checkout", user_id="alice"), env("checkout", user_id="bob")] + async with hud.eval(tasks, variants={"model": ["gpt-4o"]}, group=4) as ctx: + ... +""" + +from __future__ import annotations + +import logging +from dataclasses import dataclass, field +from typing import TYPE_CHECKING, Any + +from hud.types import MCPToolCall + +if TYPE_CHECKING: + from types import TracebackType + + from hud.environment import Environment + from hud.environment.types import EnvConfig + from hud.eval.context import EvalContext + +__all__ = ["Task", "build_eval_name"] + +logger = logging.getLogger(__name__) + + +def _warn_local_mcp(mcp_config: dict[str, Any] | None) -> None: + """Warn if mcp_config uses local MCP servers (command without url). + + Local MCP servers can cause port conflicts when running tasks concurrently. + """ + if not mcp_config: + return + + has_local = any( + isinstance(server_cfg, dict) + and "command" in server_cfg + and not server_cfg.get("url") + for server_cfg in mcp_config.values() + if isinstance(server_cfg, dict) + ) + + if has_local: + import warnings + + warnings.warn( + "Task uses local MCP configuration (command without url). " + "This may cause port conflicts when running tasks concurrently. " + "Consider using remote MCP servers for parallel execution.", + UserWarning, + stacklevel=4, # Skip through from_v4 -> _warn_local_mcp -> warn + ) + + +def build_eval_name(scenario: str | None, args: dict[str, Any] | None) -> str: + """Build descriptive name: 'scenario with val1, val2, ...'""" + if not scenario: + return "eval" + if not args: + return scenario + + val_parts = [] + for v in list(args.values())[:3]: # Max 3 values + v_str = repr(v) if isinstance(v, str) else str(v) + if len(v_str) > 25: + v_str = v_str[:22] + "..." + val_parts.append(v_str) + + if val_parts: + return f"{scenario} with {', '.join(val_parts)}" + return scenario + + +@dataclass +class Task: + """A runnable evaluation unit (data class). + + Simplified v5 Task format: + - env: Environment instance OR EnvConfig with hub name + filters + - scenario: Scenario name to run + - args: Scenario arguments + - validation: Optional list of tool calls representing successful completion + + When entered as a context manager, creates an EvalContext. + + Attributes: + id: Optional task identifier for filtering/tracking + env: Environment instance (auto-created from dict/EnvConfig in __post_init__) + scenario: Scenario name to run (from @env.scenario) + args: Scenario arguments + validation: Optional list of MCPToolCall objects representing successful completion + + Example (v5 format): + ```python + from hud.eval import Task + + # Pass dict - auto-converts to Environment + task = Task( + env={"name": "browser", "include": ["navigate", "screenshot"]}, + scenario="checkout", + args={"user_id": "alice"}, + validation=[{"name": "check_cart", "arguments": {}}] + ) + # task.env is now Environment connected to browser hub! + + # Or pass live Environment directly + env = Environment("my-env").connect_hub("browser") + task = Task(env=env, scenario="checkout", args={"user_id": "alice"}) + ``` + + Migration from v4: + Use Task.from_v4() to convert LegacyTask objects: + + ```python + task = Task.from_v4(legacy_task) + # or + task = Task.from_v4({"prompt": "...", "mcp_config": {...}, ...}) + ``` + """ + + # Core v5 task definition + id: str | None = None + env: Environment | None = None + scenario: str | None = None + args: dict[str, Any] = field(default_factory=dict) + validation: list[MCPToolCall] | None = None + + # EvalContext creation params (set by hud.eval for parallel execution) + trace_id: str | None = field(default=None, repr=False) + api_key: str | None = field(default=None, repr=False) + job_id: str | None = field(default=None, repr=False) + group_id: str | None = field(default=None, repr=False) + index: int = field(default=0, repr=False) + variants: dict[str, Any] = field(default_factory=dict, repr=False) + code_snippet: str | None = field(default=None, repr=False) + _suppress_link: bool = field(default=False, repr=False) + _trace: bool = field(default=True, repr=False) + _quiet: bool = field(default=False, repr=False) + + # Runtime state + _ctx: EvalContext | None = field(default=None, repr=False) + + def __post_init__(self) -> None: + """Validate and normalize env and validation fields after initialization. + + Auto-converts dict or EnvConfig to Environment by connecting to the hub. + Auto-converts validation dicts to MCPToolCall objects. + """ + from hud.environment import Environment + from hud.environment.types import EnvConfig + + # Convert env field + if not isinstance(self.env, (Environment, type(None))): + # Convert dict to EnvConfig first (with validation) + if isinstance(self.env, dict): + try: + config = EnvConfig(**self.env) + except Exception as e: + raise ValueError( + f"Invalid env config: {e}. Expected fields: name (str), " + f"include (list[str] | None), exclude (list[str] | None)" + ) from e + elif isinstance(self.env, EnvConfig): + config = self.env + else: + raise TypeError( + f"Task.env must be Environment, EnvConfig, dict, or None. " + f"Got {type(self.env).__name__}" + ) + + # Convert EnvConfig to Environment + env = Environment(config.name) + env.connect_hub(config.name, include=config.include, exclude=config.exclude) + self.env = env + + # Convert validation dicts to MCPToolCall objects + if self.validation and isinstance(self.validation, list): + converted_validation = [] + for item in self.validation: + if isinstance(item, dict): + converted_validation.append(MCPToolCall(**item)) + elif isinstance(item, MCPToolCall): + converted_validation.append(item) + else: + raise TypeError( + f"validation items must be dict or MCPToolCall, " + f"got {type(item).__name__}" + ) + self.validation = converted_validation + + @classmethod + def from_v4( + cls, + source: Any, # LegacyTask | dict[str, Any] | str + ) -> Task: + """Convert a v4 LegacyTask to a v5 Task. + + This is the recommended migration path for existing v4 code. The returned + Task automatically runs setup_tool at the start and evaluate_tool at the + end, matching the old LegacyTask behavior. + + Args: + source: One of: + - LegacyTask object + - dict with LegacyTask fields (prompt, mcp_config, etc.) + - JSON string of LegacyTask fields + + Returns: + Task with Environment configured to mimic LegacyTask behavior. + + Example: + ```python + from hud.eval import Task + + # From existing LegacyTask + task = Task.from_v4(legacy_task) + + # From dict (e.g., loaded from JSON file) + task = Task.from_v4({ + "prompt": "Navigate to google.com", + "mcp_config": {"hud": {...}}, + "setup_tool": {"name": "navigate", "arguments": {"url": "..."}}, + "evaluate_tool": {"name": "check_url", "arguments": {}} + }) + + # Use with hud.eval() or as context manager + async with task as ctx: + result = await agent.run(ctx) + ``` + + Note: + For new code, prefer using @env.scenario() instead: + - setup_tool code goes BEFORE the first yield + - evaluate_tool code goes AFTER the first yield + See https://docs.hud.ai/migration for the full migration guide. + """ + import json as json_module + + from hud.environment import Environment + from hud.types import LegacyTask + + # Parse JSON string + if isinstance(source, str): + try: + source = json_module.loads(source) + except json_module.JSONDecodeError as e: + from hud.shared.exceptions import HudConfigError + + raise HudConfigError(f"Invalid JSON string for Task.from_v4: {e}") from e + + # Convert dict to LegacyTask (suppress the deprecation warning since we're migrating) + if isinstance(source, dict): + import warnings + + with warnings.catch_warnings(): + warnings.filterwarnings("ignore", category=DeprecationWarning) + legacy_task = LegacyTask(**source) + elif isinstance(source, LegacyTask): + legacy_task = source + else: + raise TypeError( + f"Task.from_v4() expects LegacyTask, dict, or JSON string, " + f"got {type(source).__name__}" + ) + + # Warn if using local MCP configs (command without url) + _warn_local_mcp(legacy_task.mcp_config) + + # Create Environment and connect via mcp_config + env = Environment(legacy_task.id or "v4-legacy") + env.connect_mcp_config(legacy_task.mcp_config) + + # Set the prompt + env.prompt = legacy_task.prompt + + # Add setup_tool calls (run after connection via Environment._setup_calls) + if legacy_task.setup_tool: + setup_calls = legacy_task.setup_tool + if not isinstance(setup_calls, list): + setup_calls = [setup_calls] + for call in setup_calls: + env.setup_tool(call.name, **(call.arguments or {})) + + # Add evaluate_tool calls (run before disconnection via Environment._evaluate_calls) + if legacy_task.evaluate_tool: + evaluate_calls = legacy_task.evaluate_tool + if not isinstance(evaluate_calls, list): + evaluate_calls = [evaluate_calls] + for call in evaluate_calls: + env.evaluate_tool(call.name, **(call.arguments or {})) + + logger.debug( + "Created Task from v4 LegacyTask: %s", + legacy_task.prompt[:50] if legacy_task.prompt else "no prompt", + ) + + return cls( + id=legacy_task.id, + env=env, # Live Environment with mcp_config, setup_tool, evaluate_tool + scenario=None, # No scenario - uses prompt directly + args={}, + validation=None, + ) + + # Backwards compat alias + + def copy(self) -> Task: + """Create a copy of this Task for parallel execution.""" + return Task( + env=self.env, # Share reference - from_environment handles copying + scenario=self.scenario, + args=self.args.copy(), + trace_id=None, # Each copy gets unique trace_id + api_key=self.api_key, + job_id=self.job_id, + group_id=self.group_id, + index=self.index, + variants=self.variants.copy(), + code_snippet=self.code_snippet, + _suppress_link=self._suppress_link, + _trace=self._trace, + _quiet=self._quiet, + ) + + def to_eval_context(self) -> EvalContext: + """Convert this Task to an EvalContext. + + Creates an EvalContext from the environment (live or from config). + If env is EnvConfig or dict, creates Environment by connecting to the hub. + """ + from hud.environment import Environment + from hud.eval.context import EvalContext + + # Get environment (or create blank if None) + source_env = self.env if self.env is not None else Environment("eval") + + eval_name = build_eval_name(self.scenario, self.args) + + # Create EvalContext from environment + ctx = EvalContext.from_environment( + env=source_env, + name=eval_name, + trace_id=self.trace_id, + api_key=self.api_key, + job_id=self.job_id, + group_id=self.group_id, + index=self.index, + variants=self.variants, + code_snippet=self.code_snippet, + ) + ctx._suppress_link = self._suppress_link + ctx._trace_enabled = self._trace + + return ctx + + async def __aenter__(self) -> EvalContext: + """Enter eval context. + + Order of operations: + 1. Create EvalContext from environment config + 2. Connect environment (MCP servers, etc.) + 3. Run scenario setup (if scenario) → sets ctx.prompt + 4. Notify backend (with prompt now set) + 5. Print trace link + """ + self._ctx = self.to_eval_context() + await self._ctx.__aenter__() # Connect env, set trace headers + + # Run scenario setup (sets prompt) + if self.scenario: + await self._run_scenario_setup() + + # Notify backend with prompt included + await self._ctx._eval_enter() + self._ctx._print_eval_link() + + return self._ctx + + async def __aexit__( + self, + exc_type: type[BaseException] | None, + exc_val: BaseException | None, + exc_tb: TracebackType | None, + ) -> None: + """Exit eval context - run scenario evaluate and exit EvalContext.""" + if self._ctx is None: + return + + # If we have a scenario and no error, run its evaluate phase + if self.scenario and exc_type is None: + await self._run_scenario_evaluate() + + # Exit the EvalContext + await self._ctx.__aexit__(exc_type, exc_val, exc_tb) + self._ctx = None + + async def _run_scenario_setup(self) -> None: + """Run the scenario's setup phase (get prompt).""" + if self._ctx is None or self.scenario is None: + return + + # Store scenario name on context for ctx.submit() + self._ctx._scenario_name = self.scenario + + # Delegate to ScenarioMixin.run_scenario_setup + prompt = await self._ctx.run_scenario_setup(self.scenario, self.args) + if prompt: + self._ctx.prompt = prompt + + async def _run_scenario_evaluate(self) -> None: + """Run the scenario's evaluate phase (get reward).""" + if self._ctx is None or self.scenario is None: + return + + # Delegate to ScenarioMixin.run_scenario_evaluate + reward = await self._ctx.run_scenario_evaluate(self.scenario) + if reward is not None: + self._ctx.reward = reward diff --git a/hud/eval/tests/test_eval.py b/hud/eval/tests/test_eval.py index 38c11f58..1fa9d655 100644 --- a/hud/eval/tests/test_eval.py +++ b/hud/eval/tests/test_eval.py @@ -1,4 +1,4 @@ -"""Tests for hud.eval.eval module (Eval class).""" +"""Tests for hud.eval.task module (Task class).""" from __future__ import annotations @@ -6,36 +6,42 @@ import pytest -from hud.eval.eval import Eval +from hud.eval.task import Task -class TestEvalDataclass: - """Tests for Eval as a data class.""" +class TestTaskDataclass: + """Tests for Task as a data class.""" def test_init_defaults(self) -> None: - """Eval initializes with sensible defaults.""" - ev = Eval() + """Task initializes with sensible defaults.""" + task = Task() - assert ev.env_config is None - assert ev.script is None - assert ev.args == {} - assert ev.variants == {} - assert ev.index == 0 + assert task.env is None + assert task.scenario is None + assert task.args == {} + assert task.variants == {} + assert task.index == 0 - def test_init_with_config(self) -> None: - """Eval can be initialized with env_config and script.""" - config = {"name": "test-env", "hubs": []} - ev = Eval(env=config, script="checkout", args={"user_id": "alice"}) + def test_init_with_env_dict(self) -> None: + """Task auto-converts env dict to Environment in __post_init__.""" + from hud.environment import Environment + + task = Task( + env={"name": "browser", "include": ["navigate"]}, + scenario="checkout", + args={"user_id": "alice"}, + ) - assert ev.env_config == config - assert ev.script == "checkout" - assert ev.args == {"user_id": "alice"} + # env dict is auto-converted to Environment + assert isinstance(task.env, Environment) + assert task.scenario == "checkout" + assert task.args == {"user_id": "alice"} def test_copy_creates_new_instance(self) -> None: - """copy() creates a new Eval instance.""" - original = Eval( + """copy() creates a new Task instance.""" + original = Task( env={"name": "test"}, - script="checkout", + scenario="checkout", args={"user_id": "alice"}, variants={"model": "gpt-4o"}, ) @@ -43,7 +49,7 @@ def test_copy_creates_new_instance(self) -> None: assert copied is not original assert copied.env == original.env - assert copied.script == original.script + assert copied.scenario == original.scenario assert copied.args == original.args assert copied.args is not original.args # Deep copy assert copied.variants == original.variants @@ -51,36 +57,36 @@ def test_copy_creates_new_instance(self) -> None: def test_copy_clears_trace_id(self) -> None: """copy() clears trace_id for fresh instance.""" - original = Eval(trace_id="original-trace") + original = Task(trace_id="original-trace") copied = original.copy() assert copied.trace_id is None -class TestEvalToEvalContext: - """Tests for Eval.to_eval_context().""" +class TestTaskToEvalContext: + """Tests for Task.to_eval_context().""" def test_creates_eval_context(self) -> None: """to_eval_context() creates an EvalContext.""" from hud.eval.context import EvalContext - ev = Eval(script="checkout") - ctx = ev.to_eval_context() + task = Task(scenario="checkout") + ctx = task.to_eval_context() assert isinstance(ctx, EvalContext) assert ctx.eval_name == "checkout" - def test_uses_eval_as_name_when_no_script(self) -> None: - """to_eval_context() uses 'eval' as name when no script.""" - ev = Eval() - ctx = ev.to_eval_context() + def test_uses_eval_as_name_when_no_scenario(self) -> None: + """to_eval_context() uses 'eval' as name when no scenario.""" + task = Task() + ctx = task.to_eval_context() assert ctx.eval_name == "eval" def test_passes_through_properties(self) -> None: """to_eval_context() passes through properties.""" - ev = Eval( - script="checkout", + task = Task( + scenario="checkout", trace_id="test-trace", api_key="test-key", job_id="test-job", @@ -88,7 +94,7 @@ def test_passes_through_properties(self) -> None: index=5, variants={"model": "gpt-4o"}, ) - ctx = ev.to_eval_context() + ctx = task.to_eval_context() assert ctx.trace_id == "test-trace" assert ctx._eval_api_key == "test-key" @@ -98,15 +104,15 @@ def test_passes_through_properties(self) -> None: assert ctx.variants == {"model": "gpt-4o"} -class TestEvalContextManager: - """Tests for Eval as async context manager.""" +class TestTaskContextManager: + """Tests for Task as async context manager.""" @pytest.mark.asyncio async def test_aenter_returns_eval_context(self) -> None: """__aenter__ returns an EvalContext.""" from hud.eval.context import EvalContext - ev = Eval() # No script to avoid script lookup + task = Task() # No scenario to avoid scenario lookup with ( patch.object(EvalContext, "_eval_enter", new_callable=AsyncMock), @@ -114,17 +120,17 @@ async def test_aenter_returns_eval_context(self) -> None: patch.object(EvalContext, "__aexit__", new_callable=AsyncMock), patch.object(EvalContext, "_print_eval_link"), # Suppress link printing ): - ctx = await ev.__aenter__() + ctx = await task.__aenter__() assert isinstance(ctx, EvalContext) # Clean up manually since we patched __aexit__ - ev._ctx = None + task._ctx = None @pytest.mark.asyncio async def test_context_clears_on_exit(self) -> None: """__aexit__ clears internal context reference.""" from hud.eval.context import EvalContext - ev = Eval() + task = Task() with ( patch.object(EvalContext, "_eval_enter", new_callable=AsyncMock), @@ -132,19 +138,19 @@ async def test_context_clears_on_exit(self) -> None: patch.object(EvalContext, "__aexit__", new_callable=AsyncMock), patch.object(EvalContext, "_print_eval_link"), # Suppress link printing ): - await ev.__aenter__() - assert ev._ctx is not None + await task.__aenter__() + assert task._ctx is not None - # Manually call __aexit__ on Eval (which will call mocked ctx.__aexit__) - await ev.__aexit__(None, None, None) - assert ev._ctx is None + # Manually call __aexit__ on Task (which will call mocked ctx.__aexit__) + await task.__aexit__(None, None, None) + assert task._ctx is None @pytest.mark.asyncio async def test_reward_accessible_after_exit(self) -> None: """Reward set in context is accessible after exit.""" from hud.eval.context import EvalContext - ev = Eval() + task = Task() with ( patch.object(EvalContext, "_eval_enter", new_callable=AsyncMock), @@ -152,87 +158,158 @@ async def test_reward_accessible_after_exit(self) -> None: patch.object(EvalContext, "__aexit__", new_callable=AsyncMock), patch.object(EvalContext, "_print_eval_link"), # Suppress link printing ): - ctx = await ev.__aenter__() + ctx = await task.__aenter__() ctx.reward = 0.95 - await ev.__aexit__(None, None, None) + await task.__aexit__(None, None, None) # Context reference is cleared but reward was set on the actual context -class TestEvalFromApi: - """Tests for _eval_from_api helper.""" - - def test_creates_eval_from_api_response(self) -> None: - """_eval_from_api creates Eval from API response.""" - from hud.eval.manager import _eval_from_api - - data = { - "env_config": {"name": "test-env", "hubs": []}, - "script": "checkout", - "args": {"user_id": "alice"}, - } - - ev = _eval_from_api(data) - - assert ev.env_config == {"name": "test-env", "hubs": []} - assert ev.script == "checkout" - assert ev.args == {"user_id": "alice"} - - def test_handles_missing_optional_fields(self) -> None: - """_eval_from_api handles missing optional fields.""" - from hud.eval.manager import _eval_from_api - - data = {} # Minimal response - - ev = _eval_from_api(data) - - assert ev.env_config is None - assert ev.script is None - assert ev.args == {} class TestEnvironmentCall: - """Tests for Environment.__call__ returning Eval.""" + """Tests for Environment.__call__ returning Task.""" - def test_call_returns_eval(self) -> None: - """Environment() returns an Eval object.""" + def test_call_returns_task(self) -> None: + """Environment() returns a Task object.""" from hud.environment import Environment env = Environment("test-env") - ev = env() + task = env() - assert isinstance(ev, Eval) + assert isinstance(task, Task) - def test_call_with_script_sets_script(self) -> None: - """Environment(script) sets script name.""" + def test_call_with_scenario_sets_scenario(self) -> None: + """Environment(scenario) sets scenario name.""" from hud.environment import Environment env = Environment("test-env") - ev = env("checkout") + task = env("checkout") - assert ev.script == "checkout" + assert task.scenario == "checkout" def test_call_with_args_sets_args(self) -> None: - """Environment(script, **args) sets args.""" + """Environment(scenario, **args) sets args.""" from hud.environment import Environment env = Environment("test-env") - ev = env("checkout", user_id="alice", amount=100) + task = env("checkout", user_id="alice", amount=100) - assert ev.args == {"user_id": "alice", "amount": 100} + assert task.args == {"user_id": "alice", "amount": 100} - def test_call_captures_env_config_when_configured(self) -> None: - """Environment() captures env config when there's something to store.""" + def test_call_returns_task_with_env(self) -> None: + """Environment() returns Task with env reference.""" from hud.environment import Environment - # Plain env has no config (nothing to reconstruct) env = Environment("test-env") - ev = env() - assert ev.env_config is None # Nothing to store + task = env() + + # Task has reference to the Environment + assert task.env is env - # Env with setup_tool has config + # With setup_tool (v4 legacy) env2 = Environment("test-env").setup_tool("navigate", url="https://example.com") - ev2 = env2() - assert ev2.env_config is not None - assert ev2.env_config["name"] == "test-env" - assert len(ev2.env_config["setup_tools"]) == 1 + task2 = env2() + assert task2.env is env2 + assert len(task2.env._setup_calls) == 1 + + +class TestTaskFromV4: + """Tests for Task.from_v4() migration helper.""" + + def test_from_v4_with_legacy_task(self) -> None: + """Task.from_v4() accepts LegacyTask object.""" + import warnings + + # Suppress the deprecation warning from LegacyTask + with warnings.catch_warnings(): + warnings.filterwarnings("ignore", category=DeprecationWarning) + from hud.types import LegacyTask + + legacy = LegacyTask( + prompt="Navigate to google.com", + mcp_config={"hud": {"url": "https://mcp.hud.ai"}}, + ) + + task = Task.from_v4(legacy) + + assert isinstance(task, Task) + assert task.env is not None + assert task.env.prompt == "Navigate to google.com" + assert task.scenario is None # Uses setup/evaluate_tool, not scenarios + + def test_from_v4_with_dict(self) -> None: + """Task.from_v4() accepts dict with LegacyTask fields.""" + task = Task.from_v4({ + "prompt": "Navigate to google.com", + "mcp_config": {"hud": {"url": "https://mcp.hud.ai"}}, + }) + + assert isinstance(task, Task) + assert task.env is not None + assert task.env.prompt == "Navigate to google.com" + + def test_from_v4_with_json_string(self) -> None: + """Task.from_v4() accepts JSON string.""" + import json + + data = { + "prompt": "Navigate to google.com", + "mcp_config": {"hud": {"url": "https://mcp.hud.ai"}}, + } + task = Task.from_v4(json.dumps(data)) + + assert isinstance(task, Task) + assert task.env is not None + assert task.env.prompt == "Navigate to google.com" + + def test_from_v4_with_setup_tool(self) -> None: + """Task.from_v4() preserves setup_tool via env._setup_calls.""" + task = Task.from_v4({ + "prompt": "Check URL", + "mcp_config": {"hud": {"url": "https://mcp.hud.ai"}}, + "setup_tool": {"name": "navigate", "arguments": {"url": "https://google.com"}}, + }) + + # setup_tool is converted to env._setup_calls + assert len(task.env._setup_calls) == 1 + assert task.env._setup_calls[0] == ("navigate", {"url": "https://google.com"}) + + def test_from_v4_with_evaluate_tool(self) -> None: + """Task.from_v4() preserves evaluate_tool via env._evaluate_calls.""" + task = Task.from_v4({ + "prompt": "Check URL", + "mcp_config": {"hud": {"url": "https://mcp.hud.ai"}}, + "evaluate_tool": {"name": "check_url", "arguments": {"expected": "google"}}, + }) + + # evaluate_tool is converted to env._evaluate_calls + assert len(task.env._evaluate_calls) == 1 + assert task.env._evaluate_calls[0] == ("check_url", {"expected": "google"}) + + def test_from_v4_with_invalid_type_raises(self) -> None: + """Task.from_v4() raises TypeError for invalid input.""" + with pytest.raises(TypeError, match="expects LegacyTask, dict, or JSON string"): + Task.from_v4(12345) # type: ignore[arg-type] + + def test_from_v4_with_invalid_json_raises(self) -> None: + """Task.from_v4() raises HudConfigError for invalid JSON.""" + from hud.shared.exceptions import HudConfigError + + with pytest.raises(HudConfigError, match="Invalid JSON string"): + Task.from_v4("not valid json") + + def test_from_v4_does_not_warn_on_use(self) -> None: + """Task.from_v4() suppresses LegacyTask deprecation warning.""" + import warnings + + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter("always") + Task.from_v4({ + "prompt": "test", + "mcp_config": {"hud": {}}, + }) + + # Should not trigger deprecation warning since we're migrating + legacy_warnings = [x for x in w if issubclass(x.category, DeprecationWarning)] + assert len(legacy_warnings) == 0 diff --git a/hud/eval/types.py b/hud/eval/types.py index a6c8b376..d844eb10 100644 --- a/hud/eval/types.py +++ b/hud/eval/types.py @@ -9,8 +9,6 @@ from pydantic import BaseModel -from hud.environment.types import EnvConfig - # ============================================================================= # Exceptions # ============================================================================= @@ -34,7 +32,6 @@ class EvalPayload(BaseModel): prompt: str | None = None code_snippet: str | None = None - env_config: EnvConfig | None = None job_name: str | None = None job_id: str | None = None group_id: str | None = None diff --git a/hud/samples/browser.py b/hud/samples/browser.py index f6268dad..a6fdc695 100644 --- a/hud/samples/browser.py +++ b/hud/samples/browser.py @@ -7,11 +7,11 @@ from pydantic import Field from hud.settings import settings -from hud.types import MCPToolCall, Task +from hud.types import LegacyTask, MCPToolCall -class BrowserTask(Task): - """Task subclass with browser defaults for BrowserTask(prompt=...).""" +class BrowserTask(LegacyTask): + """LegacyTask subclass with browser defaults for BrowserTask(prompt=...).""" prompt: str = "Open Google and be ready to search." mcp_config: dict[str, Any] = Field( diff --git a/hud/server/server.py b/hud/server/server.py index aa020fa5..19b9b7a3 100644 --- a/hud/server/server.py +++ b/hud/server/server.py @@ -18,7 +18,7 @@ from hud.datasets import run_tasks from hud.server.low_level import LowLevelServerWithInit -from hud.types import Task +from hud.types import LegacyTask if TYPE_CHECKING: from collections.abc import AsyncGenerator, Callable diff --git a/hud/tests/test_datasets_extended.py b/hud/tests/test_datasets_extended.py index 264771a1..682c077f 100644 --- a/hud/tests/test_datasets_extended.py +++ b/hud/tests/test_datasets_extended.py @@ -8,7 +8,7 @@ import pytest from hud.datasets import ( - Task, + LegacyTask, run_dataset, ) from hud.types import MCPToolCall @@ -16,14 +16,14 @@ class TestTaskExtended: - """Extended tests for Task functionality.""" + """Extended tests for LegacyTask functionality.""" def test_taskconfig_with_all_fields(self): - """Test Task with all possible fields.""" + """Test LegacyTask with all possible fields.""" setup_tool = MCPToolCall(name="setup", arguments={"board_size": 4}) evaluate_tool = MCPToolCall(name="evaluate", arguments={"metric": "score"}) - task = Task( + task = LegacyTask( id="test-123", prompt="Play the game", mcp_config={ @@ -43,13 +43,13 @@ def test_taskconfig_with_all_fields(self): assert task.metadata["version"] == 2 def test_taskconfig_list_tools(self): - """Test Task with list of tools.""" + """Test LegacyTask with list of tools.""" setup_tools = [ MCPToolCall(name="init", arguments={}), MCPToolCall(name="configure", arguments={"mode": "test"}), ] - task = Task(prompt="Multi-setup task", mcp_config={"test": True}, setup_tool=setup_tools) + task = LegacyTask(prompt="Multi-setup task", mcp_config={"test": True}, setup_tool=setup_tools) assert isinstance(task.setup_tool, list) assert len(task.setup_tool) == 2 @@ -77,7 +77,7 @@ def test_env_var_complex_resolution(self, monkeypatch): "hud_telemetry_url": "https://api.example.com", } - task = Task( + task = LegacyTask( prompt="Complex env test", mcp_config={ "auth": { @@ -104,7 +104,7 @@ def test_env_var_complex_resolution(self, monkeypatch): def test_non_string_values_preserved(self): """Test that non-string values are preserved during env resolution.""" - task = Task( + task = LegacyTask( prompt="Test non-strings", mcp_config={ "string": "${MISSING}", @@ -139,13 +139,13 @@ def test_save_taskconfigs_empty_list(self): mock_instance.push_to_hub.assert_called_once_with("test-org/empty-dataset") def test_save_taskconfigs_mixed_rejection(self): - """Test that mixing dicts and Task objects is rejected.""" + """Test that mixing dicts and LegacyTask objects is rejected.""" valid_dict = {"prompt": "Dict task", "mcp_config": {"test": True}} - task_object = Task(prompt="Object task", mcp_config={"resolved": "${SOME_VAR}"}) + task_object = LegacyTask(prompt="Object task", mcp_config={"resolved": "${SOME_VAR}"}) # First item is dict, second is object - with pytest.raises(ValueError, match="Item 1 is a Task object"): + with pytest.raises(ValueError, match="Item 1 is a LegacyTask object"): save_tasks([valid_dict, task_object], "test-org/mixed") # type: ignore @@ -154,178 +154,113 @@ class TestRunDatasetExtended: @pytest.mark.asyncio async def test_run_dataset_empty(self): - """Test running empty dataset.""" - with ( - patch("hud.clients.MCPClient"), - patch("hud.eval.display.print_link"), - patch("hud.eval.display.print_complete"), - ): - # Create a mock agent class with proper type - from hud.agents import MCPAgent - - mock_agent_class = type("MockAgent", (MCPAgent,), {}) + """Test running empty dataset raises ValueError.""" + from hud.agents import MCPAgent + from hud.types import Trace - results = await run_dataset( - "empty_run", - [], # Empty task list - mock_agent_class, - ) + # Create mock agent + mock_agent = AsyncMock(spec=MCPAgent) + mock_agent.run.return_value = Trace(reward=1.0, done=True) - assert results == [] + # Empty task list should raise ValueError + with pytest.raises(ValueError, match="No tasks to run"): + await run_dataset([], mock_agent) @pytest.mark.asyncio - async def test_run_dataset_with_metadata(self): - """Test run_dataset with custom metadata.""" + async def test_run_dataset_with_task_list(self): + """Test run_dataset with Task objects.""" from hud.agents import MCPAgent + from hud.eval.task import Task from hud.types import Trace - # Create a proper mock agent class - mock_agent_instance = AsyncMock() - mock_agent_instance.run.return_value = Trace(reward=1.0, done=True) + # Create mock agent + mock_agent = AsyncMock(spec=MCPAgent) + mock_agent.run.return_value = Trace(reward=1.0, done=True) + + # Create mock tasks (with mocked Environment to avoid real connections) + mock_env = MagicMock() + mock_env.name = "test" + + tasks = [ + Task(env=mock_env, scenario="test1"), + Task(env=mock_env, scenario="test2"), + ] - mock_agent_class = type( - "MockAgent", - (MCPAgent,), - { - "__init__": lambda self, **kwargs: None, - "__new__": lambda cls, **kwargs: mock_agent_instance, - }, - ) + # Mock hud.eval to avoid real eval context + mock_ctx = AsyncMock() + mock_ctx.results = None + mock_ctx.reward = None - tasks = [{"prompt": "Task 1", "mcp_config": {"url": "test1"}}] + with patch("hud.datasets.runner.hud.eval") as mock_eval: + mock_eval.return_value.__aenter__ = AsyncMock(return_value=mock_ctx) + mock_eval.return_value.__aexit__ = AsyncMock(return_value=None) - # Mock EvalContext to avoid actual MCP connections - mock_ctx = AsyncMock() - mock_ctx.__aenter__ = AsyncMock(return_value=mock_ctx) - mock_ctx.__aexit__ = AsyncMock(return_value=None) + results = await run_dataset(tasks, mock_agent, max_steps=5) - with ( - patch("hud.clients.MCPClient"), - patch("hud.eval.context.EvalContext.from_task", return_value=mock_ctx), - patch("hud.eval.display.print_link"), - patch("hud.eval.display.print_complete"), - ): - # Should run without error - await run_dataset( - "metadata_run", - tasks, - mock_agent_class, # type: ignore - {"verbose": True}, - ) + # Should return list with ctx + assert len(results) == 1 + mock_agent.run.assert_called_once() @pytest.mark.asyncio - async def test_run_dataset_exception_handling(self): - """Test exception handling during task execution.""" + async def test_run_dataset_from_source_string(self): + """Test run_dataset with source string calls load_dataset.""" + from hud.agents import MCPAgent + from hud.eval.task import Task from hud.types import Trace - # Track execution by task index - executed_task_indices: set[int] = set() - - # Create a mock agent class where behavior depends on the task being run - def create_mock_agent(**kwargs): - agent = AsyncMock() - - async def mock_run(task, **run_kwargs): - # Extract task index from prompt "Task {i}" - task_idx = int(task.prompt.split()[-1]) - executed_task_indices.add(task_idx) - - if task_idx == 1: # Second task (index 1) should fail - raise RuntimeError("Task 2 failed") - return Trace(reward=1.0, done=True, content=f"success-{task_idx + 1}") + # Create mock agent + mock_agent = AsyncMock(spec=MCPAgent) + mock_agent.run.return_value = Trace(reward=1.0, done=True) - agent.run = mock_run - return agent + mock_env = MagicMock() + mock_tasks = [Task(env=mock_env, scenario="loaded")] - # Mock the agent class itself - runner calls agent_class.create() - mock_agent_class = MagicMock() - mock_agent_class.create = MagicMock(side_effect=create_mock_agent) - mock_agent_class.__name__ = "MockAgent" - - tasks = [{"prompt": f"Task {i}", "mcp_config": {"url": f"test{i}"}} for i in range(3)] - - # Create mock contexts for each task - def create_mock_ctx(*args, **kwargs): - ctx = AsyncMock() - ctx.__aenter__ = AsyncMock(return_value=ctx) - ctx.__aexit__ = AsyncMock(return_value=None) - ctx._suppress_link = False - return ctx + mock_ctx = AsyncMock() + mock_ctx.results = None with ( - patch("hud.clients.MCPClient"), - patch("hud.eval.context.EvalContext.from_task", side_effect=create_mock_ctx), - patch("hud.eval.display.print_link"), - patch("hud.eval.display.print_complete"), + patch("hud.datasets.runner.load_dataset", return_value=mock_tasks) as mock_load, + patch("hud.datasets.runner.hud.eval") as mock_eval, ): - # Should complete without raising - results = await run_dataset("error_run", tasks, mock_agent_class) # type: ignore + mock_eval.return_value.__aenter__ = AsyncMock(return_value=mock_ctx) + mock_eval.return_value.__aexit__ = AsyncMock(return_value=None) - # All tasks should be attempted - assert len(executed_task_indices) == 3 - assert executed_task_indices == {0, 1, 2} + await run_dataset("test-org/dataset", mock_agent) - # Second result should be None due to exception - assert results[1] is None + # Should call load_dataset with the source string + mock_load.assert_called_once_with("test-org/dataset") @pytest.mark.asyncio - async def test_run_dataset_client_cleanup(self): - """Test that run_dataset completes successfully.""" + async def test_run_dataset_passes_parameters(self): + """Test that run_dataset passes parameters correctly to hud.eval.""" from hud.agents import MCPAgent + from hud.eval.task import Task from hud.types import Trace - mock_agent_instance = AsyncMock() - mock_agent_instance.run.return_value = Trace(reward=1.0, done=True) - - mock_agent_class = type( - "MockAgent", - (MCPAgent,), - { - "__init__": lambda self, **kwargs: None, - "__new__": lambda cls, **kwargs: mock_agent_instance, - }, - ) - - tasks = [{"prompt": f"Task {i}", "mcp_config": {"url": f"test{i}"}} for i in range(3)] - - # Create mock contexts - def create_mock_ctx(*args, **kwargs): - ctx = AsyncMock() - ctx.__aenter__ = AsyncMock(return_value=ctx) - ctx.__aexit__ = AsyncMock(return_value=None) - ctx._suppress_link = False - return ctx + mock_agent = AsyncMock(spec=MCPAgent) + mock_agent.run.return_value = Trace(reward=1.0, done=True) - with ( - patch("hud.clients.MCPClient"), - patch("hud.eval.context.EvalContext.from_task", side_effect=create_mock_ctx), - patch("hud.eval.display.print_link"), - patch("hud.eval.display.print_complete"), - ): - results = await run_dataset("cleanup_run", tasks, mock_agent_class) # type: ignore + mock_env = MagicMock() + tasks = [Task(env=mock_env, scenario="test")] - # Verify results were returned - assert len(results) == 3 - - @pytest.mark.asyncio - async def test_run_dataset_validation_error(self): - """Test that tasks without required fields cause validation errors.""" - from pydantic import ValidationError - - from hud.agents import MCPAgent - - # Create a task without mcp_config (required field) - task: dict[str, Any] = { - "prompt": "Test task", - # No mcp_config - should cause validation error during Task(**task_dict) - } + mock_ctx = AsyncMock() + mock_ctx.results = None - mock_agent_class = type("MockAgent", (MCPAgent,), {}) + with patch("hud.datasets.runner.hud.eval") as mock_eval: + mock_eval.return_value.__aenter__ = AsyncMock(return_value=mock_ctx) + mock_eval.return_value.__aexit__ = AsyncMock(return_value=None) - # Validation errors should be raised immediately when Task objects are created - with pytest.raises(ValidationError): await run_dataset( - "validation_run", - [task], # Pass the task directly - mock_agent_class, # type: ignore + tasks, + mock_agent, + max_steps=25, + max_concurrent=10, + group_size=3 + ) + + # Verify hud.eval was called with correct params + mock_eval.assert_called_once_with( + tasks, + group=3, + max_concurrent=10, ) diff --git a/hud/tests/test_types.py b/hud/tests/test_types.py index d202f707..abd052f7 100644 --- a/hud/tests/test_types.py +++ b/hud/tests/test_types.py @@ -5,12 +5,12 @@ import pytest from mcp.types import ImageContent, TextContent -from hud.types import AgentResponse, MCPToolCall, MCPToolResult, Task, Trace, TraceStep +from hud.types import AgentResponse, LegacyTask, MCPToolCall, MCPToolResult, Trace, TraceStep def test_task_with_json_strings(): - """Test Task with JSON strings for config fields.""" - task = Task( + """Test LegacyTask with JSON strings for config fields.""" + task = LegacyTask( prompt="test", mcp_config='{"test": "config"}', # type: ignore metadata='{"key": "value"}', # type: ignore @@ -23,19 +23,19 @@ def test_task_with_json_strings(): def test_task_json_parse_error(): - """Test Task raises error on invalid JSON.""" + """Test LegacyTask raises error on invalid JSON.""" from hud.shared.exceptions import HudConfigError with pytest.raises(HudConfigError, match="Invalid JSON string"): - Task(prompt="test", mcp_config="{invalid json}") # type: ignore + LegacyTask(prompt="test", mcp_config="{invalid json}") # type: ignore def test_task_agent_config_rejects_extra_fields(): - """Test Task agent_config rejects unknown fields.""" + """Test LegacyTask agent_config rejects unknown fields.""" from pydantic import ValidationError with pytest.raises(ValidationError): - Task( + LegacyTask( prompt="test", mcp_config={}, agent_config={"model": "test", "unknown_field": "value"}, # type: ignore @@ -43,8 +43,8 @@ def test_task_agent_config_rejects_extra_fields(): def test_task_setup_tool_from_json_string(): - """Test Task converts JSON string to tool call.""" - task = Task( + """Test LegacyTask converts JSON string to tool call.""" + task = LegacyTask( prompt="test", mcp_config={}, setup_tool='{"name": "test_tool", "arguments": {"x": 1}}', # type: ignore @@ -54,16 +54,16 @@ def test_task_setup_tool_from_json_string(): def test_task_setup_tool_json_error(): - """Test Task raises error on invalid tool JSON.""" + """Test LegacyTask raises error on invalid tool JSON.""" from hud.shared.exceptions import HudConfigError with pytest.raises(HudConfigError, match="Invalid JSON string"): - Task(prompt="test", mcp_config={}, setup_tool="{invalid}") # type: ignore + LegacyTask(prompt="test", mcp_config={}, setup_tool="{invalid}") # type: ignore def test_task_setup_tool_from_list(): - """Test Task converts list of dicts to list of tool calls.""" - task = Task( + """Test LegacyTask converts list of dicts to list of tool calls.""" + task = LegacyTask( prompt="test", mcp_config={}, setup_tool=[ @@ -77,9 +77,9 @@ def test_task_setup_tool_from_list(): def test_task_env_var_substitution(): - """Test Task resolves environment variables.""" + """Test LegacyTask resolves environment variables.""" with patch.dict("os.environ", {"TEST_VAR": "test_value"}): - task = Task( + task = LegacyTask( prompt="test", mcp_config={"url": "${TEST_VAR}"}, ) @@ -87,9 +87,9 @@ def test_task_env_var_substitution(): def test_task_env_var_nested(): - """Test Task resolves env vars in nested structures.""" + """Test LegacyTask resolves env vars in nested structures.""" with patch.dict("os.environ", {"NESTED_VAR": "nested_value"}): - task = Task( + task = LegacyTask( prompt="test", mcp_config={"level1": {"level2": {"url": "${NESTED_VAR}"}}}, ) @@ -97,9 +97,9 @@ def test_task_env_var_nested(): def test_task_env_var_in_list(): - """Test Task resolves env vars in lists.""" + """Test LegacyTask resolves env vars in lists.""" with patch.dict("os.environ", {"LIST_VAR": "list_value"}): - task = Task( + task = LegacyTask( prompt="test", mcp_config={"items": ["${LIST_VAR}", "static"]}, ) diff --git a/hud/tools/grounding/grounded_tool.py b/hud/tools/grounding/grounded_tool.py index bc9d0345..21537afb 100644 --- a/hud/tools/grounding/grounded_tool.py +++ b/hud/tools/grounding/grounded_tool.py @@ -3,14 +3,15 @@ from __future__ import annotations import logging -from typing import Any +from typing import TYPE_CHECKING, Any from mcp import ErrorData, McpError from mcp.types import INVALID_PARAMS, ContentBlock -from hud.clients.base import AgentMCPClient # noqa: TC001 from hud.tools.grounding.grounder import Grounder # noqa: TC001 -from hud.types import MCPToolCall + +if TYPE_CHECKING: + from hud.environment import Environment logger = logging.getLogger(__name__) @@ -33,18 +34,18 @@ def __init__( self, *, grounder: Grounder, - mcp_client: AgentMCPClient, + ctx: Environment, computer_tool_name: str = "computer", ) -> None: """Initialize the grounded computer tool. Args: grounder: Grounder instance for visual grounding - mcp_client: MCP client to call the environment's computer tool + ctx: Environment or EvalContext to call tools through computer_tool_name: Name of the computer tool in the environment """ self._grounder = grounder - self._mcp_client = mcp_client + self._ctx = ctx self._computer_tool_name = computer_tool_name def get_openai_tool_schema(self) -> dict: @@ -172,10 +173,8 @@ async def __call__( if keys is not None: computer_args["keys"] = keys - result = await self._mcp_client.call_tool( - MCPToolCall( - name=self._computer_tool_name, arguments={**computer_args, **kwargs} - ) + result = await self._ctx.call_tool( + (self._computer_tool_name, {**computer_args, **kwargs}) ) return result.content @@ -224,10 +223,8 @@ async def __call__( if scroll_y is not None: computer_args["scroll_y"] = scroll_y - result = await self._mcp_client.call_tool( - MCPToolCall( - name=self._computer_tool_name, arguments={**computer_args, **kwargs} - ) + result = await self._ctx.call_tool( + (self._computer_tool_name, {**computer_args, **kwargs}) ) return result.content @@ -292,10 +289,8 @@ async def __call__( if button: computer_args["button"] = button - result = await self._mcp_client.call_tool( - MCPToolCall( - name=self._computer_tool_name, arguments={**computer_args, **kwargs} - ) + result = await self._ctx.call_tool( + (self._computer_tool_name, {**computer_args, **kwargs}) ) return result.content diff --git a/hud/types.py b/hud/types.py index 455a9ed9..30bae742 100644 --- a/hud/types.py +++ b/hud/types.py @@ -55,23 +55,28 @@ def cls(self) -> type: class BaseAgentConfig(BaseModel): - """Standard agent configuration that tasks can override. - Provider-specific configs should not be included here. + """Agent configuration for LLM-specific settings. + + Note: allowed_tools, disallowed_tools, append_setup_output, and initial_screenshot + are kept for backwards compatibility with v4 task configs but are no longer applied + at the agent level. These should be configured on the Environment/Task instead. """ model_config = ConfigDict(arbitrary_types_allowed=True, extra="forbid") + # LLM-specific setting + system_prompt: str | None = None + + # Deprecated: kept for backwards compat with v4 task configs, not applied by agent allowed_tools: list[str] | None = None disallowed_tools: list[str] | None = None - response_tool_name: str | None = None - system_prompt: str | None = None append_setup_output: bool = True initial_screenshot: bool = True -class Task(BaseModel): +class LegacyTask(BaseModel): """ - DEPRECATED: Use Eval from env() instead. + DEPRECATED: Use Task from env() instead. A task configuration that can be used to create a task. @@ -79,10 +84,18 @@ class Task(BaseModel): template placeholders in the format ${VAR_NAME} or ${VAR_NAME:default_value}. .. deprecated:: 0.5.0 - Task is deprecated. Use `env("script_name", **args)` to create Eval objects, - or use string slugs with `hud.eval("org/evalset:*")`. + LegacyTask is deprecated in v0.5.0 and will be removed in v0.6.0 + (no earlier than March 1st, 2025). + + Use one of these migration paths: + + 1. Quick conversion: ``Task.from_v4(legacy_task)`` converts LegacyTask to Task + 2. Full migration: Use ``@env.scenario()`` with setup code before first yield + and evaluate code after first yield + + See https://docs.hud.ai/migration for the full migration guide. - Example: + Example (deprecated): mcp_config: { "hud": { "url": "${HUD_MCP_URL:https://mcp.hud.ai/v3/mcp}", @@ -104,12 +117,14 @@ class Task(BaseModel): metadata: dict[str, Any] = Field(default_factory=dict) def __init__(self, **data: Any) -> None: - """Initialize Task with deprecation warning.""" + """Initialize LegacyTask with deprecation warning.""" import warnings warnings.warn( - "Task is deprecated. Use env('script_name', **args) to create Eval objects, " - "or use string slugs with hud.eval('org/evalset:*').", + "LegacyTask is deprecated in v0.5.0 and will be removed in v0.6.0 " + "(no earlier than March 1st, 2025). " + "Use Task.from_v4() for quick conversion, or migrate to @env.scenario(). " + "See https://docs.hud.ai/migration for details.", DeprecationWarning, stacklevel=2, ) @@ -378,7 +393,7 @@ class Trace(BaseModel): isError: bool = Field(default=False) # Metadata - task: Task | None = Field(default=None) + task: LegacyTask | None = Field(default=None) # Trace trace: list[TraceStep] = Field(default_factory=list) diff --git a/hud/utils/tasks.py b/hud/utils/tasks.py index 2a5606c7..90528830 100644 --- a/hud/utils/tasks.py +++ b/hud/utils/tasks.py @@ -4,13 +4,13 @@ from pathlib import Path from typing import Any -from hud.types import Task +from hud.types import LegacyTask from hud.utils.hud_console import HUDConsole hud_console = HUDConsole() -def load_tasks(tasks_input: str | list[dict], *, raw: bool = False) -> list[Task] | list[dict]: +def load_tasks(tasks_input: str | list[dict], *, raw: bool = False) -> list[LegacyTask] | list[dict]: """Load tasks from various sources. Args: @@ -22,10 +22,10 @@ def load_tasks(tasks_input: str | list[dict], *, raw: bool = False) -> list[Task raw: If True, return raw dicts without validation or env substitution Returns: - - If raw=False (default): list[Task] + - If raw=False (default): list[LegacyTask] - If raw=True: list[dict] """ - tasks: list[Task] | list[dict] = [] + tasks: list[LegacyTask] | list[dict] = [] if isinstance(tasks_input, list): # Direct list of task dicts @@ -33,7 +33,7 @@ def load_tasks(tasks_input: str | list[dict], *, raw: bool = False) -> list[Task if raw: return [item for item in tasks_input if isinstance(item, dict)] for item in tasks_input: - task = Task(**item) + task = LegacyTask(**item) tasks.append(task) elif isinstance(tasks_input, str): @@ -52,7 +52,7 @@ def load_tasks(tasks_input: str | list[dict], *, raw: bool = False) -> list[Task if raw: return [item for item in data if isinstance(item, dict)] for item in data: - task = Task(**item) + task = LegacyTask(**item) tasks.append(task) # Handle JSONL files (one task per line) @@ -74,7 +74,7 @@ def load_tasks(tasks_input: str | list[dict], *, raw: bool = False) -> list[Task if raw: return raw_items for it in raw_items: - task = Task(**it) + task = LegacyTask(**it) tasks.append(task) # Check if it's a HuggingFace dataset @@ -107,7 +107,7 @@ def load_tasks(tasks_input: str | list[dict], *, raw: bool = False) -> list[Task if raw: return raw_rows for row in raw_rows: - task = Task(**row) + task = LegacyTask(**row) tasks.append(task) except ImportError as e: @@ -147,18 +147,18 @@ def save_tasks( **kwargs: Extra kwargs forwarded to `Dataset.push_to_hub`. """ - if tasks and isinstance(tasks[0], Task): + if tasks and isinstance(tasks[0], LegacyTask): raise ValueError( - "save_tasks expects dictionaries, not Task objects. " - "Task objects have resolved environment variables which would expose secrets. " + "save_tasks expects dictionaries, not LegacyTask objects. " + "LegacyTask objects have resolved environment variables which would expose secrets. " "Please pass raw dictionaries with template strings like '${HUD_API_KEY}' preserved." ) data: list[dict[str, Any]] = [] for index, task_dict in enumerate(tasks): - if isinstance(task_dict, Task): + if isinstance(task_dict, LegacyTask): raise ValueError( - f"Item {index} is a Task object, not a dictionary. " + f"Item {index} is a LegacyTask object, not a dictionary. " "This would expose resolved environment variables. " "Please convert to dictionary format with template strings preserved." ) diff --git a/hud/utils/tests/test_tasks.py b/hud/utils/tests/test_tasks.py index 9979e752..18bc778c 100644 --- a/hud/utils/tests/test_tasks.py +++ b/hud/utils/tests/test_tasks.py @@ -7,7 +7,7 @@ import pytest -from hud.types import Task +from hud.types import LegacyTask from hud.utils.tasks import load_tasks, save_tasks @@ -21,7 +21,7 @@ def test_load_tasks_from_list(): tasks = load_tasks(task_dicts) assert len(tasks) == 2 - assert all(isinstance(t, Task) for t in tasks) + assert all(isinstance(t, LegacyTask) for t in tasks) assert tasks[0].prompt == "Test task 1" # type: ignore assert tasks[1].prompt == "Test task 2" # type: ignore @@ -55,7 +55,7 @@ def test_load_tasks_from_json_file(): tasks = load_tasks(temp_path) assert len(tasks) == 2 - assert all(isinstance(t, Task) for t in tasks) + assert all(isinstance(t, LegacyTask) for t in tasks) assert tasks[0].prompt == "Test task 1" # type: ignore finally: Path(temp_path).unlink() @@ -99,7 +99,7 @@ def test_load_tasks_from_jsonl_file(): tasks = load_tasks(temp_path) assert len(tasks) == 2 - assert all(isinstance(t, Task) for t in tasks) + assert all(isinstance(t, LegacyTask) for t in tasks) assert tasks[0].prompt == "Test task 1" # type: ignore finally: Path(temp_path).unlink() @@ -124,7 +124,7 @@ def test_load_tasks_from_jsonl_file_with_empty_lines(): tasks = load_tasks(temp_path) assert len(tasks) == 2 - assert all(isinstance(t, Task) for t in tasks) + assert all(isinstance(t, LegacyTask) for t in tasks) finally: Path(temp_path).unlink() @@ -143,7 +143,7 @@ def test_load_tasks_from_jsonl_file_with_list(): tasks = load_tasks(temp_path) assert len(tasks) == 2 - assert all(isinstance(t, Task) for t in tasks) + assert all(isinstance(t, LegacyTask) for t in tasks) finally: Path(temp_path).unlink() @@ -293,21 +293,21 @@ def __str__(self): def test_save_tasks_rejects_task_objects(): - """Test save_tasks raises error for Task objects.""" - task = Task(prompt="test", mcp_config={}) + """Test save_tasks raises error for LegacyTask objects.""" + task = LegacyTask(prompt="test", mcp_config={}) - with pytest.raises(ValueError, match="expects dictionaries, not Task objects"): + with pytest.raises(ValueError, match="expects dictionaries, not LegacyTask objects"): save_tasks([task], "test/repo") # type: ignore def test_save_tasks_rejects_task_objects_in_list(): - """Test save_tasks raises error when Task object is in the list.""" + """Test save_tasks raises error when LegacyTask object is in the list.""" tasks = [ {"id": "1", "prompt": "test", "mcp_config": {}}, - Task(prompt="test2", mcp_config={}), # Task object + LegacyTask(prompt="test2", mcp_config={}), # LegacyTask object ] - with pytest.raises(ValueError, match="Item 1 is a Task object"): + with pytest.raises(ValueError, match="Item 1 is a LegacyTask object"): save_tasks(tasks, "test/repo") # type: ignore From f25f80f87a9f619976c67a47dcad82c5a5c7d046 Mon Sep 17 00:00:00 2001 From: lorenss-m Date: Fri, 12 Dec 2025 23:48:35 -0800 Subject: [PATCH 40/92] misc docs updates --- docs/guides/integrations.mdx | 6 +- docs/guides/sandboxing.mdx | 14 ++-- docs/index.mdx | 8 +- docs/migration.mdx | 134 +++++++++++++++++++++++------- docs/quick-links/deploy.mdx | 6 +- docs/quick-links/environments.mdx | 8 +- docs/reference/agents.mdx | 10 +-- docs/reference/environments.mdx | 24 +++--- docs/reference/evals.mdx | 28 +++---- docs/reference/tasks.mdx | 16 ++-- docs/reference/types.mdx | 8 +- 11 files changed, 171 insertions(+), 91 deletions(-) diff --git a/docs/guides/integrations.mdx b/docs/guides/integrations.mdx index b929ab0d..0d826e07 100644 --- a/docs/guides/integrations.mdx +++ b/docs/guides/integrations.mdx @@ -22,14 +22,14 @@ def lookup_ceo(company: str) -> str: """Look up the CEO of a company.""" return CEOS.get(company.lower(), "Unknown") -@env.script("initials") +@env.scenario("initials") async def find_initials(company: str): answer = yield f"What are the initials of the CEO of {company}?" ceo = CEOS.get(company.lower()) correct = "".join(word[0] for word in ceo.split()) if ceo else None yield 1.0 if answer and correct and correct in answer.upper() else 0.0 -eval = env("initials", company="HUD") +task = env("initials", company="HUD") ``` --- @@ -336,7 +336,7 @@ async with hud.eval(eval) as ctx: llm=llm ) - task = Task( + task = LegacyTask( description=ctx.prompt, expected_output="The initials of the CEO", agent=researcher diff --git a/docs/guides/sandboxing.mdx b/docs/guides/sandboxing.mdx index 5a6ca742..f3eef18a 100644 --- a/docs/guides/sandboxing.mdx +++ b/docs/guides/sandboxing.mdx @@ -38,7 +38,7 @@ HUD runs each eval in its own container—isolated, reproducible, safe. But your **Databases.** Each agent needs its own sandbox. Use in-memory SQLite (fast, resets per eval), transaction rollback, or seed fresh data at start: ```python -@env.script("update-order") +@env.scenario("update-order") async def update_order(order_id: str): await db.seed_from("fixtures/orders.sql") @@ -83,17 +83,17 @@ For grading, you also need to observe what happened. If the agent creates a data ### Deterministic Setup -Each eval should seed the state it needs. HUD handles container isolation—you handle making sure your script sets up the right data before the agent runs. +Each eval should seed the state it needs. HUD handles container isolation—you handle making sure your scenario sets up the right data before the agent runs. ```python # ❌ Bad: Depends on whatever state exists -@env.script("find-user") +@env.scenario("find-user") async def find_user(name: str): answer = yield f"Find the user named {name}" yield 1.0 if name in answer else 0.0 # ✅ Good: Seeds known state before eval -@env.script("find-user") +@env.scenario("find-user") async def find_user(name: str): await db.clear() await db.insert("users", name=name, email=f"{name}@example.com") @@ -150,13 +150,13 @@ If the prompt says "create a document with a Japanese car brand", check for any ```python # ❌ Bad: Too strict—only accepts one answer -@env.script("add-car") +@env.scenario("add-car") async def add_car(): answer = yield "Add a Japanese car brand to the document" yield 1.0 if answer == "Toyota" else 0.0 # ✅ Good: Accepts any valid answer -@env.script("add-car") +@env.scenario("add-car") async def add_car(): answer = yield "Add a Japanese car brand to the document" japanese_brands = ["toyota", "honda", "nissan", "mazda", "subaru"] @@ -168,7 +168,7 @@ async def add_car(): Partial grades help you see where agents fail. Did they add to cart but not checkout? That's useful signal. Break complex grading into sub-checks with weighted grades: ```python -@env.script("checkout") +@env.scenario("checkout") async def checkout(product: str): answer = yield f"Add {product} to cart and checkout" diff --git a/docs/index.mdx b/docs/index.mdx index 8d1e7121..8841bbea 100644 --- a/docs/index.mdx +++ b/docs/index.mdx @@ -57,7 +57,7 @@ def search(query: str) -> str: """Search the knowledge base.""" return db.search(query) -@env.script("find-answer") +@env.scenario("find-answer") async def find_answer(question: str): answer = yield f"Find the answer to: {question}" yield 1.0 if "correct" in answer.lower() else 0.0 @@ -69,14 +69,14 @@ Scripts define the prompt (first yield) and the scoring logic (second yield). Th ## 3. Evals: Test and Improve -Run your script with different models. Compare results: +Run your scenario with different models. Compare results: ```python import hud -eval = env("find-answer", question="What is 2+2?") +task = env("find-answer", question="What is 2+2?") -async with hud.eval(eval, variants={"model": ["gpt-4o", "claude-sonnet-4-5"]}, group=5) as ctx: +async with hud.eval(task, variants={"model": ["gpt-4o", "claude-sonnet-4-5"]}, group=5) as ctx: response = await client.chat.completions.create( model=ctx.variants["model"], messages=[{"role": "user", "content": ctx.prompt}] diff --git a/docs/migration.mdx b/docs/migration.mdx index b4d3b0c0..81640469 100644 --- a/docs/migration.mdx +++ b/docs/migration.mdx @@ -6,6 +6,10 @@ icon: "arrow-right-arrow-left" v4 separated environments (Docker containers) from evaluation logic (Task objects). v5 unifies everything in the `Environment` class—tools, setup, and scoring live together. + +**Deprecation Notice**: `LegacyTask`, `setup_tool`, and `evaluate_tool` are deprecated in v0.5.0 and will be removed in v0.6.0 (no earlier than March 1st, 2025). Use `Task.from_v4()` for quick migration or `@env.scenario()` for new code. + + ## Good News: Your Code Still Works `Environment` inherits from `MCPServer`. Same API, same behavior. Just change the import: @@ -34,51 +38,122 @@ env.run() That's it. Your Dockerfile, your tools, your `run()` call—all unchanged. Environment adds scripts, connectors, and integrations on top. -## Recommended: Add Scripts +## Migration Path 1: Quick Conversion with Task.from_v4() -v4 defined setup and evaluation externally in Task objects. v5 lets you define them inside the environment with `@env.script()`. This is optional but recommended—platform features like trace analysis and training work best with scripts. +The fastest way to migrate existing v4 code—no changes to task definitions needed: ```python -@env.script("checkout") -async def checkout_flow(product: str): - # Setup: code before first yield - await env.call_tool("reset_cart") - - # Yield the prompt - answer = yield f"Add '{product}' to cart and checkout" - - # Evaluate: code after first yield, second yield returns reward - yield 1.0 if cart.contains(product) else 0.0 +# BEFORE (deprecated in v0.6.0) +from hud.datasets import LegacyTask + +legacy_task = LegacyTask( + prompt="Navigate to google.com", + mcp_config={"hud": {...}}, + setup_tool={"name": "navigate", "arguments": {"url": "https://google.com"}}, + evaluate_tool={"name": "check_url", "arguments": {}} +) + +# AFTER - One-line conversion +from hud.eval import Task + +task = Task.from_v4(legacy_task) # Converts LegacyTask → Task +# Also works with: Task.from_v4(dict), Task.from_v4(json_string) + +# Works the same with agents +agent = ClaudeAgent.create() +result = await agent.run(task) ``` -Your existing `setup_tool` and `evaluate_tool` definitions still work. Scripts just keep the logic with the environment instead of scattered across task files. +`Task.from_v4()` automatically: +- Runs `setup_tool` at the start of evaluation +- Runs `evaluate_tool` at the end to compute reward +- Preserves all existing behavior -## Recommended: Use env() for Evals +## Migration Path 2: Full Scenario Migration (Recommended) -v4 created Task objects: +For new code or when refactoring, migrate `setup_tool` and `evaluate_tool` to `@env.scenario()`. + +**The rule is simple:** +- `setup_tool` code → **before the first yield** +- `evaluate_tool` code → **after the first yield** ```python -task = Task(prompt="...", mcp_config={...}, setup_tool={...}, evaluate_tool={...}) +# BEFORE (deprecated in v0.6.0) +task = LegacyTask( + prompt="What's the current URL?", + mcp_config={"hud": {...}}, + setup_tool={"name": "navigate", "arguments": {"url": "https://google.com"}}, + evaluate_tool={"name": "check_url", "arguments": {"expected": "google.com"}} +) + +# AFTER +from hud import Environment + +env = Environment("browser").connect_hub("hud-evals/browser") + +@env.scenario("navigate-google") +async def navigate_google(): + # ===== SETUP SECTION (replaces setup_tool) ===== + await env.call_tool("navigate", url="https://google.com") + + # ===== PROMPT (first yield) ===== + answer = yield "What's the current URL?" + + # ===== EVALUATE SECTION (replaces evaluate_tool) ===== + result = await env.call_tool("check_url", expected="google.com") + + # ===== REWARD (second yield) ===== + yield 1.0 if result else 0.0 + +# Create task from scenario +task = env("navigate-google") ``` -v5 creates Evals by calling the environment with a script name: +### Multiple setup_tool Calls + +If you have multiple setup tools, just call them in sequence: ```python -eval = env("checkout", product="laptop") +# BEFORE +setup_tool=[ + {"name": "navigate", "arguments": {"url": "..."}}, + {"name": "login", "arguments": {"user": "..."}}, + {"name": "go_to_page", "arguments": {"page": "settings"}} +] + +# AFTER +@env.scenario("settings-test") +async def settings_test(): + # Multiple setup steps - just call them in order + await env.call_tool("navigate", url="...") + await env.call_tool("login", user="...") + await env.call_tool("go_to_page", page="settings") + + answer = yield "Verify the settings page loaded correctly" + + result = await env.call_tool("check_settings") + yield 1.0 if result else 0.0 ``` -Both work. But `env()` connects to scripts, which means setup/evaluate run automatically and you get structured traces. - -## Optional: Bring Your Own Agent +## Using with Built-in Agents -v4 required using HUD's agent classes: +Built-in agents (ClaudeAgent, OpenAIAgent, etc.) work with both patterns: ```python +from hud.agents import ClaudeAgent + agent = ClaudeAgent.create() -result = await agent.run(task) + +# Works with Task from scenario +result = await agent.run(env("navigate-google")) + +# Works with Task.from_v4() conversion +result = await agent.run(Task.from_v4(legacy_task)) ``` -v5 gives you the `hud.eval()` context manager. Use any agent, any model, any framework: +## Optional: Bring Your Own Agent + +v5 gives you the `hud.eval()` context manager for maximum flexibility: ```python async with hud.eval(env("checkout", product="laptop")) as ctx: @@ -99,9 +174,10 @@ The old `ClaudeAgent` and `OperatorAgent` still work—even with the new `hud.ev ## Quick Reference -| v4 | v5 | -|----|-----| +| v4 (deprecated in v0.6.0) | v5 | +|---------------------------|-----| +| `LegacyTask(...)` | `Task.from_v4(...)` (quick) or `env("scenario", ...)` (recommended) | +| `setup_tool` | Code before first yield in `@env.scenario()` | +| `evaluate_tool` | Code after first yield in `@env.scenario()` | | `MCPServer` | `Environment` (drop-in replacement) | -| `setup_tool` / `evaluate_tool` | `@env.script()` (recommended) | -| `Task(...)` | `env("script", ...)` (recommended) | -| `agent.run(task)` | `hud.eval()` + any agent (optional) | +| `agent.run(task)` | Still works, or use `hud.eval()` for BYOA | diff --git a/docs/quick-links/deploy.mdx b/docs/quick-links/deploy.mdx index ba7ec3f7..dd55b2c1 100644 --- a/docs/quick-links/deploy.mdx +++ b/docs/quick-links/deploy.mdx @@ -24,16 +24,16 @@ env.connect_hub("my-org/my-env") Once deployed, create evals on [hud.ai](https://hud.ai) from your scripts. Each eval is a frozen configuration—same prompt, same scoring, every time. -Your script might take arguments: +Your scenario might take arguments: ```python -@env.script("checkout") +@env.scenario("checkout") async def checkout_flow(product_name: str, apply_coupon: bool = False): yield f"Complete checkout for {product_name}" + (" with coupon" if apply_coupon else "") yield 1.0 if order_confirmed() else 0.0 ``` -On the platform, click **New Eval** → select your script → fill in the arguments. Create multiple evals from the same script: +On the platform, click **New Eval** → select your scenario → fill in the arguments. Create multiple evals from the same scenario: | Eval Name | Arguments | |-----------|-----------| diff --git a/docs/quick-links/environments.mdx b/docs/quick-links/environments.mdx index 4ccb49ce..9857807d 100644 --- a/docs/quick-links/environments.mdx +++ b/docs/quick-links/environments.mdx @@ -50,7 +50,7 @@ async with env() as ctx: To evaluate an agent, you need two things: what to tell it, and how to score what it did. Scripts capture both with two `yield` statements: ```python -@env.script("checkout") +@env.scenario("checkout") async def checkout_flow(product_name: str): # Yield the prompt, receive the agent's final answer answer = yield f"Add '{product_name}' to cart and complete checkout" @@ -64,12 +64,12 @@ The agent runs between the yields. First yield sends the prompt and returns the ## Evals -Call the environment with a script name and arguments to create an eval: +Call the environment with a scenario name and arguments to create a task: ```python -eval = env("checkout", product_name="Laptop") +task = env("checkout", product_name="Laptop") -async with hud.eval(eval, group=4) as ctx: +async with hud.eval(task, group=4) as ctx: # Connect your agent here. Handle tool calls, run agent loop... response = await client.chat.completions.create( model="gpt-4o", diff --git a/docs/reference/agents.mdx b/docs/reference/agents.mdx index fa092a50..06316c24 100644 --- a/docs/reference/agents.mdx +++ b/docs/reference/agents.mdx @@ -97,7 +97,7 @@ Claude-specific implementation using Anthropic's API. ```python from hud.agents import ClaudeAgent -from hud.datasets import Task +from hud.datasets import LegacyTask agent = ClaudeAgent.create( checkpoint_name="claude-sonnet-4-5", @@ -105,7 +105,7 @@ agent = ClaudeAgent.create( ) result = await agent.run( - Task( + LegacyTask( prompt="Navigate to example.com", mcp_config={ "hud": { @@ -245,12 +245,12 @@ agent = OpenAIChatAgent.create( ```python from hud.agents import ClaudeAgent -from hud.datasets import Task +from hud.datasets import LegacyTask agent = ClaudeAgent.create() result = await agent.run( - Task( + LegacyTask( prompt="Click the submit button", mcp_config={ "hud": { @@ -270,7 +270,7 @@ print(f"Reward: {result.reward}, Done: {result.done}") ### With Setup and Evaluation ```python -task = Task( +task = LegacyTask( prompt="Find the price of the product", mcp_config={ "hud": { diff --git a/docs/reference/environments.mdx b/docs/reference/environments.mdx index 8dfd5e81..94942849 100644 --- a/docs/reference/environments.mdx +++ b/docs/reference/environments.mdx @@ -59,7 +59,7 @@ Tools are automatically documented from type hints and docstrings. Scripts define evaluation logic with two yields: ```python -@env.script("checkout") +@env.scenario("checkout") async def checkout_flow(product: str): # First yield: send prompt, receive answer answer = yield f"Add '{product}' to cart and checkout" @@ -69,12 +69,12 @@ async def checkout_flow(product: str): yield 1.0 if order_exists else 0.0 ``` -Create Evals from scripts: +Create Tasks from scripts: ```python -eval = env("checkout", product="laptop") +task = env("checkout", product="laptop") -async with hud.eval(eval) as ctx: +async with hud.eval(task) as ctx: await agent.run(ctx.prompt) await ctx.submit(agent.response) ``` @@ -271,26 +271,26 @@ env.unmock() # Disable mock mode | Property | Type | Description | |----------|------|-------------| | `name` | `str` | Environment name | -| `prompt` | `str \| None` | Default prompt (set by connect_task) | +| `prompt` | `str \| None` | Default prompt (set by scenarios or agent code) | | `is_connected` | `bool` | True if in context | | `connections` | `dict[str, Connector]` | Active connections | -## Creating Evals +## Creating Tasks -Call the environment to create an Eval: +Call the environment to create a Task: ```python -# With script -eval = env("checkout", product="laptop") +# With scenario +task = env("checkout", product="laptop") -# Without script (just the environment) -eval = env() +# Without scenario (just the environment) +task = env() ``` Then run with `hud.eval()`: ```python -async with hud.eval(eval, variants={"model": ["gpt-4o"]}) as ctx: +async with hud.eval(task, variants={"model": ["gpt-4o"]}) as ctx: ... ``` diff --git a/docs/reference/evals.mdx b/docs/reference/evals.mdx index 58eb60f5..425e461e 100644 --- a/docs/reference/evals.mdx +++ b/docs/reference/evals.mdx @@ -21,7 +21,7 @@ async with hud.eval() as ctx: | Parameter | Type | Description | Default | |-----------|------|-------------|---------| -| `source` | `Eval \| list[Eval] \| str \| None` | Eval objects from `env()`, task slugs, or None | `None` | +| `source` | `Task \| list[Task] \| str \| None` | Task objects from `env()`, task slugs, or None | `None` | | `variants` | `dict[str, Any] \| None` | A/B test configuration (lists expand to combinations) | `None` | | `group` | `int` | Runs per variant for statistical significance | `1` | | `group_ids` | `list[str] \| None` | Custom group IDs for parallel runs | `None` | @@ -40,19 +40,19 @@ The `source` parameter accepts: async with hud.eval() as ctx: ctx.reward = compute_reward() -# 2. Eval from Environment (recommended) +# 2. Task from Environment (recommended) env = Environment("my-env") -eval = env("checkout", product="laptop") # Creates Eval from script -async with hud.eval(eval) as ctx: +task = env("checkout", product="laptop") # Creates Task from scenario +async with hud.eval(task) as ctx: await agent.run(ctx.prompt) # 3. Task slug (loads from platform) async with hud.eval("my-org/browser-task") as ctx: await agent.run(ctx) -# 4. Multiple evals -evals = [env("checkout", product="laptop"), env("checkout", product="phone")] -async with hud.eval(evals) as ctx: +# 4. Multiple tasks +tasks = [env("checkout", product="laptop"), env("checkout", product="phone")] +async with hud.eval(tasks) as ctx: await agent.run(ctx.prompt) ``` @@ -111,7 +111,7 @@ async with hud.eval( |----------|------|-------------| | `trace_id` | `str` | Unique trace identifier | | `eval_name` | `str` | Evaluation name | -| `prompt` | `str \| None` | Task prompt (from script or task) | +| `prompt` | `str \| None` | Task prompt (from scenario or task) | | `variants` | `dict[str, Any]` | Current variant assignment | | `reward` | `float \| None` | Evaluation reward (settable) | | `answer` | `str \| None` | Submitted answer | @@ -127,7 +127,7 @@ async with hud.eval( All `Environment` methods are available, plus: ```python -# Submit answer (passes to script for evaluation) +# Submit answer (passes to scenario for evaluation) await ctx.submit(answer) # Set reward directly @@ -167,17 +167,17 @@ env = Environment("my-env") def count_letter(text: str, letter: str) -> int: return text.lower().count(letter.lower()) -@env.script("count") -async def count_script(sentence: str, letter: str): +@env.scenario("count") +async def count_scenario(sentence: str, letter: str): answer = yield f"How many '{letter}' in '{sentence}'?" correct = str(sentence.lower().count(letter.lower())) yield correct in answer -# Create an Eval from the script -eval = env("count", sentence="Strawberry", letter="r") +# Create a Task from the scenario +task = env("count", sentence="Strawberry", letter="r") # Run with variants -async with hud.eval(eval, variants={"model": ["gpt-4o", "claude"]}) as ctx: +async with hud.eval(task, variants={"model": ["gpt-4o", "claude"]}) as ctx: response = await client.chat.completions.create( model=ctx.variants["model"], messages=[{"role": "user", "content": ctx.prompt}], diff --git a/docs/reference/tasks.mdx b/docs/reference/tasks.mdx index 0bd6d76a..44f93138 100644 --- a/docs/reference/tasks.mdx +++ b/docs/reference/tasks.mdx @@ -4,12 +4,16 @@ description: "SDK reference for task configuration and dataset utilities" icon: "list-check" --- -The HUD SDK provides the `Task` class for defining agent objectives and dataset utilities for managing task collections. +The HUD SDK provides the `LegacyTask` class for defining agent objectives and dataset utilities for managing task collections. -## Task Class + +`LegacyTask` is deprecated. For new code, use `env("scenario_name", **args)` to create Task objects. See [Environments](/reference/environments) for the recommended approach. + + +## LegacyTask Class ```python -from hud.datasets import Task +from hud.datasets import LegacyTask ``` Pydantic model that defines an agent's objective, setup, and evaluation criteria. @@ -31,7 +35,7 @@ Pydantic model that defines an agent's objective, setup, and evaluation criteria The `mcp_config` field automatically resolves environment variables using `${VAR_NAME}` syntax: ```python -task = Task( +task = LegacyTask( prompt="Navigate to the dashboard", mcp_config={ "browser": { @@ -45,7 +49,7 @@ task = Task( ) ``` -Variables are resolved when Task is created from a dict - this is why datasets should store raw dictionaries. +Variables are resolved when LegacyTask is created from a dict - this is why datasets should store raw dictionaries. ## Running Tasks @@ -208,7 +212,7 @@ The `agent_config` field on tasks supports: | `initial_screenshot` | `bool` | Take screenshot before first action | ```python -task = Task( +task = LegacyTask( prompt="Complete the form", mcp_config={...}, agent_config={ diff --git a/docs/reference/types.mdx b/docs/reference/types.mdx index da7ed17b..57f8cdb5 100644 --- a/docs/reference/types.mdx +++ b/docs/reference/types.mdx @@ -6,7 +6,7 @@ icon: "code" Core types used throughout the HUD SDK. -## Eval +## Task Created by calling an Environment. Holds configuration for running an evaluation. @@ -14,13 +14,13 @@ Created by calling an Environment. Holds configuration for running an evaluation from hud import Environment env = Environment("my-env") -eval = env("script_name", arg1="value") # Returns Eval +task = env("scenario_name", arg1="value") # Returns Task ``` | Field | Type | Description | |-------|------|-------------| | `env` | `Environment \| dict \| None` | Source environment | -| `script` | `str \| None` | Script name to run | +| `scenario` | `str \| None` | Scenario name to run | | `args` | `dict[str, Any]` | Script arguments | | `trace_id` | `str \| None` | Trace identifier | | `job_id` | `str \| None` | Parent job ID | @@ -33,7 +33,7 @@ eval = env("script_name", arg1="value") # Returns Eval Returned by `hud.eval()`. Extends Environment with evaluation tracking. ```python -async with hud.eval(eval) as ctx: +async with hud.eval(task) as ctx: print(ctx.prompt) # Task prompt print(ctx.variants) # Current variant ctx.reward = 1.0 # Set reward From f10fa9bd86da3453415389934acb852db571bff0 Mon Sep 17 00:00:00 2001 From: lorenss-m Date: Sat, 13 Dec 2025 00:17:54 -0800 Subject: [PATCH 41/92] add meta into analyze --- hud/clients/base.py | 40 ++++++++--- hud/clients/tests/test_analyze_scenarios.py | 75 +++++++++++++++++++++ 2 files changed, 106 insertions(+), 9 deletions(-) diff --git a/hud/clients/base.py b/hud/clients/base.py index ee5ad5a6..b7ce86a8 100644 --- a/hud/clients/base.py +++ b/hud/clients/base.py @@ -400,12 +400,16 @@ async def analyze_environment(self) -> dict[str, Any]: try: resources = await self.list_resources() for resource in resources: - resource_info = { + resource_info: dict[str, Any] = { "uri": str(resource.uri), "name": resource.name, "description": resource.description, "mime_type": getattr(resource, "mimeType", None), } + # Include meta field if present (contains scenario source code) + meta = getattr(resource, "meta", None) + if meta: + resource_info["meta"] = meta analysis["resources"].append(resource_info) except Exception as e: if self.verbose: @@ -425,13 +429,16 @@ async def analyze_environment(self) -> dict[str, Any]: for a in raw_args ] - analysis["prompts"].append( - { - "name": prompt.name, - "description": prompt.description, - "arguments": args, - } - ) + prompt_info: dict[str, Any] = { + "name": prompt.name, + "description": prompt.description, + "arguments": args, + } + # Include meta field if present (contains scenario source code) + meta = getattr(prompt, "meta", None) + if meta: + prompt_info["meta"] = meta + analysis["prompts"].append(prompt_info) except Exception as e: if self.verbose: hud_console.debug("Could not list prompts: " + str(e)) @@ -440,6 +447,7 @@ async def analyze_environment(self) -> dict[str, Any]: # A scenario is exposed as: # - Prompt: name "{env}:{scenario}" with description prefix "[Setup]" # - Resource: uri "{env}:{scenario}" with description prefix "[Evaluate]" + # Both prompt and resource contain meta.code with the scenario source code scenarios_by_id: dict[str, dict[str, Any]] = {} for p in analysis.get("prompts", []): @@ -450,7 +458,7 @@ async def analyze_environment(self) -> dict[str, Any]: if not scenario_id: continue env_name, scenario_name = ([*scenario_id.split(":", 1), ""])[:2] - scenarios_by_id[scenario_id] = { + scenario_info: dict[str, Any] = { "id": scenario_id, "env": env_name, "name": scenario_name or scenario_id, @@ -459,6 +467,11 @@ async def analyze_environment(self) -> dict[str, Any]: "has_setup_prompt": True, "has_evaluate_resource": False, } + # Extract code from meta field if present + meta = p.get("meta") + if meta and isinstance(meta, dict) and "code" in meta: + scenario_info["code"] = meta["code"] + scenarios_by_id[scenario_id] = scenario_info for r in analysis.get("resources", []): desc = (r.get("description") or "").strip() @@ -479,6 +492,15 @@ async def analyze_environment(self) -> dict[str, Any]: } scenarios_by_id[scenario_id]["evaluate_description"] = desc scenarios_by_id[scenario_id]["has_evaluate_resource"] = True + # Extract code from meta field if not already present (from prompt) + meta = r.get("meta") + if ( + meta + and isinstance(meta, dict) + and "code" in meta + and "code" not in scenarios_by_id[scenario_id] + ): + scenarios_by_id[scenario_id]["code"] = meta["code"] analysis["scenarios"] = sorted( scenarios_by_id.values(), diff --git a/hud/clients/tests/test_analyze_scenarios.py b/hud/clients/tests/test_analyze_scenarios.py index 46505507..5715fe82 100644 --- a/hud/clients/tests/test_analyze_scenarios.py +++ b/hud/clients/tests/test_analyze_scenarios.py @@ -120,3 +120,78 @@ async def test_analyze_environment_scenario_from_evaluate_only() -> None: assert scenario["has_setup_prompt"] is False assert scenario["has_evaluate_resource"] is True + +@pytest.mark.asyncio +async def test_analyze_environment_extracts_scenario_code_from_meta() -> None: + """Test that scenario code is extracted from the meta field.""" + scenario_code = """@env.scenario() +async def checkout(product_id: str): + await env.call_tool("navigate", url="/checkout") + yield "Complete the checkout" + result = await env.call_tool("check_order") + yield 1.0 if result else 0.0 +""" + # Use model_validate with _meta alias (Pydantic alias for the meta field) + prompts = [ + types.Prompt.model_validate({ + "name": "my-env:checkout", + "description": "[Setup] Checkout flow", + "arguments": [{"name": "product_id", "required": True}], + "_meta": {"code": scenario_code}, + }) + ] + resources = [ + types.Resource.model_validate({ + "uri": "my-env:checkout", + "name": "checkout", + "description": "[Evaluate] Checkout flow", + "_meta": {"code": scenario_code}, + }) + ] + + client = _MockClient(prompts=prompts, resources=resources) + analysis = await client.analyze_environment() + + assert len(analysis["scenarios"]) == 1 + scenario = analysis["scenarios"][0] + assert scenario["id"] == "my-env:checkout" + assert "code" in scenario + assert scenario["code"] == scenario_code + assert "async def checkout" in scenario["code"] + + +@pytest.mark.asyncio +async def test_analyze_environment_extracts_meta_on_prompts_and_resources() -> None: + """Test that meta field is included in prompts and resources analysis.""" + meta_data = {"code": "test code", "extra": "value"} + # Use model_validate with _meta alias (Pydantic alias for the meta field) + prompts = [ + types.Prompt.model_validate({ + "name": "test-prompt", + "description": "A test prompt", + "arguments": [], + "_meta": meta_data, + }) + ] + resources = [ + types.Resource.model_validate({ + "uri": "file:///test", + "name": "test-resource", + "description": "A test resource", + "_meta": meta_data, + }) + ] + + client = _MockClient(prompts=prompts, resources=resources) + analysis = await client.analyze_environment() + + # Check prompts have meta + assert len(analysis["prompts"]) == 1 + assert "meta" in analysis["prompts"][0] + assert analysis["prompts"][0]["meta"] == meta_data + + # Check resources have meta + assert len(analysis["resources"]) == 1 + assert "meta" in analysis["resources"][0] + assert analysis["resources"][0]["meta"] == meta_data + From 5f30f18bc37b98c5b317c0d8d981dd1c424a414d Mon Sep 17 00:00:00 2001 From: lorenss-m Date: Sat, 13 Dec 2025 01:36:50 -0800 Subject: [PATCH 42/92] update tests --- hud/agents/tests/test_gemini.py | 13 ++- .../tests/test_grounded_openai_agent.py | 4 +- hud/agents/tests/test_openai.py | 72 ++++++++------- hud/datasets/__init__.py | 3 +- hud/datasets/loader.py | 11 ++- hud/datasets/runner.py | 42 ++++++++- hud/datasets/tests/test_loader.py | 91 ++++++++++++------- hud/environment/tests/test_connectors.py | 3 +- hud/tools/tests/test_jupyter_tool.py | 5 + 9 files changed, 164 insertions(+), 80 deletions(-) diff --git a/hud/agents/tests/test_gemini.py b/hud/agents/tests/test_gemini.py index 242ce725..a4a91cbf 100644 --- a/hud/agents/tests/test_gemini.py +++ b/hud/agents/tests/test_gemini.py @@ -174,7 +174,8 @@ async def test_convert_tools_for_gemini(self, mock_gemini_client: genai.Client) # Check that tools were converted assert len(agent.gemini_tools) == 1 - assert agent.gemini_tools[0]["name"] == "my_tool" + # Gemini tools have function_declarations + assert agent.gemini_tools[0].function_declarations[0].name == "my_tool" class TestGeminiToolConversion: @@ -215,17 +216,19 @@ async def test_tool_with_properties(self, mock_gemini_client: genai.Client) -> N assert len(agent.gemini_tools) == 1 tool = agent.gemini_tools[0] - assert tool["name"] == "search" - assert "parameters" in tool + # Gemini tools have function_declarations + assert tool.function_declarations[0].name == "search" + assert tool.function_declarations[0].parameters_json_schema is not None @pytest.mark.asyncio async def test_tool_without_schema(self, mock_gemini_client: genai.Client) -> None: - """Test tool without input schema raises error.""" + """Test tool without description raises error.""" + # Create a tool with inputSchema but no description tools = [ types.Tool( name="incomplete", description=None, - inputSchema=None, + inputSchema={"type": "object"}, ) ] ctx = MockEvalContext(tools=tools) diff --git a/hud/agents/tests/test_grounded_openai_agent.py b/hud/agents/tests/test_grounded_openai_agent.py index 34d1ef00..dfe2d806 100644 --- a/hud/agents/tests/test_grounded_openai_agent.py +++ b/hud/agents/tests/test_grounded_openai_agent.py @@ -70,20 +70,20 @@ def get_openai_tool_schema(self) -> dict: @pytest.mark.asyncio async def test_call_tools_injects_screenshot_and_delegates(monkeypatch: pytest.MonkeyPatch) -> None: - # Agent with fake OpenAI client and fake MCP client + # Agent with fake OpenAI client grounder_cfg = GrounderConfig(api_base="http://example", model="qwen") fake_openai = AsyncOpenAI(api_key="test") agent = GroundedOpenAIChatAgent.create( grounder_config=grounder_cfg, openai_client=fake_openai, checkpoint_name="gpt-4o-mini", - mcp_client=FakeMCPClient(), initial_screenshot=False, ) # Inject a dummy grounded tool to observe args without full initialization dummy_tool = DummyGroundedTool() agent.grounded_tool = dummy_tool # type: ignore + agent._initialized = True # Mark as initialized to skip context initialization # Seed conversation history with a user image png_b64 = ( diff --git a/hud/agents/tests/test_openai.py b/hud/agents/tests/test_openai.py index bc58eb58..75fe2d8f 100644 --- a/hud/agents/tests/test_openai.py +++ b/hud/agents/tests/test_openai.py @@ -150,7 +150,10 @@ async def test_format_blocks_empty(self, mock_openai: AsyncOpenAI) -> None: messages = await agent.format_blocks([]) assert len(messages) == 1 - assert messages[0]["content"] == [] + # Empty blocks produce a single empty text item + assert len(messages[0]["content"]) == 1 + assert messages[0]["content"][0]["type"] == "input_text" + assert messages[0]["content"][0]["text"] == "" @pytest.mark.asyncio async def test_format_tool_results_text(self, mock_openai: AsyncOpenAI) -> None: @@ -172,7 +175,9 @@ async def test_format_tool_results_text(self, mock_openai: AsyncOpenAI) -> None: assert len(messages) == 1 assert messages[0]["type"] == "function_call_output" assert messages[0]["call_id"] == "call_123" - assert messages[0]["output"] == "Tool output" + # Output is a list of content items + assert len(messages[0]["output"]) == 1 + assert messages[0]["output"][0]["text"] == "Tool output" @pytest.mark.asyncio async def test_format_tool_results_with_error(self, mock_openai: AsyncOpenAI) -> None: @@ -192,21 +197,23 @@ async def test_format_tool_results_with_error(self, mock_openai: AsyncOpenAI) -> messages = await agent.format_tool_results(tool_calls, tool_results) assert len(messages) == 1 - assert "Error message" in messages[0]["output"] + # Output is a list; first item is error indicator, second is the message + output = messages[0]["output"] + assert any(item.get("text") == "[tool_error] true" for item in output) + assert any(item.get("text") == "Error message" for item in output) @pytest.mark.asyncio async def test_get_system_messages(self, mock_openai: AsyncOpenAI) -> None: - """Test getting system messages.""" + """Test getting system messages - OpenAI uses instructions field instead.""" agent = OpenAIAgent.create( model_client=mock_openai, system_prompt="You are a helpful assistant.", validate_api_key=False, ) + # OpenAI agent returns empty list - system prompt is passed via instructions messages = await agent.get_system_messages() - assert len(messages) == 1 - assert messages[0]["type"] == "message" - assert messages[0]["role"] == "developer" + assert len(messages) == 0 @pytest.mark.asyncio async def test_convert_tools_for_openai(self, mock_openai: AsyncOpenAI) -> None: @@ -229,9 +236,9 @@ async def test_convert_tools_for_openai(self, mock_openai: AsyncOpenAI) -> None: await agent._initialize_from_ctx(ctx) # Check that tools were converted - assert len(agent.openai_tools) >= 1 + assert len(agent._openai_tools) >= 1 # Find our tool - tool = next((t for t in agent.openai_tools if t.get("name") == "my_tool"), None) + tool = next((t for t in agent._openai_tools if t.get("name") == "my_tool"), None) assert tool is not None assert tool["type"] == "function" @@ -266,7 +273,7 @@ async def test_get_response_with_text(self, mock_openai: AsyncOpenAI) -> None: type="message", role="assistant", status="completed", - content=[ResponseOutputText(type="output_text", text="Hello!")], + content=[ResponseOutputText(type="output_text", text="Hello!", annotations=[])], ) ] mock_openai.responses.create = AsyncMock(return_value=mock_response) @@ -276,7 +283,7 @@ async def test_get_response_with_text(self, mock_openai: AsyncOpenAI) -> None: validate_api_key=False, ) # Set empty tools to avoid needing initialization - agent.openai_tools = [] + agent._openai_tools = [] agent._initialized = True response = await agent.get_response([]) @@ -288,21 +295,14 @@ async def test_get_response_with_text(self, mock_openai: AsyncOpenAI) -> None: async def test_get_response_with_tool_call(self, mock_openai: AsyncOpenAI) -> None: """Test getting response with tool call.""" mock_response = AsyncMock() + # Tool calls come as separate output items, not inside message content mock_response.output = [ - ResponseOutputMessage( - id="msg_123", - type="message", - role="assistant", - status="completed", - content=[ - ResponseFunctionToolCall( - id="call_123", - type="function_call", - call_id="call_123", - name="my_tool", - arguments='{"x": "value"}', - ) - ], + ResponseFunctionToolCall( + id="call_123", + type="function_call", + call_id="call_123", + name="my_tool", + arguments='{"x": "value"}', ) ] mock_openai.responses.create = AsyncMock(return_value=mock_response) @@ -311,8 +311,8 @@ async def test_get_response_with_tool_call(self, mock_openai: AsyncOpenAI) -> No model_client=mock_openai, validate_api_key=False, ) - agent.openai_tools = [] - agent.tool_mapping = {"my_tool": "my_tool"} + agent._openai_tools = [] + agent._tool_name_map = {"my_tool": "my_tool"} agent._initialized = True response = await agent.get_response([]) @@ -336,7 +336,7 @@ async def test_get_response_with_reasoning(self, mock_openai: AsyncOpenAI) -> No type="message", role="assistant", status="completed", - content=[ResponseOutputText(type="output_text", text="Answer!")], + content=[ResponseOutputText(type="output_text", text="Answer!", annotations=[])], ), ] mock_openai.responses.create = AsyncMock(return_value=mock_response) @@ -345,12 +345,13 @@ async def test_get_response_with_reasoning(self, mock_openai: AsyncOpenAI) -> No model_client=mock_openai, validate_api_key=False, ) - agent.openai_tools = [] + agent._openai_tools = [] agent._initialized = True response = await agent.get_response([]) - assert "Thinking about it..." in (response.reasoning or "") - assert response.content == "Answer!" + # Reasoning is prepended to content in OpenAI agent + assert "Thinking about it..." in response.content + assert "Answer!" in response.content class TestOpenAIToolConversion: @@ -385,12 +386,12 @@ async def test_shell_tool_conversion(self, mock_openai: AsyncOpenAI) -> None: await agent._initialize_from_ctx(ctx) # Check for native shell tool - shell_tool = next((t for t in agent.openai_tools if t.get("type") == "shell"), None) + shell_tool = next((t for t in agent._openai_tools if t.get("type") == "shell"), None) assert shell_tool is not None @pytest.mark.asyncio async def test_computer_tool_conversion(self, mock_openai: AsyncOpenAI) -> None: - """Test that computer tool is converted to native format.""" + """Test that computer tool is converted to function format.""" tools = [ types.Tool( name="computer", @@ -407,9 +408,10 @@ async def test_computer_tool_conversion(self, mock_openai: AsyncOpenAI) -> None: agent.ctx = ctx await agent._initialize_from_ctx(ctx) - # Check for native computer tool + # Computer tool is converted to a regular function tool computer_tool = next( - (t for t in agent.openai_tools if t.get("type") == "computer_use_preview"), + (t for t in agent._openai_tools if t.get("name") == "computer"), None, ) assert computer_tool is not None + assert computer_tool.get("type") == "function" diff --git a/hud/datasets/__init__.py b/hud/datasets/__init__.py index e67ac560..3e9f110f 100644 --- a/hud/datasets/__init__.py +++ b/hud/datasets/__init__.py @@ -11,7 +11,7 @@ from hud.utils.tasks import save_tasks from .loader import load_dataset -from .runner import run_dataset +from .runner import run_dataset, run_tasks from .utils import ( BatchRequest, SingleTaskRequest, @@ -28,6 +28,7 @@ "display_results", "load_dataset", "run_dataset", + "run_tasks", "save_tasks", "submit_rollouts", ] diff --git a/hud/datasets/loader.py b/hud/datasets/loader.py index 984c3437..c4f22eb3 100644 --- a/hud/datasets/loader.py +++ b/hud/datasets/loader.py @@ -114,11 +114,14 @@ def _load_from_api(dataset_name: str) -> list[Task]: response.raise_for_status() data = response.json() + # Extract tasks dict from response + tasks_dict = data.get("tasks", {}) + tasks: list[Task] = [] - if isinstance(data, list): - tasks = [_task_from_dict(item) for item in data] - else: - tasks = [_task_from_dict(data)] + for task_id, task_data in tasks_dict.items(): + if task_data.get("id") is None: + task_data["id"] = task_id + tasks.append(_task_from_dict(task_data)) return tasks diff --git a/hud/datasets/runner.py b/hud/datasets/runner.py index c1e291f9..fd8492d7 100644 --- a/hud/datasets/runner.py +++ b/hud/datasets/runner.py @@ -6,9 +6,10 @@ from __future__ import annotations import logging -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Any import hud +from hud.types import AgentType if TYPE_CHECKING: from hud.agents import MCPAgent @@ -18,6 +19,45 @@ logger = logging.getLogger("hud.datasets") +async def run_tasks( + tasks: list[Task], + *, + agent_type: str, + agent_params: dict[str, Any] | None = None, + max_steps: int = 10, + max_concurrent: int = 30, + group_size: int = 1, +) -> list[EvalContext]: + """Run tasks with an agent created from type and parameters. + + This is a convenience wrapper around run_dataset that creates the agent + from a type string and parameters dictionary. + + Args: + tasks: List of Task objects to run. + agent_type: Type of agent to create (e.g., "claude", "openai", "gemini"). + agent_params: Parameters to pass to agent.create(). + max_steps: Maximum steps per task. + max_concurrent: Maximum concurrent tasks. + group_size: Number of times to run each task. + + Returns: + List of EvalContext results from each task execution. + """ + # Use AgentType enum to get the agent class (same pattern as CLI) + agent_type_enum = AgentType(agent_type) + agent_cls = agent_type_enum.cls + agent = agent_cls.create(**(agent_params or {})) + + return await run_dataset( + tasks, + agent, + max_steps=max_steps, + max_concurrent=max_concurrent, + group_size=group_size, + ) + + async def run_dataset( tasks: str | list[Task], agent: MCPAgent, diff --git a/hud/datasets/tests/test_loader.py b/hud/datasets/tests/test_loader.py index 9bc617ff..34b333b5 100644 --- a/hud/datasets/tests/test_loader.py +++ b/hud/datasets/tests/test_loader.py @@ -12,8 +12,8 @@ class TestLoadDataset: """Tests for load_dataset() function.""" - @patch("hud.datasets.loader.httpx.Client") - @patch("hud.datasets.loader.settings") + @patch("httpx.Client") + @patch("hud.settings.settings") def test_load_dataset_success( self, mock_settings: MagicMock, mock_client_class: MagicMock ) -> None: @@ -22,10 +22,23 @@ def test_load_dataset_success( mock_settings.api_key = "test_key" mock_response = MagicMock() - mock_response.json.return_value = [ - {"env": {"name": "test"}, "scenario": "checkout", "args": {"user": "alice"}}, - {"env": {"name": "test"}, "scenario": "login", "args": {"user": "bob"}}, - ] + # New EvalsetTasksResponse format: tasks keyed by task ID + mock_response.json.return_value = { + "evalset_id": "evalset-123", + "evalset_name": "test-dataset", + "tasks": { + "task-1": { + "env": {"name": "test"}, + "scenario": "checkout", + "args": {"user": "alice"}, + }, + "task-2": { + "env": {"name": "test"}, + "scenario": "login", + "args": {"user": "bob"}, + }, + }, + } mock_response.raise_for_status = MagicMock() mock_client = MagicMock() @@ -37,29 +50,38 @@ def test_load_dataset_success( tasks = load_dataset("test-org/test-dataset") assert len(tasks) == 2 - assert tasks[0].scenario == "checkout" - assert tasks[0].args == {"user": "alice"} - assert tasks[1].scenario == "login" + # Tasks are keyed by ID in dict, order may vary + scenarios = {t.scenario for t in tasks} + assert scenarios == {"checkout", "login"} + # Check task IDs are set from dict keys + task_ids = {t.id for t in tasks} + assert task_ids == {"task-1", "task-2"} mock_client.get.assert_called_once_with( "https://api.hud.ai/evals/test-org/test-dataset", headers={"Authorization": "Bearer test_key"}, params={"all": "true"}, ) - @patch("hud.datasets.loader.httpx.Client") - @patch("hud.datasets.loader.settings") + @patch("httpx.Client") + @patch("hud.settings.settings") def test_load_dataset_single_task( self, mock_settings: MagicMock, mock_client_class: MagicMock ) -> None: - """load_dataset() handles single task (non-list) response.""" + """load_dataset() handles single task in EvalsetTasksResponse.""" mock_settings.hud_api_url = "https://api.hud.ai" mock_settings.api_key = "test_key" mock_response = MagicMock() mock_response.json.return_value = { - "env": {"name": "test"}, - "scenario": "checkout", - "args": {"user": "alice"}, + "evalset_id": "evalset-123", + "evalset_name": "test-dataset", + "tasks": { + "task-1": { + "env": {"name": "test"}, + "scenario": "checkout", + "args": {"user": "alice"}, + }, + }, } mock_response.raise_for_status = MagicMock() @@ -73,9 +95,10 @@ def test_load_dataset_single_task( assert len(tasks) == 1 assert tasks[0].scenario == "checkout" + assert tasks[0].id == "task-1" - @patch("hud.datasets.loader.httpx.Client") - @patch("hud.datasets.loader.settings") + @patch("httpx.Client") + @patch("hud.settings.settings") def test_load_dataset_no_api_key( self, mock_settings: MagicMock, mock_client_class: MagicMock ) -> None: @@ -84,7 +107,11 @@ def test_load_dataset_no_api_key( mock_settings.api_key = None mock_response = MagicMock() - mock_response.json.return_value = [] + mock_response.json.return_value = { + "evalset_id": "evalset-123", + "evalset_name": "test-dataset", + "tasks": {}, + } mock_response.raise_for_status = MagicMock() mock_client = MagicMock() @@ -95,14 +122,15 @@ def test_load_dataset_no_api_key( tasks = load_dataset("test-org/test-dataset") + assert len(tasks) == 0 mock_client.get.assert_called_once_with( "https://api.hud.ai/evals/test-org/test-dataset", headers={}, params={"all": "true"}, ) - @patch("hud.datasets.loader.httpx.Client") - @patch("hud.datasets.loader.settings") + @patch("httpx.Client") + @patch("hud.settings.settings") def test_load_dataset_http_error( self, mock_settings: MagicMock, mock_client_class: MagicMock ) -> None: @@ -121,8 +149,8 @@ def test_load_dataset_http_error( with pytest.raises(ValueError, match="Failed to load dataset"): load_dataset("test-org/test-dataset") - @patch("hud.datasets.loader.httpx.Client") - @patch("hud.datasets.loader.settings") + @patch("httpx.Client") + @patch("hud.settings.settings") def test_load_dataset_json_error( self, mock_settings: MagicMock, mock_client_class: MagicMock ) -> None: @@ -140,11 +168,11 @@ def test_load_dataset_json_error( mock_client.__exit__.return_value = None mock_client_class.return_value = mock_client - with pytest.raises(ValueError, match="Error processing dataset"): + with pytest.raises(ValueError, match="Failed to load dataset"): load_dataset("test-org/test-dataset") - @patch("hud.datasets.loader.httpx.Client") - @patch("hud.datasets.loader.settings") + @patch("httpx.Client") + @patch("hud.settings.settings") def test_load_dataset_empty( self, mock_settings: MagicMock, mock_client_class: MagicMock ) -> None: @@ -153,7 +181,7 @@ def test_load_dataset_empty( mock_settings.api_key = "test_key" mock_response = MagicMock() - mock_response.json.return_value = [] + mock_response.json.return_value = {"tasks": {}} mock_response.raise_for_status = MagicMock() mock_client = MagicMock() @@ -166,8 +194,8 @@ def test_load_dataset_empty( assert len(tasks) == 0 - @patch("hud.datasets.loader.httpx.Client") - @patch("hud.datasets.loader.settings") + @patch("httpx.Client") + @patch("hud.settings.settings") def test_load_dataset_missing_fields( self, mock_settings: MagicMock, mock_client_class: MagicMock ) -> None: @@ -176,9 +204,9 @@ def test_load_dataset_missing_fields( mock_settings.api_key = "test_key" mock_response = MagicMock() - mock_response.json.return_value = [ - {"scenario": "test"}, # Missing env and args - ] + mock_response.json.return_value = { + "tasks": {"task-1": {"scenario": "test"}}, + } mock_response.raise_for_status = MagicMock() mock_client = MagicMock() @@ -191,6 +219,7 @@ def test_load_dataset_missing_fields( assert len(tasks) == 1 assert tasks[0].scenario == "test" + assert tasks[0].id == "task-1" assert tasks[0].env is None assert tasks[0].args == {} diff --git a/hud/environment/tests/test_connectors.py b/hud/environment/tests/test_connectors.py index f6047a23..8a05f281 100644 --- a/hud/environment/tests/test_connectors.py +++ b/hud/environment/tests/test_connectors.py @@ -214,6 +214,7 @@ def mount(self, server: Any, *, prefix: str | None = None) -> None: env.connect_hub("hud/browser") - assert "browser" in env._connections + # connect_hub creates a connection named "hud" (the server name) + assert "hud" in env._connections diff --git a/hud/tools/tests/test_jupyter_tool.py b/hud/tools/tests/test_jupyter_tool.py index 3dec0025..932631a3 100644 --- a/hud/tools/tests/test_jupyter_tool.py +++ b/hud/tools/tests/test_jupyter_tool.py @@ -7,6 +7,11 @@ import pytest from mcp.types import TextContent +# Import tornado modules before tests to avoid forward reference issues with mocking +import tornado.httpclient # noqa: F401 +import tornado.ioloop # noqa: F401 +import tornado.websocket # noqa: F401 + from hud.tools.jupyter import JupyterTool, strip_ansi From f7b3c6c8a5965d44c08e20d86f1ed9d8ebc23025 Mon Sep 17 00:00:00 2001 From: lorenss-m Date: Sat, 13 Dec 2025 01:44:02 -0800 Subject: [PATCH 43/92] fix types --- examples/run_evaluation.py | 2 +- hud/agents/claude.py | 11 +--- hud/agents/gemini.py | 9 +-- hud/agents/grounded_openai.py | 4 +- hud/agents/tests/test_base_runtime.py | 1 + hud/agents/tests/test_claude.py | 9 +-- hud/agents/tests/test_gemini.py | 1 - hud/agents/tests/test_openai.py | 1 - hud/agents/tests/test_run_eval.py | 4 +- hud/cli/eval.py | 5 +- hud/clients/mcp_use.py | 4 +- hud/clients/tests/test_analyze_scenarios.py | 67 ++++++++++++--------- hud/datasets/__init__.py | 2 +- hud/datasets/loader.py | 2 +- hud/datasets/tests/test_loader.py | 1 - hud/datasets/utils.py | 18 ++++-- hud/environment/environment.py | 1 - hud/environment/integrations/adk.py | 2 +- hud/environment/tests/test_connectors.py | 2 - hud/environment/tests/test_environment.py | 3 - hud/environment/types.py | 4 +- hud/environment/utils/tool_wrappers.py | 8 ++- hud/eval/__init__.py | 8 +-- hud/eval/context.py | 2 - hud/eval/manager.py | 3 +- hud/eval/task.py | 30 +++++---- hud/eval/tests/test_eval.py | 48 ++++++++------- hud/patches/warnings.py | 2 - hud/server/server.py | 7 ++- hud/tests/test_datasets_extended.py | 25 +++----- hud/tools/tests/test_jupyter_tool.py | 6 +- hud/types.py | 4 ++ hud/utils/tasks.py | 4 +- 33 files changed, 148 insertions(+), 152 deletions(-) diff --git a/examples/run_evaluation.py b/examples/run_evaluation.py index 1171f4f1..d996f9e7 100644 --- a/examples/run_evaluation.py +++ b/examples/run_evaluation.py @@ -65,7 +65,7 @@ async def main() -> None: ) # Display results - print(f"\n{'='*50}") + print(f"\n{'=' * 50}") print(f"Completed {len(results)} tasks") for i, ctx in enumerate(results): reward = ctx.reward if hasattr(ctx, "reward") else "N/A" diff --git a/hud/agents/claude.py b/hud/agents/claude.py index d8124246..3c5e2a43 100644 --- a/hud/agents/claude.py +++ b/hud/agents/claude.py @@ -5,8 +5,9 @@ import copy import logging from inspect import cleandoc -from typing import TYPE_CHECKING, Any, ClassVar, Literal, cast +from typing import Any, ClassVar, Literal, cast +import mcp.types as types from anthropic import Anthropic, AsyncAnthropic, Omit from anthropic.types import CacheControlEphemeralParam from anthropic.types.beta import ( @@ -22,15 +23,9 @@ BetaToolTextEditor20250728Param, BetaToolUnionParam, ) - -import hud - -if TYPE_CHECKING: - from hud.datasets import LegacyTask - -import mcp.types as types from pydantic import ConfigDict +import hud from hud.settings import settings from hud.tools.computer.settings import computer_settings from hud.types import AgentResponse, BaseAgentConfig, MCPToolCall, MCPToolResult diff --git a/hud/agents/gemini.py b/hud/agents/gemini.py index 91942e49..f6d42739 100644 --- a/hud/agents/gemini.py +++ b/hud/agents/gemini.py @@ -3,19 +3,14 @@ from __future__ import annotations import logging -from typing import TYPE_CHECKING, Any, ClassVar, cast +from typing import Any, ClassVar, cast +import mcp.types as types from google import genai from google.genai import types as genai_types from pydantic import ConfigDict import hud - -if TYPE_CHECKING: - from hud.datasets import LegacyTask - -import mcp.types as types - from hud.settings import settings from hud.types import AgentResponse, BaseAgentConfig, MCPToolCall, MCPToolResult from hud.utils.hud_console import HUDConsole diff --git a/hud/agents/grounded_openai.py b/hud/agents/grounded_openai.py index e427bcb6..e05c6c11 100644 --- a/hud/agents/grounded_openai.py +++ b/hud/agents/grounded_openai.py @@ -140,9 +140,7 @@ async def get_response(self, messages: Any) -> AgentResponse: if not has_image: if self.ctx is None: raise ValueError("ctx is not initialized") - screenshot_result = await self.ctx.call_tool( - ("computer", {"action": "screenshot"}) - ) + screenshot_result = await self.ctx.call_tool(("computer", {"action": "screenshot"})) for block in screenshot_result.content: # Check for ImageContent type from MCP diff --git a/hud/agents/tests/test_base_runtime.py b/hud/agents/tests/test_base_runtime.py index 83502e31..f066c8f7 100644 --- a/hud/agents/tests/test_base_runtime.py +++ b/hud/agents/tests/test_base_runtime.py @@ -174,6 +174,7 @@ def handler(tool_call: MCPToolCall) -> MCPToolResult: @pytest.mark.asyncio async def test_call_tools_timeout_raises() -> None: """Test call_tools raises TimeoutError.""" + def handler(tool_call: MCPToolCall) -> MCPToolResult: raise TimeoutError("timeout") diff --git a/hud/agents/tests/test_claude.py b/hud/agents/tests/test_claude.py index 0ac8c87c..eca41461 100644 --- a/hud/agents/tests/test_claude.py +++ b/hud/agents/tests/test_claude.py @@ -3,7 +3,7 @@ from __future__ import annotations from typing import TYPE_CHECKING, Any -from unittest.mock import AsyncMock, MagicMock, patch +from unittest.mock import MagicMock, patch import pytest from anthropic import AsyncAnthropic @@ -109,9 +109,10 @@ class TestClaudeAgent: @pytest.fixture def mock_anthropic(self) -> AsyncAnthropic: """Create a stub Anthropic client.""" - with patch("hud.agents.claude.AsyncAnthropic") as mock_class, patch( - "hud.agents.claude.Anthropic" - ) as mock_sync: + with ( + patch("hud.agents.claude.AsyncAnthropic") as mock_class, + patch("hud.agents.claude.Anthropic") as mock_sync, + ): # Mock the sync client's models.list() for validation mock_sync.return_value.models.list.return_value = [] diff --git a/hud/agents/tests/test_gemini.py b/hud/agents/tests/test_gemini.py index a4a91cbf..d4e7cc7b 100644 --- a/hud/agents/tests/test_gemini.py +++ b/hud/agents/tests/test_gemini.py @@ -8,7 +8,6 @@ import pytest from google import genai -from google.genai import types as genai_types from mcp import types from hud.agents.gemini import GeminiAgent diff --git a/hud/agents/tests/test_openai.py b/hud/agents/tests/test_openai.py index 75fe2d8f..eff539b3 100644 --- a/hud/agents/tests/test_openai.py +++ b/hud/agents/tests/test_openai.py @@ -15,7 +15,6 @@ ResponseReasoningItem, ) from openai.types.responses.response_reasoning_item import Summary -from pydantic import AnyUrl from hud.agents.openai import OpenAIAgent from hud.eval.context import EvalContext diff --git a/hud/agents/tests/test_run_eval.py b/hud/agents/tests/test_run_eval.py index 1f9a7fc1..46eea596 100644 --- a/hud/agents/tests/test_run_eval.py +++ b/hud/agents/tests/test_run_eval.py @@ -57,9 +57,7 @@ class MockEvalContext(EvalContext): def __init__(self, prompt: str = "Test prompt", tools: list[types.Tool] | None = None) -> None: # Skip parent __init__, just set what we need self.prompt = prompt - self._tools = tools or [ - types.Tool(name="test_tool", description="Test", inputSchema={}) - ] + self._tools = tools or [types.Tool(name="test_tool", description="Test", inputSchema={})] self._submitted: str | None = None self.reward: float | None = None self._initialized = True diff --git a/hud/cli/eval.py b/hud/cli/eval.py index 2685282e..0d984879 100644 --- a/hud/cli/eval.py +++ b/hud/cli/eval.py @@ -509,10 +509,7 @@ async def _run_evaluation(cfg: EvalConfig) -> tuple[list[Any], list[Any]]: if cfg.task_ids: id_set = set(cfg.task_ids) # Match by task.id or index - filtered = [ - t for i, t in enumerate(tasks) - if t.id in id_set or str(i) in id_set - ] + filtered = [t for i, t in enumerate(tasks) if t.id in id_set or str(i) in id_set] if not filtered: hud_console.error(f"No tasks found matching IDs: {', '.join(cfg.task_ids)}") raise typer.Exit(1) diff --git a/hud/clients/mcp_use.py b/hud/clients/mcp_use.py index 0926328c..91b53ef3 100644 --- a/hud/clients/mcp_use.py +++ b/hud/clients/mcp_use.py @@ -284,9 +284,7 @@ async def _list_prompts_impl(self) -> list[types.Prompt]: all_prompts.extend(prompts_result.prompts) except Exception as e: if self.verbose: - hud_console.debug( - f"Could not list prompts from server '{server_name}': {e}" - ) + hud_console.debug(f"Could not list prompts from server '{server_name}': {e}") continue return all_prompts diff --git a/hud/clients/tests/test_analyze_scenarios.py b/hud/clients/tests/test_analyze_scenarios.py index 5715fe82..e19535b8 100644 --- a/hud/clients/tests/test_analyze_scenarios.py +++ b/hud/clients/tests/test_analyze_scenarios.py @@ -2,14 +2,16 @@ from __future__ import annotations -from typing import Any +from typing import TYPE_CHECKING, Any import pytest from mcp import types from pydantic import AnyUrl from hud.clients.base import BaseHUDClient -from hud.types import MCPToolCall, MCPToolResult + +if TYPE_CHECKING: + from hud.types import MCPToolCall, MCPToolResult class _MockClient(BaseHUDClient): @@ -21,7 +23,9 @@ def __init__( prompts: list[types.Prompt], resources: list[types.Resource], ) -> None: - super().__init__(mcp_config={"test": {"url": "mock://test"}}, verbose=True, auto_trace=False) + super().__init__( + mcp_config={"test": {"url": "mock://test"}}, verbose=True, auto_trace=False + ) self._mock_prompts = prompts self._mock_resources = resources # Skip initialize() (which fetches telemetry); we just need analyze_environment(). @@ -133,20 +137,24 @@ async def checkout(product_id: str): """ # Use model_validate with _meta alias (Pydantic alias for the meta field) prompts = [ - types.Prompt.model_validate({ - "name": "my-env:checkout", - "description": "[Setup] Checkout flow", - "arguments": [{"name": "product_id", "required": True}], - "_meta": {"code": scenario_code}, - }) + types.Prompt.model_validate( + { + "name": "my-env:checkout", + "description": "[Setup] Checkout flow", + "arguments": [{"name": "product_id", "required": True}], + "_meta": {"code": scenario_code}, + } + ) ] resources = [ - types.Resource.model_validate({ - "uri": "my-env:checkout", - "name": "checkout", - "description": "[Evaluate] Checkout flow", - "_meta": {"code": scenario_code}, - }) + types.Resource.model_validate( + { + "uri": "my-env:checkout", + "name": "checkout", + "description": "[Evaluate] Checkout flow", + "_meta": {"code": scenario_code}, + } + ) ] client = _MockClient(prompts=prompts, resources=resources) @@ -166,20 +174,24 @@ async def test_analyze_environment_extracts_meta_on_prompts_and_resources() -> N meta_data = {"code": "test code", "extra": "value"} # Use model_validate with _meta alias (Pydantic alias for the meta field) prompts = [ - types.Prompt.model_validate({ - "name": "test-prompt", - "description": "A test prompt", - "arguments": [], - "_meta": meta_data, - }) + types.Prompt.model_validate( + { + "name": "test-prompt", + "description": "A test prompt", + "arguments": [], + "_meta": meta_data, + } + ) ] resources = [ - types.Resource.model_validate({ - "uri": "file:///test", - "name": "test-resource", - "description": "A test resource", - "_meta": meta_data, - }) + types.Resource.model_validate( + { + "uri": "file:///test", + "name": "test-resource", + "description": "A test resource", + "_meta": meta_data, + } + ) ] client = _MockClient(prompts=prompts, resources=resources) @@ -194,4 +206,3 @@ async def test_analyze_environment_extracts_meta_on_prompts_and_resources() -> N assert len(analysis["resources"]) == 1 assert "meta" in analysis["resources"][0] assert analysis["resources"][0]["meta"] == meta_data - diff --git a/hud/datasets/__init__.py b/hud/datasets/__init__.py index 3e9f110f..eb8d040e 100644 --- a/hud/datasets/__init__.py +++ b/hud/datasets/__init__.py @@ -22,8 +22,8 @@ __all__ = [ "BatchRequest", - "SingleTaskRequest", "LegacyTask", + "SingleTaskRequest", "calculate_group_stats", "display_results", "load_dataset", diff --git a/hud/datasets/loader.py b/hud/datasets/loader.py index c4f22eb3..d5823bca 100644 --- a/hud/datasets/loader.py +++ b/hud/datasets/loader.py @@ -49,7 +49,7 @@ def _task_from_dict(item: dict[str, Any]) -> Task: validation = None if item.get("validation"): validation = [MCPToolCall(**v) for v in item["validation"]] - + return Task( id=item.get("id"), env=item.get("env"), # EnvConfig dict: {"name": "browser", "include": [...], ...} diff --git a/hud/datasets/tests/test_loader.py b/hud/datasets/tests/test_loader.py index 34b333b5..b68a1b1c 100644 --- a/hud/datasets/tests/test_loader.py +++ b/hud/datasets/tests/test_loader.py @@ -222,4 +222,3 @@ def test_load_dataset_missing_fields( assert tasks[0].id == "task-1" assert tasks[0].env is None assert tasks[0].args == {} - diff --git a/hud/datasets/utils.py b/hud/datasets/utils.py index 84e4a604..1ac829dd 100644 --- a/hud/datasets/utils.py +++ b/hud/datasets/utils.py @@ -381,9 +381,12 @@ def display_results( for i, (stat, task) in enumerate(zip(results, tasks, strict=False)): task_id = (task.id or "")[:20] # Handle both v4 (prompt attr) and v5 (prompt in args) tasks - raw_prompt = getattr(task, "prompt", None) or ( - task.args.get("prompt") if hasattr(task, "args") else None - ) or task.scenario or "" + raw_prompt = ( + getattr(task, "prompt", None) + or (task.args.get("prompt") if hasattr(task, "args") else None) + or task.scenario + or "" + ) prompt = raw_prompt[:40] if len(raw_prompt) > 40: prompt += "..." @@ -433,9 +436,12 @@ def display_results( task = tasks[i] task_id = (task.id or "")[:20] # Handle both v4 (prompt attr) and v5 (prompt in args) tasks - raw_prompt = getattr(task, "prompt", None) or ( - task.args.get("prompt") if hasattr(task, "args") else None - ) or getattr(task, "scenario", None) or "" + raw_prompt = ( + getattr(task, "prompt", None) + or (task.args.get("prompt") if hasattr(task, "args") else None) + or getattr(task, "scenario", None) + or "" + ) prompt = raw_prompt[:40] if len(raw_prompt) > 40: prompt += "..." diff --git a/hud/environment/environment.py b/hud/environment/environment.py index 0a9d3867..feeaf025 100644 --- a/hud/environment/environment.py +++ b/hud/environment/environment.py @@ -587,4 +587,3 @@ async def checkout(user_id: str): _trace=_trace, _quiet=_quiet, ) - diff --git a/hud/environment/integrations/adk.py b/hud/environment/integrations/adk.py index 2d33887f..93d0cf42 100644 --- a/hud/environment/integrations/adk.py +++ b/hud/environment/integrations/adk.py @@ -45,7 +45,7 @@ def as_adk_tools(self) -> list[Any]: name="assistant", model="gemini-2.0-flash", instruction="You are a helpful assistant.", - tools=env.as_adk_tools() + tools=env.as_adk_tools(), ) runner = Runner(agent=agent) result = await runner.run("Find information about Python") diff --git a/hud/environment/tests/test_connectors.py b/hud/environment/tests/test_connectors.py index 8a05f281..1a42e666 100644 --- a/hud/environment/tests/test_connectors.py +++ b/hud/environment/tests/test_connectors.py @@ -216,5 +216,3 @@ def mount(self, server: Any, *, prefix: str | None = None) -> None: # connect_hub creates a connection named "hud" (the server name) assert "hud" in env._connections - - diff --git a/hud/environment/tests/test_environment.py b/hud/environment/tests/test_environment.py index 39f85d9e..44febe88 100644 --- a/hud/environment/tests/test_environment.py +++ b/hud/environment/tests/test_environment.py @@ -2,8 +2,6 @@ from __future__ import annotations -from typing import Any - import pytest @@ -26,7 +24,6 @@ def test_prompt_can_be_set(self) -> None: assert env.prompt == "Navigate to google.com" - class TestEnvironmentContextManager: """Tests for Environment async context manager.""" diff --git a/hud/environment/types.py b/hud/environment/types.py index dfa76abd..fca74c7c 100644 --- a/hud/environment/types.py +++ b/hud/environment/types.py @@ -9,9 +9,9 @@ class EnvConfig(BaseModel): """Environment configuration for Tasks. - + Specifies which hub to connect to and optional tool filtering. - + Attributes: name: Hub name to connect via connect_hub() (e.g., "browser", "sheets") include: Optional whitelist of tool names to include diff --git a/hud/environment/utils/tool_wrappers.py b/hud/environment/utils/tool_wrappers.py index 876c5632..d1089242 100644 --- a/hud/environment/utils/tool_wrappers.py +++ b/hud/environment/utils/tool_wrappers.py @@ -3,16 +3,18 @@ from __future__ import annotations import json -from typing import TYPE_CHECKING, Any, Callable +from typing import TYPE_CHECKING, Any if TYPE_CHECKING: + from collections.abc import Callable + import mcp.types as mcp_types __all__ = [ - "create_sync_tool_fn", "create_async_tool_fn", - "stringify_result", + "create_sync_tool_fn", "create_tool_fns", + "stringify_result", ] diff --git a/hud/eval/__init__.py b/hud/eval/__init__.py index 93c5f699..45011413 100644 --- a/hud/eval/__init__.py +++ b/hud/eval/__init__.py @@ -36,18 +36,18 @@ # Auto-instrument httpx on import import hud.eval.instrument # noqa: F401 -# Task is safe to import -from hud.eval.task import Task - # run_eval is safe to import (uses lazy imports internally) from hud.eval.manager import run_eval +# Task is safe to import +from hud.eval.task import Task + if TYPE_CHECKING: from hud.eval.context import EvalContext __all__ = [ - "Task", "EvalContext", + "Task", "run_eval", ] diff --git a/hud/eval/context.py b/hud/eval/context.py index 2d35199a..07bcdf4a 100644 --- a/hud/eval/context.py +++ b/hud/eval/context.py @@ -23,7 +23,6 @@ if TYPE_CHECKING: from types import TracebackType - from hud.types import LegacyTask from hud.eval.types import EvalExitPayload, EvalPayload, ParallelEvalComplete @@ -154,7 +153,6 @@ def __init__( self._scenario_name: str | None = None # Current scenario name (for submit) self._source_env_name: str | None = None # Source env name for remote lookups - @classmethod def from_environment( cls, diff --git a/hud/eval/manager.py b/hud/eval/manager.py index b9f0a065..c7e0e897 100644 --- a/hud/eval/manager.py +++ b/hud/eval/manager.py @@ -303,9 +303,10 @@ async def _run_parallel_eval( import asyncio import textwrap + from hud.eval.parallel import log_eval_stats + # Lazy import to avoid circular dependency from hud.eval.task import Task - from hud.eval.parallel import log_eval_stats # Find user code frame and extract the with block body caller_frame = find_user_frame() diff --git a/hud/eval/task.py b/hud/eval/task.py index c2a145a5..5cc55cc5 100644 --- a/hud/eval/task.py +++ b/hud/eval/task.py @@ -35,7 +35,6 @@ from types import TracebackType from hud.environment import Environment - from hud.environment.types import EnvConfig from hud.eval.context import EvalContext __all__ = ["Task", "build_eval_name"] @@ -52,9 +51,7 @@ def _warn_local_mcp(mcp_config: dict[str, Any] | None) -> None: return has_local = any( - isinstance(server_cfg, dict) - and "command" in server_cfg - and not server_cfg.get("url") + isinstance(server_cfg, dict) and "command" in server_cfg and not server_cfg.get("url") for server_cfg in mcp_config.values() if isinstance(server_cfg, dict) ) @@ -112,16 +109,16 @@ class Task: Example (v5 format): ```python from hud.eval import Task - + # Pass dict - auto-converts to Environment task = Task( env={"name": "browser", "include": ["navigate", "screenshot"]}, scenario="checkout", args={"user_id": "alice"}, - validation=[{"name": "check_cart", "arguments": {}}] + validation=[{"name": "check_cart", "arguments": {}}], ) # task.env is now Environment connected to browser hub! - + # Or pass live Environment directly env = Environment("my-env").connect_hub("browser") task = Task(env=env, scenario="checkout", args={"user_id": "alice"}) @@ -161,7 +158,7 @@ class Task: def __post_init__(self) -> None: """Validate and normalize env and validation fields after initialization. - + Auto-converts dict or EnvConfig to Environment by connecting to the hub. Auto-converts validation dicts to MCPToolCall objects. """ @@ -202,8 +199,7 @@ def __post_init__(self) -> None: converted_validation.append(item) else: raise TypeError( - f"validation items must be dict or MCPToolCall, " - f"got {type(item).__name__}" + f"validation items must be dict or MCPToolCall, got {type(item).__name__}" ) self.validation = converted_validation @@ -235,12 +231,14 @@ def from_v4( task = Task.from_v4(legacy_task) # From dict (e.g., loaded from JSON file) - task = Task.from_v4({ - "prompt": "Navigate to google.com", - "mcp_config": {"hud": {...}}, - "setup_tool": {"name": "navigate", "arguments": {"url": "..."}}, - "evaluate_tool": {"name": "check_url", "arguments": {}} - }) + task = Task.from_v4( + { + "prompt": "Navigate to google.com", + "mcp_config": {"hud": {...}}, + "setup_tool": {"name": "navigate", "arguments": {"url": "..."}}, + "evaluate_tool": {"name": "check_url", "arguments": {}}, + } + ) # Use with hud.eval() or as context manager async with task as ctx: diff --git a/hud/eval/tests/test_eval.py b/hud/eval/tests/test_eval.py index 1fa9d655..2ee293fb 100644 --- a/hud/eval/tests/test_eval.py +++ b/hud/eval/tests/test_eval.py @@ -165,8 +165,6 @@ async def test_reward_accessible_after_exit(self) -> None: # Context reference is cleared but reward was set on the actual context - - class TestEnvironmentCall: """Tests for Environment.__call__ returning Task.""" @@ -203,7 +201,7 @@ def test_call_returns_task_with_env(self) -> None: env = Environment("test-env") task = env() - + # Task has reference to the Environment assert task.env is env @@ -240,10 +238,12 @@ def test_from_v4_with_legacy_task(self) -> None: def test_from_v4_with_dict(self) -> None: """Task.from_v4() accepts dict with LegacyTask fields.""" - task = Task.from_v4({ - "prompt": "Navigate to google.com", - "mcp_config": {"hud": {"url": "https://mcp.hud.ai"}}, - }) + task = Task.from_v4( + { + "prompt": "Navigate to google.com", + "mcp_config": {"hud": {"url": "https://mcp.hud.ai"}}, + } + ) assert isinstance(task, Task) assert task.env is not None @@ -265,11 +265,13 @@ def test_from_v4_with_json_string(self) -> None: def test_from_v4_with_setup_tool(self) -> None: """Task.from_v4() preserves setup_tool via env._setup_calls.""" - task = Task.from_v4({ - "prompt": "Check URL", - "mcp_config": {"hud": {"url": "https://mcp.hud.ai"}}, - "setup_tool": {"name": "navigate", "arguments": {"url": "https://google.com"}}, - }) + task = Task.from_v4( + { + "prompt": "Check URL", + "mcp_config": {"hud": {"url": "https://mcp.hud.ai"}}, + "setup_tool": {"name": "navigate", "arguments": {"url": "https://google.com"}}, + } + ) # setup_tool is converted to env._setup_calls assert len(task.env._setup_calls) == 1 @@ -277,11 +279,13 @@ def test_from_v4_with_setup_tool(self) -> None: def test_from_v4_with_evaluate_tool(self) -> None: """Task.from_v4() preserves evaluate_tool via env._evaluate_calls.""" - task = Task.from_v4({ - "prompt": "Check URL", - "mcp_config": {"hud": {"url": "https://mcp.hud.ai"}}, - "evaluate_tool": {"name": "check_url", "arguments": {"expected": "google"}}, - }) + task = Task.from_v4( + { + "prompt": "Check URL", + "mcp_config": {"hud": {"url": "https://mcp.hud.ai"}}, + "evaluate_tool": {"name": "check_url", "arguments": {"expected": "google"}}, + } + ) # evaluate_tool is converted to env._evaluate_calls assert len(task.env._evaluate_calls) == 1 @@ -305,10 +309,12 @@ def test_from_v4_does_not_warn_on_use(self) -> None: with warnings.catch_warnings(record=True) as w: warnings.simplefilter("always") - Task.from_v4({ - "prompt": "test", - "mcp_config": {"hud": {}}, - }) + Task.from_v4( + { + "prompt": "test", + "mcp_config": {"hud": {}}, + } + ) # Should not trigger deprecation warning since we're migrating legacy_warnings = [x for x in w if issubclass(x.category, DeprecationWarning)] diff --git a/hud/patches/warnings.py b/hud/patches/warnings.py index 1a7afd39..0944ebb3 100644 --- a/hud/patches/warnings.py +++ b/hud/patches/warnings.py @@ -52,5 +52,3 @@ def suppress_mcp_use_import_warnings() -> Iterator[None]: ) yield - - diff --git a/hud/server/server.py b/hud/server/server.py index 19b9b7a3..9b6b4d68 100644 --- a/hud/server/server.py +++ b/hud/server/server.py @@ -17,6 +17,7 @@ from starlette.responses import JSONResponse, Response from hud.datasets import run_tasks +from hud.eval.task import Task from hud.server.low_level import LowLevelServerWithInit from hud.types import LegacyTask @@ -753,10 +754,10 @@ async def run_eval(request: Request) -> Response: ) # Add MCP config to each task and validate basic structure - task_objects: list[Task] = [] + task_objects: list[LegacyTask] = [] for task_data in eval_request.tasks: task_data["mcp_config"] = docker_config - task_objects.append(Task.model_validate(task_data)) + task_objects.append(LegacyTask.model_validate(task_data)) agent_params: dict[str, Any] = {} if eval_request.model: @@ -765,7 +766,7 @@ async def run_eval(request: Request) -> Response: # Fire and forget - launch evaluation in background async def run_eval_background() -> None: await run_tasks( - task_objects, + [Task.from_v4(task) for task in task_objects], agent_type=agent_type, agent_params=agent_params, max_steps=eval_request.max_steps, diff --git a/hud/tests/test_datasets_extended.py b/hud/tests/test_datasets_extended.py index 682c077f..e28e395d 100644 --- a/hud/tests/test_datasets_extended.py +++ b/hud/tests/test_datasets_extended.py @@ -2,7 +2,7 @@ from __future__ import annotations -from typing import Any, cast +from typing import cast from unittest.mock import AsyncMock, MagicMock, patch import pytest @@ -49,7 +49,9 @@ def test_taskconfig_list_tools(self): MCPToolCall(name="configure", arguments={"mode": "test"}), ] - task = LegacyTask(prompt="Multi-setup task", mcp_config={"test": True}, setup_tool=setup_tools) + task = LegacyTask( + prompt="Multi-setup task", mcp_config={"test": True}, setup_tool=setup_tools + ) assert isinstance(task.setup_tool, list) assert len(task.setup_tool) == 2 @@ -177,10 +179,9 @@ async def test_run_dataset_with_task_list(self): mock_agent = AsyncMock(spec=MCPAgent) mock_agent.run.return_value = Trace(reward=1.0, done=True) - # Create mock tasks (with mocked Environment to avoid real connections) - mock_env = MagicMock() - mock_env.name = "test" - + # Create mock tasks with env as dict (to avoid real connections) + mock_env = {"name": "test"} + tasks = [ Task(env=mock_env, scenario="test1"), Task(env=mock_env, scenario="test2"), @@ -212,7 +213,7 @@ async def test_run_dataset_from_source_string(self): mock_agent = AsyncMock(spec=MCPAgent) mock_agent.run.return_value = Trace(reward=1.0, done=True) - mock_env = MagicMock() + mock_env = {"name": "test"} mock_tasks = [Task(env=mock_env, scenario="loaded")] mock_ctx = AsyncMock() @@ -240,7 +241,7 @@ async def test_run_dataset_passes_parameters(self): mock_agent = AsyncMock(spec=MCPAgent) mock_agent.run.return_value = Trace(reward=1.0, done=True) - mock_env = MagicMock() + mock_env = {"name": "test"} tasks = [Task(env=mock_env, scenario="test")] mock_ctx = AsyncMock() @@ -250,13 +251,7 @@ async def test_run_dataset_passes_parameters(self): mock_eval.return_value.__aenter__ = AsyncMock(return_value=mock_ctx) mock_eval.return_value.__aexit__ = AsyncMock(return_value=None) - await run_dataset( - tasks, - mock_agent, - max_steps=25, - max_concurrent=10, - group_size=3 - ) + await run_dataset(tasks, mock_agent, max_steps=25, max_concurrent=10, group_size=3) # Verify hud.eval was called with correct params mock_eval.assert_called_once_with( diff --git a/hud/tools/tests/test_jupyter_tool.py b/hud/tools/tests/test_jupyter_tool.py index 932631a3..cb27a4b9 100644 --- a/hud/tools/tests/test_jupyter_tool.py +++ b/hud/tools/tests/test_jupyter_tool.py @@ -5,12 +5,12 @@ from unittest.mock import AsyncMock, MagicMock, patch import pytest -from mcp.types import TextContent # Import tornado modules before tests to avoid forward reference issues with mocking -import tornado.httpclient # noqa: F401 -import tornado.ioloop # noqa: F401 +import tornado.httpclient +import tornado.ioloop import tornado.websocket # noqa: F401 +from mcp.types import TextContent from hud.tools.jupyter import JupyterTool, strip_ansi diff --git a/hud/types.py b/hud/types.py index 30bae742..cda6bed0 100644 --- a/hud/types.py +++ b/hud/types.py @@ -425,11 +425,15 @@ def populate_from_context(self) -> None: self.trace = collected_trace.trace +# Re-export Task for backwards compatibility (after module defs to avoid circular import) +from hud.eval.task import Task # noqa: E402 + __all__ = [ "AgentResponse", "AgentType", "MCPToolCall", "MCPToolResult", + "Task", "Trace", "TraceStep", ] diff --git a/hud/utils/tasks.py b/hud/utils/tasks.py index 90528830..bf44b798 100644 --- a/hud/utils/tasks.py +++ b/hud/utils/tasks.py @@ -10,7 +10,9 @@ hud_console = HUDConsole() -def load_tasks(tasks_input: str | list[dict], *, raw: bool = False) -> list[LegacyTask] | list[dict]: +def load_tasks( + tasks_input: str | list[dict], *, raw: bool = False +) -> list[LegacyTask] | list[dict]: """Load tasks from various sources. Args: From dfdb94f0a6b73765e3ec27a0b621b2513bb924f4 Mon Sep 17 00:00:00 2001 From: lorenss-m Date: Sat, 13 Dec 2025 05:47:13 -0800 Subject: [PATCH 44/92] update a bunch of things --- hud/agents/base.py | 3 - hud/cli/tests/test_analyze.py | 10 +- hud/cli/tests/test_analyze_module.py | 8 +- hud/cli/tests/test_build.py | 6 +- hud/cli/tests/test_cli_root.py | 19 +-- hud/cli/tests/test_debug.py | 12 +- hud/cli/tests/test_eval.py | 18 +-- hud/datasets/loader.py | 12 +- hud/datasets/tests/test_loader.py | 5 +- hud/environment/environment.py | 26 ++--- hud/environment/scenarios.py | 39 +++++-- hud/eval/context.py | 87 +++++++++++++- hud/eval/manager.py | 167 +++++++++++++++++---------- hud/eval/task.py | 149 +++--------------------- hud/eval/types.py | 11 +- hud/server/router.py | 8 +- hud/tests/test_datasets_extended.py | 2 +- hud/tools/base.py | 8 +- 18 files changed, 311 insertions(+), 279 deletions(-) diff --git a/hud/agents/base.py b/hud/agents/base.py index 8c03a156..3189dde7 100644 --- a/hud/agents/base.py +++ b/hud/agents/base.py @@ -332,9 +332,6 @@ async def _run_context( } trace_result = Trace(**trace_params) - # Populate trace steps from current context - trace_result.populate_from_context() - return trace_result async def call_tools( diff --git a/hud/cli/tests/test_analyze.py b/hud/cli/tests/test_analyze.py index 74d2b6a8..7bc1440c 100644 --- a/hud/cli/tests/test_analyze.py +++ b/hud/cli/tests/test_analyze.py @@ -50,7 +50,7 @@ async def test_analyze_environment_success(self) -> None: } with ( - patch("hud.cli.analyze.MCPClient") as MockClient, + patch("hud.clients.fastmcp.FastMCPHUDClient") as MockClient, patch("hud.cli.analyze.console"), patch("hud.cli.analyze.display_interactive") as mock_interactive, ): @@ -80,7 +80,7 @@ async def test_analyze_environment_success(self) -> None: async def test_analyze_environment_failure(self) -> None: """Test handling analysis failure.""" with ( - patch("hud.cli.analyze.MCPClient") as MockClient, + patch("hud.clients.fastmcp.FastMCPHUDClient") as MockClient, patch("hud.cli.analyze.console") as mock_console, patch("platform.system", return_value="Windows"), ): @@ -119,7 +119,7 @@ async def test_analyze_environment_formats(self) -> None: for output_format in ["json", "markdown", "interactive"]: with ( - patch("hud.cli.analyze.MCPClient") as MockClient, + patch("hud.clients.fastmcp.FastMCPHUDClient") as MockClient, patch("hud.cli.analyze.console") as mock_console, patch("hud.cli.analyze.display_interactive") as mock_interactive, patch("hud.cli.analyze.display_markdown") as mock_markdown, @@ -163,7 +163,7 @@ async def test_analyze_with_config_success(self) -> None: } with ( - patch("hud.cli.analyze.MCPClient") as MockClient, + patch("hud.clients.fastmcp.FastMCPHUDClient") as MockClient, patch("hud.cli.analyze.console"), patch("hud.cli.analyze.display_interactive") as mock_interactive, ): @@ -190,7 +190,7 @@ async def test_analyze_with_config_exception(self) -> None: mock_config = {"server": {"command": "test"}} with ( - patch("hud.cli.analyze.MCPClient") as MockClient, + patch("hud.clients.fastmcp.FastMCPHUDClient") as MockClient, patch("hud.cli.analyze.console"), ): # Setup mock client that fails diff --git a/hud/cli/tests/test_analyze_module.py b/hud/cli/tests/test_analyze_module.py index 468389f1..0996b270 100644 --- a/hud/cli/tests/test_analyze_module.py +++ b/hud/cli/tests/test_analyze_module.py @@ -29,7 +29,7 @@ def test_parse_docker_command(): @pytest.mark.asyncio -@patch("hud.cli.analyze.MCPClient") +@patch("hud.clients.fastmcp.FastMCPHUDClient") @patch("hud.cli.analyze.console") async def test_analyze_environment_success_json(mock_console, MockClient): client = AsyncMock() @@ -46,7 +46,7 @@ async def test_analyze_environment_success_json(mock_console, MockClient): @pytest.mark.asyncio -@patch("hud.cli.analyze.MCPClient") +@patch("hud.clients.fastmcp.FastMCPHUDClient") @patch("hud.cli.analyze.console") async def test_analyze_environment_failure(mock_console, MockClient): client = AsyncMock() @@ -93,7 +93,7 @@ def test_display_markdown_both_paths(capsys): assert "MCP Environment Analysis" in captured.out -@patch("hud.cli.analyze.MCPClient") +@patch("hud.clients.fastmcp.FastMCPHUDClient") async def test_analyze_environment_from_config(MockClient, tmp_path: Path): client = AsyncMock() client.initialize.return_value = None @@ -107,7 +107,7 @@ async def test_analyze_environment_from_config(MockClient, tmp_path: Path): assert client.initialize.awaited and client.shutdown.awaited -@patch("hud.cli.analyze.MCPClient") +@patch("hud.clients.fastmcp.FastMCPHUDClient") async def test_analyze_environment_from_mcp_config(MockClient): client = AsyncMock() client.initialize.return_value = None diff --git a/hud/cli/tests/test_build.py b/hud/cli/tests/test_build.py index 9e76977b..78287fa6 100644 --- a/hud/cli/tests/test_build.py +++ b/hud/cli/tests/test_build.py @@ -206,7 +206,7 @@ def test_extract_no_dockerfile(self, tmp_path): class TestAnalyzeMcpEnvironment: """Test analyzing MCP environment.""" - @mock.patch("hud.cli.build.MCPClient") + @mock.patch("hud.clients.fastmcp.FastMCPHUDClient") async def test_analyze_success(self, mock_client_class): """Test successful environment analysis.""" # Setup mock client @@ -240,7 +240,7 @@ async def test_analyze_success(self, mock_client_class): assert result["tools"][0]["name"] == "test_tool" assert "initializeMs" in result - @mock.patch("hud.cli.build.MCPClient") + @mock.patch("hud.clients.fastmcp.FastMCPHUDClient") async def test_analyze_failure(self, mock_client_class): """Test failed environment analysis.""" # Setup mock client to fail @@ -253,7 +253,7 @@ async def test_analyze_failure(self, mock_client_class): with pytest.raises(HudException, match="Connection failed"): await analyze_mcp_environment("test:latest") - @mock.patch("hud.cli.build.MCPClient") + @mock.patch("hud.clients.fastmcp.FastMCPHUDClient") async def test_analyze_verbose_mode(self, mock_client_class): """Test analysis in verbose mode.""" mock_client = mock.AsyncMock() diff --git a/hud/cli/tests/test_cli_root.py b/hud/cli/tests/test_cli_root.py index d0951d74..62500268 100644 --- a/hud/cli/tests/test_cli_root.py +++ b/hud/cli/tests/test_cli_root.py @@ -7,6 +7,10 @@ import hud.cli as cli +# Import the function directly from the __init__ module to avoid namespace conflict with analyze.py +import hud.cli.__init__ as cli_init +analyze_fn = cli_init.analyze + if TYPE_CHECKING: from pathlib import Path @@ -15,7 +19,7 @@ @patch("asyncio.run") def test_analyze_params_metadata(mock_run, mock_analyze): # image only -> metadata path - cli.analyze(params=["img:latest"], output_format="json", verbose=False) + analyze_fn(params=["img:latest"], output_format="json", verbose=False) assert mock_run.called @@ -25,7 +29,7 @@ def test_analyze_params_metadata(mock_run, mock_analyze): def test_analyze_params_live(mock_run, mock_build_cmd, mock_analyze_env): mock_build_cmd.return_value = ["docker", "run", "img", "-e", "K=V"] # docker args trigger live path - cli.analyze(params=["img:latest", "-e", "K=V"], output_format="json", verbose=True) + analyze_fn(params=["img:latest", "-e", "K=V"], output_format="json", verbose=True) assert mock_run.called @@ -34,7 +38,7 @@ def test_analyze_no_params_errors(): # When no params provided, analyze prints help and exits(1) with pytest.raises(typer.Exit): - cli.analyze(params=None, config=None, cursor=None, output_format="json", verbose=False) # type: ignore + analyze_fn(params=None, config=None, cursor=None, output_format="json", verbose=False) # type: ignore @patch("hud.cli.analyze.analyze_environment_from_config", new_callable=AsyncMock) @@ -42,16 +46,17 @@ def test_analyze_no_params_errors(): def test_analyze_from_config(mock_run, mock_func, tmp_path: Path): cfg = tmp_path / "cfg.json" cfg.write_text("{}") - cli.analyze(params=None, config=cfg, cursor=None, output_format="json", verbose=False) # type: ignore + analyze_fn(params=None, config=cfg, cursor=None, output_format="json", verbose=False) # type: ignore assert mock_run.called -@patch("hud.cli.parse_cursor_config") +@patch("hud.cli.console") +@patch("hud.cli.__init__.parse_cursor_config") @patch("hud.cli.analyze.analyze_environment_from_mcp_config", new_callable=AsyncMock) @patch("asyncio.run") -def test_analyze_from_cursor(mock_run, mock_analyze, mock_parse): +def test_analyze_from_cursor(mock_run, mock_analyze, mock_parse, mock_console): mock_parse.return_value = (["cmd", "arg"], None) - cli.analyze(params=None, config=None, cursor="server", output_format="json", verbose=False) # type: ignore + analyze_fn(params=None, config=None, cursor="server", output_format="json", verbose=False) # type: ignore assert mock_run.called diff --git a/hud/cli/tests/test_debug.py b/hud/cli/tests/test_debug.py index 6c4c5d90..19b9e16c 100644 --- a/hud/cli/tests/test_debug.py +++ b/hud/cli/tests/test_debug.py @@ -207,7 +207,7 @@ async def test_phase_3_tool_discovery(self) -> None: with ( patch("subprocess.run", return_value=mock_run_result), patch("subprocess.Popen", return_value=mock_proc), - patch("hud.cli.debug.MCPClient") as MockClient, + patch("hud.clients.MCPClient") as MockClient, ): mock_client = MockClient.return_value mock_client.initialize = AsyncMock() @@ -240,7 +240,7 @@ async def test_phase_3_no_tools(self) -> None: with ( patch("subprocess.run", return_value=mock_run_result), patch("subprocess.Popen", return_value=mock_proc), - patch("hud.cli.debug.MCPClient") as MockClient, + patch("hud.clients.MCPClient") as MockClient, ): mock_client = MockClient.return_value mock_client.initialize = AsyncMock() @@ -277,7 +277,7 @@ async def test_phase_4_remote_deployment(self) -> None: with ( patch("subprocess.run", return_value=mock_run_result), patch("subprocess.Popen", return_value=mock_proc), - patch("hud.cli.debug.MCPClient") as MockClient, + patch("hud.clients.MCPClient") as MockClient, ): mock_client = MockClient.return_value mock_client.initialize = AsyncMock() @@ -311,7 +311,7 @@ async def test_phase_4_slow_initialization(self) -> None: with ( patch("subprocess.run", return_value=mock_run_result), patch("subprocess.Popen", return_value=mock_proc), - patch("hud.cli.debug.MCPClient") as MockClient, + patch("hud.clients.MCPClient") as MockClient, ): mock_client = MockClient.return_value mock_client.initialize = AsyncMock() @@ -349,7 +349,7 @@ async def test_phase_5_concurrent_clients(self) -> None: with ( patch("subprocess.run", return_value=mock_run_result), patch("subprocess.Popen", return_value=mock_proc), - patch("hud.cli.debug.MCPClient") as MockClient, + patch("hud.clients.MCPClient") as MockClient, ): # Create different mock instances for each client mock_clients = [] @@ -393,7 +393,7 @@ async def test_phase_5_concurrent_failure(self) -> None: with ( patch("subprocess.run", return_value=mock_run_result), patch("subprocess.Popen", return_value=mock_proc), - patch("hud.cli.debug.MCPClient") as MockClient, + patch("hud.clients.MCPClient") as MockClient, ): # Set up for phase 1-4 success first test_tool = Mock() diff --git a/hud/cli/tests/test_eval.py b/hud/cli/tests/test_eval.py index 272b6b6b..db89ad60 100644 --- a/hud/cli/tests/test_eval.py +++ b/hud/cli/tests/test_eval.py @@ -61,8 +61,8 @@ async def test_run_dataset_with_task_list(self) -> None: from hud.eval.task import Task tasks = [ - Task(id="task1", scenario="test"), - Task(id="task2", scenario="test"), + Task(env={"name": "test"}, id="task1", scenario="test"), + Task(env={"name": "test"}, id="task2", scenario="test"), ] agent = MockAgent() @@ -92,12 +92,12 @@ async def test_run_dataset_with_string_source(self) -> None: """Test run_dataset with a string source (loads via load_dataset).""" from hud.eval.task import Task - mock_tasks = [Task(id="loaded_task", scenario="loaded")] + mock_tasks = [Task(env={"name": "test"}, id="loaded_task", scenario="loaded")] agent = MockAgent() mock_ctx = MockEvalContext() with ( - patch("hud.datasets.runner.load_dataset", return_value=mock_tasks) as mock_load, + patch("hud.datasets.loader.load_dataset", return_value=mock_tasks) as mock_load, patch("hud.datasets.runner.hud.eval") as mock_eval, ): mock_eval.return_value.__aenter__ = AsyncMock(return_value=mock_ctx) @@ -115,7 +115,7 @@ async def test_run_dataset_empty_tasks_raises(self) -> None: """Test run_dataset raises ValueError for empty tasks.""" agent = MockAgent() - with patch("hud.datasets.runner.load_dataset", return_value=[]): + with patch("hud.datasets.loader.load_dataset", return_value=[]): from hud.datasets.runner import run_dataset with pytest.raises(ValueError, match="No tasks to run"): @@ -126,7 +126,7 @@ async def test_run_dataset_with_group_size(self) -> None: """Test run_dataset passes group_size to hud.eval.""" from hud.eval.task import Task - tasks = [Task(id="task1", scenario="test")] + tasks = [Task(env={"name": "test"}, id="task1", scenario="test")] agent = MockAgent() mock_ctx = MockEvalContext() @@ -146,7 +146,7 @@ async def test_run_dataset_with_max_concurrent(self) -> None: """Test run_dataset passes max_concurrent to hud.eval.""" from hud.eval.task import Task - tasks = [Task(id="task1", scenario="test")] + tasks = [Task(env={"name": "test"}, id="task1", scenario="test")] agent = MockAgent() mock_ctx = MockEvalContext() @@ -166,7 +166,7 @@ async def test_run_dataset_returns_results(self) -> None: """Test run_dataset returns EvalContext results.""" from hud.eval.task import Task - tasks = [Task(id="task1", scenario="test")] + tasks = [Task(env={"name": "test"}, id="task1", scenario="test")] agent = MockAgent() mock_ctx = MockEvalContext() @@ -187,7 +187,7 @@ async def test_run_dataset_parallel_results(self) -> None: """Test run_dataset returns ctx.results for parallel execution.""" from hud.eval.task import Task - tasks = [Task(id="task1", scenario="test")] + tasks = [Task(env={"name": "test"}, id="task1", scenario="test")] agent = MockAgent() # Create mock context with results (parallel execution) diff --git a/hud/datasets/loader.py b/hud/datasets/loader.py index d5823bca..57806f96 100644 --- a/hud/datasets/loader.py +++ b/hud/datasets/loader.py @@ -44,16 +44,20 @@ def _task_from_dict(item: dict[str, Any]) -> Task: # v4 LegacyTask format - convert via Task.from_v4() return Task.from_v4(item) else: - # v5 format - env is EnvConfig dict with name, include, exclude + # v5 format - env is required, scenario is optional + env = item.get("env") + if env is None: + raise ValueError(f"Task missing required 'env' field: {item}") + # Convert validation dicts to MCPToolCall objects validation = None if item.get("validation"): validation = [MCPToolCall(**v) for v in item["validation"]] return Task( - id=item.get("id"), - env=item.get("env"), # EnvConfig dict: {"name": "browser", "include": [...], ...} + env=env, # EnvConfig dict: {"name": "browser", "include": [...], ...} scenario=item.get("scenario"), + id=item.get("id"), args=item.get("args", {}), validation=validation, ) @@ -107,7 +111,7 @@ def _load_from_api(dataset_name: str) -> list[Task]: with httpx.Client() as client: response = client.get( - f"{settings.hud_api_url}/evals/{dataset_name}", + f"{settings.hud_api_url}/tasks/evalset/{dataset_name}", headers=headers, params={"all": "true"}, ) diff --git a/hud/datasets/tests/test_loader.py b/hud/datasets/tests/test_loader.py index b68a1b1c..7ff31544 100644 --- a/hud/datasets/tests/test_loader.py +++ b/hud/datasets/tests/test_loader.py @@ -199,13 +199,13 @@ def test_load_dataset_empty( def test_load_dataset_missing_fields( self, mock_settings: MagicMock, mock_client_class: MagicMock ) -> None: - """load_dataset() handles tasks with missing optional fields.""" + """load_dataset() handles tasks with missing optional fields (but env is required).""" mock_settings.hud_api_url = "https://api.hud.ai" mock_settings.api_key = "test_key" mock_response = MagicMock() mock_response.json.return_value = { - "tasks": {"task-1": {"scenario": "test"}}, + "tasks": {"task-1": {"env": {"name": "test-env"}, "scenario": "test"}}, } mock_response.raise_for_status = MagicMock() @@ -220,5 +220,4 @@ def test_load_dataset_missing_fields( assert len(tasks) == 1 assert tasks[0].scenario == "test" assert tasks[0].id == "task-1" - assert tasks[0].env is None assert tasks[0].args == {} diff --git a/hud/environment/environment.py b/hud/environment/environment.py index feeaf025..4d6804d7 100644 --- a/hud/environment/environment.py +++ b/hud/environment/environment.py @@ -533,21 +533,15 @@ def __repr__(self) -> str: def __call__( self, - scenario: str | None = None, - *, - _trace: bool = True, - _quiet: bool = False, + scenario: str, **args: Any, ) -> Task: """Create a Task from this environment. - Returns a Task that can be entered as a context manager or passed - to hud.eval() for orchestration. + Returns a Task that can be passed to hud.eval() for orchestration. Args: - scenario: Optional scenario name to run (from @env.scenario) - _trace: Whether to send trace data to backend (default True) - _quiet: Whether to suppress printing links (default False) + scenario: Scenario name to run (from @env.scenario) **args: Arguments for the scenario Returns: @@ -564,15 +558,11 @@ async def checkout(user_id: str): yield 1.0 - # Simple use - Task is context manager - async with env("checkout", user_id="alice") as ctx: + # Single task via hud.eval + async with hud.eval(env("checkout", user_id="alice")) as ctx: await agent.run(ctx.prompt) - # Empty - just env - async with env() as ctx: - await ctx.call_tool("navigate", url="...") - - # Orchestrated via hud.eval + # Multiple tasks with variants tasks = [env("checkout", user_id="alice"), env("checkout", user_id="bob")] async with hud.eval(tasks, variants={"model": ["gpt-4o"]}, group=4) as ctx: ... @@ -581,9 +571,7 @@ async def checkout(user_id: str): from hud.eval.task import Task return Task( - env=self, # Pass live environment for local tools/scenarios + env=self, scenario=scenario, args=args, - _trace=_trace, - _quiet=_quiet, ) diff --git a/hud/environment/scenarios.py b/hud/environment/scenarios.py index 5efc6499..ea87102c 100644 --- a/hud/environment/scenarios.py +++ b/hud/environment/scenarios.py @@ -167,10 +167,16 @@ async def run_scenario_setup(self, scenario_name: str, args: dict[str, Any]) -> return str(prompt) else: # Remote scenario - call via MCP prompt - # Format: {env_name}:{scenario_name} (use source env name if available) - env_name = getattr(self, "_source_env_name", None) or self.name - safe_env_name = env_name.replace("_", "-") - prompt_id = f"{safe_env_name}:{scenario_name}" + # If scenario_name already contains ":", it's already namespaced - use directly + # Otherwise, prefix with env name: {env_name}:{scenario_name} + if ":" in scenario_name: + prompt_id = scenario_name + logger.debug("Remote scenario (already namespaced): prompt_id=%s", prompt_id) + else: + env_name = getattr(self, "_source_env_name", None) or self.name + safe_env_name = env_name.replace("_", "-") + prompt_id = f"{safe_env_name}:{scenario_name}" + logger.debug("Remote scenario (adding namespace): prompt_id=%s", prompt_id) try: result = await self.get_prompt(prompt_id, args) # type: ignore[attr-defined] if result.messages: @@ -222,10 +228,14 @@ async def run_scenario_evaluate(self, scenario_name: str) -> float | None: if self._scenario_latest.get(scenario_name) == session_id: del self._scenario_latest[scenario_name] - # Remote scenario - read via MCP resource (use source env name if available) - env_name = getattr(self, "_source_env_name", None) or self.name - safe_env_name = env_name.replace("_", "-") - resource_id = f"{safe_env_name}:{scenario_name}" + # Remote scenario - read via MCP resource + # If scenario_name already contains ":", it's already namespaced - use directly + if ":" in scenario_name: + resource_id = scenario_name + else: + env_name = getattr(self, "_source_env_name", None) or self.name + safe_env_name = env_name.replace("_", "-") + resource_id = f"{safe_env_name}:{scenario_name}" try: contents = await self.read_resource(resource_id) # type: ignore[attr-defined] if contents: @@ -283,7 +293,12 @@ def decorator( # Capture source code for reproducibility try: source_code = inspect.getsource(fn) - except (OSError, TypeError): + except (OSError, TypeError) as e: + logger.warning( + "Could not capture source code for scenario '%s': %s", + scenario_name, + e, + ) source_code = None # Store the generator function @@ -302,7 +317,7 @@ def decorator( scenario_fn = fn scenario_name_ref = scenario_name - async def prompt_handler(**handler_args: Any) -> list[dict[str, Any]]: + async def prompt_handler(**handler_args: Any) -> list[str]: # Create generator instance gen = scenario_fn(**handler_args) @@ -321,7 +336,9 @@ async def prompt_handler(**handler_args: Any) -> list[dict[str, Any]]: prompt_text[:50] if isinstance(prompt_text, str) else prompt_text, ) - return [{"role": "user", "content": str(prompt_text)}] + # Return just the string - FastMCP wraps it in PromptMessage + # Don't return dict or it gets JSON-serialized as text content + return [str(prompt_text)] # Register prompt using FastMCP - create FunctionPrompt directly # to bypass the **kwargs validation in from_function() diff --git a/hud/eval/context.py b/hud/eval/context.py index 07bcdf4a..5163ef36 100644 --- a/hud/eval/context.py +++ b/hud/eval/context.py @@ -23,6 +23,8 @@ if TYPE_CHECKING: from types import TracebackType + from hud.eval.task import Task + from hud.eval.types import EvalExitPayload, EvalPayload, ParallelEvalComplete @@ -152,6 +154,7 @@ def __init__( self._trace_enabled: bool = trace # Whether to send trace data to backend self._scenario_name: str | None = None # Current scenario name (for submit) self._source_env_name: str | None = None # Source env name for remote lookups + self._task: Task | None = None # Task config (set by from_task) @classmethod def from_environment( @@ -227,6 +230,77 @@ def from_environment( return ctx + @classmethod + def from_task( + cls, + task: Task, + *, + trace_id: str | None = None, + api_key: str | None = None, + job_id: str | None = None, + group_id: str | None = None, + index: int = 0, + variants: dict[str, Any] | None = None, + code_snippet: str | None = None, + trace: bool = True, + quiet: bool = False, + ) -> EvalContext: + """Create an EvalContext from a Task config. + + Args: + task: Task config (env, scenario, args) + trace_id: Unique trace ID + api_key: API key for backend calls + job_id: Job ID to link to + group_id: Group ID for parallel evaluations + index: Index in parallel execution + variants: Variant assignment + code_snippet: Code being evaluated + trace: Whether to send traces to backend + quiet: Whether to suppress output + """ + from hud.eval.task import build_eval_name + + eval_name = build_eval_name(task.scenario, task.args) + + ctx = cls.from_environment( + env=task.env, + name=eval_name, + trace_id=trace_id, + api_key=api_key, + job_id=job_id, + group_id=group_id, + index=index, + variants=variants, + code_snippet=code_snippet, + trace=trace, + quiet=quiet, + ) + + # Store task info for scenario execution + ctx._task = task + + return ctx + + async def _run_task_scenario_setup(self) -> None: + """Run the task's scenario setup phase (if scenario provided).""" + if self._task is None or self._task.scenario is None: + return + + self._scenario_name = self._task.scenario + prompt = await self.run_scenario_setup(self._task.scenario, self._task.args) + if prompt: + self.prompt = prompt + + async def _run_task_scenario_evaluate(self) -> None: + """Run the task's scenario evaluate phase (if scenario provided).""" + if self._task is None or self._task.scenario is None: + return + + reward = await self.run_scenario_evaluate(self._task.scenario) + if reward is not None: + self.reward = reward + # ========================================================================= # Summary Context - Attribute Access Control # ========================================================================= @@ -306,12 +380,12 @@ def _get_eval_api_key(self) -> str | None: def _build_base_payload(self) -> EvalPayload: """Build the base payload for enter/exit.""" return EvalPayload( - job_name=self.eval_name, prompt=self.prompt, code_snippet=self.code_snippet, job_id=self.job_id, group_id=self.group_id, variants=self.variants if self.variants else None, + task_version_id=self._task.id if self._task else None, ) async def log(self, metrics: dict[str, Any]) -> None: @@ -420,6 +494,13 @@ async def __aenter__(self) -> Self: # Connect environment (MCP servers, tools) await super().__aenter__() + # Run task scenario setup (if created from_task with scenario) + await self._run_task_scenario_setup() + + # Notify backend and print link + await self._eval_enter() + self._print_eval_link() + return self async def __aexit__( @@ -436,6 +517,10 @@ async def __aexit__( self._completed_at = datetime.now(UTC) + # Run task scenario evaluate (if no error and has scenario) + if exc_type is None: + await self._run_task_scenario_evaluate() + # Track error error_msg: str | None = None if exc_type is not None: diff --git a/hud/eval/manager.py b/hud/eval/manager.py index c7e0e897..e5bb18dd 100644 --- a/hud/eval/manager.py +++ b/hud/eval/manager.py @@ -56,6 +56,40 @@ def _get_eval_name(tasks: list[Task] | None = None) -> str: return "eval" +def _send_job_enter( + job_id: str, + name: str, + variants: dict[str, Any] | None, + group: int, + api_key: str | None, +) -> None: + """Send job enter payload (sync request before traces start).""" + import httpx + + from hud.eval.types import JobEnterPayload + from hud.settings import settings + + api_key = api_key or settings.api_key + if not settings.telemetry_enabled or not api_key: + return + + payload = JobEnterPayload( + name=name, + variants=variants, + group=group, + ) + + try: + httpx.post( + f"{settings.hud_api_url}/trace/job/{job_id}/enter", + json=payload.model_dump(exclude_none=True), + headers={"Authorization": f"Bearer {api_key}"}, + timeout=10.0, + ) + except Exception as e: + logger.warning("Failed to send job enter: %s", e) + + @asynccontextmanager async def run_eval( source: Task | list[Task] | None = None, @@ -194,32 +228,32 @@ async def run_eval( from hud.eval.context import EvalContext if total_evals == 1: - # Simple case: single eval - always use Task for consistent flow if tasks: - single_task = tasks[0] + # Single task - use EvalContext.from_task() + ctx = EvalContext.from_task( + tasks[0], + api_key=api_key, + job_id=job_id, + variants=variant_combos[0], + code_snippet=code_snippet, + trace=trace, + quiet=quiet, + ) + async with ctx: + yield ctx else: - # Blank eval - single_task = Task( - env=None, - scenario=None, + # Blank eval - use EvalContext directly + ctx = EvalContext( + name="eval", api_key=api_key, job_id=job_id, variants=variant_combos[0], code_snippet=code_snippet, - _trace=trace, - _quiet=quiet, + trace=trace, + quiet=quiet, ) - - # Apply common settings - single_task.api_key = api_key - single_task.job_id = job_id - single_task.variants = variant_combos[0] - single_task.code_snippet = code_snippet - single_task._trace = trace - single_task._quiet = quiet - - async with single_task as ctx: - yield ctx + async with ctx: + yield ctx else: # Parallel execution: create implicit job to group traces @@ -227,6 +261,15 @@ async def run_eval( implicit_job_id = job_id or str(uuid.uuid4()) job_url = f"https://hud.ai/jobs/{implicit_job_id}" + # Send job enter (sync request before traces start) + _send_job_enter( + job_id=implicit_job_id, + name=eval_name, + variants=variants, + group=group, + api_key=api_key, + ) + # Print job URL (not individual trace URLs) if not quiet: print_link(job_url, f"🚀 {eval_name}") @@ -305,9 +348,6 @@ async def _run_parallel_eval( from hud.eval.parallel import log_eval_stats - # Lazy import to avoid circular dependency - from hud.eval.task import Task - # Find user code frame and extract the with block body caller_frame = find_user_frame() body_source, captured_locals, context_var = get_with_block_body(caller_frame) @@ -317,46 +357,42 @@ async def _run_parallel_eval( total_evals = base_count * len(variant_combos) * group resolved_group_ids = resolve_group_ids(group_ids, total_evals) - # Create Task objects for parallel execution - task_objects: list[Task] = [] + # Build list of (task_or_none, runtime_params) for each parallel eval + from hud.eval.context import EvalContext + + eval_configs: list[tuple[Task | None, dict[str, Any]]] = [] idx = 0 if tasks: - # Create Task for each (task, variant, run) combination for base_task in tasks: for variant in variant_combos: for _ in range(group): - task_copy = base_task.copy() - task_copy.api_key = api_key - task_copy.job_id = job_id - task_copy.group_id = resolved_group_ids[idx] - task_copy.index = idx - task_copy.variants = variant - task_copy.code_snippet = code_snippet - task_copy._suppress_link = True # Individual traces don't print links - task_copy._trace = trace - task_copy._quiet = quiet - task_objects.append(task_copy) + runtime_params = { + "api_key": api_key, + "job_id": job_id, + "group_id": resolved_group_ids[idx], + "index": idx, + "variants": variant, + "code_snippet": code_snippet, + "trace": trace, + "quiet": True, # Individual traces don't print links + } + eval_configs.append((base_task, runtime_params)) idx += 1 else: - # Blank tasks for each (variant, run) combination for variant in variant_combos: for _ in range(group): - blank_task = Task( - env=None, - scenario=None, - args={}, - api_key=api_key, - job_id=job_id, - group_id=resolved_group_ids[idx], - index=idx, - variants=variant, - code_snippet=code_snippet, - _suppress_link=True, - _trace=trace, - _quiet=quiet, - ) - task_objects.append(blank_task) + runtime_params = { + "api_key": api_key, + "job_id": job_id, + "group_id": resolved_group_ids[idx], + "index": idx, + "variants": variant, + "code_snippet": code_snippet, + "trace": trace, + "quiet": True, + } + eval_configs.append((None, runtime_params)) idx += 1 # Create runner function using the actual variable name from the 'as' clause @@ -369,33 +405,40 @@ async def _run_parallel_eval( # Create semaphore for concurrency control sem = asyncio.Semaphore(max_concurrent) if max_concurrent else None - async def run_one(task_obj: Task) -> EvalContext: - """Run a single Task and return its EvalContext.""" + async def run_one(config: tuple[Task | None, dict[str, Any]]) -> EvalContext: + """Run a single eval and return its EvalContext.""" + task, params = config + idx = params["index"] + + # Create context from task or blank + if task is not None: + ctx = EvalContext.from_task(task, **params) + else: + ctx = EvalContext(name="eval", **params) + try: if sem: - async with sem, task_obj as ctx: + async with sem, ctx: await runner(ctx) else: - async with task_obj as ctx: + async with ctx: await runner(ctx) return ctx except Exception as e: - logger.warning("Parallel eval %d failed: %s", task_obj.index, e) - # Create a failed context from the task - ctx = task_obj.to_eval_context() + logger.warning("Parallel eval %d failed: %s", idx, e) ctx.error = e return ctx # Run in parallel logger.info( - "Running %d tasks (%d base x %d variants x %d runs)%s", - len(task_objects), + "Running %d evals (%d base x %d variants x %d runs)%s", + len(eval_configs), base_count, len(variant_combos), group, f", max_concurrent={max_concurrent}" if max_concurrent else "", ) - completed = await asyncio.gather(*[run_one(t) for t in task_objects]) + completed = await asyncio.gather(*[run_one(cfg) for cfg in eval_configs]) # Log and print stats eval_name = completed[0].eval_name if completed else "eval" diff --git a/hud/eval/task.py b/hud/eval/task.py index 5cc55cc5..8b011e68 100644 --- a/hud/eval/task.py +++ b/hud/eval/task.py @@ -32,10 +32,8 @@ from hud.types import MCPToolCall if TYPE_CHECKING: - from types import TracebackType - from hud.environment import Environment - from hud.eval.context import EvalContext + from hud.environment.types import EnvConfig __all__ = ["Task", "build_eval_name"] @@ -133,29 +131,14 @@ class Task: task = Task.from_v4({"prompt": "...", "mcp_config": {...}, ...}) ``` """ - - # Core v5 task definition - id: str | None = None - env: Environment | None = None + # Required + env: Environment | EnvConfig | dict[str, Any] + # Optional scenario: str | None = None + id: str | None = None args: dict[str, Any] = field(default_factory=dict) validation: list[MCPToolCall] | None = None - # EvalContext creation params (set by hud.eval for parallel execution) - trace_id: str | None = field(default=None, repr=False) - api_key: str | None = field(default=None, repr=False) - job_id: str | None = field(default=None, repr=False) - group_id: str | None = field(default=None, repr=False) - index: int = field(default=0, repr=False) - variants: dict[str, Any] = field(default_factory=dict, repr=False) - code_snippet: str | None = field(default=None, repr=False) - _suppress_link: bool = field(default=False, repr=False) - _trace: bool = field(default=True, repr=False) - _quiet: bool = field(default=False, repr=False) - - # Runtime state - _ctx: EvalContext | None = field(default=None, repr=False) - def __post_init__(self) -> None: """Validate and normalize env and validation fields after initialization. @@ -165,9 +148,8 @@ def __post_init__(self) -> None: from hud.environment import Environment from hud.environment.types import EnvConfig - # Convert env field - if not isinstance(self.env, (Environment, type(None))): - # Convert dict to EnvConfig first (with validation) + # Convert env field (dict/EnvConfig -> Environment) + if not isinstance(self.env, Environment): if isinstance(self.env, dict): try: config = EnvConfig(**self.env) @@ -180,7 +162,7 @@ def __post_init__(self) -> None: config = self.env else: raise TypeError( - f"Task.env must be Environment, EnvConfig, dict, or None. " + f"Task.env must be Environment, EnvConfig, or dict. " f"Got {type(self.env).__name__}" ) @@ -312,124 +294,19 @@ def from_v4( ) return cls( - id=legacy_task.id, env=env, # Live Environment with mcp_config, setup_tool, evaluate_tool - scenario=None, # No scenario - uses prompt directly + scenario=None, # v4 tasks use prompt directly, not scenarios + id=legacy_task.id, args={}, validation=None, ) - # Backwards compat alias - def copy(self) -> Task: - """Create a copy of this Task for parallel execution.""" + """Create a copy of this Task config.""" return Task( + id=self.id, env=self.env, # Share reference - from_environment handles copying scenario=self.scenario, args=self.args.copy(), - trace_id=None, # Each copy gets unique trace_id - api_key=self.api_key, - job_id=self.job_id, - group_id=self.group_id, - index=self.index, - variants=self.variants.copy(), - code_snippet=self.code_snippet, - _suppress_link=self._suppress_link, - _trace=self._trace, - _quiet=self._quiet, + validation=self.validation, ) - - def to_eval_context(self) -> EvalContext: - """Convert this Task to an EvalContext. - - Creates an EvalContext from the environment (live or from config). - If env is EnvConfig or dict, creates Environment by connecting to the hub. - """ - from hud.environment import Environment - from hud.eval.context import EvalContext - - # Get environment (or create blank if None) - source_env = self.env if self.env is not None else Environment("eval") - - eval_name = build_eval_name(self.scenario, self.args) - - # Create EvalContext from environment - ctx = EvalContext.from_environment( - env=source_env, - name=eval_name, - trace_id=self.trace_id, - api_key=self.api_key, - job_id=self.job_id, - group_id=self.group_id, - index=self.index, - variants=self.variants, - code_snippet=self.code_snippet, - ) - ctx._suppress_link = self._suppress_link - ctx._trace_enabled = self._trace - - return ctx - - async def __aenter__(self) -> EvalContext: - """Enter eval context. - - Order of operations: - 1. Create EvalContext from environment config - 2. Connect environment (MCP servers, etc.) - 3. Run scenario setup (if scenario) → sets ctx.prompt - 4. Notify backend (with prompt now set) - 5. Print trace link - """ - self._ctx = self.to_eval_context() - await self._ctx.__aenter__() # Connect env, set trace headers - - # Run scenario setup (sets prompt) - if self.scenario: - await self._run_scenario_setup() - - # Notify backend with prompt included - await self._ctx._eval_enter() - self._ctx._print_eval_link() - - return self._ctx - - async def __aexit__( - self, - exc_type: type[BaseException] | None, - exc_val: BaseException | None, - exc_tb: TracebackType | None, - ) -> None: - """Exit eval context - run scenario evaluate and exit EvalContext.""" - if self._ctx is None: - return - - # If we have a scenario and no error, run its evaluate phase - if self.scenario and exc_type is None: - await self._run_scenario_evaluate() - - # Exit the EvalContext - await self._ctx.__aexit__(exc_type, exc_val, exc_tb) - self._ctx = None - - async def _run_scenario_setup(self) -> None: - """Run the scenario's setup phase (get prompt).""" - if self._ctx is None or self.scenario is None: - return - - # Store scenario name on context for ctx.submit() - self._ctx._scenario_name = self.scenario - - # Delegate to ScenarioMixin.run_scenario_setup - prompt = await self._ctx.run_scenario_setup(self.scenario, self.args) - if prompt: - self._ctx.prompt = prompt - - async def _run_scenario_evaluate(self) -> None: - """Run the scenario's evaluate phase (get reward).""" - if self._ctx is None or self.scenario is None: - return - - # Delegate to ScenarioMixin.run_scenario_evaluate - reward = await self._ctx.run_scenario_evaluate(self.scenario) - if reward is not None: - self._ctx.reward = reward diff --git a/hud/eval/types.py b/hud/eval/types.py index d844eb10..d3ececb0 100644 --- a/hud/eval/types.py +++ b/hud/eval/types.py @@ -32,10 +32,10 @@ class EvalPayload(BaseModel): prompt: str | None = None code_snippet: str | None = None - job_name: str | None = None job_id: str | None = None group_id: str | None = None variants: dict[str, Any] | None = None + task_version_id: str | None = None class EvalExitPayload(EvalPayload): @@ -46,8 +46,17 @@ class EvalExitPayload(EvalPayload): error_message: str | None = None +class JobEnterPayload(BaseModel): + """Payload for job/{job_id}/enter - sent once at job start.""" + + name: str | None = None + variants: dict[str, Any] | None = None # Full variant config + group: int | None = None + + __all__ = [ "EvalExitPayload", "EvalPayload", + "JobEnterPayload", "ParallelEvalComplete", ] diff --git a/hud/server/router.py b/hud/server/router.py index e9859dcb..987fb611 100644 --- a/hud/server/router.py +++ b/hud/server/router.py @@ -140,8 +140,12 @@ async def _functions_catalogue() -> list[str]: self._resource_manager.add_resource(catalogue_resource) # Override _list_tools to hide internal tools when mounted - async def _list_tools(self) -> list[Tool]: - """Override _list_tools to hide internal tools when mounted.""" + async def _list_tools(self, context: Any = None) -> list[Tool]: + """Override _list_tools to hide internal tools when mounted. + + Args: + context: MiddlewareContext passed by FastMCP (optional for backwards compat) + """ return [ tool for key, tool in self._tool_manager._tools.items() diff --git a/hud/tests/test_datasets_extended.py b/hud/tests/test_datasets_extended.py index e28e395d..0cd2aa46 100644 --- a/hud/tests/test_datasets_extended.py +++ b/hud/tests/test_datasets_extended.py @@ -220,7 +220,7 @@ async def test_run_dataset_from_source_string(self): mock_ctx.results = None with ( - patch("hud.datasets.runner.load_dataset", return_value=mock_tasks) as mock_load, + patch("hud.datasets.loader.load_dataset", return_value=mock_tasks) as mock_load, patch("hud.datasets.runner.hud.eval") as mock_eval, ): mock_eval.return_value.__aenter__ = AsyncMock(return_value=mock_ctx) diff --git a/hud/tools/base.py b/hud/tools/base.py index 95e1fa4a..faa475de 100644 --- a/hud/tools/base.py +++ b/hud/tools/base.py @@ -416,8 +416,12 @@ def _update_dispatcher_description(self) -> None: } # Override _list_tools to hide internal tools when mounted - async def _list_tools(self) -> list[Tool]: - """Override _list_tools to hide internal tools when mounted.""" + async def _list_tools(self, context: Any = None) -> list[Tool]: + """Override _list_tools to hide internal tools when mounted. + + Args: + context: MiddlewareContext passed by FastMCP (optional for backwards compat) + """ return [ tool for key, tool in self._tool_manager._tools.items() From 19b09e1b8c828b42c47eda7cbf1e6f09a0b2a7ac Mon Sep 17 00:00:00 2001 From: lorenss-m Date: Sun, 14 Dec 2025 02:56:26 -0800 Subject: [PATCH 45/92] run task accepts old configs --- hud/agents/openai_chat.py | 12 +++- hud/datasets/__init__.py | 3 +- hud/datasets/runner.py | 130 ++++++++++++++++++++++++++++++++++++-- hud/datasets/utils.py | 23 +++++-- hud/eval/context.py | 14 ++-- hud/eval/manager.py | 13 +++- 6 files changed, 173 insertions(+), 22 deletions(-) diff --git a/hud/agents/openai_chat.py b/hud/agents/openai_chat.py index 486d9a53..92f51855 100644 --- a/hud/agents/openai_chat.py +++ b/hud/agents/openai_chat.py @@ -6,6 +6,7 @@ Key points: - Stateless, no special server-side conversation state is assumed. +- Defaults to HUD inference gateway (inference.hud.ai) when HUD_API_KEY is set - Accepts an :class:`openai.AsyncOpenAI` client, caller can supply their own base_url / api_key (e.g. llama.cpp, together.ai, …) - All HUD features (step_count, OTel spans, tool filtering, screenshots, …) @@ -24,6 +25,7 @@ from pydantic import ConfigDict, Field from hud import instrument +from hud.settings import settings from hud.types import AgentResponse, BaseAgentConfig, MCPToolCall, MCPToolResult from hud.utils.hud_console import HUDConsole from hud.utils.types import with_signature @@ -73,10 +75,16 @@ def __init__(self, params: OpenAIChatCreateParams | None = None, **kwargs: Any) self.oai = self.config.openai_client elif self.config.api_key is not None or self.config.base_url is not None: self.oai = AsyncOpenAI(api_key=self.config.api_key, base_url=self.config.base_url) + elif settings.api_key: + # Default to HUD inference gateway + self.oai = AsyncOpenAI( + api_key=settings.api_key, + base_url=settings.hud_gateway_url, + ) else: raise ValueError( - "Either openai_client or api_key must be provided. " - "Set OPENAI_API_KEY environment variable or pass api_key explicitly." + "No API key found. Set HUD_API_KEY for HUD gateway, " + "or provide api_key/base_url/openai_client explicitly." ) self.completion_kwargs = dict(self.config.completion_kwargs) diff --git a/hud/datasets/__init__.py b/hud/datasets/__init__.py index eb8d040e..15b8c19a 100644 --- a/hud/datasets/__init__.py +++ b/hud/datasets/__init__.py @@ -11,7 +11,7 @@ from hud.utils.tasks import save_tasks from .loader import load_dataset -from .runner import run_dataset, run_tasks +from .runner import run_dataset, run_single_task, run_tasks from .utils import ( BatchRequest, SingleTaskRequest, @@ -28,6 +28,7 @@ "display_results", "load_dataset", "run_dataset", + "run_single_task", "run_tasks", "save_tasks", "submit_rollouts", diff --git a/hud/datasets/runner.py b/hud/datasets/runner.py index fd8492d7..028e118b 100644 --- a/hud/datasets/runner.py +++ b/hud/datasets/runner.py @@ -9,7 +9,7 @@ from typing import TYPE_CHECKING, Any import hud -from hud.types import AgentType +from hud.types import AgentType, Trace if TYPE_CHECKING: from hud.agents import MCPAgent @@ -59,7 +59,7 @@ async def run_tasks( async def run_dataset( - tasks: str | list[Task], + tasks: str | list[Task] | list[dict[str, Any]] | Task | dict[str, Any], agent: MCPAgent, *, max_steps: int = 10, @@ -71,8 +71,10 @@ async def run_dataset( This is the primary entry point for running evaluations programmatically. Args: - tasks: Either a source string (file path, API slug) or list of Task objects. - If a string, tasks are loaded via load_dataset(). + tasks: Tasks to run. Can be: + - A source string (file path, API slug) - loaded via load_dataset() + - A single Task object or dict (v4 or v5 format) + - A list of Task objects or dicts (v4 or v5 format) agent: The agent instance to run. max_steps: Maximum steps per task. max_concurrent: Maximum concurrent tasks (for parallel execution). @@ -98,10 +100,27 @@ async def run_dataset( print(f"Reward: {ctx.reward}") ``` """ - from hud.datasets.loader import load_dataset + from hud.datasets.loader import _task_from_dict, load_dataset + from hud.eval.task import Task - # Load tasks if string provided - task_list = load_dataset(tasks) if isinstance(tasks, str) else tasks + # Normalize tasks to list[Task] + if isinstance(tasks, str): + task_list = load_dataset(tasks) + elif isinstance(tasks, Task): + task_list = [tasks] + elif isinstance(tasks, dict): + task_list = [_task_from_dict(tasks)] + elif isinstance(tasks, list): + task_list = [] + for t in tasks: + if isinstance(t, Task): + task_list.append(t) + elif isinstance(t, dict): + task_list.append(_task_from_dict(t)) + else: + raise TypeError(f"Expected Task or dict, got {type(t)}") + else: + raise TypeError(f"Expected str, Task, dict, or list, got {type(tasks)}") if not task_list: raise ValueError("No tasks to run") @@ -120,3 +139,100 @@ async def run_dataset( return ctx.results return [ctx] + + +async def run_single_task( + task: Task | dict[str, Any], + *, + agent_type: AgentType, + agent_params: dict[str, Any] | None = None, + max_steps: int = 10, + job_id: str | None = None, + task_id: str | None = None, + group_id: str | None = None, + trace_name: str | None = None, + metadata: dict[str, Any] | None = None, + trace_id: str | None = None, +) -> Trace: + """Run a single task with full control over eval context parameters. + + This is the low-level entry point for running individual tasks with explicit + trace/job/group IDs. Useful for remote execution workers. + + Args: + task: Task to run. Can be a Task object or dict (v4 or v5 format). + agent_type: AgentType enum specifying the agent to use. + agent_params: Parameters passed to agent.create(). Should include + pre-configured model_client for inference gateway usage. + max_steps: Maximum steps allowed for the agent. + job_id: HUD job identifier for telemetry association. + task_id: Task identifier (used in trace name if trace_name not provided). + group_id: Optional group identifier for parallel runs. + trace_name: Name for the trace (defaults to task_id or task.id). + metadata: Additional metadata for the trace context. + trace_id: Pre-assigned trace ID (if provided by backend). + + Returns: + Trace result from the agent run. + + Example: + ```python + from hud.datasets import run_single_task + from hud.types import AgentType + from openai import AsyncOpenAI + + # Configure agent with inference gateway + agent_params = { + "checkpoint_name": "gpt-4o", + "validate_api_key": False, + "model_client": AsyncOpenAI( + api_key=hud_api_key, + base_url=settings.hud_gateway_url, + ), + } + + result = await run_single_task( + task={"env": {"name": "browser"}, "scenario": "find_page"}, + agent_type=AgentType.OPENAI, + agent_params=agent_params, + max_steps=20, + job_id="job-123", + task_id="task-456", + ) + ``` + """ + from hud.datasets.loader import _task_from_dict + from hud.eval.task import Task as TaskCls + + # Normalize task to Task object + if isinstance(task, dict): + task_obj = _task_from_dict(task) + elif isinstance(task, TaskCls): + task_obj = task + else: + raise TypeError(f"Expected Task or dict, got {type(task)}") + + # Create agent + agent_cls = agent_type.cls + agent = agent_cls.create(**(agent_params or {})) + + # Determine trace name + effective_trace_name = trace_name or task_id or task_obj.id or "single_task" + + # Run with explicit eval context parameters + async with hud.eval( + task_obj, + name=effective_trace_name, + job_id=job_id, + group_id=group_id, + trace_id=trace_id, + ) as ctx: + # Store metadata if provided + if metadata: + for key, value in metadata.items(): + setattr(ctx, f"_meta_{key}", value) + + result = await agent.run(ctx, max_steps=max_steps) + ctx.reward = result.reward + + return result diff --git a/hud/datasets/utils.py b/hud/datasets/utils.py index 1ac829dd..04260186 100644 --- a/hud/datasets/utils.py +++ b/hud/datasets/utils.py @@ -21,7 +21,7 @@ class SingleTaskRequest(BaseModel): """Request to run a single task remotely - mirrors run_single_task() args.""" task: dict[str, Any] = Field( - description="Task definition compatible with hud.types.LegacyTask.", + description="Task definition (v4 LegacyTask or v5 Task format).", ) agent_type: AgentType = Field(description="Agent type to execute the task.") agent_params: dict[str, Any] = Field( @@ -32,20 +32,29 @@ class SingleTaskRequest(BaseModel): ) max_steps: int = Field(default=10, description="Maximum steps allowed for the agent.") job_id: str = Field(description="HUD job identifier for telemetry association.") - task_id: str = Field(description="Task identifier.") - trace_name: str = Field(description="Trace name.") + task_id: str | None = Field(default=None, description="Task identifier.") + trace_name: str | None = Field(default=None, description="Trace name.") group_id: str | None = Field(default=None, description="Optional HUD group identifier.") metadata: dict[str, Any] = Field( default_factory=dict, description="Additional metadata to inject into the trace context.", ) + trace_id: str | None = Field(default=None, description="Pre-assigned trace ID.") @model_validator(mode="after") def _validate_task(self) -> SingleTaskRequest: - try: - LegacyTask(**self.task) - except Exception as exc: - raise ValueError(f"Invalid task payload: {exc}") from exc + """Validate task is either v4 LegacyTask or v5 Task format.""" + from hud.datasets.loader import _is_legacy_task_format + + # v4 format: prompt + mcp_config + if _is_legacy_task_format(self.task): + try: + LegacyTask(**self.task) + except Exception as exc: + raise ValueError(f"Invalid legacy task payload: {exc}") from exc + # v5 format: env required + elif "env" not in self.task: + raise ValueError("Task must have 'env' (v5) or 'prompt'+'mcp_config' (v4)") return self @field_validator("job_id") diff --git a/hud/eval/context.py b/hud/eval/context.py index 5163ef36..1775c18c 100644 --- a/hud/eval/context.py +++ b/hud/eval/context.py @@ -57,7 +57,7 @@ class EvalContext(Environment): variants: Variant assignment dict (for A/B testing) reward: Reward value (user-settable) error: Exception if failed - results: All eval results (for parallel execution) + results: All eval results (populated for parallel execution, empty for single) task: Task definition (if loaded from slug) Example: @@ -138,8 +138,8 @@ def __init__( # Error tracking self.error: BaseException | None = None - # Parallel results - self.results: list[EvalContext] | None = None + # Parallel results (empty list for single evals, populated for parallel) + self.results: list[EvalContext] = [] # Code snippet for reproducibility self.code_snippet: str | None = code_snippet @@ -235,6 +235,7 @@ def from_task( cls, task: Task, *, + name: str | None = None, trace_id: str | None = None, api_key: str | None = None, job_id: str | None = None, @@ -249,6 +250,7 @@ def from_task( Args: task: Task config (env, scenario, args) + name: Override for eval/trace name (defaults to task scenario/args) trace_id: Unique trace ID api_key: API key for backend calls job_id: Job ID to link to @@ -259,9 +261,13 @@ def from_task( trace: Whether to send traces to backend quiet: Whether to suppress output """ + from hud.environment import Environment from hud.eval.task import build_eval_name - eval_name = build_eval_name(task.scenario, task.args) + eval_name = name or build_eval_name(task.scenario, task.args) + + # task.env is guaranteed to be Environment after Task.__post_init__ + assert isinstance(task.env, Environment), "Task.env should be Environment" ctx = cls.from_environment( env=task.env, diff --git a/hud/eval/manager.py b/hud/eval/manager.py index e5bb18dd..34fa7690 100644 --- a/hud/eval/manager.py +++ b/hud/eval/manager.py @@ -94,10 +94,13 @@ def _send_job_enter( async def run_eval( source: Task | list[Task] | None = None, *, + name: str | None = None, variants: dict[str, Any] | None = None, group: int = 1, group_ids: list[str] | None = None, job_id: str | None = None, + group_id: str | None = None, + trace_id: str | None = None, api_key: str | None = None, max_concurrent: int | None = None, trace: bool = True, @@ -115,10 +118,13 @@ async def run_eval( - list[Task]: List of Task objects - LegacyTask: Single LegacyTask object (deprecated, use Task.from_v4()) - list[LegacyTask]: List of LegacyTask objects (deprecated) + name: Optional name for the eval (used in trace) variants: A/B test configuration (dict with list values expanded) group: Runs per variant for statistical significance group_ids: Optional list of group IDs job_id: Job ID to link to + group_id: Group ID for parallel evaluations + trace_id: Pre-assigned trace ID (auto-generated if not provided) api_key: API key for backend calls max_concurrent: Maximum concurrent evals (None = unlimited) trace: Whether to send trace data to backend (default True) @@ -232,8 +238,11 @@ async def run_eval( # Single task - use EvalContext.from_task() ctx = EvalContext.from_task( tasks[0], + name=name, + trace_id=trace_id, api_key=api_key, job_id=job_id, + group_id=group_id, variants=variant_combos[0], code_snippet=code_snippet, trace=trace, @@ -244,9 +253,11 @@ async def run_eval( else: # Blank eval - use EvalContext directly ctx = EvalContext( - name="eval", + name=name or "eval", + trace_id=trace_id, api_key=api_key, job_id=job_id, + group_id=group_id, variants=variant_combos[0], code_snippet=code_snippet, trace=trace, From a32964a6a6397a525e3ccfe3608a79bb09289e06 Mon Sep 17 00:00:00 2001 From: lorenss-m Date: Sun, 14 Dec 2025 02:56:50 -0800 Subject: [PATCH 46/92] integartion test warning --- pyproject.toml | 3 +++ 1 file changed, 3 insertions(+) diff --git a/pyproject.toml b/pyproject.toml index 8333e6d9..e4bff6e2 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -248,3 +248,6 @@ asyncio_mode = "auto" testpaths = ["hud", "examples"] # Ignore the dev folder and other non-test directories addopts = "--ignore=dev --ignore=ref --ignore=test_env --ignore=environments" +markers = [ + "integration: marks tests as integration tests (require HUD_API_KEY, network access)", +] From a76e099a1af4a9839f421986cf5e55b6ebee89f7 Mon Sep 17 00:00:00 2001 From: lorenss-m Date: Sun, 14 Dec 2025 03:55:42 -0800 Subject: [PATCH 47/92] task loading improvements --- hud/cli/dev.py | 23 +++++-- hud/environment/environment.py | 4 +- hud/eval/task.py | 120 ++++++++++++++++++--------------- 3 files changed, 86 insertions(+), 61 deletions(-) diff --git a/hud/cli/dev.py b/hud/cli/dev.py index a7564360..9a91db52 100644 --- a/hud/cli/dev.py +++ b/hud/cli/dev.py @@ -101,10 +101,23 @@ def show_dev_server_info( return cursor_deeplink +def _has_mcp_or_env(content: str) -> bool: + """Check if file content defines an mcp or env variable.""" + # Check for mcp = MCPServer(...) or mcp = FastMCP(...) + if "mcp" in content and ("= MCPServer" in content or "= FastMCP" in content): + return True + # Check for env = Environment(...) + if "env" in content and "= Environment" in content: + return True + return False + + def auto_detect_module() -> tuple[str, Path | None] | tuple[None, None]: """Auto-detect MCP module in current directory. - Looks for 'mcp' defined in either __init__.py or server.py. + Looks for 'mcp' or 'env' defined in either __init__.py or main.py. + - 'mcp' with MCPServer or FastMCP + - 'env' with Environment Returns: Tuple of (module_name, parent_dir_to_add_to_path) or (None, None) @@ -116,7 +129,7 @@ def auto_detect_module() -> tuple[str, Path | None] | tuple[None, None]: if init_file.exists(): try: content = init_file.read_text(encoding="utf-8") - if "mcp" in content and ("= MCPServer" in content or "= FastMCP" in content): + if _has_mcp_or_env(content): return (cwd.name, None) except Exception: # noqa: S110 pass @@ -126,7 +139,7 @@ def auto_detect_module() -> tuple[str, Path | None] | tuple[None, None]: if main_file.exists() and init_file.exists(): try: content = main_file.read_text(encoding="utf-8") - if "mcp" in content and ("= MCPServer" in content or "= FastMCP" in content): + if _has_mcp_or_env(content): # Need to import as package.main, add parent to sys.path return (f"{cwd.name}.main", cwd.parent) except Exception: # noqa: S110 @@ -899,11 +912,11 @@ def run_mcp_dev_server( if module is None: module, extra_path = auto_detect_module() if module is None: - hud_console.error("Could not auto-detect MCP module in current directory") + hud_console.error("Could not auto-detect module in current directory") hud_console.info("") hud_console.info("[bold cyan]Expected:[/bold cyan]") hud_console.info(" • __init__.py file in current directory") - hud_console.info(" • Module must define 'mcp' variable") + hud_console.info(" • Module must define 'mcp' or 'env' variable") hud_console.info("") hud_console.info("[bold cyan]Examples:[/bold cyan]") hud_console.info(" hud dev controller") diff --git a/hud/environment/environment.py b/hud/environment/environment.py index 4d6804d7..1dbeb6fe 100644 --- a/hud/environment/environment.py +++ b/hud/environment/environment.py @@ -533,7 +533,7 @@ def __repr__(self) -> str: def __call__( self, - scenario: str, + scenario: str | None = None, **args: Any, ) -> Task: """Create a Task from this environment. @@ -541,7 +541,7 @@ def __call__( Returns a Task that can be passed to hud.eval() for orchestration. Args: - scenario: Scenario name to run (from @env.scenario) + scenario: Scenario name to run (from @env.scenario). Optional for v4 legacy. **args: Arguments for the scenario Returns: diff --git a/hud/eval/task.py b/hud/eval/task.py index 8b011e68..421cb075 100644 --- a/hud/eval/task.py +++ b/hud/eval/task.py @@ -1,4 +1,4 @@ -"""Task - A runnable evaluation unit (data class). +"""Task - A runnable evaluation unit (Pydantic model). A Task holds the configuration needed to run an evaluation: - Environment configuration (how to create/connect) @@ -26,9 +26,10 @@ from __future__ import annotations import logging -from dataclasses import dataclass, field from typing import TYPE_CHECKING, Any +from pydantic import BaseModel, ConfigDict, Field, field_validator + from hud.types import MCPToolCall if TYPE_CHECKING: @@ -85,9 +86,8 @@ def build_eval_name(scenario: str | None, args: dict[str, Any] | None) -> str: return scenario -@dataclass -class Task: - """A runnable evaluation unit (data class). +class Task(BaseModel): + """A runnable evaluation unit (Pydantic model). Simplified v5 Task format: - env: Environment instance OR EnvConfig with hub name + filters @@ -97,9 +97,9 @@ class Task: When entered as a context manager, creates an EvalContext. - Attributes: + Attributes: id: Optional task identifier for filtering/tracking - env: Environment instance (auto-created from dict/EnvConfig in __post_init__) + env: Environment instance (auto-created from dict/EnvConfig via validator) scenario: Scenario name to run (from @env.scenario) args: Scenario arguments validation: Optional list of MCPToolCall objects representing successful completion @@ -131,59 +131,67 @@ class Task: task = Task.from_v4({"prompt": "...", "mcp_config": {...}, ...}) ``` """ - # Required - env: Environment | EnvConfig | dict[str, Any] - # Optional + + model_config = ConfigDict(arbitrary_types_allowed=True) + + # Fields - env accepts Environment | EnvConfig | dict, auto-converts to Environment + env: Any = Field(default=None) # Typed as Any for input flexibility, validated below scenario: str | None = None id: str | None = None - args: dict[str, Any] = field(default_factory=dict) + args: dict[str, Any] = Field(default_factory=dict) validation: list[MCPToolCall] | None = None - def __post_init__(self) -> None: - """Validate and normalize env and validation fields after initialization. - - Auto-converts dict or EnvConfig to Environment by connecting to the hub. - Auto-converts validation dicts to MCPToolCall objects. - """ + @field_validator("env", mode="before") + @classmethod + def convert_env( + cls, v: Environment | EnvConfig | dict[str, Any] | None + ) -> Environment | None: + """Auto-convert dict/EnvConfig to Environment.""" from hud.environment import Environment from hud.environment.types import EnvConfig - # Convert env field (dict/EnvConfig -> Environment) - if not isinstance(self.env, Environment): - if isinstance(self.env, dict): - try: - config = EnvConfig(**self.env) - except Exception as e: - raise ValueError( - f"Invalid env config: {e}. Expected fields: name (str), " - f"include (list[str] | None), exclude (list[str] | None)" - ) from e - elif isinstance(self.env, EnvConfig): - config = self.env + if v is None: + return None + if isinstance(v, Environment): + return v + if isinstance(v, dict): + try: + v = EnvConfig(**v) + except Exception as e: + raise ValueError( + f"Invalid env config: {e}. Expected fields: name (str), " + f"include (list[str] | None), exclude (list[str] | None)" + ) from e + if isinstance(v, EnvConfig): + env = Environment(v.name) + env.connect_hub(v.name, include=v.include, exclude=v.exclude) + return env + raise TypeError( + f"Task.env must be Environment, EnvConfig, or dict. Got {type(v).__name__}" + ) + + @field_validator("validation", mode="before") + @classmethod + def convert_validation( + cls, v: list[MCPToolCall | dict[str, Any]] | None + ) -> list[MCPToolCall] | None: + """Auto-convert validation dicts to MCPToolCall objects.""" + if v is None: + return None + if not isinstance(v, list): + raise TypeError(f"validation must be a list, got {type(v).__name__}") + + converted = [] + for item in v: + if isinstance(item, dict): + converted.append(MCPToolCall(**item)) + elif isinstance(item, MCPToolCall): + converted.append(item) else: raise TypeError( - f"Task.env must be Environment, EnvConfig, or dict. " - f"Got {type(self.env).__name__}" + f"validation items must be dict or MCPToolCall, got {type(item).__name__}" ) - - # Convert EnvConfig to Environment - env = Environment(config.name) - env.connect_hub(config.name, include=config.include, exclude=config.exclude) - self.env = env - - # Convert validation dicts to MCPToolCall objects - if self.validation and isinstance(self.validation, list): - converted_validation = [] - for item in self.validation: - if isinstance(item, dict): - converted_validation.append(MCPToolCall(**item)) - elif isinstance(item, MCPToolCall): - converted_validation.append(item) - else: - raise TypeError( - f"validation items must be dict or MCPToolCall, got {type(item).__name__}" - ) - self.validation = converted_validation + return converted @classmethod def from_v4( @@ -302,11 +310,15 @@ def from_v4( ) def copy(self) -> Task: - """Create a copy of this Task config.""" + """Create a copy of this Task config. + + Note: env is shared (not deep copied) since Environment instances + should be reused. Args and validation are deep copied. + """ return Task( id=self.id, - env=self.env, # Share reference - from_environment handles copying + env=self.env, # Share reference scenario=self.scenario, - args=self.args.copy(), - validation=self.validation, + args=self.args.copy() if self.args else {}, + validation=self.validation.copy() if self.validation else None, ) From 1972d48041e2bdc88dfb35f209b5d3f75beea079 Mon Sep 17 00:00:00 2001 From: lorenss-m Date: Sun, 14 Dec 2025 03:57:54 -0800 Subject: [PATCH 48/92] eval test --- hud/eval/tests/test_eval.py | 122 ++---------------------------------- 1 file changed, 4 insertions(+), 118 deletions(-) diff --git a/hud/eval/tests/test_eval.py b/hud/eval/tests/test_eval.py index 2ee293fb..856a69d8 100644 --- a/hud/eval/tests/test_eval.py +++ b/hud/eval/tests/test_eval.py @@ -10,7 +10,7 @@ class TestTaskDataclass: - """Tests for Task as a data class.""" + """Tests for Task as a Pydantic model.""" def test_init_defaults(self) -> None: """Task initializes with sensible defaults.""" @@ -19,11 +19,9 @@ def test_init_defaults(self) -> None: assert task.env is None assert task.scenario is None assert task.args == {} - assert task.variants == {} - assert task.index == 0 def test_init_with_env_dict(self) -> None: - """Task auto-converts env dict to Environment in __post_init__.""" + """Task auto-converts env dict to Environment via validator.""" from hud.environment import Environment task = Task( @@ -43,126 +41,14 @@ def test_copy_creates_new_instance(self) -> None: env={"name": "test"}, scenario="checkout", args={"user_id": "alice"}, - variants={"model": "gpt-4o"}, ) copied = original.copy() assert copied is not original - assert copied.env == original.env + assert copied.env is original.env # Env reference is shared (intentional) assert copied.scenario == original.scenario assert copied.args == original.args - assert copied.args is not original.args # Deep copy - assert copied.variants == original.variants - assert copied.variants is not original.variants # Deep copy - - def test_copy_clears_trace_id(self) -> None: - """copy() clears trace_id for fresh instance.""" - original = Task(trace_id="original-trace") - copied = original.copy() - - assert copied.trace_id is None - - -class TestTaskToEvalContext: - """Tests for Task.to_eval_context().""" - - def test_creates_eval_context(self) -> None: - """to_eval_context() creates an EvalContext.""" - from hud.eval.context import EvalContext - - task = Task(scenario="checkout") - ctx = task.to_eval_context() - - assert isinstance(ctx, EvalContext) - assert ctx.eval_name == "checkout" - - def test_uses_eval_as_name_when_no_scenario(self) -> None: - """to_eval_context() uses 'eval' as name when no scenario.""" - task = Task() - ctx = task.to_eval_context() - - assert ctx.eval_name == "eval" - - def test_passes_through_properties(self) -> None: - """to_eval_context() passes through properties.""" - task = Task( - scenario="checkout", - trace_id="test-trace", - api_key="test-key", - job_id="test-job", - group_id="test-group", - index=5, - variants={"model": "gpt-4o"}, - ) - ctx = task.to_eval_context() - - assert ctx.trace_id == "test-trace" - assert ctx._eval_api_key == "test-key" - assert ctx.job_id == "test-job" - assert ctx.group_id == "test-group" - assert ctx.index == 5 - assert ctx.variants == {"model": "gpt-4o"} - - -class TestTaskContextManager: - """Tests for Task as async context manager.""" - - @pytest.mark.asyncio - async def test_aenter_returns_eval_context(self) -> None: - """__aenter__ returns an EvalContext.""" - from hud.eval.context import EvalContext - - task = Task() # No scenario to avoid scenario lookup - - with ( - patch.object(EvalContext, "_eval_enter", new_callable=AsyncMock), - patch.object(EvalContext, "_eval_exit", new_callable=AsyncMock), - patch.object(EvalContext, "__aexit__", new_callable=AsyncMock), - patch.object(EvalContext, "_print_eval_link"), # Suppress link printing - ): - ctx = await task.__aenter__() - assert isinstance(ctx, EvalContext) - # Clean up manually since we patched __aexit__ - task._ctx = None - - @pytest.mark.asyncio - async def test_context_clears_on_exit(self) -> None: - """__aexit__ clears internal context reference.""" - from hud.eval.context import EvalContext - - task = Task() - - with ( - patch.object(EvalContext, "_eval_enter", new_callable=AsyncMock), - patch.object(EvalContext, "_eval_exit", new_callable=AsyncMock), - patch.object(EvalContext, "__aexit__", new_callable=AsyncMock), - patch.object(EvalContext, "_print_eval_link"), # Suppress link printing - ): - await task.__aenter__() - assert task._ctx is not None - - # Manually call __aexit__ on Task (which will call mocked ctx.__aexit__) - await task.__aexit__(None, None, None) - assert task._ctx is None - - @pytest.mark.asyncio - async def test_reward_accessible_after_exit(self) -> None: - """Reward set in context is accessible after exit.""" - from hud.eval.context import EvalContext - - task = Task() - - with ( - patch.object(EvalContext, "_eval_enter", new_callable=AsyncMock), - patch.object(EvalContext, "_eval_exit", new_callable=AsyncMock), - patch.object(EvalContext, "__aexit__", new_callable=AsyncMock), - patch.object(EvalContext, "_print_eval_link"), # Suppress link printing - ): - ctx = await task.__aenter__() - ctx.reward = 0.95 - - await task.__aexit__(None, None, None) - # Context reference is cleared but reward was set on the actual context + assert copied.args is not original.args # Args are deep copied class TestEnvironmentCall: From 186c4b0698b5e41cc03158b6db9b5c93a939d901 Mon Sep 17 00:00:00 2001 From: lorenss-m Date: Sun, 14 Dec 2025 07:04:02 -0800 Subject: [PATCH 49/92] Huge cleanup, new telemetry and backwards compatibility --- hud/__init__.py | 18 + hud/agents/__init__.py | 12 +- hud/agents/base.py | 15 - hud/agents/claude.py | 4 +- hud/agents/gemini.py | 2 +- hud/agents/gemini_cua.py | 4 +- hud/agents/openai.py | 6 +- hud/agents/openai_chat.py | 8 +- hud/agents/operator.py | 9 +- hud/agents/tests/test_claude.py | 36 +- hud/agents/tests/test_client.py | 19 +- hud/agents/tests/test_gemini.py | 23 +- .../tests/test_grounded_openai_agent.py | 1 - hud/agents/tests/test_openai.py | 18 +- hud/agents/tests/test_operator.py | 3 +- hud/agents/tests/test_run_eval.py | 6 +- hud/cli/flows/tasks.py | 49 +- hud/cli/rft.py | 10 +- hud/cli/tests/test_dev.py | 38 +- hud/clients/__init__.py | 10 +- hud/clients/base.py | 17 - hud/clients/tests/test_analyze_scenarios.py | 2 +- hud/datasets/__init__.py | 20 +- hud/datasets/loader.py | 185 +++--- hud/datasets/runner.py | 152 ++--- hud/datasets/tests/test_runner.py | 67 -- hud/datasets/tests/test_utils.py | 100 ++- hud/datasets/utils.py | 311 +++------- hud/environment/connectors/mcp_config.py | 6 + hud/environment/connectors/remote.py | 10 + hud/environment/environment.py | 126 +++- hud/environment/integrations/adk.py | 2 +- hud/eval/__init__.py | 6 + hud/eval/context.py | 99 ++- hud/eval/display.py | 215 ++++--- hud/eval/manager.py | 2 +- hud/eval/task.py | 247 ++++---- hud/eval/tests/test_context.py | 6 - hud/eval/types.py | 1 + hud/eval/utils.py | 178 ++++++ hud/otel/__init__.py | 51 -- hud/otel/collector.py | 142 ----- hud/otel/config.py | 183 ------ hud/otel/context.py | 572 ------------------ hud/otel/exporters.py | 543 ----------------- hud/otel/instrumentation.py | 147 ----- hud/otel/processors.py | 121 ---- hud/otel/tests/__init__.py | 0 hud/otel/tests/test_instrumentation.py | 207 ------- hud/otel/tests/test_processors.py | 197 ------ hud/server/server.py | 4 +- hud/telemetry/__init__.py | 32 +- hud/telemetry/async_context.py | 345 ----------- hud/telemetry/exporter.py | 204 +++++++ hud/telemetry/instrument.py | 127 ++-- hud/telemetry/job.py | 355 ----------- hud/telemetry/replay.py | 74 --- hud/telemetry/tests/test_async_context.py | 515 ---------------- hud/telemetry/tests/test_eval_telemetry.py | 354 +++++++++++ hud/telemetry/tests/test_exporter.py | 254 ++++++++ hud/telemetry/tests/test_job.py | 555 ----------------- hud/telemetry/tests/test_replay.py | 40 -- hud/telemetry/tests/test_trace.py | 241 -------- hud/telemetry/trace.py | 166 ----- hud/telemetry/utils.py | 42 -- hud/tests/test_datasets_extended.py | 8 +- hud/tests/test_types.py | 42 -- hud/tools/grounding/grounder.py | 33 +- hud/types.py | 46 +- hud/utils/mcp.py | 52 -- hud/utils/tasks.py | 128 +--- hud/utils/tests/test_mcp.py | 25 +- hud/utils/tests/test_tasks.py | 356 ----------- scripts/pre_release_check.py | 2 +- 74 files changed, 2043 insertions(+), 6163 deletions(-) delete mode 100644 hud/datasets/tests/test_runner.py create mode 100644 hud/eval/utils.py delete mode 100644 hud/otel/__init__.py delete mode 100644 hud/otel/collector.py delete mode 100644 hud/otel/config.py delete mode 100644 hud/otel/context.py delete mode 100644 hud/otel/exporters.py delete mode 100644 hud/otel/instrumentation.py delete mode 100644 hud/otel/processors.py delete mode 100644 hud/otel/tests/__init__.py delete mode 100644 hud/otel/tests/test_instrumentation.py delete mode 100644 hud/otel/tests/test_processors.py delete mode 100644 hud/telemetry/async_context.py create mode 100644 hud/telemetry/exporter.py delete mode 100644 hud/telemetry/job.py delete mode 100644 hud/telemetry/replay.py delete mode 100644 hud/telemetry/tests/test_async_context.py create mode 100644 hud/telemetry/tests/test_eval_telemetry.py create mode 100644 hud/telemetry/tests/test_exporter.py delete mode 100644 hud/telemetry/tests/test_job.py delete mode 100644 hud/telemetry/tests/test_replay.py delete mode 100644 hud/telemetry/tests/test_trace.py delete mode 100644 hud/telemetry/trace.py delete mode 100644 hud/telemetry/utils.py delete mode 100644 hud/utils/tests/test_tasks.py diff --git a/hud/__init__.py b/hud/__init__.py index be6e8ee9..cf88add5 100644 --- a/hud/__init__.py +++ b/hud/__init__.py @@ -5,6 +5,8 @@ from __future__ import annotations +import warnings + # Apply patches to third-party libraries early, before other imports from . import patches as _patches # noqa: F401 from .environment import Environment @@ -12,11 +14,27 @@ from .eval import run_eval as eval from .telemetry.instrument import instrument + +def trace(*args: object, **kwargs: object) -> EvalContext: + """Deprecated: Use hud.eval() instead. + + .. deprecated:: 0.5.0 + hud.trace() is deprecated. Use hud.eval() or env.eval() instead. + """ + warnings.warn( + "hud.trace() is deprecated. Use hud.eval() or env.eval() instead.", + DeprecationWarning, + stacklevel=2, + ) + return eval(*args, **kwargs) # type: ignore[arg-type] + + __all__ = [ "Environment", "EvalContext", "eval", "instrument", + "trace", # Deprecated alias for eval ] try: diff --git a/hud/agents/__init__.py b/hud/agents/__init__.py index edcd569c..547d876b 100644 --- a/hud/agents/__init__.py +++ b/hud/agents/__init__.py @@ -1,17 +1,17 @@ from __future__ import annotations from .base import MCPAgent -from .claude import ClaudeAgent -from .gemini import GeminiAgent -from .gemini_cua import GeminiCUAAgent from .openai import OpenAIAgent from .openai_chat import OpenAIChatAgent from .operator import OperatorAgent +# Note: These agents are not exported here to avoid requiring optional dependencies. +# Import directly if needed: +# from hud.agents.claude import ClaudeAgent # requires anthropic +# from hud.agents.gemini import GeminiAgent # requires google-genai +# from hud.agents.gemini_cua import GeminiCUAAgent # requires google-genai + __all__ = [ - "ClaudeAgent", - "GeminiAgent", - "GeminiCUAAgent", "MCPAgent", "OpenAIAgent", "OpenAIChatAgent", diff --git a/hud/agents/base.py b/hud/agents/base.py index 3189dde7..831c59ce 100644 --- a/hud/agents/base.py +++ b/hud/agents/base.py @@ -30,7 +30,6 @@ class BaseCreateParams(BaseModel): # Primary way to bind agent to execution context (v5) ctx: Any | None = None # EvalContext or Environment - agent uses this for tool calls - auto_trace: bool = True auto_respond: bool = False verbose: bool = False @@ -93,10 +92,6 @@ def __init__(self, params: BaseCreateParams | None = None, **kwargs: Any) -> Non self._tool_map: dict[str, types.Tool] = {} self._initialized: bool = False - # Trace - self._auto_trace = params.auto_trace - self._auto_trace_cm: Any | None = None - @classmethod def create(cls, **kwargs: Any) -> MCPAgent: """ @@ -484,16 +479,6 @@ async def _filter_messages( async def _cleanup(self) -> None: """Cleanup resources.""" - # Clean up auto-created trace if any - if self._auto_trace_cm: - try: - self._auto_trace_cm.__exit__(None, None, None) - self.console.debug("Closed auto-created trace") - except Exception as e: - self.console.warning_log(f"Failed to close auto-created trace: {e}") - finally: - self._auto_trace_cm = None - # Clear context reference self.ctx = None diff --git a/hud/agents/claude.py b/hud/agents/claude.py index 9ef7736b..39229693 100644 --- a/hud/agents/claude.py +++ b/hud/agents/claude.py @@ -96,11 +96,11 @@ def _on_tools_ready(self) -> None: """Build Claude-specific tool mappings after tools are discovered.""" self._convert_tools_for_claude() - async def get_system_messages(self) -> list[Any]: + async def get_system_messages(self) -> list[BetaMessageParam]: """No system messages for Claude because applied in get_response""" return [] - async def format_blocks(self, blocks: list[types.ContentBlock]) -> list[Any]: + async def format_blocks(self, blocks: list[types.ContentBlock]) -> list[BetaMessageParam]: """Format messages for Claude.""" # Convert MCP content types to Anthropic content types anthropic_blocks: list[BetaContentBlockParam] = [] diff --git a/hud/agents/gemini.py b/hud/agents/gemini.py index ceda15a4..c405f05a 100644 --- a/hud/agents/gemini.py +++ b/hud/agents/gemini.py @@ -88,7 +88,7 @@ def _on_tools_ready(self) -> None: """Build Gemini-specific tool mappings after tools are discovered.""" self._convert_tools_for_gemini() - async def get_system_messages(self) -> list[Any]: + async def get_system_messages(self) -> list[genai_types.Content]: """No system messages for Gemini because applied in get_response""" return [] diff --git a/hud/agents/gemini_cua.py b/hud/agents/gemini_cua.py index 8ad02b21..75d8da15 100644 --- a/hud/agents/gemini_cua.py +++ b/hud/agents/gemini_cua.py @@ -10,7 +10,7 @@ from pydantic import ConfigDict, Field from hud.tools.computer.settings import computer_settings -from hud.types import BaseAgentConfig, MCPToolCall, MCPToolResult +from hud.types import AgentResponse, BaseAgentConfig, MCPToolCall, MCPToolResult from hud.utils.types import with_signature from .base import BaseCreateParams, MCPAgent @@ -126,7 +126,7 @@ def _to_gemini_tool(self, tool: types.Tool) -> genai_types.Tool | None: # For non-computer tools, use the parent implementation return super()._to_gemini_tool(tool) - async def get_response(self, messages: list[genai_types.Content]) -> Any: + async def get_response(self, messages: list[genai_types.Content]) -> AgentResponse: """Get response from Gemini including any tool calls. Extends parent to trim old screenshots before making API call. diff --git a/hud/agents/openai.py b/hud/agents/openai.py index ac7f36fb..10b2ad12 100644 --- a/hud/agents/openai.py +++ b/hud/agents/openai.py @@ -201,7 +201,7 @@ async def get_system_messages(self) -> list[types.ContentBlock]: """System messages are provided via the `instructions` field.""" return [] - async def format_blocks(self, blocks: list[types.ContentBlock]) -> ResponseInputParam: + async def format_blocks(self, blocks: list[types.ContentBlock]) -> list[Message]: """Convert MCP content blocks into OpenAI user messages.""" content: ResponseInputMessageContentListParam = [] for block in blocks: @@ -288,9 +288,9 @@ async def get_response(self, messages: ResponseInputParam) -> AgentResponse: async def format_tool_results( self, tool_calls: list[MCPToolCall], tool_results: list[MCPToolResult] - ) -> ResponseInputParam: + ) -> list[FunctionCallOutput]: """Convert MCP tool outputs into Responses input items.""" - formatted: ResponseInputParam = [] + formatted: list[FunctionCallOutput] = [] for call, result in zip(tool_calls, tool_results, strict=False): if not call.id: self.console.warning_log(f"Tool '{call.name}' missing call_id; skipping output.") diff --git a/hud/agents/openai_chat.py b/hud/agents/openai_chat.py index 810163ef..c5eeb14d 100644 --- a/hud/agents/openai_chat.py +++ b/hud/agents/openai_chat.py @@ -105,14 +105,14 @@ def _oai_to_mcp(tool_call: Any) -> MCPToolCall: # type: ignore[valid-type] arguments=args, ) - async def get_system_messages(self) -> list[Any]: + async def get_system_messages(self) -> list[dict[str, Any]]: """Get system messages for OpenAI.""" if self.system_prompt is not None: return [{"role": "system", "content": self.system_prompt}] else: return [] - async def format_blocks(self, blocks: list[types.ContentBlock]) -> list[Any]: + async def format_blocks(self, blocks: list[types.ContentBlock]) -> list[dict[str, Any]]: """Format blocks for OpenAI.""" content = [] for block in blocks: @@ -232,7 +232,7 @@ async def _invoke_chat_completion( record_args=False, record_result=True, ) - async def get_response(self, messages: list[Any]) -> AgentResponse: + async def get_response(self, messages: list[dict[str, Any]]) -> AgentResponse: """Send chat request to OpenAI and convert the response.""" # Convert MCP tool schemas to OpenAI format @@ -305,7 +305,7 @@ async def format_tool_results( self, tool_calls: list[MCPToolCall], tool_results: list[MCPToolResult], - ) -> list[Any]: + ) -> list[dict[str, Any]]: """Render MCP tool results as OpenAI messages. Note: OpenAI tool messages only support string content. diff --git a/hud/agents/operator.py b/hud/agents/operator.py index eab72772..f16deeb6 100644 --- a/hud/agents/operator.py +++ b/hud/agents/operator.py @@ -15,6 +15,7 @@ ) from openai.types.responses.response_input_param import ( ComputerCallOutput, + FunctionCallOutput, ) from openai.types.shared_params.reasoning import Reasoning from pydantic import ConfigDict @@ -144,10 +145,10 @@ def _extract_tool_call(self, item: Any) -> MCPToolCall | None: async def format_tool_results( self, tool_calls: list[MCPToolCall], tool_results: list[MCPToolResult] - ) -> ResponseInputParam: + ) -> list[ComputerCallOutput | FunctionCallOutput]: remaining_calls: list[MCPToolCall] = [] remaining_results: list[MCPToolResult] = [] - computer_outputs: ResponseInputParam = [] + computer_outputs: list[ComputerCallOutput] = [] ordering: list[tuple[str, int]] = [] for call, result in zip(tool_calls, tool_results, strict=False): @@ -186,8 +187,8 @@ async def format_tool_results( remaining_results.append(result) ordering.append(("function", len(remaining_calls) - 1)) - formatted: ResponseInputParam = [] - function_outputs: ResponseInputParam = [] + formatted: list[ComputerCallOutput | FunctionCallOutput] = [] + function_outputs: list[FunctionCallOutput] = [] if remaining_calls: function_outputs = await super().format_tool_results(remaining_calls, remaining_results) diff --git a/hud/agents/tests/test_claude.py b/hud/agents/tests/test_claude.py index 6ec9bd7c..125e3e7e 100644 --- a/hud/agents/tests/test_claude.py +++ b/hud/agents/tests/test_claude.py @@ -2,6 +2,7 @@ from __future__ import annotations +from collections.abc import Generator from typing import TYPE_CHECKING, Any, cast from unittest.mock import AsyncMock, MagicMock, patch @@ -107,7 +108,7 @@ class TestClaudeAgent: """Test ClaudeAgent class.""" @pytest.fixture - def mock_anthropic(self) -> AsyncAnthropic: + def mock_anthropic(self) -> Generator[AsyncAnthropic, None, None]: # type: ignore[misc] """Create a stub Anthropic client.""" with ( patch("hud.agents.claude.AsyncAnthropic") as mock_class, @@ -119,7 +120,7 @@ def mock_anthropic(self) -> AsyncAnthropic: client = MagicMock(spec=AsyncAnthropic) client.api_key = "test-key" mock_class.return_value = client - yield client + yield client # type: ignore[misc] @pytest.mark.asyncio async def test_init_with_client(self, mock_anthropic: AsyncAnthropic) -> None: @@ -162,9 +163,11 @@ async def test_format_blocks_text_only(self, mock_anthropic: AsyncAnthropic) -> messages = await agent.format_blocks(blocks) assert len(messages) == 1 assert messages[0]["role"] == "user" - assert len(messages[0]["content"]) == 2 - assert messages[0]["content"][0]["type"] == "text" - assert messages[0]["content"][0]["text"] == "Hello, world!" + content = messages[0]["content"] + assert isinstance(content, list) + assert len(content) == 2 + assert content[0]["type"] == "text" # type: ignore[index] + assert content[0]["text"] == "Hello, world!" # type: ignore[index] @pytest.mark.asyncio async def test_format_blocks_with_image(self, mock_anthropic: AsyncAnthropic) -> None: @@ -181,8 +184,10 @@ async def test_format_blocks_with_image(self, mock_anthropic: AsyncAnthropic) -> messages = await agent.format_blocks(blocks) assert len(messages) == 1 - assert len(messages[0]["content"]) == 2 - assert messages[0]["content"][1]["type"] == "image" + content = messages[0]["content"] + assert isinstance(content, list) + assert len(content) == 2 + assert content[1]["type"] == "image" # type: ignore[index] @pytest.mark.asyncio async def test_format_tool_results_text(self, mock_anthropic: AsyncAnthropic) -> None: @@ -204,9 +209,10 @@ async def test_format_tool_results_text(self, mock_anthropic: AsyncAnthropic) -> assert len(messages) == 1 assert messages[0]["role"] == "user" content = messages[0]["content"] + assert isinstance(content, list) assert len(content) == 1 - assert content[0]["type"] == "tool_result" - assert content[0]["tool_use_id"] == "call_123" + assert content[0]["type"] == "tool_result" # type: ignore[index] + assert content[0]["tool_use_id"] == "call_123" # type: ignore[index] @pytest.mark.asyncio async def test_format_tool_results_with_error(self, mock_anthropic: AsyncAnthropic) -> None: @@ -228,7 +234,7 @@ async def test_format_tool_results_with_error(self, mock_anthropic: AsyncAnthrop assert len(messages) == 1 content = messages[0]["content"] # Error content should include "Error:" prefix - assert any("Error" in str(block) for block in content[0]["content"]) + assert any("Error" in str(block) for block in content[0]["content"]) # type: ignore[index] @pytest.mark.asyncio async def test_get_system_messages(self, mock_anthropic: AsyncAnthropic) -> None: @@ -305,7 +311,7 @@ async def test_convert_tools_for_claude(self, mock_anthropic: AsyncAnthropic) -> # Check that tools were converted assert len(agent.claude_tools) == 1 - assert agent.claude_tools[0]["name"] == "my_tool" + assert agent.claude_tools[0]["name"] == "my_tool" # type: ignore[typeddict-item] @pytest.mark.asyncio async def test_computer_tool_detection(self, mock_anthropic: AsyncAnthropic) -> None: @@ -432,7 +438,7 @@ async def test_get_response_bedrock_uses_create_not_stream( text_block.text = "Hello from Bedrock" mock_response.content = [text_block] - bedrock_client.beta.messages.create.return_value = mock_response + bedrock_client.beta.messages.create.return_value = mock_response # type: ignore[union-attr] messages = [ cast( @@ -447,12 +453,12 @@ async def test_get_response_bedrock_uses_create_not_stream( # Bedrock-specific behavior: uses create() and appends assistant message directly. assert not hasattr(bedrock_client.beta.messages, "stream") - bedrock_client.beta.messages.create.assert_awaited_once() + bedrock_client.beta.messages.create.assert_awaited_once() # type: ignore[union-attr] assert len(messages) == 2 assert messages[-1]["role"] == "assistant" # Ensure the Bedrock call shape is stable. - _, kwargs = bedrock_client.beta.messages.create.call_args + _, kwargs = bedrock_client.beta.messages.create.call_args # type: ignore[union-attr] assert kwargs["model"] == "test-model-arn" assert kwargs["tool_choice"] == {"type": "auto", "disable_parallel_tool_use": True} assert "fine-grained-tool-streaming-2025-05-14" in kwargs["betas"] @@ -470,7 +476,7 @@ async def test_get_response_bedrock_missing_boto3_raises_value_error( validate_api_key=False, ) - bedrock_client.beta.messages.create.side_effect = ModuleNotFoundError("boto3") + bedrock_client.beta.messages.create.side_effect = ModuleNotFoundError("boto3") # type: ignore[union-attr] messages = [{"role": "user", "content": [{"type": "text", "text": "Hi"}]}] with pytest.raises(ValueError, match=r"boto3 is required for AWS Bedrock"): diff --git a/hud/agents/tests/test_client.py b/hud/agents/tests/test_client.py index 11132506..c4f86fdf 100644 --- a/hud/agents/tests/test_client.py +++ b/hud/agents/tests/test_client.py @@ -15,7 +15,6 @@ logger = logging.getLogger(__name__) -@patch("hud.clients.base.setup_hud_telemetry") class TestMCPClient: """Test MCPClient class.""" @@ -34,7 +33,7 @@ def mock_mcp_use_client(self): yield mock_instance @pytest.mark.asyncio - async def test_connect_single_server(self, mock_telemetry, mock_mcp_use_client): + async def test_connect_single_server(self, mock_mcp_use_client): """Test connecting to a single server.""" config = {"test_server": {"command": "python", "args": ["-m", "test_server"]}} @@ -77,7 +76,7 @@ async def mock_list_tools(): assert names == {"tool1", "tool2"} @pytest.mark.asyncio - async def test_connect_multiple_servers(self, mock_telemetry, mock_mcp_use_client): + async def test_connect_multiple_servers(self, mock_mcp_use_client): """Test connecting to multiple servers.""" config = { "server1": {"command": "python", "args": ["-m", "server1"]}, @@ -129,7 +128,7 @@ async def mock_list_tools2(): assert names == {"server1_tool1", "server2_tool2"} @pytest.mark.asyncio - async def test_call_tool(self, mock_telemetry, mock_mcp_use_client): + async def test_call_tool(self, mock_mcp_use_client): """Test calling a tool.""" config = {"test": {"command": "test"}} client = MCPClient(mcp_config=config) @@ -180,7 +179,7 @@ async def mock_list_tools(): ) @pytest.mark.asyncio - async def test_call_tool_not_found(self, mock_telemetry, mock_mcp_use_client): + async def test_call_tool_not_found(self, mock_mcp_use_client): """Test calling a non-existent tool.""" config = {"test": {"command": "test"}} client = MCPClient(mcp_config=config) @@ -208,7 +207,7 @@ async def mock_list_tools(): assert "Tool 'nonexistent' not found" in text_content @pytest.mark.asyncio - async def test_get_telemetry_data(self, mock_telemetry, mock_mcp_use_client): + async def test_get_telemetry_data(self, mock_mcp_use_client): """Test getting telemetry data.""" config = {"test": {"command": "test"}} client = MCPClient(mcp_config=config) @@ -245,7 +244,7 @@ async def mock_list_tools(): assert isinstance(telemetry_data, dict) @pytest.mark.asyncio - async def test_close(self, mock_telemetry, mock_mcp_use_client): + async def test_close(self, mock_mcp_use_client): """Test closing client connections.""" config = {"test": {"command": "test"}} client = MCPClient(mcp_config=config) @@ -267,7 +266,7 @@ async def mock_list_tools(): mock_mcp_use_client.close_all_sessions.assert_called_once() @pytest.mark.asyncio - async def test_context_manager(self, mock_telemetry, mock_mcp_use_client): + async def test_context_manager(self, mock_mcp_use_client): """Test using client as context manager.""" mock_session = MagicMock() mock_session.connector = MagicMock() @@ -291,7 +290,7 @@ async def mock_list_tools(): mock_mcp_use_client.close_all_sessions.assert_called_once() @pytest.mark.asyncio - async def test_get_available_tools(self, mock_telemetry, mock_mcp_use_client): + async def test_get_available_tools(self, mock_mcp_use_client): """Test getting available tools.""" config = {"test": {"command": "test"}} client = MCPClient(mcp_config=config) @@ -319,7 +318,7 @@ async def mock_list_tools(): assert names == {"tool1", "tool2"} @pytest.mark.asyncio - async def test_get_tool_map(self, mock_telemetry, mock_mcp_use_client): + async def test_get_tool_map(self, mock_mcp_use_client): """Test getting tool map.""" config = {"test": {"command": "test"}} client = MCPClient(mcp_config=config) diff --git a/hud/agents/tests/test_gemini.py b/hud/agents/tests/test_gemini.py index b2c1f4b3..74593d0d 100644 --- a/hud/agents/tests/test_gemini.py +++ b/hud/agents/tests/test_gemini.py @@ -98,6 +98,7 @@ async def test_format_blocks_text_only(self, mock_gemini_client: genai.Client) - messages = await agent.format_blocks(blocks) assert len(messages) == 1 assert messages[0].role == "user" + assert messages[0].parts is not None assert len(messages[0].parts) == 2 @pytest.mark.asyncio @@ -118,6 +119,7 @@ async def test_format_blocks_with_image(self, mock_gemini_client: genai.Client) messages = await agent.format_blocks(blocks) assert len(messages) == 1 + assert messages[0].parts is not None assert len(messages[0].parts) == 2 @pytest.mark.asyncio @@ -181,7 +183,7 @@ async def test_get_response_text_only(self, mock_gemini_client: genai.Client) -> mock_gemini_client.models.generate_content = MagicMock(return_value=mock_response) - messages = [genai_types.Content(role="user", parts=[genai_types.Part.from_text("Status?")])] + messages = [genai_types.Content(role="user", parts=[genai_types.Part.from_text(text="Status?")])] response = await agent.get_response(messages) assert response.content == "Task completed successfully" @@ -221,7 +223,7 @@ async def test_get_response_with_thinking(self, mock_gemini_client: genai.Client mock_gemini_client.models.generate_content = MagicMock(return_value=mock_response) messages = [ - genai_types.Content(role="user", parts=[genai_types.Part.from_text("Hard question")]) + genai_types.Content(role="user", parts=[genai_types.Part.from_text(text="Hard question")]) ] response = await agent.get_response(messages) @@ -249,8 +251,11 @@ async def test_convert_tools_for_gemini(self, mock_gemini_client: genai.Client) # Check that tools were converted assert len(agent.gemini_tools) == 1 - # Gemini tools have function_declarations - assert agent.gemini_tools[0].function_declarations[0].name == "my_tool" + # Gemini tools have function_declarations - cast to genai Tool type + gemini_tool = agent.gemini_tools[0] + assert isinstance(gemini_tool, genai_types.Tool) + assert gemini_tool.function_declarations is not None + assert gemini_tool.function_declarations[0].name == "my_tool" class TestGeminiToolConversion: @@ -290,10 +295,12 @@ async def test_tool_with_properties(self, mock_gemini_client: genai.Client) -> N await agent._initialize_from_ctx(ctx) assert len(agent.gemini_tools) == 1 - tool = agent.gemini_tools[0] - # Gemini tools have function_declarations - assert tool.function_declarations[0].name == "search" - assert tool.function_declarations[0].parameters_json_schema is not None + gemini_tool = agent.gemini_tools[0] + # Gemini tools have function_declarations - cast to genai Tool type + assert isinstance(gemini_tool, genai_types.Tool) + assert gemini_tool.function_declarations is not None + assert gemini_tool.function_declarations[0].name == "search" + assert gemini_tool.function_declarations[0].parameters_json_schema is not None @pytest.mark.asyncio async def test_tool_without_schema(self, mock_gemini_client: genai.Client) -> None: diff --git a/hud/agents/tests/test_grounded_openai_agent.py b/hud/agents/tests/test_grounded_openai_agent.py index a88dc924..ff8b2bfe 100644 --- a/hud/agents/tests/test_grounded_openai_agent.py +++ b/hud/agents/tests/test_grounded_openai_agent.py @@ -130,7 +130,6 @@ async def test_get_response_with_reasoning() -> None: grounder_config=grounder_cfg, openai_client=fake_openai, checkpoint_name="gpt-4o-mini", - mcp_client=FakeMCPClient(), initial_screenshot=False, ) diff --git a/hud/agents/tests/test_openai.py b/hud/agents/tests/test_openai.py index eff539b3..d65acf18 100644 --- a/hud/agents/tests/test_openai.py +++ b/hud/agents/tests/test_openai.py @@ -2,7 +2,8 @@ from __future__ import annotations -from typing import Any +from collections.abc import Generator +from typing import Any, cast from unittest.mock import AsyncMock, patch import pytest @@ -47,14 +48,14 @@ class TestOpenAIAgent: """Test OpenAIAgent class.""" @pytest.fixture - def mock_openai(self) -> AsyncOpenAI: + def mock_openai(self) -> Generator[AsyncOpenAI, None, None]: # type: ignore[misc] """Create a stub OpenAI client.""" with patch("hud.agents.openai.AsyncOpenAI") as mock_class: client = AsyncOpenAI(api_key="test", base_url="http://localhost") client.chat.completions.create = AsyncMock() client.responses.create = AsyncMock() mock_class.return_value = client - yield client + yield client # type: ignore[misc] @pytest.mark.asyncio async def test_init_with_client(self, mock_openai: AsyncOpenAI) -> None: @@ -137,7 +138,7 @@ async def test_format_blocks_with_image(self, mock_openai: AsyncOpenAI) -> None: assert len(messages) == 1 assert len(messages[0]["content"]) == 2 assert messages[0]["content"][1]["type"] == "input_image" - assert messages[0]["content"][1]["image_url"] == "data:image/png;base64,base64data" + assert messages[0]["content"][1]["image_url"] == "data:image/png;base64,base64data" # type: ignore[typeddict-item] @pytest.mark.asyncio async def test_format_blocks_empty(self, mock_openai: AsyncOpenAI) -> None: @@ -176,7 +177,7 @@ async def test_format_tool_results_text(self, mock_openai: AsyncOpenAI) -> None: assert messages[0]["call_id"] == "call_123" # Output is a list of content items assert len(messages[0]["output"]) == 1 - assert messages[0]["output"][0]["text"] == "Tool output" + assert messages[0]["output"][0]["text"] == "Tool output" # type: ignore[index] @pytest.mark.asyncio async def test_format_tool_results_with_error(self, mock_openai: AsyncOpenAI) -> None: @@ -197,7 +198,8 @@ async def test_format_tool_results_with_error(self, mock_openai: AsyncOpenAI) -> messages = await agent.format_tool_results(tool_calls, tool_results) assert len(messages) == 1 # Output is a list; first item is error indicator, second is the message - output = messages[0]["output"] + msg = cast("dict[str, Any]", messages[0]) + output = cast("list[dict[str, Any]]", msg["output"]) assert any(item.get("text") == "[tool_error] true" for item in output) assert any(item.get("text") == "Error message" for item in output) @@ -357,13 +359,13 @@ class TestOpenAIToolConversion: """Tests for tool conversion to OpenAI format.""" @pytest.fixture - def mock_openai(self) -> AsyncOpenAI: + def mock_openai(self) -> Generator[AsyncOpenAI, None, None]: # type: ignore[misc] """Create a stub OpenAI client.""" with patch("hud.agents.openai.AsyncOpenAI") as mock_class: client = AsyncOpenAI(api_key="test", base_url="http://localhost") client.responses.create = AsyncMock() mock_class.return_value = client - yield client + yield client # type: ignore[misc] @pytest.mark.asyncio async def test_shell_tool_conversion(self, mock_openai: AsyncOpenAI) -> None: diff --git a/hud/agents/tests/test_operator.py b/hud/agents/tests/test_operator.py index 94861522..c4e79cd0 100644 --- a/hud/agents/tests/test_operator.py +++ b/hud/agents/tests/test_operator.py @@ -2,6 +2,7 @@ from __future__ import annotations +from collections.abc import Generator from typing import Any, cast from unittest.mock import AsyncMock, MagicMock, patch @@ -41,7 +42,7 @@ class TestOperatorAgent: """Test OperatorAgent class.""" @pytest.fixture - def mock_openai(self) -> AsyncOpenAI: + def mock_openai(self) -> Generator[AsyncOpenAI, None, None]: """Create a mock OpenAI client.""" client = AsyncOpenAI(api_key="test", base_url="http://localhost") client.responses.create = AsyncMock() diff --git a/hud/agents/tests/test_run_eval.py b/hud/agents/tests/test_run_eval.py index 46eea596..d66b284f 100644 --- a/hud/agents/tests/test_run_eval.py +++ b/hud/agents/tests/test_run_eval.py @@ -48,7 +48,11 @@ async def get_system_messages(self) -> list[Any]: return [] async def format_blocks(self, blocks: list[types.ContentBlock]) -> list[Any]: - return [{"type": "text", "text": b.text} for b in blocks if hasattr(b, "text")] + return [ + {"type": "text", "text": getattr(b, "text")} + for b in blocks + if hasattr(b, "text") + ] class MockEvalContext(EvalContext): diff --git a/hud/cli/flows/tasks.py b/hud/cli/flows/tasks.py index a0921766..a46af389 100644 --- a/hud/cli/flows/tasks.py +++ b/hud/cli/flows/tasks.py @@ -4,7 +4,7 @@ import logging import re from pathlib import Path -from typing import TYPE_CHECKING, Any +from typing import Any import typer import yaml @@ -13,11 +13,8 @@ from hud.cli.utils.docker import require_docker_running from hud.cli.utils.env_check import find_environment_dir from hud.cli.utils.registry import extract_name_and_tag +from hud.datasets import load_dataset from hud.utils.hud_console import hud_console -from hud.utils.tasks import load_tasks - -if TYPE_CHECKING: - from hud.types import LegacyTask logger = logging.getLogger(__name__) @@ -29,7 +26,7 @@ def _is_remote_url(url: str) -> bool: return bool(re.match(r"^(https?:\/\/)?(www\.)?[a-zA-Z0-9\-\.]+\.[a-zA-Z]{2,}(\/\S*)?$", url)) -def _validate_tasks(tasks: list[LegacyTask]) -> bool: +def _validate_tasks(tasks: list[dict[str, Any]]) -> bool: """Validate the tasks file: return True if tasks already reference a remote MCP URL. A task is considered remote if any "url" field anywhere inside mcp_config @@ -50,7 +47,7 @@ def _has_remote_url(obj: Any) -> bool: return False for task in tasks: - cfg = task.mcp_config or {} + cfg = task.get("mcp_config") or {} if not _has_remote_url(cfg): return False return True @@ -115,7 +112,7 @@ def _derive_remote_image(lock_data: dict[str, Any]) -> str: raise typer.Exit(1) -def _extract_existing_images(tasks: list[LegacyTask]) -> set[str]: +def _extract_existing_images(tasks: list[dict[str, Any]]) -> set[str]: """Extract all Mcp-Image references from tasks.""" images = set() @@ -134,8 +131,9 @@ def _extract_from_obj(obj: Any) -> None: _extract_from_obj(item) for task in tasks: - if task.mcp_config: - _extract_from_obj(task.mcp_config) + mcp_config = task.get("mcp_config") + if mcp_config: + _extract_from_obj(mcp_config) return images @@ -267,11 +265,12 @@ def convert_tasks_to_remote(tasks_file: str) -> str: """ tasks_path = Path(tasks_file).resolve() - # Load validated tasks for decision-making (may resolve env vars) - tasks: list[LegacyTask] = load_tasks(str(tasks_path)) # type: ignore[assignment] + # Load raw tasks - we work with dicts directly to preserve placeholders + # when writing back to disk (e.g., ${HUD_API_KEY}) + raw_tasks: list[dict[str, Any]] = load_dataset(str(tasks_path), raw=True) # type: ignore[assignment] - # Load raw tasks to preserve placeholders when writing back to disk - raw_tasks: list[dict[str, Any]] = load_tasks(str(tasks_path), raw=True) # type: ignore[assignment] + # Use the same raw tasks for validation (they have mcp_config structure) + tasks = raw_tasks # Ensure HUD_API_KEY is available: prefer process env, else load from env_dir/.env from hud.settings import settings @@ -446,7 +445,7 @@ def _one(x: Any) -> dict[str, Any]: tasks_payload: list[dict[str, Any]] = [] for t in tasks: item: dict[str, Any] = { - "prompt": t.prompt, + "prompt": t.get("prompt"), "mcp_config": { "hud": { "url": settings.hud_mcp_url, @@ -462,16 +461,16 @@ def _one(x: Any) -> dict[str, Any]: item["mcp_config"]["hud"]["headers"].update(extra_api_key_headers) # Optional fields, omit Nones - if t.setup_tool is not None: - item["setup_tool"] = _simplify_tool_call(t.setup_tool) - if t.evaluate_tool is not None: - item["evaluate_tool"] = _simplify_tool_call(t.evaluate_tool) - if t.agent_config is not None: - item["agent_config"] = t.agent_config - if t.metadata: - item["metadata"] = t.metadata - if t.id is not None: - item["id"] = t.id + if t.get("setup_tool") is not None: + item["setup_tool"] = _simplify_tool_call(t["setup_tool"]) + if t.get("evaluate_tool") is not None: + item["evaluate_tool"] = _simplify_tool_call(t["evaluate_tool"]) + if t.get("agent_config") is not None: + item["agent_config"] = t["agent_config"] + if t.get("metadata"): + item["metadata"] = t["metadata"] + if t.get("id") is not None: + item["id"] = t["id"] tasks_payload.append(item) diff --git a/hud/cli/rft.py b/hud/cli/rft.py index 53d336ce..eccdb9d1 100644 --- a/hud/cli/rft.py +++ b/hud/cli/rft.py @@ -8,9 +8,9 @@ from rich.console import Console from rich.table import Table +from hud.datasets import load_dataset from hud.settings import settings from hud.utils.hud_console import HUDConsole -from hud.utils.tasks import load_tasks logger = logging.getLogger(__name__) console = Console() @@ -192,12 +192,8 @@ def rft_command( # Load and validate tasks try: - # Load tasks with env vars already resolved - from hud.types import LegacyTask # noqa: TC001 - - tasks_objects: list[LegacyTask] = load_tasks(tasks_file) # type: ignore[assignment] - # Convert to dicts for patching and serialization - tasks: list[dict[str, Any]] = [t.model_dump() for t in tasks_objects] + # Load tasks as raw dicts for patching and serialization + tasks: list[dict[str, Any]] = load_dataset(tasks_file, raw=True) # type: ignore[assignment] if not tasks: hud_console.error(f"No tasks found in {tasks_file}") raise typer.Exit(1) diff --git a/hud/cli/tests/test_dev.py b/hud/cli/tests/test_dev.py index d1027303..0cfcfc16 100644 --- a/hud/cli/tests/test_dev.py +++ b/hud/cli/tests/test_dev.py @@ -81,8 +81,42 @@ def test_detect_module_from_main_py(self, tmp_path, monkeypatch): assert module_name == f"{tmp_path.name}.main" assert extra_path == tmp_path.parent - def test_no_detection_without_mcp(self, tmp_path, monkeypatch): - """Test no detection when mcp not defined.""" + def test_detect_module_from_init_with_environment(self, tmp_path, monkeypatch): + """Test detection from __init__.py with Environment.""" + monkeypatch.chdir(tmp_path) + + init_file = tmp_path / "__init__.py" + init_file.write_text(""" +from hud import Environment +env = Environment(name='test') +""") + + module_name, extra_path = auto_detect_module() + + assert module_name == tmp_path.name + assert extra_path is None + + def test_detect_module_from_main_py_with_environment(self, tmp_path, monkeypatch): + """Test detection from main.py with Environment.""" + monkeypatch.chdir(tmp_path) + + # Need both __init__.py and main.py + init_file = tmp_path / "__init__.py" + init_file.write_text("") + + main_file = tmp_path / "main.py" + main_file.write_text(""" +from hud import Environment +env = Environment(name='test') +""") + + module_name, extra_path = auto_detect_module() + + assert module_name == f"{tmp_path.name}.main" + assert extra_path == tmp_path.parent + + def test_no_detection_without_mcp_or_env(self, tmp_path, monkeypatch): + """Test no detection when neither mcp nor env is defined.""" monkeypatch.chdir(tmp_path) init_file = tmp_path / "__init__.py" diff --git a/hud/clients/__init__.py b/hud/clients/__init__.py index 31692021..1f0f62c8 100644 --- a/hud/clients/__init__.py +++ b/hud/clients/__init__.py @@ -5,10 +5,13 @@ from .base import AgentMCPClient, BaseHUDClient from .environment import EnvironmentClient from .fastmcp import FastMCPHUDClient -from .mcp_use import MCPUseHUDClient -# Default to MCP-use for agents (has multi-server session support) -MCPClient = MCPUseHUDClient +# Default to FastMCP client (no optional dependencies) +MCPClient = FastMCPHUDClient + +# Note: MCPUseHUDClient requires mcp-use (optional dependency in [agents]). +# Import directly if needed: +# from hud.clients.mcp_use import MCPUseHUDClient __all__ = [ "AgentMCPClient", @@ -16,5 +19,4 @@ "EnvironmentClient", "FastMCPHUDClient", "MCPClient", - "MCPUseHUDClient", ] diff --git a/hud/clients/base.py b/hud/clients/base.py index b7ce86a8..b8304b0e 100644 --- a/hud/clients/base.py +++ b/hud/clients/base.py @@ -12,7 +12,6 @@ from hud.shared.exceptions import HudAuthenticationError, HudException from hud.types import MCPToolCall, MCPToolResult from hud.utils.hud_console import HUDConsole -from hud.utils.mcp import setup_hud_telemetry from hud.version import __version__ as hud_version if TYPE_CHECKING: @@ -86,7 +85,6 @@ def __init__( mcp_config: dict[str, dict[str, Any]] | None = None, verbose: bool = False, strict_validation: bool = False, - auto_trace: bool = True, ) -> None: """ Initialize base client. @@ -99,8 +97,6 @@ def __init__( self.verbose = verbose self._mcp_config = mcp_config self._strict_validation = strict_validation - self._auto_trace = auto_trace - self._auto_trace_cm: Any | None = None # Store auto-created trace context manager self._initialized = False self._telemetry_data = {} # Initialize telemetry data @@ -128,8 +124,6 @@ async def initialize(self, mcp_config: dict[str, dict[str, Any]] | None = None) "Either pass it to the constructor or call initialize with a configuration" ) - self._auto_trace_cm = setup_hud_telemetry(self._mcp_config, auto_trace=self._auto_trace) - hud_console.debug("Initializing MCP client...") try: @@ -158,17 +152,6 @@ async def initialize(self, mcp_config: dict[str, dict[str, Any]] | None = None) async def shutdown(self) -> None: """Disconnect from the MCP server.""" - # Clean up auto-created trace if any - if self._auto_trace_cm: - try: - self._auto_trace_cm.__exit__(None, None, None) - hud_console.info("Closed auto-created trace") - except Exception as e: - hud_console.warning(f"Failed to close auto-created trace: {e}") - finally: - self._auto_trace_cm = None - - # Disconnect from server if self._initialized: await self._disconnect() self._initialized = False diff --git a/hud/clients/tests/test_analyze_scenarios.py b/hud/clients/tests/test_analyze_scenarios.py index e19535b8..d30a4b16 100644 --- a/hud/clients/tests/test_analyze_scenarios.py +++ b/hud/clients/tests/test_analyze_scenarios.py @@ -24,7 +24,7 @@ def __init__( resources: list[types.Resource], ) -> None: super().__init__( - mcp_config={"test": {"url": "mock://test"}}, verbose=True, auto_trace=False + mcp_config={"test": {"url": "mock://test"}}, verbose=True ) self._mock_prompts = prompts self._mock_resources = resources diff --git a/hud/datasets/__init__.py b/hud/datasets/__init__.py index 15b8c19a..ae2c0869 100644 --- a/hud/datasets/__init__.py +++ b/hud/datasets/__init__.py @@ -1,35 +1,35 @@ """HUD datasets module. -Provides data models, utilities, and execution functions for working with HUD datasets. +Provides unified dataset loading and execution for HUD evaluations. + +Key functions: +- load_dataset(): Load tasks from JSON, JSONL, HuggingFace, or HUD API +- run_dataset(): Run an agent on a dataset of tasks +- submit_rollouts(): Submit tasks for remote execution + +Supports both v4 (LegacyTask) and v5 (Task) formats with automatic conversion. """ -# Data models -# Execution functions from __future__ import annotations -from hud.types import LegacyTask +from hud.eval.display import display_results from hud.utils.tasks import save_tasks from .loader import load_dataset -from .runner import run_dataset, run_single_task, run_tasks +from .runner import run_dataset, run_single_task from .utils import ( BatchRequest, SingleTaskRequest, - calculate_group_stats, - display_results, submit_rollouts, ) __all__ = [ "BatchRequest", - "LegacyTask", "SingleTaskRequest", - "calculate_group_stats", "display_results", "load_dataset", "run_dataset", "run_single_task", - "run_tasks", "save_tasks", "submit_rollouts", ] diff --git a/hud/datasets/loader.py b/hud/datasets/loader.py index 57806f96..48c7dd31 100644 --- a/hud/datasets/loader.py +++ b/hud/datasets/loader.py @@ -3,6 +3,7 @@ Unified interface for loading evaluation datasets from: - HUD API (v5 format) - Local JSON/JSONL files (v4 LegacyTask format, auto-converted) +- HuggingFace datasets (v4 LegacyTask format, auto-converted) """ from __future__ import annotations @@ -10,7 +11,7 @@ import json import logging from pathlib import Path -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING, Any, overload if TYPE_CHECKING: from hud.eval.task import Task @@ -20,52 +21,9 @@ __all__ = ["load_dataset"] -def _is_legacy_task_format(item: dict[str, Any]) -> bool: - """Check if a dict is in v4 LegacyTask format. - - LegacyTask has: prompt, mcp_config (required), setup_tool, evaluate_tool (optional) - v5 Task has: env, scenario, args - """ - # If it has prompt + mcp_config, it's legacy format - # If it has setup_tool or evaluate_tool, it's legacy - return ( - ("prompt" in item and "mcp_config" in item) - or "setup_tool" in item - or "evaluate_tool" in item - ) - - -def _task_from_dict(item: dict[str, Any]) -> Task: - """Convert a dict to Task, auto-detecting v4 vs v5 format.""" - from hud.eval.task import Task - from hud.types import MCPToolCall - - if _is_legacy_task_format(item): - # v4 LegacyTask format - convert via Task.from_v4() - return Task.from_v4(item) - else: - # v5 format - env is required, scenario is optional - env = item.get("env") - if env is None: - raise ValueError(f"Task missing required 'env' field: {item}") - - # Convert validation dicts to MCPToolCall objects - validation = None - if item.get("validation"): - validation = [MCPToolCall(**v) for v in item["validation"]] - - return Task( - env=env, # EnvConfig dict: {"name": "browser", "include": [...], ...} - scenario=item.get("scenario"), - id=item.get("id"), - args=item.get("args", {}), - validation=validation, - ) - - -def _load_from_file(path: Path) -> list[Task]: - """Load tasks from a local JSON or JSONL file.""" - tasks: list[Task] = [] +def _load_raw_from_file(path: Path) -> list[dict[str, Any]]: + """Load raw task dicts from a local JSON or JSONL file.""" + raw_items: list[dict[str, Any]] = [] if path.suffix == ".jsonl": # JSONL: one task per line @@ -77,9 +35,9 @@ def _load_from_file(path: Path) -> list[Task]: item = json.loads(line) # Handle case where line contains a list if isinstance(item, list): - tasks.extend(_task_from_dict(i) for i in item) + raw_items.extend(i for i in item if isinstance(i, dict)) elif isinstance(item, dict): - tasks.append(_task_from_dict(item)) + raw_items.append(item) else: raise ValueError( f"Invalid JSONL format: expected dict or list, got {type(item)}" @@ -90,17 +48,61 @@ def _load_from_file(path: Path) -> list[Task]: data = json.load(f) if isinstance(data, list): - tasks = [_task_from_dict(item) for item in data] + raw_items = [item for item in data if isinstance(item, dict)] elif isinstance(data, dict): - tasks = [_task_from_dict(data)] + raw_items = [data] else: raise ValueError(f"JSON file must contain an array or object, got {type(data)}") - return tasks + return raw_items -def _load_from_api(dataset_name: str) -> list[Task]: - """Load tasks from HUD API.""" +def _load_from_file(path: Path) -> list[Task]: + """Load tasks from a local JSON or JSONL file.""" + from hud.eval.task import Task + + raw_items = _load_raw_from_file(path) + return [Task(**item) for item in raw_items] + + +def _load_raw_from_huggingface(dataset_name: str) -> list[dict[str, Any]]: + """Load raw task dicts from HuggingFace dataset.""" + try: + from datasets import load_dataset as hf_load_dataset + except ImportError as e: + raise ImportError( + "Please install 'datasets' to load from HuggingFace: uv pip install datasets" + ) from e + + # Parse dataset name and optional split + if ":" in dataset_name: + name, split = dataset_name.split(":", 1) + else: + name = dataset_name + split = "train" # Default split + + logger.info("Loading from HuggingFace dataset: %s (split=%s)", name, split) + dataset = hf_load_dataset(name, split=split) + + raw_items: list[dict[str, Any]] = [] + for item in dataset: + if not isinstance(item, dict): + raise ValueError(f"Invalid HuggingFace dataset: expected dict, got {type(item)}") + raw_items.append(dict(item)) + + return raw_items + + +def _load_from_huggingface(dataset_name: str) -> list[Task]: + """Load tasks from HuggingFace dataset.""" + raw_items = _load_raw_from_huggingface(dataset_name) + from hud.eval.task import Task + + return [Task(**item) for item in raw_items] + + +def _load_raw_from_api(dataset_name: str) -> list[dict[str, Any]]: + """Load raw task dicts from HUD API.""" import httpx from hud.settings import settings @@ -121,21 +123,40 @@ def _load_from_api(dataset_name: str) -> list[Task]: # Extract tasks dict from response tasks_dict = data.get("tasks", {}) - tasks: list[Task] = [] + raw_items: list[dict[str, Any]] = [] for task_id, task_data in tasks_dict.items(): if task_data.get("id") is None: task_data["id"] = task_id - tasks.append(_task_from_dict(task_data)) + raw_items.append(task_data) + + return raw_items + - return tasks +def _load_from_api(dataset_name: str) -> list[Task]: + """Load tasks from HUD API.""" + from hud.eval.task import Task + + raw_items = _load_raw_from_api(dataset_name) + return [Task(**item) for item in raw_items] + + +@overload +def load_dataset(source: str, *, raw: bool = False) -> list[Task]: ... -def load_dataset(source: str) -> list[Task]: +@overload +def load_dataset(source: str, *, raw: bool = True) -> list[dict[str, Any]]: ... + + +def load_dataset( + source: str, *, raw: bool = False +) -> list[Task] | list[dict[str, Any]]: """Load tasks from a dataset source. Supports multiple sources with auto-detection: - Local file path (JSON or JSONL) - HUD API dataset slug (e.g., "hud-evals/SheetBench-50") + - HuggingFace dataset (e.g., "username/dataset" or "username/dataset:split") Automatically detects and converts v4 LegacyTask format to v5 Task. @@ -143,9 +164,13 @@ def load_dataset(source: str) -> list[Task]: source: Dataset source. Can be: - Path to a local JSON/JSONL file - HUD API dataset slug (e.g., "hud-evals/SheetBench-50") + - HuggingFace dataset name (e.g., "hud-evals/tasks" or "hud-evals/tasks:train") + raw: If True, return raw dicts without validation or env var substitution. + Useful for preserving template strings like "${HUD_API_KEY}". Returns: - List of Task objects ready to use with hud.eval() + - If raw=False (default): list[Task] ready to use with hud.eval() + - If raw=True: list[dict] with raw task data Example: ```python @@ -158,8 +183,14 @@ def load_dataset(source: str) -> list[Task]: # Load from local file (v4 format auto-converted) tasks = load_dataset("./my-tasks.json") + # Load from HuggingFace + tasks = load_dataset("hud-evals/benchmark:test") + + # Load raw dicts (preserves env var placeholders) + raw_tasks = load_dataset("./tasks.json", raw=True) + # Run evaluation - async with hud.eval(tasks, variants={"model": ["gpt-4o"]}, group=4) as ctx: + async with hud.eval(tasks) as ctx: await agent.run(ctx) ``` @@ -170,15 +201,29 @@ def load_dataset(source: str) -> list[Task]: path = Path(source) if path.exists() and path.suffix in {".json", ".jsonl"}: logger.info("Loading tasks from file: %s", source) - tasks = _load_from_file(path) - logger.info("Loaded %d tasks from %s", len(tasks), source) - return tasks + items = _load_raw_from_file(path) if raw else _load_from_file(path) + logger.info("Loaded %d tasks from %s", len(items), source) + return items - # Otherwise, try HUD API - logger.info("Loading dataset from HUD API: %s", source) + # Try HUD API first + try: + logger.info("Trying HUD API: %s", source) + items = _load_raw_from_api(source) if raw else _load_from_api(source) + logger.info("Loaded %d tasks from HUD API: %s", len(items), source) + return items + except Exception as hud_error: + logger.debug("HUD API load failed (%s), trying HuggingFace", hud_error) + + # Try HuggingFace as fallback try: - tasks = _load_from_api(source) - logger.info("Loaded %d tasks from %s", len(tasks), source) - return tasks - except Exception as e: - raise ValueError(f"Failed to load dataset '{source}' from HUD API: {e}") from e + logger.info("Trying HuggingFace dataset: %s", source) + items = _load_raw_from_huggingface(source) if raw else _load_from_huggingface(source) + logger.info("Loaded %d tasks from HuggingFace: %s", len(items), source) + return items + except ImportError: + raise ValueError( + f"Failed to load dataset '{source}'. " + "Install 'datasets' package for HuggingFace support." + ) from None + except Exception as hf_error: + raise ValueError(f"Failed to load dataset '{source}': {hf_error}") from hf_error diff --git a/hud/datasets/runner.py b/hud/datasets/runner.py index 028e118b..402671fe 100644 --- a/hud/datasets/runner.py +++ b/hud/datasets/runner.py @@ -9,59 +9,22 @@ from typing import TYPE_CHECKING, Any import hud -from hud.types import AgentType, Trace +from hud.types import AgentType, LegacyTask, TaskInput, Trace if TYPE_CHECKING: - from hud.agents import MCPAgent + from collections.abc import Sequence + from hud.eval.context import EvalContext from hud.eval.task import Task logger = logging.getLogger("hud.datasets") -async def run_tasks( - tasks: list[Task], - *, - agent_type: str, - agent_params: dict[str, Any] | None = None, - max_steps: int = 10, - max_concurrent: int = 30, - group_size: int = 1, -) -> list[EvalContext]: - """Run tasks with an agent created from type and parameters. - - This is a convenience wrapper around run_dataset that creates the agent - from a type string and parameters dictionary. - - Args: - tasks: List of Task objects to run. - agent_type: Type of agent to create (e.g., "claude", "openai", "gemini"). - agent_params: Parameters to pass to agent.create(). - max_steps: Maximum steps per task. - max_concurrent: Maximum concurrent tasks. - group_size: Number of times to run each task. - - Returns: - List of EvalContext results from each task execution. - """ - # Use AgentType enum to get the agent class (same pattern as CLI) - agent_type_enum = AgentType(agent_type) - agent_cls = agent_type_enum.cls - agent = agent_cls.create(**(agent_params or {})) - - return await run_dataset( - tasks, - agent, - max_steps=max_steps, - max_concurrent=max_concurrent, - group_size=group_size, - ) - - async def run_dataset( - tasks: str | list[Task] | list[dict[str, Any]] | Task | dict[str, Any], - agent: MCPAgent, + tasks: str | TaskInput | Sequence[TaskInput], + agent_type: str | AgentType, *, + agent_params: dict[str, Any] | None = None, max_steps: int = 10, max_concurrent: int = 30, group_size: int = 1, @@ -69,13 +32,15 @@ async def run_dataset( """Run an agent on a dataset of tasks. This is the primary entry point for running evaluations programmatically. + The agent is created fresh for each task context to ensure correct tool initialization. Args: tasks: Tasks to run. Can be: - A source string (file path, API slug) - loaded via load_dataset() - - A single Task object or dict (v4 or v5 format) - - A list of Task objects or dicts (v4 or v5 format) - agent: The agent instance to run. + - A single TaskInput (Task, LegacyTask, or dict) + - A list of TaskInput objects + agent_type: Type of agent to create (e.g., "claude", "openai", AgentType.CLAUDE). + agent_params: Parameters to pass to agent.create(). max_steps: Maximum steps per task. max_concurrent: Maximum concurrent tasks (for parallel execution). group_size: Number of times to run each task (for variance estimation). @@ -85,52 +50,55 @@ async def run_dataset( Example: ```python - from hud.agents import ClaudeAgent from hud.datasets import load_dataset, run_dataset - # Load tasks + # Load tasks and run tasks = load_dataset("my-tasks.json") + results = await run_dataset( + tasks, + agent_type="claude", + agent_params={"checkpoint_name": "claude-sonnet-4-20250514"}, + max_steps=50, + ) - # Create agent - agent = ClaudeAgent.create(checkpoint_name="claude-sonnet-4-20250514") - - # Run evaluation - results = await run_dataset(tasks, agent, max_steps=50) for ctx in results: print(f"Reward: {ctx.reward}") ``` """ - from hud.datasets.loader import _task_from_dict, load_dataset + from hud.datasets.loader import load_dataset from hud.eval.task import Task # Normalize tasks to list[Task] + task_list: list[Task] if isinstance(tasks, str): task_list = load_dataset(tasks) elif isinstance(tasks, Task): task_list = [tasks] - elif isinstance(tasks, dict): - task_list = [_task_from_dict(tasks)] - elif isinstance(tasks, list): - task_list = [] - for t in tasks: - if isinstance(t, Task): - task_list.append(t) - elif isinstance(t, dict): - task_list.append(_task_from_dict(t)) - else: - raise TypeError(f"Expected Task or dict, got {type(t)}") + elif isinstance(tasks, LegacyTask | dict): + # Single LegacyTask or dict - convert to Task + task_list = [Task.from_v4(tasks)] else: - raise TypeError(f"Expected str, Task, dict, or list, got {type(tasks)}") + # Sequence of TaskInput - convert each to Task + task_list = [ + t if isinstance(t, Task) else Task.from_v4(t) + for t in tasks + ] if not task_list: raise ValueError("No tasks to run") + # Resolve agent class + agent_type_enum = agent_type if isinstance(agent_type, AgentType) else AgentType(agent_type) + agent_cls = agent_type_enum.cls + # Use hud.eval() for both single and parallel execution async with hud.eval( task_list, group=group_size, max_concurrent=max_concurrent, ) as ctx: + # Create agent fresh for each context (ensures correct tool initialization) + agent = agent_cls.create(**(agent_params or {})) result = await agent.run(ctx, max_steps=max_steps) ctx.reward = result.reward @@ -142,7 +110,7 @@ async def run_dataset( async def run_single_task( - task: Task | dict[str, Any], + task: Task, *, agent_type: AgentType, agent_params: dict[str, Any] | None = None, @@ -153,14 +121,17 @@ async def run_single_task( trace_name: str | None = None, metadata: dict[str, Any] | None = None, trace_id: str | None = None, + api_key: str | None = None, + trace: bool = True, + quiet: bool = False, ) -> Trace: """Run a single task with full control over eval context parameters. This is the low-level entry point for running individual tasks with explicit - trace/job/group IDs. Useful for remote execution workers. + trace/job/group IDs. Used by remote execution workers. Args: - task: Task to run. Can be a Task object or dict (v4 or v5 format). + task: Task object to run. Use Task.from_v4() or load_dataset() to create. agent_type: AgentType enum specifying the agent to use. agent_params: Parameters passed to agent.create(). Should include pre-configured model_client for inference gateway usage. @@ -171,6 +142,9 @@ async def run_single_task( trace_name: Name for the trace (defaults to task_id or task.id). metadata: Additional metadata for the trace context. trace_id: Pre-assigned trace ID (if provided by backend). + api_key: API key override for telemetry and backend calls. + trace: Whether to send trace data to backend (default True). + quiet: Whether to suppress printing eval link (default False). Returns: Trace result from the agent run. @@ -178,9 +152,13 @@ async def run_single_task( Example: ```python from hud.datasets import run_single_task + from hud.eval.task import Task from hud.types import AgentType from openai import AsyncOpenAI + # Create task (from v4 dict or directly) + task = Task.from_v4({"prompt": "...", "mcp_config": {...}, "evaluate_tool": {...}}) + # Configure agent with inference gateway agent_params = { "checkpoint_name": "gpt-4o", @@ -192,7 +170,7 @@ async def run_single_task( } result = await run_single_task( - task={"env": {"name": "browser"}, "scenario": "find_page"}, + task=task, agent_type=AgentType.OPENAI, agent_params=agent_params, max_steps=20, @@ -201,36 +179,32 @@ async def run_single_task( ) ``` """ - from hud.datasets.loader import _task_from_dict - from hud.eval.task import Task as TaskCls - - # Normalize task to Task object - if isinstance(task, dict): - task_obj = _task_from_dict(task) - elif isinstance(task, TaskCls): - task_obj = task - else: - raise TypeError(f"Expected Task or dict, got {type(task)}") - - # Create agent - agent_cls = agent_type.cls - agent = agent_cls.create(**(agent_params or {})) - # Determine trace name - effective_trace_name = trace_name or task_id or task_obj.id or "single_task" + effective_trace_name = trace_name or task_id or task.id or "single_task" # Run with explicit eval context parameters async with hud.eval( - task_obj, + task, name=effective_trace_name, job_id=job_id, group_id=group_id, trace_id=trace_id, + api_key=api_key, + trace=trace, + quiet=quiet, ) as ctx: + # Build agent params - use system_prompt from ctx (set from task.agent_config) + final_agent_params = dict(agent_params or {}) + if ctx.system_prompt and "system_prompt" not in final_agent_params: + final_agent_params["system_prompt"] = ctx.system_prompt + + # Create agent inside ctx so it has access to context-derived values + agent_cls = agent_type.cls + agent = agent_cls.create(**final_agent_params) + # Store metadata if provided if metadata: - for key, value in metadata.items(): - setattr(ctx, f"_meta_{key}", value) + ctx.metadata.update(metadata) result = await agent.run(ctx, max_steps=max_steps) ctx.reward = result.reward diff --git a/hud/datasets/tests/test_runner.py b/hud/datasets/tests/test_runner.py deleted file mode 100644 index 31595839..00000000 --- a/hud/datasets/tests/test_runner.py +++ /dev/null @@ -1,67 +0,0 @@ -from __future__ import annotations - -from unittest.mock import MagicMock, patch - -import pytest - -from hud.telemetry.utils import flush_telemetry - - -@pytest.mark.asyncio -async def test_flush_telemetry(): - """Test flush_telemetry function.""" - with ( - patch("hud.otel.config.is_telemetry_configured", return_value=True), - patch("hud.utils.hud_console.hud_console"), - patch("opentelemetry.trace.get_tracer_provider") as mock_get_provider, - ): - from opentelemetry.sdk.trace import TracerProvider - - mock_provider = MagicMock(spec=TracerProvider) - mock_provider.force_flush.return_value = True - mock_get_provider.return_value = mock_provider - - await flush_telemetry() - - mock_provider.force_flush.assert_called_once_with(timeout_millis=5000) - - -@pytest.mark.asyncio -async def test_flush_telemetry_not_configured(): - """Test flush_telemetry when telemetry is not configured.""" - with patch("hud.otel.config.is_telemetry_configured", return_value=False): - await flush_telemetry() - - -@pytest.mark.asyncio -async def test_flush_telemetry_exception(): - """Test flush_telemetry handles exceptions gracefully.""" - with ( - patch("hud.otel.config.is_telemetry_configured", return_value=True), - patch("hud.utils.hud_console.hud_console"), - patch("opentelemetry.trace.get_tracer_provider") as mock_get_provider, - ): - from opentelemetry.sdk.trace import TracerProvider - - mock_provider = MagicMock(spec=TracerProvider) - mock_provider.force_flush.side_effect = Exception("Flush failed") - mock_get_provider.return_value = mock_provider - - await flush_telemetry() - - -@pytest.mark.asyncio -async def test_flush_telemetry_timeout(): - """Test flush_telemetry when force_flush times out.""" - with ( - patch("hud.otel.config.is_telemetry_configured", return_value=True), - patch("hud.utils.hud_console.hud_console"), - patch("opentelemetry.trace.get_tracer_provider") as mock_get_provider, - ): - from opentelemetry.sdk.trace import TracerProvider - - mock_provider = MagicMock(spec=TracerProvider) - mock_provider.force_flush.return_value = False - mock_get_provider.return_value = mock_provider - - await flush_telemetry() diff --git a/hud/datasets/tests/test_utils.py b/hud/datasets/tests/test_utils.py index 79a69544..107f737a 100644 --- a/hud/datasets/tests/test_utils.py +++ b/hud/datasets/tests/test_utils.py @@ -9,13 +9,12 @@ from hud.datasets.utils import ( BatchRequest, SingleTaskRequest, - calculate_group_stats, cancel_all_jobs, cancel_job, cancel_task, - display_results, submit_rollouts, ) +from hud.eval.display import display_results from hud.types import AgentType, LegacyTask, Trace @@ -23,9 +22,9 @@ class TestSingleTaskRequest: """Tests for SingleTaskRequest schema.""" def test_valid_request(self): - """Test creating a valid SingleTaskRequest.""" + """Test creating a valid SingleTaskRequest with v5 task.""" request = SingleTaskRequest( - task={"prompt": "test", "mcp_config": {}}, + task={"env": {"name": "browser"}, "scenario": "checkout"}, agent_type=AgentType.CLAUDE, agent_params={"checkpoint_name": "claude-sonnet-4-5"}, max_steps=10, @@ -48,8 +47,8 @@ def test_empty_job_id_rejected(self): ) def test_invalid_task_rejected(self): - """Test that invalid task payload is rejected.""" - with pytest.raises(ValueError, match="Invalid task payload"): + """Test that invalid task payload is rejected (neither v4 nor v5).""" + with pytest.raises(ValueError, match="Task must have 'env'"): SingleTaskRequest( task={"invalid_field": "test"}, # Missing required fields agent_type=AgentType.CLAUDE, @@ -58,6 +57,43 @@ def test_invalid_task_rejected(self): trace_name="Test", ) + def test_incomplete_v4_task_rejected(self): + """Test that incomplete v4 task (missing evaluate_tool) is rejected.""" + with pytest.raises(ValueError, match="v4 task missing required fields"): + SingleTaskRequest( + task={"prompt": "test", "mcp_config": {}}, # Missing evaluate_tool + agent_type=AgentType.CLAUDE, + job_id="job-123", + task_id="task-1", + trace_name="Test", + ) + + def test_valid_v4_task_accepted(self): + """Test that complete v4 task is accepted.""" + request = SingleTaskRequest( + task={ + "prompt": "test", + "mcp_config": {"server": {"url": "http://localhost"}}, + "evaluate_tool": {"name": "check", "arguments": {}}, + }, + agent_type=AgentType.CLAUDE, + job_id="job-123", + task_id="task-1", + trace_name="Test", + ) + assert request.task_id == "task-1" + + def test_valid_v5_task_accepted(self): + """Test that v5 task with env is accepted.""" + request = SingleTaskRequest( + task={"env": {"name": "browser"}, "scenario": "login"}, + agent_type=AgentType.CLAUDE, + job_id="job-123", + task_id="task-1", + trace_name="Test", + ) + assert request.task_id == "task-1" + class TestBatchRequest: """Tests for BatchRequest schema.""" @@ -66,7 +102,7 @@ def test_valid_batch(self): """Test creating a valid batch request.""" requests = [ SingleTaskRequest( - task={"prompt": "test", "mcp_config": {}}, + task={"env": {"name": "browser"}, "scenario": "test"}, agent_type=AgentType.CLAUDE, job_id="job-123", task_id=f"task-{i}", @@ -155,56 +191,6 @@ async def test_cancel_all_jobs(self): assert result["total_tasks_cancelled"] == 10 -class TestCalculateGroupStats: - """Tests for calculate_group_stats function.""" - - def test_basic_stats(self): - """Test basic group statistics calculation.""" - tasks = [ - LegacyTask(prompt="Task 1", mcp_config={}), - LegacyTask(prompt="Task 2", mcp_config={}), - ] - traces: list[Trace | None] = [ - Trace(reward=0.8, done=True), - Trace(reward=0.9, done=True), - Trace(reward=0.6, done=True), - Trace(reward=0.7, done=True), - ] - group_ids = {0: "group-0", 1: "group-1"} - - stats = calculate_group_stats(tasks, traces, group_size=2, group_ids=group_ids) - - assert len(stats) == 2 - assert stats[0]["mean_reward"] == pytest.approx(0.85, rel=0.01) - assert stats[1]["mean_reward"] == pytest.approx(0.65, rel=0.01) - - def test_all_none_traces(self): - """Test when all traces are None.""" - tasks = [LegacyTask(prompt="Task 1", mcp_config={})] - traces: list[Trace | None] = [None, None] - group_ids = {0: "group-0"} - - stats = calculate_group_stats(tasks, traces, group_size=2, group_ids=group_ids) - - assert len(stats) == 1 - assert stats[0]["error_rate"] == 1.0 - assert stats[0]["mean_reward"] == 0.0 - - def test_mixed_success_failure(self): - """Test with mixed success and failure traces.""" - tasks = [LegacyTask(prompt="Task 1", mcp_config={})] - traces: list[Trace | None] = [ - Trace(reward=1.0, done=True), - Trace(reward=0.0, done=True, isError=True), - ] - group_ids = {0: "group-0"} - - stats = calculate_group_stats(tasks, traces, group_size=2, group_ids=group_ids) - - assert stats[0]["success_rate"] == 0.5 - assert stats[0]["error_rate"] == 0.5 - - class TestDisplayResults: """Tests for display_results function.""" diff --git a/hud/datasets/utils.py b/hud/datasets/utils.py index 04260186..98b28289 100644 --- a/hud/datasets/utils.py +++ b/hud/datasets/utils.py @@ -3,19 +3,30 @@ from __future__ import annotations import logging -from statistics import mean, pstdev -from typing import Any +from typing import TYPE_CHECKING, Any import httpx from pydantic import BaseModel, Field, field_validator, model_validator from hud.settings import settings -from hud.types import AgentType, LegacyTask, Trace +from hud.types import AgentType, TaskInput from hud.utils.hud_console import HUDConsole +if TYPE_CHECKING: + from collections.abc import Sequence + logger = logging.getLogger(__name__) hud_console = HUDConsole() +__all__ = [ + "BatchRequest", + "SingleTaskRequest", + "cancel_all_jobs", + "cancel_job", + "cancel_task", + "submit_rollouts", +] + class SingleTaskRequest(BaseModel): """Request to run a single task remotely - mirrors run_single_task() args.""" @@ -44,18 +55,21 @@ class SingleTaskRequest(BaseModel): @model_validator(mode="after") def _validate_task(self) -> SingleTaskRequest: """Validate task is either v4 LegacyTask or v5 Task format.""" - from hud.datasets.loader import _is_legacy_task_format + from hud.eval.utils import is_v4_format, validate_v4_task + + # v4 format: looks like v4 (prompt + mcp_config)? + if is_v4_format(self.task): + # Validate completeness (requires evaluate_tool too) + validate_v4_task(self.task) + return self - # v4 format: prompt + mcp_config - if _is_legacy_task_format(self.task): - try: - LegacyTask(**self.task) - except Exception as exc: - raise ValueError(f"Invalid legacy task payload: {exc}") from exc # v5 format: env required - elif "env" not in self.task: - raise ValueError("Task must have 'env' (v5) or 'prompt'+'mcp_config' (v4)") - return self + if "env" in self.task: + return self + + # Neither v4 nor v5 + raise ValueError("Task must have 'env' (v5) or 'prompt'+'mcp_config'+'evaluate_tool' (v4)") + @field_validator("job_id") @classmethod @@ -75,8 +89,21 @@ class BatchRequest(BaseModel): ) +def _normalize_tasks(tasks: Sequence[TaskInput]) -> list[dict[str, Any]]: + """Convert tasks to list of dicts for remote API submission.""" + result = [] + for t in tasks: + if isinstance(t, dict): + result.append(t) + elif hasattr(t, "model_dump"): + result.append(t.model_dump(mode="json")) + else: + raise TypeError(f"Cannot convert {type(t).__name__} to dict") + return result + + async def submit_rollouts( - tasks: list[LegacyTask], + tasks: Sequence[TaskInput], job_id: str, agent_type: AgentType, agent_params: dict[str, Any] | None = None, @@ -88,60 +115,58 @@ async def submit_rollouts( """Submit rollouts to the HUD platform API for remote execution (fire-and-forget). Args: - tasks: List of Task objects to execute + tasks: List of tasks (v5 Task, v4 LegacyTask, or dicts) job_id: HUD job ID for telemetry grouping agent_type: Agent type to use for execution - agent_params: Parameters passed to agent.create(). Should include fields - from BaseCreateParams (auto_trace, auto_respond, verbose) plus - agent-specific config fields (e.g., checkpoint_name for ClaudeConfig). + agent_params: Parameters passed to agent.create() max_steps: Maximum steps per rollout group_size: Number of rollouts per task (for variance estimation) batch_size: Number of rollouts per API batch request metadata: Additional metadata for each rollout """ + from hud.eval.utils import is_v4_format + if not settings.api_key: raise ValueError("HUD_API_KEY is required for remote execution") - # Validate tasks have remote-compatible mcp_config (URL-based, not command-based) - local_task_servers: list[tuple[int, str, str]] = [] # (task_idx, task_id, server_name) - affected_task_indices: set[int] = set() - for i, task in enumerate(tasks): - if task.mcp_config: - for server_name, server_cfg in task.mcp_config.items(): - if ( - isinstance(server_cfg, dict) - and "command" in server_cfg - and not server_cfg.get("url") - ): - local_task_servers.append((i, task.id or f"task_{i}", server_name)) - affected_task_indices.add(i) - - if local_task_servers: - task_details = ", ".join(f"{tid} ({srv})" for _, tid, srv in local_task_servers[:3]) - if len(local_task_servers) > 3: - task_details += f", ... and {len(local_task_servers) - 3} more" - raise ValueError( - f"Remote execution requires URL-based mcp_config, but " - f"{len(affected_task_indices)} task(s) use local Docker configs " - f"(command-based): {task_details}. " - "Convert to remote with: hud convert " - ) + # Convert to dicts once for uniform processing + task_dicts = _normalize_tasks(tasks) + + # Validate v4 tasks have remote-compatible mcp_config (URL-based, not command-based) + for i, td in enumerate(task_dicts): + if not is_v4_format(td): + continue # v5 tasks use env config, no mcp_config to check + mcp_config = td.get("mcp_config") or {} + for server_name, server_cfg in mcp_config.items(): + is_local = ( + isinstance(server_cfg, dict) + and "command" in server_cfg + and not server_cfg.get("url") + ) + if is_local: + raise ValueError( + f"Remote execution requires URL-based mcp_config. " + f"Task {td.get('id') or i} uses local Docker config for '{server_name}'. " + "Convert to remote with: hud convert " + ) # Build single task requests requests: list[SingleTaskRequest] = [] - for task_idx, task in enumerate(tasks): - base_task_id = task.id or f"task_{task_idx}" + for task_idx, td in enumerate(task_dicts): + base_task_id = td.get("id") or f"task_{task_idx}" + trace_name = td.get("prompt") or td.get("scenario") or base_task_id + for rollout_idx in range(group_size): task_id = f"{base_task_id}_r{rollout_idx}" if group_size > 1 else base_task_id requests.append( SingleTaskRequest( - task=task.model_dump(mode="json"), + task=td, agent_type=agent_type, agent_params=agent_params or {}, max_steps=max_steps, job_id=job_id, task_id=task_id, - trace_name=task.prompt or task_id, + trace_name=trace_name, group_id=base_task_id if group_size > 1 else None, metadata=metadata or {}, ) @@ -267,197 +292,3 @@ async def cancel_all_jobs() -> dict[str, Any]: return response.json() -def calculate_group_stats( - tasks: list[LegacyTask], - traces: list[Trace | None], - group_size: int, - group_ids: dict[int, str], -) -> list[dict[str, Any]]: - """Calculate statistics for each task group. - - Args: - tasks: List of Task objects - traces: List of Trace results (may contain None for failed tasks) - group_size: Number of runs per task - group_ids: Mapping from task index to group ID - - Returns: - List of statistics dicts, one per task, containing: - - task_id, prompt, group_id, group_size - - rewards: list of individual rewards - - mean_reward, std_reward, min_reward, max_reward - - success_rate, error_rate - - traces: list of Trace objects for this group - """ - stats = [] - - for task_idx, task in enumerate(tasks): - # Get traces for this task - start = task_idx * group_size - task_traces = [t for t in traces[start : start + group_size] if t is not None] - - if not task_traces: - stats.append( - { - "task_id": task.id or f"task_{task_idx}", - "prompt": task.prompt or "", - "group_id": group_ids[task_idx], - "group_size": group_size, - "rewards": [], - "mean_reward": 0.0, - "std_reward": 0.0, - "success_rate": 0.0, - "error_rate": 1.0, - } - ) - continue - - rewards = [t.reward for t in task_traces] - errors = [t for t in task_traces if t.isError] - - task_stats = { - "task_id": task.id or f"task_{task_idx}", - "prompt": task.prompt or "", - "group_id": group_ids[task_idx], - "group_size": group_size, - "rewards": rewards, - "mean_reward": mean(rewards), - "std_reward": pstdev(rewards) if len(rewards) > 1 else 0.0, - "min_reward": min(rewards), - "max_reward": max(rewards), - "success_rate": sum(1 for r in rewards if r > 0) / len(rewards), - "error_rate": len(errors) / len(task_traces), - "traces": task_traces, - } - stats.append(task_stats) - - return stats - - -def display_results( - results: list[Any], - *, - tasks: list[Any], - elapsed: float | None = None, - show_details: bool = True, -) -> None: - """Display evaluation results in a formatted table. - - Args: - results: List of EvalContext objects or grouped statistics dicts - tasks: List of Task or LegacyTask objects corresponding to results - elapsed: Optional elapsed time in seconds - show_details: Whether to show per-task details table - """ - from rich.console import Console - from rich.table import Table - - from hud.utils.hud_console import HUDConsole - - hud_console = HUDConsole() - console = Console() - - if not results: - hud_console.warning("No results to display") - return - - # Detect if this is grouped results (list of dicts with 'mean_reward') or traces - is_grouped = isinstance(results[0], dict) and "mean_reward" in results[0] - - if is_grouped: - # Grouped evaluation stats - all_means = [s["mean_reward"] for s in results] - overall_mean = mean(all_means) if all_means else 0.0 - overall_std = pstdev(all_means) if len(all_means) > 1 else 0.0 - group_size = results[0].get("group_size", 1) - total_episodes = sum(len(s.get("rewards", [])) for s in results) - - hud_console.success("\n📊 Evaluation Complete") - hud_console.info(f"Tasks: {len(results)} x {group_size} runs = {total_episodes} episodes") - if elapsed: - hud_console.info(f"Time: {elapsed:.1f}s ({total_episodes / elapsed:.1f} episodes/s)") - hud_console.info(f"Mean reward: {overall_mean:.3f} ± {overall_std:.3f}") - - if show_details and len(results) <= 50: - table = Table(title="\nPer-Task Performance") - table.add_column("#", style="dim", justify="right") - table.add_column("Task ID", style="cyan", no_wrap=True) - table.add_column("Prompt", style="dim", max_width=40) - table.add_column("Mean±Std", justify="right", style="green") - table.add_column("Min/Max", justify="right") - table.add_column("Success%", justify="right", style="yellow") - - for i, (stat, task) in enumerate(zip(results, tasks, strict=False)): - task_id = (task.id or "")[:20] - # Handle both v4 (prompt attr) and v5 (prompt in args) tasks - raw_prompt = ( - getattr(task, "prompt", None) - or (task.args.get("prompt") if hasattr(task, "args") else None) - or task.scenario - or "" - ) - prompt = raw_prompt[:40] - if len(raw_prompt) > 40: - prompt += "..." - table.add_row( - str(i + 1), - task_id, - prompt, - f"{stat.get('mean_reward', 0):.3f}±{stat.get('std_reward', 0):.3f}", - f"{stat.get('min_reward', 0):.2f}/{stat.get('max_reward', 0):.2f}", - f"{stat.get('success_rate', 0) * 100:.0f}%", - ) - console.print(table) - - high_var = [s for s in results if s.get("std_reward", 0) > 0.3] - if high_var: - hud_console.warning(f"\n⚠️ {len(high_var)} tasks show high variance (std > 0.3)") - - else: - # Single-run traces - valid_results = [r for r in results if r is not None] - rewards = [getattr(r, "reward", 0) for r in valid_results] - - if not rewards: - hud_console.warning("No valid results") - return - - mean_reward = sum(rewards) / len(rewards) - successful = sum(1 for r in rewards if r > 0.7) - success_rate = successful / len(results) - - hud_console.success("\n📊 Evaluation Complete") - hud_console.info(f"Tasks: {len(results)}") - if elapsed: - hud_console.info(f"Time: {elapsed:.1f}s ({len(results) / elapsed:.1f} tasks/s)") - hud_console.info(f"Mean reward: {mean_reward:.3f}") - hud_console.info(f"Success rate: {success_rate * 100:.1f}% ({successful}/{len(results)})") - - if show_details and len(results) <= 50: - table = Table(title="\nPer-Task Results") - table.add_column("#", style="dim", justify="right") - table.add_column("Task ID", style="cyan", no_wrap=True) - table.add_column("Prompt", style="dim", max_width=40) - table.add_column("Reward", justify="right", style="green") - table.add_column("Status", justify="center") - - for i, r in enumerate(results): - task = tasks[i] - task_id = (task.id or "")[:20] - # Handle both v4 (prompt attr) and v5 (prompt in args) tasks - raw_prompt = ( - getattr(task, "prompt", None) - or (task.args.get("prompt") if hasattr(task, "args") else None) - or getattr(task, "scenario", None) - or "" - ) - prompt = raw_prompt[:40] - if len(raw_prompt) > 40: - prompt += "..." - - if r is None: - table.add_row(str(i + 1), task_id, prompt, "—", "[red]Error[/red]") - else: - reward = getattr(r, "reward", 0) - status = "[green]✓[/green]" if reward > 0.7 else "[yellow]✗[/yellow]" - table.add_row(str(i + 1), task_id, prompt, f"{reward:.3f}", status) diff --git a/hud/environment/connectors/mcp_config.py b/hud/environment/connectors/mcp_config.py index ebfacee5..db5aa6af 100644 --- a/hud/environment/connectors/mcp_config.py +++ b/hud/environment/connectors/mcp_config.py @@ -98,6 +98,12 @@ def connect_mcp_config( await env.call_tool("search_repositories", query="mcp") ``` """ + # Store mcp_config for serialization (v4 format) + # Merge with existing if called multiple times + if not hasattr(self, "_mcp_config") or self._mcp_config is None: + self._mcp_config = {} + self._mcp_config.update(mcp_config) + for server_name, server_config in mcp_config.items(): self.connect_mcp({server_name: server_config}, alias=server_name, **kwargs) return self diff --git a/hud/environment/connectors/remote.py b/hud/environment/connectors/remote.py index 866b13a5..b5cdda0b 100644 --- a/hud/environment/connectors/remote.py +++ b/hud/environment/connectors/remote.py @@ -50,6 +50,16 @@ def connect_hub( logger.info("Connecting to hub environment: %s", slug) + # Store hub config for serialization (v5 format) + # Note: Only first hub is stored for serialization (task configs use single hub) + if not hasattr(self, "_hub_config") or self._hub_config is None: + hub_config: dict[str, Any] = {"name": slug} + if include: + hub_config["include"] = include + if exclude: + hub_config["exclude"] = exclude + self._hub_config = hub_config + # Create mcp_config with standard MCP URL and hub slug in headers mcp_config = { "hud": { diff --git a/hud/environment/environment.py b/hud/environment/environment.py index 1dbeb6fe..74db86d3 100644 --- a/hud/environment/environment.py +++ b/hud/environment/environment.py @@ -138,8 +138,11 @@ def __init__( # Default prompt (EvalContext has per-run prompt) self.prompt: str | None = None - # Track which lifecycle tools we've warned about (only warn once per tool) - self._warned_lifecycle_tools: set[str] = set() + # Serialization support + # _hub_config: set by connect_hub() for v5 format {"name": "hub", "include": [...]} + # _mcp_config: set by connect_mcp_config() for v4 format {"server_name": {...}} + self._hub_config: dict[str, Any] | None = None + self._mcp_config: dict[str, dict[str, Any]] | None = None # Initialize mock state self._init_mock() @@ -173,26 +176,9 @@ async def call_tool(self, call: Any, /, **kwargs: Any) -> Any: # Parse the tool call (kwargs merged when call is string) parsed, fmt = parse_tool_call(call, **kwargs) - self._check_lifecycle_warning(parsed.name) result = await self._execute_tool(parsed.name, parsed.arguments or {}) return format_result(result, parsed, fmt) - def _check_lifecycle_warning(self, name: str) -> None: - """Warn once if calling a setup/evaluate tool manually.""" - if name in self._warned_lifecycle_tools: - return - setup = {n for n, _ in self._setup_calls} - evaluate = {n for n, _ in self._evaluate_calls} - if name not in setup and name not in evaluate: - return - self._warned_lifecycle_tools.add(name) - phase = "setup" if name in setup else "evaluate" - logger.warning( - "Tool '%s' is a %s tool (runs automatically). Manual call may duplicate.", - name, - phase, - ) - def _connections_with_tool(self, tool_name: str) -> set[str]: """Get connection names that have a specific tool. @@ -524,6 +510,108 @@ def local_connections(self) -> list[str]: """Names of local (non-parallelizable) connections.""" return [name for name, conn in self._connections.items() if conn.is_local] + # ========================================================================= + # Serialization + # ========================================================================= + + @property + def is_serializable(self) -> bool: + """True if environment can be serialized (no local tools/scenarios). + + For v5 format: requires hub config from connect_hub() + For v4 format: requires mcp_config, prompt, AND evaluate_tool + """ + # Check for local tools (registered via @env.tool) + if self._router._local_names: + return False + # Check for local scenarios (registered via @env.scenario) + if getattr(self, "_scenarios", {}): + return False + # v5 hub format + if self._hub_config is not None: + return True + # v4 format requires mcp_config + prompt + evaluate_tool + if self._mcp_config is not None: + return bool(self.prompt and self._evaluate_calls) + return False + + def to_config(self) -> dict[str, Any]: + """Serialize environment config for remote submission. + + Returns the config in either v5 format (hub-based) or v4 format (legacy). + For v4 format, automatically includes prompt, setup_tool, and evaluate_tool + from the Environment's state. + + Returns: + dict: Serializable config + + Raises: + ValueError: If environment has local tools/scenarios that can't be serialized + + Example: + ```python + # v5 hub-based + env = Environment("my").connect_hub("browser", include=["navigate"]) + env.to_config() # {"name": "browser", "include": ["navigate"]} + + # v4 legacy (from Task.from_v4()) + task = Task.from_v4(legacy_task) + task.env.to_config() # {"prompt": "...", "mcp_config": {...}, ...} + ``` + """ + if self._router._local_names: + raise ValueError( + f"Cannot serialize Environment with local tools: " + f"{list(self._router._local_names)}. " + "Local tools require local execution. For remote submission, " + "use dict config or connect to a remote hub." + ) + if getattr(self, "_scenarios", {}): + raise ValueError( + f"Cannot serialize Environment with local scenarios: " + f"{list(self._scenarios.keys())}. " + "Local scenarios require local execution. For remote submission, " + "define scenarios on the remote environment." + ) + + # v5 hub-based format + if self._hub_config is not None: + return self._hub_config.copy() + + # v4 legacy format - requires mcp_config, prompt, AND evaluate_tool + if self._mcp_config is not None: + # Validate required fields for v4 format + if not self.prompt: + raise ValueError( + "Cannot serialize v4 Environment without prompt. " + "Set env.prompt before serializing." + ) + if not self._evaluate_calls: + raise ValueError( + "Cannot serialize v4 Environment without evaluate_tool. " + "Use env.evaluate_tool() to define evaluation criteria." + ) + + config: dict[str, Any] = { + "prompt": self.prompt, + "mcp_config": self._mcp_config, + "evaluate_tool": [ + {"name": name, "arguments": args} + for name, args in self._evaluate_calls + ], + } + if self._setup_calls: + config["setup_tool"] = [ + {"name": name, "arguments": args} + for name, args in self._setup_calls + ] + return config + + raise ValueError( + "Cannot serialize Environment without config. " + "Use connect_hub() for v5 tasks or connect_mcp_config() for legacy tasks." + ) + def __repr__(self) -> str: return f"Environment({self.name!r}, connections={list(self._connections.keys())})" diff --git a/hud/environment/integrations/adk.py b/hud/environment/integrations/adk.py index 93d0cf42..0498fd1a 100644 --- a/hud/environment/integrations/adk.py +++ b/hud/environment/integrations/adk.py @@ -52,7 +52,7 @@ def as_adk_tools(self) -> list[Any]: ``` """ try: - from google.adk.tools import FunctionTool + from google.adk.tools.function_tool import FunctionTool except ImportError as e: raise ImportError( "Google ADK not installed. Install with: pip install google-adk" diff --git a/hud/eval/__init__.py b/hud/eval/__init__.py index 45011413..0c659773 100644 --- a/hud/eval/__init__.py +++ b/hud/eval/__init__.py @@ -42,13 +42,19 @@ # Task is safe to import from hud.eval.task import Task +# Utils for v4 format handling +from hud.eval.utils import build_env_from_v4, is_v4_format, validate_v4_task + if TYPE_CHECKING: from hud.eval.context import EvalContext __all__ = [ "EvalContext", "Task", + "build_env_from_v4", + "is_v4_format", "run_eval", + "validate_v4_task", ] diff --git a/hud/eval/context.py b/hud/eval/context.py index 1775c18c..f84ba9ca 100644 --- a/hud/eval/context.py +++ b/hud/eval/context.py @@ -12,13 +12,12 @@ import contextvars import logging import uuid -from datetime import UTC, datetime from typing import TYPE_CHECKING, Any, Self from hud.environment import Environment from hud.settings import settings from hud.shared import make_request -from hud.telemetry.job import get_current_job +from hud.telemetry import flush, instrument if TYPE_CHECKING: from types import TracebackType @@ -35,12 +34,39 @@ "current_trace_headers", default=None ) +# Contextvar to store current api_key override (for telemetry exporter) +_current_api_key: contextvars.ContextVar[str | None] = contextvars.ContextVar( + "current_api_key", default=None +) + def get_current_trace_headers() -> dict[str, str] | None: """Get the current trace headers from context.""" return _current_trace_headers.get() +def get_current_trace_id() -> str | None: + """Get the current trace ID (task_run_id) from context. + + Returns the Trace-Id if inside an eval context, None otherwise. + Used by @instrument decorator to know where to send telemetry. + """ + headers = _current_trace_headers.get() + if headers: + return headers.get("Trace-Id") + return None + + +def get_current_api_key() -> str | None: + """Get the current API key override from context. + + Returns the api_key if one was passed to hud.eval(), otherwise None. + Falls back to settings.api_key if not in an eval context. + Used by telemetry exporter for uploads. + """ + return _current_api_key.get() + + # ============================================================================= # EvalContext # ============================================================================= @@ -118,11 +144,7 @@ def __init__( self.eval_name: str = name # Separate from self.name for clarity # Job linkage - if job_id is None: - current_job = get_current_job() - self.job_id: str | None = current_job.id if current_job else None - else: - self.job_id = job_id + self.job_id: str | None = job_id self.group_id: str | None = group_id self.index: int = index @@ -134,10 +156,14 @@ def __init__( self.prompt: str | None = None # From scenario setup or task self.reward: float | None = None self.answer: str | None = None # Agent's submitted answer + self.system_prompt: str | None = None # From task.agent_config, passed to agent # Error tracking self.error: BaseException | None = None + # User metadata (arbitrary key-value pairs) + self.metadata: dict[str, Any] = {} + # Parallel results (empty list for single evals, populated for parallel) self.results: list[EvalContext] = [] @@ -146,13 +172,11 @@ def __init__( # Private state for eval tracking self._eval_api_key = api_key - self._started_at: datetime | None = None - self._completed_at: datetime | None = None self._token: contextvars.Token[dict[str, str] | None] | None = None + self._api_key_token: contextvars.Token[str | None] | None = None self._is_summary: bool = False # True for summary contexts (skip trace) self._suppress_link: bool = quiet # True to suppress printing eval link self._trace_enabled: bool = trace # Whether to send trace data to backend - self._scenario_name: str | None = None # Current scenario name (for submit) self._source_env_name: str | None = None # Source env name for remote lookups self._task: Task | None = None # Task config (set by from_task) @@ -286,6 +310,10 @@ def from_task( # Store task info for scenario execution ctx._task = task + # Set system_prompt from task.agent_config + if task.agent_config and task.agent_config.system_prompt: + ctx.system_prompt = task.agent_config.system_prompt + return ctx async def _run_task_scenario_setup(self) -> None: @@ -293,7 +321,6 @@ async def _run_task_scenario_setup(self) -> None: if self._task is None or self._task.scenario is None: return - self._scenario_name = self._task.scenario prompt = await self.run_scenario_setup(self._task.scenario, self._task.args) if prompt: self.prompt = prompt @@ -358,24 +385,11 @@ def headers(self) -> dict[str, str]: """Headers for gateway integration.""" return {"Trace-Id": self.trace_id} - @property - def duration(self) -> float: - """Execution duration in seconds.""" - if self._started_at is None: - return 0.0 - end = self._completed_at or datetime.now(UTC) - return (end - self._started_at).total_seconds() - @property def success(self) -> bool: """True if no error occurred.""" return self.error is None - @property - def done(self) -> bool: - """True if execution completed.""" - return self._completed_at is not None - # ========================================================================= # Backend Integration # ========================================================================= @@ -392,6 +406,7 @@ def _build_base_payload(self) -> EvalPayload: group_id=self.group_id, variants=self.variants if self.variants else None, task_version_id=self._task.id if self._task else None, + metadata=self.metadata if self.metadata else None, ) async def log(self, metrics: dict[str, Any]) -> None: @@ -426,7 +441,7 @@ async def submit(self, answer: str) -> None: await ctx.submit(response) # On exit, scenario's evaluate phase receives the answer """ - if not self._scenario_name: + if not self._task or not self._task.scenario: logger.warning("submit() called but no scenario is running") return @@ -434,7 +449,7 @@ async def submit(self, answer: str) -> None: self.answer = answer # Delegate to Environment.submit() which handles storage + broadcast - await super().submit(self._scenario_name, answer) + await super().submit(self._task.scenario, answer) async def _eval_enter(self) -> None: """Notify backend that eval has started.""" @@ -494,8 +509,8 @@ async def __aenter__(self) -> Self: return self # Start tracking - self._started_at = datetime.now(UTC) self._token = _current_trace_headers.set(self.headers) + self._api_key_token = _current_api_key.set(self._eval_api_key) # Connect environment (MCP servers, tools) await super().__aenter__() @@ -521,8 +536,6 @@ async def __aexit__( if self._is_summary: return exc_type is ParallelEvalComplete - self._completed_at = datetime.now(UTC) - # Run task scenario evaluate (if no error and has scenario) if exc_type is None: await self._run_task_scenario_evaluate() @@ -533,18 +546,37 @@ async def __aexit__( self.error = exc_val error_msg = str(exc_val) if exc_val else "Unknown error" + # Flush any pending telemetry spans for this trace + flush(self.trace_id) + # Disconnect environment (parent class) await super().__aexit__(exc_type, exc_val, exc_tb) - # Reset context var + # Reset context vars if self._token is not None: _current_trace_headers.reset(self._token) self._token = None + if self._api_key_token is not None: + _current_api_key.reset(self._api_key_token) + self._api_key_token = None # Notify backend await self._eval_exit(error_msg) return False + # ========================================================================= + # Tool Call Instrumentation + # ========================================================================= + + @instrument(category="mcp") + async def call_tool(self, call: Any, /, **kwargs: Any) -> Any: + """Call a tool with automatic telemetry recording. + + Overrides Environment.call_tool to record MCP spans for the eval context. + Uses @instrument decorator for automatic span recording. + """ + return await super().call_tool(call, **kwargs) + def __repr__(self) -> str: return f"EvalContext({self.trace_id[:8]}..., name={self.eval_name!r}, reward={self.reward})" @@ -561,4 +593,9 @@ def _print_eval_link(self) -> None: # Re-export for backwards compatibility with trace module -__all__ = ["EvalContext", "get_current_trace_headers"] +__all__ = [ + "EvalContext", + "get_current_api_key", + "get_current_trace_headers", + "get_current_trace_id", +] diff --git a/hud/eval/display.py b/hud/eval/display.py index a7798504..1d23d494 100644 --- a/hud/eval/display.py +++ b/hud/eval/display.py @@ -1,34 +1,20 @@ -"""Display helpers for eval links and job URLs. - -Provides consistent, beautiful display for HUD URLs using rich. -""" +"""Display helpers for eval links, job URLs, and result statistics.""" from __future__ import annotations import contextlib import webbrowser from statistics import mean, pstdev -from typing import TYPE_CHECKING, Any +from typing import Any from hud.settings import settings -if TYPE_CHECKING: - from hud.eval.context import EvalContext - def print_link(url: str, title: str, *, open_browser: bool = True) -> None: - """Print a nicely formatted link with optional browser opening. - - Args: - url: The URL to display - title: Title for the panel (e.g., "🔗 Eval Started", "🚀 Job Started") - open_browser: Whether to open the URL in browser - """ - # Only print if telemetry is enabled and has API key + """Print a nicely formatted link with optional browser opening.""" if not (settings.telemetry_enabled and settings.api_key): return - # Open in browser if open_browser: with contextlib.suppress(Exception): webbrowser.open(url, new=2) @@ -39,14 +25,10 @@ def print_link(url: str, title: str, *, open_browser: bool = True) -> None: from rich.panel import Panel console = Console() - style = "bold underline rgb(108,113,196)" link_markup = f"[{style}][link={url}]{url}[/link][/{style}]" - - content = Align.center(link_markup) - panel = Panel( - content, + Align.center(link_markup), title=title, border_style="rgb(192,150,12)", padding=(0, 2), @@ -57,14 +39,7 @@ def print_link(url: str, title: str, *, open_browser: bool = True) -> None: def print_complete(url: str, name: str, *, error: bool = False) -> None: - """Print a completion message with link. - - Args: - url: The URL to display - name: Name of the eval/job - error: Whether an error occurred - """ - # Only print if telemetry is enabled and has API key + """Print a completion message with link.""" if not (settings.telemetry_enabled and settings.api_key): return @@ -72,7 +47,6 @@ def print_complete(url: str, name: str, *, error: bool = False) -> None: from rich.console import Console console = Console() - if error: console.print( f"\n[red]✗ '{name}' failed![/red] [dim]View details at:[/dim] " @@ -88,22 +62,25 @@ def print_complete(url: str, name: str, *, error: bool = False) -> None: print(f"\n{name} {status}: {url}\n") # noqa: T201 -def print_eval_stats( - completed: list[EvalContext], - name: str = "", +def display_results( + results: list[Any], *, + tasks: list[Any] | None = None, + name: str = "", elapsed: float | None = None, show_details: bool = True, ) -> None: - """Print statistics for completed evaluations. + """Display evaluation results in a formatted table. Args: - completed: List of completed EvalContext objects + results: List of EvalContext objects from hud.eval() + tasks: Optional list of Task objects (for task info in table) name: Optional name for the evaluation elapsed: Optional elapsed time in seconds show_details: Whether to show per-eval details table """ - if not completed: + if not results: + print("No results to display") # noqa: T201 return try: @@ -112,109 +89,157 @@ def print_eval_stats( console = Console() except ImportError: - # Fallback to basic printing - _print_eval_stats_basic(completed, name, elapsed) + _display_basic(results, name, elapsed) return - # Calculate aggregate stats - rewards = [ctx.reward for ctx in completed if ctx.reward is not None] - errors = [ctx for ctx in completed if ctx.error is not None] - durations = [ctx.duration for ctx in completed if ctx.duration > 0] + # Extract stats from results (EvalContext objects) + rewards = [getattr(r, "reward", 0) for r in results if r is not None] + errors = [r for r in results if r is not None and getattr(r, "error", None)] + durations = [getattr(r, "duration", 0) for r in results if getattr(r, "duration", 0) > 0] + + if not rewards: + console.print("[yellow]No valid results[/yellow]") + return mean_reward = mean(rewards) if rewards else 0.0 std_reward = pstdev(rewards) if len(rewards) > 1 else 0.0 - success_rate = (len(completed) - len(errors)) / len(completed) if completed else 0.0 + success_count = sum(1 for r in rewards if r > 0.7) + success_rate = success_count / len(results) if results else 0.0 # Print summary - title = f"📊 '{name}' Results" if name else "📊 Eval Results" + title = f"📊 '{name}' Results" if name else "📊 Evaluation Complete" console.print(f"\n[bold]{title}[/bold]") - console.print(f" [dim]Evals:[/dim] {len(completed)}") + console.print(f" [dim]Evals:[/dim] {len(results)}") if elapsed: - rate = len(completed) / elapsed if elapsed > 0 else 0 - console.print(f" [dim]Time:[/dim] {elapsed:.1f}s ({rate:.1f} evals/s)") + rate = len(results) / elapsed if elapsed > 0 else 0 + console.print(f" [dim]Time:[/dim] {elapsed:.1f}s ({rate:.1f}/s)") if durations: - mean_duration = mean(durations) - console.print(f" [dim]Avg duration:[/dim] {mean_duration:.2f}s") + console.print(f" [dim]Avg duration:[/dim] {mean(durations):.2f}s") console.print(f" [dim]Mean reward:[/dim] [green]{mean_reward:.3f}[/green] ± {std_reward:.3f}") console.print(f" [dim]Success rate:[/dim] [yellow]{success_rate * 100:.1f}%[/yellow]") if errors: console.print(f" [dim]Errors:[/dim] [red]{len(errors)}[/red]") - # Show details table if requested and not too many - if show_details and len(completed) <= 50: - table = Table(title="Per-Eval Details", show_header=True, header_style="bold") + # Details table + if show_details and len(results) <= 50: + table = Table(title="Details", show_header=True, header_style="bold") table.add_column("#", style="dim", justify="right", width=4) - table.add_column("Variants", style="cyan", max_width=35) - table.add_column("Answer", style="white", max_width=25) - table.add_column("Reward", justify="right", style="green", width=8) - table.add_column("Duration", justify="right", width=10) - table.add_column("Status", justify="center", width=8) - for ctx in completed: - idx_str = str(ctx.index) - variants_str = _format_variants(ctx.variants) if ctx.variants else "-" - answer_str = _truncate(ctx.answer, 30) if ctx.answer else "-" - reward_str = f"{ctx.reward:.3f}" if ctx.reward is not None else "-" - duration_str = f"{ctx.duration:.2f}s" if ctx.duration > 0 else "-" + # Check if we have variants (grouped parallel runs) + has_variants = any(getattr(r, "variants", None) for r in results if r) + has_answers = any(getattr(r, "answer", None) for r in results if r) + + if has_variants: + table.add_column("Variants", style="cyan", max_width=30) + elif tasks: + table.add_column("Task", style="cyan", max_width=30) + + if has_answers: + table.add_column("Answer", style="dim", max_width=35) - if ctx.error: + table.add_column("Reward", justify="right", style="green", width=8) + if durations: + table.add_column("Time", justify="right", width=8) + table.add_column("", justify="center", width=3) # Status icon + + for i, r in enumerate(results): + if r is None: + continue + + idx = getattr(r, "index", i) + reward = getattr(r, "reward", None) + error = getattr(r, "error", None) + duration = getattr(r, "duration", 0) + variants = getattr(r, "variants", None) + answer = getattr(r, "answer", None) + + # Status icon + if error: status = "[red]✗[/red]" - elif ctx.reward is not None and ctx.reward > 0.7: + elif reward is not None and reward > 0.7: status = "[green]✓[/green]" else: status = "[yellow]○[/yellow]" - table.add_row(idx_str, variants_str, answer_str, reward_str, duration_str, status) + row = [str(idx)] + + # Variant or task column + if has_variants: + row.append(_format_variants(variants)) + elif tasks and i < len(tasks): + task = tasks[i] + task_label = _get_task_label(task, i) + row.append(task_label[:30]) + + # Answer column + if has_answers: + row.append(_truncate(answer, 35)) + + # Reward + row.append(f"{reward:.3f}" if reward is not None else "—") + + # Duration + if durations: + row.append(f"{duration:.1f}s" if duration > 0 else "—") + + row.append(status) + table.add_row(*row) console.print(table) - # Warn about high variance + # Variance warning if std_reward > 0.3: - console.print(f"\n[yellow]⚠️ High variance detected (std={std_reward:.3f})[/yellow]") + console.print(f"\n[yellow]⚠️ High variance (std={std_reward:.3f})[/yellow]") console.print() -def _format_variants(variants: dict[str, Any]) -> str: +def _display_basic(results: list[Any], name: str, elapsed: float | None) -> None: + """Fallback display without rich.""" + rewards = [getattr(r, "reward", 0) for r in results if r is not None] + title = f"'{name}' Results" if name else "Eval Results" + print(f"\n{title}") # noqa: T201 + print(f" Evals: {len(results)}") # noqa: T201 + if elapsed: + print(f" Time: {elapsed:.1f}s") # noqa: T201 + if rewards: + print(f" Mean reward: {mean(rewards):.3f}") # noqa: T201 + print() # noqa: T201 + + +def _format_variants(variants: dict[str, Any] | None) -> str: """Format variants dict for display.""" if not variants: return "-" parts = [f"{k}={v}" for k, v in variants.items()] result = ", ".join(parts) - return result[:35] + "..." if len(result) > 35 else result + return result[:28] + ".." if len(result) > 30 else result def _truncate(text: str | None, max_len: int) -> str: """Truncate text to max length.""" if not text: return "-" - # Replace newlines with spaces for display text = text.replace("\n", " ").strip() - return text[:max_len] + "..." if len(text) > max_len else text + return text[: max_len - 2] + ".." if len(text) > max_len else text -def _print_eval_stats_basic( - completed: list[EvalContext], - name: str, - elapsed: float | None, -) -> None: - """Basic stats printing without rich.""" - rewards = [ctx.reward for ctx in completed if ctx.reward is not None] - errors = [ctx for ctx in completed if ctx.error is not None] +def _get_task_label(task: Any, index: int) -> str: + """Get a display label for a task.""" + if task is None: + return f"task_{index}" + if isinstance(task, dict): + return task.get("id") or task.get("prompt", "")[:25] or f"task_{index}" + task_id = getattr(task, "id", None) + if task_id: + return task_id + prompt = getattr(task, "prompt", None) or getattr(task, "scenario", None) + if prompt: + return prompt[:25] + return f"task_{index}" - mean_reward = mean(rewards) if rewards else 0.0 - success_rate = (len(completed) - len(errors)) / len(completed) if completed else 0.0 - - title = f"'{name}' Results" if name else "Eval Results" - print(f"\n{title}") # noqa: T201 - print(f" Evals: {len(completed)}") # noqa: T201 - if elapsed: - print(f" Time: {elapsed:.1f}s") # noqa: T201 - print(f" Mean reward: {mean_reward:.3f}") # noqa: T201 - print(f" Success rate: {success_rate * 100:.1f}%") # noqa: T201 - if errors: - print(f" Errors: {len(errors)}") # noqa: T201 - print() # noqa: T201 +# Backwards compatibility alias +print_eval_stats = display_results -__all__ = ["print_complete", "print_eval_stats", "print_link"] +__all__ = ["display_results", "print_complete", "print_eval_stats", "print_link"] diff --git a/hud/eval/manager.py b/hud/eval/manager.py index 34fa7690..4cb893f7 100644 --- a/hud/eval/manager.py +++ b/hud/eval/manager.py @@ -454,7 +454,7 @@ async def run_one(config: tuple[Task | None, dict[str, Any]]) -> EvalContext: # Log and print stats eval_name = completed[0].eval_name if completed else "eval" log_eval_stats(completed) - print_eval_stats(completed, eval_name) + print_eval_stats(completed, name=eval_name) return list(completed) diff --git a/hud/eval/task.py b/hud/eval/task.py index 421cb075..6a181ddf 100644 --- a/hud/eval/task.py +++ b/hud/eval/task.py @@ -28,7 +28,14 @@ import logging from typing import TYPE_CHECKING, Any -from pydantic import BaseModel, ConfigDict, Field, field_validator +from pydantic import ( + BaseModel, + ConfigDict, + Field, + field_validator, + model_serializer, + model_validator, +) from hud.types import MCPToolCall @@ -36,35 +43,21 @@ from hud.environment import Environment from hud.environment.types import EnvConfig -__all__ = ["Task", "build_eval_name"] - -logger = logging.getLogger(__name__) +__all__ = ["Task", "TaskAgentConfig", "build_eval_name"] -def _warn_local_mcp(mcp_config: dict[str, Any] | None) -> None: - """Warn if mcp_config uses local MCP servers (command without url). +class TaskAgentConfig(BaseModel): + """Agent configuration for a Task. - Local MCP servers can cause port conflicts when running tasks concurrently. + Contains settings that should be passed to the agent when running this task. """ - if not mcp_config: - return - has_local = any( - isinstance(server_cfg, dict) and "command" in server_cfg and not server_cfg.get("url") - for server_cfg in mcp_config.values() - if isinstance(server_cfg, dict) + system_prompt: str | None = Field( + default=None, + description="Custom system prompt to pass to the agent", ) - if has_local: - import warnings - - warnings.warn( - "Task uses local MCP configuration (command without url). " - "This may cause port conflicts when running tasks concurrently. " - "Consider using remote MCP servers for parallel execution.", - UserWarning, - stacklevel=4, # Skip through from_v4 -> _warn_local_mcp -> warn - ) +logger = logging.getLogger(__name__) def build_eval_name(scenario: str | None, args: dict[str, Any] | None) -> str: @@ -141,12 +134,44 @@ class Task(BaseModel): args: dict[str, Any] = Field(default_factory=dict) validation: list[MCPToolCall] | None = None + # Agent config - settings passed to agent (system_prompt, etc.) + agent_config: TaskAgentConfig | None = None + + # Task metadata - for tracking/filtering, not used by agent + metadata: dict[str, Any] = Field(default_factory=dict) + + @model_validator(mode="before") + @classmethod + def detect_v4_format(cls, data: Any) -> Any: + """Auto-detect v4 LegacyTask format and convert to v5 Task format. + + If the input dict is a valid v4 format (has prompt, mcp_config, evaluate_tool), + it's converted using build_env_from_v4(). + + This allows Task(**v4_dict) to work seamlessly. + """ + from hud.eval.utils import build_env_from_v4, is_v4_format, validate_v4_task + + if not isinstance(data, dict): + return data + + if is_v4_format(data): + # Validate completeness before conversion + validate_v4_task(data) + # build_env_from_v4 returns a dict with all Task fields + return build_env_from_v4(data) + + return data + @field_validator("env", mode="before") @classmethod def convert_env( cls, v: Environment | EnvConfig | dict[str, Any] | None ) -> Environment | None: - """Auto-convert dict/EnvConfig to Environment.""" + """Auto-convert dict/EnvConfig to Environment. + + Format: {"name": "browser", "include": [...], "exclude": [...]} + """ from hud.environment import Environment from hud.environment.types import EnvConfig @@ -156,12 +181,15 @@ def convert_env( return v if isinstance(v, dict): try: - v = EnvConfig(**v) + config = EnvConfig(**v) except Exception as e: raise ValueError( f"Invalid env config: {e}. Expected fields: name (str), " f"include (list[str] | None), exclude (list[str] | None)" ) from e + env = Environment(config.name) + env.connect_hub(config.name, include=config.include, exclude=config.exclude) + return env if isinstance(v, EnvConfig): env = Environment(v.name) env.connect_hub(v.name, include=v.include, exclude=v.exclude) @@ -193,121 +221,82 @@ def convert_validation( ) return converted + @model_serializer(mode="wrap") + def _serialize_task( + self, handler: Any # SerializerFunctionWrapHandler + ) -> dict[str, Any]: + """Custom serializer that converts Environment to config dict. + + For v5 tasks: outputs {"env": {"name": "browser", ...}, "scenario": ...} + For v4 tasks: outputs {"prompt": ..., "mcp_config": ..., "evaluate_tool": ...} + + Raises ValueError if environment has local tools/scenarios. + """ + from hud.environment import Environment + + # Get default serialization + data = handler(self) + + # Convert Environment to serializable config + if isinstance(self.env, Environment): + env_config = self.env.to_config() + + # Detect v4 format (has mcp_config) vs v5 format (has name) + if "mcp_config" in env_config: + # v4 format - merge env_config with Task fields + result = env_config.copy() + + # Map validation → integration_test_tool + if self.validation: + result["integration_test_tool"] = [ + {"name": v.name, "arguments": v.arguments or {}} + for v in self.validation + ] + + # Preserve agent_config (with system_prompt) + if self.agent_config and self.agent_config.system_prompt: + result["agent_config"] = {"system_prompt": self.agent_config.system_prompt} + + # Preserve metadata + if self.metadata: + result["metadata"] = self.metadata + + # Preserve id + if self.id: + result["id"] = self.id + + return result + else: + # v5 format - env config goes in env field + data["env"] = env_config + + return data + @classmethod - def from_v4( - cls, - source: Any, # LegacyTask | dict[str, Any] | str - ) -> Task: - """Convert a v4 LegacyTask to a v5 Task. + def from_v4(cls, source: Any) -> Task: + """Convert v4 LegacyTask format to v5 Task. - This is the recommended migration path for existing v4 code. The returned - Task automatically runs setup_tool at the start and evaluate_tool at the - end, matching the old LegacyTask behavior. + This is a convenience wrapper. You can also use Task(**dict) directly + since the model validator auto-detects v4 format. Args: - source: One of: - - LegacyTask object - - dict with LegacyTask fields (prompt, mcp_config, etc.) - - JSON string of LegacyTask fields + source: LegacyTask, dict, or JSON string with v4 fields Returns: - Task with Environment configured to mimic LegacyTask behavior. - - Example: - ```python - from hud.eval import Task - - # From existing LegacyTask - task = Task.from_v4(legacy_task) - - # From dict (e.g., loaded from JSON file) - task = Task.from_v4( - { - "prompt": "Navigate to google.com", - "mcp_config": {"hud": {...}}, - "setup_tool": {"name": "navigate", "arguments": {"url": "..."}}, - "evaluate_tool": {"name": "check_url", "arguments": {}}, - } - ) - - # Use with hud.eval() or as context manager - async with task as ctx: - result = await agent.run(ctx) - ``` - - Note: - For new code, prefer using @env.scenario() instead: - - setup_tool code goes BEFORE the first yield - - evaluate_tool code goes AFTER the first yield - See https://docs.hud.ai/migration for the full migration guide. + Task configured for v4 behavior """ import json as json_module - from hud.environment import Environment - from hud.types import LegacyTask - - # Parse JSON string + # JSON string → dict if isinstance(source, str): - try: - source = json_module.loads(source) - except json_module.JSONDecodeError as e: - from hud.shared.exceptions import HudConfigError - - raise HudConfigError(f"Invalid JSON string for Task.from_v4: {e}") from e - - # Convert dict to LegacyTask (suppress the deprecation warning since we're migrating) - if isinstance(source, dict): - import warnings - - with warnings.catch_warnings(): - warnings.filterwarnings("ignore", category=DeprecationWarning) - legacy_task = LegacyTask(**source) - elif isinstance(source, LegacyTask): - legacy_task = source - else: - raise TypeError( - f"Task.from_v4() expects LegacyTask, dict, or JSON string, " - f"got {type(source).__name__}" - ) - - # Warn if using local MCP configs (command without url) - _warn_local_mcp(legacy_task.mcp_config) - - # Create Environment and connect via mcp_config - env = Environment(legacy_task.id or "v4-legacy") - env.connect_mcp_config(legacy_task.mcp_config) - - # Set the prompt - env.prompt = legacy_task.prompt - - # Add setup_tool calls (run after connection via Environment._setup_calls) - if legacy_task.setup_tool: - setup_calls = legacy_task.setup_tool - if not isinstance(setup_calls, list): - setup_calls = [setup_calls] - for call in setup_calls: - env.setup_tool(call.name, **(call.arguments or {})) - - # Add evaluate_tool calls (run before disconnection via Environment._evaluate_calls) - if legacy_task.evaluate_tool: - evaluate_calls = legacy_task.evaluate_tool - if not isinstance(evaluate_calls, list): - evaluate_calls = [evaluate_calls] - for call in evaluate_calls: - env.evaluate_tool(call.name, **(call.arguments or {})) - - logger.debug( - "Created Task from v4 LegacyTask: %s", - legacy_task.prompt[:50] if legacy_task.prompt else "no prompt", - ) + source = json_module.loads(source) - return cls( - env=env, # Live Environment with mcp_config, setup_tool, evaluate_tool - scenario=None, # v4 tasks use prompt directly, not scenarios - id=legacy_task.id, - args={}, - validation=None, - ) + # LegacyTask → dict (import only when needed) + if hasattr(source, "model_dump"): + source = source.model_dump() + + # Model validator handles v4 detection and conversion + return cls(**source) def copy(self) -> Task: """Create a copy of this Task config. diff --git a/hud/eval/tests/test_context.py b/hud/eval/tests/test_context.py index 275d3fef..21f9fdb7 100644 --- a/hud/eval/tests/test_context.py +++ b/hud/eval/tests/test_context.py @@ -47,12 +47,6 @@ def test_success_false_when_error(self) -> None: assert ctx.success is False - def test_done_false_initially(self) -> None: - """done property returns False initially.""" - ctx = EvalContext(name="test-task", quiet=True) - - assert ctx.done is False - def test_variants_empty_by_default(self) -> None: """variants is empty dict by default.""" ctx = EvalContext(name="test-task", quiet=True) diff --git a/hud/eval/types.py b/hud/eval/types.py index d3ececb0..6a2df059 100644 --- a/hud/eval/types.py +++ b/hud/eval/types.py @@ -36,6 +36,7 @@ class EvalPayload(BaseModel): group_id: str | None = None variants: dict[str, Any] | None = None task_version_id: str | None = None + metadata: dict[str, Any] | None = None class EvalExitPayload(EvalPayload): diff --git a/hud/eval/utils.py b/hud/eval/utils.py new file mode 100644 index 00000000..b44b1875 --- /dev/null +++ b/hud/eval/utils.py @@ -0,0 +1,178 @@ +"""Utility functions for the eval module.""" + +from __future__ import annotations + +import logging +import warnings +from typing import Any + +__all__ = ["build_env_from_v4", "is_v4_format", "validate_v4_task"] + +logger = logging.getLogger(__name__) + + +def is_v4_format(data: dict[str, Any]) -> bool: + """Detect if dict looks like v4 LegacyTask format. + + Used for branching logic. Checks if data has the core v4 fields + (prompt AND mcp_config). Does NOT validate completeness. + + Args: + data: Dict to check + + Returns: + True if looks like v4 format, False otherwise + """ + if not isinstance(data, dict): + return False + + # Core v4 detection: prompt + mcp_config + return bool(data.get("prompt")) and bool(data.get("mcp_config")) + + +def validate_v4_task(data: dict[str, Any]) -> None: + """Validate v4 task has all required fields. + + A valid v4 task must have all three required fields: + - prompt: The task instruction + - mcp_config: MCP server configuration + - evaluate_tool: How to evaluate success + + Call this after is_v4_format() when you need to ensure completeness. + + Args: + data: Dict to validate + + Raises: + ValueError: If any required fields are missing + """ + missing = [] + if not data.get("prompt"): + missing.append("prompt") + if not data.get("mcp_config"): + missing.append("mcp_config") + if not data.get("evaluate_tool"): + missing.append("evaluate_tool") + + if missing: + raise ValueError(f"v4 task missing required fields: {', '.join(missing)}") + + +def build_env_from_v4(source: dict[str, Any] | Any) -> dict[str, Any]: + """Build Environment from v4 LegacyTask format. + + Creates an Environment configured with the legacy task's fields. + Returns a dict ready to be passed to Task() constructor. + + Args: + source: dict or LegacyTask with v4 fields (prompt, mcp_config, etc.) + + Returns: + Dict with Task fields: env, id, scenario, args, validation, system_prompt, metadata + + Raises: + TypeError: If source is not a dict or LegacyTask + """ + from hud.environment import Environment + from hud.types import LegacyTask, MCPToolCall + + # Convert dict to LegacyTask if needed + if isinstance(source, dict): + with warnings.catch_warnings(): + warnings.filterwarnings("ignore", category=DeprecationWarning) + legacy = LegacyTask(**source) + elif isinstance(source, LegacyTask): + legacy = source + else: + raise TypeError(f"Expected dict or LegacyTask, got {type(source).__name__}") + + # Warn if using local MCP configs (command without url) + _warn_local_mcp(legacy.mcp_config) + + # Extract tool filters from agent_config (v4 style) + include_tools: list[str] | None = None + exclude_tools: list[str] | None = None + if legacy.agent_config: + include_tools = legacy.agent_config.allowed_tools + exclude_tools = legacy.agent_config.disallowed_tools + + # Create Environment - NO connections made here, just config stored + env = Environment(legacy.id or "v4-legacy") + env.connect_mcp_config( + legacy.mcp_config, + include=include_tools, + exclude=exclude_tools, + ) + + # Set the prompt + env.prompt = legacy.prompt + + # Add setup_tool calls (stored, not executed) + if legacy.setup_tool: + setup_calls = legacy.setup_tool + if not isinstance(setup_calls, list): + setup_calls = [setup_calls] + for call in setup_calls: + env.setup_tool(call.name, **(call.arguments or {})) + + # Add evaluate_tool calls (stored, not executed) + if legacy.evaluate_tool: + eval_calls = legacy.evaluate_tool + if not isinstance(eval_calls, list): + eval_calls = [eval_calls] + for call in eval_calls: + env.evaluate_tool(call.name, **(call.arguments or {})) + + # Build Task fields dict + result: dict[str, Any] = { + "env": env, + "id": legacy.id, + "scenario": None, # v4 uses prompt, not scenarios + "args": {}, + } + + # Map integration_test_tool → validation (same concept: tool calls to verify) + if legacy.integration_test_tool: + int_test = legacy.integration_test_tool + if not isinstance(int_test, list): + int_test = [int_test] + # Convert to MCPToolCall if needed + result["validation"] = [ + call if isinstance(call, MCPToolCall) else MCPToolCall(**call.model_dump()) + for call in int_test + ] + + # Extract agent_config (just system_prompt for now) + if legacy.agent_config and legacy.agent_config.system_prompt: + result["agent_config"] = {"system_prompt": legacy.agent_config.system_prompt} + + # Preserve metadata + if legacy.metadata: + result["metadata"] = legacy.metadata + + return result + + +def _warn_local_mcp(mcp_config: dict[str, Any] | None) -> None: + """Warn if mcp_config uses local MCP servers (command without url). + + Local MCP servers can cause port conflicts when running tasks concurrently. + """ + if not mcp_config: + return + + has_local = any( + isinstance(server_cfg, dict) and "command" in server_cfg and not server_cfg.get("url") + for server_cfg in mcp_config.values() + if isinstance(server_cfg, dict) + ) + + if has_local: + warnings.warn( + "Task uses local MCP configuration (command without url). " + "This may cause port conflicts when running tasks concurrently. " + "Consider using remote MCP servers for parallel execution.", + UserWarning, + stacklevel=4, + ) + diff --git a/hud/otel/__init__.py b/hud/otel/__init__.py deleted file mode 100644 index 855c7622..00000000 --- a/hud/otel/__init__.py +++ /dev/null @@ -1,51 +0,0 @@ -"""HUD OpenTelemetry integration. - -.. deprecated:: - The `hud.otel` module is deprecated and will be removed in a future version. - Use `env.trace()` from `hud.environment.Environment` instead. - - This module requires the [agents] extra: - pip install hud-python[agents] - -This package provides the internal OpenTelemetry implementation for HUD telemetry. - -Internal Components: -- config: OpenTelemetry configuration and setup -- context: Trace context management and utilities -- processors: Span enrichment with HUD context -- exporters: Sending spans to HUD backend -- collector: In-memory span collection for replay -- instrumentation: Auto-instrumentation for agents and MCP -""" - -from __future__ import annotations - -import warnings - -from .collector import enable_trace_collection -from .config import configure_telemetry, is_telemetry_configured, shutdown_telemetry -from .context import ( - get_current_task_run_id, - is_root_trace, - span_context, - trace, -) - -# Show deprecation warning when module is imported -warnings.warn( - "The hud.otel module is deprecated. Use env.trace() instead. " - "This module requires pip install hud-python[agents].", - DeprecationWarning, - stacklevel=2, -) - -__all__ = [ - "configure_telemetry", - "enable_trace_collection", - "get_current_task_run_id", - "is_root_trace", - "is_telemetry_configured", - "shutdown_telemetry", - "span_context", - "trace", -] diff --git a/hud/otel/collector.py b/hud/otel/collector.py deleted file mode 100644 index 310eb38f..00000000 --- a/hud/otel/collector.py +++ /dev/null @@ -1,142 +0,0 @@ -"""Global span collector for building in-memory traces. - -This module provides a way to collect spans during execution -and retrieve them as a Trace object, enabling replay functionality -without modifying agent code. -""" - -from __future__ import annotations - -import logging -import threading -from contextvars import ContextVar -from typing import TYPE_CHECKING - -from opentelemetry import trace -from opentelemetry.sdk.trace.export import SpanExporter, SpanExportResult - -from hud.types import Trace - -if TYPE_CHECKING: - from opentelemetry.sdk.trace import ReadableSpan - -logger = logging.getLogger(__name__) - -# Global storage for collected spans by task_run_id -_TRACE_STORAGE: dict[str, TraceCollector] = {} -_LOCK = threading.Lock() - -# Context variable to track if collection is enabled -_collecting_enabled: ContextVar[bool] = ContextVar("collecting_enabled", default=False) - - -class TraceCollector: - """Collects spans for a single task run.""" - - def __init__(self, task_run_id: str) -> None: - self.task_run_id = task_run_id - self.spans: list[ReadableSpan] = [] - self._lock = threading.Lock() - - def add_span(self, span: ReadableSpan) -> None: - """Thread-safe span addition.""" - with self._lock: - self.spans.append(span) - - def to_trace(self) -> Trace: - """Convert collected spans to a Trace object.""" - from .exporters import HudSpan, _span_to_dict - - trace = Trace() - - # Convert spans to TraceSteps - for span in self.spans: - try: - # Use the same conversion logic as the exporter - span_dict = _span_to_dict(span) - hud_span = HudSpan.model_validate(span_dict) - - # The attributes field is already a TraceStep - step = hud_span.attributes - # Add timing from the span itself - step.start_timestamp = hud_span.start_time - step.end_timestamp = hud_span.end_time - trace.append(step) - - except Exception as e: - # Log but don't fail the whole trace - logger.debug("Failed to convert span: %s", e) - - return trace - - -class CollectingSpanExporter(SpanExporter): - """A span exporter that collects spans in memory for replay.""" - - def export(self, spans: list[ReadableSpan]) -> SpanExportResult: - """Collect spans if collection is enabled.""" - if not _collecting_enabled.get(): - return SpanExportResult.SUCCESS - - for span in spans: - # Extract task_run_id from span - task_run_id = span.attributes.get("hud.task_run_id") if span.attributes else None - if not task_run_id or not isinstance(task_run_id, str): - continue - - # Get or create collector - with _LOCK: - if task_run_id not in _TRACE_STORAGE: - _TRACE_STORAGE[task_run_id] = TraceCollector(task_run_id) - collector = _TRACE_STORAGE[task_run_id] - - # Add span - collector.add_span(span) - - return SpanExportResult.SUCCESS - - def shutdown(self) -> None: - """Clean up resources.""" - with _LOCK: - _TRACE_STORAGE.clear() - - -def enable_trace_collection(enabled: bool = True) -> None: - """Enable or disable in-memory trace collection.""" - _collecting_enabled.set(enabled) - - -def get_trace(task_run_id: str) -> Trace | None: - """Retrieve collected trace for a task run ID. - - Returns None if no trace was collected or collection was disabled. - """ - with _LOCK: - collector = _TRACE_STORAGE.get(task_run_id) - if collector: - return collector.to_trace() - return None - - -def clear_trace(task_run_id: str) -> None: - """Clear collected trace for a task run ID.""" - with _LOCK: - _TRACE_STORAGE.pop(task_run_id, None) - - -def install_collector() -> None: - """Install the collecting span exporter. - - This should be called after configure_telemetry(). - """ - provider = trace.get_tracer_provider() - # Guard for SDK tracer providers only - if hasattr(provider, "add_span_processor"): - from opentelemetry.sdk.trace.export import SimpleSpanProcessor - - exporter = CollectingSpanExporter() - processor = SimpleSpanProcessor(exporter) - try: - provider.add_span_processor(processor) # type: ignore[attr-defined] - except Exception: - logger.warning("Failed to add span processor") diff --git a/hud/otel/config.py b/hud/otel/config.py deleted file mode 100644 index 9250f104..00000000 --- a/hud/otel/config.py +++ /dev/null @@ -1,183 +0,0 @@ -"""Central configuration for OpenTelemetry inside HUD SDK. - -This file is responsible for -1. creating the global ``TracerProvider`` -2. attaching span processors (HUD enrichment, batch + exporter) -3. activating the community MCP instrumentation so that *every* MCP - request/response/notification is traced automatically. - -It is *idempotent*: calling :func:`configure_telemetry` more than once -returns the same provider and does nothing. -""" - -from __future__ import annotations - -import logging -from typing import Any - -from opentelemetry import trace -from opentelemetry.sdk.resources import Resource -from opentelemetry.sdk.trace import TracerProvider -from opentelemetry.sdk.trace.export import BatchSpanProcessor - -from hud.settings import settings - -from .collector import enable_trace_collection, install_collector -from .exporters import HudSpanExporter -from .instrumentation import install_mcp_instrumentation -from .processors import HudEnrichmentProcessor - -logger = logging.getLogger(__name__) - -# Global singleton provider so multiple calls do not create duplicates -_TRACER_PROVIDER: TracerProvider | None = None - - -def is_telemetry_configured() -> bool: - """Check if telemetry has been configured.""" - return _TRACER_PROVIDER is not None - - -# --------------------------------------------------------------------------- -# Public API -# --------------------------------------------------------------------------- - - -def configure_telemetry( - *, - service_name: str = "hud-sdk", - service_version: str | None = None, - environment: str | None = None, - extra_resource_attributes: dict[str, Any] | None = None, - enable_otlp: bool = False, - otlp_endpoint: str | None = None, - otlp_headers: dict[str, str] | None = None, - enable_collection: bool = True, -) -> TracerProvider: - """Initialise OpenTelemetry for the current Python process. - - It is safe to call this in every entry-point; the provider will only - be created once. - """ - global _TRACER_PROVIDER - - if _TRACER_PROVIDER is not None: - return _TRACER_PROVIDER - - # ------------------------------------------------------------------ - # 1. Resource (identity of this service) - # ------------------------------------------------------------------ - res_attrs: dict[str, Any] = { - "service.name": service_name, - "telemetry.sdk.name": "hud-otel", - "telemetry.sdk.language": "python", - } - if service_version: - res_attrs["service.version"] = service_version - if environment: - res_attrs["deployment.environment"] = environment - if extra_resource_attributes: - res_attrs.update(extra_resource_attributes) - - resource = Resource.create(res_attrs) - - # ------------------------------------------------------------------ - # 2. Provider - # ------------------------------------------------------------------ - provider = TracerProvider(resource=resource) - _TRACER_PROVIDER = provider - - # ------------------------------------------------------------------ - # 3. Processors / exporters - # ------------------------------------------------------------------ - provider.add_span_processor(HudEnrichmentProcessor()) - - # HUD exporter (only if enabled and API key is available) - if settings.telemetry_enabled and settings.api_key: - # Use the HudSpanExporter directly (it now handles async context internally) - exporter = HudSpanExporter( - telemetry_url=settings.hud_telemetry_url, api_key=settings.api_key - ) - - # Batch exports for efficiency while maintaining reasonable real-time visibility - provider.add_span_processor( - BatchSpanProcessor( - exporter, - schedule_delay_millis=1000, # Export every 5 seconds (less frequent) - max_queue_size=16384, # Larger queue for high-volume scenarios - max_export_batch_size=512, # Larger batches (fewer uploads) - export_timeout_millis=30000, - ) - ) - elif settings.telemetry_enabled and not settings.api_key and not enable_otlp: - # Error if no exporters are configured - raise ValueError( - "No telemetry backend configured. Either:\n" - "1. Set HUD_API_KEY environment variable for HUD telemetry (https://hud.ai)\n" - "2. Use enable_otlp=True with configure_telemetry() for alternative backends (e.g., Jaeger)\n" # noqa: E501 - ) - elif not settings.telemetry_enabled: - logger.info("HUD telemetry disabled via HUD_TELEMETRY_ENABLED=false") - - # OTLP exporter (optional - for standard OTel viewers) - if enable_otlp: - try: - from opentelemetry.exporter.otlp.proto.http.trace_exporter import OTLPSpanExporter - - otlp_config = {} - if otlp_endpoint: - otlp_config["endpoint"] = otlp_endpoint - # Default to HTTP endpoint if not specified - if not otlp_endpoint.startswith(("http://", "https://")): - otlp_config["endpoint"] = f"http://{otlp_endpoint}/v1/traces" - else: - # Default HTTP endpoint - otlp_config["endpoint"] = "http://localhost:4318/v1/traces" - - if otlp_headers: - otlp_config["headers"] = otlp_headers - - otlp_exporter = OTLPSpanExporter(**otlp_config) - provider.add_span_processor( - BatchSpanProcessor( - otlp_exporter, - schedule_delay_millis=1000, - max_queue_size=16384, - max_export_batch_size=512, - export_timeout_millis=30000, - ) - ) - logger.info("OTLP HTTP exporter enabled - endpoint: %s", otlp_config["endpoint"]) - except ImportError: - logger.warning( - "OTLP export requested but opentelemetry-exporter-otlp-proto-http not installed. " - "Install with: pip install 'hud-python[agent]'" - ) - - # ------------------------------------------------------------------ - # 4. Activate provider and instrumentation - # ------------------------------------------------------------------ - trace.set_tracer_provider(provider) - install_mcp_instrumentation(provider) - - # Install in-memory collector if requested - if enable_collection: - install_collector() - enable_trace_collection(True) - logger.debug("In-memory trace collection enabled") - - # Agent instrumentation now handled by @hud.instrument decorators - logger.debug("OpenTelemetry configuration completed") - - logger.debug("OpenTelemetry configured (provider id=%s)", id(provider)) - return provider - - -def shutdown_telemetry() -> None: - """Flush and shutdown the global provider (if configured).""" - global _TRACER_PROVIDER - if _TRACER_PROVIDER is None: - return - _TRACER_PROVIDER.shutdown() # type: ignore[arg-type] - _TRACER_PROVIDER = None - logger.debug("OpenTelemetry shutdown complete") diff --git a/hud/otel/context.py b/hud/otel/context.py deleted file mode 100644 index 756ff64a..00000000 --- a/hud/otel/context.py +++ /dev/null @@ -1,572 +0,0 @@ -"""OpenTelemetry context utilities for HUD telemetry. - -This module provides internal utilities for managing OpenTelemetry contexts. -User-facing APIs are in hud.telemetry. -""" - -from __future__ import annotations - -import contextlib -import contextvars -import logging -import traceback -from contextlib import contextmanager -from typing import TYPE_CHECKING, Any - -from opentelemetry import baggage, context -from opentelemetry import trace as otel_trace -from opentelemetry.trace import Status, StatusCode - -if TYPE_CHECKING: - from collections.abc import Generator - from types import TracebackType - -from hud.settings import settings -from hud.shared import make_request, make_request_sync - -logger = logging.getLogger(__name__) - -# Context variables for task tracking -current_task_run_id: contextvars.ContextVar[str | None] = contextvars.ContextVar( - "current_task_run_id", default=None -) -is_root_trace_var: contextvars.ContextVar[bool] = contextvars.ContextVar( - "is_root_trace", default=False -) - -# Step counters for different types -current_base_mcp_steps: contextvars.ContextVar[int] = contextvars.ContextVar( - "current_base_mcp_steps", default=0 -) -current_mcp_tool_steps: contextvars.ContextVar[int] = contextvars.ContextVar( - "current_mcp_tool_steps", default=0 -) -current_agent_steps: contextvars.ContextVar[int] = contextvars.ContextVar( - "current_agent_steps", default=0 -) - -# Keys for OpenTelemetry baggage -TASK_RUN_ID_KEY = "hud.task_run_id" -IS_ROOT_TRACE_KEY = "hud.is_root_trace" -BASE_MCP_STEPS_KEY = "hud.base_mcp_steps" -MCP_TOOL_STEPS_KEY = "hud.mcp_tool_steps" -AGENT_STEPS_KEY = "hud.agent_steps" - - -def set_current_task_run_id(task_run_id: str | None) -> contextvars.Token: - """Set the current task run ID.""" - return current_task_run_id.set(task_run_id) - - -def get_current_task_run_id() -> str | None: - """Get current task_run_id from either contextvars or OTel baggage.""" - # First try OTel baggage - task_run_id = baggage.get_baggage(TASK_RUN_ID_KEY) - if task_run_id and isinstance(task_run_id, str): - return task_run_id - - # Fallback to contextvars - return current_task_run_id.get() - - -def is_root_trace() -> bool: - """Check if current context is a root trace.""" - # First try OTel baggage - is_root = baggage.get_baggage(IS_ROOT_TRACE_KEY) - if isinstance(is_root, str): - return is_root.lower() == "true" - - # Fallback to contextvars - return is_root_trace_var.get() - - -def get_base_mcp_steps() -> int: - """Get current base MCP step count from either contextvars or OTel baggage.""" - # First try OTel baggage - step_count = baggage.get_baggage(BASE_MCP_STEPS_KEY) - if step_count and isinstance(step_count, str): - try: - return int(step_count) - except ValueError: - pass - - # Fallback to contextvars - return current_base_mcp_steps.get() - - -def get_mcp_tool_steps() -> int: - """Get current MCP tool step count from either contextvars or OTel baggage.""" - # First try OTel baggage - step_count = baggage.get_baggage(MCP_TOOL_STEPS_KEY) - if step_count and isinstance(step_count, str): - try: - return int(step_count) - except ValueError: - pass - - # Fallback to contextvars - return current_mcp_tool_steps.get() - - -def get_agent_steps() -> int: - """Get current agent step count from either contextvars or OTel baggage.""" - # First try OTel baggage - step_count = baggage.get_baggage(AGENT_STEPS_KEY) - if step_count and isinstance(step_count, str): - try: - return int(step_count) - except ValueError: - pass - - # Fallback to contextvars - return current_agent_steps.get() - - -def increment_base_mcp_steps() -> int: - """Increment the base MCP step count and update baggage. - - Returns: - The new base MCP step count after incrementing - """ - current = get_base_mcp_steps() - new_count = current + 1 - - # Update contextvar - current_base_mcp_steps.set(new_count) - - # Update baggage for propagation - ctx = baggage.set_baggage(BASE_MCP_STEPS_KEY, str(new_count)) - context.attach(ctx) - - # Update current span if one exists - span = otel_trace.get_current_span() - if span and span.is_recording(): - span.set_attribute("hud.base_mcp_steps", new_count) - - return new_count - - -def increment_mcp_tool_steps() -> int: - """Increment the MCP tool step count and update baggage. - - Returns: - The new MCP tool step count after incrementing - """ - current = get_mcp_tool_steps() - new_count = current + 1 - - # Update contextvar - current_mcp_tool_steps.set(new_count) - - # Update baggage for propagation - ctx = baggage.set_baggage(MCP_TOOL_STEPS_KEY, str(new_count)) - context.attach(ctx) - - # Update current span if one exists - span = otel_trace.get_current_span() - if span and span.is_recording(): - span.set_attribute("hud.mcp_tool_steps", new_count) - - return new_count - - -def increment_agent_steps() -> int: - """Increment the agent step count and update baggage. - - Returns: - The new agent step count after incrementing - """ - current = get_agent_steps() - new_count = current + 1 - - # Update contextvar - current_agent_steps.set(new_count) - - # Update baggage for propagation - ctx = baggage.set_baggage(AGENT_STEPS_KEY, str(new_count)) - context.attach(ctx) - - # Update current span if one exists - span = otel_trace.get_current_span() - if span and span.is_recording(): - span.set_attribute("hud.agent_steps", new_count) - - return new_count - - -@contextmanager -def span_context( - name: str, - attributes: dict[str, Any] | None = None, - kind: otel_trace.SpanKind = otel_trace.SpanKind.INTERNAL, -) -> Generator[otel_trace.Span, None, None]: - """Create a child span within the current trace context. - - This is a simple wrapper around OpenTelemetry's span creation that - ensures the span inherits the current HUD context (task_run_id, etc). - - Args: - name: Name for the span - attributes: Additional attributes to add to the span - kind: OpenTelemetry span kind - - Example: - with span_context("process_data", {"items": 100}) as span: - # Process data... - span.set_attribute("processed", True) - """ - tracer = otel_trace.get_tracer("hud-sdk") - - # Current task_run_id will be added by HudEnrichmentProcessor - with tracer.start_as_current_span( - name, - attributes=attributes, - kind=kind, - ) as span: - yield span - - -async def _update_task_status_async( - task_run_id: str, - status: str, - job_id: str | None = None, - error_message: str | None = None, - trace_name: str | None = None, - task_id: str | None = None, - group_id: str | None = None, - extra_metadata: dict[str, Any] | None = None, -) -> None: - """Async task status update.""" - if not settings.telemetry_enabled: - return - - try: - data: dict[str, Any] = {"status": status} - - # Resolve effective job_id from explicit param, OTel baggage, or current job context - effective_job_id: str | None = job_id - if not effective_job_id: - bj = baggage.get_baggage("hud.job_id") - if isinstance(bj, str) and bj: - effective_job_id = bj - if not effective_job_id: - try: - from hud.telemetry.job import get_current_job # Local import to avoid cycles - - current_job = get_current_job() - if current_job: - effective_job_id = current_job.id - except Exception: - effective_job_id = None - - if effective_job_id: - data["job_id"] = effective_job_id - if error_message: - data["error_message"] = error_message - - # Build metadata with trace name and step counts - metadata = {} - if trace_name: - metadata["trace_name"] = trace_name - - # Include all three step counts in metadata - metadata["base_mcp_steps"] = get_base_mcp_steps() - metadata["mcp_tool_steps"] = get_mcp_tool_steps() - metadata["agent_steps"] = get_agent_steps() - - # Merge any extra metadata provided by callers (e.g., task config summaries) - if extra_metadata: - with contextlib.suppress(Exception): - metadata.update(extra_metadata) - - if metadata: - data["metadata"] = metadata - - if task_id: - data["task_id"] = task_id - - if group_id: - data["group_id"] = group_id - - await make_request( - method="POST", - url=f"{settings.hud_telemetry_url}/trace/{task_run_id}/status", - json=data, - api_key=settings.api_key, - ) - logger.debug("Updated task %s status to %s", task_run_id, status) - except Exception as e: - # Suppress warnings about interpreter shutdown - if "interpreter shutdown" not in str(e): - logger.warning("Failed to update task status: %s", e) - - -def _update_task_status_sync( - task_run_id: str, - status: str, - job_id: str | None = None, - error_message: str | None = None, - trace_name: str | None = None, - task_id: str | None = None, - group_id: str | None = None, - extra_metadata: dict[str, Any] | None = None, -) -> None: - """Synchronous task status update.""" - if not settings.telemetry_enabled: - return - - try: - data: dict[str, Any] = {"status": status} - - # Resolve effective job_id from explicit param, OTel baggage, or current job context - effective_job_id: str | None = job_id - if not effective_job_id: - bj = baggage.get_baggage("hud.job_id") - if isinstance(bj, str) and bj: - effective_job_id = bj - if not effective_job_id: - try: - from hud.telemetry.job import get_current_job # Local import to avoid cycles - - current_job = get_current_job() - if current_job: - effective_job_id = current_job.id - except Exception: - effective_job_id = None - - if effective_job_id: - data["job_id"] = effective_job_id - if error_message: - data["error_message"] = error_message - - # Build metadata with trace name and step counts - metadata = {} - if trace_name: - metadata["trace_name"] = trace_name - - # Include all three step counts in metadata - metadata["base_mcp_steps"] = get_base_mcp_steps() - metadata["mcp_tool_steps"] = get_mcp_tool_steps() - metadata["agent_steps"] = get_agent_steps() - - # Merge any extra metadata provided by callers - if extra_metadata: - with contextlib.suppress(Exception): - metadata.update(extra_metadata) - - if metadata: - data["metadata"] = metadata - - if task_id: - data["task_id"] = task_id - - if group_id: - data["group_id"] = group_id - - make_request_sync( - method="POST", - url=f"{settings.hud_telemetry_url}/trace/{task_run_id}/status", - json=data, - api_key=settings.api_key, - ) - logger.debug("Updated task %s status to %s", task_run_id, status) - except Exception as e: - # Suppress warnings about interpreter shutdown - if "interpreter shutdown" not in str(e): - logger.warning("Failed to update task status: %s", e) - - -def _print_trace_url(task_run_id: str) -> None: - """Print the trace URL in a colorful box.""" - # Only print HUD URL if HUD telemetry is enabled and has API key - if not (settings.telemetry_enabled and settings.api_key): - return - - url = f"https://hud.ai/trace/{task_run_id}" - header = "🚀 See your agent live at:" - - # ANSI color codes - DIM = "\033[90m" # Dim/Gray for border (visible on both light and dark terminals) - GOLD = "\033[33m" # Gold/Yellow for URL - RESET = "\033[0m" - BOLD = "\033[1m" - - # Calculate box width based on the longest line - box_width = max(len(url), len(header)) + 6 - - # Box drawing characters - top_border = "╔" + "═" * (box_width - 2) + "╗" - bottom_border = "╚" + "═" * (box_width - 2) + "╝" - divider = "╟" + "─" * (box_width - 2) + "╢" - - # Center the content - header_padding = (box_width - len(header) - 2) // 2 - url_padding = (box_width - len(url) - 2) // 2 - - # Print the box - print(f"\n{DIM}{top_border}{RESET}") # noqa: T201 - print( # noqa: T201 - f"{DIM}║{RESET}{' ' * header_padding}{header}{' ' * (box_width - len(header) - header_padding - 3)}{DIM}║{RESET}" # noqa: E501 - ) - print(f"{DIM}{divider}{RESET}") # noqa: T201 - print( # noqa: T201 - f"{DIM}║{RESET}{' ' * url_padding}{BOLD}{GOLD}{url}{RESET}{' ' * (box_width - len(url) - url_padding - 2)}{DIM}║{RESET}" # noqa: E501 - ) - print(f"{DIM}{bottom_border}{RESET}\n") # noqa: T201 - - -def _print_trace_complete_url(task_run_id: str, error_occurred: bool = False) -> None: - """Print the trace completion URL with appropriate messaging.""" - # Only print HUD URL if HUD telemetry is enabled and has API key - if not (settings.telemetry_enabled and settings.api_key): - return - - url = f"https://hud.ai/trace/{task_run_id}" - - # ANSI color codes - GREEN = "\033[92m" - RED = "\033[91m" - GOLD = "\033[33m" - RESET = "\033[0m" - DIM = "\033[2m" - BOLD = "\033[1m" - - if error_occurred: - print( # noqa: T201 - f"\n{RED}✗ Trace errored!{RESET} {DIM}More error details available at:{RESET} {BOLD}{GOLD}{url}{RESET}\n" # noqa: E501 - ) - else: - print(f"\n{GREEN}✓ Trace complete!{RESET} {DIM}View at:{RESET} {BOLD}{GOLD}{url}{RESET}\n") # noqa: T201 - - -class trace: - """Internal OpenTelemetry trace context manager. - - This is the sync implementation. For async code, use hud.async_trace() instead. - """ - - def __init__( - self, - task_run_id: str, - is_root: bool = True, - span_name: str = "hud.task", - attributes: dict[str, Any] | None = None, - job_id: str | None = None, - task_id: str | None = None, - group_id: str | None = None, - ) -> None: - self.task_run_id = task_run_id - self.job_id = job_id - self.task_id = task_id - self.group_id = group_id - self.is_root = is_root - self.span_name = span_name - self.attributes = attributes or {} - self._span: otel_trace.Span | None = None - self._span_manager: Any | None = None - self._otel_token: object | None = None - self._task_run_token = None - self._root_token = None - - def __enter__(self) -> str: - """Enter the trace context and return the task_run_id.""" - # Set context variables - self._task_run_token = set_current_task_run_id(self.task_run_id) - self._root_token = is_root_trace_var.set(self.is_root) - - # Set OpenTelemetry baggage for propagation - ctx = baggage.set_baggage(TASK_RUN_ID_KEY, self.task_run_id) - ctx = baggage.set_baggage(IS_ROOT_TRACE_KEY, str(self.is_root), context=ctx) - if self.job_id: - ctx = baggage.set_baggage("hud.job_id", self.job_id, context=ctx) - if self.task_id: - ctx = baggage.set_baggage("hud.task_id", self.task_id, context=ctx) - if self.group_id: - ctx = baggage.set_baggage("hud.group_id", self.group_id, context=ctx) - self._otel_token = context.attach(ctx) - - # Start a span as current - tracer = otel_trace.get_tracer("hud-sdk") - span_attrs = { - "hud.task_run_id": self.task_run_id, - "hud.is_root_trace": self.is_root, - **self.attributes, - } - if self.job_id: - span_attrs["hud.job_id"] = self.job_id - if self.task_id: - span_attrs["hud.task_id"] = self.task_id - if self.group_id: - span_attrs["hud.group_id"] = self.group_id - - # Use start_as_current_span context manager - self._span_manager = tracer.start_as_current_span( - self.span_name, - attributes=span_attrs, - ) - self._span = self._span_manager.__enter__() - - # Update task status to running (sync call - blocking is expected) - if self.is_root and settings.telemetry_enabled and settings.api_key: - _update_task_status_sync( - self.task_run_id, - "running", - job_id=self.job_id, - trace_name=self.span_name, - task_id=self.task_id, - group_id=self.group_id, - ) - if not self.job_id: - _print_trace_url(self.task_run_id) - - logger.debug("Started HUD trace context for task_run_id=%s", self.task_run_id) - return self.task_run_id - - def __exit__( - self, - exc_type: type[BaseException] | None, - exc_val: BaseException | None, - exc_tb: TracebackType | None, - ) -> None: - """Exit the trace context.""" - # Update task status (sync call - blocking is expected for sync context manager) - if self.is_root and settings.telemetry_enabled and settings.api_key: - status = "error" if exc_type else "completed" - error_msg = None - if exc_type is not None: - error_msg = "".join(traceback.format_exception(exc_type, exc_val, exc_tb)) - _update_task_status_sync( - self.task_run_id, - status, - job_id=self.job_id, - error_message=error_msg, - trace_name=self.span_name, - task_id=self.task_id, - group_id=self.group_id, - ) - if not self.job_id: - _print_trace_complete_url(self.task_run_id, error_occurred=bool(exc_type)) - - # End the span - if self._span and self._span_manager is not None: - if exc_type is not None and exc_val is not None: - self._span.record_exception(exc_val) - self._span.set_status(Status(StatusCode.ERROR, str(exc_val))) - else: - self._span.set_status(Status(StatusCode.OK)) - self._span_manager.__exit__(exc_type, exc_val, exc_tb) - - # Detach OpenTelemetry context - if self._otel_token is not None: - try: - context.detach(self._otel_token) # type: ignore[arg-type] - except Exception: - logger.warning("Failed to detach OpenTelemetry context") - - # Reset context variables - if self._task_run_token is not None: - current_task_run_id.reset(self._task_run_token) # type: ignore - if self._root_token is not None: - is_root_trace_var.reset(self._root_token) # type: ignore - - logger.debug("Ended HUD trace context for task_run_id=%s", self.task_run_id) diff --git a/hud/otel/exporters.py b/hud/otel/exporters.py deleted file mode 100644 index ee9558fa..00000000 --- a/hud/otel/exporters.py +++ /dev/null @@ -1,543 +0,0 @@ -"""Custom OpenTelemetry exporter for HUD telemetry backend. - -This exporter sends spans to the HUD telemetry HTTP endpoint, grouping them -by task_run_id for efficient batch uploads. - -Performance optimizations: -- Detects async contexts and runs exports in a thread pool to avoid blocking -- Uses persistent HTTP client with connection pooling for reduced overhead -- Tracks pending export futures to ensure completion during shutdown - -The exporter derives from SpanExporter (synchronous interface) but handles -async contexts intelligently to prevent event loop blocking during high-concurrency -workloads. -""" - -from __future__ import annotations - -import atexit -import concurrent.futures as cf -import contextlib -import json -import logging -from collections import defaultdict -from concurrent.futures import ThreadPoolExecutor -from datetime import UTC, datetime -from typing import TYPE_CHECKING, Any - -from mcp.types import ClientRequest, ServerResult -from opentelemetry.sdk.trace.export import SpanExporter, SpanExportResult -from pydantic import BaseModel, ConfigDict, Field - -from hud.shared import make_request_sync -from hud.types import TraceStep as HudSpanAttributes - -if TYPE_CHECKING: - from opentelemetry.sdk.trace import ReadableSpan - -logger = logging.getLogger(__name__) - -# Global singleton thread pool for span exports -_export_executor: ThreadPoolExecutor | None = None - - -def get_export_executor() -> ThreadPoolExecutor: - """Get or create the global thread pool for span exports. - - Returns a singleton ThreadPoolExecutor used for running span exports - in a thread pool when called from async contexts, preventing event - loop blocking during high-concurrency workloads. - - The executor is automatically cleaned up on process exit via atexit. - - Returns: - ThreadPoolExecutor with 8 workers for high-throughput parallel uploads - """ - global _export_executor - if _export_executor is None: - # Use 8 workers to handle high-volume parallel uploads efficiently - _export_executor = ThreadPoolExecutor(max_workers=8, thread_name_prefix="span-export") - - def cleanup() -> None: - if _export_executor is not None: - _export_executor.shutdown(wait=True) - - atexit.register(cleanup) - return _export_executor - - -# --------------------------------------------------------------------------- -# Models -# --------------------------------------------------------------------------- - - -class HudSpan(BaseModel): - """A telemetry span ready for export.""" - - name: str - trace_id: str = Field(pattern=r"^[0-9a-fA-F]{32}$") - span_id: str = Field(pattern=r"^[0-9a-fA-F]{16}$") - parent_span_id: str | None = Field(None, pattern=r"^[0-9a-fA-F]{16}$") - - start_time: str # ISO format - end_time: str # ISO format - - status_code: str # "UNSET", "OK", "ERROR" - status_message: str | None = None - - attributes: HudSpanAttributes - exceptions: list[dict[str, Any]] | None = None - - model_config = ConfigDict(extra="forbid") - - -def extract_span_attributes( - attrs: dict[str, Any], method_name: str | None = None, span_name: str | None = None -) -> HudSpanAttributes: - """Extract and parse span attributes into typed model. - - This handles: - - Detecting span type (MCP vs Agent) - - Renaming verbose OpenTelemetry semantic conventions - - Parsing JSON strings to MCP types - """ - # Start with core attributes - map to TraceStep field names - result_attrs = { - "task_run_id": attrs.get( - "hud.task_run_id" - ), # TraceStep expects task_run_id, not hud.task_run_id - "job_id": attrs.get("hud.job_id"), # TraceStep expects job_id, not hud.job_id - "type": attrs.get("span.kind", "CLIENT"), # TraceStep expects type, not span.kind - } - - # Determine span type based on presence of agent or MCP attributes - # Note: The input attrs might already have "category" set - existing_category = attrs.get("category") - - if existing_category: - # Use the explicit category if provided - result_attrs["category"] = existing_category - elif span_name and span_name.startswith("agent."): - # Legacy support for spans named "agent.*" - result_attrs["category"] = "agent" - else: - result_attrs["category"] = "mcp" # Default to MCP - - # No special processing needed for different categories - # The backend will handle them based on the category field - - # Add method_name and request_id for MCP spans - if result_attrs["category"] == "mcp": - if method_name: - result_attrs["method_name"] = method_name - # Check for request_id with and without semconv_ai prefix - request_id = attrs.get("semconv_ai.mcp.request_id") or attrs.get("mcp.request.id") - if request_id: - result_attrs["request_id"] = request_id - - # Parse input/output - check both with and without semconv_ai prefix - input_str = attrs.get("semconv_ai.traceloop.entity.input") or attrs.get( - "traceloop.entity.input" - ) - output_str = attrs.get("semconv_ai.traceloop.entity.output") or attrs.get( - "traceloop.entity.output" - ) - - logger.debug( - "Category: %s, has input: %s, has output: %s", - result_attrs.get("category"), - bool(input_str), - bool(output_str), - ) - - # Check for direct request/result attributes first - if "request" in attrs and not result_attrs.get("request"): - req = attrs["request"] - if isinstance(req, str): - with contextlib.suppress(json.JSONDecodeError): - req = json.loads(req) - result_attrs["request"] = req - - if "result" in attrs and not result_attrs.get("result"): - res = attrs["result"] - if isinstance(res, str): - with contextlib.suppress(json.JSONDecodeError): - res = json.loads(res) - result_attrs["result"] = res - - # Process input/output from MCP instrumentation - if input_str and not result_attrs.get("request"): - try: - input_data = json.loads(input_str) if isinstance(input_str, str) else input_str - - # For MCP category, try to parse as ClientRequest to extract the root - if result_attrs["category"] == "mcp" and isinstance(input_data, dict): - try: - if "method" in input_data and "params" in input_data: - client_request = ClientRequest.model_validate(input_data) - result_attrs["request"] = client_request.root - else: - result_attrs["request"] = input_data - except Exception: - result_attrs["request"] = input_data - else: - # For all other categories, just store the data - result_attrs["request"] = input_data - except Exception as e: - logger.debug("Failed to parse request JSON: %s", e) - - if output_str and not result_attrs.get("result"): - try: - output_data = json.loads(output_str) if isinstance(output_str, str) else output_str - - # For MCP category, try to parse as ServerResult to extract the root - if result_attrs["category"] == "mcp" and isinstance(output_data, dict): - # Check for error - if "error" in output_data: - result_attrs["mcp_error"] = True - try: - server_result = ServerResult.model_validate(output_data) - result_attrs["result"] = server_result.root - # Check for isError in the result - if getattr(server_result.root, "isError", False): - result_attrs["mcp_error"] = True - except Exception: - result_attrs["result"] = output_data - else: - # For all other categories, just store the data - result_attrs["result"] = output_data - except Exception as e: - logger.debug("Failed to parse result JSON: %s", e) - - # Don't include the verbose attributes or ones we've already processed - exclude_keys = { - "hud.task_run_id", - "hud.job_id", - "span.kind", - "semconv_ai.mcp.method_name", - "mcp.method.name", # Also exclude non-prefixed version - "semconv_ai.mcp.request_id", - "mcp.request.id", # Also exclude non-prefixed version - "semconv_ai.traceloop.entity.input", - "semconv_ai.traceloop.entity.output", - "traceloop.entity.input", # Also exclude non-prefixed versions - "traceloop.entity.output", - "mcp_request", # Exclude to prevent overwriting parsed values - "mcp_result", # Exclude to prevent overwriting parsed values - "request", # Exclude to prevent overwriting parsed values - "result", # Exclude to prevent overwriting parsed values - "category", # Already handled above - } - - # Add any extra attributes - for key, value in attrs.items(): - if key not in exclude_keys: - result_attrs[key] = value # noqa: PERF403 - - logger.debug( - """Final result_attrs before creating HudSpanAttributes: - request=%s, - result=%s""", - result_attrs.get("request"), - result_attrs.get("result"), - ) - return HudSpanAttributes(**result_attrs) - - -# --------------------------------------------------------------------------- -# Helpers -# --------------------------------------------------------------------------- - - -def _ts_ns_to_iso(ts_ns: int) -> str: - """Convert a ``Span`` timestamp (nanoseconds) to ISO-8601 string.""" - # OpenTelemetry times are epoch nanoseconds - dt = datetime.fromtimestamp(ts_ns / 1_000_000_000, tz=UTC) - return dt.isoformat().replace("+00:00", "Z") - - -def _span_to_dict(span: ReadableSpan) -> dict[str, Any]: - """Convert an OpenTelemetry span to a dict using typed models.""" - - attrs = dict(span.attributes or {}) - - # Extract method name from span name if not in attributes - # Check both with and without semconv_ai prefix - raw_method = attrs.get("semconv_ai.mcp.method_name") or attrs.get("mcp.method.name") - method_name: str | None = None - if isinstance(raw_method, str): - method_name = raw_method - if method_name is None and isinstance(span.name, str) and span.name.endswith(".mcp"): - method_name = span.name[:-4] # Remove .mcp suffix - - # Create typed attributes - typed_attrs = extract_span_attributes(attrs, method_name, str(span.name)) - - # Record span kind as extra attribute (TraceStep allows extras) - try: - typed_attrs.span_kind = span.kind.name # type: ignore[attr-defined] - except Exception: - logger.warning("Failed to set span kind attribute") - - # Build typed span - # Guard context/parent/timestamps - context = getattr(span, "context", None) - trace_id_hex = ( - format(context.trace_id, "032x") if context and hasattr(context, "trace_id") else "0" * 32 - ) - span_id_hex = ( - format(context.span_id, "016x") if context and hasattr(context, "span_id") else "0" * 16 - ) - parent = getattr(span, "parent", None) - parent_id_hex = ( - format(parent.span_id, "016x") if parent and hasattr(parent, "span_id") else None - ) - start_ns = span.start_time or 0 - end_ns = span.end_time or start_ns - - typed_span = HudSpan( - name=span.name, - trace_id=trace_id_hex, - span_id=span_id_hex, - parent_span_id=parent_id_hex, - start_time=_ts_ns_to_iso(int(start_ns)), - end_time=_ts_ns_to_iso(int(end_ns)), - status_code=span.status.status_code.name if span.status else "UNSET", - status_message=span.status.description if span.status else None, - attributes=typed_attrs, - exceptions=None, - ) - - # Add error information if present - if span.events: - exceptions = [] - exceptions = [ - { - "timestamp": _ts_ns_to_iso(event.timestamp), - "attributes": dict(event.attributes or {}), - } - for event in span.events - ] - if exceptions: - typed_span.exceptions = exceptions - - # Convert to dict for export - return typed_span.model_dump(mode="json", by_alias=True, exclude_none=True) - - -# --------------------------------------------------------------------------- -# Exporter -# --------------------------------------------------------------------------- - - -class HudSpanExporter(SpanExporter): - """OpenTelemetry span exporter for the HUD backend. - - This exporter groups spans by task_run_id and sends them to the HUD - telemetry endpoint. Performance optimizations include: - - - Auto-detects async contexts and runs exports in thread pool (non-blocking) - - Tracks pending export futures for proper shutdown coordination - - Handles high-concurrency scenarios (200+ parallel tasks) by offloading - synchronous HTTP operations to a thread pool when called from async - contexts, preventing event loop blocking. - """ - - def __init__(self, *, telemetry_url: str, api_key: str) -> None: - """Initialize the HUD span exporter. - - Args: - telemetry_url: Base URL for the HUD telemetry backend - api_key: API key for authentication - """ - super().__init__() - self._telemetry_url = telemetry_url.rstrip("/") - self._api_key = api_key - - # Track pending export futures for shutdown coordination - self._pending_futures: list[cf.Future[SpanExportResult]] = [] - - def export(self, spans: list[ReadableSpan]) -> SpanExportResult: # type: ignore[override] - """Export spans to HUD backend. - - Auto-detects async contexts: if called from an async event loop, runs - the export in a thread pool to avoid blocking. Otherwise runs synchronously. - - Args: - spans: List of ReadableSpan objects to export - - Returns: - SpanExportResult.SUCCESS (returns immediately in async contexts) - """ - if not spans: - return SpanExportResult.SUCCESS - - # Group spans by task_run_id for batched uploads - grouped: dict[str, list[ReadableSpan]] = defaultdict(list) - for span in spans: - run_id = span.attributes.get("hud.task_run_id") if span.attributes else None - if not run_id: - # Skip spans outside HUD traces - continue - grouped[str(run_id)].append(span) - - # Detect async context to avoid event loop blocking - import asyncio - - try: - loop = asyncio.get_running_loop() - # In async context - offload to thread pool - executor = get_export_executor() - - def _sync_export() -> SpanExportResult: - # Send each group synchronously (retry inside make_request_sync) - for run_id, span_batch in grouped.items(): - try: - url = f"{self._telemetry_url}/trace/{run_id}/telemetry-upload" - telemetry_spans = [_span_to_dict(s) for s in span_batch] - # Include current step count in metadata - metadata = {} - # Get the HIGHEST step count from the batch (most recent) - step_count = 0 - for span in span_batch: - if span.attributes and "hud.step_count" in span.attributes: - current_step = span.attributes["hud.step_count"] - if isinstance(current_step, int) and current_step > step_count: - step_count = current_step - - payload = { - "metadata": metadata, - "telemetry": telemetry_spans, - } - - # Only include step_count if we found any steps - if step_count > 0: - payload["step_count"] = step_count - - logger.debug("HUD exporter sending %d spans to %s", len(span_batch), url) - make_request_sync( - method="POST", - url=url, - json=payload, - api_key=self._api_key, - ) - except Exception as exc: - logger.exception( - "HUD exporter failed to send spans for task %s: %s", run_id, exc - ) - return SpanExportResult.FAILURE - return SpanExportResult.SUCCESS - - # Run in thread to avoid blocking event loop - future = loop.run_in_executor(executor, _sync_export) - # Track and cleanup when done - self._pending_futures.append(future) # type: ignore[list-item] - - def _cleanup_done(f: cf.Future[SpanExportResult]) -> None: - with contextlib.suppress(Exception): - # Consume exception to avoid "exception was never retrieved" - _ = f.exception() - # Remove from pending list - with contextlib.suppress(ValueError): - self._pending_futures.remove(f) - - future.add_done_callback(_cleanup_done) # type: ignore[arg-type] - # Don't wait for it - return immediately - return SpanExportResult.SUCCESS - - except RuntimeError: - # No event loop - run synchronously - # Send each group synchronously (retry inside make_request_sync) - for run_id, span_batch in grouped.items(): - try: - url = f"{self._telemetry_url}/trace/{run_id}/telemetry-upload" - telemetry_spans = [_span_to_dict(s) for s in span_batch] - # Include current step count in metadata - metadata = {} - # Get the HIGHEST step count from the batch (most recent) - step_count = 0 - for span in span_batch: - if span.attributes and "hud.step_count" in span.attributes: - current_step = span.attributes["hud.step_count"] - if isinstance(current_step, int) and current_step > step_count: - step_count = current_step - - payload = { - "metadata": metadata, - "telemetry": telemetry_spans, - } - - # Only include step_count if we found any steps - if step_count > 0: - payload["step_count"] = step_count - - logger.debug("HUD exporter sending %d spans to %s", len(span_batch), url) - make_request_sync( - method="POST", - url=url, - json=payload, - api_key=self._api_key, - ) - except Exception as exc: - logger.exception( - "HUD exporter failed to send spans for task %s: %s", run_id, exc - ) - # If *any* group fails we return FAILURE so the OTEL SDK can retry - return SpanExportResult.FAILURE - - return SpanExportResult.SUCCESS - - def shutdown(self) -> None: # type: ignore[override] - """Shutdown the exporter and wait for pending exports. - - Waits up to 10 seconds for any in-flight exports to complete. - """ - try: - if self._pending_futures: - with contextlib.suppress(Exception): - cf.wait(self._pending_futures, timeout=10.0) - finally: - self._pending_futures.clear() - - def force_flush(self, timeout_millis: int | None = None) -> bool: # type: ignore[override] - """Force flush all pending span exports. - - Waits for all pending export futures to complete before returning. - This is called by the OpenTelemetry SDK during shutdown to ensure - all telemetry is uploaded. - - Args: - timeout_millis: Maximum time to wait in milliseconds - - Returns: - True if all exports completed, False otherwise - """ - try: - if not self._pending_futures: - return True - - total_pending = len(self._pending_futures) - if total_pending > 10: - # Show progress for large batches - logger.info("Flushing %d pending telemetry uploads...", total_pending) - - timeout = (timeout_millis or 30000) / 1000.0 - done, not_done = cf.wait(self._pending_futures, timeout=timeout) - - # Consume exceptions to avoid "exception was never retrieved" warnings - for f in list(done): - with contextlib.suppress(Exception): - _ = f.exception() - - # Remove completed futures - for f in list(done): - with contextlib.suppress(ValueError): - self._pending_futures.remove(f) - - if total_pending > 10: - logger.info("Completed %d/%d telemetry uploads", len(done), total_pending) - - return len(not_done) == 0 - except Exception: - return False diff --git a/hud/otel/instrumentation.py b/hud/otel/instrumentation.py deleted file mode 100644 index 475ac3e1..00000000 --- a/hud/otel/instrumentation.py +++ /dev/null @@ -1,147 +0,0 @@ -"""MCP instrumentation support for HUD. - -This module provides functions to enable MCP OpenTelemetry instrumentation -for automatic tracing of MCP protocol communication. - -Note: This module requires the [agents] extra to be installed: - pip install hud-python[agents] -""" - -from __future__ import annotations - -import logging -from typing import TYPE_CHECKING, Any - -if TYPE_CHECKING: - from collections.abc import AsyncGenerator, Callable - - from opentelemetry.trace import TracerProvider - -logger = logging.getLogger(__name__) - -# Check if OpenTelemetry is available -_HAS_OPENTELEMETRY = False -try: - from opentelemetry import trace as _otel_trace # noqa: F401 - - _HAS_OPENTELEMETRY = True -except ImportError: - pass - - -def install_mcp_instrumentation(provider: TracerProvider | None = None) -> None: - """Enable community MCP OpenTelemetry instrumentation if present. - - Args: - provider: The TracerProvider to use for instrumentation - """ - import logging - - logger = logging.getLogger(__name__) - - try: - # First, patch the _instruments to use our fork - import opentelemetry.instrumentation.mcp.instrumentation as mcp_inst - - mcp_inst._instruments = ("hud-mcp-python-sdk >= 3.13.1",) - - from opentelemetry.instrumentation.mcp.instrumentation import ( - McpInstrumentor, - ) - - # Then, patch the instrumentation to handle 3-value transports correctly - _patch_mcp_instrumentation() - - McpInstrumentor().instrument(tracer_provider=provider) - logger.debug("MCP instrumentation installed with fastmcp compatibility patch") - except ImportError: - logger.debug("opentelemetry-instrumentation-mcp not available, skipping") - except Exception as exc: - logger.warning("Failed to install MCP instrumentation: %s", exc) - - -def _patch_mcp_instrumentation() -> None: - """Patch MCP instrumentation to handle 3-value transport yields correctly.""" - from contextlib import asynccontextmanager - - try: - from opentelemetry.instrumentation.mcp.instrumentation import McpInstrumentor - - # First, patch the get_error_type function to handle invalid HTTP status codes - _patch_get_error_type() - - def patched_transport_wrapper(self: Any, tracer: Any) -> Callable[..., Any]: - @asynccontextmanager - async def traced_method( - wrapped: Callable[..., Any], instance: Any, args: Any, kwargs: Any - ) -> AsyncGenerator[Any, None]: - async with wrapped(*args, **kwargs) as result: - # Check if we got a tuple with 3 values - if isinstance(result, tuple) and len(result) == 3: - read_stream, write_stream, third_value = result - # Import here to avoid circular imports - from opentelemetry.instrumentation.mcp.instrumentation import ( - InstrumentedStreamReader, - InstrumentedStreamWriter, - ) - - yield ( - InstrumentedStreamReader(read_stream, tracer), - InstrumentedStreamWriter(write_stream, tracer), - third_value, - ) - else: - # Fall back to 2-value case - read_stream, write_stream = result - from opentelemetry.instrumentation.mcp.instrumentation import ( - InstrumentedStreamReader, - InstrumentedStreamWriter, - ) - - yield ( - InstrumentedStreamReader(read_stream, tracer), - InstrumentedStreamWriter(write_stream, tracer), - ) - - return traced_method - - # Apply the patch - McpInstrumentor._transport_wrapper = patched_transport_wrapper - - except Exception as e: - import logging - - logger = logging.getLogger(__name__) - logger.warning("Failed to patch MCP instrumentation: %s", e) - - -def _patch_get_error_type() -> None: - """Patch get_error_type to handle invalid HTTP status codes gracefully.""" - import re - from http import HTTPStatus - - try: - import opentelemetry.instrumentation.mcp.instrumentation as mcp_inst - - def patched_get_error_type(error_message: str) -> str | None: - """Extract HTTP status from error message, handling invalid codes.""" - if not isinstance(error_message, str): - return None - match = re.search(r"\b(4\d{2}|5\d{2})\b", error_message) - if match: - num = int(match.group()) - try: - # Only return if it's a valid HTTPStatus - if 400 <= num <= 599: - return HTTPStatus(num).name - except ValueError: - # Not a valid HTTP status code - logger.debug("Ignoring invalid HTTP status code: %s", num) - return None - - # Apply the patch - mcp_inst.get_error_type = patched_get_error_type - logger.debug("Patched get_error_type to handle invalid HTTP status codes") - - except Exception as e: - logger.warning("Failed to patch get_error_type: %s", e) diff --git a/hud/otel/processors.py b/hud/otel/processors.py deleted file mode 100644 index f41a1b0a..00000000 --- a/hud/otel/processors.py +++ /dev/null @@ -1,121 +0,0 @@ -from __future__ import annotations - -import logging -import time -from typing import Any - -from opentelemetry import baggage -from opentelemetry.sdk.trace import ReadableSpan, Span, SpanProcessor - -from .context import ( - get_agent_steps, - get_base_mcp_steps, - get_mcp_tool_steps, - increment_agent_steps, - increment_base_mcp_steps, - increment_mcp_tool_steps, -) - -logger = logging.getLogger(__name__) - - -class HudEnrichmentProcessor(SpanProcessor): - """Span processor that enriches every span with HUD-specific context. - - • Adds ``hud.task_run_id`` attribute if available. - • Adds ``hud.job_id`` attribute if available in baggage. - • Adds ``hud.step_count`` attribute if available in baggage. - """ - - def __init__(self) -> None: - # No state, everything comes from context vars - super().__init__() - - # --- callback hooks ------------------------------------------------- - def on_start(self, span: Span, parent_context: Any) -> None: # type: ignore[override] - try: - # Get task_run_id from baggage in parent context - run_id = baggage.get_baggage("hud.task_run_id", context=parent_context) - if run_id and span.is_recording(): - span.set_attribute("hud.task_run_id", str(run_id)) - - # Get job_id from baggage if available - job_id = baggage.get_baggage("hud.job_id", context=parent_context) - if job_id and span.is_recording(): - span.set_attribute("hud.job_id", str(job_id)) - - # Check what type of step this is and increment appropriate counters - if span.is_recording(): - step_type = self._get_step_type(span) - - if step_type == "agent": - # Increment agent steps - new_agent_count = increment_agent_steps() - span.set_attribute("hud.agent_steps", new_agent_count) - logger.debug("Incremented agent steps to %d", new_agent_count) - - elif step_type == "base_mcp": - # Increment base MCP steps - new_base_count = increment_base_mcp_steps() - span.set_attribute("hud.base_mcp_steps", new_base_count) - logger.debug("Incremented base MCP steps to %d", new_base_count) - - elif step_type == "mcp_tool": - # Increment both base MCP and MCP tool steps - new_base_count = increment_base_mcp_steps() - new_tool_count = increment_mcp_tool_steps() - span.set_attribute("hud.base_mcp_steps", new_base_count) - span.set_attribute("hud.mcp_tool_steps", new_tool_count) - logger.debug( - "Incremented MCP steps to base=%d, tool=%d", new_base_count, new_tool_count - ) - - # Always set all current step counts on the span - span.set_attribute("hud.base_mcp_steps", get_base_mcp_steps()) - span.set_attribute("hud.mcp_tool_steps", get_mcp_tool_steps()) - span.set_attribute("hud.agent_steps", get_agent_steps()) - - except Exception as exc: # defensive; never fail the tracer - logger.debug("HudEnrichmentProcessor.on_start error: %s", exc, exc_info=False) - - def _get_step_type(self, span: Span) -> str | None: - """Determine what type of step this span represents. - - Returns: - 'base_mcp' for any MCP span - 'mcp_tool' for MCP tool calls (tools/call.mcp) - 'agent' for agent spans - None if not a step - """ - # Check span attributes - attrs = span.attributes or {} - span_name = span.name - - # Check for agent steps (instrumented with span_type="agent") - if attrs.get("category") == "agent": - return "agent" - - # Check span name pattern for MCP calls - if span_name: - # tools/call.mcp is an mcp_tool step - if span_name == "tools/call.mcp": - return "mcp_tool" - - # Any other .mcp suffixed span is a base MCP step - elif span_name.endswith(".mcp"): - return "base_mcp" - - return None - - def on_end(self, span: ReadableSpan) -> None: - # Nothing to do enrichment is on_start only - pass - - # Required to fully implement abstract base, but we don't batch spans - def shutdown(self) -> None: # type: ignore[override] - pass - - def force_flush(self, timeout_millis: int | None = None) -> bool: # type: ignore[override] - if timeout_millis: - time.sleep(timeout_millis / 1000) - return True diff --git a/hud/otel/tests/__init__.py b/hud/otel/tests/__init__.py deleted file mode 100644 index e69de29b..00000000 diff --git a/hud/otel/tests/test_instrumentation.py b/hud/otel/tests/test_instrumentation.py deleted file mode 100644 index cfee2873..00000000 --- a/hud/otel/tests/test_instrumentation.py +++ /dev/null @@ -1,207 +0,0 @@ -from __future__ import annotations - -from unittest.mock import MagicMock, patch - -import pytest - -from hud.otel.instrumentation import ( - _patch_get_error_type, - _patch_mcp_instrumentation, - install_mcp_instrumentation, -) - - -def test_install_mcp_instrumentation_success(): - """Test successful installation of MCP instrumentation.""" - mock_provider = MagicMock() - - with ( - patch("opentelemetry.instrumentation.mcp.instrumentation"), - patch( - "opentelemetry.instrumentation.mcp.instrumentation.McpInstrumentor" - ) as mock_instrumentor_class, - patch("hud.otel.instrumentation._patch_mcp_instrumentation"), - ): - mock_instrumentor = MagicMock() - mock_instrumentor_class.return_value = mock_instrumentor - - install_mcp_instrumentation(mock_provider) - - mock_instrumentor.instrument.assert_called_once_with(tracer_provider=mock_provider) - - -def test_install_mcp_instrumentation_import_error(): - """Test installation handles ImportError gracefully.""" - mock_provider = MagicMock() - - # Mock the import to raise ImportError - import sys - - with patch.dict(sys.modules, {"opentelemetry.instrumentation.mcp.instrumentation": None}): - # Should not raise - install_mcp_instrumentation(mock_provider) - - -def test_install_mcp_instrumentation_general_exception(): - """Test installation handles general exceptions gracefully.""" - mock_provider = MagicMock() - - with ( - patch("opentelemetry.instrumentation.mcp.instrumentation"), - patch( - "opentelemetry.instrumentation.mcp.instrumentation.McpInstrumentor" - ) as mock_instrumentor_class, - ): - mock_instrumentor_class.side_effect = Exception("Unexpected error") - - # Should not raise - install_mcp_instrumentation(mock_provider) - - -def test_patch_mcp_instrumentation_success(): - """Test successful patching of MCP instrumentation.""" - with ( - patch("opentelemetry.instrumentation.mcp.instrumentation.McpInstrumentor") as mock_class, - patch("hud.otel.instrumentation._patch_get_error_type"), - ): - mock_class._transport_wrapper = None - - _patch_mcp_instrumentation() - - # Should have set the _transport_wrapper - assert mock_class._transport_wrapper is not None - - -def test_patch_mcp_instrumentation_exception(): - """Test patching handles exceptions gracefully.""" - with patch( - "opentelemetry.instrumentation.mcp.instrumentation.McpInstrumentor", - side_effect=Exception("Error"), - ): - # Should not raise - _patch_mcp_instrumentation() - - -def test_patch_get_error_type_success(): - """Test successful patching of get_error_type.""" - with patch("opentelemetry.instrumentation.mcp.instrumentation") as mock_mcp_inst: - mock_mcp_inst.get_error_type = None - - _patch_get_error_type() - - # Should have set get_error_type - assert mock_mcp_inst.get_error_type is not None - - -def test_patch_get_error_type_exception(): - """Test patching get_error_type handles exceptions.""" - with patch( - "opentelemetry.instrumentation.mcp.instrumentation", side_effect=ImportError("Not found") - ): - # Should not raise - _patch_get_error_type() - - -def test_patched_get_error_type_valid_4xx(): - """Test patched get_error_type with valid 4xx status code.""" - with patch("opentelemetry.instrumentation.mcp.instrumentation") as mock_mcp_inst: - _patch_get_error_type() - - patched_func = mock_mcp_inst.get_error_type - - # Test with a valid 4xx error - result = patched_func("Error 404 not found") - assert result == "NOT_FOUND" - - -def test_patched_get_error_type_valid_5xx(): - """Test patched get_error_type with valid 5xx status code.""" - with patch("opentelemetry.instrumentation.mcp.instrumentation") as mock_mcp_inst: - _patch_get_error_type() - - patched_func = mock_mcp_inst.get_error_type - - # Test with a valid 5xx error - result = patched_func("Error 500 internal server error") - assert result == "INTERNAL_SERVER_ERROR" - - -def test_patched_get_error_type_invalid_status(): - """Test patched get_error_type with invalid status code.""" - with patch("opentelemetry.instrumentation.mcp.instrumentation") as mock_mcp_inst: - _patch_get_error_type() - - patched_func = mock_mcp_inst.get_error_type - - # Test with an invalid HTTP status code (e.g., 499 doesn't exist in HTTPStatus) - result = patched_func("Error 499 custom error") - # Should return the name even if it's not a standard HTTPStatus - assert result is None or isinstance(result, str) - - -def test_patched_get_error_type_no_status(): - """Test patched get_error_type with no status code.""" - with patch("opentelemetry.instrumentation.mcp.instrumentation") as mock_mcp_inst: - _patch_get_error_type() - - patched_func = mock_mcp_inst.get_error_type - - result = patched_func("Error message without status code") - assert result is None - - -def test_patched_get_error_type_non_string(): - """Test patched get_error_type with non-string input.""" - with patch("opentelemetry.instrumentation.mcp.instrumentation") as mock_mcp_inst: - _patch_get_error_type() - - patched_func = mock_mcp_inst.get_error_type - - result = patched_func(None) - assert result is None - - result = patched_func(123) - assert result is None - - -def test_patched_get_error_type_3xx_ignored(): - """Test patched get_error_type ignores 3xx codes.""" - with patch("opentelemetry.instrumentation.mcp.instrumentation") as mock_mcp_inst: - _patch_get_error_type() - - patched_func = mock_mcp_inst.get_error_type - - result = patched_func("Error 301 moved") - assert result is None - - -@pytest.mark.asyncio -async def test_transport_wrapper_three_values(): - """Test transport wrapper handles 3-value tuple.""" - with ( - patch("opentelemetry.instrumentation.mcp.instrumentation.McpInstrumentor") as mock_class, - patch("hud.otel.instrumentation._patch_get_error_type"), - ): - mock_class._transport_wrapper = None - - _patch_mcp_instrumentation() - - # Get the patched wrapper - wrapper_func = mock_class._transport_wrapper - assert wrapper_func is not None - - -@pytest.mark.asyncio -async def test_transport_wrapper_two_values(): - """Test transport wrapper handles 2-value tuple.""" - with ( - patch("opentelemetry.instrumentation.mcp.instrumentation.McpInstrumentor") as mock_class, - patch("hud.otel.instrumentation._patch_get_error_type"), - ): - mock_class._transport_wrapper = None - - _patch_mcp_instrumentation() - - # Get the patched wrapper - wrapper_func = mock_class._transport_wrapper - assert wrapper_func is not None diff --git a/hud/otel/tests/test_processors.py b/hud/otel/tests/test_processors.py deleted file mode 100644 index 50ea14d4..00000000 --- a/hud/otel/tests/test_processors.py +++ /dev/null @@ -1,197 +0,0 @@ -"""Tests for OpenTelemetry processors.""" - -from __future__ import annotations - -from unittest.mock import MagicMock, patch - -from hud.otel.processors import HudEnrichmentProcessor - - -class TestHudEnrichmentProcessor: - """Test HudEnrichmentProcessor.""" - - def test_on_start_with_run_id(self): - """Test on_start with current task run ID.""" - - processor = HudEnrichmentProcessor() - - # Mock span - span = MagicMock() - span.set_attribute = MagicMock() - span.is_recording.return_value = True - - # Mock baggage to return run ID - parent_context = {} - with patch("hud.otel.processors.baggage.get_baggage") as mock_get_baggage: - # Return run ID for task_run_id, None for job_id - mock_get_baggage.side_effect = ( - lambda key, context: "test-run-123" if key == "hud.task_run_id" else None - ) - processor.on_start(span, parent_context) - - # Verify attribute was set - span.set_attribute.assert_called_with("hud.task_run_id", "test-run-123") - - def test_on_start_no_run_id(self): - """Test on_start without current task run ID.""" - - processor = HudEnrichmentProcessor() - - # Mock span - span = MagicMock() - span.set_attribute = MagicMock() - span.is_recording.return_value = True - span.name = "test_span" - - # Set up attributes to return None (not matching any step type) - span.attributes = {} - - # Mock baggage to return None - parent_context = {} - with patch("hud.otel.processors.baggage.get_baggage", return_value=None): - processor.on_start(span, parent_context) - - # Verify only step count attributes were set (no run_id or job_id) - calls = span.set_attribute.call_args_list - set_attrs = {call[0][0] for call in calls} - - # Should have step counts but not run_id/job_id - assert "hud.task_run_id" not in set_attrs - assert "hud.job_id" not in set_attrs - assert "hud.base_mcp_steps" in set_attrs - assert "hud.mcp_tool_steps" in set_attrs - assert "hud.agent_steps" in set_attrs - - def test_on_end(self): - """Test on_end does nothing.""" - - processor = HudEnrichmentProcessor() - span = MagicMock() - - # Should not raise - processor.on_end(span) - - def test_shutdown(self): - """Test shutdown does nothing.""" - - processor = HudEnrichmentProcessor() - - # Should not raise - processor.shutdown() - - def test_force_flush(self): - """Test force_flush returns True.""" - - processor = HudEnrichmentProcessor() - - # Should return True - result = processor.force_flush() - assert result is True - - def test_on_start_with_job_id(self): - """Test on_start with job ID in baggage.""" - - processor = HudEnrichmentProcessor() - - # Mock span - span = MagicMock() - span.set_attribute = MagicMock() - span.is_recording.return_value = True - - # Mock baggage with job ID - parent_context = {} - with patch("hud.otel.processors.baggage.get_baggage") as mock_get_baggage: - # Return None for task_run_id, job-123 for job_id - mock_get_baggage.side_effect = ( - lambda key, context: "job-123" if key == "hud.job_id" else None - ) - processor.on_start(span, parent_context) - - # Verify job ID attribute was set - span.set_attribute.assert_called_with("hud.job_id", "job-123") - - def test_on_start_exception_handling(self): - """Test on_start handles exceptions gracefully.""" - - processor = HudEnrichmentProcessor() - - # Mock span that raises exception - span = MagicMock() - span.is_recording.side_effect = Exception("Test error") - - # Should not raise - processor.on_start(span, parent_context=None) - - def test_on_start_exception_handling_extended(self): - """Test that exceptions in on_start are caught and logged.""" - from hud.otel.processors import HudEnrichmentProcessor - - processor = HudEnrichmentProcessor() - - # Create a mock span that raises when setting attributes - mock_span = MagicMock() - mock_span.is_recording.return_value = True - mock_span.set_attribute.side_effect = RuntimeError("Attribute error") - - parent_context = {} - - # Patch logger and baggage to force an exception when setting attribute - with ( - patch("hud.otel.processors.logger") as mock_logger, - patch("hud.otel.processors.baggage.get_baggage", return_value="test-id"), - ): - # Should not raise, exception should be caught - processor.on_start(mock_span, parent_context) - - # Verify logger.debug was called with the exception - mock_logger.debug.assert_called_once() - args = mock_logger.debug.call_args[0] - assert "HudEnrichmentProcessor.on_start error" in args[0] - assert "Attribute error" in str(args[1]) - - def test_on_start_with_baggage_get_exception(self): - """Test exception handling when baggage.get_baggage fails for task_run_id.""" - processor = HudEnrichmentProcessor() - - mock_span = MagicMock() - mock_span.is_recording.return_value = True - - parent_context = {} - - # Make baggage.get_baggage raise an exception for task_run_id - with ( - patch( - "hud.otel.processors.baggage.get_baggage", - side_effect=ValueError("Context error"), - ), - patch("hud.otel.processors.logger") as mock_logger, - ): - # Should not raise - processor.on_start(mock_span, parent_context) - - # Verify logger.debug was called - mock_logger.debug.assert_called_once() - args = mock_logger.debug.call_args[0] - assert "Context error" in str(args[1]) - - def test_on_start_with_baggage_exception(self): - """Test exception handling when baggage.get_baggage fails.""" - processor = HudEnrichmentProcessor() - - mock_span = MagicMock() - mock_span.is_recording.return_value = True - - parent_context = {} - - # Make baggage.get_baggage raise an exception - with ( - patch("hud.otel.processors.baggage.get_baggage", side_effect=KeyError("Baggage error")), - patch("hud.otel.processors.logger") as mock_logger, - ): - # Should not raise - processor.on_start(mock_span, parent_context) - - # Verify logger.debug was called - mock_logger.debug.assert_called_once() - args = mock_logger.debug.call_args[0] - assert "Baggage error" in str(args[1]) diff --git a/hud/server/server.py b/hud/server/server.py index 9b6b4d68..b833617b 100644 --- a/hud/server/server.py +++ b/hud/server/server.py @@ -16,7 +16,7 @@ from starlette.requests import Request from starlette.responses import JSONResponse, Response -from hud.datasets import run_tasks +from hud.datasets import run_dataset from hud.eval.task import Task from hud.server.low_level import LowLevelServerWithInit from hud.types import LegacyTask @@ -765,7 +765,7 @@ async def run_eval(request: Request) -> Response: # Fire and forget - launch evaluation in background async def run_eval_background() -> None: - await run_tasks( + await run_dataset( [Task.from_v4(task) for task in task_objects], agent_type=agent_type, agent_params=agent_params, diff --git a/hud/telemetry/__init__.py b/hud/telemetry/__init__.py index a243af80..e237673b 100644 --- a/hud/telemetry/__init__.py +++ b/hud/telemetry/__init__.py @@ -1,19 +1,27 @@ -"""HUD Telemetry - Instrumentation for agent execution. +"""HUD Telemetry - Lightweight telemetry for HUD SDK. This module provides: -- instrument: Function instrumentation decorator +- @instrument decorator for recording function calls +- High-performance span export to HUD API -For other APIs, import directly from submodules: -- hud.telemetry.job: Job, job, create_job, get_current_job -- hud.telemetry.trace: Trace, trace -- hud.telemetry.async_context: async_job, async_trace -- hud.telemetry.replay: clear_trace, get_trace +Usage: + import hud -Recommended: Use hud.eval() or env.eval() instead. -""" + @hud.instrument + async def my_function(): + ... -from __future__ import annotations + # Within an eval context, calls are recorded + async with hud.eval(task) as ctx: + result = await my_function() +""" -from .instrument import instrument +from hud.telemetry.exporter import flush, queue_span, shutdown +from hud.telemetry.instrument import instrument -__all__ = ["instrument"] +__all__ = [ + "flush", + "instrument", + "queue_span", + "shutdown", +] diff --git a/hud/telemetry/async_context.py b/hud/telemetry/async_context.py deleted file mode 100644 index 90a00723..00000000 --- a/hud/telemetry/async_context.py +++ /dev/null @@ -1,345 +0,0 @@ -"""Async context managers for HUD telemetry. - -Provides async-native trace and job context managers for async code. - -Usage: - >>> import hud - >>> async with hud.async_trace("Task"): - ... await agent.run(task) - >>> async with hud.async_job("Evaluation") as job: - ... async with hud.async_trace("Task", job_id=job.id): - ... await agent.run(task) - -Telemetry is fully automatic - status updates are awaited and spans are -flushed on context exit. No manual cleanup required. -""" - -from __future__ import annotations - -import logging -import traceback -import uuid -from typing import TYPE_CHECKING, Any - -if TYPE_CHECKING: - from types import TracebackType - -from hud.otel import configure_telemetry -from hud.otel.context import ( - _update_task_status_async, -) -from hud.otel.context import ( - trace as OtelTrace, -) -from hud.settings import settings -from hud.shared import make_request -from hud.telemetry.job import Job, _print_job_complete_url, _print_job_url -from hud.telemetry.trace import Trace - -logger = logging.getLogger(__name__) - -# Module exports -__all__ = ["AsyncJob", "AsyncTrace", "async_job", "async_trace"] - -# Global state for current job -_current_job: Job | None = None - - -class AsyncTrace: - """Async context manager for HUD trace tracking. - - This is the async equivalent of `hud.trace()`, designed for use in - high-concurrency async contexts. It tracks task execution with automatic - status updates. - - The context manager: - - Creates a unique task_run_id for telemetry correlation - - Sends and AWAITS status updates ("running" → "completed"/"error") - - Integrates with OpenTelemetry for span collection - - Ensures status is updated before exiting the context - - Use `async_trace()` helper function instead of instantiating directly. - """ - - def __init__( - self, - name: str = "Test task from hud", - *, - root: bool = True, - attrs: dict[str, Any] | None = None, - job_id: str | None = None, - task_id: str | None = None, - group_id: str | None = None, - trace_id: str | None = None, - ) -> None: - self.name = name - self.root = root - self.attrs = attrs or {} - self.job_id = job_id - self.task_id = task_id - self.group_id = group_id - self.task_run_id = trace_id if trace_id else str(uuid.uuid4()) - self.trace_obj = Trace(self.task_run_id, name, job_id, task_id, group_id) - self._otel_trace = None - - async def __aenter__(self) -> Trace: - """Enter the async trace context.""" - # Ensure telemetry is configured - configure_telemetry() - - # Start the OpenTelemetry span - self._otel_trace = OtelTrace( - self.task_run_id, - is_root=self.root, - span_name=self.name, - attributes=self.attrs, - job_id=self.job_id, - task_id=self.task_id, - group_id=self.group_id, - ) - self._otel_trace.__enter__() - - # Update trace status to "running" - if self.root and settings.telemetry_enabled and settings.api_key: - await _update_task_status_async( - self.task_run_id, - "running", - job_id=self.job_id, - trace_name=self.name, - task_id=self.task_id, - group_id=self.group_id, - ) - - logger.debug("Started trace: %s (%s)", self.name, self.task_run_id) - return self.trace_obj - - async def __aexit__( - self, - exc_type: type[BaseException] | None, - exc_val: BaseException | None, - exc_tb: TracebackType | None, - ) -> None: - """Exit the async trace context.""" - # Close the OpenTelemetry span - if self._otel_trace: - self._otel_trace.__exit__(exc_type, exc_val, exc_tb) - - # Update trace status to "completed" or "error" - if self.root and settings.telemetry_enabled and settings.api_key: - status = "error" if exc_type else "completed" - error_msg = None - if exc_type is not None: - error_msg = "".join(traceback.format_exception(exc_type, exc_val, exc_tb)) - - try: - await _update_task_status_async( - self.task_run_id, - status, - job_id=self.job_id, - error_message=error_msg, - trace_name=self.name, - task_id=self.task_id, - group_id=self.group_id, - ) - except Exception as e: - logger.warning("Failed to update trace status: %s", e) - - # Flush spans for standalone traces (not part of a job) - if not self.job_id and self.root: - from hud.telemetry.utils import flush_telemetry - - await flush_telemetry() - - logger.debug("Ended trace: %s (%s)", self.name, self.task_run_id) - - -class AsyncJob: - """Async context manager for HUD job tracking. - - This is the async equivalent of `hud.job()`, designed for grouping - related tasks in high-concurrency async contexts. - - The context manager: - - Creates or uses a provided job_id - - Sends and AWAITS status updates ("running" → "completed"/"failed") - - Associates all child traces with this job - - Ensures status is updated before exiting the context - - Use `async_job()` helper function instead of instantiating directly. - """ - - def __init__( - self, - name: str, - metadata: dict[str, Any] | None = None, - job_id: str | None = None, - dataset_link: str | None = None, - ) -> None: - self.job_id = job_id or str(uuid.uuid4()) - self.job = Job(self.job_id, name, metadata, dataset_link) - - async def __aenter__(self) -> Job: - """Enter the async job context.""" - global _current_job - - # Save previous job and set this as current - self._old_job = _current_job - _current_job = self.job - - # Update job status to "running" - if settings.telemetry_enabled: - payload = { - "name": self.job.name, - "status": "running", - "metadata": self.job.metadata, - } - if self.job.dataset_link: - payload["dataset_link"] = self.job.dataset_link - - try: - await make_request( - method="POST", - url=f"{settings.hud_telemetry_url}/jobs/{self.job.id}/status", - json=payload, - api_key=settings.api_key, - ) - except Exception as e: - logger.warning("Failed to update job status: %s", e) - - _print_job_url(self.job.id, self.job.name) - logger.debug("Started job: %s (%s)", self.job.name, self.job.id) - return self.job - - async def __aexit__( - self, - exc_type: type[BaseException] | None, - exc_val: BaseException | None, - exc_tb: TracebackType | None, - ) -> None: - """Exit the async job context.""" - global _current_job - - # Flush all child trace spans before updating job status - from hud.telemetry.utils import flush_telemetry - - await flush_telemetry() - - # Update job status to "completed" or "failed" - if settings.telemetry_enabled: - status = "failed" if exc_type else "completed" - payload = { - "name": self.job.name, - "status": status, - "metadata": self.job.metadata, - } - if self.job.dataset_link: - payload["dataset_link"] = self.job.dataset_link - - try: - await make_request( - method="POST", - url=f"{settings.hud_telemetry_url}/jobs/{self.job.id}/status", - json=payload, - api_key=settings.api_key, - ) - except Exception as e: - logger.warning("Failed to update job status: %s", e) - - _print_job_complete_url(self.job.id, self.job.name, error_occurred=bool(exc_type)) - - # Restore previous job - _current_job = self._old_job - - logger.debug("Ended job: %s (%s)", self.job.name, self.job.id) - - -def async_trace( - name: str = "Test task from hud", - *, - root: bool = True, - attrs: dict[str, Any] | None = None, - job_id: str | None = None, - task_id: str | None = None, - group_id: str | None = None, - trace_id: str | None = None, -) -> AsyncTrace: - """Create an async trace context for telemetry tracking. - - This is the async equivalent of `hud.trace()` for use in async contexts. - Status updates are automatically sent and awaited - the trace doesn't exit - until its status is confirmed on the server. - - Args: - name: Descriptive name for this trace/task - root: Whether this is a root trace (updates task status) - attrs: Additional attributes to attach to the trace - job_id: Optional job ID to associate with this trace - task_id: Optional task ID for custom task identifiers - group_id: Optional group ID to associate with this trace - trace_id: Optional trace ID (auto-generated if not provided) - - Returns: - AsyncTrace context manager - - Example: - >>> import hud - >>> # Single task - everything is automatic! - >>> async with hud.async_trace("My Task"): - ... result = await agent.run(task) - >>> # Status is "completed" and spans are flushed - all done! - >>> - >>> # Multiple tasks - each trace handles itself - >>> for task in tasks: - ... async with hud.async_trace(task.name): - ... await process(task) - >>> # All traces completed and flushed - nothing more needed! - - Note: - Use this async version in async code. For sync code, use `hud.trace()`. - Telemetry is fully automatic - no manual flushing required. - """ - return AsyncTrace( - name, - root=root, - attrs=attrs, - job_id=job_id, - task_id=task_id, - group_id=group_id, - trace_id=trace_id, - ) - - -def async_job( - name: str, - metadata: dict[str, Any] | None = None, - job_id: str | None = None, - dataset_link: str | None = None, -) -> AsyncJob: - """Create an async job context for grouping related tasks. - - This is the async equivalent of `hud.job()` for async contexts. - Status updates are automatically sent and awaited - the job doesn't exit - until its status is confirmed on the server. - - Args: - name: Human-readable job name - metadata: Optional metadata dictionary - job_id: Optional job ID (auto-generated if not provided) - dataset_link: Optional HuggingFace dataset identifier - - Returns: - AsyncJob context manager - - Example: - >>> import hud - >>> async with hud.async_job("Batch Processing") as job: - ... for item in items: - ... async with hud.async_trace(f"Task {item.id}", job_id=job.id): - ... await process(item) - >>> # Job exits - automatically flushes all child trace spans! - - Note: - Use this async version in async code. For sync code, use `hud.job()`. - Telemetry is fully automatic - no manual flushing required. - """ - return AsyncJob(name, metadata=metadata, job_id=job_id, dataset_link=dataset_link) diff --git a/hud/telemetry/exporter.py b/hud/telemetry/exporter.py new file mode 100644 index 00000000..3001437f --- /dev/null +++ b/hud/telemetry/exporter.py @@ -0,0 +1,204 @@ +"""High-performance span exporter for HUD telemetry backend. + +This module provides a lightweight span exporter that sends spans to the HUD +telemetry API immediately, using a thread pool to avoid blocking async code. + +No OpenTelemetry dependency required. +""" + +from __future__ import annotations + +import atexit +import concurrent.futures as cf +import contextlib +import logging +from collections import defaultdict +from concurrent.futures import ThreadPoolExecutor +from typing import TYPE_CHECKING, Any + +from hud.shared import make_request_sync + +if TYPE_CHECKING: + pass + +logger = logging.getLogger(__name__) + +# Global singleton thread pool for span exports +_export_executor: ThreadPoolExecutor | None = None + +# Pending futures for shutdown coordination +_pending_futures: list[cf.Future[bool]] = [] + +# Spans waiting to be flushed at context exit (per task_run_id) +_pending_spans: dict[str, list[dict[str, Any]]] = defaultdict(list) + + +def _get_export_executor() -> ThreadPoolExecutor: + """Get or create the global thread pool for span exports.""" + global _export_executor + if _export_executor is None: + _export_executor = ThreadPoolExecutor(max_workers=4, thread_name_prefix="span-export") + + def cleanup() -> None: + if _export_executor is not None: + _export_executor.shutdown(wait=True) + + atexit.register(cleanup) + return _export_executor + + +def _do_upload( + task_run_id: str, + spans: list[dict[str, Any]], + telemetry_url: str, + api_key: str, +) -> bool: + """Upload spans to HUD API (sync, runs in thread pool).""" + try: + url = f"{telemetry_url}/trace/{task_run_id}/telemetry-upload" + payload: dict[str, Any] = {"telemetry": spans} + + logger.debug("Uploading %d spans to %s", len(spans), url) + make_request_sync( + method="POST", + url=url, + json=payload, + api_key=api_key, + ) + return True + except Exception as e: + logger.debug("Failed to upload spans for task %s: %s", task_run_id, e) + return False + + +def _get_api_key() -> str | None: + """Get the API key - prefer context override, fallback to settings.""" + from hud.eval.context import get_current_api_key + from hud.settings import settings + + return get_current_api_key() or settings.api_key + + +def queue_span(span: dict[str, Any]) -> None: + """Queue a span and immediately upload it (non-blocking). + + Uses thread pool to upload without blocking the event loop. + """ + from hud.settings import settings + + api_key = _get_api_key() + if not api_key or not settings.telemetry_enabled: + return + + task_run_id = span.get("attributes", {}).get("task_run_id") + if not task_run_id: + return + + # Store for potential re-flush at context exit + _pending_spans[task_run_id].append(span) + + # Capture api_key for upload closure (context may change) + upload_api_key = api_key + + # Upload immediately via thread pool + import asyncio + + try: + loop = asyncio.get_running_loop() + # In async context - use thread pool + executor = _get_export_executor() + + def _upload() -> bool: + return _do_upload(task_run_id, [span], settings.hud_telemetry_url, upload_api_key) + + future = loop.run_in_executor(executor, _upload) + _pending_futures.append(future) # type: ignore[arg-type] + + def _cleanup_done(f: cf.Future[bool]) -> None: + with contextlib.suppress(Exception): + _ = f.exception() + with contextlib.suppress(ValueError): + _pending_futures.remove(f) + # Remove from pending spans on success + if not f.exception(): + with contextlib.suppress(Exception): + if task_run_id in _pending_spans and span in _pending_spans[task_run_id]: + _pending_spans[task_run_id].remove(span) + + future.add_done_callback(_cleanup_done) # type: ignore[arg-type] + + except RuntimeError: + # No event loop - upload synchronously + if _do_upload(task_run_id, [span], settings.hud_telemetry_url, upload_api_key): + with contextlib.suppress(Exception): + if task_run_id in _pending_spans and span in _pending_spans[task_run_id]: + _pending_spans[task_run_id].remove(span) + + +def flush(task_run_id: str | None = None) -> None: + """Flush any pending spans (called at context exit). + + This ensures any spans that failed to upload are retried. + + Args: + task_run_id: Optional task run ID to flush. If None, flushes all. + """ + from hud.settings import settings + + api_key = _get_api_key() + if not api_key or not settings.telemetry_enabled: + _pending_spans.clear() + return + + if task_run_id: + # Flush specific task + spans = _pending_spans.pop(task_run_id, []) + if spans: + _do_upload(task_run_id, spans, settings.hud_telemetry_url, api_key) + else: + # Flush all + for tid, spans in list(_pending_spans.items()): + if spans: + _do_upload(tid, spans, settings.hud_telemetry_url, api_key) + _pending_spans.clear() + + +def shutdown(timeout: float = 10.0) -> bool: + """Shutdown and wait for pending exports. + + Args: + timeout: Maximum time to wait in seconds + + Returns: + True if all exports completed, False if timed out + """ + # Wait for pending async exports + if _pending_futures: + try: + done, not_done = cf.wait(_pending_futures, timeout=timeout) + for f in done: + with contextlib.suppress(Exception): + _ = f.exception() + _pending_futures.clear() + + # Flush any remaining spans synchronously + flush() + + return len(not_done) == 0 + except Exception: + return False + + # Flush any remaining spans + flush() + return True + + +# Register shutdown handler +atexit.register(lambda: shutdown(timeout=5.0)) + + +__all__ = [ + "flush", + "queue_span", + "shutdown", +] diff --git a/hud/telemetry/instrument.py b/hud/telemetry/instrument.py index 0f438b83..ce45e452 100644 --- a/hud/telemetry/instrument.py +++ b/hud/telemetry/instrument.py @@ -1,15 +1,15 @@ -"""Simple instrumentation decorator for HUD tracing. +"""Instrumentation decorator for HUD telemetry. This module provides a lightweight @instrument decorator that records -function calls within the context of env.trace(). No OpenTelemetry required. +function calls and sends them to the HUD telemetry backend. Usage: @hud.instrument async def my_function(arg1, arg2): ... - # Within a trace context, calls are recorded - async with env.eval("task") as tc: + # Within an eval context, calls are recorded and sent to HUD + async with env.eval("task") as ctx: result = await my_function("a", "b") """ @@ -27,6 +27,16 @@ async def my_function(arg1, arg2): import pydantic_core +from hud.telemetry.exporter import queue_span +from hud.types import TraceStep + + +def _get_trace_id() -> str | None: + """Lazy import to avoid circular dependency with eval.context.""" + from hud.eval.context import get_current_trace_id + + return get_current_trace_id() + if TYPE_CHECKING: from collections.abc import Awaitable, Callable from typing import ParamSpec @@ -54,13 +64,24 @@ def _serialize_value(value: Any, max_items: int = 10) -> Any: return f"<{type(value).__name__}>" +def _now_iso() -> str: + """Get current time as ISO-8601 string.""" + return datetime.now(UTC).isoformat().replace("+00:00", "Z") + + +def _normalize_trace_id(trace_id: str) -> str: + """Normalize trace_id to 32-character hex string.""" + clean = trace_id.replace("-", "") + return clean[:32].ljust(32, "0") + + @overload def instrument( func: None = None, *, name: str | None = None, category: str = "function", - span_type: str | None = None, # Alias for category + span_type: str | None = None, record_args: bool = True, record_result: bool = True, ) -> Callable[[Callable[..., Any]], Callable[..., Any]]: ... @@ -72,7 +93,7 @@ def instrument( *, name: str | None = None, category: str = "function", - span_type: str | None = None, # Alias for category + span_type: str | None = None, record_args: bool = True, record_result: bool = True, ) -> Callable[P, R]: ... @@ -84,7 +105,7 @@ def instrument( *, name: str | None = None, category: str = "function", - span_type: str | None = None, # Alias for category + span_type: str | None = None, record_args: bool = True, record_result: bool = True, ) -> Callable[P, Awaitable[R]]: ... @@ -95,18 +116,18 @@ def instrument( *, name: str | None = None, category: str = "function", - span_type: str | None = None, # Alias for category + span_type: str | None = None, record_args: bool = True, record_result: bool = True, ) -> Callable[..., Any]: """Instrument a function to record spans within eval context. - This decorator records function calls as spans, compatible with env.eval(). + This decorator records function calls as spans and sends them to the HUD API. Args: func: The function to instrument name: Custom span name (defaults to module.function) - category: Span category (e.g., "agent", "tool", "function") + category: Span category (e.g., "agent", "tool", "function", "mcp") span_type: Alias for category (deprecated, use category instead) record_args: Whether to record function arguments record_result: Whether to record function result @@ -123,8 +144,6 @@ async def process_data(items: list[str]) -> dict: async def call_model(messages: list) -> str: return await model.generate(messages) """ - - # span_type is an alias for category effective_category = span_type if span_type is not None else category def decorator(func: Callable[..., Any]) -> Callable[..., Any]: @@ -142,24 +161,25 @@ def decorator(func: Callable[..., Any]) -> Callable[..., Any]: sig = None def _build_span( - trace_id: str, + task_run_id: str, args: tuple[Any, ...], kwargs: dict[str, Any], start_time: str, end_time: str, - duration_ms: float, result: Any = None, error: str | None = None, ) -> dict[str, Any]: - """Build a span record.""" - attributes: dict[str, Any] = { - "category": effective_category, - "function": func_qualname, - "module": func_module, - "duration_ms": duration_ms, - } - - # Record arguments + """Build a HudSpan-compatible span record.""" + # Build attributes using TraceStep + attributes = TraceStep( + task_run_id=task_run_id, + category=effective_category, + type="CLIENT", + start_timestamp=start_time, + end_timestamp=end_time, + ) + + # Record arguments as request if record_args and sig: try: bound_args = sig.bind(*args, **kwargs) @@ -170,44 +190,37 @@ def _build_span( if k not in ("self", "cls") } if args_dict: - attributes["request"] = json.dumps(args_dict) + attributes.request = args_dict except Exception as e: logger.debug("Failed to serialize args: %s", e) # Record result if record_result and result is not None and error is None: try: - attributes["result"] = json.dumps(_serialize_value(result)) + attributes.result = _serialize_value(result) except Exception as e: logger.debug("Failed to serialize result: %s", e) - # Record error - if error: - attributes["error"] = error - - return { - "trace_id": trace_id, - "span_id": uuid.uuid4().hex[:16], + # Build span + span_id = uuid.uuid4().hex[:16] + span = { "name": span_name, + "trace_id": _normalize_trace_id(task_run_id), + "span_id": span_id, + "parent_span_id": None, "start_time": start_time, "end_time": end_time, "status_code": "ERROR" if error else "OK", - "attributes": attributes, + "status_message": error, + "attributes": attributes.model_dump(mode="json", exclude_none=True), + "exceptions": [{"message": error}] if error else None, } - - def _get_trace_id() -> str | None: - """Get trace_id from current eval context.""" - from hud.eval.context import get_current_trace_headers - - headers = get_current_trace_headers() - if headers: - return headers.get("Trace-Id") - return None + return span @functools.wraps(func) async def async_wrapper(*args: Any, **kwargs: Any) -> Any: - trace_id = _get_trace_id() - start_time = datetime.now(UTC).isoformat() + task_run_id = _get_trace_id() + start_time = _now_iso() start_perf = time.perf_counter() error: str | None = None result: Any = None @@ -219,19 +232,20 @@ async def async_wrapper(*args: Any, **kwargs: Any) -> Any: error = f"{type(e).__name__}: {e}" raise finally: - end_time = datetime.now(UTC).isoformat() + end_time = _now_iso() duration_ms = (time.perf_counter() - start_perf) * 1000 - if trace_id: - _build_span( - trace_id, args, kwargs, start_time, end_time, duration_ms, result, error + if task_run_id: + span = _build_span( + task_run_id, args, kwargs, start_time, end_time, result, error ) + queue_span(span) logger.debug("Span: %s (%.2fms)", span_name, duration_ms) @functools.wraps(func) def sync_wrapper(*args: Any, **kwargs: Any) -> Any: - trace_id = _get_trace_id() - start_time = datetime.now(UTC).isoformat() + task_run_id = _get_trace_id() + start_time = _now_iso() start_perf = time.perf_counter() error: str | None = None result: Any = None @@ -243,13 +257,14 @@ def sync_wrapper(*args: Any, **kwargs: Any) -> Any: error = f"{type(e).__name__}: {e}" raise finally: - end_time = datetime.now(UTC).isoformat() + end_time = _now_iso() duration_ms = (time.perf_counter() - start_perf) * 1000 - if trace_id: - _build_span( - trace_id, args, kwargs, start_time, end_time, duration_ms, result, error + if task_run_id: + span = _build_span( + task_run_id, args, kwargs, start_time, end_time, result, error ) + queue_span(span) logger.debug("Span: %s (%.2fms)", span_name, duration_ms) wrapper = async_wrapper if asyncio.iscoroutinefunction(func) else sync_wrapper @@ -263,4 +278,6 @@ def sync_wrapper(*args: Any, **kwargs: Any) -> Any: return decorator(func) -__all__ = ["instrument"] +__all__ = [ + "instrument", +] diff --git a/hud/telemetry/job.py b/hud/telemetry/job.py deleted file mode 100644 index 86576570..00000000 --- a/hud/telemetry/job.py +++ /dev/null @@ -1,355 +0,0 @@ -"""Job management for HUD SDK. - -This module provides APIs for managing jobs - logical groupings of related tasks. -Jobs can be used to track experiments, batch processing, training runs, etc. -""" - -from __future__ import annotations - -import asyncio -import logging -import uuid -from contextlib import contextmanager -from datetime import UTC, datetime -from functools import wraps -from typing import TYPE_CHECKING, Any - -from hud.settings import settings -from hud.shared import make_request, make_request_sync - -if TYPE_CHECKING: - from collections.abc import Callable, Generator - -logger = logging.getLogger(__name__) - - -class Job: - """A job represents a collection of related tasks.""" - - def __init__( - self, - job_id: str, - name: str, - metadata: dict[str, Any] | None = None, - dataset_link: str | None = None, - ) -> None: - self.id = job_id - self.name = name - self.metadata = metadata or {} - self.dataset_link = dataset_link - self.status = "created" - self.created_at = datetime.now(UTC) - self.tasks: list[str] = [] - - def add_task(self, task_id: str) -> None: - """Associate a task with this job.""" - self.tasks.append(task_id) - - async def update_status(self, status: str) -> None: - """Update job status on the server.""" - self.status = status - if settings.telemetry_enabled: - try: - payload = { - "name": self.name, - "status": status, - "metadata": self.metadata, - } - if self.dataset_link: - payload["dataset_link"] = self.dataset_link - - await make_request( - method="POST", - url=f"{settings.hud_telemetry_url}/jobs/{self.id}/status", - json=payload, - api_key=settings.api_key, - ) - except Exception as e: - logger.warning("Failed to update job status: %s", e) - - def update_status_sync(self, status: str) -> None: - """Synchronously update job status on the server.""" - self.status = status - if settings.telemetry_enabled: - try: - payload = { - "name": self.name, - "status": status, - "metadata": self.metadata, - } - if self.dataset_link: - payload["dataset_link"] = self.dataset_link - - make_request_sync( - method="POST", - url=f"{settings.hud_telemetry_url}/jobs/{self.id}/status", - json=payload, - api_key=settings.api_key, - ) - except Exception as e: - logger.warning("Failed to update job status: %s", e) - - async def log(self, metrics: dict[str, Any]) -> None: - """Log metrics to the job. - - Args: - metrics: Dictionary of metric name to value pairs - - Example: - await job.log({"loss": 0.5, "accuracy": 0.95, "epoch": 1}) - """ - if settings.telemetry_enabled: - try: - await make_request( - method="POST", - url=f"{settings.hud_telemetry_url}/jobs/{self.id}/log", - json={"metrics": metrics, "timestamp": datetime.now(UTC).isoformat()}, - api_key=settings.api_key, - ) - except Exception as e: - logger.warning("Failed to log metrics to job: %s", e) - - def log_sync(self, metrics: dict[str, Any]) -> None: - """Synchronously log metrics to the job. - - Args: - metrics: Dictionary of metric name to value pairs - - Example: - job.log_sync({"loss": 0.5, "accuracy": 0.95, "epoch": 1}) - """ - if settings.telemetry_enabled: - try: - make_request_sync( - method="POST", - url=f"{settings.hud_telemetry_url}/jobs/{self.id}/log", - json={"metrics": metrics, "timestamp": datetime.now(UTC).isoformat()}, - api_key=settings.api_key, - ) - except Exception as e: - logger.warning("Failed to log metrics to job: %s", e) - - def __repr__(self) -> str: - return f"Job(id={self.id!r}, name={self.name!r}, status={self.status!r})" - - -# Global job registry for the decorator pattern -_current_job: Job | None = None - - -def _print_job_url(job_id: str, job_name: str) -> None: - """Print the job URL in a colorful box.""" - # Only print HUD URL if HUD telemetry is enabled and has API key - if not (settings.telemetry_enabled and settings.api_key): - return - - url = f"https://hud.ai/jobs/{job_id}" - header = f"🚀 Job '{job_name}' started:" - - # ANSI color codes - DIM = "\033[90m" # Dim/Gray for border - GOLD = "\033[33m" # Gold/Yellow for URL - RESET = "\033[0m" - BOLD = "\033[1m" - - # Calculate box width based on the longest line - box_width = max(len(url), len(header)) + 6 - - # Box drawing characters - top_border = "╔" + "═" * (box_width - 2) + "╗" - bottom_border = "╚" + "═" * (box_width - 2) + "╝" - divider = "╟" + "─" * (box_width - 2) + "╢" - - # Center the content - header_padding = (box_width - len(header) - 2) // 2 - url_padding = (box_width - len(url) - 2) // 2 - - # Print the box - print(f"\n{DIM}{top_border}{RESET}") # noqa: T201 - print( # noqa: T201 - f"{DIM}║{RESET}{' ' * header_padding}{header}{' ' * (box_width - len(header) - header_padding - 3)}{DIM}║{RESET}" # noqa: E501 - ) - print(f"{DIM}{divider}{RESET}") # noqa: T201 - print( # noqa: T201 - f"{DIM}║{RESET}{' ' * url_padding}{BOLD}{GOLD}{url}{RESET}{' ' * (box_width - len(url) - url_padding - 2)}{DIM}║{RESET}" # noqa: E501 - ) - print(f"{DIM}{bottom_border}{RESET}\n") # noqa: T201 - - -def _print_job_complete_url(job_id: str, job_name: str, error_occurred: bool = False) -> None: - """Print the job completion URL with appropriate messaging.""" - # Only print HUD URL if HUD telemetry is enabled and has API key - if not (settings.telemetry_enabled and settings.api_key): - return - - url = f"https://hud.ai/jobs/{job_id}" - - # ANSI color codes - GREEN = "\033[92m" - RED = "\033[91m" - GOLD = "\033[33m" - RESET = "\033[0m" - DIM = "\033[2m" - BOLD = "\033[1m" - - if error_occurred: - print( # noqa: T201 - f"\n{RED}✗ Job '{job_name}' failed!{RESET} {DIM}View details at:{RESET} {BOLD}{GOLD}{url}{RESET}\n" # noqa: E501 - ) - else: - print( # noqa: T201 - f"\n{GREEN}✓ Job '{job_name}' complete!{RESET} {DIM}View all results at:{RESET} {BOLD}{GOLD}{url}{RESET}\n" # noqa: E501 - ) - - -def get_current_job() -> Job | None: - """Get the currently active job, if any.""" - return _current_job - - -@contextmanager -def job( - name: str, - metadata: dict[str, Any] | None = None, - job_id: str | None = None, - dataset_link: str | None = None, -) -> Generator[Job, None, None]: - """Context manager for job tracking and organization. - - Groups related tasks together under a single job for tracking and visualization. - - Args: - name: Human-readable job name - metadata: Optional metadata dictionary - job_id: Optional job ID (auto-generated if not provided) - dataset_link: Optional HuggingFace dataset identifier (e.g. "hud-evals/SheetBench-50") - - Yields: - Job: The job object - - Example: - >>> import hud - >>> with hud.job("training_run", {"model": "gpt-4"}) as job: - ... for epoch in range(10): - ... with hud.trace(f"epoch_{epoch}", job_id=job.id): - ... train_epoch() - >>> # For async code, use async_job - >>> async with hud.async_job("batch_processing") as job: - ... async with hud.async_trace("task", job_id=job.id): - ... await process() - - Note: - This is a synchronous context manager that uses blocking HTTP calls. - For async code, use `hud.async_job()` instead. - """ - global _current_job - - if not job_id: - job_id = str(uuid.uuid4()) - - job_obj = Job(job_id, name, metadata, dataset_link) - - # Set as current job - old_job = _current_job - _current_job = job_obj - - try: - job_obj.update_status_sync("running") - _print_job_url(job_obj.id, job_obj.name) - yield job_obj - job_obj.update_status_sync("completed") - _print_job_complete_url(job_obj.id, job_obj.name, error_occurred=False) - except Exception: - job_obj.update_status_sync("failed") - _print_job_complete_url(job_obj.id, job_obj.name, error_occurred=True) - raise - finally: - _current_job = old_job - - -def create_job( - name: str, - metadata: dict[str, Any] | None = None, - dataset_link: str | None = None, - job_id: str | None = None, -) -> Job: - """Create a job without using context manager. - - Useful when you need explicit control over job lifecycle. - - Args: - name: Human-readable job name - metadata: Optional metadata dictionary - dataset_link: Optional HuggingFace dataset identifier (e.g. "hud-evals/SheetBench-50") - job_id: Optional job ID (auto-generated if not provided) - Returns: - Job: The created job object - - Example: - job = hud.create_job("data_processing") - try: - for item in items: - with hud.trace(f"process_{item.id}", job_id=job.id): - process(item) - finally: - await job.update_status("completed") - """ - job_id = job_id or str(uuid.uuid4()) - return Job(job_id, name, metadata, dataset_link) - - -def job_decorator(name: str | None = None, **metadata: Any) -> Callable: - """Decorator for functions that should be tracked as jobs. - - Args: - name: Job name (defaults to function name) - **metadata: Additional metadata for the job - - Example: - @hud.job_decorator("model_training", model="gpt-4", dataset="v2") - async def train_model(config): - # This entire function execution is tracked as a job - await model.train(config) - return model.evaluate() - """ - - def decorator(func: Callable) -> Callable: - job_name = name or func.__name__ - - @wraps(func) - async def async_wrapper(*args: Any, **kwargs: Any) -> Any: - with job(job_name, metadata) as job_obj: - # Store job ID in function for access - func._current_job_id = job_obj.id - try: - return await func(*args, **kwargs) - finally: - delattr(func, "_current_job_id") - - @wraps(func) - def sync_wrapper(*args: Any, **kwargs: Any) -> Any: - with job(job_name, metadata) as job_obj: - # Store job ID in function for access - func._current_job_id = job_obj.id - try: - return func(*args, **kwargs) - finally: - delattr(func, "_current_job_id") - - # Return appropriate wrapper based on function type - if asyncio.iscoroutinefunction(func): - return async_wrapper - else: - return sync_wrapper - - return decorator - - -# Convenience exports -__all__ = [ - "Job", - "create_job", - "get_current_job", - "job", - "job_decorator", -] diff --git a/hud/telemetry/replay.py b/hud/telemetry/replay.py deleted file mode 100644 index 67d5bddc..00000000 --- a/hud/telemetry/replay.py +++ /dev/null @@ -1,74 +0,0 @@ -"""Trace retrieval and replay functionality. - -This module provides APIs to retrieve collected traces for analysis, -debugging, and replay purposes. -""" - -from __future__ import annotations - -from typing import TYPE_CHECKING - -from hud.otel.collector import clear_trace as _clear_trace -from hud.otel.collector import get_trace as _get_trace - -if TYPE_CHECKING: - from hud.types import Trace - -__all__ = ["clear_trace", "get_trace"] - - -def get_trace(task_run_id: str) -> Trace | None: - """Retrieve the collected trace for a task run. - - Returns None if trace collection was disabled or the trace doesn't exist. - - Args: - task_run_id: The task run ID to retrieve the trace for - - Returns: - Trace object containing all collected steps, or None if not found - - Usage: - import hud - - # Run agent with tracing - with hud.trace() as task_run_id: - agent = MyAgent() - result = await agent.run("solve task") - - # Get the trace for analysis - trace = hud.get_trace(task_run_id) - if trace: - print(f"Collected {len(trace.trace)} steps") - - # Analyze agent vs MCP steps - agent_steps = [s for s in trace.trace if s.category == "agent"] - mcp_steps = [s for s in trace.trace if s.category == "mcp"] - - print(f"Agent steps: {len(agent_steps)}") - print(f"MCP steps: {len(mcp_steps)}") - - # Replay or analyze individual steps - for step in trace.trace: - if step.category == "agent" and step.result: - print(f"Agent: {step.result.get('content') if isinstance(step.result, dict) else step.result}") - if step.category == "mcp" and step.request: - print(f"MCP: {step.request.method if hasattr(step.request, 'method') else step.request}") - """ # noqa: E501 - return _get_trace(task_run_id) - - -def clear_trace(task_run_id: str) -> None: - """Clear the collected trace for a task run ID. - - Useful for cleaning up memory after processing large traces. - - Args: - task_run_id: The task run ID to clear the trace for - - Usage: - trace = hud.get_trace(task_run_id) - # Process trace... - hud.clear_trace(task_run_id) # Free memory - """ - _clear_trace(task_run_id) diff --git a/hud/telemetry/tests/test_async_context.py b/hud/telemetry/tests/test_async_context.py deleted file mode 100644 index 47697cbc..00000000 --- a/hud/telemetry/tests/test_async_context.py +++ /dev/null @@ -1,515 +0,0 @@ -from __future__ import annotations - -from unittest.mock import AsyncMock, MagicMock, patch - -import pytest - -from hud.telemetry.async_context import async_job, async_trace - - -@pytest.mark.asyncio -async def test_async_trace_basic(): - """Test basic AsyncTrace usage.""" - with ( - patch("hud.telemetry.async_context.OtelTrace") as mock_otel, - patch("hud.telemetry.async_context._update_task_status_async", new_callable=AsyncMock), - ): - mock_otel_instance = MagicMock() - mock_otel.return_value = mock_otel_instance - - async with async_trace("Test Task") as trace_obj: - assert trace_obj.name == "Test Task" - assert trace_obj.id is not None - - -@pytest.mark.asyncio -async def test_async_trace_with_job_id(): - """Test AsyncTrace with job_id parameter.""" - with ( - patch("hud.telemetry.async_context.OtelTrace") as mock_otel, - patch("hud.telemetry.async_context._update_task_status_async", new_callable=AsyncMock), - ): - mock_otel_instance = MagicMock() - mock_otel.return_value = mock_otel_instance - - async with async_trace("Test", job_id="job-123") as trace_obj: - assert trace_obj.job_id == "job-123" - - -@pytest.mark.asyncio -async def test_async_trace_with_task_id(): - """Test AsyncTrace with task_id parameter.""" - with ( - patch("hud.telemetry.async_context.OtelTrace") as mock_otel, - patch("hud.telemetry.async_context._update_task_status_async", new_callable=AsyncMock), - ): - mock_otel_instance = MagicMock() - mock_otel.return_value = mock_otel_instance - - async with async_trace("Test", task_id="task-456") as trace_obj: - assert trace_obj.task_id == "task-456" - - -@pytest.mark.asyncio -async def test_async_trace_status_updates(): - """Test AsyncTrace sends and awaits status updates.""" - with ( - patch("hud.telemetry.async_context.settings") as mock_settings, - patch("hud.telemetry.async_context.OtelTrace") as mock_otel, - patch( - "hud.telemetry.async_context._update_task_status_async", - new_callable=AsyncMock, - ) as mock_update, - patch("hud.telemetry.utils.flush_telemetry", new_callable=AsyncMock), - ): - mock_settings.telemetry_enabled = True - mock_settings.api_key = "test-key" - mock_otel_instance = MagicMock() - mock_otel.return_value = mock_otel_instance - - async with async_trace("Test", job_id=None): - pass - - assert mock_update.call_count == 2 - - -@pytest.mark.asyncio -async def test_async_trace_with_exception(): - """Test AsyncTrace handles exceptions.""" - with ( - patch("hud.telemetry.async_context.settings") as mock_settings, - patch("hud.telemetry.async_context.OtelTrace") as mock_otel, - patch( - "hud.telemetry.async_context._update_task_status_async", - new_callable=AsyncMock, - ) as mock_update, - patch("hud.telemetry.utils.flush_telemetry", new_callable=AsyncMock), - ): - mock_settings.telemetry_enabled = True - mock_settings.api_key = "test-key" - mock_otel_instance = MagicMock() - mock_otel.return_value = mock_otel_instance - - with pytest.raises(ValueError): - async with async_trace("Test"): - raise ValueError("Test error") - - assert mock_update.call_count == 2 - final_call = mock_update.call_args_list[1] - assert final_call[0][1] == "error" - - -@pytest.mark.asyncio -async def test_async_trace_non_root(): - """Test AsyncTrace with root=False.""" - with ( - patch("hud.telemetry.async_context.OtelTrace") as mock_otel, - patch( - "hud.telemetry.async_context._update_task_status_async", - new_callable=AsyncMock, - ) as mock_update, - ): - mock_otel_instance = MagicMock() - mock_otel.return_value = mock_otel_instance - - async with async_trace("Test", root=False): - pass - - mock_update.assert_not_called() - - -@pytest.mark.asyncio -async def test_async_trace_flushes_when_standalone(): - """Test AsyncTrace flushes spans when not part of a job.""" - with ( - patch("hud.telemetry.async_context.settings") as mock_settings, - patch("hud.telemetry.async_context.OtelTrace") as mock_otel, - patch( - "hud.telemetry.async_context._update_task_status_async", - new_callable=AsyncMock, - ), - patch("hud.telemetry.utils.flush_telemetry", new_callable=AsyncMock) as mock_flush, - ): - mock_settings.telemetry_enabled = True - mock_settings.api_key = "test-key" - mock_otel_instance = MagicMock() - mock_otel.return_value = mock_otel_instance - - async with async_trace("Test", job_id=None): - pass - - mock_flush.assert_called_once() - - -@pytest.mark.asyncio -async def test_async_trace_no_flush_when_in_job(): - """Test AsyncTrace doesn't flush when part of a job.""" - with ( - patch("hud.telemetry.async_context.settings") as mock_settings, - patch("hud.telemetry.async_context.OtelTrace") as mock_otel, - patch( - "hud.telemetry.async_context._update_task_status_async", - new_callable=AsyncMock, - ), - patch("hud.telemetry.utils.flush_telemetry", new_callable=AsyncMock) as mock_flush, - ): - mock_settings.telemetry_enabled = True - mock_settings.api_key = "test-key" - mock_otel_instance = MagicMock() - mock_otel.return_value = mock_otel_instance - - async with async_trace("Test", job_id="job-123"): - pass - - mock_flush.assert_not_called() - - -@pytest.mark.asyncio -async def test_async_job_basic(): - """Test basic AsyncJob usage.""" - with ( - patch("hud.telemetry.async_context.make_request", new_callable=AsyncMock), - patch("hud.telemetry.utils.flush_telemetry", new_callable=AsyncMock), - patch("hud.telemetry.async_context._print_job_url"), - patch("hud.telemetry.async_context._print_job_complete_url"), - ): - async with async_job("Test Job") as job_obj: - assert job_obj.name == "Test Job" - assert job_obj.id is not None - - -@pytest.mark.asyncio -async def test_async_job_with_metadata(): - """Test AsyncJob with metadata.""" - with ( - patch("hud.telemetry.async_context.make_request", new_callable=AsyncMock), - patch("hud.telemetry.utils.flush_telemetry", new_callable=AsyncMock), - patch("hud.telemetry.async_context._print_job_url"), - patch("hud.telemetry.async_context._print_job_complete_url"), - ): - async with async_job("Test", metadata={"key": "value"}) as job_obj: - assert job_obj.metadata == {"key": "value"} - - -@pytest.mark.asyncio -async def test_async_job_with_dataset_link(): - """Test AsyncJob with dataset_link.""" - with ( - patch("hud.telemetry.async_context.make_request", new_callable=AsyncMock), - patch("hud.telemetry.utils.flush_telemetry", new_callable=AsyncMock), - patch("hud.telemetry.async_context._print_job_url"), - patch("hud.telemetry.async_context._print_job_complete_url"), - ): - async with async_job("Test", dataset_link="test/dataset") as job_obj: - assert job_obj.dataset_link == "test/dataset" - - -@pytest.mark.asyncio -async def test_async_job_with_custom_job_id(): - """Test AsyncJob with custom job_id.""" - with ( - patch("hud.telemetry.async_context.make_request", new_callable=AsyncMock), - patch("hud.telemetry.utils.flush_telemetry", new_callable=AsyncMock), - patch("hud.telemetry.async_context._print_job_url"), - patch("hud.telemetry.async_context._print_job_complete_url"), - ): - async with async_job("Test", job_id="custom-id") as job_obj: - assert job_obj.id == "custom-id" - - -@pytest.mark.asyncio -async def test_async_job_with_exception(): - """Test AsyncJob handles exceptions.""" - with ( - patch("hud.telemetry.async_context.make_request", new_callable=AsyncMock), - patch("hud.telemetry.utils.flush_telemetry", new_callable=AsyncMock), - patch("hud.telemetry.async_context._print_job_url"), - patch("hud.telemetry.async_context._print_job_complete_url") as mock_print, - ): - with pytest.raises(ValueError): - async with async_job("Test"): - raise ValueError("Job error") - - mock_print.assert_called_once() - call_kwargs = mock_print.call_args[1] - assert call_kwargs["error_occurred"] is True - - -@pytest.mark.asyncio -async def test_async_job_status_updates(): - """Test AsyncJob sends status updates.""" - with ( - patch("hud.telemetry.async_context.settings") as mock_settings, - patch("hud.telemetry.async_context.make_request", new_callable=AsyncMock) as mock_request, - patch("hud.telemetry.utils.flush_telemetry", new_callable=AsyncMock), - patch("hud.telemetry.async_context._print_job_url"), - patch("hud.telemetry.async_context._print_job_complete_url"), - ): - mock_settings.telemetry_enabled = True - mock_settings.api_key = "test-key" - mock_settings.hud_telemetry_url = "https://test.com" - - async with async_job("Test"): - pass - - assert mock_request.call_count == 2 - - -@pytest.mark.asyncio -async def test_async_job_flushes_on_exit(): - """Test AsyncJob flushes telemetry on exit.""" - with ( - patch("hud.telemetry.async_context.make_request", new_callable=AsyncMock), - patch("hud.telemetry.utils.flush_telemetry", new_callable=AsyncMock) as mock_flush, - patch("hud.telemetry.async_context._print_job_url"), - patch("hud.telemetry.async_context._print_job_complete_url"), - ): - async with async_job("Test"): - pass - - mock_flush.assert_called_once() - - -@pytest.mark.asyncio -async def test_async_trace_nested_contexts(): - """Test nested AsyncTrace contexts work correctly.""" - with ( - patch("hud.telemetry.async_context.OtelTrace") as mock_otel, - patch("hud.telemetry.async_context._update_task_status_async", new_callable=AsyncMock), - ): - mock_otel_instance = MagicMock() - mock_otel.return_value = mock_otel_instance - - async with async_trace("Outer Task") as outer: - assert outer.name == "Outer Task" - - async with async_trace("Inner Task", root=False) as inner: - assert inner.name == "Inner Task" - assert inner.id != outer.id - - -@pytest.mark.asyncio -async def test_async_trace_concurrent_traces(): - """Test multiple concurrent AsyncTrace contexts.""" - import asyncio - - with ( - patch("hud.telemetry.async_context.OtelTrace") as mock_otel, - patch("hud.telemetry.async_context._update_task_status_async", new_callable=AsyncMock), - ): - mock_otel_instance = MagicMock() - mock_otel.return_value = mock_otel_instance - - async def run_trace(name: str): - async with async_trace(name) as trace_obj: - await asyncio.sleep(0.01) - return trace_obj.id - - # Run multiple traces concurrently - ids = await asyncio.gather( - run_trace("Trace 1"), - run_trace("Trace 2"), - run_trace("Trace 3"), - ) - - # All traces should have unique IDs - assert len(set(ids)) == 3 - - -@pytest.mark.asyncio -async def test_async_trace_with_attrs(): - """Test AsyncTrace with attrs parameter passed to OtelTrace.""" - with ( - patch("hud.telemetry.async_context.OtelTrace") as mock_otel, - patch("hud.telemetry.async_context._update_task_status_async", new_callable=AsyncMock), - ): - mock_otel_instance = MagicMock() - mock_otel.return_value = mock_otel_instance - - attrs = {"key": "value", "count": 42} - async with async_trace("Test", attrs=attrs): - # attrs are passed to OtelTrace, not exposed on Trace object - mock_otel.assert_called_once() - call_kwargs = mock_otel.call_args[1] - assert call_kwargs["attributes"] == attrs - - -@pytest.mark.asyncio -async def test_async_trace_exception_types(): - """Test AsyncTrace handles different exception types correctly.""" - with ( - patch("hud.telemetry.async_context.settings") as mock_settings, - patch("hud.telemetry.async_context.OtelTrace") as mock_otel, - patch( - "hud.telemetry.async_context._update_task_status_async", - new_callable=AsyncMock, - ) as mock_update, - patch("hud.telemetry.utils.flush_telemetry", new_callable=AsyncMock), - ): - mock_settings.telemetry_enabled = True - mock_settings.api_key = "test-key" - mock_otel_instance = MagicMock() - mock_otel.return_value = mock_otel_instance - - # Test KeyError - with pytest.raises(KeyError): - async with async_trace("Test"): - raise KeyError("Missing key") - - # Test RuntimeError - with pytest.raises(RuntimeError): - async with async_trace("Test"): - raise RuntimeError("Runtime issue") - - # Both should have resulted in error status - assert mock_update.call_count >= 4 # 2 calls per trace - - -@pytest.mark.asyncio -async def test_async_job_nested_with_trace(): - """Test AsyncJob with nested AsyncTrace contexts.""" - with ( - patch("hud.telemetry.async_context.make_request", new_callable=AsyncMock), - patch("hud.telemetry.utils.flush_telemetry", new_callable=AsyncMock), - patch("hud.telemetry.async_context._print_job_url"), - patch("hud.telemetry.async_context._print_job_complete_url"), - patch("hud.telemetry.async_context.OtelTrace") as mock_otel, - patch("hud.telemetry.async_context._update_task_status_async", new_callable=AsyncMock), - ): - mock_otel_instance = MagicMock() - mock_otel.return_value = mock_otel_instance - - async with async_job("Test Job") as job_obj: - async with async_trace("Task 1", job_id=job_obj.id) as trace1: - assert trace1.job_id == job_obj.id - - async with async_trace("Task 2", job_id=job_obj.id) as trace2: - assert trace2.job_id == job_obj.id - assert trace2.id != trace1.id - - -@pytest.mark.asyncio -async def test_async_job_concurrent_jobs(): - """Test multiple concurrent AsyncJob contexts.""" - import asyncio - - with ( - patch("hud.telemetry.async_context.make_request", new_callable=AsyncMock), - patch("hud.telemetry.utils.flush_telemetry", new_callable=AsyncMock), - patch("hud.telemetry.async_context._print_job_url"), - patch("hud.telemetry.async_context._print_job_complete_url"), - ): - - async def run_job(name: str): - async with async_job(name) as job_obj: - await asyncio.sleep(0.01) - return job_obj.id - - # Run multiple jobs concurrently - ids = await asyncio.gather( - run_job("Job 1"), - run_job("Job 2"), - run_job("Job 3"), - ) - - # All jobs should have unique IDs - assert len(set(ids)) == 3 - - -@pytest.mark.asyncio -async def test_async_job_with_multiple_exceptions(): - """Test AsyncJob handles multiple exceptions in nested contexts.""" - with ( - patch("hud.telemetry.async_context.make_request", new_callable=AsyncMock), - patch("hud.telemetry.utils.flush_telemetry", new_callable=AsyncMock), - patch("hud.telemetry.async_context._print_job_url"), - patch("hud.telemetry.async_context._print_job_complete_url") as mock_print, - ): - with pytest.raises(ValueError): - async with async_job("Test"): - try: - raise RuntimeError("First error") - except RuntimeError: - # Catch and re-raise different error - raise ValueError("Second error") - - mock_print.assert_called_once() - call_kwargs = mock_print.call_args[1] - assert call_kwargs["error_occurred"] is True - - -@pytest.mark.asyncio -async def test_async_trace_telemetry_disabled(): - """Test AsyncTrace behavior when telemetry is disabled.""" - with ( - patch("hud.telemetry.async_context.settings") as mock_settings, - patch("hud.telemetry.async_context.OtelTrace") as mock_otel, - patch( - "hud.telemetry.async_context._update_task_status_async", - new_callable=AsyncMock, - ), - ): - mock_settings.telemetry_enabled = False - mock_otel_instance = MagicMock() - mock_otel.return_value = mock_otel_instance - - async with async_trace("Test"): - pass - - # Should still create trace but not send updates - mock_otel.assert_called_once() - # Status updates might still be called depending on implementation - - -@pytest.mark.asyncio -async def test_async_job_empty_metadata(): - """Test AsyncJob with empty metadata dict.""" - with ( - patch("hud.telemetry.async_context.make_request", new_callable=AsyncMock), - patch("hud.telemetry.utils.flush_telemetry", new_callable=AsyncMock), - patch("hud.telemetry.async_context._print_job_url"), - patch("hud.telemetry.async_context._print_job_complete_url"), - ): - async with async_job("Test", metadata={}) as job_obj: - assert job_obj.metadata == {} - - -@pytest.mark.asyncio -async def test_async_trace_with_all_parameters(): - """Test AsyncTrace with all parameters specified.""" - with ( - patch("hud.telemetry.async_context.OtelTrace") as mock_otel, - patch("hud.telemetry.async_context._update_task_status_async", new_callable=AsyncMock), - ): - mock_otel_instance = MagicMock() - mock_otel.return_value = mock_otel_instance - - async with async_trace( - "Test", - job_id="job-123", - task_id="task-456", - group_id="group-789", - attrs={"key": "value"}, - root=True, - ) as trace_obj: - assert trace_obj.name == "Test" - assert trace_obj.job_id == "job-123" - assert trace_obj.task_id == "task-456" - assert trace_obj.group_id == "group-789" - # Verify attrs were passed to OtelTrace - call_kwargs = mock_otel.call_args[1] - assert call_kwargs["attributes"] == {"key": "value"} - - -@pytest.mark.asyncio -async def test_async_trace_with_group_id(): - """Test AsyncTrace with group_id parameter.""" - with ( - patch("hud.telemetry.async_context.OtelTrace") as mock_otel, - patch("hud.telemetry.async_context._update_task_status_async", new_callable=AsyncMock), - ): - mock_otel_instance = MagicMock() - mock_otel.return_value = mock_otel_instance - - async with async_trace("Test", group_id="group-999") as trace_obj: - assert trace_obj.group_id == "group-999" diff --git a/hud/telemetry/tests/test_eval_telemetry.py b/hud/telemetry/tests/test_eval_telemetry.py new file mode 100644 index 00000000..15a8760d --- /dev/null +++ b/hud/telemetry/tests/test_eval_telemetry.py @@ -0,0 +1,354 @@ +"""Tests for EvalContext telemetry integration with mock backend.""" + +from __future__ import annotations + +import asyncio +from typing import Any +from unittest.mock import patch + +import pytest + +import hud +from hud.environment import Environment +from hud.eval import Task +from hud.telemetry.exporter import _pending_spans + + +@pytest.fixture(autouse=True) +def clear_pending_spans(): + """Clear pending spans before and after each test.""" + _pending_spans.clear() + yield + _pending_spans.clear() + + +class TestEvalContextTelemetry: + """Tests for EvalContext telemetry integration.""" + + @pytest.mark.asyncio + async def test_call_tool_records_span(self): + """Test that call_tool records a span with correct format.""" + uploaded_spans: list[dict[str, Any]] = [] + + def capture_upload( + task_run_id: str, + spans: list[dict[str, Any]], + telemetry_url: str, + api_key: str, + ) -> bool: + uploaded_spans.extend(spans) + return True + + # Create environment with a simple tool + env = Environment("test-env") + + @env.tool + async def greet(name: str) -> str: + """Say hello.""" + return f"Hello, {name}!" + + # Create task from environment + task = Task(env=env) + + with ( + patch("hud.settings.settings") as mock_settings, + patch("hud.telemetry.exporter._do_upload", side_effect=capture_upload), + patch("hud.eval.context.make_request"), # Don't send eval enter/exit + ): + mock_settings.api_key = "test-key" + mock_settings.telemetry_enabled = True + mock_settings.hud_telemetry_url = "https://api.hud.ai" + mock_settings.hud_api_url = "https://api.hud.ai" + + async with hud.eval(task) as ctx: + result = await ctx.call_tool("greet", name="World") + # call_tool returns MCPToolResult with formatted content + assert "Hello, World!" in str(result) + trace_id = ctx.trace_id + + # Wait for thread pool + await asyncio.sleep(0.2) + + # Verify span was recorded + assert len(uploaded_spans) >= 1 + span = uploaded_spans[0] + + # Check span structure + assert "name" in span + assert "trace_id" in span + assert "span_id" in span + assert "start_time" in span + assert "end_time" in span + assert "status_code" in span + assert "attributes" in span + + # Check attributes + attrs = span["attributes"] + assert attrs["task_run_id"] == trace_id + assert attrs["category"] == "mcp" + + @pytest.mark.asyncio + async def test_call_tool_records_error_span(self): + """Test that failed call_tool records error span.""" + uploaded_spans: list[dict[str, Any]] = [] + + def capture_upload( + task_run_id: str, + spans: list[dict[str, Any]], + telemetry_url: str, + api_key: str, + ) -> bool: + uploaded_spans.extend(spans) + return True + + env = Environment("test-env") + + @env.tool + async def failing_tool() -> str: + """Always fails.""" + raise ValueError("Tool error") + + task = Task(env=env) + + with ( + patch("hud.settings.settings") as mock_settings, + patch("hud.telemetry.exporter._do_upload", side_effect=capture_upload), + patch("hud.eval.context.make_request"), + ): + mock_settings.api_key = "test-key" + mock_settings.telemetry_enabled = True + mock_settings.hud_telemetry_url = "https://api.hud.ai" + mock_settings.hud_api_url = "https://api.hud.ai" + + async with hud.eval(task) as ctx: + # Tool errors are wrapped in ToolError + with pytest.raises(Exception, match="Tool error"): + await ctx.call_tool("failing_tool") + + await asyncio.sleep(0.2) + + # Should have recorded span with ERROR status + assert len(uploaded_spans) >= 1 + span = uploaded_spans[0] + assert span["status_code"] == "ERROR" + # Error message contains the original error + assert "Tool error" in (span.get("status_message") or "") + + @pytest.mark.asyncio + async def test_multiple_call_tools_record_spans(self): + """Test that multiple call_tool calls each record a span.""" + uploaded_spans: list[dict[str, Any]] = [] + + def capture_upload( + task_run_id: str, + spans: list[dict[str, Any]], + telemetry_url: str, + api_key: str, + ) -> bool: + uploaded_spans.extend(spans) + return True + + env = Environment("test-env") + + @env.tool + async def add(a: int, b: int) -> int: + """Add two numbers.""" + return a + b + + @env.tool + async def multiply(a: int, b: int) -> int: + """Multiply two numbers.""" + return a * b + + task = Task(env=env) + + with ( + patch("hud.settings.settings") as mock_settings, + patch("hud.telemetry.exporter._do_upload", side_effect=capture_upload), + patch("hud.eval.context.make_request"), + ): + mock_settings.api_key = "test-key" + mock_settings.telemetry_enabled = True + mock_settings.hud_telemetry_url = "https://api.hud.ai" + mock_settings.hud_api_url = "https://api.hud.ai" + + async with hud.eval(task) as ctx: + r1 = await ctx.call_tool("add", a=2, b=3) + r2 = await ctx.call_tool("multiply", a=4, b=5) + # Results are MCPToolResult objects + assert "5" in str(r1) + assert "20" in str(r2) + + await asyncio.sleep(0.2) + + # Should have 2 spans + assert len(uploaded_spans) >= 2 + + @pytest.mark.asyncio + async def test_flush_called_on_context_exit(self): + """Test that flush is called when context exits.""" + env = Environment("test-env") + + @env.tool + async def simple_tool() -> str: + return "done" + + task = Task(env=env) + + with ( + patch("hud.eval.context.flush") as mock_flush, + patch("hud.settings.settings") as mock_settings, + patch("hud.eval.context.make_request"), + ): + mock_settings.api_key = "test-key" + mock_settings.telemetry_enabled = True + mock_settings.hud_api_url = "https://api.hud.ai" + + async with hud.eval(task) as ctx: + await ctx.call_tool("simple_tool") + trace_id = ctx.trace_id + + # Verify flush was called with the trace_id + mock_flush.assert_called_once_with(trace_id) + + @pytest.mark.asyncio + async def test_telemetry_disabled_no_upload(self): + """Test that no upload happens when telemetry is disabled.""" + upload_called = False + + def should_not_be_called(*args: Any, **kwargs: Any) -> bool: + nonlocal upload_called + upload_called = True + return True + + env = Environment("test-env") + + @env.tool + async def test_tool() -> str: + return "ok" + + task = Task(env=env) + + with ( + patch("hud.settings.settings") as mock_settings, + patch("hud.telemetry.exporter._do_upload", side_effect=should_not_be_called), + patch("hud.eval.context.make_request"), + ): + mock_settings.api_key = "test-key" + mock_settings.telemetry_enabled = False # Disabled! + mock_settings.hud_telemetry_url = "https://api.hud.ai" + mock_settings.hud_api_url = "https://api.hud.ai" + + async with hud.eval(task) as ctx: + await ctx.call_tool("test_tool") + + await asyncio.sleep(0.1) + + assert upload_called is False + + +class TestSpanFormat: + """Tests for the format of recorded spans.""" + + @pytest.mark.asyncio + async def test_span_has_required_fields(self): + """Test that spans have all required HudSpan fields.""" + uploaded_spans: list[dict[str, Any]] = [] + + def capture_upload( + task_run_id: str, + spans: list[dict[str, Any]], + telemetry_url: str, + api_key: str, + ) -> bool: + uploaded_spans.extend(spans) + return True + + env = Environment("test-env") + + @env.tool + async def echo(message: str) -> str: + return message + + task = Task(env=env) + + with ( + patch("hud.settings.settings") as mock_settings, + patch("hud.telemetry.exporter._do_upload", side_effect=capture_upload), + patch("hud.eval.context.make_request"), + ): + mock_settings.api_key = "test-key" + mock_settings.telemetry_enabled = True + mock_settings.hud_telemetry_url = "https://api.hud.ai" + mock_settings.hud_api_url = "https://api.hud.ai" + + async with hud.eval(task) as ctx: + await ctx.call_tool("echo", message="test") + + await asyncio.sleep(0.2) + + assert len(uploaded_spans) >= 1 + span = uploaded_spans[0] + + # Required fields from HudSpan + assert "name" in span + assert "trace_id" in span + assert len(span["trace_id"]) == 32 # 32-char hex + assert "span_id" in span + assert len(span["span_id"]) == 16 # 16-char hex + assert "start_time" in span + assert "end_time" in span + assert "status_code" in span + assert span["status_code"] in ("OK", "ERROR", "UNSET") + + # Attributes + assert "attributes" in span + attrs = span["attributes"] + assert "task_run_id" in attrs + assert "category" in attrs + + @pytest.mark.asyncio + async def test_span_timestamps_are_iso(self): + """Test that span timestamps are in ISO format.""" + uploaded_spans: list[dict[str, Any]] = [] + + def capture_upload( + task_run_id: str, + spans: list[dict[str, Any]], + telemetry_url: str, + api_key: str, + ) -> bool: + uploaded_spans.extend(spans) + return True + + env = Environment("test-env") + + @env.tool + async def noop() -> None: + pass + + task = Task(env=env) + + with ( + patch("hud.settings.settings") as mock_settings, + patch("hud.telemetry.exporter._do_upload", side_effect=capture_upload), + patch("hud.eval.context.make_request"), + ): + mock_settings.api_key = "test-key" + mock_settings.telemetry_enabled = True + mock_settings.hud_telemetry_url = "https://api.hud.ai" + mock_settings.hud_api_url = "https://api.hud.ai" + + async with hud.eval(task) as ctx: + await ctx.call_tool("noop") + + await asyncio.sleep(0.2) + + span = uploaded_spans[0] + + # ISO format: YYYY-MM-DDTHH:MM:SS.ssssssZ + import re + + iso_pattern = r"\d{4}-\d{2}-\d{2}T\d{2}:\d{2}:\d{2}" + assert re.match(iso_pattern, span["start_time"]) + assert re.match(iso_pattern, span["end_time"]) diff --git a/hud/telemetry/tests/test_exporter.py b/hud/telemetry/tests/test_exporter.py new file mode 100644 index 00000000..3e74cfea --- /dev/null +++ b/hud/telemetry/tests/test_exporter.py @@ -0,0 +1,254 @@ +"""Tests for telemetry exporter with mock backend.""" + +from __future__ import annotations + +import asyncio +from typing import Any +from unittest.mock import MagicMock, patch + +import pytest + +from hud.telemetry.exporter import ( + _do_upload, + _pending_spans, + flush, + queue_span, + shutdown, +) + + +@pytest.fixture(autouse=True) +def clear_pending_spans(): + """Clear pending spans before and after each test.""" + _pending_spans.clear() + yield + _pending_spans.clear() + + +class TestDoUpload: + """Tests for _do_upload function.""" + + def test_upload_success(self): + """Test successful upload.""" + with patch("hud.telemetry.exporter.make_request_sync") as mock_request: + result = _do_upload( + task_run_id="test-task-123", + spans=[{"name": "test.span", "attributes": {"task_run_id": "test-task-123"}}], + telemetry_url="https://api.hud.ai", + api_key="test-key", + ) + + assert result is True + mock_request.assert_called_once() + call_kwargs = mock_request.call_args.kwargs + assert call_kwargs["method"] == "POST" + assert "test-task-123" in call_kwargs["url"] + assert call_kwargs["api_key"] == "test-key" + assert "telemetry" in call_kwargs["json"] + + def test_upload_failure(self): + """Test upload failure handling.""" + with patch("hud.telemetry.exporter.make_request_sync") as mock_request: + mock_request.side_effect = Exception("Network error") + + result = _do_upload( + task_run_id="test-task-123", + spans=[{"name": "test.span"}], + telemetry_url="https://api.hud.ai", + api_key="test-key", + ) + + assert result is False + + +class TestQueueSpan: + """Tests for queue_span function.""" + + def test_queue_span_without_api_key(self): + """Test that spans are not queued without API key.""" + with patch("hud.settings.settings") as mock_settings: + mock_settings.api_key = None + mock_settings.telemetry_enabled = True + + queue_span({"name": "test", "attributes": {"task_run_id": "123"}}) + + assert len(_pending_spans) == 0 + + def test_queue_span_without_telemetry_enabled(self): + """Test that spans are not queued when telemetry disabled.""" + with patch("hud.settings.settings") as mock_settings: + mock_settings.api_key = "test-key" + mock_settings.telemetry_enabled = False + + queue_span({"name": "test", "attributes": {"task_run_id": "123"}}) + + assert len(_pending_spans) == 0 + + def test_queue_span_without_task_run_id(self): + """Test that spans without task_run_id are ignored.""" + with patch("hud.settings.settings") as mock_settings: + mock_settings.api_key = "test-key" + mock_settings.telemetry_enabled = True + + queue_span({"name": "test", "attributes": {}}) + + assert len(_pending_spans) == 0 + + def test_queue_span_adds_to_pending(self): + """Test that spans are added to pending list.""" + # Don't mock _do_upload so spans stay in pending + with patch("hud.settings.settings") as mock_settings: + mock_settings.api_key = "test-key" + mock_settings.telemetry_enabled = True + mock_settings.hud_telemetry_url = "https://api.hud.ai" + + # Use a sync context (no event loop) so upload happens sync + # But we'll make it fail so span stays in pending + with patch("hud.telemetry.exporter._do_upload", return_value=False): + span = {"name": "test", "attributes": {"task_run_id": "task-123"}} + queue_span(span) + + # Span should be in pending (upload failed so not removed) + assert "task-123" in _pending_spans + assert span in _pending_spans["task-123"] + + @pytest.mark.asyncio + async def test_queue_span_uploads_async(self): + """Test that spans are uploaded via thread pool in async context.""" + uploaded_spans: list[dict[str, Any]] = [] + + def mock_upload( + task_run_id: str, + spans: list[dict[str, Any]], + telemetry_url: str, + api_key: str, + ) -> bool: + uploaded_spans.extend(spans) + return True + + with ( + patch("hud.settings.settings") as mock_settings, + patch("hud.telemetry.exporter._do_upload", side_effect=mock_upload), + ): + mock_settings.api_key = "test-key" + mock_settings.telemetry_enabled = True + mock_settings.hud_telemetry_url = "https://api.hud.ai" + + span = {"name": "test.async", "attributes": {"task_run_id": "async-task"}} + queue_span(span) + + # Wait for thread pool to complete + await asyncio.sleep(0.1) + + assert len(uploaded_spans) == 1 + assert uploaded_spans[0]["name"] == "test.async" + + +class TestFlush: + """Tests for flush function.""" + + def test_flush_specific_task(self): + """Test flushing spans for specific task.""" + uploaded: list[tuple[str, list[dict[str, Any]]]] = [] + + def mock_upload( + task_run_id: str, + spans: list[dict[str, Any]], + telemetry_url: str, + api_key: str, + ) -> bool: + uploaded.append((task_run_id, spans)) + return True + + with ( + patch("hud.settings.settings") as mock_settings, + patch("hud.telemetry.exporter._do_upload", side_effect=mock_upload), + ): + mock_settings.api_key = "test-key" + mock_settings.telemetry_enabled = True + mock_settings.hud_telemetry_url = "https://api.hud.ai" + + # Add spans for two tasks + _pending_spans["task-1"].append({"name": "span1"}) + _pending_spans["task-2"].append({"name": "span2"}) + + # Flush only task-1 + flush("task-1") + + assert len(uploaded) == 1 + assert uploaded[0][0] == "task-1" + assert "task-1" not in _pending_spans + assert "task-2" in _pending_spans + + def test_flush_all_tasks(self): + """Test flushing all pending spans.""" + uploaded: list[tuple[str, list[dict[str, Any]]]] = [] + + def mock_upload( + task_run_id: str, + spans: list[dict[str, Any]], + telemetry_url: str, + api_key: str, + ) -> bool: + uploaded.append((task_run_id, spans)) + return True + + with ( + patch("hud.settings.settings") as mock_settings, + patch("hud.telemetry.exporter._do_upload", side_effect=mock_upload), + ): + mock_settings.api_key = "test-key" + mock_settings.telemetry_enabled = True + mock_settings.hud_telemetry_url = "https://api.hud.ai" + + _pending_spans["task-1"].append({"name": "span1"}) + _pending_spans["task-2"].append({"name": "span2"}) + + flush() + + assert len(uploaded) == 2 + assert len(_pending_spans) == 0 + + def test_flush_clears_without_api_key(self): + """Test that flush clears spans when no API key.""" + with patch("hud.settings.settings") as mock_settings: + mock_settings.api_key = None + mock_settings.telemetry_enabled = True + + _pending_spans["task-1"].append({"name": "span1"}) + + flush() + + assert len(_pending_spans) == 0 + + +class TestShutdown: + """Tests for shutdown function.""" + + def test_shutdown_flushes_pending(self): + """Test that shutdown flushes pending spans.""" + uploaded: list[str] = [] + + def mock_upload( + task_run_id: str, + spans: list[dict[str, Any]], + telemetry_url: str, + api_key: str, + ) -> bool: + uploaded.append(task_run_id) + return True + + with ( + patch("hud.settings.settings") as mock_settings, + patch("hud.telemetry.exporter._do_upload", side_effect=mock_upload), + ): + mock_settings.api_key = "test-key" + mock_settings.telemetry_enabled = True + mock_settings.hud_telemetry_url = "https://api.hud.ai" + + _pending_spans["shutdown-task"].append({"name": "final-span"}) + + result = shutdown(timeout=1.0) + + assert result is True + assert "shutdown-task" in uploaded diff --git a/hud/telemetry/tests/test_job.py b/hud/telemetry/tests/test_job.py deleted file mode 100644 index 4449da9f..00000000 --- a/hud/telemetry/tests/test_job.py +++ /dev/null @@ -1,555 +0,0 @@ -from __future__ import annotations - -from datetime import datetime -from unittest.mock import AsyncMock, patch - -import pytest - -from hud.telemetry.job import ( - Job, - _print_job_complete_url, - _print_job_url, - create_job, - get_current_job, - job, - job_decorator, -) - - -def test_job_initialization(): - """Test Job initialization with all parameters.""" - job_obj = Job( - job_id="test-id", - name="Test Job", - metadata={"key": "value"}, - dataset_link="test/dataset", - ) - - assert job_obj.id == "test-id" - assert job_obj.name == "Test Job" - assert job_obj.metadata == {"key": "value"} - assert job_obj.dataset_link == "test/dataset" - assert job_obj.status == "created" - assert isinstance(job_obj.created_at, datetime) - assert job_obj.tasks == [] - - -def test_job_initialization_minimal(): - """Test Job initialization with minimal parameters.""" - job_obj = Job(job_id="test-id", name="Test") - - assert job_obj.metadata == {} - assert job_obj.dataset_link is None - - -def test_job_add_task(): - """Test adding tasks to a job.""" - job_obj = Job(job_id="test-id", name="Test") - - job_obj.add_task("task1") - job_obj.add_task("task2") - - assert job_obj.tasks == ["task1", "task2"] - - -@pytest.mark.asyncio -async def test_job_update_status_async(): - """Test async status update.""" - job_obj = Job(job_id="test-id", name="Test") - - with ( - patch("hud.telemetry.job.settings") as mock_settings, - patch("hud.telemetry.job.make_request", new_callable=AsyncMock) as mock_request, - ): - mock_settings.telemetry_enabled = True - mock_settings.api_key = "test_key" - mock_settings.hud_telemetry_url = "https://test.com" - - await job_obj.update_status("running") - - assert job_obj.status == "running" - mock_request.assert_called_once() - call_kwargs = mock_request.call_args[1] - assert call_kwargs["method"] == "POST" - assert "test-id" in call_kwargs["url"] - assert call_kwargs["json"]["status"] == "running" - - -@pytest.mark.asyncio -async def test_job_update_status_async_with_dataset(): - """Test async status update includes dataset link.""" - job_obj = Job(job_id="test-id", name="Test", dataset_link="test/dataset") - - with ( - patch("hud.telemetry.job.settings") as mock_settings, - patch("hud.telemetry.job.make_request", new_callable=AsyncMock) as mock_request, - ): - mock_settings.telemetry_enabled = True - mock_settings.api_key = "test_key" - mock_settings.hud_telemetry_url = "https://test.com" - - await job_obj.update_status("running") - - call_kwargs = mock_request.call_args[1] - assert call_kwargs["json"]["dataset_link"] == "test/dataset" - - -@pytest.mark.asyncio -async def test_job_update_status_async_telemetry_disabled(): - """Test async status update when telemetry is disabled.""" - job_obj = Job(job_id="test-id", name="Test") - - with ( - patch("hud.telemetry.job.settings") as mock_settings, - patch("hud.telemetry.job.make_request", new_callable=AsyncMock) as mock_request, - ): - mock_settings.telemetry_enabled = False - - await job_obj.update_status("running") - - assert job_obj.status == "running" - mock_request.assert_not_called() - - -@pytest.mark.asyncio -async def test_job_update_status_async_error(): - """Test async status update handles errors gracefully.""" - job_obj = Job(job_id="test-id", name="Test") - - with ( - patch("hud.telemetry.job.settings") as mock_settings, - patch("hud.telemetry.job.make_request", new_callable=AsyncMock) as mock_request, - ): - mock_settings.telemetry_enabled = True - mock_settings.api_key = "test_key" - mock_settings.hud_telemetry_url = "https://test.com" - mock_request.side_effect = Exception("Network error") - - # Should not raise - await job_obj.update_status("running") - assert job_obj.status == "running" - - -def test_job_update_status_sync(): - """Test sync status update.""" - job_obj = Job(job_id="test-id", name="Test") - - with ( - patch("hud.telemetry.job.settings") as mock_settings, - patch("hud.telemetry.job.make_request_sync") as mock_request, - ): - mock_settings.telemetry_enabled = True - mock_settings.api_key = "test_key" - mock_settings.hud_telemetry_url = "https://test.com" - - job_obj.update_status_sync("completed") - - assert job_obj.status == "completed" - mock_request.assert_called_once() - - -def test_job_update_status_sync_with_dataset(): - """Test sync status update includes dataset link.""" - job_obj = Job(job_id="test-id", name="Test", dataset_link="test/dataset") - - with ( - patch("hud.telemetry.job.settings") as mock_settings, - patch("hud.telemetry.job.make_request_sync") as mock_request, - ): - mock_settings.telemetry_enabled = True - mock_settings.api_key = "test_key" - mock_settings.hud_telemetry_url = "https://test.com" - - job_obj.update_status_sync("completed") - - call_kwargs = mock_request.call_args[1] - assert call_kwargs["json"]["dataset_link"] == "test/dataset" - - -def test_job_update_status_sync_telemetry_disabled(): - """Test sync status update when telemetry is disabled.""" - job_obj = Job(job_id="test-id", name="Test") - - with ( - patch("hud.telemetry.job.settings") as mock_settings, - patch("hud.telemetry.job.make_request_sync") as mock_request, - ): - mock_settings.telemetry_enabled = False - - job_obj.update_status_sync("completed") - - mock_request.assert_not_called() - - -def test_job_update_status_sync_error(): - """Test sync status update handles errors gracefully.""" - job_obj = Job(job_id="test-id", name="Test") - - with ( - patch("hud.telemetry.job.settings") as mock_settings, - patch("hud.telemetry.job.make_request_sync") as mock_request, - ): - mock_settings.telemetry_enabled = True - mock_settings.api_key = "test_key" - mock_settings.hud_telemetry_url = "https://test.com" - mock_request.side_effect = Exception("Network error") - - # Should not raise - job_obj.update_status_sync("completed") - - -@pytest.mark.asyncio -async def test_job_log(): - """Test async log method.""" - job_obj = Job(job_id="test-id", name="Test") - - with ( - patch("hud.telemetry.job.settings") as mock_settings, - patch("hud.telemetry.job.make_request", new_callable=AsyncMock) as mock_request, - ): - mock_settings.telemetry_enabled = True - mock_settings.api_key = "test_key" - mock_settings.hud_telemetry_url = "https://test.com" - - await job_obj.log({"loss": 0.5, "accuracy": 0.95}) - - mock_request.assert_called_once() - call_kwargs = mock_request.call_args[1] - assert call_kwargs["json"]["metrics"] == {"loss": 0.5, "accuracy": 0.95} - assert "timestamp" in call_kwargs["json"] - - -@pytest.mark.asyncio -async def test_job_log_telemetry_disabled(): - """Test async log when telemetry is disabled.""" - job_obj = Job(job_id="test-id", name="Test") - - with ( - patch("hud.telemetry.job.settings") as mock_settings, - patch("hud.telemetry.job.make_request", new_callable=AsyncMock) as mock_request, - ): - mock_settings.telemetry_enabled = False - - await job_obj.log({"loss": 0.5}) - - mock_request.assert_not_called() - - -@pytest.mark.asyncio -async def test_job_log_error(): - """Test async log handles errors gracefully.""" - job_obj = Job(job_id="test-id", name="Test") - - with ( - patch("hud.telemetry.job.settings") as mock_settings, - patch("hud.telemetry.job.make_request", new_callable=AsyncMock) as mock_request, - ): - mock_settings.telemetry_enabled = True - mock_settings.api_key = "test_key" - mock_settings.hud_telemetry_url = "https://test.com" - mock_request.side_effect = Exception("Network error") - - # Should not raise - await job_obj.log({"loss": 0.5}) - - -def test_job_log_sync(): - """Test sync log method.""" - job_obj = Job(job_id="test-id", name="Test") - - with ( - patch("hud.telemetry.job.settings") as mock_settings, - patch("hud.telemetry.job.make_request_sync") as mock_request, - ): - mock_settings.telemetry_enabled = True - mock_settings.api_key = "test_key" - mock_settings.hud_telemetry_url = "https://test.com" - - job_obj.log_sync({"loss": 0.5, "accuracy": 0.95}) - - mock_request.assert_called_once() - call_kwargs = mock_request.call_args[1] - assert call_kwargs["json"]["metrics"] == {"loss": 0.5, "accuracy": 0.95} - - -def test_job_log_sync_telemetry_disabled(): - """Test sync log when telemetry is disabled.""" - job_obj = Job(job_id="test-id", name="Test") - - with ( - patch("hud.telemetry.job.settings") as mock_settings, - patch("hud.telemetry.job.make_request_sync") as mock_request, - ): - mock_settings.telemetry_enabled = False - - job_obj.log_sync({"loss": 0.5}) - - mock_request.assert_not_called() - - -def test_job_log_sync_error(): - """Test sync log handles errors gracefully.""" - job_obj = Job(job_id="test-id", name="Test") - - with ( - patch("hud.telemetry.job.settings") as mock_settings, - patch("hud.telemetry.job.make_request_sync") as mock_request, - ): - mock_settings.telemetry_enabled = True - mock_settings.api_key = "test_key" - mock_settings.hud_telemetry_url = "https://test.com" - mock_request.side_effect = Exception("Network error") - - # Should not raise - job_obj.log_sync({"loss": 0.5}) - - -def test_job_repr(): - """Test Job __repr__.""" - job_obj = Job(job_id="test-id", name="Test Job") - job_obj.status = "running" - - repr_str = repr(job_obj) - assert "test-id" in repr_str - assert "Test Job" in repr_str - assert "running" in repr_str - - -def test_print_job_url_enabled(): - """Test _print_job_url when telemetry is enabled.""" - with ( - patch("hud.telemetry.job.settings") as mock_settings, - patch("builtins.print") as mock_print, - ): - mock_settings.telemetry_enabled = True - mock_settings.api_key = "test_key" - - _print_job_url("job-123", "My Job") - - # Should print multiple lines (box) - assert mock_print.call_count > 0 - - -def test_print_job_url_disabled(): - """Test _print_job_url when telemetry is disabled.""" - with ( - patch("hud.telemetry.job.settings") as mock_settings, - patch("builtins.print") as mock_print, - ): - mock_settings.telemetry_enabled = False - - _print_job_url("job-123", "My Job") - - mock_print.assert_not_called() - - -def test_print_job_url_no_api_key(): - """Test _print_job_url when no API key is set.""" - with ( - patch("hud.telemetry.job.settings") as mock_settings, - patch("builtins.print") as mock_print, - ): - mock_settings.telemetry_enabled = True - mock_settings.api_key = None - - _print_job_url("job-123", "My Job") - - mock_print.assert_not_called() - - -def test_print_job_complete_url_success(): - """Test _print_job_complete_url for successful completion.""" - with ( - patch("hud.telemetry.job.settings") as mock_settings, - patch("builtins.print") as mock_print, - ): - mock_settings.telemetry_enabled = True - mock_settings.api_key = "test_key" - - _print_job_complete_url("job-123", "My Job", error_occurred=False) - - mock_print.assert_called_once() - call_str = str(mock_print.call_args) - assert "complete" in call_str.lower() or "✓" in call_str - - -def test_print_job_complete_url_failure(): - """Test _print_job_complete_url for failed completion.""" - with ( - patch("hud.telemetry.job.settings") as mock_settings, - patch("builtins.print") as mock_print, - ): - mock_settings.telemetry_enabled = True - mock_settings.api_key = "test_key" - - _print_job_complete_url("job-123", "My Job", error_occurred=True) - - mock_print.assert_called_once() - call_str = str(mock_print.call_args) - assert "fail" in call_str.lower() or "✗" in call_str - - -def test_print_job_complete_url_disabled(): - """Test _print_job_complete_url when telemetry is disabled.""" - with ( - patch("hud.telemetry.job.settings") as mock_settings, - patch("builtins.print") as mock_print, - ): - mock_settings.telemetry_enabled = False - - _print_job_complete_url("job-123", "My Job") - - mock_print.assert_not_called() - - -def test_get_current_job_none(): - """Test get_current_job when no job is active.""" - result = get_current_job() - assert result is None - - -def test_job_context_manager(): - """Test job context manager.""" - with ( - patch("hud.telemetry.job.settings") as mock_settings, - patch("builtins.print"), - ): - mock_settings.telemetry_enabled = True - mock_settings.api_key = "test_key" - - with job("Test Job", {"key": "value"}) as job_obj: - assert job_obj.name == "Test Job" - assert job_obj.metadata == {"key": "value"} - assert get_current_job() == job_obj - - # After context, job should be cleared - assert get_current_job() is None - - -def test_job_context_manager_with_job_id(): - """Test job context manager with explicit job_id.""" - with ( - patch("hud.telemetry.job.settings") as mock_settings, - patch("builtins.print"), - ): - mock_settings.telemetry_enabled = True - mock_settings.api_key = "test_key" - - with job("Test", job_id="my-custom-id") as job_obj: - assert job_obj.id == "my-custom-id" - - -def test_job_context_manager_with_dataset_link(): - """Test job context manager with dataset link.""" - with ( - patch("hud.telemetry.job.settings") as mock_settings, - patch("builtins.print"), - ): - mock_settings.telemetry_enabled = True - mock_settings.api_key = "test_key" - - with job("Test", dataset_link="test/dataset") as job_obj: - assert job_obj.dataset_link == "test/dataset" - - -def test_job_context_manager_exception(): - """Test job context manager handles exceptions.""" - with ( - patch("hud.telemetry.job.settings") as mock_settings, - patch("builtins.print"), - ): - mock_settings.telemetry_enabled = True - mock_settings.api_key = "test_key" - - with pytest.raises(ValueError), job("Test"): - raise ValueError("Test error") - - # Job should be cleared even after exception - assert get_current_job() is None - - -def test_create_job(): - """Test create_job function.""" - job_obj = create_job("Test Job", {"key": "value"}, dataset_link="test/dataset") - - assert job_obj.name == "Test Job" - assert job_obj.metadata == {"key": "value"} - assert job_obj.dataset_link == "test/dataset" - assert job_obj.id # Should have an auto-generated ID - - -def test_create_job_with_job_id(): - """Test create_job with explicit job_id.""" - job_obj = create_job("Test", job_id="custom-id") - - assert job_obj.id == "custom-id" - - -@pytest.mark.asyncio -async def test_job_decorator_async(): - """Test job_decorator on async function.""" - with ( - patch("hud.telemetry.job.settings") as mock_settings, - patch("builtins.print"), - ): - mock_settings.telemetry_enabled = True - mock_settings.api_key = "test_key" - - @job_decorator("test_job", model="gpt-4") - async def test_func(x: int) -> int: - return x * 2 - - result = await test_func(5) - assert result == 10 - - -def test_job_decorator_sync(): - """Test job_decorator on sync function.""" - with ( - patch("hud.telemetry.job.settings") as mock_settings, - patch("builtins.print"), - ): - mock_settings.telemetry_enabled = True - mock_settings.api_key = "test_key" - - @job_decorator("test_job", model="gpt-4") - def test_func(x: int) -> int: - return x * 2 - - result = test_func(5) - assert result == 10 - - -@pytest.mark.asyncio -async def test_job_decorator_async_default_name(): - """Test job_decorator uses function name as default.""" - with ( - patch("hud.telemetry.job.settings") as mock_settings, - patch("builtins.print"), - ): - mock_settings.telemetry_enabled = True - mock_settings.api_key = "test_key" - - @job_decorator() - async def my_function(): - return "success" - - result = await my_function() - assert result == "success" - - -def test_job_decorator_sync_default_name(): - """Test job_decorator sync uses function name as default.""" - with ( - patch("hud.telemetry.job.settings") as mock_settings, - patch("builtins.print"), - ): - mock_settings.telemetry_enabled = True - mock_settings.api_key = "test_key" - - @job_decorator() - def my_function(): - return "success" - - result = my_function() - assert result == "success" diff --git a/hud/telemetry/tests/test_replay.py b/hud/telemetry/tests/test_replay.py deleted file mode 100644 index 507c4e4a..00000000 --- a/hud/telemetry/tests/test_replay.py +++ /dev/null @@ -1,40 +0,0 @@ -"""Tests for telemetry replay functionality.""" - -from __future__ import annotations - -from unittest.mock import patch - -from hud.telemetry.replay import clear_trace, get_trace - - -class TestReplayAPI: - """Tests for replay API functions.""" - - def test_get_trace_calls_internal(self): - """Test that get_trace calls the internal _get_trace function.""" - with patch("hud.telemetry.replay._get_trace") as mock_get: - mock_get.return_value = None - - result = get_trace("test-task-id") - - mock_get.assert_called_once_with("test-task-id") - assert result is None - - def test_clear_trace_calls_internal(self): - """Test that clear_trace calls the internal _clear_trace function.""" - with patch("hud.telemetry.replay._clear_trace") as mock_clear: - clear_trace("test-task-id") - - mock_clear.assert_called_once_with("test-task-id") - - def test_get_trace_with_data(self): - """Test get_trace with mock data.""" - mock_trace = {"trace": [{"step": 1}], "task_run_id": "test-123"} - - with patch("hud.telemetry.replay._get_trace") as mock_get: - mock_get.return_value = mock_trace - - result = get_trace("test-123") - - assert result == mock_trace - mock_get.assert_called_once_with("test-123") diff --git a/hud/telemetry/tests/test_trace.py b/hud/telemetry/tests/test_trace.py deleted file mode 100644 index b2835edf..00000000 --- a/hud/telemetry/tests/test_trace.py +++ /dev/null @@ -1,241 +0,0 @@ -"""Tests for telemetry trace functionality.""" - -from __future__ import annotations - -from unittest.mock import AsyncMock, patch - -import pytest - -from hud.telemetry.trace import Trace, trace - - -class TestTraceAPI: - """Tests for trace API function.""" - - def test_trace_with_disabled_telemetry_and_no_api_key(self): - """Test trace behavior when telemetry is disabled and no API key.""" - # Mock settings to disable telemetry and remove API key - mock_settings = type("Settings", (), {"telemetry_enabled": False, "api_key": None})() - - with ( - patch("hud.settings.get_settings", return_value=mock_settings), - patch("hud.telemetry.trace.OtelTrace") as mock_otel_trace, - ): - mock_otel_trace.return_value.__enter__.return_value = "1234567890" - - with trace("test-trace") as task_run_id: - # Should use placeholder ID for custom backends - assert len(task_run_id.id) == 36 - - def test_trace_with_enabled_telemetry_and_api_key(self): - """Test trace behavior when telemetry is enabled with API key.""" - mock_settings = type("Settings", (), {"telemetry_enabled": True, "api_key": "test-key"})() - - with ( - patch("hud.settings.get_settings", return_value=mock_settings), - patch("hud.telemetry.trace.OtelTrace") as mock_otel_trace, - patch("hud.telemetry.trace.uuid.uuid4") as mock_uuid, - ): - mock_uuid.return_value = "mock-uuid-123" - mock_otel_trace.return_value.__enter__.return_value = "mock-uuid-123" - - with trace("test-trace") as task_run_id: - # Should use generated UUID - assert task_run_id.id == "mock-uuid-123" - - def test_trace_with_no_api_key(self): - """Test trace behavior with no API key (custom backend scenario).""" - mock_settings = type( - "Settings", - (), - { - "telemetry_enabled": True, # Enabled but no API key - "api_key": None, - }, - )() - - with ( - patch("hud.settings.get_settings", return_value=mock_settings), - patch("hud.telemetry.trace.OtelTrace") as mock_otel_trace, - ): - mock_otel_trace.return_value.__enter__.return_value = "custom-otlp-trace" - - with trace("test-trace") as task_run_id: - # In absence of HUD API key, ID should still be a string - assert isinstance(task_run_id.id, str) - - def test_trace_with_job_id(self): - """Test trace with job_id parameter.""" - mock_settings = type("Settings", (), {"telemetry_enabled": True, "api_key": "test-key"})() - - with ( - patch("hud.settings.get_settings", return_value=mock_settings), - patch("hud.telemetry.trace.OtelTrace") as mock_otel_trace, - trace("test-trace", job_id="job-123") as trace_obj, - ): - assert trace_obj.job_id == "job-123" - - # Check OtelTrace was called with job_id - call_kwargs = mock_otel_trace.call_args[1] - assert call_kwargs["job_id"] == "job-123" - - def test_trace_with_task_id(self): - """Test trace with task_id parameter.""" - mock_settings = type("Settings", (), {"telemetry_enabled": True, "api_key": "test-key"})() - - with ( - patch("hud.settings.get_settings", return_value=mock_settings), - patch("hud.telemetry.trace.OtelTrace"), - trace("test-trace", task_id="task-456") as trace_obj, - ): - assert trace_obj.task_id == "task-456" - - def test_trace_with_attributes(self): - """Test trace with custom attributes.""" - mock_settings = type("Settings", (), {"telemetry_enabled": True, "api_key": "test-key"})() - - with ( - patch("hud.settings.get_settings", return_value=mock_settings), - patch("hud.telemetry.trace.OtelTrace") as mock_otel_trace, - trace("test-trace", attrs={"custom": "value"}), - ): - # Check OtelTrace was called with attributes - call_kwargs = mock_otel_trace.call_args[1] - assert call_kwargs["attributes"] == {"custom": "value"} - - def test_trace_non_root(self): - """Test trace with root=False.""" - mock_settings = type("Settings", (), {"telemetry_enabled": True, "api_key": "test-key"})() - - with ( - patch("hud.settings.get_settings", return_value=mock_settings), - patch("hud.telemetry.trace.OtelTrace") as mock_otel_trace, - trace("test-trace", root=False), - ): - # Check OtelTrace was called with is_root=False - call_kwargs = mock_otel_trace.call_args[1] - assert call_kwargs["is_root"] is False - - -class TestTraceClass: - """Tests for Trace class.""" - - def test_trace_initialization(self): - """Test Trace initialization.""" - trace_obj = Trace( - trace_id="test-id", - name="Test Trace", - job_id="job-123", - task_id="task-456", - ) - - assert trace_obj.id == "test-id" - assert trace_obj.name == "Test Trace" - assert trace_obj.job_id == "job-123" - assert trace_obj.task_id == "task-456" - assert trace_obj.created_at is not None - - @pytest.mark.asyncio - async def test_trace_log(self): - """Test Trace async log method.""" - trace_obj = Trace("test-id", "Test") - - with ( - patch("hud.telemetry.trace.settings") as mock_settings, - patch("hud.telemetry.trace.make_request", new_callable=AsyncMock) as mock_request, - ): - mock_settings.telemetry_enabled = True - mock_settings.api_key = "test-key" - mock_settings.hud_telemetry_url = "https://test.com" - - await trace_obj.log({"metric": 1.0}) - - mock_request.assert_called_once() - call_kwargs = mock_request.call_args[1] - assert call_kwargs["json"]["metrics"] == {"metric": 1.0} - - @pytest.mark.asyncio - async def test_trace_log_telemetry_disabled(self): - """Test Trace log when telemetry is disabled.""" - trace_obj = Trace("test-id", "Test") - - with ( - patch("hud.telemetry.trace.settings") as mock_settings, - patch("hud.telemetry.trace.make_request", new_callable=AsyncMock) as mock_request, - ): - mock_settings.telemetry_enabled = False - - await trace_obj.log({"metric": 1.0}) - - mock_request.assert_not_called() - - @pytest.mark.asyncio - async def test_trace_log_error(self): - """Test Trace log handles errors gracefully.""" - trace_obj = Trace("test-id", "Test") - - with ( - patch("hud.telemetry.trace.settings") as mock_settings, - patch("hud.telemetry.trace.make_request", new_callable=AsyncMock) as mock_request, - ): - mock_settings.telemetry_enabled = True - mock_settings.api_key = "test-key" - mock_settings.hud_telemetry_url = "https://test.com" - mock_request.side_effect = Exception("Network error") - - # Should not raise - await trace_obj.log({"metric": 1.0}) - - def test_trace_log_sync(self): - """Test Trace sync log method.""" - trace_obj = Trace("test-id", "Test") - - with ( - patch("hud.telemetry.trace.settings") as mock_settings, - patch("hud.telemetry.trace.make_request_sync") as mock_request, - ): - mock_settings.telemetry_enabled = True - mock_settings.api_key = "test-key" - mock_settings.hud_telemetry_url = "https://test.com" - - trace_obj.log_sync({"metric": 1.0}) - - mock_request.assert_called_once() - - def test_trace_log_sync_telemetry_disabled(self): - """Test Trace sync log when telemetry is disabled.""" - trace_obj = Trace("test-id", "Test") - - with ( - patch("hud.telemetry.trace.settings") as mock_settings, - patch("hud.telemetry.trace.make_request_sync") as mock_request, - ): - mock_settings.telemetry_enabled = False - - trace_obj.log_sync({"metric": 1.0}) - - mock_request.assert_not_called() - - def test_trace_log_sync_error(self): - """Test Trace sync log handles errors gracefully.""" - trace_obj = Trace("test-id", "Test") - - with ( - patch("hud.telemetry.trace.settings") as mock_settings, - patch("hud.telemetry.trace.make_request_sync") as mock_request, - ): - mock_settings.telemetry_enabled = True - mock_settings.api_key = "test-key" - mock_settings.hud_telemetry_url = "https://test.com" - mock_request.side_effect = Exception("Network error") - - # Should not raise - trace_obj.log_sync({"metric": 1.0}) - - def test_trace_repr(self): - """Test Trace __repr__.""" - trace_obj = Trace("test-id", "Test Trace") - - repr_str = repr(trace_obj) - assert "test-id" in repr_str - assert "Test Trace" in repr_str diff --git a/hud/telemetry/trace.py b/hud/telemetry/trace.py deleted file mode 100644 index 2aa19080..00000000 --- a/hud/telemetry/trace.py +++ /dev/null @@ -1,166 +0,0 @@ -"""User-facing trace context manager for HUD telemetry. - -This module provides the simple trace() API that users interact with. -The actual OpenTelemetry implementation is in hud.otel. -""" - -from __future__ import annotations - -import logging -import uuid -from contextlib import contextmanager -from datetime import UTC, datetime -from typing import TYPE_CHECKING, Any - -from hud.otel import configure_telemetry -from hud.otel import trace as OtelTrace -from hud.settings import settings -from hud.shared import make_request, make_request_sync - -if TYPE_CHECKING: - from collections.abc import Generator - -logger = logging.getLogger(__name__) - -__all__ = ["Trace", "trace"] - - -class Trace: - """A trace represents a single task execution with telemetry.""" - - def __init__( - self, - trace_id: str, - name: str, - job_id: str | None = None, - task_id: str | None = None, - group_id: str | None = None, - ) -> None: - self.id = trace_id - self.name = name - self.job_id = job_id - self.task_id = task_id - self.group_id = group_id - self.created_at = datetime.now(UTC) - - async def log(self, metrics: dict[str, Any]) -> None: - """Log metrics to this trace. - - Args: - metrics: Dictionary of metric name to value pairs - - Example: - await trace.log({"step": 1, "loss": 0.5, "accuracy": 0.92}) - """ - if settings.telemetry_enabled: - try: - await make_request( - method="POST", - url=f"{settings.hud_telemetry_url}/traces/{self.id}/log", - json={"metrics": metrics, "timestamp": datetime.now(UTC).isoformat()}, - api_key=settings.api_key, - ) - except Exception as e: - logger.warning("Failed to log metrics to trace: %s", e) - - def log_sync(self, metrics: dict[str, Any]) -> None: - """Synchronously log metrics to this trace. - - Args: - metrics: Dictionary of metric name to value pairs - - Example: - trace.log_sync({"step": 1, "loss": 0.5, "accuracy": 0.92}) - """ - if settings.telemetry_enabled: - try: - make_request_sync( - method="POST", - url=f"{settings.hud_telemetry_url}/traces/{self.id}/log", - json={"metrics": metrics, "timestamp": datetime.now(UTC).isoformat()}, - api_key=settings.api_key, - ) - except Exception as e: - logger.warning("Failed to log metrics to trace: %s", e) - - def __repr__(self) -> str: - return f"Trace(id={self.id!r}, name={self.name!r})" - - -@contextmanager -def trace( - name: str = "Test task from hud", - *, - root: bool = True, - attrs: dict[str, Any] | None = None, - job_id: str | None = None, - task_id: str | None = None, - group_id: str | None = None, - trace_id: str | None = None, -) -> Generator[Trace, None, None]: - """Start a HUD trace context for telemetry tracking. - - A unique task_run_id is automatically generated for each trace unless provided. - - Args: - name: Descriptive name for this trace/task - root: Whether this is a root trace (updates task status) - attrs: Additional attributes to attach to the trace - job_id: Optional job ID to associate with this trace - task_id: Optional task ID (for custom task identifiers) - group_id: Optional group ID to associate with this trace - trace_id: Optional trace ID (auto-generated if not provided) - - Yields: - Trace: The trace object with logging capabilities - - Example: - >>> import hud - >>> with hud.trace("My Task") as trace: - ... do_work() - ... trace.log_sync({"step": 1, "progress": 0.5}) - >>> # For async code, use async_trace - >>> async with hud.async_trace("Async Task") as trace: - ... await do_async_work() - ... await trace.log({"loss": 0.23}) - - Note: - This is a synchronous context manager that uses blocking HTTP calls. - For async code, use `hud.async_trace()` instead. - """ - # Ensure telemetry is configured - configure_telemetry() - - # Use provided trace_id or generate one - if trace_id: - task_run_id = trace_id - else: - # Only generate task_run_id if using HUD backend - # For custom OTLP backends, we don't need it - from hud.settings import get_settings - - settings = get_settings() - - if settings.telemetry_enabled and settings.api_key: - task_run_id = str(uuid.uuid4()) - else: - # Use a placeholder for custom backends - logger.warning( - "HUD API key is not set, using a placeholder for the task run ID. If this looks wrong, check your API key." # noqa: E501 - ) - task_run_id = str(uuid.uuid4()) - - # Create trace object - trace_obj = Trace(task_run_id, name, job_id, task_id, group_id) - - # Delegate to OpenTelemetry implementation - with OtelTrace( - task_run_id, - is_root=root, - span_name=name, - attributes=attrs or {}, - job_id=job_id, - task_id=task_id, - group_id=group_id, - ): - yield trace_obj diff --git a/hud/telemetry/utils.py b/hud/telemetry/utils.py deleted file mode 100644 index a63ecfaf..00000000 --- a/hud/telemetry/utils.py +++ /dev/null @@ -1,42 +0,0 @@ -"""Telemetry utility functions for managing trace and span lifecycle.""" - -from __future__ import annotations - -import logging - -logger = logging.getLogger(__name__) - - -async def flush_telemetry() -> None: - """Flush OpenTelemetry span processor to export buffered spans immediately. - - Called automatically by async_trace (standalone) and async_job on exit. - - Example: - >>> # Custom evaluation loop - >>> for task in tasks: - ... async with hud.async_trace(task.name): - ... await process(task) - >>> # Spans already flushed by each async_trace - """ - from hud.otel.config import is_telemetry_configured - from hud.utils import hud_console - - logger.debug("Flushing telemetry spans...") - if not is_telemetry_configured(): - return - - try: - from opentelemetry import trace - from opentelemetry.sdk.trace import TracerProvider - - provider = trace.get_tracer_provider() - if isinstance(provider, TracerProvider): - success = provider.force_flush(timeout_millis=5000) - if success: - hud_console.info("✓ Telemetry uploaded successfully") - logger.debug("OpenTelemetry spans flushed successfully") - else: - logger.debug("OpenTelemetry flush timed out (will export on exit)") - except Exception as e: - logger.debug("Failed to flush OpenTelemetry: %s", e) diff --git a/hud/tests/test_datasets_extended.py b/hud/tests/test_datasets_extended.py index 94ec4c51..58b013d6 100644 --- a/hud/tests/test_datasets_extended.py +++ b/hud/tests/test_datasets_extended.py @@ -7,10 +7,8 @@ import pytest -from hud.datasets import ( - LegacyTask, - run_dataset, -) +from hud.datasets import run_dataset +from hud.types import LegacyTask from hud.types import MCPToolCall from hud.utils.tasks import save_tasks @@ -214,7 +212,7 @@ async def test_run_dataset_from_source_string(self): mock_agent.run.return_value = Trace(reward=1.0, done=True) mock_env = {"name": "test"} - mock_tasks = [Task(env=mock_env, scenario="loaded")] + mock_tasks = [Task(env=mock_env, scenario="loaded")] # type: ignore[arg-type] mock_ctx = AsyncMock() mock_ctx.results = None diff --git a/hud/tests/test_types.py b/hud/tests/test_types.py index abd052f7..3c275ae1 100644 --- a/hud/tests/test_types.py +++ b/hud/tests/test_types.py @@ -249,45 +249,3 @@ def test_trace_num_messages(): """Test Trace num_messages property.""" trace = Trace(messages=[{"role": "user"}, {"role": "assistant"}]) assert trace.num_messages == 2 - - -def test_trace_populate_from_context(): - """Test Trace.populate_from_context with no context.""" - trace = Trace() - # Should not raise when no context - trace.populate_from_context() - assert len(trace.trace) == 0 - - -def test_trace_populate_from_context_with_context(): - """Test Trace.populate_from_context with active context.""" - with ( - patch("hud.otel.context.get_current_task_run_id") as mock_get_id, - patch("hud.telemetry.replay.get_trace") as mock_get_trace, - ): - mock_get_id.return_value = "test_run_id" - mock_trace = MagicMock() - mock_trace.trace = [TraceStep(category="mcp")] - mock_get_trace.return_value = mock_trace - - trace = Trace() - trace.populate_from_context() - - assert len(trace.trace) == 1 - mock_get_id.assert_called_once() - mock_get_trace.assert_called_once_with("test_run_id") - - -def test_trace_populate_from_context_no_trace(): - """Test Trace.populate_from_context when get_trace returns None.""" - with ( - patch("hud.otel.context.get_current_task_run_id") as mock_get_id, - patch("hud.telemetry.replay.get_trace") as mock_get_trace, - ): - mock_get_id.return_value = "test_run_id" - mock_get_trace.return_value = None - - trace = Trace() - trace.populate_from_context() - - assert len(trace.trace) == 0 diff --git a/hud/tools/grounding/grounder.py b/hud/tools/grounding/grounder.py index fa593831..31b7b2be 100644 --- a/hud/tools/grounding/grounder.py +++ b/hud/tools/grounding/grounder.py @@ -5,14 +5,16 @@ import base64 import io import json +import logging import re from openai import AsyncOpenAI -from opentelemetry import trace from hud import instrument from hud.tools.grounding.config import GrounderConfig # noqa: TC001 +logger = logging.getLogger(__name__) + class Grounder: """Grounder that uses AsyncOpenAI to call vLLM or other model endpoints for visual grounding. @@ -247,12 +249,7 @@ async def predict_click( # Extract response text response_text = response.choices[0].message.content - - # Manually record the raw response in the span - span = trace.get_current_span() - if span and span.is_recording(): - span.set_attribute("grounder.raw_response", json.dumps(response.model_dump())) - span.set_attribute("grounder.attempt", attempt + 1) + logger.debug("Grounder attempt %d response: %s", attempt + 1, response_text) # Parse coordinates from response if response_text is None: @@ -277,26 +274,16 @@ async def predict_click( y = max(0, min(y, original_size[1] - 1)) pixel_coords = (x, y) - # Record successful grounding in span - span = trace.get_current_span() - if span and span.is_recording(): - span.set_attribute("grounder.success", True) - span.set_attribute( - "grounder.final_coords", f"{pixel_coords[0]},{pixel_coords[1]}" - ) - span.set_attribute("grounder.total_attempts", attempt + 1) - + logger.debug( + "Grounder success: coords=%s after %d attempts", + pixel_coords, + attempt + 1, + ) return pixel_coords except Exception: if attempt < max_retries - 1: continue - # Record failure in span - span = trace.get_current_span() - if span and span.is_recording(): - span.set_attribute("grounder.success", False) - span.set_attribute("grounder.total_attempts", max_retries) - span.set_attribute("grounder.failure_reason", "All attempts exhausted") - + logger.debug("Grounder failed after %d attempts", max_retries) return None diff --git a/hud/types.py b/hud/types.py index f982b60a..73f90b54 100644 --- a/hud/types.py +++ b/hud/types.py @@ -31,7 +31,9 @@ class AgentType(str, Enum): @property def cls(self) -> type: - from hud.agents import ClaudeAgent, GeminiAgent, OpenAIAgent, OperatorAgent + from hud.agents import OpenAIAgent, OperatorAgent + from hud.agents.claude import ClaudeAgent + from hud.agents.gemini import GeminiAgent from hud.agents.gemini_cua import GeminiCUAAgent from hud.agents.openai_chat import OpenAIChatAgent @@ -346,6 +348,27 @@ class TraceStep(BaseModel): model_config = ConfigDict(populate_by_name=True, extra="allow") +class HudSpan(BaseModel): + """A telemetry span ready for export to HUD API.""" + + name: str + trace_id: str = Field(pattern=r"^[0-9a-fA-F]{32}$") + span_id: str = Field(pattern=r"^[0-9a-fA-F]{16}$") + parent_span_id: str | None = Field(default=None, pattern=r"^[0-9a-fA-F]{16}$") + + start_time: str # ISO format + end_time: str # ISO format + + status_code: str # "UNSET", "OK", "ERROR" + status_message: str | None = None + + attributes: TraceStep + exceptions: list[dict[str, Any]] | None = None + internal_type: str | None = None + + model_config = ConfigDict(extra="forbid") + + class Trace(BaseModel): """Unified result from agent execution (task or prompt). @@ -381,31 +404,22 @@ def num_messages(self) -> int: def append(self, step: TraceStep) -> None: self.trace.append(step) - def populate_from_context(self) -> None: - """Populate trace steps from the current trace context if available. - - This checks if we're executing within a hud.trace() context and - automatically populates the trace field with collected steps. - """ - from hud.otel.context import get_current_task_run_id - from hud.telemetry.replay import get_trace - - task_run_id = get_current_task_run_id() - if task_run_id: - collected_trace = get_trace(task_run_id) - if collected_trace: - self.trace = collected_trace.trace - # Re-export Task for backwards compatibility (after module defs to avoid circular import) from hud.eval.task import Task # noqa: E402 +# Type alias for functions that accept v5 Task, v4 LegacyTask, or raw dicts +TaskInput = Task | LegacyTask | dict[str, Any] + __all__ = [ "AgentResponse", "AgentType", + "HudSpan", + "LegacyTask", "MCPToolCall", "MCPToolResult", "Task", + "TaskInput", "Trace", "TraceStep", ] diff --git a/hud/utils/mcp.py b/hud/utils/mcp.py index e9335d54..859cb5b8 100644 --- a/hud/utils/mcp.py +++ b/hud/utils/mcp.py @@ -5,8 +5,6 @@ from pydantic import BaseModel, Field -from hud.settings import settings - logger = logging.getLogger(__name__) @@ -49,53 +47,3 @@ def patch_mcp_config(mcp_config: dict[str, dict[str, Any]], patch: MCPConfigPatc meta.setdefault(key, value) -def setup_hud_telemetry( - mcp_config: dict[str, dict[str, Any]], auto_trace: bool = True -) -> Any | None: - """Setup telemetry for hud servers. - - Returns: - The auto-created trace context manager if one was created, None otherwise. - Caller is responsible for exiting the context manager. - """ - if mcp_config is None: - raise ValueError("Please run initialize() before setting up client-side telemetry") - - # Check if there are any HUD servers to setup telemetry for - has_hud_servers = any( - _is_hud_server(server_cfg.get("url", "")) for server_cfg in mcp_config.values() - ) - - # If no HUD servers, no need for telemetry setup - if not has_hud_servers: - return None - - from hud.otel import get_current_task_run_id - from hud.telemetry.trace import trace - - run_id = get_current_task_run_id() - auto_trace_cm = None - - if not run_id and auto_trace: - # Start an auto trace and capture its ID for headers/metadata - auto_trace_cm = trace("My Trace") - _trace_obj = auto_trace_cm.__enter__() - try: - run_id = getattr(_trace_obj, "id", None) or str(_trace_obj) - except Exception: # pragma: no cover - fallback shouldn't fail lint - run_id = None - - # Patch HUD servers with run-id (works whether auto or user trace) - if run_id: - patch_mcp_config( - mcp_config, - MCPConfigPatch(headers={"Run-Id": run_id}, meta={"run_id": run_id}), - ) - - if settings.api_key: - patch_mcp_config( - mcp_config, - MCPConfigPatch(headers={"Authorization": f"Bearer {settings.api_key}"}), - ) - - return auto_trace_cm diff --git a/hud/utils/tasks.py b/hud/utils/tasks.py index bf44b798..ca4f4fab 100644 --- a/hud/utils/tasks.py +++ b/hud/utils/tasks.py @@ -1,133 +1,9 @@ from __future__ import annotations import json -from pathlib import Path from typing import Any from hud.types import LegacyTask -from hud.utils.hud_console import HUDConsole - -hud_console = HUDConsole() - - -def load_tasks( - tasks_input: str | list[dict], *, raw: bool = False -) -> list[LegacyTask] | list[dict]: - """Load tasks from various sources. - - Args: - tasks_input: Either: - - Path to a JSON file (array of tasks) - - Path to a JSONL file (one task per line) - - HuggingFace dataset name (format: "username/dataset" or "username/dataset:split") - - List of task dictionaries - raw: If True, return raw dicts without validation or env substitution - - Returns: - - If raw=False (default): list[LegacyTask] - - If raw=True: list[dict] - """ - tasks: list[LegacyTask] | list[dict] = [] - - if isinstance(tasks_input, list): - # Direct list of task dicts - hud_console.info(f"Loading {len(tasks_input)} tasks from provided list") - if raw: - return [item for item in tasks_input if isinstance(item, dict)] - for item in tasks_input: - task = LegacyTask(**item) - tasks.append(task) - - elif isinstance(tasks_input, str): - # Check if it's a file path - if Path(tasks_input).exists(): - file_path = Path(tasks_input) - - with open(file_path, encoding="utf-8") as f: - # Handle JSON files (array of tasks) - if file_path.suffix.lower() == ".json": - data = json.load(f) - if not isinstance(data, list): - raise ValueError( - f"JSON file must contain an array of tasks, got {type(data)}" - ) - if raw: - return [item for item in data if isinstance(item, dict)] - for item in data: - task = LegacyTask(**item) - tasks.append(task) - - # Handle JSONL files (one task per line) - else: - raw_items: list[dict] = [] - for line in f: - line = line.strip() - if not line: - continue - item = json.loads(line) - if isinstance(item, list): - raw_items.extend([it for it in item if isinstance(it, dict)]) - elif isinstance(item, dict): - raw_items.append(item) - else: - raise ValueError( - f"Invalid JSONL format: expected dict or list of dicts, got {type(item)}" # noqa: E501 - ) - if raw: - return raw_items - for it in raw_items: - task = LegacyTask(**it) - tasks.append(task) - - # Check if it's a HuggingFace dataset - elif "/" in tasks_input: - hud_console.info(f"Loading tasks from HuggingFace dataset: {tasks_input}") - try: - from datasets import load_dataset - - # Parse dataset name and optional split - if ":" in tasks_input: - dataset_name, split = tasks_input.split(":", 1) - else: - dataset_name = tasks_input - split = "train" # Default split - - dataset = load_dataset(dataset_name, split=split) - - # Convert dataset rows to Task objects - raw_rows: list[dict] = [] - for item in dataset: - if not isinstance(item, dict): - raise ValueError( - f"Invalid HuggingFace dataset: expected dict, got {type(item)}" - ) - if not item["mcp_config"] or not item["prompt"]: - raise ValueError( - f"Invalid HuggingFace dataset: expected mcp_config and prompt, got {item}" # noqa: E501 - ) - raw_rows.append(item) - if raw: - return raw_rows - for row in raw_rows: - task = LegacyTask(**row) - tasks.append(task) - - except ImportError as e: - raise ImportError( - "Please install 'datasets' to load from HuggingFace: uv pip install datasets" - ) from e - except Exception as e: - raise ValueError(f"Failed to load HuggingFace dataset '{tasks_input}': {e}") from e - - else: - raise ValueError( - f"Invalid tasks input: '{tasks_input}' is neither a file path nor a HuggingFace dataset" # noqa: E501 - ) - - else: - raise TypeError(f"tasks_input must be str or list, got {type(tasks_input)}") - - return tasks def save_tasks( @@ -136,8 +12,7 @@ def save_tasks( fields: list[str] | None = None, **kwargs: Any, ) -> None: - """ - Save data to a HuggingFace dataset with JSON string serialization. + """Save data to a HuggingFace dataset with JSON string serialization. Complex fields (dicts, lists) are serialized as JSON strings to keep schemas clean and avoid null-value pollution when uploaded to the Hub. @@ -148,7 +23,6 @@ def save_tasks( fields: Optional subset of fields to persist. Defaults to all keys per task. **kwargs: Extra kwargs forwarded to `Dataset.push_to_hub`. """ - if tasks and isinstance(tasks[0], LegacyTask): raise ValueError( "save_tasks expects dictionaries, not LegacyTask objects. " diff --git a/hud/utils/tests/test_mcp.py b/hud/utils/tests/test_mcp.py index 1af6daed..9be367c7 100644 --- a/hud/utils/tests/test_mcp.py +++ b/hud/utils/tests/test_mcp.py @@ -4,7 +4,7 @@ import pytest -from hud.utils.mcp import MCPConfigPatch, patch_mcp_config, setup_hud_telemetry +from hud.utils.mcp import MCPConfigPatch, patch_mcp_config class TestPatchMCPConfig: @@ -85,26 +85,3 @@ def test_patch_meta_preserves_existing(self): # Existing meta should be preserved, new one added assert mcp_config["test_server"]["meta"]["existing_key"] == "existing_value" assert mcp_config["test_server"]["meta"]["test_key"] == "test_value" - - -class TestSetupHUDTelemetry: - """Tests for setup_hud_telemetry function.""" - - def test_empty_config_returns_none(self): - """Test that empty config returns None (no servers to set up telemetry for).""" - result = setup_hud_telemetry({}) - assert result is None - - def test_none_config_raises_error(self): - """Test that None config raises ValueError.""" - with pytest.raises( - ValueError, match="Please run initialize\\(\\) before setting up client-side telemetry" - ): - setup_hud_telemetry(None) # type: ignore[arg-type] - - def test_valid_config_returns_none_when_no_hud_servers(self): - """Test that valid config with no HUD servers returns None.""" - mcp_config = {"test_server": {"url": "http://example.com"}} - - result = setup_hud_telemetry(mcp_config) - assert result is None diff --git a/hud/utils/tests/test_tasks.py b/hud/utils/tests/test_tasks.py deleted file mode 100644 index 18bc778c..00000000 --- a/hud/utils/tests/test_tasks.py +++ /dev/null @@ -1,356 +0,0 @@ -from __future__ import annotations - -import json -import tempfile -from pathlib import Path -from unittest.mock import MagicMock, patch - -import pytest - -from hud.types import LegacyTask -from hud.utils.tasks import load_tasks, save_tasks - - -def test_load_tasks_from_list(): - """Test loading tasks from a list of dictionaries.""" - task_dicts = [ - {"id": "1", "prompt": "Test task 1", "mcp_config": {}}, - {"id": "2", "prompt": "Test task 2", "mcp_config": {}}, - ] - - tasks = load_tasks(task_dicts) - - assert len(tasks) == 2 - assert all(isinstance(t, LegacyTask) for t in tasks) - assert tasks[0].prompt == "Test task 1" # type: ignore - assert tasks[1].prompt == "Test task 2" # type: ignore - - -def test_load_tasks_from_list_raw(): - """Test loading tasks from a list in raw mode.""" - task_dicts = [ - {"id": "1", "prompt": "Test task 1", "mcp_config": {}}, - {"id": "2", "prompt": "Test task 2", "mcp_config": {}}, - ] - - tasks = load_tasks(task_dicts, raw=True) - - assert len(tasks) == 2 - assert all(isinstance(t, dict) for t in tasks) - assert tasks[0]["prompt"] == "Test task 1" # type: ignore - - -def test_load_tasks_from_json_file(): - """Test loading tasks from a JSON file.""" - task_dicts = [ - {"id": "1", "prompt": "Test task 1", "mcp_config": {}}, - {"id": "2", "prompt": "Test task 2", "mcp_config": {}}, - ] - - with tempfile.NamedTemporaryFile(mode="w", suffix=".json", delete=False, encoding="utf-8") as f: - json.dump(task_dicts, f) - temp_path = f.name - - try: - tasks = load_tasks(temp_path) - - assert len(tasks) == 2 - assert all(isinstance(t, LegacyTask) for t in tasks) - assert tasks[0].prompt == "Test task 1" # type: ignore - finally: - Path(temp_path).unlink() - - -def test_load_tasks_from_json_file_raw(): - """Test loading tasks from a JSON file in raw mode.""" - task_dicts = [ - {"id": "1", "prompt": "Test task 1", "mcp_config": {}}, - {"id": "2", "prompt": "Test task 2", "mcp_config": {}}, - ] - - with tempfile.NamedTemporaryFile(mode="w", suffix=".json", delete=False, encoding="utf-8") as f: - json.dump(task_dicts, f) - temp_path = f.name - - try: - tasks = load_tasks(temp_path, raw=True) - - assert len(tasks) == 2 - assert all(isinstance(t, dict) for t in tasks) - finally: - Path(temp_path).unlink() - - -def test_load_tasks_from_jsonl_file(): - """Test loading tasks from a JSONL file.""" - task_dicts = [ - {"id": "1", "prompt": "Test task 1", "mcp_config": {}}, - {"id": "2", "prompt": "Test task 2", "mcp_config": {}}, - ] - - with tempfile.NamedTemporaryFile( - mode="w", suffix=".jsonl", delete=False, encoding="utf-8" - ) as f: - for task_dict in task_dicts: - f.write(json.dumps(task_dict) + "\n") - temp_path = f.name - - try: - tasks = load_tasks(temp_path) - - assert len(tasks) == 2 - assert all(isinstance(t, LegacyTask) for t in tasks) - assert tasks[0].prompt == "Test task 1" # type: ignore - finally: - Path(temp_path).unlink() - - -def test_load_tasks_from_jsonl_file_with_empty_lines(): - """Test loading tasks from a JSONL file with empty lines.""" - task_dicts = [ - {"id": "1", "prompt": "Test task 1", "mcp_config": {}}, - {"id": "2", "prompt": "Test task 2", "mcp_config": {}}, - ] - - with tempfile.NamedTemporaryFile( - mode="w", suffix=".jsonl", delete=False, encoding="utf-8" - ) as f: - f.write(json.dumps(task_dicts[0]) + "\n") - f.write("\n") # Empty line - f.write(json.dumps(task_dicts[1]) + "\n") - temp_path = f.name - - try: - tasks = load_tasks(temp_path) - - assert len(tasks) == 2 - assert all(isinstance(t, LegacyTask) for t in tasks) - finally: - Path(temp_path).unlink() - - -def test_load_tasks_from_jsonl_file_with_list(): - """Test loading tasks from a JSONL file where a line contains a list.""" - task_dict = {"id": "1", "prompt": "Test task 1", "mcp_config": {}} - - with tempfile.NamedTemporaryFile( - mode="w", suffix=".jsonl", delete=False, encoding="utf-8" - ) as f: - f.write(json.dumps([task_dict, task_dict]) + "\n") - temp_path = f.name - - try: - tasks = load_tasks(temp_path) - - assert len(tasks) == 2 - assert all(isinstance(t, LegacyTask) for t in tasks) - finally: - Path(temp_path).unlink() - - -def test_load_tasks_json_not_array_error(): - """Test that loading from JSON file with non-array raises error.""" - with tempfile.NamedTemporaryFile(mode="w", suffix=".json", delete=False, encoding="utf-8") as f: - json.dump({"not": "an array"}, f) - temp_path = f.name - - try: - with pytest.raises(ValueError, match="JSON file must contain an array"): - load_tasks(temp_path) - finally: - Path(temp_path).unlink() - - -def test_load_tasks_invalid_jsonl_format(): - """Test that loading from JSONL with invalid format raises error.""" - with tempfile.NamedTemporaryFile( - mode="w", suffix=".jsonl", delete=False, encoding="utf-8" - ) as f: - f.write(json.dumps("invalid") + "\n") - temp_path = f.name - - try: - with pytest.raises(ValueError, match="Invalid JSONL format"): - load_tasks(temp_path) - finally: - Path(temp_path).unlink() - - -def test_load_tasks_invalid_input_type(): - """Test that invalid input type raises TypeError.""" - with pytest.raises(TypeError, match="tasks_input must be str or list"): - load_tasks(123) # type: ignore - - -def test_load_tasks_nonexistent_file(): - """Test that loading from nonexistent file raises error.""" - with pytest.raises(ValueError, match="neither a file path nor a HuggingFace dataset"): - load_tasks("nonexistent_file_without_slash") - - -def test_save_tasks_basic(): - """Test basic save_tasks functionality.""" - tasks = [ - {"id": "1", "prompt": "test", "mcp_config": {"key": "value"}}, - {"id": "2", "prompt": "test2", "mcp_config": {"key2": "value2"}}, - ] - - with patch("datasets.Dataset") as mock_dataset_class: - mock_dataset = MagicMock() - mock_dataset_class.from_list.return_value = mock_dataset - - save_tasks(tasks, "test/repo") - - mock_dataset_class.from_list.assert_called_once() - call_args = mock_dataset_class.from_list.call_args[0][0] - assert len(call_args) == 2 - # Check that mcp_config was JSON serialized - assert isinstance(call_args[0]["mcp_config"], str) - mock_dataset.push_to_hub.assert_called_once_with("test/repo") - - -def test_save_tasks_with_specific_fields(): - """Test save_tasks with specific fields.""" - tasks = [ - {"id": "1", "prompt": "test", "mcp_config": {"key": "value"}, "extra": "data"}, - ] - - with patch("datasets.Dataset") as mock_dataset_class: - mock_dataset = MagicMock() - mock_dataset_class.from_list.return_value = mock_dataset - - save_tasks(tasks, "test/repo", fields=["id", "prompt"]) - - call_args = mock_dataset_class.from_list.call_args[0][0] - assert "id" in call_args[0] - assert "prompt" in call_args[0] - assert "extra" not in call_args[0] - - -def test_save_tasks_with_list_field(): - """Test save_tasks serializes list fields.""" - tasks = [ - {"id": "1", "tags": ["tag1", "tag2"], "count": 5}, - ] - - with patch("datasets.Dataset") as mock_dataset_class: - mock_dataset = MagicMock() - mock_dataset_class.from_list.return_value = mock_dataset - - save_tasks(tasks, "test/repo") - - call_args = mock_dataset_class.from_list.call_args[0][0] - # List should be JSON serialized - assert isinstance(call_args[0]["tags"], str) - assert '"tag1"' in call_args[0]["tags"] - - -def test_save_tasks_with_primitive_types(): - """Test save_tasks handles various primitive types.""" - tasks = [ - { - "string": "text", - "integer": 42, - "float": 3.14, - "boolean": True, - "none": None, - }, - ] - - with patch("datasets.Dataset") as mock_dataset_class: - mock_dataset = MagicMock() - mock_dataset_class.from_list.return_value = mock_dataset - - save_tasks(tasks, "test/repo") - - call_args = mock_dataset_class.from_list.call_args[0][0] - assert call_args[0]["string"] == "text" - assert call_args[0]["integer"] == 42 - assert call_args[0]["float"] == 3.14 - assert call_args[0]["boolean"] is True - assert call_args[0]["none"] == "" # None becomes empty string - - -def test_save_tasks_with_other_type(): - """Test save_tasks converts other types to string.""" - - class CustomObj: - def __str__(self): - return "custom_value" - - tasks = [ - {"id": "1", "custom": CustomObj()}, - ] - - with patch("datasets.Dataset") as mock_dataset_class: - mock_dataset = MagicMock() - mock_dataset_class.from_list.return_value = mock_dataset - - save_tasks(tasks, "test/repo") - - call_args = mock_dataset_class.from_list.call_args[0][0] - assert call_args[0]["custom"] == "custom_value" - - -def test_save_tasks_rejects_task_objects(): - """Test save_tasks raises error for LegacyTask objects.""" - task = LegacyTask(prompt="test", mcp_config={}) - - with pytest.raises(ValueError, match="expects dictionaries, not LegacyTask objects"): - save_tasks([task], "test/repo") # type: ignore - - -def test_save_tasks_rejects_task_objects_in_list(): - """Test save_tasks raises error when LegacyTask object is in the list.""" - tasks = [ - {"id": "1", "prompt": "test", "mcp_config": {}}, - LegacyTask(prompt="test2", mcp_config={}), # LegacyTask object - ] - - with pytest.raises(ValueError, match="Item 1 is a LegacyTask object"): - save_tasks(tasks, "test/repo") # type: ignore - - -def test_save_tasks_with_kwargs(): - """Test save_tasks passes kwargs to push_to_hub.""" - tasks = [{"id": "1", "prompt": "test"}] - - with patch("datasets.Dataset") as mock_dataset_class: - mock_dataset = MagicMock() - mock_dataset_class.from_list.return_value = mock_dataset - - save_tasks(tasks, "test/repo", private=True, commit_message="Test commit") - - mock_dataset.push_to_hub.assert_called_once_with( - "test/repo", private=True, commit_message="Test commit" - ) - - -def test_save_tasks_field_not_in_dict(): - """Test save_tasks handles missing fields gracefully.""" - tasks = [ - {"id": "1", "prompt": "test"}, - ] - - with patch("datasets.Dataset") as mock_dataset_class: - mock_dataset = MagicMock() - mock_dataset_class.from_list.return_value = mock_dataset - - # Request fields that don't exist - save_tasks(tasks, "test/repo", fields=["id", "missing_field"]) - - call_args = mock_dataset_class.from_list.call_args[0][0] - assert "id" in call_args[0] - assert "missing_field" not in call_args[0] - - -def test_save_tasks_empty_list(): - """Test save_tasks with empty list.""" - with patch("datasets.Dataset") as mock_dataset_class: - mock_dataset = MagicMock() - mock_dataset_class.from_list.return_value = mock_dataset - - save_tasks([], "test/repo") - - mock_dataset_class.from_list.assert_called_once_with([]) - mock_dataset.push_to_hub.assert_called_once() diff --git a/scripts/pre_release_check.py b/scripts/pre_release_check.py index f0ac6ddf..8ef65703 100644 --- a/scripts/pre_release_check.py +++ b/scripts/pre_release_check.py @@ -14,7 +14,7 @@ import sys from typing import Any -from hud.agents import ClaudeAgent +from hud.agents.claude import ClaudeAgent from hud.settings import settings # Configure logging From 79c7a18890cb0c36fae8b00e7343eb7b3949d40d Mon Sep 17 00:00:00 2001 From: lorenss-m Date: Sun, 14 Dec 2025 07:42:16 -0800 Subject: [PATCH 50/92] prelim small updates --- examples/00_agent_env.py | 59 +++++------- examples/01_agent_lifecycle.py | 71 ++++++-------- examples/02_claude_agent.py | 46 ++++----- examples/03_openai_compatible_agent.py | 77 ++++++--------- examples/04_grounded_agent.py | 125 ++++++++++++------------- examples/05_custom_agent.py | 37 +++----- examples/README.md | 122 +++++++++++++++++------- examples/integration_otel.py | 95 ------------------- examples/run_evaluation.py | 24 ++--- hud/datasets/tests/test_loader.py | 4 +- hud/datasets/tests/test_utils.py | 18 +++- hud/eval/task.py | 84 ++++++++++------- 12 files changed, 326 insertions(+), 436 deletions(-) delete mode 100644 examples/integration_otel.py diff --git a/examples/00_agent_env.py b/examples/00_agent_env.py index 85d4b153..41b935e8 100644 --- a/examples/00_agent_env.py +++ b/examples/00_agent_env.py @@ -1,67 +1,52 @@ """Tiny agent-environment demo in one file. ┌───────────────┐ tool call (MCP) ┌───────────────┐ -│ Client │ ────────────────► │ Server │ -│ (agent side) │ JSON-RPC / stdio │ (environment) │ +│ Agent │ ────────────────► │ Environment │ +│ (client) │ hud.eval() │ (hud.Env) │ └───────────────┘ └───────────────┘ -Server = the *environment* -• Exposes one tool `sum(a, b)` using the FastMCP SDK. -• In real projects the server runs inside Docker so stdout is reserved for the - protocol and stderr for logs. +Environment = hud.Environment with @env.tool +• Exposes one tool `sum(a, b)` using the @env.tool decorator. +• In real projects this would be a Docker image or remote service. -Client = the *agent side* -• Uses `hud.client.MCPClient` to connect to **any** MCP environment – local - subprocess here, Docker or remote HUD in real scenarios. -• Sends a single tool call and prints the result. +Agent = the client side +• Uses `hud.eval(env())` to connect and call tools. +• The environment handles tool routing automatically. -Run `python examples/00_minimal_fastmcp.py` → prints `3 + 4 = 7`. +Run `python examples/00_agent_env.py` → prints `3 + 4 = 7`. """ from __future__ import annotations import asyncio -import sys -from pathlib import Path -from fastmcp import FastMCP -from hud.clients import MCPClient +import hud # ------------------------------------------------------------------ -# Environment (server) +# Environment (with local tools) # ------------------------------------------------------------------ -server = FastMCP("MiniServer") +env = hud.Environment("calculator") -@server.tool() +@env.tool() def sum(a: int, b: int) -> int: + """Add two numbers together.""" return a + b # ------------------------------------------------------------------ -# Agent (client) – spawns the same file with --server and calls the tool +# Agent (client) – connects to env and calls tools # ------------------------------------------------------------------ -THIS_FILE = Path(__file__).absolute() - -async def run_client() -> None: - cfg = { - "local": { - "command": sys.executable, - "args": [str(THIS_FILE), "--server"], - } - } - client = MCPClient(mcp_config=cfg) - await client.initialize() - result = await client.call_tool(name="sum", arguments={"a": 3, "b": 4}) - print("3 + 4 =", result) - await client.shutdown() +async def main() -> None: + """Connect to the environment and call the sum tool.""" + # Use hud.eval() with env() to create a task and run it + async with hud.eval(env(), trace=False) as ctx: + result = await ctx.call_tool(name="sum", arguments={"a": 3, "b": 4}) + print("3 + 4 =", result) if __name__ == "__main__": - if "--server" in sys.argv: - server.run() - else: - asyncio.run(run_client()) # The client will run itself with the --server flag + asyncio.run(main()) diff --git a/examples/01_agent_lifecycle.py b/examples/01_agent_lifecycle.py index 8d15e1ce..9de673b6 100644 --- a/examples/01_agent_lifecycle.py +++ b/examples/01_agent_lifecycle.py @@ -2,12 +2,11 @@ """ Complete Agent Lifecycle Example -This example demonstrates the full agent lifecycle using Task.from_v4(): -- Task definition with setup and evaluation tools (v4 LegacyTask format) -- Conversion to v5 Task using Task.from_v4() +This example demonstrates the full agent lifecycle using the v5 Task format: +- Task definition with Environment and scenario - hud.eval() context for connection and tracing - Agent initialization and execution -- Automatic setup/evaluate tool execution +- Automatic scenario setup/evaluation - Result collection For simpler usage, just use `await agent.run(ctx)` which handles everything. @@ -15,65 +14,49 @@ """ import asyncio + import hud -from hud.datasets import LegacyTask +from hud.agents.claude import ClaudeAgent from hud.eval.task import Task -from hud.agents import ClaudeAgent -async def main(): +async def main() -> None: print("🚀 Agent Lifecycle Example") print("=" * 50) - # Phase 1: Define task in v4 LegacyTask format - # This format includes setup_tool and evaluate_tool + # Phase 1: Define task using v5 Task format + # The Task holds environment config and scenario info print("📋 Defining task...") - legacy_task = LegacyTask( - prompt="Create a new todo item with the title 'Buy groceries' and description 'Milk, eggs, bread'", - mcp_config={ - "hud": { - "url": "https://mcp.hud.ai/v3/mcp", - "headers": { - "Authorization": "Bearer ${HUD_API_KEY}", # Auto-resolved from env - "Mcp-Image": "hudevals/hud-browser:latest", - }, - } - }, - setup_tool={"name": "launch_app", "arguments": {"app_name": "todo"}}, - evaluate_tool={ - "name": "evaluate", - "arguments": {"name": "todo_exists", "arguments": {"title": "Buy groceries"}}, - }, + task = Task( + # Environment config - connects to HUD browser hub + env={"name": "browser"}, + # Scenario to run (defined on the environment) + scenario="checkout", + # Scenario arguments + args={"product": "laptop", "quantity": 1}, + # Optional: agent configuration + agent_config={"system_prompt": "You are a helpful shopping assistant."}, ) - # Phase 2: Convert to v5 Task - # Task.from_v4() creates an Environment with: - # - mcp_config connection (connects on context entry) - # - setup_tool calls (run on context entry) - # - evaluate_tool calls (run on context exit) - print("🔄 Converting to v5 Task...") - task = Task.from_v4(legacy_task) - - # Phase 3: Create agent + # Phase 2: Create agent print("🤖 Creating Claude agent...") agent = ClaudeAgent.create( - checkpoint_name="claude-sonnet-4-5", + checkpoint_name="claude-sonnet-4-20250514", allowed_tools=["anthropic_computer"], initial_screenshot=True, ) - # Phase 4: Enter eval context and run agent + # Phase 3: Enter eval context and run agent # The context manager handles: # - Environment connection (MCP servers start) - # - Setup tools execution (launch_app) + # - Scenario setup execution # - Trace creation for telemetry print("🔧 Entering eval context...") - async with task as ctx: - print(f" ✅ Environment connected") - print(f" ✅ Setup tools executed") - print(f" 📝 Prompt: {ctx.prompt[:50]}...") + async with hud.eval(task, name="agent-lifecycle-demo") as ctx: + print(" ✅ Environment connected") + print(f" 📝 Prompt: {ctx.prompt[:50] if ctx.prompt else 'N/A'}...") - # Phase 5: Run the agent + # Phase 4: Run the agent # agent.run() handles the agentic loop: # - Gets system messages # - Sends prompt to model @@ -88,9 +71,9 @@ async def main(): if result.content: print(f" - Response: {result.content[:100]}...") - # Phase 6: After exit, evaluate_tool was automatically called + # Phase 5: After exit, scenario evaluation was automatically called # and ctx.reward is set based on the evaluation - print("\n📊 Evaluation complete (via evaluate_tool)") + print("\n📊 Evaluation complete") print(f" Reward: {ctx.reward}") print(f" Success: {ctx.success}") diff --git a/examples/02_claude_agent.py b/examples/02_claude_agent.py index 9cf5b1b8..bd4c0e62 100644 --- a/examples/02_claude_agent.py +++ b/examples/02_claude_agent.py @@ -13,18 +13,16 @@ """ import asyncio + import hud -from hud.agents import ClaudeAgent -from hud.datasets import LegacyTask +from hud.agents.claude import ClaudeAgent from hud.eval.task import Task -from hud.settings import settings -async def main(): - # For any environment, you can run : +async def main() -> None: + # For any environment, you can run: # hud debug to see the logs - # hud analyze to get a report about its capabilities (tools, resources, etc.) - # e.g. hud analyze hudpython/hud-remote-browser:latest + # hud analyze to get a report about its capabilities initial_url = "https://httpbin.org/forms/post" @@ -41,43 +39,31 @@ async def main(): 9. Verify the submission was successful """ - # Create LegacyTask with mcp_config and setup - legacy_task = LegacyTask( - prompt=prompt, - mcp_config={ - "hud": { - "url": "https://mcp.hud.ai/v3/mcp", - "headers": { - "Authorization": f"Bearer {settings.api_key}", - "Mcp-Image": "hudpython/hud-remote-browser:latest", - }, - } - }, - setup_tool={ - "name": "setup", - "arguments": {"name": "navigate_to_url", "arguments": {"url": initial_url}}, - }, + # Create v5 Task with Environment config + task = Task( + env={"name": "browser"}, # Connect to browser hub + scenario="form_fill", # Scenario name + args={"url": initial_url}, # Scenario args + agent_config={"system_prompt": prompt}, # Pass prompt via agent config ) - # Convert to v5 Task - task = Task.from_v4(legacy_task) - # Create Claude-specific agent agent = ClaudeAgent.create( - checkpoint_name="claude-sonnet-4-5", + checkpoint_name="claude-sonnet-4-20250514", allowed_tools=["anthropic_computer"], initial_screenshot=True, ) - print(f"📋 Task: Multi-step form interaction") - print(f"🚀 Running Claude agent...\n") + print("📋 Task: Multi-step form interaction") + print("🚀 Running Claude agent...\n") # Run with hud.eval() context - async with task as ctx: + async with hud.eval(task, name="claude-form-demo") as ctx: result = await agent.run(ctx, max_steps=15) print("\n✨ Claude agent demo complete!") print(f" Reward: {result.reward}") + print(f" Done: {result.done}") if __name__ == "__main__": diff --git a/examples/03_openai_compatible_agent.py b/examples/03_openai_compatible_agent.py index 51578398..d55db71f 100644 --- a/examples/03_openai_compatible_agent.py +++ b/examples/03_openai_compatible_agent.py @@ -3,8 +3,8 @@ OpenAI-compatible Chat Agent playing 2048 (text or browser). Usage: - python examples/openai_compatible_agent.py --mode text # default - python examples/openai_compatible_agent.py --mode browser + python examples/03_openai_compatible_agent.py --mode text # default + python examples/03_openai_compatible_agent.py --mode browser Requirements: - pip install openai @@ -24,7 +24,7 @@ import hud from hud.agents.openai_chat import OpenAIChatAgent -from hud.datasets import LegacyTask +from hud.eval.task import Task def _system_prompt(mode: Literal["text", "browser"]) -> str: @@ -46,7 +46,7 @@ def _system_prompt(mode: Literal["text", "browser"]) -> str: "- Continue until target or game ends; no confirmations needed.\n\n" "Strategy: keep highest tiles in a corner; maintain order; avoid random moves." ) - # text + # text mode return ( "You are an expert 2048 game player. Your goal is to reach the tile specified by the user.\n\n" "HOW 2048 WORKS:\n" @@ -66,46 +66,29 @@ def _system_prompt(mode: Literal["text", "browser"]) -> str: ) -def _task_for_mode(mode: Literal["text", "browser"], target: int) -> LegacyTask: +def _create_task(mode: Literal["text", "browser"], target: int) -> Task: + """Create a v5 Task for the 2048 game.""" if mode == "browser": - mcp_config = { - "local": { - "command": "docker", - "args": ["run", "--rm", "-i", "-p", "8080:8080", "hudevals/hud-browser:0.1.3"], - } - } - prompt = ( - "Play the browser-based 2048 game and try to reach the target tile. " - "Start by taking a screenshot, then make strategic moves using arrow keys." + # Use local Docker environment for browser mode + env = hud.Environment("2048-browser") + env.connect_image( + "hudevals/hud-browser:0.1.3", + docker_args=["-p", "8080:8080"], + ) + return Task( + env=env, + scenario="game_2048", + args={"target": target}, ) - setup_tool = {"name": "launch_app", "arguments": {"app_name": "2048"}} - evaluate_tool = { - "name": "evaluate", - "arguments": {"name": "game_2048_max_number", "arguments": {"target": target}}, - } else: - mcp_config = { - "local": { - "command": "docker", - "args": ["run", "--rm", "-i", "hudevals/hud-text-2048:0.1.6"], - } - } - prompt = f"Aim for the {target} tile (at least a score of 800!)" - setup_tool = { - "name": "setup", - "arguments": {"name": "board", "arguments": {"board_size": 4}}, - } - evaluate_tool = { - "name": "evaluate", - "arguments": {"name": "max_number", "arguments": {"target": target}}, - } - - return LegacyTask( - prompt=prompt, - mcp_config=mcp_config, - setup_tool=setup_tool, # type: ignore[arg-type] - evaluate_tool=evaluate_tool, # type: ignore[arg-type] - ) + # Use local Docker environment for text mode + env = hud.Environment("2048-text") + env.connect_image("hudevals/hud-text-2048:0.1.6") + return Task( + env=env, + scenario="max_number", + args={"target": target}, + ) async def run_example(mode: Literal["text", "browser"], target: int) -> None: @@ -118,10 +101,10 @@ async def run_example(mode: Literal["text", "browser"], target: int) -> None: api_key=api_key, ) - task = _task_for_mode(mode, target) + task = _create_task(mode, target) system_prompt = _system_prompt(mode) - checkpoint = "gpt-5-mini" # Replace with your model checkpoint + checkpoint = "gpt-4o-mini" # Replace with your model checkpoint # Allowed tools differ by mode allowed_tools = ["computer"] if mode == "browser" else ["move"] @@ -135,17 +118,13 @@ async def run_example(mode: Literal["text", "browser"], target: int) -> None: system_prompt=system_prompt, ) - title = "OpenAI 2048 Game (Browser)" if mode == "browser" else "OpenAI 2048 Game (Text)" print("🎮 Starting 2048 game with OpenAI-compatible agent...") print(f"🤖 Model: {agent.config.checkpoint_name}") print(f"🧩 Mode: {mode}") print("=" * 50) - # Use hud.eval() with Task.from_v4() for legacy task format - from hud.eval.task import Task - - v5_task = Task.from_v4(task) - async with hud.eval(v5_task, variants={"model": checkpoint, "mode": mode}) as ctx: + # Use hud.eval() for the task + async with hud.eval(task, variants={"model": checkpoint, "mode": mode}) as ctx: result = await agent.run(ctx, max_steps=100) print("=" * 50) diff --git a/examples/04_grounded_agent.py b/examples/04_grounded_agent.py index 636baaf6..32bd8f3d 100644 --- a/examples/04_grounded_agent.py +++ b/examples/04_grounded_agent.py @@ -17,79 +17,72 @@ import hud from hud.agents.grounded_openai import GroundedOpenAIChatAgent +from hud.eval.task import Task from hud.settings import settings from hud.tools.grounding import GrounderConfig from openai import AsyncOpenAI -async def main(): +async def main() -> None: """Run the grounded agent example.""" - with hud.trace("Grounded Agent Demo"): - # Configure the grounding model - grounder_config = GrounderConfig( - api_base="https://openrouter.ai/api/v1", # OpenRouter API - model="qwen/qwen-2.5-vl-7b-instruct", # Vision model for grounding - api_key=settings.openrouter_api_key, - ) - - # MCP configuration for environment - mcp_config = { - "local": { - "command": "docker", - "args": ["run", "--rm", "-i", "-p", "8080:8080", "hudevals/hud-browser:0.1.6"], - } - } - - # Create OpenAI client for planning - openai_client = AsyncOpenAI( - api_key=os.getenv("OPENAI_API_KEY", settings.openai_api_key) - ) # can use any OpenAI-compatible endpoint - - agent = GroundedOpenAIChatAgent.create( - grounder_config=grounder_config, - openai_client=openai_client, - checkpoint_name="gpt-4o-mini", # Planning model - ) - - try: - # Create a task with MCP config - from hud.datasets import LegacyTask - - form_url = "https://hb.cran.dev/forms/post" - - form_prompt = f""" - Fill out the form: - 1. Enter "Grounded Test" in the customer name field - 2. Enter "555-9876" in the telephone field - 3. Type "Testing grounded agent with separated vision and reasoning" in comments - 4. Select medium pizza size - 5. Choose mushroom as a topping - 6. Submit the form - """ - - legacy_task = LegacyTask( - prompt=form_prompt, - mcp_config=mcp_config, - setup_tool={ - "name": "playwright", - "arguments": {"action": "navigate", "url": form_url}, - }, - ) - - print(f"📋 Task: Form interaction") - print(f"🚀 Running grounded agent...\n") - - # Convert LegacyTask to Task and run with hud.eval() - from hud.eval.task import Task - - task = Task.from_v4(legacy_task) - async with task as ctx: - result = await agent.run(ctx, max_steps=10) - print(f"Result: {result.content}\n") - - except Exception as e: - print(f"Error during agent execution: {e}") + # Configure the grounding model + openrouter_key = os.getenv("OPENROUTER_API_KEY") or settings.openrouter_api_key + if not openrouter_key: + raise ValueError("OPENROUTER_API_KEY is required for grounding model") + + grounder_config = GrounderConfig( + api_base="https://openrouter.ai/api/v1", # OpenRouter API + model="qwen/qwen-2.5-vl-7b-instruct", # Vision model for grounding + api_key=openrouter_key, + ) + + # Create OpenAI client for planning + openai_client = AsyncOpenAI( + api_key=os.getenv("OPENAI_API_KEY", settings.openai_api_key) + ) # can use any OpenAI-compatible endpoint + + agent = GroundedOpenAIChatAgent.create( + grounder_config=grounder_config, + openai_client=openai_client, + checkpoint_name="gpt-4o-mini", # Planning model + ) + + form_url = "https://hb.cran.dev/forms/post" + + form_prompt = """ + Fill out the form: + 1. Enter "Grounded Test" in the customer name field + 2. Enter "555-9876" in the telephone field + 3. Type "Testing grounded agent with separated vision and reasoning" in comments + 4. Select medium pizza size + 5. Choose mushroom as a topping + 6. Submit the form + """ + + # Create v5 Task with local Docker environment + env = hud.Environment("browser-grounded") + env.connect_image( + "hudevals/hud-browser:0.1.6", + docker_args=["-p", "8080:8080"], + ) + + task = Task( + env=env, + scenario="form_fill", + args={"url": form_url}, + agent_config={"system_prompt": form_prompt}, + ) + + print("📋 Task: Form interaction with grounded agent") + print("🚀 Running grounded agent...\n") + + try: + async with hud.eval(task, name="grounded-form-demo") as ctx: + result = await agent.run(ctx, max_steps=10) + print(f"Result: {result.content}\n") + except Exception as e: + print(f"Error during agent execution: {e}") print("\n✨ Grounded agent demo complete!") diff --git a/examples/05_custom_agent.py b/examples/05_custom_agent.py index 094cb0f3..49a3a5e1 100644 --- a/examples/05_custom_agent.py +++ b/examples/05_custom_agent.py @@ -7,21 +7,21 @@ 3. Works with any model available via the gateway Usage: - HUD_API_KEY=sk-hud-... python examples/custom_gateway_agent.py + HUD_API_KEY=sk-hud-... python examples/05_custom_agent.py """ import asyncio import json -import os from typing import Any import mcp.types as types from openai import AsyncOpenAI -from hud import instrument +import hud from hud.agents.base import MCPAgent -from hud.datasets import LegacyTask +from hud.eval.task import Task from hud.settings import settings +from hud.telemetry.instrument import instrument from hud.types import AgentResponse, MCPToolCall, MCPToolResult @@ -100,7 +100,7 @@ async def get_response(self, messages: list[Any]) -> AgentResponse: response = await self.client.chat.completions.create( model=self.checkpoint_name, messages=messages, - tools=tools if tools else None, # type: ignore + tools=tools if tools else None, # type: ignore[arg-type] max_tokens=self.max_tokens, temperature=self.temperature, ) @@ -195,7 +195,7 @@ async def format_tool_results( return messages -async def main(): +async def main() -> None: """Example usage of MyAgent.""" # Create agent with Claude via Gateway @@ -206,28 +206,17 @@ async def main(): verbose=True, ) - # Define a task with HUD MCP environment - legacy_task = LegacyTask( - prompt="Go to example.com and tell me the page title", - mcp_config={ - "hud": { - "url": "https://mcp.hud.ai/v3/mcp", - "headers": { - "Authorization": f"Bearer {os.environ.get('HUD_API_KEY', '')}", - "Mcp-Image": "hudpython/hud-remote-browser:latest", - }, - } - }, + # Create v5 Task with HUD hub environment + task = Task( + env={"name": "browser"}, # Connect to browser hub + scenario="navigate", + args={"url": "https://example.com"}, + agent_config={"system_prompt": "Go to example.com and tell me the page title"}, ) - # Convert to v5 Task and run with context manager - from hud.eval.task import Task - - task = Task.from_v4(legacy_task) - # Run the agent - traces are automatically captured print("Running agent with HUD Gateway inference...") - async with task as ctx: + async with hud.eval(task, name="custom-agent-demo") as ctx: result = await agent.run(ctx, max_steps=5) print("\n=== Results ===") diff --git a/examples/README.md b/examples/README.md index 303ca16c..88b11a19 100644 --- a/examples/README.md +++ b/examples/README.md @@ -2,76 +2,126 @@ A collection of examples demonstrating HUD SDK usage patterns. -## Quick Start Examples +## Quick Start ### 00_agent_env.py -Minimal MCP server and client in one file. Shows the basic agent-environment communication pattern. +Minimal MCP server and client in one file. Shows the basic agent-environment communication pattern using `hud.eval()`. ```bash python examples/00_agent_env.py ``` -### 01_hello_2048.py -Complete agent evaluation on the 2048 environment using Claude. +### 01_agent_lifecycle.py +Complete agent lifecycle demonstrating: +- v5 Task format with Environment and scenario +- `hud.eval()` context for connection and tracing +- Agent initialization and execution +- Automatic scenario setup/evaluation ```bash -python examples/01_hello_2048.py +python examples/01_agent_lifecycle.py ``` -> | Requires Docker and `ANTHROPIC_API_KEY` environment variable. +> Requires `HUD_API_KEY` and `ANTHROPIC_API_KEY` environment variables. + +## Agent Examples -### 03_browser_agent_loop.py -Quick start for the browser environment (Claude). Supports multiple demo apps. +### 02_claude_agent.py +Claude agent with computer use capabilities for browser automation. ```bash -# 2048 (default) -python examples/03_browser_agent_loop.py +python examples/02_claude_agent.py +``` + +> Requires `HUD_API_KEY` and `ANTHROPIC_API_KEY`. -# Todo app -python examples/03_browser_agent_loop.py --app todo +### 03_openai_compatible_agent.py +OpenAI-compatible chat.completions agent with both text and browser 2048 environments. + +```bash +export OPENAI_API_KEY=your-key +# export OPENAI_BASE_URL=http://localhost:8000/v1 # for local servers (e.g., vllm) + +python examples/03_openai_compatible_agent.py --mode text # text environment +python examples/03_openai_compatible_agent.py --mode browser # browser environment ``` -> | Requires Docker (exposes port 8080) and `ANTHROPIC_API_KEY`. +> Requires Docker for local environment execution. -## Core Patterns +### 04_grounded_agent.py +Grounded agent that separates visual grounding (element detection) from high-level reasoning. -### 01_agent_lifecycle.py -Demonstrates the full agent lifecycle with telemetry and state management. -- Task creation using LegacyTask format -- Trace context for debugging -- Setup and evaluation tool calls +```bash +export OPENAI_API_KEY=your-key +export OPENROUTER_API_KEY=your-key + +python examples/04_grounded_agent.py +``` + +> Requires Docker and API keys for both OpenAI and OpenRouter. + +### 05_custom_agent.py +Build a custom MCPAgent using HUD Gateway for unified model access: +- No need for individual provider API keys +- Works with Anthropic, OpenAI, Gemini, OpenRouter models +- Automatic tracing with `@hud.instrument` + +```bash +HUD_API_KEY=sk-hud-... python examples/05_custom_agent.py +``` + +## Dataset Evaluation ### run_evaluation.py Generic dataset evaluation runner using the programmatic API. ```bash -# Run all tasks +# Run all tasks in a dataset python examples/run_evaluation.py hud-evals/SheetBench-50 # Run specific tasks by index python examples/run_evaluation.py hud-evals/SheetBench-50 --task-ids 0 1 2 -# Use different agent -python examples/run_evaluation.py hud-evals/OSWorld-Verified-Gold --agent operator +# Use different agent and concurrency +python examples/run_evaluation.py hud-evals/OSWorld-Verified-Gold --agent operator --max-concurrent 50 ``` -## Integration Examples +For production evaluations, prefer the CLI: `hud eval --help` -### claude_agent.py -Direct usage of Claude agent without environments. +## Key Concepts -### integration_mcp_use.py -Using the legacy `mcp_use` client for multi-server setups. +### v5 Task Format -### integration_otel.py -Custom OpenTelemetry backend integration (e.g., Jaeger). +The v5 Task format is the recommended way to define evaluation tasks: -### openai_compatible_agent.py -OpenAI-compatible chat.completions agent with both text and browser 2048 environments. +```python +from hud.eval.task import Task -```bash -export OPENAI_API_KEY=your-key # or dummy value for local servers -# export OPENAI_BASE_URL=http://localhost:8000/v1 # e.g., vllm -python examples/openai_compatible_agent.py --mode text # text environment -python examples/openai_compatible_agent.py --mode browser # browser environment +# Simple task with hub environment +task = Task( + env={"name": "browser"}, # Connect to browser hub + scenario="checkout", # Scenario to run + args={"user_id": "alice"}, # Scenario arguments +) + +# Task with local Docker environment +env = hud.Environment("my-env") +env.connect_local(command="docker", args=["run", "--rm", "-i", "my-image"]) +task = Task(env=env, scenario="test") +``` + +### Using hud.eval() + +All examples use `hud.eval()` as the primary entry point: + +```python +async with hud.eval(task, name="my-eval", variants={"model": "gpt-4o"}) as ctx: + result = await agent.run(ctx, max_steps=10) + print(f"Reward: {ctx.reward}") ``` + +The context manager handles: +- Environment connection (MCP servers start) +- Scenario setup execution +- Telemetry and tracing +- Automatic scenario evaluation on exit diff --git a/examples/integration_otel.py b/examples/integration_otel.py deleted file mode 100644 index 9644b67d..00000000 --- a/examples/integration_otel.py +++ /dev/null @@ -1,95 +0,0 @@ -"""Example: Running HUD agents with Jaeger as the tracing backend. - -This example shows how to run a normal HUD agent (playing 2048 game) -but send all traces to Jaeger instead of HUD's backend. - -To run: -1. Build the 2048 game: - docker build -t hud-text-2048 ../environments/text_2048 - -2. Start Jaeger: - docker run -d --name jaeger \ - -e COLLECTOR_OTLP_ENABLED=true \ - -p 16686:16686 -p 4318:4318 \ - jaegertracing/all-in-one:latest - -3. Run this example: - python custom_otel_backend.py - -4. View traces at http://localhost:16686 - - Service: "hud-2048-jaeger" - - You'll see the agent's get_model_response and execute_tools spans - -5. Cleanup: - docker stop jaeger && docker rm jaeger -""" - -import asyncio - -# Configure telemetry BEFORE importing agents to use Jaeger -from hud.otel import configure_telemetry - -configure_telemetry( - service_name="hud-2048-jaeger", - enable_otlp=True, - otlp_endpoint="localhost:4318", # Jaeger's OTLP HTTP endpoint -) - -# Now import everything else -import hud -from hud.agents import ClaudeAgent -from hud.clients import MCPClient -from hud.datasets import LegacyTask - - -async def main(): - """Run 2048 game with Claude agent, traces go to Jaeger.""" - - task_dict = { - "prompt": "Play 2048 and try to get as high as possible. Do not stop even after 2048 is reached.", - "mcp_config": { - "local": {"command": "docker", "args": ["run", "--rm", "-i", "hud-text-2048"]} - }, - "setup_tool": { - "name": "setup", - "arguments": {"name": "board", "arguments": {"board_size": 4}}, - }, - "evaluate_tool": { - "name": "evaluate", - "arguments": {"name": "max_number"}, - }, - } - task = LegacyTask(**task_dict) - - # Create client and agent - # Create agent - its methods are already instrumented with @hud.instrument - agent = ClaudeAgent.create() - - # Convert to v5 Task and run with hud.eval() - from hud.eval.task import Task - - v5_task = Task.from_v4(task) - - # Run with hud.trace() and hud.eval() - this creates spans in Jaeger - with hud.trace("play_2048_game"): - print(f"🎮 Starting 2048 game") - - # Use Task as context manager to get EvalContext - async with v5_task as ctx: - # Agent will play the game with setup and evaluate phases - # Each call to get_model_response() and execute_tools() - # will create child spans in Jaeger automatically - result = await agent.run(ctx, max_steps=20) - - print(f"\n🏁 Game finished!") - print(f" Final reward: {result.reward}") - print(f" Success: {not result.isError}") - - print("\n✅ All traces sent to Jaeger!") - print("🔍 View at: http://localhost:16686") - print(" - Service: 'hud-2048-jaeger'") - print(" - You'll see the agent's reasoning and tool calls") - - -if __name__ == "__main__": - asyncio.run(main()) diff --git a/examples/run_evaluation.py b/examples/run_evaluation.py index d996f9e7..855e977b 100644 --- a/examples/run_evaluation.py +++ b/examples/run_evaluation.py @@ -28,7 +28,7 @@ async def main() -> None: args = parser.parse_args() # Import here to avoid import errors if agents not installed - from hud.datasets import load_dataset, run_dataset, display_results + from hud.datasets import load_dataset, run_dataset # Load dataset as Task objects print(f"Loading {args.dataset}...") @@ -40,25 +40,21 @@ async def main() -> None: tasks = [tasks[i] for i in indices if i < len(tasks)] print(f"Filtered to {len(tasks)} tasks at indices: {args.task_ids}") - # Create agent instance based on type + # Determine agent type and params if args.agent == "operator": - from hud.agents import OperatorAgent - - agent = OperatorAgent.create( - checkpoint_name=args.model or "computer-use-preview", - ) + agent_type = "operator" + agent_params = {"checkpoint_name": args.model or "computer-use-preview"} else: - from hud.agents import ClaudeAgent - - agent = ClaudeAgent.create( - checkpoint_name=args.model or "claude-sonnet-4-5", - ) + agent_type = "claude" + agent_params = {"checkpoint_name": args.model or "claude-sonnet-4-20250514"} - # Run evaluation + # Run evaluation using run_dataset + # Note: run_dataset creates agents fresh per task for proper tool initialization print(f"Running {len(tasks)} tasks with {args.agent} agent...") results = await run_dataset( tasks=tasks, - agent=agent, + agent_type=agent_type, + agent_params=agent_params, max_steps=args.max_steps, max_concurrent=args.max_concurrent, group_size=args.group_size, diff --git a/hud/datasets/tests/test_loader.py b/hud/datasets/tests/test_loader.py index 7ff31544..decf9019 100644 --- a/hud/datasets/tests/test_loader.py +++ b/hud/datasets/tests/test_loader.py @@ -57,7 +57,7 @@ def test_load_dataset_success( task_ids = {t.id for t in tasks} assert task_ids == {"task-1", "task-2"} mock_client.get.assert_called_once_with( - "https://api.hud.ai/evals/test-org/test-dataset", + "https://api.hud.ai/tasks/evalset/test-org/test-dataset", headers={"Authorization": "Bearer test_key"}, params={"all": "true"}, ) @@ -124,7 +124,7 @@ def test_load_dataset_no_api_key( assert len(tasks) == 0 mock_client.get.assert_called_once_with( - "https://api.hud.ai/evals/test-org/test-dataset", + "https://api.hud.ai/tasks/evalset/test-org/test-dataset", headers={}, params={"all": "true"}, ) diff --git a/hud/datasets/tests/test_utils.py b/hud/datasets/tests/test_utils.py index 107f737a..50218f2b 100644 --- a/hud/datasets/tests/test_utils.py +++ b/hud/datasets/tests/test_utils.py @@ -59,9 +59,15 @@ def test_invalid_task_rejected(self): def test_incomplete_v4_task_rejected(self): """Test that incomplete v4 task (missing evaluate_tool) is rejected.""" + # When prompt + mcp_config is present but evaluate_tool is missing, + # it's detected as v4 format but fails validation with pytest.raises(ValueError, match="v4 task missing required fields"): SingleTaskRequest( - task={"prompt": "test", "mcp_config": {}}, # Missing evaluate_tool + task={ + "prompt": "test", + "mcp_config": {"server": {"url": "http://localhost"}}, + # Missing evaluate_tool + }, agent_type=AgentType.CLAUDE, job_id="job-123", task_id="task-1", @@ -244,8 +250,10 @@ class TestSubmitRollouts: @pytest.mark.asyncio async def test_submit_single_task(self): - """Test submitting a single task.""" - tasks = [LegacyTask(id="task-1", prompt="Test prompt", mcp_config={})] + """Test submitting a single task (v5 format).""" + from hud.eval.task import Task + + tasks = [Task(env={"name": "browser"}, scenario="test", id="task-1")] with patch("hud.datasets.utils.httpx.AsyncClient") as mock_client_cls: mock_response = MagicMock() @@ -274,7 +282,9 @@ async def test_submit_single_task(self): @pytest.mark.asyncio async def test_submit_with_group_size(self): """Test submitting with group_size > 1 creates multiple requests per task.""" - tasks = [LegacyTask(id="task-1", prompt="Test prompt", mcp_config={})] + from hud.eval.task import Task + + tasks = [Task(env={"name": "browser"}, scenario="test", id="task-1")] with patch("hud.datasets.utils.httpx.AsyncClient") as mock_client_cls: mock_response = MagicMock() diff --git a/hud/eval/task.py b/hud/eval/task.py index 6a181ddf..706b95bc 100644 --- a/hud/eval/task.py +++ b/hud/eval/task.py @@ -32,6 +32,7 @@ BaseModel, ConfigDict, Field, + field_serializer, field_validator, model_serializer, model_validator, @@ -140,6 +141,22 @@ class Task(BaseModel): # Task metadata - for tracking/filtering, not used by agent metadata: dict[str, Any] = Field(default_factory=dict) + @field_validator("agent_config", mode="before") + @classmethod + def convert_agent_config( + cls, v: TaskAgentConfig | dict[str, Any] | None + ) -> TaskAgentConfig | None: + """Auto-convert dict to TaskAgentConfig.""" + if v is None: + return None + if isinstance(v, TaskAgentConfig): + return v + if isinstance(v, dict): + return TaskAgentConfig(**v) + raise TypeError( + f"Task.agent_config must be TaskAgentConfig or dict. Got {type(v).__name__}" + ) + @model_validator(mode="before") @classmethod def detect_v4_format(cls, data: Any) -> Any: @@ -221,54 +238,51 @@ def convert_validation( ) return converted + @field_serializer("env") + def serialize_env(self, env: Environment | None) -> dict[str, Any] | None: + """Serialize Environment to config dict via to_config().""" + if env is None: + return None + return env.to_config() + @model_serializer(mode="wrap") def _serialize_task( self, handler: Any # SerializerFunctionWrapHandler ) -> dict[str, Any]: - """Custom serializer that converts Environment to config dict. + """Custom serializer for v4 format flattening. - For v5 tasks: outputs {"env": {"name": "browser", ...}, "scenario": ...} - For v4 tasks: outputs {"prompt": ..., "mcp_config": ..., "evaluate_tool": ...} - - Raises ValueError if environment has local tools/scenarios. + For v5 tasks: uses default serialization (env field handled by field_serializer) + For v4 tasks: flattens {"prompt": ..., "mcp_config": ..., "evaluate_tool": ...} """ - from hud.environment import Environment - - # Get default serialization + # Get default serialization (env is already converted by field_serializer) data = handler(self) - # Convert Environment to serializable config - if isinstance(self.env, Environment): - env_config = self.env.to_config() - - # Detect v4 format (has mcp_config) vs v5 format (has name) - if "mcp_config" in env_config: - # v4 format - merge env_config with Task fields - result = env_config.copy() + # Check if this is a v4 task (env config has mcp_config) + env_config = data.get("env") + if env_config and isinstance(env_config, dict) and "mcp_config" in env_config: + # v4 format - flatten into top-level dict + result = env_config.copy() - # Map validation → integration_test_tool - if self.validation: - result["integration_test_tool"] = [ - {"name": v.name, "arguments": v.arguments or {}} - for v in self.validation - ] + # Map validation → integration_test_tool + if self.validation: + result["integration_test_tool"] = [ + {"name": v.name, "arguments": v.arguments or {}} + for v in self.validation + ] - # Preserve agent_config (with system_prompt) - if self.agent_config and self.agent_config.system_prompt: - result["agent_config"] = {"system_prompt": self.agent_config.system_prompt} + # Preserve agent_config + if data.get("agent_config"): + result["agent_config"] = data["agent_config"] - # Preserve metadata - if self.metadata: - result["metadata"] = self.metadata + # Preserve metadata + if data.get("metadata"): + result["metadata"] = data["metadata"] - # Preserve id - if self.id: - result["id"] = self.id + # Preserve id + if data.get("id"): + result["id"] = data["id"] - return result - else: - # v5 format - env config goes in env field - data["env"] = env_config + return result return data From 5e37ea84a25ad47031573837d94f762be6a9cc03 Mon Sep 17 00:00:00 2001 From: lorenss-m Date: Sun, 14 Dec 2025 07:51:23 -0800 Subject: [PATCH 51/92] format and tests --- hud/agents/claude.py | 2 +- hud/agents/operator.py | 1 - hud/agents/tests/test_claude.py | 2 +- hud/agents/tests/test_gemini.py | 8 +- hud/agents/tests/test_run_eval.py | 6 +- hud/cli/flows/tasks.py | 1 - hud/cli/tests/test_cli_root.py | 1 + hud/clients/tests/test_analyze_scenarios.py | 4 +- hud/datasets/loader.py | 4 +- hud/datasets/runner.py | 5 +- hud/datasets/utils.py | 3 - hud/environment/environment.py | 6 +- hud/eval/task.py | 15 +- hud/eval/tests/test_eval.py | 2 - hud/eval/tests/test_task.py | 143 ++++++++++++++++++++ hud/eval/utils.py | 1 - hud/telemetry/exporter.py | 5 +- hud/telemetry/instrument.py | 1 + hud/telemetry/tests/test_exporter.py | 2 +- hud/tests/test_datasets_extended.py | 3 +- hud/tests/test_types.py | 2 +- hud/tools/grounding/grounder.py | 1 - hud/utils/mcp.py | 2 - hud/utils/tests/test_mcp.py | 2 - 24 files changed, 169 insertions(+), 53 deletions(-) create mode 100644 hud/eval/tests/test_task.py diff --git a/hud/agents/claude.py b/hud/agents/claude.py index 39229693..c24c754e 100644 --- a/hud/agents/claude.py +++ b/hud/agents/claude.py @@ -8,7 +8,7 @@ from typing import Any, ClassVar, Literal, cast import mcp.types as types -from anthropic import Anthropic, AsyncAnthropic, AsyncAnthropicBedrock, Omit +from anthropic import AsyncAnthropic, AsyncAnthropicBedrock, Omit from anthropic.types import CacheControlEphemeralParam from anthropic.types.beta import ( BetaBase64ImageSourceParam, diff --git a/hud/agents/operator.py b/hud/agents/operator.py index f16deeb6..d9def5d9 100644 --- a/hud/agents/operator.py +++ b/hud/agents/operator.py @@ -11,7 +11,6 @@ FunctionShellToolParam, FunctionToolParam, ResponseComputerToolCallOutputScreenshotParam, - ResponseInputParam, ) from openai.types.responses.response_input_param import ( ComputerCallOutput, diff --git a/hud/agents/tests/test_claude.py b/hud/agents/tests/test_claude.py index 125e3e7e..2b047438 100644 --- a/hud/agents/tests/test_claude.py +++ b/hud/agents/tests/test_claude.py @@ -7,7 +7,7 @@ from unittest.mock import AsyncMock, MagicMock, patch import pytest -from anthropic import AsyncAnthropic, AsyncAnthropicBedrock, BadRequestError +from anthropic import AsyncAnthropic, AsyncAnthropicBedrock from mcp import types from hud.agents.claude import ( diff --git a/hud/agents/tests/test_gemini.py b/hud/agents/tests/test_gemini.py index 74593d0d..fb0f7c5c 100644 --- a/hud/agents/tests/test_gemini.py +++ b/hud/agents/tests/test_gemini.py @@ -183,7 +183,9 @@ async def test_get_response_text_only(self, mock_gemini_client: genai.Client) -> mock_gemini_client.models.generate_content = MagicMock(return_value=mock_response) - messages = [genai_types.Content(role="user", parts=[genai_types.Part.from_text(text="Status?")])] + messages = [ + genai_types.Content(role="user", parts=[genai_types.Part.from_text(text="Status?")]) + ] response = await agent.get_response(messages) assert response.content == "Task completed successfully" @@ -223,7 +225,9 @@ async def test_get_response_with_thinking(self, mock_gemini_client: genai.Client mock_gemini_client.models.generate_content = MagicMock(return_value=mock_response) messages = [ - genai_types.Content(role="user", parts=[genai_types.Part.from_text(text="Hard question")]) + genai_types.Content( + role="user", parts=[genai_types.Part.from_text(text="Hard question")] + ) ] response = await agent.get_response(messages) diff --git a/hud/agents/tests/test_run_eval.py b/hud/agents/tests/test_run_eval.py index d66b284f..0a09b193 100644 --- a/hud/agents/tests/test_run_eval.py +++ b/hud/agents/tests/test_run_eval.py @@ -48,11 +48,7 @@ async def get_system_messages(self) -> list[Any]: return [] async def format_blocks(self, blocks: list[types.ContentBlock]) -> list[Any]: - return [ - {"type": "text", "text": getattr(b, "text")} - for b in blocks - if hasattr(b, "text") - ] + return [{"type": "text", "text": getattr(b, "text")} for b in blocks if hasattr(b, "text")] class MockEvalContext(EvalContext): diff --git a/hud/cli/flows/tasks.py b/hud/cli/flows/tasks.py index a46af389..9563c6e7 100644 --- a/hud/cli/flows/tasks.py +++ b/hud/cli/flows/tasks.py @@ -16,7 +16,6 @@ from hud.datasets import load_dataset from hud.utils.hud_console import hud_console - logger = logging.getLogger(__name__) diff --git a/hud/cli/tests/test_cli_root.py b/hud/cli/tests/test_cli_root.py index 62500268..cf3c2be3 100644 --- a/hud/cli/tests/test_cli_root.py +++ b/hud/cli/tests/test_cli_root.py @@ -9,6 +9,7 @@ # Import the function directly from the __init__ module to avoid namespace conflict with analyze.py import hud.cli.__init__ as cli_init + analyze_fn = cli_init.analyze if TYPE_CHECKING: diff --git a/hud/clients/tests/test_analyze_scenarios.py b/hud/clients/tests/test_analyze_scenarios.py index d30a4b16..67ae6a47 100644 --- a/hud/clients/tests/test_analyze_scenarios.py +++ b/hud/clients/tests/test_analyze_scenarios.py @@ -23,9 +23,7 @@ def __init__( prompts: list[types.Prompt], resources: list[types.Resource], ) -> None: - super().__init__( - mcp_config={"test": {"url": "mock://test"}}, verbose=True - ) + super().__init__(mcp_config={"test": {"url": "mock://test"}}, verbose=True) self._mock_prompts = prompts self._mock_resources = resources # Skip initialize() (which fetches telemetry); we just need analyze_environment(). diff --git a/hud/datasets/loader.py b/hud/datasets/loader.py index 48c7dd31..d5554e25 100644 --- a/hud/datasets/loader.py +++ b/hud/datasets/loader.py @@ -148,9 +148,7 @@ def load_dataset(source: str, *, raw: bool = False) -> list[Task]: ... def load_dataset(source: str, *, raw: bool = True) -> list[dict[str, Any]]: ... -def load_dataset( - source: str, *, raw: bool = False -) -> list[Task] | list[dict[str, Any]]: +def load_dataset(source: str, *, raw: bool = False) -> list[Task] | list[dict[str, Any]]: """Load tasks from a dataset source. Supports multiple sources with auto-detection: diff --git a/hud/datasets/runner.py b/hud/datasets/runner.py index 402671fe..2805967b 100644 --- a/hud/datasets/runner.py +++ b/hud/datasets/runner.py @@ -79,10 +79,7 @@ async def run_dataset( task_list = [Task.from_v4(tasks)] else: # Sequence of TaskInput - convert each to Task - task_list = [ - t if isinstance(t, Task) else Task.from_v4(t) - for t in tasks - ] + task_list = [t if isinstance(t, Task) else Task.from_v4(t) for t in tasks] if not task_list: raise ValueError("No tasks to run") diff --git a/hud/datasets/utils.py b/hud/datasets/utils.py index 98b28289..fabcbfa9 100644 --- a/hud/datasets/utils.py +++ b/hud/datasets/utils.py @@ -69,7 +69,6 @@ def _validate_task(self) -> SingleTaskRequest: # Neither v4 nor v5 raise ValueError("Task must have 'env' (v5) or 'prompt'+'mcp_config'+'evaluate_tool' (v4)") - @field_validator("job_id") @classmethod @@ -290,5 +289,3 @@ async def cancel_all_jobs() -> dict[str, Any]: ) response.raise_for_status() return response.json() - - diff --git a/hud/environment/environment.py b/hud/environment/environment.py index 74db86d3..85d09972 100644 --- a/hud/environment/environment.py +++ b/hud/environment/environment.py @@ -596,14 +596,12 @@ def to_config(self) -> dict[str, Any]: "prompt": self.prompt, "mcp_config": self._mcp_config, "evaluate_tool": [ - {"name": name, "arguments": args} - for name, args in self._evaluate_calls + {"name": name, "arguments": args} for name, args in self._evaluate_calls ], } if self._setup_calls: config["setup_tool"] = [ - {"name": name, "arguments": args} - for name, args in self._setup_calls + {"name": name, "arguments": args} for name, args in self._setup_calls ] return config diff --git a/hud/eval/task.py b/hud/eval/task.py index 706b95bc..91bbfe44 100644 --- a/hud/eval/task.py +++ b/hud/eval/task.py @@ -58,6 +58,7 @@ class TaskAgentConfig(BaseModel): description="Custom system prompt to pass to the agent", ) + logger = logging.getLogger(__name__) @@ -182,9 +183,7 @@ def detect_v4_format(cls, data: Any) -> Any: @field_validator("env", mode="before") @classmethod - def convert_env( - cls, v: Environment | EnvConfig | dict[str, Any] | None - ) -> Environment | None: + def convert_env(cls, v: Environment | EnvConfig | dict[str, Any] | None) -> Environment | None: """Auto-convert dict/EnvConfig to Environment. Format: {"name": "browser", "include": [...], "exclude": [...]} @@ -211,9 +210,7 @@ def convert_env( env = Environment(v.name) env.connect_hub(v.name, include=v.include, exclude=v.exclude) return env - raise TypeError( - f"Task.env must be Environment, EnvConfig, or dict. Got {type(v).__name__}" - ) + raise TypeError(f"Task.env must be Environment, EnvConfig, or dict. Got {type(v).__name__}") @field_validator("validation", mode="before") @classmethod @@ -247,7 +244,8 @@ def serialize_env(self, env: Environment | None) -> dict[str, Any] | None: @model_serializer(mode="wrap") def _serialize_task( - self, handler: Any # SerializerFunctionWrapHandler + self, + handler: Any, # SerializerFunctionWrapHandler ) -> dict[str, Any]: """Custom serializer for v4 format flattening. @@ -266,8 +264,7 @@ def _serialize_task( # Map validation → integration_test_tool if self.validation: result["integration_test_tool"] = [ - {"name": v.name, "arguments": v.arguments or {}} - for v in self.validation + {"name": v.name, "arguments": v.arguments or {}} for v in self.validation ] # Preserve agent_config diff --git a/hud/eval/tests/test_eval.py b/hud/eval/tests/test_eval.py index 856a69d8..57e2b91d 100644 --- a/hud/eval/tests/test_eval.py +++ b/hud/eval/tests/test_eval.py @@ -2,8 +2,6 @@ from __future__ import annotations -from unittest.mock import AsyncMock, patch - import pytest from hud.eval.task import Task diff --git a/hud/eval/tests/test_task.py b/hud/eval/tests/test_task.py new file mode 100644 index 00000000..335c6754 --- /dev/null +++ b/hud/eval/tests/test_task.py @@ -0,0 +1,143 @@ +"""Tests for hud.eval.task module.""" + +from __future__ import annotations + +import pytest + +from hud.eval.task import Task, TaskAgentConfig + + +class TestTaskSerialization: + """Tests for Task serialization and roundtrip.""" + + def test_v5_task_roundtrip(self) -> None: + """v5 Task serializes and deserializes correctly.""" + task = Task( + env={"name": "browser", "include": ["navigate", "click"]}, + scenario="checkout", + id="task-1", + args={"user_id": "alice"}, + ) + + # Serialize + data = task.model_dump(mode="json") + + # Should have v5 format + assert "env" in data + assert data["env"]["name"] == "browser" + assert data["scenario"] == "checkout" + assert data["id"] == "task-1" + + # Recreate from serialized data + task2 = Task(**data) + + # Serialize again + data2 = task2.model_dump(mode="json") + + # Should be identical + assert data == data2 + + def test_v4_task_roundtrip(self) -> None: + """v4 Task serializes (flattens) and deserializes correctly.""" + v4_dict = { + "prompt": "Go to google.com and search for cats", + "mcp_config": { + "browser": {"url": "http://localhost:8080"}, + }, + "evaluate_tool": {"name": "check_url", "arguments": {"contains": "google"}}, + "setup_tool": {"name": "navigate", "arguments": {"url": "about:blank"}}, + "id": "v4-task-1", + "agent_config": {"system_prompt": "You are a helpful assistant"}, + "metadata": {"category": "navigation"}, + } + + # Create Task from v4 dict + task = Task.from_v4(v4_dict) + + # Serialize (should flatten to v4 format) + data = task.model_dump(mode="json") + + # Should have v4 format (flat, not nested env) + assert "prompt" in data + assert "mcp_config" in data + assert "evaluate_tool" in data + assert data["prompt"] == "Go to google.com and search for cats" + assert data["id"] == "v4-task-1" + + # Recreate from serialized data + task2 = Task(**data) + + # Serialize again + data2 = task2.model_dump(mode="json") + + # Should be identical + assert data == data2 + + def test_v4_preserves_agent_config(self) -> None: + """v4 Task preserves agent_config through roundtrip.""" + v4_dict = { + "prompt": "Test prompt", + "mcp_config": {"server": {"url": "http://localhost"}}, + "evaluate_tool": {"name": "check", "arguments": {}}, + "agent_config": {"system_prompt": "Custom system prompt"}, + } + + task = Task.from_v4(v4_dict) + data = task.model_dump(mode="json") + + assert data.get("agent_config") == {"system_prompt": "Custom system prompt"} + + # Roundtrip + task2 = Task(**data) + assert task2.agent_config is not None + assert task2.agent_config.system_prompt == "Custom system prompt" + + def test_v4_preserves_metadata(self) -> None: + """v4 Task preserves metadata through roundtrip.""" + v4_dict = { + "prompt": "Test prompt", + "mcp_config": {"server": {"url": "http://localhost"}}, + "evaluate_tool": {"name": "check", "arguments": {}}, + "metadata": {"key1": "value1", "key2": 42}, + } + + task = Task.from_v4(v4_dict) + data = task.model_dump(mode="json") + + assert data.get("metadata") == {"key1": "value1", "key2": 42} + + # Roundtrip + task2 = Task(**data) + assert task2.metadata == {"key1": "value1", "key2": 42} + + +class TestTaskValidation: + """Tests for Task validation.""" + + def test_v5_allows_none_env(self) -> None: + """v5 Task allows None env (for blank evals).""" + task = Task(scenario="test") # env=None is valid + assert task.env is None + assert task.scenario == "test" + + def test_v4_requires_evaluate_tool(self) -> None: + """v4 Task requires evaluate_tool for validation.""" + from hud.eval.utils import validate_v4_task + + with pytest.raises(ValueError, match="evaluate_tool"): + validate_v4_task({ + "prompt": "test", + "mcp_config": {"server": {}}, + # Missing evaluate_tool + }) + + def test_agent_config_accepts_dict(self) -> None: + """agent_config can be provided as dict and gets converted.""" + task = Task( + env={"name": "browser"}, + agent_config={"system_prompt": "Hello"}, + ) + + assert isinstance(task.agent_config, TaskAgentConfig) + assert task.agent_config.system_prompt == "Hello" + diff --git a/hud/eval/utils.py b/hud/eval/utils.py index b44b1875..b1c0132e 100644 --- a/hud/eval/utils.py +++ b/hud/eval/utils.py @@ -175,4 +175,3 @@ def _warn_local_mcp(mcp_config: dict[str, Any] | None) -> None: UserWarning, stacklevel=4, ) - diff --git a/hud/telemetry/exporter.py b/hud/telemetry/exporter.py index 3001437f..1b6abf08 100644 --- a/hud/telemetry/exporter.py +++ b/hud/telemetry/exporter.py @@ -14,13 +14,10 @@ import logging from collections import defaultdict from concurrent.futures import ThreadPoolExecutor -from typing import TYPE_CHECKING, Any +from typing import Any from hud.shared import make_request_sync -if TYPE_CHECKING: - pass - logger = logging.getLogger(__name__) # Global singleton thread pool for span exports diff --git a/hud/telemetry/instrument.py b/hud/telemetry/instrument.py index ce45e452..204f11bd 100644 --- a/hud/telemetry/instrument.py +++ b/hud/telemetry/instrument.py @@ -37,6 +37,7 @@ def _get_trace_id() -> str | None: return get_current_trace_id() + if TYPE_CHECKING: from collections.abc import Awaitable, Callable from typing import ParamSpec diff --git a/hud/telemetry/tests/test_exporter.py b/hud/telemetry/tests/test_exporter.py index 3e74cfea..5231316c 100644 --- a/hud/telemetry/tests/test_exporter.py +++ b/hud/telemetry/tests/test_exporter.py @@ -4,7 +4,7 @@ import asyncio from typing import Any -from unittest.mock import MagicMock, patch +from unittest.mock import patch import pytest diff --git a/hud/tests/test_datasets_extended.py b/hud/tests/test_datasets_extended.py index 58b013d6..10c90e54 100644 --- a/hud/tests/test_datasets_extended.py +++ b/hud/tests/test_datasets_extended.py @@ -8,8 +8,7 @@ import pytest from hud.datasets import run_dataset -from hud.types import LegacyTask -from hud.types import MCPToolCall +from hud.types import LegacyTask, MCPToolCall from hud.utils.tasks import save_tasks diff --git a/hud/tests/test_types.py b/hud/tests/test_types.py index 3c275ae1..127cca5c 100644 --- a/hud/tests/test_types.py +++ b/hud/tests/test_types.py @@ -1,6 +1,6 @@ from __future__ import annotations -from unittest.mock import MagicMock, patch +from unittest.mock import patch import pytest from mcp.types import ImageContent, TextContent diff --git a/hud/tools/grounding/grounder.py b/hud/tools/grounding/grounder.py index 31b7b2be..29a073e7 100644 --- a/hud/tools/grounding/grounder.py +++ b/hud/tools/grounding/grounder.py @@ -4,7 +4,6 @@ import base64 import io -import json import logging import re diff --git a/hud/utils/mcp.py b/hud/utils/mcp.py index 859cb5b8..882ac411 100644 --- a/hud/utils/mcp.py +++ b/hud/utils/mcp.py @@ -45,5 +45,3 @@ def patch_mcp_config(mcp_config: dict[str, dict[str, Any]], patch: MCPConfigPatc for key, value in patch.meta.items(): meta = server_cfg.setdefault("meta", {}) meta.setdefault(key, value) - - diff --git a/hud/utils/tests/test_mcp.py b/hud/utils/tests/test_mcp.py index 9be367c7..48b62675 100644 --- a/hud/utils/tests/test_mcp.py +++ b/hud/utils/tests/test_mcp.py @@ -2,8 +2,6 @@ from __future__ import annotations -import pytest - from hud.utils.mcp import MCPConfigPatch, patch_mcp_config From 91b9546d5dd0361f144991204529b297a252a662 Mon Sep 17 00:00:00 2001 From: lorenss-m Date: Sun, 14 Dec 2025 08:00:15 -0800 Subject: [PATCH 52/92] update tests --- hud/datasets/__init__.py | 2 - hud/eval/tests/test_eval.py | 13 ++++-- hud/telemetry/tests/test_exporter.py | 4 ++ hud/tests/test_datasets_extended.py | 27 ------------ hud/utils/tasks.py | 62 ---------------------------- 5 files changed, 13 insertions(+), 95 deletions(-) delete mode 100644 hud/utils/tasks.py diff --git a/hud/datasets/__init__.py b/hud/datasets/__init__.py index ae2c0869..7198ddc1 100644 --- a/hud/datasets/__init__.py +++ b/hud/datasets/__init__.py @@ -13,7 +13,6 @@ from __future__ import annotations from hud.eval.display import display_results -from hud.utils.tasks import save_tasks from .loader import load_dataset from .runner import run_dataset, run_single_task @@ -30,6 +29,5 @@ "load_dataset", "run_dataset", "run_single_task", - "save_tasks", "submit_rollouts", ] diff --git a/hud/eval/tests/test_eval.py b/hud/eval/tests/test_eval.py index 57e2b91d..6d470808 100644 --- a/hud/eval/tests/test_eval.py +++ b/hud/eval/tests/test_eval.py @@ -111,6 +111,7 @@ def test_from_v4_with_legacy_task(self) -> None: legacy = LegacyTask( prompt="Navigate to google.com", mcp_config={"hud": {"url": "https://mcp.hud.ai"}}, + evaluate_tool={"name": "check", "arguments": {}}, ) task = Task.from_v4(legacy) @@ -126,6 +127,7 @@ def test_from_v4_with_dict(self) -> None: { "prompt": "Navigate to google.com", "mcp_config": {"hud": {"url": "https://mcp.hud.ai"}}, + "evaluate_tool": {"name": "check", "arguments": {}}, } ) @@ -140,6 +142,7 @@ def test_from_v4_with_json_string(self) -> None: data = { "prompt": "Navigate to google.com", "mcp_config": {"hud": {"url": "https://mcp.hud.ai"}}, + "evaluate_tool": {"name": "check", "arguments": {}}, } task = Task.from_v4(json.dumps(data)) @@ -154,6 +157,7 @@ def test_from_v4_with_setup_tool(self) -> None: "prompt": "Check URL", "mcp_config": {"hud": {"url": "https://mcp.hud.ai"}}, "setup_tool": {"name": "navigate", "arguments": {"url": "https://google.com"}}, + "evaluate_tool": {"name": "check", "arguments": {}}, } ) @@ -177,14 +181,14 @@ def test_from_v4_with_evaluate_tool(self) -> None: def test_from_v4_with_invalid_type_raises(self) -> None: """Task.from_v4() raises TypeError for invalid input.""" - with pytest.raises(TypeError, match="expects LegacyTask, dict, or JSON string"): + with pytest.raises(TypeError): Task.from_v4(12345) # type: ignore[arg-type] def test_from_v4_with_invalid_json_raises(self) -> None: - """Task.from_v4() raises HudConfigError for invalid JSON.""" - from hud.shared.exceptions import HudConfigError + """Task.from_v4() raises JSONDecodeError for invalid JSON.""" + import json - with pytest.raises(HudConfigError, match="Invalid JSON string"): + with pytest.raises(json.JSONDecodeError): Task.from_v4("not valid json") def test_from_v4_does_not_warn_on_use(self) -> None: @@ -197,6 +201,7 @@ def test_from_v4_does_not_warn_on_use(self) -> None: { "prompt": "test", "mcp_config": {"hud": {}}, + "evaluate_tool": {"name": "check", "arguments": {}}, } ) diff --git a/hud/telemetry/tests/test_exporter.py b/hud/telemetry/tests/test_exporter.py index 5231316c..42c8499a 100644 --- a/hud/telemetry/tests/test_exporter.py +++ b/hud/telemetry/tests/test_exporter.py @@ -227,6 +227,9 @@ class TestShutdown: def test_shutdown_flushes_pending(self): """Test that shutdown flushes pending spans.""" + # Clear any leftover state from previous tests + _pending_spans.clear() + uploaded: list[str] = [] def mock_upload( @@ -241,6 +244,7 @@ def mock_upload( with ( patch("hud.settings.settings") as mock_settings, patch("hud.telemetry.exporter._do_upload", side_effect=mock_upload), + patch("hud.telemetry.exporter._get_api_key", return_value="test-key"), ): mock_settings.api_key = "test-key" mock_settings.telemetry_enabled = True diff --git a/hud/tests/test_datasets_extended.py b/hud/tests/test_datasets_extended.py index 10c90e54..8a4ea7d8 100644 --- a/hud/tests/test_datasets_extended.py +++ b/hud/tests/test_datasets_extended.py @@ -9,7 +9,6 @@ from hud.datasets import run_dataset from hud.types import LegacyTask, MCPToolCall -from hud.utils.tasks import save_tasks class TestTaskExtended: @@ -122,32 +121,6 @@ def test_non_string_values_preserved(self): assert task.mcp_config["nested"]["dict"]["num"] == 123 -class TestDatasetOperations: - """Test dataset conversion and operations.""" - - def test_save_taskconfigs_empty_list(self): - """Test saving empty task list.""" - with patch("datasets.Dataset") as MockDataset: - mock_instance = MagicMock() - MockDataset.from_list.return_value = mock_instance - mock_instance.push_to_hub.return_value = None - - save_tasks([], "test-org/empty-dataset") - - MockDataset.from_list.assert_called_once_with([]) - mock_instance.push_to_hub.assert_called_once_with("test-org/empty-dataset") - - def test_save_taskconfigs_mixed_rejection(self): - """Test that mixing dicts and LegacyTask objects is rejected.""" - valid_dict = {"prompt": "Dict task", "mcp_config": {"test": True}} - - task_object = LegacyTask(prompt="Object task", mcp_config={"resolved": "${SOME_VAR}"}) - - # First item is dict, second is object - with pytest.raises(ValueError, match="Item 1 is a LegacyTask object"): - save_tasks([valid_dict, task_object], "test-org/mixed") # type: ignore - - class TestRunDatasetExtended: """Extended tests for run_dataset functionality.""" diff --git a/hud/utils/tasks.py b/hud/utils/tasks.py deleted file mode 100644 index ca4f4fab..00000000 --- a/hud/utils/tasks.py +++ /dev/null @@ -1,62 +0,0 @@ -from __future__ import annotations - -import json -from typing import Any - -from hud.types import LegacyTask - - -def save_tasks( - tasks: list[dict[str, Any]], - repo_id: str, - fields: list[str] | None = None, - **kwargs: Any, -) -> None: - """Save data to a HuggingFace dataset with JSON string serialization. - - Complex fields (dicts, lists) are serialized as JSON strings to keep schemas clean - and avoid null-value pollution when uploaded to the Hub. - - Args: - tasks: List of dictionaries to save. - repo_id: HuggingFace repository ID (e.g., "hud-evals/my-tasks"). - fields: Optional subset of fields to persist. Defaults to all keys per task. - **kwargs: Extra kwargs forwarded to `Dataset.push_to_hub`. - """ - if tasks and isinstance(tasks[0], LegacyTask): - raise ValueError( - "save_tasks expects dictionaries, not LegacyTask objects. " - "LegacyTask objects have resolved environment variables which would expose secrets. " - "Please pass raw dictionaries with template strings like '${HUD_API_KEY}' preserved." - ) - - data: list[dict[str, Any]] = [] - for index, task_dict in enumerate(tasks): - if isinstance(task_dict, LegacyTask): - raise ValueError( - f"Item {index} is a LegacyTask object, not a dictionary. " - "This would expose resolved environment variables. " - "Please convert to dictionary format with template strings preserved." - ) - - row: dict[str, Any] = {} - fields_to_process = fields if fields is not None else list(task_dict.keys()) - - for field in fields_to_process: - if field not in task_dict: - continue - - value = task_dict[field] - if isinstance(value, (dict | list)): - row[field] = json.dumps(value) - elif isinstance(value, (str | int | float | bool | type(None))): - row[field] = value if value is not None else "" - else: - row[field] = str(value) - - data.append(row) - - from datasets import Dataset - - ds = Dataset.from_list(data) - ds.push_to_hub(repo_id, **kwargs) From c87380b4dd8d061fa09c84996f969be6d600745f Mon Sep 17 00:00:00 2001 From: lorenss-m Date: Sun, 14 Dec 2025 08:06:40 -0800 Subject: [PATCH 53/92] tests --- hud/agents/tests/test_claude.py | 3 ++- hud/agents/tests/test_openai.py | 6 ++++-- hud/agents/tests/test_operator.py | 6 ++++-- hud/cli/dev.py | 4 +--- hud/eval/manager.py | 4 ++++ hud/eval/tests/test_task.py | 15 ++++++++------- hud/telemetry/tests/test_eval_telemetry.py | 8 +++++--- hud/telemetry/tests/test_exporter.py | 10 +++++----- hud/tests/test_datasets_extended.py | 2 +- 9 files changed, 34 insertions(+), 24 deletions(-) diff --git a/hud/agents/tests/test_claude.py b/hud/agents/tests/test_claude.py index 2b047438..4a89053e 100644 --- a/hud/agents/tests/test_claude.py +++ b/hud/agents/tests/test_claude.py @@ -2,7 +2,6 @@ from __future__ import annotations -from collections.abc import Generator from typing import TYPE_CHECKING, Any, cast from unittest.mock import AsyncMock, MagicMock, patch @@ -20,6 +19,8 @@ from hud.types import MCPToolCall, MCPToolResult if TYPE_CHECKING: + from collections.abc import Generator + from anthropic.types.beta import BetaImageBlockParam, BetaMessageParam, BetaTextBlockParam diff --git a/hud/agents/tests/test_openai.py b/hud/agents/tests/test_openai.py index d65acf18..15f9cffc 100644 --- a/hud/agents/tests/test_openai.py +++ b/hud/agents/tests/test_openai.py @@ -2,8 +2,7 @@ from __future__ import annotations -from collections.abc import Generator -from typing import Any, cast +from typing import TYPE_CHECKING, Any, cast from unittest.mock import AsyncMock, patch import pytest @@ -21,6 +20,9 @@ from hud.eval.context import EvalContext from hud.types import MCPToolCall, MCPToolResult +if TYPE_CHECKING: + from collections.abc import Generator + class MockEvalContext(EvalContext): """Mock EvalContext for testing.""" diff --git a/hud/agents/tests/test_operator.py b/hud/agents/tests/test_operator.py index c4e79cd0..d1995f14 100644 --- a/hud/agents/tests/test_operator.py +++ b/hud/agents/tests/test_operator.py @@ -2,8 +2,7 @@ from __future__ import annotations -from collections.abc import Generator -from typing import Any, cast +from typing import TYPE_CHECKING, Any, cast from unittest.mock import AsyncMock, MagicMock, patch import pytest @@ -15,6 +14,9 @@ from hud.eval.context import EvalContext from hud.types import MCPToolCall, MCPToolResult +if TYPE_CHECKING: + from collections.abc import Generator + class MockEvalContext(EvalContext): """Mock EvalContext for testing.""" diff --git a/hud/cli/dev.py b/hud/cli/dev.py index 9a91db52..8555906e 100644 --- a/hud/cli/dev.py +++ b/hud/cli/dev.py @@ -107,9 +107,7 @@ def _has_mcp_or_env(content: str) -> bool: if "mcp" in content and ("= MCPServer" in content or "= FastMCP" in content): return True # Check for env = Environment(...) - if "env" in content and "= Environment" in content: - return True - return False + return "env" in content and "= Environment" in content def auto_detect_module() -> tuple[str, Path | None] | tuple[None, None]: diff --git a/hud/eval/manager.py b/hud/eval/manager.py index 4cb893f7..ab0fa91e 100644 --- a/hud/eval/manager.py +++ b/hud/eval/manager.py @@ -427,6 +427,10 @@ async def run_one(config: tuple[Task | None, dict[str, Any]]) -> EvalContext: else: ctx = EvalContext(name="eval", **params) + # Remove sensitive data from params after context creation to prevent + # accidental logging if an exception includes local variables + params.pop("api_key", None) + try: if sem: async with sem, ctx: diff --git a/hud/eval/tests/test_task.py b/hud/eval/tests/test_task.py index 335c6754..e1027b57 100644 --- a/hud/eval/tests/test_task.py +++ b/hud/eval/tests/test_task.py @@ -125,11 +125,13 @@ def test_v4_requires_evaluate_tool(self) -> None: from hud.eval.utils import validate_v4_task with pytest.raises(ValueError, match="evaluate_tool"): - validate_v4_task({ - "prompt": "test", - "mcp_config": {"server": {}}, - # Missing evaluate_tool - }) + validate_v4_task( + { + "prompt": "test", + "mcp_config": {"server": {}}, + # Missing evaluate_tool + } + ) def test_agent_config_accepts_dict(self) -> None: """agent_config can be provided as dict and gets converted.""" @@ -137,7 +139,6 @@ def test_agent_config_accepts_dict(self) -> None: env={"name": "browser"}, agent_config={"system_prompt": "Hello"}, ) - + assert isinstance(task.agent_config, TaskAgentConfig) assert task.agent_config.system_prompt == "Hello" - diff --git a/hud/telemetry/tests/test_eval_telemetry.py b/hud/telemetry/tests/test_eval_telemetry.py index 15a8760d..8849cd13 100644 --- a/hud/telemetry/tests/test_eval_telemetry.py +++ b/hud/telemetry/tests/test_eval_telemetry.py @@ -11,15 +11,17 @@ import hud from hud.environment import Environment from hud.eval import Task -from hud.telemetry.exporter import _pending_spans +from hud.telemetry.exporter import _pending_futures, _pending_spans @pytest.fixture(autouse=True) -def clear_pending_spans(): - """Clear pending spans before and after each test.""" +def clear_pending_state(): + """Clear pending spans and futures before and after each test.""" _pending_spans.clear() + _pending_futures.clear() yield _pending_spans.clear() + _pending_futures.clear() class TestEvalContextTelemetry: diff --git a/hud/telemetry/tests/test_exporter.py b/hud/telemetry/tests/test_exporter.py index 42c8499a..16c712d7 100644 --- a/hud/telemetry/tests/test_exporter.py +++ b/hud/telemetry/tests/test_exporter.py @@ -10,6 +10,7 @@ from hud.telemetry.exporter import ( _do_upload, + _pending_futures, _pending_spans, flush, queue_span, @@ -18,11 +19,13 @@ @pytest.fixture(autouse=True) -def clear_pending_spans(): - """Clear pending spans before and after each test.""" +def clear_pending_state(): + """Clear pending spans and futures before and after each test.""" _pending_spans.clear() + _pending_futures.clear() yield _pending_spans.clear() + _pending_futures.clear() class TestDoUpload: @@ -227,9 +230,6 @@ class TestShutdown: def test_shutdown_flushes_pending(self): """Test that shutdown flushes pending spans.""" - # Clear any leftover state from previous tests - _pending_spans.clear() - uploaded: list[str] = [] def mock_upload( diff --git a/hud/tests/test_datasets_extended.py b/hud/tests/test_datasets_extended.py index 8a4ea7d8..b9bc9ef9 100644 --- a/hud/tests/test_datasets_extended.py +++ b/hud/tests/test_datasets_extended.py @@ -3,7 +3,7 @@ from __future__ import annotations from typing import cast -from unittest.mock import AsyncMock, MagicMock, patch +from unittest.mock import AsyncMock, patch import pytest From d6cde1f9895ae8d8082f61c71369ab397c3a6da4 Mon Sep 17 00:00:00 2001 From: lorenss-m Date: Sun, 14 Dec 2025 08:38:42 -0800 Subject: [PATCH 54/92] naming changes --- examples/run_evaluation.py | 6 +- hud/agents/base.py | 2 +- hud/agents/claude.py | 6 +- hud/agents/gemini.py | 4 +- hud/agents/gemini_cua.py | 2 +- hud/agents/grounded_openai.py | 4 +- hud/agents/openai.py | 6 +- hud/agents/openai_chat.py | 4 +- hud/agents/operator.py | 2 +- hud/cli/eval.py | 4 +- hud/cli/flows/tasks.py | 4 +- hud/cli/rft.py | 4 +- hud/cli/tests/test_eval.py | 78 ++++++----- hud/datasets/__init__.py | 11 +- hud/datasets/loader.py | 128 +++++++++++++++--- hud/datasets/runner.py | 12 +- hud/datasets/tests/test_loader.py | 54 ++++---- hud/eval/manager.py | 16 +-- hud/eval/task.py | 3 +- hud/eval/tests/test_task.py | 1 + hud/tests/test_datasets_extended.py | 72 +++++----- hud/tests/test_init_module.py | 1 + .../grounding/tests/test_grounded_tool.py | 70 ++++------ hud/types.py | 11 +- 24 files changed, 307 insertions(+), 198 deletions(-) diff --git a/examples/run_evaluation.py b/examples/run_evaluation.py index 855e977b..d6f0d871 100644 --- a/examples/run_evaluation.py +++ b/examples/run_evaluation.py @@ -28,11 +28,11 @@ async def main() -> None: args = parser.parse_args() # Import here to avoid import errors if agents not installed - from hud.datasets import load_dataset, run_dataset + from hud.datasets import load_tasks, run_dataset - # Load dataset as Task objects + # Load tasks from file or API print(f"Loading {args.dataset}...") - tasks = load_dataset(args.dataset) + tasks = load_tasks(args.dataset) # Filter by index if specified if args.task_ids: diff --git a/hud/agents/base.py b/hud/agents/base.py index 831c59ce..5bb2ef54 100644 --- a/hud/agents/base.py +++ b/hud/agents/base.py @@ -78,7 +78,7 @@ def __init__(self, params: BaseCreateParams | None = None, **kwargs: Any) -> Non self.ctx: EvalContext | Environment | None = params.ctx self.model_name: str = getattr(params, "model_name", "MCPAgent") - self.checkpoint_name: str = getattr(params, "checkpoint_name", "unknown") + self.model: str = getattr(params, "model", None) or "unknown" self.auto_respond = params.auto_respond self.console = HUDConsole(logger=logger) diff --git a/hud/agents/claude.py b/hud/agents/claude.py index c24c754e..114eaf74 100644 --- a/hud/agents/claude.py +++ b/hud/agents/claude.py @@ -41,7 +41,7 @@ class ClaudeConfig(BaseAgentConfig): model_config = ConfigDict(arbitrary_types_allowed=True) model_name: str = "Claude" - checkpoint_name: str = "claude-sonnet-4-5" + model: str = "claude-sonnet-4-5" model_client: AsyncAnthropic | AsyncAnthropicBedrock | None = None max_tokens: int = 16384 use_computer_beta: bool = True @@ -153,7 +153,7 @@ async def get_response(self, messages: list[BetaMessageParam]) -> AgentResponse: if isinstance(self.anthropic_client, AsyncAnthropicBedrock): try: response = await self.anthropic_client.beta.messages.create( - model=self.config.checkpoint_name, + model=self.config.model, system=self.system_prompt if self.system_prompt is not None else Omit(), max_tokens=self.max_tokens, messages=messages_cached, @@ -169,7 +169,7 @@ async def get_response(self, messages: list[BetaMessageParam]) -> AgentResponse: else: # Regular Anthropic client supports .stream() async with self.anthropic_client.beta.messages.stream( - model=self.config.checkpoint_name, + model=self.config.model, system=self.system_prompt if self.system_prompt is not None else Omit(), max_tokens=self.max_tokens, messages=messages_cached, diff --git a/hud/agents/gemini.py b/hud/agents/gemini.py index c405f05a..88eaa3ef 100644 --- a/hud/agents/gemini.py +++ b/hud/agents/gemini.py @@ -27,7 +27,7 @@ class GeminiConfig(BaseAgentConfig): model_config = ConfigDict(arbitrary_types_allowed=True) model_name: str = "Gemini" - checkpoint_name: str = "gemini-3-pro-preview" + model: str = "gemini-3-pro-preview" model_client: genai.Client | None = None temperature: float = 1.0 top_p: float = 0.95 @@ -135,7 +135,7 @@ async def get_response(self, messages: list[genai_types.Content]) -> AgentRespon # Make API call response = self.gemini_client.models.generate_content( - model=self.config.checkpoint_name, + model=self.config.model, contents=cast("Any", messages), config=generate_config, ) diff --git a/hud/agents/gemini_cua.py b/hud/agents/gemini_cua.py index 75d8da15..491fea45 100644 --- a/hud/agents/gemini_cua.py +++ b/hud/agents/gemini_cua.py @@ -62,7 +62,7 @@ class GeminiCUAConfig(GeminiConfig): model_config = ConfigDict(arbitrary_types_allowed=True) model_name: str = "GeminiCUA" - checkpoint_name: str = "gemini-2.5-computer-use-preview" + model: str = "gemini-2.5-computer-use-preview" excluded_predefined_functions: list[str] = Field(default_factory=list) diff --git a/hud/agents/grounded_openai.py b/hud/agents/grounded_openai.py index 7e147ddf..e86cb3de 100644 --- a/hud/agents/grounded_openai.py +++ b/hud/agents/grounded_openai.py @@ -38,7 +38,7 @@ class GroundedOpenAIConfig(OpenAIChatConfig): model_config = ConfigDict(arbitrary_types_allowed=True) grounder_config: GrounderConfig - checkpoint_name: str = "gpt-4o-mini" + model: str = "gpt-4o-mini" allowed_tools: list[str] | None = None # Default set in validator append_setup_output: bool = False system_prompt: str | None = DEFAULT_GROUNDED_PROMPT @@ -164,7 +164,7 @@ async def get_response(self, messages: Any) -> AgentResponse: extra = {k: v for k, v in (self.completion_kwargs or {}).items() if k not in protected_keys} response = await self.oai.chat.completions.create( # type: ignore - model=self.config.checkpoint_name, + model=self.config.model, messages=messages, tools=tool_schemas, parallel_tool_calls=False, diff --git a/hud/agents/openai.py b/hud/agents/openai.py index 10b2ad12..c4e4c04e 100644 --- a/hud/agents/openai.py +++ b/hud/agents/openai.py @@ -48,7 +48,7 @@ class OpenAIConfig(BaseAgentConfig): model_config = ConfigDict(arbitrary_types_allowed=True) model_name: str = "OpenAI" - checkpoint_name: str = "gpt-5.1" + model: str = "gpt-5.1" model_client: AsyncOpenAI | None = None max_output_tokens: int | None = None temperature: float | None = None @@ -92,7 +92,7 @@ def __init__(self, params: OpenAICreateParams | None = None, **kwargs: Any) -> N raise ValueError(f"OpenAI API key is invalid: {exc}") from exc self.openai_client = model_client - self.model = self.config.checkpoint_name + self._model = self.config.model self.max_output_tokens = self.config.max_output_tokens self.temperature = self.config.temperature self.reasoning = self.config.reasoning @@ -240,7 +240,7 @@ async def get_response(self, messages: ResponseInputParam) -> AgentResponse: return AgentResponse(content="", tool_calls=[], done=True) response = await self.openai_client.responses.create( - model=self.model, + model=self._model, input=new_items, instructions=self.system_prompt, max_output_tokens=self.max_output_tokens, diff --git a/hud/agents/openai_chat.py b/hud/agents/openai_chat.py index c5eeb14d..f041e4b2 100644 --- a/hud/agents/openai_chat.py +++ b/hud/agents/openai_chat.py @@ -45,7 +45,7 @@ class OpenAIChatConfig(BaseAgentConfig): model_config = ConfigDict(arbitrary_types_allowed=True) model_name: str = "OpenAI Chat" - checkpoint_name: str = "gpt-5-mini" + model: str = "gpt-5-mini" openai_client: AsyncOpenAI | None = None api_key: str | None = None base_url: str | None = None @@ -221,7 +221,7 @@ async def _invoke_chat_completion( raise ValueError("openai_client is required for OpenAIChatAgent") # default transport = OpenAI SDK return await self.oai.chat.completions.create( - model=self.config.checkpoint_name, + model=self.config.model, messages=messages, tools=tools, # type: ignore ready ChatCompletionToolParam-shaped **extra, diff --git a/hud/agents/operator.py b/hud/agents/operator.py index d9def5d9..d9e75f7d 100644 --- a/hud/agents/operator.py +++ b/hud/agents/operator.py @@ -56,7 +56,7 @@ class OperatorConfig(OpenAIConfig): model_config = ConfigDict(arbitrary_types_allowed=True) model_name: str = "Operator" - checkpoint_name: str = "computer-use-preview" + model: str = "computer-use-preview" environment: Literal["windows", "mac", "linux", "ubuntu", "browser"] = "linux" diff --git a/hud/cli/eval.py b/hud/cli/eval.py index 09fa3811..406c5b57 100644 --- a/hud/cli/eval.py +++ b/hud/cli/eval.py @@ -559,14 +559,14 @@ def display(self) -> None: async def _run_evaluation(cfg: EvalConfig) -> tuple[list[Any], list[Any]]: """Run evaluation with the given config using run_dataset().""" - from hud.datasets import load_dataset, run_dataset + from hud.datasets import load_tasks, run_dataset if cfg.source is None or cfg.agent_type is None: raise ValueError("source and agent_type must be set") # Load tasks using unified loader (handles v4→v5 conversion automatically) hud_console.info(f"📊 Loading tasks from: {cfg.source}…") - tasks = load_dataset(cfg.source) + tasks = load_tasks(cfg.source) if not tasks: hud_console.error(f"No tasks found in: {cfg.source}") diff --git a/hud/cli/flows/tasks.py b/hud/cli/flows/tasks.py index 9563c6e7..5cba7d86 100644 --- a/hud/cli/flows/tasks.py +++ b/hud/cli/flows/tasks.py @@ -13,7 +13,7 @@ from hud.cli.utils.docker import require_docker_running from hud.cli.utils.env_check import find_environment_dir from hud.cli.utils.registry import extract_name_and_tag -from hud.datasets import load_dataset +from hud.datasets import load_tasks from hud.utils.hud_console import hud_console logger = logging.getLogger(__name__) @@ -266,7 +266,7 @@ def convert_tasks_to_remote(tasks_file: str) -> str: # Load raw tasks - we work with dicts directly to preserve placeholders # when writing back to disk (e.g., ${HUD_API_KEY}) - raw_tasks: list[dict[str, Any]] = load_dataset(str(tasks_path), raw=True) # type: ignore[assignment] + raw_tasks: list[dict[str, Any]] = load_tasks(str(tasks_path), raw=True) # type: ignore[assignment] # Use the same raw tasks for validation (they have mcp_config structure) tasks = raw_tasks diff --git a/hud/cli/rft.py b/hud/cli/rft.py index eccdb9d1..1c005726 100644 --- a/hud/cli/rft.py +++ b/hud/cli/rft.py @@ -8,7 +8,7 @@ from rich.console import Console from rich.table import Table -from hud.datasets import load_dataset +from hud.datasets import load_tasks from hud.settings import settings from hud.utils.hud_console import HUDConsole @@ -193,7 +193,7 @@ def rft_command( # Load and validate tasks try: # Load tasks as raw dicts for patching and serialization - tasks: list[dict[str, Any]] = load_dataset(tasks_file, raw=True) # type: ignore[assignment] + tasks: list[dict[str, Any]] = load_tasks(tasks_file, raw=True) # type: ignore[assignment] if not tasks: hud_console.error(f"No tasks found in {tasks_file}") raise typer.Exit(1) diff --git a/hud/cli/tests/test_eval.py b/hud/cli/tests/test_eval.py index db89ad60..6ee1a9a4 100644 --- a/hud/cli/tests/test_eval.py +++ b/hud/cli/tests/test_eval.py @@ -9,7 +9,7 @@ from mcp import types from hud.eval.context import EvalContext -from hud.types import MCPToolResult +from hud.types import AgentType, MCPToolResult, Trace class MockEvalContext(EvalContext): @@ -39,17 +39,13 @@ async def submit(self, answer: str) -> None: self._submitted = answer -class MockAgent: - """Mock agent for testing run_dataset.""" - - def __init__(self) -> None: - self.run_count = 0 - - async def run(self, ctx: EvalContext, *, max_steps: int = 10) -> Any: - self.run_count += 1 - ctx.reward = 1.0 - # Return a mock Trace-like object - return MagicMock(reward=1.0, done=True, content="Done") +def _create_mock_agent_cls() -> tuple[MagicMock, MagicMock]: + """Create a mock agent class and instance for testing.""" + mock_agent_instance = MagicMock() + mock_agent_instance.run = AsyncMock(return_value=Trace(reward=1.0, done=True)) + mock_agent_cls = MagicMock() + mock_agent_cls.create.return_value = mock_agent_instance + return mock_agent_cls, mock_agent_instance class TestRunDataset: @@ -64,19 +60,22 @@ async def test_run_dataset_with_task_list(self) -> None: Task(env={"name": "test"}, id="task1", scenario="test"), Task(env={"name": "test"}, id="task2", scenario="test"), ] - agent = MockAgent() + mock_agent_cls, mock_agent_instance = _create_mock_agent_cls() # Mock hud.eval to return our mock context mock_ctx = MockEvalContext() - with patch("hud.datasets.runner.hud.eval") as mock_eval: + with ( + patch("hud.datasets.runner.hud.eval") as mock_eval, + patch.object(AgentType.CLAUDE, "cls", mock_agent_cls), + ): # Set up the async context manager mock_eval.return_value.__aenter__ = AsyncMock(return_value=mock_ctx) mock_eval.return_value.__aexit__ = AsyncMock(return_value=None) from hud.datasets.runner import run_dataset - await run_dataset(tasks, agent, max_steps=5) # type: ignore[arg-type] + await run_dataset(tasks, agent_type="claude", max_steps=5) # Verify hud.eval was called with correct params mock_eval.assert_called_once() @@ -85,7 +84,7 @@ async def test_run_dataset_with_task_list(self) -> None: assert call_kwargs["max_concurrent"] == 30 # Agent should have run - assert agent.run_count == 1 + mock_agent_instance.run.assert_called_once() @pytest.mark.asyncio async def test_run_dataset_with_string_source(self) -> None: @@ -93,19 +92,20 @@ async def test_run_dataset_with_string_source(self) -> None: from hud.eval.task import Task mock_tasks = [Task(env={"name": "test"}, id="loaded_task", scenario="loaded")] - agent = MockAgent() + mock_agent_cls, _ = _create_mock_agent_cls() mock_ctx = MockEvalContext() with ( - patch("hud.datasets.loader.load_dataset", return_value=mock_tasks) as mock_load, + patch("hud.datasets.runner.load_tasks", return_value=mock_tasks) as mock_load, patch("hud.datasets.runner.hud.eval") as mock_eval, + patch.object(AgentType.OPENAI, "cls", mock_agent_cls), ): mock_eval.return_value.__aenter__ = AsyncMock(return_value=mock_ctx) mock_eval.return_value.__aexit__ = AsyncMock(return_value=None) from hud.datasets.runner import run_dataset - await run_dataset("my-tasks.json", agent) # type: ignore[arg-type] + await run_dataset("my-tasks.json", agent_type="openai") # Verify load_dataset was called mock_load.assert_called_once_with("my-tasks.json") @@ -113,13 +113,11 @@ async def test_run_dataset_with_string_source(self) -> None: @pytest.mark.asyncio async def test_run_dataset_empty_tasks_raises(self) -> None: """Test run_dataset raises ValueError for empty tasks.""" - agent = MockAgent() - with patch("hud.datasets.loader.load_dataset", return_value=[]): from hud.datasets.runner import run_dataset with pytest.raises(ValueError, match="No tasks to run"): - await run_dataset([], agent) # type: ignore[arg-type] + await run_dataset([], agent_type=AgentType.CLAUDE) @pytest.mark.asyncio async def test_run_dataset_with_group_size(self) -> None: @@ -127,16 +125,19 @@ async def test_run_dataset_with_group_size(self) -> None: from hud.eval.task import Task tasks = [Task(env={"name": "test"}, id="task1", scenario="test")] - agent = MockAgent() + mock_agent_cls, _ = _create_mock_agent_cls() mock_ctx = MockEvalContext() - with patch("hud.datasets.runner.hud.eval") as mock_eval: + with ( + patch("hud.datasets.runner.hud.eval") as mock_eval, + patch.object(AgentType.CLAUDE, "cls", mock_agent_cls), + ): mock_eval.return_value.__aenter__ = AsyncMock(return_value=mock_ctx) mock_eval.return_value.__aexit__ = AsyncMock(return_value=None) from hud.datasets.runner import run_dataset - await run_dataset(tasks, agent, group_size=3) # type: ignore[arg-type] + await run_dataset(tasks, agent_type="claude", group_size=3) call_kwargs = mock_eval.call_args[1] assert call_kwargs["group"] == 3 @@ -147,16 +148,19 @@ async def test_run_dataset_with_max_concurrent(self) -> None: from hud.eval.task import Task tasks = [Task(env={"name": "test"}, id="task1", scenario="test")] - agent = MockAgent() + mock_agent_cls, _ = _create_mock_agent_cls() mock_ctx = MockEvalContext() - with patch("hud.datasets.runner.hud.eval") as mock_eval: + with ( + patch("hud.datasets.runner.hud.eval") as mock_eval, + patch.object(AgentType.CLAUDE, "cls", mock_agent_cls), + ): mock_eval.return_value.__aenter__ = AsyncMock(return_value=mock_ctx) mock_eval.return_value.__aexit__ = AsyncMock(return_value=None) from hud.datasets.runner import run_dataset - await run_dataset(tasks, agent, max_concurrent=10) # type: ignore[arg-type] + await run_dataset(tasks, agent_type="claude", max_concurrent=10) call_kwargs = mock_eval.call_args[1] assert call_kwargs["max_concurrent"] == 10 @@ -167,16 +171,19 @@ async def test_run_dataset_returns_results(self) -> None: from hud.eval.task import Task tasks = [Task(env={"name": "test"}, id="task1", scenario="test")] - agent = MockAgent() + mock_agent_cls, _ = _create_mock_agent_cls() mock_ctx = MockEvalContext() - with patch("hud.datasets.runner.hud.eval") as mock_eval: + with ( + patch("hud.datasets.runner.hud.eval") as mock_eval, + patch.object(AgentType.CLAUDE, "cls", mock_agent_cls), + ): mock_eval.return_value.__aenter__ = AsyncMock(return_value=mock_ctx) mock_eval.return_value.__aexit__ = AsyncMock(return_value=None) from hud.datasets.runner import run_dataset - results = await run_dataset(tasks, agent) # type: ignore[arg-type] + results = await run_dataset(tasks, agent_type="claude") # Should return list with the context assert len(results) == 1 @@ -188,7 +195,7 @@ async def test_run_dataset_parallel_results(self) -> None: from hud.eval.task import Task tasks = [Task(env={"name": "test"}, id="task1", scenario="test")] - agent = MockAgent() + mock_agent_cls, _ = _create_mock_agent_cls() # Create mock context with results (parallel execution) mock_result1 = MockEvalContext(prompt="result1") @@ -199,13 +206,16 @@ async def test_run_dataset_parallel_results(self) -> None: mock_ctx = MockEvalContext() mock_ctx.results = [mock_result1, mock_result2] - with patch("hud.datasets.runner.hud.eval") as mock_eval: + with ( + patch("hud.datasets.runner.hud.eval") as mock_eval, + patch.object(AgentType.CLAUDE, "cls", mock_agent_cls), + ): mock_eval.return_value.__aenter__ = AsyncMock(return_value=mock_ctx) mock_eval.return_value.__aexit__ = AsyncMock(return_value=None) from hud.datasets.runner import run_dataset - results = await run_dataset(tasks, agent) # type: ignore[arg-type] + results = await run_dataset(tasks, agent_type="claude") # Should return the parallel results assert len(results) == 2 diff --git a/hud/datasets/__init__.py b/hud/datasets/__init__.py index 7198ddc1..6bf88851 100644 --- a/hud/datasets/__init__.py +++ b/hud/datasets/__init__.py @@ -1,9 +1,10 @@ """HUD datasets module. -Provides unified dataset loading and execution for HUD evaluations. +Provides unified task loading, saving, and execution for HUD evaluations. Key functions: -- load_dataset(): Load tasks from JSON, JSONL, HuggingFace, or HUD API +- load_tasks(): Load tasks from JSON, JSONL, HuggingFace, or HUD API +- save_tasks(): Save tasks to the HUD API - run_dataset(): Run an agent on a dataset of tasks - submit_rollouts(): Submit tasks for remote execution @@ -14,7 +15,7 @@ from hud.eval.display import display_results -from .loader import load_dataset +from .loader import load_dataset, load_tasks, save_tasks from .runner import run_dataset, run_single_task from .utils import ( BatchRequest, @@ -26,8 +27,10 @@ "BatchRequest", "SingleTaskRequest", "display_results", - "load_dataset", + "load_dataset", # Deprecated alias + "load_tasks", "run_dataset", "run_single_task", + "save_tasks", "submit_rollouts", ] diff --git a/hud/datasets/loader.py b/hud/datasets/loader.py index d5554e25..ba2d348f 100644 --- a/hud/datasets/loader.py +++ b/hud/datasets/loader.py @@ -1,6 +1,6 @@ -"""Dataset loading utilities for HUD. +"""Task loading utilities for HUD. -Unified interface for loading evaluation datasets from: +Unified interface for loading evaluation tasks from: - HUD API (v5 format) - Local JSON/JSONL files (v4 LegacyTask format, auto-converted) - HuggingFace datasets (v4 LegacyTask format, auto-converted) @@ -10,6 +10,7 @@ import json import logging +import warnings from pathlib import Path from typing import TYPE_CHECKING, Any, overload @@ -18,7 +19,7 @@ logger = logging.getLogger(__name__) -__all__ = ["load_dataset"] +__all__ = ["load_dataset", "load_tasks", "save_tasks"] def _load_raw_from_file(path: Path) -> list[dict[str, Any]]: @@ -141,15 +142,15 @@ def _load_from_api(dataset_name: str) -> list[Task]: @overload -def load_dataset(source: str, *, raw: bool = False) -> list[Task]: ... +def load_tasks(source: str, *, raw: bool = False) -> list[Task]: ... @overload -def load_dataset(source: str, *, raw: bool = True) -> list[dict[str, Any]]: ... +def load_tasks(source: str, *, raw: bool = True) -> list[dict[str, Any]]: ... -def load_dataset(source: str, *, raw: bool = False) -> list[Task] | list[dict[str, Any]]: - """Load tasks from a dataset source. +def load_tasks(source: str, *, raw: bool = False) -> list[Task] | list[dict[str, Any]]: + """Load tasks from a source. Supports multiple sources with auto-detection: - Local file path (JSON or JSONL) @@ -159,7 +160,7 @@ def load_dataset(source: str, *, raw: bool = False) -> list[Task] | list[dict[st Automatically detects and converts v4 LegacyTask format to v5 Task. Args: - source: Dataset source. Can be: + source: Task source. Can be: - Path to a local JSON/JSONL file - HUD API dataset slug (e.g., "hud-evals/SheetBench-50") - HuggingFace dataset name (e.g., "hud-evals/tasks" or "hud-evals/tasks:train") @@ -173,19 +174,19 @@ def load_dataset(source: str, *, raw: bool = False) -> list[Task] | list[dict[st Example: ```python import hud - from hud.datasets import load_dataset + from hud.datasets import load_tasks # Load from HUD API - tasks = load_dataset("hud-evals/SheetBench-50") + tasks = load_tasks("hud-evals/SheetBench-50") # Load from local file (v4 format auto-converted) - tasks = load_dataset("./my-tasks.json") + tasks = load_tasks("./my-tasks.json") # Load from HuggingFace - tasks = load_dataset("hud-evals/benchmark:test") + tasks = load_tasks("hud-evals/benchmark:test") # Load raw dicts (preserves env var placeholders) - raw_tasks = load_dataset("./tasks.json", raw=True) + raw_tasks = load_tasks("./tasks.json", raw=True) # Run evaluation async with hud.eval(tasks) as ctx: @@ -193,7 +194,7 @@ def load_dataset(source: str, *, raw: bool = False) -> list[Task] | list[dict[st ``` Raises: - ValueError: If dataset loading fails + ValueError: If task loading fails """ # Check if it's a local file path = Path(source) @@ -220,8 +221,103 @@ def load_dataset(source: str, *, raw: bool = False) -> list[Task] | list[dict[st return items except ImportError: raise ValueError( - f"Failed to load dataset '{source}'. " + f"Failed to load tasks from '{source}'. " "Install 'datasets' package for HuggingFace support." ) from None except Exception as hf_error: - raise ValueError(f"Failed to load dataset '{source}': {hf_error}") from hf_error + raise ValueError(f"Failed to load tasks from '{source}': {hf_error}") from hf_error + + +def save_tasks( + name: str, + tasks: list[Task], + *, + description: str | None = None, +) -> str: + """Save tasks to the HUD API. + + Creates or updates an evalset with the given tasks. + + Args: + name: Evalset name/slug (e.g., "my-evals/benchmark-v1"). + If no org prefix, uses user's default org. + tasks: List of Task objects (v5 format) to save. + description: Optional description for the evalset. + + Returns: + The evalset ID of the created/updated evalset. + + Example: + ```python + from hud.datasets import save_tasks, load_tasks + from hud.eval.task import Task + from hud.environment import Environment + + # Create tasks + env = Environment("my-env") + tasks = [ + Task(env=env, scenario="checkout", args={"user": "alice"}), + Task(env=env, scenario="checkout", args={"user": "bob"}), + ] + + # Save to HUD API + evalset_id = save_tasks("my-evals/benchmark-v1", tasks) + + # Later, load them back + loaded = load_tasks("my-evals/benchmark-v1") + ``` + + Raises: + ValueError: If API key is not set or save fails + """ + import httpx + + from hud.settings import settings + + if not settings.api_key: + raise ValueError("HUD_API_KEY is required to save tasks") + + # Convert tasks to dicts (Task is a Pydantic model) + task_dicts = [task.model_dump(mode="json", exclude_none=True) for task in tasks] + + # Build request payload + payload: dict[str, Any] = { + "name": name, + "tasks": task_dicts, + } + if description: + payload["description"] = description + + headers = {"Authorization": f"Bearer {settings.api_key}"} + + try: + with httpx.Client(timeout=60) as client: + response = client.post( + f"{settings.hud_api_url}/tasks/evalset", + json=payload, + headers=headers, + ) + response.raise_for_status() + data = response.json() + evalset_id = data.get("evalset_id") or data.get("id") or name + logger.info("Saved %d tasks to evalset: %s", len(tasks), evalset_id) + return evalset_id + except httpx.HTTPStatusError as e: + raise ValueError(f"Failed to save tasks: {e.response.text}") from e + except Exception as e: + raise ValueError(f"Failed to save tasks: {e}") from e + + +# Deprecated alias for backwards compatibility +def load_dataset(source: str, *, raw: bool = False) -> list[Task] | list[dict[str, Any]]: + """Deprecated: Use load_tasks() instead. + + .. deprecated:: 0.6.0 + load_dataset() is deprecated. Use load_tasks() instead. + """ + warnings.warn( + "load_dataset() is deprecated. Use load_tasks() instead.", + DeprecationWarning, + stacklevel=2, + ) + return load_tasks(source, raw=raw) diff --git a/hud/datasets/runner.py b/hud/datasets/runner.py index 2805967b..448cb3a6 100644 --- a/hud/datasets/runner.py +++ b/hud/datasets/runner.py @@ -36,7 +36,7 @@ async def run_dataset( Args: tasks: Tasks to run. Can be: - - A source string (file path, API slug) - loaded via load_dataset() + - A source string (file path, API slug) - loaded via load_tasks() - A single TaskInput (Task, LegacyTask, or dict) - A list of TaskInput objects agent_type: Type of agent to create (e.g., "claude", "openai", AgentType.CLAUDE). @@ -50,10 +50,10 @@ async def run_dataset( Example: ```python - from hud.datasets import load_dataset, run_dataset + from hud.datasets import load_tasks, run_dataset # Load tasks and run - tasks = load_dataset("my-tasks.json") + tasks = load_tasks("my-tasks.json") results = await run_dataset( tasks, agent_type="claude", @@ -65,13 +65,13 @@ async def run_dataset( print(f"Reward: {ctx.reward}") ``` """ - from hud.datasets.loader import load_dataset + from hud.datasets.loader import load_tasks from hud.eval.task import Task # Normalize tasks to list[Task] task_list: list[Task] if isinstance(tasks, str): - task_list = load_dataset(tasks) + task_list = load_tasks(tasks) elif isinstance(tasks, Task): task_list = [tasks] elif isinstance(tasks, LegacyTask | dict): @@ -128,7 +128,7 @@ async def run_single_task( trace/job/group IDs. Used by remote execution workers. Args: - task: Task object to run. Use Task.from_v4() or load_dataset() to create. + task: Task object to run. Use Task.from_v4() or load_tasks() to create. agent_type: AgentType enum specifying the agent to use. agent_params: Parameters passed to agent.create(). Should include pre-configured model_client for inference gateway usage. diff --git a/hud/datasets/tests/test_loader.py b/hud/datasets/tests/test_loader.py index decf9019..5a8ffe5a 100644 --- a/hud/datasets/tests/test_loader.py +++ b/hud/datasets/tests/test_loader.py @@ -6,18 +6,18 @@ import pytest -from hud.datasets.loader import load_dataset +from hud.datasets.loader import load_tasks -class TestLoadDataset: - """Tests for load_dataset() function.""" +class TestLoadTasks: + """Tests for load_tasks() function.""" @patch("httpx.Client") @patch("hud.settings.settings") - def test_load_dataset_success( + def test_load_tasks_success( self, mock_settings: MagicMock, mock_client_class: MagicMock ) -> None: - """load_dataset() successfully loads tasks from API.""" + """load_tasks() successfully loads tasks from API.""" mock_settings.hud_api_url = "https://api.hud.ai" mock_settings.api_key = "test_key" @@ -47,7 +47,7 @@ def test_load_dataset_success( mock_client.__exit__.return_value = None mock_client_class.return_value = mock_client - tasks = load_dataset("test-org/test-dataset") + tasks = load_tasks("test-org/test-dataset") assert len(tasks) == 2 # Tasks are keyed by ID in dict, order may vary @@ -64,10 +64,10 @@ def test_load_dataset_success( @patch("httpx.Client") @patch("hud.settings.settings") - def test_load_dataset_single_task( + def test_load_tasks_single_task( self, mock_settings: MagicMock, mock_client_class: MagicMock ) -> None: - """load_dataset() handles single task in EvalsetTasksResponse.""" + """load_tasks() handles single task in EvalsetTasksResponse.""" mock_settings.hud_api_url = "https://api.hud.ai" mock_settings.api_key = "test_key" @@ -91,7 +91,7 @@ def test_load_dataset_single_task( mock_client.__exit__.return_value = None mock_client_class.return_value = mock_client - tasks = load_dataset("test-org/test-dataset") + tasks = load_tasks("test-org/test-dataset") assert len(tasks) == 1 assert tasks[0].scenario == "checkout" @@ -99,10 +99,10 @@ def test_load_dataset_single_task( @patch("httpx.Client") @patch("hud.settings.settings") - def test_load_dataset_no_api_key( + def test_load_tasks_no_api_key( self, mock_settings: MagicMock, mock_client_class: MagicMock ) -> None: - """load_dataset() works without API key.""" + """load_tasks() works without API key.""" mock_settings.hud_api_url = "https://api.hud.ai" mock_settings.api_key = None @@ -120,7 +120,7 @@ def test_load_dataset_no_api_key( mock_client.__exit__.return_value = None mock_client_class.return_value = mock_client - tasks = load_dataset("test-org/test-dataset") + tasks = load_tasks("test-org/test-dataset") assert len(tasks) == 0 mock_client.get.assert_called_once_with( @@ -131,10 +131,10 @@ def test_load_dataset_no_api_key( @patch("httpx.Client") @patch("hud.settings.settings") - def test_load_dataset_http_error( + def test_load_tasks_http_error( self, mock_settings: MagicMock, mock_client_class: MagicMock ) -> None: - """load_dataset() raises ValueError on HTTP error.""" + """load_tasks() raises ValueError on HTTP error.""" import httpx mock_settings.hud_api_url = "https://api.hud.ai" @@ -146,15 +146,15 @@ def test_load_dataset_http_error( mock_client.__exit__.return_value = None mock_client_class.return_value = mock_client - with pytest.raises(ValueError, match="Failed to load dataset"): - load_dataset("test-org/test-dataset") + with pytest.raises(ValueError, match="Failed to load tasks"): + load_tasks("test-org/test-dataset") @patch("httpx.Client") @patch("hud.settings.settings") - def test_load_dataset_json_error( + def test_load_tasks_json_error( self, mock_settings: MagicMock, mock_client_class: MagicMock ) -> None: - """load_dataset() raises ValueError on JSON processing error.""" + """load_tasks() raises ValueError on JSON processing error.""" mock_settings.hud_api_url = "https://api.hud.ai" mock_settings.api_key = "test_key" @@ -168,15 +168,13 @@ def test_load_dataset_json_error( mock_client.__exit__.return_value = None mock_client_class.return_value = mock_client - with pytest.raises(ValueError, match="Failed to load dataset"): - load_dataset("test-org/test-dataset") + with pytest.raises(ValueError, match="Failed to load tasks"): + load_tasks("test-org/test-dataset") @patch("httpx.Client") @patch("hud.settings.settings") - def test_load_dataset_empty( - self, mock_settings: MagicMock, mock_client_class: MagicMock - ) -> None: - """load_dataset() handles empty dataset.""" + def test_load_tasks_empty(self, mock_settings: MagicMock, mock_client_class: MagicMock) -> None: + """load_tasks() handles empty dataset.""" mock_settings.hud_api_url = "https://api.hud.ai" mock_settings.api_key = "test_key" @@ -190,16 +188,16 @@ def test_load_dataset_empty( mock_client.__exit__.return_value = None mock_client_class.return_value = mock_client - tasks = load_dataset("test-org/test-dataset") + tasks = load_tasks("test-org/test-dataset") assert len(tasks) == 0 @patch("httpx.Client") @patch("hud.settings.settings") - def test_load_dataset_missing_fields( + def test_load_tasks_missing_fields( self, mock_settings: MagicMock, mock_client_class: MagicMock ) -> None: - """load_dataset() handles tasks with missing optional fields (but env is required).""" + """load_tasks() handles tasks with missing optional fields (but env is required).""" mock_settings.hud_api_url = "https://api.hud.ai" mock_settings.api_key = "test_key" @@ -215,7 +213,7 @@ def test_load_dataset_missing_fields( mock_client.__exit__.return_value = None mock_client_class.return_value = mock_client - tasks = load_dataset("test-org/test-dataset") + tasks = load_tasks("test-org/test-dataset") assert len(tasks) == 1 assert tasks[0].scenario == "test" diff --git a/hud/eval/manager.py b/hud/eval/manager.py index ab0fa91e..f6e7673e 100644 --- a/hud/eval/manager.py +++ b/hud/eval/manager.py @@ -109,12 +109,12 @@ async def run_eval( """Standalone eval context manager. Creates an EvalContext for evaluation using Task objects (or deprecated LegacyTask). - For loading tasks from datasets, use load_dataset() first. + For loading tasks from datasets, use load_tasks() first. Args: source: Task source. Can be: - None: Create blank eval context - - Task: Single Task object (from env() or load_dataset()) + - Task: Single Task object (from env() or load_tasks()) - list[Task]: List of Task objects - LegacyTask: Single LegacyTask object (deprecated, use Task.from_v4()) - list[LegacyTask]: List of LegacyTask objects (deprecated) @@ -135,7 +135,7 @@ async def run_eval( Example: ```python - from hud.datasets import load_dataset + from hud.datasets import load_tasks # Blank eval (for manual reward) async with hud.eval() as ctx: @@ -147,8 +147,8 @@ async def run_eval( async with hud.eval(tasks, variants={"model": ["gpt-4o"]}, group=4) as ctx: await agent.run(ctx.prompt) - # Load tasks from dataset first - tasks = load_dataset("hud-evals/SheetBench-50") + # Load tasks from file or API + tasks = load_tasks("hud-evals/SheetBench-50") async with hud.eval(tasks) as ctx: await agent.run(ctx) @@ -196,19 +196,19 @@ async def run_eval( # LegacyTask no longer accepted - user must convert first raise TypeError( "LegacyTask is no longer accepted by hud.eval(). " - "Convert first with Task.from_v4(legacy_task), or use load_dataset()." + "Convert first with Task.from_v4(legacy_task), or use load_tasks()." ) elif isinstance(source, str): # String slugs no longer supported - use load_dataset() raise TypeError( f"String slugs are no longer supported in hud.eval(). " - f"Use load_dataset('{source}') first, then pass the tasks list." + f"Use load_tasks('{source}') first, then pass the tasks list." ) elif isinstance(source, list) and source and isinstance(source[0], str): # List of string slugs no longer supported raise TypeError( "String slugs are no longer supported in hud.eval(). " - "Use load_dataset() first, then pass the tasks list." + "Use load_tasks() first, then pass the tasks list." ) # Calculate total evaluations diff --git a/hud/eval/task.py b/hud/eval/task.py index 91bbfe44..84e270c1 100644 --- a/hud/eval/task.py +++ b/hud/eval/task.py @@ -137,7 +137,8 @@ class Task(BaseModel): validation: list[MCPToolCall] | None = None # Agent config - settings passed to agent (system_prompt, etc.) - agent_config: TaskAgentConfig | None = None + # Accepts TaskAgentConfig or dict (auto-converted via validator) + agent_config: TaskAgentConfig | dict[str, Any] | None = None # Task metadata - for tracking/filtering, not used by agent metadata: dict[str, Any] = Field(default_factory=dict) diff --git a/hud/eval/tests/test_task.py b/hud/eval/tests/test_task.py index e1027b57..b8be866d 100644 --- a/hud/eval/tests/test_task.py +++ b/hud/eval/tests/test_task.py @@ -90,6 +90,7 @@ def test_v4_preserves_agent_config(self) -> None: # Roundtrip task2 = Task(**data) assert task2.agent_config is not None + assert isinstance(task2.agent_config, TaskAgentConfig) assert task2.agent_config.system_prompt == "Custom system prompt" def test_v4_preserves_metadata(self) -> None: diff --git a/hud/tests/test_datasets_extended.py b/hud/tests/test_datasets_extended.py index b9bc9ef9..57d20b2b 100644 --- a/hud/tests/test_datasets_extended.py +++ b/hud/tests/test_datasets_extended.py @@ -127,27 +127,17 @@ class TestRunDatasetExtended: @pytest.mark.asyncio async def test_run_dataset_empty(self): """Test running empty dataset raises ValueError.""" - from hud.agents import MCPAgent - from hud.types import Trace - - # Create mock agent - mock_agent = AsyncMock(spec=MCPAgent) - mock_agent.run.return_value = Trace(reward=1.0, done=True) + from hud.types import AgentType # Empty task list should raise ValueError with pytest.raises(ValueError, match="No tasks to run"): - await run_dataset([], mock_agent) + await run_dataset([], agent_type=AgentType.CLAUDE) @pytest.mark.asyncio async def test_run_dataset_with_task_list(self): """Test run_dataset with Task objects.""" - from hud.agents import MCPAgent from hud.eval.task import Task - from hud.types import Trace - - # Create mock agent - mock_agent = AsyncMock(spec=MCPAgent) - mock_agent.run.return_value = Trace(reward=1.0, done=True) + from hud.types import AgentType, Trace # Create mock tasks with env as dict (to avoid real connections) mock_env = {"name": "test"} @@ -162,26 +152,30 @@ async def test_run_dataset_with_task_list(self): mock_ctx.results = None mock_ctx.reward = None - with patch("hud.datasets.runner.hud.eval") as mock_eval: + # Create mock agent class and instance + mock_agent_instance = AsyncMock() + mock_agent_instance.run.return_value = Trace(reward=1.0, done=True) + mock_agent_cls = AsyncMock() + mock_agent_cls.create.return_value = mock_agent_instance + + with ( + patch("hud.datasets.runner.hud.eval") as mock_eval, + patch.object(AgentType.CLAUDE, "cls", mock_agent_cls), + ): mock_eval.return_value.__aenter__ = AsyncMock(return_value=mock_ctx) mock_eval.return_value.__aexit__ = AsyncMock(return_value=None) - results = await run_dataset(tasks, mock_agent, max_steps=5) + results = await run_dataset(tasks, agent_type="claude", max_steps=5) # Should return list with ctx assert len(results) == 1 - mock_agent.run.assert_called_once() + mock_agent_instance.run.assert_called_once() @pytest.mark.asyncio async def test_run_dataset_from_source_string(self): - """Test run_dataset with source string calls load_dataset.""" - from hud.agents import MCPAgent + """Test run_dataset with source string calls load_tasks.""" from hud.eval.task import Task - from hud.types import Trace - - # Create mock agent - mock_agent = AsyncMock(spec=MCPAgent) - mock_agent.run.return_value = Trace(reward=1.0, done=True) + from hud.types import AgentType, Trace mock_env = {"name": "test"} mock_tasks = [Task(env=mock_env, scenario="loaded")] # type: ignore[arg-type] @@ -189,14 +183,21 @@ async def test_run_dataset_from_source_string(self): mock_ctx = AsyncMock() mock_ctx.results = None + # Create mock agent class and instance + mock_agent_instance = AsyncMock() + mock_agent_instance.run.return_value = Trace(reward=1.0, done=True) + mock_agent_cls = AsyncMock() + mock_agent_cls.create.return_value = mock_agent_instance + with ( - patch("hud.datasets.loader.load_dataset", return_value=mock_tasks) as mock_load, + patch("hud.datasets.runner.load_tasks", return_value=mock_tasks) as mock_load, patch("hud.datasets.runner.hud.eval") as mock_eval, + patch.object(AgentType.OPENAI, "cls", mock_agent_cls), ): mock_eval.return_value.__aenter__ = AsyncMock(return_value=mock_ctx) mock_eval.return_value.__aexit__ = AsyncMock(return_value=None) - await run_dataset("test-org/dataset", mock_agent) + await run_dataset("test-org/dataset", agent_type="openai") # Should call load_dataset with the source string mock_load.assert_called_once_with("test-org/dataset") @@ -204,12 +205,8 @@ async def test_run_dataset_from_source_string(self): @pytest.mark.asyncio async def test_run_dataset_passes_parameters(self): """Test that run_dataset passes parameters correctly to hud.eval.""" - from hud.agents import MCPAgent from hud.eval.task import Task - from hud.types import Trace - - mock_agent = AsyncMock(spec=MCPAgent) - mock_agent.run.return_value = Trace(reward=1.0, done=True) + from hud.types import AgentType, Trace mock_env = {"name": "test"} tasks = [Task(env=mock_env, scenario="test")] @@ -217,11 +214,22 @@ async def test_run_dataset_passes_parameters(self): mock_ctx = AsyncMock() mock_ctx.results = None - with patch("hud.datasets.runner.hud.eval") as mock_eval: + # Create mock agent class and instance + mock_agent_instance = AsyncMock() + mock_agent_instance.run.return_value = Trace(reward=1.0, done=True) + mock_agent_cls = AsyncMock() + mock_agent_cls.create.return_value = mock_agent_instance + + with ( + patch("hud.datasets.runner.hud.eval") as mock_eval, + patch.object(AgentType.CLAUDE, "cls", mock_agent_cls), + ): mock_eval.return_value.__aenter__ = AsyncMock(return_value=mock_ctx) mock_eval.return_value.__aexit__ = AsyncMock(return_value=None) - await run_dataset(tasks, mock_agent, max_steps=25, max_concurrent=10, group_size=3) + await run_dataset( + tasks, agent_type=AgentType.CLAUDE, max_steps=25, max_concurrent=10, group_size=3 + ) # Verify hud.eval was called with correct params mock_eval.assert_called_once_with( diff --git a/hud/tests/test_init_module.py b/hud/tests/test_init_module.py index 2fba8a0c..607dbfae 100644 --- a/hud/tests/test_init_module.py +++ b/hud/tests/test_init_module.py @@ -25,6 +25,7 @@ def test_all_exports(self): "EvalContext", "eval", "instrument", + "trace", # Deprecated alias for eval ] assert set(hud.__all__) == set(expected) diff --git a/hud/tools/grounding/tests/test_grounded_tool.py b/hud/tools/grounding/tests/test_grounded_tool.py index 8b625c27..28fd6d23 100644 --- a/hud/tools/grounding/tests/test_grounded_tool.py +++ b/hud/tools/grounding/tests/test_grounded_tool.py @@ -7,7 +7,7 @@ import pytest from hud.tools.grounding.grounded_tool import GroundedComputerTool -from hud.types import MCPToolCall, MCPToolResult +from hud.types import MCPToolResult @dataclass @@ -17,36 +17,18 @@ class FakeResult: structuredContent: dict | None = None -class FakeMCPClient: - """Fake MCP client that implements AgentMCPClient protocol.""" - - _initialized: bool +class FakeEnvironment: + """Fake Environment that implements the call_tool interface.""" def __init__(self) -> None: self.calls: list[tuple[str, dict[str, Any]]] = [] - self._initialized = False - - @property - def mcp_config(self) -> dict[str, dict[str, Any]]: - return {"test": {"command": "echo", "args": ["test"]}} - - @property - def is_connected(self) -> bool: - return self._initialized - async def initialize(self, mcp_config: dict[str, dict[str, Any]] | None = None) -> None: - self._initialized = True - - async def list_tools(self) -> list[types.Tool]: - return [types.Tool(name="computer", description="Test tool", inputSchema={})] - - async def call_tool(self, tool_call: MCPToolCall) -> MCPToolResult: - self.calls.append((tool_call.name, tool_call.arguments or {})) + async def call_tool(self, call: tuple[str, dict[str, Any]], /, **kwargs: Any) -> MCPToolResult: + """Record the tool call and return a fake result.""" + tool_name, tool_args = call + self.calls.append((tool_name, tool_args)) return MCPToolResult(content=[types.TextContent(text="ok", type="text")], isError=False) - async def shutdown(self) -> None: - self._initialized = False - class FakeGrounder: """Fake grounder that implements Grounder interface.""" @@ -72,9 +54,9 @@ def _png_b64() -> str: @pytest.mark.asyncio async def test_click_action_grounds_and_calls_mcp() -> None: - client = FakeMCPClient() + ctx = FakeEnvironment() grounder = FakeGrounder(coords=(123, 456)) - tool = GroundedComputerTool(grounder=grounder, mcp_client=client) # type: ignore + tool = GroundedComputerTool(grounder=grounder, ctx=ctx) # type: ignore blocks = await tool( action="click", @@ -87,14 +69,14 @@ async def test_click_action_grounds_and_calls_mcp() -> None: # Grounder called once assert len(grounder.calls) == 1 # MCP called with resolved coordinates - assert client.calls == [("computer", {"action": "click", "x": 123, "y": 456, "button": "left"})] + assert ctx.calls == [("computer", {"action": "click", "x": 123, "y": 456, "button": "left"})] @pytest.mark.asyncio async def test_move_and_scroll_require_element_description_and_screenshot() -> None: - client = FakeMCPClient() + ctx = FakeEnvironment() grounder = FakeGrounder(coords=(5, 6)) - tool = GroundedComputerTool(grounder=grounder, mcp_client=client) # type: ignore + tool = GroundedComputerTool(grounder=grounder, ctx=ctx) # type: ignore # Missing element_description with pytest.raises(Exception) as ei: @@ -109,9 +91,9 @@ async def test_move_and_scroll_require_element_description_and_screenshot() -> N @pytest.mark.asyncio async def test_drag_grounds_both_points_and_calls_mcp() -> None: - client = FakeMCPClient() + ctx = FakeEnvironment() grounder = FakeGrounder(coords=(10, 20)) - tool = GroundedComputerTool(grounder=grounder, mcp_client=client) # type: ignore + tool = GroundedComputerTool(grounder=grounder, ctx=ctx) # type: ignore await tool( action="drag", @@ -124,7 +106,7 @@ async def test_drag_grounds_both_points_and_calls_mcp() -> None: # Two grounding calls (start and end) assert len(grounder.calls) == 2 # Drag path contains two points, same coords from fake grounder - name, args = client.calls[0] + name, args = ctx.calls[0] assert name == "computer" assert args["action"] == "drag" assert args["button"] == "left" @@ -133,9 +115,9 @@ async def test_drag_grounds_both_points_and_calls_mcp() -> None: @pytest.mark.asyncio async def test_drag_requires_both_descriptions_and_screenshot() -> None: - client = FakeMCPClient() + ctx = FakeEnvironment() grounder = FakeGrounder() - tool = GroundedComputerTool(grounder=grounder, mcp_client=client) # type: ignore + tool = GroundedComputerTool(grounder=grounder, ctx=ctx) # type: ignore with pytest.raises(Exception) as ei: await tool(action="drag", start_element_description="a", screenshot_b64=_png_b64()) @@ -152,9 +134,9 @@ async def test_drag_requires_both_descriptions_and_screenshot() -> None: @pytest.mark.asyncio async def test_direct_actions_bypass_grounding_and_call_mcp() -> None: - client = FakeMCPClient() + ctx = FakeEnvironment() grounder = FakeGrounder() - tool = GroundedComputerTool(grounder=grounder, mcp_client=client) # type: ignore + tool = GroundedComputerTool(grounder=grounder, ctx=ctx) # type: ignore # Actions that bypass grounding for action, extra in [ @@ -166,19 +148,19 @@ async def test_direct_actions_bypass_grounding_and_call_mcp() -> None: ("get_dimensions", {}), ("get_environment", {}), ]: - client.calls.clear() + ctx.calls.clear() _ = await tool(action=action, **extra) - assert client.calls and client.calls[0][0] == "computer" - assert client.calls[0][1]["action"] == action + assert ctx.calls and ctx.calls[0][0] == "computer" + assert ctx.calls[0][1]["action"] == action # Grounder not invoked for these assert grounder.calls == [] @pytest.mark.asyncio async def test_unsupported_action_raises() -> None: - client = FakeMCPClient() + ctx = FakeEnvironment() grounder = FakeGrounder() - tool = GroundedComputerTool(grounder=grounder, mcp_client=client) # type: ignore + tool = GroundedComputerTool(grounder=grounder, ctx=ctx) # type: ignore with pytest.raises(Exception) as ei: await tool(action="zoom") @@ -187,9 +169,9 @@ async def test_unsupported_action_raises() -> None: @pytest.mark.asyncio async def test_grounding_failure_propagates_as_error() -> None: - client = FakeMCPClient() + ctx = FakeEnvironment() grounder = FakeGrounder(coords=None) - tool = GroundedComputerTool(grounder=grounder, mcp_client=client) # type: ignore + tool = GroundedComputerTool(grounder=grounder, ctx=ctx) # type: ignore with pytest.raises(Exception) as ei: await tool(action="click", element_description="x", screenshot_b64=_png_b64()) diff --git a/hud/types.py b/hud/types.py index 73f90b54..91b00cad 100644 --- a/hud/types.py +++ b/hud/types.py @@ -62,7 +62,11 @@ class BaseAgentConfig(BaseModel): at the agent level. These should be configured on the Environment/Task instead. """ - model_config = ConfigDict(arbitrary_types_allowed=True, extra="forbid") + model_config = ConfigDict(arbitrary_types_allowed=True, extra="forbid", populate_by_name=True) + + # Model identifier - use 'model' (preferred) or 'checkpoint_name' (alias) + model: str | None = Field(default=None, validation_alias="checkpoint_name") + model_name: str = "Agent" # Human-readable display name # LLM-specific setting system_prompt: str | None = None @@ -73,6 +77,11 @@ class BaseAgentConfig(BaseModel): append_setup_output: bool = True initial_screenshot: bool = True + @property + def checkpoint_name(self) -> str | None: + """Alias for model (for backwards compatibility).""" + return self.model + class LegacyTask(BaseModel): """ From 2c57f0dab8e7b15c2e6ad6621ea87967c102a27f Mon Sep 17 00:00:00 2001 From: lorenss-m Date: Sun, 14 Dec 2025 09:15:52 -0800 Subject: [PATCH 55/92] test fixes --- hud/agents/tests/test_claude.py | 14 +++--- hud/agents/tests/test_gemini.py | 6 +-- .../tests/test_grounded_openai_agent.py | 4 +- hud/agents/tests/test_openai.py | 8 ++-- hud/agents/tests/test_operator.py | 4 +- hud/datasets/loader.py | 14 ++++-- hud/eval/context.py | 27 ++++++++++- hud/eval/display.py | 46 ++++++++++++++++++- hud/types.py | 6 ++- 9 files changed, 101 insertions(+), 28 deletions(-) diff --git a/hud/agents/tests/test_claude.py b/hud/agents/tests/test_claude.py index 4a89053e..a0285ad3 100644 --- a/hud/agents/tests/test_claude.py +++ b/hud/agents/tests/test_claude.py @@ -128,12 +128,12 @@ async def test_init_with_client(self, mock_anthropic: AsyncAnthropic) -> None: """Test agent initialization with provided client.""" agent = ClaudeAgent.create( model_client=mock_anthropic, - checkpoint_name="claude-sonnet-4-20250514", + model="claude-sonnet-4-20250514", validate_api_key=False, ) assert agent.model_name == "Claude" - assert agent.config.checkpoint_name == "claude-sonnet-4-20250514" + assert agent.config.model == "claude-sonnet-4-20250514" assert agent.anthropic_client == mock_anthropic @pytest.mark.asyncio @@ -141,7 +141,7 @@ async def test_init_with_parameters(self, mock_anthropic: AsyncAnthropic) -> Non """Test agent initialization with various parameters.""" agent = ClaudeAgent.create( model_client=mock_anthropic, - checkpoint_name="claude-sonnet-4-20250514", + model="claude-sonnet-4-20250514", max_tokens=4096, validate_api_key=False, ) @@ -410,12 +410,12 @@ async def test_init(self, bedrock_client: AsyncAnthropicBedrock) -> None: """Test agent initialization.""" agent = ClaudeAgent.create( model_client=bedrock_client, - checkpoint_name="test-model-arn", + model="test-model-arn", validate_api_key=False, ) assert agent.model_name == "Claude" - assert agent.config.checkpoint_name == "test-model-arn" + assert agent.config.model == "test-model-arn" assert agent.anthropic_client == bedrock_client @pytest.mark.asyncio @@ -426,7 +426,7 @@ async def test_get_response_bedrock_uses_create_not_stream( with patch("hud.settings.settings.telemetry_enabled", False): agent = ClaudeAgent.create( model_client=bedrock_client, - checkpoint_name="test-model-arn", + model="test-model-arn", validate_api_key=False, ) @@ -473,7 +473,7 @@ async def test_get_response_bedrock_missing_boto3_raises_value_error( with patch("hud.settings.settings.telemetry_enabled", False): agent = ClaudeAgent.create( model_client=bedrock_client, - checkpoint_name="test-model-arn", + model="test-model-arn", validate_api_key=False, ) diff --git a/hud/agents/tests/test_gemini.py b/hud/agents/tests/test_gemini.py index fb0f7c5c..a89c38ca 100644 --- a/hud/agents/tests/test_gemini.py +++ b/hud/agents/tests/test_gemini.py @@ -54,12 +54,12 @@ async def test_init(self, mock_gemini_client: genai.Client) -> None: """Test agent initialization.""" agent = GeminiAgent.create( model_client=mock_gemini_client, - checkpoint_name="gemini-2.5-flash", + model="gemini-2.5-flash", validate_api_key=False, ) assert agent.model_name == "Gemini" - assert agent.config.checkpoint_name == "gemini-2.5-flash" + assert agent.config.model == "gemini-2.5-flash" assert agent.gemini_client == mock_gemini_client @pytest.mark.asyncio @@ -76,7 +76,7 @@ async def test_init_without_model_client(self) -> None: mock_client_class.return_value = mock_client agent = GeminiAgent.create( - checkpoint_name="gemini-2.5-flash", + model="gemini-2.5-flash", validate_api_key=False, ) diff --git a/hud/agents/tests/test_grounded_openai_agent.py b/hud/agents/tests/test_grounded_openai_agent.py index ff8b2bfe..6e748d6e 100644 --- a/hud/agents/tests/test_grounded_openai_agent.py +++ b/hud/agents/tests/test_grounded_openai_agent.py @@ -76,7 +76,7 @@ async def test_call_tools_injects_screenshot_and_delegates(monkeypatch: pytest.M agent = GroundedOpenAIChatAgent.create( grounder_config=grounder_cfg, openai_client=fake_openai, - checkpoint_name="gpt-4o-mini", + model="gpt-4o-mini", initial_screenshot=False, ) @@ -129,7 +129,7 @@ async def test_get_response_with_reasoning() -> None: agent = GroundedOpenAIChatAgent.create( grounder_config=grounder_cfg, openai_client=fake_openai, - checkpoint_name="gpt-4o-mini", + model="gpt-4o-mini", initial_screenshot=False, ) diff --git a/hud/agents/tests/test_openai.py b/hud/agents/tests/test_openai.py index 15f9cffc..f36d519f 100644 --- a/hud/agents/tests/test_openai.py +++ b/hud/agents/tests/test_openai.py @@ -64,13 +64,13 @@ async def test_init_with_client(self, mock_openai: AsyncOpenAI) -> None: """Test agent initialization with provided client.""" agent = OpenAIAgent.create( model_client=mock_openai, - checkpoint_name="gpt-4o", + model="gpt-4o", validate_api_key=False, ) assert agent.model_name == "OpenAI" - assert agent.config.checkpoint_name == "gpt-4o" - assert agent.checkpoint_name == "gpt-4o" + assert agent.config.model == "gpt-4o" + assert agent.model == "gpt-4o" assert agent.openai_client == mock_openai assert agent.max_output_tokens is None assert agent.temperature is None @@ -80,7 +80,7 @@ async def test_init_with_parameters(self, mock_openai: AsyncOpenAI) -> None: """Test agent initialization with various parameters.""" agent = OpenAIAgent.create( model_client=mock_openai, - checkpoint_name="gpt-4o", + model="gpt-4o", max_output_tokens=2048, temperature=0.7, reasoning={"effort": "high"}, diff --git a/hud/agents/tests/test_operator.py b/hud/agents/tests/test_operator.py index d1995f14..3303855a 100644 --- a/hud/agents/tests/test_operator.py +++ b/hud/agents/tests/test_operator.py @@ -69,12 +69,12 @@ async def test_init(self, mock_openai: AsyncOpenAI) -> None: """Test agent initialization.""" agent = OperatorAgent.create( model_client=mock_openai, - checkpoint_name="gpt-4", + model="gpt-4", validate_api_key=False, ) assert agent.model_name == "Operator" - assert agent.config.checkpoint_name == "gpt-4" + assert agent.config.model == "gpt-4" assert agent.openai_client == mock_openai @pytest.mark.asyncio diff --git a/hud/datasets/loader.py b/hud/datasets/loader.py index ba2d348f..93acf11e 100644 --- a/hud/datasets/loader.py +++ b/hud/datasets/loader.py @@ -231,8 +231,6 @@ def load_tasks(source: str, *, raw: bool = False) -> list[Task] | list[dict[str, def save_tasks( name: str, tasks: list[Task], - *, - description: str | None = None, ) -> str: """Save tasks to the HUD API. @@ -242,7 +240,6 @@ def save_tasks( name: Evalset name/slug (e.g., "my-evals/benchmark-v1"). If no org prefix, uses user's default org. tasks: List of Task objects (v5 format) to save. - description: Optional description for the evalset. Returns: The evalset ID of the created/updated evalset. @@ -268,6 +265,7 @@ def save_tasks( ``` Raises: + TypeError: If any task is not a v5 Task object (must have 'scenario') ValueError: If API key is not set or save fails """ import httpx @@ -277,6 +275,14 @@ def save_tasks( if not settings.api_key: raise ValueError("HUD_API_KEY is required to save tasks") + # Validate all tasks are v5 format (must have 'scenario') + for i, task in enumerate(tasks): + if not hasattr(task, "scenario"): + raise TypeError( + f"Task at index {i} is missing 'scenario' - only v5 Task objects can be saved. " + "Use Task.from_v4(legacy_task) to convert from LegacyTask." + ) + # Convert tasks to dicts (Task is a Pydantic model) task_dicts = [task.model_dump(mode="json", exclude_none=True) for task in tasks] @@ -285,8 +291,6 @@ def save_tasks( "name": name, "tasks": task_dicts, } - if description: - payload["description"] = description headers = {"Authorization": f"Bearer {settings.api_key}"} diff --git a/hud/eval/context.py b/hud/eval/context.py index f84ba9ca..a86cec47 100644 --- a/hud/eval/context.py +++ b/hud/eval/context.py @@ -311,8 +311,12 @@ def from_task( ctx._task = task # Set system_prompt from task.agent_config - if task.agent_config and task.agent_config.system_prompt: - ctx.system_prompt = task.agent_config.system_prompt + if task.agent_config: + if isinstance(task.agent_config, dict): + if task.agent_config.get("system_prompt"): + ctx.system_prompt = task.agent_config["system_prompt"] + elif task.agent_config.system_prompt: + ctx.system_prompt = task.agent_config.system_prompt return ctx @@ -562,6 +566,10 @@ async def __aexit__( # Notify backend await self._eval_exit(error_msg) + + # Print single eval result summary (unless suppressed for parallel evals) + self._print_single_result(error_msg) + return False # ========================================================================= @@ -591,6 +599,21 @@ def _print_eval_link(self) -> None: trace_url = f"https://hud.ai/trace/{self.trace_id}" print_link(trace_url, "🔗 Eval Started") + def _print_single_result(self, error_msg: str | None) -> None: + """Print a single eval result summary.""" + # Skip if link printing is suppressed (e.g., parallel child traces) + if self._suppress_link: + return + + from hud.eval.display import print_single_result + + print_single_result( + trace_id=self.trace_id, + name=self.eval_name, + reward=self.reward, + error=error_msg, + ) + # Re-export for backwards compatibility with trace module __all__ = [ diff --git a/hud/eval/display.py b/hud/eval/display.py index 1d23d494..7600ac47 100644 --- a/hud/eval/display.py +++ b/hud/eval/display.py @@ -62,6 +62,44 @@ def print_complete(url: str, name: str, *, error: bool = False) -> None: print(f"\n{name} {status}: {url}\n") # noqa: T201 +def print_single_result( + trace_id: str, + name: str, + *, + reward: float | None = None, + error: str | None = None, +) -> None: + """Print a single eval result summary.""" + if not (settings.telemetry_enabled and settings.api_key): + return + + url = f"https://hud.ai/trace/{trace_id}" + + try: + from rich.console import Console + + console = Console() + + if error: + console.print( + f"\n[red]✗ '{name}' failed![/red]\n" + f" [dim]Error:[/dim] [red]{error[:80]}{'...' if len(error) > 80 else ''}[/red]\n" + f" [dim]View at:[/dim] [bold link={url}]{url}[/bold link]\n" + ) + else: + reward_str = f"{reward:.3f}" if reward is not None else "—" + reward_color = "green" if reward is not None and reward > 0.7 else "yellow" + console.print( + f"\n[green]✓ '{name}' complete![/green]\n" + f" [dim]Reward:[/dim] [{reward_color}]{reward_str}[/{reward_color}]\n" + f" [dim]View at:[/dim] [bold link={url}]{url}[/bold link]\n" + ) + except ImportError: + status = "failed" if error else "complete" + reward_str = f", reward={reward:.3f}" if reward is not None else "" + print(f"\n{name} {status}{reward_str}: {url}\n") # noqa: T201 + + def display_results( results: list[Any], *, @@ -242,4 +280,10 @@ def _get_task_label(task: Any, index: int) -> str: # Backwards compatibility alias print_eval_stats = display_results -__all__ = ["display_results", "print_complete", "print_eval_stats", "print_link"] +__all__ = [ + "display_results", + "print_complete", + "print_eval_stats", + "print_link", + "print_single_result", +] diff --git a/hud/types.py b/hud/types.py index 91b00cad..6502e7b4 100644 --- a/hud/types.py +++ b/hud/types.py @@ -8,7 +8,7 @@ import mcp.types as types from mcp.types import CallToolRequestParams, CallToolResult -from pydantic import BaseModel, ConfigDict, Field, field_validator +from pydantic import AliasChoices, BaseModel, ConfigDict, Field, field_validator from hud.settings import settings from hud.utils.env import resolve_env_vars as _resolve_env_vars @@ -65,7 +65,9 @@ class BaseAgentConfig(BaseModel): model_config = ConfigDict(arbitrary_types_allowed=True, extra="forbid", populate_by_name=True) # Model identifier - use 'model' (preferred) or 'checkpoint_name' (alias) - model: str | None = Field(default=None, validation_alias="checkpoint_name") + model: str | None = Field( + default=None, validation_alias=AliasChoices("model", "checkpoint_name") + ) model_name: str = "Agent" # Human-readable display name # LLM-specific setting From 357cd191ef22aeb6b6132067d8e3c4ab7e17fb83 Mon Sep 17 00:00:00 2001 From: lorenss-m Date: Sun, 14 Dec 2025 09:41:35 -0800 Subject: [PATCH 56/92] update tests --- hud/agents/tests/test_base.py | 32 +++++++------- hud/agents/tests/test_base_runtime.py | 16 +++---- hud/agents/tests/test_claude.py | 8 +--- .../tests/test_grounded_openai_agent.py | 1 + hud/agents/tests/test_openai.py | 6 +-- hud/agents/tests/test_run_eval.py | 2 +- hud/cli/tests/test_convert.py | 42 ++++++++----------- hud/cli/tests/test_eval.py | 2 +- hud/cli/utils/interactive.py | 1 - .../utils/tests/test_interactive_module.py | 2 +- .../tests/test_mcp_server_integration.py | 22 +++++----- hud/server/tests/test_mcp_server_more.py | 2 +- hud/tests/test_datasets_extended.py | 2 +- 13 files changed, 62 insertions(+), 76 deletions(-) diff --git a/hud/agents/tests/test_base.py b/hud/agents/tests/test_base.py index 25fab1d8..c8e19a9f 100644 --- a/hud/agents/tests/test_base.py +++ b/hud/agents/tests/test_base.py @@ -97,14 +97,14 @@ class TestMCPAgentInit: def test_init_defaults(self) -> None: """Test agent initializes with default config.""" - agent = MockMCPAgent(auto_trace=False) + agent = MockMCPAgent() assert agent.ctx is None assert agent._initialized is False assert agent.system_prompt is None def test_init_with_system_prompt(self) -> None: """Test agent with custom system prompt.""" - agent = MockMCPAgent(auto_trace=False, system_prompt="Custom prompt") + agent = MockMCPAgent(system_prompt="Custom prompt") assert agent.system_prompt == "Custom prompt" @@ -115,7 +115,7 @@ class TestMCPAgentRun: async def test_run_basic(self) -> None: """Test basic run flow with EvalContext.""" ctx = MockEvalContext(prompt="Do something") - agent = MockMCPAgent(auto_trace=False) + agent = MockMCPAgent() result = await agent.run(ctx) @@ -127,7 +127,7 @@ async def test_run_basic(self) -> None: async def test_run_initializes_agent(self) -> None: """Test run() initializes the agent with context.""" ctx = MockEvalContext(prompt="Do something") - agent = MockMCPAgent(auto_trace=False) + agent = MockMCPAgent() assert not agent._initialized await agent.run(ctx) @@ -141,7 +141,7 @@ async def test_run_discovers_tools(self) -> None: types.Tool(name="tool2", description="Tool 2", inputSchema={}), ] ctx = MockEvalContext(prompt="Do something", tools=tools) - agent = MockMCPAgent(auto_trace=False) + agent = MockMCPAgent() # We need to check tools before cleanup # Store a reference to check @@ -163,7 +163,7 @@ async def capture_tools(*args: Any, **kwargs: Any) -> Any: @pytest.mark.asyncio async def test_run_requires_eval_context(self) -> None: """Test run() raises TypeError for non-EvalContext.""" - agent = MockMCPAgent(auto_trace=False) + agent = MockMCPAgent() with pytest.raises(TypeError, match="must be EvalContext"): await agent.run("not a context") # type: ignore @@ -172,7 +172,7 @@ async def test_run_requires_eval_context(self) -> None: async def test_run_requires_prompt(self) -> None: """Test run() raises ValueError when prompt is empty.""" ctx = MockEvalContext(prompt="") - agent = MockMCPAgent(auto_trace=False) + agent = MockMCPAgent() with pytest.raises(ValueError, match="prompt is not set"): await agent.run(ctx) @@ -181,7 +181,7 @@ async def test_run_requires_prompt(self) -> None: async def test_run_clears_context_after(self) -> None: """Test run() clears ctx after completion.""" ctx = MockEvalContext(prompt="Do something") - agent = MockMCPAgent(auto_trace=False) + agent = MockMCPAgent() await agent.run(ctx) assert agent.ctx is None @@ -190,7 +190,7 @@ async def test_run_clears_context_after(self) -> None: async def test_run_no_submit_on_empty_content(self) -> None: """Test run() doesn't submit when content is empty.""" ctx = MockEvalContext(prompt="Do something") - agent = MockMCPAgent(auto_trace=False) + agent = MockMCPAgent() agent.set_response(AgentResponse(content="", tool_calls=[], done=True)) await agent.run(ctx) @@ -204,7 +204,7 @@ class TestMCPAgentToolCalling: async def test_call_tools_uses_context(self) -> None: """Test call_tools routes through ctx.call_tool.""" ctx = MockEvalContext(prompt="Do something") - agent = MockMCPAgent(auto_trace=False) + agent = MockMCPAgent() # Bind context manually agent.ctx = ctx @@ -220,7 +220,7 @@ async def test_call_tools_uses_context(self) -> None: @pytest.mark.asyncio async def test_call_tools_without_context_raises(self) -> None: """Test call_tools raises when no context bound.""" - agent = MockMCPAgent(auto_trace=False) + agent = MockMCPAgent() with pytest.raises(ValueError, match="not bound to context"): await agent.call_tools(MCPToolCall(name="test_tool", arguments={})) @@ -237,7 +237,7 @@ class AgentWithRequiredTools(MockMCPAgent): required_tools: ClassVar[list[str]] = ["must_have_tool"] ctx = MockEvalContext(prompt="Do something", tools=[]) - agent = AgentWithRequiredTools(auto_trace=False) + agent = AgentWithRequiredTools() with pytest.raises(ValueError, match="Required tools are missing"): await agent.run(ctx) @@ -251,7 +251,7 @@ class AgentWithRequiredTools(MockMCPAgent): tools = [types.Tool(name="required_tool", description="Required", inputSchema={})] ctx = MockEvalContext(prompt="Do something", tools=tools) - agent = AgentWithRequiredTools(auto_trace=False) + agent = AgentWithRequiredTools() result = await agent.run(ctx) assert result.done @@ -270,7 +270,7 @@ def _on_tools_ready(self) -> None: hook_called[0] = True ctx = MockEvalContext(prompt="Do something") - agent = AgentWithHook(auto_trace=False) + agent = AgentWithHook() await agent.run(ctx) assert hook_called[0] @@ -289,7 +289,7 @@ def _on_tools_ready(self) -> None: types.Tool(name="tool2", description="Tool 2", inputSchema={}), ] ctx = MockEvalContext(prompt="Do something", tools=tools) - agent = AgentWithHook(auto_trace=False) + agent = AgentWithHook() await agent.run(ctx) @@ -311,7 +311,7 @@ async def test_get_tool_schemas(self) -> None: ) ] ctx = MockEvalContext(prompt="Do something", tools=tools) - agent = MockMCPAgent(auto_trace=False) + agent = MockMCPAgent() # Initialize agent agent.ctx = ctx diff --git a/hud/agents/tests/test_base_runtime.py b/hud/agents/tests/test_base_runtime.py index f066c8f7..c73f4ebf 100644 --- a/hud/agents/tests/test_base_runtime.py +++ b/hud/agents/tests/test_base_runtime.py @@ -98,7 +98,7 @@ def test_find_reward_and_content_extractors() -> None: def test_get_available_tools_before_run_raises() -> None: """Test that get_available_tools raises before initialization.""" - agent = DummyAgent(auto_trace=False) + agent = DummyAgent() with pytest.raises(RuntimeError): agent.get_available_tools() @@ -106,7 +106,7 @@ def test_get_available_tools_before_run_raises() -> None: @pytest.mark.asyncio async def test_format_message_invalid_type_raises() -> None: """Test that format_message raises for invalid types.""" - agent = DummyAgent(auto_trace=False) + agent = DummyAgent() with pytest.raises(ValueError): await agent.format_message({"oops": 1}) # type: ignore @@ -121,7 +121,7 @@ def test_text_to_blocks_shapes() -> None: async def test_run_with_eval_context() -> None: """Test basic run() with EvalContext.""" ctx = MockEvalContext(prompt="hello") - agent = DummyAgent(auto_trace=False) + agent = DummyAgent() result = await agent.run(ctx, max_steps=1) assert result.done is True assert result.isError is False @@ -130,7 +130,7 @@ async def test_run_with_eval_context() -> None: @pytest.mark.asyncio async def test_run_requires_eval_context() -> None: """Test run() raises TypeError for non-EvalContext.""" - agent = DummyAgent(auto_trace=False) + agent = DummyAgent() with pytest.raises(TypeError, match="must be EvalContext"): await agent.run("hello") # type: ignore @@ -139,7 +139,7 @@ async def test_run_requires_eval_context() -> None: async def test_run_requires_prompt() -> None: """Test run() raises ValueError when prompt is empty.""" ctx = MockEvalContext(prompt="") - agent = DummyAgent(auto_trace=False) + agent = DummyAgent() with pytest.raises(ValueError, match="prompt is not set"): await agent.run(ctx) @@ -158,7 +158,7 @@ def handler(tool_call: MCPToolCall) -> MCPToolResult: ctx = MockEvalContext(prompt="test") ctx.set_call_tool_handler(handler) - agent = DummyAgent(auto_trace=False) + agent = DummyAgent() # Initialize the agent with context agent.ctx = ctx @@ -180,7 +180,7 @@ def handler(tool_call: MCPToolCall) -> MCPToolResult: ctx = MockEvalContext(prompt="test") ctx.set_call_tool_handler(handler) - agent = DummyAgent(auto_trace=False) + agent = DummyAgent() agent.ctx = ctx await agent._initialize_from_ctx(ctx) @@ -194,7 +194,7 @@ async def test_get_available_tools_after_run() -> None: """Test get_available_tools works after initialization.""" tools = [types.Tool(name="test_tool", description="Test", inputSchema={})] ctx = MockEvalContext(prompt="hello", tools=tools) - agent = DummyAgent(auto_trace=False) + agent = DummyAgent() # Run initializes the agent await agent.run(ctx, max_steps=1) diff --git a/hud/agents/tests/test_claude.py b/hud/agents/tests/test_claude.py index a0285ad3..e9a99bd5 100644 --- a/hud/agents/tests/test_claude.py +++ b/hud/agents/tests/test_claude.py @@ -111,13 +111,7 @@ class TestClaudeAgent: @pytest.fixture def mock_anthropic(self) -> Generator[AsyncAnthropic, None, None]: # type: ignore[misc] """Create a stub Anthropic client.""" - with ( - patch("hud.agents.claude.AsyncAnthropic") as mock_class, - patch("hud.agents.claude.Anthropic") as mock_sync, - ): - # Mock the sync client's models.list() for validation - mock_sync.return_value.models.list.return_value = [] - + with patch("hud.agents.claude.AsyncAnthropic") as mock_class: client = MagicMock(spec=AsyncAnthropic) client.api_key = "test-key" mock_class.return_value = client diff --git a/hud/agents/tests/test_grounded_openai_agent.py b/hud/agents/tests/test_grounded_openai_agent.py index 6e748d6e..6e0cdcd2 100644 --- a/hud/agents/tests/test_grounded_openai_agent.py +++ b/hud/agents/tests/test_grounded_openai_agent.py @@ -147,6 +147,7 @@ async def test_get_response_with_reasoning() -> None: mock_response.choices = [mock_choice] agent.oai.chat.completions.create = AsyncMock(return_value=mock_response) + agent._initialized = True # Mark as initialized to skip context initialization agent.conversation_history = [ {"role": "user", "content": [{"type": "text", "text": "Hard question"}]} diff --git a/hud/agents/tests/test_openai.py b/hud/agents/tests/test_openai.py index f36d519f..ebf5f4a0 100644 --- a/hud/agents/tests/test_openai.py +++ b/hud/agents/tests/test_openai.py @@ -352,9 +352,9 @@ async def test_get_response_with_reasoning(self, mock_openai: AsyncOpenAI) -> No agent._initialized = True response = await agent.get_response([]) - # Reasoning is prepended to content in OpenAI agent - assert "Thinking about it..." in response.content - assert "Answer!" in response.content + # Reasoning is stored separately from content + assert response.reasoning == "Thinking about it..." + assert response.content == "Answer!" class TestOpenAIToolConversion: diff --git a/hud/agents/tests/test_run_eval.py b/hud/agents/tests/test_run_eval.py index 0a09b193..d4455962 100644 --- a/hud/agents/tests/test_run_eval.py +++ b/hud/agents/tests/test_run_eval.py @@ -15,7 +15,7 @@ class MockConfig(BaseAgentConfig): model_name: str = "MockAgent" - checkpoint_name: str = "mock-model" + model: str = "mock-model" class MockCreateParams(BaseCreateParams, MockConfig): diff --git a/hud/cli/tests/test_convert.py b/hud/cli/tests/test_convert.py index 3dc5fb19..004b5b69 100644 --- a/hud/cli/tests/test_convert.py +++ b/hud/cli/tests/test_convert.py @@ -8,7 +8,6 @@ import typer from hud.cli.flows.tasks import convert_tasks_to_remote -from hud.types import LegacyTask class TestConvertCommand: @@ -84,12 +83,6 @@ def test_convert_tasks_basic( # Mock derive remote image mock_derive_remote.return_value = "registry.hud.ai/test-org/test-env:v1.0.0" - task = LegacyTask( - prompt="Test task", - mcp_config={ - "local": {"command": "docker", "args": ["run", "--rm", "-i", "test-image:latest"]} - }, - ) raw_task = { "prompt": "Test task", "mcp_config": { @@ -97,7 +90,7 @@ def test_convert_tasks_basic( }, } - mock_load_tasks.side_effect = [[task], [raw_task]] + mock_load_tasks.return_value = [raw_task] # Run conversion result_path = convert_tasks_to_remote(str(temp_tasks_file)) @@ -132,18 +125,18 @@ def test_convert_already_remote( mock_settings.api_key = "test-api-key" mock_find_env.return_value = None # No env dir needed for remote tasks - # Create task that's already remote - task = LegacyTask( - prompt="Test task", - mcp_config={ + # Create task that's already remote (as raw dict) + raw_task = { + "prompt": "Test task", + "mcp_config": { "remote": { "url": "https://mcp.hud.ai", "headers": {"Mcp-Image": "registry.hud.ai/test/image:v1"}, } }, - ) + } - mock_load_tasks.return_value = [task] + mock_load_tasks.return_value = [raw_task] # Should return original path without modification result_path = convert_tasks_to_remote(str(temp_tasks_file)) @@ -159,14 +152,14 @@ def test_convert_no_environment( mock_settings.api_key = "test-api-key" mock_find_env.return_value = None - task = LegacyTask( - prompt="Test task", - mcp_config={ + raw_task = { + "prompt": "Test task", + "mcp_config": { "local": {"command": "docker", "args": ["run", "--rm", "-i", "test-image:latest"]} }, - ) + } - mock_load_tasks.return_value = [task] + mock_load_tasks.return_value = [raw_task] with pytest.raises(typer.Exit): convert_tasks_to_remote(str(temp_tasks_file)) @@ -209,18 +202,17 @@ def test_convert_with_env_vars( env_file = mock_env_dir / ".env" env_file.write_text("OPENAI_API_KEY=sk-test123\nANTHROPIC_API_KEY=sk-ant456") - task = LegacyTask( - prompt="Test task", - mcp_config={ + raw_task = { + "prompt": "Test task", + "mcp_config": { "local": { "command": "docker", "args": ["run", "--rm", "-i", "-e", "OPENAI_API_KEY", "test-image:latest"], } }, - ) - raw_task = task.model_dump() + } - mock_load_tasks.side_effect = [[task], [raw_task]] + mock_load_tasks.return_value = [raw_task] # Run conversion result_path = convert_tasks_to_remote(str(temp_tasks_file)) diff --git a/hud/cli/tests/test_eval.py b/hud/cli/tests/test_eval.py index 6ee1a9a4..1aa6d3cb 100644 --- a/hud/cli/tests/test_eval.py +++ b/hud/cli/tests/test_eval.py @@ -96,7 +96,7 @@ async def test_run_dataset_with_string_source(self) -> None: mock_ctx = MockEvalContext() with ( - patch("hud.datasets.runner.load_tasks", return_value=mock_tasks) as mock_load, + patch("hud.datasets.loader.load_tasks", return_value=mock_tasks) as mock_load, patch("hud.datasets.runner.hud.eval") as mock_eval, patch.object(AgentType.OPENAI, "cls", mock_agent_cls), ): diff --git a/hud/cli/utils/interactive.py b/hud/cli/utils/interactive.py index 1f8d1da7..53a2ab34 100644 --- a/hud/cli/utils/interactive.py +++ b/hud/cli/utils/interactive.py @@ -50,7 +50,6 @@ async def connect(self) -> bool: self.client = MCPClient( mcp_config=config, verbose=self.verbose, - auto_trace=False, # Disable telemetry for interactive testing ) await self.client.initialize() diff --git a/hud/cli/utils/tests/test_interactive_module.py b/hud/cli/utils/tests/test_interactive_module.py index e234abd9..993233dc 100644 --- a/hud/cli/utils/tests/test_interactive_module.py +++ b/hud/cli/utils/tests/test_interactive_module.py @@ -9,7 +9,7 @@ @pytest.mark.asyncio -@patch("hud.cli.utils.interactive.MCPClient") +@patch("hud.clients.MCPClient") async def test_connect_and_disconnect(MockClient): client = AsyncMock() client.initialize.return_value = None diff --git a/hud/server/tests/test_mcp_server_integration.py b/hud/server/tests/test_mcp_server_integration.py index 10bc2c33..b575c04a 100644 --- a/hud/server/tests/test_mcp_server_integration.py +++ b/hud/server/tests/test_mcp_server_integration.py @@ -84,7 +84,7 @@ async def echo(text: str = "ok") -> str: # type: ignore[override] async def connect_and_check() -> None: cfg = {"srv": {"url": f"http://127.0.0.1:{port}/mcp"}} - client = MCPClient(mcp_config=cfg, auto_trace=False, verbose=False) + client = MCPClient(mcp_config=cfg, verbose=False) await client.initialize() tools = await client.list_tools() names = sorted(t.name for t in tools) @@ -123,7 +123,7 @@ async def _on_shutdown() -> None: try: # sanity connect so lifespan actually ran cfg = {"srv": {"url": f"http://127.0.0.1:{port}/mcp"}} - c = MCPClient(mcp_config=cfg, auto_trace=False, verbose=False) + c = MCPClient(mcp_config=cfg, verbose=False) await c.initialize() await c.shutdown() finally: @@ -140,7 +140,7 @@ async def _on_shutdown() -> None: server_task2 = await _start_http_server(mcp, port=port2) try: cfg = {"srv": {"url": f"http://127.0.0.1:{port2}/mcp"}} - c = MCPClient(mcp_config=cfg, auto_trace=False, verbose=False) + c = MCPClient(mcp_config=cfg, verbose=False) await c.initialize() await c.shutdown() @@ -170,7 +170,7 @@ async def _init(_ctx) -> None: server_task = await _start_http_server(mcp, port) cfg = {"srv": {"url": f"http://127.0.0.1:{port}/mcp"}} - client = MCPClient(mcp_config=cfg, auto_trace=False, verbose=False) + client = MCPClient(mcp_config=cfg, verbose=False) try: with pytest.raises(Exception): @@ -211,7 +211,7 @@ async def _init(_ctx) -> None: async def connect_and_check() -> None: cfg = {"srv": {"url": f"http://127.0.0.1:{port}/mcp"}} - c = MCPClient(mcp_config=cfg, auto_trace=False, verbose=False) + c = MCPClient(mcp_config=cfg, verbose=False) await c.initialize() tools = await c.list_tools() names = sorted(t.name for t in tools) @@ -244,7 +244,7 @@ async def echo(text: str = "ok") -> str: # type: ignore[override] server_task = await _start_http_server(mcp, port) try: cfg = {"srv": {"url": f"http://127.0.0.1:{port}/mcp"}} - c = MCPClient(mcp_config=cfg, auto_trace=False, verbose=False) + c = MCPClient(mcp_config=cfg, verbose=False) await c.initialize() # Call with no args → default should kick in res = await c.call_tool(name="echo", arguments={}) @@ -273,7 +273,7 @@ async def _on_shutdown() -> None: try: # Ensure lifespan started cfg = {"srv": {"url": f"http://127.0.0.1:{port}/mcp"}} - c = MCPClient(mcp_config=cfg, auto_trace=False, verbose=False) + c = MCPClient(mcp_config=cfg, verbose=False) await c.initialize() await c.shutdown() @@ -315,7 +315,7 @@ async def _init(ctx) -> None: # type: ignore[override] server_task = await _start_http_server(mcp, port) try: cfg = {"srv": {"url": f"http://127.0.0.1:{port}/mcp"}} - c = MCPClient(mcp_config=cfg, auto_trace=False, verbose=False) + c = MCPClient(mcp_config=cfg, verbose=False) await c.initialize() await c.shutdown() finally: @@ -344,7 +344,7 @@ async def _init(_ctx) -> None: try: cfg = {"srv": {"url": f"http://127.0.0.1:{port}/mcp"}} - c = MCPClient(mcp_config=cfg, auto_trace=False, verbose=False) + c = MCPClient(mcp_config=cfg, verbose=False) await c.initialize() await c.shutdown() finally: @@ -373,11 +373,11 @@ async def _init(_ctx) -> None: server_task = await _start_http_server(mcp, port) try: cfg = {"srv": {"url": f"http://127.0.0.1:{port}/mcp"}} - c1 = MCPClient(mcp_config=cfg, auto_trace=False, verbose=False) + c1 = MCPClient(mcp_config=cfg, verbose=False) await c1.initialize() await c1.shutdown() - c2 = MCPClient(mcp_config=cfg, auto_trace=False, verbose=False) + c2 = MCPClient(mcp_config=cfg, verbose=False) await c2.initialize() await c2.shutdown() finally: diff --git a/hud/server/tests/test_mcp_server_more.py b/hud/server/tests/test_mcp_server_more.py index 875424ba..d364837a 100644 --- a/hud/server/tests/test_mcp_server_more.py +++ b/hud/server/tests/test_mcp_server_more.py @@ -142,7 +142,7 @@ async def echo(text: str = "ok") -> str: # type: ignore[override] try: cfg = {"srv": {"url": f"http://127.0.0.1:{port}/mcp"}} - c = MCPClient(mcp_config=cfg, auto_trace=False, verbose=False) + c = MCPClient(mcp_config=cfg, verbose=False) await c.initialize() # Call a tool to ensure init didn't break anything diff --git a/hud/tests/test_datasets_extended.py b/hud/tests/test_datasets_extended.py index 57d20b2b..b9a97291 100644 --- a/hud/tests/test_datasets_extended.py +++ b/hud/tests/test_datasets_extended.py @@ -190,7 +190,7 @@ async def test_run_dataset_from_source_string(self): mock_agent_cls.create.return_value = mock_agent_instance with ( - patch("hud.datasets.runner.load_tasks", return_value=mock_tasks) as mock_load, + patch("hud.datasets.loader.load_tasks", return_value=mock_tasks) as mock_load, patch("hud.datasets.runner.hud.eval") as mock_eval, patch.object(AgentType.OPENAI, "cls", mock_agent_cls), ): From 3b1cda08140688e61f4f6c4f9296eb1356fbea66 Mon Sep 17 00:00:00 2001 From: lorenss-m Date: Sun, 14 Dec 2025 09:42:09 -0800 Subject: [PATCH 57/92] docs --- docs/index.mdx | 2 +- docs/quick-links/deploy.mdx | 2 +- pyproject.toml | 12 +++++------- 3 files changed, 7 insertions(+), 9 deletions(-) diff --git a/docs/index.mdx b/docs/index.mdx index 8841bbea..6bd837f1 100644 --- a/docs/index.mdx +++ b/docs/index.mdx @@ -95,7 +95,7 @@ Push your environment to GitHub, connect it on [hud.ai](https://hud.ai), and run ```bash hud init # Scaffold environment git push # Push to GitHub -# Connect on hud.ai → New Environment +# Connect on hud.ai → New → Environment hud eval my-org/my-eval --model gpt-4o --group-size 100 ``` diff --git a/docs/quick-links/deploy.mdx b/docs/quick-links/deploy.mdx index dd55b2c1..5553c7d8 100644 --- a/docs/quick-links/deploy.mdx +++ b/docs/quick-links/deploy.mdx @@ -10,7 +10,7 @@ You've built an environment with tools and scripts. Deploy it to the platform an Start with `hud init` ([see Environments](/quick-links/environments)) to scaffold locally. When ready: -1. Go to [hud.ai](https://hud.ai) → **New Environment** +1. Go to [hud.ai](https://hud.ai) → **New** → **Environment** 2. Connect your GitHub repo and name your environment 3. Push changes and it rebuilds automatically, like Vercel diff --git a/pyproject.toml b/pyproject.toml index 3fdf05f0..1176875d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -17,6 +17,8 @@ dependencies = [ # MCP dependencies "mcp>1.21.1,<1.23", "fastmcp==2.13.3", + # For all inference agents + "openai>=2.8.1", # CLI dependencies "typer>=0.9.0", "rich>=13.0.0", @@ -112,16 +114,10 @@ agents = [ "langchain>=1.1.0", # Required by mcp-use # AI providers "anthropic>=0.75", - "openai>=2.8.1", "google-genai", "openai-agents", # Dataset loading (HuggingFace) "datasets>=2.14.0", - # Telemetry / OpenTelemetry tracing - "opentelemetry-instrumentation-mcp==0.47.0", - "opentelemetry-api>=1.34.1", - "opentelemetry-sdk>=1.34.1", - "opentelemetry-exporter-otlp-proto-http>=1.34.1", # Image processing for screenshots/grounding "pillow>=11.1.0", # Jupyter kernel support @@ -152,6 +148,9 @@ dev = [ # Automation and computer control "playwright", "pyautogui>=0.9.54", + # Optional integrations (for type checking) + "llama-index-core", + "google-adk", ] # Alias for backwards compatibility @@ -212,7 +211,6 @@ exclude = [ "**/node_modules", "**/__pycache__", "**/venv", - "hud/misc/claude_plays_pokemon.py", ] pythonVersion = "3.11" typeCheckingMode = "basic" From 0b6557d18201018f0500af7163f032d9ed2459c7 Mon Sep 17 00:00:00 2001 From: lorenss-m Date: Sun, 14 Dec 2025 09:48:35 -0800 Subject: [PATCH 58/92] test fixes --- hud/agents/tests/test_base.py | 2 +- hud/agents/tests/test_base_runtime.py | 2 +- hud/cli/tests/test_eval.py | 12 ++++++------ hud/tests/test_datasets_extended.py | 10 +++++----- 4 files changed, 13 insertions(+), 13 deletions(-) diff --git a/hud/agents/tests/test_base.py b/hud/agents/tests/test_base.py index c8e19a9f..092f1b66 100644 --- a/hud/agents/tests/test_base.py +++ b/hud/agents/tests/test_base.py @@ -15,7 +15,7 @@ class MockConfig(BaseAgentConfig): model_name: str = "MockAgent" - checkpoint_name: str = "mock-model" + model: str = "mock-model" class MockCreateParams(BaseCreateParams, MockConfig): diff --git a/hud/agents/tests/test_base_runtime.py b/hud/agents/tests/test_base_runtime.py index c73f4ebf..0be1a3db 100644 --- a/hud/agents/tests/test_base_runtime.py +++ b/hud/agents/tests/test_base_runtime.py @@ -14,7 +14,7 @@ class DummyConfig(BaseAgentConfig): model_name: str = "DummyAgent" - checkpoint_name: str = "dummy-model" + model: str = "dummy-model" class DummyCreateParams(BaseCreateParams, DummyConfig): diff --git a/hud/cli/tests/test_eval.py b/hud/cli/tests/test_eval.py index 1aa6d3cb..257158a9 100644 --- a/hud/cli/tests/test_eval.py +++ b/hud/cli/tests/test_eval.py @@ -67,7 +67,7 @@ async def test_run_dataset_with_task_list(self) -> None: with ( patch("hud.datasets.runner.hud.eval") as mock_eval, - patch.object(AgentType.CLAUDE, "cls", mock_agent_cls), + patch("hud.agents.claude.ClaudeAgent", mock_agent_cls), ): # Set up the async context manager mock_eval.return_value.__aenter__ = AsyncMock(return_value=mock_ctx) @@ -98,7 +98,7 @@ async def test_run_dataset_with_string_source(self) -> None: with ( patch("hud.datasets.loader.load_tasks", return_value=mock_tasks) as mock_load, patch("hud.datasets.runner.hud.eval") as mock_eval, - patch.object(AgentType.OPENAI, "cls", mock_agent_cls), + patch("hud.agents.openai.OpenAIAgent", mock_agent_cls), ): mock_eval.return_value.__aenter__ = AsyncMock(return_value=mock_ctx) mock_eval.return_value.__aexit__ = AsyncMock(return_value=None) @@ -130,7 +130,7 @@ async def test_run_dataset_with_group_size(self) -> None: with ( patch("hud.datasets.runner.hud.eval") as mock_eval, - patch.object(AgentType.CLAUDE, "cls", mock_agent_cls), + patch("hud.agents.claude.ClaudeAgent", mock_agent_cls), ): mock_eval.return_value.__aenter__ = AsyncMock(return_value=mock_ctx) mock_eval.return_value.__aexit__ = AsyncMock(return_value=None) @@ -153,7 +153,7 @@ async def test_run_dataset_with_max_concurrent(self) -> None: with ( patch("hud.datasets.runner.hud.eval") as mock_eval, - patch.object(AgentType.CLAUDE, "cls", mock_agent_cls), + patch("hud.agents.claude.ClaudeAgent", mock_agent_cls), ): mock_eval.return_value.__aenter__ = AsyncMock(return_value=mock_ctx) mock_eval.return_value.__aexit__ = AsyncMock(return_value=None) @@ -176,7 +176,7 @@ async def test_run_dataset_returns_results(self) -> None: with ( patch("hud.datasets.runner.hud.eval") as mock_eval, - patch.object(AgentType.CLAUDE, "cls", mock_agent_cls), + patch("hud.agents.claude.ClaudeAgent", mock_agent_cls), ): mock_eval.return_value.__aenter__ = AsyncMock(return_value=mock_ctx) mock_eval.return_value.__aexit__ = AsyncMock(return_value=None) @@ -208,7 +208,7 @@ async def test_run_dataset_parallel_results(self) -> None: with ( patch("hud.datasets.runner.hud.eval") as mock_eval, - patch.object(AgentType.CLAUDE, "cls", mock_agent_cls), + patch("hud.agents.claude.ClaudeAgent", mock_agent_cls), ): mock_eval.return_value.__aenter__ = AsyncMock(return_value=mock_ctx) mock_eval.return_value.__aexit__ = AsyncMock(return_value=None) diff --git a/hud/tests/test_datasets_extended.py b/hud/tests/test_datasets_extended.py index b9a97291..6d45ff03 100644 --- a/hud/tests/test_datasets_extended.py +++ b/hud/tests/test_datasets_extended.py @@ -137,7 +137,7 @@ async def test_run_dataset_empty(self): async def test_run_dataset_with_task_list(self): """Test run_dataset with Task objects.""" from hud.eval.task import Task - from hud.types import AgentType, Trace + from hud.types import Trace # Create mock tasks with env as dict (to avoid real connections) mock_env = {"name": "test"} @@ -160,7 +160,7 @@ async def test_run_dataset_with_task_list(self): with ( patch("hud.datasets.runner.hud.eval") as mock_eval, - patch.object(AgentType.CLAUDE, "cls", mock_agent_cls), + patch("hud.agents.claude.ClaudeAgent", mock_agent_cls), ): mock_eval.return_value.__aenter__ = AsyncMock(return_value=mock_ctx) mock_eval.return_value.__aexit__ = AsyncMock(return_value=None) @@ -175,7 +175,7 @@ async def test_run_dataset_with_task_list(self): async def test_run_dataset_from_source_string(self): """Test run_dataset with source string calls load_tasks.""" from hud.eval.task import Task - from hud.types import AgentType, Trace + from hud.types import Trace mock_env = {"name": "test"} mock_tasks = [Task(env=mock_env, scenario="loaded")] # type: ignore[arg-type] @@ -192,7 +192,7 @@ async def test_run_dataset_from_source_string(self): with ( patch("hud.datasets.loader.load_tasks", return_value=mock_tasks) as mock_load, patch("hud.datasets.runner.hud.eval") as mock_eval, - patch.object(AgentType.OPENAI, "cls", mock_agent_cls), + patch("hud.agents.openai.OpenAIAgent", mock_agent_cls), ): mock_eval.return_value.__aenter__ = AsyncMock(return_value=mock_ctx) mock_eval.return_value.__aexit__ = AsyncMock(return_value=None) @@ -222,7 +222,7 @@ async def test_run_dataset_passes_parameters(self): with ( patch("hud.datasets.runner.hud.eval") as mock_eval, - patch.object(AgentType.CLAUDE, "cls", mock_agent_cls), + patch("hud.agents.claude.ClaudeAgent", mock_agent_cls), ): mock_eval.return_value.__aenter__ = AsyncMock(return_value=mock_ctx) mock_eval.return_value.__aexit__ = AsyncMock(return_value=None) From 8d568babf24c6beb87a99f13dd738fc45c6dbfe9 Mon Sep 17 00:00:00 2001 From: lorenss-m Date: Sun, 14 Dec 2025 11:49:32 -0800 Subject: [PATCH 59/92] test fixes --- docs/migration.mdx | 2 +- hud/agents/tests/test_grounded_openai_agent.py | 13 ++++++++++++- hud/cli/tests/test_eval.py | 2 +- hud/tests/test_datasets_extended.py | 16 ++++++++-------- hud/types.py | 4 ++-- 5 files changed, 24 insertions(+), 13 deletions(-) diff --git a/docs/migration.mdx b/docs/migration.mdx index 81640469..f8f76ce0 100644 --- a/docs/migration.mdx +++ b/docs/migration.mdx @@ -7,7 +7,7 @@ icon: "arrow-right-arrow-left" v4 separated environments (Docker containers) from evaluation logic (Task objects). v5 unifies everything in the `Environment` class—tools, setup, and scoring live together. -**Deprecation Notice**: `LegacyTask`, `setup_tool`, and `evaluate_tool` are deprecated in v0.5.0 and will be removed in v0.6.0 (no earlier than March 1st, 2025). Use `Task.from_v4()` for quick migration or `@env.scenario()` for new code. +**Deprecation Notice**: `LegacyTask`, `setup_tool`, and `evaluate_tool` are deprecated in v0.5.0 and will be removed in v0.6.0 (no earlier than March 1st, 2026). Use `Task.from_v4()` for quick migration or `@env.scenario()` for new code. ## Good News: Your Code Still Works diff --git a/hud/agents/tests/test_grounded_openai_agent.py b/hud/agents/tests/test_grounded_openai_agent.py index 6e0cdcd2..04bab667 100644 --- a/hud/agents/tests/test_grounded_openai_agent.py +++ b/hud/agents/tests/test_grounded_openai_agent.py @@ -149,8 +149,19 @@ async def test_get_response_with_reasoning() -> None: agent.oai.chat.completions.create = AsyncMock(return_value=mock_response) agent._initialized = True # Mark as initialized to skip context initialization + # Include an image so get_response doesn't try to take a screenshot via ctx + png_b64 = ( + "iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAQAAAC1HAwCAAAAC0lEQVR4nGMAAQAABQAB" + "J2n0mQAAAABJRU5ErkJggg==" + ) agent.conversation_history = [ - {"role": "user", "content": [{"type": "text", "text": "Hard question"}]} + { + "role": "user", + "content": [ + {"type": "image_url", "image_url": {"url": f"data:image/png;base64,{png_b64}"}}, + {"type": "text", "text": "Hard question"}, + ], + } ] response = await agent.get_response(agent.conversation_history) diff --git a/hud/cli/tests/test_eval.py b/hud/cli/tests/test_eval.py index 257158a9..eb9d11a1 100644 --- a/hud/cli/tests/test_eval.py +++ b/hud/cli/tests/test_eval.py @@ -98,7 +98,7 @@ async def test_run_dataset_with_string_source(self) -> None: with ( patch("hud.datasets.loader.load_tasks", return_value=mock_tasks) as mock_load, patch("hud.datasets.runner.hud.eval") as mock_eval, - patch("hud.agents.openai.OpenAIAgent", mock_agent_cls), + patch("hud.agents.OpenAIAgent", mock_agent_cls), ): mock_eval.return_value.__aenter__ = AsyncMock(return_value=mock_ctx) mock_eval.return_value.__aexit__ = AsyncMock(return_value=None) diff --git a/hud/tests/test_datasets_extended.py b/hud/tests/test_datasets_extended.py index 6d45ff03..fab7bcb0 100644 --- a/hud/tests/test_datasets_extended.py +++ b/hud/tests/test_datasets_extended.py @@ -3,7 +3,7 @@ from __future__ import annotations from typing import cast -from unittest.mock import AsyncMock, patch +from unittest.mock import AsyncMock, MagicMock, patch import pytest @@ -152,10 +152,10 @@ async def test_run_dataset_with_task_list(self): mock_ctx.results = None mock_ctx.reward = None - # Create mock agent class and instance + # Create mock agent class and instance (use MagicMock since create() is sync) mock_agent_instance = AsyncMock() mock_agent_instance.run.return_value = Trace(reward=1.0, done=True) - mock_agent_cls = AsyncMock() + mock_agent_cls = MagicMock() mock_agent_cls.create.return_value = mock_agent_instance with ( @@ -183,16 +183,16 @@ async def test_run_dataset_from_source_string(self): mock_ctx = AsyncMock() mock_ctx.results = None - # Create mock agent class and instance + # Create mock agent class and instance (use MagicMock since create() is sync) mock_agent_instance = AsyncMock() mock_agent_instance.run.return_value = Trace(reward=1.0, done=True) - mock_agent_cls = AsyncMock() + mock_agent_cls = MagicMock() mock_agent_cls.create.return_value = mock_agent_instance with ( patch("hud.datasets.loader.load_tasks", return_value=mock_tasks) as mock_load, patch("hud.datasets.runner.hud.eval") as mock_eval, - patch("hud.agents.openai.OpenAIAgent", mock_agent_cls), + patch("hud.agents.OpenAIAgent", mock_agent_cls), ): mock_eval.return_value.__aenter__ = AsyncMock(return_value=mock_ctx) mock_eval.return_value.__aexit__ = AsyncMock(return_value=None) @@ -214,10 +214,10 @@ async def test_run_dataset_passes_parameters(self): mock_ctx = AsyncMock() mock_ctx.results = None - # Create mock agent class and instance + # Create mock agent class and instance (use MagicMock since create() is sync) mock_agent_instance = AsyncMock() mock_agent_instance.run.return_value = Trace(reward=1.0, done=True) - mock_agent_cls = AsyncMock() + mock_agent_cls = MagicMock() mock_agent_cls.create.return_value = mock_agent_instance with ( diff --git a/hud/types.py b/hud/types.py index 6502e7b4..5525ac29 100644 --- a/hud/types.py +++ b/hud/types.py @@ -96,7 +96,7 @@ class LegacyTask(BaseModel): .. deprecated:: 0.5.0 LegacyTask is deprecated in v0.5.0 and will be removed in v0.6.0 - (no earlier than March 1st, 2025). + (no earlier than March 1st, 2026). Use one of these migration paths: @@ -133,7 +133,7 @@ def __init__(self, **data: Any) -> None: warnings.warn( "LegacyTask is deprecated in v0.5.0 and will be removed in v0.6.0 " - "(no earlier than March 1st, 2025). " + "(no earlier than March 1st, 2026). " "Use Task.from_v4() for quick conversion, or migrate to @env.scenario(). " "See https://docs.hud.ai/migration for details.", DeprecationWarning, From 28e290f9f5fa661e23c2b85cecb71243d2c915e9 Mon Sep 17 00:00:00 2001 From: lorenss-m Date: Sun, 14 Dec 2025 12:30:46 -0800 Subject: [PATCH 60/92] adjustments to instrumentation --- hud/agents/base.py | 2 -- hud/agents/claude.py | 6 ------ hud/agents/gemini.py | 6 ------ hud/agents/grounded_openai.py | 6 ------ hud/agents/openai.py | 6 ------ hud/agents/openai_chat.py | 6 ------ hud/eval/display.py | 9 +++++++++ hud/tools/grounding/grounder.py | 7 ------- 8 files changed, 9 insertions(+), 39 deletions(-) diff --git a/hud/agents/base.py b/hud/agents/base.py index 5bb2ef54..4e65f944 100644 --- a/hud/agents/base.py +++ b/hud/agents/base.py @@ -376,8 +376,6 @@ async def get_response(self, messages: list[Any]) -> AgentResponse: """ Get response from the model including any tool calls. - NOTE: Subclasses should decorate this method with: - @hud.instrument(span_type="agent", record_args=False, record_result=True) Args: messages: Current conversation messages diff --git a/hud/agents/claude.py b/hud/agents/claude.py index 114eaf74..f3d803c2 100644 --- a/hud/agents/claude.py +++ b/hud/agents/claude.py @@ -25,7 +25,6 @@ ) from pydantic import ConfigDict -import hud from hud.settings import settings from hud.tools.computer.settings import computer_settings from hud.types import AgentResponse, BaseAgentConfig, MCPToolCall, MCPToolResult @@ -134,11 +133,6 @@ async def format_blocks(self, blocks: list[types.ContentBlock]) -> list[BetaMess return [BetaMessageParam(role="user", content=anthropic_blocks)] - @hud.instrument( - span_type="agent", - record_args=False, # Messages can be large - record_result=True, - ) async def get_response(self, messages: list[BetaMessageParam]) -> AgentResponse: """Get response from Claude including any tool calls.""" diff --git a/hud/agents/gemini.py b/hud/agents/gemini.py index 88eaa3ef..9231a1eb 100644 --- a/hud/agents/gemini.py +++ b/hud/agents/gemini.py @@ -10,7 +10,6 @@ from google.genai import types as genai_types from pydantic import ConfigDict -import hud from hud.settings import settings from hud.types import AgentResponse, BaseAgentConfig, MCPToolCall, MCPToolResult from hud.utils.hud_console import HUDConsole @@ -115,11 +114,6 @@ async def format_blocks(self, blocks: list[types.ContentBlock]) -> list[genai_ty return [genai_types.Content(role="user", parts=gemini_parts)] - @hud.instrument( - span_type="agent", - record_args=False, # Messages can be large - record_result=True, - ) async def get_response(self, messages: list[genai_types.Content]) -> AgentResponse: """Get response from Gemini including any tool calls.""" diff --git a/hud/agents/grounded_openai.py b/hud/agents/grounded_openai.py index e86cb3de..441bbda9 100644 --- a/hud/agents/grounded_openai.py +++ b/hud/agents/grounded_openai.py @@ -7,7 +7,6 @@ from pydantic import ConfigDict, field_validator -from hud import instrument from hud.tools.grounding import GroundedComputerTool, Grounder, GrounderConfig from hud.types import AgentResponse, MCPToolCall, MCPToolResult from hud.utils.types import with_signature @@ -104,11 +103,6 @@ def get_tool_schemas(self) -> list[Any]: return [] return [self.grounded_tool.get_openai_tool_schema()] - @instrument( - span_type="agent", - record_args=False, - record_result=True, - ) async def get_response(self, messages: Any) -> AgentResponse: """Get response from the planning model and handle grounded tool calls. diff --git a/hud/agents/openai.py b/hud/agents/openai.py index c4e4c04e..229af849 100644 --- a/hud/agents/openai.py +++ b/hud/agents/openai.py @@ -31,7 +31,6 @@ from openai.types.shared_params.reasoning import Reasoning # noqa: TC002 from pydantic import ConfigDict -import hud from hud.settings import settings from hud.types import AgentResponse, BaseAgentConfig, MCPToolCall, MCPToolResult, Trace from hud.utils.strict_schema import ensure_strict_json_schema @@ -220,11 +219,6 @@ async def format_blocks(self, blocks: list[types.ContentBlock]) -> list[Message] content.append(ResponseInputTextParam(type="input_text", text="")) return [Message(role="user", content=content)] - @hud.instrument( - span_type="agent", - record_args=False, - record_result=True, - ) async def get_response(self, messages: ResponseInputParam) -> AgentResponse: """Send the latest input items to OpenAI's Responses API.""" new_items: ResponseInputParam = messages[self._message_cursor :] diff --git a/hud/agents/openai_chat.py b/hud/agents/openai_chat.py index f041e4b2..e4e61b05 100644 --- a/hud/agents/openai_chat.py +++ b/hud/agents/openai_chat.py @@ -24,7 +24,6 @@ from openai import AsyncOpenAI from pydantic import ConfigDict, Field -from hud import instrument from hud.settings import settings from hud.types import AgentResponse, BaseAgentConfig, MCPToolCall, MCPToolResult from hud.utils.hud_console import HUDConsole @@ -227,11 +226,6 @@ async def _invoke_chat_completion( **extra, ) # type: ignore - @instrument( - span_type="agent", - record_args=False, - record_result=True, - ) async def get_response(self, messages: list[dict[str, Any]]) -> AgentResponse: """Send chat request to OpenAI and convert the response.""" diff --git a/hud/eval/display.py b/hud/eval/display.py index 7600ac47..e759eaa9 100644 --- a/hud/eval/display.py +++ b/hud/eval/display.py @@ -165,6 +165,7 @@ def display_results( # Check if we have variants (grouped parallel runs) has_variants = any(getattr(r, "variants", None) for r in results if r) + has_prompts = any(getattr(r, "prompt", None) for r in results if r) has_answers = any(getattr(r, "answer", None) for r in results if r) if has_variants: @@ -172,6 +173,9 @@ def display_results( elif tasks: table.add_column("Task", style="cyan", max_width=30) + if has_prompts: + table.add_column("Prompt", style="dim", max_width=35) + if has_answers: table.add_column("Answer", style="dim", max_width=35) @@ -189,6 +193,7 @@ def display_results( error = getattr(r, "error", None) duration = getattr(r, "duration", 0) variants = getattr(r, "variants", None) + prompt = getattr(r, "prompt", None) answer = getattr(r, "answer", None) # Status icon @@ -209,6 +214,10 @@ def display_results( task_label = _get_task_label(task, i) row.append(task_label[:30]) + # Prompt column + if has_prompts: + row.append(_truncate(prompt, 35)) + # Answer column if has_answers: row.append(_truncate(answer, 35)) diff --git a/hud/tools/grounding/grounder.py b/hud/tools/grounding/grounder.py index 29a073e7..862432d0 100644 --- a/hud/tools/grounding/grounder.py +++ b/hud/tools/grounding/grounder.py @@ -9,7 +9,6 @@ from openai import AsyncOpenAI -from hud import instrument from hud.tools.grounding.config import GrounderConfig # noqa: TC001 logger = logging.getLogger(__name__) @@ -182,12 +181,6 @@ def _convert_coordinates( return (final_x, final_y) - @instrument( - name="Grounding.predict_click", - span_type="agent", - record_args=True, - record_result=True, - ) async def predict_click( self, *, image_b64: str, instruction: str, max_retries: int = 3 ) -> tuple[int, int] | None: From cfab31f498ef13d01f54f94f29cd94db24b5974c Mon Sep 17 00:00:00 2001 From: lorenss-m Date: Sun, 14 Dec 2025 12:54:08 -0800 Subject: [PATCH 61/92] type fix --- hud/agents/tests/test_operator.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/hud/agents/tests/test_operator.py b/hud/agents/tests/test_operator.py index 3303855a..a7cb7264 100644 --- a/hud/agents/tests/test_operator.py +++ b/hud/agents/tests/test_operator.py @@ -223,7 +223,7 @@ async def test_get_model_response( mock_openai.responses.create = AsyncMock(return_value=mock_response) messages = [{"prompt": "What's on the screen?", "screenshot": None}] - response = await agent.get_response(messages) + response = await agent.get_response(messages) # type: ignore[arg-type] assert response.done is True assert response.tool_calls == [] @@ -251,7 +251,7 @@ async def test_handle_empty_response( mock_openai.responses.create = AsyncMock(return_value=mock_response) messages = [{"prompt": "Hi", "screenshot": None}] - response = await agent.get_response(messages) + response = await agent.get_response(messages) # type: ignore[arg-type] assert response.content == "" assert response.tool_calls == [] From 5a4930851ef313151aed9c53cec319f26868dc00 Mon Sep 17 00:00:00 2001 From: lorenss-m Date: Sun, 14 Dec 2025 16:15:04 -0800 Subject: [PATCH 62/92] add prompt schema --- hud/cli/build.py | 2 +- hud/cli/debug.py | 10 ++++----- hud/clients/__init__.py | 4 ---- hud/clients/base.py | 17 +++++++++++---- hud/environment/scenarios.py | 41 +++++++++++++++++++++++++++++------- 5 files changed, 52 insertions(+), 22 deletions(-) diff --git a/hud/cli/build.py b/hud/cli/build.py index ba00dcb7..53f949dc 100644 --- a/hud/cli/build.py +++ b/hud/cli/build.py @@ -454,7 +454,7 @@ async def analyze_mcp_environment( from hud.clients.fastmcp import FastMCPHUDClient start_time = time.time() - client = FastMCPHUDClient(mcp_config=mcp_config, verbose=verbose, auto_trace=False) + client = FastMCPHUDClient(mcp_config=mcp_config, verbose=verbose) initialized = False try: diff --git a/hud/cli/debug.py b/hud/cli/debug.py index 252546e0..07c7924a 100644 --- a/hud/cli/debug.py +++ b/hud/cli/debug.py @@ -246,9 +246,9 @@ def read_stderr() -> None: logger.info("Creating MCP client via hud...") # Lazy import to avoid loading mcp_use on simple CLI commands - from hud.clients import MCPClient + from hud.clients.fastmcp import FastMCPHUDClient - client = MCPClient(mcp_config=mcp_config, verbose=False, auto_trace=False) + client = FastMCPHUDClient(mcp_config=mcp_config, verbose=False) await client.initialize() # Wait for initialization @@ -353,7 +353,7 @@ def read_stderr() -> None: logger.info("Creating 3 concurrent MCP clients...") # Lazy import to avoid loading mcp_use on simple CLI commands - from hud.clients import MCPClient + from hud.clients.fastmcp import FastMCPHUDClient for i in range(3): client_config = { @@ -363,8 +363,8 @@ def read_stderr() -> None: } } - concurrent_client = MCPClient( - mcp_config=client_config, verbose=False, auto_trace=False + concurrent_client = FastMCPHUDClient( + mcp_config=client_config, verbose=False ) await concurrent_client.initialize() concurrent_clients.append(concurrent_client) diff --git a/hud/clients/__init__.py b/hud/clients/__init__.py index 1f0f62c8..bec6841b 100644 --- a/hud/clients/__init__.py +++ b/hud/clients/__init__.py @@ -9,10 +9,6 @@ # Default to FastMCP client (no optional dependencies) MCPClient = FastMCPHUDClient -# Note: MCPUseHUDClient requires mcp-use (optional dependency in [agents]). -# Import directly if needed: -# from hud.clients.mcp_use import MCPUseHUDClient - __all__ = [ "AgentMCPClient", "BaseHUDClient", diff --git a/hud/clients/base.py b/hud/clients/base.py index b8304b0e..1e735cce 100644 --- a/hud/clients/base.py +++ b/hud/clients/base.py @@ -417,10 +417,13 @@ async def analyze_environment(self) -> dict[str, Any]: "description": prompt.description, "arguments": args, } - # Include meta field if present (contains scenario source code) + # Include meta field if present (contains scenario source code and argumentsSchema) meta = getattr(prompt, "meta", None) if meta: prompt_info["meta"] = meta + # Extract argumentsSchema to top level for easier access + if isinstance(meta, dict) and "argumentsSchema" in meta: + prompt_info["argumentsSchema"] = meta["argumentsSchema"] analysis["prompts"].append(prompt_info) except Exception as e: if self.verbose: @@ -450,10 +453,16 @@ async def analyze_environment(self) -> dict[str, Any]: "has_setup_prompt": True, "has_evaluate_resource": False, } - # Extract code from meta field if present + # Extract code and argumentsSchema from meta field if present meta = p.get("meta") - if meta and isinstance(meta, dict) and "code" in meta: - scenario_info["code"] = meta["code"] + if meta and isinstance(meta, dict): + if "code" in meta: + scenario_info["code"] = meta["code"] + if "argumentsSchema" in meta: + scenario_info["argumentsSchema"] = meta["argumentsSchema"] + # Also check top-level argumentsSchema (extracted earlier) + if p.get("argumentsSchema"): + scenario_info["argumentsSchema"] = p["argumentsSchema"] scenarios_by_id[scenario_id] = scenario_info for r in analysis.get("resources", []): diff --git a/hud/environment/scenarios.py b/hud/environment/scenarios.py index ea87102c..92399918 100644 --- a/hud/environment/scenarios.py +++ b/hud/environment/scenarios.py @@ -304,12 +304,33 @@ def decorator( # Store the generator function self._scenarios[scenario_name] = fn - # Get function signature for prompt arguments + # Get function signature for prompt arguments with type info sig = inspect.signature(fn) - prompt_args = [ - {"name": p.name, "required": p.default is inspect.Parameter.empty} - for p in sig.parameters.values() - ] + prompt_args = [] + arguments_schema: dict[str, Any] = { + "type": "object", + "properties": {}, + "required": [], + } + for p in sig.parameters.values(): + is_required = p.default is inspect.Parameter.empty + prompt_args.append({"name": p.name, "required": is_required}) + if is_required: + arguments_schema["required"].append(p.name) + # Extract type annotation for schema + if p.annotation is not inspect.Parameter.empty: + try: + # Use pydantic to convert annotation to JSON schema + from pydantic import TypeAdapter + + adapter = TypeAdapter(p.annotation) + param_schema = adapter.json_schema() + arguments_schema["properties"][p.name] = param_schema + except Exception: + # Fallback: just use string type + arguments_schema["properties"][p.name] = {"type": "string"} + else: + arguments_schema["properties"][p.name] = {"type": "string"} # Register PROMPT - runs setup, returns prompt messages # We need a reference to self and the outer variables @@ -344,8 +365,12 @@ async def prompt_handler(**handler_args: Any) -> list[str]: # to bypass the **kwargs validation in from_function() from fastmcp.prompts.prompt import FunctionPrompt, PromptArgument - # Build meta with source code - scenario_meta = {"code": source_code} if source_code else None + # Build meta with source code and arguments schema + scenario_meta: dict[str, Any] = {} + if source_code: + scenario_meta["code"] = source_code + if arguments_schema["properties"]: + scenario_meta["argumentsSchema"] = arguments_schema prompt = FunctionPrompt( name=scenario_id, @@ -355,7 +380,7 @@ async def prompt_handler(**handler_args: Any) -> list[str]: for arg in prompt_args ], fn=prompt_handler, - meta=scenario_meta, + meta=scenario_meta if scenario_meta else None, ) self._prompt_manager.add_prompt(prompt) From 3a5d3b22818160fe929a353909a3b3d65b8c5486 Mon Sep 17 00:00:00 2001 From: lorenss-m Date: Sun, 14 Dec 2025 16:16:20 -0800 Subject: [PATCH 63/92] format --- hud/cli/debug.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/hud/cli/debug.py b/hud/cli/debug.py index 07c7924a..dbb22630 100644 --- a/hud/cli/debug.py +++ b/hud/cli/debug.py @@ -363,9 +363,7 @@ def read_stderr() -> None: } } - concurrent_client = FastMCPHUDClient( - mcp_config=client_config, verbose=False - ) + concurrent_client = FastMCPHUDClient(mcp_config=client_config, verbose=False) await concurrent_client.initialize() concurrent_clients.append(concurrent_client) logger.info(f"Client {i + 1} connected") From 1c718d09b605ded5f7a4805cda05e1aee654512b Mon Sep 17 00:00:00 2001 From: lorenss-m Date: Sun, 14 Dec 2025 16:37:38 -0800 Subject: [PATCH 64/92] scenarios --- hud/clients/base.py | 30 ++++++++++++++----------- hud/environment/scenarios.py | 43 ++++++++++++++++++++++-------------- 2 files changed, 43 insertions(+), 30 deletions(-) diff --git a/hud/clients/base.py b/hud/clients/base.py index 1e735cce..227b0379 100644 --- a/hud/clients/base.py +++ b/hud/clients/base.py @@ -417,13 +417,23 @@ async def analyze_environment(self) -> dict[str, Any]: "description": prompt.description, "arguments": args, } - # Include meta field if present (contains scenario source code and argumentsSchema) + # Include meta field if present (contains scenario source code and arguments with types) meta = getattr(prompt, "meta", None) if meta: prompt_info["meta"] = meta - # Extract argumentsSchema to top level for easier access - if isinstance(meta, dict) and "argumentsSchema" in meta: - prompt_info["argumentsSchema"] = meta["argumentsSchema"] + # Merge type/default info from meta.arguments into the arguments array + if isinstance(meta, dict) and "arguments" in meta: + meta_args = {a["name"]: a for a in meta["arguments"] if "name" in a} + for arg in args: + arg_name = arg.get("name") + if arg_name and arg_name in meta_args: + meta_arg = meta_args[arg_name] + if "default" in meta_arg: + arg["default"] = meta_arg["default"] + if "type" in meta_arg: + arg["type"] = meta_arg["type"] + if "schema" in meta_arg: + arg["schema"] = meta_arg["schema"] analysis["prompts"].append(prompt_info) except Exception as e: if self.verbose: @@ -453,16 +463,10 @@ async def analyze_environment(self) -> dict[str, Any]: "has_setup_prompt": True, "has_evaluate_resource": False, } - # Extract code and argumentsSchema from meta field if present + # Extract code from meta field if present meta = p.get("meta") - if meta and isinstance(meta, dict): - if "code" in meta: - scenario_info["code"] = meta["code"] - if "argumentsSchema" in meta: - scenario_info["argumentsSchema"] = meta["argumentsSchema"] - # Also check top-level argumentsSchema (extracted earlier) - if p.get("argumentsSchema"): - scenario_info["argumentsSchema"] = p["argumentsSchema"] + if meta and isinstance(meta, dict) and "code" in meta: + scenario_info["code"] = meta["code"] scenarios_by_id[scenario_id] = scenario_info for r in analysis.get("resources", []): diff --git a/hud/environment/scenarios.py b/hud/environment/scenarios.py index 92399918..3c982f41 100644 --- a/hud/environment/scenarios.py +++ b/hud/environment/scenarios.py @@ -306,18 +306,21 @@ def decorator( # Get function signature for prompt arguments with type info sig = inspect.signature(fn) - prompt_args = [] - arguments_schema: dict[str, Any] = { - "type": "object", - "properties": {}, - "required": [], - } + prompt_args: list[dict[str, Any]] = [] for p in sig.parameters.values(): is_required = p.default is inspect.Parameter.empty - prompt_args.append({"name": p.name, "required": is_required}) - if is_required: - arguments_schema["required"].append(p.name) - # Extract type annotation for schema + arg_info: dict[str, Any] = {"name": p.name, "required": is_required} + + # Include default value if present + if not is_required: + # Only include JSON-serializable defaults + default_val = p.default + if default_val is None or isinstance( + default_val, (str, int, float, bool, list, dict) + ): + arg_info["default"] = default_val + + # Extract type annotation if p.annotation is not inspect.Parameter.empty: try: # Use pydantic to convert annotation to JSON schema @@ -325,12 +328,18 @@ def decorator( adapter = TypeAdapter(p.annotation) param_schema = adapter.json_schema() - arguments_schema["properties"][p.name] = param_schema + # Extract type from schema (could be "string", "integer", etc.) + if "type" in param_schema: + arg_info["type"] = param_schema["type"] + elif "$ref" in param_schema or "anyOf" in param_schema: + # Complex type - store the full schema + arg_info["schema"] = param_schema except Exception: - # Fallback: just use string type - arguments_schema["properties"][p.name] = {"type": "string"} + arg_info["type"] = "string" else: - arguments_schema["properties"][p.name] = {"type": "string"} + arg_info["type"] = "string" + + prompt_args.append(arg_info) # Register PROMPT - runs setup, returns prompt messages # We need a reference to self and the outer variables @@ -365,12 +374,12 @@ async def prompt_handler(**handler_args: Any) -> list[str]: # to bypass the **kwargs validation in from_function() from fastmcp.prompts.prompt import FunctionPrompt, PromptArgument - # Build meta with source code and arguments schema + # Build meta with source code and full arguments info (with types/defaults) scenario_meta: dict[str, Any] = {} if source_code: scenario_meta["code"] = source_code - if arguments_schema["properties"]: - scenario_meta["argumentsSchema"] = arguments_schema + if prompt_args: + scenario_meta["arguments"] = prompt_args prompt = FunctionPrompt( name=scenario_id, From 6cb83e9e2aa41baef9750fa34382683e744d23c6 Mon Sep 17 00:00:00 2001 From: lorenss-m Date: Sun, 14 Dec 2025 17:18:41 -0800 Subject: [PATCH 65/92] switch var names --- hud/clients/base.py | 4 ++-- hud/environment/scenarios.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/hud/clients/base.py b/hud/clients/base.py index 227b0379..a1b8f04e 100644 --- a/hud/clients/base.py +++ b/hud/clients/base.py @@ -432,8 +432,8 @@ async def analyze_environment(self) -> dict[str, Any]: arg["default"] = meta_arg["default"] if "type" in meta_arg: arg["type"] = meta_arg["type"] - if "schema" in meta_arg: - arg["schema"] = meta_arg["schema"] + if "inputSchema" in meta_arg: + arg["inputSchema"] = meta_arg["inputSchema"] analysis["prompts"].append(prompt_info) except Exception as e: if self.verbose: diff --git a/hud/environment/scenarios.py b/hud/environment/scenarios.py index 3c982f41..1369ea37 100644 --- a/hud/environment/scenarios.py +++ b/hud/environment/scenarios.py @@ -333,7 +333,7 @@ def decorator( arg_info["type"] = param_schema["type"] elif "$ref" in param_schema or "anyOf" in param_schema: # Complex type - store the full schema - arg_info["schema"] = param_schema + arg_info["inputSchema"] = param_schema except Exception: arg_info["type"] = "string" else: From 80deb860e2e10dc5ae15dcb036da4fcebe99ba09 Mon Sep 17 00:00:00 2001 From: lorenss-m Date: Sun, 14 Dec 2025 19:47:40 -0800 Subject: [PATCH 66/92] changes to dev and tools --- hud/cli/__init__.py | 37 ++++++++++++++++----------- hud/cli/dev.py | 48 ++++++++++++++++++++++++++++-------- hud/cli/flows/dev.py | 4 +-- hud/tools/computer/hud.py | 11 +++++---- hud/tools/computer/openai.py | 13 ++-------- hud/tools/types.py | 12 +++++++++ 6 files changed, 82 insertions(+), 43 deletions(-) diff --git a/hud/cli/__init__.py b/hud/cli/__init__.py index 61868bef..a5b3c595 100644 --- a/hud/cli/__init__.py +++ b/hud/cli/__init__.py @@ -453,10 +453,11 @@ def dev( interactive: bool = typer.Option( False, "--interactive", help="Launch interactive testing mode (HTTP mode only)" ), - watch: list[str] = typer.Option( # noqa: B008 - None, + watch: list[str] = typer.Option( + [], "--watch", - help="Additional directories to watch for changes (default: current directory)", + "-w", + help="Paths to watch for hot-reload (repeatable: -w tools -w env.py)", ), new: bool = typer.Option( False, @@ -470,30 +471,36 @@ def dev( 1. Python Module: hud dev # Auto-detects module - hud dev server.main # Explicit module + hud dev env:env # Explicit module:attribute + hud dev -w . # Watch current directory - 2. Docker with Volume Mounts (Complex environments like 'browser'): - hud dev --docker # Auto-detects image from hud.lock.yaml - hud dev --docker -p 8080:8080 # With extra Docker args + 2. Docker (Complex environments): + hud dev # Auto-detects Dockerfile, no hot-reload + hud dev -w tools -w env.py # Mount & watch specific paths + hud dev -w tools # Just watch tools folder - The server must define 'mcp' in its __init__.py or main.py. + For Docker mode, use --watch to specify which folders to mount and watch. + Paths not in --watch stay in the built image (no hot-reload). Examples: - hud dev # Auto-detect in current directory + hud dev # Auto-detect mode hud dev --new # Create live dev trace on hud.ai - hud dev controller # Run specific module + hud dev env:env # Run specific module hud dev --inspector # Launch MCP Inspector hud dev --interactive # Launch interactive testing mode - hud dev --stdio # Use stdio transport - hud dev --watch ../shared # Watch additional directories + hud dev -w 'tools env.py' # Docker: hot-reload tools/ and env.py - For environment backend servers, use uvicorn directly: - uvicorn server:app --reload[/not dim] + Local development pattern (Docker + local scenarios): + Terminal 1: hud dev -w 'tools env.py' --port 8000 + Terminal 2: python local_test.py # Uses connect_url()[/not dim] """ # Extract module from params if provided (first param when not --docker) module = params[0] if params and not docker else None docker_args = params if docker else [] + # Convert empty list to None for run_mcp_dev_server + watch_paths = watch if watch else None + run_mcp_dev_server( module, stdio, @@ -501,7 +508,7 @@ def dev( verbose, inspector, interactive, - watch, + watch_paths, docker=docker, docker_args=docker_args, new_trace=new, diff --git a/hud/cli/dev.py b/hud/cli/dev.py index 8555906e..b0370bb7 100644 --- a/hud/cli/dev.py +++ b/hud/cli/dev.py @@ -556,9 +556,21 @@ def run_docker_dev_server( inspector: bool, interactive: bool, docker_args: list[str], + watch_paths: list[str] | None = None, new_trace: bool = False, ) -> None: - """Run MCP server in Docker with volume mounts, expose via local HTTP proxy.""" + """Run MCP server in Docker with volume mounts, expose via local HTTP proxy. + + Args: + port: HTTP port to expose + verbose: Show detailed logs + inspector: Launch MCP Inspector + interactive: Launch interactive testing mode + docker_args: Extra Docker run arguments + watch_paths: Folders/files to mount for hot-reload (e.g., ["tools", "env.py"]). + If None, no hot-reload mounts are added. + new_trace: Create a new dev trace on hud.ai + """ import atexit import signal @@ -691,10 +703,6 @@ def signal_handler(signum: int, frame: Any) -> None: "--rm", # Automatically remove container when it stops "--name", container_name, - "-v", - f"{env_dir.absolute()}/server:/app/server:rw", - "-v", - f"{env_dir.absolute()}/environment:/app/environment:rw", "-e", "PYTHONPATH=/app", "-e", @@ -703,6 +711,22 @@ def signal_handler(signum: int, frame: Any) -> None: "HUD_DEV=1", ] + # Add volume mounts for watch paths (hot-reload) + if watch_paths: + hud_console.info(f"Hot-reload enabled for: {', '.join(watch_paths)}") + for path in watch_paths: + # Resolve the local path + local_path = env_dir.absolute() / path + if local_path.exists(): + # Mount to /app/ in container + container_path = f"/app/{path}" + base_args.extend(["-v", f"{local_path}:{container_path}:rw"]) + else: + hud_console.warning(f"Watch path not found: {path}") + else: + hud_console.info("No --watch paths specified, running without hot-reload") + hud_console.dim_info("Tip", "Use -w to enable hot-reload (e.g., -w tools -w env.py)") + # Add debugging port mappings if available if debugging_ports: hud_console.info(f"Exposing debugging ports: {', '.join(map(str, debugging_ports))}") @@ -778,8 +802,8 @@ def signal_handler(signum: int, frame: Any) -> None: ) hud_console.dim_info( "", - "Container restarts on file changes (mounted volumes), " - "if changing tools run hud dev again", + "Container restarts on file changes in watched folders (-w), " + "rebuild with 'hud dev' if changing other files", ) hud_console.info("") @@ -893,15 +917,19 @@ def run_mcp_dev_server( # Auto-detect Docker mode if Dockerfile present and no module specified if not docker and module is None and should_use_docker_mode(cwd): - hud_console.note("Detected Dockerfile - using Docker mode with volume mounts") + hud_console.note("Detected Dockerfile - using Docker mode") hud_console.dim_info("Tip", "Use 'hud dev --help' to see all options") hud_console.info("") - run_docker_dev_server(port, verbose, inspector, interactive, docker_args, new_trace) + run_docker_dev_server( + port, verbose, inspector, interactive, docker_args, watch, new_trace + ) return # Route to Docker mode if explicitly requested if docker: - run_docker_dev_server(port, verbose, inspector, interactive, docker_args, new_trace) + run_docker_dev_server( + port, verbose, inspector, interactive, docker_args, watch, new_trace + ) return transport = "stdio" if stdio else "http" diff --git a/hud/cli/flows/dev.py b/hud/cli/flows/dev.py index 8072cf01..a0da6a7f 100644 --- a/hud/cli/flows/dev.py +++ b/hud/cli/flows/dev.py @@ -143,8 +143,8 @@ def show_dev_ui( if is_docker: hud_console.dim_info( "", - "Container restarts on file changes (mounted volumes), " - "if changing tools run hud dev again", + "Container restarts on file changes in watched folders (-w), " + "rebuild with 'hud dev' if changing other files", ) hud_console.info("") diff --git a/hud/tools/computer/hud.py b/hud/tools/computer/hud.py index 6c3dc9c2..2bade98a 100644 --- a/hud/tools/computer/hud.py +++ b/hud/tools/computer/hud.py @@ -13,7 +13,7 @@ from hud.tools.executors.base import BaseExecutor from hud.tools.executors.pyautogui import PyAutoGUIExecutor from hud.tools.executors.xdo import XDOExecutor -from hud.tools.types import ContentResult, ToolError +from hud.tools.types import ContentResult, Coordinate, ToolError from .settings import computer_settings @@ -270,8 +270,8 @@ async def __call__( offset_x: int | None = Field(None, description="X offset for relative move"), offset_y: int | None = Field(None, description="Y offset for relative move"), # Drag parameters - path: list[tuple[int, int]] | None = Field( - None, description="Path for drag actions as list of (x, y) coordinates" + path: list[Coordinate] | None = Field( + None, description="Path for drag actions as list of {x, y} coordinates" ), # Wait parameter time: int | None = Field(None, description="Time in milliseconds for wait action"), @@ -348,8 +348,9 @@ async def __call__( elif action == "drag": if path is None: raise ToolError("path parameter is required for drag") - # Scale path from client space to screen space - scaled_path = self._scale_path(path) + # Convert Coordinate objects to tuples and scale from client space to screen space + path_tuples = [(point.x, point.y) for point in path] + scaled_path = self._scale_path(path_tuples) result = await self.executor.drag( path=scaled_path, pattern=pattern, hold_keys=hold_keys ) diff --git a/hud/tools/computer/openai.py b/hud/tools/computer/openai.py index 53806550..576cc618 100644 --- a/hud/tools/computer/openai.py +++ b/hud/tools/computer/openai.py @@ -6,10 +6,10 @@ from mcp import ErrorData, McpError from mcp.types import INTERNAL_ERROR, INVALID_PARAMS, ContentBlock, TextContent -from pydantic import BaseModel, Field +from pydantic import Field from hud.tools.computer.settings import computer_settings -from hud.tools.types import ContentResult +from hud.tools.types import ContentResult, Coordinate from .hud import HudComputerTool @@ -19,15 +19,6 @@ logger = logging.getLogger(__name__) -class Coordinate(BaseModel): - """A coordinate point with x and y values.""" - - model_config = {"extra": "forbid"} # Ensures additionalProperties: false in JSON schema - - x: int = Field(..., description="X coordinate") - y: int = Field(..., description="Y coordinate") - - # Map OpenAI key names to CLA standard keys OPENAI_TO_CLA_KEYS = { # Common variations diff --git a/hud/tools/types.py b/hud/tools/types.py index f3285258..282e4de2 100644 --- a/hud/tools/types.py +++ b/hud/tools/types.py @@ -6,6 +6,18 @@ from pydantic import BaseModel, ConfigDict, Field +class Coordinate(BaseModel): + """A coordinate point with x and y values. + + Used for path-based actions like drag operations. + """ + + model_config = ConfigDict(extra="forbid") + + x: int = Field(..., description="X coordinate") + y: int = Field(..., description="Y coordinate") + + class EvaluationResult(BaseModel): """Standard evaluation result format.""" From 521ef172c3d0b38e5533f3c96ca63d2c466624c3 Mon Sep 17 00:00:00 2001 From: lorenss-m Date: Sun, 14 Dec 2025 19:50:57 -0800 Subject: [PATCH 67/92] save tool result --- hud/eval/context.py | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/hud/eval/context.py b/hud/eval/context.py index a86cec47..2d21d942 100644 --- a/hud/eval/context.py +++ b/hud/eval/context.py @@ -23,6 +23,7 @@ from types import TracebackType from hud.eval.task import Task + from hud.types import MCPToolResult from hud.eval.types import EvalExitPayload, EvalPayload, ParallelEvalComplete @@ -577,13 +578,15 @@ async def __aexit__( # ========================================================================= @instrument(category="mcp") - async def call_tool(self, call: Any, /, **kwargs: Any) -> Any: - """Call a tool with automatic telemetry recording. + async def _execute_tool( + self, name: str, arguments: dict[str, Any] + ) -> MCPToolResult: + """Execute a tool with automatic telemetry recording. - Overrides Environment.call_tool to record MCP spans for the eval context. - Uses @instrument decorator for automatic span recording. + Overrides Environment._execute_tool to record MCP spans for the eval context. + The decorator records name, arguments, and result automatically. """ - return await super().call_tool(call, **kwargs) + return await super()._execute_tool(name, arguments) def __repr__(self) -> str: return f"EvalContext({self.trace_id[:8]}..., name={self.eval_name!r}, reward={self.reward})" From 8a8e80605d2131c3ec1fee504f8e8feb9d2f8553 Mon Sep 17 00:00:00 2001 From: lorenss-m Date: Sun, 14 Dec 2025 20:00:56 -0800 Subject: [PATCH 68/92] adjust tools for generic spec --- hud/tools/computer/anthropic.py | 4 ++-- hud/tools/computer/qwen.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/hud/tools/computer/anthropic.py b/hud/tools/computer/anthropic.py index 310fbbf3..59854215 100644 --- a/hud/tools/computer/anthropic.py +++ b/hud/tools/computer/anthropic.py @@ -141,13 +141,13 @@ def _map_anthropic_key_to_cla(self, key: str) -> str: async def __call__( self, action: str = Field(..., description="The action to perform on the computer"), - coordinate: list[int] | tuple[int, int] | None = Field( + coordinate: list[int] | None = Field( None, description="The coordinate to interact with on the computer [x, y]" ), text: str | None = Field( None, description="The text to type on the computer or key to press" ), - start_coordinate: list[int] | tuple[int, int] | None = Field( + start_coordinate: list[int] | None = Field( None, description="The starting coordinate for drag actions [x, y]" ), scroll_direction: str | None = Field( diff --git a/hud/tools/computer/qwen.py b/hud/tools/computer/qwen.py index 71da53fb..6f3db5cc 100644 --- a/hud/tools/computer/qwen.py +++ b/hud/tools/computer/qwen.py @@ -194,7 +194,7 @@ async def __call__( action: str = Field(..., description="The action to perform on the computer"), keys: list[str] | None = Field(None, description="Keys for key action"), text: str | None = Field(None, description="Text to type"), - coordinate: list[int] | tuple[int, int] | None = Field( + coordinate: list[int] | None = Field( None, description="The coordinate to interact with on the computer [x, y]" ), pixels: int | None = Field(None, description="Pixels to scroll"), From b4d0cabe2a8221becf69d83b0564c7193c3a7f0b Mon Sep 17 00:00:00 2001 From: lorenss-m Date: Sun, 14 Dec 2025 20:20:54 -0800 Subject: [PATCH 69/92] udpate schema resolution --- hud/cli/dev.py | 8 +-- hud/environment/integrations/openai.py | 68 ++++++++++++++++++----- hud/environment/utils/schema.py | 76 +++++++++++++++++++++++++- hud/eval/context.py | 4 +- 4 files changed, 133 insertions(+), 23 deletions(-) diff --git a/hud/cli/dev.py b/hud/cli/dev.py index b0370bb7..cf0ab918 100644 --- a/hud/cli/dev.py +++ b/hud/cli/dev.py @@ -920,16 +920,12 @@ def run_mcp_dev_server( hud_console.note("Detected Dockerfile - using Docker mode") hud_console.dim_info("Tip", "Use 'hud dev --help' to see all options") hud_console.info("") - run_docker_dev_server( - port, verbose, inspector, interactive, docker_args, watch, new_trace - ) + run_docker_dev_server(port, verbose, inspector, interactive, docker_args, watch, new_trace) return # Route to Docker mode if explicitly requested if docker: - run_docker_dev_server( - port, verbose, inspector, interactive, docker_args, watch, new_trace - ) + run_docker_dev_server(port, verbose, inspector, interactive, docker_args, watch, new_trace) return transport = "stdio" if stdio else "http" diff --git a/hud/environment/integrations/openai.py b/hud/environment/integrations/openai.py index 9d553ac0..0bad7782 100644 --- a/hud/environment/integrations/openai.py +++ b/hud/environment/integrations/openai.py @@ -3,9 +3,10 @@ from __future__ import annotations import json +import logging from typing import TYPE_CHECKING, Any, cast -from hud.environment.utils.schema import ensure_strict_schema +from hud.environment.utils.schema import ensure_strict_schema, validate_openai_schema if TYPE_CHECKING: import mcp.types as mcp_types @@ -13,6 +14,8 @@ __all__ = ["OpenAIMixin"] +logger = logging.getLogger(__name__) + class OpenAIMixin: """Mixin providing OpenAI format conversion and Agents SDK integration. @@ -43,11 +46,14 @@ async def call_tool(self, name: str, arguments: dict[str, Any]) -> Any: # Format Conversion (no external deps) # ========================================================================= - def as_openai_chat_tools(self, *, strict: bool = False) -> list[ChatCompletionToolUnionParam]: + def as_openai_chat_tools( + self, *, strict: bool = False, validate: bool = True + ) -> list[ChatCompletionToolUnionParam]: """Convert to OpenAI Chat Completions tool format. Args: strict: Enable strict mode for structured outputs + validate: Validate schemas and skip incompatible tools with warnings Returns: List of tool definitions for OpenAI Chat Completions API. @@ -72,6 +78,14 @@ def as_openai_chat_tools(self, *, strict: bool = False) -> list[ChatCompletionTo for t in self.as_tools(): schema = dict(t.inputSchema) if t.inputSchema else {"type": "object", "properties": {}} + # Validate schema for OpenAI compatibility + if validate: + errors = validate_openai_schema(schema, t.name) + if errors: + for error in errors: + logger.warning("Skipping tool: %s", error) + continue + if strict: schema = ensure_strict_schema(schema) @@ -91,12 +105,15 @@ def as_openai_chat_tools(self, *, strict: bool = False) -> list[ChatCompletionTo ) return tools - def as_openai_responses_tools(self) -> list[dict[str, Any]]: + def as_openai_responses_tools(self, *, validate: bool = True) -> list[dict[str, Any]]: """Convert to OpenAI Responses API tool format. Note: Like Chat Completions, you must execute tools yourself. OpenAI only auto-executes their built-in tools (code_interpreter, etc). + Args: + validate: Validate schemas and skip incompatible tools with warnings + Returns: List of tool definitions for OpenAI Responses API. @@ -117,21 +134,33 @@ def as_openai_responses_tools(self) -> list[dict[str, Any]]: result = await env.call_tool(item.name, **item.arguments) ``` """ - return [ - { - "type": "function", - "name": t.name, - "description": t.description or "", - "parameters": t.inputSchema or {"type": "object", "properties": {}}, - } - for t in self.as_tools() - ] + tools = [] + for t in self.as_tools(): + schema = dict(t.inputSchema) if t.inputSchema else {"type": "object", "properties": {}} + + # Validate schema for OpenAI compatibility + if validate: + errors = validate_openai_schema(schema, t.name) + if errors: + for error in errors: + logger.warning("Skipping tool: %s", error) + continue + + tools.append( + { + "type": "function", + "name": t.name, + "description": t.description or "", + "parameters": schema, + } + ) + return tools # ========================================================================= # Agents SDK Integration (requires openai-agents) # ========================================================================= - def as_openai_agent_tools(self) -> list[Any]: + def as_openai_agent_tools(self, *, validate: bool = True) -> list[Any]: """Convert to OpenAI Agents SDK FunctionTool objects. This creates FunctionTool objects that automatically execute against @@ -145,6 +174,9 @@ def as_openai_agent_tools(self) -> list[Any]: Requires: pip install openai-agents + Args: + validate: Validate schemas and skip incompatible tools with warnings + Returns: List of FunctionTool objects for OpenAI Agents SDK. @@ -171,6 +203,16 @@ def as_openai_agent_tools(self) -> list[Any]: tools = [] for t in self.as_tools(): + schema = dict(t.inputSchema) if t.inputSchema else {"type": "object", "properties": {}} + + # Validate schema for OpenAI compatibility + if validate: + errors = validate_openai_schema(schema, t.name) + if errors: + for error in errors: + logger.warning("Skipping tool: %s", error) + continue + tool = _create_function_tool(self, t, FunctionTool) tools.append(tool) return tools diff --git a/hud/environment/utils/schema.py b/hud/environment/utils/schema.py index 346ff2ce..a17c2984 100644 --- a/hud/environment/utils/schema.py +++ b/hud/environment/utils/schema.py @@ -2,9 +2,17 @@ from __future__ import annotations +import logging from typing import Any -__all__ = ["ensure_strict_schema", "json_type_to_python", "schema_to_pydantic"] +__all__ = [ + "ensure_strict_schema", + "json_type_to_python", + "schema_to_pydantic", + "validate_openai_schema", +] + +logger = logging.getLogger(__name__) def ensure_strict_schema(schema: dict[str, Any]) -> dict[str, Any]: @@ -95,3 +103,69 @@ def json_type_to_python(json_type: str) -> type: "object": dict, } return mapping.get(json_type, str) + + +def validate_openai_schema( + schema: dict[str, Any], + tool_name: str = "unknown", + path: str = "", +) -> list[str]: + """Validate a JSON schema for OpenAI API compatibility. + + OpenAI's API has specific requirements for tool schemas: + - Arrays must have 'items' (not 'prefixItems' which tuples generate) + - Certain schema features like 'prefixItems' are not supported + + Args: + schema: JSON schema to validate. + tool_name: Name of the tool (for error messages). + path: Current path in schema (for error context). + + Returns: + List of validation error messages. Empty if valid. + """ + errors: list[str] = [] + + if not isinstance(schema, dict): + return errors + + # Check for prefixItems (generated by tuple types) + if "prefixItems" in schema: + errors.append( + f"Tool '{tool_name}' has 'prefixItems' at {path or 'root'} " + "(likely from tuple type). Use list[Model] instead of tuple." + ) + + # Check arrays have 'items' + if schema.get("type") == "array" and "items" not in schema and "prefixItems" not in schema: + errors.append( + f"Tool '{tool_name}' has array at {path or 'root'} without 'items'. " + "OpenAI requires 'items' for array schemas." + ) + + # Recursively check nested schemas + # Check properties + if "properties" in schema: + for prop_name, prop_schema in schema["properties"].items(): + prop_path = f"{path}.{prop_name}" if path else prop_name + errors.extend(validate_openai_schema(prop_schema, tool_name, prop_path)) + + # Check items + if "items" in schema and isinstance(schema["items"], dict): + items_path = f"{path}[items]" if path else "[items]" + errors.extend(validate_openai_schema(schema["items"], tool_name, items_path)) + + # Check anyOf/oneOf/allOf + for key in ("anyOf", "oneOf", "allOf"): + if key in schema: + for i, sub_schema in enumerate(schema[key]): + sub_path = f"{path}.{key}[{i}]" if path else f"{key}[{i}]" + errors.extend(validate_openai_schema(sub_schema, tool_name, sub_path)) + + # Check $defs (definitions) + if "$defs" in schema: + for def_name, def_schema in schema["$defs"].items(): + def_path = f"$defs.{def_name}" + errors.extend(validate_openai_schema(def_schema, tool_name, def_path)) + + return errors diff --git a/hud/eval/context.py b/hud/eval/context.py index 2d21d942..2f2e968a 100644 --- a/hud/eval/context.py +++ b/hud/eval/context.py @@ -578,9 +578,7 @@ async def __aexit__( # ========================================================================= @instrument(category="mcp") - async def _execute_tool( - self, name: str, arguments: dict[str, Any] - ) -> MCPToolResult: + async def _execute_tool(self, name: str, arguments: dict[str, Any]) -> MCPToolResult: """Execute a tool with automatic telemetry recording. Overrides Environment._execute_tool to record MCP spans for the eval context. From aecb97b6cc6d6a6e77b1953e40c295cbbc743178 Mon Sep 17 00:00:00 2001 From: lorenss-m Date: Sun, 14 Dec 2025 20:21:37 -0800 Subject: [PATCH 70/92] update test --- hud/tools/tests/test_computer.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/hud/tools/tests/test_computer.py b/hud/tools/tests/test_computer.py index 77aa2ded..4b7ef827 100644 --- a/hud/tools/tests/test_computer.py +++ b/hud/tools/tests/test_computer.py @@ -9,6 +9,7 @@ from hud.tools.computer.hud import HudComputerTool from hud.tools.computer.openai import OpenAIComputerTool from hud.tools.executors.base import BaseExecutor +from hud.types import Coordinate @pytest.mark.asyncio @@ -193,7 +194,7 @@ async def test_move_action(self, base_executor): async def test_drag_action(self, base_executor): """Test drag action with BaseExecutor.""" tool = HudComputerTool(executor=base_executor) - result = await tool(action="drag", path=[(100, 100), (200, 200)]) + result = await tool(action="drag", path=[Coordinate(x=100, y=100), Coordinate(x=200, y=200)]) assert result assert any("Drag" in content.text for content in result if isinstance(content, TextContent)) From 0acfa42282b4a4e7025070ad284553be91052816 Mon Sep 17 00:00:00 2001 From: lorenss-m Date: Sun, 14 Dec 2025 20:22:08 -0800 Subject: [PATCH 71/92] docs and tests --- docs/advanced/testing-environments.mdx | 105 ++++++++++++ docs/guides/best-practices.mdx | 142 +++++++++++++++ docs/guides/sandboxing.mdx | 228 +++++++++++-------------- hud/tools/tests/test_computer.py | 2 +- 4 files changed, 343 insertions(+), 134 deletions(-) create mode 100644 docs/advanced/testing-environments.mdx create mode 100644 docs/guides/best-practices.mdx diff --git a/docs/advanced/testing-environments.mdx b/docs/advanced/testing-environments.mdx new file mode 100644 index 00000000..99ece815 --- /dev/null +++ b/docs/advanced/testing-environments.mdx @@ -0,0 +1,105 @@ +--- +title: "Testing Environments" +description: "Test scenarios, tools, and environment logic locally" +icon: "flask-vial" +--- + +Before deploying, test locally. See [Sandboxing](/guides/sandboxing) for Docker vs no-Docker patterns. + +## Local Testing + +| Environment | `local_test.py` | +|-------------|-----------------| +| No Docker | `from env import env` | +| Docker | `env.connect_url("http://localhost:8765/mcp")` | + +Both use the same API after setup: + +```python +async with env: + tools = env.as_tools() # List available tools + result = await env.call_tool("my_tool", arg="val") # Call a tool +``` + +## Testing Scenarios Directly + +Scenarios are async generators. `hud.eval()` drives them automatically, but you can test the logic directly—this is exactly what runs at the start and end of `hud.eval()`: + +```python +async def checkout(user_id: str, amount: int = 100): + # Setup + prompt (first yield) — runs at hud.eval() start + answer = yield f"Complete checkout for {user_id}, ${amount}" + + # Evaluation (second yield) — runs after agent submits + yield 1.0 if "success" in answer.lower() else 0.0 + +async def test(): + gen = checkout("alice", 50) + prompt = await anext(gen) # What hud.eval() does at start + reward = await gen.asend("Success!") # What hud.eval() does after submit + assert reward == 1.0 +``` + +If your scenario tests pass, `hud.eval()` will behave identically. + +## Mocking + +`env.mock()` intercepts at the tool layer—agents only see tools: + +```python +env.mock() # All tools return fake responses +env.mock_tool("send_email", {"status": "sent"}) + +# Check mock state +assert env.is_mock == True +``` + +## Hot-Reload + +For Docker environments, `hud dev -w path` reloads Python on save: + +```bash +hud dev -w scenarios -w tools --port 8765 +``` + +System services (postgres, VNC, browsers) persist across reloads. + +## Debugging Build Failures + +`hud build` runs the exact same pipeline as **New → Environment** on [hud.ai](https://hud.ai)—so if it passes locally, it'll work in production. If the build fails or the container crashes on startup, use `hud debug` to run a 5-phase compliance test: + +```bash +hud debug my-env:latest +``` + +Output shows exactly which phase failed: +``` +✓ Phase 1: Docker image exists +✓ Phase 2: MCP server responds to initialize +✗ Phase 3: Tool discovery failed + → Error: Connection refused on port 8005 + → Hint: Backend service may not be starting +``` + +You can also debug a directory (builds first) or stop at a specific phase: + +```bash +hud debug . # Build and debug current directory +hud debug . --max-phase 3 # Stop after phase 3 +hud debug --config mcp.json # Debug from config file +``` + +## Useful Environment Properties + +```python +# Check parallelization (for running multiple evals) +env.is_parallelizable # True if all connections are remote + +# List what's connected +env.connections # Dict of connection names → connectors +env.is_connected # True if in async context + +# Resources and prompts (beyond tools) +await env.list_resources() # MCP resources +await env.list_prompts() # MCP prompts +``` diff --git a/docs/guides/best-practices.mdx b/docs/guides/best-practices.mdx new file mode 100644 index 00000000..662cbab9 --- /dev/null +++ b/docs/guides/best-practices.mdx @@ -0,0 +1,142 @@ +--- +title: "Best Practices" +description: "Design effective environments, evals, and grading logic" +icon: "star" +--- + +Building good agent evaluations requires thoughtful design at every layer—the environment, the prompts, and the grading logic. This guide covers patterns that lead to useful, reliable signal. + +## Good Environments + +A good environment gives agents what they need to succeed—and gives you what you need to evaluate them. + +### Observable State + +Agents need access to the right information. If they can't see the data they need, they can't complete the task. Design tools that expose useful state: + +```python +# ❌ Bad: Agent can't see what was created +@env.tool() +def create_user(name: str) -> str: + db.insert("users", name=name) + return "User created" + +# ✅ Good: Agent gets actionable data back +@env.tool() +def create_user(name: str) -> dict: + user_id = db.insert("users", name=name) + return {"id": user_id, "name": name, "created": True} +``` + +For grading, you also need to observe what happened. If the agent creates a database row, you need to query that database. If it uploads a file, you need to read that file. Be cognizant of what you can and cannot observe—only ask agents to do things you can verify. + +### Deterministic Setup + +Each eval should seed the state it needs. HUD handles container isolation—you handle making sure your scenario sets up the right data before the agent runs. + +```python +# ❌ Bad: Depends on whatever state exists +@env.scenario("find-user") +async def find_user(name: str): + answer = yield f"Find the user named {name}" + yield 1.0 if name in answer else 0.0 + +# ✅ Good: Seeds known state before eval +@env.scenario("find-user") +async def find_user(name: str): + await db.clear() + await db.insert("users", name=name, email=f"{name}@example.com") + + answer = yield f"Find the user named {name}" + yield 1.0 if name in answer else 0.0 +``` + +### Isolated Execution + +HUD sandboxes each eval—containers don't share state. But if your environment connects to external services, think about stateful vs stateless. + +**Stateless services** are fine. Multiple agents can hit the same read-only API without interference. + +**Stateful services** need care. If 100 agents all hit the same database endpoint that modifies data, they'll step on each other. Use per-eval instances, transaction isolation, or target different records. + +## Good Evals + +An eval combines a prompt (the first `yield`) with grading logic (everything after). The prompt tells agents what to do—write short-to-medium length instructions that ask for an unambiguous change you can verify. + +### Be Specific + +Ambiguous prompts lead to ambiguous grading. Say exactly what you want: + +``` +❌ "Update the user settings" +✅ "Change the email for user alice@example.com to alice.new@example.com" +``` + +Real-world example: *"Add a column to the Portfolio snapshot with the 'Phase' of the engagement. C-11X should be 'Phase 2', all else are 'Phase 1'."* + +### Only Ask for Testable Things + +If you can't observe the result, you can't grade it. Don't ask an agent to "think about" something—ask it to do something you can verify. + +``` +❌ "Consider the best approach to optimize the query" +✅ "Rewrite the query to use an index on the email column" +``` + +### Create Variations + +Evals are easier to write when you have a specific failure mode in mind. If you've observed agents struggling with something, incorporate that into future evals. + +Create different versions with more or less explicit instructions—step-by-step guidance vs. high-level goals. Use [variants](/quick-links/ab-testing) to test these systematically. Variations make it easier to tune difficulty later. + +## Good Graders + +The grading logic after the first `yield` determines the grade. Fair grading means useful signal. + +### Match the Prompt + +If the prompt says "create a document with a Japanese car brand", check for any Japanese car brand—not just "Toyota". But don't accept any document either. Exactly as strict as the prompt implies. + +```python +# ❌ Bad: Too strict—only accepts one answer +@env.scenario("add-car") +async def add_car(): + answer = yield "Add a Japanese car brand to the document" + yield 1.0 if answer == "Toyota" else 0.0 + +# ✅ Good: Accepts any valid answer +@env.scenario("add-car") +async def add_car(): + answer = yield "Add a Japanese car brand to the document" + japanese_brands = ["toyota", "honda", "nissan", "mazda", "subaru"] + yield 1.0 if any(brand in answer.lower() for brand in japanese_brands) else 0.0 +``` + +### Use Partial Credit + +Partial grades help you see where agents fail. Did they add to cart but not checkout? That's useful signal. Break complex grading into sub-checks with weighted grades: + +```python +@env.scenario("checkout") +async def checkout(product: str): + answer = yield f"Add {product} to cart and checkout" + + score = 0.0 + if await product_in_cart(product): + score += 0.3 # Partial credit for first step + if await order_completed(product): + score += 0.7 # Most credit for completion + yield score +``` + +### Sanity Check + +At minimum, verify two cases: unchanged state → 0.0, correct completion → 1.0. For grading logic you'll reuse across many evals, write unit tests. Load a known state snapshot, verify the grade matches what you expect. + +## Finding the Right Difficulty + +A good eval set has range—target 20-30% average success rate. You want high variance: some runs should grade 0.0, others 1.0. If every run grades the same, there's no signal to learn from. Having both positive and negative examples on the same eval is what makes improvement possible. + +**Iterate.** Create an eval, test it manually, run it at scale, check the difficulty. If it's too easy or too hard, adjust the prompt or grading. Use your best evals as templates for more. + +**Train.** Every eval generates data—prompts, tool calls, grades. Use successful runs for fine-tuning. The loop: eval → analyze → train → eval again. diff --git a/docs/guides/sandboxing.mdx b/docs/guides/sandboxing.mdx index f3eef18a..dbebcb3d 100644 --- a/docs/guides/sandboxing.mdx +++ b/docs/guides/sandboxing.mdx @@ -6,194 +6,156 @@ icon: "shield" You have a production stack. You want an agent on it. But you can't just point an agent at production—it'll make real changes, hit real APIs, affect real users. And you can't test at scale against a single live instance with shared state. -HUD lets you mock your production environment so agents can run against it safely. Connect your services in a few lines. Write evals that tell agents what to do and grade how well they did it. HUD handles the sandboxing, the parallelization, the state extraction, the tracing. You get a reliable test bed where thousands of agents can run in parallel—each isolated, each reproducible, each generating useful data. +HUD lets you mock your production environment so agents can run against it safely. Connect your services in a few lines, mock external dependencies, and run thousands of agents in parallel—each isolated, each reproducible, each generating useful data. ## Connecting Your Stack -HUD wraps your existing infrastructure. Your code stays where it is—you connect it: +HUD wraps your existing infrastructure without rewriting it: ```python from hud import Environment env = Environment("my-env") -# Your FastAPI app → all routes become tools -env.connect_fastapi(app) - -# Your MCP servers -env.connect_server(mcp_server) - -# Any REST API with an OpenAPI spec -env.connect_openapi("https://api.example.com/openapi.json") +# Connect what you already have +env.connect_fastapi(app) # FastAPI → tools +env.connect_openapi("https://api.example.com/openapi.json") # OpenAPI spec → tools +env.connect_hub("hud-evals/browser") # HUD Hub environments +env.connect_image("my-service:v1") # Docker images ``` -Docker images work with `env.connect_image("my-service:v1")`. Other HUD environments compose with `env.connect_hub("my-org/other-env")`. See the full list in the [Environment Reference](/reference/environments). - -Run `hud init` to scaffold an environment in an existing project—it adds the HUD files without touching your code. Once connected, deploy and run evals at scale. - -### Making It Safe - -HUD runs each eval in its own container—isolated, reproducible, safe. But your environment might connect to external services. Here's how to handle them: +## Making Databases Safe -**Databases.** Each agent needs its own sandbox. Use in-memory SQLite (fast, resets per eval), transaction rollback, or seed fresh data at start: +Agents need isolated state. Three patterns work: +**In-memory SQLite** — fastest, resets automatically: ```python +import sqlite3 +db = sqlite3.connect(":memory:") # Fresh per eval + @env.scenario("update-order") async def update_order(order_id: str): - await db.seed_from("fixtures/orders.sql") - - answer = yield f"Update order {order_id} status to 'shipped'" - - order = await db.query("SELECT status FROM orders WHERE id = ?", order_id) - yield 1.0 if order and order["status"] == "shipped" else 0.0 + db.executescript(Path("fixtures/orders.sql").read_text()) # Seed + answer = yield f"Update order {order_id} to shipped" + row = db.execute("SELECT status FROM orders WHERE id=?", (order_id,)).fetchone() + yield 1.0 if row and row[0] == "shipped" else 0.0 ``` -**Third-party APIs.** Use mock mode to return fake responses without hitting real services: - +**Transaction rollback** — use your real DB, undo changes: ```python -env.mock() # All tools return fake responses based on schemas -env.mock_tool("send_email", {"status": "sent", "id": "mock-123"}) # Override specific tools +@env.scenario("process-refund") +async def process_refund(order_id: str): + conn = await asyncpg.connect(DATABASE_URL) + tx = conn.transaction() + await tx.start() + try: + answer = yield f"Process refund for order {order_id}" + # Check result... + yield reward + finally: + await tx.rollback() # Always undo + await conn.close() ``` -**Credentials.** If you need a live service, use staging keys. Point evals at staging, not production. - -## Good Environments - -A good environment gives agents what they need to succeed—and gives you what you need to evaluate them. +**Fixture seeding** — deterministic starting state: +```python +await db.execute("TRUNCATE orders, users CASCADE") +await db.executemany("INSERT INTO users ...", fixtures["users"]) +``` -### Observable State +## Mocking External Services -Agents need access to the right information. If they can't see the data they need, they can't complete the task. Design tools that expose useful state: +`env.mock()` intercepts at the tool layer. Agents only see tools, so this is usually all you need: ```python -# ❌ Bad: Agent can't see what was created -@env.tool() -def create_user(name: str) -> str: - db.insert("users", name=name) - return "User created" - -# ✅ Good: Agent gets actionable data back -@env.tool() -def create_user(name: str) -> dict: - user_id = db.insert("users", name=name) - return {"id": user_id, "name": name, "created": True} +env.mock() # All tools return schema-based fake responses +env.mock_tool("send_email", {"status": "sent", "id": "mock-123"}) +env.mock_tool("charge_card", {"success": True, "transaction_id": "tx-mock"}) ``` -For grading, you also need to observe what happened. If the agent creates a database row, you need to query that database. If it uploads a file, you need to read that file. Be cognizant of what you can and cannot observe—only ask agents to do things you can verify. - -### Deterministic Setup - -Each eval should seed the state it needs. HUD handles container isolation—you handle making sure your scenario sets up the right data before the agent runs. +For stateful mocking (tracking what happened for assertions): ```python -# ❌ Bad: Depends on whatever state exists -@env.scenario("find-user") -async def find_user(name: str): - answer = yield f"Find the user named {name}" - yield 1.0 if name in answer else 0.0 - -# ✅ Good: Seeds known state before eval -@env.scenario("find-user") -async def find_user(name: str): - await db.clear() - await db.insert("users", name=name, email=f"{name}@example.com") +class MockPaymentService: + def __init__(self): + self.charges = [] - answer = yield f"Find the user named {name}" - yield 1.0 if name in answer else 0.0 -``` - -### Isolated Execution + async def charge(self, amount: int, card_token: str) -> dict: + self.charges.append({"amount": amount, "token": card_token}) + return {"success": True, "id": f"ch-{len(self.charges)}"} -HUD sandboxes each eval—containers don't share state. But if your environment connects to external services, think about stateful vs stateless. +payments = MockPaymentService() -**Stateless services** are fine. Multiple agents can hit the same read-only API without interference. +@env.scenario("checkout") +async def checkout(cart_total: int): + _ = yield f"Complete checkout for ${cart_total}" + yield 1.0 if any(c["amount"] == cart_total for c in payments.charges) else 0.0 +``` -**Stateful services** need care. If 100 agents all hit the same database endpoint that modifies data, they'll step on each other. Use per-eval instances, transaction isolation, or target different records. +## Docker vs No Docker -## Good Evals +| Pattern | When to Use | Examples | +|---------|-------------|----------| +| **No Docker** | Pure Python, API integrations | Web research, LLM grading | +| **Docker** | System dependencies, persistent services | VNC, PostgreSQL, browsers | -An eval combines a prompt (the first `yield`) with grading logic (everything after). The prompt tells agents what to do—write short-to-medium length instructions that ask for an unambiguous change you can verify. +### Pattern 1: No Docker -### Be Specific +Import and test directly: -Ambiguous prompts lead to ambiguous grading. Say exactly what you want: +```python +# local_test.py +from env import env -``` -❌ "Update the user settings" -✅ "Change the email for user alice@example.com to alice.new@example.com" +async def test(): + async with env: + result = await env.call_tool("search", query="test") ``` -Real-world example: *"Add a column to the Portfolio snapshot with the 'Phase' of the engagement. C-11X should be 'Phase 2', all else are 'Phase 1'."* +### Pattern 2: Docker -### Only Ask for Testable Things +Connect to the running container instead of importing. Same API, different transport—because your tools now run inside the container where dependencies live: -If you can't observe the result, you can't grade it. Don't ask an agent to "think about" something—ask it to do something you can verify. +```python +# local_test.py +env = Environment("browser-env") +env.connect_url("http://localhost:8765/mcp") # Connect instead of import +async def test(): + async with env: # Same API from here + result = await env.call_tool("navigate", url="https://example.com") ``` -❌ "Consider the best approach to optimize the query" -✅ "Rewrite the query to use an index on the email column" -``` - -### Create Variations -Evals are easier to write when you have a specific failure mode in mind. If you've observed agents struggling with something, incorporate that into future evals. +```bash +hud build # Build image +hud dev -w scenarios -w tools --port 8765 # Start with hot-reload +python local_test.py # Connects to container +``` -Create different versions with more or less explicit instructions—step-by-step guidance vs. high-level goals. Use [variants](/quick-links/ab-testing) to test these systematically. Variations make it easier to tune difficulty later. +### Hot-Reload -## Good Graders +`hud dev -w path` reloads Python on save. System services (postgres, VNC) persist. -The grading logic after the first `yield` determines the grade. Fair grading means useful signal. +**Rebuild** (`hud build`) when: Dockerfile, system packages, or dependencies change. -### Match the Prompt +## Environment Structure -If the prompt says "create a document with a Japanese car brand", check for any Japanese car brand—not just "Toyota". But don't accept any document either. Exactly as strict as the prompt implies. +Start simple, add structure as needed: -```python -# ❌ Bad: Too strict—only accepts one answer -@env.scenario("add-car") -async def add_car(): - answer = yield "Add a Japanese car brand to the document" - yield 1.0 if answer == "Toyota" else 0.0 - -# ✅ Good: Accepts any valid answer -@env.scenario("add-car") -async def add_car(): - answer = yield "Add a Japanese car brand to the document" - japanese_brands = ["toyota", "honda", "nissan", "mazda", "subaru"] - yield 1.0 if any(brand in answer.lower() for brand in japanese_brands) else 0.0 ``` - -### Use Partial Credit - -Partial grades help you see where agents fail. Did they add to cart but not checkout? That's useful signal. Break complex grading into sub-checks with weighted grades: - -```python -@env.scenario("checkout") -async def checkout(product: str): - answer = yield f"Add {product} to cart and checkout" - - score = 0.0 - if await product_in_cart(product): - score += 0.3 # Partial credit for first step - if await order_completed(product): - score += 0.7 # Most credit for completion - yield score +# Simple # Organized +my-env/ my-env/ +├── env.py ├── env.py +├── local_test.py ├── scenarios/ +└── Dockerfile.hud ├── setup/ + ├── evaluate/ + └── Dockerfile.hud ``` -### Sanity Check - -At minimum, verify two cases: unchanged state → 0.0, correct completion → 1.0. For grading logic you'll reuse across many evals, write unit tests. Load a known state snapshot, verify the grade matches what you expect. +Most environments fall somewhere between. Split when files get hard to navigate. ## What's Next -Once your environment is connected and your evals are written, you're ready to run at scale. - -**Deploy.** Push to GitHub, connect on [hud.ai](https://hud.ai), and your environment goes live. See [Deploy](/quick-links/deploy). - -**Run with any agent.** Use [Integrations](/guides/integrations) to connect OpenAI, Anthropic, LangChain, or your own agent loop. - -**Find the right difficulty.** A good eval set has range—target 20-30% average success rate. You want high variance: some runs should grade 0.0, others 1.0. If every run grades the same, there's no signal to learn from. Having both positive and negative examples on the same eval is what makes improvement possible. - -**Iterate.** Create an eval, test it manually, run it at scale, check the difficulty. If it's too easy or too hard, adjust the prompt or grading. Use your best evals as templates for more. +**Test locally.** See [Testing Environments](/advanced/testing-environments) for debugging and scenario testing. -**Train.** Every eval generates data—prompts, tool calls, grades. Use successful runs for fine-tuning. The loop: eval → analyze → train → eval again. +**Deploy.** Push to GitHub, connect on [hud.ai](https://hud.ai). See [Deploy](/quick-links/deploy). diff --git a/hud/tools/tests/test_computer.py b/hud/tools/tests/test_computer.py index 4b7ef827..bdbb01d8 100644 --- a/hud/tools/tests/test_computer.py +++ b/hud/tools/tests/test_computer.py @@ -9,7 +9,7 @@ from hud.tools.computer.hud import HudComputerTool from hud.tools.computer.openai import OpenAIComputerTool from hud.tools.executors.base import BaseExecutor -from hud.types import Coordinate +from hud.tools.types import Coordinate @pytest.mark.asyncio From c707e5bfcd1e72e7d11face5e57fadba9995e3c1 Mon Sep 17 00:00:00 2001 From: lorenss-m Date: Sun, 14 Dec 2025 21:20:30 -0800 Subject: [PATCH 72/92] Remove environments folder - now in separate repos --- environments/README.md | 956 --- environments/blank/.env.example | 7 - environments/blank/Dockerfile | 22 - environments/blank/README.md | 128 - environments/blank/environment/README.md | 16 - environments/blank/environment/__init__.py | 1 - environments/blank/environment/pyproject.toml | 16 - environments/blank/environment/server.py | 40 - environments/blank/server/README.md | 21 - environments/blank/server/__init__.py | 1 - environments/blank/server/main.py | 43 - environments/blank/server/pyproject.toml | 19 - environments/blank/server/shared.py | 15 - environments/blank/server/tools.py | 35 - environments/blank/tasks.json | 44 - environments/blank/test_task.py | 52 - environments/browser/.dockerignore | 101 - environments/browser/.gitignore | 100 - environments/browser/Dockerfile | 60 - environments/browser/Dockerfile.local | 72 - environments/browser/README.md | 191 - environments/browser/browser-base/Dockerfile | 50 - environments/browser/browser-base/README.md | 58 - .../browser/environment/2048/README.md | 103 - .../browser/environment/2048/backend/game.py | 241 - .../browser/environment/2048/backend/main.py | 246 - .../environment/2048/backend/pyproject.toml | 9 - .../environment/2048/frontend/app/globals.css | 3 - .../environment/2048/frontend/app/layout.tsx | 22 - .../environment/2048/frontend/app/page.tsx | 190 - .../2048/frontend/components/GameBoard.tsx | 31 - .../2048/frontend/components/GameControls.tsx | 104 - .../2048/frontend/components/GameTile.tsx | 53 - .../environment/2048/frontend/next.config.js | 6 - .../environment/2048/frontend/package.json | 28 - .../2048/frontend/postcss.config.js | 6 - .../2048/frontend/tailwind.config.js | 12 - .../environment/2048/frontend/tsconfig.json | 27 - .../browser/environment/2048/launch.py | 284 - environments/browser/environment/README.md | 135 - environments/browser/environment/__init__.py | 3 - .../browser/environment/pyproject.toml | 23 - environments/browser/environment/server.py | 503 -- .../browser/environment/todo/README.md | 85 - .../browser/environment/todo/backend/main.py | 391 - .../environment/todo/backend/pyproject.toml | 15 - .../environment/todo/frontend/app/globals.css | 3 - .../environment/todo/frontend/app/layout.tsx | 22 - .../environment/todo/frontend/app/page.tsx | 289 - .../environment/todo/frontend/next.config.js | 13 - .../todo/frontend/package-lock.json | 6532 ----------------- .../environment/todo/frontend/package.json | 28 - .../todo/frontend/postcss.config.js | 6 - .../todo/frontend/tailwind.config.js | 12 - .../environment/todo/frontend/tsconfig.json | 26 - .../browser/environment/todo/launch.py | 286 - environments/browser/hud.lock.yaml | 503 -- environments/browser/pyproject.toml | 22 - environments/browser/server/__init__.py | 1 - .../browser/server/evaluate/__init__.py | 15 - .../browser/server/evaluate/game_2048.py | 220 - environments/browser/server/evaluate/todo.py | 233 - environments/browser/server/main.py | 84 - environments/browser/server/pyproject.toml | 21 - environments/browser/server/resources.py | 25 - environments/browser/server/setup/__init__.py | 15 - .../browser/server/setup/game_2048.py | 150 - environments/browser/server/setup/todo.py | 131 - environments/browser/server/shared.py | 48 - environments/browser/server/tools.py | 89 - environments/browser/tasks.json | 37 - environments/deepresearch/.gitignore | 13 - environments/deepresearch/Dockerfile | 24 - environments/deepresearch/README.md | 165 - .../deepresearch/environment/__init__.py | 1 - .../deepresearch/environment/pyproject.toml | 17 - .../deepresearch/environment/server.py | 340 - environments/deepresearch/pyproject.toml | 19 - environments/deepresearch/remote_tasks.json | 340 - environments/deepresearch/server/__init__.py | 1 - environments/deepresearch/server/main.py | 78 - .../deepresearch/server/pyproject.toml | 19 - environments/deepresearch/tasks.json | 366 - environments/jupyter/.gitignore | 2 - environments/jupyter/Dockerfile | 41 - environments/jupyter/README.md | 68 - environments/jupyter/server/__init__.py | 0 environments/jupyter/server/config.py | 4 - .../jupyter/server/evaluate/__init__.py | 11 - .../jupyter/server/evaluate/compare.py | 186 - .../jupyter/server/evaluate/eval_all.py | 146 - .../jupyter/server/evaluate/generalize.py | 83 - environments/jupyter/server/main.py | 60 - environments/jupyter/server/pyproject.toml | 34 - environments/jupyter/server/setup/__init__.py | 10 - environments/jupyter/server/tools/__init__.py | 1 - environments/jupyter/server/tools/jupyter.py | 24 - environments/jupyter/test_task.json | 27 - environments/online_mind2web/.gitignore | 2 - environments/online_mind2web/Dockerfile | 36 - environments/online_mind2web/README.md | 36 - environments/online_mind2web/pyproject.toml | 22 - .../src/hud_controller/__init__.py | 3 - .../src/hud_controller/context.py | 139 - .../src/hud_controller/evaluate/__init__.py | 11 - .../evaluate/autonomous_eval.py | 170 - .../hud_controller/evaluate/overall_judge.py | 48 - .../src/hud_controller/evaluate/webjudge.py | 502 -- .../src/hud_controller/providers/README.md | 110 - .../src/hud_controller/providers/__init__.py | 33 - .../hud_controller/providers/anchorbrowser.py | 183 - .../src/hud_controller/providers/base.py | 96 - .../hud_controller/providers/browserbase.py | 176 - .../providers/helper/__init__.py | 5 - .../hud_controller/providers/helper/proxy.py | 86 - .../hud_controller/providers/hyperbrowser.py | 244 - .../src/hud_controller/providers/kernel.py | 13 - .../src/hud_controller/providers/steel.py | 203 - .../src/hud_controller/server.py | 358 - .../src/hud_controller/setup/__init__.py | 16 - .../src/hud_controller/setup/navigate.py | 41 - .../src/hud_controller/tools/__init__.py | 13 - .../src/hud_controller/tools/anthropic.py | 265 - .../src/hud_controller/tools/executor.py | 384 - .../src/hud_controller/tools/openai.py | 266 - .../src/hud_controller/tools/playwright.py | 604 -- environments/online_mind2web/test_task.json | 55 - environments/remote_browser/.gitignore | 2 - environments/remote_browser/Dockerfile | 36 - environments/remote_browser/README.md | 225 - environments/remote_browser/pyproject.toml | 22 - .../src/hud_controller/__init__.py | 3 - .../src/hud_controller/context.py | 139 - .../src/hud_controller/evaluate/__init__.py | 24 - .../hud_controller/evaluate/cookie_exists.py | 69 - .../hud_controller/evaluate/cookie_match.py | 82 - .../hud_controller/evaluate/element_exists.py | 61 - .../hud_controller/evaluate/history_length.py | 83 - .../hud_controller/evaluate/page_contains.py | 95 - .../evaluate/raw_last_action_is.py | 93 - .../evaluate/selector_history.py | 80 - .../hud_controller/evaluate/sheet_contains.py | 171 - .../evaluate/sheets_cell_values.py | 349 - .../src/hud_controller/evaluate/url_match.py | 62 - .../evaluate/verify_type_action.py | 128 - .../src/hud_controller/problems/__init__.py | 14 - .../problems/element_interaction.py | 41 - .../problems/form_interaction.py | 28 - .../problems/navigate_and_verify.py | 28 - .../src/hud_controller/problems/registry.py | 91 - .../problems/search_interaction.py | 19 - .../src/hud_controller/providers/README.md | 110 - .../src/hud_controller/providers/__init__.py | 33 - .../hud_controller/providers/anchorbrowser.py | 170 - .../src/hud_controller/providers/base.py | 96 - .../hud_controller/providers/browserbase.py | 176 - .../providers/helper/__init__.py | 5 - .../hud_controller/providers/helper/proxy.py | 86 - .../hud_controller/providers/hyperbrowser.py | 244 - .../src/hud_controller/providers/kernel.py | 13 - .../src/hud_controller/providers/steel.py | 203 - .../src/hud_controller/server.py | 350 - .../src/hud_controller/setup/__init__.py | 16 - .../src/hud_controller/setup/cookies.py | 69 - .../src/hud_controller/setup/interact.py | 105 - .../src/hud_controller/setup/load_html.py | 44 - .../src/hud_controller/setup/navigate.py | 41 - .../src/hud_controller/setup/sheets.py | 345 - .../src/hud_controller/tools/__init__.py | 9 - .../src/hud_controller/tools/executor.py | 379 - .../src/hud_controller/tools/playwright.py | 189 - environments/remote_browser/test_task.json | 34 - environments/rubrics/.env.example | 5 - environments/rubrics/.gitignore | 13 - environments/rubrics/Dockerfile | 24 - environments/rubrics/README.md | 239 - environments/rubrics/environment/__init__.py | 1 - .../rubrics/environment/edgar_utils.py | 126 - environments/rubrics/environment/exa_utils.py | 57 - .../rubrics/environment/pyproject.toml | 19 - environments/rubrics/environment/server.py | 596 -- environments/rubrics/pyproject.toml | 19 - environments/rubrics/remote_tasks.json | 901 --- environments/rubrics/server/__init__.py | 1 - environments/rubrics/server/main.py | 154 - environments/rubrics/server/pyproject.toml | 19 - environments/rubrics/tasks.json | 912 --- environments/text_2048/2048_taskconfigs.json | 542 -- environments/text_2048/Dockerfile | 27 - environments/text_2048/README.md | 102 - environments/text_2048/pyproject.toml | 22 - .../text_2048/src/hud_controller/__init__.py | 1 - .../text_2048/src/hud_controller/context.py | 21 - .../src/hud_controller/evaluate/__init__.py | 12 - .../src/hud_controller/evaluate/efficiency.py | 26 - .../src/hud_controller/evaluate/max_number.py | 33 - .../text_2048/src/hud_controller/game.py | 204 - .../text_2048/src/hud_controller/server.py | 69 - .../src/hud_controller/setup/__init__.py | 16 - .../src/hud_controller/setup/board.py | 21 - .../src/hud_controller/tools/__init__.py | 5 - .../src/hud_controller/tools/move.py | 69 - 202 files changed, 28497 deletions(-) delete mode 100644 environments/README.md delete mode 100644 environments/blank/.env.example delete mode 100644 environments/blank/Dockerfile delete mode 100644 environments/blank/README.md delete mode 100644 environments/blank/environment/README.md delete mode 100644 environments/blank/environment/__init__.py delete mode 100644 environments/blank/environment/pyproject.toml delete mode 100644 environments/blank/environment/server.py delete mode 100644 environments/blank/server/README.md delete mode 100644 environments/blank/server/__init__.py delete mode 100644 environments/blank/server/main.py delete mode 100644 environments/blank/server/pyproject.toml delete mode 100644 environments/blank/server/shared.py delete mode 100644 environments/blank/server/tools.py delete mode 100644 environments/blank/tasks.json delete mode 100644 environments/blank/test_task.py delete mode 100644 environments/browser/.dockerignore delete mode 100644 environments/browser/.gitignore delete mode 100644 environments/browser/Dockerfile delete mode 100644 environments/browser/Dockerfile.local delete mode 100644 environments/browser/README.md delete mode 100644 environments/browser/browser-base/Dockerfile delete mode 100644 environments/browser/browser-base/README.md delete mode 100644 environments/browser/environment/2048/README.md delete mode 100644 environments/browser/environment/2048/backend/game.py delete mode 100644 environments/browser/environment/2048/backend/main.py delete mode 100644 environments/browser/environment/2048/backend/pyproject.toml delete mode 100644 environments/browser/environment/2048/frontend/app/globals.css delete mode 100644 environments/browser/environment/2048/frontend/app/layout.tsx delete mode 100644 environments/browser/environment/2048/frontend/app/page.tsx delete mode 100644 environments/browser/environment/2048/frontend/components/GameBoard.tsx delete mode 100644 environments/browser/environment/2048/frontend/components/GameControls.tsx delete mode 100644 environments/browser/environment/2048/frontend/components/GameTile.tsx delete mode 100644 environments/browser/environment/2048/frontend/next.config.js delete mode 100644 environments/browser/environment/2048/frontend/package.json delete mode 100644 environments/browser/environment/2048/frontend/postcss.config.js delete mode 100644 environments/browser/environment/2048/frontend/tailwind.config.js delete mode 100644 environments/browser/environment/2048/frontend/tsconfig.json delete mode 100644 environments/browser/environment/2048/launch.py delete mode 100644 environments/browser/environment/README.md delete mode 100644 environments/browser/environment/__init__.py delete mode 100644 environments/browser/environment/pyproject.toml delete mode 100644 environments/browser/environment/server.py delete mode 100644 environments/browser/environment/todo/README.md delete mode 100644 environments/browser/environment/todo/backend/main.py delete mode 100644 environments/browser/environment/todo/backend/pyproject.toml delete mode 100644 environments/browser/environment/todo/frontend/app/globals.css delete mode 100644 environments/browser/environment/todo/frontend/app/layout.tsx delete mode 100644 environments/browser/environment/todo/frontend/app/page.tsx delete mode 100644 environments/browser/environment/todo/frontend/next.config.js delete mode 100644 environments/browser/environment/todo/frontend/package-lock.json delete mode 100644 environments/browser/environment/todo/frontend/package.json delete mode 100644 environments/browser/environment/todo/frontend/postcss.config.js delete mode 100644 environments/browser/environment/todo/frontend/tailwind.config.js delete mode 100644 environments/browser/environment/todo/frontend/tsconfig.json delete mode 100644 environments/browser/environment/todo/launch.py delete mode 100644 environments/browser/hud.lock.yaml delete mode 100644 environments/browser/pyproject.toml delete mode 100644 environments/browser/server/__init__.py delete mode 100644 environments/browser/server/evaluate/__init__.py delete mode 100644 environments/browser/server/evaluate/game_2048.py delete mode 100644 environments/browser/server/evaluate/todo.py delete mode 100644 environments/browser/server/main.py delete mode 100644 environments/browser/server/pyproject.toml delete mode 100644 environments/browser/server/resources.py delete mode 100644 environments/browser/server/setup/__init__.py delete mode 100644 environments/browser/server/setup/game_2048.py delete mode 100644 environments/browser/server/setup/todo.py delete mode 100644 environments/browser/server/shared.py delete mode 100644 environments/browser/server/tools.py delete mode 100644 environments/browser/tasks.json delete mode 100644 environments/deepresearch/.gitignore delete mode 100644 environments/deepresearch/Dockerfile delete mode 100644 environments/deepresearch/README.md delete mode 100644 environments/deepresearch/environment/__init__.py delete mode 100644 environments/deepresearch/environment/pyproject.toml delete mode 100644 environments/deepresearch/environment/server.py delete mode 100644 environments/deepresearch/pyproject.toml delete mode 100644 environments/deepresearch/remote_tasks.json delete mode 100644 environments/deepresearch/server/__init__.py delete mode 100644 environments/deepresearch/server/main.py delete mode 100644 environments/deepresearch/server/pyproject.toml delete mode 100644 environments/deepresearch/tasks.json delete mode 100644 environments/jupyter/.gitignore delete mode 100644 environments/jupyter/Dockerfile delete mode 100644 environments/jupyter/README.md delete mode 100644 environments/jupyter/server/__init__.py delete mode 100644 environments/jupyter/server/config.py delete mode 100644 environments/jupyter/server/evaluate/__init__.py delete mode 100644 environments/jupyter/server/evaluate/compare.py delete mode 100644 environments/jupyter/server/evaluate/eval_all.py delete mode 100644 environments/jupyter/server/evaluate/generalize.py delete mode 100644 environments/jupyter/server/main.py delete mode 100644 environments/jupyter/server/pyproject.toml delete mode 100644 environments/jupyter/server/setup/__init__.py delete mode 100644 environments/jupyter/server/tools/__init__.py delete mode 100644 environments/jupyter/server/tools/jupyter.py delete mode 100644 environments/jupyter/test_task.json delete mode 100644 environments/online_mind2web/.gitignore delete mode 100644 environments/online_mind2web/Dockerfile delete mode 100644 environments/online_mind2web/README.md delete mode 100644 environments/online_mind2web/pyproject.toml delete mode 100644 environments/online_mind2web/src/hud_controller/__init__.py delete mode 100644 environments/online_mind2web/src/hud_controller/context.py delete mode 100644 environments/online_mind2web/src/hud_controller/evaluate/__init__.py delete mode 100644 environments/online_mind2web/src/hud_controller/evaluate/autonomous_eval.py delete mode 100644 environments/online_mind2web/src/hud_controller/evaluate/overall_judge.py delete mode 100644 environments/online_mind2web/src/hud_controller/evaluate/webjudge.py delete mode 100644 environments/online_mind2web/src/hud_controller/providers/README.md delete mode 100644 environments/online_mind2web/src/hud_controller/providers/__init__.py delete mode 100644 environments/online_mind2web/src/hud_controller/providers/anchorbrowser.py delete mode 100644 environments/online_mind2web/src/hud_controller/providers/base.py delete mode 100644 environments/online_mind2web/src/hud_controller/providers/browserbase.py delete mode 100644 environments/online_mind2web/src/hud_controller/providers/helper/__init__.py delete mode 100644 environments/online_mind2web/src/hud_controller/providers/helper/proxy.py delete mode 100644 environments/online_mind2web/src/hud_controller/providers/hyperbrowser.py delete mode 100644 environments/online_mind2web/src/hud_controller/providers/kernel.py delete mode 100644 environments/online_mind2web/src/hud_controller/providers/steel.py delete mode 100644 environments/online_mind2web/src/hud_controller/server.py delete mode 100644 environments/online_mind2web/src/hud_controller/setup/__init__.py delete mode 100644 environments/online_mind2web/src/hud_controller/setup/navigate.py delete mode 100644 environments/online_mind2web/src/hud_controller/tools/__init__.py delete mode 100644 environments/online_mind2web/src/hud_controller/tools/anthropic.py delete mode 100644 environments/online_mind2web/src/hud_controller/tools/executor.py delete mode 100644 environments/online_mind2web/src/hud_controller/tools/openai.py delete mode 100644 environments/online_mind2web/src/hud_controller/tools/playwright.py delete mode 100644 environments/online_mind2web/test_task.json delete mode 100644 environments/remote_browser/.gitignore delete mode 100644 environments/remote_browser/Dockerfile delete mode 100644 environments/remote_browser/README.md delete mode 100644 environments/remote_browser/pyproject.toml delete mode 100644 environments/remote_browser/src/hud_controller/__init__.py delete mode 100644 environments/remote_browser/src/hud_controller/context.py delete mode 100644 environments/remote_browser/src/hud_controller/evaluate/__init__.py delete mode 100644 environments/remote_browser/src/hud_controller/evaluate/cookie_exists.py delete mode 100644 environments/remote_browser/src/hud_controller/evaluate/cookie_match.py delete mode 100644 environments/remote_browser/src/hud_controller/evaluate/element_exists.py delete mode 100644 environments/remote_browser/src/hud_controller/evaluate/history_length.py delete mode 100644 environments/remote_browser/src/hud_controller/evaluate/page_contains.py delete mode 100644 environments/remote_browser/src/hud_controller/evaluate/raw_last_action_is.py delete mode 100644 environments/remote_browser/src/hud_controller/evaluate/selector_history.py delete mode 100644 environments/remote_browser/src/hud_controller/evaluate/sheet_contains.py delete mode 100644 environments/remote_browser/src/hud_controller/evaluate/sheets_cell_values.py delete mode 100644 environments/remote_browser/src/hud_controller/evaluate/url_match.py delete mode 100644 environments/remote_browser/src/hud_controller/evaluate/verify_type_action.py delete mode 100644 environments/remote_browser/src/hud_controller/problems/__init__.py delete mode 100644 environments/remote_browser/src/hud_controller/problems/element_interaction.py delete mode 100644 environments/remote_browser/src/hud_controller/problems/form_interaction.py delete mode 100644 environments/remote_browser/src/hud_controller/problems/navigate_and_verify.py delete mode 100644 environments/remote_browser/src/hud_controller/problems/registry.py delete mode 100644 environments/remote_browser/src/hud_controller/problems/search_interaction.py delete mode 100644 environments/remote_browser/src/hud_controller/providers/README.md delete mode 100644 environments/remote_browser/src/hud_controller/providers/__init__.py delete mode 100644 environments/remote_browser/src/hud_controller/providers/anchorbrowser.py delete mode 100644 environments/remote_browser/src/hud_controller/providers/base.py delete mode 100644 environments/remote_browser/src/hud_controller/providers/browserbase.py delete mode 100644 environments/remote_browser/src/hud_controller/providers/helper/__init__.py delete mode 100644 environments/remote_browser/src/hud_controller/providers/helper/proxy.py delete mode 100644 environments/remote_browser/src/hud_controller/providers/hyperbrowser.py delete mode 100644 environments/remote_browser/src/hud_controller/providers/kernel.py delete mode 100644 environments/remote_browser/src/hud_controller/providers/steel.py delete mode 100644 environments/remote_browser/src/hud_controller/server.py delete mode 100644 environments/remote_browser/src/hud_controller/setup/__init__.py delete mode 100644 environments/remote_browser/src/hud_controller/setup/cookies.py delete mode 100644 environments/remote_browser/src/hud_controller/setup/interact.py delete mode 100644 environments/remote_browser/src/hud_controller/setup/load_html.py delete mode 100644 environments/remote_browser/src/hud_controller/setup/navigate.py delete mode 100644 environments/remote_browser/src/hud_controller/setup/sheets.py delete mode 100644 environments/remote_browser/src/hud_controller/tools/__init__.py delete mode 100644 environments/remote_browser/src/hud_controller/tools/executor.py delete mode 100644 environments/remote_browser/src/hud_controller/tools/playwright.py delete mode 100644 environments/remote_browser/test_task.json delete mode 100644 environments/rubrics/.env.example delete mode 100644 environments/rubrics/.gitignore delete mode 100644 environments/rubrics/Dockerfile delete mode 100644 environments/rubrics/README.md delete mode 100644 environments/rubrics/environment/__init__.py delete mode 100644 environments/rubrics/environment/edgar_utils.py delete mode 100644 environments/rubrics/environment/exa_utils.py delete mode 100644 environments/rubrics/environment/pyproject.toml delete mode 100644 environments/rubrics/environment/server.py delete mode 100644 environments/rubrics/pyproject.toml delete mode 100644 environments/rubrics/remote_tasks.json delete mode 100644 environments/rubrics/server/__init__.py delete mode 100644 environments/rubrics/server/main.py delete mode 100644 environments/rubrics/server/pyproject.toml delete mode 100644 environments/rubrics/tasks.json delete mode 100644 environments/text_2048/2048_taskconfigs.json delete mode 100644 environments/text_2048/Dockerfile delete mode 100644 environments/text_2048/README.md delete mode 100644 environments/text_2048/pyproject.toml delete mode 100644 environments/text_2048/src/hud_controller/__init__.py delete mode 100644 environments/text_2048/src/hud_controller/context.py delete mode 100644 environments/text_2048/src/hud_controller/evaluate/__init__.py delete mode 100644 environments/text_2048/src/hud_controller/evaluate/efficiency.py delete mode 100644 environments/text_2048/src/hud_controller/evaluate/max_number.py delete mode 100644 environments/text_2048/src/hud_controller/game.py delete mode 100644 environments/text_2048/src/hud_controller/server.py delete mode 100644 environments/text_2048/src/hud_controller/setup/__init__.py delete mode 100644 environments/text_2048/src/hud_controller/setup/board.py delete mode 100644 environments/text_2048/src/hud_controller/tools/__init__.py delete mode 100644 environments/text_2048/src/hud_controller/tools/move.py diff --git a/environments/README.md b/environments/README.md deleted file mode 100644 index 40cba300..00000000 --- a/environments/README.md +++ /dev/null @@ -1,956 +0,0 @@ -# How to Build HUD-Compatible MCP Environments - -This document is a step-by-step guide for turning *any* piece of software that can run in a Docker container into a **Model Context Protocol (MCP)** environment that the HUD SDK can evaluate or control. We’ll move through six short phases, each with a clear checkpoint. - -> **Big picture** -> • An *agent* (LLM) wants to solve tasks inside a *software environment*. -> • Your job: give that environment a clean, programmable surface – a set of -> *tools* the agent can invoke. -> • MCP is simply the wire-format we use to move those tool calls back and forth -> (like gRPC or HTTP but JSON-RPC over stdio/Docker). -> • FastMCP is the underlying SDK; HUD provides **MCPServer** – a thin wrapper that -> adds SIGTERM handling, `@initialize` / `@shutdown` decorators, and easier -> tool registration while remaining 100 % compatible with FastMCP. -> -> The picture: -> ```text -> LLM Agent ──JSON-RPC──► FastMCP server (your code) ──► real app / game / browser -> ``` -> Your job is to wrap *any* app in an MCP server so agents can control it reproducibly & safely. - ---- - -## Phase Overview - -| Phase | Goal | -|-------|------| -| 1 | A Docker image that *starts* and prints to **stderr** | -| 2 | A minimal MCP server that responds to `initialize` over **stdio** | -| 3 | Working `setup`, `evaluate`, and **interaction** tools | -| 4 | Image launches remotely on the HUD platform & exposes live telemetry | -| 5 | Fast local iteration with `hud dev` hot-reload | - -Take the phases one at a time; do **not** jump ahead. Each stage's checkpoint is the foundation for the next. - -## Reference Implementations - -This repository includes two complete MCP environment implementations that demonstrate different levels of complexity: - -### 1. `text_2048` - Simple Game Environment -A minimalist ASCII-based 2048 game that showcases: -- Basic hub pattern with setup/evaluate tools -- Custom interaction tools (move command) -- Clean separation of game logic and MCP server -- Minimal dependencies (Python only) -- Perfect for learning the core concepts - -### 2. `remote_browser` - Advanced Browser Automation -A sophisticated browser automation environment featuring: -- Multiple cloud browser provider integrations (AnchorBrowser, Steel, BrowserBase, HyperBrowser, Kernel) -- Both Playwright and computer tools for interaction -- Extensive setup/evaluate capabilities (navigation, cookies, sheets, element checks) -- Live telemetry with browser viewing URLs -- Production-ready error handling and cleanup - -💡 **Follow along with text_2048** as you work through each phase - it demonstrates all the core patterns with minimal complexity. - -### Installing the HUD CLI - -The HUD SDK includes a powerful CLI for debugging and analyzing MCP environments: - -```bash -# Install HUD CLI globally with uv (recommended) -uv tool install hud-python@latest --python 3.12 - -# Or use without installing -uvx --from hud-python hud --help - -# Verify installation -hud --help -``` - -Common commands: -```bash -# Debug your Docker image (runs 5-phase test) -hud debug my-mcp-server:latest - -# Analyze available tools and resources -hud analyze my-mcp-server:latest --format json - -# Debug any command-based MCP server -hud debug --command "python my_server.py" -``` -While you move through the phases it's handy to run the **interactive checker** to make sure nothing broke: - -```bash -# First build your Docker image -docker build -t my-environment environments/my-environment - -# Then debug it -hud debug my-environment -``` - -**What's the difference?** -- **`hud debug`** - Tests your environment in 5 phases, checking startup, MCP protocol, tools, and readiness. Use this first! -- **`hud analyze`** - Explores the environment to discover all tools, resources, and capabilities. Only works after debug passes phase 3. - -The script walks the *same* checklist and prints coloured, human-friendly hints whenever something fails. - -| What it validates | Phase | -|-------------------|-------| -| Container starts & logs to **stderr** | 1 | -| MCP server responds to an `initialize` request | 2 | -| Discovers `setup`, `evaluate`, and interaction tools | 3 | -| Calls `setup` / `evaluate`, checks telemetry & startup time | 4 | -| Spawns three concurrent clients to stress-test resources | 5 | - -💡 **Run it after finishing each phase.** If the checker exits with a red ❌, scroll up for the gold-coloured *hint* block – it usually points directly to the root cause. - ---- - -## Phase 1 – Write a Dockerfile - -**Goal →** Create a container that can run your MCP server with proper Python packaging. - -Key principles: -- **stdout** is reserved for MCP protocol (JSON-RPC) -- **stderr** is for all logs and debug output -- Use proper Python packaging with `pyproject.toml` -- Run as a module for clean imports - -### Dockerfile Template - -```dockerfile -FROM python:3.11-slim - -# Prevent Python from buffering output (important for logs) -ENV PYTHONUNBUFFERED=1 \ - PYTHONDONTWRITEBYTECODE=1 - -WORKDIR /app - -# Copy package files -COPY pyproject.toml ./ -COPY src/ ./src/ - -# Install in editable mode for development flexibility -RUN pip install --no-cache-dir -e . - -# Run as a module to ensure proper package imports -CMD ["python", "-m", "my_module.server"] -``` - -### Build & Test - -```bash -docker build -t my-environment . - -# Test Phase 1: Container should start without errors -docker run --rm -i my-environment -``` - -### Recommended Environment Structure - -For Python-based MCP environments, use this standard structure: - -``` -my-environment/ -├── Dockerfile -├── README.md -├── server/ # MCP server package -│ ├── pyproject.toml # MCP dependencies (hud-python, etc.) -│ ├── __init__.py # Empty package marker -│ ├── main.py # mcp = MCPServer() + lifecycle hooks -│ ├── tools.py # router = MCPRouter() + @router.tool decorators -│ ├── setup/ # Setup router (modular approach) -│ │ ├── __init__.py -│ │ ├── basic.py # Basic setup functions -│ │ └── advanced.py # Advanced setup functions -│ └── evaluate/ # Evaluate router (modular approach) -│ ├── __init__.py -│ ├── checks.py # Basic evaluation checks -│ └── metrics.py # Advanced metrics evaluators -└── environment/ # Backend service package - ├── pyproject.toml # Backend dependencies (fastapi, uvicorn) - ├── __init__.py - └── server.py # FastAPI app with /health, /act, /reset, /state -``` - -This structure enables: -- Clean separation of concerns (environment logic, tools, setup, evaluation) -- Easy volume mounting for development (Phase 5) -- Standard Python packaging with `pip install -e .` -- Modular organization - each setup/evaluator in its own file for clarity - -• **One Dockerfile only** – no docker-compose. -• If you're building a GUI environment, start from `hudpython/novnc-base:latest` instead and leave VNC configuration for later phases. - -Checkpoint reached? Congratulations – move on. - -👉 Quick sanity check: `hud debug my-environment` (verifies Phase 1 automatically) - -Need inspiration? Check out our reference implementations: -• [`text_2048/Dockerfile`](./text_2048/Dockerfile) - Minimal Python setup, perfect for simple environments -• [`remote_browser/Dockerfile`](./remote_browser/Dockerfile) - Uses pre-built base image with browser dependencies -• [`browser/Dockerfile`](./browser/Dockerfile) - Multi-stage build with full GUI support - ---- - -## Phase 2 – Create the MCP Server - -**Goal →** a Python process that: -1. Speaks MCP over **stdio**. -2. Responds correctly to the `initialize` request. -3. Logs everything to **stderr**. - -The MCP lifecycle is *initialize → operate → shutdown* (see spec link above). - -### Skeleton server (MCPServer) - -```python -import sys -import logging -from hud.server import MCPServer - -# 1️⃣ Always log to stderr – stdout is reserved for JSON-RPC -logging.basicConfig( - stream=sys.stderr, - level=logging.INFO, - format='[%(levelname)s] %(asctime)s | %(name)s | %(message)s' -) - -# Create the server early so decorators can reference it -mcp = MCPServer(name="My Environment") - -# Run heavy one-time setup during MCP initialize -@mcp.initialize -async def initialize_environment(session=None, progress_token=None): - """Heavy one-time setup – start databases, launch background apps, etc.""" - logging.info("starting core services…") - await start_services() # your coroutine - logging.info("services ready") - -if __name__ == "__main__": - mcp.run() -``` - -*(Replace `start_services()` with whatever takes noticeable startup time – browsers, DBs, X servers, …)* - -### Adapt Dockerfile - -At the end of your Dockerfile, you must launch the MCP server as the container's main process, ensuring it communicates over stdio (stdin/stdout). This is typically done by setting the `CMD` or `ENTRYPOINT` to run your server module directly, for example: - - -```dockerfile -FROM python:3.11-slim - -WORKDIR /app -COPY . . - -# Optional: install requirements -# RUN pip install -r requirements.txt - -CMD ["python", "-m", "your_module_name"] # Replace 'your_module_name' with your actual entrypoint module -``` - -### Three validation steps (run them **in order**) - -| # | What you do | Why it matters | -|---|-------------|----------------| -| 1 | **Direct stdio test** – pipe the JSON below into your script | Proves the Python code handles `initialize` without any client or Docker noise | -| 2 | **MCP Inspector** – `npx @modelcontextprotocol/inspector python -m my_package.server` | Lets you click around: view capabilities, tools, resources | -| 3 | **Inside Docker** – rebuild the image and run it | This is *exactly* how HUD will execute the server | -| 4 | **Run `hud debug`** – `hud debug my-environment` | Combines the above checks & points out common mistakes | - -#### JSON for step 1 - -```jsonc -{ "jsonrpc": "2.0", "id": 1, "method": "initialize", "params": { - "protocolVersion": "2024-11-05", - "capabilities": {"roots": {"listChanged": true}}, - "clientInfo": {"name": "DevClient", "title": "Dev", "version": "0.0.0"} -}} -``` - -Pipe it: - -```bash -echo '' | python -m my_package.server -``` - -If all three validations succeed, you have a real MCP server – time to make it useful. - ---- - -## Phase 3 – Add Setup / Evaluate / Interaction Tools - -**Goal →** tools are discoverable in the Inspector *and* callable from the HUD SDK. - -👉 After wiring in the tools, confirm with `hud debug my-environment --max-phase 3` – it now checks for their presence and basic execution. - -🔍 Once debug passes phase 3, you can analyze the environment: -```bash -hud analyze my-environment # Interactive view of tools and resources -hud analyze my-environment --format json # JSON output for scripts -hud analyze my-environment --format markdown # Generate documentation -``` - -1. Write **`setup`** and **`evaluate`** tools first – they are *lifecycle* tools and never shown to the LLM. -2. Register at least one **interaction** tool (`computer`, `playwright`, or your own). - -### Approach 1: Simple Direct Implementation - -For simple environments with just a few setup/evaluate functions, you can use direct tool decorators with **MCPServer**: - -```python -from hud.server import MCPServer -from hud.tools import HudComputerTool - -mcp = MCPServer(name="my-environment") - -@mcp.tool() -async def setup(config: dict) -> dict: - ... # prepare environment - -@mcp.tool() -async def evaluate(config: dict) -> dict: - ... # return {"reward": <0-1>, "done": bool} - -@mcp.initialize -async def initialize_environment(session=None, progress_token=None): - custom_tool = HudComputerTool() - mcp.add_tool(custom_tool.mcp) - - # Any other initialization -``` - -### Approach 2: Hub Pattern (Recommended for Complex Environments) - -The BaseHub pattern provides a clean way to organize multiple setup/evaluate functions with automatic discovery and registration. **A BaseHub is fundamentally another MCP server (it's a subclass of FastMCP)** that you mount to your main server, providing namespace separation and modular organization. All hub functions are exposed through one tool named after the hub, and a resource that can list all of its tools. - -When mounted, the hub's tools are accessible through a single tool that dispatches to the appropriate function: -```json -{ - "name": "setup", - "arguments": { - "name": "reset", // Which function in the hub to call - "arguments": {"param": "value"} // Additional parameters - } -} -``` - -```python -# In setup/__init__.py -from hud.tools.base import BaseHub - -# Create the setup hub (a sub-server) -setup = BaseHub("setup") - -# Import all setup modules to register their tools -from . import basic, advanced # This registers all @setup.tool() decorated functions - -# In setup/basic.py -from . import setup -from mcp.types import TextContent - -@setup.tool() -async def reset(**kwargs): - """Reset the environment to its initial state. - - Args: - **kwargs: Additional parameters - - Returns: - TextContent - """ - # Access environment from the hub - env = setup.env - await env.reset_state() - return TextContent( - text="Environment reset to initial state", - type="text" - ) - -@setup.tool() -async def seed_data(num_items: int = 5): - """Seed the environment with test data. - - Args: - num_items: Number of items to create - - Returns: - TextContent - """ - # Access environment from the hub - env = setup.env - items = await env.create_items(num_items) - return TextContent( - text=f"Created {len(items)} items", - type="text" - ) - -# In evaluate/__init__.py -from hud.tools.base import BaseHub - -# Create the evaluate hub (another sub-server) -evaluate = BaseHub("evaluate") - -# Import all evaluator modules -from . import checks, metrics - -# In evaluate/checks.py -from . import evaluate -from hud.tools.types import EvaluationResult - -@evaluate.tool() -async def task_complete(expected_count: int): - """Check if the expected number of tasks are completed. - - Args: - expected_count: Expected number of completed tasks - - Returns: - EvaluationResult - """ - # Access environment from the hub - env = evaluate.env - completed = await env.count_completed() - return EvaluationResult( - reward=min(completed / expected_count, 1.0), - done=completed >= expected_count, - content=f"Completed {completed}/{expected_count} tasks", - info={"completed": completed, "expected": expected_count} - ) - -# In server.py -from .setup import setup as setup_hub -from .evaluate import evaluate as evaluate_hub - -# Create MCP server -mcp = MCPServer(name="my-environment") - -@mcp.initialize -async def initialize_environment(ctx): - """Initialize the environment with progress notifications.""" - # Extract progress token from context - progress_token = getattr(ctx.meta, "progressToken", None) if ctx.meta else None - # Send progress updates if available - async def send_progress(progress: int, message: str): - if progress_token: - await ctx.session.send_progress_notification( - progress_token=progress_token, - progress=progress, - total=100, - message=message, - ) - - await send_progress(10, "Starting environment initialization...") - - # Initialize your environment state/context - env = await create_environment_context() - await send_progress(50, "Environment created...") - - # Set environment on hubs - setup_hub.env = env - evaluate_hub.env = env - - # Mount hubs to MCP server - mcp.mount(setup_hub) - mcp.mount(evaluate_hub) - await send_progress(80, "Tools registered...") - - # Register any custom interaction tools - if hasattr(env, 'custom_tool'): - mcp.add_tool(env.custom_tool.mcp) - - await send_progress(100, "Environment ready!") -``` - -The BaseHub pattern provides: -- **Namespace isolation**: Tools are grouped under the hub's name (e.g., "setup", "evaluate") -- **Modular organization**: Each hub can be developed and tested independently -- **Type safety**: Full type hints preserved for parameters and returns - -When you call a hub's tool, you specify which function to execute: -```python -# Calling the "reset" function in the setup hub -await client.call_tool("setup", {"name": "reset"}) - -# Calling the "task_complete" function in the evaluate hub -await client.call_tool("evaluate", {"name": "task_complete", "expected_count": 5}) -``` - -### Test workflow - -1. **Inspector first** – restart the server, refresh the *Tools* tab, confirm the new tools appear. -2. **Run `hud debug my-environment`** – this validates initialization, tool discovery and basic calls automatically. -3. **Rebuild the image** – `docker build -t my-environment .`. -4. **HUD SDK script test** – run a short script like the one below. GUI environments built from `hudpython/novnc-base` still expose a VNC viewer on – keep it open while testing. - -```python -import asyncio -import hud -from hud.datasets import Task -from hud.agents import ClaudeAgent -from hud.clients import MCPClient - -async def main(): - # `trace` captures *everything* that happens and sends it to hud.ai - async with hud.async_trace("local_test"): - task = Task( - prompt="Complete the task", - mcp_config={ - "local": { - "command": "docker", - "args": ["run", "--rm", "-i", "my-environment:latest"] - } - }, - setup_tool={"name": "setup", "arguments": {"name": "todo_seed", "num_items": 5}}, - evaluate_tool={"name": "evaluate", "arguments": {"name": "todo_completed", "expected_count": 2}} - ) - client = MCPClient(mcp_config=task.mcp_config) - - agent = ClaudeAgent( - mcp_client=client, - model="claude-3-7-sonnet-20250219", - allowed_tools=["computer"] # or ["move"] for text_2048 - ) - - result = await agent.run(task) - print(result) - - await client.close() - -asyncio.run(main()) -``` - -The `trace` context manager sends a full timeline of agent actions, tool calls, and rewards to hud.ai – perfect for debugging. - -See `examples/01_hello_2048.py` and `examples/task_with_setup_eval.py` for larger end-to-end demos. - ---- - -## Phase 4 – Remote Deployment & HUD Runner - -**Goal →** the exact same image runs in parallel on hundreds of instances, and exposes more telemetry so the hud.ai can visualise the whole lifecycle. - -### 1. Publish your image - -Log in to Docker Hub (or any registry HUD can pull from) and push a tagged build: - -```bash -docker tag my-environment yourdockerhubuser/my-environment:latest -docker push yourdockerhubuser/my-environment:latest -``` - -*(If you’re using a private registry, make sure the HUD worker has pull credentials.)* - -### 2. Launch it remotely (gmail_remote pattern) - -Here's how to configure a remote MCP server that runs **the same Docker image**: - -```python -from hud import settings -from hud.clients import MCPClient - -# Your image is in a registry, now tell HUD to pull & run it on demand -config = { - "hud": { - "url": settings.hud_mcp_url, - "headers": { - "Authorization": f"Bearer {settings.api_key}", - "Mcp-Image": "yourdockerhubuser/my-environment:latest", # which image to launch - }, - } -} - -client = MCPClient(mcp_config=config) -``` - -_Steps 3 and 4 below are **optional but highly recommended** once the image boots successfully._ - -Spin up **many** agents in parallel by just launching multiple tasks – HUD will queue and start as many containers as resources allow. - -### 3. Progress updates during `initialize` (Optional) - -At remote scale it can take 10-30 s for heavy services to boot. Use the new -`@mcp.initialize` decorator plus the `session` / `progress_token` parameters to -stream status messages: - -```python -@mcp.initialize -async def initialize_environment(session=None, progress_token=None): - async def send(p, msg): - if session and progress_token: - await session.send_progress_notification( - progress_token=progress_token, - progress=p, - total=100, - message=msg - ) - await send(10, "Starting X11...") - await start_x11() - await send(50, "Launching browser…") - await launch_browser() - await send(100, "ready") -``` - -Those messages are displayed live on hud.ai alongside resource graphs – perfect feedback while you wait. - -### 4. Live telemetry (`telemetry://live`) (Optional) - -Expose a resource named `telemetry://live` exactly like in `environments/browser/src/hud_controller/server.py` to return live url to be displayed on hud.ai. - -Once all of the above works you can unleash *hundreds* of concurrent agents on your new environment. - ---- - -## Phase 5 – Hot-Reload Development - -For rapid local development, run the controller and environment servers separately. This enables instant code updates without Docker rebuilds. - -### Development Setup - -You'll need **two terminal windows** for local development: - -#### Terminal 1: MCP Server -```bash -cd environments/my-environment/server -hud dev # Auto-detects and runs with hot-reload - -# Optional flags: -hud dev --inspector # Launch MCP Inspector -hud dev --interactive # Launch interactive testing mode -hud dev --stdio # Use stdio transport (default: HTTP) -hud dev --watch ../shared # Watch additional directories -``` - -The `hud dev` command: -- Auto-detects the MCP module in the current directory -- Watches for file changes and reloads automatically -- Runs on HTTP by default (http://localhost:8765/mcp) -- Can launch MCP Inspector for testing tools -- Can launch interactive mode for manual testing - -#### Terminal 2: Environment Server (Backend) -```bash -cd environments/my-environment/environment -uvicorn server:app --reload # Standard uvicorn with hot-reload -``` - -For the backend, we simply use `uvicorn` directly since it already provides excellent hot-reload capabilities. - -### Development Workflow - -1. Start both servers in separate terminals -2. Edit code in either `server/` or `environment/` - changes reload automatically -3. Test changes immediately without rebuilding Docker images -4. Use MCP Inspector or interactive mode to test tools -5. When ready, build the complete Docker image: `hud build` - -### Quick Cursor Setup - -Add to `.cursor/mcp.json` (or use the deeplink from `hud dev` output): - -```json -{ - "mcpServers": { - "my-environment-dev": { - "url": "http://localhost:8765/mcp" - } - } -} -``` - -**Note**: Make sure both MCP server and environment backend are running when using with Cursor or agents. - -### Process Separation for Stateful Environments - -**Important Architecture Pattern**: For environments that maintain state (browsers, databases, running applications), you should separate the MCP server process from the actual environment process. This separation is critical for effective hot-reload development. - -#### Why Process Separation? - -When `hud dev` restarts your MCP server for code changes, you don't want to lose: -- Open browser windows and navigation state -- Database connections and data -- Running application state -- X11/VNC sessions -- Any expensive initialization - -#### Architecture Pattern - -``` -┌─────────────────┐ ┌──────────────────────┐ -│ MCP Server │────▶│ Environment Process │ -│ (Restartable) │ │ (Persistent) │ -└─────────────────┘ └──────────────────────┘ - ▲ │ - │ │ - └─── Communication ────────┘ - (Socket, API, gRPC) -``` - -#### Implementation Example - -1. **Create a Context Server** (`context_server.py`): -```python -from hud.server.context import run_context_server - -class PersistentEnvironmentContext: - def __init__(self): - self.state = {} - self.resources = None - - def startup(self): - # One-time expensive initialization - self.resources = initialize_expensive_resources() - - def get_state(self): - return self.state - -if __name__ == "__main__": - context = PersistentEnvironmentContext() - context.startup() - # Run on Unix socket - asyncio.run(run_context_server(context, "/tmp/my_env_ctx.sock")) -``` - -2. **Connect from MCP Server** (`server.py`): -```python -from hud.server.context import attach_context - -@mcp.initialize -async def initialize_environment(ctx): - # Connect to persistent context - persistent_ctx = attach_context("/tmp/my_env_ctx.sock") - - # Use existing state without reinitializing - state = persistent_ctx.get_state() - resources = persistent_ctx.get_resources() -``` - -3. **Update Dockerfile** to run both processes: -```dockerfile -# Start context server in background -CMD ["sh", "-c", "python -m hud_controller.context_server & python -m hud_controller.server"] -``` - -#### Communication Options - -- **Unix Sockets** (recommended for local): Fast, simple, no network overhead -- **TCP/HTTP API**: Good for distributed systems -- **gRPC**: Type-safe, efficient for complex APIs -- **Shared Memory**: Ultra-fast for large data - -See the `browser` environment for a complete production example of this pattern. - -### 4. Cursor rules – paste this once - -Inside `.cursor/rules/mcp_environment_iteration.mdc` add (or verify) the following so the agent always knows the expected iteration loop: - -```mdc ---- -description: Improve an MCP environment -alwaysApply: false ---- -Setup -1. Make sure the user has set up the mcp config for the environment by seeing if you have access to the tools by the given name (i.e. my-environment-dev), and make sure the title is in dev mode. If not, ask the user to make a dev version! -2. Make sure you can find the source folder for this environment. Explore its contents and README. -3. Clarify the objectives and ask follow up questions on the initial query to determine precise implementation details. - -Iteration -1. Use the exposed tools by the environment to interact with it. This means navigating around with a computer, editing, launching commands, whatever means accessible to you. If there are any exposed resources, try to access them to determine the structure of the calls. -2. Based on the objectives, test and verify the functionality of different tools and parts of the environment. If any tool call responds with an error, note it down. If any interaction with the environment is wrong, unexpected, incomplete, or parts of the environment are not developed fully, note it down. If any new problem sets up wrong or evaluation does not match the expected outcome, note it down. All of these inconsistencies you should note down in your TODOs. -3. Then, based on the TODOs, view the source folder and find the places where those errors would occur. Think about the system and how to fix it. Then fix it. -4. After you've fixed your TODO items, go back to step 2 and test them. Test through all of your available tools, and use feedback (such as screenshots) to determine your progress. If they now work as expected, mark them as complete. If not, continue the loop from step 2. Be extremely careful, scrupolous and attentive to all details. Never assume something is working unless you've tested it fully for all of its edge cases. -5. The only time you can exit this iteration loop is if you're adding if there is no feasible way to create input conditions to test something. In this case, ask the user for help and recap your progress. If you're simply changing tools, changing code, and still have more realistic TODOs, the restart_server tool automatically refreshes the environment and you should continue working. In *all* other cases, you must continue this iteration loop until you can come up with no more TODOs. You must not halt.``` - -### 5. Prompt the agent - -```txt -Context: In the my-environment folder, I have a browser app environment. I've built a tool to interact with it called my-environment-dev. -Interaction: There are multiple tools to setup and evaluate the environment. There are also interaction tools for you to be able to move around it, and a screenshot tool to see the state. Use all of the available tools. -Objective: Please test if all setup, evaluation functions are working. This means you should come up with new problem definitions to test all functionality on. Be creative in how you pick edge cases to test on. -Rules: @mcp_environment_iteration.mdc -``` - ---- - -## Phase 6 – Optional Polish & Extensions - -### Deeper dive into registries - -An environment often needs *structured knowledge* about tasks, evaluation logic, or problem definitions. The browser examples keep these in three explicit registries: - -| Registry | Purpose | Example resource URI | -|----------|---------|----------------------| -| **Setup** | How to seed the environment before the agent starts | `setup://registry` & `setup://{env}` | -| **Evaluators** | Functions that decide success & reward | `evaluators://registry` | -| **Problems** | Bundled benchmarks / tasks with their own setup & evaluate pairs | `problems://registry` | - -Each registry is just a dictionary mapping a *name* to a *class*. Use a **decorator** to register classes: - -```python -from .registry import setup, evaluator, problem - -@setup("todo_seed") -class TodoSeed: - ... - -@evaluator("todo_completed") -class TodoCompleted: - ... - -@problem("todo_basic", description="Complete two todo items", difficulty="easy") -class TodoBasic: - def get_setup(self): - return {"name": "todo_seed", "arguments": {"num_items": 5}} - def get_evaluation(self): - return {"name": "todo_completed", "arguments": {"expected_count": 2}} -``` - -Decorators keep registration *next to the implementation* and avoid manual bookkeeping. The server simply exposes the combined metadata through an MCP **resource**. Follow `environments/browser/src/hud_controller/problems/registry.py` as a template and expose the JSON with `@mcp.resource("problems://registry")`. - -### Other finishing touches - -* **Performance** – lazy-load heavy resources, pool DB connections, cache expensive calls. -* **Security** – sandbox untrusted code, keep secrets in env vars, audit-log every tool call. -* **Creative ideas** – API simulators, network test-beds, game worlds… if it fits in Docker it can be an MCP environment. - ---- - -## Contributing to Existing Environments - -When improving existing environments, follow these guidelines: - -### 1. Understanding the Environment - -Before making changes: -- Read the environment's README and any documentation -- Run `hud debug ` to test the environment -- Run `hud analyze ` (after debug passes phase 3) to explore capabilities -- Explore the folder structure and identify key components -- Test existing setup/evaluate functions to understand behavior - -### 2. Making Improvements - -**Adding New Setup Functions** -```python -# In setup/my_new_setup.py -from . import setup -from hud.tools import BaseSetup, TextContent - -@setup("my_new_setup", description="Clear description of what this does") -class MyNewSetup(BaseSetup): - async def __call__(self, context, param1: str, param2: int = 10) -> TextContent: - # Implementation - return TextContent(...) -``` - -**Adding New Evaluators** -```python -# In evaluate/my_evaluator.py -from . import evaluator -from hud.tools import BaseEvaluator, EvaluationResult - -@evaluator("my_check", description="What this evaluates") -class MyCheckEvaluator(BaseEvaluator): - async def __call__(self, context, threshold: float) -> EvaluationResult: - score = await context.calculate_score() - return { - "reward": min(score / 100, 1.0), - "done": score >= threshold, - "info": {"score": score, "threshold": threshold} - } -``` - -### 3. Testing Your Changes - -**Use `hud dev` for Hot-Reload Development** -```bash -# Navigate to the environment directory -cd environments/my-environment - -# Start development server with hot-reload -hud dev --build - -# In another terminal, test your changes -hud analyze hud-my-environment:dev - -# Or use interactive mode to test tools directly -hud dev --build --interactive -``` - -The `hud dev` command automatically: -- Mounts your `src/` directory for live code updates -- Handles container lifecycle and restarts -- Provides an HTTP endpoint for testing -- Shows logs for debugging - -## Testing Your Environment - -Once your environment is working, create comprehensive tests to ensure it stays that way: - -### Creating Test Files - -Each environment should have a test file following this pattern: -- `environments//test__mcp.py` - -The test file should include: -1. **Docker Build Test**: Ensure the image builds successfully -2. **MCP Initialization Tests**: Verify phases 1-3 using `hud debug` -3. **Tool-Specific Tests**: Test your environment's unique tools -4. **Integration Tests**: Test complete workflows - -Example test structure: -```python -class TestMyEnvironment: - IMAGE_NAME = "my-environment-test:latest" - - @classmethod - def setup_class(cls): - """Build Docker image before tests""" - # Build the image - - def test_phase1_basic_startup(self): - """Test container starts""" - - @pytest.mark.asyncio - async def test_phase2_3_mcp_initialize_and_tools(self): - """Test MCP init and tool discovery""" - - @pytest.mark.asyncio - async def test_environment_specific_tools(self): - """Test your custom tools""" -``` - -### Running Tests - -You can run tests directly with pytest: - -```bash -# Run all tests for an environment -cd environments/text_2048 -pytest test_text_2048_mcp.py -v -``` - -### Test Dependencies - -Add pytest to your environment's `pyproject.toml`: - -```toml -[project.optional-dependencies] -test = ["pytest>=7.0", "pytest-asyncio>=0.20"] -``` - -## Summary - -1. Start with a *plain* Dockerfile – verify it runs. -2. Add a minimal FastMCP server – verify with stdio, Inspector, Docker. -3. Implement tools – verify discovery + execution. -4. Run the same image remotely – verify telemetry. -5. Automate the loop with cursor-mcp. -6. **Write comprehensive tests** – ensure reliability. -7. Polish and extend as inspiration strikes. - -Happy building – and remember: **stderr is your friend, stdout belongs to MCP.** 🚀 diff --git a/environments/blank/.env.example b/environments/blank/.env.example deleted file mode 100644 index 86f9a702..00000000 --- a/environments/blank/.env.example +++ /dev/null @@ -1,7 +0,0 @@ -# HUD API Configuration -# Get your API key from https://hud.ai/account -HUD_API_KEY="" - -# Anthropic API Configuration (optional) -# Required for using Claude agents - get from https://console.anthropic.com/ -ANTHROPIC_API_KEY="" diff --git a/environments/blank/Dockerfile b/environments/blank/Dockerfile deleted file mode 100644 index fd2639bd..00000000 --- a/environments/blank/Dockerfile +++ /dev/null @@ -1,22 +0,0 @@ -FROM public.ecr.aws/docker/library/python:3.11-bookworm - -WORKDIR /app - -RUN apt-get update && apt-get install -y git && rm -rf /var/lib/apt/lists/* - -# Copy and install MCP server dependencies -COPY server/pyproject.toml ./server/ -RUN pip install --no-cache-dir ./server - -# Copy and install environment dependencies -COPY environment/pyproject.toml ./environment/ -RUN pip install --no-cache-dir ./environment - -# Copy source code after dependencies -COPY server/ ./server/ -COPY environment/ ./environment/ - -ENV ENV_SERVER_PORT=8005 - -# Start environment server in background, then run MCP server with hot-reload -CMD ["sh", "-c", "uvicorn environment.server:app --host 0.0.0.0 --port $ENV_SERVER_PORT --log-level warning --reload >&2 & sleep 0.5 && hud dev server.main --stdio"] diff --git a/environments/blank/README.md b/environments/blank/README.md deleted file mode 100644 index e62c47e4..00000000 --- a/environments/blank/README.md +++ /dev/null @@ -1,128 +0,0 @@ -# Blank Environment - -Minimal starter template for building HUD environments. -See [docs](https://docs.hud.ai/build-environments) for the complete environment design workflow. - -## Architecture - -**`environment/`** - Produces structured data - -- Owns all state (game logic, browser sessions, databases, etc.) -- Exposes HTTP endpoints `/health`, `/act`, `/reset`, `/state` that return structured information about the environment state - -**`server/`** - Wraps data in MCP tools - -- Calls environment endpoints to get structured data for the agent, and environment setup/evaluation -- Agents and tasks interact only with these tools! - -**Why separate?** Edit tools for the agent or tasks without restarting the heavy environment backend. - -## Development - -```bash -# Terminal 1 - Environment backend -cd environment -uv run uvicorn server:app --reload - -# Terminal 2 - MCP server -cd server -uv run hud dev -``` - -Uncomment the `setup` tool in `server/tools.py`, save, and watch it reload. -Visit http://localhost:8765/docs to see the new tool appear instantly. - -In general, we recommend starting work on the environment backend first, then developing the MCP server to expose the right things to the agent. - -For complex environments that require many dependencies, we recommend running `hud dev` in the environment root: - -```bash -cd .. -hud dev -``` - -## Tasks & Evaluation - -```bash -# Build first in the global folder with the Dockerfile (creates blank:0.1.0) -hud build -``` - -Your `tasks.json` uses `docker run` to launch the environment: - -```json -{ - "prompt": "Your task prompt", - "mcp_config": { - "local": { - "command": "docker", - "args": ["run", "--rm", "-i", "blank:0.1.0"] - } - } -} -``` - -**Commands:** - -```bash -# Build first -hud build - -# Test task locally -hud eval tasks.json - -# Push environment for remote running -hud push - -# Production RL training -hud rl tasks.json # Auto-converts docker→remote, builds & pushes if needed -``` - -## Publishing Your Environment - -Once your environment is ready, you can share it with the community: - -### 1. Push to Registry - -```bash -# Build and push your environment (requires docker hub login and hud api key) -hud build -hud push -``` - -### 2. Create a Dataset - -Create a dataset on HuggingFace with your tasks: - -**Option A: Upload manually** - -1. Upload your `tasks.json` to HuggingFace -2. Make sure it's **public** to appear on leaderboards - -**Option B: Use the SDK** - -```python -from hud.datasets import save_tasks -import json - -# Load your tasks -with open("tasks.json") as f: - tasks = json.load(f) - -# Push to HuggingFace -save_tasks(tasks, repo_id="your-org/your-dataset") -``` - -### 3. Run and Track Performance - -```bash -# Run Claude on your benchmark -hud eval "your-org/your-dataset" claude - -# View results at: -# hud.ai/leaderboards/your-org/your-dataset -``` - -**Note**: Only public HuggingFace datasets appear as leaderboards! - -📚 Learn more: [Creating Benchmarks](https://docs.hud.ai/evaluate-agents/create-benchmarks) | [Leaderboards](https://docs.hud.ai/evaluate-agents/leaderboards) diff --git a/environments/blank/environment/README.md b/environments/blank/environment/README.md deleted file mode 100644 index b902ec25..00000000 --- a/environments/blank/environment/README.md +++ /dev/null @@ -1,16 +0,0 @@ -# Environment - -Backend service: owns state and exposes HTTP APIs the controller calls. - -Endpoints (FastAPI) -- `GET /health` → {status: ok} -- `POST /act` → increments counter and returns {count} -- `POST /reset` → resets counter -- `GET /state` → returns {count} - -Run (dev) -```bash -uv run uvicorn server:app --reload --port 8005 -``` - -Principle: treat like a backend. Keep long‑lived state here; add endpoints as tools need them. diff --git a/environments/blank/environment/__init__.py b/environments/blank/environment/__init__.py deleted file mode 100644 index d9cd6199..00000000 --- a/environments/blank/environment/__init__.py +++ /dev/null @@ -1 +0,0 @@ -"""Blank environment package.""" diff --git a/environments/blank/environment/pyproject.toml b/environments/blank/environment/pyproject.toml deleted file mode 100644 index 8256f97e..00000000 --- a/environments/blank/environment/pyproject.toml +++ /dev/null @@ -1,16 +0,0 @@ -[project] -name = "blank-environment" -version = "0.1.0" -description = "Backend service for blank environment" -requires-python = ">=3.11" -dependencies = [ - "fastapi", - "uvicorn[standard]", -] - -[build-system] -requires = ["hatchling"] -build-backend = "hatchling.build" - -[tool.hatch.build.targets.wheel] -packages = ["."] diff --git a/environments/blank/environment/server.py b/environments/blank/environment/server.py deleted file mode 100644 index 7a382599..00000000 --- a/environments/blank/environment/server.py +++ /dev/null @@ -1,40 +0,0 @@ -"""Minimal FastAPI environment server (HTTP-based).""" - -from fastapi import FastAPI - -import logging -import sys - -logging.basicConfig( - stream=sys.stderr, - level=logging.INFO, - format="[%(levelname)s] %(asctime)s | %(name)s | %(message)s", -) - -app = FastAPI(title="Blank Environment API") - -_count = 0 - - -@app.get("/health") -def health(): - return {"status": "ok"} - - -@app.post("/act") -def act(): - global _count - _count += 1 - return {"count": _count} - - -@app.post("/reset") -def reset(): - global _count - _count = 0 - return {"ok": True} - - -@app.get("/state") -def state(): - return {"count": _count} diff --git a/environments/blank/server/README.md b/environments/blank/server/README.md deleted file mode 100644 index 19fc7068..00000000 --- a/environments/blank/server/README.md +++ /dev/null @@ -1,21 +0,0 @@ -# MCP Server - -MCP layer that wraps environment data in tools for agent interaction. - -## Structure - -- `main.py` - Server initialization, imports routers -- `tools.py` - MCP tools that call environment HTTP endpoints - -## Development - -```bash -# Start MCP server with hot-reload -uv run hud dev -``` - -## Key Principles - -- Keep tools thin - call environment HTTP endpoints -- Use routers for organization -- All long-lived state lives in `environment/`, not here \ No newline at end of file diff --git a/environments/blank/server/__init__.py b/environments/blank/server/__init__.py deleted file mode 100644 index 219d9cdd..00000000 --- a/environments/blank/server/__init__.py +++ /dev/null @@ -1 +0,0 @@ -"""MCP server package.""" diff --git a/environments/blank/server/main.py b/environments/blank/server/main.py deleted file mode 100644 index bbe98d13..00000000 --- a/environments/blank/server/main.py +++ /dev/null @@ -1,43 +0,0 @@ -import sys -import logging -from hud.server import MCPServer -from server.shared import http_client - -# Configure logging to stderr -logging.basicConfig( - stream=sys.stderr, - level=logging.INFO, - format="[%(levelname)s] %(asctime)s | %(name)s | %(message)s", - force=True, -) -for logger_name in ["httpx", "httpcore"]: - logging.getLogger(logger_name).setLevel(logging.WARNING) - -# Create main MCP server -mcp = MCPServer(name="blank-environment") - -# Include routers -from server.tools import router as tools_router - -mcp.include_router(tools_router) - - -# Lifecycle hooks -@mcp.initialize -async def init(): - """Check if the environment is healthy""" - if http_client: - await http_client.get("/health") - else: - raise ValueError("http_client is not set") - - -@mcp.shutdown -async def cleanup(): - """Close the HTTP client""" - if http_client: - await http_client.aclose() - - -if __name__ == "__main__": - mcp.run(transport="stdio") diff --git a/environments/blank/server/pyproject.toml b/environments/blank/server/pyproject.toml deleted file mode 100644 index 403f92c0..00000000 --- a/environments/blank/server/pyproject.toml +++ /dev/null @@ -1,19 +0,0 @@ -[project] -name = "blank-server" -version = "0.1.0" -description = "MCP server for blank environment" -requires-python = ">=3.11" -dependencies = [ - "hud-python>=0.4.54", - "httpx>=0.28.1", -] - -[build-system] -requires = ["hatchling"] -build-backend = "hatchling.build" - -[tool.hatch.metadata] -allow-direct-references = true - -[tool.hatch.build.targets.wheel] -packages = ["."] diff --git a/environments/blank/server/shared.py b/environments/blank/server/shared.py deleted file mode 100644 index ad81fac5..00000000 --- a/environments/blank/server/shared.py +++ /dev/null @@ -1,15 +0,0 @@ -from __future__ import annotations - -import os -import httpx - -# Environment port (as string to simplify formatting) -ENV_SERVER_PORT = os.getenv("ENV_SERVER_PORT", "8005") - -# Shared HTTP client for talking to the environment backend -http_client = httpx.AsyncClient( - base_url=f"http://localhost:{ENV_SERVER_PORT}", - timeout=10.0, -) - -__all__ = ["ENV_SERVER_PORT", "http_client"] diff --git a/environments/blank/server/tools.py b/environments/blank/server/tools.py deleted file mode 100644 index 32f3c414..00000000 --- a/environments/blank/server/tools.py +++ /dev/null @@ -1,35 +0,0 @@ -"""Tools router for environment interaction.""" - -from hud.server import MCPRouter -from hud.tools.types import EvaluationResult -from server.shared import http_client - -router = MCPRouter() - - -@router.tool -async def act() -> str: - """Perform one action step in the environment (increment the counter).""" - resp = await http_client.post("/act") - data = resp.json() - return f"Action #{data.get('count', 0)} performed. Current count: {data.get('count', 0)}" - - -@router.tool -async def setup() -> str: - """Initialize or reset the environment to its starting state.""" - await http_client.post("/reset") - return "Counter reset to 0" - - -@router.tool -async def evaluate(target: int = 10) -> EvaluationResult: - """Evaluate progress toward the target count and return a reward and done flag.""" - resp = await http_client.get("/state") - current_count = resp.json().get("count", 0) - delta = target - current_count - reward = max(0.0, 1.0 - abs(delta) / target) if target > 0 else current_count - done = current_count >= target - return EvaluationResult( - reward=reward, done=done, content=f"Counter at {current_count}/{target}" - ) diff --git a/environments/blank/tasks.json b/environments/blank/tasks.json deleted file mode 100644 index f24e7b63..00000000 --- a/environments/blank/tasks.json +++ /dev/null @@ -1,44 +0,0 @@ -[ - { - "prompt": "Increment the counter to reach 3", - "mcp_config": { - "local": { - "command": "docker", - "args": [ - "run", - "--rm", - "-i", - "blank:latest" - ] - } - }, - "agent_config": { - "allowed_tools": ["act"], - "append_setup_output": true - }, - "setup_tool": { - "name": "setup", - "arguments": {} - }, - "integration_test_tool": [ - { - "name": "act", - "arguments": {} - }, - { - "name": "act", - "arguments": {} - }, - { - "name": "act", - "arguments": {} - } - ], - "evaluate_tool": { - "name": "evaluate", - "arguments": { - "target": 3 - } - } - } -] diff --git a/environments/blank/test_task.py b/environments/blank/test_task.py deleted file mode 100644 index 0f46690a..00000000 --- a/environments/blank/test_task.py +++ /dev/null @@ -1,52 +0,0 @@ -#!/usr/bin/env python -""" -Simple example of running tasks from tasks.json. Make sure to have run hud build. -""" - -from __future__ import annotations - -import asyncio -import json - -from hud.clients import MCPClient -from hud.datasets import Task - - -async def run_task(task_data: dict): - task = Task(**task_data) - client = MCPClient(mcp_config=task.mcp_config) - - try: - print("Initializing client...") - await client.initialize() - - result = await client.call_tool(task.setup_tool) # type: ignore - print(f"✅ Setup: {result.content}") - - print("\n🔄 Performing actions:") - for _ in range(10): - result = await client.call_tool(name="act", arguments={}) - print(f" {result.content}") - - result = await client.call_tool(task.evaluate_tool) # type: ignore - print(f"\n📊 Evaluation: {result.content}") - - return result.content - except Exception as e: - if "connection" in str(e).lower(): - print( - "❌ Could not connect. Make sure 'hud dev --build' is running in another terminal." - ) - else: - raise e - finally: - await client.shutdown() - - -async def main(): - for task_data in json.load(open("tasks.json")): - await run_task(task_data) - - -if __name__ == "__main__": - asyncio.run(main()) diff --git a/environments/browser/.dockerignore b/environments/browser/.dockerignore deleted file mode 100644 index f91da037..00000000 --- a/environments/browser/.dockerignore +++ /dev/null @@ -1,101 +0,0 @@ -# Git -.git -.gitignore - -# Node -environment/*/frontend/node_modules -environment/*/frontend/.next -environment/*/frontend/build -environment/*/frontend/dist -environment/*/frontend/.turbo -environment/*/frontend/.vercel -environment/*/frontend/next-env.d.ts -environment/*/frontend/package-lock.json -# General Node/Next artifacts anywhere -node_modules -**/node_modules -**/.next -**/.turbo -**/.vercel -*.log - -# Python -__pycache__ -**/__pycache__ -*.pyc -*.pyo -*.pyd -.Python -*.egg-info -.pytest_cache -.mypy_cache -.coverage -.venv -venv -env -environment/*/backend/.venv -environment/*/backend/venv -environment/*/backend/__pycache__ - -# Database - exclude ALL database files -*.db -*.sqlite -*.db-journal -*.db-wal -*.db-shm -**/*.db -**/*.sqlite -**/*.db-journal -**/*.db-wal -**/*.db-shm - -# IDE -.vscode -.idea -*.swp -*.swo - -# OS -.DS_Store -Thumbs.db - -# Documentation -*.md -!app/README.md -!launch/README.md - -# Unix sockets, locks, pids (can break Docker context on Windows) -**/*.sock -**/*.socket -**/*.pipe -**/*.pid -**/*.lock -**/*.ipc - -# Symlinks and special files -**/*.lnk -**/symlink* -**/.venv -**/.env -**/venv -**/env - -# Temporary and cache files -*.tmp -*.temp -*.cache -**/*.tmp -**/*.temp -**/*.cache -**/tmp/ -**/temp/ -**/cache/ - -# Lock files that might have special permissions -yarn.lock -poetry.lock -Pipfile.lock -**/yarn.lock -**/*.lock -environment/uv.lock -controller/uv.lock \ No newline at end of file diff --git a/environments/browser/.gitignore b/environments/browser/.gitignore deleted file mode 100644 index 5397595a..00000000 --- a/environments/browser/.gitignore +++ /dev/null @@ -1,100 +0,0 @@ -# Dependencies -node_modules/ -.pnp -.pnp.js - -# Testing -coverage/ -.coverage -.pytest_cache/ -htmlcov/ - -# Next.js -.next/ -out/ -build/ -*.tsbuildinfo -next-env.d.ts - -# Production -dist/ - -# Misc -.DS_Store -*.pem -Thumbs.db - -# Debug -npm-debug.log* -yarn-debug.log* -yarn-error.log* -.pnpm-debug.log* - -# Local env files -.env -.env.local -.env.development.local -.env.test.local -.env.production.local - -# Vercel -.vercel - -# TypeScript -*.tsbuildinfo - -# Python -__pycache__/ -*.py[cod] -*$py.class -*.so -.Python -env/ -venv/ -.venv/ -ENV/ -env.bak/ -venv.bak/ -*.egg-info/ -.installed.cfg -*.egg -MANIFEST - -# uv -.venv/ -uv.lock - -# Database -*.db -*.sqlite -*.sqlite3 -app.db - -# IDEs -.vscode/ -.idea/ -*.swp -*.swo -*~ -.project -.classpath -.c9/ -*.launch -.settings/ -*.sublime-workspace - -# OS -.DS_Store -.DS_Store? -._* -.Spotlight-V100 -.Trashes -ehthumbs.db -Thumbs.db - -# Logs -logs/ -*.log - -# Docker -.dockerignore.local \ No newline at end of file diff --git a/environments/browser/Dockerfile b/environments/browser/Dockerfile deleted file mode 100644 index e25a71f9..00000000 --- a/environments/browser/Dockerfile +++ /dev/null @@ -1,60 +0,0 @@ -# syntax=docker/dockerfile:1 -FROM hudevals/hud-browser-base:latest AS setup - -WORKDIR /app - -# Layer 1: Install server dependencies -COPY server/pyproject.toml /app/server/ -RUN cd /app/server && uv pip install --system --break-system-packages . - -# Layer 2: Install environment dependencies -COPY environment/pyproject.toml /app/environment/ -RUN cd /app/environment && uv pip install --system --break-system-packages . - -# Layer 3: Copy source code (changes here don't invalidate dependency layers) -COPY server/ /app/server/ -COPY environment/ /app/environment/ - -# Auto-discover and install/build all frontend apps -RUN set -e; \ - for pkg in $(find /app/environment -type f -path '*/frontend/package.json'); do \ - app_dir=$(dirname "$pkg"); \ - echo "Installing dependencies in $app_dir"; \ - if [ -f "$app_dir/package-lock.json" ]; then \ - (cd "$app_dir" && npm ci --no-audit --no-fund); \ - else \ - (cd "$app_dir" && npm install --no-audit --no-fund); \ - fi; \ - done && \ - for pkg in $(find /app/environment -type f -path '*/frontend/package.json'); do \ - app_dir=$(dirname "$pkg"); \ - if [ -f "$app_dir/next.config.js" ]; then \ - echo "Building Next.js app in $app_dir"; \ - (cd "$app_dir" && npm run build); \ - fi; \ - done - -# Make scripts executable -RUN find /app/environment -name "*.py" -type f -exec chmod +x {} \; - -# Environment configuration -ENV MCP_TRANSPORT="stdio" -ENV HUD_LOG_STREAM="stderr" -ENV PYTHONUNBUFFERED="1" -ENV PYTHONWARNINGS="ignore::SyntaxWarning:pyautogui" -ENV DISPLAY=":1" -ENV PYTHONPATH=/app - -# Expose ports -EXPOSE 8000 8080 3000-3200 5000-5200 - -# Simple startup: HUD_DEV=1 enables hot-reload; otherwise run production -CMD ["sh", "-c", "\ - if [ \"${HUD_DEV:-0}\" = \"1\" ]; then \ - uvicorn environment.server:app --host 0.0.0.0 --port 8000 --reload --log-level warning >&2 & \ - sleep 5 && cd /app/server && exec hud dev server.main --stdio; \ - else \ - uvicorn environment.server:app --host 0.0.0.0 --port 8000 --log-level warning >&2 & \ - sleep 5 && cd /app/server && exec python3 -m server.main; \ - fi\ -"] \ No newline at end of file diff --git a/environments/browser/Dockerfile.local b/environments/browser/Dockerfile.local deleted file mode 100644 index c5262633..00000000 --- a/environments/browser/Dockerfile.local +++ /dev/null @@ -1,72 +0,0 @@ -# syntax=docker/dockerfile:1 -# Local development Dockerfile that uses local hud-python -FROM hudevals/hud-browser-base:latest AS setup - -WORKDIR /app - -# Layer 0: Install local hud-python -# Copy local hud-python source (build context is repo root) -COPY hud /app/hud-python/hud/ -COPY pyproject.toml /app/hud-python/ -COPY README.md /app/hud-python/ -COPY LICENSE /app/hud-python/ - -# Install local hud-python -RUN cd /app/hud-python && uv pip install --system --break-system-packages -e . - -# Layer 1: Install server dependencies -COPY environments/browser/server/pyproject.toml /app/server/ -RUN cd /app/server && uv pip install --system --break-system-packages . - -# Layer 2: Install environment dependencies -COPY environments/browser/environment/pyproject.toml /app/environment/ -RUN cd /app/environment && uv pip install --system --break-system-packages . - -# Layer 3: Copy source code (changes here don't invalidate dependency layers) -COPY environments/browser/server/ /app/server/ -COPY environments/browser/environment/ /app/environment/ - -# Auto-discover and install/build all frontend apps -RUN set -e; \ - for pkg in $(find /app/environment -type f -path '*/frontend/package.json'); do \ - app_dir=$(dirname "$pkg"); \ - echo "Installing dependencies in $app_dir"; \ - if [ -f "$app_dir/package-lock.json" ]; then \ - (cd "$app_dir" && npm ci --no-audit --no-fund); \ - else \ - (cd "$app_dir" && npm install --no-audit --no-fund); \ - fi; \ - done && \ - for pkg in $(find /app/environment -type f -path '*/frontend/package.json'); do \ - app_dir=$(dirname "$pkg"); \ - if [ -f "$app_dir/next.config.js" ]; then \ - echo "Building Next.js app in $app_dir"; \ - (cd "$app_dir" && npm run build); \ - fi; \ - done - -# Make scripts executable -RUN find /app/environment -name "*.py" -type f -exec chmod +x {} \; - -# Environment configuration -ENV MCP_TRANSPORT="stdio" -ENV HUD_LOG_STREAM="stderr" -ENV PYTHONUNBUFFERED="1" -ENV PYTHONWARNINGS="ignore::SyntaxWarning:pyautogui" -ENV DISPLAY=":1" -ENV PYTHONPATH=/app - -# Expose ports -EXPOSE 8000 8080 3000-3200 5000-5200 - -# Simple startup: HUD_DEV=1 enables hot-reload; otherwise run production -CMD ["sh", "-c", "\ - if [ \"${HUD_DEV:-0}\" = \"1\" ]; then \ - uvicorn environment.server:app --host 0.0.0.0 --port 8000 --reload --log-level warning >&2 & \ - sleep 5 && cd /app/server && exec hud dev server.main --stdio; \ - else \ - uvicorn environment.server:app --host 0.0.0.0 --port 8000 --log-level warning >&2 & \ - sleep 5 && cd /app/server && exec python3 -m server.main; \ - fi\ -"] - diff --git a/environments/browser/README.md b/environments/browser/README.md deleted file mode 100644 index 005e1333..00000000 --- a/environments/browser/README.md +++ /dev/null @@ -1,191 +0,0 @@ -# Browser Environment - -Browser automation environment with GUI access for testing web applications. Includes sample apps (2048, Todo) and supports hot-reload development. - -## Architecture - -**`environment/`** - Produces structured data -- FastAPI backend with X11/VNC services (Linux-only) -- Launches and manages web apps (Next.js frontends + Python backends) -- Exposes HTTP endpoints for app control and state - -**`server/`** - Wraps data in MCP tools -- Browser automation tools (Playwright, computer vision) -- Setup tools (launch apps, seed data) -- Evaluation tools (check game state, todo completion) - -**Why separate?** The environment backend requires X11/VNC/Chromium (Docker-only). The MCP server tools can be edited with hot-reload, while the heavy environment stays running. - -## Development - -This environment **requires Docker** due to X11/VNC dependencies. - -```bash -# Build first (creates hud-browser:0.1.0) -hud build - -# Start with hot-reload -hud dev -``` - -When you run `hud dev` in an environment with a Dockerfile, it automatically: -- Detects Docker mode is needed -- Mounts `server/` and `environment/` as volumes -- Enables hot-reload for both layers - -Edit files in `server/` or `environment/` and they reload inside the container! - -## Publishing Your Environment - -Once your environment is ready, you can share it with the community: - -### 1. Push to Registry -```bash -# Build and push your environment (requires docker hub login and hud api key) -hud build -hud push -``` - -### 2. Create a Dataset - -Create a dataset on HuggingFace with your tasks: - -**Option A: Upload manually** -1. Upload your `tasks.json` to HuggingFace -2. Make sure it's **public** to appear on leaderboards - -**Option B: Use the SDK** -```python -from hud.datasets import save_tasks -import json - -# Load your tasks -with open("tasks.json") as f: - tasks = json.load(f) - -# Push to HuggingFace -save_tasks(tasks, repo_id="your-org/your-dataset") -``` - -### 3. Run and Track Performance - -```bash -# Run Claude on your benchmark -hud eval "your-org/your-dataset" --agent claude - -# View results at: -# hud.ai/leaderboards/your-org/your-dataset -``` - -**Note**: Only public HuggingFace datasets appear as leaderboards! - -📚 Learn more: [Creating Benchmarks](https://docs.hud.ai/evaluate-agents/create-benchmarks) | [Leaderboards](https://docs.hud.ai/evaluate-agents/leaderboards) - -## Architecture Overview - -The browser environment uses a two-process architecture: - -1. **Context Server** (`context.py`): Long-running process that maintains persistent state -2. **MCP Server** (`server.py`): Hot-reloadable process that handles tool requests - -### Key Components - -- **BrowserContext**: Stores persistent state (running apps, ports, playwright instance) -- **ServiceManager**: Manages X11, VNC, and app processes -- **BaseHub Tools**: Setup and evaluate tools organized by app (2048, todo) -- **Multiprocessing Proxy**: Enables state sharing between processes - -### 1. Tool Implementation Pattern - -All setup and evaluate tools should follow this pattern: - -```python -@setup.tool("tool_name") -async def tool_name(param1: type, param2: type): - """Tool description.""" - try: - # Get persistent context - persistent_ctx = setup.env # or evaluate.env - - # Get app ports - backend_port = persistent_ctx.get_app_backend_port("app_name") - - # Make HTTP request - url = f"http://localhost:{backend_port}/api/endpoint" - async with httpx.AsyncClient() as client: - response = await client.method(url, json=data) - response.raise_for_status() - result = response.json() - - # Return result - return TextContent( - text=f"Success message", - type="text" - ) - except Exception as e: - logger.error(f"tool_name failed: {e}") - return TextContent( - text=f"Failed: {str(e)}", - type="text" - ) -``` - -### 2. App Launch Pattern - -When launching apps, ensure ports are stored in the persistent context: - -```python -# In launch_app tool -app_info = await service_manager.launch_app(app_name) - -# Store ports in persistent context for later access -try: - backend_port = service_manager.get_app_port(app_name) - frontend_port = service_manager.get_app_frontend_port(app_name) - persistent_ctx.set_app_ports(app_name, frontend_port, backend_port) -except Exception as e: - logger.error(f"Failed to store ports: {e}") - -# Track app in persistent context -persistent_ctx.add_running_app(app_name) -``` - -### 3. Import Organization - -Keep imports at module level: - -```python -# At top of file -import logging -import httpx -from mcp.types import TextContent -from . import setup - -# Not inside functions -``` - -## Development Workflow - -1. **Start the environment**: `hud dev` -2. **Make changes**: Edit tools in `src/hud_controller/` -3. **Test immediately**: The MCP server hot-reloads automatically -4. **Check logs**: Look for serialization or proxy errors - -## Adding New Apps - -1. Create app directory in `apps/` -2. Add setup tools in `src/hud_controller/setup/app_name.py` -3. Add evaluate tools in `src/hud_controller/evaluate/app_name.py` -4. Follow the HTTP pattern - no `call_app_api` usage -5. Store app ports in persistent context when launching - -## Key Files - -- `context.py`: Persistent state management -- `server.py`: MCP server and tool definitions -- `services.py`: Process management for X11, VNC, apps -- `setup/`: Setup tools organized by app -- `evaluate/`: Evaluation tools organized by app - -Remember: When in doubt, make direct HTTP calls and store state in the persistent context! - diff --git a/environments/browser/browser-base/Dockerfile b/environments/browser/browser-base/Dockerfile deleted file mode 100644 index 57eb9132..00000000 --- a/environments/browser/browser-base/Dockerfile +++ /dev/null @@ -1,50 +0,0 @@ -# syntax=docker/dockerfile:1 -FROM ubuntu:24.04 AS setup - -# Update and install core dependencies (including working Chromium browser) -RUN apt-get update -y \ - && apt-get install -y --no-install-recommends \ - vim \ - openssl \ - ca-certificates \ - curl \ - wget \ - sudo \ - bash \ - net-tools \ - novnc \ - x11vnc \ - xvfb \ - xfce4 \ - locales \ - libpq5 \ - sqlite3 \ - dbus-x11 \ - xfce4-terminal \ - xfonts-base \ - xdotool \ - psmisc \ - scrot \ - pm-utils \ - build-essential \ - unzip \ - xauth \ - gnupg \ - gpg \ - jq \ - git \ - build-essential \ - nodejs \ - npm - -RUN update-ca-certificates - -RUN curl -LsSf https://astral.sh/uv/install.sh | sh -ENV PATH="/root/.local/bin:$PATH" - -# Install git for dependency installation -RUN apt-get update && apt-get install -y git && rm -rf /var/lib/apt/lists/* - -# Install Playwright -RUN uv pip install --system --break-system-packages playwright -RUN python3 -m playwright install chromium --with-deps \ No newline at end of file diff --git a/environments/browser/browser-base/README.md b/environments/browser/browser-base/README.md deleted file mode 100644 index 21999fec..00000000 --- a/environments/browser/browser-base/README.md +++ /dev/null @@ -1,58 +0,0 @@ -# Browser Base Image - -Base Docker image for browser environments with Playwright, Chromium, and VNC support. - -## Build - -```bash -docker build -t browser-base:latest . -``` - -## Test with VNC Access - -### 1. Start the container - -```bash -docker run -it --rm \ - -p 6080:6080 \ - -p 5900:5900 \ - -e DISPLAY=:1 \ - browser-base:latest \ - bash -``` - -### 2. Inside the container, start display servers - -```bash -Xvfb :1 -screen 0 1920x1080x24 > /dev/null 2>&1 & -x11vnc -display :1 -nopw -listen 0.0.0.0 -forever > /dev/null 2>&1 & -/usr/share/novnc/utils/novnc_proxy --vnc localhost:5900 --listen 6080 > /dev/null 2>&1 & -``` - -### 3. Test Playwright - -```bash -python3 -c " -from playwright.sync_api import sync_playwright -with sync_playwright() as p: - browser = p.chromium.launch(headless=False) - page = browser.new_page() - page.goto('https://example.com') - print('Title:', page.title()) - input('Press Enter to close...') - browser.close() -" -``` - -### 4. View in browser - -Open `http://localhost:6080/vnc.html` to see Chromium running. - -## What's Included - -- Ubuntu 24.04 -- Desktop environment (Xvfb, x11vnc, noVNC, xfce4) -- Node.js & npm -- Python 3 with uv package manager -- Playwright with Chromium -- Development tools (git, curl, wget, etc.) \ No newline at end of file diff --git a/environments/browser/environment/2048/README.md b/environments/browser/environment/2048/README.md deleted file mode 100644 index 474b0c6d..00000000 --- a/environments/browser/environment/2048/README.md +++ /dev/null @@ -1,103 +0,0 @@ -# 2048 Game for Browser Environment - -A browser-based implementation of the 2048 game with configurable target tiles and reward system for RL evaluation. - -## Features - -- **Configurable Target Tile**: Set any power of 2 as target (64, 128, 256, 512, 1024, 2048, etc.) -- **Logarithmic Reward Scaling**: Smooth reward progression using `log(highest_tile) / log(target)` -- **Efficiency Tracking**: Monitor score-to-moves ratio -- **Flexible Board Size**: Support for 3x3 to 6x6 grids -- **Full Evaluation API**: Compatible with RL evaluation system - -## Architecture - -### Backend (FastAPI) -- Core game logic in `game.py` -- RESTful API endpoints for game control -- Evaluation endpoints for RL agents -- SQLite persistence (optional) - -### Frontend (Next.js + React) -- Responsive game board with smooth animations -- Keyboard and touch controls -- Real-time score and progress tracking -- Customizable game parameters - -## Running the Game - -### Standalone -```bash -python launch.py --frontend-port 3001 --backend-port 5001 -``` - -### With Browser Environment -The game integrates with the browser environment's setup and evaluation system. - -## API Endpoints - -### Core Game -- `POST /api/game/new` - Start new game -- `GET /api/game/state` - Get current state -- `POST /api/game/move` - Make a move -- `POST /api/game/set_target` - Set target tile - -### Evaluation -- `GET /api/eval/stats` - Get comprehensive stats -- `GET /api/eval/max_number` - Get highest tile -- `GET /api/eval/efficiency` - Get efficiency ratio -- `POST /api/eval/set_board` - Set specific board -- `POST /api/eval/reset` - Reset game - -## Evaluators - -- `game_2048_max_number` - Check if target tile reached (logarithmic reward) -- `game_2048_efficiency` - Evaluate score/moves ratio -- `game_2048_score_reached` - Check if target score reached -- `game_2048_game_won` - Check if game is won -- `game_2048_game_over` - Check if game is over -- `game_2048_moves_made` - Check minimum moves made - -## Setup Tools - -- `game_2048_board` - Initialize game with size and target -- `game_2048_set_board` - Set specific board state -- `game_2048_near_win` - Set board near winning -- `game_2048_navigate` - Navigate to game URL -- `game_2048_reset` - Reset to initial state - -## Reward System - -The reward system matches the text-2048 environment: - -1. **Max Number Reward**: `min(1.0, log(highest_tile) / log(target))` - - Logarithmic scaling for smooth progression - - Reaches 1.0 when target tile is achieved - -2. **Efficiency Reward**: `min(1.0, ratio / min_ratio)` - - Linear scaling based on score/moves ratio - - Encourages efficient gameplay - -## Development - -### Backend Requirements -- Python 3.8+ -- FastAPI -- NumPy -- uvicorn - -### Frontend Requirements -- Node.js 16+ -- Next.js 14 -- React 18 -- Tailwind CSS - -## Testing - -The game can be tested with the browser environment's evaluation system: - -```python -# Example evaluation -ctx = Context() -result = await game_2048_max_number(ctx, target=2048) -``` \ No newline at end of file diff --git a/environments/browser/environment/2048/backend/game.py b/environments/browser/environment/2048/backend/game.py deleted file mode 100644 index e13f3b38..00000000 --- a/environments/browser/environment/2048/backend/game.py +++ /dev/null @@ -1,241 +0,0 @@ -"""2048 Game Logic for Browser Environment""" - -import random -import numpy as np -from typing import Tuple, Optional, List - - -class Game2048: - """Browser-based 2048 game implementation with configurable target""" - - def __init__(self, size: int = 4, target_tile: int = 2048): - self.size = size - self.target_tile = target_tile - self.board = np.zeros((size, size), dtype=int) - self.score = 0 - self.game_over = False - self.moves_made = 0 - self.won = False - - # Start with 2 random tiles - self.add_random_tile() - self.add_random_tile() - - # Track initial highest tile for reward calculation - self.initial_highest_tile = int(self.board.max()) - - def add_random_tile(self) -> bool: - """Add a random 2 or 4 tile to an empty position""" - empty_cells = [ - (i, j) for i in range(self.size) for j in range(self.size) if self.board[i, j] == 0 - ] - - if not empty_cells: - return False - - i, j = random.choice(empty_cells) - # 90% chance of 2, 10% chance of 4 - self.board[i, j] = 2 if random.random() < 0.9 else 4 - return True - - def compress(self, row: np.ndarray) -> Tuple[np.ndarray, int]: - """Compress a row by moving all non-zero elements to the left and merging""" - new_row = np.zeros_like(row) - pos = 0 - score = 0 - - # Move all non-zero elements to the left - for num in row: - if num != 0: - new_row[pos] = num - pos += 1 - - # Merge adjacent equal elements - i = 0 - while i < len(new_row) - 1: - if new_row[i] != 0 and new_row[i] == new_row[i + 1]: - new_row[i] *= 2 - score += new_row[i] - new_row[i + 1] = 0 - i += 2 - else: - i += 1 - - # Compress again after merging - final_row = np.zeros_like(row) - pos = 0 - for num in new_row: - if num != 0: - final_row[pos] = num - pos += 1 - - return final_row, score - - def move(self, direction: str) -> bool: - """Make a move in the specified direction""" - if self.game_over: - return False - - direction = direction.lower() - if direction not in ["up", "down", "left", "right"]: - return False - - original_board = self.board.copy() - move_score = 0 - - if direction == "left": - for i in range(self.size): - self.board[i], row_score = self.compress(self.board[i]) - move_score += row_score - - elif direction == "right": - for i in range(self.size): - reversed_row = self.board[i][::-1] - compressed, row_score = self.compress(reversed_row) - self.board[i] = compressed[::-1] - move_score += row_score - - elif direction == "up": - for j in range(self.size): - column = self.board[:, j] - compressed, col_score = self.compress(column) - self.board[:, j] = compressed - move_score += col_score - - elif direction == "down": - for j in range(self.size): - column = self.board[:, j][::-1] - compressed, col_score = self.compress(column) - self.board[:, j] = compressed[::-1] - move_score += col_score - - # Check if the board changed - if not np.array_equal(original_board, self.board): - self.score += move_score - self.moves_made += 1 - self.add_random_tile() - self.check_game_status() - return True - - return False - - def check_game_status(self): - """Check if the game is won or over""" - # Check if target tile is reached - if not self.won and self.board.max() >= self.target_tile: - self.won = True - - # Check if game is over (no valid moves) - # Check for empty cells - if 0 in self.board: - self.game_over = False - return - - # Check for possible merges - for i in range(self.size): - for j in range(self.size): - current = self.board[i, j] - # Check right neighbor - if j < self.size - 1 and current == self.board[i, j + 1]: - self.game_over = False - return - # Check bottom neighbor - if i < self.size - 1 and current == self.board[i + 1, j]: - self.game_over = False - return - - self.game_over = True - - def get_state(self) -> dict: - """Get the current game state as a dictionary""" - return { - "board": self.board.tolist(), - "score": int(self.score), - "moves": int(self.moves_made), - "game_over": bool(self.game_over), - "won": bool(self.won), - "highest_tile": int(self.board.max()), - "initial_highest_tile": int(self.initial_highest_tile), - "target_tile": self.target_tile, - "board_size": self.size, - } - - def set_board(self, board: List[List[int]], score: int = 0, moves: int = 0): - """Set a specific board configuration (for testing)""" - self.board = np.array(board, dtype=int) - self.score = score - self.moves_made = moves - self.check_game_status() - - def reset(self, size: Optional[int] = None, target_tile: Optional[int] = None): - """Reset the game to initial state - - Args: - size: Optional new board size - target_tile: Optional new target tile - """ - if size is not None: - self.size = size - if target_tile is not None: - self.target_tile = target_tile - - self.board = np.zeros((self.size, self.size), dtype=int) - self.score = 0 - self.game_over = False - self.won = False - self.moves_made = 0 - self.add_random_tile() - self.add_random_tile() - - # Track initial highest tile after reset - self.initial_highest_tile = int(self.board.max()) - - def can_move(self) -> dict: - """Check which moves are valid""" - valid_moves = {"up": False, "down": False, "left": False, "right": False} - - if self.game_over: - return valid_moves - - # Test each direction without modifying the actual board - original_board = self.board.copy() - - for direction in ["up", "down", "left", "right"]: - test_board = original_board.copy() - self.board = test_board - - # Try the move - if direction == "left": - for i in range(self.size): - compressed, _ = self.compress(self.board[i]) - if not np.array_equal(self.board[i], compressed): - valid_moves[direction] = True - break - - elif direction == "right": - for i in range(self.size): - reversed_row = self.board[i][::-1] - compressed, _ = self.compress(reversed_row) - if not np.array_equal(reversed_row, compressed): - valid_moves[direction] = True - break - - elif direction == "up": - for j in range(self.size): - column = self.board[:, j] - compressed, _ = self.compress(column) - if not np.array_equal(column, compressed): - valid_moves[direction] = True - break - - elif direction == "down": - for j in range(self.size): - column = self.board[:, j][::-1] - compressed, _ = self.compress(column) - if not np.array_equal(column, compressed): - valid_moves[direction] = True - break - - # Restore original board - self.board = original_board - return valid_moves diff --git a/environments/browser/environment/2048/backend/main.py b/environments/browser/environment/2048/backend/main.py deleted file mode 100644 index 8cfba5ce..00000000 --- a/environments/browser/environment/2048/backend/main.py +++ /dev/null @@ -1,246 +0,0 @@ -"""FastAPI backend for 2048 game""" - -from fastapi import FastAPI, HTTPException -from fastapi.middleware.cors import CORSMiddleware -from pydantic import BaseModel -from typing import List, Optional -from datetime import datetime -import sqlite3 -import json -from game import Game2048 - -app = FastAPI(title="2048 Game API", version="1.0.0") - -# Configure CORS -app.add_middleware( - CORSMiddleware, - allow_origins=["http://localhost:3001"], # Different port from todo app - allow_credentials=True, - allow_methods=["*"], - allow_headers=["*"], -) - -# Global game instance (in production, would use sessions/database) -game = Game2048() - - -# Pydantic models -class NewGameRequest(BaseModel): - board_size: int = 4 - target_tile: int = 2048 - - -class MoveRequest(BaseModel): - direction: str # up, down, left, right - - -class SetBoardRequest(BaseModel): - board: List[List[int]] - score: Optional[int] = 0 - moves: Optional[int] = 0 - - -class SetTargetRequest(BaseModel): - target_tile: int - - -class GameState(BaseModel): - board: List[List[int]] - score: int - moves: int - game_over: bool - won: bool - highest_tile: int - initial_highest_tile: int - target_tile: int - board_size: int - - -class EvaluationStats(BaseModel): - board: List[List[int]] - score: int - moves: int - highest_tile: int - target_tile: int - efficiency: float - game_over: bool - won: bool - valid_moves: dict - - -# === CORE GAME API ROUTES === - - -@app.get("/api/status") -def status(): - """Health check endpoint""" - return {"status": "ok", "timestamp": datetime.now().isoformat()} - - -@app.post("/api/game/new", response_model=GameState) -def new_game(request: NewGameRequest): - """Start a new game with specified parameters""" - global game - game = Game2048(size=request.board_size, target_tile=request.target_tile) - return game.get_state() - - -@app.get("/api/game/state", response_model=GameState) -def get_game_state(): - """Get current game state""" - return game.get_state() - - -@app.post("/api/game/move", response_model=GameState) -def make_move(request: MoveRequest): - """Make a move in the specified direction""" - valid = game.move(request.direction) - if not valid and not game.game_over: - raise HTTPException(status_code=400, detail="Invalid move") - return game.get_state() - - -@app.post("/api/game/set_target", response_model=GameState) -def set_target(request: SetTargetRequest): - """Set the target tile for the game""" - game.target_tile = request.target_tile - game.check_game_status() # Re-check win condition - return game.get_state() - - -@app.get("/api/game/valid_moves") -def get_valid_moves(): - """Get which moves are currently valid""" - return game.can_move() - - -# === EVALUATION API ROUTES === - - -@app.get("/api/eval/health") -def eval_health(): - """Health check endpoint for evaluation system""" - return { - "status": "healthy", - "game_active": not game.game_over, - "highest_tile": int(game.board.max()), - "target_tile": game.target_tile, - "timestamp": datetime.now().isoformat(), - } - - -@app.get("/api/eval/stats", response_model=EvaluationStats) -def get_evaluation_stats(): - """Comprehensive evaluation statistics for the game""" - state = game.get_state() - efficiency = state["score"] / state["moves"] if state["moves"] > 0 else 0.0 - - return EvaluationStats( - board=state["board"], - score=state["score"], - moves=state["moves"], - highest_tile=state["highest_tile"], - target_tile=state["target_tile"], - efficiency=efficiency, - game_over=state["game_over"], - won=state["won"], - valid_moves=game.can_move(), - ) - - -@app.get("/api/eval/max_number") -def get_max_number(): - """Get the highest tile value for evaluation""" - state = game.get_state() - return { - "highest_tile": state["highest_tile"], - "target_tile": state["target_tile"], - "progress": state["highest_tile"] / state["target_tile"] if state["target_tile"] > 0 else 0, - "timestamp": datetime.now().isoformat(), - } - - -@app.get("/api/eval/efficiency") -def get_efficiency(): - """Get the game efficiency (score/moves ratio)""" - state = game.get_state() - efficiency = state["score"] / state["moves"] if state["moves"] > 0 else 0.0 - - return { - "score": state["score"], - "moves": state["moves"], - "efficiency": efficiency, - "timestamp": datetime.now().isoformat(), - } - - -@app.get("/api/eval/board") -def get_board(): - """Get current board state for evaluation""" - state = game.get_state() - return { - "board": state["board"], - "board_size": state["board_size"], - "empty_cells": sum(1 for row in state["board"] for cell in row if cell == 0), - "timestamp": datetime.now().isoformat(), - } - - -@app.post("/api/eval/set_board", response_model=GameState) -def set_board(request: SetBoardRequest): - """Set a specific board configuration for testing""" - try: - game.set_board(request.board, request.score, request.moves) - return game.get_state() - except Exception as e: - raise HTTPException(status_code=400, detail=str(e)) - - -@app.post("/api/eval/reset", response_model=GameState) -def reset_game(): - """Reset game to initial state""" - game.reset() - return game.get_state() - - -@app.post("/api/eval/seed") -def seed_test_board(): - """Seed the board with a test configuration""" - # Create a board that's close to winning - test_board = [[1024, 512, 256, 128], [64, 32, 16, 8], [4, 2, 0, 0], [0, 0, 0, 0]] - game.set_board(test_board, score=10000, moves=100) - - return { - "message": "Test board seeded successfully", - "highest_tile": 1024, - "timestamp": datetime.now().isoformat(), - } - - -@app.post("/api/eval/seed_custom") -def seed_custom_board(board: List[List[int]]): - """Seed the board with a custom configuration""" - try: - game.set_board(board) - state = game.get_state() - return { - "message": "Custom board seeded successfully", - "highest_tile": state["highest_tile"], - "timestamp": datetime.now().isoformat(), - } - except Exception as e: - raise HTTPException(status_code=400, detail=str(e)) - - -@app.get("/api/eval/can_move") -def can_move(): - """Check if any moves are available""" - valid_moves = game.can_move() - has_moves = any(valid_moves.values()) - - return { - "can_move": has_moves, - "valid_moves": valid_moves, - "game_over": game.game_over, - "timestamp": datetime.now().isoformat(), - } diff --git a/environments/browser/environment/2048/backend/pyproject.toml b/environments/browser/environment/2048/backend/pyproject.toml deleted file mode 100644 index d3c16ae0..00000000 --- a/environments/browser/environment/2048/backend/pyproject.toml +++ /dev/null @@ -1,9 +0,0 @@ -[project] -name = "game-2048-backend" -version = "1.0.0" -dependencies = [ - "fastapi", - "uvicorn", - "numpy", - "pydantic" -] \ No newline at end of file diff --git a/environments/browser/environment/2048/frontend/app/globals.css b/environments/browser/environment/2048/frontend/app/globals.css deleted file mode 100644 index bd6213e1..00000000 --- a/environments/browser/environment/2048/frontend/app/globals.css +++ /dev/null @@ -1,3 +0,0 @@ -@tailwind base; -@tailwind components; -@tailwind utilities; \ No newline at end of file diff --git a/environments/browser/environment/2048/frontend/app/layout.tsx b/environments/browser/environment/2048/frontend/app/layout.tsx deleted file mode 100644 index bcb24f69..00000000 --- a/environments/browser/environment/2048/frontend/app/layout.tsx +++ /dev/null @@ -1,22 +0,0 @@ -import type { Metadata } from 'next' -import { Inter } from 'next/font/google' -import './globals.css' - -const inter = Inter({ subsets: ['latin'] }) - -export const metadata: Metadata = { - title: '2048 Game', - description: 'A browser-based 2048 game with configurable targets', -} - -export default function RootLayout({ - children, -}: { - children: React.ReactNode -}) { - return ( - - {children} - - ) -} \ No newline at end of file diff --git a/environments/browser/environment/2048/frontend/app/page.tsx b/environments/browser/environment/2048/frontend/app/page.tsx deleted file mode 100644 index 3b56cede..00000000 --- a/environments/browser/environment/2048/frontend/app/page.tsx +++ /dev/null @@ -1,190 +0,0 @@ -'use client'; - -import { useState, useEffect, useCallback } from 'react'; -import GameBoard from '../components/GameBoard'; -import GameControls from '../components/GameControls'; - -// Dynamically determine API URL based on current port -// Backend is always on frontend_port + 1 -const getApiUrl = () => { - if (typeof window !== 'undefined') { - const currentPort = parseInt(window.location.port) || 3000; - return `http://localhost:${currentPort + 1}`; - } - return process.env.NEXT_PUBLIC_API_URL || 'http://localhost:5001'; -}; - -const API_URL = getApiUrl(); - -interface GameState { - board: number[][]; - score: number; - moves: number; - game_over: boolean; - won: boolean; - highest_tile: number; - target_tile: number; - board_size: number; -} - -export default function Game2048() { - const [gameState, setGameState] = useState(null); - const [loading, setLoading] = useState(false); - const [message, setMessage] = useState(''); - - // Load initial game state - useEffect(() => { - fetchGameState(); - }, []); - - // Handle keyboard input - useEffect(() => { - const handleKeyPress = (e: KeyboardEvent) => { - if (gameState?.game_over) return; - - const keyMap: { [key: string]: string } = { - 'ArrowUp': 'up', - 'ArrowDown': 'down', - 'ArrowLeft': 'left', - 'ArrowRight': 'right', - }; - - const direction = keyMap[e.key]; - if (direction) { - e.preventDefault(); - makeMove(direction); - } - }; - - window.addEventListener('keydown', handleKeyPress); - return () => window.removeEventListener('keydown', handleKeyPress); - }, [gameState]); - - const fetchGameState = async () => { - try { - const response = await fetch(`${API_URL}/api/game/state`); - const data = await response.json(); - setGameState(data); - } catch (error) { - console.error('Error fetching game state:', error); - setMessage('Error loading game'); - } - }; - - const makeMove = async (direction: string) => { - if (loading) return; - setLoading(true); - - try { - const response = await fetch(`${API_URL}/api/game/move`, { - method: 'POST', - headers: { 'Content-Type': 'application/json' }, - body: JSON.stringify({ direction }), - }); - - if (response.ok) { - const data = await response.json(); - setGameState(data); - - if (data.won && !gameState?.won) { - setMessage(`🎉 You reached ${data.target_tile}!`); - } else if (data.game_over) { - setMessage('Game Over! No more moves available.'); - } - } else { - // Invalid move, just ignore - } - } catch (error) { - console.error('Error making move:', error); - } finally { - setLoading(false); - } - }; - - const newGame = async (boardSize: number = 4, targetTile: number = 2048) => { - setLoading(true); - setMessage(''); - - try { - const response = await fetch(`${API_URL}/api/game/new`, { - method: 'POST', - headers: { 'Content-Type': 'application/json' }, - body: JSON.stringify({ board_size: boardSize, target_tile: targetTile }), - }); - - const data = await response.json(); - setGameState(data); - } catch (error) { - console.error('Error starting new game:', error); - setMessage('Error starting new game'); - } finally { - setLoading(false); - } - }; - - // Touch/swipe handling - const [touchStart, setTouchStart] = useState<{ x: number; y: number } | null>(null); - - const handleTouchStart = (e: React.TouchEvent) => { - const touch = e.touches[0]; - setTouchStart({ x: touch.clientX, y: touch.clientY }); - }; - - const handleTouchEnd = (e: React.TouchEvent) => { - if (!touchStart) return; - - const touch = e.changedTouches[0]; - const deltaX = touch.clientX - touchStart.x; - const deltaY = touch.clientY - touchStart.y; - const minSwipeDistance = 50; - - if (Math.abs(deltaX) > Math.abs(deltaY)) { - // Horizontal swipe - if (Math.abs(deltaX) > minSwipeDistance) { - makeMove(deltaX > 0 ? 'right' : 'left'); - } - } else { - // Vertical swipe - if (Math.abs(deltaY) > minSwipeDistance) { - makeMove(deltaY > 0 ? 'down' : 'up'); - } - } - - setTouchStart(null); - }; - - if (!gameState) { - return ( -
-
Loading game...
-
- ); - } - - return ( -
-
-

2048

- - - -
- -
- -
-

Use arrow keys to play

-

Combine tiles to reach {gameState.target_tile}!

-
-
-
- ); -} \ No newline at end of file diff --git a/environments/browser/environment/2048/frontend/components/GameBoard.tsx b/environments/browser/environment/2048/frontend/components/GameBoard.tsx deleted file mode 100644 index d5678e41..00000000 --- a/environments/browser/environment/2048/frontend/components/GameBoard.tsx +++ /dev/null @@ -1,31 +0,0 @@ -import React from 'react'; -import GameTile from './GameTile'; - -interface GameBoardProps { - board: number[][]; -} - -export default function GameBoard({ board }: GameBoardProps) { - const boardSize = board.length; - - return ( -
-
- {board.map((row, i) => - row.map((value, j) => ( - - )) - )} -
-
- ); -} \ No newline at end of file diff --git a/environments/browser/environment/2048/frontend/components/GameControls.tsx b/environments/browser/environment/2048/frontend/components/GameControls.tsx deleted file mode 100644 index b89b3613..00000000 --- a/environments/browser/environment/2048/frontend/components/GameControls.tsx +++ /dev/null @@ -1,104 +0,0 @@ -import React, { useState } from 'react'; - -interface GameState { - score: number; - moves: number; - game_over: boolean; - won: boolean; - highest_tile: number; - target_tile: number; -} - -interface GameControlsProps { - gameState: GameState; - onNewGame: (boardSize: number, targetTile: number) => void; - message: string; -} - -export default function GameControls({ gameState, onNewGame, message }: GameControlsProps) { - const [targetTile, setTargetTile] = useState(gameState.target_tile); - const [boardSize, setBoardSize] = useState(4); - - const efficiency = gameState.moves > 0 - ? (gameState.score / gameState.moves).toFixed(1) - : '0.0'; - - return ( -
- {/* Score and Stats */} -
-
-
Score
-
{gameState.score}
-
-
-
Moves
-
{gameState.moves}
-
-
-
Highest
-
{gameState.highest_tile}
-
-
-
Efficiency
-
{efficiency}
-
-
- - {/* Game Controls */} -
-
-
- - -
- -
- - -
- - -
-
- - {/* Status Message */} - {message && ( -
- {message} -
- )} -
- ); -} \ No newline at end of file diff --git a/environments/browser/environment/2048/frontend/components/GameTile.tsx b/environments/browser/environment/2048/frontend/components/GameTile.tsx deleted file mode 100644 index e3b4bdfc..00000000 --- a/environments/browser/environment/2048/frontend/components/GameTile.tsx +++ /dev/null @@ -1,53 +0,0 @@ -import React from 'react'; - -interface GameTileProps { - value: number; - position: { row: number; col: number }; -} - -export default function GameTile({ value }: GameTileProps) { - const getTileColor = (val: number): string => { - const colors: { [key: number]: string } = { - 0: 'bg-gray-200', - 2: 'bg-yellow-100', - 4: 'bg-yellow-200', - 8: 'bg-orange-300', - 16: 'bg-orange-400', - 32: 'bg-orange-500', - 64: 'bg-red-400', - 128: 'bg-yellow-300', - 256: 'bg-yellow-400', - 512: 'bg-yellow-500', - 1024: 'bg-yellow-600', - 2048: 'bg-yellow-700', - 4096: 'bg-purple-600', - 8192: 'bg-purple-700', - }; - return colors[val] || 'bg-purple-800'; - }; - - const getTextSize = (val: number): string => { - if (val === 0) return ''; - if (val < 100) return 'text-3xl'; - if (val < 1000) return 'text-2xl'; - return 'text-xl'; - }; - - const getTextColor = (val: number): string => { - return val > 4 ? 'text-white' : 'text-gray-800'; - }; - - return ( -
- {value > 0 && value} -
- ); -} \ No newline at end of file diff --git a/environments/browser/environment/2048/frontend/next.config.js b/environments/browser/environment/2048/frontend/next.config.js deleted file mode 100644 index cf97dc63..00000000 --- a/environments/browser/environment/2048/frontend/next.config.js +++ /dev/null @@ -1,6 +0,0 @@ -/** @type {import('next').NextConfig} */ -const nextConfig = { - reactStrictMode: true, -} - -module.exports = nextConfig \ No newline at end of file diff --git a/environments/browser/environment/2048/frontend/package.json b/environments/browser/environment/2048/frontend/package.json deleted file mode 100644 index 7a7e412c..00000000 --- a/environments/browser/environment/2048/frontend/package.json +++ /dev/null @@ -1,28 +0,0 @@ -{ - "name": "game-2048-frontend", - "version": "1.0.0", - "private": true, - "scripts": { - "dev": "next dev", - "build": "next build", - "start": "next start", - "lint": "next lint" - }, - "dependencies": { - "next": "14.1.0", - "react": "^18", - "react-dom": "^18", - "swr": "^2.2.4" - }, - "devDependencies": { - "@types/node": "^20", - "@types/react": "^18", - "@types/react-dom": "^18", - "autoprefixer": "^10.0.1", - "eslint": "^8", - "eslint-config-next": "14.1.0", - "postcss": "^8", - "tailwindcss": "^3.3.0", - "typescript": "^5" - } -} \ No newline at end of file diff --git a/environments/browser/environment/2048/frontend/postcss.config.js b/environments/browser/environment/2048/frontend/postcss.config.js deleted file mode 100644 index 96bb01e7..00000000 --- a/environments/browser/environment/2048/frontend/postcss.config.js +++ /dev/null @@ -1,6 +0,0 @@ -module.exports = { - plugins: { - tailwindcss: {}, - autoprefixer: {}, - }, -} \ No newline at end of file diff --git a/environments/browser/environment/2048/frontend/tailwind.config.js b/environments/browser/environment/2048/frontend/tailwind.config.js deleted file mode 100644 index 47bc0bad..00000000 --- a/environments/browser/environment/2048/frontend/tailwind.config.js +++ /dev/null @@ -1,12 +0,0 @@ -/** @type {import('tailwindcss').Config} */ -module.exports = { - content: [ - './pages/**/*.{js,ts,jsx,tsx,mdx}', - './components/**/*.{js,ts,jsx,tsx,mdx}', - './app/**/*.{js,ts,jsx,tsx,mdx}', - ], - theme: { - extend: {}, - }, - plugins: [], -} \ No newline at end of file diff --git a/environments/browser/environment/2048/frontend/tsconfig.json b/environments/browser/environment/2048/frontend/tsconfig.json deleted file mode 100644 index 9b9948d5..00000000 --- a/environments/browser/environment/2048/frontend/tsconfig.json +++ /dev/null @@ -1,27 +0,0 @@ -{ - "compilerOptions": { - "target": "es5", - "lib": ["dom", "dom.iterable", "esnext"], - "allowJs": true, - "skipLibCheck": true, - "strict": true, - "noEmit": true, - "esModuleInterop": true, - "module": "esnext", - "moduleResolution": "bundler", - "resolveJsonModule": true, - "isolatedModules": true, - "jsx": "preserve", - "incremental": true, - "plugins": [ - { - "name": "next" - } - ], - "paths": { - "@/*": ["./*"] - } - }, - "include": ["next-env.d.ts", "**/*.ts", "**/*.tsx", ".next/types/**/*.ts"], - "exclude": ["node_modules"] -} \ No newline at end of file diff --git a/environments/browser/environment/2048/launch.py b/environments/browser/environment/2048/launch.py deleted file mode 100644 index a5645668..00000000 --- a/environments/browser/environment/2048/launch.py +++ /dev/null @@ -1,284 +0,0 @@ -#!/usr/bin/env python3 -"""2048 game launcher script.""" - -import subprocess -import time -import signal -import sys -import argparse -import logging -import os -import socket -from pathlib import Path -from typing import Optional - -# Configure logging to stderr to avoid stdio contamination -logging.basicConfig(level=logging.INFO, format="[%(asctime)s] 2048: %(message)s", stream=sys.stderr) - -# Global variables to track processes -frontend_process: Optional[subprocess.Popen] = None -backend_process: Optional[subprocess.Popen] = None - - -def cleanup_processes(): - """Clean up running processes.""" - global frontend_process, backend_process - logging.info("Shutting down services...") - - for proc in [frontend_process, backend_process]: - if proc and proc.poll() is None: - proc.terminate() - try: - proc.wait(timeout=5) - except subprocess.TimeoutExpired: - proc.kill() - - -def signal_handler(sig, frame): - """Handle shutdown signals.""" - cleanup_processes() - sys.exit(0) - - -def check_port_available(port: int) -> bool: - """Check if a port is available.""" - sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) - sock.settimeout(1) - try: - result = sock.connect_ex(("localhost", port)) - sock.close() - return result != 0 # Port is available if connection fails - except: - return True - - -def launch_app(frontend_port: int = 3001, backend_port: int = 5001): - """Launch the 2048 game with frontend and backend.""" - global frontend_process, backend_process - - # Set up signal handlers - signal.signal(signal.SIGINT, signal_handler) - signal.signal(signal.SIGTERM, signal_handler) - - try: - # Get current directory - app_dir = Path(__file__).parent - frontend_dir = app_dir / "frontend" - backend_dir = app_dir / "backend" - - logging.info( - f"Starting 2048 game - Frontend port: {frontend_port}, Backend port: {backend_port}" - ) - - # Check if ports are available - if not check_port_available(backend_port): - logging.warning(f"Backend port {backend_port} is already in use") - if not check_port_available(frontend_port): - logging.warning(f"Frontend port {frontend_port} is already in use") - - # Prepare backend command - backend_env = { - "PORT": str(backend_port), - "PYTHONPATH": str(backend_dir), - **dict(os.environ), - } - - # Check if we can use uv, otherwise fall back to system python - try: - subprocess.run(["uv", "--version"], check=True, capture_output=True) - backend_cmd = [ - "uv", - "run", - "uvicorn", - "main:app", - "--host", - "0.0.0.0", - "--port", - str(backend_port), - ] - logging.info("Using uv for backend") - except (subprocess.CalledProcessError, FileNotFoundError): - # Fall back to system python with uvicorn - logging.info("uv not available, using system python for backend") - backend_cmd = [ - "python3", - "-m", - "uvicorn", - "main:app", - "--host", - "0.0.0.0", - "--port", - str(backend_port), - ] - - # Prepare frontend command - frontend_env = { - "NEXT_PUBLIC_API_URL": f"http://localhost:{backend_port}", - "PORT": str(frontend_port), - **dict(os.environ), - } - - # Check if dependencies are installed - if frontend_dir.exists(): - node_modules = frontend_dir / "node_modules" - if not node_modules.exists(): - logging.info("Installing frontend dependencies...") - npm_install = subprocess.run( - ["npm", "install"], cwd=frontend_dir, capture_output=True - ) - if npm_install.returncode != 0: - logging.error( - f"Failed to install npm dependencies: {npm_install.stderr.decode()}" - ) - cleanup_processes() - raise RuntimeError("npm install failed") - - # Check if we have a production build - if (frontend_dir / ".next").exists(): - logging.info("Running in production mode (pre-built)...") - frontend_cmd = [ - "npm", - "run", - "start", - "--", - "--port", - str(frontend_port), - "--hostname", - "0.0.0.0", - ] - else: - logging.info("Running in development mode...") - frontend_cmd = [ - "npm", - "run", - "dev", - "--", - "--port", - str(frontend_port), - "--hostname", - "0.0.0.0", - ] - - # 🚀 START BOTH PROCESSES IN PARALLEL - logging.info("Starting backend and frontend in parallel...") - - # Start backend - UPDATE GLOBAL VARIABLE - backend_process = subprocess.Popen( - backend_cmd, - cwd=backend_dir, - env=backend_env, - stdin=subprocess.DEVNULL, - stdout=subprocess.DEVNULL, # Don't capture stdout - reserved for MCP - stderr=subprocess.DEVNULL, # Don't capture stderr - reserved for MCP - ) - - # Start frontend immediately (in parallel) - UPDATE GLOBAL VARIABLE - if frontend_dir.exists(): - frontend_process = subprocess.Popen( - frontend_cmd, - cwd=frontend_dir, - env=frontend_env, - stdin=subprocess.DEVNULL, - stdout=subprocess.DEVNULL, # Don't capture stdout - reserved for MCP - stderr=subprocess.DEVNULL, # Don't capture stderr - reserved for MCP - ) - - # 🚀 WAIT FOR BOTH IN PARALLEL WITH FAST POLLING - backend_ready = False - frontend_ready = False - - # Use faster polling (every 200ms instead of 1s) - max_attempts_backend = 150 # 30 seconds at 200ms intervals - max_attempts_frontend = 600 # 120 seconds at 200ms intervals - - for attempt in range(max(max_attempts_backend, max_attempts_frontend)): - # Check if processes are still alive - if backend_process and backend_process.poll() is not None: - logging.error(f"Backend process died with exit code {backend_process.returncode}") - cleanup_processes() - raise RuntimeError("Backend failed to start") - - if frontend_process and frontend_process.poll() is not None: - logging.error(f"Frontend process died with exit code {frontend_process.returncode}") - cleanup_processes() - raise RuntimeError("Frontend failed to start") - - # Check backend readiness - if not backend_ready and attempt < max_attempts_backend: - sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) - sock.settimeout(0.1) - try: - result = sock.connect_ex(("localhost", backend_port)) - sock.close() - if result == 0: - backend_ready = True - logging.info(f"Backend is ready (attempt {attempt + 1})") - except: - pass - - # Check frontend readiness - if not frontend_ready and attempt < max_attempts_frontend: - sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) - sock.settimeout(0.1) - try: - result = sock.connect_ex(("localhost", frontend_port)) - sock.close() - if result == 0: - frontend_ready = True - logging.info(f"Frontend is ready (attempt {attempt + 1})") - except: - pass - - # Exit early if both are ready - if backend_ready and frontend_ready: - break - - time.sleep(0.2) # 200ms intervals instead of 1s - - # Check final status - if not backend_ready: - logging.error("Backend did not start within 30 seconds") - cleanup_processes() - raise RuntimeError("Backend startup timeout") - - if not frontend_ready: - logging.error("Frontend did not start within 2 minutes") - cleanup_processes() - raise RuntimeError("Frontend startup timeout") - - # Log startup information - logging.info("2048 game started successfully!") - logging.info(f"Frontend: http://localhost:{frontend_port}") - logging.info(f"Backend API: http://localhost:{backend_port}/docs") - logging.info("Press Ctrl+C to stop") - - # Wait for processes to finish - while True: - time.sleep(1) - if backend_process and backend_process.poll() is not None: - logging.error("Backend process died unexpectedly") - break - if frontend_process and frontend_process.poll() is not None: - logging.error("Frontend process died unexpectedly") - break - - except Exception as e: - logging.error(f"Error launching app: {e}") - cleanup_processes() - raise - - -if __name__ == "__main__": - parser = argparse.ArgumentParser(description="Launch 2048 Game") - parser.add_argument("--frontend-port", type=int, default=3001, help="Frontend port") - parser.add_argument("--backend-port", type=int, default=5001, help="Backend port") - - args = parser.parse_args() - - try: - launch_app(args.frontend_port, args.backend_port) - except KeyboardInterrupt: - logging.info("App interrupted by user") - except Exception as e: - logging.error(f"Failed to launch app: {e}") - sys.exit(1) diff --git a/environments/browser/environment/README.md b/environments/browser/environment/README.md deleted file mode 100644 index 2c86019e..00000000 --- a/environments/browser/environment/README.md +++ /dev/null @@ -1,135 +0,0 @@ -# Apps Directory - -Launchable web applications for the HUD browser environment. Each app is a self-contained service that can be dynamically launched. - -## App Specification - -Each app must implement: - -### Required Files -- `launch.py` - Entry point script with standardized arguments -- `backend/` - Backend service (required) -- `frontend/` - Frontend service (optional) - -### Launch Script Interface - -```python -# launch.py -import argparse - -def main(): - parser = argparse.ArgumentParser() - parser.add_argument("--frontend-port", type=int) - parser.add_argument("--backend-port", type=int, required=True) - args = parser.parse_args() - - # Start your services on the provided ports - # Backend must run on args.backend_port - # Frontend (if present) should run on args.frontend_port - -if __name__ == "__main__": - main() -``` - -### Service Requirements - -**Backend** -- Must bind to the provided `--backend-port` -- Should implement health check endpoint (`/health`) -- Must handle graceful shutdown -- Should use production-ready server (uvicorn, gunicorn, etc.) - -**Frontend** (Optional) -- Must bind to the provided `--frontend-port` -- Should be a static build or development server -- Common frameworks: Next.js, React, Vue, etc. - -## App Lifecycle - -1. **Discovery** - Apps are discovered by scanning subdirectories -2. **Launch** - Controller calls `python launch.py --backend-port=5000 --frontend-port=3000` -3. **Registration** - Ports are registered for API access -4. **Operation** - App services run independently -5. **Cleanup** - Processes terminated when environment shuts down - -## Integration Patterns - -### Basic Web App -```python -# Minimal FastAPI backend -from fastapi import FastAPI -import uvicorn - -app = FastAPI() - -@app.get("/health") -def health(): - return {"status": "healthy"} - -if __name__ == "__main__": - import sys - port = int(sys.argv[sys.argv.index("--backend-port") + 1]) - uvicorn.run(app, host="0.0.0.0", port=port) -``` - -### Full-Stack App -```python -# launch.py for app with both frontend and backend -import subprocess -import sys - -def main(): - parser = argparse.ArgumentParser() - parser.add_argument("--frontend-port", type=int) - parser.add_argument("--backend-port", type=int, required=True) - args = parser.parse_args() - - # Start backend - backend_proc = subprocess.Popen([ - "uvicorn", "backend.main:app", - "--host", "0.0.0.0", - "--port", str(args.backend_port) - ]) - - # Start frontend (if port provided) - if args.frontend_port: - frontend_proc = subprocess.Popen([ - "npm", "run", "dev", "--", "--port", str(args.frontend_port) - ], cwd="frontend") - - # Wait for processes - try: - backend_proc.wait() - except KeyboardInterrupt: - backend_proc.terminate() - if args.frontend_port: - frontend_proc.terminate() -``` - -## Optional Integrations - -### Evaluation APIs -Apps can optionally provide evaluation endpoints for testing: -- `GET /api/eval/health` - Health check -- `GET /api/eval/stats` - Application statistics -- Additional endpoints as needed - -### Environment Access -Apps can access the browser environment through: -- Shared network (communicate with controller) -- File system (shared volumes) -- Environment variables - -## Development Guidelines - -- **Port Binding** - Always use provided ports, never hardcode -- **Health Checks** - Implement basic health endpoints -- **Logging** - Use structured logging for debugging -- **Dependencies** - Manage dependencies with lockfiles -- **Graceful Shutdown** - Handle SIGTERM properly -- **Error Handling** - Return meaningful error responses - -## Examples - -- `todo/` - Full-stack Next.js + FastAPI application with evaluation integration -- See individual app READMEs for specific implementation details \ No newline at end of file diff --git a/environments/browser/environment/__init__.py b/environments/browser/environment/__init__.py deleted file mode 100644 index 36902690..00000000 --- a/environments/browser/environment/__init__.py +++ /dev/null @@ -1,3 +0,0 @@ -"""Browser environment server package.""" - -__version__ = "0.1.0" diff --git a/environments/browser/environment/pyproject.toml b/environments/browser/environment/pyproject.toml deleted file mode 100644 index f6f853f8..00000000 --- a/environments/browser/environment/pyproject.toml +++ /dev/null @@ -1,23 +0,0 @@ -[project] -name = "hud-browser-environment" -version = "0.1.0" -description = "HUD Browser Environment Backend" -requires-python = ">=3.11,<3.14" -dependencies = [ - "fastapi>=0.104.1", - "uvicorn[standard]>=0.24.0", - "python-multipart>=0.0.6", - "pydantic>=2.6,<3", - "pydantic-settings>=2.2,<3", - "httpx", -] - -[build-system] -requires = ["hatchling"] -build-backend = "hatchling.build" - -[tool.hatch.metadata] -allow-direct-references = true - -[tool.hatch.build.targets.wheel] -packages = ["environment"] diff --git a/environments/browser/environment/server.py b/environments/browser/environment/server.py deleted file mode 100644 index bd1297c7..00000000 --- a/environments/browser/environment/server.py +++ /dev/null @@ -1,503 +0,0 @@ -""" -FastAPI server for browser environment. -Exposes API endpoints to interact with the environment and its subcomponents. -""" - -import asyncio -import subprocess -import os -import logging -from pathlib import Path -from typing import Optional, Dict, List, Any, Set -import socket -from contextlib import asynccontextmanager -import shutil -import httpx - -from fastapi import FastAPI, HTTPException -from pydantic import BaseModel - -# Configure logging -logging.basicConfig( - level=logging.INFO, format="[%(levelname)s] %(asctime)s | %(name)s | %(message)s" -) -logger = logging.getLogger(__name__) - - -class AppInfo(BaseModel): - """Information about a launched app.""" - - name: str - frontend_port: int - backend_port: int - url: str - status: str - - -class ServiceStatus(BaseModel): - """Status of environment services.""" - - x11: bool - vnc: bool - websockify: bool - apps: List[AppInfo] - - -class LaunchAppRequest(BaseModel): - """Request to launch an app.""" - - app_name: str - - -class LaunchAppResponse(BaseModel): - """Response after launching an app.""" - - name: str - url: str - frontend_port: int - backend_port: int - - -class ServiceManager: - """Manages environment services (X11, VNC, apps).""" - - def __init__(self): - self.x11_proc: Optional[subprocess.Popen] = None - self.vnc_proc: Optional[subprocess.Popen] = None - self.websockify_proc: Optional[subprocess.Popen] = None - self.chrome_proc: Optional[subprocess.Popen] = None - self.cdp_port: Optional[int] = None - self._launched_apps: Dict[str, AppInfo] = {} - self._playwright = None - self._browser = None - self._app_processes: Dict[str, subprocess.Popen] = {} - self._allocated_ports: Set[int] = set() - - async def start_core_services(self): - """Start X11, VNC, and websockify services.""" - # Check if X11 is already running - if Path("/tmp/.X11-unix/X1").exists(): - logger.info("X11 display :1 already running") - else: - # Start Xvfb if not already running - self.x11_proc = subprocess.Popen( - ["Xvfb", ":1", "-screen", "0", "1920x1080x24"], - stdin=subprocess.DEVNULL, - stdout=subprocess.DEVNULL, - stderr=subprocess.PIPE, - ) - logger.info("Started Xvfb on display :1") - - # Wait for X11 - await self._wait_for_x11() - - # Start VNC and websockify - await self._start_vnc_services() - - async def _wait_for_x11(self): - """Wait for X11 display to be ready.""" - for i in range(100): # 10 seconds max - if Path("/tmp/.X11-unix/X1").exists(): - logger.info("X11 display :1 is ready") - os.environ["DISPLAY"] = ":1" - return - await asyncio.sleep(0.1) - raise TimeoutError("X11 failed to start") - - async def _start_vnc_services(self): - """Start VNC and websockify services.""" - # Start x11vnc - self.vnc_proc = subprocess.Popen( - ["x11vnc", "-display", ":1", "-forever", "-shared", "-nopw"], - stdin=subprocess.DEVNULL, - stdout=subprocess.DEVNULL, - stderr=subprocess.PIPE, - env={**os.environ, "DISPLAY": ":1"}, - ) - logger.info("Started x11vnc") - - # Start websockify - self.websockify_proc = subprocess.Popen( - ["websockify", "--web", "/usr/share/novnc", "8080", "localhost:5900"], - stdin=subprocess.DEVNULL, - stdout=subprocess.DEVNULL, - stderr=subprocess.PIPE, - ) - logger.info("Started websockify on port 8080") - - # Wait for both services - await asyncio.gather( - self._wait_for_port(5900, "VNC"), self._wait_for_port(8080, "websockify") - ) - logger.info("noVNC available at: http://localhost:8080/vnc.html") - - # Start Playwright's Chromium browser - logger.info("Starting Playwright's Chromium browser") - try: - from playwright.async_api import async_playwright - - self._playwright = await async_playwright().start() - # Get a free port for CDP - self.cdp_port = self._get_next_port() - - self._browser = await self._playwright.chromium.launch( - headless=False, - args=[ - f"--remote-debugging-port={self.cdp_port}", - "--no-sandbox", - "--disable-dev-shm-usage", - "--disable-gpu", - "--disable-web-security", - "--disable-features=IsolateOrigins,site-per-process", - "--display=:1", - "--start-maximized", - ], - env={**os.environ, "DISPLAY": ":1"}, - ) - - logger.info(f"Started Playwright Chromium with CDP on port {self.cdp_port}") - - # Wait for CDP to be ready - await self._wait_for_port(self.cdp_port, "CDP", timeout=30) - - # Open a default page so the browser window is visible - default_context = await self._browser.new_context( - viewport={"width": 1920, "height": 1080}, no_viewport=False - ) - default_page = await default_context.new_page() - await default_page.goto("about:blank") - logger.info("Opened default browser page") - - except ImportError: - logger.error("Playwright not installed") - raise RuntimeError("Playwright is required. The Docker image should have installed it.") - except Exception as e: - logger.error(f"Failed to start Playwright browser: {e}") - raise - - async def launch_app(self, app_name: str) -> LaunchAppResponse: - """Launch a specific app dynamically.""" - # Check if app is already running - if app_name in self._launched_apps: - app_info = self._launched_apps[app_name] - if app_info.status == "running": - return LaunchAppResponse( - name=app_info.name, - url=app_info.url, - frontend_port=app_info.frontend_port, - backend_port=app_info.backend_port, - ) - - app_path = Path(f"/app/environment/{app_name}") - if not app_path.exists(): - raise ValueError(f"App '{app_name}' not found at {app_path}") - - # Check if app has a launch script - launch_script = app_path / "launch.py" - if not launch_script.exists(): - raise ValueError(f"App '{app_name}' missing launch.py") - - # Get unique ports for frontend and backend - frontend_port = self._get_next_port() - backend_port = self._get_next_port() - - # Launch the app - proc = subprocess.Popen( - [ - "python3", - str(launch_script), - "--frontend-port", - str(frontend_port), - "--backend-port", - str(backend_port), - ], - cwd=app_path, - stdin=subprocess.DEVNULL, - stdout=subprocess.DEVNULL, - stderr=subprocess.DEVNULL, - env={**os.environ, "DISPLAY": ":1"}, - ) - - self._app_processes[app_name] = proc - - try: - # Wait for both ports - await asyncio.gather( - self._wait_for_port(frontend_port, f"app '{app_name}' frontend", timeout=60), - self._wait_for_port(backend_port, f"app '{app_name}' backend", timeout=60), - ) - - logger.info( - f"Launched app '{app_name}' - Frontend: {frontend_port}, Backend: {backend_port}" - ) - - # Store app information - app_info = AppInfo( - name=app_name, - frontend_port=frontend_port, - backend_port=backend_port, - url=f"http://localhost:{frontend_port}", - status="running", - ) - self._launched_apps[app_name] = app_info - - return LaunchAppResponse( - name=app_name, - url=app_info.url, - frontend_port=frontend_port, - backend_port=backend_port, - ) - - except TimeoutError: - # Check if process is still running - if proc.poll() is not None: - logger.error(f"App '{app_name}' process exited with code {proc.returncode}") - else: - logger.error(f"App '{app_name}' failed to become ready within timeout") - raise - - def get_service_status(self) -> ServiceStatus: - """Get status of all services.""" - # Update app statuses - for app_name, proc in self._app_processes.items(): - if app_name in self._launched_apps: - if proc.poll() is None: - self._launched_apps[app_name].status = "running" - else: - self._launched_apps[app_name].status = "stopped" - - return ServiceStatus( - x11=self.x11_proc is not None and self.x11_proc.poll() is None - if self.x11_proc - else Path("/tmp/.X11-unix/X1").exists(), - vnc=self.vnc_proc is not None and self.vnc_proc.poll() is None - if self.vnc_proc - else self._is_port_open(5900), - websockify=self.websockify_proc is not None and self.websockify_proc.poll() is None - if self.websockify_proc - else self._is_port_open(8080), - apps=list(self._launched_apps.values()), - ) - - def get_app_info(self, app_name: str) -> AppInfo: - """Get information about a specific app.""" - if app_name not in self._launched_apps: - raise ValueError(f"App '{app_name}' not found") - return self._launched_apps[app_name] - - async def shutdown(self): - """Shutdown all services gracefully.""" - # Stop app processes - for name, proc in self._app_processes.items(): - if proc.poll() is None: - proc.terminate() - await asyncio.sleep(1) - if proc.poll() is None: - proc.kill() - logger.info(f"Terminated app '{name}'") - - # Clear app tracking - self._app_processes.clear() - self._launched_apps.clear() - self._allocated_ports.clear() - - # Close Playwright browser - if self._browser: - try: - await self._browser.close() - logger.info("Closed Playwright browser") - except Exception as e: - logger.error(f"Error closing browser: {e}") - - if self._playwright: - try: - await self._playwright.stop() - logger.info("Stopped Playwright") - except Exception as e: - logger.error(f"Error stopping playwright: {e}") - - # Stop services in reverse order - for proc, name in [ - (self.websockify_proc, "websockify"), - (self.vnc_proc, "x11vnc"), - (self.x11_proc, "Xvfb"), - ]: - if proc and proc.poll() is None: - proc.terminate() - await asyncio.sleep(0.5) - if proc.poll() is None: - proc.kill() - logger.info(f"Stopped {name}") - - def _is_port_open(self, port: int) -> bool: - """Check if a port is open.""" - sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) - sock.settimeout(0.1) - try: - result = sock.connect_ex(("localhost", port)) - sock.close() - return result == 0 - except: - return False - - def _get_next_port(self) -> int: - """Get next available port for apps.""" - base_port = 3000 - for offset in range(200): # Support up to 200 ports - port = base_port + offset - if not self._is_port_open(port) and port not in self._allocated_ports: - self._allocated_ports.add(port) - return port - raise RuntimeError("No available ports") - - async def _wait_for_port(self, port: int, service_name: str = "service", timeout: int = 30): - """Wait for a port to become available.""" - for _ in range(timeout * 5): # Check every 200ms - if self._is_port_open(port): - logger.info(f"{service_name} is ready on port {port}") - return - await asyncio.sleep(0.2) - raise TimeoutError(f"Port {port} did not become available for {service_name}") - - async def get_cdp_websocket_url(self) -> str | None: - """Discover the actual CDP WebSocket URL from Chrome's /json/version endpoint.""" - if not self.cdp_port: - return None - - try: - async with httpx.AsyncClient() as client: - response = await client.get( - f"http://localhost:{self.cdp_port}/json/version", timeout=5.0 - ) - if response.status_code == 200: - data = response.json() - # Chrome returns webSocketDebuggerUrl in /json/version response - websocket_url = data.get("webSocketDebuggerUrl") - if websocket_url: - return websocket_url - - # Fallback: try /json/list to find a browser target - response = await client.get( - f"http://localhost:{self.cdp_port}/json/list", timeout=5.0 - ) - if response.status_code == 200: - targets = response.json() - # Look for a browser target (type 'page' or title containing 'about:blank') - for target in targets: - if target.get("type") == "page" or "about:blank" in target.get("url", ""): - websocket_url = target.get("webSocketDebuggerUrl") - if websocket_url: - return websocket_url - - except Exception as e: - logger.warning(f"Failed to discover CDP WebSocket URL: {e}") - - # Final fallback to generic path (may not work) - return f"ws://localhost:{self.cdp_port}/devtools/browser" - - -# Global service manager instance -service_manager = ServiceManager() - - -@asynccontextmanager -async def lifespan(app: FastAPI): - """Manage application lifecycle.""" - # Startup - logger.info("Starting browser environment server...") - await service_manager.start_core_services() - logger.info("Browser environment server ready") - - yield - - # Shutdown - logger.info("Shutting down browser environment server...") - await service_manager.shutdown() - - -# Create FastAPI app -app = FastAPI( - title="Browser Environment API", - description="API for managing browser environment services and applications", - version="1.0.0", - lifespan=lifespan, -) - - -@app.get("/health") -async def health_check(): - """Health check endpoint.""" - return {"status": "healthy"} - - -@app.get("/status", response_model=ServiceStatus) -async def get_status(): - """Get status of all environment services.""" - return service_manager.get_service_status() - - -@app.post("/apps/launch", response_model=LaunchAppResponse) -async def launch_app(request: LaunchAppRequest): - """Launch a specific application.""" - try: - return await service_manager.launch_app(request.app_name) - except ValueError as e: - raise HTTPException(status_code=404, detail=str(e)) - except Exception as e: - raise HTTPException(status_code=500, detail=str(e)) - - -@app.get("/apps/{app_name}", response_model=AppInfo) -async def get_app_info(app_name: str): - """Get information about a specific app.""" - try: - return service_manager.get_app_info(app_name) - except ValueError as e: - raise HTTPException(status_code=404, detail=str(e)) - - -@app.get("/vnc/url") -async def get_vnc_url(): - """Get the VNC viewer URL.""" - return {"url": "http://localhost:8080/vnc.html"} - - -@app.get("/display") -async def get_display(): - """Get the X11 display information.""" - return { - "display": os.environ.get("DISPLAY", ":1"), - "x11_running": Path("/tmp/.X11-unix/X1").exists(), - } - - -@app.get("/cdp") -async def get_cdp(): - """Return the CDP websocket URL for connecting Playwright/Chromium clients.""" - if service_manager.cdp_port is None: - raise HTTPException(status_code=503, detail="CDP not available") - - # Discover the actual CDP WebSocket URL from Chrome - websocket_url = await service_manager.get_cdp_websocket_url() - if not websocket_url: - raise HTTPException(status_code=503, detail="CDP WebSocket URL not available") - - return {"ws": websocket_url} - - -@app.post("/shutdown") -async def shutdown_env(): - """Gracefully stop services and request server shutdown.""" - try: - await service_manager.shutdown() - except Exception as e: - logger.warning(f"Error during environment shutdown: {e}") - # Signal uvicorn to exit via lifespan shutdown - # FastAPI/uvicorn doesn't expose server here; we rely on process signal from caller. - return {"status": "shutting_down"} - - -if __name__ == "__main__": - import uvicorn - - uvicorn.run(app, host="0.0.0.0", port=8000) diff --git a/environments/browser/environment/todo/README.md b/environments/browser/environment/todo/README.md deleted file mode 100644 index 7d2460e9..00000000 --- a/environments/browser/environment/todo/README.md +++ /dev/null @@ -1,85 +0,0 @@ -# Todo App - -Simple todo list application with Next.js frontend and FastAPI backend, fully integrated with the HUD evaluation system. - -## Tech Stack - -- **Frontend**: Next.js, TypeScript, Tailwind CSS -- **Backend**: FastAPI, SQLite, uv for dependency management -- **Evaluation**: Comprehensive API endpoints for testing - -## Development - -```bash -# Backend -cd backend && uv run uvicorn main:app --reload - -# Frontend -cd frontend && npm install && npm run dev -``` - -## Launching - -```python -await client.call_tool("launch_app", {"app_name": "todo"}) -``` - -## Evaluation Integration - -### Backend API Endpoints -- `GET /api/eval/health` - Health check -- `GET /api/eval/stats` - Comprehensive statistics -- `GET /api/eval/has_todo?text=` - Check if todo exists -- `GET /api/eval/completion_rate` - Completion percentage -- `POST /api/eval/seed` - Seed test data -- `DELETE /api/eval/reset` - Reset database - -### Controller Components -- **Evaluators**: `TodoCompletedEvaluator`, `TodoExistsEvaluator`, `CompositeEvaluator` -- **Setup Tools**: `TodoSeedSetup`, `TodoResetSetup`, `TodoCustomSeedSetup` -- **Problems**: `TodoBasicUsageProblem`, `TodoCompositeWeightedProblem` - -### Usage Examples - -```python -# Complete problem execution -await setup({"name": "todo_basic_usage"}) -await evaluate({"name": "todo_basic_usage"}) - -# Direct function calls -await setup({"name": "todo_reset", "arguments": {}}) -await evaluate({"name": "todo_completion_rate", "arguments": {"min_rate": 0.5}}) - -# MCP resource discovery -todo_evaluators = await client.read_resource("evaluators://todo") -``` - -## Database Schema - -```sql -CREATE TABLE items ( - id INTEGER PRIMARY KEY AUTOINCREMENT, - title TEXT NOT NULL, - description TEXT, - completed BOOLEAN DEFAULT FALSE, - created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP -); -``` - -## Testing - -### Manual -1. Launch app: `await launch_app("todo")` -2. Access at http://localhost:3000 -3. Run evaluations - -### Automated -```bash -# Test APIs -curl http://localhost:5000/api/eval/health -curl http://localhost:5000/api/eval/stats - -# Test MCP tools -await setup({"name": "todo_basic_usage"}) -await evaluate({"name": "todo_basic_usage"}) -``` \ No newline at end of file diff --git a/environments/browser/environment/todo/backend/main.py b/environments/browser/environment/todo/backend/main.py deleted file mode 100644 index 5839fa85..00000000 --- a/environments/browser/environment/todo/backend/main.py +++ /dev/null @@ -1,391 +0,0 @@ -from fastapi import FastAPI, HTTPException -from fastapi.middleware.cors import CORSMiddleware -from pydantic import BaseModel -from typing import List, Optional -from datetime import datetime -import sqlite3 -import json - -app = FastAPI(title="Todo API with Evaluation", version="0.2.0") - -# Configure CORS -app.add_middleware( - CORSMiddleware, - allow_origins=["http://localhost:3000"], - allow_credentials=True, - allow_methods=["*"], - allow_headers=["*"], -) - - -# Pydantic models -class Item(BaseModel): - id: Optional[int] = None - title: str - description: str - completed: bool = False - created_at: Optional[datetime] = None - - -class ItemCreate(BaseModel): - title: str - description: str - completed: bool = False - - -class BulkUpdateRequest(BaseModel): - item_ids: List[int] - completed: Optional[bool] = None - - -class EvaluationStats(BaseModel): - total_items: int - completed_items: int - pending_items: int - completion_rate: float - items: List[Item] - timestamps: dict - - -# Database setup -def init_db(): - conn = sqlite3.connect("app.db") - c = conn.cursor() - c.execute(""" - CREATE TABLE IF NOT EXISTS items ( - id INTEGER PRIMARY KEY AUTOINCREMENT, - title TEXT NOT NULL, - description TEXT, - completed BOOLEAN NOT NULL DEFAULT 0, - created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP - ) - """) - conn.commit() - conn.close() - - -init_db() - - -# === CORE TODO API ROUTES === - - -@app.get("/api/status") -def status(): - return {"status": "ok", "timestamp": datetime.now().isoformat()} - - -@app.get("/api/items", response_model=List[Item]) -def get_items(): - conn = sqlite3.connect("app.db") - conn.row_factory = sqlite3.Row - c = conn.cursor() - c.execute("SELECT * FROM items ORDER BY created_at DESC") - items = [dict(row) for row in c.fetchall()] - conn.close() - return items - - -@app.post("/api/items", response_model=Item) -def create_item(item: ItemCreate): - conn = sqlite3.connect("app.db") - c = conn.cursor() - c.execute( - "INSERT INTO items (title, description, completed) VALUES (?, ?, ?)", - (item.title, item.description, item.completed), - ) - item_id = c.lastrowid - conn.commit() - conn.close() - - return get_item(item_id) - - -@app.get("/api/items/{item_id}", response_model=Item) -def get_item(item_id: int): - conn = sqlite3.connect("app.db") - conn.row_factory = sqlite3.Row - c = conn.cursor() - c.execute("SELECT * FROM items WHERE id = ?", (item_id,)) - item = c.fetchone() - conn.close() - - if not item: - raise HTTPException(status_code=404, detail="Item not found") - - return dict(item) - - -@app.put("/api/items/{item_id}", response_model=Item) -def update_item(item_id: int, item: ItemCreate): - conn = sqlite3.connect("app.db") - c = conn.cursor() - c.execute( - "UPDATE items SET title = ?, description = ?, completed = ? WHERE id = ?", - (item.title, item.description, item.completed, item_id), - ) - conn.commit() - - if c.rowcount == 0: - conn.close() - raise HTTPException(status_code=404, detail="Item not found") - - conn.close() - return get_item(item_id) - - -@app.delete("/api/items/{item_id}") -def delete_item(item_id: int): - conn = sqlite3.connect("app.db") - c = conn.cursor() - c.execute("DELETE FROM items WHERE id = ?", (item_id,)) - conn.commit() - - if c.rowcount == 0: - conn.close() - raise HTTPException(status_code=404, detail="Item not found") - - conn.close() - return {"message": "Item deleted successfully"} - - -# === EVALUATION API ROUTES === - - -@app.get("/api/eval/health") -def eval_health(): - """Health check endpoint for evaluation system.""" - try: - conn = sqlite3.connect("app.db") - c = conn.cursor() - c.execute("SELECT COUNT(*) FROM items") - count = c.fetchone()[0] - conn.close() - - return { - "status": "healthy", - "database_accessible": True, - "total_items": count, - "timestamp": datetime.now().isoformat(), - } - except Exception as e: - return {"status": "unhealthy", "error": str(e), "timestamp": datetime.now().isoformat()} - - -@app.get("/api/eval/stats", response_model=EvaluationStats) -def get_evaluation_stats(): - """Comprehensive evaluation statistics for the todo app.""" - conn = sqlite3.connect("app.db") - conn.row_factory = sqlite3.Row - c = conn.cursor() - - # Get total counts - c.execute("SELECT COUNT(*) as total FROM items") - total = c.fetchone()[0] - - c.execute("SELECT COUNT(*) as completed FROM items WHERE completed = 1") - completed = c.fetchone()[0] - - # Get all items with details - c.execute("SELECT * FROM items ORDER BY created_at DESC") - items = [dict(row) for row in c.fetchall()] - - # Get timing information - c.execute(""" - SELECT created_at - FROM items - ORDER BY created_at DESC - LIMIT 1 - """) - last_created_row = c.fetchone() - last_created = last_created_row[0] if last_created_row else None - - c.execute(""" - SELECT created_at - FROM items - WHERE completed = 1 - ORDER BY created_at DESC - LIMIT 1 - """) - last_completed_row = c.fetchone() - last_completed = last_completed_row[0] if last_completed_row else None - - conn.close() - - return EvaluationStats( - total_items=total, - completed_items=completed, - pending_items=total - completed, - completion_rate=completed / total if total > 0 else 0.0, - items=items, - timestamps={"last_created": last_created, "last_completed": last_completed}, - ) - - -@app.get("/api/eval/todos", response_model=List[Item]) -def get_todos_for_evaluation(): - """Get all todos for evaluation purposes (alias for /api/items).""" - return get_items() - - -@app.get("/api/eval/has_todo") -def check_todo_exists(text: str): - """Check if a todo item exists with specific text in title or description.""" - conn = sqlite3.connect("app.db") - conn.row_factory = sqlite3.Row - c = conn.cursor() - c.execute( - """ - SELECT * FROM items - WHERE title LIKE ? OR description LIKE ? - ORDER BY created_at DESC - """, - (f"%{text}%", f"%{text}%"), - ) - - items = [dict(row) for row in c.fetchall()] - conn.close() - - return { - "exists": len(items) > 0, - "count": len(items), - "search_text": text, - "matches": items, - "timestamp": datetime.now().isoformat(), - } - - -@app.post("/api/eval/bulk_update") -def bulk_update_items(request: BulkUpdateRequest): - """Update multiple items at once for evaluation purposes.""" - conn = sqlite3.connect("app.db") - c = conn.cursor() - - updated_count = 0 - if request.completed is not None: - for item_id in request.item_ids: - c.execute("UPDATE items SET completed = ? WHERE id = ?", (request.completed, item_id)) - if c.rowcount > 0: - updated_count += 1 - - conn.commit() - conn.close() - - return { - "message": f"Updated {updated_count} items", - "updated_count": updated_count, - "requested_ids": request.item_ids, - "timestamp": datetime.now().isoformat(), - } - - -@app.get("/api/eval/completion_rate") -def get_completion_rate(): - """Get the current completion rate as a percentage.""" - conn = sqlite3.connect("app.db") - c = conn.cursor() - - c.execute("SELECT COUNT(*) as total FROM items") - total = c.fetchone()[0] - - c.execute("SELECT COUNT(*) as completed FROM items WHERE completed = 1") - completed = c.fetchone()[0] - - conn.close() - - rate = completed / total if total > 0 else 0.0 - - return { - "completion_rate": rate, - "completion_percentage": rate * 100, - "completed_items": completed, - "total_items": total, - "timestamp": datetime.now().isoformat(), - } - - -# === EVALUATION UTILITY ROUTES === - - -@app.post("/api/eval/seed") -def seed_test_data(): - """Seed the database with test data for evaluation purposes.""" - test_items = [ - {"title": "Buy groceries", "description": "Get milk, eggs, and bread", "completed": True}, - { - "title": "Walk the dog", - "description": "Take Max for a 30-minute walk", - "completed": True, - }, - { - "title": "Finish project", - "description": "Complete the Q4 presentation", - "completed": False, - }, - {"title": "Call mom", "description": "Weekly check-in call", "completed": False}, - { - "title": "Schedule dentist", - "description": "Book appointment for cleaning", - "completed": False, - }, - ] - - conn = sqlite3.connect("app.db") - c = conn.cursor() - - for item in test_items: - c.execute( - """ - INSERT INTO items (title, description, completed) - VALUES (?, ?, ?) - """, - (item["title"], item["description"], item["completed"]), - ) - - conn.commit() - conn.close() - - return { - "message": "Test data seeded successfully", - "items_added": len(test_items), - "timestamp": datetime.now().isoformat(), - } - - -@app.post("/api/eval/seed_custom") -def seed_custom_data(items: List[ItemCreate]): - """Seed the database with custom test data for evaluation purposes.""" - conn = sqlite3.connect("app.db") - c = conn.cursor() - - items_added = 0 - for item in items: - c.execute( - """ - INSERT INTO items (title, description, completed) - VALUES (?, ?, ?) - """, - (item.title, item.description if hasattr(item, "description") else "", item.completed), - ) - items_added += 1 - - conn.commit() - conn.close() - - return { - "message": "Custom test data seeded successfully", - "items_added": items_added, - "timestamp": datetime.now().isoformat(), - } - - -@app.delete("/api/eval/reset") -def reset_database(): - """Reset the database to empty state for clean evaluation.""" - conn = sqlite3.connect("app.db") - c = conn.cursor() - c.execute("DELETE FROM items") - conn.commit() - conn.close() - - return {"message": "Database reset successfully", "timestamp": datetime.now().isoformat()} diff --git a/environments/browser/environment/todo/backend/pyproject.toml b/environments/browser/environment/todo/backend/pyproject.toml deleted file mode 100644 index 493627d5..00000000 --- a/environments/browser/environment/todo/backend/pyproject.toml +++ /dev/null @@ -1,15 +0,0 @@ -[project] -name = "sample-backend" -version = "0.1.0" -description = "FastAPI backend for sample app" -requires-python = ">=3.10" -dependencies = [ - "fastapi==0.109.0", - "uvicorn[standard]==0.27.0", - "sqlalchemy==2.0.25", - "pydantic==2.5.3", - "python-multipart==0.0.6", -] - -[tool.uv] -dev-dependencies = [] \ No newline at end of file diff --git a/environments/browser/environment/todo/frontend/app/globals.css b/environments/browser/environment/todo/frontend/app/globals.css deleted file mode 100644 index de4d11a2..00000000 --- a/environments/browser/environment/todo/frontend/app/globals.css +++ /dev/null @@ -1,3 +0,0 @@ -@tailwind base; -@tailwind components; -@tailwind utilities; \ No newline at end of file diff --git a/environments/browser/environment/todo/frontend/app/layout.tsx b/environments/browser/environment/todo/frontend/app/layout.tsx deleted file mode 100644 index 0acab9a4..00000000 --- a/environments/browser/environment/todo/frontend/app/layout.tsx +++ /dev/null @@ -1,22 +0,0 @@ -import type { Metadata } from 'next' -import { Inter } from 'next/font/google' -import './globals.css' - -const inter = Inter({ subsets: ['latin'] }) - -export const metadata: Metadata = { - title: 'Sample App', - description: 'A sample Next.js app with FastAPI backend', -} - -export default function RootLayout({ - children, -}: { - children: React.ReactNode -}) { - return ( - - {children} - - ) -} \ No newline at end of file diff --git a/environments/browser/environment/todo/frontend/app/page.tsx b/environments/browser/environment/todo/frontend/app/page.tsx deleted file mode 100644 index c5de6422..00000000 --- a/environments/browser/environment/todo/frontend/app/page.tsx +++ /dev/null @@ -1,289 +0,0 @@ -'use client' - -import { useState, useEffect } from 'react' - -interface Item { - id: number - title: string - description: string - completed: boolean - created_at: string -} - -type FilterType = 'all' | 'active' | 'completed' - -// Dynamically determine API URL based on current port -// Backend is always on frontend_port + 1 -const getApiUrl = () => { - if (typeof window !== 'undefined') { - const currentPort = parseInt(window.location.port) || 3000; - return `http://localhost:${currentPort + 1}`; - } - return process.env.NEXT_PUBLIC_API_URL || 'http://localhost:5000'; -}; - -const API_URL = getApiUrl(); - -export default function Home() { - const [items, setItems] = useState([]) - const [newTitle, setNewTitle] = useState('') - const [newDescription, setNewDescription] = useState('') - const [loading, setLoading] = useState(true) - const [filter, setFilter] = useState('all') - const [searchTerm, setSearchTerm] = useState('') - - useEffect(() => { - fetchItems() - }, []) - - const fetchItems = async () => { - try { - const response = await fetch(`${API_URL}/api/items`) - const data = await response.json() - setItems(data) - } catch (error) { - console.error('Error fetching items:', error) - } finally { - setLoading(false) - } - } - - const createItem = async (e: React.FormEvent) => { - e.preventDefault() - if (!newTitle.trim()) return - - try { - const response = await fetch(`${API_URL}/api/items`, { - method: 'POST', - headers: { 'Content-Type': 'application/json' }, - body: JSON.stringify({ - title: newTitle, - description: newDescription, - completed: false - }) - }) - - if (response.ok) { - setNewTitle('') - setNewDescription('') - fetchItems() - } - } catch (error) { - console.error('Error creating item:', error) - } - } - - const toggleItem = async (id: number, item: Item) => { - try { - const response = await fetch(`${API_URL}/api/items/${id}`, { - method: 'PUT', - headers: { 'Content-Type': 'application/json' }, - body: JSON.stringify({ - ...item, - completed: !item.completed - }) - }) - - if (response.ok) { - fetchItems() - } - } catch (error) { - console.error('Error updating item:', error) - } - } - - const deleteItem = async (id: number) => { - try { - const response = await fetch(`${API_URL}/api/items/${id}`, { - method: 'DELETE' - }) - - if (response.ok) { - fetchItems() - } - } catch (error) { - console.error('Error deleting item:', error) - } - } - - const markAllComplete = async () => { - const activeItems = items.filter(item => !item.completed) - for (const item of activeItems) { - await toggleItem(item.id, item) - } - } - - const deleteCompleted = async () => { - const completedItems = items.filter(item => item.completed) - for (const item of completedItems) { - await deleteItem(item.id) - } - } - - // Filter and search logic - const filteredItems = items - .filter(item => { - if (filter === 'active') return !item.completed - if (filter === 'completed') return item.completed - return true - }) - .filter(item => { - if (!searchTerm) return true - const term = searchTerm.toLowerCase() - return item.title.toLowerCase().includes(term) || - item.description.toLowerCase().includes(term) - }) - - const stats = { - total: items.length, - active: items.filter(i => !i.completed).length, - completed: items.filter(i => i.completed).length - } - - return ( -
-
-

Todo App

- - {/* Stats Bar */} -
-
-
- - Total: {stats.total} - - - Active: {stats.active} - - - Completed: {stats.completed} - -
-
- - -
-
-
- - {/* Add Item Form */} -
-

Add New Item

-
- setNewTitle(e.target.value)} - className="w-full px-4 py-2 border border-gray-300 rounded-md focus:outline-none focus:ring-2 focus:ring-blue-500" - /> -