diff --git a/src/mcp/server/fastmcp/prompts/manager.py b/src/mcp/server/fastmcp/prompts/manager.py index 7ccbdef3..71e55aa0 100644 --- a/src/mcp/server/fastmcp/prompts/manager.py +++ b/src/mcp/server/fastmcp/prompts/manager.py @@ -4,6 +4,7 @@ from mcp.server.fastmcp.prompts.base import Message, Prompt from mcp.server.fastmcp.utilities.logging import get_logger +from mcp.types import AnyFunction logger = get_logger(__name__) @@ -25,9 +26,18 @@ def list_prompts(self) -> list[Prompt]: def add_prompt( self, - prompt: Prompt, + fn: AnyFunction, + name: str | None = None, + description: str | None = None, ) -> Prompt: - """Add a prompt to the manager.""" + """Add a prompt to the manager. + + Args: + fn: Function to create a prompt from + name: Optional name for the prompt + description: Optional description of the prompt + """ + prompt = Prompt.from_function(fn, name=name, description=description) # Check for duplicates existing = self._prompts.get(prompt.name) diff --git a/src/mcp/server/fastmcp/server.py b/src/mcp/server/fastmcp/server.py index c31f29d4..68803e32 100644 --- a/src/mcp/server/fastmcp/server.py +++ b/src/mcp/server/fastmcp/server.py @@ -35,7 +35,7 @@ AuthSettings, ) from mcp.server.fastmcp.exceptions import ResourceError -from mcp.server.fastmcp.prompts import Prompt, PromptManager +from mcp.server.fastmcp.prompts import PromptManager from mcp.server.fastmcp.resources import FunctionResource, Resource, ResourceManager from mcp.server.fastmcp.tools import ToolManager from mcp.server.fastmcp.utilities.logging import configure_logging, get_logger @@ -138,8 +138,9 @@ def __init__( self, name: str | None = None, instructions: str | None = None, - auth_server_provider: OAuthAuthorizationServerProvider[Any, Any, Any] - | None = None, + auth_server_provider: ( + OAuthAuthorizationServerProvider[Any, Any, Any] | None + ) = None, event_store: EventStore | None = None, **settings: Any, ): @@ -148,9 +149,11 @@ def __init__( self._mcp_server = MCPServer( name=name or "FastMCP", instructions=instructions, - lifespan=lifespan_wrapper(self, self.settings.lifespan) - if self.settings.lifespan - else default_lifespan, + lifespan=( + lifespan_wrapper(self, self.settings.lifespan) + if self.settings.lifespan + else default_lifespan + ), ) self._tool_manager = ToolManager( warn_on_duplicate_tools=self.settings.warn_on_duplicate_tools @@ -481,13 +484,20 @@ def decorator(fn: AnyFunction) -> AnyFunction: return decorator - def add_prompt(self, prompt: Prompt) -> None: + def add_prompt( + self, + fn: AnyFunction, + name: str | None = None, + description: str | None = None, + ) -> None: """Add a prompt to the server. Args: - prompt: A Prompt instance to add + fn: Function to create a prompt from + name: Optional name for the prompt + description: Optional description of the prompt """ - self._prompt_manager.add_prompt(prompt) + self._prompt_manager.add_prompt(fn, name=name, description=description) def prompt( self, name: str | None = None, description: str | None = None @@ -533,8 +543,7 @@ async def analyze_file(path: str) -> list[Message]: ) def decorator(func: AnyFunction) -> AnyFunction: - prompt = Prompt.from_function(func, name=name, description=description) - self.add_prompt(prompt) + self.add_prompt(func, name=name, description=description) return func return decorator diff --git a/tests/server/fastmcp/prompts/test_manager.py b/tests/server/fastmcp/prompts/test_manager.py index c64a4a56..59d0e553 100644 --- a/tests/server/fastmcp/prompts/test_manager.py +++ b/tests/server/fastmcp/prompts/test_manager.py @@ -12,10 +12,10 @@ def fn() -> str: return "Hello, world!" manager = PromptManager() - prompt = Prompt.from_function(fn) - added = manager.add_prompt(prompt) - assert added == prompt - assert manager.get_prompt("fn") == prompt + added = manager.add_prompt(fn) + assert isinstance(added, Prompt) + assert added.name == "fn" + assert manager.get_prompt("fn") == added def test_add_duplicate_prompt(self, caplog): """Test adding the same prompt twice.""" @@ -24,9 +24,8 @@ def fn() -> str: return "Hello, world!" manager = PromptManager() - prompt = Prompt.from_function(fn) - first = manager.add_prompt(prompt) - second = manager.add_prompt(prompt) + first = manager.add_prompt(fn) + second = manager.add_prompt(fn) assert first == second assert "Prompt already exists" in caplog.text @@ -37,9 +36,8 @@ def fn() -> str: return "Hello, world!" manager = PromptManager(warn_on_duplicate_prompts=False) - prompt = Prompt.from_function(fn) - first = manager.add_prompt(prompt) - second = manager.add_prompt(prompt) + first = manager.add_prompt(fn) + second = manager.add_prompt(fn) assert first == second assert "Prompt already exists" not in caplog.text @@ -53,10 +51,8 @@ def fn2() -> str: return "Goodbye, world!" manager = PromptManager() - prompt1 = Prompt.from_function(fn1) - prompt2 = Prompt.from_function(fn2) - manager.add_prompt(prompt1) - manager.add_prompt(prompt2) + prompt1 = manager.add_prompt(fn1) + prompt2 = manager.add_prompt(fn2) prompts = manager.list_prompts() assert len(prompts) == 2 assert prompts == [prompt1, prompt2] @@ -69,8 +65,7 @@ def fn() -> str: return "Hello, world!" manager = PromptManager() - prompt = Prompt.from_function(fn) - manager.add_prompt(prompt) + manager.add_prompt(fn) messages = await manager.render_prompt("fn") assert messages == [ UserMessage(content=TextContent(type="text", text="Hello, world!")) @@ -84,8 +79,7 @@ def fn(name: str) -> str: return f"Hello, {name}!" manager = PromptManager() - prompt = Prompt.from_function(fn) - manager.add_prompt(prompt) + manager.add_prompt(fn) messages = await manager.render_prompt("fn", arguments={"name": "World"}) assert messages == [ UserMessage(content=TextContent(type="text", text="Hello, World!")) @@ -106,7 +100,6 @@ def fn(name: str) -> str: return f"Hello, {name}!" manager = PromptManager() - prompt = Prompt.from_function(fn) - manager.add_prompt(prompt) + manager.add_prompt(fn) with pytest.raises(ValueError, match="Missing required arguments"): await manager.render_prompt("fn")