Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support websocket.http.response ASGI extension #1907

Closed
wants to merge 23 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
c77cfde
create test for websocket responses
kristjanvalur Mar 19, 2023
33a1a6e
update websockets protocol
kristjanvalur Mar 19, 2023
309a675
update wsproto
kristjanvalur Mar 19, 2023
101c2bf
add multi-body response test
kristjanvalur Mar 19, 2023
fae3cd3
Add missing close()
kristjanvalur Mar 20, 2023
673cd4b
Move access log to response.start message
kristjanvalur Mar 20, 2023
4207373
Fix imports
kristjanvalur Mar 20, 2023
4db75f4
Lint
kristjanvalur Mar 20, 2023
9a2850e
Update uvicorn/protocols/websockets/websockets_impl.py
kristjanvalur Mar 20, 2023
c93dabf
Update tests/protocols/test_websocket.py
kristjanvalur Mar 20, 2023
d684230
fix tests
kristjanvalur Mar 20, 2023
030e2b6
Log a warning if an unknown status code is received
kristjanvalur Mar 20, 2023
fce8f1d
fix formatting for consistency
kristjanvalur Mar 20, 2023
1674575
fix mypy problems
kristjanvalur Mar 20, 2023
c4acadb
Simply fail if application passes an invalid status code
kristjanvalur Mar 20, 2023
d717e81
Add a test for invalid message order
kristjanvalur Mar 20, 2023
19e5691
Update uvicorn/protocols/websockets/websockets_impl.py
kristjanvalur Mar 22, 2023
a010a03
Fix initial response error handling and unit-test
kristjanvalur Mar 22, 2023
1edcc35
Add a similar missing-body test
kristjanvalur Mar 22, 2023
d3b8d21
Use httpx to check rejection response body
kristjanvalur Mar 27, 2023
f4563ce
Only set the "response_started" flag once the data has been written.
kristjanvalur Mar 27, 2023
fea901c
Add test showing how content-length/transfer-encoding is automaticall…
kristjanvalur Mar 29, 2023
ba53520
Merge remote-tracking branch 'origin/master' into kristjan/ext
kristjanvalur Mar 29, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
301 changes: 297 additions & 4 deletions tests/protocols/test_websocket.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import pytest

from tests.protocols.test_http import HTTP_PROTOCOLS
from tests.response import Response
from tests.utils import run_server
from uvicorn.config import Config
from uvicorn.protocols.websockets.wsproto_impl import WSProtocol
Expand Down Expand Up @@ -50,6 +51,21 @@ async def asgi(self):
break


async def wsresponse(url):
"""
A simple websocket connection request and response helper
"""
url = url.replace("ws:", "http:")
headers = {
"connection": "upgrade",
"upgrade": "websocket",
"Sec-WebSocket-Key": "x3JJHMbDL1EzLkh9GBhXDw==",
"Sec-WebSocket-Version": "13",
}
async with httpx.AsyncClient() as client:
return await client.get(url, headers=headers)


@pytest.mark.anyio
@pytest.mark.parametrize("ws_protocol_cls", WS_PROTOCOLS)
@pytest.mark.parametrize("http_protocol_cls", HTTP_PROTOCOLS)
Expand Down Expand Up @@ -906,7 +922,10 @@ async def send_text(url):
async def test_server_reject_connection(
ws_protocol_cls, http_protocol_cls, unused_tcp_port: int
):
disconnected_message = {}

async def app(scope, receive, send):
nonlocal disconnected_message
assert scope["type"] == "websocket"

# Pull up first recv message.
Expand All @@ -919,15 +938,289 @@ async def app(scope, receive, send):

# This doesn't raise `TypeError`:
# See https://github.com/encode/uvicorn/issues/244
disconnected_message = await receive()

async def websocket_session(url):
with pytest.raises(websockets.exceptions.InvalidStatusCode) as exc_info:
async with websockets.client.connect(url):
pass # pragma: no cover
assert exc_info.value.status_code == 403

config = Config(
app=app,
ws=ws_protocol_cls,
http=http_protocol_cls,
lifespan="off",
port=unused_tcp_port,
)
async with run_server(config):
await websocket_session(f"ws://127.0.0.1:{unused_tcp_port}")

assert disconnected_message == {"type": "websocket.disconnect", "code": 1006}


