Skip to content

Commit e7be10c

Browse files
committed
Use ty for typing
1 parent de7d395 commit e7be10c

File tree

7 files changed

+136
-138
lines changed

7 files changed

+136
-138
lines changed

.vscode/settings.json

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,8 @@
11
{
2-
"git.ignoreLimitWarning": true
2+
"git.ignoreLimitWarning": true,
3+
"python.testing.pytestArgs": [
4+
"tests"
5+
],
6+
"python.testing.unittestEnabled": false,
7+
"python.testing.pytestEnabled": true
38
}

Makefile

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
test:
2-
uv run ruff format
3-
uv run ruff check
4-
uv run pytest -v --exitfirst
2+
uvx ruff format
3+
uvx ruff check --fix .
4+
uvx ty check
5+
uvx pytest -v --exitfirst

babylab/api.py

Lines changed: 47 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -6,19 +6,21 @@
66

77
from collections import OrderedDict
88
from dataclasses import dataclass, field
9-
from datetime import datetime
9+
from datetime import datetime, timezone
1010
from functools import singledispatch
1111
from json import dump, dumps, loads
12-
from os import environ, getenv
12+
from os import environ, getenv, walk
13+
from os.path import join
1314
from pathlib import Path
15+
from typing import Sequence
1416
from warnings import warn
1517
from zipfile import ZIP_DEFLATED, ZipFile
1618

1719
import polars as pl
20+
import pytz
1821
import requests
1922
from dateutil.relativedelta import relativedelta as rdelta
2023
from dotenv import find_dotenv, load_dotenv
21-
from pytz import UTC as utc
2224

2325
from babylab.globals import COLNAMES, FIELDS_TO_RENAME, INT_FIELDS, SCHEMA, URI
2426

@@ -43,6 +45,17 @@ class BadAgeFormat(Exception):
4345
"""If age does not follow the right format"""
4446

4547

48+
@dataclass
49+
class RecordList:
50+
"""List of REDCap records."""
51+
52+
records: dict = field(default_factory=dict)
53+
kind: str | None = None
54+
55+
def __len__(self) -> int:
56+
return len(self.records)
57+
58+
4659
@dataclass
4760
class Record:
4861
ppt_id: str
@@ -51,8 +64,8 @@ class Record:
5164

5265
@dataclass
5366
class Participant(Record):
54-
appointments: list = field(default_factory=list)
55-
questionnaires: list = field(default_factory=list)
67+
appointments: RecordList = field(default_factory=list)
68+
questionnaires: RecordList = field(default_factory=list)
5669

5770

5871
@dataclass
@@ -72,22 +85,11 @@ def __post_init__(self):
7285
self.isestimated = self.data["isestimated"]
7386

7487

