Skip to content

Commit 8dddac7

Browse files
committed
Add Gemini sub pr from hud-evals#169
1 parent 9cd353b commit 8dddac7

File tree

10 files changed

+1248
-4
lines changed

10 files changed

+1248
-4
lines changed

hud/agents/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,13 @@
22

33
from .base import MCPAgent
44
from .claude import ClaudeAgent
5+
from .gemini import GeminiAgent
56
from .openai import OperatorAgent
67
from .openai_chat_generic import GenericOpenAIChatAgent
78

89
__all__ = [
910
"ClaudeAgent",
11+
"GeminiAgent",
1012
"GenericOpenAIChatAgent",
1113
"MCPAgent",
1214
"OperatorAgent",

hud/agents/gemini.py

Lines changed: 485 additions & 0 deletions
Large diffs are not rendered by default.

hud/agents/tests/test_gemini.py

Lines changed: 356 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,356 @@
1+
"""Tests for Gemini MCP Agent implementation."""
2+
3+
from __future__ import annotations
4+
5+
import base64
6+
from unittest.mock import AsyncMock, MagicMock, patch
7+
8+
import pytest
9+
from google.genai import types as genai_types
10+
from mcp import types
11+
12+
from hud.agents.gemini import GeminiAgent
13+
from hud.types import MCPToolCall, MCPToolResult
14+
15+
16+
class TestGeminiAgent:
17+
"""Test GeminiAgent class."""
18+
19+
@pytest.fixture
20+
def mock_mcp_client(self):
21+
"""Create a mock MCP client."""
22+
mcp_client = AsyncMock()
23+
# Set up the mcp_config attribute as a regular dict, not a coroutine
24+
mcp_client.mcp_config = {"test_server": {"url": "http://test"}}
25+
# Mock list_tools to return gemini_computer tool
26+
mcp_client.list_tools = AsyncMock(
27+
return_value=[
28+
types.Tool(
29+
name="gemini_computer",
30+
description="Gemini computer use tool",
31+
inputSchema={},
32+
)
33+
]
34+
)
35+
mcp_client.initialize = AsyncMock()
36+
return mcp_client
37+
38+
@pytest.fixture
39+
def mock_gemini_client(self):
40+
"""Create a mock Gemini client."""
41+
client = MagicMock()
42+
client.api_key = "test_key"
43+
# Mock models.list for validation
44+
client.models = MagicMock()
45+
client.models.list = MagicMock(return_value=iter([]))
46+
return client
47+
48+
@pytest.mark.asyncio
49+
async def test_init(self, mock_mcp_client, mock_gemini_client):
50+
"""Test agent initialization."""
51+
agent = GeminiAgent(
52+
mcp_client=mock_mcp_client,
53+
model_client=mock_gemini_client,
54+
model="gemini-2.5-computer-use-preview-10-2025",
55+
validate_api_key=False, # Skip validation in tests
56+
)
57+
58+
assert agent.model_name == "gemini-2.5-computer-use-preview-10-2025"
59+
assert agent.model == "gemini-2.5-computer-use-preview-10-2025"
60+
assert agent.gemini_client == mock_gemini_client
61+
62+
@pytest.mark.asyncio
63+
async def test_init_without_model_client(self, mock_mcp_client):
64+
"""Test agent initialization without model client."""
65+
with patch("hud.settings.settings.gemini_api_key", "test_key"):
66+
with patch("hud.agents.gemini.genai.Client") as mock_client_class:
67+
mock_client = MagicMock()
68+
mock_client.api_key = "test_key"
69+
mock_client.models = MagicMock()
70+
mock_client.models.list = MagicMock(return_value=iter([]))
71+
mock_client_class.return_value = mock_client
72+
73+
agent = GeminiAgent(
74+
mcp_client=mock_mcp_client,
75+
model="gemini-2.5-computer-use-preview-10-2025",
76+
validate_api_key=False,
77+
)
78+
79+
assert agent.model_name == "gemini-2.5-computer-use-preview-10-2025"
80+
assert agent.gemini_client is not None
81+
82+
@pytest.mark.asyncio
83+
async def test_format_blocks(self, mock_mcp_client, mock_gemini_client):
84+
"""Test formatting content blocks into Gemini messages."""
85+
agent = GeminiAgent(
86+
mcp_client=mock_mcp_client,
87+
model_client=mock_gemini_client,
88+
validate_api_key=False,
89+
)
90+
91+
# Test with text only
92+
text_blocks: list[types.ContentBlock] = [
93+
types.TextContent(type="text", text="Hello, Gemini!")
94+
]
95+
messages = await agent.format_blocks(text_blocks)
96+
assert len(messages) == 1
97+
assert messages[0].role == "user"
98+
assert len(messages[0].parts) == 1
99+
assert messages[0].parts[0].text == "Hello, Gemini!"
100+
101+
# Test with screenshot
102+
image_blocks: list[types.ContentBlock] = [
103+
types.TextContent(type="text", text="Look at this"),
104+
types.ImageContent(
105+
type="image",
106+
data=base64.b64encode(b"fakeimage").decode("utf-8"),
107+
mimeType="image/png",
108+
),
109+
]
110+
messages = await agent.format_blocks(image_blocks)
111+
assert len(messages) == 1
112+
assert messages[0].role == "user"
113+
assert len(messages[0].parts) == 2
114+
# First part is text
115+
assert messages[0].parts[0].text == "Look at this"
116+
# Second part is image - check that it was created from bytes
117+
assert messages[0].parts[1].inline_data is not None
118+
119+
@pytest.mark.asyncio
120+
async def test_format_tool_results(self, mock_mcp_client, mock_gemini_client):
121+
"""Test the agent's format_tool_results method."""
122+
agent = GeminiAgent(
123+
mcp_client=mock_mcp_client,
124+
model_client=mock_gemini_client,
125+
validate_api_key=False,
126+
)
127+
128+
tool_calls = [
129+
MCPToolCall(
130+
name="gemini_computer",
131+
arguments={"action": "click_at", "x": 100, "y": 200},
132+
id="call_1", # type: ignore
133+
gemini_name="click_at", # type: ignore
134+
),
135+
]
136+
137+
tool_results = [
138+
MCPToolResult(
139+
content=[
140+
types.TextContent(type="text", text="Clicked successfully"),
141+
types.ImageContent(
142+
type="image",
143+
data=base64.b64encode(b"screenshot").decode("utf-8"),
144+
mimeType="image/png",
145+
),
146+
],
147+
isError=False,
148+
),
149+
]
150+
151+
messages = await agent.format_tool_results(tool_calls, tool_results)
152+
153+
# format_tool_results returns a single user message with function responses
154+
assert len(messages) == 1
155+
assert messages[0].role == "user"
156+
# The content contains function response parts
157+
assert len(messages[0].parts) == 1
158+
assert messages[0].parts[0].function_response is not None
159+
assert messages[0].parts[0].function_response.name == "click_at"
160+
assert messages[0].parts[0].function_response.response.get("success") is True
161+
162+
@pytest.mark.asyncio
163+
async def test_format_tool_results_with_error(self, mock_mcp_client, mock_gemini_client):
164+
"""Test formatting tool results with errors."""
165+
agent = GeminiAgent(
166+
mcp_client=mock_mcp_client,
167+
model_client=mock_gemini_client,
168+
validate_api_key=False,
169+
)
170+
171+
tool_calls = [
172+
MCPToolCall(
173+
name="gemini_computer",
174+
arguments={"action": "invalid"},
175+
id="call_error", # type: ignore
176+
gemini_name="invalid_action", # type: ignore
177+
),
178+
]
179+
180+
tool_results = [
181+
MCPToolResult(
182+
content=[types.TextContent(type="text", text="Action failed: invalid action")],
183+
isError=True,
184+
),
185+
]
186+
187+
messages = await agent.format_tool_results(tool_calls, tool_results)
188+
189+
# Check that error is in the response
190+
assert len(messages) == 1
191+
assert messages[0].role == "user"
192+
assert messages[0].parts[0].function_response is not None
193+
assert "error" in messages[0].parts[0].function_response.response
194+
195+
@pytest.mark.asyncio
196+
async def test_get_response(self, mock_mcp_client, mock_gemini_client):
197+
"""Test getting model response from Gemini API."""
198+
# Disable telemetry for this test
199+
with patch("hud.settings.settings.telemetry_enabled", False):
200+
agent = GeminiAgent(
201+
mcp_client=mock_mcp_client,
202+
model_client=mock_gemini_client,
203+
validate_api_key=False,
204+
)
205+
206+
# Set up available tools
207+
agent._available_tools = [
208+
types.Tool(name="gemini_computer", description="Computer tool", inputSchema={})
209+
]
210+
211+
# Mock the API response
212+
mock_response = MagicMock()
213+
mock_candidate = MagicMock()
214+
215+
# Create text part
216+
text_part = MagicMock()
217+
text_part.text = "I will click at coordinates"
218+
text_part.function_call = None
219+
220+
# Create function call part
221+
function_call_part = MagicMock()
222+
function_call_part.text = None
223+
function_call_part.function_call = MagicMock()
224+
function_call_part.function_call.name = "click_at"
225+
function_call_part.function_call.args = {"x": 100, "y": 200}
226+
227+
mock_candidate.content = MagicMock()
228+
mock_candidate.content.parts = [text_part, function_call_part]
229+
230+
mock_response.candidates = [mock_candidate]
231+
232+
mock_gemini_client.models = MagicMock()
233+
mock_gemini_client.models.generate_content = MagicMock(return_value=mock_response)
234+
235+
messages = [genai_types.Content(role="user", parts=[genai_types.Part(text="Click")])]
236+
response = await agent.get_response(messages)
237+
238+
assert response.content == "I will click at coordinates"
239+
assert len(response.tool_calls) == 1
240+
assert response.tool_calls[0].arguments == {"action": "click_at", "x": 100, "y": 200}
241+
assert response.done is False
242+
243+
@pytest.mark.asyncio
244+
async def test_get_response_text_only(self, mock_mcp_client, mock_gemini_client):
245+
"""Test getting text-only response."""
246+
# Disable telemetry for this test
247+
with patch("hud.settings.settings.telemetry_enabled", False):
248+
agent = GeminiAgent(
249+
mcp_client=mock_mcp_client,
250+
model_client=mock_gemini_client,
251+
validate_api_key=False,
252+
)
253+
254+
# Mock the API response with text only
255+
mock_response = MagicMock()
256+
mock_candidate = MagicMock()
257+
258+
text_part = MagicMock()
259+
text_part.text = "Task completed successfully"
260+
text_part.function_call = None
261+
262+
mock_candidate.content = MagicMock()
263+
mock_candidate.content.parts = [text_part]
264+
265+
mock_response.candidates = [mock_candidate]
266+
267+
mock_gemini_client.models = MagicMock()
268+
mock_gemini_client.models.generate_content = MagicMock(return_value=mock_response)
269+
270+
messages = [genai_types.Content(role="user", parts=[genai_types.Part(text="Status?")])]
271+
response = await agent.get_response(messages)
272+
273+
assert response.content == "Task completed successfully"
274+
assert response.tool_calls == []
275+
assert response.done is True
276+
277+
@pytest.mark.asyncio
278+
async def test_convert_tools_for_gemini(self, mock_mcp_client, mock_gemini_client):
279+
"""Test converting MCP tools to Gemini format."""
280+
agent = GeminiAgent(
281+
mcp_client=mock_mcp_client,
282+
model_client=mock_gemini_client,
283+
validate_api_key=False,
284+
)
285+
286+
# Set up available tools
287+
agent._available_tools = [
288+
types.Tool(
289+
name="gemini_computer",
290+
description="Computer tool",
291+
inputSchema={"type": "object"},
292+
),
293+
types.Tool(
294+
name="calculator",
295+
description="Calculator tool",
296+
inputSchema={
297+
"type": "object",
298+
"properties": {"operation": {"type": "string"}},
299+
},
300+
),
301+
]
302+
303+
gemini_tools = agent._convert_tools_for_gemini()
304+
305+
# Should have 2 tools: computer_use and calculator
306+
assert len(gemini_tools) == 2
307+
308+
# First should be computer use tool
309+
assert gemini_tools[0].computer_use is not None
310+
assert (
311+
gemini_tools[0].computer_use.environment == genai_types.Environment.ENVIRONMENT_BROWSER
312+
)
313+
314+
# Second should be calculator as function declaration
315+
assert gemini_tools[1].function_declarations is not None
316+
assert len(gemini_tools[1].function_declarations) == 1
317+
assert gemini_tools[1].function_declarations[0].name == "calculator"
318+
319+
@pytest.mark.asyncio
320+
async def test_create_user_message(self, mock_mcp_client, mock_gemini_client):
321+
"""Test creating a user message."""
322+
agent = GeminiAgent(
323+
mcp_client=mock_mcp_client,
324+
model_client=mock_gemini_client,
325+
validate_api_key=False,
326+
)
327+
328+
message = await agent.create_user_message("Hello Gemini")
329+
330+
assert message.role == "user"
331+
assert len(message.parts) == 1
332+
assert message.parts[0].text == "Hello Gemini"
333+
334+
@pytest.mark.asyncio
335+
async def test_handle_empty_response(self, mock_mcp_client, mock_gemini_client):
336+
"""Test handling empty response from API."""
337+
with patch("hud.settings.settings.telemetry_enabled", False):
338+
agent = GeminiAgent(
339+
mcp_client=mock_mcp_client,
340+
model_client=mock_gemini_client,
341+
validate_api_key=False,
342+
)
343+
344+
# Mock empty response
345+
mock_response = MagicMock()
346+
mock_response.candidates = []
347+
348+
mock_gemini_client.models = MagicMock()
349+
mock_gemini_client.models.generate_content = MagicMock(return_value=mock_response)
350+
351+
messages = [genai_types.Content(role="user", parts=[genai_types.Part(text="Hi")])]
352+
response = await agent.get_response(messages)
353+
354+
assert response.content == ""
355+
assert response.tool_calls == []
356+
assert response.done is True

hud/settings.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,12 @@ def settings_customise_sources(
9494
validation_alias="OPENAI_API_KEY",
9595
)
9696

97+
gemini_api_key: str | None = Field(
98+
default=None,
99+
description="API key for Google Gemini models",
100+
validation_alias="GEMINI_API_KEY",
101+
)
102+
97103
openrouter_api_key: str | None = Field(
98104
default=None,
99105
description="API key for OpenRouter models",

0 commit comments

Comments
 (0)