Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
136 changes: 136 additions & 0 deletions src/bind.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,136 @@
// Copyright (c) 2026 ADBC Drivers Contributors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

use std::sync::Arc;
use std::vec::IntoIter;

use arrow_array::{RecordBatch, RecordBatchReader};
use arrow_schema::{ArrowError, SchemaRef};
use datafusion::common::ScalarValue;
use datafusion::logical_expr::LogicalPlan;
use datafusion::prelude::*;

use crate::{ErrorHelper, Runtime};

pub fn row_to_scalar_values(
batch: &RecordBatch,
row_index: usize,
) -> adbc_core::error::Result<Vec<ScalarValue>> {
let mut values = Vec::with_capacity(batch.num_columns());
for col_index in 0..batch.num_columns() {
let array = batch.column(col_index);
let scalar =
ScalarValue::try_from_array(array, row_index).map_err(ErrorHelper::from_datafusion)?;
values.push(scalar);
}
Ok(values)
}

pub struct BindReader {
template: LogicalPlan,
runtime: Arc<Runtime>,
ctx: Arc<SessionContext>,
bound: Box<dyn RecordBatchReader + Send>,
current_batch: Option<RecordBatch>,
current_row: usize,
pending_results: IntoIter<RecordBatch>,
schema: SchemaRef,
}

impl BindReader {
pub fn new(
template: LogicalPlan,
runtime: Arc<Runtime>,
ctx: Arc<SessionContext>,
bound: Box<dyn RecordBatchReader + Send>,
) -> Self {
let schema: SchemaRef = template.schema().as_arrow().clone().into();
Self {
template,
runtime,
ctx,
bound,
current_batch: None,
current_row: 0,
pending_results: Vec::new().into_iter(),
schema,
}
}

fn advance(&mut self) -> Option<Result<RecordBatch, ArrowError>> {
loop {
if let Some(batch) = self.pending_results.next() {
return Some(Ok(batch));
}

let batch = match &self.current_batch {
Some(b) if self.current_row < b.num_rows() => b,
_ => match self.bound.next() {
Some(Ok(b)) => {
self.current_batch = Some(b);
self.current_row = 0;
self.current_batch.as_ref().unwrap()
}
Some(Err(e)) => return Some(Err(e)),
None => return None,
},
};

let params = match row_to_scalar_values(batch, self.current_row) {
Ok(p) => p,
Err(e) => {
return Some(Err(ArrowError::ExternalError(Box::new(e))));
}
};
self.current_row += 1;

let result = self.runtime.block_on(async {
let plan_with_params = self
.template
.clone()
.with_param_values(params)
.map_err(ErrorHelper::from_datafusion)?;
let df = self
.ctx
.execute_logical_plan(plan_with_params)
.await
.map_err(ErrorHelper::from_datafusion)?;
df.collect().await.map_err(ErrorHelper::from_datafusion)
});

match result {
Ok(batches) => {
self.pending_results = batches.into_iter();
}
Err(e) => {
return Some(Err(ArrowError::ExternalError(Box::new(e.to_adbc()))));
}
}
}
}
}

impl Iterator for BindReader {
type Item = Result<RecordBatch, ArrowError>;

fn next(&mut self) -> Option<Self::Item> {
self.advance()
}
}

impl RecordBatchReader for BindReader {
fn schema(&self) -> SchemaRef {
self.schema.clone()
}
}
185 changes: 63 additions & 122 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,10 @@
// specific language governing permissions and limitations
// under the License.

mod bind;
mod get_objects;
use adbc_core::constants;
use datafusion::common::{ScalarValue, TableReference};
use datafusion::common::TableReference;
use datafusion::dataframe::DataFrameWriteOptions;
use datafusion::datasource::MemTable;
use datafusion::logical_expr::LogicalPlan;
Expand Down Expand Up @@ -654,124 +655,53 @@ impl Optionable for DataFusionStatement {
}
}

