diff --git a/changelog.d/17708.feature b/changelog.d/17708.feature new file mode 100644 index 00000000000..90ec810f50d --- /dev/null +++ b/changelog.d/17708.feature @@ -0,0 +1 @@ +Added the `display_name_claim` option to the JWT configuration. This option allows specifying the claim key that contains the user's display name in the JWT payload. \ No newline at end of file diff --git a/docs/usage/configuration/config_documentation.md b/docs/usage/configuration/config_documentation.md index 29f3528c7e1..1de2f688656 100644 --- a/docs/usage/configuration/config_documentation.md +++ b/docs/usage/configuration/config_documentation.md @@ -3722,6 +3722,8 @@ Additional sub-options for this setting include: Required if `enabled` is set to true. * `subject_claim`: Name of the claim containing a unique identifier for the user. Optional, defaults to `sub`. +* `display_name_claim`: Name of the claim containing the display name for the user. Optional. + If provided, the display name will be set to the value of this claim upon first login. * `issuer`: The issuer to validate the "iss" claim against. Optional. If provided the "iss" claim will be required and validated for all JSON web tokens. * `audiences`: A list of audiences to validate the "aud" claim against. Optional. @@ -3736,6 +3738,7 @@ jwt_config: secret: "provided-by-your-issuer" algorithm: "provided-by-your-issuer" subject_claim: "name_of_claim" + display_name_claim: "name_of_claim" issuer: "provided-by-your-issuer" audiences: - "provided-by-your-issuer" diff --git a/synapse/config/jwt.py b/synapse/config/jwt.py index b41f2dc08f3..5c76551f334 100644 --- a/synapse/config/jwt.py +++ b/synapse/config/jwt.py @@ -38,6 +38,7 @@ def read_config(self, config: JsonDict, **kwargs: Any) -> None: self.jwt_algorithm = jwt_config["algorithm"] self.jwt_subject_claim = jwt_config.get("subject_claim", "sub") + self.jwt_display_name_claim = jwt_config.get("display_name_claim") # The issuer and audiences are optional, if provided, it is asserted # that the claims exist on the JWT. @@ -49,5 +50,6 @@ def read_config(self, config: JsonDict, **kwargs: Any) -> None: self.jwt_secret = None self.jwt_algorithm = None self.jwt_subject_claim = None + self.jwt_display_name_claim = None self.jwt_issuer = None self.jwt_audiences = None diff --git a/synapse/handlers/jwt.py b/synapse/handlers/jwt.py index 5fa7a305add..400f3a59aa1 100644 --- a/synapse/handlers/jwt.py +++ b/synapse/handlers/jwt.py @@ -18,7 +18,7 @@ # [This file includes modifications made by New Vector Limited] # # -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Optional, Tuple from authlib.jose import JsonWebToken, JWTClaims from authlib.jose.errors import BadSignatureError, InvalidClaimError, JoseError @@ -36,11 +36,12 @@ def __init__(self, hs: "HomeServer"): self.jwt_secret = hs.config.jwt.jwt_secret self.jwt_subject_claim = hs.config.jwt.jwt_subject_claim + self.jwt_display_name_claim = hs.config.jwt.jwt_display_name_claim self.jwt_algorithm = hs.config.jwt.jwt_algorithm self.jwt_issuer = hs.config.jwt.jwt_issuer self.jwt_audiences = hs.config.jwt.jwt_audiences - def validate_login(self, login_submission: JsonDict) -> str: + def validate_login(self, login_submission: JsonDict) -> Tuple[str, Optional[str]]: """ Authenticates the user for the /login API @@ -49,7 +50,8 @@ def validate_login(self, login_submission: JsonDict) -> str: (including 'type' and other relevant fields) Returns: - The user ID that is logging in. + A tuple of (user_id, display_name) of the user that is logging in. + If the JWT does not contain a display name, the second element of the tuple will be None. Raises: LoginError if there was an authentication problem. @@ -109,4 +111,10 @@ def validate_login(self, login_submission: JsonDict) -> str: if user is None: raise LoginError(403, "Invalid JWT", errcode=Codes.FORBIDDEN) - return UserID(user, self.hs.hostname).to_string() + default_display_name = None + if self.jwt_display_name_claim: + display_name_claim = claims.get(self.jwt_display_name_claim) + if display_name_claim is not None: + default_display_name = display_name_claim + + return UserID(user, self.hs.hostname).to_string(), default_display_name diff --git a/synapse/rest/client/login.py b/synapse/rest/client/login.py index 03b1e7edc49..3271b02d40e 100644 --- a/synapse/rest/client/login.py +++ b/synapse/rest/client/login.py @@ -363,6 +363,7 @@ async def _complete_login( login_submission: JsonDict, callback: Optional[Callable[[LoginResponse], Awaitable[None]]] = None, create_non_existent_users: bool = False, + default_display_name: Optional[str] = None, ratelimit: bool = True, auth_provider_id: Optional[str] = None, should_issue_refresh_token: bool = False, @@ -410,7 +411,8 @@ async def _complete_login( canonical_uid = await self.auth_handler.check_user_exists(user_id) if not canonical_uid: canonical_uid = await self.registration_handler.register_user( - localpart=UserID.from_string(user_id).localpart + localpart=UserID.from_string(user_id).localpart, + default_display_name=default_display_name, ) user_id = canonical_uid @@ -546,11 +548,14 @@ async def _do_jwt_login( Returns: The body of the JSON response. """ - user_id = self.hs.get_jwt_handler().validate_login(login_submission) + user_id, default_display_name = self.hs.get_jwt_handler().validate_login( + login_submission + ) return await self._complete_login( user_id, login_submission, create_non_existent_users=True, + default_display_name=default_display_name, should_issue_refresh_token=should_issue_refresh_token, request_info=request_info, ) diff --git a/tests/rest/client/test_login.py b/tests/rest/client/test_login.py index 2b1e44381b6..cbd6d8d4bf8 100644 --- a/tests/rest/client/test_login.py +++ b/tests/rest/client/test_login.py @@ -1047,6 +1047,7 @@ class JWTTestCase(unittest.HomeserverTestCase): servlets = [ synapse.rest.admin.register_servlets_for_client_rest_resource, login.register_servlets, + profile.register_servlets, ] jwt_secret = "secret" @@ -1202,6 +1203,30 @@ def test_login_custom_sub(self) -> None: self.assertEqual(channel.code, 200, msg=channel.result) self.assertEqual(channel.json_body["user_id"], "@frog:test") + @override_config( + {"jwt_config": {**base_config, "display_name_claim": "display_name"}} + ) + def test_login_custom_display_name(self) -> None: + """Test setting a custom display name.""" + localpart = "pinkie" + user_id = f"@{localpart}:test" + display_name = "Pinkie Pie" + + # Perform the login, specifying a custom display name. + channel = self.jwt_login({"sub": localpart, "display_name": display_name}) + self.assertEqual(channel.code, 200, msg=channel.result) + self.assertEqual(channel.json_body["user_id"], user_id) + + # Fetch the user's display name and check that it was set correctly. + access_token = channel.json_body["access_token"] + channel = self.make_request( + "GET", + f"/_matrix/client/v3/profile/{user_id}/displayname", + access_token=access_token, + ) + self.assertEqual(channel.code, 200, msg=channel.result) + self.assertEqual(channel.json_body["displayname"], display_name) + def test_login_no_token(self) -> None: params = {"type": "org.matrix.login.jwt"} channel = self.make_request(b"POST", LOGIN_URL, params)