Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
import org.apache.doris.nereids.trees.expressions.InPredicate;
import org.apache.doris.nereids.trees.expressions.Like;
import org.apache.doris.nereids.trees.expressions.Not;
import org.apache.doris.nereids.trees.expressions.Or;
import org.apache.doris.nereids.trees.expressions.Slot;
import org.apache.doris.nereids.trees.expressions.functions.ExpressionTrait;
import org.apache.doris.nereids.trees.expressions.literal.Literal;
Expand Down Expand Up @@ -127,6 +128,14 @@ public Void visitLike(Like like, Map<Expression, Set<Expression>> context) {
return null;
}

@Override
public Void visitOr(Or or, Map<Expression, Set<Expression>> context) {
for (Expression expr : getAllSubExpressions(or)) {
context.computeIfAbsent(expr, k -> new LinkedHashSet<>()).add(or);
}
return null;
}

private boolean validComparisonPredicate(ComparisonPredicate comparisonPredicate) {
return comparisonPredicate.right() instanceof Literal;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,12 @@
package org.apache.doris.nereids.rules.rewrite;

import org.apache.doris.mysql.MysqlCommand;
import org.apache.doris.nereids.CascadesContext;
import org.apache.doris.nereids.jobs.JobContext;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.IsNull;
import org.apache.doris.nereids.trees.expressions.NamedExpression;
import org.apache.doris.nereids.trees.expressions.Or;
import org.apache.doris.nereids.trees.expressions.Slot;
import org.apache.doris.nereids.trees.expressions.StatementScopeIdGenerator;
import org.apache.doris.nereids.trees.expressions.literal.BooleanLiteral;
Expand All @@ -37,15 +40,20 @@
import org.apache.doris.nereids.util.PredicateInferUtils;
import org.apache.doris.qe.ConnectContext;

import com.google.common.base.Suppliers;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableSet;
import com.google.common.collect.Lists;
import com.google.common.collect.Sets;

import java.util.HashMap;
import java.util.HashSet;
import java.util.LinkedHashSet;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.Set;
import java.util.function.Supplier;

/**
* infer additional predicates for `LogicalFilter` and `LogicalJoin`.
Expand Down Expand Up @@ -89,20 +97,23 @@ public Plan visitLogicalJoin(LogicalJoin<? extends Plan, ? extends Plan> join, J
Plan right = join.right();
Set<Expression> expressions = getAllExpressions(left, right, join.getOnClauseCondition());
switch (join.getJoinType()) {
case INNER_JOIN:
case CROSS_JOIN:
case LEFT_SEMI_JOIN:
case RIGHT_SEMI_JOIN:
left = inferNewPredicate(left, expressions);
right = inferNewPredicate(right, expressions);
break;
case INNER_JOIN:
case LEFT_SEMI_JOIN:
case RIGHT_SEMI_JOIN:
left = inferNewPredicateRemoveUselessIsNull(left, expressions, join, context.getCascadesContext());
right = inferNewPredicateRemoveUselessIsNull(right, expressions, join, context.getCascadesContext());
break;
case LEFT_OUTER_JOIN:
case LEFT_ANTI_JOIN:
right = inferNewPredicate(right, expressions);
right = inferNewPredicateRemoveUselessIsNull(right, expressions, join, context.getCascadesContext());
break;
case RIGHT_OUTER_JOIN:
case RIGHT_ANTI_JOIN:
left = inferNewPredicate(left, expressions);
left = inferNewPredicateRemoveUselessIsNull(left, expressions, join, context.getCascadesContext());
break;
default:
break;
Expand All @@ -120,9 +131,16 @@ public Plan visitLogicalFilter(LogicalFilter<? extends Plan> filter, JobContext
return new LogicalEmptyRelation(StatementScopeIdGenerator.newRelationId(), filter.getOutput());
}
filter = visitChildren(this, filter, context);
Set<Expression> filterPredicates = pullUpPredicates(filter);
filterPredicates.removeAll(pullUpAllPredicates(filter.child()));
return new LogicalFilter<>(ImmutableSet.copyOf(filterPredicates), filter.child());
Set<Expression> inferredPredicates = pullUpPredicates(filter);
inferredPredicates.removeAll(pullUpAllPredicates(filter.child()));
if (inferredPredicates.isEmpty()) {
return filter.child();
}
if (inferredPredicates.equals(filter.getConjuncts())) {
return filter;
} else {
return new LogicalFilter<>(ImmutableSet.copyOf(inferredPredicates), filter.child());
}
}

@Override
Expand All @@ -134,15 +152,18 @@ public Plan visitLogicalExcept(LogicalExcept except, JobContext context) {
}
ImmutableList.Builder<Plan> builder = ImmutableList.builder();
builder.add(except.child(0));
boolean changed = false;
for (int i = 1; i < except.arity(); ++i) {
Map<Expression, Expression> replaceMap = new HashMap<>();
for (int j = 0; j < except.getOutput().size(); ++j) {
NamedExpression output = except.getOutput().get(j);
replaceMap.put(output, except.getRegularChildOutput(i).get(j));
}
builder.add(inferNewPredicate(except.child(i), ExpressionUtils.replace(baseExpressions, replaceMap)));
Plan newChild = inferNewPredicate(except.child(i), ExpressionUtils.replace(baseExpressions, replaceMap));
changed = changed || newChild != except.child(i);
builder.add(newChild);
}
return except.withChildren(builder.build());
return changed ? except.withChildren(builder.build()) : except;
}

@Override
Expand All @@ -153,15 +174,18 @@ public Plan visitLogicalIntersect(LogicalIntersect intersect, JobContext context
return intersect;
}
ImmutableList.Builder<Plan> builder = ImmutableList.builder();
boolean changed = false;
for (int i = 0; i < intersect.arity(); ++i) {
Map<Expression, Expression> replaceMap = new HashMap<>();
for (int j = 0; j < intersect.getOutput().size(); ++j) {
NamedExpression output = intersect.getOutput().get(j);
replaceMap.put(output, intersect.getRegularChildOutput(i).get(j));
}
builder.add(inferNewPredicate(intersect.child(i), ExpressionUtils.replace(baseExpressions, replaceMap)));
Plan newChild = inferNewPredicate(intersect.child(i), ExpressionUtils.replace(baseExpressions, replaceMap));
changed = changed || newChild != intersect.child(i);
builder.add(newChild);
}
return intersect.withChildren(builder.build());
return changed ? intersect.withChildren(builder.build()) : intersect;
}

private Set<Expression> getAllExpressions(Plan left, Plan right, Optional<Expression> condition) {
Expand Down Expand Up @@ -191,4 +215,60 @@ private Plan inferNewPredicate(Plan plan, Set<Expression> expressions) {
predicates.removeAll(plan.accept(pullUpAllPredicates, null));
return PlanUtils.filterOrSelf(predicates, plan);
}

// Remove redundant "or is null" from expressions.
// For example, when we have a t2 left join t3 condition t2.a=t3.a, we can infer that t3.a is not null.
// If we find a predicate like "t3.a = 1 or t3.a is null" in expressions, we change it to "t3.a=1".
private Plan inferNewPredicateRemoveUselessIsNull(Plan plan, Set<Expression> expressions,
LogicalJoin<? extends Plan, ? extends Plan> join, CascadesContext cascadesContext) {
Supplier<Set<Slot>> supplier = Suppliers.memoize(() -> {
Set<Expression> all = new HashSet<>();
all.addAll(join.getHashJoinConjuncts());
all.addAll(join.getOtherJoinConjuncts());
return ExpressionUtils.inferNotNullSlots(all, cascadesContext);
});

Set<Expression> predicates = new LinkedHashSet<>();
Set<Slot> planOutputs = plan.getOutputSet();
for (Expression expr : expressions) {
Set<Slot> slots = expr.getInputSlots();
if (slots.isEmpty() || !planOutputs.containsAll(slots)) {
continue;
}
if (expr instanceof Or && expr.isInferred()) {
List<Expression> orChildren = ExpressionUtils.extractDisjunction(expr);
List<Expression> newOrChildren = Lists.newArrayList();
boolean changed = false;
for (Expression orChild : orChildren) {
if (orChild instanceof IsNull && orChild.child(0) instanceof Slot
&& supplier.get().contains(orChild.child(0))) {
changed = true;
continue;
}
newOrChildren.add(orChild);
}
if (changed) {
if (newOrChildren.size() == 1) {
predicates.add(withInferredIfSupported(newOrChildren.get(0), expr));
} else if (newOrChildren.size() > 1) {
predicates.add(ExpressionUtils.or(newOrChildren).withInferred(true));
}
} else {
predicates.add(expr);
}
} else {
predicates.add(expr);
}
}
predicates.removeAll(plan.accept(pullUpAllPredicates, null));
return PlanUtils.filterOrSelf(predicates, plan);
}

private Expression withInferredIfSupported(Expression expression, Expression originExpr) {
try {
return expression.withInferred(true);
} catch (RuntimeException e) {
return originExpr;
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@
import com.google.common.collect.Maps;
import com.google.common.collect.Sets;

import java.util.ArrayList;
import java.util.HashMap;
import java.util.IdentityHashMap;
import java.util.LinkedHashSet;
Expand Down Expand Up @@ -246,13 +247,21 @@ public ImmutableSet<Expression> visitLogicalJoin(LogicalJoin<? extends Plan, ? e
break;
}
case LEFT_OUTER_JOIN:
predicates.addAll(leftPredicates.get());
predicates.addAll(
generateNullTolerantPredicates(rightPredicates.get(), join.right().getOutputSet()));
break;
case LEFT_SEMI_JOIN:
case LEFT_ANTI_JOIN:
case NULL_AWARE_LEFT_ANTI_JOIN: {
predicates.addAll(leftPredicates.get());
break;
}
case RIGHT_OUTER_JOIN:
predicates.addAll(rightPredicates.get());
predicates.addAll(
generateNullTolerantPredicates(leftPredicates.get(), join.left().getOutputSet()));
break;
case RIGHT_SEMI_JOIN:
case RIGHT_ANTI_JOIN: {
predicates.addAll(rightPredicates.get());
Expand Down Expand Up @@ -346,6 +355,30 @@ private boolean hasAgg(Expression expression) {
return expression.anyMatch(AggregateFunction.class::isInstance);
}

private Set<Expression> generateNullTolerantPredicates(Set<Expression> predicates, Set<Slot> nullableSlots) {
if (predicates.isEmpty() || nullableSlots.isEmpty()) {
return predicates;
}
Set<Expression> tolerant = Sets.newLinkedHashSetWithExpectedSize(predicates.size());
for (Expression predicate : predicates) {
Set<Slot> predicateSlots = predicate.getInputSlots();
List<Expression> orChildren = new ArrayList<>();
if (predicateSlots.size() == 1) {
Slot slot = predicateSlots.iterator().next();
if (nullableSlots.contains(slot)) {
orChildren.add(new IsNull(slot));
}
}
if (!orChildren.isEmpty()) {
List<Expression> expandedOr = new ArrayList<>(2);
expandedOr.add(predicate);
expandedOr.addAll(orChildren);
tolerant.add(ExpressionUtils.or(expandedOr));
}
}
return tolerant;
}

private ImmutableSet<Expression> getFiltersFromUnionChild(LogicalUnion union, Void context) {
Set<Expression> filters = new LinkedHashSet<>();
for (int i = 0; i < union.getArity(); ++i) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,11 @@ public abstract class CompoundPredicate extends Expression implements ExpectsInp
private String symbol;

public CompoundPredicate(List<Expression> children, String symbol) {
super(children);
this(children, symbol, false);
}

public CompoundPredicate(List<Expression> children, String symbol, boolean inferred) {
super(children, inferred);
this.symbol = symbol;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,19 +36,28 @@ public class Or extends CompoundPredicate {
* @param right right child of comparison predicate
*/
public Or(Expression left, Expression right) {
super(ExpressionUtils.mergeList(
this(left, right, false);
}

public Or(Expression left, Expression right, boolean inferred) {
this(ExpressionUtils.mergeList(
ExpressionUtils.extractDisjunction(left),
ExpressionUtils.extractDisjunction(right)), "OR");
ExpressionUtils.extractDisjunction(right)), inferred);
}

public Or(List<Expression> children) {
super(children, "OR");
this(children, false);
}

public Or(List<Expression> children, boolean inferred) {
super(children, "OR", inferred);
Preconditions.checkArgument(children.size() >= 2);
}

@Override
public Expression withChildren(List<Expression> children) {
Preconditions.checkArgument(children.size() >= 2);
return new Or(children);
return new Or(children, this.isInferred());
}

@Override
Expand Down Expand Up @@ -89,4 +98,9 @@ public List<Expression> children() {
}
return flattenChildren;
}

