diff --git a/src/lib.rs b/src/lib.rs index c9396da..94a9992 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -22,7 +22,7 @@ mod get_objects; use adbc_core::constants; -use datafusion::common::TableReference; +use datafusion::common::{ScalarValue, TableReference}; use datafusion::dataframe::DataFrameWriteOptions; use datafusion::datasource::MemTable; use datafusion::logical_expr::LogicalPlan; @@ -487,8 +487,7 @@ impl Connection for DataFusionConnection { runtime: self.runtime.clone(), ctx: self.ctx.clone(), query: None, - bound_batches: None, - bound_schema: None, + bound: None, ingest: BulkIngestState::new(), }) } @@ -600,8 +599,7 @@ pub struct DataFusionStatement { runtime: Arc, ctx: Arc, query: Option, - bound_batches: Option>, - bound_schema: Option, + bound: Option>, ingest: BulkIngestState, } @@ -656,7 +654,133 @@ impl Optionable for DataFusionStatement { } } +fn row_to_scalar_values( + batch: &RecordBatch, + row_index: usize, +) -> adbc_core::error::Result> { + let mut values = Vec::with_capacity(batch.num_columns()); + for col_index in 0..batch.num_columns() { + let array = batch.column(col_index); + let scalar = + ScalarValue::try_from_array(array, row_index).map_err(ErrorHelper::from_datafusion)?; + values.push(scalar); + } + Ok(values) +} + impl DataFusionStatement { + fn execute_with_params( + &mut self, + reader: Box, + ) -> Result> { + self.runtime.block_on(async { + let mut all_results: Vec = Vec::new(); + let mut result_schema: Option = None; + + for batch in reader { + let batch = batch.map_err(ErrorHelper::from_arrow)?; + for row_idx in 0..batch.num_rows() { + let params = row_to_scalar_values(&batch, row_idx)?; + + let df = match &self.query { + Some(QueryState::Sql(query)) => { + let df = self + .ctx + .sql(query) + .await + .map_err(ErrorHelper::from_datafusion)?; + df.with_param_values(params) + .map_err(ErrorHelper::from_datafusion)? + } + Some(QueryState::Prepared(plan)) => { + let plan_with_params = plan + .clone() + .with_param_values(params) + .map_err(ErrorHelper::from_datafusion)?; + self.ctx + .execute_logical_plan(plan_with_params) + .await + .map_err(ErrorHelper::from_datafusion)? + } + _ => { + return Err(ErrorHelper::invalid_state() + .message("no query has been set") + .to_adbc()); + } + }; + + if result_schema.is_none() { + result_schema = Some(df.schema().as_arrow().clone().into()); + } + let batches = df.collect().await.map_err(ErrorHelper::from_datafusion)?; + all_results.extend(batches); + } + } + + let schema = result_schema.unwrap_or_else(|| Arc::new(arrow_schema::Schema::empty())); + + if all_results.is_empty() { + Ok( + Box::new(SingleBatchReader::new(RecordBatch::new_empty(schema))) + as Box, + ) + } else { + let combined = arrow_select::concat::concat_batches(&schema, &all_results) + .map_err(ErrorHelper::from_arrow)?; + Ok(Box::new(SingleBatchReader::new(combined)) as Box) + } + }) + } + + fn execute_update_with_params( + &mut self, + reader: Box, + ) -> adbc_core::error::Result> { + let mut total_rows: i64 = 0; + + self.runtime.block_on(async { + for batch in reader { + let batch = batch.map_err(ErrorHelper::from_arrow)?; + total_rows += batch.num_rows() as i64; + for row_idx in 0..batch.num_rows() { + let params = row_to_scalar_values(&batch, row_idx)?; + + let df = match &self.query { + Some(QueryState::Sql(query)) => { + let df = self + .ctx + .sql(query) + .await + .map_err(ErrorHelper::from_datafusion)?; + df.with_param_values(params) + .map_err(ErrorHelper::from_datafusion)? + } + Some(QueryState::Prepared(plan)) => { + let plan_with_params = plan + .clone() + .with_param_values(params) + .map_err(ErrorHelper::from_datafusion)?; + self.ctx + .execute_logical_plan(plan_with_params) + .await + .map_err(ErrorHelper::from_datafusion)? + } + _ => { + return Err(ErrorHelper::invalid_state() + .message("no query has been set") + .to_adbc()); + } + }; + + df.collect().await.map_err(ErrorHelper::from_datafusion)?; + } + } + Ok::<_, adbc_core::error::Error>(()) + })?; + + Ok(Some(total_rows)) + } + fn make_table_ref(&self) -> TableReference { let table = self.ingest.table.as_deref().unwrap_or(""); match (&self.ingest.catalog, &self.ingest.schema) { @@ -669,17 +793,15 @@ impl DataFusionStatement { } fn execute_bulk_ingest(&mut self) -> adbc_core::error::Result> { - let batches = self.bound_batches.take(); - let schema = self.bound_schema.take(); + let reader = self.bound.take().ok_or_else(|| { + ErrorHelper::invalid_state() + .message("no data bound for bulk ingest") + .to_adbc() + })?; - let (batches, schema) = match (batches, schema) { - (Some(b), Some(s)) => (b, s), - _ => { - return Err(ErrorHelper::invalid_state() - .message("no data bound for bulk ingest") - .to_adbc()); - } - }; + let schema = reader.schema(); + let batches: std::result::Result, ArrowError> = reader.collect(); + let batches = batches.map_err(ErrorHelper::from_arrow)?; let row_count: i64 = batches.iter().map(|b| b.num_rows() as i64).sum(); let table_ref = self.make_table_ref(); @@ -780,9 +902,7 @@ impl DataFusionStatement { impl Statement for DataFusionStatement { fn bind(&mut self, batch: arrow_array::RecordBatch) -> adbc_core::error::Result<()> { - let schema = batch.schema(); - self.bound_batches = Some(vec![batch]); - self.bound_schema = Some(schema); + self.bound = Some(Box::new(SingleBatchReader::new(batch))); Ok(()) } @@ -790,15 +910,15 @@ impl Statement for DataFusionStatement { &mut self, reader: Box, ) -> adbc_core::error::Result<()> { - let schema = reader.schema(); - let batches: std::result::Result, ArrowError> = reader.collect(); - let batches = batches.map_err(ErrorHelper::from_arrow)?; - self.bound_batches = Some(batches); - self.bound_schema = Some(schema); + self.bound = Some(reader); Ok(()) } fn execute(&mut self) -> Result> { + if let Some(reader) = self.bound.take() { + return self.execute_with_params(reader); + } + self.runtime.block_on(async { let df = match &self.query { Some(QueryState::Sql(query)) => self @@ -836,6 +956,10 @@ impl Statement for DataFusionStatement { return self.execute_bulk_ingest(); } + if let Some(reader) = self.bound.take() { + return self.execute_update_with_params(reader); + } + self.runtime.block_on(async { let df = match &self.query { Some(QueryState::Sql(query)) => self @@ -902,9 +1026,31 @@ impl Statement for DataFusionStatement { } fn get_parameter_schema(&self) -> adbc_core::error::Result { - Err(ErrorHelper::not_implemented() - .message("get_parameter_schema") - .to_adbc()) + match &self.query { + Some(QueryState::Prepared(plan)) => { + let param_types = plan + .get_parameter_types() + .map_err(ErrorHelper::from_datafusion)?; + + let mut params: Vec<_> = param_types.into_iter().collect(); + params.sort_by_key(|(name, _)| { + name.trim_start_matches('$').parse::().unwrap_or(0) + }); + + let fields: Vec = params + .into_iter() + .map(|(name, dt)| { + let data_type = dt.unwrap_or(arrow_schema::DataType::Utf8); + arrow_schema::Field::new(name, data_type, true) + }) + .collect(); + + Ok(arrow_schema::Schema::new(fields)) + } + _ => Err(ErrorHelper::invalid_state() + .message("statement must be prepared before getting parameter schema") + .to_adbc()), + } } fn prepare(&mut self) -> adbc_core::error::Result<()> { diff --git a/validation/queries/type/bind/binary.txtcase b/validation/queries/type/bind/binary.txtcase new file mode 100644 index 0000000..dd8e59b --- /dev/null +++ b/validation/queries/type/bind/binary.txtcase @@ -0,0 +1,24 @@ +// Copyright (c) 2026 ADBC Drivers Contributors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// part: metadata + +[tags] +sql-type-name = "BYTEA" + +// part: setup_query + +CREATE TABLE test_binary ( + res BYTEA +); diff --git a/validation/queries/type/bind/binary_view.txtcase b/validation/queries/type/bind/binary_view.txtcase new file mode 100644 index 0000000..db22e67 --- /dev/null +++ b/validation/queries/type/bind/binary_view.txtcase @@ -0,0 +1,17 @@ +// Copyright (c) 2026 ADBC Drivers Contributors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// part: metadata + +skip = "DataFusion does not cast BinaryView bind parameters to Binary column type" diff --git a/validation/queries/type/bind/fixed_size_binary.txtcase b/validation/queries/type/bind/fixed_size_binary.txtcase new file mode 100644 index 0000000..cd236c4 --- /dev/null +++ b/validation/queries/type/bind/fixed_size_binary.txtcase @@ -0,0 +1,17 @@ +// Copyright (c) 2026 ADBC Drivers Contributors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// part: metadata + +skip = "DataFusion does not cast FixedSizeBinary bind parameters to Binary column type" diff --git a/validation/queries/type/bind/float16.txtcase b/validation/queries/type/bind/float16.txtcase new file mode 100644 index 0000000..6c02960 --- /dev/null +++ b/validation/queries/type/bind/float16.txtcase @@ -0,0 +1,17 @@ +// Copyright (c) 2026 ADBC Drivers Contributors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// part: metadata + +hide = true diff --git a/validation/queries/type/bind/large_binary.txtcase b/validation/queries/type/bind/large_binary.txtcase new file mode 100644 index 0000000..4a8f326 --- /dev/null +++ b/validation/queries/type/bind/large_binary.txtcase @@ -0,0 +1,17 @@ +// Copyright (c) 2026 ADBC Drivers Contributors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// part: metadata + +skip = "DataFusion does not cast LargeBinary bind parameters to Binary column type" diff --git a/validation/queries/type/bind/large_string.txtcase b/validation/queries/type/bind/large_string.txtcase new file mode 100644 index 0000000..4ee1772 --- /dev/null +++ b/validation/queries/type/bind/large_string.txtcase @@ -0,0 +1,17 @@ +// Copyright (c) 2026 ADBC Drivers Contributors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// part: metadata + +skip = "DataFusion does not cast LargeUtf8 bind parameters to Utf8View column type" diff --git a/validation/queries/type/bind/string.txtcase b/validation/queries/type/bind/string.txtcase new file mode 100644 index 0000000..9771d27 --- /dev/null +++ b/validation/queries/type/bind/string.txtcase @@ -0,0 +1,17 @@ +// Copyright (c) 2026 ADBC Drivers Contributors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// part: metadata + +skip = "DataFusion does not cast Utf8 bind parameters to Utf8View column type" diff --git a/validation/queries/type/bind/string_view.txtcase b/validation/queries/type/bind/string_view.txtcase new file mode 100644 index 0000000..520553b --- /dev/null +++ b/validation/queries/type/bind/string_view.txtcase @@ -0,0 +1,17 @@ +// Copyright (c) 2026 ADBC Drivers Contributors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// part: metadata + +skip = "DataFusion does not cast Utf8View bind parameters to match column type" diff --git a/validation/queries/type/bind/time_ms.txtcase b/validation/queries/type/bind/time_ms.txtcase new file mode 100644 index 0000000..6c02960 --- /dev/null +++ b/validation/queries/type/bind/time_ms.txtcase @@ -0,0 +1,17 @@ +// Copyright (c) 2026 ADBC Drivers Contributors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// part: metadata + +hide = true diff --git a/validation/queries/type/bind/time_ns.txtcase b/validation/queries/type/bind/time_ns.txtcase new file mode 100644 index 0000000..2222ab6 --- /dev/null +++ b/validation/queries/type/bind/time_ns.txtcase @@ -0,0 +1,20 @@ +// Copyright (c) 2026 ADBC Drivers Contributors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// part: setup_query + +CREATE TABLE test_time ( + idx INT, + res TIME +); diff --git a/validation/queries/type/bind/time_s.txtcase b/validation/queries/type/bind/time_s.txtcase new file mode 100644 index 0000000..6c02960 --- /dev/null +++ b/validation/queries/type/bind/time_s.txtcase @@ -0,0 +1,17 @@ +// Copyright (c) 2026 ADBC Drivers Contributors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// part: metadata + +hide = true diff --git a/validation/queries/type/bind/time_us.txtcase b/validation/queries/type/bind/time_us.txtcase new file mode 100644 index 0000000..6c02960 --- /dev/null +++ b/validation/queries/type/bind/time_us.txtcase @@ -0,0 +1,17 @@ +// Copyright (c) 2026 ADBC Drivers Contributors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// part: metadata + +hide = true diff --git a/validation/queries/type/bind/timestamptz_ms.txtcase b/validation/queries/type/bind/timestamptz_ms.txtcase new file mode 100644 index 0000000..3dcb5f1 --- /dev/null +++ b/validation/queries/type/bind/timestamptz_ms.txtcase @@ -0,0 +1,17 @@ +// Copyright (c) 2026 ADBC Drivers Contributors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// part: metadata + +skip = "DataFusion MemTable does not handle timezone metadata from bind parameters" diff --git a/validation/queries/type/bind/timestamptz_ns.txtcase b/validation/queries/type/bind/timestamptz_ns.txtcase new file mode 100644 index 0000000..3dcb5f1 --- /dev/null +++ b/validation/queries/type/bind/timestamptz_ns.txtcase @@ -0,0 +1,17 @@ +// Copyright (c) 2026 ADBC Drivers Contributors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// part: metadata + +skip = "DataFusion MemTable does not handle timezone metadata from bind parameters" diff --git a/validation/queries/type/bind/timestamptz_s.txtcase b/validation/queries/type/bind/timestamptz_s.txtcase new file mode 100644 index 0000000..3dcb5f1 --- /dev/null +++ b/validation/queries/type/bind/timestamptz_s.txtcase @@ -0,0 +1,17 @@ +// Copyright (c) 2026 ADBC Drivers Contributors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// part: metadata + +skip = "DataFusion MemTable does not handle timezone metadata from bind parameters" diff --git a/validation/queries/type/bind/timestamptz_us.txtcase b/validation/queries/type/bind/timestamptz_us.txtcase new file mode 100644 index 0000000..3dcb5f1 --- /dev/null +++ b/validation/queries/type/bind/timestamptz_us.txtcase @@ -0,0 +1,17 @@ +// Copyright (c) 2026 ADBC Drivers Contributors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// part: metadata + +skip = "DataFusion MemTable does not handle timezone metadata from bind parameters" diff --git a/validation/tests/datafusion.py b/validation/tests/datafusion.py index 9dc4f2e..7c733e1 100644 --- a/validation/tests/datafusion.py +++ b/validation/tests/datafusion.py @@ -26,7 +26,8 @@ class DataFusionQuirks(model.DriverQuirks): vendor_version = "53.1.0" short_version = "53" features = model.DriverFeatures( - statement_bind=False, + statement_bind=True, + statement_get_parameter_schema=True, statement_prepare=True, current_catalog="datafusion", current_schema="public", @@ -53,6 +54,9 @@ class DataFusionQuirks(model.DriverQuirks): def queries_paths(self) -> tuple[Path]: return (Path(__file__).parent.parent / "queries",) + def bind_parameter(self, index: int) -> str: + return f"${index}" + def is_table_not_found(self, table_name: str, error: Exception) -> bool: msg = str(error) if table_name and table_name not in msg: diff --git a/validation/tests/test_statement.py b/validation/tests/test_statement.py index efe8cbf..38fdd7a 100644 --- a/validation/tests/test_statement.py +++ b/validation/tests/test_statement.py @@ -24,10 +24,6 @@ def pytest_generate_tests(metafunc) -> None: class TestStatement(statement_tests.TestStatement): - @pytest.mark.xfail(reason="bind parameters not implemented") - def test_parameter_execute(self, driver, conn) -> None: - super().test_parameter_execute(driver, conn) - @pytest.mark.xfail( reason="DataFusion lightweight updates require special table settings" )