From 5841683a50760334986d6ddc49a6943350b5dbec Mon Sep 17 00:00:00 2001 From: Burak Yigit Kaya Date: Thu, 14 Nov 2024 10:22:22 +0000 Subject: [PATCH] fix: No more BytesWarnings Fixes #1236. This patch makes all header operations operate on `bytes` and converts all headers and values to bytes before operation. With a follow up patch to `hpack` it should also increase efficiency as currently, `hpack` casts everything to a `str` first before converting back to bytes: https://github.com/python-hyper/hpack/blob/02afcab28ca56eb5259904fd414baa89e9f50266/src/hpack/hpack.py#L150-L151 --- src/h2/connection.py | 3 +- src/h2/stream.py | 7 +- src/h2/utilities.py | 183 ++++++++++++++++----------------- test/test_invalid_headers.py | 52 ++-------- test/test_utility_functions.py | 15 +-- tox.ini | 2 +- 6 files changed, 111 insertions(+), 151 deletions(-) diff --git a/src/h2/connection.py b/src/h2/connection.py index ca2b38329..086112eec 100644 --- a/src/h2/connection.py +++ b/src/h2/connection.py @@ -33,7 +33,7 @@ from .frame_buffer import FrameBuffer from .settings import Settings, SettingCodes from .stream import H2Stream, StreamClosedBy -from .utilities import SizeLimitDict, guard_increment_window +from .utilities import SizeLimitDict, guard_increment_window, utf8_encode_headers from .windows import WindowManager @@ -975,6 +975,7 @@ def push_stream(self, stream_id, promised_stream_id, request_headers): ) self.streams[promised_stream_id] = new_stream + request_headers = utf8_encode_headers(request_headers) frames = stream.push_stream_in_band( promised_stream_id, request_headers, self.encoder ) diff --git a/src/h2/stream.py b/src/h2/stream.py index 1c34dcd3e..629bbe548 100644 --- a/src/h2/stream.py +++ b/src/h2/stream.py @@ -25,7 +25,8 @@ from .utilities import ( guard_increment_window, is_informational_response, authority_from_headers, validate_headers, validate_outbound_headers, normalize_outbound_headers, - HeaderValidationFlags, extract_method_header, normalize_inbound_headers + HeaderValidationFlags, extract_method_header, normalize_inbound_headers, + utf8_encode_headers ) from .windows import WindowManager @@ -851,6 +852,8 @@ def send_headers(self, headers, encoder, end_stream=False): # we need to scan the header block to see if this is an informational # response. input_ = StreamInputs.SEND_HEADERS + + headers = utf8_encode_headers(headers) if ((not self.state_machine.client) and is_informational_response(headers)): if end_stream: @@ -1319,7 +1322,7 @@ def _initialize_content_length(self, headers): self._expected_content_length = int(v, 10) except ValueError: raise ProtocolError( - "Invalid content-length header: %s" % v + f"Invalid content-length header: {repr(v)}" ) return diff --git a/src/h2/utilities.py b/src/h2/utilities.py index 3a7bf6e07..54cd6f210 100644 --- a/src/h2/utilities.py +++ b/src/h2/utilities.py @@ -13,51 +13,55 @@ from .exceptions import ProtocolError, FlowControlError + UPPER_RE = re.compile(b"[A-Z]") +SIGIL = ord(b':') +INFORMATIONAL_START = ord(b'1') + # A set of headers that are hop-by-hop or connection-specific and thus # forbidden in HTTP/2. This list comes from RFC 7540 § 8.1.2.2. CONNECTION_HEADERS = frozenset([ - b'connection', u'connection', - b'proxy-connection', u'proxy-connection', - b'keep-alive', u'keep-alive', - b'transfer-encoding', u'transfer-encoding', - b'upgrade', u'upgrade', + b'connection', + b'proxy-connection', + b'keep-alive', + b'transfer-encoding', + b'upgrade', ]) _ALLOWED_PSEUDO_HEADER_FIELDS = frozenset([ - b':method', u':method', - b':scheme', u':scheme', - b':authority', u':authority', - b':path', u':path', - b':status', u':status', - b':protocol', u':protocol', + b':method', + b':scheme', + b':authority', + b':path', + b':status', + b':protocol', ]) _SECURE_HEADERS = frozenset([ # May have basic credentials which are vulnerable to dictionary attacks. - b'authorization', u'authorization', - b'proxy-authorization', u'proxy-authorization', + b'authorization', + b'proxy-authorization', ]) _REQUEST_ONLY_HEADERS = frozenset([ - b':scheme', u':scheme', - b':path', u':path', - b':authority', u':authority', - b':method', u':method', - b':protocol', u':protocol', + b':scheme', + b':path', + b':authority', + b':method', + b':protocol', ]) -_RESPONSE_ONLY_HEADERS = frozenset([b':status', u':status']) +_RESPONSE_ONLY_HEADERS = frozenset([b':status']) # A Set of pseudo headers that are only valid if the method is # CONNECT, see RFC 8441 § 5 -_CONNECT_REQUEST_ONLY_HEADERS = frozenset([b':protocol', u':protocol']) +_CONNECT_REQUEST_ONLY_HEADERS = frozenset([b':protocol']) _WHITESPACE = frozenset(map(ord, whitespace)) @@ -84,7 +88,7 @@ def _secure_headers(headers, hdr_validation_flags): for header in headers: if header[0] in _SECURE_HEADERS: yield NeverIndexedHeaderTuple(*header) - elif header[0] in (b'cookie', u'cookie') and len(header[1]) < 20: + elif header[0] == b'cookie' and len(header[1]) < 20: yield NeverIndexedHeaderTuple(*header) else: yield header @@ -95,11 +99,8 @@ def extract_method_header(headers): Extracts the request method from the headers list. """ for k, v in headers: - if k in (b':method', u':method'): - if not isinstance(v, bytes): - return v.encode('utf-8') - else: - return v + if k == b':method': + return v def is_informational_response(headers): @@ -114,25 +115,17 @@ def is_informational_response(headers): :returns: A boolean indicating if this is an informational response. """ for n, v in headers: - if isinstance(n, bytes): - sigil = b':' - status = b':status' - informational_start = b'1' - else: - sigil = u':' - status = u':status' - informational_start = u'1' - # If we find a non-special header, we're done here: stop looping. - if not n.startswith(sigil): + + if n and n[0] != SIGIL: return False # This isn't the status header, bail. - if n != status: + if n != b':status': continue # If the first digit is a 1, we've got informational headers. - return v.startswith(informational_start) + return v[0] == INFORMATIONAL_START def guard_increment_window(current, increment): @@ -164,7 +157,7 @@ def authority_from_headers(headers): Given a header set, searches for the authority header and returns the value. - Note that this doesn't terminate early, so should only be called if the + Note that this doesn't use indexing, so should only be called if the headers are for a client request. Otherwise, will loop over the entire header set, which is potentially unwise. @@ -173,11 +166,8 @@ def authority_from_headers(headers): :rtype: ``bytes`` or ``None``. """ for n, v in headers: - # This gets run against headers that come both from HPACK and from the - # user, so we may have unicode floating around in here. We only want - # bytes. - if n in (b':authority', u':authority'): - return v.encode('utf-8') if not isinstance(v, bytes) else v + if n == b':authority': + return v return None @@ -253,7 +243,8 @@ def _reject_uppercase_header_fields(headers, hdr_validation_flags): for header in headers: if UPPER_RE.search(header[0]): raise ProtocolError( - "Received uppercase header name %s." % header[0]) + f"Received uppercase header name {repr(header[0])}." + ) yield header @@ -285,11 +276,10 @@ def _reject_te(headers, hdr_validation_flags): its value is anything other than "trailers". """ for header in headers: - if header[0] in (b'te', u'te'): - if header[1].lower() not in (b'trailers', u'trailers'): + if header[0] == b'te': + if header[1].lower() != b'trailers': raise ProtocolError( - "Invalid value for TE header: %s" % - header[1] + f"Invalid value for TE header: {repr(header[1])}" ) yield header @@ -303,32 +293,21 @@ def _reject_connection_header(headers, hdr_validation_flags): for header in headers: if header[0] in CONNECTION_HEADERS: raise ProtocolError( - "Connection-specific header field present: %s." % header[0] + f"Connection-specific header field present: {repr(header[0])}." ) yield header -def _custom_startswith(test_string, bytes_prefix, unicode_prefix): - """ - Given a string that might be a bytestring or a Unicode string, - return True if it starts with the appropriate prefix. - """ - if isinstance(test_string, bytes): - return test_string.startswith(bytes_prefix) - else: - return test_string.startswith(unicode_prefix) - - -def _assert_header_in_set(string_header, bytes_header, header_set): +def _assert_header_in_set(bytes_header, header_set): """ Given a set of header names, checks whether the string or byte version of the header name is present. Raises a Protocol error with the appropriate error if it's missing. """ - if not (string_header in header_set or bytes_header in header_set): + if bytes_header not in header_set: raise ProtocolError( - "Header block missing mandatory %s header" % string_header + f"Header block missing mandatory {repr(bytes_header)} header" ) @@ -345,30 +324,26 @@ def _reject_pseudo_header_fields(headers, hdr_validation_flags): method = None for header in headers: - if _custom_startswith(header[0], b':', u':'): + if header[0][0] == SIGIL: if header[0] in seen_pseudo_header_fields: raise ProtocolError( - "Received duplicate pseudo-header field %s" % header[0] + f"Received duplicate pseudo-header field {repr(header[0])}" ) seen_pseudo_header_fields.add(header[0]) if seen_regular_header: raise ProtocolError( - "Received pseudo-header field out of sequence: %s" % - header[0] + f"Received pseudo-header field out of sequence: {repr(header[0])}" ) if header[0] not in _ALLOWED_PSEUDO_HEADER_FIELDS: raise ProtocolError( - "Received custom pseudo-header field %s" % header[0] + f"Received custom pseudo-header field {repr(header[0])}" ) - if header[0] in (b':method', u':method'): - if not isinstance(header[1], bytes): - method = header[1].encode('utf-8') - else: - method = header[1] + if header[0] in b':method': + method = header[1] else: seen_regular_header = True @@ -401,7 +376,7 @@ def _check_pseudo_header_field_acceptability(pseudo_headers, # Relevant RFC section: RFC 7540 § 8.1.2.4 # https://tools.ietf.org/html/rfc7540#section-8.1.2.4 if hdr_validation_flags.is_response_header: - _assert_header_in_set(u':status', b':status', pseudo_headers) + _assert_header_in_set(b':status', pseudo_headers) invalid_response_headers = pseudo_headers & _REQUEST_ONLY_HEADERS if invalid_response_headers: raise ProtocolError( @@ -412,9 +387,9 @@ def _check_pseudo_header_field_acceptability(pseudo_headers, not hdr_validation_flags.is_trailer): # This is a request, so we need to have seen :path, :method, and # :scheme. - _assert_header_in_set(u':path', b':path', pseudo_headers) - _assert_header_in_set(u':method', b':method', pseudo_headers) - _assert_header_in_set(u':scheme', b':scheme', pseudo_headers) + _assert_header_in_set(b':path', pseudo_headers) + _assert_header_in_set(b':method', pseudo_headers) + _assert_header_in_set(b':scheme', pseudo_headers) invalid_request_headers = pseudo_headers & _RESPONSE_ONLY_HEADERS if invalid_request_headers: raise ProtocolError( @@ -425,8 +400,7 @@ def _check_pseudo_header_field_acceptability(pseudo_headers, invalid_headers = pseudo_headers & _CONNECT_REQUEST_ONLY_HEADERS if invalid_headers: raise ProtocolError( - "Encountered connect-request-only headers %s" % - invalid_headers + f"Encountered connect-request-only headers {repr(invalid_headers)}" ) @@ -451,9 +425,9 @@ def _validate_host_authority_header(headers): host_header_val = None for header in headers: - if header[0] in (b':authority', u':authority'): + if header[0] == b':authority': authority_header_val = header[1] - elif header[0] in (b'host', u'host'): + elif header[0] == b'host': host_header_val = header[1] yield header @@ -506,7 +480,7 @@ def _check_path_header(headers, hdr_validation_flags): """ def inner(): for header in headers: - if header[0] in (b':path', u':path'): + if header[0] == b':path': if not header[1]: raise ProtocolError("An empty :path header is forbidden") @@ -525,6 +499,32 @@ def inner(): return inner() +def _to_bytes(v): + """ + Given an assumed `str` (or anything that supports `.encode()`), + encodes it using utf-8 into bytes. Returns the unmodified object + if it is already a `bytes` object. + """ + return v if isinstance(v, bytes) else v.encode('utf-8') + + +def utf8_encode_headers(headers): + """ + Given an iterable of header two-tuples, rebuilds that as a list with the + header names and values encoded as utf-8 bytes. This function produces + tuples that preserve the original type of the header tuple for tuple and + any ``HeaderTuple``. + """ + return [ + ( + header.__class__(_to_bytes(header[0]), _to_bytes(header[1])) + if isinstance(header, HeaderTuple) + else (_to_bytes(header[0]), _to_bytes(header[1])) + ) + for header in headers + ] + + def _lowercase_header_names(headers, hdr_validation_flags): """ Given an iterable of header two-tuples, rebuilds that iterable with the @@ -612,17 +612,12 @@ def _split_outbound_cookie_fields(headers, hdr_validation_flags): inbound. """ for header in headers: - if header[0] in (b'cookie', 'cookie'): - needle = b'; ' if isinstance(header[0], bytes) else '; ' - - if needle in header[1]: - for cookie_val in header[1].split(needle): - if isinstance(header, HeaderTuple): - yield header.__class__(header[0], cookie_val) - else: - yield header[0], cookie_val - else: - yield header + if header[0] == b'cookie': + for cookie_val in header[1].split(b'; '): + if isinstance(header, HeaderTuple): + yield header.__class__(header[0], cookie_val) + else: + yield header[0], cookie_val else: yield header diff --git a/test/test_invalid_headers.py b/test/test_invalid_headers.py index 165183e28..2690d3140 100644 --- a/test/test_invalid_headers.py +++ b/test/test_invalid_headers.py @@ -296,7 +296,9 @@ def test_headers_event_skipping_validation(self, frame_factory, headers): c.send_headers(1, headers) # Ensure headers are still normalized. - norm_headers = h2.utilities.normalize_outbound_headers(headers, None, False) + norm_headers = h2.utilities.normalize_outbound_headers( + h2.utilities.utf8_encode_headers(headers), None, False + ) f = frame_factory.build_headers_frame(norm_headers) assert c.data_to_send() == f.serialize() @@ -322,7 +324,9 @@ def test_push_promise_skipping_validation(self, frame_factory, headers): # Create push promise frame with normalized headers. frame_factory.refresh_encoder() - norm_headers = h2.utilities.normalize_outbound_headers(headers, None, False) + norm_headers = h2.utilities.normalize_outbound_headers( + h2.utilities.utf8_encode_headers(headers), None, False + ) pp_frame = frame_factory.build_push_promise_frame( stream_id=1, promised_stream_id=2, headers=norm_headers ) @@ -467,43 +471,10 @@ class TestFilter(object): (b':path', b''), ), ) - invalid_request_header_blocks_unicode = ( - # First, missing :method - ( - (':authority', 'google.com'), - (':path', '/'), - (':scheme', 'https'), - ), - # Next, missing :path - ( - (':authority', 'google.com'), - (':method', 'GET'), - (':scheme', 'https'), - ), - # Next, missing :scheme - ( - (':authority', 'google.com'), - (':method', 'GET'), - (':path', '/'), - ), - # Finally, path present but empty. - ( - (':authority', 'google.com'), - (':method', 'GET'), - (':scheme', 'https'), - (':path', ''), - ), - ) # All headers that are forbidden from either request or response blocks. forbidden_request_headers_bytes = (b':status',) - forbidden_request_headers_unicode = (':status',) - forbidden_response_headers_bytes = ( - b':path', b':scheme', b':authority', b':method' - ) - forbidden_response_headers_unicode = ( - ':path', ':scheme', ':authority', ':method' - ) + forbidden_response_headers_bytes = (b':path', b':scheme', b':authority', b':method') @pytest.mark.parametrize('validation_function', validation_functions) @pytest.mark.parametrize('hdr_validation_flags', hdr_validation_combos) @@ -563,10 +534,7 @@ def test_response_header_without_status(self, hdr_validation_flags): ) @pytest.mark.parametrize( 'header_block', - ( - invalid_request_header_blocks_bytes + - invalid_request_header_blocks_unicode - ) + (invalid_request_header_blocks_bytes), ) def test_outbound_req_header_missing_pseudo_headers(self, hdr_validation_flags, @@ -599,7 +567,7 @@ def test_inbound_req_header_missing_pseudo_headers(self, ) @pytest.mark.parametrize( 'invalid_header', - forbidden_request_headers_bytes + forbidden_request_headers_unicode + forbidden_request_headers_bytes, ) def test_outbound_req_header_extra_pseudo_headers(self, hdr_validation_flags, @@ -651,7 +619,7 @@ def test_inbound_req_header_extra_pseudo_headers(self, ) @pytest.mark.parametrize( 'invalid_header', - forbidden_response_headers_bytes + forbidden_response_headers_unicode + forbidden_response_headers_bytes, ) def test_outbound_resp_header_extra_pseudo_headers(self, hdr_validation_flags, diff --git a/test/test_utility_functions.py b/test/test_utility_functions.py index c6578df35..3aa0a2452 100644 --- a/test/test_utility_functions.py +++ b/test/test_utility_functions.py @@ -152,12 +152,6 @@ def test_does_not_increment_without_stream_send(self): class TestExtractHeader(object): - example_request_headers = [ - (u':authority', u'example.com'), - (u':path', u'/'), - (u':scheme', u'https'), - (u':method', u'GET'), - ] example_headers_with_bytes = [ (b':authority', b'example.com'), (b':path', b'/'), @@ -165,11 +159,10 @@ class TestExtractHeader(object): (b':method', b'GET'), ] - @pytest.mark.parametrize( - 'headers', [example_request_headers, example_headers_with_bytes] - ) - def test_extract_header_method(self, headers): - assert extract_method_header(headers) == b'GET' + def test_extract_header_method(self): + assert extract_method_header( + self.example_headers_with_bytes + ) == b'GET' def test_size_limit_dict_limit(): diff --git a/tox.ini b/tox.ini index 16e786fb4..74e0c2191 100644 --- a/tox.ini +++ b/tox.ini @@ -19,7 +19,7 @@ deps = pytest-xdist>=2.0.0,<3 hypothesis>=5.5,<7 commands = - pytest --cov-report=xml --cov-report=term --cov=h2 {posargs} + python -bb -m pytest --cov-report=xml --cov-report=term --cov=h2 {posargs} [testenv:pypy3] # temporarily disable coverage testing on PyPy due to performance problems