From 090a6a90b7833dab31bf807948ddfbcf4a37a286 Mon Sep 17 00:00:00 2001 From: Ben Pedigo Date: Thu, 23 Apr 2026 11:53:02 -0700 Subject: [PATCH] more modular base --- src/cave_catalog/routers/assets.py | 84 ++-------- src/cave_catalog/routers/helpers.py | 103 ++++++++++++ tests/test_helpers.py | 236 ++++++++++++++++++++++++++++ 3 files changed, 356 insertions(+), 67 deletions(-) create mode 100644 src/cave_catalog/routers/helpers.py create mode 100644 tests/test_helpers.py diff --git a/src/cave_catalog/routers/assets.py b/src/cave_catalog/routers/assets.py index 5500eaf..f5f65db 100644 --- a/src/cave_catalog/routers/assets.py +++ b/src/cave_catalog/routers/assets.py @@ -7,7 +7,6 @@ from __future__ import annotations import uuid -from datetime import datetime, timezone import structlog from fastapi import APIRouter, Depends, HTTPException, Query, status @@ -20,6 +19,12 @@ from cave_catalog.config import Settings, get_settings from cave_catalog.db.models import Asset from cave_catalog.db.session import get_session +from cave_catalog.routers.helpers import ( + get_asset, + now_utc, + require_asset_view_access, + require_datastack_permission, +) from cave_catalog.schemas import ( AccessResponse, AssetRequest, @@ -49,16 +54,6 @@ def _get_http_client() -> AsyncClient: # --------------------------------------------------------------------------- -def _now_utc() -> datetime: - return datetime.now(timezone.utc) - - -def _asset_is_expired(asset: Asset) -> bool: - if asset.expires_at is None: - return False - return asset.expires_at.replace(tzinfo=timezone.utc) < _now_utc() - - def _asset_to_response(asset: Asset) -> AssetResponse: return AssetResponse.model_validate(asset) @@ -113,12 +108,7 @@ async def register_asset( uri=body.uri, fmt=body.format, ) - # Auth check: user must have write permission on the datastack - if settings.auth.enabled and not user.has_permission(body.datastack, "edit"): - raise HTTPException( - status_code=status.HTTP_403_FORBIDDEN, - detail=f"Write permission required on datastack '{body.datastack}'", - ) + require_datastack_permission(user, settings, body.datastack, "edit") # Duplicate check existing = await _find_duplicate( @@ -167,7 +157,7 @@ async def register_asset( maturity=body.maturity.value, properties=body.properties, access_group=body.access_group, - created_at=_now_utc(), + created_at=now_utc(), expires_at=body.expires_at, ) session.add(asset) @@ -271,13 +261,9 @@ async def list_assets( settings: Settings = Depends(get_settings), ) -> list[AssetResponse]: logger.debug("list_assets", datastack=datastack, name=name, mat_version=mat_version) - if settings.auth.enabled and not user.has_permission(datastack, "view"): - raise HTTPException( - status_code=status.HTTP_403_FORBIDDEN, - detail=f"Read permission required on datastack '{datastack}'", - ) + require_datastack_permission(user, settings, datastack, "view") - now = _now_utc() + now = now_utc() stmt = select(Asset).where( and_( Asset.datastack == datastack, @@ -310,28 +296,15 @@ async def list_assets( @router.get("/{asset_id}", response_model=AssetResponse) -async def get_asset( +async def get_asset_by_id( asset_id: uuid.UUID, user: AuthUser = Depends(require_auth), session: AsyncSession = Depends(get_session), settings: Settings = Depends(get_settings), ) -> AssetResponse: logger.debug("get_asset", asset_id=str(asset_id)) - asset = await session.get(Asset, asset_id) - if asset is None or _asset_is_expired(asset): - raise HTTPException( - status_code=status.HTTP_404_NOT_FOUND, detail="Asset not found" - ) - - if settings.auth.enabled: - required_resource = asset.access_group or asset.datastack - if not user.has_permission(required_resource, "view") and not user.in_group( - required_resource - ): - raise HTTPException( - status_code=status.HTTP_403_FORBIDDEN, detail="Access denied" - ) - + asset = await get_asset(session, asset_id) + require_asset_view_access(user, settings, asset) return _asset_to_response(asset) @@ -348,17 +321,8 @@ async def delete_asset( settings: Settings = Depends(get_settings), ) -> None: logger.debug("delete_asset", asset_id=str(asset_id)) - asset = await session.get(Asset, asset_id) - if asset is None: - raise HTTPException( - status_code=status.HTTP_404_NOT_FOUND, detail="Asset not found" - ) - - if settings.auth.enabled and not user.has_permission(asset.datastack, "edit"): - raise HTTPException( - status_code=status.HTTP_403_FORBIDDEN, - detail=f"Write permission required on datastack '{asset.datastack}'", - ) + asset = await get_asset(session, asset_id, check_expired=False) + require_datastack_permission(user, settings, asset.datastack, "edit") await session.delete(asset) await session.commit() @@ -378,22 +342,8 @@ async def get_asset_access( settings: Settings = Depends(get_settings), ) -> AccessResponse: logger.debug("get_asset_access", asset_id=str(asset_id)) - - asset = await session.get(Asset, asset_id) - if asset is None or _asset_is_expired(asset): - raise HTTPException( - status_code=status.HTTP_404_NOT_FOUND, detail="Asset not found" - ) - - # Permission gating: consistent with get_asset - if settings.auth.enabled: - required_resource = asset.access_group or asset.datastack - if not user.has_permission(required_resource, "view") and not user.in_group( - required_resource - ): - raise HTTPException( - status_code=status.HTTP_403_FORBIDDEN, detail="Access denied" - ) + asset = await get_asset(session, asset_id) + require_asset_view_access(user, settings, asset) # Unmanaged assets: passthrough (no credentials) if not asset.is_managed: diff --git a/src/cave_catalog/routers/helpers.py b/src/cave_catalog/routers/helpers.py new file mode 100644 index 0000000..9bca655 --- /dev/null +++ b/src/cave_catalog/routers/helpers.py @@ -0,0 +1,103 @@ +"""Shared helpers for router endpoints. + +Reusable building blocks for auth checks and asset lookups that are used +across multiple endpoints. These raise ``HTTPException`` directly so they +belong in the router layer. +""" + +from __future__ import annotations + +import uuid +from datetime import datetime, timezone + +import structlog +from fastapi import HTTPException, status +from sqlalchemy.ext.asyncio import AsyncSession + +from cave_catalog.auth.middleware import AuthUser +from cave_catalog.config import Settings +from cave_catalog.db.models import Asset + +logger = structlog.get_logger() + + +# --------------------------------------------------------------------------- +# Time helpers +# --------------------------------------------------------------------------- + + +def now_utc() -> datetime: + return datetime.now(timezone.utc) + + +def asset_is_expired(asset: Asset) -> bool: + if asset.expires_at is None: + return False + return asset.expires_at.replace(tzinfo=timezone.utc) < now_utc() + + +# --------------------------------------------------------------------------- +# Auth helpers +# --------------------------------------------------------------------------- + + +def require_datastack_permission( + user: AuthUser, + settings: Settings, + datastack: str, + permission: str, +) -> None: + """Raise 403 if auth is enabled and *user* lacks *permission* on *datastack*.""" + if not settings.auth.enabled: + return + if user.has_permission(datastack, permission): + return + label = "Write" if permission == "edit" else "Read" + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, + detail=f"{label} permission required on datastack '{datastack}'", + ) + + +def require_asset_view_access( + user: AuthUser, + settings: Settings, + asset: Asset, +) -> None: + """Raise 403 if auth is enabled and *user* can't view *asset*. + + Checks both permission on the asset's access group (or datastack) and + group membership — matching the existing access-control semantics. + """ + if not settings.auth.enabled: + return + required_resource = asset.access_group or asset.datastack + if user.has_permission(required_resource, "view") or user.in_group( + required_resource + ): + return + raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="Access denied") + + +# --------------------------------------------------------------------------- +# Asset lookup +# --------------------------------------------------------------------------- + + +async def get_asset( + session: AsyncSession, + asset_id: uuid.UUID, + *, + check_expired: bool = True, +) -> Asset: + """Fetch an asset by ID, raising 404 if missing or (optionally) expired.""" + asset = await session.get(Asset, asset_id) + if asset is None: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, detail="Asset not found" + ) + if check_expired and asset_is_expired(asset): + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, detail="Asset not found" + ) + return asset diff --git a/tests/test_helpers.py b/tests/test_helpers.py new file mode 100644 index 0000000..63aaa9a --- /dev/null +++ b/tests/test_helpers.py @@ -0,0 +1,236 @@ +"""Tests for cave_catalog.routers.helpers.""" + +from __future__ import annotations + +import uuid +from datetime import datetime, timedelta, timezone +from types import SimpleNamespace +from unittest.mock import AsyncMock + +import pytest +from cave_catalog.routers.helpers import ( + asset_is_expired, + get_asset, + now_utc, + require_asset_view_access, + require_datastack_permission, +) + +# --------------------------------------------------------------------------- +# Factories +# --------------------------------------------------------------------------- + + +def _make_settings(auth_enabled: bool = True) -> SimpleNamespace: + return SimpleNamespace(auth=SimpleNamespace(enabled=auth_enabled)) + + +def _make_user( + *, + permissions: dict | None = None, + groups: list | None = None, + is_admin: bool = False, +) -> SimpleNamespace: + from cave_catalog.auth.middleware import AuthUser + + return AuthUser( + user_id=1, + email="test@example.com", + permissions=permissions or {}, + groups=groups or [], + is_admin=is_admin, + ) + + +def _make_asset( + *, + datastack: str = "minnie65", + access_group: str | None = None, + expires_at: datetime | None = None, +) -> SimpleNamespace: + return SimpleNamespace( + id=uuid.uuid4(), + datastack=datastack, + access_group=access_group, + expires_at=expires_at, + ) + + +# --------------------------------------------------------------------------- +# now_utc / asset_is_expired +# --------------------------------------------------------------------------- + + +def test_now_utc_is_aware(): + t = now_utc() + assert t.tzinfo is not None + + +def test_asset_not_expired_when_no_expiry(): + asset = _make_asset() + assert not asset_is_expired(asset) + + +def test_asset_expired_in_past(): + asset = _make_asset(expires_at=datetime(2020, 1, 1, tzinfo=timezone.utc)) + assert asset_is_expired(asset) + + +def test_asset_not_expired_in_future(): + future = now_utc() + timedelta(days=1) + asset = _make_asset(expires_at=future) + assert not asset_is_expired(asset) + + +# --------------------------------------------------------------------------- +# require_datastack_permission +# --------------------------------------------------------------------------- + + +def test_permission_passes_when_auth_disabled(): + user = _make_user() + settings = _make_settings(auth_enabled=False) + # Should not raise + require_datastack_permission(user, settings, "any_ds", "edit") + + +def test_permission_passes_when_user_has_perm(): + user = _make_user(permissions={"ds1": ["edit"]}) + settings = _make_settings() + require_datastack_permission(user, settings, "ds1", "edit") + + +def test_permission_raises_403_on_missing_edit(): + from fastapi import HTTPException + + user = _make_user(permissions={"ds1": ["view"]}) + settings = _make_settings() + with pytest.raises(HTTPException) as exc_info: + require_datastack_permission(user, settings, "ds1", "edit") + assert exc_info.value.status_code == 403 + assert "Write permission" in exc_info.value.detail + + +def test_permission_raises_403_on_missing_view(): + from fastapi import HTTPException + + user = _make_user(permissions={}) + settings = _make_settings() + with pytest.raises(HTTPException) as exc_info: + require_datastack_permission(user, settings, "ds1", "view") + assert exc_info.value.status_code == 403 + assert "Read permission" in exc_info.value.detail + + +def test_permission_passes_for_admin(): + user = _make_user(is_admin=True) + settings = _make_settings() + require_datastack_permission(user, settings, "any_ds", "edit") + + +# --------------------------------------------------------------------------- +# require_asset_access +# --------------------------------------------------------------------------- + + +def test_asset_access_passes_when_auth_disabled(): + user = _make_user() + settings = _make_settings(auth_enabled=False) + asset = _make_asset() + require_asset_view_access(user, settings, asset) + + +def test_asset_access_passes_with_permission_on_datastack(): + user = _make_user(permissions={"minnie65": ["view"]}) + settings = _make_settings() + asset = _make_asset(datastack="minnie65") + require_asset_view_access(user, settings, asset) + + +def test_asset_access_passes_with_permission_on_access_group(): + user = _make_user(permissions={"my_group": ["view"]}) + settings = _make_settings() + asset = _make_asset(datastack="minnie65", access_group="my_group") + require_asset_view_access(user, settings, asset) + + +def test_asset_access_passes_with_group_membership(): + user = _make_user(groups=["minnie65"]) + settings = _make_settings() + asset = _make_asset(datastack="minnie65") + require_asset_view_access(user, settings, asset) + + +def test_asset_access_raises_403_when_denied(): + from fastapi import HTTPException + + user = _make_user(permissions={}, groups=[]) + settings = _make_settings() + asset = _make_asset(datastack="minnie65") + with pytest.raises(HTTPException) as exc_info: + require_asset_view_access(user, settings, asset) + assert exc_info.value.status_code == 403 + assert "Access denied" in exc_info.value.detail + + +def test_asset_access_passes_for_admin(): + user = _make_user(is_admin=True) + settings = _make_settings() + asset = _make_asset(datastack="minnie65") + require_asset_view_access(user, settings, asset) + + +# --------------------------------------------------------------------------- +# get_asset +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_get_asset_returns_asset(): + from cave_catalog.db.models import Asset + + asset_id = uuid.uuid4() + mock_asset = _make_asset() + mock_asset.expires_at = None + + session = AsyncMock() + session.get = AsyncMock(return_value=mock_asset) + + result = await get_asset(session, asset_id) + assert result is mock_asset + session.get.assert_awaited_once_with(Asset, asset_id) + + +@pytest.mark.asyncio +async def test_get_asset_raises_404_when_missing(): + from fastapi import HTTPException + + session = AsyncMock() + session.get = AsyncMock(return_value=None) + + with pytest.raises(HTTPException) as exc_info: + await get_asset(session, uuid.uuid4()) + assert exc_info.value.status_code == 404 + + +@pytest.mark.asyncio +async def test_get_asset_raises_404_when_expired(): + from fastapi import HTTPException + + expired = _make_asset(expires_at=datetime(2020, 1, 1, tzinfo=timezone.utc)) + session = AsyncMock() + session.get = AsyncMock(return_value=expired) + + with pytest.raises(HTTPException) as exc_info: + await get_asset(session, uuid.uuid4()) + assert exc_info.value.status_code == 404 + + +@pytest.mark.asyncio +async def test_get_asset_ignores_expiry_when_check_expired_false(): + expired = _make_asset(expires_at=datetime(2020, 1, 1, tzinfo=timezone.utc)) + session = AsyncMock() + session.get = AsyncMock(return_value=expired) + + result = await get_asset(session, uuid.uuid4(), check_expired=False) + assert result is expired