From 748eef93fffbf74049c90a10647fadf83f0c9e25 Mon Sep 17 00:00:00 2001 From: Maximilian Hils Date: Fri, 7 Feb 2025 19:23:25 +0100 Subject: [PATCH 1/2] enforce stricter types for `H2StreamStateMachine` --- src/h2/stream.py | 28 +++++++++++++++++----------- 1 file changed, 17 insertions(+), 11 deletions(-) diff --git a/src/h2/stream.py b/src/h2/stream.py index 7d4a12e35..a3c99e351 100644 --- a/src/h2/stream.py +++ b/src/h2/stream.py @@ -46,7 +46,7 @@ from .windows import WindowManager if TYPE_CHECKING: # pragma: no cover - from collections.abc import Generator, Iterable + from collections.abc import Callable, Generator, Iterable from hpack.hpack import Encoder from hpack.struct import Header, HeaderWeaklyTyped @@ -131,7 +131,7 @@ def __init__(self, stream_id: int) -> None: # How the stream was closed. One of StreamClosedBy. self.stream_closed_by: StreamClosedBy | None = None - def process_input(self, input_: StreamInputs) -> Any: + def process_input(self, input_: StreamInputs) -> list[Event]: """ Process a specific input in the state machine. """ @@ -315,21 +315,23 @@ def recv_push_promise(self, previous_state: StreamState) -> list[Event]: event.parent_stream_id = self.stream_id return [event] - def send_end_stream(self, previous_state: StreamState) -> None: + def send_end_stream(self, previous_state: StreamState) -> list[Event]: """ Called when an attempt is made to send END_STREAM in the HALF_CLOSED_REMOTE state. """ self.stream_closed_by = StreamClosedBy.SEND_END_STREAM + return [] - def send_reset_stream(self, previous_state: StreamState) -> None: + def send_reset_stream(self, previous_state: StreamState) -> list[Event]: """ Called when an attempt is made to send RST_STREAM in a non-closed stream state. """ self.stream_closed_by = StreamClosedBy.SEND_RST_STREAM + return [] - def reset_stream_on_error(self, previous_state: StreamState) -> None: + def reset_stream_on_error(self, previous_state: StreamState) -> list[Event]: """ Called when we need to forcefully emit another RST_STREAM frame on behalf of the state machine. @@ -350,7 +352,7 @@ def reset_stream_on_error(self, previous_state: StreamState) -> None: error._events = [event] raise error - def recv_on_closed_stream(self, previous_state: StreamState) -> None: + def recv_on_closed_stream(self, previous_state: StreamState) -> list[Event]: """ Called when an unexpected frame is received on an already-closed stream. @@ -362,7 +364,7 @@ def recv_on_closed_stream(self, previous_state: StreamState) -> None: """ raise StreamClosedError(self.stream_id) - def send_on_closed_stream(self, previous_state: StreamState) -> None: + def send_on_closed_stream(self, previous_state: StreamState) -> list[Event]: """ Called when an attempt is made to send data on an already-closed stream. @@ -374,7 +376,7 @@ def send_on_closed_stream(self, previous_state: StreamState) -> None: """ raise StreamClosedError(self.stream_id) - def recv_push_on_closed_stream(self, previous_state: StreamState) -> None: + def recv_push_on_closed_stream(self, previous_state: StreamState) -> list[Event]: """ Called when a PUSH_PROMISE frame is received on a full stop stream. @@ -393,7 +395,7 @@ def recv_push_on_closed_stream(self, previous_state: StreamState) -> None: msg = "Attempted to push on closed stream." raise ProtocolError(msg) - def send_push_on_closed_stream(self, previous_state: StreamState) -> None: + def send_push_on_closed_stream(self, previous_state: StreamState) -> list[Event]: """ Called when an attempt is made to push on an already-closed stream. @@ -473,7 +475,7 @@ def recv_alt_svc(self, previous_state: StreamState) -> list[Event]: # the event and let it get populated. return [AlternativeServiceAvailable()] - def send_alt_svc(self, previous_state: StreamState) -> None: + def send_alt_svc(self, previous_state: StreamState) -> list[Event]: """ Called when sending an ALTSVC frame on this stream. @@ -489,6 +491,7 @@ def send_alt_svc(self, previous_state: StreamState) -> None: if self.headers_sent: msg = "Cannot send ALTSVC after sending response headers." raise ProtocolError(msg) + return [] @@ -561,7 +564,10 @@ def send_alt_svc(self, previous_state: StreamState) -> None: # (state, input) to tuples of (side_effect_function, end_state). This # map contains all allowed transitions: anything not in this map is # invalid and immediately causes a transition to ``closed``. -_transitions = { +_transitions: dict[ + tuple[StreamState, StreamInputs], + tuple[Callable[[H2StreamStateMachine, StreamState], list[Event]] | None, StreamState], +] = { # State: idle (StreamState.IDLE, StreamInputs.SEND_HEADERS): (H2StreamStateMachine.request_sent, StreamState.OPEN), From 6901794827338c14ffdea4ad608a4af1d8078af8 Mon Sep 17 00:00:00 2001 From: Maximilian Hils Date: Fri, 7 Feb 2025 23:22:10 +0100 Subject: [PATCH 2/2] fix up type errors with stricter state machine --- src/h2/events.py | 14 +++++++------- src/h2/stream.py | 34 ++++++++++++++++++++++------------ 2 files changed, 29 insertions(+), 19 deletions(-) diff --git a/src/h2/events.py b/src/h2/events.py index b81fd1a63..7a22f152c 100644 --- a/src/h2/events.py +++ b/src/h2/events.py @@ -16,7 +16,7 @@ from .settings import ChangedSetting, SettingCodes, Settings, _setting_code_from_int if TYPE_CHECKING: # pragma: no cover - from hpack import HeaderTuple + from hpack.struct import Header from hyperframe.frame import Frame from .errors import ErrorCodes @@ -52,7 +52,7 @@ def __init__(self) -> None: self.stream_id: int | None = None #: The request headers. - self.headers: list[HeaderTuple] | None = None + self.headers: list[Header] | None = None #: If this request also ended the stream, the associated #: :class:`StreamEnded ` event will be available @@ -91,7 +91,7 @@ def __init__(self) -> None: self.stream_id: int | None = None #: The response headers. - self.headers: list[HeaderTuple] | None = None + self.headers: list[Header] | None = None #: If this response also ended the stream, the associated #: :class:`StreamEnded ` event will be available @@ -133,7 +133,7 @@ def __init__(self) -> None: self.stream_id: int | None = None #: The trailers themselves. - self.headers: list[HeaderTuple] | None = None + self.headers: list[Header] | None = None #: Trailers always end streams. This property has the associated #: :class:`StreamEnded ` in it. @@ -237,7 +237,7 @@ def __init__(self) -> None: self.stream_id: int | None = None #: The headers for this informational response. - self.headers: list[HeaderTuple] | None = None + self.headers: list[Header] | None = None #: If this response also had associated priority information, the #: associated :class:`PriorityUpdated ` @@ -436,7 +436,7 @@ def __init__(self) -> None: #: The error code given. Either one of :class:`ErrorCodes #: ` or ``int`` - self.error_code: ErrorCodes | None = None + self.error_code: ErrorCodes | int | None = None #: Whether the remote peer sent a RST_STREAM or we did. self.remote_reset = True @@ -460,7 +460,7 @@ def __init__(self) -> None: self.parent_stream_id: int | None = None #: The request headers, sent by the remote party in the push. - self.headers: list[HeaderTuple] | None = None + self.headers: list[Header] | None = None def __repr__(self) -> str: return ( diff --git a/src/h2/stream.py b/src/h2/stream.py index a3c99e351..3f6c97cd1 100644 --- a/src/h2/stream.py +++ b/src/h2/stream.py @@ -7,7 +7,7 @@ from __future__ import annotations from enum import Enum, IntEnum -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING, Any, Union, cast from hpack import HeaderTuple from hyperframe.frame import AltSvcFrame, ContinuationFrame, DataFrame, Frame, HeadersFrame, PushPromiseFrame, RstStreamFrame, WindowUpdateFrame @@ -1046,10 +1046,11 @@ def receive_push_promise_in_band(self, events = self.state_machine.process_input( StreamInputs.RECV_PUSH_PROMISE, ) - events[0].pushed_stream_id = promised_stream_id + push_event = cast(PushedStreamReceived, events[0]) + push_event.pushed_stream_id = promised_stream_id hdr_validation_flags = self._build_hdr_validation_flags(events) - events[0].headers = self._process_received_headers( + push_event.headers = self._process_received_headers( headers, hdr_validation_flags, header_encoding, ) return [], events @@ -1083,22 +1084,30 @@ def receive_headers(self, input_ = StreamInputs.RECV_HEADERS events = self.state_machine.process_input(input_) + headers_event = cast( + Union[RequestReceived, ResponseReceived, TrailersReceived, InformationalResponseReceived], + events[0], + ) if end_stream: es_events = self.state_machine.process_input( StreamInputs.RECV_END_STREAM, ) - events[0].stream_ended = es_events[0] + # We ensured it's not an information response at the beginning of the method. + cast( + Union[RequestReceived, ResponseReceived, TrailersReceived], + headers_event, + ).stream_ended = cast(StreamEnded, es_events[0]) events += es_events self._initialize_content_length(headers) - if isinstance(events[0], TrailersReceived) and not end_stream: + if isinstance(headers_event, TrailersReceived) and not end_stream: msg = "Trailers must have END_STREAM set" raise ProtocolError(msg) hdr_validation_flags = self._build_hdr_validation_flags(events) - events[0].headers = self._process_received_headers( + headers_event.headers = self._process_received_headers( headers, hdr_validation_flags, header_encoding, ) return [], events @@ -1112,6 +1121,7 @@ def receive_data(self, data: bytes, end_stream: bool, flow_control_len: int) -> "set to %d", self, end_stream, flow_control_len, ) events = self.state_machine.process_input(StreamInputs.RECV_DATA) + data_event = cast(DataReceived, events[0]) self._inbound_window_manager.window_consumed(flow_control_len) self._track_content_length(len(data), end_stream) @@ -1119,11 +1129,11 @@ def receive_data(self, data: bytes, end_stream: bool, flow_control_len: int) -> es_events = self.state_machine.process_input( StreamInputs.RECV_END_STREAM, ) - events[0].stream_ended = es_events[0] + data_event.stream_ended = cast(StreamEnded, es_events[0]) events.extend(es_events) - events[0].data = data - events[0].flow_controlled_length = flow_control_len + data_event.data = data + data_event.flow_controlled_length = flow_control_len return [], events def receive_window_update(self, increment: int) -> tuple[list[Frame], list[Event]]: @@ -1143,7 +1153,7 @@ def receive_window_update(self, increment: int) -> tuple[list[Frame], list[Event # this should be treated as a *stream* error, not a *connection* error. # That means we need to catch the error and forcibly close the stream. if events: - events[0].delta = increment + cast(WindowUpdated, events[0]).delta = increment try: self.outbound_flow_control_window = guard_increment_window( self.outbound_flow_control_window, @@ -1226,7 +1236,7 @@ def stream_reset(self, frame: RstStreamFrame) -> tuple[list[Frame], list[Event]] if events: # We don't fire an event if this stream is already closed. - events[0].error_code = _error_code_from_int(frame.error_code) + cast(StreamReset, events[0]).error_code = _error_code_from_int(frame.error_code) return [], events @@ -1328,7 +1338,7 @@ def _build_headers_frames(self, def _process_received_headers(self, headers: Iterable[Header], header_validation_flags: HeaderValidationFlags, - header_encoding: bool | str | None) -> Iterable[Header]: + header_encoding: bool | str | None) -> list[Header]: """ When headers have been received from the remote peer, run a processing pipeline on them to transform them into the appropriate form for