Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
198 changes: 172 additions & 26 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@

mod get_objects;
use adbc_core::constants;
use datafusion::common::TableReference;
use datafusion::common::{ScalarValue, TableReference};
use datafusion::dataframe::DataFrameWriteOptions;
use datafusion::datasource::MemTable;
use datafusion::logical_expr::LogicalPlan;
Expand Down Expand Up @@ -487,8 +487,7 @@ impl Connection for DataFusionConnection {
runtime: self.runtime.clone(),
ctx: self.ctx.clone(),
query: None,
bound_batches: None,
bound_schema: None,
bound: None,
ingest: BulkIngestState::new(),
})
}
Expand Down Expand Up @@ -600,8 +599,7 @@ pub struct DataFusionStatement {
runtime: Arc<Runtime>,
ctx: Arc<SessionContext>,
query: Option<QueryState>,
bound_batches: Option<Vec<RecordBatch>>,
bound_schema: Option<SchemaRef>,
bound: Option<Box<dyn RecordBatchReader + Send>>,
ingest: BulkIngestState<ErrorHelper>,
}

Expand Down Expand Up @@ -656,7 +654,133 @@ impl Optionable for DataFusionStatement {
}
}

fn row_to_scalar_values(
batch: &RecordBatch,
row_index: usize,
) -> adbc_core::error::Result<Vec<ScalarValue>> {
let mut values = Vec::with_capacity(batch.num_columns());
for col_index in 0..batch.num_columns() {
let array = batch.column(col_index);
let scalar =
ScalarValue::try_from_array(array, row_index).map_err(ErrorHelper::from_datafusion)?;
values.push(scalar);
}
Ok(values)
}

