From c77cfdeeda4f913a9e88c09490824cb8820d4b45 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Kristj=C3=A1n=20Valur=20J=C3=B3nsson?= Date: Sun, 19 Mar 2023 19:53:14 +0000 Subject: [PATCH 01/22] create test for websocket responses --- tests/protocols/test_websocket.py | 43 ++++++++++++++++++++++++++++--- tests/response.py | 5 ++-- 2 files changed, 43 insertions(+), 5 deletions(-) diff --git a/tests/protocols/test_websocket.py b/tests/protocols/test_websocket.py index 22afeffad..0923599ce 100644 --- a/tests/protocols/test_websocket.py +++ b/tests/protocols/test_websocket.py @@ -923,11 +923,48 @@ async def app(scope, receive, send): assert message["type"] == "websocket.disconnect" async def websocket_session(url): - try: + with pytest.raises(websockets.exceptions.InvalidStatusCode) as exc_info: + async with websockets.client.connect(url): + pass # pragma: no cover + assert exc_info.value.status_code == 403 + + config = Config( + app=app, + ws=ws_protocol_cls, + http=http_protocol_cls, + lifespan="off", + port=unused_tcp_port, + ) + async with run_server(config): + await websocket_session(f"ws://127.0.0.1:{unused_tcp_port}") + +from tests.response import Response + +@pytest.mark.anyio +@pytest.mark.parametrize("ws_protocol_cls", WS_PROTOCOLS) +@pytest.mark.parametrize("http_protocol_cls", HTTP_PROTOCOLS) +async def test_server_reject_connection_with_response( + ws_protocol_cls, http_protocol_cls, unused_tcp_port: int +): + async def app(scope, receive, send): + assert scope["type"] == "websocket" + assert "websocket.http.response" in scope["extensions"] + + # Pull up first recv message. + message = await receive() + assert message["type"] == "websocket.connect" + + # Reject the connection with a response + response = Response(b"goodbye", status_code=400) + await response(scope, receive, send) + message = await receive() + assert message["type"] == "websocket.disconnect" + + async def websocket_session(url): + with pytest.raises(websockets.exceptions.InvalidStatusCode) as exc_info: async with websockets.client.connect(url): pass # pragma: no cover - except Exception: - pass + assert exc_info.value.status_code == 400 config = Config( app=app, diff --git a/tests/response.py b/tests/response.py index 774dee6fa..55766d3f1 100644 --- a/tests/response.py +++ b/tests/response.py @@ -10,9 +10,10 @@ def __init__(self, content, status_code=200, headers=None, media_type=None): self.set_content_length() async def __call__(self, scope, receive, send) -> None: + prefix = "websocket." if scope["type"] == "websocket" else "" await send( { - "type": "http.response.start", + "type": prefix + "http.response.start", "status": self.status_code, "headers": [ [key.encode(), value.encode()] @@ -20,7 +21,7 @@ async def __call__(self, scope, receive, send) -> None: ], } ) - await send({"type": "http.response.body", "body": self.body}) + await send({"type": prefix + "http.response.body", "body": self.body}) def render(self, content) -> bytes: if isinstance(content, bytes): From 33a1a6e0cee7ca6833ad9c3166faa9a798fdad51 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Kristj=C3=A1n=20Valur=20J=C3=B3nsson?= Date: Sun, 19 Mar 2023 20:01:18 +0000 Subject: [PATCH 02/22] update websockets protocol --- .../protocols/websockets/websockets_impl.py | 112 ++++++++++++------ 1 file changed, 78 insertions(+), 34 deletions(-) diff --git a/uvicorn/protocols/websockets/websockets_impl.py b/uvicorn/protocols/websockets/websockets_impl.py index 3ee3086dd..74aedcc2a 100644 --- a/uvicorn/protocols/websockets/websockets_impl.py +++ b/uvicorn/protocols/websockets/websockets_impl.py @@ -46,6 +46,8 @@ WebSocketConnectEvent, WebSocketDisconnectEvent, WebSocketReceiveEvent, + WebSocketResponseBodyEvent, + WebSocketResponseStartEvent, WebSocketScope, WebSocketSendEvent, ) @@ -102,6 +104,7 @@ def __init__( self.connect_sent = False self.lost_connection_before_handshake = False self.accepted_subprotocol: Optional[Subprotocol] = None + self.response_body: Optional[List[bytes]] = None self.ws_server: Server = Server() # type: ignore[assignment] @@ -203,6 +206,7 @@ async def process_request( "headers": asgi_headers, "subprotocols": subprotocols, "state": self.app_state.copy(), + "extensions": {"websocket.http.response": {}}, } task = self.loop.create_task(self.run_asgi()) task.add_done_callback(self.on_task_complete) @@ -278,43 +282,83 @@ async def asgi_send(self, message: "ASGISendEvent") -> None: message_type = message["type"] if not self.handshake_started_event.is_set(): - if message_type == "websocket.accept": - message = cast("WebSocketAcceptEvent", message) - self.logger.info( - '%s - "WebSocket %s" [accepted]', - self.scope["client"], - get_path_with_query_string(self.scope), - ) - self.initial_response = None - self.accepted_subprotocol = cast( - Optional[Subprotocol], message.get("subprotocol") - ) - if "headers" in message: - self.extra_headers.extend( - # ASGI spec requires bytes - # But for compatibility we need to convert it to strings + if self.response_body is None: + if message_type == "websocket.accept": + message = cast("WebSocketAcceptEvent", message) + self.logger.info( + '%s - "WebSocket %s" [accepted]', + self.scope["client"], + get_path_with_query_string(self.scope), + ) + self.initial_response = None + self.accepted_subprotocol = cast( + Optional[Subprotocol], message.get("subprotocol") + ) + if "headers" in message: + self.extra_headers.extend( + # ASGI spec requires bytes + # But for compatibility we need to convert it to strings + (name.decode("latin-1"), value.decode("latin-1")) + for name, value in message["headers"] + ) + self.handshake_started_event.set() + + elif message_type == "websocket.close": + message = cast("WebSocketCloseEvent", message) + self.logger.info( + '%s - "WebSocket %s" 403', + self.scope["client"], + get_path_with_query_string(self.scope), + ) + self.initial_response = (http.HTTPStatus.FORBIDDEN, [], b"") + self.handshake_started_event.set() + self.closed_event.set() + + elif message_type == "websocket.http.response.start": + message = cast("WebSocketResponseStartEvent", message) + try: + status = http.HTTPStatus(message["status"]) + except AttributeError: + status = http.HTTPStatus.FORBIDDEN + headers = [ (name.decode("latin-1"), value.decode("latin-1")) - for name, value in message["headers"] + for name, value in message.get("headers", []) + ] + self.initial_response = (status, headers, b"") + self.response_body = [] + + else: + msg = ( + "Expected ASGI message 'websocket.accept', 'websocket.close', " + "or 'websocket.http.response.start' " + "but got '%s'." ) - self.handshake_started_event.set() - - elif message_type == "websocket.close": - message = cast("WebSocketCloseEvent", message) - self.logger.info( - '%s - "WebSocket %s" 403', - self.scope["client"], - get_path_with_query_string(self.scope), - ) - self.initial_response = (http.HTTPStatus.FORBIDDEN, [], b"") - self.handshake_started_event.set() - self.closed_event.set() - + raise RuntimeError(msg % message_type) else: - msg = ( - "Expected ASGI message 'websocket.accept' or 'websocket.close', " - "but got '%s'." - ) - raise RuntimeError(msg % message_type) + if message_type == "websocket.http.response.body": + message = cast("WebSocketResponseBodyEvent", message) + self.response_body.append(message["body"]) + + if not message.get("more_body", False): + self.logger.info( + '%s - "WebSocket %s" %i', + self.scope["client"], + get_path_with_query_string(self.scope), + self.initial_response[0], + ) + + self.initial_response = self.initial_response[:2] + ( + b"".join(self.response_body), + ) + self.response_body = None + self.handshake_started_event.set() + self.closed_event.set() + else: + msg = ( + "Expected ASGI message 'websocket.http.response.body' " + "but got '%s'." + ) + raise RuntimeError(msg % message_type) elif not self.closed_event.is_set(): await self.handshake_completed_event.wait() From 309a6752c5709f77504923d2a84ad099462adef9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Kristj=C3=A1n=20Valur=20J=C3=B3nsson?= Date: Sun, 19 Mar 2023 22:29:53 +0000 Subject: [PATCH 03/22] update wsproto --- uvicorn/protocols/websockets/wsproto_impl.py | 123 +++++++++++++------ 1 file changed, 86 insertions(+), 37 deletions(-) diff --git a/uvicorn/protocols/websockets/wsproto_impl.py b/uvicorn/protocols/websockets/wsproto_impl.py index 206eb6c30..b28db7417 100644 --- a/uvicorn/protocols/websockets/wsproto_impl.py +++ b/uvicorn/protocols/websockets/wsproto_impl.py @@ -28,6 +28,8 @@ WebSocketConnectEvent, WebSocketDisconnectEvent, WebSocketReceiveEvent, + WebSocketResponseBodyEvent, + WebSocketResponseStartEvent, WebSocketScope, WebSocketSendEvent, ) @@ -77,6 +79,7 @@ def __init__( self.queue: asyncio.Queue["WebSocketEvent"] = asyncio.Queue() self.handshake_complete = False self.close_sent = False + self.response_started = False self.conn = wsproto.WSConnection(connection_type=ConnectionType.SERVER) @@ -187,6 +190,7 @@ def handle_connect(self, event: events.Request) -> None: "subprotocols": event.subprotocols, "extensions": None, "state": self.app_state.copy(), + "extensions": {"websocket.http.response": {}}, } self.queue.put_nowait({"type": "websocket.connect"}) task = self.loop.create_task(self.run_asgi()) @@ -269,49 +273,94 @@ async def send(self, message: "ASGISendEvent") -> None: message_type = message["type"] if not self.handshake_complete: - if message_type == "websocket.accept": - message = typing.cast("WebSocketAcceptEvent", message) - self.logger.info( - '%s - "WebSocket %s" [accepted]', - self.scope["client"], - get_path_with_query_string(self.scope), - ) - subprotocol = message.get("subprotocol") - extra_headers = self.default_headers + list(message.get("headers", [])) - extensions: typing.List[Extension] = [] - if self.config.ws_per_message_deflate: - extensions.append(PerMessageDeflate()) - if not self.transport.is_closing(): - self.handshake_complete = True - output = self.conn.send( - wsproto.events.AcceptConnection( - subprotocol=subprotocol, - extensions=extensions, - extra_headers=extra_headers, + if not self.response_started: + if message_type == "websocket.accept": + message = typing.cast("WebSocketAcceptEvent", message) + self.logger.info( + '%s - "WebSocket %s" [accepted]', + self.scope["client"], + get_path_with_query_string(self.scope), + ) + subprotocol = message.get("subprotocol") + extra_headers = self.default_headers + list( + message.get("headers", []) + ) + extensions: typing.List[Extension] = [] + if self.config.ws_per_message_deflate: + extensions.append(PerMessageDeflate()) + if not self.transport.is_closing(): + self.handshake_complete = True + output = self.conn.send( + wsproto.events.AcceptConnection( + subprotocol=subprotocol, + extensions=extensions, + extra_headers=extra_headers, + ) ) + self.transport.write(output) + + elif message_type == "websocket.close": + self.queue.put_nowait( + {"type": "websocket.disconnect", "code": 1006} + ) + self.logger.info( + '%s - "WebSocket %s" 403', + self.scope["client"], + get_path_with_query_string(self.scope), ) + self.handshake_complete = True + self.close_sent = True + event = events.RejectConnection(status_code=403, headers=[]) + output = self.conn.send(event) self.transport.write(output) + self.transport.close() - elif message_type == "websocket.close": - self.queue.put_nowait({"type": "websocket.disconnect", "code": 1006}) - self.logger.info( - '%s - "WebSocket %s" 403', - self.scope["client"], - get_path_with_query_string(self.scope), - ) - self.handshake_complete = True - self.close_sent = True - event = events.RejectConnection(status_code=403, headers=[]) - output = self.conn.send(event) - self.transport.write(output) - self.transport.close() + elif message_type == "websocket.http.response.start": + self.response_started = True + message = typing.cast("WebSocketResponseStartEvent", message) + self.logger.info( + '%s - "WebSocket %s" %i', + self.scope["client"], + get_path_with_query_string(self.scope), + message["status"], + ) + event = events.RejectConnection( + status_code=message["status"], + headers=message["headers"], + has_body=True, + ) + output = self.conn.send(event) + self.transport.write(output) + else: + msg = ( + "Expected ASGI message 'websocket.accept', 'websocket.close' " + "or 'websocket.http.response.start' " + "but got '%s'." + ) + raise RuntimeError(msg % message_type) else: - msg = ( - "Expected ASGI message 'websocket.accept' or 'websocket.close', " - "but got '%s'." - ) - raise RuntimeError(msg % message_type) + if message_type == "websocket.http.response.body": + message = typing.cast("WebSocketResponseBodyEvent", message) + body_finished = not message.get("more_body", False) + event = events.RejectData( + data=message["body"], body_finished=body_finished + ) + output = self.conn.send(event) + self.transport.write(output) + if body_finished: + self.queue.put_nowait( + {"type": "websocket.disconnect", "code": 1006} + ) + self.handshake_complete = True + self.close_sent = True + + else: + msg = ( + "Expected ASGI message 'websocket.http.response.body' " + "but got '%s'." + ) + raise RuntimeError(msg % message_type) elif not self.close_sent: if message_type == "websocket.send": From 101c2bff57a43bf1e4252ba47d06ce3b0596f344 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Kristj=C3=A1n=20Valur=20J=C3=B3nsson?= Date: Sun, 19 Mar 2023 22:58:59 +0000 Subject: [PATCH 04/22] add multi-body response test --- tests/protocols/test_websocket.py | 54 +++++++++++++++++++++++++++++++ 1 file changed, 54 insertions(+) diff --git a/tests/protocols/test_websocket.py b/tests/protocols/test_websocket.py index 0923599ce..22b51db8e 100644 --- a/tests/protocols/test_websocket.py +++ b/tests/protocols/test_websocket.py @@ -938,8 +938,10 @@ async def websocket_session(url): async with run_server(config): await websocket_session(f"ws://127.0.0.1:{unused_tcp_port}") + from tests.response import Response + @pytest.mark.anyio @pytest.mark.parametrize("ws_protocol_cls", WS_PROTOCOLS) @pytest.mark.parametrize("http_protocol_cls", HTTP_PROTOCOLS) @@ -965,6 +967,58 @@ async def websocket_session(url): async with websockets.client.connect(url): pass # pragma: no cover assert exc_info.value.status_code == 400 + # Websockets module currently does not read the response body from the socket. + + config = Config( + app=app, + ws=ws_protocol_cls, + http=http_protocol_cls, + lifespan="off", + port=unused_tcp_port, + ) + async with run_server(config): + await websocket_session(f"ws://127.0.0.1:{unused_tcp_port}") + + +@pytest.mark.anyio +@pytest.mark.parametrize("ws_protocol_cls", WS_PROTOCOLS) +@pytest.mark.parametrize("http_protocol_cls", HTTP_PROTOCOLS) +async def test_server_reject_connection_with_multibody_response( + ws_protocol_cls, http_protocol_cls, unused_tcp_port: int +): + async def app(scope, receive, send): + assert scope["type"] == "websocket" + assert "websocket.http.response" in scope["extensions"] + + # Pull up first recv message. + message = await receive() + assert message["type"] == "websocket.connect" + message = { + "type": "websocket.http.response.start", + "status": 400, + "headers": [(b"Content-Length", b"20"), (b"Content-Type", b"text/plain")], + } + await send(message) + message = { + "type": "websocket.http.response.body", + "body": b"x" * 10, + "more_body": True, + } + await send(message) + message = { + "type": "websocket.http.response.body", + "body": b"y" * 10, + } + await send(message) + message = await receive() + assert message["type"] == "websocket.disconnect" + + async def websocket_session(url): + with pytest.raises(websockets.exceptions.InvalidStatusCode) as exc_info: + async with websockets.client.connect(url): + pass # pragma: no cover + assert exc_info.value.status_code == 400 + # Websockets module currently does not read the response body from the socket. config = Config( app=app, From fae3cd3f321f235a9d8e5d5a23ab56479032dec5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Kristj=C3=A1n=20Valur=20J=C3=B3nsson?= Date: Mon, 20 Mar 2023 12:08:07 +0000 Subject: [PATCH 05/22] Add missing close() --- uvicorn/protocols/websockets/wsproto_impl.py | 1 + 1 file changed, 1 insertion(+) diff --git a/uvicorn/protocols/websockets/wsproto_impl.py b/uvicorn/protocols/websockets/wsproto_impl.py index b28db7417..d83fe38a2 100644 --- a/uvicorn/protocols/websockets/wsproto_impl.py +++ b/uvicorn/protocols/websockets/wsproto_impl.py @@ -354,6 +354,7 @@ async def send(self, message: "ASGISendEvent") -> None: ) self.handshake_complete = True self.close_sent = True + self.transport.close() else: msg = ( From 673cd4b3d450b485040db71a66fca8bca1ef8ed3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Kristj=C3=A1n=20Valur=20J=C3=B3nsson?= Date: Mon, 20 Mar 2023 12:08:24 +0000 Subject: [PATCH 06/22] Move access log to response.start message --- uvicorn/protocols/websockets/websockets_impl.py | 13 ++++++------- 1 file changed, 6 insertions(+), 7 deletions(-) diff --git a/uvicorn/protocols/websockets/websockets_impl.py b/uvicorn/protocols/websockets/websockets_impl.py index 74aedcc2a..2b825fd3e 100644 --- a/uvicorn/protocols/websockets/websockets_impl.py +++ b/uvicorn/protocols/websockets/websockets_impl.py @@ -316,6 +316,12 @@ async def asgi_send(self, message: "ASGISendEvent") -> None: elif message_type == "websocket.http.response.start": message = cast("WebSocketResponseStartEvent", message) + self.logger.info( + '%s - "WebSocket %s" %s', + self.scope["client"], + get_path_with_query_string(self.scope), + message["status"], + ) try: status = http.HTTPStatus(message["status"]) except AttributeError: @@ -340,13 +346,6 @@ async def asgi_send(self, message: "ASGISendEvent") -> None: self.response_body.append(message["body"]) if not message.get("more_body", False): - self.logger.info( - '%s - "WebSocket %s" %i', - self.scope["client"], - get_path_with_query_string(self.scope), - self.initial_response[0], - ) - self.initial_response = self.initial_response[:2] + ( b"".join(self.response_body), ) From 4207373c6f7324fa867dca6bd788cb7b2ba1ea7f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Kristj=C3=A1n=20Valur=20J=C3=B3nsson?= Date: Mon, 20 Mar 2023 12:13:34 +0000 Subject: [PATCH 07/22] Fix imports --- tests/protocols/test_websocket.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/tests/protocols/test_websocket.py b/tests/protocols/test_websocket.py index 22b51db8e..ec681d74f 100644 --- a/tests/protocols/test_websocket.py +++ b/tests/protocols/test_websocket.py @@ -6,6 +6,7 @@ import pytest from tests.protocols.test_http import HTTP_PROTOCOLS +from tests.response import Response from tests.utils import run_server from uvicorn.config import Config from uvicorn.protocols.websockets.wsproto_impl import WSProtocol @@ -939,9 +940,6 @@ async def websocket_session(url): await websocket_session(f"ws://127.0.0.1:{unused_tcp_port}") -from tests.response import Response - - @pytest.mark.anyio @pytest.mark.parametrize("ws_protocol_cls", WS_PROTOCOLS) @pytest.mark.parametrize("http_protocol_cls", HTTP_PROTOCOLS) From 4db75f43bcbbad451b44f12cfb16135223bb11be Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Kristj=C3=A1n=20Valur=20J=C3=B3nsson?= Date: Mon, 20 Mar 2023 12:14:35 +0000 Subject: [PATCH 08/22] Lint --- uvicorn/protocols/websockets/wsproto_impl.py | 1 - 1 file changed, 1 deletion(-) diff --git a/uvicorn/protocols/websockets/wsproto_impl.py b/uvicorn/protocols/websockets/wsproto_impl.py index d83fe38a2..bc0029111 100644 --- a/uvicorn/protocols/websockets/wsproto_impl.py +++ b/uvicorn/protocols/websockets/wsproto_impl.py @@ -188,7 +188,6 @@ def handle_connect(self, event: events.Request) -> None: "query_string": query_string.encode("ascii"), "headers": headers, "subprotocols": event.subprotocols, - "extensions": None, "state": self.app_state.copy(), "extensions": {"websocket.http.response": {}}, } From 9a2850e22d4fa4659990af8afc12db76b4a4a018 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Kristj=C3=A1n=20Valur=20J=C3=B3nsson?= Date: Mon, 20 Mar 2023 12:49:50 +0000 Subject: [PATCH 09/22] Update uvicorn/protocols/websockets/websockets_impl.py Co-authored-by: Marcelo Trylesinski --- uvicorn/protocols/websockets/websockets_impl.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/uvicorn/protocols/websockets/websockets_impl.py b/uvicorn/protocols/websockets/websockets_impl.py index 2b825fd3e..a8fa0c738 100644 --- a/uvicorn/protocols/websockets/websockets_impl.py +++ b/uvicorn/protocols/websockets/websockets_impl.py @@ -317,7 +317,7 @@ async def asgi_send(self, message: "ASGISendEvent") -> None: elif message_type == "websocket.http.response.start": message = cast("WebSocketResponseStartEvent", message) self.logger.info( - '%s - "WebSocket %s" %s', + '%s - "WebSocket %s" %d', self.scope["client"], get_path_with_query_string(self.scope), message["status"], From c93dabfcf51657e92814cf9f8e12730cbe07ba52 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Kristj=C3=A1n=20Valur=20J=C3=B3nsson?= Date: Mon, 20 Mar 2023 12:59:33 +0000 Subject: [PATCH 10/22] Update tests/protocols/test_websocket.py Co-authored-by: Marcelo Trylesinski --- tests/protocols/test_websocket.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/tests/protocols/test_websocket.py b/tests/protocols/test_websocket.py index ec681d74f..bc65e91ec 100644 --- a/tests/protocols/test_websocket.py +++ b/tests/protocols/test_websocket.py @@ -946,7 +946,10 @@ async def websocket_session(url): async def test_server_reject_connection_with_response( ws_protocol_cls, http_protocol_cls, unused_tcp_port: int ): + disconnected_message = {} + async def app(scope, receive, send): + nonlocal disconnected_message assert scope["type"] == "websocket" assert "websocket.http.response" in scope["extensions"] @@ -957,8 +960,7 @@ async def app(scope, receive, send): # Reject the connection with a response response = Response(b"goodbye", status_code=400) await response(scope, receive, send) - message = await receive() - assert message["type"] == "websocket.disconnect" + disconnected_message = await receive() async def websocket_session(url): with pytest.raises(websockets.exceptions.InvalidStatusCode) as exc_info: @@ -977,6 +979,7 @@ async def websocket_session(url): async with run_server(config): await websocket_session(f"ws://127.0.0.1:{unused_tcp_port}") + assert disconnected_message == {"type": "websocket.disconnect", "code": 1006"} @pytest.mark.anyio @pytest.mark.parametrize("ws_protocol_cls", WS_PROTOCOLS) From d684230da9a52ad9849b3e896ae04b5fe6fd414e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Kristj=C3=A1n=20Valur=20J=C3=B3nsson?= Date: Mon, 20 Mar 2023 13:10:55 +0000 Subject: [PATCH 11/22] fix tests --- tests/protocols/test_websocket.py | 19 ++++++++++++++----- 1 file changed, 14 insertions(+), 5 deletions(-) diff --git a/tests/protocols/test_websocket.py b/tests/protocols/test_websocket.py index bc65e91ec..01757cba2 100644 --- a/tests/protocols/test_websocket.py +++ b/tests/protocols/test_websocket.py @@ -907,7 +907,10 @@ async def send_text(url): async def test_server_reject_connection( ws_protocol_cls, http_protocol_cls, unused_tcp_port: int ): + disconnected_message = {} + async def app(scope, receive, send): + nonlocal disconnected_message assert scope["type"] == "websocket" # Pull up first recv message. @@ -920,8 +923,7 @@ async def app(scope, receive, send): # This doesn't raise `TypeError`: # See https://github.com/encode/uvicorn/issues/244 - message = await receive() - assert message["type"] == "websocket.disconnect" + disconnected_message = await receive() async def websocket_session(url): with pytest.raises(websockets.exceptions.InvalidStatusCode) as exc_info: @@ -939,6 +941,8 @@ async def websocket_session(url): async with run_server(config): await websocket_session(f"ws://127.0.0.1:{unused_tcp_port}") + assert disconnected_message == {"type": "websocket.disconnect", "code": 1006} + @pytest.mark.anyio @pytest.mark.parametrize("ws_protocol_cls", WS_PROTOCOLS) @@ -979,7 +983,8 @@ async def websocket_session(url): async with run_server(config): await websocket_session(f"ws://127.0.0.1:{unused_tcp_port}") - assert disconnected_message == {"type": "websocket.disconnect", "code": 1006"} + assert disconnected_message == {"type": "websocket.disconnect", "code": 1006} + @pytest.mark.anyio @pytest.mark.parametrize("ws_protocol_cls", WS_PROTOCOLS) @@ -987,7 +992,10 @@ async def websocket_session(url): async def test_server_reject_connection_with_multibody_response( ws_protocol_cls, http_protocol_cls, unused_tcp_port: int ): + disconnected_message = {} + async def app(scope, receive, send): + nonlocal disconnected_message assert scope["type"] == "websocket" assert "websocket.http.response" in scope["extensions"] @@ -1011,8 +1019,7 @@ async def app(scope, receive, send): "body": b"y" * 10, } await send(message) - message = await receive() - assert message["type"] == "websocket.disconnect" + disconnected_message = await receive() async def websocket_session(url): with pytest.raises(websockets.exceptions.InvalidStatusCode) as exc_info: @@ -1031,6 +1038,8 @@ async def websocket_session(url): async with run_server(config): await websocket_session(f"ws://127.0.0.1:{unused_tcp_port}") + assert disconnected_message == {"type": "websocket.disconnect", "code": 1006} + @pytest.mark.anyio @pytest.mark.parametrize("ws_protocol_cls", WS_PROTOCOLS) From 030e2b6873685bc352b1ca40cf70766c888c2755 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Kristj=C3=A1n=20Valur=20J=C3=B3nsson?= Date: Mon, 20 Mar 2023 12:48:23 +0000 Subject: [PATCH 12/22] Log a warning if an unknown status code is received --- uvicorn/protocols/websockets/websockets_impl.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/uvicorn/protocols/websockets/websockets_impl.py b/uvicorn/protocols/websockets/websockets_impl.py index a8fa0c738..9e8795327 100644 --- a/uvicorn/protocols/websockets/websockets_impl.py +++ b/uvicorn/protocols/websockets/websockets_impl.py @@ -322,10 +322,15 @@ async def asgi_send(self, message: "ASGISendEvent") -> None: get_path_with_query_string(self.scope), message["status"], ) + # websockets requires the status to be an enum. look it up. try: status = http.HTTPStatus(message["status"]) except AttributeError: status = http.HTTPStatus.FORBIDDEN + self.logger.info( + "status code %d unknown, replaced with 403", + message["status"], + ) headers = [ (name.decode("latin-1"), value.decode("latin-1")) for name, value in message.get("headers", []) From fce8f1d1d9dabbedd5566529826db5367c0a7014 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Kristj=C3=A1n=20Valur=20J=C3=B3nsson?= Date: Mon, 20 Mar 2023 12:52:49 +0000 Subject: [PATCH 13/22] fix formatting for consistency --- uvicorn/protocols/websockets/wsproto_impl.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/uvicorn/protocols/websockets/wsproto_impl.py b/uvicorn/protocols/websockets/wsproto_impl.py index bc0029111..9eb644723 100644 --- a/uvicorn/protocols/websockets/wsproto_impl.py +++ b/uvicorn/protocols/websockets/wsproto_impl.py @@ -318,7 +318,7 @@ async def send(self, message: "ASGISendEvent") -> None: self.response_started = True message = typing.cast("WebSocketResponseStartEvent", message) self.logger.info( - '%s - "WebSocket %s" %i', + '%s - "WebSocket %s" %d', self.scope["client"], get_path_with_query_string(self.scope), message["status"], From 1674575a72430cc96abd3ea7dc82668dea82c586 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Kristj=C3=A1n=20Valur=20J=C3=B3nsson?= Date: Mon, 20 Mar 2023 13:17:53 +0000 Subject: [PATCH 14/22] fix mypy problems --- uvicorn/protocols/websockets/websockets_impl.py | 1 + uvicorn/protocols/websockets/wsproto_impl.py | 6 +++--- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/uvicorn/protocols/websockets/websockets_impl.py b/uvicorn/protocols/websockets/websockets_impl.py index 9e8795327..6ca2bc26f 100644 --- a/uvicorn/protocols/websockets/websockets_impl.py +++ b/uvicorn/protocols/websockets/websockets_impl.py @@ -351,6 +351,7 @@ async def asgi_send(self, message: "ASGISendEvent") -> None: self.response_body.append(message["body"]) if not message.get("more_body", False): + assert self.initial_response is not None self.initial_response = self.initial_response[:2] + ( b"".join(self.response_body), ) diff --git a/uvicorn/protocols/websockets/wsproto_impl.py b/uvicorn/protocols/websockets/wsproto_impl.py index 9eb644723..52e22ff29 100644 --- a/uvicorn/protocols/websockets/wsproto_impl.py +++ b/uvicorn/protocols/websockets/wsproto_impl.py @@ -325,7 +325,7 @@ async def send(self, message: "ASGISendEvent") -> None: ) event = events.RejectConnection( status_code=message["status"], - headers=message["headers"], + headers=list(message["headers"]), has_body=True, ) output = self.conn.send(event) @@ -342,10 +342,10 @@ async def send(self, message: "ASGISendEvent") -> None: if message_type == "websocket.http.response.body": message = typing.cast("WebSocketResponseBodyEvent", message) body_finished = not message.get("more_body", False) - event = events.RejectData( + reject_data = events.RejectData( data=message["body"], body_finished=body_finished ) - output = self.conn.send(event) + output = self.conn.send(reject_data) self.transport.write(output) if body_finished: self.queue.put_nowait( From c4acadb752bf4142945931524258f69483f709e4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Kristj=C3=A1n=20Valur=20J=C3=B3nsson?= Date: Mon, 20 Mar 2023 14:22:47 +0000 Subject: [PATCH 15/22] Simply fail if application passes an invalid status code --- tests/protocols/test_websocket.py | 43 +++++++++++++++++++ .../protocols/websockets/websockets_impl.py | 9 +--- 2 files changed, 44 insertions(+), 8 deletions(-) diff --git a/tests/protocols/test_websocket.py b/tests/protocols/test_websocket.py index 01757cba2..d8a31aea9 100644 --- a/tests/protocols/test_websocket.py +++ b/tests/protocols/test_websocket.py @@ -1041,6 +1041,49 @@ async def websocket_session(url): assert disconnected_message == {"type": "websocket.disconnect", "code": 1006} +@pytest.mark.anyio +@pytest.mark.parametrize("ws_protocol_cls", WS_PROTOCOLS) +@pytest.mark.parametrize("http_protocol_cls", HTTP_PROTOCOLS) +async def test_server_reject_connection_with_invalid_status( + ws_protocol_cls, http_protocol_cls, unused_tcp_port: int +): + async def app(scope, receive, send): + assert scope["type"] == "websocket" + assert "websocket.http.response" in scope["extensions"] + + # Pull up first recv message. + message = await receive() + assert message["type"] == "websocket.connect" + + message = { + "type": "websocket.http.response.start", + "status": 700, # invalid status code + "headers": [(b"Content-Length", b"0"), (b"Content-Type", b"text/plain")], + } + await send(message) + message = { + "type": "websocket.http.response.body", + "body": b"", + } + await send(message) + + async def websocket_session(url): + with pytest.raises(websockets.exceptions.InvalidStatusCode) as exc_info: + async with websockets.client.connect(url): + pass # pragma: no cover + assert exc_info.value.status_code == 500 + + config = Config( + app=app, + ws=ws_protocol_cls, + http=http_protocol_cls, + lifespan="off", + port=unused_tcp_port, + ) + async with run_server(config): + await websocket_session(f"ws://127.0.0.1:{unused_tcp_port}") + + @pytest.mark.anyio @pytest.mark.parametrize("ws_protocol_cls", WS_PROTOCOLS) @pytest.mark.parametrize("http_protocol_cls", HTTP_PROTOCOLS) diff --git a/uvicorn/protocols/websockets/websockets_impl.py b/uvicorn/protocols/websockets/websockets_impl.py index 6ca2bc26f..7aa0408c8 100644 --- a/uvicorn/protocols/websockets/websockets_impl.py +++ b/uvicorn/protocols/websockets/websockets_impl.py @@ -323,14 +323,7 @@ async def asgi_send(self, message: "ASGISendEvent") -> None: message["status"], ) # websockets requires the status to be an enum. look it up. - try: - status = http.HTTPStatus(message["status"]) - except AttributeError: - status = http.HTTPStatus.FORBIDDEN - self.logger.info( - "status code %d unknown, replaced with 403", - message["status"], - ) + status = http.HTTPStatus(message["status"]) headers = [ (name.decode("latin-1"), value.decode("latin-1")) for name, value in message.get("headers", []) From d717e813f70182883632c0a8dd74069ef10c1dd4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Kristj=C3=A1n=20Valur=20J=C3=B3nsson?= Date: Mon, 20 Mar 2023 14:32:46 +0000 Subject: [PATCH 16/22] Add a test for invalid message order --- tests/protocols/test_websocket.py | 53 +++++++++++++++++++++++++++++++ 1 file changed, 53 insertions(+) diff --git a/tests/protocols/test_websocket.py b/tests/protocols/test_websocket.py index d8a31aea9..12e879236 100644 --- a/tests/protocols/test_websocket.py +++ b/tests/protocols/test_websocket.py @@ -1084,6 +1084,59 @@ async def websocket_session(url): await websocket_session(f"ws://127.0.0.1:{unused_tcp_port}") +@pytest.mark.anyio +@pytest.mark.parametrize("ws_protocol_cls", WS_PROTOCOLS) +@pytest.mark.parametrize("http_protocol_cls", HTTP_PROTOCOLS) +async def test_server_reject_connection_with_invalid_msg( + ws_protocol_cls, http_protocol_cls, unused_tcp_port: int +): + if ws_protocol_cls is WSProtocol: + pytest.skip("Cannot supporess asynchronously raised errors") + + async def app(scope, receive, send): + assert scope["type"] == "websocket" + assert "websocket.http.response" in scope["extensions"] + + # Pull up first recv message. + message = await receive() + assert message["type"] == "websocket.connect" + + message = { + "type": "websocket.http.response.start", + "status": 404, + "headers": [(b"Content-Length", b"0"), (b"Content-Type", b"text/plain")], + } + await send(message) + # send invalid message + try: + await send(message) + except Exception: + # swallow the invalid message error + pass + + async def websocket_session(url): + with pytest.raises(websockets.exceptions.InvalidStatusCode) as exc_info: + async with websockets.client.connect(url): + pass # pragma: no cover + if ws_protocol_cls == WSProtocol: + # ws protocol has started to send the response when it + # fails with the subsequent invalid message so it cannot + # undo that, we will get the initial 404 response + assert exc_info.value.status_code == 404 + else: + assert exc_info.value.status_code == 500 + + config = Config( + app=app, + ws=ws_protocol_cls, + http=http_protocol_cls, + lifespan="off", + port=unused_tcp_port, + ) + async with run_server(config): + await websocket_session(f"ws://127.0.0.1:{unused_tcp_port}") + + @pytest.mark.anyio @pytest.mark.parametrize("ws_protocol_cls", WS_PROTOCOLS) @pytest.mark.parametrize("http_protocol_cls", HTTP_PROTOCOLS) From 19e569160836843b363bb1464b7e168a814fe431 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Kristj=C3=A1n=20Valur=20J=C3=B3nsson?= Date: Wed, 22 Mar 2023 11:33:05 +0000 Subject: [PATCH 17/22] Update uvicorn/protocols/websockets/websockets_impl.py Co-authored-by: Marcelo Trylesinski --- uvicorn/protocols/websockets/websockets_impl.py | 1 - 1 file changed, 1 deletion(-) diff --git a/uvicorn/protocols/websockets/websockets_impl.py b/uvicorn/protocols/websockets/websockets_impl.py index 7aa0408c8..4204414fb 100644 --- a/uvicorn/protocols/websockets/websockets_impl.py +++ b/uvicorn/protocols/websockets/websockets_impl.py @@ -348,7 +348,6 @@ async def asgi_send(self, message: "ASGISendEvent") -> None: self.initial_response = self.initial_response[:2] + ( b"".join(self.response_body), ) - self.response_body = None self.handshake_started_event.set() self.closed_event.set() else: From a010a03c99bbaf32f6b08a097231a55579dbc32d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Kristj=C3=A1n=20Valur=20J=C3=B3nsson?= Date: Wed, 22 Mar 2023 12:51:23 +0000 Subject: [PATCH 18/22] Fix initial response error handling and unit-test --- tests/protocols/test_websocket.py | 13 ++++--------- uvicorn/protocols/websockets/wsproto_impl.py | 5 +++-- 2 files changed, 7 insertions(+), 11 deletions(-) diff --git a/tests/protocols/test_websocket.py b/tests/protocols/test_websocket.py index 12e879236..fce6d8440 100644 --- a/tests/protocols/test_websocket.py +++ b/tests/protocols/test_websocket.py @@ -1090,9 +1090,6 @@ async def websocket_session(url): async def test_server_reject_connection_with_invalid_msg( ws_protocol_cls, http_protocol_cls, unused_tcp_port: int ): - if ws_protocol_cls is WSProtocol: - pytest.skip("Cannot supporess asynchronously raised errors") - async def app(scope, receive, send): assert scope["type"] == "websocket" assert "websocket.http.response" in scope["extensions"] @@ -1107,12 +1104,8 @@ async def app(scope, receive, send): "headers": [(b"Content-Length", b"0"), (b"Content-Type", b"text/plain")], } await send(message) - # send invalid message - try: - await send(message) - except Exception: - # swallow the invalid message error - pass + # send invalid message. This will raise an exception here + await send(message) async def websocket_session(url): with pytest.raises(websockets.exceptions.InvalidStatusCode) as exc_info: @@ -1124,6 +1117,8 @@ async def websocket_session(url): # undo that, we will get the initial 404 response assert exc_info.value.status_code == 404 else: + # websockets protocol sends its response in one chunk + # and can override the already started response with a 500 assert exc_info.value.status_code == 500 config = Config( diff --git a/uvicorn/protocols/websockets/wsproto_impl.py b/uvicorn/protocols/websockets/wsproto_impl.py index 52e22ff29..25ba14212 100644 --- a/uvicorn/protocols/websockets/wsproto_impl.py +++ b/uvicorn/protocols/websockets/wsproto_impl.py @@ -252,14 +252,15 @@ async def run_asgi(self) -> None: result = await self.app(self.scope, self.receive, self.send) except BaseException: self.logger.exception("Exception in ASGI application\n") - if not self.handshake_complete: + if not self.response_started: self.send_500_response() self.transport.close() else: if not self.handshake_complete: msg = "ASGI callable returned without completing handshake." self.logger.error(msg) - self.send_500_response() + if not self.response_started: + self.send_500_response() self.transport.close() elif result is not None: msg = "ASGI callable should return None, but returned '%s'." From 1edcc35f05da095e2979dbbd62e02e2689381d40 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Kristj=C3=A1n=20Valur=20J=C3=B3nsson?= Date: Wed, 22 Mar 2023 13:16:36 +0000 Subject: [PATCH 19/22] Add a similar missing-body test --- tests/protocols/test_websocket.py | 42 +++++++++++++++++++++++++++++++ 1 file changed, 42 insertions(+) diff --git a/tests/protocols/test_websocket.py b/tests/protocols/test_websocket.py index fce6d8440..9fb9bf062 100644 --- a/tests/protocols/test_websocket.py +++ b/tests/protocols/test_websocket.py @@ -1132,6 +1132,48 @@ async def websocket_session(url): await websocket_session(f"ws://127.0.0.1:{unused_tcp_port}") +@pytest.mark.anyio +@pytest.mark.parametrize("ws_protocol_cls", WS_PROTOCOLS) +@pytest.mark.parametrize("http_protocol_cls", HTTP_PROTOCOLS) +async def test_server_reject_connection_with_missing_body( + ws_protocol_cls, http_protocol_cls, unused_tcp_port: int +): + async def app(scope, receive, send): + assert scope["type"] == "websocket" + assert "websocket.http.response" in scope["extensions"] + + # Pull up first recv message. + message = await receive() + assert message["type"] == "websocket.connect" + + message = { + "type": "websocket.http.response.start", + "status": 404, + "headers": [(b"Content-Length", b"0"), (b"Content-Type", b"text/plain")], + } + await send(message) + # no further message + + async def websocket_session(url): + with pytest.raises(websockets.exceptions.InvalidStatusCode) as exc_info: + async with websockets.client.connect(url): + pass # pragma: no cover + if ws_protocol_cls == WSProtocol: + assert exc_info.value.status_code == 404 + else: + assert exc_info.value.status_code == 500 + + config = Config( + app=app, + ws=ws_protocol_cls, + http=http_protocol_cls, + lifespan="off", + port=unused_tcp_port, + ) + async with run_server(config): + await websocket_session(f"ws://127.0.0.1:{unused_tcp_port}") + + @pytest.mark.anyio @pytest.mark.parametrize("ws_protocol_cls", WS_PROTOCOLS) @pytest.mark.parametrize("http_protocol_cls", HTTP_PROTOCOLS) From d3b8d210bf175743ca63e04f2c05bb34aa68f5aa Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Kristj=C3=A1n=20Valur=20J=C3=B3nsson?= Date: Mon, 27 Mar 2023 22:44:08 +0000 Subject: [PATCH 20/22] Use httpx to check rejection response body --- tests/protocols/test_websocket.py | 38 +++++++++++++++++++------------ 1 file changed, 24 insertions(+), 14 deletions(-) diff --git a/tests/protocols/test_websocket.py b/tests/protocols/test_websocket.py index 9fb9bf062..87bb604f8 100644 --- a/tests/protocols/test_websocket.py +++ b/tests/protocols/test_websocket.py @@ -51,6 +51,21 @@ async def asgi(self): break +async def wsresponse(url): + """ + A simple websocket connection request and response helper + """ + url = url.replace("ws:", "http:") + headers = { + "connection": "upgrade", + "upgrade": "websocket", + "Sec-WebSocket-Key": "x3JJHMbDL1EzLkh9GBhXDw==", + "Sec-WebSocket-Version": "13", + } + async with httpx.AsyncClient() as client: + return await client.get(url, headers=headers) + + @pytest.mark.anyio @pytest.mark.parametrize("ws_protocol_cls", WS_PROTOCOLS) @pytest.mark.parametrize("http_protocol_cls", HTTP_PROTOCOLS) @@ -967,11 +982,9 @@ async def app(scope, receive, send): disconnected_message = await receive() async def websocket_session(url): - with pytest.raises(websockets.exceptions.InvalidStatusCode) as exc_info: - async with websockets.client.connect(url): - pass # pragma: no cover - assert exc_info.value.status_code == 400 - # Websockets module currently does not read the response body from the socket. + response = await wsresponse(url) + assert response.status_code == 400 + assert response.content == b"goodbye" config = Config( app=app, @@ -1022,11 +1035,9 @@ async def app(scope, receive, send): disconnected_message = await receive() async def websocket_session(url): - with pytest.raises(websockets.exceptions.InvalidStatusCode) as exc_info: - async with websockets.client.connect(url): - pass # pragma: no cover - assert exc_info.value.status_code == 400 - # Websockets module currently does not read the response body from the socket. + response = await wsresponse(url) + assert response.status_code == 400 + assert response.content == (b"x" * 10) + (b"y" * 10) config = Config( app=app, @@ -1068,10 +1079,9 @@ async def app(scope, receive, send): await send(message) async def websocket_session(url): - with pytest.raises(websockets.exceptions.InvalidStatusCode) as exc_info: - async with websockets.client.connect(url): - pass # pragma: no cover - assert exc_info.value.status_code == 500 + response = await wsresponse(url) + assert response.status_code == 500 + assert response.content == b"Internal Server Error" config = Config( app=app, From f4563ced5f4504b20bfcc81c5ddc6cbfc67815c6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Kristj=C3=A1n=20Valur=20J=C3=B3nsson?= Date: Mon, 27 Mar 2023 22:44:29 +0000 Subject: [PATCH 21/22] Only set the "response_started" flag once the data has been written. --- uvicorn/protocols/websockets/wsproto_impl.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/uvicorn/protocols/websockets/wsproto_impl.py b/uvicorn/protocols/websockets/wsproto_impl.py index 25ba14212..a6fab3b1e 100644 --- a/uvicorn/protocols/websockets/wsproto_impl.py +++ b/uvicorn/protocols/websockets/wsproto_impl.py @@ -316,7 +316,6 @@ async def send(self, message: "ASGISendEvent") -> None: self.transport.close() elif message_type == "websocket.http.response.start": - self.response_started = True message = typing.cast("WebSocketResponseStartEvent", message) self.logger.info( '%s - "WebSocket %s" %d', @@ -331,6 +330,7 @@ async def send(self, message: "ASGISendEvent") -> None: ) output = self.conn.send(event) self.transport.write(output) + self.response_started = True else: msg = ( From fea901cbe95f2881c446547520af3b3e207bea7c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Kristj=C3=A1n=20Valur=20J=C3=B3nsson?= Date: Wed, 29 Mar 2023 16:12:35 +0000 Subject: [PATCH 22/22] Add test showing how content-length/transfer-encoding is automatically handled. --- tests/protocols/test_websocket.py | 49 +++++++++++++++++++++++++++++++ 1 file changed, 49 insertions(+) diff --git a/tests/protocols/test_websocket.py b/tests/protocols/test_websocket.py index 87bb604f8..a8fb57c0b 100644 --- a/tests/protocols/test_websocket.py +++ b/tests/protocols/test_websocket.py @@ -1094,6 +1094,55 @@ async def websocket_session(url): await websocket_session(f"ws://127.0.0.1:{unused_tcp_port}") +@pytest.mark.anyio +@pytest.mark.parametrize("ws_protocol_cls", WS_PROTOCOLS) +@pytest.mark.parametrize("http_protocol_cls", HTTP_PROTOCOLS) +async def test_server_reject_connection_with_body_nolength( + ws_protocol_cls, http_protocol_cls, unused_tcp_port: int +): + # test that the server can send a response with a body but no content-length + async def app(scope, receive, send): + assert scope["type"] == "websocket" + assert "websocket.http.response" in scope["extensions"] + + # Pull up first recv message. + message = await receive() + assert message["type"] == "websocket.connect" + + message = { + "type": "websocket.http.response.start", + "status": 403, + "headers": [], + } + await send(message) + message = { + "type": "websocket.http.response.body", + "body": b"hardbody", + } + await send(message) + + async def websocket_session(url): + response = await wsresponse(url) + assert response.status_code == 403 + assert response.content == b"hardbody" + if ws_protocol_cls == WSProtocol: + # wsproto automatically makes the message chunked + assert response.headers["transfer-encoding"] == "chunked" + else: + # websockets automatically adds a content-length + assert response.headers["content-length"] == "8" + + config = Config( + app=app, + ws=ws_protocol_cls, + http=http_protocol_cls, + lifespan="off", + port=unused_tcp_port, + ) + async with run_server(config): + await websocket_session(f"ws://127.0.0.1:{unused_tcp_port}") + + @pytest.mark.anyio @pytest.mark.parametrize("ws_protocol_cls", WS_PROTOCOLS) @pytest.mark.parametrize("http_protocol_cls", HTTP_PROTOCOLS)