Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
184 changes: 183 additions & 1 deletion datafusion/catalog/src/memory/schema.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

//! [`MemorySchemaProvider`]: In-memory implementations of [`SchemaProvider`].

use crate::{SchemaProvider, TableProvider};
use crate::{SchemaProvider, TableFunction, TableProvider};
use async_trait::async_trait;
use dashmap::DashMap;
use datafusion_common::{exec_err, DataFusionError};
Expand All @@ -28,13 +28,15 @@ use std::sync::Arc;
#[derive(Debug)]
pub struct MemorySchemaProvider {
tables: DashMap<String, Arc<dyn TableProvider>>,
table_functions: DashMap<String, Arc<TableFunction>>,
}

impl MemorySchemaProvider {
/// Instantiates a new MemorySchemaProvider with an empty collection of tables.
pub fn new() -> Self {
Self {
tables: DashMap::new(),
table_functions: DashMap::new(),
}
}
}
Expand Down Expand Up @@ -86,4 +88,184 @@ impl SchemaProvider for MemorySchemaProvider {
fn table_exist(&self, name: &str) -> bool {
self.tables.contains_key(name)
}

fn udtf_names(&self) -> Vec<String> {
self.table_functions
.iter()
.map(|f| f.key().clone())
.collect()
}

fn udtf(
&self,
name: &str,
) -> datafusion_common::Result<Option<Arc<TableFunction>>, DataFusionError> {
Ok(self
.table_functions
.get(name)
.map(|f| Arc::clone(f.value())))
}

fn register_udtf(
&self,
name: String,
function: Arc<TableFunction>,
) -> datafusion_common::Result<Option<Arc<TableFunction>>> {
if self.udtf_exist(name.as_str()) {
return exec_err!("The table function {name} already exists");
}
Ok(self.table_functions.insert(name, function))
}

fn deregister_udtf(
&self,
name: &str,
) -> datafusion_common::Result<Option<Arc<TableFunction>>> {
Ok(self.table_functions.remove(name).map(|(_, f)| f))
}

fn udtf_exist(&self, name: &str) -> bool {
self.table_functions.contains_key(name)
}
}

