Skip to content

Commit 2869eea

Browse files
authored
Fuzzy dedup modifications (#687)
* Driver broadcast of large spark config variables Signed-off-by: Constantin M Adam <[email protected]> * Launch k8s spark cluster using spark native k8s API Signed-off-by: Constantin M Adam <[email protected]> * Add PyYAML to the base image dependencies Signed-off-by: Constantin M Adam <[email protected]> * Updated documentation for get_bcast_params() method Signed-off-by: Constantin M Adam <[email protected]> * Updated documentation Signed-off-by: Constantin M Adam <[email protected]> --------- Signed-off-by: Constantin M Adam <[email protected]>
1 parent fa76928 commit 2869eea

File tree

5 files changed

+95
-16
lines changed

5 files changed

+95
-16
lines changed

data-processing-lib/doc/spark-runtime.md

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -41,9 +41,11 @@ of this parameter:
4141

4242
## Transforms
4343

44-
* [SparkTransformRuntimeConfiguration](../spark/src/data_processing_spark/transform/runtime_configuration.py) allows
45-
to configure transform to use PySpark
46-
44+
* [SparkTransformRuntimeConfiguration](../spark/src/data_processing_spark/runtime/spark/runtime_configuration.py)
45+
allows to configure transform to use PySpark. In addition to its base class
46+
[TransformRuntimeConfiguration](../python//src/data_processing/runtime/runtime_configuration.py) features,
47+
this class includes `get_bcast_params()` method to get very large configuration settings. Before starting the
48+
transform execution, the Spark runtime will broadcast these settings to all the workers.
4749

4850
## Runtime
4951

data-processing-lib/spark/pyproject.toml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,8 @@ authors = [
1313
dependencies = [
1414
"data-prep-toolkit==0.2.2.dev0",
1515
"pyspark>=3.5.2",
16-
"psutil>=6.0.0"
16+
"psutil>=6.0.0",
17+
"PyYAML>=6.0.2"
1718
]
1819

1920
[project_urls]

data-processing-lib/spark/src/data_processing_spark/runtime/spark/runtime_configuration.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,9 @@
1010
# limitations under the License.
1111
################################################################################
1212

13+
from typing import Any
14+
15+
from data_processing.data_access import DataAccessFactoryBase
1316
from data_processing.runtime import TransformRuntimeConfiguration
1417
from data_processing.transform import TransformConfiguration
1518
from data_processing_spark.runtime.spark import DefaultSparkTransformRuntime
@@ -29,6 +32,16 @@ def __init__(
2932
super().__init__(transform_config=transform_config)
3033
self.runtime_class = runtime_class
3134

35+
def get_bcast_params(self, data_access_factory: DataAccessFactoryBase) -> dict[str, Any]:
36+
"""Allows retrieving and broadcasting to all the workers very large
37+
configuration parameters, like the list of document IDs to remove for
38+
fuzzy dedup, or the list of blocked web domains for block listing. This
39+
function is called by the spark runtime after spark initialization, and
40+
before spark_context.parallelize()
41+
:param data_access_factory - creates data_access object to download the large config parameter
42+
"""
43+
return {}
44+
3245
def create_transform_runtime(self) -> DefaultSparkTransformRuntime:
3346
"""
3447
Create transform runtime with the parameters captured during apply_input_params()

data-processing-lib/spark/src/data_processing_spark/runtime/spark/transform_orchestrator.py

Lines changed: 60 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -10,24 +10,69 @@
1010
# limitations under the License.
1111
################################################################################
1212

13+
import os
14+
import socket
1315
import time
1416
import traceback
1517
from datetime import datetime
1618

19+
import yaml
1720
from data_processing.data_access import DataAccessFactoryBase
1821
from data_processing.transform import TransformStatistics
1922
from data_processing.utils import GB, get_logger
2023
from data_processing_spark.runtime.spark import (
24+
SparkTransformExecutionConfiguration,
2125
SparkTransformFileProcessor,
2226
SparkTransformRuntimeConfiguration,
23-
SparkTransformExecutionConfiguration,
2427
)
2528
from pyspark import SparkConf, SparkContext
29+
from pyspark.sql import SparkSession
2630

2731

2832
logger = get_logger(__name__)
2933

3034

35+
def _init_spark(runtime_config: SparkTransformRuntimeConfiguration) -> SparkSession:
36+
server_port_https = int(os.getenv("KUBERNETES_SERVICE_PORT_HTTPS", "-1"))
37+
if server_port_https == -1:
38+
# running locally
39+
spark_config = {"spark.driver.host": "127.0.0.1"}
40+
return SparkSession.builder.appName(runtime_config.get_name()).config(map=spark_config).getOrCreate()
41+
else:
42+
# running in Kubernetes, use spark_profile.yml and
43+
# environment variables for configuration
44+
server_port = os.environ["KUBERNETES_SERVICE_PORT"]
45+
master_url = f"k8s://https://kubernetes.default:{server_port}"
46+
47+
# Read Spark configuration profile
48+
config_filepath = os.path.abspath(
49+
os.path.join(os.getenv("SPARK_HOME"), "work-dir", "config", "spark_profile.yml")
50+
)
51+
with open(config_filepath, "r") as config_fp:
52+
spark_config = yaml.safe_load(os.path.expandvars(config_fp.read()))
53+
spark_config["spark.submit.deployMode"] = "client"
54+
55+
# configure the executor pods from template
56+
executor_pod_template_file = os.path.join(
57+
os.getenv("SPARK_HOME"),
58+
"work-dir",
59+
"src",
60+
"templates",
61+
"spark-executor-pod-template.yml",
62+
)
63+
spark_config["spark.kubernetes.executor.podTemplateFile"] = executor_pod_template_file
64+
spark_config["spark.kubernetes.container.image.pullPolicy"] = "Always"
65+
66+
# Pass the driver IP address to the workers for callback
67+
myservice_url = socket.gethostbyname(socket.gethostname())
68+
spark_config["spark.driver.host"] = myservice_url
69+
spark_config["spark.driver.bindAddress"] = "0.0.0.0"
70+
spark_config["spark.decommission.enabled"] = True
71+
logger.info(f"Launching Spark Session with configuration\n" f"{yaml.dump(spark_config, indent=2)}")
72+
app_name = spark_config.get("spark.app.name", "my-spark-app")
73+
return SparkSession.builder.master(master_url).appName(app_name).config(map=spark_config).getOrCreate()
74+
75+
3176
def orchestrate(
3277
runtime_config: SparkTransformRuntimeConfiguration,
3378
execution_configuration: SparkTransformExecutionConfiguration,
@@ -45,14 +90,17 @@ def orchestrate(
4590
logger.info(f"orchestrator started at {start_ts}")
4691
# create data access
4792
data_access = data_access_factory.create_data_access()
93+
bcast_params = runtime_config.get_bcast_params(data_access_factory)
4894
if data_access is None:
4995
logger.error("No DataAccess instance provided - exiting")
5096
return 1
5197
# initialize Spark
52-
conf = SparkConf().setAppName(runtime_config.get_name()).set("spark.driver.host", "127.0.0.1")
53-
sc = SparkContext(conf=conf)
98+
spark_session = _init_spark(runtime_config)
99+
sc = spark_session.sparkContext
100+
# broadcast
54101
spark_runtime_config = sc.broadcast(runtime_config)
55102
daf = sc.broadcast(data_access_factory)
103+
spark_bcast_params = sc.broadcast(bcast_params)
56104

57105
def process_partition(iterator):
58106
"""
@@ -63,6 +111,7 @@ def process_partition(iterator):
63111
# local statistics dictionary
64112
statistics = TransformStatistics()
65113
# create transformer runtime
114+
bcast_params = spark_bcast_params.value
66115
d_access_factory = daf.value
67116
runtime_conf = spark_runtime_config.value
68117
runtime = runtime_conf.create_transform_runtime()
@@ -77,8 +126,11 @@ def process_partition(iterator):
77126
logger.debug(f"partition {f}")
78127
# add additional parameters
79128
transform_params = (
80-
runtime.get_transform_config(partition=int(f[1]), data_access_factory=d_access_factory,
81-
statistics=statistics))
129+
runtime.get_transform_config(
130+
partition=int(f[1]), data_access_factory=d_access_factory, statistics=statistics
131+
)
132+
| bcast_params
133+
)
82134
# create transform with partition number
83135
file_processor.create_transform(transform_params)
84136
first = False
@@ -128,7 +180,7 @@ def process_partition(iterator):
128180
memory = 0.0
129181
for i in range(executors.size()):
130182
memory += executors.toList().apply(i)._2()._1()
131-
resources = {"cpus": cpus, "gpus": 0, "memory": round(memory/GB, 2), "object_store": 0}
183+
resources = {"cpus": cpus, "gpus": 0, "memory": round(memory / GB, 2), "object_store": 0}
132184
input_params = runtime_config.get_transform_metadata() | execution_configuration.get_input_params()
133185
metadata = {
134186
"pipeline": execution_configuration.pipeline_id,
@@ -143,7 +195,8 @@ def process_partition(iterator):
143195
"execution_stats": {
144196
"num partitions": num_partitions,
145197
"execution time, min": round((time.time() - start_time) / 60, 3),
146-
} | resources,
198+
}
199+
| resources,
147200
"job_output_stats": stats,
148201
}
149202
logger.debug(f"Saving job metadata: {metadata}.")

data-processing-lib/spark/src/data_processing_spark/runtime/spark/transform_runtime.py

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -34,19 +34,29 @@ def get_transform_config(
3434
"""
3535
Get the dictionary of configuration that will be provided to the transform's initializer.
3636
This is the opportunity for this runtime to create a new set of configuration based on the
37-
config/params provided to this instance's initializer. This may include the addition
38-
of new configuration data such as ray shared memory, new actors, etc, that might be needed and
39-
expected by the transform in its initializer and/or transform() methods.
37+
config/params provided to this instance's initializer.
38+
:param partition - the partition assigned to this worker, needed by transforms like doc_id
4039
:param data_access_factory - data access factory class being used by the RayOrchestrator.
4140
:param statistics - reference to statistics actor
4241
:return: dictionary of transform init params
4342
"""
4443
return self.params
4544

45+
def get_bcast_params(self, data_access_factory: DataAccessFactoryBase) -> dict[str, Any]:
46+
"""Allows retrieving and broadcasting to all the workers very large
47+
configuration parameters, like the list of document IDs to remove for
48+
fuzzy dedup, or the list of blocked web domains for block listing. This
49+
function is called by the spark runtime after spark initialization, and
50+
before spark_context.parallelize()
51+
:param data_access_factory - creates data_access object to download the large config parameter
52+
"""
53+
return {}
54+
4655
def compute_execution_stats(self, stats: TransformStatistics) -> None:
4756
"""
4857
Update/augment the given statistics object with runtime-specific additions/modifications.
58+
This method does not return a value; the job execution statistics are generally reported
59+
as metadata by the Spark Orchestrator.
4960
:param stats: output of statistics as aggregated across all calls to all transforms.
50-
:return: job execution statistics. These are generally reported as metadata by the Ray Orchestrator.
5161
"""
52-
pass
62+
pass

0 commit comments

Comments
 (0)