Skip to content

Commit 67d12d8

Browse files
Add labels and annotations support for train client
Signed-off-by: Abhijeet Dhumal <[email protected]>
1 parent 98db2ad commit 67d12d8

File tree

3 files changed

+530
-411
lines changed

3 files changed

+530
-411
lines changed

kubeflow/trainer/api/trainer_client.py

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -230,6 +230,10 @@ def train(
230230
runtime: Optional[types.Runtime] = None,
231231
initializer: Optional[types.Initializer] = None,
232232
trainer: Optional[Union[types.CustomTrainer, types.BuiltinTrainer]] = None,
233+
labels: Optional[Dict[str, str]] = None,
234+
annotations: Optional[Dict[str, str]] = None,
235+
job_labels: Optional[Dict[str, str]] = None,
236+
job_annotations: Optional[Dict[str, str]] = None,
233237
) -> str:
234238
"""
235239
Create the TrainJob. You can configure these types of training task:
@@ -246,6 +250,15 @@ def train(
246250
Configuration for the dataset and model initializers.
247251
trainer:
248252
Configuration for Custom Training Task or Config-driven Task with Builtin Trainer.
253+
labels: Optional dictionary of labels to apply to the TrainJob metadata (.metadata.labels).
254+
Used for TrainJob resource organization and filtering.
255+
annotations: Optional dictionary of annotations to apply to the TrainJob metadata (.metadata.annotations).
256+
Useful for storing additional metadata about the training job resource.
257+
job_labels: Optional dictionary of labels to apply to the JobSet and Jobs (.spec.labels).
258+
These labels are propagated to the derivative JobSet and Jobs. Use this for Kueue
259+
integration (e.g., {"kueue.x-k8s.io/queue-name": "ml-queue"}).
260+
job_annotations: Optional dictionary of annotations to apply to the JobSet and Jobs (.spec.annotations).
261+
These annotations are propagated to the derivative JobSet and Jobs.
249262
250263
Returns:
251264
str: The unique name of the TrainJob that has been generated.
@@ -297,10 +310,14 @@ def train(
297310
apiVersion=constants.API_VERSION,
298311
kind=constants.TRAINJOB_KIND,
299312
metadata=models.IoK8sApimachineryPkgApisMetaV1ObjectMeta(
300-
name=train_job_name
313+
name=train_job_name,
314+
labels=labels,
315+
annotations=annotations
301316
),
302317
spec=models.TrainerV1alpha1TrainJobSpec(
303318
runtimeRef=models.TrainerV1alpha1RuntimeRef(name=runtime.name),
319+
labels=job_labels,
320+
annotations=job_annotations,
304321
trainer=(
305322
trainer_crd
306323
if trainer_crd != models.TrainerV1alpha1Trainer()

kubeflow/trainer/api/trainer_client_test.py

Lines changed: 97 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -253,16 +253,26 @@ def get_train_job(
253253
runtime_name: str,
254254
train_job_name: str = BASIC_TRAIN_JOB_NAME,
255255
train_job_trainer: Optional[models.TrainerV1alpha1Trainer] = None,
256+
labels: Optional[Dict[str, str]] = None,
257+
annotations: Optional[Dict[str, str]] = None,
258+
job_labels: Optional[Dict[str, str]] = None,
259+
job_annotations: Optional[Dict[str, str]] = None,
256260
) -> models.TrainerV1alpha1TrainJob:
257261
"""
258262
Create a mock TrainJob object with optional trainer configurations.
259263
"""
260264
train_job = models.TrainerV1alpha1TrainJob(
261265
apiVersion=constants.API_VERSION,
262266
kind=constants.TRAINJOB_KIND,
263-
metadata=models.IoK8sApimachineryPkgApisMetaV1ObjectMeta(name=train_job_name),
267+
metadata=models.IoK8sApimachineryPkgApisMetaV1ObjectMeta(
268+
name=train_job_name,
269+
labels=labels,
270+
annotations=annotations
271+
),
264272
spec=models.TrainerV1alpha1TrainJobSpec(
265273
runtimeRef=models.TrainerV1alpha1RuntimeRef(name=runtime_name),
274+
labels=job_labels,
275+
annotations=job_annotations,
266276
trainer=train_job_trainer,
267277
),
268278
)
@@ -793,6 +803,86 @@ def test_get_runtime_packages(trainer_client, test_case):
793803
},
794804
expected_error=ValueError,
795805
),
806+
TestCase(
807+
name="valid flow with labels and annotations",
808+
expected_status=SUCCESS,
809+
config={
810+
"labels": {"kueue.x-k8s.io/queue-name": "ml-queue", "team": "ml-engineering"},
811+
"annotations": {"experiment.id": "exp-001", "description": "Test training job"},
812+
},
813+
expected_output=get_train_job(
814+
runtime_name=TORCH_RUNTIME,
815+
train_job_name=BASIC_TRAIN_JOB_NAME,
816+
labels={"kueue.x-k8s.io/queue-name": "ml-queue", "team": "ml-engineering"},
817+
annotations={"experiment.id": "exp-001", "description": "Test training job"},
818+
),
819+
),
820+
TestCase(
821+
name="valid flow with only labels",
822+
expected_status=SUCCESS,
823+
config={
824+
"labels": {"priority": "high"},
825+
},
826+
expected_output=get_train_job(
827+
runtime_name=TORCH_RUNTIME,
828+
train_job_name=BASIC_TRAIN_JOB_NAME,
829+
labels={"priority": "high"},
830+
),
831+
),
832+
TestCase(
833+
name="valid flow with only annotations",
834+
expected_status=SUCCESS,
835+
config={
836+
"annotations": {"created-by": "training-pipeline"},
837+
},
838+
expected_output=get_train_job(
839+
runtime_name=TORCH_RUNTIME,
840+
train_job_name=BASIC_TRAIN_JOB_NAME,
841+
annotations={"created-by": "training-pipeline"},
842+
),
843+
),
844+
TestCase(
845+
name="valid flow with job_labels for Kueue",
846+
expected_status=SUCCESS,
847+
config={
848+
"job_labels": {"kueue.x-k8s.io/queue-name": "ml-queue"},
849+
},
850+
expected_output=get_train_job(
851+
runtime_name=TORCH_RUNTIME,
852+
train_job_name=BASIC_TRAIN_JOB_NAME,
853+
job_labels={"kueue.x-k8s.io/queue-name": "ml-queue"},
854+
),
855+
),
856+
TestCase(
857+
name="valid flow with job_annotations",
858+
expected_status=SUCCESS,
859+
config={
860+
"job_annotations": {"experiment.id": "exp-001"},
861+
},
862+
expected_output=get_train_job(
863+
runtime_name=TORCH_RUNTIME,
864+
train_job_name=BASIC_TRAIN_JOB_NAME,
865+
job_annotations={"experiment.id": "exp-001"},
866+
),
867+
),
868+
TestCase(
869+
name="valid flow with both resource and JobSet labels/annotations",
870+
expected_status=SUCCESS,
871+
config={
872+
"labels": {"team": "ml-platform"},
873+
"annotations": {"created-by": "sdk"},
874+
"job_labels": {"kueue.x-k8s.io/queue-name": "gpu-queue"},
875+
"job_annotations": {"experiment.id": "exp-001"},
876+
},
877+
expected_output=get_train_job(
878+
runtime_name=TORCH_RUNTIME,
879+
train_job_name=BASIC_TRAIN_JOB_NAME,
880+
labels={"team": "ml-platform"},
881+
annotations={"created-by": "sdk"},
882+
job_labels={"kueue.x-k8s.io/queue-name": "gpu-queue"},
883+
job_annotations={"experiment.id": "exp-001"},
884+
),
885+
),
796886
],
797887
)
798888
def test_train(trainer_client, test_case):
@@ -805,7 +895,12 @@ def test_train(trainer_client, test_case):
805895
)
806896

807897
train_job_name = trainer_client.train(
808-
runtime=runtime, trainer=test_case.config.get("trainer", None)
898+
runtime=runtime,
899+
trainer=test_case.config.get("trainer", None),
900+
labels=test_case.config.get("labels", None),
901+
annotations=test_case.config.get("annotations", None),
902+
job_labels=test_case.config.get("job_labels", None),
903+
job_annotations=test_case.config.get("job_annotations", None)
809904
)
810905

811906
assert test_case.expected_status == SUCCESS

0 commit comments

Comments
 (0)