#[cfg(test)]
mod tests {
use super::*;
use crate::table::TableFunctionImpl;
use crate::Session;
use arrow::datatypes::Schema;
use datafusion_common::Result;
use datafusion_expr::{Expr, TableType};
use datafusion_physical_plan::ExecutionPlan;

#[derive(Debug)]
struct DummyTableFunc;

#[derive(Debug)]
struct DummyTable {
schema: arrow::datatypes::SchemaRef,
}

#[async_trait::async_trait]
impl TableProvider for DummyTable {
fn as_any(&self) -> &dyn Any {
self
}

fn schema(&self) -> arrow::datatypes::SchemaRef {
self.schema.clone()
}

fn table_type(&self) -> TableType {
TableType::Base
}

async fn scan(
&self,
_state: &dyn Session,
_projection: Option<&Vec<usize>>,
_filters: &[Expr],
_limit: Option<usize>,
) -> Result<Arc<dyn ExecutionPlan>> {
datafusion_common::plan_err!("DummyTable does not support scanning")
}
}

impl TableFunctionImpl for DummyTableFunc {
fn call(&self, _args: &[Expr]) -> Result<Arc<dyn TableProvider>> {
Ok(Arc::new(DummyTable {
schema: Arc::new(Schema::empty()),
}))
}
}

#[test]
fn test_register_and_retrieve_udtf() {
let schema = MemorySchemaProvider::new();
let func = Arc::new(TableFunction::new(
"my_func".to_string(),
Arc::new(DummyTableFunc),
));

let result = schema.register_udtf("my_func".to_string(), func.clone());
assert!(result.is_ok());
assert!(result.unwrap().is_none());

assert!(schema.udtf_exist("my_func"));
assert_eq!(schema.udtf_names(), vec!["my_func"]);

let retrieved = schema.udtf("my_func").unwrap();
assert!(retrieved.is_some());
assert_eq!(retrieved.unwrap().name(), "my_func");
}

#[test]
fn test_duplicate_udtf_registration_fails() {
let schema = MemorySchemaProvider::new();
let func = Arc::new(TableFunction::new(
"my_func".to_string(),
Arc::new(DummyTableFunc),
));

schema
.register_udtf("my_func".to_string(), func.clone())
.unwrap();

let result = schema.register_udtf("my_func".to_string(), func.clone());
assert!(result.is_err());
assert!(result.unwrap_err().to_string().contains("already exists"));
}

#[test]
fn test_deregister_udtf() {
let schema = MemorySchemaProvider::new();
let func = Arc::new(TableFunction::new(
"my_func".to_string(),
Arc::new(DummyTableFunc),
));

schema.register_udtf("my_func".to_string(), func).unwrap();
assert!(schema.udtf_exist("my_func"));

let removed = schema.deregister_udtf("my_func").unwrap();
assert!(removed.is_some());
assert!(!schema.udtf_exist("my_func"));
assert_eq!(schema.udtf_names(), Vec::<String>::new());

let removed = schema.deregister_udtf("my_func").unwrap();
assert!(removed.is_none());
}

#[test]
fn test_udtf_not_found() {
let schema = MemorySchemaProvider::new();

assert!(!schema.udtf_exist("nonexistent"));
let result = schema.udtf("nonexistent").unwrap();
assert!(result.is_none());
}

#[test]
fn test_multiple_udtfs() {
let schema = MemorySchemaProvider::new();
let func1 = Arc::new(TableFunction::new(
"func1".to_string(),
Arc::new(DummyTableFunc),
));
let func2 = Arc::new(TableFunction::new(
"func2".to_string(),
Arc::new(DummyTableFunc),
));

schema.register_udtf("func1".to_string(), func1).unwrap();
schema.register_udtf("func2".to_string(), func2).unwrap();

let mut names = schema.udtf_names();
names.sort();
assert_eq!(names, vec!["func1", "func2"]);

assert!(schema.udtf_exist("func1"));
assert!(schema.udtf_exist("func2"));
}
}
39 changes: 38 additions & 1 deletion datafusion/catalog/src/schema.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ use std::any::Any;
use std::fmt::Debug;
use std::sync::Arc;

use crate::table::TableProvider;
use crate::table::{TableFunction, TableProvider};
use datafusion_common::Result;
use datafusion_expr::TableType;