impl DataFusionStatement {
fn execute_with_params(
&mut self,
reader: Box<dyn RecordBatchReader + Send>,
) -> Result<Box<dyn RecordBatchReader + Send>> {
self.runtime.block_on(async {
let mut all_results: Vec<RecordBatch> = Vec::new();
let mut result_schema: Option<SchemaRef> = None;

for batch in reader {
let batch = batch.map_err(ErrorHelper::from_arrow)?;
for row_idx in 0..batch.num_rows() {
let params = row_to_scalar_values(&batch, row_idx)?;

let df = match &self.query {
Some(QueryState::Sql(query)) => {
let df = self
.ctx
.sql(query)
.await
.map_err(ErrorHelper::from_datafusion)?;
df.with_param_values(params)
.map_err(ErrorHelper::from_datafusion)?
}
Some(QueryState::Prepared(plan)) => {
let plan_with_params = plan
.clone()
.with_param_values(params)
.map_err(ErrorHelper::from_datafusion)?;
self.ctx
.execute_logical_plan(plan_with_params)
.await
.map_err(ErrorHelper::from_datafusion)?
}
_ => {
return Err(ErrorHelper::invalid_state()
.message("no query has been set")
.to_adbc());
}
};

if result_schema.is_none() {
result_schema = Some(df.schema().as_arrow().clone().into());
}
let batches = df.collect().await.map_err(ErrorHelper::from_datafusion)?;
all_results.extend(batches);
}
}

let schema = result_schema.unwrap_or_else(|| Arc::new(arrow_schema::Schema::empty()));

if all_results.is_empty() {
Ok(
Box::new(SingleBatchReader::new(RecordBatch::new_empty(schema)))
as Box<dyn RecordBatchReader + Send>,
)
} else {
let combined = arrow_select::concat::concat_batches(&schema, &all_results)
.map_err(ErrorHelper::from_arrow)?;
Ok(Box::new(SingleBatchReader::new(combined)) as Box<dyn RecordBatchReader + Send>)
}
})
}

fn execute_update_with_params(
&mut self,
reader: Box<dyn RecordBatchReader + Send>,
) -> adbc_core::error::Result<Option<i64>> {
let mut total_rows: i64 = 0;

self.runtime.block_on(async {
for batch in reader {
let batch = batch.map_err(ErrorHelper::from_arrow)?;
total_rows += batch.num_rows() as i64;
for row_idx in 0..batch.num_rows() {
let params = row_to_scalar_values(&batch, row_idx)?;

let df = match &self.query {
Some(QueryState::Sql(query)) => {
let df = self
.ctx
.sql(query)
.await
.map_err(ErrorHelper::from_datafusion)?;
df.with_param_values(params)
.map_err(ErrorHelper::from_datafusion)?
}
Some(QueryState::Prepared(plan)) => {
let plan_with_params = plan
.clone()
.with_param_values(params)
.map_err(ErrorHelper::from_datafusion)?;
self.ctx
.execute_logical_plan(plan_with_params)
.await
.map_err(ErrorHelper::from_datafusion)?
}
_ => {
return Err(ErrorHelper::invalid_state()
.message("no query has been set")
.to_adbc());
}
};

df.collect().await.map_err(ErrorHelper::from_datafusion)?;
}
}
Ok::<_, adbc_core::error::Error>(())
})?;

Ok(Some(total_rows))
}

fn make_table_ref(&self) -> TableReference {
let table = self.ingest.table.as_deref().unwrap_or("");
match (&self.ingest.catalog, &self.ingest.schema) {
Expand All @@ -669,17 +793,15 @@ impl DataFusionStatement {
}

fn execute_bulk_ingest(&mut self) -> adbc_core::error::Result<Option<i64>> {
let batches = self.bound_batches.take();
let schema = self.bound_schema.take();
let reader = self.bound.take().ok_or_else(|| {
ErrorHelper::invalid_state()
.message("no data bound for bulk ingest")
.to_adbc()
})?;

let (batches, schema) = match (batches, schema) {
(Some(b), Some(s)) => (b, s),
_ => {
return Err(ErrorHelper::invalid_state()
.message("no data bound for bulk ingest")
.to_adbc());
}
};
let schema = reader.schema();
let batches: std::result::Result<Vec<RecordBatch>, ArrowError> = reader.collect();
let batches = batches.map_err(ErrorHelper::from_arrow)?;

let row_count: i64 = batches.iter().map(|b| b.num_rows() as i64).sum();
let table_ref = self.make_table_ref();
Expand Down Expand Up @@ -780,25 +902,23 @@ impl DataFusionStatement {

impl Statement for DataFusionStatement {
fn bind(&mut self, batch: arrow_array::RecordBatch) -> adbc_core::error::Result<()> {
let schema = batch.schema();
self.bound_batches = Some(vec![batch]);
self.bound_schema = Some(schema);
self.bound = Some(Box::new(SingleBatchReader::new(batch)));
Ok(())
}

fn bind_stream(
&mut self,
reader: Box<dyn arrow_array::RecordBatchReader + Send>,
) -> adbc_core::error::Result<()> {
let schema = reader.schema();
let batches: std::result::Result<Vec<RecordBatch>, ArrowError> = reader.collect();
let batches = batches.map_err(ErrorHelper::from_arrow)?;
self.bound_batches = Some(batches);
self.bound_schema = Some(schema);
self.bound = Some(reader);
Ok(())
}

fn execute(&mut self) -> Result<Box<dyn RecordBatchReader + Send>> {
if let Some(reader) = self.bound.take() {
return self.execute_with_params(reader);
}

self.runtime.block_on(async {
let df = match &self.query {
Some(QueryState::Sql(query)) => self
Expand Down Expand Up @@ -836,6 +956,10 @@ impl Statement for DataFusionStatement {
return self.execute_bulk_ingest();
}

if let Some(reader) = self.bound.take() {
return self.execute_update_with_params(reader);
}

self.runtime.block_on(async {
let df = match &self.query {
Some(QueryState::Sql(query)) => self
Expand Down Expand Up @@ -902,9 +1026,31 @@ impl Statement for DataFusionStatement {
}

fn get_parameter_schema(&self) -> adbc_core::error::Result<arrow_schema::Schema> {
Err(ErrorHelper::not_implemented()
.message("get_parameter_schema")
.to_adbc())
match &self.query {
Some(QueryState::Prepared(plan)) => {
let param_types = plan
.get_parameter_types()
.map_err(ErrorHelper::from_datafusion)?;

let mut params: Vec<_> = param_types.into_iter().collect();
params.sort_by_key(|(name, _)| {
name.trim_start_matches('$').parse::<usize>().unwrap_or(0)
});

let fields: Vec<arrow_schema::Field> = params
.into_iter()
.map(|(name, dt)| {
let data_type = dt.unwrap_or(arrow_schema::DataType::Utf8);
arrow_schema::Field::new(name, data_type, true)
})
.collect();

Ok(arrow_schema::Schema::new(fields))
}
_ => Err(ErrorHelper::invalid_state()
.message("statement must be prepared before getting parameter schema")
.to_adbc()),
}
}

fn prepare(&mut self) -> adbc_core::error::Result<()> {
Expand Down
24 changes: 24 additions & 0 deletions validation/queries/type/bind/binary.txtcase
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
// Copyright (c) 2026 ADBC Drivers Contributors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

// part: metadata

[tags]
sql-type-name = "BYTEA"

// part: setup_query

CREATE TABLE test_binary (
res BYTEA
);
17 changes: 17 additions & 0 deletions validation/queries/type/bind/binary_view.txtcase
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
// Copyright (c) 2026 ADBC Drivers Contributors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

// part: metadata

skip = "DataFusion does not cast BinaryView bind parameters to Binary column type"
17 changes: 17 additions & 0 deletions validation/queries/type/bind/fixed_size_binary.txtcase
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
// Copyright (c) 2026 ADBC Drivers Contributors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

// part: metadata

skip = "DataFusion does not cast FixedSizeBinary bind parameters to Binary column type"
17 changes: 17 additions & 0 deletions validation/queries/type/bind/float16.txtcase
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
// Copyright (c) 2026 ADBC Drivers Contributors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

// part: metadata

hide = true
17 changes: 17 additions & 0 deletions validation/queries/type/bind/large_binary.txtcase
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
// Copyright (c) 2026 ADBC Drivers Contributors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

// part: metadata

skip = "DataFusion does not cast LargeBinary bind parameters to Binary column type"
17 changes: 17 additions & 0 deletions validation/queries/type/bind/large_string.txtcase
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
// Copyright (c) 2026 ADBC Drivers Contributors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

// part: metadata

skip = "DataFusion does not cast LargeUtf8 bind parameters to Utf8View column type"
17 changes: 17 additions & 0 deletions validation/queries/type/bind/string.txtcase
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
// Copyright (c) 2026 ADBC Drivers Contributors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

// part: metadata

skip = "DataFusion does not cast Utf8 bind parameters to Utf8View column type"
Loading
Loading