From 3a07c98b106154784ac12125ccaf57d8fef167c8 Mon Sep 17 00:00:00 2001 From: Antonis Christofides Date: Tue, 4 Nov 2025 16:17:08 +0200 Subject: [PATCH 1/4] Add textbisect --- docs/index.rst | 3 +- docs/textbisect.rst | 37 ++++ pyproject.toml | 1 - src/textbisect/__init__.py | 84 +++++++++ tests/textbisect/__init__.py | 0 tests/textbisect/test_textbisect.py | 269 ++++++++++++++++++++++++++++ 6 files changed, 392 insertions(+), 2 deletions(-) create mode 100644 docs/textbisect.rst create mode 100644 src/textbisect/__init__.py create mode 100644 tests/textbisect/__init__.py create mode 100644 tests/textbisect/test_textbisect.py diff --git a/docs/index.rst b/docs/index.rst index 8725138..acf6135 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -12,9 +12,10 @@ pthelma - Utilities for hydrological and meteorological time series processing .. toctree:: - :maxdepth: 2 + :maxdepth: 1 htimeseries + textbisect .. toctree:: diff --git a/docs/textbisect.rst b/docs/textbisect.rst new file mode 100644 index 0000000..84cf046 --- /dev/null +++ b/docs/textbisect.rst @@ -0,0 +1,37 @@ +========================================= +textbisect - Binary search in a text file +========================================= + +This module provides functionality to search inside sorted text +files. The lines of the files need not be all of the same length. The +module contains the following functions: + +.. function:: text_bisect_left(a, x, lo=0, hi=None, key=lambda x: x) + + Locates the insertion point for line ``x`` in seekable filelike + object ``a`` consisting of a number of lines; ``x`` must be specified + without a trailing newline. ``a`` must use ``\n`` as the newline + character and must not perform any line endings translation (use + ``open(..., newline='\n')``). The parameters ``lo`` and ``hi``, if + specified, must be absolute positions within object ``a``, and + specify which part of ``a`` to search; the default is to search the + entire ``a``. The character pointed to by ``hi`` (or the last + character of the object, if ``hi`` is unspecified) must be a newline. + ``key`` is a function that is used to compare each line of ``a`` with + ``x``; line endings are removed from the lines of ``a`` before + comparison. ``a`` must be sorted or the result will be undefined. If + ``x`` compares equal to a line in ``a``, the returned insertion point + is the beginning of that line. The initial position of ``a`` is + discarded. The function returns the insertion point, which is an + integer between ``lo`` and ``hi+1``, pointing to the beginning of a + line; when it exits, ``a`` is positioned there. + +.. function:: text_bisect_right(a, x, lo=0, hi=None, key=lambda x: x) + + The same as :func:`text_bisect_left`, except that if ``x`` compares + equal to a line in ``a``, the returned insertion point is the + beginning of the next line. + +.. function:: text_bisect(a, x, lo=0, hi=None, key=lambda x: x) + + Same as :func:`text_bisect_right`. diff --git a/pyproject.toml b/pyproject.toml index d02b359..a9db66b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -15,7 +15,6 @@ dependencies = [ # that it be <2. "numpy<2", "iso8601>=2.1,<3", - "textbisect>=0.1,<1", "tzdata", "Click>=7.0,<8.2", "simpletail>=1,<2", diff --git a/src/textbisect/__init__.py b/src/textbisect/__init__.py new file mode 100644 index 0000000..12f1b30 --- /dev/null +++ b/src/textbisect/__init__.py @@ -0,0 +1,84 @@ +from io import SEEK_END + + +class TextBisector: + def __init__(self, a, x, on_same, key): + self.a = a + self.x = x + self.on_same = on_same + self.key = key + self.ref = key(x) + + def get_hi(self, hi): + """Return hi, or the end position of the file if hi is None.""" + if hi is None: + self.a.seek(0, SEEK_END) + hi = self.a.tell() - 1 + return hi + + def get_beginning_of_line(self, pos, lo, hi): + """Return the beginning of the line containing the position pos. + On return the file is positioned at the return value.""" + while True: + if pos == lo: + return self.a.seek(pos) + self.a.seek(pos - 1) + char = self.a.read(1) + if char == "\n": + return pos + pos -= 1 + + def get_end_of_line(self, pos, lo, hi): + """Return the end of the line (the line feed) containing the position + pos. On return the file is positioned at the return value.""" + self.a.seek(pos) + while True: + if pos > hi: + raise EOFError("File must end in line feed") + char = self.a.read(1) + if char == "\n": + return self.a.seek(pos) + pos += 1 + + def get_line(self, pos, lo, hi): + """Return a tuple (line, start, end), where line is the line containing + the position pos (without the ending line feed), and start and end are + the positions of the start of the line and of the line feed. On return + the file is positioned at the beginning of the next line + (i.e. end + 1).""" + end = self.get_end_of_line(pos, lo, hi) + start = self.get_beginning_of_line(pos, lo, hi) + line = self.a.read(end - start) + return (line, start, end) + + def bisect(self, lo, hi): + # This recursive function ends when hi == lo - 1 + if hi == lo - 1: + self.a.seek(lo) + return lo + + # Otherwise, hi must not be less than lo + assert hi >= lo + + # Bisect the space and decide which way to continue + (line, start, end) = self.get_line((lo + hi) // 2, lo, hi) + val = self.key(line) + if (self.ref < val) or (self.ref == val and self.on_same == "left"): + return self.bisect(lo, start - 1) + else: + return self.bisect(end + 1, hi) + + +def text_bisect_left(a, x, lo=0, hi=None, key=lambda x: x): + bisector = TextBisector(a, x, "left", key) + hi = bisector.get_hi(hi) + return bisector.bisect(lo, hi) + + +def text_bisect_right(a, x, lo=0, hi=None, key=lambda x: x): + bisector = TextBisector(a, x, "right", key) + hi = bisector.get_hi(hi) + return bisector.bisect(lo, hi) + + +text_bisect = text_bisect_right diff --git a/tests/textbisect/__init__.py b/tests/textbisect/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/textbisect/test_textbisect.py b/tests/textbisect/test_textbisect.py new file mode 100644 index 0000000..3d0b3fc --- /dev/null +++ b/tests/textbisect/test_textbisect.py @@ -0,0 +1,269 @@ +import textwrap +from io import StringIO +from unittest import TestCase + +from textbisect import text_bisect, text_bisect_left, text_bisect_right + +testtext = textwrap.dedent( + """\ + alpha + bravo + charlie + delta + echo + foxtrot + golf + hotel + india + juillet + kilo + lima + mike + november + oscar + papa + quebec + romeo + sierra + tango + uniform + victor + whiskey + x-ray + yankee + zulu + """ +) + +# These are the positions in which each of the line of the above string start: +# +# alpha 0 +# bravo 6 +# charlie 12 +# delta 20 +# echo 26 +# foxtrot 31 +# golf 39 +# hotel 44 +# india 50 +# juillet 56 +# kilo 64 +# lima 69 +# mike 74 +# november 79 +# oscar 88 +# papa 94 +# quebec 99 +# romeo 106 +# sierra 112 +# tango 119 +# uniform 125 +# victor 133 +# whiskey 140 +# x-ray 148 +# yankee 154 +# zulu 161 +# 166 (length of file, or position of next character to be appended) + + +testtext2 = textwrap.dedent( + """\ + 1 + 003 + fivey + seven07 + ninenine9 + eleven00011 + """ +) + +# These are the positions in which each of the line of the above string start: +# +# 1 0 +# 003 2 +# fivey 6 +# seven07 12 +# ninenine9 20 +# eleven00011 30 +# 42 (length of file, or position of next character to be appended) + + +class TextBisectTestCaseBase(TestCase): + def _do_test(self, search_term, expected_result, direction="", lo=0, hi=None): + function = { + "left": text_bisect_left, + "right": text_bisect_right, + "": text_bisect, + }[direction] + pos = function(self.f, search_term, lo=lo, hi=hi, key=self.__class__.KEY) + self.assertEqual(pos, expected_result) + self.assertEqual(pos, self.f.tell()) + + +class TextBisectWithoutKeyTestCase(TextBisectTestCaseBase): + def KEY(x): + return x + + @classmethod + def setUpClass(cls): + super().setUpClass() + cls.f = StringIO(testtext) + + def test_beginning_of_file(self): + self._do_test("alice", 0) + + def test_end_of_file(self): + self._do_test("zuzu", 166) + + def test_somewhere_in_file(self): + self._do_test("somewhere", 119) + + def test_bisect_left(self): + self._do_test("lima", 69, direction="left") + + def test_bisect_right(self): + self._do_test("lima", 74, direction="right") + + def test_in_file_part_for_something_in_the_beginning_of_that_part(self): + self._do_test("bob", 106, lo=106) + + def test_when_file_part_starts_in_middle_of_line_and_result_starts_before(self): + self._do_test("bob", 108, lo=108) + + def test_when_file_part_starts_in_middle_of_line(self): + self._do_test("nick", 112, lo=108) + + def test_when_file_part_starts_in_end_of_line(self): + self._do_test("bob", 112, lo=111) + + def test_when_file_part_starts_at_end_of_file(self): + self._do_test("bob", 166, lo=166) + + def test_when_file_part_starts_at_last_character_of_file(self): + self._do_test("bob", 166, lo=165) + + def test_in_file_part_for_something_after_end_of_that_part(self): + self._do_test("nick", 74, hi=73) + + def test_when_file_part_ends_in_middle_of_line(self): + with self.assertRaises(EOFError): + self._do_test("nick", "irrelevant", hi=71) + + def test_searching_in_file_part_specified_by_both_lo_and_hi(self): + self._do_test("nick", 79, lo=64, hi=93) + + def test_beginning_of_file_part_specified_by_both_lo_and_hi(self): + self._do_test("george", 64, lo=64, hi=93) + + def test_end_of_file_part_specified_by_both_lo_and_hi(self): + self._do_test("tango", 94, lo=64, hi=93) + + def test_bisect_left_at_beginning_of_file_part_specified_by_both_lo_and_hi(self): + self._do_test("kilo", 64, direction="left", lo=64, hi=93) + + def test_bisect_right_at_beginning_of_file_part_specified_by_both_lo_and_hi(self): + self._do_test("kilo", 69, direction="right", lo=64, hi=93) + + def test_bisect_left_at_middle_of_file_part_specified_by_both_lo_and_hi(self): + self._do_test("oscar", 88, direction="left", lo=64, hi=93) + + def test_bisect_right_at_middle_of_file_part_specified_by_both_lo_and_hi(self): + self._do_test("oscar", 94, direction="right", lo=64, hi=93) + + def test_bisect_right_at_end_of_file_part_specified_by_both_lo_and_hi(self): + self._do_test("papa", 94, direction="right", lo=64, hi=93) + + +class TextBisectWithKeyTestCase(TextBisectTestCaseBase): + KEY = len + + @classmethod + def setUpClass(cls): + super().setUpClass() + cls.f = StringIO(testtext2) + + def test_beginning_of_file(self): + self._do_test("", 0) + + def test_end_of_file(self): + self._do_test("twelvetwelve", 42) + + def test_somewhere_in_file(self): + self._do_test("sixsix", 12) + + def test_bisect_left(self): + self._do_test("fiver", 6, direction="left") + + def test_bisect_right(self): + self._do_test("fiver", 12, direction="right") + + def test_in_file_part_for_something_in_the_beginning_of_that_part(self): + self._do_test("02", 12, lo=12) + + def test_when_file_part_starts_in_middle_of_line_and_result_starts_before(self): + self._do_test("02", 8, lo=8) + + def test_when_file_part_starts_in_middle_of_line(self): + self._do_test("four", 12, lo=8) + + def test_when_file_part_starts_in_end_of_line(self): + self._do_test("02", 6, lo=5) + + def test_when_file_part_starts_at_end_of_file(self): + self._do_test("any", 42, lo=41) + + def test_when_file_part_starts_at_last_character_of_file(self): + self._do_test("any", 42, lo=40) + + def test_in_file_part_for_something_after_end_of_that_part(self): + self._do_test("ten=ten=10", 12, hi=11) + + def test_when_file_part_ends_in_middle_of_line(self): + with self.assertRaises(EOFError): + self._do_test("eleven=0011", "irrelevant", hi=32) + + def test_beginning_of_file_part_specified_by_both_lo_and_hi(self): + self._do_test("four", 6, lo=6, hi=29) + + def test_searching_in_file_part_specified_by_both_lo_and_hi(self): + self._do_test("eight008", 20, lo=6, hi=29) + + def test_end_of_file_part_specified_by_both_lo_and_hi(self): + self._do_test("twelvetwelve", 30, lo=6, hi=29) + + def test_bisect_left_at_beginning_of_file_part_specified_by_both_lo_and_hi(self): + self._do_test("fiver", 6, direction="left", lo=6, hi=29) + + def test_bisect_right_at_beginning_of_file_part_specified_by_both_lo_and_hi(self): + self._do_test("fiver", 12, direction="right", lo=6, hi=29) + + def test_bisect_left_at_middle_of_file_part_specified_by_both_lo_and_hi(self): + self._do_test("nine=nine", 20, direction="left", lo=6, hi=29) + + def test_bisect_right_at_middle_of_file_part_specified_by_both_lo_and_hi(self): + self._do_test("nine=nine", 30, direction="right", lo=6, hi=29) + + def test_bisect_right_at_end_of_file_part_specified_by_both_lo_and_hi(self): + self._do_test("twelvetwelve", 30, direction="right", lo=6, hi=29) + + +class TextBisectOnlyOneLineTestCase(TextBisectTestCaseBase): + def KEY(x): + return x + + @classmethod + def setUpClass(cls): + super().setUpClass() + cls.f = StringIO("bravo\n") + + def test_before(self): + self._do_test("alpha", 0) + + def test_after(self): + self._do_test("charlie", 6) + + def test_left(self): + self._do_test("bravo", 0, direction="left") + + def test_right(self): + self._do_test("bravo", 6, direction="right") From 630a7863046bcd7b0bf5c438f1db9dd725a6db2b Mon Sep 17 00:00:00 2001 From: Antonis Christofides Date: Wed, 5 Nov 2025 13:48:59 +0200 Subject: [PATCH 2/4] Add type hinting (fixes DEV-127) --- .github/workflows/run-tests.yml | 5 +- MANIFEST.in | 2 + docs/enhydris_cache/api.rst | 7 +- pyproject.toml | 6 + src/enhydris_api_client/__init__.py | 113 ++++--- src/enhydris_api_client/py.typed | 0 src/enhydris_cache/cli.py | 77 +++-- src/enhydris_cache/enhydris_cache.py | 67 ++-- src/enhydris_cache/py.typed | 0 src/evaporation/cli.py | 7 +- src/evaporation/evaporation.py | 214 ++++++++----- src/haggregate/cli.py | 15 +- src/haggregate/haggregate.py | 87 ++++-- src/haggregate/py.typed | 0 src/haggregate/regularize.pyi | 14 + src/hspatial/cli.py | 30 +- src/hspatial/hspatial.py | 167 ++++++---- src/hspatial/test.py | 10 +- src/htimeseries/htimeseries.py | 289 +++++++++++------- src/htimeseries/timezone_utils.py | 13 +- src/rocc/__init__.py | 25 +- src/textbisect/__init__.py | 30 +- tests/enhydris_api_client/__init__.py | 60 ++-- tests/enhydris_api_client/test_e2e.py | 31 +- tests/enhydris_api_client/test_misc.py | 100 +++--- tests/enhydris_api_client/test_station.py | 130 ++++---- tests/enhydris_api_client/test_timeseries.py | 71 +++-- .../test_timeseriesgroup.py | 130 ++++---- tests/enhydris_api_client/test_tsdata.py | 66 ++-- tests/enhydris_cache/test_cli.py | 56 ++-- tests/enhydris_cache/test_enhydris_cache.py | 31 +- tests/evaporation/test_cli.py | 172 +++++++---- tests/evaporation/test_evaporation.py | 44 ++- tests/haggregate/test_cli.py | 80 +++-- tests/haggregate/test_haggregate.py | 34 ++- tests/haggregate/test_regularize.py | 26 +- tests/hspatial/test_cli.py | 75 ++--- tests/hspatial/test_hspatial.py | 169 +++++----- tests/htimeseries/test_htimeseries.py | 271 ++++++++-------- tests/htimeseries/test_timezone_utils.py | 14 +- tests/rocc/test_rocc.py | 136 +++++---- tests/textbisect/test_textbisect.py | 34 ++- 42 files changed, 1787 insertions(+), 1121 deletions(-) create mode 100644 src/enhydris_api_client/py.typed create mode 100644 src/enhydris_cache/py.typed create mode 100644 src/haggregate/py.typed create mode 100644 src/haggregate/regularize.pyi diff --git a/.github/workflows/run-tests.yml b/.github/workflows/run-tests.yml index 11a87bf..985b7f5 100644 --- a/.github/workflows/run-tests.yml +++ b/.github/workflows/run-tests.yml @@ -45,15 +45,16 @@ jobs: # numpy<2 is needed for gdal to contain support for gdal array pip install 'numpy<2' CPLUS_INCLUDE_PATH=/usr/include/gdal C_INCLUDE_PATH=/usr/include/gdal pip install --no-build-isolation 'gdal==3.8.4' - pip install coverage isort flake8 'black<25' twine setuptools build + pip install coverage isort flake8 'black<25' twine setuptools build pyright pip install -e . - name: Run Tests run: | source ~/.venv/bin/activate black --check . - flake8 --max-line-length=88 . + flake8 --extend-ignore=E501 . isort --check-only --diff --profile=black *.py . + pyright . python -m build twine check dist/* coverage run --include="./*" --omit="docs/","*/tests/*","_version.py","*.pyx" -m unittest -v diff --git a/MANIFEST.in b/MANIFEST.in index dd1c9d0..5d846af 100644 --- a/MANIFEST.in +++ b/MANIFEST.in @@ -5,5 +5,7 @@ include README.rst recursive-include tests * recursive-exclude * __pycache__ recursive-exclude * *.py[co] +recursive-include src *.pyi +recursive-include src py.typed recursive-include docs *.rst conf.py Makefile make.bat *.jpg *.png *.gif diff --git a/docs/enhydris_cache/api.rst b/docs/enhydris_cache/api.rst index 07de5ea..a9d64e4 100644 --- a/docs/enhydris_cache/api.rst +++ b/docs/enhydris_cache/api.rst @@ -12,9 +12,10 @@ enhydris-cache API is downloaded from Enhydris using the Enhydris web service API. *timeseries_groups* is a list; each item is a dictionary representing an Enhydris time series; its keys are *base_url*, - *auth_token*, *id*, and *file*; the latter is the filename of - the file to which the time series will be cached (absolute or - relative to the current working directory). + *auth_token*, *station_id*, *timeseries_group_id*, *timeseries_id*, + and *file*; the latter is the filename of the file to which the + time series will be cached (absolute or relative to the current + working directory). .. method:: update() diff --git a/pyproject.toml b/pyproject.toml index a9db66b..6c6f12c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -60,6 +60,9 @@ max-line-length = 88 [tool.setuptools] package-dir = {"" = "src"} +[tool.setuptools.package-data] +"*" = ["*.pyi", "py.typed"] + [tool.setuptools_scm] write_to = "src/pthelma/_version.py" @@ -74,3 +77,6 @@ skip = "pp* cp313-*" [tool.isort] skip = ["_version.py"] extra_standard_library = ["cpython", "libc"] + +[tool.pyright] +ignore = ["osgeo.*", "django.contrib.gis.*"] diff --git a/src/enhydris_api_client/__init__.py b/src/enhydris_api_client/__init__.py index aff0651..712c0b8 100644 --- a/src/enhydris_api_client/__init__.py +++ b/src/enhydris_api_client/__init__.py @@ -1,10 +1,17 @@ +from __future__ import annotations + +import datetime as dt from copy import copy from io import StringIO +from typing import Any, Dict, Generator, Iterable, Optional from urllib.parse import urljoin from zoneinfo import ZoneInfo import iso8601 +import pandas as pd import requests +from requests import Response, Session +from typing import cast from htimeseries import HTimeseries @@ -13,11 +20,15 @@ class MalformedResponseError(Exception): pass +JSONDict = Dict[str, Any] + + class EnhydrisApiClient: - def __init__(self, base_url, token=None): + def __init__(self, base_url: str, token: Optional[str] = None) -> None: self.base_url = base_url self.token = token - self.session = requests.Session() + self.session: Session = requests.Session() + self.response: Response | None = None if token is not None: self.session.headers.update({"Authorization": f"token {self.token}"}) @@ -31,14 +42,16 @@ def __init__(self, base_url, token=None): {"Content-Type": "application/x-www-form-urlencoded"} ) - def __enter__(self): + def __enter__(self) -> "EnhydrisApiClient": self.session.__enter__() return self - def __exit__(self, *args): + def __exit__(self, *args: object) -> None: self.session.__exit__(*args) - def check_response(self, expected_status_code=None): + def check_response(self, expected_status_code: Optional[int] = None) -> None: + if self.response is None: + raise RuntimeError("No response has been recorded") try: self._raise_HTTPError_on_error(expected_status_code=expected_status_code) except requests.HTTPError as e: @@ -47,21 +60,27 @@ def check_response(self, expected_status_code=None): f"{str(e)}. Server response: {self.response.text}" ) - def _raise_HTTPError_on_error(self, expected_status_code): + def _raise_HTTPError_on_error(self, expected_status_code: Optional[int]) -> None: self._check_status_code_is_nonerror() self._check_status_code_is_the_one_expected(expected_status_code) - def _check_status_code_is_nonerror(self): + def _check_status_code_is_nonerror(self) -> None: + if self.response is None: + raise RuntimeError("No response has been recorded") self.response.raise_for_status() - def _check_status_code_is_the_one_expected(self, expected_status_code): + def _check_status_code_is_the_one_expected( + self, expected_status_code: Optional[int] + ) -> None: + if self.response is None: + raise RuntimeError("No response has been recorded") if expected_status_code and self.response.status_code != expected_status_code: raise requests.HTTPError( f"Expected status code {expected_status_code}; " f"got {self.response.status_code} instead" ) - def get_token(self, username, password): + def get_token(self, username: str, password: str) -> Optional[str]: if not username: return @@ -74,7 +93,7 @@ def get_token(self, username, password): self.session.headers.update({"Authorization": f"token {key}"}) return key - def list_stations(self): + def list_stations(self) -> Generator[JSONDict, None, None]: url = urljoin(self.base_url, "api/stations/") while url: try: @@ -89,37 +108,37 @@ def list_stations(self): f"Malformed response from server: {str(e)}" ) - def get_station(self, station_id): + def get_station(self, station_id: int) -> JSONDict: url = urljoin(self.base_url, f"api/stations/{station_id}/") self.response = self.session.get(url) self.check_response() return self.response.json() - def post_station(self, data): + def post_station(self, data: JSONDict) -> int: self.response = self.session.post( urljoin(self.base_url, "api/stations/"), data=data ) self.check_response() return self.response.json()["id"] - def put_station(self, station_id, data): + def put_station(self, station_id: int, data: JSONDict) -> None: self.response = self.session.put( urljoin(self.base_url, f"api/stations/{station_id}/"), data=data ) self.check_response() - def patch_station(self, station_id, data): + def patch_station(self, station_id: int, data: JSONDict) -> None: self.response = self.session.patch( urljoin(self.base_url, f"api/stations/{station_id}/"), data=data ) self.check_response() - def delete_station(self, station_id): + def delete_station(self, station_id: int) -> None: url = urljoin(self.base_url, f"api/stations/{station_id}/") self.response = self.session.delete(url) self.check_response(expected_status_code=204) - def list_timeseries_groups(self, station_id): + def list_timeseries_groups(self, station_id: int) -> Generator[JSONDict, None, None]: url = urljoin(self.base_url, f"api/stations/{station_id}/timeseriesgroups/") while url: try: @@ -134,7 +153,7 @@ def list_timeseries_groups(self, station_id): f"Malformed response from server: {str(e)}" ) - def get_timeseries_group(self, station_id, timeseries_group_id): + def get_timeseries_group(self, station_id: int, timeseries_group_id: int) -> JSONDict: url = urljoin( self.base_url, f"api/stations/{station_id}/timeseriesgroups/{timeseries_group_id}/", @@ -143,13 +162,15 @@ def get_timeseries_group(self, station_id, timeseries_group_id): self.check_response() return self.response.json() - def post_timeseries_group(self, station_id, data): + def post_timeseries_group(self, station_id: int, data: JSONDict) -> int: url = urljoin(self.base_url, f"api/stations/{station_id}/timeseriesgroups/") self.response = self.session.post(url, data=data) self.check_response() return self.response.json()["id"] - def put_timeseries_group(self, station_id, timeseries_group_id, data): + def put_timeseries_group( + self, station_id: int, timeseries_group_id: int, data: JSONDict + ) -> int: url = urljoin( self.base_url, f"api/stations/{station_id}/timeseriesgroups/{timeseries_group_id}/", @@ -158,7 +179,9 @@ def put_timeseries_group(self, station_id, timeseries_group_id, data): self.check_response() return self.response.json()["id"] - def patch_timeseries_group(self, station_id, timeseries_group_id, data): + def patch_timeseries_group( + self, station_id: int, timeseries_group_id: int, data: JSONDict + ) -> None: url = urljoin( self.base_url, f"api/stations/{station_id}/timeseriesgroups/{timeseries_group_id}/", @@ -166,7 +189,7 @@ def patch_timeseries_group(self, station_id, timeseries_group_id, data): self.response = self.session.patch(url, data=data) self.check_response() - def delete_timeseries_group(self, station_id, timeseries_group_id): + def delete_timeseries_group(self, station_id: int, timeseries_group_id: int) -> None: url = urljoin( self.base_url, f"api/stations/{station_id}/timeseriesgroups/{timeseries_group_id}/", @@ -174,7 +197,7 @@ def delete_timeseries_group(self, station_id, timeseries_group_id): self.response = self.session.delete(url) self.check_response(expected_status_code=204) - def list_timeseries(self, station_id, timeseries_group_id): + def list_timeseries(self, station_id: int, timeseries_group_id: int) -> Iterable[JSONDict]: url = urljoin( self.base_url, f"api/stations/{station_id}/timeseriesgroups/{timeseries_group_id}/" @@ -184,7 +207,9 @@ def list_timeseries(self, station_id, timeseries_group_id): self.check_response() return self.response.json()["results"] - def get_timeseries(self, station_id, timeseries_group_id, timeseries_id): + def get_timeseries( + self, station_id: int, timeseries_group_id: int, timeseries_id: int + ) -> JSONDict: url = urljoin( self.base_url, f"api/stations/{station_id}/timeseriesgroups/{timeseries_group_id}/" @@ -194,7 +219,9 @@ def get_timeseries(self, station_id, timeseries_group_id, timeseries_id): self.check_response() return self.response.json() - def post_timeseries(self, station_id, timeseries_group_id, data): + def post_timeseries( + self, station_id: int, timeseries_group_id: int, data: JSONDict + ) -> int: self.response = self.session.post( urljoin( self.base_url, @@ -206,7 +233,9 @@ def post_timeseries(self, station_id, timeseries_group_id, data): self.check_response() return self.response.json()["id"] - def delete_timeseries(self, station_id, timeseries_group_id, timeseries_id): + def delete_timeseries( + self, station_id: int, timeseries_group_id: int, timeseries_id: int + ) -> None: url = urljoin( self.base_url, f"api/stations/{station_id}/timeseriesgroups/{timeseries_group_id}/" @@ -217,19 +246,19 @@ def delete_timeseries(self, station_id, timeseries_group_id, timeseries_id): def read_tsdata( self, - station_id, - timeseries_group_id, - timeseries_id, - start_date=None, - end_date=None, - timezone=None, - ): + station_id: int, + timeseries_group_id: int, + timeseries_id: int, + start_date: Optional[dt.datetime] = None, + end_date: Optional[dt.datetime] = None, + timezone: Optional[str] = None, + ) -> HTimeseries: url = urljoin( self.base_url, f"api/stations/{station_id}/timeseriesgroups/{timeseries_group_id}/" f"timeseries/{timeseries_id}/data/", ) - params = {"fmt": "hts"} + params: Dict[str, Any] = {"fmt": "hts"} tzinfo = ZoneInfo(timezone) if timezone else None dates_are_aware = (start_date is None or start_date.tzinfo is not None) and ( end_date is None or end_date.tzinfo is not None @@ -246,11 +275,17 @@ def read_tsdata( else: return HTimeseries() - def post_tsdata(self, station_id, timeseries_group_id, timeseries_id, ts): + def post_tsdata( + self, + station_id: int, + timeseries_group_id: int, + timeseries_id: int, + ts: HTimeseries, + ) -> str: f = StringIO() data = copy(ts.data) try: - data.index = data.index.tz_convert("UTC") + data.index = cast(pd.DatetimeIndex, data.index).tz_convert("UTC") except AttributeError: assert data.empty data.to_csv(f, header=False) @@ -266,8 +301,12 @@ def post_tsdata(self, station_id, timeseries_group_id, timeseries_id, ts): return self.response.text def get_ts_end_date( - self, station_id, timeseries_group_id, timeseries_id, timezone=None - ): + self, + station_id: int, + timeseries_group_id: int, + timeseries_id: int, + timezone: Optional[str] = None, + ) -> Optional[dt.datetime]: url = urljoin( self.base_url, f"api/stations/{station_id}/timeseriesgroups/{timeseries_group_id}/" diff --git a/src/enhydris_api_client/py.typed b/src/enhydris_api_client/py.typed new file mode 100644 index 0000000..e69de29 diff --git a/src/enhydris_cache/cli.py b/src/enhydris_cache/cli.py index 9cf1f1e..901b125 100644 --- a/src/enhydris_cache/cli.py +++ b/src/enhydris_cache/cli.py @@ -1,12 +1,15 @@ +from __future__ import annotations + import configparser import datetime as dt import logging import os +from typing import Any, Dict, Sequence import traceback import click -from enhydris_cache import TimeseriesCache +from enhydris_cache import TimeseriesCache, TimeseriesGroup from pthelma._version import __version__ @@ -15,16 +18,19 @@ class WrongValueError(configparser.Error): class App: - def __init__(self, configfilename): + def __init__(self, configfilename: str) -> None: self.configfilename = configfilename + self.logger: logging.Logger = logging.getLogger("spatialize") + self.config: AppConfig | None = None + self.cache: TimeseriesCache | None = None - def run(self): + def run(self) -> None: self.config = AppConfig(self.configfilename) self.config.read() self._setup_logger() self._execute_with_error_handling() - def _execute_with_error_handling(self): + def _execute_with_error_handling(self) -> None: self.logger.info("Starting enhydris-cache, " + dt.datetime.today().isoformat()) try: self._execute() @@ -35,24 +41,27 @@ def _execute_with_error_handling(self): "enhydris-cache terminated with error, " + dt.datetime.today().isoformat() ) - raise click.ClickException(str(e)) + raise click.ClickException(str(e)) from e else: self.logger.info( "Finished enhydris-cache, " + dt.datetime.today().isoformat() ) - def _setup_logger(self): - self.logger = logging.getLogger("spatialize") - self._set_logger_handler() + def _setup_logger(self) -> None: + if self.config is None: + raise RuntimeError("Configuration has not been loaded") + self._set_logger_handler(self.config) self.logger.setLevel(self.config.loglevel.upper()) - def _set_logger_handler(self): - if getattr(self.config, "logfile", None): - self.logger.addHandler(logging.FileHandler(self.config.logfile)) + def _set_logger_handler(self, config: AppConfig) -> None: + if getattr(config, "logfile", None): + self.logger.addHandler(logging.FileHandler(config.logfile)) else: self.logger.addHandler(logging.StreamHandler()) - def _execute(self): + def _execute(self) -> None: + if self.config is None: + raise RuntimeError("Configuration has not been loaded") os.chdir(self.config.cache_dir) self.cache = TimeseriesCache(self.config.timeseries_group) self.cache.update() @@ -73,50 +82,62 @@ class AppConfig: "file": {}, } - def __init__(self, configfilename): + def __init__(self, configfilename: str) -> None: self.configfilename = configfilename + self.config: configparser.ConfigParser | None = None + self.logfile: str = "" + self.loglevel: str = "WARNING" + self.cache_dir: str = os.getcwd() + self.timeseries_group: Sequence[TimeseriesGroup] = [] - def read(self): + def read(self) -> None: try: self._parse_config() except (OSError, configparser.Error) as e: - raise click.ClickException(str(e)) + raise click.ClickException(str(e)) from e - def _parse_config(self): + def _parse_config(self) -> None: self._read_config_file() self._parse_general_section() self._parse_timeseries_sections() - def _read_config_file(self): - self.config = configparser.ConfigParser(interpolation=None) + def _read_config_file(self) -> None: + config = configparser.ConfigParser(interpolation=None) with open(self.configfilename) as f: - self.config.read_file(f) + config.read_file(f) + self.config = config - def _parse_general_section(self): + def _parse_general_section(self) -> None: + if self.config is None: + raise RuntimeError("Configuration file has not been read") options = { - opt: self.config.get("General", opt, **kwargs) + opt: self.config.get("General", opt, **kwargs) # type: ignore[arg-type] for opt, kwargs in self.config_file_general_options.items() } for key, value in options.items(): setattr(self, key, value) self._parse_log_level() - def _parse_log_level(self): + def _parse_log_level(self) -> None: log_levels = ("ERROR", "WARNING", "INFO", "DEBUG") self.loglevel = self.loglevel.upper() if self.loglevel not in log_levels: raise WrongValueError("loglevel must be one of " + ", ".join(log_levels)) - def _parse_timeseries_sections(self): + def _parse_timeseries_sections(self) -> None: + if self.config is None: + raise RuntimeError("Configuration file has not been read") self.timeseries_group = [] for section in self.config: if section in ("General", "DEFAULT"): continue item = self._read_section(section) - self.timeseries_group.append(item) + self.timeseries_group.append(item) # type: ignore[arg-type] - def _read_section(self, section): - options = { + def _read_section(self, section: str) -> Dict[str, Any]: + if self.config is None: + raise RuntimeError("Configuration file has not been read") + options: Dict[str, Any] = { opt: self.config.get(section, opt, **kwargs) for opt, kwargs in self.config_file_timeseries_options.items() } @@ -132,7 +153,7 @@ def _read_section(self, section): options["timeseries_group_id"], options["timeseries_id"], ) - ) + ) from None return options @@ -141,7 +162,7 @@ def _read_section(self, section): @click.version_option( version=__version__, message="%(prog)s from pthelma v.%(version)s" ) -def main(configfile): +def main(configfile: str) -> None: """Spatial integration""" app = App(configfile) diff --git a/src/enhydris_cache/enhydris_cache.py b/src/enhydris_cache/enhydris_cache.py index 289f2c8..fc08659 100644 --- a/src/enhydris_cache/enhydris_cache.py +++ b/src/enhydris_cache/enhydris_cache.py @@ -1,4 +1,7 @@ +from __future__ import annotations + import datetime as dt +from typing import Mapping, Sequence, TypedDict, cast import pandas as pd @@ -6,47 +9,73 @@ from htimeseries import HTimeseries +class TimeseriesGroup(TypedDict): + base_url: str + station_id: int + timeseries_group_id: int + timeseries_id: int + auth_token: str | None + file: str + class TimeseriesCache(object): - def __init__(self, timeseries_group): + def __init__(self, timeseries_group: Sequence[TimeseriesGroup]) -> None: self.timeseries_group = timeseries_group + self.base_url: str | None = None + self.station_id: int | None = None + self.timeseries_group_id: int | None = None + self.timeseries_id: int | None = None + self.auth_token: str | None = None + self.filename: str | None = None - def update(self): + def update(self) -> None: for item in self.timeseries_group: - self.base_url = item["base_url"] - if self.base_url[-1] != "/": - self.base_url += "/" - self.station_id = item["station_id"] - self.timeseries_group_id = item["timeseries_group_id"] - self.timeseries_id = item["timeseries_id"] - self.auth_token = item["auth_token"] - self.filename = item["file"] + base_url = str(item["base_url"]) + if base_url[-1] != "/": + base_url += "/" + self.base_url = base_url + self.station_id = int(item["station_id"]) + self.timeseries_group_id = int(item["timeseries_group_id"]) + self.timeseries_id = int(item["timeseries_id"]) + self.auth_token = cast(str | None, item.get("auth_token")) + self.filename = str(item["file"]) self._update_for_one_timeseries() - def _update_for_one_timeseries(self): - cached_ts = self._read_timeseries_from_cache_file() + def _update_for_one_timeseries(self) -> None: + if self.filename is None: + raise ValueError("Cache filename has not been initialised") + cached_ts = self._read_timeseries_from_cache_file(self.filename) end_date = self._get_timeseries_end_date(cached_ts) start_date = end_date + dt.timedelta(minutes=1) new_ts = self._append_newer_timeseries(start_date, cached_ts) with open(self.filename, "w", encoding="utf-8") as f: new_ts.write(f, format=HTimeseries.FILE) - def _read_timeseries_from_cache_file(self): + def _read_timeseries_from_cache_file(self, filename: str) -> HTimeseries: try: - with open(self.filename, newline="\n") as f: + with open(filename, newline="\n") as f: return HTimeseries(f) except (FileNotFoundError, ValueError): # If file is corrupted or nonexistent, continue with empty time series return HTimeseries() - def _get_timeseries_end_date(self, timeseries): + def _get_timeseries_end_date(self, timeseries: HTimeseries) -> dt.datetime: try: end_date = timeseries.data.index[-1] except IndexError: # Timeseries is totally empty; no start and end date end_date = dt.datetime(1, 1, 1, 0, 0, tzinfo=dt.timezone.utc) + assert isinstance(end_date, dt.datetime) return end_date - def _append_newer_timeseries(self, start_date, old_ts): + def _append_newer_timeseries( + self, start_date: dt.datetime, old_ts: HTimeseries + ) -> HTimeseries: + if self.base_url is None: + raise ValueError("API base URL has not been set") + if self.station_id is None or self.timeseries_group_id is None: + raise ValueError("Timeseries identifiers have not been set") + if self.timeseries_id is None: + raise ValueError("Timeseries id has not been set") with EnhydrisApiClient(self.base_url, token=self.auth_token) as api_client: ts = api_client.read_tsdata( self.station_id, @@ -58,10 +87,12 @@ def _append_newer_timeseries(self, start_date, old_ts): # For appending to work properly, both time series need to have the same # tz. + oindex = cast(pd.DatetimeIndex, old_ts.data.index) + nindex = cast(pd.DatetimeIndex, new_data.index) if len(old_ts.data): - new_data.index = new_data.index.tz_convert(old_ts.data.index.tz) + new_data.index = nindex.tz_convert(oindex.tz) else: - old_ts.data.index = old_ts.data.index.tz_convert(new_data.index.tz) + old_ts.data.index = oindex.tz_convert(nindex.tz) ts.data = pd.concat( [old_ts.data, new_data], verify_integrity=True, sort=False diff --git a/src/enhydris_cache/py.typed b/src/enhydris_cache/py.typed new file mode 100644 index 0000000..e69de29 diff --git a/src/evaporation/cli.py b/src/evaporation/cli.py index 78531b1..eb5e702 100644 --- a/src/evaporation/cli.py +++ b/src/evaporation/cli.py @@ -54,7 +54,7 @@ def _setup_logger(self): def _set_logger_handler(self): if getattr(self.config, "logfile", None): - self.logger.addHandler(logging.FileHandler(self.config.logfile)) + self.logger.addHandler(logging.FileHandler(self.config.logfile)) # type: ignore[arg-type] else: self.logger.addHandler(logging.StreamHandler()) @@ -281,6 +281,7 @@ def _check_tif_hts_consistency(self, has_tif, has_hts): class ProcessAtPoint: + timezone: str | None def __init__(self, config): self.config = config @@ -327,6 +328,7 @@ def _get_input_timeseries_for_var(self, var): self.input_timeseries[var] = HTimeseries(f) def _check_all_timeseries_are_in_same_location_and_timezone(self): + reference_hts = None for i, (name, hts) in enumerate(self.input_timeseries.items()): if i == 0: reference_hts = hts @@ -408,7 +410,6 @@ def _prepare_resulting_htimeseries_object(self): self.pet = HTimeseries(default_tzinfo=tzinfo) self.pet.time_step = self.config.time_step self.pet.unit = "mm" - self.pet.timezone = self.timezone self.pet.variable = "Potential Evapotranspiration" self.pet.precision = 2 if self.config.time_step == "h" else 1 self.pet.location = self.location @@ -431,6 +432,8 @@ def _determine_variables_to_use_in_calculation(self): else "sunshine_duration" ), ) + else: + assert False, "Unreachable" self.input_vars = vars def _calculate_evaporation(self): diff --git a/src/evaporation/evaporation.py b/src/evaporation/evaporation.py index 8894ef7..3e585a6 100644 --- a/src/evaporation/evaporation.py +++ b/src/evaporation/evaporation.py @@ -1,10 +1,18 @@ +from __future__ import annotations + import datetime as dt import math import warnings from math import cos, pi, sin, tan +from typing import Any, Callable, Dict, Mapping, Optional, Sequence, Tuple, Union, cast import numpy as np +Numeric = Union[float, np.ndarray, np.ma.MaskedArray] +Converter = Callable[[Numeric], Numeric] +ExtraterrestrialRadiation = Union[Numeric, Tuple[Numeric, Numeric]] +OptionalNumeric = Optional[Numeric] + # Note about RuntimeWarning # # When numpy makes calculations with masked arrays, it sometimes emits spurious @@ -28,23 +36,25 @@ class PenmanMonteith(object): def __init__( self, - albedo, - elevation, - latitude, - time_step, - longitude=None, - nighttime_solar_radiation_ratio=None, - unit_converters={}, - ): + albedo: Union[float, np.ndarray, Sequence[float], Sequence[np.ndarray]], + elevation: float | np.ndarray, + latitude: float | np.ndarray, + time_step: str, + longitude: Optional[float | np.ndarray] = None, + nighttime_solar_radiation_ratio: Optional[Numeric] = None, + unit_converters: Optional[Mapping[str, Converter]] = None, + ) -> None: self.albedo = albedo self.nighttime_solar_radiation_ratio = nighttime_solar_radiation_ratio self.elevation = elevation self.latitude = latitude self.longitude = longitude self.time_step = time_step - self.unit_converters = unit_converters + self.unit_converters = ( + dict(unit_converters) if unit_converters is not None else {} + ) - def calculate(self, **kwargs): + def calculate(self, **kwargs: Any) -> Numeric: if self.time_step == "h": return self.calculate_hourly(**kwargs) elif self.time_step == "D": @@ -57,16 +67,16 @@ def calculate(self, **kwargs): def calculate_daily( self, - temperature_max, - temperature_min, - humidity_max, - humidity_min, - wind_speed, - adatetime, - sunshine_duration=None, - pressure=None, - solar_radiation=None, - ): + temperature_max: Numeric, + temperature_min: Numeric, + humidity_max: Numeric, + humidity_min: Numeric, + wind_speed: Numeric, + adatetime: dt.datetime, + sunshine_duration: Optional[Numeric] = None, + pressure: Optional[Numeric] = None, + solar_radiation: Optional[Numeric] = None, + ) -> Numeric: if pressure is None: # Eq. 7 p. 31 pressure = 101.3 * ((293 - 0.0065 * self.elevation) / 293) ** 5.26 @@ -80,45 +90,53 @@ def calculate_daily( pressure=pressure, ) + temperature_max_c = cast(Numeric, variables["temperature_max"]) + temperature_min_c = cast(Numeric, variables["temperature_min"]) + humidity_max_c = cast(Numeric, variables["humidity_max"]) + humidity_min_c = cast(Numeric, variables["humidity_min"]) + wind_speed_c = cast(Numeric, variables["wind_speed"]) + sunshine_duration_c = cast(Numeric, variables["sunshine_duration"]) + # Radiation - r_a, N = self.get_extraterrestrial_radiation(adatetime) + extraterrestrial_radiation = self.get_extraterrestrial_radiation(adatetime) + r_a, N = cast(Tuple[Numeric, Numeric], extraterrestrial_radiation) if solar_radiation is None: solar_radiation = ( - 0.25 + 0.50 * variables["sunshine_duration"] / N + 0.25 + 0.50 * sunshine_duration_c / N ) * r_a # Eq.35 p. 50 r_so = r_a * (0.75 + 2e-5 * self.elevation) # Eq. 37, p. 51 variables.update(self.convert_units(solar_radiation=solar_radiation)) + solar_radiation_c = cast(Numeric, variables["solar_radiation"]) with warnings.catch_warnings(): # See comment about RuntimeWarning on top of the file warnings.simplefilter("ignore", RuntimeWarning) - temperature_mean = ( - variables["temperature_max"] + variables["temperature_min"] - ) / 2 + temperature_mean = (temperature_max_c + temperature_min_c) / 2 variables["temperature_mean"] = temperature_mean - gamma = self.get_psychrometric_constant(temperature_mean, variables["pressure"]) + pressure_c = cast(Numeric, variables["pressure"]) + gamma = self.get_psychrometric_constant(temperature_mean, pressure_c) return self.penman_monteith_daily( - incoming_solar_radiation=variables["solar_radiation"], + incoming_solar_radiation=solar_radiation_c, clear_sky_solar_radiation=r_so, psychrometric_constant=gamma, - mean_wind_speed=variables["wind_speed"], - temperature_max=variables["temperature_max"], - temperature_min=variables["temperature_min"], - temperature_mean=variables["temperature_mean"], - humidity_max=variables["humidity_max"], - humidity_min=variables["humidity_min"], + mean_wind_speed=wind_speed_c, + temperature_max=temperature_max_c, + temperature_min=temperature_min_c, + temperature_mean=temperature_mean, + humidity_max=humidity_max_c, + humidity_min=humidity_min_c, adate=adatetime, ) def calculate_hourly( self, - temperature, - humidity, - wind_speed, - solar_radiation, - adatetime, - pressure=None, - ): + temperature: Numeric, + humidity: Numeric, + wind_speed: Numeric, + solar_radiation: Numeric, + adatetime: dt.datetime, + pressure: Optional[Numeric] = None, + ) -> Numeric: if pressure is None: # Eq. 7 p. 31 pressure = 101.3 * ((293 - 0.0065 * self.elevation) / 293) ** 5.26 @@ -129,36 +147,48 @@ def calculate_hourly( pressure=pressure, solar_radiation=solar_radiation, ) + temperature_c = cast(Numeric, variables["temperature"]) + humidity_c = cast(Numeric, variables["humidity"]) + wind_speed_c = cast(Numeric, variables["wind_speed"]) + pressure_c = cast(Numeric, variables["pressure"]) + solar_radiation_c = cast(Numeric, variables["solar_radiation"]) gamma = self.get_psychrometric_constant( - variables["temperature"], variables["pressure"] + temperature_c, pressure_c ) - r_so = self.get_extraterrestrial_radiation(adatetime) * ( + extraterrestrial_radiation = self.get_extraterrestrial_radiation(adatetime) + r_so = cast(Numeric, extraterrestrial_radiation) * ( 0.75 + 2e-5 * self.elevation ) # Eq. 37, p. 51 return self.penman_monteith_hourly( - incoming_solar_radiation=variables["solar_radiation"], + incoming_solar_radiation=solar_radiation_c, clear_sky_solar_radiation=r_so, psychrometric_constant=gamma, - mean_wind_speed=variables["wind_speed"], - mean_temperature=variables["temperature"], - mean_relative_humidity=variables["humidity"], + mean_wind_speed=wind_speed_c, + mean_temperature=temperature_c, + mean_relative_humidity=humidity_c, adatetime=adatetime, ) - def convert_units(self, **kwargs): - result = {} + def convert_units(self, **kwargs: OptionalNumeric) -> Dict[str, OptionalNumeric]: + result: Dict[str, OptionalNumeric] = {} for item in kwargs: varname = item if item.endswith("_max") or item.endswith("_min"): varname = item[:-4] + value = kwargs[item] + if value is None: + result[item] = None + continue converter = self.unit_converters.get(varname, lambda x: x) with warnings.catch_warnings(): # See comment about RuntimeWarning on top of the file warnings.simplefilter("ignore", RuntimeWarning) - result[item] = converter(kwargs[item]) + result[item] = converter(value) return result - def get_extraterrestrial_radiation(self, adatetime): + def get_extraterrestrial_radiation( + self, adatetime: dt.datetime | dt.date + ) -> ExtraterrestrialRadiation: """ Calculates the solar radiation we would receive if there were no atmosphere. This is a function of date, time and location. @@ -193,12 +223,17 @@ def get_extraterrestrial_radiation(self, adatetime): n = 24 / pi * omega_s # Eq. 34 p. 48 return r_a, n + # We continue with hourly + assert isinstance(adatetime, dt.datetime) + assert self.longitude is not None + # Seasonal correction for solar time, eq. 32, p. 48. b = 2 * pi * (j - 81) / 364 sc = 0.1645 * sin(2 * b) - 0.1255 * cos(b) - 0.025 * sin(b) # Longitude at the centre of the local time zone utc_offset = adatetime.utcoffset() + assert utc_offset is not None utc_offset_hours = utc_offset.days * 24 + utc_offset.seconds / 3600.0 lz = -utc_offset_hours * 15 @@ -230,7 +265,9 @@ def get_extraterrestrial_radiation(self, adatetime): ) ) - def get_psychrometric_constant(self, temperature, pressure): + def get_psychrometric_constant( + self, temperature: Numeric, pressure: Numeric + ) -> Numeric: """ Allen et al. (1998), eq. 8, p. 32. @@ -245,17 +282,17 @@ def get_psychrometric_constant(self, temperature, pressure): def penman_monteith_daily( self, - incoming_solar_radiation, - clear_sky_solar_radiation, - psychrometric_constant, - mean_wind_speed, - temperature_max, - temperature_min, - temperature_mean, - humidity_max, - humidity_min, - adate, - ): + incoming_solar_radiation: Numeric, + clear_sky_solar_radiation: Numeric, + psychrometric_constant: Numeric, + mean_wind_speed: Numeric, + temperature_max: Numeric, + temperature_min: Numeric, + temperature_mean: Numeric, + humidity_max: Numeric, + humidity_min: Numeric, + adate: dt.date, + ) -> Numeric: """ Calculates and returns the reference evapotranspiration according to Allen et al. (1998), eq. 6, p. 24 & 65. @@ -275,7 +312,7 @@ def penman_monteith_daily( # Net incoming radiation; p. 51, eq. 38 albedo = ( self.albedo[adate.month - 1] - if self.albedo.__class__.__name__ in ("tuple", "list") + if isinstance(self.albedo, Sequence) else self.albedo ) rns = (1.0 - albedo) * incoming_solar_radiation @@ -309,14 +346,14 @@ def penman_monteith_daily( def penman_monteith_hourly( self, - incoming_solar_radiation, - clear_sky_solar_radiation, - psychrometric_constant, - mean_wind_speed, - mean_temperature, - mean_relative_humidity, - adatetime, - ): + incoming_solar_radiation: Numeric, + clear_sky_solar_radiation: Numeric, + psychrometric_constant: Numeric, + mean_wind_speed: Numeric, + mean_temperature: Numeric, + mean_relative_humidity: Numeric, + adatetime: dt.datetime, + ) -> Numeric: """ Calculates and returns the reference evapotranspiration according to Allen et al. (1998), eq. 53, p. 74. @@ -336,7 +373,7 @@ def penman_monteith_hourly( # Net incoming radiation; p. 51, eq. 38 albedo = ( self.albedo[adatetime.month - 1] - if self.albedo.__class__.__name__ in ("tuple", "list") + if isinstance(self.albedo, Sequence) else self.albedo ) rns = (1.0 - albedo) * incoming_solar_radiation @@ -373,17 +410,17 @@ def penman_monteith_hourly( def get_net_outgoing_radiation( self, - temperature, - incoming_solar_radiation, - clear_sky_solar_radiation, - mean_actual_vapour_pressure, - ): + temperature: Union[Numeric, Tuple[Numeric, Numeric]], + incoming_solar_radiation: Numeric, + clear_sky_solar_radiation: Numeric, + mean_actual_vapour_pressure: Numeric, + ) -> Numeric: """ Allen et al. (1998), p. 52, eq. 39. Temperature can be a tuple (a pair) of min and max, or a single value. If it is a single value, the equation is modified according to end of page 74. """ - if temperature.__class__.__name__ in ("tuple", "list"): + if isinstance(temperature, Sequence): with warnings.catch_warnings(): # See comment about RuntimeWarning on top of the file warnings.simplefilter("ignore", RuntimeWarning) @@ -406,7 +443,7 @@ def get_net_outgoing_radiation( solar_radiation_ratio = np.where( clear_sky_solar_radiation > 0.05, incoming_solar_radiation / clear_sky_solar_radiation, - self.nighttime_solar_radiation_ratio, + self.nighttime_solar_radiation_ratio, # type: ignore ) solar_radiation_ratio = np.where( np.isnan(clear_sky_solar_radiation), float("nan"), solar_radiation_ratio @@ -421,19 +458,21 @@ def get_net_outgoing_radiation( result = np.array(result, dtype=float) return result - def get_saturation_vapour_pressure(self, temperature): + def get_saturation_vapour_pressure(self, temperature: Numeric) -> Numeric: "Allen et al. (1998), p. 36, eq. 11." with warnings.catch_warnings(): # See comment about RuntimeWarning on top of the file warnings.simplefilter("ignore") return 0.6108 * math.e ** (17.27 * temperature / (237.3 + temperature)) - def get_soil_heat_flux_density(self, incoming_solar_radiation, rn): + def get_soil_heat_flux_density( + self, incoming_solar_radiation: Numeric, rn: Numeric + ) -> Numeric: "Allen et al. (1998), p. 55, eq. 45 & 46." coefficient = np.where(incoming_solar_radiation > 0.05, 0.1, 0.5) return coefficient * rn - def get_saturation_vapour_pressure_curve_slope(self, temperature): + def get_saturation_vapour_pressure_curve_slope(self, temperature: Numeric) -> Numeric: "Allen et al. (1998), p. 37, eq. 13." numerator = 4098 * self.get_saturation_vapour_pressure(temperature) with warnings.catch_warnings(): @@ -443,7 +482,12 @@ def get_saturation_vapour_pressure_curve_slope(self, temperature): return numerator / denominator -def cloud2radiation(cloud_cover, latitude, longitude, date): +def cloud2radiation( + cloud_cover: Numeric, + latitude: float, + longitude: float, + date: dt.date, +) -> Numeric: a_s = 0.25 b_s = 0.50 dummy = 0.5 # Values not being used by get_extraterrestial_radiation @@ -454,6 +498,8 @@ def cloud2radiation(cloud_cover, latitude, longitude, date): longitude=longitude, time_step="D", ) - etrad = pm.get_extraterrestrial_radiation(date)[0] + r = pm.get_extraterrestrial_radiation(date) + assert isinstance(r, Sequence) + etrad = r[0] etrad *= 1e6 / 86400 # convert from MJ/m/day to W/s return (a_s + b_s * (1 - cloud_cover)) * etrad diff --git a/src/haggregate/cli.py b/src/haggregate/cli.py index 33f2fd8..c8a8045 100644 --- a/src/haggregate/cli.py +++ b/src/haggregate/cli.py @@ -1,8 +1,11 @@ +from __future__ import annotations + import configparser import datetime as dt import logging import os import sys +from typing import Optional import traceback import click @@ -14,7 +17,7 @@ @click.command() @click.argument("configfile") -def main(configfile): +def main(configfile: str) -> None: """Create lower-step timeseries from higher-step ones""" # Start by setting logger to stdout; later we will switch it according to config @@ -34,7 +37,7 @@ def main(configfile): target_step = config.get("General", "target_step") min_count = config.getint("General", "min_count") missing_flag = config.get("General", "missing_flag") - target_timestamp_offset = config.get( + target_timestamp_offset: Optional[str] = config.get( "General", "target_timestamp_offset", fallback=None ) @@ -55,8 +58,11 @@ def main(configfile): # Read each section and do the work for it for section_name in config.sections(): section = config[section_name] - source_filename = os.path.join(base_dir, section.get("source_file")) - target_filename = os.path.join(base_dir, section.get("target_file")) + source_file = section.get("source_file") + target_file = section.get("target_file") + assert source_file is not None and target_file is not None + source_filename = os.path.join(base_dir, source_file) + target_filename = os.path.join(base_dir, target_file) method = section.get("method") with open(source_filename, newline="\n") as f: ts = HTimeseries( @@ -67,6 +73,7 @@ def main(configfile): else: regularization_mode = RegularizationMode.INTERVAL regts = regularize(ts, new_date_flag="DATEINSERT", mode=regularization_mode) + assert method is not None aggts = aggregate( regts, target_step, diff --git a/src/haggregate/haggregate.py b/src/haggregate/haggregate.py index 30206d6..8741850 100644 --- a/src/haggregate/haggregate.py +++ b/src/haggregate/haggregate.py @@ -1,4 +1,8 @@ +from __future__ import annotations + +import datetime as dt from enum import Enum +from typing import Any, Callable, Dict, Optional import numpy as np import pandas as pd @@ -18,13 +22,13 @@ class AggregateError(Exception): def aggregate( - hts, - target_step, - method, - min_count=1, - missing_flag="MISS", - target_timestamp_offset=None, -): + hts: HTimeseries, + target_step: str, + method: str, + min_count: int = 1, + missing_flag: str = "MISS", + target_timestamp_offset: Optional[str] = None, +) -> HTimeseries: aggregation = Aggregation( source_timeseries=hts, target_step=target_step, @@ -38,14 +42,28 @@ def aggregate( class Aggregation: - def __init__(self, **kwargs): - for key, value in kwargs.items(): - setattr(self, key, value) + def __init__( + self, + *, + source_timeseries: HTimeseries, + target_step: str, + method: str, + min_count: int, + missing_flag: str, + target_timestamp_offset: Optional[str], + ) -> None: + self.source_timeseries = source_timeseries + self.target_step = target_step + self.method = method + self.min_count = min_count + self.missing_flag = missing_flag + self.target_timestamp_offset = target_timestamp_offset self.source = SourceTimeseries(self.source_timeseries) self.result = AggregatedTimeseries() self.result.time_step = self.target_step + self.resampler: Any = None - def execute(self): + def execute(self) -> None: self.result.set_metadata(self.source_timeseries) try: self.source.normalize(self.target_step) @@ -55,24 +73,32 @@ def execute(self): self.result.remove_leading_and_trailing_nans() self.result.add_timestamp_offset(self.target_timestamp_offset) - def do_aggregation(self): + def do_aggregation(self) -> None: self.create_resampler() self.get_result_values() self.get_result_flags() - def create_resampler(self): + def create_resampler(self) -> None: self.resampler = self.source.data["value"].resample( self.result.time_step, closed="right", label="right" ) - def get_result_values(self): + def get_result_values(self) -> None: + if self.resampler is None: + raise RuntimeError("Resampler has not been initialised") result_values = self.resampler.agg(methods[self.method]) values_count = self.resampler.count() result_values[values_count < self.min_count] = np.nan self.result.data["value"] = result_values - def get_result_flags(self): - max_count = int(pd.Timedelta(self.result.time_step) / self.source.freq) + def get_result_flags(self) -> None: + if self.resampler is None: + raise RuntimeError("Resampler has not been initialised") + assert self.source.freq is not None + source_time_step = dt.timedelta(microseconds=self.source.freq.nanos / 1000) + result_time_step = pd.Timedelta(self.result.time_step) + assert isinstance(result_time_step, dt.timedelta) + max_count = int(result_time_step / source_time_step) values_count = self.resampler.count() self.result.data["flags"] = (max_count - values_count).apply( lambda x: self.missing_flag.format(x) if x else "" @@ -83,16 +109,23 @@ class CannotInferFrequency(Exception): pass -attrs = ("unit", "timezone", "interval_type", "variable", "precision", "location") +attrs: tuple[str, ...] = ( + "unit", + "timezone", + "interval_type", + "variable", + "precision", + "location", +) class SourceTimeseries(HTimeseries): - def __init__(self, s): + def __init__(self, s: HTimeseries) -> None: for attr in attrs: setattr(self, attr, getattr(s, attr, None)) self.data = s.data - def normalize(self, target_step): + def normalize(self, target_step: str) -> None: """Reindex so that it has no missing records but has NaNs instead, starting from one before and ending in one after. """ @@ -106,17 +139,19 @@ def normalize(self, target_step): except ValueError: raise CannotInferFrequency() first_timestamp = (current_range[0] - pd.Timedelta("1s")).floor(target_step) - end_timestamp = current_range[-1].ceil(target_step) + end_timestamp = current_range[-1].ceil(target_step) # type: ignore[assignment] new_range = pd.date_range(first_timestamp, end_timestamp, freq=self.freq) self.data = self.data.reindex(new_range) class AggregatedTimeseries(HTimeseries): - def set_metadata(self, source_timeseries): + def set_metadata(self, source_timeseries: HTimeseries) -> None: for attr in attrs: setattr(self, attr, getattr(source_timeseries, attr, None)) try: - if pd.Timedelta(self.time_step) <= pd.Timedelta(0): + time_step = pd.Timedelta(self.time_step) + assert isinstance(time_step, dt.timedelta) + if time_step <= pd.Timedelta(0): raise ValueError("Non-positive time step") except ValueError as e: raise AggregateError( @@ -130,13 +165,13 @@ def set_metadata(self, source_timeseries): + source_timeseries.comment ) - def remove_leading_and_trailing_nans(self): + def remove_leading_and_trailing_nans(self) -> None: while len(self.data.index) > 0 and pd.isnull(self.data["value"]).iloc[0]: - self.data = self.data.drop(self.data.index[0]) + self.data = self.data.drop(self.data.index[0]) # type: ignore[index] while len(self.data.index) > 0 and pd.isnull(self.data["value"]).iloc[-1]: - self.data = self.data.drop(self.data.index[-1]) + self.data = self.data.drop(self.data.index[-1]) # type: ignore[index] - def add_timestamp_offset(self, target_timestamp_offset): + def add_timestamp_offset(self, target_timestamp_offset: Optional[str]) -> None: if target_timestamp_offset: periods = target_timestamp_offset.startswith("-") and 1 or -1 freq = target_timestamp_offset.lstrip("-") diff --git a/src/haggregate/py.typed b/src/haggregate/py.typed new file mode 100644 index 0000000..e69de29 diff --git a/src/haggregate/regularize.pyi b/src/haggregate/regularize.pyi new file mode 100644 index 0000000..64e2443 --- /dev/null +++ b/src/haggregate/regularize.pyi @@ -0,0 +1,14 @@ +from __future__ import annotations + +from htimeseries import HTimeseries + +from .haggregate import RegularizationMode + +class RegularizeError(Exception): + ... + +def regularize( + ts: HTimeseries, + new_date_flag: str = ..., + mode: RegularizationMode = ..., +) -> HTimeseries: ... diff --git a/src/hspatial/cli.py b/src/hspatial/cli.py index 0b0d1d4..f853fc0 100644 --- a/src/hspatial/cli.py +++ b/src/hspatial/cli.py @@ -53,7 +53,7 @@ def _setup_logger(self): def _set_logger_handler(self): if getattr(self.config, "logfile", None): - self.logger.addHandler(logging.FileHandler(self.config.logfile)) + self.logger.addHandler(logging.FileHandler(self.config.logfile)) # type: ignore[arg-type] else: self.logger.addHandler(logging.StreamHandler()) @@ -64,14 +64,18 @@ def _get_last_dates(self, filename, n): series is too small). 'filename' is used in error messages. """ # Get the time zone + line = None with open(filename) as fp: for line in fp: if line.startswith("Timezone") or (line and line[0] in "0123456789"): break + assert line is not None if not line.startswith("Timezone"): raise click.ClickException("{} does not contain Timezone".format(filename)) zonestr = line.partition("=")[2].strip() - timezone = TzinfoFromString(zonestr) + utcoffset = TzinfoFromString(zonestr).utcoffset(None) + assert utcoffset is not None + timezone = dt.timezone(offset=utcoffset) result = [] previous_line_was_empty = False @@ -145,15 +149,17 @@ def _date_fmt(self): """ Determine date_fmt based on time series time step. """ - if self._time_step.endswith("min") or self._time_step.endswith("h"): + time_step = self._time_step + assert time_step is not None + if time_step.endswith("min") or time_step.endswith("h"): return "%Y-%m-%d %H:%M%z" - elif self._time_step.endswith("D"): + elif time_step.endswith("D"): return "%Y-%m-%d" - elif self._time_step.endswith("M"): + elif time_step.endswith("M"): return "%Y-%m" - elif self._time_step.endswith("Y"): + elif time_step.endswith("Y"): return "%Y" - raise click.ClickException("Can't use time step " + str(self._time_step)) + raise click.ClickException(f"Can't use time step {time_step}") def _delete_obsolete_files(self): """ @@ -161,7 +167,7 @@ def _delete_obsolete_files(self): where N is the 'number_of_output_files' configuration option. """ pattern = os.path.join( - self.config.output_dir, "{}-*.tif".format(self.config.filename_prefix) + self.config.output_dir, "{}-*.tif".format(self.config.filename_prefix) # type: ignore[str-format] ) files = glob(pattern) files.sort() @@ -176,10 +182,10 @@ def _execute(self): ) # Get mask - mask = gdal.Open(self.config.mask) + mask = gdal.Open(self.config.mask) # type: ignore[arg-type] # Setup integration method - if self.config.method == "idw": + if self.config.method == "idw": # type: ignore[comparison-overlap] funct = idw kwargs = {"alpha": self.config.alpha} else: @@ -191,7 +197,7 @@ def _execute(self): mask, stations_layer, date, - os.path.join(self.config.output_dir, self.config.filename_prefix), + os.path.join(self.config.output_dir, self.config.filename_prefix), # type: ignore[str-format] self._date_fmt, funct, kwargs, @@ -269,7 +275,7 @@ def _parse_files(self): def _check_method(self): # Check method - if self.method != "idw": + if self.method != "idw": # type: ignore[comparison-overlap] raise WrongValueError('Option "method" can currently only be idw') # Check alpha try: diff --git a/src/hspatial/hspatial.py b/src/hspatial/hspatial.py index f85bee7..1d2cca0 100644 --- a/src/hspatial/hspatial.py +++ b/src/hspatial/hspatial.py @@ -1,8 +1,11 @@ +from __future__ import annotations + import datetime as dt import os import struct from glob import glob from math import isnan +from typing import Any, Callable, Dict, Iterable, List, Mapping, Optional, Union import iso8601 import numpy as np @@ -18,8 +21,12 @@ NODATAVALUE = -(2.0**127) +GeometryLike = Union[GeoDjangoPoint, ogr.Geometry] +KwargsMapping = Mapping[str, Any] +InterpolationFunction = Callable[..., float] + -def coordinates2point(x, y, srid=4326): +def coordinates2point(x: float, y: float, srid: int = 4326) -> ogr.Geometry: point = ogr.Geometry(ogr.wkbPoint) sr = osr.SpatialReference() sr.ImportFromEPSG(srid) @@ -30,7 +37,7 @@ def coordinates2point(x, y, srid=4326): return point -def idw(point, data_layer, alpha=1): +def idw(point: ogr.Geometry, data_layer: ogr.Layer, alpha: float = 1) -> float: data_layer.ResetReading() features = [f for f in data_layer if not isnan(f.GetField("value"))] distances = np.array([point.Distance(f.GetGeometryRef()) for f in features]) @@ -44,7 +51,14 @@ def idw(point, data_layer, alpha=1): return (weights * values).sum() -def integrate(dataset, data_layer, target_band, funct, kwargs={}): +def integrate( + dataset: Any, + data_layer: ogr.Layer, + target_band: Any, + funct: InterpolationFunction, + kwargs: Optional[KwargsMapping] = None, +) -> None: + call_kwargs: Dict[str, Any] = dict(kwargs or {}) try: mask = dataset.GetRasterBand(1).ReadAsArray() != 0 x_left, x_step, d1, y_top, d2, y_step = dataset.GetGeoTransform() @@ -60,12 +74,12 @@ def integrate(dataset, data_layer, target_band, funct, kwargs={}): xarray, yarray = np.meshgrid(xcoords, ycoords) # Create a ufunc that makes the interpolation given the above arrays - def interpolate_one_point(x, y, mask): - if not mask: + def interpolate_one_point(x: float, y: float, mask_value: bool) -> float: + if not mask_value: return np.nan point = ogr.Geometry(ogr.wkbPoint) point.AddPoint(x, y) - return funct(point, data_layer, **kwargs) + return funct(point, data_layer, **call_kwargs) interpolate = np.vectorize(interpolate_one_point, otypes=[np.float32]) @@ -80,7 +94,9 @@ def interpolate_one_point(x, y, mask): target_band.data(data=result) -def create_ogr_layer_from_timeseries(filenames, epsg, data_source): +def create_ogr_layer_from_timeseries( + filenames: Iterable[str], epsg: int, data_source: ogr.DataSource # type: ignore[type-alias] +) -> ogr.Layer: # Prepare the co-ordinate transformation from WGS84 to epsg source_sr = osr.SpatialReference() source_sr.ImportFromEPSG(4326) @@ -99,6 +115,7 @@ def create_ogr_layer_from_timeseries(filenames, epsg, data_source): # we only use the location. ts = HTimeseries(f, default_tzinfo=dt.timezone.utc) point = ogr.Geometry(ogr.wkbPoint) + assert ts.location is not None point.AddPoint(ts.location["abscissa"], ts.location["ordinate"]) point.Transform(transform) f = ogr.Feature(layer.GetLayerDefn()) @@ -108,7 +125,9 @@ def create_ogr_layer_from_timeseries(filenames, epsg, data_source): return layer -def _needs_calculation(output_filename, date, stations_layer): +def _needs_calculation( + output_filename: str, date: dt.datetime, stations_layer: ogr.Layer +) -> bool: """ Used by h_integrate to check whether the output file needs to be calculated or not. It does not need to be calculated if it already exists and has been @@ -151,7 +170,7 @@ def _needs_calculation(output_filename, date, stations_layer): with open(filename, newline="\n") as f: t = HTimeseries(f) try: - value = t.data.loc[date, "value"] + value = t.data.loc[date, "value"] # type: ignore[index] if not isnan(value): return True except KeyError: @@ -162,8 +181,14 @@ def _needs_calculation(output_filename, date, stations_layer): def h_integrate( - mask, stations_layer, date, output_filename_prefix, date_fmt, funct, kwargs -): + mask: gdal.Dataset, + stations_layer: ogr.Layer, + date: dt.datetime, + output_filename_prefix: str, + date_fmt: str, + funct: InterpolationFunction, + kwargs: Optional[KwargsMapping], +) -> None: date_fmt_for_filename = date.strftime(date_fmt).replace(" ", "-").replace(":", "-") output_filename = "{}-{}.tif".format( output_filename_prefix, date.strftime(date_fmt_for_filename) @@ -184,7 +209,7 @@ def h_integrate( if unit_of_measurement is None and hasattr(t, "unit"): unit_of_measurement = t.unit try: - value = t.data.loc[date, "value"] + value = t.data.loc[date, "value"] # type: ignore[index] except KeyError: value = np.nan station.SetField("value", value) @@ -217,32 +242,30 @@ def h_integrate( class PassepartoutPoint: """Uniform interface for GeoDjango Point and OGR Point.""" - def __init__(self, point): - self.point = point + def __init__(self, point: GeometryLike) -> None: + self.point: GeometryLike = point - def transform_to(self, target_srs_wkt): + def transform_to(self, target_srs_wkt: str) -> "PassepartoutPoint": point = self.clone(self.point) if isinstance(self.point, GeoDjangoPoint): - source_srs = point.srs or SpatialReference(4326) - try: - source_srs.SetAxisMappingStrategy(osr.OAMS_TRADITIONAL_GIS_ORDER) - except AttributeError: - pass + source_srs = getattr(point, "srs") or SpatialReference("4326") + if hasattr(source_srs, "SetAxisMappingStrategy"): + source_srs.SetAxisMappingStrategy(osr.OAMS_TRADITIONAL_GIS_ORDER) # type: ignore[attr-defined] ct = CoordTransform(source_srs, SpatialReference(target_srs_wkt)) - point.transform(ct) + point.transform(ct) # type: ignore[attr-defined] return PassepartoutPoint(point) else: - point_sr = point.GetSpatialReference() + point_sr = point.GetSpatialReference() # type: ignore[attr-defined] raster_sr = osr.SpatialReference() raster_sr.ImportFromWkt(target_srs_wkt) if int(gdal.__version__.split(".")[0]) > 2: point_sr.SetAxisMappingStrategy(osr.OAMS_TRADITIONAL_GIS_ORDER) raster_sr.SetAxisMappingStrategy(osr.OAMS_TRADITIONAL_GIS_ORDER) transform = osr.CoordinateTransformation(point_sr, raster_sr) - point.Transform(transform) + point.Transform(transform) # type: ignore[attr-defined] return PassepartoutPoint(point) - def clone(self, original_point): + def clone(self, original_point: GeometryLike) -> GeometryLike: if isinstance(original_point, GeoDjangoPoint): return GeoDjangoPoint( original_point.x, original_point.y, original_point.srid @@ -254,21 +277,23 @@ def clone(self, original_point): return point @property - def x(self): - try: + def x(self) -> float: + if isinstance(self.point, GeoDjangoPoint): return self.point.x - except AttributeError: + else: return self.point.GetX() @property - def y(self): - try: + def y(self) -> float: + if isinstance(self.point, GeoDjangoPoint): return self.point.y - except AttributeError: + else: return self.point.GetY() -def extract_point_from_raster(point, data_source, band_number=1): +def extract_point_from_raster( + point: GeometryLike, data_source: Any, band_number: int = 1 +) -> float: """Return floating-point value that corresponds to given point.""" pppoint = PassepartoutPoint(point) @@ -291,7 +316,7 @@ def extract_point_from_raster(point, data_source, band_number=1): except AttributeError: forward_transform = Affine.from_gdal(*data_source.geotransform) reverse_transform = ~forward_transform - px, py = reverse_transform * (pppoint.x, pppoint.y) + px, py = reverse_transform * (pppoint.x, pppoint.y) # type: ignore[operator] px, py = int(px), int(py) # Extract pixel value @@ -314,35 +339,45 @@ def extract_point_from_raster(point, data_source, band_number=1): class PointTimeseries: - def __init__(self, point, **kwargs): - self.point = point - filenames = kwargs.pop("filenames", None) - self.prefix = kwargs.pop("prefix", None) + def __init__( + self, + point: GeometryLike, + *, + filenames: Optional[Iterable[str]] = None, + prefix: Optional[str] = None, + date_fmt: Optional[str] = None, + start_date: Optional[dt.datetime] = None, + end_date: Optional[dt.datetime] = None, + default_time: dt.time = dt.time(0, 0, tzinfo=dt.timezone.utc), + ) -> None: + self.point: GeometryLike = point + self.prefix = prefix + self.filename_format: Optional[FilenameWithDateFormat] = None assert filenames is None or self.prefix is None assert filenames is not None or self.prefix is not None - self.date_fmt = kwargs.pop("date_fmt", None) - self.start_date = kwargs.pop("start_date", None) - self.end_date = kwargs.pop("end_date", None) - self.default_time = kwargs.pop( - "default_time", dt.time(0, 0, tzinfo=dt.timezone.utc) - ) + self.date_fmt = date_fmt + self.start_date = start_date + self.end_date = end_date + self.default_time = default_time if self.default_time.tzinfo is None: raise TypeError("default_time must be aware") if self.start_date and self.start_date.tzinfo is None: self.start_date = self.start_date.replace(tzinfo=self.default_time.tzinfo) if self.end_date and self.end_date.tzinfo is None: self.end_date = self.end_date.replace(tzinfo=self.default_time.tzinfo) - self.filenames = self._get_filenames(filenames) + self.filenames: List[str] = self._get_filenames(filenames) - def _get_filenames(self, filenames): + def _get_filenames(self, filenames: Optional[Iterable[str]]) -> List[str]: if self.prefix is None: - return filenames - filenames = glob(self.prefix + "-*.tif") + assert filenames is not None + return list(filenames) + filenames_list = glob(self.prefix + "-*.tif") + assert self.default_time.tzinfo is not None self.filename_format = FilenameWithDateFormat( self.prefix, date_fmt=self.date_fmt, tzinfo=self.default_time.tzinfo ) - result = [] - for filename in filenames: + result: List[str] = [] + for filename in filenames_list: date = self.filename_format.get_date(filename) is_after_start_date = (self.start_date is None) or (date >= self.start_date) is_before_end_date = (self.end_date is None) or (date <= self.end_date) @@ -350,7 +385,7 @@ def _get_filenames(self, filenames): result.append(filename) return result - def get(self): + def get(self) -> HTimeseries: result = HTimeseries(default_tzinfo=self.default_time.tzinfo) for filename in self.filenames: f = gdal.Open(filename) @@ -365,8 +400,10 @@ def get(self): result.data = result.data.sort_index() return result - def _get_timestamp(self, f): + def _get_timestamp(self, f: Any) -> dt.datetime: isostring = f.GetMetadata()["TIMESTAMP"] + assert self.default_time.tzinfo is not None + assert isinstance(self.default_time.tzinfo, dt.timezone) timestamp = iso8601.parse_date( isostring, default_timezone=self.default_time.tzinfo ) @@ -374,14 +411,16 @@ def _get_timestamp(self, f): timestamp = dt.datetime.combine(timestamp.date(), self.default_time) return timestamp - def _get_unit_of_measurement(self, f, ahtimeseries): + def _get_unit_of_measurement(self, f: Any, ahtimeseries: HTimeseries) -> None: if hasattr(ahtimeseries, "unit"): return unit = f.GetMetadataItem("UNIT") if unit is not None: ahtimeseries.unit = unit - def get_cached(self, dest, force=False, version=4): + def get_cached( + self, dest: str, force: bool = False, version: int = 4 + ) -> HTimeseries: assert self.prefix ts = self._get_saved_timeseries_if_updated_else_none(dest, force) if ts is None: @@ -390,33 +429,45 @@ def get_cached(self, dest, force=False, version=4): ts.write(f, format=HTimeseries.FILE, version=version) return ts - def _get_saved_timeseries_if_updated_else_none(self, dest, force): + def _get_saved_timeseries_if_updated_else_none( + self, dest: str, force: bool + ) -> Optional[HTimeseries]: if force or not os.path.exists(dest): return None else: return self._get_timeseries_if_file_is_up_to_date_else_none(dest) - def _get_timeseries_if_file_is_up_to_date_else_none(self, dest): + def _get_timeseries_if_file_is_up_to_date_else_none( + self, dest: str + ) -> Optional[HTimeseries]: with open(dest, "r", newline="") as f: ts = HTimeseries(f, default_tzinfo=self.default_time.tzinfo) for filename in self.filenames: + assert self.filename_format is not None if not self.filename_format.get_date(filename) in ts.data.index: return None return ts class FilenameWithDateFormat: - def __init__(self, prefix, *, date_fmt=None, tzinfo): + def __init__( + self, + prefix: str, + *, + date_fmt: Optional[str] = None, + tzinfo: dt.tzinfo, + ) -> None: self.prefix = prefix self.date_fmt = date_fmt self.tzinfo = tzinfo - def get_date(self, filename): + def get_date(self, filename: str) -> dt.datetime: datestr = self._extract_datestr(filename) self._ensure_we_have_date_fmt(datestr) + assert self.date_fmt is not None return dt.datetime.strptime(datestr, self.date_fmt).replace(tzinfo=self.tzinfo) - def _ensure_we_have_date_fmt(self, datestr): + def _ensure_we_have_date_fmt(self, datestr: str) -> None: if self.date_fmt is not None: pass elif datestr.count("-") == 4: @@ -426,7 +477,7 @@ def _ensure_we_have_date_fmt(self, datestr): else: raise ValueError("Invalid date " + datestr) - def _extract_datestr(self, filename): + def _extract_datestr(self, filename: str) -> str: assert filename.startswith(self.prefix + "-") assert filename.endswith(".tif") startpos = len(self.prefix) + 1 diff --git a/src/hspatial/test.py b/src/hspatial/test.py index 5c882f8..11944bd 100644 --- a/src/hspatial/test.py +++ b/src/hspatial/test.py @@ -1,8 +1,14 @@ +import datetime as dt import numpy as np from osgeo import gdal, osr - -def setup_test_raster(filename, value, timestamp=None, srid=4326, unit=None): +def setup_test_raster( + filename: str, + value: np.ndarray[np.float64, np.dtype[np.float64]], + timestamp: dt.datetime | dt.date | None = None, + srid: int = 4326, + unit: str | None = None, +): """Save value, which is an np array, to a GeoTIFF file.""" nodata = 1e8 value[np.isnan(value)] = nodata diff --git a/src/htimeseries/htimeseries.py b/src/htimeseries/htimeseries.py index 225305a..7ecba0d 100644 --- a/src/htimeseries/htimeseries.py +++ b/src/htimeseries/htimeseries.py @@ -1,7 +1,10 @@ +from __future__ import annotations + import csv import datetime as dt from configparser import ParsingError from io import StringIO +from typing import IO, Any, Callable, Dict, List, Optional, Sequence, Tuple, Union, cast import numpy as np import pandas as pd @@ -10,14 +13,17 @@ from .timezone_utils import TzinfoFromString +TextFileLike = IO[str] +FileLike = Union[IO[str], IO[bytes], "_BacktrackableFile"] + -class _BacktrackableFile(object): - def __init__(self, fp): - self.fp = fp - self.line_number = 0 - self.next_line = None +class _BacktrackableFile: + def __init__(self, fp: FileLike) -> None: + self.fp: FileLike = fp + self.line_number: int = 0 + self.next_line: Optional[Union[str, bytes]] = None - def readline(self): + def readline(self) -> Union[str, bytes]: if self.next_line is None: self.line_number += 1 result = self.fp.readline() @@ -26,17 +32,23 @@ def readline(self): self.next_line = None return result - def backtrack(self, line): + def backtrack(self, line: Union[str, bytes]) -> None: self.next_line = line - def read(self, size=None): + def read(self, size: Optional[int] = None) -> Union[str, bytes]: return self.fp.read() if size is None else self.fp.read(size) - def __getattr__(self, name): + def __iter__(self) -> "_BacktrackableFile": + return self + + def __next__(self) -> Union[str, bytes]: + return self.fp.__next__() + + def __getattr__(self, name: str) -> Any: return getattr(self.fp, name) -class _FilePart(object): +class _FilePart: """A wrapper that views only a subset of the wrapped csv filelike object. When it is created, three mandatory parameters are passed: a filelike object, @@ -44,42 +56,42 @@ class _FilePart(object): object, which views the part of the wrapped object between start_date and end_date. """ - def __init__(self, stream, start_date, end_date): - self.stream = stream + def __init__(self, stream: TextFileLike, start_date: str, end_date: str) -> None: + self.stream: TextFileLike = stream self.start_date = start_date self.end_date = end_date lo = stream.tell() - key = lambda x: x.split(",")[0] # NOQA - self.startpos = text_bisect_left(stream, start_date, lo=lo, key=key) - if self.stream.tell() < self.startpos: - self.stream.seek(self.startpos) + key: Callable[[str], str] = lambda x: x.split(",")[0] # NOQA + self.startpos: int = text_bisect_left(stream, start_date, lo=lo, key=key) + self.endpos: int = text_bisect_left(stream, end_date, lo=self.startpos, key=key) + self.stream.seek(self.startpos) - def readline(self, size=-1): + def readline(self, size: int = -1) -> str: max_available_size = self.endpos + 1 - self.stream.tell() size = min(size, max_available_size) return self.stream.readline(size) - def __iter__(self): + def __iter__(self) -> "_FilePart": return self - def __next__(self): + def __next__(self) -> str: result = self.stream.__next__() if result[:16] > self.end_date: raise StopIteration return result - def __getattr__(self, name): + def __getattr__(self, name: str) -> Any: return getattr(self.stream, name) class MetadataWriter: - def __init__(self, f, htimeseries, version): + def __init__(self, f: IO[str], htimeseries: "HTimeseries", version: int) -> None: self.version = version self.htimeseries = htimeseries self.f = f - def write_meta(self): + def write_meta(self) -> None: if self.version == 2: self.f.write("Version=2\r\n") self.write_simple("unit") @@ -94,30 +106,34 @@ def write_meta(self): self.write_location() self.write_altitude() - def write_simple(self, parm): + def write_simple(self, parm: str) -> None: value = getattr(self.htimeseries, parm, None) if value is not None: self.f.write("{}={}\r\n".format(parm.capitalize(), value)) - def write_count(self): + def write_count(self) -> None: self.f.write("Count={}\r\n".format(len(self.htimeseries.data))) - def write_comment(self): + def write_comment(self) -> None: if hasattr(self.htimeseries, "comment"): for line in self.htimeseries.comment.splitlines(): self.f.write("Comment={}\r\n".format(line)) - def write_timezone(self): - offset = self.htimeseries.data.index.tz.utcoffset(None) + def write_timezone(self) -> None: + index = cast(pd.DatetimeIndex, self.htimeseries.data.index) + assert index.tz is not None + offset = index.tz.utcoffset(None) + assert offset is not None sign = "-+"[offset >= dt.timedelta(0)] offset = abs(offset) hours = offset.seconds // 3600 minutes = offset.seconds % 3600 // 60 self.f.write(f"Timezone={sign}{hours:02}{minutes:02}\r\n") - def write_location(self): + def write_location(self) -> None: if self.version <= 2 or not getattr(self.htimeseries, "location", None): return + assert self.htimeseries.location is not None self.f.write( "Location={:.6f} {:.6f} {}\r\n".format( *[ @@ -127,13 +143,12 @@ def write_location(self): ) ) - def write_altitude(self): - no_altitude = ( - (self.version <= 2) - or not getattr(self.htimeseries, "location", None) - or (self.htimeseries.location.get("altitude") is None) - ) - if no_altitude: + def write_altitude(self) -> None: + if self.version <= 2 or not getattr(self.htimeseries, "location", None): + return + assert self.htimeseries.location is not None + + if self.htimeseries.location.get("altitude") is None: return altitude = self.htimeseries.location["altitude"] asrid = ( @@ -148,28 +163,28 @@ def write_altitude(self): ) self.f.write(fmt.format(altitude=altitude, asrid=asrid)) - def write_time_step(self): + def write_time_step(self) -> None: if getattr(self.htimeseries, "time_step", ""): self._write_nonempty_time_step() - def _write_nonempty_time_step(self): - if self.version is None or self.version >= 5: - self.f.write("Time_step={}\r\n".format(self.htimeseries.time_step)) + def _write_nonempty_time_step(self) -> None: + if self.version >= 5: + self.f.write(f"Time_step={self.htimeseries.time_step}\r\n") else: self._write_old_time_step() - def _write_old_time_step(self): + def _write_old_time_step(self) -> None: try: old_time_step = self._get_old_time_step_in_minutes() except ValueError: old_time_step = self._get_old_time_step_in_months() self.f.write("Time_step={}\r\n".format(old_time_step)) - def _get_old_time_step_in_minutes(self): - td = pd.to_timedelta(to_offset(self.htimeseries.time_step)) + def _get_old_time_step_in_minutes(self) -> str: + td = cast(pd.Timedelta, pd.to_timedelta(to_offset(self.htimeseries.time_step))) # type: ignore return str(int(td.total_seconds() / 60)) + ",0" - def _get_old_time_step_in_months(self): + def _get_old_time_step_in_months(self) -> str: time_step = self.htimeseries.time_step try: value, unit = self._split_time_step_string(time_step) @@ -182,16 +197,17 @@ def _get_old_time_step_in_months(self): pass raise ValueError('Cannot format time step "{}"'.format(time_step)) - def _split_time_step_string(self, time_step_string): + def _split_time_step_string(self, time_step_string: str) -> Tuple[str, str]: value = "" for i, char in enumerate(time_step_string): if not char.isdigit(): return value, time_step_string[i:] value += char + assert False, "Unreachable" class MetadataReader: - def __init__(self, f): + def __init__(self, f: FileLike) -> None: f = _BacktrackableFile(f) # Check if file contains headers @@ -199,14 +215,15 @@ def __init__(self, f): f.backtrack(first_line) if isinstance(first_line, bytes): first_line = first_line.decode("utf-8-sig") + assert isinstance(first_line, str) has_headers = not first_line[0].isdigit() # Read file, with its headers if needed - self.meta = {} + self.meta: Dict[str, Any] = {} if has_headers: self.read_meta(f) - def read_meta(self, f): + def read_meta(self, f: FileLike) -> None: """Read the headers of a file in file format and place them in the self.meta dictionary. """ @@ -227,23 +244,23 @@ def read_meta(self, f): e.args = e.args + (f.line_number,) raise - def get_unit(self, name, value): + def get_unit(self, name: str, value: str) -> None: self.meta[name] = value get_title = get_unit get_variable = get_unit - def get_time_step(self, name, value): + def get_time_step(self, name: str, value: str) -> None: if value and "," in value: minutes, months = self.read_minutes_months(value) self.meta[name] = self._time_step_from_minutes_months(minutes, months) else: self.meta[name] = value - def get_timezone(self, name, value): + def get_timezone(self, name: str, value: str) -> None: self.meta["_timezone"] = value - def _time_step_from_minutes_months(self, minutes, months): + def _time_step_from_minutes_months(self, minutes: int, months: int) -> str: if minutes != 0 and months != 0: raise ParsingError("Invalid time step") elif minutes != 0: @@ -251,26 +268,26 @@ def _time_step_from_minutes_months(self, minutes, months): else: return str(months) + "M" - def get_interval_type(self, name, value): + def get_interval_type(self, name: str, value: str) -> None: value = value.lower() if value not in ("sum", "average", "maximum", "minimum", "vector_average"): raise ParsingError(("Invalid interval type")) self.meta[name] = value - def get_precision(self, name, value): + def get_precision(self, name: str, value: str) -> None: try: self.meta[name] = int(value) except ValueError as e: - raise ParsingError(e.args) + raise ParsingError(str(e)) - def get_comment(self, name, value): + def get_comment(self, name: str, value: str) -> None: if "comment" in self.meta: self.meta["comment"] += "\n" else: self.meta["comment"] = "" self.meta["comment"] += value - def get_location(self, name, value): + def get_location(self, name: str, value: str) -> None: self._ensure_location_attribute_exists() try: items = value.split() @@ -280,11 +297,11 @@ def get_location(self, name, value): except (IndexError, ValueError): raise ParsingError("Invalid location") - def _ensure_location_attribute_exists(self): + def _ensure_location_attribute_exists(self) -> None: if "location" not in self.meta: self.meta["location"] = {} - def get_altitude(self, name, value): + def get_altitude(self, name: str, value: str) -> None: self._ensure_location_attribute_exists() try: items = value.split() @@ -293,7 +310,7 @@ def get_altitude(self, name, value): except (IndexError, ValueError): raise ParsingError("Invalid altitude") - def read_minutes_months(self, s): + def read_minutes_months(self, s: str) -> Tuple[int, int]: """Return a (minutes, months) tuple after parsing a "M,N" string.""" try: (minutes, months) = [int(x.strip()) for x in s.split(",")] @@ -301,7 +318,7 @@ def read_minutes_months(self, s): except Exception: raise ParsingError(('Value should be "minutes, months"')) - def read_meta_line(self, f): + def read_meta_line(self, f: FileLike) -> Tuple[str, str]: """Read one line from a file format header and return a (name, value) tuple, where name is lowercased. Returns ('', '') if the next line is blank. Raises ParsingError if next line in f is not a valid header @@ -309,6 +326,7 @@ def read_meta_line(self, f): line = f.readline() if isinstance(line, bytes): line = line.decode("utf-8-sig") + assert isinstance(line, str) name, value = "", "" if line.isspace(): return (name, value) @@ -325,14 +343,26 @@ def read_meta_line(self, f): class HTimeseries: TEXT = "TEXT" FILE = "FILE" - args = { + args: Dict[str, Optional[Any]] = { "format": None, "start_date": None, "end_date": None, "default_tzinfo": None, } - - def __init__(self, data=None, **kwargs): + comment: str + location: Dict[str, Any] | None + time_step: str + _timezone: str | None + precision: int | None + unit: str + title: str + variable: str + + def __init__( + self, + data: Optional[Union[pd.DataFrame, FileLike]] = None, + **kwargs: Any, + ) -> None: extra_parms = set(kwargs) - set(self.args) if extra_parms: raise TypeError( @@ -351,12 +381,18 @@ def __init__(self, data=None, **kwargs): else: self._read_filelike(data, **kwargs) - def _check_dataframe(self, data): - if data.index.tz is None: + def _check_dataframe(self, data: pd.DataFrame) -> None: + if cast(pd.DatetimeIndex, data.index).tz is None: raise TypeError("data.index.tz must exist") - def _read_filelike(self, *args, **kwargs): - reader = TimeseriesStreamReader(*args, **kwargs) + def _read_filelike(self, f: FileLike, **kwargs: Any) -> None: + reader = TimeseriesStreamReader( + f, + format=kwargs["format"], + start_date=kwargs["start_date"], + end_date=kwargs["end_date"], + default_tzinfo=kwargs["default_tzinfo"], + ) self.__dict__.update(reader.get_metadata()) try: tzinfo = TzinfoFromString(self._timezone) @@ -369,46 +405,56 @@ def _read_filelike(self, *args, **kwargs): "specified" ) - def write(self, f, format=TEXT, version=5): + def write(self, f: IO[str], format: str = TEXT, version: int = 5) -> None: writer = TimeseriesStreamWriter(self, f, format=format, version=version) writer.write() class TimeseriesStreamReader: - def __init__(self, f, **kwargs): - self.f = f - self.specified_format = kwargs["format"] - self.start_date = kwargs["start_date"] - self.end_date = kwargs["end_date"] - self.default_tzinfo = kwargs["default_tzinfo"] + def __init__( + self, + f: FileLike, + *, + format: Optional[str], + start_date: Optional[Union[str, dt.datetime]], + end_date: Optional[Union[str, dt.datetime]], + default_tzinfo: Optional[dt.tzinfo], + ) -> None: + self.f: FileLike = f + self.specified_format = format + self.start_date = start_date + self.end_date = end_date + self.default_tzinfo = default_tzinfo - def get_metadata(self): + def get_metadata(self) -> Dict[str, Any]: if self.format == HTimeseries.FILE: return MetadataReader(self.f).meta else: return {} @property - def format(self): + def format(self) -> str: if self.specified_format is None: return self.autodetected_format else: return self.specified_format @property - def autodetected_format(self): + def autodetected_format(self) -> str: if not hasattr(self, "_stored_autodetected_format"): self._stored_autodetected_format = FormatAutoDetector(self.f).detect() return self._stored_autodetected_format - def get_data(self, tzinfo): + def get_data(self, tzinfo: Optional[dt.tzinfo]) -> pd.DataFrame: return TimeseriesRecordsReader( self.f, self.start_date, self.end_date, tzinfo=tzinfo ).read() -def _check_timeseries_index_has_no_duplicates(data, error_message_prefix): - duplicate_dates = data.index[data.index.duplicated()].tolist() +def _check_timeseries_index_has_no_duplicates( + data: pd.DataFrame, error_message_prefix: str +) -> None: + duplicate_dates = data.index[data.index.duplicated()].tolist() # type: ignore if duplicate_dates: dates_str = ", ".join([str(x) for x in duplicate_dates]) raise ValueError( @@ -418,20 +464,26 @@ def _check_timeseries_index_has_no_duplicates(data, error_message_prefix): class TimeseriesRecordsReader: - def __init__(self, f, start_date, end_date, tzinfo): + def __init__( + self, + f: FileLike, + start_date: Optional[Union[str, dt.datetime]], + end_date: Optional[Union[str, dt.datetime]], + tzinfo: Optional[dt.tzinfo], + ) -> None: self.f = f self.start_date = start_date self.end_date = end_date self.tzinfo = tzinfo - def read(self): + def read(self) -> pd.DataFrame: start_date, end_date = self._get_bounding_dates_as_strings() - f2 = _FilePart(self.f, start_date, end_date) + f2 = _FilePart(cast(TextFileLike, self.f), start_date, end_date) data = self._read_data_from_stream(f2) self._check_there_are_no_duplicates(data) return data - def _get_bounding_dates_as_strings(self): + def _get_bounding_dates_as_strings(self) -> Tuple[str, str]: start_date = "0001-01-01 00:00" if self.start_date is None else self.start_date end_date = "9999-12-31 00:00" if self.end_date is None else self.end_date if isinstance(start_date, dt.datetime): @@ -440,7 +492,7 @@ def _get_bounding_dates_as_strings(self): end_date = end_date.strftime("%Y-%m-%d %H:%M") return start_date, end_date - def _read_data_from_stream(self, f): + def _read_data_from_stream(self, f: _FilePart) -> pd.DataFrame: dates, values, flags = self._read_csv(f) dates = self._localize_dates(dates) result = pd.DataFrame( @@ -453,22 +505,26 @@ def _read_data_from_stream(self, f): result.index.name = "date" return result - def _localize_dates(self, dates): + def _localize_dates(self, dates: Sequence[str]) -> pd.DatetimeIndex: try: - result = pd.to_datetime(dates) + result: pd.DatetimeIndex = pd.to_datetime(dates) # type: ignore except ValueError: raise ValueError( "Could not parse timestamps correctly. Maybe the CSV contains mixed " "aware and naive timestamps." ) - if len(result) == 0 or (len(result) > 0 and result[0].tzinfo is None): - result = pd.to_datetime(dates).tz_localize( + assert isinstance(result, pd.DatetimeIndex) + if len(result) == 0 or (len(result) > 0 and result.tz is None): + result = pd.to_datetime(dates).tz_localize( # type: ignore self.tzinfo, ambiguous=len(dates) * [True] ) + assert isinstance(result, pd.DatetimeIndex) return result - def _read_csv(self, f): - dates, values, flags = [], [], [] + def _read_csv(self, f: _FilePart) -> Tuple[List[str], List[str], List[str]]: + dates: List[str] = [] + values: List[str] = [] + flags: List[str] = [] for row in csv.reader(f): # We don't use pd.read_csv() because it's much slower if not len(row): continue @@ -480,90 +536,103 @@ def _read_csv(self, f): flags.append(row[2] if len(row) > 2 else "") return dates, values, flags - def _check_there_are_no_duplicates(self, data): + def _check_there_are_no_duplicates(self, data: pd.DataFrame) -> None: _check_timeseries_index_has_no_duplicates( data, error_message_prefix="Can't read time series" ) class FormatAutoDetector: - def __init__(self, f): + def __init__(self, f: FileLike) -> None: self.f = f - def detect(self): + def detect(self) -> str: original_position = self.f.tell() result = self._guess_format_from_first_nonempty_line() self.f.seek(original_position) return result - def _guess_format_from_first_nonempty_line(self): + def _guess_format_from_first_nonempty_line(self) -> str: line = self._get_first_nonempty_line() if line and not line[0].isdigit(): return HTimeseries.FILE else: return HTimeseries.TEXT - def _get_first_nonempty_line(self): + def _get_first_nonempty_line(self) -> str: for line in self.f: if line.strip(): - return line + if isinstance(line, bytes): + return line.decode("utf-8-sig") + elif isinstance(line, str): + return line + assert False, "Unreachable" return "" class TimeseriesStreamWriter: - def __init__(self, htimeseries, f, *, format, version): + def __init__( + self, + htimeseries: HTimeseries, + f: IO[str], + *, + format: str, + version: Optional[int], + ) -> None: self.htimeseries = htimeseries self.f = f self.format = format self.version = version - def write(self): + def write(self) -> None: self._write_metadata() self._write_records() - def _write_metadata(self): + def _write_metadata(self) -> None: if self.format == HTimeseries.FILE: - MetadataWriter(self.f, self.htimeseries, version=self.version).write_meta() + version = self.version or 5 + MetadataWriter(self.f, self.htimeseries, version=version).write_meta() self.f.write("\r\n") - def _write_records(self): + def _write_records(self) -> None: TimeseriesRecordsWriter(self.htimeseries, self.f).write() class TimeseriesRecordsWriter: - def __init__(self, htimeseries, f): + def __init__(self, htimeseries: HTimeseries, f: IO[str]) -> None: self.htimeseries = htimeseries self.f = f + self.float_format: str = "%f" - def write(self): + def write(self) -> None: if self.htimeseries.data.empty: return self._check_there_are_no_duplicates() self._setup_precision() self._write_records() - def _check_there_are_no_duplicates(self): + def _check_there_are_no_duplicates(self) -> None: _check_timeseries_index_has_no_duplicates( self.htimeseries.data, error_message_prefix="Can't write time series" ) - def _setup_precision(self): + def _setup_precision(self) -> None: precision = getattr(self.htimeseries, "precision", None) if precision is None: self.float_format = "%f" - elif self.htimeseries.precision >= 0: + elif precision >= 0: self.float_format = "%.{}f".format(self.htimeseries.precision) else: self.float_format = "%.0f" self._prepare_records_for_negative_precision(precision) - def _prepare_records_for_negative_precision(self, precision): + def _prepare_records_for_negative_precision(self, precision: int) -> None: assert precision < 0 datacol = self.htimeseries.data.columns[0] - m = 10 ** (-self.htimeseries.precision) + m = 10 ** (-precision) self.htimeseries.data[datacol] = np.rint(self.htimeseries.data[datacol] / m) * m - def _write_records(self): + def _write_records(self) -> None: self.htimeseries.data.to_csv( self.f, float_format=self.float_format, diff --git a/src/htimeseries/timezone_utils.py b/src/htimeseries/timezone_utils.py index 8aad36b..4fda8e1 100644 --- a/src/htimeseries/timezone_utils.py +++ b/src/htimeseries/timezone_utils.py @@ -1,9 +1,12 @@ +from __future__ import annotations + import datetime as dt +from typing import Optional class TzinfoFromString(dt.tzinfo): - def __init__(self, string): - self.offset = None + def __init__(self, string: Optional[str]) -> None: + self.offset: Optional[dt.timedelta] = None self.name = "" if not string: return @@ -34,11 +37,11 @@ def __init__(self, string): self.offset = sign * dt.timedelta(hours=hours, minutes=minutes) - def utcoffset(self, adatetime): + def utcoffset(self, adatetime: Optional[dt.datetime]) -> Optional[dt.timedelta]: return self.offset - def dst(self, adatetime): + def dst(self, adatetime: Optional[dt.datetime]) -> dt.timedelta: return dt.timedelta(0) - def tzname(self, adatetime): + def tzname(self, adatetime: Optional[dt.datetime]) -> str: return self.name diff --git a/src/rocc/__init__.py b/src/rocc/__init__.py index fb6cf09..35862c6 100644 --- a/src/rocc/__init__.py +++ b/src/rocc/__init__.py @@ -1,18 +1,27 @@ -from collections import namedtuple +from __future__ import annotations + +from typing import TYPE_CHECKING, Callable, Iterable, NamedTuple from .calculation import Rocc -Threshold = namedtuple("Threshold", ["delta_t", "allowed_diff"]) + +if TYPE_CHECKING: # pragma: no cover - used for type checkers only + from htimeseries import HTimeseries + + +class Threshold(NamedTuple): + delta_t: str + allowed_diff: float def rocc( *, - timeseries, - thresholds, - symmetric=False, - flag="TEMPORAL", - progress_callback=lambda x: None, -): + timeseries: HTimeseries, + thresholds: Iterable[Threshold], + symmetric: bool = False, + flag: str | None = "TEMPORAL", + progress_callback: Callable[[float], None] = lambda x: None, +) -> list[str]: return Rocc( timeseries, thresholds, diff --git a/src/textbisect/__init__.py b/src/textbisect/__init__.py index 12f1b30..446c2db 100644 --- a/src/textbisect/__init__.py +++ b/src/textbisect/__init__.py @@ -1,22 +1,24 @@ from io import SEEK_END +from typing import Callable, IO, Union + class TextBisector: - def __init__(self, a, x, on_same, key): + def __init__(self, a: IO[str], x: str, on_same: str, key: Callable[[str], str]): self.a = a self.x = x self.on_same = on_same self.key = key self.ref = key(x) - def get_hi(self, hi): + def get_hi(self, hi: Union[int, None]) -> int: """Return hi, or the end position of the file if hi is None.""" if hi is None: self.a.seek(0, SEEK_END) hi = self.a.tell() - 1 return hi - def get_beginning_of_line(self, pos, lo, hi): + def get_beginning_of_line(self, pos: int, lo: int, hi: int) -> int: """Return the beginning of the line containing the position pos. On return the file is positioned at the return value.""" while True: @@ -28,7 +30,7 @@ def get_beginning_of_line(self, pos, lo, hi): return pos pos -= 1 - def get_end_of_line(self, pos, lo, hi): + def get_end_of_line(self, pos: int, lo: int, hi: int) -> int: """Return the end of the line (the line feed) containing the position pos. On return the file is positioned at the return value.""" self.a.seek(pos) @@ -40,7 +42,7 @@ def get_end_of_line(self, pos, lo, hi): return self.a.seek(pos) pos += 1 - def get_line(self, pos, lo, hi): + def get_line(self, pos: int, lo: int, hi: int) -> tuple[str, int, int]: """Return a tuple (line, start, end), where line is the line containing the position pos (without the ending line feed), and start and end are the positions of the start of the line and of the line feed. On return @@ -51,7 +53,7 @@ def get_line(self, pos, lo, hi): line = self.a.read(end - start) return (line, start, end) - def bisect(self, lo, hi): + def bisect(self, lo: int, hi: int) -> int: # This recursive function ends when hi == lo - 1 if hi == lo - 1: self.a.seek(lo) @@ -69,13 +71,25 @@ def bisect(self, lo, hi): return self.bisect(end + 1, hi) -def text_bisect_left(a, x, lo=0, hi=None, key=lambda x: x): +def text_bisect_left( + a: IO[str], + x: str, + lo: int = 0, + hi: Union[int, None] = None, + key: Callable[[str], str] = lambda x: x, +) -> int: bisector = TextBisector(a, x, "left", key) hi = bisector.get_hi(hi) return bisector.bisect(lo, hi) -def text_bisect_right(a, x, lo=0, hi=None, key=lambda x: x): +def text_bisect_right( + a: IO[str], + x: str, + lo: int = 0, + hi: Union[int, None] = None, + key: Callable[[str], str] = lambda x: x, +) -> int: bisector = TextBisector(a, x, "right", key) hi = bisector.get_hi(hi) return bisector.bisect(lo, hi) diff --git a/tests/enhydris_api_client/__init__.py b/tests/enhydris_api_client/__init__.py index 7808bae..043197a 100644 --- a/tests/enhydris_api_client/__init__.py +++ b/tests/enhydris_api_client/__init__.py @@ -1,7 +1,10 @@ +from __future__ import annotations + import datetime as dt import textwrap from copy import copy from io import StringIO +from typing import Any, Callable, Dict, Optional, cast from unittest import mock import pandas as pd @@ -9,7 +12,7 @@ from htimeseries import HTimeseries -test_timeseries_csv = textwrap.dedent( +test_timeseries_csv: str = textwrap.dedent( """\ 2014-01-01 08:00,11.0, 2014-01-02 08:00,12.0, @@ -18,18 +21,20 @@ 2014-01-05 08:00,15.0, """ ) -test_timeseries_htimeseries = HTimeseries( +test_timeseries_htimeseries: HTimeseries = HTimeseries( StringIO(test_timeseries_csv), default_tzinfo=dt.timezone(dt.timedelta(hours=2)) ) -test_timeseries_csv_top = "".join(test_timeseries_csv.splitlines(keepends=True)[:-1]) -test_timeseries_csv_bottom = test_timeseries_csv.splitlines(keepends=True)[-1] +test_timeseries_csv_top: str = "".join( + test_timeseries_csv.splitlines(keepends=True)[:-1] +) +test_timeseries_csv_bottom: str = test_timeseries_csv.splitlines(keepends=True)[-1] -def mock_session(**kwargs): +def mock_session(**kwargs: Any) -> mock._patch: """Mock requests.Session. Returns - @mock.patch("requests.Session", modified_kwargs) + mock.patch("requests.Session", modified_kwargs) However, it first tampers with kwargs in order to achieve the following: - It adds a leading "return_value." to the kwargs; so you don't need to specify, @@ -39,27 +44,46 @@ def mock_session(**kwargs): - If "get.return_value.status_code" is not between 200 and 399, then raise_for_status() will raise HTTPError. Likewise for the other methods. """ + patch_kwargs: Dict[str, Any] = dict(kwargs) for method in ("get", "post", "put", "patch", "delete"): default_value = 204 if method == "delete" else 200 - c = kwargs.setdefault(method + ".return_value.status_code", default_value) - if c < 200 or c >= 400: - method_side_effect = method + ".return_value.raise_for_status.side_effect" - kwargs[method_side_effect] = requests.HTTPError - for old_key in list(kwargs.keys()): - kwargs["return_value." + old_key] = kwargs.pop(old_key) - return mock.patch("requests.Session", **kwargs) + status_code_key = f"{method}.return_value.status_code" + status_code = patch_kwargs.setdefault(status_code_key, default_value) + if isinstance(status_code, bool) or not isinstance(status_code, int): + raise TypeError( + "status_code overrides must be integers " + f"(got {status_code!r} for {method})" + ) + if status_code < 200 or status_code >= 400: + method_side_effect = ( + f"{method}.return_value.raise_for_status.side_effect" + ) + patch_kwargs[method_side_effect] = requests.HTTPError + for old_key in list(patch_kwargs.keys()): + patch_kwargs[f"return_value.{old_key}"] = patch_kwargs.pop(old_key) + return mock.patch("requests.Session", **patch_kwargs) class AssertFrameEqualMixin: - def assert_frame_equal(self, actual, expected): + assertEqual: Callable[..., None] + + def assert_frame_equal(self, actual: pd.DataFrame, expected: pd.DataFrame) -> None: + actual_index = cast(pd.DatetimeIndex, actual.index) + expected_index = cast(pd.DatetimeIndex, expected.index) + assert actual_index.tz is not None + assert expected_index.tz is not None self.assertEqual( - actual.index.tz.utcoffset(None), expected.index.tz.utcoffset(None) + actual_index.tz.utcoffset(None), expected_index.tz.utcoffset(None) ) pd.testing.assert_frame_equal(actual, expected, check_index_type=False) - def assert_frame_loosely_equal(self, actual, expected): + def assert_frame_loosely_equal( + self, actual: pd.DataFrame, expected: pd.DataFrame + ) -> None: actual = copy(actual) expected = copy(expected) - actual.index = actual.index.tz_convert("UTC") - expected.index = expected.index.tz_convert("UTC") + actual_index = cast(pd.DatetimeIndex, actual.index) + expected_index = cast(pd.DatetimeIndex, expected.index) + actual.index = actual_index.tz_convert("UTC") + expected.index = expected_index.tz_convert("UTC") pd.testing.assert_frame_equal(actual, expected) diff --git a/tests/enhydris_api_client/test_e2e.py b/tests/enhydris_api_client/test_e2e.py index a54d559..0cf5ee3 100644 --- a/tests/enhydris_api_client/test_e2e.py +++ b/tests/enhydris_api_client/test_e2e.py @@ -1,8 +1,11 @@ +from __future__ import annotations + import datetime as dt import json import os import textwrap from io import StringIO +from typing import Any, Dict from unittest import TestCase, skipUnless from zoneinfo import ZoneInfo @@ -44,8 +47,10 @@ class EndToEndTestCase(AssertFrameEqualMixin, TestCase): as in enhydris_cache. """ - def setUp(self): - v = json.loads(os.getenv("PTHELMA_TEST_ENHYDRIS_SERVER")) + def setUp(self) -> None: + raw_config = os.getenv("PTHELMA_TEST_ENHYDRIS_SERVER") + assert raw_config is not None + v: Dict[str, Any] = json.loads(raw_config) self.token = v["token"] self.client = EnhydrisApiClient(v["base_url"], token=self.token) self.client.__enter__() @@ -53,10 +58,10 @@ def setUp(self): self.variable_id = v["variable_id"] self.unit_of_measurement_id = v["unit_of_measurement_id"] - def tearDown(self): + def tearDown(self) -> None: self.client.__exit__() - def test_e2e(self): + def test_e2e(self) -> None: # Verify we're authenticated token = self.client.session.headers.get("Authorization") self.assertEqual(token, f"token {self.token}") @@ -159,14 +164,6 @@ def test_e2e(self): hts = self.client.read_tsdata( tmp_station_id, self.timeseries_group_id, self.timeseries_id ) - try: - # Compatibility with older Python or pandas versions (such as Python 3.7 - # with pandas 0.23): comparison may fail if tzinfo, although practically the - # same thing, is a different object - if hts.data.index.tz.offset == dt.timedelta(0): - hts.data.index = hts.data.index.tz_convert(dt.timezone.utc) - except AttributeError: - pass self.assert_frame_loosely_equal(hts.data, test_timeseries_htimeseries.data) # The other attributes should have been set too. @@ -192,15 +189,7 @@ def test_e2e(self): ), default_tzinfo=ZoneInfo("Etc/GMT-1"), ) - try: - # Compatibility with older Python or pandas versions (such as Python 3.7 - # with pandas 0.23): comparison may fail if tzinfo, although practically the - # same thing, is a different object - if hts.data.index.tz.offset == dt.timedelta(minutes=60): - hts.data.index = hts.data.index.tz_convert(ZoneInfo("Etc/GMT-1")) - except AttributeError: - pass - pd.testing.assert_frame_equal(hts.data, expected_result.data) + self.assert_frame_loosely_equal(hts.data, expected_result.data) # Delete the time series and verify self.client.delete_timeseries( diff --git a/tests/enhydris_api_client/test_misc.py b/tests/enhydris_api_client/test_misc.py index 57876f6..31f7d2e 100644 --- a/tests/enhydris_api_client/test_misc.py +++ b/tests/enhydris_api_client/test_misc.py @@ -1,3 +1,5 @@ +from __future__ import annotations + from unittest import TestCase, mock import requests @@ -15,13 +17,10 @@ class GetTokenTestCase(TestCase): "post.return_value.cookies": {"acookie": "a cookie value"}, } ) - def setUp(self, mock_requests_session): - self.mock_requests_session = mock_requests_session + def test_makes_post_request(self, m: mock.MagicMock) -> None: self.client = EnhydrisApiClient("https://mydomain.com") self.client.get_token("admin", "topsecret") - - def test_makes_post_request(self): - self.mock_requests_session.return_value.post.assert_called_once_with( + m.return_value.post.assert_called_once_with( "https://mydomain.com/api/auth/login/", data="username=admin&password=topsecret", allow_redirects=False, @@ -30,42 +29,50 @@ def test_makes_post_request(self): class GetTokenFailTestCase(TestCase): @mock_session(**{"post.return_value.status_code": 404}) - def test_raises_exception_on_post_failure(self, mock_requests_session): + def test_raises_exception_on_post_failure( + self, mock_requests_session: mock.MagicMock + ) -> None: self.client = EnhydrisApiClient("https://mydomain.com") with self.assertRaises(requests.HTTPError): self.client.get_token("admin", "topsecret") class GetTokenEmptyUsernameTestCase(TestCase): - @mock_session() - def setUp(self, mock_requests_session): - self.mock_requests_session = mock_requests_session + def setUp(self) -> None: + self.session_patcher = mock_session() + self.mock_requests_session = self.session_patcher.start() self.client = EnhydrisApiClient("https://mydomain.com") self.client.get_token("", "useless_password") - def test_does_not_make_get_request(self): + def tearDown(self) -> None: + self.session_patcher.stop() + + def test_does_not_make_get_request(self) -> None: self.mock_requests_session.get.assert_not_called() - def test_does_not_make_post_request(self): + def test_does_not_make_post_request(self) -> None: self.mock_requests_session.post.assert_not_called() class UseAsContextManagerTestCase(TestCase): - @mock_session() - def setUp(self, mock_requests_session): - self.mock_requests_session = mock_requests_session + def setUp(self) -> None: + self.session_patcher = mock_session() + self.mock_requests_session = self.session_patcher.start() with EnhydrisApiClient("https://mydomain.com/") as api_client: api_client.get_station(42) - def test_called_enter(self): + def tearDown(self) -> None: + self.session_patcher.stop() + + def test_called_enter(self) -> None: self.mock_requests_session.return_value.__enter__.assert_called_once_with() - def test_called_exit(self): + def test_called_exit(self) -> None: self.assertEqual( len(self.mock_requests_session.return_value.__exit__.mock_calls), 1 ) - def test_makes_request(self): + def test_makes_request(self) -> None: self.mock_requests_session.return_value.get.assert_called_once_with( "https://mydomain.com/api/stations/42/" ) @@ -74,75 +81,76 @@ def test_makes_request(self): class Error400TestCase(TestCase): msg = "hello world" - @mock_session( - **{ - "get.return_value.status_code": 400, - "get.return_value.text": "hello world", - "post.return_value.status_code": 400, - "post.return_value.text": "hello world", - "put.return_value.status_code": 400, - "put.return_value.text": "hello world", - "patch.return_value.status_code": 400, - "patch.return_value.text": "hello world", - "delete.return_value.status_code": 400, - "delete.return_value.text": "hello world", - } - ) - def setUp(self, m): - self.client = EnhydrisApiClient("https://mydomain.com") + def setUp(self) -> None: # type: ignore[misc] + m = mock_session( + **{ + "get.return_value.status_code": 400, + "get.return_value.text": "hello world", + "post.return_value.status_code": 400, + "post.return_value.text": "hello world", + "put.return_value.status_code": 400, + "put.return_value.text": "hello world", + "patch.return_value.status_code": 400, + "patch.return_value.text": "hello world", + "delete.return_value.status_code": 400, + "delete.return_value.text": "hello world", + } + ) + with m: + self.client = EnhydrisApiClient("https://mydomain.com") - def test_get_token(self): + def test_get_token(self) -> None: with self.assertRaisesRegex(requests.HTTPError, self.msg): self.client.get_token("john", "topsecret") - def test_get_station(self): + def test_get_station(self) -> None: with self.assertRaisesRegex(requests.HTTPError, self.msg): self.client.get_station(42) - def test_post_station(self): + def test_post_station(self) -> None: with self.assertRaisesRegex(requests.HTTPError, self.msg): self.client.post_station({}) - def test_put_station(self): + def test_put_station(self) -> None: with self.assertRaisesRegex(requests.HTTPError, self.msg): self.client.put_station(42, {}) - def test_patch_station(self): + def test_patch_station(self) -> None: with self.assertRaisesRegex(requests.HTTPError, self.msg): self.client.patch_station(42, {}) - def test_delete_station(self): + def test_delete_station(self) -> None: with self.assertRaisesRegex(requests.HTTPError, self.msg): self.client.delete_station(42) - def test_get_timeseries(self): + def test_get_timeseries(self) -> None: with self.assertRaisesRegex(requests.HTTPError, self.msg): self.client.get_timeseries(41, 42, 43) - def test_post_timeseries(self): + def test_post_timeseries(self) -> None: with self.assertRaisesRegex(requests.HTTPError, self.msg): self.client.post_timeseries(42, 43, {}) - def test_delete_timeseries(self): + def test_delete_timeseries(self) -> None: with self.assertRaisesRegex(requests.HTTPError, self.msg): self.client.delete_timeseries(41, 42, 43) - def test_read_tsdata(self): + def test_read_tsdata(self) -> None: with self.assertRaisesRegex(requests.HTTPError, self.msg): self.client.read_tsdata(41, 42, 43) - def test_post_tsdata(self): + def test_post_tsdata(self) -> None: with self.assertRaisesRegex(requests.HTTPError, self.msg): self.client.post_tsdata(41, 42, 43, HTimeseries()) - def test_get_ts_end_date(self): + def test_get_ts_end_date(self) -> None: with self.assertRaisesRegex(requests.HTTPError, self.msg): self.client.get_ts_end_date(41, 42, 43) class EnhydrisApiClientTestCase(TestCase): @mock.patch("requests.Session") - def test_client_with_token(self, mock_requests_session): + def test_client_with_token(self, mock_requests_session: mock.MagicMock) -> None: EnhydrisApiClient("https://mydomain.com/", token="test-token") mock_requests_session.return_value.headers.update.assert_any_call( {"Authorization": "token test-token"} diff --git a/tests/enhydris_api_client/test_station.py b/tests/enhydris_api_client/test_station.py index c8f417a..8438684 100644 --- a/tests/enhydris_api_client/test_station.py +++ b/tests/enhydris_api_client/test_station.py @@ -1,5 +1,7 @@ +from __future__ import annotations + from unittest import TestCase -from unittest.mock import call +from unittest.mock import MagicMock, call import requests @@ -9,24 +11,30 @@ class ListStationsSinglePageTestCase(TestCase): - @mock_session() - def setUp(self, m): - m.return_value.get.return_value.json.return_value = { - "count": 2, - "next": None, - "previous": None, - "results": [{"name": "Hobbiton"}, {"name": "Rivendell"}], - } - self.mock_session = m + def setUp(self) -> None: # type: ignore[misc] + self.session_patcher = mock_session( + **{ + "get.return_value.json.return_value": { + "count": 2, + "next": None, + "previous": None, + "results": [{"name": "Hobbiton"}, {"name": "Rivendell"}], + } + } + ) + self.mock_session = self.session_patcher.start() client = EnhydrisApiClient("https://mydomain.com") self.result = client.list_stations() - def test_makes_request(self): + def tearDown(self) -> None: + self.session_patcher.stop() + + def test_makes_request(self) -> None: next(self.result) # Ensure the request is actually made m = self.mock_session m.return_value.get.assert_called_once_with("https://mydomain.com/api/stations/") - def test_result(self): + def test_result(self) -> None: self.assertEqual( list(self.result), [{"name": "Hobbiton"}, {"name": "Rivendell"}], @@ -34,27 +42,33 @@ def test_result(self): class ListStationsMultiPageTestCase(TestCase): - @mock_session() - def setUp(self, m): - m.return_value.get.return_value.json.side_effect = [ - { - "count": 3, - "next": "https://mydomain.com/api/stations/?page=2", - "previous": None, - "results": [{"name": "Hobbiton"}, {"name": "Rivendell"}], - }, - { - "count": 3, - "next": None, - "previous": "https://mydomain.com/api/stations/", - "results": [{"name": "Mordor"}], - }, - ] - self.mock_session = m + def setUp(self) -> None: # type: ignore[misc] + self.session_patcher = mock_session( + **{ + "get.return_value.json.side_effect": [ + { + "count": 3, + "next": "https://mydomain.com/api/stations/?page=2", + "previous": None, + "results": [{"name": "Hobbiton"}, {"name": "Rivendell"}], + }, + { + "count": 3, + "next": None, + "previous": "https://mydomain.com/api/stations/", + "results": [{"name": "Mordor"}], + }, + ] + } + ) + self.mock_session = self.session_patcher.start() client = EnhydrisApiClient("https://mydomain.com") self.result = client.list_stations() - def test_requests(self): + def tearDown(self) -> None: + self.session_patcher.stop() + + def test_requests(self) -> None: list(self.result) # Ensure all requests are made self.assertEqual( self.mock_session.return_value.get.call_args_list, @@ -64,7 +78,7 @@ def test_requests(self): ], ) - def test_result(self): + def test_result(self) -> None: self.assertEqual( list(self.result), [{"name": "Hobbiton"}, {"name": "Rivendell"}, {"name": "Mordor"}], @@ -73,85 +87,89 @@ def test_result(self): class ListStationsErrorTestCase(TestCase): @mock_session(**{"get.return_value.status_code": 500}) - def test_raises_exception_on_error(self, m): + def test_raises_exception_on_error(self, m: MagicMock) -> None: client = EnhydrisApiClient("https://mydomain.com") with self.assertRaises(requests.HTTPError): next(client.list_stations()) @mock_session(**{"get.return_value.json.return_value": "not a dict"}) - def test_raises_exception_on_non_json_response(self, m): + def test_raises_exception_on_non_json_response(self, m: MagicMock) -> None: client = EnhydrisApiClient("https://mydomain.com") with self.assertRaises(MalformedResponseError): next(client.list_stations()) @mock_session(**{"get.return_value.json.return_value": {"no": "expected"}}) - def test_raises_exception_on_unexpected_json(self, m): + def test_raises_exception_on_unexpected_json(self, m: MagicMock) -> None: client = EnhydrisApiClient("https://mydomain.com") with self.assertRaises(MalformedResponseError): next(client.list_stations()) class GetStationTestCase(TestCase): - @mock_session(**{"get.return_value.json.return_value": {"hello": "world"}}) - def setUp(self, mock_requests_session): - self.mock_requests_session = mock_requests_session + def setUp(self) -> None: + self.session_patcher = mock_session( + **{"get.return_value.json.return_value": {"hello": "world"}} + ) + self.mock_requests_session = self.session_patcher.start() self.client = EnhydrisApiClient("https://mydomain.com") self.data = self.client.get_station(42) - def test_makes_request(self): + def tearDown(self) -> None: + self.session_patcher.stop() + + def test_makes_request(self) -> None: self.mock_requests_session.return_value.get.assert_called_once_with( "https://mydomain.com/api/stations/42/" ) - def test_returns_data(self): + def test_returns_data(self) -> None: self.assertEqual(self.data, {"hello": "world"}) class PostStationTestCase(TestCase): - @mock_session(**{"post.return_value.json.return_value": {"id": 42}}) - def setUp(self, mock_requests_session): - self.mock_requests_session = mock_requests_session + def setUp(self) -> None: + self.session_patcher = mock_session( + **{"post.return_value.json.return_value": {"id": 42}} + ) + self.mock_requests_session = self.session_patcher.start() self.client = EnhydrisApiClient("https://mydomain.com") self.data = self.client.post_station(data={"location": "Syria"}) - def test_makes_request(self): + def tearDown(self) -> None: + self.session_patcher.stop() + + def test_makes_request(self) -> None: self.mock_requests_session.return_value.post.assert_called_once_with( "https://mydomain.com/api/stations/", data={"location": "Syria"} ) - def test_returns_id(self): + def test_returns_id(self) -> None: self.assertEqual(self.data, 42) class PutStationTestCase(TestCase): @mock_session() - def setUp(self, mock_requests_session): - self.mock_requests_session = mock_requests_session + def test_makes_request(self,m: MagicMock) -> None: self.client = EnhydrisApiClient("https://mydomain.com") self.client.put_station(42, data={"location": "Syria"}) - - def test_makes_request(self): - self.mock_requests_session.return_value.put.assert_called_once_with( + m.return_value.put.assert_called_once_with( "https://mydomain.com/api/stations/42/", data={"location": "Syria"} ) class PatchStationTestCase(TestCase): @mock_session() - def setUp(self, mock_requests_session): - self.mock_requests_session = mock_requests_session + def test_makes_request(self,m: MagicMock) -> None: self.client = EnhydrisApiClient("https://mydomain.com") self.client.patch_station(42, data={"location": "Syria"}) - - def test_makes_request(self): - self.mock_requests_session.return_value.patch.assert_called_once_with( + m.return_value.patch.assert_called_once_with( "https://mydomain.com/api/stations/42/", data={"location": "Syria"} ) class DeleteStationTestCase(TestCase): @mock_session() - def test_makes_request(self, mock_requests_session): + def test_makes_request(self, mock_requests_session: MagicMock) -> None: self.client = EnhydrisApiClient("https://mydomain.com") self.client.delete_station(42) mock_requests_session.return_value.delete.assert_called_once_with( @@ -159,7 +177,7 @@ def test_makes_request(self, mock_requests_session): ) @mock_session(**{"delete.return_value.status_code": 404}) - def test_raises_exception_on_error(self, mock_requests_delete): + def test_raises_exception_on_error(self, mock_requests_delete: MagicMock) -> None: self.client = EnhydrisApiClient("https://mydomain.com") with self.assertRaises(requests.HTTPError): self.client.delete_station(42) diff --git a/tests/enhydris_api_client/test_timeseries.py b/tests/enhydris_api_client/test_timeseries.py index 6bd5ff6..8cce977 100644 --- a/tests/enhydris_api_client/test_timeseries.py +++ b/tests/enhydris_api_client/test_timeseries.py @@ -1,4 +1,6 @@ -from unittest import TestCase +from __future__ import annotations + +from unittest import TestCase, mock import requests @@ -8,83 +10,94 @@ class ListTimeseriesTestCase(TestCase): - @mock_session( - **{"get.return_value.json.return_value": {"results": [{"hello": "world"}]}} - ) - def setUp(self, mock_requests_session): - self.mock_requests_session = mock_requests_session + def setUp(self) -> None: + self.session_patcher = mock_session( + **{"get.return_value.json.return_value": {"results": [{"hello": "world"}]}} + ) + self.mock_requests_session = self.session_patcher.start() self.client = EnhydrisApiClient("https://mydomain.com") self.data = self.client.list_timeseries(41, 42) - def test_makes_request(self): + def tearDown(self) -> None: + self.session_patcher.stop() + + def test_makes_request(self) -> None: self.mock_requests_session.return_value.get.assert_called_once_with( "https://mydomain.com/api/stations/41/timeseriesgroups/42/timeseries/" ) - def test_returns_data(self): + def test_returns_data(self) -> None: self.assertEqual(self.data, [{"hello": "world"}]) class GetTimeseriesTestCase(TestCase): - @mock_session(**{"get.return_value.json.return_value": {"hello": "world"}}) - def setUp(self, mock_requests_session): - self.mock_requests_session = mock_requests_session + def setUp(self) -> None: + self.session_patcher = mock_session( + **{"get.return_value.json.return_value": {"hello": "world"}} + ) + self.mock_requests_session = self.session_patcher.start() self.client = EnhydrisApiClient("https://mydomain.com") self.data = self.client.get_timeseries(41, 42, 43) - def test_makes_request(self): + def tearDown(self) -> None: + self.session_patcher.stop() + + def test_makes_request(self) -> None: self.mock_requests_session.return_value.get.assert_called_once_with( "https://mydomain.com/api/stations/41/timeseriesgroups/42/timeseries/43/" ) - def test_returns_data(self): + def test_returns_data(self) -> None: self.assertEqual(self.data, {"hello": "world"}) class GetStationOrTimeseriesErrorTestCase(TestCase): - @mock_session(**{"get.return_value.status_code": 404}) - def setUp(self, mock_requests_session): - self.client = EnhydrisApiClient("https://mydomain.com") + def setUp(self) -> None: + with mock_session(**{"get.return_value.status_code": 404}): + self.client = EnhydrisApiClient("https://mydomain.com") - def test_raises_exception_on_get_station_error(self): + def test_raises_exception_on_get_station_error(self) -> None: with self.assertRaises(requests.HTTPError): self.data = self.client.get_station(42) - def test_raises_exception_on_get_timeseries_error(self): + def test_raises_exception_on_get_timeseries_error(self) -> None: with self.assertRaises(requests.HTTPError): self.data = self.client.get_timeseries(41, 42, 43) class PostTimeseriesTestCase(TestCase): - @mock_session(**{"post.return_value.json.return_value": {"id": 43}}) - def setUp(self, mock_requests_session): - self.mock_requests_session = mock_requests_session + def setUp(self) -> None: + self.session_patcher = mock_session( + **{"post.return_value.json.return_value": {"id": 43}} + ) + self.mock_requests_session = self.session_patcher.start() self.client = EnhydrisApiClient("https://mydomain.com") self.data = self.client.post_timeseries(41, 42, data={"location": "Syria"}) - def test_makes_request(self): + def tearDown(self) -> None: + self.session_patcher.stop() + + def test_makes_request(self) -> None: self.mock_requests_session.return_value.post.assert_called_once_with( "https://mydomain.com/api/stations/41/timeseriesgroups/42/timeseries/", data={"location": "Syria"}, ) - def test_returns_id(self): + def test_returns_id(self) -> None: self.assertEqual(self.data, 43) class FailedPostTimeseriesTestCase(TestCase): @mock_session(**{"post.return_value.status_code": 404}) - def setUp(self, mock_requests_session): + def test_raises_exception_on_error(self, m: mock.MagicMock) -> None: self.client = EnhydrisApiClient("https://mydomain.com") - - def test_raises_exception_on_error(self): with self.assertRaises(requests.HTTPError): self.client.post_timeseries(41, 42, data={"location": "Syria"}) class DeleteTimeseriesTestCase(TestCase): @mock_session() - def test_makes_request(self, mock_requests_session): + def test_makes_request(self, mock_requests_session: mock.MagicMock) -> None: self.client = EnhydrisApiClient("https://mydomain.com") self.client.delete_timeseries(41, 42, 43) mock_requests_session.return_value.delete.assert_called_once_with( @@ -92,7 +105,9 @@ def test_makes_request(self, mock_requests_session): ) @mock_session(**{"delete.return_value.status_code": 404}) - def test_raises_exception_on_error(self, mock_requests_delete): + def test_raises_exception_on_error( + self, mock_requests_delete: mock.MagicMock + ) -> None: self.client = EnhydrisApiClient("https://mydomain.com") with self.assertRaises(requests.HTTPError): self.client.delete_timeseries(41, 42, 43) diff --git a/tests/enhydris_api_client/test_timeseriesgroup.py b/tests/enhydris_api_client/test_timeseriesgroup.py index 5f9558a..39623d7 100644 --- a/tests/enhydris_api_client/test_timeseriesgroup.py +++ b/tests/enhydris_api_client/test_timeseriesgroup.py @@ -1,5 +1,7 @@ +from __future__ import annotations + from unittest import TestCase -from unittest.mock import call +from unittest.mock import MagicMock, call import requests @@ -9,26 +11,32 @@ class ListTimeseriesGroupsSinglePageTestCase(TestCase): - @mock_session() - def setUp(self, m): - m.return_value.get.return_value.json.return_value = { - "count": 2, - "next": None, - "previous": None, - "results": [{"name": "Temperature"}, {"name": "Humidity"}], - } - self.mock_session = m + def setUp(self) -> None: + self.session_patcher = mock_session( + **{ + "get.return_value.json.return_value": { + "count": 2, + "next": None, + "previous": None, + "results": [{"name": "Temperature"}, {"name": "Humidity"}], + } + } + ) + self.mock_session = self.session_patcher.start() client = EnhydrisApiClient("https://mydomain.com") self.result = client.list_timeseries_groups(station_id=42) - def test_makes_request(self): + def tearDown(self) -> None: + self.session_patcher.stop() + + def test_makes_request(self) -> None: next(self.result) # Ensure the request is actually made m = self.mock_session m.return_value.get.assert_called_once_with( "https://mydomain.com/api/stations/42/timeseriesgroups/" ) - def test_result(self): + def test_result(self) -> None: self.assertEqual( list(self.result), [{"name": "Temperature"}, {"name": "Humidity"}], @@ -36,27 +44,33 @@ def test_result(self): class ListTimeseriesGroupsMultiPageTestCase(TestCase): - @mock_session() - def setUp(self, m): - m.return_value.get.return_value.json.side_effect = [ - { - "count": 3, - "next": "https://mydomain.com/api/stations/42/timeseriesgroups/?page=2", - "previous": None, - "results": [{"name": "Temperature"}, {"name": "Humidity"}], - }, - { - "count": 3, - "next": None, - "previous": "https://mydomain.com/api/stations/42/timeseriesgroups/", - "results": [{"name": "Pressure"}], - }, - ] - self.mock_session = m + def setUp(self) -> None: + self.session_patcher = mock_session( + **{ + "get.return_value.json.side_effect": [ + { + "count": 3, + "next": "https://mydomain.com/api/stations/42/timeseriesgroups/?page=2", + "previous": None, + "results": [{"name": "Temperature"}, {"name": "Humidity"}], + }, + { + "count": 3, + "next": None, + "previous": "https://mydomain.com/api/stations/42/timeseriesgroups/", + "results": [{"name": "Pressure"}], + }, + ] + } + ) + self.mock_session = self.session_patcher.start() client = EnhydrisApiClient("https://mydomain.com") self.result = client.list_timeseries_groups(station_id=42) - def test_requests(self): + def tearDown(self) -> None: + self.session_patcher.stop() + + def test_requests(self) -> None: list(self.result) # Ensure all requests are made self.assertEqual( self.mock_session.return_value.get.call_args_list, @@ -66,7 +80,7 @@ def test_requests(self): ], ) - def test_result(self): + def test_result(self) -> None: self.assertEqual( list(self.result), [{"name": "Temperature"}, {"name": "Humidity"}, {"name": "Pressure"}], @@ -75,66 +89,73 @@ def test_result(self): class ListTimeseriesGroupsErrorTestCase(TestCase): @mock_session(**{"get.return_value.status_code": 500}) - def test_raises_exception_on_error(self, m): + def test_raises_exception_on_error(self, m: MagicMock) -> None: client = EnhydrisApiClient("https://mydomain.com") with self.assertRaises(requests.HTTPError): next(client.list_timeseries_groups(station_id=42)) @mock_session(**{"get.return_value.json.return_value": "not a dict"}) - def test_raises_exception_on_non_json_response(self, m): + def test_raises_exception_on_non_json_response(self, m: MagicMock) -> None: client = EnhydrisApiClient("https://mydomain.com") with self.assertRaises(MalformedResponseError): next(client.list_timeseries_groups(station_id=42)) @mock_session(**{"get.return_value.json.return_value": {"no": "expected"}}) - def test_raises_exception_on_unexpected_json(self, m): + def test_raises_exception_on_unexpected_json(self, m: MagicMock) -> None: client = EnhydrisApiClient("https://mydomain.com") with self.assertRaises(MalformedResponseError): next(client.list_timeseries_groups(station_id=42)) class GetTimeseriesGroupTestCase(TestCase): - @mock_session(**{"get.return_value.json.return_value": {"hello": "world"}}) - def setUp(self, mock_requests_session): - self.mock_requests_session = mock_requests_session + def setUp(self) -> None: + self.session_patcher = mock_session( + **{"get.return_value.json.return_value": {"hello": "world"}} + ) + self.mock_requests_session = self.session_patcher.start() self.client = EnhydrisApiClient("https://mydomain.com") self.data = self.client.get_timeseries_group(42, 43) - def test_makes_request(self): + def tearDown(self) -> None: + self.session_patcher.stop() + + def test_makes_request(self) -> None: self.mock_requests_session.return_value.get.assert_called_once_with( "https://mydomain.com/api/stations/42/timeseriesgroups/43/" ) - def test_returns_data(self): + def test_returns_data(self) -> None: self.assertEqual(self.data, {"hello": "world"}) class PostTimeseriesGroupTestCase(TestCase): - @mock_session(**{"post.return_value.json.return_value": {"id": 43}}) - def setUp(self, mock_requests_session): - self.mock_requests_session = mock_requests_session + def setUp(self) -> None: + self.session_patcher = mock_session( + **{"post.return_value.json.return_value": {"id": 43}} + ) + self.mock_requests_session = self.session_patcher.start() self.client = EnhydrisApiClient("https://mydomain.com") self.data = self.client.post_timeseries_group(42, data={"precision": 2}) - def test_makes_request(self): + def tearDown(self) -> None: + self.session_patcher.stop() + + def test_makes_request(self) -> None: self.mock_requests_session.return_value.post.assert_called_once_with( "https://mydomain.com/api/stations/42/timeseriesgroups/", data={"precision": 2}, ) - def test_returns_id(self): + def test_returns_id(self) -> None: self.assertEqual(self.data, 43) class PutTimeseriesGroupTestCase(TestCase): @mock_session() - def setUp(self, mock_requests_session): - self.mock_requests_session = mock_requests_session + def test_makes_request(self, m: MagicMock) -> None: self.client = EnhydrisApiClient("https://mydomain.com") self.client.put_timeseries_group(42, 43, data={"precision": 2}) - - def test_makes_request(self): - self.mock_requests_session.return_value.put.assert_called_once_with( + m.return_value.put.assert_called_once_with( "https://mydomain.com/api/stations/42/timeseriesgroups/43/", data={"precision": 2}, ) @@ -142,13 +163,10 @@ def test_makes_request(self): class PatchTimeseriesGroupTestCase(TestCase): @mock_session() - def setUp(self, mock_requests_session): - self.mock_requests_session = mock_requests_session + def test_makes_request(self, m: MagicMock) -> None: self.client = EnhydrisApiClient("https://mydomain.com") self.client.patch_timeseries_group(42, 43, data={"precision": 2}) - - def test_makes_request(self): - self.mock_requests_session.return_value.patch.assert_called_once_with( + m.return_value.patch.assert_called_once_with( "https://mydomain.com/api/stations/42/timeseriesgroups/43/", data={"precision": 2}, ) @@ -156,7 +174,7 @@ def test_makes_request(self): class DeleteTimeseriesGroupTestCase(TestCase): @mock_session() - def test_makes_request(self, mock_requests_session): + def test_makes_request(self, mock_requests_session: MagicMock) -> None: self.client = EnhydrisApiClient("https://mydomain.com") self.client.delete_timeseries_group(42, 43) mock_requests_session.return_value.delete.assert_called_once_with( @@ -164,7 +182,7 @@ def test_makes_request(self, mock_requests_session): ) @mock_session(**{"delete.return_value.status_code": 404}) - def test_raises_exception_on_error(self, mock_requests_delete): + def test_raises_exception_on_error(self, mock_requests_delete: MagicMock) -> None: self.client = EnhydrisApiClient("https://mydomain.com") with self.assertRaises(requests.HTTPError): self.client.delete_timeseries_group(42, 43) diff --git a/tests/enhydris_api_client/test_tsdata.py b/tests/enhydris_api_client/test_tsdata.py index fc013b0..445f32a 100644 --- a/tests/enhydris_api_client/test_tsdata.py +++ b/tests/enhydris_api_client/test_tsdata.py @@ -1,8 +1,12 @@ +from __future__ import annotations + import datetime as dt from copy import copy from io import StringIO -from unittest import TestCase +from typing import Any, Optional, cast +from unittest import TestCase, mock +import pandas as pd import requests from enhydris_api_client import EnhydrisApiClient @@ -23,11 +27,11 @@ class ReadTsDataTestCase(TestCase, AssertFrameEqualMixin): url = "http://example.com/api/stations/41/timeseriesgroups/42/timeseries/43/data/" - def _read_tsdata(self, **extra_args): + def _read_tsdata(self, **extra_args: Any) -> HTimeseries: self.client = EnhydrisApiClient("http://example.com") return self.client.read_tsdata(41, 42, 43, **extra_args) - def test_makes_request(self, m): + def test_makes_request(self, m: mock.MagicMock) -> None: self._read_tsdata() m.return_value.get.assert_called_once_with( self.url, @@ -39,11 +43,11 @@ def test_makes_request(self, m): }, ) - def test_returns_data(self, m): + def test_returns_data(self, m: mock.MagicMock) -> None: ahts = self._read_tsdata() self.assert_frame_equal(ahts.data, test_timeseries_htimeseries.data) - def test_uses_timezone(self, m): + def test_uses_timezone(self, m: mock.MagicMock) -> None: self._read_tsdata(timezone="UTC") m.return_value.get.assert_called_once_with( self.url, @@ -57,12 +61,19 @@ def test_uses_timezone(self, m): class ReadTsDataWithStartAndEndDateTestCase(TestCase, AssertFrameEqualMixin): - @mock_session(**{"get.return_value.text": TEST_TIMESERIES_HTS}) - def setUp(self, mock_requests_session): - self.mock_requests_session = mock_requests_session + def setUp(self) -> None: + self.session_patcher = mock_session( + **{"get.return_value.text": TEST_TIMESERIES_HTS} + ) + self.mock_session = self.session_patcher.start() self.client = EnhydrisApiClient("https://mydomain.com") - def _make_request(self, start_tzinfo, end_tzinfo): + def tearDown(self) -> None: + self.session_patcher.stop() + + def _make_request( + self, start_tzinfo: Optional[dt.tzinfo], end_tzinfo: Optional[dt.tzinfo] + ) -> None: self.data = self.client.read_tsdata( 41, 42, @@ -71,9 +82,9 @@ def _make_request(self, start_tzinfo, end_tzinfo): end_date=dt.datetime(2019, 6, 13, 15, 25, tzinfo=end_tzinfo), ) - def test_makes_request(self): + def test_makes_request(self) -> None: self._make_request(dt.timezone.utc, dt.timezone.utc) - self.mock_requests_session.return_value.get.assert_called_once_with( + self.mock_session.return_value.get.assert_called_once_with( "https://mydomain.com/api/stations/41/timeseriesgroups/42/timeseries/43/" "data/", params={ @@ -84,22 +95,22 @@ def test_makes_request(self): }, ) - def test_returns_data(self): + def test_returns_data(self) -> None: self._make_request(dt.timezone.utc, dt.timezone.utc) self.assert_frame_equal(self.data.data, test_timeseries_htimeseries.data) - def test_checks_that_start_date_is_aware(self): + def test_checks_that_start_date_is_aware(self) -> None: with self.assertRaises(ValueError): self._make_request(None, dt.timezone.utc) - def test_checks_that_end_date_is_aware(self): + def test_checks_that_end_date_is_aware(self) -> None: with self.assertRaises(ValueError): self._make_request(dt.timezone.utc, None) class ReadEmptyTsDataTestCase(TestCase, AssertFrameEqualMixin): @mock_session(**{"get.return_value.text": ""}) - def test_returns_data(self, mock_requests_session): + def test_returns_data(self, mock_requests_session: mock.MagicMock) -> None: self.client = EnhydrisApiClient("https://mydomain.com") self.data = self.client.read_tsdata(41, 42, 43) self.assert_frame_equal(self.data.data, HTimeseries().data) @@ -107,7 +118,9 @@ def test_returns_data(self, mock_requests_session): class ReadTsDataErrorTestCase(TestCase): @mock_session(**{"get.return_value.status_code": 404}) - def test_raises_exception_on_error(self, mock_requests_session): + def test_raises_exception_on_error( + self, mock_requests_session: mock.MagicMock + ) -> None: self.client = EnhydrisApiClient("https://mydomain.com") with self.assertRaises(requests.HTTPError): self.client.read_tsdata(41, 42, 43) @@ -115,12 +128,13 @@ def test_raises_exception_on_error(self, mock_requests_session): class PostTsDataTestCase(TestCase): @mock_session() - def test_makes_request(self, mock_requests_session): + def test_makes_request(self, mock_requests_session: mock.MagicMock) -> None: client = EnhydrisApiClient("https://mydomain.com") client.post_tsdata(41, 42, 43, test_timeseries_htimeseries) f = StringIO() data = copy(test_timeseries_htimeseries.data) - data.index = data.index.tz_convert("UTC") + index = cast(pd.DatetimeIndex, data.index) + data.index = index.tz_convert("UTC") data.to_csv(f, header=False) mock_requests_session.return_value.post.assert_called_once_with( "https://mydomain.com/api/stations/41/timeseriesgroups/42/timeseries/43/" @@ -129,7 +143,9 @@ def test_makes_request(self, mock_requests_session): ) @mock_session(**{"post.return_value.status_code": 404}) - def test_raises_exception_on_error(self, mock_requests_session): + def test_raises_exception_on_error( + self, mock_requests_session: mock.MagicMock + ) -> None: client = EnhydrisApiClient("https://mydomain.com") with self.assertRaises(requests.HTTPError): client.post_tsdata(41, 42, 43, test_timeseries_htimeseries) @@ -139,19 +155,19 @@ def test_raises_exception_on_error(self, mock_requests_session): class GetTsEndDateTestCase(TestCase): url = "http://mydom.com/api/stations/41/timeseriesgroups/42/timeseries/43/bottom/" - def _get_ts_end_date(self, **extra_args): + def _get_ts_end_date(self, **extra_args: Any) -> Optional[dt.datetime]: self.client = EnhydrisApiClient("http://mydom.com") return self.client.get_ts_end_date(41, 42, 43, **extra_args) - def test_makes_request(self, m): + def test_makes_request(self, m: mock.MagicMock) -> None: self._get_ts_end_date() m.return_value.get.assert_called_once_with(self.url, params={"timezone": None}) - def test_returns_date(self, m): + def test_returns_date(self, m: mock.MagicMock) -> None: result = self._get_ts_end_date() self.assertEqual(result, dt.datetime(2014, 1, 5, 8, 0)) - def test_uses_timezone(self, m): + def test_uses_timezone(self, m: mock.MagicMock) -> None: self._get_ts_end_date(timezone="Etc/GMT-2") m.return_value.get.assert_called_once_with( self.url, params={"timezone": "Etc/GMT-2"} @@ -160,7 +176,7 @@ def test_uses_timezone(self, m): class GetTsEndDateErrorTestCase(TestCase): @mock_session(**{"get.return_value.status_code": 404}) - def test_checks_response_code(self, mock_requests_session): + def test_checks_response_code(self, mock_requests_session: mock.MagicMock) -> None: client = EnhydrisApiClient("https://mydomain.com") with self.assertRaises(requests.HTTPError): client.get_ts_end_date(41, 42, 43) @@ -168,7 +184,7 @@ def test_checks_response_code(self, mock_requests_session): class GetTsEndDateEmptyTestCase(TestCase): @mock_session(**{"get.return_value.text": ""}) - def test_returns_date(self, mock_requests_session): + def test_returns_date(self, mock_requests_session: mock.MagicMock) -> None: client = EnhydrisApiClient("https://mydomain.com") date = client.get_ts_end_date(41, 42, 43) self.assertIsNone(date) diff --git a/tests/enhydris_cache/test_cli.py b/tests/enhydris_cache/test_cli.py index 5d0bc07..b28e2a4 100644 --- a/tests/enhydris_cache/test_cli.py +++ b/tests/enhydris_cache/test_cli.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import datetime as dt import json import os @@ -6,11 +8,12 @@ import tempfile import textwrap from io import StringIO +from typing import Any, Dict, List from unittest import TestCase, skipUnless -from unittest.mock import patch +from unittest.mock import MagicMock, patch import click -from click.testing import CliRunner +from click.testing import CliRunner, Result from enhydris_api_client import EnhydrisApiClient from enhydris_cache import cli @@ -18,28 +21,28 @@ class NonExistentConfigFileTestCase(TestCase): - def setUp(self): + def setUp(self) -> None: runner = CliRunner(mix_stderr=False) - self.result = runner.invoke(cli.main, ["nonexistent.conf"]) + self.result: Result = runner.invoke(cli.main, ["nonexistent.conf"]) - def test_exit_status(self): + def test_exit_status(self) -> None: self.assertEqual(self.result.exit_code, 1) - def test_error_message(self): + def test_error_message(self) -> None: self.assertIn( "No such file or directory: 'nonexistent.conf'", self.result.stderr ) class ConfigurationTestCase(TestCase): - def setUp(self): + def setUp(self) -> None: self.tempdir = tempfile.mkdtemp() self.configfilename = os.path.join(self.tempdir, "enhydris-cache.conf") - def tearDown(self): + def tearDown(self) -> None: shutil.rmtree(self.tempdir) - def test_nonexistent_log_level_raises_error(self): + def test_nonexistent_log_level_raises_error(self) -> None: with open(self.configfilename, "w") as f: f.write( textwrap.dedent( @@ -54,7 +57,7 @@ def test_nonexistent_log_level_raises_error(self): with self.assertRaisesRegex(click.ClickException, msg): cli.App(self.configfilename).run() - def test_missing_base_url_parameter_raises_error(self): + def test_missing_base_url_parameter_raises_error(self) -> None: with open(self.configfilename, "w") as f: f.write( textwrap.dedent( @@ -75,7 +78,7 @@ def test_missing_base_url_parameter_raises_error(self): with self.assertRaisesRegex(click.ClickException, msg): cli.App(self.configfilename).run() - def test_missing_station_id_parameter_raises_error(self): + def test_missing_station_id_parameter_raises_error(self) -> None: with open(self.configfilename, "w") as f: f.write( textwrap.dedent( @@ -96,7 +99,7 @@ def test_missing_station_id_parameter_raises_error(self): with self.assertRaisesRegex(click.ClickException, msg): cli.App(self.configfilename).run() - def test_missing_timeseries_group_id_parameter_raises_error(self): + def test_missing_timeseries_group_id_parameter_raises_error(self) -> None: with open(self.configfilename, "w") as f: f.write( textwrap.dedent( @@ -117,7 +120,7 @@ def test_missing_timeseries_group_id_parameter_raises_error(self): with self.assertRaisesRegex(click.ClickException, msg): cli.App(self.configfilename).run() - def test_missing_timeseries_id_parameter_raises_error(self): + def test_missing_timeseries_id_parameter_raises_error(self) -> None: with open(self.configfilename, "w") as f: f.write( textwrap.dedent( @@ -138,7 +141,7 @@ def test_missing_timeseries_id_parameter_raises_error(self): with self.assertRaisesRegex(click.ClickException, msg): cli.App(self.configfilename).run() - def test_missing_file_parameter_raises_error(self): + def test_missing_file_parameter_raises_error(self) -> None: with open(self.configfilename, "w") as f: f.write( textwrap.dedent( @@ -159,7 +162,7 @@ def test_missing_file_parameter_raises_error(self): with self.assertRaisesRegex(click.ClickException, msg): cli.App(self.configfilename).run() - def test_wrong_station_id_parameter_raises_error(self): + def test_wrong_station_id_parameter_raises_error(self) -> None: with open(self.configfilename, "w") as f: f.write( textwrap.dedent( @@ -180,7 +183,7 @@ def test_wrong_station_id_parameter_raises_error(self): with self.assertRaisesRegex(click.ClickException, "not a valid integer"): cli.App(self.configfilename).run() - def test_wrong_timeseries_group_id_parameter_raises_error(self): + def test_wrong_timeseries_group_id_parameter_raises_error(self) -> None: with open(self.configfilename, "w") as f: f.write( textwrap.dedent( @@ -201,7 +204,7 @@ def test_wrong_timeseries_group_id_parameter_raises_error(self): with self.assertRaisesRegex(click.ClickException, "not a valid integer"): cli.App(self.configfilename).run() - def test_wrong_timeseries_id_parameter_raises_error(self): + def test_wrong_timeseries_id_parameter_raises_error(self) -> None: with open(self.configfilename, "w") as f: f.write( textwrap.dedent( @@ -223,7 +226,7 @@ def test_wrong_timeseries_id_parameter_raises_error(self): cli.App(self.configfilename).run() @patch("enhydris_cache.cli.App._execute") - def test_correct_configuration_executes(self, m): + def test_correct_configuration_executes(self, m: MagicMock) -> None: with open(self.configfilename, "w") as f: f.write( f"""\ @@ -243,7 +246,7 @@ def test_correct_configuration_executes(self, m): m.assert_called_once_with() @patch("enhydris_cache.cli.App._execute") - def test_missing_auth_token_makes_it_none(self, m): + def test_missing_auth_token_makes_it_none(self, m: MagicMock) -> None: with open(self.configfilename, "w") as f: f.write( f"""\ @@ -261,10 +264,11 @@ def test_missing_auth_token_makes_it_none(self, m): ) app = cli.App(self.configfilename) app.run() + assert app.config is not None self.assertIsNone(app.config.timeseries_group[0]["auth_token"]) @patch("enhydris_cache.cli.App._execute") - def test_creates_log_file(self, *args): + def test_creates_log_file(self, *mock_objects: MagicMock) -> None: logfilename = os.path.join(self.tempdir, "enhydris_cache.log") with open(self.configfilename, "w") as f: f.write( @@ -321,15 +325,17 @@ class EnhydrisCacheE2eTestCase(TestCase): timeseries1_bottom = test_timeseries1.splitlines(True)[-1] timeseries2_bottom = test_timeseries2.splitlines(True)[-1] - def setUp(self): + def setUp(self) -> None: self.tempdir = tempfile.mkdtemp() self.config_file = os.path.join(self.tempdir, "enhydris_cache.conf") - self.saved_argv = sys.argv + self.saved_argv: List[str] = sys.argv sys.argv = ["enhydris_cache", "--traceback", self.config_file] self.savedcwd = os.getcwd() # Create two stations, each one with a time series - self.parms = json.loads(os.getenv("PTHELMA_TEST_ENHYDRIS_SERVER")) + raw_parameters = os.getenv("PTHELMA_TEST_ENHYDRIS_SERVER") + assert raw_parameters is not None + self.parms: Dict[str, Any] = json.loads(raw_parameters) self.station_id = self.parms["station_id"] self.timeseries_group_id = self.parms["timeseries_group_id"] self.api_client = EnhydrisApiClient( @@ -398,7 +404,7 @@ def setUp(self): ) ) - def tearDown(self): + def tearDown(self) -> None: self.api_client.delete_timeseries( self.station_id, self.timeseries_group_id, self.timeseries2_id ) @@ -410,7 +416,7 @@ def tearDown(self): sys.argv = self.saved_argv self.api_client.__exit__() - def test_execute(self): + def test_execute(self) -> None: application = cli.App(self.config_file) # Check that the two files don't exist yet diff --git a/tests/enhydris_cache/test_enhydris_cache.py b/tests/enhydris_cache/test_enhydris_cache.py index 508f98e..8eff721 100644 --- a/tests/enhydris_cache/test_enhydris_cache.py +++ b/tests/enhydris_cache/test_enhydris_cache.py @@ -1,16 +1,19 @@ +from __future__ import annotations + import datetime as dt import os import shutil import tempfile import textwrap from io import StringIO +from typing import Dict, Optional from unittest import TestCase, mock from enhydris_api_client import EnhydrisApiClient -from enhydris_cache import TimeseriesCache +from enhydris_cache import TimeseriesCache, TimeseriesGroup from htimeseries import HTimeseries -test_timeseries = { +test_timeseries: Dict[str, str] = { "42_all": textwrap.dedent( """\ 2014-01-01 08:00,11, @@ -39,14 +42,20 @@ def mock_read_tsdata( - station_id, timeseries_group_id, timeseries_id, start_date=None, end_date=None -): + station_id: int, + timeseries_group_id: int, + timeseries_id: int, + start_date: Optional[dt.datetime] = None, + end_date: Optional[dt.datetime] = None, +) -> HTimeseries: result = _get_hts_object(timeseries_id, start_date) _set_hts_attributes(result, timeseries_id) return result -def _get_hts_object(timeseries_id, start_date): +def _get_hts_object( + timeseries_id: int, start_date: Optional[dt.datetime] +) -> HTimeseries: timeseries_top = HTimeseries( StringIO(test_timeseries[f"{timeseries_id}_top"]), default_tzinfo=UTC_PLUS_2 ) @@ -59,14 +68,14 @@ def _get_hts_object(timeseries_id, start_date): return result -def _set_hts_attributes(hts, timeseries_id): +def _set_hts_attributes(hts: HTimeseries, timeseries_id: int) -> None: hts.time_step = "D" hts.precision = 0 if timeseries_id == 42 else 2 hts.comment = "Très importante" class TimeseriesCacheTestCase(TestCase): - def setUp(self): + def setUp(self) -> None: self.api_client = EnhydrisApiClient("https://mydomain.com") # Temporary directory for cache files @@ -74,21 +83,21 @@ def setUp(self): self.savedcwd = os.getcwd() os.chdir(self.tempdir) - def tearDown(self): + def tearDown(self) -> None: os.chdir(self.savedcwd) shutil.rmtree(self.tempdir) @mock.patch( "enhydris_cache.enhydris_cache.EnhydrisApiClient", - **{ + **{ # type: ignore[misc] "return_value.__enter__.return_value.read_tsdata.side_effect": ( mock_read_tsdata ) }, ) - def test_update(self, mock_api_client): + def test_update(self, mock_api_client: mock.MagicMock) -> None: - two_timeseries = [ + two_timeseries: list[TimeseriesGroup] = [ { "base_url": "https://mydomain.com", "station_id": 2, diff --git a/tests/evaporation/test_cli.py b/tests/evaporation/test_cli.py index cc64c70..e1bbad6 100644 --- a/tests/evaporation/test_cli.py +++ b/tests/evaporation/test_cli.py @@ -6,13 +6,15 @@ import textwrap from io import StringIO from pathlib import Path +from typing import Any, TextIO, cast from unittest import TestCase -from unittest.mock import patch +from unittest.mock import MagicMock, patch import click import numpy as np import pandas as pd -from click.testing import CliRunner +from click.testing import CliRunner, Result +from numpy.typing import NDArray from osgeo import gdal, osr from evaporation import cli @@ -23,7 +25,7 @@ gdal.UseExceptions() -def create_geotiff_file(filename, value): +def create_geotiff_file(filename: str, value: NDArray[np.float_]) -> None: geo_transform = (-16.25, 1.0, 0, 16.217, 0, 1.0) wgs84 = osr.SpatialReference() wgs84.ImportFromEPSG(4326) @@ -38,30 +40,35 @@ def create_geotiff_file(filename, value): class NonExistentConfigFileTestCase(TestCase): - def setUp(self): + result: Result + + def setUp(self) -> None: runner = CliRunner(mix_stderr=False) self.result = runner.invoke(cli.main, ["nonexistent.conf"]) - def test_exit_status(self): + def test_exit_status(self) -> None: self.assertEqual(self.result.exit_code, 1) - def test_error_message(self): + def test_error_message(self) -> None: self.assertIn( "No such file or directory: 'nonexistent.conf'", self.result.stderr ) class WrongPointConfigurationTestCase(TestCase): - def setUp(self): + tempdir: str + configfilename: str + + def setUp(self) -> None: self.tempdir = tempfile.mkdtemp() self.configfilename = os.path.join(self.tempdir, "evaporation.conf") htsfilename = os.path.join(self.tempdir, "wind_speed.hts") Path(htsfilename).touch() - def tearDown(self): + def tearDown(self) -> None: shutil.rmtree(self.tempdir) - def test_nonexistent_log_level_raises_error(self): + def test_nonexistent_log_level_raises_error(self) -> None: with open(self.configfilename, "w") as f: f.write( textwrap.dedent( @@ -82,7 +89,7 @@ def test_nonexistent_log_level_raises_error(self): with self.assertRaisesRegex(click.ClickException, msg): cli.App(self.configfilename).run() - def test_missing_step_raises_error(self): + def test_missing_step_raises_error(self) -> None: with open(self.configfilename, "w") as f: f.write( textwrap.dedent( @@ -98,7 +105,7 @@ def test_missing_step_raises_error(self): with self.assertRaisesRegex(click.ClickException, "time_step"): cli.App(self.configfilename).run() - def test_missing_albedo_raises_error(self): + def test_missing_albedo_raises_error(self) -> None: with open(self.configfilename, "w") as f: f.write( textwrap.dedent( @@ -116,16 +123,19 @@ def test_missing_albedo_raises_error(self): class WrongSpatialConfigurationTestCase(TestCase): - def setUp(self): + tempdir: str + configfilename: str + + def setUp(self) -> None: self.tempdir = tempfile.mkdtemp() self.configfilename = os.path.join(self.tempdir, "evaporation.conf") htsfilename = os.path.join(self.tempdir, "wind_speed.tif") Path(htsfilename).touch() - def tearDown(self): + def tearDown(self) -> None: shutil.rmtree(self.tempdir) - def test_missing_elevation_raises_error(self): + def test_missing_elevation_raises_error(self) -> None: with open(self.configfilename, "w") as f: f.write( textwrap.dedent( @@ -141,7 +151,7 @@ def test_missing_elevation_raises_error(self): with self.assertRaisesRegex(click.ClickException, "elevation"): cli.App(self.configfilename).run() - def test_single_albedo_with_wrong_domain_float_inputs(self): + def test_single_albedo_with_wrong_domain_float_inputs(self) -> None: with open(self.configfilename, "w") as f: f.write( textwrap.dedent( @@ -158,7 +168,7 @@ def test_single_albedo_with_wrong_domain_float_inputs(self): with self.assertRaisesRegex(ValueError, "Albedo must be between 0.0 and 1.0"): cli.App(self.configfilename).run() - def test_seasonal_albedo_configuration_with_not_enough_arguments(self): + def test_seasonal_albedo_configuration_with_not_enough_arguments(self) -> None: with open(self.configfilename, "w") as f: f.write( textwrap.dedent( @@ -176,7 +186,7 @@ def test_seasonal_albedo_configuration_with_not_enough_arguments(self): with self.assertRaisesRegex(ValueError, msg): cli.App(self.configfilename).run() - def test_seasonal_albedo_with_wrong_domain_mixin_inputs(self): + def test_seasonal_albedo_with_wrong_domain_mixin_inputs(self) -> None: with open(self.configfilename, "w") as f: f.write( textwrap.dedent( @@ -196,17 +206,20 @@ def test_seasonal_albedo_with_wrong_domain_mixin_inputs(self): class CorrectPointConfigurationTestCase(TestCase): - def setUp(self): + tempdir: str + configfilename: str + + def setUp(self) -> None: self.tempdir = tempfile.mkdtemp() self.configfilename = os.path.join(self.tempdir, "evaporation.conf") htsfilename = os.path.join(self.tempdir, "wind_speed.hts") Path(htsfilename).touch() - def tearDown(self): + def tearDown(self) -> None: shutil.rmtree(self.tempdir) @patch("evaporation.cli.ProcessAtPoint") - def test_executes(self, m): + def test_executes(self, m: MagicMock) -> None: with open(self.configfilename, "w") as f: f.write( textwrap.dedent( @@ -225,7 +238,7 @@ def test_executes(self, m): m.return_value.execute.assert_called_once_with() @patch("evaporation.cli.ProcessAtPoint") - def test_albedo_configuration_as_one_number(self, m): + def test_albedo_configuration_as_one_number(self, m: MagicMock) -> None: with open(self.configfilename, "w") as f: f.write( textwrap.dedent( @@ -243,14 +256,17 @@ def test_albedo_configuration_as_one_number(self, m): class CorrectSpatialConfigurationTestCase(TestCase): - def setUp(self): + tempdir: str + configfilename: str + + def setUp(self) -> None: self.tempdir = tempfile.mkdtemp() self.configfilename = os.path.join(self.tempdir, "evaporation.conf") htsfilename = os.path.join(self.tempdir, "wind_speed.tif") Path(htsfilename).touch() self._create_albedo_files() - def _create_albedo_files(self): + def _create_albedo_files(self) -> None: items = ( "00", "01", @@ -270,11 +286,11 @@ def _create_albedo_files(self): filename = os.path.join(self.tempdir, "a{}.tif".format(item)) create_geotiff_file(filename, np.array([[0.23, 0.44]])) - def tearDown(self): + def tearDown(self) -> None: shutil.rmtree(self.tempdir) @patch("evaporation.cli.ProcessSpatial") - def test_albedo_configuration_as_one_grid(self, m): + def test_albedo_configuration_as_one_grid(self, m: MagicMock) -> None: with open(self.configfilename, "w") as f: f.write( textwrap.dedent( @@ -292,7 +308,9 @@ def test_albedo_configuration_as_one_grid(self, m): m.return_value.execute.assert_called_once_with() @patch("evaporation.cli.ProcessSpatial") - def test_seasonal_albedo_configuration_as_12_grids(self, m): + def test_seasonal_albedo_configuration_as_12_grids( + self, m: MagicMock + ) -> None: with open(self.configfilename, "w") as f: f.write( textwrap.dedent( @@ -311,7 +329,9 @@ def test_seasonal_albedo_configuration_as_12_grids(self, m): m.return_value.execute.assert_called_once_with() @patch("evaporation.cli.ProcessSpatial") - def test_seasonal_albedo_configuration_as_mix_numbers_and_grids(self, m): + def test_seasonal_albedo_configuration_as_mix_numbers_and_grids( + self, m: MagicMock + ) -> None: with open(self.configfilename, "w") as f: f.write( textwrap.dedent( @@ -330,7 +350,9 @@ def test_seasonal_albedo_configuration_as_mix_numbers_and_grids(self, m): m.return_value.execute.assert_called_once_with() @patch("evaporation.cli.ProcessSpatial") - def test_run_app_seasonal_albedo_with_float_sample_inputs(self, m): + def test_run_app_seasonal_albedo_with_float_sample_inputs( + self, m: MagicMock + ) -> None: with open(self.configfilename, "w") as f: f.write( textwrap.dedent( @@ -349,7 +371,9 @@ def test_run_app_seasonal_albedo_with_float_sample_inputs(self, m): m.return_value.execute.assert_called_once_with() @patch("evaporation.cli.ProcessSpatial") - def test_run_app_with_seasonal_albedo_with_grid_sample_inputs(self, m): + def test_run_app_with_seasonal_albedo_with_grid_sample_inputs( + self, m: MagicMock + ) -> None: with open(self.configfilename, "w") as f: f.write( textwrap.dedent( @@ -368,7 +392,9 @@ def test_run_app_with_seasonal_albedo_with_grid_sample_inputs(self, m): m.return_value.execute.assert_called_once_with() @patch("evaporation.cli.ProcessSpatial") - def test_run_app_with_seasonal_albedo_with_mix_sample_inputs(self, m): + def test_run_app_with_seasonal_albedo_with_mix_sample_inputs( + self, m: MagicMock + ) -> None: with open(self.configfilename, "w") as f: f.write( textwrap.dedent( @@ -387,7 +413,9 @@ def test_run_app_with_seasonal_albedo_with_mix_sample_inputs(self, m): m.return_value.execute.assert_called_once_with() @patch("evaporation.cli.ProcessSpatial") - def test_seasonal_albedo_configuration_as_12_numbers(self, m): + def test_seasonal_albedo_configuration_as_12_numbers( + self, m: MagicMock + ) -> None: with open(self.configfilename, "w") as f: f.write( textwrap.dedent( @@ -408,7 +436,7 @@ def test_seasonal_albedo_configuration_as_12_numbers(self, m): class CorrectConfigurationWithLogFileTestCase(TestCase): @patch("evaporation.cli.ProcessAtPoint") - def test_creates_log_file(self, *args): + def test_creates_log_file(self, *args: MagicMock) -> None: with tempfile.TemporaryDirectory() as dirname: configfilename = os.path.join(dirname, "evaporation.conf") logfilename = os.path.join(dirname, "vaporize.log") @@ -433,16 +461,27 @@ def test_creates_log_file(self, *args): class HtsTestCase(TestCase): - def setUp(self): + tempdir: str + config_file: str + savedcwd: str + saved_stderr: TextIO + + def setUp(self) -> None: self.tempdir = tempfile.mkdtemp() self.config_file = os.path.join(self.tempdir, "vaporize.conf") self.savedcwd = os.getcwd() - def tearDown(self): + def tearDown(self) -> None: os.chdir(self.savedcwd) shutil.rmtree(self.tempdir) - def setup_input_file(self, step, basename, value, missing=None): + def setup_input_file( + self, + step: str, + basename: str, + value: Any, + missing: str | None = None, + ) -> None: filename = os.path.join(self.tempdir, basename + ".hts") timestamp = step == "hourly" and "2014-10-01 15:00" or "2014-07-06" with open(filename, "w") as f: @@ -454,14 +493,14 @@ def setup_input_file(self, step, basename, value, missing=None): f.write("Altitude={}\n".format(step == "hourly" and "8" or "100")) f.write("\n{},{},\n".format(timestamp, value)) - def setup_hourly_input_files(self, missing=None): + def setup_hourly_input_files(self, missing: str | None = None) -> None: self.setup_input_file("hourly", "temperature", 38, missing=missing) self.setup_input_file("hourly", "humidity", 52, missing=missing) self.setup_input_file("hourly", "wind_speed", 3.3, missing=missing) self.setup_input_file("hourly", "pressure", 1013, missing=missing) self.setup_input_file("hourly", "solar_radiation", 681, missing=missing) - def setup_daily_input_files(self): + def setup_daily_input_files(self) -> None: self.setup_input_file("daily", "temperature_max", 21.5) self.setup_input_file("daily", "temperature_min", 12.3) self.setup_input_file("daily", "humidity_max", 84) @@ -469,7 +508,7 @@ def setup_daily_input_files(self): self.setup_input_file("daily", "wind_speed", 2.078) self.setup_input_file("daily", "sunshine_duration", 9.25) - def setup_config_file(self, time_step): + def setup_config_file(self, time_step: str) -> None: with open(self.config_file, "w") as f: f.write( textwrap.dedent( @@ -492,7 +531,7 @@ def setup_config_file(self, time_step): ) ) - def test_hourly(self): + def test_hourly(self) -> None: self.setup_hourly_input_files() self.setup_config_file("h") @@ -506,16 +545,18 @@ def test_hourly(self): # Check that it has created a file and that the file is correct with open(result_filename) as f: t = HTimeseries(f) + tz = cast(pd.DatetimeIndex, t.data.index).tz expected_result = pd.DataFrame( data={"value": [0.63], "flags": [""]}, - columns=("value", "flags"), - index=[dt.datetime(2014, 10, 1, 15, 0, tzinfo=t.data.index.tz)], + columns=("value", "flags"), # type: ignore + index=[dt.datetime(2014, 10, 1, 15, 0, tzinfo=tz)], # type: ignore ) expected_result.index.name = "date" pd.testing.assert_frame_equal(t.data, expected_result, atol=1e-3) - self.assertEqual(t.data.index.tz.offset, dt.timedelta(hours=-1)) + tz = cast(pd.DatetimeIndex, t.data.index).tz + self.assertEqual(tz.offset, dt.timedelta(hours=-1)) # type: ignore - def test_daily(self): + def test_daily(self) -> None: self.setup_daily_input_files() self.setup_config_file("D") @@ -529,16 +570,18 @@ def test_daily(self): # Check that it has created a file and that the file is correct with open(result_filename) as f: t = HTimeseries(f) + tz = cast(pd.DatetimeIndex, t.data.index).tz expected_result = pd.DataFrame( data={"value": [3.9], "flags": [""]}, - columns=["value", "flags"], - index=[dt.datetime(2014, 7, 6, tzinfo=t.data.index.tz)], + columns=["value", "flags"], # type: ignore + index=[dt.datetime(2014, 7, 6, tzinfo=tz)], # type: ignore ) expected_result.index.name = "date" pd.testing.assert_frame_equal(t.data, expected_result, atol=1e-3) - self.assertEqual(t.data.index.tz.offset, dt.timedelta(hours=-1)) + tz = cast(pd.DatetimeIndex, t.data.index).tz + self.assertEqual(tz.offset, dt.timedelta(hours=-1)) # type: ignore - def test_missing_location(self): + def test_missing_location(self) -> None: self.setup_hourly_input_files(missing="location") self.setup_config_file("h") msg = ( @@ -548,7 +591,7 @@ def test_missing_location(self): with self.assertRaisesRegex(click.ClickException, msg): cli.App(self.config_file).run() - def test_missing_altitude(self): + def test_missing_altitude(self) -> None: self.setup_hourly_input_files(missing="altitude") self.setup_config_file("h") msg = ( @@ -560,7 +603,18 @@ def test_missing_altitude(self): class SpatialTestCase(TestCase): - def setup_input_file(self, variable, value, with_date=True): + tempdir: str + config_file: str + savedcwd: str + saved_stderr: TextIO + timestamp: dt.date | dt.datetime + + def setup_input_file( + self, + variable: str, + value: Any, + with_date: bool = True, + ) -> None: """ Saves value, which is an np array, to a GeoTIFF file whose name is based on variable. @@ -590,7 +644,7 @@ def setup_input_file(self, variable, value, with_date=True): # Close the dataset f = None - def setUp(self): + def setUp(self) -> None: self.tempdir = tempfile.mkdtemp() self.config_file = os.path.join(self.tempdir, "vaporize.conf") self.savedcwd = os.getcwd() @@ -607,11 +661,11 @@ def setUp(self): # Save standard error (some tests change it) self.saved_stderr = sys.stderr - def tearDown(self): + def tearDown(self) -> None: os.chdir(self.savedcwd) shutil.rmtree(self.tempdir) - def test_execute_notz(self): + def test_execute_notz(self) -> None: # Prepare input files without time zone self.timestamp = dt.datetime(2014, 10, 1, 15, 0) self.setup_input_file("temperature-notz", np.array([[38.0, 28.0]])) @@ -655,7 +709,7 @@ def test_execute_notz(self): # Verify the output file still doesn't exist self.assertFalse(os.path.exists(result_filename)) - def test_execute_daily(self): + def test_execute_daily(self) -> None: # Prepare input files self.timestamp = dt.date(2014, 7, 6) self.setup_input_file("temperature_max", np.array([[21.5, 28]])) @@ -719,7 +773,7 @@ def test_execute_daily(self): ) fp = None - def test_execute_daily_with_radiation(self): + def test_execute_daily_with_radiation(self) -> None: """Same as test_execute_daily, except that we use solar radiation instead of sunshine duration.""" # Prepare input files @@ -785,7 +839,7 @@ def test_execute_daily_with_radiation(self): ) fp = None - def test_execute_hourly(self): + def test_execute_hourly(self) -> None: # Prepare input files self.timestamp = dt.datetime(2014, 10, 1, 15, 0, tzinfo=senegal_tzinfo) self.setup_input_file("temperature", np.array([[38.0, 28.0]])) @@ -853,7 +907,7 @@ def test_execute_hourly(self): ) fp = None - def test_execute_hourly_no_pressure(self): + def test_execute_hourly_no_pressure(self) -> None: """Same as test_execute_hourly, but does not have pressure an input; therefore, it will calculate pressure itself.""" # Prepare input files @@ -921,7 +975,7 @@ def test_execute_hourly_no_pressure(self): ) fp = None - def test_execute_hourly_without_sun(self): + def test_execute_hourly_without_sun(self) -> None: # Prepare input files, without solar radiation self.timestamp = dt.datetime(2014, 10, 1, 15, 0, tzinfo=senegal_tzinfo) self.setup_input_file("temperature", np.array([[38.0, 28.0]])) @@ -951,7 +1005,7 @@ def test_execute_hourly_without_sun(self): with self.assertRaisesRegex(click.ClickException, msg): cli.App(self.config_file).run() - def test_execute_with_dem(self): + def test_execute_with_dem(self) -> None: """This is essentially the same as test_execute, but uses a GeoTIFF with a DEM instead of a constant elevation. The numbers are the same, however (all DEM gridpoints have the same value).""" @@ -1011,7 +1065,7 @@ def test_execute_with_dem(self): ) fp = None - def test_execute_with_nodata(self): + def test_execute_with_nodata(self) -> None: """This is essentially the same as test_execute, but the gdal rasters contain cells with nodata.""" diff --git a/tests/evaporation/test_evaporation.py b/tests/evaporation/test_evaporation.py index 8f4d0c6..7c1284a 100644 --- a/tests/evaporation/test_evaporation.py +++ b/tests/evaporation/test_evaporation.py @@ -1,5 +1,6 @@ import datetime as dt import math +from typing import Any, Dict from unittest import TestCase import numpy as np @@ -17,42 +18,48 @@ class SenegalTzinfo(dt.tzinfo): assumption as example 19, in order to get the same result. """ - def utcoffset(self, adate): + def utcoffset(self, adate: dt.datetime | None) -> dt.timedelta: return -dt.timedelta(hours=1) - def dst(self, adate): + def dst(self, adate: dt.datetime | None) -> dt.timedelta: return dt.timedelta(0) -senegal_tzinfo = SenegalTzinfo() +senegal_tzinfo: dt.tzinfo = SenegalTzinfo() class PenmanMonteithTestCase(TestCase): - def test_daily_plain(self): + pmclassvars: Dict[str, Any] + pmvars: Dict[str, Any] + + def test_daily_plain(self) -> None: # Apply Allen et al. (1998) Example 18 page 72. self._get_daily_vars() result = PenmanMonteith(**self.pmclassvars).calculate(**self.pmvars) + assert isinstance(result, float) self.assertAlmostEqual(result, 3.9, places=1) - def test_daily_with_solar_radiation(self): + def test_daily_with_solar_radiation(self) -> None: # Same as above, but instead of sunshine duration we provide the solar radiation # directly. Should get the same result. self._get_daily_vars() del self.pmvars["sunshine_duration"] self.pmvars["solar_radiation"] = 22.07 result = PenmanMonteith(**self.pmclassvars).calculate(**self.pmvars) + assert isinstance(result, float) self.assertAlmostEqual(result, 3.9, places=1) - def test_daily_with_pressure(self): + def test_daily_with_pressure(self) -> None: # Same as above, but instead of letting it calculate pressure we provide it # directly. Should get the same result. self._get_daily_vars() self.pmclassvars["unit_converters"]["pressure"] = lambda x: x / 10 self.pmvars["pressure"] = 1001 result = PenmanMonteith(**self.pmclassvars).calculate(**self.pmvars) + assert isinstance(result, float) self.assertAlmostEqual(result, 3.9, places=1) - def _get_daily_vars(self): + def _get_daily_vars(self) -> None: unit_converters = { # Eq. 47 p. 56 "wind_speed": lambda x: (x * 4.87 / math.log(67.8 * 10 - 5.42)) @@ -74,7 +81,7 @@ def _get_daily_vars(self): "adatetime": dt.date(2014, 7, 6), } - def test_daily_grid(self): + def test_daily_grid(self) -> None: # We use a 1x3 grid, where point (1, 1) is the same as Example 18, # point (1, 2) has some different values, and the elevation at point # (1, 3) is NaN to signify a nodata point. @@ -116,7 +123,7 @@ def test_daily_grid(self): ) np.testing.assert_allclose(result, np.array([3.9, 4.8, float("nan")]), atol=0.1) - def test_hourly(self): + def test_hourly(self) -> None: # Apply Allen et al. (1998) Example 19 page 75. pm = PenmanMonteith( albedo=0.23, @@ -135,6 +142,7 @@ def test_hourly(self): solar_radiation=2.450, adatetime=dt.datetime(2014, 10, 1, 15, 0, tzinfo=senegal_tzinfo), ) + assert isinstance(result, float) self.assertAlmostEqual(result, 0.63, places=2) result = pm.calculate( temperature=28, @@ -144,6 +152,7 @@ def test_hourly(self): solar_radiation=0, adatetime=dt.datetime(2014, 10, 1, 2, 30, tzinfo=senegal_tzinfo), ) + assert isinstance(result, float) self.assertAlmostEqual(result, 0.0, places=2) # Same thing, but let it calculate pressure itself @@ -154,6 +163,7 @@ def test_hourly(self): solar_radiation=2.450, adatetime=dt.datetime(2014, 10, 1, 15, 0, tzinfo=senegal_tzinfo), ) + assert isinstance(result, float) self.assertAlmostEqual(result, 0.63, places=2) result = pm.calculate( temperature=28, @@ -162,9 +172,10 @@ def test_hourly(self): solar_radiation=0, adatetime=dt.datetime(2014, 10, 1, 2, 30, tzinfo=senegal_tzinfo), ) + assert isinstance(result, float) self.assertAlmostEqual(result, 0.0, places=2) - def test_hourly_grid(self): + def test_hourly_grid(self) -> None: # We use a 2x1 grid, where point 1, 1 is the same as Example 19, and # point 1, 2 has some different values. pm = PenmanMonteith( @@ -185,7 +196,7 @@ def test_hourly_grid(self): ) np.testing.assert_almost_equal(result, np.array([0.63, 0.36]), decimal=2) - def test_hourly_with_albedo_grid(self): + def test_hourly_with_albedo_grid(self) -> None: # Apply Allen et al. (1998) Example 19 page 75. pm = PenmanMonteith( albedo=np.array([0.23]), @@ -204,6 +215,8 @@ def test_hourly_with_albedo_grid(self): solar_radiation=2.450, adatetime=dt.datetime(2014, 10, 1, 15, 0, tzinfo=senegal_tzinfo), ) + + assert isinstance(result, np.ndarray) # The following two lines could be written more simply like this: # self.assertAlmostEqual(result, 0.63, places=2) # However, it does not work properly on Python 3 because of a numpy @@ -211,7 +224,7 @@ def test_hourly_with_albedo_grid(self): self.assertEqual(result.size, 1) self.assertAlmostEqual(result[0], 0.63, places=2) - def test_hourly_array_with_seasonal_albedo_grid(self): + def test_hourly_array_with_seasonal_albedo_grid(self) -> None: # We use a 2x1 grid, where point 1, 1 is the same as Example 19, and # point 1, 2 has some different values. pm = PenmanMonteith( @@ -232,7 +245,7 @@ def test_hourly_array_with_seasonal_albedo_grid(self): ) np.testing.assert_almost_equal(result, np.array([0.63, 0.36]), decimal=2) - def test_hourly_with_seasonal_albedo(self): + def test_hourly_with_seasonal_albedo(self) -> None: # Apply Allen et al. (1998) Example 19 page 75. pm = PenmanMonteith( @@ -265,6 +278,7 @@ def test_hourly_with_seasonal_albedo(self): solar_radiation=2.450, adatetime=dt.datetime(2014, 1, 1, 15, 0, tzinfo=senegal_tzinfo), ) + assert isinstance(result, float) self.assertAlmostEqual(result, 0.69, places=2) result = pm.calculate( @@ -275,6 +289,7 @@ def test_hourly_with_seasonal_albedo(self): solar_radiation=2.450, adatetime=dt.datetime(2014, 12, 1, 15, 0, tzinfo=senegal_tzinfo), ) + assert isinstance(result, float) self.assertAlmostEqual(result, 0.56, places=2) result = pm.calculate( @@ -285,11 +300,12 @@ def test_hourly_with_seasonal_albedo(self): solar_radiation=2.450, adatetime=dt.datetime(2014, 10, 1, 15, 0, tzinfo=senegal_tzinfo), ) + assert isinstance(result, float) self.assertAlmostEqual(result, 0.63, places=2) class Cloud2RadiationTestCase(TestCase): - def test_daily(self): + def test_daily(self) -> None: # We test using the example at the bottom of FAO56 p. 50, except that we # replace n/N with (1 - cloud_cover). cloud_cover = 1 - 7.1 / 10.9 diff --git a/tests/haggregate/test_cli.py b/tests/haggregate/test_cli.py index bce31ac..01e52b2 100644 --- a/tests/haggregate/test_cli.py +++ b/tests/haggregate/test_cli.py @@ -1,39 +1,46 @@ import datetime as dt import textwrap +from typing import ClassVar from unittest import TestCase -from unittest.mock import patch +from unittest.mock import MagicMock, patch -from click.testing import CliRunner +from click.testing import CliRunner, Result from haggregate import RegularizationMode, cli class CliUsageErrorTestCase(TestCase): - def setUp(self): + result: Result + + def setUp(self) -> None: runner = CliRunner() self.result = runner.invoke(cli.main, []) - def test_exit_code(self): + def test_exit_code(self) -> None: self.assertTrue(self.result.exit_code > 0) - def test_error_message(self): + def test_error_message(self) -> None: self.assertIn("Usage: main [OPTIONS] CONFIGFILE", self.result.output) class CliConfigFileNotFoundTestCase(TestCase): - def setUp(self): + result: Result + + def setUp(self) -> None: runner = CliRunner() self.result = runner.invoke(cli.main, ["/nonexistent/nonexistent"]) - def test_exit_code(self): + def test_exit_code(self) -> None: self.assertTrue(self.result.exit_code > 0) - def test_error_message(self): + def test_error_message(self) -> None: self.assertIn("No such file or directory", self.result.output) class CliNoTimeSeriesErrorTestCase(TestCase): - def setUp(self): + result: Result + + def setUp(self) -> None: runner = CliRunner() with runner.isolated_filesystem(): with open("config.ini", "w") as f: @@ -49,18 +56,29 @@ def setUp(self): ) self.result = runner.invoke(cli.main, ["config.ini"]) - def test_exit_code(self): + def test_exit_code(self) -> None: self.assertTrue(self.result.exit_code > 0) - def test_error_messages(self): + def test_error_messages(self) -> None: self.assertIn("No time series have been specified", self.result.output) class CliMixin: - @patch("haggregate.cli.HTimeseries", **{"return_value": "my timeseries"}) + configuration: str + result: Result + mock_aggregate: MagicMock + mock_regularize: MagicMock + mock_htimeseries: MagicMock + + @patch("haggregate.cli.HTimeseries", return_value="my timeseries") @patch("haggregate.cli.regularize", return_value="regularized timeseries") @patch("haggregate.cli.aggregate") - def _execute(self, mock_aggregate, mock_regularize, mock_htimeseries): + def _execute( + self, + mock_aggregate: MagicMock, + mock_regularize: MagicMock, + mock_htimeseries: MagicMock, + ) -> None: self.mock_aggregate = mock_aggregate self.mock_regularize = mock_regularize self.mock_htimeseries = mock_htimeseries @@ -87,30 +105,30 @@ class CliTestCase(CliMixin, TestCase): """ ) - def setUp(self): + def setUp(self) -> None: self._execute() - def test_exit_code(self): + def test_exit_code(self) -> None: self.assertEqual(self.result.exit_code, 0) - def test_read_source_file(self): + def test_read_source_file(self) -> None: self.assertEqual(self.mock_htimeseries.call_count, 1) - def test_htimeseries_called_correctly(self): + def test_htimeseries_called_correctly(self) -> None: self.mock_htimeseries.assert_called_once() self.assertEqual( self.mock_htimeseries.call_args[1], {"format": self.mock_htimeseries.FILE, "default_tzinfo": dt.timezone.utc}, ) - def test_regularize_called_correctly(self): + def test_regularize_called_correctly(self) -> None: self.mock_regularize.assert_called_once_with( "my timeseries", new_date_flag="DATEINSERT", mode=RegularizationMode.INTERVAL, ) - def test_aggregate_called_correctly(self): + def test_aggregate_called_correctly(self) -> None: self.mock_aggregate.assert_called_once_with( "regularized timeseries", "D", @@ -120,12 +138,12 @@ def test_aggregate_called_correctly(self): target_timestamp_offset=None, ) - def test_wrote_target_file(self): + def test_wrote_target_file(self) -> None: self.assertEqual(self.mock_aggregate.return_value.write.call_count, 1) class RegularizationModeTestCase(CliMixin, TestCase): - def _run(self, method): + def _run(self, method: str) -> None: self.configuration = textwrap.dedent( f"""\ [General] @@ -141,7 +159,7 @@ def _run(self, method): ) self._execute() - def test_regularize_called_correctly_when_sum(self): + def test_regularize_called_correctly_when_sum(self) -> None: self._run("sum") self.mock_regularize.assert_called_once_with( "my timeseries", @@ -149,7 +167,7 @@ def test_regularize_called_correctly_when_sum(self): mode=RegularizationMode.INTERVAL, ) - def test_regularize_called_correctly_when_mean(self): + def test_regularize_called_correctly_when_mean(self) -> None: self._run("mean") self.mock_regularize.assert_called_once_with( "my timeseries", @@ -157,7 +175,7 @@ def test_regularize_called_correctly_when_mean(self): mode=RegularizationMode.INSTANTANEOUS, ) - def test_regularize_called_correctly_when_max(self): + def test_regularize_called_correctly_when_max(self) -> None: self._run("max") self.mock_regularize.assert_called_once_with( "my timeseries", @@ -165,7 +183,7 @@ def test_regularize_called_correctly_when_max(self): mode=RegularizationMode.INTERVAL, ) - def test_regularize_called_correctly_when_min(self): + def test_regularize_called_correctly_when_min(self) -> None: self._run("min") self.mock_regularize.assert_called_once_with( "my timeseries", @@ -190,23 +208,23 @@ class CliWithTargetTimestampOffsetTestCase(CliMixin, TestCase): """ ) - def setUp(self): + def setUp(self) -> None: self._execute() - def test_exit_code(self): + def test_exit_code(self) -> None: self.assertEqual(self.result.exit_code, 0) - def test_read_source_file(self): + def test_read_source_file(self) -> None: self.assertEqual(self.mock_htimeseries.call_count, 1) - def test_regularize_called_correctly(self): + def test_regularize_called_correctly(self) -> None: self.mock_regularize.assert_called_once_with( "my timeseries", new_date_flag="DATEINSERT", mode=RegularizationMode.INTERVAL, ) - def test_aggregate_called_correctly(self): + def test_aggregate_called_correctly(self) -> None: self.mock_aggregate.assert_called_once_with( "regularized timeseries", "D", @@ -216,5 +234,5 @@ def test_aggregate_called_correctly(self): target_timestamp_offset="1min", ) - def test_wrote_target_file(self): + def test_wrote_target_file(self) -> None: self.assertEqual(self.mock_aggregate.return_value.write.call_count, 1) diff --git a/tests/haggregate/test_haggregate.py b/tests/haggregate/test_haggregate.py index cd36a2c..11d49e1 100644 --- a/tests/haggregate/test_haggregate.py +++ b/tests/haggregate/test_haggregate.py @@ -70,15 +70,19 @@ def test_length(self): self.assertEqual(len(self.result.data), 4) def test_value_1(self): + assert isinstance(self.result.data.loc["2008-02-07 10:00"].value, float) self.assertAlmostEqual(self.result.data.loc["2008-02-07 10:00"].value, 31.25) def test_value_2(self): + assert isinstance(self.result.data.loc["2008-02-07 11:00"].value, float) self.assertAlmostEqual(self.result.data.loc["2008-02-07 11:00"].value, 65.47) def test_value_3(self): + assert isinstance(self.result.data.loc["2008-02-07 12:00"].value, float) self.assertAlmostEqual(self.result.data.loc["2008-02-07 12:00"].value, 69.29) def test_value_4(self): + assert isinstance(self.result.data.loc["2008-02-07 13:00"].value, float) self.assertAlmostEqual(self.result.data.loc["2008-02-07 13:00"].value, 72.77) @@ -152,12 +156,15 @@ def test_length(self): self.assertEqual(len(self.result.data), 3) def test_value_1(self): + assert isinstance(self.result.data.loc["2008-02-07 11:00"].value, float) self.assertAlmostEqual(self.result.data.loc["2008-02-07 11:00"].value, 65.47) def test_value_2(self): + assert isinstance(self.result.data.loc["2008-02-07 12:00"].value, float) self.assertAlmostEqual(self.result.data.loc["2008-02-07 12:00"].value, 69.29) def test_value_3(self): + assert isinstance(self.result.data.loc["2008-02-07 13:00"].value, float) self.assertAlmostEqual(self.result.data.loc["2008-02-07 13:00"].value, 72.77) @@ -172,21 +179,25 @@ def test_length(self): self.assertEqual(len(self.result.data), 4) def test_value_1(self): + assert isinstance(self.result.data.loc["2008-02-07 10:00"].value, float) self.assertAlmostEqual( self.result.data.loc["2008-02-07 10:00"].value, 10.4166667 ) def test_value_2(self): + assert isinstance(self.result.data.loc["2008-02-07 11:00"].value, float) self.assertAlmostEqual( self.result.data.loc["2008-02-07 11:00"].value, 10.9116667 ) def test_value_3(self): + assert isinstance(self.result.data.loc["2008-02-07 12:00"].value, float) self.assertAlmostEqual( self.result.data.loc["2008-02-07 12:00"].value, 11.5483333 ) def test_value_4(self): + assert isinstance(self.result.data.loc["2008-02-07 13:00"].value, float) self.assertAlmostEqual( self.result.data.loc["2008-02-07 13:00"].value, 12.1283333 ) @@ -210,21 +221,25 @@ def test_length(self): self.assertEqual(len(self.result.data), 4) def test_value_1(self): + assert isinstance(self.result.data.loc["2008-02-07 09:59"].value, float) self.assertAlmostEqual( self.result.data.loc["2008-02-07 09:59"].value, 10.4166667 ) def test_value_2(self): + assert isinstance(self.result.data.loc["2008-02-07 10:59"].value, float) self.assertAlmostEqual( self.result.data.loc["2008-02-07 10:59"].value, 10.9116667 ) def test_value_3(self): + assert isinstance(self.result.data.loc["2008-02-07 11:59"].value, float) self.assertAlmostEqual( self.result.data.loc["2008-02-07 11:59"].value, 11.5483333 ) def test_value_4(self): + assert isinstance(self.result.data.loc["2008-02-07 12:59"].value, float) self.assertAlmostEqual( self.result.data.loc["2008-02-07 12:59"].value, 12.1283333 ) @@ -248,21 +263,25 @@ def test_length(self): self.assertEqual(len(self.result.data), 4) def test_value_1(self): + assert isinstance(self.result.data.loc["2008-02-07 10:01"].value, float) self.assertAlmostEqual( self.result.data.loc["2008-02-07 10:01"].value, 10.4166667 ) def test_value_2(self): + assert isinstance(self.result.data.loc["2008-02-07 11:01"].value, float) self.assertAlmostEqual( self.result.data.loc["2008-02-07 11:01"].value, 10.9116667 ) def test_value_3(self): + assert isinstance(self.result.data.loc["2008-02-07 12:01"].value, float) self.assertAlmostEqual( self.result.data.loc["2008-02-07 12:01"].value, 11.5483333 ) def test_value_4(self): + assert isinstance(self.result.data.loc["2008-02-07 13:01"].value, float) self.assertAlmostEqual( self.result.data.loc["2008-02-07 13:01"].value, 12.1283333 ) @@ -279,15 +298,19 @@ def test_length(self): self.assertEqual(len(self.result.data), 4) def test_value_1(self): + assert isinstance(self.result.data.loc["2008-02-07 10:00"].value, float) self.assertAlmostEqual(self.result.data.loc["2008-02-07 10:00"].value, 10.51) def test_value_2(self): + assert isinstance(self.result.data.loc["2008-02-07 11:00"].value, float) self.assertAlmostEqual(self.result.data.loc["2008-02-07 11:00"].value, 11.23) def test_value_3(self): + assert isinstance(self.result.data.loc["2008-02-07 12:00"].value, float) self.assertAlmostEqual(self.result.data.loc["2008-02-07 12:00"].value, 11.8) def test_value_4(self): + assert isinstance(self.result.data.loc["2008-02-07 13:00"].value, float) self.assertAlmostEqual(self.result.data.loc["2008-02-07 13:00"].value, 12.24) @@ -302,15 +325,19 @@ def test_length(self): self.assertEqual(len(self.result.data), 4) def test_value_1(self): + assert isinstance(self.result.data.loc["2008-02-07 10:00"].value, float) self.assertAlmostEqual(self.result.data.loc["2008-02-07 10:00"].value, 10.32) def test_value_2(self): + assert isinstance(self.result.data.loc["2008-02-07 11:00"].value, float) self.assertAlmostEqual(self.result.data.loc["2008-02-07 11:00"].value, 10.54) def test_value_3(self): + assert isinstance(self.result.data.loc["2008-02-07 12:00"].value, float) self.assertAlmostEqual(self.result.data.loc["2008-02-07 12:00"].value, 11.41) def test_value_4(self): + assert isinstance(self.result.data.loc["2008-02-07 13:00"].value, float) self.assertAlmostEqual(self.result.data.loc["2008-02-07 13:00"].value, 11.91) @@ -327,9 +354,11 @@ def test_length(self): self.assertEqual(len(self.result.data), 2) def test_value_1(self): + assert isinstance(self.result.data.loc["2008-02-07 12:00"].value, float) self.assertAlmostEqual(self.result.data.loc["2008-02-07 12:00"].value, 166.01) def test_value_2(self): + assert isinstance(self.result.data.loc["2008-02-07 15:00"].value, float) self.assertAlmostEqual(self.result.data.loc["2008-02-07 15:00"].value, 85.08) def test_flags_1(self): @@ -374,6 +403,7 @@ def test_length(self): self.assertEqual(len(self.result.data), 1) def test_value_1(self): + assert isinstance(self.result.data.loc["2005-05-01 02:00"].value, float) self.assertAlmostEqual(self.result.data.loc["2005-05-01 02:00"].value, 3) @@ -385,7 +415,6 @@ def setUp(self): self.ts.title = "hello" self.ts.precision = 1 self.ts.comment = "world" - self.ts.timezone = "EET (+0200)" self.result = aggregate(self.ts, "1h", "sum", min_count=3, missing_flag="MISS") def test_sets_title(self): @@ -402,6 +431,3 @@ def test_sets_comment(self): def test_sets_time_step(self): self.assertEqual(self.result.time_step, "1h") - - def test_sets_timezone(self): - self.assertEqual(self.result.timezone, "EET (+0200)") diff --git a/tests/haggregate/test_regularize.py b/tests/haggregate/test_regularize.py index 85d4744..e3cdfea 100644 --- a/tests/haggregate/test_regularize.py +++ b/tests/haggregate/test_regularize.py @@ -2,6 +2,7 @@ import math import textwrap from io import StringIO +from typing import cast from unittest import TestCase from zoneinfo import ZoneInfo @@ -56,19 +57,23 @@ def test_length(self): self.assertEqual(len(self.result.data), 3) def test_timestamps_are_aware(self): - self.assertEqual(self.result.data.index[0].utcoffset(), dt.timedelta(hours=2)) + timestamp = cast(dt.datetime, self.result.data.index[0]) + self.assertEqual(timestamp.utcoffset(), dt.timedelta(hours=2)) def test_value_1(self): + assert isinstance(self.result.data.loc["2008-02-07 10:30:00+0200"].value, float) self.assertAlmostEqual( self.result.data.loc["2008-02-07 10:30:00+0200"].value, 10.71 ) def test_value_2(self): + assert isinstance(self.result.data.loc["2008-02-07 10:40:00+0200"].value, float) self.assertAlmostEqual( self.result.data.loc["2008-02-07 10:40:00+0200"].value, 10.93 ) def test_value_3(self): + assert isinstance(self.result.data.loc["2008-02-07 10:50:00+0200"].value, float) self.assertAlmostEqual( self.result.data.loc["2008-02-07 10:50:00+0200"].value, 11.10 ) @@ -105,16 +110,19 @@ def test_length(self): self.assertEqual(len(self.result.data), 3) def test_value_1(self): + assert isinstance(self.result.data.loc["2008-02-07 10:30:00+0200"].value, float) self.assertAlmostEqual( self.result.data.loc["2008-02-07 10:30:00+0200"].value, 10.71 ) def test_value_2(self): + assert isinstance(self.result.data.loc["2008-02-07 10:40:00+0200"].value, float) self.assertAlmostEqual( self.result.data.loc["2008-02-07 10:40:00+0200"].value, 10.93 ) def test_value_3(self): + assert isinstance(self.result.data.loc["2008-02-07 10:50:00+0200"].value, float) self.assertAlmostEqual( self.result.data.loc["2008-02-07 10:50:00+0200"].value, 11.10 ) @@ -137,16 +145,19 @@ def test_length(self): self.assertEqual(len(self.result.data), 3) def test_value_1(self): + assert isinstance(self.result.data.loc["2008-02-07 10:30:00+0200"].value, float) self.assertAlmostEqual( self.result.data.loc["2008-02-07 10:30:00+0200"].value, 10.71 ) def test_value_2(self): + assert isinstance(self.result.data.loc["2008-02-07 10:40:00+0200"].value, float) self.assertAlmostEqual( self.result.data.loc["2008-02-07 10:40:00+0200"].value, 10.93 ) def test_value_3(self): + assert isinstance(self.result.data.loc["2008-02-07 10:50:00+0200"].value, float) self.assertAlmostEqual( self.result.data.loc["2008-02-07 10:50:00+0200"].value, 11.10 ) @@ -169,6 +180,7 @@ def test_length(self): self.assertEqual(len(self.result.data), 4) def test_value_1(self): + assert isinstance(self.result.data.loc["2008-02-07 10:30:00+0200"].value, float) self.assertAlmostEqual( self.result.data.loc["2008-02-07 10:30:00+0200"].value, 10.71 ) @@ -179,11 +191,13 @@ def test_value_2(self): ) def test_value_3(self): + assert isinstance(self.result.data.loc["2008-02-07 10:50:00+0200"].value, float) self.assertAlmostEqual( self.result.data.loc["2008-02-07 10:50:00+0200"].value, 11.10 ) def test_value_4(self): + assert isinstance(self.result.data.loc["2008-02-07 11:00:00+0200"].value, float) self.assertAlmostEqual( self.result.data.loc["2008-02-07 11:00:00+0200"].value, 10.93 ) @@ -214,10 +228,12 @@ def setUp(self): def test_interval(self): result = regularize(self.ts, mode=RegularizationMode.INTERVAL) + assert isinstance(result.data.loc["2008-02-07 10:40"].value, float) self.assertTrue(math.isnan(result.data.loc["2008-02-07 10:40"].value)) def test_instantaneous(self): result = regularize(self.ts, mode=RegularizationMode.INSTANTANEOUS) + assert isinstance(result.data.loc["2008-02-07 10:40"].value, float) self.assertAlmostEqual(result.data.loc["2008-02-07 10:40"].value, 10.93) @@ -236,10 +252,12 @@ def setUp(self): def test_interval(self): result = regularize(self.ts, mode=RegularizationMode.INTERVAL) + assert isinstance(result.data.loc["2008-02-07 10:40"].value, float) self.assertTrue(math.isnan(result.data.loc["2008-02-07 10:40"].value)) def test_instantaneous(self): result = regularize(self.ts, mode=RegularizationMode.INSTANTANEOUS) + assert isinstance(result.data.loc["2008-02-07 10:40"].value, float) self.assertAlmostEqual(result.data.loc["2008-02-07 10:40"].value, 10.93) @@ -257,7 +275,6 @@ def setUp(self): self.ts.title = "hello" self.ts.precision = 1 self.ts.comment = "world" - self.ts.timezone = "EET (+0200)" self.result = regularize(self.ts, mode=RegularizationMode.INTERVAL) def test_sets_title(self): @@ -274,7 +291,4 @@ def test_sets_comment(self): ) def test_sets_time_step(self): - self.assertEqual(self.result.time_step, "10min") - - def test_sets_timezone(self): - self.assertEqual(self.result.timezone, "EET (+0200)") + self.assertEqual(self.result.time_step, "10min") \ No newline at end of file diff --git a/tests/hspatial/test_cli.py b/tests/hspatial/test_cli.py index 1830698..10150ba 100644 --- a/tests/hspatial/test_cli.py +++ b/tests/hspatial/test_cli.py @@ -1,8 +1,11 @@ +from __future__ import annotations + import datetime as dt import os import shutil import tempfile import textwrap +from collections.abc import Sequence from pathlib import Path from unittest import TestCase from unittest.mock import patch @@ -17,7 +20,9 @@ gdal.UseExceptions() -def create_geotiff_file(filename, value): +def create_geotiff_file( + filename: str | os.PathLike[str], value: Sequence[Sequence[float]] +) -> None: geo_transform = (-16.25, 1.0, 0, 16.217, 0, 1.0) wgs84 = osr.SpatialReference() wgs84.ImportFromEPSG(4326) @@ -32,28 +37,28 @@ def create_geotiff_file(filename, value): class NonExistentConfigFileTestCase(TestCase): - def setUp(self): + def setUp(self) -> None: runner = CliRunner(mix_stderr=False) self.result = runner.invoke(cli.main, ["nonexistent.conf"]) - def test_exit_status(self): + def test_exit_status(self) -> None: self.assertEqual(self.result.exit_code, 1) - def test_error_message(self): + def test_error_message(self) -> None: self.assertIn( "No such file or directory: 'nonexistent.conf'", self.result.stderr ) class ConfigurationTestCase(TestCase): - def setUp(self): + def setUp(self) -> None: self.tempdir = tempfile.mkdtemp() self.configfilename = os.path.join(self.tempdir, "hspatial.conf") - def tearDown(self): + def tearDown(self) -> None: shutil.rmtree(self.tempdir) - def test_missing_mask_parameter_raises_error(self): + def test_missing_mask_parameter_raises_error(self) -> None: with open(self.configfilename, "w") as f: f.write( textwrap.dedent( @@ -74,7 +79,7 @@ def test_missing_mask_parameter_raises_error(self): with self.assertRaisesRegex(click.ClickException, msg): cli.App(self.configfilename).run() - def test_nonexistent_log_level_raises_error(self): + def test_nonexistent_log_level_raises_error(self) -> None: with open(self.configfilename, "w") as f: f.write( textwrap.dedent( @@ -97,7 +102,7 @@ def test_nonexistent_log_level_raises_error(self): with self.assertRaisesRegex(click.ClickException, msg): cli.App(self.configfilename).run() - def test_missing_epsg_parameter_raises_error(self): + def test_missing_epsg_parameter_raises_error(self) -> None: with open(self.configfilename, "w") as f: f.write( textwrap.dedent( @@ -118,7 +123,7 @@ def test_missing_epsg_parameter_raises_error(self): with self.assertRaisesRegex(click.ClickException, msg): cli.App(self.configfilename).run() - def test_wrong_epsg_parameter_raises_error(self): + def test_wrong_epsg_parameter_raises_error(self) -> None: with open(self.configfilename, "w") as f: f.write( textwrap.dedent( @@ -140,7 +145,7 @@ def test_wrong_epsg_parameter_raises_error(self): with self.assertRaisesRegex(click.ClickException, msg): cli.App(self.configfilename).run() - def test_missing_filename_prefix_parameter_raises_error(self): + def test_missing_filename_prefix_parameter_raises_error(self) -> None: with open(self.configfilename, "w") as f: f.write( textwrap.dedent( @@ -161,7 +166,7 @@ def test_missing_filename_prefix_parameter_raises_error(self): with self.assertRaisesRegex(click.ClickException, msg): cli.App(self.configfilename).run() - def test_missing_output_dir_parameter_raises_error(self): + def test_missing_output_dir_parameter_raises_error(self) -> None: with open(self.configfilename, "w") as f: f.write( textwrap.dedent( @@ -182,7 +187,7 @@ def test_missing_output_dir_parameter_raises_error(self): with self.assertRaisesRegex(click.ClickException, msg): cli.App(self.configfilename).run() - def test_missing_number_of_output_files_parameter_raises_error(self): + def test_missing_number_of_output_files_parameter_raises_error(self) -> None: with open(self.configfilename, "w") as f: f.write( textwrap.dedent( @@ -203,7 +208,7 @@ def test_missing_number_of_output_files_parameter_raises_error(self): with self.assertRaisesRegex(click.ClickException, msg): cli.App(self.configfilename).run() - def test_wrong_number_of_output_files_parameter_raises_error(self): + def test_wrong_number_of_output_files_parameter_raises_error(self) -> None: with open(self.configfilename, "w") as f: f.write( textwrap.dedent( @@ -225,7 +230,7 @@ def test_wrong_number_of_output_files_parameter_raises_error(self): with self.assertRaisesRegex(click.ClickException, msg): cli.App(self.configfilename).run() - def test_missing_method_raises_error(self): + def test_missing_method_raises_error(self) -> None: with open(self.configfilename, "w") as f: f.write( textwrap.dedent( @@ -246,7 +251,7 @@ def test_missing_method_raises_error(self): with self.assertRaisesRegex(click.ClickException, msg): cli.App(self.configfilename).run() - def test_missing_files_parameter_raises_error(self): + def test_missing_files_parameter_raises_error(self) -> None: with open(self.configfilename, "w") as f: f.write( textwrap.dedent( @@ -265,7 +270,7 @@ def test_missing_files_parameter_raises_error(self): with self.assertRaisesRegex(click.ClickException, msg): cli.App(self.configfilename).run() - def test_wrong_alpha_parameter_raises_error(self): + def test_wrong_alpha_parameter_raises_error(self) -> None: with open(self.configfilename, "w") as f: f.write( textwrap.dedent( @@ -289,7 +294,7 @@ def test_wrong_alpha_parameter_raises_error(self): cli.App(self.configfilename).run() @patch("hspatial.cli.App._execute") - def test_correct_configuration_executes(self, m): + def test_correct_configuration_executes(self, m) -> None: with open(self.configfilename, "w") as f: f.write( textwrap.dedent( @@ -311,7 +316,7 @@ def test_correct_configuration_executes(self, m): m.assert_called_once_with() @patch("hspatial.cli.App._execute") - def test_creates_log_file(self, *args): + def test_creates_log_file(self, *args) -> None: logfilename = os.path.join(self.tempdir, "hspatial.log") with open(self.configfilename, "w") as f: f.write( @@ -334,7 +339,9 @@ def test_creates_log_file(self, *args): self.assertTrue(os.path.exists(logfilename)) -def _create_test_data(filename1, filename2): +def _create_test_data( + filename1: str | os.PathLike[str], filename2: str | os.PathLike[str] +) -> None: with open(filename1, "w") as f: f.write( textwrap.dedent( @@ -368,7 +375,7 @@ def _create_test_data(filename1, filename2): class AppTestCase(TestCase): - def setUp(self): + def setUp(self) -> None: self.tempdir = tempfile.mkdtemp() self.output_dir = os.path.join(self.tempdir, "output") self.config_file = os.path.join(self.tempdir, "spatialize.conf") @@ -378,7 +385,7 @@ def setUp(self): _create_test_data(self.filenames[0], self.filenames[1]) self._prepare_config_file() - def _prepare_config_file(self, number_of_output_files=24): + def _prepare_config_file(self, number_of_output_files: int = 24) -> None: with open(self.config_file, "w") as f: f.write( textwrap.dedent( @@ -396,7 +403,7 @@ def _prepare_config_file(self, number_of_output_files=24): ) ) - def _create_mask_file(self): + def _create_mask_file(self) -> None: mask_filename = os.path.join(self.tempdir, "mask.tif") mask = gdal.GetDriverByName("GTiff").Create( mask_filename, 640, 480, 1, gdal.GDT_Float32 @@ -411,10 +418,10 @@ def _create_mask_file(self): ) mask = None - def tearDown(self): + def tearDown(self) -> None: shutil.rmtree(self.tempdir) - def test_get_last_dates(self): + def test_get_last_dates(self) -> None: application = cli.App(self.config_file) tzinfo = TzinfoFromString("+0200") self.assertEqual( @@ -435,7 +442,7 @@ def test_get_last_dates(self): ) @patch("hspatial.cli.App._execute") - def test_dates_to_calculate(self, *args): + def test_dates_to_calculate(self, *args) -> None: application = cli.App(self.config_file) application.run() tzinfo = TzinfoFromString("+0200") @@ -483,7 +490,7 @@ def test_dates_to_calculate(self, *args): ], ) - def test_dates_to_calculate_error1(self): + def test_dates_to_calculate_error1(self) -> None: self._create_mask_file() application = cli.App(self.config_file) with open(self.filenames[0], "a") as f: @@ -505,7 +512,7 @@ def test_dates_to_calculate_error1(self): application.run() @patch("hspatial.cli.App._execute") - def test_date_fmt(self, m): + def test_date_fmt(self, m) -> None: application = cli.App(self.config_file) application.run() @@ -553,12 +560,12 @@ def test_date_fmt(self, m): application._date_fmt @patch("hspatial.cli.App._execute") - def test_delete_obsolete_files(self, m): + def test_delete_obsolete_files(self, m) -> None: application = cli.App(self.config_file) application.run() # Create three files - prefix = application.config.filename_prefix + prefix = application.config.filename_prefix # type: ignore[attr-defined] filename1 = os.path.join(self.output_dir, "{}-1.tif".format(prefix)) filename2 = os.path.join(self.output_dir, "{}-2.tif".format(prefix)) filename3 = os.path.join(self.output_dir, "{}-3.tif".format(prefix)) @@ -584,7 +591,7 @@ def test_delete_obsolete_files(self, m): self.assertTrue(os.path.exists(filename2)) self.assertTrue(os.path.exists(filename3)) - def test_run(self): + def test_run(self) -> None: application = cli.App(self.config_file) # Create a mask @@ -595,7 +602,7 @@ def test_run(self): application.run() # Check that it has created three files - full_prefix = os.path.join(self.output_dir, application.config.filename_prefix) + full_prefix = os.path.join(self.output_dir, application.config.filename_prefix) # type: ignore[attr-defined] self.assertTrue(os.path.exists(full_prefix + "-2014-04-30-15-00+0200.tif")) self.assertTrue(os.path.exists(full_prefix + "-2014-04-30-14-00+0200.tif")) self.assertTrue(os.path.exists(full_prefix + "-2014-04-30-13-00+0200.tif")) @@ -610,7 +617,7 @@ def test_run(self): # unit-tested lower level functions in detail, the above is reasonably # sufficient for us to know that it works. - def test_no_timezone(self): + def test_no_timezone(self) -> None: self._remove_timezone_from_file(self.filenames[0]) self._remove_timezone_from_file(self.filenames[1]) application = cli.App(self.config_file) @@ -620,7 +627,7 @@ def test_no_timezone(self): with self.assertRaisesRegex(click.ClickException, msg): application.run() - def _remove_timezone_from_file(self, filename): + def _remove_timezone_from_file(self, filename: str | os.PathLike[str]) -> None: with open(filename, "r") as f: lines = f.readlines() with open(filename, "w") as f: diff --git a/tests/hspatial/test_hspatial.py b/tests/hspatial/test_hspatial.py index a88438c..d762d1d 100644 --- a/tests/hspatial/test_hspatial.py +++ b/tests/hspatial/test_hspatial.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import datetime as dt import math import os @@ -6,11 +8,12 @@ import textwrap from stat import S_IREAD, S_IRGRP, S_IROTH from time import sleep +from typing import Callable from unittest import TestCase import numpy as np import pandas as pd -from django.contrib.gis.gdal import GDALRaster +from django.contrib.gis.gdal import GDALRaster # type: ignore from django.contrib.gis.geos import Point as GeoDjangoPoint from osgeo import gdal, ogr, osr @@ -23,7 +26,7 @@ UTC_PLUS_2 = dt.timezone(dt.timedelta(hours=2)) -def add_point_to_layer(layer, x, y, value): +def add_point_to_layer(layer: ogr.Layer, x: float, y: float, value: float) -> None: p = ogr.Geometry(ogr.wkbPoint) p.AddPoint(x, y) f = ogr.Feature(layer.GetLayerDefn()) @@ -33,24 +36,30 @@ def add_point_to_layer(layer, x, y, value): class IdwTestCase(TestCase): - def setUp(self): + data_layer: ogr.Layer | None + point: ogr.Geometry | None + + def setUp(self) -> None: self.point = ogr.Geometry(ogr.wkbPoint) self.point.AddPoint(5.1, 2.5) self.data_source = ogr.GetDriverByName("memory").CreateDataSource("tmp") self.data_layer = self.data_source.CreateLayer("test") + assert self.data_layer is not None self.data_layer.CreateField(ogr.FieldDefn("value", ogr.OFTReal)) - def tearDown(self): + def tearDown(self) -> None: self.data_layer = None self.data_source = None self.point = None - def test_idw_single_point(self): + def test_idw_single_point(self) -> None: + assert self.data_layer is not None and self.point is not None add_point_to_layer(self.data_layer, 5.3, 6.4, 42.8) self.assertAlmostEqual(hspatial.idw(self.point, self.data_layer), 42.8) - def test_idw_three_points(self): + def test_idw_three_points(self) -> None: + assert self.data_layer is not None and self.point is not None add_point_to_layer(self.data_layer, 6.4, 7.8, 33.0) add_point_to_layer(self.data_layer, 9.5, 7.4, 94.0) add_point_to_layer(self.data_layer, 7.1, 4.9, 67.7) @@ -61,7 +70,8 @@ def test_idw_three_points(self): hspatial.idw(self.point, self.data_layer, alpha=2.0), 64.188, places=3 ) - def test_idw_point_with_nan(self): + def test_idw_point_with_nan(self) -> None: + assert self.data_layer is not None and self.point is not None add_point_to_layer(self.data_layer, 6.4, 7.8, 33.0) add_point_to_layer(self.data_layer, 9.5, 7.4, 94.0) add_point_to_layer(self.data_layer, 7.1, 4.9, 67.7) @@ -78,7 +88,7 @@ class IntegrateTestCase(TestCase): # The calculations for this test have been made manually in # data/spatial_calculations.ods, tab test_integrate_idw. - def setUp(self): + def setUp(self) -> None: # We will test on a 7x15 grid self.mask = np.zeros((7, 15), np.int8) self.mask[3, 3] = 1 @@ -104,10 +114,10 @@ def setUp(self): add_point_to_layer(self.data_layer, 125.7, 19.0, 24.0) add_point_to_layer(self.data_layer, 9.8, 57.1, 95.4) - def tearDown(self): + def tearDown(self) -> None: self.data_source = None - def test_integrate_idw(self): + def test_integrate_idw(self) -> None: hspatial.integrate( self.dataset, self.data_layer, self.target_band, hspatial.idw ) @@ -127,7 +137,7 @@ class IntegrateWithGeoDjangoObjectsTestCase(IntegrateTestCase): """This is exactly the same as IntegrateTestCase, except that instead of using gdal objects for the mask and target_band, it uses django.contrib.gis.gdal objects.""" - def setUp(self): + def setUp(self) -> None: # We will test on a 7x15 grid self.mask = np.zeros((7, 15), np.float32) self.mask[3, 3] = 1 @@ -156,10 +166,10 @@ def setUp(self): add_point_to_layer(self.data_layer, 125.7, 19.0, 24.0) add_point_to_layer(self.data_layer, 9.8, 57.1, 95.4) - def tearDown(self): + def tearDown(self) -> None: self.data_source = None - def test_integrate_idw(self): + def test_integrate_idw(self) -> None: hspatial.integrate( self.dataset, self.data_layer, self.target_band, hspatial.idw ) @@ -170,13 +180,13 @@ def test_integrate_idw(self): # (^ is bitwise xor in Python) self.assertTrue(((result == nodatavalue) ^ (self.mask != 0)).all()) - self.assertAlmostEqual(result[3, 3], 62.971, places=3) - self.assertAlmostEqual(result[6, 14], 34.838, places=3) - self.assertAlmostEqual(result[4, 13], 30.737, places=3) + self.assertAlmostEqual(result[3, 3], 62.971, places=3) # type: ignore + self.assertAlmostEqual(result[6, 14], 34.838, places=3) # type: ignore + self.assertAlmostEqual(result[4, 13], 30.737, places=3) # type: ignore class CreateOgrLayerFromTimeseriesTestCase(TestCase): - def setUp(self): + def setUp(self) -> None: self.tempdir = tempfile.mkdtemp() # Create two time series @@ -199,10 +209,10 @@ def setUp(self): ) ) - def tearDown(self): + def tearDown(self) -> None: shutil.rmtree(self.tempdir) - def test_create_ogr_layer_from_timeseries(self): + def test_create_ogr_layer_from_timeseries(self) -> None: data_source = ogr.GetDriverByName("memory").CreateDataSource("tmp") filenames = [os.path.join(self.tempdir, x) for x in ("ts1", "ts2")] layer = hspatial.create_ogr_layer_from_timeseries(filenames, 2100, data_source) @@ -244,14 +254,14 @@ class HIntegrateTestCase(TestCase): data/spatial_calculations.ods, tab test_h_integrate. """ - def create_mask(self): + def create_mask(self) -> None: mask_array = np.ones((3, 4), np.int8) mask_array[0, 2] = 0 self.mask = gdal.GetDriverByName("mem").Create("mask", 4, 3, 1, gdal.GDT_Byte) self.mask.SetGeoTransform((0, 10000, 0, 30000, 0, -10000)) self.mask.GetRasterBand(1).WriteArray(mask_array) - def setUp(self): + def setUp(self) -> None: self.tempdir = tempfile.mkdtemp() self.filenames = [ @@ -305,10 +315,10 @@ def setUp(self): self.filenames, 2100, self.stations ) - def tearDown(self): + def tearDown(self) -> None: shutil.rmtree(self.tempdir) - def test_h_integrate(self): + def test_h_integrate(self) -> None: output_filename_prefix = os.path.join(self.tempdir, "test") result_filename = output_filename_prefix + "-2014-04-22-13-00+0200.tif" hspatial.h_integrate( @@ -404,12 +414,12 @@ def test_h_integrate(self): class ExtractPointFromRasterTestCase(TestCase): - def setUp(self): + def setUp(self) -> None: self.tempdir = tempfile.mkdtemp() self._setup_test_raster() self.fp = gdal.Open(self.filename) - def _setup_test_raster(self): + def _setup_test_raster(self) -> None: self.filename = os.path.join(self.tempdir, "test_raster") nan = float("nan") setup_test_raster( @@ -418,27 +428,27 @@ def _setup_test_raster(self): dt.datetime(2014, 11, 21, 16, 1), ) - def tearDown(self): + def tearDown(self) -> None: self.fp = None shutil.rmtree(self.tempdir) - def test_top_left_point(self): + def test_top_left_point(self) -> None: point = hspatial.coordinates2point(22.005, 37.995) self.assertAlmostEqual( hspatial.extract_point_from_raster(point, self.fp), 1.1, places=2 ) - def test_top_left_point_as_geodjango(self): + def test_top_left_point_as_geodjango(self) -> None: point = GeoDjangoPoint(22.005, 37.995) self.assertAlmostEqual( hspatial.extract_point_from_raster(point, self.fp), 1.1, places=2 ) - def test_top_middle_point(self): + def test_top_middle_point(self) -> None: point = hspatial.coordinates2point(22.015, 37.995) self.assertTrue(math.isnan(hspatial.extract_point_from_raster(point, self.fp))) - def test_middle_point(self): + def test_middle_point(self) -> None: # We use co-ordinates almost to the common corner of the four lower left points, # only a little bit towards the center. point = hspatial.coordinates2point(22.01001, 37.98001) @@ -446,7 +456,7 @@ def test_middle_point(self): hspatial.extract_point_from_raster(point, self.fp), 2.2, places=2 ) - def test_middle_point_with_GDALRaster(self): + def test_middle_point_with_GDALRaster(self) -> None: # Same as test_middle_point(), but uses GDALRaster object instead of a gdal # data source. point = hspatial.coordinates2point(22.01001, 37.98001) @@ -455,7 +465,7 @@ def test_middle_point_with_GDALRaster(self): hspatial.extract_point_from_raster(point, gdal_raster_object), 2.2, places=2 ) - def test_bottom_left_point(self): + def test_bottom_left_point(self) -> None: # Use almost exactly same point as test_middle_point(), only slightly altered # so that we get bottom left point instead. point = hspatial.coordinates2point(22.00999, 37.97999) @@ -463,7 +473,7 @@ def test_bottom_left_point(self): hspatial.extract_point_from_raster(point, self.fp), 3.1, places=2 ) - def test_middle_point_with_GRS80(self): + def test_middle_point_with_GRS80(self) -> None: # Same as test_middle_point(), but with a different reference system, GRS80; the # result should be the same. point = hspatial.coordinates2point(325077, 4205177, srid=2100) @@ -471,7 +481,7 @@ def test_middle_point_with_GRS80(self): hspatial.extract_point_from_raster(point, self.fp), 2.2, places=2 ) - def test_does_not_modify_srid_of_point(self): + def test_does_not_modify_srid_of_point(self) -> None: point = hspatial.coordinates2point(325077, 4205177, srid=2100) original_spatial_reference = point.GetSpatialReference().ExportToWkt() hspatial.extract_point_from_raster(point, self.fp) @@ -479,7 +489,7 @@ def test_does_not_modify_srid_of_point(self): point.GetSpatialReference().ExportToWkt(), original_spatial_reference ) - def test_bottom_left_point_with_GRS80(self): + def test_bottom_left_point_with_GRS80(self) -> None: # Same as test_bottom_left_point(), but with a different reference system, # GRS80; the result should be the same. point = hspatial.coordinates2point(324076, 4205176, srid=2100) @@ -487,19 +497,19 @@ def test_bottom_left_point_with_GRS80(self): hspatial.extract_point_from_raster(point, self.fp), 3.1, places=2 ) - def test_point_outside_raster(self): + def test_point_outside_raster(self) -> None: point = hspatial.coordinates2point(21.0, 38.0) with self.assertRaises(RuntimeError): hspatial.extract_point_from_raster(point, self.fp) class ExtractPointFromRasterWhenOutsideCRSLimitsTestCase(TestCase): - def setUp(self): + def setUp(self) -> None: self.tempdir = tempfile.mkdtemp() self._setup_test_raster() self.fp = gdal.Open(self.filename) - def _setup_test_raster(self): + def _setup_test_raster(self) -> None: self.filename = os.path.join(self.tempdir, "test_raster") nan = float("nan") setup_test_raster( @@ -509,32 +519,34 @@ def _setup_test_raster(self): srid=2100, ) - def tearDown(self): + def tearDown(self) -> None: self.fp = None shutil.rmtree(self.tempdir) - def test_fails_gracefully_when_osr_point_is_really_outside_crs_limits(self): + def test_fails_gracefully_when_osr_point_is_really_outside_crs_limits(self) -> None: point = hspatial.coordinates2point(125.0, 85.0) with self.assertRaises(RuntimeError): hspatial.extract_point_from_raster(point, self.fp) - def test_fails_gracefully_when_geos_point_is_really_outside_crs_limits(self): + def test_fails_gracefully_when_geos_point_is_really_outside_crs_limits(self) -> None: point = GeoDjangoPoint(125.0, 85.0) with self.assertRaises(RuntimeError): hspatial.extract_point_from_raster(point, self.fp) class SetupTestRastersMixin: + assertEqual: Callable[[object, object], None] + include_time = True - def setUp(self): + def setUp(self) -> None: self.tempdir = tempfile.mkdtemp() self._setup_test_rasters() - def tearDown(self): + def tearDown(self) -> None: shutil.rmtree(self.tempdir) - def _setup_test_rasters(self): + def _setup_test_rasters(self) -> None: self._setup_raster( dt.date(2014, 11, 21), np.array([[1.1, 1.2, 1.3], [2.1, 2.2, 2.3], [3.1, 3.2, 3.3]]), @@ -550,38 +562,41 @@ def _setup_test_rasters(self): ), ) - def _setup_raster(self, date, value): + def _setup_raster(self, date: dt.date, value: np.ndarray) -> None: filename = self._create_filename(date) timestamp = self._create_timestamp(date) setup_test_raster(filename, value, timestamp, unit="microkernels") - def _create_filename(self, date): + def _create_filename(self, date: dt.date) -> str: result = date.strftime("test-%Y-%m-%d") if self.include_time: result += "-16-1" result += ".tif" return os.path.join(self.tempdir, result) - def _create_timestamp(self, date): + def _create_timestamp(self, date: dt.date) -> dt.datetime | dt.date: if self.include_time: return dt.datetime.combine(date, dt.time(16, 1, tzinfo=UTC_PLUS_2)) else: return date - def _check_against_expected(self, ts): + def _check_against_expected(self, ts: HTimeseries) -> None: expected = pd.DataFrame( data={"value": [2.2, 22.1, 220.1], "flags": ["", "", ""]}, - index=self.expected_index, - columns=["value", "flags"], + index=self.expected_index, # type: ignore + columns=["value", "flags"], # type: ignore ) expected.index.name = "date" + expected_index = pd.DatetimeIndex(expected.index) + actual_index = pd.DatetimeIndex(ts.data.index) + assert expected_index.tz is not None and actual_index.tz is not None self.assertEqual( - ts.data.index.tz.utcoffset(None), expected.index.tz.utcoffset(None) + actual_index.tz.utcoffset(None), expected_index.tz.utcoffset(None) ) pd.testing.assert_frame_equal(ts.data, expected, check_index_type=False) @property - def expected_index(self): + def expected_index(self) -> list[dt.datetime]: hour, minute = self.include_time and (16, 1) or (23, 58) return [ dt.datetime(2014, 11, 21, hour, minute, tzinfo=UTC_PLUS_2), @@ -591,7 +606,7 @@ def expected_index(self): class PointTimeseriesGetTestCase(SetupTestRastersMixin, TestCase): - def test_with_list_of_files(self): + def test_with_list_of_files(self) -> None: # Use co-ordinates almost to the common corner of the four lower left points, # and only a little bit towards the center. point = hspatial.coordinates2point(22.01001, 37.98001) @@ -605,7 +620,7 @@ def test_with_list_of_files(self): ).get() self._check_against_expected(ts) - def test_raises_when_no_timezone(self): + def test_raises_when_no_timezone(self) -> None: point = hspatial.coordinates2point(22.01001, 37.98001) filenames = [os.path.join(self.tempdir, "test-2014-11-22-16-1.tif")] with self.assertRaises(TypeError): @@ -613,7 +628,7 @@ def test_raises_when_no_timezone(self): point, filenames=filenames, default_time=dt.time(0, 0) ) - def test_with_prefix(self): + def test_with_prefix(self) -> None: # Same as test_with_list_of_files(), but with prefix. point = hspatial.coordinates2point(22.01001, 37.98001) prefix = os.path.join(self.tempdir, "test") @@ -622,7 +637,7 @@ def test_with_prefix(self): ).get() self._check_against_expected(ts) - def test_with_prefix_and_geodjango(self): + def test_with_prefix_and_geodjango(self) -> None: point = hspatial.coordinates2point(22.01001, 37.98001) prefix = os.path.join(self.tempdir, "test") ts = hspatial.PointTimeseries( @@ -630,7 +645,7 @@ def test_with_prefix_and_geodjango(self): ).get() self._check_against_expected(ts) - def test_unit_of_measurement(self): + def test_unit_of_measurement(self) -> None: point = hspatial.coordinates2point(22.01001, 37.98001) prefix = os.path.join(self.tempdir, "test") ts = hspatial.PointTimeseries( @@ -642,7 +657,7 @@ def test_unit_of_measurement(self): class PointTimeseriesGetDailyTestCase(SetupTestRastersMixin, TestCase): include_time = False - def test_with_list_of_files(self): + def test_with_list_of_files(self) -> None: # Use co-ordinates almost to the center of the four lower left points, and only # a little bit towards the center. point = hspatial.coordinates2point(22.01001, 37.98001) @@ -656,7 +671,7 @@ def test_with_list_of_files(self): ).get() self._check_against_expected(ts) - def test_with_prefix(self): + def test_with_prefix(self) -> None: # Same as test_with_list_of_files(), but with prefix. point = hspatial.coordinates2point(22.01001, 37.98001) prefix = os.path.join(self.tempdir, "test") @@ -665,7 +680,7 @@ def test_with_prefix(self): ).get() self._check_against_expected(ts) - def test_with_prefix_and_geodjango(self): + def test_with_prefix_and_geodjango(self) -> None: point = GeoDjangoPoint(22.01001, 37.98001) prefix = os.path.join(self.tempdir, "test") ts = hspatial.PointTimeseries( @@ -675,13 +690,13 @@ def test_with_prefix_and_geodjango(self): class PointTimeseriesGetCachedTestCase(SetupTestRastersMixin, TestCase): - def setUp(self): + def setUp(self) -> None: super().setUp() self.point = hspatial.coordinates2point(22.01001, 37.98001) self.prefix = os.path.join(self.tempdir, "test") self.dest = os.path.join(self.tempdir, "dest.hts") - def test_result(self): + def test_result(self) -> None: result = hspatial.PointTimeseries( self.point, prefix=self.prefix, @@ -689,7 +704,7 @@ def test_result(self): ).get_cached(self.dest) self._check_against_expected(result) - def test_file(self): + def test_file(self) -> None: hspatial.PointTimeseries( self.point, prefix=self.prefix, @@ -698,7 +713,7 @@ def test_file(self): with open(self.dest, "r", newline="\n") as f: self._check_against_expected(HTimeseries(f, default_tzinfo=UTC_PLUS_2)) - def test_version(self): + def test_version(self) -> None: hspatial.PointTimeseries( self.point, prefix=self.prefix, @@ -708,7 +723,7 @@ def test_version(self): first_line = f.readline() self.assertEqual(first_line, "Version=2\n") - def test_file_is_not_recreated(self): + def test_file_is_not_recreated(self) -> None: hspatial.PointTimeseries( self.point, prefix=self.prefix, @@ -727,7 +742,7 @@ def test_file_is_not_recreated(self): with open(self.dest, "r", newline="\n") as f: self._check_against_expected(HTimeseries(f, default_tzinfo=UTC_PLUS_2)) - def test_file_is_recreated_when_out_of_date(self): + def test_file_is_recreated_when_out_of_date(self) -> None: hspatial.PointTimeseries( self.point, prefix=self.prefix, @@ -746,7 +761,7 @@ def test_file_is_recreated_when_out_of_date(self): default_time=dt.time(0, 0, tzinfo=UTC_PLUS_2), ).get_cached(self.dest) - def _setup_additional_raster(self): + def _setup_additional_raster(self) -> None: filename = os.path.join(self.tempdir, "test-2014-11-24-16-1.tif") setup_test_raster( filename, @@ -756,7 +771,7 @@ def _setup_additional_raster(self): dt.datetime(2014, 11, 24, 16, 1), ) - def test_start_date(self): + def test_start_date(self) -> None: start_date = dt.datetime(2014, 11, 22, 16, 1) result = hspatial.PointTimeseries( self.point, @@ -766,7 +781,7 @@ def test_start_date(self): ).get_cached(self.dest) self.assertEqual(result.data.index[0], start_date.replace(tzinfo=UTC_PLUS_2)) - def test_end_date(self): + def test_end_date(self) -> None: end_date = dt.datetime(2014, 11, 22, 16, 1) result = hspatial.PointTimeseries( self.point, @@ -778,7 +793,7 @@ def test_end_date(self): class FilenameWithDateFormatTestCase(TestCase): - def test_with_given_datetime_format(self): + def test_with_given_datetime_format(self) -> None: format = hspatial.FilenameWithDateFormat( "myprefix", date_fmt="%d-%m-%Y-%H-%M", tzinfo=UTC_PLUS_2 ) @@ -787,7 +802,7 @@ def test_with_given_datetime_format(self): dt.datetime(2019, 8, 4, 10, 41, tzinfo=UTC_PLUS_2), ) - def test_with_given_date_format(self): + def test_with_given_date_format(self) -> None: format = hspatial.FilenameWithDateFormat( "myprefix", date_fmt="%d-%m-%Y", tzinfo=UTC_PLUS_2 ) @@ -796,14 +811,14 @@ def test_with_given_date_format(self): dt.datetime(2019, 8, 4, tzinfo=UTC_PLUS_2), ) - def test_datetime_with_auto_format(self): + def test_datetime_with_auto_format(self) -> None: format = hspatial.FilenameWithDateFormat("myprefix", tzinfo=UTC_PLUS_2) self.assertEqual( format.get_date("myprefix-2019-8-4-10-41.tif"), dt.datetime(2019, 8, 4, 10, 41, tzinfo=UTC_PLUS_2), ) - def test_date_with_auto_format(self): + def test_date_with_auto_format(self) -> None: format = hspatial.FilenameWithDateFormat("myprefix", tzinfo=UTC_PLUS_2) self.assertEqual( format.get_date("myprefix-2019-8-4.tif"), @@ -812,10 +827,11 @@ def test_date_with_auto_format(self): class PassepartoutPointTestCase(TestCase): - def test_transform_does_not_modify_srid_of_gdal_point(self): + def test_transform_does_not_modify_srid_of_gdal_point(self) -> None: pppoint = hspatial.PassepartoutPoint( hspatial.coordinates2point(324651, 4205742, srid=2100) ) + assert isinstance(pppoint.point, ogr.Geometry) original_spatial_reference = pppoint.point.GetSpatialReference().ExportToWkt() sr = osr.SpatialReference() sr.ImportFromEPSG(4326) @@ -825,7 +841,8 @@ def test_transform_does_not_modify_srid_of_gdal_point(self): original_spatial_reference, ) - def test_transform_does_not_modify_srid_of_geodjango_point(self): + def test_transform_does_not_modify_srid_of_geodjango_point(self) -> None: pppoint = hspatial.PassepartoutPoint(GeoDjangoPoint(324651, 4205742, srid=2100)) - pppoint.transform_to(4326) + pppoint.transform_to("4326") + assert isinstance(pppoint.point, GeoDjangoPoint) self.assertEqual(pppoint.point.srid, 2100) diff --git a/tests/htimeseries/test_htimeseries.py b/tests/htimeseries/test_htimeseries.py index 6a2e9d7..9ef5b2e 100644 --- a/tests/htimeseries/test_htimeseries.py +++ b/tests/htimeseries/test_htimeseries.py @@ -1,9 +1,12 @@ +from __future__ import annotations + import datetime as dt import re import textwrap from configparser import ParsingError from copy import copy from io import StringIO +from typing import Any, Callable, cast from unittest import TestCase from zoneinfo import ZoneInfo @@ -222,22 +225,22 @@ standard_empty_dataframe = pd.DataFrame( data={"value": np.array([], dtype=np.float64), "flags": np.array([], dtype=str)}, index=pd.DatetimeIndex([], tz=dt.timezone(dt.timedelta(hours=2))), - columns=["value", "flags"], + columns=["value", "flags"], # type: ignore[misc] ) standard_empty_dataframe.index.name = "date" class HTimeseriesArgumentsTestCase(TestCase): - def test_raises_on_invalid_argument(self): + def test_raises_on_invalid_argument(self) -> None: msg = r"HTimeseries.__init__\(\) got an unexpected keyword argument 'invalid'" with self.assertRaisesRegex(TypeError, msg): HTimeseries(invalid=42) - def test_raises_if_timezone_unspecified(self): + def test_raises_if_timezone_unspecified(self) -> None: with self.assertRaises(TypeError): HTimeseries(StringIO(tenmin_test_timeseries)) - def test_raises_if_dataframe_naive(self): + def test_raises_if_dataframe_naive(self) -> None: df = copy(standard_empty_dataframe) df.index = pd.DatetimeIndex([]) # Replace with a naive index with self.assertRaises(TypeError): @@ -245,30 +248,30 @@ def test_raises_if_dataframe_naive(self): class HTimeseriesEmptyTestCase(TestCase): - def test_read_empty(self): + def test_read_empty(self) -> None: s = StringIO() ts = HTimeseries(s, default_tzinfo=dt.timezone(dt.timedelta(hours=2))) pd.testing.assert_frame_equal(ts.data, standard_empty_dataframe) - def test_write_empty(self): + def test_write_empty(self) -> None: ts = HTimeseries(default_tzinfo=dt.timezone(dt.timedelta(hours=2))) s = StringIO() ts.write(s) self.assertEqual(s.getvalue(), "") - def test_create_empty(self): + def test_create_empty(self) -> None: pd.testing.assert_frame_equal( HTimeseries(default_tzinfo=dt.timezone(dt.timedelta(hours=2))).data, standard_empty_dataframe, ) - def test_unspecified_default_tzinfo(self): + def test_unspecified_default_tzinfo(self) -> None: ts = HTimeseries() - self.assertEqual(ts.data.index.tz, dt.timezone.utc) + self.assertEqual(cast(pd.DatetimeIndex, ts.data.index).tz, dt.timezone.utc) class HTimeseriesWriteSimpleTestCase(TestCase): - def test_write(self): + def test_write(self) -> None: anp = np.array( [ [parse_date("2005-08-23 18:53"), 93, ""], @@ -276,9 +279,10 @@ def test_write(self): [parse_date("2005-08-25 23:59"), 28.3, "HEARTS SPADES"], [parse_date("2005-08-26 00:02"), float("NaN"), ""], [parse_date("2005-08-27 00:02"), float("NaN"), "DIAMONDS"], - ] + ], + dtype=object, ) - data = pd.DataFrame(anp[:, [1, 2]], index=anp[:, 0], columns=("value", "flags")) + data = pd.DataFrame(anp[:, [1, 2]], index=anp[:, 0], columns=("value", "flags")) # type: ignore[misc] ts = HTimeseries(data=data) s = StringIO() ts.write(s) @@ -297,17 +301,17 @@ def test_write(self): class HTimeseriesWriteFileTestCase(TestCase): - def setUp(self): - data = pd.read_csv( + def setUp(self) -> None: + data = pd.read_csv( # type: ignore[misc] StringIO(tenmin_test_timeseries), parse_dates=[0], - usecols=["date", "value", "flags"], + usecols=["date", "value", "flags"], # type: ignore[misc] index_col=0, header=None, names=("date", "value", "flags"), dtype={"value": np.float64, "flags": str}, ).asfreq("10min") - data.index = data.index.tz_localize(dt.timezone(dt.timedelta(hours=2))) + data.index = cast(pd.DatetimeIndex, data.index).tz_localize(dt.timezone(dt.timedelta(hours=2))) self.reference_ts = HTimeseries(data=data) self.reference_ts.unit = "°C" self.reference_ts.title = "A test 10-min time series" @@ -328,22 +332,22 @@ def setUp(self): "asrid": None, } - def test_version_2(self): + def test_version_2(self) -> None: outstring = StringIO() self.reference_ts.write(outstring, format=HTimeseries.FILE, version=2) self.assertEqual(outstring.getvalue(), tenmin_test_timeseries_file_version_2) - def test_version_3(self): + def test_version_3(self) -> None: outstring = StringIO() self.reference_ts.write(outstring, format=HTimeseries.FILE, version=3) self.assertEqual(outstring.getvalue(), tenmin_test_timeseries_file_version_3) - def test_version_4(self): + def test_version_4(self) -> None: outstring = StringIO() self.reference_ts.write(outstring, format=HTimeseries.FILE, version=4) self.assertEqual(outstring.getvalue(), tenmin_test_timeseries_file_version_4) - def test_version_5(self): + def test_version_5(self) -> None: outstring = StringIO() self.reference_ts.write(outstring, format=HTimeseries.FILE, version=5) self.assertEqual( @@ -353,7 +357,7 @@ def test_version_5(self): ), ) - def test_version_latest(self): + def test_version_latest(self) -> None: outstring = StringIO() self.reference_ts.write(outstring, format=HTimeseries.FILE) self.assertEqual( @@ -363,49 +367,52 @@ def test_version_latest(self): ), ) - def test_altitude_none(self): + def test_altitude_none(self) -> None: + assert self.reference_ts.location is not None self.reference_ts.location["altitude"] = None outstring = StringIO() self.reference_ts.write(outstring, format=HTimeseries.FILE, version=4) self.assertEqual(outstring.getvalue(), tenmin_test_timeseries_file_no_altitude) - def test_no_altitude(self): + def test_no_altitude(self) -> None: + assert self.reference_ts.location is not None del self.reference_ts.location["altitude"] outstring = StringIO() self.reference_ts.write(outstring, format=HTimeseries.FILE, version=4) self.assertEqual(outstring.getvalue(), tenmin_test_timeseries_file_no_altitude) - def test_altitude_zero(self): + def test_altitude_zero(self) -> None: + assert self.reference_ts.location is not None self.reference_ts.location["altitude"] = 0 outstring = StringIO() self.reference_ts.write(outstring, format=HTimeseries.FILE, version=4) self.assertIn("Altitude=0", outstring.getvalue()) - def test_location_none(self): + def test_location_none(self) -> None: self.reference_ts.location = None outstring = StringIO() self.reference_ts.write(outstring, format=HTimeseries.FILE, version=4) self.assertEqual(outstring.getvalue(), tenmin_test_timeseries_file_no_location) - def test_no_location(self): + def test_no_location(self) -> None: delattr(self.reference_ts, "location") outstring = StringIO() self.reference_ts.write(outstring, format=HTimeseries.FILE, version=4) self.assertEqual(outstring.getvalue(), tenmin_test_timeseries_file_no_location) - def test_precision_none(self): + def test_precision_none(self) -> None: self.reference_ts.precision = None outstring = StringIO() self.reference_ts.write(outstring, format=HTimeseries.FILE, version=4) self.assertEqual(outstring.getvalue(), tenmin_test_timeseries_file_no_precision) - def test_no_precision(self): + def test_no_precision(self) -> None: delattr(self.reference_ts, "precision") outstring = StringIO() self.reference_ts.write(outstring, format=HTimeseries.FILE, version=4) self.assertEqual(outstring.getvalue(), tenmin_test_timeseries_file_no_precision) - def test_precision_zero(self): + def test_precision_zero(self) -> None: self.reference_ts.precision = 0 outstring = StringIO() self.reference_ts.write(outstring, format=HTimeseries.FILE, version=4) @@ -413,7 +420,7 @@ def test_precision_zero(self): outstring.getvalue(), tenmin_test_timeseries_file_zero_precision ) - def test_negative_precision(self): + def test_negative_precision(self) -> None: self.reference_ts.precision = -1 outstring = StringIO() self.reference_ts.write(outstring, format=HTimeseries.FILE, version=4) @@ -421,20 +428,20 @@ def test_negative_precision(self): outstring.getvalue(), tenmin_test_timeseries_file_negative_precision ) - def test_timezone_utc(self): + def test_timezone_utc(self) -> None: self.reference_ts.data = self.reference_ts.data.tz_convert(dt.timezone.utc) outstring = StringIO() self.reference_ts.write(outstring, format=HTimeseries.FILE, version=4) self.assertIn("Timezone=+0000\r\n", outstring.getvalue()) - def test_timezone_positive(self): + def test_timezone_positive(self) -> None: tz = dt.timezone(dt.timedelta(hours=2, minutes=30)) self.reference_ts.data = self.reference_ts.data.tz_convert(tz) outstring = StringIO() self.reference_ts.write(outstring, format=HTimeseries.FILE, version=4) self.assertIn("Timezone=+0230\r\n", outstring.getvalue()) - def test_timezone_negative(self): + def test_timezone_negative(self) -> None: tz = dt.timezone(-dt.timedelta(hours=3, minutes=15)) self.reference_ts.data = self.reference_ts.data.tz_convert(tz) outstring = StringIO() @@ -443,109 +450,112 @@ def test_timezone_negative(self): class ReadFilelikeTestCaseBase: - def test_length(self): + ts: HTimeseries + assertEqual: Callable[[Any, Any], None] + + def test_length(self) -> None: self.assertEqual(len(self.ts.data), 5) - def test_dates(self): + def test_dates(self) -> None: np.testing.assert_array_equal( self.ts.data.index, pd.date_range("2008-02-07 11:20+0200", periods=5, freq="10min"), ) - def test_values(self): + def test_values(self) -> None: expected = np.array( [1141.00, 1142.01, 1154.02, float("NaN"), 1180.04], dtype=float ) np.testing.assert_allclose(self.ts.data.values[:, 0].astype(float), expected) - def test_flags(self): + def test_flags(self) -> None: expected = np.array(["", "MISS", "", "", ""]) np.testing.assert_array_equal(self.ts.data.values[:, 1], expected) - def test_tz(self): - self.assertEqual( - self.ts.data.index.tz.utcoffset(dt.datetime(2000, 1, 1)), - dt.timedelta(hours=2), - ) + def test_tz(self) -> None: + tz = cast(pd.DatetimeIndex, self.ts.data.index).tz + assert tz is not None + self.assertEqual(tz.utcoffset(dt.datetime(2000, 1, 1)), dt.timedelta(hours=2)) class HTimeseriesReadTwoColumnsTestCase(ReadFilelikeTestCaseBase, TestCase): - def setUp(self): + def setUp(self) -> None: string = self._remove_flags_column(tenmin_test_timeseries) s = StringIO(string) s.seek(0) self.ts = HTimeseries(s, default_tzinfo=ZoneInfo("Etc/GMT-2")) - def _remove_flags_column(self, s): + def _remove_flags_column(self, s: str) -> str: return re.sub(r",[^,]*$", "", s, flags=re.MULTILINE) + "\n" - def test_flags(self): + def test_flags(self) -> None: expected = np.array(["", "", "", "", ""]) np.testing.assert_array_equal(self.ts.data.values[:, 1], expected) - def test_tz(self): - self.assertEqual( - self.ts.data.index.tz.utcoffset(dt.datetime(2000, 1, 1)), - dt.timedelta(hours=2), - ) + def test_tz(self) -> None: + tz = cast(pd.DatetimeIndex, self.ts.data.index).tz + assert tz is not None + self.assertEqual(tz.utcoffset(dt.datetime(2000, 1, 1)), dt.timedelta(hours=2)) class HTimeseriesReadMixOf2And3ColumnsTestCase(ReadFilelikeTestCaseBase, TestCase): - def setUp(self): + def setUp(self) -> None: string = self._remove_empty_flags_column(tenmin_test_timeseries) s = StringIO(string) s.seek(0) self.ts = HTimeseries(s, default_tzinfo=ZoneInfo("Etc/GMT-2")) - def _remove_empty_flags_column(self, s): + def _remove_empty_flags_column(self, s: str) -> str: return re.sub(r",$", "", s, flags=re.MULTILINE) + "\n" class HTimeseriesReadOneColumnTestCase(TestCase): - def test_one_column(self): + def test_one_column(self) -> None: s = StringIO("2023-12-19 15:17\n") s.seek(0) self.ts = HTimeseries(s, default_tzinfo=ZoneInfo("Etc/GMT-2")) expected = pd.DataFrame( {"value": [np.nan], "flags": [""]}, - index=[dt.datetime(2023, 12, 19, 15, 17, tzinfo=ZoneInfo("Etc/GMT-2"))], + index=[dt.datetime(2023, 12, 19, 15, 17, tzinfo=ZoneInfo("Etc/GMT-2"))], # type: ignore[misc] ) expected.index.name = "date" pd.testing.assert_frame_equal(self.ts.data, expected) class HTimeseriesReadFilelikeMetadataOnlyTestCase(TestCase): - def setUp(self): + def setUp(self) -> None: s = StringIO(tenmin_test_timeseries_file_version_4) s.seek(0) self.ts = HTimeseries( s, start_date="1971-01-01 00:00", end_date="1970-01-01 00:00" ) - def test_data_is_empty(self): + def test_data_is_empty(self) -> None: self.assertEqual(len(self.ts.data), 0) - def test_metadata_was_read(self): + def test_metadata_was_read(self) -> None: self.assertEqual(self.ts.unit, "°C") class HTimeseriesReadFilelikeWithMissingLocationButPresentAltitudeTestCase(TestCase): - def setUp(self): + def setUp(self) -> None: s = StringIO("Altitude=55\n\n") self.ts = HTimeseries(s, default_tzinfo=dt.timezone(dt.timedelta(hours=2))) - def test_data_is_empty(self): + def test_data_is_empty(self) -> None: pd.testing.assert_frame_equal(self.ts.data, standard_empty_dataframe) - def test_has_altitude(self): + def test_has_altitude(self) -> None: + assert self.ts.location is not None self.assertEqual(self.ts.location["altitude"], 55) - def test_has_no_abscissa(self): + def test_has_no_abscissa(self) -> None: + assert self.ts.location is not None self.assertFalse("abscissa" in self.ts.location) class HTimeseriesReadWithStartDateAndEndDateTestCase(TestCase): - def setUp(self): + def setUp(self) -> None: s = StringIO(tenmin_test_timeseries) s.seek(0) self.ts = HTimeseries( @@ -555,10 +565,10 @@ def setUp(self): default_tzinfo=dt.timezone.utc, ) - def test_length(self): + def test_length(self) -> None: self.assertEqual(len(self.ts.data), 3) - def test_dates(self): + def test_dates(self) -> None: np.testing.assert_array_equal( self.ts.data.index, pd.date_range( @@ -566,20 +576,20 @@ def test_dates(self): ), ) - def test_values(self): + def test_values(self) -> None: np.testing.assert_allclose( self.ts.data.values[:, 0].astype(float), np.array([1142.01, 1154.02, float("NaN")]), ) - def test_flags(self): + def test_flags(self) -> None: np.testing.assert_array_equal( self.ts.data.values[:, 1], np.array(["MISS", "", ""]) ) class HTimeseriesReadWithStartDateTestCase(TestCase): - def setUp(self): + def setUp(self) -> None: s = StringIO(tenmin_test_timeseries) s.seek(0) self.ts = HTimeseries( @@ -588,10 +598,10 @@ def setUp(self): default_tzinfo=dt.timezone.utc, ) - def test_length(self): + def test_length(self) -> None: self.assertEqual(len(self.ts.data), 2) - def test_dates(self): + def test_dates(self) -> None: np.testing.assert_array_equal( self.ts.data.index, pd.date_range( @@ -599,17 +609,17 @@ def test_dates(self): ), ) - def test_values(self): + def test_values(self) -> None: np.testing.assert_allclose( self.ts.data.values[:, 0].astype(float), np.array([float("NaN"), 1180.04]) ) - def test_flags(self): + def test_flags(self) -> None: np.testing.assert_array_equal(self.ts.data.values[:, 1], np.array(["", ""])) class HTimeseriesReadWithEndDateTestCase(TestCase): - def setUp(self): + def setUp(self) -> None: s = StringIO(tenmin_test_timeseries) s.seek(0) self.ts = HTimeseries( @@ -618,10 +628,10 @@ def setUp(self): default_tzinfo=dt.timezone.utc, ) - def test_length(self): + def test_length(self) -> None: self.assertEqual(len(self.ts.data), 4) - def test_dates(self): + def test_dates(self) -> None: np.testing.assert_array_equal( self.ts.data.index, pd.date_range( @@ -629,31 +639,31 @@ def test_dates(self): ), ) - def test_values(self): + def test_values(self) -> None: np.testing.assert_allclose( self.ts.data.values[:, 0].astype(float), np.array([1141.00, 1142.01, 1154.02, float("NaN")]), ) - def test_flags(self): + def test_flags(self) -> None: np.testing.assert_array_equal( self.ts.data.values[:, 1], np.array(["", "MISS", "", ""]) ) class HTimeseriesReadFileFormatTestCase(TestCase): - def setUp(self): + def setUp(self) -> None: s = StringIO(tenmin_test_timeseries_file_version_4) s.seek(0) self.ts = HTimeseries(s) - def test_unit(self): + def test_unit(self) -> None: self.assertEqual(self.ts.unit, "°C") - def test_title(self): + def test_title(self) -> None: self.assertEqual(self.ts.title, "A test 10-min time series") - def test_comment(self): + def test_comment(self) -> None: self.assertEqual( self.ts.comment, textwrap.dedent( @@ -666,37 +676,44 @@ def test_comment(self): ), ) - def test_timezone(self): - self.assertEqual(self.ts.data.index.tz.utcoffset(None), dt.timedelta(hours=2)) + def test_timezone(self) -> None: + tz = cast(pd.DatetimeIndex, self.ts.data.index).tz + assert tz is not None + self.assertEqual(tz.utcoffset(None), dt.timedelta(hours=2)) - def test_time_step(self): + def test_time_step(self) -> None: self.assertEqual(self.ts.time_step, "10min") - def test_variable(self): + def test_variable(self) -> None: self.assertEqual(self.ts.variable, "temperature") - def test_precision(self): + def test_precision(self) -> None: self.assertEqual(self.ts.precision, 1) - def test_abscissa(self): + def test_abscissa(self) -> None: + assert self.ts.location is not None self.assertAlmostEqual(self.ts.location["abscissa"], 24.678900, places=6) - def test_ordinate(self): + def test_ordinate(self) -> None: + assert self.ts.location is not None self.assertAlmostEqual(self.ts.location["ordinate"], 38.123450, places=6) - def test_srid(self): + def test_srid(self) -> None: + assert self.ts.location is not None self.assertEqual(self.ts.location["srid"], 4326) - def test_altitude(self): + def test_altitude(self) -> None: + assert self.ts.location is not None self.assertAlmostEqual(self.ts.location["altitude"], 219.22, places=2) - def test_asrid(self): + def test_asrid(self) -> None: + assert self.ts.location is not None self.assertTrue(self.ts.location["asrid"] is None) - def test_length(self): + def test_length(self) -> None: self.assertEqual(len(self.ts.data), 5) - def test_dates(self): + def test_dates(self) -> None: np.testing.assert_array_equal( self.ts.data.index, pd.date_range( @@ -704,24 +721,24 @@ def test_dates(self): ), ) - def test_values(self): + def test_values(self) -> None: np.testing.assert_allclose( self.ts.data.values[:, 0].astype(float), np.array([1141.0, 1142.0, 1154.0, float("NaN"), 1180.0]), ) - def test_flags(self): + def test_flags(self) -> None: np.testing.assert_array_equal( self.ts.data.values[:, 1], np.array(["", "MISS", "", "", ""]) ) class FormatAutoDetectorTestCase(TestCase): - def test_auto_detect_text_format(self): + def test_auto_detect_text_format(self) -> None: detected_format = FormatAutoDetector(StringIO(tenmin_test_timeseries)).detect() self.assertEqual(detected_format, HTimeseries.TEXT) - def test_auto_detect_file_format(self): + def test_auto_detect_file_format(self) -> None: detected_format = FormatAutoDetector( StringIO(tenmin_test_timeseries_file_version_4) ).detect() @@ -729,70 +746,70 @@ def test_auto_detect_file_format(self): class WriteOldTimeStepTestCase(TestCase): - def get_value(self, time_step): + def get_value(self, time_step: str) -> str: self.f = StringIO() self.htimeseries = HTimeseries() self.htimeseries.time_step = time_step MetadataWriter(self.f, self.htimeseries, version=2).write_time_step() return self.f.getvalue() - def test_empty(self): + def test_empty(self) -> None: self.assertEqual(self.get_value(""), "") - def test_min(self): + def test_min(self) -> None: self.assertEqual(self.get_value("27min"), "Time_step=27,0\r\n") - def test_min_without_number(self): + def test_min_without_number(self) -> None: self.assertEqual(self.get_value("min"), "Time_step=1,0\r\n") - def test_hour(self): + def test_hour(self) -> None: self.assertEqual(self.get_value("3h"), "Time_step=180,0\r\n") - def test_hour_without_number(self): + def test_hour_without_number(self) -> None: self.assertEqual(self.get_value("h"), "Time_step=60,0\r\n") - def test_day(self): + def test_day(self) -> None: self.assertEqual(self.get_value("3D"), "Time_step=4320,0\r\n") - def test_day_without_number(self): + def test_day_without_number(self) -> None: self.assertEqual(self.get_value("D"), "Time_step=1440,0\r\n") - def test_month(self): + def test_month(self) -> None: self.assertEqual(self.get_value("3ME"), "Time_step=0,3\r\n") - def test_month_without_number(self): + def test_month_without_number(self) -> None: self.assertEqual(self.get_value("ME"), "Time_step=0,1\r\n") - def test_year(self): + def test_year(self) -> None: self.assertEqual(self.get_value("3YE"), "Time_step=0,36\r\n") - def test_year_without_number(self): + def test_year_without_number(self) -> None: self.assertEqual(self.get_value("YE"), "Time_step=0,12\r\n") - def test_garbage(self): + def test_garbage(self) -> None: with self.assertRaisesRegex(ValueError, 'Cannot format time step "hello"'): self.get_value("hello") - def test_wrong_number(self): + def test_wrong_number(self) -> None: with self.assertRaisesRegex(ValueError, 'Cannot format time step "FM"'): self.get_value("FM") class GetTimeStepTestCase(TestCase): - def get_value(self, time_step): + def get_value(self, time_step: str) -> str: f = StringIO("Time_step={}\r\n\r\n".format(time_step)) return MetadataReader(f).meta["time_step"] - def test_min(self): + def test_min(self) -> None: self.assertEqual(self.get_value("1min"), "1min") - def test_minutes(self): + def test_minutes(self) -> None: self.assertEqual(self.get_value("250,0"), "250min") - def test_months(self): + def test_months(self) -> None: self.assertEqual(self.get_value("0,25"), "25M") - def test_both_nonzero(self): + def test_both_nonzero(self) -> None: with self.assertRaisesRegex(ParsingError, "Invalid time step"): self.get_value("5,5") @@ -809,7 +826,7 @@ class HTimeseriesReadWithDuplicateDatesTestCase(TestCase): """ ) - def test_read_csv_with_duplicates_raises_error(self): + def test_read_csv_with_duplicates_raises_error(self) -> None: s = StringIO(self.csv_with_duplicates) s.seek(0) msg = ( @@ -819,17 +836,17 @@ def test_read_csv_with_duplicates_raises_error(self): with self.assertRaisesRegex(ValueError, msg): HTimeseries(s) - def test_write_csv_with_duplicates_raises_error(self): - data = pd.read_csv( + def test_write_csv_with_duplicates_raises_error(self) -> None: + data = pd.read_csv( # type: ignore[misc] StringIO(self.csv_with_duplicates), parse_dates=[0], - usecols=["date", "value", "flags"], + usecols=["date", "value", "flags"], # type: ignore[misc] index_col=0, header=None, names=("date", "value", "flags"), dtype={"value": np.float64, "flags": str}, ) - data.index = data.index.tz_localize(dt.timezone.utc) + data.index = cast(pd.DatetimeIndex, data.index).tz_localize(dt.timezone.utc) msg = ( "Can't write time series: the following timestamps appear more than once: " r"2020-02-23 12:00:00\+00:00, 2020-02-23 13:00:00\+00:00" @@ -857,12 +874,12 @@ class HTimeseriesTimeChangeTestCase(TestCase): """ ) - def setUp(self): + def setUp(self) -> None: s = StringIO(self.time_change_test_timeseries) s.seek(0) self.ts = HTimeseries(s, default_tzinfo=ZoneInfo("Europe/Athens")) - def test_dates(self): + def test_dates(self) -> None: expected = np.array( [ dt.datetime(2023, 10, 28, 23, 30, 0, tzinfo=dt.timezone.utc), @@ -872,10 +889,8 @@ def test_dates(self): dt.datetime(2023, 10, 29, 2, 30, 0, tzinfo=dt.timezone.utc), ] ) - np.testing.assert_array_equal( - self.ts.data.index.tz_convert(dt.timezone.utc), - expected, - ) + index = cast(pd.DatetimeIndex, self.ts.data.index) + np.testing.assert_array_equal(index.tz_convert(dt.timezone.utc), expected) class HTimeseriesCsvWithAwareTimestampsTestCase(TestCase): @@ -886,7 +901,7 @@ class HTimeseriesCsvWithAwareTimestampsTestCase(TestCase): 19:53+00:00. We check that we handle this correctly. """ - def test_csv_with_aware_timestamps(self): + def test_csv_with_aware_timestamps(self) -> None: s = StringIO("2023-12-22 19:53+00:00,42.0,\n") s.seek(0) self.ts = HTimeseries(s, default_tzinfo=ZoneInfo("Etc/GMT-2")) @@ -895,7 +910,7 @@ def test_csv_with_aware_timestamps(self): dt.datetime(2023, 12, 22, 19, 53, tzinfo=dt.timezone.utc), ) - def test_csv_with_mixed_timestamps(self): + def test_csv_with_mixed_timestamps(self) -> None: s = StringIO("2023-12-22 19:53+00:00,42.0,\n2023-12-22 20:53,43.0,\n") s.seek(0) msg = "Maybe the CSV contains mixed aware and naive timestamps" diff --git a/tests/htimeseries/test_timezone_utils.py b/tests/htimeseries/test_timezone_utils.py index 6f75af6..717967c 100644 --- a/tests/htimeseries/test_timezone_utils.py +++ b/tests/htimeseries/test_timezone_utils.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import datetime as dt from unittest import TestCase @@ -5,26 +7,26 @@ class TzinfoFromStringTestCase(TestCase): - def test_simple(self): + def test_simple(self) -> None: atzinfo = TzinfoFromString("+0130") self.assertEqual(atzinfo.offset, dt.timedelta(hours=1, minutes=30)) - def test_brackets(self): + def test_brackets(self) -> None: atzinfo = TzinfoFromString("DUMMY (+0240)") self.assertEqual(atzinfo.offset, dt.timedelta(hours=2, minutes=40)) - def test_brackets_with_utc(self): + def test_brackets_with_utc(self) -> None: atzinfo = TzinfoFromString("DUMMY (UTC+0350)") self.assertEqual(atzinfo.offset, dt.timedelta(hours=3, minutes=50)) - def test_negative(self): + def test_negative(self) -> None: atzinfo = TzinfoFromString("DUMMY (UTC-0420)") self.assertEqual(atzinfo.offset, -dt.timedelta(hours=4, minutes=20)) - def test_zero(self): + def test_zero(self) -> None: atzinfo = TzinfoFromString("DUMMY (UTC-0000)") self.assertEqual(atzinfo.offset, dt.timedelta(hours=0, minutes=0)) - def test_wrong_input(self): + def test_wrong_input(self) -> None: for s in ("DUMMY (GMT+0350)", "0150", "+01500"): self.assertRaises(ValueError, TzinfoFromString, s) diff --git a/tests/rocc/test_rocc.py b/tests/rocc/test_rocc.py index 1fe16db..c8be31a 100644 --- a/tests/rocc/test_rocc.py +++ b/tests/rocc/test_rocc.py @@ -1,6 +1,9 @@ +from __future__ import annotations + import datetime as dt import textwrap from io import StringIO +from typing import Any, Iterable from unittest import TestCase from zoneinfo import ZoneInfo @@ -9,7 +12,7 @@ class RoccTestCase(TestCase): - test_data = textwrap.dedent( + test_data: str = textwrap.dedent( """\ 2020-10-06 14:30,24.0, 2020-10-06 14:40,25.0, @@ -22,13 +25,14 @@ class RoccTestCase(TestCase): """ ) - def setUp(self): - self.ahtimeseries = HTimeseries( + def setUp(self) -> None: + self.ahtimeseries: HTimeseries = HTimeseries( StringIO(self.test_data), default_tzinfo=ZoneInfo("Etc/GMT-2") ) self.ahtimeseries.precision = 1 + self.return_value: list[str] = [] - def _run_rocc(self, flag): + def _run_rocc(self, flag: str | None) -> None: self.return_value = rocc( timeseries=self.ahtimeseries, thresholds=( @@ -39,11 +43,11 @@ def _run_rocc(self, flag): flag=flag, ) - def test_calculation(self): + def test_calculation(self) -> None: self._run_rocc(flag="TEMPORAL") - result = StringIO() - self.ahtimeseries.write(result) - result = result.getvalue().replace("\r\n", "\n") + f: StringIO = StringIO() + self.ahtimeseries.write(f) + result = f.getvalue().replace("\r\n", "\n") self.assertEqual( result, textwrap.dedent( @@ -60,7 +64,7 @@ def test_calculation(self): ), ) - def test_return_value(self): + def test_return_value(self) -> None: self._run_rocc(flag="TEMPORAL") self.assertEqual(len(self.return_value), 2) self.assertEqual( @@ -70,26 +74,26 @@ def test_return_value(self): self.return_value[1], "2020-10-06T15:41 +20.0 in 20min (> 15.0)" ) - def test_value_dtype(self): + def test_value_dtype(self) -> None: self._run_rocc(flag="TEMPORAL") - expected_dtype = HTimeseries().data["value"].dtype + expected_dtype: Any = HTimeseries().data["value"].dtype self.assertEqual(self.ahtimeseries.data["value"].dtype, expected_dtype) - def test_flags_dtype(self): + def test_flags_dtype(self) -> None: self._run_rocc(flag="TEMPORAL") - expected_dtype = HTimeseries().data["flags"].dtype + expected_dtype: Any = HTimeseries().data["flags"].dtype self.assertEqual(self.ahtimeseries.data["flags"].dtype, expected_dtype) - def test_empty_flag(self): + def test_empty_flag(self) -> None: self._run_rocc(flag=None) - result = StringIO() - self.ahtimeseries.write(result) - result = result.getvalue().replace("\r\n", "\n") + f: StringIO = StringIO() + self.ahtimeseries.write(f) + result = f.getvalue().replace("\r\n", "\n") self.assertEqual(result, self.test_data) class RoccNegativeTestCase(TestCase): - test_data = textwrap.dedent( + test_data: str = textwrap.dedent( """\ 2020-10-06 14:30,24.0, 2020-10-06 14:40,25.0, @@ -102,13 +106,14 @@ class RoccNegativeTestCase(TestCase): """ ) - def setUp(self): - self.ahtimeseries = HTimeseries( + def setUp(self) -> None: + self.ahtimeseries: HTimeseries = HTimeseries( StringIO(self.test_data), default_tzinfo=ZoneInfo("Etc/GMT-2") ) self.ahtimeseries.precision = 1 + self.return_value: list[str] = [] - def _run_rocc(self, flag): + def _run_rocc(self, flag: str | None) -> None: self.return_value = rocc( timeseries=self.ahtimeseries, thresholds=( @@ -118,11 +123,11 @@ def _run_rocc(self, flag): flag=flag, ) - def test_calculation(self): + def test_calculation(self) -> None: self._run_rocc(flag="TEMPORAL") - result = StringIO() - self.ahtimeseries.write(result) - result = result.getvalue().replace("\r\n", "\n") + f = StringIO() + self.ahtimeseries.write(f) + result = f.getvalue().replace("\r\n", "\n") self.assertEqual( result, textwrap.dedent( @@ -139,7 +144,7 @@ def test_calculation(self): ), ) - def test_return_value(self): + def test_return_value(self) -> None: self._run_rocc(flag="TEMPORAL") self.assertEqual(len(self.return_value), 2) self.assertEqual( @@ -151,7 +156,7 @@ def test_return_value(self): class RoccSymmetricTestCase(TestCase): - test_data = textwrap.dedent( + test_data: str = textwrap.dedent( """\ 2020-10-06 14:30,76.0, 2020-10-06 14:40,75.0,SOMEFLAG @@ -164,13 +169,13 @@ class RoccSymmetricTestCase(TestCase): """ ) - def setUp(self): - self.ahtimeseries = HTimeseries( + def setUp(self) -> None: + self.ahtimeseries: HTimeseries = HTimeseries( StringIO(self.test_data), default_tzinfo=ZoneInfo("Etc/GMT-2") ) self.ahtimeseries.precision = 1 - def test_without_symmetric(self): + def test_without_symmetric(self) -> None: rocc( timeseries=self.ahtimeseries, thresholds=( @@ -179,9 +184,9 @@ def test_without_symmetric(self): Threshold("h", 40), ), ) - result = StringIO() - self.ahtimeseries.write(result) - result = result.getvalue().replace("\r\n", "\n") + f = StringIO() + self.ahtimeseries.write(f) + result = f.getvalue().replace("\r\n", "\n") self.assertEqual( result, textwrap.dedent( @@ -198,8 +203,8 @@ def test_without_symmetric(self): ), ) - def test_with_symmetric(self): - output = rocc( + def test_with_symmetric(self) -> None: + output: list[str] = rocc( timeseries=self.ahtimeseries, thresholds=( # Keep in strange order to test for possible errors Threshold("20min", -15), @@ -208,9 +213,9 @@ def test_with_symmetric(self): ), symmetric=True, ) - result = StringIO() - self.ahtimeseries.write(result) - result = result.getvalue().replace("\r\n", "\n") + f = StringIO() + self.ahtimeseries.write(f) + result = f.getvalue().replace("\r\n", "\n") self.assertEqual( result, textwrap.dedent( @@ -234,8 +239,8 @@ def test_with_symmetric(self): ], ) - def test_symmetric_return_value(self): - return_value = rocc( + def test_symmetric_return_value(self) -> None: + return_value: list[str] = rocc( timeseries=self.ahtimeseries, thresholds=( Threshold("10min", 10), @@ -250,29 +255,32 @@ def test_symmetric_return_value(self): class RoccEmptyCase(TestCase): - def test_with_empty(self): - ahtimeseries = HTimeseries() + def test_with_empty(self) -> None: + ahtimeseries: HTimeseries = HTimeseries() rocc(timeseries=ahtimeseries, thresholds=[Threshold("10min", 10)]) self.assertTrue(ahtimeseries.data.empty) class RoccImpliedThresholdsTestCase(TestCase): - def _create_htimeseries(self, test_data): - self.ahtimeseries = HTimeseries( + def _create_htimeseries(self, test_data: str) -> None: + self.ahtimeseries: HTimeseries = HTimeseries( StringIO(test_data), default_tzinfo=dt.timezone.utc ) self.ahtimeseries.precision = 2 + self.return_value: list[str] = [] - def _run_rocc(self, thresholds, symmetric=False): + def _run_rocc( + self, thresholds: Iterable[Threshold], symmetric: bool = False + ) -> str: self.return_value = rocc( timeseries=self.ahtimeseries, thresholds=thresholds, symmetric=symmetric ) - result = StringIO() - self.ahtimeseries.write(result) - result = result.getvalue().replace("\r\n", "\n") + f = StringIO() + self.ahtimeseries.write(f) + result = f.getvalue().replace("\r\n", "\n") return result - def test_positive_ok(self): + def test_positive_ok(self) -> None: test_data = textwrap.dedent( """\ 2020-10-06 14:30,25.00, @@ -280,7 +288,7 @@ def test_positive_ok(self): """ ) self._create_htimeseries(test_data) - result = self._run_rocc( + result: str = self._run_rocc( thresholds=( Threshold("10min", 10), Threshold("20min", 15), @@ -296,7 +304,7 @@ def test_positive_ok(self): ), ) - def test_positive_not_ok(self): + def test_positive_not_ok(self) -> None: test_data = textwrap.dedent( """\ 2020-10-06 14:30,25.00, @@ -304,7 +312,7 @@ def test_positive_not_ok(self): """ ) self._create_htimeseries(test_data) - result = self._run_rocc( + result: str = self._run_rocc( thresholds=( Threshold("10min", 10), Threshold("20min", 15), @@ -320,7 +328,7 @@ def test_positive_not_ok(self): ), ) - def test_negative_ok(self): + def test_negative_ok(self) -> None: test_data = textwrap.dedent( """\ 2020-10-06 14:30,75.00, @@ -328,7 +336,7 @@ def test_negative_ok(self): """ ) self._create_htimeseries(test_data) - result = self._run_rocc( + result: str = self._run_rocc( thresholds=( Threshold("10min", -10), Threshold("20min", -15), @@ -344,7 +352,7 @@ def test_negative_ok(self): ), ) - def test_negative_not_ok(self): + def test_negative_not_ok(self) -> None: test_data = textwrap.dedent( """\ 2020-10-06 14:30,75.00, @@ -352,7 +360,7 @@ def test_negative_not_ok(self): """ ) self._create_htimeseries(test_data) - result = self._run_rocc( + result: str = self._run_rocc( thresholds=( Threshold("10min", -10), Threshold("20min", -15), @@ -368,7 +376,7 @@ def test_negative_not_ok(self): ), ) - def test_symmetric_ok(self): + def test_symmetric_ok(self) -> None: test_data = textwrap.dedent( """\ 2020-10-06 14:30,50.00, @@ -376,7 +384,7 @@ def test_symmetric_ok(self): """ ) self._create_htimeseries(test_data) - result = self._run_rocc( + result: str = self._run_rocc( thresholds=( Threshold("10min", 10), Threshold("20min", -15), @@ -393,7 +401,7 @@ def test_symmetric_ok(self): ), ) - def test_symmetric_not_ok(self): + def test_symmetric_not_ok(self) -> None: test_data = textwrap.dedent( """\ 2020-10-06 14:30,75.00, @@ -401,7 +409,7 @@ def test_symmetric_not_ok(self): """ ) self._create_htimeseries(test_data) - result = self._run_rocc( + result: str = self._run_rocc( thresholds=( Threshold("10min", 10), Threshold("20min", -15), @@ -418,7 +426,7 @@ def test_symmetric_not_ok(self): ), ) - def test_compare_to_previous_not_null_record(self): + def test_compare_to_previous_not_null_record(self) -> None: test_data = textwrap.dedent( """\ 2023-10-13 08:40,12.3, @@ -427,7 +435,7 @@ def test_compare_to_previous_not_null_record(self): """ ) self._create_htimeseries(test_data) - result = self._run_rocc( + result: str = self._run_rocc( thresholds=(Threshold("10min", 1),), symmetric=True, ) @@ -442,7 +450,7 @@ def test_compare_to_previous_not_null_record(self): ), ) - def test_check_delta_remainder_with_symmetric_and_negative_threshold(self): + def test_check_delta_remainder_with_symmetric_and_negative_threshold(self) -> None: test_data = textwrap.dedent( """\ 2023-10-13 08:40,12.30, @@ -450,7 +458,7 @@ def test_check_delta_remainder_with_symmetric_and_negative_threshold(self): """ ) self._create_htimeseries(test_data) - result = self._run_rocc( + result: str = self._run_rocc( thresholds=(Threshold("10min", -1),), symmetric=True, ) diff --git a/tests/textbisect/test_textbisect.py b/tests/textbisect/test_textbisect.py index 3d0b3fc..d9e113d 100644 --- a/tests/textbisect/test_textbisect.py +++ b/tests/textbisect/test_textbisect.py @@ -2,6 +2,8 @@ from io import StringIO from unittest import TestCase +from typing import Any, Union + from textbisect import text_bisect, text_bisect_left, text_bisect_right testtext = textwrap.dedent( @@ -89,7 +91,20 @@ class TextBisectTestCaseBase(TestCase): - def _do_test(self, search_term, expected_result, direction="", lo=0, hi=None): + f: StringIO + + @staticmethod + def KEY(x: str) -> str: + return x + + def _do_test( + self, + search_term: str, + expected_result: int, + direction: str = "", + lo: int = 0, + hi: Union[int, None] = None, + ): function = { "left": text_bisect_left, "right": text_bisect_right, @@ -100,8 +115,12 @@ def _do_test(self, search_term, expected_result, direction="", lo=0, hi=None): self.assertEqual(pos, self.f.tell()) +IRRELEVANT = -1 + + class TextBisectWithoutKeyTestCase(TextBisectTestCaseBase): - def KEY(x): + @staticmethod + def KEY(x: str) -> Any: return x @classmethod @@ -147,7 +166,7 @@ def test_in_file_part_for_something_after_end_of_that_part(self): def test_when_file_part_ends_in_middle_of_line(self): with self.assertRaises(EOFError): - self._do_test("nick", "irrelevant", hi=71) + self._do_test("nick", IRRELEVANT, hi=71) def test_searching_in_file_part_specified_by_both_lo_and_hi(self): self._do_test("nick", 79, lo=64, hi=93) @@ -175,7 +194,9 @@ def test_bisect_right_at_end_of_file_part_specified_by_both_lo_and_hi(self): class TextBisectWithKeyTestCase(TextBisectTestCaseBase): - KEY = len + @staticmethod + def KEY(x: str) -> Any: + return len(x) @classmethod def setUpClass(cls): @@ -220,7 +241,7 @@ def test_in_file_part_for_something_after_end_of_that_part(self): def test_when_file_part_ends_in_middle_of_line(self): with self.assertRaises(EOFError): - self._do_test("eleven=0011", "irrelevant", hi=32) + self._do_test("eleven=0011", IRRELEVANT, hi=32) def test_beginning_of_file_part_specified_by_both_lo_and_hi(self): self._do_test("four", 6, lo=6, hi=29) @@ -248,7 +269,8 @@ def test_bisect_right_at_end_of_file_part_specified_by_both_lo_and_hi(self): class TextBisectOnlyOneLineTestCase(TextBisectTestCaseBase): - def KEY(x): + @staticmethod + def KEY(x: str) -> str: return x @classmethod From 1f34acd4d9759eade846cda881e3093892286a89 Mon Sep 17 00:00:00 2001 From: Antonis Christofides Date: Wed, 5 Nov 2025 13:49:22 +0200 Subject: [PATCH 3/4] Add recent changes to the changelog --- CHANGELOG.rst | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/CHANGELOG.rst b/CHANGELOG.rst index bb066d3..387b683 100644 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -2,6 +2,12 @@ Changelog ========= +DEV +=== + +* Added textbisect module. +* Added type hints. + 2.6.0 (2025-10-01) ================== From f82fe16e8de76f170a741857badc90609d331e2a Mon Sep 17 00:00:00 2001 From: Antonis Christofides Date: Wed, 5 Nov 2025 14:05:41 +0200 Subject: [PATCH 4/4] Lint --- .github/workflows/run-tests.yml | 2 +- src/enhydris_api_client/__init__.py | 21 ++++++++++++++------- src/enhydris_cache/cli.py | 2 +- src/enhydris_cache/enhydris_cache.py | 3 ++- src/evaporation/cli.py | 1 + src/evaporation/evaporation.py | 8 ++++---- src/haggregate/cli.py | 2 +- src/haggregate/haggregate.py | 2 +- src/haggregate/regularize.pyi | 3 +-- src/hspatial/test.py | 2 ++ src/htimeseries/htimeseries.py | 3 ++- src/rocc/__init__.py | 1 - src/textbisect/__init__.py | 3 +-- tests/enhydris_api_client/__init__.py | 6 ++---- tests/enhydris_api_client/test_e2e.py | 1 - tests/enhydris_api_client/test_station.py | 4 ++-- tests/evaporation/test_cli.py | 8 ++------ tests/haggregate/test_cli.py | 1 - tests/haggregate/test_regularize.py | 2 +- tests/hspatial/test_hspatial.py | 4 +++- tests/htimeseries/test_htimeseries.py | 4 +++- tests/textbisect/test_textbisect.py | 5 ++--- 22 files changed, 46 insertions(+), 42 deletions(-) diff --git a/.github/workflows/run-tests.yml b/.github/workflows/run-tests.yml index 985b7f5..434d817 100644 --- a/.github/workflows/run-tests.yml +++ b/.github/workflows/run-tests.yml @@ -45,7 +45,7 @@ jobs: # numpy<2 is needed for gdal to contain support for gdal array pip install 'numpy<2' CPLUS_INCLUDE_PATH=/usr/include/gdal C_INCLUDE_PATH=/usr/include/gdal pip install --no-build-isolation 'gdal==3.8.4' - pip install coverage isort flake8 'black<25' twine setuptools build pyright + pip install coverage isort flake8 'black<25' twine setuptools build pyright setuptools_scm cython pip install -e . - name: Run Tests diff --git a/src/enhydris_api_client/__init__.py b/src/enhydris_api_client/__init__.py index 712c0b8..56dce1b 100644 --- a/src/enhydris_api_client/__init__.py +++ b/src/enhydris_api_client/__init__.py @@ -3,7 +3,7 @@ import datetime as dt from copy import copy from io import StringIO -from typing import Any, Dict, Generator, Iterable, Optional +from typing import Any, Dict, Generator, Iterable, Optional, cast from urllib.parse import urljoin from zoneinfo import ZoneInfo @@ -11,7 +11,6 @@ import pandas as pd import requests from requests import Response, Session -from typing import cast from htimeseries import HTimeseries @@ -82,7 +81,7 @@ def _check_status_code_is_the_one_expected( def get_token(self, username: str, password: str) -> Optional[str]: if not username: - return + return None # Get a csrftoken login_url = urljoin(self.base_url, "api/auth/login/") @@ -138,7 +137,9 @@ def delete_station(self, station_id: int) -> None: self.response = self.session.delete(url) self.check_response(expected_status_code=204) - def list_timeseries_groups(self, station_id: int) -> Generator[JSONDict, None, None]: + def list_timeseries_groups( + self, station_id: int + ) -> Generator[JSONDict, None, None]: url = urljoin(self.base_url, f"api/stations/{station_id}/timeseriesgroups/") while url: try: @@ -153,7 +154,9 @@ def list_timeseries_groups(self, station_id: int) -> Generator[JSONDict, None, N f"Malformed response from server: {str(e)}" ) - def get_timeseries_group(self, station_id: int, timeseries_group_id: int) -> JSONDict: + def get_timeseries_group( + self, station_id: int, timeseries_group_id: int + ) -> JSONDict: url = urljoin( self.base_url, f"api/stations/{station_id}/timeseriesgroups/{timeseries_group_id}/", @@ -189,7 +192,9 @@ def patch_timeseries_group( self.response = self.session.patch(url, data=data) self.check_response() - def delete_timeseries_group(self, station_id: int, timeseries_group_id: int) -> None: + def delete_timeseries_group( + self, station_id: int, timeseries_group_id: int + ) -> None: url = urljoin( self.base_url, f"api/stations/{station_id}/timeseriesgroups/{timeseries_group_id}/", @@ -197,7 +202,9 @@ def delete_timeseries_group(self, station_id: int, timeseries_group_id: int) -> self.response = self.session.delete(url) self.check_response(expected_status_code=204) - def list_timeseries(self, station_id: int, timeseries_group_id: int) -> Iterable[JSONDict]: + def list_timeseries( + self, station_id: int, timeseries_group_id: int + ) -> Iterable[JSONDict]: url = urljoin( self.base_url, f"api/stations/{station_id}/timeseriesgroups/{timeseries_group_id}/" diff --git a/src/enhydris_cache/cli.py b/src/enhydris_cache/cli.py index 901b125..ad2477e 100644 --- a/src/enhydris_cache/cli.py +++ b/src/enhydris_cache/cli.py @@ -4,8 +4,8 @@ import datetime as dt import logging import os -from typing import Any, Dict, Sequence import traceback +from typing import Any, Dict, Sequence import click diff --git a/src/enhydris_cache/enhydris_cache.py b/src/enhydris_cache/enhydris_cache.py index fc08659..64f3364 100644 --- a/src/enhydris_cache/enhydris_cache.py +++ b/src/enhydris_cache/enhydris_cache.py @@ -1,7 +1,7 @@ from __future__ import annotations import datetime as dt -from typing import Mapping, Sequence, TypedDict, cast +from typing import Sequence, TypedDict, cast import pandas as pd @@ -17,6 +17,7 @@ class TimeseriesGroup(TypedDict): auth_token: str | None file: str + class TimeseriesCache(object): def __init__(self, timeseries_group: Sequence[TimeseriesGroup]) -> None: self.timeseries_group = timeseries_group diff --git a/src/evaporation/cli.py b/src/evaporation/cli.py index eb5e702..5653f46 100644 --- a/src/evaporation/cli.py +++ b/src/evaporation/cli.py @@ -282,6 +282,7 @@ def _check_tif_hts_consistency(self, has_tif, has_hts): class ProcessAtPoint: timezone: str | None + def __init__(self, config): self.config = config diff --git a/src/evaporation/evaporation.py b/src/evaporation/evaporation.py index 3e585a6..528339d 100644 --- a/src/evaporation/evaporation.py +++ b/src/evaporation/evaporation.py @@ -152,9 +152,7 @@ def calculate_hourly( wind_speed_c = cast(Numeric, variables["wind_speed"]) pressure_c = cast(Numeric, variables["pressure"]) solar_radiation_c = cast(Numeric, variables["solar_radiation"]) - gamma = self.get_psychrometric_constant( - temperature_c, pressure_c - ) + gamma = self.get_psychrometric_constant(temperature_c, pressure_c) extraterrestrial_radiation = self.get_extraterrestrial_radiation(adatetime) r_so = cast(Numeric, extraterrestrial_radiation) * ( 0.75 + 2e-5 * self.elevation @@ -472,7 +470,9 @@ def get_soil_heat_flux_density( coefficient = np.where(incoming_solar_radiation > 0.05, 0.1, 0.5) return coefficient * rn - def get_saturation_vapour_pressure_curve_slope(self, temperature: Numeric) -> Numeric: + def get_saturation_vapour_pressure_curve_slope( + self, temperature: Numeric + ) -> Numeric: "Allen et al. (1998), p. 37, eq. 13." numerator = 4098 * self.get_saturation_vapour_pressure(temperature) with warnings.catch_warnings(): diff --git a/src/haggregate/cli.py b/src/haggregate/cli.py index c8a8045..20552ea 100644 --- a/src/haggregate/cli.py +++ b/src/haggregate/cli.py @@ -5,8 +5,8 @@ import logging import os import sys -from typing import Optional import traceback +from typing import Optional import click diff --git a/src/haggregate/haggregate.py b/src/haggregate/haggregate.py index 8741850..e4a5307 100644 --- a/src/haggregate/haggregate.py +++ b/src/haggregate/haggregate.py @@ -2,7 +2,7 @@ import datetime as dt from enum import Enum -from typing import Any, Callable, Dict, Optional +from typing import Any, Optional import numpy as np import pandas as pd diff --git a/src/haggregate/regularize.pyi b/src/haggregate/regularize.pyi index 64e2443..276f586 100644 --- a/src/haggregate/regularize.pyi +++ b/src/haggregate/regularize.pyi @@ -4,8 +4,7 @@ from htimeseries import HTimeseries from .haggregate import RegularizationMode -class RegularizeError(Exception): - ... +class RegularizeError(Exception): ... def regularize( ts: HTimeseries, diff --git a/src/hspatial/test.py b/src/hspatial/test.py index 11944bd..35ad030 100644 --- a/src/hspatial/test.py +++ b/src/hspatial/test.py @@ -1,7 +1,9 @@ import datetime as dt + import numpy as np from osgeo import gdal, osr + def setup_test_raster( filename: str, value: np.ndarray[np.float64, np.dtype[np.float64]], diff --git a/src/htimeseries/htimeseries.py b/src/htimeseries/htimeseries.py index 7ecba0d..a846acc 100644 --- a/src/htimeseries/htimeseries.py +++ b/src/htimeseries/htimeseries.py @@ -9,6 +9,7 @@ import numpy as np import pandas as pd from pandas.tseries.frequencies import to_offset + from textbisect import text_bisect_left from .timezone_utils import TzinfoFromString @@ -147,7 +148,7 @@ def write_altitude(self) -> None: if self.version <= 2 or not getattr(self.htimeseries, "location", None): return assert self.htimeseries.location is not None - + if self.htimeseries.location.get("altitude") is None: return altitude = self.htimeseries.location["altitude"] diff --git a/src/rocc/__init__.py b/src/rocc/__init__.py index 35862c6..8b9be78 100644 --- a/src/rocc/__init__.py +++ b/src/rocc/__init__.py @@ -4,7 +4,6 @@ from .calculation import Rocc - if TYPE_CHECKING: # pragma: no cover - used for type checkers only from htimeseries import HTimeseries diff --git a/src/textbisect/__init__.py b/src/textbisect/__init__.py index 446c2db..773ac8a 100644 --- a/src/textbisect/__init__.py +++ b/src/textbisect/__init__.py @@ -1,6 +1,5 @@ from io import SEEK_END - -from typing import Callable, IO, Union +from typing import IO, Callable, Union class TextBisector: diff --git a/tests/enhydris_api_client/__init__.py b/tests/enhydris_api_client/__init__.py index 043197a..94457f3 100644 --- a/tests/enhydris_api_client/__init__.py +++ b/tests/enhydris_api_client/__init__.py @@ -4,7 +4,7 @@ import textwrap from copy import copy from io import StringIO -from typing import Any, Callable, Dict, Optional, cast +from typing import Any, Callable, Dict, cast from unittest import mock import pandas as pd @@ -55,9 +55,7 @@ def mock_session(**kwargs: Any) -> mock._patch: f"(got {status_code!r} for {method})" ) if status_code < 200 or status_code >= 400: - method_side_effect = ( - f"{method}.return_value.raise_for_status.side_effect" - ) + method_side_effect = f"{method}.return_value.raise_for_status.side_effect" patch_kwargs[method_side_effect] = requests.HTTPError for old_key in list(patch_kwargs.keys()): patch_kwargs[f"return_value.{old_key}"] = patch_kwargs.pop(old_key) diff --git a/tests/enhydris_api_client/test_e2e.py b/tests/enhydris_api_client/test_e2e.py index 0cf5ee3..9bb21cf 100644 --- a/tests/enhydris_api_client/test_e2e.py +++ b/tests/enhydris_api_client/test_e2e.py @@ -9,7 +9,6 @@ from unittest import TestCase, skipUnless from zoneinfo import ZoneInfo -import pandas as pd import requests from enhydris_api_client import EnhydrisApiClient diff --git a/tests/enhydris_api_client/test_station.py b/tests/enhydris_api_client/test_station.py index 8438684..882bfc0 100644 --- a/tests/enhydris_api_client/test_station.py +++ b/tests/enhydris_api_client/test_station.py @@ -149,7 +149,7 @@ def test_returns_id(self) -> None: class PutStationTestCase(TestCase): @mock_session() - def test_makes_request(self,m: MagicMock) -> None: + def test_makes_request(self, m: MagicMock) -> None: self.client = EnhydrisApiClient("https://mydomain.com") self.client.put_station(42, data={"location": "Syria"}) m.return_value.put.assert_called_once_with( @@ -159,7 +159,7 @@ def test_makes_request(self,m: MagicMock) -> None: class PatchStationTestCase(TestCase): @mock_session() - def test_makes_request(self,m: MagicMock) -> None: + def test_makes_request(self, m: MagicMock) -> None: self.client = EnhydrisApiClient("https://mydomain.com") self.client.patch_station(42, data={"location": "Syria"}) m.return_value.patch.assert_called_once_with( diff --git a/tests/evaporation/test_cli.py b/tests/evaporation/test_cli.py index e1bbad6..7c2b42e 100644 --- a/tests/evaporation/test_cli.py +++ b/tests/evaporation/test_cli.py @@ -308,9 +308,7 @@ def test_albedo_configuration_as_one_grid(self, m: MagicMock) -> None: m.return_value.execute.assert_called_once_with() @patch("evaporation.cli.ProcessSpatial") - def test_seasonal_albedo_configuration_as_12_grids( - self, m: MagicMock - ) -> None: + def test_seasonal_albedo_configuration_as_12_grids(self, m: MagicMock) -> None: with open(self.configfilename, "w") as f: f.write( textwrap.dedent( @@ -413,9 +411,7 @@ def test_run_app_with_seasonal_albedo_with_mix_sample_inputs( m.return_value.execute.assert_called_once_with() @patch("evaporation.cli.ProcessSpatial") - def test_seasonal_albedo_configuration_as_12_numbers( - self, m: MagicMock - ) -> None: + def test_seasonal_albedo_configuration_as_12_numbers(self, m: MagicMock) -> None: with open(self.configfilename, "w") as f: f.write( textwrap.dedent( diff --git a/tests/haggregate/test_cli.py b/tests/haggregate/test_cli.py index 01e52b2..9aa7c49 100644 --- a/tests/haggregate/test_cli.py +++ b/tests/haggregate/test_cli.py @@ -1,6 +1,5 @@ import datetime as dt import textwrap -from typing import ClassVar from unittest import TestCase from unittest.mock import MagicMock, patch diff --git a/tests/haggregate/test_regularize.py b/tests/haggregate/test_regularize.py index e3cdfea..53c3bbe 100644 --- a/tests/haggregate/test_regularize.py +++ b/tests/haggregate/test_regularize.py @@ -291,4 +291,4 @@ def test_sets_comment(self): ) def test_sets_time_step(self): - self.assertEqual(self.result.time_step, "10min") \ No newline at end of file + self.assertEqual(self.result.time_step, "10min") diff --git a/tests/hspatial/test_hspatial.py b/tests/hspatial/test_hspatial.py index d762d1d..2f82830 100644 --- a/tests/hspatial/test_hspatial.py +++ b/tests/hspatial/test_hspatial.py @@ -528,7 +528,9 @@ def test_fails_gracefully_when_osr_point_is_really_outside_crs_limits(self) -> N with self.assertRaises(RuntimeError): hspatial.extract_point_from_raster(point, self.fp) - def test_fails_gracefully_when_geos_point_is_really_outside_crs_limits(self) -> None: + def test_fails_gracefully_when_geos_point_is_really_outside_crs_limits( + self, + ) -> None: point = GeoDjangoPoint(125.0, 85.0) with self.assertRaises(RuntimeError): hspatial.extract_point_from_raster(point, self.fp) diff --git a/tests/htimeseries/test_htimeseries.py b/tests/htimeseries/test_htimeseries.py index 9ef5b2e..2eb75ec 100644 --- a/tests/htimeseries/test_htimeseries.py +++ b/tests/htimeseries/test_htimeseries.py @@ -311,7 +311,9 @@ def setUp(self) -> None: names=("date", "value", "flags"), dtype={"value": np.float64, "flags": str}, ).asfreq("10min") - data.index = cast(pd.DatetimeIndex, data.index).tz_localize(dt.timezone(dt.timedelta(hours=2))) + data.index = cast(pd.DatetimeIndex, data.index).tz_localize( + dt.timezone(dt.timedelta(hours=2)) + ) self.reference_ts = HTimeseries(data=data) self.reference_ts.unit = "°C" self.reference_ts.title = "A test 10-min time series" diff --git a/tests/textbisect/test_textbisect.py b/tests/textbisect/test_textbisect.py index d9e113d..66d8c3b 100644 --- a/tests/textbisect/test_textbisect.py +++ b/tests/textbisect/test_textbisect.py @@ -1,8 +1,7 @@ import textwrap from io import StringIO -from unittest import TestCase - from typing import Any, Union +from unittest import TestCase from textbisect import text_bisect, text_bisect_left, text_bisect_right @@ -96,7 +95,7 @@ class TextBisectTestCaseBase(TestCase): @staticmethod def KEY(x: str) -> str: return x - + def _do_test( self, search_term: str,