Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support websocket.http.response ASGI extension #1907

Closed
wants to merge 23 commits into from
Closed
Show file tree
Hide file tree
Changes from 8 commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
c77cfde
create test for websocket responses
kristjanvalur Mar 19, 2023
33a1a6e
update websockets protocol
kristjanvalur Mar 19, 2023
309a675
update wsproto
kristjanvalur Mar 19, 2023
101c2bf
add multi-body response test
kristjanvalur Mar 19, 2023
fae3cd3
Add missing close()
kristjanvalur Mar 20, 2023
673cd4b
Move access log to response.start message
kristjanvalur Mar 20, 2023
4207373
Fix imports
kristjanvalur Mar 20, 2023
4db75f4
Lint
kristjanvalur Mar 20, 2023
9a2850e
Update uvicorn/protocols/websockets/websockets_impl.py
kristjanvalur Mar 20, 2023
c93dabf
Update tests/protocols/test_websocket.py
kristjanvalur Mar 20, 2023
d684230
fix tests
kristjanvalur Mar 20, 2023
030e2b6
Log a warning if an unknown status code is received
kristjanvalur Mar 20, 2023
fce8f1d
fix formatting for consistency
kristjanvalur Mar 20, 2023
1674575
fix mypy problems
kristjanvalur Mar 20, 2023
c4acadb
Simply fail if application passes an invalid status code
kristjanvalur Mar 20, 2023
d717e81
Add a test for invalid message order
kristjanvalur Mar 20, 2023
19e5691
Update uvicorn/protocols/websockets/websockets_impl.py
kristjanvalur Mar 22, 2023
a010a03
Fix initial response error handling and unit-test
kristjanvalur Mar 22, 2023
1edcc35
Add a similar missing-body test
kristjanvalur Mar 22, 2023
d3b8d21
Use httpx to check rejection response body
kristjanvalur Mar 27, 2023
f4563ce
Only set the "response_started" flag once the data has been written.
kristjanvalur Mar 27, 2023
fea901c
Add test showing how content-length/transfer-encoding is automaticall…
kristjanvalur Mar 29, 2023
ba53520
Merge remote-tracking branch 'origin/master' into kristjan/ext
kristjanvalur Mar 29, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
95 changes: 92 additions & 3 deletions tests/protocols/test_websocket.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import pytest

from tests.protocols.test_http import HTTP_PROTOCOLS
from tests.response import Response
from tests.utils import run_server
from uvicorn.config import Config
from uvicorn.protocols.websockets.wsproto_impl import WSProtocol
Expand Down Expand Up @@ -923,11 +924,99 @@ async def app(scope, receive, send):
assert message["type"] == "websocket.disconnect"

async def websocket_session(url):
try:
with pytest.raises(websockets.exceptions.InvalidStatusCode) as exc_info:
async with websockets.client.connect(url):
pass # pragma: no cover
assert exc_info.value.status_code == 403

config = Config(
app=app,
ws=ws_protocol_cls,
http=http_protocol_cls,
lifespan="off",
port=unused_tcp_port,
)
async with run_server(config):
await websocket_session(f"ws://127.0.0.1:{unused_tcp_port}")


@pytest.mark.anyio
@pytest.mark.parametrize("ws_protocol_cls", WS_PROTOCOLS)
@pytest.mark.parametrize("http_protocol_cls", HTTP_PROTOCOLS)
async def test_server_reject_connection_with_response(
ws_protocol_cls, http_protocol_cls, unused_tcp_port: int
):
async def app(scope, receive, send):
assert scope["type"] == "websocket"
assert "websocket.http.response" in scope["extensions"]

# Pull up first recv message.
message = await receive()
assert message["type"] == "websocket.connect"

# Reject the connection with a response
response = Response(b"goodbye", status_code=400)
await response(scope, receive, send)
message = await receive()
assert message["type"] == "websocket.disconnect"

