66
77from collections import OrderedDict
88from dataclasses import dataclass , field
9- from datetime import datetime
9+ from datetime import datetime , timezone
1010from functools import singledispatch
1111from json import dump , dumps , loads
12- from os import environ , getenv
12+ from os import environ , getenv , walk
13+ from os .path import join
1314from pathlib import Path
15+ from typing import Sequence
1416from warnings import warn
1517from zipfile import ZIP_DEFLATED , ZipFile
1618
1719import polars as pl
20+ import pytz
1821import requests
1922from dateutil .relativedelta import relativedelta as rdelta
2023from dotenv import find_dotenv , load_dotenv
21- from pytz import UTC as utc
2224
2325from babylab .globals import COLNAMES , FIELDS_TO_RENAME , INT_FIELDS , SCHEMA , URI
2426
@@ -43,6 +45,17 @@ class BadAgeFormat(Exception):
4345 """If age does not follow the right format"""
4446
4547
48+ @dataclass
49+ class RecordList :
50+ """List of REDCap records."""
51+
52+ records : dict = field (default_factory = dict )
53+ kind : str | None = None
54+
55+ def __len__ (self ) -> int :
56+ return len (self .records )
57+
58+
4659@dataclass
4760class Record :
4861 ppt_id : str
@@ -51,8 +64,8 @@ class Record:
5164
5265@dataclass
5366class Participant (Record ):
54- appointments : list = field (default_factory = list )
55- questionnaires : list = field (default_factory = list )
67+ appointments : RecordList = field (default_factory = list )
68+ questionnaires : RecordList = field (default_factory = list )
5669
5770
5871@dataclass
@@ -72,22 +85,11 @@ def __post_init__(self):
7285 self .isestimated = self .data ["isestimated" ]
7386
7487
75- @dataclass
76- class RecordList :
77- """List of REDCap records."""
78-
79- records : dict = field (default_factory = dict )
80- kind : str | None = None
81-
82- def __len__ (self ) -> int :
83- return len (self .records )
84-
85-
86- def get_api_key (path : Path | str = None , name : str = "API_KEY" ) -> str :
88+ def get_api_key (path : Path | str | None = None , name : str = "API_KEY" ) -> str :
8789 """Retrieve API credentials.
8890
8991 Args:
90- path (Path | str, optional): Path to the .env file with global variables. Defaults to ``Path.home()``.
92+ path (Path | str | None , optional): Path to the .env file with global variables. Defaults to ``Path.home()``.
9193 name (str, optional): Name of the variable to import. Defaults to "API_KEY".
9294
9395 Returns:
@@ -118,19 +120,19 @@ def get_api_key(path: Path | str = None, name: str = "API_KEY") -> str:
118120 return token
119121
120122
121- def post_request (fields : dict , timeout : list [int ] = (5 , 10 )) -> dict :
123+ def post_request (fields : dict , timeout : Sequence [int ] = (5 , 10 )) -> requests . Response :
122124 """Make a POST request to the REDCap database.
123125
124126 Args:
125127 fields (dict): Fields to retrieve.
126- timeout (list [int], optional): Timeout of HTTP request in seconds. Defaults to 10.
128+ timeout (Sequence [int], optional): Timeout of HTTP request in seconds. Defaults to 10.
127129
128130 Raises:
129131 requests.exceptions.HTTPError: If HTTP request fails.
130132 BadToken: If API token contains non-alphanumeric characters.
131133
132134 Returns:
133- dict : HTTP request response in JSON format.
135+ requests.Response : HTTP request response in JSON format.
134136 """
135137 t = get_api_key ()
136138
@@ -404,12 +406,12 @@ def prepare_data(x: dict, kind: str = "ppt") -> dict:
404406 return fmt_labels (x )
405407
406408
407- def make_id (ppt_id : str , repeat_id : str = None ) -> str :
409+ def make_id (ppt_id : str | int , repeat_id : str | int | None = None ) -> str :
408410 """Make a record ID.
409411
410412 Args:
411- ppt_id (str): Participant ID.
412- repeat_id (str, optional): Appointment or Questionnaire ID, or ``redcap_repeated_id``. Defaults to None.
413+ ppt_id (str | int ): Participant ID.
414+ repeat_id (str | int | None , optional): Appointment or Questionnaire ID, or ``redcap_repeated_id``. Defaults to None.
413415
414416 Returns:
415417 str: Record ID.
@@ -439,7 +441,7 @@ def get_records(record_id: str | list | None = None) -> dict:
439441 record_id (str): ID of record to retrieve. Defaults to None.
440442
441443 Returns:
442- dict: REDCap records in JSON format.
444+ list[ dict[str, str]] : REDCap records in JSON format.
443445 """
444446 fields = {"content" : "record" , "format" : "json" , "type" : "flat" }
445447
@@ -449,9 +451,7 @@ def get_records(record_id: str | list | None = None) -> dict:
449451 for r in record_id :
450452 fields [f"records[{ record_id } ]" ] = r
451453
452- records = post_request (fields = fields ).json ()
453-
454- return [str_to_dt (r ) for r in records ]
454+ return post_request (fields = fields ).json ()
455455
456456
457457def get_participant (ppt_id : str ) -> Participant :
@@ -687,29 +687,23 @@ def warn_missing_record(r: requests.models.Response):
687687 warn ("Record does not exist!" , stacklevel = 2 )
688688
689689
690- def redcap_backup (path : Path | str = None ) -> dict :
690+ def redcap_backup (path : Path | str = Path ( "tmp" )) -> Path :
691691 """Download a backup of the REDCap database
692692
693693 Args:
694694 path (Path | str, optional): Output directory. Defaults to ``Path("tmp")``.
695695
696696 Returns:
697- dict: A dictionary with the key data and metadata of the project.
697+ Path: Path to the generated file with data and metadata of the project.
698698 """
699- if path is None :
700- path = Path ("tmp" )
701-
702- if isinstance (path , str ):
703- path = Path (path )
704-
705- if not path .exists ():
706- path .mkdir (exist_ok = True )
699+ path = Path (path )
700+ path .mkdir (exist_ok = True )
707701
708702 p = {}
709703 for k in ["project" , "metadata" , "instrument" ]:
710704 p [k ] = {"format" : "json" , "returnFormat" : "json" , "content" : k }
711705
712- d = {k : loads (post_request (v ).text ) for k , v in pl .items ()}
706+ d = {k : loads (post_request (v ).text ) for k , v in p .items ()}
713707
714708 with open (path / "records.csv" , "w+" , encoding = "utf-8" ) as f :
715709 fields = {
@@ -736,19 +730,20 @@ def redcap_backup(path: Path | str = None) -> dict:
736730 timestamp = datetime .strftime (datetime .now (), "%Y-%m-%d-%H-%M" )
737731 file = path / ("backup_" + timestamp + ".zip" )
738732
739- for root , _ , files in path . walk (top_down = False ):
733+ for root , _ , files in walk (str ( path ), topdown = False ):
740734 with ZipFile (file , "w" , ZIP_DEFLATED ) as z :
741735 for f in files :
742- z .write (root / f )
736+ z .write (join ( root , f ) )
743737
744738 return file
745739
746740
747741class Records :
748742 """REDCap records"""
749743
750- def __init__ (self , record_id : str | list = None ):
744+ def __init__ (self , record_id : str | list | None = None ):
751745 records = get_records (record_id )
746+ records = [str_to_dt (r ) for r in records ]
752747 ppt , apt , que = {}, {}, {}
753748
754749 for r in records :
@@ -823,11 +818,11 @@ def parse_age(age: tuple) -> tuple[int, int]:
823818 raise BadAgeFormat ("age must be in (months, age) format" ) from e
824819
825820
826- def parse_str_date (x : str ) -> datetime :
821+ def parse_str_date (x : str | datetime ) -> datetime :
827822 """Parse string data to datetime.
828823
829824 Args:
830- x (str): String date to parse.
825+ x (str | datetime ): String date to parse.
831826
832827 Returns:
833828 datetime: Parsed datetime.
@@ -844,25 +839,23 @@ def parse_str_date(x: str) -> datetime:
844839 return datetime .strptime (x , "%Y-%m-%d %H:%M" )
845840
846841
847- def get_age (age : str | tuple , ts : datetime | str , ts_new : datetime = None ):
842+ def get_age (
843+ age : tuple , ts : datetime | str , ts_new : datetime | None = None , tz : str = "UTC"
844+ ):
848845 """Calculate the age of a person in months and days at a new timestamp.
849846
850847 Args:
851848 age (tuple): Age in months and days as a tuple of type (months, days).
852849 ts (datetime | str): Birth date as ``datetime.datetime`` type.
853- ts_new (datetime.datetime, optional): Time for which the age is calculated. Defaults to current date (``datetime.datetime.now()``).
850+ ts_new (datetime.datetime | None , optional): Time for which the age is calculated. Defaults to current date (``datetime.datetime.now()``).
854851
855852 Returns:
856853 tuple: Age in at ``new_timestamp``.
857854 """
858- ts = parse_str_date (ts )
859- ts_new = datetime .now (utc ) if ts_new is None else ts_new
860-
861- if ts .tzinfo is None or ts .tzinfo .utcoffset (ts ) is None :
862- ts = utc .localize (ts , True )
863-
864- if ts_new .tzinfo is None or ts_new .tzinfo .utcoffset (ts_new ) is None :
865- ts_new = utc .localize (ts_new , True )
855+ tz = pytz .timezone (tz )
856+ ts = tz .localize (parse_str_date (ts ))
857+ ts_new = datetime .now () if ts_new is None else ts_new
858+ ts_new = tz .localize (ts_new )
866859
867860 tdiff = rdelta (ts_new , ts )
868861 months , days = parse_age (age )
0 commit comments