Skip to content

feat: implement MCP-Protocol-Version header requirement for HTTP transport #898

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 10 commits into from
Jun 12, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 8 additions & 2 deletions src/mcp/client/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,13 @@
import anyio
import httpx

from mcp.shared.auth import OAuthClientInformationFull, OAuthClientMetadata, OAuthMetadata, OAuthToken
from mcp.client.streamable_http import MCP_PROTOCOL_VERSION
from mcp.shared.auth import (
OAuthClientInformationFull,
OAuthClientMetadata,
OAuthMetadata,
OAuthToken,
)
from mcp.types import LATEST_PROTOCOL_VERSION

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -121,7 +127,7 @@ async def _discover_oauth_metadata(self, server_url: str) -> OAuthMetadata | Non
# Extract base URL per MCP spec
auth_base_url = self._get_authorization_base_url(server_url)
url = urljoin(auth_base_url, "/.well-known/oauth-authorization-server")
headers = {"MCP-Protocol-Version": LATEST_PROTOCOL_VERSION}
headers = {MCP_PROTOCOL_VERSION: LATEST_PROTOCOL_VERSION}

async with httpx.AsyncClient() as client:
try:
Expand Down
55 changes: 46 additions & 9 deletions src/mcp/client/streamable_http.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from mcp.shared.message import ClientMessageMetadata, SessionMessage
from mcp.types import (
ErrorData,
InitializeResult,
JSONRPCError,
JSONRPCMessage,
JSONRPCNotification,
Expand All @@ -39,6 +40,7 @@
GetSessionIdCallback = Callable[[], str | None]

MCP_SESSION_ID = "mcp-session-id"
MCP_PROTOCOL_VERSION = "mcp-protocol-version"
LAST_EVENT_ID = "last-event-id"
CONTENT_TYPE = "content-type"
ACCEPT = "Accept"
Expand Down Expand Up @@ -97,17 +99,20 @@ def __init__(
)
self.auth = auth
self.session_id = None
self.protocol_version = None
self.request_headers = {
ACCEPT: f"{JSON}, {SSE}",
CONTENT_TYPE: JSON,
**self.headers,
}

def _update_headers_with_session(self, base_headers: dict[str, str]) -> dict[str, str]:
"""Update headers with session ID if available."""
def _prepare_request_headers(self, base_headers: dict[str, str]) -> dict[str, str]:
"""Update headers with session ID and protocol version if available."""
headers = base_headers.copy()
if self.session_id:
headers[MCP_SESSION_ID] = self.session_id
if self.protocol_version:
headers[MCP_PROTOCOL_VERSION] = self.protocol_version
return headers

def _is_initialization_request(self, message: JSONRPCMessage) -> bool:
Expand All @@ -128,19 +133,39 @@ def _maybe_extract_session_id_from_response(
self.session_id = new_session_id
logger.info(f"Received session ID: {self.session_id}")

def _maybe_extract_protocol_version_from_message(
self,
message: JSONRPCMessage,
) -> None:
"""Extract protocol version from initialization response message."""
if isinstance(message.root, JSONRPCResponse) and message.root.result:
try:
# Parse the result as InitializeResult for type safety
init_result = InitializeResult.model_validate(message.root.result)
self.protocol_version = str(init_result.protocolVersion)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

something unsettling about this 😅 technically it will work, but can we handle None?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

wait, actually it will not work, as now we'll have "None" and will not set it to the default

Copy link
Contributor Author

@felixweinberger felixweinberger Jun 12, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

protocolVersion can only be str | int:

class InitializeResult(Result):
    """After receiving an initialize request from the client, the server sends this."""

    protocolVersion: str | int
    """The version of the Model Context Protocol that the server wants to use."""
    capabilities: ServerCapabilities
    serverInfo: Implementation
    instructions: str | None = None
    """Instructions describing how to use the server and its features."""

Do you mean handling when message.root.result === None implying a broken server? if that happens the try-catch ensures we don't crash out, self.protocol_version just stays unset (this is on the client).

Copy link
Contributor Author

@felixweinberger felixweinberger Jun 12, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We could also do something like this, where we set the self.protocol_version = DEFAULT_NEGOTIATED_VERSION but I would argue against this because we're on the client here and that might mask server issues:

try:
    # Parse the result as InitializeResult for type safety
    init_result = InitializeResult.model_validate(message.root.result)
    self.protocol_version = str(init_result.protocolVersion)
    logger.info(f"Negotiated protocol version: {self.protocol_version}")
except Exception as exc:
    # Assume the default version if parsing fails for any reason
    self.protocol_version = DEFAULT_NEGOTIATED_VERSION
    logger.warning(f"Failed to parse initialization response as InitializeResult: {exc}")
    logger.warning(f"Raw result: {message.root.result}")

If the server is returning incorrect responses I feel like client shouldn't be "hiding" that?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sorry, I mean just remove str(init_result.protocolVersion) and instead do

self.protocol_version = str(init_result.protocolVersion) if init_result.protocolVersion is not None else None

logger.info(f"Negotiated protocol version: {self.protocol_version}")
except Exception as exc:
logger.warning(f"Failed to parse initialization response as InitializeResult: {exc}")
logger.warning(f"Raw result: {message.root.result}")

async def _handle_sse_event(
self,
sse: ServerSentEvent,
read_stream_writer: StreamWriter,
original_request_id: RequestId | None = None,
resumption_callback: Callable[[str], Awaitable[None]] | None = None,
is_initialization: bool = False,
) -> bool:
"""Handle an SSE event, returning True if the response is complete."""
if sse.event == "message":
try:
message = JSONRPCMessage.model_validate_json(sse.data)
logger.debug(f"SSE message: {message}")

# Extract protocol version from initialization response
if is_initialization:
self._maybe_extract_protocol_version_from_message(message)

# If this is a response and we have original_request_id, replace it
if original_request_id is not None and isinstance(message.root, JSONRPCResponse | JSONRPCError):
message.root.id = original_request_id
Expand Down Expand Up @@ -174,7 +199,7 @@ async def handle_get_stream(
if not self.session_id:
return

headers = self._update_headers_with_session(self.request_headers)
headers = self._prepare_request_headers(self.request_headers)

async with aconnect_sse(
client,
Expand All @@ -194,7 +219,7 @@ async def handle_get_stream(

async def _handle_resumption_request(self, ctx: RequestContext) -> None:
"""Handle a resumption request using GET with SSE."""
headers = self._update_headers_with_session(ctx.headers)
headers = self._prepare_request_headers(ctx.headers)
if ctx.metadata and ctx.metadata.resumption_token:
headers[LAST_EVENT_ID] = ctx.metadata.resumption_token
else:
Expand Down Expand Up @@ -227,7 +252,7 @@ async def _handle_resumption_request(self, ctx: RequestContext) -> None:

async def _handle_post_request(self, ctx: RequestContext) -> None:
"""Handle a POST request with response processing."""
headers = self._update_headers_with_session(ctx.headers)
headers = self._prepare_request_headers(ctx.headers)
message = ctx.session_message.message
is_initialization = self._is_initialization_request(message)

Expand Down Expand Up @@ -256,9 +281,9 @@ async def _handle_post_request(self, ctx: RequestContext) -> None:
content_type = response.headers.get(CONTENT_TYPE, "").lower()

if content_type.startswith(JSON):
await self._handle_json_response(response, ctx.read_stream_writer)
await self._handle_json_response(response, ctx.read_stream_writer, is_initialization)
elif content_type.startswith(SSE):
await self._handle_sse_response(response, ctx)
await self._handle_sse_response(response, ctx, is_initialization)
else:
await self._handle_unexpected_content_type(
content_type,
Expand All @@ -269,18 +294,29 @@ async def _handle_json_response(
self,
response: httpx.Response,
read_stream_writer: StreamWriter,
is_initialization: bool = False,
) -> None:
"""Handle JSON response from the server."""
try:
content = await response.aread()
message = JSONRPCMessage.model_validate_json(content)

# Extract protocol version from initialization response
if is_initialization:
self._maybe_extract_protocol_version_from_message(message)

session_message = SessionMessage(message)
await read_stream_writer.send(session_message)
except Exception as exc:
logger.error(f"Error parsing JSON response: {exc}")
await read_stream_writer.send(exc)

async def _handle_sse_response(self, response: httpx.Response, ctx: RequestContext) -> None:
async def _handle_sse_response(
self,
response: httpx.Response,
ctx: RequestContext,
is_initialization: bool = False,
) -> None:
"""Handle SSE response from the server."""
try:
event_source = EventSource(response)
Expand All @@ -289,6 +325,7 @@ async def _handle_sse_response(self, response: httpx.Response, ctx: RequestConte
sse,
ctx.read_stream_writer,
resumption_callback=(ctx.metadata.on_resumption_token_update if ctx.metadata else None),
is_initialization=is_initialization,
)
# If the SSE event indicates completion, like returning respose/error
# break the loop
Expand Down Expand Up @@ -385,7 +422,7 @@ async def terminate_session(self, client: httpx.AsyncClient) -> None:
return

try:
headers = self._update_headers_with_session(self.request_headers)
headers = self._prepare_request_headers(self.request_headers)
response = await client.delete(self.url, headers=headers)

if response.status_code == 405:
Expand Down
3 changes: 2 additions & 1 deletion src/mcp/server/auth/routes.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from mcp.server.auth.middleware.client_auth import ClientAuthenticator
from mcp.server.auth.provider import OAuthAuthorizationServerProvider
from mcp.server.auth.settings import ClientRegistrationOptions, RevocationOptions
from mcp.server.streamable_http import MCP_PROTOCOL_VERSION_HEADER
from mcp.shared.auth import OAuthMetadata


Expand Down Expand Up @@ -55,7 +56,7 @@ def cors_middleware(
app=request_response(handler),
allow_origins="*",
allow_methods=allow_methods,
allow_headers=["mcp-protocol-version"],
allow_headers=[MCP_PROTOCOL_VERSION_HEADER],
)
return cors_app

Expand Down
42 changes: 37 additions & 5 deletions src/mcp/server/streamable_http.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,9 @@
from starlette.types import Receive, Scope, Send

from mcp.shared.message import ServerMessageMetadata, SessionMessage
from mcp.shared.version import SUPPORTED_PROTOCOL_VERSIONS
from mcp.types import (
DEFAULT_NEGOTIATED_VERSION,
INTERNAL_ERROR,
INVALID_PARAMS,
INVALID_REQUEST,
Expand All @@ -45,6 +47,7 @@

# Header names
MCP_SESSION_ID_HEADER = "mcp-session-id"
MCP_PROTOCOL_VERSION_HEADER = "mcp-protocol-version"
LAST_EVENT_ID_HEADER = "last-event-id"

# Content types
Expand Down Expand Up @@ -293,7 +296,7 @@ async def _handle_post_request(self, scope: Scope, request: Request, receive: Re
has_json, has_sse = self._check_accept_headers(request)
if not (has_json and has_sse):
response = self._create_error_response(
("Not Acceptable: Client must accept both application/json and " "text/event-stream"),
("Not Acceptable: Client must accept both application/json and text/event-stream"),
HTTPStatus.NOT_ACCEPTABLE,
)
await response(scope, receive, send)
Expand Down Expand Up @@ -353,8 +356,7 @@ async def _handle_post_request(self, scope: Scope, request: Request, receive: Re
)
await response(scope, receive, send)
return
# For non-initialization requests, validate the session
elif not await self._validate_session(request, send):
elif not await self._validate_request_headers(request, send):
return

# For notifications and responses only, return 202 Accepted
Expand Down Expand Up @@ -513,8 +515,9 @@ async def _handle_get_request(self, request: Request, send: Send) -> None:
await response(request.scope, request.receive, send)
return

if not await self._validate_session(request, send):
if not await self._validate_request_headers(request, send):
return

# Handle resumability: check for Last-Event-ID header
if last_event_id := request.headers.get(LAST_EVENT_ID_HEADER):
await self._replay_events(last_event_id, request, send)
Expand Down Expand Up @@ -593,7 +596,7 @@ async def _handle_delete_request(self, request: Request, send: Send) -> None:
await response(request.scope, request.receive, send)
return

if not await self._validate_session(request, send):
if not await self._validate_request_headers(request, send):
return

await self._terminate_session()
Expand Down Expand Up @@ -653,6 +656,13 @@ async def _handle_unsupported_request(self, request: Request, send: Send) -> Non
)
await response(request.scope, request.receive, send)

async def _validate_request_headers(self, request: Request, send: Send) -> bool:
if not await self._validate_session(request, send):
return False
if not await self._validate_protocol_version(request, send):
return False
return True

async def _validate_session(self, request: Request, send: Send) -> bool:
"""Validate the session ID in the request."""
if not self.mcp_session_id:
Expand Down Expand Up @@ -682,6 +692,28 @@ async def _validate_session(self, request: Request, send: Send) -> bool:

return True

async def _validate_protocol_version(self, request: Request, send: Send) -> bool:
"""Validate the protocol version header in the request."""
# Get the protocol version from the request headers
protocol_version = request.headers.get(MCP_PROTOCOL_VERSION_HEADER)

# If no protocol version provided, assume default version
if protocol_version is None:
protocol_version = DEFAULT_NEGOTIATED_VERSION

# Check if the protocol version is supported
if protocol_version not in SUPPORTED_PROTOCOL_VERSIONS:
supported_versions = ", ".join(SUPPORTED_PROTOCOL_VERSIONS)
response = self._create_error_response(
f"Bad Request: Unsupported protocol version: {protocol_version}. "
+ f"Supported versions: {supported_versions}",
HTTPStatus.BAD_REQUEST,
)
await response(request.scope, request.receive, send)
return False

return True

async def _replay_events(self, last_event_id: str, request: Request, send: Send) -> None:
"""
Replays events that would have been sent after the specified event ID.
Expand Down
8 changes: 8 additions & 0 deletions src/mcp/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,14 @@

LATEST_PROTOCOL_VERSION = "2025-03-26"

"""
The default negotiated version of the Model Context Protocol when no version is specified.
We need this to satisfy the MCP specification, which requires the server to assume a
specific version if none is provided by the client. See section "Protocol Version Header" at
https://modelcontextprotocol.io/specification
"""
DEFAULT_NEGOTIATED_VERSION = "2025-03-26"

ProgressToken = str | int
Cursor = str
Role = Literal["user", "assistant"]
Expand Down
Loading
Loading