1818package org .apache .doris .nereids .rules .rewrite ;
1919
2020import org .apache .doris .mysql .MysqlCommand ;
21+ import org .apache .doris .nereids .CascadesContext ;
2122import org .apache .doris .nereids .jobs .JobContext ;
2223import org .apache .doris .nereids .trees .expressions .Expression ;
24+ import org .apache .doris .nereids .trees .expressions .IsNull ;
2325import org .apache .doris .nereids .trees .expressions .NamedExpression ;
26+ import org .apache .doris .nereids .trees .expressions .Or ;
2427import org .apache .doris .nereids .trees .expressions .Slot ;
2528import org .apache .doris .nereids .trees .expressions .StatementScopeIdGenerator ;
2629import org .apache .doris .nereids .trees .expressions .literal .BooleanLiteral ;
3740import org .apache .doris .nereids .util .PredicateInferUtils ;
3841import org .apache .doris .qe .ConnectContext ;
3942
43+ import com .google .common .base .Suppliers ;
4044import com .google .common .collect .ImmutableList ;
4145import com .google .common .collect .ImmutableSet ;
46+ import com .google .common .collect .Lists ;
4247import com .google .common .collect .Sets ;
4348
4449import java .util .HashMap ;
50+ import java .util .HashSet ;
4551import java .util .LinkedHashSet ;
52+ import java .util .List ;
4653import java .util .Map ;
4754import java .util .Optional ;
4855import java .util .Set ;
56+ import java .util .function .Supplier ;
4957
5058/**
5159 * infer additional predicates for `LogicalFilter` and `LogicalJoin`.
@@ -89,20 +97,23 @@ public Plan visitLogicalJoin(LogicalJoin<? extends Plan, ? extends Plan> join, J
8997 Plan right = join .right ();
9098 Set <Expression > expressions = getAllExpressions (left , right , join .getOnClauseCondition ());
9199 switch (join .getJoinType ()) {
92- case INNER_JOIN :
93100 case CROSS_JOIN :
94- case LEFT_SEMI_JOIN :
95- case RIGHT_SEMI_JOIN :
96101 left = inferNewPredicate (left , expressions );
97102 right = inferNewPredicate (right , expressions );
98103 break ;
104+ case INNER_JOIN :
105+ case LEFT_SEMI_JOIN :
106+ case RIGHT_SEMI_JOIN :
107+ left = inferNewPredicateRemoveUselessIsNull (left , expressions , join , context .getCascadesContext ());
108+ right = inferNewPredicateRemoveUselessIsNull (right , expressions , join , context .getCascadesContext ());
109+ break ;
99110 case LEFT_OUTER_JOIN :
100111 case LEFT_ANTI_JOIN :
101- right = inferNewPredicate (right , expressions );
112+ right = inferNewPredicateRemoveUselessIsNull (right , expressions , join , context . getCascadesContext () );
102113 break ;
103114 case RIGHT_OUTER_JOIN :
104115 case RIGHT_ANTI_JOIN :
105- left = inferNewPredicate (left , expressions );
116+ left = inferNewPredicateRemoveUselessIsNull (left , expressions , join , context . getCascadesContext () );
106117 break ;
107118 default :
108119 break ;
@@ -120,9 +131,16 @@ public Plan visitLogicalFilter(LogicalFilter<? extends Plan> filter, JobContext
120131 return new LogicalEmptyRelation (StatementScopeIdGenerator .newRelationId (), filter .getOutput ());
121132 }
122133 filter = visitChildren (this , filter , context );
123- Set <Expression > filterPredicates = pullUpPredicates (filter );
124- filterPredicates .removeAll (pullUpAllPredicates (filter .child ()));
125- return new LogicalFilter <>(ImmutableSet .copyOf (filterPredicates ), filter .child ());
134+ Set <Expression > inferredPredicates = pullUpPredicates (filter );
135+ inferredPredicates .removeAll (pullUpAllPredicates (filter .child ()));
136+ if (inferredPredicates .isEmpty ()) {
137+ return filter .child ();
138+ }
139+ if (inferredPredicates .equals (filter .getConjuncts ())) {
140+ return filter ;
141+ } else {
142+ return new LogicalFilter <>(ImmutableSet .copyOf (inferredPredicates ), filter .child ());
143+ }
126144 }
127145
128146 @ Override
@@ -134,15 +152,18 @@ public Plan visitLogicalExcept(LogicalExcept except, JobContext context) {
134152 }
135153 ImmutableList .Builder <Plan > builder = ImmutableList .builder ();
136154 builder .add (except .child (0 ));
155+ boolean changed = false ;
137156 for (int i = 1 ; i < except .arity (); ++i ) {
138157 Map <Expression , Expression > replaceMap = new HashMap <>();
139158 for (int j = 0 ; j < except .getOutput ().size (); ++j ) {
140159 NamedExpression output = except .getOutput ().get (j );
141160 replaceMap .put (output , except .getRegularChildOutput (i ).get (j ));
142161 }
143- builder .add (inferNewPredicate (except .child (i ), ExpressionUtils .replace (baseExpressions , replaceMap )));
162+ Plan newChild = inferNewPredicate (except .child (i ), ExpressionUtils .replace (baseExpressions , replaceMap ));
163+ changed = changed || newChild != except .child (i );
164+ builder .add (newChild );
144165 }
145- return except .withChildren (builder .build ());
166+ return changed ? except .withChildren (builder .build ()) : except ;
146167 }
147168
148169 @ Override
@@ -153,15 +174,18 @@ public Plan visitLogicalIntersect(LogicalIntersect intersect, JobContext context
153174 return intersect ;
154175 }
155176 ImmutableList .Builder <Plan > builder = ImmutableList .builder ();
177+ boolean changed = false ;
156178 for (int i = 0 ; i < intersect .arity (); ++i ) {
157179 Map <Expression , Expression > replaceMap = new HashMap <>();
158180 for (int j = 0 ; j < intersect .getOutput ().size (); ++j ) {
159181 NamedExpression output = intersect .getOutput ().get (j );
160182 replaceMap .put (output , intersect .getRegularChildOutput (i ).get (j ));
161183 }
162- builder .add (inferNewPredicate (intersect .child (i ), ExpressionUtils .replace (baseExpressions , replaceMap )));
184+ Plan newChild = inferNewPredicate (intersect .child (i ), ExpressionUtils .replace (baseExpressions , replaceMap ));
185+ changed = changed || newChild != intersect .child (i );
186+ builder .add (newChild );
163187 }
164- return intersect .withChildren (builder .build ());
188+ return changed ? intersect .withChildren (builder .build ()) : intersect ;
165189 }
166190
167191 private Set <Expression > getAllExpressions (Plan left , Plan right , Optional <Expression > condition ) {
@@ -191,4 +215,60 @@ private Plan inferNewPredicate(Plan plan, Set<Expression> expressions) {
191215 predicates .removeAll (plan .accept (pullUpAllPredicates , null ));
192216 return PlanUtils .filterOrSelf (predicates , plan );
193217 }
218+
219+ // Remove redundant "or is null" from expressions.
220+ // For example, when we have a t2 left join t3 condition t2.a=t3.a, we can infer that t3.a is not null.
221+ // If we find a predicate like "t3.a = 1 or t3.a is null" in expressions, we change it to "t3.a=1".
222+ private Plan inferNewPredicateRemoveUselessIsNull (Plan plan , Set <Expression > expressions ,
223+ LogicalJoin <? extends Plan , ? extends Plan > join , CascadesContext cascadesContext ) {
224+ Supplier <Set <Slot >> supplier = Suppliers .memoize (() -> {
225+ Set <Expression > all = new HashSet <>();
226+ all .addAll (join .getHashJoinConjuncts ());
227+ all .addAll (join .getOtherJoinConjuncts ());
228+ return ExpressionUtils .inferNotNullSlots (all , cascadesContext );
229+ });
230+
231+ Set <Expression > predicates = new LinkedHashSet <>();
232+ Set <Slot > planOutputs = plan .getOutputSet ();
233+ for (Expression expr : expressions ) {
234+ Set <Slot > slots = expr .getInputSlots ();
235+ if (slots .isEmpty () || !planOutputs .containsAll (slots )) {
236+ continue ;
237+ }
238+ if (expr instanceof Or && expr .isInferred ()) {
239+ List <Expression > orChildren = ExpressionUtils .extractDisjunction (expr );
240+ List <Expression > newOrChildren = Lists .newArrayList ();
241+ boolean changed = false ;
242+ for (Expression orChild : orChildren ) {
243+ if (orChild instanceof IsNull && orChild .child (0 ) instanceof Slot
244+ && supplier .get ().contains (orChild .child (0 ))) {
245+ changed = true ;
246+ continue ;
247+ }
248+ newOrChildren .add (orChild );
249+ }
250+ if (changed ) {
251+ if (newOrChildren .size () == 1 ) {
252+ predicates .add (withInferredIfSupported (newOrChildren .get (0 ), expr ));
253+ } else if (newOrChildren .size () > 1 ) {
254+ predicates .add (ExpressionUtils .or (newOrChildren ).withInferred (true ));
255+ }
256+ } else {
257+ predicates .add (expr );
258+ }
259+ } else {
260+ predicates .add (expr );
261+ }
262+ }
263+ predicates .removeAll (plan .accept (pullUpAllPredicates , null ));
264+ return PlanUtils .filterOrSelf (predicates , plan );
265+ }
266+
267+ private Expression withInferredIfSupported (Expression expression , Expression originExpr ) {
268+ try {
269+ return expression .withInferred (true );
270+ } catch (RuntimeException e ) {
271+ return originExpr ;
272+ }
273+ }
194274}
0 commit comments