From 3fb2b0c4a4d8921c190a76b5c8fe49b63adb9775 Mon Sep 17 00:00:00 2001 From: Rob Brackett Date: Fri, 15 Dec 2023 17:24:38 -0800 Subject: [PATCH] A horrifying implementation that passes tests We can refine from here. Lots of bad things that need cleanup, and, well, maybe just too much stuff in general. --- wayback/__init__.py | 4 + wayback/_client.py | 469 ++++++++++++++++++++++++++++++----- wayback/_utils.py | 29 +-- wayback/tests/test_client.py | 209 +++++++++++----- 4 files changed, 571 insertions(+), 140 deletions(-) diff --git a/wayback/__init__.py b/wayback/__init__.py index 2761e2a..b25f60d 100644 --- a/wayback/__init__.py +++ b/wayback/__init__.py @@ -2,6 +2,10 @@ __version__ = get_versions()['version'] del get_versions +# XXX: Just for testing! Must remove before merge. +import logging # noqa +logging.getLogger("urllib3").setLevel(logging.DEBUG) + from ._utils import memento_url_data, RateLimit # noqa from ._models import ( # noqa diff --git a/wayback/_client.py b/wayback/_client.py index df8cf7c..32c1680 100644 --- a/wayback/_client.py +++ b/wayback/_client.py @@ -27,10 +27,15 @@ # RetryError, # Timeout) import time +from typing import Generator, Optional from urllib.parse import urljoin, urlparse +from urllib3 import PoolManager, HTTPResponse, Timeout as Urllib3Timeout from urllib3.connectionpool import HTTPConnectionPool -from urllib3.exceptions import (ConnectTimeoutError, +from urllib3.exceptions import (ClosedPoolError, + ConnectTimeoutError, + DecodeError, MaxRetryError, + ProtocolError, ReadTimeoutError, ProxyError, TimeoutError, @@ -45,7 +50,8 @@ MementoPlaybackError, NoMementoError, WaybackRetryError, - RateLimitError) + RateLimitError, + SessionClosedError) logger = logging.getLogger(__name__) @@ -161,20 +167,6 @@ def cdx_hash(content): return b32encode(hashlib.sha1(content).digest()).decode() -# XXX: see how requests reads the body: -# https://github.com/psf/requests/blob/a25fde6989f8df5c3d823bc9f2e2fc24aa71f375/src/requests/models.py#L794-L839 -def read_and_close(response): - # Read content so it gets cached and close the response so - # we can release the connection for reuse. See: - # https://github.com/psf/requests/blob/eedd67462819f8dbf8c1c32e77f9070606605231/requests/sessions.py#L160-L163 - try: - response.content - except (ChunkedEncodingError, ContentDecodingError, RuntimeError): - response.raw.read(decode_content=False) - finally: - response.close() - - REDIRECT_PAGE_PATTERN = re.compile(r'Got an? HTTP 3\d\d response at crawl time', re.IGNORECASE) @@ -341,7 +333,311 @@ def _new_header_init(self, headers=None, **kwargs): ##################################################################### -class WaybackSession(_utils.DisableAfterCloseSession): +def iter_byte_slices(data: bytes, size: int) -> Generator[bytes, None, None]: + """ + Iterate over groups of N bytes from some original bytes. In Python 3.12+, + this can be done with ``itertools.batched()``. + """ + index = 0 + if size <= 0: + size = len(data) + while index < len(data): + yield data[index:index + size] + index += size + + +# XXX: pretty much wholesale taken from requests. May need adjustment. +def parse_header_links(value): + """Return a list of parsed link headers proxies. + + i.e. Link: ; rel=front; type="image/jpeg",; rel=back;type="image/jpeg" + + :rtype: list + """ + + links = [] + + replace_chars = " '\"" + + value = value.strip(replace_chars) + if not value: + return links + + for val in re.split(", *<", value): + try: + url, params = val.split(";", 1) + except ValueError: + url, params = val, "" + + link = {"url": url.strip("<> '\"")} + + for param in params.split(";"): + try: + key, value = param.split("=") + except ValueError: + break + + link[key.strip(replace_chars)] = value.strip(replace_chars) + + links.append(link) + + return links + + +from urllib.parse import urlencode +# XXX: pretty much wholesale taken from requests. May need adjustment. +# https://github.com/psf/requests/blob/147c8511ddbfa5e8f71bbf5c18ede0c4ceb3bba4/requests/models.py#L107-L134 +def serialize_querystring(data): + """Encode parameters in a piece of data. + + Will successfully encode parameters when passed as a dict or a list of + 2-tuples. Order is retained if data is a list of 2-tuples but arbitrary + if parameters are supplied as a dict. + """ + if data is None: + return None + if isinstance(data, (str, bytes)): + return data + elif hasattr(data, "read"): + return data + elif hasattr(data, "__iter__"): + result = [] + for k, vs in list(data.items()): + if isinstance(vs, str) or not hasattr(vs, "__iter__"): + vs = [vs] + for v in vs: + if v is not None: + result.append( + ( + k.encode("utf-8") if isinstance(k, str) else k, + v.encode("utf-8") if isinstance(v, str) else v, + ) + ) + return urlencode(result, doseq=True) + else: + return data + + +# XXX: pretty much wholesale taken from requests. May need adjustment. +# We have some similar code in `test/support.py`, and we should probably figure +# out how to merge these. +def _parse_content_type_header(header): + """Returns content type and parameters from given header + + :param header: string + :return: tuple containing content type and dictionary of + parameters + """ + + tokens = header.split(";") + content_type, params = tokens[0].strip(), tokens[1:] + params_dict = {} + items_to_strip = "\"' " + + for param in params: + param = param.strip() + if param: + key, value = param, True + index_of_equals = param.find("=") + if index_of_equals != -1: + key = param[:index_of_equals].strip(items_to_strip) + value = param[index_of_equals + 1 :].strip(items_to_strip) + params_dict[key.lower()] = value + return content_type, params_dict + + +# XXX: pretty much wholesale taken from requests. May need adjustment. +def get_encoding_from_headers(headers): + """Returns encodings from given HTTP Header Dict. + + :param headers: dictionary to extract encoding from. + :rtype: str + """ + + content_type = headers.get("content-type") + + if not content_type: + return None + + content_type, params = _parse_content_type_header(content_type) + + if "charset" in params: + return params["charset"].strip("'\"") + + # XXX: Browsers today actually use Windows-1252 as the standard default + # (some TLDs have a different default), per WHATWG. + # ISO-8859-1 comes from requests, maybe we should change it? It makes sense + # for us to generally act more like a browser than a generic HTTP tool, but + # also probably not a big deal. + if "text" in content_type: + return "ISO-8859-1" + + if "application/json" in content_type: + # Assume UTF-8 based on RFC 4627: https://www.ietf.org/rfc/rfc4627.txt since the charset was unset + return "utf-8" + + +# XXX: Everything that lazily calculates an underscore-prefixed property here +# needs an Lock, or needs to precalculate its value in the constructor or some +# sort of builder function. +class InternalHttpResponse: + """ + Internal wrapper class for HTTP responses. THIS SHOULD NEVER BE EXPOSED TO + USER CODE. This makes some things from urllib3 a little easier to deal with, + like parsing special headers, caching body content, etc. + + This is *similar* to response objects from httpx and requests, although it + lacks facilities from those libraries that we don't need or use, and takes + shortcuts that are specific to our use cases. + """ + raw: HTTPResponse + status_code: int + headers: Urllib3HTTPHeaderDict + encoding: Optional[str] = None + url: str + _content: Optional[bytes] = None + _text: Optional[str] = None + _redirect_url: Optional[str] = None + + def __init__(self, raw: HTTPResponse, request_url: str) -> None: + self.raw = raw + self.status_code = raw.status + self.headers = raw.headers + self.url = getattr(raw, 'url', request_url) + self.encoding = get_encoding_from_headers(self.headers) + + # XXX: shortcut to essentially what requests does in `iter_content()`. + # Requests has a weird thing where it uses `raw.stream()` if present, but + # always passes `decode_content=True` to it when it does the opposite for + # `raw.read()` (when `stream()` is not present). This is confusing! + # https://github.com/psf/requests/blob/147c8511ddbfa5e8f71bbf5c18ede0c4ceb3bba4/requests/models.py#L812-L833 + # + # - `stream()` has been around since urllib3 v1.10.3 (released 2015-04-21). + # Seems like you could just depend on it being there. Two theories: + # A) requests just has a lot of old code hanging around, or + # B) VCR or some mocking libraries don't implement `stream`, and just give + # requests a file-like. + # If (B), we ought to see problems in tests. + # + # - Looking at urllib3, `stream()` should just call `read()`, so I wouldn't + # think you'd want to pass a different value for `decode_content`! + # https://github.com/urllib3/urllib3/blob/90c30f5fdca56a54248614dc86570bf2692a4caa/src/urllib3/response.py#L1001-L1026 + # Theory: this is actual about compression (via the content-encoding + # header), not text encoding. The differing values still seems like a bug, + # but assuming we always wind up using `stream()`, then it makes sense + # to always set this to `True` (always decompress). + def stream(self, chunk_size: int = 10 * 1024) -> Generator[bytes, None, None]: + # If content was preloaded, it'll be in `._body`, but some mocking + # tools are missing the attribute altogether. + body = getattr(self.raw, '_body', None) + if body: + yield from iter_byte_slices(body, chunk_size) + else: + yield from self.raw.stream(chunk_size, decode_content=True) + self._release_conn() + + @property + def content(self) -> bytes: + if self._content is None: + logger.warning(f'Getting content!!!') + self._content = b"".join(self.stream()) or b"" + logger.warning(f'Getting content DONE: "{self._content}"') + + return self._content + + @property + def text(self) -> str: + if self._text is None: + encoding = self.encoding or self.sniff_encoding() or 'utf-8' + try: + self._text = str(self.content, encoding, errors="replace") + except (LookupError, TypeError): + self._text = str(self.content, errors="replace") + + return self._text + + def sniff_encoding(self) -> None: + # XXX: requests uses chardet here. Consider what we want to use. + ... + + @property + def links(self) -> dict: + """Returns the parsed header links of the response, if any.""" + + header = self.headers.get("link") + + resolved_links = {} + + if header: + links = parse_header_links(header) + + for link in links: + key = link.get("rel") or link.get("url") + resolved_links[key] = link + + return resolved_links + + @property + def redirect_url(self) -> str: + """ + The URL this response redirects to. If the response is not a redirect, + this returns an empty string. + """ + if self._redirect_url is None: + url = '' + if self.status_code >= 300 and self.status_code < 400: + location = self.headers.get('location') + if location: + url = urljoin(self.url, location) + self._redirect_url = url + return self._redirect_url + + @property + def ok(self) -> bool: + return self.status_code >= 200 and self.status_code < 300 + + # XXX: This and _release_conn probably need wrapping with RLock! + def close(self, cache: bool = True) -> None: + """ + Read the rest of the response off the wire and release the connection. + If the full response is not read, the connection can hang your program + will leak memory (and cause a bad time for the server as well). + + Parameters + ---------- + cache : bool, default: True + Whether to cache the response body so it can still be used via the + ``content`` and ``text`` properties. + """ + if self.raw: + try: + if cache: + # Inspired by requests: https://github.com/psf/requests/blob/eedd67462819f8dbf8c1c32e77f9070606605231/requests/sessions.py#L160-L163 + try: + self.content + except (DecodeError, ProtocolError, RuntimeError): + self.raw.drain_conn() + else: + self.raw.drain_conn() + finally: + self._release_conn() + + def _release_conn(self) -> None: + "Release the connection. Make sure to drain it first!" + if self.raw: + # Some mocks (e.g. VCR) are missing `.release_conn` + release_conn = getattr(self.raw, 'release_conn', None) + if release_conn is None: + # self.raw.close() + ... + else: + release_conn() + # Let go of the raw urllib3 response so we can't accidentally read + # it later when its connection might be re-used. + self.raw = None + + +class WaybackSession: """ Manages HTTP requests to Wayback Machine servers, handling things like retries, rate limiting, connection pooling, timeouts, etc. @@ -436,6 +732,25 @@ def __init__(self, retries=6, backoff=2, timeout=60, user_agent=None, # The memento limit is actually a generic Wayback limit. '/': _utils.RateLimit.make_limit(memento_calls_per_second), } + # XXX: These parameters are the same as requests, but we have had at + # least one user reach in and change the adapters we used with requests + # to modify these. We should consider whether different values are + # appropriate (e.g. block=True) or if these need to be exposed somehow. + # + # XXX: Consider using a HTTPSConnectionPool instead of a PoolManager. + # We can make some code simpler if we are always assuming the same host. + # (At current, we only use one host, so this is feasible.) + # + # XXX: Do we need a cookie jar? urllib3 doesn't do any cookie management + # for us, and the Wayback Machine may set some cookies we should retain + # in subsequent requests. (In practice, it doesn't appear the CDX, + # Memento, or Timemap APIs do by default, but not sure what happens if + # you send S3-style credentials or use other endpoints.) + self._pool_manager = PoolManager( + num_pools=10, + maxsize=10, + block=False, + ) # NOTE: the nice way to accomplish retry/backoff is with a urllib3: # adapter = requests.adapters.HTTPAdapter( # max_retries=Retry(total=5, backoff_factor=2, @@ -449,36 +764,78 @@ def __init__(self, retries=6, backoff=2, timeout=60, user_agent=None, # with Wayback's APIs, but urllib3 logs a warning on every retry: # https://github.com/urllib3/urllib3/blob/5b047b645f5f93900d5e2fc31230848c25eb1f5f/src/urllib3/connectionpool.py#L730-L737 - # Customize the built-in `send()` with retryability and rate limiting. - def send(self, request: requests.PreparedRequest, **kwargs): + def send(self, method, url, *, params=None, allow_redirects=True, timeout=-1) -> InternalHttpResponse: + if not self._pool_manager: + raise SessionClosedError('This session has already been closed ' + 'and cannot send new HTTP requests.') + start_time = time.time() maximum = self.retries retries = 0 - url = urlparse(request.url) + timeout = self.timeout if timeout is -1 else timeout + # XXX: grabbed from requests. Clean up for our use case. + if isinstance(timeout, tuple): + try: + connect, read = timeout + timeout = Urllib3Timeout(connect=connect, read=read) + except ValueError: + raise ValueError( + f"Invalid timeout {timeout}. Pass a (connect, read) timeout tuple, " + f"or a single float to set both timeouts to the same value." + ) + elif isinstance(timeout, Urllib3Timeout): + pass + else: + timeout = Urllib3Timeout(connect=timeout, read=timeout) + + parsed = urlparse(url) for path, limit in self.rate_limts.items(): - if url.path.startswith(path): + if parsed.path.startswith(path): rate_limit = limit break else: rate_limit = DEFAULT_MEMENTO_RATE_LIMIT + # Do our own querystring work since urllib3 serializes lists poorly. + if params: + serialized = serialize_querystring(params) + if parsed.query: + url += f'&{serialized}' + else: + url += f'?{serialized}' + while True: retry_delay = 0 try: - logger.debug('sending HTTP request %s "%s", %s', request.method, request.url, kwargs) + # XXX: should be `debug()`. Set to warning to testing. + logger.warning('sending HTTP request %s "%s", %s', method, url, params) rate_limit.wait() - response = super().send(request, **kwargs) + response = InternalHttpResponse(self._pool_manager.request( + method=method, + url=url, + # fields=serialize_querystring(params), + headers=self.headers, + # XXX: is allow_redirects safe for preload_content == False? + # XXX: it is, BUT THAT SKIPS OUR RATE LIMITING, which also + # is obviously already a problem today, but we ought to get + # it fixed now. Leaving this on for the moment, but it + # must be addressed before merging. + redirect=allow_redirects, + preload_content=False, + timeout=timeout + ), url) + retry_delay = self.get_retry_delay(retries, response) if retries >= maximum or not self.should_retry(response): if response.status_code == 429: - read_and_close(response) + response.close() raise RateLimitError(response, retry_delay) return response else: logger.debug('Received error response (status: %s), will retry', response.status_code) - read_and_close(response) + response.close(cache=False) # XXX: urllib3's MaxRetryError can wrap all the other errors, so # we should probably be checking `error.reason` on it. See how # requests handles this: https://github.com/psf/requests/blob/a25fde6989f8df5c3d823bc9f2e2fc24aa71f375/src/requests/adapters.py#L502-L537 @@ -493,7 +850,7 @@ def send(self, request: requests.PreparedRequest, **kwargs): except WaybackSession.handleable_errors as error: response = getattr(error, 'response', None) if response is not None: - read_and_close(response) + response.close() if retries >= maximum: raise WaybackRetryError(retries, time.time() - start_time, error) from error @@ -511,20 +868,13 @@ def send(self, request: requests.PreparedRequest, **kwargs): # We can't do this in `send` because `request` always passes a `timeout` # keyword to `send`. Inside `send`, we can't tell the difference between a # user explicitly requesting no timeout and not setting one at all. - def request(self, method, url, **kwargs): + def request(self, method, url, *, params=None, allow_redirects=True, timeout=-1) -> InternalHttpResponse: """ - Perform an HTTP request using this session. For arguments and return - values, see: - https://requests.readthedocs.io/en/latest/api/#requests.Session.request - - If the ``timeout`` keyword argument is not set, it will default to the - session's ``timeout`` attribute. + Perform an HTTP request using this session. """ - if 'timeout' not in kwargs: - kwargs['timeout'] = self.timeout - return super().request(method, url, **kwargs) + return self.send(method, url, params=params, allow_redirects=allow_redirects, timeout=timeout) - def should_retry(self, response): + def should_retry(self, response: InternalHttpResponse): # A memento may actually be a capture of an error, so don't retry it :P if is_memento_response(response): return False @@ -554,7 +904,7 @@ def should_retry_error(self, error): return False - def get_retry_delay(self, retries, response=None): + def get_retry_delay(self, retries, response: InternalHttpResponse = None): delay = 0 # As of 2023-11-27, the Wayback Machine does not set a `Retry-After` @@ -577,14 +927,12 @@ def get_retry_delay(self, retries, response=None): # proxy.clear() def reset(self): "Reset any network connections the session is using." - # Close really just closes all the adapters in `self.adapters`. We - # could do that directly, but `self.adapters` is not documented/public, - # so might be somewhat risky. - self.close(disable=False) - # Re-build the standard adapters. See: - # https://github.com/kennethreitz/requests/blob/v2.22.0/requests/sessions.py#L415-L418 - self.mount('https://', requests.adapters.HTTPAdapter()) - self.mount('http://', requests.adapters.HTTPAdapter()) + self._pool_manager.clear() + + def close(self) -> None: + if self._pool_manager: + self._pool_manager.clear() + self._pool_manager = None # TODO: add retry, backoff, cross_thread_backoff, and rate_limit options that @@ -603,7 +951,7 @@ class WaybackClient(_utils.DepthCountedContext): session : WaybackSession, optional """ def __init__(self, session=None): - self.session = session or WaybackSession() + self.session: WaybackSession = session or WaybackSession() def __exit_all__(self, type, value, traceback): self.close() @@ -830,17 +1178,18 @@ def search(self, url, *, match_type=None, limit=1000, offset=None, # to worry about it. If we don't raise here, we still want to # close the connection so it doesn't leak when we move onto # the next page of results or when this iterator ends. - read_and_close(response) + response.close() - if response.status >= 400: + if response.status_code >= 400: if 'AdministrativeAccessControlException' in response.text: raise BlockedSiteError(query['url']) elif 'RobotAccessControlException' in response.text: raise BlockedByRobotsError(query['url']) else: - raise WaybackException(f'HTTP {response.status} error for CDX search: "{query}"') + raise WaybackException(f'HTTP {response.status_code} error for CDX search: "{query}"') lines = iter(response.content.splitlines()) + logger.warning(f'Unparsed CDX lines: {response.content.splitlines()}') for line in lines: text = line.decode() @@ -1060,8 +1409,10 @@ def get_memento(self, url, timestamp=None, mode=Mode.original, *, redirect_url = detect_view_mode_redirect(response, current_date) if redirect_url: # Fix up response properties to be like other modes. - redirect = requests.Request('GET', redirect_url) - response._next = self.session.prepare_request(redirect) + # redirect = requests.Request('GET', redirect_url) + # response._next = self.session.prepare_request(redirect) + # XXX: make this publicly settable? + response._redirect_url = redirect_url response.headers['Memento-Datetime'] = current_date.strftime( '%a, %d %b %Y %H:%M:%S %Z' ) @@ -1100,11 +1451,11 @@ def get_memento(self, url, timestamp=None, mode=Mode.original, *, # rarely have been captured at the same time as the # redirect itself. (See 2b) playable = False - if response.next and ( + if response.redirect_url and ( (len(history) == 0 and not exact) or (len(history) > 0 and (previous_was_memento or not exact_redirects)) ): - target_url, target_date, _ = _utils.memento_url_data(response.next.url) + target_url, target_date, _ = _utils.memento_url_data(response.redirect_url) # A non-memento redirect is generally taking us to the # closest-in-time capture of the same URL. Note that is # NOT the next capture -- i.e. the one that would have @@ -1133,7 +1484,7 @@ def get_memento(self, url, timestamp=None, mode=Mode.original, *, playable = True if not playable: - read_and_close(response) + response.close() message = response.headers.get('X-Archive-Wayback-Runtime-Error', '') if ( ('AdministrativeAccessControlException' in message) or @@ -1162,13 +1513,13 @@ def get_memento(self, url, timestamp=None, mode=Mode.original, *, raise MementoPlaybackError(f'{response.status_code} error while loading ' f'memento at {url}') - if response.next: + if response.redirect_url: previous_was_memento = is_memento - read_and_close(response) + response.close() # Wayback sometimes has circular memento redirects ¯\_(ツ)_/¯ urls.add(response.url) - if response.next.url in urls: + if response.redirect_url in urls: raise MementoPlaybackError(f'Memento at {url} is circular') # All requests are included in `debug_history`, but @@ -1176,7 +1527,7 @@ def get_memento(self, url, timestamp=None, mode=Mode.original, *, debug_history.append(response.url) if is_memento: history.append(memento) - response = self.session.send(response.next, allow_redirects=False) + response = self.session.request('GET', response.redirect_url, allow_redirects=False) else: break diff --git a/wayback/_utils.py b/wayback/_utils.py index aab15e8..183a618 100644 --- a/wayback/_utils.py +++ b/wayback/_utils.py @@ -8,7 +8,6 @@ import time from typing import Union import urllib.parse -from .exceptions import SessionClosedError logger = logging.getLogger(__name__) @@ -268,6 +267,10 @@ def wait(self) -> None: self._last_call_time = time.time() + def reset(self) -> None: + with self._lock: + self._last_call_time = 0 + def __enter__(self) -> None: self.wait() @@ -320,30 +323,6 @@ def __exit_all__(self, type, value, traceback): pass -class DisableAfterCloseSession: - """ - A custom session object raises a :class:`SessionClosedError` if you try to - use it after closing it, to help identify and avoid potentially dangerous - code patterns. (Standard session objects continue to be usable after - closing, even if they may not work exactly as expected.) - """ - _closed: bool = False - - def close(self, disable: bool = True) -> None: - super().close() - if disable: - self._closed = True - - # XXX: this no longer works correctly, we probably need some sort of - # decorator or something - def send(self, *args, **kwargs): - if self._closed: - raise SessionClosedError('This session has already been closed ' - 'and cannot send new HTTP requests.') - - return super().send(*args, **kwargs) - - class CaseInsensitiveDict(MutableMapping): """ A case-insensitive ``dict`` subclass. diff --git a/wayback/tests/test_client.py b/wayback/tests/test_client.py index f236a76..644eef8 100644 --- a/wayback/tests/test_client.py +++ b/wayback/tests/test_client.py @@ -3,18 +3,20 @@ from pathlib import Path import time import pytest -import requests from unittest import mock from .support import create_vcr -from .._utils import SessionClosedError from .._client import (CdxRecord, Mode, WaybackSession, - WaybackClient) + WaybackClient, + DEFAULT_CDX_RATE_LIMIT, + DEFAULT_MEMENTO_RATE_LIMIT, + DEFAULT_TIMEMAP_RATE_LIMIT) from ..exceptions import (BlockedSiteError, MementoPlaybackError, NoMementoError, - RateLimitError) + RateLimitError, + SessionClosedError) ia_vcr = create_vcr() @@ -51,6 +53,13 @@ def get_file(filepath): return file.read() +@pytest.fixture(autouse=True) +def reset_default_rate_limits(): + DEFAULT_CDX_RATE_LIMIT.reset() + DEFAULT_MEMENTO_RATE_LIMIT.reset() + DEFAULT_TIMEMAP_RATE_LIMIT.reset() + + @ia_vcr.use_cassette() def test_search(): with WaybackClient() as client: @@ -208,7 +217,84 @@ def test_search_with_filter_tuple(): assert all(('feature' in v.url for v in versions)) -def test_search_removes_malformed_entries(requests_mock): +from io import BytesIO +from urllib.parse import urlparse, ParseResult, parse_qs +from urllib3 import HTTPConnectionPool, HTTPResponse, HTTPHeaderDict +import logging +class Urllib3MockManager: + def __init__(self) -> None: + self.responses = [] + + def get(self, url, responses) -> None: + url_info = urlparse(url) + if url_info.path == '': + url_info = url_info._replace(path='/') + for index, response in enumerate(responses): + repeat = True if index == len(responses) - 1 else False + self.responses.append(('GET', url_info, response, repeat)) + + def _compare_querystrings(self, actual, candidate): + for k, v in candidate.items(): + if k not in actual or actual[k] != v: + return False + return True + + def urlopen(self, pool: HTTPConnectionPool, method, url, *args, preload_content: bool = True, **kwargs): + opened_url = urlparse(url) + opened_path = opened_url.path or '/' + opened_query = parse_qs(opened_url.query) + for index, candidate in enumerate(self.responses): + candidate_url: ParseResult = candidate[1] + if ( + method == candidate[0] + and (not candidate_url.scheme or candidate_url.scheme == pool.scheme) + and (not candidate_url.hostname or candidate_url.hostname == pool.host) + and (not candidate_url.port or candidate_url.port == pool.port) + and candidate_url.path == opened_path + # This is cheap, ideally we'd parse the querystrings. + # and parse_qs(candidate_url.query) == opened_query + and self._compare_querystrings(opened_query, parse_qs(candidate_url.query)) + ): + if not candidate[3]: + self.responses.pop(index) + + data = candidate[2] + if data.get('exc'): + raise data['exc']() + + content = data.get('content') + if content is None: + content = data.get('text', '').encode() + + return HTTPResponse( + body=BytesIO(content), + headers=HTTPHeaderDict(data.get('headers', {})), + status=data.get('status_code', 200), + decode_content=False, + preload_content=preload_content, + ) + + # No matches! + raise RuntimeError( + f"No HTTP mocks matched {method} {pool.scheme}://{pool.host}{url}" + ) + + +@pytest.fixture +def urllib3_mock(monkeypatch): + manager = Urllib3MockManager() + + def urlopen_mock(self, method, url, *args, preload_content: bool = True, **kwargs): + return manager.urlopen(self, method, url, *args, preload_content=preload_content, **kwargs) + + monkeypatch.setattr( + "urllib3.connectionpool.HTTPConnectionPool.urlopen", urlopen_mock + ) + + return manager + + +def test_search_removes_malformed_entries(urllib3_mock): """ The CDX index contains many lines for things that can't actually be archived and will have no corresponding memento, like `mailto:` and `data:` @@ -223,11 +309,12 @@ def test_search_removes_malformed_entries(requests_mock): bad_cdx_data = f.read() with WaybackClient() as client: - requests_mock.get('https://web.archive.org/cdx/search/cdx' - '?url=https%3A%2F%2Fepa.gov%2F%2A' - '&from=20200418000000&to=20200419000000' - '&showResumeKey=true&resolveRevisits=true', - [{'status_code': 200, 'text': bad_cdx_data}]) + urllib3_mock.get('https://web.archive.org/cdx/search/cdx' + '?url=https%3A%2F%2Fepa.gov%2F%2A' + '&limit=1000' + '&from=20200418000000&to=20200419000000' + '&showResumeKey=true&resolveRevisits=true', + [{'status_code': 200, 'text': bad_cdx_data}]) records = client.search('https://epa.gov/*', from_date=datetime(2020, 4, 18), to_date=datetime(2020, 4, 19)) @@ -235,7 +322,7 @@ def test_search_removes_malformed_entries(requests_mock): assert 2 == len(list(records)) -def test_search_handles_no_length_cdx_records(requests_mock): +def test_search_handles_no_length_cdx_records(urllib3_mock): """ The CDX index can contain a "-" in lieu of an actual length, which can't be parsed into an int. We should handle this. @@ -247,11 +334,11 @@ def test_search_handles_no_length_cdx_records(requests_mock): bad_cdx_data = f.read() with WaybackClient() as client: - requests_mock.get('https://web.archive.org/cdx/search/cdx' - '?url=www.cnn.com%2F%2A' - '&matchType=domain&filter=statuscode%3A200' - '&showResumeKey=true&resolveRevisits=true', - [{'status_code': 200, 'text': bad_cdx_data}]) + urllib3_mock.get('https://web.archive.org/cdx/search/cdx' + '?url=www.cnn.com%2F%2A' + '&matchType=domain&filter=statuscode%3A200' + '&showResumeKey=true&resolveRevisits=true', + [{'status_code': 200, 'text': bad_cdx_data}]) records = client.search('www.cnn.com/*', match_type="domain", filter_field="statuscode:200") @@ -263,7 +350,7 @@ def test_search_handles_no_length_cdx_records(requests_mock): assert record_list[-1].length is None -def test_search_handles_bad_timestamp_cdx_records(requests_mock): +def test_search_handles_bad_timestamp_cdx_records(urllib3_mock): """ The CDX index can contain a timestamp with an invalid day "00", which can't be parsed into an timestamp. We should handle this. @@ -275,11 +362,12 @@ def test_search_handles_bad_timestamp_cdx_records(requests_mock): bad_cdx_data = f.read() with WaybackClient() as client: - requests_mock.get('https://web.archive.org/cdx/search/cdx' - '?url=www.usatoday.com%2F%2A' - '&matchType=domain&filter=statuscode%3A200' - '&showResumeKey=true&resolveRevisits=true', - [{'status_code': 200, 'text': bad_cdx_data}]) + urllib3_mock.get('https://web.archive.org/cdx/search/cdx' + '?url=www.usatoday.com%2F%2A' + '&limit=1000' + '&matchType=domain&filter=statuscode%3A200' + '&showResumeKey=true&resolveRevisits=true', + [{'status_code': 200, 'text': bad_cdx_data}]) records = client.search('www.usatoday.com/*', match_type="domain", filter_field="statuscode:200") @@ -671,96 +759,105 @@ def test_get_memento_returns_memento_with_accurate_url(): assert memento.url == 'https://www.fws.gov/' -def return_timeout(self, *args, **kwargs) -> requests.Response: +def return_timeout(self, *args, **kwargs) -> HTTPResponse: """ - Patch requests.Session.send with this in order to return a response with - the provided timeout value as the response body. + Patch urllib3.HTTPConnectionPool.urlopen with this in order to return a + response with the provided timeout value as the response body. Usage: - >>> @mock.patch('requests.Session.send', side_effect=return_timeout) + >>> @mock.patch('urllib3.HTTPConnectionPool.urlopen', side_effect=return_timeout) >>> def test_timeout(self, mock_class): - >>> assert requests.get('http://test.com', timeout=5).text == '5' + >>> assert urllib3.get('http://test.com', timeout=5).data == b'5' """ - res = requests.Response() - res.status_code = 200 - res._content = str(kwargs.get('timeout', None)).encode() + logging.warning(f'Called with args={args}, kwargs={kwargs}') + res = HTTPResponse( + body=str(kwargs.get('timeout', None)).encode(), + headers=HTTPHeaderDict(), + status=200 + ) return res +from urllib3 import Timeout as Urllib3Timeout + + class TestWaybackSession: - def test_request_retries(self, requests_mock): - requests_mock.get('http://test.com', [{'text': 'bad1', 'status_code': 503}, - {'text': 'bad2', 'status_code': 503}, - {'text': 'good', 'status_code': 200}]) + def test_request_retries(self, urllib3_mock): + urllib3_mock.get('http://test.com', [{'text': 'bad1', 'status_code': 503}, + {'text': 'bad2', 'status_code': 503}, + {'text': 'good', 'status_code': 200}]) session = WaybackSession(retries=2, backoff=0.1) response = session.request('GET', 'http://test.com') assert response.ok session.close() - def test_stops_after_given_retries(self, requests_mock): - requests_mock.get('http://test.com', [{'text': 'bad1', 'status_code': 503}, - {'text': 'bad2', 'status_code': 503}, - {'text': 'good', 'status_code': 200}]) + def test_stops_after_given_retries(self, urllib3_mock): + urllib3_mock.get('http://test.com', [{'text': 'bad1', 'status_code': 503}, + {'text': 'bad2', 'status_code': 503}, + {'text': 'good', 'status_code': 200}]) session = WaybackSession(retries=1, backoff=0.1) response = session.request('GET', 'http://test.com') assert response.status_code == 503 assert response.text == 'bad2' - def test_only_retries_some_errors(self, requests_mock): - requests_mock.get('http://test.com', [{'text': 'bad1', 'status_code': 400}, - {'text': 'good', 'status_code': 200}]) + def test_only_retries_some_errors(self, urllib3_mock): + urllib3_mock.get('http://test.com', [{'text': 'bad1', 'status_code': 400}, + {'text': 'good', 'status_code': 200}]) session = WaybackSession(retries=1, backoff=0.1) response = session.request('GET', 'http://test.com') assert response.status_code == 400 - def test_raises_rate_limit_error(self, requests_mock): - requests_mock.get('http://test.com', [WAYBACK_RATE_LIMIT_ERROR]) + def test_raises_rate_limit_error(self, urllib3_mock): + urllib3_mock.get('http://test.com', [WAYBACK_RATE_LIMIT_ERROR]) with pytest.raises(RateLimitError): session = WaybackSession(retries=0) session.request('GET', 'http://test.com') - def test_rate_limit_error_includes_retry_after(self, requests_mock): - requests_mock.get('http://test.com', [WAYBACK_RATE_LIMIT_ERROR]) + def test_rate_limit_error_includes_retry_after(self, urllib3_mock): + urllib3_mock.get('http://test.com', [WAYBACK_RATE_LIMIT_ERROR]) with pytest.raises(RateLimitError) as excinfo: session = WaybackSession(retries=0) session.request('GET', 'http://test.com') assert excinfo.value.retry_after == 10 - @mock.patch('requests.Session.send', side_effect=return_timeout) + @mock.patch('urllib3.HTTPConnectionPool.urlopen', side_effect=return_timeout) def test_timeout_applied_session(self, mock_class): # Is the timeout applied through the WaybackSession session = WaybackSession(timeout=1) res = session.request('GET', 'http://test.com') - assert res.text == '1' + assert res.text == str(Urllib3Timeout(connect=1, read=1)) # Overwriting the default in the requests method res = session.request('GET', 'http://test.com', timeout=None) - assert res.text == 'None' + assert res.text == str(Urllib3Timeout(connect=None, read=None)) res = session.request('GET', 'http://test.com', timeout=2) - assert res.text == '2' + assert res.text == str(Urllib3Timeout(connect=2, read=2)) - @mock.patch('requests.Session.send', side_effect=return_timeout) + # XXX: We should probably change this test. What we really want to test is + # that the default when unspecified in both the session and the request + # is not None. + @mock.patch('urllib3.HTTPConnectionPool.urlopen', side_effect=return_timeout) def test_timeout_applied_request(self, mock_class): # Using the default value session = WaybackSession() res = session.request('GET', 'http://test.com') - assert res.text == '60' + assert res.text == str(Urllib3Timeout(connect=60, read=60)) # Overwriting the default res = session.request('GET', 'http://test.com', timeout=None) - assert res.text == 'None' + assert res.text == str(Urllib3Timeout(connect=None, read=None)) res = session.request('GET', 'http://test.com', timeout=2) - assert res.text == '2' + assert res.text == str(Urllib3Timeout(connect=2, read=2)) - @mock.patch('requests.Session.send', side_effect=return_timeout) + @mock.patch('urllib3.HTTPConnectionPool.urlopen', side_effect=return_timeout) def test_timeout_empty(self, mock_class): # Disabling default timeout session = WaybackSession(timeout=None) res = session.request('GET', 'http://test.com') - assert res.text == 'None' + assert res.text == str(Urllib3Timeout(connect=None, read=None)) # Overwriting the default res = session.request('GET', 'http://test.com', timeout=1) - assert res.text == '1' + assert res.text == str(Urllib3Timeout(connect=1, read=1)) @ia_vcr.use_cassette() def test_search_rate_limits(self):