diff --git a/CHANGES.md b/CHANGES.md index 5668356..11a6d8e 100644 --- a/CHANGES.md +++ b/CHANGES.md @@ -2,6 +2,11 @@ ## [Unreleased] +### Changed + +- update `stac-fastapi-*` version requirements to `>=5.2,<6.0` +- add pgstac health-check in `/_mgmt/health` + ## [5.0.2] - 2025-04-07 ### Fixed diff --git a/setup.py b/setup.py index 2183a6f..f5e7eb7 100644 --- a/setup.py +++ b/setup.py @@ -10,9 +10,9 @@ "orjson", "pydantic", "stac_pydantic==3.1.*", - "stac-fastapi.api>=5.1,<6.0", - "stac-fastapi.extensions>=5.1,<6.0", - "stac-fastapi.types>=5.1,<6.0", + "stac-fastapi.api>=5.2,<6.0", + "stac-fastapi.extensions>=5.2,<6.0", + "stac-fastapi.types>=5.2,<6.0", "asyncpg", "buildpg", "brotli_asgi", diff --git a/stac_fastapi/pgstac/app.py b/stac_fastapi/pgstac/app.py index fcc1288..5e9dd13 100644 --- a/stac_fastapi/pgstac/app.py +++ b/stac_fastapi/pgstac/app.py @@ -10,12 +10,12 @@ from brotli_asgi import BrotliMiddleware from fastapi import FastAPI -from fastapi.responses import ORJSONResponse from stac_fastapi.api.app import StacApi from stac_fastapi.api.middleware import CORSMiddleware, ProxyHeaderMiddleware from stac_fastapi.api.models import ( EmptyRequest, ItemCollectionUri, + JSONResponse, create_get_request_model, create_post_request_model, create_request_model, @@ -40,7 +40,7 @@ from starlette.middleware import Middleware from stac_fastapi.pgstac.config import Settings -from stac_fastapi.pgstac.core import CoreCrudClient +from stac_fastapi.pgstac.core import CoreCrudClient, health_check from stac_fastapi.pgstac.db import close_db_connection, connect_to_db from stac_fastapi.pgstac.extensions import QueryExtension from stac_fastapi.pgstac.extensions.filter import FiltersClient @@ -54,7 +54,7 @@ "transaction": TransactionExtension( client=TransactionsClient(), settings=settings, - response_class=ORJSONResponse, + response_class=JSONResponse, ), "bulk_transactions": BulkTransactionExtension(client=BulkTransactionsClient()), } @@ -174,7 +174,7 @@ async def lifespan(app: FastAPI): settings=settings, extensions=application_extensions, client=CoreCrudClient(pgstac_search_model=post_request_model), - response_class=ORJSONResponse, + response_class=JSONResponse, items_get_request_model=items_get_request_model, search_get_request_model=get_request_model, search_post_request_model=post_request_model, @@ -188,6 +188,7 @@ async def lifespan(app: FastAPI): allow_methods=settings.cors_methods, ), ], + health_check=health_check, ) app = api.app diff --git a/stac_fastapi/pgstac/core.py b/stac_fastapi/pgstac/core.py index 582b455..aab9625 100644 --- a/stac_fastapi/pgstac/core.py +++ b/stac_fastapi/pgstac/core.py @@ -605,3 +605,49 @@ def _clean_search_args( # noqa: C901 clean[k] = v return clean + + +async def health_check(request: Request) -> Union[Dict, JSONResponse]: + """PgSTAC HealthCheck.""" + resp = { + "status": "UP", + "lifespan": { + "status": "UP", + }, + } + if not hasattr(request.app.state, "get_connection"): + return JSONResponse( + status_code=503, + content={ + "status": "DOWN", + "lifespan": { + "status": "DOWN", + "message": "application lifespan wasn't run", + }, + "pgstac": { + "status": "DOWN", + "message": "Could not connect to database", + }, + }, + ) + + try: + async with request.app.state.get_connection(request, "r") as conn: + q, p = render( + """SELECT pgstac.get_version();""", + ) + version = await conn.fetchval(q, *p) + except Exception as e: + resp["status"] = "DOWN" + resp["pgstac"] = { + "status": "DOWN", + "message": str(e), + } + return JSONResponse(status_code=503, content=resp) + + resp["pgstac"] = { + "status": "UP", + "pgstac_version": version, + } + + return resp diff --git a/tests/conftest.py b/tests/conftest.py index 7944e8d..052f260 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -42,7 +42,7 @@ from stac_pydantic import Collection, Item from stac_fastapi.pgstac.config import PostgresSettings, Settings -from stac_fastapi.pgstac.core import CoreCrudClient +from stac_fastapi.pgstac.core import CoreCrudClient, health_check from stac_fastapi.pgstac.db import close_db_connection, connect_to_db from stac_fastapi.pgstac.extensions import QueryExtension from stac_fastapi.pgstac.extensions.filter import FiltersClient @@ -191,6 +191,7 @@ def api_client(request): collections_get_request_model=collection_search_extension.GET, response_class=ORJSONResponse, router=APIRouter(prefix=prefix), + health_check=health_check, ) return api @@ -302,6 +303,7 @@ def api_client_no_ext(): TransactionExtension(client=TransactionsClient(), settings=api_settings) ], client=CoreCrudClient(), + health_check=health_check, ) diff --git a/tests/resources/test_mgmt.py b/tests/resources/test_mgmt.py index 9d2bc3d..147966a 100644 --- a/tests/resources/test_mgmt.py +++ b/tests/resources/test_mgmt.py @@ -1,3 +1,11 @@ +from httpx import ASGITransport, AsyncClient +from stac_fastapi.api.app import StacApi + +from stac_fastapi.pgstac.config import PostgresSettings, Settings +from stac_fastapi.pgstac.core import CoreCrudClient, health_check +from stac_fastapi.pgstac.db import close_db_connection, connect_to_db + + async def test_ping_no_param(app_client): """ Test ping endpoint with a mocked client. @@ -7,3 +15,67 @@ async def test_ping_no_param(app_client): res = await app_client.get("/_mgmt/ping") assert res.status_code == 200 assert res.json() == {"message": "PONG"} + + +async def test_health(app_client): + """ + Test health endpoint + + Args: + app_client (TestClient): mocked client fixture + + """ + res = await app_client.get("/_mgmt/health") + assert res.status_code == 200 + body = res.json() + assert body["status"] == "UP" + assert body["pgstac"]["status"] == "UP" + assert body["pgstac"]["pgstac_version"] + + +async def test_health_503(database): + """Test health endpoint error.""" + + # No lifespan so no `get_connection` is application state + api = StacApi( + settings=Settings(testing=True), + extensions=[], + client=CoreCrudClient(), + health_check=health_check, + ) + + async with AsyncClient( + transport=ASGITransport(app=api.app), base_url="http://test" + ) as client: + res = await client.get("/_mgmt/health") + assert res.status_code == 503 + body = res.json() + assert body["status"] == "DOWN" + assert body["lifespan"]["status"] == "DOWN" + assert body["lifespan"]["message"] == "application lifespan wasn't run" + assert body["pgstac"]["status"] == "DOWN" + assert body["pgstac"]["message"] == "Could not connect to database" + + # No lifespan so no `get_connection` is application state + postgres_settings = PostgresSettings( + postgres_user=database.user, + postgres_pass=database.password, + postgres_host_reader=database.host, + postgres_host_writer=database.host, + postgres_port=database.port, + postgres_dbname=database.dbname, + ) + # Create connection pool but close it just after + await connect_to_db(api.app, postgres_settings=postgres_settings) + await close_db_connection(api.app) + + async with AsyncClient( + transport=ASGITransport(app=api.app), base_url="http://test" + ) as client: + res = await client.get("/_mgmt/health") + assert res.status_code == 503 + body = res.json() + assert body["status"] == "DOWN" + assert body["lifespan"]["status"] == "UP" + assert body["pgstac"]["status"] == "DOWN" + assert body["pgstac"]["message"] == "pool is closed"