Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[FSTORE-1667] Feature store client doesn't work on Databricks with BYOK setup #482

Open
wants to merge 10 commits into
base: main
Choose a base branch
from
18 changes: 8 additions & 10 deletions python/hsfs/engine/spark.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
import json
import os
import re
import shutil
import uuid
import warnings
from datetime import date, datetime, timezone
Expand Down Expand Up @@ -1121,22 +1120,18 @@ def add_file(self, file):

# for external clients, download the file
if client._is_external():
tmp_file = os.path.join(SparkFiles.getRootDirectory(), file_name)
tmp_file = f"/tmp/{file_name}"
print("Reading key file from storage connector.")
response = self._dataset_api.read_content(file, util.get_dataset_type(file))

with open(tmp_file, "wb") as f:
f.write(response.content)
else:
self._spark_context.addFile(file)

# The file is not added to the driver current working directory
# We should add it manually by copying from the download location
# The file will be added to the executors current working directory
# before the next task is executed
shutil.copy(SparkFiles.get(file_name), file_name)
file = f"file://{tmp_file}"

self._spark_context.addFile(file)

return file_name
return SparkFiles.get(file_name)

def profile(
self,
Expand Down Expand Up @@ -1681,6 +1676,9 @@ def read_feature_log(query, time_col):
df = query.read()
return df.drop("log_id", time_col)

def get_spark_version(self):
return self._spark_session.version


class SchemaError(Exception):
"""Thrown when schemas don't match"""
53 changes: 49 additions & 4 deletions python/hsfs/storage_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -1304,16 +1304,61 @@ def confluent_options(self) -> Dict[str, Any]:

return config

def _read_pem(self, file_name):
with open(file_name, "r") as file:
return file.read()

def spark_options(self) -> Dict[str, Any]:
"""Return prepared options to be passed to Spark, based on the additional arguments.
This is done by just adding 'kafka.' prefix to kafka_options.
https://spark.apache.org/docs/latest/structured-streaming-kafka-integration.html#kafka-specific-configurations
"""
config = {}
for key, value in self.kafka_options().items():
config[f"{KafkaConnector.SPARK_FORMAT}.{key}"] = value
from packaging import version

return config
spark_config = {}

kafka_options = self.kafka_options()

for key, value in kafka_options.items():
if key in [
"ssl.truststore.location",
"ssl.truststore.password",
"ssl.keystore.location",
"ssl.keystore.password",
"ssl.key.password",
] and version.parse(
engine.get_instance().get_spark_version()
) >= version.parse("3.2.0"):
# We can only use this in the newer version of Spark which depend on Kafka > 2.7.0
# Kafka 2.7.0 adds support for providing the SSL credentials as PEM objects.
if not self._pem_files_created:
(
ca_chain_path,
client_cert_path,
client_key_path,
) = client.get_instance()._write_pem(
kafka_options["ssl.keystore.location"],
kafka_options["ssl.keystore.password"],
kafka_options["ssl.truststore.location"],
kafka_options["ssl.truststore.password"],
f"kafka_sc_{client.get_instance()._project_id}_{self._id}",
)
self._pem_files_created = True
spark_config["kafka.ssl.truststore.certificates"] = self._read_pem(
ca_chain_path
)
spark_config["kafka.ssl.keystore.certificate.chain"] = (
self._read_pem(client_cert_path)
)
spark_config["kafka.ssl.keystore.key"] = self._read_pem(
client_key_path
)
spark_config["kafka.ssl.truststore.type"] = "PEM"
spark_config["kafka.ssl.keystore.type"] = "PEM"
else:
spark_config[f"{KafkaConnector.SPARK_FORMAT}.{key}"] = value

return spark_config

def read(
self,
Expand Down
3 changes: 3 additions & 0 deletions python/tests/core/test_kafka_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -414,6 +414,7 @@ def test_spark_get_kafka_config(self, mocker, backend_fixtures):
json = backend_fixtures["storage_connector"]["get_kafka_external"]["response"]
sc = storage_connector.StorageConnector.from_response_json(json)
mock_storage_connector_api.return_value.get_kafka_connector.return_value = sc
mock_engine_get_instance.return_value.get_spark_version.return_value = "3.1.0"

mocker.patch("hopsworks_common.client._is_external", return_value=False)
# Act
Expand Down Expand Up @@ -456,6 +457,7 @@ def test_spark_get_kafka_config_external_client(self, mocker, backend_fixtures):
json = backend_fixtures["storage_connector"]["get_kafka_external"]["response"]
sc = storage_connector.StorageConnector.from_response_json(json)
mock_storage_connector_api.return_value.get_kafka_connector.return_value = sc
mock_engine_get_instance.return_value.get_spark_version.return_value = "3.1.0"

# Act
results = kafka_engine.get_kafka_config(
Expand Down Expand Up @@ -497,6 +499,7 @@ def test_spark_get_kafka_config_internal_kafka(self, mocker, backend_fixtures):
json = backend_fixtures["storage_connector"]["get_kafka_external"]["response"]
sc = storage_connector.StorageConnector.from_response_json(json)
mock_storage_connector_api.return_value.get_kafka_connector.return_value = sc
mock_engine_get_instance.return_value.get_spark_version.return_value = "3.1.0"

# Act
results = kafka_engine.get_kafka_config(
Expand Down
9 changes: 8 additions & 1 deletion python/tests/engine/test_spark.py
Original file line number Diff line number Diff line change
Expand Up @@ -866,6 +866,7 @@ def test_save_stream_dataframe(self, mocker, backend_fixtures):
)

mock_engine_get_instance = mocker.patch("hsfs.engine.get_instance")
mock_engine_get_instance.return_value.get_spark_version.return_value = "3.1.0"
mock_engine_get_instance.return_value.add_file.return_value = (
"result_from_add_file"
)
Expand Down Expand Up @@ -993,6 +994,7 @@ def test_save_stream_dataframe_query_name(self, mocker, backend_fixtures):
)

mock_engine_get_instance = mocker.patch("hsfs.engine.get_instance")
mock_engine_get_instance.return_value.get_spark_version.return_value = "3.1.0"
mock_engine_get_instance.return_value.add_file.return_value = (
"result_from_add_file"
)
Expand Down Expand Up @@ -1124,6 +1126,7 @@ def test_save_stream_dataframe_checkpoint_dir(self, mocker, backend_fixtures):
)

mock_engine_get_instance = mocker.patch("hsfs.engine.get_instance")
mock_engine_get_instance.return_value.get_spark_version.return_value = "3.1.0"
mock_engine_get_instance.return_value.add_file.return_value = (
"result_from_add_file"
)
Expand Down Expand Up @@ -1251,6 +1254,7 @@ def test_save_stream_dataframe_await_termination(self, mocker, backend_fixtures)
)

