diff --git a/src/lib.rs b/src/lib.rs index 9a0a9bb..c9396da 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -236,8 +236,11 @@ impl Driver for DataFusionDriver { type DatabaseType = DataFusionDatabase; fn new_database(&mut self) -> Result { + let config = SessionConfig::new().with_information_schema(true); + let ctx = SessionContext::new_with_config(config); Ok(Self::DatabaseType { handle: self.handle.clone(), + ctx: Arc::new(ctx), }) } @@ -250,8 +253,11 @@ impl Driver for DataFusionDriver { ), >, ) -> adbc_core::error::Result { + let config = SessionConfig::new().with_information_schema(true); + let ctx = SessionContext::new_with_config(config); let mut database = Self::DatabaseType { handle: self.handle.clone(), + ctx: Arc::new(ctx), }; for (key, value) in opts { database.set_option(key, value)?; @@ -262,6 +268,7 @@ impl Driver for DataFusionDriver { pub struct DataFusionDatabase { handle: Option, + ctx: Arc, } impl Optionable for DataFusionDatabase { @@ -296,9 +303,6 @@ impl Database for DataFusionDatabase { type ConnectionType = DataFusionConnection; fn new_connection(&self) -> Result { - let config = SessionConfig::new().with_information_schema(true); - let ctx = SessionContext::new_with_config(config); - let runtime = Runtime::new(self.handle.clone()).map_err(|e| { ErrorHelper::io() .context("create Tokio runtime") @@ -308,7 +312,7 @@ impl Database for DataFusionDatabase { Ok(DataFusionConnection { runtime: Arc::new(runtime), - ctx: Arc::new(ctx), + ctx: self.ctx.clone(), }) } @@ -321,9 +325,6 @@ impl Database for DataFusionDatabase { ), >, ) -> adbc_core::error::Result { - let config = SessionConfig::new().with_information_schema(true); - let ctx = SessionContext::new_with_config(config); - let runtime = Runtime::new(self.handle.clone()).map_err(|e| { ErrorHelper::io() .context("create Tokio runtime") @@ -333,7 +334,7 @@ impl Database for DataFusionDatabase { let mut connection = DataFusionConnection { runtime: Arc::new(runtime), - ctx: Arc::new(ctx), + ctx: self.ctx.clone(), }; for (key, value) in opts { @@ -360,6 +361,12 @@ impl Optionable for DataFusionConnection { match key.as_ref() { constants::ADBC_CONNECTION_OPTION_CURRENT_CATALOG => match value { OptionValue::String(value) => { + if !self.ctx.catalog_names().contains(&value) { + return Err(ErrorHelper::not_found() + .context("set current catalog") + .format(format_args!("catalog '{value}' does not exist")) + .to_adbc()); + } self.runtime.block_on(async { let query = format!("SET datafusion.catalog.default_catalog = {value}"); self.ctx @@ -379,6 +386,22 @@ impl Optionable for DataFusionConnection { }, constants::ADBC_CONNECTION_OPTION_CURRENT_DB_SCHEMA => match value { OptionValue::String(value) => { + let state = self.ctx.state(); + let catalog_name = &state.config_options().catalog.default_catalog; + let catalog = self.ctx.catalog(catalog_name).ok_or_else(|| { + ErrorHelper::not_found() + .context("set current schema") + .format(format_args!("catalog '{catalog_name}' does not exist")) + .to_adbc() + })?; + if !catalog.schema_names().contains(&value) { + return Err(ErrorHelper::not_found() + .context("set current schema") + .format(format_args!( + "schema '{value}' does not exist in catalog '{catalog_name}'" + )) + .to_adbc()); + } self.runtime.block_on(async { let query = format!("SET datafusion.catalog.default_schema = {value}"); self.ctx diff --git a/tests/test_datafusion.rs b/tests/test_datafusion.rs index f44a486..e3c489b 100644 --- a/tests/test_datafusion.rs +++ b/tests/test_datafusion.rs @@ -93,10 +93,18 @@ fn test_connection_options() { assert_eq!(current_catalog, "datafusion"); - let _ = connection.set_option( - OptionConnection::CurrentCatalog, - OptionValue::String("datafusion2".to_string()), - ); + // Create the secondary catalog and schema before switching + let mut stmt = connection.new_statement().unwrap(); + stmt.set_sql_query("CREATE DATABASE IF NOT EXISTS datafusion2") + .unwrap(); + stmt.execute_update().unwrap(); + + connection + .set_option( + OptionConnection::CurrentCatalog, + OptionValue::String("datafusion2".to_string()), + ) + .unwrap(); let current_catalog = connection .get_option_string(OptionConnection::CurrentCatalog) @@ -104,22 +112,54 @@ fn test_connection_options() { assert_eq!(current_catalog, "datafusion2"); + // Switch back and create a secondary schema + connection + .set_option( + OptionConnection::CurrentCatalog, + OptionValue::String("datafusion".to_string()), + ) + .unwrap(); + + let mut stmt = connection.new_statement().unwrap(); + stmt.set_sql_query("CREATE SCHEMA IF NOT EXISTS public2") + .unwrap(); + stmt.execute_update().unwrap(); + let current_schema = connection .get_option_string(OptionConnection::CurrentSchema) .unwrap(); assert_eq!(current_schema, "public"); - let _ = connection.set_option( - OptionConnection::CurrentSchema, - OptionValue::String("public2".to_string()), - ); + connection + .set_option( + OptionConnection::CurrentSchema, + OptionValue::String("public2".to_string()), + ) + .unwrap(); let current_schema = connection .get_option_string(OptionConnection::CurrentSchema) .unwrap(); assert_eq!(current_schema, "public2"); + + // Verify setting nonexistent catalog/schema returns an error + let err = connection + .set_option( + OptionConnection::CurrentCatalog, + OptionValue::String("nonexistent".to_string()), + ) + .unwrap_err(); + assert_eq!(err.status, adbc_core::error::Status::NotFound); + + let err = connection + .set_option( + OptionConnection::CurrentSchema, + OptionValue::String("nonexistent".to_string()), + ) + .unwrap_err(); + assert_eq!(err.status, adbc_core::error::Status::NotFound); } #[test] diff --git a/validation/tests/conftest.py b/validation/tests/conftest.py index 2a6a99f..815dace 100644 --- a/validation/tests/conftest.py +++ b/validation/tests/conftest.py @@ -13,14 +13,16 @@ # limitations under the License. import sys +import typing from pathlib import Path +import adbc_driver_manager +import adbc_driver_manager.dbapi import adbc_drivers_validation.model import adbc_drivers_validation.tests.conftest import pytest from adbc_drivers_validation.tests.conftest import ( # noqa: F401 conn, - conn_factory, db_kwargs, manual_test, pytest_collection_modifyitems, @@ -51,3 +53,36 @@ def driver_path(driver: adbc_drivers_validation.model.DriverQuirks) -> str: Path(__file__).parent.parent.parent / f"build/libadbc_driver_{driver.name}.{ext}" ) + + +@pytest.fixture(scope="session") +def conn_factory( + driver_path: str, + db_kwargs: dict[str, typing.Any], # noqa:F811 +) -> typing.Callable[[], adbc_driver_manager.dbapi.Connection]: + kwargs = db_kwargs.copy() + kwargs["driver"] = driver_path + db = adbc_driver_manager.AdbcDatabase(**kwargs) + shared_db = adbc_driver_manager.dbapi._SharedDatabase(db) + + def _factory() -> adbc_driver_manager.dbapi.Connection: + adbc_conn = adbc_driver_manager.AdbcConnection(db) + return adbc_driver_manager.dbapi.Connection( + shared_db, adbc_conn, autocommit=True + ) + + return _factory + + +@pytest.fixture(scope="session", autouse=True) +def _setup_resources( + conn_factory: typing.Callable[[], adbc_driver_manager.dbapi.Connection], +) -> None: + with conn_factory() as c: + with c.cursor() as cursor: + for statement in [ + "CREATE SCHEMA IF NOT EXISTS secondary", + "CREATE DATABASE IF NOT EXISTS secondary_catalog", + "CREATE SCHEMA IF NOT EXISTS secondary_catalog.secondary_schema", + ]: + cursor.execute(statement) diff --git a/validation/tests/datafusion.py b/validation/tests/datafusion.py index bf6c170..9dc4f2e 100644 --- a/validation/tests/datafusion.py +++ b/validation/tests/datafusion.py @@ -30,9 +30,16 @@ class DataFusionQuirks(model.DriverQuirks): statement_prepare=True, current_catalog="datafusion", current_schema="public", + secondary_schema="secondary", + secondary_catalog="secondary_catalog", + secondary_catalog_schema="secondary_schema", connection_get_table_schema=True, + connection_set_current_catalog=True, + connection_set_current_schema=True, get_objects=True, statement_bulk_ingest=True, + statement_bulk_ingest_schema=True, + statement_bulk_ingest_catalog=True, statement_rows_affected=True, statement_execute_schema=True, )