diff --git a/pyproject.toml b/pyproject.toml index c15ec1c..8d1c027 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -2,14 +2,21 @@ dependencies = [ "alembic>=1.13", "asyncpg>=0.29", + "cachetools>=7.0.5", + "caveclient", "cloudpathlib[gs,s3]>=0.23.0", "deltalake>=0.17", "fastapi>=0.115.0", + "fsspec>=2026.3.0", + "gcsfs>=2026.3.0", "google-auth>=2.0", "httpx>=0.27", + "jinja2>=3.1.6", "polars>=1.0", + "pyarrow>=23.0.1", "pydantic-settings>=2.3", "pydantic>=2.7", + "python-multipart>=0.0.26", "requests>=2.31", "sqlalchemy[asyncio]>=2.0", "structlog>=24.2", @@ -48,3 +55,6 @@ testpaths = ["tests"] disallow_untyped_defs = true ignore_missing_imports = true packages = ["cave_catalog"] + +[tool.uv.sources] +caveclient = { workspace = true } diff --git a/src/cave_catalog/app.py b/src/cave_catalog/app.py index 8568071..3332d96 100644 --- a/src/cave_catalog/app.py +++ b/src/cave_catalog/app.py @@ -1,14 +1,18 @@ import logging from collections.abc import AsyncGenerator from contextlib import asynccontextmanager +from pathlib import Path import structlog -from fastapi import FastAPI +from fastapi import FastAPI, Request from fastapi.middleware.cors import CORSMiddleware +from fastapi.responses import RedirectResponse +from fastapi.staticfiles import StaticFiles from cave_catalog.config import get_settings from cave_catalog.db.session import get_engine -from cave_catalog.routers import assets, health +from cave_catalog.routers import assets, health, tables, ui +from cave_catalog.routers.ui import _RedirectException logger = structlog.get_logger() @@ -52,5 +56,14 @@ def create_app() -> FastAPI: app.include_router(health.router) app.include_router(assets.router) + app.include_router(tables.router) + app.include_router(ui.router) + + @app.exception_handler(_RedirectException) + async def _handle_redirect(request: Request, exc: _RedirectException): + return RedirectResponse(url=exc.url, status_code=302) + + _pkg_dir = Path(__file__).resolve().parent + app.mount("/static", StaticFiles(directory=_pkg_dir / "static"), name="static") return app diff --git a/src/cave_catalog/auth/middleware.py b/src/cave_catalog/auth/middleware.py index 13bcdd0..88ba03a 100644 --- a/src/cave_catalog/auth/middleware.py +++ b/src/cave_catalog/auth/middleware.py @@ -196,8 +196,11 @@ def create_token_cookie_response(redirect_url: str, token: str) -> Response: def get_authorize_url(settings: Settings, redirect_url: str) -> str: + # AUTH_SERVICE_URL is e.g. "https://globalv1.daf-apis.com/auth" (for API calls) + # but the OAuth authorize endpoint lives on the sticky auth app at /sticky_auth/api/v1/authorize auth_url = settings.auth.service_url.rstrip("/") - return f"{auth_url}/api/v1/authorize?redirect={quote(redirect_url)}" + base_url = auth_url.removesuffix("/auth") + return f"{base_url}/sticky_auth/api/v1/authorize?redirect={quote(redirect_url)}" async def get_current_user( diff --git a/src/cave_catalog/config.py b/src/cave_catalog/config.py index f491d27..148a0ac 100644 --- a/src/cave_catalog/config.py +++ b/src/cave_catalog/config.py @@ -1,3 +1,4 @@ +import json from functools import lru_cache from pydantic import Field @@ -24,8 +25,22 @@ class Settings(BaseSettings): service_name: str = Field(default="cave-catalog", alias="SERVICE_NAME") mat_engine_url: str | None = Field(default=None, alias="MAT_ENGINE_URL") log_level: str = Field(default="INFO", alias="LOG_LEVEL") + datastacks_raw: str = Field(default="", alias="DATASTACKS") + cave_token: str | None = Field(default=None, alias="CAVE_TOKEN") + caveclient_server_address: str | None = Field( + default=None, alias="CAVECLIENT_SERVER_ADDRESS" + ) auth: AuthSettings = Field(default_factory=AuthSettings) + @property + def datastacks(self) -> list[str]: + raw = self.datastacks_raw.strip() + if not raw: + return [] + if raw.startswith("["): + return json.loads(raw) + return [s.strip() for s in raw.split(",") if s.strip()] + model_config = SettingsConfigDict( env_file=".env", env_file_encoding="utf-8", diff --git a/src/cave_catalog/db/models.py b/src/cave_catalog/db/models.py index a9b0f50..a6646ea 100644 --- a/src/cave_catalog/db/models.py +++ b/src/cave_catalog/db/models.py @@ -1,4 +1,5 @@ import uuid +from collections.abc import MutableMapping from datetime import datetime, timezone from sqlalchemy import JSON, Boolean, DateTime, Index, Integer, String, text @@ -6,6 +7,13 @@ from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column +class _FallbackPolymorphicMap(dict, MutableMapping): + """Allow unknown ``asset_type`` values to load as the base ``Asset`` class.""" + + def __missing__(self, key): + return self["asset"] + + class Base(DeclarativeBase): pass @@ -23,7 +31,7 @@ class Asset(Base): mat_version: Mapped[int | None] = mapped_column(Integer, nullable=True) revision: Mapped[int] = mapped_column(Integer, nullable=False, default=0) uri: Mapped[str] = mapped_column(String, nullable=False) - format: Mapped[str] = mapped_column(String, nullable=False) + format: Mapped[str | None] = mapped_column(String, nullable=True) asset_type: Mapped[str] = mapped_column(String, nullable=False) owner: Mapped[int] = mapped_column(Integer, nullable=False) is_managed: Mapped[bool] = mapped_column(Boolean, nullable=False) @@ -40,6 +48,20 @@ class Asset(Base): DateTime(timezone=True), nullable=True ) + __mapper_args__ = { + "polymorphic_on": "asset_type", + "polymorphic_identity": "asset", + "with_polymorphic": "*", + } + + # Table-specific nullable columns (populated only for asset_type="table") + source: Mapped[str | None] = mapped_column(String, nullable=True) + cached_metadata: Mapped[dict | None] = mapped_column(JSON, nullable=True) + metadata_cached_at: Mapped[datetime | None] = mapped_column( + DateTime(timezone=True), nullable=True + ) + column_annotations: Mapped[list | None] = mapped_column(JSON, nullable=True) + __table_args__ = ( # Uniqueness when mat_version is present Index( @@ -61,3 +83,17 @@ class Asset(Base): postgresql_where=text("mat_version IS NULL"), ), ) + + +# Install fallback polymorphic map so unknown asset_type values load as Asset +Asset.__mapper__.polymorphic_map = _FallbackPolymorphicMap( + Asset.__mapper__.polymorphic_map +) + + +class Table(Asset): + """Table asset — single table inheritance subclass of Asset.""" + + __mapper_args__ = { + "polymorphic_identity": "table", + } diff --git a/src/cave_catalog/extractors.py b/src/cave_catalog/extractors.py new file mode 100644 index 0000000..acdc4fb --- /dev/null +++ b/src/cave_catalog/extractors.py @@ -0,0 +1,203 @@ +"""Metadata extractors for table assets. + +Each extractor reads lightweight metadata (schema, row count, size, partition +info) from a cloud storage URI and returns a ``TableMetadata`` instance. +""" + +from __future__ import annotations + +import abc +import asyncio +from typing import Any + +import structlog + +from cave_catalog.table_schemas import ColumnInfo, TableMetadata + +logger = structlog.get_logger() + + +# --------------------------------------------------------------------------- +# Base interface (task 2.1) +# --------------------------------------------------------------------------- + + +class MetadataExtractor(abc.ABC): + """Base interface for format-specific metadata extractors.""" + + @abc.abstractmethod + async def extract( + self, + uri: str, + storage_options: dict[str, Any] | None = None, + ) -> TableMetadata: + """Extract metadata from the table at *uri*. + + Parameters + ---------- + uri + Cloud or local path to the table/file. + storage_options + Optional storage credentials (e.g. GCS token dict). + + Returns + ------- + TableMetadata + Discovered metadata. + """ + + +# --------------------------------------------------------------------------- +# Delta Lake extractor (task 2.2) +# --------------------------------------------------------------------------- + + +class DeltaMetadataExtractor(MetadataExtractor): + """Extract metadata from a Delta Lake table via the transaction log.""" + + async def extract( + self, + uri: str, + storage_options: dict[str, Any] | None = None, + ) -> TableMetadata: + from deltalake import DeltaTable + + logger.debug("delta_extract_start", uri=uri) + + kwargs: dict[str, Any] = {} + if storage_options: + kwargs["storage_options"] = storage_options + + try: + dt = await asyncio.to_thread(lambda: DeltaTable(uri, **kwargs)) + except Exception as exc: + msg = str(exc) + if "no files in log segment" in msg.lower() or "log segment" in msg.lower(): + raise ValueError( + f"No Delta transaction log found at '{uri}'. " + "This path may not contain a Delta Lake table." + ) from exc + raise + + schema = await asyncio.to_thread(lambda: dt.schema()) + columns = [ + ColumnInfo(name=field.name, dtype=str(field.type)) + for field in schema.fields + ] + + metadata = await asyncio.to_thread(lambda: dt.metadata()) + partition_columns = list(metadata.partition_columns) + + n_rows: int | None = None + n_bytes: int | None = None + + # Try to get row count and size from file stats + try: + actions_table = await asyncio.to_thread( + lambda: dt.get_add_actions(flatten=True) + ) + # Convert arro3 Table to dict-of-lists via pyarrow + import pyarrow as pa + + file_actions = pa.table(actions_table).to_pydict() + if "num_records" in file_actions: + row_counts = file_actions["num_records"] + if all(r is not None for r in row_counts): + n_rows = sum(row_counts) + if "size_bytes" in file_actions: + sizes = file_actions["size_bytes"] + if all(s is not None for s in sizes): + n_bytes = sum(sizes) + except Exception: + logger.debug("delta_stats_unavailable", uri=uri) + + return TableMetadata( + n_rows=n_rows, + n_columns=len(columns), + n_bytes=n_bytes, + columns=columns, + partition_columns=partition_columns, + ) + + +# --------------------------------------------------------------------------- +# Parquet extractor (task 2.3) +# --------------------------------------------------------------------------- + + +class ParquetMetadataExtractor(MetadataExtractor): + """Extract metadata from a Parquet file/dataset via polars.""" + + async def extract( + self, + uri: str, + storage_options: dict[str, Any] | None = None, + ) -> TableMetadata: + import polars as pl + + logger.debug("parquet_extract_start", uri=uri) + + kwargs: dict[str, Any] = {} + if storage_options: + kwargs["storage_options"] = storage_options + + lf = await asyncio.to_thread( + lambda: pl.scan_parquet(uri, **kwargs) + ) + schema = await asyncio.to_thread(lambda: lf.collect_schema()) + columns = [ + ColumnInfo(name=name, dtype=str(dtype)) + for name, dtype in schema.items() + ] + + # Get on-disk size via fsspec and row count from parquet metadata + n_rows: int | None = None + n_bytes: int | None = None + try: + import fsspec + import pyarrow.parquet as pq + + fs, path = await asyncio.to_thread( + lambda: fsspec.core.url_to_fs(uri, **(storage_options or {})) + ) + info = await asyncio.to_thread(lambda: fs.info(path)) + n_bytes = info.get("size") + + pq_meta = await asyncio.to_thread( + lambda: pq.read_metadata(path, filesystem=fs) + ) + n_rows = pq_meta.num_rows + except Exception as exc: + logger.warning("parquet_stats_unavailable", uri=uri, error=str(exc)) + + return TableMetadata( + n_rows=n_rows, + n_columns=len(columns), + n_bytes=n_bytes, + columns=columns, + partition_columns=[], + ) + + +# --------------------------------------------------------------------------- +# Extractor registry (task 2.4) +# --------------------------------------------------------------------------- + +EXTRACTORS: dict[str, MetadataExtractor] = { + "delta": DeltaMetadataExtractor(), + "parquet": ParquetMetadataExtractor(), +} + + +def get_extractor(fmt: str) -> MetadataExtractor: + """Look up extractor by format string. + + Raises ``ValueError`` if no extractor is registered for *fmt*. + """ + try: + return EXTRACTORS[fmt.lower()] + except KeyError: + raise ValueError( + f"No metadata extractor for format '{fmt}'. " + f"Supported: {', '.join(sorted(EXTRACTORS))}" + ) diff --git a/src/cave_catalog/mat_proxy.py b/src/cave_catalog/mat_proxy.py new file mode 100644 index 0000000..af66126 --- /dev/null +++ b/src/cave_catalog/mat_proxy.py @@ -0,0 +1,184 @@ +"""Materialization service proxy — cached CAVEclient queries for reference data.""" + +from __future__ import annotations + +import asyncio +import logging +from dataclasses import dataclass + +from cachetools import TTLCache +from caveclient import CAVEclient + +from cave_catalog.config import get_settings + +logger = logging.getLogger(__name__) + +# Cache configuration +_CACHE_TTL = 300 # 5 minutes +_CACHE_MAXSIZE = 256 + +# Caches keyed by (datastack, version) or (datastack, version, target_name) +_tables_cache: TTLCache[tuple[str, int | None], list[str]] = TTLCache( + maxsize=_CACHE_MAXSIZE, ttl=_CACHE_TTL +) +_views_cache: TTLCache[tuple[str, int | None], list[str]] = TTLCache( + maxsize=_CACHE_MAXSIZE, ttl=_CACHE_TTL +) +_columns_cache: TTLCache[tuple[str, int | None, str, str], list[dict]] = TTLCache( + maxsize=_CACHE_MAXSIZE, ttl=_CACHE_TTL +) + + +class MatProxyError(Exception): + """Raised when a materialization proxy operation fails.""" + + +@dataclass +class LinkableTarget: + name: str + target_type: str # "table" or "view" + + +def _get_cave_client(datastack: str, version: int | None = None) -> CAVEclient: + """Create a CAVEclient instance with the service token.""" + settings = get_settings() + if not settings.cave_token: + raise MatProxyError( + "CAVE_TOKEN is not configured. Cannot query materialization service." + ) + kwargs: dict = { + "datastack_name": datastack, + "auth_token": settings.cave_token, + } + if settings.caveclient_server_address: + kwargs["server_address"] = settings.caveclient_server_address + if version is not None: + kwargs["version"] = version + return CAVEclient(**kwargs) + + +def _sync_get_tables(datastack: str, version: int | None = None) -> list[str]: + """Synchronous: fetch table list via CAVEclient.""" + client = _get_cave_client(datastack, version) + return client.materialize.get_tables() + + +def _sync_get_views(datastack: str, version: int | None = None) -> list[str]: + """Synchronous: fetch view list via CAVEclient.""" + client = _get_cave_client(datastack, version) + return client.materialize.get_views() + + +def _sync_get_table_columns( + datastack: str, table_name: str, version: int | None = None +) -> list[dict]: + """Synchronous: resolve columns for a materialization table. + + Path: get_table_metadata() → schema_type → schema_definition() → columns. + """ + client = _get_cave_client(datastack, version) + metadata = client.materialize.get_table_metadata(table_name) + schema_type = metadata.get("schema_type") or metadata.get("schema") + if not schema_type: + raise MatProxyError( + f"Could not determine schema type for table '{table_name}'" + ) + schema_def = client.schema.schema_definition(schema_type) + # schema_def is a JSON Schema; resolve top-level $ref then read "properties" + resolved = schema_def + ref = schema_def.get("$ref", "") + if ref.startswith("#/definitions/"): + def_name = ref.split("/")[-1] + resolved = schema_def.get("definitions", {}).get(def_name, schema_def) + properties = resolved.get("properties", resolved) + columns = [] + for col_name, col_info in properties.items(): + columns.append({"name": col_name, "type": str(col_info)}) + return columns + + +def _sync_get_view_columns( + datastack: str, view_name: str, version: int | None = None +) -> list[dict]: + """Synchronous: resolve columns for a materialization view.""" + client = _get_cave_client(datastack, version) + schema = client.materialize.get_view_schema(view_name) + # schema is a dict with column names → type info + columns = [] + for col_name, col_info in schema.items(): + columns.append({"name": col_name, "type": str(col_info)}) + return columns + + +async def get_mat_tables(datastack: str, version: int | None = None) -> list[str]: + """Get materialization tables for a datastack (cached).""" + cache_key = (datastack, version) + if cache_key in _tables_cache: + return _tables_cache[cache_key] + try: + tables = await asyncio.to_thread(_sync_get_tables, datastack, version) + except MatProxyError: + raise + except Exception as e: + logger.exception("Failed to fetch mat tables for %s", datastack) + raise MatProxyError(f"Failed to fetch tables: {e}") from e + _tables_cache[cache_key] = tables + return tables + + +async def get_mat_views(datastack: str, version: int | None = None) -> list[str]: + """Get materialization views for a datastack (cached).""" + cache_key = (datastack, version) + if cache_key in _views_cache: + return _views_cache[cache_key] + try: + views = await asyncio.to_thread(_sync_get_views, datastack, version) + except MatProxyError: + raise + except Exception as e: + logger.exception("Failed to fetch mat views for %s", datastack) + raise MatProxyError(f"Failed to fetch views: {e}") from e + _views_cache[cache_key] = views + return views + + +async def get_linkable_targets( + datastack: str, version: int | None = None +) -> list[LinkableTarget]: + """Get combined list of tables and views as linkable targets.""" + tables = await get_mat_tables(datastack, version) + views = await get_mat_views(datastack, version) + targets = [LinkableTarget(name=t, target_type="table") for t in tables] + targets += [LinkableTarget(name=v, target_type="view") for v in views] + targets.sort(key=lambda t: t.name) + return targets + + +async def get_target_columns( + datastack: str, + target_name: str, + target_type: str, + version: int | None = None, +) -> list[dict]: + """Get columns for a linkable target (table or view), cached.""" + cache_key = (datastack, version, target_name, target_type) + if cache_key in _columns_cache: + return _columns_cache[cache_key] + try: + if target_type == "view": + columns = await asyncio.to_thread( + _sync_get_view_columns, datastack, target_name, version + ) + else: + columns = await asyncio.to_thread( + _sync_get_table_columns, datastack, target_name, version + ) + except MatProxyError: + raise + except Exception as e: + logger.exception( + "Failed to fetch columns for %s/%s", datastack, target_name + ) + raise MatProxyError(f"Failed to fetch columns for '{target_name}': {e}") from e + _columns_cache[cache_key] = columns + return columns diff --git a/src/cave_catalog/routers/assets.py b/src/cave_catalog/routers/assets.py index f5f65db..787a032 100644 --- a/src/cave_catalog/routers/assets.py +++ b/src/cave_catalog/routers/assets.py @@ -10,7 +10,6 @@ import structlog from fastapi import APIRouter, Depends, HTTPException, Query, status -from httpx import AsyncClient from sqlalchemy import and_, or_, select from sqlalchemy.exc import IntegrityError from sqlalchemy.ext.asyncio import AsyncSession @@ -20,8 +19,12 @@ from cave_catalog.db.models import Asset from cave_catalog.db.session import get_session from cave_catalog.routers.helpers import ( + asset_to_response, + find_duplicate, get_asset, + get_http_client, now_utc, + raise_if_validation_failed, require_asset_view_access, require_datastack_permission, ) @@ -32,60 +35,14 @@ ValidationCheck, ValidationReport, ) +from cave_catalog.table_schemas import TableResponse from cave_catalog.validation import run_validation_pipeline +from cave_catalog.validation import check_name_reservation as _check_name_reservation logger = structlog.get_logger() router = APIRouter(prefix="/api/v1/assets", tags=["assets"]) -# Shared httpx client (module-level singleton; fine for service lifetime) -_http_client: AsyncClient | None = None - - -def _get_http_client() -> AsyncClient: - global _http_client - if _http_client is None: - _http_client = AsyncClient() - return _http_client - - -# --------------------------------------------------------------------------- -# Helpers -# --------------------------------------------------------------------------- - - -def _asset_to_response(asset: Asset) -> AssetResponse: - return AssetResponse.model_validate(asset) - - -async def _find_duplicate( - session: AsyncSession, - datastack: str, - name: str, - mat_version: int | None, - revision: int, -) -> Asset | None: - if mat_version is not None: - stmt = select(Asset).where( - and_( - Asset.datastack == datastack, - Asset.name == name, - Asset.mat_version == mat_version, - Asset.revision == revision, - ) - ) - else: - stmt = select(Asset).where( - and_( - Asset.datastack == datastack, - Asset.name == name, - Asset.mat_version.is_(None), - Asset.revision == revision, - ) - ) - result = await session.execute(stmt) - return result.scalar_one_or_none() - # --------------------------------------------------------------------------- # POST /api/v1/assets/register @@ -111,7 +68,7 @@ async def register_asset( require_datastack_permission(user, settings, body.datastack, "edit") # Duplicate check - existing = await _find_duplicate( + existing = await find_duplicate( session, body.datastack, body.name, body.mat_version, body.revision ) if existing is not None: @@ -127,20 +84,10 @@ async def register_asset( uri=body.uri, fmt=body.format, properties=body.properties, - client=_get_http_client(), + client=get_http_client(), token=user.token, ) - - failures = { - k: v - for k, v in report.model_dump().items() - if v is not None and not v.get("passed", True) - } - if failures: - raise HTTPException( - status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, - detail={"message": "Validation failed", "checks": failures}, - ) + raise_if_validation_failed(report) asset = Asset( id=uuid.uuid4(), @@ -167,7 +114,7 @@ async def register_asset( except IntegrityError: await session.rollback() # Race condition — duplicate was inserted between our check and insert - dup = await _find_duplicate( + dup = await find_duplicate( session, body.datastack, body.name, body.mat_version, body.revision ) raise HTTPException( @@ -181,7 +128,7 @@ async def register_asset( logger.info( "asset_registered", id=str(asset.id), datastack=body.datastack, name=body.name ) - return _asset_to_response(asset) + return asset_to_response(asset) # --------------------------------------------------------------------------- @@ -215,7 +162,7 @@ async def validate_asset( report.auth_check = ValidationCheck(passed=True) # Duplicate check - existing = await _find_duplicate( + existing = await find_duplicate( session, body.datastack, body.name, body.mat_version, body.revision ) if existing is not None: @@ -230,7 +177,7 @@ async def validate_asset( uri=body.uri, fmt=body.format, properties=body.properties, - client=_get_http_client(), + client=get_http_client(), token=user.token, ) report.name_reservation_check = content_report.name_reservation_check @@ -241,12 +188,49 @@ async def validate_asset( return report +# --------------------------------------------------------------------------- +# GET /api/v1/assets/check-name +# --------------------------------------------------------------------------- + + +@router.get("/check-name") +async def check_name( + datastack: str = Query(...), + name: str = Query(...), + mat_version: int | None = Query(default=None), + revision: int = Query(default=0), + user: AuthUser = Depends(require_auth), + session: AsyncSession = Depends(get_session), +) -> dict: + # 1. Check name reservation against mat tables + reservation = await _check_name_reservation( + datastack=datastack, + name=name, + is_mat_source=False, + client=get_http_client(), + token=user.token, + ) + if not reservation.passed: + return {"available": False, "reason": "reserved"} + + # 2. Check for duplicate asset in DB + existing = await find_duplicate(session, datastack, name, mat_version, revision) + if existing is not None: + return { + "available": False, + "reason": "duplicate", + "existing_id": str(existing.id), + } + + return {"available": True} + + # --------------------------------------------------------------------------- # GET /api/v1/assets/ # --------------------------------------------------------------------------- -@router.get("/", response_model=list[AssetResponse]) +@router.get("/", response_model=list[TableResponse | AssetResponse]) async def list_assets( datastack: str = Query(...), name: str | None = Query(default=None), @@ -259,7 +243,7 @@ async def list_assets( user: AuthUser = Depends(require_auth), session: AsyncSession = Depends(get_session), settings: Settings = Depends(get_settings), -) -> list[AssetResponse]: +) -> list[AssetResponse | TableResponse]: logger.debug("list_assets", datastack=datastack, name=name, mat_version=mat_version) require_datastack_permission(user, settings, datastack, "view") @@ -287,7 +271,7 @@ async def list_assets( result = await session.execute(stmt) assets = result.scalars().all() - return [_asset_to_response(a) for a in assets] + return [asset_to_response(a) for a in assets] # --------------------------------------------------------------------------- @@ -295,17 +279,17 @@ async def list_assets( # --------------------------------------------------------------------------- -@router.get("/{asset_id}", response_model=AssetResponse) +@router.get("/{asset_id}", response_model=TableResponse | AssetResponse) async def get_asset_by_id( asset_id: uuid.UUID, user: AuthUser = Depends(require_auth), session: AsyncSession = Depends(get_session), settings: Settings = Depends(get_settings), -) -> AssetResponse: +) -> AssetResponse | TableResponse: logger.debug("get_asset", asset_id=str(asset_id)) asset = await get_asset(session, asset_id) require_asset_view_access(user, settings, asset) - return _asset_to_response(asset) + return asset_to_response(asset) # --------------------------------------------------------------------------- diff --git a/src/cave_catalog/routers/helpers.py b/src/cave_catalog/routers/helpers.py index 9bca655..db53f4e 100644 --- a/src/cave_catalog/routers/helpers.py +++ b/src/cave_catalog/routers/helpers.py @@ -12,15 +12,39 @@ import structlog from fastapi import HTTPException, status +from httpx import AsyncClient +from sqlalchemy import and_, select from sqlalchemy.ext.asyncio import AsyncSession from cave_catalog.auth.middleware import AuthUser from cave_catalog.config import Settings -from cave_catalog.db.models import Asset +from cave_catalog.db.models import Asset, Table +from cave_catalog.schemas import AssetResponse, ValidationReport +from cave_catalog.table_schemas import ( + ColumnAnnotation, + TableMetadata, + TableResponse, + merge_columns, +) logger = structlog.get_logger() +# --------------------------------------------------------------------------- +# Shared httpx client +# --------------------------------------------------------------------------- + +_http_client: AsyncClient | None = None + + +def get_http_client() -> AsyncClient: + """Module-level singleton httpx client for outbound service calls.""" + global _http_client + if _http_client is None: + _http_client = AsyncClient() + return _http_client + + # --------------------------------------------------------------------------- # Time helpers # --------------------------------------------------------------------------- @@ -101,3 +125,108 @@ async def get_asset( status_code=status.HTTP_404_NOT_FOUND, detail="Asset not found" ) return asset + + +# --------------------------------------------------------------------------- +# Duplicate check +# --------------------------------------------------------------------------- + + +async def find_duplicate( + session: AsyncSession, + datastack: str, + name: str, + mat_version: int | None, + revision: int, +) -> Asset | None: + """Find an existing asset with the same (datastack, name, mat_version, revision).""" + if mat_version is not None: + stmt = select(Asset).where( + and_( + Asset.datastack == datastack, + Asset.name == name, + Asset.mat_version == mat_version, + Asset.revision == revision, + ) + ) + else: + stmt = select(Asset).where( + and_( + Asset.datastack == datastack, + Asset.name == name, + Asset.mat_version.is_(None), + Asset.revision == revision, + ) + ) + result = await session.execute(stmt) + return result.scalar_one_or_none() + + +# --------------------------------------------------------------------------- +# Validation helpers +# --------------------------------------------------------------------------- + + +def raise_if_validation_failed(report: ValidationReport) -> None: + """Raise 422 if any check in *report* failed.""" + failures = { + k: v + for k, v in report.model_dump().items() + if v is not None and not v.get("passed", True) + } + if failures: + raise HTTPException( + status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, + detail={"message": "Validation failed", "checks": failures}, + ) + + +# --------------------------------------------------------------------------- +# Response builders +# --------------------------------------------------------------------------- + + +def table_to_response(table: Table) -> TableResponse: + """Build a TableResponse from an ORM Table, including merged columns.""" + metadata = None + if table.cached_metadata is not None: + metadata = TableMetadata.model_validate(table.cached_metadata) + + annotations: list[ColumnAnnotation] = [] + if table.column_annotations: + annotations = [ + ColumnAnnotation.model_validate(a) for a in table.column_annotations + ] + + columns = merge_columns(metadata, annotations) + + return TableResponse( + id=table.id, + datastack=table.datastack, + name=table.name, + mat_version=table.mat_version, + revision=table.revision, + uri=table.uri, + format=table.format, + asset_type=table.asset_type, + owner=table.owner, + is_managed=table.is_managed, + mutability=table.mutability, + maturity=table.maturity, + properties=table.properties, + access_group=table.access_group, + created_at=table.created_at, + expires_at=table.expires_at, + source=table.source, + cached_metadata=metadata, + metadata_cached_at=table.metadata_cached_at, + column_annotations=annotations, + columns=columns, + ) + + +def asset_to_response(asset: Asset) -> AssetResponse | TableResponse: + """Build the correct response model based on the asset's type.""" + if isinstance(asset, Table): + return table_to_response(asset) + return AssetResponse.model_validate(asset) diff --git a/src/cave_catalog/routers/tables.py b/src/cave_catalog/routers/tables.py new file mode 100644 index 0000000..f2a105a --- /dev/null +++ b/src/cave_catalog/routers/tables.py @@ -0,0 +1,372 @@ +"""Table-specific endpoints. + +Handles table registration, preview, annotation updates, metadata refresh, +and table-specific listing. +""" + +from __future__ import annotations + +import uuid + +import structlog +from fastapi import APIRouter, Depends, HTTPException, Query, status +from sqlalchemy import and_, or_, select +from sqlalchemy.exc import IntegrityError +from sqlalchemy.ext.asyncio import AsyncSession + +from cave_catalog.auth.middleware import AuthUser, require_auth +from cave_catalog.config import Settings, get_settings +from cave_catalog.db.models import Table +from cave_catalog.db.session import get_session +from cave_catalog.extractors import get_extractor +from cave_catalog.routers.helpers import ( + find_duplicate, + get_asset, + get_http_client, + now_utc, + raise_if_validation_failed, + require_datastack_permission, + table_to_response, +) +from cave_catalog.table_schemas import ( + AnnotationUpdateRequest, + TablePreviewRequest, + TablePreviewResponse, + TableRequest, + TableResponse, +) +from cave_catalog.validation import run_validation_pipeline, validate_column_links + +logger = structlog.get_logger() + +router = APIRouter(prefix="/api/v1/tables", tags=["tables"]) + + +# --------------------------------------------------------------------------- +# POST /api/v1/tables/preview (task 4.2) +# --------------------------------------------------------------------------- + + +@router.post("/preview", response_model=TablePreviewResponse) +async def preview_table( + body: TablePreviewRequest, + user: AuthUser = Depends(require_auth), + settings: Settings = Depends(get_settings), +) -> TablePreviewResponse: + logger.debug( + "preview_table", uri=body.uri, format=body.format, datastack=body.datastack + ) + require_datastack_permission(user, settings, body.datastack, "view") + + try: + extractor = get_extractor(body.format) + except ValueError as exc: + raise HTTPException( + status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, + detail=str(exc), + ) + + try: + metadata = await extractor.extract(body.uri) + except Exception as exc: + raise HTTPException( + status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, + detail=f"Metadata extraction failed: {exc}", + ) + + return TablePreviewResponse(metadata=metadata) + + +# --------------------------------------------------------------------------- +# POST /api/v1/tables/register (task 4.3) +# --------------------------------------------------------------------------- + + +@router.post( + "/register", response_model=TableResponse, status_code=status.HTTP_201_CREATED +) +async def register_table( + body: TableRequest, + user: AuthUser = Depends(require_auth), + session: AsyncSession = Depends(get_session), + settings: Settings = Depends(get_settings), +) -> TableResponse: + logger.debug( + "register_table", + datastack=body.datastack, + name=body.name, + uri=body.uri, + format=body.format, + ) + require_datastack_permission(user, settings, body.datastack, "edit") + + # Duplicate check + existing = await find_duplicate( + session, body.datastack, body.name, body.mat_version, body.revision + ) + if existing is not None: + raise HTTPException( + status_code=status.HTTP_409_CONFLICT, + detail={"message": "Table already exists", "existing_id": str(existing.id)}, + ) + + # Content validation pipeline + report = await run_validation_pipeline( + datastack=body.datastack, + name=body.name, + uri=body.uri, + fmt=body.format, + properties=body.properties, + client=get_http_client(), + token=user.token, + ) + raise_if_validation_failed(report) + + # Column link validation (if annotations provided) + annotations_dicts = [a.model_dump() for a in body.column_annotations] + if annotations_dicts: + link_result = await validate_column_links( + annotations_dicts, + body.datastack, + get_http_client(), + token=user.token, + ) + if not link_result.passed: + raise HTTPException( + status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, + detail={ + "message": "Column link validation failed", + "errors": [ + { + "column_name": e.column_name, + "target_table": e.target_table, + "target_column": e.target_column, + "reason": e.reason, + } + for e in link_result.errors + ], + }, + ) + + # Extract metadata + try: + extractor = get_extractor(body.format) + except ValueError as exc: + raise HTTPException( + status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, + detail=str(exc), + ) + + try: + metadata = await extractor.extract(body.uri) + except Exception as exc: + raise HTTPException( + status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, + detail=f"Metadata extraction failed: {exc}", + ) + + now = now_utc() + table = Table( + id=uuid.uuid4(), + datastack=body.datastack, + name=body.name, + mat_version=body.mat_version, + revision=body.revision, + uri=body.uri, + format=body.format, + asset_type="table", + owner=user.user_id, + is_managed=body.is_managed, + mutability=body.mutability.value, + maturity=body.maturity.value, + properties=body.properties, + access_group=body.access_group, + created_at=now, + expires_at=body.expires_at, + source=body.source, + cached_metadata=metadata.model_dump(), + metadata_cached_at=now, + column_annotations=annotations_dicts, + ) + + session.add(table) + try: + await session.commit() + await session.refresh(table) + except IntegrityError: + await session.rollback() + dup = await find_duplicate( + session, body.datastack, body.name, body.mat_version, body.revision + ) + raise HTTPException( + status_code=status.HTTP_409_CONFLICT, + detail={ + "message": "Table already exists", + "existing_id": str(dup.id) if dup else None, + }, + ) + + logger.info( + "table_registered", + id=str(table.id), + datastack=body.datastack, + name=body.name, + ) + return table_to_response(table) + + +# --------------------------------------------------------------------------- +# PATCH /api/v1/tables/{id}/annotations (task 4.4) +# --------------------------------------------------------------------------- + + +@router.patch("/{table_id}/annotations", response_model=TableResponse) +async def update_annotations( + table_id: uuid.UUID, + body: AnnotationUpdateRequest, + user: AuthUser = Depends(require_auth), + session: AsyncSession = Depends(get_session), + settings: Settings = Depends(get_settings), +) -> TableResponse: + logger.debug("update_annotations", table_id=str(table_id)) + table = await get_asset(session, table_id) + + if table.asset_type != "table": + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail="Asset is not a table", + ) + + require_datastack_permission(user, settings, table.datastack, "edit") + + # Column link validation + annotations_dicts = [a.model_dump() for a in body.column_annotations] + if annotations_dicts: + link_result = await validate_column_links( + annotations_dicts, + table.datastack, + get_http_client(), + token=user.token, + ) + if not link_result.passed: + raise HTTPException( + status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, + detail={ + "message": "Column link validation failed", + "errors": [ + { + "column_name": e.column_name, + "target_table": e.target_table, + "target_column": e.target_column, + "reason": e.reason, + } + for e in link_result.errors + ], + }, + ) + + table.column_annotations = annotations_dicts + await session.commit() + await session.refresh(table) + + logger.info("annotations_updated", table_id=str(table_id)) + return table_to_response(table) + + +# --------------------------------------------------------------------------- +# POST /api/v1/tables/{id}/refresh (task 4.5) +# --------------------------------------------------------------------------- + + +@router.post("/{table_id}/refresh", response_model=TableResponse) +async def refresh_metadata( + table_id: uuid.UUID, + user: AuthUser = Depends(require_auth), + session: AsyncSession = Depends(get_session), + settings: Settings = Depends(get_settings), +) -> TableResponse: + logger.debug("refresh_metadata", table_id=str(table_id)) + table = await get_asset(session, table_id) + + if table.asset_type != "table": + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail="Asset is not a table", + ) + + require_datastack_permission(user, settings, table.datastack, "edit") + + try: + extractor = get_extractor(table.format) + except ValueError as exc: + raise HTTPException( + status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, + detail=str(exc), + ) + + try: + metadata = await extractor.extract(table.uri) + except Exception as exc: + raise HTTPException( + status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, + detail=f"Metadata extraction failed: {exc}", + ) + + table.cached_metadata = metadata.model_dump() + table.metadata_cached_at = now_utc() + # column_annotations intentionally NOT modified + await session.commit() + await session.refresh(table) + + logger.info("metadata_refreshed", table_id=str(table_id)) + return table_to_response(table) + + +# --------------------------------------------------------------------------- +# GET /api/v1/tables/ (task 4.6) +# --------------------------------------------------------------------------- + + +@router.get("/", response_model=list[TableResponse]) +async def list_tables( + datastack: str = Query(...), + name: str | None = Query(default=None), + mat_version: int | None = Query(default=None), + revision: int | None = Query(default=None), + format: str | None = Query(default=None), + source: str | None = Query(default=None), + mutability: str | None = Query(default=None), + maturity: str | None = Query(default=None), + user: AuthUser = Depends(require_auth), + session: AsyncSession = Depends(get_session), + settings: Settings = Depends(get_settings), +) -> list[TableResponse]: + logger.debug("list_tables", datastack=datastack, name=name, format=format) + require_datastack_permission(user, settings, datastack, "view") + + now = now_utc() + stmt = select(Table).where( + and_( + Table.datastack == datastack, + Table.asset_type == "table", + or_(Table.expires_at.is_(None), Table.expires_at > now), + ) + ) + if name is not None: + stmt = stmt.where(Table.name == name) + if mat_version is not None: + stmt = stmt.where(Table.mat_version == mat_version) + if revision is not None: + stmt = stmt.where(Table.revision == revision) + if format is not None: + stmt = stmt.where(Table.format == format) + if source is not None: + stmt = stmt.where(Table.source == source) + if mutability is not None: + stmt = stmt.where(Table.mutability == mutability) + if maturity is not None: + stmt = stmt.where(Table.maturity == maturity) + + result = await session.execute(stmt) + tables = result.scalars().all() + return [table_to_response(t) for t in tables] diff --git a/src/cave_catalog/routers/ui.py b/src/cave_catalog/routers/ui.py new file mode 100644 index 0000000..622cf95 --- /dev/null +++ b/src/cave_catalog/routers/ui.py @@ -0,0 +1,455 @@ +"""UI route handlers for the server-rendered frontend.""" + +import httpx +from fastapi import APIRouter, Depends, Query, Request +from fastapi.responses import HTMLResponse, RedirectResponse +from sqlalchemy.ext.asyncio import AsyncSession + +from cave_catalog.auth.middleware import ( + AuthUser, + TOKEN_COOKIE_NAME, + create_token_cookie_response, + get_authorize_url, + get_current_user, +) +from cave_catalog.config import Settings, get_settings +from cave_catalog.db.session import get_session +from cave_catalog.extractors import get_extractor +from cave_catalog.mat_proxy import ( + MatProxyError, + get_linkable_targets, + get_target_columns, +) +from cave_catalog.routers.helpers import find_duplicate, get_http_client +from cave_catalog.templating import templates +from cave_catalog.validation import check_name_reservation + +router = APIRouter(prefix="/ui", tags=["ui"]) + + +# --------------------------------------------------------------------------- +# Auth guard dependency — redirects to login instead of returning 401 JSON +# --------------------------------------------------------------------------- + + +async def require_ui_auth( + request: Request, + user: AuthUser | None = Depends(get_current_user), + settings: Settings = Depends(get_settings), +) -> AuthUser: + """If user is authenticated, return the user. Otherwise redirect to login.""" + if user is None: + login_url = f"/ui/login?next={request.url.path}" + raise _redirect_exception(login_url) + return user + + +class _RedirectException(Exception): + def __init__(self, url: str) -> None: + self.url = url + + +def _redirect_exception(url: str) -> _RedirectException: + return _RedirectException(url) + + +# --------------------------------------------------------------------------- +# Auth routes +# --------------------------------------------------------------------------- + + +@router.get("/login") +async def login( + request: Request, + next: str = "/ui/register", + settings: Settings = Depends(get_settings), +): + """Redirect to middle_auth OAuth authorize endpoint.""" + # Build callback URL that includes the final destination + callback_url = str(request.url_for("ui_callback")) + f"?next={next}" + authorize_url = get_authorize_url(settings, callback_url) + return RedirectResponse(url=authorize_url) + + +@router.get("/callback", name="ui_callback") +async def callback( + request: Request, + next: str = "/ui/register", +): + """OAuth callback — extract token from query param, set cookie, redirect.""" + token = request.query_params.get(TOKEN_COOKIE_NAME) or request.query_params.get( + "token" + ) + if not token: + return RedirectResponse(url="/ui/login") + return create_token_cookie_response(redirect_url=next, token=token) + + +@router.get("/logout") +async def logout(): + """Clear the auth cookie and redirect to login.""" + response = RedirectResponse(url="/ui/login", status_code=302) + response.delete_cookie(key=TOKEN_COOKIE_NAME) + return response + + +# --------------------------------------------------------------------------- +# Page routes (auth-guarded) +# --------------------------------------------------------------------------- + +DATASTACK_COOKIE = "cave_catalog_datastack" + + +def _get_current_datastack(request: Request, settings: Settings) -> str | None: + """Read selected datastack from cookie, falling back to first configured.""" + cookie_val = request.cookies.get(DATASTACK_COOKIE) + if cookie_val and cookie_val in settings.datastacks: + return cookie_val + if settings.datastacks: + return settings.datastacks[0] + return None + + +def _page_context( + request: Request, user: AuthUser, settings: Settings, active_page: str +) -> dict: + """Build common template context for all pages.""" + return { + "active_page": active_page, + "user": user, + "datastacks": settings.datastacks, + "current_datastack": _get_current_datastack(request, settings), + } + + +@router.get("/select-datastack") +async def select_datastack( + request: Request, + datastack: str = "", + settings: Settings = Depends(get_settings), +): + """Set the selected datastack cookie (called via HTMX from the selector).""" + referer = request.headers.get("referer", "/ui/register") + response = RedirectResponse(url=referer, status_code=302) + if datastack and datastack in settings.datastacks: + response.set_cookie( + key=DATASTACK_COOKIE, value=datastack, httponly=True, samesite="lax" + ) + return response + + +@router.get("/register", response_class=HTMLResponse) +async def register_page( + request: Request, + user: AuthUser = Depends(require_ui_auth), + settings: Settings = Depends(get_settings), +): + return templates.TemplateResponse( + request, + "register.html", + _page_context(request, user, settings, "register"), + ) + + +@router.get("/explore", response_class=HTMLResponse) +async def explore_page( + request: Request, + user: AuthUser = Depends(require_ui_auth), + settings: Settings = Depends(get_settings), +): + return templates.TemplateResponse( + request, + "explore.html", + _page_context(request, user, settings, "explore"), + ) + + +# --------------------------------------------------------------------------- +# Preview HTMX route (Section 6) +# --------------------------------------------------------------------------- + + +@router.post("/preview", response_class=HTMLResponse) +async def preview_table( + request: Request, + user: AuthUser = Depends(require_ui_auth), + settings: Settings = Depends(get_settings), +): + """Extract metadata from a URI and return a preview HTML fragment.""" + form = await request.form() + uri = str(form.get("uri", "")).strip() + fmt = str(form.get("format", "delta")).strip() + + if not uri: + return templates.TemplateResponse( + request, + "fragments/preview_result.html", + {"error": "Please enter a URI."}, + ) + + # Resolve extractor + try: + extractor = get_extractor(fmt) + except ValueError: + return templates.TemplateResponse( + request, + "fragments/preview_result.html", + {"error": f"Unsupported format: '{fmt}'. Supported formats: delta, parquet."}, + ) + + # Run extraction + try: + metadata = await extractor.extract(uri) + except Exception as exc: + error_msg = str(exc) + # Distinguish error types for diagnostics + lower = error_msg.lower() + if "not found" in lower or "no such" in lower or "does not exist" in lower: + diagnostic = f"URI unreachable — the path does not exist or is not accessible: {error_msg}" + elif "permission" in lower or "forbidden" in lower or "access" in lower: + diagnostic = f"URI unreachable — permission denied: {error_msg}" + else: + diagnostic = f"Failed to read {fmt} data: {error_msg}" + return templates.TemplateResponse( + request, + "fragments/preview_result.html", + {"error": diagnostic}, + ) + + return templates.TemplateResponse( + request, + "fragments/preview_result.html", + {"metadata": metadata, "format": fmt, "error": None}, + ) + + +# --------------------------------------------------------------------------- +# Registration submit HTMX route (Section 8) +# --------------------------------------------------------------------------- + + +def _parse_column_annotations(form: dict) -> list[dict]: + """Parse column annotations and links from flat form data.""" + n_columns = int(form.get("n_columns", 0)) + annotations = [] + for i in range(n_columns): + col_name = form.get(f"col_name_{i}", "") + description = form.get(f"col_desc_{i}", "").strip() or None + + # Collect links for this column + links = [] + for key, val in form.items(): + if key.startswith(f"link_type_{i}_"): + link_id = key.split("_")[-1] + link_type = val + target = form.get(f"link_target_{i}_{link_id}", "") + column = form.get(f"link_column_{i}_{link_id}", "") + if target and column: + links.append( + { + "link_type": link_type, + "target_table": target, + "target_column": column, + } + ) + + if description or links: + annotations.append( + { + "column_name": col_name, + "description": description, + "links": links, + } + ) + return annotations + + +@router.post("/register/submit", response_class=HTMLResponse) +async def register_submit( + request: Request, + user: AuthUser = Depends(require_ui_auth), + settings: Settings = Depends(get_settings), + session: AsyncSession = Depends(get_session), +): + """Handle registration form submission — call the tables API internally.""" + form_data = await request.form() + form = dict(form_data) + + uri = str(form.get("uri", "")).strip() + fmt = str(form.get("format", "delta")).strip() + name = str(form.get("name", "")).strip() + datastack = _get_current_datastack(request, settings) + + if not uri or not name or not datastack: + return templates.TemplateResponse( + request, + "fragments/register_error.html", + {"error": "URI, name, and datastack are required."}, + ) + + mat_version_raw = str(form.get("mat_version", "")).strip() + mat_version = int(mat_version_raw) if mat_version_raw else None + + column_annotations = _parse_column_annotations(form) + + # Build the request payload for the tables API + import httpx + + payload = { + "datastack": datastack, + "name": name, + "mat_version": mat_version, + "revision": 0, + "uri": uri, + "format": fmt, + "asset_type": "table", + "is_managed": True, + "mutability": "static", + "maturity": "stable", + "properties": {}, + "column_annotations": column_annotations, + } + + # Call tables register API internally via ASGI transport + async with httpx.AsyncClient( + transport=httpx.ASGITransport(app=request.app), + base_url="http://localhost", + ) as client: + resp = await client.post( + "/api/v1/tables/register", + json=payload, + headers={"Authorization": f"Bearer {user.token}"}, + ) + + if resp.status_code == 201: + data = resp.json() + return templates.TemplateResponse( + request, + "fragments/register_success.html", + { + "table_id": data.get("id", ""), + "table_name": data.get("name", name), + "datastack": data.get("datastack", datastack), + "mat_version": data.get("mat_version"), + "uri": data.get("uri", uri), + "format": data.get("format", fmt), + }, + ) + else: + detail = resp.json().get("detail", "Unknown error") + if isinstance(detail, dict): + error_msg = detail.get("message", str(detail)) + details = [] + for err in detail.get("errors", []): + details.append( + f"{err.get('column_name', '?')}: {err.get('reason', '?')} " + f"(target: {err.get('target_table', '?')}.{err.get('target_column', '?')})" + ) + return templates.TemplateResponse( + request, + "fragments/register_error.html", + {"error": error_msg, "details": details or None}, + ) + return templates.TemplateResponse( + request, + "fragments/register_error.html", + {"error": str(detail)}, + ) + + +# --------------------------------------------------------------------------- +# Mat proxy HTMX fragment routes (for link builder) +# --------------------------------------------------------------------------- + + +@router.get("/fragments/linkable-targets", response_class=HTMLResponse) +async def linkable_targets_fragment( + request: Request, + user: AuthUser = Depends(require_ui_auth), + settings: Settings = Depends(get_settings), + version: int | None = None, +): + """Return HTML ') + try: + targets = await get_linkable_targets(datastack, version) + except MatProxyError as e: + return HTMLResponse(f'') + options = [''] + for t in targets: + label = f"{t.name} ({t.target_type})" + options.append( + f'' + ) + return HTMLResponse("\n".join(options)) + + +@router.get("/fragments/target-columns", response_class=HTMLResponse) +async def target_columns_fragment( + request: Request, + target_name: str, + target_type: str = "table", + user: AuthUser = Depends(require_ui_auth), + settings: Settings = Depends(get_settings), + version: int | None = None, +): + """Return HTML ') + try: + columns = await get_target_columns(datastack, target_name, target_type, version) + except MatProxyError as e: + return HTMLResponse(f'') + options = [''] + for col in columns: + options.append(f'') + return HTMLResponse("\n".join(options)) + + +@router.get("/fragments/check-name", response_class=HTMLResponse) +async def check_name_fragment( + request: Request, + name: str = Query(""), + mat_version: int | None = Query(default=None), + revision: int = Query(default=0), + user: AuthUser = Depends(require_ui_auth), + settings: Settings = Depends(get_settings), + session: AsyncSession = Depends(get_session), +): + """Return HTML fragment with ✓/✗ name availability indicator.""" + name = name.strip() + if not name: + return HTMLResponse("") + + datastack = _get_current_datastack(request, settings) + if not datastack: + return HTMLResponse( + 'No datastack selected' + ) + + # Check reservation against mat tables + reservation = await check_name_reservation( + datastack=datastack, + name=name, + is_mat_source=False, + client=get_http_client(), + token=user.token, + ) + if not reservation.passed: + return HTMLResponse( + '✗ Name is reserved for materialization' + ) + + # Check for duplicate in DB + existing = await find_duplicate(session, datastack, name, mat_version, revision) + if existing is not None: + return HTMLResponse( + f'✗ Already registered (ID: {existing.id})' + ) + + return HTMLResponse( + '✓ Available' + ) diff --git a/src/cave_catalog/schemas.py b/src/cave_catalog/schemas.py index 78deff7..4832ce6 100644 --- a/src/cave_catalog/schemas.py +++ b/src/cave_catalog/schemas.py @@ -23,7 +23,7 @@ class AssetRequest(BaseModel): mat_version: int | None = None revision: int = Field(default=0, ge=0) uri: str - format: str + format: str | None = None asset_type: str is_managed: bool mutability: Mutability = Mutability.STATIC @@ -40,7 +40,7 @@ class AssetResponse(BaseModel): mat_version: int | None revision: int uri: str - format: str + format: str | None asset_type: str owner: int is_managed: bool diff --git a/src/cave_catalog/static/style.css b/src/cave_catalog/static/style.css new file mode 100644 index 0000000..28f42a2 --- /dev/null +++ b/src/cave_catalog/static/style.css @@ -0,0 +1,346 @@ +/* CAVE Catalog — minimal admin UI styles */ + +:root { + --color-bg: #f5f6f8; + --color-surface: #ffffff; + --color-primary: #2563eb; + --color-primary-hover: #1d4ed8; + --color-text: #1e293b; + --color-text-muted: #64748b; + --color-border: #e2e8f0; + --color-sidebar-bg: #1e293b; + --color-sidebar-text: #cbd5e1; + --color-sidebar-active: #ffffff; + --color-success: #16a34a; + --color-error: #dc2626; + --topbar-height: 48px; + --sidebar-width: 200px; +} + +* { box-sizing: border-box; margin: 0; padding: 0; } + +body { + font-family: -apple-system, BlinkMacSystemFont, "Segoe UI", Roboto, Helvetica, Arial, sans-serif; + color: var(--color-text); + background: var(--color-bg); + line-height: 1.5; +} + +/* Top bar */ +.topbar { + display: flex; + align-items: center; + justify-content: space-between; + height: var(--topbar-height); + padding: 0 16px; + background: var(--color-surface); + border-bottom: 1px solid var(--color-border); + position: fixed; + top: 0; + left: 0; + right: 0; + z-index: 100; +} + +.topbar-left { display: flex; align-items: center; gap: 12px; } +.topbar-center { display: flex; align-items: center; gap: 8px; } +.topbar-right { display: flex; align-items: center; gap: 12px; } + +.service-name { + font-weight: 700; + font-size: 15px; + color: var(--color-primary); +} + +.topbar select { + padding: 4px 8px; + border: 1px solid var(--color-border); + border-radius: 4px; + font-size: 13px; + background: var(--color-surface); +} + +.user-info { + font-size: 13px; + color: var(--color-text-muted); +} + +.btn-link { + font-size: 13px; + color: var(--color-primary); + text-decoration: none; +} +.btn-link:hover { text-decoration: underline; } + +/* Layout */ +.layout { + display: flex; + margin-top: var(--topbar-height); + min-height: calc(100vh - var(--topbar-height)); +} + +/* Sidebar */ +.sidebar { + width: var(--sidebar-width); + background: var(--color-sidebar-bg); + padding: 16px 0; + flex-shrink: 0; +} + +.sidebar ul { list-style: none; } + +.sidebar a { + display: block; + padding: 10px 20px; + color: var(--color-sidebar-text); + text-decoration: none; + font-size: 14px; + transition: background 0.15s; +} + +.sidebar a:hover { + background: rgba(255, 255, 255, 0.08); +} + +.sidebar a.active { + color: var(--color-sidebar-active); + background: rgba(255, 255, 255, 0.12); + font-weight: 600; + border-left: 3px solid var(--color-primary); +} + +/* Content */ +.content { + flex: 1; + padding: 24px 32px; + max-width: 960px; +} + +.content h1 { + font-size: 22px; + margin-bottom: 16px; +} + +/* Utility classes */ +.text-success { color: var(--color-success); } +.text-error { color: var(--color-error); } +.text-muted { color: var(--color-text-muted); } + +.btn { + display: inline-block; + padding: 8px 16px; + font-size: 14px; + font-weight: 500; + border: none; + border-radius: 4px; + cursor: pointer; + text-decoration: none; + transition: background 0.15s; +} + +.btn-primary { + background: var(--color-primary); + color: #fff; +} +.btn-primary:hover { background: var(--color-primary-hover); } + +.btn-secondary { + background: var(--color-surface); + color: var(--color-text); + border: 1px solid var(--color-border); +} +.btn-secondary:hover { background: var(--color-bg); } + +/* Forms */ +input[type="text"], +input[type="number"], +input[type="url"], +textarea, +select { + padding: 6px 10px; + border: 1px solid var(--color-border); + border-radius: 4px; + font-size: 14px; + font-family: inherit; + width: 100%; +} + +input:focus, select:focus, textarea:focus { + outline: 2px solid var(--color-primary); + outline-offset: -1px; +} + +label { + font-size: 13px; + font-weight: 500; + color: var(--color-text-muted); +} + +/* Tables */ +table { + width: 100%; + border-collapse: collapse; + font-size: 14px; +} +th, td { + padding: 8px 12px; + text-align: left; + border-bottom: 1px solid var(--color-border); +} +th { + font-weight: 600; + font-size: 12px; + text-transform: uppercase; + letter-spacing: 0.05em; + color: var(--color-text-muted); + background: var(--color-bg); +} + +/* Inline indicators */ +.indicator { + display: inline-flex; + align-items: center; + gap: 4px; + font-size: 13px; + padding: 2px 0; +} + +/* Alert / message boxes */ +.alert { + padding: 12px 16px; + border-radius: 4px; + font-size: 14px; + margin-bottom: 16px; +} +.alert-error { + background: #fef2f2; + border: 1px solid #fecaca; + color: var(--color-error); +} +.alert-success { + background: #f0fdf4; + border: 1px solid #bbf7d0; + color: var(--color-success); +} + +/* Form sections */ +.form-section { + margin-bottom: 24px; +} +.form-section h2 { + font-size: 16px; + margin-bottom: 12px; + padding-bottom: 6px; + border-bottom: 1px solid var(--color-border); +} +.form-group { + margin-bottom: 12px; +} +.form-group label { + display: block; + margin-bottom: 4px; +} +.form-row { + display: flex; + gap: 12px; + align-items: flex-start; +} + +/* Metadata summary */ +.metadata-summary { + display: flex; + align-items: center; + gap: 8px; + margin-bottom: 12px; + font-size: 14px; +} +.badge { + display: inline-block; + padding: 2px 8px; + border-radius: 4px; + font-size: 12px; + font-weight: 600; + background: var(--color-primary); + color: #fff; +} + +/* Name check indicators */ +.name-check { + font-size: 13px; + margin-top: 4px; + display: inline-block; +} +.name-check.available { color: var(--color-success); } +.name-check.unavailable { color: var(--color-error); } +.name-check.error { color: var(--color-error); } + +/* HTMX indicator */ +.htmx-indicator { + display: none; + font-size: 13px; + color: var(--color-text-muted); +} +.htmx-request .htmx-indicator, +.htmx-request.htmx-indicator { + display: inline; +} + +/* Code inline */ +code { + font-family: "SF Mono", "Consolas", "Menlo", monospace; + font-size: 13px; +} + +/* Annotation table */ +.annotation-table td { vertical-align: top; } +.annotation-table .col-description { width: 100%; } +.links-cell { min-width: 320px; } +.links-container { display: flex; flex-direction: column; gap: 6px; margin-bottom: 4px; } + +/* Link row */ +.link-row { + display: flex; + align-items: center; + gap: 4px; + padding: 4px 0; +} +.link-row select { width: auto; min-width: 100px; font-size: 13px; padding: 4px 6px; } +.link-row .link-target { min-width: 140px; } +.link-row .link-column { min-width: 120px; } + +/* Small button variant */ +.btn-small { + padding: 3px 8px; + font-size: 12px; +} +.add-link-btn { margin-top: 2px; } +.remove-link-btn { + padding: 2px 6px; + font-size: 14px; + line-height: 1; + color: var(--color-error); + background: none; + border: 1px solid var(--color-border); + border-radius: 3px; + cursor: pointer; +} +.remove-link-btn:hover { background: #fef2f2; } + +/* Detail list (registration success) */ +.detail-list { margin: 12px 0; } +.detail-list dt { + font-size: 12px; + font-weight: 600; + text-transform: uppercase; + letter-spacing: 0.05em; + color: var(--color-text-muted); + margin-top: 8px; +} +.detail-list dd { margin: 2px 0 0; } + +/* Error details */ +.error-details { + margin: 8px 0 0 16px; + font-size: 13px; + color: var(--color-error); +} diff --git a/src/cave_catalog/table_schemas.py b/src/cave_catalog/table_schemas.py new file mode 100644 index 0000000..b724896 --- /dev/null +++ b/src/cave_catalog/table_schemas.py @@ -0,0 +1,140 @@ +"""Pydantic models for table assets. + +Covers cached metadata (format-discriminated), column annotations with links, +and table-specific request/response schemas. +""" + +from __future__ import annotations + +from datetime import datetime + +from pydantic import BaseModel, Field + +from cave_catalog.schemas import AssetRequest, AssetResponse + +# --------------------------------------------------------------------------- +# Cached metadata models (task 1.3) +# --------------------------------------------------------------------------- + + +class ColumnInfo(BaseModel): + """A single column discovered from file metadata.""" + + name: str + dtype: str + + +class TableMetadata(BaseModel): + """Cached metadata common to all table formats.""" + + n_rows: int | None = None + n_columns: int | None = None + n_bytes: int | None = None + columns: list[ColumnInfo] = Field(default_factory=list) + partition_columns: list[str] = Field(default_factory=list) + + +# --------------------------------------------------------------------------- +# Column annotation models (task 1.4) +# --------------------------------------------------------------------------- + + +class ColumnLink(BaseModel): + """Semantic link from a column to a materialization service table/column.""" + + link_type: str + target_table: str + target_column: str + + +class ColumnAnnotation(BaseModel): + """User-provided annotation for a single column.""" + + column_name: str + description: str | None = None + links: list[ColumnLink] = Field(default_factory=list) + + +# --------------------------------------------------------------------------- +# Merged column view (read-time merge of cached metadata + annotations) +# --------------------------------------------------------------------------- + + +class MergedColumn(BaseModel): + """Unified column view returned by the API.""" + + name: str + dtype: str + description: str | None = None + links: list[ColumnLink] = Field(default_factory=list) + + +# --------------------------------------------------------------------------- +# Request / response models (task 1.5) +# --------------------------------------------------------------------------- + + +class TablePreviewRequest(BaseModel): + uri: str + format: str + datastack: str + + +class TablePreviewResponse(BaseModel): + metadata: TableMetadata + + +class TableRequest(AssetRequest): + format: str # required for tables (not optional like base) + asset_type: str = "table" + source: str = "user" + column_annotations: list[ColumnAnnotation] = Field(default_factory=list) + + +class TableResponse(AssetResponse): + source: str | None = None + cached_metadata: TableMetadata | None = None + metadata_cached_at: datetime | None = None + column_annotations: list[ColumnAnnotation] = Field(default_factory=list) + columns: list[MergedColumn] = Field(default_factory=list) + + +class AnnotationUpdateRequest(BaseModel): + column_annotations: list[ColumnAnnotation] + + +# --------------------------------------------------------------------------- +# Column merging helper (task 4.7) +# --------------------------------------------------------------------------- + + +def merge_columns( + metadata: TableMetadata | None, + annotations: list[ColumnAnnotation] | None, +) -> list[MergedColumn]: + """Merge cached column schema with user-provided annotations by column name. + + For each column in ``metadata.columns``, look up matching annotation by + ``column_name``. Annotated columns get description + links; unannotated + columns get None/empty. Orphaned annotations (no matching column in + metadata) are silently dropped — they're inert until the column reappears. + """ + if not metadata or not metadata.columns: + return [] + + ann_by_name: dict[str, ColumnAnnotation] = {} + for ann in annotations or []: + ann_by_name[ann.column_name] = ann + + merged: list[MergedColumn] = [] + for col in metadata.columns: + ann = ann_by_name.get(col.name) + merged.append( + MergedColumn( + name=col.name, + dtype=col.dtype, + description=ann.description if ann else None, + links=ann.links if ann else [], + ) + ) + return merged diff --git a/src/cave_catalog/templates/base.html b/src/cave_catalog/templates/base.html new file mode 100644 index 0000000..157015b --- /dev/null +++ b/src/cave_catalog/templates/base.html @@ -0,0 +1,52 @@ + + + + + + {% block title %}CAVE Catalog{% endblock %} + + + {% block head %}{% endblock %} + + +
+
+ CAVE Catalog +
+
+ + +
+
+ {% if user is defined and user %} + + Logout + {% else %} + Login + {% endif %} +
+
+ +
+ +
+ {% block content %}{% endblock %} +
+
+ + diff --git a/src/cave_catalog/templates/explore.html b/src/cave_catalog/templates/explore.html new file mode 100644 index 0000000..58f650b --- /dev/null +++ b/src/cave_catalog/templates/explore.html @@ -0,0 +1,6 @@ +{% extends "base.html" %} +{% block title %}Explore Assets — CAVE Catalog{% endblock %} +{% block content %} +

Explore Assets

+

Coming soon.

+{% endblock %} diff --git a/src/cave_catalog/templates/fragments/preview_result.html b/src/cave_catalog/templates/fragments/preview_result.html new file mode 100644 index 0000000..36ae58d --- /dev/null +++ b/src/cave_catalog/templates/fragments/preview_result.html @@ -0,0 +1,77 @@ +{# HTMX fragment: preview results after metadata extraction #} + +{% if error %} +
+ Preview failed: {{ error }} +
+{% else %} +
+

Preview Results

+ +
+ {{ format | upper }} + {% if metadata.n_rows is not none %} + {{ "{:,}".format(metadata.n_rows) }} rows + {% endif %} + {% if metadata.n_columns is not none %} + · {{ metadata.n_columns }} columns + {% endif %} + {% if metadata.n_bytes is not none %} + {% if metadata.n_bytes >= 1073741824 %} + · {{ "%.1f"|format(metadata.n_bytes / 1073741824) }} GB + {% elif metadata.n_bytes >= 1048576 %} + · {{ "%.1f"|format(metadata.n_bytes / 1048576) }} MB + {% elif metadata.n_bytes >= 1024 %} + · {{ "%.1f"|format(metadata.n_bytes / 1024) }} KB + {% else %} + · {{ metadata.n_bytes }} B + {% endif %} + {% endif %} + {% if metadata.partition_columns %} + · partitioned by {{ metadata.partition_columns | join(", ") }} + {% endif %} +
+
+ +{# Populate the column annotations table #} + +{% endif %} diff --git a/src/cave_catalog/templates/fragments/register_error.html b/src/cave_catalog/templates/fragments/register_error.html new file mode 100644 index 0000000..66a0300 --- /dev/null +++ b/src/cave_catalog/templates/fragments/register_error.html @@ -0,0 +1,12 @@ +{# HTMX fragment: registration error #} +
+ Registration failed: {{ error }} +
+ +{% if details %} + +{% endif %} diff --git a/src/cave_catalog/templates/fragments/register_success.html b/src/cave_catalog/templates/fragments/register_success.html new file mode 100644 index 0000000..0cad6e9 --- /dev/null +++ b/src/cave_catalog/templates/fragments/register_success.html @@ -0,0 +1,25 @@ +{# HTMX fragment: registration success #} +
+ Table registered successfully! +
+ +
+
+
Table ID
+
{{ table_id }}
+
Name
+
{{ table_name }}
+
Datastack
+
{{ datastack }}
+ {% if mat_version %} +
Mat Version
+
{{ mat_version }}
+ {% endif %} +
URI
+
{{ uri }}
+
Format
+
{{ format }}
+
+
+ +Register Another diff --git a/src/cave_catalog/templates/register.html b/src/cave_catalog/templates/register.html new file mode 100644 index 0000000..f59f0dd --- /dev/null +++ b/src/cave_catalog/templates/register.html @@ -0,0 +1,155 @@ +{% extends "base.html" %} +{% block title %}Register — CAVE Catalog{% endblock %} +{% block content %} +

Register a Table

+ +
+ + {# ── Step 1: URI + Preview ── #} +
+

1. Data Source

+ +
+ + +
+ +
+
+ + +
+
+ +
+
+
+ + {# ── Preview results area (populated by HTMX) ── #} +
+ + {# ── Step 2: Registration details (shown after preview) ── #} + + +
+ +{# ── Link row template (used by JS) ── #} + + + +{% endblock %} diff --git a/src/cave_catalog/templating.py b/src/cave_catalog/templating.py new file mode 100644 index 0000000..c479078 --- /dev/null +++ b/src/cave_catalog/templating.py @@ -0,0 +1,7 @@ +"""Jinja2 template configuration for the UI.""" + +from pathlib import Path + +from fastapi.templating import Jinja2Templates + +templates = Jinja2Templates(directory=Path(__file__).resolve().parent / "templates") diff --git a/src/cave_catalog/validation.py b/src/cave_catalog/validation.py index 1ef23e5..42f747e 100644 --- a/src/cave_catalog/validation.py +++ b/src/cave_catalog/validation.py @@ -9,6 +9,7 @@ import asyncio from collections.abc import Callable, Coroutine +from dataclasses import dataclass, field from pathlib import Path from typing import Any @@ -213,6 +214,129 @@ async def check_name_reservation( return ValidationCheck(passed=True) +# --- Column link validation ------------------------------------------------- + + +@dataclass +class LinkValidationError: + """A single column link that failed validation.""" + + column_name: str + link_type: str + target_table: str + target_column: str + reason: str + + +@dataclass +class LinkValidationResult: + """Result of validating column links against the materialization service.""" + + passed: bool + errors: list[LinkValidationError] = field(default_factory=list) + skipped: bool = False + message: str | None = None + + +async def validate_column_links( + annotations: list[dict[str, Any]], + datastack: str, + client: AsyncClient, + token: str = "", +) -> LinkValidationResult: + """Validate column link targets against the materialization service. + + Checks that each ``target_table`` referenced in column annotations exists + in the materialization service for the given datastack. Column existence + is not validated (the ME API does not expose flat column names directly). + + Parameters + ---------- + annotations + List of column annotation dicts, each with ``column_name``, optional + ``links`` list containing ``{link_type, target_table, target_column}``. + datastack + Datastack name to validate against. + client + httpx async client for ME calls. + token + Optional bearer token. + + Returns + ------- + LinkValidationResult + """ + # Collect all unique target tables referenced in links + links_by_table: dict[str, list[tuple[str, dict[str, str]]]] = {} + for ann in annotations: + col_name = ann.get("column_name", "") + for link in ann.get("links", []): + target = link.get("target_table", "") + if target: + links_by_table.setdefault(target, []).append((col_name, link)) + + if not links_by_table: + return LinkValidationResult(passed=True) + + settings = get_settings() + if not settings.mat_engine_url: + logger.warning("link_validation_skipped", reason="MAT_ENGINE_URL not configured") + return LinkValidationResult( + passed=True, + skipped=True, + message="Column link validation skipped: MAT_ENGINE_URL not configured", + ) + + # Fetch table list from ME + base = settings.mat_engine_url.rstrip("/") + url = f"{base}/api/v2/datastack/{datastack}/tables" + headers = {"Authorization": f"Bearer {token}"} if token else {} + + try: + response = await client.get(url, headers=headers, timeout=15.0) + except Exception as exc: + logger.warning("link_validation_skipped", reason=str(exc)) + return LinkValidationResult( + passed=True, + skipped=True, + message=f"Column link validation skipped: {exc}", + ) + + if response.status_code in (301, 302, 303, 307, 308, 401, 403): + return LinkValidationResult( + passed=True, + skipped=True, + message=f"Column link validation skipped: ME auth failed (HTTP {response.status_code})", + ) + if response.status_code != 200: + return LinkValidationResult( + passed=True, + skipped=True, + message=f"Column link validation skipped: ME returned HTTP {response.status_code}", + ) + + mat_tables: set[str] = set(response.json()) + + # Validate each target table exists + errors: list[LinkValidationError] = [] + for target_table, col_links in links_by_table.items(): + if target_table not in mat_tables: + for col_name, link in col_links: + errors.append( + LinkValidationError( + column_name=col_name, + link_type=link.get("link_type", ""), + target_table=target_table, + target_column=link.get("target_column", ""), + reason=f"Table '{target_table}' not found in materialization service for datastack '{datastack}'", + ) + ) + + if errors: + return LinkValidationResult(passed=False, errors=errors) + return LinkValidationResult(passed=True) + + # --- Main pipeline ---------------------------------------------------------- diff --git a/tests/test_assets.py b/tests/test_assets.py index 7a0eabb..1561082 100644 --- a/tests/test_assets.py +++ b/tests/test_assets.py @@ -1,4 +1,4 @@ -"""Tests for asset registry endpoints (task 2.8). +"""Tests for asset registry endpoints (task 2.8) and Phase 5 unified reads. Uses FastAPI dependency_overrides with a real in-memory SQLite DB so no live Postgres is needed. External HTTP calls (URI reachability, format sniff, ME @@ -12,6 +12,7 @@ from unittest.mock import AsyncMock from cave_catalog.schemas import ValidationCheck, ValidationReport +from cave_catalog.table_schemas import ColumnInfo, TableMetadata # --------------------------------------------------------------------------- # Helpers @@ -288,3 +289,143 @@ async def test_delete_asset_success(client, monkeypatch): async def test_delete_asset_not_found(client): response = await client.delete(f"/api/v1/assets/{uuid.uuid4()}") assert response.status_code == 404 + + +# --------------------------------------------------------------------------- +# Phase 5: Unified read surface returns table-specific fields +# --------------------------------------------------------------------------- + +_TABLE_METADATA = TableMetadata( + n_rows=100, + n_columns=2, + n_bytes=5000, + columns=[ + ColumnInfo(name="a", dtype="int64"), + ColumnInfo(name="b", dtype="string"), + ], + partition_columns=[], +) + + +def _patch_table_helpers(monkeypatch): + """Patch extraction, validation, and link validation for table registration.""" + mock_extractor = AsyncMock() + mock_extractor.extract = AsyncMock(return_value=_TABLE_METADATA) + monkeypatch.setattr( + "cave_catalog.routers.tables.get_extractor", + lambda fmt: mock_extractor, + ) + + from cave_catalog.validation import LinkValidationResult + + monkeypatch.setattr( + "cave_catalog.routers.tables.validate_column_links", + AsyncMock(return_value=LinkValidationResult(passed=True, errors=[])), + ) + + +def _table_payload(**overrides: Any) -> dict: + base = { + "datastack": "minnie65_public", + "name": "my_table", + "revision": 0, + "uri": "gs://bucket/tables/my_table/", + "format": "delta", + "is_managed": True, + "mutability": "static", + "maturity": "stable", + "properties": {}, + } + base.update(overrides) + return base + + +async def _register_table_via_tables_api(client, monkeypatch, **overrides) -> dict: + _patch_table_helpers(monkeypatch) + # Patch validation on the tables router (not the assets router) + monkeypatch.setattr( + "cave_catalog.routers.tables.run_validation_pipeline", + AsyncMock(return_value=_passing_report()), + ) + resp = await client.post( + "/api/v1/tables/register", json=_table_payload(**overrides) + ) + assert resp.status_code == 201, resp.text + return resp.json() + + +async def test_list_assets_includes_table_fields(client, monkeypatch): + """GET /assets/ should include table-specific fields for table assets.""" + table = await _register_table_via_tables_api(client, monkeypatch) + + response = await client.get("/api/v1/assets/?datastack=minnie65_public") + assert response.status_code == 200 + data = response.json() + assert len(data) == 1 + item = data[0] + + # Table-specific fields should be present + assert item["id"] == table["id"] + assert item["asset_type"] == "table" + assert item["source"] is not None + assert item["cached_metadata"] is not None + assert item["cached_metadata"]["n_rows"] == 100 + assert "columns" in item + assert len(item["columns"]) == 2 + + +async def test_get_asset_by_id_returns_table_fields(client, monkeypatch): + """GET /assets/{id} should return TableResponse with merged columns for table assets.""" + table = await _register_table_via_tables_api(client, monkeypatch) + + response = await client.get(f"/api/v1/assets/{table['id']}") + assert response.status_code == 200 + data = response.json() + + assert data["asset_type"] == "table" + assert data["cached_metadata"]["n_columns"] == 2 + assert len(data["columns"]) == 2 + assert data["columns"][0]["name"] == "a" + assert data["columns"][0]["dtype"] == "int64" + + +async def test_list_assets_mixed_types(client, monkeypatch): + """GET /assets/ should return both table and non-table assets with correct fields.""" + # Register a plain asset + await _register(client, monkeypatch, name="plain_asset", asset_type="asset") + + # Register a table + await _register_table_via_tables_api( + client, monkeypatch, name="table_asset" + ) + + response = await client.get("/api/v1/assets/?datastack=minnie65_public") + assert response.status_code == 200 + data = response.json() + assert len(data) == 2 + + by_name = {item["name"]: item for item in data} + + # Plain asset should NOT have table fields + plain = by_name["plain_asset"] + assert "columns" not in plain or plain.get("columns") is None + + # Table asset should HAVE table fields + table = by_name["table_asset"] + assert table["asset_type"] == "table" + assert "columns" in table + assert len(table["columns"]) == 2 + assert table["cached_metadata"] is not None + + +async def test_register_non_table_asset_still_works(client, monkeypatch): + """POST /assets/register for non-table assets should still work without table fields.""" + _patch_validation(monkeypatch) + payload = _asset_payload(asset_type="asset", name="generic_file") + resp = await client.post("/api/v1/assets/register", json=payload) + assert resp.status_code == 201 + data = resp.json() + assert data["asset_type"] == "asset" + assert data["name"] == "generic_file" + # No table-specific fields in response + assert "columns" not in data or data.get("columns") is None diff --git a/tests/test_extractors.py b/tests/test_extractors.py new file mode 100644 index 0000000..f0f71da --- /dev/null +++ b/tests/test_extractors.py @@ -0,0 +1,149 @@ +"""Tests for metadata extractors: Delta and Parquet. + +Phase 2 tests — covers tasks 2.1–2.4. +Uses real fixture data on disk (same pattern as test_validation.py). +""" + +from __future__ import annotations + +import polars as pl +import pytest +from cave_catalog.extractors import ( + DeltaMetadataExtractor, + ParquetMetadataExtractor, + get_extractor, +) +from cave_catalog.table_schemas import TableMetadata + +# --------------------------------------------------------------------------- +# Fixtures (reuse the same pattern as test_validation.py) +# --------------------------------------------------------------------------- + + +@pytest.fixture +def parquet_file(tmp_path) -> str: + path = tmp_path / "data.parquet" + pl.DataFrame({"a": [1, 2, 3], "b": ["x", "y", "z"]}).write_parquet(str(path)) + return str(path) + + +@pytest.fixture +def delta_table(tmp_path) -> str: + path = tmp_path / "delta_table" + pl.DataFrame({"a": [1, 2, 3], "b": ["x", "y", "z"]}).write_delta(str(path)) + return str(path) + + +@pytest.fixture +def partitioned_delta(tmp_path) -> str: + path = tmp_path / "partitioned_delta" + pl.DataFrame( + {"part": ["x", "x", "y"], "val": [1, 2, 3]} + ).write_delta(str(path), delta_write_options={"partition_by": ["part"]}) + return str(path) + + +# --------------------------------------------------------------------------- +# Delta extractor +# --------------------------------------------------------------------------- + + +async def test_delta_extract_columns(delta_table): + ext = DeltaMetadataExtractor() + result = await ext.extract(delta_table) + + assert isinstance(result, TableMetadata) + col_names = [c.name for c in result.columns] + assert "a" in col_names + assert "b" in col_names + + +async def test_delta_extract_counts(delta_table): + ext = DeltaMetadataExtractor() + result = await ext.extract(delta_table) + + assert result.n_columns == 2 + assert result.n_rows == 3 + assert result.n_bytes is not None and result.n_bytes > 0 + + +async def test_delta_extract_partition_columns(partitioned_delta): + ext = DeltaMetadataExtractor() + result = await ext.extract(partitioned_delta) + + assert result.partition_columns == ["part"] + + +async def test_delta_extract_no_partitions(delta_table): + ext = DeltaMetadataExtractor() + result = await ext.extract(delta_table) + + assert result.partition_columns == [] + + +async def test_delta_extract_invalid_uri(tmp_path): + ext = DeltaMetadataExtractor() + with pytest.raises(Exception): + await ext.extract(str(tmp_path / "nonexistent")) + + +# --------------------------------------------------------------------------- +# Parquet extractor +# --------------------------------------------------------------------------- + + +async def test_parquet_extract_columns(parquet_file): + ext = ParquetMetadataExtractor() + result = await ext.extract(parquet_file) + + assert isinstance(result, TableMetadata) + col_names = [c.name for c in result.columns] + assert "a" in col_names + assert "b" in col_names + + +async def test_parquet_extract_counts(parquet_file): + ext = ParquetMetadataExtractor() + result = await ext.extract(parquet_file) + + assert result.n_columns == 2 + assert result.n_rows == 3 + assert result.n_bytes is not None and result.n_bytes > 0 + + +async def test_parquet_extract_no_partition_columns(parquet_file): + ext = ParquetMetadataExtractor() + result = await ext.extract(parquet_file) + + assert result.partition_columns == [] + + +async def test_parquet_extract_invalid_uri(tmp_path): + ext = ParquetMetadataExtractor() + with pytest.raises(Exception): + await ext.extract(str(tmp_path / "nonexistent.parquet")) + + +# --------------------------------------------------------------------------- +# Extractor registry +# --------------------------------------------------------------------------- + + +def test_get_extractor_delta(): + ext = get_extractor("delta") + assert isinstance(ext, DeltaMetadataExtractor) + + +def test_get_extractor_parquet(): + ext = get_extractor("parquet") + assert isinstance(ext, ParquetMetadataExtractor) + + +def test_get_extractor_case_insensitive(): + ext = get_extractor("Delta") + assert isinstance(ext, DeltaMetadataExtractor) + + +def test_get_extractor_unknown_raises(): + with pytest.raises(ValueError, match="No metadata extractor"): + get_extractor("lance") diff --git a/tests/test_link_validation.py b/tests/test_link_validation.py new file mode 100644 index 0000000..9298920 --- /dev/null +++ b/tests/test_link_validation.py @@ -0,0 +1,270 @@ +"""Tests for column link validation against the materialization service. + +Phase 3 tests — covers task 3.1 (column link validator). +Uses httpx mocking to simulate ME API responses. +""" + +from __future__ import annotations + +import httpx +import pytest +from cave_catalog.validation import ( + LinkValidationError, + LinkValidationResult, + validate_column_links, +) + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def _annotations_with_links(*links: tuple[str, str, str, str]) -> list[dict]: + """Build annotations list from (col_name, link_type, target_table, target_col) tuples.""" + by_col: dict[str, list[dict]] = {} + for col_name, link_type, target_table, target_col in links: + by_col.setdefault(col_name, []).append( + { + "link_type": link_type, + "target_table": target_table, + "target_column": target_col, + } + ) + return [ + {"column_name": col, "links": col_links} + for col, col_links in by_col.items() + ] + + +def _mock_transport(status_code: int = 200, json_body: list | None = None): + """Return an httpx transport that returns a canned response.""" + + async def handler(request: httpx.Request) -> httpx.Response: + body = json_body if json_body is not None else [] + return httpx.Response(status_code, json=body) + + return httpx.MockTransport(handler) + + +# --------------------------------------------------------------------------- +# No links → passes trivially +# --------------------------------------------------------------------------- + + +async def test_no_links_passes(): + annotations = [{"column_name": "a", "description": "col a"}] + async with httpx.AsyncClient() as client: + result = await validate_column_links(annotations, "ds1", client) + assert result.passed is True + assert result.errors == [] + + +async def test_empty_annotations_passes(): + async with httpx.AsyncClient() as client: + result = await validate_column_links([], "ds1", client) + assert result.passed is True + + +# --------------------------------------------------------------------------- +# ME not configured → skipped +# --------------------------------------------------------------------------- + + +async def test_skipped_when_me_not_configured(monkeypatch): + monkeypatch.setenv("MAT_ENGINE_URL", "") + from cave_catalog.config import get_settings + + get_settings.cache_clear() + + annotations = _annotations_with_links( + ("col_a", "foreign_key", "synapses", "id"), + ) + async with httpx.AsyncClient() as client: + result = await validate_column_links(annotations, "ds1", client) + + assert result.passed is True + assert result.skipped is True + assert "not configured" in result.message + + get_settings.cache_clear() + + +# --------------------------------------------------------------------------- +# Valid links → passes +# --------------------------------------------------------------------------- + + +async def test_valid_table_link_passes(monkeypatch): + monkeypatch.setenv("MAT_ENGINE_URL", "http://me:5000") + from cave_catalog.config import get_settings + + get_settings.cache_clear() + + annotations = _annotations_with_links( + ("pre_pt_root_id", "foreign_key", "synapses", "pre_pt_root_id"), + ("post_pt_root_id", "foreign_key", "synapses", "post_pt_root_id"), + ) + transport = _mock_transport(200, ["synapses", "nucleus_detection_v0"]) + async with httpx.AsyncClient(transport=transport) as client: + result = await validate_column_links(annotations, "minnie65", client) + + assert result.passed is True + assert result.errors == [] + + get_settings.cache_clear() + + +async def test_multiple_tables_all_valid(monkeypatch): + monkeypatch.setenv("MAT_ENGINE_URL", "http://me:5000") + from cave_catalog.config import get_settings + + get_settings.cache_clear() + + annotations = _annotations_with_links( + ("syn_id", "foreign_key", "synapses", "id"), + ("cell_id", "foreign_key", "nucleus_detection_v0", "id"), + ) + transport = _mock_transport(200, ["synapses", "nucleus_detection_v0"]) + async with httpx.AsyncClient(transport=transport) as client: + result = await validate_column_links(annotations, "minnie65", client) + + assert result.passed is True + + get_settings.cache_clear() + + +# --------------------------------------------------------------------------- +# Invalid links → fails +# --------------------------------------------------------------------------- + + +async def test_invalid_table_fails(monkeypatch): + monkeypatch.setenv("MAT_ENGINE_URL", "http://me:5000") + from cave_catalog.config import get_settings + + get_settings.cache_clear() + + annotations = _annotations_with_links( + ("col_a", "foreign_key", "nonexistent_table", "id"), + ) + transport = _mock_transport(200, ["synapses"]) + async with httpx.AsyncClient(transport=transport) as client: + result = await validate_column_links(annotations, "minnie65", client) + + assert result.passed is False + assert len(result.errors) == 1 + err = result.errors[0] + assert err.target_table == "nonexistent_table" + assert err.column_name == "col_a" + assert "not found" in err.reason + + get_settings.cache_clear() + + +async def test_mixed_valid_and_invalid(monkeypatch): + monkeypatch.setenv("MAT_ENGINE_URL", "http://me:5000") + from cave_catalog.config import get_settings + + get_settings.cache_clear() + + annotations = _annotations_with_links( + ("good_col", "foreign_key", "synapses", "id"), + ("bad_col", "foreign_key", "fake_table", "id"), + ) + transport = _mock_transport(200, ["synapses"]) + async with httpx.AsyncClient(transport=transport) as client: + result = await validate_column_links(annotations, "minnie65", client) + + assert result.passed is False + assert len(result.errors) == 1 + assert result.errors[0].target_table == "fake_table" + + get_settings.cache_clear() + + +async def test_multiple_links_to_same_bad_table(monkeypatch): + monkeypatch.setenv("MAT_ENGINE_URL", "http://me:5000") + from cave_catalog.config import get_settings + + get_settings.cache_clear() + + annotations = _annotations_with_links( + ("col_a", "foreign_key", "bad_table", "id"), + ("col_b", "derived_from", "bad_table", "value"), + ) + transport = _mock_transport(200, ["synapses"]) + async with httpx.AsyncClient(transport=transport) as client: + result = await validate_column_links(annotations, "minnie65", client) + + assert result.passed is False + assert len(result.errors) == 2 + + get_settings.cache_clear() + + +# --------------------------------------------------------------------------- +# ME errors → graceful skip +# --------------------------------------------------------------------------- + + +async def test_me_auth_failure_skips(monkeypatch): + monkeypatch.setenv("MAT_ENGINE_URL", "http://me:5000") + from cave_catalog.config import get_settings + + get_settings.cache_clear() + + annotations = _annotations_with_links( + ("col_a", "foreign_key", "synapses", "id"), + ) + transport = _mock_transport(403) + async with httpx.AsyncClient(transport=transport) as client: + result = await validate_column_links(annotations, "minnie65", client) + + assert result.passed is True + assert result.skipped is True + assert "auth failed" in result.message + + get_settings.cache_clear() + + +async def test_me_server_error_skips(monkeypatch): + monkeypatch.setenv("MAT_ENGINE_URL", "http://me:5000") + from cave_catalog.config import get_settings + + get_settings.cache_clear() + + annotations = _annotations_with_links( + ("col_a", "foreign_key", "synapses", "id"), + ) + transport = _mock_transport(500) + async with httpx.AsyncClient(transport=transport) as client: + result = await validate_column_links(annotations, "minnie65", client) + + assert result.passed is True + assert result.skipped is True + + get_settings.cache_clear() + + +async def test_me_connection_error_skips(monkeypatch): + monkeypatch.setenv("MAT_ENGINE_URL", "http://me:5000") + from cave_catalog.config import get_settings + + get_settings.cache_clear() + + annotations = _annotations_with_links( + ("col_a", "foreign_key", "synapses", "id"), + ) + + async def _raise(request: httpx.Request) -> httpx.Response: + raise httpx.ConnectError("connection refused") + + transport = httpx.MockTransport(_raise) + async with httpx.AsyncClient(transport=transport) as client: + result = await validate_column_links(annotations, "minnie65", client) + + assert result.passed is True + assert result.skipped is True + + get_settings.cache_clear() diff --git a/tests/test_mat_proxy.py b/tests/test_mat_proxy.py new file mode 100644 index 0000000..b2a8be8 --- /dev/null +++ b/tests/test_mat_proxy.py @@ -0,0 +1,164 @@ +"""Tests for mat_proxy module — caching and CAVEclient integration.""" + +from unittest.mock import MagicMock, patch + +import pytest + +from cave_catalog.mat_proxy import ( + MatProxyError, + _columns_cache, + _tables_cache, + _views_cache, + get_linkable_targets, + get_mat_tables, + get_mat_views, + get_target_columns, +) + + +@pytest.fixture(autouse=True) +def clear_caches(): + """Clear all caches before each test.""" + _tables_cache.clear() + _views_cache.clear() + _columns_cache.clear() + yield + _tables_cache.clear() + _views_cache.clear() + _columns_cache.clear() + + +@pytest.fixture +def mock_settings(): + """Provide settings with CAVE_TOKEN configured.""" + with patch("cave_catalog.mat_proxy.get_settings") as mock: + settings = MagicMock() + settings.cave_token = "test-token" + settings.caveclient_server_address = "https://test-server.com" + mock.return_value = settings + yield settings + + +@pytest.fixture +def mock_caveclient(): + """Mock CAVEclient constructor.""" + with patch("cave_catalog.mat_proxy.CAVEclient") as mock_cls: + client = MagicMock() + mock_cls.return_value = client + yield client + + +class TestGetMatTables: + async def test_fetches_tables(self, mock_settings, mock_caveclient): + mock_caveclient.materialize.get_tables.return_value = [ + "synapses", + "nucleus_detection", + ] + result = await get_mat_tables("minnie65_phase3") + assert result == ["synapses", "nucleus_detection"] + + async def test_cache_hit(self, mock_settings, mock_caveclient): + mock_caveclient.materialize.get_tables.return_value = ["synapses"] + await get_mat_tables("minnie65_phase3", version=1) + await get_mat_tables("minnie65_phase3", version=1) + # Should only call CAVEclient once due to cache + assert mock_caveclient.materialize.get_tables.call_count == 1 + + async def test_different_version_is_cache_miss( + self, mock_settings, mock_caveclient + ): + mock_caveclient.materialize.get_tables.return_value = ["synapses"] + await get_mat_tables("minnie65_phase3", version=1) + await get_mat_tables("minnie65_phase3", version=2) + assert mock_caveclient.materialize.get_tables.call_count == 2 + + async def test_error_without_cave_token(self): + with patch("cave_catalog.mat_proxy.get_settings") as mock: + settings = MagicMock() + settings.cave_token = None + mock.return_value = settings + with pytest.raises(MatProxyError, match="CAVE_TOKEN is not configured"): + await get_mat_tables("minnie65_phase3") + + async def test_wraps_unexpected_error(self, mock_settings, mock_caveclient): + mock_caveclient.materialize.get_tables.side_effect = RuntimeError("timeout") + with pytest.raises(MatProxyError, match="Failed to fetch tables"): + await get_mat_tables("minnie65_phase3") + + +class TestGetMatViews: + async def test_fetches_views(self, mock_settings, mock_caveclient): + mock_caveclient.materialize.get_views.return_value = [ + "synapse_with_nucleus", + "cell_type_view", + ] + result = await get_mat_views("minnie65_phase3") + assert result == ["synapse_with_nucleus", "cell_type_view"] + + async def test_cache_hit(self, mock_settings, mock_caveclient): + mock_caveclient.materialize.get_views.return_value = ["view1"] + await get_mat_views("minnie65_phase3") + await get_mat_views("minnie65_phase3") + assert mock_caveclient.materialize.get_views.call_count == 1 + + +class TestGetLinkableTargets: + async def test_combines_tables_and_views(self, mock_settings, mock_caveclient): + mock_caveclient.materialize.get_tables.return_value = ["b_table", "a_table"] + mock_caveclient.materialize.get_views.return_value = ["c_view"] + targets = await get_linkable_targets("minnie65_phase3") + # Should be sorted by name + names = [t.name for t in targets] + assert names == ["a_table", "b_table", "c_view"] + types = [t.target_type for t in targets] + assert types == ["table", "table", "view"] + + +class TestGetTargetColumns: + async def test_table_columns_via_schema(self, mock_settings, mock_caveclient): + mock_caveclient.materialize.get_table_metadata.return_value = { + "schema_type": "synapse" + } + mock_caveclient.schema.schema_definition.return_value = { + "$schema": "http://json-schema.org/draft-07/schema#", + "definitions": { + "BoundSpatialPoint": {"type": "object"}, + "SynapseSchema": { + "type": "object", + "properties": { + "pre_pt": {"$ref": "#/definitions/BoundSpatialPoint"}, + "post_pt": {"$ref": "#/definitions/BoundSpatialPoint"}, + }, + }, + }, + "$ref": "#/definitions/SynapseSchema", + } + result = await get_target_columns( + "minnie65_phase3", "synapses", "table" + ) + assert len(result) == 2 + assert result[0]["name"] == "pre_pt" + + async def test_view_columns_direct(self, mock_settings, mock_caveclient): + mock_caveclient.materialize.get_view_schema.return_value = { + "id": "integer", + "cell_type": "string", + } + result = await get_target_columns( + "minnie65_phase3", "cell_type_view", "view" + ) + assert len(result) == 2 + col_names = [c["name"] for c in result] + assert "id" in col_names + assert "cell_type" in col_names + + async def test_cache_hit(self, mock_settings, mock_caveclient): + mock_caveclient.materialize.get_view_schema.return_value = {"id": "integer"} + await get_target_columns("minnie65_phase3", "v1", "view") + await get_target_columns("minnie65_phase3", "v1", "view") + assert mock_caveclient.materialize.get_view_schema.call_count == 1 + + async def test_error_on_missing_schema_type(self, mock_settings, mock_caveclient): + mock_caveclient.materialize.get_table_metadata.return_value = {} + with pytest.raises(MatProxyError, match="Could not determine schema type"): + await get_target_columns("minnie65_phase3", "bad_table", "table") diff --git a/tests/test_name_check.py b/tests/test_name_check.py new file mode 100644 index 0000000..738232a --- /dev/null +++ b/tests/test_name_check.py @@ -0,0 +1,173 @@ +"""Tests for the name availability check endpoint and UI fragment.""" + +from __future__ import annotations + +from unittest.mock import AsyncMock + +import pytest + +from cave_catalog.schemas import ValidationCheck + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def _asset_payload(**overrides): + base = { + "datastack": "minnie65_public", + "name": "my_table", + "mat_version": 943, + "revision": 0, + "uri": "gs://bucket/data/", + "format": "delta", + "asset_type": "table", + "is_managed": True, + "mutability": "static", + "maturity": "stable", + "properties": {}, + } + base.update(overrides) + return base + + +def _passing_report(): + from cave_catalog.schemas import ValidationReport + + return ValidationReport( + auth_check=ValidationCheck(passed=True), + duplicate_check=ValidationCheck(passed=True), + name_reservation_check=ValidationCheck(passed=True), + uri_reachable=ValidationCheck(passed=True), + format_sniff=ValidationCheck(passed=True), + ) + + +# --------------------------------------------------------------------------- +# API endpoint tests: GET /api/v1/assets/check-name +# --------------------------------------------------------------------------- + + +class TestCheckNameAPI: + async def test_name_available(self, client, monkeypatch): + """Name that is not reserved and has no duplicate returns available.""" + monkeypatch.setattr( + "cave_catalog.routers.assets._check_name_reservation", + AsyncMock(return_value=ValidationCheck(passed=True)), + ) + resp = await client.get( + "/api/v1/assets/check-name", + params={"datastack": "minnie65_public", "name": "new_table"}, + ) + assert resp.status_code == 200 + data = resp.json() + assert data["available"] is True + + async def test_name_reserved(self, client, monkeypatch): + """Name that matches a mat table returns reserved.""" + monkeypatch.setattr( + "cave_catalog.routers.assets._check_name_reservation", + AsyncMock( + return_value=ValidationCheck(passed=False, message="reserved") + ), + ) + resp = await client.get( + "/api/v1/assets/check-name", + params={"datastack": "minnie65_public", "name": "synapses"}, + ) + assert resp.status_code == 200 + data = resp.json() + assert data["available"] is False + assert data["reason"] == "reserved" + + async def test_name_duplicate(self, client, monkeypatch): + """Name that already exists as an asset returns duplicate.""" + # Bypass validation for register + monkeypatch.setattr( + "cave_catalog.routers.assets.run_validation_pipeline", + AsyncMock(return_value=_passing_report()), + ) + # Register an asset first + await client.post( + "/api/v1/assets/register", + json=_asset_payload(name="taken_name"), + ) + + # Now check name — reservation passes but duplicate exists + monkeypatch.setattr( + "cave_catalog.routers.assets._check_name_reservation", + AsyncMock(return_value=ValidationCheck(passed=True)), + ) + resp = await client.get( + "/api/v1/assets/check-name", + params={ + "datastack": "minnie65_public", + "name": "taken_name", + "mat_version": 943, + "revision": 0, + }, + ) + assert resp.status_code == 200 + data = resp.json() + assert data["available"] is False + assert data["reason"] == "duplicate" + assert "existing_id" in data + + +# --------------------------------------------------------------------------- +# UI fragment tests: GET /ui/fragments/check-name +# --------------------------------------------------------------------------- + + +class TestCheckNameFragment: + async def test_empty_name_returns_empty(self, client): + """Empty name returns empty response.""" + resp = await client.get( + "/ui/fragments/check-name", + params={"name": ""}, + ) + assert resp.status_code == 200 + assert resp.text == "" + + async def test_available_name_shows_check(self, client, monkeypatch): + """Available name returns ✓ fragment.""" + monkeypatch.setenv("DATASTACKS", "minnie65_public") + from cave_catalog.config import get_settings + + get_settings.cache_clear() + + monkeypatch.setattr( + "cave_catalog.routers.ui.check_name_reservation", + AsyncMock(return_value=ValidationCheck(passed=True)), + ) + resp = await client.get( + "/ui/fragments/check-name", + params={"name": "new_table"}, + cookies={"cave_catalog_datastack": "minnie65_public"}, + ) + assert resp.status_code == 200 + assert "✓" in resp.text or "✓" in resp.text + assert "Available" in resp.text + + async def test_reserved_name_shows_x(self, client, monkeypatch): + """Reserved name returns ✗ fragment.""" + monkeypatch.setenv("DATASTACKS", "minnie65_public") + from cave_catalog.config import get_settings + + get_settings.cache_clear() + + monkeypatch.setattr( + "cave_catalog.routers.ui.check_name_reservation", + AsyncMock( + return_value=ValidationCheck(passed=False, message="reserved") + ), + ) + resp = await client.get( + "/ui/fragments/check-name", + params={"name": "synapses"}, + cookies={"cave_catalog_datastack": "minnie65_public"}, + ) + assert resp.status_code == 200 + assert "✗" in resp.text or "✗" in resp.text + assert "reserved" in resp.text.lower() diff --git a/tests/test_table_model.py b/tests/test_table_model.py new file mode 100644 index 0000000..e45fd11 --- /dev/null +++ b/tests/test_table_model.py @@ -0,0 +1,373 @@ +"""Tests for table data model: single table inheritance and Pydantic schemas. + +Phase 1 tests — covers tasks 1.1–1.5. +""" + +from __future__ import annotations + +from datetime import datetime, timezone + +import pytest +import pytest_asyncio +from cave_catalog.db.models import Asset, Base, Table +from cave_catalog.table_schemas import ( + AnnotationUpdateRequest, + ColumnAnnotation, + ColumnInfo, + ColumnLink, + MergedColumn, + TableMetadata, + TablePreviewRequest, + TablePreviewResponse, + TableRequest, + TableResponse, +) +from sqlalchemy.ext.asyncio import async_sessionmaker, create_async_engine + + +# --------------------------------------------------------------------------- +# DB fixtures +# --------------------------------------------------------------------------- + + +@pytest_asyncio.fixture +async def db_session(tmp_path): + """Async SQLAlchemy session backed by a per-test file SQLite DB.""" + db_path = tmp_path / "test.db" + engine = create_async_engine(f"sqlite+aiosqlite:///{db_path}", echo=False) + async with engine.begin() as conn: + await conn.run_sync(Base.metadata.create_all) + + factory = async_sessionmaker(engine, expire_on_commit=False) + async with factory() as session: + yield session + + await engine.dispose() + + +# --------------------------------------------------------------------------- +# 1.1 / 1.2 Single table inheritance +# --------------------------------------------------------------------------- + + +async def test_create_table_asset(db_session): + """Table assets are persisted via the Table model with asset_type='table'.""" + table = Table( + datastack="minnie65", + name="synapses_v943", + uri="gs://bucket/synapses/", + format="delta", + asset_type="table", + owner=1, + is_managed=True, + source="user", + ) + db_session.add(table) + await db_session.commit() + + result = await db_session.get(Table, table.id) + assert result is not None + assert result.asset_type == "table" + assert result.source == "user" + assert result.name == "synapses_v943" + + +async def test_table_loads_as_table_via_base_query(db_session): + """Querying Asset should return Table instances for asset_type='table'.""" + from sqlalchemy import select + + table = Table( + datastack="minnie65", + name="synapses", + uri="gs://bucket/synapses/", + format="delta", + asset_type="table", + owner=1, + is_managed=True, + source="managed", + ) + db_session.add(table) + await db_session.commit() + + stmt = select(Asset).where(Asset.name == "synapses") + result = (await db_session.execute(stmt)).scalar_one() + assert isinstance(result, Table) + assert result.source == "managed" + + +async def test_base_asset_still_works(db_session): + """Non-table assets use the base Asset class and have null table-specific fields.""" + asset = Asset( + datastack="minnie65", + name="image_volume", + uri="gs://bucket/images/", + format="precomputed", + asset_type="asset", + owner=1, + is_managed=False, + ) + db_session.add(asset) + await db_session.commit() + + result = await db_session.get(Asset, asset.id) + assert result is not None + assert result.asset_type == "asset" + assert result.source is None + assert result.cached_metadata is None + assert result.column_annotations is None + + +async def test_unknown_asset_type_loads_as_base(db_session): + """Unknown asset_type values load as Asset (fallback polymorphic map).""" + asset = Asset( + datastack="minnie65", + name="custom_thing", + uri="gs://bucket/custom/", + asset_type="unknown_type", + owner=1, + is_managed=False, + ) + db_session.add(asset) + await db_session.commit() + + result = await db_session.get(Asset, asset.id) + assert result is not None + assert type(result) is Asset + assert result.asset_type == "unknown_type" + + +async def test_table_with_cached_metadata(db_session): + """Table-specific JSONB fields round-trip through the DB.""" + now = datetime.now(timezone.utc) + meta = {"n_rows": 1000, "n_columns": 5, "n_bytes": 50000, "columns": [], "partition_columns": []} + annotations = [{"column_name": "pt_root_id", "description": "Root ID", "links": []}] + + table = Table( + datastack="minnie65", + name="cells", + uri="gs://bucket/cells/", + format="delta", + asset_type="table", + owner=1, + is_managed=True, + source="user", + cached_metadata=meta, + metadata_cached_at=now, + column_annotations=annotations, + ) + db_session.add(table) + await db_session.commit() + + result = await db_session.get(Table, table.id) + assert result.cached_metadata["n_rows"] == 1000 + assert result.metadata_cached_at is not None + assert result.column_annotations[0]["column_name"] == "pt_root_id" + + +async def test_mixed_asset_types_in_same_query(db_session): + """Base and Table assets coexist and can be queried together.""" + from sqlalchemy import select + + base = Asset( + datastack="minnie65", + name="image_vol", + uri="gs://bucket/images/", + asset_type="asset", + owner=1, + is_managed=False, + ) + table = Table( + datastack="minnie65", + name="synapses", + uri="gs://bucket/synapses/", + format="delta", + asset_type="table", + owner=1, + is_managed=True, + source="user", + ) + db_session.add_all([base, table]) + await db_session.commit() + + stmt = select(Asset).where(Asset.datastack == "minnie65") + results = (await db_session.execute(stmt)).scalars().all() + assert len(results) == 2 + types = {type(r) for r in results} + assert types == {Asset, Table} + + +# --------------------------------------------------------------------------- +# 1.3 TableMetadata Pydantic model +# --------------------------------------------------------------------------- + + +def test_table_metadata_defaults(): + meta = TableMetadata() + assert meta.n_rows is None + assert meta.n_columns is None + assert meta.n_bytes is None + assert meta.columns == [] + assert meta.partition_columns == [] + + +def test_table_metadata_full(): + meta = TableMetadata( + n_rows=100, + n_columns=3, + n_bytes=5000, + columns=[ColumnInfo(name="a", dtype="int64"), ColumnInfo(name="b", dtype="string")], + partition_columns=["a"], + ) + assert meta.n_rows == 100 + assert len(meta.columns) == 2 + assert meta.partition_columns == ["a"] + + +def test_table_metadata_roundtrip_json(): + meta = TableMetadata( + n_rows=10, + n_columns=2, + n_bytes=1000, + columns=[ColumnInfo(name="x", dtype="float64")], + partition_columns=["x"], + ) + data = meta.model_dump() + restored = TableMetadata.model_validate(data) + assert restored == meta + + +# --------------------------------------------------------------------------- +# 1.4 ColumnAnnotation / ColumnLink +# --------------------------------------------------------------------------- + + +def test_column_annotation_minimal(): + ann = ColumnAnnotation(column_name="pt_root_id") + assert ann.description is None + assert ann.links == [] + + +def test_column_annotation_with_links(): + ann = ColumnAnnotation( + column_name="pre_pt_root_id", + description="Pre-synaptic root ID", + links=[ + ColumnLink( + link_type="foreign_key", + target_table="synapses", + target_column="pre_pt_root_id", + ) + ], + ) + assert len(ann.links) == 1 + assert ann.links[0].link_type == "foreign_key" + + +# --------------------------------------------------------------------------- +# 1.5 Request / Response models +# --------------------------------------------------------------------------- + + +def test_table_request_defaults(): + req = TableRequest( + datastack="minnie65", + name="synapses", + uri="gs://bucket/synapses/", + format="delta", + owner=1, + is_managed=True, + ) + assert req.asset_type == "table" + assert req.source == "user" + assert req.column_annotations == [] + + +def test_table_request_inherits_base_fields(): + req = TableRequest( + datastack="minnie65", + name="synapses", + uri="gs://bucket/synapses/", + format="delta", + mat_version=943, + owner=1, + is_managed=True, + mutability="static", + maturity="stable", + ) + assert req.mat_version == 943 + assert req.mutability == "static" + + +def test_table_response_from_orm_dict(): + resp = TableResponse.model_validate( + { + "id": "00000000-0000-0000-0000-000000000001", + "datastack": "minnie65", + "name": "synapses", + "mat_version": 943, + "revision": 0, + "uri": "gs://bucket/synapses/", + "format": "delta", + "asset_type": "table", + "owner": 1, + "is_managed": True, + "mutability": "static", + "maturity": "stable", + "properties": {}, + "access_group": None, + "created_at": "2026-01-01T00:00:00Z", + "expires_at": None, + "source": "user", + "cached_metadata": { + "n_rows": 100, + "n_columns": 2, + "n_bytes": 5000, + "columns": [{"name": "a", "dtype": "int64"}], + "partition_columns": [], + }, + } + ) + assert resp.source == "user" + assert resp.cached_metadata.n_rows == 100 + + +def test_table_preview_request(): + req = TablePreviewRequest( + uri="gs://bucket/synapses/", + format="delta", + datastack="minnie65", + ) + assert req.format == "delta" + + +def test_table_preview_response(): + resp = TablePreviewResponse( + metadata=TableMetadata(n_rows=10, n_columns=2, n_bytes=500, columns=[]), + ) + assert resp.metadata.n_rows == 10 + + +def test_annotation_update_request(): + req = AnnotationUpdateRequest( + column_annotations=[ + ColumnAnnotation(column_name="a", description="col a"), + ] + ) + assert len(req.column_annotations) == 1 + + +def test_merged_column(): + col = MergedColumn( + name="pre_pt_root_id", + dtype="int64", + description="Pre-synaptic root", + links=[ + ColumnLink( + link_type="foreign_key", + target_table="synapses", + target_column="pre_pt_root_id", + ) + ], + ) + assert col.description == "Pre-synaptic root" + assert len(col.links) == 1 diff --git a/tests/test_tables.py b/tests/test_tables.py new file mode 100644 index 0000000..1221927 --- /dev/null +++ b/tests/test_tables.py @@ -0,0 +1,534 @@ +"""Tests for table endpoints and column merging. + +Phase 4 tests — covers tasks 4.1–4.8. +""" + +from __future__ import annotations + +import uuid +from typing import Any +from unittest.mock import AsyncMock, patch + +import pytest +from cave_catalog.table_schemas import ( + ColumnAnnotation, + ColumnInfo, + ColumnLink, + MergedColumn, + TableMetadata, + merge_columns, +) + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + +_DELTA_METADATA = TableMetadata( + n_rows=100, + n_columns=3, + n_bytes=5000, + columns=[ + ColumnInfo(name="a", dtype="int64"), + ColumnInfo(name="b", dtype="string"), + ColumnInfo(name="c", dtype="float64"), + ], + partition_columns=[], +) + + +def _table_payload(**overrides: Any) -> dict: + base = { + "datastack": "minnie65_public", + "name": "my_table", + "revision": 0, + "uri": "gs://bucket/tables/my_table/", + "format": "delta", + "is_managed": True, + "mutability": "static", + "maturity": "stable", + "properties": {}, + } + base.update(overrides) + return base + + +def _preview_payload(**overrides: Any) -> dict: + base = { + "uri": "gs://bucket/tables/my_table/", + "format": "delta", + "datastack": "minnie65_public", + } + base.update(overrides) + return base + + +def _patch_extraction(monkeypatch, metadata: TableMetadata | None = None): + """Patch the extractor registry so extract() returns canned metadata.""" + meta = metadata or _DELTA_METADATA + mock_extractor = AsyncMock() + mock_extractor.extract = AsyncMock(return_value=meta) + monkeypatch.setattr( + "cave_catalog.routers.tables.get_extractor", + lambda fmt: mock_extractor, + ) + + +def _patch_validation(monkeypatch): + """Patch the validation pipeline to always pass.""" + from cave_catalog.schemas import ValidationCheck, ValidationReport + + ok = ValidationCheck(passed=True) + report = ValidationReport( + name_reservation_check=ok, uri_reachable=ok, format_sniff=ok + ) + monkeypatch.setattr( + "cave_catalog.routers.tables.run_validation_pipeline", + AsyncMock(return_value=report), + ) + + +def _patch_link_validation(monkeypatch, passed: bool = True, errors=None): + """Patch column link validation.""" + from cave_catalog.validation import LinkValidationResult + + result = LinkValidationResult(passed=passed, errors=errors or []) + monkeypatch.setattr( + "cave_catalog.routers.tables.validate_column_links", + AsyncMock(return_value=result), + ) + + +async def _register_table(client, monkeypatch, **overrides) -> dict: + """Register a table and return the response JSON.""" + _patch_validation(monkeypatch) + _patch_extraction(monkeypatch) + _patch_link_validation(monkeypatch) + resp = await client.post( + "/api/v1/tables/register", json=_table_payload(**overrides) + ) + assert resp.status_code == 201, resp.text + return resp.json() + + +# --------------------------------------------------------------------------- +# 4.7 Column merging +# --------------------------------------------------------------------------- + + +def test_merge_no_metadata(): + result = merge_columns(None, []) + assert result == [] + + +def test_merge_no_annotations(): + result = merge_columns(_DELTA_METADATA, []) + assert len(result) == 3 + assert all(isinstance(c, MergedColumn) for c in result) + assert all(c.description is None for c in result) + assert all(c.links == [] for c in result) + + +def test_merge_with_annotations(): + anns = [ + ColumnAnnotation(column_name="a", description="Column A"), + ColumnAnnotation( + column_name="b", + description="Column B", + links=[ColumnLink(link_type="fk", target_table="t", target_column="c")], + ), + ] + result = merge_columns(_DELTA_METADATA, anns) + by_name = {c.name: c for c in result} + + assert by_name["a"].description == "Column A" + assert by_name["b"].links[0].link_type == "fk" + assert by_name["c"].description is None # unannotated + + +def test_merge_orphaned_annotation_dropped(): + """Annotations for columns not in metadata are silently dropped.""" + anns = [ColumnAnnotation(column_name="nonexistent", description="Orphan")] + result = merge_columns(_DELTA_METADATA, anns) + names = [c.name for c in result] + assert "nonexistent" not in names + assert len(result) == 3 + + +# --------------------------------------------------------------------------- +# 4.2 Preview endpoint +# --------------------------------------------------------------------------- + + +async def test_preview_success(client, monkeypatch): + _patch_extraction(monkeypatch) + + resp = await client.post("/api/v1/tables/preview", json=_preview_payload()) + assert resp.status_code == 200 + + data = resp.json() + assert data["metadata"]["n_rows"] == 100 + assert len(data["metadata"]["columns"]) == 3 + + +async def test_preview_unsupported_format(client, monkeypatch): + monkeypatch.setattr( + "cave_catalog.routers.tables.get_extractor", + lambda fmt: (_ for _ in ()).throw(ValueError(f"No extractor for '{fmt}'")), + ) + resp = await client.post( + "/api/v1/tables/preview", json=_preview_payload(format="lance") + ) + assert resp.status_code == 422 + + +async def test_preview_extraction_failure(client, monkeypatch): + mock_ext = AsyncMock() + mock_ext.extract = AsyncMock(side_effect=Exception("read failed")) + monkeypatch.setattr( + "cave_catalog.routers.tables.get_extractor", lambda fmt: mock_ext + ) + + resp = await client.post("/api/v1/tables/preview", json=_preview_payload()) + assert resp.status_code == 422 + assert "extraction failed" in resp.json()["detail"].lower() + + +# --------------------------------------------------------------------------- +# 4.3 Registration endpoint +# --------------------------------------------------------------------------- + + +async def test_register_table_success(client, monkeypatch): + data = await _register_table(client, monkeypatch) + + assert data["name"] == "my_table" + assert data["asset_type"] == "table" + assert data["source"] == "user" + assert data["cached_metadata"]["n_rows"] == 100 + assert len(data["columns"]) == 3 # merged columns + + +async def test_register_table_with_annotations(client, monkeypatch): + annotations = [ + {"column_name": "a", "description": "Col A", "links": []}, + ] + data = await _register_table( + client, monkeypatch, column_annotations=annotations + ) + + assert data["column_annotations"][0]["column_name"] == "a" + # Merged column should have the annotation + by_name = {c["name"]: c for c in data["columns"]} + assert by_name["a"]["description"] == "Col A" + + +async def test_register_table_duplicate(client, monkeypatch): + await _register_table(client, monkeypatch) + + # Second registration with same key + _patch_validation(monkeypatch) + _patch_extraction(monkeypatch) + _patch_link_validation(monkeypatch) + resp = await client.post("/api/v1/tables/register", json=_table_payload()) + assert resp.status_code == 409 + + +async def test_register_table_validation_failure(client, monkeypatch): + from cave_catalog.schemas import ValidationCheck, ValidationReport + + report = ValidationReport( + name_reservation_check=ValidationCheck(passed=True), + uri_reachable=ValidationCheck(passed=False, message="not found"), + format_sniff=ValidationCheck(passed=True), + ) + monkeypatch.setattr( + "cave_catalog.routers.tables.run_validation_pipeline", + AsyncMock(return_value=report), + ) + _patch_extraction(monkeypatch) + + resp = await client.post("/api/v1/tables/register", json=_table_payload()) + assert resp.status_code == 422 + + +async def test_register_table_link_validation_failure(client, monkeypatch): + from cave_catalog.validation import LinkValidationError + + _patch_validation(monkeypatch) + _patch_extraction(monkeypatch) + _patch_link_validation( + monkeypatch, + passed=False, + errors=[ + LinkValidationError( + column_name="a", + link_type="fk", + target_table="bad_table", + target_column="id", + reason="not found", + ) + ], + ) + + payload = _table_payload( + column_annotations=[ + { + "column_name": "a", + "links": [ + { + "link_type": "fk", + "target_table": "bad_table", + "target_column": "id", + } + ], + } + ] + ) + resp = await client.post("/api/v1/tables/register", json=payload) + assert resp.status_code == 422 + assert "link validation" in resp.json()["detail"]["message"].lower() + + +# --------------------------------------------------------------------------- +# 4.4 Annotation update endpoint +# --------------------------------------------------------------------------- + + +async def test_update_annotations_success(client, monkeypatch): + table = await _register_table(client, monkeypatch) + table_id = table["id"] + + _patch_link_validation(monkeypatch) + resp = await client.patch( + f"/api/v1/tables/{table_id}/annotations", + json={ + "column_annotations": [ + {"column_name": "a", "description": "Updated A", "links": []}, + ] + }, + ) + assert resp.status_code == 200 + data = resp.json() + assert data["column_annotations"][0]["description"] == "Updated A" + + +async def test_update_annotations_clear(client, monkeypatch): + table = await _register_table( + client, + monkeypatch, + column_annotations=[{"column_name": "a", "description": "Old", "links": []}], + ) + table_id = table["id"] + + _patch_link_validation(monkeypatch) + resp = await client.patch( + f"/api/v1/tables/{table_id}/annotations", + json={"column_annotations": []}, + ) + assert resp.status_code == 200 + assert resp.json()["column_annotations"] == [] + + +async def test_update_annotations_on_non_table_returns_400(client, monkeypatch): + """Patching annotations on a base asset should return 400.""" + # Register via the base asset endpoint + from cave_catalog.schemas import ValidationCheck, ValidationReport + + ok = ValidationCheck(passed=True) + monkeypatch.setattr( + "cave_catalog.routers.assets.run_validation_pipeline", + AsyncMock( + return_value=ValidationReport( + name_reservation_check=ok, uri_reachable=ok, format_sniff=ok + ) + ), + ) + resp = await client.post( + "/api/v1/assets/register", + json={ + "datastack": "minnie65_public", + "name": "image_vol", + "revision": 0, + "uri": "gs://bucket/images/", + "format": "precomputed", + "asset_type": "asset", + "is_managed": False, + "mutability": "static", + "maturity": "stable", + "properties": {}, + }, + ) + assert resp.status_code == 201 + asset_id = resp.json()["id"] + + resp2 = await client.patch( + f"/api/v1/tables/{asset_id}/annotations", + json={"column_annotations": []}, + ) + assert resp2.status_code == 400 + + +# --------------------------------------------------------------------------- +# 4.5 Metadata refresh endpoint +# --------------------------------------------------------------------------- + + +async def test_refresh_metadata_success(client, monkeypatch): + table = await _register_table(client, monkeypatch) + table_id = table["id"] + + # Refresh with updated metadata + new_meta = TableMetadata( + n_rows=200, + n_columns=3, + n_bytes=10000, + columns=_DELTA_METADATA.columns, + partition_columns=[], + ) + _patch_extraction(monkeypatch, metadata=new_meta) + + resp = await client.post(f"/api/v1/tables/{table_id}/refresh") + assert resp.status_code == 200 + data = resp.json() + assert data["cached_metadata"]["n_rows"] == 200 + + +async def test_refresh_preserves_annotations(client, monkeypatch): + table = await _register_table( + client, + monkeypatch, + column_annotations=[{"column_name": "a", "description": "Keep me", "links": []}], + ) + table_id = table["id"] + + _patch_extraction(monkeypatch) + resp = await client.post(f"/api/v1/tables/{table_id}/refresh") + assert resp.status_code == 200 + data = resp.json() + assert data["column_annotations"][0]["description"] == "Keep me" + + +async def test_refresh_non_table_returns_400(client, monkeypatch): + from cave_catalog.schemas import ValidationCheck, ValidationReport + + ok = ValidationCheck(passed=True) + monkeypatch.setattr( + "cave_catalog.routers.assets.run_validation_pipeline", + AsyncMock( + return_value=ValidationReport( + name_reservation_check=ok, uri_reachable=ok, format_sniff=ok + ) + ), + ) + resp = await client.post( + "/api/v1/assets/register", + json={ + "datastack": "minnie65_public", + "name": "image_vol2", + "revision": 0, + "uri": "gs://bucket/images/", + "format": "precomputed", + "asset_type": "asset", + "is_managed": False, + "mutability": "static", + "maturity": "stable", + "properties": {}, + }, + ) + asset_id = resp.json()["id"] + + resp2 = await client.post(f"/api/v1/tables/{asset_id}/refresh") + assert resp2.status_code == 400 + + +# --------------------------------------------------------------------------- +# 4.6 List tables endpoint +# --------------------------------------------------------------------------- + + +async def test_list_tables_empty(client): + resp = await client.get("/api/v1/tables/?datastack=minnie65_public") + assert resp.status_code == 200 + assert resp.json() == [] + + +async def test_list_tables_returns_tables(client, monkeypatch): + await _register_table(client, monkeypatch) + + resp = await client.get("/api/v1/tables/?datastack=minnie65_public") + assert resp.status_code == 200 + data = resp.json() + assert len(data) == 1 + assert data[0]["asset_type"] == "table" + assert "columns" in data[0] # merged columns present + + +async def test_list_tables_filters_by_format(client, monkeypatch): + await _register_table(client, monkeypatch, name="delta_table", format="delta") + + resp = await client.get( + "/api/v1/tables/?datastack=minnie65_public&format=parquet" + ) + assert resp.json() == [] + + resp2 = await client.get( + "/api/v1/tables/?datastack=minnie65_public&format=delta" + ) + assert len(resp2.json()) == 1 + + +async def test_list_tables_filters_by_source(client, monkeypatch): + await _register_table(client, monkeypatch, name="t1", source="user") + await _register_table(client, monkeypatch, name="t2", source="materialization") + + resp = await client.get( + "/api/v1/tables/?datastack=minnie65_public&source=materialization" + ) + data = resp.json() + assert len(data) == 1 + assert data[0]["source"] == "materialization" + + +async def test_list_tables_excludes_non_tables(client, monkeypatch): + """Base assets should not appear in the tables list.""" + # Register a table + await _register_table(client, monkeypatch, name="real_table") + + # Register a base asset via the assets endpoint + from cave_catalog.schemas import ValidationCheck, ValidationReport + + ok = ValidationCheck(passed=True) + monkeypatch.setattr( + "cave_catalog.routers.assets.run_validation_pipeline", + AsyncMock( + return_value=ValidationReport( + name_reservation_check=ok, uri_reachable=ok, format_sniff=ok + ) + ), + ) + await client.post( + "/api/v1/assets/register", + json={ + "datastack": "minnie65_public", + "name": "image_vol", + "revision": 0, + "uri": "gs://bucket/images/", + "format": "precomputed", + "asset_type": "asset", + "is_managed": False, + "mutability": "static", + "maturity": "stable", + "properties": {}, + }, + ) + + resp = await client.get("/api/v1/tables/?datastack=minnie65_public") + data = resp.json() + assert len(data) == 1 + assert data[0]["name"] == "real_table" + + +async def test_list_tables_requires_datastack(client): + resp = await client.get("/api/v1/tables/") + assert resp.status_code == 422 diff --git a/tests/test_ui_auth.py b/tests/test_ui_auth.py new file mode 100644 index 0000000..a448a2e --- /dev/null +++ b/tests/test_ui_auth.py @@ -0,0 +1,91 @@ +"""Tests for the UI auth flow: login redirect, callback, logout, and auth guard.""" + +import pytest +from cave_catalog.auth.middleware import TOKEN_COOKIE_NAME + + +@pytest.fixture +def auth_client(client): + """Alias for the base client fixture.""" + return client + + +@pytest.fixture +def auth_enabled_env(monkeypatch): + """Enable auth for these tests.""" + monkeypatch.setenv("AUTH_ENABLED", "true") + from cave_catalog.config import get_settings + + get_settings.cache_clear() + yield + get_settings.cache_clear() + + +class TestAuthGuard: + """Unauthenticated users should be redirected to /ui/login.""" + + async def test_register_redirects_when_unauthenticated( + self, client, auth_enabled_env + ): + # Re-create client with auth enabled + response = await client.get("/ui/register", follow_redirects=False) + assert response.status_code == 302 + assert "/ui/login" in response.headers["location"] + + async def test_explore_redirects_when_unauthenticated( + self, client, auth_enabled_env + ): + response = await client.get("/ui/explore", follow_redirects=False) + assert response.status_code == 302 + assert "/ui/login" in response.headers["location"] + + async def test_pages_accessible_when_auth_disabled(self, client): + """With AUTH_ENABLED=false (default), pages are accessible.""" + response = await client.get("/ui/register") + assert response.status_code == 200 + assert "Register" in response.text + + +class TestLogin: + """GET /ui/login should redirect to middle_auth authorize URL.""" + + async def test_login_redirects_to_authorize(self, client): + response = await client.get("/ui/login", follow_redirects=False) + assert response.status_code == 307 or response.status_code == 302 + location = response.headers["location"] + assert "/sticky_auth/api/v1/authorize" in location + assert "redirect=" in location + + +class TestCallback: + """GET /ui/callback should set cookie and redirect.""" + + async def test_callback_sets_cookie(self, client): + response = await client.get( + f"/ui/callback?{TOKEN_COOKIE_NAME}=test-token-123&next=/ui/register", + follow_redirects=False, + ) + assert response.status_code == 302 + assert response.headers["location"] == "/ui/register" + # Cookie should be set in the response + set_cookie = response.headers.get("set-cookie", "") + assert TOKEN_COOKIE_NAME in set_cookie + assert "test-token-123" in set_cookie + + async def test_callback_without_token_redirects_to_login(self, client): + response = await client.get("/ui/callback", follow_redirects=False) + assert response.status_code == 307 or response.status_code == 302 + location = response.headers["location"] + assert "/ui/login" in location + + +class TestLogout: + """GET /ui/logout should clear cookie and redirect to login.""" + + async def test_logout_clears_cookie(self, client): + response = await client.get("/ui/logout", follow_redirects=False) + assert response.status_code == 302 + assert response.headers["location"] == "/ui/login" + set_cookie = response.headers.get("set-cookie", "") + # Cookie should be deleted (max-age=0 or expires in past) + assert TOKEN_COOKIE_NAME in set_cookie diff --git a/tests/test_ui_preview.py b/tests/test_ui_preview.py new file mode 100644 index 0000000..ef5bdee --- /dev/null +++ b/tests/test_ui_preview.py @@ -0,0 +1,120 @@ +"""Tests for the table preview UI route handler.""" + +from __future__ import annotations + +from unittest.mock import AsyncMock, patch + +import pytest + +from cave_catalog.table_schemas import ColumnInfo, TableMetadata + + +class TestPreviewRoute: + """Tests for POST /ui/preview.""" + + async def test_empty_uri_returns_error(self, client): + resp = await client.post("/ui/preview", data={"uri": "", "format": "delta"}) + assert resp.status_code == 200 + assert "Please enter a URI" in resp.text + + async def test_unsupported_format(self, client): + resp = await client.post( + "/ui/preview", data={"uri": "gs://bucket/path", "format": "csv"} + ) + assert resp.status_code == 200 + assert "Unsupported format" in resp.text + assert "csv" in resp.text + + async def test_uri_not_found(self, client, monkeypatch): + """URI that doesn't exist shows diagnostic error.""" + mock_extractor = AsyncMock() + mock_extractor.extract.side_effect = FileNotFoundError( + "No such file or directory: gs://bucket/missing" + ) + monkeypatch.setattr( + "cave_catalog.routers.ui.get_extractor", lambda fmt: mock_extractor + ) + resp = await client.post( + "/ui/preview", data={"uri": "gs://bucket/missing", "format": "delta"} + ) + assert resp.status_code == 200 + assert "URI unreachable" in resp.text + assert "does not exist" in resp.text + + async def test_permission_denied(self, client, monkeypatch): + """Permission error shows diagnostic.""" + mock_extractor = AsyncMock() + mock_extractor.extract.side_effect = PermissionError( + "403 Forbidden: Access denied" + ) + monkeypatch.setattr( + "cave_catalog.routers.ui.get_extractor", lambda fmt: mock_extractor + ) + resp = await client.post( + "/ui/preview", data={"uri": "gs://secret/data", "format": "delta"} + ) + assert resp.status_code == 200 + assert "URI unreachable" in resp.text or "permission" in resp.text.lower() + + async def test_extraction_failure(self, client, monkeypatch): + """Generic extraction error shows format-specific diagnostic.""" + mock_extractor = AsyncMock() + mock_extractor.extract.side_effect = RuntimeError("Corrupt delta log") + monkeypatch.setattr( + "cave_catalog.routers.ui.get_extractor", lambda fmt: mock_extractor + ) + resp = await client.post( + "/ui/preview", data={"uri": "gs://bucket/bad", "format": "delta"} + ) + assert resp.status_code == 200 + assert "Failed to read delta data" in resp.text + + async def test_successful_preview(self, client, monkeypatch): + """Successful preview returns metadata fragment.""" + metadata = TableMetadata( + n_rows=1000, + n_columns=3, + n_bytes=1048576, + columns=[ + ColumnInfo(name="id", dtype="int64"), + ColumnInfo(name="pt_position", dtype="list"), + ColumnInfo(name="label", dtype="string"), + ], + partition_columns=[], + ) + mock_extractor = AsyncMock() + mock_extractor.extract.return_value = metadata + monkeypatch.setattr( + "cave_catalog.routers.ui.get_extractor", lambda fmt: mock_extractor + ) + resp = await client.post( + "/ui/preview", data={"uri": "gs://bucket/table", "format": "delta"} + ) + assert resp.status_code == 200 + assert "DELTA" in resp.text + assert "1,000" in resp.text # formatted row count + assert "id" in resp.text + assert "pt_position" in resp.text + assert "registration-fields" in resp.text # JS to show next step + + async def test_successful_preview_parquet(self, client, monkeypatch): + """Parquet format shows correctly.""" + metadata = TableMetadata( + n_rows=500, + n_columns=2, + columns=[ + ColumnInfo(name="col_a", dtype="float64"), + ColumnInfo(name="col_b", dtype="string"), + ], + ) + mock_extractor = AsyncMock() + mock_extractor.extract.return_value = metadata + monkeypatch.setattr( + "cave_catalog.routers.ui.get_extractor", lambda fmt: mock_extractor + ) + resp = await client.post( + "/ui/preview", data={"uri": "gs://bucket/file.parquet", "format": "parquet"} + ) + assert resp.status_code == 200 + assert "PARQUET" in resp.text + assert "col_a" in resp.text diff --git a/tests/test_ui_register.py b/tests/test_ui_register.py new file mode 100644 index 0000000..c2a3a67 --- /dev/null +++ b/tests/test_ui_register.py @@ -0,0 +1,144 @@ +"""Tests for the registration page and submit flow.""" + +from __future__ import annotations + +from unittest.mock import AsyncMock + +import pytest + +from cave_catalog.schemas import ValidationCheck, ValidationReport +from cave_catalog.table_schemas import ColumnInfo, TableMetadata + + +def _passing_report(): + return ValidationReport( + auth_check=ValidationCheck(passed=True), + duplicate_check=ValidationCheck(passed=True), + name_reservation_check=ValidationCheck(passed=True), + uri_reachable=ValidationCheck(passed=True), + format_sniff=ValidationCheck(passed=True), + ) + + +def _mock_metadata(): + return TableMetadata( + n_rows=100, + n_columns=2, + columns=[ + ColumnInfo(name="id", dtype="int64"), + ColumnInfo(name="value", dtype="float64"), + ], + ) + + +class TestRegisterPageRenders: + async def test_register_page_renders(self, client): + """Register page loads with form elements.""" + resp = await client.get("/ui/register") + assert resp.status_code == 200 + assert "Register a Table" in resp.text + assert "Preview" in resp.text + assert "uri" in resp.text + + async def test_preview_returns_column_table(self, client, monkeypatch): + """Successful preview includes annotation table with description fields.""" + mock_extractor = AsyncMock() + mock_extractor.extract.return_value = _mock_metadata() + monkeypatch.setattr( + "cave_catalog.routers.ui.get_extractor", lambda fmt: mock_extractor + ) + resp = await client.post( + "/ui/preview", data={"uri": "gs://bucket/table", "format": "delta"} + ) + assert resp.status_code == 200 + # Should contain annotation inputs + assert "col_name_0" in resp.text + assert "col_desc_0" in resp.text + assert "Add Link" in resp.text + + +class TestRegisterSubmit: + async def test_successful_registration(self, client, monkeypatch): + """Full registration flow: preview + submit → success.""" + monkeypatch.setenv("DATASTACKS", "minnie65_public") + from cave_catalog.config import get_settings + + get_settings.cache_clear() + + # Mock validation and extraction + monkeypatch.setattr( + "cave_catalog.routers.tables.run_validation_pipeline", + AsyncMock(return_value=_passing_report()), + ) + mock_extractor = AsyncMock() + mock_extractor.extract.return_value = _mock_metadata() + monkeypatch.setattr( + "cave_catalog.routers.tables.get_extractor", lambda fmt: mock_extractor + ) + + resp = await client.post( + "/ui/register/submit", + data={ + "uri": "gs://bucket/table", + "format": "delta", + "name": "test_table", + "mat_version": "", + "n_columns": "2", + "col_name_0": "id", + "col_dtype_0": "int64", + "col_desc_0": "Primary key", + "col_name_1": "value", + "col_dtype_1": "float64", + "col_desc_1": "", + }, + cookies={"cave_catalog_datastack": "minnie65_public"}, + ) + assert resp.status_code == 200 + assert "successfully" in resp.text.lower() or "register_success" in resp.url.path if hasattr(resp, 'url') else True + assert "test_table" in resp.text + + async def test_missing_fields(self, client, monkeypatch): + """Submit without required fields returns error.""" + monkeypatch.setenv("DATASTACKS", "minnie65_public") + from cave_catalog.config import get_settings + + get_settings.cache_clear() + + resp = await client.post( + "/ui/register/submit", + data={"uri": "", "format": "delta", "name": ""}, + cookies={"cave_catalog_datastack": "minnie65_public"}, + ) + assert resp.status_code == 200 + assert "required" in resp.text.lower() + + async def test_validation_failure_shows_error(self, client, monkeypatch): + """Validation failure renders error fragment with details.""" + monkeypatch.setenv("DATASTACKS", "minnie65_public") + from cave_catalog.config import get_settings + + get_settings.cache_clear() + + # Make validation fail + failing_report = _passing_report() + failing_report.uri_reachable = ValidationCheck( + passed=False, message="URI not reachable" + ) + monkeypatch.setattr( + "cave_catalog.routers.tables.run_validation_pipeline", + AsyncMock(return_value=failing_report), + ) + + resp = await client.post( + "/ui/register/submit", + data={ + "uri": "gs://bad/path", + "format": "delta", + "name": "fail_table", + "n_columns": "0", + }, + cookies={"cave_catalog_datastack": "minnie65_public"}, + ) + assert resp.status_code == 200 + # Should show an error message + assert "failed" in resp.text.lower() or "error" in resp.text.lower()