Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
988 changes: 960 additions & 28 deletions datafusion/optimizer/src/eliminate_join.rs

Large diffs are not rendered by default.

28 changes: 4 additions & 24 deletions datafusion/optimizer/src/optimize_projections/required_indices.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,8 @@

//! [`RequiredIndices`] helper for OptimizeProjection

use datafusion_common::tree_node::{TreeNode, TreeNodeRecursion};
use crate::utils::for_each_referenced_index;
use datafusion_common::tree_node::TreeNodeRecursion;
use datafusion_common::{Column, DFSchemaRef, Result};
use datafusion_expr::{Expr, LogicalPlan};

Expand Down Expand Up @@ -112,29 +113,8 @@ impl RequiredIndices {
/// * `input_schema`: The input schema to analyze for index requirements.
/// * `expr`: An expression for which we want to find necessary field indices.
fn add_expr(&mut self, input_schema: &DFSchemaRef, expr: &Expr) {
// `apply` does not descend into subqueries, so recurse manually to
// handle those cases.
expr.apply(|e| {
match e {
Expr::Column(c) | Expr::OuterReferenceColumn(_, c) => {
if let Some(idx) = input_schema.maybe_index_of_column(c) {
self.indices.push(idx);
}
}
Expr::ScalarSubquery(sub) => {
self.add_exprs(input_schema, &sub.outer_ref_columns);
}
Expr::Exists(ex) => {
self.add_exprs(input_schema, &ex.subquery.outer_ref_columns);
}
Expr::InSubquery(isq) => {
self.add_exprs(input_schema, &isq.subquery.outer_ref_columns);
}
_ => {}
}
Ok(TreeNodeRecursion::Continue)
})
.expect("traversal is infallible");
for_each_referenced_index(expr, input_schema, |idx| self.indices.push(idx))
.expect("traversal is infallible");
}

/// Like [`Self::add_expr`], but for multiple expressions.
Expand Down
53 changes: 52 additions & 1 deletion datafusion/optimizer/src/utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,10 @@ use arrow::array::{Array, RecordBatch, new_null_array};
use arrow::datatypes::{DataType, Field, Schema};
use datafusion_common::TableReference;
use datafusion_common::cast::as_boolean_array;
use datafusion_common::tree_node::{TransformedResult, TreeNode};
use datafusion_common::tree_node::{TransformedResult, TreeNode, TreeNodeRecursion};
use datafusion_common::{Column, DFSchema, Result, ScalarValue};
use datafusion_expr::execution_props::ExecutionProps;
use datafusion_expr::expr::{Exists, InSubquery, SetComparison};
use datafusion_expr::expr_rewriter::replace_col;
use datafusion_expr::{ColumnarValue, Expr, logical_plan::LogicalPlan};
use datafusion_physical_expr::create_physical_expr;
Expand All @@ -37,6 +38,56 @@ use std::sync::Arc;
/// as it was initially placed here and then moved elsewhere.
pub use datafusion_expr::expr_rewriter::NamePreserver;

/// Invokes `f` with the index, within `schema`, of every column referenced by
/// `expr` — including columns reached through a correlated subquery's outer
/// references. Columns absent from `schema` are skipped.
///
/// A subquery's own plan is intentionally not traversed: its internal columns
/// index into its own schema, not `schema`; only the outer (correlated) columns
/// it references from `schema` are relevant. The comparison expression of an
/// `IN`/set-comparison subquery is reached by the normal expression walk.
///
/// This is the shared primitive behind the top-down "which of a node's output
/// columns does an ancestor still need" analyses, namely
/// [`OptimizeProjections`](crate::optimize_projections::OptimizeProjections)
/// and [`EliminateJoin`](crate::eliminate_join::EliminateJoin). The two keep
/// their own required-index containers (an ordered set vs. a hash set), so this
/// reports indices through a callback rather than populating a shared type.
pub(crate) fn for_each_referenced_index(
expr: &Expr,
schema: &DFSchema,
mut f: impl FnMut(usize),
) -> Result<()> {
visit_referenced_indices(expr, schema, &mut f)
}

fn visit_referenced_indices(
expr: &Expr,
schema: &DFSchema,
f: &mut dyn FnMut(usize),
) -> Result<()> {
expr.apply(|expr| {
match expr {
Expr::Column(column) | Expr::OuterReferenceColumn(_, column) => {
if let Some(idx) = schema.maybe_index_of_column(column) {
f(idx);
}
}
Expr::Exists(Exists { subquery, .. })
| Expr::InSubquery(InSubquery { subquery, .. })
| Expr::SetComparison(SetComparison { subquery, .. })
| Expr::ScalarSubquery(subquery) => {
for outer in &subquery.outer_ref_columns {
visit_referenced_indices(outer, schema, f)?;
}
}
_ => {}
}
Ok(TreeNodeRecursion::Continue)
})?;
Ok(())
}

/// Returns true if `expr` contains all columns in `schema_cols`
pub(crate) fn has_all_column_refs(
expr: &Expr,
Expand Down
57 changes: 47 additions & 10 deletions datafusion/sqllogictest/test_files/joins.slt
Original file line number Diff line number Diff line change
Expand Up @@ -1333,19 +1333,57 @@ inner join join_t2 on join_t1.t1_id = join_t2.t2_id
----
logical_plan
01)Aggregate: groupBy=[[join_t1.t1_id]], aggr=[[]]
02)--Projection: join_t1.t1_id
03)----Inner Join: join_t1.t1_id = join_t2.t2_id
04)------TableScan: join_t1 projection=[t1_id]
05)------TableScan: join_t2 projection=[t2_id]
02)--LeftSemi Join: join_t1.t1_id = join_t2.t2_id
03)----TableScan: join_t1 projection=[t1_id]
04)----TableScan: join_t2 projection=[t2_id]
physical_plan
01)AggregateExec: mode=FinalPartitioned, gby=[t1_id@0 as t1_id], aggr=[]
02)--RepartitionExec: partitioning=Hash([t1_id@0], 2), input_partitions=2
03)----AggregateExec: mode=Partial, gby=[t1_id@0 as t1_id], aggr=[]
04)------HashJoinExec: mode=CollectLeft, join_type=Inner, on=[(t1_id@0, t2_id@0)], projection=[t1_id@0]
04)------HashJoinExec: mode=CollectLeft, join_type=LeftSemi, on=[(t1_id@0, t2_id@0)]
05)--------DataSourceExec: partitions=1, partition_sizes=[1]
06)--------RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1
07)----------DataSourceExec: partitions=1, partition_sizes=[1]

