17
17
18
18
package org .apache .spark .sql .execution .datasources .v2
19
19
20
+ import org .apache .spark .SparkUnsupportedOperationException
20
21
import org .roaringbitmap .longlong .Roaring64Bitmap
21
-
22
22
import org .apache .spark .rdd .RDD
23
23
import org .apache .spark .sql .AnalysisException
24
24
import org .apache .spark .sql .catalyst .InternalRow
25
25
import org .apache .spark .sql .catalyst .expressions .Attribute
26
26
import org .apache .spark .sql .catalyst .expressions .AttributeSet
27
27
import org .apache .spark .sql .catalyst .expressions .BasePredicate
28
+ import org .apache .spark .sql .catalyst .expressions .BindReferences
28
29
import org .apache .spark .sql .catalyst .expressions .Expression
29
30
import org .apache .spark .sql .catalyst .expressions .Projection
30
31
import org .apache .spark .sql .catalyst .expressions .UnsafeProjection
31
- import org .apache .spark .sql .catalyst .expressions .codegen .GeneratePredicate
32
+ import org .apache .spark .sql .catalyst .expressions .codegen .{CodegenContext , ExprCode , FalseLiteral , GeneratePredicate , JavaCode }
33
+ import org .apache .spark .sql .catalyst .expressions .codegen .Block .BlockHelper
32
34
import org .apache .spark .sql .catalyst .plans .logical .MergeRows .{Context , Copy , Delete , Discard , Insert , Instruction , Keep , ROW_ID , Split , Update }
33
35
import org .apache .spark .sql .catalyst .util .truncatedString
34
36
import org .apache .spark .sql .errors .QueryExecutionErrors
35
- import org .apache .spark .sql .execution .SparkPlan
36
- import org .apache .spark .sql .execution .UnaryExecNode
37
+ import org .apache .spark .sql .execution .{CodegenSupport , SparkPlan , UnaryExecNode }
37
38
import org .apache .spark .sql .execution .metric .{SQLMetric , SQLMetrics }
39
+ import org .apache .spark .sql .types .BooleanType
38
40
39
41
case class MergeRowsExec (
40
42
isSourceRowPresent : Expression ,
@@ -44,7 +46,7 @@ case class MergeRowsExec(
44
46
notMatchedBySourceInstructions : Seq [Instruction ],
45
47
checkCardinality : Boolean ,
46
48
output : Seq [Attribute ],
47
- child : SparkPlan ) extends UnaryExecNode {
49
+ child : SparkPlan ) extends UnaryExecNode with CodegenSupport {
48
50
49
51
override lazy val metrics : Map [String , SQLMetric ] = Map (
50
52
" numTargetRowsCopied" -> SQLMetrics .createMetric(sparkContext,
@@ -92,6 +94,277 @@ case class MergeRowsExec(
92
94
child.execute().mapPartitions(processPartition)
93
95
}
94
96
97
+ override def inputRDDs (): Seq [RDD [InternalRow ]] = {
98
+ child.asInstanceOf [CodegenSupport ].inputRDDs()
99
+ }
100
+
101
+ protected override def doProduce (ctx : CodegenContext ): String = {
102
+ child.asInstanceOf [CodegenSupport ].produce(ctx, this )
103
+ }
104
+
105
+ override def doConsume (ctx : CodegenContext , input : Seq [ExprCode ], row : ExprCode ): String = {
106
+ // Save the input variables that were passed to doConsume
107
+ val inputCurrentVars = input
108
+
109
+ // code for instruction execution code
110
+ generateInstructionExecutionCode(ctx, inputCurrentVars)
111
+ }
112
+
113
+
114
+ /**
115
+ * code for cardinality validation
116
+ */
117
+ private def generateCardinalityValidationCode (ctx : CodegenContext , rowIdOrdinal : Int ,
118
+ input : Seq [ExprCode ]): ExprCode = {
119
+ val bitmapClass = classOf [Roaring64Bitmap ]
120
+ val rowIdBitmap = ctx.addMutableState(bitmapClass.getName, " matchedRowIds" ,
121
+ v => s " $v = new ${bitmapClass.getName}(); " )
122
+
123
+ val currentRowId = input(rowIdOrdinal)
124
+ val queryExecutionErrorsClass = QueryExecutionErrors .getClass.getName + " .MODULE$"
125
+ val code =
126
+ code """
127
+ | ${currentRowId.code}
128
+ |if ( $rowIdBitmap.contains( ${currentRowId.value})) {
129
+ | throw $queryExecutionErrorsClass.mergeCardinalityViolationError();
130
+ |}
131
+ | $rowIdBitmap.add( ${currentRowId.value});
132
+ """ .stripMargin
133
+ ExprCode (code, FalseLiteral , JavaCode .variable(rowIdBitmap, bitmapClass))
134
+ }
135
+
136
+ /**
137
+ * Generate code for instruction execution based on row presence conditions
138
+ */
139
+ private def generateInstructionExecutionCode (ctx : CodegenContext ,
140
+ inputExprs : Seq [ExprCode ]): String = {
141
+
142
+ // code for evaluating src/tgt presence conditions
143
+ val sourcePresentExpr = generatePredicateCode(ctx, isSourceRowPresent, child.output, inputExprs)
144
+ val targetPresentExpr = generatePredicateCode(ctx, isTargetRowPresent, child.output, inputExprs)
145
+
146
+ // code for each instruction type
147
+ val matchedInstructionsCode = generateInstructionsCode(ctx, matchedInstructions,
148
+ " matched" , inputExprs, sourcePresent = true )
149
+ val notMatchedInstructionsCode = generateInstructionsCode(ctx, notMatchedInstructions,
150
+ " notMatched" , inputExprs, sourcePresent = true )
151
+ val notMatchedBySourceInstructionsCode = generateInstructionsCode(ctx,
152
+ notMatchedBySourceInstructions, " notMatchedBySource" , inputExprs, sourcePresent = false )
153
+
154
+ val cardinalityValidationCode = if (checkCardinality) {
155
+ val rowIdOrdinal = child.output.indexWhere(attr => conf.resolver(attr.name, ROW_ID ))
156
+ assert(rowIdOrdinal != - 1 , " Cannot find row ID attr" )
157
+ generateCardinalityValidationCode(ctx, rowIdOrdinal, inputExprs).code
158
+ } else {
159
+ " "
160
+ }
161
+
162
+ s """
163
+ | ${sourcePresentExpr.code}
164
+ | ${targetPresentExpr.code}
165
+ |
166
+ |if ( ${targetPresentExpr.value} && ${sourcePresentExpr.value}) {
167
+ | $cardinalityValidationCode
168
+ | $matchedInstructionsCode
169
+ |} else if ( ${sourcePresentExpr.value}) {
170
+ | $notMatchedInstructionsCode
171
+ |} else if ( ${targetPresentExpr.value}) {
172
+ | $notMatchedBySourceInstructionsCode
173
+ |}
174
+ """ .stripMargin
175
+ }
176
+
177
+ /**
178
+ * Generate code for executing a sequence of instructions
179
+ */
180
+ private def generateInstructionsCode (ctx : CodegenContext , instructions : Seq [Instruction ],
181
+ instructionType : String ,
182
+ inputExprs : Seq [ExprCode ],
183
+ sourcePresent : Boolean ): String = {
184
+ if (instructions.isEmpty) {
185
+ " "
186
+ } else {
187
+ val instructionCodes = instructions.map(instruction =>
188
+ generateSingleInstructionCode(ctx, instruction, inputExprs, sourcePresent))
189
+
190
+ s """
191
+ | ${instructionCodes.mkString(" \n " )}
192
+ |return;
193
+ """ .stripMargin
194
+ }
195
+ }
196
+
197
+ private def generateSingleInstructionCode (ctx : CodegenContext ,
198
+ instruction : Instruction ,
199
+ inputExprs : Seq [ExprCode ],
200
+ sourcePresent : Boolean ): String = {
201
+ instruction match {
202
+ case Keep (context, condition, outputExprs) =>
203
+ val projectionExpr = generateProjectionCode(ctx, outputExprs, inputExprs)
204
+ val code = generatePredicateCode(ctx, condition, child.output, inputExprs)
205
+
206
+ // Generate metric updates based on context
207
+ val metricUpdateCode = generateMetricUpdateCode(ctx, context, sourcePresent)
208
+
209
+ s """
210
+ | ${code.code}
211
+ |if ( ${code.value}) {
212
+ | $metricUpdateCode
213
+ | ${consume(ctx, projectionExpr)}
214
+ | return;
215
+ |}
216
+ """ .stripMargin
217
+
218
+ case Discard (condition) =>
219
+ val code = generatePredicateCode(ctx, condition, child.output, inputExprs)
220
+ val metricUpdateCode = generateDeleteMetricUpdateCode(ctx, sourcePresent)
221
+
222
+ s """
223
+ | ${code.code}
224
+ |if ( ${code.value}) {
225
+ | $metricUpdateCode
226
+ | return; // Discar row
227
+ |}
228
+ """ .stripMargin
229
+
230
+ case Split (condition, outputExprs, otherOutputExprs) =>
231
+ val projectionExpr = generateProjectionCode(ctx, outputExprs, inputExprs)
232
+ val otherProjectionExpr = generateProjectionCode(ctx, otherOutputExprs, inputExprs)
233
+ val code = generatePredicateCode(ctx, condition, child.output, inputExprs)
234
+ val metricUpdateCode = generateUpdateMetricUpdateCode(ctx, sourcePresent)
235
+
236
+ s """
237
+ | ${code.code}
238
+ |if ( ${code.value}) {
239
+ | $metricUpdateCode
240
+ | ${consume(ctx, projectionExpr)}
241
+ | ${consume(ctx, otherProjectionExpr)}
242
+ | return;
243
+ |}
244
+ """ .stripMargin
245
+ case _ =>
246
+ // Codegen not implemented
247
+ throw new SparkUnsupportedOperationException (
248
+ errorClass = " _LEGACY_ERROR_TEMP_3073" ,
249
+ messageParameters = Map (" instruction" -> instruction.toString))
250
+ }
251
+ }
252
+
253
+ /**
254
+ * metric update code based on Keep's context
255
+ */
256
+ private def generateMetricUpdateCode (ctx : CodegenContext , context : Context ,
257
+ sourcePresent : Boolean ): String = {
258
+ context match {
259
+ case Copy =>
260
+ val copyMetric = metricTerm(ctx, " numTargetRowsCopied" )
261
+ s " $copyMetric.add(1); "
262
+
263
+ case Insert =>
264
+ val insertMetric = metricTerm(ctx, " numTargetRowsInserted" )
265
+ s " $insertMetric.add(1); "
266
+
267
+ case Update =>
268
+ generateUpdateMetricUpdateCode(ctx, sourcePresent)
269
+
270
+ case Delete =>
271
+ generateDeleteMetricUpdateCode(ctx, sourcePresent)
272
+
273
+ case _ =>
274
+ throw new IllegalArgumentException (s " Unexpected context for KeepExec: $context" )
275
+ }
276
+ }
277
+
278
+ private def generateUpdateMetricUpdateCode (ctx : CodegenContext ,
279
+ sourcePresent : Boolean ): String = {
280
+ val updateMetric = metricTerm(ctx, " numTargetRowsUpdated" )
281
+ if (sourcePresent) {
282
+ val matchedUpdateMetric = metricTerm(ctx, " numTargetRowsMatchedUpdated" )
283
+
284
+ s """
285
+ | $updateMetric.add(1);
286
+ | $matchedUpdateMetric.add(1);
287
+ """ .stripMargin
288
+ } else {
289
+ val notMatchedBySourceUpdateMetric = metricTerm(ctx, " numTargetRowsNotMatchedBySourceUpdated" )
290
+
291
+ s """
292
+ | $updateMetric.add(1);
293
+ | $notMatchedBySourceUpdateMetric.add(1);
294
+ """ .stripMargin
295
+ }
296
+ }
297
+
298
+ private def generateDeleteMetricUpdateCode (ctx : CodegenContext ,
299
+ sourcePresent : Boolean ): String = {
300
+ val deleteMetric = metricTerm(ctx, " numTargetRowsDeleted" )
301
+ if (sourcePresent) {
302
+ val matchedDeleteMetric = metricTerm(ctx, " numTargetRowsMatchedDeleted" )
303
+
304
+ s """
305
+ | $deleteMetric.add(1);
306
+ | $matchedDeleteMetric.add(1);
307
+ """ .stripMargin
308
+ } else {
309
+ val notMatchedBySourceDeleteMetric = metricTerm(ctx, " numTargetRowsNotMatchedBySourceDeleted" )
310
+
311
+ s """
312
+ | $deleteMetric.add(1);
313
+ | $notMatchedBySourceDeleteMetric.add(1);
314
+ """ .stripMargin
315
+ }
316
+ }
317
+
318
+ /**
319
+ * Helper method to save and restore CodegenContext state for code generation.
320
+ *
321
+ * This is needed because when generating code for expressions, the CodegenContext
322
+ * state (currentVars and INPUT_ROW) gets modified during expression evaluation.
323
+ * This method temporarily sets the context to the input variables from doConsume
324
+ * and restores the original state after the block completes.
325
+ */
326
+ private def withCodegenContext [T ](
327
+ ctx : CodegenContext ,
328
+ inputCurrentVars : Seq [ExprCode ])(block : => T ): T = {
329
+ val originalCurrentVars = ctx.currentVars
330
+ val originalInputRow = ctx.INPUT_ROW
331
+ try {
332
+ // Set to the input variables saved in doConsume
333
+ ctx.currentVars = inputCurrentVars
334
+ block
335
+ } finally {
336
+ // Restore original context
337
+ ctx.currentVars = originalCurrentVars
338
+ ctx.INPUT_ROW = originalInputRow
339
+ }
340
+ }
341
+
342
+ private def generatePredicateCode (ctx : CodegenContext ,
343
+ predicate : Expression ,
344
+ inputAttrs : Seq [Attribute ],
345
+ inputCurrentVars : Seq [ExprCode ]): ExprCode = {
346
+ withCodegenContext(ctx, inputCurrentVars) {
347
+ val boundPredicate = BindReferences .bindReference(predicate, inputAttrs)
348
+ val ev = boundPredicate.genCode(ctx)
349
+ val predicateVar = ctx.freshName(" predicateResult" )
350
+ val code = code """
351
+ | ${ev.code}
352
+ |boolean $predicateVar = ! ${ev.isNull} && ${ev.value};
353
+ """ .stripMargin
354
+ ExprCode (code, FalseLiteral ,
355
+ JavaCode .variable(predicateVar, BooleanType ))
356
+ }
357
+ }
358
+
359
+ private def generateProjectionCode (ctx : CodegenContext ,
360
+ outputExprs : Seq [Expression ],
361
+ inputCurrentVars : Seq [ExprCode ]): Seq [ExprCode ] = {
362
+ withCodegenContext(ctx, inputCurrentVars) {
363
+ val boundExprs = outputExprs.map(BindReferences .bindReference(_, child.output))
364
+ boundExprs.map(_.genCode(ctx))
365
+ }
366
+ }
367
+
95
368
private def processPartition (rowIterator : Iterator [InternalRow ]): Iterator [InternalRow ] = {
96
369
val isSourceRowPresentPred = createPredicate(isSourceRowPresent)
97
370
val isTargetRowPresentPred = createPredicate(isTargetRowPresent)
0 commit comments