Skip to content

Commit dc0fb12

Browse files
feat: Add simplified Training Options for Kubernetes backend type with podTemplateOverrides
Signed-off-by: Abhijeet Dhumal <[email protected]>
1 parent b87b81b commit dc0fb12

File tree

19 files changed

+1670
-43
lines changed

19 files changed

+1670
-43
lines changed

Makefile

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -75,12 +75,8 @@ release: install-dev
7575
.PHONY: test-python
7676
test-python: uv-venv ## Run Python unit tests
7777
@uv sync
78-
@uv run coverage run --source=kubeflow.trainer.backends.kubernetes.backend,kubeflow.trainer.utils.utils -m pytest \
79-
./kubeflow/trainer/backends/kubernetes/backend_test.py \
80-
./kubeflow/trainer/backends/kubernetes/utils_test.py
81-
@uv run coverage report -m \
82-
kubeflow/trainer/backends/kubernetes/backend.py \
83-
kubeflow/trainer/backends/kubernetes/utils.py
78+
@uv run coverage run --source=kubeflow -m pytest ./kubeflow/
79+
@uv run coverage report --omit='*_test.py' --skip-covered --skip-empty
8480
ifeq ($(report),xml)
8581
@uv run coverage xml
8682
else

kubeflow/trainer/api/trainer_client.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,7 @@ def train(
107107
trainer: Optional[
108108
Union[types.CustomTrainer, types.CustomTrainerContainer, types.BuiltinTrainer]
109109
] = None,
110+
options: Optional[list] = None,
110111
) -> str:
111112
"""Create a TrainJob. You can configure the TrainJob using one of these trainers:
112113
@@ -124,6 +125,8 @@ def train(
124125
trainer: Optional configuration for a CustomTrainer, CustomTrainerContainer, or
125126
BuiltinTrainer. If not specified, the TrainJob will use the
126127
runtime's default values.
128+
options: Optional list of configuration options to apply to the TrainJob.
129+
Options can be imported from kubeflow.trainer.options.
127130
128131
Returns:
129132
The unique name of the TrainJob that has been generated.
@@ -133,7 +136,12 @@ def train(
133136
TimeoutError: Timeout to create TrainJobs.
134137
RuntimeError: Failed to create TrainJobs.
135138
"""
136-
return self.backend.train(runtime=runtime, initializer=initializer, trainer=trainer)
139+
return self.backend.train(
140+
runtime=runtime,
141+
initializer=initializer,
142+
trainer=trainer,
143+
options=options,
144+
)
137145

