Skip to content

Commit 40a6454

Browse files
2010YOUY01martin-g
andauthored
refactor(hash-aggr): Forward port the soft limit optimization to the new hash aggregation impl (#22824)
## Which issue does this PR close? <!-- We generally require a GitHub issue to be filed for all bug fixes and enhancements and this helps us generate change logs for our releases. You can link an issue to this PR using the GitHub syntax. For example `Closes #123` indicates that this PR will close issue #123. --> - part of #22710 ## Rationale for this change <!-- Why are you proposing this change? If this is already explained clearly in the issue then this section is not needed. Explaining clearly why changes are proposed helps reviewers understand your changes and offer better suggestions for fixes. --> Part of rewriting hash aggregation into several dedicated streams. In the first step #22729, `PartialHashAggregateStream` and `FinalHashAggregateStream` has been split from the old `GroupsHashAggregateStream`, but both stream only have basic implementation, no optimizations and extra features like spilling. \* it's incremental migration, so old impl won't change, we plan to delete it once migration is finished This PR forward ports the below optimization to the new implementation: - #8038 The optimizer part don't have to move, ported changes are only inside aggregate operator. ## What changes are included in this PR? <!-- There is no need to duplicate the description in the issue here but it is sometimes worth providing a summary of the individual changes in this PR. --> Extends `PartialHashAggregateStream` and `FinalHashAggregateStream` to apply the optimization. See code comment at `datafusion/physical-plan/src/aggregates/hash_aggregate.rs` for the background. ## Are these changes tested? <!-- We typically require tests for all PRs in order to: 1. Prevent the code from being accidentally broken by subsequent changes 2. Serve as another way to document the expected behavior of the code If tests are not included in your PR, please explain why (for example, are they covered by existing tests)? --> Yes, the original test in #8038 is only at `ExecutionPlan` level, they're still passing after the change. This PR added new test coverage: check `explain analyze` to ensure the implementation actually respects this soft limit at runtime. ## Are there any user-facing changes? <!-- If there are user-facing changes then we may require documentation to be updated before approving the PR. --> <!-- If there are any breaking changes to public APIs, please add the `api change` label. --> --------- Co-authored-by: Martin Grigorov <martin-g@users.noreply.github.com>
1 parent 8f6876d commit 40a6454

4 files changed

Lines changed: 312 additions & 11 deletions

File tree

datafusion/core/tests/physical_optimizer/limited_distinct_aggregation.rs

Lines changed: 116 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ use datafusion_physical_expr_common::sort_expr::PhysicalSortExpr;
3636
use datafusion_physical_plan::{
3737
ExecutionPlan,
3838
aggregates::{AggregateExec, AggregateMode},
39-
collect,
39+
collect, displayable,
4040
limit::{GlobalLimitExec, LocalLimitExec},
4141
};
4242

@@ -104,6 +104,121 @@ async fn test_partial_final() -> Result<()> {
104104
Ok(())
105105
}
106106

107+
// Ensure operator respect the soft limit and stops early: `AggregateExec`'s
108+
// `output_rows` metric should be smaller than then total distinct group count.
109+
#[tokio::test]
110+
async fn limited_distinct_aggregate_stream_respects_soft_limit() -> Result<()> {
111+
// Snapshot for an aggregate operator node from `EXPLAIN ANALYZE`.
112+
//
113+
// Example: In an `EXPLAIN ANALYZE` output
114+
// ```txt
115+
// AggregateExec: mode=partial, limit=10, metrics=[output_rows=100, ...]
116+
// ```
117+
// we get:
118+
// ```txt
119+
// AggregateRuntimeMetric {
120+
// mode: Partial,
121+
// limit: Some(10),
122+
// output_rows: 100,
123+
// }
124+
// ```
125+
#[derive(Debug)]
126+
struct AggregateRuntimeMetric {
127+
mode: AggregateMode,
128+
limit: Option<usize>,
129+
output_rows: usize,
130+
}
131+
132+
fn collect_aggregate_runtime_metrics(
133+
plan: &Arc<dyn ExecutionPlan>,
134+
metrics: &mut Vec<AggregateRuntimeMetric>,
135+
) {
136+
if let Some(agg) = plan.downcast_ref::<AggregateExec>() {
137+
let output_rows = agg
138+
.metrics()
139+
.and_then(|metrics| metrics.aggregate_by_name().output_rows())
140+
.expect("AggregateExec should record output_rows after execution");
141+
142+
metrics.push(AggregateRuntimeMetric {
143+
mode: *agg.mode(),
144+
limit: agg.limit_options().map(|config| config.limit()),
145+
output_rows,
146+
});
147+
}
148+
149+
for child in plan.children() {
150+
collect_aggregate_runtime_metrics(child, metrics);
151+
}
152+
}
153+
154+
fn aggregate_runtime_metrics(
155+
plan: &Arc<dyn ExecutionPlan>,
156+
) -> Vec<AggregateRuntimeMetric> {
157+
let mut metrics = vec![];
158+
collect_aggregate_runtime_metrics(plan, &mut metrics);
159+
metrics
160+
}
161+
162+
let cfg = SessionConfig::new()
163+
.with_target_partitions(2)
164+
.with_batch_size(10)
165+
.set_bool("datafusion.execution.enable_migration_aggregate", true);
166+
let ctx = SessionContext::new_with_config(cfg);
167+
168+
let dataframe = ctx
169+
.sql(
170+
"SELECT DISTINCT value % 100000 AS v \
171+
FROM generate_series(1000000) \
172+
LIMIT 10",
173+
)
174+
.await?;
175+
let plan = dataframe.create_physical_plan().await?;
176+
let formatted_plan = displayable(plan.as_ref()).indent(false).to_string();
177+
assert!(
178+
formatted_plan.contains("AggregateExec: mode=Partial"),
179+
"expected a partial aggregate in plan:\n{formatted_plan}"
180+
);
181+
assert!(
182+
formatted_plan.contains("AggregateExec: mode=FinalPartitioned"),
183+
"expected a final partitioned aggregate in plan:\n{formatted_plan}"
184+
);
185+
186+
let batches = collect(Arc::clone(&plan), ctx.task_ctx()).await?;
187+
assert_eq!(
188+
batches.iter().map(|batch| batch.num_rows()).sum::<usize>(),
189+
10
190+
);
191+
192+
let metrics = aggregate_runtime_metrics(&plan);
193+
let partial = metrics
194+
.iter()
195+
.find(|metric| metric.mode == AggregateMode::Partial)
196+
.expect("expected partial aggregate metrics");
197+
let final_aggregate = metrics
198+
.iter()
199+
.find(|metric| {
200+
matches!(
201+
metric.mode,
202+
AggregateMode::Final | AggregateMode::FinalPartitioned
203+
)
204+
})
205+
.expect("expected final aggregate metrics");
206+
207+
assert_eq!(partial.limit, Some(10));
208+
assert_eq!(final_aggregate.limit, Some(10));
209+
210+
assert!(
211+
partial.output_rows <= 100,
212+
"partial aggregate should stop before emitting all distinct groups: {metrics:?}"
213+
);
214+
assert!(
215+
final_aggregate.output_rows <= 100,
216+
"final aggregate should stop before emitting all distinct groups: {metrics:?}"
217+
);
218+
219+
Ok(())
220+
}
221+
107222
#[tokio::test]
108223
async fn test_single_local() -> Result<()> {
109224
let source = mock_data()?;

datafusion/physical-plan/src/aggregates/hash_aggregate.rs

Lines changed: 88 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,33 @@ use crate::{InputOrderMode, RecordBatchStream, SendableRecordBatchStream, metric
6060
/// ## Final Stage Behavior
6161
/// Input: partial states
6262
/// Output: results for all groups (e.g. for avg(x), it's avg(x) calculated from the state)
63+
///
64+
/// # Optimization: DISTINCT LIMIT Soft Limit
65+
///
66+
/// This optimization applies to both [`PartialHashAggregateStream`] and [`FinalHashAggregateStream`]
67+
///
68+
/// Unordered distinct queries such as:
69+
///
70+
/// ```sql
71+
/// SELECT DISTINCT x FROM t LIMIT 10;
72+
/// ```
73+
///
74+
/// are optimized into a two-stage aggregate like:
75+
///
76+
/// ```txt
77+
/// LimitExec, limit=10
78+
/// --AggregateExec(Final), group_by=[x], aggr=[], soft_limit=10
79+
/// ---- RepartitionExec, partitioning=hash(x)
80+
/// ------ AggregateExec(Partial), group_by=[x], aggr=[], soft_limit=10
81+
/// -------- Scan(t)
82+
/// ```
83+
///
84+
/// After each input batch, the stream checks whether the soft limit has been
85+
/// reached. If so, it emits the accumulated groups and stops reading input.
86+
///
87+
/// This operator does not guarantee an exact limit because a single batch can
88+
/// cross the threshold. The downstream limit operator enforces the exact result
89+
/// size.
6390
pub(crate) struct PartialHashAggregateStream {
6491
/// Output schema: group columns followed by partial aggregate state columns.
6592
schema: SchemaRef,
@@ -78,6 +105,12 @@ pub(crate) struct PartialHashAggregateStream {
78105

79106
/// Tracks partial aggregation row reduction, matching `GroupedHashAggregateStream`.
80107
reduction_factor: metrics::RatioMetrics,
108+
109+
/// Optional soft limit on the number of groups to accumulate before output.
110+
///
111+
/// Invariant: when this is `Some(..)`, the accumulators inside `hash_table` must
112+
/// be empty. See struct comments for details.
113+
group_values_soft_limit: Option<usize>,
81114
}
82115

83116
/// Hash aggregation uses a 2-stage (partial and final) hash aggregation, this stream
@@ -99,6 +132,9 @@ pub(crate) struct FinalHashAggregateStream {
99132

100133
/// Memory reservation for group keys and accumulators.
101134
reservation: MemoryReservation,
135+
136+
/// See comments for the same variable in [`PartialHashAggregateStream`]
137+
group_values_soft_limit: Option<usize>,
102138
}
103139

104140
impl PartialHashAggregateStream {
@@ -139,8 +175,21 @@ impl PartialHashAggregateStream {
139175
baseline_metrics,
140176
reservation,
141177
reduction_factor,
178+
group_values_soft_limit: agg.limit_options().map(|config| config.limit()),
142179
})
143180
}
181+
182+
/// See comments in [`Self::group_values_soft_limit`] for details.
183+
fn hit_soft_group_limit(&self) -> bool {
184+
self.group_values_soft_limit
185+
.is_some_and(|limit| limit <= self.hash_table.building_group_count())
186+
}
187+
188+
fn start_output(&mut self) -> Result<()> {
189+
let input_schema = self.input.schema();
190+
self.input = Box::pin(EmptyRecordBatchStream::new(input_schema));
191+
self.hash_table.start_output()
192+
}
144193
}
145194

146195
impl Stream for PartialHashAggregateStream {
@@ -169,6 +218,18 @@ impl Stream for PartialHashAggregateStream {
169218
return Poll::Ready(Some(Err(e)));
170219
}
171220

221+
if self.hit_soft_group_limit() {
222+
let timer = elapsed_compute.timer();
223+
let result = self.start_output();
224+
timer.done();
225+
226+
if let Err(e) = result {
227+
return Poll::Ready(Some(Err(e)));
228+
}
229+
230+
continue;
231+
}
232+
172233
// TODO: impl memory-limited aggr, when OOM directly send
173234
// partial state to final aggregate stage
174235
if let Err(e) =
@@ -181,11 +242,8 @@ impl Stream for PartialHashAggregateStream {
181242
return Poll::Ready(Some(Err(e)));
182243
}
183244
Poll::Ready(None) => {
184-
let input_schema = self.input.schema();
185-
self.input = Box::pin(EmptyRecordBatchStream::new(input_schema));
186-
187245
let timer = elapsed_compute.timer();
188-
let result = self.hash_table.start_output();
246+
let result = self.start_output();
189247
timer.done();
190248

191249
if let Err(e) = result {
@@ -262,8 +320,21 @@ impl FinalHashAggregateStream {
262320
hash_table,
263321
baseline_metrics,
264322
reservation,
323+
group_values_soft_limit: agg.limit_options().map(|config| config.limit()),
265324
})
266325
}
326+
327+
/// See comments in [`Self::group_values_soft_limit`] for details.
328+
fn hit_soft_group_limit(&self) -> bool {
329+
self.group_values_soft_limit
330+
.is_some_and(|limit| limit <= self.hash_table.building_group_count())
331+
}
332+
333+
fn start_output(&mut self) -> Result<()> {
334+
let input_schema = self.input.schema();
335+
self.input = Box::pin(EmptyRecordBatchStream::new(input_schema));
336+
self.hash_table.start_output()
337+
}
267338
}
268339

269340
impl Stream for FinalHashAggregateStream {
@@ -291,6 +362,18 @@ impl Stream for FinalHashAggregateStream {
291362
return Poll::Ready(Some(Err(e)));
292363
}
293364

365+
if self.hit_soft_group_limit() {
366+
let timer = elapsed_compute.timer();
367+
let result = self.start_output();
368+
timer.done();
369+
370+
if let Err(e) = result {
371+
return Poll::Ready(Some(Err(e)));
372+
}
373+
374+
continue;
375+
}
376+
294377
if let Err(e) =
295378
self.reservation.try_resize(self.hash_table.memory_size())
296379
{
@@ -301,11 +384,8 @@ impl Stream for FinalHashAggregateStream {
301384
return Poll::Ready(Some(Err(e)));
302385
}
303386
Poll::Ready(None) => {
304-
let input_schema = self.input.schema();
305-
self.input = Box::pin(EmptyRecordBatchStream::new(input_schema));
306-
307387
let timer = elapsed_compute.timer();
308-
let result = self.hash_table.start_output();
388+
let result = self.start_output();
309389
timer.done();
310390

311391
if let Err(e) = result {

datafusion/physical-plan/src/aggregates/hash_table.rs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -342,6 +342,10 @@ impl<Mode> AggregateHashTable<Mode> {
342342
}
343343
}
344344

345+
pub(super) fn building_group_count(&self) -> usize {
346+
self.state.building().group_values.len()
347+
}
348+
345349
pub(super) fn is_building(&self) -> bool {
346350
matches!(self.state, AggregateHashTableState::Building(_))
347351
}

0 commit comments

Comments
 (0)