Skip to content

Commit d604b45

Browse files
branch-3.1: [fix](nereids) pull up left join right predicate with or is null #58372 (#58664)
picked from #58372
1 parent fb7628b commit d604b45

File tree

10 files changed

+319
-27
lines changed

10 files changed

+319
-27
lines changed

fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/InferPredicateByReplace.java

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
import org.apache.doris.nereids.trees.expressions.InPredicate;
3030
import org.apache.doris.nereids.trees.expressions.Like;
3131
import org.apache.doris.nereids.trees.expressions.Not;
32+
import org.apache.doris.nereids.trees.expressions.Or;
3233
import org.apache.doris.nereids.trees.expressions.Slot;
3334
import org.apache.doris.nereids.trees.expressions.functions.ExpressionTrait;
3435
import org.apache.doris.nereids.trees.expressions.literal.Literal;
@@ -127,6 +128,14 @@ public Void visitLike(Like like, Map<Expression, Set<Expression>> context) {
127128
return null;
128129
}
129130

131+
@Override
132+
public Void visitOr(Or or, Map<Expression, Set<Expression>> context) {
133+
for (Expression expr : getAllSubExpressions(or)) {
134+
context.computeIfAbsent(expr, k -> new LinkedHashSet<>()).add(or);
135+
}
136+
return null;
137+
}
138+
130139
private boolean validComparisonPredicate(ComparisonPredicate comparisonPredicate) {
131140
return comparisonPredicate.right() instanceof Literal;
132141
}

fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/InferPredicates.java

Lines changed: 92 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -18,9 +18,12 @@
1818
package org.apache.doris.nereids.rules.rewrite;
1919

2020
import org.apache.doris.mysql.MysqlCommand;
21+
import org.apache.doris.nereids.CascadesContext;
2122
import org.apache.doris.nereids.jobs.JobContext;
2223
import org.apache.doris.nereids.trees.expressions.Expression;
24+
import org.apache.doris.nereids.trees.expressions.IsNull;
2325
import org.apache.doris.nereids.trees.expressions.NamedExpression;
26+
import org.apache.doris.nereids.trees.expressions.Or;
2427
import org.apache.doris.nereids.trees.expressions.Slot;
2528
import org.apache.doris.nereids.trees.expressions.StatementScopeIdGenerator;
2629
import org.apache.doris.nereids.trees.expressions.literal.BooleanLiteral;
@@ -37,15 +40,20 @@
3740
import org.apache.doris.nereids.util.PredicateInferUtils;
3841
import org.apache.doris.qe.ConnectContext;
3942

43+
import com.google.common.base.Suppliers;
4044
import com.google.common.collect.ImmutableList;
4145
import com.google.common.collect.ImmutableSet;
46+
import com.google.common.collect.Lists;
4247
import com.google.common.collect.Sets;
4348

4449
import java.util.HashMap;
50+
import java.util.HashSet;
4551
import java.util.LinkedHashSet;
52+
import java.util.List;
4653
import java.util.Map;
4754
import java.util.Optional;
4855
import java.util.Set;
56+
import java.util.function.Supplier;
4957

