Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 31 additions & 8 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -236,8 +236,11 @@ impl Driver for DataFusionDriver {
type DatabaseType = DataFusionDatabase;

fn new_database(&mut self) -> Result<Self::DatabaseType> {
let config = SessionConfig::new().with_information_schema(true);
let ctx = SessionContext::new_with_config(config);
Ok(Self::DatabaseType {
handle: self.handle.clone(),
ctx: Arc::new(ctx),
})
}

Expand All @@ -250,8 +253,11 @@ impl Driver for DataFusionDriver {
),
>,
) -> adbc_core::error::Result<Self::DatabaseType> {
let config = SessionConfig::new().with_information_schema(true);
let ctx = SessionContext::new_with_config(config);
let mut database = Self::DatabaseType {
handle: self.handle.clone(),
ctx: Arc::new(ctx),
};
for (key, value) in opts {
database.set_option(key, value)?;
Expand All @@ -262,6 +268,7 @@ impl Driver for DataFusionDriver {

pub struct DataFusionDatabase {
handle: Option<tokio::runtime::Handle>,
ctx: Arc<SessionContext>,
}

impl Optionable for DataFusionDatabase {
Expand Down Expand Up @@ -296,9 +303,6 @@ impl Database for DataFusionDatabase {
type ConnectionType = DataFusionConnection;

fn new_connection(&self) -> Result<Self::ConnectionType> {
let config = SessionConfig::new().with_information_schema(true);
let ctx = SessionContext::new_with_config(config);

let runtime = Runtime::new(self.handle.clone()).map_err(|e| {
ErrorHelper::io()
.context("create Tokio runtime")
Expand All @@ -308,7 +312,7 @@ impl Database for DataFusionDatabase {

Ok(DataFusionConnection {
runtime: Arc::new(runtime),
ctx: Arc::new(ctx),
ctx: self.ctx.clone(),
})
}

Expand All @@ -321,9 +325,6 @@ impl Database for DataFusionDatabase {
),
>,
) -> adbc_core::error::Result<Self::ConnectionType> {
let config = SessionConfig::new().with_information_schema(true);
let ctx = SessionContext::new_with_config(config);

let runtime = Runtime::new(self.handle.clone()).map_err(|e| {
ErrorHelper::io()
.context("create Tokio runtime")
Expand All @@ -333,7 +334,7 @@ impl Database for DataFusionDatabase {

let mut connection = DataFusionConnection {
runtime: Arc::new(runtime),
ctx: Arc::new(ctx),
ctx: self.ctx.clone(),
};

for (key, value) in opts {
Expand All @@ -360,6 +361,12 @@ impl Optionable for DataFusionConnection {
match key.as_ref() {
constants::ADBC_CONNECTION_OPTION_CURRENT_CATALOG => match value {
OptionValue::String(value) => {
if !self.ctx.catalog_names().contains(&value) {
return Err(ErrorHelper::not_found()
.context("set current catalog")
.format(format_args!("catalog '{value}' does not exist"))
.to_adbc());
}
self.runtime.block_on(async {
let query = format!("SET datafusion.catalog.default_catalog = {value}");
self.ctx
Expand All @@ -379,6 +386,22 @@ impl Optionable for DataFusionConnection {
},
constants::ADBC_CONNECTION_OPTION_CURRENT_DB_SCHEMA => match value {
OptionValue::String(value) => {
let state = self.ctx.state();
let catalog_name = &state.config_options().catalog.default_catalog;
let catalog = self.ctx.catalog(catalog_name).ok_or_else(|| {
ErrorHelper::not_found()
.context("set current schema")
.format(format_args!("catalog '{catalog_name}' does not exist"))
.to_adbc()
})?;
if !catalog.schema_names().contains(&value) {
return Err(ErrorHelper::not_found()
.context("set current schema")
.format(format_args!(
"schema '{value}' does not exist in catalog '{catalog_name}'"
))
.to_adbc());
}
self.runtime.block_on(async {
let query = format!("SET datafusion.catalog.default_schema = {value}");
self.ctx
Expand Down
56 changes: 48 additions & 8 deletions tests/test_datafusion.rs
Original file line number Diff line number Diff line change
Expand Up @@ -93,33 +93,73 @@ fn test_connection_options() {

assert_eq!(current_catalog, "datafusion");

let _ = connection.set_option(
OptionConnection::CurrentCatalog,
OptionValue::String("datafusion2".to_string()),
);
// Create the secondary catalog and schema before switching
let mut stmt = connection.new_statement().unwrap();
stmt.set_sql_query("CREATE DATABASE IF NOT EXISTS datafusion2")
.unwrap();
stmt.execute_update().unwrap();

connection
.set_option(
OptionConnection::CurrentCatalog,
OptionValue::String("datafusion2".to_string()),
)
.unwrap();

let current_catalog = connection
.get_option_string(OptionConnection::CurrentCatalog)
.unwrap();

assert_eq!(current_catalog, "datafusion2");

// Switch back and create a secondary schema
connection
.set_option(
OptionConnection::CurrentCatalog,
OptionValue::String("datafusion".to_string()),
)
.unwrap();

let mut stmt = connection.new_statement().unwrap();
stmt.set_sql_query("CREATE SCHEMA IF NOT EXISTS public2")
.unwrap();
stmt.execute_update().unwrap();

let current_schema = connection
.get_option_string(OptionConnection::CurrentSchema)
.unwrap();

assert_eq!(current_schema, "public");

let _ = connection.set_option(
OptionConnection::CurrentSchema,
OptionValue::String("public2".to_string()),
);
connection
.set_option(
OptionConnection::CurrentSchema,
OptionValue::String("public2".to_string()),
)
.unwrap();

let current_schema = connection
.get_option_string(OptionConnection::CurrentSchema)
.unwrap();

assert_eq!(current_schema, "public2");

// Verify setting nonexistent catalog/schema returns an error
let err = connection
.set_option(
OptionConnection::CurrentCatalog,
OptionValue::String("nonexistent".to_string()),
)
.unwrap_err();
assert_eq!(err.status, adbc_core::error::Status::NotFound);

let err = connection
.set_option(
OptionConnection::CurrentSchema,
OptionValue::String("nonexistent".to_string()),
)
.unwrap_err();
assert_eq!(err.status, adbc_core::error::Status::NotFound);
}

#[test]
Expand Down
37 changes: 36 additions & 1 deletion validation/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,16 @@
# limitations under the License.

import sys
import typing
from pathlib import Path

import adbc_driver_manager
import adbc_driver_manager.dbapi
import adbc_drivers_validation.model
import adbc_drivers_validation.tests.conftest
import pytest
from adbc_drivers_validation.tests.conftest import ( # noqa: F401
conn,
conn_factory,
db_kwargs,
manual_test,
pytest_collection_modifyitems,
Expand Down Expand Up @@ -51,3 +53,36 @@ def driver_path(driver: adbc_drivers_validation.model.DriverQuirks) -> str:
Path(__file__).parent.parent.parent
/ f"build/libadbc_driver_{driver.name}.{ext}"
)


@pytest.fixture(scope="session")
def conn_factory(
driver_path: str,
db_kwargs: dict[str, typing.Any], # noqa:F811
) -> typing.Callable[[], adbc_driver_manager.dbapi.Connection]:
kwargs = db_kwargs.copy()
kwargs["driver"] = driver_path
db = adbc_driver_manager.AdbcDatabase(**kwargs)
shared_db = adbc_driver_manager.dbapi._SharedDatabase(db)

def _factory() -> adbc_driver_manager.dbapi.Connection:
adbc_conn = adbc_driver_manager.AdbcConnection(db)
return adbc_driver_manager.dbapi.Connection(
shared_db, adbc_conn, autocommit=True
)

return _factory


@pytest.fixture(scope="session", autouse=True)
def _setup_resources(
conn_factory: typing.Callable[[], adbc_driver_manager.dbapi.Connection],
) -> None:
with conn_factory() as c:
with c.cursor() as cursor:
for statement in [
"CREATE SCHEMA IF NOT EXISTS secondary",
"CREATE DATABASE IF NOT EXISTS secondary_catalog",
"CREATE SCHEMA IF NOT EXISTS secondary_catalog.secondary_schema",
]:
cursor.execute(statement)
7 changes: 7 additions & 0 deletions validation/tests/datafusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,16 @@ class DataFusionQuirks(model.DriverQuirks):
statement_prepare=True,
current_catalog="datafusion",
current_schema="public",
secondary_schema="secondary",
secondary_catalog="secondary_catalog",
secondary_catalog_schema="secondary_schema",
connection_get_table_schema=True,
connection_set_current_catalog=True,
connection_set_current_schema=True,
get_objects=True,
statement_bulk_ingest=True,
statement_bulk_ingest_schema=True,
statement_bulk_ingest_catalog=True,
statement_rows_affected=True,
statement_execute_schema=True,
)
Expand Down
Loading