Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
115 changes: 98 additions & 17 deletions src/web/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -489,6 +489,18 @@ def _record_login_attempt(ip: str) -> None:
_login_attempts.setdefault(ip, []).append(time.time())


def _is_db_connection_error(exc: Exception) -> bool:
"""Check if an exception indicates the database is unreachable."""
current: BaseException | None = exc
for _ in range(10):
if current is None:
break
if isinstance(current, OSError):
return True
current = getattr(current, "__cause__", None) or getattr(current, "__context__", None)
return False


async def _create_session(
username: str,
role: str,
Expand Down Expand Up @@ -769,6 +781,31 @@ async def read_root():
)


@app.exception_handler(Exception)
async def unhandled_exception_handler(request: Request, exc: Exception):
"""Catch unhandled exceptions and return 503 for DB connection errors."""
if _is_db_connection_error(exc):
logger.error(f"Database connection error on {request.url.path}: {exc}")
return JSONResponse(status_code=503, content={"detail": "Database temporarily unavailable"})
logger.error(f"Unhandled error on {request.url.path}: {exc}", exc_info=True)
return JSONResponse(status_code=500, content={"detail": "Internal server error"})


@app.get("/api/health")
async def health_check():
"""Health check endpoint for monitoring and Docker healthchecks."""
result = {"status": "ok"}
if db:
try:
await db.get_chat_count()
result["database"] = "connected"
except Exception:
result["database"] = "unreachable"
result["status"] = "degraded"
return JSONResponse(status_code=503, content=result)
return result


@app.get("/api/auth/check")
async def check_auth(auth_cookie: str | None = Cookie(default=None, alias=AUTH_COOKIE_NAME)):
"""Check current authentication status. Returns role and username if authenticated."""
Expand Down Expand Up @@ -828,8 +865,15 @@ async def login(request: Request):
user_agent = request.headers.get("user-agent", "")[:500]

# 1. Check DB viewer accounts first
_db_reachable = True
if db:
viewer = await db.get_viewer_by_username(username)
try:
viewer = await db.get_viewer_by_username(username)
except Exception as e:
logger.warning(f"Database unavailable during login, falling back to env credentials: {e}")
_db_reachable = False
viewer = None

if viewer and viewer["is_active"]:
if _verify_password(password, viewer["salt"], viewer["password_hash"]):
allowed = None
Expand All @@ -851,7 +895,7 @@ async def login(request: Request):
max_age=AUTH_SESSION_SECONDS,
)

if db:
try:
await db.create_audit_log(
username=username,
role="viewer",
Expand All @@ -860,6 +904,8 @@ async def login(request: Request):
ip_address=client_ip,
user_agent=user_agent,
)
except Exception:
logger.warning(f"Failed to write audit log for viewer login: {username}")
return response

# 2. Fall back to master env var credentials
Expand All @@ -878,27 +924,36 @@ async def login(request: Request):
max_age=AUTH_SESSION_SECONDS,
)

try:
if db:
await db.create_audit_log(
username=username,
role="master",
action="login_success",
endpoint="/api/login",
ip_address=client_ip,
user_agent=user_agent,
)
except Exception:
logger.warning(f"Failed to write audit log for master login: {username}")
return response

# Failed login — if DB was unreachable, viewer accounts couldn't be checked
if not _db_reachable:
raise HTTPException(status_code=503, detail="Database temporarily unavailable, please try again later")

try:
if db:
await db.create_audit_log(
username=username,
role="master",
action="login_success",
username=username or "(empty)",
role="unknown",
action="login_failed",
endpoint="/api/login",
ip_address=client_ip,
user_agent=user_agent,
)
return response

# Failed login
if db:
await db.create_audit_log(
username=username or "(empty)",
role="unknown",
action="login_failed",
endpoint="/api/login",
ip_address=client_ip,
user_agent=user_agent,
)
except Exception:
logger.warning(f"Failed to write audit log for failed login: {username}")
raise HTTPException(status_code=401, detail="Invalid credentials")


Expand Down Expand Up @@ -1138,6 +1193,8 @@ async def get_chats(
}
except Exception as e:
logger.error(f"Error fetching chats: {e}", exc_info=True)
if _is_db_connection_error(e):
raise HTTPException(status_code=503, detail="Database temporarily unavailable")
raise HTTPException(status_code=500, detail="Internal server error")


