diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/InferPredicateByReplace.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/InferPredicateByReplace.java index 4fc9efc1943ba6..ce5c90b75a48d8 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/InferPredicateByReplace.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/InferPredicateByReplace.java @@ -29,6 +29,7 @@ import org.apache.doris.nereids.trees.expressions.InPredicate; import org.apache.doris.nereids.trees.expressions.Like; import org.apache.doris.nereids.trees.expressions.Not; +import org.apache.doris.nereids.trees.expressions.Or; import org.apache.doris.nereids.trees.expressions.Slot; import org.apache.doris.nereids.trees.expressions.functions.ExpressionTrait; import org.apache.doris.nereids.trees.expressions.literal.Literal; @@ -127,6 +128,14 @@ public Void visitLike(Like like, Map> context) { return null; } + @Override + public Void visitOr(Or or, Map> context) { + for (Expression expr : getAllSubExpressions(or)) { + context.computeIfAbsent(expr, k -> new LinkedHashSet<>()).add(or); + } + return null; + } + private boolean validComparisonPredicate(ComparisonPredicate comparisonPredicate) { return comparisonPredicate.right() instanceof Literal; } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/InferPredicates.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/InferPredicates.java index 4ca9cfad478796..07852d5274bf86 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/InferPredicates.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/InferPredicates.java @@ -18,9 +18,12 @@ package org.apache.doris.nereids.rules.rewrite; import org.apache.doris.mysql.MysqlCommand; +import org.apache.doris.nereids.CascadesContext; import org.apache.doris.nereids.jobs.JobContext; import org.apache.doris.nereids.trees.expressions.Expression; +import org.apache.doris.nereids.trees.expressions.IsNull; import org.apache.doris.nereids.trees.expressions.NamedExpression; +import org.apache.doris.nereids.trees.expressions.Or; import org.apache.doris.nereids.trees.expressions.Slot; import org.apache.doris.nereids.trees.expressions.StatementScopeIdGenerator; import org.apache.doris.nereids.trees.expressions.literal.BooleanLiteral; @@ -37,15 +40,20 @@ import org.apache.doris.nereids.util.PredicateInferUtils; import org.apache.doris.qe.ConnectContext; +import com.google.common.base.Suppliers; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableSet; +import com.google.common.collect.Lists; import com.google.common.collect.Sets; import java.util.HashMap; +import java.util.HashSet; import java.util.LinkedHashSet; +import java.util.List; import java.util.Map; import java.util.Optional; import java.util.Set; +import java.util.function.Supplier; /** * infer additional predicates for `LogicalFilter` and `LogicalJoin`. @@ -89,20 +97,23 @@ public Plan visitLogicalJoin(LogicalJoin join, J Plan right = join.right(); Set expressions = getAllExpressions(left, right, join.getOnClauseCondition()); switch (join.getJoinType()) { - case INNER_JOIN: case CROSS_JOIN: - case LEFT_SEMI_JOIN: - case RIGHT_SEMI_JOIN: left = inferNewPredicate(left, expressions); right = inferNewPredicate(right, expressions); break; + case INNER_JOIN: + case LEFT_SEMI_JOIN: + case RIGHT_SEMI_JOIN: + left = inferNewPredicateRemoveUselessIsNull(left, expressions, join, context.getCascadesContext()); + right = inferNewPredicateRemoveUselessIsNull(right, expressions, join, context.getCascadesContext()); + break; case LEFT_OUTER_JOIN: case LEFT_ANTI_JOIN: - right = inferNewPredicate(right, expressions); + right = inferNewPredicateRemoveUselessIsNull(right, expressions, join, context.getCascadesContext()); break; case RIGHT_OUTER_JOIN: case RIGHT_ANTI_JOIN: - left = inferNewPredicate(left, expressions); + left = inferNewPredicateRemoveUselessIsNull(left, expressions, join, context.getCascadesContext()); break; default: break; @@ -120,9 +131,16 @@ public Plan visitLogicalFilter(LogicalFilter filter, JobContext return new LogicalEmptyRelation(StatementScopeIdGenerator.newRelationId(), filter.getOutput()); } filter = visitChildren(this, filter, context); - Set filterPredicates = pullUpPredicates(filter); - filterPredicates.removeAll(pullUpAllPredicates(filter.child())); - return new LogicalFilter<>(ImmutableSet.copyOf(filterPredicates), filter.child()); + Set inferredPredicates = pullUpPredicates(filter); + inferredPredicates.removeAll(pullUpAllPredicates(filter.child())); + if (inferredPredicates.isEmpty()) { + return filter.child(); + } + if (inferredPredicates.equals(filter.getConjuncts())) { + return filter; + } else { + return new LogicalFilter<>(ImmutableSet.copyOf(inferredPredicates), filter.child()); + } } @Override @@ -134,15 +152,18 @@ public Plan visitLogicalExcept(LogicalExcept except, JobContext context) { } ImmutableList.Builder builder = ImmutableList.builder(); builder.add(except.child(0)); + boolean changed = false; for (int i = 1; i < except.arity(); ++i) { Map replaceMap = new HashMap<>(); for (int j = 0; j < except.getOutput().size(); ++j) { NamedExpression output = except.getOutput().get(j); replaceMap.put(output, except.getRegularChildOutput(i).get(j)); } - builder.add(inferNewPredicate(except.child(i), ExpressionUtils.replace(baseExpressions, replaceMap))); + Plan newChild = inferNewPredicate(except.child(i), ExpressionUtils.replace(baseExpressions, replaceMap)); + changed = changed || newChild != except.child(i); + builder.add(newChild); } - return except.withChildren(builder.build()); + return changed ? except.withChildren(builder.build()) : except; } @Override @@ -153,15 +174,18 @@ public Plan visitLogicalIntersect(LogicalIntersect intersect, JobContext context return intersect; } ImmutableList.Builder builder = ImmutableList.builder(); + boolean changed = false; for (int i = 0; i < intersect.arity(); ++i) { Map replaceMap = new HashMap<>(); for (int j = 0; j < intersect.getOutput().size(); ++j) { NamedExpression output = intersect.getOutput().get(j); replaceMap.put(output, intersect.getRegularChildOutput(i).get(j)); } - builder.add(inferNewPredicate(intersect.child(i), ExpressionUtils.replace(baseExpressions, replaceMap))); + Plan newChild = inferNewPredicate(intersect.child(i), ExpressionUtils.replace(baseExpressions, replaceMap)); + changed = changed || newChild != intersect.child(i); + builder.add(newChild); } - return intersect.withChildren(builder.build()); + return changed ? intersect.withChildren(builder.build()) : intersect; } private Set getAllExpressions(Plan left, Plan right, Optional condition) { @@ -191,4 +215,60 @@ private Plan inferNewPredicate(Plan plan, Set expressions) { predicates.removeAll(plan.accept(pullUpAllPredicates, null)); return PlanUtils.filterOrSelf(predicates, plan); } + + // Remove redundant "or is null" from expressions. + // For example, when we have a t2 left join t3 condition t2.a=t3.a, we can infer that t3.a is not null. + // If we find a predicate like "t3.a = 1 or t3.a is null" in expressions, we change it to "t3.a=1". + private Plan inferNewPredicateRemoveUselessIsNull(Plan plan, Set expressions, + LogicalJoin join, CascadesContext cascadesContext) { + Supplier> supplier = Suppliers.memoize(() -> { + Set all = new HashSet<>(); + all.addAll(join.getHashJoinConjuncts()); + all.addAll(join.getOtherJoinConjuncts()); + return ExpressionUtils.inferNotNullSlots(all, cascadesContext); + }); + + Set predicates = new LinkedHashSet<>(); + Set planOutputs = plan.getOutputSet(); + for (Expression expr : expressions) { + Set slots = expr.getInputSlots(); + if (slots.isEmpty() || !planOutputs.containsAll(slots)) { + continue; + } + if (expr instanceof Or && expr.isInferred()) { + List orChildren = ExpressionUtils.extractDisjunction(expr); + List newOrChildren = Lists.newArrayList(); + boolean changed = false; + for (Expression orChild : orChildren) { + if (orChild instanceof IsNull && orChild.child(0) instanceof Slot + && supplier.get().contains(orChild.child(0))) { + changed = true; + continue; + } + newOrChildren.add(orChild); + } + if (changed) { + if (newOrChildren.size() == 1) { + predicates.add(withInferredIfSupported(newOrChildren.get(0), expr)); + } else if (newOrChildren.size() > 1) { + predicates.add(ExpressionUtils.or(newOrChildren).withInferred(true)); + } + } else { + predicates.add(expr); + } + } else { + predicates.add(expr); + } + } + predicates.removeAll(plan.accept(pullUpAllPredicates, null)); + return PlanUtils.filterOrSelf(predicates, plan); + } + + private Expression withInferredIfSupported(Expression expression, Expression originExpr) { + try { + return expression.withInferred(true); + } catch (RuntimeException e) { + return originExpr; + } + } } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PullUpPredicates.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PullUpPredicates.java index 57d3b0dc4e4415..23950c48e3cfab 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PullUpPredicates.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PullUpPredicates.java @@ -53,6 +53,7 @@ import com.google.common.collect.Maps; import com.google.common.collect.Sets; +import java.util.ArrayList; import java.util.HashMap; import java.util.IdentityHashMap; import java.util.LinkedHashSet; @@ -246,6 +247,10 @@ public ImmutableSet visitLogicalJoin(LogicalJoin visitLogicalJoin(LogicalJoin generateNullTolerantPredicates(Set predicates, Set nullableSlots) { + if (predicates.isEmpty() || nullableSlots.isEmpty()) { + return predicates; + } + Set tolerant = Sets.newLinkedHashSetWithExpectedSize(predicates.size()); + for (Expression predicate : predicates) { + Set predicateSlots = predicate.getInputSlots(); + List orChildren = new ArrayList<>(); + if (predicateSlots.size() == 1) { + Slot slot = predicateSlots.iterator().next(); + if (nullableSlots.contains(slot)) { + orChildren.add(new IsNull(slot)); + } + } + if (!orChildren.isEmpty()) { + List expandedOr = new ArrayList<>(2); + expandedOr.add(predicate); + expandedOr.addAll(orChildren); + tolerant.add(ExpressionUtils.or(expandedOr)); + } + } + return tolerant; + } + private ImmutableSet getFiltersFromUnionChild(LogicalUnion union, Void context) { Set filters = new LinkedHashSet<>(); for (int i = 0; i < union.getArity(); ++i) { diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/CompoundPredicate.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/CompoundPredicate.java index 9b1535eb9cc3c9..3e04ee53d90d96 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/CompoundPredicate.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/CompoundPredicate.java @@ -37,7 +37,11 @@ public abstract class CompoundPredicate extends Expression implements ExpectsInp private String symbol; public CompoundPredicate(List children, String symbol) { - super(children); + this(children, symbol, false); + } + + public CompoundPredicate(List children, String symbol, boolean inferred) { + super(children, inferred); this.symbol = symbol; } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/Or.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/Or.java index cf6c46c3ea4215..b62c0e76b40715 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/Or.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/Or.java @@ -36,19 +36,28 @@ public class Or extends CompoundPredicate { * @param right right child of comparison predicate */ public Or(Expression left, Expression right) { - super(ExpressionUtils.mergeList( + this(left, right, false); + } + + public Or(Expression left, Expression right, boolean inferred) { + this(ExpressionUtils.mergeList( ExpressionUtils.extractDisjunction(left), - ExpressionUtils.extractDisjunction(right)), "OR"); + ExpressionUtils.extractDisjunction(right)), inferred); } public Or(List children) { - super(children, "OR"); + this(children, false); + } + + public Or(List children, boolean inferred) { + super(children, "OR", inferred); + Preconditions.checkArgument(children.size() >= 2); } @Override public Expression withChildren(List children) { Preconditions.checkArgument(children.size() >= 2); - return new Or(children); + return new Or(children, this.isInferred()); } @Override @@ -89,4 +98,9 @@ public List children() { } return flattenChildren; } + + @Override + public Expression withInferred(boolean inferred) { + return new Or(children, inferred); + } } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/logical/LogicalAggregate.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/logical/LogicalAggregate.java index fbea327f7a865a..fd224426c2cc6f 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/logical/LogicalAggregate.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/logical/LogicalAggregate.java @@ -284,7 +284,7 @@ public LogicalAggregate withChildGroupByAndOutput(List groupBy public LogicalAggregate withChildGroupByAndOutputAndSourceRepeat(List groupByExprList, List outputExpressionList, Plan newChild, - Optional> sourceRepeat) { + Optional> sourceRepeat) { return new LogicalAggregate<>(groupByExprList, outputExpressionList, normalized, ordinalIsResolved, generated, hasPushed, sourceRepeat, Optional.empty(), Optional.empty(), newChild); } diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/InferPredicateByReplaceTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/InferPredicateByReplaceTest.java index 98fbbfbec13f2e..ad35028d7b9467 100644 --- a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/InferPredicateByReplaceTest.java +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/InferPredicateByReplaceTest.java @@ -152,7 +152,7 @@ public void testInferWithOrPredicate() { inputs.add(equalTo); Set result = InferPredicateByReplace.infer(inputs); - Assertions.assertEquals(2, result.size()); + Assertions.assertEquals(3, result.size()); } @Test diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/sqltest/InferTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/sqltest/InferTest.java index 613965b1238e27..d7853012e6f3cd 100644 --- a/fe/fe-core/src/test/java/org/apache/doris/nereids/sqltest/InferTest.java +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/sqltest/InferTest.java @@ -49,9 +49,10 @@ void testInferNotNullFromFilterAndEliminateOuter2() { .printlnTree() .matches( innerLogicalJoin( - logicalOlapScan(), - logicalFilter().when( - f -> f.getPredicate().toString().equals("OR[(id#0 = 4),(id#0 > 4)]")) + logicalFilter().when( + f -> f.getPredicate().toString().equals("OR[(id#2 = 4),(id#2 > 4)]")), + logicalFilter().when( + f -> f.getPredicate().toString().equals("OR[(id#0 = 4),(id#0 > 4)]")) ) ); @@ -70,7 +71,8 @@ void testInferNotNullFromFilterAndEliminateOuter3() { leftOuterLogicalJoin( logicalFilter().when( f -> f.getPredicate().toString().equals("OR[(id#0 = 4),(id#0 > 4)]")), - logicalOlapScan() + logicalFilter().when( + f -> f.getPredicate().toString().equals("OR[(id#2 = 4),(id#2 > 4)]")) ) ).when(f -> f.getPredicate().toString() .equals("OR[(id#0 = 4),AND[(id#0 > 4),score#3 IS NULL]]")) diff --git a/regression-test/data/nereids_rules_p0/infer_predicate/extend_infer_equal_predicate.out b/regression-test/data/nereids_rules_p0/infer_predicate/extend_infer_equal_predicate.out index ed43d254b5063f..7ec049b663ce33 100644 --- a/regression-test/data/nereids_rules_p0/infer_predicate/extend_infer_equal_predicate.out +++ b/regression-test/data/nereids_rules_p0/infer_predicate/extend_infer_equal_predicate.out @@ -114,14 +114,16 @@ PhysicalResultSink --hashJoin[INNER_JOIN] hashCondition=((t1.a = t2.a)) otherCondition=() ----filter(OR[(t1.a < 2),(t1.a > 10)]) ------PhysicalOlapScan[extend_infer_t3] -----PhysicalOlapScan[extend_infer_t4] +----filter(OR[(t2.a < 2),(t2.a > 10)]) +------PhysicalOlapScan[extend_infer_t4] -- !test_or2 -- PhysicalResultSink --hashJoin[INNER_JOIN] hashCondition=((t1.a = t2.a)) otherCondition=() ----filter(OR[(t1.a < 2),(t1.a > 10)]) ------PhysicalOlapScan[extend_infer_t3] -----PhysicalOlapScan[extend_infer_t4] +----filter(OR[(t2.a < 2),(t2.a > 10)]) +------PhysicalOlapScan[extend_infer_t4] -- !test_sign_predicate -- PhysicalResultSink @@ -771,3 +773,116 @@ PhysicalResultSink -- !pull_up_from_agg -- 0 +-- !qt_leftjoin_right_pull_up_shape_shape -- +PhysicalResultSink +--hashJoin[LEFT_OUTER_JOIN] hashCondition=((t2.a = t3.a)) otherCondition=() +----hashJoin[LEFT_OUTER_JOIN] hashCondition=((t1.a = t2.a)) otherCondition=() +------filter((t1.a = 1)) +--------PhysicalOlapScan[extend_infer_t3] +------filter((t2.a = 1)) +--------PhysicalOlapScan[extend_infer_t4] +----filter((t3.a = 1)) +------PhysicalOlapScan[extend_infer_t5] + +-- !qt_leftjoin_right_pull_up_shape_result -- + +-- !qt_multi_leftjoin_right_pull_up_shape_shape -- +PhysicalResultSink +--hashJoin[LEFT_OUTER_JOIN] hashCondition=((t2.a = t5.a)) otherCondition=() +----hashJoin[LEFT_OUTER_JOIN] hashCondition=((t4.a = t3.a)) otherCondition=() +------hashJoin[LEFT_OUTER_JOIN] hashCondition=((t2.a = t3.a)) otherCondition=() +--------hashJoin[LEFT_OUTER_JOIN] hashCondition=((t1.a = t2.a)) otherCondition=() +----------filter((t1.a = 1)) +------------PhysicalOlapScan[extend_infer_t3] +----------filter((t2.a = 1)) +------------PhysicalOlapScan[extend_infer_t4] +--------filter((t3.a = 1)) +----------PhysicalOlapScan[extend_infer_t5] +------filter((t4.a = 1)) +--------PhysicalOlapScan[extend_infer_t5] +----filter((t5.a = 1)) +------PhysicalOlapScan[extend_infer_t5] + +-- !qt_multi_leftjoin_right_pull_up_shape_result -- + +-- !qt_leftjoin_right_pull_up_in_shape_shape -- +PhysicalResultSink +--hashJoin[LEFT_OUTER_JOIN] hashCondition=((t2.a = t3.a)) otherCondition=() +----hashJoin[LEFT_OUTER_JOIN] hashCondition=((t1.a = t2.a)) otherCondition=() +------filter(a IN (1, 2)) +--------PhysicalOlapScan[extend_infer_t3] +------filter(a IN (1, 2)) +--------PhysicalOlapScan[extend_infer_t4] +----filter(a IN (1, 2)) +------PhysicalOlapScan[extend_infer_t5] + +-- !qt_leftjoin_right_pull_up_in_shape_result -- + +-- !qt_leftjoin_right_pull_up_is_null_shape_shape -- +PhysicalResultSink +--hashJoin[LEFT_OUTER_JOIN] hashCondition=((t2.a = t3.a)) otherCondition=() +----hashJoin[LEFT_OUTER_JOIN] hashCondition=((t1.a = t2.a)) otherCondition=() +------filter(a IS NULL) +--------PhysicalOlapScan[extend_infer_t3] +------PhysicalOlapScan[extend_infer_t4] +----PhysicalOlapScan[extend_infer_t5] + +-- !qt_leftjoin_right_pull_up_is_null_shape_result -- +\N \N 9 3 \N \N \N \N \N \N \N \N +\N d2 3 55 \N \N \N \N \N \N \N \N + +-- !qt_leftjoin_right_pull_up_is_not_null_shape_shape -- +PhysicalResultSink +--hashJoin[LEFT_OUTER_JOIN] hashCondition=((t2.a = t3.a)) otherCondition=() +----hashJoin[LEFT_OUTER_JOIN] hashCondition=((t1.a = t2.a)) otherCondition=() +------filter(( not a IS NULL)) +--------PhysicalOlapScan[extend_infer_t3] +------PhysicalOlapScan[extend_infer_t4] +----PhysicalOlapScan[extend_infer_t5] + +-- !qt_leftjoin_right_pull_up_is_not_null_shape_result -- +0 d2 3 5 0 d2 2 2 \N \N \N \N +100 d2 3 5 100 d2 3 \N \N \N \N \N +12 \N 9 3 \N \N \N \N \N \N \N \N +33 d2 2 5 33 d2 23 5 \N \N \N \N +78 \N 9 3 78 d2 23 5 \N \N \N \N + +-- !qt_left_join_inner_shape -- +PhysicalResultSink +--hashJoin[INNER_JOIN] hashCondition=((t2.a = t3.a)) otherCondition=() +----hashJoin[INNER_JOIN] hashCondition=((t1.a = t2.a)) otherCondition=() +------filter((t1.a = 1)) +--------PhysicalOlapScan[extend_infer_t3] +------filter((t2.a = 1)) +--------PhysicalOlapScan[extend_infer_t4] +----filter((t3.a = 1)) +------PhysicalOlapScan[extend_infer_t5] + +-- !qt_left_join_inner_result -- + +-- !qt_left_join_semi_shape -- +PhysicalResultSink +--hashJoin[LEFT_SEMI_JOIN] hashCondition=((t2.a = t3.a)) otherCondition=() +----hashJoin[INNER_JOIN] hashCondition=((t1.a = t2.a)) otherCondition=() +------filter((t1.a = 1)) +--------PhysicalOlapScan[extend_infer_t3] +------filter((t2.a = 1)) +--------PhysicalOlapScan[extend_infer_t4] +----filter((t3.a = 1)) +------PhysicalOlapScan[extend_infer_t5] + +-- !qt_left_join_semi_result -- + +-- !qt_left_join_anti_shape -- +PhysicalResultSink +--hashJoin[LEFT_ANTI_JOIN] hashCondition=((t2.a = t3.a)) otherCondition=() +----hashJoin[LEFT_OUTER_JOIN] hashCondition=((t1.a = t2.a)) otherCondition=() +------filter((t1.a = 1)) +--------PhysicalOlapScan[extend_infer_t3] +------filter((t2.a = 1)) +--------PhysicalOlapScan[extend_infer_t4] +----filter((t3.a = 1)) +------PhysicalOlapScan[extend_infer_t5] + +-- !qt_left_join_anti_result -- + diff --git a/regression-test/suites/nereids_rules_p0/infer_predicate/extend_infer_equal_predicate.groovy b/regression-test/suites/nereids_rules_p0/infer_predicate/extend_infer_equal_predicate.groovy index b7a6090e901ed3..3c93815ce6ac85 100644 --- a/regression-test/suites/nereids_rules_p0/infer_predicate/extend_infer_equal_predicate.groovy +++ b/regression-test/suites/nereids_rules_p0/infer_predicate/extend_infer_equal_predicate.groovy @@ -357,8 +357,8 @@ suite("extend_infer_equal_predicate") { qt_pull_up_window_order_column """select c1,a from (select a,b,sum(a) over(order by a) c1 from extend_infer_t3 where a<33 ) t where a<33 order by 1,2""" qt_pull_up_partition_topn """select * from (select a, c,row_number() over(partition by b order by c) as rn from extend_infer_t3 where a>5 and c>3)t where a>5 and c>3 and rn<3 order by 1,2,3;""" - qt_pull_up_generate """select a,b, age from (select * from extend_infer_t3 lateral view - EXPLODE(ARRAY(30,60)) t1 as age where a<10 ) t group by grouping sets ((age),(a,b)) having a <10 order by 1,2,3""" +// qt_pull_up_generate """select a,b, age from (select * from extend_infer_t3 lateral view +// EXPLODE(ARRAY(30,60)) t1 as age where a<10 ) t group by grouping sets ((age),(a,b)) having a <10 order by 1,2,3""" qt_pull_up_from_inner_join """select a,b from (select t1.a,t2.b from extend_infer_t3 t1 inner join extend_infer_t4 t2 on t1.a=t2.a where t1.a<10 limit 10) t where a<10 order by 1,2""" qt_pull_up_from_left_join """select a,b from (select t2.a,t2.b from extend_infer_t3 t1 left join extend_infer_t4 t2 on t1.a=t2.a and t2.a<10 limit 10) t where a<10 order by 1,2""" @@ -377,4 +377,39 @@ suite("extend_infer_equal_predicate") { qt_pull_up_from_intersect """select a from(select a from (select t1.a from extend_infer_t3 t1 where t1.a<10 intersect select t2.a from extend_infer_t4 t2 where t2.a<10 ) tt limit 10) t where a<10 order by 1 ;""" qt_pull_up_from_agg """select a from (select a from extend_infer_t3 t1 where a<10 group by a limit 10) t where a<10 order by 1""" + + def explain_and_result = { tag, sql -> + "qt_${tag}_shape" "explain shape plan ${sql}" + "order_qt_${tag}_result" "${sql}" + } + + // test left join right table predicate pull up + explain_and_result 'qt_leftjoin_right_pull_up_shape', ''' + select * from extend_infer_t3 t1 left join extend_infer_t4 t2 on t1.a=t2.a left join extend_infer_t5 t3 on t2.a= t3.a where t1.a=1; + ''' + // test multi left join right table predicate pull up + explain_and_result "qt_multi_leftjoin_right_pull_up_shape", """ + select * from extend_infer_t3 t1 left join extend_infer_t4 t2 on t1.a=t2.a left join extend_infer_t5 t3 on t2.a= t3.a left join extend_infer_t5 t4 on t4.a=t3.a left join extend_infer_t5 t5 on t2.a=t5.a where t1.a=1; + """ + explain_and_result "qt_leftjoin_right_pull_up_in_shape", """ + select * from extend_infer_t3 t1 left join extend_infer_t4 t2 on t1.a=t2.a left join extend_infer_t5 t3 on t2.a= t3.a where t1.a in (1,2); + """ + // is null may be can be inferred but we do not infer it now + explain_and_result "qt_leftjoin_right_pull_up_is_null_shape", """ + select * from extend_infer_t3 t1 left join extend_infer_t4 t2 on t1.a=t2.a left join extend_infer_t5 t3 on t2.a= t3.a where t1.a is null; + """ + // is not null may be need not be innfered + explain_and_result "qt_leftjoin_right_pull_up_is_not_null_shape", """ + select * from extend_infer_t3 t1 left join extend_infer_t4 t2 on t1.a=t2.a left join extend_infer_t5 t3 on t2.a= t3.a where t1.a is not null; + """ + + explain_and_result 'qt_left_join_inner', ''' + select * from extend_infer_t3 t1 left join extend_infer_t4 t2 on t1.a=t2.a inner join extend_infer_t5 t3 on t2.a= t3.a where t1.a=1; + ''' + explain_and_result 'qt_left_join_semi', ''' + select * from extend_infer_t3 t1 left join extend_infer_t4 t2 on t1.a=t2.a left semi join extend_infer_t5 t3 on t2.a= t3.a where t1.a=1; + ''' + explain_and_result 'qt_left_join_anti', ''' + select * from extend_infer_t3 t1 left join extend_infer_t4 t2 on t1.a=t2.a left anti join extend_infer_t5 t3 on t2.a= t3.a where t1.a=1; + ''' } \ No newline at end of file