Skip to content

Commit 0fc7546

Browse files
gaogaotiantianHyukjinKwon
authored andcommitted
[SPARK-56519][PYTHON] Isolate communication part from python udf worker
### What changes were proposed in this pull request? We read everything we need for creating Python UDF worker from JVM at the very beginning, so we don't need to read anything while we process. This isolates the communication part from the processing logic. A message interface is designed to hold all the information from JVM. For now it does not change the protocol at all. It is designed to support multiple protocols. The goal is to have a layer that contains all the information needed in a language agnostic form (`int`, `str`, `bytes`, `None`, `list`, `dict`). In the future, we can support a whole different protocol while keep supporting the existing socket protocol - the new protocol just needs to fill in the blanks. We still have some random workers that rely on `setup_spark_files` or `setup_broadcasts`. So in this PR we made all these worker utils support both the old and new input. If the argument is an IO, it uses the original code path. If the function is fed the concrete data, it just uses it to initialize stuff. ### Why are the changes needed? The current protocol is super fragile because we mix the communication and processing. No one knows where exactly the socket is going to be read. It's also heavily based on the actual eval type which is horrible. We want to isolate the communication part out and do read just once - read everything we need then process. It will make the protocol much more robust and easier to debug. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? `test_udf.py` and `test_udtf.py` passed locally. Still need to wait for CI result. ### Was this patch authored or co-authored using generative AI tooling? No. Closes #55380 from gaogaotiantian/isolate-init-data. Authored-by: Tian Gao <gaogaotiantian@hotmail.com> Signed-off-by: Hyukjin Kwon <gurwls223@apache.org>
1 parent a2dd757 commit 0fc7546

4 files changed

Lines changed: 414 additions & 135 deletions

File tree

python/pyspark/taskcontext.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
# See the License for the specific language governing permissions and
1515
# limitations under the License.
1616
#
17-
from typing import ClassVar, Type, TypeVar, Dict, List, Optional, Union, cast
17+
from typing import Any, ClassVar, Type, TypeVar, Dict, List, Optional, Union, cast
1818

1919
from pyspark.util import local_connect_and_auth
2020
from pyspark.serializers import read_int, write_int, write_with_length, UTF8Deserializer
@@ -130,7 +130,7 @@ class TaskContext:
130130
_cpus: Optional[int] = None
131131
_resources: Optional[Dict[str, "ResourceInformation"]] = None
132132

133-
def __new__(cls: Type["TaskContext"], **kwargs: Dict) -> "TaskContext":
133+
def __new__(cls: Type["TaskContext"], **kwargs: Any) -> "TaskContext":
134134
"""
135135
Even if users construct :class:`TaskContext` instead of using get, give them the singleton.
136136
"""
@@ -142,7 +142,7 @@ def __new__(cls: Type["TaskContext"], **kwargs: Dict) -> "TaskContext":
142142

143143
def __init__(
144144
self,
145-
**kwargs: Dict,
145+
**kwargs: Any,
146146
) -> None:
147147
# Set attributes only if they are passed in and not None
148148
# The kwargs are auto-mapped to the private attributes of TaskContext

python/pyspark/worker.py

Lines changed: 47 additions & 79 deletions
Original file line numberDiff line numberDiff line change
@@ -57,8 +57,6 @@
5757
from pyspark.util import PythonEvalType
5858
from pyspark.serializers import (
5959
write_int,
60-
read_long,
61-
read_bool,
6260
write_long,
6361
read_int,
6462
SpecialLengths,
@@ -109,6 +107,7 @@
109107
)
110108
from pyspark import _NoValue, shuffle
111109
from pyspark.errors import PySparkRuntimeError, PySparkTypeError, PySparkValueError
110+
from pyspark.worker_message import WorkerInitInfo
112111
from pyspark.worker_util import (
113112
check_python_version,
114113
get_sock_file_to_executor,
@@ -118,7 +117,6 @@
118117
setup_broadcasts,
119118
setup_memory_limits,
120119
setup_spark_files,
121-
utf8_deserializer,
122120
Conf,
123121
)
124122
from pyspark.logger.worker_io import capture_outputs
@@ -981,39 +979,25 @@ def profiling_func(*args, **kwargs):
981979
return profiling_func
982980