statement ok
set datafusion.explain.logical_plan_only = true;

# A single `count(DISTINCT col)` over a join whose other side is used only as an
# existence filter can be rewritten to a semi join.
query TT
EXPLAIN
select join_t1.t1_id, count(distinct join_t1.t1_int)
from join_t1
inner join join_t2 on join_t1.t1_id = join_t2.t2_id
group by join_t1.t1_id
----
logical_plan
01)Projection: join_t1.t1_id, count(alias1) AS count(DISTINCT join_t1.t1_int)
02)--Aggregate: groupBy=[[join_t1.t1_id]], aggr=[[count(alias1)]]
03)----Aggregate: groupBy=[[join_t1.t1_id, join_t1.t1_int AS alias1]], aggr=[[]]
04)------LeftSemi Join: join_t1.t1_id = join_t2.t2_id
05)--------TableScan: join_t1 projection=[t1_id, t1_int]
06)--------TableScan: join_t2 projection=[t2_id]

# A similar query with two DISTINCT aggregates is currently not rewritten
# TODO: https://github.com/apache/datafusion/issues/22644
query TT
EXPLAIN
select join_t1.t1_id, count(distinct join_t1.t1_int), count(distinct join_t1.t1_name)
from join_t1
inner join join_t2 on join_t1.t1_id = join_t2.t2_id
group by join_t1.t1_id
----
logical_plan
01)Aggregate: groupBy=[[join_t1.t1_id]], aggr=[[count(DISTINCT join_t1.t1_int), count(DISTINCT join_t1.t1_name)]]
02)--Projection: join_t1.t1_id, join_t1.t1_name, join_t1.t1_int
03)----Inner Join: join_t1.t1_id = join_t2.t2_id
04)------TableScan: join_t1 projection=[t1_id, t1_name, t1_int]
05)------TableScan: join_t2 projection=[t2_id]

