Skip to content

Commit 95155f6

Browse files
Add WithPodSpecOverrides option for pod customization
Signed-off-by: Abhijeet Dhumal <[email protected]>
1 parent dbba135 commit 95155f6

File tree

6 files changed

+218
-72
lines changed

6 files changed

+218
-72
lines changed

kubeflow/trainer/api/trainer_client.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -123,19 +123,22 @@ def train(
123123

124124
if options:
125125
for option in options:
126-
option.apply(job_spec)
126+
option(job_spec)
127127

128128
metadata_section = job_spec.get("metadata", {})
129+
spec_section = job_spec.get("spec", {})
129130

130131
labels = metadata_section.get("labels") or None
131132
annotations = metadata_section.get("annotations") or None
133+
pod_spec_overrides = spec_section.get("podSpecOverrides") or None
132134

133135
return self.backend.train(
134136
runtime=runtime,
135137
initializer=initializer,
136138
trainer=trainer,
137139
labels=labels,
138140
annotations=annotations,
141+
pod_spec_overrides=pod_spec_overrides,
139142
)
140143

141144
def list_jobs(self, runtime: Optional[types.Runtime] = None) -> list[types.TrainJob]:

kubeflow/trainer/backends/kubernetes/backend.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -183,6 +183,7 @@ def train(
183183
trainer: Optional[Union[types.CustomTrainer, types.BuiltinTrainer]] = None,
184184
labels: Optional[Dict[str, str]] = None,
185185
annotations: Optional[Dict[str, str]] = None,
186+
pod_spec_overrides: Optional[list] = None,
186187
) -> str:
187188
if runtime is None:
188189
runtime = self.get_runtime(constants.TORCH_RUNTIME)
@@ -219,9 +220,7 @@ def train(
219220
apiVersion=constants.API_VERSION,
220221
kind=constants.TRAINJOB_KIND,
221222
metadata=models.IoK8sApimachineryPkgApisMetaV1ObjectMeta(
222-
name=train_job_name,
223-
labels=labels,
224-
annotations=annotations
223+
name=train_job_name, labels=labels, annotations=annotations
225224
),
226225
spec=models.TrainerV1alpha1TrainJobSpec(
227226
runtimeRef=models.TrainerV1alpha1RuntimeRef(name=runtime.name),
@@ -234,6 +233,14 @@ def train(
234233
if isinstance(initializer, types.Initializer)
235234
else None
236235
),
236+
pod_spec_overrides=(
237+
[
238+
models.TrainerV1alpha1PodSpecOverride.from_dict(override)
239+
for override in pod_spec_overrides
240+
]
241+
if pod_spec_overrides
242+
else None
243+
),
237244
),
238245
)
239246

kubeflow/trainer/backends/kubernetes/backend_test.py

Lines changed: 70 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,12 @@
3333

3434
from kubeflow.trainer.constants import constants
3535
from kubeflow.trainer.types import types
36-
from kubeflow.trainer.options import WithLabels, WithAnnotations
36+
from kubeflow.trainer.options import (
37+
WithLabels,
38+
WithAnnotations,
39+
WithPodSpecOverrides,
40+
PodSpecOverride,
41+
)
3742
from kubeflow.trainer.utils import utils
3843
from kubeflow.trainer.backends.kubernetes.backend import KubernetesBackend
3944
from kubeflow.trainer.backends.kubernetes.types import KubernetesBackendConfig
@@ -256,6 +261,7 @@ def get_train_job(
256261
train_job_trainer: Optional[models.TrainerV1alpha1Trainer] = None,
257262
labels: Optional[Dict[str, str]] = None,
258263
annotations: Optional[Dict[str, str]] = None,
264+
pod_spec_overrides: Optional[list] = None,
259265
) -> models.TrainerV1alpha1TrainJob:
260266
"""
261267
Create a mock TrainJob object with optional trainer configurations.
@@ -264,13 +270,19 @@ def get_train_job(
264270
apiVersion=constants.API_VERSION,
265271
kind=constants.TRAINJOB_KIND,
266272
metadata=models.IoK8sApimachineryPkgApisMetaV1ObjectMeta(
267-
name=train_job_name,
268-
labels=labels,
269-
annotations=annotations
273+
name=train_job_name, labels=labels, annotations=annotations
270274
),
271275
spec=models.TrainerV1alpha1TrainJobSpec(
272276
runtimeRef=models.TrainerV1alpha1RuntimeRef(name=runtime_name),
273277
trainer=train_job_trainer,
278+
pod_spec_overrides=(
279+
[
280+
models.TrainerV1alpha1PodSpecOverride.from_dict(override)
281+
for override in pod_spec_overrides
282+
]
283+
if pod_spec_overrides
284+
else None
285+
),
274286
),
275287
)
276288

@@ -862,6 +874,56 @@ def test_get_runtime_packages(kubernetes_backend, test_case):
862874
annotations={"created-by": "sdk"},
863875
),
864876
),
877+
TestCase(
878+
name="valid flow with WithPodSpecOverrides option",
879+
expected_status=SUCCESS,
880+
config={
881+
"options": [
882+
WithPodSpecOverrides(
883+
[
884+
PodSpecOverride(
885+
target_jobs=["node"],
886+
volumes=[
887+
{
888+
"name": "datasets",
889+
"persistentVolumeClaim": {"claimName": "training-data-pvc"},
890+
}
891+
],
892+
containers=[
893+
{
894+
"name": "node",
895+
"volumeMounts": [
896+
{"name": "datasets", "mountPath": "/data"}
897+
],
898+
}
899+
],
900+
)
901+
]
902+
)
903+
],
904+
},
905+
expected_output=get_train_job(
906+
runtime_name=TORCH_RUNTIME,
907+
train_job_name=BASIC_TRAIN_JOB_NAME,
908+
pod_spec_overrides=[
909+
{
910+
"targetJobs": [{"name": "node"}],
911+
"volumes": [
912+
{
913+
"name": "datasets",
914+
"persistentVolumeClaim": {"claimName": "training-data-pvc"},
915+
}
916+
],
917+
"containers": [
918+
{
919+
"name": "node",
920+
"volumeMounts": [{"name": "datasets", "mountPath": "/data"}],
921+
}
922+
],
923+
}
924+
],
925+
),
926+
),
865927
],
866928
)
867929
def test_train(kubernetes_backend, test_case):
@@ -876,12 +938,14 @@ def test_train(kubernetes_backend, test_case):
876938
options = test_case.config.get("options", None)
877939
if options:
878940
for option in options:
879-
option.apply(job_spec)
941+
option(job_spec)
880942

881943
metadata_section = job_spec.get("metadata", {})
944+
spec_section = job_spec.get("spec", {})
882945

883946
labels = metadata_section.get("labels") or None
884947
annotations = metadata_section.get("annotations") or None
948+
pod_spec_overrides = spec_section.get("podSpecOverrides") or None
885949

886950
# Merge individual parameters with options
887951
individual_labels = test_case.config.get("labels", None)
@@ -903,6 +967,7 @@ def test_train(kubernetes_backend, test_case):
903967
trainer=test_case.config.get("trainer", None),
904968
labels=labels,
905969
annotations=annotations,
970+
pod_spec_overrides=pod_spec_overrides,
906971
)
907972

908973
assert test_case.expected_status == SUCCESS

kubeflow/trainer/options/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,12 +14,16 @@
1414

1515
from kubeflow.trainer.options.options import (
1616
Option,
17+
PodSpecOverride,
1718
WithAnnotations,
1819
WithLabels,
20+
WithPodSpecOverrides,
1921
)
2022

2123
__all__ = [
2224
"Option",
25+
"PodSpecOverride",
2326
"WithAnnotations",
2427
"WithLabels",
28+
"WithPodSpecOverrides",
2529
]

kubeflow/trainer/options/options.py

Lines changed: 76 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -12,30 +12,28 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
from abc import ABC, abstractmethod
1615
from dataclasses import dataclass
17-
from typing import Dict
16+
from typing import Dict, List, Optional, Protocol
1817

1918

20-
class Option(ABC):
21-
"""Base class for TrainJob configuration options.
19+
class Option(Protocol):
20+
"""Protocol for TrainJob configuration options.
2221
2322
Options provide a composable way to configure different aspects of a TrainJob.
24-
Each option implements the apply() method to modify the TrainJob specification.
23+
Each option implements the __call__ method to modify the TrainJob specification.
2524
"""
2625

27-
@abstractmethod
28-
def apply(self, job_spec: dict) -> None:
26+
def __call__(self, job_spec: dict) -> None:
2927
"""Apply this option to the TrainJob specification.
3028
3129
Args:
3230
job_spec: The TrainJob specification dictionary to modify.
3331
"""
34-
pass
32+
...
3533

3634

3735
@dataclass
38-
class WithLabels(Option):
36+
class WithLabels:
3937
"""Add labels to the TrainJob resource metadata (.metadata.labels).
4038
4139
These labels are applied to the TrainJob resource itself and are used
@@ -47,15 +45,15 @@ class WithLabels(Option):
4745

4846
labels: Dict[str, str]
4947

50-
def apply(self, job_spec: dict) -> None:
48+
def __call__(self, job_spec: dict) -> None:
5149
"""Apply labels to TrainJob metadata."""
5250
metadata = job_spec.setdefault("metadata", {})
5351
existing_labels = metadata.setdefault("labels", {})
5452
existing_labels.update(self.labels)
5553

5654

5755
@dataclass
58-
class WithAnnotations(Option):
56+
class WithAnnotations:
5957
"""Add annotations to the TrainJob resource metadata (.metadata.annotations).
6058
6159
These annotations are applied to the TrainJob resource itself and are used
@@ -67,10 +65,76 @@ class WithAnnotations(Option):
6765

6866
annotations: Dict[str, str]
6967

70-
def apply(self, job_spec: dict) -> None:
68+
def __call__(self, job_spec: dict) -> None:
7169
"""Apply annotations to TrainJob metadata."""
7270
metadata = job_spec.setdefault("metadata", {})
7371
existing_annotations = metadata.setdefault("annotations", {})
7472
existing_annotations.update(self.annotations)
7573

7674

75+
@dataclass
76+
class PodSpecOverride:
77+
"""Configuration for overriding pod specifications for specific job types.
78+
79+
Args:
80+
target_jobs: List of job names to apply this override to.
81+
volumes: List of volume configurations to add to the pods.
82+
containers: List of container overrides.
83+
init_containers: List of init container overrides.
84+
node_selector: Node selector to place pods on specific nodes.
85+
service_account_name: Service account name for the pods.
86+
tolerations: List of tolerations for pod scheduling.
87+
"""
88+
89+
target_jobs: List[str]
90+
volumes: Optional[List[Dict]] = None
91+
containers: Optional[List[Dict]] = None
92+
init_containers: Optional[List[Dict]] = None
93+
node_selector: Optional[Dict[str, str]] = None
94+
service_account_name: Optional[str] = None
95+
tolerations: Optional[List[Dict]] = None
96+
97+
98+
@dataclass
99+
class WithPodSpecOverrides:
100+
"""Add pod specification overrides to the TrainJob (.spec.podSpecOverrides).
101+
102+
This option allows you to customize pod specifications for different job types
103+
in your TrainJob. You can specify multiple overrides for different job types
104+
or different configurations.
105+
106+
Args:
107+
overrides: List of PodSpecOverride configurations to apply.
108+
"""
109+
110+
overrides: List[PodSpecOverride]
111+
112+
def __call__(self, job_spec: dict) -> None:
113+
"""Apply pod spec overrides to TrainJob spec."""
114+
spec = job_spec.setdefault("spec", {})
115+
existing_overrides = spec.setdefault("podSpecOverrides", [])
116+
117+
for override in self.overrides:
118+
# Convert PodSpecOverride to TrainJob API format
119+
api_override = {"targetJobs": [{"name": job} for job in override.target_jobs]}
120+
121+
# Add optional fields if provided
122+
if override.volumes:
123+
api_override["volumes"] = override.volumes
124+
125+
if override.containers:
126+
api_override["containers"] = override.containers
127+
128+
if override.init_containers:
129+
api_override["initContainers"] = override.init_containers
130+
131+
if override.node_selector:
132+
api_override["nodeSelector"] = override.node_selector
133+
134+
if override.service_account_name:
135+
api_override["serviceAccountName"] = override.service_account_name
136+
137+
if override.tolerations:
138+
api_override["tolerations"] = override.tolerations
139+
140+
existing_overrides.append(api_override)

0 commit comments

Comments
 (0)