diff --git a/src/bind.rs b/src/bind.rs new file mode 100644 index 0000000..d29e32b --- /dev/null +++ b/src/bind.rs @@ -0,0 +1,136 @@ +// Copyright (c) 2026 ADBC Drivers Contributors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use std::sync::Arc; +use std::vec::IntoIter; + +use arrow_array::{RecordBatch, RecordBatchReader}; +use arrow_schema::{ArrowError, SchemaRef}; +use datafusion::common::ScalarValue; +use datafusion::logical_expr::LogicalPlan; +use datafusion::prelude::*; + +use crate::{ErrorHelper, Runtime}; + +pub fn row_to_scalar_values( + batch: &RecordBatch, + row_index: usize, +) -> adbc_core::error::Result> { + let mut values = Vec::with_capacity(batch.num_columns()); + for col_index in 0..batch.num_columns() { + let array = batch.column(col_index); + let scalar = + ScalarValue::try_from_array(array, row_index).map_err(ErrorHelper::from_datafusion)?; + values.push(scalar); + } + Ok(values) +} + +pub struct BindReader { + template: LogicalPlan, + runtime: Arc, + ctx: Arc, + bound: Box, + current_batch: Option, + current_row: usize, + pending_results: IntoIter, + schema: SchemaRef, +} + +impl BindReader { + pub fn new( + template: LogicalPlan, + runtime: Arc, + ctx: Arc, + bound: Box, + ) -> Self { + let schema: SchemaRef = template.schema().as_arrow().clone().into(); + Self { + template, + runtime, + ctx, + bound, + current_batch: None, + current_row: 0, + pending_results: Vec::new().into_iter(), + schema, + } + } + + fn advance(&mut self) -> Option> { + loop { + if let Some(batch) = self.pending_results.next() { + return Some(Ok(batch)); + } + + let batch = match &self.current_batch { + Some(b) if self.current_row < b.num_rows() => b, + _ => match self.bound.next() { + Some(Ok(b)) => { + self.current_batch = Some(b); + self.current_row = 0; + self.current_batch.as_ref().unwrap() + } + Some(Err(e)) => return Some(Err(e)), + None => return None, + }, + }; + + let params = match row_to_scalar_values(batch, self.current_row) { + Ok(p) => p, + Err(e) => { + return Some(Err(ArrowError::ExternalError(Box::new(e)))); + } + }; + self.current_row += 1; + + let result = self.runtime.block_on(async { + let plan_with_params = self + .template + .clone() + .with_param_values(params) + .map_err(ErrorHelper::from_datafusion)?; + let df = self + .ctx + .execute_logical_plan(plan_with_params) + .await + .map_err(ErrorHelper::from_datafusion)?; + df.collect().await.map_err(ErrorHelper::from_datafusion) + }); + + match result { + Ok(batches) => { + self.pending_results = batches.into_iter(); + } + Err(e) => { + return Some(Err(ArrowError::ExternalError(Box::new(e.to_adbc())))); + } + } + } + } +} + +impl Iterator for BindReader { + type Item = Result; + + fn next(&mut self) -> Option { + self.advance() + } +} + +impl RecordBatchReader for BindReader { + fn schema(&self) -> SchemaRef { + self.schema.clone() + } +} diff --git a/src/lib.rs b/src/lib.rs index 94a9992..548dd79 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -20,9 +20,10 @@ // specific language governing permissions and limitations // under the License. +mod bind; mod get_objects; use adbc_core::constants; -use datafusion::common::{ScalarValue, TableReference}; +use datafusion::common::TableReference; use datafusion::dataframe::DataFrameWriteOptions; use datafusion::datasource::MemTable; use datafusion::logical_expr::LogicalPlan; @@ -654,124 +655,53 @@ impl Optionable for DataFusionStatement { } } -fn row_to_scalar_values( - batch: &RecordBatch, - row_index: usize, -) -> adbc_core::error::Result> { - let mut values = Vec::with_capacity(batch.num_columns()); - for col_index in 0..batch.num_columns() { - let array = batch.column(col_index); - let scalar = - ScalarValue::try_from_array(array, row_index).map_err(ErrorHelper::from_datafusion)?; - values.push(scalar); +impl DataFusionStatement { + fn get_prepared_plan(&mut self) -> adbc_core::error::Result { + self.prepare()?; + match &self.query { + Some(QueryState::Prepared(plan)) => Ok(plan.clone()), + _ => Err(ErrorHelper::invalid_state() + .message("no query has been set") + .to_adbc()), + } } - Ok(values) -} -impl DataFusionStatement { fn execute_with_params( &mut self, reader: Box, ) -> Result> { - self.runtime.block_on(async { - let mut all_results: Vec = Vec::new(); - let mut result_schema: Option = None; - - for batch in reader { - let batch = batch.map_err(ErrorHelper::from_arrow)?; - for row_idx in 0..batch.num_rows() { - let params = row_to_scalar_values(&batch, row_idx)?; - - let df = match &self.query { - Some(QueryState::Sql(query)) => { - let df = self - .ctx - .sql(query) - .await - .map_err(ErrorHelper::from_datafusion)?; - df.with_param_values(params) - .map_err(ErrorHelper::from_datafusion)? - } - Some(QueryState::Prepared(plan)) => { - let plan_with_params = plan - .clone() - .with_param_values(params) - .map_err(ErrorHelper::from_datafusion)?; - self.ctx - .execute_logical_plan(plan_with_params) - .await - .map_err(ErrorHelper::from_datafusion)? - } - _ => { - return Err(ErrorHelper::invalid_state() - .message("no query has been set") - .to_adbc()); - } - }; - - if result_schema.is_none() { - result_schema = Some(df.schema().as_arrow().clone().into()); - } - let batches = df.collect().await.map_err(ErrorHelper::from_datafusion)?; - all_results.extend(batches); - } - } - - let schema = result_schema.unwrap_or_else(|| Arc::new(arrow_schema::Schema::empty())); - - if all_results.is_empty() { - Ok( - Box::new(SingleBatchReader::new(RecordBatch::new_empty(schema))) - as Box, - ) - } else { - let combined = arrow_select::concat::concat_batches(&schema, &all_results) - .map_err(ErrorHelper::from_arrow)?; - Ok(Box::new(SingleBatchReader::new(combined)) as Box) - } - }) + let template = self.get_prepared_plan()?; + Ok(Box::new(bind::BindReader::new( + template, + self.runtime.clone(), + self.ctx.clone(), + reader, + ))) } fn execute_update_with_params( &mut self, reader: Box, ) -> adbc_core::error::Result> { + let template = self.get_prepared_plan()?; let mut total_rows: i64 = 0; self.runtime.block_on(async { for batch in reader { let batch = batch.map_err(ErrorHelper::from_arrow)?; - total_rows += batch.num_rows() as i64; + total_rows = total_rows.saturating_add(batch.num_rows() as i64); for row_idx in 0..batch.num_rows() { - let params = row_to_scalar_values(&batch, row_idx)?; - - let df = match &self.query { - Some(QueryState::Sql(query)) => { - let df = self - .ctx - .sql(query) - .await - .map_err(ErrorHelper::from_datafusion)?; - df.with_param_values(params) - .map_err(ErrorHelper::from_datafusion)? - } - Some(QueryState::Prepared(plan)) => { - let plan_with_params = plan - .clone() - .with_param_values(params) - .map_err(ErrorHelper::from_datafusion)?; - self.ctx - .execute_logical_plan(plan_with_params) - .await - .map_err(ErrorHelper::from_datafusion)? - } - _ => { - return Err(ErrorHelper::invalid_state() - .message("no query has been set") - .to_adbc()); - } - }; + let params = bind::row_to_scalar_values(&batch, row_idx)?; + let plan_with_params = template + .clone() + .with_param_values(params) + .map_err(ErrorHelper::from_datafusion)?; + let df = self + .ctx + .execute_logical_plan(plan_with_params) + .await + .map_err(ErrorHelper::from_datafusion)?; df.collect().await.map_err(ErrorHelper::from_datafusion)?; } } @@ -1026,31 +956,42 @@ impl Statement for DataFusionStatement { } fn get_parameter_schema(&self) -> adbc_core::error::Result { - match &self.query { - Some(QueryState::Prepared(plan)) => { - let param_types = plan - .get_parameter_types() - .map_err(ErrorHelper::from_datafusion)?; + let plan = match &self.query { + Some(QueryState::Prepared(plan)) => plan.clone(), + Some(QueryState::Sql(sql)) => self + .runtime + .block_on(async { + self.ctx + .state() + .create_logical_plan(sql) + .await + .map_err(ErrorHelper::from_datafusion) + }) + .map_err(|e| e.to_adbc())?, + _ => { + return Err(ErrorHelper::invalid_state() + .message("no query has been set") + .to_adbc()); + } + }; - let mut params: Vec<_> = param_types.into_iter().collect(); - params.sort_by_key(|(name, _)| { - name.trim_start_matches('$').parse::().unwrap_or(0) - }); + let param_types = plan + .get_parameter_types() + .map_err(ErrorHelper::from_datafusion) + .map_err(|e| e.to_adbc())?; - let fields: Vec = params - .into_iter() - .map(|(name, dt)| { - let data_type = dt.unwrap_or(arrow_schema::DataType::Utf8); - arrow_schema::Field::new(name, data_type, true) - }) - .collect(); + let mut params: Vec<_> = param_types.into_iter().collect(); + params.sort_by_key(|(name, _)| name.trim_start_matches('$').parse::().unwrap_or(0)); - Ok(arrow_schema::Schema::new(fields)) - } - _ => Err(ErrorHelper::invalid_state() - .message("statement must be prepared before getting parameter schema") - .to_adbc()), - } + let fields: Vec = params + .into_iter() + .map(|(name, dt)| { + let data_type = dt.unwrap_or(arrow_schema::DataType::Utf8); + arrow_schema::Field::new(name, data_type, true) + }) + .collect(); + + Ok(arrow_schema::Schema::new(fields)) } fn prepare(&mut self) -> adbc_core::error::Result<()> {