75-
@dataclass
76-
class RecordList:
77-
"""List of REDCap records."""
78-
79-
records: dict = field(default_factory=dict)
80-
kind: str | None = None
81-
82-
def __len__(self) -> int:
83-
return len(self.records)
84-
85-
86-
def get_api_key(path: Path | str = None, name: str = "API_KEY") -> str:
88+
def get_api_key(path: Path | str | None = None, name: str = "API_KEY") -> str:
8789
"""Retrieve API credentials.
8890
8991
Args:
90-
path (Path | str, optional): Path to the .env file with global variables. Defaults to ``Path.home()``.
92+
path (Path | str | None, optional): Path to the .env file with global variables. Defaults to ``Path.home()``.
9193
name (str, optional): Name of the variable to import. Defaults to "API_KEY".
9294
9395
Returns:
@@ -118,19 +120,19 @@ def get_api_key(path: Path | str = None, name: str = "API_KEY") -> str:
118120
return token
119121

120122

121-
def post_request(fields: dict, timeout: list[int] = (5, 10)) -> dict:
123+
def post_request(fields: dict, timeout: Sequence[int] = (5, 10)) -> requests.Response:
122124
"""Make a POST request to the REDCap database.
123125
124126
Args:
125127
fields (dict): Fields to retrieve.
126-
timeout (list[int], optional): Timeout of HTTP request in seconds. Defaults to 10.
128+
timeout (Sequence[int], optional): Timeout of HTTP request in seconds. Defaults to 10.
127129
128130
Raises:
129131
requests.exceptions.HTTPError: If HTTP request fails.
130132
BadToken: If API token contains non-alphanumeric characters.
131133
132134
Returns:
133-
dict: HTTP request response in JSON format.
135+
requests.Response: HTTP request response in JSON format.
134136
"""
135137
t = get_api_key()
136138

@@ -404,12 +406,12 @@ def prepare_data(x: dict, kind: str = "ppt") -> dict:
404406
return fmt_labels(x)
405407

406408

407-
def make_id(ppt_id: str, repeat_id: str = None) -> str:
409+
def make_id(ppt_id: str | int, repeat_id: str | int | None = None) -> str:
408410
"""Make a record ID.
409411
410412
Args:
411-
ppt_id (str): Participant ID.
412-
repeat_id (str, optional): Appointment or Questionnaire ID, or ``redcap_repeated_id``. Defaults to None.
413+
ppt_id (str | int): Participant ID.
414+
repeat_id (str | int | None, optional): Appointment or Questionnaire ID, or ``redcap_repeated_id``. Defaults to None.
413415
414416
Returns:
415417
str: Record ID.
@@ -439,7 +441,7 @@ def get_records(record_id: str | list | None = None) -> dict:
439441
record_id (str): ID of record to retrieve. Defaults to None.
440442
441443
Returns:
442-
dict: REDCap records in JSON format.
444+
list[dict[str, str]]: REDCap records in JSON format.
443445
"""
444446
fields = {"content": "record", "format": "json", "type": "flat"}
445447

@@ -449,9 +451,7 @@ def get_records(record_id: str | list | None = None) -> dict:
449451
for r in record_id:
450452
fields[f"records[{record_id}]"] = r
451453

452-
records = post_request(fields=fields).json()
453-
454-
return [str_to_dt(r) for r in records]
454+
return post_request(fields=fields).json()
455455

456456

457457
def get_participant(ppt_id: str) -> Participant:
@@ -687,29 +687,23 @@ def warn_missing_record(r: requests.models.Response):
687687
warn("Record does not exist!", stacklevel=2)
688688

689689

690-
def redcap_backup(path: Path | str = None) -> dict:
690+
def redcap_backup(path: Path | str = Path("tmp")) -> Path:
691691
"""Download a backup of the REDCap database
692692
693693
Args:
694694
path (Path | str, optional): Output directory. Defaults to ``Path("tmp")``.
695695
696696
Returns:
697-
dict: A dictionary with the key data and metadata of the project.
697+
Path: Path to the generated file with data and metadata of the project.
698698
"""
699-
if path is None:
700-
path = Path("tmp")
701-
702-
if isinstance(path, str):
703-
path = Path(path)
704-
705-
if not path.exists():
706-
path.mkdir(exist_ok=True)
699+
path = Path(path)
700+
path.mkdir(exist_ok=True)
707701

708702
p = {}
709703
for k in ["project", "metadata", "instrument"]:
710704
p[k] = {"format": "json", "returnFormat": "json", "content": k}
711705

712-
d = {k: loads(post_request(v).text) for k, v in pl.items()}
706+
d = {k: loads(post_request(v).text) for k, v in p.items()}
713707

714708
with open(path / "records.csv", "w+", encoding="utf-8") as f:
715709
fields = {
@@ -736,19 +730,20 @@ def redcap_backup(path: Path | str = None) -> dict:
736730
timestamp = datetime.strftime(datetime.now(), "%Y-%m-%d-%H-%M")
737731
file = path / ("backup_" + timestamp + ".zip")
738732

739-
for root, _, files in path.walk(top_down=False):
733+
for root, _, files in walk(str(path), topdown=False):
740734
with ZipFile(file, "w", ZIP_DEFLATED) as z:
741735
for f in files:
742-
z.write(root / f)
736+
z.write(join(root, f))
743737

744738
return file
745739

746740

747741
class Records:
748742
"""REDCap records"""
749743

750-
def __init__(self, record_id: str | list = None):
744+
def __init__(self, record_id: str | list | None = None):
751745
records = get_records(record_id)
746+
records = [str_to_dt(r) for r in records]
752747
ppt, apt, que = {}, {}, {}
753748

754749
for r in records:
@@ -823,11 +818,11 @@ def parse_age(age: tuple) -> tuple[int, int]:
823818
raise BadAgeFormat("age must be in (months, age) format") from e
824819

825820

826-
def parse_str_date(x: str) -> datetime:
821+
def parse_str_date(x: str | datetime) -> datetime:
827822
"""Parse string data to datetime.
828823
829824
Args:
830-
x (str): String date to parse.
825+
x (str | datetime): String date to parse.
831826
832827
Returns:
833828
datetime: Parsed datetime.
@@ -844,25 +839,23 @@ def parse_str_date(x: str) -> datetime:
844839
return datetime.strptime(x, "%Y-%m-%d %H:%M")
845840

846841

847-
def get_age(age: str | tuple, ts: datetime | str, ts_new: datetime = None):
842+
def get_age(
843+
age: tuple, ts: datetime | str, ts_new: datetime | None = None, tz: str = "UTC"
844+
):
848845
"""Calculate the age of a person in months and days at a new timestamp.
849846
850847
Args:
851848
age (tuple): Age in months and days as a tuple of type (months, days).
852849
ts (datetime | str): Birth date as ``datetime.datetime`` type.
853-
ts_new (datetime.datetime, optional): Time for which the age is calculated. Defaults to current date (``datetime.datetime.now()``).
850+
ts_new (datetime.datetime | None, optional): Time for which the age is calculated. Defaults to current date (``datetime.datetime.now()``).
854851
855852
Returns:
856853
tuple: Age in at ``new_timestamp``.
857854
"""
858-
ts = parse_str_date(ts)
859-
ts_new = datetime.now(utc) if ts_new is None else ts_new
860-
861-
if ts.tzinfo is None or ts.tzinfo.utcoffset(ts) is None:
862-
ts = utc.localize(ts, True)
863-
864-
if ts_new.tzinfo is None or ts_new.tzinfo.utcoffset(ts_new) is None:
865-
ts_new = utc.localize(ts_new, True)
855+
tz = pytz.timezone(tz)
856+
ts = tz.localize(parse_str_date(ts))
857+
ts_new = datetime.now() if ts_new is None else ts_new
858+
ts_new = tz.localize(ts_new)
866859

867860
tdiff = rdelta(ts_new, ts)
868861
months, days = parse_age(age)

0 commit comments

Comments
 (0)