10
10
# limitations under the License.
11
11
################################################################################
12
12
13
+ import os
14
+ import socket
13
15
import time
14
16
import traceback
15
17
from datetime import datetime
16
18
19
+ import yaml
17
20
from data_processing .data_access import DataAccessFactoryBase
18
21
from data_processing .transform import TransformStatistics
19
22
from data_processing .utils import GB , get_logger
20
23
from data_processing_spark .runtime .spark import (
24
+ SparkTransformExecutionConfiguration ,
21
25
SparkTransformFileProcessor ,
22
26
SparkTransformRuntimeConfiguration ,
23
- SparkTransformExecutionConfiguration ,
24
27
)
25
28
from pyspark import SparkConf , SparkContext
29
+ from pyspark .sql import SparkSession
26
30
27
31
28
32
logger = get_logger (__name__ )
29
33
30
34
35
+ def _init_spark (runtime_config : SparkTransformRuntimeConfiguration ) -> SparkSession :
36
+ server_port_https = int (os .getenv ("KUBERNETES_SERVICE_PORT_HTTPS" , "-1" ))
37
+ if server_port_https == - 1 :
38
+ # running locally
39
+ spark_config = {"spark.driver.host" : "127.0.0.1" }
40
+ return SparkSession .builder .appName (runtime_config .get_name ()).config (map = spark_config ).getOrCreate ()
41
+ else :
42
+ # running in Kubernetes, use spark_profile.yml and
43
+ # environment variables for configuration
44
+ server_port = os .environ ["KUBERNETES_SERVICE_PORT" ]
45
+ master_url = f"k8s://https://kubernetes.default:{ server_port } "
46
+
47
+ # Read Spark configuration profile
48
+ config_filepath = os .path .abspath (
49
+ os .path .join (os .getenv ("SPARK_HOME" ), "work-dir" , "config" , "spark_profile.yml" )
50
+ )
51
+ with open (config_filepath , "r" ) as config_fp :
52
+ spark_config = yaml .safe_load (os .path .expandvars (config_fp .read ()))
53
+ spark_config ["spark.submit.deployMode" ] = "client"
54
+
55
+ # configure the executor pods from template
56
+ executor_pod_template_file = os .path .join (
57
+ os .getenv ("SPARK_HOME" ),
58
+ "work-dir" ,
59
+ "src" ,
60
+ "templates" ,
61
+ "spark-executor-pod-template.yml" ,
62
+ )
63
+ spark_config ["spark.kubernetes.executor.podTemplateFile" ] = executor_pod_template_file
64
+ spark_config ["spark.kubernetes.container.image.pullPolicy" ] = "Always"
65
+
66
+ # Pass the driver IP address to the workers for callback
67
+ myservice_url = socket .gethostbyname (socket .gethostname ())
68
+ spark_config ["spark.driver.host" ] = myservice_url
69
+ spark_config ["spark.driver.bindAddress" ] = "0.0.0.0"
70
+ spark_config ["spark.decommission.enabled" ] = True
71
+ logger .info (f"Launching Spark Session with configuration\n " f"{ yaml .dump (spark_config , indent = 2 )} " )
72
+ app_name = spark_config .get ("spark.app.name" , "my-spark-app" )
73
+ return SparkSession .builder .master (master_url ).appName (app_name ).config (map = spark_config ).getOrCreate ()
74
+
75
+
31
76
def orchestrate (
32
77
runtime_config : SparkTransformRuntimeConfiguration ,
33
78
execution_configuration : SparkTransformExecutionConfiguration ,
@@ -45,14 +90,17 @@ def orchestrate(
45
90
logger .info (f"orchestrator started at { start_ts } " )
46
91
# create data access
47
92
data_access = data_access_factory .create_data_access ()
93
+ bcast_params = runtime_config .get_bcast_params (data_access_factory )
48
94
if data_access is None :
49
95
logger .error ("No DataAccess instance provided - exiting" )
50
96
return 1
51
97
# initialize Spark
52
- conf = SparkConf ().setAppName (runtime_config .get_name ()).set ("spark.driver.host" , "127.0.0.1" )
53
- sc = SparkContext (conf = conf )
98
+ spark_session = _init_spark (runtime_config )
99
+ sc = spark_session .sparkContext
100
+ # broadcast
54
101
spark_runtime_config = sc .broadcast (runtime_config )
55
102
daf = sc .broadcast (data_access_factory )
103
+ spark_bcast_params = sc .broadcast (bcast_params )
56
104
57
105
def process_partition (iterator ):
58
106
"""
@@ -63,6 +111,7 @@ def process_partition(iterator):
63
111
# local statistics dictionary
64
112
statistics = TransformStatistics ()
65
113
# create transformer runtime
114
+ bcast_params = spark_bcast_params .value
66
115
d_access_factory = daf .value
67
116
runtime_conf = spark_runtime_config .value
68
117
runtime = runtime_conf .create_transform_runtime ()
@@ -77,8 +126,11 @@ def process_partition(iterator):
77
126
logger .debug (f"partition { f } " )
78
127
# add additional parameters
79
128
transform_params = (
80
- runtime .get_transform_config (partition = int (f [1 ]), data_access_factory = d_access_factory ,
81
- statistics = statistics ))
129
+ runtime .get_transform_config (
130
+ partition = int (f [1 ]), data_access_factory = d_access_factory , statistics = statistics
131
+ )
132
+ | bcast_params
133
+ )
82
134
# create transform with partition number
83
135
file_processor .create_transform (transform_params )
84
136
first = False
@@ -128,7 +180,7 @@ def process_partition(iterator):
128
180
memory = 0.0
129
181
for i in range (executors .size ()):
130
182
memory += executors .toList ().apply (i )._2 ()._1 ()
131
- resources = {"cpus" : cpus , "gpus" : 0 , "memory" : round (memory / GB , 2 ), "object_store" : 0 }
183
+ resources = {"cpus" : cpus , "gpus" : 0 , "memory" : round (memory / GB , 2 ), "object_store" : 0 }
132
184
input_params = runtime_config .get_transform_metadata () | execution_configuration .get_input_params ()
133
185
metadata = {
134
186
"pipeline" : execution_configuration .pipeline_id ,
@@ -143,7 +195,8 @@ def process_partition(iterator):
143
195
"execution_stats" : {
144
196
"num partitions" : num_partitions ,
145
197
"execution time, min" : round ((time .time () - start_time ) / 60 , 3 ),
146
- } | resources ,
198
+ }
199
+ | resources ,
147
200
"job_output_stats" : stats ,
148
201
}
149
202
logger .debug (f"Saving job metadata: { metadata } ." )
0 commit comments