Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 15 additions & 17 deletions src/datacustomcode/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,10 @@
Optional,
)

from pyspark.sql import SparkSession

from datacustomcode.config import SparkConfig, config
from datacustomcode.config import config
from datacustomcode.file.path.default import DefaultFindFilePath
from datacustomcode.io.reader.base import BaseDataCloudReader
from datacustomcode.spark.default import DefaultSparkSessionProvider

if TYPE_CHECKING:
from pathlib import Path
Expand All @@ -34,18 +33,7 @@

from datacustomcode.io.reader.base import BaseDataCloudReader
from datacustomcode.io.writer.base import BaseDataCloudWriter, WriteMode


def _setup_spark(spark_config: SparkConfig) -> SparkSession:
"""Setup Spark session from config."""
builder = SparkSession.builder
if spark_config.master is not None:
builder = builder.master(spark_config.master)

builder = builder.appName(spark_config.app_name)
for key, value in spark_config.options.items():
builder = builder.config(key, value)
return builder.getOrCreate()
from datacustomcode.spark.base import BaseSparkSessionProvider


class DataCloudObjectType(Enum):
Expand Down Expand Up @@ -123,7 +111,8 @@ class Client:
def __new__(
cls,
reader: Optional[BaseDataCloudReader] = None,
writer: Optional[BaseDataCloudWriter] = None,
writer: Optional["BaseDataCloudWriter"] = None,
spark_provider: Optional["BaseSparkSessionProvider"] = None,
) -> Client:
if cls._instance is None:
cls._instance = super().__new__(cls)
Expand All @@ -136,7 +125,16 @@ def __new__(
raise ValueError(
"Spark config is required when reader/writer is not provided"
)
spark = _setup_spark(config.spark_config)

provider: BaseSparkSessionProvider
if spark_provider is not None:
provider = spark_provider
elif config.spark_provider_config is not None:
provider = config.spark_provider_config.to_object()
else:
provider = DefaultSparkSessionProvider()

spark = provider.get_session(config.spark_config)

if config.reader_config is None and reader is None:
raise ValueError(
Expand Down
23 changes: 23 additions & 0 deletions src/datacustomcode/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
from datacustomcode.io.base import BaseDataAccessLayer
from datacustomcode.io.reader.base import BaseDataCloudReader # noqa: TCH001
from datacustomcode.io.writer.base import BaseDataCloudWriter # noqa: TCH001
from datacustomcode.spark.base import BaseSparkSessionProvider

DEFAULT_CONFIG_NAME = "config.yaml"

Expand Down Expand Up @@ -89,10 +90,29 @@ class SparkConfig(ForceableConfig):
)


_P = TypeVar("_P", bound=BaseSparkSessionProvider)


class SparkProviderConfig(ForceableConfig, Generic[_P]):
model_config = ConfigDict(validate_default=True, extra="forbid")
type_base: ClassVar[Type[BaseSparkSessionProvider]] = BaseSparkSessionProvider
type_config_name: str = Field(
description="CONFIG_NAME of the Spark session provider."
)
options: dict[str, Any] = Field(default_factory=dict)

def to_object(self) -> _P:
type_ = self.type_base.subclass_from_config_name(self.type_config_name)
return cast(_P, type_(**self.options))


class ClientConfig(BaseModel):
reader_config: Union[AccessLayerObjectConfig[BaseDataCloudReader], None] = None
writer_config: Union[AccessLayerObjectConfig[BaseDataCloudWriter], None] = None
spark_config: Union[SparkConfig, None] = None
spark_provider_config: Union[
SparkProviderConfig[BaseSparkSessionProvider], None
] = None

def update(self, other: ClientConfig) -> ClientConfig:
"""Merge this ClientConfig with another, respecting force flags.
Expand All @@ -117,6 +137,9 @@ def merge(
self.reader_config = merge(self.reader_config, other.reader_config)
self.writer_config = merge(self.writer_config, other.writer_config)
self.spark_config = merge(self.spark_config, other.spark_config)
self.spark_provider_config = merge(
self.spark_provider_config, other.spark_provider_config
)
return self

def load(self, config_path: str) -> ClientConfig:
Expand Down
20 changes: 20 additions & 0 deletions src/datacustomcode/spark/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
# Copyright (c) 2025, Salesforce, Inc.
# SPDX-License-Identifier: Apache-2
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations

from datacustomcode.spark.base import BaseSparkSessionProvider
from datacustomcode.spark.default import DefaultSparkSessionProvider

__all__ = ["BaseSparkSessionProvider", "DefaultSparkSessionProvider"]
29 changes: 29 additions & 0 deletions src/datacustomcode/spark/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
# Copyright (c) 2025, Salesforce, Inc.
# SPDX-License-Identifier: Apache-2
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations

from typing import TYPE_CHECKING

from datacustomcode.mixin import UserExtendableNamedConfigMixin

if TYPE_CHECKING:
from pyspark.sql import SparkSession

from datacustomcode.config import SparkConfig


class BaseSparkSessionProvider(UserExtendableNamedConfigMixin):
def get_session(self, spark_config: SparkConfig) -> "SparkSession":
raise NotImplementedError
39 changes: 39 additions & 0 deletions src/datacustomcode/spark/default.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
# Copyright (c) 2025, Salesforce, Inc.
# SPDX-License-Identifier: Apache-2
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations

from typing import TYPE_CHECKING

from datacustomcode.spark.base import BaseSparkSessionProvider

if TYPE_CHECKING:
from pyspark.sql import SparkSession

from datacustomcode.config import SparkConfig


class DefaultSparkSessionProvider(BaseSparkSessionProvider):
CONFIG_NAME = "DefaultSparkSessionProvider"

def get_session(self, spark_config: SparkConfig) -> "SparkSession":
from pyspark.sql import SparkSession

builder = SparkSession.builder
if spark_config.master is not None:
builder = builder.master(spark_config.master)
builder = builder.appName(spark_config.app_name)
for key, value in spark_config.options.items():
builder = builder.config(key, value)
return builder.getOrCreate()
1 change: 1 addition & 0 deletions tests/spark/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
# Package initialization file
126 changes: 126 additions & 0 deletions tests/spark/test_session_provider.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,126 @@
from __future__ import annotations

from typing import TYPE_CHECKING, Any

from datacustomcode.client import Client
from datacustomcode.config import (
AccessLayerObjectConfig,
ClientConfig,
SparkConfig,
SparkProviderConfig,
)
from datacustomcode.io.reader.base import BaseDataCloudReader
from datacustomcode.io.writer.base import BaseDataCloudWriter, WriteMode
from datacustomcode.spark.base import BaseSparkSessionProvider

if TYPE_CHECKING:
from pyspark.sql import DataFrame as PySparkDataFrame


class _Sentinel:
pass


SENTINEL_SPARK = _Sentinel()


class MockReader(BaseDataCloudReader):
CONFIG_NAME = "MockReader"
last_spark: Any | None = None

def __init__(self, spark):
super().__init__(spark)
MockReader.last_spark = spark

def read_dlo(self, name: str): # type: ignore[override]
raise NotImplementedError

def read_dmo(self, name: str): # type: ignore[override]
raise NotImplementedError


class MockWriter(BaseDataCloudWriter):
CONFIG_NAME = "MockWriter"
last_spark: Any | None = None

def __init__(self, spark):
super().__init__(spark)
MockWriter.last_spark = spark

def write_to_dlo(
self, name: str, dataframe: PySparkDataFrame, write_mode: WriteMode
) -> None: # type: ignore[override]
raise NotImplementedError

def write_to_dmo(
self, name: str, dataframe: PySparkDataFrame, write_mode: WriteMode
) -> None: # type: ignore[override]
raise NotImplementedError


class FakeProvider(BaseSparkSessionProvider):
CONFIG_NAME = "FakeProvider"

def get_session(self, spark_config: SparkConfig): # type: ignore[override]
return SENTINEL_SPARK


def _reset_singleton():
# Reset Client singleton between tests
Client._instance = None # type: ignore[attr-defined]


def test_client_uses_provider_from_config(monkeypatch):
_reset_singleton()

cfg = ClientConfig(
reader_config=AccessLayerObjectConfig(
type_config_name=MockReader.CONFIG_NAME, options={}
),
writer_config=AccessLayerObjectConfig(
type_config_name=MockWriter.CONFIG_NAME, options={}
),
spark_config=SparkConfig(app_name="test-app", master=None, options={}),
spark_provider_config=SparkProviderConfig(
type_config_name=FakeProvider.CONFIG_NAME, options={}
),
)

from datacustomcode.config import config as global_config

global_config.update(cfg)

Client()
assert MockReader.last_spark is SENTINEL_SPARK
assert MockWriter.last_spark is SENTINEL_SPARK


class ExplicitProvider(BaseSparkSessionProvider):
CONFIG_NAME = "ExplicitProvider"

def get_session(self, spark_config: SparkConfig): # type: ignore[override]
return SENTINEL_SPARK


def test_client_explicit_provider_overrides_config(monkeypatch):
_reset_singleton()

cfg = ClientConfig(
reader_config=AccessLayerObjectConfig(
type_config_name=MockReader.CONFIG_NAME, options={}
),
writer_config=AccessLayerObjectConfig(
type_config_name=MockWriter.CONFIG_NAME, options={}
),
spark_config=SparkConfig(app_name="test-app", master=None, options={}),
spark_provider_config=None,
)

from datacustomcode.config import config as global_config

global_config.update(cfg)

provider = ExplicitProvider()
Client(spark_provider=provider)
assert MockReader.last_spark is SENTINEL_SPARK
assert MockWriter.last_spark is SENTINEL_SPARK
Loading