mock_engine_get_instance = mocker.patch("hsfs.engine.get_instance")
mock_engine_get_instance.return_value.get_spark_version.return_value = "3.1.0"
mock_engine_get_instance.return_value.add_file.return_value = (
"result_from_add_file"
)
Expand Down Expand Up @@ -1515,6 +1519,7 @@ def test_save_online_dataframe(self, mocker, backend_fixtures):
)

mock_engine_get_instance = mocker.patch("hsfs.engine.get_instance")
mock_engine_get_instance.return_value.get_spark_version.return_value = "3.1.0"
mock_engine_get_instance.return_value.add_file.return_value = (
"result_from_add_file"
)
Expand Down Expand Up @@ -3221,8 +3226,10 @@ def test_read_location_format_tsv(self, mocker):

def test_read_stream(self, mocker):
# Arrange
mocker.patch("hsfs.engine.get_instance")
mock_engine_get_instance = mocker.patch("hsfs.engine.get_instance")
mocker.patch("hopsworks_common.client.get_instance")
mock_engine_get_instance.return_value.get_spark_version.return_value = "3.1.0"

mock_pyspark_getOrCreate = mocker.patch(
"pyspark.sql.session.SparkSession.builder.getOrCreate"
)
Expand Down
52 changes: 51 additions & 1 deletion python/tests/test_storage_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -618,10 +618,12 @@ def test_kafka_options_external(self, mocker, backend_fixtures):

def test_spark_options(self, mocker, backend_fixtures):
# Arrange
mocker.patch("hsfs.engine.get_instance")
mock_engine_get_instance = mocker.patch("hsfs.engine.get_instance")
mock_client_get_instance = mocker.patch("hopsworks_common.client.get_instance")
json = backend_fixtures["storage_connector"]["get_kafka_internal"]["response"]

mock_engine_get_instance.return_value.get_spark_version.return_value = "3.1.0"

mock_client_get_instance.return_value._get_jks_trust_store_path.return_value = (
"result_from_get_jks_trust_store_path"
)
Expand Down Expand Up @@ -653,6 +655,7 @@ def test_spark_options_external(self, mocker, backend_fixtures):
mock_engine_get_instance = mocker.patch("hsfs.engine.get_instance")
json = backend_fixtures["storage_connector"]["get_kafka_external"]["response"]

mock_engine_get_instance.return_value.get_spark_version.return_value = "3.1.0"
mock_engine_get_instance.return_value.add_file.return_value = (
"result_from_add_file"
)
Expand All @@ -675,6 +678,53 @@ def test_spark_options_external(self, mocker, backend_fixtures):
"kafka.ssl.key.password": "test_ssl_key_password",
}

def test_spark_options_spark_35(self, mocker, backend_fixtures):
# Arrange
mock_engine_get_instance = mocker.patch("hsfs.engine.get_instance")
mock_client_get_instance = mocker.patch("hopsworks_common.client.get_instance")
json = backend_fixtures["storage_connector"]["get_kafka_internal"]["response"]

mock_engine_get_instance.return_value.get_spark_version.return_value = "3.5.0"

mock_client_get_instance.return_value._get_jks_trust_store_path.return_value = (
"result_from_get_jks_trust_store_path"
)
mock_client_get_instance.return_value._get_jks_key_store_path.return_value = (
"result_from_get_jks_key_store_path"
)
mock_client_get_instance.return_value._cert_key = "result_from_cert_key"
mock_client_get_instance.return_value._write_pem.return_value = (
None,
None,
None,
)

sc = storage_connector.StorageConnector.from_response_json(json)

# Mock the read pem method in the storage connector itself
sc._read_pem = mocker.Mock()
sc._read_pem.side_effect = [
"test_ssl_ca",
"test_ssl_certificate",
"test_ssl_key",
]

# Act
config = sc.spark_options()

# Assert
assert config == {
"kafka.test_option_name": "test_option_value",
"kafka.bootstrap.servers": "test_bootstrap_servers",
"kafka.security.protocol": "test_security_protocol",
"kafka.ssl.endpoint.identification.algorithm": "test_ssl_endpoint_identification_algorithm",
"kafka.ssl.truststore.type": "PEM",
"kafka.ssl.keystore.type": "PEM",
"kafka.ssl.truststore.certificates": "test_ssl_ca",
"kafka.ssl.keystore.certificate.chain": "test_ssl_certificate",
"kafka.ssl.keystore.key": "test_ssl_key",
}

def test_confluent_options(self, mocker, backend_fixtures):
# Arrange
mock_engine_get_instance = mocker.patch("hsfs.engine.get_instance")
Expand Down
Loading