Skip to content

Commit

Permalink
Do not send the response start until the first response body is received
Browse files Browse the repository at this point in the history
  • Loading branch information
kristjanvalur committed May 6, 2023
1 parent b10d31a commit 5f41c9c
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 22 deletions.
15 changes: 2 additions & 13 deletions tests/protocols/test_websocket.py
Original file line number Diff line number Diff line change
Expand Up @@ -1170,15 +1170,7 @@ async def websocket_session(url):
with pytest.raises(websockets.exceptions.InvalidStatusCode) as exc_info:
async with websockets.client.connect(url):
pass # pragma: no cover
if ws_protocol_cls == WSProtocol:
# ws protocol has started to send the response when it
# fails with the subsequent invalid message so it cannot
# undo that, we will get the initial 404 response
assert exc_info.value.status_code == 404
else:
# websockets protocol sends its response in one chunk
# and can override the already started response with a 500
assert exc_info.value.status_code == 500
assert exc_info.value.status_code == 500

config = Config(
app=app,
Expand Down Expand Up @@ -1217,10 +1209,7 @@ async def websocket_session(url):
with pytest.raises(websockets.exceptions.InvalidStatusCode) as exc_info:
async with websockets.client.connect(url):
pass # pragma: no cover
if ws_protocol_cls == WSProtocol:
assert exc_info.value.status_code == 404
else:
assert exc_info.value.status_code == 500
assert exc_info.value.status_code == 500

config = Config(
app=app,
Expand Down
32 changes: 23 additions & 9 deletions uvicorn/protocols/websockets/wsproto_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,10 @@ def __init__(
self.queue: asyncio.Queue["WebSocketEvent"] = asyncio.Queue()
self.handshake_complete = False
self.close_sent = False
self.response_started = False

# Rejection state
self.reject_event: typing.Optional[typing.Any] = None
self.response_started: bool = False # we have sent response start

self.conn = wsproto.WSConnection(connection_type=ConnectionType.SERVER)

Expand Down Expand Up @@ -233,6 +236,8 @@ def handle_ping(self, event: events.Ping) -> None:
self.transport.write(self.conn.send(event.response()))

def send_500_response(self) -> None:
if self.response_started or self.handshake_complete:
return # we cannot send responses anymore
headers = [
(b"content-type", b"text/plain; charset=utf-8"),
(b"connection", b"close"),
Expand All @@ -252,15 +257,13 @@ async def run_asgi(self) -> None:
result = await self.app(self.scope, self.receive, self.send)
except BaseException:
self.logger.exception("Exception in ASGI application\n")
if not self.response_started:
self.send_500_response()
self.send_500_response()
self.transport.close()
else:
if not self.handshake_complete:
msg = "ASGI callable returned without completing handshake."
self.logger.error(msg)
if not self.response_started:
self.send_500_response()
self.send_500_response()
self.transport.close()
elif result is not None:
msg = "ASGI callable should return None, but returned '%s'."
Expand All @@ -273,7 +276,8 @@ async def send(self, message: "ASGISendEvent") -> None:
message_type = message["type"]

if not self.handshake_complete:
if not self.response_started:
if not (self.response_started or self.reject_event):
# a rejection event has not been sent yet
if message_type == "websocket.accept":
message = typing.cast("WebSocketAcceptEvent", message)
self.logger.info(
Expand Down Expand Up @@ -328,9 +332,10 @@ async def send(self, message: "ASGISendEvent") -> None:
headers=list(message["headers"]),
has_body=True,
)
output = self.conn.send(event)
self.transport.write(output)
self.response_started = True
# Create the event here but do not send it, the ASGI spec
# suggest that we wait for the body event before sending.
# https://asgi.readthedocs.io/en/latest/specs/www.html#response-start-send-event
self.reject_event = event

else:
msg = (
Expand All @@ -340,14 +345,23 @@ async def send(self, message: "ASGISendEvent") -> None:
)
raise RuntimeError(msg % message_type)
else:
# we have started a rejection process with http.response.start
if message_type == "websocket.http.response.body":
message = typing.cast("WebSocketResponseBodyEvent", message)
body_finished = not message.get("more_body", False)
reject_data = events.RejectData(
data=message["body"], body_finished=body_finished
)
if self.reject_event is not None:
# Prepend with the reject event now that we have a body event.
output = self.conn.send(self.reject_event)
self.transport.write(output)
self.reject_event = None
self.response_started = True

output = self.conn.send(reject_data)
self.transport.write(output)

if body_finished:
self.queue.put_nowait(
{"type": "websocket.disconnect", "code": 1006}
Expand Down

0 comments on commit 5f41c9c

Please sign in to comment.