|
14 | 14 | # |
15 | 15 |
|
16 | 16 | """Utils for working with MCP tools.""" |
| 17 | +import contextlib |
| 18 | +import httpx |
17 | 19 |
|
18 | 20 | from importlib.metadata import PackageNotFoundError, version |
19 | 21 | import typing |
20 | 22 | from typing import Any |
21 | 23 |
|
| 24 | +import google.auth |
| 25 | +from google.auth.transport.requests import Request |
| 26 | + |
22 | 27 | from . import _common |
23 | 28 | from . import types |
| 29 | +from ._api_client import _MULTI_REGIONAL_LOCATIONS |
24 | 30 |
|
25 | 31 | if typing.TYPE_CHECKING: |
26 | 32 | from mcp.types import Tool as McpTool |
27 | 33 | from mcp import ClientSession as McpClientSession |
| 34 | + from mcp.client.streamable_http import streamable_http_client |
| 35 | + from mcp.shared._httpx_utils import create_mcp_http_client |
28 | 36 | else: |
29 | 37 | McpClientSession: typing.Type = Any |
30 | 38 | McpTool: typing.Type = Any |
| 39 | + streamable_http_client: Any = None |
| 40 | + create_mcp_http_client: Any = None |
| 41 | + |
31 | 42 | try: |
32 | 43 | from mcp.types import Tool as McpTool |
33 | 44 | from mcp import ClientSession as McpClientSession |
| 45 | + from mcp.client.streamable_http import streamable_http_client |
| 46 | + from mcp.shared._httpx_utils import create_mcp_http_client |
34 | 47 | except ImportError: |
35 | 48 | McpTool = None |
36 | 49 | McpClientSession = None |
| 50 | + streamable_http_client = None |
| 51 | + create_mcp_http_client = None |
37 | 52 |
|
38 | 53 |
|
39 | 54 | def mcp_to_gemini_tool(tool: McpTool) -> types.Tool: |
@@ -146,3 +161,79 @@ def _filter_to_supported_schema( |
146 | 161 |
|
147 | 162 | return filtered_schema |
148 | 163 |
|
| 164 | + |
| 165 | +@contextlib.asynccontextmanager |
| 166 | +async def _connect_agent_platform_mcp(api_client: Any, toolset_name: str) -> typing.AsyncIterator[Any]: |
| 167 | + """Internal helper to manage the Agent Platform MCP lifecycle per request.""" |
| 168 | + if streamable_http_client is None: |
| 169 | + raise ImportError( |
| 170 | + "The 'mcp' package is required to use Agent Platform MCP servers." |
| 171 | + ) |
| 172 | + |
| 173 | + base_url = None |
| 174 | + if hasattr(api_client, '_http_options') and hasattr(api_client._http_options, 'base_url'): |
| 175 | + base_url = api_client._http_options.base_url |
| 176 | + |
| 177 | + if base_url: |
| 178 | + if base_url.endswith("/"): |
| 179 | + base_url = base_url[:-1] |
| 180 | + mcp_url = f"{base_url}/mcp/{toolset_name}" |
| 181 | + else: |
| 182 | + location = getattr(api_client, "location", "global") |
| 183 | + if location == "global": |
| 184 | + mcp_url = f"https://aiplatform.googleapis.com/mcp/{toolset_name}" |
| 185 | + elif location in _MULTI_REGIONAL_LOCATIONS: |
| 186 | + mcp_url = f"https://aiplatform.{location}.rep.googleapis.com/mcp/{toolset_name}" |
| 187 | + else: |
| 188 | + mcp_url = f"https://{location}-aiplatform.googleapis.com/mcp/{toolset_name}" |
| 189 | + |
| 190 | + token = await api_client._async_access_token() |
| 191 | + project = getattr(api_client, "project", None) |
| 192 | + |
| 193 | + headers = { |
| 194 | + "Authorization": f"Bearer {token}", |
| 195 | + } |
| 196 | + if project: |
| 197 | + headers["X-Goog-User-Project"] = project |
| 198 | + |
| 199 | + set_mcp_usage_header(headers) |
| 200 | + |
| 201 | + timeout = httpx.Timeout(30.0, read=300.0) |
| 202 | + http_client = httpx.AsyncClient(headers=headers, timeout=timeout) |
| 203 | + |
| 204 | + try: |
| 205 | + async with http_client: |
| 206 | + async with streamable_http_client( |
| 207 | + url=mcp_url, http_client=http_client |
| 208 | + ) as streams: |
| 209 | + read_stream, write_stream, _ = streams |
| 210 | + async with McpClientSession(read_stream, write_stream) as session: |
| 211 | + await session.initialize() |
| 212 | + try: |
| 213 | + yield session |
| 214 | + except GeneratorExit: |
| 215 | + return |
| 216 | + |
| 217 | + except BaseException as eg: |
| 218 | + |
| 219 | + error_messages = [] |
| 220 | + |
| 221 | + def _extract_errors(exc: Any) -> None: |
| 222 | + # Handle potentially nested ExceptionGroups |
| 223 | + if hasattr(exc, "exceptions"): |
| 224 | + for e in exc.exceptions: |
| 225 | + _extract_errors(e) |
| 226 | + else: |
| 227 | + msg = f"{type(exc).__name__}: {str(exc)}" |
| 228 | + if hasattr(exc, "response") and exc.response is not None: |
| 229 | + msg += f" (HTTP {exc.response.status_code}: {exc.response.text})" |
| 230 | + error_messages.append(msg) |
| 231 | + |
| 232 | + if type(eg).__name__ in ("ExceptionGroup", "BaseExceptionGroup") or hasattr( |
| 233 | + eg, "exceptions" |
| 234 | + ): |
| 235 | + _extract_errors(eg) |
| 236 | + raise ValueError( |
| 237 | + f"Failed to connect to Agent Platform MCP Server at {mcp_url}.\n" |
| 238 | + f"Underlying errors: {error_messages}" |
| 239 | + ) from eg |
0 commit comments