5058
/**
5159
* infer additional predicates for `LogicalFilter` and `LogicalJoin`.
@@ -89,20 +97,23 @@ public Plan visitLogicalJoin(LogicalJoin<? extends Plan, ? extends Plan> join, J
8997
Plan right = join.right();
9098
Set<Expression> expressions = getAllExpressions(left, right, join.getOnClauseCondition());
9199
switch (join.getJoinType()) {
92-
case INNER_JOIN:
93100
case CROSS_JOIN:
94-
case LEFT_SEMI_JOIN:
95-
case RIGHT_SEMI_JOIN:
96101
left = inferNewPredicate(left, expressions);
97102
right = inferNewPredicate(right, expressions);
98103
break;
104+
case INNER_JOIN:
105+
case LEFT_SEMI_JOIN:
106+
case RIGHT_SEMI_JOIN:
107+
left = inferNewPredicateRemoveUselessIsNull(left, expressions, join, context.getCascadesContext());
108+
right = inferNewPredicateRemoveUselessIsNull(right, expressions, join, context.getCascadesContext());
109+
break;
99110
case LEFT_OUTER_JOIN:
100111
case LEFT_ANTI_JOIN:
101-
right = inferNewPredicate(right, expressions);
112+
right = inferNewPredicateRemoveUselessIsNull(right, expressions, join, context.getCascadesContext());
102113
break;
103114
case RIGHT_OUTER_JOIN:
104115
case RIGHT_ANTI_JOIN:
105-
left = inferNewPredicate(left, expressions);
116+
left = inferNewPredicateRemoveUselessIsNull(left, expressions, join, context.getCascadesContext());
106117
break;
107118
default:
108119
break;
@@ -120,9 +131,16 @@ public Plan visitLogicalFilter(LogicalFilter<? extends Plan> filter, JobContext
120131
return new LogicalEmptyRelation(StatementScopeIdGenerator.newRelationId(), filter.getOutput());
121132
}
122133
filter = visitChildren(this, filter, context);
123-
Set<Expression> filterPredicates = pullUpPredicates(filter);
124-
filterPredicates.removeAll(pullUpAllPredicates(filter.child()));
125-
return new LogicalFilter<>(ImmutableSet.copyOf(filterPredicates), filter.child());
134+
Set<Expression> inferredPredicates = pullUpPredicates(filter);
135+
inferredPredicates.removeAll(pullUpAllPredicates(filter.child()));
136+
if (inferredPredicates.isEmpty()) {
137+
return filter.child();
138+
}
139+
if (inferredPredicates.equals(filter.getConjuncts())) {
140+
return filter;
141+
} else {
142+
return new LogicalFilter<>(ImmutableSet.copyOf(inferredPredicates), filter.child());
143+
}
126144
}
127145

128146
@Override
@@ -134,15 +152,18 @@ public Plan visitLogicalExcept(LogicalExcept except, JobContext context) {
134152
}
135153
ImmutableList.Builder<Plan> builder = ImmutableList.builder();
136154
builder.add(except.child(0));
155+
boolean changed = false;
137156
for (int i = 1; i < except.arity(); ++i) {
138157
Map<Expression, Expression> replaceMap = new HashMap<>();
139158
for (int j = 0; j < except.getOutput().size(); ++j) {
140159
NamedExpression output = except.getOutput().get(j);
141160
replaceMap.put(output, except.getRegularChildOutput(i).get(j));
142161
}
143-
builder.add(inferNewPredicate(except.child(i), ExpressionUtils.replace(baseExpressions, replaceMap)));
162+
Plan newChild = inferNewPredicate(except.child(i), ExpressionUtils.replace(baseExpressions, replaceMap));
163+
changed = changed || newChild != except.child(i);
164+
builder.add(newChild);
144165
}
145-
return except.withChildren(builder.build());
166+
return changed ? except.withChildren(builder.build()) : except;
146167
}
147168

148169
@Override
@@ -153,15 +174,18 @@ public Plan visitLogicalIntersect(LogicalIntersect intersect, JobContext context
153174
return intersect;
154175
}
155176
ImmutableList.Builder<Plan> builder = ImmutableList.builder();
177+
boolean changed = false;
156178
for (int i = 0; i < intersect.arity(); ++i) {
157179
Map<Expression, Expression> replaceMap = new HashMap<>();
158180
for (int j = 0; j < intersect.getOutput().size(); ++j) {
159181
NamedExpression output = intersect.getOutput().get(j);
160182
replaceMap.put(output, intersect.getRegularChildOutput(i).get(j));
161183
}
162-
builder.add(inferNewPredicate(intersect.child(i), ExpressionUtils.replace(baseExpressions, replaceMap)));
184+
Plan newChild = inferNewPredicate(intersect.child(i), ExpressionUtils.replace(baseExpressions, replaceMap));
185+
changed = changed || newChild != intersect.child(i);
186+
builder.add(newChild);
163187
}
164-
return intersect.withChildren(builder.build());
188+
return changed ? intersect.withChildren(builder.build()) : intersect;
165189
}
166190

167191
private Set<Expression> getAllExpressions(Plan left, Plan right, Optional<Expression> condition) {
@@ -191,4 +215,60 @@ private Plan inferNewPredicate(Plan plan, Set<Expression> expressions) {
191215
predicates.removeAll(plan.accept(pullUpAllPredicates, null));
192216
return PlanUtils.filterOrSelf(predicates, plan);
193217
}
218+
219+
// Remove redundant "or is null" from expressions.
220+
// For example, when we have a t2 left join t3 condition t2.a=t3.a, we can infer that t3.a is not null.
221+
// If we find a predicate like "t3.a = 1 or t3.a is null" in expressions, we change it to "t3.a=1".
222+
private Plan inferNewPredicateRemoveUselessIsNull(Plan plan, Set<Expression> expressions,
223+
LogicalJoin<? extends Plan, ? extends Plan> join, CascadesContext cascadesContext) {
224+
Supplier<Set<Slot>> supplier = Suppliers.memoize(() -> {
225+
Set<Expression> all = new HashSet<>();
226+
all.addAll(join.getHashJoinConjuncts());
227+
all.addAll(join.getOtherJoinConjuncts());
228+
return ExpressionUtils.inferNotNullSlots(all, cascadesContext);
229+
});
230+
231+
Set<Expression> predicates = new LinkedHashSet<>();
232+
Set<Slot> planOutputs = plan.getOutputSet();
233+
for (Expression expr : expressions) {
234+
Set<Slot> slots = expr.getInputSlots();
235+
if (slots.isEmpty() || !planOutputs.containsAll(slots)) {
236+
continue;
237+
}
238+
if (expr instanceof Or && expr.isInferred()) {
239+
List<Expression> orChildren = ExpressionUtils.extractDisjunction(expr);
240+
List<Expression> newOrChildren = Lists.newArrayList();
241+
boolean changed = false;
242+
for (Expression orChild : orChildren) {
243+
if (orChild instanceof IsNull && orChild.child(0) instanceof Slot
244+
&& supplier.get().contains(orChild.child(0))) {
245+
changed = true;
246+
continue;
247+
}
248+
newOrChildren.add(orChild);
249+
}
250+
if (changed) {
251+
if (newOrChildren.size() == 1) {
252+
predicates.add(withInferredIfSupported(newOrChildren.get(0), expr));
253+
} else if (newOrChildren.size() > 1) {
254+
predicates.add(ExpressionUtils.or(newOrChildren).withInferred(true));
255+
}
256+
} else {
257+
predicates.add(expr);
258+
}
259+
} else {
260+
predicates.add(expr);
261+
}
262+
}
263+
predicates.removeAll(plan.accept(pullUpAllPredicates, null));
264+
return PlanUtils.filterOrSelf(predicates, plan);
265+
}
266+
267+
private Expression withInferredIfSupported(Expression expression, Expression originExpr) {
268+
try {
269+
return expression.withInferred(true);
270+
} catch (RuntimeException e) {
271+
return originExpr;
272+
}
273+
}
194274
}

fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PullUpPredicates.java

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@
5353
import com.google.common.collect.Maps;
5454
import com.google.common.collect.Sets;
5555

56+
import java.util.ArrayList;
5657
import java.util.HashMap;
5758
import java.util.IdentityHashMap;
5859
import java.util.LinkedHashSet;
@@ -246,13 +247,21 @@ public ImmutableSet<Expression> visitLogicalJoin(LogicalJoin<? extends Plan, ? e
246247
break;
247248
}
248249
case LEFT_OUTER_JOIN:
250+
predicates.addAll(leftPredicates.get());
251+
predicates.addAll(
252+
generateNullTolerantPredicates(rightPredicates.get(), join.right().getOutputSet()));
253+
break;
249254
case LEFT_SEMI_JOIN:
250255
case LEFT_ANTI_JOIN:
251256
case NULL_AWARE_LEFT_ANTI_JOIN: {
252257
predicates.addAll(leftPredicates.get());
253258
break;
254259
}
255260
case RIGHT_OUTER_JOIN:
261+
predicates.addAll(rightPredicates.get());
262+
predicates.addAll(
263+
generateNullTolerantPredicates(leftPredicates.get(), join.left().getOutputSet()));
264+
break;
256265
case RIGHT_SEMI_JOIN:
257266
case RIGHT_ANTI_JOIN: {
258267
predicates.addAll(rightPredicates.get());
@@ -346,6 +355,30 @@ private boolean hasAgg(Expression expression) {
346355
return expression.anyMatch(AggregateFunction.class::isInstance);
347356
}
348357

358+
private Set<Expression> generateNullTolerantPredicates(Set<Expression> predicates, Set<Slot> nullableSlots) {
359+
if (predicates.isEmpty() || nullableSlots.isEmpty()) {
360+
return predicates;
361+
}
362+
Set<Expression> tolerant = Sets.newLinkedHashSetWithExpectedSize(predicates.size());
363+
for (Expression predicate : predicates) {
364+
Set<Slot> predicateSlots = predicate.getInputSlots();
365+
List<Expression> orChildren = new ArrayList<>();
366+
if (predicateSlots.size() == 1) {
367+
Slot slot = predicateSlots.iterator().next();
368+
if (nullableSlots.contains(slot)) {
369+
orChildren.add(new IsNull(slot));
370+
}
371+
}
372+
if (!orChildren.isEmpty()) {
373+
List<Expression> expandedOr = new ArrayList<>(2);
374+
expandedOr.add(predicate);
375+
expandedOr.addAll(orChildren);
376+
tolerant.add(ExpressionUtils.or(expandedOr));
377+
}
378+
}
379+
return tolerant;
380+
}
381+
349382
private ImmutableSet<Expression> getFiltersFromUnionChild(LogicalUnion union, Void context) {
350383
Set<Expression> filters = new LinkedHashSet<>();
351384
for (int i = 0; i < union.getArity(); ++i) {

fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/CompoundPredicate.java

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,11 @@ public abstract class CompoundPredicate extends Expression implements ExpectsInp
3737
private String symbol;
3838

3939
public CompoundPredicate(List<Expression> children, String symbol) {
40-
super(children);
40+
this(children, symbol, false);
41+
}
42+
43+
public CompoundPredicate(List<Expression> children, String symbol, boolean inferred) {
44+
super(children, inferred);
4145
this.symbol = symbol;
4246
}
4347

fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/Or.java

Lines changed: 18 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -36,19 +36,28 @@ public class Or extends CompoundPredicate {
3636
* @param right right child of comparison predicate
3737
*/
3838
public Or(Expression left, Expression right) {
39-
super(ExpressionUtils.mergeList(
39+
this(left, right, false);
40+
}
41+
42+
public Or(Expression left, Expression right, boolean inferred) {
43+
this(ExpressionUtils.mergeList(
4044
ExpressionUtils.extractDisjunction(left),
41-
ExpressionUtils.extractDisjunction(right)), "OR");
45+
ExpressionUtils.extractDisjunction(right)), inferred);
4246
}
4347

