Skip to content

add tool name filtering to mcp server implementation #1220

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions pydantic_ai_slim/pydantic_ai/mcp.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,8 @@ class MCPServer(ABC):

is_running: bool = False

allowed_tools: list[str] | None = None

_client: ClientSession
_read_stream: MemoryObjectReceiveStream[JSONRPCMessage | Exception]
_write_stream: MemoryObjectSendStream[JSONRPCMessage]
Expand Down Expand Up @@ -66,6 +68,7 @@ async def list_tools(self) -> list[ToolDefinition]:
parameters_json_schema=tool.inputSchema,
)
for tool in tools.tools
if self.allowed_tools is None or tool.name in self.allowed_tools
]

async def call_tool(self, tool_name: str, arguments: dict[str, Any]) -> CallToolResult:
Expand All @@ -78,6 +81,9 @@ async def call_tool(self, tool_name: str, arguments: dict[str, Any]) -> CallTool
Returns:
The result of the tool call.
"""
if self.allowed_tools is not None and tool_name not in self.allowed_tools:
raise ValueError(f'Tool {tool_name} is not in the list of allowed_tools')

return await self._client.call_tool(tool_name, arguments)

async def __aenter__(self) -> Self:
Expand Down Expand Up @@ -139,6 +145,9 @@ async def main():
If you want to inherit the environment variables from the parent process, use `env=os.environ`.
"""

allowed_tools: list[str] | None = None

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No need to repeat this attribute definition on the subclasses

"""A list of tool names that can be called on this mcp server"""

@asynccontextmanager
async def client_streams(
self,
Expand Down Expand Up @@ -188,6 +197,9 @@ async def main():
For example for a server running locally, this might be `http://localhost:3001/sse`.
"""

allowed_tools: list[str] | None = None
"""A list of tool names that can be called on this mcp server"""

@asynccontextmanager
async def client_streams(
self,
Expand Down
22 changes: 22 additions & 0 deletions tests/test_mcp.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,28 @@ async def test_stdio_server():
assert result.content == snapshot([TextContent(type='text', text='32.0')])


async def test_stdio_server_tool_filtering():
server = MCPServerStdio('python', ['-m', 'tests.mcp_server'], allowed_tools=['celsius_to_fahrenheit'])
async with server:
tools = await server.list_tools()
assert len(tools) == 1
assert tools[0].name == 'celsius_to_fahrenheit'
assert tools[0].description.startswith('Convert Celsius to Fahrenheit.')

# Test calling the temperature conversion tool
result = await server.call_tool('celsius_to_fahrenheit', {'celsius': 0})
assert result.content == snapshot([TextContent(type='text', text='32.0')])

# Test setting allowed tools to empty list
server.allowed_tools = []
tools = await server.list_tools()
assert len(tools) == 0

# Test calling the temperature conversion tool when its not allowed
with pytest.raises(ValueError):
result = await server.call_tool('celsius_to_fahrenheit', {'celsius': 0})


def test_sse_server():
sse_server = MCPServerHTTP(url='http://localhost:8000/sse')
assert sse_server.url == 'http://localhost:8000/sse'
Expand Down
2 changes: 1 addition & 1 deletion uv.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.