Skip to content

Commit c0daaac

Browse files
author
Karuppayya Rajendran
committed
Codegen for MergeRowExec
1 parent 686d844 commit c0daaac

File tree

1 file changed

+278
-5
lines changed

1 file changed

+278
-5
lines changed

sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/MergeRowsExec.scala

Lines changed: 278 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -17,24 +17,26 @@
1717

1818
package org.apache.spark.sql.execution.datasources.v2
1919

20+
import org.apache.spark.SparkUnsupportedOperationException
2021
import org.roaringbitmap.longlong.Roaring64Bitmap
21-
2222
import org.apache.spark.rdd.RDD
2323
import org.apache.spark.sql.AnalysisException
2424
import org.apache.spark.sql.catalyst.InternalRow
2525
import org.apache.spark.sql.catalyst.expressions.Attribute
2626
import org.apache.spark.sql.catalyst.expressions.AttributeSet
2727
import org.apache.spark.sql.catalyst.expressions.BasePredicate
28+
import org.apache.spark.sql.catalyst.expressions.BindReferences
2829
import org.apache.spark.sql.catalyst.expressions.Expression
2930
import org.apache.spark.sql.catalyst.expressions.Projection
3031
import org.apache.spark.sql.catalyst.expressions.UnsafeProjection
31-
import org.apache.spark.sql.catalyst.expressions.codegen.GeneratePredicate
32+
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode, FalseLiteral, GeneratePredicate, JavaCode}
33+
import org.apache.spark.sql.catalyst.expressions.codegen.Block.BlockHelper
3234
import org.apache.spark.sql.catalyst.plans.logical.MergeRows.{Context, Copy, Delete, Discard, Insert, Instruction, Keep, ROW_ID, Split, Update}
3335
import org.apache.spark.sql.catalyst.util.truncatedString
3436
import org.apache.spark.sql.errors.QueryExecutionErrors
35-
import org.apache.spark.sql.execution.SparkPlan
36-
import org.apache.spark.sql.execution.UnaryExecNode
37+
import org.apache.spark.sql.execution.{CodegenSupport, SparkPlan, UnaryExecNode}
3738
import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics}
39+
import org.apache.spark.sql.types.BooleanType
3840

