Skip to content

Commit 24f00a7

Browse files
committed
Fixed bugs and validated current test cases
1 parent 1b8704b commit 24f00a7

File tree

3 files changed

+43
-2
lines changed

3 files changed

+43
-2
lines changed

kubeflow/trainer/backends/kubernetes/backend.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,7 @@ def list_runtimes(self) -> list[types.Runtime]:
8383
namespace_thread.get(constants.DEFAULT_TIMEOUT)
8484
)
8585

86-
if not (cluster_runtime_list and namespace_runtime_list):
86+
if not (cluster_runtime_list or namespace_runtime_list):
8787
return result
8888

8989
for runtime in namespace_runtime_list.items + cluster_runtime_list.items:

kubeflow/trainer/backends/kubernetes/backend_test.py

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -339,6 +339,16 @@ def list_namespaced_custom_object_response(*args, **kwargs):
339339
models.TrainerV1alpha1TrainJobList(items=items),
340340
models.TrainerV1alpha1TrainJobList,
341341
)
342+
elif args[3] == constants.TRAINING_RUNTIME_PLURAL:
343+
# TODO: add test case for namespace scoped runtimes
344+
# items = [
345+
# create_training_runtime(name="runtime-1"),
346+
# create_training_runtime(name="runtime-2"),
347+
# ]
348+
mock_thread.get.return_value = normalize_model(
349+
models.TrainerV1alpha1TrainingRuntimeList(items=[]),
350+
models.TrainerV1alpha1TrainingRuntimeList,
351+
)
342352

343353
return mock_thread
344354

@@ -472,6 +482,37 @@ def create_cluster_training_runtime(
472482
)
473483

474484

485+
def create_training_runtime(
486+
name: str,
487+
namespace: str = "default",
488+
) -> models.TrainerV1alpha1TrainingRuntime:
489+
"""Create a mock namespaced TrainingRuntime object (not cluster-scoped)."""
490+
return models.TrainerV1alpha1TrainingRuntime(
491+
apiVersion=constants.API_VERSION,
492+
kind="TrainingRuntime",
493+
metadata=models.IoK8sApimachineryPkgApisMetaV1ObjectMeta(
494+
name=name,
495+
namespace=namespace,
496+
labels={constants.RUNTIME_FRAMEWORK_LABEL: name},
497+
),
498+
spec=models.TrainerV1alpha1TrainingRuntimeSpec(
499+
mlPolicy=models.TrainerV1alpha1MLPolicy(
500+
torch=models.TrainerV1alpha1TorchMLPolicySource(
501+
numProcPerNode=models.IoK8sApimachineryPkgUtilIntstrIntOrString(2)
502+
),
503+
numNodes=2,
504+
),
505+
template=models.TrainerV1alpha1JobSetTemplateSpec(
506+
metadata=models.IoK8sApimachineryPkgApisMetaV1ObjectMeta(
507+
name=name,
508+
namespace=namespace,
509+
),
510+
spec=models.JobsetV1alpha2JobSetSpec(replicatedJobs=[get_replicated_job()]),
511+
),
512+
),
513+
)
514+
515+
475516
def get_replicated_job() -> models.JobsetV1alpha2ReplicatedJob:
476517
return models.JobsetV1alpha2ReplicatedJob(
477518
name="node",

kubeflow/trainer/constants/constants.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@
3333
CLUSTER_TRAINING_RUNTIME_PLURAL = "clustertrainingruntimes"
3434

3535
# The Kind name for the TrainingRuntime.
36-
TRAINING_RUNTIME_KIND = "ClusterTrainingRuntime"
36+
TRAINING_RUNTIME_KIND = "TrainingRuntime"
3737

3838
# The plural for the ClusterTrainingRuntime.
3939
TRAINING_RUNTIME_PLURAL = "trainingruntimes"

0 commit comments

Comments
 (0)