1212# See the License for the specific language governing permissions and
1313# limitations under the License.
1414
15- from abc import ABC , abstractmethod
1615from dataclasses import dataclass
17- from typing import Dict
16+ from typing import Dict , List , Optional , Protocol
1817
1918
20- class Option (ABC ):
21- """Base class for TrainJob configuration options.
19+ class Option (Protocol ):
20+ """Protocol for TrainJob configuration options.
2221
2322 Options provide a composable way to configure different aspects of a TrainJob.
24- Each option implements the apply() method to modify the TrainJob specification.
23+ Each option implements the __call__ method to modify the TrainJob specification.
2524 """
2625
27- @abstractmethod
28- def apply (self , job_spec : dict ) -> None :
26+ def __call__ (self , job_spec : dict ) -> None :
2927 """Apply this option to the TrainJob specification.
3028
3129 Args:
3230 job_spec: The TrainJob specification dictionary to modify.
3331 """
34- pass
32+ ...
3533
3634
3735@dataclass
38- class WithLabels ( Option ) :
36+ class WithLabels :
3937 """Add labels to the TrainJob resource metadata (.metadata.labels).
4038
4139 These labels are applied to the TrainJob resource itself and are used
@@ -47,15 +45,15 @@ class WithLabels(Option):
4745
4846 labels : Dict [str , str ]
4947
50- def apply (self , job_spec : dict ) -> None :
48+ def __call__ (self , job_spec : dict ) -> None :
5149 """Apply labels to TrainJob metadata."""
5250 metadata = job_spec .setdefault ("metadata" , {})
5351 existing_labels = metadata .setdefault ("labels" , {})
5452 existing_labels .update (self .labels )
5553
5654
5755@dataclass
58- class WithAnnotations ( Option ) :
56+ class WithAnnotations :
5957 """Add annotations to the TrainJob resource metadata (.metadata.annotations).
6058
6159 These annotations are applied to the TrainJob resource itself and are used
@@ -67,10 +65,76 @@ class WithAnnotations(Option):
6765
6866 annotations : Dict [str , str ]
6967
70- def apply (self , job_spec : dict ) -> None :
68+ def __call__ (self , job_spec : dict ) -> None :
7169 """Apply annotations to TrainJob metadata."""
7270 metadata = job_spec .setdefault ("metadata" , {})
7371 existing_annotations = metadata .setdefault ("annotations" , {})
7472 existing_annotations .update (self .annotations )
7573
7674
75+ @dataclass
76+ class PodSpecOverride :
77+ """Configuration for overriding pod specifications for specific job types.
78+
79+ Args:
80+ target_jobs: List of job names to apply this override to.
81+ volumes: List of volume configurations to add to the pods.
82+ containers: List of container overrides.
83+ init_containers: List of init container overrides.
84+ node_selector: Node selector to place pods on specific nodes.
85+ service_account_name: Service account name for the pods.
86+ tolerations: List of tolerations for pod scheduling.
87+ """
88+
89+ target_jobs : List [str ]
90+ volumes : Optional [List [Dict ]] = None
91+ containers : Optional [List [Dict ]] = None
92+ init_containers : Optional [List [Dict ]] = None
93+ node_selector : Optional [Dict [str , str ]] = None
94+ service_account_name : Optional [str ] = None
95+ tolerations : Optional [List [Dict ]] = None
96+
97+
98+ @dataclass
99+ class WithPodSpecOverrides :
100+ """Add pod specification overrides to the TrainJob (.spec.podSpecOverrides).
101+
102+ This option allows you to customize pod specifications for different job types
103+ in your TrainJob. You can specify multiple overrides for different job types
104+ or different configurations.
105+
106+ Args:
107+ overrides: List of PodSpecOverride configurations to apply.
108+ """
109+
110+ overrides : List [PodSpecOverride ]
111+
112+ def __call__ (self , job_spec : dict ) -> None :
113+ """Apply pod spec overrides to TrainJob spec."""
114+ spec = job_spec .setdefault ("spec" , {})
115+ existing_overrides = spec .setdefault ("podSpecOverrides" , [])
116+
117+ for override in self .overrides :
118+ # Convert PodSpecOverride to TrainJob API format
119+ api_override = {"targetJobs" : [{"name" : job } for job in override .target_jobs ]}
120+
121+ # Add optional fields if provided
122+ if override .volumes :
123+ api_override ["volumes" ] = override .volumes
124+
125+ if override .containers :
126+ api_override ["containers" ] = override .containers
127+
128+ if override .init_containers :
129+ api_override ["initContainers" ] = override .init_containers
130+
131+ if override .node_selector :
132+ api_override ["nodeSelector" ] = override .node_selector
133+
134+ if override .service_account_name :
135+ api_override ["serviceAccountName" ] = override .service_account_name
136+
137+ if override .tolerations :
138+ api_override ["tolerations" ] = override .tolerations
139+
140+ existing_overrides .append (api_override )
0 commit comments