diff --git a/src/common/callback_stream.rs b/src/common/callback_stream.rs deleted file mode 100644 index 2edf97a..0000000 --- a/src/common/callback_stream.rs +++ /dev/null @@ -1,84 +0,0 @@ -use futures::Stream; -use pin_project::{pin_project, pinned_drop}; -use std::fmt::Display; -use std::pin::Pin; -use std::task::{Context, Poll}; - -/// The reason why the stream ended: -/// - [CallbackStreamEndReason::Finished] if it finished gracefully -/// - [CallbackStreamEndReason::Aborted] if it was abandoned. -#[derive(Debug)] -pub enum CallbackStreamEndReason { - /// The stream finished gracefully. - Finished, - /// The stream was abandoned. - Aborted, -} - -impl Display for CallbackStreamEndReason { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!(f, "{:?}", self) - } -} - -/// Stream that executes a callback when it is fully consumed or gets cancelled. -#[pin_project(PinnedDrop)] -pub struct CallbackStream -where - S: Stream, - F: FnOnce(CallbackStreamEndReason), -{ - #[pin] - stream: S, - callback: Option, -} - -impl Stream for CallbackStream -where - S: Stream, - F: FnOnce(CallbackStreamEndReason), -{ - type Item = S::Item; - - fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - let this = self.project(); - - match this.stream.poll_next(cx) { - Poll::Ready(None) => { - // Stream is fully consumed, execute the callback - if let Some(callback) = this.callback.take() { - callback(CallbackStreamEndReason::Finished); - } - Poll::Ready(None) - } - other => other, - } - } -} - -#[pinned_drop] -impl PinnedDrop for CallbackStream -where - S: Stream, - F: FnOnce(CallbackStreamEndReason), -{ - fn drop(self: Pin<&mut Self>) { - let this = self.project(); - if let Some(callback) = this.callback.take() { - callback(CallbackStreamEndReason::Aborted); - } - } -} - -/// Wrap a stream with a callback that will be executed when the stream is fully -/// consumed or gets canceled. -pub fn with_callback(stream: S, callback: F) -> CallbackStream -where - S: Stream, - F: FnOnce(CallbackStreamEndReason) + Send + 'static, -{ - CallbackStream { - stream, - callback: Some(callback), - } -} diff --git a/src/common/map_last_stream.rs b/src/common/map_last_stream.rs new file mode 100644 index 0000000..d0eb779 --- /dev/null +++ b/src/common/map_last_stream.rs @@ -0,0 +1,84 @@ +use futures::{Stream, StreamExt, stream}; +use std::task::Poll; + +/// Maps the last element of the provided stream. +pub(crate) fn map_last_stream( + mut input: impl Stream + Unpin, + map_f: impl FnOnce(T) -> T, +) -> impl Stream + Unpin { + let mut final_closure = Some(map_f); + + // this is used to peek the new value so that we can map upon emitting the last message + let mut current_value = None; + + stream::poll_fn(move |cx| match futures::ready!(input.poll_next_unpin(cx)) { + Some(new_val) => { + match current_value.take() { + // This is the first value, so we store it and repoll to get the next value + None => { + current_value = Some(new_val); + cx.waker().wake_by_ref(); + Poll::Pending + } + + Some(existing) => { + current_value = Some(new_val); + + Poll::Ready(Some(existing)) + } + } + } + // this is our last value, so we map it using the user provided closure + None => match current_value.take() { + Some(existing) => { + // make sure we wake ourselves to finish the stream + cx.waker().wake_by_ref(); + + if let Some(closure) = final_closure.take() { + Poll::Ready(Some(closure(existing))) + } else { + unreachable!("the closure is only executed once") + } + } + None => Poll::Ready(None), + }, + }) +} + +#[cfg(test)] +mod tests { + use super::*; + use futures::stream; + + #[tokio::test] + async fn test_map_last_stream_empty_stream() { + let input = stream::empty::(); + let mapped = map_last_stream(input, |x| x + 10); + let result: Vec = mapped.collect().await; + assert_eq!(result, Vec::::new()); + } + + #[tokio::test] + async fn test_map_last_stream_single_element() { + let input = stream::iter(vec![5]); + let mapped = map_last_stream(input, |x| x * 2); + let result: Vec = mapped.collect().await; + assert_eq!(result, vec![10]); + } + + #[tokio::test] + async fn test_map_last_stream_multiple_elements() { + let input = stream::iter(vec![1, 2, 3, 4]); + let mapped = map_last_stream(input, |x| x + 100); + let result: Vec = mapped.collect().await; + assert_eq!(result, vec![1, 2, 3, 104]); // Only the last element is transformed + } + + #[tokio::test] + async fn test_map_last_stream_preserves_order() { + let input = stream::iter(vec![10, 20, 30, 40, 50]); + let mapped = map_last_stream(input, |x| x - 50); + let result: Vec = mapped.collect().await; + assert_eq!(result, vec![10, 20, 30, 40, 0]); // Last element: 50 - 50 = 0 + } +} diff --git a/src/common/mod.rs b/src/common/mod.rs index fe9773f..9194388 100644 --- a/src/common/mod.rs +++ b/src/common/mod.rs @@ -1,7 +1,7 @@ -mod callback_stream; +mod map_last_stream; mod partitioning; #[allow(unused)] pub mod ttl_map; -pub(crate) use callback_stream::with_callback; +pub(crate) use map_last_stream::map_last_stream; pub(crate) use partitioning::{scale_partitioning, scale_partitioning_props}; diff --git a/src/execution_plans/network_coalesce.rs b/src/execution_plans/network_coalesce.rs index 0f44ea5..b3f0ad2 100644 --- a/src/execution_plans/network_coalesce.rs +++ b/src/execution_plans/network_coalesce.rs @@ -14,11 +14,12 @@ use arrow_flight::decode::FlightRecordBatchStream; use arrow_flight::error::FlightError; use dashmap::DashMap; use datafusion::common::{exec_err, internal_datafusion_err, internal_err, plan_err}; +use datafusion::datasource::schema_adapter::DefaultSchemaAdapterFactory; use datafusion::error::DataFusionError; use datafusion::execution::{SendableRecordBatchStream, TaskContext}; use datafusion::physical_plan::stream::RecordBatchStreamAdapter; use datafusion::physical_plan::{DisplayAs, DisplayFormatType, ExecutionPlan, PlanProperties}; -use futures::{TryFutureExt, TryStreamExt}; +use futures::{StreamExt, TryFutureExt, TryStreamExt}; use http::Extensions; use prost::Message; use std::any::Any; @@ -283,6 +284,8 @@ impl ExecutionPlan for NetworkCoalesceExec { }; let metrics_collection_capture = self_ready.metrics_collection.clone(); + let adapter = DefaultSchemaAdapterFactory::from_schema(self.schema()); + let (mapper, _indices) = adapter.map_schema(&self.schema())?; let stream = async move { let mut client = channel_resolver.get_flight_client_for_url(&url).await?; let stream = client @@ -297,7 +300,12 @@ impl ExecutionPlan for NetworkCoalesceExec { Ok( FlightRecordBatchStream::new_from_flight_data(metrics_collecting_stream) - .map_err(map_flight_to_datafusion_error), + .map_err(map_flight_to_datafusion_error) + .map(move |batch| { + let batch = batch?; + + mapper.map_batch(batch) + }), ) } .try_flatten_stream(); diff --git a/src/execution_plans/network_shuffle.rs b/src/execution_plans/network_shuffle.rs index 54140aa..51beeb0 100644 --- a/src/execution_plans/network_shuffle.rs +++ b/src/execution_plans/network_shuffle.rs @@ -14,6 +14,7 @@ use arrow_flight::decode::FlightRecordBatchStream; use arrow_flight::error::FlightError; use dashmap::DashMap; use datafusion::common::{exec_err, internal_datafusion_err, internal_err, plan_err}; +use datafusion::datasource::schema_adapter::DefaultSchemaAdapterFactory; use datafusion::error::DataFusionError; use datafusion::execution::{SendableRecordBatchStream, TaskContext}; use datafusion::physical_expr::Partitioning; @@ -308,8 +309,12 @@ impl ExecutionPlan for NetworkShuffleExec { let task_context = DistributedTaskContext::from_ctx(&context); let off = self_ready.properties.partitioning.partition_count() * task_context.task_index; + let adapter = DefaultSchemaAdapterFactory::from_schema(self.schema()); + let (mapper, _indices) = adapter.map_schema(&self.schema())?; + let stream = input_stage_tasks.into_iter().enumerate().map(|(i, task)| { let channel_resolver = Arc::clone(&channel_resolver); + let mapper = mapper.clone(); let ticket = Request::from_parts( MetadataMap::from_headers(context_headers.clone()), @@ -349,7 +354,12 @@ impl ExecutionPlan for NetworkShuffleExec { Ok( FlightRecordBatchStream::new_from_flight_data(metrics_collecting_stream) - .map_err(map_flight_to_datafusion_error), + .map_err(map_flight_to_datafusion_error) + .map(move |batch| { + let batch = batch?; + + mapper.map_batch(batch) + }), ) } .try_flatten_stream() diff --git a/src/flight_service/do_get.rs b/src/flight_service/do_get.rs index 03585fc..8b754de 100644 --- a/src/flight_service/do_get.rs +++ b/src/flight_service/do_get.rs @@ -1,9 +1,8 @@ -use crate::common::with_callback; +use crate::common::map_last_stream; use crate::config_extension_ext::ContextGrpcMetadata; use crate::execution_plans::{DistributedTaskContext, StageExec}; use crate::flight_service::service::ArrowFlightEndpoint; use crate::flight_service::session_builder::DistributedSessionBuilderContext; -use crate::flight_service::trailing_flight_data_stream::TrailingFlightDataStream; use crate::metrics::TaskMetricsCollector; use crate::metrics::proto::df_metrics_set_to_proto; use crate::protobuf::{ @@ -16,15 +15,10 @@ use arrow_flight::encode::FlightDataEncoderBuilder; use arrow_flight::error::FlightError; use arrow_flight::flight_service_server::FlightService; use bytes::Bytes; -use datafusion::arrow::array::RecordBatch; -use datafusion::arrow::datatypes::SchemaRef; -use datafusion::arrow::ipc::writer::{DictionaryTracker, IpcDataGenerator, IpcWriteOptions}; + use datafusion::common::exec_datafusion_err; -use datafusion::execution::SendableRecordBatchStream; -use datafusion::physical_plan::stream::RecordBatchStreamAdapter; use datafusion::prelude::SessionContext; use futures::TryStreamExt; -use futures::{Stream, stream}; use prost::Message; use std::sync::Arc; use std::sync::atomic::{AtomicUsize, Ordering}; @@ -134,90 +128,39 @@ impl ArrowFlightEndpoint { .execute(doget.target_partition as usize, session_state.task_ctx()) .map_err(|err| Status::internal(format!("Error executing stage plan: {err:#?}")))?; - let schema = stream.schema(); - - // TODO: We don't need to do this since the stage / plan is captured again by the - // TrailingFlightDataStream. However, we will eventuall only use the TrailingFlightDataStream - // if we are running an `explain (analyze)` command. We should update this section - // to only use one or the other - not both. - let plan_capture = stage.plan.clone(); - let stream = with_callback(stream, move |_| { - // We need to hold a reference to the plan for at least as long as the stream is - // execution. Some plans might store state necessary for the stream to work, and - // dropping the plan early could drop this state too soon. - let _ = plan_capture; - }); - - let record_batch_stream = Box::pin(RecordBatchStreamAdapter::new(schema, stream)); - let task_data_capture = self.task_data_entries.clone(); - Ok(flight_stream_from_record_batch_stream( - key.clone(), - stage_data.clone(), - move || { - task_data_capture.remove(key.clone()); - }, - record_batch_stream, - )) - } -} - -fn missing(field: &'static str) -> impl FnOnce() -> Status { - move || Status::invalid_argument(format!("Missing field '{field}'")) -} - -/// Creates a tonic response from a stream of record batches. Handles -/// - RecordBatch to flight conversion partition tracking, stage eviction, and trailing metrics. -fn flight_stream_from_record_batch_stream( - stage_key: StageKey, - stage_data: TaskData, - evict_stage: impl FnOnce() + Send + 'static, - stream: SendableRecordBatchStream, -) -> Response<::DoGetStream> { - let flight_data_stream = - FlightDataEncoderBuilder::new() + let stream = FlightDataEncoderBuilder::new() .with_schema(stream.schema().clone()) .build(stream.map_err(|err| { FlightError::Tonic(Box::new(datafusion_error_to_tonic_status(&err))) })); - let trailing_metrics_stream = TrailingFlightDataStream::new( - move || { - if stage_data - .num_partitions_remaining - .fetch_sub(1, Ordering::SeqCst) - == 1 - { - evict_stage(); - - let metrics_stream = - collect_and_create_metrics_flight_data(stage_key, stage_data.stage).map_err( - |err| { - Status::internal(format!( - "error collecting metrics in arrow flight endpoint: {err}" - )) - }, - )?; - - return Ok(Some(metrics_stream)); + let task_data_entries = Arc::clone(&self.task_data_entries); + let num_partitions_remaining = Arc::clone(&stage_data.num_partitions_remaining); + + let stream = map_last_stream(stream, move |last| { + if num_partitions_remaining.fetch_sub(1, Ordering::SeqCst) == 1 { + task_data_entries.remove(key.clone()); } + last.and_then(|el| collect_and_create_metrics_flight_data(key, stage, el)) + }); - Ok(None) - }, - flight_data_stream, - ); + Ok(Response::new(Box::pin(stream.map_err(|err| match err { + FlightError::Tonic(status) => *status, + _ => Status::internal(format!("Error during flight stream: {err}")), + })))) + } +} - Response::new(Box::pin(trailing_metrics_stream.map_err(|err| match err { - FlightError::Tonic(status) => *status, - _ => Status::internal(format!("Error during flight stream: {err}")), - }))) +fn missing(field: &'static str) -> impl FnOnce() -> Status { + move || Status::invalid_argument(format!("Missing field '{field}'")) } -// Collects metrics from the provided stage and encodes it into a stream of flight data using -// the schema of the stage. +/// Collects metrics from the provided stage and includes it in the flight data fn collect_and_create_metrics_flight_data( stage_key: StageKey, stage: Arc, -) -> Result> + Send + 'static, FlightError> { + incoming: FlightData, +) -> Result { // Get the metrics for the task executed on this worker. Separately, collect metrics for child tasks. let mut result = TaskMetricsCollector::new() .collect(stage.plan.clone()) @@ -252,35 +195,12 @@ fn collect_and_create_metrics_flight_data( })), }; - let metrics_flight_data = - empty_flight_data_with_app_metadata(flight_app_metadata, stage.plan.schema())?; - Ok(Box::pin(stream::once( - async move { Ok(metrics_flight_data) }, - ))) -} - -/// Creates a FlightData with the given app_metadata and empty RecordBatch using the provided schema. -/// We don't use [arrow_flight::encode::FlightDataEncoder] (and by extension, the [arrow_flight::encode::FlightDataEncoderBuilder]) -/// since they skip messages with empty RecordBatch data. -pub fn empty_flight_data_with_app_metadata( - metadata: FlightAppMetadata, - schema: SchemaRef, -) -> Result { let mut buf = vec![]; - metadata + flight_app_metadata .encode(&mut buf) .map_err(|err| FlightError::ProtocolError(err.to_string()))?; - let empty_batch = RecordBatch::new_empty(schema); - let options = IpcWriteOptions::default(); - let data_gen = IpcDataGenerator::default(); - let mut dictionary_tracker = DictionaryTracker::new(true); - let (_, encoded_data) = data_gen - .encoded_batch(&empty_batch, &mut dictionary_tracker, &options) - .map_err(|e| { - FlightError::ProtocolError(format!("Failed to create empty batch FlightData: {e}")) - })?; - Ok(FlightData::from(encoded_data).with_app_metadata(buf)) + Ok(incoming.with_app_metadata(buf)) } #[cfg(test)] diff --git a/src/flight_service/mod.rs b/src/flight_service/mod.rs index db3bd91..96ea609 100644 --- a/src/flight_service/mod.rs +++ b/src/flight_service/mod.rs @@ -1,7 +1,6 @@ mod do_get; mod service; mod session_builder; -pub(super) mod trailing_flight_data_stream; pub(crate) use do_get::DoGet; pub use service::ArrowFlightEndpoint; diff --git a/src/flight_service/trailing_flight_data_stream.rs b/src/flight_service/trailing_flight_data_stream.rs deleted file mode 100644 index 36e6778..0000000 --- a/src/flight_service/trailing_flight_data_stream.rs +++ /dev/null @@ -1,236 +0,0 @@ -use arrow_flight::{FlightData, error::FlightError}; -use futures::stream::Stream; -use pin_project::pin_project; -use std::pin::Pin; -use std::task::{Context, Poll}; -use tokio::pin; - -/// TrailingFlightDataStream - wraps a FlightData stream. It calls the `on_complete` closure when the stream is finished. -/// If the closure returns a new stream, it will be appended to the original stream and consumed. -#[pin_project] -pub struct TrailingFlightDataStream -where - S: Stream> + Send, - T: Stream> + Send, - F: FnOnce() -> Result, FlightError>, -{ - #[pin] - inner: S, - on_complete: Option, - #[pin] - trailing_stream: Option, -} - -impl TrailingFlightDataStream -where - S: Stream> + Send, - T: Stream> + Send, - F: FnOnce() -> Result, FlightError>, -{ - pub fn new(on_complete: F, inner: S) -> Self { - Self { - inner, - on_complete: Some(on_complete), - trailing_stream: None, - } - } -} - -impl Stream for TrailingFlightDataStream -where - S: Stream> + Send, - T: Stream> + Send, - F: FnOnce() -> Result, FlightError>, -{ - type Item = Result; - - fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - let mut this = self.as_mut().project(); - - match this.inner.poll_next(cx) { - Poll::Ready(Some(Ok(flight_data))) => Poll::Ready(Some(Ok(flight_data))), - Poll::Ready(Some(Err(err))) => Poll::Ready(Some(Err(err))), - Poll::Ready(None) => { - if let Some(trailing_stream) = this.trailing_stream.as_mut().as_pin_mut() { - return trailing_stream.poll_next(cx); - } - if let Some(on_complete) = this.on_complete.take() { - if let Some(trailing_stream) = on_complete()? { - this.trailing_stream.set(Some(trailing_stream)); - return self.poll_next(cx); - } - } - Poll::Ready(None) - } - Poll::Pending => Poll::Pending, - } - } -} - -#[cfg(test)] -mod tests { - use super::*; - use arrow::array::{Array, Int32Array, StringArray}; - use arrow::datatypes::{DataType, Field, Schema}; - use arrow::record_batch::RecordBatch; - use arrow_flight::FlightData; - use arrow_flight::decode::FlightRecordBatchStream; - use arrow_flight::encode::{FlightDataEncoder, FlightDataEncoderBuilder}; - use futures::stream::{self, StreamExt}; - use std::sync::Arc; - - fn create_trailing_flight_data_stream( - name_array: StringArray, - value_array: Int32Array, - ) -> Pin> + Send>> { - create_flight_data_stream_inner(name_array, value_array, true) - } - - fn create_flight_data_stream( - name_array: StringArray, - value_array: Int32Array, - ) -> Pin> + Send>> { - create_flight_data_stream_inner(name_array, value_array, false) - } - - // Creates a stream of RecordBatches. - fn create_flight_data_stream_inner( - name_array: StringArray, - value_array: Int32Array, - is_trailing: bool, - ) -> Pin> + Send>> { - assert_eq!( - name_array.len(), - value_array.len(), - "StringArray and Int32Array must have equal lengths" - ); - - let schema = Arc::new(Schema::new(vec![ - Field::new("name", DataType::Utf8, false), - Field::new("value", DataType::Int32, false), - ])); - - let batches: Vec = (0..name_array.len()) - .map(|i| { - let name_slice = name_array.slice(i, 1); - let value_slice = value_array.slice(i, 1); - - RecordBatch::try_new( - schema.clone(), - vec![Arc::new(name_slice), Arc::new(value_slice)], - ) - .unwrap() - }) - .collect(); - - let batch_stream = futures::stream::iter(batches.into_iter().map(Ok)); - let flight_stream = FlightDataEncoderBuilder::new() - .with_schema(schema) - .build(batch_stream); - - // By default, this encoder will emit a schema message as the first message in the stream. - // Since we are concatenating streams, we need to drop the schema message from the trailing stream. - if is_trailing { - // Skip the schema message - return Box::pin(flight_stream.skip(1)); - } - Box::pin(flight_stream) - } - - #[tokio::test] - async fn test_basic_streaming_functionality() { - let name_array = StringArray::from(vec!["a", "b", "c"]); - let value_array = Int32Array::from(vec![1, 2, 3]); - let inner_stream = create_flight_data_stream(name_array, value_array); - - let name_array = StringArray::from(vec!["d", "e", "f"]); - let value_array = Int32Array::from(vec![5, 6, 7]); - let trailing_stream = create_trailing_flight_data_stream(name_array, value_array); - - let on_complete = || Ok(Some(trailing_stream)); - - let trailing_stream = TrailingFlightDataStream::new(on_complete, inner_stream); - let record_batches = FlightRecordBatchStream::new_from_flight_data(trailing_stream) - .collect::>>() - .await; - - assert_eq!(record_batches.len(), 6); - assert!(record_batches.iter().all(|batch| batch.is_ok())); - assert_eq!( - record_batches - .iter() - .map(|batch| batch - .as_ref() - .unwrap() - .column(0) - .as_any() - .downcast_ref::() - .unwrap() - .value(0)) - .collect::>(), - vec!["a", "b", "c", "d", "e", "f"] - ); - } - - #[tokio::test] - async fn test_error_handling_in_inner_stream() { - let mut stream = - create_flight_data_stream(StringArray::from(vec!["item1"]), Int32Array::from(vec![1])); - let schema_message = stream.next().await.unwrap().unwrap(); - let flight_data = stream.next().await.unwrap().unwrap(); - let data = vec![ - Ok(schema_message), - Ok(flight_data), - Err(FlightError::ExternalError(Box::new(std::io::Error::new( - std::io::ErrorKind::Other, - "test error", - )))), - ]; - let inner_stream = stream::iter(data); - let on_complete = || Ok(None::); - let trailing_stream = TrailingFlightDataStream::new(on_complete, inner_stream); - let record_batches = FlightRecordBatchStream::new_from_flight_data(trailing_stream) - .collect::>>() - .await; - - assert_eq!(record_batches.len(), 2); - assert!(record_batches[0].is_ok()); - assert!(record_batches[1].is_err()); - } - - #[tokio::test] - async fn test_error_handling_in_on_complete_callback() { - let name_array = StringArray::from(vec!["item1"]); - let value_array = Int32Array::from(vec![1]); - let inner_stream = create_flight_data_stream(name_array, value_array); - let on_complete = || -> Result, FlightError> { - Err(FlightError::ExternalError(Box::new(std::io::Error::new( - std::io::ErrorKind::Other, - "callback error", - )))) - }; - - let trailing_stream = TrailingFlightDataStream::new(on_complete, inner_stream); - let record_batches = FlightRecordBatchStream::new_from_flight_data(trailing_stream) - .collect::>>() - .await; - assert_eq!(record_batches.len(), 2); - assert!(record_batches[0].is_ok()); - assert!(record_batches[1].is_err()); - } - - #[tokio::test] - async fn test_stream_with_no_trailer() { - let inner_stream = create_flight_data_stream( - StringArray::from(vec!["item1"] as Vec<&str>), - Int32Array::from(vec![1] as Vec), - ); - let on_complete = || Ok(None::); - let trailing_stream = TrailingFlightDataStream::new(on_complete, inner_stream); - let record_batches = FlightRecordBatchStream::new_from_flight_data(trailing_stream) - .collect::>>() - .await; - assert_eq!(record_batches.len(), 1); - assert!(record_batches[0].is_ok()); - } -} diff --git a/src/metrics/task_metrics_collector.rs b/src/metrics/task_metrics_collector.rs index adee221..e4c054c 100644 --- a/src/metrics/task_metrics_collector.rs +++ b/src/metrics/task_metrics_collector.rs @@ -122,6 +122,7 @@ impl TaskMetricsCollector { mod tests { use super::*; + use arrow::datatypes::UInt16Type; use datafusion::arrow::array::{Int32Array, StringArray}; use datafusion::arrow::record_batch::RecordBatch; use futures::StreamExt; @@ -183,6 +184,11 @@ mod tests { Field::new("name", DataType::Utf8, false), Field::new("phone", DataType::Utf8, false), Field::new("balance", DataType::Float64, false), + Field::new( + "company", + DataType::Dictionary(Box::new(DataType::UInt16), Box::new(DataType::Utf8)), + false, + ), ])); let batches2 = vec![ @@ -203,6 +209,11 @@ mod tests { Arc::new(datafusion::arrow::array::Float64Array::from(vec![ 100.5, 250.0, 50.25, ])), + Arc::new( + vec!["company1", "company1", "company1"] + .into_iter() + .collect::>(), + ), ], ) .unwrap(), @@ -239,9 +250,13 @@ mod tests { let task_ctx = ctx.task_ctx(); let stream = stage_exec.execute(0, task_ctx).unwrap(); + let schema = stream.schema(); + let mut stream = stream; while let Some(batch) = stream.next().await { - batch.unwrap(); + let batch = batch.unwrap(); + + assert_eq!(schema, batch.schema()) } } @@ -322,7 +337,7 @@ mod tests { #[tokio::test] async fn test_metrics_collection_e2e_3() { run_metrics_collection_e2e_test( - "SELECT + "SELECT substring(phone, 1, 2) as country_code, count(*) as num_customers, sum(balance) as total_balance @@ -338,4 +353,9 @@ mod tests { ) .await; } + + #[tokio::test] + async fn test_metrics_collection_e2e_4() { + run_metrics_collection_e2e_test("SELECT distinct company from table2").await; + } }