diff --git a/README.md b/README.md index 26f43cfd9..b197ca888 100644 --- a/README.md +++ b/README.md @@ -30,10 +30,12 @@ - [Prompts](#prompts) - [Images](#images) - [Context](#context) + - [Authentication](#authentication) - [Running Your Server](#running-your-server) - [Development Mode](#development-mode) - [Claude Desktop Integration](#claude-desktop-integration) - [Direct Execution](#direct-execution) + - [Streamable HTTP Transport](#streamable-http-transport) - [Mounting to an Existing ASGI Server](#mounting-to-an-existing-asgi-server) - [Examples](#examples) - [Echo Server](#echo-server) @@ -243,6 +245,19 @@ async def fetch_weather(city: str) -> str: async with httpx.AsyncClient() as client: response = await client.get(f"https://api.weather.com/{city}") return response.text + + +tool = mcp._tool_manager.get_tool("fetch_weather") + + +async def disable_tool(): + # Disable the tool temporarily + await tool.disable(mcp.get_context()) + + +async def enable_tool(): + # Re-enable the tool when needed + await tool.enable(mcp.get_context()) ``` ### Prompts diff --git a/src/mcp/client/session_group.py b/src/mcp/client/session_group.py index a430533b3..a77dc7a1e 100644 --- a/src/mcp/client/session_group.py +++ b/src/mcp/client/session_group.py @@ -154,7 +154,6 @@ async def __aexit__( for exit_stack in self._session_exit_stacks.values(): tg.start_soon(exit_stack.aclose) - @property def sessions(self) -> list[mcp.ClientSession]: """Returns the list of sessions being managed.""" diff --git a/src/mcp/server/fastmcp/tools/base.py b/src/mcp/server/fastmcp/tools/base.py index 01fedcdc9..87d105dc7 100644 --- a/src/mcp/server/fastmcp/tools/base.py +++ b/src/mcp/server/fastmcp/tools/base.py @@ -9,7 +9,7 @@ from mcp.server.fastmcp.exceptions import ToolError from mcp.server.fastmcp.utilities.func_metadata import FuncMetadata, func_metadata -from mcp.types import ToolAnnotations +from mcp.types import ServerNotification, ToolAnnotations, ToolListChangedNotification if TYPE_CHECKING: from mcp.server.fastmcp.server import Context @@ -35,6 +35,7 @@ class Tool(BaseModel): annotations: ToolAnnotations | None = Field( None, description="Optional annotations for the tool" ) + enabled: bool = Field(default=True, description="Whether the tool is enabled") @classmethod def from_function( @@ -100,6 +101,32 @@ async def run( except Exception as e: raise ToolError(f"Error executing tool {self.name}: {e}") from e + async def enable( + self, context: Context[ServerSessionT, LifespanContextT] | None = None + ) -> None: + """Enable the tool and notify clients.""" + if not self.enabled: + self.enabled = True + if context and context.session: + notification = ToolListChangedNotification( + method="notifications/tools/list_changed" + ) + server_notification = ServerNotification.model_validate(notification) + await context.session.send_notification(server_notification) + + async def disable( + self, context: Context[ServerSessionT, LifespanContextT] | None = None + ) -> None: + """Disable the tool and notify clients.""" + if self.enabled: + self.enabled = False + if context and context.session: + notification = ToolListChangedNotification( + method="notifications/tools/list_changed" + ) + server_notification = ServerNotification.model_validate(notification) + await context.session.send_notification(server_notification) + def _is_async_callable(obj: Any) -> bool: while isinstance(obj, functools.partial): diff --git a/src/mcp/server/fastmcp/tools/tool_manager.py b/src/mcp/server/fastmcp/tools/tool_manager.py index 6ec4fd151..0df5015cc 100644 --- a/src/mcp/server/fastmcp/tools/tool_manager.py +++ b/src/mcp/server/fastmcp/tools/tool_manager.py @@ -39,8 +39,8 @@ def get_tool(self, name: str) -> Tool | None: return self._tools.get(name) def list_tools(self) -> list[Tool]: - """List all registered tools.""" - return list(self._tools.values()) + """List all enabled registered tools.""" + return [tool for tool in self._tools.values() if tool.enabled] def add_tool( self, @@ -72,4 +72,7 @@ async def call_tool( if not tool: raise ToolError(f"Unknown tool: {name}") + if not tool.enabled: + raise ToolError(f"Tool is disabled: {name}") + return await tool.run(arguments, context=context) diff --git a/tests/server/fastmcp/test_tool_manager.py b/tests/server/fastmcp/test_tool_manager.py index 203a7172b..25b69f4db 100644 --- a/tests/server/fastmcp/test_tool_manager.py +++ b/tests/server/fastmcp/test_tool_manager.py @@ -453,3 +453,114 @@ def echo(message: str) -> str: assert tools[0].annotations is not None assert tools[0].annotations.title == "Echo Tool" assert tools[0].annotations.readOnlyHint is True + + +class TestToolEnableDisable: + """Test enabling and disabling tools.""" + + @pytest.mark.anyio + async def test_enable_disable_tool(self): + """Test enabling and disabling a tool.""" + + def add(a: int, b: int) -> int: + """Add two numbers.""" + return a + b + + manager = ToolManager() + tool = manager.add_tool(add) + + # Tool should be enabled by default + assert tool.enabled is True + + # Disable the tool + await tool.disable() + assert tool.enabled is False + + # Enable the tool + await tool.enable() + assert tool.enabled is True + + @pytest.mark.anyio + async def test_enable_disable_no_change(self): + """Test enabling and disabling a tool when there's no state change.""" + + def add(a: int, b: int) -> int: + """Add two numbers.""" + return a + b + + manager = ToolManager() + tool = manager.add_tool(add) + + # Enable an already enabled tool (should not change state) + await tool.enable() + assert tool.enabled is True + + # Disable the tool + await tool.disable() + assert tool.enabled is False + + # Disable an already disabled tool (should not change state) + await tool.disable() + assert tool.enabled is False + + @pytest.mark.anyio + async def test_list_tools_filters_disabled(self): + """Test that list_tools only returns enabled tools.""" + + def add(a: int, b: int) -> int: + """Add two numbers.""" + return a + b + + def subtract(a: int, b: int) -> int: + """Subtract two numbers.""" + return a - b + + manager = ToolManager() + tool1 = manager.add_tool(add) + tool2 = manager.add_tool(subtract) + + # Both tools should be listed initially + tools = manager.list_tools() + assert len(tools) == 2 + assert tool1 in tools + assert tool2 in tools + + # Disable one tool + await tool1.disable() + + # Only enabled tool should be listed + tools = manager.list_tools() + assert len(tools) == 1 + assert tool1 not in tools + assert tool2 in tools + + # Re-enable the tool + await tool1.enable() + + # Both tools should be listed again + tools = manager.list_tools() + assert len(tools) == 2 + assert tool1 in tools + assert tool2 in tools + + @pytest.mark.anyio + async def test_call_disabled_tool_raises_error(self): + """Test that calling a disabled tool raises an error.""" + + def add(a: int, b: int) -> int: + """Add two numbers.""" + return a + b + + manager = ToolManager() + tool = manager.add_tool(add) + + # Tool should work normally when enabled + result = await manager.call_tool("add", {"a": 1, "b": 2}) + assert result == 3 + + # Disable the tool + await tool.disable() + + # Calling disabled tool should raise error + with pytest.raises(ToolError, match="Tool is disabled: add"): + await manager.call_tool("add", {"a": 1, "b": 2})