diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/BatchScanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/BatchScanExec.scala index e9a18833ed9a5..5f18e76375210 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/BatchScanExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/BatchScanExec.scala @@ -28,6 +28,9 @@ import org.apache.spark.sql.catalyst.plans.physical.{KeyedPartitioning, SinglePa import org.apache.spark.sql.catalyst.util.{truncatedString, InternalRowComparableWrapper} import org.apache.spark.sql.connector.catalog.Table import org.apache.spark.sql.connector.read._ +import org.apache.spark.sql.connector.write.RowLevelOperation.Command.DELETE +import org.apache.spark.sql.connector.write.RowLevelOperationTable +import org.apache.spark.sql.execution.metric.{SQLLastAttemptMetrics, SQLMetric, SQLMetrics} import org.apache.spark.util.ArrayImplicits._ /** @@ -44,6 +47,20 @@ case class BatchScanExec( @transient lazy val batch: Batch = if (scan == null) null else scan.toBatch + override protected lazy val sparkMetrics: Map[String, SQLMetric] = { + val name = "number of output rows" + val metric = table match { + // Use SLAM for the scan-output count when this scan reads on behalf of a row-level DELETE, + // so that the driver-side derivation `numDeletedRows = numScannedRows - numCopiedRows` in + // `ReplaceDataExec.getWriteSummary` stays correct under stage retries. + case rlot: RowLevelOperationTable if rlot.operation.command() == DELETE => + SQLLastAttemptMetrics.createMetric(sparkContext, name) + case _ => + SQLMetrics.createMetric(sparkContext, name) + } + Map("numOutputRows" -> metric) + } + // TODO: unify the equal/hashCode implementation for all data source v2 query plans. override def equals(other: Any): Boolean = other match { case other: BatchScanExec => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/MergeRowsExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/MergeRowsExec.scala index 526ff843a1496..67212de165e91 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/MergeRowsExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/MergeRowsExec.scala @@ -36,7 +36,7 @@ import org.apache.spark.sql.catalyst.plans.logical.MergeRows.{Context, Copy, Del import org.apache.spark.sql.catalyst.util.truncatedString import org.apache.spark.sql.errors.QueryExecutionErrors import org.apache.spark.sql.execution.{CodegenSupport, SparkPlan, UnaryExecNode} -import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics} +import org.apache.spark.sql.execution.metric.{SQLLastAttemptMetrics, SQLMetric} import org.apache.spark.sql.types.BooleanType case class MergeRowsExec( @@ -50,21 +50,21 @@ case class MergeRowsExec( child: SparkPlan) extends UnaryExecNode with CodegenSupport { override lazy val metrics: Map[String, SQLMetric] = Map( - "numTargetRowsCopied" -> SQLMetrics.createMetric(sparkContext, + "numTargetRowsCopied" -> SQLLastAttemptMetrics.createMetric(sparkContext, "number of target rows copied unmodified because they did not match any action"), - "numTargetRowsInserted" -> SQLMetrics.createMetric(sparkContext, + "numTargetRowsInserted" -> SQLLastAttemptMetrics.createMetric(sparkContext, "number of target rows inserted"), - "numTargetRowsDeleted" -> SQLMetrics.createMetric(sparkContext, + "numTargetRowsDeleted" -> SQLLastAttemptMetrics.createMetric(sparkContext, "number of target rows deleted"), - "numTargetRowsUpdated" -> SQLMetrics.createMetric(sparkContext, + "numTargetRowsUpdated" -> SQLLastAttemptMetrics.createMetric(sparkContext, "number of target rows updated"), - "numTargetRowsMatchedUpdated" -> SQLMetrics.createMetric(sparkContext, + "numTargetRowsMatchedUpdated" -> SQLLastAttemptMetrics.createMetric(sparkContext, "number of target rows updated by a matched clause"), - "numTargetRowsMatchedDeleted" -> SQLMetrics.createMetric(sparkContext, + "numTargetRowsMatchedDeleted" -> SQLLastAttemptMetrics.createMetric(sparkContext, "number of target rows deleted by a matched clause"), - "numTargetRowsNotMatchedBySourceUpdated" -> SQLMetrics.createMetric(sparkContext, + "numTargetRowsNotMatchedBySourceUpdated" -> SQLLastAttemptMetrics.createMetric(sparkContext, "number of target rows updated by a not matched by source clause"), - "numTargetRowsNotMatchedBySourceDeleted" -> SQLMetrics.createMetric(sparkContext, + "numTargetRowsNotMatchedBySourceDeleted" -> SQLLastAttemptMetrics.createMetric(sparkContext, "number of target rows deleted by a not matched by source clause")) @transient override lazy val producedAttributes: AttributeSet = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2Exec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2Exec.scala index ccfcdc1855f04..2ceee6c1ce6a4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2Exec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2Exec.scala @@ -38,7 +38,7 @@ import org.apache.spark.sql.connector.write.RowLevelOperation.Command._ import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryExecutionErrors} import org.apache.spark.sql.execution.{QueryExecution, SparkPlan, UnaryExecNode} import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper -import org.apache.spark.sql.execution.metric.{CustomMetrics, SQLMetric, SQLMetrics} +import org.apache.spark.sql.execution.metric.{CustomMetrics, SQLLastAttemptMetric, SQLLastAttemptMetrics, SQLMetric} import org.apache.spark.sql.types.StructType import org.apache.spark.sql.util.SchemaValidationMode.PROHIBIT_CHANGES import org.apache.spark.util.{LongAccumulator, Utils} @@ -480,20 +480,31 @@ trait RowLevelWriteExec extends V2ExistingTableWriteExec { override protected lazy val sparkMetrics: Map[String, SQLMetric] = rowLevelCommand match { case UPDATE => Map( - "numUpdatedRows" -> SQLMetrics.createMetric(sparkContext, "number of updated rows"), - "numCopiedRows" -> SQLMetrics.createMetric(sparkContext, "number of copied rows")) + "numUpdatedRows" -> + SQLLastAttemptMetrics.createMetric(sparkContext, "number of updated rows"), + "numCopiedRows" -> + SQLLastAttemptMetrics.createMetric(sparkContext, "number of copied rows")) case DELETE => Map( - "numDeletedRows" -> SQLMetrics.createMetric(sparkContext, "number of deleted rows"), - "numCopiedRows" -> SQLMetrics.createMetric(sparkContext, "number of copied rows")) + "numDeletedRows" -> + SQLLastAttemptMetrics.createMetric(sparkContext, "number of deleted rows"), + "numCopiedRows" -> + SQLLastAttemptMetrics.createMetric(sparkContext, "number of copied rows")) case _ => Map.empty } /** - * Returns the value of the named metric, or -1 if the metric is not found. + * Returns the value of the named metric, or -1 if the metric is not found. For + * [[SQLLastAttemptMetric]] values, prefers the last-attempt value so the result is stable across + * stage retries; falls back to the regular accumulator value if the last-attempt value is + * unavailable (e.g. the accumulator bailed out). */ protected def getMetricValue(metrics: Map[String, SQLMetric], name: String): Long = { - metrics.get(name).map(_.value).getOrElse(-1L) + metrics.get(name).map { + case slam: SQLLastAttemptMetric => + slam.lastAttemptValueForHighestRDDId().getOrElse(slam.value) + case m => m.value + }.getOrElse(-1L) } override protected def getWriteSummary(query: SparkPlan): Option[WriteSummary] = { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/DeleteFromTableSuiteBase.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/DeleteFromTableSuiteBase.scala index f8d81ee086911..b6ebcac12f698 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/DeleteFromTableSuiteBase.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/DeleteFromTableSuiteBase.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql.connector +import org.apache.spark.internal.config import org.apache.spark.sql.{AnalysisException, Row} import org.apache.spark.sql.catalyst.expressions.CheckInvariant import org.apache.spark.sql.catalyst.plans.logical.Filter @@ -24,6 +25,7 @@ import org.apache.spark.sql.connector.catalog.{Aborted, Committed} import org.apache.spark.sql.connector.catalog.InMemoryTable import org.apache.spark.sql.connector.write.DeleteSummary import org.apache.spark.sql.execution.datasources.v2.{DeleteFromTableExec, ReplaceDataExec, WriteDeltaExec} +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.sources abstract class DeleteFromTableSuiteBase extends RowLevelOperationSuiteBase { @@ -425,6 +427,46 @@ abstract class DeleteFromTableSuiteBase extends RowLevelOperationSuiteBase { } } + test("metric values are stable across stage retries") { + // Force a shuffle in the DELETE plan via an IN-subquery (with broadcast disabled), then + // have the DAGScheduler corrupt the first attempt of every upstream shuffle map stage so + // the writer stage has to retry. With plain SQLMetrics the writer-side numCopiedRows / + // numDeletedRows and the scan-side numOutputRows would all double up across attempts; + // SQLLastAttemptMetric reports only the last attempt, so the values surfaced via + // `DeleteSummary` (including the group-based driver derivation + // numDeletedRows = numScannedRows - numCopiedRows in `ReplaceDataExec.getWriteSummary`) + // remain correct. + withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1") { + withTempView("source") { + createAndInitTable("pk INT NOT NULL, salary INT, dep STRING", + """{ "pk": 1, "salary": 100, "dep": "hr" } + |{ "pk": 2, "salary": 200, "dep": "software" } + |{ "pk": 3, "salary": 300, "dep": "hr" } + |{ "pk": 4, "salary": 400, "dep": "software" } + |""".stripMargin) + + val sourceDF = Seq(1, 2).toDF("pk") + sourceDF.createOrReplaceTempView("source") + + withSparkContextConf( + config.Tests.INJECT_SHUFFLE_FETCH_FAILURES.key -> "true") { + sql( + s"""DELETE FROM $tableNameAsString + |WHERE pk IN (SELECT pk FROM source) + |""".stripMargin) + } + + checkDeleteMetrics(numDeletedRows = 2, numCopiedRows = 2) + + checkAnswer( + sql(s"SELECT * FROM $tableNameAsString"), + Seq( + Row(3, 300, "hr"), + Row(4, 400, "software"))) + } + } + } + test("delete with NOT IN subqueries") { withTempView("deleted_id", "deleted_dep") { createAndInitTable("pk INT NOT NULL, id INT, dep STRING", diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/MergeIntoTableSuiteBase.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/MergeIntoTableSuiteBase.scala index aaf45f0f5f7a5..5abd05d389003 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/MergeIntoTableSuiteBase.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/MergeIntoTableSuiteBase.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.connector import org.apache.spark.SparkRuntimeException +import org.apache.spark.internal.config import org.apache.spark.sql.{AnalysisException, Row} import org.apache.spark.sql.catalyst.expressions.{AttributeReference, EqualTo, In, Not} import org.apache.spark.sql.catalyst.optimizer.BuildLeft @@ -2663,6 +2664,57 @@ abstract class MergeIntoTableSuiteBase extends RowLevelOperationSuiteBase } } + test("metric values are stable across stage retries") { + // The join in the MERGE plan introduces a shuffle (with broadcast disabled). The + // DAGScheduler corrupts the first attempt of every upstream shuffle map stage, forcing + // the MergeRowsExec stage to retry. With plain SQLMetrics the row counters would double + // up across attempts, but SQLLastAttemptMetric reports only the last attempt, so the + // values surfaced via `MergeSummary` remain correct. + withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1") { + withTempView("source") { + createAndInitTable("pk INT NOT NULL, salary INT, dep STRING", + """{ "pk": 1, "salary": 100, "dep": "hr" } + |{ "pk": 2, "salary": 200, "dep": "software" } + |{ "pk": 3, "salary": 300, "dep": "hr" } + |""".stripMargin) + + val sourceDF = Seq(1, 2, 10).toDF("pk") + sourceDF.createOrReplaceTempView("source") + + withSparkContextConf( + config.Tests.INJECT_SHUFFLE_FETCH_FAILURES.key -> "true") { + sql( + s"""MERGE INTO $tableNameAsString t + |USING source s + |ON t.pk = s.pk + |WHEN MATCHED THEN + | UPDATE SET salary = salary + 100 + |WHEN NOT MATCHED THEN + | INSERT (pk, salary, dep) VALUES (s.pk, 999, 'unknown') + |""".stripMargin) + } + + val mergeSummary = getMergeSummary() + assert(mergeSummary.numTargetRowsUpdated === 2L) + assert(mergeSummary.numTargetRowsMatchedUpdated === 2L) + assert(mergeSummary.numTargetRowsInserted === 1L) + assert(mergeSummary.numTargetRowsCopied === (if (deltaMerge) 0L else 1L)) + assert(mergeSummary.numTargetRowsDeleted === 0L) + assert(mergeSummary.numTargetRowsMatchedDeleted === 0L) + assert(mergeSummary.numTargetRowsNotMatchedBySourceUpdated === 0L) + assert(mergeSummary.numTargetRowsNotMatchedBySourceDeleted === 0L) + + checkAnswer( + sql(s"SELECT pk, salary FROM $tableNameAsString ORDER BY pk"), + Seq( + Row(1, 200), + Row(2, 300), + Row(3, 300), + Row(10, 999))) + } + } + } + test("SPARK-55074: imerge with type coercion from INT to STRING") { // INT -> STRING is allowed in ANSI mode, merge should succeed via type coercion // without requiring schema evolution diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/UpdateTableSuiteBase.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/UpdateTableSuiteBase.scala index c5264ca87a70d..dc5cc3774eecb 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/UpdateTableSuiteBase.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/UpdateTableSuiteBase.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.connector import org.apache.spark.SparkRuntimeException +import org.apache.spark.internal.config import org.apache.spark.sql.{sources, AnalysisException, Row} import org.apache.spark.sql.connector.catalog.{Aborted, Column, ColumnDefaultValue, Committed, InMemoryTable, TableChange, TableInfo} import org.apache.spark.sql.connector.expressions.{GeneralScalarExpression, LiteralValue} @@ -340,6 +341,46 @@ abstract class UpdateTableSuiteBase extends RowLevelOperationSuiteBase { checkUpdateMetrics(numUpdatedRows = 2, numCopiedRows = 1) } + test("metric values are stable across stage retries") { + // Force a shuffle in the UPDATE plan via an IN-subquery (with broadcast disabled), then + // have the DAGScheduler corrupt the first attempt of every upstream shuffle map stage so + // the writer stage has to retry. With a plain SQLMetric the row counters would double up + // across attempts, but SQLLastAttemptMetric reports only the last attempt, so the values + // surfaced via `UpdateSummary` remain correct. + withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1") { + withTempView("source") { + createAndInitTable("pk INT NOT NULL, salary INT, dep STRING", + """{ "pk": 1, "salary": 100, "dep": "hr" } + |{ "pk": 2, "salary": 200, "dep": "software" } + |{ "pk": 3, "salary": 300, "dep": "hr" } + |{ "pk": 4, "salary": 400, "dep": "software" } + |""".stripMargin) + + val sourceDF = Seq(1, 2).toDF("pk") + sourceDF.createOrReplaceTempView("source") + + withSparkContextConf( + config.Tests.INJECT_SHUFFLE_FETCH_FAILURES.key -> "true") { + sql( + s"""UPDATE $tableNameAsString + |SET salary = salary + 100 + |WHERE pk IN (SELECT pk FROM source) + |""".stripMargin) + } + + checkUpdateMetrics(numUpdatedRows = 2, numCopiedRows = 2) + + checkAnswer( + sql(s"SELECT * FROM $tableNameAsString"), + Seq( + Row(1, 200, "hr"), + Row(2, 300, "software"), + Row(3, 300, "hr"), + Row(4, 400, "software"))) + } + } + } + test("update nested struct fields") { createAndInitTable( s"""pk INT NOT NULL,