Skip to content

Commit 226fe7c

Browse files
sararobcopybara-github
authored andcommitted
chore: Add session and auth helpers for Agent Platform MCP support
PiperOrigin-RevId: 906436921
1 parent 04bf0b8 commit 226fe7c

2 files changed

Lines changed: 159 additions & 0 deletions

File tree

google/genai/_mcp_utils.py

Lines changed: 91 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,26 +14,41 @@
1414
#
1515

1616
"""Utils for working with MCP tools."""
17+
import contextlib
18+
import httpx
1719

1820
from importlib.metadata import PackageNotFoundError, version
1921
import typing
2022
from typing import Any
2123

24+
import google.auth
25+
from google.auth.transport.requests import Request
26+
2227
from . import _common
2328
from . import types
29+
from ._api_client import _MULTI_REGIONAL_LOCATIONS
2430

2531
if typing.TYPE_CHECKING:
2632
from mcp.types import Tool as McpTool
2733
from mcp import ClientSession as McpClientSession
34+
from mcp.client.streamable_http import streamable_http_client
35+
from mcp.shared._httpx_utils import create_mcp_http_client
2836
else:
2937
McpClientSession: typing.Type = Any
3038
McpTool: typing.Type = Any
39+
streamable_http_client: Any = None
40+
create_mcp_http_client: Any = None
41+
3142
try:
3243
from mcp.types import Tool as McpTool
3344
from mcp import ClientSession as McpClientSession
45+
from mcp.client.streamable_http import streamable_http_client
46+
from mcp.shared._httpx_utils import create_mcp_http_client
3447
except ImportError:
3548
McpTool = None
3649
McpClientSession = None
50+
streamable_http_client = None
51+
create_mcp_http_client = None
3752

3853

