From 5f41c9c054d863b57d128a8437637f72d6787e77 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Kristj=C3=A1n=20Valur=20J=C3=B3nsson?= Date: Thu, 30 Mar 2023 21:30:44 +0000 Subject: [PATCH] Do not send the response start until the first response body is received --- tests/protocols/test_websocket.py | 15 ++------- uvicorn/protocols/websockets/wsproto_impl.py | 32 ++++++++++++++------ 2 files changed, 25 insertions(+), 22 deletions(-) diff --git a/tests/protocols/test_websocket.py b/tests/protocols/test_websocket.py index a8fb57c0b..73b2e9e9d 100644 --- a/tests/protocols/test_websocket.py +++ b/tests/protocols/test_websocket.py @@ -1170,15 +1170,7 @@ async def websocket_session(url): with pytest.raises(websockets.exceptions.InvalidStatusCode) as exc_info: async with websockets.client.connect(url): pass # pragma: no cover - if ws_protocol_cls == WSProtocol: - # ws protocol has started to send the response when it - # fails with the subsequent invalid message so it cannot - # undo that, we will get the initial 404 response - assert exc_info.value.status_code == 404 - else: - # websockets protocol sends its response in one chunk - # and can override the already started response with a 500 - assert exc_info.value.status_code == 500 + assert exc_info.value.status_code == 500 config = Config( app=app, @@ -1217,10 +1209,7 @@ async def websocket_session(url): with pytest.raises(websockets.exceptions.InvalidStatusCode) as exc_info: async with websockets.client.connect(url): pass # pragma: no cover - if ws_protocol_cls == WSProtocol: - assert exc_info.value.status_code == 404 - else: - assert exc_info.value.status_code == 500 + assert exc_info.value.status_code == 500 config = Config( app=app, diff --git a/uvicorn/protocols/websockets/wsproto_impl.py b/uvicorn/protocols/websockets/wsproto_impl.py index a6fab3b1e..f3e57177f 100644 --- a/uvicorn/protocols/websockets/wsproto_impl.py +++ b/uvicorn/protocols/websockets/wsproto_impl.py @@ -79,7 +79,10 @@ def __init__( self.queue: asyncio.Queue["WebSocketEvent"] = asyncio.Queue() self.handshake_complete = False self.close_sent = False - self.response_started = False + + # Rejection state + self.reject_event: typing.Optional[typing.Any] = None + self.response_started: bool = False # we have sent response start self.conn = wsproto.WSConnection(connection_type=ConnectionType.SERVER) @@ -233,6 +236,8 @@ def handle_ping(self, event: events.Ping) -> None: self.transport.write(self.conn.send(event.response())) def send_500_response(self) -> None: + if self.response_started or self.handshake_complete: + return # we cannot send responses anymore headers = [ (b"content-type", b"text/plain; charset=utf-8"), (b"connection", b"close"), @@ -252,15 +257,13 @@ async def run_asgi(self) -> None: result = await self.app(self.scope, self.receive, self.send) except BaseException: self.logger.exception("Exception in ASGI application\n") - if not self.response_started: - self.send_500_response() + self.send_500_response() self.transport.close() else: if not self.handshake_complete: msg = "ASGI callable returned without completing handshake." self.logger.error(msg) - if not self.response_started: - self.send_500_response() + self.send_500_response() self.transport.close() elif result is not None: msg = "ASGI callable should return None, but returned '%s'." @@ -273,7 +276,8 @@ async def send(self, message: "ASGISendEvent") -> None: message_type = message["type"] if not self.handshake_complete: - if not self.response_started: + if not (self.response_started or self.reject_event): + # a rejection event has not been sent yet if message_type == "websocket.accept": message = typing.cast("WebSocketAcceptEvent", message) self.logger.info( @@ -328,9 +332,10 @@ async def send(self, message: "ASGISendEvent") -> None: headers=list(message["headers"]), has_body=True, ) - output = self.conn.send(event) - self.transport.write(output) - self.response_started = True + # Create the event here but do not send it, the ASGI spec + # suggest that we wait for the body event before sending. + # https://asgi.readthedocs.io/en/latest/specs/www.html#response-start-send-event + self.reject_event = event else: msg = ( @@ -340,14 +345,23 @@ async def send(self, message: "ASGISendEvent") -> None: ) raise RuntimeError(msg % message_type) else: + # we have started a rejection process with http.response.start if message_type == "websocket.http.response.body": message = typing.cast("WebSocketResponseBodyEvent", message) body_finished = not message.get("more_body", False) reject_data = events.RejectData( data=message["body"], body_finished=body_finished ) + if self.reject_event is not None: + # Prepend with the reject event now that we have a body event. + output = self.conn.send(self.reject_event) + self.transport.write(output) + self.reject_event = None + self.response_started = True + output = self.conn.send(reject_data) self.transport.write(output) + if body_finished: self.queue.put_nowait( {"type": "websocket.disconnect", "code": 1006}