983981

984-
def read_single_udf(pickleSer, infile, eval_type, runner_conf, udf_index):
985-
num_arg = read_int(infile)
986-
987-
args_offsets = []
988-
kwargs_offsets = {}
989-
for _ in range(num_arg):
990-
offset = read_int(infile)
991-
if read_bool(infile):
992-
name = utf8_deserializer.loads(infile)
993-
kwargs_offsets[name] = offset
994-
else:
995-
args_offsets.append(offset)
996-
982+
def read_single_udf(pickleSer, udf_info, eval_type, runner_conf, udf_index):
997983
chained_func = None
998-
for i in range(read_int(infile)):
999-
f, return_type = read_command(pickleSer, infile)
984+
for udf in udf_info.udfs:
985+
f, return_type = read_command(pickleSer, udf)
1000986
if chained_func is None:
1001987
chained_func = f
1002988
else:
1003989
chained_func = chain(chained_func, f)
1004990

1005-
result_id = read_long(infile)
1006-
1007991
# If chained_func is from pyspark.sql.worker, it is to read/write data source.
1008992
# In this case, we check the data_source_profiler config.
1009993
if getattr(chained_func, "__module__", "").startswith("pyspark.sql.worker."):
1010994
profiler = runner_conf.data_source_profiler
1011995
else:
1012996
profiler = runner_conf.udf_profiler
1013997
if profiler == "perf":
1014-
profiling_func = wrap_perf_profiler(chained_func, eval_type, result_id)
998+
profiling_func = wrap_perf_profiler(chained_func, eval_type, udf_info.result_id)
1015999
elif profiler == "memory":
1016-
profiling_func = wrap_memory_profiler(chained_func, eval_type, result_id)
1000+
profiling_func = wrap_memory_profiler(chained_func, eval_type, udf_info.result_id)
10171001
else:
10181002
profiling_func = chained_func
10191003

@@ -1027,6 +1011,8 @@ def read_single_udf(pickleSer, infile, eval_type, runner_conf, udf_index):
10271011
# when they are processed in a for loop, raise them as RuntimeError's instead
10281012
func = fail_on_stopiteration(profiling_func)
10291013

1014+
args_offsets, kwargs_offsets = udf_info.args, udf_info.kwargs
1015+
10301016
# the last returnType will be the return type of UDF
10311017
if eval_type == PythonEvalType.SQL_SCALAR_PANDAS_UDF:
10321018
return func, args_offsets, kwargs_offsets, return_type
@@ -1105,7 +1091,7 @@ def read_single_udf(pickleSer, infile, eval_type, runner_conf, udf_index):
11051091
# It expects the UDTF to be in a specific format and performs various checks to
11061092
# ensure the UDTF is valid. This function also prepares a mapper function for applying
11071093
# the UDTF logic to input rows.
1108-
def read_udtf(pickleSer, infile, eval_type, runner_conf, eval_conf):
1094+
def read_udtf(pickleSer, udtf_info, eval_type, runner_conf, eval_conf):
11091095
if eval_type == PythonEvalType.SQL_ARROW_TABLE_UDF:
11101096
if runner_conf.use_legacy_pandas_udtf_conversion:
11111097
# NOTE: if timezone is set here, that implies respectSessionTimeZone is True
@@ -1126,50 +1112,35 @@ def read_udtf(pickleSer, infile, eval_type, runner_conf, eval_conf):
11261112
# Each row is a group so do not batch but send one by one.
11271113
ser = BatchedSerializer(CPickleSerializer(), 1)
11281114

