diff --git a/common/utils/src/main/resources/error/error-conditions.json b/common/utils/src/main/resources/error/error-conditions.json index 993ffd888e0e7..eacea816db137 100644 --- a/common/utils/src/main/resources/error/error-conditions.json +++ b/common/utils/src/main/resources/error/error-conditions.json @@ -498,6 +498,13 @@ }, "sqlState" : "0A000" }, + "CANNOT_UPDATE_PARTITION_COLUMNS" : { + "message" : [ + "Declared partitioning conflicts with existing table partitioning .", + "Please delete the table or change the declared partitioning to match its partitions." + ], + "sqlState" : "42000" + }, "CANNOT_UP_CAST_DATATYPE" : { "message" : [ "Cannot up cast from to .", diff --git a/common/utils/src/main/scala/org/apache/spark/internal/LogKey.scala b/common/utils/src/main/scala/org/apache/spark/internal/LogKey.scala index 1f997592dbfb7..b0fae3fd9443b 100644 --- a/common/utils/src/main/scala/org/apache/spark/internal/LogKey.scala +++ b/common/utils/src/main/scala/org/apache/spark/internal/LogKey.scala @@ -291,6 +291,7 @@ private[spark] object LogKeys { case object FINAL_PATH extends LogKey case object FINISH_TIME extends LogKey case object FINISH_TRIGGER_DURATION extends LogKey + case object FLOW_NAME extends LogKey case object FREE_MEMORY_SIZE extends LogKey case object FROM_OFFSET extends LogKey case object FROM_TIME extends LogKey diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 478b92de0b8e4..8246fff00f7e8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -5885,6 +5885,70 @@ object SQLConf { .booleanConf .createWithDefault(true) + val PIPELINES_STREAM_STATE_POLLING_INTERVAL = { + buildConf("spark.sql.pipelines.execution.streamstate.pollingInterval") + .doc( + "Interval in seconds at which the stream state is polled for changes. This is used to " + + "check if the stream has failed and needs to be restarted." + ) + .version("4.1.0") + .timeConf(TimeUnit.SECONDS) + .createWithDefault(1) + } + + val PIPELINES_WATCHDOG_MIN_RETRY_TIME_IN_SECONDS = { + buildConf("spark.sql.pipelines.execution.watchdog.minRetryTime") + .doc( + "Initial duration in seconds between the time when we notice a flow has failed and " + + "when we try to restart the flow. The interval between flow restarts doubles with " + + "every stream failure up to the maximum value set in " + + "`pipelines.execution.watchdog.maxRetryTime`." + ) + .version("4.1.0") + .timeConf(TimeUnit.SECONDS) + .checkValue(v => v > 0, "Watchdog minimum retry time must be at least 1 second.") + .createWithDefault(5) + } + + val PIPELINES_WATCHDOG_MAX_RETRY_TIME_IN_SECONDS = { + buildConf("spark.sql.pipelines.execution.watchdog.maxRetryTime") + .doc( + "Maximum time interval in seconds at which flows will be restarted." + ) + .version("4.1.0") + .timeConf(TimeUnit.SECONDS) + .createWithDefault(3600) + } + + val PIPELINES_MAX_CONCURRENT_FLOWS = { + buildConf("spark.sql.pipelines.execution.maxConcurrentFlows") + .doc( + "Max number of flows to execute at once. Used to tune performance for triggered " + + "pipelines. Has no effect on continuous pipelines." + ) + .version("4.1.0") + .intConf + .createWithDefault(16) + } + + + val PIPELINES_TIMEOUT_MS_FOR_TERMINATION_JOIN_AND_LOCK = { + buildConf("spark.sql.pipelines.timeoutMsForTerminationJoinAndLock") + .doc("Timeout in milliseconds to grab a lock for stopping update - default is 1hr.") + .version("4.1.0") + .timeConf(TimeUnit.MILLISECONDS) + .checkValue(v => v > 0L, "Timeout for lock must be at least 1 millisecond.") + .createWithDefault(60 * 60 * 1000) + } + + val PIPELINES_MAX_FLOW_RETRY_ATTEMPTS = { + buildConf("spark.sql.pipelines.maxFlowRetryAttempts") + .doc("Maximum number of times a flow can be retried") + .version("4.1.0") + .intConf + .createWithDefault(2) + } + /** * Holds information about keys that have been deprecated. * diff --git a/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/common/GraphStates.scala b/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/common/GraphStates.scala index 7ec4c9147e5eb..6a4722ba8b3c8 100644 --- a/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/common/GraphStates.scala +++ b/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/common/GraphStates.scala @@ -43,6 +43,22 @@ object FlowStatus { case object IDLE extends FlowStatus } +sealed trait RunState + +object RunState { + // Run is currently executing queries. + case object RUNNING extends RunState + + // Run is complete and all necessary resources are cleaned up. + case object COMPLETED extends RunState + + // Run has run into an error that could not be recovered from. + case object FAILED extends RunState + + // Run was canceled. + case object CANCELED extends RunState +} + // The type of the dataset. sealed trait DatasetType object DatasetType { diff --git a/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/DataflowGraph.scala b/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/DataflowGraph.scala index 585ba6295f239..0263a2fef4f44 100644 --- a/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/DataflowGraph.scala +++ b/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/DataflowGraph.scala @@ -199,7 +199,7 @@ case class DataflowGraph(flows: Seq[Flow], tables: Seq[Table], views: Seq[View]) * streaming tables without a query; such tables should still have at least one flow * writing to it. */ - def validateEveryDatasetHasFlow(): Unit = { + private def validateEveryDatasetHasFlow(): Unit = { (tables.map(_.identifier) ++ views.map(_.identifier)).foreach { identifier => if (!flows.exists(_.destinationIdentifier == identifier)) { throw new AnalysisException( diff --git a/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/DatasetManager.scala b/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/DatasetManager.scala new file mode 100644 index 0000000000000..79c5ef36b0bc4 --- /dev/null +++ b/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/DatasetManager.scala @@ -0,0 +1,294 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.pipelines.graph + +import scala.jdk.CollectionConverters._ +import scala.util.control.{NonFatal, NoStackTrace} + +import org.apache.spark.SparkException +import org.apache.spark.internal.{Logging, LogKeys, MDC} +import org.apache.spark.sql.AnalysisException +import org.apache.spark.sql.catalyst.TableIdentifier +import org.apache.spark.sql.connector.catalog.{ + CatalogV2Util, + Identifier, + TableCatalog, + TableChange, + TableInfo +} +import org.apache.spark.sql.connector.expressions.Expressions +import org.apache.spark.sql.pipelines.graph.QueryOrigin.ExceptionHelpers +import org.apache.spark.sql.pipelines.util.SchemaInferenceUtils.diffSchemas +import org.apache.spark.sql.pipelines.util.SchemaMergingUtils + +/** + * `DatasetManager` is responsible for materializing tables in the catalog based on the given + * graph. For each table in the graph, it will create a table if none exists (or if this is a + * full refresh), or merge the schema of an existing table to match the new flows writing to it. + */ +object DatasetManager extends Logging { + + /** + * Wraps table materialization exceptions. + * + * The target use case of this exception is merely as a means to capture attribution - + * 1. Indicate that the exception is associated with table materialization. + * 2. Indicate which table materialization failed for. + * + * @param tableName The name of the table that failed to materialize. + * @param cause The underlying exception that caused the materialization to fail. + */ + case class TableMaterializationException( + tableName: String, + cause: Throwable + ) extends Exception(cause) + with NoStackTrace + + /** + * Materializes the tables in the given graph. This method will create or update the tables + * in the catalog based on the given graph and context. + * + * @param resolvedDataflowGraph The resolved [[DataflowGraph]] with resolved [[Flow]] sorted + * in topological order. + * @param context The context for the pipeline update. + * @return The graph with materialized tables. + */ + def materializeDatasets( + resolvedDataflowGraph: DataflowGraph, + context: PipelineUpdateContext + ): DataflowGraph = { + val (_, refreshTableIdentsSet, fullRefreshTableIdentsSet) = { + DatasetManager.constructFullRefreshSet(resolvedDataflowGraph.tables, context) + } + + /** Return all the tables that need to be materialized from the given graph. */ + def tablesToMatz(graph: DataflowGraph): Seq[TableRefreshType] = { + graph.tables + .filter(t => fullRefreshTableIdentsSet.contains(t.identifier)) + .map(table => TableRefreshType(table, isFullRefresh = true)) ++ + graph.tables + .filter(t => refreshTableIdentsSet.contains(t.identifier)) + .map(table => TableRefreshType(table, isFullRefresh = false)) + } + + val tablesToMaterialize = { + tablesToMatz(resolvedDataflowGraph).map(t => t.table.identifier -> t).toMap + } + + // materialized [[DataflowGraph]] where each table has been materialized and each table + // has metadata (e.g., normalized table storage path) populated + val materializedGraph: DataflowGraph = try { + DataflowGraphTransformer + .withDataflowGraphTransformer(resolvedDataflowGraph) { transformer => + transformer.transformTables { table => + if (tablesToMaterialize.keySet.contains(table.identifier)) { + try { + materializeTable( + resolvedDataflowGraph = resolvedDataflowGraph, + table = table, + isFullRefresh = tablesToMaterialize(table.identifier).isFullRefresh, + context = context + ) + } catch { + case NonFatal(e) => + throw TableMaterializationException( + table.displayName, + cause = e.addOrigin(table.origin) + ) + } + } else { + table + } + } + // TODO: Publish persisted views to the metastore. + } + .getDataflowGraph + } catch { + case e: SparkException if e.getCause != null => throw e.getCause + } + + materializedGraph + } + + /** + * Materializes a table in the catalog. This method will create or update the table in the + * catalog based on the given table and context. + * @param resolvedDataflowGraph The resolved [[DataflowGraph]] used to infer the table schema. + * @param table The table to be materialized. + * @param isFullRefresh Whether this table should be full refreshed or not. + * @param context The context for the pipeline update. + * @return The materialized table (with additional metadata set). + */ + private def materializeTable( + resolvedDataflowGraph: DataflowGraph, + table: Table, + isFullRefresh: Boolean, + context: PipelineUpdateContext + ): Table = { + logInfo(log"Materializing metadata for table ${MDC(LogKeys.TABLE_NAME, table.identifier)}.") + val catalogManager = context.spark.sessionState.catalogManager + val catalog = (table.identifier.catalog match { + case Some(catalogName) => + catalogManager.catalog(catalogName) + case None => + catalogManager.currentCatalog + }).asInstanceOf[TableCatalog] + + val identifier = + Identifier.of(Array(table.identifier.database.get), table.identifier.identifier) + val outputSchema = table.specifiedSchema.getOrElse( + resolvedDataflowGraph.inferredSchema(table.identifier).asNullable + ) + val mergedProperties = resolveTableProperties(table, identifier) + val partitioning = table.partitionCols.toSeq.flatten.map(Expressions.identity) + + val existingTableOpt = if (catalog.tableExists(identifier)) { + Some(catalog.loadTable(identifier)) + } else { + None + } + + // Error if partitioning doesn't match + if (existingTableOpt.isDefined) { + val existingPartitioning = existingTableOpt.get.partitioning().toSeq + if (existingPartitioning != partitioning) { + throw new AnalysisException( + errorClass = "CANNOT_UPDATE_PARTITION_COLUMNS", + messageParameters = Map( + "existingPartitionColumns" -> existingPartitioning.mkString(", "), + "requestedPartitionColumns" -> partitioning.mkString(", ") + ) + ) + } + } + + // Wipe the data if we need to + if ((isFullRefresh || !table.isStreamingTableOpt.get) && existingTableOpt.isDefined) { + context.spark.sql(s"TRUNCATE TABLE ${table.identifier.quotedString}") + } + + // Alter the table if we need to + if (existingTableOpt.isDefined) { + val existingSchema = existingTableOpt.get.schema() + + val targetSchema = if (table.isStreamingTableOpt.get && !isFullRefresh) { + SchemaMergingUtils.mergeSchemas(existingSchema, outputSchema) + } else { + outputSchema + } + + val columnChanges = diffSchemas(existingSchema, targetSchema) + val setProperties = mergedProperties.map { case (k, v) => TableChange.setProperty(k, v) } + catalog.alterTable(identifier, (columnChanges ++ setProperties).toArray: _*) + } + + // Create the table if we need to + if (existingTableOpt.isEmpty) { + catalog.createTable( + identifier, + new TableInfo.Builder() + .withProperties(mergedProperties.asJava) + .withColumns(CatalogV2Util.structTypeToV2Columns(outputSchema)) + .withPartitions(partitioning.toArray) + .build() + ) + } + + table.copy( + normalizedPath = Option( + catalog.loadTable(identifier).properties().get(TableCatalog.PROP_LOCATION) + ) + ) + } + + /** + * Some fields on the [[Table]] object are represented as reserved table properties by the catalog + * APIs. This method creates a table properties map that merges the user-provided table properties + * with these reserved properties. + */ + private def resolveTableProperties(table: Table, identifier: Identifier): Map[String, String] = { + val validatedAndCanonicalizedProps = + PipelinesTableProperties.validateAndCanonicalize( + table.properties, + warnFunction = s => logWarning(s) + ) + + val specialProps = Seq( + (table.comment, "comment", TableCatalog.PROP_COMMENT), + (table.format, "format", TableCatalog.PROP_PROVIDER) + ).map { + case (value, name, reservedPropKey) => + validatedAndCanonicalizedProps.get(reservedPropKey).foreach { pc => + if (value.isDefined && value.get != pc) { + throw new IllegalArgumentException( + s"For dataset $identifier, $name '${value.get}' does not match value '$pc' for " + + s"reserved table property '$reservedPropKey''" + ) + } + } + reservedPropKey -> value + } + .collect { case (key, Some(value)) => key -> value } + + validatedAndCanonicalizedProps ++ specialProps + } + + /** + * A case class that represents the type of refresh for a table. + * @param table The table to be refreshed. + * @param isFullRefresh Whether this table should be fully refreshed or not. + */ + private case class TableRefreshType(table: Table, isFullRefresh: Boolean) + + /** + * Constructs the set of tables that should be fully refreshed and the set of tables that + * should be refreshed. + */ + private def constructFullRefreshSet( + graphTables: Seq[Table], + context: PipelineUpdateContext + ): (Seq[Table], Seq[TableIdentifier], Seq[TableIdentifier]) = { + val (fullRefreshTablesSet, refreshTablesSet) = { + val specifiedFullRefreshTables = context.fullRefreshTables.filter(graphTables) + val specifiedRefreshTables = context.refreshTables.filter(graphTables) + + val (fullRefreshAllowed, fullRefreshNotAllowed) = specifiedFullRefreshTables.partition { t => + PipelinesTableProperties.resetAllowed.fromMap(t.properties) + } + + val refreshTables = (specifiedRefreshTables ++ fullRefreshNotAllowed).filterNot { t => + fullRefreshAllowed.contains(t) + } + + if (fullRefreshNotAllowed.nonEmpty) { + logInfo( + log"Skipping full refresh on some tables because " + + log"${MDC(LogKeys.PROPERTY_NAME, PipelinesTableProperties.resetAllowed.key)} " + + log"was set to false. Tables: " + + log"${MDC(LogKeys.TABLE_NAME, fullRefreshNotAllowed.map(_.identifier))}" + ) + } + + (fullRefreshAllowed, refreshTables) + } + val allRefreshTables = fullRefreshTablesSet ++ refreshTablesSet + val refreshTableIdentsSet = refreshTablesSet.map(_.identifier) + val fullRefreshTableIdentsSet = fullRefreshTablesSet.map(_.identifier) + (allRefreshTables, refreshTableIdentsSet, fullRefreshTableIdentsSet) + } +} diff --git a/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/FlowAnalysisContext.scala b/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/FlowAnalysisContext.scala index fb96c6cb5bb1d..1139946df59ac 100644 --- a/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/FlowAnalysisContext.scala +++ b/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/FlowAnalysisContext.scala @@ -25,9 +25,9 @@ import org.apache.spark.sql.classic.SparkSession import org.apache.spark.sql.pipelines.AnalysisWarning /** - * A context used when evaluating a [[Flow]]'s query into a concrete DataFrame. + * A context used when evaluating a `Flow`'s query into a concrete DataFrame. * - * @param allInputs Set of identifiers for all [[Input]]s defined in the DataflowGraph. + * @param allInputs Set of identifiers for all `Input`s defined in the DataflowGraph. * @param availableInputs Inputs available to be referenced with `read` or `readStream`. * @param queryContext The context of the query being evaluated. * @param requestedInputs A mutable buffer populated with names of all inputs that were @@ -49,7 +49,7 @@ private[pipelines] case class FlowAnalysisContext( externalInputs: mutable.HashSet[TableIdentifier] = mutable.HashSet.empty ) { - /** Map from [[Input]] name to the actual [[Input]] */ + /** Map from `Input` name to the actual `Input` */ val availableInput: Map[TableIdentifier, Input] = availableInputs.map(i => i.identifier -> i).toMap diff --git a/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/FlowExecution.scala b/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/FlowExecution.scala new file mode 100644 index 0000000000000..5c981a2442edd --- /dev/null +++ b/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/FlowExecution.scala @@ -0,0 +1,263 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.pipelines.graph + +import java.util.concurrent.ThreadPoolExecutor +import java.util.concurrent.atomic.AtomicBoolean + +import scala.concurrent.{ExecutionContext, Future} +import scala.util.{Failure, Success} +import scala.util.control.NonFatal + +import org.apache.spark.internal.{Logging, LogKeys, MDC} +import org.apache.spark.sql.catalyst.TableIdentifier +import org.apache.spark.sql.classic.SparkSession +import org.apache.spark.sql.pipelines.graph.QueryOrigin.ExceptionHelpers +import org.apache.spark.sql.pipelines.util.SparkSessionUtils +import org.apache.spark.sql.streaming.{OutputMode, StreamingQuery, Trigger} +import org.apache.spark.util.ThreadUtils + +/** + * A flow's execution may complete for two reasons: + * 1. it may finish performing all of its necessary work, or + * 2. it may be interrupted by a request from a user to stop it. + * + * We use this result to disambiguate these two cases, using 'ExecutionResult.FINISHED' + * for the former and 'ExecutionResult.STOPPED' for the latter. + */ +sealed trait ExecutionResult +object ExecutionResult { + case object FINISHED extends ExecutionResult + case object STOPPED extends ExecutionResult +} + +/** A `FlowExecution` specifies how to execute a flow and manages its execution. */ +trait FlowExecution { + + /** Identifier of this physical flow */ + def identifier: TableIdentifier + + /** + * Returns a user-visible name for the flow. + */ + final def displayName: String = identifier.unquotedString + + /** + * SparkSession to execute this physical flow with. + * + * The default value for streaming flows is the pipeline's spark session because the source + * dataframe is resolved using the pipeline's spark session, and a new session will be started + * implicitly by the streaming query. + * + * The default value for batch flows is a cloned spark session from the pipeline's spark session. + * + * Please make sure that the execution thread runs in a different spark session than the + * pipeline's spark session. + */ + protected def spark: SparkSession = updateContext.spark + + /** + * Origin to use when recording events for this flow. + */ + def getOrigin: QueryOrigin + + /** + * Returns true if and only if this `FlowExecution` has been completed with + * either success or an exception. + */ + def isCompleted: Boolean = _future.exists(_.isCompleted) + + /** Returns true iff this `FlowExecution` executes using Spark Structured Streaming. */ + def isStreaming: Boolean + + /** Retrieves the future that can be used to track execution status. */ + def getFuture: Future[ExecutionResult] = { + _future.getOrElse( + throw new IllegalStateException(s"FlowExecution $identifier has not been executed.") + ) + } + + /** Tracks the currently running future. */ + private final var _future: Option[Future[ExecutionResult]] = None + + /** Context about this pipeline update. */ + def updateContext: PipelineUpdateContext + + /** The thread execution context for the current `FlowExecution`. */ + implicit val executionContext: ExecutionContext = { + ExecutionContext.fromExecutor(FlowExecution.threadPool) + } + + /** + * Stops execution of this `FlowExecution`. If you override this, please be sure to + * call `super.stop()` at the beginning of your method, so we can properly handle errors + * when a user tries to stop a flow. + */ + def stop(): Unit = { + stopped.set(true) + } + + /** Returns an optional exception that occurred during execution, if any. */ + def exception: Option[Throwable] = _future.flatMap(_.value).flatMap(_.failed.toOption) + + /** + * Executes this FlowExecution synchronously to perform its intended update. + * This method should be overridden by subclasses to provide the actual execution logic. + * + * @return a Future that completes when the execution is finished or stopped. + */ + def executeInternal(): Future[Unit] + + /** + * Executes this FlowExecution asynchronously to perform its intended update. A future that can be + * used to track execution status is saved, and can be retrieved with `getFuture`. + */ + final def executeAsync(): Unit = { + if (_future.isDefined) { + throw new IllegalStateException( + s"FlowExecution ${identifier.unquotedString} has already been executed." + ) + } + + val queryOrigin = QueryOrigin(filePath = getOrigin.filePath) + + _future = try { + Option( + executeInternal() + .transform { + case Success(_) => Success(ExecutionResult.FINISHED) + case Failure(e) => Failure(e) + } + .map(_ => ExecutionResult.FINISHED) + .recover { + case _: Throwable if stopped.get() => + ExecutionResult.STOPPED + } + ) + } catch { + case NonFatal(e) => + // Add query origin to exceptions raised while starting a flow + throw e.addOrigin(queryOrigin) + } + } + + /** The destination that this `FlowExecution` is writing to. */ + def destination: Output + + /** Whether this `FlowExecution` has been stopped. Set by `FlowExecution.stop()`. */ + private val stopped: AtomicBoolean = new AtomicBoolean(false) +} + +object FlowExecution { + + /** A thread pool used to execute `FlowExecutions`. */ + private val threadPool: ThreadPoolExecutor = { + ThreadUtils.newDaemonCachedThreadPool("FlowExecution") + } +} + +/** A 'FlowExecution' that processes data statefully using Structured Streaming. */ +trait StreamingFlowExecution extends FlowExecution with Logging { + + /** The `ResolvedFlow` that this `StreamingFlowExecution` is executing. */ + def flow: ResolvedFlow + + /** Structured Streaming checkpoint. */ + def checkpointPath: String + + /** Structured Streaming trigger. */ + def trigger: Trigger + + def isStreaming: Boolean = true + + /** Spark confs that must be set when starting this flow. */ + protected def sqlConf: Map[String, String] + + /** Starts a stream and returns its streaming query. */ + protected def startStream(): StreamingQuery + + /** + * Executes this `StreamingFlowExecution` by starting its stream with the correct scheduling pool + * and confs. + */ + override final def executeInternal(): Future[Unit] = { + logInfo( + log"Starting ${MDC(LogKeys.TABLE_NAME, identifier)} with " + + log"checkpoint location ${MDC(LogKeys.CHECKPOINT_PATH, checkpointPath)}" + ) + val streamingQuery = SparkSessionUtils.withSqlConf(spark, sqlConf.toList: _*)(startStream()) + Future(streamingQuery.awaitTermination()) + } +} + +/** A `StreamingFlowExecution` that writes a streaming `DataFrame` to a `Table`. */ +class StreamingTableWrite( + val identifier: TableIdentifier, + val flow: ResolvedFlow, + val graph: DataflowGraph, + val updateContext: PipelineUpdateContext, + val checkpointPath: String, + val trigger: Trigger, + val destination: Table, + val sqlConf: Map[String, String] +) extends StreamingFlowExecution { + + override def getOrigin: QueryOrigin = flow.origin + + def startStream(): StreamingQuery = { + val data = graph.reanalyzeFlow(flow).df + val dataStreamWriter = data.writeStream + .queryName(displayName) + .option("checkpointLocation", checkpointPath) + .trigger(trigger) + .outputMode(OutputMode.Append()) + if (destination.format.isDefined) { + dataStreamWriter.format(destination.format.get) + } + dataStreamWriter.toTable(destination.identifier.unquotedString) + } +} + +/** A `FlowExecution` that writes a batch `DataFrame` to a `Table`. */ +class BatchTableWrite( + val identifier: TableIdentifier, + val flow: ResolvedFlow, + val graph: DataflowGraph, + val destination: Table, + val updateContext: PipelineUpdateContext, + val sqlConf: Map[String, String] +) extends FlowExecution { + + override def isStreaming: Boolean = false + override def getOrigin: QueryOrigin = flow.origin + + def executeInternal(): scala.concurrent.Future[Unit] = + SparkSessionUtils.withSqlConf(spark, sqlConf.toList: _*) { + updateContext.flowProgressEventLogger.recordRunning(flow = flow) + val data = graph.reanalyzeFlow(flow).df + Future { + val dataFrameWriter = data.write + if (destination.format.isDefined) { + dataFrameWriter.format(destination.format.get) + } + dataFrameWriter + .mode("append") + .saveAsTable(destination.identifier.unquotedString) + } + } +} diff --git a/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/FlowPlanner.scala b/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/FlowPlanner.scala new file mode 100644 index 0000000000000..bb154a0081da5 --- /dev/null +++ b/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/FlowPlanner.scala @@ -0,0 +1,77 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.pipelines.graph + +import org.apache.spark.sql.streaming.Trigger + +/** + * Plans execution of `Flow`s in a `DataflowGraph` by converting `Flow`s into + * 'FlowExecution's. + * + * @param graph `DataflowGraph` to help plan based on relationship to other elements. + * @param updateContext `PipelineUpdateContext` for this pipeline update (shared across flows). + * @param triggerFor Function that returns the correct streaming Trigger for the specified + * `Flow`. + */ +class FlowPlanner( + graph: DataflowGraph, + updateContext: PipelineUpdateContext, + triggerFor: Flow => Trigger +) { + + /** + * Turns a [[Flow]] into an executable [[FlowExecution]]. + */ + def plan(flow: ResolvedFlow): FlowExecution = { + val output = graph.output(flow.destinationIdentifier) + flow match { + case cf: CompleteFlow => + new BatchTableWrite( + graph = graph, + flow = flow, + identifier = cf.identifier, + sqlConf = cf.sqlConf, + destination = output.asInstanceOf[Table], + updateContext = updateContext + ) + case sf: StreamingFlow => + output match { + case o: Table => + new StreamingTableWrite( + graph = graph, + flow = flow, + identifier = sf.identifier, + destination = o, + updateContext = updateContext, + sqlConf = sf.sqlConf, + trigger = triggerFor(sf), + checkpointPath = output.path + ) + case _ => + throw new UnsupportedOperationException( + s"Streaming flow ${sf.identifier} cannot write to non-table destination: " + + s"${output.getClass.getSimpleName} (${flow.destinationIdentifier})" + ) + } + case _ => + throw new UnsupportedOperationException( + s"Unable to plan flow of type ${flow.getClass.getSimpleName}" + ) + } + } +} diff --git a/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/GraphExecution.scala b/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/GraphExecution.scala new file mode 100644 index 0000000000000..fdb94c82d868c --- /dev/null +++ b/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/GraphExecution.scala @@ -0,0 +1,294 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.pipelines.graph +import java.util.concurrent.{ConcurrentHashMap, TimeoutException} + +import scala.concurrent.ExecutionContext +import scala.jdk.CollectionConverters._ +import scala.util.{Failure, Success} + +import org.apache.spark.internal.{Logging, LogKeys, MDC} +import org.apache.spark.sql.catalyst.TableIdentifier +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.pipelines.logging.StreamListener +import org.apache.spark.sql.streaming.Trigger + +abstract class GraphExecution( + val graphForExecution: DataflowGraph, + env: PipelineUpdateContext +) extends Logging { + + /** The `Trigger` configuration for a streaming flow. */ + def streamTrigger(flow: Flow): Trigger + + protected val pipelineConf: PipelineConf = env.pipelineConf + + /** Maps flow identifier to count of consecutive failures. Used to manage flow retries */ + private val flowToNumConsecutiveFailure = new ConcurrentHashMap[TableIdentifier, Int].asScala + + /** Maps flow identifier to count of successful runs. Used to populate batch id. */ + private val flowToNumSuccess = new ConcurrentHashMap[TableIdentifier, Long].asScala + + /** + * `FlowExecution`s currently being executed and tracked by the graph execution. + */ + val flowExecutions = new collection.concurrent.TrieMap[TableIdentifier, FlowExecution] + + /** Increments flow execution retry count for `flow`. */ + private def incrementFlowToNumConsecutiveFailure(flowIdentifier: TableIdentifier): Unit = { + flowToNumConsecutiveFailure.put(flowIdentifier, flowToNumConsecutiveFailure(flowIdentifier) + 1) + } + + /** + * Planner use to convert each logical dataflow (i.e., `Flow`) defined in the + * `DataflowGraph` into a concrete execution plan `FlowExecution` used by the + * pipeline execution. + */ + private val flowPlanner = new FlowPlanner( + graph = graphForExecution, + updateContext = env, + triggerFor = streamTrigger + ) + + /** Listener to process streaming events and metrics. */ + private val streamListener = new StreamListener(env, graphForExecution) + + /** + * Plans the logical `ResolvedFlow` into a `FlowExecution` and then starts executing it. + * Implementation note: Thread safe + * + * @return None if the flow planner decided that there is no actual update required here. + * Otherwise returns the corresponding physical flow. + */ + def planAndStartFlow(flow: ResolvedFlow): Option[FlowExecution] = { + try { + val flowExecution = flowPlanner.plan( + flow = graphForExecution.resolvedFlow(flow.identifier) + ) + + env.flowProgressEventLogger.recordStart(flowExecution) + + flowExecution.executeAsync() + flowExecutions.put(flow.identifier, flowExecution) + implicit val ec: ExecutionContext = flowExecution.executionContext + + // Note: The asynchronous handling here means that completed events might be recorded after + // initializing events for the next retry of this flow. + flowExecution.getFuture.onComplete { + case Failure(ex) if !flowExecution.isStreaming => + incrementFlowToNumConsecutiveFailure(flow.identifier) + env.flowProgressEventLogger.recordFailed( + flow = flow, + exception = ex, + // Log as warn if flow has retries left + logAsWarn = { + flowToNumConsecutiveFailure(flow.identifier) < + 1 + maxRetryAttemptsForFlow(flow.identifier) + } + ) + case Success(ExecutionResult.STOPPED) => + // We already recorded a STOPPED event in [[FlowExecution.stopFlow()]]. + // We don't need to log another one here. + case Success(ExecutionResult.FINISHED) if !flowExecution.isStreaming => + // Reset consecutive failure count on success + flowToNumConsecutiveFailure.put(flow.identifier, 0) + flowToNumSuccess.put( + flow.identifier, + flowToNumSuccess.getOrElse(flow.identifier, 0L) + 1L + ) + env.flowProgressEventLogger.recordCompletion(flow) + case _ => // Handled by StreamListener + } + Option(flowExecution) + } catch { + // This is if the flow fails to even start. + case ex: Throwable => + logError( + log"Unhandled exception while starting flow:${MDC(LogKeys.FLOW_NAME, flow.displayName)}", + ex + ) + // InterruptedException is thrown when the thread executing `startFlow` is interrupted. + if (ex.isInstanceOf[InterruptedException]) { + env.flowProgressEventLogger.recordStop(flow) + } else { + env.flowProgressEventLogger.recordFailed( + flow = flow, + exception = ex, + logAsWarn = false + ) + } + throw ex + } + } + + /** + * Starts the execution of flows in `graphForExecution`. Does not block. + */ + def start(): Unit = { + env.spark.listenerManager.clear() + env.spark.streams.addListener(streamListener) + } + + /** + * Stops this execution by stopping all streams and terminating any other resources. + * + * This method may be called multiple times due to race conditions and must be idempotent. + */ + def stop(): Unit = { + env.spark.streams.removeListener(streamListener) + } + + /** Stops execution of a `FlowExecution`. */ + def stopFlow(pf: FlowExecution): Unit = { + if (!pf.isCompleted) { + val flow = graphForExecution.resolvedFlow(pf.identifier) + try { + logInfo(log"Stopping ${MDC(LogKeys.FLOW_NAME, pf.identifier)}") + pf.stop() + } catch { + case e: Throwable => + val message = s"Error stopping flow ${pf.identifier}" + logError(message, e) + env.flowProgressEventLogger.recordFailed( + flow = flow, + exception = e, + logAsWarn = false, + messageOpt = Option(s"Flow '${pf.displayName}' has failed to stop.") + ) + throw e + } + env.flowProgressEventLogger.recordStop(flow) + logInfo(log"Stopped ${MDC(LogKeys.FLOW_NAME, pf.identifier)}") + } else { + logWarning( + log"Flow ${MDC(LogKeys.FLOW_NAME, pf.identifier)} was not stopped because it " + + log"was already completed. Exception: ${MDC(LogKeys.EXCEPTION, pf.exception)}" + ) + } + } + + /** + * Blocks the current thread while any flows are queued or running. Returns when all flows that + * could be run have completed. When this returns, all flows are either SUCCESSFUL, + * TERMINATED_WITH_ERROR, SKIPPED, CANCELED, or EXCLUDED. + */ + def awaitCompletion(): Unit + + /** + * Returns the reason why this flow execution has terminated. + * If the function is called before the flow has not terminated yet, the behavior is undefined, + * and may return `UnexpectedRunFailure`. + */ + def getRunTerminationReason: RunTerminationReason + + def maxRetryAttemptsForFlow(flowName: TableIdentifier): Int = { + val flow = graphForExecution.flow(flowName) + flow.sqlConf + .get(SQLConf.PIPELINES_MAX_FLOW_RETRY_ATTEMPTS.key) + .map(_.toInt) // Flow-level conf + // Pipeline-level conf, else default flow retry limit + .getOrElse(pipelineConf.maxFlowRetryAttempts) + } + + /** + * Stop a thread timeout. + */ + def stopThread(thread: Thread): Unit = { + // Don't wait to join if current thread is the thread to stop + if (thread.getId != Thread.currentThread().getId) { + thread.join(env.pipelineConf.timeoutMsForTerminationJoinAndLock) + // thread is alive after we join. + if (thread.isAlive) { + throw new TimeoutException("Failed to stop the update due to a hanging control thread.") + } + } + } +} + +object GraphExecution extends Logging { + + // Set of states after checking the exception for flow execution retryability analysis. + sealed trait FlowExecutionAction + + /** Indicates that the flow execution should be retried. */ + case object RetryFlowExecution extends FlowExecutionAction + + /** Indicates that the flow execution should be stopped with a specific reason. */ + case class StopFlowExecution(reason: FlowExecutionStopReason) extends FlowExecutionAction + + /** Represents the reason why a flow execution should be stopped. */ + sealed trait FlowExecutionStopReason { + def cause: Throwable + def flowDisplayName: String + def runTerminationReason: RunTerminationReason + def failureMessage: String + // If true, we record this flow execution as STOPPED with a WARNING instead a FAILED with ERROR. + def warnInsteadOfError: Boolean = false + } + + /** + * Represents the `FlowExecution` should be stopped due to it failed with some retryable errors + * and has exhausted all the retry attempts. + */ + private case class MaxRetryExceeded( + cause: Throwable, + flowDisplayName: String, + maxAllowedRetries: Int + ) extends FlowExecutionStopReason { + override lazy val runTerminationReason: RunTerminationReason = { + QueryExecutionFailure(flowDisplayName, maxAllowedRetries, Option(cause)) + } + override lazy val failureMessage: String = { + s"Flow '$flowDisplayName' has FAILED more than $maxAllowedRetries times and will not be " + + s"restarted." + } + } + + /** + * Analyze the exception thrown by flow execution and figure out if we should retry the execution, + * or we need to reanalyze the flow entirely to resolve issues like schema changes. + * This should be the narrow waist for all exception analysis in flow execution. + * TODO: currently it only handles schema change and max retries, we should aim to extend this to + * include other non-retryable exception as well so we can have a single SoT for all these error + * matching logic. + * @param ex Exception to analyze. + * @param flowDisplayName The user facing flow name with the error. + * @param pipelineConf Pipeline configuration. + * @param currentNumTries Number of times the flow has been tried. + * @param maxAllowedRetries Maximum number of retries allowed for the flow. + */ + def determineFlowExecutionActionFromError( + ex: => Throwable, + flowDisplayName: => String, + pipelineConf: => PipelineConf, + currentNumTries: => Int, + maxAllowedRetries: => Int + ): FlowExecutionAction = { + val flowExecutionNonRetryableReasonOpt = if (currentNumTries > maxAllowedRetries) { + Some(MaxRetryExceeded(ex, flowDisplayName, maxAllowedRetries)) + } else { + None + } + + if (flowExecutionNonRetryableReasonOpt.isDefined) { + StopFlowExecution(flowExecutionNonRetryableReasonOpt.get) + } else { + RetryFlowExecution + } + } +} diff --git a/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/GraphFilter.scala b/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/GraphFilter.scala new file mode 100644 index 0000000000000..2f40e53cafac0 --- /dev/null +++ b/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/GraphFilter.scala @@ -0,0 +1,135 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.pipelines.graph + +import org.apache.spark.sql.catalyst.TableIdentifier + +/** + * Specifies how we should filter Graph elements. + */ +sealed trait GraphFilter[E] { + + /** Returns the subset of elements provided that match this filter. */ + def filter(elements: Seq[E]): Seq[E] + + /** Returns the subset of elements provided that do not match this filter. */ + def filterNot(elements: Seq[E]): Seq[E] +} + +/** + * Specifies how we should filter Flows. + */ +sealed trait FlowFilter extends GraphFilter[ResolvedFlow] + +/** + * Specifies how we should filter Tables. + */ +sealed trait TableFilter extends GraphFilter[Table] { + + /** Returns whether at least one table will pass the filter. */ + def nonEmpty: Boolean +} + +/** + * Used in full graph update to select all flows. + */ +case object AllFlows extends FlowFilter { + override def filter(flows: Seq[ResolvedFlow]): Seq[ResolvedFlow] = flows + override def filterNot(flows: Seq[ResolvedFlow]): Seq[ResolvedFlow] = Seq.empty +} + +/** + * Used in partial graph updates to select flows that flow to "selectedTables". + */ +case class FlowsForTables(selectedTables: Set[TableIdentifier]) extends FlowFilter { + + private def filterCondition( + flows: Seq[ResolvedFlow], + useFilterNot: Boolean + ): Seq[ResolvedFlow] = { + val (matchingFlows, nonMatchingFlows) = flows.partition { f => + selectedTables.contains(f.destinationIdentifier) + } + + if (useFilterNot) nonMatchingFlows else matchingFlows + } + + override def filter(flows: Seq[ResolvedFlow]): Seq[ResolvedFlow] = { + filterCondition(flows, useFilterNot = false) + } + + override def filterNot(flows: Seq[ResolvedFlow]): Seq[ResolvedFlow] = { + filterCondition(flows, useFilterNot = true) + } +} + +/** Returns a flow filter that is a union of two flow filters */ +case class UnionFlowFilter(oneFilter: FlowFilter, otherFilter: FlowFilter) extends FlowFilter { + override def filter(flows: Seq[ResolvedFlow]): Seq[ResolvedFlow] = { + (oneFilter.filter(flows).toSet ++ otherFilter.filter(flows).toSet).toSeq + } + + override def filterNot(flows: Seq[ResolvedFlow]): Seq[ResolvedFlow] = { + (flows.toSet -- filter(flows).toSet).toSeq + } +} + +/** Used to specify that no flows should be refreshed. */ +case object NoFlows extends FlowFilter { + override def filter(flows: Seq[ResolvedFlow]): Seq[ResolvedFlow] = Seq.empty + override def filterNot(flows: Seq[ResolvedFlow]): Seq[ResolvedFlow] = flows +} + +/** + * Used in full graph updates to select all tables. + */ +case object AllTables extends TableFilter { + override def filter(tables: Seq[Table]): Seq[Table] = tables + override def filterNot(tables: Seq[Table]): Seq[Table] = Seq.empty + + override def nonEmpty: Boolean = true +} + +/** + * Used to select no tables. + */ +case object NoTables extends TableFilter { + override def filter(tables: Seq[Table]): Seq[Table] = Seq.empty + override def filterNot(tables: Seq[Table]): Seq[Table] = tables + + override def nonEmpty: Boolean = false +} + +/** + * Used in partial graph updates to select "selectedTables". + */ +case class SomeTables(selectedTables: Set[TableIdentifier]) extends TableFilter { + private def filterCondition(tables: Seq[Table], useFilterNot: Boolean): Seq[Table] = { + tables.filter { t => + useFilterNot ^ selectedTables.contains(t.identifier) + } + } + + override def filter(tables: Seq[Table]): Seq[Table] = + filterCondition(tables, useFilterNot = false) + + override def filterNot(tables: Seq[Table]): Seq[Table] = + filterCondition(tables, useFilterNot = true) + + override def nonEmpty: Boolean = selectedTables.nonEmpty +} diff --git a/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/GraphOperations.scala b/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/GraphOperations.scala index c0b5a360afea7..1514bbbe3c797 100644 --- a/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/GraphOperations.scala +++ b/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/GraphOperations.scala @@ -38,7 +38,7 @@ trait GraphOperations { private lazy val destinationSet: Set[TableIdentifier] = flows.map(_.destinationIdentifier).toSet - /** A map from flow identifier to [[FlowNode]], which contains the input/output nodes. */ + /** A map from flow identifier to `FlowNode`, which contains the input/output nodes. */ lazy val flowNodes: Map[TableIdentifier, FlowNode] = { flows.map { f => val identifier = f.identifier diff --git a/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/GraphValidations.scala b/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/GraphValidations.scala index 99142432f9cec..b7e0cf86e4dcc 100644 --- a/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/GraphValidations.scala +++ b/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/GraphValidations.scala @@ -25,7 +25,7 @@ import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.pipelines.graph.DataflowGraph.mapUnique import org.apache.spark.sql.pipelines.util.SchemaInferenceUtils -/** Validations performed on a [[DataflowGraph]]. */ +/** Validations performed on a `DataflowGraph`. */ trait GraphValidations extends Logging { this: DataflowGraph => diff --git a/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/PipelineConf.scala b/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/PipelineConf.scala new file mode 100644 index 0000000000000..42648c6ef58b8 --- /dev/null +++ b/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/PipelineConf.scala @@ -0,0 +1,64 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.pipelines.graph + +import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.internal.SQLConf + +/** + * Configuration for the pipeline system, which is read from the Spark session's SQL configuration. + */ +@deprecated("TODO(SPARK-52410): Remove this class in favor of using SqlConf directly") +class PipelineConf(spark: SparkSession) { + private val sqlConf: SQLConf = spark.sessionState.conf + + /** Interval in milliseconds to poll the state of streaming flow execution. */ + val streamStatePollingInterval: Long = sqlConf.getConf( + SQLConf.PIPELINES_STREAM_STATE_POLLING_INTERVAL + ) + + /** Minimum time in seconds between retries for the watchdog. */ + val watchdogMinRetryTimeInSeconds: Long = { + sqlConf.getConf(SQLConf.PIPELINES_WATCHDOG_MIN_RETRY_TIME_IN_SECONDS) + } + + /** Maximum time in seconds for the watchdog to retry before giving up. */ + val watchdogMaxRetryTimeInSeconds: Long = { + val value = sqlConf.getConf(SQLConf.PIPELINES_WATCHDOG_MAX_RETRY_TIME_IN_SECONDS) + // TODO(SPARK-52410): Remove this check and use `checkValue` when defining the conf + // in `SqlConf`. + if (value < watchdogMinRetryTimeInSeconds) { + throw new IllegalArgumentException( + "Watchdog maximum retry time must be greater than or equal to the watchdog minimum " + + "retry time." + ) + } + value + } + + /** Maximum number of concurrent flows that can be executed. */ + val maxConcurrentFlows: Int = sqlConf.getConf(SQLConf.PIPELINES_MAX_CONCURRENT_FLOWS) + + /** Timeout in milliseconds for termination join and lock operations. */ + val timeoutMsForTerminationJoinAndLock: Long = { + sqlConf.getConf(SQLConf.PIPELINES_TIMEOUT_MS_FOR_TERMINATION_JOIN_AND_LOCK) + } + + /** Maximum number of retry attempts for a flow execution. */ + val maxFlowRetryAttempts: Int = sqlConf.getConf(SQLConf.PIPELINES_MAX_FLOW_RETRY_ATTEMPTS) +} diff --git a/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/PipelineExecution.scala b/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/PipelineExecution.scala new file mode 100644 index 0000000000000..3f9cc91ed487d --- /dev/null +++ b/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/PipelineExecution.scala @@ -0,0 +1,116 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.pipelines.graph + +import org.apache.spark.sql.pipelines.logging.{ + ConstructPipelineEvent, + EventLevel, + PipelineEventOrigin, + RunProgress +} + +/** + * Executes a [[DataflowGraph]] by resolving the graph, materializing datasets, and running the + * flows. + * + * @param context The context for this pipeline update. + */ +class PipelineExecution(context: PipelineUpdateContext) { + + /** [Visible for testing] */ + private[pipelines] var graphExecution: Option[TriggeredGraphExecution] = None + + /** + * Executes all flows in the graph. + */ + def runPipeline(): Unit = synchronized { + // Initialize the graph. + val initializedGraph = initializeGraph() + + // Execute the graph. + graphExecution = Option( + new TriggeredGraphExecution(initializedGraph, context, onCompletion = terminationReason => { + context.eventBuffer.addEvent( + ConstructPipelineEvent( + origin = PipelineEventOrigin( + flowName = None, + datasetName = None, + sourceCodeLocation = None + ), + level = EventLevel.INFO, + message = terminationReason.message, + details = RunProgress(terminationReason.terminalState), + exception = terminationReason.cause.orNull + ) + ) + }) + ) + graphExecution.foreach(_.start()) + } + + /** Initializes the graph by resolving it and materializing datasets. */ + private def initializeGraph(): DataflowGraph = { + val resolvedGraph = try { + context.unresolvedGraph.resolve().validate() + } catch { + case e: UnresolvedPipelineException => + handleInvalidPipeline(e) + throw e + } + DatasetManager.materializeDatasets(resolvedGraph, context) + } + + /** Waits for the execution to complete. Only used in tests */ + private[sql] def awaitCompletion(): Unit = { + graphExecution.foreach(_.awaitCompletion()) + } + + /** + * Emits FlowProgress.FAILED events for each flow that failed to resolve. Downstream flow failures + * (flows that failed to resolve when reading from other flows that also failed to resolve) are + * written to the event log first at WARN level, while upstream flow failures which are expected + * to be "real" failures are written at ERROR level and come afterwards. This makes the real + * errors show up first in the UI. + * + * @param e The exception that was raised while executing a stage + */ + private def handleInvalidPipeline(e: UnresolvedPipelineException): Unit = { + e.downstreamFailures.foreach { failure => + val (flowIdentifier, ex) = failure + val flow = e.graph.resolutionFailedFlow(flowIdentifier) + context.flowProgressEventLogger.recordFailed( + flow = flow, + exception = ex, + logAsWarn = true, + messageOpt = Option( + s"Failed to resolve flow due to upstream failure: '${flow.displayName}'." + ) + ) + } + e.directFailures.foreach { failure => + val (flowIdentifier, ex) = failure + val flow = e.graph.resolutionFailedFlow(flowIdentifier) + context.flowProgressEventLogger.recordFailed( + flow = flow, + exception = ex, + logAsWarn = true, + messageOpt = Option(s"Failed to resolve flow: '${flow.displayName}'.") + ) + } + } +} diff --git a/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/PipelineUpdateContext.scala b/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/PipelineUpdateContext.scala new file mode 100644 index 0000000000000..93d608dd7668d --- /dev/null +++ b/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/PipelineUpdateContext.scala @@ -0,0 +1,67 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.pipelines.graph + +import org.apache.spark.sql.classic.SparkSession +import org.apache.spark.sql.pipelines.logging.{FlowProgressEventLogger, PipelineRunEventBuffer} + +trait PipelineUpdateContext { + + /** The SparkSession for this update. */ + def spark: SparkSession + + /** Filter for which tables should be refreshed when performing this update. */ + def refreshTables: TableFilter + + /** Filter for which tables should be full refreshed when performing this update. */ + def fullRefreshTables: TableFilter + + def resetCheckpointFlows: FlowFilter + + /** + * Filter for which flows should be refreshed when performing this update. Should be a superset of + * fullRefreshFlows. + */ + final def refreshFlows: FlowFilter = { + val flowFilterForTables = (refreshTables, fullRefreshTables) match { + case (AllTables, _) => AllFlows + case (_, AllTables) => AllFlows + case (SomeTables(tablesRefresh), SomeTables(tablesFullRefresh)) => + FlowsForTables(tablesRefresh ++ tablesFullRefresh) + case (SomeTables(tables), _) => FlowsForTables(tables) + case (_, SomeTables(tables)) => FlowsForTables(tables) + case _ => NoFlows + } + UnionFlowFilter(flowFilterForTables, resetCheckpointFlows) + } + + /** `PipelineConf` based on the root SparkSession for this update. */ + def pipelineConf: PipelineConf + + /** Buffer containing internal events that are emitted during a run of a pipeline. */ + def eventBuffer: PipelineRunEventBuffer + + /** Emits internal flow progress events into the event buffer. */ + def flowProgressEventLogger: FlowProgressEventLogger + + /** The unresolved graph for this update. */ + def unresolvedGraph: DataflowGraph + + /** Defines operations relates to end to end execution of a `DataflowGraph`. */ + val pipelineExecution: PipelineExecution = new PipelineExecution(context = this) +} diff --git a/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/PipelineUpdateContextImpl.scala b/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/PipelineUpdateContextImpl.scala new file mode 100644 index 0000000000000..8192068af67e5 --- /dev/null +++ b/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/PipelineUpdateContextImpl.scala @@ -0,0 +1,53 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.pipelines.graph + +import scala.annotation.unused + +import org.apache.spark.sql.classic.SparkSession +import org.apache.spark.sql.pipelines.logging.{FlowProgressEventLogger, PipelineEvent, PipelineRunEventBuffer} + +/** + * An implementation of the PipelineUpdateContext trait used in production. + * @param unresolvedGraph The graph (unresolved) to be executed in this update. + * @param eventCallback A callback function to be called when an event is added to the event buffer. + */ +@unused( + "TODO(SPARK-51727) construct this spark connect server when we expose APIs for users " + + "to interact with a pipeline" +) +class PipelineUpdateContextImpl( + override val unresolvedGraph: DataflowGraph, + eventCallback: PipelineEvent => Unit +) extends PipelineUpdateContext { + + override val spark: SparkSession = SparkSession.getActiveSession.getOrElse( + throw new IllegalStateException("SparkSession is not available") + ) + + override val pipelineConf: PipelineConf = new PipelineConf(spark) + + override val eventBuffer = new PipelineRunEventBuffer(eventCallback) + + override val flowProgressEventLogger: FlowProgressEventLogger = + new FlowProgressEventLogger(eventBuffer = eventBuffer) + + override val refreshTables: TableFilter = AllTables + override val fullRefreshTables: TableFilter = NoTables + override val resetCheckpointFlows: FlowFilter = NoFlows +} diff --git a/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/PipelinesErrors.scala b/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/PipelinesErrors.scala index 4bed25f2aa1c7..a0e378f85bce7 100644 --- a/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/PipelinesErrors.scala +++ b/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/PipelinesErrors.scala @@ -17,7 +17,10 @@ package org.apache.spark.sql.pipelines.graph +import scala.collection.mutable.ArrayBuffer + import org.apache.spark.SparkException +import org.apache.spark.internal.Logging import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.TableIdentifier @@ -45,6 +48,120 @@ case class LoadTableException(name: String, cause: Option[Throwable]) cause = cause.orNull ) +object PipelinesErrors extends Logging { + + /** + * Gets the exception chain for a given exception by repeatedly calling getCause. + * + * @param originalErr The error on which getCause is repeatedly called + * @return An ArrayBuffer containing the original error and all the causes in its exception chain. + * For a given exception in the ArrayBuffer, the next element is its cause. + */ + private def getExceptionChain(originalErr: Throwable): ArrayBuffer[Throwable] = { + val exceptionChain = ArrayBuffer[Throwable]() + var lastException = originalErr + while (lastException != null) { + exceptionChain += lastException + lastException = lastException.getCause + } + exceptionChain + } + + /** + * Checks whether a throwable or any of its nested causes meets some condition + * @param throwable A Throwable to inspect + * @param check Function to run on each cause + * @return Whether or not `throwable` or any of its nested causes satisfy the check + */ + private def checkCauses(throwable: Throwable, check: Throwable => Boolean): Boolean = { + getExceptionChain(throwable).exists(check) + } + + /** + * Checks an error for streaming specific handling. This is a pretty messy signature as a result + * of unifying some divergences between the triggered caller in TriggeredGraphExecution and the + * continuous caller in StreamWatchdog. + * + * @param ex the error to check + * @param env the update context + * @param graphExecution the graph execution + * @param flow the resolved logical flow + * @param shouldRethrow whether to throw an UpdateTerminationException wrapping `ex`. This is set + * to true for ContinuousFlowExecution so we can eagerly stop the execution. + * @param prevFailureCount the number of failures that have occurred so far + * @param maxRetries the max retries that were available (whether or not they're exhausted now) + */ + def checkStreamingErrorsAndRetry( + ex: Throwable, + env: PipelineUpdateContext, + graphExecution: GraphExecution, + flow: ResolvedFlow, + shouldRethrow: Boolean, + prevFailureCount: Int, + maxRetries: Int, + onRetry: => Unit + ): Unit = { + if (PipelinesErrors.checkCauses( + throwable = ex, + check = ex => { + ex.isInstanceOf[AssertionError] && + ex.getMessage != null && + ex.getMessage.contains("sources in the checkpoint offsets and now there are") && + ex.getMessage.contains("sources requested by the query. Cannot continue.") + } + )) { + val message = s""" + |Flow '${flow.displayName}' had streaming sources added or removed. Please perform a + |full refresh in order to rebuild '${flow.displayName}' against the current set of + |sources. + |""".stripMargin + + env.flowProgressEventLogger.recordFailed( + flow = flow, + exception = ex, + logAsWarn = false, + messageOpt = Option(message) + ) + } else if (flow.once && ex == null) { + // No need to do anything if this is a ONCE flow with no exception. That just means it's done. + } else { + val actionFromError = GraphExecution.determineFlowExecutionActionFromError( + ex = ex, + flowDisplayName = flow.displayName, + pipelineConf = env.pipelineConf, + currentNumTries = prevFailureCount + 1, + maxAllowedRetries = maxRetries + ) + actionFromError match { + // Simply retry + case GraphExecution.RetryFlowExecution => onRetry + // Schema change exception + case GraphExecution.StopFlowExecution(reason) => + val msg = reason.failureMessage + if (reason.warnInsteadOfError) { + logWarning(msg, reason.cause) + env.flowProgressEventLogger.recordStop( + flow = flow, + message = Option(msg), + cause = Option(reason.cause) + ) + } else { + logError(reason.failureMessage, reason.cause) + env.flowProgressEventLogger.recordFailed( + flow = flow, + exception = reason.cause, + logAsWarn = false, + messageOpt = Option(msg) + ) + } + if (shouldRethrow) { + throw RunTerminationException(reason.runTerminationReason) + } + } + } + } +} + /** * Exception raised when a pipeline has one or more flows that cannot be resolved * diff --git a/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/QueryOrigin.scala b/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/QueryOrigin.scala index 042b4d9626fd6..e260d9693b6dc 100644 --- a/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/QueryOrigin.scala +++ b/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/QueryOrigin.scala @@ -28,7 +28,7 @@ import org.apache.spark.sql.pipelines.logging.SourceCodeLocation * Records information used to track the provenance of a given query to user code. * * @param language The language used by the user to define the query. - * @param fileName The file name of the user code that defines the query. + * @param filePath Path to the file of the user code that defines the query. * @param sqlText The SQL text of the query. * @param line The line number of the query in the user code. * Line numbers are 1-indexed. @@ -38,7 +38,7 @@ import org.apache.spark.sql.pipelines.logging.SourceCodeLocation */ case class QueryOrigin( language: Option[Language] = None, - fileName: Option[String] = None, + filePath: Option[String] = None, sqlText: Option[String] = None, line: Option[Int] = None, startPosition: Option[Int] = None, @@ -55,7 +55,7 @@ case class QueryOrigin( def merge(other: QueryOrigin): QueryOrigin = { QueryOrigin( language = other.language.orElse(language), - fileName = other.fileName.orElse(fileName), + filePath = other.filePath.orElse(filePath), sqlText = other.sqlText.orElse(sqlText), line = other.line.orElse(line), startPosition = other.startPosition.orElse(startPosition), @@ -82,7 +82,7 @@ case class QueryOrigin( /** Generates a SourceCodeLocation using the details present in the query origin. */ def toSourceCodeLocation: SourceCodeLocation = SourceCodeLocation( - path = fileName, + path = filePath, // QueryOrigin tracks line numbers using a 1-indexed numbering scheme whereas SourceCodeLocation // tracks them using a 0-indexed numbering scheme. lineNumber = line.map(_ - 1), @@ -98,7 +98,7 @@ object QueryOrigin extends Logging { val empty: QueryOrigin = QueryOrigin() /** - * An exception that wraps [[QueryOrigin]] and lets us store it in errors as suppressed + * An exception that wraps `QueryOrigin` and lets us store it in errors as suppressed * exceptions. */ private case class QueryOriginWrapper(origin: QueryOrigin) extends Exception with NoStackTrace @@ -120,20 +120,21 @@ object QueryOrigin extends Logging { t.addSuppressed(QueryOriginWrapper(origin)) } } catch { - case NonFatal(e) => logError("Failed to add pipeline context", e) + case NonFatal(e) => + logError("Failed to add pipeline context", e) } t } } - /** Returns the [[QueryOrigin]] stored as a suppressed exception in the given throwable. + /** Returns the `QueryOrigin` stored as a suppressed exception in the given throwable. * * @return Some(origin) if the origin is recorded as part of the given throwable, `None` * otherwise. */ def getOrigin(t: Throwable): Option[QueryOrigin] = { try { - // Wrap in an `Option(_)` first to handle `null` throwables. + // Wrap in an `Option(_)` first to handle `null` throwable. Option(t).flatMap { ex => ex.getSuppressed.collectFirst { case QueryOriginWrapper(context) => context diff --git a/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/RunTerminationReason.scala b/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/RunTerminationReason.scala new file mode 100644 index 0000000000000..c95ce6a197eeb --- /dev/null +++ b/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/RunTerminationReason.scala @@ -0,0 +1,121 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.pipelines.graph + +import org.apache.spark.sql.catalyst.TableIdentifier +import org.apache.spark.sql.pipelines.common.RunState + +sealed trait RunTerminationReason { + + /** Terminal state for the given run. */ + def terminalState: RunState + + /** + * User visible message associated with run termination. This will also be set as the message + * in the associated terminal run progress log. + */ + def message: String + + /** + * Exception associated with the given run termination. This exception will be + * included in the error details in the associated terminal run progress event. + */ + def cause: Option[Throwable] +} + +/** + * Helper exception class that indicates that a run has to be terminated and + * tracks the associated termination reason. + */ +case class RunTerminationException(reason: RunTerminationReason) extends Exception + +// =============================================================== +// ============ Graceful run termination states ================== +// =============================================================== + +/** Indicates that a triggered run has successfully completed execution. */ +case class RunCompletion() extends RunTerminationReason { + override def terminalState: RunState = RunState.COMPLETED + override def message: String = s"Run is $terminalState." + override def cause: Option[Throwable] = None +} + +// =============================================================== +// ======================= Run failures ========================== +// =============================================================== + +/** Indicates that an run entered the failed state.. */ +abstract sealed class RunFailure extends RunTerminationReason { + + /** Whether or not this failure is considered fatal / irrecoverable. */ + def isFatal: Boolean + + override def terminalState: RunState = RunState.FAILED +} + +/** Indicates that run has failed due to a query execution failure. */ +case class QueryExecutionFailure( + flowName: String, + maxRetries: Int, + override val cause: Option[Throwable]) + extends RunFailure { + override def isFatal: Boolean = false + + override def message: String = + if (maxRetries == 0) { + s"Run is $terminalState since flow '$flowName' has failed." + } else { + s"Run is $terminalState since flow '$flowName' has failed more " + + s"than $maxRetries times." + } +} + +/** Abstract class used to identify failures related to failures stopping an operation/timeouts. */ +abstract class FailureStoppingOperation extends RunFailure { + + /** Name of the operation that failed to stop. */ + def operation: String +} + +/** Indicates that there was a failure while stopping the flow. */ +case class FailureStoppingFlow(flowIdentifiers: Seq[TableIdentifier]) + extends FailureStoppingOperation { + override def isFatal: Boolean = false + override def operation: String = "flow execution" + override def message: String = { + if (flowIdentifiers.nonEmpty) { + val flowNamesToPrint = flowIdentifiers.map(_.toString).sorted.take(5).mkString(", ") + s"Run is $terminalState since following flows have failed to stop: " + + s"$flowNamesToPrint." + } else { + s"Run is $terminalState since stopping flow execution has failed." + } + } + override def cause: Option[Throwable] = None +} + +/** + * Run could not be associated with a proper root cause. + * This is not expected and likely indicates a bug. + */ +case class UnexpectedRunFailure() extends RunFailure { + override def isFatal: Boolean = false + override def message: String = + s"Run $terminalState unexpectedly." + override def cause: Option[Throwable] = None +} diff --git a/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/TriggeredGraphExecution.scala b/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/TriggeredGraphExecution.scala new file mode 100644 index 0000000000000..fdac2cdb2fc90 --- /dev/null +++ b/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/TriggeredGraphExecution.scala @@ -0,0 +1,494 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.pipelines.graph + +import java.util.concurrent.{ConcurrentHashMap, Semaphore} + +import scala.collection.mutable +import scala.concurrent.duration._ +import scala.jdk.CollectionConverters._ +import scala.util.Try +import scala.util.control.NonFatal + +import org.apache.spark.internal.{LogKeys, MDC} +import org.apache.spark.sql.catalyst.TableIdentifier +import org.apache.spark.sql.pipelines.graph.TriggeredGraphExecution._ +import org.apache.spark.sql.pipelines.util.ExponentialBackoffStrategy +import org.apache.spark.sql.streaming.Trigger +import org.apache.spark.util.{Clock, SystemClock, ThreadUtils, Utils} + +/** + * Executes all of the flows in the given graph in topological order. Each flow processes + * all available data before downstream flows are triggered. + * + * @param graphForExecution the graph to execute. + * @param env the context in which the graph is executed. + * @param onCompletion a callback to execute after all streams are done. The boolean + * argument is true if the execution was successful. + * @param clock a clock used to determine the time of execution. + */ +class TriggeredGraphExecution( + graphForExecution: DataflowGraph, + env: PipelineUpdateContext, + onCompletion: RunTerminationReason => Unit = _ => (), + clock: Clock = new SystemClock() +) extends GraphExecution(graphForExecution, env) { + + /** + * [Visible for testing] A map to store stream state of all flows which should be materialized. + * This includes flows whose streams have not yet been started, ie they are queued or have been + * marked as skipped. + */ + private[pipelines] val pipelineState = { + new ConcurrentHashMap[TableIdentifier, StreamState]().asScala + } + + /** + * Keeps track of flow failure information required for retry logic. + * This only contains values for flows that either failed previously or are currently in the + * failed state. + */ + private val failureTracker = { + new ConcurrentHashMap[TableIdentifier, TriggeredFailureInfo]().asScala + } + + /** Back-off strategy used to determine duration between retries. */ + private val backoffStrategy = ExponentialBackoffStrategy( + maxTime = (pipelineConf.watchdogMaxRetryTimeInSeconds * 1000).millis, + stepSize = (pipelineConf.watchdogMinRetryTimeInSeconds * 1000).millis + ) + + override def streamTrigger(flow: Flow): Trigger = { + Trigger.AvailableNow() + } + + /** The control thread responsible for topologically executing flows. */ + private var topologicalExecutionThread: Option[Thread] = None + + private def buildTopologicalExecutionThread(): Thread = { + new Thread("Topological Execution") { + override def run(): Unit = { + try { + topologicalExecution() + } finally { + TriggeredGraphExecution.super.stop() + } + } + } + } + + override def start(): Unit = { + super.start() + // If tablesToUpdate is empty, queue all flows; Otherwise, queue flows for which the + // destination tables are specified in tablesToUpdate. + env.refreshFlows + .filter(graphForExecution.materializedFlows) + .foreach { f => + env.flowProgressEventLogger.recordQueued(f) + pipelineState.put(f.identifier, StreamState.QUEUED) + } + env.refreshFlows + .filterNot(graphForExecution.materializedFlows) + .foreach { f => + env.flowProgressEventLogger.recordExcluded(f) + pipelineState.put(f.identifier, StreamState.EXCLUDED) + } + val thread = buildTopologicalExecutionThread() + UncaughtExceptionHandler.addHandler( + thread, { + case _: InterruptedException => + case _ => + try { + stopInternal(stopTopologicalExecutionThread = false) + } catch { + case ex: Throwable => + logError(s"Exception thrown while stopping the update...", ex) + } finally { + onCompletion(UnexpectedRunFailure()) + } + } + ) + thread.start() + topologicalExecutionThread = Option(thread) + } + + /** Used to control how many flows are executing at once. */ + private val concurrencyLimit: Semaphore = new Semaphore(pipelineConf.maxConcurrentFlows) + + /** + * Runs the pipeline in a topological order. + * + * Non-accepting states: Queued, Running + * Accepting states: Successful, TerminatedWithError, Skipped, Cancelled, Excluded + * All [[Flow]]s which can write to a stream begin in a queued state. The following state + * transitions describe the topological execution of a [[DataflowGraph]]. + * + * Queued -> Running if Flow has no parents or the parent tables of the queued [[Flow]] + * have run successfully. + * Running -> Successful if the stream associated with the [[Flow]] succeeds. + * Running -> TerminatedWithError if the stream associated with the [[Flow]] stops with an + * exception. + * + * Non-fatally failed flows are retried with exponential back-off a bounded no. of times. + * If a flow cannot be retried, all downstream flows of the failed flow are moved to Skipped + * state. + * Running -> Cancelled if the stream associated with the [[Flow]] is stopped mid-run by + * calling `stop`. All remaining [[Flow]]s in queue are moved to state Skipped. + * + * The execution is over once there are no [[Flow]]s left running or in the queue. + */ + private def topologicalExecution(): Unit = { + // Done executing once no flows remain running or in queue + def allFlowsDone = { + flowsWithState(StreamState.QUEUED).isEmpty && flowsWithState(StreamState.RUNNING).isEmpty && + flowsQueuedForRetry().isEmpty + } + + // LinkedHashSet returns elements in the order inserted. This ensures that flows queued but + // unable to run because we are at max concurrent execution will get priority on the next round. + val runnableFlows: mutable.LinkedHashSet[TableIdentifier] = new mutable.LinkedHashSet() + + while (!Thread.interrupted() && !allFlowsDone) { + // Since queries are managed by FlowExecutions, so update state based on [[FlowExecution]]s. + flowsWithState(StreamState.RUNNING).foreach { flowIdentifier => + flowExecutions(flowIdentifier) match { + case f if !f.isCompleted => // Nothing to be done; let this stream continue. + case f if f.isCompleted && f.exception.isEmpty => + recordSuccess(flowIdentifier) + case f => + recordFailed(flowIdentifier = flowIdentifier, e = f.exception.get) + } + } + + // Log info on if we're leaking Semaphore permits. Synchronize here so we don't double-count + // or mis-count because a batch flow is finishing asynchronously. + val (runningFlows, availablePermits) = concurrencyLimit.synchronized { + (flowsWithState(StreamState.RUNNING).size, concurrencyLimit.availablePermits) + } + if ((runningFlows + availablePermits) < pipelineConf.maxConcurrentFlows) { + val errorStr = + s"The max concurrency is ${pipelineConf.maxConcurrentFlows}, but there are only " + + s"$availablePermits permits available with $runningFlows flows running. If this " + + s"happens consistently, it's possible we're leaking permits." + logError(errorStr) + if (Utils.isTesting) { + throw new IllegalStateException(errorStr) + } + } + + // All flows which can potentially be run now if their parent tables have successfully + // completed or have been excluded. + val queuedForRetry = + flowsQueuedForRetry().filter(nextRetryTime(_) <= clock.getTimeMillis()) + // Take flows that have terminated but have retry attempts left and flows that are queued, and + // filter the ones whose parents have all successfully completed, excluded, or idled because + // they are ONCE flows which already ran. + runnableFlows ++= (queuedForRetry ++ flowsWithState(StreamState.QUEUED)).filter { id => + graphForExecution + .upstreamFlows(id) + .intersect(graphForExecution.materializedFlowIdentifiers) + .forall { id => + pipelineState(id) == StreamState.SUCCESSFUL || + pipelineState(id) == StreamState.EXCLUDED || + pipelineState(id) == StreamState.IDLE + } + } + + // collect flow that are ready to start + val flowsToStart = mutable.ArrayBuffer[ResolvedFlow]() + while (runnableFlows.nonEmpty && concurrencyLimit.tryAcquire()) { + val flowIdentifier = runnableFlows.head + runnableFlows.remove(flowIdentifier) + flowsToStart.append(graphForExecution.resolvedFlow(flowIdentifier)) + } + + def startFlow(flow: ResolvedFlow): Unit = { + val flowIdentifier = flow.identifier + logInfo(log"Starting flow ${MDC(LogKeys.FLOW_NAME, flow.identifier)}") + env.flowProgressEventLogger.recordPlanningForBatchFlow(flow) + try { + val flowStarted = planAndStartFlow(flow) + if (flowStarted.nonEmpty) { + pipelineState.put(flowIdentifier, StreamState.RUNNING) + logInfo(log"Flow ${MDC(LogKeys.FLOW_NAME, flowIdentifier)} started.") + } else { + if (flow.once) { + // ONCE flows are marked as IDLE in the event buffer for consistency with continuous + // execution where all unstarted flows are IDLE. + env.flowProgressEventLogger.recordIdle(flow) + pipelineState.put(flowIdentifier, StreamState.IDLE) + concurrencyLimit.release() + } else { + env.flowProgressEventLogger.recordSkipped(flow) + concurrencyLimit.release() + pipelineState.put(flowIdentifier, StreamState.SKIPPED) + } + } + } catch { + case NonFatal(ex) => recordFailed(flowIdentifier, ex) + } + } + + // start each flow serially + flowsToStart.foreach(startFlow) + + try { + // Put thread to sleep for the configured polling interval to avoid busy-waiting + // and holding one CPU core. + Thread.sleep(pipelineConf.streamStatePollingInterval * 1000) + } catch { + case _: InterruptedException => return + } + } + if (allFlowsDone) { + onCompletion(getRunTerminationReason) + } + } + + /** Record the specified flow as successful. */ + private def recordSuccess(flowIdentifier: TableIdentifier): Unit = { + concurrencyLimit.synchronized { + concurrencyLimit.release() + pipelineState.put(flowIdentifier, StreamState.SUCCESSFUL) + } + logInfo( + log"Flow ${MDC(LogKeys.FLOW_NAME, flowIdentifier)} has COMPLETED " + + log"in TriggeredFlowExecution." + ) + } + + /** + * Record the specified flow as failed and any downstream flows as failed. + * + * @param e The error that caused the query to fail. + */ + private def recordFailed( + flowIdentifier: TableIdentifier, + e: Throwable + ): Unit = { + logError(log"Flow ${MDC(LogKeys.FLOW_NAME, flowIdentifier)} failed", e) + concurrencyLimit.synchronized { + concurrencyLimit.release() + pipelineState.put(flowIdentifier, StreamState.TERMINATED_WITH_ERROR) + } + val prevFailureCount = failureTracker.get(flowIdentifier).map(_.numFailures).getOrElse(0) + val flow = graphForExecution.resolvedFlow(flowIdentifier) + + failureTracker.put( + flowIdentifier, + TriggeredFailureInfo( + lastFailTimestamp = clock.getTimeMillis(), + numFailures = prevFailureCount + 1, + lastException = e, + lastExceptionAction = GraphExecution.determineFlowExecutionActionFromError( + ex = e, + flowDisplayName = flow.displayName, + pipelineConf = pipelineConf, + currentNumTries = prevFailureCount + 1, + maxAllowedRetries = maxRetryAttemptsForFlow(flowIdentifier) + ) + ) + ) + if (graphForExecution.resolvedFlow(flow.identifier).df.isStreaming) { + // Batch query failure log comes from the batch execution thread. + PipelinesErrors.checkStreamingErrorsAndRetry( + ex = e, + env = env, + graphExecution = this, + flow = flow, + shouldRethrow = false, + prevFailureCount = prevFailureCount, + maxRetries = maxRetryAttemptsForFlow(flowIdentifier), + onRetry = { + env.flowProgressEventLogger.recordFailed( + flow = flow, + exception = e, + logAsWarn = true + ) + } + ) + } + + // Don't skip downstream outputs yet if this flow still has retries left and didn't fail + // fatally. + if (!flowsQueuedForRetry().contains(flowIdentifier)) { + graphForExecution + .downstreamFlows(flowIdentifier) + .intersect(graphForExecution.materializedFlowIdentifiers) + .foreach(recordSkippedIfSelected) + } + } + + /** + * Record the specified flow as skipped. This is no-op if the flow is already excluded in the + * refresh selection. + */ + private def recordSkippedIfSelected(flowIdentifier: TableIdentifier): Unit = { + if (pipelineState(flowIdentifier) != StreamState.EXCLUDED) { + val flow = graphForExecution.resolvedFlow(flowIdentifier) + pipelineState.put(flowIdentifier, StreamState.SKIPPED) + logWarning( + log"Flow ${MDC(LogKeys.FLOW_NAME, flowIdentifier)} SKIPPED due " + + log"to upstream failure(s)." + ) + env.flowProgressEventLogger.recordSkippedOnUpStreamFailure(flow) + } + } + + private def flowsWithState(state: StreamState): Set[TableIdentifier] = { + pipelineState + .filter { + case (_, flowState) => flowState == state + } + .keySet + .toSet + } + + /** Set of flows which have failed, but can be queued again for a retry. */ + private def flowsQueuedForRetry(): Set[TableIdentifier] = { + flowsWithState(StreamState.TERMINATED_WITH_ERROR).filter(!failureTracker(_).nonRetryable) + } + + /** Earliest time at which flow can be run next. */ + private def nextRetryTime(flowIdentifier: TableIdentifier): Long = { + failureTracker + .get(flowIdentifier) + .map { failureInfo => + failureInfo.lastFailTimestamp + + backoffStrategy.waitDuration(failureInfo.numFailures).toMillis + } + .getOrElse(-1) + } + + private def stopInternal(stopTopologicalExecutionThread: Boolean): Unit = { + super.stop() + if (stopTopologicalExecutionThread) { + topologicalExecutionThread.filter(_.isAlive).foreach { t => + t.interrupt() + stopThread(t) + } + } + flowsWithState(StreamState.QUEUED).foreach(recordSkippedIfSelected) + + val flowsFailedToStop = ThreadUtils + .parmap(flowsWithState(StreamState.RUNNING).toSeq, "stop-flow", maxThreads = 10) { flowName => + pipelineState.put(flowName, StreamState.CANCELED) + flowExecutions.get(flowName).map { f => + ( + f.identifier, + Try(stopFlow(f)) + ) + } + } + .filter(_.nonEmpty) + .filter(_.get._2.isFailure) + .map(_.get._1) + + if (flowsFailedToStop.nonEmpty) { + throw RunTerminationException(FailureStoppingFlow(flowsFailedToStop)) + } + } + + override def awaitCompletion(): Unit = { + topologicalExecutionThread.foreach(_.join) + } + + override def stop(): Unit = { stopInternal(stopTopologicalExecutionThread = true) } + + override def getRunTerminationReason: RunTerminationReason = { + val success = + pipelineState.valuesIterator.forall(TERMINAL_NON_FAILURE_STREAM_STATES.contains) + if (success) { + return RunCompletion() + } + + val executionFailureOpt = failureTracker.iterator + .map { + case (flowIdentifier, failureInfo) => + ( + graphForExecution.flow(flowIdentifier), + failureInfo.lastException, + failureInfo.lastExceptionAction + ) + } + .collectFirst { + case (_, _, GraphExecution.StopFlowExecution(reason)) => + reason.runTerminationReason + } + + executionFailureOpt.getOrElse(UnexpectedRunFailure()) + } +} + +case class TriggeredFailureInfo( + lastFailTimestamp: Long, + numFailures: Int, + lastException: Throwable, + lastExceptionAction: GraphExecution.FlowExecutionAction) { + + // Whether this failure can be retried or not with flow execution. + lazy val nonRetryable: Boolean = { + lastExceptionAction.isInstanceOf[GraphExecution.StopFlowExecution] + } +} + +object TriggeredGraphExecution { + + // All possible states of a data stream for a flow + sealed trait StreamState + object StreamState { + // Stream is waiting on its parent tables to successfully finish processing + // data to start running, in triggered execution + case object QUEUED extends StreamState + + // Stream is processing data + case object RUNNING extends StreamState + + // Stream excluded if it's not selected in the partial graph update API call. + case object EXCLUDED extends StreamState + + // Stream will not be rerun because it is a ONCE flow. + case object IDLE extends StreamState + + // Stream will not be run due to parent tables not finishing successfully in triggered execution + case object SKIPPED extends StreamState + + // Stream has been stopped with a fatal error + case object TERMINATED_WITH_ERROR extends StreamState + + // Stream stopped before completion in triggered execution + case object CANCELED extends StreamState + + // Stream successfully processed all available data in triggered execution + case object SUCCESSFUL extends StreamState + } + + /** + * List of terminal states which we don't consider as failures. + * + * An update was successful if all rows either updated successfully or were skipped (if they + * didn't have any data to process) or excluded (if they were not selected in a refresh + * selection.) + */ + private val TERMINAL_NON_FAILURE_STREAM_STATES: Set[StreamState] = Set( + StreamState.SUCCESSFUL, + StreamState.SKIPPED, + StreamState.EXCLUDED, + StreamState.IDLE + ) +} diff --git a/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/UncaughtExceptionHandler.scala b/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/UncaughtExceptionHandler.scala new file mode 100644 index 0000000000000..0f27e50471d7f --- /dev/null +++ b/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/UncaughtExceptionHandler.scala @@ -0,0 +1,51 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.pipelines.graph + +/** + * Uncaught exception handler which first calls the delegate and then calls the + * OnFailure function with the uncaught exception. + */ +class UncaughtExceptionHandler( + delegate: Option[Thread.UncaughtExceptionHandler], + onFailure: Throwable => Unit) + extends Thread.UncaughtExceptionHandler { + + override def uncaughtException(t: Thread, e: Throwable): Unit = { + try { + delegate.foreach(_.uncaughtException(t, e)) + } finally { + onFailure(e) + } + } +} + +object UncaughtExceptionHandler { + + /** + * Sets a handler which calls 'onFailure' function with the uncaught exception. + * If the thread already has a uncaught exception handler, it will be called first + * before calling the 'onFailure' function. + */ + def addHandler(thread: Thread, onFailure: Throwable => Unit): Unit = { + val currentHandler = Option(thread.getUncaughtExceptionHandler) + thread.setUncaughtExceptionHandler( + new UncaughtExceptionHandler(currentHandler, onFailure) + ) + } +} diff --git a/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/logging/ConstructPipelineEvent.scala b/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/logging/ConstructPipelineEvent.scala index db692ac5f6e5c..604747ea9e432 100644 --- a/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/logging/ConstructPipelineEvent.scala +++ b/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/logging/ConstructPipelineEvent.scala @@ -41,13 +41,14 @@ object ConstructPipelineEvent { Option(t.getCause).map(serializeException).getOrElse(Nil) } - def constructErrorDetails(t: Throwable): ErrorDetail = ErrorDetail(serializeException(t)) + private def constructErrorDetails(t: Throwable): ErrorDetail = ErrorDetail(serializeException(t)) /** * Returns a new event with the current or provided timestamp and the given origin/message. */ def apply( origin: PipelineEventOrigin, + level: EventLevel, message: String, details: EventDetails, exception: Throwable = null, @@ -55,6 +56,7 @@ object ConstructPipelineEvent { ): PipelineEvent = { ConstructPipelineEvent( origin = origin, + level = level, message = message, details = details, errorDetails = Option(exception).map(constructErrorDetails), @@ -67,6 +69,7 @@ object ConstructPipelineEvent { */ def apply( origin: PipelineEventOrigin, + level: EventLevel, message: String, details: EventDetails, errorDetails: Option[ErrorDetail], @@ -82,7 +85,8 @@ object ConstructPipelineEvent { message = message, details = details, error = errorDetails, - origin = origin + origin = origin, + level = level ) } } diff --git a/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/logging/FlowProgressEventLogger.scala b/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/logging/FlowProgressEventLogger.scala new file mode 100644 index 0000000000000..b96bd64a9ce6f --- /dev/null +++ b/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/logging/FlowProgressEventLogger.scala @@ -0,0 +1,294 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.pipelines.logging + +import java.util.concurrent.ConcurrentHashMap + +import scala.jdk.CollectionConverters._ + +import org.apache.spark.internal.Logging +import org.apache.spark.sql.catalyst.TableIdentifier +import org.apache.spark.sql.pipelines.common.FlowStatus +import org.apache.spark.sql.pipelines.graph.{FlowExecution, ResolutionCompletedFlow, ResolvedFlow} + +/** + * This class should be used for all flow progress events logging, it controls the level at which + * events are logged. It uses execution mode, flow name and previous flow statuses to infer the + * level at which an event is to be logged. Below is a more details description of how flow + * progress events for batch/streaming flows will be logged: + * + * For batch & streaming flows in triggered execution mode: + * - All flow progress events other than errors/warnings will be logged at INFO level (including + * flow progress events with metrics) and error/warning messages will be logged at their level. + * + * @param eventBuffer Event log to log the flow progress events. + */ +class FlowProgressEventLogger(eventBuffer: PipelineRunEventBuffer) extends Logging { + + /** + * This map stores flow identifier to a boolean representing whether flow is running. + * - For a flow which is queued and has not yet run, there will be no entry present in the map. + * - For a flow which has started running but failed will have a value of false or will not be + * present in the map. + * - Flow which has started running with no failures will have a value of true. + */ + private val runningFlows = new ConcurrentHashMap[TableIdentifier, Boolean]().asScala + + /** This map stores idle flows, it's a map of flow name to idle status (IDLE|SKIPPED). */ + private val knownIdleFlows = new ConcurrentHashMap[TableIdentifier, FlowStatus]().asScala + + /** + * Records flow progress events with flow status as QUEUED. This event will always be logged at + * INFO level, since flows are only queued once. + */ + def recordQueued(flow: ResolvedFlow): Unit = synchronized { + eventBuffer.addEvent( + ConstructPipelineEvent( + origin = PipelineEventOrigin( + flowName = Option(flow.displayName), + datasetName = None, + sourceCodeLocation = Option(flow.origin.toSourceCodeLocation) + ), + level = EventLevel.INFO, + message = s"Flow ${flow.displayName} is QUEUED.", + details = FlowProgress(FlowStatus.QUEUED) + ) + ) + } + + /** + * Records flow progress events with flow status as PLANNING for batch flows. + */ + def recordPlanningForBatchFlow(batchFlow: ResolvedFlow): Unit = synchronized { + if (batchFlow.df.isStreaming) return + eventBuffer.addEvent( + ConstructPipelineEvent( + origin = PipelineEventOrigin( + flowName = Option(batchFlow.displayName), + datasetName = None, + sourceCodeLocation = Option(batchFlow.origin.toSourceCodeLocation) + ), + level = EventLevel.INFO, + message = s"Flow ${batchFlow.displayName} is PLANNING.", + details = FlowProgress(FlowStatus.PLANNING) + ) + ) + knownIdleFlows.remove(batchFlow.identifier) + } + + /** + * Records flow progress events with flow status as STARTING. For batch flows in continuous mode, + * event will be logged at INFO if the recent flow run had failed otherwise the event will be + * logged at METRICS. All other cases will be logged at INFO. + */ + def recordStart(flowExecution: FlowExecution): Unit = synchronized { + eventBuffer.addEvent( + ConstructPipelineEvent( + origin = PipelineEventOrigin( + flowName = Option(flowExecution.displayName), + datasetName = None, + sourceCodeLocation = Option(flowExecution.getOrigin.toSourceCodeLocation) + ), + level = EventLevel.INFO, + message = s"Flow ${flowExecution.displayName} is STARTING.", + details = FlowProgress(FlowStatus.STARTING) + ) + ) + knownIdleFlows.remove(flowExecution.identifier) + } + + /** Records flow progress events with flow status as RUNNING. */ + def recordRunning(flow: ResolvedFlow): Unit = synchronized { + eventBuffer.addEvent( + ConstructPipelineEvent( + origin = PipelineEventOrigin( + flowName = Option(flow.displayName), + datasetName = None, + sourceCodeLocation = Option(flow.origin.toSourceCodeLocation) + ), + level = EventLevel.INFO, + message = s"Flow ${flow.displayName} is RUNNING.", + details = FlowProgress(FlowStatus.RUNNING) + ) + ) + runningFlows.put(flow.identifier, true) + knownIdleFlows.remove(flow.identifier) + } + + /** + * Records flow progress events with failure flow status. By default failed flow progress events + * are logged at ERROR level, logAsWarn serve as a way to log the event as a WARN. + */ + def recordFailed( + flow: ResolutionCompletedFlow, + exception: Throwable, + logAsWarn: Boolean, + messageOpt: Option[String] = None + ): Unit = synchronized { + val eventLogMessage = messageOpt.getOrElse(s"Flow '${flow.displayName}' has FAILED.") + + eventBuffer.addEvent( + ConstructPipelineEvent( + origin = PipelineEventOrigin( + flowName = Option(flow.displayName), + datasetName = None, + sourceCodeLocation = Option(flow.origin.toSourceCodeLocation) + ), + level = if (logAsWarn) EventLevel.WARN else EventLevel.ERROR, + message = eventLogMessage, + details = FlowProgress(FlowStatus.FAILED), + exception = exception + ) + ) + // Since the flow failed, remove the flow from runningFlows. + runningFlows.remove(flow.identifier) + knownIdleFlows.remove(flow.identifier) + } + + /** + * Records flow progress events with flow status as SKIPPED at WARN level, this version of + * record skipped should be used when the flow is skipped because of upstream flow failures. + */ + def recordSkippedOnUpStreamFailure(flow: ResolvedFlow): Unit = synchronized { + eventBuffer.addEvent( + ConstructPipelineEvent( + origin = PipelineEventOrigin( + flowName = Option(flow.displayName), + datasetName = None, + sourceCodeLocation = Option(flow.origin.toSourceCodeLocation) + ), + level = EventLevel.WARN, + message = s"Flow '${flow.displayName}' SKIPPED due to upstream failure(s).", + details = FlowProgress(FlowStatus.SKIPPED) + ) + ) + runningFlows.remove(flow.identifier) + // Even though this is skipped it is a skipped because of a failure so this is not marked as + // a idle flow. + knownIdleFlows.remove(flow.identifier) + } + + /** + * Records flow progress events with flow status as SKIPPED. For flows skipped because of + * upstream failures use [[recordSkippedOnUpStreamFailure]] function. + */ + def recordSkipped(flow: ResolvedFlow): Unit = synchronized { + eventBuffer.addEvent( + ConstructPipelineEvent( + origin = PipelineEventOrigin( + flowName = Option(flow.displayName), + datasetName = None, + sourceCodeLocation = Option(flow.origin.toSourceCodeLocation) + ), + level = EventLevel.INFO, + message = { + s"Flow '${flow.displayName}' has been processed by a previous iteration " + + s"and will not be rerun." + }, + details = FlowProgress(FlowStatus.SKIPPED) + ) + ) + knownIdleFlows.put(flow.identifier, FlowStatus.SKIPPED) + } + + /** Records flow progress events with flow status as EXCLUDED at INFO level. */ + def recordExcluded(flow: ResolvedFlow): Unit = synchronized { + eventBuffer.addEvent( + ConstructPipelineEvent( + origin = PipelineEventOrigin( + flowName = Option(flow.displayName), + datasetName = None, + sourceCodeLocation = Option(flow.origin.toSourceCodeLocation) + ), + level = EventLevel.INFO, + message = s"Flow '${flow.displayName}' is EXCLUDED.", + details = FlowProgress(FlowStatus.EXCLUDED) + ) + ) + knownIdleFlows.remove(flow.identifier) + } + + /** + * Records flow progress events with flow status as STOPPED. This event will always be logged at + * INFO level, since flows wouldn't run after they are stopped. + */ + def recordStop( + flow: ResolvedFlow, + message: Option[String] = None, + cause: Option[Throwable] = None + ): Unit = synchronized { + eventBuffer.addEvent( + ConstructPipelineEvent( + origin = PipelineEventOrigin( + flowName = Option(flow.displayName), + datasetName = None, + sourceCodeLocation = Option(flow.origin.toSourceCodeLocation) + ), + level = EventLevel.INFO, + message = message.getOrElse(s"Flow '${flow.displayName}' has STOPPED."), + details = FlowProgress(FlowStatus.STOPPED), + exception = cause.orNull + ) + ) + // Once a flow is stopped, remove it from running and idle. + runningFlows.remove(flow.identifier) + knownIdleFlows.remove(flow.identifier) + } + + /** Records flow progress events with flow status as IDLE. */ + def recordIdle(flow: ResolvedFlow): Unit = synchronized { + eventBuffer.addEvent( + ConstructPipelineEvent( + origin = PipelineEventOrigin( + flowName = Option(flow.displayName), + datasetName = None, + sourceCodeLocation = Option(flow.origin.toSourceCodeLocation) + ), + level = EventLevel.INFO, + message = s"Flow '${flow.displayName}' is IDLE, waiting for new data.", + details = FlowProgress(FlowStatus.IDLE) + ) + ) + knownIdleFlows.put(flow.identifier, FlowStatus.IDLE) + } + + /** + * Records flow progress events with flow status as COMPLETED. For batch flows in continuous + * mode, events will be logged at METRICS since a completed status is always preceded by running + * status. + * + * Note that flow complete events for batch flows are expected to contain quality stats where as + * for streaming flows quality stats are not expected and hence not added to the flow progress + * event. + */ + def recordCompletion(flow: ResolvedFlow): Unit = synchronized { + eventBuffer.addEvent( + ConstructPipelineEvent( + origin = PipelineEventOrigin( + flowName = Option(flow.displayName), + datasetName = None, + sourceCodeLocation = Option(flow.origin.toSourceCodeLocation) + ), + level = EventLevel.INFO, + message = s"Flow ${flow.displayName} has COMPLETED.", + details = FlowProgress(FlowStatus.COMPLETED) + ) + ) + knownIdleFlows.remove(flow.identifier) + } +} diff --git a/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/logging/PipelineEvent.scala b/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/logging/PipelineEvent.scala index d0b32daebfef8..90dcbc6e911f8 100644 --- a/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/logging/PipelineEvent.scala +++ b/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/logging/PipelineEvent.scala @@ -16,13 +16,14 @@ */ package org.apache.spark.sql.pipelines.logging -import org.apache.spark.sql.pipelines.common.FlowStatus +import org.apache.spark.sql.pipelines.common.{FlowStatus, RunState} /** * An internal event that is emitted during the run of a pipeline. * @param id A globally unique id * @param timestamp The time of the event * @param origin Where the event originated from + * @param level Security level of the event * @param message A user friendly description of the event * @param details The details of the event * @param error An error that occurred during the event @@ -31,6 +32,7 @@ case class PipelineEvent( id: String, timestamp: String, origin: PipelineEventOrigin, + level: EventLevel, message: String, details: EventDetails, error: Option[ErrorDetail] @@ -65,11 +67,14 @@ case class SourceCodeLocation( ) // Additional details about the PipelineEvent -trait EventDetails +sealed trait EventDetails // An event indicating that a flow has made progress and transitioned to a different state case class FlowProgress(status: FlowStatus) extends EventDetails +// An event indicating that a run has made progress and transitioned to a different state +case class RunProgress(state: RunState) extends EventDetails + // Additional details about the error that occurred during the event case class ErrorDetail(exceptions: Seq[SerializedException]) @@ -78,3 +83,11 @@ case class SerializedException(className: String, message: String, stack: Seq[St // A stack frame of an exception case class StackFrame(declaringClass: String, methodName: String) + +// The severity level of the event. +sealed trait EventLevel +object EventLevel { + case object INFO extends EventLevel + case object WARN extends EventLevel + case object ERROR extends EventLevel +} diff --git a/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/logging/PipelineRunEventBuffer.scala b/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/logging/PipelineRunEventBuffer.scala new file mode 100644 index 0000000000000..1ef2a561a9913 --- /dev/null +++ b/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/logging/PipelineRunEventBuffer.scala @@ -0,0 +1,52 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.pipelines.logging + +import scala.collection.mutable.ArrayBuffer + +import org.apache.spark.internal.Logging + +/** + * An in-memory buffer which contains the internal events that are emitted during a run of a + * pipeline. + * + * @param eventCallback A callback function to be called when an event is added to the buffer. + */ +class PipelineRunEventBuffer(eventCallback: PipelineEvent => Unit) extends Logging { + + /** + * A buffer to hold the events emitted during a pipeline run. + * This buffer is thread-safe and can be accessed concurrently. + * + * TODO(SPARK-52409): Deprecate this class to be used in test only and use a more + * robust event logging system in production. + */ + private val events = ArrayBuffer[PipelineEvent]() + + def addEvent(event: PipelineEvent): Unit = synchronized { + val eventToAdd = event + events.append(eventToAdd) + eventCallback(event) + } + + def clear(): Unit = synchronized { + events.clear() + } + + def getEvents: Seq[PipelineEvent] = events.toSeq +} diff --git a/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/logging/StreamListener.scala b/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/logging/StreamListener.scala new file mode 100644 index 0000000000000..fdf7c821714f1 --- /dev/null +++ b/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/logging/StreamListener.scala @@ -0,0 +1,60 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.pipelines.logging + +import java.util.UUID + +import org.apache.spark.internal.Logging +import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.pipelines.graph.{DataflowGraph, GraphIdentifierManager, PipelineUpdateContext, ResolvedFlow} +import org.apache.spark.sql.streaming.{StreamingQuery, StreamingQueryListener} + +/** + * A streaming listener that converts streaming events into pipeline events for the relevant flows. + */ +class StreamListener( + env: PipelineUpdateContext, + graphForExecution: DataflowGraph +) extends StreamingQueryListener + with Logging { + + private val queries = new java.util.concurrent.ConcurrentHashMap[UUID, StreamingQuery]() + + private def spark = SparkSession.active + + override def onQueryStarted(event: StreamingQueryListener.QueryStartedEvent): Unit = { + val stream = spark.streams.get(event.id) + queries.put(event.runId, stream) + env.flowProgressEventLogger.recordRunning(getFlowFromStreamName(stream.name)) + } + + override def onQueryProgress(event: StreamingQueryListener.QueryProgressEvent): Unit = {} + + override def onQueryTerminated(event: StreamingQueryListener.QueryTerminatedEvent): Unit = { + // if the non-pipelines managed stream is started before flow execution started, + // onQueryStarted would not have captured the stream and it will not be in the queries map + if (!queries.containsKey(event.runId)) return + + val stream = queries.remove(event.runId) + env.flowProgressEventLogger.recordCompletion(getFlowFromStreamName(stream.name)) + } + + private def getFlowFromStreamName(streamName: String): ResolvedFlow = { + val flowIdentifier = GraphIdentifierManager.parseTableIdentifier(streamName, env.spark) + graphForExecution.resolvedFlow(flowIdentifier) + } +} diff --git a/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/util/BackoffStrategy.scala b/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/util/BackoffStrategy.scala new file mode 100644 index 0000000000000..29712f7839dd2 --- /dev/null +++ b/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/util/BackoffStrategy.scala @@ -0,0 +1,67 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.pipelines.util + +import scala.concurrent.duration._ +import scala.math.{log10, pow} + +/** + * A `BackoffStrategy` determines the backoff duration (how long we should wait) for + * retries after failures. + */ +trait BackoffStrategy { + + /** Returns the amount of time to wait after `numFailures` failures. */ + def waitDuration(numFailures: Int): FiniteDuration +} + +/** + * A `BackoffStrategy` where the back-off time grows exponentially for each + * successive retry. + * + * The back-off time after `n` failures is min(maxTime, (2 ** n) * stepSize). + * + * @param maxTime Maximum back-off time. + * @param stepSize Minimum step size to increment back-off. + */ +case class ExponentialBackoffStrategy(maxTime: FiniteDuration, stepSize: FiniteDuration) + extends BackoffStrategy { + + require( + stepSize >= 0.seconds, + s"Back-off step size must be non-negative. Given value: $stepSize" + ) + require( + maxTime >= 0.seconds, + s"Back-off max time must be non-negative. Given value: $stepSize" + ) + + override def waitDuration(numFailures: Int): FiniteDuration = { + require( + numFailures >= 0, + s"Number of failures must be non-negative. Given value: $numFailures." + ) + + if (stepSize <= 0.seconds) return 0.seconds + if (stepSize >= maxTime) return maxTime + + def log2(x: Double) = log10(x) / log10(2.0) + val willExceedMax = numFailures >= log2(maxTime.toNanos.toDouble / stepSize.toNanos) + 1 + if (!willExceedMax) pow(2, numFailures - 1).toLong * stepSize else maxTime + } +} diff --git a/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/util/SparkSessionUtils.scala b/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/util/SparkSessionUtils.scala new file mode 100644 index 0000000000000..95cf3af285e06 --- /dev/null +++ b/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/util/SparkSessionUtils.scala @@ -0,0 +1,41 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.pipelines.util + +import org.apache.spark.sql.SparkSession + +object SparkSessionUtils { + + /** + * Sets all SQL configurations specified in `pairs`, calls `f`, and then restores all SQL + * configurations. + */ + def withSqlConf[T](spark: SparkSession, pairs: (String, String)*)(f: => T): T = { + val conf = spark.conf + val (keys, values) = pairs.unzip + val currentValues = keys.map(conf.getOption) + keys.lazyZip(values).foreach((k, v) => conf.set(k, v)) + try f + finally { + keys.zip(currentValues).foreach { + case (key, Some(value)) => conf.set(key, value) + case (key, None) => conf.unset(key) + } + } + } +} diff --git a/sql/pipelines/src/test/scala/org/apache/spark/sql/pipelines/graph/MaterializeTablesSuite.scala b/sql/pipelines/src/test/scala/org/apache/spark/sql/pipelines/graph/MaterializeTablesSuite.scala new file mode 100644 index 0000000000000..90d752aff05cc --- /dev/null +++ b/sql/pipelines/src/test/scala/org/apache/spark/sql/pipelines/graph/MaterializeTablesSuite.scala @@ -0,0 +1,827 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.pipelines.graph + +import scala.jdk.CollectionConverters._ + +import org.apache.spark.SparkThrowable +import org.apache.spark.sql.connector.catalog.{CatalogV2Util, Identifier, TableCatalog} +import org.apache.spark.sql.connector.expressions.Expressions +import org.apache.spark.sql.execution.streaming.MemoryStream +import org.apache.spark.sql.pipelines.graph.DatasetManager.TableMaterializationException +import org.apache.spark.sql.pipelines.utils.{BaseCoreExecutionTest, TestGraphRegistrationContext} +import org.apache.spark.sql.types._ +import org.apache.spark.util.Utils.exceptionString + +/** + * Local integration tests for materialization of `Table`s in a `DataflowGraph` to make sure + * tables are written with the appropriate schemas. + */ +class MaterializeTablesSuite extends BaseCoreExecutionTest { + + import originalSpark.implicits._ + + test("basic") { + materializeGraph( + new TestGraphRegistrationContext(spark) { + registerFlow( + "a", + "a", + query = dfFlowFunc(Seq((1, 1), (2, 3)).toDF("x", "x2")) + ) + registerTable( + "a", + specifiedSchema = Option( + new StructType() + .add("x", IntegerType, nullable = false, "comment1") + .add("x2", IntegerType, nullable = true, "comment2") + ), + comment = Option("p-comment") + ) + }.resolveToDataflowGraph() + ) + + val identifier = Identifier.of(Array(TestGraphRegistrationContext.DEFAULT_DATABASE), "a") + val catalog = spark.sessionState.catalogManager.currentCatalog.asInstanceOf[TableCatalog] + val catalogTable = catalog.loadTable(identifier) + + assert( + catalogTable.columns() sameElements CatalogV2Util.structTypeToV2Columns( + new StructType() + .add("x", IntegerType, nullable = false, "comment1") + .add("x2", IntegerType, nullable = true, "comment2") + ) + ) + assert(catalogTable.properties().get(TableCatalog.PROP_COMMENT) == "p-comment") + + materializeGraph( + new TestGraphRegistrationContext(spark) { + registerFlow( + "a", + "a", + query = dfFlowFunc(Seq((1, 1), (2, 3)).toDF("x", "x2")) + ) + registerTable( + "a", + specifiedSchema = Option( + new StructType() + .add("x", IntegerType, nullable = false, "comment3") + .add("x2", IntegerType, nullable = true, "comment4") + ), + comment = Option("p-comment") + ) + }.resolveToDataflowGraph() + ) + val catalogTable2 = catalog.loadTable(identifier) + assert( + catalogTable2.columns() sameElements CatalogV2Util.structTypeToV2Columns( + new StructType() + .add("x", IntegerType, nullable = false, "comment3") + .add("x2", IntegerType, nullable = true, "comment4") + ) + ) + assert(catalogTable2.properties().get(TableCatalog.PROP_COMMENT) == "p-comment") + + materializeGraph( + new TestGraphRegistrationContext(spark) { + registerFlow( + "a", + "a", + query = dfFlowFunc(Seq((1, 1), (2, 3)).toDF("x", "x2")) + ) + registerTable( + "a", + specifiedSchema = Option( + new StructType() + .add("x", IntegerType, nullable = false) + .add("x2", IntegerType, nullable = true) + ), + comment = Option("p-comment") + ) + }.resolveToDataflowGraph() + ) + + val catalogTable3 = catalog.loadTable(identifier) + assert( + catalogTable3.columns() sameElements CatalogV2Util.structTypeToV2Columns( + new StructType() + .add("x", IntegerType, nullable = false, comment = null) + .add("x2", IntegerType, nullable = true, comment = null) + ) + ) + assert(catalogTable3.properties().get(TableCatalog.PROP_COMMENT) == "p-comment") + } + + test("multiple") { + materializeGraph( + new TestGraphRegistrationContext(spark) { + registerFlow( + "t1", + "t1", + query = dfFlowFunc(Seq(1, 2, 3).toDF("x")) + ) + registerFlow( + "t2", + "t2", + query = dfFlowFunc(Seq("a", "b").toDF("y")) + ) + registerTable("t1") + registerTable("t2") + }.resolveToDataflowGraph() + ) + + val identifier1 = Identifier.of(Array(TestGraphRegistrationContext.DEFAULT_DATABASE), "t1") + val identifier2 = Identifier.of(Array(TestGraphRegistrationContext.DEFAULT_DATABASE), "t2") + val catalog = spark.sessionState.catalogManager.currentCatalog.asInstanceOf[TableCatalog] + val catalogTable1 = catalog.loadTable(identifier1) + val catalogTable2 = catalog.loadTable(identifier2) + + assert( + catalogTable1.columns() sameElements CatalogV2Util + .structTypeToV2Columns(new StructType().add("x", IntegerType)) + ) + assert( + catalogTable2.columns() sameElements CatalogV2Util + .structTypeToV2Columns(new StructType().add("y", StringType)) + ) + } + + test("temporary views don't get materialized") { + materializeGraph( + new TestGraphRegistrationContext(spark) { + registerFlow( + "t2", + "t2", + query = dfFlowFunc(Seq("a", "b").toDF("y")) + ) + registerTable("t2") + registerView( + "t1", + dfFlowFunc(Seq(1, 2, 3).toDF("x")) + ) + }.resolveToDataflowGraph() + ) + + val catalog = spark.sessionState.catalogManager.currentCatalog.asInstanceOf[TableCatalog] + assert( + !catalog.tableExists( + Identifier.of(Array(TestGraphRegistrationContext.DEFAULT_DATABASE), "t1") + ) + ) + assert( + catalog.tableExists(Identifier.of(Array(TestGraphRegistrationContext.DEFAULT_DATABASE), "t2")) + ) + } + + // TableManager performs different validations for batch tables vs streaming tables when + // materializing tables. Flows writing to a batch tables can have incompatible schemas with the + // existing table since the table is being overwritten completely. This test ensures that + // it is possible to do that. + test("batch flow reading from streaming table") { + class P1 extends TestGraphRegistrationContext(spark) { + registerTable( + "a", + query = Option(dfFlowFunc(spark.readStream.format("rate").load())) + ) + // Defines a column called timestamp as `int`. + registerTable( + "b", + query = Option(sqlFlowFunc(spark, "SELECT value AS timestamp FROM a")) + ) + } + materializeGraph(new P1().resolveToDataflowGraph()) + + val catalog = spark.sessionState.catalogManager.currentCatalog.asInstanceOf[TableCatalog] + val b = + catalog.loadTable(Identifier.of(Array(TestGraphRegistrationContext.DEFAULT_DATABASE), "b")) + assert( + b.columns() sameElements CatalogV2Util + .structTypeToV2Columns(new StructType().add("timestamp", LongType)) + ) + + class P2 extends TestGraphRegistrationContext(spark) { + registerTable( + "a", + query = Option(dfFlowFunc(spark.readStream.format("rate").load())) + ) + // Defines a column called timestamp as `timestamp`. + registerTable( + "b", + query = Option(sqlFlowFunc(spark, "SELECT timestamp FROM a")) + ) + } + materializeGraph(new P2().resolveToDataflowGraph()) + val b2 = + catalog.loadTable(Identifier.of(Array(TestGraphRegistrationContext.DEFAULT_DATABASE), "b")) + assert( + b2.columns() sameElements CatalogV2Util + .structTypeToV2Columns(new StructType().add("timestamp", TimestampType)) + ) + } + + test("schema matches existing table schema") { + sql(s"CREATE TABLE ${TestGraphRegistrationContext.DEFAULT_DATABASE}.t2(x INT)") + val catalog = spark.sessionState.catalogManager.currentCatalog.asInstanceOf[TableCatalog] + val identifier = Identifier.of(Array(TestGraphRegistrationContext.DEFAULT_DATABASE), "t2") + val table = catalog.loadTable(identifier) + assert( + table.columns() sameElements CatalogV2Util.structTypeToV2Columns( + new StructType().add("x", IntegerType) + ) + ) + + materializeGraph( + new TestGraphRegistrationContext(spark) { + registerFlow("t2", "t2", query = dfFlowFunc(Seq(1, 2, 3).toDF("x"))) + registerTable("t2") + }.resolveToDataflowGraph() + ) + + val table2 = catalog.loadTable(identifier) + assert( + table2.columns() sameElements CatalogV2Util + .structTypeToV2Columns(new StructType().add("x", IntegerType)) + ) + } + + test("invalid schema merge") { + val streamInts = MemoryStream[Int] + streamInts.addData(1, 2) + + materializeGraph( + new TestGraphRegistrationContext(spark) { + registerView("a", query = dfFlowFunc(streamInts.toDF())) + registerTable("b", query = Option(sqlFlowFunc(spark, "SELECT value AS x FROM STREAM a"))) + }.resolveToDataflowGraph() + ) + + val streamStrings = MemoryStream[String] + streamStrings.addData("a", "b") + val graph2 = new TestGraphRegistrationContext(spark) { + registerView("a", query = dfFlowFunc(streamStrings.toDF())) + registerTable("b", query = Option(sqlFlowFunc(spark, "SELECT value AS x FROM STREAM a"))) + }.resolveToDataflowGraph() + + val ex = intercept[TableMaterializationException] { + materializeGraph(graph2) + } + val cause = ex.cause + val exStr = exceptionString(cause) + assert(exStr.contains("Failed to merge incompatible data types")) + } + + test("table materialized with specified schema, even if different from inferred") { + sql(s"CREATE TABLE ${TestGraphRegistrationContext.DEFAULT_DATABASE}.t4(x INT)") + val catalog = spark.sessionState.catalogManager.currentCatalog.asInstanceOf[TableCatalog] + val identifier = Identifier.of(Array(TestGraphRegistrationContext.DEFAULT_DATABASE), "t4") + val table = catalog.loadTable(identifier) + assert( + table.columns() sameElements CatalogV2Util.structTypeToV2Columns( + new StructType().add("x", IntegerType) + ) + ) + + materializeGraph( + new TestGraphRegistrationContext(spark) { + registerFlow("t4", "t4", query = dfFlowFunc(Seq[Short](1, 2).toDF("x"))) + registerTable( + "t4", + specifiedSchema = Option( + new StructType() + .add("x", IntegerType, nullable = true, "this is column x") + .add("z", LongType, nullable = true, "this is column z") + ) + ) + }.resolveToDataflowGraph() + ) + + val table2 = catalog.loadTable(identifier) + assert( + table2.columns() sameElements CatalogV2Util.structTypeToV2Columns( + new StructType() + .add("x", IntegerType, nullable = true, "this is column x") + .add("z", LongType, nullable = true, "this is column z") + ) + ) + } + + test("specified schema incompatible with existing table") { + sql(s"CREATE TABLE ${TestGraphRegistrationContext.DEFAULT_DATABASE}.t6(x BOOLEAN)") + val catalog = spark.sessionState.catalogManager.currentCatalog.asInstanceOf[TableCatalog] + val identifier = Identifier.of(Array(TestGraphRegistrationContext.DEFAULT_DATABASE), "t6") + val table = catalog.loadTable(identifier) + assert( + table.columns() sameElements CatalogV2Util.structTypeToV2Columns( + new StructType().add("x", BooleanType) + ) + ) + + val ex = intercept[TableMaterializationException] { + materializeGraph(new TestGraphRegistrationContext(spark) { + val source: MemoryStream[Int] = MemoryStream[Int] + source.addData(1, 2) + registerTable( + "t6", + specifiedSchema = Option(new StructType().add("x", IntegerType)), + query = Option(dfFlowFunc(source.toDF().select($"value" as "x"))) + ) + + }.resolveToDataflowGraph()) + } + val cause = ex.cause + val exStr = exceptionString(cause) + assert(exStr.contains("Failed to merge incompatible data types")) + + // Works fine for a complete table + materializeGraph(new TestGraphRegistrationContext(spark) { + registerTable( + "t6", + specifiedSchema = Option(new StructType().add("x", IntegerType)), + query = Option(dfFlowFunc(Seq(1, 2).toDF("x"))) + ) + }.resolveToDataflowGraph()) + val table2 = catalog.loadTable(identifier) + assert( + table2.columns() sameElements CatalogV2Util + .structTypeToV2Columns(new StructType().add("x", IntegerType)) + ) + } + + test("partition columns with user schema") { + materializeGraph( + new TestGraphRegistrationContext(spark) { + registerTable( + "a", + query = Option(dfFlowFunc(Seq((1, 1), (2, 3)).toDF("x1", "x2"))), + specifiedSchema = Option( + new StructType() + .add("x1", IntegerType) + .add("x2", IntegerType) + ), + partitionCols = Option(Seq("x2")) + ) + }.resolveToDataflowGraph() + ) + val catalog = spark.sessionState.catalogManager.currentCatalog.asInstanceOf[TableCatalog] + val identifier = Identifier.of(Array(TestGraphRegistrationContext.DEFAULT_DATABASE), "a") + val table = catalog.loadTable(identifier) + assert( + table.columns() sameElements CatalogV2Util.structTypeToV2Columns( + new StructType().add("x1", IntegerType).add("x2", IntegerType) + ) + ) + assert(table.partitioning().toSeq == Seq(Expressions.identity("x2"))) + } + + test("specifying partition column with existing partitioned table") { + sql( + s"CREATE TABLE ${TestGraphRegistrationContext.DEFAULT_DATABASE}.t7(x BOOLEAN, y INT) " + + s"PARTITIONED BY (x)" + ) + val catalog = spark.sessionState.catalogManager.currentCatalog.asInstanceOf[TableCatalog] + val identifier = Identifier.of(Array(TestGraphRegistrationContext.DEFAULT_DATABASE), "t7") + val table = catalog.loadTable(identifier) + assert( + table.columns().map(_.name()).toSet == new StructType() + .add("x", BooleanType) + .add("y", IntegerType) + .fieldNames + .toSet + ) + assert(table.partitioning().toSeq == Seq(Expressions.identity("x"))) + + // Specify the same partition column. + materializeGraph( + new TestGraphRegistrationContext(spark) { + registerFlow( + "t7", + "t7", + query = dfFlowFunc(Seq((true, 1), (false, 3)).toDF("x", "y")) + ) + registerTable( + "t7", + partitionCols = Option(Seq("x")) + ) + }.resolveToDataflowGraph() + ) + + val table2 = catalog.loadTable(identifier) + assert( + table2.columns() sameElements CatalogV2Util + .structTypeToV2Columns(new StructType().add("y", IntegerType).add("x", BooleanType)) + ) + assert(table2.partitioning().toSeq == Seq(Expressions.identity("x"))) + + // Don't specify any partition column; should throw. + val ex = intercept[TableMaterializationException] { + materializeGraph( + new TestGraphRegistrationContext(spark) { + registerFlow( + "t7", + "t7", + query = dfFlowFunc(Seq((true, 1), (false, 3)).toDF("x", "y")) + ) + registerTable("t7") + }.resolveToDataflowGraph() + ) + } + assert(ex.cause.asInstanceOf[SparkThrowable].getCondition == "CANNOT_UPDATE_PARTITION_COLUMNS") + + val table3 = catalog.loadTable(identifier) + assert( + table3.columns() sameElements CatalogV2Util + .structTypeToV2Columns(new StructType().add("y", IntegerType).add("x", BooleanType)) + ) + assert(table3.partitioning().toSeq == Seq(Expressions.identity("x"))) + } + + test("specifying partition column different from existing partitioned table") { + sql( + s"CREATE TABLE ${TestGraphRegistrationContext.DEFAULT_DATABASE}.t8(x BOOLEAN, y INT) " + + s"PARTITIONED BY (x)" + ) + Seq((true, 1), (false, 1)).toDF("x", "y").write.mode("append").saveAsTable("t8") + + val catalog = spark.sessionState.catalogManager.currentCatalog.asInstanceOf[TableCatalog] + val identifier = Identifier.of(Array(TestGraphRegistrationContext.DEFAULT_DATABASE), "t8") + + // Specify a different partition column. Should throw. + val graph = new TestGraphRegistrationContext(spark) { + registerFlow( + "t8", + "t8", + query = dfFlowFunc(Seq((true, 1), (false, 3)).toDF("x", "y")) + ) + registerTable("t8", partitionCols = Option(Seq("y"))) + }.resolveToDataflowGraph() + + val ex = intercept[TableMaterializationException] { + materializeGraph(graph) + } + assert(ex.cause.asInstanceOf[SparkThrowable].getCondition == "CANNOT_UPDATE_PARTITION_COLUMNS") + val table = catalog.loadTable(identifier) + assert(table.partitioning().toSeq == Seq(Expressions.identity("x"))) + } + + test("Table properties are set when table gets materialized") { + materializeGraph( + new TestGraphRegistrationContext(spark) { + registerTable( + "a", + query = Option(dfFlowFunc(spark.readStream.format("rate").load())), + properties = Map( + "pipelines.reset.allowed" -> "true", + "some.prop" -> "foo" + ) + ) + registerTable( + "b", + query = Option(sqlFlowFunc(spark, "SELECT * FROM STREAM a")), + properties = Map("pipelines.reset.alloweD" -> "true", "some.prop" -> "foo") + ) + }.resolveToDataflowGraph() + ) + + val catalog = spark.sessionState.catalogManager.currentCatalog.asInstanceOf[TableCatalog] + val identifierA = Identifier.of(Array(TestGraphRegistrationContext.DEFAULT_DATABASE), "a") + val identifierB = Identifier.of(Array(TestGraphRegistrationContext.DEFAULT_DATABASE), "b") + val tableA = catalog.loadTable(identifierA) + val tableB = catalog.loadTable(identifierB) + + val expectedProps = Map( + "pipelines.reset.allowed" -> "true", + "some.prop" -> "foo" + ) + + assert(expectedProps.forall { case (k, v) => tableA.properties().asScala.get(k).contains(v) }) + assert(expectedProps.forall { case (k, v) => tableB.properties().asScala.get(k).contains(v) }) + } + + test("Invalid table properties error during table materialization") { + // Invalid pipelines property + val graph1 = + new TestGraphRegistrationContext(spark) { + registerTable( + "a", + query = Option(dfFlowFunc(Seq(1).toDF())), + properties = Map("pipelines.reset.allowed" -> "123") + ) + }.resolveToDataflowGraph() + val ex1 = + intercept[TableMaterializationException] { + materializeGraph(graph1) + } + + assert(ex1.cause.isInstanceOf[IllegalArgumentException]) + assert(ex1.cause.getMessage.contains("pipelines.reset.allowed")) + } + + test( + "Materialization succeeds even if there are unknown pipeline properties on the existing table" + ) { + sql( + s"CREATE TABLE ${TestGraphRegistrationContext.DEFAULT_DATABASE}.t9(x INT) " + + s"TBLPROPERTIES ('pipelines.someProperty' = 'foo')" + ) + + val graph1 = new TestGraphRegistrationContext(spark) { + registerTable("a", query = Option(dfFlowFunc(spark.readStream.format("rate").load()))) + }.resolveToDataflowGraph().validate() + + materializeGraph(graph1) + } + + for (isFullRefresh <- Seq(true, false)) { + test( + s"Complete tables should not evolve schema - isFullRefresh = $isFullRefresh" + ) { + val rawGraph = + new TestGraphRegistrationContext(spark) { + registerView("a", query = dfFlowFunc(Seq((1, 2), (2, 3)).toDF("x", "y"))) + registerTable("b", query = Option(sqlFlowFunc(spark, "SELECT x FROM a"))) + }.resolveToDataflowGraph() + + val graph = materializeGraph(rawGraph) + val (refreshSelection, fullRefreshSelection) = if (isFullRefresh) { + (NoTables, AllTables) + } else { + (AllTables, NoTables) + } + + materializeGraph( + rawGraph, + contextOpt = Option( + TestPipelineUpdateContext( + spark = spark, + unresolvedGraph = graph, + refreshTables = refreshSelection, + fullRefreshTables = fullRefreshSelection + ) + ) + ) + + val catalog = spark.sessionState.catalogManager.currentCatalog.asInstanceOf[TableCatalog] + val identifier = Identifier.of(Array(TestGraphRegistrationContext.DEFAULT_DATABASE), "b") + + val table = catalog.loadTable(identifier) + assert( + table.columns() sameElements CatalogV2Util + .structTypeToV2Columns(new StructType().add("x", IntegerType)) + ) + + materializeGraph( + new TestGraphRegistrationContext(spark) { + registerView("a", query = dfFlowFunc(Seq((1, 2), (2, 3)).toDF("x", "y"))) + registerTable("b", query = Option(sqlFlowFunc(spark, "SELECT y FROM a"))) + }.resolveToDataflowGraph() + ) + val table2 = catalog.loadTable(identifier) + assert( + table2.columns() sameElements CatalogV2Util + .structTypeToV2Columns(new StructType().add("y", IntegerType)) + ) + } + } + + for (isFullRefresh <- Seq(true, false)) { + test( + s"Streaming tables should evolve schema only if not full refresh = $isFullRefresh" + ) { + val streamInts = MemoryStream[Int] + streamInts.addData(1 until 5: _*) + + val graph = + new TestGraphRegistrationContext(spark) { + registerView("a", query = dfFlowFunc(streamInts.toDF())) + registerTable("b", query = Option(sqlFlowFunc(spark, "SELECT value AS x FROM STREAM a"))) + }.resolveToDataflowGraph().validate() + + val (refreshSelection, fullRefreshSelection) = if (isFullRefresh) { + (NoTables, AllTables) + } else { + (AllTables, NoTables) + } + val updateContextOpt = Option( + TestPipelineUpdateContext( + spark = spark, + unresolvedGraph = graph, + refreshTables = refreshSelection, + fullRefreshTables = fullRefreshSelection + ) + ) + materializeGraph(graph, contextOpt = updateContextOpt) + + val catalog = spark.sessionState.catalogManager.currentCatalog.asInstanceOf[TableCatalog] + val identifier = Identifier.of(Array(TestGraphRegistrationContext.DEFAULT_DATABASE), "b") + val table = catalog.loadTable(identifier) + assert( + table.columns() sameElements CatalogV2Util + .structTypeToV2Columns(new StructType().add("x", IntegerType)) + ) + + materializeGraph( + new TestGraphRegistrationContext(spark) { + registerView("a", query = dfFlowFunc(streamInts.toDF())) + registerTable("b", query = Option(sqlFlowFunc(spark, "SELECT value AS y FROM STREAM a"))) + }.resolveToDataflowGraph().validate(), + contextOpt = updateContextOpt + ) + + val table2 = catalog.loadTable(identifier) + + if (isFullRefresh) { + assert( + table2.columns() sameElements CatalogV2Util.structTypeToV2Columns( + new StructType().add("y", IntegerType) + ) + ) + } else { + assert( + table2.columns() sameElements CatalogV2Util.structTypeToV2Columns( + new StructType() + .add("x", IntegerType) + .add("y", IntegerType) + ) + ) + } + } + } + + test( + "materialize only selected tables" + ) { + val graph = new TestGraphRegistrationContext(spark) { + registerTable("a", query = Option(dfFlowFunc(Seq((1, 2), (2, 3)).toDF("x", "y")))) + registerTable("b", query = Option(sqlFlowFunc(spark, "SELECT x FROM a"))) + registerTable("c", query = Option(sqlFlowFunc(spark, "SELECT y FROM a"))) + }.resolveToDataflowGraph() + materializeGraph( + graph, + contextOpt = Option( + TestPipelineUpdateContext( + spark = spark, + unresolvedGraph = graph, + refreshTables = SomeTables(Set(fullyQualifiedIdentifier("a"))), + fullRefreshTables = SomeTables(Set(fullyQualifiedIdentifier("c"))) + ) + ) + ) + + val catalog = spark.sessionState.catalogManager.currentCatalog.asInstanceOf[TableCatalog] + + val tableA = + catalog.loadTable(Identifier.of(Array(TestGraphRegistrationContext.DEFAULT_DATABASE), "a")) + assert( + !catalog.tableExists(Identifier.of(Array(TestGraphRegistrationContext.DEFAULT_DATABASE), "b")) + ) + val tableC = + catalog.loadTable(Identifier.of(Array(TestGraphRegistrationContext.DEFAULT_DATABASE), "c")) + + assert( + tableA.columns() sameElements CatalogV2Util.structTypeToV2Columns( + new StructType() + .add("x", IntegerType) + .add("y", IntegerType) + ) + ) + + assert( + tableC.columns() sameElements CatalogV2Util + .structTypeToV2Columns(new StructType().add("y", IntegerType)) + ) + } + + test("tables with arrays and maps") { + val rawGraph = + new TestGraphRegistrationContext(spark) { + registerTable("a", query = Option(sqlFlowFunc(spark, "select map(1, struct('a', 'b')) m"))) + registerTable( + "b", + query = Option(dfFlowFunc(Seq(Array(1, 3, 5), Array(2, 4, 6)).toDF("arr"))) + ) + registerTable( + "c", + query = Option( + sqlFlowFunc(spark, "select * from a join b where map_entries(m)[0].key = arr[0]") + ) + ) + }.resolveToDataflowGraph() + materializeGraph(rawGraph) + // Materialize twice because some logic compares the incoming schema with the previous one. + materializeGraph(rawGraph) + + val catalog = spark.sessionState.catalogManager.currentCatalog.asInstanceOf[TableCatalog] + val tableA = + catalog.loadTable(Identifier.of(Array(TestGraphRegistrationContext.DEFAULT_DATABASE), "a")) + val tableB = + catalog.loadTable(Identifier.of(Array(TestGraphRegistrationContext.DEFAULT_DATABASE), "b")) + val tableC = + catalog.loadTable(Identifier.of(Array(TestGraphRegistrationContext.DEFAULT_DATABASE), "c")) + + assert( + tableA.columns() sameElements CatalogV2Util.structTypeToV2Columns( + StructType.fromDDL("m MAP>") + ) + ) + assert( + tableB.columns() sameElements CatalogV2Util.structTypeToV2Columns( + StructType.fromDDL("arr ARRAY") + ) + ) + assert( + tableC.columns() sameElements CatalogV2Util.structTypeToV2Columns( + StructType.fromDDL("m MAP>, arr ARRAY") + ) + ) + } + + test("tables with nested arrays and maps") { + val rawGraph = + new TestGraphRegistrationContext(spark) { + registerTable( + "a", + query = Option(sqlFlowFunc(spark, "select map(0, map(0, struct('a', 'b'))) m")) + ) + registerTable( + "b", + query = Option( + sqlFlowFunc(spark, "select array(array('a', 'b', 'c'), array('d', 'e', 'f')) arr") + ) + ) + registerTable( + "c", + query = + Option(sqlFlowFunc(spark, "select * from a join b where m[0][0].col1 = arr[0][0]")) + ) + + }.resolveToDataflowGraph() + materializeGraph(rawGraph) + // Materialize twice because some logic compares the incoming schema with the previous one. + materializeGraph(rawGraph) + val catalog = spark.sessionState.catalogManager.currentCatalog.asInstanceOf[TableCatalog] + val tableA = + catalog.loadTable(Identifier.of(Array(TestGraphRegistrationContext.DEFAULT_DATABASE), "a")) + val tableB = + catalog.loadTable(Identifier.of(Array(TestGraphRegistrationContext.DEFAULT_DATABASE), "b")) + val tableC = + catalog.loadTable(Identifier.of(Array(TestGraphRegistrationContext.DEFAULT_DATABASE), "c")) + + assert( + tableA.columns() sameElements CatalogV2Util.structTypeToV2Columns( + StructType.fromDDL("m MAP>>") + ) + ) + assert( + tableB.columns() sameElements CatalogV2Util + .structTypeToV2Columns(StructType.fromDDL("arr ARRAY>")) + ) + assert( + tableC.columns() sameElements CatalogV2Util.structTypeToV2Columns( + StructType.fromDDL( + "m MAP>>, arr ARRAY>" + ) + ) + ) + } + + test("materializing no tables doesn't throw") { + val graph1 = + new DataflowGraph(flows = Seq.empty, tables = Seq.empty, views = Seq.empty) + val graph2 = new TestGraphRegistrationContext(spark) { + registerFlow( + "a", + "a", + query = dfFlowFunc(Seq((1, 1), (2, 3)).toDF("x", "x2")) + ) + registerTable("a") + }.resolveToDataflowGraph() + + materializeGraph(graph1) + materializeGraph( + graph2, + contextOpt = Option( + TestPipelineUpdateContext( + spark = spark, + unresolvedGraph = graph2, + refreshTables = NoTables, + fullRefreshTables = NoTables + ) + ) + ) + } +} diff --git a/sql/pipelines/src/test/scala/org/apache/spark/sql/pipelines/graph/TriggeredGraphExecutionSuite.scala b/sql/pipelines/src/test/scala/org/apache/spark/sql/pipelines/graph/TriggeredGraphExecutionSuite.scala new file mode 100644 index 0000000000000..54759c41ace5d --- /dev/null +++ b/sql/pipelines/src/test/scala/org/apache/spark/sql/pipelines/graph/TriggeredGraphExecutionSuite.scala @@ -0,0 +1,1019 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.pipelines.graph + +import org.scalatest.concurrent.Eventually.eventually +import org.scalatest.concurrent.Futures.timeout +import org.scalatest.time.{Seconds, Span} + +import org.apache.spark.sql.{functions, Row} +import org.apache.spark.sql.catalyst.TableIdentifier +import org.apache.spark.sql.classic.{DataFrame, Dataset} +import org.apache.spark.sql.connector.catalog.{CatalogV2Util, Identifier, TableCatalog} +import org.apache.spark.sql.execution.streaming.MemoryStream +import org.apache.spark.sql.pipelines.common.{FlowStatus, RunState} +import org.apache.spark.sql.pipelines.graph.TriggeredGraphExecution.StreamState +import org.apache.spark.sql.pipelines.logging.EventLevel +import org.apache.spark.sql.pipelines.utils.{ExecutionTest, TestGraphRegistrationContext} +import org.apache.spark.sql.types.{IntegerType, StringType, StructType} + +class TriggeredGraphExecutionSuite extends ExecutionTest { + + import originalSpark.implicits._ + + /** Returns a Dataset of Longs from the table with the given identifier. */ + private def getTable(identifier: TableIdentifier): Dataset[Long] = { + spark.read.table(identifier.toString).as[Long] + } + + /** Return flows with expected stream state. */ + private def getFlowsWithState( + graphExecution: TriggeredGraphExecution, + state: StreamState + ): Set[TableIdentifier] = { + graphExecution.pipelineState.collect { + case (flowIdentifier, flowState) if flowState == state => flowIdentifier + }.toSet + } + + test("basic graph resolution and execution") { + val pipelineDef = new TestGraphRegistrationContext(spark) { + registerTable("a", query = Option(dfFlowFunc(Seq(1, 2).toDF("x")))) + registerTable("b", query = Option(readFlowFunc("a"))) + } + val unresolvedGraph = pipelineDef.toDataflowGraph + val resolvedGraph = unresolvedGraph.resolve() + assert(resolvedGraph.flows.size == 2) + assert(unresolvedGraph.flows.size == 2) + assert(unresolvedGraph.tables.size == 2) + val bFlow = + resolvedGraph.resolvedFlows.filter(_.identifier == fullyQualifiedIdentifier("b")).head + assert(bFlow.inputs == Set(fullyQualifiedIdentifier("a"))) + + val updateContext = TestPipelineUpdateContext(spark, unresolvedGraph) + updateContext.pipelineExecution.runPipeline() + updateContext.pipelineExecution.awaitCompletion() + + // start with queued + assertFlowProgressStatusInOrder( + eventBuffer = updateContext.eventBuffer, + identifier = fullyQualifiedIdentifier("a"), + expectedFlowProgressStatus = Seq( + (EventLevel.INFO, FlowStatus.QUEUED), + (EventLevel.INFO, FlowStatus.PLANNING), + (EventLevel.INFO, FlowStatus.STARTING), + (EventLevel.INFO, FlowStatus.RUNNING), + (EventLevel.INFO, FlowStatus.COMPLETED) + ) + ) + assertFlowProgressStatusInOrder( + eventBuffer = updateContext.eventBuffer, + identifier = fullyQualifiedIdentifier("b"), + expectedFlowProgressStatus = Seq( + (EventLevel.INFO, FlowStatus.QUEUED), + (EventLevel.INFO, FlowStatus.PLANNING), + (EventLevel.INFO, FlowStatus.STARTING), + (EventLevel.INFO, FlowStatus.RUNNING), + (EventLevel.INFO, FlowStatus.COMPLETED) + ) + ) + } + + test("graph materialization with streams") { + val pipelineDef = new TestGraphRegistrationContext(spark) { + registerTable("a", query = Option(dfFlowFunc(Seq(1, 2).toDF("x")))) + registerTable("b", query = Option(readFlowFunc("a"))) + registerView("c", query = readStreamFlowFunc("a")) + registerTable("d", query = Option(readStreamFlowFunc("c"))) + } + + val unresolvedGraph = pipelineDef.toDataflowGraph + val resolvedGraph = unresolvedGraph.resolve() + assert(resolvedGraph.flows.size == 4) + assert(resolvedGraph.tables.size == 3) + assert(resolvedGraph.views.size == 1) + + val bFlow = + resolvedGraph.resolvedFlows.filter(_.identifier == fullyQualifiedIdentifier("b")).head + assert(bFlow.inputs == Set(fullyQualifiedIdentifier("a"))) + + val cFlow = + resolvedGraph.resolvedFlows + .filter(_.identifier == fullyQualifiedIdentifier("c", isView = true)) + .head + assert(cFlow.inputs == Set(fullyQualifiedIdentifier("a"))) + + val dFlow = + resolvedGraph.resolvedFlows.filter(_.identifier == fullyQualifiedIdentifier("d")).head + assert(dFlow.inputs == Set(fullyQualifiedIdentifier("c", isView = true))) + + val updateContext = TestPipelineUpdateContext(spark, unresolvedGraph) + updateContext.pipelineExecution.runPipeline() + updateContext.pipelineExecution.awaitCompletion() + + assertFlowProgressStatusInOrder( + eventBuffer = updateContext.eventBuffer, + identifier = fullyQualifiedIdentifier("a"), + expectedFlowProgressStatus = Seq( + (EventLevel.INFO, FlowStatus.QUEUED), + (EventLevel.INFO, FlowStatus.PLANNING), + (EventLevel.INFO, FlowStatus.STARTING), + (EventLevel.INFO, FlowStatus.RUNNING), + (EventLevel.INFO, FlowStatus.COMPLETED) + ) + ) + assertFlowProgressStatusInOrder( + eventBuffer = updateContext.eventBuffer, + identifier = fullyQualifiedIdentifier("b"), + expectedFlowProgressStatus = Seq( + (EventLevel.INFO, FlowStatus.QUEUED), + (EventLevel.INFO, FlowStatus.PLANNING), + (EventLevel.INFO, FlowStatus.STARTING), + (EventLevel.INFO, FlowStatus.RUNNING), + (EventLevel.INFO, FlowStatus.COMPLETED) + ) + ) + assertFlowProgressStatusInOrder( + eventBuffer = updateContext.eventBuffer, + identifier = fullyQualifiedIdentifier("d"), + expectedFlowProgressStatus = Seq( + (EventLevel.INFO, FlowStatus.QUEUED), + (EventLevel.INFO, FlowStatus.STARTING), + (EventLevel.INFO, FlowStatus.RUNNING), + (EventLevel.INFO, FlowStatus.COMPLETED) + ) + ) + + // no flow progress event for c, as it is a temporary view + assertNoFlowProgressEvent( + eventBuffer = updateContext.eventBuffer, + identifier = fullyQualifiedIdentifier("c", isView = true), + flowStatus = FlowStatus.STARTING + ) + checkAnswer( + spark.read.table(fullyQualifiedIdentifier("a").toString), + Seq(Row(1), Row(2)) + ) + } + + test("three hop pipeline") { + // Construct pipeline + val pipelineDef = new TestGraphRegistrationContext(spark) { + private val ints = MemoryStream[Int] + ints.addData(1 until 10: _*) + registerView("input", query = dfFlowFunc(ints.toDF())) + registerTable( + "eights", + query = Option(sqlFlowFunc(spark, "SELECT value * 2 as value FROM STREAM fours")) + ) + registerTable( + "fours", + query = Option(sqlFlowFunc(spark, "SELECT value * 2 as value FROM STREAM evens")) + ) + registerTable( + "evens", + query = Option(sqlFlowFunc(spark, "SELECT * FROM STREAM input WHERE value % 2 = 0")) + ) + } + val graph = pipelineDef.toDataflowGraph + + val updateContext = TestPipelineUpdateContext(spark, graph) + updateContext.pipelineExecution.runPipeline() + updateContext.pipelineExecution.awaitCompletion() + + assertFlowProgressStatusInOrder( + eventBuffer = updateContext.eventBuffer, + identifier = fullyQualifiedIdentifier("evens"), + expectedFlowProgressStatus = Seq( + (EventLevel.INFO, FlowStatus.QUEUED), + (EventLevel.INFO, FlowStatus.STARTING), + (EventLevel.INFO, FlowStatus.RUNNING), + (EventLevel.INFO, FlowStatus.COMPLETED) + ) + ) + assertFlowProgressStatusInOrder( + eventBuffer = updateContext.eventBuffer, + identifier = fullyQualifiedIdentifier("eights"), + expectedFlowProgressStatus = Seq( + (EventLevel.INFO, FlowStatus.QUEUED), + (EventLevel.INFO, FlowStatus.STARTING), + (EventLevel.INFO, FlowStatus.RUNNING), + (EventLevel.INFO, FlowStatus.COMPLETED) + ) + ) + assertFlowProgressStatusInOrder( + eventBuffer = updateContext.eventBuffer, + identifier = fullyQualifiedIdentifier("fours"), + expectedFlowProgressStatus = Seq( + (EventLevel.INFO, FlowStatus.QUEUED), + (EventLevel.INFO, FlowStatus.STARTING), + (EventLevel.INFO, FlowStatus.RUNNING), + (EventLevel.INFO, FlowStatus.COMPLETED) + ) + ) + + val graphExecution = updateContext.pipelineExecution.graphExecution.get + assert( + getFlowsWithState(graphExecution, StreamState.SUCCESSFUL) == Set( + fullyQualifiedIdentifier("eights"), + fullyQualifiedIdentifier("fours"), + fullyQualifiedIdentifier("evens") + ) + ) + assert(getFlowsWithState(graphExecution, StreamState.TERMINATED_WITH_ERROR).isEmpty) + assert(getFlowsWithState(graphExecution, StreamState.SKIPPED).isEmpty) + assert(getFlowsWithState(graphExecution, StreamState.CANCELED).isEmpty) + + checkDatasetUnorderly(getTable(fullyQualifiedIdentifier("evens")), 2L, 4L, 6L, 8L) + checkDatasetUnorderly(getTable(fullyQualifiedIdentifier("fours")), 4L, 8L, 12L, 16L) + checkDatasetUnorderly(getTable(fullyQualifiedIdentifier("eights")), 8L, 16L, 24L, 32L) + } + + test("all events are emitted even if there is no data") { + // Construct pipeline + val pipelineDef = new TestGraphRegistrationContext(spark) { + private val ints = MemoryStream[Int] + registerView("input", query = dfFlowFunc(ints.toDF())) + registerTable( + "evens", + query = Option(sqlFlowFunc(spark, "SELECT * FROM STREAM input WHERE value % 2 = 0")) + ) + } + val graph = pipelineDef.toDataflowGraph + + val updateContext = TestPipelineUpdateContext(spark, graph) + updateContext.pipelineExecution.runPipeline() + updateContext.pipelineExecution.awaitCompletion() + + val graphExecution = updateContext.pipelineExecution.graphExecution.get + assert( + getFlowsWithState(graphExecution, StreamState.SUCCESSFUL) == Set( + fullyQualifiedIdentifier("evens") + ) + ) + assert(getFlowsWithState(graphExecution, StreamState.TERMINATED_WITH_ERROR).isEmpty) + assert(getFlowsWithState(graphExecution, StreamState.SKIPPED).isEmpty) + assert(getFlowsWithState(graphExecution, StreamState.CANCELED).isEmpty) + + assertFlowProgressStatusInOrder( + eventBuffer = updateContext.eventBuffer, + identifier = fullyQualifiedIdentifier("evens"), + expectedFlowProgressStatus = Seq( + (EventLevel.INFO, FlowStatus.QUEUED), + (EventLevel.INFO, FlowStatus.STARTING), + (EventLevel.INFO, FlowStatus.RUNNING), + (EventLevel.INFO, FlowStatus.COMPLETED) + ) + ) + + checkDatasetUnorderly(getTable(fullyQualifiedIdentifier("evens"))) + } + + test("stream failure causes its downstream to be skipped") { + spark.sql("CREATE TABLE src USING PARQUET AS SELECT * FROM RANGE(10)") + + // A UDF which fails immediately + val failUDF = functions.udf((_: String) => { + throw new RuntimeException("Test error") + true + }) + + val pipelineDef = new TestGraphRegistrationContext(spark) { + private val memoryStream = MemoryStream[Int] + memoryStream.addData(1, 2) + registerView("input_view", query = dfFlowFunc(memoryStream.toDF())) + registerTable( + "input_table", + query = Option(readStreamFlowFunc("input_view")) + ) + registerTable( + "branch_1", + query = Option(readStreamFlowFunc("input_table")) + ) + registerTable( + "branch_2", + query = Option(dfFlowFunc(spark.readStream.table("src").filter(failUDF($"id")))) + ) + registerTable("x", query = Option(readStreamFlowFunc("branch_2"))) + registerView("y", query = readStreamFlowFunc("x")) + registerTable("z", query = Option(readStreamFlowFunc("x"))) + } + val graph = pipelineDef.toDataflowGraph + val updateContext = TestPipelineUpdateContext(spark, graph) + updateContext.pipelineExecution.runPipeline() + updateContext.pipelineExecution.awaitCompletion() + + val graphExecution = updateContext.pipelineExecution.graphExecution.get + assert( + getFlowsWithState(graphExecution, StreamState.SUCCESSFUL) == Set( + fullyQualifiedIdentifier("input_table"), + fullyQualifiedIdentifier("branch_1") + ) + ) + assert( + getFlowsWithState(graphExecution, StreamState.TERMINATED_WITH_ERROR) == Set( + fullyQualifiedIdentifier("branch_2") + ) + ) + assert( + getFlowsWithState(graphExecution, StreamState.SKIPPED) == Set( + fullyQualifiedIdentifier("x"), + fullyQualifiedIdentifier("z") + ) + ) + assert(getFlowsWithState(graphExecution, StreamState.CANCELED).isEmpty) + + // `input_table` and `branch_1` should succeed, while `branch_2` should fail. + assertFlowProgressEvent( + eventBuffer = updateContext.eventBuffer, + identifier = fullyQualifiedIdentifier("input_table"), + expectedFlowStatus = FlowStatus.COMPLETED, + expectedEventLevel = EventLevel.INFO + ) + assertFlowProgressEvent( + eventBuffer = updateContext.eventBuffer, + identifier = fullyQualifiedIdentifier("branch_1"), + expectedFlowStatus = FlowStatus.COMPLETED, + expectedEventLevel = EventLevel.INFO + ) + assertFlowProgressEvent( + eventBuffer = updateContext.eventBuffer, + identifier = fullyQualifiedIdentifier("branch_2"), + expectedFlowStatus = FlowStatus.FAILED, + expectedEventLevel = EventLevel.ERROR, + errorChecker = { ex => + ex.exceptions.exists { ex => + ex.message.contains("Test error") + } + } + ) + // all the downstream flows of `branch_2` should be skipped. + assertFlowProgressEvent( + eventBuffer = updateContext.eventBuffer, + identifier = fullyQualifiedIdentifier("x"), + expectedFlowStatus = FlowStatus.SKIPPED, + expectedEventLevel = EventLevel.INFO + ) + assertFlowProgressEvent( + eventBuffer = updateContext.eventBuffer, + identifier = fullyQualifiedIdentifier("z"), + expectedFlowStatus = FlowStatus.SKIPPED, + expectedEventLevel = EventLevel.INFO + ) + + // since flow `x` and `z` are skipped, we should not see any progress events for them. + assertNoFlowProgressEvent( + eventBuffer = updateContext.eventBuffer, + identifier = fullyQualifiedIdentifier("x"), + flowStatus = FlowStatus.STARTING + ) + assertNoFlowProgressEvent( + eventBuffer = updateContext.eventBuffer, + identifier = fullyQualifiedIdentifier("z"), + flowStatus = FlowStatus.STARTING + ) + + // b,c should have valid data as their upstream has no failures. + checkDatasetUnorderly(getTable(fullyQualifiedIdentifier("input_table")), 1L, 2L) + checkDatasetUnorderly(getTable(fullyQualifiedIdentifier("branch_1")), 1L, 2L) + + // the pipeline run should be marked as failed since `branch_2` has failed. + assertRunProgressEvent( + eventBuffer = updateContext.eventBuffer, + state = RunState.FAILED, + expectedEventLevel = EventLevel.ERROR, + msgChecker = msg => + msg.contains( + "Run is FAILED since flow 'spark_catalog.test_db.branch_2' has failed more than " + + "2 times" + ) + ) + } + + test("stream failure on deletes and updates gives clear error") { + spark.sql("CREATE TABLE src1 USING PARQUET AS SELECT * FROM RANGE(10)") + spark.sql("CREATE TABLE src2 USING PARQUET AS SELECT * FROM RANGE(10)") + + // Pipeline fails due to udf failure + val pipelineDef1 = new TestGraphRegistrationContext(spark) { + registerView( + "input_view", + query = dfFlowFunc( + spark.readStream + .table("src1") + .unionAll(spark.readStream.table("src2")) + ) + ) + registerTable( + "input_table", + query = Option(readStreamFlowFunc("input_view")) + ) + } + + val graph1 = pipelineDef1.toDataflowGraph + val updateContext1 = TestPipelineUpdateContext(spark, graph1) + updateContext1.pipelineExecution.runPipeline() + updateContext1.pipelineExecution.awaitCompletion() + assertFlowProgressEvent( + eventBuffer = updateContext1.eventBuffer, + identifier = fullyQualifiedIdentifier("input_table"), + expectedFlowStatus = FlowStatus.COMPLETED, + expectedEventLevel = EventLevel.INFO + ) + + val pipelineDef2 = new TestGraphRegistrationContext(spark) { + registerView( + "input_view", + query = dfFlowFunc(spark.readStream.table("src2")) + ) + registerTable( + "input_table", + query = Option(readStreamFlowFunc("input_view")) + ) + } + val graph2 = pipelineDef2.toDataflowGraph + val updateContext2 = TestPipelineUpdateContext(spark, graph2) + updateContext2.pipelineExecution.runPipeline() + updateContext2.pipelineExecution.awaitCompletion() + + assertFlowProgressEvent( + eventBuffer = updateContext2.eventBuffer, + identifier = fullyQualifiedIdentifier("input_table"), + expectedFlowStatus = FlowStatus.FAILED, + expectedEventLevel = EventLevel.ERROR, + msgChecker = _.contains( + s"Flow '${eventLogName("input_table")}' had streaming sources added or removed." + ) + ) + } + + test("user-specified schema is applied to table") { + val specifiedSchema = new StructType().add("x", "int", nullable = true) + val pipelineDef = new TestGraphRegistrationContext(spark) { + registerTable( + "specified_schema", + query = Option(dfFlowFunc(Seq(1, 2).toDF("x"))), + specifiedSchema = Option(specifiedSchema) + ) + + registerTable( + "specified_schema_stream", + query = Option(dfFlowFunc(Seq(1, 2).toDF("x"))), + specifiedSchema = Option(specifiedSchema) + ) + + registerTable( + "specified_schema_downstream", + query = Option(readStreamFlowFunc("specified_schema")), + specifiedSchema = Option(specifiedSchema) + ) + + registerTable( + "specified_schema_downbatch", + query = Option(readFlowFunc("specified_schema_stream")), + specifiedSchema = Option(specifiedSchema) + ) + } + val ctx = TestPipelineUpdateContext( + spark, + pipelineDef.toDataflowGraph, + fullRefreshTables = AllTables, + resetCheckpointFlows = AllFlows + ) + ctx.pipelineExecution.runPipeline() + ctx.pipelineExecution.awaitCompletion() + + Seq( + fullyQualifiedIdentifier("specified_schema"), + fullyQualifiedIdentifier("specified_schema_stream"), + fullyQualifiedIdentifier("specified_schema_downstream"), + fullyQualifiedIdentifier("specified_schema_downbatch") + ).foreach { tableIdentifier => + val catalogId = + Identifier.of(Array(tableIdentifier.database.get), tableIdentifier.identifier) + assert( + spark.sessionState.catalogManager.currentCatalog + .asInstanceOf[TableCatalog] + .loadTable(catalogId) + .columns() sameElements CatalogV2Util.structTypeToV2Columns(specifiedSchema), + s"Table $catalogId's schema in the catalog does not match the specified schema" + ) + assert( + spark.table(tableIdentifier).schema == specifiedSchema, + s"Table $tableIdentifier's schema in storage does not match the specified schema" + ) + } + } + + test("stopping a pipeline mid-execution") { + // A UDF which adds a delay + val delayUDF = functions.udf((_: String) => { + Thread.sleep(5 * 1000) + true + }) + spark.sql("CREATE TABLE src1 USING PARQUET AS SELECT * FROM RANGE(10)") + + // Construct pipeline + val pipelineDef = new TestGraphRegistrationContext(spark) { + private val memoryStream = MemoryStream[Int] + memoryStream.addData(1, 2) + registerView("input_view", query = dfFlowFunc(memoryStream.toDF())) + registerTable( + "input_table", + query = Option(readStreamFlowFunc("input_view")) + ) + registerTable( + "query_with_delay", + query = Option(dfFlowFunc(spark.readStream.table("src1").filter(delayUDF($"id")))) + ) + registerTable("x", query = Option(readStreamFlowFunc("query_with_delay"))) + } + + val graph = pipelineDef.toDataflowGraph + val updateContext = TestPipelineUpdateContext(spark, graph) + updateContext.pipelineExecution.runPipeline() + + val graphExecution = updateContext.pipelineExecution.graphExecution.get + + eventually(timeout(Span(60, Seconds))) { + assert( + graphExecution.pipelineState( + fullyQualifiedIdentifier("input_table") + ) == StreamState.SUCCESSFUL + ) + assert( + graphExecution.pipelineState( + fullyQualifiedIdentifier("query_with_delay") + ) == StreamState.RUNNING + ) + graphExecution.stop() + } + assert( + getFlowsWithState(graphExecution, StreamState.SUCCESSFUL) == Set( + fullyQualifiedIdentifier("input_table") + ) + ) + assert(getFlowsWithState(graphExecution, StreamState.TERMINATED_WITH_ERROR).isEmpty) + assert( + getFlowsWithState(graphExecution, StreamState.SKIPPED) == Set(fullyQualifiedIdentifier("x")) + ) + assert( + getFlowsWithState(graphExecution, StreamState.CANCELED) == Set( + fullyQualifiedIdentifier("query_with_delay") + ) + ) + assert( + latestFlowStatuses(updateContext.eventBuffer) == Map( + eventLogName("input_table") -> FlowStatus.COMPLETED, + eventLogName("query_with_delay") -> FlowStatus.STOPPED, + eventLogName("x") -> FlowStatus.SKIPPED + ) + ) + + checkDatasetUnorderly(getTable(fullyQualifiedIdentifier("input_table")), 1L, 2L) + } + + test("two hop pipeline with partitioned graph") { + val pipelineDef = new TestGraphRegistrationContext(spark) { + registerTable("integer_input", query = Option(dfFlowFunc(Seq(1, 2, 3, 4).toDF("value")))) + registerTable( + "double", + query = Option(sqlFlowFunc(spark, "SELECT value * 2 as value FROM integer_input")) + ) + registerTable( + "string_input", + query = Option(dfFlowFunc(Seq("a", "b", "c", "d").toDF("value"))) + ) + registerTable( + "append_x", + query = Option( + sqlFlowFunc(spark, "SELECT CONCAT(value, 'x') as value FROM string_input") + ) + ) + } + + val graph = pipelineDef.toDataflowGraph + val updateContext = TestPipelineUpdateContext(spark, graph) + updateContext.pipelineExecution.runPipeline() + updateContext.pipelineExecution.awaitCompletion() + + val graphExecution = updateContext.pipelineExecution.graphExecution.get + + assert( + getFlowsWithState(graphExecution, StreamState.SUCCESSFUL) == Set( + fullyQualifiedIdentifier("integer_input"), + fullyQualifiedIdentifier("string_input"), + fullyQualifiedIdentifier("double"), + fullyQualifiedIdentifier("append_x") + ) + ) + assert(getFlowsWithState(graphExecution, StreamState.TERMINATED_WITH_ERROR).isEmpty) + assert(getFlowsWithState(graphExecution, StreamState.SKIPPED).isEmpty) + assert(getFlowsWithState(graphExecution, StreamState.CANCELED).isEmpty) + + assert( + latestFlowStatuses(updateContext.eventBuffer) == Map( + eventLogName("integer_input") -> FlowStatus.COMPLETED, + eventLogName("string_input") -> FlowStatus.COMPLETED, + eventLogName("double") -> FlowStatus.COMPLETED, + eventLogName("append_x") -> FlowStatus.COMPLETED + ) + ) + + checkDatasetUnorderly(getTable(fullyQualifiedIdentifier("double")), 2L, 4L, 6L, 8L) + checkAnswer( + spark.read.table(fullyQualifiedIdentifier("append_x").toString), + Seq(Row("ax"), Row("bx"), Row("cx"), Row("dx")) + ) + } + + test("multiple hop pipeline with merge from multiple sources") { + val pipelineDef = new TestGraphRegistrationContext(spark) { + registerTable("integer_input", query = Option(dfFlowFunc(Seq(1, 2, 3, 4).toDF("nums")))) + registerTable( + "double", + query = Option(sqlFlowFunc(spark, "SELECT nums * 2 as nums FROM integer_input")) + ) + registerTable( + "string_input", + query = Option(dfFlowFunc(Seq("a", "b", "c", "d").toDF("text"))) + ) + registerTable( + "append_x", + query = Option( + sqlFlowFunc(spark, "SELECT CONCAT(text, 'x') as text FROM string_input") + ) + ) + registerTable( + "merged", + query = Option( + sqlFlowFunc( + spark, + "SELECT * FROM double FULL OUTER JOIN append_x ON nums::STRING = text" + ) + ) + ) + } + + val graph = pipelineDef.toDataflowGraph + val updateContext = TestPipelineUpdateContext(spark, graph) + updateContext.pipelineExecution.runPipeline() + updateContext.pipelineExecution.awaitCompletion() + + val graphExecution = updateContext.pipelineExecution.graphExecution.get + + assert( + getFlowsWithState(graphExecution, StreamState.SUCCESSFUL) == Set( + fullyQualifiedIdentifier("integer_input"), + fullyQualifiedIdentifier("string_input"), + fullyQualifiedIdentifier("double"), + fullyQualifiedIdentifier("append_x"), + fullyQualifiedIdentifier("merged") + ) + ) + + assert(getFlowsWithState(graphExecution, StreamState.TERMINATED_WITH_ERROR).isEmpty) + assert(getFlowsWithState(graphExecution, StreamState.SKIPPED).isEmpty) + assert(getFlowsWithState(graphExecution, StreamState.CANCELED).isEmpty) + + assert( + latestFlowStatuses(updateContext.eventBuffer) == Map( + eventLogName("integer_input") -> FlowStatus.COMPLETED, + eventLogName("string_input") -> FlowStatus.COMPLETED, + eventLogName("double") -> FlowStatus.COMPLETED, + eventLogName("append_x") -> FlowStatus.COMPLETED, + eventLogName("merged") -> FlowStatus.COMPLETED + ) + ) + + val expectedSchema = new StructType().add("nums", IntegerType).add("text", StringType) + assert(spark.read.table(fullyQualifiedIdentifier("merged").toString).schema == expectedSchema) + checkAnswer( + spark.read.table(fullyQualifiedIdentifier("merged").toString), + Seq( + Row(2, null), + Row(4, null), + Row(6, null), + Row(8, null), + Row(null, "ax"), + Row(null, "bx"), + Row(null, "cx"), + Row(null, "dx") + ) + ) + } + + test("multiple hop pipeline with split and merge from single source") { + val pipelineDef = new TestGraphRegistrationContext(spark) { + registerTable( + "input_table", + query = Option( + dfFlowFunc( + Seq((1, 1), (1, 2), (2, 3), (2, 4)).toDF("x", "y") + ) + ) + ) + registerTable( + "left_split", + query = Option( + sqlFlowFunc( + spark, + "SELECT x FROM input_table WHERE x IS NOT NULL" + ) + ) + ) + registerTable( + "right_split", + query = Option(sqlFlowFunc(spark, "SELECT y FROM input_table WHERE y IS NOT NULL")) + ) + registerTable( + "merged", + query = Option( + sqlFlowFunc( + spark, + "SELECT * FROM left_split FULL OUTER JOIN right_split ON x = y" + ) + ) + ) + } + + val graph = pipelineDef.toDataflowGraph + val updateContext = TestPipelineUpdateContext(spark, graph) + updateContext.pipelineExecution.runPipeline() + updateContext.pipelineExecution.awaitCompletion() + + val graphExecution = updateContext.pipelineExecution.graphExecution.get + + assert( + getFlowsWithState(graphExecution, StreamState.SUCCESSFUL) == Set( + fullyQualifiedIdentifier("input_table"), + fullyQualifiedIdentifier("left_split"), + fullyQualifiedIdentifier("right_split"), + fullyQualifiedIdentifier("merged") + ) + ) + + assert(getFlowsWithState(graphExecution, StreamState.TERMINATED_WITH_ERROR).isEmpty) + assert(getFlowsWithState(graphExecution, StreamState.SKIPPED).isEmpty) + assert(getFlowsWithState(graphExecution, StreamState.CANCELED).isEmpty) + + assert( + latestFlowStatuses(updateContext.eventBuffer) == Map( + eventLogName("input_table") -> FlowStatus.COMPLETED, + eventLogName("left_split") -> FlowStatus.COMPLETED, + eventLogName("right_split") -> FlowStatus.COMPLETED, + eventLogName("merged") -> FlowStatus.COMPLETED + ) + ) + + val expectedSchema = new StructType().add("x", IntegerType).add("y", IntegerType) + assert(spark.read.table(fullyQualifiedIdentifier("merged").toString).schema == expectedSchema) + checkAnswer( + spark.read.table(fullyQualifiedIdentifier("merged").toString), + Seq( + Row(1, 1), + Row(1, 1), + Row(2, 2), + Row(2, 2), + Row(null, 3), + Row(null, 4) + ) + ) + } + + test("test default flow retry is 2 and event WARN/ERROR levels accordingly") { + val fail = functions.udf((x: Int) => { + throw new RuntimeException("Intentionally failing UDF.") + x + }) + spark.sql("CREATE TABLE src USING parquet AS SELECT id AS value FROM RANGE(10)") + val pipelineDef = new TestGraphRegistrationContext(spark) { + registerTable( + "a", + query = Option(dfFlowFunc(spark.readStream.table("src").select(fail($"value") as "value"))) + ) + } + val graph = pipelineDef.toDataflowGraph + val updateContext = TestPipelineUpdateContext(spark, graph) + updateContext.pipelineExecution.runPipeline() + updateContext.pipelineExecution.awaitCompletion() + + // assert the flow failure event log tracked 3 times, 2 retries and 1 original attempt + assertFlowProgressEvent( + eventBuffer = updateContext.eventBuffer, + identifier = fullyQualifiedIdentifier("a"), + expectedFlowStatus = FlowStatus.FAILED, + expectedEventLevel = EventLevel.INFO, + expectedNumOfEvents = Option(3) + ) + + assertFlowProgressEvent( + eventBuffer = updateContext.eventBuffer, + identifier = fullyQualifiedIdentifier("a"), + expectedFlowStatus = FlowStatus.FAILED, + expectedEventLevel = EventLevel.ERROR, + expectedNumOfEvents = Option(1), + msgChecker = _.contains("has FAILED more than 2 times and will not be restarted") + ) + + assertRunProgressEvent( + eventBuffer = updateContext.eventBuffer, + state = RunState.FAILED, + expectedEventLevel = EventLevel.ERROR, + msgChecker = msg => + msg.contains( + "Run is FAILED since flow 'spark_catalog.test_db.a' has failed more than 2 times" + ) + ) + } + + test("partial graph updates") { + val ints: MemoryStream[Int] = MemoryStream[Int] + val pipelineDef = new TestGraphRegistrationContext(spark) { + ints.addData(1, 2, 3) + registerTable("source", query = Option(dfFlowFunc(ints.toDF()))) + registerTable("all", query = Option(readStreamFlowFunc("source"))) + registerView( + "evens", + query = sqlFlowFunc(spark, "SELECT * FROM all WHERE value % 2 = 0") + ) + registerTable("max_evens", query = Option(sqlFlowFunc(spark, "SELECT MAX(value) FROM evens"))) + } + val graph1 = pipelineDef.toDataflowGraph + + // First update, which excludes "max_evens" table. + val updateContext1 = TestPipelineUpdateContext( + spark = spark, + unresolvedGraph = graph1, + refreshTables = SomeTables( + Set(fullyQualifiedIdentifier("source"), fullyQualifiedIdentifier("all")) + ), + resetCheckpointFlows = NoFlows + ) + updateContext1.pipelineExecution.runPipeline() + updateContext1.pipelineExecution.awaitCompletion() + + def readTable(tableName: String): DataFrame = { + spark.read.table(fullyQualifiedIdentifier(tableName).toString) + } + + // only `source` and `all_tables` flows are executed and their table are refreshed and have data + assertFlowProgressStatusInOrder( + eventBuffer = updateContext1.eventBuffer, + identifier = fullyQualifiedIdentifier("source"), + expectedFlowProgressStatus = Seq( + (EventLevel.INFO, FlowStatus.QUEUED), + (EventLevel.INFO, FlowStatus.STARTING), + (EventLevel.INFO, FlowStatus.RUNNING), + (EventLevel.INFO, FlowStatus.COMPLETED) + ) + ) + assertFlowProgressStatusInOrder( + eventBuffer = updateContext1.eventBuffer, + identifier = fullyQualifiedIdentifier("all"), + expectedFlowProgressStatus = Seq( + (EventLevel.INFO, FlowStatus.QUEUED), + (EventLevel.INFO, FlowStatus.STARTING), + (EventLevel.INFO, FlowStatus.RUNNING), + (EventLevel.INFO, FlowStatus.COMPLETED) + ) + ) + checkAnswer(readTable("source"), Seq(Row(1), Row(2), Row(3))) + checkAnswer(readTable("all"), Seq(Row(1), Row(2), Row(3))) + // `max_evens` flow shouldn't be executed and the table is not created + assertFlowProgressEvent( + eventBuffer = updateContext1.eventBuffer, + identifier = fullyQualifiedIdentifier("max_evens"), + expectedFlowStatus = FlowStatus.EXCLUDED, + expectedEventLevel = EventLevel.INFO + ) + assert(!spark.catalog.tableExists(fullyQualifiedIdentifier("max_evens").toString)) + + // Second update, which excludes "all" table. + ints.addData(4) + val updateContext2 = TestPipelineUpdateContext( + spark = spark, + unresolvedGraph = graph1, + refreshTables = SomeTables( + Set(fullyQualifiedIdentifier("source"), fullyQualifiedIdentifier("max_evens")) + ), + resetCheckpointFlows = NoFlows + ) + updateContext2.pipelineExecution.runPipeline() + updateContext2.pipelineExecution.awaitCompletion() + checkAnswer(readTable("source"), Seq(Row(1), Row(2), Row(3), Row(4))) + // `all` flow should not be executed and 'all' table is not refreshed and has old data + checkAnswer(readTable("all"), Seq(Row(1), Row(2), Row(3))) + checkAnswer(readTable("max_evens"), Seq(Row(2))) + assertFlowProgressEvent( + eventBuffer = updateContext2.eventBuffer, + identifier = fullyQualifiedIdentifier("all"), + expectedFlowStatus = FlowStatus.EXCLUDED, + expectedEventLevel = EventLevel.INFO + ) + assertNoFlowProgressEvent( + eventBuffer = updateContext2.eventBuffer, + identifier = fullyQualifiedIdentifier("all"), + flowStatus = FlowStatus.STARTING + ) + + // `max_evens` and `source` flows should be executed and the table is created with new data + assertFlowProgressStatusInOrder( + eventBuffer = updateContext2.eventBuffer, + identifier = fullyQualifiedIdentifier("max_evens"), + expectedFlowProgressStatus = Seq( + (EventLevel.INFO, FlowStatus.QUEUED), + (EventLevel.INFO, FlowStatus.PLANNING), + (EventLevel.INFO, FlowStatus.STARTING), + (EventLevel.INFO, FlowStatus.RUNNING), + (EventLevel.INFO, FlowStatus.COMPLETED) + ) + ) + assertFlowProgressStatusInOrder( + eventBuffer = updateContext2.eventBuffer, + identifier = fullyQualifiedIdentifier("source"), + expectedFlowProgressStatus = Seq( + (EventLevel.INFO, FlowStatus.QUEUED), + (EventLevel.INFO, FlowStatus.STARTING), + (EventLevel.INFO, FlowStatus.RUNNING), + (EventLevel.INFO, FlowStatus.COMPLETED) + ) + ) + } + + test("flow fails to resolve") { + val graph = new TestGraphRegistrationContext(spark) { + registerTable("table1", query = Option(sqlFlowFunc(spark, "SELECT * FROM nonexistent_src1"))) + registerTable("table2", query = Option(sqlFlowFunc(spark, "SELECT * FROM nonexistent_src2"))) + registerTable("table3", query = Option(sqlFlowFunc(spark, "SELECT * FROM table1"))) + }.toDataflowGraph + + val updateContext = TestPipelineUpdateContext(spark = spark, unresolvedGraph = graph) + intercept[UnresolvedPipelineException] { + updateContext.pipelineExecution.runPipeline() + } + + assertFlowProgressEvent( + updateContext.eventBuffer, + identifier = fullyQualifiedIdentifier("table1"), + expectedFlowStatus = FlowStatus.FAILED, + expectedEventLevel = EventLevel.WARN, + msgChecker = _.contains("Failed to resolve flow: 'spark_catalog.test_db.table1'"), + errorChecker = { ex => + ex.exceptions.exists { ex => + ex.message.contains( + "The table or view `spark_catalog`.`test_db`.`nonexistent_src1` cannot be found" + ) + } + } + ) + + assertFlowProgressEvent( + updateContext.eventBuffer, + identifier = fullyQualifiedIdentifier("table2"), + expectedFlowStatus = FlowStatus.FAILED, + expectedEventLevel = EventLevel.WARN, + msgChecker = _.contains("Failed to resolve flow: 'spark_catalog.test_db.table2'"), + errorChecker = { ex => + ex.exceptions.exists { ex => + ex.message.contains( + "The table or view `spark_catalog`.`test_db`.`nonexistent_src2` cannot be found" + ) + } + } + ) + + assertFlowProgressEvent( + updateContext.eventBuffer, + identifier = fullyQualifiedIdentifier("table3"), + expectedFlowStatus = FlowStatus.FAILED, + expectedEventLevel = EventLevel.WARN, + msgChecker = _.contains( + "Failed to resolve flow due to upstream failure: 'spark_catalog.test_db.table3'" + ), + errorChecker = { ex => + ex.exceptions.exists { ex => + ex.message.contains( + "Failed to read dataset 'spark_catalog.test_db.table1'. Dataset is defined in the " + + "pipeline but could not be resolved." + ) + } + } + ) + } +} diff --git a/sql/pipelines/src/test/scala/org/apache/spark/sql/pipelines/logging/ConstructPipelineEventSuite.scala b/sql/pipelines/src/test/scala/org/apache/spark/sql/pipelines/logging/ConstructPipelineEventSuite.scala index c1d6360a55065..81d5758b19d54 100644 --- a/sql/pipelines/src/test/scala/org/apache/spark/sql/pipelines/logging/ConstructPipelineEventSuite.scala +++ b/sql/pipelines/src/test/scala/org/apache/spark/sql/pipelines/logging/ConstructPipelineEventSuite.scala @@ -73,8 +73,7 @@ class ConstructPipelineEventSuite extends SparkFunSuite with BeforeAndAfterEach assert(serializedEx.map(_.message) == Seq("exception 2", "exception 1")) assert(serializedEx.head.className == "java.lang.IllegalStateException") - assert( - serializedEx.head.stack.nonEmpty, "Stack trace of main exception should not be empty") + assert(serializedEx.head.stack.nonEmpty, "Stack trace of main exception should not be empty") assert(serializedEx(1).stack.nonEmpty, "Stack trace of cause should not be empty") // Get the original and serialized stack traces for comparison @@ -110,6 +109,7 @@ class ConstructPipelineEventSuite extends SparkFunSuite with BeforeAndAfterEach ) ) ), + level = EventLevel.INFO, message = "Flow 'b' has failed", details = FlowProgress(FlowStatus.FAILED), eventTimestamp = Some(new Timestamp(1747338049615L)) @@ -117,6 +117,7 @@ class ConstructPipelineEventSuite extends SparkFunSuite with BeforeAndAfterEach assert(event.origin.datasetName.contains("dataset")) assert(event.origin.flowName.contains("flow")) assert(event.origin.sourceCodeLocation.get.path.contains("path")) + assert(event.level == EventLevel.INFO) assert(event.message == "Flow 'b' has failed") assert(event.details.asInstanceOf[FlowProgress].status == FlowStatus.FAILED) assert(event.timestamp == "2025-05-15T19:40:49.615Z") diff --git a/sql/pipelines/src/test/scala/org/apache/spark/sql/pipelines/logging/PipelineEventSuite.scala b/sql/pipelines/src/test/scala/org/apache/spark/sql/pipelines/logging/PipelineEventSuite.scala index 7430054fa3312..21facf0a6057a 100644 --- a/sql/pipelines/src/test/scala/org/apache/spark/sql/pipelines/logging/PipelineEventSuite.scala +++ b/sql/pipelines/src/test/scala/org/apache/spark/sql/pipelines/logging/PipelineEventSuite.scala @@ -49,8 +49,12 @@ class PipelineEventSuite extends SparkFunSuite with Logging { private def makeEvent() = { ConstructPipelineEvent( - origin = - PipelineEventOrigin(flowName = Option("a"), datasetName = None, sourceCodeLocation = None), + origin = PipelineEventOrigin( + flowName = Option("a"), + datasetName = None, + sourceCodeLocation = None + ), + level = EventLevel.INFO, message = "OK", details = FlowProgress(FlowStatus.STARTING) ) @@ -87,11 +91,16 @@ class PipelineEventSuite extends SparkFunSuite with Logging { test("basic flow progress event has expected fields set") { val event = ConstructPipelineEvent( - origin = - PipelineEventOrigin(flowName = Option("a"), datasetName = None, sourceCodeLocation = None), + origin = PipelineEventOrigin( + flowName = Option("a"), + datasetName = None, + sourceCodeLocation = None + ), + level = EventLevel.INFO, message = "Flow 'a' has completed", details = FlowProgress(FlowStatus.COMPLETED) ) + assert(event.level == EventLevel.INFO) assert(event.message == "Flow 'a' has completed") assert(event.details.isInstanceOf[FlowProgress]) assert(event.origin.flowName == Option("a")) diff --git a/sql/pipelines/src/test/scala/org/apache/spark/sql/pipelines/utils/BaseCoreExecutionTest.scala b/sql/pipelines/src/test/scala/org/apache/spark/sql/pipelines/utils/BaseCoreExecutionTest.scala new file mode 100644 index 0000000000000..3ee73f9394e92 --- /dev/null +++ b/sql/pipelines/src/test/scala/org/apache/spark/sql/pipelines/utils/BaseCoreExecutionTest.scala @@ -0,0 +1,37 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.pipelines.utils + +import org.apache.spark.sql.pipelines.graph.{DataflowGraph, DatasetManager, PipelineUpdateContext} + +trait BaseCoreExecutionTest extends ExecutionTest { + + /** + * Materializes the given graph using the provided context. + * If no context is provided, a default context is created. + */ + protected def materializeGraph( + graph: DataflowGraph, + contextOpt: Option[PipelineUpdateContext] = None + ): DataflowGraph = { + val contextToUse = contextOpt.getOrElse( + TestPipelineUpdateContext(spark = spark, unresolvedGraph = graph) + ) + DatasetManager.materializeDatasets(graph, contextToUse) + } +} diff --git a/sql/pipelines/src/test/scala/org/apache/spark/sql/pipelines/utils/ExecutionTest.scala b/sql/pipelines/src/test/scala/org/apache/spark/sql/pipelines/utils/ExecutionTest.scala new file mode 100644 index 0000000000000..7aa2f9d631064 --- /dev/null +++ b/sql/pipelines/src/test/scala/org/apache/spark/sql/pipelines/utils/ExecutionTest.scala @@ -0,0 +1,270 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.pipelines.utils + +import org.apache.spark.sql.catalyst.TableIdentifier +import org.apache.spark.sql.classic.SparkSession +import org.apache.spark.sql.pipelines.common.{FlowStatus, RunState} +import org.apache.spark.sql.pipelines.graph.{ + AllFlows, + AllTables, + DataflowGraph, + FlowFilter, + NoTables, + PipelineConf, + PipelineUpdateContext, + TableFilter +} +import org.apache.spark.sql.pipelines.logging.{ + ErrorDetail, + EventLevel, + FlowProgress, + FlowProgressEventLogger, + PipelineEvent, + PipelineRunEventBuffer, + RunProgress +} + +trait ExecutionTest + extends PipelineTest + with TestPipelineUpdateContextMixin + with EventVerificationTestHelpers + +trait TestPipelineUpdateContextMixin { + + /** + * A test implementation of the PipelineUpdateContext trait. + * @param spark The Spark session to use. + * @param unresolvedGraph The unresolved dataflow graph. + * @param fullRefreshTables Set of tables to be fully refreshed. + * @param refreshTables Set of tables to be refreshed. + * @param resetCheckpointFlows Set of flows to be reset. + */ + case class TestPipelineUpdateContext( + spark: SparkSession, + unresolvedGraph: DataflowGraph, + fullRefreshTables: TableFilter = NoTables, + refreshTables: TableFilter = AllTables, + resetCheckpointFlows: FlowFilter = AllFlows + ) extends PipelineUpdateContext { + val pipelineConf: PipelineConf = new PipelineConf(spark) + val eventBuffer = new PipelineRunEventBuffer(eventCallback = _ => ()) + + override def flowProgressEventLogger: FlowProgressEventLogger = { + new FlowProgressEventLogger( + eventBuffer = eventBuffer + ) + } + } +} + +trait EventVerificationTestHelpers { + + /** + * Asserts that there is a [[FlowProgress]] event in the event log with the specified flow name + * and status and matching the specified error condition. + * + * @param identifier Flow identifier to look for events for + * @param expectedFlowStatus Expected [[FlowStatus]] + * @param expectedEventLevel Expected [[EventLevel]] of the event. + * @param errorChecker Condition that the event's exception, if any, must pass in order for this + * function to return true. + * @param msgChecker Condition that the event's msg must pass in order for the function to return + * true. + * @param cond Predicate to filter the flow progress events by. Useful for more complex event + * verification. + */ + protected def assertFlowProgressEvent( + eventBuffer: PipelineRunEventBuffer, + identifier: TableIdentifier, + expectedFlowStatus: FlowStatus, + expectedEventLevel: EventLevel, + errorChecker: ErrorDetail => Boolean = _ => true, + msgChecker: String => Boolean = _ => true, + cond: PipelineEvent => Boolean = _ => true, + expectedNumOfEvents: Option[Int] = None + ): Unit = { + // Get all events for the flow. This list is logged if the assertion + // fails to help with debugging. Only minimal filtering is done here + // so we have a complete list of events to look at for debugging. + val flowEvents = eventBuffer.getEvents.filter(_.details.isInstanceOf[FlowProgress]).filter { + event => + event.origin.flowName == Option(identifier.unquotedString) + } + + var matchingEvents = flowEvents.filter(e => msgChecker(e.message)) + matchingEvents = matchingEvents + .filter { + _.details.isInstanceOf[FlowProgress] + } + .filter { + cond + } + .filter { e => + e.details.asInstanceOf[FlowProgress].status == expectedFlowStatus && + (e.error.isEmpty || errorChecker(e.error.get)) + } + + assert( + matchingEvents.nonEmpty, + s"Could not find a matching event for $identifier. Logs for $identifier are $flowEvents" + ) + assert( + expectedNumOfEvents.forall(_ == matchingEvents.size), + s"Found ${matchingEvents.size} events for $identifier but expected " + + s"$expectedNumOfEvents events. Logs for $identifier are $flowEvents" + ) + } + + /** + * Asserts emitted flow progress event logs for a given flow have the expected sequence of + * event levels and flow statuses in the expected order. + * + * @param eventBuffer The event buffer containing the events. + * @param identifier The identifier of the flow to check. + * @param expectedFlowProgressStatus A sequence of tuples containing the expected event level + * and flow status in the expected order. + */ + protected def assertFlowProgressStatusInOrder( + eventBuffer: PipelineRunEventBuffer, + identifier: TableIdentifier, + expectedFlowProgressStatus: Seq[(EventLevel, FlowStatus)] + ): Unit = { + // Get all events for the flow. This list is logged if the assertion + // fails to help with debugging. Only minimal filtering is done here + // so we have a complete list of events to look at for debugging. + val actualFlowProgressStatus = eventBuffer.getEvents + .filter(_.details.isInstanceOf[FlowProgress]) + .filter { event => + event.origin.flowName == Option(identifier.unquotedString) + } + .map { event => + val flowProgress = event.details.asInstanceOf[FlowProgress] + (event.level, flowProgress.status) + } + + assert( + actualFlowProgressStatus == expectedFlowProgressStatus, + s"Expected flow progress status for $identifier to be " + + s"$expectedFlowProgressStatus but got $actualFlowProgressStatus. " + + s"Logs for $identifier are " + + s"${eventBuffer.getEvents.filter(_.origin.flowName == Option(identifier.unquotedString))}. " + + s"All events in the buffer are ${eventBuffer.getEvents.mkString("\n")}" + ) + } + + /** + * Asserts that there is no `FlowProgress` event + * in the event log with the specified flow name and status and metrics checkers. + */ + protected def assertNoFlowProgressEvent( + eventBuffer: PipelineRunEventBuffer, + identifier: TableIdentifier, + flowStatus: FlowStatus + ): Unit = { + val flowEvents = eventBuffer.getEvents + .filter(_.details.isInstanceOf[FlowProgress]) + .filter(_.origin.flowName == Option(identifier.unquotedString)) + assert( + !flowEvents.filter(_.details.isInstanceOf[FlowProgress]).exists { e => + e.details.asInstanceOf[FlowProgress].status == flowStatus + }, + s"Found a matching event for flow $identifier. Logs for $identifier are $flowEvents" + ) + } + + /** Returns a map of flow names to their latest [[FlowStatus]]es. */ + protected def latestFlowStatuses(eventBuffer: PipelineRunEventBuffer): Map[String, FlowStatus] = { + eventBuffer.getEvents + .filter(_.details.isInstanceOf[FlowProgress]) + .groupBy(_.origin.flowName.get) + .view + .mapValues { events: Seq[PipelineEvent] => + events.reverse + .map(_.details) + .collectFirst { case fp: FlowProgress => fp.status } + } + .collect { case (k, Some(v)) => k -> v } + .toMap + } + + /** + * Asserts that there is a planning event in the event log with the specified flow name. + */ + protected def assertPlanningEvent( + eventBuffer: PipelineRunEventBuffer, + identifier: TableIdentifier + ): Unit = { + val flowEventLogName = identifier.unquotedString + val expectedPlanningMessage = s"Flow '$flowEventLogName' is PLANNING." + val foundPlanningEvent = eventBuffer.getEvents + .filter { e => + val matchName = e.origin.flowName.contains(flowEventLogName) + val matchDetails = e.details match { + case fp: FlowProgress => fp.status == FlowStatus.PLANNING + case _ => false + } + matchName && matchDetails + } + assert( + foundPlanningEvent.nonEmpty && + foundPlanningEvent.head.message.contains(expectedPlanningMessage), + s"Planning event not found for flow $flowEventLogName" + ) + } + + /** + * Asserts that there is an [[RunProgress]] event in the event log with the specified id + * and state and matching the specified error condition. + * + * @param state Expected [[RunState]] + * @param expectedEventLevel Expected [[EventLevel]] of the event. + * @param errorChecker Condition that the event's exception, if any, must pass in order for this + * function to return true. + * @param msgChecker Condition that the event's msg must pass in order for the function to + * return true. + */ + protected def assertRunProgressEvent( + eventBuffer: PipelineRunEventBuffer, + state: RunState, + expectedEventLevel: EventLevel, + errorChecker: Option[ErrorDetail] => Boolean = null, + msgChecker: String => Boolean = _ => true): Unit = { + val errorCheckerWithDefault = Option(errorChecker).getOrElse { + if (state == RunState.FAILED) { (errorDetailsOpt: Option[ErrorDetail]) => + errorDetailsOpt.nonEmpty + } else { (errorDetailsOpt: Option[ErrorDetail]) => + errorDetailsOpt.isEmpty + } + } + + val runEvents = eventBuffer.getEvents.filter(_.details.isInstanceOf[RunProgress]) + val expectedEvent = eventBuffer.getEvents.find { e => + val matchingState = e.details match { + case RunProgress(up) => up == state + case _ => false + } + matchingState && errorCheckerWithDefault(e.error) && msgChecker(e.message) + } + assert( + expectedEvent.isDefined, + s"Could not find a matching event with run state $state. " + + s"Logs are $runEvents" + ) + } +} diff --git a/sql/pipelines/src/test/scala/org/apache/spark/sql/pipelines/utils/PipelineTest.scala b/sql/pipelines/src/test/scala/org/apache/spark/sql/pipelines/utils/PipelineTest.scala index 815f2ef894b78..50c0dca02bd57 100644 --- a/sql/pipelines/src/test/scala/org/apache/spark/sql/pipelines/utils/PipelineTest.scala +++ b/sql/pipelines/src/test/scala/org/apache/spark/sql/pipelines/utils/PipelineTest.scala @@ -48,7 +48,9 @@ abstract class PipelineTest final protected val storageRoot = createTempDir() - var spark: SparkSession = _ + var spark: SparkSession = createAndInitializeSpark() + val originalSpark: SparkSession = spark.cloneSession() + implicit def sqlContext: SQLContext = spark.sqlContext def sql(text: String): DataFrame = spark.sql(text) @@ -169,17 +171,17 @@ abstract class PipelineTest } /** - * Creates individual tests for all items in [[params]]. + * Creates individual tests for all items in `params`. * * The full test name will be " ( = )" where is one - * item in [[params]]. + * item in `params`. * * @param testNamePrefix The test name prefix. * @param paramName A descriptive name for the parameter. * @param testTags Extra tags for the test. * @param params The list of parameters for which to generate tests. * @param testFun The actual test function. This function will be called with one argument of - * type [[A]]. + * type `A`. * @tparam A The type of the params. */ protected def gridTest[A](testNamePrefix: String, paramName: String, testTags: Tag*)( @@ -189,8 +191,8 @@ abstract class PipelineTest )(testFun) /** - * Specialized version of gridTest where the params are two boolean values - [[true]] and - * [[false]]. + * Specialized version of gridTest where the params are two boolean values - `true` and + * `false`. */ protected def booleanGridTest(testNamePrefix: String, paramName: String, testTags: Tag*)( testFun: Boolean => Unit): Unit = { @@ -244,8 +246,8 @@ abstract class PipelineTest /** * Runs the plan and makes sure the answer matches the expected result. * - * @param df the [[DataFrame]] to be executed - * @param expectedAnswer the expected result in a [[Seq]] of [[Row]]s. + * @param df the `DataFrame` to be executed + * @param expectedAnswer the expected result in a `Seq` of `Row`s. */ protected def checkAnswer(df: => DataFrame, expectedAnswer: Seq[Row]): Unit = { checkAnswerAndPlan(df, expectedAnswer, None)