async def websocket_session(url):
with pytest.raises(websockets.exceptions.InvalidStatusCode) as exc_info:
async with websockets.client.connect(url):
pass # pragma: no cover
assert exc_info.value.status_code == 400
# Websockets module currently does not read the response body from the socket.

config = Config(
app=app,
ws=ws_protocol_cls,
http=http_protocol_cls,
lifespan="off",
port=unused_tcp_port,
)
async with run_server(config):
await websocket_session(f"ws://127.0.0.1:{unused_tcp_port}")

kristjanvalur marked this conversation as resolved.
Show resolved Hide resolved

@pytest.mark.anyio
@pytest.mark.parametrize("ws_protocol_cls", WS_PROTOCOLS)
@pytest.mark.parametrize("http_protocol_cls", HTTP_PROTOCOLS)
async def test_server_reject_connection_with_multibody_response(
ws_protocol_cls, http_protocol_cls, unused_tcp_port: int
):
async def app(scope, receive, send):
assert scope["type"] == "websocket"
assert "websocket.http.response" in scope["extensions"]

# Pull up first recv message.
message = await receive()
assert message["type"] == "websocket.connect"
message = {
"type": "websocket.http.response.start",
"status": 400,
"headers": [(b"Content-Length", b"20"), (b"Content-Type", b"text/plain")],
}
await send(message)
message = {
"type": "websocket.http.response.body",
"body": b"x" * 10,
"more_body": True,
}
await send(message)
message = {
"type": "websocket.http.response.body",
"body": b"y" * 10,
}
await send(message)
message = await receive()
assert message["type"] == "websocket.disconnect"
Kludex marked this conversation as resolved.
Show resolved Hide resolved

async def websocket_session(url):
with pytest.raises(websockets.exceptions.InvalidStatusCode) as exc_info:
async with websockets.client.connect(url):
pass # pragma: no cover
except Exception:
pass
assert exc_info.value.status_code == 400
# Websockets module currently does not read the response body from the socket.