1129-
# See 'PythonUDTFRunner.PythonUDFWriterThread.writeCommand'
1130-
num_arg = read_int(infile)
1131-
args_offsets = []
1132-
kwargs_offsets = {}
1133-
for _ in range(num_arg):
1134-
offset = read_int(infile)
1135-
if read_bool(infile):
1136-
name = utf8_deserializer.loads(infile)
1137-
kwargs_offsets[name] = offset
1138-
else:
1139-
args_offsets.append(offset)
1140-
num_partition_child_indexes = read_int(infile)
1141-
partition_child_indexes = [read_int(infile) for i in range(num_partition_child_indexes)]
1142-
has_pickled_analyze_result = read_bool(infile)
1143-
if has_pickled_analyze_result:
1144-
pickled_analyze_result = pickleSer._read_with_length(infile)
1115+
if udtf_info.pickled_analyze_result is not None:
1116+
pickled_analyze_result = pickleSer.loads(udtf_info.pickled_analyze_result)
11451117
else:
11461118
pickled_analyze_result = None
11471119
# Initially we assume that the UDTF __init__ method accepts the pickled AnalyzeResult,
11481120
# although we may set this to false later if we find otherwise.
1149-
handler = read_command(pickleSer, infile)
1121+
handler = read_command(pickleSer, udtf_info.handler)
11501122
if not isinstance(handler, type):
11511123
raise PySparkRuntimeError(
11521124
f"Invalid UDTF handler type. Expected a class (type 'type'), but "
11531125
f"got an instance of {type(handler).__name__}."
11541126
)
11551127

1156-
return_type = _parse_datatype_json_string(utf8_deserializer.loads(infile))
1128+
return_type = _parse_datatype_json_string(udtf_info.return_type)
11571129
if not isinstance(return_type, StructType):
11581130
raise PySparkRuntimeError(
11591131
f"The return type of a UDTF must be a struct type, but got {type(return_type)}."
11601132
)
1161-
udtf_name = utf8_deserializer.loads(infile)
11621133

11631134
# Update the handler that creates a new UDTF instance to first try calling the UDTF constructor
11641135
# with one argument containing the previous AnalyzeResult. If that fails, then try a constructor
11651136
# with no arguments. In this way each UDTF class instance can decide if it wants to inspect the
11661137
# AnalyzeResult.
11671138
udtf_init_args = inspect.getfullargspec(handler)
1168-
if has_pickled_analyze_result:
1139+
if pickled_analyze_result is not None:
11691140
if len(udtf_init_args.args) > 2:
11701141
raise PySparkRuntimeError(
11711142
errorClass="UDTF_CONSTRUCTOR_INVALID_IMPLEMENTS_ANALYZE_METHOD",
1172-
messageParameters={"name": udtf_name},
1143+
messageParameters={"name": udtf_info.name},
11731144
)
11741145
elif len(udtf_init_args.args) == 2:
11751146
prev_handler = handler
@@ -1182,7 +1153,7 @@ def construct_udtf():
11821153
elif len(udtf_init_args.args) > 1:
11831154
raise PySparkRuntimeError(
11841155
errorClass="UDTF_CONSTRUCTOR_INVALID_NO_ANALYZE_METHOD",
1185-
messageParameters={"name": udtf_name},
1156+
messageParameters={"name": udtf_info.name},
11861157
)
11871158

11881159
class UDTFWithPartitions:
@@ -1220,7 +1191,7 @@ def __init__(self, create_udtf: Callable, partition_child_indexes: list):
12201191
self._create_udtf: Callable = create_udtf
12211192
self._udtf = create_udtf()
12221193
self._prev_arguments: list = list()
1223-
self._partition_child_indexes: list = partition_child_indexes
1194+
self._partition_child_indexes: list = udtf_info.partition_child_indexes
12241195
self._eval_raised_skip_rest_of_input_table: bool = False
12251196

12261197
def eval(self, *args, **kwargs) -> Iterator:
@@ -1614,13 +1585,13 @@ def _remove_partition_by_exprs(self, arg: Any) -> Any:
16141585

