From fe33f307148f50458f9be15375926cb05b936b38 Mon Sep 17 00:00:00 2001 From: Sandy Ryza Date: Sat, 29 Nov 2025 14:08:24 -0800 Subject: [PATCH 1/7] graph analysis context --- .../pipelines/graph_element_registry.py | 9 + .../spark_connect_graph_element_registry.py | 64 +++- .../connect/pipelines/PipelinesHandler.scala | 148 ++++++++- .../sql/connect/service/SessionHolder.scala | 1 + .../service/SparkConnectAnalyzeHandler.scala | 27 +- .../pipelines/PythonPipelineSuite.scala | 29 ++ ...SparkDeclarativePipelinesServerSuite.scala | 63 ++++ .../graph/CoreDataflowNodeProcessor.scala | 44 ++- .../sql/pipelines/graph/DataflowGraph.scala | 8 +- .../graph/DataflowGraphTransformer.scala | 293 +++++++++--------- .../spark/sql/pipelines/graph/Flow.scala | 17 + .../sql/pipelines/graph/FlowAnalysis.scala | 82 ++++- .../graph/GraphAnalysisContext.scala | 114 +++++++ .../graph/GraphRegistrationContext.scala | 20 ++ .../pipelines/graph/PipelineExecution.scala | 3 +- .../sql/pipelines/graph/PipelinesErrors.scala | 6 + .../graph/DistributedAnalysisSuite.scala | 262 ++++++++++++++++ 17 files changed, 981 insertions(+), 209 deletions(-) create mode 100644 sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/GraphAnalysisContext.scala create mode 100644 sql/pipelines/src/test/scala/org/apache/spark/sql/pipelines/graph/DistributedAnalysisSuite.scala diff --git a/python/pyspark/pipelines/graph_element_registry.py b/python/pyspark/pipelines/graph_element_registry.py index 8e311fc2ca98e..c9d4ec3092cec 100644 --- a/python/pyspark/pipelines/graph_element_registry.py +++ b/python/pyspark/pipelines/graph_element_registry.py @@ -51,6 +51,15 @@ def register_sql(self, sql_text: str, file_path: Path) -> None: :param file_path: The path to the file that the SQL txt came from. """ + @abstractmethod + def register_signalled_query_functions(self) -> None: + """Open up a stream to receive query function execution signals from the server. + When a signal is received, execute the corresponding query function and register the + result with the server. + + This method should be called after all the flows and datasets have been registered. + """ + _graph_element_registry_context_var: ContextVar[Optional[GraphElementRegistry]] = ContextVar( "graph_element_registry_context", default=None diff --git a/python/pyspark/pipelines/spark_connect_graph_element_registry.py b/python/pyspark/pipelines/spark_connect_graph_element_registry.py index ab88317908302..fb51185342fb9 100644 --- a/python/pyspark/pipelines/spark_connect_graph_element_registry.py +++ b/python/pyspark/pipelines/spark_connect_graph_element_registry.py @@ -15,9 +15,10 @@ # limitations under the License. # from pathlib import Path +from pyspark.errors.exceptions.base import PySparkValueError from pyspark.errors import PySparkTypeError -from pyspark.sql import SparkSession +from pyspark.sql import SparkSession, DataFrame from pyspark.sql.connect.dataframe import DataFrame as ConnectDataFrame from pyspark.pipelines.output import ( Output, @@ -35,6 +36,8 @@ from typing import Any, cast import pyspark.sql.connect.proto as pb2 from pyspark.pipelines.add_pipeline_analysis_context import add_pipeline_analysis_context +from pyspark.pipelines.flow import QueryFunction +import uuid class SparkConnectGraphElementRegistry(GraphElementRegistry): @@ -46,6 +49,8 @@ def __init__(self, spark: SparkSession, dataflow_graph_id: str) -> None: self._spark = spark self._client = cast(Any, spark).client self._dataflow_graph_id = dataflow_graph_id + self._client_id = str(uuid.uuid4()) + self._query_funcs_by_flow_name: dict[str, QueryFunction] = {} def register_output(self, output: Output) -> None: table_details = None @@ -111,10 +116,14 @@ def register_output(self, output: Output) -> None: self._client.execute_command(command) def register_flow(self, flow: Flow) -> None: - with add_pipeline_analysis_context( - spark=self._spark, dataflow_graph_id=self._dataflow_graph_id, flow_name=flow.name - ): - df = flow.func() + self._query_funcs_by_flow_name[flow.name] = flow.func + try: + df = self._execute_query_function(flow.name, flow.func) + except Exception as e: + raise PySparkValueError( + f"Error executing query function for flow {flow.name}: {e}" + ) + relation = cast(ConnectDataFrame, df)._plan.plan(self._client) relation_flow_details = pb2.PipelineCommand.DefineFlow.WriteRelationFlowDetails( @@ -128,6 +137,7 @@ def register_flow(self, flow: Flow) -> None: relation_flow_details=relation_flow_details, sql_conf=flow.spark_conf, source_code_location=source_code_location_to_proto(flow.source_code_location), + client_id=self._client_id, ) command = pb2.Command() command.pipeline_command.define_flow.CopyFrom(inner_command) @@ -143,6 +153,50 @@ def register_sql(self, sql_text: str, file_path: Path) -> None: command.pipeline_command.define_sql_graph_elements.CopyFrom(inner_command) self._client.execute_command(command) + def register_signalled_query_functions(self) -> None: + """Open up a stream to receive query function execution signals from the server. + When a signal is received, execute the corresponding query function and register the + result with the server. + + This method should be called after all the flows and datasets have been registered. + """ + + inner_command = pb2.PipelineCommand.GetQueryFunctionExecutionSignalStream( + dataflow_graph_id=self._dataflow_graph_id, + client_id=self._client_id, + ) + command = pb2.Command() + command.pipeline_command.get_query_function_execution_signal_stream.CopyFrom(inner_command) + self._client.execute_command_as_iterator(command) + + result_iter = self._client.execute_command_as_iterator(inner_command) + for result in result_iter: + if "pipeline_query_function_execution_signal" not in result.keys(): + raise PySparkValueError( + f"GetQueryFunctionExecutionSignalStream received unexpected result: {result}" + ) + + signal = result["pipeline_query_function_execution_signal"] + flow_names = signal.flow_names + for flow_name in flow_names: + func = self._query_funcs_by_flow_name[flow_name] + df = self._execute_query_function(flow_name, func) + relation = cast(ConnectDataFrame, df)._plan.plan(self._client) + inner_command = pb2.PipelineCommand.DefineFlowQueryFunctionResult( + dataflow_graph_id=self._dataflow_graph_id, + flow_name=flow_name, + relation=relation, + ) + command = pb2.Command() + command.pipeline_command.define_flow_query_function_result.CopyFrom(inner_command) + self._client.execute_command(command) + + def _execute_query_function(self, flow_name: str, func: QueryFunction) -> DataFrame: + with add_pipeline_analysis_context( + spark=self._spark, dataflow_graph_id=self._dataflow_graph_id, flow_name=flow_name + ): + return func() + def source_code_location_to_proto( source_code_location: SourceCodeLocation, diff --git a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/pipelines/PipelinesHandler.scala b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/pipelines/PipelinesHandler.scala index 62f060014117c..3fe5e34af5f31 100644 --- a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/pipelines/PipelinesHandler.scala +++ b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/pipelines/PipelinesHandler.scala @@ -24,17 +24,18 @@ import scala.util.Using import io.grpc.stub.StreamObserver import org.apache.spark.connect.proto -import org.apache.spark.connect.proto.{ExecutePlanResponse, PipelineCommandResult, Relation, ResolvedIdentifier} +import org.apache.spark.connect.proto.{ExecutePlanResponse, PipelineAnalysisContext, PipelineCommandResult, Relation, ResolvedIdentifier} import org.apache.spark.internal.Logging import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.plans.logical.{Command, CreateNamespace, CreateTable, CreateTableAsSelect, CreateView, DescribeRelation, DropView, InsertIntoStatement, LogicalPlan, RenameTable, ShowColumns, ShowCreateTable, ShowFunctions, ShowTableProperties, ShowTables, ShowViews} +import org.apache.spark.sql.classic.DataFrame import org.apache.spark.sql.connect.common.DataTypeProtoConverter import org.apache.spark.sql.connect.service.SessionHolder import org.apache.spark.sql.execution.command.{ShowCatalogsCommand, ShowNamespacesCommand} import org.apache.spark.sql.pipelines.Language.Python import org.apache.spark.sql.pipelines.common.RunState.{CANCELED, FAILED} -import org.apache.spark.sql.pipelines.graph.{AllTables, FlowAnalysis, GraphIdentifierManager, GraphRegistrationContext, IdentifierHelper, NoTables, PipelineUpdateContextImpl, QueryContext, QueryOrigin, QueryOriginType, Sink, SinkImpl, SomeTables, SqlGraphRegistrationContext, Table, TableFilter, TemporaryView, UnresolvedFlow} +import org.apache.spark.sql.pipelines.graph.{AllTables, FlowAnalysis, GraphIdentifierManager, GraphRegistrationContext, IdentifierHelper, NoTables, PipelineUpdateContextImpl, QueryContext, QueryFunctionSuccess, QueryOrigin, QueryOriginType, Sink, SinkImpl, SomeTables, SqlGraphRegistrationContext, Table, TableFilter, TemporaryView, UnresolvedFlow} import org.apache.spark.sql.pipelines.logging.{PipelineEvent, RunProgress} import org.apache.spark.sql.types.StructType @@ -123,6 +124,16 @@ private[connect] object PipelinesHandler extends Logging { logInfo(s"Register sql datasets cmd received: $cmd") defineSqlGraphElements(cmd.getDefineSqlGraphElements, sessionHolder) defaultResponse + case proto.PipelineCommand.CommandTypeCase.GET_QUERY_FUNCTION_EXECUTION_SIGNAL_STREAM => + logInfo(s"Get query function execution signal stream cmd received: $cmd") + getQueryFunctionExecutionSignalStream( + cmd.getGetQueryFunctionExecutionSignalStream, responseObserver, sessionHolder) + defaultResponse + case proto.PipelineCommand.CommandTypeCase.DEFINE_FLOW_QUERY_FUNCTION_RESULT => + logInfo(s"Define flow query function result cmd received: $cmd") + defineFlowQueryFunctionResult( + cmd.getDefineFlowQueryFunctionResult, transformRelationFunc, sessionHolder) + defaultResponse case other => throw new UnsupportedOperationException(s"$other not supported") } } @@ -367,12 +378,19 @@ private[connect] object PipelinesHandler extends Logging { } val relationFlowDetails = flow.getRelationFlowDetails + val flowFunction = if (relationFlowDetails.hasRelation()) { + FlowAnalysis.createFlowFunctionFromLogicalPlan( + transformRelationFunc(relationFlowDetails.getRelation)) + } else { + FlowAnalysis.createQueryFunctionResultPollingFlowFunction( + flowIdentifier, + graphElementRegistry) + } graphElementRegistry.registerFlow( UnresolvedFlow( identifier = flowIdentifier, destinationIdentifier = destinationIdentifier, - func = FlowAnalysis.createFlowFunctionFromLogicalPlan( - transformRelationFunc(relationFlowDetails.getRelation)), + func = flowFunction, sqlConf = flow.getSqlConfMap.asScala.toMap, once = false, queryContext = QueryContext(Option(defaultCatalog), Option(defaultDatabase)), @@ -528,4 +546,126 @@ private[connect] object PipelinesHandler extends Logging { * A case class to hold the table filters for full refresh and refresh operations. */ private case class TableFilters(fullRefresh: TableFilter, refresh: TableFilter) + + /** + * Handles the GetQueryFunctionExecutionSignalStream command by monitoring for flows + * that have more resolved inputs than the last time we sent a signal. + */ + private def getQueryFunctionExecutionSignalStream( + cmd: proto.PipelineCommand.GetQueryFunctionExecutionSignalStream, + responseObserver: StreamObserver[ExecutePlanResponse], + sessionHolder: SessionHolder): Unit = { + val dataflowGraphId = cmd.getDataflowGraphId + val clientId = cmd.getClientId + + logInfo(s"Starting query function execution signal stream for " + + s"graph $dataflowGraphId, client $clientId") + + sessionHolder.getPipelineExecution(dataflowGraphId) match { + case Some(pipelineExecution) => + val execution = pipelineExecution.pipelineExecution + val graphAnalysisContext = execution.graphAnalysisContext + + try { + while (execution.executionStarted) { + val signal = proto.PipelineQueryFunctionExecutionSignal.newBuilder() + + while (!graphAnalysisContext.flowClientSignalQueue.isEmpty) { + val flowId = graphAnalysisContext.flowClientSignalQueue.remove() + signal.addFlowNames(flowId.unquotedString) + } + + if (signal.getFlowNamesCount > 0) { + logInfo(s"Sending execution signal for ${signal.getFlowNamesCount} flows") + + val response = ExecutePlanResponse.newBuilder() + .setPipelineQueryFunctionExecutionSignal(signal.build()) + .build() + + responseObserver.onNext(response) + } + + Thread.sleep(100) + } + } catch { + case e: Exception => + logError( + s"Error in query function execution signal stream for graph $dataflowGraphId", e) + responseObserver.onError(e) + } finally { + responseObserver.onCompleted() + } + + case None => + val error = new IllegalStateException( + s"No active pipeline execution found for graph $dataflowGraphId") + logError(error.getMessage) + responseObserver.onError(error) + } + } + + /** + * Handles the DefineFlowQueryFunctionResult command by storing the query function result + * in the GraphRegistrationContext. + */ + private def defineFlowQueryFunctionResult( + cmd: proto.PipelineCommand.DefineFlowQueryFunctionResult, + transformRelationFunc: Relation => LogicalPlan, + sessionHolder: SessionHolder): Unit = { + val dataflowGraphId = cmd.getDataflowGraphId + val flowName = cmd.getFlowName + val relation = cmd.getRelation + + val graphElementRegistry = + sessionHolder.dataflowGraphRegistry.getDataflowGraphOrThrow(dataflowGraphId) + + val logicalPlan = transformRelationFunc(relation) + // Assume the identifier is already fully qualified (TODO: verify this) + val flowIdentifier = GraphIdentifierManager.parseTableIdentifier( + flowName, + sessionHolder.session + ) + + graphElementRegistry.registerQueryFunctionResult( + flowIdentifier, + QueryFunctionSuccess(logicalPlan) + ) + sessionHolder.getPipelineExecution(dataflowGraphId) match { + case Some(pipelineUpdateContext) => + // TODO: what if we haven't yet started analysis? + val graphAnalysisContext = pipelineUpdateContext.pipelineExecution.graphAnalysisContext + graphAnalysisContext.markFlowPlanRegistered(flowIdentifier) + case None => + } + + logInfo(s"Registered query function result for flow '$flowName' in graph '$dataflowGraphId'") + } + + def analyze( + pipelineAnalysisContext: PipelineAnalysisContext, + logicalPlan: LogicalPlan, + sessionHolder: SessionHolder): DataFrame = { + val graphId = pipelineAnalysisContext.getDataflowGraphId + val session = sessionHolder.session + sessionHolder.getPipelineExecution(graphId) match { + case Some(pipelineUpdateContext) => + // Assume the identifier is already fully qualified + val flowIdentifier = GraphIdentifierManager.parseTableIdentifier( + pipelineAnalysisContext.getFlowName, + session + ) + + pipelineUpdateContext.pipelineExecution.graphAnalysisContext.analyze( + flowIdentifier, + logicalPlan, + pipelineUpdateContext.unresolvedGraph, + session + ) + case None => + throw new IllegalStateException( + s"Pipeline analysis context specifies flow '${pipelineAnalysisContext.getFlowName}' " + + s"but no active pipeline execution found for graph '$graphId'" + ) + } + } } diff --git a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SessionHolder.scala b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SessionHolder.scala index d0d0f0ba750a3..86906b2825b28 100644 --- a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SessionHolder.scala +++ b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SessionHolder.scala @@ -140,6 +140,7 @@ case class SessionHolder(userId: String, sessionId: String, session: SparkSessio // Registry for dataflow graphs specific to this session private[connect] lazy val dataflowGraphRegistry = new DataflowGraphRegistry() + // Handles Python process clean up for streaming queries. Initialized on first use in a query. private[connect] lazy val streamingForeachBatchRunnerCleanerCache = new StreamingForeachBatchHelper.CleanerCache(this) diff --git a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectAnalyzeHandler.scala b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectAnalyzeHandler.scala index ec8d95271c762..0535952d79fce 100644 --- a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectAnalyzeHandler.scala +++ b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectAnalyzeHandler.scala @@ -23,13 +23,14 @@ import io.grpc.stub.StreamObserver import org.apache.spark.connect.proto import org.apache.spark.internal.Logging -import org.apache.spark.sql.{AnalysisException, Row} +import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.encoders.RowEncoder import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.classic.{DataFrame, Dataset} import org.apache.spark.sql.connect.common.{DataTypeProtoConverter, InvalidPlanInput, StorageLevelProtoConverter} +import org.apache.spark.sql.connect.pipelines.PipelinesHandler import org.apache.spark.sql.connect.planner.SparkConnectPlanner -import org.apache.spark.sql.connect.utils.{PipelineAnalysisContextUtils, PlanCompressionUtils} +import org.apache.spark.sql.connect.utils.PlanCompressionUtils import org.apache.spark.sql.execution.{CodegenMode, CommandExecutionMode, CostMode, ExtendedMode, FormattedMode, SimpleMode} import org.apache.spark.sql.types.{DataType, StructType} import org.apache.spark.util.ArrayImplicits._ @@ -62,21 +63,29 @@ private[connect] class SparkConnectAnalyzeHandler( lazy val planner = new SparkConnectPlanner(sessionHolder) val session = sessionHolder.session val builder = proto.AnalyzePlanResponse.newBuilder() - val userContext = request.getUserContext - // Pipeline has not yet supported eager analysis inside flow function. - if (PipelineAnalysisContextUtils.isInsidePipelineFlowFunction(userContext)) { - throw new AnalysisException("ATTEMPT_ANALYSIS_IN_PIPELINE_QUERY_FUNCTION", Map()) + def transformRelation(rel: proto.Relation) = { + planner.transformRelation(rel, cachePlan = true) } - def transformRelation(rel: proto.Relation) = planner.transformRelation(rel, cachePlan = true) def transformRelationPlan(plan: proto.Plan) = { transformRelation(PlanCompressionUtils.decompressPlan(plan).getRoot) } def getDataFrameWithoutExecuting(rel: LogicalPlan): DataFrame = { - val qe = session.sessionState.executePlan(rel, CommandExecutionMode.SKIP) - new Dataset[Row](qe, () => RowEncoder.encoderFor(qe.analyzed.schema)) + val userContextExtensions = request.getUserContext.getExtensionsList.asScala + val pipelineAnalysisContextOpt = userContextExtensions + .filter(_.is(classOf[proto.PipelineAnalysisContext])) + .map(_.unpack(classOf[proto.PipelineAnalysisContext])) + .find(ctx => ctx.hasFlowName && ctx.hasDataflowGraphId) + + pipelineAnalysisContextOpt match { + case Some(pipelineAnalysisContext) => + PipelinesHandler.analyze(pipelineAnalysisContext, rel, sessionHolder) + case None => + val qe = session.sessionState.executePlan(rel, CommandExecutionMode.SKIP) + new Dataset[Row](qe, () => RowEncoder.encoderFor(qe.analyzed.schema)) + } } request.getAnalyzeCase match { diff --git a/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/pipelines/PythonPipelineSuite.scala b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/pipelines/PythonPipelineSuite.scala index fd05b0cc357eb..ba0c0fcae8099 100644 --- a/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/pipelines/PythonPipelineSuite.scala +++ b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/pipelines/PythonPipelineSuite.scala @@ -1169,6 +1169,35 @@ class PythonPipelineSuite |""".stripMargin) } + test("access upstream schema within query function") { + val graph = buildGraph(""" + |@dp.materialized_view + |def mv2(): + | spark.table("table1").schema + | return spark.table("table1") + | + |@dp.materialized_view + |def mv1(): + | return spark.range(5) + |""".stripMargin) + .resolve() + .validate() + assert(graph.flows.size == 2) + assert(graph.tables.size == 2) + } + + test("query function failure") { + val graph = buildGraph(""" + |@dp.materialized_view + |def mv(): + | raise ValueError("bla") + |""".stripMargin) + .resolve() + .validate() + assert(graph.flows.size == 2) + assert(graph.tables.size == 2) + } + override protected def test(testName: String, testTags: Tag*)(testFun: => Any)(implicit pos: Position): Unit = { if (PythonTestDepsChecker.isConnectDepsAvailable) { diff --git a/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/pipelines/SparkDeclarativePipelinesServerSuite.scala b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/pipelines/SparkDeclarativePipelinesServerSuite.scala index 3cb45fa6e1720..0c933f7bc95c7 100644 --- a/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/pipelines/SparkDeclarativePipelinesServerSuite.scala +++ b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/pipelines/SparkDeclarativePipelinesServerSuite.scala @@ -944,4 +944,67 @@ class SparkDeclarativePipelinesServerSuite relation.getCommand.getSqlCommand.getInput) } } + + test("SessionHolder provides pipeline execution data for analysis") { + withRawBlockingStub { implicit stub => + val graphId = createDataflowGraph + + sendPlan( + buildPlanFromPipelineCommand( + PipelineCommand + .newBuilder() + .setDefineOutput( + DefineOutput + .newBuilder() + .setDataflowGraphId(graphId) + .setOutputName("test_mv") + .setOutputType(OutputType.MATERIALIZED_VIEW)) + .build())) + + sendPlan( + buildPlanFromPipelineCommand( + PipelineCommand + .newBuilder() + .setDefineFlow( + DefineFlow + .newBuilder() + .setDataflowGraphId(graphId) + .setFlowName("test_flow") + .setTargetDatasetName("test_mv") + .setRelationFlowDetails( + DefineFlow.WriteRelationFlowDetails + .newBuilder() + .setRelation( + Relation + .newBuilder() + .setUnresolvedTableValuedFunction(UnresolvedTableValuedFunction + .newBuilder() + .setFunctionName("range") + .addArguments(Expression + .newBuilder() + .setLiteral(Expression.Literal.newBuilder().setInteger(10).build()) + .build()) + .build()) + .build()) + .build()) + .build()) + .build())) + + val sessionHolder = getDefaultSessionHolder + + assert(sessionHolder.getPipelineExecution(graphId).isEmpty, + "Pipeline execution should not exist before starting the run") + + val definition = sessionHolder.dataflowGraphRegistry.getDataflowGraphOrThrow(graphId) + val resolvedGraph = definition.toDataflowGraph.resolve() + + assert(resolvedGraph.inputIdentifiers.nonEmpty, + "Resolved graph should have input identifiers") + assert(resolvedGraph.inputIdentifiers.exists(_.table == "test_mv"), + "test_mv should be in the graph inputs") + + assert(resolvedGraph.flows.size == 1, "Graph should have one flow") + assert(resolvedGraph.tables.size == 1, "Graph should have one table") + } + } } diff --git a/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/CoreDataflowNodeProcessor.scala b/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/CoreDataflowNodeProcessor.scala index 38fde0bfec4a1..5d0e3858f6fc6 100644 --- a/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/CoreDataflowNodeProcessor.scala +++ b/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/CoreDataflowNodeProcessor.scala @@ -17,10 +17,6 @@ package org.apache.spark.sql.pipelines.graph -import java.util.concurrent.{ConcurrentHashMap, ConcurrentLinkedQueue} - -import scala.jdk.CollectionConverters._ - import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.pipelines.graph.DataflowGraphTransformer.{ @@ -32,26 +28,19 @@ import org.apache.spark.sql.pipelines.graph.DataflowGraphTransformer.{ * Processor that is responsible for analyzing each flow and sort the nodes in * topological order */ -class CoreDataflowNodeProcessor(rawGraph: DataflowGraph) { - +class CoreDataflowNodeProcessor(rawGraph: DataflowGraph, context: GraphAnalysisContext) { private val flowResolver = new FlowResolver(rawGraph) - // Map of input identifier to resolved [[Input]]. - private val resolvedInputs = new ConcurrentHashMap[TableIdentifier, Input]() - // Map & queue of resolved flows identifiers - // queue is there to track the topological order while map is used to store the id -> flow - // mapping - private val resolvedFlowNodesMap = new ConcurrentHashMap[TableIdentifier, ResolvedFlow]() - private val resolvedFlowNodesQueue = new ConcurrentLinkedQueue[ResolvedFlow]() - - private def processUnresolvedFlow(flow: UnresolvedFlow): ResolvedFlow = { + private def processUnresolvedFlow( + flow: UnresolvedFlow, + context: GraphAnalysisContext): ResolvedFlow = { val resolvedFlow = flowResolver.attemptResolveFlow( flow, rawGraph.inputIdentifiers, - resolvedInputs.asScala.toMap + context.resolvedInputsByIdentifier ) - resolvedFlowNodesQueue.add(resolvedFlow) - resolvedFlowNodesMap.put(flow.identifier, resolvedFlow) + context.resolvedFlowsQueue.add(resolvedFlow) + context.resolvedFlowsMap.put(flow.identifier, resolvedFlow) resolvedFlow } @@ -65,8 +54,8 @@ class CoreDataflowNodeProcessor(rawGraph: DataflowGraph) { */ def processNode(node: GraphElement, upstreamNodes: Seq[GraphElement]): Seq[GraphElement] = { node match { - case flow: UnresolvedFlow => Seq(processUnresolvedFlow(flow)) - case failedFlow: ResolutionFailedFlow => Seq(processUnresolvedFlow(failedFlow.flow)) + case flow: UnresolvedFlow => Seq(processUnresolvedFlow(flow, context)) + case failedFlow: ResolutionFailedFlow => Seq(processUnresolvedFlow(failedFlow.flow, context)) case table: Table => // Ensure all upstreamNodes for a table are flows val flowsToTable = upstreamNodes.map { @@ -78,7 +67,7 @@ class CoreDataflowNodeProcessor(rawGraph: DataflowGraph) { ) } val resolvedFlowsToTable = flowsToTable.map { flow => - resolvedFlowNodesMap.get(flow.identifier) + context.resolvedFlowsMap.get(flow.identifier) } // We mark all tables as virtual to ensure resolution uses incoming flows // rather than previously materialized tables. @@ -88,14 +77,14 @@ class CoreDataflowNodeProcessor(rawGraph: DataflowGraph) { incomingFlowIdentifiers = flowsToTable.map(_.identifier).toSet, availableFlows = resolvedFlowsToTable ) - resolvedInputs.put(table.identifier, virtualTableInput) + context.putResolvedInput(virtualTableInput) Seq(table) case view: View => // For view, add the flow to resolvedInputs and return empty. require(upstreamNodes.size == 1, "Found multiple flows to view") upstreamNodes.head match { case f: Flow => - resolvedInputs.put(view.identifier, resolvedFlowNodesMap.get(f.destinationIdentifier)) + context.putResolvedInput(context.resolvedFlowsMap.get(f.destinationIdentifier)) Seq(view) case _ => throw new IllegalArgumentException( @@ -120,6 +109,7 @@ private class FlowResolver(rawGraph: DataflowGraph) { flowToResolve: UnresolvedFlow, allInputs: Set[TableIdentifier], availableResolvedInputs: Map[TableIdentifier, Input]): ResolvedFlow = { + val flowFunctionResult = flowToResolve.func.call( allInputs = allInputs, availableInputs = availableResolvedInputs.values.toList, @@ -184,11 +174,17 @@ private class FlowResolver(rawGraph: DataflowGraph) { case f => f.dataFrame.failed.toOption.collectFirst { case e: UnresolvedDatasetException => e + case e: QueryFunctionResultNotAvailableException => e case _ => None } match { case Some(e: UnresolvedDatasetException) => throw TransformNodeRetryableException( - e.identifier, + Some(e.identifier), + new ResolutionFailedFlow(flowToResolve, flowFunctionResult) + ) + case Some(_: QueryFunctionResultNotAvailableException) => + throw TransformNodeRetryableException( + None, new ResolutionFailedFlow(flowToResolve, flowFunctionResult) ) case _ => diff --git a/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/DataflowGraph.scala b/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/DataflowGraph.scala index c9578ddd3b469..822d693ba9366 100644 --- a/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/DataflowGraph.scala +++ b/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/DataflowGraph.scala @@ -228,14 +228,16 @@ case class DataflowGraph( def resolved: Boolean = flows.forall(f => resolvedFlow.contains(f.identifier)) - def resolve(): DataflowGraph = + def resolve(contextOpt: Option[GraphAnalysisContext] = None): DataflowGraph = { + val context = contextOpt.getOrElse(new GraphAnalysisContext()) DataflowGraphTransformer.withDataflowGraphTransformer(this) { transformer => val coreDataflowNodeProcessor = - new CoreDataflowNodeProcessor(rawGraph = this) + new CoreDataflowNodeProcessor(rawGraph = this, context) transformer - .transformDownNodes(coreDataflowNodeProcessor.processNode) + .transformDownNodes(coreDataflowNodeProcessor.processNode, context) .getDataflowGraph } + } } object DataflowGraph { diff --git a/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/DataflowGraphTransformer.scala b/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/DataflowGraphTransformer.scala index 2523c0ae5502a..104b4910860af 100644 --- a/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/DataflowGraphTransformer.scala +++ b/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/DataflowGraphTransformer.scala @@ -18,9 +18,6 @@ package org.apache.spark.sql.pipelines.graph import java.util.concurrent.{ - ConcurrentHashMap, - ConcurrentLinkedDeque, - ConcurrentLinkedQueue, ExecutionException, Future } @@ -64,6 +61,7 @@ class DataflowGraphTransformer(graph: DataflowGraph) extends AutoCloseable { private var sinks: Seq[Sink] = graph.sinks private var sinkMap: Map[TableIdentifier, Sink] = computeSinkMap() + // Fail analysis nodes // Failed flows are flows that are failed to resolve or its inputs are not available or its // destination failed to resolve. @@ -72,6 +70,8 @@ class DataflowGraphTransformer(graph: DataflowGraph) extends AutoCloseable { private var failedTables: Seq[Table] = Seq.empty private var failedSinks: Seq[Sink] = Seq.empty + @volatile private var currentCoreDataflowNodeProcessor: Option[CoreDataflowNodeProcessor] = None + private val parallelism = 10 // Executor used to resolve nodes in parallel. It is lazily initialized to avoid creating it @@ -120,25 +120,15 @@ class DataflowGraphTransformer(graph: DataflowGraph) extends AutoCloseable { */ def transformDownNodes( transformer: (GraphElement, Seq[GraphElement]) => Seq[GraphElement], + context: GraphAnalysisContext, disableParallelism: Boolean = false): DataflowGraphTransformer = { val executor = if (disableParallelism) selfExecutor else fixedPoolExecutor val batchSize = if (disableParallelism) 1 else parallelism - // List of resolved tables, sinks and flows - val resolvedFlows = new ConcurrentLinkedQueue[ResolutionCompletedFlow]() - val resolvedTables = new ConcurrentLinkedQueue[Table]() - val resolvedViews = new ConcurrentLinkedQueue[View]() - val resolvedSinks = new ConcurrentLinkedQueue[Sink]() - // Flow identifier to a list of transformed flows mapping to track resolved flows - val resolvedFlowsMap = new ConcurrentHashMap[TableIdentifier, Seq[Flow]]() - val resolvedFlowDestinationsMap = new ConcurrentHashMap[TableIdentifier, Boolean]() - val failedFlowsQueue = new ConcurrentLinkedQueue[ResolutionFailedFlow]() - val failedDependentFlows = new ConcurrentHashMap[TableIdentifier, Seq[ResolutionFailedFlow]]() var futures = ArrayBuffer[Future[Unit]]() - val toBeResolvedFlows = new ConcurrentLinkedDeque[Flow]() - toBeResolvedFlows.addAll(flows.asJava) + context.toBeResolvedFlows.addAll(flows.asJava) - while (futures.nonEmpty || toBeResolvedFlows.peekFirst() != null) { + while (futures.nonEmpty || context.toBeResolvedFlows.peekFirst() != null) { val (done, notDone) = futures.partition(_.isDone) // Explicitly call future.get() to propagate exceptions one by one if any try { @@ -152,7 +142,7 @@ class DataflowGraphTransformer(graph: DataflowGraph) extends AutoCloseable { val flowOpt = { // We only schedule [[batchSize]] number of flows in parallel. if (futures.size < batchSize) { - Option(toBeResolvedFlows.pollFirst()) + Option(context.toBeResolvedFlows.pollFirst()) } else { None } @@ -160,131 +150,7 @@ class DataflowGraphTransformer(graph: DataflowGraph) extends AutoCloseable { flowOpt.foreach { flow => futures.append( executor.submit( - () => - try { - try { - // Note: Flow don't need their inputs passed, so for now we send empty Seq. - val result = transformer(flow, Seq.empty) - require( - result.forall(_.isInstanceOf[ResolvedFlow]), - "transformer must return a Seq[Flow]" - ) - - val transformedFlows = result.map(_.asInstanceOf[ResolvedFlow]) - resolvedFlowsMap.put(flow.identifier, transformedFlows) - resolvedFlows.addAll(transformedFlows.asJava) - } catch { - case e: TransformNodeRetryableException => - val datasetIdentifier = e.datasetIdentifier - failedDependentFlows.compute( - datasetIdentifier, - (_, flows) => { - // Don't add the input flow back but the failed flow object - // back which has relevant failure information. - val failedFlow = e.failedNode - if (flows == null) { - Seq(failedFlow) - } else { - flows :+ failedFlow - } - } - ) - // Between the time the flow started and finished resolving, perhaps the - // dependent dataset was resolved - resolvedFlowDestinationsMap.computeIfPresent( - datasetIdentifier, - (_, resolved) => { - if (resolved) { - // Check if the dataset that the flow is dependent on has been resolved - // and if so, remove all dependent flows from the failedDependentFlows and - // add them to the toBeResolvedFlows queue for retry. - failedDependentFlows.computeIfPresent( - datasetIdentifier, - (_, toRetryFlows) => { - toRetryFlows.foreach(toBeResolvedFlows.addFirst(_)) - null - } - ) - } - resolved - } - ) - case other: Throwable => throw other - } - // If all flows to this particular destination are resolved, move to the destination - // node transformer - if (flowsTo(flow.destinationIdentifier).forall({ flowToDestination => - resolvedFlowsMap.containsKey(flowToDestination.identifier) - })) { - // If multiple flows completed in parallel, ensure we resolve the destination only - // once by electing a leader via computeIfAbsent - var isCurrentThreadLeader = false - resolvedFlowDestinationsMap.computeIfAbsent(flow.destinationIdentifier, _ => { - isCurrentThreadLeader = true - // Set initial value as false as flow destination is not resolved yet. - false - }) - if (isCurrentThreadLeader) { - if (tableMap.contains(flow.destinationIdentifier)) { - val transformed = - transformer( - tableMap(flow.destinationIdentifier), - flowsTo(flow.destinationIdentifier) - ) - resolvedTables.addAll( - transformed.collect { case t: Table => t }.asJava - ) - resolvedFlows.addAll( - transformed.collect { case f: ResolvedFlow => f }.asJava - ) - } else if (viewMap.contains(flow.destinationIdentifier)) { - resolvedViews.addAll { - val transformed = - transformer( - viewMap(flow.destinationIdentifier), - flowsTo(flow.destinationIdentifier) - ) - transformed.map(_.asInstanceOf[View]).asJava - } - } else if (sinkMap.contains(flow.destinationIdentifier)) { - resolvedSinks.addAll { - val transformed = - transformer( - sinkMap(flow.destinationIdentifier), flowsTo(flow.destinationIdentifier) - ) - require( - transformed.forall(_.isInstanceOf[Sink]), - "transformer must return a Seq[Sink]" - ) - transformed.map(_.asInstanceOf[Sink]).asJava - } - } else { - throw new IllegalArgumentException( - s"Unsupported destination ${flow.destinationIdentifier.unquotedString}" + - s" in flow: ${flow.displayName} at transformDownNodes" - ) - } - // Set flow destination as resolved now. - resolvedFlowDestinationsMap.computeIfPresent( - flow.destinationIdentifier, - (_, _) => { - // If there are any other node failures dependent on this destination, retry - // them - failedDependentFlows.computeIfPresent( - flow.destinationIdentifier, - (_, toRetryFlows) => { - toRetryFlows.foreach(toBeResolvedFlows.addFirst(_)) - null - } - ) - true - } - ) - } - } - } catch { - case ex: TransformNodeFailedException => failedFlowsQueue.add(ex.failedNode) - } + () => transformFlowAndMaybeDestination(flow, transformer, context) ) ) } @@ -294,18 +160,18 @@ class DataflowGraphTransformer(graph: DataflowGraph) extends AutoCloseable { // A table is failed to analyze if: // - It does not exist in the resolvedFlowDestinationsMap failedTables = tables.filterNot { table => - resolvedFlowDestinationsMap.getOrDefault(table.identifier, false) + context.resolvedFlowDestinationsMap.getOrDefault(table.identifier, false) } // A sink is failed to analyze if: // - It does not exist in the resolvedFlowDestinationsMap failedSinks = sinks.filterNot { sink => - resolvedFlowDestinationsMap.getOrDefault(sink.identifier, false) + context.resolvedFlowDestinationsMap.getOrDefault(sink.identifier, false) } // We maintain the topological sort order of successful flows always val (resolvedFlowsWithResolvedDest, resolvedFlowsWithFailedDest) = - resolvedFlows.asScala.toSeq.partition(flow => { - resolvedFlowDestinationsMap.getOrDefault(flow.destinationIdentifier, false) + context.resolvedFlowsQueue.asScala.toSeq.partition(flow => { + context.resolvedFlowDestinationsMap.getOrDefault(flow.destinationIdentifier, false) }) // A flow is failed to analyze if: @@ -318,22 +184,144 @@ class DataflowGraphTransformer(graph: DataflowGraph) extends AutoCloseable { // All transformed flows that write to a destination that is failed to analyze. resolvedFlowsWithFailedDest ++ // All failed flows thrown by TransformNodeFailedException - failedFlowsQueue.asScala.toSeq ++ + context.failedFlowsQueue.asScala.toSeq ++ // All flows that have not been transformed and resolved yet - failedDependentFlows.values().asScala.flatten.toSeq + context.failedDependentFlows.values().asScala.flatten.toSeq // Mutate the resolved entities flows = resolvedFlowsWithResolvedDest flowsTo = computeFlowsTo() - tables = resolvedTables.asScala.toSeq - views = resolvedViews.asScala.toSeq - sinks = resolvedSinks.asScala.toSeq + tables = context.resolvedTables.asScala.toSeq + views = context.resolvedViews.asScala.toSeq + sinks = context.resolvedSinks.asScala.toSeq tableMap = computeTableMap() viewMap = computeViewMap() sinkMap = computeSinkMap() this } + private[pipelines] def transformFlowAndMaybeDestination( + flow: Flow, + transformer: (GraphElement, Seq[GraphElement]) => Seq[GraphElement], + context: GraphAnalysisContext): Unit = { + try { + try { + // Note: Flow don't need their inputs passed, so for now we send empty Seq. + transformer(flow, Seq.empty) + } catch { + case e: TransformNodeRetryableException => + e.datasetIdentifier match { + case None => context.failedUnregisteredFlows.put(e.failedNode.identifier, e.failedNode) + case Some(datasetIdentifier) => + context.registerFailedDependentFlow(datasetIdentifier, e.failedNode) + // Between the time the flow started and finished resolving, perhaps the + // dependent dataset was resolved + context.resolvedFlowDestinationsMap.computeIfPresent( + datasetIdentifier, + (_, resolved) => { + if (resolved) { + // Check if the dataset that the flow is dependent on has been resolved + // and if so, remove all dependent flows from the failedDependentFlows and + // add them to the toBeResolvedFlows queue for retry. + context.failedDependentFlows.computeIfPresent( + datasetIdentifier, + (_, toRetryFlows) => { + toRetryFlows.foreach { f => + if (context.failedUnregisteredFlows.containsKey(f.identifier)) { + context.flowClientSignalQueue.add(f.identifier) + } else { + context.toBeResolvedFlows.addFirst(f) + } + } + null + } + ) + } + resolved + } + ) + } + case other: Throwable => throw other + } + // If all flows to this particular destination are resolved, move to the destination + // node transformer + if (flowsTo(flow.destinationIdentifier).forall({ flowToDestination => + context.resolvedFlowsMap.containsKey(flowToDestination.identifier) + })) { + // If multiple flows completed in parallel, ensure we resolve the destination only + // once by electing a leader via computeIfAbsent + var isCurrentThreadLeader = false + context.resolvedFlowDestinationsMap.computeIfAbsent(flow.destinationIdentifier, _ => { + isCurrentThreadLeader = true + // Set initial value as false as flow destination is not resolved yet. + false + }) + if (isCurrentThreadLeader) { + if (tableMap.contains(flow.destinationIdentifier)) { + val transformed = + transformer( + tableMap(flow.destinationIdentifier), + flowsTo(flow.destinationIdentifier) + ) + context.resolvedTables.addAll( + transformed.collect { case t: Table => t }.asJava + ) + } else if (viewMap.contains(flow.destinationIdentifier)) { + context.resolvedViews.addAll { + val transformed = + transformer( + viewMap(flow.destinationIdentifier), + flowsTo(flow.destinationIdentifier) + ) + transformed.map(_.asInstanceOf[View]).asJava + } + } else if (sinkMap.contains(flow.destinationIdentifier)) { + context.resolvedSinks.addAll { + val transformed = + transformer( + sinkMap(flow.destinationIdentifier), flowsTo(flow.destinationIdentifier) + ) + require( + transformed.forall(_.isInstanceOf[Sink]), + "transformer must return a Seq[Sink]" + ) + transformed.map(_.asInstanceOf[Sink]).asJava + } + } else { + throw new IllegalArgumentException( + s"Unsupported destination ${flow.destinationIdentifier.unquotedString}" + + s" in flow: ${flow.displayName} at transformDownNodes" + ) + } + // Set flow destination as resolved now. + context.resolvedFlowDestinationsMap.computeIfPresent( + flow.destinationIdentifier, + (_, _) => { + // If there are any other node failures dependent on this destination, retry + // them + context.failedDependentFlows.computeIfPresent( + flow.destinationIdentifier, + (_, toRetryFlows) => { + toRetryFlows.foreach { f => + if (context.failedUnregisteredFlows.containsKey(f.identifier)) { + context.flowClientSignalQueue.add(f.identifier) + } else { + context.toBeResolvedFlows.addFirst(f) + } + } + null + } + ) + true + } + ) + } + } + } catch { + case ex: TransformNodeFailedException => context.failedFlowsQueue.add(ex.failedNode) + } + } + def getDataflowGraph: DataflowGraph = { graph.copy( // Returns all flows (resolved and failed) in topological order. @@ -356,6 +344,7 @@ class DataflowGraphTransformer(graph: DataflowGraph) extends AutoCloseable { object DataflowGraphTransformer { + /** * Exception thrown when transforming a node in the graph fails because at least one of its * dependencies weren't yet transformed. @@ -364,7 +353,7 @@ object DataflowGraphTransformer { * dataflow graph. */ case class TransformNodeRetryableException( - datasetIdentifier: TableIdentifier, + datasetIdentifier: Option[TableIdentifier], failedNode: ResolutionFailedFlow) extends Exception with NoStackTrace diff --git a/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/Flow.scala b/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/Flow.scala index e329308502f0d..48c04c02afe5e 100644 --- a/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/Flow.scala +++ b/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/Flow.scala @@ -120,6 +120,23 @@ case class FlowFunctionResult( final def resolved: Boolean = failure.isEmpty // don't override this, override failure } +object FlowFunctionResult { + def fromFlowAnalysisContext( + ctx: FlowAnalysisContext, + df: Try[DataFrame], + confs: Map[String, String]): FlowFunctionResult = { + FlowFunctionResult( + requestedInputs = ctx.requestedInputs.toSet, + batchInputs = ctx.batchInputs.toSet, + streamingInputs = ctx.streamingInputs.toSet, + usedExternalInputs = ctx.externalInputs.toSet, + dataFrame = df, + sqlConf = confs, + analysisWarnings = ctx.analysisWarnings.toList + ) + } +} + /** A [[Flow]] whose output schema and dependencies aren't known. */ case class UnresolvedFlow( identifier: TableIdentifier, diff --git a/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/FlowAnalysis.scala b/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/FlowAnalysis.scala index 1a00a6339c4ba..5c068ae718fcf 100644 --- a/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/FlowAnalysis.scala +++ b/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/FlowAnalysis.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.pipelines.graph -import scala.util.Try +import scala.util.{Failure, Try} import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.{AliasIdentifier, TableIdentifier} @@ -58,18 +58,78 @@ object FlowAnalysis { } finally { ctx.restoreOriginalConf() } - FlowFunctionResult( - requestedInputs = ctx.requestedInputs.toSet, - batchInputs = ctx.batchInputs.toSet, - streamingInputs = ctx.streamingInputs.toSet, - usedExternalInputs = ctx.externalInputs.toSet, - dataFrame = df, - sqlConf = confs, - analysisWarnings = ctx.analysisWarnings.toList - ) + FlowFunctionResult.fromFlowAnalysisContext(ctx, df, confs) } } + /** + * Creates a [[FlowFunction]] that looks up the LogicalPlan from the GraphRegistrationContext + * for the given flow name and analyzes it. + * + * @param flowIdentifier The ID of the flow to look up in the context. + * @param context The GraphRegistrationContext containing query function results. + * @return A FlowFunction that analyzes the stored LogicalPlan for the flow. + */ + def createQueryFunctionResultPollingFlowFunction( + flowIdentifier: TableIdentifier, + context: GraphRegistrationContext): FlowFunction = { + (allInputs: Set[TableIdentifier], + availableInputs: Seq[Input], + confs: Map[String, String], + queryContext: QueryContext, + queryOrigin: QueryOrigin) => { + context.getQueryFunctionResult(flowIdentifier) match { + case Some(result: QueryFunctionSuccess) => + val flowFunc = createFlowFunctionFromLogicalPlan(result.plan) + flowFunc.call(allInputs, availableInputs, confs, queryContext, queryOrigin) + case Some(QueryFunctionTerminalFailure) => + FlowFunctionResult( + requestedInputs = Set.empty, + batchInputs = Set.empty, + streamingInputs = Set.empty, + usedExternalInputs = Set.empty, + dataFrame = Failure(QueryFunctionTerminalFailureException()), + sqlConf = confs, + analysisWarnings = Seq.empty + ) + case None => + FlowFunctionResult( + requestedInputs = Set.empty, + batchInputs = Set.empty, + streamingInputs = Set.empty, + usedExternalInputs = Set.empty, + dataFrame = Failure(QueryFunctionResultNotAvailableException()), + sqlConf = confs, + analysisWarnings = Seq.empty + ) + } + } + } + + /** + * Public wrapper method for flow analysis from Spark Connect. + * Creates FlowAnalysisContext internally and calls the analyze method. + */ + def analyze( + allInputs: Set[TableIdentifier], + availableInputs: Seq[Input], + currentCatalog: String, + currentDatabase: String, + spark: SparkSession, + plan: LogicalPlan + ): DataFrame = { + val context = FlowAnalysisContext( + allInputs = allInputs, + availableInputs = availableInputs, + queryContext = QueryContext( + currentCatalog = Some(currentCatalog), + currentDatabase = Some(currentDatabase) + ), + spark = spark + ) + analyze(context, plan) + } + /** * Constructs an analyzed [[DataFrame]] from a [[LogicalPlan]] by resolving Pipelines specific * TVFs and datasets that cannot be resolved directly by Catalyst. @@ -81,7 +141,7 @@ object FlowAnalysis { * @param plan The [[LogicalPlan]] defining a flow. * @return An analyzed [[DataFrame]]. */ - private def analyze( + def analyze( context: FlowAnalysisContext, plan: LogicalPlan ): DataFrame = { diff --git a/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/GraphAnalysisContext.scala b/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/GraphAnalysisContext.scala new file mode 100644 index 0000000000000..8d1900afc55bf --- /dev/null +++ b/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/GraphAnalysisContext.scala @@ -0,0 +1,114 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.pipelines.graph + +import java.util.concurrent.{ConcurrentHashMap, ConcurrentLinkedDeque, ConcurrentLinkedQueue} + +import scala.jdk.CollectionConverters.ConcurrentMapHasAsScala +import scala.util.Failure + +import org.apache.spark.sql.catalyst.TableIdentifier +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.classic.{DataFrame, SparkSession} + +class GraphAnalysisContext { + val toBeResolvedFlows = new ConcurrentLinkedDeque[Flow]() + + // Map of input identifier to resolved [[Input]]. + private val resolvedInputsHashMap = new ConcurrentHashMap[TableIdentifier, Input]() + + // Destination identifier to boolean indicating whether the destination has been resolved + val resolvedFlowDestinationsMap = new ConcurrentHashMap[TableIdentifier, Boolean]() + + // Map & queue of resolved flows identifiers + // queue is there to track the topological order while map is used to store the id -> flow + // mapping + val resolvedFlowsMap = new ConcurrentHashMap[TableIdentifier, ResolvedFlow]() + val resolvedFlowsQueue = new ConcurrentLinkedQueue[ResolvedFlow]() + + // List of resolved tables, sinks and flows + val resolvedTables = new ConcurrentLinkedQueue[Table]() + val resolvedViews = new ConcurrentLinkedQueue[View]() + val resolvedSinks = new ConcurrentLinkedQueue[Sink]() + + // Flows that failed due to the client not yet having registered a plan. Keyed by flow identifier. + val failedUnregisteredFlows = new ConcurrentHashMap[TableIdentifier, ResolutionFailedFlow]() + + // Dataset identifier to list of flows that failed resolution due to missing this dataset + val failedDependentFlows = new ConcurrentHashMap[TableIdentifier, Seq[ResolutionFailedFlow]]() + val failedFlowsQueue = new ConcurrentLinkedQueue[ResolutionFailedFlow]() + + // Queue of flow identifiers that have had their upstream inputs resolved and should have their + // flow function retried on the client + val flowClientSignalQueue = new ConcurrentLinkedQueue[TableIdentifier]() + + def putResolvedInput(input: Input): Unit = { + resolvedInputsHashMap.put(input.identifier, input) + } + + def resolvedInputsByIdentifier: Map[TableIdentifier, Input] = resolvedInputsHashMap.asScala.toMap + + def registerFailedDependentFlow( + inputDatasetIdentifier: TableIdentifier, + failedFlow: ResolutionFailedFlow): Unit = { + failedDependentFlows.compute( + inputDatasetIdentifier, + (_, flows) => { + if (flows == null) { + Seq(failedFlow) + } else { + flows :+ failedFlow + } + } + ) + } + + def markFlowPlanRegistered(flowIdentifier: TableIdentifier): Unit = { + val flow = failedUnregisteredFlows.remove(flowIdentifier) + // Flow could be null if it failed on the client side and we never tried to execute the flow + // function on the server side. + if (flow != null) { + toBeResolvedFlows.addFirst(flow) + } + } + + def analyze( + flowIdentifier: TableIdentifier, + logicalPlan: LogicalPlan, + unresolvedGraph: DataflowGraph, + session: SparkSession): DataFrame = { + val unresolvedFlow = unresolvedGraph.flow(flowIdentifier).asInstanceOf[UnresolvedFlow] + val flowAnalysisContext = FlowAnalysisContext( + allInputs = unresolvedGraph.inputIdentifiers, + availableInputs = resolvedInputsByIdentifier.values.toSeq, + queryContext = unresolvedFlow.queryContext, + spark = session + ) + + try { + FlowAnalysis.analyze(flowAnalysisContext, plan = logicalPlan) + } catch { + case e: UnresolvedDatasetException => + val flowFunctionResult = + FlowFunctionResult.fromFlowAnalysisContext(flowAnalysisContext, Failure(e), Map.empty) + + val resolutionFailedFlow = new ResolutionFailedFlow(unresolvedFlow, flowFunctionResult) + registerFailedDependentFlow(e.identifier, resolutionFailedFlow) + throw e + } + } +} diff --git a/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/GraphRegistrationContext.scala b/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/GraphRegistrationContext.scala index dadda0561b19f..f058917b4f7cd 100644 --- a/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/GraphRegistrationContext.scala +++ b/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/GraphRegistrationContext.scala @@ -16,10 +16,13 @@ */ package org.apache.spark.sql.pipelines.graph +import java.util.concurrent.ConcurrentHashMap + import scala.collection.mutable import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.TableIdentifier +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan /** * A mutable context for registering tables, views, and flows in a dataflow graph. @@ -37,6 +40,8 @@ class GraphRegistrationContext( protected val views = new mutable.ListBuffer[View] protected val sinks = new mutable.ListBuffer[Sink] protected val flows = new mutable.ListBuffer[UnresolvedFlow] + // keyed by flow ID + private val queryFunctionResults = new ConcurrentHashMap[TableIdentifier, QueryFunctionResult]() def registerTable(tableDef: Table): Unit = { tables += tableDef @@ -62,6 +67,15 @@ class GraphRegistrationContext( flows += flowDef.copy(sqlConf = defaultSqlConf ++ flowDef.sqlConf) } + def registerQueryFunctionResult( + flowIdentifier: TableIdentifier, result: QueryFunctionResult): Unit = { + queryFunctionResults.put(flowIdentifier, result) + } + + def getQueryFunctionResult(flowIdentifier: TableIdentifier): Option[QueryFunctionResult] = { + Option(queryFunctionResults.get(flowIdentifier)) + } + private def isEmpty: Boolean = { tables.isEmpty && views.collect { case v: PersistedView => v @@ -189,3 +203,9 @@ object GraphRegistrationContext { override def toString: String = "SINK" } } + +sealed trait QueryFunctionResult + +case class QueryFunctionSuccess(plan: LogicalPlan) extends QueryFunctionResult + +case object QueryFunctionTerminalFailure extends QueryFunctionResult diff --git a/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/PipelineExecution.scala b/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/PipelineExecution.scala index d35d701d44e57..ffbecdda6c586 100644 --- a/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/PipelineExecution.scala +++ b/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/PipelineExecution.scala @@ -36,6 +36,7 @@ class PipelineExecution(context: PipelineUpdateContext) { /** [Visible for testing] */ private[pipelines] var graphExecution: Option[TriggeredGraphExecution] = None + val graphAnalysisContext = new GraphAnalysisContext() def executionStarted: Boolean = synchronized { graphExecution.nonEmpty } @@ -110,7 +111,7 @@ class PipelineExecution(context: PipelineUpdateContext) { private def resolveGraph(): DataflowGraph = { try { - context.unresolvedGraph.resolve().validate() + context.unresolvedGraph.resolve(Some(graphAnalysisContext)).validate() } catch { case e: UnresolvedPipelineException => handleInvalidPipeline(e) diff --git a/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/PipelinesErrors.scala b/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/PipelinesErrors.scala index 7116f5fbcf068..979aaba388c97 100644 --- a/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/PipelinesErrors.scala +++ b/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/PipelinesErrors.scala @@ -35,6 +35,12 @@ case class UnresolvedDatasetException(identifier: TableIdentifier) s"pipeline but could not be resolved." ) +case class QueryFunctionResultNotAvailableException() + extends AnalysisException("Query function result is not yet available.") + +case class QueryFunctionTerminalFailureException() + extends AnalysisException("Query function failed in a way that further analysis won't fix.") + /** * Exception raised when a flow fails to read from a table defined within the pipeline * diff --git a/sql/pipelines/src/test/scala/org/apache/spark/sql/pipelines/graph/DistributedAnalysisSuite.scala b/sql/pipelines/src/test/scala/org/apache/spark/sql/pipelines/graph/DistributedAnalysisSuite.scala new file mode 100644 index 0000000000000..0f9aa71b12db5 --- /dev/null +++ b/sql/pipelines/src/test/scala/org/apache/spark/sql/pipelines/graph/DistributedAnalysisSuite.scala @@ -0,0 +1,262 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.pipelines.graph + +import scala.collection.mutable + +import org.apache.spark.sql.catalyst.TableIdentifier +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.classic.SparkSession +import org.apache.spark.sql.pipelines.utils.{PipelineTest, TestGraphRegistrationContext} +import org.apache.spark.sql.test.SharedSparkSession + +/** + * TODO + */ +class DistributedAnalysisSuite extends PipelineTest with SharedSparkSession { + case class DummyFlowInfo( + destinationTable: String, + eagerDeps: Seq[String] = Seq.empty, + lazyDeps: Seq[String] = Seq.empty, + failsTerminally: Boolean = false + ) { + val flowName: String = destinationTable.replace("table", "flow") + + def register(registrationContext: TestGraphRegistrationContext): Unit = { + val flowFunction = if (eagerDeps.nonEmpty) { + FlowAnalysis.createQueryFunctionResultPollingFlowFunction( + fullyQualifiedIdentifier(flowName), + registrationContext + ) + } else { + val plan = if (lazyDeps.nonEmpty) { + spark.sessionState.sqlParser + .parsePlan(eagerDeps.map(t => s"select * from $t").mkString(" union ")) + } else { + spark.range(4).logicalPlan + } + FlowAnalysis.createFlowFunctionFromLogicalPlan(plan) + } + registrationContext.registerFlow(destinationTable, flowName, flowFunction) + } + } + + class TestDistributedAnalyzer(val dummyFlows: Seq[DummyFlowInfo], val spark: SparkSession) { + val registrationContext = new TestGraphRegistrationContext(spark) + dummyFlows.foreach { f => + f.register(registrationContext) + registrationContext.registerTable(f.destinationTable) + } + private val dummyFlowsById = dummyFlows.map { f => + (fullyQualifiedIdentifier(f.flowName), f) + }.toMap + + val unresolvedGraph: DataflowGraph = registrationContext.toDataflowGraph + val analysisContext = new GraphAnalysisContext() + unresolvedGraph.flows.foreach(analysisContext.toBeResolvedFlows.add) + val transformer = new DataflowGraphTransformer(unresolvedGraph) + + val nodeProcessor = new CoreDataflowNodeProcessor(unresolvedGraph, analysisContext) + + // Signals received by the client, but not yet processed + val clientSignalQueue = new mutable.Queue[TableIdentifier]() + + def transformerIteration(): Unit = { + Option(analysisContext.toBeResolvedFlows.pollFirst()).foreach { flow => + transformer.transformFlowAndMaybeDestination( + flow, + nodeProcessor.processNode, + analysisContext + ) + } + } + + def analyze(flowIdentifier: TableIdentifier, plan: LogicalPlan): Unit = { + analysisContext.analyze(flowIdentifier, plan, unresolvedGraph, spark) + } + + def registerQueryFunctionResult( + flowIdentifier: TableIdentifier, + result: QueryFunctionResult): Unit = { + registrationContext.registerQueryFunctionResult(flowIdentifier, result) + analysisContext.markFlowPlanRegistered(flowIdentifier) + } + + def runToCompletion(randomSeed: Integer): Unit = { + val r = new scala.util.Random(randomSeed) + + var numIters = 0 + while (numIters < 200 + && analysisContext.toBeResolvedFlows.isEmpty + && clientSignalQueue.isEmpty + && analysisContext.flowClientSignalQueue.isEmpty) { + numIters += 1 + + r.nextInt(3) match { + case 0 => transformerIteration() + case 1 => // Simulate the signal-sending RPC thread + if (!analysisContext.flowClientSignalQueue.isEmpty) { + val flowToRetryId = analysisContext.flowClientSignalQueue.poll() + dummyFlowsById(flowToRetryId).eagerDeps.foreach { tableNames => + try { + tableNames.foreach { tableName => + val plan = spark.sessionState.sqlParser.parsePlan(s"select * from $tableName") + analyze(flowToRetryId, plan) + } + clientSignalQueue.append(flowToRetryId) + } catch { + case _: UnresolvedDatasetException => + } + } + } + case 2 => // Simulate the client + if (clientSignalQueue.nonEmpty) { + val flowId = clientSignalQueue.dequeue() + val upstreamTables = dummyFlowsById(flowId).eagerDeps + val dummyFlow = dummyFlowsById(flowId) + val result = if (dummyFlow.failsTerminally) { + QueryFunctionTerminalFailure + } else { + QueryFunctionSuccess( + spark.sessionState.sqlParser + .parsePlan(s"select * from $upstreamTables") + .logicalPlan + ) + } + registerQueryFunctionResult(flowId, result) + } + } + + assert( + analysisContext.toBeResolvedFlows.size + + analysisContext.resolvedFlowsMap.size + + analysisContext.failedUnregisteredFlows.size + + analysisContext.failedDependentFlows.size + == unresolvedGraph.flows.size + ) + } + } + } + + test("single node external .columns") { + val externalTableId = fullyQualifiedIdentifier("external") + spark.sql(s"CREATE TABLE $externalTableId AS SELECT * FROM RANGE(3)") + val dummyFlows = Seq( + DummyFlowInfo("table1", eagerDeps = Seq(externalTableId.quotedString)) + ) + val testDistributedAnalyzer = new TestDistributedAnalyzer(dummyFlows, spark) + + // transformer processes node, flow function should fail because no relation for node, get + // popped back on queue + testDistributedAnalyzer.transformerIteration() + + // define query function result + testDistributedAnalyzer.registerQueryFunctionResult( + fullyQualifiedIdentifier("table1"), + QueryFunctionSuccess(spark.range(4).logicalPlan) + ) + + // transformer processes node again. flow function should succeed. + testDistributedAnalyzer.transformerIteration() + } + + test("two nodes with second has .columns") { + // flow 1 -> table -> flow 2 -> table + // flow 2 can't be resolved immediately because it analyzes flow 1 + val dummyFlows = Seq( + DummyFlowInfo("table2", eagerDeps = Seq("table1")), + DummyFlowInfo("table1") + ) + + val testDistributedAnalyzer = new TestDistributedAnalyzer(dummyFlows, spark) + val analysisContext = testDistributedAnalyzer.analysisContext + val selectFromTable1Plan = spark.sessionState.sqlParser.parsePlan("select * from table1") + + val flow1Id = fullyQualifiedIdentifier("flow1") + val table1Id = fullyQualifiedIdentifier("table1") + val flow2Id = fullyQualifiedIdentifier("flow2") + + // Attempt to analyze table1 on behalf of flow2. Should fail and record dependency. + intercept[UnresolvedDatasetException]( + testDistributedAnalyzer.analyze(flow2Id, selectFromTable1Plan) + ) + val failedDependentFlows = analysisContext.failedDependentFlows + assert(failedDependentFlows.size() == 1) + assert(failedDependentFlows.get(table1Id).map(_.identifier).toSet == Set(flow2Id)) + + // transformer processes flow2, flow function should fail because no relation for node, get + // added to unregistered list + testDistributedAnalyzer.transformerIteration() + assert(analysisContext.failedUnregisteredFlows.containsKey(flow2Id)) + + // transformer processes flow1, flow function should succeed + // flow2 shouldn't be added to the queue, because its plan is still unregistered + testDistributedAnalyzer.transformerIteration() + assert(analysisContext.toBeResolvedFlows.isEmpty) + assert(analysisContext.resolvedFlowsMap.containsKey(flow1Id)) + + // detects flow1 is resolved, sends signal to retry flow2 + assert( + testDistributedAnalyzer.analysisContext.flowClientSignalQueue.toArray.toSeq == Seq(flow2Id) + ) + testDistributedAnalyzer.analysisContext.flowClientSignalQueue.clear() + + testDistributedAnalyzer.analyze(flow2Id, selectFromTable1Plan) + + // define query function result + testDistributedAnalyzer.registerQueryFunctionResult( + flow2Id, + QueryFunctionSuccess(selectFromTable1Plan) + ) + assert(analysisContext.failedUnregisteredFlows.isEmpty) + assert(analysisContext.toBeResolvedFlows.size == 1) + + // transformer processes node again. flow function should succeed. + testDistributedAnalyzer.transformerIteration() + assert(analysisContext.toBeResolvedFlows.isEmpty) + assert(analysisContext.failedUnregisteredFlows.isEmpty) + assert(analysisContext.failedDependentFlows.isEmpty) + assert(analysisContext.resolvedFlowsMap.size == 2) + } + + test("random orderings") { + // flow 1 -> table -> flow 2 -> table + // flow 2 can't be resolved immediately because it analyzes flow 1 + val dummyFlows = Seq( + DummyFlowInfo("table2", eagerDeps = Seq("table1")), + DummyFlowInfo("table1") + ) + + val testDistributedAnalyzer = new TestDistributedAnalyzer(dummyFlows, spark) + + testDistributedAnalyzer.runToCompletion(4367) + assert(testDistributedAnalyzer.analysisContext.failedUnregisteredFlows.isEmpty) + assert(testDistributedAnalyzer.analysisContext.failedDependentFlows.isEmpty) + assert(testDistributedAnalyzer.analysisContext.resolvedFlowsMap.size == 2) + } + + test("query function fails after eager analysis") { + val dummyFlows = Seq( + DummyFlowInfo("table2", eagerDeps = Seq("table1"), failsTerminally = true), + DummyFlowInfo("table1") + ) + val testDistributedAnalyzer = new TestDistributedAnalyzer(dummyFlows, spark) + testDistributedAnalyzer.runToCompletion(4367) + assert(testDistributedAnalyzer.analysisContext.resolvedFlowsMap.size == 0) + } +} From e8fa38d9708cdd8aa162f32073768f92d270b5d3 Mon Sep 17 00:00:00 2001 From: Sandy Ryza Date: Sat, 20 Dec 2025 12:39:10 -0800 Subject: [PATCH 2/7] PythonPipelineSuite stuff --- .../spark_connect_graph_element_registry.py | 22 +++-- .../tests/python_pipeline_suite_helpers.py | 24 +++++ .../connect/pipelines/PipelinesHandler.scala | 12 ++- .../pipelines/PythonPipelineSuite.scala | 92 +++++++++++++------ .../pipelines/graph/PipelineExecution.scala | 2 +- 5 files changed, 108 insertions(+), 44 deletions(-) create mode 100644 python/pyspark/pipelines/tests/python_pipeline_suite_helpers.py diff --git a/python/pyspark/pipelines/spark_connect_graph_element_registry.py b/python/pyspark/pipelines/spark_connect_graph_element_registry.py index fb51185342fb9..800321b7f238b 100644 --- a/python/pyspark/pipelines/spark_connect_graph_element_registry.py +++ b/python/pyspark/pipelines/spark_connect_graph_element_registry.py @@ -17,7 +17,7 @@ from pathlib import Path from pyspark.errors.exceptions.base import PySparkValueError -from pyspark.errors import PySparkTypeError +from pyspark.errors import PySparkException, PySparkTypeError from pyspark.sql import SparkSession, DataFrame from pyspark.sql.connect.dataframe import DataFrame as ConnectDataFrame from pyspark.pipelines.output import ( @@ -119,12 +119,18 @@ def register_flow(self, flow: Flow) -> None: self._query_funcs_by_flow_name[flow.name] = flow.func try: df = self._execute_query_function(flow.name, flow.func) - except Exception as e: - raise PySparkValueError( - f"Error executing query function for flow {flow.name}: {e}" - ) + except PySparkException as e: + print("pizza: exception while executing query function") + if e.getCondition() == "ATTEMPT_ANALYSIS_IN_PIPELINE_QUERY_FUNCTION": + print("pizza: exception is an analysis exception") + df = None + else: + raise e - relation = cast(ConnectDataFrame, df)._plan.plan(self._client) + if df is not None: + relation = cast(ConnectDataFrame, df)._plan.plan(self._client) + else: + relation = None relation_flow_details = pb2.PipelineCommand.DefineFlow.WriteRelationFlowDetails( relation=relation, @@ -167,9 +173,8 @@ def register_signalled_query_functions(self) -> None: ) command = pb2.Command() command.pipeline_command.get_query_function_execution_signal_stream.CopyFrom(inner_command) - self._client.execute_command_as_iterator(command) - result_iter = self._client.execute_command_as_iterator(inner_command) + result_iter = self._client.execute_command_as_iterator(command) for result in result_iter: if "pipeline_query_function_execution_signal" not in result.keys(): raise PySparkValueError( @@ -178,6 +183,7 @@ def register_signalled_query_functions(self) -> None: signal = result["pipeline_query_function_execution_signal"] flow_names = signal.flow_names + print("pizza: received signal with flow names: ", flow_names) for flow_name in flow_names: func = self._query_funcs_by_flow_name[flow_name] df = self._execute_query_function(flow_name, func) diff --git a/python/pyspark/pipelines/tests/python_pipeline_suite_helpers.py b/python/pyspark/pipelines/tests/python_pipeline_suite_helpers.py new file mode 100644 index 0000000000000..74c29728dd811 --- /dev/null +++ b/python/pyspark/pipelines/tests/python_pipeline_suite_helpers.py @@ -0,0 +1,24 @@ +from pyspark.sql import SparkSession +from pyspark import pipelines as dp +from pyspark.pipelines.spark_connect_graph_element_registry import ( + SparkConnectGraphElementRegistry, +) +from pyspark.pipelines.spark_connect_pipeline import create_dataflow_graph + + +def setup(server_port: str, session_identifier: str) -> tuple[SparkSession, SparkConnectGraphElementRegistry]: + spark = SparkSession.builder \ + .remote(f"sc://localhost:{server_port}") \ + .config("spark.connect.grpc.channel.timeout", "5s") \ + .config("spark.custom.identifier", session_identifier) \ + .create() + + dataflow_graph_id = create_dataflow_graph( + spark, + default_catalog=None, + default_database=None, + sql_conf={}, + ) + + registry = SparkConnectGraphElementRegistry(spark, dataflow_graph_id) + return spark, registry diff --git a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/pipelines/PipelinesHandler.scala b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/pipelines/PipelinesHandler.scala index 3fe5e34af5f31..fff772b53f5ac 100644 --- a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/pipelines/PipelinesHandler.scala +++ b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/pipelines/PipelinesHandler.scala @@ -378,7 +378,7 @@ private[connect] object PipelinesHandler extends Logging { } val relationFlowDetails = flow.getRelationFlowDetails - val flowFunction = if (relationFlowDetails.hasRelation()) { + val flowFunction = if (relationFlowDetails.hasRelation) { FlowAnalysis.createFlowFunctionFromLogicalPlan( transformRelationFunc(relationFlowDetails.getRelation)) } else { @@ -662,10 +662,12 @@ private[connect] object PipelinesHandler extends Logging { session ) case None => - throw new IllegalStateException( - s"Pipeline analysis context specifies flow '${pipelineAnalysisContext.getFlowName}' " + - s"but no active pipeline execution found for graph '$graphId'" - ) + throw new AnalysisException("ATTEMPT_ANALYSIS_IN_PIPELINE_QUERY_FUNCTION", Map()) + // TODO +// throw new IllegalStateException( +// s"Pipeline analysis context specifies flow '${pipelineAnalysisContext.getFlowName}' " + +// s"but no active pipeline execution found for graph '$graphId'" +// ) } } } diff --git a/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/pipelines/PythonPipelineSuite.scala b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/pipelines/PythonPipelineSuite.scala index ba0c0fcae8099..1da87acc5a94a 100644 --- a/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/pipelines/PythonPipelineSuite.scala +++ b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/pipelines/PythonPipelineSuite.scala @@ -33,11 +33,11 @@ import org.apache.spark.api.python.PythonUtils import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.connect.PythonTestDepsChecker -import org.apache.spark.sql.connect.service.SparkConnectService +import org.apache.spark.sql.connect.service.{SessionHolder, SparkConnectService} import org.apache.spark.sql.connector.catalog.{Identifier, TableCatalog} import org.apache.spark.sql.pipelines.Language.Python import org.apache.spark.sql.pipelines.common.FlowStatus -import org.apache.spark.sql.pipelines.graph.{DataflowGraph, PipelineUpdateContextImpl, QueryOrigin, QueryOriginType} +import org.apache.spark.sql.pipelines.graph.{DataflowGraph, GraphRegistrationContext, PipelineUpdateContextImpl, QueryOrigin, QueryOriginType} import org.apache.spark.sql.pipelines.logging.EventLevel import org.apache.spark.sql.pipelines.utils.{EventVerificationTestHelpers, TestPipelineUpdateContextMixin} import org.apache.spark.sql.types.StructType @@ -57,37 +57,18 @@ class PythonPipelineSuite val customSessionIdentifier = UUID.randomUUID().toString val pythonCode = s""" - |from pyspark.sql import SparkSession - |from pyspark import pipelines as dp - |from pyspark.pipelines.spark_connect_graph_element_registry import ( - | SparkConnectGraphElementRegistry, - |) - |from pyspark.pipelines.spark_connect_pipeline import create_dataflow_graph - |from pyspark.pipelines.graph_element_registry import ( - | graph_element_registration_context, - |) + |from pyspark.pipelines.test.python_pipeline_suite_helpers import setup |from pyspark.pipelines.add_pipeline_analysis_context import ( | add_pipeline_analysis_context |) | - |spark = SparkSession.builder \\ - | .remote("sc://localhost:$serverPort") \\ - | .config("spark.connect.grpc.channel.timeout", "5s") \\ - | .config("spark.custom.identifier", "$customSessionIdentifier") \\ - | .create() + |spark, registry = setup("$serverPort", "$customSessionIdentifier") | - |dataflow_graph_id = create_dataflow_graph( - | spark, - | default_catalog=None, - | default_database=None, - | sql_conf={}, - |) - | - |registry = SparkConnectGraphElementRegistry(spark, dataflow_graph_id) |with add_pipeline_analysis_context( | spark=spark, dataflow_graph_id=dataflow_graph_id, flow_name=None |): | with graph_element_registration_context(registry): + | |$indentedPythonText |""".stripMargin @@ -98,20 +79,71 @@ class PythonPipelineSuite throw new RuntimeException( s"Python process failed with exit code $exitCode. Output: ${output.mkString("\n")}") } + + getCurrentGraphRegistrationContext(customSessionIdentifier).toDataflowGraph + } + + private def findSessionHolder(sessionId: String): Option[SessionHolder] = { val activeSessions = SparkConnectService.sessionManager.listActiveSessions - // get the session holder by finding the session with the custom UUID set in the conf - val sessionHolder = activeSessions + activeSessions .map(info => SparkConnectService.sessionManager.getIsolatedSession(info.key, None)) - .find(_.session.conf.get("spark.custom.identifier") == customSessionIdentifier) - .getOrElse( - throw new RuntimeException(s"Session with identifier $customSessionIdentifier not found")) + .find(_.session.conf.get("spark.custom.identifier") == sessionId) + } + + private def getCurrentGraphRegistrationContext(sessionId: String): GraphRegistrationContext = { + val sessionHolder = findSessionHolder(sessionId).getOrElse( + throw new RuntimeException(s"Session with identifier $sessionId not found")) // get all dataflow graphs from the session holder val dataflowGraphContexts = sessionHolder.dataflowGraphRegistry.getAllDataflowGraphs assert(dataflowGraphContexts.size == 1) - dataflowGraphContexts.head.toDataflowGraph + dataflowGraphContexts.head + } + + def buildAndResolveGraph(pythonText: String): DataflowGraph = { + val indentedPythonText = pythonText.linesIterator.map(" " + _).mkString("\n") + val sessionId = UUID.randomUUID().toString + val pythonCode = + s""" + |from pyspark.pipelines.test.python_pipeline_suite_helpers import setup + |from pyspark.pipelines.add_pipeline_analysis_context import ( + | add_pipeline_analysis_context + |) + | + |spark, registry = setup("$serverPort", "$sessionId") + | + |with add_pipeline_analysis_context( + | spark=spark, dataflow_graph_id=dataflow_graph_id, flow_name=None + |): + | with graph_element_registration_context(registry): + | with graph_element_registration_context(registry): + | + |$indentedPythonText + | + |registry.register_signalled_query_functions() + |""".stripMargin + + import scala.concurrent.Future + + // Execute the code in a separate thread + val pythonExecutionFuture: Future[(Int, Seq[String])] = Future { + executePythonCode(pythonCode) + } + + // TODO: need to wait until GraphRegistrationContext session exists + + val sessionHolder = findSessionHolder(sessionId).get + + val graphRegistrationContext = getCurrentGraphRegistrationContext(sessionId) + val unresolvedGraph = graphRegistrationContext.toDataflowGraph + val updateContext = TestPipelineUpdateContext(spark, unresolvedGraph, storageRoot) + sessionHolder.cachePipelineExecution(dataflowGraphId, updateContext) + + val graph = updateContext.pipelineExecution.resolveGraph() + + graph } def graphIdentifier(name: String): TableIdentifier = { diff --git a/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/PipelineExecution.scala b/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/PipelineExecution.scala index ffbecdda6c586..609d2988642e5 100644 --- a/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/PipelineExecution.scala +++ b/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/PipelineExecution.scala @@ -109,7 +109,7 @@ class PipelineExecution(context: PipelineUpdateContext) { ) } - private def resolveGraph(): DataflowGraph = { + def resolveGraph(): DataflowGraph = { try { context.unresolvedGraph.resolve(Some(graphAnalysisContext)).validate() } catch { From a06a934d127131a433a36f5670a9595fad3e8752 Mon Sep 17 00:00:00 2001 From: Sandy Ryza Date: Sun, 21 Dec 2025 16:23:17 -0800 Subject: [PATCH 3/7] more test --- .../spark_connect_graph_element_registry.py | 4 + .../pipelines/DataflowGraphRegistry.scala | 15 ++-- .../pipelines/PythonPipelineSuite.scala | 79 ++++++++++++------- ...SparkDeclarativePipelinesServerSuite.scala | 14 ++-- 4 files changed, 67 insertions(+), 45 deletions(-) diff --git a/python/pyspark/pipelines/spark_connect_graph_element_registry.py b/python/pyspark/pipelines/spark_connect_graph_element_registry.py index 800321b7f238b..93011edc28293 100644 --- a/python/pyspark/pipelines/spark_connect_graph_element_registry.py +++ b/python/pyspark/pipelines/spark_connect_graph_element_registry.py @@ -52,6 +52,10 @@ def __init__(self, spark: SparkSession, dataflow_graph_id: str) -> None: self._client_id = str(uuid.uuid4()) self._query_funcs_by_flow_name: dict[str, QueryFunction] = {} + @property + def dataflow_graph_id(self) -> str: + return self._dataflow_graph_id + def register_output(self, output: Output) -> None: table_details = None sink_details = None diff --git a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/pipelines/DataflowGraphRegistry.scala b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/pipelines/DataflowGraphRegistry.scala index e8114f38ec40c..9ff6d8068285c 100644 --- a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/pipelines/DataflowGraphRegistry.scala +++ b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/pipelines/DataflowGraphRegistry.scala @@ -46,30 +46,25 @@ class DataflowGraphRegistry { graphId } - /** Retrieves the graph for a given id. */ - def getDataflowGraph(graphId: String): Option[GraphRegistrationContext] = { - Option(dataflowGraphs.get(graphId)) + /** Returns a Map of dataflow graph IDs to their corresponding GraphRegistrationContext. */ + def getDataflowGraphs: Map[String, GraphRegistrationContext] = { + dataflowGraphs.asScala.toMap } /** Retrieves the graph for a given id, and throws if the id could not be found. */ def getDataflowGraphOrThrow(dataflowGraphId: String): GraphRegistrationContext = - getDataflowGraph(dataflowGraphId).getOrElse { + getDataflowGraphs.getOrElse(dataflowGraphId, { throw new SparkException( errorClass = "DATAFLOW_GRAPH_NOT_FOUND", messageParameters = Map("graphId" -> dataflowGraphId), cause = null) - } + }) /** Removes the graph with a given id from the registry. */ def dropDataflowGraph(graphId: String): Unit = { dataflowGraphs.remove(graphId) } - /** Returns all graphs in the registry. */ - def getAllDataflowGraphs: Seq[GraphRegistrationContext] = { - dataflowGraphs.values().asScala.toSeq - } - /** Removes all graphs from the registry. */ def dropAllDataflowGraphs(): Unit = { dataflowGraphs.clear() diff --git a/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/pipelines/PythonPipelineSuite.scala b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/pipelines/PythonPipelineSuite.scala index 1da87acc5a94a..80e4675d88d72 100644 --- a/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/pipelines/PythonPipelineSuite.scala +++ b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/pipelines/PythonPipelineSuite.scala @@ -24,10 +24,13 @@ import java.util.UUID import java.util.concurrent.TimeUnit import scala.collection.mutable.ArrayBuffer +import scala.concurrent.{ExecutionContext, Future} import scala.util.Try import org.scalactic.source.Position import org.scalatest.Tag +import org.scalatest.concurrent.{Eventually, Futures} +import org.scalatest.time.{Millis, Seconds, Span} import org.apache.spark.api.python.PythonUtils import org.apache.spark.sql.AnalysisException @@ -57,15 +60,19 @@ class PythonPipelineSuite val customSessionIdentifier = UUID.randomUUID().toString val pythonCode = s""" - |from pyspark.pipelines.test.python_pipeline_suite_helpers import setup + |from pyspark import pipelines as dp + |from pyspark.pipelines.tests.python_pipeline_suite_helpers import setup |from pyspark.pipelines.add_pipeline_analysis_context import ( | add_pipeline_analysis_context |) + |from pyspark.pipelines.graph_element_registry import ( + | graph_element_registration_context, + |) | |spark, registry = setup("$serverPort", "$customSessionIdentifier") | |with add_pipeline_analysis_context( - | spark=spark, dataflow_graph_id=dataflow_graph_id, flow_name=None + | spark=spark, dataflow_graph_id=registry.dataflow_graph_id, flow_name=None |): | with graph_element_registration_context(registry): | @@ -80,7 +87,9 @@ class PythonPipelineSuite s"Python process failed with exit code $exitCode. Output: ${output.mkString("\n")}") } - getCurrentGraphRegistrationContext(customSessionIdentifier).toDataflowGraph + getCurrentGraphRegistrationContext(customSessionIdentifier) + .getOrElse(throw new RuntimeException("Graph registration context not found")) + .toDataflowGraph } private def findSessionHolder(sessionId: String): Option[SessionHolder] = { @@ -91,15 +100,16 @@ class PythonPipelineSuite .find(_.session.conf.get("spark.custom.identifier") == sessionId) } - private def getCurrentGraphRegistrationContext(sessionId: String): GraphRegistrationContext = { - val sessionHolder = findSessionHolder(sessionId).getOrElse( - throw new RuntimeException(s"Session with identifier $sessionId not found")) - - // get all dataflow graphs from the session holder - val dataflowGraphContexts = sessionHolder.dataflowGraphRegistry.getAllDataflowGraphs - assert(dataflowGraphContexts.size == 1) - - dataflowGraphContexts.head + private def getCurrentGraphRegistrationContext( + sessionId: String): Option[GraphRegistrationContext] = { + findSessionHolder(sessionId).map { sessionHolder => + val dataflowGraphs = sessionHolder.dataflowGraphRegistry.getDataflowGraphs + if (dataflowGraphs.size == 1) { + dataflowGraphs.values.head + } else { + throw new RuntimeException(s"Expected exactly 1 graph, but found ${dataflowGraphs.size}") + } + } } def buildAndResolveGraph(pythonText: String): DataflowGraph = { @@ -107,42 +117,59 @@ class PythonPipelineSuite val sessionId = UUID.randomUUID().toString val pythonCode = s""" - |from pyspark.pipelines.test.python_pipeline_suite_helpers import setup + |from pyspark import pipelines as dp + |from pyspark.pipelines.tests.python_pipeline_suite_helpers import setup |from pyspark.pipelines.add_pipeline_analysis_context import ( | add_pipeline_analysis_context |) + |from pyspark.pipelines.graph_element_registry import ( + | graph_element_registration_context, + |) | |spark, registry = setup("$serverPort", "$sessionId") | |with add_pipeline_analysis_context( - | spark=spark, dataflow_graph_id=dataflow_graph_id, flow_name=None + | spark=spark, dataflow_graph_id=registry.dataflow_graph_id, flow_name=None |): | with graph_element_registration_context(registry): - | with graph_element_registration_context(registry): | |$indentedPythonText | |registry.register_signalled_query_functions() |""".stripMargin - import scala.concurrent.Future - + // Create a custom execution context for the Future + implicit val ec: ExecutionContext = ExecutionContext.global // Execute the code in a separate thread val pythonExecutionFuture: Future[(Int, Seq[String])] = Future { executePythonCode(pythonCode) } - // TODO: need to wait until GraphRegistrationContext session exists + // Wait until the session exists, timeout after 10 seconds + val sessionHolder = Eventually.eventually( + Futures.timeout(Span(10, Seconds)), Futures.interval(Span(100, Millis))) { + findSessionHolder(sessionId).getOrElse( + throw new RuntimeException(s"Session with ID $sessionId not found")) + } - val sessionHolder = findSessionHolder(sessionId).get + // Wait until a graph registration context exists, timeout after 10 seconds + val graphRegistrationContext = Eventually.eventually( + Futures.timeout(Span(10, Seconds)), Futures.interval(Span(100, Millis))) { + getCurrentGraphRegistrationContext(sessionId).getOrElse( + throw new RuntimeException("Graph registration context not found")) + } - val graphRegistrationContext = getCurrentGraphRegistrationContext(sessionId) val unresolvedGraph = graphRegistrationContext.toDataflowGraph val updateContext = TestPipelineUpdateContext(spark, unresolvedGraph, storageRoot) - sessionHolder.cachePipelineExecution(dataflowGraphId, updateContext) + + // Get the actual graph ID from the registry for this context + val dataflowGraphs = sessionHolder.dataflowGraphRegistry.getDataflowGraphs + val graphId = dataflowGraphs.find(_._2 == graphRegistrationContext).map(_._1) + .getOrElse(throw new RuntimeException("Could not find graph ID for context")) + sessionHolder.cachePipelineExecution(graphId, updateContext) val graph = updateContext.pipelineExecution.resolveGraph() - + graph } @@ -1202,7 +1229,7 @@ class PythonPipelineSuite } test("access upstream schema within query function") { - val graph = buildGraph(""" + val graph = buildAndResolveGraph(""" |@dp.materialized_view |def mv2(): | spark.table("table1").schema @@ -1212,20 +1239,16 @@ class PythonPipelineSuite |def mv1(): | return spark.range(5) |""".stripMargin) - .resolve() - .validate() assert(graph.flows.size == 2) assert(graph.tables.size == 2) } test("query function failure") { - val graph = buildGraph(""" + val graph = buildAndResolveGraph(""" |@dp.materialized_view |def mv(): | raise ValueError("bla") |""".stripMargin) - .resolve() - .validate() assert(graph.flows.size == 2) assert(graph.tables.size == 2) } diff --git a/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/pipelines/SparkDeclarativePipelinesServerSuite.scala b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/pipelines/SparkDeclarativePipelinesServerSuite.scala index 0c933f7bc95c7..cd1345319a63f 100644 --- a/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/pipelines/SparkDeclarativePipelinesServerSuite.scala +++ b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/pipelines/SparkDeclarativePipelinesServerSuite.scala @@ -321,7 +321,7 @@ class SparkDeclarativePipelinesServerSuite .build())) // Verify the graph exists in the default session - assert(getDefaultSessionHolder.dataflowGraphRegistry.getAllDataflowGraphs.size == 1) + assert(getDefaultSessionHolder.dataflowGraphRegistry.getDataflowGraphs.size == 1) } // Create a second session with different user/session ID @@ -374,8 +374,8 @@ class SparkDeclarativePipelinesServerSuite .getIsolatedSession(SessionKey(newSessionUserId, newSessionId), None) val defaultSessionGraphs = - getDefaultSessionHolder.dataflowGraphRegistry.getAllDataflowGraphs - val newSessionGraphs = newSessionHolder.dataflowGraphRegistry.getAllDataflowGraphs + getDefaultSessionHolder.dataflowGraphRegistry.getDataflowGraphs.values + val newSessionGraphs = newSessionHolder.dataflowGraphRegistry.getDataflowGraphs.values assert(defaultSessionGraphs.size == 1) assert(newSessionGraphs.size == 1) @@ -441,7 +441,7 @@ class SparkDeclarativePipelinesServerSuite .getIsolatedSessionIfPresent(SessionKey(testUserId, testSessionId)) .get - val graphsBefore = sessionHolder.dataflowGraphRegistry.getAllDataflowGraphs + val graphsBefore = sessionHolder.dataflowGraphRegistry.getDataflowGraphs.values assert(graphsBefore.size == 1) // Close the session @@ -453,7 +453,7 @@ class SparkDeclarativePipelinesServerSuite assert(sessionAfterClose.isEmpty, "Session should be cleaned up after close") // Verify the graph is removed - val graphsAfter = sessionHolder.dataflowGraphRegistry.getAllDataflowGraphs + val graphsAfter = sessionHolder.dataflowGraphRegistry.getDataflowGraphs.values assert(graphsAfter.isEmpty, "Graph should be removed after session close") } } @@ -518,7 +518,7 @@ class SparkDeclarativePipelinesServerSuite // Verify the graph exists val sessionHolder = getDefaultSessionHolder - val graphsBefore = sessionHolder.dataflowGraphRegistry.getAllDataflowGraphs + val graphsBefore = sessionHolder.dataflowGraphRegistry.getDataflowGraphs.values assert(graphsBefore.size == 1) // Drop the graph @@ -532,7 +532,7 @@ class SparkDeclarativePipelinesServerSuite .build())) // Verify the graph is removed - val graphsAfter = sessionHolder.dataflowGraphRegistry.getAllDataflowGraphs + val graphsAfter = sessionHolder.dataflowGraphRegistry.getDataflowGraphs.values assert(graphsAfter.isEmpty, "Graph should be removed after drop") } } From a53ff15eab83ff48fd02e0941b05a8e4ab3e625d Mon Sep 17 00:00:00 2001 From: Sandy Ryza Date: Sun, 21 Dec 2025 16:56:09 -0800 Subject: [PATCH 4/7] run start_run in test --- .../tests/python_pipeline_suite_helpers.py | 23 +++++++++++++++++- .../connect/pipelines/PipelinesHandler.scala | 1 + .../sql/connect/service/SessionHolder.scala | 1 + .../pipelines/PythonPipelineSuite.scala | 24 +++++++++++++------ .../pipelines/graph/PipelineExecution.scala | 12 ++++++---- 5 files changed, 48 insertions(+), 13 deletions(-) diff --git a/python/pyspark/pipelines/tests/python_pipeline_suite_helpers.py b/python/pyspark/pipelines/tests/python_pipeline_suite_helpers.py index 74c29728dd811..dc92fe2d924a0 100644 --- a/python/pyspark/pipelines/tests/python_pipeline_suite_helpers.py +++ b/python/pyspark/pipelines/tests/python_pipeline_suite_helpers.py @@ -4,7 +4,8 @@ SparkConnectGraphElementRegistry, ) from pyspark.pipelines.spark_connect_pipeline import create_dataflow_graph - +from pyspark.pipelines.spark_connect_pipeline import start_run, handle_pipeline_events +import threading def setup(server_port: str, session_identifier: str) -> tuple[SparkSession, SparkConnectGraphElementRegistry]: spark = SparkSession.builder \ @@ -22,3 +23,23 @@ def setup(server_port: str, session_identifier: str) -> tuple[SparkSession, Spar registry = SparkConnectGraphElementRegistry(spark, dataflow_graph_id) return spark, registry + + +def run_and_handle_signals( + spark: SparkSession, + registry: SparkConnectGraphElementRegistry, + storage_root: str, +) -> None: + result_iter = start_run( + spark, + dataflow_graph_id=registry.dataflow_graph_id, + full_refresh=None, + full_refresh_all=False, + refresh=None, + dry=True, + storage=storage_root + ) + thread = threading.Thread(target=handle_pipeline_events, args=(result_iter,), daemon=True) + thread.start() + + registry.register_signalled_query_functions() diff --git a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/pipelines/PipelinesHandler.scala b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/pipelines/PipelinesHandler.scala index fff772b53f5ac..4e28f135a00eb 100644 --- a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/pipelines/PipelinesHandler.scala +++ b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/pipelines/PipelinesHandler.scala @@ -326,6 +326,7 @@ private[connect] object PipelinesHandler extends Logging { flow: proto.PipelineCommand.DefineFlow, transformRelationFunc: Relation => LogicalPlan, sessionHolder: SessionHolder): TableIdentifier = { + logInfo("pizza defining flow") if (flow.hasOnce) { throw new AnalysisException( "DEFINE_FLOW_ONCE_OPTION_NOT_SUPPORTED", diff --git a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SessionHolder.scala b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SessionHolder.scala index 86906b2825b28..0a1d1546cc13d 100644 --- a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SessionHolder.scala +++ b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SessionHolder.scala @@ -493,6 +493,7 @@ case class SessionHolder(userId: String, sessionId: String, session: SparkSessio private[connect] def cachePipelineExecution( graphId: String, pipelineUpdateContext: PipelineUpdateContext): Unit = { + print("pizza: caching pipeline execution") pipelineExecutions.compute( graphId, (_, existing) => { diff --git a/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/pipelines/PythonPipelineSuite.scala b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/pipelines/PythonPipelineSuite.scala index 80e4675d88d72..b3646baeac6b0 100644 --- a/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/pipelines/PythonPipelineSuite.scala +++ b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/pipelines/PythonPipelineSuite.scala @@ -118,7 +118,7 @@ class PythonPipelineSuite val pythonCode = s""" |from pyspark import pipelines as dp - |from pyspark.pipelines.tests.python_pipeline_suite_helpers import setup + |from pyspark.pipelines.tests.python_pipeline_suite_helpers import * |from pyspark.pipelines.add_pipeline_analysis_context import ( | add_pipeline_analysis_context |) @@ -135,7 +135,7 @@ class PythonPipelineSuite | |$indentedPythonText | - |registry.register_signalled_query_functions() + |run_and_handle_signals(spark, registry, "$storageRoot") |""".stripMargin // Create a custom execution context for the Future @@ -159,18 +159,28 @@ class PythonPipelineSuite throw new RuntimeException("Graph registration context not found")) } - val unresolvedGraph = graphRegistrationContext.toDataflowGraph - val updateContext = TestPipelineUpdateContext(spark, unresolvedGraph, storageRoot) +// val unresolvedGraph = graphRegistrationContext.toDataflowGraph +// val updateContext = TestPipelineUpdateContext(spark, unresolvedGraph, storageRoot) // Get the actual graph ID from the registry for this context val dataflowGraphs = sessionHolder.dataflowGraphRegistry.getDataflowGraphs val graphId = dataflowGraphs.find(_._2 == graphRegistrationContext).map(_._1) .getOrElse(throw new RuntimeException("Could not find graph ID for context")) - sessionHolder.cachePipelineExecution(graphId, updateContext) - val graph = updateContext.pipelineExecution.resolveGraph() + val pipelineExecution = Eventually.eventually( + Futures.timeout(Span(10, Seconds)), Futures.interval(Span(100, Millis))) { + sessionHolder.getPipelineExecution(graphId).getOrElse( + throw new RuntimeException("Pipeline execution not found")) + } +// sessionHolder.cachePipelineExecution(graphId, updateContext) + +// val graph = updateContext.pipelineExecution.resolveGraph() - graph + Eventually.eventually( + Futures.timeout(Span(10, Seconds)), Futures.interval(Span(100, Millis))) { + pipelineExecution.pipelineExecution.resolvedGraph.getOrElse( + throw new RuntimeException("Resolved graph not found")) + } } def graphIdentifier(name: String): TableIdentifier = { diff --git a/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/PipelineExecution.scala b/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/PipelineExecution.scala index 609d2988642e5..3e73d12cda78a 100644 --- a/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/PipelineExecution.scala +++ b/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/PipelineExecution.scala @@ -35,7 +35,9 @@ import org.apache.spark.sql.pipelines.logging.{ class PipelineExecution(context: PipelineUpdateContext) { /** [Visible for testing] */ - private[pipelines] var graphExecution: Option[TriggeredGraphExecution] = None + @volatile private[pipelines] var graphExecution: Option[TriggeredGraphExecution] = None + /** [Visible for testing] */ + @volatile var resolvedGraph: Option[DataflowGraph] = None val graphAnalysisContext = new GraphAnalysisContext() def executionStarted: Boolean = synchronized { graphExecution.nonEmpty } @@ -46,12 +48,12 @@ class PipelineExecution(context: PipelineUpdateContext) { */ def startPipeline(): Unit = synchronized { // Initialize the graph. - val resolvedGraph = resolveGraph() + resolvedGraph = Some(resolveGraph()) if (context.fullRefreshTables.nonEmpty) { - State.reset(resolvedGraph, context) + State.reset(resolvedGraph.get, context) } - val initializedGraph = DatasetManager.materializeDatasets(resolvedGraph, context) + val initializedGraph = DatasetManager.materializeDatasets(resolvedGraph.get, context) // Execute the graph. graphExecution = Some( @@ -87,7 +89,7 @@ class PipelineExecution(context: PipelineUpdateContext) { /** Validates that the pipeline graph can be successfully resolved and validates it. */ def dryRunPipeline(): Unit = synchronized { - resolveGraph() + resolvedGraph = Some(resolveGraph()) context.eventCallback( constructTerminationEvent(RunCompletion()) ) From 5dad12ff8d1770815a72124adc567fb09a1e8ea0 Mon Sep 17 00:00:00 2001 From: Sandy Ryza Date: Sun, 21 Dec 2025 17:11:41 -0800 Subject: [PATCH 5/7] resolvedGraph --- .../apache/spark/sql/connect/pipelines/PipelinesHandler.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/pipelines/PipelinesHandler.scala b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/pipelines/PipelinesHandler.scala index 4e28f135a00eb..9f2ef61411903 100644 --- a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/pipelines/PipelinesHandler.scala +++ b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/pipelines/PipelinesHandler.scala @@ -568,7 +568,7 @@ private[connect] object PipelinesHandler extends Logging { val graphAnalysisContext = execution.graphAnalysisContext try { - while (execution.executionStarted) { + while (execution.resolvedGraph.isEmpty) { val signal = proto.PipelineQueryFunctionExecutionSignal.newBuilder() while (!graphAnalysisContext.flowClientSignalQueue.isEmpty) { From 4dea491607cb68ef35a2ca358a17fd655999f8b5 Mon Sep 17 00:00:00 2001 From: Sandy Ryza Date: Mon, 22 Dec 2025 08:30:09 -0800 Subject: [PATCH 6/7] instrumentation and wait for pipeline execution --- .../connect/pipelines/PipelinesHandler.scala | 127 +++++++++++++----- .../sql/connect/service/SessionHolder.scala | 5 + .../graph/CoreDataflowNodeProcessor.scala | 3 + .../graph/DataflowGraphTransformer.scala | 32 ++++- .../graph/GraphAnalysisContext.scala | 11 ++ .../pipelines/graph/PipelineExecution.scala | 15 ++- 6 files changed, 153 insertions(+), 40 deletions(-) diff --git a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/pipelines/PipelinesHandler.scala b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/pipelines/PipelinesHandler.scala index 9f2ef61411903..1abe077b5cc8f 100644 --- a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/pipelines/PipelinesHandler.scala +++ b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/pipelines/PipelinesHandler.scala @@ -35,7 +35,7 @@ import org.apache.spark.sql.connect.service.SessionHolder import org.apache.spark.sql.execution.command.{ShowCatalogsCommand, ShowNamespacesCommand} import org.apache.spark.sql.pipelines.Language.Python import org.apache.spark.sql.pipelines.common.RunState.{CANCELED, FAILED} -import org.apache.spark.sql.pipelines.graph.{AllTables, FlowAnalysis, GraphIdentifierManager, GraphRegistrationContext, IdentifierHelper, NoTables, PipelineUpdateContextImpl, QueryContext, QueryFunctionSuccess, QueryOrigin, QueryOriginType, Sink, SinkImpl, SomeTables, SqlGraphRegistrationContext, Table, TableFilter, TemporaryView, UnresolvedFlow} +import org.apache.spark.sql.pipelines.graph.{AllTables, FlowAnalysis, GraphIdentifierManager, GraphRegistrationContext, IdentifierHelper, NoTables, PipelineUpdateContext, PipelineUpdateContextImpl, QueryContext, QueryFunctionSuccess, QueryOrigin, QueryOriginType, Sink, SinkImpl, SomeTables, SqlGraphRegistrationContext, Table, TableFilter, TemporaryView, UnresolvedFlow} import org.apache.spark.sql.pipelines.logging.{PipelineEvent, RunProgress} import org.apache.spark.sql.types.StructType @@ -451,12 +451,20 @@ private[connect] object PipelinesHandler extends Logging { cmd.getStorage) sessionHolder.cachePipelineExecution(dataflowGraphId, pipelineUpdateContext) + // scalastyle:off println + println(s"INSTRUMENTATION: Starting pipeline execution, dry=${cmd.getDry}") + // scalastyle:on println + if (cmd.getDry) { pipelineUpdateContext.pipelineExecution.dryRunPipeline() } else { pipelineUpdateContext.pipelineExecution.runPipeline() } + // scalastyle:off println + println(s"INSTRUMENTATION: Pipeline execution completed") + // scalastyle:on println + // Rethrow any exceptions that caused the pipeline run to fail so that the exception is // propagated back to the SC client / CLI. runFailureEvent.foreach { event => @@ -562,46 +570,88 @@ private[connect] object PipelinesHandler extends Logging { logInfo(s"Starting query function execution signal stream for " + s"graph $dataflowGraphId, client $clientId") - sessionHolder.getPipelineExecution(dataflowGraphId) match { - case Some(pipelineExecution) => - val execution = pipelineExecution.pipelineExecution - val graphAnalysisContext = execution.graphAnalysisContext - - try { - while (execution.resolvedGraph.isEmpty) { - val signal = proto.PipelineQueryFunctionExecutionSignal.newBuilder() - - while (!graphAnalysisContext.flowClientSignalQueue.isEmpty) { - val flowId = graphAnalysisContext.flowClientSignalQueue.remove() - signal.addFlowNames(flowId.unquotedString) - } - - if (signal.getFlowNamesCount > 0) { - logInfo(s"Sending execution signal for ${signal.getFlowNamesCount} flows") - - val response = ExecutePlanResponse.newBuilder() - .setPipelineQueryFunctionExecutionSignal(signal.build()) - .build() - - responseObserver.onNext(response) - } - - Thread.sleep(100) - } - } catch { - case e: Exception => - logError( - s"Error in query function execution signal stream for graph $dataflowGraphId", e) - responseObserver.onError(e) - } finally { - responseObserver.onCompleted() + var streamCompleted = false + try { + var pipelineExecution: Option[PipelineUpdateContext] = None + var waitAttempts = 0 + val maxWaitAttempts = 100 + + while (pipelineExecution.isEmpty && waitAttempts < maxWaitAttempts) { + pipelineExecution = sessionHolder.getPipelineExecution(dataflowGraphId) + if (pipelineExecution.isEmpty) { + Thread.sleep(100) + waitAttempts += 1 } + } - case None => + if (pipelineExecution.isEmpty) { val error = new IllegalStateException( - s"No active pipeline execution found for graph $dataflowGraphId") - logError(error.getMessage) + s"No active pipeline execution found for graph $dataflowGraphId " + + s"after ${maxWaitAttempts * 100}ms") responseObserver.onError(error) + streamCompleted = true + return + } + + val execution = pipelineExecution.get.pipelineExecution + val graphAnalysisContext = execution.graphAnalysisContext + var signalAttempts = 0 + val maxSignalAttempts = 600 + + // scalastyle:off println + println(s"INSTRUMENTATION: Starting signal loop") + // scalastyle:on println + + while (execution.resolvedGraph.isEmpty && signalAttempts < maxSignalAttempts) { + val signal = proto.PipelineQueryFunctionExecutionSignal.newBuilder() + + while (!graphAnalysisContext.flowClientSignalQueue.isEmpty) { + val flowId = graphAnalysisContext.flowClientSignalQueue.remove() + signal.addFlowNames(flowId.unquotedString) + } + + if (signal.getFlowNamesCount > 0) { + // scalastyle:off println + println(s"INSTRUMENTATION: Sending signal for ${signal.getFlowNamesCount} flows") + // scalastyle:on println + logInfo(s"Sending execution signal for ${signal.getFlowNamesCount} flows") + + val response = ExecutePlanResponse.newBuilder() + .setPipelineQueryFunctionExecutionSignal(signal.build()) + .build() + + responseObserver.onNext(response) + // scalastyle:off println + println(s"INSTRUMENTATION: Signal sent successfully") + // scalastyle:on println + } else { + // scalastyle:off println + println(s"INSTRUMENTATION: No signals, attempt $signalAttempts") + // scalastyle:on println + } + + Thread.sleep(100) + signalAttempts += 1 + } + + // scalastyle:off println + println(s"INSTRUMENTATION: Exited signal loop, attempts=$signalAttempts") + // scalastyle:on println + + responseObserver.onCompleted() + streamCompleted = true + } catch { + case e: Exception => + logError( + s"Error in query function execution signal stream for graph $dataflowGraphId", e) + if (!streamCompleted) { + responseObserver.onError(e) + streamCompleted = true + } + } finally { + if (!streamCompleted) { + responseObserver.onCompleted() + } } } @@ -635,6 +685,9 @@ private[connect] object PipelinesHandler extends Logging { case Some(pipelineUpdateContext) => // TODO: what if we haven't yet started analysis? val graphAnalysisContext = pipelineUpdateContext.pipelineExecution.graphAnalysisContext + // scalastyle:off println + println(s"INSTRUMENTATION: markFlowPlanRegistered called for flow $flowIdentifier") + // scalastyle:on println graphAnalysisContext.markFlowPlanRegistered(flowIdentifier) case None => } diff --git a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SessionHolder.scala b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SessionHolder.scala index 0a1d1546cc13d..07da1f66fb21e 100644 --- a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SessionHolder.scala +++ b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SessionHolder.scala @@ -332,6 +332,11 @@ case class SessionHolder(userId: String, sessionId: String, session: SparkSessio // It is not called under SparkConnectSessionManager.sessionsLock, but it's guaranteed to be // called only once, since removing the session from SparkConnectSessionManager.sessionStore is // synchronized and guaranteed to happen only once. + // scalastyle:off println + println(s"INSTRUMENTATION: SessionHolder.close() called for session $sessionId") + println(s"INSTRUMENTATION: Called from: ${Thread.currentThread().getStackTrace.take(10). + mkString("\n")}") + // scalastyle:on println if (closedTimeMs.isDefined) { throw new IllegalStateException(s"Session $key is already closed.") } diff --git a/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/CoreDataflowNodeProcessor.scala b/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/CoreDataflowNodeProcessor.scala index 5d0e3858f6fc6..8e7db57a02a54 100644 --- a/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/CoreDataflowNodeProcessor.scala +++ b/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/CoreDataflowNodeProcessor.scala @@ -40,6 +40,9 @@ class CoreDataflowNodeProcessor(rawGraph: DataflowGraph, context: GraphAnalysisC context.resolvedInputsByIdentifier ) context.resolvedFlowsQueue.add(resolvedFlow) + // scalastyle:off println + println(s"INSTRUMENTATION: Adding resolved flow ${flow.identifier} to resolvedFlowsMap") + // scalastyle:on println context.resolvedFlowsMap.put(flow.identifier, resolvedFlow) resolvedFlow } diff --git a/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/DataflowGraphTransformer.scala b/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/DataflowGraphTransformer.scala index 104b4910860af..b67512df182ec 100644 --- a/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/DataflowGraphTransformer.scala +++ b/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/DataflowGraphTransformer.scala @@ -122,13 +122,24 @@ class DataflowGraphTransformer(graph: DataflowGraph) extends AutoCloseable { transformer: (GraphElement, Seq[GraphElement]) => Seq[GraphElement], context: GraphAnalysisContext, disableParallelism: Boolean = false): DataflowGraphTransformer = { + // scalastyle:off println + println("INSTRUMENTATION: transformDownNodes starting") + // scalastyle:on println val executor = if (disableParallelism) selfExecutor else fixedPoolExecutor val batchSize = if (disableParallelism) 1 else parallelism var futures = ArrayBuffer[Future[Unit]]() context.toBeResolvedFlows.addAll(flows.asJava) - while (futures.nonEmpty || context.toBeResolvedFlows.peekFirst() != null) { + while ( + futures.nonEmpty || + context.toBeResolvedFlows.peekFirst() != null || + !context.failedUnregisteredFlows.isEmpty) { + // scalastyle:off println + println("INSTRUMENTATION: transformDownNodes iteration") + println(s"INSTRUMENTATION: toBeResolvedFlows.size = ${context.toBeResolvedFlows.size}") + println(s"INSTRUMENTATION: failedUnregistered.size = ${context.failedUnregisteredFlows.size}") + // scalastyle:on println val (done, notDone) = futures.partition(_.isDone) // Explicitly call future.get() to propagate exceptions one by one if any try { @@ -156,6 +167,10 @@ class DataflowGraphTransformer(graph: DataflowGraph) extends AutoCloseable { } } + // scalastyle:off println + println("INSTRUMENTATION: transformDownNodes - exited main while loop") + // scalastyle:on println + // Mutate the fail analysis entities // A table is failed to analyze if: // - It does not exist in the resolvedFlowDestinationsMap @@ -197,6 +212,9 @@ class DataflowGraphTransformer(graph: DataflowGraph) extends AutoCloseable { tableMap = computeTableMap() viewMap = computeViewMap() sinkMap = computeSinkMap() + // scalastyle:off println + println("pizza: completed transformDownNodes") + // scalastyle: on this } @@ -211,7 +229,11 @@ class DataflowGraphTransformer(graph: DataflowGraph) extends AutoCloseable { } catch { case e: TransformNodeRetryableException => e.datasetIdentifier match { - case None => context.failedUnregisteredFlows.put(e.failedNode.identifier, e.failedNode) + case None => + // scalastyle:off println + println(s"INSTRUMENTATION: Adding flow ${e.failedNode.identifier} to unregistered") + // scalastyle:on println + context.failedUnregisteredFlows.put(e.failedNode.identifier, e.failedNode) case Some(datasetIdentifier) => context.registerFailedDependentFlow(datasetIdentifier, e.failedNode) // Between the time the flow started and finished resolving, perhaps the @@ -228,6 +250,9 @@ class DataflowGraphTransformer(graph: DataflowGraph) extends AutoCloseable { (_, toRetryFlows) => { toRetryFlows.foreach { f => if (context.failedUnregisteredFlows.containsKey(f.identifier)) { + // scalastyle:off println + println(s"INSTRUMENTATION: Adding signal for ${f.identifier}") + // scalastyle:on println context.flowClientSignalQueue.add(f.identifier) } else { context.toBeResolvedFlows.addFirst(f) @@ -304,6 +329,9 @@ class DataflowGraphTransformer(graph: DataflowGraph) extends AutoCloseable { (_, toRetryFlows) => { toRetryFlows.foreach { f => if (context.failedUnregisteredFlows.containsKey(f.identifier)) { + // scalastyle:off println + println(s"INSTRUMENTATION: Adding signal for ${f.identifier} (location 2)") + // scalastyle:on println context.flowClientSignalQueue.add(f.identifier) } else { context.toBeResolvedFlows.addFirst(f) diff --git a/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/GraphAnalysisContext.scala b/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/GraphAnalysisContext.scala index 8d1900afc55bf..45814a5702677 100644 --- a/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/GraphAnalysisContext.scala +++ b/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/GraphAnalysisContext.scala @@ -78,11 +78,22 @@ class GraphAnalysisContext { } def markFlowPlanRegistered(flowIdentifier: TableIdentifier): Unit = { + // scalastyle:off println + println(s"INSTRUMENTATION: markFlowPlanRegistered - looking for flow $flowIdentifier") + println(s"INSTRUMENTATION: failedUnregisteredFlows size: ${failedUnregisteredFlows.size()}") + // scalastyle:on println val flow = failedUnregisteredFlows.remove(flowIdentifier) // Flow could be null if it failed on the client side and we never tried to execute the flow // function on the server side. if (flow != null) { + // scalastyle:off println + println(s"INSTRUMENTATION: Found flow $flowIdentifier, adding to toBeResolvedFlows queue") + // scalastyle:on println toBeResolvedFlows.addFirst(flow) + } else { + // scalastyle:off println + println(s"INSTRUMENTATION: Flow $flowIdentifier not found in failedUnregisteredFlows") + // scalastyle:on println } } diff --git a/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/PipelineExecution.scala b/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/PipelineExecution.scala index 3e73d12cda78a..d4be01f063a88 100644 --- a/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/PipelineExecution.scala +++ b/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/PipelineExecution.scala @@ -89,7 +89,13 @@ class PipelineExecution(context: PipelineUpdateContext) { /** Validates that the pipeline graph can be successfully resolved and validates it. */ def dryRunPipeline(): Unit = synchronized { + // scalastyle:off println + println("INSTRUMENTATION: dryRunPipeline() starting") + // scalastyle:on println resolvedGraph = Some(resolveGraph()) + // scalastyle:off println + println("INSTRUMENTATION: dryRunPipeline() - resolveGraph completed") + // scalastyle:on println context.eventCallback( constructTerminationEvent(RunCompletion()) ) @@ -112,8 +118,15 @@ class PipelineExecution(context: PipelineUpdateContext) { } def resolveGraph(): DataflowGraph = { + // scalastyle:off println + println("INSTRUMENTATION: resolveGraph() starting") + // scalastyle:on println try { - context.unresolvedGraph.resolve(Some(graphAnalysisContext)).validate() + val resolved = context.unresolvedGraph.resolve(Some(graphAnalysisContext)).validate() + // scalastyle:off println + println("INSTRUMENTATION: resolveGraph() completed successfully") + // scalastyle:on println + resolved } catch { case e: UnresolvedPipelineException => handleInvalidPipeline(e) From 29580a829884eda704dffd5232164f42f5b37b61 Mon Sep 17 00:00:00 2001 From: Sandy Ryza Date: Thu, 8 Jan 2026 13:24:35 -0800 Subject: [PATCH 7/7] remove prints --- .../spark_connect_graph_element_registry.py | 3 -- .../connect/pipelines/PipelinesHandler.scala | 28 ------------------- .../sql/connect/service/SessionHolder.scala | 6 ---- .../graph/CoreDataflowNodeProcessor.scala | 3 -- .../graph/DataflowGraphTransformer.scala | 22 --------------- .../graph/GraphAnalysisContext.scala | 10 ------- .../pipelines/graph/PipelineExecution.scala | 12 -------- 7 files changed, 84 deletions(-) diff --git a/python/pyspark/pipelines/spark_connect_graph_element_registry.py b/python/pyspark/pipelines/spark_connect_graph_element_registry.py index 93011edc28293..4229b35915580 100644 --- a/python/pyspark/pipelines/spark_connect_graph_element_registry.py +++ b/python/pyspark/pipelines/spark_connect_graph_element_registry.py @@ -124,9 +124,7 @@ def register_flow(self, flow: Flow) -> None: try: df = self._execute_query_function(flow.name, flow.func) except PySparkException as e: - print("pizza: exception while executing query function") if e.getCondition() == "ATTEMPT_ANALYSIS_IN_PIPELINE_QUERY_FUNCTION": - print("pizza: exception is an analysis exception") df = None else: raise e @@ -187,7 +185,6 @@ def register_signalled_query_functions(self) -> None: signal = result["pipeline_query_function_execution_signal"] flow_names = signal.flow_names - print("pizza: received signal with flow names: ", flow_names) for flow_name in flow_names: func = self._query_funcs_by_flow_name[flow_name] df = self._execute_query_function(flow_name, func) diff --git a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/pipelines/PipelinesHandler.scala b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/pipelines/PipelinesHandler.scala index 1abe077b5cc8f..132e16e98a852 100644 --- a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/pipelines/PipelinesHandler.scala +++ b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/pipelines/PipelinesHandler.scala @@ -326,7 +326,6 @@ private[connect] object PipelinesHandler extends Logging { flow: proto.PipelineCommand.DefineFlow, transformRelationFunc: Relation => LogicalPlan, sessionHolder: SessionHolder): TableIdentifier = { - logInfo("pizza defining flow") if (flow.hasOnce) { throw new AnalysisException( "DEFINE_FLOW_ONCE_OPTION_NOT_SUPPORTED", @@ -451,20 +450,12 @@ private[connect] object PipelinesHandler extends Logging { cmd.getStorage) sessionHolder.cachePipelineExecution(dataflowGraphId, pipelineUpdateContext) - // scalastyle:off println - println(s"INSTRUMENTATION: Starting pipeline execution, dry=${cmd.getDry}") - // scalastyle:on println - if (cmd.getDry) { pipelineUpdateContext.pipelineExecution.dryRunPipeline() } else { pipelineUpdateContext.pipelineExecution.runPipeline() } - // scalastyle:off println - println(s"INSTRUMENTATION: Pipeline execution completed") - // scalastyle:on println - // Rethrow any exceptions that caused the pipeline run to fail so that the exception is // propagated back to the SC client / CLI. runFailureEvent.foreach { event => @@ -598,10 +589,6 @@ private[connect] object PipelinesHandler extends Logging { var signalAttempts = 0 val maxSignalAttempts = 600 - // scalastyle:off println - println(s"INSTRUMENTATION: Starting signal loop") - // scalastyle:on println - while (execution.resolvedGraph.isEmpty && signalAttempts < maxSignalAttempts) { val signal = proto.PipelineQueryFunctionExecutionSignal.newBuilder() @@ -611,9 +598,6 @@ private[connect] object PipelinesHandler extends Logging { } if (signal.getFlowNamesCount > 0) { - // scalastyle:off println - println(s"INSTRUMENTATION: Sending signal for ${signal.getFlowNamesCount} flows") - // scalastyle:on println logInfo(s"Sending execution signal for ${signal.getFlowNamesCount} flows") val response = ExecutePlanResponse.newBuilder() @@ -621,12 +605,7 @@ private[connect] object PipelinesHandler extends Logging { .build() responseObserver.onNext(response) - // scalastyle:off println - println(s"INSTRUMENTATION: Signal sent successfully") - // scalastyle:on println } else { - // scalastyle:off println - println(s"INSTRUMENTATION: No signals, attempt $signalAttempts") // scalastyle:on println } @@ -634,10 +613,6 @@ private[connect] object PipelinesHandler extends Logging { signalAttempts += 1 } - // scalastyle:off println - println(s"INSTRUMENTATION: Exited signal loop, attempts=$signalAttempts") - // scalastyle:on println - responseObserver.onCompleted() streamCompleted = true } catch { @@ -685,9 +660,6 @@ private[connect] object PipelinesHandler extends Logging { case Some(pipelineUpdateContext) => // TODO: what if we haven't yet started analysis? val graphAnalysisContext = pipelineUpdateContext.pipelineExecution.graphAnalysisContext - // scalastyle:off println - println(s"INSTRUMENTATION: markFlowPlanRegistered called for flow $flowIdentifier") - // scalastyle:on println graphAnalysisContext.markFlowPlanRegistered(flowIdentifier) case None => } diff --git a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SessionHolder.scala b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SessionHolder.scala index 07da1f66fb21e..86906b2825b28 100644 --- a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SessionHolder.scala +++ b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SessionHolder.scala @@ -332,11 +332,6 @@ case class SessionHolder(userId: String, sessionId: String, session: SparkSessio // It is not called under SparkConnectSessionManager.sessionsLock, but it's guaranteed to be // called only once, since removing the session from SparkConnectSessionManager.sessionStore is // synchronized and guaranteed to happen only once. - // scalastyle:off println - println(s"INSTRUMENTATION: SessionHolder.close() called for session $sessionId") - println(s"INSTRUMENTATION: Called from: ${Thread.currentThread().getStackTrace.take(10). - mkString("\n")}") - // scalastyle:on println if (closedTimeMs.isDefined) { throw new IllegalStateException(s"Session $key is already closed.") } @@ -498,7 +493,6 @@ case class SessionHolder(userId: String, sessionId: String, session: SparkSessio private[connect] def cachePipelineExecution( graphId: String, pipelineUpdateContext: PipelineUpdateContext): Unit = { - print("pizza: caching pipeline execution") pipelineExecutions.compute( graphId, (_, existing) => { diff --git a/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/CoreDataflowNodeProcessor.scala b/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/CoreDataflowNodeProcessor.scala index 8e7db57a02a54..5d0e3858f6fc6 100644 --- a/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/CoreDataflowNodeProcessor.scala +++ b/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/CoreDataflowNodeProcessor.scala @@ -40,9 +40,6 @@ class CoreDataflowNodeProcessor(rawGraph: DataflowGraph, context: GraphAnalysisC context.resolvedInputsByIdentifier ) context.resolvedFlowsQueue.add(resolvedFlow) - // scalastyle:off println - println(s"INSTRUMENTATION: Adding resolved flow ${flow.identifier} to resolvedFlowsMap") - // scalastyle:on println context.resolvedFlowsMap.put(flow.identifier, resolvedFlow) resolvedFlow } diff --git a/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/DataflowGraphTransformer.scala b/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/DataflowGraphTransformer.scala index b67512df182ec..66dae5099f091 100644 --- a/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/DataflowGraphTransformer.scala +++ b/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/DataflowGraphTransformer.scala @@ -122,9 +122,6 @@ class DataflowGraphTransformer(graph: DataflowGraph) extends AutoCloseable { transformer: (GraphElement, Seq[GraphElement]) => Seq[GraphElement], context: GraphAnalysisContext, disableParallelism: Boolean = false): DataflowGraphTransformer = { - // scalastyle:off println - println("INSTRUMENTATION: transformDownNodes starting") - // scalastyle:on println val executor = if (disableParallelism) selfExecutor else fixedPoolExecutor val batchSize = if (disableParallelism) 1 else parallelism @@ -135,11 +132,6 @@ class DataflowGraphTransformer(graph: DataflowGraph) extends AutoCloseable { futures.nonEmpty || context.toBeResolvedFlows.peekFirst() != null || !context.failedUnregisteredFlows.isEmpty) { - // scalastyle:off println - println("INSTRUMENTATION: transformDownNodes iteration") - println(s"INSTRUMENTATION: toBeResolvedFlows.size = ${context.toBeResolvedFlows.size}") - println(s"INSTRUMENTATION: failedUnregistered.size = ${context.failedUnregisteredFlows.size}") - // scalastyle:on println val (done, notDone) = futures.partition(_.isDone) // Explicitly call future.get() to propagate exceptions one by one if any try { @@ -167,9 +159,6 @@ class DataflowGraphTransformer(graph: DataflowGraph) extends AutoCloseable { } } - // scalastyle:off println - println("INSTRUMENTATION: transformDownNodes - exited main while loop") - // scalastyle:on println // Mutate the fail analysis entities // A table is failed to analyze if: @@ -212,8 +201,6 @@ class DataflowGraphTransformer(graph: DataflowGraph) extends AutoCloseable { tableMap = computeTableMap() viewMap = computeViewMap() sinkMap = computeSinkMap() - // scalastyle:off println - println("pizza: completed transformDownNodes") // scalastyle: on this } @@ -230,9 +217,6 @@ class DataflowGraphTransformer(graph: DataflowGraph) extends AutoCloseable { case e: TransformNodeRetryableException => e.datasetIdentifier match { case None => - // scalastyle:off println - println(s"INSTRUMENTATION: Adding flow ${e.failedNode.identifier} to unregistered") - // scalastyle:on println context.failedUnregisteredFlows.put(e.failedNode.identifier, e.failedNode) case Some(datasetIdentifier) => context.registerFailedDependentFlow(datasetIdentifier, e.failedNode) @@ -250,9 +234,6 @@ class DataflowGraphTransformer(graph: DataflowGraph) extends AutoCloseable { (_, toRetryFlows) => { toRetryFlows.foreach { f => if (context.failedUnregisteredFlows.containsKey(f.identifier)) { - // scalastyle:off println - println(s"INSTRUMENTATION: Adding signal for ${f.identifier}") - // scalastyle:on println context.flowClientSignalQueue.add(f.identifier) } else { context.toBeResolvedFlows.addFirst(f) @@ -329,9 +310,6 @@ class DataflowGraphTransformer(graph: DataflowGraph) extends AutoCloseable { (_, toRetryFlows) => { toRetryFlows.foreach { f => if (context.failedUnregisteredFlows.containsKey(f.identifier)) { - // scalastyle:off println - println(s"INSTRUMENTATION: Adding signal for ${f.identifier} (location 2)") - // scalastyle:on println context.flowClientSignalQueue.add(f.identifier) } else { context.toBeResolvedFlows.addFirst(f) diff --git a/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/GraphAnalysisContext.scala b/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/GraphAnalysisContext.scala index 45814a5702677..92462d0731a05 100644 --- a/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/GraphAnalysisContext.scala +++ b/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/GraphAnalysisContext.scala @@ -78,22 +78,12 @@ class GraphAnalysisContext { } def markFlowPlanRegistered(flowIdentifier: TableIdentifier): Unit = { - // scalastyle:off println - println(s"INSTRUMENTATION: markFlowPlanRegistered - looking for flow $flowIdentifier") - println(s"INSTRUMENTATION: failedUnregisteredFlows size: ${failedUnregisteredFlows.size()}") - // scalastyle:on println val flow = failedUnregisteredFlows.remove(flowIdentifier) // Flow could be null if it failed on the client side and we never tried to execute the flow // function on the server side. if (flow != null) { - // scalastyle:off println - println(s"INSTRUMENTATION: Found flow $flowIdentifier, adding to toBeResolvedFlows queue") - // scalastyle:on println toBeResolvedFlows.addFirst(flow) } else { - // scalastyle:off println - println(s"INSTRUMENTATION: Flow $flowIdentifier not found in failedUnregisteredFlows") - // scalastyle:on println } } diff --git a/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/PipelineExecution.scala b/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/PipelineExecution.scala index d4be01f063a88..f5f3c9dd01e66 100644 --- a/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/PipelineExecution.scala +++ b/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/PipelineExecution.scala @@ -89,13 +89,7 @@ class PipelineExecution(context: PipelineUpdateContext) { /** Validates that the pipeline graph can be successfully resolved and validates it. */ def dryRunPipeline(): Unit = synchronized { - // scalastyle:off println - println("INSTRUMENTATION: dryRunPipeline() starting") - // scalastyle:on println resolvedGraph = Some(resolveGraph()) - // scalastyle:off println - println("INSTRUMENTATION: dryRunPipeline() - resolveGraph completed") - // scalastyle:on println context.eventCallback( constructTerminationEvent(RunCompletion()) ) @@ -118,14 +112,8 @@ class PipelineExecution(context: PipelineUpdateContext) { } def resolveGraph(): DataflowGraph = { - // scalastyle:off println - println("INSTRUMENTATION: resolveGraph() starting") - // scalastyle:on println try { val resolved = context.unresolvedGraph.resolve(Some(graphAnalysisContext)).validate() - // scalastyle:off println - println("INSTRUMENTATION: resolveGraph() completed successfully") - // scalastyle:on println resolved } catch { case e: UnresolvedPipelineException =>