diff --git a/src/fastmcp/apps/app.py b/src/fastmcp/apps/app.py index b800ff5a34..db5b6102c9 100644 --- a/src/fastmcp/apps/app.py +++ b/src/fastmcp/apps/app.py @@ -27,7 +27,6 @@ def save_contact(name: str, email: str) -> str: from __future__ import annotations -import inspect from collections.abc import AsyncIterator, Callable, Sequence from contextlib import asynccontextmanager, suppress from typing import Any, Literal, TypeVar, overload @@ -38,6 +37,7 @@ def save_contact(name: str, email: str) -> str: from fastmcp.server.providers.base import Provider from fastmcp.server.providers.local_provider import LocalProvider from fastmcp.tools.base import Tool +from fastmcp.utilities.callable_utils import is_callable_object from fastmcp.utilities.logging import get_logger logger = get_logger(__name__) @@ -100,7 +100,7 @@ def _dispatch_decorator( decorator_name: str, ) -> Any: """Shared dispatch logic for @app.tool() and @app.ui() calling patterns.""" - if inspect.isroutine(name_or_fn): + if is_callable_object(name_or_fn): return register(name_or_fn, name) if isinstance(name_or_fn, str): diff --git a/src/fastmcp/decorators.py b/src/fastmcp/decorators.py index 75dff25ace..913ad1203a 100644 --- a/src/fastmcp/decorators.py +++ b/src/fastmcp/decorators.py @@ -39,3 +39,13 @@ def get_fastmcp_meta(fn: Any) -> Any | None: except ValueError: pass return None + + +def set_fastmcp_meta(fn: Any, metadata: Any) -> None: + """Attach FastMCP metadata to a function, handling bound methods. + + For bound methods and staticmethods, the metadata is attached to the + underlying ``__func__`` so that ``get_fastmcp_meta`` can find it. + """ + target = fn.__func__ if hasattr(fn, "__func__") else fn + target.__fastmcp__ = metadata diff --git a/src/fastmcp/prompts/function_prompt.py b/src/fastmcp/prompts/function_prompt.py index 63545cb90d..8bc20136b5 100644 --- a/src/fastmcp/prompts/function_prompt.py +++ b/src/fastmcp/prompts/function_prompt.py @@ -2,7 +2,6 @@ from __future__ import annotations -import functools import inspect import json import warnings @@ -23,7 +22,7 @@ from pydantic.json_schema import SkipJsonSchema import fastmcp -from fastmcp.decorators import resolve_task_config +from fastmcp.decorators import resolve_task_config, set_fastmcp_meta from fastmcp.exceptions import FastMCPDeprecationWarning, PromptError from fastmcp.prompts.base import Prompt, PromptArgument, PromptResult from fastmcp.server.auth.authorization import AuthCheck @@ -36,6 +35,11 @@ call_sync_fn_in_threadpool, is_coroutine_function, ) +from fastmcp.utilities.callable_utils import ( + get_callable_name, + is_callable_object, + prepare_callable, +) from fastmcp.utilities.json_schema import compress_schema from fastmcp.utilities.logging import get_logger from fastmcp.utilities.types import get_cached_typeadapter @@ -137,9 +141,7 @@ def from_function( auth=auth, ) - func_name = ( - metadata.name or getattr(fn, "__name__", None) or fn.__class__.__name__ - ) + func_name = metadata.name or get_callable_name(fn) if func_name == "": raise ValueError("You must provide a name for lambda functions") @@ -158,22 +160,10 @@ def from_function( else inspect.getdoc(fn) ) - # Normalize task to TaskConfig and validate - task_value = metadata.task - if task_value is None: - task_config = TaskConfig(mode="forbidden") - elif isinstance(task_value, bool): - task_config = TaskConfig.from_bool(task_value) - else: - task_config = task_value + task_config = TaskConfig.normalize(metadata.task) task_config.validate_function(fn, func_name) - # if the fn is a callable class, we need to get the __call__ method from here out - if not inspect.isroutine(fn) and not isinstance(fn, functools.partial): - fn = fn.__call__ - # if the fn is a staticmethod, we need to work with the underlying function - if isinstance(fn, staticmethod): - fn = fn.__func__ + fn = prepare_callable(fn) # Transform Context type annotations to Depends() for unified DI fn = transform_context_annotations(fn) @@ -452,8 +442,7 @@ def attach_metadata(fn: F, prompt_name: str | None) -> F: task=task, auth=auth, ) - target = fn.__func__ if hasattr(fn, "__func__") else fn - target.__fastmcp__ = metadata + set_fastmcp_meta(fn, metadata) return fn def decorator(fn: F, prompt_name: str | None) -> F: @@ -467,7 +456,7 @@ def decorator(fn: F, prompt_name: str | None) -> F: return create_prompt(fn, prompt_name) # type: ignore[return-value] # ty:ignore[invalid-return-type] return attach_metadata(fn, prompt_name) - if inspect.isroutine(name_or_fn): + if is_callable_object(name_or_fn): return decorator(name_or_fn, name) elif isinstance(name_or_fn, str): if name is not None: diff --git a/src/fastmcp/resources/function_resource.py b/src/fastmcp/resources/function_resource.py index df542fccef..80b8846dcf 100644 --- a/src/fastmcp/resources/function_resource.py +++ b/src/fastmcp/resources/function_resource.py @@ -2,7 +2,6 @@ from __future__ import annotations -import functools import inspect import warnings from collections.abc import Callable @@ -14,7 +13,7 @@ from pydantic.json_schema import SkipJsonSchema import fastmcp -from fastmcp.decorators import resolve_task_config +from fastmcp.decorators import resolve_task_config, set_fastmcp_meta from fastmcp.exceptions import FastMCPDeprecationWarning from fastmcp.resources.base import Resource, ResourceResult from fastmcp.server.auth.authorization import AuthCheck @@ -27,6 +26,11 @@ call_sync_fn_in_threadpool, is_coroutine_function, ) +from fastmcp.utilities.callable_utils import ( + get_callable_name, + is_callable_object, + prepare_callable, +) from fastmcp.utilities.mime import resolve_ui_mime_type if TYPE_CHECKING: @@ -159,27 +163,12 @@ def from_function( uri_obj = AnyUrl(metadata.uri) - # Get function name - use class name for callable objects - func_name = ( - metadata.name or getattr(fn, "__name__", None) or fn.__class__.__name__ - ) + func_name = metadata.name or get_callable_name(fn) - # Normalize task to TaskConfig and validate - task_value = metadata.task - if task_value is None: - task_config = TaskConfig(mode="forbidden") - elif isinstance(task_value, bool): - task_config = TaskConfig.from_bool(task_value) - else: - task_config = task_value + task_config = TaskConfig.normalize(metadata.task) task_config.validate_function(fn, func_name) - # if the fn is a callable class, we need to get the __call__ method from here out - if not inspect.isroutine(fn) and not isinstance(fn, functools.partial): - fn = fn.__call__ - # if the fn is a staticmethod, we need to work with the underlying function - if isinstance(fn, staticmethod): - fn = fn.__func__ + fn = prepare_callable(fn) # Transform Context type annotations to Depends() for unified DI fn = transform_context_annotations(fn) @@ -259,7 +248,7 @@ def resource( if isinstance(annotations, dict): annotations = Annotations(**annotations) - if inspect.isroutine(uri): + if is_callable_object(uri): raise TypeError( "The @resource decorator requires a URI. " "Use @resource('uri') instead of @resource" @@ -325,8 +314,7 @@ def attach_metadata(fn: F) -> F: task=task, auth=auth, ) - target = fn.__func__ if hasattr(fn, "__func__") else fn - target.__fastmcp__ = metadata + set_fastmcp_meta(fn, metadata) return fn def decorator(fn: F) -> F: diff --git a/src/fastmcp/resources/template.py b/src/fastmcp/resources/template.py index 265e88f913..7df6ff09ea 100644 --- a/src/fastmcp/resources/template.py +++ b/src/fastmcp/resources/template.py @@ -2,7 +2,6 @@ from __future__ import annotations -import functools import inspect import re from collections.abc import Callable @@ -30,6 +29,7 @@ without_injected_parameters, ) from fastmcp.server.tasks.config import TaskConfig, TaskMeta +from fastmcp.utilities.callable_utils import get_callable_name, prepare_callable from fastmcp.utilities.components import FastMCPComponent from fastmcp.utilities.json_schema import compress_schema from fastmcp.utilities.mime import resolve_ui_mime_type @@ -488,7 +488,7 @@ def from_function( ) -> FunctionResourceTemplate: """Create a template from a function.""" - func_name = name or getattr(fn, "__name__", None) or fn.__class__.__name__ + func_name = name or get_callable_name(fn) if func_name == "": raise ValueError("You must provide a name for lambda functions") @@ -555,21 +555,10 @@ def from_function( description = description if description is not None else inspect.getdoc(fn) - # Normalize task to TaskConfig and validate - if task is None: - task_config = TaskConfig(mode="forbidden") - elif isinstance(task, bool): - task_config = TaskConfig.from_bool(task) - else: - task_config = task + task_config = TaskConfig.normalize(task) task_config.validate_function(fn, func_name) - # if the fn is a callable class, we need to get the __call__ method from here out - if not inspect.isroutine(fn) and not isinstance(fn, functools.partial): - fn = fn.__call__ - # if the fn is a staticmethod, we need to work with the underlying function - if isinstance(fn, staticmethod): - fn = fn.__func__ + fn = prepare_callable(fn) # Transform Context type annotations to Depends() for unified DI fn = transform_context_annotations(fn) diff --git a/src/fastmcp/server/providers/local_provider/decorators/prompts.py b/src/fastmcp/server/providers/local_provider/decorators/prompts.py index 583aed5633..70cafcae35 100644 --- a/src/fastmcp/server/providers/local_provider/decorators/prompts.py +++ b/src/fastmcp/server/providers/local_provider/decorators/prompts.py @@ -15,10 +15,12 @@ from mcp.types import AnyFunction import fastmcp +from fastmcp.decorators import set_fastmcp_meta from fastmcp.prompts.base import Prompt from fastmcp.prompts.function_prompt import FunctionPrompt from fastmcp.server.auth.authorization import AuthCheck from fastmcp.server.tasks.config import TaskConfig +from fastmcp.utilities.callable_utils import is_callable_object if TYPE_CHECKING: from fastmcp.server.providers.local_provider import LocalProvider @@ -223,12 +225,11 @@ def decorate_and_register( auth=auth, enabled=enabled, ) - target = fn.__func__ if hasattr(fn, "__func__") else fn - target.__fastmcp__ = metadata # type: ignore[attr-defined] # ty:ignore[unresolved-attribute] + set_fastmcp_meta(fn, metadata) self.add_prompt(fn) return fn - if inspect.isroutine(name_or_fn): + if is_callable_object(name_or_fn): return decorate_and_register(name_or_fn, name) elif isinstance(name_or_fn, str): diff --git a/src/fastmcp/server/providers/local_provider/decorators/resources.py b/src/fastmcp/server/providers/local_provider/decorators/resources.py index 41043a461f..c302d5d88d 100644 --- a/src/fastmcp/server/providers/local_provider/decorators/resources.py +++ b/src/fastmcp/server/providers/local_provider/decorators/resources.py @@ -14,11 +14,13 @@ from mcp.types import Annotations, AnyFunction import fastmcp +from fastmcp.decorators import set_fastmcp_meta from fastmcp.resources.base import Resource from fastmcp.resources.function_resource import resource as standalone_resource from fastmcp.resources.template import ResourceTemplate from fastmcp.server.auth.authorization import AuthCheck from fastmcp.server.tasks.config import TaskConfig +from fastmcp.utilities.callable_utils import is_callable_object if TYPE_CHECKING: from fastmcp.server.providers.local_provider import LocalProvider @@ -159,7 +161,7 @@ def get_weather(city: str) -> str: if isinstance(annotations, dict): annotations = Annotations(**annotations) - if inspect.isroutine(uri): + if is_callable_object(uri): raise TypeError( "The @resource decorator was used incorrectly. " "It requires a URI as the first argument. " @@ -234,8 +236,7 @@ def decorator(fn: AnyFunction) -> Any: auth=auth, enabled=enabled, ) - target = fn.__func__ if hasattr(fn, "__func__") else fn - target.__fastmcp__ = metadata # type: ignore[attr-defined] # ty:ignore[unresolved-attribute] + set_fastmcp_meta(fn, metadata) self.add_resource(fn) return fn diff --git a/src/fastmcp/server/providers/local_provider/decorators/tools.py b/src/fastmcp/server/providers/local_provider/decorators/tools.py index c3dfd2fdd8..55b37da60d 100644 --- a/src/fastmcp/server/providers/local_provider/decorators/tools.py +++ b/src/fastmcp/server/providers/local_provider/decorators/tools.py @@ -27,11 +27,13 @@ from mcp.types import AnyFunction, ToolAnnotations import fastmcp +from fastmcp.decorators import set_fastmcp_meta from fastmcp.exceptions import FastMCPDeprecationWarning from fastmcp.server.auth.authorization import AuthCheck from fastmcp.server.tasks.config import TaskConfig from fastmcp.tools.base import Tool from fastmcp.tools.function_tool import FunctionTool +from fastmcp.utilities.callable_utils import is_callable_object from fastmcp.utilities.types import NotSet, NotSetT try: @@ -396,12 +398,11 @@ def decorate_and_register( auth=auth, enabled=enabled, ) - target = fn.__func__ if hasattr(fn, "__func__") else fn - target.__fastmcp__ = metadata # type: ignore[attr-defined] # ty:ignore[unresolved-attribute] + set_fastmcp_meta(fn, metadata) tool_obj = self.add_tool(fn) return fn - if inspect.isroutine(name_or_fn): + if is_callable_object(name_or_fn): return decorate_and_register(name_or_fn, name) elif isinstance(name_or_fn, str): diff --git a/src/fastmcp/server/tasks/config.py b/src/fastmcp/server/tasks/config.py index 1d5befa2a0..d0ddd421a8 100644 --- a/src/fastmcp/server/tasks/config.py +++ b/src/fastmcp/server/tasks/config.py @@ -6,14 +6,13 @@ from __future__ import annotations -import functools -import inspect from collections.abc import Callable from dataclasses import dataclass from datetime import timedelta from typing import Any, Literal from fastmcp.utilities.async_utils import is_coroutine_function +from fastmcp.utilities.callable_utils import prepare_callable # Task execution modes per SEP-1686 / MCP ToolExecution.taskSupport TaskMode = Literal["forbidden", "optional", "required"] @@ -90,6 +89,23 @@ def from_bool(cls, value: bool) -> TaskConfig: """ return cls(mode="optional" if value else "forbidden") + @classmethod + def normalize(cls, task: bool | TaskConfig | None) -> TaskConfig: + """Convert a task parameter to a TaskConfig. + + Args: + task: True/False for simple enable/disable, TaskConfig for full + control, or None for the default (forbidden). + + Returns: + A TaskConfig instance. + """ + if task is None: + return cls(mode="forbidden") + if isinstance(task, bool): + return cls.from_bool(task) + return task + def supports_tasks(self) -> bool: """Check if this component supports task execution. @@ -126,15 +142,7 @@ def validate_function(self, fn: Callable[..., Any], name: str) -> None: require_docket(f"`task=True` on function '{name}'") # Unwrap callable classes and staticmethods - fn_to_check = fn - if ( - not inspect.isroutine(fn) - and not isinstance(fn, functools.partial) - and callable(fn) - ): - fn_to_check = fn.__call__ - if isinstance(fn_to_check, staticmethod): - fn_to_check = fn_to_check.__func__ + fn_to_check = prepare_callable(fn) if not is_coroutine_function(fn_to_check): raise ValueError( diff --git a/src/fastmcp/tools/function_parsing.py b/src/fastmcp/tools/function_parsing.py index 21d10cb795..126fed47e8 100644 --- a/src/fastmcp/tools/function_parsing.py +++ b/src/fastmcp/tools/function_parsing.py @@ -2,7 +2,6 @@ from __future__ import annotations -import functools import inspect import types from collections.abc import Callable @@ -18,6 +17,7 @@ without_injected_parameters, ) from fastmcp.tools.base import ToolResult +from fastmcp.utilities.callable_utils import get_callable_name, prepare_callable from fastmcp.utilities.json_schema import compress_schema from fastmcp.utilities.logging import get_logger from fastmcp.utilities.types import ( @@ -165,15 +165,10 @@ def from_function( ) # collect name and doc before we potentially modify the function - fn_name = getattr(fn, "__name__", None) or fn.__class__.__name__ + fn_name = get_callable_name(fn) fn_doc = inspect.getdoc(fn) - # if the fn is a callable class, we need to get the __call__ method from here out - if not inspect.isroutine(fn) and not isinstance(fn, functools.partial): - fn = fn.__call__ - # if the fn is a staticmethod, we need to work with the underlying function - if isinstance(fn, staticmethod): - fn = fn.__func__ + fn = prepare_callable(fn) # Transform Context type annotations to Depends() for unified DI fn = transform_context_annotations(fn) diff --git a/src/fastmcp/tools/function_tool.py b/src/fastmcp/tools/function_tool.py index 94e7a0b3ea..fdc137bef1 100644 --- a/src/fastmcp/tools/function_tool.py +++ b/src/fastmcp/tools/function_tool.py @@ -24,7 +24,7 @@ from pydantic.json_schema import SkipJsonSchema import fastmcp -from fastmcp.decorators import resolve_task_config +from fastmcp.decorators import resolve_task_config, set_fastmcp_meta from fastmcp.exceptions import FastMCPDeprecationWarning from fastmcp.server.auth.authorization import AuthCheck from fastmcp.server.dependencies import without_injected_parameters @@ -39,6 +39,7 @@ call_sync_fn_in_threadpool, is_coroutine_function, ) +from fastmcp.utilities.callable_utils import is_callable_object from fastmcp.utilities.logging import get_logger from fastmcp.utilities.types import ( NotSet, @@ -193,14 +194,7 @@ def from_function( if func_name == "": raise ValueError("You must provide a name for lambda functions") - # Normalize task to TaskConfig - task_value = metadata.task - if task_value is None: - task_config = TaskConfig(mode="forbidden") - elif isinstance(task_value, bool): - task_config = TaskConfig.from_bool(task_value) - else: - task_config = task_value + task_config = TaskConfig.normalize(metadata.task) task_config.validate_function(fn, func_name) # Handle output_schema @@ -447,8 +441,7 @@ def attach_metadata(fn: F, tool_name: str | None) -> F: timeout=timeout, auth=auth, ) - target = fn.__func__ if hasattr(fn, "__func__") else fn - target.__fastmcp__ = metadata + set_fastmcp_meta(fn, metadata) return fn def decorator(fn: F, tool_name: str | None) -> F: @@ -462,7 +455,7 @@ def decorator(fn: F, tool_name: str | None) -> F: return create_tool(fn, tool_name) # type: ignore[return-value] # ty:ignore[invalid-return-type] return attach_metadata(fn, tool_name) - if inspect.isroutine(name_or_fn): + if is_callable_object(name_or_fn): return decorator(name_or_fn, name) elif isinstance(name_or_fn, str): if name is not None: diff --git a/src/fastmcp/utilities/callable_utils.py b/src/fastmcp/utilities/callable_utils.py new file mode 100644 index 0000000000..c572dc6ee2 --- /dev/null +++ b/src/fastmcp/utilities/callable_utils.py @@ -0,0 +1,79 @@ +"""Utilities for handling callables, including functools.partial objects. + +Provides centralized helpers for the shared steps in the tool/prompt/resource +``from_function`` pipelines, avoiding duplicated ``isinstance`` checks, name +extraction logic, and callable unwrapping across the codebase. +""" + +from __future__ import annotations + +import functools +import inspect +from collections.abc import Callable +from typing import Any, TypeGuard + + +def is_callable_object(obj: Any) -> TypeGuard[Callable[..., Any]]: + """Check if an object is a callable suitable for use as a tool, resource, or prompt. + + Returns True for functions, methods, builtins, and functools.partial objects. + This is a broader check than ``inspect.isroutine`` which returns False for + functools.partial. + """ + return inspect.isroutine(obj) or isinstance(obj, functools.partial) + + +def get_callable_name(fn: Any) -> str: + """Extract a human-readable name from a callable. + + Handles functions, callable classes, and functools.partial: + + - Regular functions: returns ``fn.__name__`` (e.g. ``"add"``) + - Callable classes: returns the class name (e.g. ``"MyTool"``) + - Partial with ``update_wrapper``: returns the wrapped name (e.g. ``"add"``) + - Partial without ``update_wrapper``: returns the underlying function name + (e.g. ``"add"`` instead of ``"partial"``) + """ + name = getattr(fn, "__name__", None) + if name is not None: + return name + # functools.partial without update_wrapper — use the underlying function's name + if isinstance(fn, functools.partial): + return getattr(fn.func, "__name__", None) or fn.__class__.__name__ + return fn.__class__.__name__ + + +def prepare_callable(fn: Callable[..., Any]) -> Callable[..., Any]: + """Prepare a callable for introspection by ``inspect.signature()`` and Pydantic. + + This handles three cases that would otherwise require special-casing in every + ``from_function`` method: + + 1. **functools.partial with __wrapped__**: ``functools.update_wrapper`` sets + ``__wrapped__`` which causes ``inspect.signature()`` and Pydantic to follow + it back to the original function, ignoring the partial's bound arguments. + We strip ``__wrapped__`` by reconstructing the partial. + + 2. **Callable classes**: Non-routine callables (classes with ``__call__``) need + to be unwrapped to their ``__call__`` method so ``inspect.signature()`` sees + the right parameters. + + 3. **staticmethod**: Needs unwrapping to the underlying function. + + Call this AFTER extracting name/doc from the original callable, since this + may change what ``__name__`` and ``__doc__`` return. + """ + # Strip __wrapped__ from partials so Pydantic sees the partial's own + # signature with bound args removed, not the original function's signature. + if isinstance(fn, functools.partial) and hasattr(fn, "__wrapped__"): + fn = functools.partial(fn.func, *fn.args, **fn.keywords) + + # Callable classes (not routines, not partials) → unwrap to __call__ + if not inspect.isroutine(fn) and not isinstance(fn, functools.partial): + fn = fn.__call__ + + # staticmethod → unwrap to underlying function + if isinstance(fn, staticmethod): + fn = fn.__func__ + + return fn diff --git a/tests/tools/tool/test_partial.py b/tests/tools/tool/test_partial.py new file mode 100644 index 0000000000..a86bed4665 --- /dev/null +++ b/tests/tools/tool/test_partial.py @@ -0,0 +1,129 @@ +"""Tests for functools.partial support as tools, prompts, and resources. + +See https://github.com/PrefectHQ/fastmcp/issues/3266 +""" + +import functools + +from mcp.types import TextContent + +from fastmcp import Client, FastMCP +from fastmcp.tools.function_tool import FunctionTool as Tool + + +class TestPartialTool: + """Test tools created from functools.partial objects.""" + + async def test_partial_sync(self): + def add(x: int, y: int) -> int: + return x + y + + partial_add = functools.partial(add, y=10) + functools.update_wrapper(partial_add, add) + + tool = Tool.from_function(partial_add) + result = await tool.run({"x": 5}) + assert result.content == [TextContent(type="text", text="15")] + + async def test_partial_async(self): + async def multiply(x: int, factor: int) -> int: + return x * factor + + partial_mul = functools.partial(multiply, factor=3) + functools.update_wrapper(partial_mul, multiply) + + tool = Tool.from_function(partial_mul) + result = await tool.run({"x": 7}) + assert result.content == [TextContent(type="text", text="21")] + + async def test_partial_preserves_name(self): + def greet(name: str, greeting: str = "Hello") -> str: + """Greet someone.""" + return f"{greeting}, {name}!" + + partial_greet = functools.partial(greet, greeting="Hi") + functools.update_wrapper(partial_greet, greet) + + tool = Tool.from_function(partial_greet) + assert tool.name == "greet" + assert tool.description == "Greet someone." + + async def test_partial_without_update_wrapper(self): + def add(x: int, y: int) -> int: + return x + y + + partial_add = functools.partial(add, y=10) + + tool = Tool.from_function(partial_add, name="add_ten") + result = await tool.run({"x": 5}) + assert result.content == [TextContent(type="text", text="15")] + + async def test_partial_with_add_tool(self): + mcp = FastMCP("test") + + def greet(name: str, greeting: str = "Hello") -> str: + return f"{greeting}, {name}!" + + partial_greet = functools.partial(greet, greeting="Hey") + functools.update_wrapper(partial_greet, greet) + + mcp.add_tool(partial_greet) + + result = await mcp.call_tool("greet", {"name": "World"}) + assert result.content == [TextContent(type="text", text="Hey, World!")] + + async def test_partial_with_server_tool_decorator(self): + mcp = FastMCP("test") + + def add(x: int, y: int) -> int: + return x + y + + partial_add = functools.partial(add, y=100) + functools.update_wrapper(partial_add, add) + + mcp.tool(partial_add) + + result = await mcp.call_tool("add", {"x": 5}) + assert result.content == [TextContent(type="text", text="105")] + + +class TestPartialPrompt: + """Test prompts created from functools.partial objects.""" + + async def test_partial_prompt_with_decorator(self): + """Partial can be registered via @mcp.prompt() decorator.""" + mcp = FastMCP("test") + + def greet_prompt(name: str, lang: str) -> str: + return f"Say hello to {name} in {lang}." + + partial_greet = functools.partial(greet_prompt, lang="French") + functools.update_wrapper(partial_greet, greet_prompt) + + mcp.prompt(partial_greet) + + async with Client(mcp) as client: + result = await client.get_prompt("greet_prompt", {"name": "Alice"}) + assert "Alice" in str(result.messages[0]) + assert "French" in str(result.messages[0]) + + +class TestPartialResource: + """Test resources created from functools.partial objects.""" + + async def test_partial_resource_with_decorator(self): + """Partial can be registered via @mcp.resource() decorator.""" + mcp = FastMCP("test") + + def get_data(key: str, fmt: str = "text") -> str: + return f"{key} in {fmt} format" + + partial_data = functools.partial(get_data, fmt="json") + functools.update_wrapper(partial_data, get_data) + + mcp.resource("data://{key}")(partial_data) + + async with Client(mcp) as client: + content = await client.read_resource("data://users") + assert "users" in str(content) + assert "json" in str(content) diff --git a/tests/utilities/test_callable_utils.py b/tests/utilities/test_callable_utils.py new file mode 100644 index 0000000000..89202bb57c --- /dev/null +++ b/tests/utilities/test_callable_utils.py @@ -0,0 +1,121 @@ +"""Tests for callable utility functions.""" + +import functools + +from fastmcp.utilities.callable_utils import ( + get_callable_name, + is_callable_object, + prepare_callable, +) + + +class TestIsCallableObject: + def test_function(self): + def fn(): + pass + + assert is_callable_object(fn) is True + + def test_async_function(self): + async def fn(): + pass + + assert is_callable_object(fn) is True + + def test_partial(self): + def fn(x, y): + return x + y + + assert is_callable_object(functools.partial(fn, y=1)) is True + + def test_callable_class(self): + class MyCallable: + def __call__(self): + pass + + assert is_callable_object(MyCallable()) is False + + def test_string(self): + assert is_callable_object("not a callable") is False + + def test_none(self): + assert is_callable_object(None) is False + + +class TestGetCallableName: + def test_function(self): + def my_function(): + pass + + assert get_callable_name(my_function) == "my_function" + + def test_lambda(self): + assert get_callable_name(lambda: None) == "" + + def test_partial_with_update_wrapper(self): + def add(x, y): + return x + y + + p = functools.partial(add, y=10) + functools.update_wrapper(p, add) + assert get_callable_name(p) == "add" + + def test_partial_without_update_wrapper(self): + def add(x, y): + return x + y + + p = functools.partial(add, y=10) + assert get_callable_name(p) == "add" + + def test_callable_class(self): + class MyTool: + def __call__(self): + pass + + assert get_callable_name(MyTool()) == "MyTool" + + +class TestPrepareCallable: + def test_regular_function_unchanged(self): + def fn(x): + return x + + assert prepare_callable(fn) is fn + + def test_strips_wrapped_from_partial(self): + def add(x, y): + return x + y + + p = functools.partial(add, y=10) + functools.update_wrapper(p, add) + assert hasattr(p, "__wrapped__") + + prepared = prepare_callable(p) + assert isinstance(prepared, functools.partial) + assert not hasattr(prepared, "__wrapped__") + assert prepared.keywords == {"y": 10} + + def test_partial_without_wrapper_unchanged(self): + def add(x, y): + return x + y + + p = functools.partial(add, y=10) + prepared = prepare_callable(p) + assert isinstance(prepared, functools.partial) + assert prepared.func is add + + def test_callable_class_unwrapped(self): + class MyCallable: + def __call__(self, x): + return x + + obj = MyCallable() + prepared = prepare_callable(obj) + assert prepared == obj.__call__ + + def test_staticmethod_unwrapped(self): + def fn(x): + return x + + sm = staticmethod(fn) + assert prepare_callable(sm) is fn