Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
183 changes: 168 additions & 15 deletions crates/integrations/datafusion/src/catalog.rs
Original file line number Diff line number Diff line change
Expand Up @@ -50,30 +50,21 @@ impl IcebergCatalogProvider {
// TODO:
// Schemas and providers should be cached and evicted based on time
// As of right now; schemas might become stale.
let schema_names: Vec<_> = client
.list_namespaces(None)
.await?
.iter()
.flat_map(|ns| ns.as_ref().clone())
.collect();
let namespace_idents = fetch_all_namespaces(client.as_ref()).await?;

let providers = try_join_all(
schema_names
namespace_idents
.iter()
.map(|name| {
IcebergSchemaProvider::try_new(
client.clone(),
NamespaceIdent::new(name.clone()),
)
})
.map(|nsi| IcebergSchemaProvider::try_new(client.clone(), nsi.clone()))
.collect::<Vec<_>>(),
)
.await?;

let schemas: HashMap<String, Arc<dyn SchemaProvider>> = schema_names
let schemas: HashMap<String, Arc<dyn SchemaProvider>> = namespace_idents
.into_iter()
.zip(providers.into_iter())
.map(|(name, provider)| {
.map(|(nsi, provider)| {
let name = nsi.as_ref().join(".");
let provider = Arc::new(provider) as Arc<dyn SchemaProvider>;
(name, provider)
})
Expand All @@ -83,6 +74,22 @@ impl IcebergCatalogProvider {
}
}

async fn fetch_all_namespaces(client: &dyn Catalog) -> Result<Vec<NamespaceIdent>> {
let mut all_namespaces = Vec::new();
let mut queue = std::collections::VecDeque::new();
queue.push_back(None);

while let Some(parent) = queue.pop_front() {
let children = client.list_namespaces(parent.as_ref()).await?;
for child in children {
all_namespaces.push(child.clone());
queue.push_back(Some(child));
}
}

Ok(all_namespaces)
}

impl CatalogProvider for IcebergCatalogProvider {
fn as_any(&self) -> &dyn Any {
self
Expand All @@ -96,3 +103,149 @@ impl CatalogProvider for IcebergCatalogProvider {
self.schemas.get(name).cloned()
}
}

#[cfg(test)]
mod tests {
use iceberg::memory::{MEMORY_CATALOG_WAREHOUSE, MemoryCatalogBuilder};
use iceberg::spec::{NestedField, PrimitiveType, Schema, Type};
use iceberg::{CatalogBuilder, NamespaceIdent, TableCreation};

use super::*;

async fn create_catalog() -> Arc<dyn Catalog> {
let mut props = HashMap::new();
props.insert(
MEMORY_CATALOG_WAREHOUSE.to_string(),
"memory://".to_string(),
);

let catalog = MemoryCatalogBuilder::default()
.load("test", props)
.await
.unwrap();

Arc::new(catalog)
}

#[tokio::test]
async fn test_iceberg_catalog_provider_empty() {
let catalog = create_catalog().await;

let provider = IcebergCatalogProvider::try_new(catalog).await.unwrap();

assert!(provider.schema_names().is_empty());
}

#[tokio::test]
async fn test_iceberg_catalog_provider_single_namespace() {
let catalog = create_catalog().await;
let ns = NamespaceIdent::new("a".to_string());
catalog.create_namespace(&ns, HashMap::new()).await.unwrap();

let provider = IcebergCatalogProvider::try_new(catalog).await.unwrap();
let schema_names = provider.schema_names();

assert_eq!(schema_names.len(), 1);
assert!(schema_names.contains(&"a".to_string()));
assert!(provider.schema("a").is_some());
}

#[tokio::test]
async fn test_iceberg_catalog_provider_with_table() {
let catalog = create_catalog().await;
let ns = NamespaceIdent::new("a".to_string());
catalog.create_namespace(&ns, HashMap::new()).await.unwrap();

let schema = Schema::builder()
.with_fields(vec![
NestedField::required(1, "foo", Type::Primitive(PrimitiveType::Int)).into(),
])
.build()
.unwrap();

let table_creation = TableCreation::builder()
.name("t".to_string())
.schema(schema)
.build();

catalog.create_table(&ns, table_creation).await.unwrap();

let provider = IcebergCatalogProvider::try_new(catalog).await.unwrap();

let schema_provider = provider.schema("a").unwrap();
let table_names = schema_provider.table_names();

assert!(table_names.contains(&"t".to_string()));
assert!(schema_provider.table("t").await.unwrap().is_some());
}

#[tokio::test]
async fn test_iceberg_catalog_provider_nested_namespaces() {
let catalog = create_catalog().await;
let ns1 = NamespaceIdent::new("a".to_string());
let ns2 = NamespaceIdent::from_vec(vec!["a".to_string(), "b".to_string()]).unwrap();
catalog
.create_namespace(&ns1, HashMap::new())
.await
.unwrap();
catalog
.create_namespace(&ns2, HashMap::new())
.await
.unwrap();

let provider = IcebergCatalogProvider::try_new(catalog).await.unwrap();
let schema_names = provider.schema_names();

// This will fail since only list_namespaces(None) is used.
assert!(schema_names.contains(&"a".to_string()));
assert!(schema_names.contains(&"a.b".to_string()));
assert_eq!(schema_names.len(), 2);

assert!(provider.schema("a").is_some());
assert!(provider.schema("a.b").is_some());
}

#[tokio::test]
async fn test_fetch_all_namespaces_empty() {
let catalog = create_catalog().await;
let namespaces = fetch_all_namespaces(catalog.as_ref()).await.unwrap();
assert!(namespaces.is_empty());
}

#[tokio::test]
async fn test_fetch_all_namespaces_one() {
let catalog = create_catalog().await;
let ns = NamespaceIdent::new("a".to_string());
catalog.create_namespace(&ns, HashMap::new()).await.unwrap();

let namespaces = fetch_all_namespaces(catalog.as_ref()).await.unwrap();
assert_eq!(namespaces.len(), 1);
assert!(namespaces.contains(&ns));
}

#[tokio::test]
async fn test_fetch_all_namespaces_nested() {
let catalog = create_catalog().await;
let ns1 = NamespaceIdent::new("a".to_string());
let ns2 = NamespaceIdent::from_vec(vec!["a".to_string(), "b".to_string()]).unwrap();
let ns3 = NamespaceIdent::from_vec(vec!["a".to_string(), "b".to_string(), "c".to_string()])
.unwrap();
catalog
.create_namespace(&ns1, HashMap::new())
.await
.unwrap();
catalog
.create_namespace(&ns2, HashMap::new())
.await
.unwrap();
catalog
.create_namespace(&ns3, HashMap::new())
.await
.unwrap();

let namespaces = fetch_all_namespaces(catalog.as_ref()).await.unwrap();
assert!(namespaces.contains(&ns1));
assert!(namespaces.contains(&ns2));
assert!(namespaces.contains(&ns3));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -598,8 +598,8 @@ async fn test_insert_into_nested() -> Result<()> {
// Insert data with nested structs
let insert_sql = r#"
INSERT INTO catalog.test_insert_nested.nested_table
SELECT
1 as id,
SELECT
1 as id,
'Alice' as name,
named_struct(
'address', named_struct(
Expand All @@ -613,8 +613,8 @@ async fn test_insert_into_nested() -> Result<()> {
)
) as profile
UNION ALL
SELECT
2 as id,
SELECT
2 as id,
'Bob' as name,
named_struct(
'address', named_struct(
Expand Down Expand Up @@ -736,15 +736,15 @@ async fn test_insert_into_nested() -> Result<()> {
let df = ctx
.sql(
r#"
SELECT
id,
SELECT
id,
name,
profile.address.street,
profile.address.city,
profile.address.zip,
profile.contact.email,
profile.contact.phone
FROM catalog.test_insert_nested.nested_table
FROM catalog.test_insert_nested.nested_table
ORDER BY id
"#,
)
Expand Down Expand Up @@ -850,8 +850,8 @@ async fn test_insert_into_partitioned() -> Result<()> {
let df = ctx
.sql(
r#"
INSERT INTO catalog.test_partitioned_write.partitioned_table
VALUES
INSERT INTO catalog.test_partitioned_write.partitioned_table
VALUES
(1, 'electronics', 'laptop'),
(2, 'electronics', 'phone'),
(3, 'books', 'novel'),
Expand Down Expand Up @@ -943,3 +943,64 @@ async fn test_insert_into_partitioned() -> Result<()> {

Ok(())
}

#[tokio::test]
async fn test_child_namespace_crud() -> Result<()> {
let iceberg_catalog = get_iceberg_catalog().await;

let parent_ns = NamespaceIdent::new("parent_ns".to_string());
set_test_namespace(&iceberg_catalog, &parent_ns).await?;
let child_ns = NamespaceIdent::from_vec(vec!["parent_ns".to_string(), "child_ns".to_string()])?;
set_test_namespace(&iceberg_catalog, &child_ns).await?;

let creation = get_table_creation(temp_path(), "t", None)?;
iceberg_catalog.create_table(&child_ns, creation).await?;

let client = Arc::new(iceberg_catalog);
let catalog = Arc::new(IcebergCatalogProvider::try_new(client).await?);

let ctx = SessionContext::new();
ctx.register_catalog("catalog", catalog);

let df = ctx
.sql("INSERT INTO catalog.\"parent_ns.child_ns\".t VALUES (1, 'test')")
.await
.unwrap();

let batches = df.collect().await.unwrap();
assert_eq!(batches.len(), 1);
let batch = &batches[0];
let rows_inserted = batch
.column(0)
.as_any()
.downcast_ref::<UInt64Array>()
.unwrap();
assert_eq!(rows_inserted.value(0), 1);

let df = ctx
.sql("SELECT * FROM catalog.\"parent_ns.child_ns\".t")
.await
.unwrap();

let batches = df.collect().await.unwrap();

check_record_batches(
batches,
expect![[r#"
Field { "foo1": Int32, metadata: {"PARQUET:field_id": "1"} },
Field { "foo2": Utf8, metadata: {"PARQUET:field_id": "2"} }"#]],
expect![[r#"
foo1: PrimitiveArray<Int32>
[
1,
],
foo2: StringArray
[
"test",
]"#]],
&[],
Some("foo1"),
);

Ok(())
}
Loading