config = Config(
app=app,
Expand Down
5 changes: 3 additions & 2 deletions tests/response.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,17 +10,18 @@ def __init__(self, content, status_code=200, headers=None, media_type=None):
self.set_content_length()

async def __call__(self, scope, receive, send) -> None:
prefix = "websocket." if scope["type"] == "websocket" else ""
await send(
{
"type": "http.response.start",
"type": prefix + "http.response.start",
"status": self.status_code,
"headers": [
[key.encode(), value.encode()]
for key, value in self.headers.items()
],
}
)
await send({"type": "http.response.body", "body": self.body})
await send({"type": prefix + "http.response.body", "body": self.body})

def render(self, content) -> bytes:
if isinstance(content, bytes):
Expand Down
111 changes: 77 additions & 34 deletions uvicorn/protocols/websockets/websockets_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,8 @@
WebSocketConnectEvent,
WebSocketDisconnectEvent,
WebSocketReceiveEvent,
WebSocketResponseBodyEvent,
WebSocketResponseStartEvent,
WebSocketScope,
WebSocketSendEvent,
)
Expand Down Expand Up @@ -102,6 +104,7 @@ def __init__(
self.connect_sent = False
self.lost_connection_before_handshake = False
self.accepted_subprotocol: Optional[Subprotocol] = None
self.response_body: Optional[List[bytes]] = None

self.ws_server: Server = Server() # type: ignore[assignment]

Expand Down Expand Up @@ -203,6 +206,7 @@ async def process_request(
"headers": asgi_headers,
"subprotocols": subprotocols,
"state": self.app_state.copy(),
"extensions": {"websocket.http.response": {}},
}
task = self.loop.create_task(self.run_asgi())
task.add_done_callback(self.on_task_complete)
Expand Down Expand Up @@ -278,43 +282,82 @@ async def asgi_send(self, message: "ASGISendEvent") -> None:
message_type = message["type"]

if not self.handshake_started_event.is_set():
if message_type == "websocket.accept":
message = cast("WebSocketAcceptEvent", message)
self.logger.info(
'%s - "WebSocket %s" [accepted]',
self.scope["client"],
get_path_with_query_string(self.scope),
)
self.initial_response = None
self.accepted_subprotocol = cast(
Optional[Subprotocol], message.get("subprotocol")
)
if "headers" in message:
self.extra_headers.extend(
# ASGI spec requires bytes
# But for compatibility we need to convert it to strings
if self.response_body is None:
Kludex marked this conversation as resolved.
Show resolved Hide resolved
if message_type == "websocket.accept":
message = cast("WebSocketAcceptEvent", message)
self.logger.info(
'%s - "WebSocket %s" [accepted]',
self.scope["client"],
get_path_with_query_string(self.scope),
)
self.initial_response = None
self.accepted_subprotocol = cast(
Optional[Subprotocol], message.get("subprotocol")
)
if "headers" in message:
self.extra_headers.extend(
# ASGI spec requires bytes
# But for compatibility we need to convert it to strings
(name.decode("latin-1"), value.decode("latin-1"))
for name, value in message["headers"]
)
self.handshake_started_event.set()

elif message_type == "websocket.close":
message = cast("WebSocketCloseEvent", message)
self.logger.info(
'%s - "WebSocket %s" 403',
self.scope["client"],
get_path_with_query_string(self.scope),
)
self.initial_response = (http.HTTPStatus.FORBIDDEN, [], b"")
self.handshake_started_event.set()
self.closed_event.set()

elif message_type == "websocket.http.response.start":
message = cast("WebSocketResponseStartEvent", message)
Kludex marked this conversation as resolved.
Show resolved Hide resolved
self.logger.info(
'%s - "WebSocket %s" %s',
kristjanvalur marked this conversation as resolved.
Show resolved Hide resolved
self.scope["client"],
get_path_with_query_string(self.scope),
message["status"],
)
try:
status = http.HTTPStatus(message["status"])
except AttributeError:
status = http.HTTPStatus.FORBIDDEN
Kludex marked this conversation as resolved.
Show resolved Hide resolved
headers = [
(name.decode("latin-1"), value.decode("latin-1"))
for name, value in message["headers"]
for name, value in message.get("headers", [])
]
self.initial_response = (status, headers, b"")
Kludex marked this conversation as resolved.
Show resolved Hide resolved
self.response_body = []

else:
msg = (
"Expected ASGI message 'websocket.accept', 'websocket.close', "
"or 'websocket.http.response.start' "
"but got '%s'."
)
self.handshake_started_event.set()

elif message_type == "websocket.close":
message = cast("WebSocketCloseEvent", message)
self.logger.info(
'%s - "WebSocket %s" 403',
self.scope["client"],
get_path_with_query_string(self.scope),
)
self.initial_response = (http.HTTPStatus.FORBIDDEN, [], b"")
self.handshake_started_event.set()
self.closed_event.set()

raise RuntimeError(msg % message_type)
else:
msg = (
"Expected ASGI message 'websocket.accept' or 'websocket.close', "
"but got '%s'."
)
raise RuntimeError(msg % message_type)
if message_type == "websocket.http.response.body":
message = cast("WebSocketResponseBodyEvent", message)
self.response_body.append(message["body"])

if not message.get("more_body", False):
self.initial_response = self.initial_response[:2] + (
b"".join(self.response_body),
)
self.response_body = None
kristjanvalur marked this conversation as resolved.
Show resolved Hide resolved
self.handshake_started_event.set()
self.closed_event.set()
else:
msg = (
"Expected ASGI message 'websocket.http.response.body' "
"but got '%s'."
)
raise RuntimeError(msg % message_type)

elif not self.closed_event.is_set():
await self.handshake_completed_event.wait()
Expand Down
Loading