16151586
# Instantiate the UDTF class.
16161587
try:
1617-
if len(partition_child_indexes) > 0:
1588+
if len(udtf_info.partition_child_indexes) > 0:
16181589
# Determine if this is an Arrow UDTF
16191590
is_arrow_udtf = eval_type == PythonEvalType.SQL_ARROW_UDTF
16201591
if is_arrow_udtf:
1621-
udtf = ArrowUDTFWithPartition(handler, partition_child_indexes)
1592+
udtf = ArrowUDTFWithPartition(handler, udtf_info.partition_child_indexes)
16221593
else:
1623-
udtf = UDTFWithPartitions(handler, partition_child_indexes)
1594+
udtf = UDTFWithPartitions(handler, udtf_info.partition_child_indexes)
16241595
else:
16251596
udtf = handler()
16261597
except Exception as e:
@@ -1640,11 +1611,11 @@ def _remove_partition_by_exprs(self, arg: Any) -> Any:
16401611
# Check that the arguments provided to the UDTF call match the expected parameters defined
16411612
# in the 'eval' method signature.
16421613
try:
1643-
inspect.signature(udtf.eval).bind(*args_offsets, **kwargs_offsets)
1614+
inspect.signature(udtf.eval).bind(*udtf_info.args, **udtf_info.kwargs)
16441615
except TypeError as e:
16451616
raise PySparkRuntimeError(
16461617
errorClass="UDTF_EVAL_METHOD_ARGUMENTS_DO_NOT_MATCH_SIGNATURE",
1647-
messageParameters={"name": udtf_name, "reason": str(e)},
1618+
messageParameters={"name": udtf_info.name, "reason": str(e)},
16481619
) from None
16491620

16501621
def build_null_checker(return_type: StructType) -> Optional[Callable[[Any], None]]:
@@ -1848,7 +1819,7 @@ def evaluate(*args: pd.Series, num_rows=1):
18481819
return evaluate
18491820

18501821
eval_func_kwargs_support, args_kwargs_offsets = wrap_kwargs_support(
1851-
getattr(udtf, "eval"), args_offsets, kwargs_offsets
1822+
getattr(udtf, "eval"), udtf_info.args, udtf_info.kwargs
18521823
)
18531824
eval = wrap_arrow_udtf(eval_func_kwargs_support, return_type)
18541825

@@ -2010,7 +1981,7 @@ def evaluate(*args: list, num_rows=1):
20101981
return evaluate
20111982

20121983
eval_func_kwargs_support, args_kwargs_offsets = wrap_kwargs_support(
2013-
getattr(udtf, "eval"), args_offsets, kwargs_offsets
1984+
getattr(udtf, "eval"), udtf_info.args, udtf_info.kwargs
20141985
)
20151986
eval = wrap_arrow_udtf(eval_func_kwargs_support, return_type)
20161987

@@ -2136,7 +2107,7 @@ def evaluate(*args: pa.RecordBatch):
21362107
return evaluate
21372108

21382109
eval_func_kwargs_support, args_kwargs_offsets = wrap_kwargs_support(
2139-
getattr(udtf, "eval"), args_offsets, kwargs_offsets
2110+
getattr(udtf, "eval"), udtf_info.args, udtf_info.kwargs
21402111
)
21412112
eval = wrap_pyarrow_udtf(eval_func_kwargs_support, return_type)
21422113

@@ -2238,7 +2209,7 @@ def evaluate(*a) -> tuple:
22382209
return evaluate
22392210

22402211
eval_func_kwargs_support, args_kwargs_offsets = wrap_kwargs_support(
2241-
getattr(udtf, "eval"), args_offsets, kwargs_offsets
2212+
getattr(udtf, "eval"), udtf_info.args, udtf_info.kwargs
22422213
)
22432214
eval = wrap_udtf(eval_func_kwargs_support, return_type)
22442215

@@ -2266,7 +2237,7 @@ def mapper(_, it):
22662237
return mapper, None, ser, ser
22672238

22682239

