Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ jobs:
run: |
python -m pip install --upgrade pip
pip install -r requirements.txt
pip install pytest pytest-cov
pip install pytest pytest-cov beautifulsoup4

- name: Run tests
run: |
Expand Down
6 changes: 5 additions & 1 deletion scripts/auth_noninteractive.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,11 @@ async def main():
await client.sign_in(phone, code)
except Exception as e:
error_str = str(e)
if "Two-steps verification" in error_str or "password" in error_str.lower() or "SessionPasswordNeeded" in error_str:
if (
"Two-steps verification" in error_str
or "password" in error_str.lower()
or "SessionPasswordNeeded" in error_str
):
if not password:
print("2FA is enabled. Re-run with: verify CODE 2FA_PASSWORD")
await client.disconnect()
Expand Down
16 changes: 4 additions & 12 deletions src/db/adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -1845,9 +1845,7 @@ async def save_session(
async def get_session(self, token: str) -> dict[str, Any] | None:
"""Get a session by token."""
async with self.db_manager.async_session_factory() as session:
result = await session.execute(
select(ViewerSession).where(ViewerSession.token == token)
)
result = await session.execute(select(ViewerSession).where(ViewerSession.token == token))
row = result.scalar_one_or_none()
return self._viewer_session_to_dict(row) if row else None

Expand All @@ -1861,19 +1859,15 @@ async def load_all_sessions(self) -> list[dict[str, Any]]:
async def delete_session(self, token: str) -> bool:
"""Delete a single session by token."""
async with self.db_manager.async_session_factory() as session:
result = await session.execute(
delete(ViewerSession).where(ViewerSession.token == token)
)
result = await session.execute(delete(ViewerSession).where(ViewerSession.token == token))
await session.commit()
return result.rowcount > 0

@retry_on_locked()
async def delete_user_sessions(self, username: str) -> int:
"""Delete all sessions for a given username. Returns count deleted."""
async with self.db_manager.async_session_factory() as session:
result = await session.execute(
delete(ViewerSession).where(ViewerSession.username == username)
)
result = await session.execute(delete(ViewerSession).where(ViewerSession.username == username))
await session.commit()
return result.rowcount

Expand All @@ -1884,9 +1878,7 @@ async def cleanup_expired_sessions(self, max_age_seconds: float) -> int:

cutoff = time.time() - max_age_seconds
async with self.db_manager.async_session_factory() as session:
result = await session.execute(
delete(ViewerSession).where(ViewerSession.created_at < cutoff)
)
result = await session.execute(delete(ViewerSession).where(ViewerSession.created_at < cutoff))
await session.commit()
return result.rowcount

Expand Down
4 changes: 3 additions & 1 deletion src/db/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -248,7 +248,9 @@ class PushSubscription(Base):
auth: Mapped[str] = mapped_column(String(255), nullable=False) # Auth secret
chat_id: Mapped[int | None] = mapped_column(BigInteger) # Optional: subscribe to specific chat only
username: Mapped[str | None] = mapped_column(String(255)) # User who created this subscription
allowed_chat_ids: Mapped[str | None] = mapped_column(Text) # JSON snapshot of user's allowed chats at subscribe time
allowed_chat_ids: Mapped[str | None] = mapped_column(
Text
) # JSON snapshot of user's allowed chats at subscribe time
user_agent: Mapped[str | None] = mapped_column(String(500)) # Browser info for debugging
created_at: Mapped[datetime] = mapped_column(DateTime, default=datetime.utcnow, server_default=func.now())
last_used_at: Mapped[datetime | None] = mapped_column(DateTime) # Track activity
Expand Down
13 changes: 6 additions & 7 deletions src/telegram_import.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,12 +123,12 @@ def parse_date(msg: dict) -> datetime | None:
if "date_unixtime" in msg:
try:
return datetime.fromtimestamp(int(msg["date_unixtime"]), tz=UTC).replace(tzinfo=None)
except (ValueError, TypeError, OSError):
except ValueError, TypeError, OSError:
pass
if "date" in msg:
try:
return datetime.fromisoformat(msg["date"]).replace(tzinfo=None)
except (ValueError, TypeError):
except ValueError, TypeError:
pass
return None

Expand All @@ -138,12 +138,12 @@ def parse_edited_date(msg: dict) -> datetime | None:
if "edited_unixtime" in msg:
try:
return datetime.fromtimestamp(int(msg["edited_unixtime"]), tz=UTC).replace(tzinfo=None)
except (ValueError, TypeError, OSError):
except ValueError, TypeError, OSError:
pass
if "edited" in msg:
try:
return datetime.fromisoformat(msg["edited"]).replace(tzinfo=None)
except (ValueError, TypeError):
except ValueError, TypeError:
pass
return None

Expand Down Expand Up @@ -222,7 +222,7 @@ def parse_html_date(date_str: str) -> str | None:
try:
day, month, year = parts[0].split(".")
return f"{year}-{month}-{day}T{parts[1]}"
except (ValueError, IndexError):
except ValueError, IndexError:
return None


Expand Down Expand Up @@ -551,8 +551,7 @@ async def run(
chats = [{"name": chat_name, "type": "html_export", "id": 0, "messages": messages}]
else:
raise FileNotFoundError(
f"No result.json or messages.html found in {path}. "
"Expected a Telegram Desktop export directory."
f"No result.json or messages.html found in {path}. Expected a Telegram Desktop export directory."
)

if not chats:
Expand Down
20 changes: 11 additions & 9 deletions src/web/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -318,7 +318,7 @@ async def lifespan(app: FastAPI) -> AsyncGenerator[None]:
if row["allowed_chat_ids"]:
try:
allowed = set(json.loads(row["allowed_chat_ids"]))
except (json.JSONDecodeError, TypeError):
except json.JSONDecodeError, TypeError:
logger.warning(f"Skipping session with corrupted allowed_chat_ids for {row['username']}")
continue
_sessions[row["token"]] = SessionData(
Expand Down Expand Up @@ -498,8 +498,11 @@ async def _create_session(username: str, role: str, allowed_chat_ids: set[int] |
now = time.time()
token = secrets.token_urlsafe(32)
_sessions[token] = SessionData(
username=username, role=role, allowed_chat_ids=allowed_chat_ids,
created_at=now, last_accessed=now,
username=username,
role=role,
allowed_chat_ids=allowed_chat_ids,
created_at=now,
last_accessed=now,
)

# Persist to database
Expand Down Expand Up @@ -563,7 +566,7 @@ async def _resolve_session(auth_cookie: str) -> SessionData | None:
if row["allowed_chat_ids"]:
try:
allowed = set(json.loads(row["allowed_chat_ids"]))
except (json.JSONDecodeError, TypeError):
except json.JSONDecodeError, TypeError:
logger.warning(f"Corrupted allowed_chat_ids for session {row['username']}, denying access")
return None

Expand Down Expand Up @@ -759,7 +762,7 @@ async def login(request: Request):
if viewer["allowed_chat_ids"]:
try:
allowed = set(json.loads(viewer["allowed_chat_ids"]))
except (json.JSONDecodeError, TypeError):
except json.JSONDecodeError, TypeError:
allowed = None

token = await _create_session(username, "viewer", allowed)
Expand Down Expand Up @@ -1259,8 +1262,7 @@ async def internal_push(request: Request):

allowed = False
if client_host and (
client_host in ("127.0.0.1", "localhost", "::1")
or client_host.startswith(("172.", "10.", "192.168."))
client_host in ("127.0.0.1", "localhost", "::1") or client_host.startswith(("172.", "10.", "192.168."))
):
allowed = True

Expand Down Expand Up @@ -1436,7 +1438,7 @@ async def create_viewer(request: Request, user: UserContext = Depends(require_ma
if allowed_chat_ids is not None:
try:
chat_ids_json = json.dumps([int(cid) for cid in allowed_chat_ids])
except (ValueError, TypeError):
except ValueError, TypeError:
raise HTTPException(status_code=400, detail="Invalid chat ID format")

account = await db.create_viewer_account(
Expand Down Expand Up @@ -1491,7 +1493,7 @@ async def update_viewer(viewer_id: int, request: Request, user: UserContext = De
else:
try:
updates["allowed_chat_ids"] = json.dumps([int(cid) for cid in allowed])
except (ValueError, TypeError):
except ValueError, TypeError:
raise HTTPException(status_code=400, detail="Invalid chat ID format")

if "is_active" in data:
Expand Down
2 changes: 1 addition & 1 deletion src/web/push.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,7 +219,7 @@ async def get_subscriptions(self, chat_id: int | None = None) -> list[dict[str,
user_chats = json.loads(sub.allowed_chat_ids)
if chat_id not in user_chats:
continue
except (json.JSONDecodeError, TypeError):
except json.JSONDecodeError, TypeError:
continue
filtered.append({"endpoint": sub.endpoint, "keys": {"p256dh": sub.p256dh, "auth": sub.auth}})

Expand Down
4 changes: 3 additions & 1 deletion tests/test_multi_user_auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,8 @@ def _make_mock_db():
db.calculate_and_store_statistics = AsyncMock(return_value={"total_chats": 3})
db.get_all_folders = AsyncMock(return_value=[])
db.get_archived_chat_count = AsyncMock(return_value=0)
db.get_session = AsyncMock(return_value=None)
db.delete_session = AsyncMock()
return db


Expand Down Expand Up @@ -246,7 +248,7 @@ class TestRateLimiting:

def test_rate_limit_blocks_after_threshold(self, auth_env):
client, mod, _ = _get_client()
for _ in range(5):
for _ in range(15):
client.post("/api/login", json={"username": "admin", "password": "wrong"})

resp = client.post("/api/login", json={"username": "admin", "password": "wrong"})
Expand Down
Loading