5757from pyspark .util import PythonEvalType
5858from pyspark .serializers import (
5959 write_int ,
60- read_long ,
61- read_bool ,
6260 write_long ,
6361 read_int ,
6462 SpecialLengths ,
109107)
110108from pyspark import _NoValue , shuffle
111109from pyspark .errors import PySparkRuntimeError , PySparkTypeError , PySparkValueError
110+ from pyspark .worker_message import WorkerInitInfo
112111from pyspark .worker_util import (
113112 check_python_version ,
114113 get_sock_file_to_executor ,
118117 setup_broadcasts ,
119118 setup_memory_limits ,
120119 setup_spark_files ,
121- utf8_deserializer ,
122120 Conf ,
123121)
124122from pyspark .logger .worker_io import capture_outputs
@@ -981,39 +979,25 @@ def profiling_func(*args, **kwargs):
981979 return profiling_func
982980
983981
984- def read_single_udf (pickleSer , infile , eval_type , runner_conf , udf_index ):
985- num_arg = read_int (infile )
986-
987- args_offsets = []
988- kwargs_offsets = {}
989- for _ in range (num_arg ):
990- offset = read_int (infile )
991- if read_bool (infile ):
992- name = utf8_deserializer .loads (infile )
993- kwargs_offsets [name ] = offset
994- else :
995- args_offsets .append (offset )
996-
982+ def read_single_udf (pickleSer , udf_info , eval_type , runner_conf , udf_index ):
997983 chained_func = None
998- for i in range ( read_int ( infile )) :
999- f , return_type = read_command (pickleSer , infile )
984+ for udf in udf_info . udfs :
985+ f , return_type = read_command (pickleSer , udf )
1000986 if chained_func is None :
1001987 chained_func = f
1002988 else :
1003989 chained_func = chain (chained_func , f )
1004990
1005- result_id = read_long (infile )
1006-
1007991 # If chained_func is from pyspark.sql.worker, it is to read/write data source.
1008992 # In this case, we check the data_source_profiler config.
1009993 if getattr (chained_func , "__module__" , "" ).startswith ("pyspark.sql.worker." ):
1010994 profiler = runner_conf .data_source_profiler
1011995 else :
1012996 profiler = runner_conf .udf_profiler
1013997 if profiler == "perf" :
1014- profiling_func = wrap_perf_profiler (chained_func , eval_type , result_id )
998+ profiling_func = wrap_perf_profiler (chained_func , eval_type , udf_info . result_id )
1015999 elif profiler == "memory" :
1016- profiling_func = wrap_memory_profiler (chained_func , eval_type , result_id )
1000+ profiling_func = wrap_memory_profiler (chained_func , eval_type , udf_info . result_id )
10171001 else :
10181002 profiling_func = chained_func
10191003
@@ -1027,6 +1011,8 @@ def read_single_udf(pickleSer, infile, eval_type, runner_conf, udf_index):
10271011 # when they are processed in a for loop, raise them as RuntimeError's instead
10281012 func = fail_on_stopiteration (profiling_func )
10291013
1014+ args_offsets , kwargs_offsets = udf_info .args , udf_info .kwargs
1015+
10301016 # the last returnType will be the return type of UDF
10311017 if eval_type == PythonEvalType .SQL_SCALAR_PANDAS_UDF :
10321018 return func , args_offsets , kwargs_offsets , return_type
@@ -1105,7 +1091,7 @@ def read_single_udf(pickleSer, infile, eval_type, runner_conf, udf_index):
11051091# It expects the UDTF to be in a specific format and performs various checks to
11061092# ensure the UDTF is valid. This function also prepares a mapper function for applying
11071093# the UDTF logic to input rows.
1108- def read_udtf (pickleSer , infile , eval_type , runner_conf , eval_conf ):
1094+ def read_udtf (pickleSer , udtf_info , eval_type , runner_conf , eval_conf ):
11091095 if eval_type == PythonEvalType .SQL_ARROW_TABLE_UDF :
11101096 if runner_conf .use_legacy_pandas_udtf_conversion :
11111097 # NOTE: if timezone is set here, that implies respectSessionTimeZone is True
@@ -1126,50 +1112,35 @@ def read_udtf(pickleSer, infile, eval_type, runner_conf, eval_conf):
11261112 # Each row is a group so do not batch but send one by one.
11271113 ser = BatchedSerializer (CPickleSerializer (), 1 )
11281114
1129- # See 'PythonUDTFRunner.PythonUDFWriterThread.writeCommand'
1130- num_arg = read_int (infile )
1131- args_offsets = []
1132- kwargs_offsets = {}
1133- for _ in range (num_arg ):
1134- offset = read_int (infile )
1135- if read_bool (infile ):
1136- name = utf8_deserializer .loads (infile )
1137- kwargs_offsets [name ] = offset
1138- else :
1139- args_offsets .append (offset )
1140- num_partition_child_indexes = read_int (infile )
1141- partition_child_indexes = [read_int (infile ) for i in range (num_partition_child_indexes )]
1142- has_pickled_analyze_result = read_bool (infile )
1143- if has_pickled_analyze_result :
1144- pickled_analyze_result = pickleSer ._read_with_length (infile )
1115+ if udtf_info .pickled_analyze_result is not None :
1116+ pickled_analyze_result = pickleSer .loads (udtf_info .pickled_analyze_result )
11451117 else :
11461118 pickled_analyze_result = None
11471119 # Initially we assume that the UDTF __init__ method accepts the pickled AnalyzeResult,
11481120 # although we may set this to false later if we find otherwise.
1149- handler = read_command (pickleSer , infile )
1121+ handler = read_command (pickleSer , udtf_info . handler )
11501122 if not isinstance (handler , type ):
11511123 raise PySparkRuntimeError (
11521124 f"Invalid UDTF handler type. Expected a class (type 'type'), but "
11531125 f"got an instance of { type (handler ).__name__ } ."
11541126 )
11551127
1156- return_type = _parse_datatype_json_string (utf8_deserializer . loads ( infile ) )
1128+ return_type = _parse_datatype_json_string (udtf_info . return_type )
11571129 if not isinstance (return_type , StructType ):
11581130 raise PySparkRuntimeError (
11591131 f"The return type of a UDTF must be a struct type, but got { type (return_type )} ."
11601132 )
1161- udtf_name = utf8_deserializer .loads (infile )
11621133
11631134 # Update the handler that creates a new UDTF instance to first try calling the UDTF constructor
11641135 # with one argument containing the previous AnalyzeResult. If that fails, then try a constructor
11651136 # with no arguments. In this way each UDTF class instance can decide if it wants to inspect the
11661137 # AnalyzeResult.
11671138 udtf_init_args = inspect .getfullargspec (handler )
1168- if has_pickled_analyze_result :
1139+ if pickled_analyze_result is not None :
11691140 if len (udtf_init_args .args ) > 2 :
11701141 raise PySparkRuntimeError (
11711142 errorClass = "UDTF_CONSTRUCTOR_INVALID_IMPLEMENTS_ANALYZE_METHOD" ,
1172- messageParameters = {"name" : udtf_name },
1143+ messageParameters = {"name" : udtf_info . name },
11731144 )
11741145 elif len (udtf_init_args .args ) == 2 :
11751146 prev_handler = handler
@@ -1182,7 +1153,7 @@ def construct_udtf():
11821153 elif len (udtf_init_args .args ) > 1 :
11831154 raise PySparkRuntimeError (
11841155 errorClass = "UDTF_CONSTRUCTOR_INVALID_NO_ANALYZE_METHOD" ,
1185- messageParameters = {"name" : udtf_name },
1156+ messageParameters = {"name" : udtf_info . name },
11861157 )
11871158
11881159 class UDTFWithPartitions :
@@ -1220,7 +1191,7 @@ def __init__(self, create_udtf: Callable, partition_child_indexes: list):
12201191 self ._create_udtf : Callable = create_udtf
12211192 self ._udtf = create_udtf ()
12221193 self ._prev_arguments : list = list ()
1223- self ._partition_child_indexes : list = partition_child_indexes
1194+ self ._partition_child_indexes : list = udtf_info . partition_child_indexes
12241195 self ._eval_raised_skip_rest_of_input_table : bool = False
12251196
12261197 def eval (self , * args , ** kwargs ) -> Iterator :
@@ -1614,13 +1585,13 @@ def _remove_partition_by_exprs(self, arg: Any) -> Any:
16141585
16151586 # Instantiate the UDTF class.
16161587 try :
1617- if len (partition_child_indexes ) > 0 :
1588+ if len (udtf_info . partition_child_indexes ) > 0 :
16181589 # Determine if this is an Arrow UDTF
16191590 is_arrow_udtf = eval_type == PythonEvalType .SQL_ARROW_UDTF
16201591 if is_arrow_udtf :
1621- udtf = ArrowUDTFWithPartition (handler , partition_child_indexes )
1592+ udtf = ArrowUDTFWithPartition (handler , udtf_info . partition_child_indexes )
16221593 else :
1623- udtf = UDTFWithPartitions (handler , partition_child_indexes )
1594+ udtf = UDTFWithPartitions (handler , udtf_info . partition_child_indexes )
16241595 else :
16251596 udtf = handler ()
16261597 except Exception as e :
@@ -1640,11 +1611,11 @@ def _remove_partition_by_exprs(self, arg: Any) -> Any:
16401611 # Check that the arguments provided to the UDTF call match the expected parameters defined
16411612 # in the 'eval' method signature.
16421613 try :
1643- inspect .signature (udtf .eval ).bind (* args_offsets , ** kwargs_offsets )
1614+ inspect .signature (udtf .eval ).bind (* udtf_info . args , ** udtf_info . kwargs )
16441615 except TypeError as e :
16451616 raise PySparkRuntimeError (
16461617 errorClass = "UDTF_EVAL_METHOD_ARGUMENTS_DO_NOT_MATCH_SIGNATURE" ,
1647- messageParameters = {"name" : udtf_name , "reason" : str (e )},
1618+ messageParameters = {"name" : udtf_info . name , "reason" : str (e )},
16481619 ) from None
16491620
16501621 def build_null_checker (return_type : StructType ) -> Optional [Callable [[Any ], None ]]:
@@ -1848,7 +1819,7 @@ def evaluate(*args: pd.Series, num_rows=1):
18481819 return evaluate
18491820
18501821 eval_func_kwargs_support , args_kwargs_offsets = wrap_kwargs_support (
1851- getattr (udtf , "eval" ), args_offsets , kwargs_offsets
1822+ getattr (udtf , "eval" ), udtf_info . args , udtf_info . kwargs
18521823 )
18531824 eval = wrap_arrow_udtf (eval_func_kwargs_support , return_type )
18541825
@@ -2010,7 +1981,7 @@ def evaluate(*args: list, num_rows=1):
20101981 return evaluate
20111982
20121983 eval_func_kwargs_support , args_kwargs_offsets = wrap_kwargs_support (
2013- getattr (udtf , "eval" ), args_offsets , kwargs_offsets
1984+ getattr (udtf , "eval" ), udtf_info . args , udtf_info . kwargs
20141985 )
20151986 eval = wrap_arrow_udtf (eval_func_kwargs_support , return_type )
20161987
@@ -2136,7 +2107,7 @@ def evaluate(*args: pa.RecordBatch):
21362107 return evaluate
21372108
21382109 eval_func_kwargs_support , args_kwargs_offsets = wrap_kwargs_support (
2139- getattr (udtf , "eval" ), args_offsets , kwargs_offsets
2110+ getattr (udtf , "eval" ), udtf_info . args , udtf_info . kwargs
21402111 )
21412112 eval = wrap_pyarrow_udtf (eval_func_kwargs_support , return_type )
21422113
@@ -2238,7 +2209,7 @@ def evaluate(*a) -> tuple:
22382209 return evaluate
22392210
22402211 eval_func_kwargs_support , args_kwargs_offsets = wrap_kwargs_support (
2241- getattr (udtf , "eval" ), args_offsets , kwargs_offsets
2212+ getattr (udtf , "eval" ), udtf_info . args , udtf_info . kwargs
22422213 )
22432214 eval = wrap_udtf (eval_func_kwargs_support , return_type )
22442215
@@ -2266,7 +2237,7 @@ def mapper(_, it):
22662237 return mapper , None , ser , ser
22672238
22682239
2269- def read_udfs (pickleSer , infile , eval_type , runner_conf , eval_conf ):
2240+ def read_udfs (pickleSer , udf_info_list , eval_type , runner_conf , eval_conf ):
22702241 if eval_type in (
22712242 PythonEvalType .SQL_ARROW_BATCHED_UDF ,
22722243 PythonEvalType .SQL_SCALAR_PANDAS_UDF ,
@@ -2400,13 +2371,13 @@ def read_udfs(pickleSer, infile, eval_type, runner_conf, eval_conf):
24002371 batch_size = int (os .environ .get ("PYTHON_UDF_BATCH_SIZE" , "100" ))
24012372 ser = BatchedSerializer (CPickleSerializer (), batch_size )
24022373
2403- # Read all UDFs
2404- num_udfs = read_int (infile )
24052374 udfs = [
2406- read_single_udf (pickleSer , infile , eval_type , runner_conf , udf_index = i )
2407- for i in range ( num_udfs )
2375+ read_single_udf (pickleSer , udf_info , eval_type , runner_conf , udf_index = udf_index )
2376+ for udf_index , udf_info in enumerate ( udf_info_list )
24082377 ]
24092378
2379+ num_udfs = len (udfs )
2380+
24102381 def extract_key_value_indexes (grouped_arg_offsets ):
24112382 """
24122383 Helper function to extract the key and value indexes from arg_offsets for the grouped and
@@ -3599,52 +3570,49 @@ def func(_, it):
35993570def main (infile , outfile ):
36003571 try :
36013572 boot_time = time .time ()
3602- split_index = read_int (infile )
3603- if split_index == - 1 : # for unit tests
3604- sys .exit (- 1 )
3573+ init_info = WorkerInitInfo .from_stream (infile )
36053574 start_faulthandler_periodic_traceback ()
3606- check_python_version (infile )
3575+ check_python_version (init_info . python_version )
36073576
36083577 memory_limit_mb = int (os .environ .get ("PYSPARK_EXECUTOR_MEMORY_MB" , "-1" ))
36093578 setup_memory_limits (memory_limit_mb )
36103579
3611- task_context_json = json .loads (utf8_deserializer .loads (infile ))
3612- if task_context_json ["isBarrier" ]:
3613- taskContext = BarrierTaskContext .from_json (task_context_json )
3614- else :
3615- taskContext = TaskContext .from_json (task_context_json )
3616- TaskContext ._setTaskContext (taskContext )
3580+ TaskContext ._setTaskContext (init_info .task_context .to_task_context ())
36173581
36183582 shuffle .MemoryBytesSpilled = 0
36193583 shuffle .DiskBytesSpilled = 0
36203584
3621- setup_spark_files (infile )
3622- setup_broadcasts (infile )
3585+ setup_spark_files (init_info .spark_files_dir , init_info .python_includes )
3586+ setup_broadcasts (
3587+ init_info .broadcast .variables ,
3588+ init_info .broadcast .conn_info ,
3589+ init_info .broadcast .auth_secret ,
3590+ )
36233591
36243592 _accumulatorRegistry .clear ()
3625- eval_type = read_int ( infile )
3626- runner_conf = RunnerConf (infile )
3627- eval_conf = EvalConf (infile )
3593+ eval_type = init_info . eval_type
3594+ runner_conf = RunnerConf (init_info . runner_conf )
3595+ eval_conf = EvalConf (init_info . eval_conf )
36283596 if eval_type == PythonEvalType .NON_UDF :
3629- func , profiler , deserializer , serializer = read_command (pickleSer , infile )
3597+ func , profiler , deserializer , serializer = read_command (pickleSer , init_info . udf_info )
36303598 elif eval_type in (
36313599 PythonEvalType .SQL_TABLE_UDF ,
36323600 PythonEvalType .SQL_ARROW_TABLE_UDF ,
36333601 PythonEvalType .SQL_ARROW_UDTF ,
36343602 ):
36353603 func , profiler , deserializer , serializer = read_udtf (
3636- pickleSer , infile , eval_type , runner_conf , eval_conf
3604+ pickleSer , init_info . udf_info , eval_type , runner_conf , eval_conf
36373605 )
36383606 else :
36393607 func , profiler , deserializer , serializer = read_udfs (
3640- pickleSer , infile , eval_type , runner_conf , eval_conf
3608+ pickleSer , init_info . udf_info , eval_type , runner_conf , eval_conf
36413609 )
36423610
36433611 init_time = time .time ()
36443612
36453613 def process ():
36463614 iterator = deserializer .load_stream (infile )
3647- out_iter = func (split_index , iterator )
3615+ out_iter = func (init_info . split_index , iterator )
36483616 try :
36493617 serializer .dump_stream (out_iter , outfile )
36503618 finally :
0 commit comments