Skip to content

Move Prompt object instantiation from server to prompt manager #687

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 12 additions & 2 deletions src/mcp/server/fastmcp/prompts/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

from mcp.server.fastmcp.prompts.base import Message, Prompt
from mcp.server.fastmcp.utilities.logging import get_logger
from mcp.types import AnyFunction

logger = get_logger(__name__)

Expand All @@ -25,9 +26,18 @@ def list_prompts(self) -> list[Prompt]:

def add_prompt(
self,
prompt: Prompt,
fn: AnyFunction,
name: str | None = None,
description: str | None = None,
) -> Prompt:
"""Add a prompt to the manager."""
"""Add a prompt to the manager.

Args:
fn: Function to create a prompt from
name: Optional name for the prompt
description: Optional description of the prompt
"""
prompt = Prompt.from_function(fn, name=name, description=description)

# Check for duplicates
existing = self._prompts.get(prompt.name)
Expand Down
31 changes: 20 additions & 11 deletions src/mcp/server/fastmcp/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@
AuthSettings,
)
from mcp.server.fastmcp.exceptions import ResourceError
from mcp.server.fastmcp.prompts import Prompt, PromptManager
from mcp.server.fastmcp.prompts import PromptManager
from mcp.server.fastmcp.resources import FunctionResource, Resource, ResourceManager
from mcp.server.fastmcp.tools import ToolManager
from mcp.server.fastmcp.utilities.logging import configure_logging, get_logger
Expand Down Expand Up @@ -138,8 +138,9 @@ def __init__(
self,
name: str | None = None,
instructions: str | None = None,
auth_server_provider: OAuthAuthorizationServerProvider[Any, Any, Any]
| None = None,
auth_server_provider: (
OAuthAuthorizationServerProvider[Any, Any, Any] | None
) = None,
event_store: EventStore | None = None,
**settings: Any,
):
Expand All @@ -148,9 +149,11 @@ def __init__(
self._mcp_server = MCPServer(
name=name or "FastMCP",
instructions=instructions,
lifespan=lifespan_wrapper(self, self.settings.lifespan)
if self.settings.lifespan
else default_lifespan,
lifespan=(
lifespan_wrapper(self, self.settings.lifespan)
if self.settings.lifespan
else default_lifespan
),
)
self._tool_manager = ToolManager(
warn_on_duplicate_tools=self.settings.warn_on_duplicate_tools
Expand Down Expand Up @@ -481,13 +484,20 @@ def decorator(fn: AnyFunction) -> AnyFunction:

return decorator

def add_prompt(self, prompt: Prompt) -> None:
def add_prompt(
self,
fn: AnyFunction,
name: str | None = None,
description: str | None = None,
) -> None:
"""Add a prompt to the server.

Args:
prompt: A Prompt instance to add
fn: Function to create a prompt from
name: Optional name for the prompt
description: Optional description of the prompt
"""
self._prompt_manager.add_prompt(prompt)
self._prompt_manager.add_prompt(fn, name=name, description=description)

def prompt(
self, name: str | None = None, description: str | None = None
Expand Down Expand Up @@ -533,8 +543,7 @@ async def analyze_file(path: str) -> list[Message]:
)

def decorator(func: AnyFunction) -> AnyFunction:
prompt = Prompt.from_function(func, name=name, description=description)
self.add_prompt(prompt)
self.add_prompt(func, name=name, description=description)
return func

return decorator
Expand Down
33 changes: 13 additions & 20 deletions tests/server/fastmcp/prompts/test_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,10 @@ def fn() -> str:
return "Hello, world!"

manager = PromptManager()
prompt = Prompt.from_function(fn)
added = manager.add_prompt(prompt)
assert added == prompt
assert manager.get_prompt("fn") == prompt
added = manager.add_prompt(fn)
assert isinstance(added, Prompt)
assert added.name == "fn"
assert manager.get_prompt("fn") == added

def test_add_duplicate_prompt(self, caplog):
"""Test adding the same prompt twice."""
Expand All @@ -24,9 +24,8 @@ def fn() -> str:
return "Hello, world!"

manager = PromptManager()
prompt = Prompt.from_function(fn)
first = manager.add_prompt(prompt)
second = manager.add_prompt(prompt)
first = manager.add_prompt(fn)
second = manager.add_prompt(fn)
assert first == second
assert "Prompt already exists" in caplog.text

Expand All @@ -37,9 +36,8 @@ def fn() -> str:
return "Hello, world!"

manager = PromptManager(warn_on_duplicate_prompts=False)
prompt = Prompt.from_function(fn)
first = manager.add_prompt(prompt)
second = manager.add_prompt(prompt)
first = manager.add_prompt(fn)
second = manager.add_prompt(fn)
assert first == second
assert "Prompt already exists" not in caplog.text

Expand All @@ -53,10 +51,8 @@ def fn2() -> str:
return "Goodbye, world!"

manager = PromptManager()
prompt1 = Prompt.from_function(fn1)
prompt2 = Prompt.from_function(fn2)
manager.add_prompt(prompt1)
manager.add_prompt(prompt2)
prompt1 = manager.add_prompt(fn1)
prompt2 = manager.add_prompt(fn2)
prompts = manager.list_prompts()
assert len(prompts) == 2
assert prompts == [prompt1, prompt2]
Expand All @@ -69,8 +65,7 @@ def fn() -> str:
return "Hello, world!"

manager = PromptManager()
prompt = Prompt.from_function(fn)
manager.add_prompt(prompt)
manager.add_prompt(fn)
messages = await manager.render_prompt("fn")
assert messages == [
UserMessage(content=TextContent(type="text", text="Hello, world!"))
Expand All @@ -84,8 +79,7 @@ def fn(name: str) -> str:
return f"Hello, {name}!"

manager = PromptManager()
prompt = Prompt.from_function(fn)
manager.add_prompt(prompt)
manager.add_prompt(fn)
messages = await manager.render_prompt("fn", arguments={"name": "World"})
assert messages == [
UserMessage(content=TextContent(type="text", text="Hello, World!"))
Expand All @@ -106,7 +100,6 @@ def fn(name: str) -> str:
return f"Hello, {name}!"

manager = PromptManager()
prompt = Prompt.from_function(fn)
manager.add_prompt(prompt)
manager.add_prompt(fn)
with pytest.raises(ValueError, match="Missing required arguments"):
await manager.render_prompt("fn")
Loading