diff --git a/jupyter_server/base/websocket.py b/jupyter_server/base/websocket.py index 8780d7afc4..5fd971f1b0 100644 --- a/jupyter_server/base/websocket.py +++ b/jupyter_server/base/websocket.py @@ -1,4 +1,7 @@ """Base websocket classes.""" + +from __future__ import annotations + import re import warnings from typing import Optional, no_type_check @@ -164,3 +167,10 @@ def send_ping(self): def on_pong(self, data): """Handle a pong message.""" self.last_pong = ioloop.IOLoop.current().time() + + def select_subprotocol(self, subprotocols: list[str]) -> str | None: + # default subprotocol + # some clients (Chrome) + # require selected subprotocol to match one of the requested subprotocols + # otherwise connection is rejected + return "v1.token.websocket.jupyter.org" diff --git a/jupyter_server/services/events/handlers.py b/jupyter_server/services/events/handlers.py index ce580048f2..7672c4aeda 100644 --- a/jupyter_server/services/events/handlers.py +++ b/jupyter_server/services/events/handlers.py @@ -14,6 +14,7 @@ from jupyter_server.auth.decorator import authorized, ws_authenticated from jupyter_server.base.handlers import JupyterHandler +from jupyter_server.base.websocket import WebSocketMixin from ...base.handlers import APIHandler @@ -21,6 +22,7 @@ class SubscribeWebsocket( + WebSocketMixin, JupyterHandler, websocket.WebSocketHandler, ): diff --git a/jupyter_server/services/kernels/websocket.py b/jupyter_server/services/kernels/websocket.py index 374df76f3e..c5682fca7c 100644 --- a/jupyter_server/services/kernels/websocket.py +++ b/jupyter_server/services/kernels/websocket.py @@ -90,6 +90,12 @@ def select_subprotocol(self, subprotocols): preferred_protocol = "v1.kernel.websocket.jupyter.org" elif preferred_protocol == "": preferred_protocol = None - selected_subprotocol = preferred_protocol if preferred_protocol in subprotocols else None + + # super() subprotocol enables token authentication via subprotocol + selected_subprotocol = ( + preferred_protocol + if preferred_protocol in subprotocols + else super().select_subprotocol(subprotocols) + ) # None is the default, "legacy" protocol return selected_subprotocol diff --git a/tests/base/test_websocket.py b/tests/base/test_websocket.py index c888c8601b..690cd763c0 100644 --- a/tests/base/test_websocket.py +++ b/tests/base/test_websocket.py @@ -149,10 +149,11 @@ async def test_websocket_token_subprotocol_auth(jp_serverapp, jp_ws_fetch): "ws", headers={ "Authorization": "", - "Sec-WebSocket-Protocol": "v1.kernel.websocket.jupyter.org, v1.token.websocket.jupyter.org." + "Sec-WebSocket-Protocol": "v1.kernel.websocket.jupyter.org, v1.token.websocket.jupyter.org, v1.token.websocket.jupyter.org." + token, }, ) + assert ws.protocol.selected_subprotocol == "v1.token.websocket.jupyter.org" ws.close()