-
Notifications
You must be signed in to change notification settings - Fork 1.8k
feat: implement MCP-Protocol-Version header requirement for HTTP transport #898
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
2314b5a
c6cca13
4ce7e4c
370b993
6832394
b9a0c96
d14666b
bb4eaab
56c1b0a
5a5e7a9
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -22,6 +22,7 @@ | |
from mcp.shared.message import ClientMessageMetadata, SessionMessage | ||
from mcp.types import ( | ||
ErrorData, | ||
InitializeResult, | ||
JSONRPCError, | ||
JSONRPCMessage, | ||
JSONRPCNotification, | ||
|
@@ -39,6 +40,7 @@ | |
GetSessionIdCallback = Callable[[], str | None] | ||
|
||
MCP_SESSION_ID = "mcp-session-id" | ||
MCP_PROTOCOL_VERSION = "mcp-protocol-version" | ||
LAST_EVENT_ID = "last-event-id" | ||
CONTENT_TYPE = "content-type" | ||
ACCEPT = "Accept" | ||
|
@@ -97,17 +99,20 @@ def __init__( | |
) | ||
self.auth = auth | ||
self.session_id = None | ||
self.protocol_version = None | ||
self.request_headers = { | ||
ACCEPT: f"{JSON}, {SSE}", | ||
CONTENT_TYPE: JSON, | ||
**self.headers, | ||
} | ||
|
||
def _update_headers_with_session(self, base_headers: dict[str, str]) -> dict[str, str]: | ||
"""Update headers with session ID if available.""" | ||
def _prepare_request_headers(self, base_headers: dict[str, str]) -> dict[str, str]: | ||
"""Update headers with session ID and protocol version if available.""" | ||
headers = base_headers.copy() | ||
if self.session_id: | ||
headers[MCP_SESSION_ID] = self.session_id | ||
if self.protocol_version: | ||
headers[MCP_PROTOCOL_VERSION] = self.protocol_version | ||
return headers | ||
|
||
def _is_initialization_request(self, message: JSONRPCMessage) -> bool: | ||
|
@@ -128,19 +133,39 @@ def _maybe_extract_session_id_from_response( | |
self.session_id = new_session_id | ||
logger.info(f"Received session ID: {self.session_id}") | ||
|
||
def _maybe_extract_protocol_version_from_message( | ||
self, | ||
message: JSONRPCMessage, | ||
) -> None: | ||
"""Extract protocol version from initialization response message.""" | ||
if isinstance(message.root, JSONRPCResponse) and message.root.result: | ||
try: | ||
# Parse the result as InitializeResult for type safety | ||
init_result = InitializeResult.model_validate(message.root.result) | ||
self.protocol_version = str(init_result.protocolVersion) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. something unsettling about this 😅 technically it will work, but can we handle None? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. wait, actually it will not work, as now we'll have "None" and will not set it to the default There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
class InitializeResult(Result):
"""After receiving an initialize request from the client, the server sends this."""
protocolVersion: str | int
"""The version of the Model Context Protocol that the server wants to use."""
capabilities: ServerCapabilities
serverInfo: Implementation
instructions: str | None = None
"""Instructions describing how to use the server and its features.""" Do you mean handling when There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We could also do something like this, where we set the try:
# Parse the result as InitializeResult for type safety
init_result = InitializeResult.model_validate(message.root.result)
self.protocol_version = str(init_result.protocolVersion)
logger.info(f"Negotiated protocol version: {self.protocol_version}")
except Exception as exc:
# Assume the default version if parsing fails for any reason
self.protocol_version = DEFAULT_NEGOTIATED_VERSION
logger.warning(f"Failed to parse initialization response as InitializeResult: {exc}")
logger.warning(f"Raw result: {message.root.result}") If the server is returning incorrect responses I feel like client shouldn't be "hiding" that? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. sorry, I mean just remove
|
||
logger.info(f"Negotiated protocol version: {self.protocol_version}") | ||
except Exception as exc: | ||
logger.warning(f"Failed to parse initialization response as InitializeResult: {exc}") | ||
logger.warning(f"Raw result: {message.root.result}") | ||
|
||
async def _handle_sse_event( | ||
self, | ||
sse: ServerSentEvent, | ||
read_stream_writer: StreamWriter, | ||
original_request_id: RequestId | None = None, | ||
resumption_callback: Callable[[str], Awaitable[None]] | None = None, | ||
is_initialization: bool = False, | ||
) -> bool: | ||
"""Handle an SSE event, returning True if the response is complete.""" | ||
if sse.event == "message": | ||
try: | ||
message = JSONRPCMessage.model_validate_json(sse.data) | ||
logger.debug(f"SSE message: {message}") | ||
|
||
# Extract protocol version from initialization response | ||
if is_initialization: | ||
self._maybe_extract_protocol_version_from_message(message) | ||
|
||
# If this is a response and we have original_request_id, replace it | ||
if original_request_id is not None and isinstance(message.root, JSONRPCResponse | JSONRPCError): | ||
message.root.id = original_request_id | ||
|
@@ -174,7 +199,7 @@ async def handle_get_stream( | |
if not self.session_id: | ||
return | ||
|
||
headers = self._update_headers_with_session(self.request_headers) | ||
headers = self._prepare_request_headers(self.request_headers) | ||
|
||
async with aconnect_sse( | ||
client, | ||
|
@@ -194,7 +219,7 @@ async def handle_get_stream( | |
|
||
async def _handle_resumption_request(self, ctx: RequestContext) -> None: | ||
"""Handle a resumption request using GET with SSE.""" | ||
headers = self._update_headers_with_session(ctx.headers) | ||
headers = self._prepare_request_headers(ctx.headers) | ||
if ctx.metadata and ctx.metadata.resumption_token: | ||
headers[LAST_EVENT_ID] = ctx.metadata.resumption_token | ||
else: | ||
|
@@ -227,7 +252,7 @@ async def _handle_resumption_request(self, ctx: RequestContext) -> None: | |
|
||
async def _handle_post_request(self, ctx: RequestContext) -> None: | ||
"""Handle a POST request with response processing.""" | ||
headers = self._update_headers_with_session(ctx.headers) | ||
headers = self._prepare_request_headers(ctx.headers) | ||
message = ctx.session_message.message | ||
is_initialization = self._is_initialization_request(message) | ||
|
||
|
@@ -256,9 +281,9 @@ async def _handle_post_request(self, ctx: RequestContext) -> None: | |
content_type = response.headers.get(CONTENT_TYPE, "").lower() | ||
|
||
if content_type.startswith(JSON): | ||
await self._handle_json_response(response, ctx.read_stream_writer) | ||
await self._handle_json_response(response, ctx.read_stream_writer, is_initialization) | ||
elif content_type.startswith(SSE): | ||
await self._handle_sse_response(response, ctx) | ||
await self._handle_sse_response(response, ctx, is_initialization) | ||
else: | ||
await self._handle_unexpected_content_type( | ||
content_type, | ||
|
@@ -269,18 +294,29 @@ async def _handle_json_response( | |
self, | ||
response: httpx.Response, | ||
read_stream_writer: StreamWriter, | ||
is_initialization: bool = False, | ||
) -> None: | ||
"""Handle JSON response from the server.""" | ||
try: | ||
content = await response.aread() | ||
message = JSONRPCMessage.model_validate_json(content) | ||
|
||
# Extract protocol version from initialization response | ||
if is_initialization: | ||
self._maybe_extract_protocol_version_from_message(message) | ||
|
||
session_message = SessionMessage(message) | ||
await read_stream_writer.send(session_message) | ||
except Exception as exc: | ||
logger.error(f"Error parsing JSON response: {exc}") | ||
await read_stream_writer.send(exc) | ||
|
||
async def _handle_sse_response(self, response: httpx.Response, ctx: RequestContext) -> None: | ||
async def _handle_sse_response( | ||
self, | ||
response: httpx.Response, | ||
ctx: RequestContext, | ||
is_initialization: bool = False, | ||
) -> None: | ||
"""Handle SSE response from the server.""" | ||
try: | ||
event_source = EventSource(response) | ||
|
@@ -289,6 +325,7 @@ async def _handle_sse_response(self, response: httpx.Response, ctx: RequestConte | |
sse, | ||
ctx.read_stream_writer, | ||
resumption_callback=(ctx.metadata.on_resumption_token_update if ctx.metadata else None), | ||
is_initialization=is_initialization, | ||
) | ||
# If the SSE event indicates completion, like returning respose/error | ||
# break the loop | ||
|
@@ -385,7 +422,7 @@ async def terminate_session(self, client: httpx.AsyncClient) -> None: | |
return | ||
|
||
try: | ||
headers = self._update_headers_with_session(self.request_headers) | ||
headers = self._prepare_request_headers(self.request_headers) | ||
response = await client.delete(self.url, headers=headers) | ||
|
||
if response.status_code == 405: | ||
|
Uh oh!
There was an error while loading. Please reload this page.