Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Raise Disconnect on send() when client disconnected #2218

Merged
merged 5 commits into from
Jan 19, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 7 additions & 2 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -253,8 +253,12 @@ def unused_tcp_port() -> int:
marks=pytest.mark.skipif(
not importlib.util.find_spec("wsproto"), reason="wsproto not installed."
),
id="wsproto",
),
pytest.param(
"uvicorn.protocols.websockets.websockets_impl:WebSocketProtocol",
id="websockets",
),
"uvicorn.protocols.websockets.websockets_impl:WebSocketProtocol",
]
)
def ws_protocol_cls(request: pytest.FixtureRequest):
@@ -269,8 +273,9 @@ def ws_protocol_cls(request: pytest.FixtureRequest):
not importlib.util.find_spec("httptools"),
reason="httptools not installed.",
),
id="httptools",
),
"uvicorn.protocols.http.h11_impl:H11Protocol",
pytest.param("uvicorn.protocols.http.h11_impl:H11Protocol", id="h11"),
]
)
def http_protocol_cls(request: pytest.FixtureRequest):
36 changes: 36 additions & 0 deletions tests/protocols/test_websocket.py
Original file line number Diff line number Diff line change
@@ -762,6 +762,42 @@ async def app(scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable
assert got_disconnect_event_before_shutdown is True


@pytest.mark.anyio
async def test_client_connection_lost_on_send(
ws_protocol_cls: "typing.Type[WSProtocol | WebSocketProtocol]",
http_protocol_cls: "typing.Type[H11Protocol | HttpToolsProtocol]",
unused_tcp_port: int,
):
disconnect = asyncio.Event()
got_disconnect_event = False

async def app(scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable):
nonlocal got_disconnect_event
message = await receive()
if message["type"] == "websocket.connect":
await send({"type": "websocket.accept"})
try:
await disconnect.wait()
await send({"type": "websocket.send", "text": "123"})
except IOError:
got_disconnect_event = True

config = Config(
app=app,
ws=ws_protocol_cls,
http=http_protocol_cls,
lifespan="off",
port=unused_tcp_port,
)
async with run_server(config):
url = f"ws://127.0.0.1:{unused_tcp_port}"
async with websockets.client.connect(url):
await asyncio.sleep(0.1)
disconnect.set()

assert got_disconnect_event is True


@pytest.mark.anyio
async def test_connection_lost_before_handshake_complete(
ws_protocol_cls: "typing.Type[WSProtocol | WebSocketProtocol]",
4 changes: 4 additions & 0 deletions uvicorn/protocols/utils.py
Original file line number Diff line number Diff line change
@@ -6,6 +6,10 @@
from uvicorn._types import WWWScope


class ClientDisconnected(IOError):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IOError is an alias for OSError on python 3

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ouch I think we put IOError in the spec...

...


def get_remote_addr(transport: asyncio.Transport) -> tuple[str, int] | None:
socket_info = transport.get_extra_info("socket")
if socket_info is not None:
45 changes: 26 additions & 19 deletions uvicorn/protocols/websockets/websockets_impl.py
Original file line number Diff line number Diff line change
@@ -38,6 +38,7 @@
from uvicorn.config import Config
from uvicorn.logging import TRACE_LOG_LEVEL
from uvicorn.protocols.utils import (
ClientDisconnected,
get_local_addr,
get_path_with_query_string,
get_remote_addr,
@@ -252,6 +253,9 @@ async def run_asgi(self) -> None:
"""
try:
result = await self.app(self.scope, self.asgi_receive, self.asgi_send)
except ClientDisconnected:
self.closed_event.set()
self.transport.close()
except BaseException as exc:
self.closed_event.set()
msg = "Exception in ASGI application\n"
@@ -336,26 +340,29 @@ async def asgi_send(self, message: "ASGISendEvent") -> None:
elif not self.closed_event.is_set() and self.initial_response is None:
await self.handshake_completed_event.wait()

if message_type == "websocket.send":
message = cast("WebSocketSendEvent", message)
bytes_data = message.get("bytes")
text_data = message.get("text")
data = text_data if bytes_data is None else bytes_data
await self.send(data) # type: ignore[arg-type]

elif message_type == "websocket.close":
message = cast("WebSocketCloseEvent", message)
code = message.get("code", 1000)
reason = message.get("reason", "") or ""
await self.close(code, reason)
self.closed_event.set()
try:
if message_type == "websocket.send":
message = cast("WebSocketSendEvent", message)
bytes_data = message.get("bytes")
text_data = message.get("text")
data = text_data if bytes_data is None else bytes_data
await self.send(data) # type: ignore[arg-type]

elif message_type == "websocket.close":
message = cast("WebSocketCloseEvent", message)
code = message.get("code", 1000)
reason = message.get("reason", "") or ""
await self.close(code, reason)
self.closed_event.set()

else:
msg = (
"Expected ASGI message 'websocket.send' or 'websocket.close',"
" but got '%s'."
)
raise RuntimeError(msg % message_type)
else:
msg = (
"Expected ASGI message 'websocket.send' or 'websocket.close',"
" but got '%s'."
)
raise RuntimeError(msg % message_type)
except ConnectionClosed as exc:
raise ClientDisconnected from exc

elif self.initial_response is not None:
if message_type == "websocket.http.response.body":
68 changes: 37 additions & 31 deletions uvicorn/protocols/websockets/wsproto_impl.py
Original file line number Diff line number Diff line change
@@ -10,7 +10,7 @@
from wsproto import ConnectionType, events
from wsproto.connection import ConnectionState
from wsproto.extensions import Extension, PerMessageDeflate
from wsproto.utilities import RemoteProtocolError
from wsproto.utilities import LocalProtocolError, RemoteProtocolError

from uvicorn._types import (
ASGISendEvent,
@@ -25,6 +25,7 @@
from uvicorn.config import Config
from uvicorn.logging import TRACE_LOG_LEVEL
from uvicorn.protocols.utils import (
ClientDisconnected,
get_local_addr,
get_path_with_query_string,
get_remote_addr,
@@ -236,6 +237,8 @@ def send_500_response(self) -> None:
async def run_asgi(self) -> None:
try:
result = await self.app(self.scope, self.receive, self.send)
except ClientDisconnected:
self.transport.close()
except BaseException:
self.logger.exception("Exception in ASGI application\n")
self.send_500_response()
@@ -325,36 +328,39 @@ async def send(self, message: ASGISendEvent) -> None:
raise RuntimeError(msg % message_type)

elif not self.close_sent and not self.response_started:
if message_type == "websocket.send":
message = typing.cast(WebSocketSendEvent, message)
bytes_data = message.get("bytes")
text_data = message.get("text")
data = text_data if bytes_data is None else bytes_data
output = self.conn.send(
wsproto.events.Message(data=data) # type: ignore[type-var]
)
if not self.transport.is_closing():
self.transport.write(output)

elif message_type == "websocket.close":
message = typing.cast(WebSocketCloseEvent, message)
self.close_sent = True
code = message.get("code", 1000)
reason = message.get("reason", "") or ""
self.queue.put_nowait({"type": "websocket.disconnect", "code": code})
output = self.conn.send(
wsproto.events.CloseConnection(code=code, reason=reason)
)
if not self.transport.is_closing():
self.transport.write(output)
self.transport.close()

else:
msg = (
"Expected ASGI message 'websocket.send' or 'websocket.close',"
" but got '%s'."
)
raise RuntimeError(msg % message_type)
try:
if message_type == "websocket.send":
message = typing.cast(WebSocketSendEvent, message)
bytes_data = message.get("bytes")
text_data = message.get("text")
data = text_data if bytes_data is None else bytes_data
output = self.conn.send(wsproto.events.Message(data=data)) # type: ignore
if not self.transport.is_closing():
self.transport.write(output)

elif message_type == "websocket.close":
message = typing.cast(WebSocketCloseEvent, message)
self.close_sent = True
code = message.get("code", 1000)
reason = message.get("reason", "") or ""
self.queue.put_nowait(
{"type": "websocket.disconnect", "code": code}
)
output = self.conn.send(
wsproto.events.CloseConnection(code=code, reason=reason)
)
if not self.transport.is_closing():
self.transport.write(output)
self.transport.close()

else:
msg = (
"Expected ASGI message 'websocket.send' or 'websocket.close',"
" but got '%s'."
)
raise RuntimeError(msg % message_type)
except LocalProtocolError as exc:
raise ClientDisconnected from exc
elif self.response_started:
if message_type == "websocket.http.response.body":
message = typing.cast("WebSocketResponseBodyEvent", message)