diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala index 4e6089b50a46d..1524ff455ec75 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala @@ -415,13 +415,17 @@ private[spark] abstract class BasePythonRunner[IN, OUT]( */ def writeNextInputToStream(dataOut: DataOutputStream): Boolean - def open(dataOut: DataOutputStream): Unit = Utils.logUncaughtExceptions { + def open(outputStream: DataOutputStream): Unit = Utils.logUncaughtExceptions { val isUnixDomainSock = authHelper.conf.get(PYTHON_UNIX_DOMAIN_SOCKET_ENABLED) lazy val sockPath = new File( authHelper.conf.get(PYTHON_UNIX_DOMAIN_SOCKET_DIR) .getOrElse(System.getProperty("java.io.tmpdir")), s".${UUID.randomUUID()}.sock") try { + // Buffer the initialization message, and send it together with its length. + val buffer = new ByteArrayOutputStream() + val dataOut = new DataOutputStream(buffer) + // Partition index dataOut.writeInt(partitionIndex) @@ -522,6 +526,13 @@ private[spark] abstract class BasePythonRunner[IN, OUT]( writeCommand(dataOut) dataOut.flush() + + // The initialization message is complete, write it to the stream with its length. + val messageBytes = buffer.toByteArray + outputStream.writeInt(SpecialLengths.START_OF_INIT_MESSAGE) + outputStream.writeInt(messageBytes.length) + outputStream.write(messageBytes) + outputStream.flush() } catch { case t: Throwable if NonFatal(t) || t.isInstanceOf[Exception] => if (context.isCompleted() || context.isInterrupted()) { @@ -1085,6 +1096,7 @@ private[spark] object SpecialLengths { val NULL = -5 val START_ARROW_STREAM = -6 val END_OF_MICRO_BATCH = -7 + val START_OF_INIT_MESSAGE = -8 } private[spark] object BarrierTaskContextMessageProtocol { diff --git a/python/packaging/classic/setup.py b/python/packaging/classic/setup.py index 911c50141e43f..8c201d4c25807 100755 --- a/python/packaging/classic/setup.py +++ b/python/packaging/classic/setup.py @@ -267,6 +267,8 @@ def run(self): "pyspark", "pyspark.core", "pyspark.cloudpickle", + "pyspark.messages", + "pyspark.messages.socket", "pyspark.mllib", "pyspark.mllib.linalg", "pyspark.mllib.stat", diff --git a/python/packaging/client/setup.py b/python/packaging/client/setup.py index 17475e9e065ad..182ec11ab2d77 100755 --- a/python/packaging/client/setup.py +++ b/python/packaging/client/setup.py @@ -148,6 +148,8 @@ connect_packages = [ "pyspark", "pyspark.cloudpickle", + "pyspark.messages", + "pyspark.messages.socket", "pyspark.mllib", "pyspark.mllib.linalg", "pyspark.mllib.stat", diff --git a/python/pyspark/messages/__init__.py b/python/pyspark/messages/__init__.py index ccb7b9323257f..69cfbf6bd53a2 100644 --- a/python/pyspark/messages/__init__.py +++ b/python/pyspark/messages/__init__.py @@ -15,8 +15,12 @@ # limitations under the License. # +from pyspark.messages.spark_message_receiver import SparkMessageReceiver from pyspark.messages.zero_copy_byte_stream import ZeroCopyByteStream +from pyspark.messages.socket.spark_socket_message_receiver import SparkSocketMessageReceiver __all__ = [ + "SparkMessageReceiver", + "SparkSocketMessageReceiver", "ZeroCopyByteStream", ] diff --git a/python/pyspark/messages/socket/__init__.py b/python/pyspark/messages/socket/__init__.py new file mode 100644 index 0000000000000..cce3acad34a49 --- /dev/null +++ b/python/pyspark/messages/socket/__init__.py @@ -0,0 +1,16 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# diff --git a/python/pyspark/messages/socket/spark_socket_message_receiver.py b/python/pyspark/messages/socket/spark_socket_message_receiver.py new file mode 100644 index 0000000000000..fe46d988e8392 --- /dev/null +++ b/python/pyspark/messages/socket/spark_socket_message_receiver.py @@ -0,0 +1,64 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from typing import BinaryIO + +from pyspark.serializers import read_int, SpecialLengths +from pyspark.messages.zero_copy_byte_stream import ZeroCopyByteStream +from pyspark.messages.spark_message_receiver import ( + SparkMessageReceiver, +) + + +def _assert_message_id(message_id: int, expected: int) -> None: + assert message_id == expected, ( + f"Expected message with id {expected} " + f"but got message with id {message_id} instead." + ) + + +class SparkSocketMessageReceiver(SparkMessageReceiver): + def __init__(self, infile: BinaryIO): + super().__init__() + self._infile = infile + + def _do_get_init_message(self) -> ZeroCopyByteStream: + message_id = read_int(self._infile) + _assert_message_id(message_id, SpecialLengths.START_OF_INIT_MESSAGE) + + # Read the length and init content + message_length = read_int(self._infile) + message_content = self._infile.read(message_length) + + return ZeroCopyByteStream(memoryview(message_content)) + + def _do_get_data_stream(self) -> BinaryIO: + # For socket communication, we just pass along the underlying socket + # for the data channel. We already stripped the initialization data + # at this state. Therefore, any bytes following this are data bytes. + # + # Note: We deliberately did not introduce a message header for + # data messages to reduce the overhead, especially for small + # batch sizes and real-time-mode (RTM). + return self._infile + + def _do_is_stream_finished(self) -> bool: + # Check if the stream is finished. + # If everything finished properly, we should read a + # 'END_OF_STREAM'. If we read somethign else this means + # the stream has unread data and something went wrong + # during processing. + return read_int(self._infile) == SpecialLengths.END_OF_STREAM diff --git a/python/pyspark/messages/spark_message_receiver.py b/python/pyspark/messages/spark_message_receiver.py new file mode 100644 index 0000000000000..903a2fb114083 --- /dev/null +++ b/python/pyspark/messages/spark_message_receiver.py @@ -0,0 +1,126 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from enum import Enum +from functools import wraps +from typing import BinaryIO, Callable, TypeVar +from abc import ABC, abstractmethod + +from pyspark.messages.zero_copy_byte_stream import ZeroCopyByteStream + + +T = TypeVar("T", bound="SparkMessageReceiver") +R = TypeVar("R") + + +class MessageState(Enum): + WAITING_FOR_INIT = 1 + WAITING_FOR_DATA = 2 + WAITING_FOR_FINISH = 3 + DONE = 4 + + +class SparkMessageReceiver(ABC): + """ + Generic class that implements receiving messages from Spark. + Caution: This class is STATEFUL. It is expected, that the + methods of this class are called in the following order: + + 1. Init -> 2. Data stream -> 3. Finish + + This order is verified using assertions in the class. Each function + can be called EXACTLY ONCE in the specified order. + """ + + def __init__(self) -> None: + self._state = MessageState.WAITING_FOR_INIT + + @staticmethod + def _state_transition( + required_state: MessageState, next_state: MessageState + ) -> Callable[[Callable[[T], R]], Callable[[T], R]]: + """Decorator to enforce state transitions.""" + + def decorator(func: Callable[[T], R]) -> Callable[[T], R]: + @wraps(func) + def wrapper(self: T) -> R: + assert self._state == required_state + result = func(self) + self._state = next_state + return result + + return wrapper + + return decorator + + @_state_transition(MessageState.WAITING_FOR_INIT, MessageState.WAITING_FOR_DATA) + def get_init_message(self) -> ZeroCopyByteStream: + """ + Returns: + the binary contents of the initial message as a ZeroCopyByteStream. + """ + return self._do_get_init_message() + + @_state_transition(MessageState.WAITING_FOR_DATA, MessageState.WAITING_FOR_FINISH) + def get_data_stream(self) -> BinaryIO: + """ + Returns: + A binary stream containing the data to invoke the UDF on. + """ + return self._do_get_data_stream() + + @_state_transition(MessageState.WAITING_FOR_FINISH, MessageState.DONE) + def is_stream_finished(self) -> bool: + """ + Checks if a finish message was received + from the JVM. The finish message itself only + has a message id and marks the end of the stream. + If bytes different from the finish id are read + this means something went wrong while consuming the stream. + """ + return self._do_is_stream_finished() + + @abstractmethod + def _do_get_init_message(self) -> ZeroCopyByteStream: + """ + Returns the contents of the init message + as a 'ZeroCopyByteStream'. + + To be implemented by child classes. + """ + pass + + @abstractmethod + def _do_get_data_stream(self) -> BinaryIO: + """ + Returns the Spark data stream. + + To be implemented by child classes. + """ + pass + + @abstractmethod + def _do_is_stream_finished(self) -> bool: + """ + Blocking call that returns whether + the data stream from the JVM is finished. + This is implemented differently, depending + on the transport channel. + + To be implemented by child classes. + """ + pass diff --git a/python/pyspark/serializers.py b/python/pyspark/serializers.py index 6de64a1062f0b..48166c948b5b1 100644 --- a/python/pyspark/serializers.py +++ b/python/pyspark/serializers.py @@ -63,6 +63,7 @@ import zlib import itertools import pickle +import codecs pickle_protocol = pickle.HIGHEST_PROTOCOL @@ -84,6 +85,7 @@ class SpecialLengths: END_OF_STREAM = -4 NULL = -5 START_ARROW_STREAM = -6 + START_OF_INIT_MESSAGE = -8 class Serializer: @@ -539,7 +541,7 @@ def loads(self, stream): elif length == SpecialLengths.NULL: return None s = stream.read(length) - return s.decode("utf-8") if self.use_unicode else s + return codecs.decode(s, "utf-8") if self.use_unicode else s def load_stream(self, stream): try: diff --git a/python/pyspark/taskcontext.py b/python/pyspark/taskcontext.py index 7ac6453232436..50cba601321ba 100644 --- a/python/pyspark/taskcontext.py +++ b/python/pyspark/taskcontext.py @@ -161,7 +161,7 @@ def _getOrCreate(cls: Type["TaskContext"]) -> "TaskContext": return cls._taskContext @classmethod - def _setTaskContext(cls: Type["TaskContext"], taskContext: "TaskContext") -> None: + def _setTaskContext(cls: Type["TaskContext"], taskContext: Optional["TaskContext"]) -> None: cls._taskContext = taskContext @classmethod diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py index cb7310589540d..033761c8c3dc6 100644 --- a/python/pyspark/worker.py +++ b/python/pyspark/worker.py @@ -39,6 +39,7 @@ Union, get_args, get_origin, + BinaryIO, ) T = TypeVar("T") @@ -58,7 +59,6 @@ from pyspark.serializers import ( write_int, write_long, - read_int, SpecialLengths, CPickleSerializer, BatchedSerializer, @@ -120,6 +120,10 @@ Conf, ) from pyspark.logger.worker_io import capture_outputs +from pyspark.messages import ( + SparkMessageReceiver, + SparkSocketMessageReceiver, +) class RunnerConf(Conf): @@ -3574,11 +3578,20 @@ def func(_, it): return func, None, ser, ser -@with_faulthandler -def main(infile, outfile): +def invoke_udf(message_receiver: SparkMessageReceiver, outfile: BinaryIO): + """ + This function is the main processing function for worker.py. + It receives messages from the JVM, processes the data, and sends back results. + This method goes through three phases: + + Initialization -> Processing -> Finish/Cleanup + """ try: boot_time = time.time() + # Initialization + infile = message_receiver.get_init_message() init_info = WorkerInitInfo.from_stream(infile) + start_faulthandler_periodic_traceback() check_python_version(init_info.python_version) @@ -3602,6 +3615,10 @@ def main(infile, outfile): runner_conf = RunnerConf(init_info.runner_conf) eval_conf = EvalConf(init_info.eval_conf) if eval_type == PythonEvalType.NON_UDF: + # The type checker needs some help here.. + # See the code in WorkerInitInfo.from_stream(infile) + # to see the correct type. + assert isinstance(init_info.udf_info, memoryview) func, profiler, deserializer, serializer = read_command(pickleSer, init_info.udf_info) elif eval_type in ( PythonEvalType.SQL_TABLE_UDF, @@ -3618,8 +3635,13 @@ def main(infile, outfile): init_time = time.time() + # Processing + + # Fetch the input data stream + input_data_stream = message_receiver.get_data_stream() + def process(): - iterator = deserializer.load_stream(infile) + iterator = deserializer.load_stream(input_data_stream) out_iter = func(init_info.split_index, iterator) try: serializer.dump_stream(out_iter, outfile) @@ -3635,6 +3657,7 @@ def process(): process() processing_time_ms = int(1000 * (time.time() - processing_start_time)) + # Cleanup # Reset task context to None. This is a guard code to avoid residual context when worker # reuse. TaskContext._setTaskContext(None) @@ -3652,7 +3675,7 @@ def process(): send_accumulator_updates(outfile) # check end of stream - if read_int(infile) == SpecialLengths.END_OF_STREAM: + if message_receiver.is_stream_finished(): write_int(SpecialLengths.END_OF_STREAM, outfile) else: # write a different value to tell JVM to not reuse this worker @@ -3660,6 +3683,14 @@ def process(): sys.exit(-1) +@with_faulthandler +def main(infile, outfile): + # Instantiate socket message readers for executing the UDF + socket_reader = SparkSocketMessageReceiver(infile) + + invoke_udf(socket_reader, outfile) + + if __name__ == "__main__": with get_sock_file_to_executor() as sock_file: main(sock_file, sock_file) diff --git a/python/pyspark/worker_message.py b/python/pyspark/worker_message.py index b1519cb084421..707c481761fbe 100644 --- a/python/pyspark/worker_message.py +++ b/python/pyspark/worker_message.py @@ -18,13 +18,14 @@ import dataclasses import json import sys -from typing import Optional, Union, IO +from typing import Optional, Union, IO, Any from pyspark.errors import PySparkValueError from pyspark.serializers import read_bool, read_int, read_long, SpecialLengths from pyspark.taskcontext import BarrierTaskContext, ResourceInformation, TaskContext from pyspark.util import PythonEvalType from pyspark.worker_util import utf8_deserializer +from pyspark.messages import ZeroCopyByteStream @dataclasses.dataclass @@ -46,7 +47,7 @@ class ResourceInfo: local_properties: dict[str, str] @classmethod - def from_stream(cls, stream: IO) -> "TaskContextInfo": + def from_stream(cls, stream: ZeroCopyByteStream) -> "TaskContextInfo": task_context_json = json.loads(utf8_deserializer.loads(stream)) return cls( is_barrier=task_context_json["isBarrier"], @@ -100,7 +101,7 @@ class BroadcastInfo: variables: list[tuple[int, Optional[str]]] @classmethod - def from_stream(cls, stream: IO) -> "BroadcastInfo": + def from_stream(cls, stream: Union[ZeroCopyByteStream, IO[Any]]) -> "BroadcastInfo": needs_broadcast_decryption_server = read_bool(stream) num_broadcast_variables = read_int(stream) conn_info = None @@ -125,13 +126,13 @@ def from_stream(cls, stream: IO) -> "BroadcastInfo": @dataclasses.dataclass class UDFInfo: - udfs: list[bytes] + udfs: list[memoryview] args: list[int] kwargs: dict[str, int] result_id: int @classmethod - def from_stream(cls, stream: IO) -> "UDFInfo": + def from_stream(cls, stream: ZeroCopyByteStream) -> "UDFInfo": num_args = read_int(stream) udfs = [] args = [] @@ -167,13 +168,13 @@ class UDTFInfo: args: list[int] kwargs: dict[str, int] partition_child_indexes: list[int] - pickled_analyze_result: Optional[bytes] - handler: bytes + pickled_analyze_result: Optional[memoryview] + handler: memoryview return_type: str name: str @classmethod - def from_stream(cls, stream: IO) -> "UDTFInfo": + def from_stream(cls, stream: ZeroCopyByteStream) -> "UDTFInfo": # See 'PythonUDTFRunner.PythonUDFWriterThread.writeCommand' args = [] kwargs = {} @@ -215,10 +216,10 @@ class WorkerInitInfo: eval_type: int runner_conf: dict[str, str] eval_conf: dict[str, str] - udf_info: Union[bytes, UDTFInfo, list[UDFInfo]] + udf_info: Union[memoryview, UDTFInfo, list[UDFInfo]] @classmethod - def from_stream(cls, stream: IO) -> "WorkerInitInfo": + def from_stream(cls, stream: ZeroCopyByteStream) -> "WorkerInitInfo": split_index = read_int(stream) if split_index == -1: sys.exit(-1) @@ -243,7 +244,7 @@ def from_stream(cls, stream: IO) -> "WorkerInitInfo": v = utf8_deserializer.loads(stream) eval_conf[k] = v - udf_info: Union[bytes, UDTFInfo, list[UDFInfo]] + udf_info: Union[memoryview, UDTFInfo, list[UDFInfo]] if eval_type == PythonEvalType.NON_UDF: udf_info = stream.read(read_int(stream)) diff --git a/python/pyspark/worker_util.py b/python/pyspark/worker_util.py index 08edf0c5decb4..e427ccf285b89 100644 --- a/python/pyspark/worker_util.py +++ b/python/pyspark/worker_util.py @@ -27,6 +27,8 @@ from typing import Any, Generator, IO, Optional, Union, overload import warnings +from pyspark.messages import ZeroCopyByteStream + if "SPARK_TESTING" in os.environ: assert os.environ.get("SPARK_PYTHON_RUNTIME") == "PYTHON_WORKER", ( "This module can only be imported in python woker" @@ -65,11 +67,11 @@ def add_path(path: str) -> bool: return False -def read_command(serializer: FramedSerializer, file: Union[IO, bytes]) -> Any: +def read_command(serializer: FramedSerializer, file: Union[IO, memoryview]) -> Any: if not is_remote_only(): from pyspark.core.broadcast import Broadcast - if isinstance(file, bytes): + if isinstance(file, memoryview): command = serializer.loads(file) else: command = serializer._read_with_length(file) @@ -173,7 +175,9 @@ def setup_spark_files( @overload -def setup_broadcasts(infile_or_variables: IO) -> None: ... +def setup_broadcasts(infile_or_variables: IO[Any]) -> None: ... +@overload +def setup_broadcasts(infile_or_variables: ZeroCopyByteStream) -> None: ... @overload def setup_broadcasts( infile_or_variables: list[tuple[int, Union[str, None]]], conn_info: str, auth_secret: None @@ -184,10 +188,12 @@ def setup_broadcasts( ) -> None: ... @overload def setup_broadcasts( - infile_or_variables: list[tuple[int, Union[str, None]]], conn_info: None, auth_secret: None + infile_or_variables: list[tuple[int, Union[str, None]]], + conn_info: Optional[Union[str, int]], + auth_secret: Optional[str], ) -> None: ... def setup_broadcasts( - infile_or_variables: Union[IO, list[tuple[int, Union[str, None]]]], + infile_or_variables: Union[ZeroCopyByteStream, IO[Any], list[tuple[int, Union[str, None]]]], conn_info: Optional[Union[str, int]] = None, auth_secret: Optional[str] = None, ) -> None: