From 4a41b5dd8e7526b1a5482b359975c5d8addb3816 Mon Sep 17 00:00:00 2001 From: Tianhao-Gu Date: Thu, 12 Mar 2026 15:37:52 -0500 Subject: [PATCH 1/5] Revert "Merge pull request #132 from BERDataLakehouse/feature/revert_polaris" This reverts commit c1508253be46eb6f7ce23dd0500b33fa7d325c8e, reversing changes made to 7ce659e0ebfa668fa39cd51ec7af236dca77bacf. --- Dockerfile | 2 +- configs/ipython_startup/00-notebookutils.py | 6 + configs/ipython_startup/01-credentials.py | 51 ++ .../ipython_startup/01-minio-credentials.py | 29 - configs/jupyter_server_config.py | 93 +++- configs/spark-defaults.conf.template | 14 +- docker-compose.yaml | 116 +++- docs/data_sharing_guide.md | 8 +- docs/iceberg_migration_guide.md | 426 +++++++++++++++ docs/tenant_sql_warehouse_guide.md | 176 +++--- docs/user_guide.md | 11 +- .../berdl_notebook_utils/__init__.py | 3 + .../berdl_notebook_utils/agent/tools.py | 6 +- .../berdl_notebook_utils/berdl_settings.py | 14 +- .../berdl_notebook_utils/mcp/operations.py | 33 +- .../minio_governance/__init__.py | 10 + .../minio_governance/operations.py | 303 ++++++++-- .../berdl_notebook_utils/refresh.py | 116 ++++ .../setup_spark_session.py | 94 +++- .../berdl_notebook_utils/spark/__init__.py | 2 +- .../spark/connect_server.py | 28 +- .../berdl_notebook_utils/spark/data_store.py | 5 +- .../berdl_notebook_utils/spark/database.py | 58 +- notebook_utils/pyproject.toml | 4 +- notebook_utils/tests/agent/test_mcp_tools.py | 257 ++++++++- notebook_utils/tests/agent/test_prompts.py | 91 +++ notebook_utils/tests/mcp/test_operations.py | 13 - .../tests/minio_governance/test_operations.py | 517 ++++++++++++++++++ .../tests/spark/test_connect_server.py | 250 ++++++++- notebook_utils/tests/spark/test_data_store.py | 12 +- notebook_utils/tests/spark/test_database.py | 62 ++- notebook_utils/tests/test_cache.py | 116 ++++ .../tests/test_get_spark_session.py | 204 +++++++ notebook_utils/tests/test_refresh.py | 216 ++++++++ notebook_utils/uv.lock | 114 ++-- scripts/init-polaris-db.sh | 13 + scripts/migrate_delta_to_iceberg.py | 379 +++++++++++++ 37 files changed, 3474 insertions(+), 378 deletions(-) create mode 100644 configs/ipython_startup/01-credentials.py delete mode 100644 configs/ipython_startup/01-minio-credentials.py create mode 100644 docs/iceberg_migration_guide.md create mode 100644 notebook_utils/berdl_notebook_utils/refresh.py create mode 100644 notebook_utils/tests/agent/test_prompts.py create mode 100644 notebook_utils/tests/test_cache.py create mode 100644 notebook_utils/tests/test_refresh.py create mode 100755 scripts/init-polaris-db.sh create mode 100644 scripts/migrate_delta_to_iceberg.py diff --git a/Dockerfile b/Dockerfile index 519951a..d03f263 100644 --- a/Dockerfile +++ b/Dockerfile @@ -1,4 +1,4 @@ -ARG BASE_TAG=pr-111 +ARG BASE_TAG=pr-115 ARG BASE_REGISTRY=ghcr.io/berdatalakehouse/ FROM ${BASE_REGISTRY}spark_notebook_base:${BASE_TAG} diff --git a/configs/ipython_startup/00-notebookutils.py b/configs/ipython_startup/00-notebookutils.py index 665a1a7..295dd29 100644 --- a/configs/ipython_startup/00-notebookutils.py +++ b/configs/ipython_startup/00-notebookutils.py @@ -102,6 +102,7 @@ create_tenant_and_assign_users, get_group_sql_warehouse, get_minio_credentials, + get_polaris_credentials, get_my_accessible_paths, get_my_groups, get_my_policies, @@ -117,6 +118,11 @@ unshare_table, ) +# ============================================================================ +# Environment Refresh +# ============================================================================ +from berdl_notebook_utils.refresh import refresh_spark_environment # noqa: F401 + # ============================================================================ # Help Utilities # ============================================================================ diff --git a/configs/ipython_startup/01-credentials.py b/configs/ipython_startup/01-credentials.py new file mode 100644 index 0000000..2ce303f --- /dev/null +++ b/configs/ipython_startup/01-credentials.py @@ -0,0 +1,51 @@ +""" +Initialize MinIO and Polaris credentials. + +This runs after 00-notebookutils.py loads all the imports, so get_minio_credentials +is already available in the global namespace. +""" + +# Setup logging +import logging + +logger = logging.getLogger("berdl.startup") +logger.setLevel(logging.INFO) +if not logger.handlers: + handler = logging.StreamHandler() + formatter = logging.Formatter("%(message)s") + handler.setFormatter(formatter) + logger.addHandler(handler) + +# --- MinIO Credentials --- +try: + # Set MinIO credentials to environment - also creates user if they don't exist + credentials = get_minio_credentials() # noqa: F821 + logger.info(f"✅ MinIO credentials set for user: {credentials.username}") + +except Exception as e: + import warnings + + warnings.warn(f"Failed to set MinIO credentials: {str(e)}", UserWarning) + logger.error(f"❌ Failed to set MinIO credentials: {str(e)}") + credentials = None + +# --- Polaris Credentials --- +try: + polaris_creds = get_polaris_credentials() # noqa: F821 + if polaris_creds: + logger.info(f"✅ Polaris credentials set for catalog: {polaris_creds['personal_catalog']}") + if polaris_creds["tenant_catalogs"]: + logger.info(f" Tenant catalogs: {', '.join(polaris_creds['tenant_catalogs'])}") + # Clear the settings cache so downstream code (e.g., Spark Connect server startup) + # picks up the POLARIS_CREDENTIAL, POLARIS_PERSONAL_CATALOG, and + # POLARIS_TENANT_CATALOGS env vars that get_polaris_credentials() just set. + get_settings.cache_clear() # noqa: F821 + else: + logger.info("ℹ️ Polaris not configured, skipping Polaris credential setup") + +except Exception as e: + import warnings + + warnings.warn(f"Failed to set Polaris credentials: {str(e)}", UserWarning) + logger.warning(f"⚠️ Failed to set Polaris credentials: {str(e)}") + polaris_creds = None diff --git a/configs/ipython_startup/01-minio-credentials.py b/configs/ipython_startup/01-minio-credentials.py deleted file mode 100644 index 379cad7..0000000 --- a/configs/ipython_startup/01-minio-credentials.py +++ /dev/null @@ -1,29 +0,0 @@ -""" -Initialize MinIO credentials and basic MinIO client. - -This runs after 00-notebookutils.py loads all the imports, so get_minio_credentials -is already available in the global namespace. -""" - -# Setup logging -import logging - -logger = logging.getLogger("berdl.startup") -logger.setLevel(logging.INFO) -if not logger.handlers: - handler = logging.StreamHandler() - formatter = logging.Formatter("%(message)s") - handler.setFormatter(formatter) - logger.addHandler(handler) - -try: - # Set MinIO credentials to environment - also creates user if they don't exist - credentials = get_minio_credentials() # noqa: F821 - logger.info(f"✅ MinIO credentials set for user: {credentials.username}") - -except Exception as e: - import warnings - - warnings.warn(f"Failed to set MinIO credentials: {str(e)}", UserWarning) - logger.error(f"❌ Failed to set MinIO credentials: {str(e)}") - credentials = None diff --git a/configs/jupyter_server_config.py b/configs/jupyter_server_config.py index 378f2f9..93fc9dd 100644 --- a/configs/jupyter_server_config.py +++ b/configs/jupyter_server_config.py @@ -1,13 +1,12 @@ import os import sys import logging -import json -from pathlib import Path # Add config directory to path for local imports # Note: __file__ may not be defined when exec'd by traitlets config loader sys.path.insert(0, "/etc/jupyter") +from berdl_notebook_utils.berdl_settings import get_settings from hybridcontents import HybridContentsManager from jupyter_server.services.contents.largefilemanager import LargeFileManager from grouped_s3_contents import GroupedS3ContentsManager @@ -48,31 +47,20 @@ def get_minio_config(): - """Extract MinIO configuration from credentials file or environment.""" + """Extract MinIO configuration, provisioning credentials via governance API if needed.""" + from berdl_notebook_utils.minio_governance import get_minio_credentials + + # Provision user + fetch credentials (checks cache first, calls API if needed, + # sets MINIO_ACCESS_KEY/MINIO_SECRET_KEY env vars) + credentials = get_minio_credentials() + access_key = credentials.access_key + secret_key = credentials.secret_key - # Default values endpoint = os.environ.get("MINIO_ENDPOINT_URL") - access_key = os.environ.get("MINIO_ACCESS_KEY") - secret_key = os.environ.get("MINIO_SECRET_KEY") use_ssl = os.environ.get("MINIO_SECURE", "false").lower() == "true" - # Try reading from credential file - try: - username = os.environ.get("NB_USER", "jovyan") - cred_path = Path(f"/home/{username}/.berdl_minio_credentials") - if cred_path.exists(): - data = json.loads(cred_path.read_text()) - access_key = data.get("access_key") or access_key - secret_key = data.get("secret_key") or secret_key - logger.info(f"Loaded MinIO credentials from {cred_path} for user: {data.get('username', 'unknown')}") - except Exception as e: - logger.warning(f"Failed to read credential file: {e}") - - # Validate required config - if not endpoint or not access_key or not secret_key: - configs = [("MINIO_ENDPOINT_URL", endpoint), ("MINIO_ACCESS_KEY", access_key), ("MINIO_SECRET_KEY", secret_key)] - missing = [k for k, v in configs if not v] - raise ValueError(f"Missing required MinIO configuration: {missing}") + if not endpoint: + raise ValueError("MINIO_ENDPOINT_URL is required") if not endpoint.startswith(("http://", "https://")): protocol = "https://" if use_ssl else "http://" @@ -147,17 +135,72 @@ def get_user_governance_paths(): return sources +def provision_polaris(): + """Provision Polaris credentials at server startup and set env vars. + + Called once at Jupyter Server startup so credentials are available + before any notebook kernel opens. Subsequent calls from IPython startup + scripts will hit the file cache and return immediately. + """ + try: + from berdl_notebook_utils.minio_governance import get_polaris_credentials + + polaris_creds = get_polaris_credentials() + if polaris_creds: + logger.info(f"\u2705 Polaris credentials provisioned for catalog: {polaris_creds['personal_catalog']}") + if polaris_creds["tenant_catalogs"]: + logger.info(f" Tenant catalogs: {', '.join(polaris_creds['tenant_catalogs'])}") + else: + logger.info("\u2139\ufe0f Polaris not configured, skipping Polaris credential provisioning") + except Exception as e: + logger.error(f"Failed to provision Polaris credentials: {e}") + + +def start_spark_connect(): + """Start Spark Connect server at Jupyter Server startup. + + Runs in a background thread so it doesn't block the server from accepting + connections. Idempotent: reuses existing process if already running. + """ + import threading + + def _start(): + try: + from berdl_notebook_utils.spark.connect_server import start_spark_connect_server + + server_info = start_spark_connect_server() + logger.info(f"\u2705 Spark Connect server ready at {server_info['url']}") + except Exception as e: + logger.error(f"\u274c Failed to start Spark Connect server: {e}") + + t = threading.Thread(target=_start, name="spark-connect-startup", daemon=True) + t.start() + + # --- Main Configuration Logic --- # 1. Local Manager (Root) # We map the root directory to the user's home username = os.environ.get("NB_USER", "jovyan") -# 2. Get MinIO configuration +# 2. Get MinIO configuration (also provisions/caches the user in MinIO) endpoint_url, access_key, secret_key, use_ssl = get_minio_config() governance_paths = get_user_governance_paths() -# 3. Configure HybridContentsManager +# 3. Provision Polaris credentials — MUST be before Spark Connect so that +# POLARIS_CREDENTIAL env vars are set when generating spark-defaults.conf +provision_polaris() + +# Clear the settings cache so start_spark_connect picks up the new +# POLARIS_CREDENTIAL/POLARIS_PERSONAL_CATALOG/POLARIS_TENANT_CATALOGS env vars +# that provision_polaris() just set. Without this, the lru_cache returns the +# stale settings object captured before Polaris provisioning ran. +get_settings.cache_clear() + +# 4. Start Spark Connect server in background (non-blocking) +start_spark_connect() + +# 5. Configure HybridContentsManager # - Root ("") -> Local filesystem # - "datalake_minio" -> GroupedS3ContentsManager with all S3 paths as subdirectories c.HybridContentsManager.manager_classes = { diff --git a/configs/spark-defaults.conf.template b/configs/spark-defaults.conf.template index 6801757..8d86261 100644 --- a/configs/spark-defaults.conf.template +++ b/configs/spark-defaults.conf.template @@ -43,13 +43,14 @@ # ============================================================================== # ------------------------------------------------------------------------------ -# Delta Lake Configuration (STATIC - Server-Side Only) +# SQL Extensions and Catalog Configuration (STATIC - Server-Side Only) # ------------------------------------------------------------------------------ # These SQL extensions must be loaded when the Spark server starts. -# They initialize Delta Lake support by registering custom SparkSessionExtensions -# and catalog implementations that handle Delta table operations. +# Delta Lake + Iceberg + Sedona run side-by-side. +# The default catalog remains spark_catalog (Delta/Hive) for backward compatibility. +# Iceberg catalogs (my, tenant aliases) are added dynamically by connect_server.py. -spark.sql.extensions=io.delta.sql.DeltaSparkSessionExtension,org.apache.sedona.sql.SedonaSqlExtensions +spark.sql.extensions=io.delta.sql.DeltaSparkSessionExtension,org.apache.iceberg.spark.extensions.IcebergSparkSessionExtensions,org.apache.sedona.sql.SedonaSqlExtensions spark.sql.catalog.spark_catalog=org.apache.spark.sql.delta.catalog.DeltaCatalog # Delta Lake settings @@ -90,6 +91,9 @@ spark.hadoop.fs.s3a.path.style.access=true spark.hadoop.fs.s3a.impl=org.apache.hadoop.fs.s3a.S3AFileSystem spark.hadoop.fs.s3a.connection.ssl.enabled=false +# Polaris Iceberg catalog configuration will be appended dynamically by connect_server.py +# based on POLARIS_* environment variables (personal catalog "my" + tenant catalogs) + # ------------------------------------------------------------------------------ # KBase Authentication Interceptor (STATIC - Server-Side Only) # ------------------------------------------------------------------------------ @@ -116,7 +120,7 @@ spark.hadoop.fs.s3a.connection.ssl.enabled=false # Environment variables used by the namespace interceptor: # - BERDL_ALLOWED_NAMESPACE_PREFIXES: Comma-separated allowed prefixes # (e.g., "u_tgu2__,kbase_,research_"). Set dynamically by connect_server.py. -spark.connect.grpc.interceptor.classes=us.kbase.spark.KBaseAuthServerInterceptor,us.kbase.spark.NamespaceValidationInterceptor +spark.connect.grpc.interceptor.classes=us.kbase.spark.KBaseAuthServerInterceptor # ------------------------------------------------------------------------------ # Session Timeout (STATIC - Server-Side Only) diff --git a/docker-compose.yaml b/docker-compose.yaml index f2cb87e..fad1678 100644 --- a/docker-compose.yaml +++ b/docker-compose.yaml @@ -68,7 +68,9 @@ # - MinIO: minio/minio123 # - PostgreSQL: hive/hivepassword services: - spark-notebook: + # Service names use the pattern: spark-notebook-{CI_KBASE_USERNAME} + # Update these keys if you change the usernames in .env + spark-notebook-tgu2: # image: ghcr.io/berdatalakehouse/spark_notebook:main # platform: linux/amd64 build: @@ -85,7 +87,7 @@ services: - CDM_TASK_SERVICE_URL=http://localhost:8080 - SPARK_CLUSTER_MANAGER_API_URL=http://localhost:8000 - SPARK_MASTER_URL=spark://spark-master:7077 - - BERDL_POD_IP=spark-notebook + - BERDL_POD_IP=spark-notebook-${CI_KBASE_USERNAME} - BERDL_HIVE_METASTORE_URI=thrift://hive-metastore:9083 # MINIO CONFIGURATION @@ -98,6 +100,9 @@ services: # DATALAKE MCP SERVER CONFIGURATION - DATALAKE_MCP_SERVER_URL=http://datalake-mcp-server:8000/apis/mcp + # POLARIS CONFIGURATION (per-user credentials provisioned dynamically by 01-credentials.py) + - POLARIS_CATALOG_URI=http://polaris:8181/api/catalog + # AUTHENTICATION CONFIGURATION - KBASE_AUTH_URL=https://ci.kbase.us/services/auth/ - KBASE_AUTH_TOKEN=${CI_KBASE_AUTH_TOKEN} @@ -112,6 +117,20 @@ services: - hive-metastore command: start-notebook.py --NotebookApp.token='' --NotebookApp.password='' --notebook-dir=/home/${CI_KBASE_USERNAME} + spark-connect-proxy: + build: + context: ../spark_connect_proxy + dockerfile: Dockerfile + ports: + - "15004:15002" # Proxy exposed on host port 15004 + environment: + - KBASE_AUTH_URL=https://ci.kbase.us/services/auth/ + - PROXY_LISTEN_PORT=15002 + - BACKEND_PORT=15002 + - SERVICE_TEMPLATE=spark-notebook-{username} + - TOKEN_CACHE_TTL=300 + - MFA_EXEMPT_USERS=${CI_KBASE_USERNAME} + minio-manager-service: # image: ghcr.io/berdatalakehouse/minio_manager_service:main # platform: linux/amd64 @@ -128,6 +147,9 @@ services: - KBASE_ADMIN_ROLES=CDM_JUPYTERHUB_ADMIN - KBASE_APPROVED_ROLES=BERDL_USER - REDIS_URL=redis://redis:6379 + # Polaris admin credentials (only the governance service needs root access) + - POLARIS_CATALOG_URI=http://polaris:8181/api/catalog + - POLARIS_CREDENTIAL=root:s3cr3t datalake-mcp-server: # image: ghcr.io/berdatalakehouse/datalake-mcp-server:main @@ -142,11 +164,11 @@ services: - "8005:8000" # MCP server port environment: # Shared MCP service - connects to user-specific Spark Connect servers - # For docker-compose local dev, we default to the primary notebook (tgu2) - # NOTE: Due to different port mappings in docker-compose (tgu2:15002, tgu1:15003), - # the MCP server will primarily work with tgu2. In Kubernetes, all users + # For docker-compose local dev, we default to the primary notebook + # NOTE: Due to different port mappings in docker-compose (user1:15002, user2:15003), + # the MCP server will primarily work with the primary user. In Kubernetes, all users # follow the same pattern: sc://jupyter-{username}.jupyterhub-{env}:15002 - - SPARK_CONNECT_URL_TEMPLATE=sc://spark-notebook:15002 + - SPARK_CONNECT_URL_TEMPLATE=sc://spark-notebook-${CI_KBASE_USERNAME}:15002 - BERDL_HIVE_METASTORE_URI=thrift://hive-metastore:9083 - MINIO_ENDPOINT_URL=minio:9002 - GOVERNANCE_API_URL=http://minio-manager-service:8000 @@ -158,6 +180,8 @@ services: - POSTGRES_PASSWORD=readonly_password - KBASE_REQUIRED_ROLES=BERDL_USER - MFA_EXEMPT_USERS=${CI_KBASE_USERNAME} + # POLARIS CONFIGURATION (per-user credentials provisioned dynamically) + - POLARIS_CATALOG_URI=http://polaris:8181/api/catalog volumes: # Mount the shared /home directory to access all users' credentials # This allows the MCP server to dynamically read any user's credentials @@ -235,10 +259,6 @@ services: - SPARK_WORKER_MEMORY=5g - SPARK_WORKER_PORT=8081 - SPARK_WORKER_WEBUI_PORT=8081 - - BERDL_REDIS_HOST=redis - - BERDL_REDIS_PORT=6379 - - BERDL_DELTALAKE_WAREHOUSE_DIRECTORY_PATH=s3a://cdm-lake/users-sql-warehouse - - BERDL_HIVE_METASTORE_URI=thrift://hive-metastore:9083 depends_on: - spark-master @@ -260,7 +280,14 @@ services: - POSTGRES_DB=hive volumes: - postgres_data:/var/lib/postgresql/data - - ./scripts/init-postgres-readonly.sh:/docker-entrypoint-initdb.d/init-postgres-readonly.sh:ro + - ./scripts/init-postgres-readonly.sh:/docker-entrypoint-initdb.d/01-init-readonly.sh:ro + - ./scripts/init-polaris-db.sh:/docker-entrypoint-initdb.d/02-init-polaris-db.sh:ro + healthcheck: + test: ["CMD-SHELL", "pg_isready -U hive"] + interval: 5s + timeout: 2s + retries: 15 + hive-metastore: # image: ghcr.io/berdatalakehouse/hive_metastore:main @@ -326,13 +353,78 @@ services: echo 'Creating buckets...'; mc mb --ignore-existing local/cdm-lake; + mc mb --ignore-existing local/cdm-spark-job-logs; echo 'MinIO bucket creation complete.'; " + polaris-bootstrap: + image: apache/polaris-admin-tool:latest + environment: + - POLARIS_PERSISTENCE_TYPE=relational-jdbc + - QUARKUS_DATASOURCE_DB_KIND=postgresql + - QUARKUS_DATASOURCE_JDBC_URL=jdbc:postgresql://postgres:5432/polaris + - QUARKUS_DATASOURCE_USERNAME=hive + - QUARKUS_DATASOURCE_PASSWORD=hivepassword + # Bootstrap exits 3 if already bootstrapped (expected with persistent storage). + # Treat exit 3 as success so docker compose doesn't fail on subsequent runs. + entrypoint: ["sh", "-c"] + command: + - | + java -jar /deployments/polaris-admin-tool.jar bootstrap --realm=POLARIS --credential=POLARIS,root,s3cr3t + rc=$$? + if [ $$rc -eq 3 ]; then + echo "Already bootstrapped — skipping (OK)" + exit 0 + fi + exit $$rc + depends_on: + postgres: + condition: service_healthy + + polaris: + image: apache/polaris:latest + ports: + - "8181:8181" + environment: + # Persistence — PostgreSQL instead of in-memory + - POLARIS_PERSISTENCE_TYPE=relational-jdbc + - QUARKUS_DATASOURCE_DB_KIND=postgresql + - QUARKUS_DATASOURCE_JDBC_URL=jdbc:postgresql://postgres:5432/polaris + - QUARKUS_DATASOURCE_USERNAME=hive + - QUARKUS_DATASOURCE_PASSWORD=hivepassword + # Realm configuration + - POLARIS_REALM_NAME=default-realm + # MinIO credentials for Polaris's own S3 access (metadata files). + # Polaris reads endpointInternal + pathStyleAccess from each catalog's storageConfigInfo. + # STS is disabled per-catalog via stsUnavailable:true (not the global SKIP_CREDENTIAL flag). + - AWS_REGION=us-east-1 + - AWS_ACCESS_KEY_ID=minio + - AWS_SECRET_ACCESS_KEY=minio123 + healthcheck: + test: ["CMD", "curl", "-f", "http://localhost:8182/q/health"] + interval: 10s + timeout: 5s + retries: 5 + depends_on: + polaris-bootstrap: + condition: service_completed_successfully + + polaris-ui: + image: ghcr.io/binarycat0/apache-polaris-ui:latest + ports: + - "3000:3000" + environment: + # Server-side env vars used by the Next.js API routes for proxying auth + - POLARIS_MANAGEMENT_API_URL=http://polaris:8181/api/management/v1 + - POLARIS_CATALOG_API_URL=http://polaris:8181/api/catalog/v1 + depends_on: + polaris: + condition: service_started + volumes: postgres_data: minio_data: redis_data: global_share: - users_home: # Shared volume for all user home directories \ No newline at end of file + users_home: # Shared volume for all user home directories diff --git a/docs/data_sharing_guide.md b/docs/data_sharing_guide.md index dbac753..433b5f3 100644 --- a/docs/data_sharing_guide.md +++ b/docs/data_sharing_guide.md @@ -45,11 +45,13 @@ All BERDL JupyterHub notebooks automatically import these data governance functi **Pre-Initialized Client:** - `governance` - Pre-initialized `DataGovernanceClient()` instance for advanced operations - **Other Auto-Imported Functions:** -- `get_spark_session()` - Create Spark sessions with Delta Lake support +- `get_spark_session()` - Create Spark sessions with Iceberg + Delta Lake support +- `create_namespace_if_not_exists()` - Create namespaces (use `iceberg=True` for Iceberg catalogs) - Plus many other utility functions for data operations +> **Note:** With the migration to Iceberg, **tenant catalogs** are the recommended way to share data. Create tables in a tenant catalog (e.g., `kbase`) and all members can access them. See the [Iceberg Migration Guide](iceberg_migration_guide.md) for details. + ### Quick Start ```python @@ -258,7 +260,7 @@ if response.errors: ## Public and Private Table Access (DEPRECATED) -> **⚠️ DEPRECATION WARNING**: Direct public path sharing functions (`make_table_public`, `make_table_private`) are deprecated. Please create a namespace under the `globalusers` tenant for public sharing activities instead. +> **⚠️ DEPRECATION WARNING**: Direct public path sharing functions (`make_table_public`, `make_table_private`) are deprecated. Please create a namespace under the `kbase` tenant for public sharing activities instead. ### Make Tables Publicly Accessible diff --git a/docs/iceberg_migration_guide.md b/docs/iceberg_migration_guide.md new file mode 100644 index 0000000..b0d7a88 --- /dev/null +++ b/docs/iceberg_migration_guide.md @@ -0,0 +1,426 @@ +# Polaris Catalog Migration Guide + +## Why We Migrated + +KBERDL previously used **Delta Lake + Hive Metastore** with namespace isolation enforced by naming conventions — every database had to be prefixed with `u_{username}__` (personal) or `{tenant}_` (shared). This worked but had several limitations: + +- **Naming conventions are fragile** — isolation depends on every user following prefix rules correctly +- **No catalog-level boundaries** — all users share the same Hive Metastore, so a misconfigured namespace could leak data +- **Single-engine lock-in** — Delta Lake tables are only accessible through Spark with the Delta extension +- **No time travel or schema evolution** — Delta supports these, but Hive Metastore doesn't track them natively + +We migrated to **Apache Polaris + Apache Iceberg**, which provides: + +- **Catalog-level isolation** — each user gets their own Polaris catalog (`my`), and each tenant gets a shared catalog (e.g., `kbase`). No naming prefixes needed. +- **Multi-engine support** — Iceberg tables can be read by Spark, Trino, DuckDB, PyIceberg, and other engines +- **Native time travel** — query any previous snapshot of your data +- **Schema evolution** — add, rename, or drop columns without rewriting data +- **ACID transactions** — concurrent reads and writes are safe + +## What Changed + +| Aspect | Before (Delta/Hive) | After (Polaris/Iceberg) | +|--------|---------------------|------------------------| +| **Metadata catalog** | Hive Metastore | Apache Polaris (Iceberg REST catalog) | +| **Table format** | Delta Lake | Apache Iceberg | +| **Isolation model** | Naming prefixes (`u_user__`, `tenant_`) | Separate catalogs per user/tenant | +| **Your personal catalog** | Hive (shared, prefix-isolated) | `my` (dedicated Polaris catalog) | +| **Tenant catalogs** | Hive (shared, prefix-isolated) | One catalog per tenant (e.g., `kbase`) | +| **Credentials** | MinIO S3 keys only | MinIO S3 keys + Polaris OAuth2 (auto-provisioned) | +| **Spark session** | `get_spark_session()` | `get_spark_session()` (unchanged) | + +> **Migration complete:** All existing Delta Lake tables have been migrated to Iceberg in Polaris. Your data is available in the new Iceberg catalogs (`my` for personal, tenant name for shared). The original Delta tables remain accessible during the dual-read period for backward compatibility, but all new tables should be created in Iceberg. + +## Side-by-Side Comparison + +### Create a Namespace + +The function call is **unchanged** — only the return value format differs. + + + + + + + +
Before (Delta/Hive)After (Iceberg/Polaris)
+ +```python +spark = get_spark_session() + +# Personal namespace +ns = create_namespace_if_not_exists( + spark, "analysis" +) +# Returns: "u_tgu2__analysis" + +# Tenant namespace +ns = create_namespace_if_not_exists( + spark, "research", + tenant_name="kbase" +) +# Returns: "kbase_research" +``` + + + +```python +spark = get_spark_session() + +# Personal namespace (same call) +ns = create_namespace_if_not_exists( + spark, "analysis" +) +# Returns: "my.analysis" + +# Tenant namespace (same call) +ns = create_namespace_if_not_exists( + spark, "research", + tenant_name="kbase" +) +# Returns: "kbase.research" +``` + +
+ +### Write a Table + + + + + + + +
Before (Delta/Hive)After (Iceberg/Polaris)
+ +```python +ns = create_namespace_if_not_exists( + spark, "analysis" +) +# ns = "u_tgu2__analysis" + +df = spark.createDataFrame(data, columns) + +# Delta format +df.write.format("delta").saveAsTable( + f"{ns}.my_table" +) +``` + + + +```python +ns = create_namespace_if_not_exists( + spark, "analysis" +) +# ns = "my.analysis" + +df = spark.createDataFrame(data, columns) + +# Iceberg format (via writeTo API) +df.writeTo(f"{ns}.my_table").createOrReplace() + +# Or append to existing table +df.writeTo(f"{ns}.my_table").append() +``` + +
+ +### Read a Table + + + + + + + +
Before (Delta/Hive)After (Iceberg/Polaris)
+ +```python +# Query with prefixed namespace +df = spark.sql(""" + SELECT * FROM u_tgu2__analysis.my_table +""") + +# Or use the variable +df = spark.sql( + f"SELECT * FROM {ns}.my_table" +) +``` + + + +```python +# Query with catalog.namespace +df = spark.sql(""" + SELECT * FROM my.analysis.my_table +""") + +# Or use the variable +df = spark.sql( + f"SELECT * FROM {ns}.my_table" +) +``` + +
+ +### Cross-Catalog Queries + + + + + + + +
Before (Delta/Hive)After (Iceberg/Polaris)
+ +```python +# Everything in one Hive catalog +# Must know the naming convention +spark.sql(""" + SELECT u.name, d.dept_name + FROM u_tgu2__analysis.users u + JOIN kbase_shared.depts d + ON u.dept_id = d.id +""") +``` + + + +```python +# Explicit catalog boundaries +spark.sql(""" + SELECT u.name, d.dept_name + FROM my.analysis.users u + JOIN kbase.shared.depts d + ON u.dept_id = d.id +""") +``` + +
+ +### List Namespaces and Tables + + + + + + + +
Before (Delta/Hive)After (Iceberg/Polaris)
+ +```python +# Lists all Hive databases +# (filtered by u_{user}__ prefix) +list_namespaces(spark) + +# List tables in a namespace +list_tables(spark, "u_tgu2__analysis") +``` + + + +```python +# List namespaces in your catalog +spark.sql("SHOW NAMESPACES IN my") + +# List tables in a namespace +spark.sql( + "SHOW TABLES IN my.analysis" +) + +# list_tables still works +list_tables(spark, "my.analysis") +``` + +
+ +### Drop a Table + + + + + + + +
Before (Delta/Hive)After (Iceberg/Polaris)
+ +```python +spark.sql( + "DROP TABLE IF EXISTS " + "u_tgu2__analysis.my_table" +) +``` + + + +```python +spark.sql( + "DROP TABLE IF EXISTS " + "my.analysis.my_table" +) +``` + +> **Note:** `DROP TABLE` removes the catalog entry but does **not** delete the underlying S3 data files. `DROP TABLE ... PURGE` also does not delete files due to a [known Iceberg bug](https://github.com/apache/iceberg/issues/14743). To fully remove data, delete files from S3 directly using `get_minio_client()`. + +
+ +## Iceberg-Only Features + +These features are only available with Iceberg tables. + +### Time Travel + +Query a previous version of your table: + +```python +# View snapshot history +spark.sql("SELECT * FROM my.analysis.my_table.snapshots") + +# Read data as it was at a specific snapshot +spark.sql(""" + SELECT * FROM my.analysis.my_table + VERSION AS OF 1234567890 +""") + +# Read data as it was at a specific timestamp +spark.sql(""" + SELECT * FROM my.analysis.my_table + TIMESTAMP AS OF '2026-03-01 12:00:00' +""") +``` + +### Schema Evolution + +Modify table schema without rewriting data: + +```python +# Add a column +spark.sql("ALTER TABLE my.analysis.my_table ADD COLUMN email STRING") + +# Rename a column +spark.sql("ALTER TABLE my.analysis.my_table RENAME COLUMN name TO full_name") + +# Drop a column +spark.sql("ALTER TABLE my.analysis.my_table DROP COLUMN temp_field") +``` + +### Snapshot Management + +```python +# View snapshot history +display_df(spark.sql("SELECT * FROM my.analysis.my_table.snapshots")) + +# View file-level details +display_df(spark.sql("SELECT * FROM my.analysis.my_table.files")) + +# View table history +display_df(spark.sql("SELECT * FROM my.analysis.my_table.history")) +``` + +## Complete Example + +```python +# 1. Create a Spark session +print("1. Creating Spark session...") +spark = get_spark_session("MyAnalysis") +print(" Spark session ready.") + +# 2. Create a personal namespace +print("\n2. Creating personal namespace...") +ns = create_namespace_if_not_exists(spark, "demo") +print(f" Namespace: {ns}") # "my.demo" + +# 3. Create a table +print(f"\n3. Creating table {ns}.employees...") +data = [(1, "Alice", 25), (2, "Bob", 30), (3, "Charlie", 35)] +df = spark.createDataFrame(data, ["id", "name", "age"]) +df.writeTo(f"{ns}.employees").createOrReplace() +print(f" Table {ns}.employees created with {df.count()} rows.") + +# 4. Query the table +print(f"\n4. Querying {ns}.employees:") +result = spark.sql(f"SELECT * FROM {ns}.employees") +display_df(result) + +# 5. Append more data +print(f"\n5. Appending data to {ns}.employees...") +new_data = [(4, "Diana", 28)] +new_df = spark.createDataFrame(new_data, ["id", "name", "age"]) +new_df.writeTo(f"{ns}.employees").append() +print(f" Appended {new_df.count()} row(s). Total: {spark.sql(f'SELECT * FROM {ns}.employees').count()} rows.") +display_df(spark.sql(f"SELECT * FROM {ns}.employees")) + +# 6. View snapshots and files (two snapshots now: create + append) +print(f"\n6. Viewing snapshots and files for {ns}.employees:") +print(" Snapshots:") +display_df(spark.sql(f"SELECT * FROM {ns}.employees.snapshots")) +print(" Data files:") +display_df(spark.sql(f"SELECT * FROM {ns}.employees.files")) + +# 7. Time travel to the original version +print(f"\n7. Time travel to original version (before append)...") +first_snapshot = spark.sql( + f"SELECT snapshot_id FROM {ns}.employees.snapshots " + f"ORDER BY committed_at LIMIT 1" +).collect()[0]["snapshot_id"] +print(f" First snapshot ID: {first_snapshot}") +original = spark.sql( + f"SELECT * FROM {ns}.employees VERSION AS OF {first_snapshot}" +) +print(f" Original version ({original.count()} rows, before append):") +display_df(original) + +# 8. Tenant namespace (shared with your team) +print("\n8. Creating tenant namespace and shared table...") +tenant_ns = create_namespace_if_not_exists( + spark, "shared_data", tenant_name="globalusers" +) +print(f" Tenant namespace: {tenant_ns}") +df.writeTo(f"{tenant_ns}.team_employees").createOrReplace() +print(f" Table {tenant_ns}.team_employees created.") + +# 9. Cross-catalog query +print(f"\n9. Cross-catalog query ({ns} + {tenant_ns}):") +cross_result = spark.sql(f""" + SELECT * FROM {ns}.employees + UNION ALL + SELECT * FROM {tenant_ns}.team_employees +""") +display_df(cross_result) + +# 10. Schema evolution +print(f"\n10. Adding 'email' column to {ns}.employees...") +spark.sql(f"ALTER TABLE {ns}.employees ADD COLUMN email STRING") +print(f" Updated schema:") +spark.sql(f"DESCRIBE {ns}.employees").show() +display_df(spark.sql(f"SELECT * FROM {ns}.employees")) +print("Done!") +``` + +## FAQ + +**Q: Do I need to change my `get_spark_session()` call?** +No. `get_spark_session()` automatically configures both Delta and Iceberg catalogs. Your Polaris catalogs are ready to use. + +**Q: Can I still access my old Delta tables?** +Yes. During the dual-read period, your Delta tables remain accessible at their original names (e.g., `u_{username}__analysis.my_table`). Iceberg copies are at `my.analysis.my_table`. + +**Q: What happened to my namespace prefixes (`u_{username}__`)?** +They're no longer needed. With Iceberg, your personal catalog `my` is isolated at the catalog level — only you can access it. No prefix is required. + +**Q: How do I share data with my team?** +Create a table in a tenant catalog: +```python +ns = create_namespace_if_not_exists( + spark, "shared_data", tenant_name="kbase" +) +df.writeTo(f"{ns}.my_shared_table").createOrReplace() +``` +All members of the `kbase` tenant can read this table. + +**Q: Why does `DROP TABLE PURGE` leave files on S3?** +This is a [known Iceberg bug](https://github.com/apache/iceberg/issues/14743) — Spark's `SparkCatalog` ignores the `PURGE` flag when talking to REST catalogs. `DROP TABLE` only removes the catalog entry. To delete the S3 files, use the MinIO client directly. + +**Q: Can I use `df.write.format("iceberg").saveAsTable(...)` instead of `writeTo`?** +Yes, both work. `writeTo` is the recommended Iceberg API since it supports `createOrReplace()` and `append()` natively. diff --git a/docs/tenant_sql_warehouse_guide.md b/docs/tenant_sql_warehouse_guide.md index 6c5b4ed..6d9015d 100644 --- a/docs/tenant_sql_warehouse_guide.md +++ b/docs/tenant_sql_warehouse_guide.md @@ -6,152 +6,114 @@ This guide explains how to configure your Spark session to write tables to eithe The BERDL JupyterHub environment supports two types of SQL warehouses: -1. **User SQL Warehouse**: Your personal workspace for tables -2. **Tenant SQL Warehouse**: Shared workspace for team/organization tables +1. **User SQL Warehouse**: Your personal workspace for tables (Iceberg catalog: `my`) +2. **Tenant SQL Warehouse**: Shared workspace for team/organization tables (Iceberg catalog per tenant, e.g., `kbase`) -## Spark Connect Architecture - -BERDL uses **Spark Connect**, which provides a client-server architecture: -- **Connection Protocol**: Uses gRPC for efficient communication with remote Spark clusters -- **Spark Connect Server**: Runs locally in your notebook pod as a proxy -- **Spark Connect URL**: `sc://localhost:15002` -- **Driver and Executors**: Runs locally in your notebook pod - -### Automatic Credential Management - -Your MinIO S3 credentials are automatically configured when your notebook server starts: - -1. **JupyterHub** calls the governance API to create/retrieve your MinIO credentials -2. **Environment Variables** are set in your notebook: `MINIO_ACCESS_KEY`, `MINIO_SECRET_KEY`, `MINIO_USERNAME` -3. **Spark Cluster** receives your credentials and configures S3 access for Hive Metastore operations -4. **Notebooks** use credentials from environment (no API call needed) - -This ensures secure, per-user S3 access without exposing credentials in code. +Isolation is provided at the **catalog level** — no naming prefixes needed. -## Quick Comparison +| Configuration | Catalog | Example Namespace | +|--------------|---------|-------------------| +| `create_namespace_if_not_exists(spark, "analysis")` | `my` (personal) | `my.analysis` | +| `create_namespace_if_not_exists(spark, "research", tenant_name="kbase")` | `kbase` (tenant) | `kbase.research` | -| Configuration | Warehouse Location | Tables Location | Default Namespace | -|--------------|-------------------|----------------|-------------------| -| `create_namespace_if_not_exists(spark)` | `s3a://cdm-lake/users-sql-warehouse/{username}/` | Personal workspace | `u_{username}__default` | -| `create_namespace_if_not_exists(spark, tenant_name="kbase")` | `s3a://cdm-lake/tenant-sql-warehouse/kbase/` | Tenant workspace | `kbase_default` | +## Personal SQL Warehouse -## Personal SQL Warehouse (Default) +Your personal catalog (`my`) is automatically provisioned and only accessible by you. ### Usage -```python -from berdl_notebook_utils import get_spark_session -from berdl_notebook_utils.spark.database import create_namespace_if_not_exists - -# Create Spark session -spark = get_spark_session("MyPersonalApp") - -# Create namespace in your personal SQL warehouse -namespace = create_namespace_if_not_exists(spark) -``` -### Where Your Tables Go -- **Warehouse Directory**: `s3a://cdm-lake/users-sql-warehouse/{username}/` -- **Default Namespace**: `u_{username}__default` (username prefix added automatically) -- **Table Location**: `s3a://cdm-lake/users-sql-warehouse/{username}/u_{username}__default.db/your_table/` - -### Example ```python -from berdl_notebook_utils import get_spark_session -from berdl_notebook_utils.spark.database import create_namespace_if_not_exists - -# Create Spark session -spark = get_spark_session("MyPersonalApp") +spark = get_spark_session("MyApp") -# Create default namespace (creates "u_{username}__default" with username prefix) -namespace = create_namespace_if_not_exists(spark) -print(f"Created namespace: {namespace}") # Output: "Created namespace: u_{username}__default" +# Create namespace in your personal Iceberg catalog +ns = create_namespace_if_not_exists(spark, "analysis") +# Returns: "my.analysis" -# Or create custom namespace (creates "u_{username}__analysis" with username prefix) -analysis_namespace = create_namespace_if_not_exists(spark, namespace="analysis") -print(f"Created namespace: {analysis_namespace}") # Output: "Created namespace: u_{username}__analysis" - -# Create a DataFrame and save as table using returned namespace +# Write a table df = spark.createDataFrame([(1, "Alice"), (2, "Bob")], ["id", "name"]) -df.write.format("delta").saveAsTable(f"{namespace}.my_personal_table") +df.writeTo(f"{ns}.my_table").createOrReplace() -# Table will be stored at: -# s3a://cdm-lake/users-sql-warehouse/{username}/u_{username}__default.db/my_personal_table/ +# Read it back +spark.sql(f"SELECT * FROM {ns}.my_table").show() ``` +### Where Your Tables Go + +- **Catalog**: `my` (dedicated Polaris catalog) +- **Namespace**: User-defined (e.g., `analysis`, `experiments`) +- **Full table path**: `my.{namespace}.{table_name}` +- **S3 location**: Managed automatically by Polaris + ## Tenant SQL Warehouse +Tenant catalogs are shared across all members of a tenant/group. + ### Usage + ```python -from berdl_notebook_utils import get_spark_session -from berdl_notebook_utils.spark.database import create_namespace_if_not_exists +spark = get_spark_session("TeamApp") + +# Create namespace in tenant Iceberg catalog +ns = create_namespace_if_not_exists( + spark, "research", tenant_name="kbase" +) +# Returns: "kbase.research" -# Create Spark session -spark = get_spark_session("TeamAnalysis") +# Write a shared table +df.writeTo(f"{ns}.shared_dataset").createOrReplace() -# Create namespace in tenant SQL warehouse (validates membership) -namespace = create_namespace_if_not_exists(spark, tenant_name="kbase") +# All kbase members can read this table +spark.sql(f"SELECT * FROM {ns}.shared_dataset").show() ``` ### Where Your Tables Go -- **Warehouse Directory**: `s3a://cdm-lake/tenant-sql-warehouse/{tenant}/` -- **Default Namespace**: `{tenant}_default` (tenant prefix added automatically) -- **Table Location**: `s3a://cdm-lake/tenant-sql-warehouse/{tenant}/{tenant}_default.db/your_table/` + +- **Catalog**: Tenant name (e.g., `kbase`) +- **Namespace**: User-defined (e.g., `research`, `shared_data`) +- **Full table path**: `{tenant}.{namespace}.{table_name}` +- **S3 location**: Managed automatically by Polaris ### Requirements + - You must be a member of the specified tenant/group - The system validates your membership before allowing access -### Example -```python -from berdl_notebook_utils import get_spark_session -from berdl_notebook_utils.spark.database import create_namespace_if_not_exists - -# Create Spark session -spark = get_spark_session("TeamAnalysis") - -# Create default namespace in tenant warehouse (creates "{tenant}_default" with tenant prefix) -namespace = create_namespace_if_not_exists(spark, tenant_name="kbase") -print(f"Created namespace: {namespace}") # Output: "Created namespace: kbase_default" +## Spark Connect Architecture -# Or create custom namespace (creates "{tenant}_research" with tenant prefix) -research_namespace = create_namespace_if_not_exists(spark, namespace="research", tenant_name="kbase") -print(f"Created namespace: {research_namespace}") # Output: "Created namespace: kbase_research" +BERDL uses **Spark Connect**, which provides a client-server architecture: +- **Connection Protocol**: Uses gRPC for efficient communication with remote Spark clusters +- **Spark Connect Server**: Runs locally in your notebook pod as a proxy +- **Spark Connect URL**: `sc://localhost:15002` +- **Driver and Executors**: Runs locally in your notebook pod -# Create a DataFrame and save as table using returned namespace -df = spark.createDataFrame([(1, "Dataset A"), (2, "Dataset B")], ["id", "dataset"]) -df.write.format("delta").saveAsTable(f"{namespace}.shared_analysis") +### Automatic Credential Management -# Table will be stored at: -# s3a://cdm-lake/tenant-sql-warehouse/kbase/kbase_default.db/shared_analysis/ -``` +Your credentials are automatically configured when your notebook server starts: -## Advanced Namespace Management +1. **JupyterHub** calls the governance API to create/retrieve your MinIO and Polaris credentials +2. **Environment Variables** are set in your notebook (MinIO S3 keys + Polaris OAuth2 tokens) +3. **Spark Session** is pre-configured with access to your personal catalog and tenant catalogs +4. **Notebooks** use credentials from environment (no API call needed) -### Custom Namespaces (Default Behavior) -```python -# Personal warehouse with custom namespace (prefix enabled by default) -spark = get_spark_session() -exp_namespace = create_namespace_if_not_exists(spark, "experiments") # Returns "u_{username}__experiments" +This ensures secure, per-user access without exposing credentials in code. -# Tenant warehouse with custom namespace (prefix enabled by default) -data_namespace = create_namespace_if_not_exists(spark, "research_data", tenant_name="kbase") # Returns "kbase_research_data" +## Cross-Catalog Queries -# Use returned namespace names for table operations -df.write.format("delta").saveAsTable(f"{exp_namespace}.my_experiment_table") -df.write.format("delta").saveAsTable(f"{data_namespace}.shared_dataset") +You can query across personal and tenant catalogs in a single query: -# Tables will be stored at: -# s3a://cdm-lake/users-sql-warehouse/{username}/u_{username}__experiments.db/my_experiment_table/ -# s3a://cdm-lake/tenant-sql-warehouse/kbase/kbase_research_data.db/shared_dataset/ +```python +spark.sql(""" + SELECT u.name, d.dept_name + FROM my.analysis.users u + JOIN kbase.shared.depts d + ON u.dept_id = d.id +""") ``` ## Tips -- **Always use the returned namespace**: `create_namespace_if_not_exists()` returns the actual namespace name (with prefixes applied). Always use this value when creating tables. -- **Permission issues with manual namespaces**: If you create namespaces manually (without using `create_namespace_if_not_exists()`) and the namespace doesn't follow the expected naming rules (e.g., missing the `u_{username}__` or `{tenant}_` prefix), you may not have the correct permissions to read/write to that namespace. The governance system enforces permissions based on namespace naming conventions: - - User namespaces must start with `u_{username}__` to grant you access - - Tenant namespaces must start with `{tenant}_` and you must be a member of that tenant - - Namespaces without proper prefixes will result in "Access Denied" errors from MinIO +- **Always use the returned namespace**: `create_namespace_if_not_exists()` returns the fully qualified namespace (e.g., `my.analysis`). Always use this value when creating or querying tables. - **Tenant membership required**: Attempting to access a tenant warehouse without membership will fail. -- **Credentials are automatic**: MinIO credentials are set by JupyterHub - you don't need to call any API to get them. +- **Credentials are automatic**: MinIO and Polaris credentials are set by JupyterHub — you don't need to call any API to get them. - **Spark Connect is default**: All sessions use Spark Connect for better stability and resource isolation. +- **Iceberg features**: Your tables support time travel, schema evolution, and snapshot management. See the [Iceberg Migration Guide](iceberg_migration_guide.md) for details. diff --git a/docs/user_guide.md b/docs/user_guide.md index 6e29941..0e848a2 100644 --- a/docs/user_guide.md +++ b/docs/user_guide.md @@ -52,8 +52,9 @@ Log in using your KBase credentials (the same 3rd party identity provider, usern ### 3. Data Access By default, you have **read/write access** to: -- Your personal SQL warehouse (`s3a://cdm-lake/users-sql-warehouse/{username}/`) -- Any tenant SQL warehouses that you belong to (e.g., `s3a://cdm-lake/tenant-sql-warehouse/kbase/`) +- Your personal Iceberg catalog (`my`) — create namespaces and tables here +- Any tenant Iceberg catalogs you belong to (e.g., `kbase`) — shared team data +- Your personal S3 storage (`s3a://cdm-lake/users-sql-warehouse/{username}/`) For questions about data access or permissions, please reach out to the BERDL Platform team. @@ -80,9 +81,11 @@ spark = get_spark_session() This automatically configures: - Spark Connect server connection -- Hive Metastore integration +- Apache Iceberg catalogs (personal `my` + tenant catalogs via Polaris) - MinIO S3 access -- Delta Lake support +- Delta Lake support (legacy, for backward compatibility) + +> **Note:** BERDL has migrated from Delta Lake to Apache Iceberg. See the [Iceberg Migration Guide](iceberg_migration_guide.md) for details on the new catalog structure and Iceberg-specific features like time travel and schema evolution. #### 5.2 Displaying DataFrames diff --git a/notebook_utils/berdl_notebook_utils/__init__.py b/notebook_utils/berdl_notebook_utils/__init__.py index 3a7fc4f..b817750 100644 --- a/notebook_utils/berdl_notebook_utils/__init__.py +++ b/notebook_utils/berdl_notebook_utils/__init__.py @@ -56,6 +56,7 @@ AgentSettings, get_agent_settings, ) +from berdl_notebook_utils.refresh import refresh_spark_environment __all__ = [ "BERDLSettings", @@ -106,6 +107,8 @@ "BERDLAgent", "AgentSettings", "get_agent_settings", + # Environment refresh + "refresh_spark_environment", ] diff --git a/notebook_utils/berdl_notebook_utils/agent/tools.py b/notebook_utils/berdl_notebook_utils/agent/tools.py index 3d4bf4d..47f0cc1 100644 --- a/notebook_utils/berdl_notebook_utils/agent/tools.py +++ b/notebook_utils/berdl_notebook_utils/agent/tools.py @@ -82,7 +82,7 @@ def list_databases(dummy_input: str = "") -> str: JSON-formatted list of database names or error message """ try: - databases = mcp_ops.mcp_list_databases(use_hms=True) + databases = mcp_ops.mcp_list_databases() if not databases: return "No databases found. You may need to create a database first." return _format_result(databases) @@ -102,7 +102,7 @@ def list_tables(database: str) -> str: JSON-formatted list of table names or error message """ try: - tables = mcp_ops.mcp_list_tables(database=database, use_hms=True) + tables = mcp_ops.mcp_list_tables(database=database) if not tables: return f"No tables found in database '{database}'. The database may be empty." return _format_result(tables) @@ -143,7 +143,7 @@ def get_database_structure(input_str: str = "false") -> str: try: # Parse boolean input with_schema = input_str.lower() in ("true", "1", "yes") - structure = mcp_ops.mcp_get_database_structure(with_schema=with_schema, use_hms=True) + structure = mcp_ops.mcp_get_database_structure(with_schema=with_schema) return _format_result(structure) except Exception as e: logger.error(f"Error getting database structure: {e}") diff --git a/notebook_utils/berdl_notebook_utils/berdl_settings.py b/notebook_utils/berdl_notebook_utils/berdl_settings.py index d80c4b5..4831c3a 100644 --- a/notebook_utils/berdl_notebook_utils/berdl_settings.py +++ b/notebook_utils/berdl_notebook_utils/berdl_settings.py @@ -43,7 +43,7 @@ class BERDLSettings(BaseSettings): ) # Hive configuration - BERDL_HIVE_METASTORE_URI: AnyUrl # Accepts thrift:// + BERDL_HIVE_METASTORE_URI: AnyUrl | None = Field(default=None, description="Hive metastore URI (thrift://...)") # Profile-specific Spark configuration from JupyterHub SPARK_WORKER_COUNT: int = Field(default=1, description="Number of Spark workers from profile") @@ -74,6 +74,18 @@ class BERDLSettings(BaseSettings): default=None, description="Tenant Access Request Service URL for Slack-based approval workflow" ) + # Polaris Iceberg Catalog configuration + POLARIS_CATALOG_URI: AnyHttpUrl | None = Field( + default=None, description="Polaris REST Catalog endpoint (e.g., http://polaris:8181/api/catalog)" + ) + POLARIS_CREDENTIAL: str | None = Field(default=None, description="Polaris client_id:client_secret credential") + POLARIS_PERSONAL_CATALOG: str | None = Field( + default=None, description="Polaris personal catalog name (e.g., user_tgu2)" + ) + POLARIS_TENANT_CATALOGS: str | None = Field( + default=None, description="Comma-separated Polaris tenant catalog names" + ) + def validate_environment(): """ diff --git a/notebook_utils/berdl_notebook_utils/mcp/operations.py b/notebook_utils/berdl_notebook_utils/mcp/operations.py index ae86844..9d828bf 100644 --- a/notebook_utils/berdl_notebook_utils/mcp/operations.py +++ b/notebook_utils/berdl_notebook_utils/mcp/operations.py @@ -62,18 +62,17 @@ def _handle_error_response(response: Any, operation: str) -> None: raise Exception(error_msg) -def mcp_list_databases(use_hms: bool = True) -> list[str]: +def mcp_list_databases() -> list[str]: """ - List all databases in the Hive metastore via MCP server. + List all databases (Iceberg catalog namespaces) via MCP server. This function connects to the global datalake-mcp-server, which will use your authentication token to connect to your personal Spark Connect server. - Args: - use_hms: If True, uses Hive Metastore client for faster retrieval (default: True) + Note: Only Iceberg catalogs are listed (spark_catalog is excluded). Returns: - List of database names + List of database names (e.g., ['my.demo_dataset', 'kbase.pangenome']) Raises: Exception: If the MCP server returns an error or is unreachable @@ -81,12 +80,12 @@ def mcp_list_databases(use_hms: bool = True) -> list[str]: Example: >>> databases = mcp_list_databases() >>> print(databases) - ['default', 'my_database', 'analytics'] + ['my.demo_dataset', 'kbase.analytics'] """ client = get_datalake_mcp_client() - request = DatabaseListRequest(use_hms=use_hms) + request = DatabaseListRequest() - logger.debug(f"Listing databases via MCP server (use_hms={use_hms})") + logger.debug("Listing databases via MCP server") response = list_databases.sync(client=client, body=request) _handle_error_response(response, "list_databases") @@ -98,13 +97,12 @@ def mcp_list_databases(use_hms: bool = True) -> list[str]: return response.databases -def mcp_list_tables(database: str, use_hms: bool = True) -> list[str]: +def mcp_list_tables(database: str) -> list[str]: """ List all tables in a specific database via MCP server. Args: - database: Name of the database - use_hms: If True, uses Hive Metastore client for faster retrieval (default: True) + database: Name of the database (e.g., 'my.demo_dataset') Returns: List of table names in the database @@ -113,12 +111,12 @@ def mcp_list_tables(database: str, use_hms: bool = True) -> list[str]: Exception: If the MCP server returns an error or is unreachable Example: - >>> tables = mcp_list_tables("my_database") + >>> tables = mcp_list_tables("my.demo_dataset") >>> print(tables) ['users', 'orders', 'products'] """ client = get_datalake_mcp_client() - request = TableListRequest(database=database, use_hms=use_hms) + request = TableListRequest(database=database) logger.debug(f"Listing tables in database '{database}' via MCP server") response = list_database_tables.sync(client=client, body=request) @@ -167,14 +165,13 @@ def mcp_get_table_schema(database: str, table: str) -> list[str]: def mcp_get_database_structure( - with_schema: bool = False, use_hms: bool = True + with_schema: bool = False, ) -> dict[str, list[str] | dict[str, list[str]]]: """ Get the complete structure of all databases via MCP server. Args: with_schema: If True, includes table schemas (column names) (default: False) - use_hms: If True, uses Hive Metastore client for faster retrieval (default: True) Returns: Dictionary mapping database names to either: @@ -188,15 +185,15 @@ def mcp_get_database_structure( >>> # Without schema >>> structure = mcp_get_database_structure() >>> print(structure) - {'default': ['table1', 'table2'], 'analytics': ['metrics', 'events']} + {'my.demo': ['table1', 'table2'], 'kbase.analytics': ['metrics', 'events']} >>> # With schema >>> structure = mcp_get_database_structure(with_schema=True) >>> print(structure) - {'default': {'table1': ['col1', 'col2'], 'table2': ['col3', 'col4']}} + {'my.demo': {'table1': ['col1', 'col2'], 'table2': ['col3', 'col4']}} """ client = get_datalake_mcp_client() - request = DatabaseStructureRequest(with_schema=with_schema, use_hms=use_hms) + request = DatabaseStructureRequest(with_schema=with_schema) logger.debug(f"Getting database structure via MCP server (with_schema={with_schema})") response = get_database_structure.sync(client=client, body=request) diff --git a/notebook_utils/berdl_notebook_utils/minio_governance/__init__.py b/notebook_utils/berdl_notebook_utils/minio_governance/__init__.py index 624c413..1a09308 100644 --- a/notebook_utils/berdl_notebook_utils/minio_governance/__init__.py +++ b/notebook_utils/berdl_notebook_utils/minio_governance/__init__.py @@ -10,6 +10,7 @@ check_governance_health, get_group_sql_warehouse, get_minio_credentials, + get_polaris_credentials, get_my_accessible_paths, get_my_groups, get_my_policies, @@ -20,6 +21,7 @@ add_group_member, create_tenant_and_assign_users, list_groups, + list_user_names, list_users, remove_group_member, # Table operations @@ -31,6 +33,9 @@ # Tenant access requests list_available_groups, request_tenant_access, + # Migration (admin-only) + ensure_polaris_resources, + regenerate_policies, ) __all__ = [ @@ -38,6 +43,7 @@ "check_governance_health", "get_group_sql_warehouse", "get_minio_credentials", + "get_polaris_credentials", "get_my_accessible_paths", "get_my_groups", "get_my_policies", @@ -48,6 +54,7 @@ "add_group_member", "create_tenant_and_assign_users", "list_groups", + "list_user_names", "list_users", "remove_group_member", # Table operations @@ -59,4 +66,7 @@ # Tenant access requests "list_available_groups", "request_tenant_access", + # Migration (admin-only) + "ensure_polaris_resources", + "regenerate_policies", ] diff --git a/notebook_utils/berdl_notebook_utils/minio_governance/operations.py b/notebook_utils/berdl_notebook_utils/minio_governance/operations.py index 0f1b581..22a07fb 100644 --- a/notebook_utils/berdl_notebook_utils/minio_governance/operations.py +++ b/notebook_utils/berdl_notebook_utils/minio_governance/operations.py @@ -7,22 +7,33 @@ import logging import os import time +import warnings from pathlib import Path +from collections.abc import Callable +from typing import TypeVar + from typing import TypedDict import httpx from governance_client.api.credentials import get_credentials_credentials_get from governance_client.api.health import health_check_health_get +from governance_client.api.polaris import provision_polaris_user_polaris_user_provision_username_post from governance_client.api.management import ( add_group_member_management_groups_group_name_members_username_post, create_group_management_groups_group_name_post, + ensure_all_polaris_resources_management_migrate_ensure_polaris_resources_post, list_groups_management_groups_get, list_users_management_users_get, + regenerate_all_policies_management_migrate_regenerate_policies_post, remove_group_member_management_groups_group_name_members_username_delete, ) from governance_client.api.management.list_group_names_management_groups_names_get import ( sync as list_group_names_sync, ) +from governance_client.api.management.list_user_names_management_users_names_get import ( + sync as list_user_names_sync, +) +from governance_client.models.user_names_response import UserNamesResponse from governance_client.api.sharing import ( get_path_access_info_sharing_get_path_access_info_post, make_path_private_sharing_make_private_post, @@ -84,20 +95,54 @@ class TenantCreationResult(TypedDict): # Credential caching configuration CREDENTIALS_CACHE_FILE = ".berdl_minio_credentials" +POLARIS_CREDENTIALS_CACHE_FILE = ".berdl_polaris_credentials" # ============================================================================= # HELPER FUNCTIONS # ============================================================================= +_T = TypeVar("_T") + + +def _fetch_with_file_cache( + cache_path: Path, + read_cache: Callable[[Path], _T | None], + fetch: Callable[[], _T | None], + write_cache: Callable[[Path, _T], None], +) -> _T | None: + """Fetch credentials using file-based caching with exclusive file locking. + + The lock is released when the file handle is closed (exiting the `with` block). + We intentionally do NOT delete the lock file afterward — another process + may have already acquired a lock on it between our unlock and unlink. + """ + lock_path = cache_path.with_suffix(".lock") + with open(lock_path, "w") as lock_file: + fcntl.flock(lock_file.fileno(), fcntl.LOCK_EX) + + cached = read_cache(cache_path) + if cached is not None: + return cached + + result = fetch() + if result is not None: + write_cache(cache_path, result) + return result + def _get_credentials_cache_path() -> Path: - """Get the path to the credentials cache file in the user's home directory.""" + """Get the path to the MinIO credentials cache file in the user's home directory.""" return Path.home() / CREDENTIALS_CACHE_FILE +def _get_polaris_cache_path() -> Path: + """Get the path to the Polaris credentials cache file in the user's home directory.""" + return Path.home() / POLARIS_CREDENTIALS_CACHE_FILE + + def _read_cached_credentials(cache_path: Path) -> CredentialsResponse | None: - """Read credentials from cache file. Returns None if file doesn't exist or is corrupted.""" + """Read MinIO credentials from cache file. Returns None if file doesn't exist or is corrupted.""" try: if not cache_path.exists(): return None @@ -109,7 +154,7 @@ def _read_cached_credentials(cache_path: Path) -> CredentialsResponse | None: def _write_credentials_cache(cache_path: Path, credentials: CredentialsResponse) -> None: - """Write credentials to cache file.""" + """Write MinIO credentials to cache file.""" try: with open(cache_path, "w") as f: json.dump(credentials.to_dict(), f) @@ -117,6 +162,30 @@ def _write_credentials_cache(cache_path: Path, credentials: CredentialsResponse) pass +def _read_cached_polaris_credentials(cache_path: Path) -> "PolarisCredentials | None": + """Read Polaris credentials from cache file. Returns None if file doesn't exist or is corrupted.""" + try: + if not cache_path.exists(): + return None + with open(cache_path, "r") as f: + data = json.load(f) + # Validate required keys are present + if all(k in data for k in ("client_id", "client_secret", "personal_catalog", "tenant_catalogs")): + return data + return None + except (json.JSONDecodeError, TypeError, KeyError, OSError): + return None + + +def _write_polaris_credentials_cache(cache_path: Path, credentials: "PolarisCredentials") -> None: + """Write Polaris credentials to cache file.""" + try: + with open(cache_path, "w") as f: + json.dump(credentials, f) + except (OSError, TypeError): + pass + + def _build_table_path(username: str, namespace: str, table_name: str) -> str: """ Build S3 path for a SQL warehouse table. @@ -147,6 +216,15 @@ def check_governance_health() -> HealthResponse: return health_check_health_get.sync(client=client) +def _fetch_minio_credentials() -> CredentialsResponse | None: + """Fetch fresh MinIO credentials from the governance API.""" + client = get_governance_client() + api_response = get_credentials_credentials_get.sync(client=client) + if isinstance(api_response, CredentialsResponse): + return api_response + return None + + def get_minio_credentials() -> CredentialsResponse: """ Get MinIO credentials for the current user and set them as environment variables. @@ -161,37 +239,14 @@ def get_minio_credentials() -> CredentialsResponse: Returns: CredentialsResponse with username, access_key, and secret_key """ - cache_path = _get_credentials_cache_path() - lock_path = cache_path.with_suffix(".lock") - - # Use file locking to prevent concurrent access - with open(lock_path, "w") as lock_file: - try: - # Acquire exclusive lock (blocks until available) - fcntl.flock(lock_file.fileno(), fcntl.LOCK_EX) - - # Try to load from cache first (double-check after acquiring lock) - cached_credentials = _read_cached_credentials(cache_path) - if cached_credentials: - credentials = cached_credentials - else: - # No cache or cache corrupted, fetch fresh credentials - client = get_governance_client() - api_response = get_credentials_credentials_get.sync(client=client) - if isinstance(api_response, CredentialsResponse): - credentials = api_response - _write_credentials_cache(cache_path, credentials) - else: - raise RuntimeError("Failed to fetch credentials from API") - finally: - # Lock is automatically released when file is closed - pass - - # Clean up lock file if it exists - try: - lock_path.unlink(missing_ok=True) - except OSError: - pass + credentials = _fetch_with_file_cache( + _get_credentials_cache_path(), + _read_cached_credentials, + _fetch_minio_credentials, + _write_credentials_cache, + ) + if credentials is None: + raise RuntimeError("Failed to fetch credentials from API") # Set MinIO credentials as environment variables os.environ["MINIO_ACCESS_KEY"] = credentials.access_key @@ -200,6 +255,82 @@ def get_minio_credentials() -> CredentialsResponse: return credentials +class PolarisCredentials(TypedDict): + """Polaris credential provisioning result.""" + + client_id: str + client_secret: str + personal_catalog: str + tenant_catalogs: list[str] + + +def _fetch_polaris_credentials() -> PolarisCredentials | None: + """Fetch fresh Polaris credentials from the governance API.""" + settings = get_settings() + polaris_logger = logging.getLogger(__name__) + + client = get_governance_client() + api_response = provision_polaris_user_polaris_user_provision_username_post.sync( + username=settings.USER, client=client + ) + + if isinstance(api_response, ErrorResponse): + polaris_logger.warning(f"Polaris provisioning failed: {api_response.message}") + return None + if api_response is None: + polaris_logger.warning("Polaris provisioning returned no response") + return None + + data = api_response.to_dict() + return { + "client_id": data.get("client_id", ""), + "client_secret": data.get("client_secret", ""), + "personal_catalog": data.get("personal_catalog", ""), + "tenant_catalogs": data.get("tenant_catalogs", []), + } + + +def get_polaris_credentials() -> PolarisCredentials | None: + """ + Provision a Polaris catalog for the current user and set credentials as environment variables. + + Uses file locking and caching to prevent race conditions and avoid unnecessary + API calls when credentials are already cached (same pattern as get_minio_credentials). + + Calls POST /polaris/user_provision/{username} on the governance API on cache miss. + This provisions the user's Polaris environment (catalog, principal, roles, credentials, + tenant access) and returns the configuration. + + Sets the following environment variables: + - POLARIS_CREDENTIAL: client_id:client_secret for authenticating with Polaris + - POLARIS_PERSONAL_CATALOG: Name of the user's personal Polaris catalog + - POLARIS_TENANT_CATALOGS: Comma-separated list of tenant catalogs the user has access to + + Returns: + PolarisCredentials dict, or None if Polaris is not configured + """ + settings = get_settings() + + if not settings.POLARIS_CATALOG_URI: + return None + + result = _fetch_with_file_cache( + _get_polaris_cache_path(), + _read_cached_polaris_credentials, + _fetch_polaris_credentials, + _write_polaris_credentials_cache, + ) + if result is None: + return None + + # Set as environment variables for Spark catalog configuration + os.environ["POLARIS_CREDENTIAL"] = f"{result['client_id']}:{result['client_secret']}" + os.environ["POLARIS_PERSONAL_CATALOG"] = result["personal_catalog"] + os.environ["POLARIS_TENANT_CATALOGS"] = ",".join(result["tenant_catalogs"]) + + return result + + def get_my_sql_warehouse() -> UserSqlWarehousePrefixResponse: """ Get SQL warehouse prefix for the current user. @@ -346,6 +477,13 @@ def share_table( Example: share_table("analytics", "user_metrics", with_users=["alice", "bob"]) """ + warnings.warn( + "share_table is deprecated and will be removed in a future release. " + "Direct path sharing is no longer recommended. Please create a Tenant Workspace " + "and request access to the tenant for sharing activities.", + DeprecationWarning, + stacklevel=2, + ) client = get_governance_client() # Get current user's username from environment variable username = get_settings().USER @@ -384,6 +522,13 @@ def unshare_table( Example: unshare_table("analytics", "user_metrics", from_users=["alice"]) """ + warnings.warn( + "unshare_table is deprecated and will be removed in a future release. " + "Direct path sharing is no longer recommended. Please create a Tenant Workspace " + "and request access to the tenant for unsharing activities.", + DeprecationWarning, + stacklevel=2, + ) client = get_governance_client() # Get current user's username from environment variable username = get_settings().USER @@ -418,6 +563,13 @@ def make_table_public( Example: make_table_public("research", "public_dataset") """ + warnings.warn( + "make_table_public is deprecated and will be removed in a future release. " + "Direct public path sharing is no longer recommended. Please create a namespace " + "under the `globalusers` tenant for public sharing activities.", + DeprecationWarning, + stacklevel=2, + ) client = get_governance_client() # Get current user's username from environment variable username = get_settings().USER @@ -444,6 +596,13 @@ def make_table_private( Example: make_table_private("research", "sensitive_data") """ + warnings.warn( + "make_table_private is deprecated and will be removed in a future release. " + "Direct public path sharing is no longer recommended. Please remove the namespace " + "under the `globalusers` tenant to revoke public access.", + DeprecationWarning, + stacklevel=2, + ) client = get_governance_client() # Get current user's username from environment variable username = get_settings().USER @@ -507,9 +666,17 @@ def list_groups() -> dict | ErrorResponse | None: return list_groups_management_groups_get.sync(client=client) -def list_users(): +def list_users(page: int = 1, page_size: int = 500): """ - List all users in the system. + List all users in the system with full details. + + This fetches full user info (policies, groups, paths) for each user, + which can be slow with many users. If you only need usernames, + use ``list_user_names()`` instead. + + Args: + page: Page number (1-based). Default: 1. + page_size: Number of users per page. Default: 500. Returns: UserListResponse with user information, or ErrorResponse on failure. @@ -519,7 +686,32 @@ def list_users(): # Returns list of user information """ client = get_governance_client() - return list_users_management_users_get.sync(client=client) + return list_users_management_users_get.sync(client=client, page=page, page_size=page_size) + + +def list_user_names() -> list[str]: + """ + List all usernames in the system (lightweight). + + This is much faster than ``list_users()`` because it only returns + usernames without fetching full user details (policies, groups, paths). + + Returns: + List of usernames. + + Raises: + RuntimeError: If the API call fails. + """ + client = get_governance_client() + response = list_user_names_sync(client=client) + + if isinstance(response, ErrorResponse): + raise RuntimeError(f"Failed to list usernames: {response.message}") + + if not isinstance(response, UserNamesResponse): + raise RuntimeError("Failed to list usernames: no response from API") + + return response.usernames def add_group_member( @@ -777,3 +969,40 @@ def request_tenant_access( raise RuntimeError(f"Failed to submit access request: {e.response.status_code} - {e.response.text}") except httpx.RequestError as e: raise RuntimeError(f"Failed to connect to tenant access service: {e}") + + +# ============================================================================= +# MIGRATION - Admin-only bulk operations for IAM + Polaris migration +# ============================================================================= + + +def regenerate_policies(): + """ + Force-regenerate all MinIO IAM HOME policies from the current template. + + This admin-only endpoint updates pre-existing policies to include new path + statements (e.g., Iceberg paths). Each regeneration is independent — errors + do not block others. + + Returns: + RegeneratePoliciesResponse with users_updated, groups_updated, errors, + or ErrorResponse on failure. + """ + client = get_governance_client() + return regenerate_all_policies_management_migrate_regenerate_policies_post.sync(client=client) + + +def ensure_polaris_resources(): + """ + Ensure Polaris resources exist for all users and groups. + + Creates Polaris principals, personal catalogs, and roles for all users. + Creates tenant catalogs for all base groups. Grants correct principal roles + based on group memberships. All operations are idempotent. + + Returns: + EnsurePolarisResponse with users_provisioned, groups_provisioned, errors, + or ErrorResponse on failure. + """ + client = get_governance_client() + return ensure_all_polaris_resources_management_migrate_ensure_polaris_resources_post.sync(client=client) diff --git a/notebook_utils/berdl_notebook_utils/refresh.py b/notebook_utils/berdl_notebook_utils/refresh.py new file mode 100644 index 0000000..d9affc7 --- /dev/null +++ b/notebook_utils/berdl_notebook_utils/refresh.py @@ -0,0 +1,116 @@ +""" +Refresh credentials and Spark environment. + +Provides a single function to clear all credential caches, re-provision +MinIO and Polaris credentials, restart the Spark Connect server, and stop +any existing Spark session — ensuring get_spark_session() works afterward. +""" + +import logging +from pathlib import Path + +from pyspark.sql import SparkSession + +from berdl_notebook_utils.berdl_settings import get_settings +from berdl_notebook_utils.minio_governance.operations import ( + CREDENTIALS_CACHE_FILE, + POLARIS_CREDENTIALS_CACHE_FILE, + get_minio_credentials, + get_polaris_credentials, +) +from berdl_notebook_utils.spark.connect_server import start_spark_connect_server + +logger = logging.getLogger("berdl.refresh") + + +def _remove_cache_file(path: Path) -> bool: + """Remove a cache file. Returns True if the main file existed.""" + existed = False + try: + if path.exists(): + path.unlink() + existed = True + except OSError: + pass + return existed + + +def refresh_spark_environment() -> dict: + """Clear all credential caches, re-provision credentials, and restart Spark. + + Steps performed: + 1. Delete MinIO and Polaris credential cache files + 2. Clear the in-memory ``get_settings()`` LRU cache + 3. Re-fetch MinIO credentials (sets MINIO_ACCESS_KEY/SECRET_KEY env vars) + 4. Re-fetch Polaris credentials (sets POLARIS_CREDENTIAL and catalog env vars) + 5. Clear settings cache again so downstream code sees fresh env vars + 6. Stop any existing Spark session + 7. Restart the Spark Connect server with regenerated spark-defaults.conf + + Returns: + dict with keys ``minio``, ``polaris``, ``spark_connect``, + ``spark_session_stopped`` summarising what happened. + """ + home = Path.home() + result: dict = {} + + # 1. Delete credential cache files + minio_removed = _remove_cache_file(home / CREDENTIALS_CACHE_FILE) + polaris_removed = _remove_cache_file(home / POLARIS_CREDENTIALS_CACHE_FILE) + logger.info( + "Cleared credential caches (minio=%s, polaris=%s)", + minio_removed, + polaris_removed, + ) + + # 2. Clear in-memory settings cache + get_settings.cache_clear() + + # 3. Re-fetch MinIO credentials + try: + minio_creds = get_minio_credentials() + result["minio"] = {"status": "ok", "username": minio_creds.username} + logger.info("MinIO credentials refreshed for user: %s", minio_creds.username) + except Exception as exc: + result["minio"] = {"status": "error", "error": str(exc)} + logger.warning("Failed to refresh MinIO credentials: %s", exc) + + # 4. Re-fetch Polaris credentials + try: + polaris_creds = get_polaris_credentials() + if polaris_creds: + result["polaris"] = { + "status": "ok", + "personal_catalog": polaris_creds["personal_catalog"], + "tenant_catalogs": polaris_creds.get("tenant_catalogs", []), + } + logger.info("Polaris credentials refreshed for catalog: %s", polaris_creds["personal_catalog"]) + else: + result["polaris"] = {"status": "skipped", "reason": "Polaris not configured"} + logger.info("Polaris not configured, skipping credential refresh") + except Exception as exc: + result["polaris"] = {"status": "error", "error": str(exc)} + logger.warning("Failed to refresh Polaris credentials: %s", exc) + + # 5. Clear settings cache again so get_settings() picks up new env vars + get_settings.cache_clear() + + # 6. Stop existing Spark session + existing = SparkSession.getActiveSession() + if existing: + existing.stop() + result["spark_session_stopped"] = True + logger.info("Stopped existing Spark session") + else: + result["spark_session_stopped"] = False + + # 7. Restart Spark Connect server with fresh config + try: + sc_result = start_spark_connect_server(force_restart=True) + result["spark_connect"] = sc_result + logger.info("Spark Connect server restarted: %s", sc_result.get("status", "unknown")) + except Exception as exc: + result["spark_connect"] = {"status": "error", "error": str(exc)} + logger.warning("Failed to restart Spark Connect server: %s", exc) + + return result diff --git a/notebook_utils/berdl_notebook_utils/setup_spark_session.py b/notebook_utils/berdl_notebook_utils/setup_spark_session.py index 0373af2..209d641 100644 --- a/notebook_utils/berdl_notebook_utils/setup_spark_session.py +++ b/notebook_utils/berdl_notebook_utils/setup_spark_session.py @@ -14,8 +14,8 @@ from pyspark.conf import SparkConf from pyspark.sql import SparkSession -from berdl_notebook_utils import BERDLSettings, get_settings -from berdl_notebook_utils.minio_governance.operations import ( +from .berdl_settings import BERDLSettings, get_settings +from .minio_governance.operations import ( get_group_sql_warehouse, get_my_sql_warehouse, ) @@ -165,6 +165,60 @@ def _get_spark_defaults_conf() -> dict[str, str]: } +def _get_catalog_conf(settings: BERDLSettings) -> dict[str, str]: + """Get Iceberg catalog configuration for Polaris REST catalog.""" + config = {} + + if not settings.POLARIS_CATALOG_URI: + return config + + polaris_uri = str(settings.POLARIS_CATALOG_URI).rstrip("/") + + # S3/MinIO properties for Iceberg's S3FileIO (used by executors to read/write data files). + # Iceberg does NOT use Spark's spark.hadoop.fs.s3a.* — it has its own AWS SDK S3 client. + s3_endpoint = settings.MINIO_ENDPOINT_URL + if not s3_endpoint.startswith("http"): + s3_endpoint = f"http://{s3_endpoint}" + s3_props = { + "s3.endpoint": s3_endpoint, + "s3.access-key-id": settings.MINIO_ACCESS_KEY, + "s3.secret-access-key": settings.MINIO_SECRET_KEY, + "s3.path-style-access": "true", + "s3.region": "us-east-1", + } + + def _catalog_props(prefix: str, warehouse: str) -> dict[str, str]: + props = { + f"{prefix}": "org.apache.iceberg.spark.SparkCatalog", + f"{prefix}.type": "rest", + f"{prefix}.uri": polaris_uri, + f"{prefix}.credential": settings.POLARIS_CREDENTIAL or "", + f"{prefix}.warehouse": warehouse, + f"{prefix}.scope": "PRINCIPAL_ROLE:ALL", + f"{prefix}.token-refresh-enabled": "false", + f"{prefix}.client.region": "us-east-1", + } + # Add S3 properties scoped to this catalog + for k, v in s3_props.items(): + props[f"{prefix}.{k}"] = v + return props + + # 1. Add Personal Catalog (if configured) + if settings.POLARIS_PERSONAL_CATALOG: + config.update(_catalog_props("spark.sql.catalog.my", settings.POLARIS_PERSONAL_CATALOG)) + + # 2. Add Tenant Catalogs (if configured) + if settings.POLARIS_TENANT_CATALOGS: + for tenant_catalog in settings.POLARIS_TENANT_CATALOGS.split(","): + tenant_catalog = tenant_catalog.strip() + if not tenant_catalog: + continue + catalog_alias = tenant_catalog[7:] if tenant_catalog.startswith("tenant_") else tenant_catalog + config.update(_catalog_props(f"spark.sql.catalog.{catalog_alias}", tenant_catalog)) + + return config + + def _get_delta_conf() -> dict[str, str]: return { "spark.sql.extensions": "io.delta.sql.DeltaSparkSessionExtension", @@ -245,6 +299,18 @@ def _get_s3_conf(settings: BERDLSettings, tenant_name: str | None = None) -> dic } +def _is_immutable_config(key: str) -> bool: + """Check if a Spark config key is immutable in Spark Connect mode.""" + if key in IMMUTABLE_CONFIGS: + return True + # Iceberg catalog configs (spark.sql.catalog..*) are all static/immutable + # because catalogs must be registered at server startup. The catalog names are + # dynamic (personal "my" + tenant aliases), so we match by prefix. + if key.startswith("spark.sql.catalog.") and key != "spark.sql.catalog.spark_catalog": + return True + return False + + def _filter_immutable_spark_connect_configs(config: dict[str, str]) -> dict[str, str]: """ Filter out configurations that cannot be modified in Spark Connect mode. @@ -259,7 +325,7 @@ def _filter_immutable_spark_connect_configs(config: dict[str, str]) -> dict[str, Filtered configuration dictionary with only mutable configs """ - return {k: v for k, v in config.items() if k not in IMMUTABLE_CONFIGS} + return {k: v for k, v in config.items() if not _is_immutable_config(k)} def _set_scheduler_pool(spark: SparkSession, scheduler_pool: str) -> None: @@ -313,6 +379,9 @@ def generate_spark_conf( if use_hive: config.update(_get_hive_conf(settings)) + # Always add Polaris catalogs if they are configured + config.update(_get_catalog_conf(settings)) + if use_spark_connect: # Spark Connect: filter out immutable configs that cannot be modified from the client config = _filter_immutable_spark_connect_configs(config) @@ -392,4 +461,23 @@ def get_spark_session( spark.sparkContext.setLogLevel("DEBUG") _set_scheduler_pool(spark, scheduler_pool) + # Warm up Polaris REST catalogs so they appear in SHOW CATALOGS immediately. + # Spark lazily initializes REST catalog plugins — they only show up in + # CatalogManager._catalogs (and therefore SHOW CATALOGS) after first access. + if use_spark_connect and not local: + _settings = settings or get_settings() + _catalog_aliases: list[str] = [] + if _settings.POLARIS_PERSONAL_CATALOG: + _catalog_aliases.append("my") + if _settings.POLARIS_TENANT_CATALOGS: + for _raw in _settings.POLARIS_TENANT_CATALOGS.split(","): + _raw = _raw.strip() + if _raw: + _catalog_aliases.append(_raw.removeprefix("tenant_")) + for _alias in _catalog_aliases: + try: + spark.sql(f"SHOW NAMESPACES IN {_alias}").collect() + except Exception: + pass # catalog may not have any namespaces yet; access is enough to register it + return spark diff --git a/notebook_utils/berdl_notebook_utils/spark/__init__.py b/notebook_utils/berdl_notebook_utils/spark/__init__.py index 55b99bd..287a5e2 100644 --- a/notebook_utils/berdl_notebook_utils/spark/__init__.py +++ b/notebook_utils/berdl_notebook_utils/spark/__init__.py @@ -4,7 +4,7 @@ This package provides comprehensive Spark utilities organized into focused modules: - database: Catalog and namespace management utilities - dataframe: DataFrame operations and display functions -- data_store: Hive metastore and database information utilities +- data_store: Iceberg catalog browsing utilities - connect_server: Spark Connect server management All functions are imported at the package level for convenient access. diff --git a/notebook_utils/berdl_notebook_utils/spark/connect_server.py b/notebook_utils/berdl_notebook_utils/spark/connect_server.py index e517d8c..a9522db 100644 --- a/notebook_utils/berdl_notebook_utils/spark/connect_server.py +++ b/notebook_utils/berdl_notebook_utils/spark/connect_server.py @@ -9,20 +9,22 @@ import os import shutil import signal +import socket import subprocess import time from pathlib import Path from typing import Optional -from berdl_notebook_utils.berdl_settings import BERDLSettings, get_settings -from berdl_notebook_utils.minio_governance.operations import ( +from ..berdl_settings import BERDLSettings, get_settings +from ..minio_governance.operations import ( get_my_groups, get_my_sql_warehouse, get_namespace_prefix, ) -from berdl_notebook_utils.setup_spark_session import ( +from ..setup_spark_session import ( DRIVER_MEMORY_OVERHEAD, EXECUTOR_MEMORY_OVERHEAD, + _get_catalog_conf, convert_memory_format, ) @@ -119,6 +121,13 @@ def generate_spark_config(self) -> None: warehouse_response = get_my_sql_warehouse() f.write(f"spark.sql.warehouse.dir={warehouse_response.sql_warehouse_prefix}\n") + # Add Polaris Iceberg Catalogs + catalog_configs = _get_catalog_conf(self.settings) + if catalog_configs: + f.write("\n# Polaris Catalog Configuration\n") + for key, value in catalog_configs.items(): + f.write(f"{key}={value}\n") + logger.info(f"Spark configuration written to {self.spark_defaults_path}") def compute_allowed_namespace_prefixes(self) -> str: @@ -288,8 +297,6 @@ def _wait_for_port_release(self, timeout: int) -> bool: Returns: True if port is free, False if timeout reached. """ - import socket - port = self.config.spark_connect_port start_time = time.time() @@ -319,9 +326,9 @@ def start(self, force_restart: bool = False) -> dict: Dictionary with server information. """ # Check if server is already running - if self.is_running(): + server_info = self.get_server_info() + if server_info is not None: if not force_restart: - server_info = self.get_server_info() logger.info(f"✅ Spark Connect server already running (PID: {server_info['pid']})") logger.info(" Reusing existing server - no need to start a new one") return server_info @@ -382,6 +389,9 @@ def start(self, force_restart: bool = False) -> dict: f.write(str(process.pid)) server_info = self.get_server_info() + if server_info is None: + raise RuntimeError("Failed to get server info after startup") + logger.info(f"✅ Spark Connect server started successfully (PID: {process.pid})") logger.info(f" Connect URL: {server_info['url']}") logger.info(f" Logs: {server_info['log_file']}") @@ -401,8 +411,8 @@ def status(self) -> dict: Returns: Dictionary with status information. """ - if self.is_running(): - info = self.get_server_info() + info = self.get_server_info() + if info is not None: return { "status": "running", **info, diff --git a/notebook_utils/berdl_notebook_utils/spark/data_store.py b/notebook_utils/berdl_notebook_utils/spark/data_store.py index 32ab985..fc7b921 100644 --- a/notebook_utils/berdl_notebook_utils/spark/data_store.py +++ b/notebook_utils/berdl_notebook_utils/spark/data_store.py @@ -10,10 +10,11 @@ from functools import wraps from typing import Any, Callable, Dict, List, Optional, TypeVar, Union -from berdl_notebook_utils import hive_metastore from pyspark.sql import SparkSession -from berdl_notebook_utils.setup_spark_session import get_spark_session + +from berdl_notebook_utils import hive_metastore from berdl_notebook_utils.minio_governance import get_my_accessible_paths, get_my_groups, get_namespace_prefix +from berdl_notebook_utils.setup_spark_session import get_spark_session # ============================================================================= # TTL CACHE FOR GOVERNANCE API CALLS diff --git a/notebook_utils/berdl_notebook_utils/spark/database.py b/notebook_utils/berdl_notebook_utils/spark/database.py index 3e2d05c..60b7f1e 100644 --- a/notebook_utils/berdl_notebook_utils/spark/database.py +++ b/notebook_utils/berdl_notebook_utils/spark/database.py @@ -3,6 +3,10 @@ This module contains utility functions to interact with the Spark catalog, including tenant-aware namespace management for BERDL SQL warehouses. + +create_namespace_if_not_exists() supports two flows: +- Delta/Hive: governance prefixes (u_user__, t_tenant__) when no catalog is specified +- Polaris Iceberg: catalog-level isolation (no prefixes) when catalog is specified """ from pyspark.sql import SparkSession @@ -40,6 +44,11 @@ def generate_namespace_location(namespace: str | None = None, tenant_name: str | # Always fetch warehouse directory from governance API for proper S3 location # Don't rely on spark.sql.warehouse.dir as it may be set to local path by Spark Connect server warehouse_response = get_group_sql_warehouse(tenant_name) if tenant_name else get_my_sql_warehouse() + + if hasattr(warehouse_response, "message") and not getattr(warehouse_response, "sql_warehouse_prefix", None): + print(f"Warning: Failed to get warehouse location: {getattr(warehouse_response, 'message', 'Unknown error')}") + return (namespace, None) + warehouse_dir = warehouse_response.sql_warehouse_prefix if warehouse_dir and ("users-sql-warehouse" in warehouse_dir or "tenant-sql-warehouse" in warehouse_dir): @@ -73,35 +82,48 @@ def generate_namespace_location(namespace: str | None = None, tenant_name: str | def create_namespace_if_not_exists( spark: SparkSession, namespace: str | None = DEFAULT_NAMESPACE, - append_target: bool = True, tenant_name: str | None = None, + iceberg: bool = False, ) -> str: """ Create a namespace in the Spark catalog if it does not exist. - If append_target is True, automatically prepends the governance-provided namespace prefix - based on the warehouse directory type (user vs tenant) to create the properly formatted namespace. + Supports two flows controlled by the *iceberg* flag: - For Spark Connect, this function explicitly sets the database LOCATION to ensure tables are - written to the correct S3 path, since spark.sql.warehouse.dir cannot be modified per session. + **Iceberg flow** (``iceberg=True``): + Creates ``{catalog}.{namespace}`` with no governance prefixes — the + catalog itself provides isolation. The catalog is determined by + *tenant_name*: ``None`` → ``"my"`` (user catalog), otherwise the + tenant name is used as the catalog name. + + **Delta/Hive flow** (``iceberg=False``, the default): + Prepends the governance-provided namespace prefix based on the + warehouse directory type (user vs tenant) and explicitly sets the + database LOCATION for Spark Connect compatibility. :param spark: The Spark session. :param namespace: The name of the namespace. - :param append_target: If True, prepends governance namespace prefix based on warehouse type. - If False, uses namespace as-is. - :param tenant_name: Optional tenant name. If provided, uses tenant warehouse. Otherwise uses user warehouse. - :return: The name of the namespace. + :param tenant_name: Delta/Hive: optional tenant name for tenant warehouse. + Iceberg: used as catalog name (defaults to ``"my"`` when None). + :param iceberg: If True, uses the Iceberg flow with catalog-level isolation. + :return: The fully-qualified namespace name. """ - db_location = None + namespace = _namespace_norm(namespace) - if append_target: - try: - namespace, db_location = generate_namespace_location(namespace, tenant_name) - except Exception as e: - print(f"Error creating namespace: {e}") - raise e - else: - namespace = _namespace_norm(namespace) + # Iceberg flow: catalog-level isolation, no governance prefixes + if iceberg: + catalog = tenant_name or "my" + full_ns = f"{catalog}.{namespace}" + spark.sql(f"CREATE NAMESPACE IF NOT EXISTS {full_ns}") + print(f"Namespace {full_ns} is ready to use.") + return full_ns + + # Delta/Hive flow + try: + namespace, db_location = generate_namespace_location(namespace, tenant_name) + except Exception as e: + print(f"Error creating namespace: {e}") + raise e if spark.catalog.databaseExists(namespace): print(f"Namespace {namespace} is already registered and ready to use") diff --git a/notebook_utils/pyproject.toml b/notebook_utils/pyproject.toml index 5b339c9..793d3d3 100644 --- a/notebook_utils/pyproject.toml +++ b/notebook_utils/pyproject.toml @@ -102,8 +102,8 @@ DATALAKE_MCP_SERVER_URL = "http://localhost:8080" [tool.uv.sources] cdm-task-service-client = { git = "https://github.com/kbase/cdm-task-service-client", rev = "0.2.3" } cdm-spark-manager-client = { git = "https://github.com/kbase/cdm-kube-spark-manager-client.git", rev = "0.0.1" } -minio-manager-service-client = { git = "https://github.com/BERDataLakehouse/minio_manager_service_client.git", rev = "v0.0.7" } -datalake-mcp-server-client = { git = "https://github.com/BERDataLakehouse/datalake-mcp-server-client.git", rev = "v0.0.6" } +minio-manager-service-client = { git = "https://github.com/BERDataLakehouse/minio_manager_service_client.git", rev = "v0.0.10" } +datalake-mcp-server-client = { git = "https://github.com/BERDataLakehouse/datalake-mcp-server-client.git", rev = "v0.0.7" } [tool.ruff.lint.per-file-ignores] "**/cdm_spark_cluster_manager_api_client/**/*.py" = ["E501"] diff --git a/notebook_utils/tests/agent/test_mcp_tools.py b/notebook_utils/tests/agent/test_mcp_tools.py index a7b069a..89a9108 100644 --- a/notebook_utils/tests/agent/test_mcp_tools.py +++ b/notebook_utils/tests/agent/test_mcp_tools.py @@ -2,10 +2,11 @@ Tests for agent/mcp_tools.py - Native MCP tool integration. """ +import asyncio from typing import Optional -from unittest.mock import Mock, patch +from unittest.mock import AsyncMock, Mock, patch import pytest -from pydantic import BaseModel +from pydantic import BaseModel, Field import berdl_notebook_utils.agent.mcp_tools as mcp_module from berdl_notebook_utils.agent.mcp_tools import ( @@ -46,6 +47,56 @@ class TestSchema(BaseModel): assert "required_field" in result.model_fields assert "optional_field" in result.model_fields + def test_optional_field_with_non_none_default(self): + """Test Optional field with a non-None default value.""" + + class TestSchema(BaseModel): + opt_with_default: Optional[int] = 42 + + result = _simplify_schema_for_openai(TestSchema) + + assert "opt_with_default" in result.model_fields + + def test_field_with_description(self): + """Test field with description preserves it.""" + + class TestSchema(BaseModel): + described_field: str = Field(description="A described field") + opt_described: Optional[str] = Field(default=None, description="Optional described") + opt_described_with_val: Optional[str] = Field(default="hello", description="Has default") + + result = _simplify_schema_for_openai(TestSchema) + + assert "described_field" in result.model_fields + assert "opt_described" in result.model_fields + assert "opt_described_with_val" in result.model_fields + + def test_required_field_without_description(self): + """Test required field without description uses Ellipsis default.""" + + class TestSchema(BaseModel): + name: str + + result = _simplify_schema_for_openai(TestSchema) + + assert "name" in result.model_fields + + def test_exception_returns_original(self): + """Test that exception during simplification returns original class.""" + # Create a model that will cause an error during simplification + mock_class = Mock() + mock_class.__name__ = "BrokenSchema" + mock_class.__fields__ = {"field": Mock(annotation=Mock(side_effect=Exception("broken")))} + # Make field access raise + mock_field = Mock() + mock_field.annotation = str + type(mock_field).default = property(lambda self: (_ for _ in ()).throw(Exception("broken"))) + mock_class.__fields__ = {"field": mock_field} + + result = _simplify_schema_for_openai(mock_class) + + assert result is mock_class + class TestWrapAsyncTool: """Tests for _wrap_async_tool function.""" @@ -64,6 +115,93 @@ def test_sync_wrapper_raises_if_no_loop(self): with pytest.raises(RuntimeError, match="MCP event loop not initialized"): sync_tool.func(test_arg="value") + def test_sync_wrapper_success(self): + """Test sync wrapper executes async tool and returns result.""" + mock_async_tool = Mock() + mock_async_tool.name = "test_tool" + mock_async_tool.description = "Test tool" + mock_async_tool.args_schema = None + + # Create a real event loop for this test + loop = asyncio.new_event_loop() + original_loop = mcp_module._mcp_event_loop + mcp_module._mcp_event_loop = loop + + # Start the loop in a background thread + import threading + + thread = threading.Thread(target=loop.run_forever, daemon=True) + thread.start() + + try: + # Set up the async tool to return a value + async def mock_ainvoke(kwargs): + return "tool_result" + + mock_async_tool.ainvoke = mock_ainvoke + + with patch("berdl_notebook_utils.agent.mcp_tools._simplify_schema_for_openai", return_value=None): + sync_tool = _wrap_async_tool(mock_async_tool) + + result = sync_tool.func(key="value") + assert result == "tool_result" + finally: + loop.call_soon_threadsafe(loop.stop) + thread.join(timeout=2) + loop.close() + mcp_module._mcp_event_loop = original_loop + + def test_sync_wrapper_timeout(self): + """Test sync wrapper handles timeout.""" + mock_async_tool = Mock() + mock_async_tool.name = "test_tool" + mock_async_tool.description = "Test tool" + mock_async_tool.args_schema = None + + mock_loop = Mock() + original_loop = mcp_module._mcp_event_loop + mcp_module._mcp_event_loop = mock_loop + + try: + mock_future = Mock() + mock_future.result.side_effect = TimeoutError() + mock_loop_run = Mock(return_value=mock_future) + + with patch("asyncio.run_coroutine_threadsafe", mock_loop_run): + with patch("berdl_notebook_utils.agent.mcp_tools._simplify_schema_for_openai", return_value=None): + sync_tool = _wrap_async_tool(mock_async_tool) + + result = sync_tool.func(key="value") + assert "timed out" in result + + finally: + mcp_module._mcp_event_loop = original_loop + + def test_sync_wrapper_exception(self): + """Test sync wrapper handles exceptions from tool execution.""" + mock_async_tool = Mock() + mock_async_tool.name = "test_tool" + mock_async_tool.description = "Test tool" + mock_async_tool.args_schema = None + + mock_loop = Mock() + original_loop = mcp_module._mcp_event_loop + mcp_module._mcp_event_loop = mock_loop + + try: + mock_future = Mock() + mock_future.result.side_effect = ValueError("something broke") + + with patch("asyncio.run_coroutine_threadsafe", return_value=mock_future): + with patch("berdl_notebook_utils.agent.mcp_tools._simplify_schema_for_openai", return_value=None): + sync_tool = _wrap_async_tool(mock_async_tool) + + result = sync_tool.func(key="value") + assert "Error executing tool" in result + + finally: + mcp_module._mcp_event_loop = original_loop + class TestGetMcpTools: """Tests for get_mcp_tools function.""" @@ -95,6 +233,95 @@ def test_raises_import_error_if_langchain_mcp_tools_not_installed(self, mock_set with pytest.raises(ImportError, match="langchain-mcp-tools is required"): get_mcp_tools() + @patch("berdl_notebook_utils.agent.mcp_tools._wrap_async_tool") + @patch("asyncio.run_coroutine_threadsafe") + @patch("berdl_notebook_utils.agent.mcp_tools.get_settings") + def test_get_mcp_tools_full_flow(self, mock_settings, mock_run_coro, mock_wrap): + """Test full flow of get_mcp_tools: discover, wrap, cache.""" + # Clear global state + original_cache = mcp_module._mcp_tools_cache + original_loop = mcp_module._mcp_event_loop + original_thread = mcp_module._mcp_loop_thread + mcp_module._mcp_tools_cache = None + mcp_module._mcp_event_loop = None + + try: + mock_settings.return_value.DATALAKE_MCP_SERVER_URL = "http://localhost:8000" + mock_settings.return_value.KBASE_AUTH_TOKEN = "token123" + + # Mock the convert_mcp_to_langchain_tools function + mock_async_tool = Mock() + mock_async_tool.name = "list_databases" + mock_async_tool.description = "List all databases" + + mock_cleanup = AsyncMock() + + # Mock the discovery future + mock_future = Mock() + mock_future.result.return_value = [mock_async_tool] + mock_run_coro.return_value = mock_future + + # Mock the tool wrapping + mock_sync_tool = Mock() + mock_sync_tool.name = "list_databases" + mock_sync_tool.description = "List all databases" + mock_wrap.return_value = mock_sync_tool + + mock_convert = AsyncMock(return_value=([mock_async_tool], mock_cleanup)) + + with patch.dict("sys.modules", {"langchain_mcp_tools": Mock(convert_mcp_to_langchain_tools=mock_convert)}): + result = get_mcp_tools(server_url="http://custom:8000") + + assert len(result) == 1 + assert result == [mock_sync_tool] + # Verify tools are cached + assert mcp_module._mcp_tools_cache == [mock_sync_tool] + + finally: + mcp_module._mcp_tools_cache = original_cache + if mcp_module._mcp_event_loop and mcp_module._mcp_event_loop != original_loop: + mcp_module._mcp_event_loop.call_soon_threadsafe(mcp_module._mcp_event_loop.stop) + mcp_module._mcp_event_loop = original_loop + mcp_module._mcp_loop_thread = original_thread + + @patch("berdl_notebook_utils.agent.mcp_tools.get_settings") + def test_get_mcp_tools_generic_exception(self, mock_settings): + """Test get_mcp_tools re-raises generic exceptions.""" + mcp_module._mcp_tools_cache = None + original_loop = mcp_module._mcp_event_loop + mcp_module._mcp_event_loop = None + + try: + mock_settings.return_value.DATALAKE_MCP_SERVER_URL = "http://localhost:8000" + mock_settings.return_value.KBASE_AUTH_TOKEN = "token" + + mock_convert = Mock(side_effect=ConnectionError("Cannot connect")) + + with patch.dict("sys.modules", {"langchain_mcp_tools": Mock(convert_mcp_to_langchain_tools=mock_convert)}): + with patch("asyncio.run_coroutine_threadsafe") as mock_run_coro: + mock_future = Mock() + mock_future.result.side_effect = ConnectionError("Cannot connect") + mock_run_coro.return_value = mock_future + + with pytest.raises(ConnectionError): + get_mcp_tools() + + finally: + mcp_module._mcp_tools_cache = None + if mcp_module._mcp_event_loop and mcp_module._mcp_event_loop != original_loop: + mcp_module._mcp_event_loop.call_soon_threadsafe(mcp_module._mcp_event_loop.stop) + mcp_module._mcp_event_loop = original_loop + + def test_uses_settings_url_when_no_server_url(self): + """Test uses DATALAKE_MCP_SERVER_URL from settings when server_url is None.""" + mcp_module._mcp_tools_cache = ["already_cached"] + + try: + result = get_mcp_tools() + assert result == ["already_cached"] + finally: + mcp_module._mcp_tools_cache = None + class TestClearMcpToolsCache: """Tests for clear_mcp_tools_cache function.""" @@ -179,3 +406,29 @@ async def mock_cleanup_coro(): finally: mcp_module._mcp_event_loop = original_loop mcp_module._mcp_cleanup = original_cleanup + + @patch("asyncio.run_coroutine_threadsafe") + def test_cleanup_handles_exception(self, mock_run_coro): + """Test cleanup handles exception during MCP cleanup gracefully.""" + mock_loop = Mock() + + mock_cleanup_func = Mock(return_value=Mock()) + + original_loop = mcp_module._mcp_event_loop + original_cleanup = mcp_module._mcp_cleanup + + mcp_module._mcp_event_loop = mock_loop + mcp_module._mcp_cleanup = mock_cleanup_func + + # Make future.result raise an exception + mock_future = Mock() + mock_future.result.side_effect = Exception("Cleanup failed") + mock_run_coro.return_value = mock_future + + try: + # Should not raise, just log warning + _cleanup_on_exit() + mock_loop.call_soon_threadsafe.assert_called() + finally: + mcp_module._mcp_event_loop = original_loop + mcp_module._mcp_cleanup = original_cleanup diff --git a/notebook_utils/tests/agent/test_prompts.py b/notebook_utils/tests/agent/test_prompts.py new file mode 100644 index 0000000..60e2120 --- /dev/null +++ b/notebook_utils/tests/agent/test_prompts.py @@ -0,0 +1,91 @@ +"""Tests for agent/prompts.py - System prompt generation.""" + +from unittest.mock import Mock, patch + +from berdl_notebook_utils.agent.prompts import TOOL_DESCRIPTIONS, get_system_prompt + + +class TestGetSystemPrompt: + """Tests for get_system_prompt function.""" + + def test_with_explicit_username(self): + """Test prompt generation with an explicit username.""" + prompt = get_system_prompt(username="alice") + + assert "alice" in prompt + assert "u_alice__" in prompt + assert "BERDL Data Lake Assistant" in prompt + + def test_with_none_username_auto_detects(self): + """Test auto-detection of username from settings when None.""" + mock_settings = Mock() + mock_settings.USER = "bob" + + with patch("berdl_notebook_utils.agent.prompts.get_settings", return_value=mock_settings): + prompt = get_system_prompt(username=None) + + assert "bob" in prompt + assert "u_bob__" in prompt + + def test_with_none_username_fallback_on_error(self): + """Test falls back to 'unknown' when settings raise an exception.""" + with patch("berdl_notebook_utils.agent.prompts.get_settings", side_effect=Exception("no settings")): + prompt = get_system_prompt(username=None) + + assert "unknown" in prompt + assert "u_unknown__" in prompt + + def test_prompt_contains_platform_sections(self): + """Test prompt contains all expected platform sections.""" + prompt = get_system_prompt(username="test_user") + + # Key sections that should be present + assert "## User Context" in prompt + assert "## BERDL Platform Architecture" in prompt + assert "## Data Organization" in prompt + assert "## Available Tools" in prompt + assert "## Best Practices" in prompt + assert "## Example Workflows" in prompt + assert "## Error Handling" in prompt + assert "## Important Constraints" in prompt + assert "## Response Style" in prompt + + def test_prompt_contains_tool_references(self): + """Test prompt references the available tools.""" + prompt = get_system_prompt(username="test_user") + + assert "list_databases" in prompt + assert "list_tables" in prompt + assert "get_table_schema" in prompt + assert "sample_table" in prompt + assert "query_table" in prompt + + def test_returns_string(self): + """Test return type is string.""" + result = get_system_prompt(username="test") + assert isinstance(result, str) + assert len(result) > 100 # Non-trivial prompt + + +class TestToolDescriptions: + """Tests for TOOL_DESCRIPTIONS constant.""" + + def test_contains_expected_tools(self): + """Test TOOL_DESCRIPTIONS has all expected tool names.""" + expected_tools = [ + "list_databases", + "list_tables", + "get_table_schema", + "get_database_structure", + "sample_table", + "count_table_rows", + "query_table", + ] + for tool in expected_tools: + assert tool in TOOL_DESCRIPTIONS + + def test_descriptions_are_non_empty_strings(self): + """Test all tool descriptions are non-empty strings.""" + for name, desc in TOOL_DESCRIPTIONS.items(): + assert isinstance(desc, str), f"{name} description is not a string" + assert len(desc) > 10, f"{name} description is too short" diff --git a/notebook_utils/tests/mcp/test_operations.py b/notebook_utils/tests/mcp/test_operations.py index c54bb7a..eb077b4 100644 --- a/notebook_utils/tests/mcp/test_operations.py +++ b/notebook_utils/tests/mcp/test_operations.py @@ -76,19 +76,6 @@ def test_returns_list_of_databases(self, mock_client): assert result == ["default", "analytics", "user_data"] - def test_passes_use_hms_parameter(self, mock_client): - """Test that use_hms parameter is passed correctly.""" - mock_response = Mock() - mock_response.databases = [] - - with patch("berdl_notebook_utils.mcp.operations.list_databases") as mock_api: - mock_api.sync.return_value = mock_response - - mcp_list_databases(use_hms=False) - - call_kwargs = mock_api.sync.call_args[1] - assert call_kwargs["body"].use_hms is False - def test_raises_on_none_response(self, mock_client): """Test that None response raises an exception.""" with patch("berdl_notebook_utils.mcp.operations.list_databases") as mock_api: diff --git a/notebook_utils/tests/minio_governance/test_operations.py b/notebook_utils/tests/minio_governance/test_operations.py index 79cf953..a354201 100644 --- a/notebook_utils/tests/minio_governance/test_operations.py +++ b/notebook_utils/tests/minio_governance/test_operations.py @@ -9,13 +9,20 @@ import httpx import pytest +from governance_client.models.user_names_response import UserNamesResponse + from berdl_notebook_utils.minio_governance.operations import ( + _fetch_with_file_cache, _get_credentials_cache_path, + _get_polaris_cache_path, _read_cached_credentials, + _read_cached_polaris_credentials, _write_credentials_cache, + _write_polaris_credentials_cache, _build_table_path, check_governance_health, get_minio_credentials, + get_polaris_credentials, get_my_sql_warehouse, get_group_sql_warehouse, get_namespace_prefix, @@ -30,12 +37,14 @@ make_table_private, list_available_groups, list_groups, + list_user_names, list_users, add_group_member, remove_group_member, create_tenant_and_assign_users, request_tenant_access, CREDENTIALS_CACHE_FILE, + POLARIS_CREDENTIALS_CACHE_FILE, CredentialsResponse, ErrorResponse, ) @@ -115,6 +124,59 @@ def test_builds_path_with_db_suffix(self): assert path == "s3a://cdm-lake/users-sql-warehouse/user1/analytics.db/users" +class TestFetchWithFileCache: + """Tests for _fetch_with_file_cache helper.""" + + @patch("berdl_notebook_utils.minio_governance.operations.fcntl") + def test_returns_cached_value_on_cache_hit(self, mock_fcntl, tmp_path): + """Test returns cached value without calling fetch when cache hits.""" + cache_path = tmp_path / "creds.json" + sentinel = {"key": "cached_value"} + + read_cache = Mock(return_value=sentinel) + fetch = Mock() + write_cache = Mock() + + result = _fetch_with_file_cache(cache_path, read_cache, fetch, write_cache) + + assert result == sentinel + read_cache.assert_called_once_with(cache_path) + fetch.assert_not_called() + write_cache.assert_not_called() + + @patch("berdl_notebook_utils.minio_governance.operations.fcntl") + def test_fetches_and_writes_cache_on_cache_miss(self, mock_fcntl, tmp_path): + """Test fetches fresh data and writes cache when cache misses.""" + cache_path = tmp_path / "creds.json" + sentinel = {"key": "fresh_value"} + + read_cache = Mock(return_value=None) + fetch = Mock(return_value=sentinel) + write_cache = Mock() + + result = _fetch_with_file_cache(cache_path, read_cache, fetch, write_cache) + + assert result == sentinel + read_cache.assert_called_once_with(cache_path) + fetch.assert_called_once() + write_cache.assert_called_once_with(cache_path, sentinel) + + @patch("berdl_notebook_utils.minio_governance.operations.fcntl") + def test_returns_none_without_writing_when_fetch_fails(self, mock_fcntl, tmp_path): + """Test returns None and does not write cache when fetch returns None.""" + cache_path = tmp_path / "creds.json" + + read_cache = Mock(return_value=None) + fetch = Mock(return_value=None) + write_cache = Mock() + + result = _fetch_with_file_cache(cache_path, read_cache, fetch, write_cache) + + assert result is None + fetch.assert_called_once() + write_cache.assert_not_called() + + class TestCheckGovernanceHealth: """Tests for check_governance_health function.""" @@ -494,6 +556,45 @@ def test_list_users(self, mock_list_users, mock_get_client): assert result.users == ["user1", "user2"] +class TestListUserNames: + """Tests for list_user_names function.""" + + @patch("berdl_notebook_utils.minio_governance.operations.get_governance_client") + @patch("berdl_notebook_utils.minio_governance.operations.list_user_names_sync") + def test_list_user_names_success(self, mock_list_user_names, mock_get_client): + """Test list_user_names returns list of usernames.""" + mock_client = Mock() + mock_get_client.return_value = mock_client + mock_list_user_names.return_value = Mock(spec=UserNamesResponse, usernames=["user1", "user2", "user3"]) + + result = list_user_names() + + assert result == ["user1", "user2", "user3"] + mock_list_user_names.assert_called_once_with(client=mock_client) + + @patch("berdl_notebook_utils.minio_governance.operations.get_governance_client") + @patch("berdl_notebook_utils.minio_governance.operations.list_user_names_sync") + def test_list_user_names_error_response(self, mock_list_user_names, mock_get_client): + """Test list_user_names raises on error response.""" + mock_client = Mock() + mock_get_client.return_value = mock_client + mock_list_user_names.return_value = Mock(spec=ErrorResponse, message="Forbidden") + + with pytest.raises(RuntimeError, match="Failed to list usernames"): + list_user_names() + + @patch("berdl_notebook_utils.minio_governance.operations.get_governance_client") + @patch("berdl_notebook_utils.minio_governance.operations.list_user_names_sync") + def test_list_user_names_none_response(self, mock_list_user_names, mock_get_client): + """Test list_user_names raises on None response.""" + mock_client = Mock() + mock_get_client.return_value = mock_client + mock_list_user_names.return_value = None + + with pytest.raises(RuntimeError, match="no response from API"): + list_user_names() + + class TestAddGroupMember: """Tests for add_group_member function.""" @@ -638,3 +739,419 @@ def test_request_tenant_access_http_error(self, mock_settings, mock_httpx): with pytest.raises(RuntimeError, match="Failed to submit access request"): request_tenant_access("kbase") + + @patch("berdl_notebook_utils.minio_governance.operations.httpx") + @patch("berdl_notebook_utils.minio_governance.operations.get_settings") + def test_request_tenant_access_connection_error(self, mock_settings, mock_httpx): + """Test request_tenant_access handles connection errors.""" + mock_settings.return_value.TENANT_ACCESS_SERVICE_URL = "http://service:8000" + mock_settings.return_value.KBASE_AUTH_TOKEN = "token" + + mock_httpx.HTTPStatusError = httpx.HTTPStatusError + mock_httpx.RequestError = httpx.RequestError + mock_httpx.post.side_effect = httpx.RequestError("Connection refused") + + with pytest.raises(RuntimeError, match="Failed to connect to tenant access service"): + request_tenant_access("kbase") + + @patch("berdl_notebook_utils.minio_governance.operations.httpx") + @patch("berdl_notebook_utils.minio_governance.operations.get_settings") + def test_request_tenant_access_with_justification(self, mock_settings, mock_httpx): + """Test request_tenant_access includes justification in payload.""" + mock_settings.return_value.TENANT_ACCESS_SERVICE_URL = "http://service:8000" + mock_settings.return_value.KBASE_AUTH_TOKEN = "token" + + mock_response = Mock() + mock_response.json.return_value = { + "status": "pending", + "message": "Request submitted", + "requester": "test_user", + "tenant_name": "kbase", + "permission": "read_write", + } + mock_httpx.post.return_value = mock_response + + result = request_tenant_access("kbase", permission="read_write", justification="Need access for project X") + + assert result["status"] == "pending" + # Verify justification was included in the payload + call_kwargs = mock_httpx.post.call_args + assert call_kwargs[1]["json"]["justification"] == "Need access for project X" + + +# ============================================================================= +# Polaris credential caching tests +# ============================================================================= + + +class TestGetPolarisCachePath: + """Tests for _get_polaris_cache_path helper.""" + + def test_returns_path_in_home(self): + """Test returns path in home directory.""" + path = _get_polaris_cache_path() + + assert path == Path.home() / POLARIS_CREDENTIALS_CACHE_FILE + + +class TestReadCachedPolarisCredentials: + """Tests for _read_cached_polaris_credentials helper.""" + + def test_returns_none_if_file_not_exists(self, tmp_path): + """Test returns None if cache file doesn't exist.""" + result = _read_cached_polaris_credentials(tmp_path / "nonexistent.json") + + assert result is None + + def test_returns_none_on_invalid_json(self, tmp_path): + """Test returns None on invalid JSON.""" + cache_file = tmp_path / "cache.json" + cache_file.write_text("not valid json") + + result = _read_cached_polaris_credentials(cache_file) + + assert result is None + + def test_returns_none_on_missing_keys(self, tmp_path): + """Test returns None when required keys are missing.""" + cache_file = tmp_path / "cache.json" + cache_file.write_text('{"client_id": "test", "client_secret": "test"}') + + result = _read_cached_polaris_credentials(cache_file) + + assert result is None + + def test_returns_credentials_on_valid_cache(self, tmp_path): + """Test returns credentials on valid cache file.""" + cache_file = tmp_path / "cache.json" + data = { + "client_id": "test_id", + "client_secret": "test_secret", + "personal_catalog": "user_test", + "tenant_catalogs": ["tenant_kbase"], + } + cache_file.write_text(json.dumps(data)) + + result = _read_cached_polaris_credentials(cache_file) + + assert result is not None + assert result["client_id"] == "test_id" + assert result["personal_catalog"] == "user_test" + + +class TestWritePolarisCachedCredentials: + """Tests for _write_polaris_credentials_cache helper.""" + + def test_writes_credentials_to_file(self, tmp_path): + """Test writes Polaris credentials to cache file.""" + cache_file = tmp_path / "cache.json" + creds = { + "client_id": "test_id", + "client_secret": "test_secret", + "personal_catalog": "user_test", + "tenant_catalogs": [], + } + + _write_polaris_credentials_cache(cache_file, creds) + + assert cache_file.exists() + content = json.loads(cache_file.read_text()) + assert content["client_id"] == "test_id" + + def test_handles_os_error(self, tmp_path): + """Test handles OSError gracefully.""" + # Use a path that can't be written to + cache_file = tmp_path / "nonexistent_dir" / "cache.json" + + # Should not raise + _write_polaris_credentials_cache(cache_file, {"client_id": "test"}) + + +class TestWriteCredentialsCacheErrors: + """Tests for _write_credentials_cache error handling.""" + + def test_handles_os_error(self, tmp_path): + """Test handles OSError gracefully.""" + cache_file = tmp_path / "nonexistent_dir" / "cache.json" + mock_creds = Mock() + mock_creds.to_dict.return_value = {"access_key": "test"} + + # Should not raise + _write_credentials_cache(cache_file, mock_creds) + + +# ============================================================================= +# get_polaris_credentials tests +# ============================================================================= + + +class TestGetPolarisCredentials: + """Tests for get_polaris_credentials function.""" + + @patch("berdl_notebook_utils.minio_governance.operations.os") + @patch("berdl_notebook_utils.minio_governance.operations.fcntl") + @patch("berdl_notebook_utils.minio_governance.operations.get_settings") + def test_returns_none_when_polaris_not_configured(self, mock_settings, mock_fcntl, mock_os): + """Test returns None when POLARIS_CATALOG_URI is not set.""" + mock_settings.return_value.POLARIS_CATALOG_URI = None + + result = get_polaris_credentials() + + assert result is None + + @patch("berdl_notebook_utils.minio_governance.operations.os") + @patch("berdl_notebook_utils.minio_governance.operations.fcntl") + @patch("berdl_notebook_utils.minio_governance.operations._write_polaris_credentials_cache") + @patch("berdl_notebook_utils.minio_governance.operations._read_cached_polaris_credentials") + @patch("berdl_notebook_utils.minio_governance.operations._get_polaris_cache_path") + @patch("berdl_notebook_utils.minio_governance.operations.get_settings") + def test_returns_cached_credentials( + self, mock_settings, mock_cache_path, mock_read_cache, mock_write_cache, mock_fcntl, mock_os, tmp_path + ): + """Test returns cached Polaris credentials when available.""" + mock_settings.return_value.POLARIS_CATALOG_URI = "http://polaris:8181/api/catalog" + mock_cache_path.return_value = tmp_path / ".polaris_cache" + + cached = { + "client_id": "cached_id", + "client_secret": "cached_secret", + "personal_catalog": "user_test", + "tenant_catalogs": ["tenant_kbase"], + } + mock_read_cache.return_value = cached + + result = get_polaris_credentials() + + assert result["client_id"] == "cached_id" + mock_os.environ.__setitem__.assert_any_call("POLARIS_CREDENTIAL", "cached_id:cached_secret") + mock_os.environ.__setitem__.assert_any_call("POLARIS_PERSONAL_CATALOG", "user_test") + + @patch("berdl_notebook_utils.minio_governance.operations.os") + @patch("berdl_notebook_utils.minio_governance.operations.fcntl") + @patch("berdl_notebook_utils.minio_governance.operations._write_polaris_credentials_cache") + @patch("berdl_notebook_utils.minio_governance.operations._read_cached_polaris_credentials") + @patch("berdl_notebook_utils.minio_governance.operations._get_polaris_cache_path") + @patch( + "berdl_notebook_utils.minio_governance.operations.provision_polaris_user_polaris_user_provision_username_post" + ) + @patch("berdl_notebook_utils.minio_governance.operations.get_governance_client") + @patch("berdl_notebook_utils.minio_governance.operations.get_settings") + def test_fetches_fresh_credentials( + self, + mock_settings, + mock_get_client, + mock_provision, + mock_cache_path, + mock_read_cache, + mock_write_cache, + mock_fcntl, + mock_os, + tmp_path, + ): + """Test fetches fresh credentials from API when no cache.""" + mock_settings.return_value.POLARIS_CATALOG_URI = "http://polaris:8181/api/catalog" + mock_settings.return_value.USER = "test_user" + mock_cache_path.return_value = tmp_path / ".polaris_cache" + mock_read_cache.return_value = None + + mock_api_response = Mock() + mock_api_response.to_dict.return_value = { + "client_id": "new_id", + "client_secret": "new_secret", + "personal_catalog": "user_test_user", + "tenant_catalogs": ["tenant_team"], + } + mock_provision.sync.return_value = mock_api_response + + result = get_polaris_credentials() + + assert result["client_id"] == "new_id" + assert result["personal_catalog"] == "user_test_user" + mock_write_cache.assert_called_once() + mock_provision.sync.assert_called_once_with(username="test_user", client=mock_get_client.return_value) + + @patch("berdl_notebook_utils.minio_governance.operations.fcntl") + @patch("berdl_notebook_utils.minio_governance.operations._read_cached_polaris_credentials") + @patch("berdl_notebook_utils.minio_governance.operations._get_polaris_cache_path") + @patch( + "berdl_notebook_utils.minio_governance.operations.provision_polaris_user_polaris_user_provision_username_post" + ) + @patch("berdl_notebook_utils.minio_governance.operations.get_governance_client") + @patch("berdl_notebook_utils.minio_governance.operations.get_settings") + def test_returns_none_on_error_response( + self, mock_settings, mock_get_client, mock_provision, mock_cache_path, mock_read_cache, mock_fcntl, tmp_path + ): + """Test returns None when API returns ErrorResponse.""" + mock_settings.return_value.POLARIS_CATALOG_URI = "http://polaris:8181/api/catalog" + mock_settings.return_value.USER = "test_user" + mock_cache_path.return_value = tmp_path / ".polaris_cache" + mock_read_cache.return_value = None + + mock_error = Mock(spec=ErrorResponse) + mock_error.message = "Internal server error" + mock_provision.sync.return_value = mock_error + + result = get_polaris_credentials() + + assert result is None + + @patch("berdl_notebook_utils.minio_governance.operations.fcntl") + @patch("berdl_notebook_utils.minio_governance.operations._read_cached_polaris_credentials") + @patch("berdl_notebook_utils.minio_governance.operations._get_polaris_cache_path") + @patch( + "berdl_notebook_utils.minio_governance.operations.provision_polaris_user_polaris_user_provision_username_post" + ) + @patch("berdl_notebook_utils.minio_governance.operations.get_governance_client") + @patch("berdl_notebook_utils.minio_governance.operations.get_settings") + def test_returns_none_on_no_response( + self, mock_settings, mock_get_client, mock_provision, mock_cache_path, mock_read_cache, mock_fcntl, tmp_path + ): + """Test returns None when API returns None (unexpected status).""" + mock_settings.return_value.POLARIS_CATALOG_URI = "http://polaris:8181/api/catalog" + mock_settings.return_value.USER = "test_user" + mock_cache_path.return_value = tmp_path / ".polaris_cache" + mock_read_cache.return_value = None + + mock_provision.sync.return_value = None + + result = get_polaris_credentials() + + assert result is None + + +# ============================================================================= +# unshare_table error logging tests +# ============================================================================= + + +class TestUnshareTableErrors: + """Tests for unshare_table error logging.""" + + @patch("berdl_notebook_utils.minio_governance.operations.get_settings") + @patch("berdl_notebook_utils.minio_governance.operations.get_governance_client") + @patch("berdl_notebook_utils.minio_governance.operations.unshare_data_sharing_unshare_post") + def test_unshare_table_logs_errors(self, mock_unshare, mock_get_client, mock_settings, caplog): + """Test unshare_table logs error messages when present.""" + mock_settings.return_value.USER = "test_user" + mock_client = Mock() + mock_get_client.return_value = mock_client + mock_unshare.sync.return_value = Mock(errors=["User not found", "Permission denied"]) + + with caplog.at_level(logging.WARNING): + unshare_table("test_db", "test_table", from_users=["invalid_user"]) + + assert "Error unsharing table" in caplog.text + + +# ============================================================================= +# get_minio_credentials edge cases +# ============================================================================= + + +class TestGetMinioCredentialsEdgeCases: + """Tests for get_minio_credentials edge cases.""" + + @patch("berdl_notebook_utils.minio_governance.operations.os") + @patch("berdl_notebook_utils.minio_governance.operations.fcntl") + @patch("berdl_notebook_utils.minio_governance.operations.get_credentials_credentials_get") + @patch("berdl_notebook_utils.minio_governance.operations.get_governance_client") + @patch("berdl_notebook_utils.minio_governance.operations._write_credentials_cache") + @patch("berdl_notebook_utils.minio_governance.operations._read_cached_credentials") + @patch("berdl_notebook_utils.minio_governance.operations._get_credentials_cache_path") + def test_raises_on_api_error_response( + self, + mock_cache_path, + mock_read_cache, + mock_write_cache, + mock_get_client, + mock_get_creds, + mock_fcntl, + mock_os, + tmp_path, + ): + """Test raises RuntimeError when API returns non-CredentialsResponse.""" + mock_cache_path.return_value = tmp_path / ".cache" + mock_read_cache.return_value = None + mock_client = Mock() + mock_get_client.return_value = mock_client + # Return something that's not a CredentialsResponse + mock_get_creds.sync.return_value = Mock(spec=ErrorResponse) + + with pytest.raises(RuntimeError, match="Failed to fetch credentials from API"): + get_minio_credentials() + + +# ============================================================================= +# create_tenant_and_assign_users edge cases +# ============================================================================= + + +class TestCreateTenantEdgeCases: + """Tests for create_tenant_and_assign_users edge cases.""" + + @patch("berdl_notebook_utils.minio_governance.operations.time") + @patch("berdl_notebook_utils.minio_governance.operations.get_governance_client") + @patch("berdl_notebook_utils.minio_governance.operations.create_group_management_groups_group_name_post") + @patch( + "berdl_notebook_utils.minio_governance.operations.add_group_member_management_groups_group_name_members_username_post" + ) + def test_add_member_error_response(self, mock_add_member, mock_create, mock_get_client, mock_time): + """Test create_tenant handles ErrorResponse when adding members.""" + mock_client = Mock() + mock_get_client.return_value = mock_client + mock_create.sync.return_value = Mock(success=True) + mock_create.sync.return_value.__class__ = Mock # Not ErrorResponse + + # Return ErrorResponse for add_member + error_resp = Mock(spec=ErrorResponse) + error_resp.message = "User not found" + mock_add_member.sync.return_value = error_resp + + result = create_tenant_and_assign_users("tenant1", ["bad_user"]) + + assert len(result["add_members"]) == 1 + assert result["add_members"][0][1] == error_resp + + @patch("berdl_notebook_utils.minio_governance.operations.time") + @patch("berdl_notebook_utils.minio_governance.operations.get_governance_client") + @patch("berdl_notebook_utils.minio_governance.operations.create_group_management_groups_group_name_post") + @patch( + "berdl_notebook_utils.minio_governance.operations.add_group_member_management_groups_group_name_members_username_post" + ) + def test_add_member_exception(self, mock_add_member, mock_create, mock_get_client, mock_time): + """Test create_tenant handles exception when adding members.""" + mock_client = Mock() + mock_get_client.return_value = mock_client + mock_create.sync.return_value = Mock(success=True) + mock_create.sync.return_value.__class__ = Mock # Not ErrorResponse + + # Raise exception for add_member + mock_add_member.sync.side_effect = Exception("API timeout") + + result = create_tenant_and_assign_users("tenant1", ["user1"]) + + assert len(result["add_members"]) == 1 + # Should have an ErrorResponse tuple + username, error = result["add_members"][0] + assert username == "user1" + assert isinstance(error, ErrorResponse) + + +# ============================================================================= +# list_available_groups edge cases +# ============================================================================= + + +class TestListAvailableGroupsEdgeCases: + """Tests for list_available_groups edge cases.""" + + @patch("berdl_notebook_utils.minio_governance.operations.get_governance_client") + @patch("berdl_notebook_utils.minio_governance.operations.list_group_names_sync") + def test_list_available_groups_none_response(self, mock_list_groups, mock_get_client): + """Test list_available_groups raises on None response.""" + mock_client = Mock() + mock_get_client.return_value = mock_client + mock_list_groups.return_value = None + + with pytest.raises(RuntimeError, match="Failed to list groups: no response"): + list_available_groups() diff --git a/notebook_utils/tests/spark/test_connect_server.py b/notebook_utils/tests/spark/test_connect_server.py index 3d83bbf..4d01465 100644 --- a/notebook_utils/tests/spark/test_connect_server.py +++ b/notebook_utils/tests/spark/test_connect_server.py @@ -102,6 +102,8 @@ def test_generate_spark_config(self, mock_get_warehouse, mock_get_settings, mock mock_settings.SPARK_MASTER_CORES = 1 mock_settings.SPARK_MASTER_MEMORY = "8G" mock_settings.BERDL_POD_IP = "192.168.1.100" + # Polaris not configured — _get_catalog_conf returns empty dict + mock_settings.POLARIS_CATALOG_URI = None mock_get_settings.return_value = mock_settings mock_convert.return_value = "8g" @@ -125,6 +127,60 @@ def test_generate_spark_config(self, mock_get_warehouse, mock_get_settings, mock assert "test_user" in content assert "spark.eventLog.dir" in content + @patch("berdl_notebook_utils.spark.connect_server._get_catalog_conf") + @patch("berdl_notebook_utils.spark.connect_server.shutil.copy") + @patch("berdl_notebook_utils.spark.connect_server.convert_memory_format") + @patch("berdl_notebook_utils.spark.connect_server.get_settings") + @patch("berdl_notebook_utils.spark.connect_server.get_my_sql_warehouse") + def test_generate_spark_config_with_catalog( + self, mock_get_warehouse, mock_get_settings, mock_convert, mock_copy, mock_catalog_conf, tmp_path + ): + """Test generate_spark_config includes Polaris catalog config when present.""" + mock_settings = Mock() + mock_settings.USER = "test_user" + mock_settings.SPARK_HOME = "/opt/spark" + mock_settings.SPARK_CONNECT_DEFAULTS_TEMPLATE = str(tmp_path / "template.conf") + mock_url = Mock() + mock_url.port = 15002 + mock_settings.SPARK_CONNECT_URL = mock_url + mock_settings.SPARK_MASTER_URL = "spark://master:7077" + mock_settings.BERDL_HIVE_METASTORE_URI = "thrift://localhost:9083" + mock_settings.MINIO_ENDPOINT_URL = "http://localhost:9000" + mock_settings.MINIO_ACCESS_KEY = "minioadmin" + mock_settings.MINIO_SECRET_KEY = "minioadmin" + mock_settings.SPARK_WORKER_COUNT = 2 + mock_settings.SPARK_WORKER_CORES = 2 + mock_settings.SPARK_WORKER_MEMORY = "10G" + mock_settings.SPARK_MASTER_CORES = 1 + mock_settings.SPARK_MASTER_MEMORY = "8G" + mock_settings.BERDL_POD_IP = "192.168.1.100" + mock_get_settings.return_value = mock_settings + + mock_convert.return_value = "8g" + mock_warehouse = Mock() + mock_warehouse.sql_warehouse_prefix = "s3a://cdm-lake/users-sql-warehouse/test_user" + mock_get_warehouse.return_value = mock_warehouse + + # Return catalog config entries + mock_catalog_conf.return_value = { + "spark.sql.catalog.my": "org.apache.iceberg.spark.SparkCatalog", + "spark.sql.catalog.my.type": "rest", + } + + # Create template file + template_file = tmp_path / "template.conf" + template_file.write_text("# Base config") + + config = SparkConnectServerConfig() + config.spark_defaults_path = tmp_path / "spark-defaults.conf" + + config.generate_spark_config() + + content = config.spark_defaults_path.read_text() + assert "Polaris Catalog Configuration" in content + assert "spark.sql.catalog.my=org.apache.iceberg.spark.SparkCatalog" in content + assert "spark.sql.catalog.my.type=rest" in content + @patch("berdl_notebook_utils.spark.connect_server.get_settings") def test_generate_spark_config_template_not_found(self, mock_get_settings): """Test generate_spark_config raises if template not found.""" @@ -143,6 +199,59 @@ def test_generate_spark_config_template_not_found(self, mock_get_settings): with pytest.raises(FileNotFoundError, match="Spark config template not found"): config.generate_spark_config() + @patch("berdl_notebook_utils.spark.connect_server.get_my_groups") + @patch("berdl_notebook_utils.spark.connect_server.get_namespace_prefix") + @patch("berdl_notebook_utils.spark.connect_server.get_settings") + def test_compute_allowed_namespace_prefixes(self, mock_get_settings, mock_ns_prefix, mock_groups): + """Test compute_allowed_namespace_prefixes returns correct prefixes.""" + mock_settings = Mock() + mock_settings.USER = "test_user" + mock_settings.SPARK_HOME = "/opt/spark" + mock_settings.SPARK_CONNECT_DEFAULTS_TEMPLATE = "/etc/template.conf" + mock_url = Mock() + mock_url.port = 15002 + mock_settings.SPARK_CONNECT_URL = mock_url + mock_settings.SPARK_MASTER_URL = "spark://master:7077" + mock_get_settings.return_value = mock_settings + + # Mock governance responses + mock_ns_prefix.return_value = Mock(user_namespace_prefix="u_test_user__") + mock_groups.return_value = Mock(groups=["kbase", "kbasero", "research"]) + + config = SparkConnectServerConfig() + result = config.compute_allowed_namespace_prefixes() + + assert "u_test_user__" in result + assert "kbase_" in result + assert "research_" in result + # "kbasero" ends with "ro" so should be excluded + assert "kbasero_" not in result + + @patch("berdl_notebook_utils.spark.connect_server.get_my_groups") + @patch("berdl_notebook_utils.spark.connect_server.get_namespace_prefix") + @patch("berdl_notebook_utils.spark.connect_server.get_settings") + def test_compute_allowed_namespace_prefixes_errors(self, mock_get_settings, mock_ns_prefix, mock_groups): + """Test compute_allowed_namespace_prefixes handles API errors gracefully.""" + mock_settings = Mock() + mock_settings.USER = "test_user" + mock_settings.SPARK_HOME = "/opt/spark" + mock_settings.SPARK_CONNECT_DEFAULTS_TEMPLATE = "/etc/template.conf" + mock_url = Mock() + mock_url.port = 15002 + mock_settings.SPARK_CONNECT_URL = mock_url + mock_settings.SPARK_MASTER_URL = "spark://master:7077" + mock_get_settings.return_value = mock_settings + + # Both API calls fail + mock_ns_prefix.side_effect = Exception("API error") + mock_groups.side_effect = Exception("API error") + + config = SparkConnectServerConfig() + result = config.compute_allowed_namespace_prefixes() + + # Should return empty string (no prefixes) + assert result == "" + class TestSparkConnectServerManager: """Tests for SparkConnectServerManager class.""" @@ -438,14 +547,62 @@ def test_wait_for_port_release_port_free(self, mock_config_class): assert result is True + @patch("berdl_notebook_utils.spark.connect_server.SparkConnectServerConfig") + def test_wait_for_port_release_timeout(self, mock_config_class): + """Test _wait_for_port_release returns False on timeout.""" + mock_config = Mock() + mock_config.spark_connect_port = 15002 + mock_config_class.return_value = mock_config + + # Simulate port always in use (connect_ex returns 0 = success = port in use) + mock_sock = Mock() + mock_sock.connect_ex.return_value = 0 + + manager = SparkConnectServerManager() + # Patch socket inside the method's local import + with patch("socket.socket", return_value=mock_sock): + result = manager._wait_for_port_release(timeout=0.1) + + assert result is False + + @patch("berdl_notebook_utils.spark.connect_server.subprocess.run") + @patch("berdl_notebook_utils.spark.connect_server.os.kill") + @patch("berdl_notebook_utils.spark.connect_server.SparkConnectServerConfig") + def test_kill_java_process_pgrep_finds_but_kill_fails(self, mock_config_class, mock_kill, mock_run): + """Test _kill_java_process handles failure to kill found processes.""" + mock_config = Mock() + mock_config_class.return_value = mock_config + + mock_run.return_value = Mock(returncode=0, stdout="12345\n12346") + # os.kill raises for both PIDs + mock_kill.side_effect = [ProcessLookupError(), OSError("Permission denied")] + + manager = SparkConnectServerManager() + # Should not raise + manager._kill_java_process() + + @patch("berdl_notebook_utils.spark.connect_server.subprocess.run") + @patch("berdl_notebook_utils.spark.connect_server.SparkConnectServerConfig") + def test_kill_java_process_both_pgrep_pkill_not_found(self, mock_config_class, mock_run): + """Test _kill_java_process handles neither pgrep nor pkill available.""" + mock_config = Mock() + mock_config_class.return_value = mock_config + + # Both pgrep and pkill raise FileNotFoundError + mock_run.side_effect = [FileNotFoundError(), FileNotFoundError()] + + manager = SparkConnectServerManager() + # Should not raise + manager._kill_java_process() + class TestSparkConnectServerManagerForceRestart: """Tests for force_restart functionality.""" @patch("berdl_notebook_utils.spark.connect_server.SparkConnectServerManager.stop") - @patch("berdl_notebook_utils.spark.connect_server.SparkConnectServerManager.is_running") + @patch("berdl_notebook_utils.spark.connect_server.SparkConnectServerManager.get_server_info") @patch("berdl_notebook_utils.spark.connect_server.SparkConnectServerConfig") - def test_start_force_restart_calls_stop(self, mock_config_class, mock_is_running, mock_stop, tmp_path): + def test_start_force_restart_calls_stop(self, mock_config_class, mock_get_info, mock_stop, tmp_path): """Test start with force_restart=True calls stop first.""" mock_config = Mock() mock_config.username = "test_user" @@ -457,8 +614,9 @@ def test_start_force_restart_calls_stop(self, mock_config_class, mock_is_running mock_config.pid_file_path = tmp_path / "pid" mock_config_class.return_value = mock_config - # Server is running initially - mock_is_running.side_effect = [True, False] # First check: running, after stop: not running + # start() calls get_server_info() — return a dict (server running) on first call, + # then None after stop() to trigger the "start new server" path + mock_get_info.return_value = {"pid": 12345, "port": 15002} # Mock the start script check to fail (we don't want to actually start) with patch("pathlib.Path.exists", return_value=False): @@ -469,3 +627,87 @@ def test_start_force_restart_calls_stop(self, mock_config_class, mock_is_running pass # Expected - start script doesn't exist mock_stop.assert_called_once() + + @patch("berdl_notebook_utils.spark.connect_server.time") + @patch("berdl_notebook_utils.spark.connect_server.subprocess.Popen") + @patch("berdl_notebook_utils.spark.connect_server.SparkConnectServerConfig") + def test_start_new_server_success(self, mock_config_class, mock_popen, mock_time, tmp_path): + """Test starting a new Spark Connect server successfully.""" + mock_config = Mock() + mock_config.username = "test_user" + mock_config.spark_home = str(tmp_path) + mock_config.spark_master_url = "spark://master:7077" + mock_config.spark_connect_port = 15002 + mock_config.user_conf_dir = tmp_path / "conf" + mock_config.log_file_path = tmp_path / "log" + mock_config.pid_file_path = tmp_path / "pid" + mock_config_class.return_value = mock_config + + # Create the start script + sbin_dir = tmp_path / "sbin" + sbin_dir.mkdir() + start_script = sbin_dir / "start-connect-server.sh" + start_script.write_text("#!/bin/bash\n") + start_script.chmod(0o755) + + # Mock process + mock_process = Mock() + mock_process.poll.return_value = None # Process still running + mock_process.pid = 54321 + mock_popen.return_value = mock_process + + manager = SparkConnectServerManager() + + # get_server_info: None (not running), then return info after start + with patch.object(manager, "get_server_info") as mock_get_info: + mock_get_info.side_effect = [ + None, # First call: not running + { # Second call: after start + "pid": 54321, + "port": 15002, + "url": "sc://localhost:15002", + "log_file": str(tmp_path / "log"), + "master_url": "spark://master:7077", + }, + ] + + result = manager.start() + + assert result["pid"] == 54321 + assert result["port"] == 15002 + mock_config.create_directories.assert_called_once() + mock_config.generate_spark_config.assert_called_once() + + @patch("berdl_notebook_utils.spark.connect_server.time") + @patch("berdl_notebook_utils.spark.connect_server.subprocess.Popen") + @patch("berdl_notebook_utils.spark.connect_server.SparkConnectServerConfig") + def test_start_new_server_fails(self, mock_config_class, mock_popen, mock_time, tmp_path): + """Test start raises RuntimeError when server fails to start.""" + mock_config = Mock() + mock_config.username = "test_user" + mock_config.spark_home = str(tmp_path) + mock_config.spark_master_url = "spark://master:7077" + mock_config.spark_connect_port = 15002 + mock_config.user_conf_dir = tmp_path / "conf" + mock_config.log_file_path = tmp_path / "log" + mock_config.pid_file_path = tmp_path / "pid" + mock_config_class.return_value = mock_config + + # Create the start script + sbin_dir = tmp_path / "sbin" + sbin_dir.mkdir() + start_script = sbin_dir / "start-connect-server.sh" + start_script.write_text("#!/bin/bash\n") + start_script.chmod(0o755) + + # Mock process - failed (poll returns exit code) + mock_process = Mock() + mock_process.poll.return_value = 1 # Process exited with error + mock_process.returncode = 1 + mock_popen.return_value = mock_process + + manager = SparkConnectServerManager() + + with patch.object(manager, "get_server_info", return_value=None): + with pytest.raises(RuntimeError, match="Spark Connect server failed to start"): + manager.start() diff --git a/notebook_utils/tests/spark/test_data_store.py b/notebook_utils/tests/spark/test_data_store.py index ae28c84..b71a3ff 100644 --- a/notebook_utils/tests/spark/test_data_store.py +++ b/notebook_utils/tests/spark/test_data_store.py @@ -2,19 +2,19 @@ Tests for spark/data_store.py - Data store operations. """ -from unittest.mock import Mock, patch import json +from unittest.mock import Mock, patch from berdl_notebook_utils.spark.data_store import ( + _execute_with_spark, + _extract_databases_from_paths, + _format_output, _ttl_cache, clear_governance_cache, - _format_output, - _extract_databases_from_paths, get_databases, - get_tables, - get_table_schema, get_db_structure, - _execute_with_spark, + get_table_schema, + get_tables, ) diff --git a/notebook_utils/tests/spark/test_database.py b/notebook_utils/tests/spark/test_database.py index 19e5e2c..4f70732 100644 --- a/notebook_utils/tests/spark/test_database.py +++ b/notebook_utils/tests/spark/test_database.py @@ -110,12 +110,16 @@ def test_generate_namespace_location_no_match_warns( assert "Warning: Could not determine target name from warehouse directory" in captured.out +# ============================================================================ +# create_namespace_if_not_exists tests (Delta/Hive flow) +# ============================================================================ + + @pytest.mark.parametrize("tenant", [None, TENANT_NAME]) @pytest.mark.parametrize("namespace_arg", EXPECTED_NS) def test_create_namespace_if_not_exists_user_tenant_warehouse(namespace_arg: str | None, tenant: str | None) -> None: """Test user and tenant namespace creation.""" mock_spark = make_mock_spark() - # Run with append_target=True (default) ns = create_namespace_if_not_exists(mock_spark, namespace=namespace_arg, tenant_name=tenant) # type: ignore namespace = EXPECTED_NS[namespace_arg] if tenant: @@ -128,18 +132,6 @@ def test_create_namespace_if_not_exists_user_tenant_warehouse(namespace_arg: str mock_spark.sql.assert_called_once_with(f"CREATE DATABASE IF NOT EXISTS {ns} LOCATION '{expected_location}'") -@pytest.mark.parametrize("tenant", [None, TENANT_NAME]) -@pytest.mark.parametrize("namespace_arg", EXPECTED_NS) -def test_create_namespace_if_not_exists_without_prefix(namespace_arg: str | None, tenant: str | None) -> None: - """Test namespace creation when append_target is set to false.""" - mock_spark = make_mock_spark() - ns = create_namespace_if_not_exists(mock_spark, namespace=namespace_arg, append_target=False, tenant_name=tenant) # type: ignore - namespace = EXPECTED_NS[namespace_arg] - assert ns == namespace - # Should create database without LOCATION clause - mock_spark.sql.assert_called_once_with(f"CREATE DATABASE IF NOT EXISTS {namespace}") - - @pytest.mark.parametrize("tenant", [None, TENANT_NAME]) @pytest.mark.parametrize("namespace_arg", EXPECTED_NS) def test_create_namespace_if_not_exists_already_exists( @@ -147,13 +139,17 @@ def test_create_namespace_if_not_exists_already_exists( ) -> None: """Test namespace creation when the namespace has already been registered.""" mock_spark = make_mock_spark(database_exists=True) - ns = create_namespace_if_not_exists(mock_spark, namespace=namespace_arg, append_target=False, tenant_name=tenant) # type: ignore + ns = create_namespace_if_not_exists(mock_spark, namespace=namespace_arg, tenant_name=tenant) # type: ignore namespace = EXPECTED_NS[namespace_arg] - assert ns == namespace + if tenant: + expected_ns = f"tenant__{namespace}" + else: + expected_ns = f"user__{namespace}" + assert ns == expected_ns # No call to spark.sql as the namespace already exists mock_spark.sql.assert_not_called() captured = capfd.readouterr() - assert f"Namespace {namespace} is already registered and ready to use" in captured.out + assert f"Namespace {expected_ns} is already registered and ready to use" in captured.out @pytest.mark.parametrize("namespace_arg", EXPECTED_NS) @@ -186,3 +182,37 @@ def test_create_namespace_if_not_exists_error() -> None: pytest.raises(RuntimeError, match="things went wrong"), ): create_namespace_if_not_exists(Mock(), "some_namespace") + + +# ============================================================================ +# create_namespace_if_not_exists iceberg=True tests (Polaris Iceberg flow) +# ============================================================================ + + +@pytest.mark.parametrize("namespace_arg", EXPECTED_NS) +def test_create_namespace_iceberg_default_catalog(namespace_arg: str | None, capfd: pytest.CaptureFixture[str]) -> None: + """Test Iceberg namespace creation uses 'my' catalog by default (no tenant_name).""" + mock_spark = make_mock_spark() + namespace = EXPECTED_NS[namespace_arg] + result = create_namespace_if_not_exists(mock_spark, namespace, iceberg=True) + assert result == f"my.{namespace}" + mock_spark.sql.assert_called_once_with(f"CREATE NAMESPACE IF NOT EXISTS my.{namespace}") + captured = capfd.readouterr() + assert f"Namespace my.{namespace} is ready to use." in captured.out + + +@pytest.mark.parametrize("tenant", ["globalusers", "research"]) +def test_create_namespace_iceberg_tenant_as_catalog(tenant: str) -> None: + """Test Iceberg namespace creation uses tenant_name as catalog.""" + mock_spark = make_mock_spark() + result = create_namespace_if_not_exists(mock_spark, "test_db", iceberg=True, tenant_name=tenant) + assert result == f"{tenant}.test_db" + mock_spark.sql.assert_called_once_with(f"CREATE NAMESPACE IF NOT EXISTS {tenant}.test_db") + + +def test_create_namespace_iceberg_no_governance_prefix() -> None: + """Test that iceberg=True does NOT call governance API for prefixes.""" + mock_spark = make_mock_spark() + with patch("berdl_notebook_utils.spark.database.get_namespace_prefix") as mock_prefix: + create_namespace_if_not_exists(mock_spark, "test_db", iceberg=True) + mock_prefix.assert_not_called() diff --git a/notebook_utils/tests/test_cache.py b/notebook_utils/tests/test_cache.py new file mode 100644 index 0000000..f0920f6 --- /dev/null +++ b/notebook_utils/tests/test_cache.py @@ -0,0 +1,116 @@ +"""Tests for cache.py - Token-dependent cache management.""" + +from functools import lru_cache + +from berdl_notebook_utils.cache import ( + _token_change_caches, + clear_kbase_token_caches, + kbase_token_dependent, +) + + +class TestKbaseTokenDependent: + """Tests for kbase_token_dependent decorator.""" + + def test_registers_function(self): + """Test that decorator registers the function in _token_change_caches.""" + initial_len = len(_token_change_caches) + + # Production order: @kbase_token_dependent on top of @lru_cache + # This means lru_cache wraps first, then kbase_token_dependent registers the wrapper + @kbase_token_dependent + @lru_cache + def dummy_func(): + return "value" + + assert len(_token_change_caches) == initial_len + 1 + assert dummy_func in _token_change_caches + + # Clean up + _token_change_caches.remove(dummy_func) + + def test_returns_function_unchanged(self): + """Test that decorator returns the function without modification.""" + + @lru_cache + def original(): + return 42 + + result = kbase_token_dependent(original) + assert result is original + assert result() == 42 + + # Clean up + _token_change_caches.remove(original) + + +class TestClearKbaseTokenCaches: + """Tests for clear_kbase_token_caches function.""" + + def test_clears_all_registered_caches(self): + """Test that clear_kbase_token_caches calls cache_clear on all registered functions.""" + call_count = 0 + + @lru_cache + def cached_func(): + nonlocal call_count + call_count += 1 + return call_count + + _token_change_caches.append(cached_func) + + try: + # Call once to populate cache + result1 = cached_func() + assert result1 == 1 + + # Call again - should return cached value + result2 = cached_func() + assert result2 == 1 + + # Clear caches + clear_kbase_token_caches() + + # Call again - should recompute + result3 = cached_func() + assert result3 == 2 + finally: + _token_change_caches.remove(cached_func) + + def test_handles_multiple_caches(self): + """Test clearing multiple registered caches.""" + + @lru_cache + def func_a(): + return "a" + + @lru_cache + def func_b(): + return "b" + + _token_change_caches.append(func_a) + _token_change_caches.append(func_b) + + try: + # Populate caches + func_a() + func_b() + assert func_a.cache_info().hits == 0 + assert func_b.cache_info().hits == 0 + + func_a() + func_b() + assert func_a.cache_info().hits == 1 + assert func_b.cache_info().hits == 1 + + # Clear all + clear_kbase_token_caches() + + # Verify caches were cleared + assert func_a.cache_info().hits == 0 + assert func_a.cache_info().misses == 0 + assert func_b.cache_info().hits == 0 + assert func_b.cache_info().misses == 0 + finally: + _token_change_caches.remove(func_a) + _token_change_caches.remove(func_b) diff --git a/notebook_utils/tests/test_get_spark_session.py b/notebook_utils/tests/test_get_spark_session.py index 7c7727a..35d8caf 100644 --- a/notebook_utils/tests/test_get_spark_session.py +++ b/notebook_utils/tests/test_get_spark_session.py @@ -347,3 +347,207 @@ def test_executor_conf_no_auth_token_for_legacy_mode(monkeypatch: pytest.MonkeyP assert "spark.driver.host" in config # Verify master URL does NOT contain auth token (legacy mode doesn't use URL-based auth) assert "authorization" not in config.get("spark.master", "") + + +# ============================================================================= +# convert_memory_format tests +# ============================================================================= + + +class TestConvertMemoryFormat: + """Tests for convert_memory_format edge cases.""" + + def test_invalid_memory_format_raises(self): + """Test that invalid memory format raises ValueError.""" + from berdl_notebook_utils.setup_spark_session import convert_memory_format + + with pytest.raises(ValueError, match="Invalid memory format"): + convert_memory_format("not_a_memory_value") + + def test_invalid_memory_format_no_unit(self): + """Test that bare number without unit raises ValueError.""" + from berdl_notebook_utils.setup_spark_session import convert_memory_format + + with pytest.raises(ValueError, match="Invalid memory format"): + convert_memory_format("1024") + + def test_small_memory_returns_kb(self): + """Test small memory values return kilobyte format.""" + from berdl_notebook_utils.setup_spark_session import convert_memory_format + + # 1 MiB with 10% overhead = ~921.6 KiB → "922k" + result = convert_memory_format("1MiB", 0.1) + assert result.endswith("k") + + def test_very_small_memory_returns_bytes(self): + """Test very small memory values (< 1 KiB) return byte format (no unit suffix).""" + from berdl_notebook_utils.setup_spark_session import convert_memory_format + + # The regex requires [kmgtKMGT] prefix, so smallest valid unit is "kb" + # 0.5 kb = 512 bytes, with 10% overhead = 460.8 bytes → "461" (no unit) + result = convert_memory_format("0.5kb", 0.1) + assert not result.endswith("k") + assert not result.endswith("m") + assert not result.endswith("g") + + def test_gib_memory(self): + """Test GiB memory format.""" + from berdl_notebook_utils.setup_spark_session import convert_memory_format + + result = convert_memory_format("4GiB", 0.1) + assert result.endswith("g") + + def test_unit_key_fallback(self): + """Test unit key fallback for unusual unit formats.""" + from berdl_notebook_utils.setup_spark_session import convert_memory_format + + # "4gb" should work + result = convert_memory_format("4gb", 0.1) + assert result.endswith("g") + + +# ============================================================================= +# _get_catalog_conf tests +# ============================================================================= + + +class TestGetCatalogConf: + """Tests for _get_catalog_conf with Polaris configuration.""" + + def test_returns_empty_when_no_polaris_uri(self): + """Test returns empty dict when POLARIS_CATALOG_URI is None.""" + from berdl_notebook_utils.setup_spark_session import _get_catalog_conf + + settings = BERDLSettings() + settings.POLARIS_CATALOG_URI = None + + result = _get_catalog_conf(settings) + + assert result == {} + + def test_personal_catalog_config(self): + """Test generates personal catalog config when POLARIS_PERSONAL_CATALOG is set.""" + from berdl_notebook_utils.setup_spark_session import _get_catalog_conf + + settings = BERDLSettings() + settings.POLARIS_CATALOG_URI = "http://polaris:8181/api/catalog" # type: ignore + settings.POLARIS_CREDENTIAL = "client_id:client_secret" + settings.POLARIS_PERSONAL_CATALOG = "user_tgu2" + settings.POLARIS_TENANT_CATALOGS = None + + result = _get_catalog_conf(settings) + + assert "spark.sql.catalog.my" in result + assert result["spark.sql.catalog.my"] == "org.apache.iceberg.spark.SparkCatalog" + assert result["spark.sql.catalog.my.type"] == "rest" + assert result["spark.sql.catalog.my.warehouse"] == "user_tgu2" + assert "spark.sql.catalog.my.s3.endpoint" in result + assert result["spark.sql.catalog.my.s3.path-style-access"] == "true" + + def test_tenant_catalog_config(self): + """Test generates tenant catalog config with 'tenant_' prefix stripped.""" + from berdl_notebook_utils.setup_spark_session import _get_catalog_conf + + settings = BERDLSettings() + settings.POLARIS_CATALOG_URI = "http://polaris:8181/api/catalog" # type: ignore + settings.POLARIS_CREDENTIAL = "client_id:client_secret" + settings.POLARIS_PERSONAL_CATALOG = None + settings.POLARIS_TENANT_CATALOGS = "tenant_globalusers,tenant_research" + + result = _get_catalog_conf(settings) + + # "tenant_globalusers" → alias "globalusers" + assert "spark.sql.catalog.globalusers" in result + assert result["spark.sql.catalog.globalusers.warehouse"] == "tenant_globalusers" + # "tenant_research" → alias "research" + assert "spark.sql.catalog.research" in result + assert result["spark.sql.catalog.research.warehouse"] == "tenant_research" + + def test_empty_tenant_catalog_entries_skipped(self): + """Test empty entries in POLARIS_TENANT_CATALOGS are skipped.""" + from berdl_notebook_utils.setup_spark_session import _get_catalog_conf + + settings = BERDLSettings() + settings.POLARIS_CATALOG_URI = "http://polaris:8181/api/catalog" # type: ignore + settings.POLARIS_CREDENTIAL = "client_id:client_secret" + settings.POLARIS_PERSONAL_CATALOG = None + settings.POLARIS_TENANT_CATALOGS = "tenant_kbase,, " + + result = _get_catalog_conf(settings) + + assert "spark.sql.catalog.kbase" in result + # Empty entries should not produce catalog configs + catalog_keys = [ + k for k in result if k.startswith("spark.sql.catalog.") and "." not in k[len("spark.sql.catalog.") :] + ] + assert all("kbase" in k for k in catalog_keys) + + def test_s3_endpoint_without_http_prefix(self): + """Test s3 endpoint gets http:// prefix if missing.""" + from berdl_notebook_utils.setup_spark_session import _get_catalog_conf + + settings = BERDLSettings() + settings.POLARIS_CATALOG_URI = "http://polaris:8181/api/catalog" # type: ignore + settings.POLARIS_CREDENTIAL = "client_id:client_secret" + settings.POLARIS_PERSONAL_CATALOG = "user_test" + settings.POLARIS_TENANT_CATALOGS = None + settings.MINIO_ENDPOINT_URL = "minio:9000" + + result = _get_catalog_conf(settings) + + assert result["spark.sql.catalog.my.s3.endpoint"] == "http://minio:9000" + + def test_both_personal_and_tenant_catalogs(self): + """Test generates config for both personal and tenant catalogs.""" + from berdl_notebook_utils.setup_spark_session import _get_catalog_conf + + settings = BERDLSettings() + settings.POLARIS_CATALOG_URI = "http://polaris:8181/api/catalog" # type: ignore + settings.POLARIS_CREDENTIAL = "client_id:client_secret" + settings.POLARIS_PERSONAL_CATALOG = "user_alice" + settings.POLARIS_TENANT_CATALOGS = "tenant_team" + + result = _get_catalog_conf(settings) + + assert "spark.sql.catalog.my" in result + assert "spark.sql.catalog.team" in result + + +# ============================================================================= +# _is_immutable_config tests +# ============================================================================= + + +class TestIsImmutableConfig: + """Tests for _is_immutable_config function.""" + + def test_known_immutable_configs(self): + """Test known immutable config keys are detected.""" + from berdl_notebook_utils.setup_spark_session import _is_immutable_config + + for key in IMMUTABLE_CONFIGS: + assert _is_immutable_config(key) is True, f"Expected {key} to be immutable" + + def test_catalog_config_keys_are_immutable(self): + """Test that spark.sql.catalog..* keys are immutable.""" + from berdl_notebook_utils.setup_spark_session import _is_immutable_config + + assert _is_immutable_config("spark.sql.catalog.my") is True + assert _is_immutable_config("spark.sql.catalog.my.type") is True + assert _is_immutable_config("spark.sql.catalog.globalusers") is True + assert _is_immutable_config("spark.sql.catalog.globalusers.warehouse") is True + + def test_spark_catalog_is_not_custom_catalog(self): + """Test spark_catalog is handled by IMMUTABLE_CONFIGS set, not the prefix check.""" + from berdl_notebook_utils.setup_spark_session import _is_immutable_config + + # spark.sql.catalog.spark_catalog is in IMMUTABLE_CONFIGS explicitly + assert _is_immutable_config("spark.sql.catalog.spark_catalog") is True + + def test_mutable_config_keys(self): + """Test that non-immutable keys return False.""" + from berdl_notebook_utils.setup_spark_session import _is_immutable_config + + assert _is_immutable_config("spark.app.name") is False + assert _is_immutable_config("spark.sql.autoBroadcastJoinThreshold") is False + assert _is_immutable_config("spark.hadoop.fs.s3a.endpoint") is False diff --git a/notebook_utils/tests/test_refresh.py b/notebook_utils/tests/test_refresh.py new file mode 100644 index 0000000..380663e --- /dev/null +++ b/notebook_utils/tests/test_refresh.py @@ -0,0 +1,216 @@ +"""Tests for refresh.py - Credential and Spark environment refresh.""" + +from pathlib import Path +from unittest.mock import Mock, patch + +from berdl_notebook_utils.refresh import _remove_cache_file, refresh_spark_environment + + +class TestRemoveCacheFile: + """Tests for _remove_cache_file helper.""" + + def test_removes_existing_file(self, tmp_path): + """Test removes file and returns True when file exists.""" + cache_file = tmp_path / "test_cache" + cache_file.write_text("cached data") + + result = _remove_cache_file(cache_file) + + assert result is True + assert not cache_file.exists() + + def test_preserves_lock_file(self, tmp_path): + """Test leaves the .lock companion file in place.""" + cache_file = tmp_path / "test_cache" + lock_file = tmp_path / "test_cache.lock" + cache_file.write_text("cached data") + lock_file.write_text("lock") + + _remove_cache_file(cache_file) + + assert not cache_file.exists() + assert lock_file.exists() + + def test_returns_false_when_file_missing(self, tmp_path): + """Test returns False when file doesn't exist.""" + result = _remove_cache_file(tmp_path / "nonexistent") + + assert result is False + + def test_handles_os_error(self, tmp_path): + """Test silently handles OSError on unlink.""" + with ( + patch.object(Path, "exists", return_value=True), + patch.object(Path, "unlink", side_effect=OSError("permission denied")), + ): + result = _remove_cache_file(tmp_path / "test_cache") + + assert result is False + + +class TestRefreshSparkEnvironment: + """Tests for refresh_spark_environment function.""" + + @patch("berdl_notebook_utils.refresh.start_spark_connect_server") + @patch("berdl_notebook_utils.refresh.SparkSession") + @patch("berdl_notebook_utils.refresh.get_polaris_credentials") + @patch("berdl_notebook_utils.refresh.get_minio_credentials") + @patch("berdl_notebook_utils.refresh.get_settings") + @patch("berdl_notebook_utils.refresh._remove_cache_file") + def test_happy_path(self, mock_remove, mock_settings, mock_minio, mock_polaris, mock_spark, mock_sc_start): + """Test full refresh with all services succeeding.""" + mock_minio_creds = Mock() + mock_minio_creds.username = "u_testuser" + mock_minio.return_value = mock_minio_creds + + mock_polaris.return_value = { + "client_id": "abc", + "client_secret": "xyz", + "personal_catalog": "user_testuser", + "tenant_catalogs": ["tenant_a"], + } + + mock_spark.getActiveSession.return_value = None + mock_sc_start.return_value = {"status": "running"} + + result = refresh_spark_environment() + + assert result["minio"] == {"status": "ok", "username": "u_testuser"} + assert result["polaris"]["status"] == "ok" + assert result["polaris"]["personal_catalog"] == "user_testuser" + assert result["spark_session_stopped"] is False + assert result["spark_connect"] == {"status": "running"} + assert mock_settings.cache_clear.call_count == 2 + mock_sc_start.assert_called_once_with(force_restart=True) + + @patch("berdl_notebook_utils.refresh.start_spark_connect_server") + @patch("berdl_notebook_utils.refresh.SparkSession") + @patch("berdl_notebook_utils.refresh.get_polaris_credentials") + @patch("berdl_notebook_utils.refresh.get_minio_credentials") + @patch("berdl_notebook_utils.refresh.get_settings") + @patch("berdl_notebook_utils.refresh._remove_cache_file") + def test_stops_existing_spark_session( + self, mock_remove, mock_settings, mock_minio, mock_polaris, mock_spark, mock_sc_start + ): + """Test stops active Spark session before restarting Connect server.""" + mock_minio.return_value = Mock(username="u_test") + mock_polaris.return_value = None + + mock_session = Mock() + mock_spark.getActiveSession.return_value = mock_session + mock_sc_start.return_value = {"status": "running"} + + result = refresh_spark_environment() + + mock_session.stop.assert_called_once() + assert result["spark_session_stopped"] is True + + @patch("berdl_notebook_utils.refresh.start_spark_connect_server") + @patch("berdl_notebook_utils.refresh.SparkSession") + @patch("berdl_notebook_utils.refresh.get_polaris_credentials") + @patch("berdl_notebook_utils.refresh.get_minio_credentials") + @patch("berdl_notebook_utils.refresh.get_settings") + @patch("berdl_notebook_utils.refresh._remove_cache_file") + def test_polaris_not_configured( + self, mock_remove, mock_settings, mock_minio, mock_polaris, mock_spark, mock_sc_start + ): + """Test handles Polaris not being configured (returns None).""" + mock_minio.return_value = Mock(username="u_test") + mock_polaris.return_value = None + mock_spark.getActiveSession.return_value = None + mock_sc_start.return_value = {"status": "running"} + + result = refresh_spark_environment() + + assert result["polaris"] == {"status": "skipped", "reason": "Polaris not configured"} + + @patch("berdl_notebook_utils.refresh.start_spark_connect_server") + @patch("berdl_notebook_utils.refresh.SparkSession") + @patch("berdl_notebook_utils.refresh.get_polaris_credentials") + @patch("berdl_notebook_utils.refresh.get_minio_credentials") + @patch("berdl_notebook_utils.refresh.get_settings") + @patch("berdl_notebook_utils.refresh._remove_cache_file") + def test_minio_error_does_not_block( + self, mock_remove, mock_settings, mock_minio, mock_polaris, mock_spark, mock_sc_start + ): + """Test that MinIO failure doesn't prevent Polaris/Spark refresh.""" + mock_minio.side_effect = ConnectionError("minio unreachable") + mock_polaris.return_value = { + "client_id": "x", + "client_secret": "y", + "personal_catalog": "user_test", + "tenant_catalogs": [], + } + mock_spark.getActiveSession.return_value = None + mock_sc_start.return_value = {"status": "running"} + + result = refresh_spark_environment() + + assert result["minio"]["status"] == "error" + assert "minio unreachable" in result["minio"]["error"] + assert result["polaris"]["status"] == "ok" + assert result["spark_connect"] == {"status": "running"} + + @patch("berdl_notebook_utils.refresh.start_spark_connect_server") + @patch("berdl_notebook_utils.refresh.SparkSession") + @patch("berdl_notebook_utils.refresh.get_polaris_credentials") + @patch("berdl_notebook_utils.refresh.get_minio_credentials") + @patch("berdl_notebook_utils.refresh.get_settings") + @patch("berdl_notebook_utils.refresh._remove_cache_file") + def test_spark_connect_error_captured( + self, mock_remove, mock_settings, mock_minio, mock_polaris, mock_spark, mock_sc_start + ): + """Test that Spark Connect restart failure is captured in result.""" + mock_minio.return_value = Mock(username="u_test") + mock_polaris.return_value = None + mock_spark.getActiveSession.return_value = None + mock_sc_start.side_effect = RuntimeError("server start failed") + + result = refresh_spark_environment() + + assert result["spark_connect"]["status"] == "error" + assert "server start failed" in result["spark_connect"]["error"] + + @patch("berdl_notebook_utils.refresh.start_spark_connect_server") + @patch("berdl_notebook_utils.refresh.SparkSession") + @patch("berdl_notebook_utils.refresh.get_polaris_credentials") + @patch("berdl_notebook_utils.refresh.get_minio_credentials") + @patch("berdl_notebook_utils.refresh.get_settings") + @patch("berdl_notebook_utils.refresh._remove_cache_file") + def test_cache_files_removed_first( + self, mock_remove, mock_settings, mock_minio, mock_polaris, mock_spark, mock_sc_start + ): + """Test that cache files are removed before credentials are re-fetched.""" + mock_minio.return_value = Mock(username="u_test") + mock_polaris.return_value = None + mock_spark.getActiveSession.return_value = None + mock_sc_start.return_value = {"status": "running"} + + refresh_spark_environment() + + # _remove_cache_file called twice (minio + polaris) before get_minio_credentials + assert mock_remove.call_count == 2 + # Settings cache cleared before credential fetches + mock_settings.cache_clear.assert_called() + + @patch("berdl_notebook_utils.refresh.start_spark_connect_server") + @patch("berdl_notebook_utils.refresh.SparkSession") + @patch("berdl_notebook_utils.refresh.get_polaris_credentials") + @patch("berdl_notebook_utils.refresh.get_minio_credentials") + @patch("berdl_notebook_utils.refresh.get_settings") + @patch("berdl_notebook_utils.refresh._remove_cache_file") + def test_all_errors_still_returns_result( + self, mock_remove, mock_settings, mock_minio, mock_polaris, mock_spark, mock_sc_start + ): + """Test that even if everything fails, we get a complete result dict.""" + mock_minio.side_effect = Exception("minio fail") + mock_polaris.side_effect = Exception("polaris fail") + mock_spark.getActiveSession.return_value = None + mock_sc_start.side_effect = Exception("spark fail") + + result = refresh_spark_environment() + + assert result["minio"]["status"] == "error" + assert result["polaris"]["status"] == "error" + assert result["spark_session_stopped"] is False + assert result["spark_connect"]["status"] == "error" diff --git a/notebook_utils/uv.lock b/notebook_utils/uv.lock index 63bd27b..2daf78d 100644 --- a/notebook_utils/uv.lock +++ b/notebook_utils/uv.lock @@ -239,7 +239,7 @@ requires-dist = [ { name = "attrs", specifier = ">=25.4.0" }, { name = "cdm-spark-manager-client", git = "https://github.com/kbase/cdm-kube-spark-manager-client.git?rev=0.0.1" }, { name = "cdm-task-service-client", git = "https://github.com/kbase/cdm-task-service-client?rev=0.2.3" }, - { name = "datalake-mcp-server-client", git = "https://github.com/BERDataLakehouse/datalake-mcp-server-client.git?rev=v0.0.6" }, + { name = "datalake-mcp-server-client", git = "https://github.com/BERDataLakehouse/datalake-mcp-server-client.git?rev=v0.0.7" }, { name = "hmsclient", specifier = ">=0.1" }, { name = "httpx", specifier = ">=0.24" }, { name = "ipython", specifier = ">=9.10.0" }, @@ -250,7 +250,7 @@ requires-dist = [ { name = "langchain-community", specifier = ">=0.3.31" }, { name = "langchain-openai", specifier = ">=0.3.35" }, { name = "minio", specifier = ">=7.0" }, - { name = "minio-manager-service-client", git = "https://github.com/BERDataLakehouse/minio_manager_service_client.git?rev=v0.0.7" }, + { name = "minio-manager-service-client", git = "https://github.com/BERDataLakehouse/minio_manager_service_client.git?rev=v0.0.10" }, { name = "pandas", specifier = ">=3.0.1" }, { name = "pydantic-settings", specifier = ">=2.0" }, { name = "pyspark", extras = ["connect"], specifier = ">=4.0" }, @@ -338,43 +338,43 @@ wheels = [ [[package]] name = "charset-normalizer" -version = "3.4.4" -source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/13/69/33ddede1939fdd074bce5434295f38fae7136463422fe4fd3e0e89b98062/charset_normalizer-3.4.4.tar.gz", hash = "sha256:94537985111c35f28720e43603b8e7b43a6ecfb2ce1d3058bbe955b73404e21a", size = 129418, upload-time = "2025-10-14T04:42:32.879Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/97/45/4b3a1239bbacd321068ea6e7ac28875b03ab8bc0aa0966452db17cd36714/charset_normalizer-3.4.4-cp313-cp313-macosx_10_13_universal2.whl", hash = "sha256:e1f185f86a6f3403aa2420e815904c67b2f9ebc443f045edd0de921108345794", size = 208091, upload-time = "2025-10-14T04:41:13.346Z" }, - { url = "https://files.pythonhosted.org/packages/7d/62/73a6d7450829655a35bb88a88fca7d736f9882a27eacdca2c6d505b57e2e/charset_normalizer-3.4.4-cp313-cp313-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:6b39f987ae8ccdf0d2642338faf2abb1862340facc796048b604ef14919e55ed", size = 147936, upload-time = "2025-10-14T04:41:14.461Z" }, - { url = "https://files.pythonhosted.org/packages/89/c5/adb8c8b3d6625bef6d88b251bbb0d95f8205831b987631ab0c8bb5d937c2/charset_normalizer-3.4.4-cp313-cp313-manylinux2014_armv7l.manylinux_2_17_armv7l.manylinux_2_31_armv7l.whl", hash = "sha256:3162d5d8ce1bb98dd51af660f2121c55d0fa541b46dff7bb9b9f86ea1d87de72", size = 144180, upload-time = "2025-10-14T04:41:15.588Z" }, - { url = "https://files.pythonhosted.org/packages/91/ed/9706e4070682d1cc219050b6048bfd293ccf67b3d4f5a4f39207453d4b99/charset_normalizer-3.4.4-cp313-cp313-manylinux2014_ppc64le.manylinux_2_17_ppc64le.manylinux_2_28_ppc64le.whl", hash = "sha256:81d5eb2a312700f4ecaa977a8235b634ce853200e828fbadf3a9c50bab278328", size = 161346, upload-time = "2025-10-14T04:41:16.738Z" }, - { url = "https://files.pythonhosted.org/packages/d5/0d/031f0d95e4972901a2f6f09ef055751805ff541511dc1252ba3ca1f80cf5/charset_normalizer-3.4.4-cp313-cp313-manylinux2014_s390x.manylinux_2_17_s390x.manylinux_2_28_s390x.whl", hash = "sha256:5bd2293095d766545ec1a8f612559f6b40abc0eb18bb2f5d1171872d34036ede", size = 158874, upload-time = "2025-10-14T04:41:17.923Z" }, - { url = "https://files.pythonhosted.org/packages/f5/83/6ab5883f57c9c801ce5e5677242328aa45592be8a00644310a008d04f922/charset_normalizer-3.4.4-cp313-cp313-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:a8a8b89589086a25749f471e6a900d3f662d1d3b6e2e59dcecf787b1cc3a1894", size = 153076, upload-time = "2025-10-14T04:41:19.106Z" }, - { url = "https://files.pythonhosted.org/packages/75/1e/5ff781ddf5260e387d6419959ee89ef13878229732732ee73cdae01800f2/charset_normalizer-3.4.4-cp313-cp313-manylinux_2_31_riscv64.manylinux_2_39_riscv64.whl", hash = "sha256:bc7637e2f80d8530ee4a78e878bce464f70087ce73cf7c1caf142416923b98f1", size = 150601, upload-time = "2025-10-14T04:41:20.245Z" }, - { url = "https://files.pythonhosted.org/packages/d7/57/71be810965493d3510a6ca79b90c19e48696fb1ff964da319334b12677f0/charset_normalizer-3.4.4-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:f8bf04158c6b607d747e93949aa60618b61312fe647a6369f88ce2ff16043490", size = 150376, upload-time = "2025-10-14T04:41:21.398Z" }, - { url = "https://files.pythonhosted.org/packages/e5/d5/c3d057a78c181d007014feb7e9f2e65905a6c4ef182c0ddf0de2924edd65/charset_normalizer-3.4.4-cp313-cp313-musllinux_1_2_armv7l.whl", hash = "sha256:554af85e960429cf30784dd47447d5125aaa3b99a6f0683589dbd27e2f45da44", size = 144825, upload-time = "2025-10-14T04:41:22.583Z" }, - { url = "https://files.pythonhosted.org/packages/e6/8c/d0406294828d4976f275ffbe66f00266c4b3136b7506941d87c00cab5272/charset_normalizer-3.4.4-cp313-cp313-musllinux_1_2_ppc64le.whl", hash = "sha256:74018750915ee7ad843a774364e13a3db91682f26142baddf775342c3f5b1133", size = 162583, upload-time = "2025-10-14T04:41:23.754Z" }, - { url = "https://files.pythonhosted.org/packages/d7/24/e2aa1f18c8f15c4c0e932d9287b8609dd30ad56dbe41d926bd846e22fb8d/charset_normalizer-3.4.4-cp313-cp313-musllinux_1_2_riscv64.whl", hash = "sha256:c0463276121fdee9c49b98908b3a89c39be45d86d1dbaa22957e38f6321d4ce3", size = 150366, upload-time = "2025-10-14T04:41:25.27Z" }, - { url = "https://files.pythonhosted.org/packages/e4/5b/1e6160c7739aad1e2df054300cc618b06bf784a7a164b0f238360721ab86/charset_normalizer-3.4.4-cp313-cp313-musllinux_1_2_s390x.whl", hash = "sha256:362d61fd13843997c1c446760ef36f240cf81d3ebf74ac62652aebaf7838561e", size = 160300, upload-time = "2025-10-14T04:41:26.725Z" }, - { url = "https://files.pythonhosted.org/packages/7a/10/f882167cd207fbdd743e55534d5d9620e095089d176d55cb22d5322f2afd/charset_normalizer-3.4.4-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:9a26f18905b8dd5d685d6d07b0cdf98a79f3c7a918906af7cc143ea2e164c8bc", size = 154465, upload-time = "2025-10-14T04:41:28.322Z" }, - { url = "https://files.pythonhosted.org/packages/89/66/c7a9e1b7429be72123441bfdbaf2bc13faab3f90b933f664db506dea5915/charset_normalizer-3.4.4-cp313-cp313-win32.whl", hash = "sha256:9b35f4c90079ff2e2edc5b26c0c77925e5d2d255c42c74fdb70fb49b172726ac", size = 99404, upload-time = "2025-10-14T04:41:29.95Z" }, - { url = "https://files.pythonhosted.org/packages/c4/26/b9924fa27db384bdcd97ab83b4f0a8058d96ad9626ead570674d5e737d90/charset_normalizer-3.4.4-cp313-cp313-win_amd64.whl", hash = "sha256:b435cba5f4f750aa6c0a0d92c541fb79f69a387c91e61f1795227e4ed9cece14", size = 107092, upload-time = "2025-10-14T04:41:31.188Z" }, - { url = "https://files.pythonhosted.org/packages/af/8f/3ed4bfa0c0c72a7ca17f0380cd9e4dd842b09f664e780c13cff1dcf2ef1b/charset_normalizer-3.4.4-cp313-cp313-win_arm64.whl", hash = "sha256:542d2cee80be6f80247095cc36c418f7bddd14f4a6de45af91dfad36d817bba2", size = 100408, upload-time = "2025-10-14T04:41:32.624Z" }, - { url = "https://files.pythonhosted.org/packages/2a/35/7051599bd493e62411d6ede36fd5af83a38f37c4767b92884df7301db25d/charset_normalizer-3.4.4-cp314-cp314-macosx_10_13_universal2.whl", hash = "sha256:da3326d9e65ef63a817ecbcc0df6e94463713b754fe293eaa03da99befb9a5bd", size = 207746, upload-time = "2025-10-14T04:41:33.773Z" }, - { url = "https://files.pythonhosted.org/packages/10/9a/97c8d48ef10d6cd4fcead2415523221624bf58bcf68a802721a6bc807c8f/charset_normalizer-3.4.4-cp314-cp314-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:8af65f14dc14a79b924524b1e7fffe304517b2bff5a58bf64f30b98bbc5079eb", size = 147889, upload-time = "2025-10-14T04:41:34.897Z" }, - { url = "https://files.pythonhosted.org/packages/10/bf/979224a919a1b606c82bd2c5fa49b5c6d5727aa47b4312bb27b1734f53cd/charset_normalizer-3.4.4-cp314-cp314-manylinux2014_armv7l.manylinux_2_17_armv7l.manylinux_2_31_armv7l.whl", hash = "sha256:74664978bb272435107de04e36db5a9735e78232b85b77d45cfb38f758efd33e", size = 143641, upload-time = "2025-10-14T04:41:36.116Z" }, - { url = "https://files.pythonhosted.org/packages/ba/33/0ad65587441fc730dc7bd90e9716b30b4702dc7b617e6ba4997dc8651495/charset_normalizer-3.4.4-cp314-cp314-manylinux2014_ppc64le.manylinux_2_17_ppc64le.manylinux_2_28_ppc64le.whl", hash = "sha256:752944c7ffbfdd10c074dc58ec2d5a8a4cd9493b314d367c14d24c17684ddd14", size = 160779, upload-time = "2025-10-14T04:41:37.229Z" }, - { url = "https://files.pythonhosted.org/packages/67/ed/331d6b249259ee71ddea93f6f2f0a56cfebd46938bde6fcc6f7b9a3d0e09/charset_normalizer-3.4.4-cp314-cp314-manylinux2014_s390x.manylinux_2_17_s390x.manylinux_2_28_s390x.whl", hash = "sha256:d1f13550535ad8cff21b8d757a3257963e951d96e20ec82ab44bc64aeb62a191", size = 159035, upload-time = "2025-10-14T04:41:38.368Z" }, - { url = "https://files.pythonhosted.org/packages/67/ff/f6b948ca32e4f2a4576aa129d8bed61f2e0543bf9f5f2b7fc3758ed005c9/charset_normalizer-3.4.4-cp314-cp314-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:ecaae4149d99b1c9e7b88bb03e3221956f68fd6d50be2ef061b2381b61d20838", size = 152542, upload-time = "2025-10-14T04:41:39.862Z" }, - { url = "https://files.pythonhosted.org/packages/16/85/276033dcbcc369eb176594de22728541a925b2632f9716428c851b149e83/charset_normalizer-3.4.4-cp314-cp314-manylinux_2_31_riscv64.manylinux_2_39_riscv64.whl", hash = "sha256:cb6254dc36b47a990e59e1068afacdcd02958bdcce30bb50cc1700a8b9d624a6", size = 149524, upload-time = "2025-10-14T04:41:41.319Z" }, - { url = "https://files.pythonhosted.org/packages/9e/f2/6a2a1f722b6aba37050e626530a46a68f74e63683947a8acff92569f979a/charset_normalizer-3.4.4-cp314-cp314-musllinux_1_2_aarch64.whl", hash = "sha256:c8ae8a0f02f57a6e61203a31428fa1d677cbe50c93622b4149d5c0f319c1d19e", size = 150395, upload-time = "2025-10-14T04:41:42.539Z" }, - { url = "https://files.pythonhosted.org/packages/60/bb/2186cb2f2bbaea6338cad15ce23a67f9b0672929744381e28b0592676824/charset_normalizer-3.4.4-cp314-cp314-musllinux_1_2_armv7l.whl", hash = "sha256:47cc91b2f4dd2833fddaedd2893006b0106129d4b94fdb6af1f4ce5a9965577c", size = 143680, upload-time = "2025-10-14T04:41:43.661Z" }, - { url = "https://files.pythonhosted.org/packages/7d/a5/bf6f13b772fbb2a90360eb620d52ed8f796f3c5caee8398c3b2eb7b1c60d/charset_normalizer-3.4.4-cp314-cp314-musllinux_1_2_ppc64le.whl", hash = "sha256:82004af6c302b5d3ab2cfc4cc5f29db16123b1a8417f2e25f9066f91d4411090", size = 162045, upload-time = "2025-10-14T04:41:44.821Z" }, - { url = "https://files.pythonhosted.org/packages/df/c5/d1be898bf0dc3ef9030c3825e5d3b83f2c528d207d246cbabe245966808d/charset_normalizer-3.4.4-cp314-cp314-musllinux_1_2_riscv64.whl", hash = "sha256:2b7d8f6c26245217bd2ad053761201e9f9680f8ce52f0fcd8d0755aeae5b2152", size = 149687, upload-time = "2025-10-14T04:41:46.442Z" }, - { url = "https://files.pythonhosted.org/packages/a5/42/90c1f7b9341eef50c8a1cb3f098ac43b0508413f33affd762855f67a410e/charset_normalizer-3.4.4-cp314-cp314-musllinux_1_2_s390x.whl", hash = "sha256:799a7a5e4fb2d5898c60b640fd4981d6a25f1c11790935a44ce38c54e985f828", size = 160014, upload-time = "2025-10-14T04:41:47.631Z" }, - { url = "https://files.pythonhosted.org/packages/76/be/4d3ee471e8145d12795ab655ece37baed0929462a86e72372fd25859047c/charset_normalizer-3.4.4-cp314-cp314-musllinux_1_2_x86_64.whl", hash = "sha256:99ae2cffebb06e6c22bdc25801d7b30f503cc87dbd283479e7b606f70aff57ec", size = 154044, upload-time = "2025-10-14T04:41:48.81Z" }, - { url = "https://files.pythonhosted.org/packages/b0/6f/8f7af07237c34a1defe7defc565a9bc1807762f672c0fde711a4b22bf9c0/charset_normalizer-3.4.4-cp314-cp314-win32.whl", hash = "sha256:f9d332f8c2a2fcbffe1378594431458ddbef721c1769d78e2cbc06280d8155f9", size = 99940, upload-time = "2025-10-14T04:41:49.946Z" }, - { url = "https://files.pythonhosted.org/packages/4b/51/8ade005e5ca5b0d80fb4aff72a3775b325bdc3d27408c8113811a7cbe640/charset_normalizer-3.4.4-cp314-cp314-win_amd64.whl", hash = "sha256:8a6562c3700cce886c5be75ade4a5db4214fda19fede41d9792d100288d8f94c", size = 107104, upload-time = "2025-10-14T04:41:51.051Z" }, - { url = "https://files.pythonhosted.org/packages/da/5f/6b8f83a55bb8278772c5ae54a577f3099025f9ade59d0136ac24a0df4bde/charset_normalizer-3.4.4-cp314-cp314-win_arm64.whl", hash = "sha256:de00632ca48df9daf77a2c65a484531649261ec9f25489917f09e455cb09ddb2", size = 100743, upload-time = "2025-10-14T04:41:52.122Z" }, - { url = "https://files.pythonhosted.org/packages/0a/4c/925909008ed5a988ccbb72dcc897407e5d6d3bd72410d69e051fc0c14647/charset_normalizer-3.4.4-py3-none-any.whl", hash = "sha256:7a32c560861a02ff789ad905a2fe94e3f840803362c84fecf1851cb4cf3dc37f", size = 53402, upload-time = "2025-10-14T04:42:31.76Z" }, +version = "3.4.5" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/1d/35/02daf95b9cd686320bb622eb148792655c9412dbb9b67abb5694e5910a24/charset_normalizer-3.4.5.tar.gz", hash = "sha256:95adae7b6c42a6c5b5b559b1a99149f090a57128155daeea91732c8d970d8644", size = 134804, upload-time = "2026-03-06T06:03:19.46Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/f5/48/9f34ec4bb24aa3fdba1890c1bddb97c8a4be1bd84ef5c42ac2352563ad05/charset_normalizer-3.4.5-cp313-cp313-macosx_10_13_universal2.whl", hash = "sha256:ac59c15e3f1465f722607800c68713f9fbc2f672b9eb649fe831da4019ae9b23", size = 280788, upload-time = "2026-03-06T06:01:37.126Z" }, + { url = "https://files.pythonhosted.org/packages/0e/09/6003e7ffeb90cc0560da893e3208396a44c210c5ee42efff539639def59b/charset_normalizer-3.4.5-cp313-cp313-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:165c7b21d19365464e8f70e5ce5e12524c58b48c78c1f5a57524603c1ab003f8", size = 188890, upload-time = "2026-03-06T06:01:38.73Z" }, + { url = "https://files.pythonhosted.org/packages/42/1e/02706edf19e390680daa694d17e2b8eab4b5f7ac285e2a51168b4b22ee6b/charset_normalizer-3.4.5-cp313-cp313-manylinux2014_ppc64le.manylinux_2_17_ppc64le.manylinux_2_28_ppc64le.whl", hash = "sha256:28269983f25a4da0425743d0d257a2d6921ea7d9b83599d4039486ec5b9f911d", size = 206136, upload-time = "2026-03-06T06:01:40.016Z" }, + { url = "https://files.pythonhosted.org/packages/c7/87/942c3def1b37baf3cf786bad01249190f3ca3d5e63a84f831e704977de1f/charset_normalizer-3.4.5-cp313-cp313-manylinux2014_s390x.manylinux_2_17_s390x.manylinux_2_28_s390x.whl", hash = "sha256:d27ce22ec453564770d29d03a9506d449efbb9fa13c00842262b2f6801c48cce", size = 202551, upload-time = "2026-03-06T06:01:41.522Z" }, + { url = "https://files.pythonhosted.org/packages/94/0a/af49691938dfe175d71b8a929bd7e4ace2809c0c5134e28bc535660d5262/charset_normalizer-3.4.5-cp313-cp313-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:0625665e4ebdddb553ab185de5db7054393af8879fb0c87bd5690d14379d6819", size = 195572, upload-time = "2026-03-06T06:01:43.208Z" }, + { url = "https://files.pythonhosted.org/packages/20/ea/dfb1792a8050a8e694cfbde1570ff97ff74e48afd874152d38163d1df9ae/charset_normalizer-3.4.5-cp313-cp313-manylinux_2_31_armv7l.whl", hash = "sha256:c23eb3263356d94858655b3e63f85ac5d50970c6e8febcdde7830209139cc37d", size = 184438, upload-time = "2026-03-06T06:01:44.755Z" }, + { url = "https://files.pythonhosted.org/packages/72/12/c281e2067466e3ddd0595bfaea58a6946765ace5c72dfa3edc2f5f118026/charset_normalizer-3.4.5-cp313-cp313-manylinux_2_31_riscv64.manylinux_2_39_riscv64.whl", hash = "sha256:e6302ca4ae283deb0af68d2fbf467474b8b6aedcd3dab4db187e07f94c109763", size = 193035, upload-time = "2026-03-06T06:01:46.051Z" }, + { url = "https://files.pythonhosted.org/packages/ba/4f/3792c056e7708e10464bad0438a44708886fb8f92e3c3d29ec5e2d964d42/charset_normalizer-3.4.5-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:e51ae7d81c825761d941962450f50d041db028b7278e7b08930b4541b3e45cb9", size = 191340, upload-time = "2026-03-06T06:01:47.547Z" }, + { url = "https://files.pythonhosted.org/packages/e7/86/80ddba897127b5c7a9bccc481b0cd36c8fefa485d113262f0fe4332f0bf4/charset_normalizer-3.4.5-cp313-cp313-musllinux_1_2_armv7l.whl", hash = "sha256:597d10dec876923e5c59e48dbd366e852eacb2b806029491d307daea6b917d7c", size = 185464, upload-time = "2026-03-06T06:01:48.764Z" }, + { url = "https://files.pythonhosted.org/packages/4d/00/b5eff85ba198faacab83e0e4b6f0648155f072278e3b392a82478f8b988b/charset_normalizer-3.4.5-cp313-cp313-musllinux_1_2_ppc64le.whl", hash = "sha256:5cffde4032a197bd3b42fd0b9509ec60fb70918d6970e4cc773f20fc9180ca67", size = 208014, upload-time = "2026-03-06T06:01:50.371Z" }, + { url = "https://files.pythonhosted.org/packages/c8/11/d36f70be01597fd30850dde8a1269ebc8efadd23ba5785808454f2389bde/charset_normalizer-3.4.5-cp313-cp313-musllinux_1_2_riscv64.whl", hash = "sha256:2da4eedcb6338e2321e831a0165759c0c620e37f8cd044a263ff67493be8ffb3", size = 193297, upload-time = "2026-03-06T06:01:51.933Z" }, + { url = "https://files.pythonhosted.org/packages/1a/1d/259eb0a53d4910536c7c2abb9cb25f4153548efb42800c6a9456764649c0/charset_normalizer-3.4.5-cp313-cp313-musllinux_1_2_s390x.whl", hash = "sha256:65a126fb4b070d05340a84fc709dd9e7c75d9b063b610ece8a60197a291d0adf", size = 204321, upload-time = "2026-03-06T06:01:53.887Z" }, + { url = "https://files.pythonhosted.org/packages/84/31/faa6c5b9d3688715e1ed1bb9d124c384fe2fc1633a409e503ffe1c6398c1/charset_normalizer-3.4.5-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:c7a80a9242963416bd81f99349d5f3fce1843c303bd404f204918b6d75a75fd6", size = 197509, upload-time = "2026-03-06T06:01:56.439Z" }, + { url = "https://files.pythonhosted.org/packages/fd/a5/c7d9dd1503ffc08950b3260f5d39ec2366dd08254f0900ecbcf3a6197c7c/charset_normalizer-3.4.5-cp313-cp313-win32.whl", hash = "sha256:f1d725b754e967e648046f00c4facc42d414840f5ccc670c5670f59f83693e4f", size = 132284, upload-time = "2026-03-06T06:01:57.812Z" }, + { url = "https://files.pythonhosted.org/packages/b9/0f/57072b253af40c8aa6636e6de7d75985624c1eb392815b2f934199340a89/charset_normalizer-3.4.5-cp313-cp313-win_amd64.whl", hash = "sha256:e37bd100d2c5d3ba35db9c7c5ba5a9228cbcffe5c4778dc824b164e5257813d7", size = 142630, upload-time = "2026-03-06T06:01:59.062Z" }, + { url = "https://files.pythonhosted.org/packages/31/41/1c4b7cc9f13bd9d369ce3bc993e13d374ce25fa38a2663644283ecf422c1/charset_normalizer-3.4.5-cp313-cp313-win_arm64.whl", hash = "sha256:93b3b2cc5cf1b8743660ce77a4f45f3f6d1172068207c1defc779a36eea6bb36", size = 133254, upload-time = "2026-03-06T06:02:00.281Z" }, + { url = "https://files.pythonhosted.org/packages/43/be/0f0fd9bb4a7fa4fb5067fb7d9ac693d4e928d306f80a0d02bde43a7c4aee/charset_normalizer-3.4.5-cp314-cp314-macosx_10_15_universal2.whl", hash = "sha256:8197abe5ca1ffb7d91e78360f915eef5addff270f8a71c1fc5be24a56f3e4873", size = 280232, upload-time = "2026-03-06T06:02:01.508Z" }, + { url = "https://files.pythonhosted.org/packages/28/02/983b5445e4bef49cd8c9da73a8e029f0825f39b74a06d201bfaa2e55142a/charset_normalizer-3.4.5-cp314-cp314-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:a2aecdb364b8a1802afdc7f9327d55dad5366bc97d8502d0f5854e50712dbc5f", size = 189688, upload-time = "2026-03-06T06:02:02.857Z" }, + { url = "https://files.pythonhosted.org/packages/d0/88/152745c5166437687028027dc080e2daed6fe11cfa95a22f4602591c42db/charset_normalizer-3.4.5-cp314-cp314-manylinux2014_ppc64le.manylinux_2_17_ppc64le.manylinux_2_28_ppc64le.whl", hash = "sha256:a66aa5022bf81ab4b1bebfb009db4fd68e0c6d4307a1ce5ef6a26e5878dfc9e4", size = 206833, upload-time = "2026-03-06T06:02:05.127Z" }, + { url = "https://files.pythonhosted.org/packages/cb/0f/ebc15c8b02af2f19be9678d6eed115feeeccc45ce1f4b098d986c13e8769/charset_normalizer-3.4.5-cp314-cp314-manylinux2014_s390x.manylinux_2_17_s390x.manylinux_2_28_s390x.whl", hash = "sha256:d77f97e515688bd615c1d1f795d540f32542d514242067adcb8ef532504cb9ee", size = 202879, upload-time = "2026-03-06T06:02:06.446Z" }, + { url = "https://files.pythonhosted.org/packages/38/9c/71336bff6934418dc8d1e8a1644176ac9088068bc571da612767619c97b3/charset_normalizer-3.4.5-cp314-cp314-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:01a1ed54b953303ca7e310fafe0fe347aab348bd81834a0bcd602eb538f89d66", size = 195764, upload-time = "2026-03-06T06:02:08.763Z" }, + { url = "https://files.pythonhosted.org/packages/b7/95/ce92fde4f98615661871bc282a856cf9b8a15f686ba0af012984660d480b/charset_normalizer-3.4.5-cp314-cp314-manylinux_2_31_armv7l.whl", hash = "sha256:b2d37d78297b39a9eb9eb92c0f6df98c706467282055419df141389b23f93362", size = 183728, upload-time = "2026-03-06T06:02:10.137Z" }, + { url = "https://files.pythonhosted.org/packages/1c/e7/f5b4588d94e747ce45ae680f0f242bc2d98dbd4eccfab73e6160b6893893/charset_normalizer-3.4.5-cp314-cp314-manylinux_2_31_riscv64.manylinux_2_39_riscv64.whl", hash = "sha256:e71bbb595973622b817c042bd943c3f3667e9c9983ce3d205f973f486fec98a7", size = 192937, upload-time = "2026-03-06T06:02:11.663Z" }, + { url = "https://files.pythonhosted.org/packages/f9/29/9d94ed6b929bf9f48bf6ede6e7474576499f07c4c5e878fb186083622716/charset_normalizer-3.4.5-cp314-cp314-musllinux_1_2_aarch64.whl", hash = "sha256:4cd966c2559f501c6fd69294d082c2934c8dd4719deb32c22961a5ac6db0df1d", size = 192040, upload-time = "2026-03-06T06:02:13.489Z" }, + { url = "https://files.pythonhosted.org/packages/15/d2/1a093a1cf827957f9445f2fe7298bcc16f8fc5e05c1ed2ad1af0b239035e/charset_normalizer-3.4.5-cp314-cp314-musllinux_1_2_armv7l.whl", hash = "sha256:d5e52d127045d6ae01a1e821acfad2f3a1866c54d0e837828538fabe8d9d1bd6", size = 184107, upload-time = "2026-03-06T06:02:14.83Z" }, + { url = "https://files.pythonhosted.org/packages/0f/7d/82068ce16bd36135df7b97f6333c5d808b94e01d4599a682e2337ed5fd14/charset_normalizer-3.4.5-cp314-cp314-musllinux_1_2_ppc64le.whl", hash = "sha256:30a2b1a48478c3428d047ed9690d57c23038dac838a87ad624c85c0a78ebeb39", size = 208310, upload-time = "2026-03-06T06:02:16.165Z" }, + { url = "https://files.pythonhosted.org/packages/84/4e/4dfb52307bb6af4a5c9e73e482d171b81d36f522b21ccd28a49656baa680/charset_normalizer-3.4.5-cp314-cp314-musllinux_1_2_riscv64.whl", hash = "sha256:d8ed79b8f6372ca4254955005830fd61c1ccdd8c0fac6603e2c145c61dd95db6", size = 192918, upload-time = "2026-03-06T06:02:18.144Z" }, + { url = "https://files.pythonhosted.org/packages/08/a4/159ff7da662cf7201502ca89980b8f06acf3e887b278956646a8aeb178ab/charset_normalizer-3.4.5-cp314-cp314-musllinux_1_2_s390x.whl", hash = "sha256:c5af897b45fa606b12464ccbe0014bbf8c09191e0a66aab6aa9d5cf6e77e0c94", size = 204615, upload-time = "2026-03-06T06:02:19.821Z" }, + { url = "https://files.pythonhosted.org/packages/d6/62/0dd6172203cb6b429ffffc9935001fde42e5250d57f07b0c28c6046deb6b/charset_normalizer-3.4.5-cp314-cp314-musllinux_1_2_x86_64.whl", hash = "sha256:1088345bcc93c58d8d8f3d783eca4a6e7a7752bbff26c3eee7e73c597c191c2e", size = 197784, upload-time = "2026-03-06T06:02:21.86Z" }, + { url = "https://files.pythonhosted.org/packages/c7/5e/1aab5cb737039b9c59e63627dc8bbc0d02562a14f831cc450e5f91d84ce1/charset_normalizer-3.4.5-cp314-cp314-win32.whl", hash = "sha256:ee57b926940ba00bca7ba7041e665cc956e55ef482f851b9b65acb20d867e7a2", size = 133009, upload-time = "2026-03-06T06:02:23.289Z" }, + { url = "https://files.pythonhosted.org/packages/40/65/e7c6c77d7aaa4c0d7974f2e403e17f0ed2cb0fc135f77d686b916bf1eead/charset_normalizer-3.4.5-cp314-cp314-win_amd64.whl", hash = "sha256:4481e6da1830c8a1cc0b746b47f603b653dadb690bcd851d039ffaefe70533aa", size = 143511, upload-time = "2026-03-06T06:02:26.195Z" }, + { url = "https://files.pythonhosted.org/packages/ba/91/52b0841c71f152f563b8e072896c14e3d83b195c188b338d3cc2e582d1d4/charset_normalizer-3.4.5-cp314-cp314-win_arm64.whl", hash = "sha256:97ab7787092eb9b50fb47fa04f24c75b768a606af1bcba1957f07f128a7219e4", size = 133775, upload-time = "2026-03-06T06:02:27.473Z" }, + { url = "https://files.pythonhosted.org/packages/c5/60/3a621758945513adfd4db86827a5bafcc615f913dbd0b4c2ed64a65731be/charset_normalizer-3.4.5-py3-none-any.whl", hash = "sha256:9db5e3fcdcee89a78c04dffb3fe33c79f77bd741a624946db2591c81b2fc85b0", size = 55455, upload-time = "2026-03-06T06:03:17.827Z" }, ] [[package]] @@ -492,7 +492,7 @@ wheels = [ [[package]] name = "datalake-mcp-server-client" version = "0.0.1" -source = { git = "https://github.com/BERDataLakehouse/datalake-mcp-server-client.git?rev=v0.0.6#525e40a5fa481bff8b556051612250bd6735a7a5" } +source = { git = "https://github.com/BERDataLakehouse/datalake-mcp-server-client.git?rev=v0.0.7#9a2eb8f4160ed244f6f6ad671b312ef75fa0ee14" } dependencies = [ { name = "httpx" }, { name = "pydantic" }, @@ -609,14 +609,14 @@ wheels = [ [[package]] name = "googleapis-common-protos" -version = "1.72.0" +version = "1.73.0" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "protobuf" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/e5/7b/adfd75544c415c487b33061fe7ae526165241c1ea133f9a9125a56b39fd8/googleapis_common_protos-1.72.0.tar.gz", hash = "sha256:e55a601c1b32b52d7a3e65f43563e2aa61bcd737998ee672ac9b951cd49319f5", size = 147433, upload-time = "2025-11-06T18:29:24.087Z" } +sdist = { url = "https://files.pythonhosted.org/packages/99/96/a0205167fa0154f4a542fd6925bdc63d039d88dab3588b875078107e6f06/googleapis_common_protos-1.73.0.tar.gz", hash = "sha256:778d07cd4fbeff84c6f7c72102f0daf98fa2bfd3fa8bea426edc545588da0b5a", size = 147323, upload-time = "2026-03-06T21:53:09.727Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/c4/ab/09169d5a4612a5f92490806649ac8d41e3ec9129c636754575b3553f4ea4/googleapis_common_protos-1.72.0-py3-none-any.whl", hash = "sha256:4299c5a82d5ae1a9702ada957347726b167f9f8d1fc352477702a1e851ff4038", size = 297515, upload-time = "2025-11-06T18:29:13.14Z" }, + { url = "https://files.pythonhosted.org/packages/69/28/23eea8acd65972bbfe295ce3666b28ac510dfcb115fac089d3edb0feb00a/googleapis_common_protos-1.73.0-py3-none-any.whl", hash = "sha256:dfdaaa2e860f242046be561e6d6cb5c5f1541ae02cfbcb034371aadb2942b4e8", size = 297578, upload-time = "2026-03-06T21:52:33.933Z" }, ] [[package]] @@ -777,7 +777,7 @@ wheels = [ [[package]] name = "ipython" -version = "9.10.0" +version = "9.11.0" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "colorama", marker = "sys_platform == 'win32'" }, @@ -791,9 +791,9 @@ dependencies = [ { name = "stack-data" }, { name = "traitlets" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/a6/60/2111715ea11f39b1535bed6024b7dec7918b71e5e5d30855a5b503056b50/ipython-9.10.0.tar.gz", hash = "sha256:cd9e656be97618a0676d058134cd44e6dc7012c0e5cb36a9ce96a8c904adaf77", size = 4426526, upload-time = "2026-02-02T10:00:33.594Z" } +sdist = { url = "https://files.pythonhosted.org/packages/86/28/a4698eda5a8928a45d6b693578b135b753e14fa1c2b36ee9441e69a45576/ipython-9.11.0.tar.gz", hash = "sha256:2a94bc4406b22ecc7e4cb95b98450f3ea493a76bec8896cda11b78d7752a6667", size = 4427354, upload-time = "2026-03-05T08:57:30.549Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/3d/aa/898dec789a05731cd5a9f50605b7b44a72bd198fd0d4528e11fc610177cc/ipython-9.10.0-py3-none-any.whl", hash = "sha256:c6ab68cc23bba8c7e18e9b932797014cc61ea7fd6f19de180ab9ba73e65ee58d", size = 622774, upload-time = "2026-02-02T10:00:31.503Z" }, + { url = "https://files.pythonhosted.org/packages/b2/90/45c72becc57158facc6a6404f663b77bbcea2519ca57f760e2879ae1315d/ipython-9.11.0-py3-none-any.whl", hash = "sha256:6922d5bcf944c6e525a76a0a304451b60a2b6f875e86656d8bc2dfda5d710e19", size = 624222, upload-time = "2026-03-05T08:57:28.94Z" }, ] [[package]] @@ -928,7 +928,7 @@ wheels = [ [[package]] name = "langchain" -version = "0.3.27" +version = "0.3.28" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "langchain-core" }, @@ -939,9 +939,9 @@ dependencies = [ { name = "requests" }, { name = "sqlalchemy" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/83/f6/f4f7f3a56626fe07e2bb330feb61254dbdf06c506e6b59a536a337da51cf/langchain-0.3.27.tar.gz", hash = "sha256:aa6f1e6274ff055d0fd36254176770f356ed0a8994297d1df47df341953cec62", size = 10233809, upload-time = "2025-07-24T14:42:32.959Z" } +sdist = { url = "https://files.pythonhosted.org/packages/87/bb/a65e29c8e4aaf0348c2617962e427c8e760d82a67adbd197019e49c7769d/langchain-0.3.28.tar.gz", hash = "sha256:30a32f44cc6690bcc6a6fb7c14d61a15406d5eda1a0e7eab60b3660944888741", size = 10242473, upload-time = "2026-03-06T22:45:17.911Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/f6/d5/4861816a95b2f6993f1360cfb605aacb015506ee2090433a71de9cca8477/langchain-0.3.27-py3-none-any.whl", hash = "sha256:7b20c4f338826acb148d885b20a73a16e410ede9ee4f19bb02011852d5f98798", size = 1018194, upload-time = "2025-07-24T14:42:30.23Z" }, + { url = "https://files.pythonhosted.org/packages/5b/f5/ecd71e5b78e67944b2600a155ef63000bc00148e6794e8e7809b2453887a/langchain-0.3.28-py3-none-any.whl", hash = "sha256:1ba1244477b67b812b775f346209fa596e78bf055a34e45ce22acb7a45842a32", size = 1024717, upload-time = "2026-03-06T22:45:15.545Z" }, ] [[package]] @@ -1028,7 +1028,7 @@ wheels = [ [[package]] name = "langsmith" -version = "0.7.11" +version = "0.7.14" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "httpx" }, @@ -1041,9 +1041,9 @@ dependencies = [ { name = "xxhash" }, { name = "zstandard" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/0d/43/db660d35fb59577490b072fa7bee4043ee4ba9d21c3185882efb3713fe59/langsmith-0.7.11.tar.gz", hash = "sha256:71df5fb9fa1ee0d3b494c14393566d33130739656de5ef96486bcbb0b5e4d329", size = 1109819, upload-time = "2026-03-03T20:29:18.406Z" } +sdist = { url = "https://files.pythonhosted.org/packages/bc/32/b3931027ff7d635a66a0edbeec9f8a285fe77b04f1f0cbbc58fd20f2555a/langsmith-0.7.14.tar.gz", hash = "sha256:95606314a8dea0ea1ff3650da4cf0433737b14c4c296579c6b770b43cb5e0b37", size = 1113666, upload-time = "2026-03-06T20:13:17.308Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/da/c1/aec40ba797c3ce0f9c41536491394704ae2d7253794405cb813748dcddbe/langsmith-0.7.11-py3-none-any.whl", hash = "sha256:0aff5b4316341d6ab6bcb6abf405a6a098f469020bad4889cafb6098650b8603", size = 346485, upload-time = "2026-03-03T20:29:16.685Z" }, + { url = "https://files.pythonhosted.org/packages/6e/4f/b81ee2d06e1d69aa689b43d2b777901c060d257507806cad7cd9035d5ca4/langsmith-0.7.14-py3-none-any.whl", hash = "sha256:754dcb474a3f3f83cfefbd9694b897bce2a1a0b412bf75e256f85a64206ddcb7", size = 347350, upload-time = "2026-03-06T20:13:15.706Z" }, ] [[package]] @@ -1089,7 +1089,7 @@ wheels = [ [[package]] name = "minio-manager-service-client" version = "0.0.1" -source = { git = "https://github.com/BERDataLakehouse/minio_manager_service_client.git?rev=v0.0.7#d60b6ea2f4cf49979eac22b9ee3b884eb24f39d3" } +source = { git = "https://github.com/BERDataLakehouse/minio_manager_service_client.git?rev=v0.0.10#14fdfe1377358c4d5e154010a4fa8724f12bdb2d" } dependencies = [ { name = "httpx" }, { name = "pydantic" }, @@ -1237,7 +1237,7 @@ wheels = [ [[package]] name = "openai" -version = "2.24.0" +version = "2.26.0" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "anyio" }, @@ -1249,9 +1249,9 @@ dependencies = [ { name = "tqdm" }, { name = "typing-extensions" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/55/13/17e87641b89b74552ed408a92b231283786523edddc95f3545809fab673c/openai-2.24.0.tar.gz", hash = "sha256:1e5769f540dbd01cb33bc4716a23e67b9d695161a734aff9c5f925e2bf99a673", size = 658717, upload-time = "2026-02-24T20:02:07.958Z" } +sdist = { url = "https://files.pythonhosted.org/packages/d7/91/2a06c4e9597c338cac1e5e5a8dd6f29e1836fc229c4c523529dca387fda8/openai-2.26.0.tar.gz", hash = "sha256:b41f37c140ae0034a6e92b0c509376d907f3a66109935fba2c1b471a7c05a8fb", size = 666702, upload-time = "2026-03-05T23:17:35.874Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/c9/30/844dc675ee6902579b8eef01ed23917cc9319a1c9c0c14ec6e39340c96d0/openai-2.24.0-py3-none-any.whl", hash = "sha256:fed30480d7d6c884303287bde864980a4b137b60553ffbcf9ab4a233b7a73d94", size = 1120122, upload-time = "2026-02-24T20:02:05.669Z" }, + { url = "https://files.pythonhosted.org/packages/c6/2e/3f73e8ca53718952222cacd0cf7eecc9db439d020f0c1fe7ae717e4e199a/openai-2.26.0-py3-none-any.whl", hash = "sha256:6151bf8f83802f036117f06cc8a57b3a4da60da9926826cc96747888b57f394f", size = 1136409, upload-time = "2026-03-05T23:17:34.072Z" }, ] [[package]] @@ -1359,7 +1359,7 @@ name = "pexpect" version = "4.9.0" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "ptyprocess" }, + { name = "ptyprocess", marker = "sys_platform != 'emscripten' and sys_platform != 'win32'" }, ] sdist = { url = "https://files.pythonhosted.org/packages/42/92/cc564bf6381ff43ce1f4d06852fc19a2f11d180f23dc32d9588bee2f149d/pexpect-4.9.0.tar.gz", hash = "sha256:ee7d41123f3c9911050ea2c2dac107568dc43b2d3b0c7557a33212c398ead30f", size = 166450, upload-time = "2023-11-25T09:07:26.339Z" } wheels = [ diff --git a/scripts/init-polaris-db.sh b/scripts/init-polaris-db.sh new file mode 100755 index 0000000..323d9cb --- /dev/null +++ b/scripts/init-polaris-db.sh @@ -0,0 +1,13 @@ +#!/bin/bash +set -e + +# Create the polaris database in the shared PostgreSQL instance. +# This script is mounted into /docker-entrypoint-initdb.d/ and runs once +# when the PostgreSQL container is first initialized. + +psql -v ON_ERROR_STOP=1 --username "$POSTGRES_USER" --dbname "$POSTGRES_DB" <<-EOSQL + SELECT 'CREATE DATABASE polaris' + WHERE NOT EXISTS (SELECT FROM pg_database WHERE datname = 'polaris')\gexec +EOSQL + +echo "Polaris database created (or already exists)" diff --git a/scripts/migrate_delta_to_iceberg.py b/scripts/migrate_delta_to_iceberg.py new file mode 100644 index 0000000..558efd9 --- /dev/null +++ b/scripts/migrate_delta_to_iceberg.py @@ -0,0 +1,379 @@ +""" +Delta Lake to Iceberg migration utilities for BERDL Phase 4. + +This script migrates Delta Lake tables (Hive Metastore) to Iceberg tables +(Polaris REST catalog), preserving partitions and validating row counts. + +Functions: + migrate_table - Migrate a single Delta table to an Iceberg catalog + migrate_user - Migrate all of a user's Delta databases to their Iceberg catalog + migrate_tenant - Migrate all of a tenant's Delta databases to their Iceberg catalog + +Usage in the migration notebook (migration_phase4.ipynb): + + # Import after adding scripts/ to sys.path + from migrate_delta_to_iceberg import MigrationTracker, migrate_user, migrate_tenant + + tracker = MigrationTracker() + + # Migrate all tables for a user (idempotent — skips existing tables) + migrate_user(spark, "tgu2", target_catalog="user_tgu2", tracker=tracker) + + # Migrate all tables for a tenant + migrate_tenant(spark, "globalusers", target_catalog="tenant_globalusers", tracker=tracker) + + # View results + tracker.to_dataframe(spark).show(truncate=False) + print(tracker.summary()) + +Force re-migration (drops existing Iceberg table and re-copies from Delta): + + # Force re-migrate a single user table + migrate_table( + spark, + hive_db="u_tian_gu_test__demo_personal", + table_name="personal_test_table", + target_catalog="user_tian_gu_test", + target_ns="demo_personal", + tracker=tracker, + force=True, + ) + + # Force re-migrate a single tenant table + migrate_table( + spark, + hive_db="globalusers_demo_shared", + table_name="tenant_test_table", + target_catalog="tenant_globalusers", + target_ns="demo_shared", + tracker=tracker, + force=True, + ) + + # Force re-migrate all tables for a user + migrate_user(spark, "tian_gu_test", target_catalog="user_tian_gu_test", tracker=tracker, force=True) + + # Force re-migrate all tables for a tenant + migrate_tenant(spark, "globalusers", target_catalog="tenant_globalusers", tracker=tracker, force=True) + +Note: + - Requires an admin Spark session configured with cross-user catalog access + (see Section 3 of migration_phase4.ipynb) + - By default, migration is idempotent: tables that already exist in the target + catalog are skipped. Use force=True to drop and re-migrate. + - DROP TABLE PURGE does not delete S3 data files due to an Iceberg bug (#14743). + To fully clean up, delete files from S3 directly using get_minio_client(). +""" + +import logging +from dataclasses import dataclass, field + +from pyspark.sql import SparkSession + +logger = logging.getLogger(__name__) + + +def _clean_error(e: Exception, max_len: int = 150) -> str: + """Extract a short, readable error message without JVM stacktraces.""" + msg = str(e) + # Strip everything after "JVM stacktrace:" if present + if "JVM stacktrace:" in msg: + msg = msg[: msg.index("JVM stacktrace:")].strip() + # Take only the first line + msg = msg.split("\n")[0].strip() + if len(msg) > max_len: + msg = msg[:max_len] + "..." + return msg + + +@dataclass +class TableResult: + """Result of a single table migration.""" + + source: str + target: str + status: str # "migrated", "skipped", "failed" + row_count: int = 0 + error: str = "" + + +@dataclass +class MigrationTracker: + """Tracks migration progress across users and tenants.""" + + results: list[TableResult] = field(default_factory=list) + + @property + def migrated(self) -> list[TableResult]: + return [r for r in self.results if r.status == "migrated"] + + @property + def skipped(self) -> list[TableResult]: + return [r for r in self.results if r.status == "skipped"] + + @property + def failed(self) -> list[TableResult]: + return [r for r in self.results if r.status == "failed"] + + def add(self, result: TableResult): + self.results.append(result) + + def summary(self) -> str: + return ( + f"Total: {len(self.results)} | " + f"Migrated: {len(self.migrated)} | " + f"Skipped: {len(self.skipped)} | " + f"Failed: {len(self.failed)}" + ) + + def to_dataframe(self, spark: SparkSession): + """Convert results to a Spark DataFrame for notebook display.""" + rows = [(r.source, r.target, r.status, r.row_count, r.error) for r in self.results] + return spark.createDataFrame(rows, ["source", "target", "status", "row_count", "error"]) + + +def _validate_target_catalog(spark: SparkSession, target_catalog: str) -> None: + """Raise ValueError if target_catalog is not configured in the Spark session. + + Uses spark.conf.get() instead of SHOW CATALOGS because Spark lazily loads + catalogs — they won't appear in SHOW CATALOGS until first access. + """ + config_key = f"spark.sql.catalog.{target_catalog}" + try: + spark.conf.get(config_key) + except Exception: + raise ValueError( + f"Catalog '{target_catalog}' is not configured in the current Spark session " + f"(no {config_key} property found). " + f"Did you run Section 3 (Configure Admin Spark) and restart Spark Connect?" + ) + + +def table_exists_in_catalog(spark: SparkSession, catalog: str, namespace: str, table_name: str) -> bool: + """Check if a table already exists in the target Iceberg catalog. + + Uses SHOW TABLES instead of DESCRIBE TABLE to avoid JVM-level + TABLE_OR_VIEW_NOT_FOUND error logs when the table doesn't exist. + """ + try: + tables = spark.sql(f"SHOW TABLES IN {catalog}.{namespace}").collect() + return any(row["tableName"] == table_name for row in tables) + except Exception: + return False + + +def migrate_table( + spark: SparkSession, + hive_db: str, + table_name: str, + target_catalog: str, + target_ns: str, + tracker: MigrationTracker | None = None, + force: bool = False, +): + """ + Migrate a single Delta table to Iceberg via Polaris, preserving partitions. + + Args: + spark: Active Spark session + hive_db: Original Hive/Delta database name + table_name: Original table name + target_catalog: Target Iceberg catalog name (e.g., 'my') + target_ns: Target namespace in Iceberg (e.g., 'test_db') + tracker: Optional MigrationTracker for progress tracking + force: If True, drop existing target table and re-migrate + """ + source_ref = f"{hive_db}.{table_name}" + target_table_ref = f"{target_catalog}.{target_ns}.{table_name}" + print(f" {source_ref} -> {target_table_ref}") + logger.info(f"Starting migration for {source_ref} -> {target_table_ref}") + + # 0. Idempotency: skip if target already exists (unless force=True) + if table_exists_in_catalog(spark, target_catalog, target_ns, table_name): + if force: + logger.info(f"Force mode: dropping existing {target_table_ref}") + spark.sql(f"DROP TABLE {target_table_ref} PURGE") + else: + logger.info(f"Skipping {target_table_ref} — already exists in target catalog") + if tracker: + tracker.add(TableResult(source=source_ref, target=target_table_ref, status="skipped")) + return + + # 1. Read from Delta using spark.table fallback + try: + df = spark.table(source_ref) + except Exception as e: + err_str = str(e) + if "DELTA_READ_TABLE_WITHOUT_COLUMNS" in err_str: + msg = "Delta table has no columns (empty schema) — skipping" + logger.warning(f"Skipping {source_ref} — {msg}") + if tracker: + tracker.add(TableResult(source=source_ref, target=target_table_ref, status="skipped", error=msg)) + return + short_err = _clean_error(e) + print(f" FAILED (read): {short_err}") + logger.error(f"Failed to read source table {source_ref}: {short_err}") + if tracker: + tracker.add(TableResult(source=source_ref, target=target_table_ref, status="failed", error=short_err)) + return + + # 1b. Skip tables with empty schema (corrupt Delta tables with no columns) + if len(df.columns) == 0: + msg = "Delta table has no columns (empty schema)" + logger.warning(f"Skipping {source_ref} — {msg}") + if tracker: + tracker.add(TableResult(source=source_ref, target=target_table_ref, status="skipped", error=msg)) + return + + # 2. Extract partition columns from the original Delta table + partition_cols: list[str] = [] + try: + partition_cols = [row.name for row in spark.catalog.listColumns(f"{hive_db}.{table_name}") if row.isPartition] + if partition_cols: + logger.info(f"Found partition columns: {partition_cols}") + except Exception as e: + logger.warning(f"Could not fetch partitions for {source_ref} via catalog API: {e}") + + # 3. Create target namespace, write data, and validate + try: + spark.sql(f"CREATE NAMESPACE IF NOT EXISTS {target_catalog}.{target_ns}") + + # 4. Write as Iceberg (applying partition logic if it existed) + writer = df.writeTo(target_table_ref) + if partition_cols: + writer = writer.partitionedBy(*partition_cols) + + logger.info(f"Writing data to {target_table_ref}...") + writer.create() + logger.info(f"Write completed for {target_table_ref}") + + # 5. Validate row counts + original_count = spark.sql(f"SELECT COUNT(*) as cnt FROM {source_ref}").collect()[0]["cnt"] + migrated_count = spark.sql(f"SELECT COUNT(*) as cnt FROM {target_table_ref}").collect()[0]["cnt"] + + if original_count != migrated_count: + msg = f"Row count mismatch: {original_count} vs {migrated_count}" + logger.error(f"Validation FAILED: {msg}") + if tracker: + tracker.add( + TableResult( + source=source_ref, + target=target_table_ref, + status="failed", + row_count=migrated_count, + error=msg, + ) + ) + raise ValueError(msg) + + logger.info(f"Validation SUCCESS: {migrated_count} rows migrated exactly.") + if tracker: + tracker.add( + TableResult( + source=source_ref, + target=target_table_ref, + status="migrated", + row_count=migrated_count, + ) + ) + except Exception as e: + short_err = _clean_error(e) + print(f" FAILED: {short_err}") + logger.error(f"Failed to migrate {source_ref} -> {target_table_ref}: {short_err}") + if tracker and not any(r.target == target_table_ref for r in tracker.results): + tracker.add(TableResult(source=source_ref, target=target_table_ref, status="failed", error=short_err)) + + +def migrate_user( + spark: SparkSession, + username: str, + target_catalog: str = "my", + tracker: MigrationTracker | None = None, + force: bool = False, +): + """ + Migrate all of a user's Delta databases to their Iceberg catalog. + + Args: + spark: Active SparkSession + username: The user's username + target_catalog: The target catalog (e.g., 'user_{username}') + tracker: Optional MigrationTracker for progress tracking + force: If True, drop existing target tables and re-migrate + """ + _validate_target_catalog(spark, target_catalog) + + prefix = f"u_{username}__" + databases = [db[0] for db in spark.sql("SHOW DATABASES").collect() if db[0].startswith(prefix)] + + if not databases: + logger.info(f"No databases found for username {username} with prefix {prefix}") + return + + for hive_db in databases: + iceberg_ns = hive_db.replace(prefix, "", 1) # "u_tgu2__test_db" -> "test_db" + logger.info(f"Scanning database {hive_db}...") + try: + tables = spark.sql(f"SHOW TABLES IN {hive_db}").collect() + except Exception as e: + print(f" Error listing tables in {hive_db}: {_clean_error(e)}") + continue + + for table_row in tables: + table_name = table_row["tableName"] + try: + migrate_table(spark, hive_db, table_name, target_catalog, iceberg_ns, tracker, force=force) + except Exception as e: + print(f" FAILED (unexpected): {_clean_error(e)}") + + +def migrate_tenant( + spark: SparkSession, + tenant_name: str, + target_catalog: str, + tracker: MigrationTracker | None = None, + force: bool = False, +): + """ + Migrate all Delta databases for a tenant to their Iceberg catalog. + + Tenant databases follow the pattern: {tenant_name}_{dbname} in Hive. + The {tenant_name}_ prefix is stripped to get the Iceberg namespace. + + Args: + spark: Active SparkSession + tenant_name: The tenant/group name (e.g., 'kbase') + target_catalog: Target Iceberg catalog (e.g., 'tenant_kbase') + tracker: Optional MigrationTracker for progress tracking + force: If True, drop existing target tables and re-migrate + """ + _validate_target_catalog(spark, target_catalog) + + prefix = f"{tenant_name}_" + databases = [db[0] for db in spark.sql("SHOW DATABASES").collect() if db[0].startswith(prefix)] + + if not databases: + logger.info(f"No databases found for tenant {tenant_name} with prefix {prefix}") + return + + for hive_db in databases: + iceberg_ns = hive_db.replace(prefix, "", 1) + logger.info(f"Scanning tenant database {hive_db}...") + try: + tables = spark.sql(f"SHOW TABLES IN {hive_db}").collect() + except Exception as e: + print(f" Error listing tables in {hive_db}: {_clean_error(e)}") + continue + + for table_row in tables: + table_name = table_row["tableName"] + try: + migrate_table(spark, hive_db, table_name, target_catalog, iceberg_ns, tracker, force=force) + except Exception as e: + print(f" FAILED (unexpected): {_clean_error(e)}") + + +if __name__ == "__main__": + # If this is run independently we set up basic print logging + logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s") + print("Migration functions are loaded. Supply an active spark session to begin.") From e30e37ed97b1a05aed30de1fffcf3a9953103b8b Mon Sep 17 00:00:00 2001 From: Tianhao-Gu Date: Fri, 13 Mar 2026 15:07:03 -0500 Subject: [PATCH 2/5] =?UTF-8?q?Merge=20main=20into=20feature/polaris=20?= =?UTF-8?q?=E2=80=94=20resolve=20test=20coverage=20conflicts?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Merged latest main (PR #135 test coverage improvements) into feature/polaris. Resolved 5 conflicts by keeping both sides' test additions where applicable, preferring main's safer patterns (try/finally cleanup, concurrent.futures.TimeoutError, time.sleep patching). --- configs/jupyter_server_config.py | 8 +- docker-compose.yaml | 2 +- notebook_utils/tests/agent/test_mcp_tools.py | 3 +- notebook_utils/tests/agent/test_settings.py | 21 ++ notebook_utils/tests/mcp/test_operations.py | 116 ++++++++ .../tests/spark/test_connect_server.py | 43 +-- notebook_utils/tests/spark/test_data_store.py | 124 ++++++++ notebook_utils/tests/spark/test_metrics.py | 273 ++++++++++++++++++ notebook_utils/tests/test_cache.py | 40 +-- 9 files changed, 586 insertions(+), 44 deletions(-) create mode 100644 notebook_utils/tests/agent/test_settings.py diff --git a/configs/jupyter_server_config.py b/configs/jupyter_server_config.py index 93fc9dd..b8fe90d 100644 --- a/configs/jupyter_server_config.py +++ b/configs/jupyter_server_config.py @@ -147,11 +147,11 @@ def provision_polaris(): polaris_creds = get_polaris_credentials() if polaris_creds: - logger.info(f"\u2705 Polaris credentials provisioned for catalog: {polaris_creds['personal_catalog']}") + logger.info(f"Polaris credentials provisioned for catalog: {polaris_creds['personal_catalog']}") if polaris_creds["tenant_catalogs"]: logger.info(f" Tenant catalogs: {', '.join(polaris_creds['tenant_catalogs'])}") else: - logger.info("\u2139\ufe0f Polaris not configured, skipping Polaris credential provisioning") + logger.info("Polaris not configured, skipping Polaris credential provisioning") except Exception as e: logger.error(f"Failed to provision Polaris credentials: {e}") @@ -169,9 +169,9 @@ def _start(): from berdl_notebook_utils.spark.connect_server import start_spark_connect_server server_info = start_spark_connect_server() - logger.info(f"\u2705 Spark Connect server ready at {server_info['url']}") + logger.info(f"Spark Connect server ready at {server_info['url']}") except Exception as e: - logger.error(f"\u274c Failed to start Spark Connect server: {e}") + logger.error(f"Failed to start Spark Connect server: {e}") t = threading.Thread(target=_start, name="spark-connect-startup", daemon=True) t.start() diff --git a/docker-compose.yaml b/docker-compose.yaml index 917f154..4b38e77 100644 --- a/docker-compose.yaml +++ b/docker-compose.yaml @@ -129,7 +129,7 @@ services: - KBASE_AUTH_URL=https://ci.kbase.us/services/auth/ - PROXY_LISTEN_PORT=15002 - BACKEND_PORT=15002 - - SERVICE_TEMPLATE=spark-notebook-{username} + - SERVICE_TEMPLATE=spark-notebook - TOKEN_CACHE_TTL=300 - MFA_EXEMPT_USERS=${CI_KBASE_USERNAME} diff --git a/notebook_utils/tests/agent/test_mcp_tools.py b/notebook_utils/tests/agent/test_mcp_tools.py index 89a9108..aa0cc5b 100644 --- a/notebook_utils/tests/agent/test_mcp_tools.py +++ b/notebook_utils/tests/agent/test_mcp_tools.py @@ -5,6 +5,7 @@ import asyncio from typing import Optional from unittest.mock import AsyncMock, Mock, patch +import concurrent.futures import pytest from pydantic import BaseModel, Field @@ -164,7 +165,7 @@ def test_sync_wrapper_timeout(self): try: mock_future = Mock() - mock_future.result.side_effect = TimeoutError() + mock_future.result.side_effect = concurrent.futures.TimeoutError() mock_loop_run = Mock(return_value=mock_future) with patch("asyncio.run_coroutine_threadsafe", mock_loop_run): diff --git a/notebook_utils/tests/agent/test_settings.py b/notebook_utils/tests/agent/test_settings.py new file mode 100644 index 0000000..c4ae5e7 --- /dev/null +++ b/notebook_utils/tests/agent/test_settings.py @@ -0,0 +1,21 @@ +""" +Tests for agent/settings.py - Agent settings configuration. +""" + +from berdl_notebook_utils.agent.settings import AgentSettings, get_agent_settings + + +class TestGetAgentSettings: + """Tests for get_agent_settings function.""" + + def test_returns_agent_settings_instance(self): + """Test get_agent_settings returns an AgentSettings instance.""" + settings = get_agent_settings() + assert isinstance(settings, AgentSettings) + + def test_default_values(self): + """Test default settings values.""" + settings = get_agent_settings() + assert settings.AGENT_MODEL_PROVIDER == "openai" + assert settings.AGENT_TEMPERATURE == 0.0 + assert settings.AGENT_VERBOSE is True diff --git a/notebook_utils/tests/mcp/test_operations.py b/notebook_utils/tests/mcp/test_operations.py index eb077b4..a20b7b9 100644 --- a/notebook_utils/tests/mcp/test_operations.py +++ b/notebook_utils/tests/mcp/test_operations.py @@ -500,3 +500,119 @@ def test_with_distinct(self, mock_client): call_kwargs = mock_api.sync.call_args[1] assert call_kwargs["body"].distinct is True + + +class TestNoneResponseErrors: + """Tests for None response error paths across all MCP operations.""" + + def test_list_tables_none_response(self, mock_client): + """Test mcp_list_tables raises on None response.""" + with patch("berdl_notebook_utils.mcp.operations.list_database_tables") as mock_api: + mock_api.sync.return_value = None + + with pytest.raises(Exception, match="no response for list_tables"): + mcp_list_tables("test_db") + + def test_get_table_schema_none_response(self, mock_client): + """Test mcp_get_table_schema raises on None response.""" + with patch("berdl_notebook_utils.mcp.operations.get_table_schema") as mock_api: + mock_api.sync.return_value = None + + with pytest.raises(Exception, match="no response for get_table_schema"): + mcp_get_table_schema("db", "table") + + def test_get_database_structure_none_response(self, mock_client): + """Test mcp_get_database_structure raises on None response.""" + with patch("berdl_notebook_utils.mcp.operations.get_database_structure") as mock_api: + mock_api.sync.return_value = None + + with pytest.raises(Exception, match="no response for get_database_structure"): + mcp_get_database_structure() + + def test_count_table_none_response(self, mock_client): + """Test mcp_count_table raises on None response.""" + with patch("berdl_notebook_utils.mcp.operations.count_delta_table") as mock_api: + mock_api.sync.return_value = None + + with pytest.raises(Exception, match="no response for count_table"): + mcp_count_table("db", "table") + + def test_sample_table_none_response(self, mock_client): + """Test mcp_sample_table raises on None response.""" + with patch("berdl_notebook_utils.mcp.operations.sample_delta_table") as mock_api: + mock_api.sync.return_value = None + + with pytest.raises(Exception, match="no response for sample_table"): + mcp_sample_table("db", "table") + + def test_query_table_none_response(self, mock_client): + """Test mcp_query_table raises on None response.""" + with patch("berdl_notebook_utils.mcp.operations.query_delta_table") as mock_api: + mock_api.sync.return_value = None + + with pytest.raises(Exception, match="no response for query_table"): + mcp_query_table("SELECT 1") + + def test_select_table_none_response(self, mock_client): + """Test mcp_select_table raises on None response.""" + with patch("berdl_notebook_utils.mcp.operations.select_delta_table") as mock_api: + mock_api.sync.return_value = None + + with pytest.raises(Exception, match="no response for select_table"): + mcp_select_table("db", "table") + + +class TestDatabaseStructureResponseConversion: + """Tests for get_database_structure response structure conversion.""" + + def test_structure_with_to_dict(self, mock_client): + """Test response structure with to_dict method is converted.""" + mock_response = Mock() + mock_structure = Mock() + mock_structure.to_dict.return_value = {"db1": ["t1"]} + mock_response.structure = mock_structure + + with patch("berdl_notebook_utils.mcp.operations.get_database_structure") as mock_api: + mock_api.sync.return_value = mock_response + + result = mcp_get_database_structure() + + assert result == {"db1": ["t1"]} + + def test_structure_without_to_dict_uses_dict(self, mock_client): + """Test response structure without to_dict uses dict() conversion.""" + mock_response = Mock(spec=[]) # No to_dict + # Use a real dict-like object without to_dict + mock_response.structure = {"db1": ["t1"]} + + with patch("berdl_notebook_utils.mcp.operations.get_database_structure") as mock_api: + mock_api.sync.return_value = mock_response + + result = mcp_get_database_structure() + + assert result == {"db1": ["t1"]} + + +class TestSelectTableHavingSpec: + """Tests for mcp_select_table having parameter.""" + + def test_with_having(self, mock_client): + """Test having parameter builds FilterCondition specs.""" + mock_response = Mock() + mock_response.data = [] + mock_response.pagination = Mock(limit=100, offset=0, total_count=0, has_more=False) + + with patch("berdl_notebook_utils.mcp.operations.select_delta_table") as mock_api: + mock_api.sync.return_value = mock_response + + mcp_select_table( + "db", + "table", + group_by=["status"], + having=[{"column": "count", "operator": ">", "value": "10"}], + ) + + call_kwargs = mock_api.sync.call_args[1] + having = call_kwargs["body"].having + assert len(having) == 1 + assert having[0].column == "count" diff --git a/notebook_utils/tests/spark/test_connect_server.py b/notebook_utils/tests/spark/test_connect_server.py index 4d01465..e3cb0f4 100644 --- a/notebook_utils/tests/spark/test_connect_server.py +++ b/notebook_utils/tests/spark/test_connect_server.py @@ -181,24 +181,6 @@ def test_generate_spark_config_with_catalog( assert "spark.sql.catalog.my=org.apache.iceberg.spark.SparkCatalog" in content assert "spark.sql.catalog.my.type=rest" in content - @patch("berdl_notebook_utils.spark.connect_server.get_settings") - def test_generate_spark_config_template_not_found(self, mock_get_settings): - """Test generate_spark_config raises if template not found.""" - mock_settings = Mock() - mock_settings.USER = "test_user" - mock_settings.SPARK_HOME = "/opt/spark" - mock_settings.SPARK_CONNECT_DEFAULTS_TEMPLATE = "/nonexistent/template.conf" - mock_url = Mock() - mock_url.port = 15002 - mock_settings.SPARK_CONNECT_URL = mock_url - mock_settings.SPARK_MASTER_URL = "spark://master:7077" - mock_get_settings.return_value = mock_settings - - config = SparkConnectServerConfig() - - with pytest.raises(FileNotFoundError, match="Spark config template not found"): - config.generate_spark_config() - @patch("berdl_notebook_utils.spark.connect_server.get_my_groups") @patch("berdl_notebook_utils.spark.connect_server.get_namespace_prefix") @patch("berdl_notebook_utils.spark.connect_server.get_settings") @@ -252,6 +234,24 @@ def test_compute_allowed_namespace_prefixes_errors(self, mock_get_settings, mock # Should return empty string (no prefixes) assert result == "" + @patch("berdl_notebook_utils.spark.connect_server.get_settings") + def test_generate_spark_config_template_not_found(self, mock_get_settings): + """Test generate_spark_config raises if template not found.""" + mock_settings = Mock() + mock_settings.USER = "test_user" + mock_settings.SPARK_HOME = "/opt/spark" + mock_settings.SPARK_CONNECT_DEFAULTS_TEMPLATE = "/nonexistent/template.conf" + mock_url = Mock() + mock_url.port = 15002 + mock_settings.SPARK_CONNECT_URL = mock_url + mock_settings.SPARK_MASTER_URL = "spark://master:7077" + mock_get_settings.return_value = mock_settings + + config = SparkConnectServerConfig() + + with pytest.raises(FileNotFoundError, match="Spark config template not found"): + config.generate_spark_config() + class TestSparkConnectServerManager: """Tests for SparkConnectServerManager class.""" @@ -559,8 +559,11 @@ def test_wait_for_port_release_timeout(self, mock_config_class): mock_sock.connect_ex.return_value = 0 manager = SparkConnectServerManager() - # Patch socket inside the method's local import - with patch("socket.socket", return_value=mock_sock): + # Patch socket and sleep inside the method's local imports to avoid real waiting + with ( + patch("socket.socket", return_value=mock_sock), + patch("berdl_notebook_utils.spark.connect_server.time.sleep", return_value=None), + ): result = manager._wait_for_port_release(timeout=0.1) assert result is False diff --git a/notebook_utils/tests/spark/test_data_store.py b/notebook_utils/tests/spark/test_data_store.py index b71a3ff..c99fc25 100644 --- a/notebook_utils/tests/spark/test_data_store.py +++ b/notebook_utils/tests/spark/test_data_store.py @@ -5,7 +5,12 @@ import json from unittest.mock import Mock, patch +import pytest + from berdl_notebook_utils.spark.data_store import ( + _cached_get_my_accessible_paths, + _cached_get_my_groups, + _cached_get_namespace_prefix, _execute_with_spark, _extract_databases_from_paths, _format_output, @@ -274,3 +279,122 @@ def test_func(spark, arg1): result = _execute_with_spark(test_func, mock_spark, "test") assert result == "result_test" + + +class TestCachedWrappers: + """Tests for cached governance API wrappers.""" + + @patch("berdl_notebook_utils.spark.data_store.get_my_groups") + def test_cached_get_my_groups(self, mock_get_my_groups): + """Test _cached_get_my_groups calls through to get_my_groups.""" + mock_get_my_groups.return_value = Mock(groups=["team1"]) + _cached_get_my_groups.clear_cache() + + result = _cached_get_my_groups() + + assert result.groups == ["team1"] + mock_get_my_groups.assert_called_once() + + @patch("berdl_notebook_utils.spark.data_store.get_namespace_prefix") + def test_cached_get_namespace_prefix(self, mock_get_ns): + """Test _cached_get_namespace_prefix calls through to get_namespace_prefix.""" + mock_get_ns.return_value = Mock(user_namespace_prefix="u_test__") + _cached_get_namespace_prefix.clear_cache() + + result = _cached_get_namespace_prefix() + + assert result.user_namespace_prefix == "u_test__" + mock_get_ns.assert_called_once() + + @patch("berdl_notebook_utils.spark.data_store.get_my_accessible_paths") + def test_cached_get_my_accessible_paths(self, mock_get_paths): + """Test _cached_get_my_accessible_paths calls through to get_my_accessible_paths.""" + mock_get_paths.return_value = Mock(accessible_paths=["s3a://bucket/path"]) + _cached_get_my_accessible_paths.clear_cache() + + result = _cached_get_my_accessible_paths() + + assert result.accessible_paths == ["s3a://bucket/path"] + mock_get_paths.assert_called_once() + + +class TestGetDatabasesFilterError: + """Tests for get_databases filter_by_namespace error path.""" + + @patch("berdl_notebook_utils.spark.data_store._cached_get_my_accessible_paths") + @patch("berdl_notebook_utils.spark.data_store._cached_get_namespace_prefix") + @patch("berdl_notebook_utils.spark.data_store._cached_get_my_groups") + @patch("berdl_notebook_utils.spark.data_store.hive_metastore") + def test_filter_error_raises(self, mock_hms, mock_groups, mock_prefix, mock_paths): + """Test get_databases raises when filter_by_namespace fails.""" + mock_hms.get_databases.return_value = ["db1"] + mock_groups.side_effect = Exception("API error") + + with pytest.raises(Exception, match="Could not filter databases by namespace"): + get_databases(use_hms=True, filter_by_namespace=True, return_json=False) + + +class TestGetTablesSparkInnerFunction: + """Tests for get_tables using Spark (inner _get_tbls function).""" + + def test_get_tables_spark_calls_catalog(self): + """Test get_tables with use_hms=False uses Spark catalog.""" + mock_spark = Mock() + mock_table1 = Mock() + mock_table1.name = "table1" + mock_table2 = Mock() + mock_table2.name = "table2" + mock_spark.catalog.listTables.return_value = [mock_table1, mock_table2] + + result = get_tables("test_db", spark=mock_spark, use_hms=False, return_json=False) + + assert result == ["table1", "table2"] + mock_spark.catalog.listTables.assert_called_once_with(dbName="test_db") + + +class TestGetTableSchemaErrorPath: + """Tests for get_table_schema inner error handling.""" + + def test_schema_error_returns_empty_list(self): + """Test _get_schema returns [] when catalog raises Exception.""" + mock_spark = Mock() + mock_spark.catalog.listColumns.side_effect = Exception("table not found") + + result = get_table_schema("test_db", "broken_table", spark=mock_spark, return_json=False) + + assert result == [] + + +class TestGetDbStructureSparkPath: + """Tests for get_db_structure with use_hms=False (Spark inner function).""" + + @patch("berdl_notebook_utils.spark.data_store.get_table_schema") + @patch("berdl_notebook_utils.spark.data_store.get_tables") + @patch("berdl_notebook_utils.spark.data_store.get_databases") + @patch("berdl_notebook_utils.spark.data_store.get_spark_session") + def test_spark_path_without_schema(self, mock_get_session, mock_get_dbs, mock_get_tables, mock_get_schema): + """Test get_db_structure via Spark without schema.""" + mock_spark = Mock() + mock_get_session.return_value = mock_spark + mock_get_dbs.return_value = ["db1"] + mock_get_tables.return_value = ["t1", "t2"] + + result = get_db_structure(with_schema=False, use_hms=False, return_json=False) + + assert result == {"db1": ["t1", "t2"]} + + @patch("berdl_notebook_utils.spark.data_store.get_table_schema") + @patch("berdl_notebook_utils.spark.data_store.get_tables") + @patch("berdl_notebook_utils.spark.data_store.get_databases") + @patch("berdl_notebook_utils.spark.data_store.get_spark_session") + def test_spark_path_with_schema(self, mock_get_session, mock_get_dbs, mock_get_tables, mock_get_schema): + """Test get_db_structure via Spark with schema.""" + mock_spark = Mock() + mock_get_session.return_value = mock_spark + mock_get_dbs.return_value = ["db1"] + mock_get_tables.return_value = ["t1"] + mock_get_schema.return_value = ["col1", "col2"] + + result = get_db_structure(with_schema=True, use_hms=False, return_json=False) + + assert result == {"db1": {"t1": ["col1", "col2"]}} diff --git a/notebook_utils/tests/spark/test_metrics.py b/notebook_utils/tests/spark/test_metrics.py index e5ecb47..b7ddfcf 100644 --- a/notebook_utils/tests/spark/test_metrics.py +++ b/notebook_utils/tests/spark/test_metrics.py @@ -396,3 +396,276 @@ def test_returns_sorted_usernames(self, metrics): users = metrics.list_users() assert users == ["alice", "bob", "charlie"] + + +class TestInitEndpointParsing: + """Tests for __init__ endpoint URL parsing and secure inference.""" + + def test_https_url_strips_scheme_and_sets_secure(self): + """Test https:// URL is stripped and secure is inferred as True.""" + with patch("berdl_notebook_utils.spark.metrics.Minio") as mock_minio: + SparkJobMetrics(endpoint="https://minio.example.com", access_key="ak", secret_key="sk") + mock_minio.assert_called_once_with("minio.example.com", access_key="ak", secret_key="sk", secure=True) + + def test_http_url_strips_scheme_and_sets_insecure(self): + """Test http:// URL is stripped and secure is inferred as False.""" + with patch("berdl_notebook_utils.spark.metrics.Minio") as mock_minio: + SparkJobMetrics(endpoint="http://localhost:9000", access_key="ak", secret_key="sk") + mock_minio.assert_called_once_with("localhost:9000", access_key="ak", secret_key="sk", secure=False) + + def test_explicit_secure_overrides_scheme(self): + """Test explicit secure=False with https:// URL — explicit wins.""" + with patch("berdl_notebook_utils.spark.metrics.Minio") as mock_minio: + SparkJobMetrics(endpoint="https://minio.example.com", access_key="ak", secret_key="sk", secure=False) + mock_minio.assert_called_once_with("minio.example.com", access_key="ak", secret_key="sk", secure=False) + + def test_no_scheme_falls_back_to_env_var(self): + """Test no scheme falls back to MINIO_SECURE env var.""" + with ( + patch("berdl_notebook_utils.spark.metrics.Minio") as mock_minio, + patch.dict("os.environ", {"MINIO_SECURE": "true"}), + ): + SparkJobMetrics(endpoint="minio.example.com:9000", access_key="ak", secret_key="sk") + mock_minio.assert_called_once_with("minio.example.com:9000", access_key="ak", secret_key="sk", secure=True) + + +class TestGetJobSummaryAllUsers: + """Tests for get_job_summary with username=None (all-users path).""" + + def test_all_users_returns_summary(self, metrics): + """Test get_job_summary without username queries all users.""" + compressed = _compress_events(SAMPLE_EVENTS) + + # _list_all_app_dirs: recursive list returning objects from multiple users + metrics._client.list_objects.side_effect = [ + # First call: recursive list for _list_all_app_dirs + [_make_obj("spark-job-logs/alice/eventlog_v2_app-20260311022136-0001/events.zstd")], + # Second call: list files in app dir for _read_event_files + [_make_obj("spark-job-logs/alice/eventlog_v2_app-20260311022136-0001/events.zstd")], + ] + metrics._client.get_object.return_value = _make_minio_response(compressed) + + df = metrics.get_job_summary(username=None) + + assert len(df) == 1 + assert df.iloc[0]["username"] == "alice" + + def test_all_users_skips_empty_events(self, metrics): + """Test all-users path skips apps with no events.""" + metrics._client.list_objects.side_effect = [ + # _list_all_app_dirs finds one app + [_make_obj("spark-job-logs/alice/eventlog_v2_app-20260311022136-0001/events.zstd")], + # _read_event_files returns no .zstd files + [], + ] + + df = metrics.get_job_summary(username=None) + assert df.empty + + def test_all_users_skips_no_app_id(self, metrics): + """Test all-users path skips jobs with empty app_id.""" + events_no_app_id = [ + {"Event": "SparkListenerEnvironmentUpdate", "Spark Properties": []}, + ] + compressed = _compress_events(events_no_app_id) + + metrics._client.list_objects.side_effect = [ + [_make_obj("spark-job-logs/alice/eventlog_v2_app-20260311022136-0001/events.zstd")], + [_make_obj("spark-job-logs/alice/eventlog_v2_app-20260311022136-0001/events.zstd")], + ] + metrics._client.get_object.return_value = _make_minio_response(compressed) + + df = metrics.get_job_summary(username=None) + assert df.empty + + +class TestGetTaskDetailAllUsers: + """Tests for get_task_detail all-users path and edge cases.""" + + def test_all_users_returns_task_rows(self, metrics): + """Test get_task_detail without username queries all users.""" + compressed = _compress_events(SAMPLE_EVENTS) + + metrics._client.list_objects.side_effect = [ + [_make_obj("spark-job-logs/alice/eventlog_v2_app-20260311022136-0001/events.zstd")], + [_make_obj("spark-job-logs/alice/eventlog_v2_app-20260311022136-0001/events.zstd")], + ] + metrics._client.get_object.return_value = _make_minio_response(compressed) + + df = metrics.get_task_detail(username=None) + assert len(df) == 1 + + def test_skips_empty_task_metrics(self, metrics): + """Test that TaskEnd events with empty Task Metrics are skipped.""" + events = [ + {"Event": "SparkListenerApplicationStart", "App ID": "app-1", "App Name": "test"}, + {"Event": "SparkListenerTaskEnd", "Task Metrics": {}}, + {"Event": "SparkListenerTaskEnd", "Task Metrics": None}, + { + "Event": "SparkListenerTaskEnd", + "Task Metrics": { + "Executor Run Time": 100, + "Executor CPU Time": 0, + "JVM GC Time": 0, + "Peak Execution Memory": 1000, + "Memory Bytes Spilled": 0, + "Disk Bytes Spilled": 0, + "Input Metrics": {}, + "Output Metrics": {}, + "Shuffle Read Metrics": {}, + "Shuffle Write Metrics": {}, + }, + "Task Info": {"Task ID": 0, "Executor ID": "1", "Host": "w1", "Launch Time": 0, "Finish Time": 100}, + }, + ] + compressed = _compress_events(events) + + metrics._client.list_objects.side_effect = [ + [_make_obj("spark-job-logs/alice/eventlog_v2_app-20260311022136-0001/")], + [_make_obj("spark-job-logs/alice/eventlog_v2_app-20260311022136-0001/events.zstd")], + ] + metrics._client.get_object.return_value = _make_minio_response(compressed) + + df = metrics.get_task_detail(username="alice") + # Only 1 task has valid Task Metrics + assert len(df) == 1 + + +class TestListAllAppDirsEdgeCases: + """Tests for _list_all_app_dirs edge cases.""" + + def test_skips_entries_with_few_path_parts(self, metrics): + """Test that entries with <2 path parts after prefix are skipped.""" + metrics._client.list_objects.return_value = [ + _make_obj("spark-job-logs/lonely-file.txt"), # Only 1 part after prefix + _make_obj("spark-job-logs/alice/eventlog_v2_app-20260311000000-0001/events.zstd"), + ] + + result = metrics._list_all_app_dirs() + assert len(result) == 1 + assert result[0][0] == "alice" + + def test_extract_ts_fallback_on_malformed_dir(self, metrics): + """Test that _extract_ts returns empty string for malformed eventlog dirs.""" + metrics._client.list_objects.return_value = [ + # Has eventlog_v2_ but no valid app-TIMESTAMP format + _make_obj("spark-job-logs/alice/eventlog_v2_malformed/events.zstd"), + _make_obj("spark-job-logs/bob/eventlog_v2_app-20260311000000-0001/events.zstd"), + ] + + result = metrics._list_all_app_dirs() + # Both should be returned; malformed one sorts with empty timestamp + assert len(result) == 2 + + +class TestListAppDirsFiltering: + """Tests for _list_app_dirs since and limit filtering.""" + + def test_since_filters_old_dirs(self, metrics): + """Test since parameter filters out old dirs.""" + metrics._client.list_objects.return_value = [ + _make_obj("spark-job-logs/alice/eventlog_v2_app-20260313000000-0001/"), + _make_obj("spark-job-logs/alice/eventlog_v2_app-20260301000000-0001/"), + ] + + result = metrics._list_app_dirs("alice", since="20260310") + assert len(result) == 1 + assert "20260313" in result[0] + + def test_since_handles_malformed_timestamp(self, metrics): + """Test since filter gracefully handles dirs without valid timestamp.""" + metrics._client.list_objects.return_value = [ + _make_obj("spark-job-logs/alice/eventlog_v2_malformed/"), + _make_obj("spark-job-logs/alice/eventlog_v2_app-20260313000000-0001/"), + ] + + result = metrics._list_app_dirs("alice", since="20260310") + # Malformed dir is kept (IndexError caught, not skipped) + assert len(result) == 2 + + def test_limit_returns_most_recent(self, metrics): + """Test limit returns N most recent dirs.""" + metrics._client.list_objects.return_value = [ + _make_obj("spark-job-logs/alice/eventlog_v2_app-20260311000000-0001/"), + _make_obj("spark-job-logs/alice/eventlog_v2_app-20260313000000-0001/"), + _make_obj("spark-job-logs/alice/eventlog_v2_app-20260312000000-0001/"), + ] + + result = metrics._list_app_dirs("alice", limit=2) + assert len(result) == 2 + # Should be sorted descending + assert "20260313" in result[0] + assert "20260312" in result[1] + + def test_skips_non_eventlog_dirs(self, metrics): + """Test that non-eventlog directories are skipped.""" + metrics._client.list_objects.return_value = [ + _make_obj("spark-job-logs/alice/some-other-dir/"), + _make_obj("spark-job-logs/alice/eventlog_v2_app-20260311000000-0001/"), + ] + + result = metrics._list_app_dirs("alice") + assert len(result) == 1 + + +class TestParseUserLogsEdgeCases: + """Tests for _parse_user_logs edge cases.""" + + def test_skips_jobs_without_app_id(self, metrics): + """Test _parse_user_logs skips jobs with empty app_id.""" + events_no_app = [{"Event": "SparkListenerStageCompleted"}] + compressed = _compress_events(events_no_app) + + metrics._client.list_objects.side_effect = [ + [_make_obj("spark-job-logs/alice/eventlog_v2_app-20260311022136-0001/")], + [_make_obj("spark-job-logs/alice/eventlog_v2_app-20260311022136-0001/events.zstd")], + ] + metrics._client.get_object.return_value = _make_minio_response(compressed) + + jobs = metrics._parse_user_logs("alice") + assert len(jobs) == 0 + + +class TestEventsToJobMetricsAdditionalEvents: + """Tests for event types not covered by existing tests.""" + + def test_resource_profile_added(self, metrics): + """Test SparkListenerResourceProfileAdded sets allocated memory/cores.""" + events = [ + {"Event": "SparkListenerApplicationStart", "App ID": "app-1"}, + { + "Event": "SparkListenerResourceProfileAdded", + "Executor Resource Requests": { + "memory": {"Amount": 4096}, + "cores": {"Amount": 4}, + }, + }, + ] + job = metrics._events_to_job_metrics(events, "alice") + + assert job.allocated_executor_memory_mb == 4096 + assert job.allocated_executor_cores == 4 + + def test_block_manager_added_tracks_max(self, metrics): + """Test SparkListenerBlockManagerAdded tracks maximum memory values.""" + events = [ + {"Event": "SparkListenerApplicationStart", "App ID": "app-1"}, + {"Event": "SparkListenerBlockManagerAdded", "Maximum Memory": 1000, "Maximum Onheap Memory": 800}, + {"Event": "SparkListenerBlockManagerAdded", "Maximum Memory": 2000, "Maximum Onheap Memory": 1500}, + {"Event": "SparkListenerBlockManagerAdded", "Maximum Memory": 500, "Maximum Onheap Memory": 400}, + ] + job = metrics._events_to_job_metrics(events, "alice") + + assert job.block_manager_max_memory_bytes == 2000 + assert job.block_manager_max_onheap_bytes == 1500 + + def test_task_end_with_empty_metrics_skipped(self, metrics): + """Test SparkListenerTaskEnd with empty Task Metrics is skipped.""" + events = [ + {"Event": "SparkListenerApplicationStart", "App ID": "app-1"}, + {"Event": "SparkListenerTaskEnd", "Task Metrics": {}}, + {"Event": "SparkListenerTaskEnd"}, # No Task Metrics key at all + ] + job = metrics._events_to_job_metrics(events, "alice") + + assert job.total_tasks == 0 diff --git a/notebook_utils/tests/test_cache.py b/notebook_utils/tests/test_cache.py index f0920f6..c29d816 100644 --- a/notebook_utils/tests/test_cache.py +++ b/notebook_utils/tests/test_cache.py @@ -16,18 +16,20 @@ def test_registers_function(self): """Test that decorator registers the function in _token_change_caches.""" initial_len = len(_token_change_caches) - # Production order: @kbase_token_dependent on top of @lru_cache - # This means lru_cache wraps first, then kbase_token_dependent registers the wrapper - @kbase_token_dependent - @lru_cache - def dummy_func(): - return "value" - - assert len(_token_change_caches) == initial_len + 1 - assert dummy_func in _token_change_caches - - # Clean up - _token_change_caches.remove(dummy_func) + try: + # Production order: @kbase_token_dependent on top of @lru_cache + # This means lru_cache wraps first, then kbase_token_dependent registers the wrapper + @kbase_token_dependent + @lru_cache + def dummy_func(): + return "value" + + assert len(_token_change_caches) == initial_len + 1 + assert dummy_func in _token_change_caches + finally: + # Clean up without assuming registration always occurred + if "dummy_func" in locals() and dummy_func in _token_change_caches: + _token_change_caches.remove(dummy_func) def test_returns_function_unchanged(self): """Test that decorator returns the function without modification.""" @@ -36,12 +38,14 @@ def test_returns_function_unchanged(self): def original(): return 42 - result = kbase_token_dependent(original) - assert result is original - assert result() == 42 - - # Clean up - _token_change_caches.remove(original) + try: + result = kbase_token_dependent(original) + assert result is original + assert result() == 42 + finally: + # Clean up without assuming registration always occurred + if original in _token_change_caches: + _token_change_caches.remove(original) class TestClearKbaseTokenCaches: From 8b0f05d36c00fb55f7b3c6a777e94a65cfb416f2 Mon Sep 17 00:00:00 2001 From: Tianhao-Gu Date: Sun, 22 Mar 2026 21:03:04 -0500 Subject: [PATCH 3/5] Fix lint errors: remove duplicate imports and test class - Remove duplicate list_user_names import - Remove duplicate UserNamesResponse import (already from governance_client.models) - Remove duplicate TestListUserNames class (kept original with more thorough assertions) --- .../tests/minio_governance/test_operations.py | 36 ------------------- 1 file changed, 36 deletions(-) diff --git a/notebook_utils/tests/minio_governance/test_operations.py b/notebook_utils/tests/minio_governance/test_operations.py index 2c6aaf2..9a413d7 100644 --- a/notebook_utils/tests/minio_governance/test_operations.py +++ b/notebook_utils/tests/minio_governance/test_operations.py @@ -48,7 +48,6 @@ list_groups, list_user_names, list_users, - list_user_names, add_group_member, remove_group_member, create_tenant_and_assign_users, @@ -60,7 +59,6 @@ CredentialsResponse, ErrorResponse, GroupManagementResponse, - UserNamesResponse, ) @@ -1127,40 +1125,6 @@ def test_list_available_groups_none_response(self, mock_list_groups, mock_get_cl list_available_groups() -class TestListUserNames: - """Tests for list_user_names function.""" - - @patch("berdl_notebook_utils.minio_governance.operations.get_governance_client") - @patch("berdl_notebook_utils.minio_governance.operations.list_user_names_sync") - def test_list_user_names_success(self, mock_list_names, mock_get_client): - mock_get_client.return_value = Mock() - mock_response = Mock(spec=UserNamesResponse) - mock_response.usernames = ["alice", "bob", "charlie"] - mock_list_names.return_value = mock_response - - result = list_user_names() - - assert result == ["alice", "bob", "charlie"] - - @patch("berdl_notebook_utils.minio_governance.operations.get_governance_client") - @patch("berdl_notebook_utils.minio_governance.operations.list_user_names_sync") - def test_list_user_names_error_response(self, mock_list_names, mock_get_client): - mock_get_client.return_value = Mock() - mock_list_names.return_value = ErrorResponse(message="forbidden", error_type="error") - - with pytest.raises(RuntimeError, match="Failed to list usernames: forbidden"): - list_user_names() - - @patch("berdl_notebook_utils.minio_governance.operations.get_governance_client") - @patch("berdl_notebook_utils.minio_governance.operations.list_user_names_sync") - def test_list_user_names_none_response(self, mock_list_names, mock_get_client): - mock_get_client.return_value = Mock() - mock_list_names.return_value = None - - with pytest.raises(RuntimeError, match="Failed to list usernames: no response from API"): - list_user_names() - - class TestCreateTenantAddMemberErrorAndException: """Tests for create_tenant_and_assign_users error/exception paths.""" From 745af4d0bcc4fbb0df2e520fc96db7f67bf9a698 Mon Sep 17 00:00:00 2001 From: Tianhao-Gu Date: Sun, 22 Mar 2026 21:04:15 -0500 Subject: [PATCH 4/5] run formatter --- notebook_utils/tests/minio_governance/test_operations.py | 1 - 1 file changed, 1 deletion(-) diff --git a/notebook_utils/tests/minio_governance/test_operations.py b/notebook_utils/tests/minio_governance/test_operations.py index 9a413d7..1dbcfee 100644 --- a/notebook_utils/tests/minio_governance/test_operations.py +++ b/notebook_utils/tests/minio_governance/test_operations.py @@ -979,7 +979,6 @@ def test_request_tenant_access_http_error(self, mock_settings, mock_httpx): request_tenant_access("kbase") - # ============================================================================= # Additional tests for uncovered lines # ============================================================================= From 322409c3f02e65e9217d419d45ba66a4f0ce5ded Mon Sep 17 00:00:00 2001 From: Tianhao-Gu Date: Mon, 23 Mar 2026 14:06:36 -0500 Subject: [PATCH 5/5] Remove MinIO file-based credential caching, keep Polaris caching Align MinIO credential flow with main (PR #142): use direct API calls instead of local file cache with file locking. The governance API caches credentials server-side in PostgreSQL. Polaris credentials still use file-based caching since the provisioning API is heavier weight. --- .../minio_governance/operations.py | 83 +--- .../berdl_notebook_utils/refresh.py | 20 +- .../tests/minio_governance/test_operations.py | 467 +++++------------- notebook_utils/tests/test_refresh.py | 4 +- 4 files changed, 150 insertions(+), 424 deletions(-) diff --git a/notebook_utils/berdl_notebook_utils/minio_governance/operations.py b/notebook_utils/berdl_notebook_utils/minio_governance/operations.py index c2fd79e..e5d6e83 100644 --- a/notebook_utils/berdl_notebook_utils/minio_governance/operations.py +++ b/notebook_utils/berdl_notebook_utils/minio_governance/operations.py @@ -105,8 +105,7 @@ class TenantCreationResult(TypedDict): SQL_WAREHOUSE_BUCKET = "cdm-lake" # TODO: change to berdl-lake SQL_USER_WAREHOUSE_PATH = "users-sql-warehouse" -# Credential caching configuration -CREDENTIALS_CACHE_FILE = ".berdl_minio_credentials" +# Credential caching configuration (Polaris only — MinIO uses direct API calls) POLARIS_CREDENTIALS_CACHE_FILE = ".berdl_polaris_credentials" @@ -143,37 +142,11 @@ def _fetch_with_file_cache( return result -def _get_credentials_cache_path() -> Path: - """Get the path to the MinIO credentials cache file in the user's home directory.""" - return Path.home() / CREDENTIALS_CACHE_FILE - - def _get_polaris_cache_path() -> Path: """Get the path to the Polaris credentials cache file in the user's home directory.""" return Path.home() / POLARIS_CREDENTIALS_CACHE_FILE -def _read_cached_credentials(cache_path: Path) -> CredentialsResponse | None: - """Read MinIO credentials from cache file. Returns None if file doesn't exist or is corrupted.""" - try: - if not cache_path.exists(): - return None - with open(cache_path, "r") as f: - data = json.load(f) - return CredentialsResponse.from_dict(data) - except (json.JSONDecodeError, TypeError, KeyError, OSError): - return None - - -def _write_credentials_cache(cache_path: Path, credentials: CredentialsResponse) -> None: - """Write MinIO credentials to cache file.""" - try: - with open(cache_path, "w") as f: - json.dump(credentials.to_dict(), f) - except (OSError, TypeError): - pass - - def _read_cached_polaris_credentials(cache_path: Path) -> "PolarisCredentials | None": """Read Polaris credentials from cache file. Returns None if file doesn't exist or is corrupted.""" try: @@ -236,21 +209,12 @@ def check_governance_health() -> HealthResponse: return response -def _fetch_minio_credentials() -> CredentialsResponse | None: - """Fetch fresh MinIO credentials from the governance API.""" - client = get_governance_client() - api_response = get_credentials_credentials_get.sync(client=client) - if isinstance(api_response, CredentialsResponse): - return api_response - return None - - def get_minio_credentials() -> CredentialsResponse: """ Get MinIO credentials for the current user and set them as environment variables. - Uses file locking to prevent race conditions when multiple processes/notebooks - try to access credentials simultaneously. + Fetches credentials from the governance API (MMS). The API caches credentials + server-side in PostgreSQL, so repeated calls are fast. Sets the following environment variables: - MINIO_ACCESS_KEY: User's MinIO access key @@ -259,24 +223,19 @@ def get_minio_credentials() -> CredentialsResponse: Returns: CredentialsResponse with username, access_key, and secret_key """ - credentials = _fetch_with_file_cache( - _get_credentials_cache_path(), - _read_cached_credentials, - _fetch_minio_credentials, - _write_credentials_cache, - ) - if credentials is None: + client = get_governance_client() + api_response = get_credentials_credentials_get.sync(client=client) + if not isinstance(api_response, CredentialsResponse): raise RuntimeError("Failed to fetch credentials from API") - # Set MinIO credentials as environment variables - os.environ["MINIO_ACCESS_KEY"] = credentials.access_key - os.environ["MINIO_SECRET_KEY"] = credentials.secret_key + os.environ["MINIO_ACCESS_KEY"] = api_response.access_key + os.environ["MINIO_SECRET_KEY"] = api_response.secret_key # Clear the cached settings so subsequent get_settings() calls pick up the # new MINIO_ACCESS_KEY / MINIO_SECRET_KEY env vars. get_settings.cache_clear() - return credentials + return api_response class PolarisCredentials(TypedDict): @@ -361,13 +320,10 @@ def get_polaris_credentials() -> PolarisCredentials | None: def rotate_minio_credentials() -> CredentialsResponse: """ - Rotate MinIO credentials for the current user and update local caches. + Rotate MinIO credentials for the current user and update environment variables. Calls POST /credentials/rotate to generate new credentials in MinIO, - then updates the local cache file and environment variables. - - Uses the same file locking strategy as get_minio_credentials() to prevent - concurrent access from corrupting the cache file. + then updates the environment variables. Returns: CredentialsResponse with username, access_key, and secret_key @@ -377,23 +333,6 @@ def rotate_minio_credentials() -> CredentialsResponse: if not isinstance(api_response, CredentialsResponse): raise RuntimeError("Failed to rotate credentials from API") - # Update the local credential cache under lock - cache_path = _get_credentials_cache_path() - lock_path = cache_path.with_suffix(".lock") - - with open(lock_path, "w") as lock_file: - try: - fcntl.flock(lock_file.fileno(), fcntl.LOCK_EX) - _write_credentials_cache(cache_path, api_response) - finally: - pass - - try: - lock_path.unlink(missing_ok=True) - except OSError: - pass - - # Update environment variables os.environ["MINIO_ACCESS_KEY"] = api_response.access_key os.environ["MINIO_SECRET_KEY"] = api_response.secret_key diff --git a/notebook_utils/berdl_notebook_utils/refresh.py b/notebook_utils/berdl_notebook_utils/refresh.py index 6dc3bbd..e26da87 100644 --- a/notebook_utils/berdl_notebook_utils/refresh.py +++ b/notebook_utils/berdl_notebook_utils/refresh.py @@ -1,9 +1,9 @@ """ Refresh credentials and Spark environment. -Provides a single function to clear all credential caches, re-provision -MinIO and Polaris credentials, restart the Spark Connect server, and stop -any existing Spark session — ensuring get_spark_session() works afterward. +Provides a single function to re-provision MinIO and Polaris credentials, +restart the Spark Connect server, and stop any existing Spark session — +ensuring get_spark_session() works afterward. """ import logging @@ -13,7 +13,6 @@ from berdl_notebook_utils.berdl_settings import get_settings from berdl_notebook_utils.minio_governance.operations import ( - _get_credentials_cache_path, _get_polaris_cache_path, rotate_minio_credentials, get_polaris_credentials, @@ -36,10 +35,10 @@ def _remove_cache_file(path: Path) -> bool: def refresh_spark_environment() -> dict: - """Clear all credential caches, re-provision credentials, and restart Spark. + """Re-provision credentials and restart Spark. Steps performed: - 1. Delete MinIO and Polaris credential cache files + 1. Delete Polaris credential cache file 2. Clear the in-memory ``get_settings()`` LRU cache 3. Rotate MinIO credentials via MMS (generates new secret key, updates env vars) 4. Re-fetch Polaris credentials (sets POLARIS_CREDENTIAL and catalog env vars) @@ -53,14 +52,9 @@ def refresh_spark_environment() -> dict: """ result: dict = {} - # 1. Delete credential cache files - minio_removed = _remove_cache_file(_get_credentials_cache_path()) + # 1. Delete Polaris credential cache file (MinIO uses direct API calls, no local cache) polaris_removed = _remove_cache_file(_get_polaris_cache_path()) - logger.info( - "Cleared credential caches (minio=%s, polaris=%s)", - minio_removed, - polaris_removed, - ) + logger.info("Cleared Polaris credential cache (removed=%s)", polaris_removed) # 2. Clear in-memory settings cache get_settings.cache_clear() diff --git a/notebook_utils/tests/minio_governance/test_operations.py b/notebook_utils/tests/minio_governance/test_operations.py index 1dbcfee..bed3e07 100644 --- a/notebook_utils/tests/minio_governance/test_operations.py +++ b/notebook_utils/tests/minio_governance/test_operations.py @@ -6,27 +6,23 @@ import logging from pathlib import Path from unittest.mock import Mock, patch + import httpx import pytest - from governance_client.models import ( HealthResponse, NamespacePrefixResponse, PathAccessResponse, UserAccessiblePathsResponse, UserGroupsResponse, - UserNamesResponse, UserPoliciesResponse, UserSqlWarehousePrefixResponse, ) from berdl_notebook_utils.minio_governance.operations import ( _fetch_with_file_cache, - _get_credentials_cache_path, _get_polaris_cache_path, - _read_cached_credentials, _read_cached_polaris_credentials, - _write_credentials_cache, _write_polaris_credentials_cache, _build_table_path, check_governance_health, @@ -46,80 +42,22 @@ make_table_private, list_available_groups, list_groups, - list_user_names, list_users, + list_user_names, add_group_member, remove_group_member, create_tenant_and_assign_users, request_tenant_access, rotate_minio_credentials, regenerate_policies, - CREDENTIALS_CACHE_FILE, POLARIS_CREDENTIALS_CACHE_FILE, CredentialsResponse, ErrorResponse, GroupManagementResponse, + UserNamesResponse, ) -class TestGetCredentialsCachePath: - """Tests for _get_credentials_cache_path helper.""" - - def test_returns_path_in_home(self): - """Test returns path in home directory.""" - path = _get_credentials_cache_path() - - assert path == Path.home() / CREDENTIALS_CACHE_FILE - - -class TestReadCachedCredentials: - """Tests for _read_cached_credentials helper.""" - - def test_returns_none_if_file_not_exists(self, tmp_path): - """Test returns None if cache file doesn't exist.""" - result = _read_cached_credentials(tmp_path / "nonexistent.json") - - assert result is None - - def test_returns_none_on_invalid_json(self, tmp_path): - """Test returns None on invalid JSON.""" - cache_file = tmp_path / "cache.json" - cache_file.write_text("not valid json") - - result = _read_cached_credentials(cache_file) - - assert result is None - - @patch("berdl_notebook_utils.minio_governance.operations.CredentialsResponse") - def test_returns_credentials_on_valid_cache(self, mock_creds_class, tmp_path): - """Test returns credentials on valid cache file.""" - cache_file = tmp_path / "cache.json" - cache_file.write_text('{"access_key": "key", "secret_key": "secret"}') - - mock_creds = Mock() - mock_creds_class.from_dict.return_value = mock_creds - - result = _read_cached_credentials(cache_file) - - assert result == mock_creds - - -class TestWriteCredentialsCache: - """Tests for _write_credentials_cache helper.""" - - def test_writes_credentials_to_file(self, tmp_path): - """Test writes credentials to cache file.""" - cache_file = tmp_path / "cache.json" - mock_creds = Mock() - mock_creds.to_dict.return_value = {"access_key": "test_key"} - - _write_credentials_cache(cache_file, mock_creds) - - assert cache_file.exists() - content = json.loads(cache_file.read_text()) - assert content["access_key"] == "test_key" - - class TestBuildTablePath: """Tests for _build_table_path helper.""" @@ -136,59 +74,6 @@ def test_builds_path_with_db_suffix(self): assert path == "s3a://cdm-lake/users-sql-warehouse/user1/analytics.db/users" -class TestFetchWithFileCache: - """Tests for _fetch_with_file_cache helper.""" - - @patch("berdl_notebook_utils.minio_governance.operations.fcntl") - def test_returns_cached_value_on_cache_hit(self, mock_fcntl, tmp_path): - """Test returns cached value without calling fetch when cache hits.""" - cache_path = tmp_path / "creds.json" - sentinel = {"key": "cached_value"} - - read_cache = Mock(return_value=sentinel) - fetch = Mock() - write_cache = Mock() - - result = _fetch_with_file_cache(cache_path, read_cache, fetch, write_cache) - - assert result == sentinel - read_cache.assert_called_once_with(cache_path) - fetch.assert_not_called() - write_cache.assert_not_called() - - @patch("berdl_notebook_utils.minio_governance.operations.fcntl") - def test_fetches_and_writes_cache_on_cache_miss(self, mock_fcntl, tmp_path): - """Test fetches fresh data and writes cache when cache misses.""" - cache_path = tmp_path / "creds.json" - sentinel = {"key": "fresh_value"} - - read_cache = Mock(return_value=None) - fetch = Mock(return_value=sentinel) - write_cache = Mock() - - result = _fetch_with_file_cache(cache_path, read_cache, fetch, write_cache) - - assert result == sentinel - read_cache.assert_called_once_with(cache_path) - fetch.assert_called_once() - write_cache.assert_called_once_with(cache_path, sentinel) - - @patch("berdl_notebook_utils.minio_governance.operations.fcntl") - def test_returns_none_without_writing_when_fetch_fails(self, mock_fcntl, tmp_path): - """Test returns None and does not write cache when fetch returns None.""" - cache_path = tmp_path / "creds.json" - - read_cache = Mock(return_value=None) - fetch = Mock(return_value=None) - write_cache = Mock() - - result = _fetch_with_file_cache(cache_path, read_cache, fetch, write_cache) - - assert result is None - fetch.assert_called_once() - write_cache.assert_not_called() - - class TestCheckGovernanceHealth: """Tests for check_governance_health function.""" @@ -226,55 +111,12 @@ def test_check_governance_health_none_response(self, mock_health_check, mock_get class TestGetMinioCredentials: """Tests for get_minio_credentials function.""" + @patch("berdl_notebook_utils.minio_governance.operations.get_settings") @patch("berdl_notebook_utils.minio_governance.operations.os") - @patch("berdl_notebook_utils.minio_governance.operations.fcntl") - @patch("berdl_notebook_utils.minio_governance.operations._write_credentials_cache") - @patch("berdl_notebook_utils.minio_governance.operations._read_cached_credentials") - @patch("berdl_notebook_utils.minio_governance.operations._get_credentials_cache_path") - def test_returns_cached_credentials( - self, - mock_cache_path, - mock_read_cache, - mock_write_cache, - mock_fcntl, - mock_os, - tmp_path, - ): - """Test returns cached credentials when available.""" - mock_cache_path.return_value = tmp_path / ".cache" - mock_creds = Mock() - mock_creds.access_key = "cached_key" - mock_creds.secret_key = "cached_secret" - mock_read_cache.return_value = mock_creds - - result = get_minio_credentials() - - assert result == mock_creds - mock_os.environ.__setitem__.assert_any_call("MINIO_ACCESS_KEY", "cached_key") - mock_os.environ.__setitem__.assert_any_call("MINIO_SECRET_KEY", "cached_secret") - - @patch("berdl_notebook_utils.minio_governance.operations.os") - @patch("berdl_notebook_utils.minio_governance.operations.fcntl") @patch("berdl_notebook_utils.minio_governance.operations.get_credentials_credentials_get") @patch("berdl_notebook_utils.minio_governance.operations.get_governance_client") - @patch("berdl_notebook_utils.minio_governance.operations._write_credentials_cache") - @patch("berdl_notebook_utils.minio_governance.operations._read_cached_credentials") - @patch("berdl_notebook_utils.minio_governance.operations._get_credentials_cache_path") - def test_fetches_fresh_credentials_when_no_cache( - self, - mock_cache_path, - mock_read_cache, - mock_write_cache, - mock_get_client, - mock_get_creds, - mock_fcntl, - mock_os, - tmp_path, - ): - """Test fetches fresh credentials when cache is empty.""" - mock_cache_path.return_value = tmp_path / ".cache" - mock_read_cache.return_value = None - + def test_fetches_credentials_from_api(self, mock_get_client, mock_get_creds, mock_os, mock_get_settings): + """Test fetches credentials from API and sets env vars.""" mock_client = Mock() mock_get_client.return_value = mock_client @@ -286,7 +128,30 @@ def test_fetches_fresh_credentials_when_no_cache( result = get_minio_credentials() assert result == mock_creds - mock_write_cache.assert_called_once() + mock_get_creds.sync.assert_called_once_with(client=mock_client) + mock_os.environ.__setitem__.assert_any_call("MINIO_ACCESS_KEY", "new_key") + mock_os.environ.__setitem__.assert_any_call("MINIO_SECRET_KEY", "new_secret") + mock_get_settings.cache_clear.assert_called_once() + + @patch("berdl_notebook_utils.minio_governance.operations.get_credentials_credentials_get") + @patch("berdl_notebook_utils.minio_governance.operations.get_governance_client") + def test_raises_on_error_response(self, mock_get_client, mock_get_creds): + """Test raises RuntimeError when API returns an error response.""" + mock_get_client.return_value = Mock() + mock_get_creds.sync.return_value = ErrorResponse(message="unauthorized", error_type="error") + + with pytest.raises(RuntimeError, match="Failed to fetch credentials from API"): + get_minio_credentials() + + @patch("berdl_notebook_utils.minio_governance.operations.get_credentials_credentials_get") + @patch("berdl_notebook_utils.minio_governance.operations.get_governance_client") + def test_raises_on_none_response(self, mock_get_client, mock_get_creds): + """Test raises RuntimeError when API returns None.""" + mock_get_client.return_value = Mock() + mock_get_creds.sync.return_value = None + + with pytest.raises(RuntimeError, match="Failed to fetch credentials from API"): + get_minio_credentials() class TestRotateMinioCredentials: @@ -294,21 +159,10 @@ class TestRotateMinioCredentials: @patch("berdl_notebook_utils.minio_governance.operations.get_settings") @patch("berdl_notebook_utils.minio_governance.operations.os") - @patch("berdl_notebook_utils.minio_governance.operations._write_credentials_cache") - @patch("berdl_notebook_utils.minio_governance.operations._get_credentials_cache_path") @patch("berdl_notebook_utils.minio_governance.operations.rotate_credentials_credentials_rotate_post") @patch("berdl_notebook_utils.minio_governance.operations.get_governance_client") - def test_rotates_and_updates_cache( - self, - mock_get_client, - mock_rotate_api, - mock_cache_path, - mock_write_cache, - mock_os, - mock_get_settings, - tmp_path, - ): - """Test rotate calls API and updates local cache and env vars.""" + def test_rotates_and_updates_env_vars(self, mock_get_client, mock_rotate_api, mock_os, mock_get_settings): + """Test rotate calls API and updates env vars.""" mock_client = Mock() mock_get_client.return_value = mock_client @@ -318,13 +172,10 @@ def test_rotates_and_updates_cache( mock_creds.username = "testuser" mock_rotate_api.sync.return_value = mock_creds - mock_cache_path.return_value = tmp_path / ".cache" - result = rotate_minio_credentials() assert result == mock_creds mock_rotate_api.sync.assert_called_once_with(client=mock_client) - mock_write_cache.assert_called_once() mock_os.environ.__setitem__.assert_any_call("MINIO_ACCESS_KEY", "rotated_key") mock_os.environ.__setitem__.assert_any_call("MINIO_SECRET_KEY", "rotated_secret") mock_get_settings.cache_clear.assert_called_once() @@ -794,45 +645,6 @@ def test_list_users(self, mock_list_users, mock_get_client): assert result.users == ["user1", "user2"] -class TestListUserNames: - """Tests for list_user_names function.""" - - @patch("berdl_notebook_utils.minio_governance.operations.get_governance_client") - @patch("berdl_notebook_utils.minio_governance.operations.list_user_names_sync") - def test_list_user_names_success(self, mock_list_user_names, mock_get_client): - """Test list_user_names returns list of usernames.""" - mock_client = Mock() - mock_get_client.return_value = mock_client - mock_list_user_names.return_value = Mock(spec=UserNamesResponse, usernames=["user1", "user2", "user3"]) - - result = list_user_names() - - assert result == ["user1", "user2", "user3"] - mock_list_user_names.assert_called_once_with(client=mock_client) - - @patch("berdl_notebook_utils.minio_governance.operations.get_governance_client") - @patch("berdl_notebook_utils.minio_governance.operations.list_user_names_sync") - def test_list_user_names_error_response(self, mock_list_user_names, mock_get_client): - """Test list_user_names raises on error response.""" - mock_client = Mock() - mock_get_client.return_value = mock_client - mock_list_user_names.return_value = Mock(spec=ErrorResponse, message="Forbidden") - - with pytest.raises(RuntimeError, match="Failed to list usernames"): - list_user_names() - - @patch("berdl_notebook_utils.minio_governance.operations.get_governance_client") - @patch("berdl_notebook_utils.minio_governance.operations.list_user_names_sync") - def test_list_user_names_none_response(self, mock_list_user_names, mock_get_client): - """Test list_user_names raises on None response.""" - mock_client = Mock() - mock_get_client.return_value = mock_client - mock_list_user_names.return_value = None - - with pytest.raises(RuntimeError, match="no response from API"): - list_user_names() - - class TestAddGroupMember: """Tests for add_group_member function.""" @@ -984,117 +796,8 @@ def test_request_tenant_access_http_error(self, mock_settings, mock_httpx): # ============================================================================= -class TestWriteCredentialsCacheErrors: - """Tests for _write_credentials_cache error handling.""" - - def test_silently_handles_os_error(self, tmp_path): - """Test swallows OSError when writing fails (e.g. read-only dir).""" - bad_path = tmp_path / "nonexistent_dir" / "cache.json" - mock_creds = Mock() - mock_creds.to_dict.return_value = {"access_key": "key"} - - # Should not raise - _write_credentials_cache(bad_path, mock_creds) - - def test_silently_handles_type_error(self, tmp_path): - """Test swallows TypeError when serialization fails.""" - cache_file = tmp_path / "cache.json" - mock_creds = Mock() - mock_creds.to_dict.return_value = {"bad": object()} # Not JSON-serializable - - # Should not raise - _write_credentials_cache(cache_file, mock_creds) - - -class TestGetMinioCredentialsFreshFetchFailure: - """Tests for get_minio_credentials when API returns non-CredentialsResponse.""" - - @patch("berdl_notebook_utils.minio_governance.operations.get_settings") - @patch("berdl_notebook_utils.minio_governance.operations.os") - @patch("berdl_notebook_utils.minio_governance.operations.fcntl") - @patch("berdl_notebook_utils.minio_governance.operations.get_credentials_credentials_get") - @patch("berdl_notebook_utils.minio_governance.operations.get_governance_client") - @patch("berdl_notebook_utils.minio_governance.operations._write_credentials_cache") - @patch("berdl_notebook_utils.minio_governance.operations._read_cached_credentials") - @patch("berdl_notebook_utils.minio_governance.operations._get_credentials_cache_path") - def test_raises_when_api_returns_error( - self, - mock_cache_path, - mock_read_cache, - mock_write_cache, - mock_get_client, - mock_get_creds, - mock_fcntl, - mock_os, - mock_get_settings, - tmp_path, - ): - mock_cache_path.return_value = tmp_path / ".cache" - mock_read_cache.return_value = None - mock_get_client.return_value = Mock() - mock_get_creds.sync.return_value = ErrorResponse(message="unauthorized", error_type="error") - - with pytest.raises(RuntimeError, match="Failed to fetch credentials from API"): - get_minio_credentials() - - -class TestGetMinioCredentialsLockCleanupOSError: - """Tests for OSError during lock file cleanup.""" - - @patch("berdl_notebook_utils.minio_governance.operations.get_settings") - @patch("berdl_notebook_utils.minio_governance.operations.fcntl") - @patch("berdl_notebook_utils.minio_governance.operations._read_cached_credentials") - @patch("berdl_notebook_utils.minio_governance.operations._get_credentials_cache_path") - def test_handles_lock_cleanup_oserror( - self, - mock_cache_path, - mock_read_cache, - mock_fcntl, - mock_get_settings, - tmp_path, - ): - mock_cache_path.return_value = tmp_path / ".cache" - mock_creds = CredentialsResponse(username="u", access_key="ak", secret_key="sk") - mock_read_cache.return_value = mock_creds - - with patch.object(Path, "unlink", side_effect=OSError("permission denied")): - result = get_minio_credentials() - - assert result.access_key == "ak" - - -class TestRotateMinioCredentialsLockCleanupOSError: - """Tests for OSError during lock file cleanup in rotate.""" - - @patch("berdl_notebook_utils.minio_governance.operations.get_settings") - @patch("berdl_notebook_utils.minio_governance.operations._write_credentials_cache") - @patch("berdl_notebook_utils.minio_governance.operations._get_credentials_cache_path") - @patch("berdl_notebook_utils.minio_governance.operations.rotate_credentials_credentials_rotate_post") - @patch("berdl_notebook_utils.minio_governance.operations.get_governance_client") - def test_handles_lock_cleanup_oserror( - self, - mock_get_client, - mock_rotate_api, - mock_cache_path, - mock_write_cache, - mock_get_settings, - tmp_path, - ): - mock_get_client.return_value = Mock() - mock_creds = Mock(spec=CredentialsResponse) - mock_creds.access_key = "new_key" - mock_creds.secret_key = "new_secret" - mock_rotate_api.sync.return_value = mock_creds - mock_cache_path.return_value = tmp_path / ".cache" - - with patch.object(Path, "unlink", side_effect=OSError("permission denied")): - result = rotate_minio_credentials() - - assert result == mock_creds - - class TestUnshareTableLogsErrors: - """Tests for unshare_table error logging.""" + """Tests for unshare_table error logging (lines 517-519).""" @patch("berdl_notebook_utils.minio_governance.operations.get_settings") @patch("berdl_notebook_utils.minio_governance.operations.get_governance_client") @@ -1112,7 +815,7 @@ def test_unshare_table_logs_errors(self, mock_unshare, mock_get_client, mock_set class TestListAvailableGroupsNoneResponse: - """Tests for list_available_groups None response.""" + """Tests for list_available_groups None response (line 632).""" @patch("berdl_notebook_utils.minio_governance.operations.get_governance_client") @patch("berdl_notebook_utils.minio_governance.operations.list_group_names_sync") @@ -1124,8 +827,42 @@ def test_list_available_groups_none_response(self, mock_list_groups, mock_get_cl list_available_groups() +class TestListUserNames: + """Tests for list_user_names function.""" + + @patch("berdl_notebook_utils.minio_governance.operations.get_governance_client") + @patch("berdl_notebook_utils.minio_governance.operations.list_user_names_sync") + def test_list_user_names_success(self, mock_list_names, mock_get_client): + mock_client = Mock() + mock_get_client.return_value = mock_client + mock_list_names.return_value = Mock(spec=UserNamesResponse, usernames=["alice", "bob", "charlie"]) + + result = list_user_names() + + assert result == ["alice", "bob", "charlie"] + mock_list_names.assert_called_once_with(client=mock_client) + + @patch("berdl_notebook_utils.minio_governance.operations.get_governance_client") + @patch("berdl_notebook_utils.minio_governance.operations.list_user_names_sync") + def test_list_user_names_error_response(self, mock_list_names, mock_get_client): + mock_get_client.return_value = Mock() + mock_list_names.return_value = Mock(spec=ErrorResponse, message="Forbidden") + + with pytest.raises(RuntimeError, match="Failed to list usernames"): + list_user_names() + + @patch("berdl_notebook_utils.minio_governance.operations.get_governance_client") + @patch("berdl_notebook_utils.minio_governance.operations.list_user_names_sync") + def test_list_user_names_none_response(self, mock_list_names, mock_get_client): + mock_get_client.return_value = Mock() + mock_list_names.return_value = None + + with pytest.raises(RuntimeError, match="no response from API"): + list_user_names() + + class TestCreateTenantAddMemberErrorAndException: - """Tests for create_tenant_and_assign_users error/exception paths.""" + """Tests for create_tenant_and_assign_users error/exception paths (lines 847, 851-856).""" @patch("berdl_notebook_utils.minio_governance.operations.time") @patch("berdl_notebook_utils.minio_governance.operations.get_governance_client") @@ -1162,6 +899,7 @@ def test_handles_exception_during_add_member( with caplog.at_level(logging.ERROR): result = create_tenant_and_assign_users("tenant1", ["user1", "user2"]) + # user1 failed with exception, user2 succeeded assert len(result["add_members"]) == 2 username1, resp1 = result["add_members"][0] assert username1 == "user1" @@ -1171,7 +909,7 @@ def test_handles_exception_during_add_member( class TestRequestTenantAccessJustification: - """Tests for request_tenant_access with justification.""" + """Tests for request_tenant_access with justification (line 930).""" @patch("berdl_notebook_utils.minio_governance.operations.httpx") @patch("berdl_notebook_utils.minio_governance.operations.get_settings") @@ -1192,13 +930,15 @@ def test_includes_justification_in_payload(self, mock_settings, mock_httpx): result = request_tenant_access("kbase", permission="read_write", justification="Need data for project X") assert result["status"] == "pending" + assert result["permission"] == "read_write" + # Verify justification was in the payload call_kwargs = mock_httpx.post.call_args payload = call_kwargs[1]["json"] assert payload["justification"] == "Need data for project X" class TestRequestTenantAccessConnectionError: - """Tests for request_tenant_access RequestError.""" + """Tests for request_tenant_access RequestError (lines 955-956).""" @patch("berdl_notebook_utils.minio_governance.operations.httpx") @patch("berdl_notebook_utils.minio_governance.operations.get_settings") @@ -1215,7 +955,7 @@ def test_request_tenant_access_connection_error(self, mock_settings, mock_httpx) class TestRegeneratePolicies: - """Tests for regenerate_policies function.""" + """Tests for regenerate_policies function (lines 978-987).""" @patch("berdl_notebook_utils.minio_governance.operations.get_governance_client") @patch( @@ -1273,9 +1013,62 @@ def test_remove_group_member_read_only(self, mock_remove_member, mock_get_client assert calls[0][1]["group_name"] == "kbasero" -# ============================================================================= +# --------------------------------------------------------------------------- # Polaris credential caching tests -# ============================================================================= +# --------------------------------------------------------------------------- + + +class TestFetchWithFileCache: + """Tests for _fetch_with_file_cache helper.""" + + @patch("berdl_notebook_utils.minio_governance.operations.fcntl") + def test_returns_cached_value_on_cache_hit(self, mock_fcntl, tmp_path): + """Test returns cached value without calling fetch when cache hits.""" + cache_path = tmp_path / "creds.json" + sentinel = {"key": "cached_value"} + + read_cache = Mock(return_value=sentinel) + fetch = Mock() + write_cache = Mock() + + result = _fetch_with_file_cache(cache_path, read_cache, fetch, write_cache) + + assert result == sentinel + read_cache.assert_called_once_with(cache_path) + fetch.assert_not_called() + write_cache.assert_not_called() + + @patch("berdl_notebook_utils.minio_governance.operations.fcntl") + def test_fetches_and_writes_cache_on_cache_miss(self, mock_fcntl, tmp_path): + """Test fetches fresh data and writes cache when cache misses.""" + cache_path = tmp_path / "creds.json" + sentinel = {"key": "fresh_value"} + + read_cache = Mock(return_value=None) + fetch = Mock(return_value=sentinel) + write_cache = Mock() + + result = _fetch_with_file_cache(cache_path, read_cache, fetch, write_cache) + + assert result == sentinel + read_cache.assert_called_once_with(cache_path) + fetch.assert_called_once() + write_cache.assert_called_once_with(cache_path, sentinel) + + @patch("berdl_notebook_utils.minio_governance.operations.fcntl") + def test_returns_none_without_writing_when_fetch_fails(self, mock_fcntl, tmp_path): + """Test returns None and does not write cache when fetch returns None.""" + cache_path = tmp_path / "creds.json" + + read_cache = Mock(return_value=None) + fetch = Mock(return_value=None) + write_cache = Mock() + + result = _fetch_with_file_cache(cache_path, read_cache, fetch, write_cache) + + assert result is None + fetch.assert_called_once() + write_cache.assert_not_called() class TestGetPolarisCachePath: diff --git a/notebook_utils/tests/test_refresh.py b/notebook_utils/tests/test_refresh.py index 3660e82..ead4ad1 100644 --- a/notebook_utils/tests/test_refresh.py +++ b/notebook_utils/tests/test_refresh.py @@ -188,8 +188,8 @@ def test_cache_files_removed_first( refresh_spark_environment() - # _remove_cache_file called twice (minio + polaris) before rotate_minio_credentials - assert mock_remove.call_count == 2 + # _remove_cache_file called once (polaris only — MinIO uses direct API calls) + assert mock_remove.call_count == 1 # Settings cache cleared before credential fetches mock_settings.cache_clear.assert_called()