@Override
public Expression withInferred(boolean inferred) {
return new Or(children, inferred);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -284,7 +284,7 @@ public LogicalAggregate<Plan> withChildGroupByAndOutput(List<Expression> groupBy

public LogicalAggregate<Plan> withChildGroupByAndOutputAndSourceRepeat(List<Expression> groupByExprList,
List<NamedExpression> outputExpressionList, Plan newChild,
Optional<LogicalRepeat<?>> sourceRepeat) {
Optional<LogicalRepeat<? extends Plan>> sourceRepeat) {
return new LogicalAggregate<>(groupByExprList, outputExpressionList, normalized, ordinalIsResolved, generated,
hasPushed, sourceRepeat, Optional.empty(), Optional.empty(), newChild);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,7 @@ public void testInferWithOrPredicate() {
inputs.add(equalTo);

Set<Expression> result = InferPredicateByReplace.infer(inputs);
Assertions.assertEquals(2, result.size());
Assertions.assertEquals(3, result.size());
}

@Test
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,9 +49,10 @@ void testInferNotNullFromFilterAndEliminateOuter2() {
.printlnTree()
.matches(
innerLogicalJoin(
logicalOlapScan(),
logicalFilter().when(
f -> f.getPredicate().toString().equals("OR[(id#0 = 4),(id#0 > 4)]"))
logicalFilter().when(
f -> f.getPredicate().toString().equals("OR[(id#2 = 4),(id#2 > 4)]")),
logicalFilter().when(
f -> f.getPredicate().toString().equals("OR[(id#0 = 4),(id#0 > 4)]"))
)

);
Expand All @@ -70,7 +71,8 @@ void testInferNotNullFromFilterAndEliminateOuter3() {
leftOuterLogicalJoin(
logicalFilter().when(
f -> f.getPredicate().toString().equals("OR[(id#0 = 4),(id#0 > 4)]")),
logicalOlapScan()
logicalFilter().when(
f -> f.getPredicate().toString().equals("OR[(id#2 = 4),(id#2 > 4)]"))
)
).when(f -> f.getPredicate().toString()
.equals("OR[(id#0 = 4),AND[(id#0 > 4),score#3 IS NULL]]"))
Expand Down
Loading
Loading