@pytest.mark.anyio
@pytest.mark.parametrize("ws_protocol_cls", WS_PROTOCOLS)
@pytest.mark.parametrize("http_protocol_cls", HTTP_PROTOCOLS)
async def test_server_reject_connection_with_response(
ws_protocol_cls, http_protocol_cls, unused_tcp_port: int
):
disconnected_message = {}

async def app(scope, receive, send):
nonlocal disconnected_message
assert scope["type"] == "websocket"
assert "websocket.http.response" in scope["extensions"]

# Pull up first recv message.
message = await receive()
assert message["type"] == "websocket.disconnect"
assert message["type"] == "websocket.connect"

# Reject the connection with a response
response = Response(b"goodbye", status_code=400)
await response(scope, receive, send)
disconnected_message = await receive()

async def websocket_session(url):
try:
response = await wsresponse(url)
assert response.status_code == 400
assert response.content == b"goodbye"

config = Config(
app=app,
ws=ws_protocol_cls,
http=http_protocol_cls,
lifespan="off",
port=unused_tcp_port,
)
async with run_server(config):
await websocket_session(f"ws://127.0.0.1:{unused_tcp_port}")

assert disconnected_message == {"type": "websocket.disconnect", "code": 1006}


@pytest.mark.anyio
@pytest.mark.parametrize("ws_protocol_cls", WS_PROTOCOLS)
@pytest.mark.parametrize("http_protocol_cls", HTTP_PROTOCOLS)
async def test_server_reject_connection_with_multibody_response(
ws_protocol_cls, http_protocol_cls, unused_tcp_port: int
):
disconnected_message = {}

async def app(scope, receive, send):
nonlocal disconnected_message
assert scope["type"] == "websocket"
assert "websocket.http.response" in scope["extensions"]

# Pull up first recv message.
message = await receive()
assert message["type"] == "websocket.connect"
message = {
"type": "websocket.http.response.start",
"status": 400,
"headers": [(b"Content-Length", b"20"), (b"Content-Type", b"text/plain")],
}
await send(message)
message = {
"type": "websocket.http.response.body",
"body": b"x" * 10,
"more_body": True,
}
await send(message)
message = {
"type": "websocket.http.response.body",
"body": b"y" * 10,
}
await send(message)
disconnected_message = await receive()

async def websocket_session(url):
response = await wsresponse(url)
assert response.status_code == 400
assert response.content == (b"x" * 10) + (b"y" * 10)

config = Config(
app=app,
ws=ws_protocol_cls,
http=http_protocol_cls,
lifespan="off",
port=unused_tcp_port,
)
async with run_server(config):
await websocket_session(f"ws://127.0.0.1:{unused_tcp_port}")

assert disconnected_message == {"type": "websocket.disconnect", "code": 1006}


@pytest.mark.anyio
@pytest.mark.parametrize("ws_protocol_cls", WS_PROTOCOLS)
@pytest.mark.parametrize("http_protocol_cls", HTTP_PROTOCOLS)
async def test_server_reject_connection_with_invalid_status(
ws_protocol_cls, http_protocol_cls, unused_tcp_port: int
):
async def app(scope, receive, send):
assert scope["type"] == "websocket"
assert "websocket.http.response" in scope["extensions"]

# Pull up first recv message.
message = await receive()
assert message["type"] == "websocket.connect"

message = {
"type": "websocket.http.response.start",
"status": 700, # invalid status code
"headers": [(b"Content-Length", b"0"), (b"Content-Type", b"text/plain")],
}
await send(message)
message = {
"type": "websocket.http.response.body",
"body": b"",
}
await send(message)

async def websocket_session(url):
response = await wsresponse(url)
assert response.status_code == 500
assert response.content == b"Internal Server Error"

config = Config(
app=app,
ws=ws_protocol_cls,
http=http_protocol_cls,
lifespan="off",
port=unused_tcp_port,
)
async with run_server(config):
await websocket_session(f"ws://127.0.0.1:{unused_tcp_port}")


@pytest.mark.anyio
@pytest.mark.parametrize("ws_protocol_cls", WS_PROTOCOLS)
@pytest.mark.parametrize("http_protocol_cls", HTTP_PROTOCOLS)
async def test_server_reject_connection_with_body_nolength(
ws_protocol_cls, http_protocol_cls, unused_tcp_port: int
):
# test that the server can send a response with a body but no content-length
async def app(scope, receive, send):
assert scope["type"] == "websocket"
assert "websocket.http.response" in scope["extensions"]

