Skip to content

Commit c235b5f

Browse files
kiszkgatorsmile
authored andcommitted
[SPARK-22746][SQL] Avoid the generation of useless mutable states by SortMergeJoin
## What changes were proposed in this pull request? This PR reduce the number of global mutable variables in generated code of `SortMergeJoin`. Before this PR, global mutable variables are used to extend lifetime of variables in the nested loop. This can be achieved by declaring variable at the outer most loop level where the variables are used. In the following example, `smj_value8`, `smj_value8`, and `smj_value9` are declared as local variable at lines 145-147 in `With this PR`. This PR fixes potential assertion error by #19865. Without this PR, a global mutable variable is potentially passed to arguments in generated code of split function. Without this PR ``` /* 010 */ int smj_value8; /* 011 */ boolean smj_value8; /* 012 */ int smj_value9; .. /* 143 */ protected void processNext() throws java.io.IOException { /* 144 */ while (findNextInnerJoinRows(smj_leftInput, smj_rightInput)) { /* 145 */ boolean smj_loaded = false; /* 146 */ smj_isNull6 = smj_leftRow.isNullAt(1); /* 147 */ smj_value9 = smj_isNull6 ? -1 : (smj_leftRow.getInt(1)); /* 148 */ scala.collection.Iterator<UnsafeRow> smj_iterator = smj_matches.generateIterator(); /* 149 */ while (smj_iterator.hasNext()) { /* 150 */ InternalRow smj_rightRow1 = (InternalRow) smj_iterator.next(); /* 151 */ boolean smj_isNull8 = smj_rightRow1.isNullAt(1); /* 152 */ int smj_value11 = smj_isNull8 ? -1 : (smj_rightRow1.getInt(1)); /* 153 */ /* 154 */ boolean smj_value12 = (smj_isNull6 && smj_isNull8) || /* 155 */ (!smj_isNull6 && !smj_isNull8 && smj_value9 == smj_value11); /* 156 */ if (false || !smj_value12) continue; /* 157 */ if (!smj_loaded) { /* 158 */ smj_loaded = true; /* 159 */ smj_value8 = smj_leftRow.getInt(0); /* 160 */ } /* 161 */ int smj_value10 = smj_rightRow1.getInt(0); /* 162 */ smj_numOutputRows.add(1); /* 163 */ /* 164 */ smj_rowWriter.zeroOutNullBytes(); /* 165 */ /* 166 */ smj_rowWriter.write(0, smj_value8); /* 167 */ /* 168 */ if (smj_isNull6) { /* 169 */ smj_rowWriter.setNullAt(1); /* 170 */ } else { /* 171 */ smj_rowWriter.write(1, smj_value9); /* 172 */ } /* 173 */ /* 174 */ smj_rowWriter.write(2, smj_value10); /* 175 */ /* 176 */ if (smj_isNull8) { /* 177 */ smj_rowWriter.setNullAt(3); /* 178 */ } else { /* 179 */ smj_rowWriter.write(3, smj_value11); /* 180 */ } /* 181 */ append(smj_result.copy()); /* 182 */ /* 183 */ } /* 184 */ if (shouldStop()) return; /* 185 */ } /* 186 */ } ``` With this PR ``` /* 143 */ protected void processNext() throws java.io.IOException { /* 144 */ while (findNextInnerJoinRows(smj_leftInput, smj_rightInput)) { /* 145 */ int smj_value8 = -1; /* 146 */ boolean smj_isNull6 = false; /* 147 */ int smj_value9 = -1; /* 148 */ boolean smj_loaded = false; /* 149 */ smj_isNull6 = smj_leftRow.isNullAt(1); /* 150 */ smj_value9 = smj_isNull6 ? -1 : (smj_leftRow.getInt(1)); /* 151 */ scala.collection.Iterator<UnsafeRow> smj_iterator = smj_matches.generateIterator(); /* 152 */ while (smj_iterator.hasNext()) { /* 153 */ InternalRow smj_rightRow1 = (InternalRow) smj_iterator.next(); /* 154 */ boolean smj_isNull8 = smj_rightRow1.isNullAt(1); /* 155 */ int smj_value11 = smj_isNull8 ? -1 : (smj_rightRow1.getInt(1)); /* 156 */ /* 157 */ boolean smj_value12 = (smj_isNull6 && smj_isNull8) || /* 158 */ (!smj_isNull6 && !smj_isNull8 && smj_value9 == smj_value11); /* 159 */ if (false || !smj_value12) continue; /* 160 */ if (!smj_loaded) { /* 161 */ smj_loaded = true; /* 162 */ smj_value8 = smj_leftRow.getInt(0); /* 163 */ } /* 164 */ int smj_value10 = smj_rightRow1.getInt(0); /* 165 */ smj_numOutputRows.add(1); /* 166 */ /* 167 */ smj_rowWriter.zeroOutNullBytes(); /* 168 */ /* 169 */ smj_rowWriter.write(0, smj_value8); /* 170 */ /* 171 */ if (smj_isNull6) { /* 172 */ smj_rowWriter.setNullAt(1); /* 173 */ } else { /* 174 */ smj_rowWriter.write(1, smj_value9); /* 175 */ } /* 176 */ /* 177 */ smj_rowWriter.write(2, smj_value10); /* 178 */ /* 179 */ if (smj_isNull8) { /* 180 */ smj_rowWriter.setNullAt(3); /* 181 */ } else { /* 182 */ smj_rowWriter.write(3, smj_value11); /* 183 */ } /* 184 */ append(smj_result.copy()); /* 185 */ /* 186 */ } /* 187 */ if (shouldStop()) return; /* 188 */ } /* 189 */ } ``` ## How was this patch tested? Existing test cases Author: Kazuaki Ishizaki <[email protected]> Closes #19937 from kiszk/SPARK-22746.
1 parent a04f2be commit c235b5f

