diff --git a/src/datacustomcode/client.py b/src/datacustomcode/client.py index 15d34a9..8ba974b 100644 --- a/src/datacustomcode/client.py +++ b/src/datacustomcode/client.py @@ -21,11 +21,10 @@ Optional, ) -from pyspark.sql import SparkSession - -from datacustomcode.config import SparkConfig, config +from datacustomcode.config import config from datacustomcode.file.path.default import DefaultFindFilePath from datacustomcode.io.reader.base import BaseDataCloudReader +from datacustomcode.spark.default import DefaultSparkSessionProvider if TYPE_CHECKING: from pathlib import Path @@ -34,18 +33,7 @@ from datacustomcode.io.reader.base import BaseDataCloudReader from datacustomcode.io.writer.base import BaseDataCloudWriter, WriteMode - - -def _setup_spark(spark_config: SparkConfig) -> SparkSession: - """Setup Spark session from config.""" - builder = SparkSession.builder - if spark_config.master is not None: - builder = builder.master(spark_config.master) - - builder = builder.appName(spark_config.app_name) - for key, value in spark_config.options.items(): - builder = builder.config(key, value) - return builder.getOrCreate() + from datacustomcode.spark.base import BaseSparkSessionProvider class DataCloudObjectType(Enum): @@ -123,7 +111,8 @@ class Client: def __new__( cls, reader: Optional[BaseDataCloudReader] = None, - writer: Optional[BaseDataCloudWriter] = None, + writer: Optional["BaseDataCloudWriter"] = None, + spark_provider: Optional["BaseSparkSessionProvider"] = None, ) -> Client: if cls._instance is None: cls._instance = super().__new__(cls) @@ -136,7 +125,16 @@ def __new__( raise ValueError( "Spark config is required when reader/writer is not provided" ) - spark = _setup_spark(config.spark_config) + + provider: BaseSparkSessionProvider + if spark_provider is not None: + provider = spark_provider + elif config.spark_provider_config is not None: + provider = config.spark_provider_config.to_object() + else: + provider = DefaultSparkSessionProvider() + + spark = provider.get_session(config.spark_config) if config.reader_config is None and reader is None: raise ValueError( diff --git a/src/datacustomcode/config.py b/src/datacustomcode/config.py index a4d35db..8e18551 100644 --- a/src/datacustomcode/config.py +++ b/src/datacustomcode/config.py @@ -38,6 +38,7 @@ from datacustomcode.io.base import BaseDataAccessLayer from datacustomcode.io.reader.base import BaseDataCloudReader # noqa: TCH001 from datacustomcode.io.writer.base import BaseDataCloudWriter # noqa: TCH001 +from datacustomcode.spark.base import BaseSparkSessionProvider DEFAULT_CONFIG_NAME = "config.yaml" @@ -89,10 +90,29 @@ class SparkConfig(ForceableConfig): ) +_P = TypeVar("_P", bound=BaseSparkSessionProvider) + + +class SparkProviderConfig(ForceableConfig, Generic[_P]): + model_config = ConfigDict(validate_default=True, extra="forbid") + type_base: ClassVar[Type[BaseSparkSessionProvider]] = BaseSparkSessionProvider + type_config_name: str = Field( + description="CONFIG_NAME of the Spark session provider." + ) + options: dict[str, Any] = Field(default_factory=dict) + + def to_object(self) -> _P: + type_ = self.type_base.subclass_from_config_name(self.type_config_name) + return cast(_P, type_(**self.options)) + + class ClientConfig(BaseModel): reader_config: Union[AccessLayerObjectConfig[BaseDataCloudReader], None] = None writer_config: Union[AccessLayerObjectConfig[BaseDataCloudWriter], None] = None spark_config: Union[SparkConfig, None] = None + spark_provider_config: Union[ + SparkProviderConfig[BaseSparkSessionProvider], None + ] = None def update(self, other: ClientConfig) -> ClientConfig: """Merge this ClientConfig with another, respecting force flags. @@ -117,6 +137,9 @@ def merge( self.reader_config = merge(self.reader_config, other.reader_config) self.writer_config = merge(self.writer_config, other.writer_config) self.spark_config = merge(self.spark_config, other.spark_config) + self.spark_provider_config = merge( + self.spark_provider_config, other.spark_provider_config + ) return self def load(self, config_path: str) -> ClientConfig: diff --git a/src/datacustomcode/spark/__init__.py b/src/datacustomcode/spark/__init__.py new file mode 100644 index 0000000..fabc660 --- /dev/null +++ b/src/datacustomcode/spark/__init__.py @@ -0,0 +1,20 @@ +# Copyright (c) 2025, Salesforce, Inc. +# SPDX-License-Identifier: Apache-2 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from __future__ import annotations + +from datacustomcode.spark.base import BaseSparkSessionProvider +from datacustomcode.spark.default import DefaultSparkSessionProvider + +__all__ = ["BaseSparkSessionProvider", "DefaultSparkSessionProvider"] diff --git a/src/datacustomcode/spark/base.py b/src/datacustomcode/spark/base.py new file mode 100644 index 0000000..fe7bf92 --- /dev/null +++ b/src/datacustomcode/spark/base.py @@ -0,0 +1,29 @@ +# Copyright (c) 2025, Salesforce, Inc. +# SPDX-License-Identifier: Apache-2 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from __future__ import annotations + +from typing import TYPE_CHECKING + +from datacustomcode.mixin import UserExtendableNamedConfigMixin + +if TYPE_CHECKING: + from pyspark.sql import SparkSession + + from datacustomcode.config import SparkConfig + + +class BaseSparkSessionProvider(UserExtendableNamedConfigMixin): + def get_session(self, spark_config: SparkConfig) -> "SparkSession": + raise NotImplementedError diff --git a/src/datacustomcode/spark/default.py b/src/datacustomcode/spark/default.py new file mode 100644 index 0000000..d020dd1 --- /dev/null +++ b/src/datacustomcode/spark/default.py @@ -0,0 +1,39 @@ +# Copyright (c) 2025, Salesforce, Inc. +# SPDX-License-Identifier: Apache-2 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from __future__ import annotations + +from typing import TYPE_CHECKING + +from datacustomcode.spark.base import BaseSparkSessionProvider + +if TYPE_CHECKING: + from pyspark.sql import SparkSession + + from datacustomcode.config import SparkConfig + + +class DefaultSparkSessionProvider(BaseSparkSessionProvider): + CONFIG_NAME = "DefaultSparkSessionProvider" + + def get_session(self, spark_config: SparkConfig) -> "SparkSession": + from pyspark.sql import SparkSession + + builder = SparkSession.builder + if spark_config.master is not None: + builder = builder.master(spark_config.master) + builder = builder.appName(spark_config.app_name) + for key, value in spark_config.options.items(): + builder = builder.config(key, value) + return builder.getOrCreate() diff --git a/tests/spark/__init__.py b/tests/spark/__init__.py new file mode 100644 index 0000000..6ae5294 --- /dev/null +++ b/tests/spark/__init__.py @@ -0,0 +1 @@ +# Package initialization file diff --git a/tests/spark/test_session_provider.py b/tests/spark/test_session_provider.py new file mode 100644 index 0000000..71f0e70 --- /dev/null +++ b/tests/spark/test_session_provider.py @@ -0,0 +1,126 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING, Any + +from datacustomcode.client import Client +from datacustomcode.config import ( + AccessLayerObjectConfig, + ClientConfig, + SparkConfig, + SparkProviderConfig, +) +from datacustomcode.io.reader.base import BaseDataCloudReader +from datacustomcode.io.writer.base import BaseDataCloudWriter, WriteMode +from datacustomcode.spark.base import BaseSparkSessionProvider + +if TYPE_CHECKING: + from pyspark.sql import DataFrame as PySparkDataFrame + + +class _Sentinel: + pass + + +SENTINEL_SPARK = _Sentinel() + + +class MockReader(BaseDataCloudReader): + CONFIG_NAME = "MockReader" + last_spark: Any | None = None + + def __init__(self, spark): + super().__init__(spark) + MockReader.last_spark = spark + + def read_dlo(self, name: str): # type: ignore[override] + raise NotImplementedError + + def read_dmo(self, name: str): # type: ignore[override] + raise NotImplementedError + + +class MockWriter(BaseDataCloudWriter): + CONFIG_NAME = "MockWriter" + last_spark: Any | None = None + + def __init__(self, spark): + super().__init__(spark) + MockWriter.last_spark = spark + + def write_to_dlo( + self, name: str, dataframe: PySparkDataFrame, write_mode: WriteMode + ) -> None: # type: ignore[override] + raise NotImplementedError + + def write_to_dmo( + self, name: str, dataframe: PySparkDataFrame, write_mode: WriteMode + ) -> None: # type: ignore[override] + raise NotImplementedError + + +class FakeProvider(BaseSparkSessionProvider): + CONFIG_NAME = "FakeProvider" + + def get_session(self, spark_config: SparkConfig): # type: ignore[override] + return SENTINEL_SPARK + + +def _reset_singleton(): + # Reset Client singleton between tests + Client._instance = None # type: ignore[attr-defined] + + +def test_client_uses_provider_from_config(monkeypatch): + _reset_singleton() + + cfg = ClientConfig( + reader_config=AccessLayerObjectConfig( + type_config_name=MockReader.CONFIG_NAME, options={} + ), + writer_config=AccessLayerObjectConfig( + type_config_name=MockWriter.CONFIG_NAME, options={} + ), + spark_config=SparkConfig(app_name="test-app", master=None, options={}), + spark_provider_config=SparkProviderConfig( + type_config_name=FakeProvider.CONFIG_NAME, options={} + ), + ) + + from datacustomcode.config import config as global_config + + global_config.update(cfg) + + Client() + assert MockReader.last_spark is SENTINEL_SPARK + assert MockWriter.last_spark is SENTINEL_SPARK + + +class ExplicitProvider(BaseSparkSessionProvider): + CONFIG_NAME = "ExplicitProvider" + + def get_session(self, spark_config: SparkConfig): # type: ignore[override] + return SENTINEL_SPARK + + +def test_client_explicit_provider_overrides_config(monkeypatch): + _reset_singleton() + + cfg = ClientConfig( + reader_config=AccessLayerObjectConfig( + type_config_name=MockReader.CONFIG_NAME, options={} + ), + writer_config=AccessLayerObjectConfig( + type_config_name=MockWriter.CONFIG_NAME, options={} + ), + spark_config=SparkConfig(app_name="test-app", master=None, options={}), + spark_provider_config=None, + ) + + from datacustomcode.config import config as global_config + + global_config.update(cfg) + + provider = ExplicitProvider() + Client(spark_provider=provider) + assert MockReader.last_spark is SENTINEL_SPARK + assert MockWriter.last_spark is SENTINEL_SPARK diff --git a/tests/test_client.py b/tests/test_client.py index 8a546e2..5d97d58 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -9,7 +9,6 @@ Client, DataCloudAccessLayerException, DataCloudObjectType, - _setup_spark, ) from datacustomcode.config import ( AccessLayerObjectConfig, @@ -100,37 +99,42 @@ def test_singleton_pattern(self, reset_client, mock_spark): Client(reader=MagicMock(spec=BaseDataCloudReader)) @patch("datacustomcode.client.config") - @patch("datacustomcode.client._setup_spark") - def test_initialization_with_config( - self, mock_setup_spark, mock_config, reset_client, mock_spark - ): + def test_initialization_with_config(self, mock_config, reset_client, mock_spark): """Test client initialization using configuration.""" - mock_setup_spark.return_value = mock_spark + from unittest.mock import patch as mock_patch - mock_reader = MagicMock(spec=BaseDataCloudReader) - mock_reader_config = MagicMock() - mock_reader_config.to_object.return_value = mock_reader - mock_reader_config.force = False + from datacustomcode.spark.default import DefaultSparkSessionProvider - mock_writer = MagicMock(spec=BaseDataCloudWriter) - mock_writer_config = MagicMock() - mock_writer_config.to_object.return_value = mock_writer - mock_writer_config.force = False + with mock_patch.object( + DefaultSparkSessionProvider, "get_session" + ) as mock_get_session: + mock_get_session.return_value = mock_spark - mock_spark_config = MagicMock(spec=SparkConfig) + mock_reader = MagicMock(spec=BaseDataCloudReader) + mock_reader_config = MagicMock() + mock_reader_config.to_object.return_value = mock_reader + mock_reader_config.force = False - mock_config.reader_config = mock_reader_config - mock_config.writer_config = mock_writer_config - mock_config.spark_config = mock_spark_config + mock_writer = MagicMock(spec=BaseDataCloudWriter) + mock_writer_config = MagicMock() + mock_writer_config.to_object.return_value = mock_writer + mock_writer_config.force = False - client = Client() + mock_spark_config = MagicMock(spec=SparkConfig) + mock_config.spark_provider_config = None - mock_setup_spark.assert_called_once_with(mock_spark_config) - mock_reader_config.to_object.assert_called_once_with(mock_spark) - mock_writer_config.to_object.assert_called_once_with(mock_spark) + mock_config.reader_config = mock_reader_config + mock_config.writer_config = mock_writer_config + mock_config.spark_config = mock_spark_config - assert client._reader is mock_reader - assert client._writer is mock_writer + client = Client() + + mock_get_session.assert_called_once_with(mock_spark_config) + mock_reader_config.to_object.assert_called_once_with(mock_spark) + mock_writer_config.to_object.assert_called_once_with(mock_spark) + + assert client._reader is mock_reader + assert client._writer is mock_writer def test_read_dlo(self, reset_client, mock_spark): reader = MagicMock(spec=BaseDataCloudReader) @@ -249,12 +253,12 @@ def test_read_pattern_flow(self, reset_client, mock_spark): assert "source_dmo" in client._data_layer_history[DataCloudObjectType.DMO] -# Add tests for _setup_spark function -class TestSetupSpark: +# Add tests for DefaultSparkSessionProvider +class TestDefaultSparkSessionProvider: - @patch("datacustomcode.client.SparkSession") - def test_setup_spark_with_master(self, mock_spark_session): - """Test _setup_spark with master specified""" + @patch("pyspark.sql.SparkSession") + def test_get_session_with_master(self, mock_spark_session): + """Test DefaultSparkSessionProvider with master specified""" mock_builder = MagicMock() mock_master_builder = MagicMock() mock_app_name_builder = MagicMock() @@ -273,7 +277,10 @@ def test_setup_spark_with_master(self, mock_spark_session): options={"spark.executor.memory": "1g"}, ) - result = _setup_spark(spark_config) + from datacustomcode.spark.default import DefaultSparkSessionProvider + + provider = DefaultSparkSessionProvider() + result = provider.get_session(spark_config) mock_builder.master.assert_called_once_with("local[1]") mock_master_builder.appName.assert_called_once_with("test-app") @@ -283,9 +290,9 @@ def test_setup_spark_with_master(self, mock_spark_session): mock_config_builder.getOrCreate.assert_called_once() assert result is mock_session - @patch("datacustomcode.client.SparkSession") - def test_setup_spark_with_multiple_options(self, mock_spark_session): - """Test _setup_spark with multiple config options""" + @patch("pyspark.sql.SparkSession") + def test_get_session_with_multiple_options(self, mock_spark_session): + """Test DefaultSparkSessionProvider with multiple config options""" mock_builder = MagicMock() mock_app_name_builder = MagicMock() mock_config_builder1 = MagicMock() @@ -310,7 +317,10 @@ def test_setup_spark_with_multiple_options(self, mock_spark_session): }, ) - result = _setup_spark(spark_config) + from datacustomcode.spark.default import DefaultSparkSessionProvider + + provider = DefaultSparkSessionProvider() + result = provider.get_session(spark_config) mock_builder.appName.assert_called_once_with("test-app")