diff --git a/CHANGES/11012.breaking.rst b/CHANGES/11012.breaking.rst new file mode 100644 index 00000000000..e5248a55c61 --- /dev/null +++ b/CHANGES/11012.breaking.rst @@ -0,0 +1,9 @@ +Refactored ``ClientRequest`` class. This simplifies a lot of code and improves our type +checking accuracy. It also better aligns public/private attributes with what we expect +developers to access safely from a client middleware. + +If code subclasses ``ClientRequest``, it is likely that the subclass will need tweaking +to be compatible with the new version. Similarly, subclasses of ``ClientResponse`` may +need to adjust ``__init__`` parameters. + +-- by :user:`Dreamsorcerer`. diff --git a/aiohttp/client.py b/aiohttp/client.py index 7a4ad715362..dcbdd23dfd4 100644 --- a/aiohttp/client.py +++ b/aiohttp/client.py @@ -684,7 +684,7 @@ async def _connect_and_send_request( max_field_size=max_field_size, ) try: - resp = await req.send(conn) + resp = await req._send(conn) try: await resp.start(conn) except BaseException: diff --git a/aiohttp/client_reqrep.py b/aiohttp/client_reqrep.py index 050d3a259e1..0d6f435b6e5 100644 --- a/aiohttp/client_reqrep.py +++ b/aiohttp/client_reqrep.py @@ -7,22 +7,23 @@ import sys import traceback import warnings -from collections.abc import Callable, Iterable, Mapping +from collections.abc import Callable, Iterable, Sequence from hashlib import md5, sha1, sha256 -from http.cookies import Morsel, SimpleCookie +from http.cookies import BaseCookie, SimpleCookie from types import MappingProxyType, TracebackType -from typing import TYPE_CHECKING, Any, Literal, NamedTuple, Optional, Union +from typing import TYPE_CHECKING, Any, NamedTuple, TypedDict from multidict import CIMultiDict, CIMultiDictProxy, MultiDict, MultiDictProxy from yarl import URL -from . import hdrs, helpers, http, multipart, payload +from . import hdrs, multipart, payload from ._cookie_helpers import ( parse_cookie_header, parse_set_cookie_headers, preserve_morsel_with_coded_value, ) from .abc import AbstractStreamWriter +from .base_protocol import BaseProtocol from .client_exceptions import ( ClientConnectionError, ClientOSError, @@ -33,7 +34,6 @@ ) from .compression_utils import HAS_BROTLI, HAS_ZSTD from .formdata import FormData -from .hdrs import CONTENT_TYPE from .helpers import ( _SENTINEL, BaseTimerContext, @@ -50,20 +50,14 @@ ) from .http import ( SERVER_SOFTWARE, + HttpProcessingError, HttpVersion, HttpVersion10, HttpVersion11, StreamWriter, ) from .streams import StreamReader -from .typedefs import ( - DEFAULT_JSON_DECODER, - JSONDecoder, - LooseCookies, - LooseHeaders, - Query, - RawHeaders, -) +from .typedefs import DEFAULT_JSON_DECODER, JSONDecoder, Query, RawHeaders if TYPE_CHECKING: import ssl @@ -172,6 +166,7 @@ def check(self, transport: asyncio.Transport) -> None: SSL_ALLOWED_TYPES = (bool,) # type: ignore[unreachable] +_CONNECTION_CLOSED_EXCEPTION = ClientConnectionError("Connection closed") _SSL_SCHEMES = frozenset(("https", "wss")) @@ -190,25 +185,6 @@ class ConnectionKey(NamedTuple): proxy_headers_hash: int | None # hash(CIMultiDict) -def _warn_if_unclosed_payload(payload: payload.Payload, stacklevel: int = 2) -> None: - """Warn if the payload is not closed. - - Callers must check that the body is a Payload before calling this method. - - Args: - payload: The payload to check - stacklevel: Stack level for the warning (default 2 for direct callers) - """ - if not payload.autoclose and not payload.consumed: - warnings.warn( - "The previous request body contains unclosed resources. " - "Use await request.update_body() instead of setting request.body " - "directly to properly close resources and avoid leaks.", - ResourceWarning, - stacklevel=stacklevel, - ) - - class ClientResponse(HeadersMixin): # Some of these attributes are None when created, # but will be set by the start() method. @@ -224,12 +200,12 @@ class ClientResponse(HeadersMixin): _history: tuple["ClientResponse", ...] = () _raw_headers: RawHeaders = None # type: ignore[assignment] - _connection: Optional["Connection"] = None # current connection + _connection: "Connection | None" = None # current connection _cookies: SimpleCookie | None = None _raw_cookie_headers: tuple[str, ...] | None = None - _continue: Optional["asyncio.Future[bool]"] = None + _continue: asyncio.Future[bool] | None = None _source_traceback: traceback.StackSummary | None = None - _session: Optional["ClientSession"] = None + _session: "ClientSession | None" = None # set up by ClientRequest after ClientResponse object creation # post-init stage allows to not change ctor signature _closed = True # to allow __del__ for non-initialized properly response @@ -238,21 +214,28 @@ class ClientResponse(HeadersMixin): _resolve_charset: Callable[["ClientResponse", bytes], str] = lambda *_: "utf-8" - __writer: Optional["asyncio.Task[None]"] = None + __writer: asyncio.Task[None] | None = None def __init__( self, method: str, url: URL, *, - writer: "asyncio.Task[None] | None", - continue100: Optional["asyncio.Future[bool]"], + writer: asyncio.Task[None] | None, + continue100: asyncio.Future[bool] | None, timer: BaseTimerContext | None, - request_info: RequestInfo, - traces: list["Trace"], + traces: Sequence["Trace"], loop: asyncio.AbstractEventLoop, - session: "ClientSession", + session: "ClientSession | None", + request_headers: CIMultiDict[str], + original_url: URL, + **kwargs: object, ) -> None: + # kwargs exists so authors of subclasses should expect to pass through unknown + # arguments. This allows us to safely add new arguments in future releases. + # But, we should never receive unknown arguments here in the parent class, this + # would indicate an argument has been named wrong or similar in the subclass. + assert not kwargs, "Unexpected arguments to ClientResponse" # URL forbids subclasses, so a simple type check is enough. assert type(url) is URL @@ -264,14 +247,14 @@ def __init__( self._writer = writer if continue100 is not None: self._continue = continue100 - self._request_info = request_info + self._request_headers = request_headers + self._original_url = original_url self._timer = timer if timer is not None else TimerNoop() self._cache: dict[str, Any] = {} self._traces = traces self._loop = loop # Save reference to _resolve_charset, so that get_encoding() will still # work after the response has finished reading the body. - # TODO: Fix session=None in tests (see ClientRequest.__init__). if session is not None: # store a reference to session #1985 self._session = session @@ -283,7 +266,7 @@ def __reset_writer(self, _: object = None) -> None: self.__writer = None @property - def _writer(self) -> Optional["asyncio.Task[None]"]: + def _writer(self) -> asyncio.Task[None] | None: """The writer task for streaming data. _writer is only provided for backwards compatibility @@ -292,7 +275,7 @@ def _writer(self) -> Optional["asyncio.Task[None]"]: return self.__writer @_writer.setter - def _writer(self, writer: Optional["asyncio.Task[None]"]) -> None: + def _writer(self, writer: asyncio.Task[None] | None) -> None: """Set the writer task for streaming data.""" if self.__writer is not None: self.__writer.remove_done_callback(self.__reset_writer) @@ -353,7 +336,11 @@ def raw_headers(self) -> RawHeaders: @reify def request_info(self) -> RequestInfo: - return self._request_info + # Build RequestInfo lazily from components + headers = CIMultiDictProxy(self._request_headers) + return tuple.__new__( + RequestInfo, (self._url, self.method, headers, self._original_url) + ) @reify def content_disposition(self) -> ContentDisposition | None: @@ -399,7 +386,7 @@ def __repr__(self) -> str: return out.getvalue() @property - def connection(self) -> Optional["Connection"]: + def connection(self) -> "Connection | None": return self._connection @reify @@ -453,7 +440,7 @@ async def start(self, connection: "Connection") -> "ClientResponse": try: protocol = self._protocol message, payload = await protocol.read() # type: ignore[union-attr] - except http.HttpProcessingError as exc: + except HttpProcessingError as exc: raise ClientResponseError( self.request_info, self.history, @@ -624,7 +611,7 @@ async def read(self) -> bytes: def get_encoding(self) -> str: ctype = self.headers.get(hdrs.CONTENT_TYPE, "").lower() - mimetype = helpers.parse_mimetype(ctype) + mimetype = parse_mimetype(ctype) encoding = mimetype.parameters.get("charset") if encoding: @@ -700,35 +687,26 @@ async def __aexit__( await self.wait_for_close() -class ClientRequest: - GET_METHODS = { - hdrs.METH_GET, - hdrs.METH_HEAD, - hdrs.METH_OPTIONS, - hdrs.METH_TRACE, - } - POST_METHODS = {hdrs.METH_PATCH, hdrs.METH_POST, hdrs.METH_PUT} - ALL_METHODS = GET_METHODS.union(POST_METHODS).union({hdrs.METH_DELETE}) +class ClientRequestBase: + """An internal class for proxy requests.""" - DEFAULT_HEADERS = { - hdrs.ACCEPT: "*/*", - hdrs.ACCEPT_ENCODING: _gen_default_accept_encoding(), - } + POST_METHODS = {hdrs.METH_PATCH, hdrs.METH_POST, hdrs.METH_PUT} - # Type of body depends on PAYLOAD_REGISTRY, which is dynamic. - _body: None | payload.Payload = None auth = None - response = None + proxy: URL | None = None + response_class = ClientResponse + server_hostname: str | None = None # Needed in connector.py + version = HttpVersion11 + _response = None # These class defaults help create_autospec() work correctly. # If autospec is improved in future, maybe these can be removed. url = URL() method = "GET" - __writer: Optional["asyncio.Task[None]"] = None # async task for streaming data - _continue = None # waiter future for '100 Continue' response + _writer_task: asyncio.Task[None] | None = None # async task for streaming data - _skip_auto_headers: Optional["CIMultiDict[None]"] = None + _skip_auto_headers: "CIMultiDict[None] | None" = None # N.B. # Adding __del__ method with self._writer closing doesn't make sense @@ -740,27 +718,11 @@ def __init__( method: str, url: URL, *, - params: Query = None, - headers: LooseHeaders | None = None, - skip_auto_headers: Iterable[str] | None = None, - data: Any = None, - cookies: LooseCookies | None = None, - auth: BasicAuth | None = None, - version: http.HttpVersion = http.HttpVersion11, - compress: str | bool = False, - chunked: bool | None = None, - expect100: bool = False, + headers: CIMultiDict[str], + auth: BasicAuth | None, loop: asyncio.AbstractEventLoop, - response_class: type["ClientResponse"] | None = None, - proxy: URL | None = None, - proxy_auth: BasicAuth | None = None, - timer: BaseTimerContext | None = None, - session: Optional["ClientSession"] = None, - ssl: SSLContext | bool | Fingerprint = True, - proxy_headers: LooseHeaders | None = None, - traces: list["Trace"] | None = None, + ssl: SSLContext | bool | Fingerprint, trust_env: bool = False, - server_hostname: str | None = None, ): if match := _CONTAINS_CONTROL_CHAR_RE.search(method): raise ValueError( @@ -769,50 +731,21 @@ def __init__( ) # URL forbids subclasses, so a simple type check is enough. assert type(url) is URL, url - if proxy is not None: - assert type(proxy) is URL, proxy - # FIXME: session is None in tests only, need to fix tests - # assert session is not None - if TYPE_CHECKING: - assert session is not None - self._session = session - if params: - url = url.extend_query(params) self.original_url = url self.url = url.with_fragment(None) if url.raw_fragment else url self.method = method.upper() - self.chunked = chunked self.loop = loop - self.length = None - if response_class is None: - real_response_class = ClientResponse - else: - real_response_class = response_class - self.response_class: type[ClientResponse] = real_response_class - self._timer = timer if timer is not None else TimerNoop() self._ssl = ssl - self.server_hostname = server_hostname if loop.get_debug(): self._source_traceback = traceback.extract_stack(sys._getframe(1)) - self.update_version(version) - self.update_host(url) - self.update_headers(headers) - self.update_auto_headers(skip_auto_headers) - self.update_cookies(cookies) - self.update_content_encoding(data, compress) - self.update_auth(auth, trust_env) - self.update_proxy(proxy, proxy_auth, proxy_headers) - - self.update_body_from_data(data) - if data is not None or self.method not in self.GET_METHODS: - self.update_transfer_encoding() - self.update_expect_continue(expect100) - self._traces = [] if traces is None else traces + self._update_host(url) + self._update_headers(headers) + self._update_auth(auth, trust_env) - def __reset_writer(self, _: object = None) -> None: - self.__writer = None + def _reset_writer(self, _: object = None) -> None: + self._writer_task = None def _get_content_length(self) -> int | None: """Extract and validate Content-Length header value. @@ -832,33 +765,25 @@ def _get_content_length(self) -> int | None: ) from None @property - def skip_auto_headers(self) -> CIMultiDict[None]: - return self._skip_auto_headers or CIMultiDict() - - @property - def _writer(self) -> Optional["asyncio.Task[None]"]: - return self.__writer + def _writer(self) -> asyncio.Task[None] | None: + return self._writer_task @_writer.setter - def _writer(self, writer: "asyncio.Task[None]") -> None: - if self.__writer is not None: - self.__writer.remove_done_callback(self.__reset_writer) - self.__writer = writer - writer.add_done_callback(self.__reset_writer) + def _writer(self, writer: asyncio.Task[None]) -> None: + if self._writer_task is not None: + self._writer_task.remove_done_callback(self._reset_writer) + self._writer_task = writer + writer.add_done_callback(self._reset_writer) def is_ssl(self) -> bool: return self.url.scheme in _SSL_SCHEMES @property - def ssl(self) -> Union["SSLContext", bool, Fingerprint]: + def ssl(self) -> "SSLContext | bool | Fingerprint": return self._ssl @property def connection_key(self) -> ConnectionKey: - if proxy_headers := self.proxy_headers: - h: int | None = hash(tuple(proxy_headers.items())) - else: - h = None url = self.url return tuple.__new__( ConnectionKey, @@ -867,74 +792,25 @@ def connection_key(self) -> ConnectionKey: url.port, url.scheme in _SSL_SCHEMES, self._ssl, - self.proxy, - self.proxy_auth, - h, + None, + None, + None, ), ) - @property - def host(self) -> str: - ret = self.url.raw_host - assert ret is not None - return ret - - @property - def port(self) -> int | None: - return self.url.port - - @property - def body(self) -> payload.Payload | Literal[b""]: - """Request body.""" - # empty body is represented as bytes for backwards compatibility - return self._body or b"" - - @body.setter - def body(self, value: Any) -> None: - """Set request body with warning for non-autoclose payloads. - - WARNING: This setter must be called from within an event loop and is not - thread-safe. Setting body outside of an event loop may raise RuntimeError - when closing file-based payloads. - - DEPRECATED: Direct assignment to body is deprecated and will be removed - in a future version. Use await update_body() instead for proper resource - management. - """ - # Close existing payload if present - if self._body is not None: - # Warn if the payload needs manual closing - # stacklevel=3: user code -> body setter -> _warn_if_unclosed_payload - _warn_if_unclosed_payload(self._body, stacklevel=3) - # NOTE: In the future, when we remove sync close support, - # this setter will need to be removed and only the async - # update_body() method will be available. For now, we call - # _close() for backwards compatibility. - self._body._close() - self._update_body(value) - - @property - def request_info(self) -> RequestInfo: - headers: CIMultiDictProxy[str] = CIMultiDictProxy(self.headers) - # These are created on every request, so we use a NamedTuple - # for performance reasons. We don't use the RequestInfo.__new__ - # method because it has a different signature which is provided - # for backwards compatibility only. - return tuple.__new__( - RequestInfo, (self.url, self.method, headers, self.original_url) - ) + def _update_auth(self, auth: BasicAuth | None, trust_env: bool = False) -> None: + """Set basic auth.""" + if auth is None: + auth = self.auth + if auth is None: + return - @property - def session(self) -> "ClientSession": - """Return the ClientSession instance. + if not isinstance(auth, BasicAuth): + raise TypeError("BasicAuth() tuple is required instead") - This property provides access to the ClientSession that initiated - this request, allowing middleware to make additional requests - using the same session. - """ - return self._session + self.headers[hdrs.AUTHORIZATION] = auth.encode() - def update_host(self, url: URL) -> None: + def _update_host(self, url: URL) -> None: """Update destination host, port and connection type (ssl).""" # get host/port if not url.raw_host: @@ -942,24 +818,9 @@ def update_host(self, url: URL) -> None: # basic auth info if url.raw_user or url.raw_password: - self.auth = helpers.BasicAuth(url.user or "", url.password or "") - - def update_version(self, version: http.HttpVersion | str) -> None: - """Convert request version to two elements tuple. - - parser HTTP version '1.1' => (1, 1) - """ - if isinstance(version, str): - v = [part.strip() for part in version.split(".", 1)] - try: - version = http.HttpVersion(int(v[0]), int(v[1])) - except ValueError: - raise ValueError( - f"Can not parse http version number: {version}" - ) from None - self.version = version + self.auth = BasicAuth(url.user or "", url.password or "") - def update_headers(self, headers: LooseHeaders | None) -> None: + def _update_headers(self, headers: CIMultiDict[str]) -> None: """Update request headers.""" self.headers: CIMultiDict[str] = CIMultiDict() @@ -969,22 +830,244 @@ def update_headers(self, headers: LooseHeaders | None) -> None: # host_port_subcomponent is None when the URL is a relative URL. # but we know we do not have a relative URL here. assert host is not None - self.headers[hdrs.HOST] = host + self.headers[hdrs.HOST] = headers.pop(hdrs.HOST, host) + self.headers.extend(headers) - if not headers: - return + def _create_response(self, task: asyncio.Task[None] | None) -> ClientResponse: + return self.response_class( + self.method, + self.original_url, + writer=task, + continue100=None, + timer=TimerNoop(), + traces=(), + loop=self.loop, + session=None, + request_headers=self.headers, + original_url=self.original_url, + ) + + def _create_writer(self, protocol: BaseProtocol) -> StreamWriter: + return StreamWriter(protocol, self.loop) + + def _should_write(self, protocol: BaseProtocol) -> bool: + return protocol.writing_paused + + async def _send(self, conn: "Connection") -> ClientResponse: + # Specify request target: + # - CONNECT request must send authority form URI + # - not CONNECT proxy must send absolute form URI + # - most common is origin form URI + if self.method == hdrs.METH_CONNECT: + connect_host = self.url.host_subcomponent + assert connect_host is not None + path = f"{connect_host}:{self.url.port}" + elif self.proxy and not self.is_ssl(): + path = str(self.url) + else: + path = self.url.raw_path_qs - if isinstance(headers, (dict, MultiDictProxy, MultiDict)): - headers = headers.items() + protocol = conn.protocol + assert protocol is not None + writer = self._create_writer(protocol) - for key, value in headers: # type: ignore[misc] - # A special case for Host header - if key in hdrs.HOST_ALL: - self.headers[key] = value + # set default content-type + if ( + self.method in self.POST_METHODS + and ( + self._skip_auto_headers is None + or hdrs.CONTENT_TYPE not in self._skip_auto_headers + ) + and hdrs.CONTENT_TYPE not in self.headers + ): + self.headers[hdrs.CONTENT_TYPE] = "application/octet-stream" + + v = self.version + if hdrs.CONNECTION not in self.headers: + if conn._connector.force_close: + if v == HttpVersion11: + self.headers[hdrs.CONNECTION] = "close" + elif v == HttpVersion10: + self.headers[hdrs.CONNECTION] = "keep-alive" + + # status + headers + status_line = f"{self.method} {path} HTTP/{v.major}.{v.minor}" + + # Buffer headers for potential coalescing with body + await writer.write_headers(status_line, self.headers) + + task: asyncio.Task[None] | None + if self._should_write(protocol): + coro = self._write_bytes(writer, conn, self._get_content_length()) + if sys.version_info >= (3, 12): + # Optimization for Python 3.12, try to write + # bytes immediately to avoid having to schedule + # the task on the event loop. + task = asyncio.Task(coro, loop=self.loop, eager_start=True) else: - self.headers.add(key, value) + task = self.loop.create_task(coro) + if task.done(): + task = None + else: + self._writer = task + else: + # We have nothing to write because + # - there is no body + # - the protocol does not have writing paused + # - we are not waiting for a 100-continue response + protocol.start_timeout() + writer.set_eof() + task = None + self._response = self._create_response(task) + return self._response + + async def _write_bytes( + self, + writer: AbstractStreamWriter, + conn: "Connection", + content_length: int | None, + ) -> None: + # Base class never has a body, this will never be run. + assert False + + +class ClientRequestArgs(TypedDict, total=False): + params: Query + headers: CIMultiDict[str] + skip_auto_headers: Iterable[str] | None + data: Any + cookies: BaseCookie[str] + auth: BasicAuth | None + version: HttpVersion + compress: str | bool + chunked: bool | None + expect100: bool + loop: asyncio.AbstractEventLoop + response_class: type[ClientResponse] + proxy: URL | None + proxy_auth: BasicAuth | None + timer: BaseTimerContext + session: "ClientSession" + ssl: SSLContext | bool | Fingerprint + proxy_headers: CIMultiDict[str] | None + traces: list["Trace"] + trust_env: bool + server_hostname: str | None + + +class ClientRequest(ClientRequestBase): + _EMPTY_BODY = payload.PAYLOAD_REGISTRY.get(b"", disposition=None) + _body = _EMPTY_BODY + _continue = None # waiter future for '100 Continue' response + + GET_METHODS = { + hdrs.METH_GET, + hdrs.METH_HEAD, + hdrs.METH_OPTIONS, + hdrs.METH_TRACE, + } + DEFAULT_HEADERS = { + hdrs.ACCEPT: "*/*", + hdrs.ACCEPT_ENCODING: _gen_default_accept_encoding(), + } + + def __init__( + self, + method: str, + url: URL, + *, + params: Query, + headers: CIMultiDict[str], + skip_auto_headers: Iterable[str] | None, + data: Any, + cookies: BaseCookie[str], + auth: BasicAuth | None, + version: HttpVersion, + compress: str | bool, + chunked: bool | None, + expect100: bool, + loop: asyncio.AbstractEventLoop, + response_class: type[ClientResponse], + proxy: URL | None, + proxy_auth: BasicAuth | None, + timer: BaseTimerContext, + session: "ClientSession", + ssl: SSLContext | bool | Fingerprint, + proxy_headers: CIMultiDict[str] | None, + traces: list["Trace"], + trust_env: bool, + server_hostname: str | None, + **kwargs: object, + ): + # kwargs exists so authors of subclasses should expect to pass through unknown + # arguments. This allows us to safely add new arguments in future releases. + # But, we should never receive unknown arguments here in the parent class, this + # would indicate an argument has been named wrong or similar in the subclass. + assert not kwargs, "Unexpected arguments to ClientRequest" + + if params: + url = url.extend_query(params) + super().__init__(method, url, headers=headers, auth=auth, loop=loop, ssl=ssl) + + if proxy is not None: + assert type(proxy) is URL, proxy + self._session = session + self.chunked = chunked + self.response_class = response_class + self._timer = timer + self.server_hostname = server_hostname + self.version = version - def update_auto_headers(self, skip_auto_headers: Iterable[str] | None) -> None: + self._update_auto_headers(skip_auto_headers) + self._update_cookies(cookies) + self._update_content_encoding(data, compress) + self._update_proxy(proxy, proxy_auth, proxy_headers) + + self._update_body_from_data(data) + if data is not None or self.method not in self.GET_METHODS: + self._update_transfer_encoding() + self._update_expect_continue(expect100) + self._traces = traces + + @property + def body(self) -> payload.Payload: + return self._body + + @property + def skip_auto_headers(self) -> CIMultiDict[None]: + return self._skip_auto_headers or CIMultiDict() + + @property + def connection_key(self) -> ConnectionKey: + if proxy_headers := self.proxy_headers: + h: int | None = hash(tuple(proxy_headers.items())) + else: + h = None + url = self.url + return tuple.__new__( + ConnectionKey, + ( + url.raw_host or "", + url.port, + url.scheme in _SSL_SCHEMES, + self._ssl, + self.proxy, + self.proxy_auth, + h, + ), + ) + + @property + def session(self) -> "ClientSession": + """Return the ClientSession instance. + + This property provides access to the ClientSession that initiated + this request, allowing middleware to make additional requests + using the same session. + """ + return self._session + + def _update_auto_headers(self, skip_auto_headers: Iterable[str] | None) -> None: if skip_auto_headers is not None: self._skip_auto_headers = CIMultiDict( (hdr, None) for hdr in sorted(skip_auto_headers) @@ -1003,7 +1086,7 @@ def update_auto_headers(self, skip_auto_headers: Iterable[str] | None) -> None: if hdrs.USER_AGENT not in used_headers: self.headers[hdrs.USER_AGENT] = SERVER_SOFTWARE - def update_cookies(self, cookies: LooseCookies | None) -> None: + def _update_cookies(self, cookies: BaseCookie[str]) -> None: """Update request cookies header.""" if not cookies: return @@ -1014,20 +1097,13 @@ def update_cookies(self, cookies: LooseCookies | None) -> None: c.update(parse_cookie_header(self.headers.get(hdrs.COOKIE, ""))) del self.headers[hdrs.COOKIE] - if isinstance(cookies, Mapping): - iter_cookies = cookies.items() - else: - iter_cookies = cookies # type: ignore[assignment] - for name, value in iter_cookies: - if isinstance(value, Morsel): - # Use helper to preserve coded_value exactly as sent by server - c[name] = preserve_morsel_with_coded_value(value) - else: - c[name] = value # type: ignore[assignment] + for name, value in cookies.items(): + # Use helper to preserve coded_value exactly as sent by server + c[name] = preserve_morsel_with_coded_value(value) self.headers[hdrs.COOKIE] = c.output(header="", sep=";").strip() - def update_content_encoding(self, data: Any, compress: bool | str) -> None: + def _update_content_encoding(self, data: Any, compress: bool | str) -> None: """Set request content encoding.""" self.compress = None if not data: @@ -1043,7 +1119,7 @@ def update_content_encoding(self, data: Any, compress: bool | str) -> None: self.headers[hdrs.CONTENT_ENCODING] = self.compress self.chunked = True # enable chunked, no need to deal with length - def update_transfer_encoding(self) -> None: + def _update_transfer_encoding(self) -> None: """Analyze transfer-encoding header.""" te = self.headers.get(hdrs.TRANSFER_ENCODING, "").lower() @@ -1062,25 +1138,10 @@ def update_transfer_encoding(self) -> None: self.headers[hdrs.TRANSFER_ENCODING] = "chunked" - def update_auth(self, auth: BasicAuth | None, trust_env: bool = False) -> None: - """Set basic auth.""" - if auth is None: - auth = self.auth - if auth is None: - return - - if not isinstance(auth, helpers.BasicAuth): - raise TypeError("BasicAuth() tuple is required instead") - - self.headers[hdrs.AUTHORIZATION] = auth.encode() - - def update_body_from_data(self, body: Any, _stacklevel: int = 3) -> None: + def _update_body_from_data(self, body: Any) -> None: """Update request body from data.""" - if self._body is not None: - _warn_if_unclosed_payload(self._body, stacklevel=_stacklevel) - if body is None: - self._body = None + self._body = self._EMPTY_BODY # Set Content-Length to 0 when body is None for methods that expect a body if ( self.method not in self.GET_METHODS @@ -1091,31 +1152,33 @@ def update_body_from_data(self, body: Any, _stacklevel: int = 3) -> None: return # FormData - maybe_payload = body() if isinstance(body, FormData) else body + if isinstance(body, FormData): + body = body() + else: + try: + body = payload.PAYLOAD_REGISTRY.get(body, disposition=None) + except payload.LookupError: + boundary = None + if hdrs.CONTENT_TYPE in self.headers: + boundary = parse_mimetype( + self.headers[hdrs.CONTENT_TYPE] + ).parameters.get("boundary") + body = FormData(body, boundary=boundary)() - try: - body_payload = payload.PAYLOAD_REGISTRY.get(maybe_payload, disposition=None) - except payload.LookupError: - boundary: str | None = None - if CONTENT_TYPE in self.headers: - boundary = parse_mimetype(self.headers[CONTENT_TYPE]).parameters.get( - "boundary" - ) - body_payload = FormData(maybe_payload, boundary=boundary)() # type: ignore[arg-type] + self._body = body - self._body = body_payload # enable chunked encoding if needed if not self.chunked and hdrs.CONTENT_LENGTH not in self.headers: - if (size := body_payload.size) is not None: + if (size := body.size) is not None: self.headers[hdrs.CONTENT_LENGTH] = str(size) else: self.chunked = True # copy payload headers - assert body_payload.headers + assert body.headers headers = self.headers skip_headers = self._skip_auto_headers - for key, value in body_payload.headers.items(): + for key, value in body.headers.items(): if key in headers or (skip_headers is not None and key in skip_headers): continue headers[key] = value @@ -1131,12 +1194,11 @@ def _update_body(self, body: Any) -> None: del self.headers[hdrs.TRANSFER_ENCODING] # Now update the body using the existing method - # Called from _update_body, add 1 to stacklevel from caller - self.update_body_from_data(body, _stacklevel=4) + self._update_body_from_data(body) # Update transfer encoding headers if needed (same logic as __init__) if body is not None or self.method not in self.GET_METHODS: - self.update_transfer_encoding() + self._update_transfer_encoding() async def update_body(self, body: Any) -> None: """ @@ -1200,7 +1262,7 @@ async def update_body(self, body: Any) -> None: await self._body.close() self._update_body(body) - def update_expect_continue(self, expect: bool = False) -> None: + def _update_expect_continue(self, expect: bool = False) -> None: if expect: self.headers[hdrs.EXPECT] = "100-continue" elif ( @@ -1212,11 +1274,11 @@ def update_expect_continue(self, expect: bool = False) -> None: if expect: self._continue = self.loop.create_future() - def update_proxy( + def _update_proxy( self, proxy: URL | None, proxy_auth: BasicAuth | None, - proxy_headers: LooseHeaders | None, + proxy_headers: CIMultiDict[str] | None, ) -> None: self.proxy = proxy if proxy is None: @@ -1224,17 +1286,54 @@ def update_proxy( self.proxy_headers = None return - if proxy_auth and not isinstance(proxy_auth, helpers.BasicAuth): + if proxy_auth and not isinstance(proxy_auth, BasicAuth): raise ValueError("proxy_auth must be None or BasicAuth() tuple") self.proxy_auth = proxy_auth - - if proxy_headers is not None and not isinstance( - proxy_headers, (MultiDict, MultiDictProxy) - ): - proxy_headers = CIMultiDict(proxy_headers) self.proxy_headers = proxy_headers - async def write_bytes( + def _create_response(self, task: asyncio.Task[None] | None) -> ClientResponse: + return self.response_class( + self.method, + self.original_url, + writer=task, + continue100=self._continue, + timer=self._timer, + traces=self._traces, + loop=self.loop, + session=self._session, + request_headers=self.headers, + original_url=self.original_url, + ) + + def _create_writer(self, protocol: BaseProtocol) -> StreamWriter: + writer = StreamWriter( + protocol, + self.loop, + on_chunk_sent=( + functools.partial(self._on_chunk_request_sent, self.method, self.url) + if self._traces + else None + ), + on_headers_sent=( + functools.partial(self._on_headers_request_sent, self.method, self.url) + if self._traces + else None + ), + ) + + if self.compress: + writer.enable_compression(self.compress) + + if self.chunked is not None: + writer.enable_chunking() + return writer + + def _should_write(self, protocol: BaseProtocol) -> bool: + return ( + self.body.size != 0 or self._continue is not None or protocol.writing_paused + ) + + async def _write_bytes( self, writer: AbstractStreamWriter, conn: "Connection", @@ -1276,14 +1375,7 @@ async def write_bytes( protocol = conn.protocol assert protocol is not None try: - # This should be a rare case but the - # self._body can be set to None while - # the task is being started or we wait above - # for the 100-continue response. - # The more likely case is we have an empty - # payload, but 100-continue is still expected. - if self._body is not None: - await self._body.write_with_length(writer, content_length) + await self._body.write_with_length(writer, content_length) except OSError as underlying_exc: reraised_exc = underlying_exc @@ -1316,109 +1408,10 @@ async def write_bytes( await writer.write_eof() protocol.start_timeout() - async def send(self, conn: "Connection") -> "ClientResponse": - # Specify request target: - # - CONNECT request must send authority form URI - # - not CONNECT proxy must send absolute form URI - # - most common is origin form URI - if self.method == hdrs.METH_CONNECT: - connect_host = self.url.host_subcomponent - assert connect_host is not None - path = f"{connect_host}:{self.url.port}" - elif self.proxy and not self.is_ssl(): - path = str(self.url) - else: - path = self.url.raw_path_qs - - protocol = conn.protocol - assert protocol is not None - writer = StreamWriter( - protocol, - self.loop, - on_chunk_sent=( - functools.partial(self._on_chunk_request_sent, self.method, self.url) - if self._traces - else None - ), - on_headers_sent=( - functools.partial(self._on_headers_request_sent, self.method, self.url) - if self._traces - else None - ), - ) - - if self.compress: - writer.enable_compression(self.compress) - - if self.chunked is not None: - writer.enable_chunking() - - # set default content-type - if ( - self.method in self.POST_METHODS - and ( - self._skip_auto_headers is None - or hdrs.CONTENT_TYPE not in self._skip_auto_headers - ) - and hdrs.CONTENT_TYPE not in self.headers - ): - self.headers[hdrs.CONTENT_TYPE] = "application/octet-stream" - - v = self.version - if hdrs.CONNECTION not in self.headers: - if conn._connector.force_close: - if v == HttpVersion11: - self.headers[hdrs.CONNECTION] = "close" - elif v == HttpVersion10: - self.headers[hdrs.CONNECTION] = "keep-alive" - - # status + headers - status_line = f"{self.method} {path} HTTP/{v.major}.{v.minor}" - - # Buffer headers for potential coalescing with body - await writer.write_headers(status_line, self.headers) - - task: asyncio.Task[None] | None - if self._body or self._continue is not None or protocol.writing_paused: - coro = self.write_bytes(writer, conn, self._get_content_length()) - if sys.version_info >= (3, 12): - # Optimization for Python 3.12, try to write - # bytes immediately to avoid having to schedule - # the task on the event loop. - task = asyncio.Task(coro, loop=self.loop, eager_start=True) - else: - task = self.loop.create_task(coro) - if task.done(): - task = None - else: - self._writer = task - else: - # We have nothing to write because - # - there is no body - # - the protocol does not have writing paused - # - we are not waiting for a 100-continue response - protocol.start_timeout() - writer.set_eof() - task = None - response_class = self.response_class - assert response_class is not None - self.response = response_class( - self.method, - self.original_url, - writer=task, - continue100=self._continue, - timer=self._timer, - request_info=self.request_info, - traces=self._traces, - loop=self.loop, - session=self._session, - ) - return self.response - - async def close(self) -> None: - if self.__writer is not None: + async def _close(self) -> None: + if self._writer_task is not None: try: - await self.__writer + await self._writer_task except asyncio.CancelledError: if ( sys.version_info >= (3, 11) @@ -1427,12 +1420,12 @@ async def close(self) -> None: ): raise - def terminate(self) -> None: - if self.__writer is not None: + def _terminate(self) -> None: + if self._writer_task is not None: if not self.loop.is_closed(): - self.__writer.cancel() - self.__writer.remove_done_callback(self.__reset_writer) - self.__writer = None + self._writer_task.cancel() + self._writer_task.remove_done_callback(self._reset_writer) + self._writer_task = None async def _on_chunk_request_sent(self, method: str, url: URL, chunk: bytes) -> None: for trace in self._traces: diff --git a/aiohttp/connector.py b/aiohttp/connector.py index a6eec4d3d15..0b6081f8e08 100644 --- a/aiohttp/connector.py +++ b/aiohttp/connector.py @@ -12,10 +12,11 @@ from itertools import chain, cycle, islice from time import monotonic from types import TracebackType -from typing import TYPE_CHECKING, Any, Literal, Optional, Union, cast +from typing import TYPE_CHECKING, Any, Literal, cast import aiohappyeyeballs from aiohappyeyeballs import AddrInfoType, SocketFactoryType +from multidict import CIMultiDict from . import hdrs, helpers from .abc import AbstractResolver, ResolveResult @@ -33,7 +34,12 @@ ssl_errors, ) from .client_proto import ResponseHandler -from .client_reqrep import SSL_ALLOWED_TYPES, ClientRequest, Fingerprint +from .client_reqrep import ( + SSL_ALLOWED_TYPES, + ClientRequest, + ClientRequestBase, + Fingerprint, +) from .helpers import ( _SENTINEL, ceil_timeout, @@ -48,7 +54,7 @@ if sys.version_info >= (3, 12): from collections.abc import Buffer else: - Buffer = Union[bytes, bytearray, "memoryview[int]", "memoryview[bytes]"] + Buffer = "bytes | bytearray | memoryview[int] | memoryview[bytes]" if TYPE_CHECKING: import ssl @@ -1145,7 +1151,7 @@ async def _create_connection( return proto - def _get_ssl_context(self, req: ClientRequest) -> SSLContext | None: + def _get_ssl_context(self, req: ClientRequestBase) -> SSLContext | None: """Logic to get the correct SSL context 0. if req.ssl is false, return None @@ -1178,7 +1184,7 @@ def _get_ssl_context(self, req: ClientRequest) -> SSLContext | None: return _SSL_CONTEXT_UNVERIFIED return _SSL_CONTEXT_VERIFIED - def _get_fingerprint(self, req: ClientRequest) -> Optional["Fingerprint"]: + def _get_fingerprint(self, req: ClientRequestBase) -> "Fingerprint | None": ret = req.ssl if isinstance(ret, Fingerprint): return ret @@ -1191,7 +1197,7 @@ async def _wrap_create_connection( self, *args: Any, addr_infos: list[AddrInfoType], - req: ClientRequest, + req: ClientRequestBase, timeout: "ClientTimeout", client_error: type[Exception] = ClientConnectorError, **kwargs: Any, @@ -1231,7 +1237,7 @@ def _warn_about_tls_in_tls( req: ClientRequest, ) -> None: """Issue a warning if the requested URL has HTTPS scheme.""" - if req.request_info.url.scheme != "https": + if req.url.scheme != "https": return # Check if uvloop is being used, which supports TLS in TLS, @@ -1294,7 +1300,7 @@ async def _start_tls_connection( underlying_transport, tls_proto, sslcontext, - server_hostname=req.server_hostname or req.host, + server_hostname=req.server_hostname or req.url.raw_host, ssl_handshake_timeout=timeout.total, ssl_shutdown_timeout=self._ssl_shutdown_timeout, ) @@ -1303,7 +1309,7 @@ async def _start_tls_connection( underlying_transport, tls_proto, sslcontext, - server_hostname=req.server_hostname or req.host, + server_hostname=req.server_hostname or req.url.raw_host, ssl_handshake_timeout=timeout.total, ) except BaseException: @@ -1340,7 +1346,7 @@ async def _start_tls_connection( raise ClientConnectionError( "Cannot initialize a TLS-in-TLS connection to host " - f"{req.host!s}:{req.port:d} through an underlying connection " + f"{req.url.host!s}:{req.url.port:d} through an underlying connection " f"to an HTTPS proxy {req.proxy!s} ssl:{req.ssl or 'default'} " f"[{type_err!s}]" ) from type_err @@ -1377,7 +1383,7 @@ def _convert_hosts_to_addr_infos( async def _create_direct_connection( self, - req: ClientRequest, + req: ClientRequestBase, traces: list["Trace"], timeout: "ClientTimeout", *, @@ -1393,7 +1399,7 @@ async def _create_direct_connection( # See https://github.com/aio-libs/aiohttp/pull/7364. if host.endswith(".."): host = host.rstrip(".") + "." - port = req.port + port = req.url.port assert port is not None try: # Cancelling this lookup should not cancel the underlying lookup @@ -1452,14 +1458,12 @@ async def _create_direct_connection( async def _create_proxy_connection( self, req: ClientRequest, traces: list["Trace"], timeout: "ClientTimeout" ) -> tuple[asyncio.BaseTransport, ResponseHandler]: - headers: dict[str, str] = {} - if req.proxy_headers is not None: - headers = req.proxy_headers # type: ignore[assignment] + headers = CIMultiDict[str]() if req.proxy_headers is None else req.proxy_headers headers[hdrs.HOST] = req.headers[hdrs.HOST] url = req.proxy assert url is not None - proxy_req = ClientRequest( + proxy_req = ClientRequestBase( hdrs.METH_GET, url, headers=headers, @@ -1498,7 +1502,7 @@ async def _create_proxy_connection( proxy=None, proxy_auth=None, proxy_headers_hash=None ) conn = _ConnectTunnelConnection(self, key, proto, self._loop) - proxy_resp = await proxy_req.send(conn) + proxy_resp = await proxy_req._send(conn) try: protocol = conn._protocol assert protocol is not None diff --git a/tests/conftest.py b/tests/conftest.py index 6833d2c1653..e5dc79cad4d 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,11 +1,14 @@ +from __future__ import annotations # TODO(PY311): Remove + import asyncio import base64 import os import socket import ssl import sys -from collections.abc import AsyncIterator, Callable, Generator, Iterator +from collections.abc import AsyncIterator, Callable, Iterator from hashlib import md5, sha1, sha256 +from http.cookies import BaseCookie from pathlib import Path from tempfile import TemporaryDirectory from typing import Any @@ -13,6 +16,8 @@ from uuid import uuid4 import pytest +from multidict import CIMultiDict +from yarl import URL try: from blockbuster import blockbuster_ctx @@ -22,9 +27,12 @@ HAS_BLOCKBUSTER = False from aiohttp import payload +from aiohttp.client import ClientSession from aiohttp.client_proto import ResponseHandler +from aiohttp.client_reqrep import ClientRequest, ClientRequestArgs, ClientResponse from aiohttp.compression_utils import ZLibBackend, ZLibBackendProtocol, set_zlib_backend -from aiohttp.http import WS_KEY +from aiohttp.helpers import TimerNoop +from aiohttp.http import WS_KEY, HttpVersion11 from aiohttp.test_utils import get_unused_port_socket, loop_context try: @@ -46,6 +54,11 @@ except ImportError: uvloop = None # type: ignore[assignment] +if sys.version_info >= (3, 11): + from typing import Unpack +else: + from typing import Any as Unpack + pytest_plugins = ("aiohttp.pytest_plugin", "pytester") @@ -343,7 +356,7 @@ def ws_key(key: bytes) -> str: @pytest.fixture -def enable_cleanup_closed() -> Generator[None, None, None]: +def enable_cleanup_closed() -> Iterator[None]: """Fixture to override the NEEDS_CLEANUP_CLOSED flag. On Python 3.12.7+ and 3.13.1+ enable_cleanup_closed is not needed, @@ -354,7 +367,7 @@ def enable_cleanup_closed() -> Generator[None, None, None]: @pytest.fixture -def unused_port_socket() -> Generator[socket.socket, None, None]: +def unused_port_socket() -> Iterator[socket.socket]: """Return a socket that is unused on the current host. Unlike aiohttp_used_port, the socket is yielded so there is no @@ -371,7 +384,7 @@ def unused_port_socket() -> Generator[socket.socket, None, None]: @pytest.fixture(params=["zlib", "zlib_ng.zlib_ng", "isal.isal_zlib"]) def parametrize_zlib_backend( request: pytest.FixtureRequest, -) -> Generator[None, None, None]: +) -> Iterator[None]: original_backend: ZLibBackendProtocol = ZLibBackend._zlib_backend backend = pytest.importorskip(request.param) set_zlib_backend(backend) @@ -391,3 +404,49 @@ async def cleanup_payload_pending_file_closes( loop_futures = [f for f in payload._CLOSE_FUTURES if f.get_loop() is loop] if loop_futures: await asyncio.gather(*loop_futures, return_exceptions=True) + + +@pytest.fixture +async def make_client_request( + loop: asyncio.AbstractEventLoop, +) -> AsyncIterator[Callable[[str, URL, Unpack[ClientRequestArgs]], ClientRequest]]: + """Fixture to help creating test ClientRequest objects with defaults.""" + request = session = None + + def maker( + method: str, url: URL, **kwargs: Unpack[ClientRequestArgs] + ) -> ClientRequest: + nonlocal request, session + session = ClientSession() + default_args: ClientRequestArgs = { + "loop": loop, + "params": {}, + "headers": CIMultiDict[str](), + "skip_auto_headers": None, + "data": None, + "cookies": BaseCookie[str](), + "auth": None, + "version": HttpVersion11, + "compress": False, + "chunked": None, + "expect100": False, + "response_class": ClientResponse, + "proxy": None, + "proxy_auth": None, + "timer": TimerNoop(), + "session": session, + "ssl": True, + "proxy_headers": None, + "traces": [], + "trust_env": False, + "server_hostname": None, + } + request = ClientRequest(method, url, **(default_args | kwargs)) + return request + + yield maker + + if request is not None: + await request._close() + assert session is not None + await session.close() diff --git a/tests/test_benchmarks_client_request.py b/tests/test_benchmarks_client_request.py index ea5e1c28a48..0a37087380f 100644 --- a/tests/test_benchmarks_client_request.py +++ b/tests/test_benchmarks_client_request.py @@ -1,24 +1,35 @@ """codspeed benchmarks for client requests.""" import asyncio +import sys +from collections.abc import Callable from http.cookies import BaseCookie +from typing import Any from multidict import CIMultiDict from pytest_codspeed import BenchmarkFixture from yarl import URL -from aiohttp.client_reqrep import ClientRequest, ClientResponse +from aiohttp.client_reqrep import ClientRequest, ClientRequestArgs, ClientResponse from aiohttp.cookiejar import CookieJar from aiohttp.helpers import TimerNoop from aiohttp.http_writer import HttpVersion11 from aiohttp.tracing import Trace +if sys.version_info >= (3, 11): + from typing import Unpack -def test_client_request_update_cookies( - loop: asyncio.AbstractEventLoop, benchmark: BenchmarkFixture + _RequestMaker = Callable[[str, URL, Unpack[ClientRequestArgs]], ClientRequest] +else: + _RequestMaker = Any + + +async def test_client_request_update_cookies( + benchmark: BenchmarkFixture, + make_client_request: _RequestMaker, ) -> None: url = URL("http://python.org") - req = ClientRequest("get", url, loop=loop) + req = make_client_request("get", url) cookie_jar = CookieJar() cookie_jar.update_cookies({"string": "Another string"}) cookies = cookie_jar.filter_cookies(url) @@ -26,11 +37,12 @@ def test_client_request_update_cookies( @benchmark def _run() -> None: - req.update_cookies(cookies=cookies) + req._update_cookies(cookies=cookies) def test_create_client_request_with_cookies( - loop: asyncio.AbstractEventLoop, benchmark: BenchmarkFixture + loop: asyncio.AbstractEventLoop, + benchmark: BenchmarkFixture, ) -> None: url = URL("http://python.org") cookie_jar = CookieJar() @@ -54,7 +66,7 @@ def _run() -> None: proxy_auth=None, proxy_headers=None, timer=timer, - session=None, + session=None, # type: ignore[arg-type] ssl=True, traces=traces, trust_env=False, @@ -71,7 +83,8 @@ def _run() -> None: def test_create_client_request_with_headers( - loop: asyncio.AbstractEventLoop, benchmark: BenchmarkFixture + loop: asyncio.AbstractEventLoop, + benchmark: BenchmarkFixture, ) -> None: url = URL("http://python.org") timer = TimerNoop() @@ -92,7 +105,7 @@ def _run() -> None: proxy_auth=None, proxy_headers=None, timer=timer, - session=None, + session=None, # type: ignore[arg-type] ssl=True, traces=traces, trust_env=False, @@ -109,10 +122,17 @@ def _run() -> None: def test_send_client_request_one_hundred( - loop: asyncio.AbstractEventLoop, benchmark: BenchmarkFixture + loop: asyncio.AbstractEventLoop, + benchmark: BenchmarkFixture, + make_client_request: _RequestMaker, ) -> None: url = URL("http://python.org") - req = ClientRequest("get", url, loop=loop) + + async def make_req() -> ClientRequest: + """Need async context.""" + return make_client_request("get", url) + + req = loop.run_until_complete(make_req()) class MockTransport(asyncio.Transport): """Mock transport for testing that do no real I/O.""" @@ -154,7 +174,7 @@ def __init__(self) -> None: async def send_requests() -> None: for _ in range(100): - await req.send(conn) # type: ignore[arg-type] + await req._send(conn) # type: ignore[arg-type] @benchmark def _run() -> None: diff --git a/tests/test_client_functional.py b/tests/test_client_functional.py index 731878d7c1b..e4242c6d117 100644 --- a/tests/test_client_functional.py +++ b/tests/test_client_functional.py @@ -38,8 +38,6 @@ TooManyRedirects, ) from aiohttp.client_reqrep import ClientRequest -from aiohttp.connector import Connection -from aiohttp.http_writer import StreamWriter from aiohttp.payload import ( AsyncIterablePayload, BufferedReaderPayload, @@ -1720,39 +1718,9 @@ async def test_GET_DEFLATE(aiohttp_client: AiohttpClient) -> None: async def handler(request: web.Request) -> web.Response: return web.json_response({"ok": True}) - write_mock = None - writelines_mock = None - original_write_bytes = ClientRequest.write_bytes - - async def write_bytes( - self: ClientRequest, - writer: StreamWriter, - conn: Connection, - content_length: int | None = None, - ) -> None: - nonlocal write_mock, writelines_mock - original_write = writer._write - original_writelines = writer._writelines - - with ( - mock.patch.object( - writer, - "_write", - autospec=True, - spec_set=True, - side_effect=original_write, - ) as write_mock, - mock.patch.object( - writer, - "_writelines", - autospec=True, - spec_set=True, - side_effect=original_writelines, - ) as writelines_mock, - ): - await original_write_bytes(self, writer, conn, content_length) - - with mock.patch.object(ClientRequest, "write_bytes", write_bytes): + with mock.patch.object( + ClientRequest, "_write_bytes", autospec=True, spec_set=True + ) as m: app = web.Application() app.router.add_get("/", handler) client = await aiohttp_client(app) @@ -1762,27 +1730,15 @@ async def write_bytes( content = await resp.json() assert content == {"ok": True} - # With packet coalescing, headers are buffered and may be written - # during write_bytes if there's an empty body to process. - # The test should verify no body chunks are written, but headers - # may be written as part of the coalescing optimization. - # If _write was called, it should only be for headers ending with \r\n\r\n - # and not any body content - for call in write_mock.call_args_list: # type: ignore[union-attr] - data = call[0][0] - assert data.endswith( - b"\r\n\r\n" - ), "Only headers should be written, not body chunks" - - # No body data should be written via writelines either - writelines_mock.assert_not_called() # type: ignore[union-attr] + # With an empty body, _write_bytes() should not be called at all. + m.assert_not_called() async def test_GET_DEFLATE_no_body(aiohttp_client: AiohttpClient) -> None: async def handler(request: web.Request) -> web.Response: return web.json_response({"ok": True}) - with mock.patch.object(ClientRequest, "write_bytes") as mock_write_bytes: + with mock.patch.object(ClientRequest, "_write_bytes") as mock_write_bytes: app = web.Application() app.router.add_get("/", handler) client = await aiohttp_client(app) @@ -2673,6 +2629,23 @@ async def handler(request: web.Request) -> web.Response: assert 200 == resp.status +async def test_cookies_is_quoted_with_special_characters( + aiohttp_client: AiohttpClient, +) -> None: + async def handler(request: web.Request) -> web.Response: + assert 'cookie1="val/one"' == request.headers["Cookie"] + assert "cookie1" in request.cookies + assert request.cookies["cookie1"] == "val/one" + return web.Response() + + app = web.Application() + app.router.add_get("/", handler) + client = await aiohttp_client(app) + + async with client.get("/", cookies={"cookie1": "val/one"}) as resp: + assert resp.status == 200 + + async def test_morsel_with_attributes(aiohttp_client: AiohttpClient) -> None: # A comment from original test: # diff --git a/tests/test_client_middleware.py b/tests/test_client_middleware.py index da5bcece6e8..222e912d3a9 100644 --- a/tests/test_client_middleware.py +++ b/tests/test_client_middleware.py @@ -1083,10 +1083,7 @@ def __init__(self, secretkey: str) -> None: self.secretkey = secretkey def get_hash(self, request: ClientRequest) -> str: - if request.body: - data = request.body.decode("utf-8") - else: - data = "{}" + data = request.body.decode("utf-8") or "{}" # Simulate authentication hash without using real crypto return f"SIGNATURE-{self.secretkey}-{len(data)}-{data[:10]}" diff --git a/tests/test_client_proto.py b/tests/test_client_proto.py index 0764d26221d..49a81c8dbb3 100644 --- a/tests/test_client_proto.py +++ b/tests/test_client_proto.py @@ -1,6 +1,7 @@ import asyncio from unittest import mock +from multidict import CIMultiDict from pytest_mock import MockerFixture from yarl import URL @@ -107,17 +108,19 @@ async def test_multiple_responses_one_byte_at_a_time( proto.data_received(messages[i : i + 1]) expected = [b"ab", b"cd", b"ef"] + url = URL("http://def-cl-resp.org") for payload in expected: response = ClientResponse( "get", - URL("http://def-cl-resp.org"), + url, writer=mock.Mock(), continue100=None, timer=TimerNoop(), - request_info=mock.Mock(), traces=[], loop=loop, session=mock.Mock(), + request_headers=CIMultiDict[str](), + original_url=url, ) await response.start(conn) await response.read() == payload @@ -138,16 +141,18 @@ class PatchableHttpResponseParser(http.HttpResponseParser): conn = mock.Mock(protocol=proto) proto.set_response_params(read_until_eof=True) proto.data_received(b"HTTP/1.1 200 OK\r\nContent-Length: 2\r\n\r\nab") + url = URL("http://def-cl-resp.org") response = ClientResponse( "get", - URL("http://def-cl-resp.org"), + url, writer=mock.Mock(), continue100=None, timer=TimerNoop(), - request_info=mock.Mock(), traces=[], loop=loop, session=mock.Mock(), + request_headers=CIMultiDict[str](), + original_url=url, ) await response.start(conn) await response.read() == b"ab" @@ -166,16 +171,18 @@ async def test_client_protocol_readuntil_eof(loop: asyncio.AbstractEventLoop) -> proto.data_received(b"HTTP/1.1 200 Ok\r\n\r\n") + url = URL("http://def-cl-resp.org") response = ClientResponse( "get", - URL("http://def-cl-resp.org"), + url, writer=mock.Mock(), continue100=None, timer=TimerNoop(), - request_info=mock.Mock(), traces=[], loop=loop, session=mock.Mock(), + request_headers=CIMultiDict[str](), + original_url=url, ) proto.set_response_params(read_until_eof=True) await response.start(conn) diff --git a/tests/test_client_request.py b/tests/test_client_request.py index e05b3198a79..74670cbc9f7 100644 --- a/tests/test_client_request.py +++ b/tests/test_client_request.py @@ -3,10 +3,9 @@ import io import pathlib import sys -import warnings -from collections.abc import AsyncIterator, Callable, Iterable, Iterator -from http.cookies import BaseCookie, Morsel, SimpleCookie -from typing import Any, Protocol +from collections.abc import AsyncIterator, Callable, Iterable +from http.cookies import BaseCookie, SimpleCookie +from typing import Any from unittest import mock import pytest @@ -20,19 +19,24 @@ from aiohttp.client_exceptions import ClientConnectionError from aiohttp.client_reqrep import ( ClientRequest, + ClientRequestArgs, ClientResponse, Fingerprint, _gen_default_accept_encoding, ) from aiohttp.compression_utils import ZLibBackend from aiohttp.connector import Connection +from aiohttp.hdrs import METH_DELETE +from aiohttp.helpers import TimerNoop from aiohttp.http import HttpVersion10, HttpVersion11, StreamWriter from aiohttp.multipart import MultipartWriter -from aiohttp.typedefs import LooseCookies +if sys.version_info >= (3, 11): + from typing import Unpack -class _RequestMaker(Protocol): - def __call__(self, method: str, url: str, **kwargs: Any) -> ClientRequest: ... + _RequestMaker = Callable[[str, URL, Unpack[ClientRequestArgs]], ClientRequest] +else: + _RequestMaker = Any class WriterMock(mock.AsyncMock): @@ -43,18 +47,9 @@ def remove_done_callback(self, cb: Callable[[], None]) -> None: """Dummy method.""" -@pytest.fixture -def make_request(loop: asyncio.AbstractEventLoop) -> Iterator[_RequestMaker]: - request = None - - def maker(method: str, url: str, **kwargs: Any) -> ClientRequest: - nonlocal request - request = ClientRequest(method, URL(url), loop=loop, **kwargs) - return request - - yield maker - if request is not None: - loop.run_until_complete(request.close()) +ALL_METHODS = frozenset( + (*ClientRequest.GET_METHODS, *ClientRequest.POST_METHODS, METH_DELETE) +) @pytest.fixture @@ -96,47 +91,73 @@ def conn(transport: asyncio.Transport, protocol: BaseProtocol) -> Connection: return mock.Mock(transport=transport, protocol=protocol) -def test_method1(make_request: _RequestMaker) -> None: - req = make_request("get", "http://python.org/") +async def test_method1(make_client_request: _RequestMaker) -> None: + req = make_client_request("get", URL("http://python.org/")) assert req.method == "GET" -def test_method2(make_request: _RequestMaker) -> None: - req = make_request("head", "http://python.org/") +async def test_method2(make_client_request: _RequestMaker) -> None: + req = make_client_request("head", URL("http://python.org/")) assert req.method == "HEAD" -def test_method3(make_request: _RequestMaker) -> None: - req = make_request("HEAD", "http://python.org/") +async def test_method3(make_client_request: _RequestMaker) -> None: + req = make_client_request("HEAD", URL("http://python.org/")) assert req.method == "HEAD" -def test_method_invalid(make_request: _RequestMaker) -> None: +async def test_method_invalid(make_client_request: _RequestMaker) -> None: with pytest.raises(ValueError, match="Method cannot contain non-token characters"): - make_request("METHOD WITH\nWHITESPACES", "http://python.org/") + make_client_request("METHOD WITH\nWHITESPACES", URL("http://python.org/")) -def test_version_1_0(make_request: _RequestMaker) -> None: - req = make_request("get", "http://python.org/", version="1.0") +async def test_version_1_0(make_client_request: _RequestMaker) -> None: + req = make_client_request("get", URL("http://python.org/"), version=HttpVersion10) assert req.version == (1, 0) -def test_version_default(make_request: _RequestMaker) -> None: - req = make_request("get", "http://python.org/") +async def test_version_default(make_client_request: _RequestMaker) -> None: + req = make_client_request("get", URL("http://python.org/")) assert req.version == (1, 1) -def test_request_info(make_request: _RequestMaker) -> None: - req = make_request("get", "http://python.org/") +async def test_request_info(make_client_request: _RequestMaker) -> None: + req = make_client_request("get", URL("http://python.org/")) url = URL("http://python.org/") h = CIMultiDictProxy(req.headers) - assert req.request_info == aiohttp.RequestInfo(url, "GET", h, url) + # Create a response to test request_info + resp = req.response_class( + "GET", + url, + writer=None, + continue100=None, + timer=TimerNoop(), + traces=[], + loop=req.loop, + session=None, + request_headers=req.headers, + original_url=url, + ) + assert resp.request_info == aiohttp.RequestInfo(url, "GET", h, url) -def test_request_info_with_fragment(make_request: _RequestMaker) -> None: - req = make_request("get", "http://python.org/#urlfragment") +async def test_request_info_with_fragment(make_client_request: _RequestMaker) -> None: + req = make_client_request("get", URL("http://python.org/#urlfragment")) h = CIMultiDictProxy(req.headers) - assert req.request_info == aiohttp.RequestInfo( + # Create a response to test request_info + resp = req.response_class( + "GET", + URL("http://python.org/"), + writer=None, + continue100=None, + timer=TimerNoop(), + traces=[], + loop=req.loop, + session=None, + request_headers=req.headers, + original_url=URL("http://python.org/#urlfragment"), + ) + assert resp.request_info == aiohttp.RequestInfo( URL("http://python.org/"), "GET", h, @@ -144,206 +165,233 @@ def test_request_info_with_fragment(make_request: _RequestMaker) -> None: ) -def test_version_err(make_request: _RequestMaker) -> None: - with pytest.raises(ValueError): - make_request("get", "http://python.org/", version="1.c") - - -def test_host_port_default_http(make_request: _RequestMaker) -> None: - req = make_request("get", "http://python.org/") - assert req.host == "python.org" - assert req.port == 80 +async def test_host_port_default_http(make_client_request: _RequestMaker) -> None: + req = make_client_request("get", URL("http://python.org/")) + assert req.url.host == "python.org" + assert req.url.port == 80 assert not req.is_ssl() -def test_host_port_default_https(make_request: _RequestMaker) -> None: - req = make_request("get", "https://python.org/") - assert req.host == "python.org" - assert req.port == 443 +async def test_host_port_default_https(make_client_request: _RequestMaker) -> None: + req = make_client_request("get", URL("https://python.org/")) + assert req.url.host == "python.org" + assert req.url.port == 443 assert req.is_ssl() -def test_host_port_nondefault_http(make_request: _RequestMaker) -> None: - req = make_request("get", "http://python.org:960/") - assert req.host == "python.org" - assert req.port == 960 +async def test_host_port_nondefault_http(make_client_request: _RequestMaker) -> None: + req = make_client_request("get", URL("http://python.org:960/")) + assert req.url.host == "python.org" + assert req.url.port == 960 assert not req.is_ssl() -def test_host_port_nondefault_https(make_request: _RequestMaker) -> None: - req = make_request("get", "https://python.org:960/") - assert req.host == "python.org" - assert req.port == 960 +async def test_host_port_nondefault_https(make_client_request: _RequestMaker) -> None: + req = make_client_request("get", URL("https://python.org:960/")) + assert req.url.host == "python.org" + assert req.url.port == 960 assert req.is_ssl() -def test_host_port_default_ws(make_request: _RequestMaker) -> None: - req = make_request("get", "ws://python.org/") - assert req.host == "python.org" - assert req.port == 80 +async def test_host_port_default_ws(make_client_request: _RequestMaker) -> None: + req = make_client_request("get", URL("ws://python.org/")) + assert req.url.host == "python.org" + assert req.url.port == 80 assert not req.is_ssl() -def test_host_port_default_wss(make_request: _RequestMaker) -> None: - req = make_request("get", "wss://python.org/") - assert req.host == "python.org" - assert req.port == 443 +async def test_host_port_default_wss(make_client_request: _RequestMaker) -> None: + req = make_client_request("get", URL("wss://python.org/")) + assert req.url.host == "python.org" + assert req.url.port == 443 assert req.is_ssl() -def test_host_port_nondefault_ws(make_request: _RequestMaker) -> None: - req = make_request("get", "ws://python.org:960/") - assert req.host == "python.org" - assert req.port == 960 +async def test_host_port_nondefault_ws(make_client_request: _RequestMaker) -> None: + req = make_client_request("get", URL("ws://python.org:960/")) + assert req.url.host == "python.org" + assert req.url.port == 960 assert not req.is_ssl() -def test_host_port_nondefault_wss(make_request: _RequestMaker) -> None: - req = make_request("get", "wss://python.org:960/") - assert req.host == "python.org" - assert req.port == 960 +async def test_host_port_nondefault_wss(make_client_request: _RequestMaker) -> None: + req = make_client_request("get", URL("wss://python.org:960/")) + assert req.url.host == "python.org" + assert req.url.port == 960 assert req.is_ssl() -def test_host_port_none_port(make_request: _RequestMaker) -> None: - req = make_request("get", "unix://localhost/path") +async def test_host_port_none_port(make_client_request: _RequestMaker) -> None: + req = make_client_request("get", URL("unix://localhost/path")) assert req.headers[hdrs.HOST] == "localhost" -def test_host_port_err(make_request: _RequestMaker) -> None: +async def test_host_port_err(make_client_request: _RequestMaker) -> None: with pytest.raises(ValueError): - make_request("get", "http://python.org:123e/") + make_client_request("get", URL("http://python.org:123e/")) -def test_hostname_err(make_request: _RequestMaker) -> None: +async def test_hostname_err(make_client_request: _RequestMaker) -> None: with pytest.raises(ValueError): - make_request("get", "http://:8080/") + make_client_request("get", URL("http://:8080/")) -def test_host_header_host_first(make_request: _RequestMaker) -> None: - req = make_request("get", "http://python.org/") +async def test_host_header_host_first(make_client_request: _RequestMaker) -> None: + req = make_client_request("get", URL("http://python.org/")) assert list(req.headers)[0] == hdrs.HOST -def test_host_header_host_without_port(make_request: _RequestMaker) -> None: - req = make_request("get", "http://python.org/") +async def test_host_header_host_without_port( + make_client_request: _RequestMaker, +) -> None: + req = make_client_request("get", URL("http://python.org/")) assert req.headers[hdrs.HOST] == "python.org" -def test_host_header_host_with_default_port(make_request: _RequestMaker) -> None: - req = make_request("get", "http://python.org:80/") +async def test_host_header_host_with_default_port( + make_client_request: _RequestMaker, +) -> None: + req = make_client_request("get", URL("http://python.org:80/")) assert req.headers[hdrs.HOST] == "python.org" -def test_host_header_host_with_nondefault_port(make_request: _RequestMaker) -> None: - req = make_request("get", "http://python.org:99/") +async def test_host_header_host_with_nondefault_port( + make_client_request: _RequestMaker, +) -> None: + req = make_client_request("get", URL("http://python.org:99/")) assert req.headers["HOST"] == "python.org:99" -def test_host_header_host_idna_encode(make_request: _RequestMaker) -> None: - req = make_request("get", "http://xn--9caa.com") +async def test_host_header_host_idna_encode(make_client_request: _RequestMaker) -> None: + req = make_client_request("get", URL("http://xn--9caa.com")) assert req.headers["HOST"] == "xn--9caa.com" -def test_host_header_host_unicode(make_request: _RequestMaker) -> None: - req = make_request("get", "http://éé.com") +async def test_host_header_host_unicode(make_client_request: _RequestMaker) -> None: + req = make_client_request("get", URL("http://éé.com")) assert req.headers["HOST"] == "xn--9caa.com" -def test_host_header_explicit_host(make_request: _RequestMaker) -> None: - req = make_request("get", "http://python.org/", headers={"host": "example.com"}) +async def test_host_header_explicit_host(make_client_request: _RequestMaker) -> None: + req = make_client_request( + "get", URL("http://python.org/"), headers=CIMultiDict({"host": "example.com"}) + ) assert req.headers["HOST"] == "example.com" -def test_host_header_explicit_host_with_port(make_request: _RequestMaker) -> None: - req = make_request("get", "http://python.org/", headers={"host": "example.com:99"}) +async def test_host_header_explicit_host_with_port( + make_client_request: _RequestMaker, +) -> None: + req = make_client_request( + "get", + URL("http://python.org/"), + headers=CIMultiDict({"host": "example.com:99"}), + ) assert req.headers["HOST"] == "example.com:99" -def test_host_header_ipv4(make_request: _RequestMaker) -> None: - req = make_request("get", "http://127.0.0.2") +async def test_host_header_ipv4(make_client_request: _RequestMaker) -> None: + req = make_client_request("get", URL("http://127.0.0.2")) assert req.headers["HOST"] == "127.0.0.2" -def test_host_header_ipv6(make_request: _RequestMaker) -> None: - req = make_request("get", "http://[::2]") +async def test_host_header_ipv6(make_client_request: _RequestMaker) -> None: + req = make_client_request("get", URL("http://[::2]")) assert req.headers["HOST"] == "[::2]" -def test_host_header_ipv4_with_port(make_request: _RequestMaker) -> None: - req = make_request("get", "http://127.0.0.2:99") +async def test_host_header_ipv4_with_port(make_client_request: _RequestMaker) -> None: + req = make_client_request("get", URL("http://127.0.0.2:99")) assert req.headers["HOST"] == "127.0.0.2:99" -def test_host_header_ipv6_with_port(make_request: _RequestMaker) -> None: - req = make_request("get", "http://[::2]:99") +async def test_host_header_ipv6_with_port(make_client_request: _RequestMaker) -> None: + req = make_client_request("get", URL("http://[::2]:99")) assert req.headers["HOST"] == "[::2]:99" @pytest.mark.parametrize( ("url", "headers", "expected"), ( - pytest.param("http://localhost.", None, "localhost", id="dot only at the end"), - pytest.param("http://python.org.", None, "python.org", id="single dot"), pytest.param( - "http://python.org.:99", None, "python.org:99", id="single dot with port" + "http://localhost.", CIMultiDict(), "localhost", id="dot only at the end" + ), + pytest.param( + "http://python.org.", CIMultiDict(), "python.org", id="single dot" + ), + pytest.param( + "http://python.org.:99", + CIMultiDict(), + "python.org:99", + id="single dot with port", ), pytest.param( "http://python.org...:99", - None, + CIMultiDict(), "python.org:99", id="multiple dots with port", ), pytest.param( "http://python.org.:99", - {"host": "example.com.:99"}, + CIMultiDict({"host": "example.com.:99"}), "example.com.:99", id="explicit host header", ), - pytest.param("https://python.org.", None, "python.org", id="https"), - pytest.param("https://...", None, "", id="only dots"), + pytest.param("https://python.org.", CIMultiDict(), "python.org", id="https"), + pytest.param("https://...", CIMultiDict(), "", id="only dots"), pytest.param( "http://príklad.example.org.:99", - None, + CIMultiDict(), "xn--prklad-4va.example.org:99", id="single dot with port idna", ), ), ) -def test_host_header_fqdn( - make_request: _RequestMaker, url: str, headers: dict[str, str], expected: str +async def test_host_header_fqdn( # type: ignore[misc] + make_client_request: _RequestMaker, + url: str, + headers: CIMultiDict[str], + expected: str, ) -> None: - req = make_request("get", url, headers=headers) + req = make_client_request("get", URL(url), headers=headers) assert req.headers["HOST"] == expected -def test_default_headers_useragent(make_request: _RequestMaker) -> None: - req = make_request("get", "http://python.org/") +async def test_default_headers_useragent(make_client_request: _RequestMaker) -> None: + req = make_client_request("get", URL("http://python.org/")) assert "SERVER" not in req.headers assert "USER-AGENT" in req.headers -def test_default_headers_useragent_custom(make_request: _RequestMaker) -> None: - req = make_request( - "get", "http://python.org/", headers={"user-agent": "my custom agent"} +async def test_default_headers_useragent_custom( + make_client_request: _RequestMaker, +) -> None: + req = make_client_request( + "get", + URL("http://python.org/"), + headers=CIMultiDict({"user-agent": "my custom agent"}), ) assert "USER-Agent" in req.headers assert "my custom agent" == req.headers["User-Agent"] -def test_skip_default_useragent_header(make_request: _RequestMaker) -> None: - req = make_request( - "get", "http://python.org/", skip_auto_headers={istr("user-agent")} +async def test_skip_default_useragent_header( + make_client_request: _RequestMaker, +) -> None: + req = make_client_request( + "get", URL("http://python.org/"), skip_auto_headers={istr("user-agent")} ) assert "User-Agent" not in req.headers -def test_headers(make_request: _RequestMaker) -> None: - req = make_request( - "post", "http://python.org/", headers={hdrs.CONTENT_TYPE: "text/plain"} +async def test_headers(make_client_request: _RequestMaker) -> None: + req = make_client_request( + "post", + URL("http://python.org/"), + headers=CIMultiDict({hdrs.CONTENT_TYPE: "text/plain"}), ) assert hdrs.CONTENT_TYPE in req.headers @@ -351,217 +399,228 @@ def test_headers(make_request: _RequestMaker) -> None: assert "gzip" in req.headers[hdrs.ACCEPT_ENCODING] -def test_headers_list(make_request: _RequestMaker) -> None: - req = make_request( - "post", "http://python.org/", headers=[("Content-Type", "text/plain")] +async def test_headers_list(make_client_request: _RequestMaker) -> None: + req = make_client_request( + "post", + URL("http://python.org/"), + headers=CIMultiDict((("Content-Type", "text/plain"),)), ) assert "CONTENT-TYPE" in req.headers assert req.headers["CONTENT-TYPE"] == "text/plain" -def test_headers_default(make_request: _RequestMaker) -> None: - req = make_request( - "get", "http://python.org/", headers={"ACCEPT-ENCODING": "deflate"} +async def test_headers_default(make_client_request: _RequestMaker) -> None: + req = make_client_request( + "get", + URL("http://python.org/"), + headers=CIMultiDict({"ACCEPT-ENCODING": "deflate"}), ) assert req.headers["ACCEPT-ENCODING"] == "deflate" -def test_invalid_url(make_request: _RequestMaker) -> None: +async def test_invalid_url(make_client_request: _RequestMaker) -> None: with pytest.raises(aiohttp.InvalidURL): - make_request("get", "hiwpefhipowhefopw") + make_client_request("get", URL("hiwpefhipowhefopw")) -def test_no_path(make_request: _RequestMaker) -> None: - req = make_request("get", "http://python.org") +async def test_no_path(make_client_request: _RequestMaker) -> None: + req = make_client_request("get", URL("http://python.org")) assert "/" == req.url.path -def test_ipv6_default_http_port(make_request: _RequestMaker) -> None: - req = make_request("get", "http://[2001:db8::1]/") - assert req.host == "2001:db8::1" - assert req.port == 80 +async def test_ipv6_default_http_port(make_client_request: _RequestMaker) -> None: + req = make_client_request("get", URL("http://[2001:db8::1]/")) + assert req.url.host == "2001:db8::1" + assert req.url.port == 80 assert not req.is_ssl() -def test_ipv6_default_https_port(make_request: _RequestMaker) -> None: - req = make_request("get", "https://[2001:db8::1]/") - assert req.host == "2001:db8::1" - assert req.port == 443 +async def test_ipv6_default_https_port(make_client_request: _RequestMaker) -> None: + req = make_client_request("get", URL("https://[2001:db8::1]/")) + assert req.url.host == "2001:db8::1" + assert req.url.port == 443 assert req.is_ssl() -def test_ipv6_nondefault_http_port(make_request: _RequestMaker) -> None: - req = make_request("get", "http://[2001:db8::1]:960/") - assert req.host == "2001:db8::1" - assert req.port == 960 +async def test_ipv6_nondefault_http_port(make_client_request: _RequestMaker) -> None: + req = make_client_request("get", URL("http://[2001:db8::1]:960/")) + assert req.url.host == "2001:db8::1" + assert req.url.port == 960 assert not req.is_ssl() -def test_ipv6_nondefault_https_port(make_request: _RequestMaker) -> None: - req = make_request("get", "https://[2001:db8::1]:960/") - assert req.host == "2001:db8::1" - assert req.port == 960 +async def test_ipv6_nondefault_https_port(make_client_request: _RequestMaker) -> None: + req = make_client_request("get", URL("https://[2001:db8::1]:960/")) + assert req.url.host == "2001:db8::1" + assert req.url.port == 960 assert req.is_ssl() -def test_basic_auth(make_request: _RequestMaker) -> None: - req = make_request( - "get", "http://python.org", auth=aiohttp.BasicAuth("nkim", "1234") +async def test_basic_auth(make_client_request: _RequestMaker) -> None: + req = make_client_request( + "get", URL("http://python.org"), auth=aiohttp.BasicAuth("nkim", "1234") ) assert "AUTHORIZATION" in req.headers assert "Basic bmtpbToxMjM0" == req.headers["AUTHORIZATION"] -def test_basic_auth_utf8(make_request: _RequestMaker) -> None: - req = make_request( - "get", "http://python.org", auth=aiohttp.BasicAuth("nkim", "секрет", "utf-8") +async def test_basic_auth_utf8(make_client_request: _RequestMaker) -> None: + req = make_client_request( + "get", + URL("http://python.org"), + auth=aiohttp.BasicAuth("nkim", "секрет", "utf-8"), ) assert "AUTHORIZATION" in req.headers assert "Basic bmtpbTrRgdC10LrRgNC10YI=" == req.headers["AUTHORIZATION"] -def test_basic_auth_tuple_forbidden(make_request: _RequestMaker) -> None: +async def test_basic_auth_tuple_forbidden(make_client_request: _RequestMaker) -> None: with pytest.raises(TypeError): - make_request("get", "http://python.org", auth=("nkim", "1234")) + make_client_request("get", URL("http://python.org"), auth=("nkim", "1234")) # type: ignore[arg-type] -def test_basic_auth_from_url(make_request: _RequestMaker) -> None: - req = make_request("get", "http://nkim:1234@python.org") +async def test_basic_auth_from_url(make_client_request: _RequestMaker) -> None: + req = make_client_request("get", URL("http://nkim:1234@python.org")) assert "AUTHORIZATION" in req.headers assert "Basic bmtpbToxMjM0" == req.headers["AUTHORIZATION"] - assert "python.org" == req.host + assert "python.org" == req.url.host -def test_basic_auth_no_user_from_url(make_request: _RequestMaker) -> None: - req = make_request("get", "http://:1234@python.org") +async def test_basic_auth_no_user_from_url(make_client_request: _RequestMaker) -> None: + req = make_client_request("get", URL("http://:1234@python.org")) assert "AUTHORIZATION" in req.headers assert "Basic OjEyMzQ=" == req.headers["AUTHORIZATION"] - assert "python.org" == req.host + assert "python.org" == req.url.host -def test_basic_auth_from_url_overridden(make_request: _RequestMaker) -> None: - req = make_request( - "get", "http://garbage@python.org", auth=aiohttp.BasicAuth("nkim", "1234") +async def test_basic_auth_from_url_overridden( + make_client_request: _RequestMaker, +) -> None: + req = make_client_request( + "get", URL("http://garbage@python.org"), auth=aiohttp.BasicAuth("nkim", "1234") ) assert "AUTHORIZATION" in req.headers assert "Basic bmtpbToxMjM0" == req.headers["AUTHORIZATION"] - assert "python.org" == req.host + assert "python.org" == req.url.host -def test_path_is_not_double_encoded1(make_request: _RequestMaker) -> None: - req = make_request("get", "http://0.0.0.0/get/test case") +async def test_path_is_not_double_encoded1(make_client_request: _RequestMaker) -> None: + req = make_client_request("get", URL("http://0.0.0.0/get/test case")) assert req.url.raw_path == "/get/test%20case" -def test_path_is_not_double_encoded2(make_request: _RequestMaker) -> None: - req = make_request("get", "http://0.0.0.0/get/test%2fcase") +async def test_path_is_not_double_encoded2(make_client_request: _RequestMaker) -> None: + req = make_client_request("get", URL("http://0.0.0.0/get/test%2fcase")) assert req.url.raw_path == "/get/test%2Fcase" -def test_path_is_not_double_encoded3(make_request: _RequestMaker) -> None: - req = make_request("get", "http://0.0.0.0/get/test%20case") +async def test_path_is_not_double_encoded3(make_client_request: _RequestMaker) -> None: + req = make_client_request("get", URL("http://0.0.0.0/get/test%20case")) assert req.url.raw_path == "/get/test%20case" -def test_path_safe_chars_preserved(make_request: _RequestMaker) -> None: - req = make_request("get", "http://0.0.0.0/get/:=+/%2B/") +async def test_path_safe_chars_preserved(make_client_request: _RequestMaker) -> None: + req = make_client_request("get", URL("http://0.0.0.0/get/:=+/%2B/")) assert req.url.path == "/get/:=+/+/" -def test_params_are_added_before_fragment1(make_request: _RequestMaker) -> None: - req = make_request("GET", "http://example.com/path#fragment", params={"a": "b"}) +async def test_params_are_added_before_fragment1( + make_client_request: _RequestMaker, +) -> None: + req = make_client_request( + "GET", URL("http://example.com/path#fragment"), params={"a": "b"} + ) assert str(req.url) == "http://example.com/path?a=b" -def test_params_are_added_before_fragment2(make_request: _RequestMaker) -> None: - req = make_request( - "GET", "http://example.com/path?key=value#fragment", params={"a": "b"} +async def test_params_are_added_before_fragment2( + make_client_request: _RequestMaker, +) -> None: + req = make_client_request( + "GET", URL("http://example.com/path?key=value#fragment"), params={"a": "b"} ) assert str(req.url) == "http://example.com/path?key=value&a=b" -def test_path_not_contain_fragment1(make_request: _RequestMaker) -> None: - req = make_request("GET", "http://example.com/path#fragment") +async def test_path_not_contain_fragment1(make_client_request: _RequestMaker) -> None: + req = make_client_request("GET", URL("http://example.com/path#fragment")) assert req.url.path == "/path" -def test_path_not_contain_fragment2(make_request: _RequestMaker) -> None: - req = make_request("GET", "http://example.com/path?key=value#fragment") +async def test_path_not_contain_fragment2(make_client_request: _RequestMaker) -> None: + req = make_client_request("GET", URL("http://example.com/path?key=value#fragment")) assert str(req.url) == "http://example.com/path?key=value" -def test_cookies(make_request: _RequestMaker) -> None: - req = make_request("get", "http://test.com/path", cookies={"cookie1": "val1"}) +async def test_cookies(make_client_request: _RequestMaker) -> None: + req = make_client_request( + "get", URL("http://test.com/path"), cookies=BaseCookie({"cookie1": "val1"}) + ) assert "COOKIE" in req.headers assert "cookie1=val1" == req.headers["COOKIE"] -def test_cookies_is_quoted_with_special_characters(make_request: _RequestMaker) -> None: - req = make_request("get", "http://test.com/path", cookies={"cookie1": "val/one"}) - - assert "COOKIE" in req.headers - assert 'cookie1="val/one"' == req.headers["COOKIE"] - - -def test_cookies_merge_with_headers(make_request: _RequestMaker) -> None: - req = make_request( +async def test_cookies_merge_with_headers(make_client_request: _RequestMaker) -> None: + req = make_client_request( "get", - "http://test.com/path", - headers={"cookie": "cookie1=val1"}, - cookies={"cookie2": "val2"}, + URL("http://test.com/path"), + headers=CIMultiDict({"cookie": "cookie1=val1"}), + cookies=BaseCookie({"cookie2": "val2"}), ) assert "cookie1=val1; cookie2=val2" == req.headers["COOKIE"] -def test_query_multivalued_param(make_request: _RequestMaker) -> None: - for meth in ClientRequest.ALL_METHODS: - req = make_request( - meth, "http://python.org", params=(("test", "foo"), ("test", "baz")) +async def test_query_multivalued_param(make_client_request: _RequestMaker) -> None: + for meth in ALL_METHODS: + req = make_client_request( + meth, URL("http://python.org"), params=(("test", "foo"), ("test", "baz")) ) assert str(req.url) == "http://python.org/?test=foo&test=baz" -def test_query_str_param(make_request: _RequestMaker) -> None: - for meth in ClientRequest.ALL_METHODS: - req = make_request(meth, "http://python.org", params="test=foo") +async def test_query_str_param(make_client_request: _RequestMaker) -> None: + for meth in ALL_METHODS: + req = make_client_request(meth, URL("http://python.org"), params="test=foo") assert str(req.url) == "http://python.org/?test=foo" -def test_query_bytes_param_raises(make_request: _RequestMaker) -> None: - for meth in ClientRequest.ALL_METHODS: +async def test_query_bytes_param_raises(make_client_request: _RequestMaker) -> None: + for meth in ALL_METHODS: with pytest.raises(TypeError): - make_request(meth, "http://python.org", params=b"test=foo") + make_client_request(meth, URL("http://python.org"), params=b"test=foo") # type: ignore[arg-type] -def test_query_str_param_is_not_encoded(make_request: _RequestMaker) -> None: - for meth in ClientRequest.ALL_METHODS: - req = make_request(meth, "http://python.org", params="test=f+oo") +async def test_query_str_param_is_not_encoded( + make_client_request: _RequestMaker, +) -> None: + for meth in ALL_METHODS: + req = make_client_request(meth, URL("http://python.org"), params="test=f+oo") assert str(req.url) == "http://python.org/?test=f+oo" -def test_params_update_path_and_url(make_request: _RequestMaker) -> None: - req = make_request( - "get", "http://python.org", params=(("test", "foo"), ("test", "baz")) +async def test_params_update_path_and_url(make_client_request: _RequestMaker) -> None: + req = make_client_request( + "get", URL("http://python.org"), params=(("test", "foo"), ("test", "baz")) ) assert str(req.url) == "http://python.org/?test=foo&test=baz" -def test_params_empty_path_and_url(make_request: _RequestMaker) -> None: - req_empty = make_request("get", "http://python.org", params={}) +async def test_params_empty_path_and_url(make_client_request: _RequestMaker) -> None: + req_empty = make_client_request("get", URL("http://python.org"), params={}) assert str(req_empty.url) == "http://python.org" - req_none = make_request("get", "http://python.org") + req_none = make_client_request("get", URL("http://python.org")) assert str(req_none.url) == "http://python.org" -def test_gen_netloc_all(make_request: _RequestMaker) -> None: - req = make_request( +async def test_gen_netloc_all(make_client_request: _RequestMaker) -> None: + req = make_client_request( "get", - "https://aiohttp:pwpwpw@" - + "12345678901234567890123456789" - + "012345678901234567890:8080", + URL( + "https://aiohttp:pwpwpw@12345678901234567890123456789012345678901234567890:8080" + ), ) assert ( req.headers["HOST"] @@ -569,40 +628,46 @@ def test_gen_netloc_all(make_request: _RequestMaker) -> None: ) -def test_gen_netloc_no_port(make_request: _RequestMaker) -> None: - req = make_request( +async def test_gen_netloc_no_port(make_client_request: _RequestMaker) -> None: + req = make_client_request( "get", - "https://aiohttp:pwpwpw@" - + "12345678901234567890123456789" - + "012345678901234567890/", + URL( + "https://aiohttp:pwpwpw@12345678901234567890123456789012345678901234567890/" + ), ) assert ( req.headers["HOST"] == "12345678901234567890123456789" + "012345678901234567890" ) -def test_cookie_coded_value_preserved(loop: asyncio.AbstractEventLoop) -> None: +async def test_cookie_coded_value_preserved( + loop: asyncio.AbstractEventLoop, + make_client_request: _RequestMaker, +) -> None: """Verify the coded value of a cookie is preserved.""" # https://github.com/aio-libs/aiohttp/pull/1453 - req = ClientRequest("get", URL("http://python.org"), loop=loop) - req.update_cookies(cookies=SimpleCookie('ip-cookie="second"; Domain=127.0.0.1;')) + req = make_client_request("get", URL("http://python.org"), loop=loop) + req._update_cookies(cookies=SimpleCookie('ip-cookie="second"; Domain=127.0.0.1;')) assert req.headers["COOKIE"] == 'ip-cookie="second"' -def test_update_cookies_with_special_chars_in_existing_header( +async def test_update_cookies_with_special_chars_in_existing_header( loop: asyncio.AbstractEventLoop, + make_client_request: _RequestMaker, ) -> None: """Test that update_cookies handles existing cookies with special characters.""" # Create request with a cookie that has special characters (real-world example) - req = ClientRequest( + req = make_client_request( "get", URL("http://python.org"), - headers={"Cookie": "ISAWPLB{A7F52349-3531-4DA9-8776-F74BC6F4F1BB}=value1"}, + headers=CIMultiDict( + {"Cookie": "ISAWPLB{A7F52349-3531-4DA9-8776-F74BC6F4F1BB}=value1"} + ), loop=loop, ) # Update with another cookie - req.update_cookies(cookies={"normal_cookie": "value2"}) + req._update_cookies(cookies=BaseCookie({"normal_cookie": "value2"})) # Both cookies should be preserved in the exact order assert ( @@ -611,20 +676,21 @@ def test_update_cookies_with_special_chars_in_existing_header( ) -def test_update_cookies_with_quoted_existing_header( +async def test_update_cookies_with_quoted_existing_header( loop: asyncio.AbstractEventLoop, + make_client_request: _RequestMaker, ) -> None: """Test that update_cookies handles existing cookies with quoted values.""" # Create request with cookies that have quoted values - req = ClientRequest( + req = make_client_request( "get", URL("http://python.org"), - headers={"Cookie": 'session="value;with;semicolon"; token=abc123'}, + headers=CIMultiDict({"Cookie": 'session="value;with;semicolon"; token=abc123'}), loop=loop, ) # Update with another cookie - req.update_cookies(cookies={"new_cookie": "new_value"}) + req._update_cookies(cookies=BaseCookie({"new_cookie": "new_value"})) # All cookies should be preserved with their original coded values # The quoted value should be preserved as-is @@ -635,90 +701,106 @@ def test_update_cookies_with_quoted_existing_header( async def test_connection_header( - loop: asyncio.AbstractEventLoop, conn: mock.Mock + loop: asyncio.AbstractEventLoop, + conn: mock.Mock, + make_client_request: _RequestMaker, ) -> None: - req = ClientRequest("get", URL("http://python.org"), loop=loop) + req = make_client_request("get", URL("http://python.org"), loop=loop) req.headers.clear() req.version = HttpVersion11 req.headers.clear() with mock.patch.object(conn._connector, "force_close", False): - await req.send(conn) + await req._send(conn) assert req.headers.get("CONNECTION") is None req.version = HttpVersion10 req.headers.clear() with mock.patch.object(conn._connector, "force_close", False): - await req.send(conn) + await req._send(conn) assert req.headers.get("CONNECTION") == "keep-alive" req.version = HttpVersion11 req.headers.clear() with mock.patch.object(conn._connector, "force_close", True): - await req.send(conn) + await req._send(conn) assert req.headers.get("CONNECTION") == "close" req.version = HttpVersion10 req.headers.clear() with mock.patch.object(conn._connector, "force_close", True): - await req.send(conn) + await req._send(conn) assert not req.headers.get("CONNECTION") async def test_no_content_length( - loop: asyncio.AbstractEventLoop, conn: mock.Mock + loop: asyncio.AbstractEventLoop, + conn: mock.Mock, + make_client_request: _RequestMaker, ) -> None: - req = ClientRequest("get", URL("http://python.org"), loop=loop) - resp = await req.send(conn) + req = make_client_request("get", URL("http://python.org"), loop=loop) + resp = await req._send(conn) assert req.headers.get("CONTENT-LENGTH") is None - await req.close() + await req._close() resp.close() async def test_no_content_length_head( - loop: asyncio.AbstractEventLoop, conn: mock.Mock + loop: asyncio.AbstractEventLoop, + conn: mock.Mock, + make_client_request: _RequestMaker, ) -> None: - req = ClientRequest("head", URL("http://python.org"), loop=loop) - resp = await req.send(conn) + req = make_client_request("head", URL("http://python.org"), loop=loop) + resp = await req._send(conn) assert req.headers.get("CONTENT-LENGTH") is None - await req.close() + await req._close() resp.close() async def test_content_type_auto_header_get( - loop: asyncio.AbstractEventLoop, conn: mock.Mock + loop: asyncio.AbstractEventLoop, + conn: mock.Mock, + make_client_request: _RequestMaker, ) -> None: - req = ClientRequest("get", URL("http://python.org"), loop=loop) - resp = await req.send(conn) + req = make_client_request("get", URL("http://python.org"), loop=loop) + resp = await req._send(conn) assert "CONTENT-TYPE" not in req.headers resp.close() - await req.close() + await req._close() async def test_content_type_auto_header_form( - loop: asyncio.AbstractEventLoop, conn: mock.Mock + loop: asyncio.AbstractEventLoop, + conn: mock.Mock, + make_client_request: _RequestMaker, ) -> None: - req = ClientRequest( + req = make_client_request( "post", URL("http://python.org"), data={"hey": "you"}, loop=loop ) - resp = await req.send(conn) + resp = await req._send(conn) assert "application/x-www-form-urlencoded" == req.headers.get("CONTENT-TYPE") resp.close() async def test_content_type_auto_header_bytes( - loop: asyncio.AbstractEventLoop, conn: mock.Mock + loop: asyncio.AbstractEventLoop, + conn: mock.Mock, + make_client_request: _RequestMaker, ) -> None: - req = ClientRequest("post", URL("http://python.org"), data=b"hey you", loop=loop) - resp = await req.send(conn) + req = make_client_request( + "post", URL("http://python.org"), data=b"hey you", loop=loop + ) + resp = await req._send(conn) assert "application/octet-stream" == req.headers.get("CONTENT-TYPE") resp.close() async def test_content_type_skip_auto_header_bytes( - loop: asyncio.AbstractEventLoop, conn: mock.Mock + loop: asyncio.AbstractEventLoop, + conn: mock.Mock, + make_client_request: _RequestMaker, ) -> None: - req = ClientRequest( + req = make_client_request( "post", URL("http://python.org"), data=b"hey you", @@ -726,52 +808,58 @@ async def test_content_type_skip_auto_header_bytes( loop=loop, ) assert req.skip_auto_headers == CIMultiDict({"CONTENT-TYPE": None}) - resp = await req.send(conn) + resp = await req._send(conn) assert "CONTENT-TYPE" not in req.headers resp.close() async def test_content_type_skip_auto_header_form( - loop: asyncio.AbstractEventLoop, conn: mock.Mock + loop: asyncio.AbstractEventLoop, + conn: mock.Mock, + make_client_request: _RequestMaker, ) -> None: - req = ClientRequest( + req = make_client_request( "post", URL("http://python.org"), data={"hey": "you"}, loop=loop, skip_auto_headers={"Content-Type"}, ) - resp = await req.send(conn) + resp = await req._send(conn) assert "CONTENT-TYPE" not in req.headers resp.close() async def test_content_type_auto_header_content_length_no_skip( - loop: asyncio.AbstractEventLoop, conn: mock.Mock + loop: asyncio.AbstractEventLoop, + conn: mock.Mock, + make_client_request: _RequestMaker, ) -> None: with io.BytesIO(b"hey") as file_handle: - req = ClientRequest( + req = make_client_request( "post", URL("http://python.org"), data=file_handle, skip_auto_headers={"Content-Length"}, loop=loop, ) - resp = await req.send(conn) + resp = await req._send(conn) assert req.headers.get("CONTENT-LENGTH") == "3" resp.close() async def test_urlencoded_formdata_charset( - loop: asyncio.AbstractEventLoop, conn: mock.Mock + loop: asyncio.AbstractEventLoop, + conn: mock.Mock, + make_client_request: _RequestMaker, ) -> None: - req = ClientRequest( + req = make_client_request( "post", URL("http://python.org"), data=aiohttp.FormData({"hey": "you"}, charset="koi8-r"), loop=loop, ) - async with await req.send(conn): + async with await req._send(conn): await asyncio.sleep(0) assert "application/x-www-form-urlencoded; charset=koi8-r" == req.headers.get( "CONTENT-TYPE" @@ -779,53 +867,66 @@ async def test_urlencoded_formdata_charset( async def test_formdata_boundary_from_headers( - loop: asyncio.AbstractEventLoop, conn: mock.Mock + loop: asyncio.AbstractEventLoop, + conn: mock.Mock, + make_client_request: _RequestMaker, ) -> None: boundary = "some_boundary" file_path = pathlib.Path(__file__).parent / "aiohttp.png" with file_path.open("rb") as f: - req = ClientRequest( + req = make_client_request( "post", URL("http://python.org"), data={"aiohttp.png": f}, - headers={"Content-Type": f"multipart/form-data; boundary={boundary}"}, + headers=CIMultiDict( + {"Content-Type": f"multipart/form-data; boundary={boundary}"} + ), loop=loop, ) - async with await req.send(conn): + async with await req._send(conn): await asyncio.sleep(0) assert isinstance(req.body, MultipartWriter) assert req.body._boundary == boundary.encode() -async def test_post_data(loop: asyncio.AbstractEventLoop, conn: mock.Mock) -> None: +async def test_post_data( + loop: asyncio.AbstractEventLoop, + conn: mock.Mock, + make_client_request: _RequestMaker, +) -> None: for meth in ClientRequest.POST_METHODS: - req = ClientRequest( + req = make_client_request( meth, URL("http://python.org/"), data={"life": "42"}, loop=loop ) - resp = await req.send(conn) + resp = await req._send(conn) assert "/" == req.url.path assert isinstance(req.body, payload.Payload) assert b"life=42" == req.body._value assert "application/x-www-form-urlencoded" == req.headers["CONTENT-TYPE"] - await req.close() + await req._close() resp.close() -async def test_pass_falsy_data(loop: asyncio.AbstractEventLoop) -> None: - with mock.patch("aiohttp.client_reqrep.ClientRequest.update_body_from_data") as m: - req = ClientRequest("post", URL("http://python.org/"), data={}, loop=loop) +async def test_pass_falsy_data( + loop: asyncio.AbstractEventLoop, + make_client_request: _RequestMaker, +) -> None: + with mock.patch("aiohttp.client_reqrep.ClientRequest._update_body_from_data") as m: + req = make_client_request("post", URL("http://python.org/"), data={}, loop=loop) m.assert_called_once_with({}) - await req.close() + await req._close() async def test_pass_falsy_data_file( - loop: asyncio.AbstractEventLoop, tmp_path: pathlib.Path + loop: asyncio.AbstractEventLoop, + tmp_path: pathlib.Path, + make_client_request: _RequestMaker, ) -> None: testfile = (tmp_path / "tmpfile").open("w+b") testfile.write(b"data") testfile.seek(0) skip = frozenset([hdrs.CONTENT_TYPE]) - req = ClientRequest( + req = make_client_request( "post", URL("http://python.org/"), data=testfile, @@ -833,225 +934,262 @@ async def test_pass_falsy_data_file( loop=loop, ) assert req.headers.get("CONTENT-LENGTH", None) is not None - await req.close() + await req._close() testfile.close() # Elasticsearch API requires to send request body with GET-requests -async def test_get_with_data(loop: asyncio.AbstractEventLoop) -> None: +async def test_get_with_data( + loop: asyncio.AbstractEventLoop, + make_client_request: _RequestMaker, +) -> None: for meth in ClientRequest.GET_METHODS: - req = ClientRequest( + req = make_client_request( meth, URL("http://python.org/"), data={"life": "42"}, loop=loop ) assert "/" == req.url.path assert isinstance(req.body, payload.Payload) assert b"life=42" == req.body._value - await req.close() + await req._close() -async def test_bytes_data(loop: asyncio.AbstractEventLoop, conn: mock.Mock) -> None: +async def test_bytes_data( + loop: asyncio.AbstractEventLoop, + conn: mock.Mock, + make_client_request: _RequestMaker, +) -> None: for meth in ClientRequest.POST_METHODS: - req = ClientRequest( + req = make_client_request( meth, URL("http://python.org/"), data=b"binary data", loop=loop ) - resp = await req.send(conn) + resp = await req._send(conn) assert "/" == req.url.path assert isinstance(req.body, payload.BytesPayload) assert b"binary data" == req.body._value assert "application/octet-stream" == req.headers["CONTENT-TYPE"] - await req.close() + await req._close() resp.close() @pytest.mark.usefixtures("parametrize_zlib_backend") -async def test_content_encoding( +async def test_content_encoding( # type: ignore[misc] loop: asyncio.AbstractEventLoop, conn: mock.Mock, + make_client_request: _RequestMaker, ) -> None: - req = ClientRequest( + req = make_client_request( "post", URL("http://python.org/"), data="foo", compress="deflate", loop=loop ) with mock.patch("aiohttp.client_reqrep.StreamWriter") as m_writer: m_writer.return_value.write_headers = mock.AsyncMock() - resp = await req.send(conn) + resp = await req._send(conn) assert req.headers["TRANSFER-ENCODING"] == "chunked" assert req.headers["CONTENT-ENCODING"] == "deflate" m_writer.return_value.enable_compression.assert_called_with("deflate") - await req.close() + await req._close() resp.close() async def test_content_encoding_dont_set_headers_if_no_body( - loop: asyncio.AbstractEventLoop, conn: mock.Mock + loop: asyncio.AbstractEventLoop, + conn: mock.Mock, + make_client_request: _RequestMaker, ) -> None: - req = ClientRequest( + req = make_client_request( "post", URL("http://python.org/"), compress="deflate", loop=loop ) - with mock.patch("aiohttp.client_reqrep.http"): - resp = await req.send(conn) + resp = await req._send(conn) assert "TRANSFER-ENCODING" not in req.headers assert "CONTENT-ENCODING" not in req.headers - await req.close() + await req._close() resp.close() @pytest.mark.usefixtures("parametrize_zlib_backend") -async def test_content_encoding_header( +async def test_content_encoding_header( # type: ignore[misc] loop: asyncio.AbstractEventLoop, conn: mock.Mock, + make_client_request: _RequestMaker, ) -> None: - req = ClientRequest( + req = make_client_request( "post", URL("http://python.org/"), data="foo", - headers={"Content-Encoding": "deflate"}, + headers=CIMultiDict({"Content-Encoding": "deflate"}), loop=loop, ) with mock.patch("aiohttp.client_reqrep.StreamWriter") as m_writer: m_writer.return_value.write_headers = mock.AsyncMock() - resp = await req.send(conn) + resp = await req._send(conn) assert not m_writer.return_value.enable_compression.called assert not m_writer.return_value.enable_chunking.called - await req.close() + await req._close() resp.close() async def test_compress_and_content_encoding( - loop: asyncio.AbstractEventLoop, conn: mock.Mock + loop: asyncio.AbstractEventLoop, + conn: mock.Mock, + make_client_request: _RequestMaker, ) -> None: with pytest.raises(ValueError): - ClientRequest( + make_client_request( "post", URL("http://python.org/"), data="foo", - headers={"content-encoding": "deflate"}, + headers=CIMultiDict({"content-encoding": "deflate"}), compress="deflate", loop=loop, ) -async def test_chunked(loop: asyncio.AbstractEventLoop, conn: mock.Mock) -> None: - req = ClientRequest( +async def test_chunked( + loop: asyncio.AbstractEventLoop, + conn: mock.Mock, + make_client_request: _RequestMaker, +) -> None: + req = make_client_request( "post", URL("http://python.org/"), - headers={"TRANSFER-ENCODING": "gzip"}, + headers=CIMultiDict({"TRANSFER-ENCODING": "gzip"}), loop=loop, ) - resp = await req.send(conn) + resp = await req._send(conn) assert "gzip" == req.headers["TRANSFER-ENCODING"] - await req.close() + await req._close() resp.close() -async def test_chunked2(loop: asyncio.AbstractEventLoop, conn: mock.Mock) -> None: - req = ClientRequest( +async def test_chunked2( + loop: asyncio.AbstractEventLoop, + conn: mock.Mock, + make_client_request: _RequestMaker, +) -> None: + req = make_client_request( "post", URL("http://python.org/"), - headers={"Transfer-encoding": "chunked"}, + headers=CIMultiDict({"Transfer-encoding": "chunked"}), loop=loop, ) - resp = await req.send(conn) + resp = await req._send(conn) assert "chunked" == req.headers["TRANSFER-ENCODING"] - await req.close() + await req._close() resp.close() async def test_chunked_empty_body( - loop: asyncio.AbstractEventLoop, conn: mock.Mock + loop: asyncio.AbstractEventLoop, + conn: mock.Mock, + make_client_request: _RequestMaker, ) -> None: """Ensure write_bytes is called even if the body is empty.""" - req = ClientRequest( + req = make_client_request( "post", URL("http://python.org/"), chunked=True, loop=loop, data=b"", ) - with mock.patch.object(req, "write_bytes") as write_bytes: - resp = await req.send(conn) + with mock.patch.object(req, "_write_bytes") as write_bytes: + resp = await req._send(conn) assert "chunked" == req.headers["TRANSFER-ENCODING"] assert write_bytes.called - await req.close() + await req._close() resp.close() async def test_chunked_explicit( - loop: asyncio.AbstractEventLoop, conn: mock.Mock + loop: asyncio.AbstractEventLoop, + conn: mock.Mock, + make_client_request: _RequestMaker, ) -> None: - req = ClientRequest("post", URL("http://python.org/"), chunked=True, loop=loop) + req = make_client_request( + "post", URL("http://python.org/"), chunked=True, loop=loop + ) with mock.patch("aiohttp.client_reqrep.StreamWriter") as m_writer: m_writer.return_value.write_headers = mock.AsyncMock() m_writer.return_value.write_eof = mock.AsyncMock() - resp = await req.send(conn) + resp = await req._send(conn) assert "chunked" == req.headers["TRANSFER-ENCODING"] m_writer.return_value.enable_chunking.assert_called_with() - await req.close() + await req._close() resp.close() -async def test_chunked_length(loop: asyncio.AbstractEventLoop, conn: mock.Mock) -> None: +async def test_chunked_length( + loop: asyncio.AbstractEventLoop, + conn: mock.Mock, + make_client_request: _RequestMaker, +) -> None: with pytest.raises(ValueError): - ClientRequest( + make_client_request( "post", URL("http://python.org/"), - headers={"CONTENT-LENGTH": "1000"}, + headers=CIMultiDict({"CONTENT-LENGTH": "1000"}), chunked=True, loop=loop, ) async def test_chunked_transfer_encoding( - loop: asyncio.AbstractEventLoop, conn: mock.Mock + loop: asyncio.AbstractEventLoop, + conn: mock.Mock, + make_client_request: _RequestMaker, ) -> None: with pytest.raises(ValueError): - ClientRequest( + make_client_request( "post", URL("http://python.org/"), - headers={"TRANSFER-ENCODING": "chunked"}, + headers=CIMultiDict({"TRANSFER-ENCODING": "chunked"}), chunked=True, loop=loop, ) -async def test_file_upload_not_chunked(loop: asyncio.AbstractEventLoop) -> None: +async def test_file_upload_not_chunked( + loop: asyncio.AbstractEventLoop, + make_client_request: _RequestMaker, +) -> None: file_path = pathlib.Path(__file__).parent / "aiohttp.png" with file_path.open("rb") as f: - req = ClientRequest("post", URL("http://python.org/"), data=f, loop=loop) + req = make_client_request("post", URL("http://python.org/"), data=f, loop=loop) assert not req.chunked assert req.headers["CONTENT-LENGTH"] == str(file_path.stat().st_size) - await req.close() + await req._close() @pytest.mark.usefixtures("parametrize_zlib_backend") -async def test_precompressed_data_stays_intact( +async def test_precompressed_data_stays_intact( # type: ignore[misc] loop: asyncio.AbstractEventLoop, + make_client_request: _RequestMaker, ) -> None: data = ZLibBackend.compress(b"foobar") - req = ClientRequest( + req = make_client_request( "post", URL("http://python.org/"), data=data, - headers={"CONTENT-ENCODING": "deflate"}, + headers=CIMultiDict({"CONTENT-ENCODING": "deflate"}), compress=False, loop=loop, ) assert not req.compress assert not req.chunked assert req.headers["CONTENT-ENCODING"] == "deflate" - await req.close() + await req._close() async def test_body_with_size_sets_content_length( loop: asyncio.AbstractEventLoop, + make_client_request: _RequestMaker, ) -> None: """Test that when body has a size and no Content-Length header is set, it gets added.""" # Create a BytesPayload which has a size property data = b"test data" # Create request with data that will create a BytesPayload - req = ClientRequest( + req = make_client_request( "post", URL("http://python.org/"), data=data, @@ -1063,11 +1201,12 @@ async def test_body_with_size_sets_content_length( assert req.body is not None assert req._body is not None # When _body is set, body returns it assert req._body.size == len(data) - await req.close() + await req._close() async def test_body_payload_with_size_no_content_length( loop: asyncio.AbstractEventLoop, + make_client_request: _RequestMaker, ) -> None: """Test that when a body payload is set via update_body, Content-Length is added.""" # Create a payload with a known size @@ -1075,14 +1214,12 @@ async def test_body_payload_with_size_no_content_length( bytes_payload = payload.BytesPayload(data) # Create request with no data initially - req = ClientRequest( + req = make_client_request( "post", URL("http://python.org/"), loop=loop, ) - # Initially no body should be set - assert req._body is None # POST method with None body should have Content-Length: 0 assert req.headers[hdrs.CONTENT_LENGTH] == "0" @@ -1093,7 +1230,6 @@ async def test_body_payload_with_size_no_content_length( assert req.headers[hdrs.CONTENT_LENGTH] == str(len(data)) assert req.body is bytes_payload assert req._body is bytes_payload # Access _body which is the Payload - assert req._body is not None # type: ignore[unreachable] assert req._body.size == len(data) # Set body back to None @@ -1101,64 +1237,83 @@ async def test_body_payload_with_size_no_content_length( # Verify Content-Length is back to 0 for POST with None body assert req.headers[hdrs.CONTENT_LENGTH] == "0" - assert req._body is None - await req.close() + await req._close() -async def test_file_upload_not_chunked_seek(loop: asyncio.AbstractEventLoop) -> None: +async def test_file_upload_not_chunked_seek( + loop: asyncio.AbstractEventLoop, + make_client_request: _RequestMaker, +) -> None: file_path = pathlib.Path(__file__).parent / "aiohttp.png" with file_path.open("rb") as f: f.seek(100) - req = ClientRequest("post", URL("http://python.org/"), data=f, loop=loop) + req = make_client_request("post", URL("http://python.org/"), data=f, loop=loop) assert req.headers["CONTENT-LENGTH"] == str(file_path.stat().st_size - 100) - await req.close() + await req._close() -async def test_file_upload_force_chunked(loop: asyncio.AbstractEventLoop) -> None: +async def test_file_upload_force_chunked( + loop: asyncio.AbstractEventLoop, + make_client_request: _RequestMaker, +) -> None: file_path = pathlib.Path(__file__).parent / "aiohttp.png" with file_path.open("rb") as f: - req = ClientRequest( + req = make_client_request( "post", URL("http://python.org/"), data=f, chunked=True, loop=loop ) assert req.chunked assert "CONTENT-LENGTH" not in req.headers - await req.close() + await req._close() -async def test_expect100(loop: asyncio.AbstractEventLoop, conn: mock.Mock) -> None: - req = ClientRequest("get", URL("http://python.org/"), expect100=True, loop=loop) - resp = await req.send(conn) +async def test_expect100( + loop: asyncio.AbstractEventLoop, + conn: mock.Mock, + make_client_request: _RequestMaker, +) -> None: + req = make_client_request( + "get", URL("http://python.org/"), expect100=True, loop=loop + ) + resp = await req._send(conn) assert "100-continue" == req.headers["EXPECT"] assert req._continue is not None - req.terminate() + req._terminate() resp.close() async def test_expect_100_continue_header( - loop: asyncio.AbstractEventLoop, conn: mock.Mock + loop: asyncio.AbstractEventLoop, + conn: mock.Mock, + make_client_request: _RequestMaker, ) -> None: - req = ClientRequest( - "get", URL("http://python.org/"), headers={"expect": "100-continue"}, loop=loop + req = make_client_request( + "get", + URL("http://python.org/"), + headers=CIMultiDict({"expect": "100-continue"}), + loop=loop, ) - resp = await req.send(conn) + resp = await req._send(conn) assert "100-continue" == req.headers["EXPECT"] assert req._continue is not None - req.terminate() + req._terminate() resp.close() async def test_data_stream( - loop: asyncio.AbstractEventLoop, buf: bytearray, conn: mock.Mock + loop: asyncio.AbstractEventLoop, + buf: bytearray, + conn: mock.Mock, + make_client_request: _RequestMaker, ) -> None: async def gen() -> AsyncIterator[bytes]: yield b"binary data" yield b" result" - req = ClientRequest("POST", URL("http://python.org/"), data=gen(), loop=loop) + req = make_client_request("POST", URL("http://python.org/"), data=gen(), loop=loop) assert req.chunked assert req.headers["TRANSFER-ENCODING"] == "chunked" - original_write_bytes = req.write_bytes + original_write_bytes = req._write_bytes async def _mock_write_bytes( writer: AbstractStreamWriter, conn: mock.Mock, content_length: int | None @@ -1167,22 +1322,25 @@ async def _mock_write_bytes( await asyncio.sleep(0) await original_write_bytes(writer, conn, content_length) - with mock.patch.object(req, "write_bytes", _mock_write_bytes): - resp = await req.send(conn) + with mock.patch.object(req, "_write_bytes", _mock_write_bytes): + resp = await req._send(conn) assert asyncio.isfuture(req._writer) await resp.wait_for_close() assert req._writer is None assert ( # type: ignore[unreachable] buf.split(b"\r\n\r\n", 1)[1] == b"b\r\nbinary data\r\n7\r\n result\r\n0\r\n\r\n" ) - await req.close() + await req._close() async def test_data_file( - loop: asyncio.AbstractEventLoop, buf: bytearray, conn: mock.Mock + loop: asyncio.AbstractEventLoop, + buf: bytearray, + conn: mock.Mock, + make_client_request: _RequestMaker, ) -> None: with io.BufferedReader(io.BytesIO(b"*" * 2)) as file_handle: - req = ClientRequest( + req = make_client_request( "POST", URL("http://python.org/"), data=file_handle, @@ -1192,17 +1350,19 @@ async def test_data_file( assert isinstance(req.body, payload.BufferedReaderPayload) assert req.headers["TRANSFER-ENCODING"] == "chunked" - resp = await req.send(conn) + resp = await req._send(conn) assert asyncio.isfuture(req._writer) await resp.wait_for_close() assert req._writer is None assert buf.split(b"\r\n\r\n", 1)[1] == b"2\r\n" + b"*" * 2 + b"\r\n0\r\n\r\n" # type: ignore[unreachable] - await req.close() + await req._close() async def test_data_stream_exc( - loop: asyncio.AbstractEventLoop, conn: mock.Mock + loop: asyncio.AbstractEventLoop, + conn: mock.Mock, + make_client_request: _RequestMaker, ) -> None: fut = loop.create_future() @@ -1210,7 +1370,7 @@ async def gen() -> AsyncIterator[bytes]: yield b"binary data" await fut - req = ClientRequest("POST", URL("http://python.org/"), data=gen(), loop=loop) + req = make_client_request("POST", URL("http://python.org/"), data=gen(), loop=loop) assert req.chunked assert req.headers["TRANSFER-ENCODING"] == "chunked" @@ -1220,18 +1380,20 @@ async def throw_exc() -> None: t = loop.create_task(throw_exc()) - async with await req.send(conn): + async with await req._send(conn): assert req._writer is not None await req._writer await t # assert conn.close.called assert conn.protocol is not None assert conn.protocol.set_exception.called - await req.close() + await req._close() async def test_data_stream_exc_chain( - loop: asyncio.AbstractEventLoop, conn: mock.Mock + loop: asyncio.AbstractEventLoop, + conn: mock.Mock, + make_client_request: _RequestMaker, ) -> None: fut = loop.create_future() @@ -1240,7 +1402,7 @@ async def gen() -> AsyncIterator[None]: assert False yield # type: ignore[unreachable] # pragma: no cover - req = ClientRequest("POST", URL("http://python.org/"), data=gen(), loop=loop) + req = make_client_request("POST", URL("http://python.org/"), data=gen(), loop=loop) inner_exc = ValueError() @@ -1250,7 +1412,7 @@ async def throw_exc() -> None: t = loop.create_task(throw_exc()) - async with await req.send(conn): + async with await req._send(conn): assert req._writer is not None await req._writer await t @@ -1259,17 +1421,20 @@ async def throw_exc() -> None: outer_exc = conn.protocol.set_exception.call_args[0][0] assert isinstance(outer_exc, ClientConnectionError) assert outer_exc.__cause__ is inner_exc - await req.close() + await req._close() async def test_data_stream_continue( - loop: asyncio.AbstractEventLoop, buf: bytearray, conn: mock.Mock + loop: asyncio.AbstractEventLoop, + buf: bytearray, + conn: mock.Mock, + make_client_request: _RequestMaker, ) -> None: async def gen() -> AsyncIterator[bytes]: yield b"binary data" yield b" result" - req = ClientRequest( + req = make_client_request( "POST", URL("http://python.org/"), data=gen(), expect100=True, loop=loop ) assert req.chunked @@ -1281,21 +1446,24 @@ async def coro() -> None: t = loop.create_task(coro()) - resp = await req.send(conn) + resp = await req._send(conn) assert req._writer is not None await req._writer await t assert ( buf.split(b"\r\n\r\n", 1)[1] == b"b\r\nbinary data\r\n7\r\n result\r\n0\r\n\r\n" ) - await req.close() + await req._close() resp.close() async def test_data_continue( - loop: asyncio.AbstractEventLoop, buf: bytearray, conn: mock.Mock + loop: asyncio.AbstractEventLoop, + buf: bytearray, + conn: mock.Mock, + make_client_request: _RequestMaker, ) -> None: - req = ClientRequest( + req = make_client_request( "POST", URL("http://python.org/"), data=b"data", expect100=True, loop=loop ) @@ -1306,70 +1474,81 @@ async def coro() -> None: t = loop.create_task(coro()) - resp = await req.send(conn) + resp = await req._send(conn) assert req._writer is not None await req._writer await t assert buf.split(b"\r\n\r\n", 1)[1] == b"data" - await req.close() + await req._close() resp.close() async def test_close( - loop: asyncio.AbstractEventLoop, buf: bytearray, conn: mock.Mock + loop: asyncio.AbstractEventLoop, + buf: bytearray, + conn: mock.Mock, + make_client_request: _RequestMaker, ) -> None: async def gen() -> AsyncIterator[bytes]: await asyncio.sleep(0.00001) yield b"result" - req = ClientRequest("POST", URL("http://python.org/"), data=gen(), loop=loop) - resp = await req.send(conn) - await req.close() + req = make_client_request("POST", URL("http://python.org/"), data=gen(), loop=loop) + resp = await req._send(conn) + await req._close() assert buf.split(b"\r\n\r\n", 1)[1] == b"6\r\nresult\r\n0\r\n\r\n" - await req.close() + await req._close() resp.close() -async def test_bad_version(loop: asyncio.AbstractEventLoop, conn: mock.Mock) -> None: - req = ClientRequest( +async def test_bad_version( + loop: asyncio.AbstractEventLoop, + conn: mock.Mock, + make_client_request: _RequestMaker, +) -> None: + req = make_client_request( "GET", URL("http://python.org"), loop=loop, - headers={"Connection": "Close"}, + headers=CIMultiDict({"Connection": "Close"}), version=("1", "1\r\nInjected-Header: not allowed"), # type: ignore[arg-type] ) with pytest.raises(AttributeError): - await req.send(conn) + await req._send(conn) async def test_custom_response_class( - loop: asyncio.AbstractEventLoop, conn: mock.Mock + loop: asyncio.AbstractEventLoop, + conn: mock.Mock, + make_client_request: _RequestMaker, ) -> None: class CustomResponse(ClientResponse): async def read(self) -> bytes: return b"customized!" - req = ClientRequest( + req = make_client_request( "GET", URL("http://python.org/"), response_class=CustomResponse, loop=loop ) - resp = await req.send(conn) + resp = await req._send(conn) assert await resp.read() == b"customized!" - await req.close() + await req._close() resp.close() async def test_oserror_on_write_bytes( - loop: asyncio.AbstractEventLoop, conn: mock.Mock + loop: asyncio.AbstractEventLoop, + conn: mock.Mock, + make_client_request: _RequestMaker, ) -> None: - req = ClientRequest("POST", URL("http://python.org/"), loop=loop) - req.body = b"test data" + req = make_client_request("POST", URL("http://python.org/"), loop=loop) + await req.update_body(b"test data") writer = WriterMock() writer.write.side_effect = OSError - await req.write_bytes(writer, conn, None) + await req._write_bytes(writer, conn, None) assert conn.protocol.set_exception.called exc = conn.protocol.set_exception.call_args[0][0] @@ -1377,11 +1556,15 @@ async def test_oserror_on_write_bytes( @pytest.mark.skipif(sys.version_info < (3, 11), reason="Needs Task.cancelling()") -async def test_cancel_close(loop: asyncio.AbstractEventLoop, conn: mock.Mock) -> None: - req = ClientRequest("get", URL("http://python.org"), loop=loop) +async def test_cancel_close( # type: ignore[misc] + loop: asyncio.AbstractEventLoop, + conn: mock.Mock, + make_client_request: _RequestMaker, +) -> None: + req = make_client_request("get", URL("http://python.org"), loop=loop) req._writer = asyncio.Future() # type: ignore[assignment] - t = asyncio.create_task(req.close()) + t = asyncio.create_task(req._close()) # Start waiting on _writer await asyncio.sleep(0) @@ -1392,15 +1575,19 @@ async def test_cancel_close(loop: asyncio.AbstractEventLoop, conn: mock.Mock) -> await t -async def test_terminate(loop: asyncio.AbstractEventLoop, conn: mock.Mock) -> None: - req = ClientRequest("get", URL("http://python.org"), loop=loop) +async def test_terminate( + loop: asyncio.AbstractEventLoop, + conn: mock.Mock, + make_client_request: _RequestMaker, +) -> None: + req = make_client_request("get", URL("http://python.org"), loop=loop) async def _mock_write_bytes(*args: object, **kwargs: object) -> None: # Ensure the task is scheduled await asyncio.sleep(0) - with mock.patch.object(req, "write_bytes", _mock_write_bytes): - resp = await req.send(conn) + with mock.patch.object(req, "_write_bytes", _mock_write_bytes): + resp = await req._send(conn) assert req._writer is not None assert resp._writer is not None @@ -1413,27 +1600,53 @@ async def _mock_write_bytes(*args: object, **kwargs: object) -> None: assert req._writer is not None assert resp._writer is not None - req.terminate() + req._terminate() writer.cancel.assert_called_with() writer.done.assert_called_with() resp.close() def test_terminate_with_closed_loop( - loop: asyncio.AbstractEventLoop, conn: mock.Mock + loop: asyncio.AbstractEventLoop, + conn: mock.Mock, ) -> None: req = resp = writer = None async def go() -> None: nonlocal req, resp, writer - req = ClientRequest("get", URL("http://python.org"), loop=loop) + # Can't use make_client_request here, due to closing the loop mid-test. + req = ClientRequest( + "get", + URL("http://python.org"), + loop=loop, + params={}, + headers=CIMultiDict[str](), + skip_auto_headers=None, + data=None, + cookies=BaseCookie[str](), + auth=None, + version=HttpVersion11, + compress=False, + chunked=None, + expect100=False, + response_class=ClientResponse, + proxy=None, + proxy_auth=None, + timer=TimerNoop(), + session=None, # type: ignore[arg-type] + ssl=True, + proxy_headers=None, + traces=[], + trust_env=False, + server_hostname=None, + ) async def _mock_write_bytes(*args: object, **kwargs: object) -> None: # Ensure the task is scheduled await asyncio.sleep(0) - with mock.patch.object(req, "write_bytes", _mock_write_bytes): - resp = await req.send(conn) + with mock.patch.object(req, "_write_bytes", _mock_write_bytes): + resp = await req._send(conn) assert req._writer is not None writer = WriterMock() @@ -1447,7 +1660,7 @@ async def _mock_write_bytes(*args: object, **kwargs: object) -> None: loop.close() assert req is not None - req.terminate() + req._terminate() assert req._writer is None assert writer is not None assert not writer.cancel.called @@ -1455,11 +1668,11 @@ async def _mock_write_bytes(*args: object, **kwargs: object) -> None: resp.close() -def test_terminate_without_writer(loop: asyncio.AbstractEventLoop) -> None: - req = ClientRequest("get", URL("http://python.org"), loop=loop) +async def test_terminate_without_writer(make_client_request: _RequestMaker) -> None: + req = make_client_request("get", URL("http://python.org")) assert req._writer is None - req.terminate() + req._terminate() assert req._writer is None @@ -1481,17 +1694,18 @@ async def start(self, connection: Connection) -> ClientResponse: called = False class CustomRequest(ClientRequest): - async def send(self, conn: Connection) -> ClientResponse: + async def _send(self, conn: Connection) -> ClientResponse: resp = self.response_class( self.method, self.url, writer=self._writer, continue100=self._continue, timer=self._timer, - request_info=self.request_info, traces=self._traces, loop=self.loop, session=self._session, + request_headers=self.headers, + original_url=self.original_url, ) self.response = resp nonlocal called @@ -1536,27 +1750,6 @@ def test_insecure_fingerprint_sha1(loop: asyncio.AbstractEventLoop) -> None: Fingerprint(hashlib.sha1(b"foo").digest()) -def test_loose_cookies_types(loop: asyncio.AbstractEventLoop) -> None: - req = ClientRequest("get", URL("http://python.org"), loop=loop) - morsel: Morsel[str] = Morsel() - morsel.set(key="string", val="Another string", coded_val="really") - - accepted_types: list[LooseCookies] = [ - [("str", BaseCookie())], - [("str", morsel)], - [ - ("str", "str"), - ], - {"str": BaseCookie()}, - {"str": morsel}, - {"str": "str"}, - SimpleCookie(), - ] - - for loose_cookies_type in accepted_types: - req.update_cookies(cookies=loose_cookies_type) - - @pytest.mark.parametrize( "has_brotli,has_zstd,expected", [ @@ -1580,11 +1773,11 @@ def test_gen_default_accept_encoding( indirect=("netrc_contents",), ) @pytest.mark.usefixtures("netrc_contents") -def test_basicauth_from_netrc_present_untrusted_env( - make_request: _RequestMaker, +async def test_basicauth_from_netrc_present_untrusted_env( # type: ignore[misc] + make_client_request: _RequestMaker, ) -> None: """Test no authorization header is sent via netrc if trust_env is False""" - req = make_request("get", "http://example.com", trust_env=False) + req = make_client_request("get", URL("http://example.com"), trust_env=False) assert hdrs.AUTHORIZATION not in req.headers @@ -1594,39 +1787,43 @@ def test_basicauth_from_netrc_present_untrusted_env( indirect=("netrc_contents",), ) @pytest.mark.usefixtures("netrc_contents") -def test_basicauth_from_empty_netrc( - make_request: _RequestMaker, +async def test_basicauth_from_empty_netrc( # type: ignore[misc] + make_client_request: _RequestMaker, ) -> None: """Test that no Authorization header is sent when netrc is empty""" - req = make_request("get", "http://example.com", trust_env=True) + req = make_client_request("get", URL("http://example.com"), trust_env=True) assert hdrs.AUTHORIZATION not in req.headers -async def test_connection_key_with_proxy() -> None: +async def test_connection_key_with_proxy( + make_client_request: _RequestMaker, +) -> None: """Verify the proxy headers are included in the ConnectionKey when a proxy is used.""" proxy = URL("http://proxy.example.com") - req = ClientRequest( + req = make_client_request( "GET", URL("http://example.com"), proxy=proxy, - proxy_headers={"X-Proxy": "true"}, + proxy_headers=CIMultiDict({"X-Proxy": "true"}), loop=asyncio.get_running_loop(), ) assert req.connection_key.proxy_headers_hash is not None - await req.close() + await req._close() -async def test_connection_key_without_proxy() -> None: +async def test_connection_key_without_proxy( + make_client_request: _RequestMaker, +) -> None: """Verify the proxy headers are not included in the ConnectionKey when a proxy is used.""" # If proxy is unspecified, proxy_headers should be ignored - req = ClientRequest( + req = make_client_request( "GET", URL("http://example.com"), - proxy_headers={"X-Proxy": "true"}, + proxy_headers=CIMultiDict({"X-Proxy": "true"}), loop=asyncio.get_running_loop(), ) assert req.connection_key.proxy_headers_hash is None - await req.close() + await req._close() def test_request_info_back_compat() -> None: @@ -1672,9 +1869,9 @@ def test_request_info_tuple_new() -> None: ) -def test_get_content_length(make_request: _RequestMaker) -> None: +async def test_get_content_length(make_client_request: _RequestMaker) -> None: """Test _get_content_length method extracts Content-Length correctly.""" - req = make_request("get", "http://python.org/") + req = make_client_request("get", URL("http://python.org/")) # No Content-Length header assert req._get_content_length() is None @@ -1690,22 +1887,25 @@ def test_get_content_length(make_request: _RequestMaker) -> None: async def test_write_bytes_with_content_length_limit( - loop: asyncio.AbstractEventLoop, buf: bytearray, conn: mock.Mock + loop: asyncio.AbstractEventLoop, + buf: bytearray, + conn: mock.Mock, + make_client_request: _RequestMaker, ) -> None: """Test that write_bytes respects content_length limit for different body types.""" # Test with bytes data data = b"Hello World" - req = ClientRequest("post", URL("http://python.org/"), loop=loop) + req = make_client_request("post", URL("http://python.org/"), loop=loop) - req.body = data + await req.update_body(data) writer = StreamWriter(protocol=conn.protocol, loop=loop) # Use content_length=5 to truncate data - await req.write_bytes(writer, conn, 5) + await req._write_bytes(writer, conn, 5) # Verify only the first 5 bytes were written assert buf == b"Hello" - await req.close() + await req._close() @pytest.mark.parametrize( @@ -1715,15 +1915,16 @@ async def test_write_bytes_with_content_length_limit( b"Part1Part2Part3", ], ) -async def test_write_bytes_with_iterable_content_length_limit( +async def test_write_bytes_with_iterable_content_length_limit( # type: ignore[misc] loop: asyncio.AbstractEventLoop, buf: bytearray, conn: mock.Mock, data: list[bytes] | bytes, + make_client_request: _RequestMaker, ) -> None: """Test that write_bytes respects content_length limit for iterable data.""" # Test with iterable data - req = ClientRequest("post", URL("http://python.org/"), loop=loop) + req = make_client_request("post", URL("http://python.org/"), loop=loop) # Convert list to async generator if needed if isinstance(data, list): @@ -1732,191 +1933,48 @@ async def gen() -> AsyncIterator[bytes]: for chunk in data: yield chunk - req.body = gen() + await req.update_body(gen()) else: - req.body = data + await req.update_body(data) writer = StreamWriter(protocol=conn.protocol, loop=loop) # Use content_length=7 to truncate at the middle of Part2 - await req.write_bytes(writer, conn, 7) + await req._write_bytes(writer, conn, 7) assert len(buf) == 7 assert buf == b"Part1Pa" - await req.close() + await req._close() async def test_write_bytes_empty_iterable_with_content_length( - loop: asyncio.AbstractEventLoop, buf: bytearray, conn: mock.Mock + loop: asyncio.AbstractEventLoop, + buf: bytearray, + conn: mock.Mock, + make_client_request: _RequestMaker, ) -> None: """Test that write_bytes handles empty iterable body with content_length.""" - req = ClientRequest("post", URL("http://python.org/"), loop=loop) + req = make_client_request("post", URL("http://python.org/"), loop=loop) # Create an empty async generator async def gen() -> AsyncIterator[bytes]: return yield # pragma: no cover # This makes it a generator but never executes - req.body = gen() + await req.update_body(gen()) writer = StreamWriter(protocol=conn.protocol, loop=loop) # Use content_length=10 with empty body - await req.write_bytes(writer, conn, 10) + await req._write_bytes(writer, conn, 10) # Verify nothing was written assert len(buf) == 0 - await req.close() - - -async def test_warn_if_unclosed_payload_via_body_setter( - make_request: _RequestMaker, -) -> None: - """Test that _warn_if_unclosed_payload is called when setting body with unclosed payload.""" - req = make_request("POST", "http://python.org/") - - # First set a payload that needs manual closing (autoclose=False) - file_payload = payload.BufferedReaderPayload( - io.BufferedReader(io.BytesIO(b"test data")), - encoding="utf-8", - ) - req.body = file_payload - - # Setting body again should trigger the warning for the previous payload - with pytest.warns( - ResourceWarning, - match="The previous request body contains unclosed resources", - ): - req.body = b"new data" - - await req.close() - - -async def test_no_warn_for_autoclose_payload_via_body_setter( - make_request: _RequestMaker, -) -> None: - """Test that no warning is issued for payloads with autoclose=True.""" - req = make_request("POST", "http://python.org/") - - # First set BytesIOPayload which has autoclose=True - bytes_payload = payload.BytesIOPayload(io.BytesIO(b"test data")) - req.body = bytes_payload - - # Setting body again should not trigger warning since previous payload has autoclose=True - with warnings.catch_warnings(record=True) as warning_list: - warnings.simplefilter("always") - req.body = b"new data" - - # Filter out any non-ResourceWarning warnings - resource_warnings = [ - w for w in warning_list if issubclass(w.category, ResourceWarning) - ] - assert len(resource_warnings) == 0 - - await req.close() - - -async def test_no_warn_for_consumed_payload_via_body_setter( - make_request: _RequestMaker, -) -> None: - """Test that no warning is issued for already consumed payloads.""" - req = make_request("POST", "http://python.org/") - - # Create a payload that needs manual closing - file_payload = payload.BufferedReaderPayload( - io.BufferedReader(io.BytesIO(b"test data")), - encoding="utf-8", - ) - req.body = file_payload - - # Properly close the payload to mark it as consumed - await file_payload.close() - - # Setting body again should not trigger warning since previous payload is consumed - with warnings.catch_warnings(record=True) as warning_list: - warnings.simplefilter("always") - req.body = b"new data" - - # Filter out any non-ResourceWarning warnings - resource_warnings = [ - w for w in warning_list if issubclass(w.category, ResourceWarning) - ] - assert len(resource_warnings) == 0 - - await req.close() - - -async def test_warn_if_unclosed_payload_via_update_body_from_data( - make_request: _RequestMaker, -) -> None: - """Test that _warn_if_unclosed_payload is called via update_body_from_data.""" - req = make_request("POST", "http://python.org/") - - # First set a payload that needs manual closing - file_payload = payload.BufferedReaderPayload( - io.BufferedReader(io.BytesIO(b"initial data")), - encoding="utf-8", - ) - req.update_body_from_data(file_payload) - - # Create FormData for second update - form = aiohttp.FormData() - form.add_field("test", "value") - - # update_body_from_data should trigger the warning for the previous payload - with pytest.warns( - ResourceWarning, - match="The previous request body contains unclosed resources", - ): - req.update_body_from_data(form) - - await req.close() - - -async def test_warn_via_update_with_file_payload( - make_request: _RequestMaker, -) -> None: - """Test warning via update_body_from_data with file-like object.""" - req = make_request("POST", "http://python.org/") - - # First create a file-like object that results in BufferedReaderPayload - buffered1 = io.BufferedReader(io.BytesIO(b"file content 1")) - req.update_body_from_data(buffered1) - - # Second update should warn about the first payload - buffered2 = io.BufferedReader(io.BytesIO(b"file content 2")) - - with pytest.warns( - ResourceWarning, - match="The previous request body contains unclosed resources", - ): - req.update_body_from_data(buffered2) - - await req.close() - - -async def test_no_warn_for_simple_data_via_update_body_from_data( - make_request: _RequestMaker, -) -> None: - """Test that no warning is issued for simple data types.""" - req = make_request("POST", "http://python.org/") - - # Simple bytes data should not trigger warning - with warnings.catch_warnings(record=True) as warning_list: - warnings.simplefilter("always") - req.update_body_from_data(b"simple data") - - # Filter out any non-ResourceWarning warnings - resource_warnings = [ - w for w in warning_list if issubclass(w.category, ResourceWarning) - ] - assert len(resource_warnings) == 0 - - await req.close() + await req._close() async def test_update_body_closes_previous_payload( - make_request: _RequestMaker, + make_client_request: _RequestMaker, ) -> None: """Test that update_body properly closes the previous payload.""" - req = make_request("POST", "http://python.org/") + req = make_client_request("POST", URL("http://python.org/")) # Create a mock payload that tracks if it was closed mock_payload = mock.Mock(spec=payload.Payload) @@ -1934,39 +1992,14 @@ async def test_update_body_closes_previous_payload( # Verify new body is set (it's a BytesPayload now) assert isinstance(req.body, payload.BytesPayload) - await req.close() - - -async def test_body_setter_closes_previous_payload( - make_request: _RequestMaker, -) -> None: - """Test that body setter properly closes the previous payload.""" - req = make_request("POST", "http://python.org/") - - # Create a mock payload that tracks if it was closed - # We need to use create_autospec to ensure all methods are available - mock_payload = mock.create_autospec(payload.Payload, instance=True) - - # Set initial payload - req._body = mock_payload - - # Update body with new data using setter - req.body = b"new body data" - - # Verify the previous payload was closed using _close - mock_payload._close.assert_called_once() - - # Verify new body is set (it's a BytesPayload now) - assert isinstance(req.body, payload.BytesPayload) - - await req.close() + await req._close() async def test_update_body_with_different_types( - make_request: _RequestMaker, + make_client_request: _RequestMaker, ) -> None: """Test update_body with various data types.""" - req = make_request("POST", "http://python.org/") + req = make_client_request("POST", URL("http://python.org/")) # Test with bytes await req.update_body(b"bytes data") @@ -1978,17 +2011,17 @@ async def test_update_body_with_different_types( # Test with None (clears body) await req.update_body(None) - assert req.body == b"" # type: ignore[comparison-overlap] # empty body is represented as b"" + assert req.body._value == b"" - await req.close() + await req._close() async def test_update_body_with_chunked_encoding( - make_request: _RequestMaker, + make_client_request: _RequestMaker, ) -> None: """Test that update_body properly handles chunked transfer encoding.""" # Create request with chunked=True - req = make_request("POST", "http://python.org/", chunked=True) + req = make_client_request("POST", URL("http://python.org/"), chunked=True) # Verify Transfer-Encoding header is set assert req.headers["Transfer-Encoding"] == "chunked" @@ -2010,15 +2043,15 @@ async def test_update_body_with_chunked_encoding( assert req.headers["Transfer-Encoding"] == "chunked" assert "Content-Length" not in req.headers - await req.close() + await req._close() async def test_update_body_get_method_with_none_body( - make_request: _RequestMaker, + make_client_request: _RequestMaker, ) -> None: """Test that update_body with GET method and None body doesn't call update_transfer_encoding.""" # Create GET request - req = make_request("GET", "http://python.org/") + req = make_client_request("GET", URL("http://python.org/")) # GET requests shouldn't have Transfer-Encoding or Content-Length initially assert "Transfer-Encoding" not in req.headers @@ -2032,14 +2065,14 @@ async def test_update_body_get_method_with_none_body( assert "Transfer-Encoding" not in req.headers assert "Content-Length" not in req.headers - await req.close() + await req._close() async def test_update_body_updates_content_length( - make_request: _RequestMaker, + make_client_request: _RequestMaker, ) -> None: """Test that update_body properly updates Content-Length header when body size changes.""" - req = make_request("POST", "http://python.org/") + req = make_client_request("POST", URL("http://python.org/")) # Set initial body with known size await req.update_body(b"initial data") @@ -2060,98 +2093,28 @@ async def test_update_body_updates_content_length( # For None body with POST method, Content-Length should be set to 0 assert req.headers[hdrs.CONTENT_LENGTH] == "0" - await req.close() + await req._close() -async def test_warn_stacklevel_points_to_user_code( - make_request: _RequestMaker, +async def test_expect100_with_body_becomes_empty( + make_client_request: _RequestMaker, ) -> None: - """Test that the warning stacklevel correctly points to user code.""" - req = make_request("POST", "http://python.org/") - - # First set a payload that needs manual closing (autoclose=False) - file_payload = payload.BufferedReaderPayload( - io.BufferedReader(io.BytesIO(b"test data")), - encoding="utf-8", - ) - req.body = file_payload - - # Capture warnings with their details - with warnings.catch_warnings(record=True) as warning_list: - warnings.simplefilter("always", ResourceWarning) - # This line should be reported as the warning source - req.body = b"new data" - - # Find the ResourceWarning - resource_warnings = [ - w for w in warning_list if issubclass(w.category, ResourceWarning) - ] - assert len(resource_warnings) == 1 - - warning = resource_warnings[0] - # The warning should point to the line where we set req.body, not inside the library - # Call chain: user code -> body setter -> _warn_if_unclosed_payload - # stacklevel=3 is used in body setter to skip the setter and _warn_if_unclosed_payload - assert warning.filename == __file__ - # The line number should be the line with "req.body = b'new data'" - # We can't hardcode the line number, but we can verify it's not pointing - # to client_reqrep.py (the library code) - assert "client_reqrep.py" not in warning.filename - - await req.close() - - -async def test_warn_stacklevel_update_body_from_data( - make_request: _RequestMaker, -) -> None: - """Test that warning stacklevel is correct when called from update_body_from_data.""" - req = make_request("POST", "http://python.org/") - - # First set a payload that needs manual closing (autoclose=False) - file_payload = payload.BufferedReaderPayload( - io.BufferedReader(io.BytesIO(b"test data")), - encoding="utf-8", - ) - req.update_body_from_data(file_payload) - - # Capture warnings with their details - with warnings.catch_warnings(record=True) as warning_list: - warnings.simplefilter("always", ResourceWarning) - # This line should be reported as the warning source - req.update_body_from_data(b"new data") # LINE TO BE REPORTED - - # Find the ResourceWarning - resource_warnings = [ - w for w in warning_list if issubclass(w.category, ResourceWarning) - ] - assert len(resource_warnings) == 1 - - warning = resource_warnings[0] - # For update_body_from_data, stacklevel=3 points to this test file - # Call chain: user code -> update_body_from_data -> _warn_if_unclosed_payload - assert warning.filename == __file__ - assert "client_reqrep.py" not in warning.filename - - await req.close() - - -async def test_expect100_with_body_becomes_none() -> None: - """Test that write_bytes handles body becoming None after expect100 handling.""" + """Test that write_bytes handles body becoming empty after expect100 handling.""" # Create a mock writer and connection mock_writer = mock.AsyncMock() mock_conn = mock.Mock() # Create a request - req = ClientRequest( + req = make_client_request( "POST", URL("http://test.example.com/"), loop=asyncio.get_event_loop() ) req._body = mock.Mock() # Start with a body - # Now set body to None to simulate a race condition + # Now set body to empty payload to simulate a race condition # where req._body is set to None after expect100 handling - req._body = None + req._body = payload.PAYLOAD_REGISTRY.get(b"", disposition=None) - await req.write_bytes(mock_writer, mock_conn, None) + await req._write_bytes(mock_writer, mock_conn, None) @pytest.mark.parametrize( @@ -2180,14 +2143,15 @@ async def test_expect100_with_body_becomes_none() -> None: ("DELETE", b"x", "1"), ], ) -def test_content_length_for_methods( +async def test_content_length_for_methods( # type: ignore[misc] method: str, data: bytes | None, expected_content_length: str | None, loop: asyncio.AbstractEventLoop, + make_client_request: _RequestMaker, ) -> None: """Test that Content-Length header is set correctly for all HTTP methods.""" - req = ClientRequest(method, URL("http://python.org/"), data=data, loop=loop) + req = make_client_request(method, URL("http://python.org/"), data=data, loop=loop) actual_content_length = req.headers.get(hdrs.CONTENT_LENGTH) assert actual_content_length == expected_content_length @@ -2205,58 +2169,73 @@ def test_non_get_methods_classification(method: str) -> None: assert method not in ClientRequest.GET_METHODS -async def test_content_length_with_string_data(loop: asyncio.AbstractEventLoop) -> None: +async def test_content_length_with_string_data( + loop: asyncio.AbstractEventLoop, + make_client_request: _RequestMaker, +) -> None: """Test Content-Length when data is a string.""" data = "Hello, World!" - req = ClientRequest("POST", URL("http://python.org/"), data=data, loop=loop) + req = make_client_request("POST", URL("http://python.org/"), data=data, loop=loop) # String should be encoded to bytes, default encoding is utf-8 assert req.headers[hdrs.CONTENT_LENGTH] == str(len(data.encode("utf-8"))) - await req.close() + await req._close() async def test_content_length_with_async_iterable( loop: asyncio.AbstractEventLoop, + make_client_request: _RequestMaker, ) -> None: """Test that async iterables use chunked encoding, not Content-Length.""" async def data_gen() -> AsyncIterator[bytes]: yield b"chunk1" # pragma: no cover - req = ClientRequest("POST", URL("http://python.org/"), data=data_gen(), loop=loop) + req = make_client_request( + "POST", URL("http://python.org/"), data=data_gen(), loop=loop + ) assert hdrs.CONTENT_LENGTH not in req.headers assert req.chunked assert req.headers[hdrs.TRANSFER_ENCODING] == "chunked" - await req.close() + await req._close() -async def test_content_length_not_overridden(loop: asyncio.AbstractEventLoop) -> None: +async def test_content_length_not_overridden( + loop: asyncio.AbstractEventLoop, + make_client_request: _RequestMaker, +) -> None: """Test that explicitly set Content-Length is not overridden.""" - req = ClientRequest( + req = make_client_request( "POST", URL("http://python.org/"), data=b"test", - headers={hdrs.CONTENT_LENGTH: "100"}, + headers=CIMultiDict({hdrs.CONTENT_LENGTH: "100"}), loop=loop, ) # Should keep the explicitly set value assert req.headers[hdrs.CONTENT_LENGTH] == "100" - await req.close() + await req._close() -async def test_content_length_with_formdata(loop: asyncio.AbstractEventLoop) -> None: +async def test_content_length_with_formdata( + loop: asyncio.AbstractEventLoop, + make_client_request: _RequestMaker, +) -> None: """Test Content-Length with FormData.""" form = aiohttp.FormData() form.add_field("field", "value") - req = ClientRequest("POST", URL("http://python.org/"), data=form, loop=loop) + req = make_client_request("POST", URL("http://python.org/"), data=form, loop=loop) # FormData with known size should set Content-Length assert hdrs.CONTENT_LENGTH in req.headers - await req.close() + await req._close() -async def test_no_content_length_with_chunked(loop: asyncio.AbstractEventLoop) -> None: +async def test_no_content_length_with_chunked( + loop: asyncio.AbstractEventLoop, + make_client_request: _RequestMaker, +) -> None: """Test that chunked encoding prevents Content-Length header.""" - req = ClientRequest( + req = make_client_request( "POST", URL("http://python.org/"), data=b"test", @@ -2265,36 +2244,84 @@ async def test_no_content_length_with_chunked(loop: asyncio.AbstractEventLoop) - ) assert hdrs.CONTENT_LENGTH not in req.headers assert req.headers[hdrs.TRANSFER_ENCODING] == "chunked" - await req.close() + await req._close() @pytest.mark.parametrize("method", ["POST", "PUT", "PATCH", "DELETE"]) -async def test_update_body_none_sets_content_length_zero( - method: str, loop: asyncio.AbstractEventLoop +async def test_update_body_none_sets_content_length_zero( # type: ignore[misc] + method: str, + loop: asyncio.AbstractEventLoop, + make_client_request: _RequestMaker, ) -> None: """Test that updating body to None sets Content-Length: 0 for POST-like methods.""" # Create request with initial body - req = ClientRequest(method, URL("http://python.org/"), data=b"initial", loop=loop) + req = make_client_request( + method, URL("http://python.org/"), data=b"initial", loop=loop + ) assert req.headers[hdrs.CONTENT_LENGTH] == "7" # Update body to None await req.update_body(None) assert req.headers[hdrs.CONTENT_LENGTH] == "0" - assert req._body is None - await req.close() + await req._close() @pytest.mark.parametrize("method", ["GET", "HEAD", "OPTIONS", "TRACE"]) -async def test_update_body_none_no_content_length_for_get_methods( - method: str, loop: asyncio.AbstractEventLoop +async def test_update_body_none_no_content_length_for_get_methods( # type: ignore[misc] + method: str, + loop: asyncio.AbstractEventLoop, + make_client_request: _RequestMaker, ) -> None: """Test that updating body to None doesn't set Content-Length for GET-like methods.""" # Create request with initial body - req = ClientRequest(method, URL("http://python.org/"), data=b"initial", loop=loop) + req = make_client_request( + method, URL("http://python.org/"), data=b"initial", loop=loop + ) assert req.headers[hdrs.CONTENT_LENGTH] == "7" # Update body to None await req.update_body(None) assert hdrs.CONTENT_LENGTH not in req.headers - assert req._body is None - await req.close() + await req._close() + + +async def test_multiple_requests_share_empty_body_safely( + make_client_request: _RequestMaker, +) -> None: + """Test that multiple ClientRequest objects safely share the empty body payload.""" + requests: list[ClientRequest] = [] + for i in range(5): + req = make_client_request("GET", URL(f"http://example.com/path{i}")) + requests.append(req) + + empty_body = ClientRequest._EMPTY_BODY + for i, req in enumerate(requests): + assert req.body is empty_body, f"Request {i} has different empty body" + assert req.body.size == 0 + assert req.body.consumed is False + + assert empty_body.consumed is False + assert empty_body.size == 0 + + +async def test_empty_body_isolation_after_update( + make_client_request: _RequestMaker, +) -> None: + """Test that updating one request's body doesn't affect other requests.""" + req1 = make_client_request("POST", URL("http://example.com/1")) + req2 = make_client_request("POST", URL("http://example.com/2")) + + assert req1.body is ClientRequest._EMPTY_BODY + assert req2.body is ClientRequest._EMPTY_BODY + + await req1.update_body(b"new data") + + assert req1.body is not ClientRequest._EMPTY_BODY + assert req1.body.size == 8 + + assert req2.body is ClientRequest._EMPTY_BODY + assert req2.body.size == 0 + assert req2.body.consumed is False + + assert ClientRequest._EMPTY_BODY.consumed is False + assert ClientRequest._EMPTY_BODY.size == 0 diff --git a/tests/test_client_response.py b/tests/test_client_response.py index 23ff170e8dd..e9f7da127ee 100644 --- a/tests/test_client_response.py +++ b/tests/test_client_response.py @@ -14,7 +14,7 @@ import aiohttp from aiohttp import ClientSession, hdrs, http -from aiohttp.client_reqrep import ClientResponse, RequestInfo +from aiohttp.client_reqrep import ClientResponse from aiohttp.connector import Connection from aiohttp.helpers import TimerNoop from aiohttp.multipart import BadContentDispositionHeader @@ -32,17 +32,18 @@ def session() -> mock.Mock: async def test_http_processing_error(session: ClientSession) -> None: loop = mock.Mock() - request_info = mock.Mock() + url = URL("http://del-cl-resp.org") response = ClientResponse( "get", - URL("http://del-cl-resp.org"), - request_info=request_info, + url, writer=WriterMock(), continue100=None, timer=TimerNoop(), traces=[], loop=loop, session=session, + request_headers=CIMultiDict[str](), + original_url=url, ) loop.get_debug = mock.Mock() loop.get_debug.return_value = True @@ -54,22 +55,24 @@ async def test_http_processing_error(session: ClientSession) -> None: with pytest.raises(aiohttp.ClientResponseError) as info: await response.start(connection) - assert info.value.request_info is request_info + assert info.value.request_info.url is url response.close() def test_del(session: ClientSession) -> None: loop = mock.Mock() + url = URL("http://del-cl-resp.org") response = ClientResponse( "get", - URL("http://del-cl-resp.org"), - request_info=mock.Mock(), + url, writer=WriterMock(), continue100=None, timer=TimerNoop(), traces=[], loop=loop, session=session, + request_headers=CIMultiDict[str](), + original_url=url, ) loop.get_debug = mock.Mock() loop.get_debug.return_value = True @@ -87,16 +90,18 @@ def test_del(session: ClientSession) -> None: def test_close(loop: asyncio.AbstractEventLoop, session: ClientSession) -> None: + url = URL("http://def-cl-resp.org") response = ClientResponse( "get", - URL("http://def-cl-resp.org"), - request_info=mock.Mock(), + url, writer=WriterMock(), continue100=None, timer=TimerNoop(), traces=[], loop=loop, session=session, + request_headers=CIMultiDict[str](), + original_url=url, ) response._closed = False response._connection = mock.Mock() @@ -109,16 +114,18 @@ def test_close(loop: asyncio.AbstractEventLoop, session: ClientSession) -> None: def test_wait_for_100_1( loop: asyncio.AbstractEventLoop, session: ClientSession ) -> None: + url = URL("http://python.org") response = ClientResponse( "get", - URL("http://python.org"), + url, continue100=loop.create_future(), - request_info=mock.Mock(), writer=WriterMock(), timer=TimerNoop(), traces=[], loop=loop, session=session, + request_headers=CIMultiDict[str](), + original_url=url, ) assert response._continue is not None response.close() @@ -127,32 +134,36 @@ def test_wait_for_100_1( def test_wait_for_100_2( loop: asyncio.AbstractEventLoop, session: ClientSession ) -> None: + url = URL("http://python.org") response = ClientResponse( "get", - URL("http://python.org"), - request_info=mock.Mock(), + url, continue100=None, writer=WriterMock(), timer=TimerNoop(), traces=[], loop=loop, session=session, + request_headers=CIMultiDict[str](), + original_url=url, ) assert response._continue is None response.close() def test_repr(loop: asyncio.AbstractEventLoop, session: ClientSession) -> None: + url = URL("http://def-cl-resp.org") response = ClientResponse( "get", - URL("http://def-cl-resp.org"), - request_info=mock.Mock(), + url, writer=WriterMock(), continue100=None, timer=TimerNoop(), traces=[], loop=loop, session=session, + request_headers=CIMultiDict[str](), + original_url=url, ) response.status = 200 response.reason = "Ok" @@ -160,31 +171,35 @@ def test_repr(loop: asyncio.AbstractEventLoop, session: ClientSession) -> None: def test_repr_non_ascii_url() -> None: + url = URL("http://fake-host.org/\u03bb") response = ClientResponse( "get", - URL("http://fake-host.org/\u03bb"), - request_info=mock.Mock(), + url, writer=WriterMock(), continue100=None, timer=TimerNoop(), traces=[], loop=mock.Mock(), session=mock.Mock(), + request_headers=CIMultiDict[str](), + original_url=url, ) assert "" in repr(response) def test_repr_non_ascii_reason() -> None: + url = URL("http://fake-host.org/path") response = ClientResponse( "get", - URL("http://fake-host.org/path"), - request_info=mock.Mock(), + url, writer=WriterMock(), continue100=None, timer=TimerNoop(), traces=[], loop=mock.Mock(), session=mock.Mock(), + request_headers=CIMultiDict[str](), + original_url=url, ) response.reason = "\u03bb" assert "" in repr( @@ -195,16 +210,18 @@ def test_repr_non_ascii_reason() -> None: async def test_read_and_release_connection( loop: asyncio.AbstractEventLoop, session: ClientSession ) -> None: + url = URL("http://def-cl-resp.org") response = ClientResponse( "get", - URL("http://def-cl-resp.org"), - request_info=mock.Mock(), + url, writer=WriterMock(), continue100=None, timer=TimerNoop(), traces=[], loop=loop, session=session, + request_headers=CIMultiDict[str](), + original_url=url, ) def side_effect(*args: object, **kwargs: object) -> "asyncio.Future[bytes]": @@ -223,16 +240,18 @@ def side_effect(*args: object, **kwargs: object) -> "asyncio.Future[bytes]": async def test_read_and_release_connection_with_error( loop: asyncio.AbstractEventLoop, session: ClientSession ) -> None: + url = URL("http://def-cl-resp.org") response = ClientResponse( "get", - URL("http://def-cl-resp.org"), - request_info=mock.Mock(), + url, writer=WriterMock(), continue100=None, timer=TimerNoop(), traces=[], loop=loop, session=session, + request_headers=CIMultiDict[str](), + original_url=url, ) content = response.content = mock.Mock() content.read.return_value = loop.create_future() @@ -244,16 +263,18 @@ async def test_read_and_release_connection_with_error( async def test_release(loop: asyncio.AbstractEventLoop, session: ClientSession) -> None: + url = URL("http://def-cl-resp.org") response = ClientResponse( "get", - URL("http://def-cl-resp.org"), - request_info=mock.Mock(), + url, writer=WriterMock(), continue100=None, timer=TimerNoop(), traces=[], loop=loop, session=session, + request_headers=CIMultiDict[str](), + original_url=url, ) fut = loop.create_future() fut.set_result(b"") @@ -275,16 +296,18 @@ async def test_release_on_del( connection.protocol.upgraded = False def run(conn: Connection) -> None: + url = URL("http://def-cl-resp.org") response = ClientResponse( "get", - URL("http://def-cl-resp.org"), - request_info=mock.Mock(), + url, writer=WriterMock(), continue100=None, timer=TimerNoop(), traces=[], loop=loop, session=session, + request_headers=CIMultiDict[str](), + original_url=url, ) response._closed = False response._connection = conn @@ -297,16 +320,18 @@ def run(conn: Connection) -> None: async def test_response_eof( loop: asyncio.AbstractEventLoop, session: ClientSession ) -> None: + url = URL("http://def-cl-resp.org") response = ClientResponse( "get", - URL("http://def-cl-resp.org"), - request_info=mock.Mock(), + url, writer=None, continue100=None, timer=TimerNoop(), traces=[], loop=loop, session=session, + request_headers=CIMultiDict[str](), + original_url=url, ) response._closed = False conn = response._connection = mock.Mock() @@ -320,16 +345,18 @@ async def test_response_eof( async def test_response_eof_upgraded( loop: asyncio.AbstractEventLoop, session: ClientSession ) -> None: + url = URL("http://def-cl-resp.org") response = ClientResponse( "get", - URL("http://def-cl-resp.org"), - request_info=mock.Mock(), + url, writer=WriterMock(), continue100=None, timer=TimerNoop(), traces=[], loop=loop, session=session, + request_headers=CIMultiDict[str](), + original_url=url, ) conn = response._connection = mock.Mock() @@ -343,16 +370,18 @@ async def test_response_eof_upgraded( async def test_response_eof_after_connection_detach( loop: asyncio.AbstractEventLoop, session: ClientSession ) -> None: + url = URL("http://def-cl-resp.org") response = ClientResponse( "get", - URL("http://def-cl-resp.org"), - request_info=mock.Mock(), + url, writer=None, continue100=None, timer=TimerNoop(), traces=[], loop=loop, session=session, + request_headers=CIMultiDict[str](), + original_url=url, ) response._closed = False conn = response._connection = mock.Mock() @@ -364,16 +393,18 @@ async def test_response_eof_after_connection_detach( async def test_text(loop: asyncio.AbstractEventLoop, session: ClientSession) -> None: + url = URL("http://def-cl-resp.org") response = ClientResponse( "get", - URL("http://def-cl-resp.org"), - request_info=mock.Mock(), + url, writer=WriterMock(), continue100=None, timer=TimerNoop(), traces=[], loop=loop, session=session, + request_headers=CIMultiDict[str](), + original_url=url, ) def side_effect(*args: object, **kwargs: object) -> "asyncio.Future[bytes]": @@ -394,16 +425,18 @@ def side_effect(*args: object, **kwargs: object) -> "asyncio.Future[bytes]": async def test_text_bad_encoding( loop: asyncio.AbstractEventLoop, session: ClientSession ) -> None: + url = URL("http://def-cl-resp.org") response = ClientResponse( "get", - URL("http://def-cl-resp.org"), - request_info=mock.Mock(), + url, writer=WriterMock(), continue100=None, timer=TimerNoop(), traces=[], loop=loop, session=session, + request_headers=CIMultiDict[str](), + original_url=url, ) def side_effect(*args: object, **kwargs: object) -> "asyncio.Future[bytes]": @@ -428,16 +461,18 @@ async def test_text_badly_encoded_encoding_header( loop: asyncio.AbstractEventLoop, session: ClientSession ) -> None: session._resolve_charset = lambda *_: "utf-8" + url = URL("http://def-cl-resp.org") response = ClientResponse( "get", - URL("http://def-cl-resp.org"), - request_info=mock.Mock(), + url, writer=WriterMock(), continue100=None, timer=TimerNoop(), traces=[], loop=loop, session=session, + request_headers=CIMultiDict[str](), + original_url=url, ) def side_effect(*args: object, **kwargs: object) -> "asyncio.Future[bytes]": @@ -459,16 +494,18 @@ def side_effect(*args: object, **kwargs: object) -> "asyncio.Future[bytes]": async def test_text_custom_encoding( loop: asyncio.AbstractEventLoop, session: ClientSession ) -> None: + url = URL("http://def-cl-resp.org") response = ClientResponse( "get", - URL("http://def-cl-resp.org"), - request_info=mock.Mock(), + url, writer=WriterMock(), continue100=None, timer=TimerNoop(), traces=[], loop=loop, session=session, + request_headers=CIMultiDict[str](), + original_url=url, ) def side_effect(*args: object, **kwargs: object) -> "asyncio.Future[bytes]": @@ -492,16 +529,18 @@ async def test_text_charset_resolver( content_type: str, loop: asyncio.AbstractEventLoop, session: ClientSession ) -> None: session._resolve_charset = lambda r, b: "cp1251" + url = URL("http://def-cl-resp.org") response = ClientResponse( "get", - URL("http://def-cl-resp.org"), - request_info=mock.Mock(), + url, writer=WriterMock(), continue100=None, timer=TimerNoop(), traces=[], loop=loop, session=session, + request_headers=CIMultiDict[str](), + original_url=url, ) def side_effect(*args: object, **kwargs: object) -> "asyncio.Future[bytes]": @@ -524,16 +563,18 @@ def side_effect(*args: object, **kwargs: object) -> "asyncio.Future[bytes]": async def test_get_encoding_body_none( loop: asyncio.AbstractEventLoop, session: ClientSession ) -> None: + url = URL("http://def-cl-resp.org") response = ClientResponse( "get", - URL("http://def-cl-resp.org"), - request_info=mock.Mock(), + url, writer=WriterMock(), continue100=None, timer=TimerNoop(), traces=[], loop=loop, session=session, + request_headers=CIMultiDict[str](), + original_url=url, ) h = {"Content-Type": "text/html"} @@ -552,16 +593,18 @@ async def test_get_encoding_body_none( async def test_text_after_read( loop: asyncio.AbstractEventLoop, session: ClientSession ) -> None: + url = URL("http://def-cl-resp.org") response = ClientResponse( "get", - URL("http://def-cl-resp.org"), - request_info=mock.Mock(), + url, writer=WriterMock(), continue100=None, timer=TimerNoop(), traces=[], loop=loop, session=session, + request_headers=CIMultiDict[str](), + original_url=url, ) def side_effect(*args: object, **kwargs: object) -> "asyncio.Future[bytes]": @@ -580,16 +623,18 @@ def side_effect(*args: object, **kwargs: object) -> "asyncio.Future[bytes]": async def test_json(loop: asyncio.AbstractEventLoop, session: ClientSession) -> None: + url = URL("http://def-cl-resp.org") response = ClientResponse( "get", - URL("http://def-cl-resp.org"), - request_info=mock.Mock(), + url, writer=WriterMock(), continue100=None, timer=TimerNoop(), traces=[], loop=loop, session=session, + request_headers=CIMultiDict[str](), + original_url=url, ) def side_effect(*args: object, **kwargs: object) -> "asyncio.Future[bytes]": @@ -610,16 +655,18 @@ def side_effect(*args: object, **kwargs: object) -> "asyncio.Future[bytes]": async def test_json_extended_content_type( loop: asyncio.AbstractEventLoop, session: ClientSession ) -> None: + url = URL("http://def-cl-resp.org") response = ClientResponse( "get", - URL("http://def-cl-resp.org"), - request_info=mock.Mock(), + url, writer=WriterMock(), continue100=None, timer=TimerNoop(), traces=[], loop=loop, session=session, + request_headers=CIMultiDict[str](), + original_url=url, ) def side_effect(*args: object, **kwargs: object) -> "asyncio.Future[bytes]": @@ -640,16 +687,18 @@ def side_effect(*args: object, **kwargs: object) -> "asyncio.Future[bytes]": async def test_json_custom_content_type( loop: asyncio.AbstractEventLoop, session: ClientSession ) -> None: + url = URL("http://def-cl-resp.org") response = ClientResponse( "get", - URL("http://def-cl-resp.org"), - request_info=mock.Mock(), + url, writer=WriterMock(), continue100=None, timer=TimerNoop(), traces=[], loop=loop, session=session, + request_headers=CIMultiDict[str](), + original_url=url, ) def side_effect(*args: object, **kwargs: object) -> "asyncio.Future[bytes]": @@ -670,16 +719,18 @@ def side_effect(*args: object, **kwargs: object) -> "asyncio.Future[bytes]": async def test_json_custom_loader( loop: asyncio.AbstractEventLoop, session: ClientSession ) -> None: + url = URL("http://def-cl-resp.org") response = ClientResponse( "get", - URL("http://def-cl-resp.org"), - request_info=mock.Mock(), + url, writer=WriterMock(), continue100=None, timer=TimerNoop(), traces=[], loop=loop, session=session, + request_headers=CIMultiDict[str](), + original_url=url, ) h = {"Content-Type": "application/json;charset=cp1251"} response._headers = CIMultiDictProxy(CIMultiDict(h)) @@ -695,16 +746,18 @@ def custom(content: str) -> str: async def test_json_invalid_content_type( loop: asyncio.AbstractEventLoop, session: ClientSession ) -> None: + url = URL("http://def-cl-resp.org") response = ClientResponse( "get", - URL("http://def-cl-resp.org"), - request_info=mock.Mock(), + url, writer=WriterMock(), continue100=None, timer=TimerNoop(), traces=[], loop=loop, session=session, + request_headers=CIMultiDict[str](), + original_url=url, ) h = {"Content-Type": "data/octet-stream"} response._headers = CIMultiDictProxy(CIMultiDict(h)) @@ -721,16 +774,18 @@ async def test_json_invalid_content_type( async def test_json_no_content( loop: asyncio.AbstractEventLoop, session: ClientSession ) -> None: + url = URL("http://def-cl-resp.org") response = ClientResponse( "get", - URL("http://def-cl-resp.org"), - request_info=mock.Mock(), + url, writer=WriterMock(), continue100=None, timer=TimerNoop(), traces=[], loop=loop, session=session, + request_headers=CIMultiDict[str](), + original_url=url, ) h = {"Content-Type": "application/json"} response._headers = CIMultiDictProxy(CIMultiDict(h)) @@ -743,16 +798,18 @@ async def test_json_no_content( async def test_json_override_encoding( loop: asyncio.AbstractEventLoop, session: ClientSession ) -> None: + url = URL("http://def-cl-resp.org") response = ClientResponse( "get", - URL("http://def-cl-resp.org"), - request_info=mock.Mock(), + url, writer=WriterMock(), continue100=None, timer=TimerNoop(), traces=[], loop=loop, session=session, + request_headers=CIMultiDict[str](), + original_url=url, ) def side_effect(*args: object, **kwargs: object) -> "asyncio.Future[bytes]": @@ -774,16 +831,18 @@ def side_effect(*args: object, **kwargs: object) -> "asyncio.Future[bytes]": def test_get_encoding_unknown( loop: asyncio.AbstractEventLoop, session: ClientSession ) -> None: + url = URL("http://def-cl-resp.org") response = ClientResponse( "get", - URL("http://def-cl-resp.org"), - request_info=mock.Mock(), + url, writer=WriterMock(), continue100=None, timer=TimerNoop(), traces=[], loop=loop, session=session, + request_headers=CIMultiDict[str](), + original_url=url, ) h = {"Content-Type": "application/json"} @@ -792,16 +851,18 @@ def test_get_encoding_unknown( def test_raise_for_status_2xx() -> None: + url = URL("http://def-cl-resp.org") response = ClientResponse( "get", - URL("http://def-cl-resp.org"), - request_info=mock.Mock(), + url, writer=WriterMock(), continue100=None, timer=TimerNoop(), traces=[], loop=mock.Mock(), session=mock.Mock(), + request_headers=CIMultiDict[str](), + original_url=url, ) response.status = 200 response.reason = "OK" @@ -809,16 +870,18 @@ def test_raise_for_status_2xx() -> None: def test_raise_for_status_4xx() -> None: + url = URL("http://def-cl-resp.org") response = ClientResponse( "get", - URL("http://def-cl-resp.org"), - request_info=mock.Mock(), + url, writer=WriterMock(), continue100=None, timer=TimerNoop(), traces=[], loop=mock.Mock(), session=mock.Mock(), + request_headers=CIMultiDict[str](), + original_url=url, ) response.status = 409 response.reason = "CONFLICT" @@ -830,16 +893,18 @@ def test_raise_for_status_4xx() -> None: def test_raise_for_status_4xx_without_reason() -> None: + url = URL("http://def-cl-resp.org") response = ClientResponse( "get", - URL("http://def-cl-resp.org"), - request_info=mock.Mock(), + url, writer=WriterMock(), continue100=None, timer=TimerNoop(), traces=[], loop=mock.Mock(), session=mock.Mock(), + request_headers=CIMultiDict[str](), + original_url=url, ) response.status = 404 response.reason = "" @@ -851,31 +916,35 @@ def test_raise_for_status_4xx_without_reason() -> None: def test_resp_host() -> None: + url = URL("http://del-cl-resp.org") response = ClientResponse( "get", - URL("http://del-cl-resp.org"), - request_info=mock.Mock(), + url, writer=WriterMock(), continue100=None, timer=TimerNoop(), traces=[], loop=mock.Mock(), session=mock.Mock(), + request_headers=CIMultiDict[str](), + original_url=url, ) assert "del-cl-resp.org" == response.host def test_content_type() -> None: + url = URL("http://def-cl-resp.org") response = ClientResponse( "get", - URL("http://def-cl-resp.org"), - request_info=mock.Mock(), + url, writer=WriterMock(), continue100=None, timer=TimerNoop(), traces=[], loop=mock.Mock(), session=mock.Mock(), + request_headers=CIMultiDict[str](), + original_url=url, ) h = {"Content-Type": "application/json;charset=cp1251"} response._headers = CIMultiDictProxy(CIMultiDict(h)) @@ -884,16 +953,18 @@ def test_content_type() -> None: def test_content_type_no_header() -> None: + url = URL("http://def-cl-resp.org") response = ClientResponse( "get", - URL("http://def-cl-resp.org"), - request_info=mock.Mock(), + url, writer=WriterMock(), continue100=None, timer=TimerNoop(), traces=[], loop=mock.Mock(), session=mock.Mock(), + request_headers=CIMultiDict[str](), + original_url=url, ) response._headers = CIMultiDictProxy(CIMultiDict({})) @@ -901,16 +972,18 @@ def test_content_type_no_header() -> None: def test_charset() -> None: + url = URL("http://def-cl-resp.org") response = ClientResponse( "get", - URL("http://def-cl-resp.org"), - request_info=mock.Mock(), + url, writer=WriterMock(), continue100=None, timer=TimerNoop(), traces=[], loop=mock.Mock(), session=mock.Mock(), + request_headers=CIMultiDict[str](), + original_url=url, ) h = {"Content-Type": "application/json;charset=cp1251"} response._headers = CIMultiDictProxy(CIMultiDict(h)) @@ -919,16 +992,18 @@ def test_charset() -> None: def test_charset_no_header() -> None: + url = URL("http://def-cl-resp.org") response = ClientResponse( "get", - URL("http://def-cl-resp.org"), - request_info=mock.Mock(), + url, writer=WriterMock(), continue100=None, timer=TimerNoop(), traces=[], loop=mock.Mock(), session=mock.Mock(), + request_headers=CIMultiDict[str](), + original_url=url, ) response._headers = CIMultiDictProxy(CIMultiDict({})) @@ -936,16 +1011,18 @@ def test_charset_no_header() -> None: def test_charset_no_charset() -> None: + url = URL("http://def-cl-resp.org") response = ClientResponse( "get", - URL("http://def-cl-resp.org"), - request_info=mock.Mock(), + url, writer=WriterMock(), continue100=None, timer=TimerNoop(), traces=[], loop=mock.Mock(), session=mock.Mock(), + request_headers=CIMultiDict[str](), + original_url=url, ) h = {"Content-Type": "application/json"} response._headers = CIMultiDictProxy(CIMultiDict(h)) @@ -954,16 +1031,18 @@ def test_charset_no_charset() -> None: def test_content_disposition_full() -> None: + url = URL("http://def-cl-resp.org") response = ClientResponse( "get", - URL("http://def-cl-resp.org"), - request_info=mock.Mock(), + url, writer=WriterMock(), continue100=None, timer=TimerNoop(), traces=[], loop=mock.Mock(), session=mock.Mock(), + request_headers=CIMultiDict[str](), + original_url=url, ) h = {"Content-Disposition": 'attachment; filename="archive.tar.gz"; foo=bar'} response._headers = CIMultiDictProxy(CIMultiDict(h)) @@ -977,16 +1056,18 @@ def test_content_disposition_full() -> None: def test_content_disposition_no_parameters() -> None: + url = URL("http://def-cl-resp.org") response = ClientResponse( "get", - URL("http://def-cl-resp.org"), - request_info=mock.Mock(), + url, writer=WriterMock(), continue100=None, timer=TimerNoop(), traces=[], loop=mock.Mock(), session=mock.Mock(), + request_headers=CIMultiDict[str](), + original_url=url, ) h = {"Content-Disposition": "attachment"} response._headers = CIMultiDictProxy(CIMultiDict(h)) @@ -1005,16 +1086,18 @@ def test_content_disposition_no_parameters() -> None: ), ) def test_content_disposition_empty_parts(content_disposition: str) -> None: + url = URL("http://def-cl-resp.org") response = ClientResponse( "get", - URL("http://def-cl-resp.org"), - request_info=mock.Mock(), + url, writer=WriterMock(), continue100=None, timer=TimerNoop(), traces=[], loop=mock.Mock(), session=mock.Mock(), + request_headers=CIMultiDict[str](), + original_url=url, ) h = {"Content-Disposition": content_disposition} response._headers = CIMultiDictProxy(CIMultiDict(h)) @@ -1026,16 +1109,18 @@ def test_content_disposition_empty_parts(content_disposition: str) -> None: def test_content_disposition_no_header() -> None: + url = URL("http://def-cl-resp.org") response = ClientResponse( "get", - URL("http://def-cl-resp.org"), - request_info=mock.Mock(), + url, writer=WriterMock(), continue100=None, timer=TimerNoop(), traces=[], loop=mock.Mock(), session=mock.Mock(), + request_headers=CIMultiDict[str](), + original_url=url, ) response._headers = CIMultiDictProxy(CIMultiDict({})) @@ -1043,16 +1128,18 @@ def test_content_disposition_no_header() -> None: def test_default_encoding_is_utf8() -> None: + url = URL("http://def-cl-resp.org") response = ClientResponse( "get", - URL("http://def-cl-resp.org"), - request_info=mock.Mock(), + url, writer=WriterMock(), continue100=None, timer=TimerNoop(), traces=[], loop=mock.Mock(), - session=None, # type: ignore[arg-type] + session=None, + request_headers=CIMultiDict[str](), + original_url=url, ) response._headers = CIMultiDictProxy(CIMultiDict({})) response._body = b"" @@ -1063,17 +1150,18 @@ def test_default_encoding_is_utf8() -> None: def test_response_request_info() -> None: url = URL("http://def-cl-resp.org") h = {"Content-Type": "application/json;charset=cp1251"} - headers = CIMultiDictProxy(CIMultiDict(h)) + headers = CIMultiDict(h) response = ClientResponse( "get", url, - request_info=RequestInfo(url, "get", headers, url), writer=WriterMock(), continue100=None, timer=TimerNoop(), traces=[], loop=mock.Mock(), session=mock.Mock(), + request_headers=headers, + original_url=url, ) assert url == response.request_info.url assert "get" == response.request_info.method @@ -1083,17 +1171,18 @@ def test_response_request_info() -> None: def test_request_info_in_exception() -> None: url = URL("http://def-cl-resp.org") h = {"Content-Type": "application/json;charset=cp1251"} - headers = CIMultiDictProxy(CIMultiDict(h)) + headers = CIMultiDict(h) response = ClientResponse( "get", url, - request_info=RequestInfo(url, "get", headers, url), writer=WriterMock(), continue100=None, timer=TimerNoop(), traces=[], loop=mock.Mock(), session=mock.Mock(), + request_headers=headers, + original_url=url, ) response.status = 409 response.reason = "CONFLICT" @@ -1105,17 +1194,18 @@ def test_request_info_in_exception() -> None: def test_no_redirect_history_in_exception() -> None: url = URL("http://def-cl-resp.org") h = {"Content-Type": "application/json;charset=cp1251"} - headers = CIMultiDictProxy(CIMultiDict(h)) + headers = CIMultiDict(h) response = ClientResponse( "get", url, - request_info=RequestInfo(url, "get", headers, url), writer=WriterMock(), continue100=None, timer=TimerNoop(), traces=[], loop=mock.Mock(), session=mock.Mock(), + request_headers=headers, + original_url=url, ) response.status = 409 response.reason = "CONFLICT" @@ -1130,17 +1220,18 @@ def test_redirect_history_in_exception() -> None: url = URL(u) hist_headers = {"Content-Type": "application/json;charset=cp1251", "Location": u} h = {"Content-Type": "application/json;charset=cp1251"} - headers = CIMultiDictProxy(CIMultiDict(h)) + headers = CIMultiDict(h) response = ClientResponse( "get", url, - request_info=RequestInfo(url, "get", headers, url), writer=WriterMock(), continue100=None, timer=TimerNoop(), traces=[], loop=mock.Mock(), session=mock.Mock(), + request_headers=headers, + original_url=url, ) response.status = 409 response.reason = "CONFLICT" @@ -1148,13 +1239,14 @@ def test_redirect_history_in_exception() -> None: hist_response = ClientResponse( "get", hist_url, - request_info=RequestInfo(url, "get", headers, url), writer=WriterMock(), continue100=None, timer=TimerNoop(), traces=[], loop=mock.Mock(), session=mock.Mock(), + request_headers=headers, + original_url=hist_url, ) hist_response._headers = CIMultiDictProxy(CIMultiDict(hist_headers)) @@ -1179,13 +1271,14 @@ async def test_response_read_triggers_callback( response = ClientResponse( response_method, response_url, - request_info=mock.Mock(), writer=WriterMock(), continue100=None, timer=TimerNoop(), loop=loop, session=session, traces=[trace], + request_headers=CIMultiDict[str](), + original_url=response_url, ) def side_effect(*args: object, **kwargs: object) -> "asyncio.Future[bytes]": @@ -1211,16 +1304,18 @@ def side_effect(*args: object, **kwargs: object) -> "asyncio.Future[bytes]": def test_response_cookies( loop: asyncio.AbstractEventLoop, session: ClientSession ) -> None: + url = URL("http://python.org") response = ClientResponse( "get", - URL("http://python.org"), - request_info=mock.Mock(), + url, writer=WriterMock(), continue100=None, timer=TimerNoop(), traces=[], loop=loop, session=session, + request_headers=CIMultiDict[str](), + original_url=url, ) cookies = response.cookies # Ensure the same cookies object is returned each time @@ -1234,13 +1329,14 @@ def test_response_real_url( response = ClientResponse( "get", url, - request_info=mock.Mock(), writer=WriterMock(), continue100=None, timer=TimerNoop(), traces=[], loop=loop, session=session, + request_headers=CIMultiDict[str](), + original_url=url, ) assert response.url == url.with_fragment(None) assert response.real_url == url @@ -1253,13 +1349,14 @@ def test_response_links_comma_separated( response = ClientResponse( "get", url, - request_info=mock.Mock(), writer=WriterMock(), continue100=None, timer=TimerNoop(), traces=[], loop=loop, session=session, + request_headers=CIMultiDict[str](), + original_url=url, ) h = ( ( @@ -1284,13 +1381,14 @@ def test_response_links_multiple_headers( response = ClientResponse( "get", url, - request_info=mock.Mock(), writer=WriterMock(), continue100=None, timer=TimerNoop(), traces=[], loop=loop, session=session, + request_headers=CIMultiDict[str](), + original_url=url, ) h = ( ("Link", "; rel=next"), @@ -1310,13 +1408,14 @@ def test_response_links_no_rel( response = ClientResponse( "get", url, - request_info=mock.Mock(), writer=WriterMock(), continue100=None, timer=TimerNoop(), traces=[], loop=loop, session=session, + request_headers=CIMultiDict[str](), + original_url=url, ) h = (("Link", ""),) response._headers = CIMultiDictProxy(CIMultiDict(h)) @@ -1332,13 +1431,14 @@ def test_response_links_quoted( response = ClientResponse( "get", url, - request_info=mock.Mock(), writer=WriterMock(), continue100=None, timer=TimerNoop(), traces=[], loop=loop, session=session, + request_headers=CIMultiDict[str](), + original_url=url, ) h = (("Link", '; rel="home-page"'),) response._headers = CIMultiDictProxy(CIMultiDict(h)) @@ -1354,13 +1454,14 @@ def test_response_links_relative( response = ClientResponse( "get", url, - request_info=mock.Mock(), writer=WriterMock(), continue100=None, timer=TimerNoop(), traces=[], loop=loop, session=session, + request_headers=CIMultiDict[str](), + original_url=url, ) h = (("Link", "; rel=rel"),) response._headers = CIMultiDictProxy(CIMultiDict(h)) @@ -1376,29 +1477,32 @@ def test_response_links_empty( response = ClientResponse( "get", url, - request_info=mock.Mock(), writer=WriterMock(), continue100=None, timer=TimerNoop(), traces=[], loop=loop, session=session, + request_headers=CIMultiDict[str](), + original_url=url, ) response._headers = CIMultiDictProxy(CIMultiDict()) assert response.links == {} def test_response_not_closed_after_get_ok(mocker: MockerFixture) -> None: + url = URL("http://del-cl-resp.org") response = ClientResponse( "get", - URL("http://del-cl-resp.org"), - request_info=mock.Mock(), + url, writer=WriterMock(), continue100=None, timer=TimerNoop(), traces=[], loop=mock.Mock(), session=mock.Mock(), + request_headers=CIMultiDict[str](), + original_url=url, ) response.status = 400 response.reason = "Bad Request" @@ -1423,16 +1527,18 @@ def test_response_duplicate_cookie_names( - response.headers.getall('Set-Cookie') for raw headers - The session's cookie jar correctly stores all cookies """ + url = URL("http://example.com") response = ClientResponse( "get", - URL("http://example.com"), - request_info=mock.Mock(), + url, writer=WriterMock(), continue100=None, timer=TimerNoop(), traces=[], loop=loop, session=session, + request_headers=CIMultiDict[str](), + original_url=url, ) # Set headers with duplicate cookie names but different domains @@ -1462,16 +1568,18 @@ def test_response_raw_cookie_headers_preserved( loop: asyncio.AbstractEventLoop, session: ClientSession ) -> None: """Test that raw Set-Cookie headers are preserved in _raw_cookie_headers.""" + url = URL("http://example.com") response = ClientResponse( "get", - URL("http://example.com"), - request_info=mock.Mock(), + url, writer=WriterMock(), continue100=None, timer=TimerNoop(), traces=[], loop=loop, session=session, + request_headers=CIMultiDict[str](), + original_url=url, ) # Set headers with multiple cookies @@ -1502,16 +1610,18 @@ def test_response_cookies_setter_updates_raw_headers( loop: asyncio.AbstractEventLoop, session: ClientSession ) -> None: """Test that setting cookies property updates _raw_cookie_headers.""" + url = URL("http://example.com") response = ClientResponse( "get", - URL("http://example.com"), - request_info=mock.Mock(), + url, writer=WriterMock(), continue100=None, timer=TimerNoop(), traces=[], loop=loop, session=session, + request_headers=CIMultiDict[str](), + original_url=url, ) # Create a SimpleCookie with some cookies diff --git a/tests/test_client_session.py b/tests/test_client_session.py index 84a417f9219..21057d3fbb5 100644 --- a/tests/test_client_session.py +++ b/tests/test_client_session.py @@ -555,9 +555,9 @@ async def test_reraise_os_error( create_mocked_conn: Callable[[], ResponseHandler], ) -> None: err = OSError(1, "permission error") - req = mock.Mock() + req = mock.create_autospec(aiohttp.ClientRequest, spec_set=True) req_factory = mock.Mock(return_value=req) - req.send = mock.AsyncMock(side_effect=err) + req._send = mock.AsyncMock(side_effect=err) req._body = mock.Mock() req._body.close = mock.AsyncMock() session = await create_session(request_class=req_factory) @@ -587,9 +587,9 @@ class UnexpectedException(BaseException): pass err = UnexpectedException("permission error") - req = mock.Mock() + req = mock.create_autospec(aiohttp.ClientRequest, spec_set=True) req_factory = mock.Mock(return_value=req) - req.send = mock.AsyncMock(side_effect=err) + req._send = mock.AsyncMock(side_effect=err) req._body = mock.Mock() req._body.close = mock.AsyncMock() session = await create_session(request_class=req_factory) @@ -651,7 +651,7 @@ async def test_ws_connect_allowed_protocols( # type: ignore[misc] req = mock.create_autospec(aiohttp.ClientRequest, spec_set=True) req._body = None # No body for WebSocket upgrade requests req_factory = mock.Mock(return_value=req) - req.send = mock.AsyncMock(return_value=resp) + req._send = mock.AsyncMock(return_value=resp) # BaseConnector allows all high level protocols by default connector = BaseConnector() @@ -714,7 +714,7 @@ async def test_ws_connect_unix_socket_allowed_protocols( # type: ignore[misc] req = mock.create_autospec(aiohttp.ClientRequest, spec_set=True) req._body = None # No body for WebSocket upgrade requests req_factory = mock.Mock(return_value=req) - req.send = mock.AsyncMock(return_value=resp) + req._send = mock.AsyncMock(return_value=resp) # UnixConnector allows all high level protocols by default and unix sockets session = await create_session( connector=UnixConnector(path=""), request_class=req_factory @@ -954,7 +954,7 @@ async def on_request_headers_sent( context: object, params: tracing.TraceRequestHeadersSentParams, ) -> None: - gathered_req_headers.extend(**params.headers) + gathered_req_headers.extend(params.headers) trace_config = aiohttp.TraceConfig( trace_config_ctx_factory=mock.Mock(return_value=trace_config_ctx) diff --git a/tests/test_connector.py b/tests/test_connector.py index 91796280b27..1d5ed0c01a0 100644 --- a/tests/test_connector.py +++ b/tests/test_connector.py @@ -12,7 +12,7 @@ from collections.abc import Awaitable, Callable, Iterator, Sequence from concurrent import futures from contextlib import closing, suppress -from typing import Literal, NoReturn +from typing import Any, Literal, NoReturn from unittest import mock import pytest @@ -29,7 +29,7 @@ ) from aiohttp.abc import ResolveResult from aiohttp.client_proto import ResponseHandler -from aiohttp.client_reqrep import ConnectionKey +from aiohttp.client_reqrep import ClientRequestArgs, ConnectionKey from aiohttp.connector import ( _SSL_CONTEXT_UNVERIFIED, _SSL_CONTEXT_VERIFIED, @@ -43,6 +43,13 @@ from aiohttp.test_utils import unused_port from aiohttp.tracing import Trace +if sys.version_info >= (3, 11): + from typing import Unpack + + _RequestMaker = Callable[[str, URL, Unpack[ClientRequestArgs]], ClientRequest] +else: + _RequestMaker = Any + @pytest.fixture def key() -> ConnectionKey: @@ -648,9 +655,11 @@ async def test__release_acquired_per_host3( async def test_tcp_connector_certificate_error( - loop: asyncio.AbstractEventLoop, start_connection: mock.AsyncMock + loop: asyncio.AbstractEventLoop, + start_connection: mock.AsyncMock, + make_client_request: _RequestMaker, ) -> None: - req = ClientRequest("GET", URL("https://127.0.0.1:443"), loop=loop) + req = make_client_request("GET", URL("https://127.0.0.1:443"), loop=loop) conn = aiohttp.TCPConnector() with mock.patch.object( @@ -671,7 +680,9 @@ async def test_tcp_connector_certificate_error( async def test_tcp_connector_server_hostname_default( - loop: asyncio.AbstractEventLoop, start_connection: mock.AsyncMock + loop: asyncio.AbstractEventLoop, + start_connection: mock.AsyncMock, + make_client_request: _RequestMaker, ) -> None: conn = aiohttp.TCPConnector() @@ -680,7 +691,7 @@ async def test_tcp_connector_server_hostname_default( ) as create_connection: create_connection.return_value = mock.Mock(), mock.Mock() - req = ClientRequest("GET", URL("https://127.0.0.1:443"), loop=loop) + req = make_client_request("GET", URL("https://127.0.0.1:443"), loop=loop) with closing(await conn.connect(req, [], ClientTimeout())): assert create_connection.call_args.kwargs["server_hostname"] == "127.0.0.1" @@ -689,7 +700,9 @@ async def test_tcp_connector_server_hostname_default( async def test_tcp_connector_server_hostname_override( - loop: asyncio.AbstractEventLoop, start_connection: mock.AsyncMock + loop: asyncio.AbstractEventLoop, + start_connection: mock.AsyncMock, + make_client_request: _RequestMaker, ) -> None: conn = aiohttp.TCPConnector() @@ -698,7 +711,7 @@ async def test_tcp_connector_server_hostname_override( ) as create_connection: create_connection.return_value = mock.Mock(), mock.Mock() - req = ClientRequest( + req = make_client_request( "GET", URL("https://127.0.0.1:443"), loop=loop, server_hostname="localhost" ) @@ -709,7 +722,7 @@ async def test_tcp_connector_server_hostname_override( async def test_tcp_connector_multiple_hosts_errors( - loop: asyncio.AbstractEventLoop, + loop: asyncio.AbstractEventLoop, make_client_request: _RequestMaker ) -> None: conn = aiohttp.TCPConnector() @@ -724,7 +737,7 @@ async def test_tcp_connector_multiple_hosts_errors( fingerprint = hashlib.sha256(b"foo").digest() - req = ClientRequest( + req = make_client_request( "GET", URL("https://mocked.host"), ssl=aiohttp.Fingerprint(fingerprint), @@ -877,8 +890,10 @@ def get_extra_info(param: str) -> object: ("happy_eyeballs_delay"), [0.1, 0.25, None], ) -async def test_tcp_connector_happy_eyeballs( - loop: asyncio.AbstractEventLoop, happy_eyeballs_delay: float | None +async def test_tcp_connector_happy_eyeballs( # type: ignore[misc] + loop: asyncio.AbstractEventLoop, + happy_eyeballs_delay: float | None, + make_client_request: _RequestMaker, ) -> None: conn = aiohttp.TCPConnector(happy_eyeballs_delay=happy_eyeballs_delay) @@ -887,7 +902,7 @@ async def test_tcp_connector_happy_eyeballs( ips = [ip1, ip2] addrs_tried = [] - req = ClientRequest( + req = make_client_request( "GET", URL("https://mocked.host"), loop=loop, @@ -964,7 +979,9 @@ async def create_connection( await conn.close() -async def test_tcp_connector_interleave(loop: asyncio.AbstractEventLoop) -> None: +async def test_tcp_connector_interleave( + loop: asyncio.AbstractEventLoop, make_client_request: _RequestMaker +) -> None: conn = aiohttp.TCPConnector(interleave=2) ip1 = "192.168.1.1" @@ -976,7 +993,7 @@ async def test_tcp_connector_interleave(loop: asyncio.AbstractEventLoop) -> None success_ips = [] interleave_val = None - req = ClientRequest( + req = make_client_request( "GET", URL("https://mocked.host"), loop=loop, @@ -1057,7 +1074,7 @@ async def create_connection( async def test_tcp_connector_family_is_respected( - loop: asyncio.AbstractEventLoop, + loop: asyncio.AbstractEventLoop, make_client_request: _RequestMaker ) -> None: conn = aiohttp.TCPConnector(family=socket.AF_INET) @@ -1066,7 +1083,7 @@ async def test_tcp_connector_family_is_respected( ips = [ip1, ip2] addrs_tried = [] - req = ClientRequest( + req = make_client_request( "GET", URL("https://mocked.host"), loop=loop, @@ -1142,9 +1159,10 @@ async def create_connection( ("https://mocked.host"), ], ) -async def test_tcp_connector_multiple_hosts_one_timeout( +async def test_tcp_connector_multiple_hosts_one_timeout( # type: ignore[misc] loop: asyncio.AbstractEventLoop, request_url: str, + make_client_request: _RequestMaker, ) -> None: conn = aiohttp.TCPConnector() @@ -1156,7 +1174,7 @@ async def test_tcp_connector_multiple_hosts_one_timeout( timeout_error = False connected = False - req = ClientRequest( + req = make_client_request( "GET", URL(request_url), loop=loop, @@ -1409,6 +1427,7 @@ async def coro() -> NoReturn: async def test_tcp_connector_cancel_dns_error_captured( loop: asyncio.AbstractEventLoop, dns_response_error: Callable[[], Awaitable[NoReturn]], + make_client_request: _RequestMaker, ) -> None: exception_handler_called = False @@ -1419,9 +1438,7 @@ def exception_handler(loop: asyncio.AbstractEventLoop, context: object) -> None: loop.set_exception_handler(mock.Mock(side_effect=exception_handler)) with mock.patch("aiohttp.connector.DefaultResolver") as m_resolver: - req = ClientRequest( - method="GET", url=URL("http://temporary-failure:80"), loop=loop - ) + req = make_client_request("GET", URL("http://temporary-failure:80"), loop=loop) conn = aiohttp.TCPConnector( use_dns_cache=False, ) @@ -1592,7 +1609,9 @@ async def test_tcp_connector_close_resolver() -> None: m_resolver.close.assert_awaited_once() -async def test_dns_error(loop: asyncio.AbstractEventLoop) -> None: +async def test_dns_error( + loop: asyncio.AbstractEventLoop, make_client_request: _RequestMaker +) -> None: connector = aiohttp.TCPConnector() with mock.patch.object( connector, @@ -1601,7 +1620,7 @@ async def test_dns_error(loop: asyncio.AbstractEventLoop) -> None: spec_set=True, side_effect=OSError("dont take it serious"), ): - req = ClientRequest("GET", URL("http://www.python.org"), loop=loop) + req = make_client_request("GET", URL("http://www.python.org"), loop=loop) with pytest.raises(aiohttp.ClientConnectorError): await connector.connect(req, [], ClientTimeout()) @@ -1679,11 +1698,15 @@ async def test_release_not_opened( await conn.close() -async def test_connect(loop: asyncio.AbstractEventLoop, key: ConnectionKey) -> None: +async def test_connect( + loop: asyncio.AbstractEventLoop, + key: ConnectionKey, + make_client_request: _RequestMaker, +) -> None: proto = create_mocked_conn(loop) proto.is_connected.return_value = True - req = ClientRequest("GET", URL("http://localhost:80"), loop=loop) + req = make_client_request("GET", URL("http://localhost:80"), loop=loop) conn = aiohttp.BaseConnector() conn._conns[key] = deque([(proto, loop.time())]) @@ -1701,7 +1724,9 @@ async def test_connect(loop: asyncio.AbstractEventLoop, key: ConnectionKey) -> N await conn.close() -async def test_connect_tracing(loop: asyncio.AbstractEventLoop) -> None: +async def test_connect_tracing( + loop: asyncio.AbstractEventLoop, make_client_request: _RequestMaker +) -> None: session = mock.Mock() trace_config_ctx = mock.Mock() on_connection_create_start = mock.AsyncMock() @@ -1718,7 +1743,7 @@ async def test_connect_tracing(loop: asyncio.AbstractEventLoop) -> None: proto = create_mocked_conn(loop) proto.is_connected.return_value = True - req = ClientRequest("GET", URL("http://host:80"), loop=loop) + req = make_client_request("GET", URL("http://host:80"), loop=loop) conn = aiohttp.BaseConnector() with mock.patch.object( @@ -1742,8 +1767,8 @@ async def test_connect_tracing(loop: asyncio.AbstractEventLoop) -> None: "on_connection_create_end", ], ) -async def test_exception_during_connetion_create_tracing( - loop: asyncio.AbstractEventLoop, signal: str +async def test_exception_during_connetion_create_tracing( # type: ignore[misc] + loop: asyncio.AbstractEventLoop, signal: str, make_client_request: _RequestMaker ) -> None: session = mock.Mock() trace_config_ctx = mock.Mock() @@ -1758,7 +1783,7 @@ async def test_exception_during_connetion_create_tracing( proto = create_mocked_conn(loop) proto.is_connected.return_value = True - req = ClientRequest("GET", URL("http://host:80"), loop=loop) + req = make_client_request("GET", URL("http://host:80"), loop=loop) key = req.connection_key conn = aiohttp.BaseConnector() assert not conn._acquired @@ -1777,7 +1802,7 @@ async def test_exception_during_connetion_create_tracing( async def test_exception_during_connection_queued_tracing( - loop: asyncio.AbstractEventLoop, + loop: asyncio.AbstractEventLoop, make_client_request: _RequestMaker ) -> None: session = mock.Mock() trace_config_ctx = mock.Mock() @@ -1792,7 +1817,7 @@ async def test_exception_during_connection_queued_tracing( proto = create_mocked_conn(loop) proto.is_connected.return_value = True - req = ClientRequest("GET", URL("http://host:80"), loop=loop) + req = make_client_request("GET", URL("http://host:80"), loop=loop) key = req.connection_key conn = aiohttp.BaseConnector(limit=1) assert not conn._acquired @@ -1818,7 +1843,7 @@ async def test_exception_during_connection_queued_tracing( async def test_exception_during_connection_reuse_tracing( - loop: asyncio.AbstractEventLoop, + loop: asyncio.AbstractEventLoop, make_client_request: _RequestMaker ) -> None: session = mock.Mock() trace_config_ctx = mock.Mock() @@ -1833,7 +1858,7 @@ async def test_exception_during_connection_reuse_tracing( proto = create_mocked_conn(loop) proto.is_connected.return_value = True - req = ClientRequest("GET", URL("http://host:80"), loop=loop) + req = make_client_request("GET", URL("http://host:80"), loop=loop) key = req.connection_key conn = aiohttp.BaseConnector() assert not conn._acquired @@ -1860,6 +1885,7 @@ async def test_exception_during_connection_reuse_tracing( async def test_cancellation_during_waiting_for_free_connection( loop: asyncio.AbstractEventLoop, + make_client_request: _RequestMaker, ) -> None: session = mock.Mock() trace_config_ctx = mock.Mock() @@ -1878,7 +1904,7 @@ async def on_connection_queued_start(*args: object, **kwargs: object) -> None: proto = create_mocked_conn(loop) proto.is_connected.return_value = True - req = ClientRequest("GET", URL("http://host:80"), loop=loop) + req = make_client_request("GET", URL("http://host:80"), loop=loop) key = req.connection_key conn = aiohttp.BaseConnector(limit=1) assert not conn._acquired @@ -1902,12 +1928,14 @@ async def on_connection_queued_start(*args: object, **kwargs: object) -> None: assert key not in conn._acquired_per_host -async def test_close_during_connect(loop: asyncio.AbstractEventLoop) -> None: +async def test_close_during_connect( + loop: asyncio.AbstractEventLoop, make_client_request: _RequestMaker +) -> None: proto = create_mocked_conn(loop) proto.is_connected.return_value = True fut = loop.create_future() - req = ClientRequest("GET", URL("http://host:80"), loop=loop) + req = make_client_request("GET", URL("http://host:80"), loop=loop) conn = aiohttp.BaseConnector() with mock.patch.object(conn, "_create_connection", lambda *args: fut): @@ -2146,8 +2174,10 @@ async def test_tcp_connector_ssl_shutdown_timeout_pre_311( @pytest.mark.skipif( sys.version_info < (3, 11), reason="ssl_shutdown_timeout requires Python 3.11+" ) -async def test_tcp_connector_ssl_shutdown_timeout_passed_to_create_connection( - loop: asyncio.AbstractEventLoop, start_connection: mock.AsyncMock +async def test_tcp_connector_ssl_shutdown_timeout_passed_to_create_connection( # type: ignore[misc] + loop: asyncio.AbstractEventLoop, + start_connection: mock.AsyncMock, + make_client_request: _RequestMaker, ) -> None: # Test that ssl_shutdown_timeout is passed to create_connection for SSL connections with pytest.warns( @@ -2160,7 +2190,7 @@ async def test_tcp_connector_ssl_shutdown_timeout_passed_to_create_connection( ) as create_connection: create_connection.return_value = mock.Mock(), mock.Mock() - req = ClientRequest("GET", URL("https://example.com"), loop=loop) + req = make_client_request("GET", URL("https://example.com"), loop=loop) with closing(await conn.connect(req, [], ClientTimeout())): assert create_connection.call_args.kwargs["ssl_shutdown_timeout"] == 2.5 @@ -2178,7 +2208,7 @@ async def test_tcp_connector_ssl_shutdown_timeout_passed_to_create_connection( ) as create_connection: create_connection.return_value = mock.Mock(), mock.Mock() - req = ClientRequest("GET", URL("https://example.com"), loop=loop) + req = make_client_request("GET", URL("https://example.com"), loop=loop) with closing(await conn.connect(req, [], ClientTimeout())): # When ssl_shutdown_timeout is None, it should not be in kwargs @@ -2197,7 +2227,7 @@ async def test_tcp_connector_ssl_shutdown_timeout_passed_to_create_connection( ) as create_connection: create_connection.return_value = mock.Mock(), mock.Mock() - req = ClientRequest("GET", URL("http://example.com"), loop=loop) + req = make_client_request("GET", URL("http://example.com"), loop=loop) with closing(await conn.connect(req, [], ClientTimeout())): # For non-SSL connections, ssl_shutdown_timeout should not be passed @@ -2207,8 +2237,10 @@ async def test_tcp_connector_ssl_shutdown_timeout_passed_to_create_connection( @pytest.mark.skipif(sys.version_info >= (3, 11), reason="Test for Python < 3.11") -async def test_tcp_connector_ssl_shutdown_timeout_not_passed_pre_311( - loop: asyncio.AbstractEventLoop, start_connection: mock.AsyncMock +async def test_tcp_connector_ssl_shutdown_timeout_not_passed_pre_311( # type: ignore[misc] + loop: asyncio.AbstractEventLoop, + start_connection: mock.AsyncMock, + make_client_request: _RequestMaker, ) -> None: # Test that ssl_shutdown_timeout is NOT passed to create_connection on Python < 3.11 with warnings.catch_warnings(record=True) as w: @@ -2225,12 +2257,12 @@ async def test_tcp_connector_ssl_shutdown_timeout_not_passed_pre_311( create_connection.return_value = mock.Mock(), mock.Mock() # Test with HTTPS - req = ClientRequest("GET", URL("https://example.com"), loop=loop) + req = make_client_request("GET", URL("https://example.com"), loop=loop) with closing(await conn.connect(req, [], ClientTimeout())): assert "ssl_shutdown_timeout" not in create_connection.call_args.kwargs # Test with HTTP - req = ClientRequest("GET", URL("http://example.com"), loop=loop) + req = make_client_request("GET", URL("http://example.com"), loop=loop) with closing(await conn.connect(req, [], ClientTimeout())): assert "ssl_shutdown_timeout" not in create_connection.call_args.kwargs @@ -2376,7 +2408,9 @@ async def test_tcp_connector_ssl_shutdown_timeout_sentinel_no_warning_pre_311( async def test_tcp_connector_ssl_shutdown_timeout_zero_not_passed( - loop: asyncio.AbstractEventLoop, start_connection: mock.AsyncMock + loop: asyncio.AbstractEventLoop, + start_connection: mock.AsyncMock, + make_client_request: _RequestMaker, ) -> None: """Test that ssl_shutdown_timeout=0 is NOT passed to create_connection.""" with pytest.warns( @@ -2390,13 +2424,13 @@ async def test_tcp_connector_ssl_shutdown_timeout_zero_not_passed( create_connection.return_value = mock.Mock(), mock.Mock() # Test with HTTPS - req = ClientRequest("GET", URL("https://example.com"), loop=loop) + req = make_client_request("GET", URL("https://example.com"), loop=loop) with closing(await conn.connect(req, [], ClientTimeout())): # Verify ssl_shutdown_timeout was NOT passed assert "ssl_shutdown_timeout" not in create_connection.call_args.kwargs # Test with HTTP (should not have ssl_shutdown_timeout anyway) - req = ClientRequest("GET", URL("http://example.com"), loop=loop) + req = make_client_request("GET", URL("http://example.com"), loop=loop) with closing(await conn.connect(req, [], ClientTimeout())): assert "ssl_shutdown_timeout" not in create_connection.call_args.kwargs @@ -2406,8 +2440,10 @@ async def test_tcp_connector_ssl_shutdown_timeout_zero_not_passed( @pytest.mark.skipif( sys.version_info < (3, 11), reason="ssl_shutdown_timeout requires Python 3.11+" ) -async def test_tcp_connector_ssl_shutdown_timeout_nonzero_passed( - loop: asyncio.AbstractEventLoop, start_connection: mock.AsyncMock +async def test_tcp_connector_ssl_shutdown_timeout_nonzero_passed( # type: ignore[misc] + loop: asyncio.AbstractEventLoop, + start_connection: mock.AsyncMock, + make_client_request: _RequestMaker, ) -> None: """Test that non-zero ssl_shutdown_timeout IS passed to create_connection on Python 3.11+.""" with pytest.warns( @@ -2421,13 +2457,13 @@ async def test_tcp_connector_ssl_shutdown_timeout_nonzero_passed( create_connection.return_value = mock.Mock(), mock.Mock() # Test with HTTPS - req = ClientRequest("GET", URL("https://example.com"), loop=loop) + req = make_client_request("GET", URL("https://example.com"), loop=loop) with closing(await conn.connect(req, [], ClientTimeout())): # Verify ssl_shutdown_timeout WAS passed assert create_connection.call_args.kwargs["ssl_shutdown_timeout"] == 5.0 # Test with HTTP (should not have ssl_shutdown_timeout) - req = ClientRequest("GET", URL("http://example.com"), loop=loop) + req = make_client_request("GET", URL("http://example.com"), loop=loop) with closing(await conn.connect(req, [], ClientTimeout())): assert "ssl_shutdown_timeout" not in create_connection.call_args.kwargs @@ -2747,7 +2783,9 @@ async def test_close_cancels_cleanup_handle( assert conn._cleanup_handle is None -async def test_close_cancels_resolve_host(loop: asyncio.AbstractEventLoop) -> None: +async def test_close_cancels_resolve_host( + loop: asyncio.AbstractEventLoop, make_client_request: _RequestMaker +) -> None: cancelled = False async def delay_resolve(*args: object, **kwargs: object) -> None: @@ -2760,7 +2798,7 @@ async def delay_resolve(*args: object, **kwargs: object) -> None: raise conn = aiohttp.TCPConnector() - req = ClientRequest( + req = make_client_request( "GET", URL("http://localhost:80"), loop=loop, response_class=mock.Mock() ) with mock.patch.object(conn._resolver, "resolve", delay_resolve): @@ -2781,7 +2819,7 @@ async def delay_resolve(*args: object, **kwargs: object) -> None: async def test_multiple_dns_resolution_requests_success( - loop: asyncio.AbstractEventLoop, + loop: asyncio.AbstractEventLoop, make_client_request: _RequestMaker ) -> None: """Verify that multiple DNS resolution requests are handled correctly.""" @@ -2801,7 +2839,7 @@ async def delay_resolve(*args: object, **kwargs: object) -> list[ResolveResult]: ] conn = aiohttp.TCPConnector(force_close=True) - req = ClientRequest( + req = make_client_request( "GET", URL("http://localhost:80"), loop=loop, response_class=mock.Mock() ) with ( @@ -2843,7 +2881,7 @@ async def delay_resolve(*args: object, **kwargs: object) -> list[ResolveResult]: async def test_multiple_dns_resolution_requests_failure( - loop: asyncio.AbstractEventLoop, + loop: asyncio.AbstractEventLoop, make_client_request: _RequestMaker ) -> None: """Verify that DNS resolution failure for multiple requests is handled correctly.""" @@ -2854,7 +2892,7 @@ async def delay_resolve(*args: object, **kwargs: object) -> list[ResolveResult]: raise OSError(None, "DNS Resolution mock failure") conn = aiohttp.TCPConnector(force_close=True) - req = ClientRequest( + req = make_client_request( "GET", URL("http://localhost:80"), loop=loop, response_class=mock.Mock() ) with ( @@ -2896,7 +2934,7 @@ async def delay_resolve(*args: object, **kwargs: object) -> list[ResolveResult]: async def test_multiple_dns_resolution_requests_cancelled( - loop: asyncio.AbstractEventLoop, + loop: asyncio.AbstractEventLoop, make_client_request: _RequestMaker ) -> None: """Verify that DNS resolution cancellation does not affect other tasks.""" @@ -2907,7 +2945,7 @@ async def delay_resolve(*args: object, **kwargs: object) -> list[ResolveResult]: raise OSError(None, "DNS Resolution mock failure") conn = aiohttp.TCPConnector(force_close=True) - req = ClientRequest( + req = make_client_request( "GET", URL("http://localhost:80"), loop=loop, response_class=mock.Mock() ) with ( @@ -2948,7 +2986,7 @@ async def delay_resolve(*args: object, **kwargs: object) -> list[ResolveResult]: async def test_multiple_dns_resolution_requests_first_cancelled( - loop: asyncio.AbstractEventLoop, + loop: asyncio.AbstractEventLoop, make_client_request: _RequestMaker ) -> None: """Verify that first DNS resolution cancellation does not make other resolutions fail.""" @@ -2968,7 +3006,7 @@ async def delay_resolve(*args: object, **kwargs: object) -> list[ResolveResult]: ] conn = aiohttp.TCPConnector(force_close=True) - req = ClientRequest( + req = make_client_request( "GET", URL("http://localhost:80"), loop=loop, response_class=mock.Mock() ) with ( @@ -3011,7 +3049,7 @@ async def delay_resolve(*args: object, **kwargs: object) -> list[ResolveResult]: async def test_multiple_dns_resolution_requests_first_fails_second_successful( - loop: asyncio.AbstractEventLoop, + loop: asyncio.AbstractEventLoop, make_client_request: _RequestMaker ) -> None: """Verify that first DNS resolution fails the first time and is successful the second time.""" attempt = 0 @@ -3036,7 +3074,7 @@ async def delay_resolve(*args: object, **kwargs: object) -> list[ResolveResult]: ] conn = aiohttp.TCPConnector(force_close=True) - req = ClientRequest( + req = make_client_request( "GET", URL("http://localhost:80"), loop=loop, response_class=mock.Mock() ) with ( @@ -3133,12 +3171,14 @@ async def test_base_connector_allows_high_level_protocols( async def test_connect_with_limit( - loop: asyncio.AbstractEventLoop, key: ConnectionKey + loop: asyncio.AbstractEventLoop, + key: ConnectionKey, + make_client_request: _RequestMaker, ) -> None: proto = create_mocked_conn(loop) proto.is_connected.return_value = True - req = ClientRequest( + req = make_client_request( "GET", URL("http://localhost:80"), loop=loop, response_class=mock.Mock() ) @@ -3177,7 +3217,9 @@ async def f() -> None: async def test_connect_queued_operation_tracing( - loop: asyncio.AbstractEventLoop, key: ConnectionKey + loop: asyncio.AbstractEventLoop, + key: ConnectionKey, + make_client_request: _RequestMaker, ) -> None: session = mock.Mock() trace_config_ctx = mock.Mock() @@ -3195,7 +3237,7 @@ async def test_connect_queued_operation_tracing( proto = create_mocked_conn(loop) proto.is_connected.return_value = True - req = ClientRequest( + req = make_client_request( "GET", URL("http://localhost1:80"), loop=loop, response_class=mock.Mock() ) @@ -3224,7 +3266,9 @@ async def f() -> None: async def test_connect_reuseconn_tracing( - loop: asyncio.AbstractEventLoop, key: ConnectionKey + loop: asyncio.AbstractEventLoop, + key: ConnectionKey, + make_client_request: _RequestMaker, ) -> None: session = mock.Mock() trace_config_ctx = mock.Mock() @@ -3240,7 +3284,7 @@ async def test_connect_reuseconn_tracing( proto = create_mocked_conn(loop) proto.is_connected.return_value = True - req = ClientRequest( + req = make_client_request( "GET", URL("http://localhost:80"), loop=loop, response_class=mock.Mock() ) @@ -3256,12 +3300,14 @@ async def test_connect_reuseconn_tracing( async def test_connect_with_limit_and_limit_per_host( - loop: asyncio.AbstractEventLoop, key: ConnectionKey + loop: asyncio.AbstractEventLoop, + key: ConnectionKey, + make_client_request: _RequestMaker, ) -> None: proto = create_mocked_conn(loop) proto.is_connected.return_value = True - req = ClientRequest("GET", URL("http://localhost:80"), loop=loop) + req = make_client_request("GET", URL("http://localhost:80"), loop=loop) conn = aiohttp.BaseConnector(limit=1000, limit_per_host=1) conn._conns[key] = deque([(proto, loop.time())]) @@ -3291,12 +3337,14 @@ async def f() -> None: async def test_connect_with_no_limit_and_limit_per_host( - loop: asyncio.AbstractEventLoop, key: ConnectionKey + loop: asyncio.AbstractEventLoop, + key: ConnectionKey, + make_client_request: _RequestMaker, ) -> None: proto = create_mocked_conn(loop) proto.is_connected.return_value = True - req = ClientRequest("GET", URL("http://localhost1:80"), loop=loop) + req = make_client_request("GET", URL("http://localhost1:80"), loop=loop) conn = aiohttp.BaseConnector(limit=0, limit_per_host=1) conn._conns[key] = deque([(proto, loop.time())]) @@ -3324,12 +3372,14 @@ async def f() -> None: async def test_connect_with_no_limits( - loop: asyncio.AbstractEventLoop, key: ConnectionKey + loop: asyncio.AbstractEventLoop, + key: ConnectionKey, + make_client_request: _RequestMaker, ) -> None: proto = create_mocked_conn(loop) proto.is_connected.return_value = True - req = ClientRequest("GET", URL("http://localhost:80"), loop=loop) + req = make_client_request("GET", URL("http://localhost:80"), loop=loop) conn = aiohttp.BaseConnector(limit=0, limit_per_host=0) conn._conns[key] = deque([(proto, loop.time())]) @@ -3357,12 +3407,14 @@ async def f() -> None: async def test_connect_with_limit_cancelled( - loop: asyncio.AbstractEventLoop, key: ConnectionKey + loop: asyncio.AbstractEventLoop, + key: ConnectionKey, + make_client_request: _RequestMaker, ) -> None: proto = create_mocked_conn(loop) proto.is_connected.return_value = True - req = ClientRequest("GET", URL("http://host:80"), loop=loop) + req = make_client_request("GET", URL("http://host:80"), loop=loop) conn = aiohttp.BaseConnector(limit=1) conn._conns[key] = deque([(proto, loop.time())]) @@ -3404,12 +3456,14 @@ async def check_with_exc(err: Exception) -> None: await check_with_exc(asyncio.TimeoutError()) -async def test_connect_with_limit_concurrent(loop: asyncio.AbstractEventLoop) -> None: +async def test_connect_with_limit_concurrent( + loop: asyncio.AbstractEventLoop, make_client_request: _RequestMaker +) -> None: proto = create_mocked_conn(loop) proto.should_close = False proto.is_connected.return_value = True - req = ClientRequest("GET", URL("http://host:80"), loop=loop) + req = make_client_request("GET", URL("http://host:80"), loop=loop) max_connections = 2 num_connections = 0 @@ -3464,11 +3518,13 @@ async def f(start: bool = True) -> None: assert max_connections == num_connections -async def test_connect_waiters_cleanup(loop: asyncio.AbstractEventLoop) -> None: +async def test_connect_waiters_cleanup( + loop: asyncio.AbstractEventLoop, make_client_request: _RequestMaker +) -> None: proto = create_mocked_conn(loop) proto.is_connected.return_value = True - req = ClientRequest("GET", URL("http://host:80"), loop=loop) + req = make_client_request("GET", URL("http://host:80"), loop=loop) conn = aiohttp.BaseConnector(limit=1) with mock.patch.object(conn, "_available_connections", return_value=0): @@ -3485,12 +3541,12 @@ async def test_connect_waiters_cleanup(loop: asyncio.AbstractEventLoop) -> None: async def test_connect_waiters_cleanup_key_error( - loop: asyncio.AbstractEventLoop, + loop: asyncio.AbstractEventLoop, make_client_request: _RequestMaker ) -> None: proto = create_mocked_conn(loop) proto.is_connected.return_value = True - req = ClientRequest("GET", URL("http://host:80"), loop=loop) + req = make_client_request("GET", URL("http://host:80"), loop=loop) conn = aiohttp.BaseConnector(limit=1, limit_per_host=10) with mock.patch.object( @@ -3513,12 +3569,14 @@ async def test_connect_waiters_cleanup_key_error( async def test_close_with_acquired_connection( - loop: asyncio.AbstractEventLoop, key: ConnectionKey + loop: asyncio.AbstractEventLoop, + key: ConnectionKey, + make_client_request: _RequestMaker, ) -> None: proto = create_mocked_conn(loop) proto.is_connected.return_value = True - req = ClientRequest("GET", URL("http://host:80"), loop=loop) + req = make_client_request("GET", URL("http://host:80"), loop=loop) conn = aiohttp.BaseConnector(limit=1) conn._conns[key] = deque([(proto, loop.time())]) @@ -3722,21 +3780,25 @@ async def handler(request: web.Request) -> web.Response: @pytest.mark.skipif(not hasattr(socket, "AF_UNIX"), reason="requires UNIX sockets") -async def test_unix_connector_not_found(loop: asyncio.AbstractEventLoop) -> None: +async def test_unix_connector_not_found( # type: ignore[misc] + loop: asyncio.AbstractEventLoop, make_client_request: _RequestMaker +) -> None: connector = aiohttp.UnixConnector("/" + uuid.uuid4().hex) - req = ClientRequest("GET", URL("http://www.python.org"), loop=loop) + req = make_client_request("GET", URL("http://www.python.org"), loop=loop) with pytest.raises(aiohttp.ClientConnectorError): await connector.connect(req, [], ClientTimeout()) @pytest.mark.skipif(not hasattr(socket, "AF_UNIX"), reason="requires UNIX sockets") -async def test_unix_connector_permission(loop: asyncio.AbstractEventLoop) -> None: +async def test_unix_connector_permission( # type: ignore[misc] + loop: asyncio.AbstractEventLoop, make_client_request: _RequestMaker +) -> None: m = mock.AsyncMock(side_effect=PermissionError()) with mock.patch.object(loop, "create_unix_connection", m): connector = aiohttp.UnixConnector("/" + uuid.uuid4().hex) - req = ClientRequest("GET", URL("http://www.python.org"), loop=loop) + req = make_client_request("GET", URL("http://www.python.org"), loop=loop) with pytest.raises(aiohttp.ClientConnectorError): await connector.connect(req, [], ClientTimeout()) @@ -3754,13 +3816,15 @@ async def test_named_pipe_connector_wrong_loop( @pytest.mark.skipif( platform.system() != "Windows", reason="Proactor Event loop present only in Windows" ) -async def test_named_pipe_connector_not_found( - proactor_loop: asyncio.AbstractEventLoop, pipe_name: str +async def test_named_pipe_connector_not_found( # type: ignore[misc] + proactor_loop: asyncio.AbstractEventLoop, + pipe_name: str, + make_client_request: _RequestMaker, ) -> None: asyncio.set_event_loop(proactor_loop) connector = aiohttp.NamedPipeConnector(pipe_name) - req = ClientRequest("GET", URL("http://www.python.org"), loop=proactor_loop) + req = make_client_request("GET", URL("http://www.python.org"), loop=proactor_loop) with pytest.raises(aiohttp.ClientConnectorError): await connector.connect(req, [], ClientTimeout()) @@ -3768,15 +3832,19 @@ async def test_named_pipe_connector_not_found( @pytest.mark.skipif( platform.system() != "Windows", reason="Proactor Event loop present only in Windows" ) -async def test_named_pipe_connector_permission( - proactor_loop: asyncio.AbstractEventLoop, pipe_name: str +async def test_named_pipe_connector_permission( # type: ignore[misc] + proactor_loop: asyncio.AbstractEventLoop, + pipe_name: str, + make_client_request: _RequestMaker, ) -> None: m = mock.AsyncMock(side_effect=PermissionError()) with mock.patch.object(proactor_loop, "create_pipe_connection", m): asyncio.set_event_loop(proactor_loop) connector = aiohttp.NamedPipeConnector(pipe_name) - req = ClientRequest("GET", URL("http://www.python.org"), loop=proactor_loop) + req = make_client_request( + "GET", URL("http://www.python.org"), loop=proactor_loop + ) with pytest.raises(aiohttp.ClientConnectorError): await connector.connect(req, [], ClientTimeout()) @@ -3789,12 +3857,12 @@ async def test_default_use_dns_cache() -> None: async def test_resolver_not_called_with_address_is_ip( - loop: asyncio.AbstractEventLoop, + loop: asyncio.AbstractEventLoop, make_client_request: _RequestMaker ) -> None: resolver = mock.MagicMock() connector = aiohttp.TCPConnector(resolver=resolver) - req = ClientRequest( + req = make_client_request( "GET", URL(f"http://127.0.0.1:{unused_port()}"), loop=loop, @@ -4230,12 +4298,14 @@ async def resolve_response() -> list[ResolveResult]: async def test_connector_does_not_remove_needed_waiters( - loop: asyncio.AbstractEventLoop, key: ConnectionKey + loop: asyncio.AbstractEventLoop, + key: ConnectionKey, + make_client_request: _RequestMaker, ) -> None: proto = create_mocked_conn(loop) proto.is_connected.return_value = True - req = ClientRequest("GET", URL("https://localhost:80"), loop=loop) + req = make_client_request("GET", URL("https://localhost:80"), loop=loop) connection_key = req.connection_key async def await_connection_and_check_waiters() -> None: @@ -4281,13 +4351,13 @@ async def allow_connection_and_add_dummy_waiter() -> None: await connector.close() -def test_connector_multiple_event_loop() -> None: +def test_connector_multiple_event_loop(make_client_request: _RequestMaker) -> None: """Test the connector with multiple event loops.""" async def async_connect() -> Literal[True]: conn = aiohttp.TCPConnector() loop = asyncio.get_running_loop() - req = ClientRequest("GET", URL("https://127.0.0.1"), loop=loop) + req = make_client_request("GET", URL("https://127.0.0.1"), loop=loop) with suppress(aiohttp.ClientConnectorError): with mock.patch.object( conn._loop, @@ -4314,7 +4384,9 @@ def test_connect() -> Literal[True]: async def test_tcp_connector_socket_factory( - loop: asyncio.AbstractEventLoop, start_connection: mock.AsyncMock + loop: asyncio.AbstractEventLoop, + start_connection: mock.AsyncMock, + make_client_request: _RequestMaker, ) -> None: """Check that socket factory is called""" with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: @@ -4340,7 +4412,7 @@ async def test_tcp_connector_socket_factory( ): host = "127.0.0.1" port = 443 - req = ClientRequest("GET", URL(f"https://{host}:{port}"), loop=loop) + req = make_client_request("GET", URL(f"https://{host}:{port}"), loop=loop) with closing(await conn.connect(req, [], ClientTimeout())): pass await conn.close() diff --git a/tests/test_payload.py b/tests/test_payload.py index cb38cb5a6d0..d5c2a9a0246 100644 --- a/tests/test_payload.py +++ b/tests/test_payload.py @@ -1352,3 +1352,28 @@ def tell(self) -> int: # For unseekable files that can't tell() or seek(), # they are marked as consumed after the first write assert p.consumed is True + + +async def test_empty_bytes_payload_is_reusable() -> None: + """Test that empty BytesPayload can be safely reused across requests.""" + empty_payload = payload.PAYLOAD_REGISTRY.get(b"", disposition=None) + + assert isinstance(empty_payload, payload.BytesPayload) + assert empty_payload.size == 0 + assert empty_payload.consumed is False + assert empty_payload.autoclose is True + + initial_headers = dict(empty_payload.headers) + + for i in range(3): + writer = BufferWriter() + await empty_payload.write_with_length(writer, None) + + assert writer.buffer == b"" + assert empty_payload.consumed is False, f"consumed flag changed on write {i+1}" + assert ( + dict(empty_payload.headers) == initial_headers + ), f"headers mutated on write {i+1}" + assert empty_payload.size == 0, f"size changed on write {i+1}" + + assert empty_payload.headers == CIMultiDict(initial_headers) diff --git a/tests/test_proxy.py b/tests/test_proxy.py index fd22c3a9910..147b5998b8e 100644 --- a/tests/test_proxy.py +++ b/tests/test_proxy.py @@ -1,41 +1,58 @@ import asyncio import socket import ssl +import sys +from typing import Callable from unittest import mock import pytest +from multidict import CIMultiDict from yarl import URL import aiohttp -from aiohttp.client_reqrep import ClientRequest, ClientResponse, Fingerprint +from aiohttp.client_reqrep import ( + ClientRequest, + ClientRequestArgs, + ClientRequestBase, + ClientResponse, + Fingerprint, +) from aiohttp.connector import _SSL_CONTEXT_VERIFIED from aiohttp.helpers import TimerNoop +if sys.version_info >= (3, 11): + from typing import Unpack -@mock.patch("aiohttp.connector.ClientRequest") + _RequestMaker = Callable[[str, URL, Unpack[ClientRequestArgs]], ClientRequest] +else: + from typing import Any + + _RequestMaker = Any + + +@mock.patch("aiohttp.connector.ClientRequestBase") @mock.patch( "aiohttp.connector.aiohappyeyeballs.start_connection", autospec=True, spec_set=True, ) -def test_connect( # type: ignore[misc] +async def test_connect( # type: ignore[misc] start_connection: mock.Mock, ClientRequestMock: mock.Mock, - event_loop: asyncio.AbstractEventLoop, + make_client_request: _RequestMaker, ) -> None: - req = ClientRequest( + event_loop = asyncio.get_running_loop() + req = make_client_request( "GET", URL("http://www.python.org"), proxy=URL("http://proxy.example.com"), loop=event_loop, + ssl=True, + headers=CIMultiDict({}), ) assert str(req.proxy) == "http://proxy.example.com" - # mock all the things! - async def make_conn() -> aiohttp.TCPConnector: - return aiohttp.TCPConnector() - - connector = event_loop.run_until_complete(make_conn()) + connector = aiohttp.TCPConnector() r = { "hostname": "hostname", "host": "127.0.0.1", @@ -56,9 +73,7 @@ async def make_conn() -> aiohttp.TCPConnector: autospec=True, return_value=(proto.transport, proto), ): - conn = event_loop.run_until_complete( - connector.connect(req, [], aiohttp.ClientTimeout()) - ) + conn = await connector.connect(req, [], aiohttp.ClientTimeout()) assert req.url == URL("http://www.python.org") assert conn._protocol is proto assert conn.transport is proto.transport @@ -73,34 +88,33 @@ async def make_conn() -> aiohttp.TCPConnector: ) conn.close() - event_loop.run_until_complete(connector.close()) + await connector.close() -@mock.patch("aiohttp.connector.ClientRequest") +@mock.patch("aiohttp.connector.ClientRequestBase") @mock.patch( "aiohttp.connector.aiohappyeyeballs.start_connection", autospec=True, spec_set=True, ) -def test_proxy_headers( # type: ignore[misc] +async def test_proxy_headers( # type: ignore[misc] start_connection: mock.Mock, ClientRequestMock: mock.Mock, - event_loop: asyncio.AbstractEventLoop, + make_client_request: _RequestMaker, ) -> None: - req = ClientRequest( + event_loop = asyncio.get_running_loop() + req = make_client_request( "GET", URL("http://www.python.org"), proxy=URL("http://proxy.example.com"), - proxy_headers={"Foo": "Bar"}, + proxy_headers=CIMultiDict({"Foo": "Bar"}), loop=event_loop, + ssl=True, + headers=CIMultiDict({}), ) assert str(req.proxy) == "http://proxy.example.com" - # mock all the things! - async def make_conn() -> aiohttp.TCPConnector: - return aiohttp.TCPConnector() - - connector = event_loop.run_until_complete(make_conn()) + connector = aiohttp.TCPConnector() r = { "hostname": "hostname", "host": "127.0.0.1", @@ -121,9 +135,7 @@ async def make_conn() -> aiohttp.TCPConnector: autospec=True, return_value=(proto.transport, proto), ): - conn = event_loop.run_until_complete( - connector.connect(req, [], aiohttp.ClientTimeout()) - ) + conn = await connector.connect(req, [], aiohttp.ClientTimeout()) assert req.url == URL("http://www.python.org") assert conn._protocol is proto assert conn.transport is proto.transport @@ -138,7 +150,7 @@ async def make_conn() -> aiohttp.TCPConnector: ) conn.close() - event_loop.run_until_complete(connector.close()) + await connector.close() @mock.patch( @@ -146,10 +158,13 @@ async def make_conn() -> aiohttp.TCPConnector: autospec=True, spec_set=True, ) -def test_proxy_auth(start_connection: mock.Mock) -> None: # type: ignore[misc] +async def test_proxy_auth( # type: ignore[misc] + start_connection: mock.Mock, + make_client_request: _RequestMaker, +) -> None: msg = r"proxy_auth must be None or BasicAuth\(\) tuple" with pytest.raises(ValueError, match=msg): - ClientRequest( + make_client_request( "GET", URL("http://python.org"), proxy=URL("http://proxy.example.com"), @@ -163,34 +178,29 @@ def test_proxy_auth(start_connection: mock.Mock) -> None: # type: ignore[misc] autospec=True, spec_set=True, ) -def test_proxy_dns_error( # type: ignore[misc] +async def test_proxy_dns_error( # type: ignore[misc] start_connection: mock.Mock, - event_loop: asyncio.AbstractEventLoop, + make_client_request: _RequestMaker, ) -> None: - async def make_conn() -> aiohttp.TCPConnector: - return aiohttp.TCPConnector() - - connector = event_loop.run_until_complete(make_conn()) + connector = aiohttp.TCPConnector() with mock.patch.object( connector, "_resolve_host", autospec=True, side_effect=OSError("dont take it serious"), ): - req = ClientRequest( + req = make_client_request( "GET", URL("http://www.python.org"), proxy=URL("http://proxy.example.com"), - loop=event_loop, + loop=asyncio.get_running_loop(), ) expected_headers = dict(req.headers) with pytest.raises(aiohttp.ClientConnectorError): - event_loop.run_until_complete( - connector.connect(req, [], aiohttp.ClientTimeout()) - ) + await connector.connect(req, [], aiohttp.ClientTimeout()) assert req.url.path == "/" assert dict(req.headers) == expected_headers - event_loop.run_until_complete(connector.close()) + await connector.close() @mock.patch( @@ -199,14 +209,11 @@ async def make_conn() -> aiohttp.TCPConnector: spec_set=True, return_value=mock.create_autospec(socket.socket, spec_set=True, instance=True), ) -def test_proxy_connection_error( # type: ignore[misc] +async def test_proxy_connection_error( # type: ignore[misc] start_connection: mock.Mock, - event_loop: asyncio.AbstractEventLoop, + make_client_request: _RequestMaker, ) -> None: - async def make_conn() -> aiohttp.TCPConnector: - return aiohttp.TCPConnector() - - connector = event_loop.run_until_complete(make_conn()) + connector = aiohttp.TCPConnector() r = { "hostname": "www.python.org", "host": "127.0.0.1", @@ -222,52 +229,56 @@ async def make_conn() -> aiohttp.TCPConnector: autospec=True, side_effect=OSError("dont take it serious"), ): - req = ClientRequest( + req = make_client_request( "GET", URL("http://www.python.org"), proxy=URL("http://proxy.example.com"), - loop=event_loop, ) with pytest.raises(aiohttp.ClientProxyConnectionError): - event_loop.run_until_complete( - connector.connect(req, [], aiohttp.ClientTimeout()) - ) - event_loop.run_until_complete(connector.close()) + await connector.connect(req, [], aiohttp.ClientTimeout()) + await connector.close() -@mock.patch("aiohttp.connector.ClientRequest") +@mock.patch("aiohttp.connector.ClientRequestBase") @mock.patch( "aiohttp.connector.aiohappyeyeballs.start_connection", autospec=True, spec_set=True, ) -def test_proxy_server_hostname_default( # type: ignore[misc] +async def test_proxy_server_hostname_default( # type: ignore[misc] start_connection: mock.Mock, ClientRequestMock: mock.Mock, - event_loop: asyncio.AbstractEventLoop, + make_client_request: _RequestMaker, ) -> None: - proxy_req = ClientRequest("GET", URL("http://proxy.example.com"), loop=event_loop) + event_loop = asyncio.get_running_loop() + proxy_req = ClientRequestBase( + "GET", + URL("http://proxy.example.com"), + auth=None, + loop=event_loop, + ssl=True, + headers=CIMultiDict({}), + ) ClientRequestMock.return_value = proxy_req + url = URL("http://proxy.example.com") proxy_resp = ClientResponse( "get", - URL("http://proxy.example.com"), - request_info=mock.Mock(), + url, writer=None, continue100=None, timer=TimerNoop(), traces=[], loop=event_loop, session=mock.Mock(), + request_headers=CIMultiDict[str](), + original_url=url, ) - with mock.patch.object(proxy_req, "send", autospec=True, return_value=proxy_resp): + with mock.patch.object(proxy_req, "_send", autospec=True, return_value=proxy_resp): with mock.patch.object(proxy_resp, "start", autospec=True) as m: m.return_value.status = 200 - async def make_conn() -> aiohttp.TCPConnector: - return aiohttp.TCPConnector() - - connector = event_loop.run_until_complete(make_conn()) + connector = aiohttp.TCPConnector() r = { "hostname": "hostname", "host": "127.0.0.1", @@ -292,16 +303,14 @@ async def make_conn() -> aiohttp.TCPConnector: autospec=True, return_value=mock.Mock(), ) as tls_m: - req = ClientRequest( + req = make_client_request( "GET", URL("https://www.python.org"), proxy=URL("http://proxy.example.com"), loop=event_loop, ) - event_loop.run_until_complete( - connector._create_connection( - req, [], aiohttp.ClientTimeout() - ) + await connector._create_connection( + req, [], aiohttp.ClientTimeout() ) assert ( @@ -309,49 +318,51 @@ async def make_conn() -> aiohttp.TCPConnector: == "www.python.org" ) - event_loop.run_until_complete(proxy_req.close()) proxy_resp.close() - event_loop.run_until_complete(req.close()) - event_loop.run_until_complete(connector.close()) + await req._close() + await connector.close() -@mock.patch("aiohttp.connector.ClientRequest") +@mock.patch("aiohttp.connector.ClientRequestBase") @mock.patch( "aiohttp.connector.aiohappyeyeballs.start_connection", autospec=True, spec_set=True, ) -def test_proxy_server_hostname_override( # type: ignore[misc] +async def test_proxy_server_hostname_override( # type: ignore[misc] start_connection: mock.Mock, ClientRequestMock: mock.Mock, - event_loop: asyncio.AbstractEventLoop, + make_client_request: _RequestMaker, ) -> None: - proxy_req = ClientRequest( + event_loop = asyncio.get_running_loop() + proxy_req = ClientRequestBase( "GET", URL("http://proxy.example.com"), + auth=None, loop=event_loop, + ssl=True, + headers=CIMultiDict({}), ) ClientRequestMock.return_value = proxy_req + url = URL("http://proxy.example.com") proxy_resp = ClientResponse( "get", - URL("http://proxy.example.com"), - request_info=mock.Mock(), + url, writer=None, continue100=None, timer=TimerNoop(), traces=[], loop=event_loop, session=mock.Mock(), + request_headers=CIMultiDict[str](), + original_url=url, ) - with mock.patch.object(proxy_req, "send", autospec=True, return_value=proxy_resp): + with mock.patch.object(proxy_req, "_send", autospec=True, return_value=proxy_resp): with mock.patch.object(proxy_resp, "start", autospec=True) as m: m.return_value.status = 200 - async def make_conn() -> aiohttp.TCPConnector: - return aiohttp.TCPConnector() - - connector = event_loop.run_until_complete(make_conn()) + connector = aiohttp.TCPConnector() r = { "hostname": "hostname", "host": "127.0.0.1", @@ -376,17 +387,14 @@ async def make_conn() -> aiohttp.TCPConnector: autospec=True, return_value=mock.Mock(), ) as tls_m: - req = ClientRequest( + req = make_client_request( "GET", URL("https://www.python.org"), proxy=URL("http://proxy.example.com"), server_hostname="server-hostname.example.com", - loop=event_loop, ) - event_loop.run_until_complete( - connector._create_connection( - req, [], aiohttp.ClientTimeout() - ) + await connector._create_connection( + req, [], aiohttp.ClientTimeout() ) assert ( @@ -394,13 +402,12 @@ async def make_conn() -> aiohttp.TCPConnector: == "server-hostname.example.com" ) - event_loop.run_until_complete(proxy_req.close()) proxy_resp.close() - event_loop.run_until_complete(req.close()) - event_loop.run_until_complete(connector.close()) + await req._close() + await connector.close() -@mock.patch("aiohttp.connector.ClientRequest") +@mock.patch("aiohttp.connector.ClientRequestBase") @mock.patch( "aiohttp.connector.aiohappyeyeballs.start_connection", autospec=True, @@ -408,32 +415,39 @@ async def make_conn() -> aiohttp.TCPConnector: ) @pytest.mark.usefixtures("enable_cleanup_closed") @pytest.mark.parametrize("cleanup", (True, False)) -def test_https_connect_fingerprint_mismatch( # type: ignore[misc] +async def test_https_connect_fingerprint_mismatch( # type: ignore[misc] start_connection: mock.Mock, ClientRequestMock: mock.Mock, cleanup: bool, - event_loop: asyncio.AbstractEventLoop, + make_client_request: _RequestMaker, ) -> None: - async def make_conn() -> aiohttp.TCPConnector: - return aiohttp.TCPConnector(enable_cleanup_closed=cleanup) - - proxy_req = ClientRequest("GET", URL("http://proxy.example.com"), loop=event_loop) + event_loop = asyncio.get_running_loop() + proxy_req = ClientRequestBase( + "GET", + URL("http://proxy.example.com"), + auth=None, + loop=event_loop, + ssl=True, + headers=CIMultiDict({}), + ) ClientRequestMock.return_value = proxy_req class TransportMock(asyncio.Transport): def close(self) -> None: pass + url = URL("http://proxy.example.com") proxy_resp = ClientResponse( "get", - URL("http://proxy.example.com"), - request_info=mock.Mock(), + url, writer=mock.Mock(), continue100=None, timer=TimerNoop(), traces=[], loop=event_loop, session=mock.Mock(), + request_headers=CIMultiDict[str](), + original_url=url, ) fingerprint_mock = mock.Mock(spec=Fingerprint, auto_spec=True) fingerprint_mock.check.side_effect = aiohttp.ServerFingerprintMismatch( @@ -442,7 +456,7 @@ def close(self) -> None: with ( mock.patch.object( proxy_req, - "send", + "_send", autospec=True, spec_set=True, return_value=proxy_resp, @@ -455,7 +469,7 @@ def close(self) -> None: return_value=mock.Mock(status=200), ), ): - connector = event_loop.run_until_complete(make_conn()) + connector = aiohttp.TCPConnector(enable_cleanup_closed=cleanup) host = [ { "hostname": "hostname", @@ -496,51 +510,56 @@ def close(self) -> None: return_value=TransportMock(), ), ): - req = ClientRequest( + req = make_client_request( "GET", URL("https://www.python.org"), proxy=URL("http://proxy.example.com"), loop=event_loop, ) with pytest.raises(aiohttp.ServerFingerprintMismatch): - event_loop.run_until_complete( - connector._create_connection(req, [], aiohttp.ClientTimeout()) - ) + await connector._create_connection(req, [], aiohttp.ClientTimeout()) -@mock.patch("aiohttp.connector.ClientRequest") +@mock.patch("aiohttp.connector.ClientRequestBase") @mock.patch( "aiohttp.connector.aiohappyeyeballs.start_connection", autospec=True, spec_set=True, ) -def test_https_connect( # type: ignore[misc] +async def test_https_connect( # type: ignore[misc] start_connection: mock.Mock, ClientRequestMock: mock.Mock, - event_loop: asyncio.AbstractEventLoop, + make_client_request: _RequestMaker, ) -> None: - proxy_req = ClientRequest("GET", URL("http://proxy.example.com"), loop=event_loop) + event_loop = asyncio.get_running_loop() + proxy_req = ClientRequestBase( + "GET", + URL("http://proxy.example.com"), + auth=None, + loop=event_loop, + ssl=True, + headers=CIMultiDict({}), + ) ClientRequestMock.return_value = proxy_req + url = URL("http://proxy.example.com") proxy_resp = ClientResponse( "get", - URL("http://proxy.example.com"), - request_info=mock.Mock(), + url, writer=None, continue100=None, timer=TimerNoop(), traces=[], loop=event_loop, session=mock.Mock(), + request_headers=CIMultiDict[str](), + original_url=url, ) - with mock.patch.object(proxy_req, "send", autospec=True, return_value=proxy_resp): + with mock.patch.object(proxy_req, "_send", autospec=True, return_value=proxy_resp): with mock.patch.object(proxy_resp, "start", autospec=True) as m: m.return_value.status = 200 - async def make_conn() -> aiohttp.TCPConnector: - return aiohttp.TCPConnector() - - connector = event_loop.run_until_complete(make_conn()) + connector = aiohttp.TCPConnector() r = { "hostname": "hostname", "host": "127.0.0.1", @@ -565,61 +584,65 @@ async def make_conn() -> aiohttp.TCPConnector: autospec=True, return_value=mock.Mock(), ): - req = ClientRequest( + req = make_client_request( "GET", URL("https://www.python.org"), proxy=URL("http://proxy.example.com"), loop=event_loop, ) - event_loop.run_until_complete( - connector._create_connection( - req, [], aiohttp.ClientTimeout() - ) + await connector._create_connection( + req, [], aiohttp.ClientTimeout() ) assert req.url.path == "/" assert proxy_req.method == "CONNECT" assert proxy_req.url == URL("https://www.python.org") - event_loop.run_until_complete(proxy_req.close()) proxy_resp.close() - event_loop.run_until_complete(req.close()) - event_loop.run_until_complete(connector.close()) + await req._close() + await connector.close() -@mock.patch("aiohttp.connector.ClientRequest") +@mock.patch("aiohttp.connector.ClientRequestBase") @mock.patch( "aiohttp.connector.aiohappyeyeballs.start_connection", autospec=True, spec_set=True, ) -def test_https_connect_certificate_error( # type: ignore[misc] +async def test_https_connect_certificate_error( # type: ignore[misc] start_connection: mock.Mock, ClientRequestMock: mock.Mock, - event_loop: asyncio.AbstractEventLoop, + make_client_request: _RequestMaker, ) -> None: - proxy_req = ClientRequest("GET", URL("http://proxy.example.com"), loop=event_loop) + event_loop = asyncio.get_running_loop() + proxy_req = ClientRequestBase( + "GET", + URL("http://proxy.example.com"), + auth=None, + loop=event_loop, + ssl=True, + headers=CIMultiDict({}), + ) ClientRequestMock.return_value = proxy_req + url = URL("http://proxy.example.com") proxy_resp = ClientResponse( "get", - URL("http://proxy.example.com"), - request_info=mock.Mock(), + url, writer=None, continue100=None, timer=TimerNoop(), traces=[], loop=event_loop, session=mock.Mock(), + request_headers=CIMultiDict[str](), + original_url=url, ) - with mock.patch.object(proxy_req, "send", autospec=True, return_value=proxy_resp): + with mock.patch.object(proxy_req, "_send", autospec=True, return_value=proxy_resp): with mock.patch.object(proxy_resp, "start", autospec=True) as m: m.return_value.status = 200 - async def make_conn() -> aiohttp.TCPConnector: - return aiohttp.TCPConnector() - - connector = event_loop.run_until_complete(make_conn()) + connector = aiohttp.TCPConnector() r = { "hostname": "hostname", "host": "127.0.0.1", @@ -646,54 +669,59 @@ async def make_conn() -> aiohttp.TCPConnector: autospec=True, side_effect=ssl.CertificateError, ): - req = ClientRequest( + req = make_client_request( "GET", URL("https://www.python.org"), proxy=URL("http://proxy.example.com"), loop=event_loop, ) with pytest.raises(aiohttp.ClientConnectorCertificateError): - event_loop.run_until_complete( - connector._create_connection( - req, [], aiohttp.ClientTimeout() - ) + await connector._create_connection( + req, [], aiohttp.ClientTimeout() ) - event_loop.run_until_complete(connector.close()) + await connector.close() -@mock.patch("aiohttp.connector.ClientRequest") +@mock.patch("aiohttp.connector.ClientRequestBase") @mock.patch( "aiohttp.connector.aiohappyeyeballs.start_connection", autospec=True, spec_set=True, ) -def test_https_connect_ssl_error( # type: ignore[misc] +async def test_https_connect_ssl_error( # type: ignore[misc] start_connection: mock.Mock, ClientRequestMock: mock.Mock, - event_loop: asyncio.AbstractEventLoop, + make_client_request: _RequestMaker, ) -> None: - proxy_req = ClientRequest("GET", URL("http://proxy.example.com"), loop=event_loop) + event_loop = asyncio.get_running_loop() + proxy_req = ClientRequestBase( + "GET", + URL("http://proxy.example.com"), + auth=None, + loop=event_loop, + ssl=True, + headers=CIMultiDict({}), + ) ClientRequestMock.return_value = proxy_req + url = URL("http://proxy.example.com") proxy_resp = ClientResponse( "get", - URL("http://proxy.example.com"), - request_info=mock.Mock(), + url, writer=None, continue100=None, timer=TimerNoop(), traces=[], loop=event_loop, session=mock.Mock(), + request_headers=CIMultiDict[str](), + original_url=url, ) - with mock.patch.object(proxy_req, "send", autospec=True, return_value=proxy_resp): + with mock.patch.object(proxy_req, "_send", autospec=True, return_value=proxy_resp): with mock.patch.object(proxy_resp, "start", autospec=True) as m: m.return_value.status = 200 - async def make_conn() -> aiohttp.TCPConnector: - return aiohttp.TCPConnector() - - connector = event_loop.run_until_complete(make_conn()) + connector = aiohttp.TCPConnector() r = { "hostname": "hostname", "host": "127.0.0.1", @@ -720,55 +748,60 @@ async def make_conn() -> aiohttp.TCPConnector: autospec=True, side_effect=ssl.SSLError, ): - req = ClientRequest( + req = make_client_request( "GET", URL("https://www.python.org"), proxy=URL("http://proxy.example.com"), loop=event_loop, ) with pytest.raises(aiohttp.ClientConnectorSSLError): - event_loop.run_until_complete( - connector._create_connection( - req, [], aiohttp.ClientTimeout() - ) + await connector._create_connection( + req, [], aiohttp.ClientTimeout() ) - event_loop.run_until_complete(connector.close()) + await connector.close() -@mock.patch("aiohttp.connector.ClientRequest") +@mock.patch("aiohttp.connector.ClientRequestBase") @mock.patch( "aiohttp.connector.aiohappyeyeballs.start_connection", autospec=True, spec_set=True, ) -def test_https_connect_http_proxy_error( # type: ignore[misc] +async def test_https_connect_http_proxy_error( # type: ignore[misc] start_connection: mock.Mock, ClientRequestMock: mock.Mock, - event_loop: asyncio.AbstractEventLoop, + make_client_request: _RequestMaker, ) -> None: - proxy_req = ClientRequest("GET", URL("http://proxy.example.com"), loop=event_loop) + event_loop = asyncio.get_running_loop() + proxy_req = ClientRequestBase( + "GET", + URL("http://proxy.example.com"), + auth=None, + loop=event_loop, + ssl=True, + headers=CIMultiDict({}), + ) ClientRequestMock.return_value = proxy_req + url = URL("http://proxy.example.com") proxy_resp = ClientResponse( "get", - URL("http://proxy.example.com"), - request_info=mock.Mock(), + url, writer=None, continue100=None, timer=TimerNoop(), traces=[], loop=event_loop, session=mock.Mock(), + request_headers=CIMultiDict[str](), + original_url=url, ) - with mock.patch.object(proxy_req, "send", autospec=True, return_value=proxy_resp): + with mock.patch.object(proxy_req, "_send", autospec=True, return_value=proxy_resp): with mock.patch.object(proxy_resp, "start", autospec=True) as m: m.return_value.status = 400 m.return_value.reason = "bad request" - async def make_conn() -> aiohttp.TCPConnector: - return aiohttp.TCPConnector() - - connector = event_loop.run_until_complete(make_conn()) + connector = aiohttp.TCPConnector() r = { "hostname": "hostname", "host": "127.0.0.1", @@ -789,7 +822,7 @@ async def make_conn() -> aiohttp.TCPConnector: autospec=True, return_value=(tr, proto), ): - req = ClientRequest( + req = make_client_request( "GET", URL("https://www.python.org"), proxy=URL("http://proxy.example.com"), @@ -798,52 +831,55 @@ async def make_conn() -> aiohttp.TCPConnector: with pytest.raises( aiohttp.ClientHttpProxyError, match="400, message='bad request'" ): - event_loop.run_until_complete( - connector._create_connection( - req, [], aiohttp.ClientTimeout() - ) + await connector._create_connection( + req, [], aiohttp.ClientTimeout() ) - event_loop.run_until_complete(proxy_req.close()) proxy_resp.close() - event_loop.run_until_complete(req.close()) - event_loop.run_until_complete(connector.close()) + await req._close() + await connector.close() -@mock.patch("aiohttp.connector.ClientRequest") +@mock.patch("aiohttp.connector.ClientRequestBase") @mock.patch( "aiohttp.connector.aiohappyeyeballs.start_connection", autospec=True, spec_set=True, ) -def test_https_connect_resp_start_error( # type: ignore[misc] +async def test_https_connect_resp_start_error( # type: ignore[misc] start_connection: mock.Mock, ClientRequestMock: mock.Mock, - event_loop: asyncio.AbstractEventLoop, + make_client_request: _RequestMaker, ) -> None: - proxy_req = ClientRequest("GET", URL("http://proxy.example.com"), loop=event_loop) + event_loop = asyncio.get_running_loop() + proxy_req = ClientRequestBase( + "GET", + URL("http://proxy.example.com"), + auth=None, + loop=event_loop, + ssl=True, + headers=CIMultiDict({}), + ) ClientRequestMock.return_value = proxy_req + url = URL("http://proxy.example.com") proxy_resp = ClientResponse( "get", - URL("http://proxy.example.com"), - request_info=mock.Mock(), + url, writer=None, continue100=None, timer=TimerNoop(), traces=[], loop=event_loop, session=mock.Mock(), + request_headers=CIMultiDict[str](), + original_url=url, ) - with mock.patch.object(proxy_req, "send", autospec=True, return_value=proxy_resp): + with mock.patch.object(proxy_req, "_send", autospec=True, return_value=proxy_resp): with mock.patch.object( proxy_resp, "start", autospec=True, side_effect=OSError("error message") ): - - async def make_conn() -> aiohttp.TCPConnector: - return aiohttp.TCPConnector() - - connector = event_loop.run_until_complete(make_conn()) + connector = aiohttp.TCPConnector() r = { "hostname": "hostname", "host": "127.0.0.1", @@ -864,19 +900,17 @@ async def make_conn() -> aiohttp.TCPConnector: autospec=True, return_value=(tr, proto), ): - req = ClientRequest( + req = make_client_request( "GET", URL("https://www.python.org"), proxy=URL("http://proxy.example.com"), loop=event_loop, ) with pytest.raises(OSError, match="error message"): - event_loop.run_until_complete( - connector._create_connection( - req, [], aiohttp.ClientTimeout() - ) + await connector._create_connection( + req, [], aiohttp.ClientTimeout() ) - event_loop.run_until_complete(connector.close()) + await connector.close() @mock.patch("aiohttp.connector.ClientRequest") @@ -885,18 +919,18 @@ async def make_conn() -> aiohttp.TCPConnector: autospec=True, spec_set=True, ) -def test_request_port( # type: ignore[misc] +async def test_request_port( # type: ignore[misc] start_connection: mock.Mock, ClientRequestMock: mock.Mock, - event_loop: asyncio.AbstractEventLoop, + make_client_request: _RequestMaker, ) -> None: - proxy_req = ClientRequest("GET", URL("http://proxy.example.com"), loop=event_loop) + event_loop = asyncio.get_running_loop() + proxy_req = make_client_request( + "GET", URL("http://proxy.example.com"), loop=event_loop + ) ClientRequestMock.return_value = proxy_req - async def make_conn() -> aiohttp.TCPConnector: - return aiohttp.TCPConnector() - - connector = event_loop.run_until_complete(make_conn()) + connector = aiohttp.TCPConnector() r = { "hostname": "hostname", "host": "127.0.0.1", @@ -912,21 +946,21 @@ async def make_conn() -> aiohttp.TCPConnector: with mock.patch.object( event_loop, "create_connection", autospec=True, return_value=(tr, proto) ): - req = ClientRequest( + req = make_client_request( "GET", URL("http://localhost:1234/path"), proxy=URL("http://proxy.example.com"), loop=event_loop, ) - event_loop.run_until_complete( - connector._create_connection(req, [], aiohttp.ClientTimeout()) - ) + await connector._create_connection(req, [], aiohttp.ClientTimeout()) assert req.url == URL("http://localhost:1234/path") - event_loop.run_until_complete(connector.close()) + await connector.close() -def test_proxy_auth_property(event_loop: asyncio.AbstractEventLoop) -> None: - req = aiohttp.ClientRequest( +async def test_proxy_auth_property( + event_loop: asyncio.AbstractEventLoop, make_client_request: _RequestMaker +) -> None: + req = make_client_request( "GET", URL("http://localhost:1234/path"), proxy=URL("http://proxy.example.com"), @@ -936,8 +970,11 @@ def test_proxy_auth_property(event_loop: asyncio.AbstractEventLoop) -> None: assert ("user", "pass", "latin1") == req.proxy_auth -def test_proxy_auth_property_default(event_loop: asyncio.AbstractEventLoop) -> None: - req = aiohttp.ClientRequest( +async def test_proxy_auth_property_default( + event_loop: asyncio.AbstractEventLoop, + make_client_request: _RequestMaker, +) -> None: + req = make_client_request( "GET", URL("http://localhost:1234/path"), proxy=URL("http://proxy.example.com"), @@ -946,39 +983,46 @@ def test_proxy_auth_property_default(event_loop: asyncio.AbstractEventLoop) -> N assert req.proxy_auth is None -@mock.patch("aiohttp.connector.ClientRequest") +@mock.patch("aiohttp.connector.ClientRequestBase") @mock.patch( "aiohttp.connector.aiohappyeyeballs.start_connection", autospec=True, spec_set=True, ) -def test_https_connect_pass_ssl_context( # type: ignore[misc] +async def test_https_connect_pass_ssl_context( # type: ignore[misc] start_connection: mock.Mock, ClientRequestMock: mock.Mock, - event_loop: asyncio.AbstractEventLoop, + make_client_request: _RequestMaker, ) -> None: - proxy_req = ClientRequest("GET", URL("http://proxy.example.com"), loop=event_loop) + event_loop = asyncio.get_running_loop() + proxy_req = ClientRequestBase( + "GET", + URL("http://proxy.example.com"), + auth=None, + loop=event_loop, + ssl=True, + headers=CIMultiDict({}), + ) ClientRequestMock.return_value = proxy_req + url = URL("http://proxy.example.com") proxy_resp = ClientResponse( "get", - URL("http://proxy.example.com"), - request_info=mock.Mock(), + url, writer=None, continue100=None, timer=TimerNoop(), traces=[], loop=event_loop, session=mock.Mock(), + request_headers=CIMultiDict[str](), + original_url=url, ) - with mock.patch.object(proxy_req, "send", autospec=True, return_value=proxy_resp): + with mock.patch.object(proxy_req, "_send", autospec=True, return_value=proxy_resp): with mock.patch.object(proxy_resp, "start", autospec=True) as m: m.return_value.status = 200 - async def make_conn() -> aiohttp.TCPConnector: - return aiohttp.TCPConnector() - - connector = event_loop.run_until_complete(make_conn()) + connector = aiohttp.TCPConnector() r = { "hostname": "hostname", "host": "127.0.0.1", @@ -1003,16 +1047,14 @@ async def make_conn() -> aiohttp.TCPConnector: autospec=True, return_value=mock.Mock(), ) as tls_m: - req = ClientRequest( + req = make_client_request( "GET", URL("https://www.python.org"), proxy=URL("http://proxy.example.com"), loop=event_loop, ) - event_loop.run_until_complete( - connector._create_connection( - req, [], aiohttp.ClientTimeout() - ) + await connector._create_connection( + req, [], aiohttp.ClientTimeout() ) # ssl_shutdown_timeout=0 is not passed to start_tls @@ -1028,50 +1070,51 @@ async def make_conn() -> aiohttp.TCPConnector: assert proxy_req.method == "CONNECT" assert proxy_req.url == URL("https://www.python.org") - event_loop.run_until_complete(proxy_req.close()) proxy_resp.close() - event_loop.run_until_complete(req.close()) - event_loop.run_until_complete(connector.close()) + await req._close() + await connector.close() -@mock.patch("aiohttp.connector.ClientRequest") +@mock.patch("aiohttp.connector.ClientRequestBase") @mock.patch( "aiohttp.connector.aiohappyeyeballs.start_connection", autospec=True, spec_set=True, ) -def test_https_auth( # type: ignore[misc] +async def test_https_auth( # type: ignore[misc] start_connection: mock.Mock, ClientRequestMock: mock.Mock, - event_loop: asyncio.AbstractEventLoop, + make_client_request: _RequestMaker, ) -> None: - proxy_req = ClientRequest( + event_loop = asyncio.get_running_loop() + proxy_req = ClientRequestBase( "GET", URL("http://proxy.example.com"), auth=aiohttp.helpers.BasicAuth("user", "pass"), loop=event_loop, + ssl=True, + headers=CIMultiDict({}), ) ClientRequestMock.return_value = proxy_req + url = URL("http://proxy.example.com") proxy_resp = ClientResponse( "get", - URL("http://proxy.example.com"), - request_info=mock.Mock(), + url, writer=None, continue100=None, timer=TimerNoop(), traces=[], loop=event_loop, session=mock.Mock(), + request_headers=CIMultiDict[str](), + original_url=url, ) - with mock.patch.object(proxy_req, "send", autospec=True, return_value=proxy_resp): + with mock.patch.object(proxy_req, "_send", autospec=True, return_value=proxy_resp): with mock.patch.object(proxy_resp, "start", autospec=True) as m: m.return_value.status = 200 - async def make_conn() -> aiohttp.TCPConnector: - return aiohttp.TCPConnector() - - connector = event_loop.run_until_complete(make_conn()) + connector = aiohttp.TCPConnector() r = { "hostname": "hostname", "host": "127.0.0.1", @@ -1099,7 +1142,7 @@ async def make_conn() -> aiohttp.TCPConnector: assert "AUTHORIZATION" in proxy_req.headers assert "PROXY-AUTHORIZATION" not in proxy_req.headers - req = ClientRequest( + req = make_client_request( "GET", URL("https://www.python.org"), proxy=URL("http://proxy.example.com"), @@ -1107,10 +1150,8 @@ async def make_conn() -> aiohttp.TCPConnector: ) assert "AUTHORIZATION" not in req.headers assert "PROXY-AUTHORIZATION" not in req.headers - event_loop.run_until_complete( - connector._create_connection( - req, [], aiohttp.ClientTimeout() - ) + await connector._create_connection( + req, [], aiohttp.ClientTimeout() ) assert req.url.path == "/" @@ -1123,7 +1164,6 @@ async def make_conn() -> aiohttp.TCPConnector: "proxy.example.com", 80, traces=mock.ANY ) - event_loop.run_until_complete(proxy_req.close()) proxy_resp.close() - event_loop.run_until_complete(req.close()) - event_loop.run_until_complete(connector.close()) + await req._close() + await connector.close() diff --git a/tests/test_web_functional.py b/tests/test_web_functional.py index 5ff56cc2dab..db51a7cf06b 100644 --- a/tests/test_web_functional.py +++ b/tests/test_web_functional.py @@ -2230,7 +2230,7 @@ async def handler(request: web.Request) -> web.Response: app = web.Application() app.router.add_get("/", handler) - client = await aiohttp_client(app, version="1.1") + client = await aiohttp_client(app, version=HttpVersion11) resp = await client.get("/") assert CONTENT_LENGTH not in resp.headers assert TRANSFER_ENCODING not in resp.headers