2269-
def read_udfs(pickleSer, infile, eval_type, runner_conf, eval_conf):
2240+
def read_udfs(pickleSer, udf_info_list, eval_type, runner_conf, eval_conf):
22702241
if eval_type in (
22712242
PythonEvalType.SQL_ARROW_BATCHED_UDF,
22722243
PythonEvalType.SQL_SCALAR_PANDAS_UDF,
@@ -2400,13 +2371,13 @@ def read_udfs(pickleSer, infile, eval_type, runner_conf, eval_conf):
24002371
batch_size = int(os.environ.get("PYTHON_UDF_BATCH_SIZE", "100"))
24012372
ser = BatchedSerializer(CPickleSerializer(), batch_size)
24022373

2403-
# Read all UDFs
2404-
num_udfs = read_int(infile)
24052374
udfs = [
2406-
read_single_udf(pickleSer, infile, eval_type, runner_conf, udf_index=i)
2407-
for i in range(num_udfs)
2375+
read_single_udf(pickleSer, udf_info, eval_type, runner_conf, udf_index=udf_index)
2376+
for udf_index, udf_info in enumerate(udf_info_list)
24082377
]
24092378

2379+
num_udfs = len(udfs)
2380+
24102381
def extract_key_value_indexes(grouped_arg_offsets):
24112382
"""
24122383
Helper function to extract the key and value indexes from arg_offsets for the grouped and
@@ -3599,52 +3570,49 @@ def func(_, it):
35993570
def main(infile, outfile):
36003571
try:
36013572
boot_time = time.time()
3602-
split_index = read_int(infile)
3603-
if split_index == -1: # for unit tests
3604-
sys.exit(-1)
3573+
init_info = WorkerInitInfo.from_stream(infile)
36053574
start_faulthandler_periodic_traceback()
3606-
check_python_version(infile)
3575+
check_python_version(init_info.python_version)
36073576

36083577
memory_limit_mb = int(os.environ.get("PYSPARK_EXECUTOR_MEMORY_MB", "-1"))
36093578
setup_memory_limits(memory_limit_mb)
36103579

3611-
task_context_json = json.loads(utf8_deserializer.loads(infile))
3612-
if task_context_json["isBarrier"]:
3613-
taskContext = BarrierTaskContext.from_json(task_context_json)
3614-
else:
3615-
taskContext = TaskContext.from_json(task_context_json)
3616-
TaskContext._setTaskContext(taskContext)
3580+
TaskContext._setTaskContext(init_info.task_context.to_task_context())
36173581

36183582
shuffle.MemoryBytesSpilled = 0
36193583
shuffle.DiskBytesSpilled = 0
36203584

3621-
setup_spark_files(infile)
3622-
setup_broadcasts(infile)
3585+
setup_spark_files(init_info.spark_files_dir, init_info.python_includes)
3586+
setup_broadcasts(
3587+
init_info.broadcast.variables,
3588+
init_info.broadcast.conn_info,
3589+
init_info.broadcast.auth_secret,
3590+
)
36233591

36243592
_accumulatorRegistry.clear()
3625-
eval_type = read_int(infile)
3626-
runner_conf = RunnerConf(infile)
3627-
eval_conf = EvalConf(infile)
3593+
eval_type = init_info.eval_type
3594+
runner_conf = RunnerConf(init_info.runner_conf)
3595+
eval_conf = EvalConf(init_info.eval_conf)
36283596
if eval_type == PythonEvalType.NON_UDF:
3629-
func, profiler, deserializer, serializer = read_command(pickleSer, infile)
3597+
func, profiler, deserializer, serializer = read_command(pickleSer, init_info.udf_info)
36303598
elif eval_type in (
36313599
PythonEvalType.SQL_TABLE_UDF,
36323600
PythonEvalType.SQL_ARROW_TABLE_UDF,
36333601
PythonEvalType.SQL_ARROW_UDTF,
36343602
):
36353603
func, profiler, deserializer, serializer = read_udtf(
3636-
pickleSer, infile, eval_type, runner_conf, eval_conf
3604+
pickleSer, init_info.udf_info, eval_type, runner_conf, eval_conf
36373605
)
36383606
else:
36393607
func, profiler, deserializer, serializer = read_udfs(
3640-
pickleSer, infile, eval_type, runner_conf, eval_conf
3608+
pickleSer, init_info.udf_info, eval_type, runner_conf, eval_conf
36413609
)
36423610

36433611
init_time = time.time()
36443612

36453613
def process():
36463614
iterator = deserializer.load_stream(infile)
3647-
out_iter = func(split_index, iterator)
3615+
out_iter = func(init_info.split_index, iterator)
36483616
try:
36493617
serializer.dump_stream(out_iter, outfile)
36503618
finally:

0 commit comments

Comments
 (0)