138146
def list_jobs(self, runtime: Optional[types.Runtime] = None) -> list[types.TrainJob]:
139147
"""List of the created TrainJobs. If a runtime is specified, only TrainJobs associated with
Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,72 @@
1+
# Copyright 2025 The Kubeflow Authors.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
"""
16+
Unit tests for TrainerClient backend selection.
17+
"""
18+
19+
from unittest.mock import Mock, patch
20+
21+
import pytest
22+
23+
from kubeflow.common.types import KubernetesBackendConfig
24+
from kubeflow.trainer.api.trainer_client import TrainerClient
25+
from kubeflow.trainer.backends.localprocess.types import LocalProcessBackendConfig
26+
27+
28+
@pytest.mark.parametrize(
29+
"test_case",
30+
[
31+
{
32+
"name": "default_backend_is_kubernetes",
33+
"backend_config": None,
34+
"expected_backend": "KubernetesBackend",
35+
"use_k8s_mocks": True,
36+
},
37+
{
38+
"name": "local_process_backend_selection",
39+
"backend_config": LocalProcessBackendConfig(),
40+
"expected_backend": "LocalProcessBackend",
41+
"use_k8s_mocks": False,
42+
},
43+
{
44+
"name": "kubernetes_backend_selection",
45+
"backend_config": KubernetesBackendConfig(),
46+
"expected_backend": "KubernetesBackend",
47+
"use_k8s_mocks": True,
48+
},
49+
],
50+
)
51+
def test_backend_selection(test_case):
52+
"""Test TrainerClient backend selection logic."""
53+
if test_case["use_k8s_mocks"]:
54+
with (
55+
patch("kubernetes.config.load_kube_config"),
56+
patch("kubernetes.client.CustomObjectsApi") as mock_custom_api,
57+
patch("kubernetes.client.CoreV1Api") as mock_core_api,
58+
):
59+
mock_custom_api.return_value = Mock()
60+
mock_core_api.return_value = Mock()
61+
62+
if test_case["backend_config"]:
63+
client = TrainerClient(backend_config=test_case["backend_config"])
64+
else:
65+
client = TrainerClient()
66+
67+
backend_name = client.backend.__class__.__name__
68+
assert backend_name == test_case["expected_backend"]
69+
else:
70+
client = TrainerClient(backend_config=test_case["backend_config"])
71+
backend_name = client.backend.__class__.__name__
72+
assert backend_name == test_case["expected_backend"]

kubeflow/trainer/backends/base.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,11 @@
2121

2222

2323
class RuntimeBackend(abc.ABC):
24+
"""Base class for runtime backends.
25+
26+
Options self-validate by checking the backend instance type in their __call__ method.
27+
"""
28+
2429
@abc.abstractmethod
2530
def list_runtimes(self) -> list[types.Runtime]:
2631
raise NotImplementedError()
@@ -41,6 +46,7 @@ def train(
4146
trainer: Optional[
4247
Union[types.CustomTrainer, types.CustomTrainerContainer, types.BuiltinTrainer]
4348
] = None,
49+
options: Optional[list] = None,
4450
) -> str:
4551
raise NotImplementedError()
4652

kubeflow/trainer/backends/kubernetes/backend.py

Lines changed: 68 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
import re
2121
import string
2222
import time
23-
from typing import Optional, Union
23+
from typing import Any, Optional, Union
2424
import uuid
2525

2626
from kubeflow_trainer_api import models
@@ -87,15 +87,9 @@ def list_runtimes(self) -> list[types.Runtime]:
8787
result.append(self.__get_runtime_from_cr(runtime))
8888

8989
except multiprocessing.TimeoutError as e:
90-
raise TimeoutError(
91-
f"Timeout to list {constants.CLUSTER_TRAINING_RUNTIME_KIND}s "
92-
f"in namespace: {self.namespace}"
93-
) from e
90+
raise TimeoutError(f"Timeout to list {constants.CLUSTER_TRAINING_RUNTIME_KIND}s") from e
9491
except Exception as e:
95-
raise RuntimeError(
96-
f"Failed to list {constants.CLUSTER_TRAINING_RUNTIME_KIND}s "
97-
f"in namespace: {self.namespace}"
98-
) from e
92+
raise RuntimeError(f"Failed to list {constants.CLUSTER_TRAINING_RUNTIME_KIND}s") from e
9993

10094
return result
10195

@@ -184,16 +178,62 @@ def train(
184178
trainer: Optional[
185179
Union[types.CustomTrainer, types.CustomTrainerContainer, types.BuiltinTrainer]
186180
] = None,
181+
options: Optional[list] = None,
187182
) -> str:
188-
# Generate unique name for the TrainJob.
189-
train_job_name = random.choice(string.ascii_lowercase) + uuid.uuid4().hex[:11]
183+
if runtime is None:
184+
runtime = self.get_runtime(constants.TORCH_RUNTIME)
185+
186+
# Process options to extract configuration
187+
job_spec = {}
188+
labels = None
189+
annotations = None
190+
name = None
191+
spec_labels = None
192+
spec_annotations = None
193+
trainer_overrides = {}
194+
pod_template_overrides = None
195+
196+
if options:
197+
for option in options:
198+
option(job_spec, trainer, self)
199+
200+
metadata_section = job_spec.get("metadata", {})
201+
labels = metadata_section.get("labels")
202+
annotations = metadata_section.get("annotations")
203+
name = metadata_section.get("name")
204+
205+
# Extract spec-level labels/annotations and other spec configurations
206+
spec_section = job_spec.get("spec", {})
207+
spec_labels = spec_section.get("labels")
208+
spec_annotations = spec_section.get("annotations")
209+
trainer_overrides = spec_section.get("trainer", {})
210+
pod_template_overrides = spec_section.get("podTemplateOverrides")
211+
212+
# Generate unique name for the TrainJob if not provided
213+
train_job_name = name or (
214+
random.choice(string.ascii_lowercase)
215+
+ uuid.uuid4().hex[: constants.JOB_NAME_UUID_LENGTH]
216+
)
217+
218+
# Build the TrainJob spec using the common _get_trainjob_spec method
219+
trainjob_spec = self._get_trainjob_spec(
220+
runtime=runtime,
221+
initializer=initializer,
222+
trainer=trainer,
223+
trainer_overrides=trainer_overrides,
224+
spec_labels=spec_labels,
225+
spec_annotations=spec_annotations,
226+
pod_template_overrides=pod_template_overrides,
227+
)
190228

191229
# Build the TrainJob.
192230
train_job = models.TrainerV1alpha1TrainJob(
193231
apiVersion=constants.API_VERSION,
194232
kind=constants.TRAINJOB_KIND,
195-
metadata=models.IoK8sApimachineryPkgApisMetaV1ObjectMeta(name=train_job_name),
196-
spec=self._get_trainjob_spec(runtime, initializer, trainer),
233+
metadata=models.IoK8sApimachineryPkgApisMetaV1ObjectMeta(
234+
name=train_job_name, labels=labels, annotations=annotations
235+
),
236+
spec=trainjob_spec,
197237
)
198238

199239
# Create the TrainJob.
@@ -549,6 +589,10 @@ def _get_trainjob_spec(
549589
trainer: Optional[
550590
Union[types.CustomTrainer, types.CustomTrainerContainer, types.BuiltinTrainer]
551591
] = None,
592+
trainer_overrides: Optional[dict[str, Any]] = None,
593+
spec_labels: Optional[dict[str, str]] = None,
594+
spec_annotations: Optional[dict[str, str]] = None,
595+
pod_template_overrides: Optional[models.IoK8sApiCoreV1PodTemplateSpec] = None,
552596
) -> models.TrainerV1alpha1TrainJobSpec:
553597
"""Get TrainJob spec from the given parameters"""
554598
if runtime is None:
@@ -575,9 +619,16 @@ def _get_trainjob_spec(
575619
else:
576620
raise ValueError(
577621
f"The trainer type {type(trainer)} is not supported. "
578-
"Please use CustomTrainer or BuiltinTrainer."
622+
"Please use CustomTrainer, CustomTrainerContainer, or BuiltinTrainer."
579623
)
580624

625+
# Apply trainer overrides if trainer was not provided but overrides exist
626+
if trainer_overrides:
627+
if "command" in trainer_overrides:
628+
trainer_cr.command = trainer_overrides["command"]
629+
if "args" in trainer_overrides:
630+
trainer_cr.args = trainer_overrides["args"]
631+
581632
return models.TrainerV1alpha1TrainJobSpec(
582633
runtimeRef=models.TrainerV1alpha1RuntimeRef(name=runtime.name),
583634
trainer=(trainer_cr if trainer_cr != models.TrainerV1alpha1Trainer() else None),
@@ -589,4 +640,7 @@ def _get_trainjob_spec(
589640
if isinstance(initializer, types.Initializer)
590641
else None
591642
),
643+
labels=spec_labels,
644+
annotations=spec_annotations,
645+
pod_template_overrides=pod_template_overrides,
592646
)

kubeflow/trainer/backends/kubernetes/backend_test.py

Lines changed: 74 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,12 @@
3535
from kubeflow.trainer.backends.kubernetes.backend import KubernetesBackend
3636
import kubeflow.trainer.backends.kubernetes.utils as utils
3737
from kubeflow.trainer.constants import constants
38+
from kubeflow.trainer.options import (
39+
Annotations,
40+
Labels,
41+
SpecAnnotations,
42+
SpecLabels,
43+
)
3844
from kubeflow.trainer.test.common import (
3945
DEFAULT_NAMESPACE,
4046
FAILED,
@@ -274,17 +280,27 @@ def get_train_job(
274280
runtime_name: str,
275281
train_job_name: str = BASIC_TRAIN_JOB_NAME,
276282
train_job_trainer: Optional[models.TrainerV1alpha1Trainer] = None,
283+
labels: Optional[dict[str, str]] = None,
284+
annotations: Optional[dict[str, str]] = None,
285+
spec_labels: Optional[dict[str, str]] = None,
286+
spec_annotations: Optional[dict[str, str]] = None,
277287
) -> models.TrainerV1alpha1TrainJob:
278288
"""
279289
Create a mock TrainJob object with optional trainer configurations.
280290
"""
281291
train_job = models.TrainerV1alpha1TrainJob(
282292
apiVersion=constants.API_VERSION,
283293
kind=constants.TRAINJOB_KIND,
284-
metadata=models.IoK8sApimachineryPkgApisMetaV1ObjectMeta(name=train_job_name),
294+
metadata=models.IoK8sApimachineryPkgApisMetaV1ObjectMeta(
295+
name=train_job_name,
296+
labels=labels,
297+
annotations=annotations,
298+
),
285299
spec=models.TrainerV1alpha1TrainJobSpec(
286300
runtimeRef=models.TrainerV1alpha1RuntimeRef(name=runtime_name),
287301
trainer=train_job_trainer,
302+
labels=spec_labels,
303+
annotations=spec_annotations,
288304
),
289305
)
290306

@@ -879,6 +895,58 @@ def test_get_runtime_packages(kubernetes_backend, test_case):
879895
},
880896
expected_error=ValueError,
881897
),
898+
TestCase(
899+
name="train with metadata labels and annotations",
900+
expected_status=SUCCESS,
901+
config={
902+
"options": [
903+
Labels({"team": "ml-platform"}),
904+
Annotations({"created-by": "sdk"}),
905+
],
906+
},
907+
expected_output=get_train_job(
908+
runtime_name=TORCH_RUNTIME,
909+
train_job_name=BASIC_TRAIN_JOB_NAME,
910+
labels={"team": "ml-platform"},
911+
annotations={"created-by": "sdk"},
912+
),
913+
),
914+
TestCase(
915+
name="train with spec labels and annotations",
916+
expected_status=SUCCESS,
917+
config={
918+
"options": [
919+
SpecLabels({"app": "training", "version": "v1.0"}),
920+
SpecAnnotations({"prometheus.io/scrape": "true"}),
921+
],
922+
},
923+
expected_output=get_train_job(
924+
runtime_name=TORCH_RUNTIME,
925+
train_job_name=BASIC_TRAIN_JOB_NAME,
926+
spec_labels={"app": "training", "version": "v1.0"},
927+
spec_annotations={"prometheus.io/scrape": "true"},
928+
),
929+
),
930+
TestCase(
931+
name="train with both metadata and spec labels/annotations",
932+
expected_status=SUCCESS,
933+
config={
934+
"options": [
935+
Labels({"owner": "ml-team"}),
936+
Annotations({"description": "Fine-tuning job"}),
937+
SpecLabels({"app": "training", "version": "v1.0"}),
938+
SpecAnnotations({"prometheus.io/scrape": "true"}),
939+
],
940+
},
941+
expected_output=get_train_job(
942+
runtime_name=TORCH_RUNTIME,
943+
train_job_name=BASIC_TRAIN_JOB_NAME,
944+
labels={"owner": "ml-team"},
945+
annotations={"description": "Fine-tuning job"},
946+
spec_labels={"app": "training", "version": "v1.0"},
947+
spec_annotations={"prometheus.io/scrape": "true"},
948+
),
949+
),
882950
],
883951
)
884952
def test_train(kubernetes_backend, test_case):
@@ -888,8 +956,12 @@ def test_train(kubernetes_backend, test_case):
888956
kubernetes_backend.namespace = test_case.config.get("namespace", DEFAULT_NAMESPACE)
889957
runtime = kubernetes_backend.get_runtime(test_case.config.get("runtime", TORCH_RUNTIME))
890958

959+
options = test_case.config.get("options", [])
960+
891961
train_job_name = kubernetes_backend.train(
892-
runtime=runtime, trainer=test_case.config.get("trainer", None)
962+
runtime=runtime,
963+
trainer=test_case.config.get("trainer", None),
964+
options=options,
893965
)
894966

895967
assert test_case.expected_status == SUCCESS

0 commit comments

Comments
 (0)