fn row_to_scalar_values(
batch: &RecordBatch,
row_index: usize,
) -> adbc_core::error::Result<Vec<ScalarValue>> {
let mut values = Vec::with_capacity(batch.num_columns());
for col_index in 0..batch.num_columns() {
let array = batch.column(col_index);
let scalar =
ScalarValue::try_from_array(array, row_index).map_err(ErrorHelper::from_datafusion)?;
values.push(scalar);
impl DataFusionStatement {
fn get_prepared_plan(&mut self) -> adbc_core::error::Result<LogicalPlan> {
self.prepare()?;
match &self.query {
Some(QueryState::Prepared(plan)) => Ok(plan.clone()),
_ => Err(ErrorHelper::invalid_state()
.message("no query has been set")
.to_adbc()),
}
}
Ok(values)
}

impl DataFusionStatement {
fn execute_with_params(
&mut self,
reader: Box<dyn RecordBatchReader + Send>,
) -> Result<Box<dyn RecordBatchReader + Send>> {
self.runtime.block_on(async {
let mut all_results: Vec<RecordBatch> = Vec::new();
let mut result_schema: Option<SchemaRef> = None;

for batch in reader {
let batch = batch.map_err(ErrorHelper::from_arrow)?;
for row_idx in 0..batch.num_rows() {
let params = row_to_scalar_values(&batch, row_idx)?;

let df = match &self.query {
Some(QueryState::Sql(query)) => {
let df = self
.ctx
.sql(query)
.await
.map_err(ErrorHelper::from_datafusion)?;
df.with_param_values(params)
.map_err(ErrorHelper::from_datafusion)?
}
Some(QueryState::Prepared(plan)) => {
let plan_with_params = plan
.clone()
.with_param_values(params)
.map_err(ErrorHelper::from_datafusion)?;
self.ctx
.execute_logical_plan(plan_with_params)
.await
.map_err(ErrorHelper::from_datafusion)?
}
_ => {
return Err(ErrorHelper::invalid_state()
.message("no query has been set")
.to_adbc());
}
};

if result_schema.is_none() {
result_schema = Some(df.schema().as_arrow().clone().into());
}
let batches = df.collect().await.map_err(ErrorHelper::from_datafusion)?;
all_results.extend(batches);
}
}

let schema = result_schema.unwrap_or_else(|| Arc::new(arrow_schema::Schema::empty()));

if all_results.is_empty() {
Ok(
Box::new(SingleBatchReader::new(RecordBatch::new_empty(schema)))
as Box<dyn RecordBatchReader + Send>,
)
} else {
let combined = arrow_select::concat::concat_batches(&schema, &all_results)
.map_err(ErrorHelper::from_arrow)?;
Ok(Box::new(SingleBatchReader::new(combined)) as Box<dyn RecordBatchReader + Send>)
}
})
let template = self.get_prepared_plan()?;
Ok(Box::new(bind::BindReader::new(
template,
self.runtime.clone(),
self.ctx.clone(),
reader,
)))
}

fn execute_update_with_params(
&mut self,
reader: Box<dyn RecordBatchReader + Send>,
) -> adbc_core::error::Result<Option<i64>> {
let template = self.get_prepared_plan()?;
let mut total_rows: i64 = 0;

self.runtime.block_on(async {
for batch in reader {
let batch = batch.map_err(ErrorHelper::from_arrow)?;
total_rows += batch.num_rows() as i64;
total_rows = total_rows.saturating_add(batch.num_rows() as i64);
for row_idx in 0..batch.num_rows() {
let params = row_to_scalar_values(&batch, row_idx)?;

let df = match &self.query {
Some(QueryState::Sql(query)) => {
let df = self
.ctx
.sql(query)
.await
.map_err(ErrorHelper::from_datafusion)?;
df.with_param_values(params)
.map_err(ErrorHelper::from_datafusion)?
}
Some(QueryState::Prepared(plan)) => {
let plan_with_params = plan
.clone()
.with_param_values(params)
.map_err(ErrorHelper::from_datafusion)?;
self.ctx
.execute_logical_plan(plan_with_params)
.await
.map_err(ErrorHelper::from_datafusion)?
}
_ => {
return Err(ErrorHelper::invalid_state()
.message("no query has been set")
.to_adbc());
}
};
let params = bind::row_to_scalar_values(&batch, row_idx)?;

let plan_with_params = template
.clone()
.with_param_values(params)
.map_err(ErrorHelper::from_datafusion)?;
let df = self
.ctx
.execute_logical_plan(plan_with_params)
.await
.map_err(ErrorHelper::from_datafusion)?;
df.collect().await.map_err(ErrorHelper::from_datafusion)?;
}
}
Expand Down Expand Up @@ -1026,31 +956,42 @@ impl Statement for DataFusionStatement {
}

fn get_parameter_schema(&self) -> adbc_core::error::Result<arrow_schema::Schema> {
match &self.query {
Some(QueryState::Prepared(plan)) => {
let param_types = plan
.get_parameter_types()
.map_err(ErrorHelper::from_datafusion)?;
let plan = match &self.query {
Some(QueryState::Prepared(plan)) => plan.clone(),
Some(QueryState::Sql(sql)) => self
.runtime
.block_on(async {
self.ctx
.state()
.create_logical_plan(sql)
.await
.map_err(ErrorHelper::from_datafusion)
})
.map_err(|e| e.to_adbc())?,
_ => {
return Err(ErrorHelper::invalid_state()
.message("no query has been set")
.to_adbc());
}
};

let mut params: Vec<_> = param_types.into_iter().collect();
params.sort_by_key(|(name, _)| {
name.trim_start_matches('$').parse::<usize>().unwrap_or(0)
});
let param_types = plan
.get_parameter_types()
.map_err(ErrorHelper::from_datafusion)
.map_err(|e| e.to_adbc())?;

let fields: Vec<arrow_schema::Field> = params
.into_iter()
.map(|(name, dt)| {
let data_type = dt.unwrap_or(arrow_schema::DataType::Utf8);
arrow_schema::Field::new(name, data_type, true)
})
.collect();
let mut params: Vec<_> = param_types.into_iter().collect();
params.sort_by_key(|(name, _)| name.trim_start_matches('$').parse::<usize>().unwrap_or(0));

Ok(arrow_schema::Schema::new(fields))
}
_ => Err(ErrorHelper::invalid_state()
.message("statement must be prepared before getting parameter schema")
.to_adbc()),
}
let fields: Vec<arrow_schema::Field> = params
.into_iter()
.map(|(name, dt)| {
let data_type = dt.unwrap_or(arrow_schema::DataType::Utf8);
arrow_schema::Field::new(name, data_type, true)
})
.collect();

Ok(arrow_schema::Schema::new(fields))
}

fn prepare(&mut self) -> adbc_core::error::Result<()> {
Expand Down
Loading