4448
public Or(List<Expression> children) {
45-
super(children, "OR");
49+
this(children, false);
50+
}
51+
52+
public Or(List<Expression> children, boolean inferred) {
53+
super(children, "OR", inferred);
54+
Preconditions.checkArgument(children.size() >= 2);
4655
}
4756

4857
@Override
4958
public Expression withChildren(List<Expression> children) {
5059
Preconditions.checkArgument(children.size() >= 2);
51-
return new Or(children);
60+
return new Or(children, this.isInferred());
5261
}
5362

5463
@Override
@@ -89,4 +98,9 @@ public List<Expression> children() {
8998
}
9099
return flattenChildren;
91100
}
101+
102+
@Override
103+
public Expression withInferred(boolean inferred) {
104+
return new Or(children, inferred);
105+
}
92106
}

fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/logical/LogicalAggregate.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -284,7 +284,7 @@ public LogicalAggregate<Plan> withChildGroupByAndOutput(List<Expression> groupBy
284284

285285
public LogicalAggregate<Plan> withChildGroupByAndOutputAndSourceRepeat(List<Expression> groupByExprList,
286286
List<NamedExpression> outputExpressionList, Plan newChild,
287-
Optional<LogicalRepeat<?>> sourceRepeat) {
287+
Optional<LogicalRepeat<? extends Plan>> sourceRepeat) {
288288
return new LogicalAggregate<>(groupByExprList, outputExpressionList, normalized, ordinalIsResolved, generated,
289289
hasPushed, sourceRepeat, Optional.empty(), Optional.empty(), newChild);
290290
}

fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/InferPredicateByReplaceTest.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -152,7 +152,7 @@ public void testInferWithOrPredicate() {
152152
inputs.add(equalTo);
153153

154154
Set<Expression> result = InferPredicateByReplace.infer(inputs);
155-
Assertions.assertEquals(2, result.size());
155+
Assertions.assertEquals(3, result.size());
156156
}
157157

158158
@Test

fe/fe-core/src/test/java/org/apache/doris/nereids/sqltest/InferTest.java

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -49,9 +49,10 @@ void testInferNotNullFromFilterAndEliminateOuter2() {
4949
.printlnTree()
5050
.matches(
5151
innerLogicalJoin(
52-
logicalOlapScan(),
53-
logicalFilter().when(
54-
f -> f.getPredicate().toString().equals("OR[(id#0 = 4),(id#0 > 4)]"))
52+
logicalFilter().when(
53+
f -> f.getPredicate().toString().equals("OR[(id#2 = 4),(id#2 > 4)]")),
54+
logicalFilter().when(
55+
f -> f.getPredicate().toString().equals("OR[(id#0 = 4),(id#0 > 4)]"))
5556
)
5657

5758
);
@@ -70,7 +71,8 @@ void testInferNotNullFromFilterAndEliminateOuter3() {
7071
leftOuterLogicalJoin(
7172
logicalFilter().when(
7273
f -> f.getPredicate().toString().equals("OR[(id#0 = 4),(id#0 > 4)]")),
73-
logicalOlapScan()
74+
logicalFilter().when(
75+
f -> f.getPredicate().toString().equals("OR[(id#2 = 4),(id#2 > 4)]"))
7476
)
7577
).when(f -> f.getPredicate().toString()
7678
.equals("OR[(id#0 = 4),AND[(id#0 > 4),score#3 IS NULL]]"))

0 commit comments

Comments
 (0)