Expand Down Expand Up @@ -88,4 +88,41 @@ pub trait SchemaProvider: Debug + Sync + Send {

/// Returns true if table exist in the schema provider, false otherwise.
fn table_exist(&self, name: &str) -> bool;

/// Retrieves the list of available table function names in this schema.
fn udtf_names(&self) -> Vec<String> {
vec![]
}

/// Retrieves a specific table function from the schema by name, if it exists,
/// otherwise returns `None`.
fn udtf(&self, _name: &str) -> Result<Option<Arc<TableFunction>>> {
Ok(None)
}

/// If supported by the implementation, adds a new table function named `name` to
/// this schema.
///
/// If a table function of the same name was already registered, returns "Table
/// function already exists" error.
fn register_udtf(
&self,
_name: String,
_function: Arc<TableFunction>,
) -> Result<Option<Arc<TableFunction>>> {
exec_err!("schema provider does not support registering table functions")
}

/// If supported by the implementation, removes the `name` table function from this
/// schema and returns the previously registered [`TableFunction`], if any.
///
/// If no `name` table function exists, returns Ok(None).
fn deregister_udtf(&self, _name: &str) -> Result<Option<Arc<TableFunction>>> {
exec_err!("schema provider does not support deregistering table functions")
}

/// Returns true if table function exists in the schema provider, false otherwise.
fn udtf_exist(&self, _name: &str) -> bool {
false
}
}
27 changes: 20 additions & 7 deletions datafusion/core/src/execution/session_state.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1699,21 +1699,34 @@ impl ContextProvider for SessionContextProvider<'_> {
name: &str,
args: Vec<Expr>,
) -> datafusion_common::Result<Arc<dyn TableSource>> {
let tbl_func = self
.state
.table_functions
.get(name)
.cloned()
.ok_or_else(|| plan_datafusion_err!("table function '{name}' not found"))?;
let table_ref = TableReference::parse_str(name);

let dummy_schema = DFSchema::empty();
let simplifier =
ExprSimplifier::new(SessionSimplifyProvider::new(self.state, &dummy_schema));
let args = args
.into_iter()
.map(|arg| simplifier.simplify(arg))
.collect::<datafusion_common::Result<Vec<_>>>()?;
let provider = tbl_func.create_table_provider(&args)?;

let tbl_func = if table_ref.schema().is_some() {
let func_name = table_ref.table().to_string();
let schema = self.state.schema_for_ref(table_ref)?;

schema.udtf(&func_name)?.ok_or_else(|| {
plan_datafusion_err!("Table function '{}' not found in schema", name)
})?
} else {
self.state
.table_functions
.get(name)
.cloned()
.ok_or_else(|| {
plan_datafusion_err!("table function '{name}' not found")
})?
};

let provider = tbl_func.create_table_provider(&args)?;
Ok(provider_as_source(provider))
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ use datafusion::execution::TaskContext;
use datafusion::physical_plan::{collect, ExecutionPlan};
use datafusion::prelude::SessionContext;
use datafusion_catalog::Session;
use datafusion_catalog::TableFunctionImpl;
use datafusion_catalog::{SchemaProvider, TableFunctionImpl};
use datafusion_common::{DFSchema, ScalarValue};
use datafusion_expr::{EmptyRelation, Expr, LogicalPlan, Projection, TableType};

Expand Down Expand Up @@ -109,6 +109,76 @@ async fn test_deregister_udtf() -> Result<()> {
Ok(())
}

#[tokio::test]
async fn test_schema_qualified_udtf() -> Result<()> {
let ctx = SessionContext::new();

let catalog = ctx.catalog("datafusion").unwrap();
let schema = catalog.schema("public").unwrap();
let memory_schema = schema
.as_any()
.downcast_ref::<datafusion_catalog::MemorySchemaProvider>()
.unwrap();

let func = Arc::new(datafusion_catalog::TableFunction::new(
"schema_func".to_string(),
Arc::new(SimpleCsvTableFunc {}),
));
memory_schema
.register_udtf("schema_func".to_string(), func)
.unwrap();

let csv_file = "tests/tpch-csv/nation.csv";
let rbs = ctx
.sql(format!("SELECT * FROM public.schema_func('{csv_file}', 3);").as_str())
.await?
.collect()
.await?;

assert_eq!(rbs[0].num_rows(), 3);

Ok(())
}

/// Test that unqualified names still use global registry (backward compatibility)
#[tokio::test]
async fn test_unqualified_uses_global_registry() -> Result<()> {
let ctx = SessionContext::new();

ctx.register_udtf("global_func", Arc::new(SimpleCsvTableFunc {}));

let csv_file = "tests/tpch-csv/nation.csv";
let rbs = ctx
.sql(format!("SELECT * FROM global_func('{csv_file}', 2);").as_str())
.await?
.collect()
.await?;

assert_eq!(rbs[0].num_rows(), 2);

Ok(())
}

#[tokio::test]
async fn test_schema_qualified_not_in_global() -> Result<()> {
let ctx = SessionContext::new();

ctx.register_udtf("global_only", Arc::new(SimpleCsvTableFunc {}));

let csv_file = "tests/tpch-csv/nation.csv";
let result = ctx
.sql(format!("SELECT * FROM public.global_only('{csv_file}');").as_str())
.await;

assert!(result.is_err());
assert!(result
.unwrap_err()
.to_string()
.contains("not found in schema"));

Ok(())
}

#[derive(Debug)]
struct SimpleCsvTable {
schema: SchemaRef,
Expand Down
Loading