diff --git a/pydantic_ai_slim/pydantic_ai/mcp.py b/pydantic_ai_slim/pydantic_ai/mcp.py index ab6e124d4..ade77ef29 100644 --- a/pydantic_ai_slim/pydantic_ai/mcp.py +++ b/pydantic_ai_slim/pydantic_ai/mcp.py @@ -35,6 +35,8 @@ class MCPServer(ABC): is_running: bool = False + allowed_tools: list[str] | None = None + _client: ClientSession _read_stream: MemoryObjectReceiveStream[JSONRPCMessage | Exception] _write_stream: MemoryObjectSendStream[JSONRPCMessage] @@ -66,6 +68,7 @@ async def list_tools(self) -> list[ToolDefinition]: parameters_json_schema=tool.inputSchema, ) for tool in tools.tools + if self.allowed_tools is None or tool.name in self.allowed_tools ] async def call_tool(self, tool_name: str, arguments: dict[str, Any]) -> CallToolResult: @@ -78,6 +81,9 @@ async def call_tool(self, tool_name: str, arguments: dict[str, Any]) -> CallTool Returns: The result of the tool call. """ + if self.allowed_tools is not None and tool_name not in self.allowed_tools: + raise ValueError(f'Tool {tool_name} is not in the list of allowed_tools') + return await self._client.call_tool(tool_name, arguments) async def __aenter__(self) -> Self: @@ -139,6 +145,9 @@ async def main(): If you want to inherit the environment variables from the parent process, use `env=os.environ`. """ + allowed_tools: list[str] | None = None + """A list of tool names that can be called on this mcp server""" + @asynccontextmanager async def client_streams( self, @@ -188,6 +197,9 @@ async def main(): For example for a server running locally, this might be `http://localhost:3001/sse`. """ + allowed_tools: list[str] | None = None + """A list of tool names that can be called on this mcp server""" + @asynccontextmanager async def client_streams( self, diff --git a/tests/test_mcp.py b/tests/test_mcp.py index 8f83153fa..e7bc3926c 100644 --- a/tests/test_mcp.py +++ b/tests/test_mcp.py @@ -38,6 +38,28 @@ async def test_stdio_server(): assert result.content == snapshot([TextContent(type='text', text='32.0')]) +async def test_stdio_server_tool_filtering(): + server = MCPServerStdio('python', ['-m', 'tests.mcp_server'], allowed_tools=['celsius_to_fahrenheit']) + async with server: + tools = await server.list_tools() + assert len(tools) == 1 + assert tools[0].name == 'celsius_to_fahrenheit' + assert tools[0].description.startswith('Convert Celsius to Fahrenheit.') + + # Test calling the temperature conversion tool + result = await server.call_tool('celsius_to_fahrenheit', {'celsius': 0}) + assert result.content == snapshot([TextContent(type='text', text='32.0')]) + + # Test setting allowed tools to empty list + server.allowed_tools = [] + tools = await server.list_tools() + assert len(tools) == 0 + + # Test calling the temperature conversion tool when its not allowed + with pytest.raises(ValueError): + result = await server.call_tool('celsius_to_fahrenheit', {'celsius': 0}) + + def test_sse_server(): sse_server = MCPServerHTTP(url='http://localhost:8000/sse') assert sse_server.url == 'http://localhost:8000/sse' diff --git a/uv.lock b/uv.lock index 71b883e19..feb998fda 100644 --- a/uv.lock +++ b/uv.lock @@ -2964,7 +2964,7 @@ requires-dist = [ { name = "logfire", marker = "extra == 'logfire'", specifier = ">=2.3" }, { name = "mcp", marker = "python_full_version >= '3.10' and extra == 'mcp'", specifier = ">=1.4.1" }, { name = "mistralai", marker = "extra == 'mistral'", specifier = ">=1.2.5" }, - { name = "openai", marker = "extra == 'openai'", specifier = ">=1.66.0" }, + { name = "openai", marker = "extra == 'openai'", specifier = ">=1.67.0" }, { name = "opentelemetry-api", specifier = ">=1.28.0" }, { name = "prompt-toolkit", marker = "extra == 'cli'", specifier = ">=3" }, { name = "pydantic", specifier = ">=2.10" },