Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 16 additions & 7 deletions src/google/adk/tools/mcp_tool/mcp_session_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,9 @@
from typing import Union

import anyio
import httpx
from pydantic import BaseModel
from pydantic import ConfigDict

try:
from mcp import ClientSession
Expand Down Expand Up @@ -99,13 +101,16 @@ class StreamableHTTPConnectionParams(BaseModel):
Streamable HTTP server.
terminate_on_close: Whether to terminate the MCP Streamable HTTP server
when the connection is closed.
httpx_client: httpx.AsyncClient to use for the connection.
"""

url: str
headers: dict[str, Any] | None = None
timeout: float = 5.0
sse_read_timeout: float = 60 * 5.0
terminate_on_close: bool = True
httpx_client: httpx.AsyncClient | None = None
model_config = ConfigDict(arbitrary_types_allowed=True)


def retry_on_closed_resource(func):
Expand Down Expand Up @@ -277,15 +282,19 @@ def _create_client(self, merged_headers: Optional[Dict[str, str]] = None):
sse_read_timeout=self._connection_params.sse_read_timeout,
)
elif isinstance(self._connection_params, StreamableHTTPConnectionParams):
client = streamablehttp_client(
url=self._connection_params.url,
headers=merged_headers,
timeout=timedelta(seconds=self._connection_params.timeout),
sse_read_timeout=timedelta(
kwargs = {
'url': self._connection_params.url,
'headers': merged_headers,
'timeout': timedelta(seconds=self._connection_params.timeout),
'sse_read_timeout': timedelta(
seconds=self._connection_params.sse_read_timeout
),
terminate_on_close=self._connection_params.terminate_on_close,
)
'terminate_on_close': self._connection_params.terminate_on_close,
}
if self._connection_params.httpx_client:
kwargs['httpx_client_factory'] = self._connection_params.httpx_client
client = streamablehttp_client(**kwargs)

else:
raise ValueError(
'Unable to initialize connection. Connection should be'
Expand Down
13 changes: 13 additions & 0 deletions tests/unittests/tools/mcp_tool/test_mcp_session_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from unittest.mock import Mock
from unittest.mock import patch

import httpx
import pytest

# Skip all tests in this module if Python version is less than 3.10
Expand Down Expand Up @@ -143,6 +144,18 @@ def test_init_with_streamable_http_params(self):
manager = MCPSessionManager(http_params)

assert manager._connection_params == http_params
assert manager._connection_params.httpx_client is None

def test_init_with_streamable_http_params_with_httpx_client(self):
"""Test initialization with StreamableHTTPConnectionParams with httpx client."""
client = httpx.AsyncClient()
http_params = StreamableHTTPConnectionParams(
url="https://example.com/mcp", timeout=15.0, httpx_client=client
)
manager = MCPSessionManager(http_params)

assert manager._connection_params == http_params
assert manager._connection_params.httpx_client == client

def test_generate_session_key_stdio(self):
"""Test session key generation for stdio connections."""
Expand Down