diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 7fb5c158..dd13dac9 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -23,7 +23,7 @@ jobs: run: | python -m pip install --upgrade pip pip install -r requirements.txt - pip install pytest pytest-cov + pip install pytest pytest-cov beautifulsoup4 - name: Run tests run: | diff --git a/scripts/auth_noninteractive.py b/scripts/auth_noninteractive.py index ff8eb94f..91de032d 100644 --- a/scripts/auth_noninteractive.py +++ b/scripts/auth_noninteractive.py @@ -74,7 +74,11 @@ async def main(): await client.sign_in(phone, code) except Exception as e: error_str = str(e) - if "Two-steps verification" in error_str or "password" in error_str.lower() or "SessionPasswordNeeded" in error_str: + if ( + "Two-steps verification" in error_str + or "password" in error_str.lower() + or "SessionPasswordNeeded" in error_str + ): if not password: print("2FA is enabled. Re-run with: verify CODE 2FA_PASSWORD") await client.disconnect() diff --git a/src/db/adapter.py b/src/db/adapter.py index ef547a6d..44661b6c 100644 --- a/src/db/adapter.py +++ b/src/db/adapter.py @@ -1845,9 +1845,7 @@ async def save_session( async def get_session(self, token: str) -> dict[str, Any] | None: """Get a session by token.""" async with self.db_manager.async_session_factory() as session: - result = await session.execute( - select(ViewerSession).where(ViewerSession.token == token) - ) + result = await session.execute(select(ViewerSession).where(ViewerSession.token == token)) row = result.scalar_one_or_none() return self._viewer_session_to_dict(row) if row else None @@ -1861,9 +1859,7 @@ async def load_all_sessions(self) -> list[dict[str, Any]]: async def delete_session(self, token: str) -> bool: """Delete a single session by token.""" async with self.db_manager.async_session_factory() as session: - result = await session.execute( - delete(ViewerSession).where(ViewerSession.token == token) - ) + result = await session.execute(delete(ViewerSession).where(ViewerSession.token == token)) await session.commit() return result.rowcount > 0 @@ -1871,9 +1867,7 @@ async def delete_session(self, token: str) -> bool: async def delete_user_sessions(self, username: str) -> int: """Delete all sessions for a given username. Returns count deleted.""" async with self.db_manager.async_session_factory() as session: - result = await session.execute( - delete(ViewerSession).where(ViewerSession.username == username) - ) + result = await session.execute(delete(ViewerSession).where(ViewerSession.username == username)) await session.commit() return result.rowcount @@ -1884,9 +1878,7 @@ async def cleanup_expired_sessions(self, max_age_seconds: float) -> int: cutoff = time.time() - max_age_seconds async with self.db_manager.async_session_factory() as session: - result = await session.execute( - delete(ViewerSession).where(ViewerSession.created_at < cutoff) - ) + result = await session.execute(delete(ViewerSession).where(ViewerSession.created_at < cutoff)) await session.commit() return result.rowcount diff --git a/src/db/models.py b/src/db/models.py index 5c58c041..cc944dc9 100644 --- a/src/db/models.py +++ b/src/db/models.py @@ -248,7 +248,9 @@ class PushSubscription(Base): auth: Mapped[str] = mapped_column(String(255), nullable=False) # Auth secret chat_id: Mapped[int | None] = mapped_column(BigInteger) # Optional: subscribe to specific chat only username: Mapped[str | None] = mapped_column(String(255)) # User who created this subscription - allowed_chat_ids: Mapped[str | None] = mapped_column(Text) # JSON snapshot of user's allowed chats at subscribe time + allowed_chat_ids: Mapped[str | None] = mapped_column( + Text + ) # JSON snapshot of user's allowed chats at subscribe time user_agent: Mapped[str | None] = mapped_column(String(500)) # Browser info for debugging created_at: Mapped[datetime] = mapped_column(DateTime, default=datetime.utcnow, server_default=func.now()) last_used_at: Mapped[datetime | None] = mapped_column(DateTime) # Track activity diff --git a/src/telegram_import.py b/src/telegram_import.py index f7c1fc40..bf0280ae 100644 --- a/src/telegram_import.py +++ b/src/telegram_import.py @@ -123,12 +123,12 @@ def parse_date(msg: dict) -> datetime | None: if "date_unixtime" in msg: try: return datetime.fromtimestamp(int(msg["date_unixtime"]), tz=UTC).replace(tzinfo=None) - except (ValueError, TypeError, OSError): + except ValueError, TypeError, OSError: pass if "date" in msg: try: return datetime.fromisoformat(msg["date"]).replace(tzinfo=None) - except (ValueError, TypeError): + except ValueError, TypeError: pass return None @@ -138,12 +138,12 @@ def parse_edited_date(msg: dict) -> datetime | None: if "edited_unixtime" in msg: try: return datetime.fromtimestamp(int(msg["edited_unixtime"]), tz=UTC).replace(tzinfo=None) - except (ValueError, TypeError, OSError): + except ValueError, TypeError, OSError: pass if "edited" in msg: try: return datetime.fromisoformat(msg["edited"]).replace(tzinfo=None) - except (ValueError, TypeError): + except ValueError, TypeError: pass return None @@ -222,7 +222,7 @@ def parse_html_date(date_str: str) -> str | None: try: day, month, year = parts[0].split(".") return f"{year}-{month}-{day}T{parts[1]}" - except (ValueError, IndexError): + except ValueError, IndexError: return None @@ -551,8 +551,7 @@ async def run( chats = [{"name": chat_name, "type": "html_export", "id": 0, "messages": messages}] else: raise FileNotFoundError( - f"No result.json or messages.html found in {path}. " - "Expected a Telegram Desktop export directory." + f"No result.json or messages.html found in {path}. Expected a Telegram Desktop export directory." ) if not chats: diff --git a/src/web/main.py b/src/web/main.py index 49614dcf..5fca11f7 100644 --- a/src/web/main.py +++ b/src/web/main.py @@ -318,7 +318,7 @@ async def lifespan(app: FastAPI) -> AsyncGenerator[None]: if row["allowed_chat_ids"]: try: allowed = set(json.loads(row["allowed_chat_ids"])) - except (json.JSONDecodeError, TypeError): + except json.JSONDecodeError, TypeError: logger.warning(f"Skipping session with corrupted allowed_chat_ids for {row['username']}") continue _sessions[row["token"]] = SessionData( @@ -498,8 +498,11 @@ async def _create_session(username: str, role: str, allowed_chat_ids: set[int] | now = time.time() token = secrets.token_urlsafe(32) _sessions[token] = SessionData( - username=username, role=role, allowed_chat_ids=allowed_chat_ids, - created_at=now, last_accessed=now, + username=username, + role=role, + allowed_chat_ids=allowed_chat_ids, + created_at=now, + last_accessed=now, ) # Persist to database @@ -563,7 +566,7 @@ async def _resolve_session(auth_cookie: str) -> SessionData | None: if row["allowed_chat_ids"]: try: allowed = set(json.loads(row["allowed_chat_ids"])) - except (json.JSONDecodeError, TypeError): + except json.JSONDecodeError, TypeError: logger.warning(f"Corrupted allowed_chat_ids for session {row['username']}, denying access") return None @@ -759,7 +762,7 @@ async def login(request: Request): if viewer["allowed_chat_ids"]: try: allowed = set(json.loads(viewer["allowed_chat_ids"])) - except (json.JSONDecodeError, TypeError): + except json.JSONDecodeError, TypeError: allowed = None token = await _create_session(username, "viewer", allowed) @@ -1259,8 +1262,7 @@ async def internal_push(request: Request): allowed = False if client_host and ( - client_host in ("127.0.0.1", "localhost", "::1") - or client_host.startswith(("172.", "10.", "192.168.")) + client_host in ("127.0.0.1", "localhost", "::1") or client_host.startswith(("172.", "10.", "192.168.")) ): allowed = True @@ -1436,7 +1438,7 @@ async def create_viewer(request: Request, user: UserContext = Depends(require_ma if allowed_chat_ids is not None: try: chat_ids_json = json.dumps([int(cid) for cid in allowed_chat_ids]) - except (ValueError, TypeError): + except ValueError, TypeError: raise HTTPException(status_code=400, detail="Invalid chat ID format") account = await db.create_viewer_account( @@ -1491,7 +1493,7 @@ async def update_viewer(viewer_id: int, request: Request, user: UserContext = De else: try: updates["allowed_chat_ids"] = json.dumps([int(cid) for cid in allowed]) - except (ValueError, TypeError): + except ValueError, TypeError: raise HTTPException(status_code=400, detail="Invalid chat ID format") if "is_active" in data: diff --git a/src/web/push.py b/src/web/push.py index 05b93b90..255192b0 100644 --- a/src/web/push.py +++ b/src/web/push.py @@ -219,7 +219,7 @@ async def get_subscriptions(self, chat_id: int | None = None) -> list[dict[str, user_chats = json.loads(sub.allowed_chat_ids) if chat_id not in user_chats: continue - except (json.JSONDecodeError, TypeError): + except json.JSONDecodeError, TypeError: continue filtered.append({"endpoint": sub.endpoint, "keys": {"p256dh": sub.p256dh, "auth": sub.auth}}) diff --git a/tests/test_multi_user_auth.py b/tests/test_multi_user_auth.py index 4cd935b5..7080df4f 100644 --- a/tests/test_multi_user_auth.py +++ b/tests/test_multi_user_auth.py @@ -48,6 +48,8 @@ def _make_mock_db(): db.calculate_and_store_statistics = AsyncMock(return_value={"total_chats": 3}) db.get_all_folders = AsyncMock(return_value=[]) db.get_archived_chat_count = AsyncMock(return_value=0) + db.get_session = AsyncMock(return_value=None) + db.delete_session = AsyncMock() return db @@ -246,7 +248,7 @@ class TestRateLimiting: def test_rate_limit_blocks_after_threshold(self, auth_env): client, mod, _ = _get_client() - for _ in range(5): + for _ in range(15): client.post("/api/login", json={"username": "admin", "password": "wrong"}) resp = client.post("/api/login", json={"username": "admin", "password": "wrong"})