statement ok
set datafusion.explain.logical_plan_only = false;

# Join on struct
query TT
explain select join_t3.s3, join_t4.s4
Expand Down Expand Up @@ -1411,10 +1449,9 @@ logical_plan
01)Projection: count(alias1) AS count(DISTINCT join_t1.t1_id)
02)--Aggregate: groupBy=[[]], aggr=[[count(alias1)]]
03)----Aggregate: groupBy=[[join_t1.t1_id AS alias1]], aggr=[[]]
04)------Projection: join_t1.t1_id
05)--------Inner Join: join_t1.t1_id = join_t2.t2_id
06)----------TableScan: join_t1 projection=[t1_id]
07)----------TableScan: join_t2 projection=[t2_id]
04)------LeftSemi Join: join_t1.t1_id = join_t2.t2_id
05)--------TableScan: join_t1 projection=[t1_id]
06)--------TableScan: join_t2 projection=[t2_id]
physical_plan
01)ProjectionExec: expr=[count(alias1)@0 as count(DISTINCT join_t1.t1_id)]
02)--AggregateExec: mode=Final, gby=[], aggr=[count(alias1)]
Expand All @@ -1423,7 +1460,7 @@ physical_plan
05)--------AggregateExec: mode=FinalPartitioned, gby=[alias1@0 as alias1], aggr=[]
06)----------RepartitionExec: partitioning=Hash([alias1@0], 2), input_partitions=2
07)------------AggregateExec: mode=Partial, gby=[t1_id@0 as alias1], aggr=[]
08)--------------HashJoinExec: mode=CollectLeft, join_type=Inner, on=[(t1_id@0, t2_id@0)], projection=[t1_id@0]
08)--------------HashJoinExec: mode=CollectLeft, join_type=LeftSemi, on=[(t1_id@0, t2_id@0)]
09)----------------DataSourceExec: partitions=1, partition_sizes=[1]
10)----------------RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1
11)------------------DataSourceExec: partitions=1, partition_sizes=[1]
Expand Down
14 changes: 7 additions & 7 deletions datafusion/sqllogictest/test_files/subquery.slt
Original file line number Diff line number Diff line change
Expand Up @@ -338,13 +338,13 @@ where c_acctbal < (
logical_plan
01)Sort: customer.c_custkey ASC NULLS LAST
02)--Projection: customer.c_custkey
03)----Inner Join: customer.c_custkey = __scalar_sq_1.o_custkey Filter: CAST(customer.c_acctbal AS Decimal128(25, 2)) < __scalar_sq_1.sum(orders.o_totalprice)
03)----LeftSemi Join: customer.c_custkey = __scalar_sq_1.o_custkey Filter: CAST(customer.c_acctbal AS Decimal128(25, 2)) < __scalar_sq_1.sum(orders.o_totalprice)
04)------TableScan: customer projection=[c_custkey, c_acctbal]
05)------SubqueryAlias: __scalar_sq_1
06)--------Projection: sum(orders.o_totalprice), orders.o_custkey
07)----------Aggregate: groupBy=[[orders.o_custkey]], aggr=[[sum(orders.o_totalprice)]]
08)------------Projection: orders.o_custkey, orders.o_totalprice
09)--------------Inner Join: orders.o_orderkey = __scalar_sq_2.l_orderkey Filter: CAST(orders.o_totalprice AS Decimal128(25, 2)) < __scalar_sq_2.price
09)--------------LeftSemi Join: orders.o_orderkey = __scalar_sq_2.l_orderkey Filter: CAST(orders.o_totalprice AS Decimal128(25, 2)) < __scalar_sq_2.price
10)----------------TableScan: orders projection=[o_orderkey, o_custkey, o_totalprice]
11)----------------SubqueryAlias: __scalar_sq_2
12)------------------Projection: sum(lineitem.l_extendedprice) AS price, lineitem.l_orderkey
Expand Down Expand Up @@ -555,7 +555,7 @@ logical_plan
02)--TableScan: t0 projection=[t0_id, t0_name]
03)--SubqueryAlias: __correlated_sq_2
04)----Projection: t1.t1_name
05)------Inner Join: t1.t1_id = t2.t2_id
05)------LeftSemi Join: t1.t1_id = t2.t2_id
06)--------TableScan: t1 projection=[t1_id, t1_name]
07)--------TableScan: t2 projection=[t2_id]

