Skip to content

Commit e5b4d5f

Browse files
Fix "aiq info" problems after adding AsyncExitStack
Signed-off-by: Anuradha Karuppiah <[email protected]>
1 parent 1f1adc3 commit e5b4d5f

File tree

2 files changed

+37
-27
lines changed

2 files changed

+37
-27
lines changed

src/aiq/cli/commands/info/list_mcp.py

Lines changed: 26 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -60,19 +60,23 @@ async def list_tools_and_schemas(command, url, tool_name=None, client_type='sse'
6060
args = []
6161
from aiq.tool.mcp.mcp_client import MCPSSEClient
6262
from aiq.tool.mcp.mcp_client import MCPStdioClient
63+
from aiq.tool.mcp.mcp_client import MCPStreamableHTTPClient
6364

6465
try:
6566
if client_type == 'stdio':
6667
client = MCPStdioClient(command=command, args=args, env=env)
67-
else:
68+
elif client_type == 'streamable-http':
69+
client = MCPStreamableHTTPClient(url=url)
70+
else: # sse
6871
client = MCPSSEClient(url=url)
6972

70-
if tool_name:
71-
tool = await client.get_tool(tool_name)
72-
return [format_tool(tool)]
73-
else:
74-
tools = await client.get_tools()
75-
return [format_tool(tool) for tool in tools.values()]
73+
async with client:
74+
if tool_name:
75+
tool = await client.get_tool(tool_name)
76+
return [format_tool(tool)]
77+
else:
78+
tools = await client.get_tools()
79+
return [format_tool(tool) for tool in tools.values()]
7680
except Exception as e:
7781
click.echo(f"[ERROR] Failed to fetch tools via MCP client: {e}", err=True)
7882
return []
@@ -85,6 +89,7 @@ async def list_tools_direct(command, url, tool_name=None, client_type='sse', arg
8589
from mcp.client.sse import sse_client
8690
from mcp.client.stdio import StdioServerParameters
8791
from mcp.client.stdio import stdio_client
92+
from mcp.client.streamable_http import streamablehttp_client
8893

8994
try:
9095
if client_type == 'stdio':
@@ -93,7 +98,13 @@ def get_stdio_client():
9398
return stdio_client(server=StdioServerParameters(command=command, args=args, env=env))
9499

95100
client = get_stdio_client
96-
else:
101+
elif client_type == 'streamable-http':
102+
103+
def get_streamable_http_client():
104+
return streamablehttp_client(url=url)
105+
106+
client = get_streamable_http_client
107+
else: # sse
97108

98109
def get_sse_client():
99110
return sse_client(url=url)
@@ -125,9 +136,9 @@ def get_sse_client():
125136
@click.option('--url',
126137
default='http://localhost:9901/sse',
127138
show_default=True,
128-
help='For SSE: MCP server URL (e.g. http://localhost:8080/sse)')
139+
help='For SSE/StreamableHTTP: MCP server URL (e.g. http://localhost:8080/sse)')
129140
@click.option('--client-type',
130-
type=click.Choice(['sse', 'stdio']),
141+
type=click.Choice(['sse', 'stdio', 'streamable-http']),
131142
default='sse',
132143
show_default=True,
133144
help='Type of client to use')
@@ -151,12 +162,14 @@ def list_mcp(ctx, direct, url, client_type, command, args, env, tool, detail, js
151162
click.echo("[ERROR] --command is required when using stdio client type", err=True)
152163
return
153164

154-
if client_type == 'sse':
165+
if client_type in ['sse', 'streamable-http']:
155166
if not url:
156-
click.echo("[ERROR] --url is required when using sse client type", err=True)
167+
click.echo("[ERROR] --url is required when using sse or streamable-http client type", err=True)
157168
return
158169
if command or args or env:
159-
click.echo("[ERROR] --command, --args, and --env are not allowed when using sse client type", err=True)
170+
click.echo(
171+
"[ERROR] --command, --args, and --env are not allowed when using sse or streamable-http client type",
172+
err=True)
160173
return
161174

162175
stdio_args = args.split() if args else []

src/aiq/tool/mcp/mcp_tool.py

Lines changed: 11 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,8 @@
1414
# limitations under the License.
1515

1616
import logging
17+
from typing import Literal
1718

18-
from pydantic import BaseModel
1919
from pydantic import Field
2020
from pydantic import HttpUrl
2121

@@ -35,7 +35,8 @@ class MCPToolConfig(FunctionBaseConfig, name="mcp_tool_wrapper"):
3535
# Add your custom configuration parameters here
3636
url: HttpUrl | None = Field(default=None, description="The URL of the MCP server (for SSE mode)")
3737
mcp_tool_name: str = Field(description="The name of the tool served by the MCP Server that you want to use")
38-
client_type: str = Field(default="sse", description="The type of client to use ('sse' or 'stdio')")
38+
client_type: Literal["sse", "stdio", "streamable-http"] = Field(default="sse",
39+
description="The type of transport to use")
3940
command: str | None = Field(default=None,
4041
description="The command to run for stdio mode (e.g. 'docker' or 'python')")
4142
args: list[str] | None = Field(default=None, description="Additional arguments for the stdio command")
@@ -52,15 +53,15 @@ class MCPToolConfig(FunctionBaseConfig, name="mcp_tool_wrapper"):
5253
""")
5354

5455
def model_post_init(self, __context):
55-
"""Validate that stdio and SSE properties are mutually exclusive."""
56+
"""Validate that stdio and SSE/Streamable HTTP properties are mutually exclusive."""
5657
super().model_post_init(__context)
5758

5859
if self.client_type == 'stdio':
5960
if self.url is not None:
6061
raise ValueError("url should not be set when using stdio client type")
6162
if not self.command:
6263
raise ValueError("command is required when using stdio client type")
63-
else:
64+
elif self.client_type == 'streamable-http' or self.client_type == 'sse':
6465
if self.command is not None or self.args is not None or self.env is not None:
6566
raise ValueError("command, args, and env should not be set when using sse client type")
6667
if not self.url:
@@ -83,17 +84,20 @@ async def mcp_tool(config: MCPToolConfig, builder: Builder): # pylint: disable=
8384
if not config.command:
8485
raise ValueError("command is required when using stdio client type")
8586

87+
source = f"{config.command} {' '.join(config.args) if config.args else ''}"
8688
client = MCPStdioClient(command=config.command, args=config.args, env=config.env)
8789
elif config.client_type == 'streamable-http':
8890
if not config.url:
8991
raise ValueError("url is required when using streamable-http client type")
9092

91-
client = MCPStreamableHTTPClient(url=str(config.url))
93+
source = str(config.url)
94+
client = MCPStreamableHTTPClient(url=source)
9295
elif config.client_type == 'sse':
9396
if not config.url:
9497
raise ValueError("url is required when using sse client type")
9598

96-
client = MCPSSEClient(url=str(config.url))
99+
source = str(config.url)
100+
client = MCPSSEClient(url=source)
97101
else:
98102
raise ValueError(f"Invalid client type: {config.client_type}")
99103

@@ -103,19 +107,12 @@ async def mcp_tool(config: MCPToolConfig, builder: Builder): # pylint: disable=
103107
if config.description:
104108
tool.set_description(description=config.description)
105109

106-
if config.client_type == "sse" or config.client_type == "streamable-http":
107-
source = config.url
108-
elif config.client_type == "stdio":
109-
source = f"{config.command} {' '.join(config.args) if config.args else ''}"
110-
else:
111-
raise ValueError(f"Invalid client type: {config.client_type}")
112-
113110
logger.info("Configured to use tool: %s from MCP server at %s", tool.name, source)
114111

115112
def _convert_from_str(input_str: str) -> tool.input_schema:
116113
return tool.input_schema.model_validate_json(input_str)
117114

118-
async def _response_fn(tool_input: tool.input_schema) -> str:
115+
async def _response_fn(tool_input: tool.input_schema | None = None, **kwargs) -> str:
119116
# Run the tool, catching any errors and sending to agent for correction
120117
try:
121118
if tool_input:

0 commit comments

Comments
 (0)