Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions python/pyspark/pipelines/graph_element_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,15 @@ def register_sql(self, sql_text: str, file_path: Path) -> None:
:param file_path: The path to the file that the SQL txt came from.
"""

@abstractmethod
def register_signalled_query_functions(self) -> None:
"""Open up a stream to receive query function execution signals from the server.
When a signal is received, execute the corresponding query function and register the
result with the server.

This method should be called after all the flows and datasets have been registered.
"""


_graph_element_registry_context_var: ContextVar[Optional[GraphElementRegistry]] = ContextVar(
"graph_element_registry_context", default=None
Expand Down
75 changes: 68 additions & 7 deletions python/pyspark/pipelines/spark_connect_graph_element_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,10 @@
# limitations under the License.
#
from pathlib import Path
from pyspark.errors.exceptions.base import PySparkValueError

from pyspark.errors import PySparkTypeError
from pyspark.sql import SparkSession
from pyspark.errors import PySparkException, PySparkTypeError
from pyspark.sql import SparkSession, DataFrame
from pyspark.sql.connect.dataframe import DataFrame as ConnectDataFrame
from pyspark.pipelines.output import (
Output,
Expand All @@ -35,6 +36,8 @@
from typing import Any, cast
import pyspark.sql.connect.proto as pb2
from pyspark.pipelines.add_pipeline_analysis_context import add_pipeline_analysis_context
from pyspark.pipelines.flow import QueryFunction
import uuid


class SparkConnectGraphElementRegistry(GraphElementRegistry):
Expand All @@ -46,6 +49,12 @@ def __init__(self, spark: SparkSession, dataflow_graph_id: str) -> None:
self._spark = spark
self._client = cast(Any, spark).client
self._dataflow_graph_id = dataflow_graph_id
self._client_id = str(uuid.uuid4())
self._query_funcs_by_flow_name: dict[str, QueryFunction] = {}

@property
def dataflow_graph_id(self) -> str:
return self._dataflow_graph_id

def register_output(self, output: Output) -> None:
table_details = None
Expand Down Expand Up @@ -111,11 +120,19 @@ def register_output(self, output: Output) -> None:
self._client.execute_command(command)

def register_flow(self, flow: Flow) -> None:
with add_pipeline_analysis_context(
spark=self._spark, dataflow_graph_id=self._dataflow_graph_id, flow_name=flow.name
):
df = flow.func()
relation = cast(ConnectDataFrame, df)._plan.plan(self._client)
self._query_funcs_by_flow_name[flow.name] = flow.func
try:
df = self._execute_query_function(flow.name, flow.func)
except PySparkException as e:
if e.getCondition() == "ATTEMPT_ANALYSIS_IN_PIPELINE_QUERY_FUNCTION":
df = None
else:
raise e

if df is not None:
relation = cast(ConnectDataFrame, df)._plan.plan(self._client)
else:
relation = None

relation_flow_details = pb2.PipelineCommand.DefineFlow.WriteRelationFlowDetails(
relation=relation,
Expand All @@ -128,6 +145,7 @@ def register_flow(self, flow: Flow) -> None:
relation_flow_details=relation_flow_details,
sql_conf=flow.spark_conf,
source_code_location=source_code_location_to_proto(flow.source_code_location),
client_id=self._client_id,
)
command = pb2.Command()
command.pipeline_command.define_flow.CopyFrom(inner_command)
Expand All @@ -143,6 +161,49 @@ def register_sql(self, sql_text: str, file_path: Path) -> None:
command.pipeline_command.define_sql_graph_elements.CopyFrom(inner_command)
self._client.execute_command(command)

def register_signalled_query_functions(self) -> None:
"""Open up a stream to receive query function execution signals from the server.
When a signal is received, execute the corresponding query function and register the
result with the server.

