diff --git a/Makefile b/Makefile index f79736666..510a6114a 100644 --- a/Makefile +++ b/Makefile @@ -75,12 +75,8 @@ release: install-dev .PHONY: test-python test-python: uv-venv ## Run Python unit tests @uv sync - @uv run coverage run --source=kubeflow.trainer.backends.kubernetes.backend,kubeflow.trainer.utils.utils -m pytest \ - ./kubeflow/trainer/backends/kubernetes/backend_test.py \ - ./kubeflow/trainer/backends/kubernetes/utils_test.py - @uv run coverage report -m \ - kubeflow/trainer/backends/kubernetes/backend.py \ - kubeflow/trainer/backends/kubernetes/utils.py + @uv run coverage run --source=kubeflow -m pytest ./kubeflow/ + @uv run coverage report --omit='*_test.py' --skip-covered --skip-empty ifeq ($(report),xml) @uv run coverage xml else diff --git a/kubeflow/trainer/api/trainer_client.py b/kubeflow/trainer/api/trainer_client.py index 79613f541..59238de8f 100644 --- a/kubeflow/trainer/api/trainer_client.py +++ b/kubeflow/trainer/api/trainer_client.py @@ -107,6 +107,7 @@ def train( trainer: Optional[ Union[types.CustomTrainer, types.CustomTrainerContainer, types.BuiltinTrainer] ] = None, + options: Optional[list] = None, ) -> str: """Create a TrainJob. You can configure the TrainJob using one of these trainers: @@ -124,6 +125,8 @@ def train( trainer: Optional configuration for a CustomTrainer, CustomTrainerContainer, or BuiltinTrainer. If not specified, the TrainJob will use the runtime's default values. + options: Optional list of configuration options to apply to the TrainJob. + Options can be imported from kubeflow.trainer.options. Returns: The unique name of the TrainJob that has been generated. @@ -133,7 +136,12 @@ def train( TimeoutError: Timeout to create TrainJobs. RuntimeError: Failed to create TrainJobs. """ - return self.backend.train(runtime=runtime, initializer=initializer, trainer=trainer) + return self.backend.train( + runtime=runtime, + initializer=initializer, + trainer=trainer, + options=options, + ) def list_jobs(self, runtime: Optional[types.Runtime] = None) -> list[types.TrainJob]: """List of the created TrainJobs. If a runtime is specified, only TrainJobs associated with diff --git a/kubeflow/trainer/api/trainer_client_test.py b/kubeflow/trainer/api/trainer_client_test.py new file mode 100644 index 000000000..718a10bac --- /dev/null +++ b/kubeflow/trainer/api/trainer_client_test.py @@ -0,0 +1,72 @@ +# Copyright 2025 The Kubeflow Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Unit tests for TrainerClient backend selection. +""" + +from unittest.mock import Mock, patch + +import pytest + +from kubeflow.common.types import KubernetesBackendConfig +from kubeflow.trainer.api.trainer_client import TrainerClient +from kubeflow.trainer.backends.localprocess.types import LocalProcessBackendConfig + + +@pytest.mark.parametrize( + "test_case", + [ + { + "name": "default_backend_is_kubernetes", + "backend_config": None, + "expected_backend": "KubernetesBackend", + "use_k8s_mocks": True, + }, + { + "name": "local_process_backend_selection", + "backend_config": LocalProcessBackendConfig(), + "expected_backend": "LocalProcessBackend", + "use_k8s_mocks": False, + }, + { + "name": "kubernetes_backend_selection", + "backend_config": KubernetesBackendConfig(), + "expected_backend": "KubernetesBackend", + "use_k8s_mocks": True, + }, + ], +) +def test_backend_selection(test_case): + """Test TrainerClient backend selection logic.""" + if test_case["use_k8s_mocks"]: + with ( + patch("kubernetes.config.load_kube_config"), + patch("kubernetes.client.CustomObjectsApi") as mock_custom_api, + patch("kubernetes.client.CoreV1Api") as mock_core_api, + ): + mock_custom_api.return_value = Mock() + mock_core_api.return_value = Mock() + + if test_case["backend_config"]: + client = TrainerClient(backend_config=test_case["backend_config"]) + else: + client = TrainerClient() + + backend_name = client.backend.__class__.__name__ + assert backend_name == test_case["expected_backend"] + else: + client = TrainerClient(backend_config=test_case["backend_config"]) + backend_name = client.backend.__class__.__name__ + assert backend_name == test_case["expected_backend"] diff --git a/kubeflow/trainer/backends/base.py b/kubeflow/trainer/backends/base.py index b32bbf62b..1578ff827 100644 --- a/kubeflow/trainer/backends/base.py +++ b/kubeflow/trainer/backends/base.py @@ -21,6 +21,11 @@ class RuntimeBackend(abc.ABC): + """Base class for runtime backends. + + Options self-validate by checking the backend instance type in their __call__ method. + """ + @abc.abstractmethod def list_runtimes(self) -> list[types.Runtime]: raise NotImplementedError() @@ -41,6 +46,7 @@ def train( trainer: Optional[ Union[types.CustomTrainer, types.CustomTrainerContainer, types.BuiltinTrainer] ] = None, + options: Optional[list] = None, ) -> str: raise NotImplementedError() diff --git a/kubeflow/trainer/backends/kubernetes/backend.py b/kubeflow/trainer/backends/kubernetes/backend.py index e96fb1b3e..c16903025 100644 --- a/kubeflow/trainer/backends/kubernetes/backend.py +++ b/kubeflow/trainer/backends/kubernetes/backend.py @@ -20,7 +20,7 @@ import re import string import time -from typing import Optional, Union +from typing import Any, Optional, Union import uuid from kubeflow_trainer_api import models @@ -87,15 +87,9 @@ def list_runtimes(self) -> list[types.Runtime]: result.append(self.__get_runtime_from_cr(runtime)) except multiprocessing.TimeoutError as e: - raise TimeoutError( - f"Timeout to list {constants.CLUSTER_TRAINING_RUNTIME_KIND}s " - f"in namespace: {self.namespace}" - ) from e + raise TimeoutError(f"Timeout to list {constants.CLUSTER_TRAINING_RUNTIME_KIND}s") from e except Exception as e: - raise RuntimeError( - f"Failed to list {constants.CLUSTER_TRAINING_RUNTIME_KIND}s " - f"in namespace: {self.namespace}" - ) from e + raise RuntimeError(f"Failed to list {constants.CLUSTER_TRAINING_RUNTIME_KIND}s") from e return result @@ -184,16 +178,62 @@ def train( trainer: Optional[ Union[types.CustomTrainer, types.CustomTrainerContainer, types.BuiltinTrainer] ] = None, + options: Optional[list] = None, ) -> str: - # Generate unique name for the TrainJob. - train_job_name = random.choice(string.ascii_lowercase) + uuid.uuid4().hex[:11] + if runtime is None: + runtime = self.get_runtime(constants.TORCH_RUNTIME) + + # Process options to extract configuration + job_spec = {} + labels = None + annotations = None + name = None + spec_labels = None + spec_annotations = None + trainer_overrides = {} + pod_template_overrides = None + + if options: + for option in options: + option(job_spec, trainer, self) + + metadata_section = job_spec.get("metadata", {}) + labels = metadata_section.get("labels") + annotations = metadata_section.get("annotations") + name = metadata_section.get("name") + + # Extract spec-level labels/annotations and other spec configurations + spec_section = job_spec.get("spec", {}) + spec_labels = spec_section.get("labels") + spec_annotations = spec_section.get("annotations") + trainer_overrides = spec_section.get("trainer", {}) + pod_template_overrides = spec_section.get("podTemplateOverrides") + + # Generate unique name for the TrainJob if not provided + train_job_name = name or ( + random.choice(string.ascii_lowercase) + + uuid.uuid4().hex[: constants.JOB_NAME_UUID_LENGTH] + ) + + # Build the TrainJob spec using the common _get_trainjob_spec method + trainjob_spec = self._get_trainjob_spec( + runtime=runtime, + initializer=initializer, + trainer=trainer, + trainer_overrides=trainer_overrides, + spec_labels=spec_labels, + spec_annotations=spec_annotations, + pod_template_overrides=pod_template_overrides, + ) # Build the TrainJob. train_job = models.TrainerV1alpha1TrainJob( apiVersion=constants.API_VERSION, kind=constants.TRAINJOB_KIND, - metadata=models.IoK8sApimachineryPkgApisMetaV1ObjectMeta(name=train_job_name), - spec=self._get_trainjob_spec(runtime, initializer, trainer), + metadata=models.IoK8sApimachineryPkgApisMetaV1ObjectMeta( + name=train_job_name, labels=labels, annotations=annotations + ), + spec=trainjob_spec, ) # Create the TrainJob. @@ -549,6 +589,10 @@ def _get_trainjob_spec( trainer: Optional[ Union[types.CustomTrainer, types.CustomTrainerContainer, types.BuiltinTrainer] ] = None, + trainer_overrides: Optional[dict[str, Any]] = None, + spec_labels: Optional[dict[str, str]] = None, + spec_annotations: Optional[dict[str, str]] = None, + pod_template_overrides: Optional[models.IoK8sApiCoreV1PodTemplateSpec] = None, ) -> models.TrainerV1alpha1TrainJobSpec: """Get TrainJob spec from the given parameters""" if runtime is None: @@ -575,9 +619,16 @@ def _get_trainjob_spec( else: raise ValueError( f"The trainer type {type(trainer)} is not supported. " - "Please use CustomTrainer or BuiltinTrainer." + "Please use CustomTrainer, CustomTrainerContainer, or BuiltinTrainer." ) + # Apply trainer overrides if trainer was not provided but overrides exist + if trainer_overrides: + if "command" in trainer_overrides: + trainer_cr.command = trainer_overrides["command"] + if "args" in trainer_overrides: + trainer_cr.args = trainer_overrides["args"] + return models.TrainerV1alpha1TrainJobSpec( runtimeRef=models.TrainerV1alpha1RuntimeRef(name=runtime.name), trainer=(trainer_cr if trainer_cr != models.TrainerV1alpha1Trainer() else None), @@ -589,4 +640,7 @@ def _get_trainjob_spec( if isinstance(initializer, types.Initializer) else None ), + labels=spec_labels, + annotations=spec_annotations, + pod_template_overrides=pod_template_overrides, ) diff --git a/kubeflow/trainer/backends/kubernetes/backend_test.py b/kubeflow/trainer/backends/kubernetes/backend_test.py index 2becf5874..c5fcc190c 100644 --- a/kubeflow/trainer/backends/kubernetes/backend_test.py +++ b/kubeflow/trainer/backends/kubernetes/backend_test.py @@ -35,6 +35,12 @@ from kubeflow.trainer.backends.kubernetes.backend import KubernetesBackend import kubeflow.trainer.backends.kubernetes.utils as utils from kubeflow.trainer.constants import constants +from kubeflow.trainer.options import ( + Annotations, + Labels, + SpecAnnotations, + SpecLabels, +) from kubeflow.trainer.test.common import ( DEFAULT_NAMESPACE, FAILED, @@ -274,6 +280,10 @@ def get_train_job( runtime_name: str, train_job_name: str = BASIC_TRAIN_JOB_NAME, train_job_trainer: Optional[models.TrainerV1alpha1Trainer] = None, + labels: Optional[dict[str, str]] = None, + annotations: Optional[dict[str, str]] = None, + spec_labels: Optional[dict[str, str]] = None, + spec_annotations: Optional[dict[str, str]] = None, ) -> models.TrainerV1alpha1TrainJob: """ Create a mock TrainJob object with optional trainer configurations. @@ -281,10 +291,16 @@ def get_train_job( train_job = models.TrainerV1alpha1TrainJob( apiVersion=constants.API_VERSION, kind=constants.TRAINJOB_KIND, - metadata=models.IoK8sApimachineryPkgApisMetaV1ObjectMeta(name=train_job_name), + metadata=models.IoK8sApimachineryPkgApisMetaV1ObjectMeta( + name=train_job_name, + labels=labels, + annotations=annotations, + ), spec=models.TrainerV1alpha1TrainJobSpec( runtimeRef=models.TrainerV1alpha1RuntimeRef(name=runtime_name), trainer=train_job_trainer, + labels=spec_labels, + annotations=spec_annotations, ), ) @@ -879,6 +895,58 @@ def test_get_runtime_packages(kubernetes_backend, test_case): }, expected_error=ValueError, ), + TestCase( + name="train with metadata labels and annotations", + expected_status=SUCCESS, + config={ + "options": [ + Labels({"team": "ml-platform"}), + Annotations({"created-by": "sdk"}), + ], + }, + expected_output=get_train_job( + runtime_name=TORCH_RUNTIME, + train_job_name=BASIC_TRAIN_JOB_NAME, + labels={"team": "ml-platform"}, + annotations={"created-by": "sdk"}, + ), + ), + TestCase( + name="train with spec labels and annotations", + expected_status=SUCCESS, + config={ + "options": [ + SpecLabels({"app": "training", "version": "v1.0"}), + SpecAnnotations({"prometheus.io/scrape": "true"}), + ], + }, + expected_output=get_train_job( + runtime_name=TORCH_RUNTIME, + train_job_name=BASIC_TRAIN_JOB_NAME, + spec_labels={"app": "training", "version": "v1.0"}, + spec_annotations={"prometheus.io/scrape": "true"}, + ), + ), + TestCase( + name="train with both metadata and spec labels/annotations", + expected_status=SUCCESS, + config={ + "options": [ + Labels({"owner": "ml-team"}), + Annotations({"description": "Fine-tuning job"}), + SpecLabels({"app": "training", "version": "v1.0"}), + SpecAnnotations({"prometheus.io/scrape": "true"}), + ], + }, + expected_output=get_train_job( + runtime_name=TORCH_RUNTIME, + train_job_name=BASIC_TRAIN_JOB_NAME, + labels={"owner": "ml-team"}, + annotations={"description": "Fine-tuning job"}, + spec_labels={"app": "training", "version": "v1.0"}, + spec_annotations={"prometheus.io/scrape": "true"}, + ), + ), ], ) def test_train(kubernetes_backend, test_case): @@ -888,8 +956,12 @@ def test_train(kubernetes_backend, test_case): kubernetes_backend.namespace = test_case.config.get("namespace", DEFAULT_NAMESPACE) runtime = kubernetes_backend.get_runtime(test_case.config.get("runtime", TORCH_RUNTIME)) + options = test_case.config.get("options", []) + train_job_name = kubernetes_backend.train( - runtime=runtime, trainer=test_case.config.get("trainer", None) + runtime=runtime, + trainer=test_case.config.get("trainer", None), + options=options, ) assert test_case.expected_status == SUCCESS diff --git a/kubeflow/trainer/backends/kubernetes/constants.py b/kubeflow/trainer/backends/kubernetes/constants.py new file mode 100644 index 000000000..0a1d9ba02 --- /dev/null +++ b/kubeflow/trainer/backends/kubernetes/constants.py @@ -0,0 +1,15 @@ +# Copyright 2025 The Kubeflow Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Kubernetes backend-specific constants.""" diff --git a/kubeflow/trainer/backends/kubernetes/utils.py b/kubeflow/trainer/backends/kubernetes/utils.py index 75446c4df..c88db0eda 100644 --- a/kubeflow/trainer/backends/kubernetes/utils.py +++ b/kubeflow/trainer/backends/kubernetes/utils.py @@ -349,6 +349,10 @@ def get_trainer_cr_from_custom_trainer( ) -> models.TrainerV1alpha1Trainer: """ Get the Trainer CR from the custom trainer. + + Args: + runtime: The runtime configuration. + trainer: The custom trainer or container configuration. """ trainer_cr = models.TrainerV1alpha1Trainer() @@ -361,7 +365,7 @@ def get_trainer_cr_from_custom_trainer( trainer_cr.resources_per_node = get_resources_per_node(trainer.resources_per_node) if isinstance(trainer, types.CustomTrainer): - # If CustomTrainer is used, add command to the Trainer. + # If CustomTrainer is used, generate command from function. trainer_cr.command = get_command_using_train_func( runtime, trainer.func, diff --git a/kubeflow/trainer/backends/kubernetes/utils_test.py b/kubeflow/trainer/backends/kubernetes/utils_test.py index d55bfa41b..3405642d5 100644 --- a/kubeflow/trainer/backends/kubernetes/utils_test.py +++ b/kubeflow/trainer/backends/kubernetes/utils_test.py @@ -238,8 +238,6 @@ def test_get_script_for_python_packages(test_case): ], ) def test_get_command_using_train_func(test_case: TestCase): - print("Executing test:", test_case.name) - try: command = utils.get_command_using_train_func( runtime=test_case.config["runtime"], @@ -254,7 +252,6 @@ def test_get_command_using_train_func(test_case: TestCase): except Exception as e: assert type(e) is test_case.expected_error - print("test execution complete") def test_get_dataset_initializer(): diff --git a/kubeflow/trainer/backends/localprocess/backend.py b/kubeflow/trainer/backends/localprocess/backend.py index a3597e777..f9a827406 100644 --- a/kubeflow/trainer/backends/localprocess/backend.py +++ b/kubeflow/trainer/backends/localprocess/backend.py @@ -62,11 +62,11 @@ def get_runtime(self, name: str) -> types.Runtime: return runtime def get_runtime_packages(self, runtime: types.Runtime): - runtime = next((rt for rt in local_runtimes if rt.name == runtime.name), None) - if not runtime: + local_runtime = next((rt for rt in local_runtimes if rt.name == runtime.name), None) + if not local_runtime: raise ValueError(f"Runtime '{runtime.name}' not found.") - return runtime.trainer.packages + return local_runtime.trainer.packages def train( self, @@ -75,9 +75,27 @@ def train( trainer: Optional[ Union[types.CustomTrainer, types.CustomTrainerContainer, types.BuiltinTrainer] ] = None, + options: Optional[list] = None, ) -> str: - # set train job name - train_job_name = random.choice(string.ascii_lowercase) + uuid.uuid4().hex[:11] + if runtime is None: + raise ValueError("Runtime must be provided for LocalProcessBackend") + + # Process options to extract configuration + name = None + if options: + job_spec = {} + for option in options: + option(job_spec, trainer, self) + + metadata_section = job_spec.get("metadata", {}) + name = metadata_section.get("name") + + # Generate train job name if not provided via options + train_job_name = name or ( + random.choice(string.ascii_lowercase) + + uuid.uuid4().hex[: constants.JOB_NAME_UUID_LENGTH] + ) + # localprocess backend only supports CustomTrainer if not isinstance(trainer, types.CustomTrainer): raise ValueError("CustomTrainer must be set with LocalProcessBackend") @@ -86,6 +104,7 @@ def train( venv_dir = tempfile.mkdtemp(prefix=train_job_name) logger.debug(f"operating in {venv_dir}") + # get local runtime trainer runtime.trainer = local_utils.get_local_runtime_trainer( runtime_name=runtime.name, venv_dir=venv_dir, @@ -156,7 +175,11 @@ def get_job(self, name: str) -> types.TrainJob: name=_job.name, creation_timestamp=_job.created, steps=[ - types.Step(name=_step.step_name, pod_name=_step.step_name, status=_step.job.status) + types.Step( + name=_step.step_name, + pod_name=_step.step_name, + status=_step.job.status, + ) for _step in _job.steps ], runtime=_job.runtime, @@ -197,7 +220,10 @@ def wait_for_job_status( raise ValueError(f"No TrainJob with name {name}") # find a better implementation for this for _step in _job.steps: - if _step.job.status in [constants.TRAINJOB_RUNNING, constants.TRAINJOB_CREATED]: + if _step.job.status in [ + constants.TRAINJOB_RUNNING, + constants.TRAINJOB_CREATED, + ]: _step.job.join(timeout=timeout) return self.get_job(name) @@ -233,14 +259,15 @@ def __register_job( job: LocalJob, runtime: types.Runtime = None, ): - _job = [j for j in self.__local_jobs if j.name == train_job_name] - if not _job: + existing_jobs = [j for j in self.__local_jobs if j.name == train_job_name] + if not existing_jobs: _job = LocalBackendJobs(name=train_job_name, runtime=runtime, created=datetime.now()) self.__local_jobs.append(_job) else: - _job = _job[0] - _step = [s for s in _job.steps if s.step_name == step_name] - if not _step: + _job = existing_jobs[0] + + existing_steps = [s for s in _job.steps if s.step_name == step_name] + if not existing_steps: _step = LocalBackendStep(step_name=step_name, job=job) _job.steps.append(_step) else: diff --git a/kubeflow/trainer/backends/localprocess/backend_test.py b/kubeflow/trainer/backends/localprocess/backend_test.py new file mode 100644 index 000000000..1702e9224 --- /dev/null +++ b/kubeflow/trainer/backends/localprocess/backend_test.py @@ -0,0 +1,488 @@ +# Copyright 2025 The Kubeflow Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Unit tests for the LocalProcessBackend class in the Kubeflow Trainer SDK. +""" + +from unittest.mock import Mock, patch + +import pytest + +from kubeflow.trainer.backends.localprocess.backend import LocalProcessBackend +from kubeflow.trainer.backends.localprocess.types import ( + LocalProcessBackendConfig, + LocalRuntimeTrainer, +) +from kubeflow.trainer.constants import constants +from kubeflow.trainer.options import ( + Annotations, + Labels, + Name, + PodTemplateOverride, + PodTemplateOverrides, +) +from kubeflow.trainer.test.common import FAILED, SUCCESS, TestCase +from kubeflow.trainer.types import types + +# Test constants +TORCH_RUNTIME = constants.TORCH_RUNTIME +BASIC_TRAIN_JOB_NAME = "test-job" + + +def dummy_training_function(): + """Dummy training function for testing.""" + print("Training started") + return {"loss": 0.5, "accuracy": 0.95} + + +@pytest.fixture +def local_backend(): + """Create LocalProcessBackend for testing.""" + cfg = LocalProcessBackendConfig() + backend = LocalProcessBackend(cfg) + yield backend + # Cleanup: Clear jobs to prevent test pollution + backend._LocalProcessBackend__local_jobs.clear() + + +@pytest.fixture +def mock_train_environment(): + """Mock the training environment to avoid actual subprocess execution.""" + with ( + patch("kubeflow.trainer.backends.localprocess.job.LocalJob.start") as mock_start, + patch( + "kubeflow.trainer.backends.localprocess.utils.get_local_runtime_trainer" + ) as mock_get_trainer, + patch( + "kubeflow.trainer.backends.localprocess.utils.get_local_train_job_script" + ) as mock_get_script, + patch("tempfile.mkdtemp") as mock_mkdtemp, + ): + # Setup mock return values + mock_mkdtemp.return_value = "/tmp/test-venv" + mock_get_script.return_value = ["/bin/bash", "-c", "echo 'training'"] + + mock_trainer = LocalRuntimeTrainer( + trainer_type=types.TrainerType.CUSTOM_TRAINER, + framework="torch", + num_nodes=1, + device_count=1, + device="cpu", + packages=["torch"], + ) + mock_trainer.set_command = Mock() + mock_get_trainer.return_value = mock_trainer + + yield { + "start": mock_start, + "get_trainer": mock_get_trainer, + "get_script": mock_get_script, + "mkdtemp": mock_mkdtemp, + } + + +@pytest.mark.parametrize( + "test_case", + [ + TestCase( + name="list_all_local_runtimes", + expected_status=SUCCESS, + config={}, + ), + ], +) +def test_list_runtimes(local_backend, test_case): + """Test LocalProcessBackend.list_runtimes().""" + runtimes = local_backend.list_runtimes() + assert len(runtimes) > 0 + assert all(isinstance(rt, types.Runtime) for rt in runtimes) + + +@pytest.mark.parametrize( + "test_case", + [ + TestCase( + name="get_existing_runtime", + expected_status=SUCCESS, + config={"runtime_name": TORCH_RUNTIME}, + ), + TestCase( + name="get_nonexistent_runtime", + expected_status=FAILED, + config={"runtime_name": "nonexistent-runtime"}, + expected_error=ValueError, + ), + ], +) +def test_get_runtime(local_backend, test_case): + """Test LocalProcessBackend.get_runtime().""" + runtime_name = test_case.config.get("runtime_name") + + if test_case.expected_status == FAILED: + with pytest.raises(test_case.expected_error): + local_backend.get_runtime(runtime_name) + else: + runtime = local_backend.get_runtime(runtime_name) + assert runtime is not None + assert runtime.name == runtime_name + + +@pytest.mark.parametrize( + "test_case", + [ + TestCase( + name="get_packages_for_existing_runtime", + expected_status=SUCCESS, + config={ + "runtime": types.Runtime( + name=TORCH_RUNTIME, + trainer=types.RuntimeTrainer( + trainer_type=types.TrainerType.CUSTOM_TRAINER, + framework="torch", + num_nodes=1, + ), + ), + }, + ), + TestCase( + name="get_packages_for_nonexistent_runtime", + expected_status=FAILED, + config={ + "runtime": types.Runtime( + name="nonexistent-runtime", + trainer=types.RuntimeTrainer( + trainer_type=types.TrainerType.CUSTOM_TRAINER, + framework="torch", + num_nodes=1, + ), + ), + }, + expected_error=ValueError, + ), + ], +) +def test_get_runtime_packages(local_backend, test_case): + """Test LocalProcessBackend.get_runtime_packages().""" + runtime = test_case.config.get("runtime") + + if test_case.expected_status == FAILED: + with pytest.raises(test_case.expected_error): + local_backend.get_runtime_packages(runtime) + else: + packages = local_backend.get_runtime_packages(runtime) + assert packages is not None + assert isinstance(packages, list) + + +@pytest.mark.parametrize( + "test_case", + [ + TestCase( + name="train with basic custom trainer - no options", + expected_status=SUCCESS, + config={ + "runtime": types.Runtime( + name=TORCH_RUNTIME, + trainer=types.RuntimeTrainer( + trainer_type=types.TrainerType.CUSTOM_TRAINER, + framework="torch", + num_nodes=1, + ), + ), + "trainer": types.CustomTrainer( + func=dummy_training_function, + packages_to_install=["numpy", "torch"], + ), + "options": [], + }, + ), + TestCase( + name="train with custom trainer and environment variables", + expected_status=SUCCESS, + config={ + "runtime": types.Runtime( + name=TORCH_RUNTIME, + trainer=types.RuntimeTrainer( + trainer_type=types.TrainerType.CUSTOM_TRAINER, + framework="torch", + num_nodes=1, + ), + ), + "trainer": types.CustomTrainer( + func=dummy_training_function, + packages_to_install=["torch"], + env={"CUDA_VISIBLE_DEVICES": "0", "OMP_NUM_THREADS": "4"}, + ), + "options": [], + }, + ), + TestCase( + name="train rejects kubernetes labels option", + expected_status=FAILED, + config={ + "runtime": types.Runtime( + name=TORCH_RUNTIME, + trainer=types.RuntimeTrainer( + trainer_type=types.TrainerType.CUSTOM_TRAINER, + framework="torch", + num_nodes=1, + ), + ), + "trainer": types.CustomTrainer( + func=dummy_training_function, + ), + "options": [Labels({"app": "test"})], + }, + expected_error=ValueError, + ), + TestCase( + name="train rejects kubernetes annotations option", + expected_status=FAILED, + config={ + "runtime": types.Runtime( + name=TORCH_RUNTIME, + trainer=types.RuntimeTrainer( + trainer_type=types.TrainerType.CUSTOM_TRAINER, + framework="torch", + num_nodes=1, + ), + ), + "trainer": types.CustomTrainer( + func=dummy_training_function, + ), + "options": [Annotations({"description": "test"})], + }, + expected_error=ValueError, + ), + TestCase( + name="train rejects pod template overrides option", + expected_status=FAILED, + config={ + "runtime": types.Runtime( + name=TORCH_RUNTIME, + trainer=types.RuntimeTrainer( + trainer_type=types.TrainerType.CUSTOM_TRAINER, + framework="torch", + num_nodes=1, + ), + ), + "trainer": types.CustomTrainer( + func=dummy_training_function, + ), + "options": [ + PodTemplateOverrides( + PodTemplateOverride( + target_jobs=["node"], + ) + ) + ], + }, + expected_error=ValueError, + ), + TestCase( + name="train fails without runtime", + expected_status=FAILED, + config={ + "runtime": None, + "trainer": types.CustomTrainer( + func=dummy_training_function, + ), + "options": [], + }, + expected_error=ValueError, + ), + TestCase( + name="train fails without custom trainer", + expected_status=FAILED, + config={ + "runtime": types.Runtime( + name=TORCH_RUNTIME, + trainer=types.RuntimeTrainer( + trainer_type=types.TrainerType.CUSTOM_TRAINER, + framework="torch", + num_nodes=1, + ), + ), + "trainer": None, + }, + expected_error=ValueError, + ), + ], +) +def test_train(local_backend, mock_train_environment, test_case): + """Test LocalProcessBackend.train() with success and failure cases.""" + runtime = test_case.config.get("runtime") + trainer = test_case.config.get("trainer") + options = test_case.config.get("options", []) + + mocks = mock_train_environment + + if test_case.expected_status == FAILED: + with pytest.raises(test_case.expected_error) as exc_info: + local_backend.train( + runtime=runtime, + trainer=trainer, + options=options, + ) + + # Verify specific error messages + error_msg = str(exc_info.value) + if "rejects kubernetes" in test_case.name: + assert "not compatible with" in error_msg + elif "without runtime" in test_case.name: + assert "Runtime must be provided" in error_msg + elif "without custom trainer" in test_case.name: + assert "CustomTrainer must be set" in error_msg + else: + train_job_name = local_backend.train( + runtime=runtime, + trainer=trainer, + options=options, + ) + + assert train_job_name is not None + assert len(train_job_name) > 0 + mocks["start"].assert_called_once() + mocks["get_trainer"].assert_called_once() + mocks["get_script"].assert_called_once() + + # Verify job is tracked + jobs = local_backend.list_jobs(runtime=runtime) + assert any(job.name == train_job_name for job in jobs) + + +@pytest.mark.parametrize( + "test_case", + [ + TestCase( + name="get_nonexistent_job", + expected_status=FAILED, + config={"job_name": "nonexistent-job"}, + expected_error=ValueError, + ), + ], +) +def test_get_job(local_backend, test_case): + """Test LocalProcessBackend.get_job().""" + job_name = test_case.config.get("job_name") + + if test_case.expected_status == FAILED: + with pytest.raises(test_case.expected_error): + local_backend.get_job(job_name) + + +@pytest.mark.parametrize( + "test_case", + [ + TestCase( + name="list_jobs_empty", + expected_status=SUCCESS, + config={"runtime": None}, + ), + ], +) +def test_list_jobs(local_backend, test_case): + """Test LocalProcessBackend.list_jobs().""" + runtime = test_case.config.get("runtime") + jobs = local_backend.list_jobs(runtime=runtime) + assert isinstance(jobs, list) + + +@pytest.mark.parametrize( + "test_case", + [ + TestCase( + name="get_logs_nonexistent_job", + expected_status=FAILED, + config={"job_name": "nonexistent-job", "step": "train"}, + expected_error=ValueError, + ), + ], +) +def test_get_job_logs(local_backend, test_case): + """Test LocalProcessBackend.get_job_logs().""" + job_name = test_case.config.get("job_name") + step = test_case.config.get("step", "train") + + if test_case.expected_status == FAILED: + with pytest.raises(test_case.expected_error): + list(local_backend.get_job_logs(job_name, step=step)) + + +@pytest.mark.parametrize( + "test_case", + [ + TestCase( + name="wait_for_nonexistent_job", + expected_status=FAILED, + config={"job_name": "nonexistent-job"}, + expected_error=ValueError, + ), + ], +) +def test_wait_for_job_status(local_backend, test_case): + """Test LocalProcessBackend.wait_for_job_status().""" + job_name = test_case.config.get("job_name") + + if test_case.expected_status == FAILED: + with pytest.raises(test_case.expected_error): + local_backend.wait_for_job_status(job_name) + + +@pytest.mark.parametrize( + "test_case", + [ + TestCase( + name="delete_nonexistent_job", + expected_status=FAILED, + config={"job_name": "nonexistent-job"}, + expected_error=ValueError, + ), + ], +) +def test_delete_job(local_backend, test_case): + """Test LocalProcessBackend.delete_job().""" + job_name = test_case.config.get("job_name") + + if test_case.expected_status == FAILED: + with pytest.raises(test_case.expected_error): + local_backend.delete_job(job_name) + + +def test_name_option_sets_job_name(local_backend, mock_train_environment): + """Test that Name option sets the custom job name.""" + custom_name = "my-custom-job-name" + + def dummy_func(): + pass + + runtime = types.Runtime( + name=TORCH_RUNTIME, + trainer=types.RuntimeTrainer( + trainer_type=types.TrainerType.CUSTOM_TRAINER, + framework="torch", + ), + ) + + trainer = types.CustomTrainer(func=dummy_func) + options = [Name(name=custom_name)] + + job_name = local_backend.train( + runtime=runtime, + trainer=trainer, + options=options, + ) + + assert job_name == custom_name diff --git a/kubeflow/trainer/backends/localprocess/types.py b/kubeflow/trainer/backends/localprocess/types.py index 09bd452a8..ebd4c3f85 100644 --- a/kubeflow/trainer/backends/localprocess/types.py +++ b/kubeflow/trainer/backends/localprocess/types.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. + from dataclasses import dataclass, field from datetime import datetime import typing diff --git a/kubeflow/trainer/backends/localprocess/utils.py b/kubeflow/trainer/backends/localprocess/utils.py index 1c18676e8..94c20a3a7 100644 --- a/kubeflow/trainer/backends/localprocess/utils.py +++ b/kubeflow/trainer/backends/localprocess/utils.py @@ -122,7 +122,8 @@ def get_local_runtime_trainer( Get the LocalRuntimeTrainer object. """ local_runtime = next( - (rt for rt in local_exec_constants.local_runtimes if rt.name == runtime_name), None + (rt for rt in local_exec_constants.local_runtimes if rt.name == runtime_name), + None, ) if not local_runtime: raise ValueError(f"Runtime {runtime_name} not found") @@ -267,9 +268,11 @@ def get_local_train_job_script( dependency_script = "\n" if trainer.packages_to_install: dependency_script = get_dependencies_command( - pip_index_urls=trainer.pip_index_urls - if trainer.pip_index_urls - else constants.DEFAULT_PIP_INDEX_URLS, + pip_index_urls=( + trainer.pip_index_urls + if trainer.pip_index_urls + else constants.DEFAULT_PIP_INDEX_URLS + ), runtime_packages=runtime_trainer.packages, trainer_packages=trainer.packages_to_install, quiet=False, diff --git a/kubeflow/trainer/constants/constants.py b/kubeflow/trainer/constants/constants.py index a15d402ae..52932e603 100644 --- a/kubeflow/trainer/constants/constants.py +++ b/kubeflow/trainer/constants/constants.py @@ -167,3 +167,7 @@ DEFAULT_FRAMEWORK_IMAGES = { "torch": "pytorch/pytorch:2.7.1-cuda12.8-cudnn9-runtime", } + +# The length of the UUID suffix for auto-generated job names. +# Total name length = 1 (random letter) + 11 (UUID hex) = 12 characters +JOB_NAME_UUID_LENGTH = 11 diff --git a/kubeflow/trainer/options/__init__.py b/kubeflow/trainer/options/__init__.py new file mode 100644 index 000000000..cd4f7b618 --- /dev/null +++ b/kubeflow/trainer/options/__init__.py @@ -0,0 +1,52 @@ +# Copyright 2025 The Kubeflow Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Training options for the Kubeflow Trainer SDK. + +All options are available from this single import location: + from kubeflow.trainer.options import Name, Labels, PodTemplateOverrides, ... + +Options self-validate their backend compatibility at runtime. +Check each option's docstring for supported backends. +""" + +from kubeflow.trainer.options.common import Name +from kubeflow.trainer.options.kubernetes import ( + Annotations, + ContainerOverride, + Labels, + PodSpecOverride, + PodTemplateOverride, + PodTemplateOverrides, + SpecAnnotations, + SpecLabels, + TrainerArgs, + TrainerCommand, +) + +__all__ = [ + # Common options (all backends) + "Name", + # Kubernetes options + "Annotations", + "ContainerOverride", + "Labels", + "PodSpecOverride", + "PodTemplateOverride", + "PodTemplateOverrides", + "SpecAnnotations", + "SpecLabels", + "TrainerArgs", + "TrainerCommand", +] diff --git a/kubeflow/trainer/options/common.py b/kubeflow/trainer/options/common.py new file mode 100644 index 000000000..abcad086d --- /dev/null +++ b/kubeflow/trainer/options/common.py @@ -0,0 +1,51 @@ +# Copyright 2025 The Kubeflow Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Common options and helper classes used across multiple backends.""" + +from dataclasses import dataclass +from typing import Any, Optional, Union + +from kubeflow.trainer.backends.base import RuntimeBackend +from kubeflow.trainer.types.types import BuiltinTrainer, CustomTrainer, CustomTrainerContainer + + +@dataclass +class Name: + """Set a custom name for the TrainJob resource. + + This option works with all backends. + + Args: + name: Custom name for the job. Must be a valid identifier. + """ + + name: str + + def __call__( + self, + job_spec: dict[str, Any], + trainer: Optional[Union[BuiltinTrainer, CustomTrainer, CustomTrainerContainer]], + backend: RuntimeBackend, + ) -> None: + """Apply custom name to the job specification. + + Args: + job_spec: Job specification dictionary to modify. + trainer: Optional trainer instance for context. + backend: Backend instance for validation and context. + """ + # Name option is generic - works with all backends + metadata = job_spec.setdefault("metadata", {}) + metadata["name"] = self.name diff --git a/kubeflow/trainer/options/kubernetes.py b/kubeflow/trainer/options/kubernetes.py new file mode 100644 index 000000000..a5b7596d9 --- /dev/null +++ b/kubeflow/trainer/options/kubernetes.py @@ -0,0 +1,498 @@ +# Copyright 2025 The Kubeflow Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Kubernetes-specific training options for the Kubeflow Trainer SDK.""" + +from dataclasses import dataclass +from typing import Any, Optional, Union + +from kubeflow.trainer.backends.base import RuntimeBackend +from kubeflow.trainer.types.types import BuiltinTrainer, CustomTrainer, CustomTrainerContainer + + +@dataclass +class ContainerOverride: + """Configuration for overriding a specific container in a pod. + + Args: + name: Name of the container to override (must exist in TrainingRuntime). + env: Environment variables to add/merge with the container. + Each dict should have 'name' and 'value' or 'valueFrom' keys. + volume_mounts: Volume mounts to add/merge with the container. + Each dict should have 'name' and 'mountPath' keys at minimum. + """ + + name: str + env: Optional[list[dict]] = None + volume_mounts: Optional[list[dict]] = None + + def __post_init__(self): + """Validate the container override configuration.""" + # Validate container name + if not self.name or not self.name.strip(): + raise ValueError("Container name must be a non-empty string") + + if self.env is not None: + if not isinstance(self.env, list): + raise ValueError("env must be a list of dictionaries") + for env_var in self.env: + if not isinstance(env_var, dict): + raise ValueError("Each env entry must be a dictionary") + if "name" not in env_var: + raise ValueError("Each env entry must have a 'name' key") + if not env_var.get("name"): + raise ValueError("env 'name' must be a non-empty string") + if "value" not in env_var and "valueFrom" not in env_var: + raise ValueError("Each env entry must have either 'value' or 'valueFrom' key") + # Validate valueFrom structure if present + if "valueFrom" in env_var: + value_from = env_var["valueFrom"] + if not isinstance(value_from, dict): + raise ValueError("env 'valueFrom' must be a dictionary") + # valueFrom must have one of these keys + valid_keys = {"configMapKeyRef", "secretKeyRef", "fieldRef", "resourceFieldRef"} + if not any(key in value_from for key in valid_keys): + raise ValueError( + f"env 'valueFrom' must contain one of: {', '.join(valid_keys)}" + ) + + if self.volume_mounts is not None: + if not isinstance(self.volume_mounts, list): + raise ValueError("volume_mounts must be a list of dictionaries") + for mount in self.volume_mounts: + if not isinstance(mount, dict): + raise ValueError("Each volume_mounts entry must be a dictionary") + if "name" not in mount: + raise ValueError("Each volume_mounts entry must have a 'name' key") + if not mount.get("name"): + raise ValueError("volume_mounts 'name' must be a non-empty string") + if "mountPath" not in mount: + raise ValueError("Each volume_mounts entry must have a 'mountPath' key") + mount_path = mount.get("mountPath") + if not mount_path or not isinstance(mount_path, str): + raise ValueError("volume_mounts 'mountPath' must be a non-empty string") + if not mount_path.startswith("/"): + raise ValueError( + f"volume_mounts 'mountPath' must be an absolute path " + f"(start with /): {mount_path}" + ) + + +@dataclass +class PodSpecOverride: + """Configuration for overriding pod template specifications. + + Args: + service_account_name: Service account to use for the pods. + node_selector: Node selector to place pods on specific nodes. + affinity: Affinity rules for pod scheduling. + tolerations: Tolerations for pod scheduling. + volumes: Volumes to add/merge with the pod. + init_containers: Init containers to add/merge with the pod. + containers: Containers to add/merge with the pod. + scheduling_gates: Scheduling gates for the pods. + image_pull_secrets: Image pull secrets for the pods. + """ + + service_account_name: Optional[str] = None + node_selector: Optional[dict[str, str]] = None + affinity: Optional[dict] = None + tolerations: Optional[list[dict]] = None + volumes: Optional[list[dict]] = None + init_containers: Optional[list[ContainerOverride]] = None + containers: Optional[list[ContainerOverride]] = None + scheduling_gates: Optional[list[dict]] = None + image_pull_secrets: Optional[list[dict]] = None + + +@dataclass +class PodTemplateOverride: + """Configuration for overriding pod templates for specific job types. + + Args: + target_jobs: List of job names to apply the overrides to (e.g., ["node", "launcher"]). + metadata: Metadata overrides for the pod template (labels, annotations). + spec: Spec overrides for the pod template. + """ + + target_jobs: list[str] + metadata: Optional[dict] = None + spec: Optional[PodSpecOverride] = None + + +@dataclass +class Labels: + """Add labels to the TrainJob resource metadata (.metadata.labels). + + Supported backends: + - Kubernetes + + Args: + labels: Dictionary of label key-value pairs to add to TrainJob metadata. + """ + + labels: dict[str, str] + + def __call__( + self, + job_spec: dict[str, Any], + trainer: Optional[Union[CustomTrainer, BuiltinTrainer]], + backend: RuntimeBackend, + ) -> None: + """Apply labels to the job specification. + + Args: + job_spec: Job specification dictionary to modify. + trainer: Optional trainer instance for context. + backend: Backend instance for validation. + + Raises: + ValueError: If backend does not support labels. + """ + from kubeflow.trainer.backends.kubernetes.backend import KubernetesBackend + + if not isinstance(backend, KubernetesBackend): + raise ValueError( + f"Labels option is not compatible with {type(backend).__name__}. " + f"Supported backends: KubernetesBackend" + ) + + metadata = job_spec.setdefault("metadata", {}) + metadata["labels"] = self.labels + + +@dataclass +class Annotations: + """Add annotations to the TrainJob resource metadata (.metadata.annotations). + + Supported backends: + - Kubernetes + + Args: + annotations: Dictionary of annotation key-value pairs to add to TrainJob metadata. + """ + + annotations: dict[str, str] + + def __call__( + self, + job_spec: dict[str, Any], + trainer: Optional[Union[CustomTrainer, BuiltinTrainer]], + backend: RuntimeBackend, + ) -> None: + """Apply annotations to the job specification. + + Args: + job_spec: Job specification dictionary to modify. + trainer: Optional trainer instance for context. + backend: Backend instance for validation. + + Raises: + ValueError: If backend does not support annotations. + """ + from kubeflow.trainer.backends.kubernetes.backend import KubernetesBackend + + if not isinstance(backend, KubernetesBackend): + raise ValueError( + f"Annotations option is not compatible with {type(backend).__name__}. " + f"Supported backends: KubernetesBackend" + ) + + metadata = job_spec.setdefault("metadata", {}) + metadata["annotations"] = self.annotations + + +@dataclass +class SpecLabels: + """Add labels to derivative JobSet and Jobs (.spec.labels). + + These labels will be merged with the TrainingRuntime values and applied to + the JobSet and Jobs created by the TrainJob. + + Supported backends: + - Kubernetes + + Args: + labels: Dictionary of label key-value pairs to add to JobSet and Jobs. + """ + + labels: dict[str, str] + + def __call__( + self, + job_spec: dict[str, Any], + trainer: Optional[Union[CustomTrainer, BuiltinTrainer]], + backend: RuntimeBackend, + ) -> None: + """Apply spec-level labels to the job specification. + + Args: + job_spec: Job specification dictionary to modify. + trainer: Optional trainer instance for context. + backend: Backend instance for validation. + + Raises: + ValueError: If backend does not support spec labels. + """ + from kubeflow.trainer.backends.kubernetes.backend import KubernetesBackend + + if not isinstance(backend, KubernetesBackend): + raise ValueError( + f"SpecLabels option is not compatible with {type(backend).__name__}. " + f"Supported backends: KubernetesBackend" + ) + + spec = job_spec.setdefault("spec", {}) + spec["labels"] = self.labels + + +@dataclass +class SpecAnnotations: + """Add annotations to derivative JobSet and Jobs (.spec.annotations). + + These annotations will be merged with the TrainingRuntime values and applied to + the JobSet and Jobs created by the TrainJob. + + Supported backends: + - Kubernetes + + Args: + annotations: Dictionary of annotation key-value pairs to add to JobSet and Jobs. + """ + + annotations: dict[str, str] + + def __call__( + self, + job_spec: dict[str, Any], + trainer: Optional[Union[CustomTrainer, BuiltinTrainer]], + backend: RuntimeBackend, + ) -> None: + """Apply spec-level annotations to the job specification. + + Args: + job_spec: Job specification dictionary to modify. + trainer: Optional trainer instance for context. + backend: Backend instance for validation. + + Raises: + ValueError: If backend does not support spec annotations. + """ + from kubeflow.trainer.backends.kubernetes.backend import KubernetesBackend + + if not isinstance(backend, KubernetesBackend): + raise ValueError( + f"SpecAnnotations option is not compatible with {type(backend).__name__}. " + f"Supported backends: KubernetesBackend" + ) + + spec = job_spec.setdefault("spec", {}) + spec["annotations"] = self.annotations + + +class PodTemplateOverrides: + """Add pod template overrides to the TrainJob (.spec.podTemplateOverrides). + + Supported backends: + - Kubernetes + + Args: + *overrides: One or more PodTemplateOverride objects. + """ + + def __init__(self, *overrides: PodTemplateOverride): + """Initialize with variable number of PodTemplateOverride objects.""" + if not overrides: + raise ValueError("At least one PodTemplateOverride must be provided") + self.pod_overrides = list(overrides) + + def __call__( + self, + job_spec: dict[str, Any], + trainer: Optional[Union[CustomTrainer, BuiltinTrainer]], + backend: RuntimeBackend, + ) -> None: + """Apply pod template overrides to the job specification. + + Args: + job_spec: Job specification dictionary to modify. + trainer: Optional trainer instance for context. + backend: Backend instance for validation. + + Raises: + ValueError: If backend does not support pod template overrides. + """ + from kubeflow.trainer.backends.kubernetes.backend import KubernetesBackend + + if not isinstance(backend, KubernetesBackend): + raise ValueError( + f"PodTemplateOverrides option is not compatible with {type(backend).__name__}. " + f"Supported backends: KubernetesBackend" + ) + spec = job_spec.setdefault("spec", {}) + pod_overrides = spec.setdefault("podTemplateOverrides", []) + + for override in self.pod_overrides: + api_override = {"targetJobs": [{"name": job} for job in override.target_jobs]} + + if override.metadata: + api_override["metadata"] = override.metadata + + if override.spec: + spec_dict = {} + + if override.spec.service_account_name: + spec_dict["serviceAccountName"] = override.spec.service_account_name + if override.spec.node_selector: + spec_dict["nodeSelector"] = override.spec.node_selector + if override.spec.affinity: + spec_dict["affinity"] = override.spec.affinity + if override.spec.tolerations: + spec_dict["tolerations"] = override.spec.tolerations + if override.spec.volumes: + spec_dict["volumes"] = override.spec.volumes + if override.spec.scheduling_gates: + spec_dict["schedulingGates"] = override.spec.scheduling_gates + if override.spec.image_pull_secrets: + spec_dict["imagePullSecrets"] = override.spec.image_pull_secrets + + # Handle container overrides + if override.spec.init_containers: + spec_dict["initContainers"] = [] + for container in override.spec.init_containers: + container_dict = {"name": container.name} + if container.env: + container_dict["env"] = container.env + if container.volume_mounts: + container_dict["volumeMounts"] = container.volume_mounts + spec_dict["initContainers"].append(container_dict) + + if override.spec.containers: + spec_dict["containers"] = [] + for container in override.spec.containers: + container_dict = {"name": container.name} + if container.env: + container_dict["env"] = container.env + if container.volume_mounts: + container_dict["volumeMounts"] = container.volume_mounts + spec_dict["containers"].append(container_dict) + + if spec_dict: + api_override["spec"] = spec_dict + + pod_overrides.append(api_override) + + +@dataclass +class TrainerCommand: + """Override the trainer container command (.spec.trainer.command). + + Can only be used with CustomTrainerContainer. CustomTrainer generates its own + command from the function, and BuiltinTrainer uses pre-configured commands. + + Supported backends: + - Kubernetes + + Args: + command: List of command strings to override the default trainer command. + """ + + command: list[str] + + def __call__( + self, + job_spec: dict[str, Any], + trainer: Optional[Union[CustomTrainer, BuiltinTrainer, CustomTrainerContainer]], + backend: RuntimeBackend, + ) -> None: + """Apply trainer command override to the job specification. + + Args: + job_spec: The job specification to modify. + trainer: Optional trainer context for validation. + backend: Backend instance for validation. + + Raises: + ValueError: If backend doesn't support or trainer type conflicts. + """ + from kubeflow.trainer.backends.kubernetes.backend import KubernetesBackend + + if not isinstance(backend, KubernetesBackend): + raise ValueError( + f"TrainerCommand option is not compatible with {type(backend).__name__}. " + f"Supported backends: KubernetesBackend" + ) + + if trainer is not None and not isinstance(trainer, CustomTrainerContainer): + raise ValueError( + "TrainerCommand can only be used with CustomTrainerContainer. " + "CustomTrainer generates its own command from the function, and " + "BuiltinTrainer uses pre-configured commands." + ) + + spec = job_spec.setdefault("spec", {}) + trainer_spec = spec.setdefault("trainer", {}) + trainer_spec["command"] = self.command + + +@dataclass +class TrainerArgs: + """Override the trainer container arguments (.spec.trainer.args). + + Can only be used with CustomTrainerContainer. CustomTrainer generates its own + arguments from the function, and BuiltinTrainer uses pre-configured arguments. + + Supported backends: + - Kubernetes + + Args: + args: List of argument strings to override the default trainer arguments. + """ + + args: list[str] + + def __call__( + self, + job_spec: dict[str, Any], + trainer: Optional[Union[CustomTrainer, BuiltinTrainer, CustomTrainerContainer]], + backend: RuntimeBackend, + ) -> None: + """Apply trainer args override to the job specification. + + Args: + job_spec: The job specification to modify. + trainer: Optional trainer context for validation. + backend: Backend instance for validation. + + Raises: + ValueError: If backend doesn't support or trainer type conflicts. + """ + from kubeflow.trainer.backends.kubernetes.backend import KubernetesBackend + + if not isinstance(backend, KubernetesBackend): + raise ValueError( + f"TrainerArgs option is not compatible with {type(backend).__name__}. " + f"Supported backends: KubernetesBackend" + ) + + if trainer is not None and not isinstance(trainer, CustomTrainerContainer): + raise ValueError( + "TrainerArgs can only be used with CustomTrainerContainer. " + "CustomTrainer generates its own arguments from the function, and " + "BuiltinTrainer uses pre-configured arguments." + ) + + spec = job_spec.setdefault("spec", {}) + trainer_spec = spec.setdefault("trainer", {}) + trainer_spec["args"] = self.args diff --git a/kubeflow/trainer/options/kubernetes_test.py b/kubeflow/trainer/options/kubernetes_test.py new file mode 100644 index 000000000..73c897fe0 --- /dev/null +++ b/kubeflow/trainer/options/kubernetes_test.py @@ -0,0 +1,259 @@ +# Copyright 2025 The Kubeflow Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Unit tests for Kubernetes options.""" + +import pytest + +from kubeflow.trainer.backends.kubernetes.backend import KubernetesBackend +from kubeflow.trainer.backends.localprocess.backend import LocalProcessBackend +from kubeflow.trainer.options import ( + Annotations, + ContainerOverride, + Labels, + Name, + PodTemplateOverride, + PodTemplateOverrides, + SpecAnnotations, + SpecLabels, + TrainerArgs, + TrainerCommand, +) + + +@pytest.fixture +def mock_kubernetes_backend(): + """Mock Kubernetes backend for testing.""" + from unittest.mock import Mock + + backend = Mock(spec=KubernetesBackend) + backend.__class__ = KubernetesBackend + return backend + + +@pytest.fixture +def mock_localprocess_backend(): + """Mock LocalProcess backend for testing.""" + from unittest.mock import MagicMock + + # Create a proper mock that isinstance checks will work with + backend = MagicMock(spec=LocalProcessBackend) + # Make type(backend).__name__ return the correct class name + type(backend).__name__ = "LocalProcessBackend" + return backend + + +class TestKubernetesOptionBackendValidation: + """Test that Kubernetes options validate backend compatibility.""" + + @pytest.mark.parametrize( + "option_class,option_args", + [ + (Labels, {"app": "test", "version": "v1"}), + (Annotations, {"description": "test job"}), + (SpecLabels, {"app": "training"}), + (SpecAnnotations, {"prometheus.io/scrape": "true"}), + (TrainerCommand, ["python", "train.py"]), + (TrainerArgs, ["--epochs", "10"]), + ], + ) + def test_kubernetes_options_reject_wrong_backend( + self, option_class, option_args, mock_localprocess_backend + ): + """Test Kubernetes-specific options reject non-Kubernetes backends.""" + if option_class == TrainerCommand: + option = option_class(command=option_args) + elif option_class == TrainerArgs: + option = option_class(args=option_args) + else: + option = option_class(option_args) + + job_spec = {} + + with pytest.raises(ValueError) as exc_info: + option(job_spec, None, mock_localprocess_backend) + + assert "not compatible with" in str(exc_info.value) + assert "LocalProcessBackend" in str(exc_info.value) + + def test_pod_template_overrides_rejects_wrong_backend(self, mock_localprocess_backend): + """Test PodTemplateOverrides rejects non-Kubernetes backends.""" + override = PodTemplateOverride(target_jobs=["node"]) + option = PodTemplateOverrides(override) + + job_spec = {} + + with pytest.raises(ValueError) as exc_info: + option(job_spec, None, mock_localprocess_backend) + + assert "not compatible with" in str(exc_info.value) + + +class TestKubernetesOptionApplication: + """Test Kubernetes option application behavior.""" + + @pytest.mark.parametrize( + "option_class,option_args,expected_spec", + [ + ( + Labels, + {"app": "test", "version": "v1"}, + {"metadata": {"labels": {"app": "test", "version": "v1"}}}, + ), + ( + Annotations, + {"description": "test job"}, + {"metadata": {"annotations": {"description": "test job"}}}, + ), + ( + SpecLabels, + {"app": "training", "version": "v1.0"}, + {"spec": {"labels": {"app": "training", "version": "v1.0"}}}, + ), + ( + SpecAnnotations, + {"prometheus.io/scrape": "true"}, + {"spec": {"annotations": {"prometheus.io/scrape": "true"}}}, + ), + (Name, "custom-job-name", {"metadata": {"name": "custom-job-name"}}), + ( + TrainerCommand, + ["python", "train.py"], + {"spec": {"trainer": {"command": ["python", "train.py"]}}}, + ), + ( + TrainerArgs, + ["--epochs", "10"], + {"spec": {"trainer": {"args": ["--epochs", "10"]}}}, + ), + ], + ) + def test_option_application( + self, option_class, option_args, expected_spec, mock_kubernetes_backend + ): + """Test each option applies correctly to job spec with Kubernetes backend.""" + if option_class == TrainerCommand: + option = option_class(command=option_args) + elif option_class == TrainerArgs: + option = option_class(args=option_args) + else: + option = option_class(option_args) + + job_spec = {} + option(job_spec, None, mock_kubernetes_backend) + + assert job_spec == expected_spec + + +class TestTrainerOptionValidation: + """Test validation of trainer-specific options.""" + + @pytest.mark.parametrize( + "option_class,option_args,trainer_type,should_fail", + [ + # Validation failures + (TrainerCommand, ["python", "train.py"], "CustomTrainer", True), + (TrainerArgs, ["--epochs", "10"], "CustomTrainer", True), + (TrainerCommand, ["python", "train.py"], "BuiltinTrainer", True), + (TrainerArgs, ["--epochs", "10"], "BuiltinTrainer", True), + # Successful applications + (TrainerCommand, ["python", "train.py"], "CustomTrainerContainer", False), + (TrainerArgs, ["--epochs", "10"], "CustomTrainerContainer", False), + ], + ) + def test_trainer_option_validation( + self, option_class, option_args, trainer_type, should_fail, mock_kubernetes_backend + ): + """Test trainer option validation with different trainer types.""" + from kubeflow.trainer.types.types import ( + BuiltinTrainer, + CustomTrainer, + CustomTrainerContainer, + TorchTuneConfig, + ) + + # Create appropriate trainer instance + if trainer_type == "CustomTrainer": + + def dummy_func(): + pass + + trainer = CustomTrainer(func=dummy_func) + elif trainer_type == "BuiltinTrainer": + trainer = BuiltinTrainer(config=TorchTuneConfig()) + else: # CustomTrainerContainer + trainer = CustomTrainerContainer(image="custom-image:latest") + + # Create option + if option_class == TrainerCommand: + option = option_class(command=option_args) + else: # TrainerArgs + option = option_class(args=option_args) + + job_spec = {} + + if should_fail: + with pytest.raises(ValueError) as exc_info: + option(job_spec, trainer, mock_kubernetes_backend) + assert "TrainerCommand can only be used with CustomTrainerContainer" in str( + exc_info.value + ) or "TrainerArgs can only be used with CustomTrainerContainer" in str(exc_info.value) + else: + option(job_spec, trainer, mock_kubernetes_backend) + if option_class == TrainerCommand: + assert job_spec["spec"]["trainer"]["command"] == option_args + else: + assert job_spec["spec"]["trainer"]["args"] == option_args + + +class TestContainerOverride: + """Test ContainerOverride validation.""" + + @pytest.mark.parametrize( + "kwargs,expected_error", + [ + ({"name": ""}, "Container name must be a non-empty string"), + ( + {"name": "trainer", "env": [{"invalid": "structure"}]}, + "Each env entry must have a 'name' key", + ), + ( + {"name": "trainer", "volume_mounts": [{"name": "vol"}]}, + "Each volume_mounts entry must have a 'mountPath' key", + ), + ], + ) + def test_container_override_validation(self, kwargs, expected_error): + """Test ContainerOverride validates inputs correctly.""" + with pytest.raises(ValueError) as exc_info: + ContainerOverride(**kwargs) + assert expected_error in str(exc_info.value) + + +class TestPodTemplateOverrides: + """Test PodTemplateOverrides functionality.""" + + def test_pod_template_overrides_basic(self, mock_kubernetes_backend): + """Test basic PodTemplateOverrides application.""" + + override = PodTemplateOverride(target_jobs=["node"]) + option = PodTemplateOverrides(override) + + job_spec = {} + option(job_spec, None, mock_kubernetes_backend) + + assert "spec" in job_spec + assert "podTemplateOverrides" in job_spec["spec"] + assert len(job_spec["spec"]["podTemplateOverrides"]) == 1 + assert job_spec["spec"]["podTemplateOverrides"][0]["targetJobs"] == [{"name": "node"}] diff --git a/kubeflow/trainer/options/localprocess.py b/kubeflow/trainer/options/localprocess.py new file mode 100644 index 000000000..1bd246f8b --- /dev/null +++ b/kubeflow/trainer/options/localprocess.py @@ -0,0 +1,20 @@ +# Copyright 2025 The Kubeflow Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""LocalProcess-specific training options for the Kubeflow Trainer SDK. + +TODO: Add LocalProcess options (ProcessTimeout, WorkingDirectory, etc.) in future iteration. +""" + +# TODO: Implement LocalProcess options using LocalProcessCompatible base class