diff --git a/src/mcp/server/streamable_http.py b/src/mcp/server/streamable_http.py index a94cc2834..fd3c0d532 100644 --- a/src/mcp/server/streamable_http.py +++ b/src/mcp/server/streamable_http.py @@ -483,7 +483,7 @@ async def sse_writer(): ): break except Exception as e: - logger.exception(f"Error in SSE writer: {e}") + logger.warning(f"Error in SSE writer: {e}", exc_info=True) finally: logger.debug("Closing SSE writer") await self._clean_up_memory_streams(request_id) @@ -517,13 +517,13 @@ async def sse_writer(): session_message = SessionMessage(message, metadata=metadata) await writer.send(session_message) except Exception: - logger.exception("SSE response error") + logger.warning("SSE response error", exc_info=True) await sse_stream_writer.aclose() await sse_stream_reader.aclose() await self._clean_up_memory_streams(request_id) except Exception as err: - logger.exception("Error handling POST request") + logger.warning("Error handling POST request", exc_info=True) response = self._create_error_response( f"Error handling POST request: {err}", HTTPStatus.INTERNAL_SERVER_ERROR, @@ -610,7 +610,7 @@ async def standalone_sse_writer(): event_data = self._create_event_data(event_message) await sse_stream_writer.send(event_data) except Exception as e: - logger.exception(f"Error in standalone SSE writer: {e}") + logger.warning(f"Error in standalone SSE writer: {e}", exc_info=True) finally: logger.debug("Closing standalone SSE writer") await self._clean_up_memory_streams(GET_STREAM_KEY) @@ -626,7 +626,7 @@ async def standalone_sse_writer(): # This will send headers immediately and establish the SSE connection await response(request.scope, request.receive, send) except Exception as e: - logger.exception(f"Error in standalone SSE response: {e}") + logger.warning(f"Error in standalone SSE response: {e}", exc_info=True) await sse_stream_writer.aclose() await sse_stream_reader.aclose() await self._clean_up_memory_streams(GET_STREAM_KEY) diff --git a/src/mcp/server/streamable_http_manager.py b/src/mcp/server/streamable_http_manager.py index 8188c2f3b..c3808ea85 100644 --- a/src/mcp/server/streamable_http_manager.py +++ b/src/mcp/server/streamable_http_manager.py @@ -51,7 +51,6 @@ class StreamableHTTPSessionManager: json_response: Whether to use JSON responses instead of SSE streams stateless: If True, creates a completely fresh transport for each request with no session tracking or state persistence between requests. - """ def __init__( @@ -171,12 +170,15 @@ async def run_stateless_server( async with http_transport.connect() as streams: read_stream, write_stream = streams task_status.started() - await self.app.run( - read_stream, - write_stream, - self.app.create_initialization_options(), - stateless=True, - ) + try: + await self.app.run( + read_stream, + write_stream, + self.app.create_initialization_options(), + stateless=True, + ) + except Exception as e: + logger.warning(f"Stateless session crashed: {e}", exc_info=True) # Assert task group is not None for type checking assert self._task_group is not None @@ -235,12 +237,37 @@ async def run_server( async with http_transport.connect() as streams: read_stream, write_stream = streams task_status.started() - await self.app.run( - read_stream, - write_stream, - self.app.create_initialization_options(), - stateless=False, # Stateful mode - ) + try: + await self.app.run( + read_stream, + write_stream, + self.app.create_initialization_options(), + stateless=False, # Stateful mode + ) + except Exception as e: + logger.warning( + f"Session {http_transport.mcp_session_id} crashed: {e}", + exc_info=True, + ) + finally: + # Only remove from instances if not terminated + if ( + http_transport.mcp_session_id + and http_transport.mcp_session_id + in self._server_instances + and not ( + hasattr(http_transport, "_terminated") + and http_transport._terminated # pyright: ignore + ) + ): + logger.info( + "Cleaning up crashed session " + f"{http_transport.mcp_session_id} from " + "active instances." + ) + del self._server_instances[ + http_transport.mcp_session_id + ] # Assert task group is not None for type checking assert self._task_group is not None diff --git a/tests/server/test_streamable_http_manager.py b/tests/server/test_streamable_http_manager.py index 32782e458..12850240c 100644 --- a/tests/server/test_streamable_http_manager.py +++ b/tests/server/test_streamable_http_manager.py @@ -1,9 +1,12 @@ """Tests for StreamableHTTPSessionManager.""" +from unittest.mock import AsyncMock + import anyio import pytest from mcp.server.lowlevel import Server +from mcp.server.streamable_http import MCP_SESSION_ID_HEADER from mcp.server.streamable_http_manager import StreamableHTTPSessionManager @@ -79,3 +82,128 @@ async def send(message): assert "Task group is not initialized. Make sure to use run()." in str( excinfo.value ) + + +class TestException(Exception): + __test__ = False # Prevent pytest from collecting this as a test class + pass + + +@pytest.fixture +async def running_manager(): + app = Server("test-cleanup-server") + # It's important that the app instance used by the manager is the one we can patch + manager = StreamableHTTPSessionManager(app=app) + async with manager.run(): + # Patch app.run here if it's simpler, or patch it within the test + yield manager, app + + +@pytest.mark.anyio +async def test_stateful_session_cleanup_on_graceful_exit(running_manager): + manager, app = running_manager + + mock_mcp_run = AsyncMock(return_value=None) + # This will be called by StreamableHTTPSessionManager's run_server -> self.app.run + app.run = mock_mcp_run + + sent_messages = [] + + async def mock_send(message): + sent_messages.append(message) + + scope = {"type": "http", "method": "POST", "path": "/mcp", "headers": []} + + async def mock_receive(): + return {"type": "http.request", "body": b"", "more_body": False} + + # Trigger session creation + await manager.handle_request(scope, mock_receive, mock_send) + + # Extract session ID from response headers + session_id = None + for msg in sent_messages: + if msg["type"] == "http.response.start": + for header_name, header_value in msg.get("headers", []): + if header_name.decode().lower() == MCP_SESSION_ID_HEADER.lower(): + session_id = header_value.decode() + break + if session_id: # Break outer loop if session_id is found + break + + assert session_id is not None, "Session ID not found in response headers" + + # Ensure MCPServer.run was called + mock_mcp_run.assert_called_once() + + # At this point, mock_mcp_run has completed, and the finally block in + # StreamableHTTPSessionManager's run_server should have executed. + + # To ensure the task spawned by handle_request finishes and cleanup occurs: + # Give other tasks a chance to run. This is important for the finally block. + await anyio.sleep(0.01) + + assert ( + session_id not in manager._server_instances + ), "Session ID should be removed from _server_instances after graceful exit" + assert ( + not manager._server_instances + ), "No sessions should be tracked after the only session exits gracefully" + + +@pytest.mark.anyio +async def test_stateful_session_cleanup_on_exception(running_manager): + manager, app = running_manager + + mock_mcp_run = AsyncMock(side_effect=TestException("Simulated crash")) + app.run = mock_mcp_run + + sent_messages = [] + + async def mock_send(message): + sent_messages.append(message) + # If an exception occurs, the transport might try to send an error response + # For this test, we mostly care that the session is established enough + # to get an ID + if message["type"] == "http.response.start" and message["status"] >= 500: + pass # Expected if TestException propagates that far up the transport + + scope = {"type": "http", "method": "POST", "path": "/mcp", "headers": []} + + async def mock_receive(): + return {"type": "http.request", "body": b"", "more_body": False} + + # It's possible handle_request itself might raise an error if the TestException + # isn't caught by the transport layer before propagating. + # The key is that the session manager's internal task for MCPServer.run + # encounters the exception. + try: + await manager.handle_request(scope, mock_receive, mock_send) + except TestException: + # This might be caught here if not handled by StreamableHTTPServerTransport's + # error handling + pass + + session_id = None + for msg in sent_messages: + if msg["type"] == "http.response.start": + for header_name, header_value in msg.get("headers", []): + if header_name.decode().lower() == MCP_SESSION_ID_HEADER.lower(): + session_id = header_value.decode() + break + if session_id: # Break outer loop if session_id is found + break + + assert session_id is not None, "Session ID not found in response headers" + + mock_mcp_run.assert_called_once() + + # Give other tasks a chance to run to ensure the finally block executes + await anyio.sleep(0.01) + + assert ( + session_id not in manager._server_instances + ), "Session ID should be removed from _server_instances after an exception" + assert ( + not manager._server_instances + ), "No sessions should be tracked after the only session crashes"