Skip to content

Adding Client Credentials to Auth #882

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 18 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -851,6 +851,9 @@ async def main():
callback_handler=lambda: ("auth_code", None),
)

# For machine-to-machine scenarios, use ClientCredentialsProvider
# instead of OAuthClientProvider.

# Use with streamable HTTP client
async with streamablehttp_client(
"https://api.example.com/mcp", auth=oauth_auth
Expand Down
18 changes: 18 additions & 0 deletions examples/servers/simple-auth/mcp_simple_auth/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,6 +247,24 @@ async def exchange_refresh_token(
"""Exchange refresh token"""
raise NotImplementedError("Not supported")

async def exchange_client_credentials(
self, client: OAuthClientInformationFull, scopes: list[str]
) -> OAuthToken:
"""Exchange client credentials for an access token."""
token = f"mcp_{secrets.token_hex(32)}"
self.tokens[token] = AccessToken(
token=token,
client_id=client.client_id,
scopes=scopes,
expires_at=int(time.time()) + 3600,
)
return OAuthToken(
access_token=token,
token_type="bearer",
expires_in=3600,
scope=" ".join(scopes),
)

async def revoke_token(
self, token: str, token_type_hint: str | None = None
) -> None:
Expand Down
283 changes: 231 additions & 52 deletions src/mcp/client/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,56 @@ async def set_client_info(self, client_info: OAuthClientInformationFull) -> None
...


def _get_authorization_base_url(server_url: str) -> str:
"""
Return the authorization base URL for ``server_url``.

Per MCP spec 2.3.2, the path component must be discarded so that
``https://api.example.com/v1/mcp`` becomes ``https://api.example.com``.
"""
from urllib.parse import urlparse, urlunparse

parsed = urlparse(server_url)
# Remove path component
return urlunparse((parsed.scheme, parsed.netloc, "", "", "", ""))


async def _discover_oauth_metadata(server_url: str) -> OAuthMetadata | None:
"""
Discover OAuth metadata from the server's well-known endpoint.
"""

# Extract base URL per MCP spec
auth_base_url = _get_authorization_base_url(server_url)
url = urljoin(auth_base_url, "/.well-known/oauth-authorization-server")
headers = {"MCP-Protocol-Version": LATEST_PROTOCOL_VERSION}

async with httpx.AsyncClient() as client:
try:
response = await client.get(url, headers=headers)
if response.status_code == 404:
return None
response.raise_for_status()
metadata_json = response.json()
logger.debug(f"OAuth metadata discovered: {metadata_json}")
return OAuthMetadata.model_validate(metadata_json)
except Exception:
# Retry without MCP header for CORS compatibility
try:
response = await client.get(url)
if response.status_code == 404:
return None
response.raise_for_status()
metadata_json = response.json()
logger.debug(
f"OAuth metadata discovered (no MCP header): {metadata_json}"
)
return OAuthMetadata.model_validate(metadata_json)
except Exception:
logger.exception("Failed to discover OAuth metadata")
return None


class OAuthClientProvider(httpx.Auth):
"""
Authentication for httpx using anyio.
Expand Down Expand Up @@ -110,52 +160,6 @@ def _generate_code_challenge(self, code_verifier: str) -> str:
digest = hashlib.sha256(code_verifier.encode()).digest()
return base64.urlsafe_b64encode(digest).decode().rstrip("=")

def _get_authorization_base_url(self, server_url: str) -> str:
"""
Extract base URL by removing path component.

Per MCP spec 2.3.2: https://api.example.com/v1/mcp -> https://api.example.com
"""
from urllib.parse import urlparse, urlunparse

parsed = urlparse(server_url)
# Remove path component
return urlunparse((parsed.scheme, parsed.netloc, "", "", "", ""))

async def _discover_oauth_metadata(self, server_url: str) -> OAuthMetadata | None:
"""
Discover OAuth metadata from server's well-known endpoint.
"""
# Extract base URL per MCP spec
auth_base_url = self._get_authorization_base_url(server_url)
url = urljoin(auth_base_url, "/.well-known/oauth-authorization-server")
headers = {"MCP-Protocol-Version": LATEST_PROTOCOL_VERSION}

async with httpx.AsyncClient() as client:
try:
response = await client.get(url, headers=headers)
if response.status_code == 404:
return None
response.raise_for_status()
metadata_json = response.json()
logger.debug(f"OAuth metadata discovered: {metadata_json}")
return OAuthMetadata.model_validate(metadata_json)
except Exception:
# Retry without MCP header for CORS compatibility
try:
response = await client.get(url)
if response.status_code == 404:
return None
response.raise_for_status()
metadata_json = response.json()
logger.debug(
f"OAuth metadata discovered (no MCP header): {metadata_json}"
)
return OAuthMetadata.model_validate(metadata_json)
except Exception:
logger.exception("Failed to discover OAuth metadata")
return None

async def _register_oauth_client(
self,
server_url: str,
Expand All @@ -166,13 +170,13 @@ async def _register_oauth_client(
Register OAuth client with server.
"""
if not metadata:
metadata = await self._discover_oauth_metadata(server_url)
metadata = await _discover_oauth_metadata(server_url)

if metadata and metadata.registration_endpoint:
registration_url = str(metadata.registration_endpoint)
else:
# Use fallback registration endpoint
auth_base_url = self._get_authorization_base_url(server_url)
auth_base_url = _get_authorization_base_url(server_url)
registration_url = urljoin(auth_base_url, "/register")

# Handle default scope
Expand Down Expand Up @@ -321,7 +325,7 @@ async def _perform_oauth_flow(self) -> None:

# Discover OAuth metadata
if not self._metadata:
self._metadata = await self._discover_oauth_metadata(self.server_url)
self._metadata = await _discover_oauth_metadata(self.server_url)

# Ensure client registration
client_info = await self._get_or_register_client()
Expand All @@ -335,7 +339,7 @@ async def _perform_oauth_flow(self) -> None:
auth_url_base = str(self._metadata.authorization_endpoint)
else:
# Use fallback authorization endpoint
auth_base_url = self._get_authorization_base_url(self.server_url)
auth_base_url = _get_authorization_base_url(self.server_url)
auth_url_base = urljoin(auth_base_url, "/authorize")

# Build authorization URL
Expand Down Expand Up @@ -386,7 +390,7 @@ async def _exchange_code_for_token(
token_url = str(self._metadata.token_endpoint)
else:
# Use fallback token endpoint
auth_base_url = self._get_authorization_base_url(self.server_url)
auth_base_url = _get_authorization_base_url(self.server_url)
token_url = urljoin(auth_base_url, "/token")

token_data = {
Expand Down Expand Up @@ -453,7 +457,7 @@ async def _refresh_access_token(self) -> bool:
token_url = str(self._metadata.token_endpoint)
else:
# Use fallback token endpoint
auth_base_url = self._get_authorization_base_url(self.server_url)
auth_base_url = _get_authorization_base_url(self.server_url)
token_url = urljoin(auth_base_url, "/token")

refresh_data = {
Expand Down Expand Up @@ -499,3 +503,178 @@ async def _refresh_access_token(self) -> bool:
except Exception:
logger.exception("Token refresh failed")
return False


class ClientCredentialsProvider(httpx.Auth):
"""HTTPX auth using the OAuth2 client credentials grant."""

def __init__(
self,
server_url: str,
client_metadata: OAuthClientMetadata,
storage: TokenStorage,
timeout: float = 300.0,
):
self.server_url = server_url
self.client_metadata = client_metadata
self.storage = storage
self.timeout = timeout

self._current_tokens: OAuthToken | None = None
self._metadata: OAuthMetadata | None = None
self._client_info: OAuthClientInformationFull | None = None
self._token_expiry_time: float | None = None

self._token_lock = anyio.Lock()

async def _register_oauth_client(
self,
server_url: str,
client_metadata: OAuthClientMetadata,
metadata: OAuthMetadata | None = None,
) -> OAuthClientInformationFull:
if not metadata:
metadata = await _discover_oauth_metadata(server_url)

if metadata and metadata.registration_endpoint:
registration_url = str(metadata.registration_endpoint)
else:
auth_base_url = _get_authorization_base_url(server_url)
registration_url = urljoin(auth_base_url, "/register")

if (
client_metadata.scope is None
and metadata
and metadata.scopes_supported is not None
):
client_metadata.scope = " ".join(metadata.scopes_supported)

registration_data = client_metadata.model_dump(
by_alias=True, mode="json", exclude_none=True
)

async with httpx.AsyncClient() as client:
response = await client.post(
registration_url,
json=registration_data,
headers={"Content-Type": "application/json"},
)

if response.status_code not in (200, 201):
raise httpx.HTTPStatusError(
f"Registration failed: {response.status_code}",
request=response.request,
response=response,
)

return OAuthClientInformationFull.model_validate(response.json())

def _has_valid_token(self) -> bool:
if not self._current_tokens or not self._current_tokens.access_token:
return False

if self._token_expiry_time and time.time() > self._token_expiry_time:
return False
return True

async def _validate_token_scopes(self, token_response: OAuthToken) -> None:
if not token_response.scope:
return

requested_scopes: set[str] = set()
if self.client_metadata.scope:
requested_scopes = set(self.client_metadata.scope.split())
returned_scopes = set(token_response.scope.split())
unauthorized_scopes = returned_scopes - requested_scopes
if unauthorized_scopes:
raise Exception(
f"Server granted unauthorized scopes: {unauthorized_scopes}."
)
else:
granted = set(token_response.scope.split())
logger.debug(
"No explicit scopes requested, accepting server-granted scopes: %s",
granted,
)

async def initialize(self) -> None:
self._current_tokens = await self.storage.get_tokens()
self._client_info = await self.storage.get_client_info()

async def _get_or_register_client(self) -> OAuthClientInformationFull:
if not self._client_info:
self._client_info = await self._register_oauth_client(
self.server_url, self.client_metadata, self._metadata
)
await self.storage.set_client_info(self._client_info)
return self._client_info

async def _request_token(self) -> None:
if not self._metadata:
self._metadata = await _discover_oauth_metadata(self.server_url)

client_info = await self._get_or_register_client()

if self._metadata and self._metadata.token_endpoint:
token_url = str(self._metadata.token_endpoint)
else:
auth_base_url = _get_authorization_base_url(self.server_url)
token_url = urljoin(auth_base_url, "/token")

token_data = {
"grant_type": "client_credentials",
"client_id": client_info.client_id,
}

if client_info.client_secret:
token_data["client_secret"] = client_info.client_secret

if self.client_metadata.scope:
token_data["scope"] = self.client_metadata.scope

async with httpx.AsyncClient() as client:
response = await client.post(
token_url,
data=token_data,
headers={"Content-Type": "application/x-www-form-urlencoded"},
timeout=30.0,
)

if response.status_code != 200:
raise Exception(
f"Token request failed: {response.status_code} {response.text}"
)

token_response = OAuthToken.model_validate(response.json())
await self._validate_token_scopes(token_response)

if token_response.expires_in:
self._token_expiry_time = time.time() + token_response.expires_in
else:
self._token_expiry_time = None

await self.storage.set_tokens(token_response)
self._current_tokens = token_response

async def ensure_token(self) -> None:
async with self._token_lock:
if self._has_valid_token():
return
await self._request_token()

async def async_auth_flow(
self, request: httpx.Request
) -> AsyncGenerator[httpx.Request, httpx.Response]:
if not self._has_valid_token():
await self.initialize()
await self.ensure_token()

if self._current_tokens and self._current_tokens.access_token:
request.headers["Authorization"] = (
f"Bearer {self._current_tokens.access_token}"
)

response = yield request

if response.status_code == 401:
self._current_tokens = None
10 changes: 8 additions & 2 deletions src/mcp/client/streamable_http.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,8 +161,14 @@ async def _handle_sse_event(
session_message = SessionMessage(message)
await read_stream_writer.send(session_message)

# Call resumption token callback if we have an ID
if sse.id and resumption_callback:
# Call resumption token callback if we have an ID. Only update
# the resumption token on notifications to avoid overwriting it
# with the token from the final response.
if (
sse.id
and resumption_callback
and not isinstance(message.root, JSONRPCResponse | JSONRPCError)
):
await resumption_callback(sse.id)

# If this is a response or error return True indicating completion
Expand Down
Loading
Loading