Expand All @@ -568,7 +568,7 @@ logical_plan
02)--TableScan: t0 projection=[t0_id, t0_name]
03)--SubqueryAlias: __correlated_sq_1
04)----Projection: t2.t2_name
05)------Inner Join: t1.t1_id = t2.t2_id
05)------RightSemi Join: t1.t1_id = t2.t2_id
06)--------TableScan: t1 projection=[t1_id]
07)--------SubqueryAlias: t2
08)----------TableScan: t2 projection=[t2_id, t2_name]
Expand Down Expand Up @@ -1675,7 +1675,7 @@ where c_acctbal < (
logical_plan
01)Sort: customer.c_custkey ASC NULLS LAST
02)--Projection: customer.c_custkey
03)----Inner Join: customer.c_custkey = __scalar_sq_2.o_custkey Filter: CAST(customer.c_acctbal AS Decimal128(25, 2)) < __scalar_sq_2.sum(orders.o_totalprice)
03)----LeftSemi Join: customer.c_custkey = __scalar_sq_2.o_custkey Filter: CAST(customer.c_acctbal AS Decimal128(25, 2)) < __scalar_sq_2.sum(orders.o_totalprice)
04)------TableScan: customer projection=[c_custkey, c_acctbal]
05)------SubqueryAlias: __scalar_sq_2
06)--------Projection: sum(orders.o_totalprice), orders.o_custkey
Expand All @@ -1701,7 +1701,7 @@ where c_acctbal < (
logical_plan
01)Sort: customer.c_custkey ASC NULLS LAST
02)--Projection: customer.c_custkey
03)----Inner Join: customer.c_custkey = __scalar_sq_2.o_custkey Filter: CAST(customer.c_acctbal AS Decimal128(25, 2)) < __scalar_sq_2.sum(orders.o_totalprice)
03)----LeftSemi Join: customer.c_custkey = __scalar_sq_2.o_custkey Filter: CAST(customer.c_acctbal AS Decimal128(25, 2)) < __scalar_sq_2.sum(orders.o_totalprice)
04)------TableScan: customer projection=[c_custkey, c_acctbal]
05)------SubqueryAlias: __scalar_sq_2
06)--------Projection: sum(orders.o_totalprice), orders.o_custkey
Expand Down Expand Up @@ -1746,7 +1746,7 @@ WHERE e1.salary > (
----
logical_plan
01)Projection: e1.employee_name, e1.salary
02)--Inner Join: e1.dept_id = __scalar_sq_1.dept_id Filter: CAST(e1.salary AS Decimal128(38, 14)) > __scalar_sq_1.avg(e2.salary)
02)--LeftSemi Join: e1.dept_id = __scalar_sq_1.dept_id Filter: CAST(e1.salary AS Decimal128(38, 14)) > __scalar_sq_1.avg(e2.salary)
03)----SubqueryAlias: e1
04)------TableScan: employees projection=[employee_name, dept_id, salary]
05)----SubqueryAlias: __scalar_sq_1
Expand Down
8 changes: 4 additions & 4 deletions datafusion/sqllogictest/test_files/tpch/plans/q11.slt.part
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ logical_plan
05)--------Projection: CAST(CAST(sum(partsupp.ps_supplycost * partsupp.ps_availqty) AS Float64) * Float64(0.0001) AS Decimal128(38, 15))
06)----------Aggregate: groupBy=[[]], aggr=[[sum(partsupp.ps_supplycost * CAST(partsupp.ps_availqty AS Decimal128(10, 0)))]]
07)------------Projection: partsupp.ps_availqty, partsupp.ps_supplycost
08)--------------Inner Join: supplier.s_nationkey = nation.n_nationkey
08)--------------LeftSemi Join: supplier.s_nationkey = nation.n_nationkey
09)----------------Projection: partsupp.ps_availqty, partsupp.ps_supplycost, supplier.s_nationkey
10)------------------Inner Join: partsupp.ps_suppkey = supplier.s_suppkey
11)--------------------TableScan: partsupp projection=[ps_suppkey, ps_availqty, ps_supplycost]
Expand All @@ -64,7 +64,7 @@ logical_plan
15)--------------------TableScan: nation projection=[n_nationkey, n_name], partial_filters=[nation.n_name = Utf8View("GERMANY")]
16)------Aggregate: groupBy=[[partsupp.ps_partkey]], aggr=[[sum(partsupp.ps_supplycost * CAST(partsupp.ps_availqty AS Decimal128(10, 0)))]]
17)--------Projection: partsupp.ps_partkey, partsupp.ps_availqty, partsupp.ps_supplycost
18)----------Inner Join: supplier.s_nationkey = nation.n_nationkey
18)----------LeftSemi Join: supplier.s_nationkey = nation.n_nationkey
19)------------Projection: partsupp.ps_partkey, partsupp.ps_availqty, partsupp.ps_supplycost, supplier.s_nationkey
20)--------------Inner Join: partsupp.ps_suppkey = supplier.s_suppkey
21)----------------TableScan: partsupp projection=[ps_partkey, ps_suppkey, ps_availqty, ps_supplycost]
Expand All @@ -81,7 +81,7 @@ physical_plan
06)----------AggregateExec: mode=FinalPartitioned, gby=[ps_partkey@0 as ps_partkey], aggr=[sum(partsupp.ps_supplycost * partsupp.ps_availqty)]
07)------------RepartitionExec: partitioning=Hash([ps_partkey@0], 4), input_partitions=4
08)--------------AggregateExec: mode=Partial, gby=[ps_partkey@0 as ps_partkey], aggr=[sum(partsupp.ps_supplycost * partsupp.ps_availqty)]
09)----------------HashJoinExec: mode=Partitioned, join_type=Inner, on=[(s_nationkey@3, n_nationkey@0)], projection=[ps_partkey@0, ps_availqty@1, ps_supplycost@2]
09)----------------HashJoinExec: mode=Partitioned, join_type=LeftSemi, on=[(s_nationkey@3, n_nationkey@0)], projection=[ps_partkey@0, ps_availqty@1, ps_supplycost@2]
10)------------------RepartitionExec: partitioning=Hash([s_nationkey@3], 4), input_partitions=4
11)--------------------HashJoinExec: mode=Partitioned, join_type=Inner, on=[(ps_suppkey@1, s_suppkey@0)], projection=[ps_partkey@0, ps_availqty@2, ps_supplycost@3, s_nationkey@5]
12)----------------------RepartitionExec: partitioning=Hash([ps_suppkey@1], 4), input_partitions=4
Expand All @@ -96,7 +96,7 @@ physical_plan
21)----AggregateExec: mode=Final, gby=[], aggr=[sum(partsupp.ps_supplycost * partsupp.ps_availqty)]
22)------CoalescePartitionsExec
23)--------AggregateExec: mode=Partial, gby=[], aggr=[sum(partsupp.ps_supplycost * partsupp.ps_availqty)]
24)----------HashJoinExec: mode=Partitioned, join_type=Inner, on=[(s_nationkey@2, n_nationkey@0)], projection=[ps_availqty@0, ps_supplycost@1]
24)----------HashJoinExec: mode=Partitioned, join_type=LeftSemi, on=[(s_nationkey@2, n_nationkey@0)], projection=[ps_availqty@0, ps_supplycost@1]
25)------------RepartitionExec: partitioning=Hash([s_nationkey@2], 4), input_partitions=4
26)--------------HashJoinExec: mode=Partitioned, join_type=Inner, on=[(ps_suppkey@0, s_suppkey@0)], projection=[ps_availqty@1, ps_supplycost@2, s_nationkey@4]
27)----------------RepartitionExec: partitioning=Hash([ps_suppkey@0], 4), input_partitions=4
Expand Down
Loading
Loading