3954
def mcp_to_gemini_tool(tool: McpTool) -> types.Tool:
@@ -146,3 +161,79 @@ def _filter_to_supported_schema(
146161

147162
return filtered_schema
148163

164+
165+
@contextlib.asynccontextmanager
166+
async def _connect_agent_platform_mcp(api_client: Any, toolset_name: str) -> typing.AsyncIterator[Any]:
167+
"""Internal helper to manage the Agent Platform MCP lifecycle per request."""
168+
if streamable_http_client is None:
169+
raise ImportError(
170+
"The 'mcp' package is required to use Agent Platform MCP servers."
171+
)
172+
173+
base_url = None
174+
if hasattr(api_client, '_http_options') and hasattr(api_client._http_options, 'base_url'):
175+
base_url = api_client._http_options.base_url
176+
177+
if base_url:
178+
if base_url.endswith("/"):
179+
base_url = base_url[:-1]
180+
mcp_url = f"{base_url}/mcp/{toolset_name}"
181+
else:
182+
location = getattr(api_client, "location", "global")
183+
if location == "global":
184+
mcp_url = f"https://aiplatform.googleapis.com/mcp/{toolset_name}"
185+
elif location in _MULTI_REGIONAL_LOCATIONS:
186+
mcp_url = f"https://aiplatform.{location}.rep.googleapis.com/mcp/{toolset_name}"
187+
else:
188+
mcp_url = f"https://{location}-aiplatform.googleapis.com/mcp/{toolset_name}"
189+
190+
token = await api_client._async_access_token()
191+
project = getattr(api_client, "project", None)
192+
193+
headers = {
194+
"Authorization": f"Bearer {token}",
195+
}
196+
if project:
197+
headers["X-Goog-User-Project"] = project
198+
199+
set_mcp_usage_header(headers)
200+
201+
timeout = httpx.Timeout(30.0, read=300.0)
202+
http_client = httpx.AsyncClient(headers=headers, timeout=timeout)
203+
204+
try:
205+
async with http_client:
206+
async with streamable_http_client(
207+
url=mcp_url, http_client=http_client
208+
) as streams:
209+
read_stream, write_stream, _ = streams
210+
async with McpClientSession(read_stream, write_stream) as session:
211+
await session.initialize()
212+
try:
213+
yield session
214+
except GeneratorExit:
215+
return
216+
217+
except BaseException as eg:
218+
219+
error_messages = []
220+
221+
def _extract_errors(exc: Any) -> None:
222+
# Handle potentially nested ExceptionGroups
223+
if hasattr(exc, "exceptions"):
224+
for e in exc.exceptions:
225+
_extract_errors(e)
226+
else:
227+
msg = f"{type(exc).__name__}: {str(exc)}"
228+
if hasattr(exc, "response") and exc.response is not None:
229+
msg += f" (HTTP {exc.response.status_code}: {exc.response.text})"
230+
error_messages.append(msg)
231+
232+
if type(eg).__name__ in ("ExceptionGroup", "BaseExceptionGroup") or hasattr(
233+
eg, "exceptions"
234+
):
235+
_extract_errors(eg)
236+
raise ValueError(
237+
f"Failed to connect to Agent Platform MCP Server at {mcp_url}.\n"
238+
f"Underlying errors: {error_messages}"
239+
) from eg

google/genai/tests/mcp/test_mcp_to_gemini_tools.py

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,15 @@
1313
# limitations under the License.
1414
#
1515

16+
import contextlib
17+
from unittest import mock
18+
19+
import pytest
20+
1621
from ... import _mcp_utils
1722
from ... import types
23+
from ..._api_client import BaseApiClient
24+
1825

1926
try:
2027
from mcp import types as mcp_types
@@ -276,3 +283,64 @@ def test_update_endpoint_labels_conversion():
276283
labels_schema = schema['properties']['endpoint']['properties']['labels']
277284

278285
assert 'additionalProperties' in labels_schema
286+
287+
288+
@pytest.mark.asyncio
289+
@mock.patch('httpx.AsyncClient')
290+
@mock.patch.object(_mcp_utils, 'streamable_http_client')
291+
@mock.patch.object(_mcp_utils, 'McpClientSession')
292+
@mock.patch('google.auth.default')
293+
async def test_connect_agent_platform_mcp_url_and_headers(
294+
mock_auth_default, mock_session_cls, mock_streamable, mock_create_http
295+
):
296+
"""Tests that _mcp_utils._connect_agent_platform_mcp builds the correct
297+
regional URL and injects auth headers.
298+
"""
299+
300+
mock_creds = mock.Mock()
301+
mock_creds.token = 'fake-oauth-token'
302+
mock_auth_default.return_value = (mock_creds, 'fake-project')
303+
304+
@contextlib.asynccontextmanager
305+
async def mock_streamable_ctx(*args, **kwargs):
306+
yield (mock.Mock(), mock.Mock(), mock.Mock())
307+
308+
mock_streamable.side_effect = mock_streamable_ctx
309+
310+
@contextlib.asynccontextmanager
311+
async def mock_session_ctx(*args, **kwargs):
312+
session_instance = mock.AsyncMock()
313+
yield session_instance
314+
315+
mock_session_cls.side_effect = mock_session_ctx
316+
317+
mock_http_client_instance = mock.AsyncMock()
318+
mock_http_client_instance.__aenter__.return_value = (
319+
mock_http_client_instance
320+
)
321+
mock_http_client_instance.__aexit__.return_value = None
322+
mock_create_http.return_value = mock_http_client_instance
323+
324+
api_client = BaseApiClient(
325+
vertexai=True,
326+
project='test-project-123',
327+
location='europe-west4'
328+
)
329+
330+
async with _mcp_utils._connect_agent_platform_mcp(
331+
api_client, 'endpoints'
332+
) as session:
333+
session.initialize.assert_awaited_once()
334+
335+
mock_streamable.assert_called_once()
336+
assert (
337+
mock_streamable.call_args.kwargs['url']
338+
== 'https://europe-west4-aiplatform.googleapis.com/mcp/endpoints'
339+
)
340+
341+
mock_create_http.assert_called_once()
342+
called_headers = mock_create_http.call_args.kwargs['headers']
343+
344+
assert called_headers['Authorization'] == 'Bearer fake-oauth-token'
345+
assert called_headers['X-Goog-User-Project'] == 'test-project-123'
346+
assert 'mcp_used' in called_headers.get('x-goog-api-client', '')

0 commit comments

Comments
 (0)