Skip to content

Commit

Permalink
Fix initial response error handling and unit-test
Browse files Browse the repository at this point in the history
  • Loading branch information
kristjanvalur committed Mar 22, 2023
1 parent 19e5691 commit a010a03
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 11 deletions.
13 changes: 4 additions & 9 deletions tests/protocols/test_websocket.py
Original file line number Diff line number Diff line change
Expand Up @@ -1090,9 +1090,6 @@ async def websocket_session(url):
async def test_server_reject_connection_with_invalid_msg(
ws_protocol_cls, http_protocol_cls, unused_tcp_port: int
):
if ws_protocol_cls is WSProtocol:
pytest.skip("Cannot supporess asynchronously raised errors")

async def app(scope, receive, send):
assert scope["type"] == "websocket"
assert "websocket.http.response" in scope["extensions"]
Expand All @@ -1107,12 +1104,8 @@ async def app(scope, receive, send):
"headers": [(b"Content-Length", b"0"), (b"Content-Type", b"text/plain")],
}
await send(message)
# send invalid message
try:
await send(message)
except Exception:
# swallow the invalid message error
pass
# send invalid message. This will raise an exception here
await send(message)

async def websocket_session(url):
with pytest.raises(websockets.exceptions.InvalidStatusCode) as exc_info:
Expand All @@ -1124,6 +1117,8 @@ async def websocket_session(url):
# undo that, we will get the initial 404 response
assert exc_info.value.status_code == 404
else:
# websockets protocol sends its response in one chunk
# and can override the already started response with a 500
assert exc_info.value.status_code == 500

config = Config(
Expand Down
5 changes: 3 additions & 2 deletions uvicorn/protocols/websockets/wsproto_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -252,14 +252,15 @@ async def run_asgi(self) -> None:
result = await self.app(self.scope, self.receive, self.send)
except BaseException:
self.logger.exception("Exception in ASGI application\n")
if not self.handshake_complete:
if not self.response_started:
self.send_500_response()
self.transport.close()
else:
if not self.handshake_complete:
msg = "ASGI callable returned without completing handshake."
self.logger.error(msg)
self.send_500_response()
if not self.response_started:
self.send_500_response()
self.transport.close()
elif result is not None:
msg = "ASGI callable should return None, but returned '%s'."
Expand Down

0 comments on commit a010a03

Please sign in to comment.