Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
84 changes: 0 additions & 84 deletions src/common/callback_stream.rs

This file was deleted.

84 changes: 84 additions & 0 deletions src/common/map_last_stream.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
use futures::{Stream, StreamExt, stream};
use std::task::Poll;

/// Maps the last element of the provided stream.
pub(crate) fn map_last_stream<T>(
mut input: impl Stream<Item = T> + Unpin,
map_f: impl FnOnce(T) -> T,
) -> impl Stream<Item = T> + Unpin {
let mut final_closure = Some(map_f);

// this is used to peek the new value so that we can map upon emitting the last message
let mut current_value = None;

stream::poll_fn(move |cx| match futures::ready!(input.poll_next_unpin(cx)) {
Some(new_val) => {
match current_value.take() {
// This is the first value, so we store it and repoll to get the next value
None => {
current_value = Some(new_val);
cx.waker().wake_by_ref();
Poll::Pending
}

Some(existing) => {
current_value = Some(new_val);

Poll::Ready(Some(existing))
}
}
}
// this is our last value, so we map it using the user provided closure
None => match current_value.take() {
Some(existing) => {
// make sure we wake ourselves to finish the stream
cx.waker().wake_by_ref();

if let Some(closure) = final_closure.take() {
Poll::Ready(Some(closure(existing)))
} else {
unreachable!("the closure is only executed once")
}
}
None => Poll::Ready(None),
},
})
}

#[cfg(test)]
mod tests {
use super::*;
use futures::stream;

#[tokio::test]
async fn test_map_last_stream_empty_stream() {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think this works in all cases. Say that the last partition for a task is empty. This behavior means we won't send any metrics for any partitions of the task (because we only send metrics for the entire task after the last partition is done).

It also means we may lose metrics from child tasks because this task may have collected them.

Unfortunately we don't have a test for this case. We would certainly benefit from having that.

Copy link
Contributor Author

@cetra3 cetra3 Oct 6, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The Arrow Flight Encoder will always send something as far as I can tell. Even if there are no recordbatches returned, you will still receive the encoded schema

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah okay. This makes sense.

let input = stream::empty::<i32>();
let mapped = map_last_stream(input, |x| x + 10);
let result: Vec<i32> = mapped.collect().await;
assert_eq!(result, Vec::<i32>::new());
}

#[tokio::test]
async fn test_map_last_stream_single_element() {
let input = stream::iter(vec![5]);
let mapped = map_last_stream(input, |x| x * 2);
let result: Vec<i32> = mapped.collect().await;
assert_eq!(result, vec![10]);
}

#[tokio::test]
async fn test_map_last_stream_multiple_elements() {
let input = stream::iter(vec![1, 2, 3, 4]);
let mapped = map_last_stream(input, |x| x + 100);
let result: Vec<i32> = mapped.collect().await;
assert_eq!(result, vec![1, 2, 3, 104]); // Only the last element is transformed
}

#[tokio::test]
async fn test_map_last_stream_preserves_order() {
let input = stream::iter(vec![10, 20, 30, 40, 50]);
let mapped = map_last_stream(input, |x| x - 50);
let result: Vec<i32> = mapped.collect().await;
assert_eq!(result, vec![10, 20, 30, 40, 0]); // Last element: 50 - 50 = 0
}
}
4 changes: 2 additions & 2 deletions src/common/mod.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
mod callback_stream;
mod map_last_stream;
mod partitioning;
#[allow(unused)]
pub mod ttl_map;

pub(crate) use callback_stream::with_callback;
pub(crate) use map_last_stream::map_last_stream;
pub(crate) use partitioning::{scale_partitioning, scale_partitioning_props};
12 changes: 10 additions & 2 deletions src/execution_plans/network_coalesce.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,12 @@ use arrow_flight::decode::FlightRecordBatchStream;
use arrow_flight::error::FlightError;
use dashmap::DashMap;
use datafusion::common::{exec_err, internal_datafusion_err, internal_err, plan_err};
use datafusion::datasource::schema_adapter::DefaultSchemaAdapterFactory;
use datafusion::error::DataFusionError;
use datafusion::execution::{SendableRecordBatchStream, TaskContext};
use datafusion::physical_plan::stream::RecordBatchStreamAdapter;
use datafusion::physical_plan::{DisplayAs, DisplayFormatType, ExecutionPlan, PlanProperties};
use futures::{TryFutureExt, TryStreamExt};
use futures::{StreamExt, TryFutureExt, TryStreamExt};
use http::Extensions;
use prost::Message;
use std::any::Any;
Expand Down Expand Up @@ -283,6 +284,8 @@ impl ExecutionPlan for NetworkCoalesceExec {
};

