diff --git a/configs/ipython_startup/00-notebookutils.py b/configs/ipython_startup/00-notebookutils.py index ead8e79..aa2c89e 100644 --- a/configs/ipython_startup/00-notebookutils.py +++ b/configs/ipython_startup/00-notebookutils.py @@ -103,6 +103,7 @@ create_tenant_and_assign_users, get_group_sql_warehouse, get_minio_credentials, + get_polaris_credentials, get_my_accessible_paths, get_my_groups, get_my_policies, diff --git a/configs/ipython_startup/01-credentials.py b/configs/ipython_startup/01-credentials.py new file mode 100644 index 0000000..2ce303f --- /dev/null +++ b/configs/ipython_startup/01-credentials.py @@ -0,0 +1,51 @@ +""" +Initialize MinIO and Polaris credentials. + +This runs after 00-notebookutils.py loads all the imports, so get_minio_credentials +is already available in the global namespace. +""" + +# Setup logging +import logging + +logger = logging.getLogger("berdl.startup") +logger.setLevel(logging.INFO) +if not logger.handlers: + handler = logging.StreamHandler() + formatter = logging.Formatter("%(message)s") + handler.setFormatter(formatter) + logger.addHandler(handler) + +# --- MinIO Credentials --- +try: + # Set MinIO credentials to environment - also creates user if they don't exist + credentials = get_minio_credentials() # noqa: F821 + logger.info(f"✅ MinIO credentials set for user: {credentials.username}") + +except Exception as e: + import warnings + + warnings.warn(f"Failed to set MinIO credentials: {str(e)}", UserWarning) + logger.error(f"❌ Failed to set MinIO credentials: {str(e)}") + credentials = None + +# --- Polaris Credentials --- +try: + polaris_creds = get_polaris_credentials() # noqa: F821 + if polaris_creds: + logger.info(f"✅ Polaris credentials set for catalog: {polaris_creds['personal_catalog']}") + if polaris_creds["tenant_catalogs"]: + logger.info(f" Tenant catalogs: {', '.join(polaris_creds['tenant_catalogs'])}") + # Clear the settings cache so downstream code (e.g., Spark Connect server startup) + # picks up the POLARIS_CREDENTIAL, POLARIS_PERSONAL_CATALOG, and + # POLARIS_TENANT_CATALOGS env vars that get_polaris_credentials() just set. + get_settings.cache_clear() # noqa: F821 + else: + logger.info("ℹ️ Polaris not configured, skipping Polaris credential setup") + +except Exception as e: + import warnings + + warnings.warn(f"Failed to set Polaris credentials: {str(e)}", UserWarning) + logger.warning(f"⚠️ Failed to set Polaris credentials: {str(e)}") + polaris_creds = None diff --git a/configs/ipython_startup/01-minio-credentials.py b/configs/ipython_startup/01-minio-credentials.py deleted file mode 100644 index 379cad7..0000000 --- a/configs/ipython_startup/01-minio-credentials.py +++ /dev/null @@ -1,29 +0,0 @@ -""" -Initialize MinIO credentials and basic MinIO client. - -This runs after 00-notebookutils.py loads all the imports, so get_minio_credentials -is already available in the global namespace. -""" - -# Setup logging -import logging - -logger = logging.getLogger("berdl.startup") -logger.setLevel(logging.INFO) -if not logger.handlers: - handler = logging.StreamHandler() - formatter = logging.Formatter("%(message)s") - handler.setFormatter(formatter) - logger.addHandler(handler) - -try: - # Set MinIO credentials to environment - also creates user if they don't exist - credentials = get_minio_credentials() # noqa: F821 - logger.info(f"✅ MinIO credentials set for user: {credentials.username}") - -except Exception as e: - import warnings - - warnings.warn(f"Failed to set MinIO credentials: {str(e)}", UserWarning) - logger.error(f"❌ Failed to set MinIO credentials: {str(e)}") - credentials = None diff --git a/configs/jupyter_server_config.py b/configs/jupyter_server_config.py index 5de2403..b8fe90d 100644 --- a/configs/jupyter_server_config.py +++ b/configs/jupyter_server_config.py @@ -6,6 +6,7 @@ # Note: __file__ may not be defined when exec'd by traitlets config loader sys.path.insert(0, "/etc/jupyter") +from berdl_notebook_utils.berdl_settings import get_settings from hybridcontents import HybridContentsManager from jupyter_server.services.contents.largefilemanager import LargeFileManager from grouped_s3_contents import GroupedS3ContentsManager @@ -134,6 +135,27 @@ def get_user_governance_paths(): return sources +def provision_polaris(): + """Provision Polaris credentials at server startup and set env vars. + + Called once at Jupyter Server startup so credentials are available + before any notebook kernel opens. Subsequent calls from IPython startup + scripts will hit the file cache and return immediately. + """ + try: + from berdl_notebook_utils.minio_governance import get_polaris_credentials + + polaris_creds = get_polaris_credentials() + if polaris_creds: + logger.info(f"Polaris credentials provisioned for catalog: {polaris_creds['personal_catalog']}") + if polaris_creds["tenant_catalogs"]: + logger.info(f" Tenant catalogs: {', '.join(polaris_creds['tenant_catalogs'])}") + else: + logger.info("Polaris not configured, skipping Polaris credential provisioning") + except Exception as e: + logger.error(f"Failed to provision Polaris credentials: {e}") + + def start_spark_connect(): """Start Spark Connect server at Jupyter Server startup. @@ -165,10 +187,20 @@ def _start(): endpoint_url, access_key, secret_key, use_ssl = get_minio_config() governance_paths = get_user_governance_paths() -# 3. Start Spark Connect server in background (non-blocking) +# 3. Provision Polaris credentials — MUST be before Spark Connect so that +# POLARIS_CREDENTIAL env vars are set when generating spark-defaults.conf +provision_polaris() + +# Clear the settings cache so start_spark_connect picks up the new +# POLARIS_CREDENTIAL/POLARIS_PERSONAL_CATALOG/POLARIS_TENANT_CATALOGS env vars +# that provision_polaris() just set. Without this, the lru_cache returns the +# stale settings object captured before Polaris provisioning ran. +get_settings.cache_clear() + +# 4. Start Spark Connect server in background (non-blocking) start_spark_connect() -# 4. Configure HybridContentsManager +# 5. Configure HybridContentsManager # - Root ("") -> Local filesystem # - "datalake_minio" -> GroupedS3ContentsManager with all S3 paths as subdirectories c.HybridContentsManager.manager_classes = { diff --git a/configs/spark-defaults.conf.template b/configs/spark-defaults.conf.template index 6801757..8d86261 100644 --- a/configs/spark-defaults.conf.template +++ b/configs/spark-defaults.conf.template @@ -43,13 +43,14 @@ # ============================================================================== # ------------------------------------------------------------------------------ -# Delta Lake Configuration (STATIC - Server-Side Only) +# SQL Extensions and Catalog Configuration (STATIC - Server-Side Only) # ------------------------------------------------------------------------------ # These SQL extensions must be loaded when the Spark server starts. -# They initialize Delta Lake support by registering custom SparkSessionExtensions -# and catalog implementations that handle Delta table operations. +# Delta Lake + Iceberg + Sedona run side-by-side. +# The default catalog remains spark_catalog (Delta/Hive) for backward compatibility. +# Iceberg catalogs (my, tenant aliases) are added dynamically by connect_server.py. -spark.sql.extensions=io.delta.sql.DeltaSparkSessionExtension,org.apache.sedona.sql.SedonaSqlExtensions +spark.sql.extensions=io.delta.sql.DeltaSparkSessionExtension,org.apache.iceberg.spark.extensions.IcebergSparkSessionExtensions,org.apache.sedona.sql.SedonaSqlExtensions spark.sql.catalog.spark_catalog=org.apache.spark.sql.delta.catalog.DeltaCatalog # Delta Lake settings @@ -90,6 +91,9 @@ spark.hadoop.fs.s3a.path.style.access=true spark.hadoop.fs.s3a.impl=org.apache.hadoop.fs.s3a.S3AFileSystem spark.hadoop.fs.s3a.connection.ssl.enabled=false +# Polaris Iceberg catalog configuration will be appended dynamically by connect_server.py +# based on POLARIS_* environment variables (personal catalog "my" + tenant catalogs) + # ------------------------------------------------------------------------------ # KBase Authentication Interceptor (STATIC - Server-Side Only) # ------------------------------------------------------------------------------ @@ -116,7 +120,7 @@ spark.hadoop.fs.s3a.connection.ssl.enabled=false # Environment variables used by the namespace interceptor: # - BERDL_ALLOWED_NAMESPACE_PREFIXES: Comma-separated allowed prefixes # (e.g., "u_tgu2__,kbase_,research_"). Set dynamically by connect_server.py. -spark.connect.grpc.interceptor.classes=us.kbase.spark.KBaseAuthServerInterceptor,us.kbase.spark.NamespaceValidationInterceptor +spark.connect.grpc.interceptor.classes=us.kbase.spark.KBaseAuthServerInterceptor # ------------------------------------------------------------------------------ # Session Timeout (STATIC - Server-Side Only) diff --git a/docker-compose.yaml b/docker-compose.yaml index 2b80564..19e0bb6 100644 --- a/docker-compose.yaml +++ b/docker-compose.yaml @@ -68,7 +68,9 @@ # - MinIO: minio/minio123 # - PostgreSQL: hive/hivepassword services: - spark-notebook: + # Service names use the pattern: spark-notebook-{CI_KBASE_USERNAME} + # Update these keys if you change the usernames in .env + spark-notebook-tgu2: # image: ghcr.io/berdatalakehouse/spark_notebook:main # platform: linux/amd64 build: @@ -85,7 +87,7 @@ services: - CDM_TASK_SERVICE_URL=http://localhost:8080 - SPARK_CLUSTER_MANAGER_API_URL=http://localhost:8000 - SPARK_MASTER_URL=spark://spark-master:7077 - - BERDL_POD_IP=spark-notebook + - BERDL_POD_IP=spark-notebook-${CI_KBASE_USERNAME} - BERDL_HIVE_METASTORE_URI=thrift://hive-metastore:9083 # MINIO CONFIGURATION @@ -98,6 +100,9 @@ services: # DATALAKE MCP SERVER CONFIGURATION - DATALAKE_MCP_SERVER_URL=http://datalake-mcp-server:8000/apis/mcp + # POLARIS CONFIGURATION (per-user credentials provisioned dynamically by 01-credentials.py) + - POLARIS_CATALOG_URI=http://polaris:8181/api/catalog + # TRINO CONFIGURATION - TRINO_HOST=trino - TRINO_PORT=8080 @@ -148,6 +153,9 @@ services: - KBASE_ADMIN_ROLES=CDM_JUPYTERHUB_ADMIN - KBASE_APPROVED_ROLES=BERDL_USER - REDIS_URL=redis://redis:6379 + # Polaris admin credentials (only the governance service needs root access) + - POLARIS_CATALOG_URI=http://polaris:8181/api/catalog + - POLARIS_CREDENTIAL=root:s3cr3t # Credential store (PostgreSQL) - MMS_DB_HOST=postgres - MMS_DB_PORT=5432 @@ -182,8 +190,16 @@ services: - KBASE_AUTH_URL=https://ci.kbase.us/services/auth/ - KBASE_REQUIRED_ROLES=BERDL_USER - MFA_EXEMPT_USERS=${CI_KBASE_USERNAME} + # POLARIS CONFIGURATION (per-user credentials provisioned dynamically) + - POLARIS_CATALOG_URI=http://polaris:8181/api/catalog - BERDL_REDIS_HOST=redis - BERDL_REDIS_PORT=6379 + volumes: + # Mount the shared /home directory to access all users' credentials + # This allows the MCP server to dynamically read any user's credentials + # from /home/{username}/.berdl_minio_credentials + # In K8s: mount the parent directory or use a shared volume + - users_home:/home:ro depends_on: - hive-metastore - minio @@ -298,7 +314,13 @@ services: volumes: - postgres_data:/var/lib/postgresql/data - ./scripts/init-postgres-readonly.sh:/docker-entrypoint-initdb.d/01-init-postgres-readonly.sh:ro - - ./scripts/init-mms-db.sh:/docker-entrypoint-initdb.d/02-init-mms-db.sh:ro + - ./scripts/init-polaris-db.sh:/docker-entrypoint-initdb.d/02-init-polaris-db.sh:ro + - ./scripts/init-mms-db.sh:/docker-entrypoint-initdb.d/03-init-mms-db.sh:ro + healthcheck: + test: ["CMD-SHELL", "pg_isready -U hive"] + interval: 5s + timeout: 2s + retries: 15 hive-metastore: # image: ghcr.io/berdatalakehouse/hive_metastore:main @@ -369,9 +391,73 @@ services: echo 'MinIO bucket creation complete.'; " + polaris-bootstrap: + image: apache/polaris-admin-tool:latest + environment: + - POLARIS_PERSISTENCE_TYPE=relational-jdbc + - QUARKUS_DATASOURCE_DB_KIND=postgresql + - QUARKUS_DATASOURCE_JDBC_URL=jdbc:postgresql://postgres:5432/polaris + - QUARKUS_DATASOURCE_USERNAME=hive + - QUARKUS_DATASOURCE_PASSWORD=hivepassword + # Bootstrap exits 3 if already bootstrapped (expected with persistent storage). + # Treat exit 3 as success so docker compose doesn't fail on subsequent runs. + entrypoint: ["sh", "-c"] + command: + - | + java -jar /deployments/polaris-admin-tool.jar bootstrap --realm=POLARIS --credential=POLARIS,root,s3cr3t + rc=$$? + if [ $$rc -eq 3 ]; then + echo "Already bootstrapped — skipping (OK)" + exit 0 + fi + exit $$rc + depends_on: + postgres: + condition: service_healthy + + polaris: + image: apache/polaris:latest + ports: + - "8181:8181" + environment: + # Persistence — PostgreSQL instead of in-memory + - POLARIS_PERSISTENCE_TYPE=relational-jdbc + - QUARKUS_DATASOURCE_DB_KIND=postgresql + - QUARKUS_DATASOURCE_JDBC_URL=jdbc:postgresql://postgres:5432/polaris + - QUARKUS_DATASOURCE_USERNAME=hive + - QUARKUS_DATASOURCE_PASSWORD=hivepassword + # Realm configuration + - POLARIS_REALM_NAME=default-realm + # MinIO credentials for Polaris's own S3 access (metadata files). + # Polaris reads endpointInternal + pathStyleAccess from each catalog's storageConfigInfo. + # STS is disabled per-catalog via stsUnavailable:true (not the global SKIP_CREDENTIAL flag). + - AWS_REGION=us-east-1 + - AWS_ACCESS_KEY_ID=minio + - AWS_SECRET_ACCESS_KEY=minio123 + healthcheck: + test: ["CMD", "curl", "-f", "http://localhost:8182/q/health"] + interval: 10s + timeout: 5s + retries: 5 + depends_on: + polaris-bootstrap: + condition: service_completed_successfully + + polaris-ui: + image: ghcr.io/binarycat0/apache-polaris-ui:latest + ports: + - "3000:3000" + environment: + # Server-side env vars used by the Next.js API routes for proxying auth + - POLARIS_MANAGEMENT_API_URL=http://polaris:8181/api/management/v1 + - POLARIS_CATALOG_API_URL=http://polaris:8181/api/catalog/v1 + depends_on: + polaris: + condition: service_started + volumes: postgres_data: minio_data: redis_data: global_share: - users_home: # Shared volume for all user home directories \ No newline at end of file + users_home: # Shared volume for all user home directories diff --git a/docs/data_sharing_guide.md b/docs/data_sharing_guide.md index dbac753..433b5f3 100644 --- a/docs/data_sharing_guide.md +++ b/docs/data_sharing_guide.md @@ -45,11 +45,13 @@ All BERDL JupyterHub notebooks automatically import these data governance functi **Pre-Initialized Client:** - `governance` - Pre-initialized `DataGovernanceClient()` instance for advanced operations - **Other Auto-Imported Functions:** -- `get_spark_session()` - Create Spark sessions with Delta Lake support +- `get_spark_session()` - Create Spark sessions with Iceberg + Delta Lake support +- `create_namespace_if_not_exists()` - Create namespaces (use `iceberg=True` for Iceberg catalogs) - Plus many other utility functions for data operations +> **Note:** With the migration to Iceberg, **tenant catalogs** are the recommended way to share data. Create tables in a tenant catalog (e.g., `kbase`) and all members can access them. See the [Iceberg Migration Guide](iceberg_migration_guide.md) for details. + ### Quick Start ```python @@ -258,7 +260,7 @@ if response.errors: ## Public and Private Table Access (DEPRECATED) -> **⚠️ DEPRECATION WARNING**: Direct public path sharing functions (`make_table_public`, `make_table_private`) are deprecated. Please create a namespace under the `globalusers` tenant for public sharing activities instead. +> **⚠️ DEPRECATION WARNING**: Direct public path sharing functions (`make_table_public`, `make_table_private`) are deprecated. Please create a namespace under the `kbase` tenant for public sharing activities instead. ### Make Tables Publicly Accessible diff --git a/docs/iceberg_migration_guide.md b/docs/iceberg_migration_guide.md new file mode 100644 index 0000000..b0d7a88 --- /dev/null +++ b/docs/iceberg_migration_guide.md @@ -0,0 +1,426 @@ +# Polaris Catalog Migration Guide + +## Why We Migrated + +KBERDL previously used **Delta Lake + Hive Metastore** with namespace isolation enforced by naming conventions — every database had to be prefixed with `u_{username}__` (personal) or `{tenant}_` (shared). This worked but had several limitations: + +- **Naming conventions are fragile** — isolation depends on every user following prefix rules correctly +- **No catalog-level boundaries** — all users share the same Hive Metastore, so a misconfigured namespace could leak data +- **Single-engine lock-in** — Delta Lake tables are only accessible through Spark with the Delta extension +- **No time travel or schema evolution** — Delta supports these, but Hive Metastore doesn't track them natively + +We migrated to **Apache Polaris + Apache Iceberg**, which provides: + +- **Catalog-level isolation** — each user gets their own Polaris catalog (`my`), and each tenant gets a shared catalog (e.g., `kbase`). No naming prefixes needed. +- **Multi-engine support** — Iceberg tables can be read by Spark, Trino, DuckDB, PyIceberg, and other engines +- **Native time travel** — query any previous snapshot of your data +- **Schema evolution** — add, rename, or drop columns without rewriting data +- **ACID transactions** — concurrent reads and writes are safe + +## What Changed + +| Aspect | Before (Delta/Hive) | After (Polaris/Iceberg) | +|--------|---------------------|------------------------| +| **Metadata catalog** | Hive Metastore | Apache Polaris (Iceberg REST catalog) | +| **Table format** | Delta Lake | Apache Iceberg | +| **Isolation model** | Naming prefixes (`u_user__`, `tenant_`) | Separate catalogs per user/tenant | +| **Your personal catalog** | Hive (shared, prefix-isolated) | `my` (dedicated Polaris catalog) | +| **Tenant catalogs** | Hive (shared, prefix-isolated) | One catalog per tenant (e.g., `kbase`) | +| **Credentials** | MinIO S3 keys only | MinIO S3 keys + Polaris OAuth2 (auto-provisioned) | +| **Spark session** | `get_spark_session()` | `get_spark_session()` (unchanged) | + +> **Migration complete:** All existing Delta Lake tables have been migrated to Iceberg in Polaris. Your data is available in the new Iceberg catalogs (`my` for personal, tenant name for shared). The original Delta tables remain accessible during the dual-read period for backward compatibility, but all new tables should be created in Iceberg. + +## Side-by-Side Comparison + +### Create a Namespace + +The function call is **unchanged** — only the return value format differs. + + + + + + + +
Before (Delta/Hive)After (Iceberg/Polaris)
+ +```python +spark = get_spark_session() + +# Personal namespace +ns = create_namespace_if_not_exists( + spark, "analysis" +) +# Returns: "u_tgu2__analysis" + +# Tenant namespace +ns = create_namespace_if_not_exists( + spark, "research", + tenant_name="kbase" +) +# Returns: "kbase_research" +``` + + + +```python +spark = get_spark_session() + +# Personal namespace (same call) +ns = create_namespace_if_not_exists( + spark, "analysis" +) +# Returns: "my.analysis" + +# Tenant namespace (same call) +ns = create_namespace_if_not_exists( + spark, "research", + tenant_name="kbase" +) +# Returns: "kbase.research" +``` + +
+ +### Write a Table + + + + + + + +
Before (Delta/Hive)After (Iceberg/Polaris)
+ +```python +ns = create_namespace_if_not_exists( + spark, "analysis" +) +# ns = "u_tgu2__analysis" + +df = spark.createDataFrame(data, columns) + +# Delta format +df.write.format("delta").saveAsTable( + f"{ns}.my_table" +) +``` + + + +```python +ns = create_namespace_if_not_exists( + spark, "analysis" +) +# ns = "my.analysis" + +df = spark.createDataFrame(data, columns) + +# Iceberg format (via writeTo API) +df.writeTo(f"{ns}.my_table").createOrReplace() + +# Or append to existing table +df.writeTo(f"{ns}.my_table").append() +``` + +
+ +### Read a Table + + + + + + + +
Before (Delta/Hive)After (Iceberg/Polaris)
+ +```python +# Query with prefixed namespace +df = spark.sql(""" + SELECT * FROM u_tgu2__analysis.my_table +""") + +# Or use the variable +df = spark.sql( + f"SELECT * FROM {ns}.my_table" +) +``` + + + +```python +# Query with catalog.namespace +df = spark.sql(""" + SELECT * FROM my.analysis.my_table +""") + +# Or use the variable +df = spark.sql( + f"SELECT * FROM {ns}.my_table" +) +``` + +
+ +### Cross-Catalog Queries + + + + + + + +
Before (Delta/Hive)After (Iceberg/Polaris)
+ +```python +# Everything in one Hive catalog +# Must know the naming convention +spark.sql(""" + SELECT u.name, d.dept_name + FROM u_tgu2__analysis.users u + JOIN kbase_shared.depts d + ON u.dept_id = d.id +""") +``` + + + +```python +# Explicit catalog boundaries +spark.sql(""" + SELECT u.name, d.dept_name + FROM my.analysis.users u + JOIN kbase.shared.depts d + ON u.dept_id = d.id +""") +``` + +
+ +### List Namespaces and Tables + + + + + + + +
Before (Delta/Hive)After (Iceberg/Polaris)
+ +```python +# Lists all Hive databases +# (filtered by u_{user}__ prefix) +list_namespaces(spark) + +# List tables in a namespace +list_tables(spark, "u_tgu2__analysis") +``` + + + +```python +# List namespaces in your catalog +spark.sql("SHOW NAMESPACES IN my") + +# List tables in a namespace +spark.sql( + "SHOW TABLES IN my.analysis" +) + +# list_tables still works +list_tables(spark, "my.analysis") +``` + +
+ +### Drop a Table + + + + + + + +
Before (Delta/Hive)After (Iceberg/Polaris)
+ +```python +spark.sql( + "DROP TABLE IF EXISTS " + "u_tgu2__analysis.my_table" +) +``` + + + +```python +spark.sql( + "DROP TABLE IF EXISTS " + "my.analysis.my_table" +) +``` + +> **Note:** `DROP TABLE` removes the catalog entry but does **not** delete the underlying S3 data files. `DROP TABLE ... PURGE` also does not delete files due to a [known Iceberg bug](https://github.com/apache/iceberg/issues/14743). To fully remove data, delete files from S3 directly using `get_minio_client()`. + +
+ +## Iceberg-Only Features + +These features are only available with Iceberg tables. + +### Time Travel + +Query a previous version of your table: + +```python +# View snapshot history +spark.sql("SELECT * FROM my.analysis.my_table.snapshots") + +# Read data as it was at a specific snapshot +spark.sql(""" + SELECT * FROM my.analysis.my_table + VERSION AS OF 1234567890 +""") + +# Read data as it was at a specific timestamp +spark.sql(""" + SELECT * FROM my.analysis.my_table + TIMESTAMP AS OF '2026-03-01 12:00:00' +""") +``` + +### Schema Evolution + +Modify table schema without rewriting data: + +```python +# Add a column +spark.sql("ALTER TABLE my.analysis.my_table ADD COLUMN email STRING") + +# Rename a column +spark.sql("ALTER TABLE my.analysis.my_table RENAME COLUMN name TO full_name") + +# Drop a column +spark.sql("ALTER TABLE my.analysis.my_table DROP COLUMN temp_field") +``` + +### Snapshot Management + +```python +# View snapshot history +display_df(spark.sql("SELECT * FROM my.analysis.my_table.snapshots")) + +# View file-level details +display_df(spark.sql("SELECT * FROM my.analysis.my_table.files")) + +# View table history +display_df(spark.sql("SELECT * FROM my.analysis.my_table.history")) +``` + +## Complete Example + +```python +# 1. Create a Spark session +print("1. Creating Spark session...") +spark = get_spark_session("MyAnalysis") +print(" Spark session ready.") + +# 2. Create a personal namespace +print("\n2. Creating personal namespace...") +ns = create_namespace_if_not_exists(spark, "demo") +print(f" Namespace: {ns}") # "my.demo" + +# 3. Create a table +print(f"\n3. Creating table {ns}.employees...") +data = [(1, "Alice", 25), (2, "Bob", 30), (3, "Charlie", 35)] +df = spark.createDataFrame(data, ["id", "name", "age"]) +df.writeTo(f"{ns}.employees").createOrReplace() +print(f" Table {ns}.employees created with {df.count()} rows.") + +# 4. Query the table +print(f"\n4. Querying {ns}.employees:") +result = spark.sql(f"SELECT * FROM {ns}.employees") +display_df(result) + +# 5. Append more data +print(f"\n5. Appending data to {ns}.employees...") +new_data = [(4, "Diana", 28)] +new_df = spark.createDataFrame(new_data, ["id", "name", "age"]) +new_df.writeTo(f"{ns}.employees").append() +print(f" Appended {new_df.count()} row(s). Total: {spark.sql(f'SELECT * FROM {ns}.employees').count()} rows.") +display_df(spark.sql(f"SELECT * FROM {ns}.employees")) + +# 6. View snapshots and files (two snapshots now: create + append) +print(f"\n6. Viewing snapshots and files for {ns}.employees:") +print(" Snapshots:") +display_df(spark.sql(f"SELECT * FROM {ns}.employees.snapshots")) +print(" Data files:") +display_df(spark.sql(f"SELECT * FROM {ns}.employees.files")) + +# 7. Time travel to the original version +print(f"\n7. Time travel to original version (before append)...") +first_snapshot = spark.sql( + f"SELECT snapshot_id FROM {ns}.employees.snapshots " + f"ORDER BY committed_at LIMIT 1" +).collect()[0]["snapshot_id"] +print(f" First snapshot ID: {first_snapshot}") +original = spark.sql( + f"SELECT * FROM {ns}.employees VERSION AS OF {first_snapshot}" +) +print(f" Original version ({original.count()} rows, before append):") +display_df(original) + +# 8. Tenant namespace (shared with your team) +print("\n8. Creating tenant namespace and shared table...") +tenant_ns = create_namespace_if_not_exists( + spark, "shared_data", tenant_name="globalusers" +) +print(f" Tenant namespace: {tenant_ns}") +df.writeTo(f"{tenant_ns}.team_employees").createOrReplace() +print(f" Table {tenant_ns}.team_employees created.") + +# 9. Cross-catalog query +print(f"\n9. Cross-catalog query ({ns} + {tenant_ns}):") +cross_result = spark.sql(f""" + SELECT * FROM {ns}.employees + UNION ALL + SELECT * FROM {tenant_ns}.team_employees +""") +display_df(cross_result) + +# 10. Schema evolution +print(f"\n10. Adding 'email' column to {ns}.employees...") +spark.sql(f"ALTER TABLE {ns}.employees ADD COLUMN email STRING") +print(f" Updated schema:") +spark.sql(f"DESCRIBE {ns}.employees").show() +display_df(spark.sql(f"SELECT * FROM {ns}.employees")) +print("Done!") +``` + +## FAQ + +**Q: Do I need to change my `get_spark_session()` call?** +No. `get_spark_session()` automatically configures both Delta and Iceberg catalogs. Your Polaris catalogs are ready to use. + +**Q: Can I still access my old Delta tables?** +Yes. During the dual-read period, your Delta tables remain accessible at their original names (e.g., `u_{username}__analysis.my_table`). Iceberg copies are at `my.analysis.my_table`. + +**Q: What happened to my namespace prefixes (`u_{username}__`)?** +They're no longer needed. With Iceberg, your personal catalog `my` is isolated at the catalog level — only you can access it. No prefix is required. + +**Q: How do I share data with my team?** +Create a table in a tenant catalog: +```python +ns = create_namespace_if_not_exists( + spark, "shared_data", tenant_name="kbase" +) +df.writeTo(f"{ns}.my_shared_table").createOrReplace() +``` +All members of the `kbase` tenant can read this table. + +**Q: Why does `DROP TABLE PURGE` leave files on S3?** +This is a [known Iceberg bug](https://github.com/apache/iceberg/issues/14743) — Spark's `SparkCatalog` ignores the `PURGE` flag when talking to REST catalogs. `DROP TABLE` only removes the catalog entry. To delete the S3 files, use the MinIO client directly. + +**Q: Can I use `df.write.format("iceberg").saveAsTable(...)` instead of `writeTo`?** +Yes, both work. `writeTo` is the recommended Iceberg API since it supports `createOrReplace()` and `append()` natively. diff --git a/docs/tenant_sql_warehouse_guide.md b/docs/tenant_sql_warehouse_guide.md index 6c5b4ed..6d9015d 100644 --- a/docs/tenant_sql_warehouse_guide.md +++ b/docs/tenant_sql_warehouse_guide.md @@ -6,152 +6,114 @@ This guide explains how to configure your Spark session to write tables to eithe The BERDL JupyterHub environment supports two types of SQL warehouses: -1. **User SQL Warehouse**: Your personal workspace for tables -2. **Tenant SQL Warehouse**: Shared workspace for team/organization tables +1. **User SQL Warehouse**: Your personal workspace for tables (Iceberg catalog: `my`) +2. **Tenant SQL Warehouse**: Shared workspace for team/organization tables (Iceberg catalog per tenant, e.g., `kbase`) -## Spark Connect Architecture - -BERDL uses **Spark Connect**, which provides a client-server architecture: -- **Connection Protocol**: Uses gRPC for efficient communication with remote Spark clusters -- **Spark Connect Server**: Runs locally in your notebook pod as a proxy -- **Spark Connect URL**: `sc://localhost:15002` -- **Driver and Executors**: Runs locally in your notebook pod - -### Automatic Credential Management - -Your MinIO S3 credentials are automatically configured when your notebook server starts: - -1. **JupyterHub** calls the governance API to create/retrieve your MinIO credentials -2. **Environment Variables** are set in your notebook: `MINIO_ACCESS_KEY`, `MINIO_SECRET_KEY`, `MINIO_USERNAME` -3. **Spark Cluster** receives your credentials and configures S3 access for Hive Metastore operations -4. **Notebooks** use credentials from environment (no API call needed) - -This ensures secure, per-user S3 access without exposing credentials in code. +Isolation is provided at the **catalog level** — no naming prefixes needed. -## Quick Comparison +| Configuration | Catalog | Example Namespace | +|--------------|---------|-------------------| +| `create_namespace_if_not_exists(spark, "analysis")` | `my` (personal) | `my.analysis` | +| `create_namespace_if_not_exists(spark, "research", tenant_name="kbase")` | `kbase` (tenant) | `kbase.research` | -| Configuration | Warehouse Location | Tables Location | Default Namespace | -|--------------|-------------------|----------------|-------------------| -| `create_namespace_if_not_exists(spark)` | `s3a://cdm-lake/users-sql-warehouse/{username}/` | Personal workspace | `u_{username}__default` | -| `create_namespace_if_not_exists(spark, tenant_name="kbase")` | `s3a://cdm-lake/tenant-sql-warehouse/kbase/` | Tenant workspace | `kbase_default` | +## Personal SQL Warehouse -## Personal SQL Warehouse (Default) +Your personal catalog (`my`) is automatically provisioned and only accessible by you. ### Usage -```python -from berdl_notebook_utils import get_spark_session -from berdl_notebook_utils.spark.database import create_namespace_if_not_exists - -# Create Spark session -spark = get_spark_session("MyPersonalApp") - -# Create namespace in your personal SQL warehouse -namespace = create_namespace_if_not_exists(spark) -``` -### Where Your Tables Go -- **Warehouse Directory**: `s3a://cdm-lake/users-sql-warehouse/{username}/` -- **Default Namespace**: `u_{username}__default` (username prefix added automatically) -- **Table Location**: `s3a://cdm-lake/users-sql-warehouse/{username}/u_{username}__default.db/your_table/` - -### Example ```python -from berdl_notebook_utils import get_spark_session -from berdl_notebook_utils.spark.database import create_namespace_if_not_exists - -# Create Spark session -spark = get_spark_session("MyPersonalApp") +spark = get_spark_session("MyApp") -# Create default namespace (creates "u_{username}__default" with username prefix) -namespace = create_namespace_if_not_exists(spark) -print(f"Created namespace: {namespace}") # Output: "Created namespace: u_{username}__default" +# Create namespace in your personal Iceberg catalog +ns = create_namespace_if_not_exists(spark, "analysis") +# Returns: "my.analysis" -# Or create custom namespace (creates "u_{username}__analysis" with username prefix) -analysis_namespace = create_namespace_if_not_exists(spark, namespace="analysis") -print(f"Created namespace: {analysis_namespace}") # Output: "Created namespace: u_{username}__analysis" - -# Create a DataFrame and save as table using returned namespace +# Write a table df = spark.createDataFrame([(1, "Alice"), (2, "Bob")], ["id", "name"]) -df.write.format("delta").saveAsTable(f"{namespace}.my_personal_table") +df.writeTo(f"{ns}.my_table").createOrReplace() -# Table will be stored at: -# s3a://cdm-lake/users-sql-warehouse/{username}/u_{username}__default.db/my_personal_table/ +# Read it back +spark.sql(f"SELECT * FROM {ns}.my_table").show() ``` +### Where Your Tables Go + +- **Catalog**: `my` (dedicated Polaris catalog) +- **Namespace**: User-defined (e.g., `analysis`, `experiments`) +- **Full table path**: `my.{namespace}.{table_name}` +- **S3 location**: Managed automatically by Polaris + ## Tenant SQL Warehouse +Tenant catalogs are shared across all members of a tenant/group. + ### Usage + ```python -from berdl_notebook_utils import get_spark_session -from berdl_notebook_utils.spark.database import create_namespace_if_not_exists +spark = get_spark_session("TeamApp") + +# Create namespace in tenant Iceberg catalog +ns = create_namespace_if_not_exists( + spark, "research", tenant_name="kbase" +) +# Returns: "kbase.research" -# Create Spark session -spark = get_spark_session("TeamAnalysis") +# Write a shared table +df.writeTo(f"{ns}.shared_dataset").createOrReplace() -# Create namespace in tenant SQL warehouse (validates membership) -namespace = create_namespace_if_not_exists(spark, tenant_name="kbase") +# All kbase members can read this table +spark.sql(f"SELECT * FROM {ns}.shared_dataset").show() ``` ### Where Your Tables Go -- **Warehouse Directory**: `s3a://cdm-lake/tenant-sql-warehouse/{tenant}/` -- **Default Namespace**: `{tenant}_default` (tenant prefix added automatically) -- **Table Location**: `s3a://cdm-lake/tenant-sql-warehouse/{tenant}/{tenant}_default.db/your_table/` + +- **Catalog**: Tenant name (e.g., `kbase`) +- **Namespace**: User-defined (e.g., `research`, `shared_data`) +- **Full table path**: `{tenant}.{namespace}.{table_name}` +- **S3 location**: Managed automatically by Polaris ### Requirements + - You must be a member of the specified tenant/group - The system validates your membership before allowing access -### Example -```python -from berdl_notebook_utils import get_spark_session -from berdl_notebook_utils.spark.database import create_namespace_if_not_exists - -# Create Spark session -spark = get_spark_session("TeamAnalysis") - -# Create default namespace in tenant warehouse (creates "{tenant}_default" with tenant prefix) -namespace = create_namespace_if_not_exists(spark, tenant_name="kbase") -print(f"Created namespace: {namespace}") # Output: "Created namespace: kbase_default" +## Spark Connect Architecture -# Or create custom namespace (creates "{tenant}_research" with tenant prefix) -research_namespace = create_namespace_if_not_exists(spark, namespace="research", tenant_name="kbase") -print(f"Created namespace: {research_namespace}") # Output: "Created namespace: kbase_research" +BERDL uses **Spark Connect**, which provides a client-server architecture: +- **Connection Protocol**: Uses gRPC for efficient communication with remote Spark clusters +- **Spark Connect Server**: Runs locally in your notebook pod as a proxy +- **Spark Connect URL**: `sc://localhost:15002` +- **Driver and Executors**: Runs locally in your notebook pod -# Create a DataFrame and save as table using returned namespace -df = spark.createDataFrame([(1, "Dataset A"), (2, "Dataset B")], ["id", "dataset"]) -df.write.format("delta").saveAsTable(f"{namespace}.shared_analysis") +### Automatic Credential Management -# Table will be stored at: -# s3a://cdm-lake/tenant-sql-warehouse/kbase/kbase_default.db/shared_analysis/ -``` +Your credentials are automatically configured when your notebook server starts: -## Advanced Namespace Management +1. **JupyterHub** calls the governance API to create/retrieve your MinIO and Polaris credentials +2. **Environment Variables** are set in your notebook (MinIO S3 keys + Polaris OAuth2 tokens) +3. **Spark Session** is pre-configured with access to your personal catalog and tenant catalogs +4. **Notebooks** use credentials from environment (no API call needed) -### Custom Namespaces (Default Behavior) -```python -# Personal warehouse with custom namespace (prefix enabled by default) -spark = get_spark_session() -exp_namespace = create_namespace_if_not_exists(spark, "experiments") # Returns "u_{username}__experiments" +This ensures secure, per-user access without exposing credentials in code. -# Tenant warehouse with custom namespace (prefix enabled by default) -data_namespace = create_namespace_if_not_exists(spark, "research_data", tenant_name="kbase") # Returns "kbase_research_data" +## Cross-Catalog Queries -# Use returned namespace names for table operations -df.write.format("delta").saveAsTable(f"{exp_namespace}.my_experiment_table") -df.write.format("delta").saveAsTable(f"{data_namespace}.shared_dataset") +You can query across personal and tenant catalogs in a single query: -# Tables will be stored at: -# s3a://cdm-lake/users-sql-warehouse/{username}/u_{username}__experiments.db/my_experiment_table/ -# s3a://cdm-lake/tenant-sql-warehouse/kbase/kbase_research_data.db/shared_dataset/ +```python +spark.sql(""" + SELECT u.name, d.dept_name + FROM my.analysis.users u + JOIN kbase.shared.depts d + ON u.dept_id = d.id +""") ``` ## Tips -- **Always use the returned namespace**: `create_namespace_if_not_exists()` returns the actual namespace name (with prefixes applied). Always use this value when creating tables. -- **Permission issues with manual namespaces**: If you create namespaces manually (without using `create_namespace_if_not_exists()`) and the namespace doesn't follow the expected naming rules (e.g., missing the `u_{username}__` or `{tenant}_` prefix), you may not have the correct permissions to read/write to that namespace. The governance system enforces permissions based on namespace naming conventions: - - User namespaces must start with `u_{username}__` to grant you access - - Tenant namespaces must start with `{tenant}_` and you must be a member of that tenant - - Namespaces without proper prefixes will result in "Access Denied" errors from MinIO +- **Always use the returned namespace**: `create_namespace_if_not_exists()` returns the fully qualified namespace (e.g., `my.analysis`). Always use this value when creating or querying tables. - **Tenant membership required**: Attempting to access a tenant warehouse without membership will fail. -- **Credentials are automatic**: MinIO credentials are set by JupyterHub - you don't need to call any API to get them. +- **Credentials are automatic**: MinIO and Polaris credentials are set by JupyterHub — you don't need to call any API to get them. - **Spark Connect is default**: All sessions use Spark Connect for better stability and resource isolation. +- **Iceberg features**: Your tables support time travel, schema evolution, and snapshot management. See the [Iceberg Migration Guide](iceberg_migration_guide.md) for details. diff --git a/docs/user_guide.md b/docs/user_guide.md index 6e29941..0e848a2 100644 --- a/docs/user_guide.md +++ b/docs/user_guide.md @@ -52,8 +52,9 @@ Log in using your KBase credentials (the same 3rd party identity provider, usern ### 3. Data Access By default, you have **read/write access** to: -- Your personal SQL warehouse (`s3a://cdm-lake/users-sql-warehouse/{username}/`) -- Any tenant SQL warehouses that you belong to (e.g., `s3a://cdm-lake/tenant-sql-warehouse/kbase/`) +- Your personal Iceberg catalog (`my`) — create namespaces and tables here +- Any tenant Iceberg catalogs you belong to (e.g., `kbase`) — shared team data +- Your personal S3 storage (`s3a://cdm-lake/users-sql-warehouse/{username}/`) For questions about data access or permissions, please reach out to the BERDL Platform team. @@ -80,9 +81,11 @@ spark = get_spark_session() This automatically configures: - Spark Connect server connection -- Hive Metastore integration +- Apache Iceberg catalogs (personal `my` + tenant catalogs via Polaris) - MinIO S3 access -- Delta Lake support +- Delta Lake support (legacy, for backward compatibility) + +> **Note:** BERDL has migrated from Delta Lake to Apache Iceberg. See the [Iceberg Migration Guide](iceberg_migration_guide.md) for details on the new catalog structure and Iceberg-specific features like time travel and schema evolution. #### 5.2 Displaying DataFrames diff --git a/notebook_utils/berdl_notebook_utils/agent/tools.py b/notebook_utils/berdl_notebook_utils/agent/tools.py index 3d4bf4d..47f0cc1 100644 --- a/notebook_utils/berdl_notebook_utils/agent/tools.py +++ b/notebook_utils/berdl_notebook_utils/agent/tools.py @@ -82,7 +82,7 @@ def list_databases(dummy_input: str = "") -> str: JSON-formatted list of database names or error message """ try: - databases = mcp_ops.mcp_list_databases(use_hms=True) + databases = mcp_ops.mcp_list_databases() if not databases: return "No databases found. You may need to create a database first." return _format_result(databases) @@ -102,7 +102,7 @@ def list_tables(database: str) -> str: JSON-formatted list of table names or error message """ try: - tables = mcp_ops.mcp_list_tables(database=database, use_hms=True) + tables = mcp_ops.mcp_list_tables(database=database) if not tables: return f"No tables found in database '{database}'. The database may be empty." return _format_result(tables) @@ -143,7 +143,7 @@ def get_database_structure(input_str: str = "false") -> str: try: # Parse boolean input with_schema = input_str.lower() in ("true", "1", "yes") - structure = mcp_ops.mcp_get_database_structure(with_schema=with_schema, use_hms=True) + structure = mcp_ops.mcp_get_database_structure(with_schema=with_schema) return _format_result(structure) except Exception as e: logger.error(f"Error getting database structure: {e}") diff --git a/notebook_utils/berdl_notebook_utils/berdl_settings.py b/notebook_utils/berdl_notebook_utils/berdl_settings.py index 8accfc3..8f68136 100644 --- a/notebook_utils/berdl_notebook_utils/berdl_settings.py +++ b/notebook_utils/berdl_notebook_utils/berdl_settings.py @@ -43,7 +43,7 @@ class BERDLSettings(BaseSettings): ) # Hive configuration - BERDL_HIVE_METASTORE_URI: AnyUrl # Accepts thrift:// + BERDL_HIVE_METASTORE_URI: AnyUrl | None = Field(default=None, description="Hive metastore URI (thrift://...)") # Profile-specific Spark configuration from JupyterHub SPARK_WORKER_COUNT: int = Field(default=1, description="Number of Spark workers from profile") @@ -78,6 +78,18 @@ class BERDLSettings(BaseSettings): default=None, description="Tenant Access Request Service URL for Slack-based approval workflow" ) + # Polaris Iceberg Catalog configuration + POLARIS_CATALOG_URI: AnyHttpUrl | None = Field( + default=None, description="Polaris REST Catalog endpoint (e.g., http://polaris:8181/api/catalog)" + ) + POLARIS_CREDENTIAL: str | None = Field(default=None, description="Polaris client_id:client_secret credential") + POLARIS_PERSONAL_CATALOG: str | None = Field( + default=None, description="Polaris personal catalog name (e.g., user_tgu2)" + ) + POLARIS_TENANT_CATALOGS: str | None = Field( + default=None, description="Comma-separated Polaris tenant catalog names" + ) + def validate_environment(): """ diff --git a/notebook_utils/berdl_notebook_utils/mcp/operations.py b/notebook_utils/berdl_notebook_utils/mcp/operations.py index ae86844..9d828bf 100644 --- a/notebook_utils/berdl_notebook_utils/mcp/operations.py +++ b/notebook_utils/berdl_notebook_utils/mcp/operations.py @@ -62,18 +62,17 @@ def _handle_error_response(response: Any, operation: str) -> None: raise Exception(error_msg) -def mcp_list_databases(use_hms: bool = True) -> list[str]: +def mcp_list_databases() -> list[str]: """ - List all databases in the Hive metastore via MCP server. + List all databases (Iceberg catalog namespaces) via MCP server. This function connects to the global datalake-mcp-server, which will use your authentication token to connect to your personal Spark Connect server. - Args: - use_hms: If True, uses Hive Metastore client for faster retrieval (default: True) + Note: Only Iceberg catalogs are listed (spark_catalog is excluded). Returns: - List of database names + List of database names (e.g., ['my.demo_dataset', 'kbase.pangenome']) Raises: Exception: If the MCP server returns an error or is unreachable @@ -81,12 +80,12 @@ def mcp_list_databases(use_hms: bool = True) -> list[str]: Example: >>> databases = mcp_list_databases() >>> print(databases) - ['default', 'my_database', 'analytics'] + ['my.demo_dataset', 'kbase.analytics'] """ client = get_datalake_mcp_client() - request = DatabaseListRequest(use_hms=use_hms) + request = DatabaseListRequest() - logger.debug(f"Listing databases via MCP server (use_hms={use_hms})") + logger.debug("Listing databases via MCP server") response = list_databases.sync(client=client, body=request) _handle_error_response(response, "list_databases") @@ -98,13 +97,12 @@ def mcp_list_databases(use_hms: bool = True) -> list[str]: return response.databases -def mcp_list_tables(database: str, use_hms: bool = True) -> list[str]: +def mcp_list_tables(database: str) -> list[str]: """ List all tables in a specific database via MCP server. Args: - database: Name of the database - use_hms: If True, uses Hive Metastore client for faster retrieval (default: True) + database: Name of the database (e.g., 'my.demo_dataset') Returns: List of table names in the database @@ -113,12 +111,12 @@ def mcp_list_tables(database: str, use_hms: bool = True) -> list[str]: Exception: If the MCP server returns an error or is unreachable Example: - >>> tables = mcp_list_tables("my_database") + >>> tables = mcp_list_tables("my.demo_dataset") >>> print(tables) ['users', 'orders', 'products'] """ client = get_datalake_mcp_client() - request = TableListRequest(database=database, use_hms=use_hms) + request = TableListRequest(database=database) logger.debug(f"Listing tables in database '{database}' via MCP server") response = list_database_tables.sync(client=client, body=request) @@ -167,14 +165,13 @@ def mcp_get_table_schema(database: str, table: str) -> list[str]: def mcp_get_database_structure( - with_schema: bool = False, use_hms: bool = True + with_schema: bool = False, ) -> dict[str, list[str] | dict[str, list[str]]]: """ Get the complete structure of all databases via MCP server. Args: with_schema: If True, includes table schemas (column names) (default: False) - use_hms: If True, uses Hive Metastore client for faster retrieval (default: True) Returns: Dictionary mapping database names to either: @@ -188,15 +185,15 @@ def mcp_get_database_structure( >>> # Without schema >>> structure = mcp_get_database_structure() >>> print(structure) - {'default': ['table1', 'table2'], 'analytics': ['metrics', 'events']} + {'my.demo': ['table1', 'table2'], 'kbase.analytics': ['metrics', 'events']} >>> # With schema >>> structure = mcp_get_database_structure(with_schema=True) >>> print(structure) - {'default': {'table1': ['col1', 'col2'], 'table2': ['col3', 'col4']}} + {'my.demo': {'table1': ['col1', 'col2'], 'table2': ['col3', 'col4']}} """ client = get_datalake_mcp_client() - request = DatabaseStructureRequest(with_schema=with_schema, use_hms=use_hms) + request = DatabaseStructureRequest(with_schema=with_schema) logger.debug(f"Getting database structure via MCP server (with_schema={with_schema})") response = get_database_structure.sync(client=client, body=request) diff --git a/notebook_utils/berdl_notebook_utils/minio_governance/__init__.py b/notebook_utils/berdl_notebook_utils/minio_governance/__init__.py index 73ff6a0..d2da3e7 100644 --- a/notebook_utils/berdl_notebook_utils/minio_governance/__init__.py +++ b/notebook_utils/berdl_notebook_utils/minio_governance/__init__.py @@ -10,6 +10,7 @@ check_governance_health, get_group_sql_warehouse, get_minio_credentials, + get_polaris_credentials, rotate_minio_credentials, get_my_accessible_paths, get_my_groups, @@ -21,6 +22,7 @@ add_group_member, create_tenant_and_assign_users, list_groups, + list_user_names, list_users, remove_group_member, # Table operations @@ -32,8 +34,8 @@ # Tenant access requests list_available_groups, request_tenant_access, - # Lightweight management queries (direct HTTP) - list_user_names, + # Migration (admin-only) + ensure_polaris_resources, regenerate_policies, ) from .tenant_management import ( @@ -55,6 +57,7 @@ "check_governance_health", "get_group_sql_warehouse", "get_minio_credentials", + "get_polaris_credentials", "rotate_minio_credentials", "get_my_accessible_paths", "get_my_groups", @@ -91,4 +94,7 @@ # Tenant access requests "list_available_groups", "request_tenant_access", + # Migration (admin-only) + "ensure_polaris_resources", + "regenerate_policies", ] diff --git a/notebook_utils/berdl_notebook_utils/minio_governance/operations.py b/notebook_utils/berdl_notebook_utils/minio_governance/operations.py index 5711918..e5d6e83 100644 --- a/notebook_utils/berdl_notebook_utils/minio_governance/operations.py +++ b/notebook_utils/berdl_notebook_utils/minio_governance/operations.py @@ -2,10 +2,16 @@ Utility functions for BERDL MinIO Data Governance integration """ +import fcntl +import json import logging import os import time import warnings +from pathlib import Path +from collections.abc import Callable +from typing import TypeVar + from typing import TypedDict import httpx @@ -18,14 +24,26 @@ add_group_member_management_groups_group_name_members_username_post, create_group_management_groups_group_name_post, list_groups_management_groups_get, - list_user_names_management_users_names_get, list_users_management_users_get, regenerate_all_policies_management_migrate_regenerate_policies_post, remove_group_member_management_groups_group_name_members_username_delete, ) + +# Polaris-specific imports — only available when governance client includes polaris endpoints +try: + from governance_client.api.polaris import provision_polaris_user_polaris_user_provision_username_post + from governance_client.api.management import ( + ensure_all_polaris_resources_management_migrate_ensure_polaris_resources_post, + ) +except ImportError: + provision_polaris_user_polaris_user_provision_username_post = None # type: ignore[assignment] + ensure_all_polaris_resources_management_migrate_ensure_polaris_resources_post = None # type: ignore[assignment] from governance_client.api.management.list_group_names_management_groups_names_get import ( sync as list_group_names_sync, ) +from governance_client.api.management.list_user_names_management_users_names_get import ( + sync as list_user_names_sync, +) from governance_client.api.sharing import ( get_path_access_info_sharing_get_path_access_info_post, make_path_private_sharing_make_private_post, @@ -87,10 +105,71 @@ class TenantCreationResult(TypedDict): SQL_WAREHOUSE_BUCKET = "cdm-lake" # TODO: change to berdl-lake SQL_USER_WAREHOUSE_PATH = "users-sql-warehouse" +# Credential caching configuration (Polaris only — MinIO uses direct API calls) +POLARIS_CREDENTIALS_CACHE_FILE = ".berdl_polaris_credentials" + + # ============================================================================= # HELPER FUNCTIONS # ============================================================================= +_T = TypeVar("_T") + + +def _fetch_with_file_cache( + cache_path: Path, + read_cache: Callable[[Path], _T | None], + fetch: Callable[[], _T | None], + write_cache: Callable[[Path, _T], None], +) -> _T | None: + """Fetch credentials using file-based caching with exclusive file locking. + + The lock is released when the file handle is closed (exiting the `with` block). + We intentionally do NOT delete the lock file afterward — another process + may have already acquired a lock on it between our unlock and unlink. + """ + lock_path = cache_path.with_suffix(".lock") + with open(lock_path, "w") as lock_file: + fcntl.flock(lock_file.fileno(), fcntl.LOCK_EX) + + cached = read_cache(cache_path) + if cached is not None: + return cached + + result = fetch() + if result is not None: + write_cache(cache_path, result) + return result + + +def _get_polaris_cache_path() -> Path: + """Get the path to the Polaris credentials cache file in the user's home directory.""" + return Path.home() / POLARIS_CREDENTIALS_CACHE_FILE + + +def _read_cached_polaris_credentials(cache_path: Path) -> "PolarisCredentials | None": + """Read Polaris credentials from cache file. Returns None if file doesn't exist or is corrupted.""" + try: + if not cache_path.exists(): + return None + with open(cache_path, "r") as f: + data = json.load(f) + # Validate required keys are present + if all(k in data for k in ("client_id", "client_secret", "personal_catalog", "tenant_catalogs")): + return data + return None + except (json.JSONDecodeError, TypeError, KeyError, OSError): + return None + + +def _write_polaris_credentials_cache(cache_path: Path, credentials: "PolarisCredentials") -> None: + """Write Polaris credentials to cache file.""" + try: + with open(cache_path, "w") as f: + json.dump(credentials, f) + except (OSError, TypeError): + pass + def _build_table_path(username: str, namespace: str, table_name: str) -> str: """ @@ -159,6 +238,86 @@ def get_minio_credentials() -> CredentialsResponse: return api_response +class PolarisCredentials(TypedDict): + """Polaris credential provisioning result.""" + + client_id: str + client_secret: str + personal_catalog: str + tenant_catalogs: list[str] + + +def _fetch_polaris_credentials() -> PolarisCredentials | None: + """Fetch fresh Polaris credentials from the governance API.""" + settings = get_settings() + polaris_logger = logging.getLogger(__name__) + + if provision_polaris_user_polaris_user_provision_username_post is None: + polaris_logger.warning("Polaris API not available — governance client does not include polaris endpoints") + return None + + client = get_governance_client() + api_response = provision_polaris_user_polaris_user_provision_username_post.sync( + username=settings.USER, client=client + ) + + if isinstance(api_response, ErrorResponse): + polaris_logger.warning(f"Polaris provisioning failed: {api_response.message}") + return None + if api_response is None: + polaris_logger.warning("Polaris provisioning returned no response") + return None + + data = api_response.to_dict() + return { + "client_id": data.get("client_id", ""), + "client_secret": data.get("client_secret", ""), + "personal_catalog": data.get("personal_catalog", ""), + "tenant_catalogs": data.get("tenant_catalogs", []), + } + + +def get_polaris_credentials() -> PolarisCredentials | None: + """ + Provision a Polaris catalog for the current user and set credentials as environment variables. + + Uses file locking and caching to prevent race conditions and avoid unnecessary + API calls when credentials are already cached (same pattern as get_minio_credentials). + + Calls POST /polaris/user_provision/{username} on the governance API on cache miss. + This provisions the user's Polaris environment (catalog, principal, roles, credentials, + tenant access) and returns the configuration. + + Sets the following environment variables: + - POLARIS_CREDENTIAL: client_id:client_secret for authenticating with Polaris + - POLARIS_PERSONAL_CATALOG: Name of the user's personal Polaris catalog + - POLARIS_TENANT_CATALOGS: Comma-separated list of tenant catalogs the user has access to + + Returns: + PolarisCredentials dict, or None if Polaris is not configured + """ + settings = get_settings() + + if not settings.POLARIS_CATALOG_URI: + return None + + result = _fetch_with_file_cache( + _get_polaris_cache_path(), + _read_cached_polaris_credentials, + _fetch_polaris_credentials, + _write_polaris_credentials_cache, + ) + if result is None: + return None + + # Set as environment variables for Spark catalog configuration + os.environ["POLARIS_CREDENTIAL"] = f"{result['client_id']}:{result['client_secret']}" + os.environ["POLARIS_PERSONAL_CATALOG"] = result["personal_catalog"] + os.environ["POLARIS_TENANT_CATALOGS"] = ",".join(result["tenant_catalogs"]) + + return result + + def rotate_minio_credentials() -> CredentialsResponse: """ Rotate MinIO credentials for the current user and update environment variables. @@ -385,6 +544,13 @@ def share_table( Example: share_table("analytics", "user_metrics", with_users=["alice", "bob"]) """ + warnings.warn( + "share_table is deprecated and will be removed in a future release. " + "Direct path sharing is no longer recommended. Please create a Tenant Workspace " + "and request access to the tenant for sharing activities.", + DeprecationWarning, + stacklevel=2, + ) client = get_governance_client() # Get current user's username from environment variable username = get_settings().USER @@ -423,6 +589,13 @@ def unshare_table( Example: unshare_table("analytics", "user_metrics", from_users=["alice"]) """ + warnings.warn( + "unshare_table is deprecated and will be removed in a future release. " + "Direct path sharing is no longer recommended. Please create a Tenant Workspace " + "and request access to the tenant for unsharing activities.", + DeprecationWarning, + stacklevel=2, + ) client = get_governance_client() # Get current user's username from environment variable username = get_settings().USER @@ -607,7 +780,7 @@ def list_user_names() -> list[str]: RuntimeError: If the API call fails. """ client = get_governance_client() - response = list_user_names_management_users_names_get.sync(client=client) + response = list_user_names_sync(client=client) if isinstance(response, ErrorResponse): raise RuntimeError(f"Failed to list usernames: {response.message}") @@ -875,13 +1048,18 @@ def request_tenant_access( raise RuntimeError(f"Failed to connect to tenant access service: {e}") +# ============================================================================= +# MIGRATION - Admin-only bulk operations for IAM + Polaris migration +# ============================================================================= + + def regenerate_policies() -> RegeneratePoliciesResponse: """ - Regenerate all IAM policies for all users and groups. + Force-regenerate all MinIO IAM HOME policies from the current template. - This is an admin-only endpoint that recalculates and applies MinIO IAM - policies for every user and group in the system. Useful after bulk - permission changes or to ensure policy consistency. + This admin-only endpoint updates pre-existing policies to include new path + statements (e.g., Iceberg paths). Each regeneration is independent — errors + do not block others. Returns: RegeneratePoliciesResponse with users_updated, groups_updated, errors, @@ -904,3 +1082,22 @@ def regenerate_policies() -> RegeneratePoliciesResponse: raise RuntimeError("Failed to regenerate policies: no response from API") return response + + +def ensure_polaris_resources(): + """ + Ensure Polaris resources exist for all users and groups. + + Creates Polaris principals, personal catalogs, and roles for all users. + Creates tenant catalogs for all base groups. Grants correct principal roles + based on group memberships. All operations are idempotent. + + Returns: + EnsurePolarisResponse with users_provisioned, groups_provisioned, errors, + or ErrorResponse on failure. + """ + if ensure_all_polaris_resources_management_migrate_ensure_polaris_resources_post is None: + raise RuntimeError("Polaris API not available — governance client does not include polaris endpoints") + + client = get_governance_client() + return ensure_all_polaris_resources_management_migrate_ensure_polaris_resources_post.sync(client=client) diff --git a/notebook_utils/berdl_notebook_utils/refresh.py b/notebook_utils/berdl_notebook_utils/refresh.py index a989db6..e26da87 100644 --- a/notebook_utils/berdl_notebook_utils/refresh.py +++ b/notebook_utils/berdl_notebook_utils/refresh.py @@ -1,42 +1,65 @@ """ Refresh credentials and Spark environment. -Provides a single function to re-provision MinIO credentials, restart -the Spark Connect server, and stop any existing Spark session — ensuring -get_spark_session() works afterward. +Provides a single function to re-provision MinIO and Polaris credentials, +restart the Spark Connect server, and stop any existing Spark session — +ensuring get_spark_session() works afterward. """ import logging +from pathlib import Path from pyspark.sql import SparkSession from berdl_notebook_utils.berdl_settings import get_settings -from berdl_notebook_utils.minio_governance.operations import rotate_minio_credentials +from berdl_notebook_utils.minio_governance.operations import ( + _get_polaris_cache_path, + rotate_minio_credentials, + get_polaris_credentials, +) from berdl_notebook_utils.spark.connect_server import start_spark_connect_server logger = logging.getLogger("berdl.refresh") +def _remove_cache_file(path: Path) -> bool: + """Remove a cache file. Returns True if the main file existed.""" + existed = False + try: + if path.exists(): + path.unlink() + existed = True + except OSError: + pass + return existed + + def refresh_spark_environment() -> dict: """Re-provision credentials and restart Spark. Steps performed: - 1. Clear the in-memory ``get_settings()`` LRU cache - 2. Rotate MinIO credentials via MMS (generates new secret key, updates env vars) - 3. Clear settings cache again so downstream code sees fresh env vars - 4. Stop any existing Spark session - 5. Restart the Spark Connect server with regenerated spark-defaults.conf + 1. Delete Polaris credential cache file + 2. Clear the in-memory ``get_settings()`` LRU cache + 3. Rotate MinIO credentials via MMS (generates new secret key, updates env vars) + 4. Re-fetch Polaris credentials (sets POLARIS_CREDENTIAL and catalog env vars) + 5. Clear settings cache again so downstream code sees fresh env vars + 6. Stop any existing Spark session + 7. Restart the Spark Connect server with regenerated spark-defaults.conf Returns: - dict with keys ``minio``, ``spark_connect``, + dict with keys ``minio``, ``polaris``, ``spark_connect``, ``spark_session_stopped`` summarising what happened. """ result: dict = {} - # 1. Clear in-memory settings cache + # 1. Delete Polaris credential cache file (MinIO uses direct API calls, no local cache) + polaris_removed = _remove_cache_file(_get_polaris_cache_path()) + logger.info("Cleared Polaris credential cache (removed=%s)", polaris_removed) + + # 2. Clear in-memory settings cache get_settings.cache_clear() - # 2. Rotate MinIO credentials (generates new secret key) + # 3. Rotate MinIO credentials (generates new secret key) try: minio_creds = rotate_minio_credentials() result["minio"] = {"status": "ok", "username": minio_creds.username} @@ -45,10 +68,27 @@ def refresh_spark_environment() -> dict: result["minio"] = {"status": "error", "error": str(exc)} logger.warning("Failed to rotate MinIO credentials: %s", exc) - # 3. Clear settings cache again so get_settings() picks up new env vars + # 4. Re-fetch Polaris credentials + try: + polaris_creds = get_polaris_credentials() + if polaris_creds: + result["polaris"] = { + "status": "ok", + "personal_catalog": polaris_creds["personal_catalog"], + "tenant_catalogs": polaris_creds.get("tenant_catalogs", []), + } + logger.info("Polaris credentials refreshed for catalog: %s", polaris_creds["personal_catalog"]) + else: + result["polaris"] = {"status": "skipped", "reason": "Polaris not configured"} + logger.info("Polaris not configured, skipping credential refresh") + except Exception as exc: + result["polaris"] = {"status": "error", "error": str(exc)} + logger.warning("Failed to refresh Polaris credentials: %s", exc) + + # 5. Clear settings cache again so get_settings() picks up new env vars get_settings.cache_clear() - # 4. Stop existing Spark session + # 6. Stop existing Spark session existing = SparkSession.getActiveSession() if existing: existing.stop() @@ -57,7 +97,7 @@ def refresh_spark_environment() -> dict: else: result["spark_session_stopped"] = False - # 5. Restart Spark Connect server with fresh config + # 7. Restart Spark Connect server with fresh config try: sc_result = start_spark_connect_server(force_restart=True) result["spark_connect"] = sc_result diff --git a/notebook_utils/berdl_notebook_utils/setup_spark_session.py b/notebook_utils/berdl_notebook_utils/setup_spark_session.py index 66d81bc..209d641 100644 --- a/notebook_utils/berdl_notebook_utils/setup_spark_session.py +++ b/notebook_utils/berdl_notebook_utils/setup_spark_session.py @@ -165,6 +165,60 @@ def _get_spark_defaults_conf() -> dict[str, str]: } +def _get_catalog_conf(settings: BERDLSettings) -> dict[str, str]: + """Get Iceberg catalog configuration for Polaris REST catalog.""" + config = {} + + if not settings.POLARIS_CATALOG_URI: + return config + + polaris_uri = str(settings.POLARIS_CATALOG_URI).rstrip("/") + + # S3/MinIO properties for Iceberg's S3FileIO (used by executors to read/write data files). + # Iceberg does NOT use Spark's spark.hadoop.fs.s3a.* — it has its own AWS SDK S3 client. + s3_endpoint = settings.MINIO_ENDPOINT_URL + if not s3_endpoint.startswith("http"): + s3_endpoint = f"http://{s3_endpoint}" + s3_props = { + "s3.endpoint": s3_endpoint, + "s3.access-key-id": settings.MINIO_ACCESS_KEY, + "s3.secret-access-key": settings.MINIO_SECRET_KEY, + "s3.path-style-access": "true", + "s3.region": "us-east-1", + } + + def _catalog_props(prefix: str, warehouse: str) -> dict[str, str]: + props = { + f"{prefix}": "org.apache.iceberg.spark.SparkCatalog", + f"{prefix}.type": "rest", + f"{prefix}.uri": polaris_uri, + f"{prefix}.credential": settings.POLARIS_CREDENTIAL or "", + f"{prefix}.warehouse": warehouse, + f"{prefix}.scope": "PRINCIPAL_ROLE:ALL", + f"{prefix}.token-refresh-enabled": "false", + f"{prefix}.client.region": "us-east-1", + } + # Add S3 properties scoped to this catalog + for k, v in s3_props.items(): + props[f"{prefix}.{k}"] = v + return props + + # 1. Add Personal Catalog (if configured) + if settings.POLARIS_PERSONAL_CATALOG: + config.update(_catalog_props("spark.sql.catalog.my", settings.POLARIS_PERSONAL_CATALOG)) + + # 2. Add Tenant Catalogs (if configured) + if settings.POLARIS_TENANT_CATALOGS: + for tenant_catalog in settings.POLARIS_TENANT_CATALOGS.split(","): + tenant_catalog = tenant_catalog.strip() + if not tenant_catalog: + continue + catalog_alias = tenant_catalog[7:] if tenant_catalog.startswith("tenant_") else tenant_catalog + config.update(_catalog_props(f"spark.sql.catalog.{catalog_alias}", tenant_catalog)) + + return config + + def _get_delta_conf() -> dict[str, str]: return { "spark.sql.extensions": "io.delta.sql.DeltaSparkSessionExtension", @@ -245,6 +299,18 @@ def _get_s3_conf(settings: BERDLSettings, tenant_name: str | None = None) -> dic } +def _is_immutable_config(key: str) -> bool: + """Check if a Spark config key is immutable in Spark Connect mode.""" + if key in IMMUTABLE_CONFIGS: + return True + # Iceberg catalog configs (spark.sql.catalog..*) are all static/immutable + # because catalogs must be registered at server startup. The catalog names are + # dynamic (personal "my" + tenant aliases), so we match by prefix. + if key.startswith("spark.sql.catalog.") and key != "spark.sql.catalog.spark_catalog": + return True + return False + + def _filter_immutable_spark_connect_configs(config: dict[str, str]) -> dict[str, str]: """ Filter out configurations that cannot be modified in Spark Connect mode. @@ -259,7 +325,7 @@ def _filter_immutable_spark_connect_configs(config: dict[str, str]) -> dict[str, Filtered configuration dictionary with only mutable configs """ - return {k: v for k, v in config.items() if k not in IMMUTABLE_CONFIGS} + return {k: v for k, v in config.items() if not _is_immutable_config(k)} def _set_scheduler_pool(spark: SparkSession, scheduler_pool: str) -> None: @@ -313,6 +379,9 @@ def generate_spark_conf( if use_hive: config.update(_get_hive_conf(settings)) + # Always add Polaris catalogs if they are configured + config.update(_get_catalog_conf(settings)) + if use_spark_connect: # Spark Connect: filter out immutable configs that cannot be modified from the client config = _filter_immutable_spark_connect_configs(config) @@ -392,4 +461,23 @@ def get_spark_session( spark.sparkContext.setLogLevel("DEBUG") _set_scheduler_pool(spark, scheduler_pool) + # Warm up Polaris REST catalogs so they appear in SHOW CATALOGS immediately. + # Spark lazily initializes REST catalog plugins — they only show up in + # CatalogManager._catalogs (and therefore SHOW CATALOGS) after first access. + if use_spark_connect and not local: + _settings = settings or get_settings() + _catalog_aliases: list[str] = [] + if _settings.POLARIS_PERSONAL_CATALOG: + _catalog_aliases.append("my") + if _settings.POLARIS_TENANT_CATALOGS: + for _raw in _settings.POLARIS_TENANT_CATALOGS.split(","): + _raw = _raw.strip() + if _raw: + _catalog_aliases.append(_raw.removeprefix("tenant_")) + for _alias in _catalog_aliases: + try: + spark.sql(f"SHOW NAMESPACES IN {_alias}").collect() + except Exception: + pass # catalog may not have any namespaces yet; access is enough to register it + return spark diff --git a/notebook_utils/berdl_notebook_utils/spark/__init__.py b/notebook_utils/berdl_notebook_utils/spark/__init__.py index 55b99bd..287a5e2 100644 --- a/notebook_utils/berdl_notebook_utils/spark/__init__.py +++ b/notebook_utils/berdl_notebook_utils/spark/__init__.py @@ -4,7 +4,7 @@ This package provides comprehensive Spark utilities organized into focused modules: - database: Catalog and namespace management utilities - dataframe: DataFrame operations and display functions -- data_store: Hive metastore and database information utilities +- data_store: Iceberg catalog browsing utilities - connect_server: Spark Connect server management All functions are imported at the package level for convenient access. diff --git a/notebook_utils/berdl_notebook_utils/spark/connect_server.py b/notebook_utils/berdl_notebook_utils/spark/connect_server.py index 890c4eb..a9522db 100644 --- a/notebook_utils/berdl_notebook_utils/spark/connect_server.py +++ b/notebook_utils/berdl_notebook_utils/spark/connect_server.py @@ -9,6 +9,7 @@ import os import shutil import signal +import socket import subprocess import time from pathlib import Path @@ -23,6 +24,7 @@ from ..setup_spark_session import ( DRIVER_MEMORY_OVERHEAD, EXECUTOR_MEMORY_OVERHEAD, + _get_catalog_conf, convert_memory_format, ) @@ -119,6 +121,13 @@ def generate_spark_config(self) -> None: warehouse_response = get_my_sql_warehouse() f.write(f"spark.sql.warehouse.dir={warehouse_response.sql_warehouse_prefix}\n") + # Add Polaris Iceberg Catalogs + catalog_configs = _get_catalog_conf(self.settings) + if catalog_configs: + f.write("\n# Polaris Catalog Configuration\n") + for key, value in catalog_configs.items(): + f.write(f"{key}={value}\n") + logger.info(f"Spark configuration written to {self.spark_defaults_path}") def compute_allowed_namespace_prefixes(self) -> str: @@ -288,8 +297,6 @@ def _wait_for_port_release(self, timeout: int) -> bool: Returns: True if port is free, False if timeout reached. """ - import socket - port = self.config.spark_connect_port start_time = time.time() @@ -319,9 +326,9 @@ def start(self, force_restart: bool = False) -> dict: Dictionary with server information. """ # Check if server is already running - if self.is_running(): + server_info = self.get_server_info() + if server_info is not None: if not force_restart: - server_info = self.get_server_info() logger.info(f"✅ Spark Connect server already running (PID: {server_info['pid']})") logger.info(" Reusing existing server - no need to start a new one") return server_info @@ -382,6 +389,9 @@ def start(self, force_restart: bool = False) -> dict: f.write(str(process.pid)) server_info = self.get_server_info() + if server_info is None: + raise RuntimeError("Failed to get server info after startup") + logger.info(f"✅ Spark Connect server started successfully (PID: {process.pid})") logger.info(f" Connect URL: {server_info['url']}") logger.info(f" Logs: {server_info['log_file']}") @@ -401,8 +411,8 @@ def status(self) -> dict: Returns: Dictionary with status information. """ - if self.is_running(): - info = self.get_server_info() + info = self.get_server_info() + if info is not None: return { "status": "running", **info, diff --git a/notebook_utils/berdl_notebook_utils/spark/database.py b/notebook_utils/berdl_notebook_utils/spark/database.py index 3e2d05c..60b7f1e 100644 --- a/notebook_utils/berdl_notebook_utils/spark/database.py +++ b/notebook_utils/berdl_notebook_utils/spark/database.py @@ -3,6 +3,10 @@ This module contains utility functions to interact with the Spark catalog, including tenant-aware namespace management for BERDL SQL warehouses. + +create_namespace_if_not_exists() supports two flows: +- Delta/Hive: governance prefixes (u_user__, t_tenant__) when no catalog is specified +- Polaris Iceberg: catalog-level isolation (no prefixes) when catalog is specified """ from pyspark.sql import SparkSession @@ -40,6 +44,11 @@ def generate_namespace_location(namespace: str | None = None, tenant_name: str | # Always fetch warehouse directory from governance API for proper S3 location # Don't rely on spark.sql.warehouse.dir as it may be set to local path by Spark Connect server warehouse_response = get_group_sql_warehouse(tenant_name) if tenant_name else get_my_sql_warehouse() + + if hasattr(warehouse_response, "message") and not getattr(warehouse_response, "sql_warehouse_prefix", None): + print(f"Warning: Failed to get warehouse location: {getattr(warehouse_response, 'message', 'Unknown error')}") + return (namespace, None) + warehouse_dir = warehouse_response.sql_warehouse_prefix if warehouse_dir and ("users-sql-warehouse" in warehouse_dir or "tenant-sql-warehouse" in warehouse_dir): @@ -73,35 +82,48 @@ def generate_namespace_location(namespace: str | None = None, tenant_name: str | def create_namespace_if_not_exists( spark: SparkSession, namespace: str | None = DEFAULT_NAMESPACE, - append_target: bool = True, tenant_name: str | None = None, + iceberg: bool = False, ) -> str: """ Create a namespace in the Spark catalog if it does not exist. - If append_target is True, automatically prepends the governance-provided namespace prefix - based on the warehouse directory type (user vs tenant) to create the properly formatted namespace. + Supports two flows controlled by the *iceberg* flag: - For Spark Connect, this function explicitly sets the database LOCATION to ensure tables are - written to the correct S3 path, since spark.sql.warehouse.dir cannot be modified per session. + **Iceberg flow** (``iceberg=True``): + Creates ``{catalog}.{namespace}`` with no governance prefixes — the + catalog itself provides isolation. The catalog is determined by + *tenant_name*: ``None`` → ``"my"`` (user catalog), otherwise the + tenant name is used as the catalog name. + + **Delta/Hive flow** (``iceberg=False``, the default): + Prepends the governance-provided namespace prefix based on the + warehouse directory type (user vs tenant) and explicitly sets the + database LOCATION for Spark Connect compatibility. :param spark: The Spark session. :param namespace: The name of the namespace. - :param append_target: If True, prepends governance namespace prefix based on warehouse type. - If False, uses namespace as-is. - :param tenant_name: Optional tenant name. If provided, uses tenant warehouse. Otherwise uses user warehouse. - :return: The name of the namespace. + :param tenant_name: Delta/Hive: optional tenant name for tenant warehouse. + Iceberg: used as catalog name (defaults to ``"my"`` when None). + :param iceberg: If True, uses the Iceberg flow with catalog-level isolation. + :return: The fully-qualified namespace name. """ - db_location = None + namespace = _namespace_norm(namespace) - if append_target: - try: - namespace, db_location = generate_namespace_location(namespace, tenant_name) - except Exception as e: - print(f"Error creating namespace: {e}") - raise e - else: - namespace = _namespace_norm(namespace) + # Iceberg flow: catalog-level isolation, no governance prefixes + if iceberg: + catalog = tenant_name or "my" + full_ns = f"{catalog}.{namespace}" + spark.sql(f"CREATE NAMESPACE IF NOT EXISTS {full_ns}") + print(f"Namespace {full_ns} is ready to use.") + return full_ns + + # Delta/Hive flow + try: + namespace, db_location = generate_namespace_location(namespace, tenant_name) + except Exception as e: + print(f"Error creating namespace: {e}") + raise e if spark.catalog.databaseExists(namespace): print(f"Namespace {namespace} is already registered and ready to use") diff --git a/notebook_utils/tests/mcp/test_operations.py b/notebook_utils/tests/mcp/test_operations.py index c43788d..a20b7b9 100644 --- a/notebook_utils/tests/mcp/test_operations.py +++ b/notebook_utils/tests/mcp/test_operations.py @@ -76,19 +76,6 @@ def test_returns_list_of_databases(self, mock_client): assert result == ["default", "analytics", "user_data"] - def test_passes_use_hms_parameter(self, mock_client): - """Test that use_hms parameter is passed correctly.""" - mock_response = Mock() - mock_response.databases = [] - - with patch("berdl_notebook_utils.mcp.operations.list_databases") as mock_api: - mock_api.sync.return_value = mock_response - - mcp_list_databases(use_hms=False) - - call_kwargs = mock_api.sync.call_args[1] - assert call_kwargs["body"].use_hms is False - def test_raises_on_none_response(self, mock_client): """Test that None response raises an exception.""" with patch("berdl_notebook_utils.mcp.operations.list_databases") as mock_api: diff --git a/notebook_utils/tests/minio_governance/test_operations.py b/notebook_utils/tests/minio_governance/test_operations.py index 06ff352..bed3e07 100644 --- a/notebook_utils/tests/minio_governance/test_operations.py +++ b/notebook_utils/tests/minio_governance/test_operations.py @@ -2,7 +2,9 @@ Tests for minio_governance/operations.py module. """ +import json import logging +from pathlib import Path from unittest.mock import Mock, patch import httpx @@ -18,9 +20,14 @@ ) from berdl_notebook_utils.minio_governance.operations import ( + _fetch_with_file_cache, + _get_polaris_cache_path, + _read_cached_polaris_credentials, + _write_polaris_credentials_cache, _build_table_path, check_governance_health, get_minio_credentials, + get_polaris_credentials, get_my_sql_warehouse, get_group_sql_warehouse, get_namespace_prefix, @@ -43,6 +50,7 @@ request_tenant_access, rotate_minio_credentials, regenerate_policies, + POLARIS_CREDENTIALS_CACHE_FILE, CredentialsResponse, ErrorResponse, GroupManagementResponse, @@ -820,36 +828,36 @@ def test_list_available_groups_none_response(self, mock_list_groups, mock_get_cl class TestListUserNames: - """Tests for list_user_names function (lines 690-699).""" + """Tests for list_user_names function.""" @patch("berdl_notebook_utils.minio_governance.operations.get_governance_client") - @patch("berdl_notebook_utils.minio_governance.operations.list_user_names_management_users_names_get") + @patch("berdl_notebook_utils.minio_governance.operations.list_user_names_sync") def test_list_user_names_success(self, mock_list_names, mock_get_client): - mock_get_client.return_value = Mock() - mock_response = Mock(spec=UserNamesResponse) - mock_response.usernames = ["alice", "bob", "charlie"] - mock_list_names.sync.return_value = mock_response + mock_client = Mock() + mock_get_client.return_value = mock_client + mock_list_names.return_value = Mock(spec=UserNamesResponse, usernames=["alice", "bob", "charlie"]) result = list_user_names() assert result == ["alice", "bob", "charlie"] + mock_list_names.assert_called_once_with(client=mock_client) @patch("berdl_notebook_utils.minio_governance.operations.get_governance_client") - @patch("berdl_notebook_utils.minio_governance.operations.list_user_names_management_users_names_get") + @patch("berdl_notebook_utils.minio_governance.operations.list_user_names_sync") def test_list_user_names_error_response(self, mock_list_names, mock_get_client): mock_get_client.return_value = Mock() - mock_list_names.sync.return_value = ErrorResponse(message="forbidden", error_type="error") + mock_list_names.return_value = Mock(spec=ErrorResponse, message="Forbidden") - with pytest.raises(RuntimeError, match="Failed to list usernames: forbidden"): + with pytest.raises(RuntimeError, match="Failed to list usernames"): list_user_names() @patch("berdl_notebook_utils.minio_governance.operations.get_governance_client") - @patch("berdl_notebook_utils.minio_governance.operations.list_user_names_management_users_names_get") + @patch("berdl_notebook_utils.minio_governance.operations.list_user_names_sync") def test_list_user_names_none_response(self, mock_list_names, mock_get_client): mock_get_client.return_value = Mock() - mock_list_names.sync.return_value = None + mock_list_names.return_value = None - with pytest.raises(RuntimeError, match="Failed to list usernames: no response from API"): + with pytest.raises(RuntimeError, match="no response from API"): list_user_names() @@ -1003,3 +1011,277 @@ def test_remove_group_member_read_only(self, mock_remove_member, mock_get_client assert len(result) == 1 calls = mock_remove_member.sync.call_args_list assert calls[0][1]["group_name"] == "kbasero" + + +# --------------------------------------------------------------------------- +# Polaris credential caching tests +# --------------------------------------------------------------------------- + + +class TestFetchWithFileCache: + """Tests for _fetch_with_file_cache helper.""" + + @patch("berdl_notebook_utils.minio_governance.operations.fcntl") + def test_returns_cached_value_on_cache_hit(self, mock_fcntl, tmp_path): + """Test returns cached value without calling fetch when cache hits.""" + cache_path = tmp_path / "creds.json" + sentinel = {"key": "cached_value"} + + read_cache = Mock(return_value=sentinel) + fetch = Mock() + write_cache = Mock() + + result = _fetch_with_file_cache(cache_path, read_cache, fetch, write_cache) + + assert result == sentinel + read_cache.assert_called_once_with(cache_path) + fetch.assert_not_called() + write_cache.assert_not_called() + + @patch("berdl_notebook_utils.minio_governance.operations.fcntl") + def test_fetches_and_writes_cache_on_cache_miss(self, mock_fcntl, tmp_path): + """Test fetches fresh data and writes cache when cache misses.""" + cache_path = tmp_path / "creds.json" + sentinel = {"key": "fresh_value"} + + read_cache = Mock(return_value=None) + fetch = Mock(return_value=sentinel) + write_cache = Mock() + + result = _fetch_with_file_cache(cache_path, read_cache, fetch, write_cache) + + assert result == sentinel + read_cache.assert_called_once_with(cache_path) + fetch.assert_called_once() + write_cache.assert_called_once_with(cache_path, sentinel) + + @patch("berdl_notebook_utils.minio_governance.operations.fcntl") + def test_returns_none_without_writing_when_fetch_fails(self, mock_fcntl, tmp_path): + """Test returns None and does not write cache when fetch returns None.""" + cache_path = tmp_path / "creds.json" + + read_cache = Mock(return_value=None) + fetch = Mock(return_value=None) + write_cache = Mock() + + result = _fetch_with_file_cache(cache_path, read_cache, fetch, write_cache) + + assert result is None + fetch.assert_called_once() + write_cache.assert_not_called() + + +class TestGetPolarisCachePath: + """Tests for _get_polaris_cache_path helper.""" + + def test_returns_path_in_home(self): + """Test returns path in home directory.""" + path = _get_polaris_cache_path() + + assert path == Path.home() / POLARIS_CREDENTIALS_CACHE_FILE + + +class TestReadCachedPolarisCredentials: + """Tests for _read_cached_polaris_credentials helper.""" + + def test_returns_none_if_file_not_exists(self, tmp_path): + """Test returns None if cache file doesn't exist.""" + result = _read_cached_polaris_credentials(tmp_path / "nonexistent.json") + + assert result is None + + def test_returns_none_on_invalid_json(self, tmp_path): + """Test returns None on invalid JSON.""" + cache_file = tmp_path / "cache.json" + cache_file.write_text("not valid json") + + result = _read_cached_polaris_credentials(cache_file) + + assert result is None + + def test_returns_none_on_missing_keys(self, tmp_path): + """Test returns None when required keys are missing.""" + cache_file = tmp_path / "cache.json" + cache_file.write_text('{"client_id": "test", "client_secret": "test"}') + + result = _read_cached_polaris_credentials(cache_file) + + assert result is None + + def test_returns_credentials_on_valid_cache(self, tmp_path): + """Test returns credentials on valid cache file.""" + cache_file = tmp_path / "cache.json" + data = { + "client_id": "test_id", + "client_secret": "test_secret", + "personal_catalog": "user_test", + "tenant_catalogs": ["tenant_kbase"], + } + cache_file.write_text(json.dumps(data)) + + result = _read_cached_polaris_credentials(cache_file) + + assert result is not None + assert result["client_id"] == "test_id" + assert result["personal_catalog"] == "user_test" + + +class TestWritePolarisCachedCredentials: + """Tests for _write_polaris_credentials_cache helper.""" + + def test_writes_credentials_to_file(self, tmp_path): + """Test writes Polaris credentials to cache file.""" + cache_file = tmp_path / "cache.json" + creds = { + "client_id": "test_id", + "client_secret": "test_secret", + "personal_catalog": "user_test", + "tenant_catalogs": [], + } + + _write_polaris_credentials_cache(cache_file, creds) + + assert cache_file.exists() + content = json.loads(cache_file.read_text()) + assert content["client_id"] == "test_id" + + def test_handles_os_error(self, tmp_path): + """Test handles OSError gracefully.""" + cache_file = tmp_path / "nonexistent_dir" / "cache.json" + + # Should not raise + _write_polaris_credentials_cache(cache_file, {"client_id": "test"}) + + +class TestGetPolarisCredentials: + """Tests for get_polaris_credentials function.""" + + @patch("berdl_notebook_utils.minio_governance.operations.os") + @patch("berdl_notebook_utils.minio_governance.operations.fcntl") + @patch("berdl_notebook_utils.minio_governance.operations.get_settings") + def test_returns_none_when_polaris_not_configured(self, mock_settings, mock_fcntl, mock_os): + """Test returns None when POLARIS_CATALOG_URI is not set.""" + mock_settings.return_value.POLARIS_CATALOG_URI = None + + result = get_polaris_credentials() + + assert result is None + + @patch("berdl_notebook_utils.minio_governance.operations.os") + @patch("berdl_notebook_utils.minio_governance.operations.fcntl") + @patch("berdl_notebook_utils.minio_governance.operations._write_polaris_credentials_cache") + @patch("berdl_notebook_utils.minio_governance.operations._read_cached_polaris_credentials") + @patch("berdl_notebook_utils.minio_governance.operations._get_polaris_cache_path") + @patch("berdl_notebook_utils.minio_governance.operations.get_settings") + def test_returns_cached_credentials( + self, mock_settings, mock_cache_path, mock_read_cache, mock_write_cache, mock_fcntl, mock_os, tmp_path + ): + """Test returns cached Polaris credentials when available.""" + mock_settings.return_value.POLARIS_CATALOG_URI = "http://polaris:8181/api/catalog" + mock_cache_path.return_value = tmp_path / ".polaris_cache" + + cached = { + "client_id": "cached_id", + "client_secret": "cached_secret", + "personal_catalog": "user_test", + "tenant_catalogs": ["tenant_kbase"], + } + mock_read_cache.return_value = cached + + result = get_polaris_credentials() + + assert result["client_id"] == "cached_id" + mock_os.environ.__setitem__.assert_any_call("POLARIS_CREDENTIAL", "cached_id:cached_secret") + mock_os.environ.__setitem__.assert_any_call("POLARIS_PERSONAL_CATALOG", "user_test") + + @patch("berdl_notebook_utils.minio_governance.operations.os") + @patch("berdl_notebook_utils.minio_governance.operations.fcntl") + @patch("berdl_notebook_utils.minio_governance.operations._write_polaris_credentials_cache") + @patch("berdl_notebook_utils.minio_governance.operations._read_cached_polaris_credentials") + @patch("berdl_notebook_utils.minio_governance.operations._get_polaris_cache_path") + @patch( + "berdl_notebook_utils.minio_governance.operations.provision_polaris_user_polaris_user_provision_username_post" + ) + @patch("berdl_notebook_utils.minio_governance.operations.get_governance_client") + @patch("berdl_notebook_utils.minio_governance.operations.get_settings") + def test_fetches_fresh_credentials( + self, + mock_settings, + mock_get_client, + mock_provision, + mock_cache_path, + mock_read_cache, + mock_write_cache, + mock_fcntl, + mock_os, + tmp_path, + ): + """Test fetches fresh credentials from API when no cache.""" + mock_settings.return_value.POLARIS_CATALOG_URI = "http://polaris:8181/api/catalog" + mock_settings.return_value.USER = "test_user" + mock_cache_path.return_value = tmp_path / ".polaris_cache" + mock_read_cache.return_value = None + + mock_api_response = Mock() + mock_api_response.to_dict.return_value = { + "client_id": "new_id", + "client_secret": "new_secret", + "personal_catalog": "user_test_user", + "tenant_catalogs": ["tenant_team"], + } + mock_provision.sync.return_value = mock_api_response + + result = get_polaris_credentials() + + assert result["client_id"] == "new_id" + assert result["personal_catalog"] == "user_test_user" + mock_write_cache.assert_called_once() + mock_provision.sync.assert_called_once_with(username="test_user", client=mock_get_client.return_value) + + @patch("berdl_notebook_utils.minio_governance.operations.fcntl") + @patch("berdl_notebook_utils.minio_governance.operations._read_cached_polaris_credentials") + @patch("berdl_notebook_utils.minio_governance.operations._get_polaris_cache_path") + @patch( + "berdl_notebook_utils.minio_governance.operations.provision_polaris_user_polaris_user_provision_username_post" + ) + @patch("berdl_notebook_utils.minio_governance.operations.get_governance_client") + @patch("berdl_notebook_utils.minio_governance.operations.get_settings") + def test_returns_none_on_error_response( + self, mock_settings, mock_get_client, mock_provision, mock_cache_path, mock_read_cache, mock_fcntl, tmp_path + ): + """Test returns None when API returns ErrorResponse.""" + mock_settings.return_value.POLARIS_CATALOG_URI = "http://polaris:8181/api/catalog" + mock_settings.return_value.USER = "test_user" + mock_cache_path.return_value = tmp_path / ".polaris_cache" + mock_read_cache.return_value = None + + mock_error = Mock(spec=ErrorResponse) + mock_error.message = "Internal server error" + mock_provision.sync.return_value = mock_error + + result = get_polaris_credentials() + + assert result is None + + @patch("berdl_notebook_utils.minio_governance.operations.fcntl") + @patch("berdl_notebook_utils.minio_governance.operations._read_cached_polaris_credentials") + @patch("berdl_notebook_utils.minio_governance.operations._get_polaris_cache_path") + @patch( + "berdl_notebook_utils.minio_governance.operations.provision_polaris_user_polaris_user_provision_username_post" + ) + @patch("berdl_notebook_utils.minio_governance.operations.get_governance_client") + @patch("berdl_notebook_utils.minio_governance.operations.get_settings") + def test_returns_none_on_no_response( + self, mock_settings, mock_get_client, mock_provision, mock_cache_path, mock_read_cache, mock_fcntl, tmp_path + ): + """Test returns None when API returns None (unexpected status).""" + mock_settings.return_value.POLARIS_CATALOG_URI = "http://polaris:8181/api/catalog" + mock_settings.return_value.USER = "test_user" + mock_cache_path.return_value = tmp_path / ".polaris_cache" + mock_read_cache.return_value = None + + mock_provision.sync.return_value = None + + result = get_polaris_credentials() + + assert result is None diff --git a/notebook_utils/tests/spark/test_connect_server.py b/notebook_utils/tests/spark/test_connect_server.py index 8cc130a..e3cb0f4 100644 --- a/notebook_utils/tests/spark/test_connect_server.py +++ b/notebook_utils/tests/spark/test_connect_server.py @@ -102,6 +102,8 @@ def test_generate_spark_config(self, mock_get_warehouse, mock_get_settings, mock mock_settings.SPARK_MASTER_CORES = 1 mock_settings.SPARK_MASTER_MEMORY = "8G" mock_settings.BERDL_POD_IP = "192.168.1.100" + # Polaris not configured — _get_catalog_conf returns empty dict + mock_settings.POLARIS_CATALOG_URI = None mock_get_settings.return_value = mock_settings mock_convert.return_value = "8g" @@ -125,6 +127,60 @@ def test_generate_spark_config(self, mock_get_warehouse, mock_get_settings, mock assert "test_user" in content assert "spark.eventLog.dir" in content + @patch("berdl_notebook_utils.spark.connect_server._get_catalog_conf") + @patch("berdl_notebook_utils.spark.connect_server.shutil.copy") + @patch("berdl_notebook_utils.spark.connect_server.convert_memory_format") + @patch("berdl_notebook_utils.spark.connect_server.get_settings") + @patch("berdl_notebook_utils.spark.connect_server.get_my_sql_warehouse") + def test_generate_spark_config_with_catalog( + self, mock_get_warehouse, mock_get_settings, mock_convert, mock_copy, mock_catalog_conf, tmp_path + ): + """Test generate_spark_config includes Polaris catalog config when present.""" + mock_settings = Mock() + mock_settings.USER = "test_user" + mock_settings.SPARK_HOME = "/opt/spark" + mock_settings.SPARK_CONNECT_DEFAULTS_TEMPLATE = str(tmp_path / "template.conf") + mock_url = Mock() + mock_url.port = 15002 + mock_settings.SPARK_CONNECT_URL = mock_url + mock_settings.SPARK_MASTER_URL = "spark://master:7077" + mock_settings.BERDL_HIVE_METASTORE_URI = "thrift://localhost:9083" + mock_settings.MINIO_ENDPOINT_URL = "http://localhost:9000" + mock_settings.MINIO_ACCESS_KEY = "minioadmin" + mock_settings.MINIO_SECRET_KEY = "minioadmin" + mock_settings.SPARK_WORKER_COUNT = 2 + mock_settings.SPARK_WORKER_CORES = 2 + mock_settings.SPARK_WORKER_MEMORY = "10G" + mock_settings.SPARK_MASTER_CORES = 1 + mock_settings.SPARK_MASTER_MEMORY = "8G" + mock_settings.BERDL_POD_IP = "192.168.1.100" + mock_get_settings.return_value = mock_settings + + mock_convert.return_value = "8g" + mock_warehouse = Mock() + mock_warehouse.sql_warehouse_prefix = "s3a://cdm-lake/users-sql-warehouse/test_user" + mock_get_warehouse.return_value = mock_warehouse + + # Return catalog config entries + mock_catalog_conf.return_value = { + "spark.sql.catalog.my": "org.apache.iceberg.spark.SparkCatalog", + "spark.sql.catalog.my.type": "rest", + } + + # Create template file + template_file = tmp_path / "template.conf" + template_file.write_text("# Base config") + + config = SparkConnectServerConfig() + config.spark_defaults_path = tmp_path / "spark-defaults.conf" + + config.generate_spark_config() + + content = config.spark_defaults_path.read_text() + assert "Polaris Catalog Configuration" in content + assert "spark.sql.catalog.my=org.apache.iceberg.spark.SparkCatalog" in content + assert "spark.sql.catalog.my.type=rest" in content + @patch("berdl_notebook_utils.spark.connect_server.get_my_groups") @patch("berdl_notebook_utils.spark.connect_server.get_namespace_prefix") @patch("berdl_notebook_utils.spark.connect_server.get_settings") @@ -547,9 +603,9 @@ class TestSparkConnectServerManagerForceRestart: """Tests for force_restart functionality.""" @patch("berdl_notebook_utils.spark.connect_server.SparkConnectServerManager.stop") - @patch("berdl_notebook_utils.spark.connect_server.SparkConnectServerManager.is_running") + @patch("berdl_notebook_utils.spark.connect_server.SparkConnectServerManager.get_server_info") @patch("berdl_notebook_utils.spark.connect_server.SparkConnectServerConfig") - def test_start_force_restart_calls_stop(self, mock_config_class, mock_is_running, mock_stop, tmp_path): + def test_start_force_restart_calls_stop(self, mock_config_class, mock_get_info, mock_stop, tmp_path): """Test start with force_restart=True calls stop first.""" mock_config = Mock() mock_config.username = "test_user" @@ -561,8 +617,9 @@ def test_start_force_restart_calls_stop(self, mock_config_class, mock_is_running mock_config.pid_file_path = tmp_path / "pid" mock_config_class.return_value = mock_config - # Server is running initially - mock_is_running.side_effect = [True, False] # First check: running, after stop: not running + # start() calls get_server_info() — return a dict (server running) on first call, + # then None after stop() to trigger the "start new server" path + mock_get_info.return_value = {"pid": 12345, "port": 15002} # Mock the start script check to fail (we don't want to actually start) with patch("pathlib.Path.exists", return_value=False): diff --git a/notebook_utils/tests/spark/test_database.py b/notebook_utils/tests/spark/test_database.py index 19e5e2c..4f70732 100644 --- a/notebook_utils/tests/spark/test_database.py +++ b/notebook_utils/tests/spark/test_database.py @@ -110,12 +110,16 @@ def test_generate_namespace_location_no_match_warns( assert "Warning: Could not determine target name from warehouse directory" in captured.out +# ============================================================================ +# create_namespace_if_not_exists tests (Delta/Hive flow) +# ============================================================================ + + @pytest.mark.parametrize("tenant", [None, TENANT_NAME]) @pytest.mark.parametrize("namespace_arg", EXPECTED_NS) def test_create_namespace_if_not_exists_user_tenant_warehouse(namespace_arg: str | None, tenant: str | None) -> None: """Test user and tenant namespace creation.""" mock_spark = make_mock_spark() - # Run with append_target=True (default) ns = create_namespace_if_not_exists(mock_spark, namespace=namespace_arg, tenant_name=tenant) # type: ignore namespace = EXPECTED_NS[namespace_arg] if tenant: @@ -128,18 +132,6 @@ def test_create_namespace_if_not_exists_user_tenant_warehouse(namespace_arg: str mock_spark.sql.assert_called_once_with(f"CREATE DATABASE IF NOT EXISTS {ns} LOCATION '{expected_location}'") -@pytest.mark.parametrize("tenant", [None, TENANT_NAME]) -@pytest.mark.parametrize("namespace_arg", EXPECTED_NS) -def test_create_namespace_if_not_exists_without_prefix(namespace_arg: str | None, tenant: str | None) -> None: - """Test namespace creation when append_target is set to false.""" - mock_spark = make_mock_spark() - ns = create_namespace_if_not_exists(mock_spark, namespace=namespace_arg, append_target=False, tenant_name=tenant) # type: ignore - namespace = EXPECTED_NS[namespace_arg] - assert ns == namespace - # Should create database without LOCATION clause - mock_spark.sql.assert_called_once_with(f"CREATE DATABASE IF NOT EXISTS {namespace}") - - @pytest.mark.parametrize("tenant", [None, TENANT_NAME]) @pytest.mark.parametrize("namespace_arg", EXPECTED_NS) def test_create_namespace_if_not_exists_already_exists( @@ -147,13 +139,17 @@ def test_create_namespace_if_not_exists_already_exists( ) -> None: """Test namespace creation when the namespace has already been registered.""" mock_spark = make_mock_spark(database_exists=True) - ns = create_namespace_if_not_exists(mock_spark, namespace=namespace_arg, append_target=False, tenant_name=tenant) # type: ignore + ns = create_namespace_if_not_exists(mock_spark, namespace=namespace_arg, tenant_name=tenant) # type: ignore namespace = EXPECTED_NS[namespace_arg] - assert ns == namespace + if tenant: + expected_ns = f"tenant__{namespace}" + else: + expected_ns = f"user__{namespace}" + assert ns == expected_ns # No call to spark.sql as the namespace already exists mock_spark.sql.assert_not_called() captured = capfd.readouterr() - assert f"Namespace {namespace} is already registered and ready to use" in captured.out + assert f"Namespace {expected_ns} is already registered and ready to use" in captured.out @pytest.mark.parametrize("namespace_arg", EXPECTED_NS) @@ -186,3 +182,37 @@ def test_create_namespace_if_not_exists_error() -> None: pytest.raises(RuntimeError, match="things went wrong"), ): create_namespace_if_not_exists(Mock(), "some_namespace") + + +# ============================================================================ +# create_namespace_if_not_exists iceberg=True tests (Polaris Iceberg flow) +# ============================================================================ + + +@pytest.mark.parametrize("namespace_arg", EXPECTED_NS) +def test_create_namespace_iceberg_default_catalog(namespace_arg: str | None, capfd: pytest.CaptureFixture[str]) -> None: + """Test Iceberg namespace creation uses 'my' catalog by default (no tenant_name).""" + mock_spark = make_mock_spark() + namespace = EXPECTED_NS[namespace_arg] + result = create_namespace_if_not_exists(mock_spark, namespace, iceberg=True) + assert result == f"my.{namespace}" + mock_spark.sql.assert_called_once_with(f"CREATE NAMESPACE IF NOT EXISTS my.{namespace}") + captured = capfd.readouterr() + assert f"Namespace my.{namespace} is ready to use." in captured.out + + +@pytest.mark.parametrize("tenant", ["globalusers", "research"]) +def test_create_namespace_iceberg_tenant_as_catalog(tenant: str) -> None: + """Test Iceberg namespace creation uses tenant_name as catalog.""" + mock_spark = make_mock_spark() + result = create_namespace_if_not_exists(mock_spark, "test_db", iceberg=True, tenant_name=tenant) + assert result == f"{tenant}.test_db" + mock_spark.sql.assert_called_once_with(f"CREATE NAMESPACE IF NOT EXISTS {tenant}.test_db") + + +def test_create_namespace_iceberg_no_governance_prefix() -> None: + """Test that iceberg=True does NOT call governance API for prefixes.""" + mock_spark = make_mock_spark() + with patch("berdl_notebook_utils.spark.database.get_namespace_prefix") as mock_prefix: + create_namespace_if_not_exists(mock_spark, "test_db", iceberg=True) + mock_prefix.assert_not_called() diff --git a/notebook_utils/tests/test_get_spark_session.py b/notebook_utils/tests/test_get_spark_session.py index 7c7727a..35d8caf 100644 --- a/notebook_utils/tests/test_get_spark_session.py +++ b/notebook_utils/tests/test_get_spark_session.py @@ -347,3 +347,207 @@ def test_executor_conf_no_auth_token_for_legacy_mode(monkeypatch: pytest.MonkeyP assert "spark.driver.host" in config # Verify master URL does NOT contain auth token (legacy mode doesn't use URL-based auth) assert "authorization" not in config.get("spark.master", "") + + +# ============================================================================= +# convert_memory_format tests +# ============================================================================= + + +class TestConvertMemoryFormat: + """Tests for convert_memory_format edge cases.""" + + def test_invalid_memory_format_raises(self): + """Test that invalid memory format raises ValueError.""" + from berdl_notebook_utils.setup_spark_session import convert_memory_format + + with pytest.raises(ValueError, match="Invalid memory format"): + convert_memory_format("not_a_memory_value") + + def test_invalid_memory_format_no_unit(self): + """Test that bare number without unit raises ValueError.""" + from berdl_notebook_utils.setup_spark_session import convert_memory_format + + with pytest.raises(ValueError, match="Invalid memory format"): + convert_memory_format("1024") + + def test_small_memory_returns_kb(self): + """Test small memory values return kilobyte format.""" + from berdl_notebook_utils.setup_spark_session import convert_memory_format + + # 1 MiB with 10% overhead = ~921.6 KiB → "922k" + result = convert_memory_format("1MiB", 0.1) + assert result.endswith("k") + + def test_very_small_memory_returns_bytes(self): + """Test very small memory values (< 1 KiB) return byte format (no unit suffix).""" + from berdl_notebook_utils.setup_spark_session import convert_memory_format + + # The regex requires [kmgtKMGT] prefix, so smallest valid unit is "kb" + # 0.5 kb = 512 bytes, with 10% overhead = 460.8 bytes → "461" (no unit) + result = convert_memory_format("0.5kb", 0.1) + assert not result.endswith("k") + assert not result.endswith("m") + assert not result.endswith("g") + + def test_gib_memory(self): + """Test GiB memory format.""" + from berdl_notebook_utils.setup_spark_session import convert_memory_format + + result = convert_memory_format("4GiB", 0.1) + assert result.endswith("g") + + def test_unit_key_fallback(self): + """Test unit key fallback for unusual unit formats.""" + from berdl_notebook_utils.setup_spark_session import convert_memory_format + + # "4gb" should work + result = convert_memory_format("4gb", 0.1) + assert result.endswith("g") + + +# ============================================================================= +# _get_catalog_conf tests +# ============================================================================= + + +class TestGetCatalogConf: + """Tests for _get_catalog_conf with Polaris configuration.""" + + def test_returns_empty_when_no_polaris_uri(self): + """Test returns empty dict when POLARIS_CATALOG_URI is None.""" + from berdl_notebook_utils.setup_spark_session import _get_catalog_conf + + settings = BERDLSettings() + settings.POLARIS_CATALOG_URI = None + + result = _get_catalog_conf(settings) + + assert result == {} + + def test_personal_catalog_config(self): + """Test generates personal catalog config when POLARIS_PERSONAL_CATALOG is set.""" + from berdl_notebook_utils.setup_spark_session import _get_catalog_conf + + settings = BERDLSettings() + settings.POLARIS_CATALOG_URI = "http://polaris:8181/api/catalog" # type: ignore + settings.POLARIS_CREDENTIAL = "client_id:client_secret" + settings.POLARIS_PERSONAL_CATALOG = "user_tgu2" + settings.POLARIS_TENANT_CATALOGS = None + + result = _get_catalog_conf(settings) + + assert "spark.sql.catalog.my" in result + assert result["spark.sql.catalog.my"] == "org.apache.iceberg.spark.SparkCatalog" + assert result["spark.sql.catalog.my.type"] == "rest" + assert result["spark.sql.catalog.my.warehouse"] == "user_tgu2" + assert "spark.sql.catalog.my.s3.endpoint" in result + assert result["spark.sql.catalog.my.s3.path-style-access"] == "true" + + def test_tenant_catalog_config(self): + """Test generates tenant catalog config with 'tenant_' prefix stripped.""" + from berdl_notebook_utils.setup_spark_session import _get_catalog_conf + + settings = BERDLSettings() + settings.POLARIS_CATALOG_URI = "http://polaris:8181/api/catalog" # type: ignore + settings.POLARIS_CREDENTIAL = "client_id:client_secret" + settings.POLARIS_PERSONAL_CATALOG = None + settings.POLARIS_TENANT_CATALOGS = "tenant_globalusers,tenant_research" + + result = _get_catalog_conf(settings) + + # "tenant_globalusers" → alias "globalusers" + assert "spark.sql.catalog.globalusers" in result + assert result["spark.sql.catalog.globalusers.warehouse"] == "tenant_globalusers" + # "tenant_research" → alias "research" + assert "spark.sql.catalog.research" in result + assert result["spark.sql.catalog.research.warehouse"] == "tenant_research" + + def test_empty_tenant_catalog_entries_skipped(self): + """Test empty entries in POLARIS_TENANT_CATALOGS are skipped.""" + from berdl_notebook_utils.setup_spark_session import _get_catalog_conf + + settings = BERDLSettings() + settings.POLARIS_CATALOG_URI = "http://polaris:8181/api/catalog" # type: ignore + settings.POLARIS_CREDENTIAL = "client_id:client_secret" + settings.POLARIS_PERSONAL_CATALOG = None + settings.POLARIS_TENANT_CATALOGS = "tenant_kbase,, " + + result = _get_catalog_conf(settings) + + assert "spark.sql.catalog.kbase" in result + # Empty entries should not produce catalog configs + catalog_keys = [ + k for k in result if k.startswith("spark.sql.catalog.") and "." not in k[len("spark.sql.catalog.") :] + ] + assert all("kbase" in k for k in catalog_keys) + + def test_s3_endpoint_without_http_prefix(self): + """Test s3 endpoint gets http:// prefix if missing.""" + from berdl_notebook_utils.setup_spark_session import _get_catalog_conf + + settings = BERDLSettings() + settings.POLARIS_CATALOG_URI = "http://polaris:8181/api/catalog" # type: ignore + settings.POLARIS_CREDENTIAL = "client_id:client_secret" + settings.POLARIS_PERSONAL_CATALOG = "user_test" + settings.POLARIS_TENANT_CATALOGS = None + settings.MINIO_ENDPOINT_URL = "minio:9000" + + result = _get_catalog_conf(settings) + + assert result["spark.sql.catalog.my.s3.endpoint"] == "http://minio:9000" + + def test_both_personal_and_tenant_catalogs(self): + """Test generates config for both personal and tenant catalogs.""" + from berdl_notebook_utils.setup_spark_session import _get_catalog_conf + + settings = BERDLSettings() + settings.POLARIS_CATALOG_URI = "http://polaris:8181/api/catalog" # type: ignore + settings.POLARIS_CREDENTIAL = "client_id:client_secret" + settings.POLARIS_PERSONAL_CATALOG = "user_alice" + settings.POLARIS_TENANT_CATALOGS = "tenant_team" + + result = _get_catalog_conf(settings) + + assert "spark.sql.catalog.my" in result + assert "spark.sql.catalog.team" in result + + +# ============================================================================= +# _is_immutable_config tests +# ============================================================================= + + +class TestIsImmutableConfig: + """Tests for _is_immutable_config function.""" + + def test_known_immutable_configs(self): + """Test known immutable config keys are detected.""" + from berdl_notebook_utils.setup_spark_session import _is_immutable_config + + for key in IMMUTABLE_CONFIGS: + assert _is_immutable_config(key) is True, f"Expected {key} to be immutable" + + def test_catalog_config_keys_are_immutable(self): + """Test that spark.sql.catalog..* keys are immutable.""" + from berdl_notebook_utils.setup_spark_session import _is_immutable_config + + assert _is_immutable_config("spark.sql.catalog.my") is True + assert _is_immutable_config("spark.sql.catalog.my.type") is True + assert _is_immutable_config("spark.sql.catalog.globalusers") is True + assert _is_immutable_config("spark.sql.catalog.globalusers.warehouse") is True + + def test_spark_catalog_is_not_custom_catalog(self): + """Test spark_catalog is handled by IMMUTABLE_CONFIGS set, not the prefix check.""" + from berdl_notebook_utils.setup_spark_session import _is_immutable_config + + # spark.sql.catalog.spark_catalog is in IMMUTABLE_CONFIGS explicitly + assert _is_immutable_config("spark.sql.catalog.spark_catalog") is True + + def test_mutable_config_keys(self): + """Test that non-immutable keys return False.""" + from berdl_notebook_utils.setup_spark_session import _is_immutable_config + + assert _is_immutable_config("spark.app.name") is False + assert _is_immutable_config("spark.sql.autoBroadcastJoinThreshold") is False + assert _is_immutable_config("spark.hadoop.fs.s3a.endpoint") is False diff --git a/notebook_utils/tests/test_refresh.py b/notebook_utils/tests/test_refresh.py index bcc9701..ead4ad1 100644 --- a/notebook_utils/tests/test_refresh.py +++ b/notebook_utils/tests/test_refresh.py @@ -1,8 +1,51 @@ -"""Tests for berdl_notebook_utils.refresh module.""" +"""Tests for refresh.py - Credential and Spark environment refresh.""" +from pathlib import Path from unittest.mock import Mock, patch -from berdl_notebook_utils.refresh import refresh_spark_environment +from berdl_notebook_utils.refresh import _remove_cache_file, refresh_spark_environment + + +class TestRemoveCacheFile: + """Tests for _remove_cache_file helper.""" + + def test_removes_existing_file(self, tmp_path): + """Test removes file and returns True when file exists.""" + cache_file = tmp_path / "test_cache" + cache_file.write_text("cached data") + + result = _remove_cache_file(cache_file) + + assert result is True + assert not cache_file.exists() + + def test_preserves_lock_file(self, tmp_path): + """Test leaves the .lock companion file in place.""" + cache_file = tmp_path / "test_cache" + lock_file = tmp_path / "test_cache.lock" + cache_file.write_text("cached data") + lock_file.write_text("lock") + + _remove_cache_file(cache_file) + + assert not cache_file.exists() + assert lock_file.exists() + + def test_returns_false_when_file_missing(self, tmp_path): + """Test returns False when file doesn't exist.""" + result = _remove_cache_file(tmp_path / "nonexistent") + + assert result is False + + def test_handles_os_error(self, tmp_path): + """Test silently handles OSError on unlink.""" + with ( + patch.object(Path, "exists", return_value=True), + patch.object(Path, "unlink", side_effect=OSError("permission denied")), + ): + result = _remove_cache_file(tmp_path / "test_cache") + + assert result is False class TestRefreshSparkEnvironment: @@ -10,137 +53,164 @@ class TestRefreshSparkEnvironment: @patch("berdl_notebook_utils.refresh.start_spark_connect_server") @patch("berdl_notebook_utils.refresh.SparkSession") + @patch("berdl_notebook_utils.refresh.get_polaris_credentials") @patch("berdl_notebook_utils.refresh.rotate_minio_credentials") @patch("berdl_notebook_utils.refresh.get_settings") - def test_full_refresh_success( - self, - mock_get_settings, - mock_rotate_minio, - mock_spark_session, - mock_start_connect, - ): - """Test successful full refresh with all steps.""" + @patch("berdl_notebook_utils.refresh._remove_cache_file") + def test_happy_path(self, mock_remove, mock_settings, mock_minio, mock_polaris, mock_spark, mock_sc_start): + """Test full refresh with all services succeeding.""" mock_minio_creds = Mock() - mock_minio_creds.username = "testuser" - mock_rotate_minio.return_value = mock_minio_creds - mock_spark_session.getActiveSession.return_value = None - mock_start_connect.return_value = {"pid": 123, "port": 15002, "url": "sc://localhost:15002"} + mock_minio_creds.username = "u_testuser" + mock_minio.return_value = mock_minio_creds + + mock_polaris.return_value = { + "client_id": "abc", + "client_secret": "xyz", + "personal_catalog": "user_testuser", + "tenant_catalogs": ["tenant_a"], + } + + mock_spark.getActiveSession.return_value = None + mock_sc_start.return_value = {"status": "running"} result = refresh_spark_environment() - assert result["minio"] == {"status": "ok", "username": "testuser"} + assert result["minio"] == {"status": "ok", "username": "u_testuser"} + assert result["polaris"]["status"] == "ok" + assert result["polaris"]["personal_catalog"] == "user_testuser" assert result["spark_session_stopped"] is False - assert result["spark_connect"] == {"pid": 123, "port": 15002, "url": "sc://localhost:15002"} - assert mock_get_settings.cache_clear.call_count == 2 - mock_rotate_minio.assert_called_once() - mock_start_connect.assert_called_once_with(force_restart=True) + assert result["spark_connect"] == {"status": "running"} + assert mock_settings.cache_clear.call_count == 2 + mock_sc_start.assert_called_once_with(force_restart=True) @patch("berdl_notebook_utils.refresh.start_spark_connect_server") @patch("berdl_notebook_utils.refresh.SparkSession") + @patch("berdl_notebook_utils.refresh.get_polaris_credentials") @patch("berdl_notebook_utils.refresh.rotate_minio_credentials") @patch("berdl_notebook_utils.refresh.get_settings") + @patch("berdl_notebook_utils.refresh._remove_cache_file") def test_stops_existing_spark_session( - self, - mock_get_settings, - mock_rotate_minio, - mock_spark_session, - mock_start_connect, + self, mock_remove, mock_settings, mock_minio, mock_polaris, mock_spark, mock_sc_start ): - """Test that an existing Spark session is stopped.""" - mock_rotate_minio.return_value = Mock(username="testuser") - mock_existing_session = Mock() - mock_spark_session.getActiveSession.return_value = mock_existing_session - mock_start_connect.return_value = {"pid": 456} + """Test stops active Spark session before restarting Connect server.""" + mock_minio.return_value = Mock(username="u_test") + mock_polaris.return_value = None + + mock_session = Mock() + mock_spark.getActiveSession.return_value = mock_session + mock_sc_start.return_value = {"status": "running"} result = refresh_spark_environment() + mock_session.stop.assert_called_once() assert result["spark_session_stopped"] is True - mock_existing_session.stop.assert_called_once() @patch("berdl_notebook_utils.refresh.start_spark_connect_server") @patch("berdl_notebook_utils.refresh.SparkSession") + @patch("berdl_notebook_utils.refresh.get_polaris_credentials") @patch("berdl_notebook_utils.refresh.rotate_minio_credentials") @patch("berdl_notebook_utils.refresh.get_settings") - def test_minio_rotation_failure_continues( - self, - mock_get_settings, - mock_rotate_minio, - mock_spark_session, - mock_start_connect, + @patch("berdl_notebook_utils.refresh._remove_cache_file") + def test_polaris_not_configured( + self, mock_remove, mock_settings, mock_minio, mock_polaris, mock_spark, mock_sc_start ): - """Test that MinIO rotation failure doesn't stop the rest of the refresh.""" - mock_rotate_minio.side_effect = RuntimeError("API unavailable") - mock_spark_session.getActiveSession.return_value = None - mock_start_connect.return_value = {"pid": 789} + """Test handles Polaris not being configured (returns None).""" + mock_minio.return_value = Mock(username="u_test") + mock_polaris.return_value = None + mock_spark.getActiveSession.return_value = None + mock_sc_start.return_value = {"status": "running"} result = refresh_spark_environment() - assert result["minio"]["status"] == "error" - assert "API unavailable" in result["minio"]["error"] - assert result["spark_connect"] == {"pid": 789} - mock_start_connect.assert_called_once_with(force_restart=True) + assert result["polaris"] == {"status": "skipped", "reason": "Polaris not configured"} @patch("berdl_notebook_utils.refresh.start_spark_connect_server") @patch("berdl_notebook_utils.refresh.SparkSession") + @patch("berdl_notebook_utils.refresh.get_polaris_credentials") @patch("berdl_notebook_utils.refresh.rotate_minio_credentials") @patch("berdl_notebook_utils.refresh.get_settings") - def test_spark_connect_failure_captured( - self, - mock_get_settings, - mock_rotate_minio, - mock_spark_session, - mock_start_connect, + @patch("berdl_notebook_utils.refresh._remove_cache_file") + def test_minio_error_does_not_block( + self, mock_remove, mock_settings, mock_minio, mock_polaris, mock_spark, mock_sc_start ): - """Test that Spark Connect failure is captured in result.""" - mock_rotate_minio.return_value = Mock(username="testuser") - mock_spark_session.getActiveSession.return_value = None - mock_start_connect.side_effect = RuntimeError("start script not found") + """Test that MinIO failure doesn't prevent Polaris/Spark refresh.""" + mock_minio.side_effect = ConnectionError("minio unreachable") + mock_polaris.return_value = { + "client_id": "x", + "client_secret": "y", + "personal_catalog": "user_test", + "tenant_catalogs": [], + } + mock_spark.getActiveSession.return_value = None + mock_sc_start.return_value = {"status": "running"} result = refresh_spark_environment() - assert result["minio"]["status"] == "ok" - assert result["spark_connect"]["status"] == "error" - assert "start script not found" in result["spark_connect"]["error"] + assert result["minio"]["status"] == "error" + assert "minio unreachable" in result["minio"]["error"] + assert result["polaris"]["status"] == "ok" + assert result["spark_connect"] == {"status": "running"} @patch("berdl_notebook_utils.refresh.start_spark_connect_server") @patch("berdl_notebook_utils.refresh.SparkSession") + @patch("berdl_notebook_utils.refresh.get_polaris_credentials") @patch("berdl_notebook_utils.refresh.rotate_minio_credentials") @patch("berdl_notebook_utils.refresh.get_settings") - def test_both_failures_captured( - self, - mock_get_settings, - mock_rotate_minio, - mock_spark_session, - mock_start_connect, + @patch("berdl_notebook_utils.refresh._remove_cache_file") + def test_spark_connect_error_captured( + self, mock_remove, mock_settings, mock_minio, mock_polaris, mock_spark, mock_sc_start ): - """Test that failures in both MinIO and Spark Connect are captured.""" - mock_rotate_minio.side_effect = ConnectionError("MMS down") - mock_spark_session.getActiveSession.return_value = None - mock_start_connect.side_effect = FileNotFoundError("no start script") + """Test that Spark Connect restart failure is captured in result.""" + mock_minio.return_value = Mock(username="u_test") + mock_polaris.return_value = None + mock_spark.getActiveSession.return_value = None + mock_sc_start.side_effect = RuntimeError("server start failed") result = refresh_spark_environment() - assert result["minio"]["status"] == "error" assert result["spark_connect"]["status"] == "error" - assert result["spark_session_stopped"] is False + assert "server start failed" in result["spark_connect"]["error"] @patch("berdl_notebook_utils.refresh.start_spark_connect_server") @patch("berdl_notebook_utils.refresh.SparkSession") + @patch("berdl_notebook_utils.refresh.get_polaris_credentials") @patch("berdl_notebook_utils.refresh.rotate_minio_credentials") @patch("berdl_notebook_utils.refresh.get_settings") - def test_settings_cache_cleared_twice( - self, - mock_get_settings, - mock_rotate_minio, - mock_spark_session, - mock_start_connect, + @patch("berdl_notebook_utils.refresh._remove_cache_file") + def test_cache_files_removed_first( + self, mock_remove, mock_settings, mock_minio, mock_polaris, mock_spark, mock_sc_start ): - """Test that get_settings cache is cleared before and after credential rotation.""" - mock_rotate_minio.return_value = Mock(username="testuser") - mock_spark_session.getActiveSession.return_value = None - mock_start_connect.return_value = {} + """Test that cache files are removed before credentials are re-fetched.""" + mock_minio.return_value = Mock(username="u_test") + mock_polaris.return_value = None + mock_spark.getActiveSession.return_value = None + mock_sc_start.return_value = {"status": "running"} refresh_spark_environment() - # Should be cleared once before rotation and once after - assert mock_get_settings.cache_clear.call_count == 2 + # _remove_cache_file called once (polaris only — MinIO uses direct API calls) + assert mock_remove.call_count == 1 + # Settings cache cleared before credential fetches + mock_settings.cache_clear.assert_called() + + @patch("berdl_notebook_utils.refresh.start_spark_connect_server") + @patch("berdl_notebook_utils.refresh.SparkSession") + @patch("berdl_notebook_utils.refresh.get_polaris_credentials") + @patch("berdl_notebook_utils.refresh.rotate_minio_credentials") + @patch("berdl_notebook_utils.refresh.get_settings") + @patch("berdl_notebook_utils.refresh._remove_cache_file") + def test_all_errors_still_returns_result( + self, mock_remove, mock_settings, mock_minio, mock_polaris, mock_spark, mock_sc_start + ): + """Test that even if everything fails, we get a complete result dict.""" + mock_minio.side_effect = Exception("minio fail") + mock_polaris.side_effect = Exception("polaris fail") + mock_spark.getActiveSession.return_value = None + mock_sc_start.side_effect = Exception("spark fail") + + result = refresh_spark_environment() + + assert result["minio"]["status"] == "error" + assert result["polaris"]["status"] == "error" + assert result["spark_session_stopped"] is False + assert result["spark_connect"]["status"] == "error" diff --git a/scripts/init-polaris-db.sh b/scripts/init-polaris-db.sh new file mode 100755 index 0000000..323d9cb --- /dev/null +++ b/scripts/init-polaris-db.sh @@ -0,0 +1,13 @@ +#!/bin/bash +set -e + +# Create the polaris database in the shared PostgreSQL instance. +# This script is mounted into /docker-entrypoint-initdb.d/ and runs once +# when the PostgreSQL container is first initialized. + +psql -v ON_ERROR_STOP=1 --username "$POSTGRES_USER" --dbname "$POSTGRES_DB" <<-EOSQL + SELECT 'CREATE DATABASE polaris' + WHERE NOT EXISTS (SELECT FROM pg_database WHERE datname = 'polaris')\gexec +EOSQL + +echo "Polaris database created (or already exists)" diff --git a/scripts/migrate_delta_to_iceberg.py b/scripts/migrate_delta_to_iceberg.py new file mode 100644 index 0000000..558efd9 --- /dev/null +++ b/scripts/migrate_delta_to_iceberg.py @@ -0,0 +1,379 @@ +""" +Delta Lake to Iceberg migration utilities for BERDL Phase 4. + +This script migrates Delta Lake tables (Hive Metastore) to Iceberg tables +(Polaris REST catalog), preserving partitions and validating row counts. + +Functions: + migrate_table - Migrate a single Delta table to an Iceberg catalog + migrate_user - Migrate all of a user's Delta databases to their Iceberg catalog + migrate_tenant - Migrate all of a tenant's Delta databases to their Iceberg catalog + +Usage in the migration notebook (migration_phase4.ipynb): + + # Import after adding scripts/ to sys.path + from migrate_delta_to_iceberg import MigrationTracker, migrate_user, migrate_tenant + + tracker = MigrationTracker() + + # Migrate all tables for a user (idempotent — skips existing tables) + migrate_user(spark, "tgu2", target_catalog="user_tgu2", tracker=tracker) + + # Migrate all tables for a tenant + migrate_tenant(spark, "globalusers", target_catalog="tenant_globalusers", tracker=tracker) + + # View results + tracker.to_dataframe(spark).show(truncate=False) + print(tracker.summary()) + +Force re-migration (drops existing Iceberg table and re-copies from Delta): + + # Force re-migrate a single user table + migrate_table( + spark, + hive_db="u_tian_gu_test__demo_personal", + table_name="personal_test_table", + target_catalog="user_tian_gu_test", + target_ns="demo_personal", + tracker=tracker, + force=True, + ) + + # Force re-migrate a single tenant table + migrate_table( + spark, + hive_db="globalusers_demo_shared", + table_name="tenant_test_table", + target_catalog="tenant_globalusers", + target_ns="demo_shared", + tracker=tracker, + force=True, + ) + + # Force re-migrate all tables for a user + migrate_user(spark, "tian_gu_test", target_catalog="user_tian_gu_test", tracker=tracker, force=True) + + # Force re-migrate all tables for a tenant + migrate_tenant(spark, "globalusers", target_catalog="tenant_globalusers", tracker=tracker, force=True) + +Note: + - Requires an admin Spark session configured with cross-user catalog access + (see Section 3 of migration_phase4.ipynb) + - By default, migration is idempotent: tables that already exist in the target + catalog are skipped. Use force=True to drop and re-migrate. + - DROP TABLE PURGE does not delete S3 data files due to an Iceberg bug (#14743). + To fully clean up, delete files from S3 directly using get_minio_client(). +""" + +import logging +from dataclasses import dataclass, field + +from pyspark.sql import SparkSession + +logger = logging.getLogger(__name__) + + +def _clean_error(e: Exception, max_len: int = 150) -> str: + """Extract a short, readable error message without JVM stacktraces.""" + msg = str(e) + # Strip everything after "JVM stacktrace:" if present + if "JVM stacktrace:" in msg: + msg = msg[: msg.index("JVM stacktrace:")].strip() + # Take only the first line + msg = msg.split("\n")[0].strip() + if len(msg) > max_len: + msg = msg[:max_len] + "..." + return msg + + +@dataclass +class TableResult: + """Result of a single table migration.""" + + source: str + target: str + status: str # "migrated", "skipped", "failed" + row_count: int = 0 + error: str = "" + + +@dataclass +class MigrationTracker: + """Tracks migration progress across users and tenants.""" + + results: list[TableResult] = field(default_factory=list) + + @property + def migrated(self) -> list[TableResult]: + return [r for r in self.results if r.status == "migrated"] + + @property + def skipped(self) -> list[TableResult]: + return [r for r in self.results if r.status == "skipped"] + + @property + def failed(self) -> list[TableResult]: + return [r for r in self.results if r.status == "failed"] + + def add(self, result: TableResult): + self.results.append(result) + + def summary(self) -> str: + return ( + f"Total: {len(self.results)} | " + f"Migrated: {len(self.migrated)} | " + f"Skipped: {len(self.skipped)} | " + f"Failed: {len(self.failed)}" + ) + + def to_dataframe(self, spark: SparkSession): + """Convert results to a Spark DataFrame for notebook display.""" + rows = [(r.source, r.target, r.status, r.row_count, r.error) for r in self.results] + return spark.createDataFrame(rows, ["source", "target", "status", "row_count", "error"]) + + +def _validate_target_catalog(spark: SparkSession, target_catalog: str) -> None: + """Raise ValueError if target_catalog is not configured in the Spark session. + + Uses spark.conf.get() instead of SHOW CATALOGS because Spark lazily loads + catalogs — they won't appear in SHOW CATALOGS until first access. + """ + config_key = f"spark.sql.catalog.{target_catalog}" + try: + spark.conf.get(config_key) + except Exception: + raise ValueError( + f"Catalog '{target_catalog}' is not configured in the current Spark session " + f"(no {config_key} property found). " + f"Did you run Section 3 (Configure Admin Spark) and restart Spark Connect?" + ) + + +def table_exists_in_catalog(spark: SparkSession, catalog: str, namespace: str, table_name: str) -> bool: + """Check if a table already exists in the target Iceberg catalog. + + Uses SHOW TABLES instead of DESCRIBE TABLE to avoid JVM-level + TABLE_OR_VIEW_NOT_FOUND error logs when the table doesn't exist. + """ + try: + tables = spark.sql(f"SHOW TABLES IN {catalog}.{namespace}").collect() + return any(row["tableName"] == table_name for row in tables) + except Exception: + return False + + +def migrate_table( + spark: SparkSession, + hive_db: str, + table_name: str, + target_catalog: str, + target_ns: str, + tracker: MigrationTracker | None = None, + force: bool = False, +): + """ + Migrate a single Delta table to Iceberg via Polaris, preserving partitions. + + Args: + spark: Active Spark session + hive_db: Original Hive/Delta database name + table_name: Original table name + target_catalog: Target Iceberg catalog name (e.g., 'my') + target_ns: Target namespace in Iceberg (e.g., 'test_db') + tracker: Optional MigrationTracker for progress tracking + force: If True, drop existing target table and re-migrate + """ + source_ref = f"{hive_db}.{table_name}" + target_table_ref = f"{target_catalog}.{target_ns}.{table_name}" + print(f" {source_ref} -> {target_table_ref}") + logger.info(f"Starting migration for {source_ref} -> {target_table_ref}") + + # 0. Idempotency: skip if target already exists (unless force=True) + if table_exists_in_catalog(spark, target_catalog, target_ns, table_name): + if force: + logger.info(f"Force mode: dropping existing {target_table_ref}") + spark.sql(f"DROP TABLE {target_table_ref} PURGE") + else: + logger.info(f"Skipping {target_table_ref} — already exists in target catalog") + if tracker: + tracker.add(TableResult(source=source_ref, target=target_table_ref, status="skipped")) + return + + # 1. Read from Delta using spark.table fallback + try: + df = spark.table(source_ref) + except Exception as e: + err_str = str(e) + if "DELTA_READ_TABLE_WITHOUT_COLUMNS" in err_str: + msg = "Delta table has no columns (empty schema) — skipping" + logger.warning(f"Skipping {source_ref} — {msg}") + if tracker: + tracker.add(TableResult(source=source_ref, target=target_table_ref, status="skipped", error=msg)) + return + short_err = _clean_error(e) + print(f" FAILED (read): {short_err}") + logger.error(f"Failed to read source table {source_ref}: {short_err}") + if tracker: + tracker.add(TableResult(source=source_ref, target=target_table_ref, status="failed", error=short_err)) + return + + # 1b. Skip tables with empty schema (corrupt Delta tables with no columns) + if len(df.columns) == 0: + msg = "Delta table has no columns (empty schema)" + logger.warning(f"Skipping {source_ref} — {msg}") + if tracker: + tracker.add(TableResult(source=source_ref, target=target_table_ref, status="skipped", error=msg)) + return + + # 2. Extract partition columns from the original Delta table + partition_cols: list[str] = [] + try: + partition_cols = [row.name for row in spark.catalog.listColumns(f"{hive_db}.{table_name}") if row.isPartition] + if partition_cols: + logger.info(f"Found partition columns: {partition_cols}") + except Exception as e: + logger.warning(f"Could not fetch partitions for {source_ref} via catalog API: {e}") + + # 3. Create target namespace, write data, and validate + try: + spark.sql(f"CREATE NAMESPACE IF NOT EXISTS {target_catalog}.{target_ns}") + + # 4. Write as Iceberg (applying partition logic if it existed) + writer = df.writeTo(target_table_ref) + if partition_cols: + writer = writer.partitionedBy(*partition_cols) + + logger.info(f"Writing data to {target_table_ref}...") + writer.create() + logger.info(f"Write completed for {target_table_ref}") + + # 5. Validate row counts + original_count = spark.sql(f"SELECT COUNT(*) as cnt FROM {source_ref}").collect()[0]["cnt"] + migrated_count = spark.sql(f"SELECT COUNT(*) as cnt FROM {target_table_ref}").collect()[0]["cnt"] + + if original_count != migrated_count: + msg = f"Row count mismatch: {original_count} vs {migrated_count}" + logger.error(f"Validation FAILED: {msg}") + if tracker: + tracker.add( + TableResult( + source=source_ref, + target=target_table_ref, + status="failed", + row_count=migrated_count, + error=msg, + ) + ) + raise ValueError(msg) + + logger.info(f"Validation SUCCESS: {migrated_count} rows migrated exactly.") + if tracker: + tracker.add( + TableResult( + source=source_ref, + target=target_table_ref, + status="migrated", + row_count=migrated_count, + ) + ) + except Exception as e: + short_err = _clean_error(e) + print(f" FAILED: {short_err}") + logger.error(f"Failed to migrate {source_ref} -> {target_table_ref}: {short_err}") + if tracker and not any(r.target == target_table_ref for r in tracker.results): + tracker.add(TableResult(source=source_ref, target=target_table_ref, status="failed", error=short_err)) + + +def migrate_user( + spark: SparkSession, + username: str, + target_catalog: str = "my", + tracker: MigrationTracker | None = None, + force: bool = False, +): + """ + Migrate all of a user's Delta databases to their Iceberg catalog. + + Args: + spark: Active SparkSession + username: The user's username + target_catalog: The target catalog (e.g., 'user_{username}') + tracker: Optional MigrationTracker for progress tracking + force: If True, drop existing target tables and re-migrate + """ + _validate_target_catalog(spark, target_catalog) + + prefix = f"u_{username}__" + databases = [db[0] for db in spark.sql("SHOW DATABASES").collect() if db[0].startswith(prefix)] + + if not databases: + logger.info(f"No databases found for username {username} with prefix {prefix}") + return + + for hive_db in databases: + iceberg_ns = hive_db.replace(prefix, "", 1) # "u_tgu2__test_db" -> "test_db" + logger.info(f"Scanning database {hive_db}...") + try: + tables = spark.sql(f"SHOW TABLES IN {hive_db}").collect() + except Exception as e: + print(f" Error listing tables in {hive_db}: {_clean_error(e)}") + continue + + for table_row in tables: + table_name = table_row["tableName"] + try: + migrate_table(spark, hive_db, table_name, target_catalog, iceberg_ns, tracker, force=force) + except Exception as e: + print(f" FAILED (unexpected): {_clean_error(e)}") + + +def migrate_tenant( + spark: SparkSession, + tenant_name: str, + target_catalog: str, + tracker: MigrationTracker | None = None, + force: bool = False, +): + """ + Migrate all Delta databases for a tenant to their Iceberg catalog. + + Tenant databases follow the pattern: {tenant_name}_{dbname} in Hive. + The {tenant_name}_ prefix is stripped to get the Iceberg namespace. + + Args: + spark: Active SparkSession + tenant_name: The tenant/group name (e.g., 'kbase') + target_catalog: Target Iceberg catalog (e.g., 'tenant_kbase') + tracker: Optional MigrationTracker for progress tracking + force: If True, drop existing target tables and re-migrate + """ + _validate_target_catalog(spark, target_catalog) + + prefix = f"{tenant_name}_" + databases = [db[0] for db in spark.sql("SHOW DATABASES").collect() if db[0].startswith(prefix)] + + if not databases: + logger.info(f"No databases found for tenant {tenant_name} with prefix {prefix}") + return + + for hive_db in databases: + iceberg_ns = hive_db.replace(prefix, "", 1) + logger.info(f"Scanning tenant database {hive_db}...") + try: + tables = spark.sql(f"SHOW TABLES IN {hive_db}").collect() + except Exception as e: + print(f" Error listing tables in {hive_db}: {_clean_error(e)}") + continue + + for table_row in tables: + table_name = table_row["tableName"] + try: + migrate_table(spark, hive_db, table_name, target_catalog, iceberg_ns, tracker, force=force) + except Exception as e: + print(f" FAILED (unexpected): {_clean_error(e)}") + + +if __name__ == "__main__": + # If this is run independently we set up basic print logging + logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s") + print("Migration functions are loaded. Supply an active spark session to begin.")