@@ -253,16 +253,26 @@ def get_train_job(
253253 runtime_name : str ,
254254 train_job_name : str = BASIC_TRAIN_JOB_NAME ,
255255 train_job_trainer : Optional [models .TrainerV1alpha1Trainer ] = None ,
256+ labels : Optional [Dict [str , str ]] = None ,
257+ annotations : Optional [Dict [str , str ]] = None ,
258+ job_labels : Optional [Dict [str , str ]] = None ,
259+ job_annotations : Optional [Dict [str , str ]] = None ,
256260) -> models .TrainerV1alpha1TrainJob :
257261 """
258262 Create a mock TrainJob object with optional trainer configurations.
259263 """
260264 train_job = models .TrainerV1alpha1TrainJob (
261265 apiVersion = constants .API_VERSION ,
262266 kind = constants .TRAINJOB_KIND ,
263- metadata = models .IoK8sApimachineryPkgApisMetaV1ObjectMeta (name = train_job_name ),
267+ metadata = models .IoK8sApimachineryPkgApisMetaV1ObjectMeta (
268+ name = train_job_name ,
269+ labels = labels ,
270+ annotations = annotations
271+ ),
264272 spec = models .TrainerV1alpha1TrainJobSpec (
265273 runtimeRef = models .TrainerV1alpha1RuntimeRef (name = runtime_name ),
274+ labels = job_labels ,
275+ annotations = job_annotations ,
266276 trainer = train_job_trainer ,
267277 ),
268278 )
@@ -793,6 +803,86 @@ def test_get_runtime_packages(trainer_client, test_case):
793803 },
794804 expected_error = ValueError ,
795805 ),
806+ TestCase (
807+ name = "valid flow with labels and annotations" ,
808+ expected_status = SUCCESS ,
809+ config = {
810+ "labels" : {"kueue.x-k8s.io/queue-name" : "ml-queue" , "team" : "ml-engineering" },
811+ "annotations" : {"experiment.id" : "exp-001" , "description" : "Test training job" },
812+ },
813+ expected_output = get_train_job (
814+ runtime_name = TORCH_RUNTIME ,
815+ train_job_name = BASIC_TRAIN_JOB_NAME ,
816+ labels = {"kueue.x-k8s.io/queue-name" : "ml-queue" , "team" : "ml-engineering" },
817+ annotations = {"experiment.id" : "exp-001" , "description" : "Test training job" },
818+ ),
819+ ),
820+ TestCase (
821+ name = "valid flow with only labels" ,
822+ expected_status = SUCCESS ,
823+ config = {
824+ "labels" : {"priority" : "high" },
825+ },
826+ expected_output = get_train_job (
827+ runtime_name = TORCH_RUNTIME ,
828+ train_job_name = BASIC_TRAIN_JOB_NAME ,
829+ labels = {"priority" : "high" },
830+ ),
831+ ),
832+ TestCase (
833+ name = "valid flow with only annotations" ,
834+ expected_status = SUCCESS ,
835+ config = {
836+ "annotations" : {"created-by" : "training-pipeline" },
837+ },
838+ expected_output = get_train_job (
839+ runtime_name = TORCH_RUNTIME ,
840+ train_job_name = BASIC_TRAIN_JOB_NAME ,
841+ annotations = {"created-by" : "training-pipeline" },
842+ ),
843+ ),
844+ TestCase (
845+ name = "valid flow with job_labels for Kueue" ,
846+ expected_status = SUCCESS ,
847+ config = {
848+ "job_labels" : {"kueue.x-k8s.io/queue-name" : "ml-queue" },
849+ },
850+ expected_output = get_train_job (
851+ runtime_name = TORCH_RUNTIME ,
852+ train_job_name = BASIC_TRAIN_JOB_NAME ,
853+ job_labels = {"kueue.x-k8s.io/queue-name" : "ml-queue" },
854+ ),
855+ ),
856+ TestCase (
857+ name = "valid flow with job_annotations" ,
858+ expected_status = SUCCESS ,
859+ config = {
860+ "job_annotations" : {"experiment.id" : "exp-001" },
861+ },
862+ expected_output = get_train_job (
863+ runtime_name = TORCH_RUNTIME ,
864+ train_job_name = BASIC_TRAIN_JOB_NAME ,
865+ job_annotations = {"experiment.id" : "exp-001" },
866+ ),
867+ ),
868+ TestCase (
869+ name = "valid flow with both resource and JobSet labels/annotations" ,
870+ expected_status = SUCCESS ,
871+ config = {
872+ "labels" : {"team" : "ml-platform" },
873+ "annotations" : {"created-by" : "sdk" },
874+ "job_labels" : {"kueue.x-k8s.io/queue-name" : "gpu-queue" },
875+ "job_annotations" : {"experiment.id" : "exp-001" },
876+ },
877+ expected_output = get_train_job (
878+ runtime_name = TORCH_RUNTIME ,
879+ train_job_name = BASIC_TRAIN_JOB_NAME ,
880+ labels = {"team" : "ml-platform" },
881+ annotations = {"created-by" : "sdk" },
882+ job_labels = {"kueue.x-k8s.io/queue-name" : "gpu-queue" },
883+ job_annotations = {"experiment.id" : "exp-001" },
884+ ),
885+ ),
796886 ],
797887)
798888def test_train (trainer_client , test_case ):
@@ -805,7 +895,12 @@ def test_train(trainer_client, test_case):
805895 )
806896
807897 train_job_name = trainer_client .train (
808- runtime = runtime , trainer = test_case .config .get ("trainer" , None )
898+ runtime = runtime ,
899+ trainer = test_case .config .get ("trainer" , None ),
900+ labels = test_case .config .get ("labels" , None ),
901+ annotations = test_case .config .get ("annotations" , None ),
902+ job_labels = test_case .config .get ("job_labels" , None ),
903+ job_annotations = test_case .config .get ("job_annotations" , None )
809904 )
810905
811906 assert test_case .expected_status == SUCCESS
0 commit comments