diff --git a/src/mcp/client/sse.py b/src/mcp/client/sse.py index 4f6241a7..0aef8c15 100644 --- a/src/mcp/client/sse.py +++ b/src/mcp/client/sse.py @@ -14,8 +14,38 @@ logger = logging.getLogger(__name__) +# TODO: move these to utils/url_utils.py +def get_origin(url: str) -> str: + parsed_url = urlparse(url) + return f"{parsed_url.scheme}://{parsed_url.netloc}" + + +def get_relative_path(url: str, remove_params: bool = False) -> str: + parsed_url = urlparse(url) + if remove_params: + return parsed_url.path + relative_path = parsed_url.path + if parsed_url.query: + relative_path += f"?{parsed_url.query}" + if parsed_url.fragment: + relative_path += f"#{parsed_url.fragment}" + return relative_path + + +def get_endpoint_url( + base_url: str, sse_relative_url: str, server_mount_path: str = "" +) -> str: + endpoint_url = urljoin(base_url, sse_relative_url) + if server_mount_path: + origin, path = get_origin(endpoint_url), get_relative_path(endpoint_url) + endpoint_url = urljoin( + f"{origin}/{server_mount_path.strip('/')}/", path.lstrip("/") + ) + return endpoint_url + + def remove_request_params(url: str) -> str: - return urljoin(url, urlparse(url).path) + return urljoin(url, get_relative_path(url, remove_params=True)) @asynccontextmanager @@ -24,12 +54,16 @@ async def sse_client( headers: dict[str, Any] | None = None, timeout: float = 5, sse_read_timeout: float = 60 * 5, + server_mount_path: str = "", ): """ Client transport for SSE. `sse_read_timeout` determines how long (in seconds) the client will wait for a new event before disconnecting. All other HTTP operations are controlled by `timeout`. + + `server_mount_path` provides the relative mount path of the MCP server + (used if it is mounted relatively on another ASGI server). """ read_stream: MemoryObjectReceiveStream[types.JSONRPCMessage | Exception] read_stream_writer: MemoryObjectSendStream[types.JSONRPCMessage | Exception] @@ -61,18 +95,15 @@ async def sse_reader( logger.debug(f"Received SSE event: {sse.event}") match sse.event: case "endpoint": - endpoint_url = urljoin(url, sse.data) + endpoint_url = get_endpoint_url( + base_url=url, + sse_relative_url=sse.data, + server_mount_path=server_mount_path, + ) logger.info( f"Received endpoint URL: {endpoint_url}" ) - - url_parsed = urlparse(url) - endpoint_parsed = urlparse(endpoint_url) - if ( - url_parsed.netloc != endpoint_parsed.netloc - or url_parsed.scheme - != endpoint_parsed.scheme - ): + if get_origin(url) != get_origin(endpoint_url): error_msg = ( "Endpoint origin does not match " f"connection origin: {endpoint_url}"