diff --git a/src/web/main.py b/src/web/main.py index 34b4c834..5bf76d78 100644 --- a/src/web/main.py +++ b/src/web/main.py @@ -489,6 +489,18 @@ def _record_login_attempt(ip: str) -> None: _login_attempts.setdefault(ip, []).append(time.time()) +def _is_db_connection_error(exc: Exception) -> bool: + """Check if an exception indicates the database is unreachable.""" + current: BaseException | None = exc + for _ in range(10): + if current is None: + break + if isinstance(current, OSError): + return True + current = getattr(current, "__cause__", None) or getattr(current, "__context__", None) + return False + + async def _create_session( username: str, role: str, @@ -769,6 +781,31 @@ async def read_root(): ) +@app.exception_handler(Exception) +async def unhandled_exception_handler(request: Request, exc: Exception): + """Catch unhandled exceptions and return 503 for DB connection errors.""" + if _is_db_connection_error(exc): + logger.error(f"Database connection error on {request.url.path}: {exc}") + return JSONResponse(status_code=503, content={"detail": "Database temporarily unavailable"}) + logger.error(f"Unhandled error on {request.url.path}: {exc}", exc_info=True) + return JSONResponse(status_code=500, content={"detail": "Internal server error"}) + + +@app.get("/api/health") +async def health_check(): + """Health check endpoint for monitoring and Docker healthchecks.""" + result = {"status": "ok"} + if db: + try: + await db.get_chat_count() + result["database"] = "connected" + except Exception: + result["database"] = "unreachable" + result["status"] = "degraded" + return JSONResponse(status_code=503, content=result) + return result + + @app.get("/api/auth/check") async def check_auth(auth_cookie: str | None = Cookie(default=None, alias=AUTH_COOKIE_NAME)): """Check current authentication status. Returns role and username if authenticated.""" @@ -828,8 +865,15 @@ async def login(request: Request): user_agent = request.headers.get("user-agent", "")[:500] # 1. Check DB viewer accounts first + _db_reachable = True if db: - viewer = await db.get_viewer_by_username(username) + try: + viewer = await db.get_viewer_by_username(username) + except Exception as e: + logger.warning(f"Database unavailable during login, falling back to env credentials: {e}") + _db_reachable = False + viewer = None + if viewer and viewer["is_active"]: if _verify_password(password, viewer["salt"], viewer["password_hash"]): allowed = None @@ -851,7 +895,7 @@ async def login(request: Request): max_age=AUTH_SESSION_SECONDS, ) - if db: + try: await db.create_audit_log( username=username, role="viewer", @@ -860,6 +904,8 @@ async def login(request: Request): ip_address=client_ip, user_agent=user_agent, ) + except Exception: + logger.warning(f"Failed to write audit log for viewer login: {username}") return response # 2. Fall back to master env var credentials @@ -878,27 +924,36 @@ async def login(request: Request): max_age=AUTH_SESSION_SECONDS, ) + try: + if db: + await db.create_audit_log( + username=username, + role="master", + action="login_success", + endpoint="/api/login", + ip_address=client_ip, + user_agent=user_agent, + ) + except Exception: + logger.warning(f"Failed to write audit log for master login: {username}") + return response + + # Failed login — if DB was unreachable, viewer accounts couldn't be checked + if not _db_reachable: + raise HTTPException(status_code=503, detail="Database temporarily unavailable, please try again later") + + try: if db: await db.create_audit_log( - username=username, - role="master", - action="login_success", + username=username or "(empty)", + role="unknown", + action="login_failed", endpoint="/api/login", ip_address=client_ip, user_agent=user_agent, ) - return response - - # Failed login - if db: - await db.create_audit_log( - username=username or "(empty)", - role="unknown", - action="login_failed", - endpoint="/api/login", - ip_address=client_ip, - user_agent=user_agent, - ) + except Exception: + logger.warning(f"Failed to write audit log for failed login: {username}") raise HTTPException(status_code=401, detail="Invalid credentials") @@ -1138,6 +1193,8 @@ async def get_chats( } except Exception as e: logger.error(f"Error fetching chats: {e}", exc_info=True) + if _is_db_connection_error(e): + raise HTTPException(status_code=503, detail="Database temporarily unavailable") raise HTTPException(status_code=500, detail="Internal server error") @@ -1191,6 +1248,8 @@ async def get_messages( return messages except Exception as e: logger.error(f"Error fetching messages: {e}", exc_info=True) + if _is_db_connection_error(e): + raise HTTPException(status_code=503, detail="Database temporarily unavailable") raise HTTPException(status_code=500, detail="Internal server error") @@ -1206,6 +1265,8 @@ async def get_pinned_messages(chat_id: int, user: UserContext = Depends(require_ return pinned_messages # Returns empty list if no pinned messages except Exception as e: logger.error(f"Error fetching pinned messages: {e}", exc_info=True) + if _is_db_connection_error(e): + raise HTTPException(status_code=503, detail="Database temporarily unavailable") raise HTTPException(status_code=500, detail="Internal server error") @@ -1220,6 +1281,8 @@ async def get_folders(user: UserContext = Depends(require_auth)): return {"folders": folders} except Exception as e: logger.error(f"Error fetching folders: {e}", exc_info=True) + if _is_db_connection_error(e): + raise HTTPException(status_code=503, detail="Database temporarily unavailable") raise HTTPException(status_code=500, detail="Internal server error") @@ -1238,6 +1301,8 @@ async def get_chat_topics(chat_id: int, user: UserContext = Depends(require_auth return {"topics": topics} except Exception as e: logger.error(f"Error fetching topics: {e}", exc_info=True) + if _is_db_connection_error(e): + raise HTTPException(status_code=503, detail="Database temporarily unavailable") raise HTTPException(status_code=500, detail="Internal server error") @@ -1258,6 +1323,8 @@ async def get_archived_count(user: UserContext = Depends(require_auth)): return {"count": count} except Exception as e: logger.error(f"Error fetching archived count: {e}", exc_info=True) + if _is_db_connection_error(e): + raise HTTPException(status_code=503, detail="Database temporarily unavailable") raise HTTPException(status_code=500, detail="Internal server error") @@ -1285,6 +1352,8 @@ async def get_stats(user: UserContext = Depends(require_auth)): return stats except Exception as e: logger.error(f"Error fetching stats: {e}", exc_info=True) + if _is_db_connection_error(e): + raise HTTPException(status_code=503, detail="Database temporarily unavailable") raise HTTPException(status_code=500, detail="Internal server error") @@ -1297,6 +1366,8 @@ async def refresh_stats(user: UserContext = Depends(require_master)): return stats except Exception as e: logger.error(f"Error calculating stats: {e}", exc_info=True) + if _is_db_connection_error(e): + raise HTTPException(status_code=503, detail="Database temporarily unavailable") raise HTTPException(status_code=500, detail="Internal server error") @@ -1380,6 +1451,8 @@ async def push_subscribe(request: Request, user: UserContext = Depends(require_a raise except Exception as e: logger.error(f"Push subscribe error: {e}") + if _is_db_connection_error(e): + raise HTTPException(status_code=503, detail="Database temporarily unavailable") raise HTTPException(status_code=500, detail="Internal server error") @@ -1410,6 +1483,8 @@ async def push_unsubscribe(request: Request, user: UserContext = Depends(require raise except Exception as e: logger.error(f"Push unsubscribe error: {e}") + if _is_db_connection_error(e): + raise HTTPException(status_code=503, detail="Database temporarily unavailable") raise HTTPException(status_code=500, detail="Internal server error") @@ -1459,6 +1534,8 @@ async def get_chat_stats(chat_id: int, user: UserContext = Depends(require_auth) return stats except Exception as e: logger.error(f"Error getting chat stats: {e}", exc_info=True) + if _is_db_connection_error(e): + raise HTTPException(status_code=503, detail="Database temporarily unavailable") raise HTTPException(status_code=500, detail="Internal server error") @@ -1505,6 +1582,8 @@ async def get_message_by_date( raise except Exception as e: logger.error(f"Error finding message by date: {e}", exc_info=True) + if _is_db_connection_error(e): + raise HTTPException(status_code=503, detail="Database temporarily unavailable") raise HTTPException(status_code=500, detail="Internal server error") @@ -1549,6 +1628,8 @@ async def iter_json(): raise except Exception as e: logger.error(f"Error exporting chat: {e}", exc_info=True) + if _is_db_connection_error(e): + raise HTTPException(status_code=503, detail="Database temporarily unavailable") raise HTTPException(status_code=500, detail="Internal server error")