-
Notifications
You must be signed in to change notification settings - Fork 1.4k
Sending cancellation notification to server based on client anyio.CancelScope status #628
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
f9598ad
dfb3686
abda067
24553c6
fe49931
efd0ffd
1364b7a
17ae44c
06f4b3c
07e1a52
92f806b
e46c693
2a24e0c
87722f8
f0782d2
a0164f9
bf220d5
45ac52a
8b7f1cd
ac4b822
d86b4a5
6f4ae44
11d2e52
235df35
f96aaa5
bd73448
3817fe2
2fd27a2
2e86d32
1039b99
b926970
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -14,7 +14,9 @@ | |
from mcp.shared.exceptions import McpError | ||
from mcp.shared.message import MessageMetadata, ServerMessageMetadata, SessionMessage | ||
from mcp.types import ( | ||
REQUEST_CANCELLED, | ||
CancelledNotification, | ||
CancelledNotificationParams, | ||
ClientNotification, | ||
ClientRequest, | ||
ClientResult, | ||
|
@@ -33,6 +35,12 @@ | |
SendRequestT = TypeVar("SendRequestT", ClientRequest, ServerRequest) | ||
SendResultT = TypeVar("SendResultT", ClientResult, ServerResult) | ||
SendNotificationT = TypeVar("SendNotificationT", ClientNotification, ServerNotification) | ||
SendNotificationInternalT = TypeVar( | ||
"SendNotificationInternalT", | ||
CancelledNotification, | ||
ClientNotification, | ||
ServerNotification, | ||
) | ||
ReceiveRequestT = TypeVar("ReceiveRequestT", ClientRequest, ServerRequest) | ||
ReceiveResultT = TypeVar("ReceiveResultT", bound=BaseModel) | ||
ReceiveNotificationT = TypeVar( | ||
|
@@ -214,12 +222,25 @@ async def send_request( | |
result_type: type[ReceiveResultT], | ||
request_read_timeout_seconds: timedelta | None = None, | ||
metadata: MessageMetadata = None, | ||
cancellable: bool = True, | ||
) -> ReceiveResultT: | ||
""" | ||
Sends a request and wait for a response. Raises an McpError if the | ||
response contains an error. If a request read timeout is provided, it | ||
will take precedence over the session read timeout. | ||
|
||
If cancellable is set to False then the request will wait | ||
request_read_timeout_seconds to complete and ignore any attempt to | ||
cancel via the anyio.CancelScope within which this method was called. | ||
|
||
If cancellable is set to True (default) if the anyio.CancelScope within | ||
which this method was called is cancelled it will generate a | ||
CancelationNotfication and send this to the server which should then abort | ||
the task however the server is is not guaranteed to honour this request. | ||
|
||
For further information on the CancelNotification flow refer to | ||
https://modelcontextprotocol.io/specification/2025-03-26/basic/utilities/cancellation | ||
|
||
Do not use this method to emit notifications! Use send_notification() | ||
instead. | ||
""" | ||
|
@@ -254,20 +275,38 @@ async def send_request( | |
elif self._session_read_timeout_seconds is not None: | ||
timeout = self._session_read_timeout_seconds.total_seconds() | ||
|
||
try: | ||
with anyio.fail_after(timeout): | ||
response_or_error = await response_stream_reader.receive() | ||
except TimeoutError: | ||
raise McpError( | ||
ErrorData( | ||
code=httpx.codes.REQUEST_TIMEOUT, | ||
message=( | ||
f"Timed out while waiting for response to " | ||
f"{request.__class__.__name__}. Waited " | ||
f"{timeout} seconds." | ||
), | ||
with anyio.CancelScope(shield=not cancellable): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Alternative to this might be to separate client cancelation and server cancelation, e.g. client can be cancelled and server cancellation event is only set if a flag such as 'propagate_client_cancelation_to_server' (shorter names available) is set on the request. |
||
try: | ||
with anyio.fail_after(timeout) as scope: | ||
response_or_error = await response_stream_reader.receive() | ||
|
||
if scope.cancel_called: | ||
notification = CancelledNotification( | ||
method="notifications/cancelled", | ||
params=CancelledNotificationParams( | ||
requestId=request_id, reason="cancelled" | ||
), | ||
) | ||
await self._send_notification_internal( | ||
notification, request_id | ||
) | ||
raise McpError( | ||
ErrorData( | ||
code=REQUEST_CANCELLED, message="Request cancelled" | ||
) | ||
) | ||
|
||
except TimeoutError: | ||
raise McpError( | ||
ErrorData( | ||
code=httpx.codes.REQUEST_TIMEOUT, | ||
message=( | ||
f"Timed out while waiting for response to " | ||
f"{request.__class__.__name__}. Waited " | ||
f"{timeout} seconds." | ||
), | ||
) | ||
) | ||
) | ||
|
||
if isinstance(response_or_error, JSONRPCError): | ||
raise McpError(response_or_error.error) | ||
|
@@ -288,6 +327,16 @@ async def send_notification( | |
Emits a notification, which is a one-way message that does not expect | ||
a response. | ||
""" | ||
await self._send_notification_internal(notification, related_request_id) | ||
|
||
# this method is required as SendNotificationT type checking prevents | ||
# internal use for sending cancelation - typechecking sorcery may be | ||
# required | ||
async def _send_notification_internal( | ||
self, | ||
notification: SendNotificationInternalT, | ||
related_request_id: RequestId | None = None, | ||
) -> None: | ||
# Some transport implementations may need to set the related_request_id | ||
# to attribute to the notifications to the request that triggered them. | ||
jsonrpc_notification = JSONRPCNotification( | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,4 +1,5 @@ | ||
from collections.abc import AsyncGenerator | ||
from datetime import timedelta | ||
|
||
import anyio | ||
import pytest | ||
|
@@ -9,10 +10,6 @@ | |
from mcp.shared.exceptions import McpError | ||
from mcp.shared.memory import create_connected_server_and_client_session | ||
from mcp.types import ( | ||
CancelledNotification, | ||
CancelledNotificationParams, | ||
ClientNotification, | ||
ClientRequest, | ||
EmptyResult, | ||
) | ||
|
||
|
@@ -46,11 +43,11 @@ async def test_in_flight_requests_cleared_after_completion( | |
@pytest.mark.anyio | ||
async def test_request_cancellation(): | ||
"""Test that requests can be cancelled while in-flight.""" | ||
# The tool is already registered in the fixture | ||
|
||
ev_tool_called = anyio.Event() | ||
ev_tool_cancelled = anyio.Event() | ||
ev_cancelled = anyio.Event() | ||
request_id = None | ||
ev_cancel_notified = anyio.Event() | ||
|
||
# Start the request in a separate task so we can cancel it | ||
def make_server() -> Server: | ||
|
@@ -59,14 +56,24 @@ def make_server() -> Server: | |
# Register the tool handler | ||
@server.call_tool() | ||
async def handle_call_tool(name: str, arguments: dict | None) -> list: | ||
nonlocal request_id, ev_tool_called | ||
nonlocal ev_tool_called, ev_tool_cancelled | ||
if name == "slow_tool": | ||
request_id = server.request_context.request_id | ||
ev_tool_called.set() | ||
await anyio.sleep(10) # Long enough to ensure we can cancel | ||
return [] | ||
with anyio.CancelScope(): | ||
try: | ||
await anyio.sleep(10) # Long enough to ensure we can cancel | ||
return [] | ||
except anyio.get_cancelled_exc_class() as err: | ||
ev_tool_cancelled.set() | ||
raise err | ||
|
||
raise ValueError(f"Unknown tool: {name}") | ||
|
||
@server.cancel_notification() | ||
async def handle_cancel(requestId: str | int, reason: str | None): | ||
nonlocal ev_cancel_notified | ||
ev_cancel_notified.set() | ||
|
||
# Register the tool so it shows up in list_tools | ||
@server.list_tools() | ||
async def handle_list_tools() -> list[types.Tool]: | ||
|
@@ -80,20 +87,10 @@ async def handle_list_tools() -> list[types.Tool]: | |
|
||
return server | ||
|
||
async def make_request(client_session): | ||
async def make_request(client_session: ClientSession): | ||
nonlocal ev_cancelled | ||
try: | ||
await client_session.send_request( | ||
ClientRequest( | ||
types.CallToolRequest( | ||
method="tools/call", | ||
params=types.CallToolRequestParams( | ||
name="slow_tool", arguments={} | ||
), | ||
) | ||
), | ||
types.CallToolResult, | ||
) | ||
await client_session.call_tool("slow_tool") | ||
pytest.fail("Request should have been cancelled") | ||
except McpError as e: | ||
# Expected - request was cancelled | ||
|
@@ -110,17 +107,87 @@ async def make_request(client_session): | |
with anyio.fail_after(1): # Timeout after 1 second | ||
await ev_tool_called.wait() | ||
|
||
# Send cancellation notification | ||
assert request_id is not None | ||
await client_session.send_notification( | ||
ClientNotification( | ||
CancelledNotification( | ||
method="notifications/cancelled", | ||
params=CancelledNotificationParams(requestId=request_id), | ||
) | ||
) | ||
) | ||
# Cancel the task via task group | ||
tg.cancel_scope.cancel() | ||
|
||
# Give cancellation time to process | ||
with anyio.fail_after(1): | ||
await ev_cancelled.wait() | ||
|
||
# Check server cancel notification received | ||
with anyio.fail_after(1): | ||
await ev_cancel_notified.wait() | ||
|
||
# Give cancellation time to process on server | ||
with anyio.fail_after(1): | ||
await ev_tool_cancelled.wait() | ||
|
||
|
||
@pytest.mark.anyio | ||
async def test_request_cancellation_uncancellable(): | ||
"""Test that asserts a call with cancellable=False is not cancelled on | ||
server when cancel scope on client is set.""" | ||
|
||
ev_tool_called = anyio.Event() | ||
ev_tool_commplete = anyio.Event() | ||
ev_cancelled = anyio.Event() | ||
|
||
# Start the request in a separate task so we can cancel it | ||
def make_server() -> Server: | ||
server = Server(name="TestSessionServer") | ||
|
||
# Register the tool handler | ||
@server.call_tool() | ||
async def handle_call_tool(name: str, arguments: dict | None) -> list: | ||
nonlocal ev_tool_called, ev_tool_commplete | ||
if name == "slow_tool": | ||
ev_tool_called.set() | ||
with anyio.CancelScope(): | ||
with anyio.fail_after(10): # Long enough to ensure we can cancel | ||
await ev_cancelled.wait() | ||
ev_tool_commplete.set() | ||
return [] | ||
|
||
raise ValueError(f"Unknown tool: {name}") | ||
|
||
# Register the tool so it shows up in list_tools | ||
@server.list_tools() | ||
async def handle_list_tools() -> list[types.Tool]: | ||
return [ | ||
types.Tool( | ||
name="slow_tool", | ||
description="A slow tool that takes 10 seconds to complete", | ||
inputSchema={}, | ||
) | ||
] | ||
|
||
return server | ||
|
||
async def make_request(client_session: ClientSession): | ||
nonlocal ev_cancelled | ||
try: | ||
await client_session.call_tool( | ||
"slow_tool", | ||
cancellable=False, | ||
read_timeout_seconds=timedelta(seconds=10), | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. A test that validates a short timeout with an uncancelable call would be sensible |
||
) | ||
except McpError: | ||
pytest.fail("Request should not have been cancelled") | ||
|
||
async with create_connected_server_and_client_session( | ||
make_server() | ||
) as client_session: | ||
async with anyio.create_task_group() as tg: | ||
tg.start_soon(make_request, client_session) | ||
|
||
# Wait for the request to be in-flight | ||
with anyio.fail_after(1): # Timeout after 1 second | ||
await ev_tool_called.wait() | ||
|
||
# Cancel the task via task group | ||
tg.cancel_scope.cancel() | ||
ev_cancelled.set() | ||
|
||
# Check server completed regardless | ||
with anyio.fail_after(1): | ||
await ev_tool_commplete.wait() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This code is ugly appears to need some python typechecking sorcery to tidy up though