File tree

1 file changed

+17
-10
lines changed

1 file changed

+17
-10
lines changed

sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala

Lines changed: 17 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -507,32 +507,38 @@ case class SortMergeJoinExec(
507507
}
508508

509509
/**
510-
* Creates variables for left part of result row.
510+
* Creates variables and declarations for left part of result row.
511511
*
512512
* In order to defer the access after condition and also only access once in the loop,
513513
* the variables should be declared separately from accessing the columns, we can't use the
514514
* codegen of BoundReference here.
515515
*/
516-
private def createLeftVars(ctx: CodegenContext, leftRow: String): Seq[ExprCode] = {
516+
private def createLeftVars(ctx: CodegenContext, leftRow: String): (Seq[ExprCode], Seq[String]) = {
517517
ctx.INPUT_ROW = leftRow
518518
left.output.zipWithIndex.map { case (a, i) =>
519519
val value = ctx.freshName("value")
520520
val valueCode = ctx.getValue(leftRow, a.dataType, i.toString)
521-
// declare it as class member, so we can access the column before or in the loop.
522-
ctx.addMutableState(ctx.javaType(a.dataType), value)
521+
val javaType = ctx.javaType(a.dataType)
522+
val defaultValue = ctx.defaultValue(a.dataType)
523523
if (a.nullable) {
524524
val isNull = ctx.freshName("isNull")
525-
ctx.addMutableState(ctx.JAVA_BOOLEAN, isNull)
526525
val code =
527526
s"""
528527
|$isNull = $leftRow.isNullAt($i);
529-
|$value = $isNull ? ${ctx.defaultValue(a.dataType)} : ($valueCode);
528+
|$value = $isNull ? $defaultValue : ($valueCode);
530529
""".stripMargin
531-
ExprCode(code, isNull, value)
530+
val leftVarsDecl =
531+
s"""
532+
|boolean $isNull = false;
533+
|$javaType $value = $defaultValue;
534+
""".stripMargin
535+
(ExprCode(code, isNull, value), leftVarsDecl)
532536
} else {
533-
ExprCode(s"$value = $valueCode;", "false", value)
537+
val code = s"$value = $valueCode;"
538+
val leftVarsDecl = s"""$javaType $value = $defaultValue;"""
539+
(ExprCode(code, "false", value), leftVarsDecl)
534540
}
535-
}
541+
}.unzip
536542
}
537543

538544
/**
@@ -580,7 +586,7 @@ case class SortMergeJoinExec(
580586
val (leftRow, matches) = genScanner(ctx)
581587

582588
// Create variables for row from both sides.
583-
val leftVars = createLeftVars(ctx, leftRow)
589+
val (leftVars, leftVarDecl) = createLeftVars(ctx, leftRow)
584590
val rightRow = ctx.freshName("rightRow")
585591
val rightVars = createRightVar(ctx, rightRow)
586592

@@ -617,6 +623,7 @@ case class SortMergeJoinExec(
617623

618624
s"""
619625
|while (findNextInnerJoinRows($leftInput, $rightInput)) {
626+
| ${leftVarDecl.mkString("\n")}
620627
| ${beforeLoop.trim}
621628
| scala.collection.Iterator<UnsafeRow> $iterator = $matches.generateIterator();
622629
| while ($iterator.hasNext()) {

0 commit comments

Comments
 (0)