Skip to content

Feature/add enable disable methods tools #728

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 15 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -30,10 +30,12 @@
- [Prompts](#prompts)
- [Images](#images)
- [Context](#context)
- [Authentication](#authentication)
- [Running Your Server](#running-your-server)
- [Development Mode](#development-mode)
- [Claude Desktop Integration](#claude-desktop-integration)
- [Direct Execution](#direct-execution)
- [Streamable HTTP Transport](#streamable-http-transport)
- [Mounting to an Existing ASGI Server](#mounting-to-an-existing-asgi-server)
- [Examples](#examples)
- [Echo Server](#echo-server)
Expand Down Expand Up @@ -243,6 +245,19 @@ async def fetch_weather(city: str) -> str:
async with httpx.AsyncClient() as client:
response = await client.get(f"https://api.weather.com/{city}")
return response.text


tool = mcp._tool_manager.get_tool("fetch_weather")


async def disable_tool():
# Disable the tool temporarily
await tool.disable(mcp.get_context())


async def enable_tool():
# Re-enable the tool when needed
await tool.enable(mcp.get_context())
```

### Prompts
Expand Down
1 change: 0 additions & 1 deletion src/mcp/client/session_group.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,6 @@ async def __aexit__(
for exit_stack in self._session_exit_stacks.values():
tg.start_soon(exit_stack.aclose)


@property
def sessions(self) -> list[mcp.ClientSession]:
"""Returns the list of sessions being managed."""
Expand Down
29 changes: 28 additions & 1 deletion src/mcp/server/fastmcp/tools/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

from mcp.server.fastmcp.exceptions import ToolError
from mcp.server.fastmcp.utilities.func_metadata import FuncMetadata, func_metadata
from mcp.types import ToolAnnotations
from mcp.types import ServerNotification, ToolAnnotations, ToolListChangedNotification

if TYPE_CHECKING:
from mcp.server.fastmcp.server import Context
Expand All @@ -35,6 +35,7 @@ class Tool(BaseModel):
annotations: ToolAnnotations | None = Field(
None, description="Optional annotations for the tool"
)
enabled: bool = Field(default=True, description="Whether the tool is enabled")

@classmethod
def from_function(
Expand Down Expand Up @@ -100,6 +101,32 @@ async def run(
except Exception as e:
raise ToolError(f"Error executing tool {self.name}: {e}") from e

async def enable(
self, context: Context[ServerSessionT, LifespanContextT] | None = None
) -> None:
"""Enable the tool and notify clients."""
if not self.enabled:
self.enabled = True
if context and context.session:
notification = ToolListChangedNotification(
method="notifications/tools/list_changed"
)
server_notification = ServerNotification.model_validate(notification)
await context.session.send_notification(server_notification)

async def disable(
self, context: Context[ServerSessionT, LifespanContextT] | None = None
) -> None:
"""Disable the tool and notify clients."""
if self.enabled:
self.enabled = False
if context and context.session:
notification = ToolListChangedNotification(
method="notifications/tools/list_changed"
)
server_notification = ServerNotification.model_validate(notification)
await context.session.send_notification(server_notification)


def _is_async_callable(obj: Any) -> bool:
while isinstance(obj, functools.partial):
Expand Down
7 changes: 5 additions & 2 deletions src/mcp/server/fastmcp/tools/tool_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,8 @@ def get_tool(self, name: str) -> Tool | None:
return self._tools.get(name)

def list_tools(self) -> list[Tool]:
"""List all registered tools."""
return list(self._tools.values())
"""List all enabled registered tools."""
return [tool for tool in self._tools.values() if tool.enabled]

def add_tool(
self,
Expand Down Expand Up @@ -72,4 +72,7 @@ async def call_tool(
if not tool:
raise ToolError(f"Unknown tool: {name}")

if not tool.enabled:
raise ToolError(f"Tool is disabled: {name}")

return await tool.run(arguments, context=context)
111 changes: 111 additions & 0 deletions tests/server/fastmcp/test_tool_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -453,3 +453,114 @@ def echo(message: str) -> str:
assert tools[0].annotations is not None
assert tools[0].annotations.title == "Echo Tool"
assert tools[0].annotations.readOnlyHint is True


class TestToolEnableDisable:
"""Test enabling and disabling tools."""

@pytest.mark.anyio
async def test_enable_disable_tool(self):
"""Test enabling and disabling a tool."""

def add(a: int, b: int) -> int:
"""Add two numbers."""
return a + b

manager = ToolManager()
tool = manager.add_tool(add)

# Tool should be enabled by default
assert tool.enabled is True

# Disable the tool
await tool.disable()
assert tool.enabled is False

# Enable the tool
await tool.enable()
assert tool.enabled is True

@pytest.mark.anyio
async def test_enable_disable_no_change(self):
"""Test enabling and disabling a tool when there's no state change."""

def add(a: int, b: int) -> int:
"""Add two numbers."""
return a + b

manager = ToolManager()
tool = manager.add_tool(add)

# Enable an already enabled tool (should not change state)
await tool.enable()
assert tool.enabled is True

# Disable the tool
await tool.disable()
assert tool.enabled is False

# Disable an already disabled tool (should not change state)
await tool.disable()
assert tool.enabled is False

@pytest.mark.anyio
async def test_list_tools_filters_disabled(self):
"""Test that list_tools only returns enabled tools."""

def add(a: int, b: int) -> int:
"""Add two numbers."""
return a + b

def subtract(a: int, b: int) -> int:
"""Subtract two numbers."""
return a - b

manager = ToolManager()
tool1 = manager.add_tool(add)
tool2 = manager.add_tool(subtract)

# Both tools should be listed initially
tools = manager.list_tools()
assert len(tools) == 2
assert tool1 in tools
assert tool2 in tools

# Disable one tool
await tool1.disable()

# Only enabled tool should be listed
tools = manager.list_tools()
assert len(tools) == 1
assert tool1 not in tools
assert tool2 in tools

# Re-enable the tool
await tool1.enable()

# Both tools should be listed again
tools = manager.list_tools()
assert len(tools) == 2
assert tool1 in tools
assert tool2 in tools

@pytest.mark.anyio
async def test_call_disabled_tool_raises_error(self):
"""Test that calling a disabled tool raises an error."""

def add(a: int, b: int) -> int:
"""Add two numbers."""
return a + b

manager = ToolManager()
tool = manager.add_tool(add)

# Tool should work normally when enabled
result = await manager.call_tool("add", {"a": 1, "b": 2})
assert result == 3

# Disable the tool
await tool.disable()

# Calling disabled tool should raise error
with pytest.raises(ToolError, match="Tool is disabled: add"):
await manager.call_tool("add", {"a": 1, "b": 2})