From 3a8a1db65394bca22cd2dee2b8421a21cf29faf8 Mon Sep 17 00:00:00 2001 From: Yongting You <2010youy01@gmail.com> Date: Wed, 3 Jun 2026 16:50:57 +0800 Subject: [PATCH 1/5] split hash aggregation logic --- .../src/aggregates/group_values/metrics.rs | 9 +- .../src/aggregates/hash_aggregate.rs | 349 ++++++++++ .../src/aggregates/hash_table.rs | 615 ++++++++++++++++++ .../physical-plan/src/aggregates/mod.rs | 165 ++++- 4 files changed, 1132 insertions(+), 6 deletions(-) create mode 100644 datafusion/physical-plan/src/aggregates/hash_aggregate.rs create mode 100644 datafusion/physical-plan/src/aggregates/hash_table.rs diff --git a/datafusion/physical-plan/src/aggregates/group_values/metrics.rs b/datafusion/physical-plan/src/aggregates/group_values/metrics.rs index b6c32204e85f0..a0934b976ea79 100644 --- a/datafusion/physical-plan/src/aggregates/group_values/metrics.rs +++ b/datafusion/physical-plan/src/aggregates/group_values/metrics.rs @@ -59,6 +59,7 @@ mod tests { use arrow::record_batch::RecordBatch; use datafusion_common::Result; use datafusion_execution::TaskContext; + use datafusion_execution::runtime_env::RuntimeEnvBuilder; use datafusion_functions_aggregate::count::count_udaf; use datafusion_functions_aggregate::sum::sum_udaf; use datafusion_physical_expr::aggregate::AggregateExprBuilder; @@ -135,7 +136,13 @@ mod tests { schema, )?); - let task_ctx = Arc::new(TaskContext::default()); + // This test is for `GroupByMetrics`, which are maintained by + // `GroupedHashAggregateStream`. Use a finite memory pool so the partial + // aggregate does not take the initial-partial stream path. + let runtime = RuntimeEnvBuilder::new() + .with_memory_limit(10 * 1024 * 1024, 1.0) + .build_arc()?; + let task_ctx = Arc::new(TaskContext::default().with_runtime(runtime)); let _result = collect(Arc::clone(&aggregate_exec) as _, Arc::clone(&task_ctx)).await?; diff --git a/datafusion/physical-plan/src/aggregates/hash_aggregate.rs b/datafusion/physical-plan/src/aggregates/hash_aggregate.rs new file mode 100644 index 0000000000000..f716ef41279e8 --- /dev/null +++ b/datafusion/physical-plan/src/aggregates/hash_aggregate.rs @@ -0,0 +1,349 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Grouped hash aggregation for simple multi-stage aggregation paths. +//! +//! This module handles the basic grouped two-stage paths: +//! +//! ```text +//! input rows -> GROUP BY hash table -> accumulator state rows +//! state rows -> GROUP BY hash table -> final aggregate rows +//! ``` +//! +//! `AggregateExec` keeps finite-memory, ordered, limit, grouping-set, +//! `partial state -> partial state`, and single-stage aggregation on +//! `GroupedHashAggregateStream` for now. + +use std::sync::Arc; +use std::task::{Context, Poll}; + +use arrow::datatypes::SchemaRef; +use arrow::record_batch::RecordBatch; +use datafusion_common::Result; +use datafusion_execution::TaskContext; +use datafusion_execution::memory_pool::{MemoryConsumer, MemoryReservation}; +use futures::stream::{Stream, StreamExt}; + +use super::AggregateExec; +use super::hash_table::{AggregateHashTable, InitialPartial, PartialFinal}; +use crate::metrics::{BaselineMetrics, MetricBuilder, RecordOutput, SpillMetrics}; +use crate::stream::EmptyRecordBatchStream; +use crate::{InputOrderMode, RecordBatchStream, SendableRecordBatchStream, metrics}; + +/// Hash aggregation uses a 2-stage (partial and final) hash aggregation, this stream +/// is for the partial stage. +/// +/// # Example +/// +/// select k, avg(v) from t group by k; +/// +/// ## Plan +/// AggregateExec(stage=final) +/// -- RepartitionExec(hash(k)) +/// ---- AggregateExec(stage=partial) +/// +/// ## Partial Stage Behavior +/// Input: raw rows +/// Output: partial states for all groups (e.g. for avg(x), it's sum(x), count(x)) +/// +/// ## Final Stage Behavior +/// Input: partial states +/// Output: results for all groups (e.g. for avg(x), it's avg(x) calculated from the state) +pub(crate) struct InitialPartialHashAggregateStream { + /// Output schema: group columns followed by partial aggregate state columns. + schema: SchemaRef, + + /// Input batches containing raw rows, not partial aggregate state. + input: SendableRecordBatchStream, + + /// Hash table state for this aggregate stream. + hash_table: AggregateHashTable, + + /// Memory reservation for group keys and accumulators. + reservation: MemoryReservation, + + /// Execution metrics shared with the aggregate plan node. + baseline_metrics: BaselineMetrics, + + /// Tracks partial aggregation row reduction, matching `GroupedHashAggregateStream`. + reduction_factor: metrics::RatioMetrics, +} + +/// Hash aggregation uses a 2-stage (partial and final) hash aggregation, this stream +/// is for the final stage. +/// +/// See [`InitialPartialHashAggregateStream`] for details. +pub(crate) struct PartialFinalHashAggregateStream { + /// Output schema: group columns followed by final aggregate value columns. + schema: SchemaRef, + + /// Input batches containing partial aggregate state rows. + input: SendableRecordBatchStream, + + /// Hash table state for this aggregate stream. + hash_table: AggregateHashTable, + + /// Execution metrics shared with the aggregate plan node. + baseline_metrics: BaselineMetrics, + + /// Memory reservation for group keys and accumulators. + reservation: MemoryReservation, +} + +impl InitialPartialHashAggregateStream { + pub fn new( + agg: &AggregateExec, + context: &Arc, + partition: usize, + ) -> Result { + debug_assert_eq!(agg.mode, super::AggregateMode::Partial); + debug_assert_eq!(agg.input_order_mode, InputOrderMode::Linear); + + let schema = Arc::clone(&agg.schema); + let input = agg.input.execute(partition, Arc::clone(context))?; + let batch_size = context.session_config().batch_size(); + let baseline_metrics = BaselineMetrics::new(&agg.metrics, partition); + + // Preserve the existing aggregate metric surface for this plan node. + let _spill_metrics = SpillMetrics::new(&agg.metrics, partition); + let reduction_factor = MetricBuilder::new(&agg.metrics) + .with_type(metrics::MetricType::Summary) + .ratio_metrics("reduction_factor", partition); + + let hash_table = AggregateHashTable::::new( + agg, + partition, + Arc::clone(&schema), + batch_size, + )?; + + let reservation = MemoryConsumer::new(format!( + "InitialPartialHashAggregateStream[{partition}]" + )) + .register(context.memory_pool()); + + Ok(Self { + schema, + input, + hash_table, + baseline_metrics, + reservation, + reduction_factor, + }) + } +} + +impl Stream for InitialPartialHashAggregateStream { + type Item = Result; + + fn poll_next( + mut self: std::pin::Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll> { + let elapsed_compute = self.baseline_metrics.elapsed_compute().clone(); + + loop { + if self.hash_table.is_done() { + let _ = self.reservation.try_resize(0); + return Poll::Ready(None); + } else if self.hash_table.is_building() { + match self.input.poll_next_unpin(cx) { + Poll::Pending => return Poll::Pending, + Poll::Ready(Some(Ok(batch))) => { + let timer = elapsed_compute.timer(); + self.reduction_factor.add_total(batch.num_rows()); + let result = self.hash_table.aggregate_batch(&batch); + timer.done(); + + if let Err(e) = result { + return Poll::Ready(Some(Err(e))); + } + + // TODO: impl memory-limited aggr, when OOM directly send + // partial state to final aggregate stage + if let Err(e) = + self.reservation.try_resize(self.hash_table.memory_size()) + { + return Poll::Ready(Some(Err(e))); + } + } + Poll::Ready(Some(Err(e))) => { + return Poll::Ready(Some(Err(e))); + } + Poll::Ready(None) => { + let input_schema = self.input.schema(); + self.input = Box::pin(EmptyRecordBatchStream::new(input_schema)); + + let timer = elapsed_compute.timer(); + let result = self.hash_table.start_output(); + timer.done(); + + if let Err(e) = result { + return Poll::Ready(Some(Err(e))); + } + } + } + } else { + let timer = elapsed_compute.timer(); + let result = self.hash_table.next_output_batch(); + timer.done(); + + match result { + Ok(Some(batch)) => { + let _ = + self.reservation.try_resize(self.hash_table.memory_size()); + self.reduction_factor.add_part(batch.num_rows()); + debug_assert!(batch.num_rows() > 0); + return Poll::Ready(Some(Ok( + batch.record_output(&self.baseline_metrics) + ))); + } + Ok(None) => { + let _ = self.reservation.try_resize(0); + return Poll::Ready(None); + } + Err(e) => return Poll::Ready(Some(Err(e))), + } + } + } + } +} + +impl RecordBatchStream for InitialPartialHashAggregateStream { + fn schema(&self) -> SchemaRef { + Arc::clone(&self.schema) + } +} + +impl PartialFinalHashAggregateStream { + pub fn new( + agg: &AggregateExec, + context: &Arc, + partition: usize, + ) -> Result { + debug_assert!(matches!( + agg.mode, + super::AggregateMode::Final | super::AggregateMode::FinalPartitioned + )); + debug_assert_eq!(agg.input_order_mode, InputOrderMode::Linear); + + let schema = Arc::clone(&agg.schema); + let input = agg.input.execute(partition, Arc::clone(context))?; + let batch_size = context.session_config().batch_size(); + let baseline_metrics = BaselineMetrics::new(&agg.metrics, partition); + + // Preserve the existing aggregate metric surface for this plan node. + let _spill_metrics = SpillMetrics::new(&agg.metrics, partition); + + let hash_table = AggregateHashTable::::new( + agg, + partition, + Arc::clone(&schema), + batch_size, + )?; + + let reservation = + MemoryConsumer::new(format!("PartialFinalHashAggregateStream[{partition}]")) + .register(context.memory_pool()); + + Ok(Self { + schema, + input, + hash_table, + baseline_metrics, + reservation, + }) + } +} + +impl Stream for PartialFinalHashAggregateStream { + type Item = Result; + + fn poll_next( + mut self: std::pin::Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll> { + let elapsed_compute = self.baseline_metrics.elapsed_compute().clone(); + + loop { + if self.hash_table.is_done() { + let _ = self.reservation.try_resize(0); + return Poll::Ready(None); + } else if self.hash_table.is_building() { + match self.input.poll_next_unpin(cx) { + Poll::Pending => return Poll::Pending, + Poll::Ready(Some(Ok(batch))) => { + let timer = elapsed_compute.timer(); + let result = self.hash_table.aggregate_batch(&batch); + timer.done(); + + if let Err(e) = result { + return Poll::Ready(Some(Err(e))); + } + + if let Err(e) = + self.reservation.try_resize(self.hash_table.memory_size()) + { + return Poll::Ready(Some(Err(e))); + } + } + Poll::Ready(Some(Err(e))) => { + return Poll::Ready(Some(Err(e))); + } + Poll::Ready(None) => { + let input_schema = self.input.schema(); + self.input = Box::pin(EmptyRecordBatchStream::new(input_schema)); + + let timer = elapsed_compute.timer(); + let result = self.hash_table.start_output(); + timer.done(); + + if let Err(e) = result { + return Poll::Ready(Some(Err(e))); + } + } + } + } else { + let timer = elapsed_compute.timer(); + let result = self.hash_table.next_output_batch(); + timer.done(); + + match result { + Ok(Some(batch)) => { + let _ = + self.reservation.try_resize(self.hash_table.memory_size()); + debug_assert!(batch.num_rows() > 0); + return Poll::Ready(Some(Ok( + batch.record_output(&self.baseline_metrics) + ))); + } + Ok(None) => { + let _ = self.reservation.try_resize(0); + return Poll::Ready(None); + } + Err(e) => return Poll::Ready(Some(Err(e))), + } + } + } + } +} + +impl RecordBatchStream for PartialFinalHashAggregateStream { + fn schema(&self) -> SchemaRef { + Arc::clone(&self.schema) + } +} diff --git a/datafusion/physical-plan/src/aggregates/hash_table.rs b/datafusion/physical-plan/src/aggregates/hash_table.rs new file mode 100644 index 0000000000000..6c412d0a137da --- /dev/null +++ b/datafusion/physical-plan/src/aggregates/hash_table.rs @@ -0,0 +1,615 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::collections::HashMap; +use std::marker::PhantomData; +use std::sync::Arc; + +use arrow::array::{ArrayRef, AsArray, BooleanArray, new_null_array}; +use arrow::datatypes::SchemaRef; +use arrow::record_batch::RecordBatch; +use datafusion_common::{Result, internal_err}; +use datafusion_execution::memory_pool::proxy::VecAllocExt; +use datafusion_expr::{EmitTo, GroupsAccumulator}; + +use super::group_values::{GroupByMetrics, GroupValues, new_group_values}; +use super::order::GroupOrdering; +use super::row_hash::create_group_accumulator; +use super::{ + AggregateExec, PhysicalGroupBy, aggregate_expressions, evaluate_group_by, + group_id_array, max_duplicate_ordinal, +}; +use crate::PhysicalExpr; +use crate::metrics::{MetricBuilder, MetricCategory}; + +/// Marker for raw rows -> partial state aggregation. +pub(super) struct InitialPartial; +/// Marker for partial state -> final value aggregation. +pub(super) struct PartialFinal; + +/// Grouped hash table shared by the initial-partial and partial-final paths. +/// +/// While building, it consumes input batches and updates group / accumulator +/// state. While outputting, it incrementally output the materialized batches. +/// +/// # Marker Type +/// `AggrMode` is a zero-sized marker type used with `PhantomData` to keep the +/// initial-partial and partial-final update logic in separate impl blocks. Different +/// stages has different semantics and applicable optimizations, this makes the +/// implementation easier to follow. +pub(super) struct AggregateHashTable { + /// Grouping and accumulator-specific timing metrics. + group_by_metrics: GroupByMetrics, + + /// Raw input schema, used to evaluate expressions and synthesize empty + /// grouping-set rows. + input_schema: SchemaRef, + + /// Output schema: group columns followed by aggregate state or final values. + output_schema: SchemaRef, + + /// Maximum rows per emitted output batch. + batch_size: usize, + + /// Lifecycle-specific state: building stage / outputting stage + state: AggregateHashTableState, + + _mode: PhantomData, +} + +struct HashAggregateAccumulator { + /// Arguments to pass to this accumulator. + /// + /// Example: `CORR(x, y)` stores two expressions here, while `SUM(x)` stores one. + arguments: Vec>, + + /// Optional `FILTER` expression for this accumulator. + /// + /// Example: `SUM(x) FILTER (WHERE x > 10)` stores the `x > 10` predicate. + filter: Option>, + + /// Accumulator state for all groups for one aggregate expression. + accumulator: Box, +} + +struct EvaluatedHashAggregateAccumulator { + arguments: Vec, + filter: Option, +} + +/// Evaluated all group by keys and accumulator args. +/// +/// e.g., `select k+1, sum(v*v) from t group by (k+1)`, this function evaluates +/// `k+1`, `v*v` +struct EvaluatedAggregateBatch { + /// One entry per grouping set; each entry contains all evaluated group key + /// arrays for the current input batch. + grouping_set_args: Vec>, + + /// Evaluated arguments and filters, one entry per aggregate expression. + accumulator_args: Vec, +} + +/// Hash table state while grouped aggregation is consuming input. +/// +/// This owns the coupled state for: +/// - evaluating group keys, +/// - interning each distinct group, +/// - mapping each input row to its group index, +/// - evaluating aggregate inputs, +/// - updating per-group accumulator state. +struct BuildingHashTableState { + /// GROUP BY expressions evaluated for each input batch. + group_by: Arc, + + /// Interned group keys. Accumulator state is stored separately by group index. + group_values: Box, + + /// Group index for each row in the current input batch. + /// + /// Each value indexes into `group_values`, and the same index is used by every + /// accumulator to update that group's aggregate state. + batch_group_indices: Vec, + + /// One item per aggregate expression. + /// + /// Example: `COUNT(x), SUM(y)` creates two items. Each item owns the input + /// expressions, optional filter, and accumulator state for all groups. + accumulators: Vec, +} + +enum AggregateHashTableState { + Building(BuildingHashTableState), + Outputting { + output_batch: Option, + output_batch_offset: usize, + }, + Done, +} + +impl HashAggregateAccumulator { + fn new( + arguments: Vec>, + filter: Option>, + accumulator: Box, + ) -> Self { + Self { + arguments, + filter, + accumulator, + } + } + + fn evaluate(&self, batch: &RecordBatch) -> Result { + let arguments = self + .arguments + .iter() + .map(|expr| { + expr.evaluate(batch) + .and_then(|value| value.into_array(batch.num_rows())) + }) + .collect::>()?; + + let filter = self + .filter + .as_ref() + .map(|filter| { + filter + .evaluate(batch) + .and_then(|value| value.into_array(batch.num_rows())) + }) + .transpose()?; + + Ok(EvaluatedHashAggregateAccumulator { arguments, filter }) + } + + fn update_batch( + &mut self, + values: &EvaluatedHashAggregateAccumulator, + group_indices: &[usize], + total_num_groups: usize, + ) -> Result<()> { + let filter = values.filter.as_ref().map(|filter| filter.as_boolean()); + self.accumulator.update_batch( + &values.arguments, + group_indices, + filter, + total_num_groups, + ) + } + + fn merge_batch( + &mut self, + values: &EvaluatedHashAggregateAccumulator, + group_indices: &[usize], + total_num_groups: usize, + ) -> Result<()> { + debug_assert!(values.filter.is_none()); + self.accumulator.merge_batch( + &values.arguments, + group_indices, + None, + total_num_groups, + ) + } + + fn evaluate_final(&mut self, emit_to: EmitTo) -> Result { + self.accumulator.evaluate(emit_to) + } + + fn state(&mut self, emit_to: EmitTo) -> Result> { + self.accumulator.state(emit_to) + } + + fn supports_convert_to_state(&self) -> bool { + self.accumulator.supports_convert_to_state() + } + + fn null_arguments(&self, input_schema: &SchemaRef) -> Result> { + self.arguments + .iter() + .map(|expr| { + let data_type = expr.data_type(input_schema)?; + Ok(new_null_array(&data_type, 1)) + }) + .collect() + } +} + +impl AggregateHashTableState { + fn building(&self) -> &BuildingHashTableState { + let Self::Building(state) = self else { + unreachable!("hash aggregate table is not building") + }; + state + } + + fn building_mut(&mut self) -> &mut BuildingHashTableState { + let Self::Building(state) = self else { + unreachable!("hash aggregate table is not building") + }; + state + } +} + +impl AggregateHashTable { + fn new_with_filters( + agg: &AggregateExec, + partition: usize, + output_schema: SchemaRef, + batch_size: usize, + filters: Vec>>, + ) -> Result { + let input_schema = agg.input().schema(); + let aggregate_arguments = aggregate_expressions( + &agg.aggr_expr, + &agg.mode, + agg.group_by.num_group_exprs(), + )?; + let accumulators: Vec<_> = agg + .aggr_expr + .iter() + .zip(aggregate_arguments) + .zip(filters) + .map(|((agg_expr, arguments), filter)| { + let accumulator = create_group_accumulator(agg_expr)?; + Ok(HashAggregateAccumulator::new( + arguments, + filter, + accumulator, + )) + }) + .collect::>()?; + + let group_schema = agg.group_by.group_schema(&input_schema)?; + let group_values = new_group_values(group_schema, &GroupOrdering::None)?; + + Ok(Self { + group_by_metrics: GroupByMetrics::new(&agg.metrics, partition), + input_schema, + output_schema, + batch_size, + state: AggregateHashTableState::Building(BuildingHashTableState { + group_by: Arc::clone(&agg.group_by), + group_values, + batch_group_indices: Default::default(), + accumulators, + }), + _mode: PhantomData, + }) + } + + /// See comments in [`EvaluatedAggregateBatch`] + fn evaluate_batch(&self, batch: &RecordBatch) -> Result { + let state = self.state.building(); + let timer = self.group_by_metrics.time_calculating_group_ids.timer(); + // outer vec: one per each grouping set + // inner vec: all group by exprs for the current grouping set + let grouping_set_args = evaluate_group_by(&state.group_by, batch)?; + drop(timer); + + let timer = self.group_by_metrics.aggregate_arguments_time.timer(); + // The evaluated args for each accumulator + let accumulator_args = self + .state + .building() + .accumulators + .iter() + .map(|acc| acc.evaluate(batch)) + .collect::>>()?; + drop(timer); + + Ok(EvaluatedAggregateBatch { + grouping_set_args, + accumulator_args, + }) + } + + pub(super) fn memory_size(&self) -> usize { + match &self.state { + AggregateHashTableState::Building(state) => { + let acc = state + .accumulators + .iter() + .map(|acc| acc.accumulator.size()) + .sum::(); + + acc + state.group_values.size() + + state.batch_group_indices.allocated_size() + } + AggregateHashTableState::Outputting { output_batch, .. } => { + output_batch_memory_size(output_batch) + } + AggregateHashTableState::Done => 0, + } + } + + pub(super) fn is_building(&self) -> bool { + matches!(self.state, AggregateHashTableState::Building(_)) + } + + pub(super) fn is_done(&self) -> bool { + matches!(self.state, AggregateHashTableState::Done) + } + + fn set_output_batch(&mut self, output_batch: Option) { + self.state = AggregateHashTableState::Outputting { + output_batch, + output_batch_offset: 0, + }; + } + + pub(super) fn next_output_batch(&mut self) -> Result> { + match std::mem::replace(&mut self.state, AggregateHashTableState::Done) { + AggregateHashTableState::Outputting { + output_batch, + mut output_batch_offset, + } => { + let Some(batch) = output_batch.as_ref() else { + return Ok(None); + }; + + let num_rows = batch.num_rows(); + if output_batch_offset >= num_rows { + return Ok(None); + } + + debug_assert!(self.batch_size > 0); + let output_len = + self.batch_size.max(1).min(num_rows - output_batch_offset); + let output = batch.slice(output_batch_offset, output_len); + output_batch_offset += output_len; + + if output_batch_offset == num_rows { + self.state = AggregateHashTableState::Done; + } else { + self.state = AggregateHashTableState::Outputting { + output_batch, + output_batch_offset, + }; + } + + debug_assert!(output.num_rows() > 0); + debug_assert!(output.num_rows() <= self.batch_size.max(1)); + Ok(Some(output)) + } + _ => { + self.state = AggregateHashTableState::Done; + internal_err!("next_output_batch must be called in the outputting state") + } + } + } +} + +impl AggregateHashTable { + pub(super) fn new( + agg: &AggregateExec, + partition: usize, + output_schema: SchemaRef, + batch_size: usize, + ) -> Result { + let table = Self::new_with_filters( + agg, + partition, + output_schema, + batch_size, + agg.filter_expr.iter().cloned().collect(), + )?; + + if table + .state + .building() + .accumulators + .iter() + .all(|acc| acc.supports_convert_to_state()) + { + let _skipped_aggregation_rows = MetricBuilder::new(&agg.metrics) + .with_category(MetricCategory::Rows) + .counter("skipped_aggregation_rows", partition); + } + + Ok(table) + } + + pub(super) fn aggregate_batch(&mut self, batch: &RecordBatch) -> Result<()> { + let evaluated_batch = self.evaluate_batch(batch)?; + let state = self.state.building_mut(); + + let timer = self.group_by_metrics.aggregation_time.timer(); + for group_values in &evaluated_batch.grouping_set_args { + state + .group_values + .intern(group_values, &mut state.batch_group_indices)?; + let group_indices = &state.batch_group_indices; + let total_num_groups = state.group_values.len(); + + for (acc, values) in state + .accumulators + .iter_mut() + .zip(evaluated_batch.accumulator_args.iter()) + { + acc.update_batch(values, group_indices, total_num_groups)?; + } + } + drop(timer); + + Ok(()) + } + + pub(super) fn start_output(&mut self) -> Result<()> { + self.init_empty_grouping_sets()?; + let state = self.state.building_mut(); + + let output_batch = if state.group_values.is_empty() { + None + } else { + let timer = self.group_by_metrics.emitting_time.timer(); + let mut output = state.group_values.emit(EmitTo::All)?; + + for acc in state.accumulators.iter_mut() { + output.extend(acc.state(EmitTo::All)?); + } + + let batch = RecordBatch::try_new(Arc::clone(&self.output_schema), output)?; + debug_assert!(batch.num_rows() > 0); + drop(timer); + Some(batch) + }; + + self.set_output_batch(output_batch); + Ok(()) + } + + /// Creates the required empty grouping-set rows when the input is empty. + /// + /// For example, this query must still produce one grand-total group even if + /// `t` has no rows: + /// + /// ```sql + /// SELECT COUNT(v) + /// FROM t + /// GROUP BY GROUPING SETS (()); + /// ``` + /// + /// The synthetic row is filtered out before accumulator update so aggregates + /// see the same state they would see for an empty input, rather than a real + /// null-valued row. + fn init_empty_grouping_sets(&mut self) -> Result<()> { + let state = self.state.building_mut(); + if !state.group_by.has_grouping_set() || !state.group_values.is_empty() { + return Ok(()); + } + + let max_ordinal = max_duplicate_ordinal(state.group_by.groups()); + let mut ordinals: HashMap<&[bool], usize> = HashMap::new(); + let group_schema = state.group_by.group_schema(&self.input_schema)?; + let n_expr = state.group_by.expr().len(); + let mut any_interned = false; + + for group in state.group_by.groups() { + let ordinal = { + let entry = ordinals.entry(group.as_slice()).or_insert(0); + let ordinal = *entry; + *entry += 1; + ordinal + }; + + if !group.iter().all(|&is_null| is_null) { + continue; + } + + let mut cols: Vec = group_schema + .fields() + .iter() + .take(n_expr) + .map(|field| new_null_array(field.data_type(), 1)) + .collect(); + cols.push(group_id_array(group, ordinal, max_ordinal, 1)?); + + state + .group_values + .intern(&cols, &mut state.batch_group_indices)?; + any_interned = true; + } + + if any_interned { + let total_groups = state.group_values.len(); + let false_filter = BooleanArray::from(vec![false]); + for acc in state.accumulators.iter_mut() { + let null_args = acc.null_arguments(&self.input_schema)?; + let values = EvaluatedHashAggregateAccumulator { + arguments: null_args, + filter: Some(Arc::new(false_filter.clone())), + }; + acc.update_batch(&values, &[0], total_groups)?; + } + } + + Ok(()) + } +} + +impl AggregateHashTable { + pub(super) fn new( + agg: &AggregateExec, + partition: usize, + output_schema: SchemaRef, + batch_size: usize, + ) -> Result { + Self::new_with_filters( + agg, + partition, + output_schema, + batch_size, + vec![None; agg.aggr_expr.len()], + ) + } + + pub(super) fn aggregate_batch(&mut self, batch: &RecordBatch) -> Result<()> { + let evaluated_batch = self.evaluate_batch(batch)?; + let state = self.state.building_mut(); + + let timer = self.group_by_metrics.aggregation_time.timer(); + for group_values in &evaluated_batch.grouping_set_args { + state + .group_values + .intern(group_values, &mut state.batch_group_indices)?; + let group_indices = &state.batch_group_indices; + let total_num_groups = state.group_values.len(); + + for (acc, values) in state + .accumulators + .iter_mut() + .zip(evaluated_batch.accumulator_args.iter()) + { + acc.merge_batch(values, group_indices, total_num_groups)?; + } + } + drop(timer); + + Ok(()) + } + + pub(super) fn start_output(&mut self) -> Result<()> { + let state = self.state.building_mut(); + let output_batch = if state.group_values.is_empty() { + None + } else { + let timer = self.group_by_metrics.emitting_time.timer(); + let mut output = state.group_values.emit(EmitTo::All)?; + + for acc in state.accumulators.iter_mut() { + output.push(acc.evaluate_final(EmitTo::All)?); + } + + let batch = RecordBatch::try_new(Arc::clone(&self.output_schema), output)?; + debug_assert!(batch.num_rows() > 0); + drop(timer); + Some(batch) + }; + + self.set_output_batch(output_batch); + Ok(()) + } +} + +fn output_batch_memory_size(output_batch: &Option) -> usize { + output_batch + .as_ref() + .map(RecordBatch::get_array_memory_size) + .unwrap_or_default() +} diff --git a/datafusion/physical-plan/src/aggregates/mod.rs b/datafusion/physical-plan/src/aggregates/mod.rs index c8b825d576e02..3fc89aa3dc91a 100644 --- a/datafusion/physical-plan/src/aggregates/mod.rs +++ b/datafusion/physical-plan/src/aggregates/mod.rs @@ -22,7 +22,11 @@ use std::sync::Arc; use super::{DisplayAs, ExecutionPlanProperties, PlanProperties}; use crate::aggregates::{ - no_grouping::AggregateStream, row_hash::GroupedHashAggregateStream, + hash_aggregate::{ + InitialPartialHashAggregateStream, PartialFinalHashAggregateStream, + }, + no_grouping::AggregateStream, + row_hash::GroupedHashAggregateStream, topk_stream::GroupedTopKAggregateStream, }; use crate::execution_plan::{CardinalityEffect, EmissionType}; @@ -50,6 +54,7 @@ use datafusion_common::{ internal_err, not_impl_err, }; use datafusion_execution::TaskContext; +use datafusion_execution::memory_pool::MemoryLimit; use datafusion_expr::{Accumulator, Aggregate}; use datafusion_physical_expr::aggregate::AggregateFunctionExpr; use datafusion_physical_expr::equivalence::ProjectionMapping; @@ -69,6 +74,8 @@ use topk::hash_table::is_supported_hash_key_type; use topk::heap::is_supported_heap_type; pub mod group_values; +mod hash_aggregate; +mod hash_table; mod no_grouping; pub mod order; mod row_hash; @@ -499,6 +506,8 @@ impl PartialEq for PhysicalGroupBy { #[expect(clippy::large_enum_variant)] enum StreamType { AggregateStream(AggregateStream), + InitialPartialHash(InitialPartialHashAggregateStream), + PartialFinalHash(PartialFinalHashAggregateStream), GroupedHash(GroupedHashAggregateStream), GroupedPriorityQueue(GroupedTopKAggregateStream), } @@ -507,6 +516,8 @@ impl From for SendableRecordBatchStream { fn from(stream: StreamType) -> Self { match stream { StreamType::AggregateStream(stream) => Box::pin(stream), + StreamType::InitialPartialHash(stream) => Box::pin(stream), + StreamType::PartialFinalHash(stream) => Box::pin(stream), StreamType::GroupedHash(stream) => Box::pin(stream), StreamType::GroupedPriorityQueue(stream) => Box::pin(stream), } @@ -964,12 +975,52 @@ impl AggregateExec { )); } + if self.should_use_initial_partial_hash_stream(context) { + return Ok(StreamType::InitialPartialHash( + InitialPartialHashAggregateStream::new(self, context, partition)?, + )); + } + + if self.should_use_partial_final_hash_stream(context) { + return Ok(StreamType::PartialFinalHash( + PartialFinalHashAggregateStream::new(self, context, partition)?, + )); + } + // grouping by something else and we need to just materialize all results Ok(StreamType::GroupedHash(GroupedHashAggregateStream::new( self, context, partition, )?)) } + fn should_use_initial_partial_hash_stream(&self, context: &TaskContext) -> bool { + // TODO: implement memory-limited path and remove this limitation + if matches!(context.memory_pool().memory_limit(), MemoryLimit::Finite(_)) { + return false; + } + + self.mode == AggregateMode::Partial + && self.limit_options.is_none() + && self.input_order_mode == InputOrderMode::Linear + && !self.group_by.is_true_no_grouping() + && self.group_by.is_single() + } + + fn should_use_partial_final_hash_stream(&self, context: &TaskContext) -> bool { + // TODO: implement memory-limited path and remove this limitation + if matches!(context.memory_pool().memory_limit(), MemoryLimit::Finite(_)) { + return false; + } + + matches!( + self.mode, + AggregateMode::Final | AggregateMode::FinalPartitioned + ) && self.limit_options.is_none() + && self.input_order_mode == InputOrderMode::Linear + && !self.group_by.is_true_no_grouping() + && self.group_by.is_single() + } + /// Finds the DataType and SortDirection for this Aggregate, if there is one pub fn get_minmax_desc(&self) -> Option<(FieldRef, bool)> { let agg_expr = self.aggr_expr.iter().exactly_one().ok()?; @@ -2180,6 +2231,32 @@ pub(crate) fn max_duplicate_ordinal(groups: &[Vec]) -> usize { /// The outer Vec appears to be for grouping sets /// The inner Vec contains the results per expression /// The inner-inner Array contains the results per row +/// +/// For example, for `GROUP BY GROUPING SETS ((a, b), (a))` with input: +/// +/// ```text +/// a b +/// 1 1 +/// 1 2 +/// 2 1 +/// ``` +/// +/// The output is: +/// +/// ```text +/// [ +/// [ +/// a: [1, 1, 2] +/// b: [1, 2, 1] +/// grouping_id: [0, 0, 0] +/// ], +/// [ +/// a: [1, 1, 2] +/// b: [NULL, NULL, NULL] +/// grouping_id: [1, 1, 1] +/// ] +/// ] +/// ``` pub fn evaluate_group_by( group_by: &PhysicalGroupBy, batch: &RecordBatch, @@ -2954,6 +3031,78 @@ mod tests { Ok(()) } + #[tokio::test] + async fn partial_grouped_aggregate_uses_raw_partial_stream() -> Result<()> { + let (schema, batches) = some_data(); + let input = TestMemoryExec::try_new_exec(&[batches], Arc::clone(&schema), None)?; + let group_by = + PhysicalGroupBy::new_single(vec![(col("a", &schema)?, "a".to_string())]); + let udaf = Arc::new(AggregateUDF::from(InputTypeAssertingUdaf::new( + vec![DataType::Float64], + vec![DataType::Int32], + DataType::Int64, + ))); + let aggregates: Vec> = vec![Arc::new( + AggregateExprBuilder::new(udaf, vec![col("b", &schema)?]) + .schema(Arc::clone(&schema)) + .alias("input_type_asserting(b)") + .build()?, + )]; + + let partial_aggregate = Arc::new(AggregateExec::try_new( + AggregateMode::Partial, + group_by.clone(), + aggregates.clone(), + vec![None], + input, + Arc::clone(&schema), + )?); + let task_ctx = Arc::new( + TaskContext::default() + .with_session_config(SessionConfig::new().with_batch_size(2)), + ); + + let stream = partial_aggregate.execute_typed(0, &task_ctx)?; + assert!(matches!(stream, StreamType::InitialPartialHash(_))); + + let stream: SendableRecordBatchStream = stream.into(); + let batches = collect(stream).await?; + assert_eq!( + batches + .iter() + .map(RecordBatch::num_rows) + .collect::>(), + vec![2, 1] + ); + assert_eq!(batches.iter().map(RecordBatch::num_rows).sum::(), 3); + + let merge = Arc::new(CoalescePartitionsExec::new(partial_aggregate)); + let final_aggregate = AggregateExec::try_new( + AggregateMode::Final, + group_by.as_final(), + aggregates, + vec![None], + merge, + Arc::clone(&schema), + )?; + + let stream = final_aggregate.execute_typed(0, &task_ctx)?; + assert!(matches!(stream, StreamType::PartialFinalHash(_))); + + let stream: SendableRecordBatchStream = stream.into(); + let batches = collect(stream).await?; + assert_eq!( + batches + .iter() + .map(RecordBatch::num_rows) + .collect::>(), + vec![2, 1] + ); + assert_eq!(batches.iter().map(RecordBatch::num_rows).sum::(), 3); + + Ok(()) + } + #[tokio::test] async fn test_drop_cancel_without_groups() -> Result<()> { let task_ctx = Arc::new(TaskContext::default()); @@ -3682,8 +3831,11 @@ mod tests { &ScalarValue::Float64(Some(0.1)), ); - let ctx = TaskContext::default().with_session_config(session_config); - let output = collect(aggregate_exec.execute(0, Arc::new(ctx))?).await?; + let ctx = Arc::new(TaskContext::default().with_session_config(session_config)); + let stream: SendableRecordBatchStream = Box::pin( + GroupedHashAggregateStream::new(aggregate_exec.as_ref(), &ctx, 0)?, + ); + let output = collect(stream).await?; allow_duplicates! { assert_snapshot!(batches_to_string(&output), @r" @@ -3769,8 +3921,11 @@ mod tests { &ScalarValue::Float64(Some(0.1)), ); - let ctx = TaskContext::default().with_session_config(session_config); - let output = collect(aggregate_exec.execute(0, Arc::new(ctx))?).await?; + let ctx = Arc::new(TaskContext::default().with_session_config(session_config)); + let stream: SendableRecordBatchStream = Box::pin( + GroupedHashAggregateStream::new(aggregate_exec.as_ref(), &ctx, 0)?, + ); + let output = collect(stream).await?; allow_duplicates! { assert_snapshot!(batches_to_string(&output), @r" From 3184e0044c1468a4660f5c89063253d47077b344 Mon Sep 17 00:00:00 2001 From: Yongting You <2010youy01@gmail.com> Date: Thu, 4 Jun 2026 20:31:10 +0800 Subject: [PATCH 2/5] review: more comments --- .../src/aggregates/hash_aggregate.rs | 15 +++++------ .../physical-plan/src/aggregates/mod.rs | 27 +++++++++++++++++++ 2 files changed, 33 insertions(+), 9 deletions(-) diff --git a/datafusion/physical-plan/src/aggregates/hash_aggregate.rs b/datafusion/physical-plan/src/aggregates/hash_aggregate.rs index f716ef41279e8..74749a33d29e1 100644 --- a/datafusion/physical-plan/src/aggregates/hash_aggregate.rs +++ b/datafusion/physical-plan/src/aggregates/hash_aggregate.rs @@ -15,18 +15,15 @@ // specific language governing permissions and limitations // under the License. -//! Grouped hash aggregation for simple multi-stage aggregation paths. +//! 2-stage hash aggregation stream implementation. //! -//! This module handles the basic grouped two-stage paths: +//! See comments in [`InitialPartialHashAggregateStream`] and [`PartialFinalHashAggregateStream`] +//! for details. //! -//! ```text -//! input rows -> GROUP BY hash table -> accumulator state rows -//! state rows -> GROUP BY hash table -> final aggregate rows -//! ``` +//! Note these streams are an incremental migration of the existing +//! [`crate::aggregates::row_hash::GroupedHashAggregateStream`]. //! -//! `AggregateExec` keeps finite-memory, ordered, limit, grouping-set, -//! `partial state -> partial state`, and single-stage aggregation on -//! `GroupedHashAggregateStream` for now. +//! See issue for details: use std::sync::Arc; use std::task::{Context, Poll}; diff --git a/datafusion/physical-plan/src/aggregates/mod.rs b/datafusion/physical-plan/src/aggregates/mod.rs index 3fc89aa3dc91a..e107de66a1430 100644 --- a/datafusion/physical-plan/src/aggregates/mod.rs +++ b/datafusion/physical-plan/src/aggregates/mod.rs @@ -503,12 +503,39 @@ impl PartialEq for PhysicalGroupBy { } } +/// Streams used by [`AggregateExec`]. +/// +/// # Stream Variant Schema Notation +/// For example, `SELECT g, AVG(x) FROM t GROUP BY g` uses these schemas: +/// +/// ```text +/// initial input: [g, x] +/// partial state: [g, AVG(x) state columns, e.g. sum/count] +/// final result: [g, AVG(x)] +/// ``` #[expect(clippy::large_enum_variant)] enum StreamType { + /// Single group (no group by) aggregate stream. + /// Input output scheme: initial input -> final result AggregateStream(AggregateStream), + /// Partial stage of the hash aggregation + /// Input output scheme: initial input -> partial state InitialPartialHash(InitialPartialHashAggregateStream), + /// Final stage of the hash aggregation + /// Input output scheme: partial state -> final result PartialFinalHash(PartialFinalHashAggregateStream), + /// Hash aggregation resused for multiple stages + /// + /// Note this is being incrementally migrated to dedicated streams like + /// [`StreamType::InitialPartialHash`] and [`StreamType::PartialFinalHash`] + /// + /// See issue for details: GroupedHash(GroupedHashAggregateStream), + /// Grouped TopK aggregate stream. + /// Input output scheme: initial input -> final result + /// + /// Used for grouped aggregation with LIMIT / ordering, where the stream keeps + /// only the top groups required by the query. GroupedPriorityQueue(GroupedTopKAggregateStream), } From 4a2b907285e3be77eab342353b8c2c6fd947035d Mon Sep 17 00:00:00 2001 From: Yongting You <2010youy01@gmail.com> Date: Thu, 4 Jun 2026 21:00:22 +0800 Subject: [PATCH 3/5] add config to enable migration aggregate path --- datafusion/common/src/config.rs | 11 ++++ .../physical-plan/src/aggregates/mod.rs | 52 +++++++++++++------ .../test_files/information_schema.slt | 2 + 3 files changed, 49 insertions(+), 16 deletions(-) diff --git a/datafusion/common/src/config.rs b/datafusion/common/src/config.rs index e4a3cea709b31..d533bf973cb85 100644 --- a/datafusion/common/src/config.rs +++ b/datafusion/common/src/config.rs @@ -587,6 +587,17 @@ config_namespace! { /// the new schema verification step. pub skip_physical_aggregate_schema_check: bool, default = false + /// Temporary switch for aggregate stream implementations that are being + /// migrated from `GroupedHashAggregateStream`. + /// + /// When set to true, DataFusion tries the migrated implementations when + /// their preconditions are satisfied. When set to false, grouped + /// aggregation falls back to `GroupedHashAggregateStream`. This option + /// will be removed after the migration is finished. + /// + /// See for details. + pub enable_migration_aggregate: bool, default = true + /// Sets the compression codec used when spilling data to disk. /// /// Since datafusion writes spill files using the Arrow IPC Stream format, diff --git a/datafusion/physical-plan/src/aggregates/mod.rs b/datafusion/physical-plan/src/aggregates/mod.rs index e107de66a1430..b6c3b151e01d9 100644 --- a/datafusion/physical-plan/src/aggregates/mod.rs +++ b/datafusion/physical-plan/src/aggregates/mod.rs @@ -524,7 +524,7 @@ enum StreamType { /// Final stage of the hash aggregation /// Input output scheme: partial state -> final result PartialFinalHash(PartialFinalHashAggregateStream), - /// Hash aggregation resused for multiple stages + /// Hash aggregation reused for multiple stages /// /// Note this is being incrementally migrated to dedicated streams like /// [`StreamType::InitialPartialHash`] and [`StreamType::PartialFinalHash`] @@ -1002,16 +1002,23 @@ impl AggregateExec { )); } - if self.should_use_initial_partial_hash_stream(context) { - return Ok(StreamType::InitialPartialHash( - InitialPartialHashAggregateStream::new(self, context, partition)?, - )); - } + if context + .session_config() + .options() + .execution + .enable_migration_aggregate + { + if self.should_use_initial_partial_hash_stream(context) { + return Ok(StreamType::InitialPartialHash( + InitialPartialHashAggregateStream::new(self, context, partition)?, + )); + } - if self.should_use_partial_final_hash_stream(context) { - return Ok(StreamType::PartialFinalHash( - PartialFinalHashAggregateStream::new(self, context, partition)?, - )); + if self.should_use_partial_final_hash_stream(context) { + return Ok(StreamType::PartialFinalHash( + PartialFinalHashAggregateStream::new(self, context, partition)?, + )); + } } // grouping by something else and we need to just materialize all results @@ -3089,10 +3096,20 @@ mod tests { .with_session_config(SessionConfig::new().with_batch_size(2)), ); - let stream = partial_aggregate.execute_typed(0, &task_ctx)?; - assert!(matches!(stream, StreamType::InitialPartialHash(_))); + let partial_stream = partial_aggregate.execute_typed(0, &task_ctx)?; + assert!(matches!(partial_stream, StreamType::InitialPartialHash(_))); - let stream: SendableRecordBatchStream = stream.into(); + let fallback_task_ctx = Arc::new( + TaskContext::default().with_session_config( + SessionConfig::new() + .with_batch_size(2) + .set_bool("datafusion.execution.enable_migration_aggregate", false), + ), + ); + let stream = partial_aggregate.execute_typed(0, &fallback_task_ctx)?; + assert!(matches!(stream, StreamType::GroupedHash(_))); + + let stream: SendableRecordBatchStream = partial_stream.into(); let batches = collect(stream).await?; assert_eq!( batches @@ -3113,10 +3130,13 @@ mod tests { Arc::clone(&schema), )?; - let stream = final_aggregate.execute_typed(0, &task_ctx)?; - assert!(matches!(stream, StreamType::PartialFinalHash(_))); + let final_stream = final_aggregate.execute_typed(0, &task_ctx)?; + assert!(matches!(final_stream, StreamType::PartialFinalHash(_))); + + let stream = final_aggregate.execute_typed(0, &fallback_task_ctx)?; + assert!(matches!(stream, StreamType::GroupedHash(_))); - let stream: SendableRecordBatchStream = stream.into(); + let stream: SendableRecordBatchStream = final_stream.into(); let batches = collect(stream).await?; assert_eq!( batches diff --git a/datafusion/sqllogictest/test_files/information_schema.slt b/datafusion/sqllogictest/test_files/information_schema.slt index ec2055e5ad62d..0a5ee876253bd 100644 --- a/datafusion/sqllogictest/test_files/information_schema.slt +++ b/datafusion/sqllogictest/test_files/information_schema.slt @@ -218,6 +218,7 @@ datafusion.execution.batch_size 8192 datafusion.execution.coalesce_batches true datafusion.execution.collect_statistics true datafusion.execution.enable_ansi_mode false +datafusion.execution.enable_migration_aggregate true datafusion.execution.enable_recursive_ctes true datafusion.execution.enforce_batch_size_in_joins false datafusion.execution.hash_join_buffering_capacity 0 @@ -370,6 +371,7 @@ datafusion.execution.batch_size 8192 Default batch size while creating new batch datafusion.execution.coalesce_batches true When set to true, record batches will be examined between each operator and small batches will be coalesced into larger batches. This is helpful when there are highly selective filters or joins that could produce tiny output batches. The target batch size is determined by the configuration setting datafusion.execution.collect_statistics true Should DataFusion collect statistics when first creating a table. Has no effect after the table is created. Applies to the default `ListingTableProvider` in DataFusion. Defaults to true. datafusion.execution.enable_ansi_mode false Whether to enable ANSI SQL mode. The flag is experimental and relevant only for DataFusion Spark built-in functions When `enable_ansi_mode` is set to `true`, the query engine follows ANSI SQL semantics for expressions, casting, and error handling. This means: - **Strict type coercion rules:** implicit casts between incompatible types are disallowed. - **Standard SQL arithmetic behavior:** operations such as division by zero, numeric overflow, or invalid casts raise runtime errors rather than returning `NULL` or adjusted values. - **Consistent ANSI behavior** for string concatenation, comparisons, and `NULL` handling. When `enable_ansi_mode` is `false` (the default), the engine uses a more permissive, non-ANSI mode designed for user convenience and backward compatibility. In this mode: - Implicit casts between types are allowed (e.g., string to integer when possible). - Arithmetic operations are more lenient — for example, `abs()` on the minimum representable integer value returns the input value instead of raising overflow. - Division by zero or invalid casts may return `NULL` instead of failing. # Default `false` — ANSI SQL mode is disabled by default. +datafusion.execution.enable_migration_aggregate true Temporary switch for aggregate stream implementations that are being migrated from `GroupedHashAggregateStream`. When set to true, DataFusion tries the migrated implementations when their preconditions are satisfied. When set to false, grouped aggregation falls back to `GroupedHashAggregateStream`. This option will be removed after the migration is finished. See for details. datafusion.execution.enable_recursive_ctes true Should DataFusion support recursive CTEs datafusion.execution.enforce_batch_size_in_joins false Should DataFusion enforce batch size in joins or not. By default, DataFusion will not enforce batch size in joins. Enforcing batch size in joins can reduce memory usage when joining large tables with a highly-selective join filter, but is also slightly slower. datafusion.execution.hash_join_buffering_capacity 0 How many bytes to buffer in the probe side of hash joins while the build side is concurrently being built. Without this, hash joins will wait until the full materialization of the build side before polling the probe side. This is useful in scenarios where the query is not completely CPU bounded, allowing to do some early work concurrently and reducing the latency of the query. Note that when hash join buffering is enabled, the probe side will start eagerly polling data, not giving time for the producer side of dynamic filters to produce any meaningful predicate. Queries with dynamic filters might see performance degradation. Disabled by default, set to a number greater than 0 for enabling it. From 09462b9f7b8ea1c3fcccc0b2966b3163470734f4 Mon Sep 17 00:00:00 2001 From: Yongting You <2010youy01@gmail.com> Date: Thu, 4 Jun 2026 21:09:26 +0800 Subject: [PATCH 4/5] review: better struct naming --- .../src/aggregates/hash_aggregate.rs | 39 +++++++++---------- .../src/aggregates/hash_table.rs | 8 ++-- .../physical-plan/src/aggregates/mod.rs | 38 +++++++++--------- 3 files changed, 41 insertions(+), 44 deletions(-) diff --git a/datafusion/physical-plan/src/aggregates/hash_aggregate.rs b/datafusion/physical-plan/src/aggregates/hash_aggregate.rs index 74749a33d29e1..7a9818100678d 100644 --- a/datafusion/physical-plan/src/aggregates/hash_aggregate.rs +++ b/datafusion/physical-plan/src/aggregates/hash_aggregate.rs @@ -17,7 +17,7 @@ //! 2-stage hash aggregation stream implementation. //! -//! See comments in [`InitialPartialHashAggregateStream`] and [`PartialFinalHashAggregateStream`] +//! See comments in [`PartialHashAggregateStream`] and [`FinalHashAggregateStream`] //! for details. //! //! Note these streams are an incremental migration of the existing @@ -36,7 +36,7 @@ use datafusion_execution::memory_pool::{MemoryConsumer, MemoryReservation}; use futures::stream::{Stream, StreamExt}; use super::AggregateExec; -use super::hash_table::{AggregateHashTable, InitialPartial, PartialFinal}; +use super::hash_table::{AggregateHashTable, Partial}; use crate::metrics::{BaselineMetrics, MetricBuilder, RecordOutput, SpillMetrics}; use crate::stream::EmptyRecordBatchStream; use crate::{InputOrderMode, RecordBatchStream, SendableRecordBatchStream, metrics}; @@ -60,7 +60,7 @@ use crate::{InputOrderMode, RecordBatchStream, SendableRecordBatchStream, metric /// ## Final Stage Behavior /// Input: partial states /// Output: results for all groups (e.g. for avg(x), it's avg(x) calculated from the state) -pub(crate) struct InitialPartialHashAggregateStream { +pub(crate) struct PartialHashAggregateStream { /// Output schema: group columns followed by partial aggregate state columns. schema: SchemaRef, @@ -68,7 +68,7 @@ pub(crate) struct InitialPartialHashAggregateStream { input: SendableRecordBatchStream, /// Hash table state for this aggregate stream. - hash_table: AggregateHashTable, + hash_table: AggregateHashTable, /// Memory reservation for group keys and accumulators. reservation: MemoryReservation, @@ -83,8 +83,8 @@ pub(crate) struct InitialPartialHashAggregateStream { /// Hash aggregation uses a 2-stage (partial and final) hash aggregation, this stream /// is for the final stage. /// -/// See [`InitialPartialHashAggregateStream`] for details. -pub(crate) struct PartialFinalHashAggregateStream { +/// See [`PartialHashAggregateStream`] for details. +pub(crate) struct FinalHashAggregateStream { /// Output schema: group columns followed by final aggregate value columns. schema: SchemaRef, @@ -92,7 +92,7 @@ pub(crate) struct PartialFinalHashAggregateStream { input: SendableRecordBatchStream, /// Hash table state for this aggregate stream. - hash_table: AggregateHashTable, + hash_table: AggregateHashTable, /// Execution metrics shared with the aggregate plan node. baseline_metrics: BaselineMetrics, @@ -101,7 +101,7 @@ pub(crate) struct PartialFinalHashAggregateStream { reservation: MemoryReservation, } -impl InitialPartialHashAggregateStream { +impl PartialHashAggregateStream { pub fn new( agg: &AggregateExec, context: &Arc, @@ -121,17 +121,16 @@ impl InitialPartialHashAggregateStream { .with_type(metrics::MetricType::Summary) .ratio_metrics("reduction_factor", partition); - let hash_table = AggregateHashTable::::new( + let hash_table = AggregateHashTable::::new( agg, partition, Arc::clone(&schema), batch_size, )?; - let reservation = MemoryConsumer::new(format!( - "InitialPartialHashAggregateStream[{partition}]" - )) - .register(context.memory_pool()); + let reservation = + MemoryConsumer::new(format!("PartialHashAggregateStream[{partition}]")) + .register(context.memory_pool()); Ok(Self { schema, @@ -144,7 +143,7 @@ impl InitialPartialHashAggregateStream { } } -impl Stream for InitialPartialHashAggregateStream { +impl Stream for PartialHashAggregateStream { type Item = Result; fn poll_next( @@ -220,13 +219,13 @@ impl Stream for InitialPartialHashAggregateStream { } } -impl RecordBatchStream for InitialPartialHashAggregateStream { +impl RecordBatchStream for PartialHashAggregateStream { fn schema(&self) -> SchemaRef { Arc::clone(&self.schema) } } -impl PartialFinalHashAggregateStream { +impl FinalHashAggregateStream { pub fn new( agg: &AggregateExec, context: &Arc, @@ -246,7 +245,7 @@ impl PartialFinalHashAggregateStream { // Preserve the existing aggregate metric surface for this plan node. let _spill_metrics = SpillMetrics::new(&agg.metrics, partition); - let hash_table = AggregateHashTable::::new( + let hash_table = AggregateHashTable::::new( agg, partition, Arc::clone(&schema), @@ -254,7 +253,7 @@ impl PartialFinalHashAggregateStream { )?; let reservation = - MemoryConsumer::new(format!("PartialFinalHashAggregateStream[{partition}]")) + MemoryConsumer::new(format!("FinalHashAggregateStream[{partition}]")) .register(context.memory_pool()); Ok(Self { @@ -267,7 +266,7 @@ impl PartialFinalHashAggregateStream { } } -impl Stream for PartialFinalHashAggregateStream { +impl Stream for FinalHashAggregateStream { type Item = Result; fn poll_next( @@ -339,7 +338,7 @@ impl Stream for PartialFinalHashAggregateStream { } } -impl RecordBatchStream for PartialFinalHashAggregateStream { +impl RecordBatchStream for FinalHashAggregateStream { fn schema(&self) -> SchemaRef { Arc::clone(&self.schema) } diff --git a/datafusion/physical-plan/src/aggregates/hash_table.rs b/datafusion/physical-plan/src/aggregates/hash_table.rs index 6c412d0a137da..d23aa230aaed9 100644 --- a/datafusion/physical-plan/src/aggregates/hash_table.rs +++ b/datafusion/physical-plan/src/aggregates/hash_table.rs @@ -37,9 +37,9 @@ use crate::PhysicalExpr; use crate::metrics::{MetricBuilder, MetricCategory}; /// Marker for raw rows -> partial state aggregation. -pub(super) struct InitialPartial; +pub(super) struct Partial; /// Marker for partial state -> final value aggregation. -pub(super) struct PartialFinal; +pub(super) struct Partial; /// Grouped hash table shared by the initial-partial and partial-final paths. /// @@ -395,7 +395,7 @@ impl AggregateHashTable { } } -impl AggregateHashTable { +impl AggregateHashTable { pub(super) fn new( agg: &AggregateExec, partition: usize, @@ -543,7 +543,7 @@ impl AggregateHashTable { } } -impl AggregateHashTable { +impl AggregateHashTable { pub(super) fn new( agg: &AggregateExec, partition: usize, diff --git a/datafusion/physical-plan/src/aggregates/mod.rs b/datafusion/physical-plan/src/aggregates/mod.rs index b6c3b151e01d9..ae0d9664ad85d 100644 --- a/datafusion/physical-plan/src/aggregates/mod.rs +++ b/datafusion/physical-plan/src/aggregates/mod.rs @@ -22,9 +22,7 @@ use std::sync::Arc; use super::{DisplayAs, ExecutionPlanProperties, PlanProperties}; use crate::aggregates::{ - hash_aggregate::{ - InitialPartialHashAggregateStream, PartialFinalHashAggregateStream, - }, + hash_aggregate::{FinalHashAggregateStream, PartialHashAggregateStream}, no_grouping::AggregateStream, row_hash::GroupedHashAggregateStream, topk_stream::GroupedTopKAggregateStream, @@ -520,14 +518,14 @@ enum StreamType { AggregateStream(AggregateStream), /// Partial stage of the hash aggregation /// Input output scheme: initial input -> partial state - InitialPartialHash(InitialPartialHashAggregateStream), + PartialHash(PartialHashAggregateStream), /// Final stage of the hash aggregation /// Input output scheme: partial state -> final result - PartialFinalHash(PartialFinalHashAggregateStream), + FinalHash(FinalHashAggregateStream), /// Hash aggregation reused for multiple stages /// /// Note this is being incrementally migrated to dedicated streams like - /// [`StreamType::InitialPartialHash`] and [`StreamType::PartialFinalHash`] + /// [`StreamType::PartialHash`] and [`StreamType::FinalHash`] /// /// See issue for details: GroupedHash(GroupedHashAggregateStream), @@ -543,8 +541,8 @@ impl From for SendableRecordBatchStream { fn from(stream: StreamType) -> Self { match stream { StreamType::AggregateStream(stream) => Box::pin(stream), - StreamType::InitialPartialHash(stream) => Box::pin(stream), - StreamType::PartialFinalHash(stream) => Box::pin(stream), + StreamType::PartialHash(stream) => Box::pin(stream), + StreamType::FinalHash(stream) => Box::pin(stream), StreamType::GroupedHash(stream) => Box::pin(stream), StreamType::GroupedPriorityQueue(stream) => Box::pin(stream), } @@ -1008,16 +1006,16 @@ impl AggregateExec { .execution .enable_migration_aggregate { - if self.should_use_initial_partial_hash_stream(context) { - return Ok(StreamType::InitialPartialHash( - InitialPartialHashAggregateStream::new(self, context, partition)?, - )); + if self.should_use_partial_hash_stream(context) { + return Ok(StreamType::PartialHash(PartialHashAggregateStream::new( + self, context, partition, + )?)); } - if self.should_use_partial_final_hash_stream(context) { - return Ok(StreamType::PartialFinalHash( - PartialFinalHashAggregateStream::new(self, context, partition)?, - )); + if self.should_use_final_hash_stream(context) { + return Ok(StreamType::FinalHash(FinalHashAggregateStream::new( + self, context, partition, + )?)); } } @@ -1027,7 +1025,7 @@ impl AggregateExec { )?)) } - fn should_use_initial_partial_hash_stream(&self, context: &TaskContext) -> bool { + fn should_use_partial_hash_stream(&self, context: &TaskContext) -> bool { // TODO: implement memory-limited path and remove this limitation if matches!(context.memory_pool().memory_limit(), MemoryLimit::Finite(_)) { return false; @@ -1040,7 +1038,7 @@ impl AggregateExec { && self.group_by.is_single() } - fn should_use_partial_final_hash_stream(&self, context: &TaskContext) -> bool { + fn should_use_final_hash_stream(&self, context: &TaskContext) -> bool { // TODO: implement memory-limited path and remove this limitation if matches!(context.memory_pool().memory_limit(), MemoryLimit::Finite(_)) { return false; @@ -3097,7 +3095,7 @@ mod tests { ); let partial_stream = partial_aggregate.execute_typed(0, &task_ctx)?; - assert!(matches!(partial_stream, StreamType::InitialPartialHash(_))); + assert!(matches!(partial_stream, StreamType::PartialHash(_))); let fallback_task_ctx = Arc::new( TaskContext::default().with_session_config( @@ -3131,7 +3129,7 @@ mod tests { )?; let final_stream = final_aggregate.execute_typed(0, &task_ctx)?; - assert!(matches!(final_stream, StreamType::PartialFinalHash(_))); + assert!(matches!(final_stream, StreamType::FinalHash(_))); let stream = final_aggregate.execute_typed(0, &fallback_task_ctx)?; assert!(matches!(stream, StreamType::GroupedHash(_))); From bd75229ee6f492e63e802b3cb7c76aee792b85f1 Mon Sep 17 00:00:00 2001 From: Yongting You <2010youy01@gmail.com> Date: Thu, 4 Jun 2026 21:19:04 +0800 Subject: [PATCH 5/5] fix ci --- datafusion/physical-plan/src/aggregates/hash_aggregate.rs | 6 +++--- datafusion/physical-plan/src/aggregates/hash_table.rs | 4 ++-- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/datafusion/physical-plan/src/aggregates/hash_aggregate.rs b/datafusion/physical-plan/src/aggregates/hash_aggregate.rs index 7a9818100678d..f25299631a92c 100644 --- a/datafusion/physical-plan/src/aggregates/hash_aggregate.rs +++ b/datafusion/physical-plan/src/aggregates/hash_aggregate.rs @@ -36,7 +36,7 @@ use datafusion_execution::memory_pool::{MemoryConsumer, MemoryReservation}; use futures::stream::{Stream, StreamExt}; use super::AggregateExec; -use super::hash_table::{AggregateHashTable, Partial}; +use super::hash_table::{AggregateHashTable, Final, Partial}; use crate::metrics::{BaselineMetrics, MetricBuilder, RecordOutput, SpillMetrics}; use crate::stream::EmptyRecordBatchStream; use crate::{InputOrderMode, RecordBatchStream, SendableRecordBatchStream, metrics}; @@ -92,7 +92,7 @@ pub(crate) struct FinalHashAggregateStream { input: SendableRecordBatchStream, /// Hash table state for this aggregate stream. - hash_table: AggregateHashTable, + hash_table: AggregateHashTable, /// Execution metrics shared with the aggregate plan node. baseline_metrics: BaselineMetrics, @@ -245,7 +245,7 @@ impl FinalHashAggregateStream { // Preserve the existing aggregate metric surface for this plan node. let _spill_metrics = SpillMetrics::new(&agg.metrics, partition); - let hash_table = AggregateHashTable::::new( + let hash_table = AggregateHashTable::::new( agg, partition, Arc::clone(&schema), diff --git a/datafusion/physical-plan/src/aggregates/hash_table.rs b/datafusion/physical-plan/src/aggregates/hash_table.rs index d23aa230aaed9..66f403ba97a49 100644 --- a/datafusion/physical-plan/src/aggregates/hash_table.rs +++ b/datafusion/physical-plan/src/aggregates/hash_table.rs @@ -39,7 +39,7 @@ use crate::metrics::{MetricBuilder, MetricCategory}; /// Marker for raw rows -> partial state aggregation. pub(super) struct Partial; /// Marker for partial state -> final value aggregation. -pub(super) struct Partial; +pub(super) struct Final; /// Grouped hash table shared by the initial-partial and partial-final paths. /// @@ -543,7 +543,7 @@ impl AggregateHashTable { } } -impl AggregateHashTable { +impl AggregateHashTable { pub(super) fn new( agg: &AggregateExec, partition: usize,