This method should be called after all the flows and datasets have been registered.
"""

inner_command = pb2.PipelineCommand.GetQueryFunctionExecutionSignalStream(
dataflow_graph_id=self._dataflow_graph_id,
client_id=self._client_id,
)
command = pb2.Command()
command.pipeline_command.get_query_function_execution_signal_stream.CopyFrom(inner_command)

result_iter = self._client.execute_command_as_iterator(command)
for result in result_iter:
if "pipeline_query_function_execution_signal" not in result.keys():
raise PySparkValueError(
f"GetQueryFunctionExecutionSignalStream received unexpected result: {result}"
)

signal = result["pipeline_query_function_execution_signal"]
flow_names = signal.flow_names
for flow_name in flow_names:
func = self._query_funcs_by_flow_name[flow_name]
df = self._execute_query_function(flow_name, func)
relation = cast(ConnectDataFrame, df)._plan.plan(self._client)
inner_command = pb2.PipelineCommand.DefineFlowQueryFunctionResult(
dataflow_graph_id=self._dataflow_graph_id,
flow_name=flow_name,
relation=relation,
)
command = pb2.Command()
command.pipeline_command.define_flow_query_function_result.CopyFrom(inner_command)
self._client.execute_command(command)

def _execute_query_function(self, flow_name: str, func: QueryFunction) -> DataFrame:
with add_pipeline_analysis_context(
spark=self._spark, dataflow_graph_id=self._dataflow_graph_id, flow_name=flow_name
):
return func()


def source_code_location_to_proto(
source_code_location: SourceCodeLocation,
Expand Down
45 changes: 45 additions & 0 deletions python/pyspark/pipelines/tests/python_pipeline_suite_helpers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
from pyspark.sql import SparkSession
from pyspark import pipelines as dp
from pyspark.pipelines.spark_connect_graph_element_registry import (
SparkConnectGraphElementRegistry,
)
from pyspark.pipelines.spark_connect_pipeline import create_dataflow_graph
from pyspark.pipelines.spark_connect_pipeline import start_run, handle_pipeline_events
import threading

def setup(server_port: str, session_identifier: str) -> tuple[SparkSession, SparkConnectGraphElementRegistry]:
spark = SparkSession.builder \
.remote(f"sc://localhost:{server_port}") \
.config("spark.connect.grpc.channel.timeout", "5s") \
.config("spark.custom.identifier", session_identifier) \
.create()

dataflow_graph_id = create_dataflow_graph(
spark,
default_catalog=None,
default_database=None,
sql_conf={},
)

registry = SparkConnectGraphElementRegistry(spark, dataflow_graph_id)
return spark, registry


def run_and_handle_signals(
spark: SparkSession,
registry: SparkConnectGraphElementRegistry,
storage_root: str,
) -> None:
result_iter = start_run(
spark,
dataflow_graph_id=registry.dataflow_graph_id,
full_refresh=None,
full_refresh_all=False,
refresh=None,
dry=True,
storage=storage_root
)
thread = threading.Thread(target=handle_pipeline_events, args=(result_iter,), daemon=True)
thread.start()

registry.register_signalled_query_functions()
Original file line number Diff line number Diff line change
Expand Up @@ -46,30 +46,25 @@ class DataflowGraphRegistry {
graphId
}

/** Retrieves the graph for a given id. */
def getDataflowGraph(graphId: String): Option[GraphRegistrationContext] = {
Option(dataflowGraphs.get(graphId))
/** Returns a Map of dataflow graph IDs to their corresponding GraphRegistrationContext. */
def getDataflowGraphs: Map[String, GraphRegistrationContext] = {
dataflowGraphs.asScala.toMap
}

/** Retrieves the graph for a given id, and throws if the id could not be found. */
def getDataflowGraphOrThrow(dataflowGraphId: String): GraphRegistrationContext =
getDataflowGraph(dataflowGraphId).getOrElse {
getDataflowGraphs.getOrElse(dataflowGraphId, {
throw new SparkException(
errorClass = "DATAFLOW_GRAPH_NOT_FOUND",
messageParameters = Map("graphId" -> dataflowGraphId),
cause = null)
}
})

/** Removes the graph with a given id from the registry. */
def dropDataflowGraph(graphId: String): Unit = {
dataflowGraphs.remove(graphId)
}

/** Returns all graphs in the registry. */
def getAllDataflowGraphs: Seq[GraphRegistrationContext] = {
dataflowGraphs.values().asScala.toSeq
}

/** Removes all graphs from the registry. */
def dropAllDataflowGraphs(): Unit = {
dataflowGraphs.clear()
Expand Down
Loading