2020import re
2121import string
2222import time
23- from typing import Optional , Union
23+ from typing import Any , Optional , Union
2424import uuid
2525
2626from kubeflow_trainer_api import models
@@ -87,15 +87,9 @@ def list_runtimes(self) -> list[types.Runtime]:
8787 result .append (self .__get_runtime_from_cr (runtime ))
8888
8989 except multiprocessing .TimeoutError as e :
90- raise TimeoutError (
91- f"Timeout to list { constants .CLUSTER_TRAINING_RUNTIME_KIND } s "
92- f"in namespace: { self .namespace } "
93- ) from e
90+ raise TimeoutError (f"Timeout to list { constants .CLUSTER_TRAINING_RUNTIME_KIND } s" ) from e
9491 except Exception as e :
95- raise RuntimeError (
96- f"Failed to list { constants .CLUSTER_TRAINING_RUNTIME_KIND } s "
97- f"in namespace: { self .namespace } "
98- ) from e
92+ raise RuntimeError (f"Failed to list { constants .CLUSTER_TRAINING_RUNTIME_KIND } s" ) from e
9993
10094 return result
10195
@@ -184,16 +178,62 @@ def train(
184178 trainer : Optional [
185179 Union [types .CustomTrainer , types .CustomTrainerContainer , types .BuiltinTrainer ]
186180 ] = None ,
181+ options : Optional [list ] = None ,
187182 ) -> str :
188- # Generate unique name for the TrainJob.
189- train_job_name = random .choice (string .ascii_lowercase ) + uuid .uuid4 ().hex [:11 ]
183+ if runtime is None :
184+ runtime = self .get_runtime (constants .TORCH_RUNTIME )
185+
186+ # Process options to extract configuration
187+ job_spec = {}
188+ labels = None
189+ annotations = None
190+ name = None
191+ spec_labels = None
192+ spec_annotations = None
193+ trainer_overrides = {}
194+ pod_template_overrides = None
195+
196+ if options :
197+ for option in options :
198+ option (job_spec , trainer , self )
199+
200+ metadata_section = job_spec .get ("metadata" , {})
201+ labels = metadata_section .get ("labels" )
202+ annotations = metadata_section .get ("annotations" )
203+ name = metadata_section .get ("name" )
204+
205+ # Extract spec-level labels/annotations and other spec configurations
206+ spec_section = job_spec .get ("spec" , {})
207+ spec_labels = spec_section .get ("labels" )
208+ spec_annotations = spec_section .get ("annotations" )
209+ trainer_overrides = spec_section .get ("trainer" , {})
210+ pod_template_overrides = spec_section .get ("podTemplateOverrides" )
211+
212+ # Generate unique name for the TrainJob if not provided
213+ train_job_name = name or (
214+ random .choice (string .ascii_lowercase )
215+ + uuid .uuid4 ().hex [: constants .JOB_NAME_UUID_LENGTH ]
216+ )
217+
218+ # Build the TrainJob spec using the common _get_trainjob_spec method
219+ trainjob_spec = self ._get_trainjob_spec (
220+ runtime = runtime ,
221+ initializer = initializer ,
222+ trainer = trainer ,
223+ trainer_overrides = trainer_overrides ,
224+ spec_labels = spec_labels ,
225+ spec_annotations = spec_annotations ,
226+ pod_template_overrides = pod_template_overrides ,
227+ )
190228
191229 # Build the TrainJob.
192230 train_job = models .TrainerV1alpha1TrainJob (
193231 apiVersion = constants .API_VERSION ,
194232 kind = constants .TRAINJOB_KIND ,
195- metadata = models .IoK8sApimachineryPkgApisMetaV1ObjectMeta (name = train_job_name ),
196- spec = self ._get_trainjob_spec (runtime , initializer , trainer ),
233+ metadata = models .IoK8sApimachineryPkgApisMetaV1ObjectMeta (
234+ name = train_job_name , labels = labels , annotations = annotations
235+ ),
236+ spec = trainjob_spec ,
197237 )
198238
199239 # Create the TrainJob.
@@ -549,6 +589,10 @@ def _get_trainjob_spec(
549589 trainer : Optional [
550590 Union [types .CustomTrainer , types .CustomTrainerContainer , types .BuiltinTrainer ]
551591 ] = None ,
592+ trainer_overrides : Optional [dict [str , Any ]] = None ,
593+ spec_labels : Optional [dict [str , str ]] = None ,
594+ spec_annotations : Optional [dict [str , str ]] = None ,
595+ pod_template_overrides : Optional [models .IoK8sApiCoreV1PodTemplateSpec ] = None ,
552596 ) -> models .TrainerV1alpha1TrainJobSpec :
553597 """Get TrainJob spec from the given parameters"""
554598 if runtime is None :
@@ -575,9 +619,16 @@ def _get_trainjob_spec(
575619 else :
576620 raise ValueError (
577621 f"The trainer type { type (trainer )} is not supported. "
578- "Please use CustomTrainer or BuiltinTrainer."
622+ "Please use CustomTrainer, CustomTrainerContainer, or BuiltinTrainer."
579623 )
580624
625+ # Apply trainer overrides if trainer was not provided but overrides exist
626+ if trainer_overrides :
627+ if "command" in trainer_overrides :
628+ trainer_cr .command = trainer_overrides ["command" ]
629+ if "args" in trainer_overrides :
630+ trainer_cr .args = trainer_overrides ["args" ]
631+
581632 return models .TrainerV1alpha1TrainJobSpec (
582633 runtimeRef = models .TrainerV1alpha1RuntimeRef (name = runtime .name ),
583634 trainer = (trainer_cr if trainer_cr != models .TrainerV1alpha1Trainer () else None ),
@@ -589,4 +640,7 @@ def _get_trainjob_spec(
589640 if isinstance (initializer , types .Initializer )
590641 else None
591642 ),
643+ labels = spec_labels ,
644+ annotations = spec_annotations ,
645+ pod_template_overrides = pod_template_overrides ,
592646 )
0 commit comments