Skip to content

Commit 3e0cdfc

Browse files
committed
fix(substrait): Do not add implicit groupBy expressions in LogicalPlanBuilder or when building logical plans from Substrait (apache#14860)
* feat: add add_implicit_group_by_exprs option to logical plan builder * fix: do not add implicity group by exprs in substrait path * test: add substrait tests * test: add builder option tests * style: clippy errors
1 parent 7299d0e commit 3e0cdfc

File tree

7 files changed

+368
-12
lines changed

7 files changed

+368
-12
lines changed

datafusion/core/src/dataframe/mod.rs

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,8 @@ use crate::execution::context::{SessionState, TaskContext};
3333
use crate::execution::FunctionRegistry;
3434
use crate::logical_expr::utils::find_window_exprs;
3535
use crate::logical_expr::{
36-
col, Expr, JoinType, LogicalPlan, LogicalPlanBuilder, Partitioning, TableType,
36+
col, Expr, JoinType, LogicalPlan, LogicalPlanBuilder, LogicalPlanBuilderOptions,
37+
Partitioning, TableType,
3738
};
3839
use crate::physical_plan::{
3940
collect, collect_partitioned, execute_stream, execute_stream_partitioned,
@@ -526,7 +527,10 @@ impl DataFrame {
526527
) -> Result<DataFrame> {
527528
let is_grouping_set = matches!(group_expr.as_slice(), [Expr::GroupingSet(_)]);
528529
let aggr_expr_len = aggr_expr.len();
530+
let options =
531+
LogicalPlanBuilderOptions::new().with_add_implicit_group_by_exprs(true);
529532
let plan = LogicalPlanBuilder::from(self.plan)
533+
.with_options(options)
530534
.aggregate(group_expr, aggr_expr)?
531535
.build()?;
532536
let plan = if is_grouping_set {

datafusion/expr/src/logical_plan/builder.rs

Lines changed: 112 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -53,8 +53,8 @@ use datafusion_common::display::ToStringifiedPlan;
5353
use datafusion_common::file_options::file_type::FileType;
5454
use datafusion_common::{
5555
exec_err, get_target_functional_dependencies, internal_err, not_impl_err,
56-
plan_datafusion_err, plan_err, Column, DFSchema, DFSchemaRef, DataFusionError,
57-
Result, ScalarValue, TableReference, ToDFSchema, UnnestOptions,
56+
plan_datafusion_err, plan_err, Column, Constraints, DFSchema, DFSchemaRef,
57+
DataFusionError, Result, ScalarValue, TableReference, ToDFSchema, UnnestOptions,
5858
};
5959
use datafusion_expr_common::type_coercion::binary::type_union_resolution;
6060

@@ -63,6 +63,26 @@ use indexmap::IndexSet;
6363
/// Default table name for unnamed table
6464
pub const UNNAMED_TABLE: &str = "?table?";
6565

66+
/// Options for [`LogicalPlanBuilder`]
67+
#[derive(Default, Debug, Clone)]
68+
pub struct LogicalPlanBuilderOptions {
69+
/// Flag indicating whether the plan builder should add
70+
/// functionally dependent expressions as additional aggregation groupings.
71+
add_implicit_group_by_exprs: bool,
72+
}
73+
74+
impl LogicalPlanBuilderOptions {
75+
pub fn new() -> Self {
76+
Default::default()
77+
}
78+
79+
/// Should the builder add functionally dependent expressions as additional aggregation groupings.
80+
pub fn with_add_implicit_group_by_exprs(mut self, add: bool) -> Self {
81+
self.add_implicit_group_by_exprs = add;
82+
self
83+
}
84+
}
85+
6686
/// Builder for logical plans
6787
///
6888
/// # Example building a simple plan
@@ -103,19 +123,29 @@ pub const UNNAMED_TABLE: &str = "?table?";
103123
#[derive(Debug, Clone)]
104124
pub struct LogicalPlanBuilder {
105125
plan: Arc<LogicalPlan>,
126+
options: LogicalPlanBuilderOptions,
106127
}
107128

108129
impl LogicalPlanBuilder {
109130
/// Create a builder from an existing plan
110131
pub fn new(plan: LogicalPlan) -> Self {
111132
Self {
112133
plan: Arc::new(plan),
134+
options: LogicalPlanBuilderOptions::default(),
113135
}
114136
}
115137

116138
/// Create a builder from an existing plan
117139
pub fn new_from_arc(plan: Arc<LogicalPlan>) -> Self {
118-
Self { plan }
140+
Self {
141+
plan,
142+
options: LogicalPlanBuilderOptions::default(),
143+
}
144+
}
145+
146+
pub fn with_options(mut self, options: LogicalPlanBuilderOptions) -> Self {
147+
self.options = options;
148+
self
119149
}
120150

121151
/// Return the output schema of the plan build so far
@@ -1138,8 +1168,12 @@ impl LogicalPlanBuilder {
11381168
let group_expr = normalize_cols(group_expr, &self.plan)?;
11391169
let aggr_expr = normalize_cols(aggr_expr, &self.plan)?;
11401170

1141-
let group_expr =
1142-
add_group_by_exprs_from_dependencies(group_expr, self.plan.schema())?;
1171+
let group_expr = if self.options.add_implicit_group_by_exprs {
1172+
add_group_by_exprs_from_dependencies(group_expr, self.plan.schema())?
1173+
} else {
1174+
group_expr
1175+
};
1176+
11431177
Aggregate::try_new(self.plan, group_expr, aggr_expr)
11441178
.map(LogicalPlan::Aggregate)
11451179
.map(Self::new)
@@ -1550,6 +1584,7 @@ pub fn add_group_by_exprs_from_dependencies(
15501584
}
15511585
Ok(group_expr)
15521586
}
1587+
15531588
/// Errors if one or more expressions have equal names.
15541589
pub fn validate_unique_names<'a>(
15551590
node_name: &str,
@@ -1685,7 +1720,21 @@ pub fn table_scan_with_filter_and_fetch(
16851720

16861721
pub fn table_source(table_schema: &Schema) -> Arc<dyn TableSource> {
16871722
let table_schema = Arc::new(table_schema.clone());
1688-
Arc::new(LogicalTableSource { table_schema })
1723+
Arc::new(LogicalTableSource {
1724+
table_schema,
1725+
constraints: Default::default(),
1726+
})
1727+
}
1728+
1729+
pub fn table_source_with_constraints(
1730+
table_schema: &Schema,
1731+
constraints: Constraints,
1732+
) -> Arc<dyn TableSource> {
1733+
let table_schema = Arc::new(table_schema.clone());
1734+
Arc::new(LogicalTableSource {
1735+
table_schema,
1736+
constraints,
1737+
})
16891738
}
16901739

16911740
/// Wrap projection for a plan, if the join keys contains normal expression.
@@ -1756,12 +1805,21 @@ pub fn wrap_projection_for_join_if_necessary(
17561805
/// DefaultTableSource.
17571806
pub struct LogicalTableSource {
17581807
table_schema: SchemaRef,
1808+
constraints: Constraints,
17591809
}
17601810

17611811
impl LogicalTableSource {
17621812
/// Create a new LogicalTableSource
17631813
pub fn new(table_schema: SchemaRef) -> Self {
1764-
Self { table_schema }
1814+
Self {
1815+
table_schema,
1816+
constraints: Constraints::default(),
1817+
}
1818+
}
1819+
1820+
pub fn with_constraints(mut self, constraints: Constraints) -> Self {
1821+
self.constraints = constraints;
1822+
self
17651823
}
17661824
}
17671825

@@ -1774,6 +1832,10 @@ impl TableSource for LogicalTableSource {
17741832
Arc::clone(&self.table_schema)
17751833
}
17761834

1835+
fn constraints(&self) -> Option<&Constraints> {
1836+
Some(&self.constraints)
1837+
}
1838+
17771839
fn supports_filters_pushdown(
17781840
&self,
17791841
filters: &[&Expr],
@@ -2023,12 +2085,12 @@ pub fn unnest_with_options(
20232085

20242086
#[cfg(test)]
20252087
mod tests {
2026-
20272088
use super::*;
20282089
use crate::logical_plan::StringifiedPlan;
20292090
use crate::{col, expr, expr_fn::exists, in_subquery, lit, scalar_subquery};
20302091

2031-
use datafusion_common::{RecursionUnnestOption, SchemaError};
2092+
use crate::test::function_stub::sum;
2093+
use datafusion_common::{Constraint, RecursionUnnestOption, SchemaError};
20322094

20332095
#[test]
20342096
fn plan_builder_simple() -> Result<()> {
@@ -2575,4 +2637,45 @@ mod tests {
25752637

25762638
Ok(())
25772639
}
2640+
2641+
#[test]
2642+
fn plan_builder_aggregate_without_implicit_group_by_exprs() -> Result<()> {
2643+
let constraints =
2644+
Constraints::new_unverified(vec![Constraint::PrimaryKey(vec![0])]);
2645+
let table_source = table_source_with_constraints(&employee_schema(), constraints);
2646+
2647+
let plan =
2648+
LogicalPlanBuilder::scan("employee_csv", table_source, Some(vec![0, 3, 4]))?
2649+
.aggregate(vec![col("id")], vec![sum(col("salary"))])?
2650+
.build()?;
2651+
2652+
let expected =
2653+
"Aggregate: groupBy=[[employee_csv.id]], aggr=[[sum(employee_csv.salary)]]\
2654+
\n TableScan: employee_csv projection=[id, state, salary]";
2655+
assert_eq!(expected, format!("{plan}"));
2656+
2657+
Ok(())
2658+
}
2659+
2660+
#[test]
2661+
fn plan_builder_aggregate_with_implicit_group_by_exprs() -> Result<()> {
2662+
let constraints =
2663+
Constraints::new_unverified(vec![Constraint::PrimaryKey(vec![0])]);
2664+
let table_source = table_source_with_constraints(&employee_schema(), constraints);
2665+
2666+
let options =
2667+
LogicalPlanBuilderOptions::new().with_add_implicit_group_by_exprs(true);
2668+
let plan =
2669+
LogicalPlanBuilder::scan("employee_csv", table_source, Some(vec![0, 3, 4]))?
2670+
.with_options(options)
2671+
.aggregate(vec![col("id")], vec![sum(col("salary"))])?
2672+
.build()?;
2673+
2674+
let expected =
2675+
"Aggregate: groupBy=[[employee_csv.id, employee_csv.state, employee_csv.salary]], aggr=[[sum(employee_csv.salary)]]\
2676+
\n TableScan: employee_csv projection=[id, state, salary]";
2677+
assert_eq!(expected, format!("{plan}"));
2678+
2679+
Ok(())
2680+
}
25782681
}

datafusion/expr/src/logical_plan/mod.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ pub mod tree_node;
2828

2929
pub use builder::{
3030
build_join_schema, table_scan, union, wrap_projection_for_join_if_necessary,
31-
LogicalPlanBuilder, LogicalTableSource, UNNAMED_TABLE,
31+
LogicalPlanBuilder, LogicalPlanBuilderOptions, LogicalTableSource, UNNAMED_TABLE,
3232
};
3333
pub use ddl::{
3434
CreateCatalog, CreateCatalogSchema, CreateExternalTable, CreateFunction,

datafusion/sql/src/select.rs

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,8 @@ use datafusion_expr::utils::{
3838
};
3939
use datafusion_expr::{
4040
qualified_wildcard_with_options, wildcard_with_options, Aggregate, Expr, Filter,
41-
GroupingSet, LogicalPlan, LogicalPlanBuilder, Partitioning,
41+
GroupingSet, LogicalPlan, LogicalPlanBuilder, LogicalPlanBuilderOptions,
42+
Partitioning,
4243
};
4344

4445
use indexmap::IndexMap;
@@ -371,7 +372,10 @@ impl<S: ContextProvider> SqlToRel<'_, S> {
371372
let agg_expr = agg.aggr_expr.clone();
372373
let (new_input, new_group_by_exprs) =
373374
self.try_process_group_by_unnest(agg)?;
375+
let options = LogicalPlanBuilderOptions::new()
376+
.with_add_implicit_group_by_exprs(true);
374377
LogicalPlanBuilder::from(new_input)
378+
.with_options(options)
375379
.aggregate(new_group_by_exprs, agg_expr)?
376380
.build()
377381
}
@@ -744,7 +748,10 @@ impl<S: ContextProvider> SqlToRel<'_, S> {
744748
aggr_exprs: &[Expr],
745749
) -> Result<(LogicalPlan, Vec<Expr>, Option<Expr>)> {
746750
// create the aggregate plan
751+
let options =
752+
LogicalPlanBuilderOptions::new().with_add_implicit_group_by_exprs(true);
747753
let plan = LogicalPlanBuilder::from(input.clone())
754+
.with_options(options)
748755
.aggregate(group_by_exprs.to_vec(), aggr_exprs.to_vec())?
749756
.build()?;
750757
let group_by_exprs = if let LogicalPlan::Aggregate(agg) = &plan {

datafusion/substrait/tests/cases/logical_plans.rs

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -91,4 +91,22 @@ mod tests {
9191

9292
Ok(())
9393
}
94+
95+
#[tokio::test]
96+
async fn multilayer_aggregate() -> Result<()> {
97+
let proto_plan =
98+
read_json("tests/testdata/test_plans/multilayer_aggregate.substrait.json");
99+
let ctx = add_plan_schemas_to_ctx(SessionContext::new(), &proto_plan)?;
100+
let plan = from_substrait_plan(&ctx.state(), &proto_plan).await?;
101+
102+
assert_eq!(
103+
format!("{}", plan),
104+
"Projection: lower(sales.product) AS lower(product), sum(count(sales.product)) AS product_count\
105+
\n Aggregate: groupBy=[[sales.product]], aggr=[[sum(count(sales.product))]]\
106+
\n Aggregate: groupBy=[[sales.product]], aggr=[[count(sales.product)]]\
107+
\n TableScan: sales"
108+
);
109+
110+
Ok(())
111+
}
94112
}

datafusion/substrait/tests/cases/roundtrip_logical_plan.rs

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -308,6 +308,17 @@ async fn aggregate_grouping_rollup() -> Result<()> {
308308
).await
309309
}
310310

311+
#[tokio::test]
312+
async fn multilayer_aggregate() -> Result<()> {
313+
assert_expected_plan(
314+
"SELECT a, sum(partial_count_b) FROM (SELECT a, count(b) as partial_count_b FROM data GROUP BY a) GROUP BY a",
315+
"Aggregate: groupBy=[[data.a]], aggr=[[sum(count(data.b)) AS sum(partial_count_b)]]\
316+
\n Aggregate: groupBy=[[data.a]], aggr=[[count(data.b)]]\
317+
\n TableScan: data projection=[a, b]",
318+
true
319+
).await
320+
}
321+
311322
#[tokio::test]
312323
async fn decimal_literal() -> Result<()> {
313324
roundtrip("SELECT * FROM data WHERE b > 2.5").await

0 commit comments

Comments
 (0)