# Pull up first recv message.
message = await receive()
assert message["type"] == "websocket.connect"

message = {
"type": "websocket.http.response.start",
"status": 403,
"headers": [],
}
await send(message)
message = {
"type": "websocket.http.response.body",
"body": b"hardbody",
}
await send(message)

async def websocket_session(url):
response = await wsresponse(url)
assert response.status_code == 403
assert response.content == b"hardbody"
if ws_protocol_cls == WSProtocol:
# wsproto automatically makes the message chunked
assert response.headers["transfer-encoding"] == "chunked"
else:
# websockets automatically adds a content-length
assert response.headers["content-length"] == "8"

config = Config(
app=app,
ws=ws_protocol_cls,
http=http_protocol_cls,
lifespan="off",
port=unused_tcp_port,
)
async with run_server(config):
await websocket_session(f"ws://127.0.0.1:{unused_tcp_port}")


@pytest.mark.anyio
@pytest.mark.parametrize("ws_protocol_cls", WS_PROTOCOLS)
@pytest.mark.parametrize("http_protocol_cls", HTTP_PROTOCOLS)
async def test_server_reject_connection_with_invalid_msg(
ws_protocol_cls, http_protocol_cls, unused_tcp_port: int
):
async def app(scope, receive, send):
assert scope["type"] == "websocket"
assert "websocket.http.response" in scope["extensions"]

# Pull up first recv message.
message = await receive()
assert message["type"] == "websocket.connect"

message = {
"type": "websocket.http.response.start",
"status": 404,
"headers": [(b"Content-Length", b"0"), (b"Content-Type", b"text/plain")],
}
await send(message)
# send invalid message. This will raise an exception here
await send(message)

async def websocket_session(url):
with pytest.raises(websockets.exceptions.InvalidStatusCode) as exc_info:
async with websockets.client.connect(url):
pass # pragma: no cover
except Exception:
pass
if ws_protocol_cls == WSProtocol:
# ws protocol has started to send the response when it
# fails with the subsequent invalid message so it cannot
# undo that, we will get the initial 404 response
assert exc_info.value.status_code == 404
else:
# websockets protocol sends its response in one chunk
# and can override the already started response with a 500
assert exc_info.value.status_code == 500

config = Config(
app=app,
ws=ws_protocol_cls,
http=http_protocol_cls,
lifespan="off",
port=unused_tcp_port,
)
async with run_server(config):
await websocket_session(f"ws://127.0.0.1:{unused_tcp_port}")


@pytest.mark.anyio
@pytest.mark.parametrize("ws_protocol_cls", WS_PROTOCOLS)
@pytest.mark.parametrize("http_protocol_cls", HTTP_PROTOCOLS)
async def test_server_reject_connection_with_missing_body(
ws_protocol_cls, http_protocol_cls, unused_tcp_port: int
):
async def app(scope, receive, send):
assert scope["type"] == "websocket"
assert "websocket.http.response" in scope["extensions"]

# Pull up first recv message.
message = await receive()
assert message["type"] == "websocket.connect"

message = {
"type": "websocket.http.response.start",
"status": 404,
"headers": [(b"Content-Length", b"0"), (b"Content-Type", b"text/plain")],
}
await send(message)
# no further message

async def websocket_session(url):
with pytest.raises(websockets.exceptions.InvalidStatusCode) as exc_info:
async with websockets.client.connect(url):
pass # pragma: no cover
if ws_protocol_cls == WSProtocol:
assert exc_info.value.status_code == 404
else:
assert exc_info.value.status_code == 500

config = Config(
app=app,
Expand Down
5 changes: 3 additions & 2 deletions tests/response.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,17 +10,18 @@ def __init__(self, content, status_code=200, headers=None, media_type=None):
self.set_content_length()

async def __call__(self, scope, receive, send) -> None:
prefix = "websocket." if scope["type"] == "websocket" else ""
await send(
{
"type": "http.response.start",
"type": prefix + "http.response.start",
"status": self.status_code,
"headers": [
[key.encode(), value.encode()]
for key, value in self.headers.items()
],
}
)
await send({"type": "http.response.body", "body": self.body})
await send({"type": prefix + "http.response.body", "body": self.body})

def render(self, content) -> bytes:
if isinstance(content, bytes):
Expand Down
Loading