diff --git a/src/mcp/client/session.py b/src/mcp/client/session.py index 7bb8821f..8d38cd16 100644 --- a/src/mcp/client/session.py +++ b/src/mcp/client/session.py @@ -140,6 +140,12 @@ async def initialize(self) -> types.InitializeResult: ) ), types.InitializeResult, + # TODO should set a request_read_timeout_seconds as per + # guidance from BaseSession.send_request not obvious + # what subsequent process should be, refer the following + # specification for more details + # https://modelcontextprotocol.io/specification/2025-03-26/basic/utilities/cancellation + cancellable=False, ) if result.protocolVersion not in SUPPORTED_PROTOCOL_VERSIONS: @@ -259,6 +265,7 @@ async def call_tool( name: str, arguments: dict[str, Any] | None = None, read_timeout_seconds: timedelta | None = None, + cancellable: bool = True, ) -> types.CallToolResult: """Send a tools/call request.""" @@ -271,6 +278,7 @@ async def call_tool( ), types.CallToolResult, request_read_timeout_seconds=read_timeout_seconds, + cancellable=cancellable, ) async def list_prompts(self) -> types.ListPromptsResult: diff --git a/src/mcp/server/lowlevel/server.py b/src/mcp/server/lowlevel/server.py index 4b97b33d..ffe1bf0c 100644 --- a/src/mcp/server/lowlevel/server.py +++ b/src/mcp/server/lowlevel/server.py @@ -85,7 +85,7 @@ async def main(): from mcp.shared.context import RequestContext from mcp.shared.exceptions import McpError from mcp.shared.message import SessionMessage -from mcp.shared.session import RequestResponder +from mcp.shared.session import RequestId, RequestResponder logger = logging.getLogger(__name__) @@ -427,7 +427,7 @@ async def handler(req: types.CallToolRequest): def progress_notification(self): def decorator( - func: Callable[[str | int, float, float | None], Awaitable[None]], + func: Callable[[types.ProgressToken, float, float | None], Awaitable[None]], ): logger.debug("Registering handler for ProgressNotification") @@ -441,6 +441,20 @@ async def handler(req: types.ProgressNotification): return decorator + def cancel_notification(self): + def decorator( + func: Callable[[RequestId, str | None], Awaitable[None]], + ): + logger.debug("Registering handler for CancelledNotification") + + async def handler(req: types.CancelledNotification): + await func(req.params.requestId, req.params.reason) + + self.notification_handlers[types.CancelledNotification] = handler + return func + + return decorator + def completion(self): """Provides completions for prompts and resource templates""" diff --git a/src/mcp/shared/session.py b/src/mcp/shared/session.py index cce8b118..3e203ca3 100644 --- a/src/mcp/shared/session.py +++ b/src/mcp/shared/session.py @@ -14,7 +14,9 @@ from mcp.shared.exceptions import McpError from mcp.shared.message import MessageMetadata, ServerMessageMetadata, SessionMessage from mcp.types import ( + REQUEST_CANCELLED, CancelledNotification, + CancelledNotificationParams, ClientNotification, ClientRequest, ClientResult, @@ -33,6 +35,12 @@ SendRequestT = TypeVar("SendRequestT", ClientRequest, ServerRequest) SendResultT = TypeVar("SendResultT", ClientResult, ServerResult) SendNotificationT = TypeVar("SendNotificationT", ClientNotification, ServerNotification) +SendNotificationInternalT = TypeVar( + "SendNotificationInternalT", + CancelledNotification, + ClientNotification, + ServerNotification, +) ReceiveRequestT = TypeVar("ReceiveRequestT", ClientRequest, ServerRequest) ReceiveResultT = TypeVar("ReceiveResultT", bound=BaseModel) ReceiveNotificationT = TypeVar( @@ -214,12 +222,25 @@ async def send_request( result_type: type[ReceiveResultT], request_read_timeout_seconds: timedelta | None = None, metadata: MessageMetadata = None, + cancellable: bool = True, ) -> ReceiveResultT: """ Sends a request and wait for a response. Raises an McpError if the response contains an error. If a request read timeout is provided, it will take precedence over the session read timeout. + If cancellable is set to False then the request will wait + request_read_timeout_seconds to complete and ignore any attempt to + cancel via the anyio.CancelScope within which this method was called. + + If cancellable is set to True (default) if the anyio.CancelScope within + which this method was called is cancelled it will generate a + CancelationNotfication and send this to the server which should then abort + the task however the server is is not guaranteed to honour this request. + + For further information on the CancelNotification flow refer to + https://modelcontextprotocol.io/specification/2025-03-26/basic/utilities/cancellation + Do not use this method to emit notifications! Use send_notification() instead. """ @@ -254,20 +275,38 @@ async def send_request( elif self._session_read_timeout_seconds is not None: timeout = self._session_read_timeout_seconds.total_seconds() - try: - with anyio.fail_after(timeout): - response_or_error = await response_stream_reader.receive() - except TimeoutError: - raise McpError( - ErrorData( - code=httpx.codes.REQUEST_TIMEOUT, - message=( - f"Timed out while waiting for response to " - f"{request.__class__.__name__}. Waited " - f"{timeout} seconds." - ), + with anyio.CancelScope(shield=not cancellable): + try: + with anyio.fail_after(timeout) as scope: + response_or_error = await response_stream_reader.receive() + + if scope.cancel_called: + notification = CancelledNotification( + method="notifications/cancelled", + params=CancelledNotificationParams( + requestId=request_id, reason="cancelled" + ), + ) + await self._send_notification_internal( + notification, request_id + ) + raise McpError( + ErrorData( + code=REQUEST_CANCELLED, message="Request cancelled" + ) + ) + + except TimeoutError: + raise McpError( + ErrorData( + code=httpx.codes.REQUEST_TIMEOUT, + message=( + f"Timed out while waiting for response to " + f"{request.__class__.__name__}. Waited " + f"{timeout} seconds." + ), + ) ) - ) if isinstance(response_or_error, JSONRPCError): raise McpError(response_or_error.error) @@ -288,6 +327,16 @@ async def send_notification( Emits a notification, which is a one-way message that does not expect a response. """ + await self._send_notification_internal(notification, related_request_id) + + # this method is required as SendNotificationT type checking prevents + # internal use for sending cancelation - typechecking sorcery may be + # required + async def _send_notification_internal( + self, + notification: SendNotificationInternalT, + related_request_id: RequestId | None = None, + ) -> None: # Some transport implementations may need to set the related_request_id # to attribute to the notifications to the request that triggered them. jsonrpc_notification = JSONRPCNotification( diff --git a/src/mcp/types.py b/src/mcp/types.py index 6ab7fba5..5c077ca9 100644 --- a/src/mcp/types.py +++ b/src/mcp/types.py @@ -146,6 +146,7 @@ class JSONRPCResponse(BaseModel): METHOD_NOT_FOUND = -32601 INVALID_PARAMS = -32602 INTERNAL_ERROR = -32603 +REQUEST_CANCELLED = -32604 class ErrorData(BaseModel): diff --git a/tests/shared/test_session.py b/tests/shared/test_session.py index 59cb30c8..c0f60da8 100644 --- a/tests/shared/test_session.py +++ b/tests/shared/test_session.py @@ -1,4 +1,5 @@ from collections.abc import AsyncGenerator +from datetime import timedelta import anyio import pytest @@ -9,10 +10,6 @@ from mcp.shared.exceptions import McpError from mcp.shared.memory import create_connected_server_and_client_session from mcp.types import ( - CancelledNotification, - CancelledNotificationParams, - ClientNotification, - ClientRequest, EmptyResult, ) @@ -46,11 +43,11 @@ async def test_in_flight_requests_cleared_after_completion( @pytest.mark.anyio async def test_request_cancellation(): """Test that requests can be cancelled while in-flight.""" - # The tool is already registered in the fixture ev_tool_called = anyio.Event() + ev_tool_cancelled = anyio.Event() ev_cancelled = anyio.Event() - request_id = None + ev_cancel_notified = anyio.Event() # Start the request in a separate task so we can cancel it def make_server() -> Server: @@ -59,14 +56,24 @@ def make_server() -> Server: # Register the tool handler @server.call_tool() async def handle_call_tool(name: str, arguments: dict | None) -> list: - nonlocal request_id, ev_tool_called + nonlocal ev_tool_called, ev_tool_cancelled if name == "slow_tool": - request_id = server.request_context.request_id ev_tool_called.set() - await anyio.sleep(10) # Long enough to ensure we can cancel - return [] + with anyio.CancelScope(): + try: + await anyio.sleep(10) # Long enough to ensure we can cancel + return [] + except anyio.get_cancelled_exc_class() as err: + ev_tool_cancelled.set() + raise err + raise ValueError(f"Unknown tool: {name}") + @server.cancel_notification() + async def handle_cancel(requestId: str | int, reason: str | None): + nonlocal ev_cancel_notified + ev_cancel_notified.set() + # Register the tool so it shows up in list_tools @server.list_tools() async def handle_list_tools() -> list[types.Tool]: @@ -80,20 +87,10 @@ async def handle_list_tools() -> list[types.Tool]: return server - async def make_request(client_session): + async def make_request(client_session: ClientSession): nonlocal ev_cancelled try: - await client_session.send_request( - ClientRequest( - types.CallToolRequest( - method="tools/call", - params=types.CallToolRequestParams( - name="slow_tool", arguments={} - ), - ) - ), - types.CallToolResult, - ) + await client_session.call_tool("slow_tool") pytest.fail("Request should have been cancelled") except McpError as e: # Expected - request was cancelled @@ -110,17 +107,87 @@ async def make_request(client_session): with anyio.fail_after(1): # Timeout after 1 second await ev_tool_called.wait() - # Send cancellation notification - assert request_id is not None - await client_session.send_notification( - ClientNotification( - CancelledNotification( - method="notifications/cancelled", - params=CancelledNotificationParams(requestId=request_id), - ) - ) - ) + # Cancel the task via task group + tg.cancel_scope.cancel() # Give cancellation time to process with anyio.fail_after(1): await ev_cancelled.wait() + + # Check server cancel notification received + with anyio.fail_after(1): + await ev_cancel_notified.wait() + + # Give cancellation time to process on server + with anyio.fail_after(1): + await ev_tool_cancelled.wait() + + +@pytest.mark.anyio +async def test_request_cancellation_uncancellable(): + """Test that asserts a call with cancellable=False is not cancelled on + server when cancel scope on client is set.""" + + ev_tool_called = anyio.Event() + ev_tool_commplete = anyio.Event() + ev_cancelled = anyio.Event() + + # Start the request in a separate task so we can cancel it + def make_server() -> Server: + server = Server(name="TestSessionServer") + + # Register the tool handler + @server.call_tool() + async def handle_call_tool(name: str, arguments: dict | None) -> list: + nonlocal ev_tool_called, ev_tool_commplete + if name == "slow_tool": + ev_tool_called.set() + with anyio.CancelScope(): + with anyio.fail_after(10): # Long enough to ensure we can cancel + await ev_cancelled.wait() + ev_tool_commplete.set() + return [] + + raise ValueError(f"Unknown tool: {name}") + + # Register the tool so it shows up in list_tools + @server.list_tools() + async def handle_list_tools() -> list[types.Tool]: + return [ + types.Tool( + name="slow_tool", + description="A slow tool that takes 10 seconds to complete", + inputSchema={}, + ) + ] + + return server + + async def make_request(client_session: ClientSession): + nonlocal ev_cancelled + try: + await client_session.call_tool( + "slow_tool", + cancellable=False, + read_timeout_seconds=timedelta(seconds=10), + ) + except McpError: + pytest.fail("Request should not have been cancelled") + + async with create_connected_server_and_client_session( + make_server() + ) as client_session: + async with anyio.create_task_group() as tg: + tg.start_soon(make_request, client_session) + + # Wait for the request to be in-flight + with anyio.fail_after(1): # Timeout after 1 second + await ev_tool_called.wait() + + # Cancel the task via task group + tg.cancel_scope.cancel() + ev_cancelled.set() + + # Check server completed regardless + with anyio.fail_after(1): + await ev_tool_commplete.wait()