Skip to content

Commit ce129ff

Browse files
authored
chore: improve type hints (#2638)
1 parent fa3d9d2 commit ce129ff

File tree

5 files changed

+24
-24
lines changed

5 files changed

+24
-24
lines changed

tests/middleware/test_logging.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,8 @@
44
import logging
55
import socket
66
import sys
7-
import typing
7+
from collections.abc import Iterator
8+
from typing import TYPE_CHECKING
89

910
import httpx
1011
import pytest
@@ -15,7 +16,7 @@
1516
from uvicorn import Config
1617
from uvicorn._types import ASGIReceiveCallable, ASGISendCallable, Scope
1718

18-
if typing.TYPE_CHECKING:
19+
if TYPE_CHECKING:
1920
import sys
2021

2122
from uvicorn.protocols.websockets.websockets_impl import WebSocketProtocol
@@ -32,7 +33,7 @@
3233

3334

3435
@contextlib.contextmanager
35-
def caplog_for_logger(caplog: pytest.LogCaptureFixture, logger_name: str) -> typing.Iterator[pytest.LogCaptureFixture]:
36+
def caplog_for_logger(caplog: pytest.LogCaptureFixture, logger_name: str) -> Iterator[pytest.LogCaptureFixture]:
3637
logger = logging.getLogger(logger_name)
3738
logger.propagate, old_propagate = False, logger.propagate
3839
logger.addHandler(caplog.handler)

tests/protocols/test_websocket.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
from __future__ import annotations
22

33
import asyncio
4-
import typing
54
from copy import deepcopy
5+
from typing import TYPE_CHECKING, Any, TypedDict
66

77
import httpx
88
import pytest
@@ -35,7 +35,7 @@
3535
except ModuleNotFoundError: # pragma: no cover
3636
skip_if_no_wsproto = pytest.mark.skipif(True, reason="wsproto is not installed.")
3737

38-
if typing.TYPE_CHECKING:
38+
if TYPE_CHECKING:
3939
import sys
4040

4141
from uvicorn.protocols.http.h11_impl import H11Protocol
@@ -776,7 +776,7 @@ async def websocket_session(url: str):
776776
assert disconnected_message == {"type": "websocket.disconnect", "code": 1006}
777777

778778

779-
class EmptyDict(typing.TypedDict): ...
779+
class EmptyDict(TypedDict): ...
780780

781781

782782
async def test_server_reject_connection_with_response(
@@ -1142,12 +1142,12 @@ async def open_connection(url: str):
11421142

11431143

11441144
async def test_lifespan_state(ws_protocol_cls: WSProtocol, http_protocol_cls: HTTPProtocol, unused_tcp_port: int):
1145-
expected_states: list[dict[str, typing.Any]] = [
1145+
expected_states: list[dict[str, Any]] = [
11461146
{"a": 123, "b": [1]},
11471147
{"a": 123, "b": [1, 2]},
11481148
]
11491149

1150-
actual_states: list[dict[str, typing.Any]] = []
1150+
actual_states: list[dict[str, Any]] = []
11511151

11521152
async def lifespan_app(scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable):
11531153
message = await receive()

tests/test_config.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,9 @@
77
import os
88
import socket
99
import sys
10-
import typing
10+
from collections.abc import Iterator
1111
from pathlib import Path
12-
from typing import Any, Literal
12+
from typing import IO, Any, Callable, Literal
1313
from unittest.mock import MagicMock
1414

1515
import pytest
@@ -291,7 +291,7 @@ def test_ssl_config_combined(tls_certificate_key_and_chain_path: str) -> None:
291291
assert config.is_ssl is True
292292

293293

294-
def asgi2_app(scope: Scope) -> typing.Callable:
294+
def asgi2_app(scope: Scope) -> Callable:
295295
async def asgi(receive: ASGIReceiveCallable, send: ASGISendCallable) -> None: # pragma: nocover
296296
pass
297297

@@ -374,7 +374,7 @@ def test_log_config_yaml(
374374
@pytest.mark.parametrize("config_file", ["log_config.ini", configparser.ConfigParser(), io.StringIO()])
375375
def test_log_config_file(
376376
mocked_logging_config_module: MagicMock,
377-
config_file: str | configparser.RawConfigParser | typing.IO[Any],
377+
config_file: str | configparser.RawConfigParser | IO[Any],
378378
) -> None:
379379
"""
380380
Test that one can load a configparser config from disk.
@@ -386,14 +386,14 @@ def test_log_config_file(
386386

387387

388388
@pytest.fixture(params=[0, 1])
389-
def web_concurrency(request: pytest.FixtureRequest) -> typing.Iterator[int]:
389+
def web_concurrency(request: pytest.FixtureRequest) -> Iterator[int]:
390390
yield request.param
391391
if os.getenv("WEB_CONCURRENCY"):
392392
del os.environ["WEB_CONCURRENCY"]
393393

394394

395395
@pytest.fixture(params=["127.0.0.1", "127.0.0.2"])
396-
def forwarded_allow_ips(request: pytest.FixtureRequest) -> typing.Iterator[str]:
396+
def forwarded_allow_ips(request: pytest.FixtureRequest) -> Iterator[str]:
397397
yield request.param
398398
if os.getenv("FORWARDED_ALLOW_IPS"):
399399
del os.environ["FORWARDED_ALLOW_IPS"]

uvicorn/protocols/websockets/auto.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
11
from __future__ import annotations
22

33
import asyncio
4-
import typing
4+
from typing import Callable
55

6-
AutoWebSocketsProtocol: typing.Callable[..., asyncio.Protocol] | None
6+
AutoWebSocketsProtocol: Callable[..., asyncio.Protocol] | None
77
try:
88
import websockets # noqa
99
except ImportError: # pragma: no cover

uvicorn/protocols/websockets/wsproto_impl.py

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,7 @@
22

33
import asyncio
44
import logging
5-
import typing
6-
from typing import Literal, cast
5+
from typing import Any, Literal, cast
76
from urllib.parse import unquote
87

98
import wsproto
@@ -41,7 +40,7 @@ def __init__(
4140
self,
4241
config: Config,
4342
server_state: ServerState,
44-
app_state: dict[str, typing.Any],
43+
app_state: dict[str, Any],
4544
_loop: asyncio.AbstractEventLoop | None = None,
4645
) -> None:
4746
if not config.loaded:
@@ -256,7 +255,7 @@ async def send(self, message: ASGISendEvent) -> None:
256255

257256
if not self.handshake_complete:
258257
if message_type == "websocket.accept":
259-
message = typing.cast(WebSocketAcceptEvent, message)
258+
message = cast(WebSocketAcceptEvent, message)
260259
self.logger.info(
261260
'%s - "WebSocket %s" [accepted]',
262261
get_client_addr(self.scope),
@@ -293,7 +292,7 @@ async def send(self, message: ASGISendEvent) -> None:
293292
self.transport.close()
294293

295294
elif message_type == "websocket.http.response.start":
296-
message = typing.cast(WebSocketResponseStartEvent, message)
295+
message = cast(WebSocketResponseStartEvent, message)
297296
# ensure status code is in the valid range
298297
if not (100 <= message["status"] < 600):
299298
msg = "Invalid HTTP status code '%d' in response."
@@ -325,7 +324,7 @@ async def send(self, message: ASGISendEvent) -> None:
325324
elif not self.close_sent and not self.response_started:
326325
try:
327326
if message_type == "websocket.send":
328-
message = typing.cast(WebSocketSendEvent, message)
327+
message = cast(WebSocketSendEvent, message)
329328
bytes_data = message.get("bytes")
330329
text_data = message.get("text")
331330
data = text_data if bytes_data is None else bytes_data
@@ -334,7 +333,7 @@ async def send(self, message: ASGISendEvent) -> None:
334333
self.transport.write(output)
335334

336335
elif message_type == "websocket.close":
337-
message = typing.cast(WebSocketCloseEvent, message)
336+
message = cast(WebSocketCloseEvent, message)
338337
self.close_sent = True
339338
code = message.get("code", 1000)
340339
reason = message.get("reason", "") or ""
@@ -351,7 +350,7 @@ async def send(self, message: ASGISendEvent) -> None:
351350
raise ClientDisconnected from exc
352351
elif self.response_started:
353352
if message_type == "websocket.http.response.body":
354-
message = typing.cast("WebSocketResponseBodyEvent", message)
353+
message = cast("WebSocketResponseBodyEvent", message)
355354
body_finished = not message.get("more_body", False)
356355
reject_data = events.RejectData(data=message["body"], body_finished=body_finished)
357356
output = self.conn.send(reject_data)

0 commit comments

Comments
 (0)