diff --git a/google/genai/_mcp_utils.py b/google/genai/_mcp_utils.py index 24fec0b03..aa9f0bfe0 100644 --- a/google/genai/_mcp_utils.py +++ b/google/genai/_mcp_utils.py @@ -14,26 +14,41 @@ # """Utils for working with MCP tools.""" +import contextlib +import httpx from importlib.metadata import PackageNotFoundError, version import typing from typing import Any +import google.auth +from google.auth.transport.requests import Request + from . import _common from . import types +from ._api_client import _MULTI_REGIONAL_LOCATIONS if typing.TYPE_CHECKING: from mcp.types import Tool as McpTool from mcp import ClientSession as McpClientSession + from mcp.client.streamable_http import streamable_http_client + from mcp.shared._httpx_utils import create_mcp_http_client else: McpClientSession: typing.Type = Any McpTool: typing.Type = Any + streamable_http_client: Any = None + create_mcp_http_client: Any = None + try: from mcp.types import Tool as McpTool from mcp import ClientSession as McpClientSession + from mcp.client.streamable_http import streamable_http_client + from mcp.shared._httpx_utils import create_mcp_http_client except ImportError: McpTool = None McpClientSession = None + streamable_http_client = None + create_mcp_http_client = None def mcp_to_gemini_tool(tool: McpTool) -> types.Tool: @@ -144,3 +159,81 @@ def _filter_to_supported_schema( return filtered_schema + +@contextlib.asynccontextmanager +async def _connect_agent_platform_mcp(api_client: Any, toolset_name: str) -> typing.AsyncIterator[Any]: + """Internal helper to manage the Agent Platform MCP lifecycle per request.""" + if streamable_http_client is None: + raise ImportError( + "The 'mcp' package is required to use Agent Platform MCP servers." + ) + + base_url = None + if hasattr(api_client, '_http_options') and hasattr(api_client._http_options, 'base_url'): + base_url = api_client._http_options.base_url + + if base_url: + if base_url.endswith("/"): + base_url = base_url[:-1] + mcp_url = f"{base_url}/mcp/{toolset_name}" + else: + location = getattr(api_client, "location", "global") + if location == "global": + mcp_url = f"https://aiplatform.googleapis.com/mcp/{toolset_name}" + elif location in _MULTI_REGIONAL_LOCATIONS: + mcp_url = f"https://aiplatform.{location}.rep.googleapis.com/mcp/{toolset_name}" + else: + mcp_url = f"https://{location}-aiplatform.googleapis.com/mcp/{toolset_name}" + + token = await api_client._async_access_token() + project = getattr(api_client, "project", None) + + headers = {} + if hasattr(api_client, "_http_options") and api_client._http_options and api_client._http_options.headers: + headers = dict(api_client._http_options.headers) + + headers["Authorization"] = f"Bearer {token}" + if project: + headers["X-Goog-User-Project"] = project + + set_mcp_usage_header(headers) + + timeout = httpx.Timeout(30.0, read=300.0) + http_client = httpx.AsyncClient(headers=headers, timeout=timeout) + + try: + async with http_client: + async with streamable_http_client( + url=mcp_url, http_client=http_client + ) as streams: + read_stream, write_stream, _ = streams + async with McpClientSession(read_stream, write_stream) as session: + await session.initialize() + try: + yield session + except GeneratorExit: + return + + except BaseException as eg: + + error_messages = [] + + def _extract_errors(exc: Any) -> None: + # Handle potentially nested ExceptionGroups + if hasattr(exc, "exceptions"): + for e in exc.exceptions: + _extract_errors(e) + else: + msg = f"{type(exc).__name__}: {str(exc)}" + if hasattr(exc, "response") and exc.response is not None: + msg += f" (HTTP {exc.response.status_code}: {exc.response.text})" + error_messages.append(msg) + + if type(eg).__name__ in ("ExceptionGroup", "BaseExceptionGroup") or hasattr( + eg, "exceptions" + ): + _extract_errors(eg) + raise ValueError( + f"Failed to connect to Agent Platform MCP Server at {mcp_url}.\n" + f"Underlying errors: {error_messages}" + ) from eg diff --git a/google/genai/tests/mcp/test_mcp_to_gemini_tools.py b/google/genai/tests/mcp/test_mcp_to_gemini_tools.py index af07aa0a9..61d64102c 100644 --- a/google/genai/tests/mcp/test_mcp_to_gemini_tools.py +++ b/google/genai/tests/mcp/test_mcp_to_gemini_tools.py @@ -13,8 +13,15 @@ # limitations under the License. # +import contextlib +from unittest import mock + +import pytest + from ... import _mcp_utils from ... import types +from ..._api_client import BaseApiClient + try: from mcp import types as mcp_types @@ -301,3 +308,64 @@ def test_agent_platform_preserves_unknown_fields(): # Verify the entire schema is passed through intact, including the unknown field assert 'some_new_future_field' in schema assert schema['some_new_future_field'] == 'value' + + +@pytest.mark.asyncio +@mock.patch('httpx.AsyncClient') +@mock.patch.object(_mcp_utils, 'streamable_http_client') +@mock.patch.object(_mcp_utils, 'McpClientSession') +@mock.patch('google.auth.default') +async def test_connect_agent_platform_mcp_url_and_headers( + mock_auth_default, mock_session_cls, mock_streamable, mock_create_http +): + """Tests that _mcp_utils._connect_agent_platform_mcp builds the correct + regional URL and injects auth headers. + """ + + mock_creds = mock.Mock() + mock_creds.token = 'fake-oauth-token' + mock_auth_default.return_value = (mock_creds, 'fake-project') + + @contextlib.asynccontextmanager + async def mock_streamable_ctx(*args, **kwargs): + yield (mock.Mock(), mock.Mock(), mock.Mock()) + + mock_streamable.side_effect = mock_streamable_ctx + + @contextlib.asynccontextmanager + async def mock_session_ctx(*args, **kwargs): + session_instance = mock.AsyncMock() + yield session_instance + + mock_session_cls.side_effect = mock_session_ctx + + mock_http_client_instance = mock.AsyncMock() + mock_http_client_instance.__aenter__.return_value = ( + mock_http_client_instance + ) + mock_http_client_instance.__aexit__.return_value = None + mock_create_http.return_value = mock_http_client_instance + + api_client = BaseApiClient( + vertexai=True, + project='test-project-123', + location='europe-west4' + ) + + async with _mcp_utils._connect_agent_platform_mcp( + api_client, 'endpoints' + ) as session: + session.initialize.assert_awaited_once() + + mock_streamable.assert_called_once() + assert ( + mock_streamable.call_args.kwargs['url'] + == 'https://europe-west4-aiplatform.googleapis.com/mcp/endpoints' + ) + + mock_create_http.assert_called_once() + called_headers = mock_create_http.call_args.kwargs['headers'] + + assert called_headers['Authorization'] == 'Bearer fake-oauth-token' + assert called_headers['X-Goog-User-Project'] == 'test-project-123' + assert 'mcp_used' in called_headers.get('x-goog-api-client', '')