Skip to content

Commit 655b3dd

Browse files
authored
fix: resolve CI test failures and lint formatting issues (GeiserX#85)
- Add beautifulsoup4 to test CI deps (needed for HTML import tests) - Add get_session/delete_session mocks to prevent 500 on logout test - Update rate limit test threshold from 5 to 15 to match production config - Apply ruff format to 6 files that failed format check
1 parent 7449e0b commit 655b3dd

8 files changed

Lines changed: 34 additions & 33 deletions

File tree

.github/workflows/tests.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ jobs:
2323
run: |
2424
python -m pip install --upgrade pip
2525
pip install -r requirements.txt
26-
pip install pytest pytest-cov
26+
pip install pytest pytest-cov beautifulsoup4
2727
2828
- name: Run tests
2929
run: |

scripts/auth_noninteractive.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,11 @@ async def main():
7474
await client.sign_in(phone, code)
7575
except Exception as e:
7676
error_str = str(e)
77-
if "Two-steps verification" in error_str or "password" in error_str.lower() or "SessionPasswordNeeded" in error_str:
77+
if (
78+
"Two-steps verification" in error_str
79+
or "password" in error_str.lower()
80+
or "SessionPasswordNeeded" in error_str
81+
):
7882
if not password:
7983
print("2FA is enabled. Re-run with: verify CODE 2FA_PASSWORD")
8084
await client.disconnect()

src/db/adapter.py

Lines changed: 4 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1845,9 +1845,7 @@ async def save_session(
18451845
async def get_session(self, token: str) -> dict[str, Any] | None:
18461846
"""Get a session by token."""
18471847
async with self.db_manager.async_session_factory() as session:
1848-
result = await session.execute(
1849-
select(ViewerSession).where(ViewerSession.token == token)
1850-
)
1848+
result = await session.execute(select(ViewerSession).where(ViewerSession.token == token))
18511849
row = result.scalar_one_or_none()
18521850
return self._viewer_session_to_dict(row) if row else None
18531851

@@ -1861,19 +1859,15 @@ async def load_all_sessions(self) -> list[dict[str, Any]]:
18611859
async def delete_session(self, token: str) -> bool:
18621860
"""Delete a single session by token."""
18631861
async with self.db_manager.async_session_factory() as session:
1864-
result = await session.execute(
1865-
delete(ViewerSession).where(ViewerSession.token == token)
1866-
)
1862+
result = await session.execute(delete(ViewerSession).where(ViewerSession.token == token))
18671863
await session.commit()
18681864
return result.rowcount > 0
18691865

18701866
@retry_on_locked()
18711867
async def delete_user_sessions(self, username: str) -> int:
18721868
"""Delete all sessions for a given username. Returns count deleted."""
18731869
async with self.db_manager.async_session_factory() as session:
1874-
result = await session.execute(
1875-
delete(ViewerSession).where(ViewerSession.username == username)
1876-
)
1870+
result = await session.execute(delete(ViewerSession).where(ViewerSession.username == username))
18771871
await session.commit()
18781872
return result.rowcount
18791873

@@ -1884,9 +1878,7 @@ async def cleanup_expired_sessions(self, max_age_seconds: float) -> int:
18841878

18851879
cutoff = time.time() - max_age_seconds
18861880
async with self.db_manager.async_session_factory() as session:
1887-
result = await session.execute(
1888-
delete(ViewerSession).where(ViewerSession.created_at < cutoff)
1889-
)
1881+
result = await session.execute(delete(ViewerSession).where(ViewerSession.created_at < cutoff))
18901882
await session.commit()
18911883
return result.rowcount
18921884

src/db/models.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -248,7 +248,9 @@ class PushSubscription(Base):
248248
auth: Mapped[str] = mapped_column(String(255), nullable=False) # Auth secret
249249
chat_id: Mapped[int | None] = mapped_column(BigInteger) # Optional: subscribe to specific chat only
250250
username: Mapped[str | None] = mapped_column(String(255)) # User who created this subscription
251-
allowed_chat_ids: Mapped[str | None] = mapped_column(Text) # JSON snapshot of user's allowed chats at subscribe time
251+
allowed_chat_ids: Mapped[str | None] = mapped_column(
252+
Text
253+
) # JSON snapshot of user's allowed chats at subscribe time
252254
user_agent: Mapped[str | None] = mapped_column(String(500)) # Browser info for debugging
253255
created_at: Mapped[datetime] = mapped_column(DateTime, default=datetime.utcnow, server_default=func.now())
254256
last_used_at: Mapped[datetime | None] = mapped_column(DateTime) # Track activity

src/telegram_import.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -123,12 +123,12 @@ def parse_date(msg: dict) -> datetime | None:
123123
if "date_unixtime" in msg:
124124
try:
125125
return datetime.fromtimestamp(int(msg["date_unixtime"]), tz=UTC).replace(tzinfo=None)
126-
except (ValueError, TypeError, OSError):
126+
except ValueError, TypeError, OSError:
127127
pass
128128
if "date" in msg:
129129
try:
130130
return datetime.fromisoformat(msg["date"]).replace(tzinfo=None)
131-
except (ValueError, TypeError):
131+
except ValueError, TypeError:
132132
pass
133133
return None
134134

@@ -138,12 +138,12 @@ def parse_edited_date(msg: dict) -> datetime | None:
138138
if "edited_unixtime" in msg:
139139
try:
140140
return datetime.fromtimestamp(int(msg["edited_unixtime"]), tz=UTC).replace(tzinfo=None)
141-
except (ValueError, TypeError, OSError):
141+
except ValueError, TypeError, OSError:
142142
pass
143143
if "edited" in msg:
144144
try:
145145
return datetime.fromisoformat(msg["edited"]).replace(tzinfo=None)
146-
except (ValueError, TypeError):
146+
except ValueError, TypeError:
147147
pass
148148
return None
149149

@@ -222,7 +222,7 @@ def parse_html_date(date_str: str) -> str | None:
222222
try:
223223
day, month, year = parts[0].split(".")
224224
return f"{year}-{month}-{day}T{parts[1]}"
225-
except (ValueError, IndexError):
225+
except ValueError, IndexError:
226226
return None
227227

228228

@@ -551,8 +551,7 @@ async def run(
551551
chats = [{"name": chat_name, "type": "html_export", "id": 0, "messages": messages}]
552552
else:
553553
raise FileNotFoundError(
554-
f"No result.json or messages.html found in {path}. "
555-
"Expected a Telegram Desktop export directory."
554+
f"No result.json or messages.html found in {path}. Expected a Telegram Desktop export directory."
556555
)
557556

558557
if not chats:

src/web/main.py

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -318,7 +318,7 @@ async def lifespan(app: FastAPI) -> AsyncGenerator[None]:
318318
if row["allowed_chat_ids"]:
319319
try:
320320
allowed = set(json.loads(row["allowed_chat_ids"]))
321-
except (json.JSONDecodeError, TypeError):
321+
except json.JSONDecodeError, TypeError:
322322
logger.warning(f"Skipping session with corrupted allowed_chat_ids for {row['username']}")
323323
continue
324324
_sessions[row["token"]] = SessionData(
@@ -498,8 +498,11 @@ async def _create_session(username: str, role: str, allowed_chat_ids: set[int] |
498498
now = time.time()
499499
token = secrets.token_urlsafe(32)
500500
_sessions[token] = SessionData(
501-
username=username, role=role, allowed_chat_ids=allowed_chat_ids,
502-
created_at=now, last_accessed=now,
501+
username=username,
502+
role=role,
503+
allowed_chat_ids=allowed_chat_ids,
504+
created_at=now,
505+
last_accessed=now,
503506
)
504507

505508
# Persist to database
@@ -563,7 +566,7 @@ async def _resolve_session(auth_cookie: str) -> SessionData | None:
563566
if row["allowed_chat_ids"]:
564567
try:
565568
allowed = set(json.loads(row["allowed_chat_ids"]))
566-
except (json.JSONDecodeError, TypeError):
569+
except json.JSONDecodeError, TypeError:
567570
logger.warning(f"Corrupted allowed_chat_ids for session {row['username']}, denying access")
568571
return None
569572

@@ -759,7 +762,7 @@ async def login(request: Request):
759762
if viewer["allowed_chat_ids"]:
760763
try:
761764
allowed = set(json.loads(viewer["allowed_chat_ids"]))
762-
except (json.JSONDecodeError, TypeError):
765+
except json.JSONDecodeError, TypeError:
763766
allowed = None
764767

765768
token = await _create_session(username, "viewer", allowed)
@@ -1259,8 +1262,7 @@ async def internal_push(request: Request):
12591262

12601263
allowed = False
12611264
if client_host and (
1262-
client_host in ("127.0.0.1", "localhost", "::1")
1263-
or client_host.startswith(("172.", "10.", "192.168."))
1265+
client_host in ("127.0.0.1", "localhost", "::1") or client_host.startswith(("172.", "10.", "192.168."))
12641266
):
12651267
allowed = True
12661268

@@ -1436,7 +1438,7 @@ async def create_viewer(request: Request, user: UserContext = Depends(require_ma
14361438
if allowed_chat_ids is not None:
14371439
try:
14381440
chat_ids_json = json.dumps([int(cid) for cid in allowed_chat_ids])
1439-
except (ValueError, TypeError):
1441+
except ValueError, TypeError:
14401442
raise HTTPException(status_code=400, detail="Invalid chat ID format")
14411443

14421444
account = await db.create_viewer_account(
@@ -1491,7 +1493,7 @@ async def update_viewer(viewer_id: int, request: Request, user: UserContext = De
14911493
else:
14921494
try:
14931495
updates["allowed_chat_ids"] = json.dumps([int(cid) for cid in allowed])
1494-
except (ValueError, TypeError):
1496+
except ValueError, TypeError:
14951497
raise HTTPException(status_code=400, detail="Invalid chat ID format")
14961498

14971499
if "is_active" in data:

src/web/push.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -219,7 +219,7 @@ async def get_subscriptions(self, chat_id: int | None = None) -> list[dict[str,
219219
user_chats = json.loads(sub.allowed_chat_ids)
220220
if chat_id not in user_chats:
221221
continue
222-
except (json.JSONDecodeError, TypeError):
222+
except json.JSONDecodeError, TypeError:
223223
continue
224224
filtered.append({"endpoint": sub.endpoint, "keys": {"p256dh": sub.p256dh, "auth": sub.auth}})
225225

tests/test_multi_user_auth.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,8 @@ def _make_mock_db():
4848
db.calculate_and_store_statistics = AsyncMock(return_value={"total_chats": 3})
4949
db.get_all_folders = AsyncMock(return_value=[])
5050
db.get_archived_chat_count = AsyncMock(return_value=0)
51+
db.get_session = AsyncMock(return_value=None)
52+
db.delete_session = AsyncMock()
5153
return db
5254

5355

@@ -246,7 +248,7 @@ class TestRateLimiting:
246248

247249
def test_rate_limit_blocks_after_threshold(self, auth_env):
248250
client, mod, _ = _get_client()
249-
for _ in range(5):
251+
for _ in range(15):
250252
client.post("/api/login", json={"username": "admin", "password": "wrong"})
251253

252254
resp = client.post("/api/login", json={"username": "admin", "password": "wrong"})

0 commit comments

Comments
 (0)