let metrics_collection_capture = self_ready.metrics_collection.clone();
let adapter = DefaultSchemaAdapterFactory::from_schema(self.schema());
let (mapper, _indices) = adapter.map_schema(&self.schema())?;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks like 1:1 schema mapping. What does it do? Is this just a way to assert that the schema hasn't changed? I think adding a test which shows why this is necessary would be good.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The schema does change. The Arrow Flight data hydrates dictionary values as real values, and so the schema of the incoming recordbatch is different. We use the mapper here to map back to what the execution plan expects

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I noticed that tests still pass without this line.

IIUC, the root problem was on the server - we were sending an empty flight data to the client without sending the schema / dictionary message first. You've fixed this problem.

I don't see an issue on the client that this solves. The flight decoder in the client should be able to handle any message sent by the encoder on the server.

The metrics collector on the client passes through flight data unchanged, minus clearing the app_metadata.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would prefer to either have a test which shows why this is needed or remove the lines. Lmk if you think otherwise though!

Once again, I appreciate the contribution 🙏🏽 - the old empty flight data code was sketchy for sure.

Copy link
Contributor Author

@cetra3 cetra3 Oct 8, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

test_metrics_collection_e2e_4 fails with this removed from both the network plans

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah sorry I commented one but not the other. This LGTM

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've added an assert here to make sure the schema matches: a141a3b

let stream = async move {
let mut client = channel_resolver.get_flight_client_for_url(&url).await?;
let stream = client
Expand All @@ -297,7 +300,12 @@ impl ExecutionPlan for NetworkCoalesceExec {

Ok(
FlightRecordBatchStream::new_from_flight_data(metrics_collecting_stream)
.map_err(map_flight_to_datafusion_error),
.map_err(map_flight_to_datafusion_error)
.map(move |batch| {
let batch = batch?;

mapper.map_batch(batch)
}),
)
}
.try_flatten_stream();
Expand Down
12 changes: 11 additions & 1 deletion src/execution_plans/network_shuffle.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ use arrow_flight::decode::FlightRecordBatchStream;
use arrow_flight::error::FlightError;
use dashmap::DashMap;
use datafusion::common::{exec_err, internal_datafusion_err, internal_err, plan_err};
use datafusion::datasource::schema_adapter::DefaultSchemaAdapterFactory;
use datafusion::error::DataFusionError;
use datafusion::execution::{SendableRecordBatchStream, TaskContext};
use datafusion::physical_expr::Partitioning;
Expand Down Expand Up @@ -308,8 +309,12 @@ impl ExecutionPlan for NetworkShuffleExec {
let task_context = DistributedTaskContext::from_ctx(&context);
let off = self_ready.properties.partitioning.partition_count() * task_context.task_index;

let adapter = DefaultSchemaAdapterFactory::from_schema(self.schema());
let (mapper, _indices) = adapter.map_schema(&self.schema())?;

let stream = input_stage_tasks.into_iter().enumerate().map(|(i, task)| {
let channel_resolver = Arc::clone(&channel_resolver);
let mapper = mapper.clone();

let ticket = Request::from_parts(
MetadataMap::from_headers(context_headers.clone()),
Expand Down Expand Up @@ -349,7 +354,12 @@ impl ExecutionPlan for NetworkShuffleExec {

Ok(
FlightRecordBatchStream::new_from_flight_data(metrics_collecting_stream)
.map_err(map_flight_to_datafusion_error),
.map_err(map_flight_to_datafusion_error)
.map(move |batch| {
let batch = batch?;

mapper.map_batch(batch)
}),
)
}
.try_flatten_stream()
Expand Down
Loading