Expand Down Expand Up @@ -1191,6 +1248,8 @@ async def get_messages(
return messages
except Exception as e:
logger.error(f"Error fetching messages: {e}", exc_info=True)
if _is_db_connection_error(e):
raise HTTPException(status_code=503, detail="Database temporarily unavailable")
raise HTTPException(status_code=500, detail="Internal server error")


Expand All @@ -1206,6 +1265,8 @@ async def get_pinned_messages(chat_id: int, user: UserContext = Depends(require_
return pinned_messages # Returns empty list if no pinned messages
except Exception as e:
logger.error(f"Error fetching pinned messages: {e}", exc_info=True)
if _is_db_connection_error(e):
raise HTTPException(status_code=503, detail="Database temporarily unavailable")
raise HTTPException(status_code=500, detail="Internal server error")


Expand All @@ -1220,6 +1281,8 @@ async def get_folders(user: UserContext = Depends(require_auth)):
return {"folders": folders}
except Exception as e:
logger.error(f"Error fetching folders: {e}", exc_info=True)
if _is_db_connection_error(e):
raise HTTPException(status_code=503, detail="Database temporarily unavailable")
raise HTTPException(status_code=500, detail="Internal server error")


Expand All @@ -1238,6 +1301,8 @@ async def get_chat_topics(chat_id: int, user: UserContext = Depends(require_auth
return {"topics": topics}
except Exception as e:
logger.error(f"Error fetching topics: {e}", exc_info=True)
if _is_db_connection_error(e):
raise HTTPException(status_code=503, detail="Database temporarily unavailable")
raise HTTPException(status_code=500, detail="Internal server error")


Expand All @@ -1258,6 +1323,8 @@ async def get_archived_count(user: UserContext = Depends(require_auth)):
return {"count": count}
except Exception as e:
logger.error(f"Error fetching archived count: {e}", exc_info=True)
if _is_db_connection_error(e):
raise HTTPException(status_code=503, detail="Database temporarily unavailable")
raise HTTPException(status_code=500, detail="Internal server error")


Expand Down Expand Up @@ -1285,6 +1352,8 @@ async def get_stats(user: UserContext = Depends(require_auth)):
return stats
except Exception as e:
logger.error(f"Error fetching stats: {e}", exc_info=True)
if _is_db_connection_error(e):
raise HTTPException(status_code=503, detail="Database temporarily unavailable")
raise HTTPException(status_code=500, detail="Internal server error")


Expand All @@ -1297,6 +1366,8 @@ async def refresh_stats(user: UserContext = Depends(require_master)):
return stats
except Exception as e:
logger.error(f"Error calculating stats: {e}", exc_info=True)
if _is_db_connection_error(e):
raise HTTPException(status_code=503, detail="Database temporarily unavailable")
raise HTTPException(status_code=500, detail="Internal server error")


Expand Down Expand Up @@ -1380,6 +1451,8 @@ async def push_subscribe(request: Request, user: UserContext = Depends(require_a
raise
except Exception as e:
logger.error(f"Push subscribe error: {e}")
if _is_db_connection_error(e):
raise HTTPException(status_code=503, detail="Database temporarily unavailable")
raise HTTPException(status_code=500, detail="Internal server error")


Expand Down Expand Up @@ -1410,6 +1483,8 @@ async def push_unsubscribe(request: Request, user: UserContext = Depends(require
raise
except Exception as e:
logger.error(f"Push unsubscribe error: {e}")
if _is_db_connection_error(e):
raise HTTPException(status_code=503, detail="Database temporarily unavailable")
raise HTTPException(status_code=500, detail="Internal server error")


Expand Down Expand Up @@ -1459,6 +1534,8 @@ async def get_chat_stats(chat_id: int, user: UserContext = Depends(require_auth)
return stats
except Exception as e:
logger.error(f"Error getting chat stats: {e}", exc_info=True)
if _is_db_connection_error(e):
raise HTTPException(status_code=503, detail="Database temporarily unavailable")
raise HTTPException(status_code=500, detail="Internal server error")


Expand Down Expand Up @@ -1505,6 +1582,8 @@ async def get_message_by_date(
raise
except Exception as e:
logger.error(f"Error finding message by date: {e}", exc_info=True)
if _is_db_connection_error(e):
raise HTTPException(status_code=503, detail="Database temporarily unavailable")
raise HTTPException(status_code=500, detail="Internal server error")


Expand Down Expand Up @@ -1549,6 +1628,8 @@ async def iter_json():
raise
except Exception as e:
logger.error(f"Error exporting chat: {e}", exc_info=True)
if _is_db_connection_error(e):
raise HTTPException(status_code=503, detail="Database temporarily unavailable")
raise HTTPException(status_code=500, detail="Internal server error")


Expand Down
Loading