Skip to content

Commit e8a689f

Browse files
committed
Fix linting in tests
1 parent 29a2180 commit e8a689f

File tree

3 files changed

+13
-5
lines changed

3 files changed

+13
-5
lines changed

Makefile

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,4 +2,4 @@ test:
22
uvx ruff format
33
uvx ruff check --fix .
44
uvx ty check
5-
uvx pytest -v --exitfirst
5+
uv run pytest -v --exitfirst

babylab/utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -194,14 +194,14 @@ def count_col(
194194
return counts
195195

196196

197-
def get_year_weeks(year: int) -> Generator[datetime]:
197+
def get_year_weeks(year: int) -> Generator[datetime, datetime, datetime]:
198198
"""Get week numbers of the year.
199199
200200
Args:
201201
year (int): Year to get weeks for.
202202
203203
Yields:
204-
int: Number of weeks in the year.
204+
Generator[datetime, datetime, datetime]: Number of weeks in the year.
205205
"""
206206
date_first = date(year, 1, 1)
207207
date_first += timedelta(days=6 - date_first.weekday())

tests/test_utils.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -119,7 +119,11 @@ def test_get_apt_table_id_list(apt_id: str | Sequence[str] | None = None, k: int
119119
apt_id = [apt_id]
120120

121121
ppt_id = set(i.split(":")[0] for i in apt_id)
122-
df = utils.get_apt_table(conftest.RECORDS, ppt_id=sample(ppt_id, k=k))
122+
123+
if k > len(ppt_id):
124+
k = len(ppt_id)
125+
126+
df = utils.get_apt_table(conftest.RECORDS, ppt_id=sample(list(ppt_id), k=k))
123127

124128
assert isinstance(df, pl.DataFrame)
125129
assert all(p in ppt_id for p in df["record_id"].unique().to_list())
@@ -148,7 +152,11 @@ def test_get_que_table_id_list(que_id: str | list[str] | None = None, k: int = 1
148152
que_id = [que_id]
149153

150154
ppt_id = set(i.split(":")[0] for i in que_id)
151-
df = utils.get_apt_table(conftest.RECORDS, ppt_id=sample(ppt_id, k=k))
155+
156+
if k > len(ppt_id):
157+
k = len(ppt_id)
158+
159+
df = utils.get_apt_table(conftest.RECORDS, ppt_id=sample(list(ppt_id), k=k))
152160
assert isinstance(df, pl.DataFrame)
153161
assert all(p in ppt_id for p in df["record_id"].unique().to_list())
154162

0 commit comments

Comments
 (0)