Skip to content

Commit 1f1adc3

Browse files
mdemoret-nvAnuradhaKaruppiah
authored andcommitted
Improving the design of the MCP client
Signed-off-by: Michael Demoret <[email protected]>
1 parent fceca93 commit 1f1adc3

File tree

5 files changed

+2424
-2349
lines changed

5 files changed

+2424
-2349
lines changed

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ dependencies = [
2626
"httpx~=0.27",
2727
"jinja2~=3.1",
2828
"jsonpath-ng~=1.7",
29-
"mcp>=1.0.0",
29+
"mcp~=1.8",
3030
"networkx~=3.4",
3131
"numpy~=1.26",
3232
"openinference-semantic-conventions~=0.1.14",

src/aiq/tool/mcp/mcp_client.py

Lines changed: 101 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,7 @@
1818
import logging
1919
from abc import ABC
2020
from abc import abstractmethod
21-
from collections.abc import Callable
22-
from contextlib import AbstractAsyncContextManager
21+
from contextlib import AsyncExitStack
2322
from contextlib import asynccontextmanager
2423
from enum import Enum
2524
from typing import Any
@@ -28,6 +27,7 @@
2827
from mcp.client.sse import sse_client
2928
from mcp.client.stdio import StdioServerParameters
3029
from mcp.client.stdio import stdio_client
30+
from mcp.client.streamable_http import streamablehttp_client
3131
from mcp.types import TextContent
3232
from pydantic import BaseModel
3333
from pydantic import Field
@@ -97,8 +97,35 @@ class MCPBaseClient(ABC):
9797
def __init__(self, client_type: str = 'sse'):
9898
self._tools = None
9999
self._client_type = client_type.lower()
100-
if self._client_type not in ['sse', 'stdio']:
101-
raise ValueError("client_type must be either 'sse' or 'stdio'")
100+
if self._client_type not in ['sse', 'stdio', 'streamable-http']:
101+
raise ValueError("client_type must be either 'sse', 'stdio' or 'streamable-http'")
102+
103+
self._exit_stack: AsyncExitStack | None = None
104+
105+
self._session: ClientSession | None = None
106+
107+
@property
108+
def client_type(self) -> str:
109+
return self._client_type
110+
111+
async def __aenter__(self):
112+
if self._exit_stack:
113+
raise RuntimeError("MCPBaseClient already initialized. Use async with to initialize.")
114+
115+
self._exit_stack = AsyncExitStack()
116+
117+
self._session = await self._exit_stack.enter_async_context(self.connect_to_server())
118+
119+
return self
120+
121+
async def __aexit__(self, exc_type, exc_value, traceback):
122+
123+
if not self._exit_stack:
124+
raise RuntimeError("MCPBaseClient not initialized. Use async with to initialize.")
125+
126+
await self._exit_stack.aclose()
127+
self._session = None
128+
self._exit_stack = None
102129

103130
@abstractmethod
104131
@asynccontextmanager
@@ -112,12 +139,15 @@ async def get_tools(self):
112139
"""
113140
Retrieve a dictionary of all tools served by the MCP server.
114141
"""
115-
async with self.connect_to_server() as session:
116-
response = await session.list_tools()
142+
143+
if not self._session:
144+
raise RuntimeError("MCPBaseClient not initialized. Use async with to initialize.")
145+
146+
response = await self._session.list_tools()
117147

118148
return {
119149
tool.name:
120-
MCPToolClient(connect_fn=self.connect_to_server,
150+
MCPToolClient(session=self._session,
121151
tool_name=tool.name,
122152
tool_description=tool.description,
123153
tool_input_schema=tool.inputSchema)
@@ -137,6 +167,9 @@ async def get_tool(self, tool_name: str) -> MCPToolClient:
137167
Raise:
138168
ValueError if no tool is available with that name.
139169
"""
170+
if not self._exit_stack:
171+
raise RuntimeError("MCPBaseClient not initialized. Use async with to initialize.")
172+
140173
if not self._tools:
141174
self._tools = await self.get_tools()
142175

@@ -146,9 +179,11 @@ async def get_tool(self, tool_name: str) -> MCPToolClient:
146179
return tool
147180

148181
async def call_tool(self, tool_name: str, tool_args: dict | None):
149-
async with self.connect_to_server() as session:
150-
result = await session.call_tool(tool_name, tool_args)
151-
return result
182+
if not self._session:
183+
raise RuntimeError("MCPBaseClient not initialized. Use async with to initialize.")
184+
185+
result = await self._session.call_tool(tool_name, tool_args)
186+
return result
152187

153188

154189
class MCPSSEClient(MCPBaseClient):
@@ -164,6 +199,10 @@ def __init__(self, url: str, client_type: str = 'sse'):
164199
super().__init__(client_type)
165200
self._url = url
166201

202+
@property
203+
def url(self) -> str:
204+
return self._url
205+
167206
@asynccontextmanager
168207
async def connect_to_server(self):
169208
"""
@@ -195,39 +234,55 @@ def __init__(self,
195234
self._command = command
196235
self._args = args
197236
self._env = env
198-
self._session = None # hold session if persistent
199-
self._session_cm = None
200-
201-
async def start_persistent_session(self):
202-
"""Starts and holds a persistent session."""
203-
server_params = StdioServerParameters(command=self._command, args=self._args, env=self._env)
204-
self._session_cm = stdio_client(server_params)
205-
read, write = await self._session_cm.__aenter__()
206-
self._session = ClientSession(read, write)
207-
await self._session.initialize()
208-
209-
async def stop_persistent_session(self):
210-
"""Ends the persistent session."""
211-
if self._session:
212-
await self._session.__aexit__(None, None, None)
213-
self._session = None
214-
if self._session_cm:
215-
await self._session_cm.__aexit__(None, None, None)
216-
self._session_cm = None
237+
238+
@property
239+
def command(self) -> str:
240+
return self._command
241+
242+
@property
243+
def args(self) -> list[str] | None:
244+
return self._args
245+
246+
@property
247+
def env(self) -> dict[str, str] | None:
248+
return self._env
217249

218250
@asynccontextmanager
219251
async def connect_to_server(self):
220252
"""
221253
Establish a session with an MCP server via stdio within an async context
222254
"""
223-
if self._session:
224-
yield self._session
225-
else:
226-
server_params = StdioServerParameters(command=self._command, args=self._args or [], env=self._env)
227-
async with stdio_client(server_params) as (read, write):
228-
async with ClientSession(read, write) as session:
229-
await session.initialize()
230-
yield session
255+
256+
server_params = StdioServerParameters(command=self._command, args=self._args or [], env=self._env)
257+
async with stdio_client(server_params) as (read, write):
258+
async with ClientSession(read, write) as session:
259+
await session.initialize()
260+
yield session
261+
262+
263+
class MCPStreamableHTTPClient(MCPBaseClient):
264+
"""
265+
Client for creating a session and connecting to an MCP server using streamable-http
266+
"""
267+
268+
def __init__(self, url: str, client_type: str = 'streamable-http'):
269+
super().__init__(client_type)
270+
271+
self._url = url
272+
273+
@property
274+
def url(self) -> str:
275+
return self._url
276+
277+
@asynccontextmanager
278+
async def connect_to_server(self):
279+
"""
280+
Establish a session with an MCP server via streamable-http within an async context
281+
"""
282+
async with streamablehttp_client(url=self._url) as (read, write, get_session_id):
283+
async with ClientSession(read, write) as session:
284+
await session.initialize()
285+
yield session
231286

232287

233288
class MCPToolClient:
@@ -242,11 +297,11 @@ class MCPToolClient:
242297
"""
243298

244299
def __init__(self,
245-
connect_fn: Callable[[], AbstractAsyncContextManager[ClientSession]],
300+
session: ClientSession,
246301
tool_name: str,
247302
tool_description: str | None,
248303
tool_input_schema: dict | None = None):
249-
self._connect_fn = connect_fn
304+
self._session = session
250305
self._tool_name = tool_name
251306
self._tool_description = tool_description
252307
self._input_schema = (model_from_mcp_schema(self._tool_name, tool_input_schema) if tool_input_schema else None)
@@ -285,8 +340,7 @@ async def acall(self, tool_args: dict) -> str:
285340
Args:
286341
tool_args (dict[str, Any]): A dictionary of key value pairs to serve as inputs for the MCP tool.
287342
"""
288-
async with self._connect_fn() as session:
289-
result = await session.call_tool(self._tool_name, tool_args)
343+
result = await self._session.call_tool(self._tool_name, tool_args)
290344

291345
output = []
292346

@@ -296,4 +350,9 @@ async def acall(self, tool_args: dict) -> str:
296350
else:
297351
# Log non-text content for now
298352
logger.warning("Got not-text output from %s of type %s", self.name, type(res))
299-
return "\n".join(output)
353+
result_str = "\n".join(output)
354+
355+
if result.isError:
356+
raise RuntimeError(result_str)
357+
358+
return result_str

src/aiq/tool/mcp/mcp_tool.py

Lines changed: 57 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -40,9 +40,6 @@ class MCPToolConfig(FunctionBaseConfig, name="mcp_tool_wrapper"):
4040
description="The command to run for stdio mode (e.g. 'docker' or 'python')")
4141
args: list[str] | None = Field(default=None, description="Additional arguments for the stdio command")
4242
env: dict[str, str] | None = Field(default=None, description="Environment variables to set for the stdio process")
43-
persistent: bool = Field(
44-
default=False,
45-
description="If true, keeps the MCP stdio subprocess open across multiple calls. Only applies to stdio mode.")
4643
description: str | None = Field(default=None,
4744
description="""
4845
Description for the tool that will override the description provided by the MCP server. Should only be used if
@@ -78,55 +75,69 @@ async def mcp_tool(config: MCPToolConfig, builder: Builder): # pylint: disable=
7875

7976
from aiq.tool.mcp.mcp_client import MCPSSEClient
8077
from aiq.tool.mcp.mcp_client import MCPStdioClient
78+
from aiq.tool.mcp.mcp_client import MCPStreamableHTTPClient
8179
from aiq.tool.mcp.mcp_client import MCPToolClient
8280

8381
# Initialize the client
8482
if config.client_type == 'stdio':
83+
if not config.command:
84+
raise ValueError("command is required when using stdio client type")
85+
8586
client = MCPStdioClient(command=config.command, args=config.args, env=config.env)
86-
if config.persistent:
87-
await client.start_persistent_session()
88-
else:
89-
client = MCPSSEClient(url=str(config.url))
87+
elif config.client_type == 'streamable-http':
88+
if not config.url:
89+
raise ValueError("url is required when using streamable-http client type")
9090

91-
# If the tool is found create a MCPToolClient object and set the description if provided
92-
tool: MCPToolClient = await client.get_tool(config.mcp_tool_name)
93-
if config.description:
94-
tool.set_description(description=config.description)
91+
client = MCPStreamableHTTPClient(url=str(config.url))
92+
elif config.client_type == 'sse':
93+
if not config.url:
94+
raise ValueError("url is required when using sse client type")
9595

96-
if config.client_type == "sse":
97-
source = config.url
96+
client = MCPSSEClient(url=str(config.url))
9897
else:
99-
source = f"{config.command} {' '.join(config.args) if config.args else ''}"
100-
logger.info("Configured to use tool: %s from MCP server at %s", tool.name, source)
101-
if config.client_type == "stdio" and not config.persistent:
102-
logger.info("MCP stdio tool will launch a fresh subprocess per call (persistent=False).")
103-
104-
def _convert_from_str(input_str: str) -> tool.input_schema:
105-
return tool.input_schema.model_validate_json(input_str)
106-
107-
async def _response_fn(tool_input: BaseModel | None = None, **kwargs) -> str:
108-
# Run the tool, catching any errors and sending to agent for correction
109-
try:
110-
if tool_input:
111-
args = tool_input.model_dump()
112-
return await tool.acall(args)
113-
114-
_ = tool.input_schema.model_validate(kwargs)
115-
return await tool.acall(kwargs)
116-
except Exception as e:
117-
if config.return_exception:
98+
raise ValueError(f"Invalid client type: {config.client_type}")
99+
100+
async with client:
101+
# If the tool is found create a MCPToolClient object and set the description if provided
102+
tool: MCPToolClient = await client.get_tool(config.mcp_tool_name)
103+
if config.description:
104+
tool.set_description(description=config.description)
105+
106+
if config.client_type == "sse" or config.client_type == "streamable-http":
107+
source = config.url
108+
elif config.client_type == "stdio":
109+
source = f"{config.command} {' '.join(config.args) if config.args else ''}"
110+
else:
111+
raise ValueError(f"Invalid client type: {config.client_type}")
112+
113+
logger.info("Configured to use tool: %s from MCP server at %s", tool.name, source)
114+
115+
def _convert_from_str(input_str: str) -> tool.input_schema:
116+
return tool.input_schema.model_validate_json(input_str)
117+
118+
async def _response_fn(tool_input: tool.input_schema) -> str:
119+
# Run the tool, catching any errors and sending to agent for correction
120+
try:
118121
if tool_input:
119-
logger.warning("Error calling tool %s with serialized input: %s",
120-
tool.name,
121-
tool_input.model_dump(),
122-
exc_info=True)
123-
else:
124-
logger.warning("Error calling tool %s with input: %s", tool.name, kwargs, exc_info=True)
125-
return str(e)
126-
# If the tool call fails, raise the exception.
127-
raise
128-
129-
yield FunctionInfo.create(single_fn=_response_fn,
130-
description=tool.description,
131-
input_schema=tool.input_schema,
132-
converters=[_convert_from_str])
122+
args = tool_input.model_dump()
123+
return await tool.acall(args)
124+
125+
_ = tool.input_schema.model_validate(kwargs)
126+
return await tool.acall(kwargs)
127+
except Exception as e:
128+
if config.return_exception:
129+
if tool_input:
130+
logger.warning("Error calling tool %s with serialized input: %s",
131+
tool.name,
132+
tool_input.model_dump(),
133+
exc_info=True)
134+
else:
135+
logger.warning("Error calling tool %s with input: %s", tool.name, kwargs, exc_info=True)
136+
return str(e)
137+
# If the tool call fails, raise the exception.
138+
raise
139+
140+
yield FunctionInfo.create(single_fn=_response_fn,
141+
description=tool.description,
142+
input_schema=tool.input_schema,
143+
converters=[_convert_from_str])

0 commit comments

Comments
 (0)