3941
case class MergeRowsExec(
4042
isSourceRowPresent: Expression,
@@ -44,7 +46,7 @@ case class MergeRowsExec(
4446
notMatchedBySourceInstructions: Seq[Instruction],
4547
checkCardinality: Boolean,
4648
output: Seq[Attribute],
47-
child: SparkPlan) extends UnaryExecNode {
49+
child: SparkPlan) extends UnaryExecNode with CodegenSupport {
4850

4951
override lazy val metrics: Map[String, SQLMetric] = Map(
5052
"numTargetRowsCopied" -> SQLMetrics.createMetric(sparkContext,
@@ -92,6 +94,277 @@ case class MergeRowsExec(
9294
child.execute().mapPartitions(processPartition)
9395
}
9496

97+
override def inputRDDs(): Seq[RDD[InternalRow]] = {
98+
child.asInstanceOf[CodegenSupport].inputRDDs()
99+
}
100+
101+
protected override def doProduce(ctx: CodegenContext): String = {
102+
child.asInstanceOf[CodegenSupport].produce(ctx, this)
103+
}
104+
105+
override def doConsume(ctx: CodegenContext, input: Seq[ExprCode], row: ExprCode): String = {
106+
// Save the input variables that were passed to doConsume
107+
val inputCurrentVars = input
108+
109+
// code for instruction execution code
110+
generateInstructionExecutionCode(ctx, inputCurrentVars)
111+
}
112+
113+
114+
/**
115+
* code for cardinality validation
116+
*/
117+
private def generateCardinalityValidationCode(ctx: CodegenContext, rowIdOrdinal: Int,
118+
input: Seq[ExprCode]): ExprCode = {
119+
val bitmapClass = classOf[Roaring64Bitmap]
120+
val rowIdBitmap = ctx.addMutableState(bitmapClass.getName, "matchedRowIds",
121+
v => s"$v = new ${bitmapClass.getName}();")
122+
123+
val currentRowId = input(rowIdOrdinal)
124+
val queryExecutionErrorsClass = QueryExecutionErrors.getClass.getName + ".MODULE$"
125+
val code =
126+
code"""
127+
|${currentRowId.code}
128+
|if ($rowIdBitmap.contains(${currentRowId.value})) {
129+
| throw $queryExecutionErrorsClass.mergeCardinalityViolationError();
130+
|}
131+
|$rowIdBitmap.add(${currentRowId.value});
132+
""".stripMargin
133+
ExprCode(code, FalseLiteral, JavaCode.variable(rowIdBitmap, bitmapClass))
134+
}
135+
136+
/**
137+
* Generate code for instruction execution based on row presence conditions
138+
*/
139+
private def generateInstructionExecutionCode(ctx: CodegenContext,
140+
inputExprs: Seq[ExprCode]): String = {
141+
142+
// code for evaluating src/tgt presence conditions
143+
val sourcePresentExpr = generatePredicateCode(ctx, isSourceRowPresent, child.output, inputExprs)
144+
val targetPresentExpr = generatePredicateCode(ctx, isTargetRowPresent, child.output, inputExprs)
145+
146+
// code for each instruction type
147+
val matchedInstructionsCode = generateInstructionsCode(ctx, matchedInstructions,
148+
"matched", inputExprs, sourcePresent = true)
149+
val notMatchedInstructionsCode = generateInstructionsCode(ctx, notMatchedInstructions,
150+
"notMatched", inputExprs, sourcePresent = true)
151+
val notMatchedBySourceInstructionsCode = generateInstructionsCode(ctx,
152+
notMatchedBySourceInstructions, "notMatchedBySource", inputExprs, sourcePresent = false)
153+
154+
val cardinalityValidationCode = if (checkCardinality) {
155+
val rowIdOrdinal = child.output.indexWhere(attr => conf.resolver(attr.name, ROW_ID))
156+
assert(rowIdOrdinal != -1, "Cannot find row ID attr")
157+
generateCardinalityValidationCode(ctx, rowIdOrdinal, inputExprs).code
158+
} else {
159+
""
160+
}
161+
162+
s"""
163+
|${sourcePresentExpr.code}
164+
|${targetPresentExpr.code}
165+
|
166+
|if (${targetPresentExpr.value} && ${sourcePresentExpr.value}) {
167+
| $cardinalityValidationCode
168+
| $matchedInstructionsCode
169+
|} else if (${sourcePresentExpr.value}) {
170+
| $notMatchedInstructionsCode
171+
|} else if (${targetPresentExpr.value}) {
172+
| $notMatchedBySourceInstructionsCode
173+
|}
174+
""".stripMargin
175+
}
176+
177+
/**
178+
* Generate code for executing a sequence of instructions
179+
*/
180+
private def generateInstructionsCode(ctx: CodegenContext, instructions: Seq[Instruction],
181+
instructionType: String,
182+
inputExprs: Seq[ExprCode],
183+
sourcePresent: Boolean): String = {
184+
if (instructions.isEmpty) {
185+
""
186+
} else {
187+
val instructionCodes = instructions.map(instruction =>
188+
generateSingleInstructionCode(ctx, instruction, inputExprs, sourcePresent))
189+
190+
s"""
191+
|${instructionCodes.mkString("\n")}
192+
|return;
193+
""".stripMargin
194+
}
195+
}
196+
197+
private def generateSingleInstructionCode(ctx: CodegenContext,
198+
instruction: Instruction,
199+
inputExprs: Seq[ExprCode],
200+
sourcePresent: Boolean): String = {
201+
instruction match {
202+
case Keep(context, condition, outputExprs) =>
203+
val projectionExpr = generateProjectionCode(ctx, outputExprs, inputExprs)
204+
val code = generatePredicateCode(ctx, condition, child.output, inputExprs)
205+
206+
// Generate metric updates based on context
207+
val metricUpdateCode = generateMetricUpdateCode(ctx, context, sourcePresent)
208+
209+
s"""
210+
|${code.code}
211+
|if (${code.value}) {
212+
| $metricUpdateCode
213+
| ${consume(ctx, projectionExpr)}
214+
| return;
215+
|}
216+
""".stripMargin
217+
218+
case Discard(condition) =>
219+
val code = generatePredicateCode(ctx, condition, child.output, inputExprs)
220+
val metricUpdateCode = generateDeleteMetricUpdateCode(ctx, sourcePresent)
221+
222+
s"""
223+
|${code.code}
224+
|if (${code.value}) {
225+
| $metricUpdateCode
226+
| return; // Discar row
227+
|}
228+
""".stripMargin
229+
230+
case Split(condition, outputExprs, otherOutputExprs) =>
231+
val projectionExpr = generateProjectionCode(ctx, outputExprs, inputExprs)
232+
val otherProjectionExpr = generateProjectionCode(ctx, otherOutputExprs, inputExprs)
233+
val code = generatePredicateCode(ctx, condition, child.output, inputExprs)
234+
val metricUpdateCode = generateUpdateMetricUpdateCode(ctx, sourcePresent)
235+
236+
s"""
237+
|${code.code}
238+
|if (${code.value}) {
239+
| $metricUpdateCode
240+
| ${consume(ctx, projectionExpr)}
241+
| ${consume(ctx, otherProjectionExpr)}
242+
| return;
243+
|}
244+
""".stripMargin
245+
case _ =>
246+
// Codegen not implemented
247+
throw new SparkUnsupportedOperationException(
248+
errorClass = "_LEGACY_ERROR_TEMP_3073",
249+
messageParameters = Map("instruction" -> instruction.toString))
250+
}
251+
}
252+
253+
/**
254+
* metric update code based on Keep's context
255+
*/
256+
private def generateMetricUpdateCode(ctx: CodegenContext, context: Context,
257+
sourcePresent: Boolean): String = {
258+
context match {
259+
case Copy =>
260+
val copyMetric = metricTerm(ctx, "numTargetRowsCopied")
261+
s"$copyMetric.add(1);"
262+
263+
case Insert =>
264+
val insertMetric = metricTerm(ctx, "numTargetRowsInserted")
265+
s"$insertMetric.add(1);"
266+
267+
case Update =>
268+
generateUpdateMetricUpdateCode(ctx, sourcePresent)
269+
270+
case Delete =>
271+
generateDeleteMetricUpdateCode(ctx, sourcePresent)
272+
273+
case _ =>
274+
throw new IllegalArgumentException(s"Unexpected context for KeepExec: $context")
275+
}
276+
}
277+
278+
private def generateUpdateMetricUpdateCode(ctx: CodegenContext,
279+
sourcePresent: Boolean): String = {
280+
val updateMetric = metricTerm(ctx, "numTargetRowsUpdated")
281+
if (sourcePresent) {
282+
val matchedUpdateMetric = metricTerm(ctx, "numTargetRowsMatchedUpdated")
283+
284+
s"""
285+
|$updateMetric.add(1);
286+
|$matchedUpdateMetric.add(1);
287+
""".stripMargin
288+
} else {
289+
val notMatchedBySourceUpdateMetric = metricTerm(ctx, "numTargetRowsNotMatchedBySourceUpdated")
290+
291+
s"""
292+
|$updateMetric.add(1);
293+
|$notMatchedBySourceUpdateMetric.add(1);
294+
""".stripMargin
295+
}
296+
}
297+
298+
private def generateDeleteMetricUpdateCode(ctx: CodegenContext,
299+
sourcePresent: Boolean): String = {
300+
val deleteMetric = metricTerm(ctx, "numTargetRowsDeleted")
301+
if (sourcePresent) {
302+
val matchedDeleteMetric = metricTerm(ctx, "numTargetRowsMatchedDeleted")
303+
304+
s"""
305+
|$deleteMetric.add(1);
306+
|$matchedDeleteMetric.add(1);
307+
""".stripMargin
308+
} else {
309+
val notMatchedBySourceDeleteMetric = metricTerm(ctx, "numTargetRowsNotMatchedBySourceDeleted")
310+
311+
s"""
312+
|$deleteMetric.add(1);
313+
|$notMatchedBySourceDeleteMetric.add(1);
314+
""".stripMargin
315+
}
316+
}
317+
318+
/**
319+
* Helper method to save and restore CodegenContext state for code generation.
320+
*
321+
* This is needed because when generating code for expressions, the CodegenContext
322+
* state (currentVars and INPUT_ROW) gets modified during expression evaluation.
323+
* This method temporarily sets the context to the input variables from doConsume
324+
* and restores the original state after the block completes.
325+
*/
326+
private def withCodegenContext[T](
327+
ctx: CodegenContext,
328+
inputCurrentVars: Seq[ExprCode])(block: => T): T = {
329+
val originalCurrentVars = ctx.currentVars
330+
val originalInputRow = ctx.INPUT_ROW
331+
try {
332+
// Set to the input variables saved in doConsume
333+
ctx.currentVars = inputCurrentVars
334+
block
335+
} finally {
336+
// Restore original context
337+
ctx.currentVars = originalCurrentVars
338+
ctx.INPUT_ROW = originalInputRow
339+
}
340+
}
341+
342+
private def generatePredicateCode(ctx: CodegenContext,
343+
predicate: Expression,
344+
inputAttrs: Seq[Attribute],
345+
inputCurrentVars: Seq[ExprCode]): ExprCode = {
346+
withCodegenContext(ctx, inputCurrentVars) {
347+
val boundPredicate = BindReferences.bindReference(predicate, inputAttrs)
348+
val ev = boundPredicate.genCode(ctx)
349+
val predicateVar = ctx.freshName("predicateResult")
350+
val code = code"""
351+
|${ev.code}
352+
|boolean $predicateVar = !${ev.isNull} && ${ev.value};
353+
""".stripMargin
354+
ExprCode(code, FalseLiteral,
355+
JavaCode.variable(predicateVar, BooleanType))
356+
}
357+
}
358+
359+
private def generateProjectionCode(ctx: CodegenContext,
360+
outputExprs: Seq[Expression],
361+
inputCurrentVars: Seq[ExprCode]): Seq[ExprCode] = {
362+
withCodegenContext(ctx, inputCurrentVars) {
363+
val boundExprs = outputExprs.map(BindReferences.bindReference(_, child.output))
364+
boundExprs.map(_.genCode(ctx))
365+
}
366+
}
367+
95368
private def processPartition(rowIterator: Iterator[InternalRow]): Iterator[InternalRow] = {
96369
val isSourceRowPresentPred = createPredicate(isSourceRowPresent)
97370
val isTargetRowPresentPred = createPredicate(isTargetRowPresent)

0 commit comments

Comments
 (0)