18
18
import logging
19
19
from abc import ABC
20
20
from abc import abstractmethod
21
- from collections .abc import Callable
22
- from contextlib import AbstractAsyncContextManager
21
+ from contextlib import AsyncExitStack
23
22
from contextlib import asynccontextmanager
24
23
from enum import Enum
25
24
from typing import Any
28
27
from mcp .client .sse import sse_client
29
28
from mcp .client .stdio import StdioServerParameters
30
29
from mcp .client .stdio import stdio_client
30
+ from mcp .client .streamable_http import streamablehttp_client
31
31
from mcp .types import TextContent
32
32
from pydantic import BaseModel
33
33
from pydantic import Field
@@ -97,8 +97,35 @@ class MCPBaseClient(ABC):
97
97
def __init__ (self , client_type : str = 'sse' ):
98
98
self ._tools = None
99
99
self ._client_type = client_type .lower ()
100
- if self ._client_type not in ['sse' , 'stdio' ]:
101
- raise ValueError ("client_type must be either 'sse' or 'stdio'" )
100
+ if self ._client_type not in ['sse' , 'stdio' , 'streamable-http' ]:
101
+ raise ValueError ("client_type must be either 'sse', 'stdio' or 'streamable-http'" )
102
+
103
+ self ._exit_stack : AsyncExitStack | None = None
104
+
105
+ self ._session : ClientSession | None = None
106
+
107
+ @property
108
+ def client_type (self ) -> str :
109
+ return self ._client_type
110
+
111
+ async def __aenter__ (self ):
112
+ if self ._exit_stack :
113
+ raise RuntimeError ("MCPBaseClient already initialized. Use async with to initialize." )
114
+
115
+ self ._exit_stack = AsyncExitStack ()
116
+
117
+ self ._session = await self ._exit_stack .enter_async_context (self .connect_to_server ())
118
+
119
+ return self
120
+
121
+ async def __aexit__ (self , exc_type , exc_value , traceback ):
122
+
123
+ if not self ._exit_stack :
124
+ raise RuntimeError ("MCPBaseClient not initialized. Use async with to initialize." )
125
+
126
+ await self ._exit_stack .aclose ()
127
+ self ._session = None
128
+ self ._exit_stack = None
102
129
103
130
@abstractmethod
104
131
@asynccontextmanager
@@ -112,12 +139,15 @@ async def get_tools(self):
112
139
"""
113
140
Retrieve a dictionary of all tools served by the MCP server.
114
141
"""
115
- async with self .connect_to_server () as session :
116
- response = await session .list_tools ()
142
+
143
+ if not self ._session :
144
+ raise RuntimeError ("MCPBaseClient not initialized. Use async with to initialize." )
145
+
146
+ response = await self ._session .list_tools ()
117
147
118
148
return {
119
149
tool .name :
120
- MCPToolClient (connect_fn = self .connect_to_server ,
150
+ MCPToolClient (session = self ._session ,
121
151
tool_name = tool .name ,
122
152
tool_description = tool .description ,
123
153
tool_input_schema = tool .inputSchema )
@@ -137,6 +167,9 @@ async def get_tool(self, tool_name: str) -> MCPToolClient:
137
167
Raise:
138
168
ValueError if no tool is available with that name.
139
169
"""
170
+ if not self ._exit_stack :
171
+ raise RuntimeError ("MCPBaseClient not initialized. Use async with to initialize." )
172
+
140
173
if not self ._tools :
141
174
self ._tools = await self .get_tools ()
142
175
@@ -146,9 +179,11 @@ async def get_tool(self, tool_name: str) -> MCPToolClient:
146
179
return tool
147
180
148
181
async def call_tool (self , tool_name : str , tool_args : dict | None ):
149
- async with self .connect_to_server () as session :
150
- result = await session .call_tool (tool_name , tool_args )
151
- return result
182
+ if not self ._session :
183
+ raise RuntimeError ("MCPBaseClient not initialized. Use async with to initialize." )
184
+
185
+ result = await self ._session .call_tool (tool_name , tool_args )
186
+ return result
152
187
153
188
154
189
class MCPSSEClient (MCPBaseClient ):
@@ -164,6 +199,10 @@ def __init__(self, url: str, client_type: str = 'sse'):
164
199
super ().__init__ (client_type )
165
200
self ._url = url
166
201
202
+ @property
203
+ def url (self ) -> str :
204
+ return self ._url
205
+
167
206
@asynccontextmanager
168
207
async def connect_to_server (self ):
169
208
"""
@@ -195,39 +234,55 @@ def __init__(self,
195
234
self ._command = command
196
235
self ._args = args
197
236
self ._env = env
198
- self ._session = None # hold session if persistent
199
- self ._session_cm = None
200
-
201
- async def start_persistent_session (self ):
202
- """Starts and holds a persistent session."""
203
- server_params = StdioServerParameters (command = self ._command , args = self ._args , env = self ._env )
204
- self ._session_cm = stdio_client (server_params )
205
- read , write = await self ._session_cm .__aenter__ ()
206
- self ._session = ClientSession (read , write )
207
- await self ._session .initialize ()
208
-
209
- async def stop_persistent_session (self ):
210
- """Ends the persistent session."""
211
- if self ._session :
212
- await self ._session .__aexit__ (None , None , None )
213
- self ._session = None
214
- if self ._session_cm :
215
- await self ._session_cm .__aexit__ (None , None , None )
216
- self ._session_cm = None
237
+
238
+ @property
239
+ def command (self ) -> str :
240
+ return self ._command
241
+
242
+ @property
243
+ def args (self ) -> list [str ] | None :
244
+ return self ._args
245
+
246
+ @property
247
+ def env (self ) -> dict [str , str ] | None :
248
+ return self ._env
217
249
218
250
@asynccontextmanager
219
251
async def connect_to_server (self ):
220
252
"""
221
253
Establish a session with an MCP server via stdio within an async context
222
254
"""
223
- if self ._session :
224
- yield self ._session
225
- else :
226
- server_params = StdioServerParameters (command = self ._command , args = self ._args or [], env = self ._env )
227
- async with stdio_client (server_params ) as (read , write ):
228
- async with ClientSession (read , write ) as session :
229
- await session .initialize ()
230
- yield session
255
+
256
+ server_params = StdioServerParameters (command = self ._command , args = self ._args or [], env = self ._env )
257
+ async with stdio_client (server_params ) as (read , write ):
258
+ async with ClientSession (read , write ) as session :
259
+ await session .initialize ()
260
+ yield session
261
+
262
+
263
+ class MCPStreamableHTTPClient (MCPBaseClient ):
264
+ """
265
+ Client for creating a session and connecting to an MCP server using streamable-http
266
+ """
267
+
268
+ def __init__ (self , url : str , client_type : str = 'streamable-http' ):
269
+ super ().__init__ (client_type )
270
+
271
+ self ._url = url
272
+
273
+ @property
274
+ def url (self ) -> str :
275
+ return self ._url
276
+
277
+ @asynccontextmanager
278
+ async def connect_to_server (self ):
279
+ """
280
+ Establish a session with an MCP server via streamable-http within an async context
281
+ """
282
+ async with streamablehttp_client (url = self ._url ) as (read , write , get_session_id ):
283
+ async with ClientSession (read , write ) as session :
284
+ await session .initialize ()
285
+ yield session
231
286
232
287
233
288
class MCPToolClient :
@@ -242,11 +297,11 @@ class MCPToolClient:
242
297
"""
243
298
244
299
def __init__ (self ,
245
- connect_fn : Callable [[], AbstractAsyncContextManager [ ClientSession ]] ,
300
+ session : ClientSession ,
246
301
tool_name : str ,
247
302
tool_description : str | None ,
248
303
tool_input_schema : dict | None = None ):
249
- self ._connect_fn = connect_fn
304
+ self ._session = session
250
305
self ._tool_name = tool_name
251
306
self ._tool_description = tool_description
252
307
self ._input_schema = (model_from_mcp_schema (self ._tool_name , tool_input_schema ) if tool_input_schema else None )
@@ -285,8 +340,7 @@ async def acall(self, tool_args: dict) -> str:
285
340
Args:
286
341
tool_args (dict[str, Any]): A dictionary of key value pairs to serve as inputs for the MCP tool.
287
342
"""
288
- async with self ._connect_fn () as session :
289
- result = await session .call_tool (self ._tool_name , tool_args )
343
+ result = await self ._session .call_tool (self ._tool_name , tool_args )
290
344
291
345
output = []
292
346
@@ -296,4 +350,9 @@ async def acall(self, tool_args: dict) -> str:
296
350
else :
297
351
# Log non-text content for now
298
352
logger .warning ("Got not-text output from %s of type %s" , self .name , type (res ))
299
- return "\n " .join (output )
353
+ result_str = "\n " .join (output )
354
+
355
+ if result .isError :
356
+ raise RuntimeError (result_str )
357
+
358
+ return result_str
0 commit comments