diff --git a/python/pyspark/pipelines/graph_element_registry.py b/python/pyspark/pipelines/graph_element_registry.py index 8e311fc2ca98e..c9d4ec3092cec 100644 --- a/python/pyspark/pipelines/graph_element_registry.py +++ b/python/pyspark/pipelines/graph_element_registry.py @@ -51,6 +51,15 @@ def register_sql(self, sql_text: str, file_path: Path) -> None: :param file_path: The path to the file that the SQL txt came from. """ + @abstractmethod + def register_signalled_query_functions(self) -> None: + """Open up a stream to receive query function execution signals from the server. + When a signal is received, execute the corresponding query function and register the + result with the server. + + This method should be called after all the flows and datasets have been registered. + """ + _graph_element_registry_context_var: ContextVar[Optional[GraphElementRegistry]] = ContextVar( "graph_element_registry_context", default=None diff --git a/python/pyspark/pipelines/spark_connect_graph_element_registry.py b/python/pyspark/pipelines/spark_connect_graph_element_registry.py index ab88317908302..4229b35915580 100644 --- a/python/pyspark/pipelines/spark_connect_graph_element_registry.py +++ b/python/pyspark/pipelines/spark_connect_graph_element_registry.py @@ -15,9 +15,10 @@ # limitations under the License. # from pathlib import Path +from pyspark.errors.exceptions.base import PySparkValueError -from pyspark.errors import PySparkTypeError -from pyspark.sql import SparkSession +from pyspark.errors import PySparkException, PySparkTypeError +from pyspark.sql import SparkSession, DataFrame from pyspark.sql.connect.dataframe import DataFrame as ConnectDataFrame from pyspark.pipelines.output import ( Output, @@ -35,6 +36,8 @@ from typing import Any, cast import pyspark.sql.connect.proto as pb2 from pyspark.pipelines.add_pipeline_analysis_context import add_pipeline_analysis_context +from pyspark.pipelines.flow import QueryFunction +import uuid class SparkConnectGraphElementRegistry(GraphElementRegistry): @@ -46,6 +49,12 @@ def __init__(self, spark: SparkSession, dataflow_graph_id: str) -> None: self._spark = spark self._client = cast(Any, spark).client self._dataflow_graph_id = dataflow_graph_id + self._client_id = str(uuid.uuid4()) + self._query_funcs_by_flow_name: dict[str, QueryFunction] = {} + + @property + def dataflow_graph_id(self) -> str: + return self._dataflow_graph_id def register_output(self, output: Output) -> None: table_details = None @@ -111,11 +120,19 @@ def register_output(self, output: Output) -> None: self._client.execute_command(command) def register_flow(self, flow: Flow) -> None: - with add_pipeline_analysis_context( - spark=self._spark, dataflow_graph_id=self._dataflow_graph_id, flow_name=flow.name - ): - df = flow.func() - relation = cast(ConnectDataFrame, df)._plan.plan(self._client) + self._query_funcs_by_flow_name[flow.name] = flow.func + try: + df = self._execute_query_function(flow.name, flow.func) + except PySparkException as e: + if e.getCondition() == "ATTEMPT_ANALYSIS_IN_PIPELINE_QUERY_FUNCTION": + df = None + else: + raise e + + if df is not None: + relation = cast(ConnectDataFrame, df)._plan.plan(self._client) + else: + relation = None relation_flow_details = pb2.PipelineCommand.DefineFlow.WriteRelationFlowDetails( relation=relation, @@ -128,6 +145,7 @@ def register_flow(self, flow: Flow) -> None: relation_flow_details=relation_flow_details, sql_conf=flow.spark_conf, source_code_location=source_code_location_to_proto(flow.source_code_location), + client_id=self._client_id, ) command = pb2.Command() command.pipeline_command.define_flow.CopyFrom(inner_command) @@ -143,6 +161,49 @@ def register_sql(self, sql_text: str, file_path: Path) -> None: command.pipeline_command.define_sql_graph_elements.CopyFrom(inner_command) self._client.execute_command(command) + def register_signalled_query_functions(self) -> None: + """Open up a stream to receive query function execution signals from the server. + When a signal is received, execute the corresponding query function and register the + result with the server. + + This method should be called after all the flows and datasets have been registered. + """ + + inner_command = pb2.PipelineCommand.GetQueryFunctionExecutionSignalStream( + dataflow_graph_id=self._dataflow_graph_id, + client_id=self._client_id, + ) + command = pb2.Command() + command.pipeline_command.get_query_function_execution_signal_stream.CopyFrom(inner_command) + + result_iter = self._client.execute_command_as_iterator(command) + for result in result_iter: + if "pipeline_query_function_execution_signal" not in result.keys(): + raise PySparkValueError( + f"GetQueryFunctionExecutionSignalStream received unexpected result: {result}" + ) + + signal = result["pipeline_query_function_execution_signal"] + flow_names = signal.flow_names + for flow_name in flow_names: + func = self._query_funcs_by_flow_name[flow_name] + df = self._execute_query_function(flow_name, func) + relation = cast(ConnectDataFrame, df)._plan.plan(self._client) + inner_command = pb2.PipelineCommand.DefineFlowQueryFunctionResult( + dataflow_graph_id=self._dataflow_graph_id, + flow_name=flow_name, + relation=relation, + ) + command = pb2.Command() + command.pipeline_command.define_flow_query_function_result.CopyFrom(inner_command) + self._client.execute_command(command) + + def _execute_query_function(self, flow_name: str, func: QueryFunction) -> DataFrame: + with add_pipeline_analysis_context( + spark=self._spark, dataflow_graph_id=self._dataflow_graph_id, flow_name=flow_name + ): + return func() + def source_code_location_to_proto( source_code_location: SourceCodeLocation, diff --git a/python/pyspark/pipelines/tests/python_pipeline_suite_helpers.py b/python/pyspark/pipelines/tests/python_pipeline_suite_helpers.py new file mode 100644 index 0000000000000..dc92fe2d924a0 --- /dev/null +++ b/python/pyspark/pipelines/tests/python_pipeline_suite_helpers.py @@ -0,0 +1,45 @@ +from pyspark.sql import SparkSession +from pyspark import pipelines as dp +from pyspark.pipelines.spark_connect_graph_element_registry import ( + SparkConnectGraphElementRegistry, +) +from pyspark.pipelines.spark_connect_pipeline import create_dataflow_graph +from pyspark.pipelines.spark_connect_pipeline import start_run, handle_pipeline_events +import threading + +def setup(server_port: str, session_identifier: str) -> tuple[SparkSession, SparkConnectGraphElementRegistry]: + spark = SparkSession.builder \ + .remote(f"sc://localhost:{server_port}") \ + .config("spark.connect.grpc.channel.timeout", "5s") \ + .config("spark.custom.identifier", session_identifier) \ + .create() + + dataflow_graph_id = create_dataflow_graph( + spark, + default_catalog=None, + default_database=None, + sql_conf={}, + ) + + registry = SparkConnectGraphElementRegistry(spark, dataflow_graph_id) + return spark, registry + + +def run_and_handle_signals( + spark: SparkSession, + registry: SparkConnectGraphElementRegistry, + storage_root: str, +) -> None: + result_iter = start_run( + spark, + dataflow_graph_id=registry.dataflow_graph_id, + full_refresh=None, + full_refresh_all=False, + refresh=None, + dry=True, + storage=storage_root + ) + thread = threading.Thread(target=handle_pipeline_events, args=(result_iter,), daemon=True) + thread.start() + + registry.register_signalled_query_functions() diff --git a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/pipelines/DataflowGraphRegistry.scala b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/pipelines/DataflowGraphRegistry.scala index e8114f38ec40c..9ff6d8068285c 100644 --- a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/pipelines/DataflowGraphRegistry.scala +++ b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/pipelines/DataflowGraphRegistry.scala @@ -46,30 +46,25 @@ class DataflowGraphRegistry { graphId } - /** Retrieves the graph for a given id. */ - def getDataflowGraph(graphId: String): Option[GraphRegistrationContext] = { - Option(dataflowGraphs.get(graphId)) + /** Returns a Map of dataflow graph IDs to their corresponding GraphRegistrationContext. */ + def getDataflowGraphs: Map[String, GraphRegistrationContext] = { + dataflowGraphs.asScala.toMap } /** Retrieves the graph for a given id, and throws if the id could not be found. */ def getDataflowGraphOrThrow(dataflowGraphId: String): GraphRegistrationContext = - getDataflowGraph(dataflowGraphId).getOrElse { + getDataflowGraphs.getOrElse(dataflowGraphId, { throw new SparkException( errorClass = "DATAFLOW_GRAPH_NOT_FOUND", messageParameters = Map("graphId" -> dataflowGraphId), cause = null) - } + }) /** Removes the graph with a given id from the registry. */ def dropDataflowGraph(graphId: String): Unit = { dataflowGraphs.remove(graphId) } - /** Returns all graphs in the registry. */ - def getAllDataflowGraphs: Seq[GraphRegistrationContext] = { - dataflowGraphs.values().asScala.toSeq - } - /** Removes all graphs from the registry. */ def dropAllDataflowGraphs(): Unit = { dataflowGraphs.clear() diff --git a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/pipelines/PipelinesHandler.scala b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/pipelines/PipelinesHandler.scala index 62f060014117c..132e16e98a852 100644 --- a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/pipelines/PipelinesHandler.scala +++ b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/pipelines/PipelinesHandler.scala @@ -24,17 +24,18 @@ import scala.util.Using import io.grpc.stub.StreamObserver import org.apache.spark.connect.proto -import org.apache.spark.connect.proto.{ExecutePlanResponse, PipelineCommandResult, Relation, ResolvedIdentifier} +import org.apache.spark.connect.proto.{ExecutePlanResponse, PipelineAnalysisContext, PipelineCommandResult, Relation, ResolvedIdentifier} import org.apache.spark.internal.Logging import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.plans.logical.{Command, CreateNamespace, CreateTable, CreateTableAsSelect, CreateView, DescribeRelation, DropView, InsertIntoStatement, LogicalPlan, RenameTable, ShowColumns, ShowCreateTable, ShowFunctions, ShowTableProperties, ShowTables, ShowViews} +import org.apache.spark.sql.classic.DataFrame import org.apache.spark.sql.connect.common.DataTypeProtoConverter import org.apache.spark.sql.connect.service.SessionHolder import org.apache.spark.sql.execution.command.{ShowCatalogsCommand, ShowNamespacesCommand} import org.apache.spark.sql.pipelines.Language.Python import org.apache.spark.sql.pipelines.common.RunState.{CANCELED, FAILED} -import org.apache.spark.sql.pipelines.graph.{AllTables, FlowAnalysis, GraphIdentifierManager, GraphRegistrationContext, IdentifierHelper, NoTables, PipelineUpdateContextImpl, QueryContext, QueryOrigin, QueryOriginType, Sink, SinkImpl, SomeTables, SqlGraphRegistrationContext, Table, TableFilter, TemporaryView, UnresolvedFlow} +import org.apache.spark.sql.pipelines.graph.{AllTables, FlowAnalysis, GraphIdentifierManager, GraphRegistrationContext, IdentifierHelper, NoTables, PipelineUpdateContext, PipelineUpdateContextImpl, QueryContext, QueryFunctionSuccess, QueryOrigin, QueryOriginType, Sink, SinkImpl, SomeTables, SqlGraphRegistrationContext, Table, TableFilter, TemporaryView, UnresolvedFlow} import org.apache.spark.sql.pipelines.logging.{PipelineEvent, RunProgress} import org.apache.spark.sql.types.StructType @@ -123,6 +124,16 @@ private[connect] object PipelinesHandler extends Logging { logInfo(s"Register sql datasets cmd received: $cmd") defineSqlGraphElements(cmd.getDefineSqlGraphElements, sessionHolder) defaultResponse + case proto.PipelineCommand.CommandTypeCase.GET_QUERY_FUNCTION_EXECUTION_SIGNAL_STREAM => + logInfo(s"Get query function execution signal stream cmd received: $cmd") + getQueryFunctionExecutionSignalStream( + cmd.getGetQueryFunctionExecutionSignalStream, responseObserver, sessionHolder) + defaultResponse + case proto.PipelineCommand.CommandTypeCase.DEFINE_FLOW_QUERY_FUNCTION_RESULT => + logInfo(s"Define flow query function result cmd received: $cmd") + defineFlowQueryFunctionResult( + cmd.getDefineFlowQueryFunctionResult, transformRelationFunc, sessionHolder) + defaultResponse case other => throw new UnsupportedOperationException(s"$other not supported") } } @@ -367,12 +378,19 @@ private[connect] object PipelinesHandler extends Logging { } val relationFlowDetails = flow.getRelationFlowDetails + val flowFunction = if (relationFlowDetails.hasRelation) { + FlowAnalysis.createFlowFunctionFromLogicalPlan( + transformRelationFunc(relationFlowDetails.getRelation)) + } else { + FlowAnalysis.createQueryFunctionResultPollingFlowFunction( + flowIdentifier, + graphElementRegistry) + } graphElementRegistry.registerFlow( UnresolvedFlow( identifier = flowIdentifier, destinationIdentifier = destinationIdentifier, - func = FlowAnalysis.createFlowFunctionFromLogicalPlan( - transformRelationFunc(relationFlowDetails.getRelation)), + func = flowFunction, sqlConf = flow.getSqlConfMap.asScala.toMap, once = false, queryContext = QueryContext(Option(defaultCatalog), Option(defaultDatabase)), @@ -528,4 +546,154 @@ private[connect] object PipelinesHandler extends Logging { * A case class to hold the table filters for full refresh and refresh operations. */ private case class TableFilters(fullRefresh: TableFilter, refresh: TableFilter) + + /** + * Handles the GetQueryFunctionExecutionSignalStream command by monitoring for flows + * that have more resolved inputs than the last time we sent a signal. + */ + private def getQueryFunctionExecutionSignalStream( + cmd: proto.PipelineCommand.GetQueryFunctionExecutionSignalStream, + responseObserver: StreamObserver[ExecutePlanResponse], + sessionHolder: SessionHolder): Unit = { + val dataflowGraphId = cmd.getDataflowGraphId + val clientId = cmd.getClientId + + logInfo(s"Starting query function execution signal stream for " + + s"graph $dataflowGraphId, client $clientId") + + var streamCompleted = false + try { + var pipelineExecution: Option[PipelineUpdateContext] = None + var waitAttempts = 0 + val maxWaitAttempts = 100 + + while (pipelineExecution.isEmpty && waitAttempts < maxWaitAttempts) { + pipelineExecution = sessionHolder.getPipelineExecution(dataflowGraphId) + if (pipelineExecution.isEmpty) { + Thread.sleep(100) + waitAttempts += 1 + } + } + + if (pipelineExecution.isEmpty) { + val error = new IllegalStateException( + s"No active pipeline execution found for graph $dataflowGraphId " + + s"after ${maxWaitAttempts * 100}ms") + responseObserver.onError(error) + streamCompleted = true + return + } + + val execution = pipelineExecution.get.pipelineExecution + val graphAnalysisContext = execution.graphAnalysisContext + var signalAttempts = 0 + val maxSignalAttempts = 600 + + while (execution.resolvedGraph.isEmpty && signalAttempts < maxSignalAttempts) { + val signal = proto.PipelineQueryFunctionExecutionSignal.newBuilder() + + while (!graphAnalysisContext.flowClientSignalQueue.isEmpty) { + val flowId = graphAnalysisContext.flowClientSignalQueue.remove() + signal.addFlowNames(flowId.unquotedString) + } + + if (signal.getFlowNamesCount > 0) { + logInfo(s"Sending execution signal for ${signal.getFlowNamesCount} flows") + + val response = ExecutePlanResponse.newBuilder() + .setPipelineQueryFunctionExecutionSignal(signal.build()) + .build() + + responseObserver.onNext(response) + } else { + // scalastyle:on println + } + + Thread.sleep(100) + signalAttempts += 1 + } + + responseObserver.onCompleted() + streamCompleted = true + } catch { + case e: Exception => + logError( + s"Error in query function execution signal stream for graph $dataflowGraphId", e) + if (!streamCompleted) { + responseObserver.onError(e) + streamCompleted = true + } + } finally { + if (!streamCompleted) { + responseObserver.onCompleted() + } + } + } + + /** + * Handles the DefineFlowQueryFunctionResult command by storing the query function result + * in the GraphRegistrationContext. + */ + private def defineFlowQueryFunctionResult( + cmd: proto.PipelineCommand.DefineFlowQueryFunctionResult, + transformRelationFunc: Relation => LogicalPlan, + sessionHolder: SessionHolder): Unit = { + val dataflowGraphId = cmd.getDataflowGraphId + val flowName = cmd.getFlowName + val relation = cmd.getRelation + + val graphElementRegistry = + sessionHolder.dataflowGraphRegistry.getDataflowGraphOrThrow(dataflowGraphId) + + val logicalPlan = transformRelationFunc(relation) + // Assume the identifier is already fully qualified (TODO: verify this) + val flowIdentifier = GraphIdentifierManager.parseTableIdentifier( + flowName, + sessionHolder.session + ) + + graphElementRegistry.registerQueryFunctionResult( + flowIdentifier, + QueryFunctionSuccess(logicalPlan) + ) + sessionHolder.getPipelineExecution(dataflowGraphId) match { + case Some(pipelineUpdateContext) => + // TODO: what if we haven't yet started analysis? + val graphAnalysisContext = pipelineUpdateContext.pipelineExecution.graphAnalysisContext + graphAnalysisContext.markFlowPlanRegistered(flowIdentifier) + case None => + } + + logInfo(s"Registered query function result for flow '$flowName' in graph '$dataflowGraphId'") + } + + def analyze( + pipelineAnalysisContext: PipelineAnalysisContext, + logicalPlan: LogicalPlan, + sessionHolder: SessionHolder): DataFrame = { + val graphId = pipelineAnalysisContext.getDataflowGraphId + val session = sessionHolder.session + sessionHolder.getPipelineExecution(graphId) match { + case Some(pipelineUpdateContext) => + // Assume the identifier is already fully qualified + val flowIdentifier = GraphIdentifierManager.parseTableIdentifier( + pipelineAnalysisContext.getFlowName, + session + ) + + pipelineUpdateContext.pipelineExecution.graphAnalysisContext.analyze( + flowIdentifier, + logicalPlan, + pipelineUpdateContext.unresolvedGraph, + session + ) + case None => + throw new AnalysisException("ATTEMPT_ANALYSIS_IN_PIPELINE_QUERY_FUNCTION", Map()) + // TODO +// throw new IllegalStateException( +// s"Pipeline analysis context specifies flow '${pipelineAnalysisContext.getFlowName}' " + +// s"but no active pipeline execution found for graph '$graphId'" +// ) + } + } } diff --git a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SessionHolder.scala b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SessionHolder.scala index d0d0f0ba750a3..86906b2825b28 100644 --- a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SessionHolder.scala +++ b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SessionHolder.scala @@ -140,6 +140,7 @@ case class SessionHolder(userId: String, sessionId: String, session: SparkSessio // Registry for dataflow graphs specific to this session private[connect] lazy val dataflowGraphRegistry = new DataflowGraphRegistry() + // Handles Python process clean up for streaming queries. Initialized on first use in a query. private[connect] lazy val streamingForeachBatchRunnerCleanerCache = new StreamingForeachBatchHelper.CleanerCache(this) diff --git a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectAnalyzeHandler.scala b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectAnalyzeHandler.scala index ec8d95271c762..0535952d79fce 100644 --- a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectAnalyzeHandler.scala +++ b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectAnalyzeHandler.scala @@ -23,13 +23,14 @@ import io.grpc.stub.StreamObserver import org.apache.spark.connect.proto import org.apache.spark.internal.Logging -import org.apache.spark.sql.{AnalysisException, Row} +import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.encoders.RowEncoder import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.classic.{DataFrame, Dataset} import org.apache.spark.sql.connect.common.{DataTypeProtoConverter, InvalidPlanInput, StorageLevelProtoConverter} +import org.apache.spark.sql.connect.pipelines.PipelinesHandler import org.apache.spark.sql.connect.planner.SparkConnectPlanner -import org.apache.spark.sql.connect.utils.{PipelineAnalysisContextUtils, PlanCompressionUtils} +import org.apache.spark.sql.connect.utils.PlanCompressionUtils import org.apache.spark.sql.execution.{CodegenMode, CommandExecutionMode, CostMode, ExtendedMode, FormattedMode, SimpleMode} import org.apache.spark.sql.types.{DataType, StructType} import org.apache.spark.util.ArrayImplicits._ @@ -62,21 +63,29 @@ private[connect] class SparkConnectAnalyzeHandler( lazy val planner = new SparkConnectPlanner(sessionHolder) val session = sessionHolder.session val builder = proto.AnalyzePlanResponse.newBuilder() - val userContext = request.getUserContext - // Pipeline has not yet supported eager analysis inside flow function. - if (PipelineAnalysisContextUtils.isInsidePipelineFlowFunction(userContext)) { - throw new AnalysisException("ATTEMPT_ANALYSIS_IN_PIPELINE_QUERY_FUNCTION", Map()) + def transformRelation(rel: proto.Relation) = { + planner.transformRelation(rel, cachePlan = true) } - def transformRelation(rel: proto.Relation) = planner.transformRelation(rel, cachePlan = true) def transformRelationPlan(plan: proto.Plan) = { transformRelation(PlanCompressionUtils.decompressPlan(plan).getRoot) } def getDataFrameWithoutExecuting(rel: LogicalPlan): DataFrame = { - val qe = session.sessionState.executePlan(rel, CommandExecutionMode.SKIP) - new Dataset[Row](qe, () => RowEncoder.encoderFor(qe.analyzed.schema)) + val userContextExtensions = request.getUserContext.getExtensionsList.asScala + val pipelineAnalysisContextOpt = userContextExtensions + .filter(_.is(classOf[proto.PipelineAnalysisContext])) + .map(_.unpack(classOf[proto.PipelineAnalysisContext])) + .find(ctx => ctx.hasFlowName && ctx.hasDataflowGraphId) + + pipelineAnalysisContextOpt match { + case Some(pipelineAnalysisContext) => + PipelinesHandler.analyze(pipelineAnalysisContext, rel, sessionHolder) + case None => + val qe = session.sessionState.executePlan(rel, CommandExecutionMode.SKIP) + new Dataset[Row](qe, () => RowEncoder.encoderFor(qe.analyzed.schema)) + } } request.getAnalyzeCase match { diff --git a/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/pipelines/PythonPipelineSuite.scala b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/pipelines/PythonPipelineSuite.scala index fd05b0cc357eb..b3646baeac6b0 100644 --- a/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/pipelines/PythonPipelineSuite.scala +++ b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/pipelines/PythonPipelineSuite.scala @@ -24,20 +24,23 @@ import java.util.UUID import java.util.concurrent.TimeUnit import scala.collection.mutable.ArrayBuffer +import scala.concurrent.{ExecutionContext, Future} import scala.util.Try import org.scalactic.source.Position import org.scalatest.Tag +import org.scalatest.concurrent.{Eventually, Futures} +import org.scalatest.time.{Millis, Seconds, Span} import org.apache.spark.api.python.PythonUtils import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.connect.PythonTestDepsChecker -import org.apache.spark.sql.connect.service.SparkConnectService +import org.apache.spark.sql.connect.service.{SessionHolder, SparkConnectService} import org.apache.spark.sql.connector.catalog.{Identifier, TableCatalog} import org.apache.spark.sql.pipelines.Language.Python import org.apache.spark.sql.pipelines.common.FlowStatus -import org.apache.spark.sql.pipelines.graph.{DataflowGraph, PipelineUpdateContextImpl, QueryOrigin, QueryOriginType} +import org.apache.spark.sql.pipelines.graph.{DataflowGraph, GraphRegistrationContext, PipelineUpdateContextImpl, QueryOrigin, QueryOriginType} import org.apache.spark.sql.pipelines.logging.EventLevel import org.apache.spark.sql.pipelines.utils.{EventVerificationTestHelpers, TestPipelineUpdateContextMixin} import org.apache.spark.sql.types.StructType @@ -57,37 +60,22 @@ class PythonPipelineSuite val customSessionIdentifier = UUID.randomUUID().toString val pythonCode = s""" - |from pyspark.sql import SparkSession |from pyspark import pipelines as dp - |from pyspark.pipelines.spark_connect_graph_element_registry import ( - | SparkConnectGraphElementRegistry, + |from pyspark.pipelines.tests.python_pipeline_suite_helpers import setup + |from pyspark.pipelines.add_pipeline_analysis_context import ( + | add_pipeline_analysis_context |) - |from pyspark.pipelines.spark_connect_pipeline import create_dataflow_graph |from pyspark.pipelines.graph_element_registry import ( | graph_element_registration_context, |) - |from pyspark.pipelines.add_pipeline_analysis_context import ( - | add_pipeline_analysis_context - |) - | - |spark = SparkSession.builder \\ - | .remote("sc://localhost:$serverPort") \\ - | .config("spark.connect.grpc.channel.timeout", "5s") \\ - | .config("spark.custom.identifier", "$customSessionIdentifier") \\ - | .create() | - |dataflow_graph_id = create_dataflow_graph( - | spark, - | default_catalog=None, - | default_database=None, - | sql_conf={}, - |) + |spark, registry = setup("$serverPort", "$customSessionIdentifier") | - |registry = SparkConnectGraphElementRegistry(spark, dataflow_graph_id) |with add_pipeline_analysis_context( - | spark=spark, dataflow_graph_id=dataflow_graph_id, flow_name=None + | spark=spark, dataflow_graph_id=registry.dataflow_graph_id, flow_name=None |): | with graph_element_registration_context(registry): + | |$indentedPythonText |""".stripMargin @@ -98,20 +86,101 @@ class PythonPipelineSuite throw new RuntimeException( s"Python process failed with exit code $exitCode. Output: ${output.mkString("\n")}") } + + getCurrentGraphRegistrationContext(customSessionIdentifier) + .getOrElse(throw new RuntimeException("Graph registration context not found")) + .toDataflowGraph + } + + private def findSessionHolder(sessionId: String): Option[SessionHolder] = { val activeSessions = SparkConnectService.sessionManager.listActiveSessions - // get the session holder by finding the session with the custom UUID set in the conf - val sessionHolder = activeSessions + activeSessions .map(info => SparkConnectService.sessionManager.getIsolatedSession(info.key, None)) - .find(_.session.conf.get("spark.custom.identifier") == customSessionIdentifier) - .getOrElse( - throw new RuntimeException(s"Session with identifier $customSessionIdentifier not found")) + .find(_.session.conf.get("spark.custom.identifier") == sessionId) + } - // get all dataflow graphs from the session holder - val dataflowGraphContexts = sessionHolder.dataflowGraphRegistry.getAllDataflowGraphs - assert(dataflowGraphContexts.size == 1) + private def getCurrentGraphRegistrationContext( + sessionId: String): Option[GraphRegistrationContext] = { + findSessionHolder(sessionId).map { sessionHolder => + val dataflowGraphs = sessionHolder.dataflowGraphRegistry.getDataflowGraphs + if (dataflowGraphs.size == 1) { + dataflowGraphs.values.head + } else { + throw new RuntimeException(s"Expected exactly 1 graph, but found ${dataflowGraphs.size}") + } + } + } - dataflowGraphContexts.head.toDataflowGraph + def buildAndResolveGraph(pythonText: String): DataflowGraph = { + val indentedPythonText = pythonText.linesIterator.map(" " + _).mkString("\n") + val sessionId = UUID.randomUUID().toString + val pythonCode = + s""" + |from pyspark import pipelines as dp + |from pyspark.pipelines.tests.python_pipeline_suite_helpers import * + |from pyspark.pipelines.add_pipeline_analysis_context import ( + | add_pipeline_analysis_context + |) + |from pyspark.pipelines.graph_element_registry import ( + | graph_element_registration_context, + |) + | + |spark, registry = setup("$serverPort", "$sessionId") + | + |with add_pipeline_analysis_context( + | spark=spark, dataflow_graph_id=registry.dataflow_graph_id, flow_name=None + |): + | with graph_element_registration_context(registry): + | + |$indentedPythonText + | + |run_and_handle_signals(spark, registry, "$storageRoot") + |""".stripMargin + + // Create a custom execution context for the Future + implicit val ec: ExecutionContext = ExecutionContext.global + // Execute the code in a separate thread + val pythonExecutionFuture: Future[(Int, Seq[String])] = Future { + executePythonCode(pythonCode) + } + + // Wait until the session exists, timeout after 10 seconds + val sessionHolder = Eventually.eventually( + Futures.timeout(Span(10, Seconds)), Futures.interval(Span(100, Millis))) { + findSessionHolder(sessionId).getOrElse( + throw new RuntimeException(s"Session with ID $sessionId not found")) + } + + // Wait until a graph registration context exists, timeout after 10 seconds + val graphRegistrationContext = Eventually.eventually( + Futures.timeout(Span(10, Seconds)), Futures.interval(Span(100, Millis))) { + getCurrentGraphRegistrationContext(sessionId).getOrElse( + throw new RuntimeException("Graph registration context not found")) + } + +// val unresolvedGraph = graphRegistrationContext.toDataflowGraph +// val updateContext = TestPipelineUpdateContext(spark, unresolvedGraph, storageRoot) + + // Get the actual graph ID from the registry for this context + val dataflowGraphs = sessionHolder.dataflowGraphRegistry.getDataflowGraphs + val graphId = dataflowGraphs.find(_._2 == graphRegistrationContext).map(_._1) + .getOrElse(throw new RuntimeException("Could not find graph ID for context")) + + val pipelineExecution = Eventually.eventually( + Futures.timeout(Span(10, Seconds)), Futures.interval(Span(100, Millis))) { + sessionHolder.getPipelineExecution(graphId).getOrElse( + throw new RuntimeException("Pipeline execution not found")) + } +// sessionHolder.cachePipelineExecution(graphId, updateContext) + +// val graph = updateContext.pipelineExecution.resolveGraph() + + Eventually.eventually( + Futures.timeout(Span(10, Seconds)), Futures.interval(Span(100, Millis))) { + pipelineExecution.pipelineExecution.resolvedGraph.getOrElse( + throw new RuntimeException("Resolved graph not found")) + } } def graphIdentifier(name: String): TableIdentifier = { @@ -1169,6 +1238,31 @@ class PythonPipelineSuite |""".stripMargin) } + test("access upstream schema within query function") { + val graph = buildAndResolveGraph(""" + |@dp.materialized_view + |def mv2(): + | spark.table("table1").schema + | return spark.table("table1") + | + |@dp.materialized_view + |def mv1(): + | return spark.range(5) + |""".stripMargin) + assert(graph.flows.size == 2) + assert(graph.tables.size == 2) + } + + test("query function failure") { + val graph = buildAndResolveGraph(""" + |@dp.materialized_view + |def mv(): + | raise ValueError("bla") + |""".stripMargin) + assert(graph.flows.size == 2) + assert(graph.tables.size == 2) + } + override protected def test(testName: String, testTags: Tag*)(testFun: => Any)(implicit pos: Position): Unit = { if (PythonTestDepsChecker.isConnectDepsAvailable) { diff --git a/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/pipelines/SparkDeclarativePipelinesServerSuite.scala b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/pipelines/SparkDeclarativePipelinesServerSuite.scala index 3cb45fa6e1720..cd1345319a63f 100644 --- a/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/pipelines/SparkDeclarativePipelinesServerSuite.scala +++ b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/pipelines/SparkDeclarativePipelinesServerSuite.scala @@ -321,7 +321,7 @@ class SparkDeclarativePipelinesServerSuite .build())) // Verify the graph exists in the default session - assert(getDefaultSessionHolder.dataflowGraphRegistry.getAllDataflowGraphs.size == 1) + assert(getDefaultSessionHolder.dataflowGraphRegistry.getDataflowGraphs.size == 1) } // Create a second session with different user/session ID @@ -374,8 +374,8 @@ class SparkDeclarativePipelinesServerSuite .getIsolatedSession(SessionKey(newSessionUserId, newSessionId), None) val defaultSessionGraphs = - getDefaultSessionHolder.dataflowGraphRegistry.getAllDataflowGraphs - val newSessionGraphs = newSessionHolder.dataflowGraphRegistry.getAllDataflowGraphs + getDefaultSessionHolder.dataflowGraphRegistry.getDataflowGraphs.values + val newSessionGraphs = newSessionHolder.dataflowGraphRegistry.getDataflowGraphs.values assert(defaultSessionGraphs.size == 1) assert(newSessionGraphs.size == 1) @@ -441,7 +441,7 @@ class SparkDeclarativePipelinesServerSuite .getIsolatedSessionIfPresent(SessionKey(testUserId, testSessionId)) .get - val graphsBefore = sessionHolder.dataflowGraphRegistry.getAllDataflowGraphs + val graphsBefore = sessionHolder.dataflowGraphRegistry.getDataflowGraphs.values assert(graphsBefore.size == 1) // Close the session @@ -453,7 +453,7 @@ class SparkDeclarativePipelinesServerSuite assert(sessionAfterClose.isEmpty, "Session should be cleaned up after close") // Verify the graph is removed - val graphsAfter = sessionHolder.dataflowGraphRegistry.getAllDataflowGraphs + val graphsAfter = sessionHolder.dataflowGraphRegistry.getDataflowGraphs.values assert(graphsAfter.isEmpty, "Graph should be removed after session close") } } @@ -518,7 +518,7 @@ class SparkDeclarativePipelinesServerSuite // Verify the graph exists val sessionHolder = getDefaultSessionHolder - val graphsBefore = sessionHolder.dataflowGraphRegistry.getAllDataflowGraphs + val graphsBefore = sessionHolder.dataflowGraphRegistry.getDataflowGraphs.values assert(graphsBefore.size == 1) // Drop the graph @@ -532,7 +532,7 @@ class SparkDeclarativePipelinesServerSuite .build())) // Verify the graph is removed - val graphsAfter = sessionHolder.dataflowGraphRegistry.getAllDataflowGraphs + val graphsAfter = sessionHolder.dataflowGraphRegistry.getDataflowGraphs.values assert(graphsAfter.isEmpty, "Graph should be removed after drop") } } @@ -944,4 +944,67 @@ class SparkDeclarativePipelinesServerSuite relation.getCommand.getSqlCommand.getInput) } } + + test("SessionHolder provides pipeline execution data for analysis") { + withRawBlockingStub { implicit stub => + val graphId = createDataflowGraph + + sendPlan( + buildPlanFromPipelineCommand( + PipelineCommand + .newBuilder() + .setDefineOutput( + DefineOutput + .newBuilder() + .setDataflowGraphId(graphId) + .setOutputName("test_mv") + .setOutputType(OutputType.MATERIALIZED_VIEW)) + .build())) + + sendPlan( + buildPlanFromPipelineCommand( + PipelineCommand + .newBuilder() + .setDefineFlow( + DefineFlow + .newBuilder() + .setDataflowGraphId(graphId) + .setFlowName("test_flow") + .setTargetDatasetName("test_mv") + .setRelationFlowDetails( + DefineFlow.WriteRelationFlowDetails + .newBuilder() + .setRelation( + Relation + .newBuilder() + .setUnresolvedTableValuedFunction(UnresolvedTableValuedFunction + .newBuilder() + .setFunctionName("range") + .addArguments(Expression + .newBuilder() + .setLiteral(Expression.Literal.newBuilder().setInteger(10).build()) + .build()) + .build()) + .build()) + .build()) + .build()) + .build())) + + val sessionHolder = getDefaultSessionHolder + + assert(sessionHolder.getPipelineExecution(graphId).isEmpty, + "Pipeline execution should not exist before starting the run") + + val definition = sessionHolder.dataflowGraphRegistry.getDataflowGraphOrThrow(graphId) + val resolvedGraph = definition.toDataflowGraph.resolve() + + assert(resolvedGraph.inputIdentifiers.nonEmpty, + "Resolved graph should have input identifiers") + assert(resolvedGraph.inputIdentifiers.exists(_.table == "test_mv"), + "test_mv should be in the graph inputs") + + assert(resolvedGraph.flows.size == 1, "Graph should have one flow") + assert(resolvedGraph.tables.size == 1, "Graph should have one table") + } + } } diff --git a/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/CoreDataflowNodeProcessor.scala b/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/CoreDataflowNodeProcessor.scala index 38fde0bfec4a1..5d0e3858f6fc6 100644 --- a/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/CoreDataflowNodeProcessor.scala +++ b/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/CoreDataflowNodeProcessor.scala @@ -17,10 +17,6 @@ package org.apache.spark.sql.pipelines.graph -import java.util.concurrent.{ConcurrentHashMap, ConcurrentLinkedQueue} - -import scala.jdk.CollectionConverters._ - import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.pipelines.graph.DataflowGraphTransformer.{ @@ -32,26 +28,19 @@ import org.apache.spark.sql.pipelines.graph.DataflowGraphTransformer.{ * Processor that is responsible for analyzing each flow and sort the nodes in * topological order */ -class CoreDataflowNodeProcessor(rawGraph: DataflowGraph) { - +class CoreDataflowNodeProcessor(rawGraph: DataflowGraph, context: GraphAnalysisContext) { private val flowResolver = new FlowResolver(rawGraph) - // Map of input identifier to resolved [[Input]]. - private val resolvedInputs = new ConcurrentHashMap[TableIdentifier, Input]() - // Map & queue of resolved flows identifiers - // queue is there to track the topological order while map is used to store the id -> flow - // mapping - private val resolvedFlowNodesMap = new ConcurrentHashMap[TableIdentifier, ResolvedFlow]() - private val resolvedFlowNodesQueue = new ConcurrentLinkedQueue[ResolvedFlow]() - - private def processUnresolvedFlow(flow: UnresolvedFlow): ResolvedFlow = { + private def processUnresolvedFlow( + flow: UnresolvedFlow, + context: GraphAnalysisContext): ResolvedFlow = { val resolvedFlow = flowResolver.attemptResolveFlow( flow, rawGraph.inputIdentifiers, - resolvedInputs.asScala.toMap + context.resolvedInputsByIdentifier ) - resolvedFlowNodesQueue.add(resolvedFlow) - resolvedFlowNodesMap.put(flow.identifier, resolvedFlow) + context.resolvedFlowsQueue.add(resolvedFlow) + context.resolvedFlowsMap.put(flow.identifier, resolvedFlow) resolvedFlow } @@ -65,8 +54,8 @@ class CoreDataflowNodeProcessor(rawGraph: DataflowGraph) { */ def processNode(node: GraphElement, upstreamNodes: Seq[GraphElement]): Seq[GraphElement] = { node match { - case flow: UnresolvedFlow => Seq(processUnresolvedFlow(flow)) - case failedFlow: ResolutionFailedFlow => Seq(processUnresolvedFlow(failedFlow.flow)) + case flow: UnresolvedFlow => Seq(processUnresolvedFlow(flow, context)) + case failedFlow: ResolutionFailedFlow => Seq(processUnresolvedFlow(failedFlow.flow, context)) case table: Table => // Ensure all upstreamNodes for a table are flows val flowsToTable = upstreamNodes.map { @@ -78,7 +67,7 @@ class CoreDataflowNodeProcessor(rawGraph: DataflowGraph) { ) } val resolvedFlowsToTable = flowsToTable.map { flow => - resolvedFlowNodesMap.get(flow.identifier) + context.resolvedFlowsMap.get(flow.identifier) } // We mark all tables as virtual to ensure resolution uses incoming flows // rather than previously materialized tables. @@ -88,14 +77,14 @@ class CoreDataflowNodeProcessor(rawGraph: DataflowGraph) { incomingFlowIdentifiers = flowsToTable.map(_.identifier).toSet, availableFlows = resolvedFlowsToTable ) - resolvedInputs.put(table.identifier, virtualTableInput) + context.putResolvedInput(virtualTableInput) Seq(table) case view: View => // For view, add the flow to resolvedInputs and return empty. require(upstreamNodes.size == 1, "Found multiple flows to view") upstreamNodes.head match { case f: Flow => - resolvedInputs.put(view.identifier, resolvedFlowNodesMap.get(f.destinationIdentifier)) + context.putResolvedInput(context.resolvedFlowsMap.get(f.destinationIdentifier)) Seq(view) case _ => throw new IllegalArgumentException( @@ -120,6 +109,7 @@ private class FlowResolver(rawGraph: DataflowGraph) { flowToResolve: UnresolvedFlow, allInputs: Set[TableIdentifier], availableResolvedInputs: Map[TableIdentifier, Input]): ResolvedFlow = { + val flowFunctionResult = flowToResolve.func.call( allInputs = allInputs, availableInputs = availableResolvedInputs.values.toList, @@ -184,11 +174,17 @@ private class FlowResolver(rawGraph: DataflowGraph) { case f => f.dataFrame.failed.toOption.collectFirst { case e: UnresolvedDatasetException => e + case e: QueryFunctionResultNotAvailableException => e case _ => None } match { case Some(e: UnresolvedDatasetException) => throw TransformNodeRetryableException( - e.identifier, + Some(e.identifier), + new ResolutionFailedFlow(flowToResolve, flowFunctionResult) + ) + case Some(_: QueryFunctionResultNotAvailableException) => + throw TransformNodeRetryableException( + None, new ResolutionFailedFlow(flowToResolve, flowFunctionResult) ) case _ => diff --git a/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/DataflowGraph.scala b/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/DataflowGraph.scala index c9578ddd3b469..822d693ba9366 100644 --- a/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/DataflowGraph.scala +++ b/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/DataflowGraph.scala @@ -228,14 +228,16 @@ case class DataflowGraph( def resolved: Boolean = flows.forall(f => resolvedFlow.contains(f.identifier)) - def resolve(): DataflowGraph = + def resolve(contextOpt: Option[GraphAnalysisContext] = None): DataflowGraph = { + val context = contextOpt.getOrElse(new GraphAnalysisContext()) DataflowGraphTransformer.withDataflowGraphTransformer(this) { transformer => val coreDataflowNodeProcessor = - new CoreDataflowNodeProcessor(rawGraph = this) + new CoreDataflowNodeProcessor(rawGraph = this, context) transformer - .transformDownNodes(coreDataflowNodeProcessor.processNode) + .transformDownNodes(coreDataflowNodeProcessor.processNode, context) .getDataflowGraph } + } } object DataflowGraph { diff --git a/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/DataflowGraphTransformer.scala b/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/DataflowGraphTransformer.scala index 2523c0ae5502a..66dae5099f091 100644 --- a/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/DataflowGraphTransformer.scala +++ b/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/DataflowGraphTransformer.scala @@ -18,9 +18,6 @@ package org.apache.spark.sql.pipelines.graph import java.util.concurrent.{ - ConcurrentHashMap, - ConcurrentLinkedDeque, - ConcurrentLinkedQueue, ExecutionException, Future } @@ -64,6 +61,7 @@ class DataflowGraphTransformer(graph: DataflowGraph) extends AutoCloseable { private var sinks: Seq[Sink] = graph.sinks private var sinkMap: Map[TableIdentifier, Sink] = computeSinkMap() + // Fail analysis nodes // Failed flows are flows that are failed to resolve or its inputs are not available or its // destination failed to resolve. @@ -72,6 +70,8 @@ class DataflowGraphTransformer(graph: DataflowGraph) extends AutoCloseable { private var failedTables: Seq[Table] = Seq.empty private var failedSinks: Seq[Sink] = Seq.empty + @volatile private var currentCoreDataflowNodeProcessor: Option[CoreDataflowNodeProcessor] = None + private val parallelism = 10 // Executor used to resolve nodes in parallel. It is lazily initialized to avoid creating it @@ -120,25 +120,18 @@ class DataflowGraphTransformer(graph: DataflowGraph) extends AutoCloseable { */ def transformDownNodes( transformer: (GraphElement, Seq[GraphElement]) => Seq[GraphElement], + context: GraphAnalysisContext, disableParallelism: Boolean = false): DataflowGraphTransformer = { val executor = if (disableParallelism) selfExecutor else fixedPoolExecutor val batchSize = if (disableParallelism) 1 else parallelism - // List of resolved tables, sinks and flows - val resolvedFlows = new ConcurrentLinkedQueue[ResolutionCompletedFlow]() - val resolvedTables = new ConcurrentLinkedQueue[Table]() - val resolvedViews = new ConcurrentLinkedQueue[View]() - val resolvedSinks = new ConcurrentLinkedQueue[Sink]() - // Flow identifier to a list of transformed flows mapping to track resolved flows - val resolvedFlowsMap = new ConcurrentHashMap[TableIdentifier, Seq[Flow]]() - val resolvedFlowDestinationsMap = new ConcurrentHashMap[TableIdentifier, Boolean]() - val failedFlowsQueue = new ConcurrentLinkedQueue[ResolutionFailedFlow]() - val failedDependentFlows = new ConcurrentHashMap[TableIdentifier, Seq[ResolutionFailedFlow]]() var futures = ArrayBuffer[Future[Unit]]() - val toBeResolvedFlows = new ConcurrentLinkedDeque[Flow]() - toBeResolvedFlows.addAll(flows.asJava) + context.toBeResolvedFlows.addAll(flows.asJava) - while (futures.nonEmpty || toBeResolvedFlows.peekFirst() != null) { + while ( + futures.nonEmpty || + context.toBeResolvedFlows.peekFirst() != null || + !context.failedUnregisteredFlows.isEmpty) { val (done, notDone) = futures.partition(_.isDone) // Explicitly call future.get() to propagate exceptions one by one if any try { @@ -152,7 +145,7 @@ class DataflowGraphTransformer(graph: DataflowGraph) extends AutoCloseable { val flowOpt = { // We only schedule [[batchSize]] number of flows in parallel. if (futures.size < batchSize) { - Option(toBeResolvedFlows.pollFirst()) + Option(context.toBeResolvedFlows.pollFirst()) } else { None } @@ -160,152 +153,29 @@ class DataflowGraphTransformer(graph: DataflowGraph) extends AutoCloseable { flowOpt.foreach { flow => futures.append( executor.submit( - () => - try { - try { - // Note: Flow don't need their inputs passed, so for now we send empty Seq. - val result = transformer(flow, Seq.empty) - require( - result.forall(_.isInstanceOf[ResolvedFlow]), - "transformer must return a Seq[Flow]" - ) - - val transformedFlows = result.map(_.asInstanceOf[ResolvedFlow]) - resolvedFlowsMap.put(flow.identifier, transformedFlows) - resolvedFlows.addAll(transformedFlows.asJava) - } catch { - case e: TransformNodeRetryableException => - val datasetIdentifier = e.datasetIdentifier - failedDependentFlows.compute( - datasetIdentifier, - (_, flows) => { - // Don't add the input flow back but the failed flow object - // back which has relevant failure information. - val failedFlow = e.failedNode - if (flows == null) { - Seq(failedFlow) - } else { - flows :+ failedFlow - } - } - ) - // Between the time the flow started and finished resolving, perhaps the - // dependent dataset was resolved - resolvedFlowDestinationsMap.computeIfPresent( - datasetIdentifier, - (_, resolved) => { - if (resolved) { - // Check if the dataset that the flow is dependent on has been resolved - // and if so, remove all dependent flows from the failedDependentFlows and - // add them to the toBeResolvedFlows queue for retry. - failedDependentFlows.computeIfPresent( - datasetIdentifier, - (_, toRetryFlows) => { - toRetryFlows.foreach(toBeResolvedFlows.addFirst(_)) - null - } - ) - } - resolved - } - ) - case other: Throwable => throw other - } - // If all flows to this particular destination are resolved, move to the destination - // node transformer - if (flowsTo(flow.destinationIdentifier).forall({ flowToDestination => - resolvedFlowsMap.containsKey(flowToDestination.identifier) - })) { - // If multiple flows completed in parallel, ensure we resolve the destination only - // once by electing a leader via computeIfAbsent - var isCurrentThreadLeader = false - resolvedFlowDestinationsMap.computeIfAbsent(flow.destinationIdentifier, _ => { - isCurrentThreadLeader = true - // Set initial value as false as flow destination is not resolved yet. - false - }) - if (isCurrentThreadLeader) { - if (tableMap.contains(flow.destinationIdentifier)) { - val transformed = - transformer( - tableMap(flow.destinationIdentifier), - flowsTo(flow.destinationIdentifier) - ) - resolvedTables.addAll( - transformed.collect { case t: Table => t }.asJava - ) - resolvedFlows.addAll( - transformed.collect { case f: ResolvedFlow => f }.asJava - ) - } else if (viewMap.contains(flow.destinationIdentifier)) { - resolvedViews.addAll { - val transformed = - transformer( - viewMap(flow.destinationIdentifier), - flowsTo(flow.destinationIdentifier) - ) - transformed.map(_.asInstanceOf[View]).asJava - } - } else if (sinkMap.contains(flow.destinationIdentifier)) { - resolvedSinks.addAll { - val transformed = - transformer( - sinkMap(flow.destinationIdentifier), flowsTo(flow.destinationIdentifier) - ) - require( - transformed.forall(_.isInstanceOf[Sink]), - "transformer must return a Seq[Sink]" - ) - transformed.map(_.asInstanceOf[Sink]).asJava - } - } else { - throw new IllegalArgumentException( - s"Unsupported destination ${flow.destinationIdentifier.unquotedString}" + - s" in flow: ${flow.displayName} at transformDownNodes" - ) - } - // Set flow destination as resolved now. - resolvedFlowDestinationsMap.computeIfPresent( - flow.destinationIdentifier, - (_, _) => { - // If there are any other node failures dependent on this destination, retry - // them - failedDependentFlows.computeIfPresent( - flow.destinationIdentifier, - (_, toRetryFlows) => { - toRetryFlows.foreach(toBeResolvedFlows.addFirst(_)) - null - } - ) - true - } - ) - } - } - } catch { - case ex: TransformNodeFailedException => failedFlowsQueue.add(ex.failedNode) - } + () => transformFlowAndMaybeDestination(flow, transformer, context) ) ) } } + // Mutate the fail analysis entities // A table is failed to analyze if: // - It does not exist in the resolvedFlowDestinationsMap failedTables = tables.filterNot { table => - resolvedFlowDestinationsMap.getOrDefault(table.identifier, false) + context.resolvedFlowDestinationsMap.getOrDefault(table.identifier, false) } // A sink is failed to analyze if: // - It does not exist in the resolvedFlowDestinationsMap failedSinks = sinks.filterNot { sink => - resolvedFlowDestinationsMap.getOrDefault(sink.identifier, false) + context.resolvedFlowDestinationsMap.getOrDefault(sink.identifier, false) } // We maintain the topological sort order of successful flows always val (resolvedFlowsWithResolvedDest, resolvedFlowsWithFailedDest) = - resolvedFlows.asScala.toSeq.partition(flow => { - resolvedFlowDestinationsMap.getOrDefault(flow.destinationIdentifier, false) + context.resolvedFlowsQueue.asScala.toSeq.partition(flow => { + context.resolvedFlowDestinationsMap.getOrDefault(flow.destinationIdentifier, false) }) // A flow is failed to analyze if: @@ -318,22 +188,146 @@ class DataflowGraphTransformer(graph: DataflowGraph) extends AutoCloseable { // All transformed flows that write to a destination that is failed to analyze. resolvedFlowsWithFailedDest ++ // All failed flows thrown by TransformNodeFailedException - failedFlowsQueue.asScala.toSeq ++ + context.failedFlowsQueue.asScala.toSeq ++ // All flows that have not been transformed and resolved yet - failedDependentFlows.values().asScala.flatten.toSeq + context.failedDependentFlows.values().asScala.flatten.toSeq // Mutate the resolved entities flows = resolvedFlowsWithResolvedDest flowsTo = computeFlowsTo() - tables = resolvedTables.asScala.toSeq - views = resolvedViews.asScala.toSeq - sinks = resolvedSinks.asScala.toSeq + tables = context.resolvedTables.asScala.toSeq + views = context.resolvedViews.asScala.toSeq + sinks = context.resolvedSinks.asScala.toSeq tableMap = computeTableMap() viewMap = computeViewMap() sinkMap = computeSinkMap() + // scalastyle: on this } + private[pipelines] def transformFlowAndMaybeDestination( + flow: Flow, + transformer: (GraphElement, Seq[GraphElement]) => Seq[GraphElement], + context: GraphAnalysisContext): Unit = { + try { + try { + // Note: Flow don't need their inputs passed, so for now we send empty Seq. + transformer(flow, Seq.empty) + } catch { + case e: TransformNodeRetryableException => + e.datasetIdentifier match { + case None => + context.failedUnregisteredFlows.put(e.failedNode.identifier, e.failedNode) + case Some(datasetIdentifier) => + context.registerFailedDependentFlow(datasetIdentifier, e.failedNode) + // Between the time the flow started and finished resolving, perhaps the + // dependent dataset was resolved + context.resolvedFlowDestinationsMap.computeIfPresent( + datasetIdentifier, + (_, resolved) => { + if (resolved) { + // Check if the dataset that the flow is dependent on has been resolved + // and if so, remove all dependent flows from the failedDependentFlows and + // add them to the toBeResolvedFlows queue for retry. + context.failedDependentFlows.computeIfPresent( + datasetIdentifier, + (_, toRetryFlows) => { + toRetryFlows.foreach { f => + if (context.failedUnregisteredFlows.containsKey(f.identifier)) { + context.flowClientSignalQueue.add(f.identifier) + } else { + context.toBeResolvedFlows.addFirst(f) + } + } + null + } + ) + } + resolved + } + ) + } + case other: Throwable => throw other + } + // If all flows to this particular destination are resolved, move to the destination + // node transformer + if (flowsTo(flow.destinationIdentifier).forall({ flowToDestination => + context.resolvedFlowsMap.containsKey(flowToDestination.identifier) + })) { + // If multiple flows completed in parallel, ensure we resolve the destination only + // once by electing a leader via computeIfAbsent + var isCurrentThreadLeader = false + context.resolvedFlowDestinationsMap.computeIfAbsent(flow.destinationIdentifier, _ => { + isCurrentThreadLeader = true + // Set initial value as false as flow destination is not resolved yet. + false + }) + if (isCurrentThreadLeader) { + if (tableMap.contains(flow.destinationIdentifier)) { + val transformed = + transformer( + tableMap(flow.destinationIdentifier), + flowsTo(flow.destinationIdentifier) + ) + context.resolvedTables.addAll( + transformed.collect { case t: Table => t }.asJava + ) + } else if (viewMap.contains(flow.destinationIdentifier)) { + context.resolvedViews.addAll { + val transformed = + transformer( + viewMap(flow.destinationIdentifier), + flowsTo(flow.destinationIdentifier) + ) + transformed.map(_.asInstanceOf[View]).asJava + } + } else if (sinkMap.contains(flow.destinationIdentifier)) { + context.resolvedSinks.addAll { + val transformed = + transformer( + sinkMap(flow.destinationIdentifier), flowsTo(flow.destinationIdentifier) + ) + require( + transformed.forall(_.isInstanceOf[Sink]), + "transformer must return a Seq[Sink]" + ) + transformed.map(_.asInstanceOf[Sink]).asJava + } + } else { + throw new IllegalArgumentException( + s"Unsupported destination ${flow.destinationIdentifier.unquotedString}" + + s" in flow: ${flow.displayName} at transformDownNodes" + ) + } + // Set flow destination as resolved now. + context.resolvedFlowDestinationsMap.computeIfPresent( + flow.destinationIdentifier, + (_, _) => { + // If there are any other node failures dependent on this destination, retry + // them + context.failedDependentFlows.computeIfPresent( + flow.destinationIdentifier, + (_, toRetryFlows) => { + toRetryFlows.foreach { f => + if (context.failedUnregisteredFlows.containsKey(f.identifier)) { + context.flowClientSignalQueue.add(f.identifier) + } else { + context.toBeResolvedFlows.addFirst(f) + } + } + null + } + ) + true + } + ) + } + } + } catch { + case ex: TransformNodeFailedException => context.failedFlowsQueue.add(ex.failedNode) + } + } + def getDataflowGraph: DataflowGraph = { graph.copy( // Returns all flows (resolved and failed) in topological order. @@ -356,6 +350,7 @@ class DataflowGraphTransformer(graph: DataflowGraph) extends AutoCloseable { object DataflowGraphTransformer { + /** * Exception thrown when transforming a node in the graph fails because at least one of its * dependencies weren't yet transformed. @@ -364,7 +359,7 @@ object DataflowGraphTransformer { * dataflow graph. */ case class TransformNodeRetryableException( - datasetIdentifier: TableIdentifier, + datasetIdentifier: Option[TableIdentifier], failedNode: ResolutionFailedFlow) extends Exception with NoStackTrace diff --git a/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/Flow.scala b/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/Flow.scala index e329308502f0d..48c04c02afe5e 100644 --- a/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/Flow.scala +++ b/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/Flow.scala @@ -120,6 +120,23 @@ case class FlowFunctionResult( final def resolved: Boolean = failure.isEmpty // don't override this, override failure } +object FlowFunctionResult { + def fromFlowAnalysisContext( + ctx: FlowAnalysisContext, + df: Try[DataFrame], + confs: Map[String, String]): FlowFunctionResult = { + FlowFunctionResult( + requestedInputs = ctx.requestedInputs.toSet, + batchInputs = ctx.batchInputs.toSet, + streamingInputs = ctx.streamingInputs.toSet, + usedExternalInputs = ctx.externalInputs.toSet, + dataFrame = df, + sqlConf = confs, + analysisWarnings = ctx.analysisWarnings.toList + ) + } +} + /** A [[Flow]] whose output schema and dependencies aren't known. */ case class UnresolvedFlow( identifier: TableIdentifier, diff --git a/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/FlowAnalysis.scala b/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/FlowAnalysis.scala index 1a00a6339c4ba..5c068ae718fcf 100644 --- a/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/FlowAnalysis.scala +++ b/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/FlowAnalysis.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.pipelines.graph -import scala.util.Try +import scala.util.{Failure, Try} import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.{AliasIdentifier, TableIdentifier} @@ -58,18 +58,78 @@ object FlowAnalysis { } finally { ctx.restoreOriginalConf() } - FlowFunctionResult( - requestedInputs = ctx.requestedInputs.toSet, - batchInputs = ctx.batchInputs.toSet, - streamingInputs = ctx.streamingInputs.toSet, - usedExternalInputs = ctx.externalInputs.toSet, - dataFrame = df, - sqlConf = confs, - analysisWarnings = ctx.analysisWarnings.toList - ) + FlowFunctionResult.fromFlowAnalysisContext(ctx, df, confs) } } + /** + * Creates a [[FlowFunction]] that looks up the LogicalPlan from the GraphRegistrationContext + * for the given flow name and analyzes it. + * + * @param flowIdentifier The ID of the flow to look up in the context. + * @param context The GraphRegistrationContext containing query function results. + * @return A FlowFunction that analyzes the stored LogicalPlan for the flow. + */ + def createQueryFunctionResultPollingFlowFunction( + flowIdentifier: TableIdentifier, + context: GraphRegistrationContext): FlowFunction = { + (allInputs: Set[TableIdentifier], + availableInputs: Seq[Input], + confs: Map[String, String], + queryContext: QueryContext, + queryOrigin: QueryOrigin) => { + context.getQueryFunctionResult(flowIdentifier) match { + case Some(result: QueryFunctionSuccess) => + val flowFunc = createFlowFunctionFromLogicalPlan(result.plan) + flowFunc.call(allInputs, availableInputs, confs, queryContext, queryOrigin) + case Some(QueryFunctionTerminalFailure) => + FlowFunctionResult( + requestedInputs = Set.empty, + batchInputs = Set.empty, + streamingInputs = Set.empty, + usedExternalInputs = Set.empty, + dataFrame = Failure(QueryFunctionTerminalFailureException()), + sqlConf = confs, + analysisWarnings = Seq.empty + ) + case None => + FlowFunctionResult( + requestedInputs = Set.empty, + batchInputs = Set.empty, + streamingInputs = Set.empty, + usedExternalInputs = Set.empty, + dataFrame = Failure(QueryFunctionResultNotAvailableException()), + sqlConf = confs, + analysisWarnings = Seq.empty + ) + } + } + } + + /** + * Public wrapper method for flow analysis from Spark Connect. + * Creates FlowAnalysisContext internally and calls the analyze method. + */ + def analyze( + allInputs: Set[TableIdentifier], + availableInputs: Seq[Input], + currentCatalog: String, + currentDatabase: String, + spark: SparkSession, + plan: LogicalPlan + ): DataFrame = { + val context = FlowAnalysisContext( + allInputs = allInputs, + availableInputs = availableInputs, + queryContext = QueryContext( + currentCatalog = Some(currentCatalog), + currentDatabase = Some(currentDatabase) + ), + spark = spark + ) + analyze(context, plan) + } + /** * Constructs an analyzed [[DataFrame]] from a [[LogicalPlan]] by resolving Pipelines specific * TVFs and datasets that cannot be resolved directly by Catalyst. @@ -81,7 +141,7 @@ object FlowAnalysis { * @param plan The [[LogicalPlan]] defining a flow. * @return An analyzed [[DataFrame]]. */ - private def analyze( + def analyze( context: FlowAnalysisContext, plan: LogicalPlan ): DataFrame = { diff --git a/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/GraphAnalysisContext.scala b/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/GraphAnalysisContext.scala new file mode 100644 index 0000000000000..92462d0731a05 --- /dev/null +++ b/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/GraphAnalysisContext.scala @@ -0,0 +1,115 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.pipelines.graph + +import java.util.concurrent.{ConcurrentHashMap, ConcurrentLinkedDeque, ConcurrentLinkedQueue} + +import scala.jdk.CollectionConverters.ConcurrentMapHasAsScala +import scala.util.Failure + +import org.apache.spark.sql.catalyst.TableIdentifier +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.classic.{DataFrame, SparkSession} + +class GraphAnalysisContext { + val toBeResolvedFlows = new ConcurrentLinkedDeque[Flow]() + + // Map of input identifier to resolved [[Input]]. + private val resolvedInputsHashMap = new ConcurrentHashMap[TableIdentifier, Input]() + + // Destination identifier to boolean indicating whether the destination has been resolved + val resolvedFlowDestinationsMap = new ConcurrentHashMap[TableIdentifier, Boolean]() + + // Map & queue of resolved flows identifiers + // queue is there to track the topological order while map is used to store the id -> flow + // mapping + val resolvedFlowsMap = new ConcurrentHashMap[TableIdentifier, ResolvedFlow]() + val resolvedFlowsQueue = new ConcurrentLinkedQueue[ResolvedFlow]() + + // List of resolved tables, sinks and flows + val resolvedTables = new ConcurrentLinkedQueue[Table]() + val resolvedViews = new ConcurrentLinkedQueue[View]() + val resolvedSinks = new ConcurrentLinkedQueue[Sink]() + + // Flows that failed due to the client not yet having registered a plan. Keyed by flow identifier. + val failedUnregisteredFlows = new ConcurrentHashMap[TableIdentifier, ResolutionFailedFlow]() + + // Dataset identifier to list of flows that failed resolution due to missing this dataset + val failedDependentFlows = new ConcurrentHashMap[TableIdentifier, Seq[ResolutionFailedFlow]]() + val failedFlowsQueue = new ConcurrentLinkedQueue[ResolutionFailedFlow]() + + // Queue of flow identifiers that have had their upstream inputs resolved and should have their + // flow function retried on the client + val flowClientSignalQueue = new ConcurrentLinkedQueue[TableIdentifier]() + + def putResolvedInput(input: Input): Unit = { + resolvedInputsHashMap.put(input.identifier, input) + } + + def resolvedInputsByIdentifier: Map[TableIdentifier, Input] = resolvedInputsHashMap.asScala.toMap + + def registerFailedDependentFlow( + inputDatasetIdentifier: TableIdentifier, + failedFlow: ResolutionFailedFlow): Unit = { + failedDependentFlows.compute( + inputDatasetIdentifier, + (_, flows) => { + if (flows == null) { + Seq(failedFlow) + } else { + flows :+ failedFlow + } + } + ) + } + + def markFlowPlanRegistered(flowIdentifier: TableIdentifier): Unit = { + val flow = failedUnregisteredFlows.remove(flowIdentifier) + // Flow could be null if it failed on the client side and we never tried to execute the flow + // function on the server side. + if (flow != null) { + toBeResolvedFlows.addFirst(flow) + } else { + } + } + + def analyze( + flowIdentifier: TableIdentifier, + logicalPlan: LogicalPlan, + unresolvedGraph: DataflowGraph, + session: SparkSession): DataFrame = { + val unresolvedFlow = unresolvedGraph.flow(flowIdentifier).asInstanceOf[UnresolvedFlow] + val flowAnalysisContext = FlowAnalysisContext( + allInputs = unresolvedGraph.inputIdentifiers, + availableInputs = resolvedInputsByIdentifier.values.toSeq, + queryContext = unresolvedFlow.queryContext, + spark = session + ) + + try { + FlowAnalysis.analyze(flowAnalysisContext, plan = logicalPlan) + } catch { + case e: UnresolvedDatasetException => + val flowFunctionResult = + FlowFunctionResult.fromFlowAnalysisContext(flowAnalysisContext, Failure(e), Map.empty) + + val resolutionFailedFlow = new ResolutionFailedFlow(unresolvedFlow, flowFunctionResult) + registerFailedDependentFlow(e.identifier, resolutionFailedFlow) + throw e + } + } +} diff --git a/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/GraphRegistrationContext.scala b/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/GraphRegistrationContext.scala index dadda0561b19f..f058917b4f7cd 100644 --- a/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/GraphRegistrationContext.scala +++ b/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/GraphRegistrationContext.scala @@ -16,10 +16,13 @@ */ package org.apache.spark.sql.pipelines.graph +import java.util.concurrent.ConcurrentHashMap + import scala.collection.mutable import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.TableIdentifier +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan /** * A mutable context for registering tables, views, and flows in a dataflow graph. @@ -37,6 +40,8 @@ class GraphRegistrationContext( protected val views = new mutable.ListBuffer[View] protected val sinks = new mutable.ListBuffer[Sink] protected val flows = new mutable.ListBuffer[UnresolvedFlow] + // keyed by flow ID + private val queryFunctionResults = new ConcurrentHashMap[TableIdentifier, QueryFunctionResult]() def registerTable(tableDef: Table): Unit = { tables += tableDef @@ -62,6 +67,15 @@ class GraphRegistrationContext( flows += flowDef.copy(sqlConf = defaultSqlConf ++ flowDef.sqlConf) } + def registerQueryFunctionResult( + flowIdentifier: TableIdentifier, result: QueryFunctionResult): Unit = { + queryFunctionResults.put(flowIdentifier, result) + } + + def getQueryFunctionResult(flowIdentifier: TableIdentifier): Option[QueryFunctionResult] = { + Option(queryFunctionResults.get(flowIdentifier)) + } + private def isEmpty: Boolean = { tables.isEmpty && views.collect { case v: PersistedView => v @@ -189,3 +203,9 @@ object GraphRegistrationContext { override def toString: String = "SINK" } } + +sealed trait QueryFunctionResult + +case class QueryFunctionSuccess(plan: LogicalPlan) extends QueryFunctionResult + +case object QueryFunctionTerminalFailure extends QueryFunctionResult diff --git a/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/PipelineExecution.scala b/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/PipelineExecution.scala index d35d701d44e57..f5f3c9dd01e66 100644 --- a/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/PipelineExecution.scala +++ b/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/PipelineExecution.scala @@ -35,7 +35,10 @@ import org.apache.spark.sql.pipelines.logging.{ class PipelineExecution(context: PipelineUpdateContext) { /** [Visible for testing] */ - private[pipelines] var graphExecution: Option[TriggeredGraphExecution] = None + @volatile private[pipelines] var graphExecution: Option[TriggeredGraphExecution] = None + /** [Visible for testing] */ + @volatile var resolvedGraph: Option[DataflowGraph] = None + val graphAnalysisContext = new GraphAnalysisContext() def executionStarted: Boolean = synchronized { graphExecution.nonEmpty } @@ -45,12 +48,12 @@ class PipelineExecution(context: PipelineUpdateContext) { */ def startPipeline(): Unit = synchronized { // Initialize the graph. - val resolvedGraph = resolveGraph() + resolvedGraph = Some(resolveGraph()) if (context.fullRefreshTables.nonEmpty) { - State.reset(resolvedGraph, context) + State.reset(resolvedGraph.get, context) } - val initializedGraph = DatasetManager.materializeDatasets(resolvedGraph, context) + val initializedGraph = DatasetManager.materializeDatasets(resolvedGraph.get, context) // Execute the graph. graphExecution = Some( @@ -86,7 +89,7 @@ class PipelineExecution(context: PipelineUpdateContext) { /** Validates that the pipeline graph can be successfully resolved and validates it. */ def dryRunPipeline(): Unit = synchronized { - resolveGraph() + resolvedGraph = Some(resolveGraph()) context.eventCallback( constructTerminationEvent(RunCompletion()) ) @@ -108,9 +111,10 @@ class PipelineExecution(context: PipelineUpdateContext) { ) } - private def resolveGraph(): DataflowGraph = { + def resolveGraph(): DataflowGraph = { try { - context.unresolvedGraph.resolve().validate() + val resolved = context.unresolvedGraph.resolve(Some(graphAnalysisContext)).validate() + resolved } catch { case e: UnresolvedPipelineException => handleInvalidPipeline(e) diff --git a/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/PipelinesErrors.scala b/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/PipelinesErrors.scala index 7116f5fbcf068..979aaba388c97 100644 --- a/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/PipelinesErrors.scala +++ b/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/PipelinesErrors.scala @@ -35,6 +35,12 @@ case class UnresolvedDatasetException(identifier: TableIdentifier) s"pipeline but could not be resolved." ) +case class QueryFunctionResultNotAvailableException() + extends AnalysisException("Query function result is not yet available.") + +case class QueryFunctionTerminalFailureException() + extends AnalysisException("Query function failed in a way that further analysis won't fix.") + /** * Exception raised when a flow fails to read from a table defined within the pipeline * diff --git a/sql/pipelines/src/test/scala/org/apache/spark/sql/pipelines/graph/DistributedAnalysisSuite.scala b/sql/pipelines/src/test/scala/org/apache/spark/sql/pipelines/graph/DistributedAnalysisSuite.scala new file mode 100644 index 0000000000000..0f9aa71b12db5 --- /dev/null +++ b/sql/pipelines/src/test/scala/org/apache/spark/sql/pipelines/graph/DistributedAnalysisSuite.scala @@ -0,0 +1,262 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.pipelines.graph + +import scala.collection.mutable + +import org.apache.spark.sql.catalyst.TableIdentifier +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.classic.SparkSession +import org.apache.spark.sql.pipelines.utils.{PipelineTest, TestGraphRegistrationContext} +import org.apache.spark.sql.test.SharedSparkSession + +/** + * TODO + */ +class DistributedAnalysisSuite extends PipelineTest with SharedSparkSession { + case class DummyFlowInfo( + destinationTable: String, + eagerDeps: Seq[String] = Seq.empty, + lazyDeps: Seq[String] = Seq.empty, + failsTerminally: Boolean = false + ) { + val flowName: String = destinationTable.replace("table", "flow") + + def register(registrationContext: TestGraphRegistrationContext): Unit = { + val flowFunction = if (eagerDeps.nonEmpty) { + FlowAnalysis.createQueryFunctionResultPollingFlowFunction( + fullyQualifiedIdentifier(flowName), + registrationContext + ) + } else { + val plan = if (lazyDeps.nonEmpty) { + spark.sessionState.sqlParser + .parsePlan(eagerDeps.map(t => s"select * from $t").mkString(" union ")) + } else { + spark.range(4).logicalPlan + } + FlowAnalysis.createFlowFunctionFromLogicalPlan(plan) + } + registrationContext.registerFlow(destinationTable, flowName, flowFunction) + } + } + + class TestDistributedAnalyzer(val dummyFlows: Seq[DummyFlowInfo], val spark: SparkSession) { + val registrationContext = new TestGraphRegistrationContext(spark) + dummyFlows.foreach { f => + f.register(registrationContext) + registrationContext.registerTable(f.destinationTable) + } + private val dummyFlowsById = dummyFlows.map { f => + (fullyQualifiedIdentifier(f.flowName), f) + }.toMap + + val unresolvedGraph: DataflowGraph = registrationContext.toDataflowGraph + val analysisContext = new GraphAnalysisContext() + unresolvedGraph.flows.foreach(analysisContext.toBeResolvedFlows.add) + val transformer = new DataflowGraphTransformer(unresolvedGraph) + + val nodeProcessor = new CoreDataflowNodeProcessor(unresolvedGraph, analysisContext) + + // Signals received by the client, but not yet processed + val clientSignalQueue = new mutable.Queue[TableIdentifier]() + + def transformerIteration(): Unit = { + Option(analysisContext.toBeResolvedFlows.pollFirst()).foreach { flow => + transformer.transformFlowAndMaybeDestination( + flow, + nodeProcessor.processNode, + analysisContext + ) + } + } + + def analyze(flowIdentifier: TableIdentifier, plan: LogicalPlan): Unit = { + analysisContext.analyze(flowIdentifier, plan, unresolvedGraph, spark) + } + + def registerQueryFunctionResult( + flowIdentifier: TableIdentifier, + result: QueryFunctionResult): Unit = { + registrationContext.registerQueryFunctionResult(flowIdentifier, result) + analysisContext.markFlowPlanRegistered(flowIdentifier) + } + + def runToCompletion(randomSeed: Integer): Unit = { + val r = new scala.util.Random(randomSeed) + + var numIters = 0 + while (numIters < 200 + && analysisContext.toBeResolvedFlows.isEmpty + && clientSignalQueue.isEmpty + && analysisContext.flowClientSignalQueue.isEmpty) { + numIters += 1 + + r.nextInt(3) match { + case 0 => transformerIteration() + case 1 => // Simulate the signal-sending RPC thread + if (!analysisContext.flowClientSignalQueue.isEmpty) { + val flowToRetryId = analysisContext.flowClientSignalQueue.poll() + dummyFlowsById(flowToRetryId).eagerDeps.foreach { tableNames => + try { + tableNames.foreach { tableName => + val plan = spark.sessionState.sqlParser.parsePlan(s"select * from $tableName") + analyze(flowToRetryId, plan) + } + clientSignalQueue.append(flowToRetryId) + } catch { + case _: UnresolvedDatasetException => + } + } + } + case 2 => // Simulate the client + if (clientSignalQueue.nonEmpty) { + val flowId = clientSignalQueue.dequeue() + val upstreamTables = dummyFlowsById(flowId).eagerDeps + val dummyFlow = dummyFlowsById(flowId) + val result = if (dummyFlow.failsTerminally) { + QueryFunctionTerminalFailure + } else { + QueryFunctionSuccess( + spark.sessionState.sqlParser + .parsePlan(s"select * from $upstreamTables") + .logicalPlan + ) + } + registerQueryFunctionResult(flowId, result) + } + } + + assert( + analysisContext.toBeResolvedFlows.size + + analysisContext.resolvedFlowsMap.size + + analysisContext.failedUnregisteredFlows.size + + analysisContext.failedDependentFlows.size + == unresolvedGraph.flows.size + ) + } + } + } + + test("single node external .columns") { + val externalTableId = fullyQualifiedIdentifier("external") + spark.sql(s"CREATE TABLE $externalTableId AS SELECT * FROM RANGE(3)") + val dummyFlows = Seq( + DummyFlowInfo("table1", eagerDeps = Seq(externalTableId.quotedString)) + ) + val testDistributedAnalyzer = new TestDistributedAnalyzer(dummyFlows, spark) + + // transformer processes node, flow function should fail because no relation for node, get + // popped back on queue + testDistributedAnalyzer.transformerIteration() + + // define query function result + testDistributedAnalyzer.registerQueryFunctionResult( + fullyQualifiedIdentifier("table1"), + QueryFunctionSuccess(spark.range(4).logicalPlan) + ) + + // transformer processes node again. flow function should succeed. + testDistributedAnalyzer.transformerIteration() + } + + test("two nodes with second has .columns") { + // flow 1 -> table -> flow 2 -> table + // flow 2 can't be resolved immediately because it analyzes flow 1 + val dummyFlows = Seq( + DummyFlowInfo("table2", eagerDeps = Seq("table1")), + DummyFlowInfo("table1") + ) + + val testDistributedAnalyzer = new TestDistributedAnalyzer(dummyFlows, spark) + val analysisContext = testDistributedAnalyzer.analysisContext + val selectFromTable1Plan = spark.sessionState.sqlParser.parsePlan("select * from table1") + + val flow1Id = fullyQualifiedIdentifier("flow1") + val table1Id = fullyQualifiedIdentifier("table1") + val flow2Id = fullyQualifiedIdentifier("flow2") + + // Attempt to analyze table1 on behalf of flow2. Should fail and record dependency. + intercept[UnresolvedDatasetException]( + testDistributedAnalyzer.analyze(flow2Id, selectFromTable1Plan) + ) + val failedDependentFlows = analysisContext.failedDependentFlows + assert(failedDependentFlows.size() == 1) + assert(failedDependentFlows.get(table1Id).map(_.identifier).toSet == Set(flow2Id)) + + // transformer processes flow2, flow function should fail because no relation for node, get + // added to unregistered list + testDistributedAnalyzer.transformerIteration() + assert(analysisContext.failedUnregisteredFlows.containsKey(flow2Id)) + + // transformer processes flow1, flow function should succeed + // flow2 shouldn't be added to the queue, because its plan is still unregistered + testDistributedAnalyzer.transformerIteration() + assert(analysisContext.toBeResolvedFlows.isEmpty) + assert(analysisContext.resolvedFlowsMap.containsKey(flow1Id)) + + // detects flow1 is resolved, sends signal to retry flow2 + assert( + testDistributedAnalyzer.analysisContext.flowClientSignalQueue.toArray.toSeq == Seq(flow2Id) + ) + testDistributedAnalyzer.analysisContext.flowClientSignalQueue.clear() + + testDistributedAnalyzer.analyze(flow2Id, selectFromTable1Plan) + + // define query function result + testDistributedAnalyzer.registerQueryFunctionResult( + flow2Id, + QueryFunctionSuccess(selectFromTable1Plan) + ) + assert(analysisContext.failedUnregisteredFlows.isEmpty) + assert(analysisContext.toBeResolvedFlows.size == 1) + + // transformer processes node again. flow function should succeed. + testDistributedAnalyzer.transformerIteration() + assert(analysisContext.toBeResolvedFlows.isEmpty) + assert(analysisContext.failedUnregisteredFlows.isEmpty) + assert(analysisContext.failedDependentFlows.isEmpty) + assert(analysisContext.resolvedFlowsMap.size == 2) + } + + test("random orderings") { + // flow 1 -> table -> flow 2 -> table + // flow 2 can't be resolved immediately because it analyzes flow 1 + val dummyFlows = Seq( + DummyFlowInfo("table2", eagerDeps = Seq("table1")), + DummyFlowInfo("table1") + ) + + val testDistributedAnalyzer = new TestDistributedAnalyzer(dummyFlows, spark) + + testDistributedAnalyzer.runToCompletion(4367) + assert(testDistributedAnalyzer.analysisContext.failedUnregisteredFlows.isEmpty) + assert(testDistributedAnalyzer.analysisContext.failedDependentFlows.isEmpty) + assert(testDistributedAnalyzer.analysisContext.resolvedFlowsMap.size == 2) + } + + test("query function fails after eager analysis") { + val dummyFlows = Seq( + DummyFlowInfo("table2", eagerDeps = Seq("table1"), failsTerminally = true), + DummyFlowInfo("table1") + ) + val testDistributedAnalyzer = new TestDistributedAnalyzer(dummyFlows, spark) + testDistributedAnalyzer.runToCompletion(4367) + assert(testDistributedAnalyzer.analysisContext.resolvedFlowsMap.size == 0) + } +}