diff --git a/README.md b/README.md index 7e94d653..0088b427 100644 --- a/README.md +++ b/README.md @@ -78,7 +78,7 @@ Kubeflow Trainer client supports local development without needing a Kubernetes ### Available Backends - **KubernetesBackend** (default) - Production training on Kubernetes -- **ContainerBackend** - Local development with Docker/Podman isolation +- **ContainerBackend** - Local development with Docker/Podman isolation - **LocalProcessBackend** - Quick prototyping with Python subprocesses **Quick Start:** diff --git a/kubeflow/trainer/backends/kubernetes/backend.py b/kubeflow/trainer/backends/kubernetes/backend.py index e96fb1b3..183700b7 100644 --- a/kubeflow/trainer/backends/kubernetes/backend.py +++ b/kubeflow/trainer/backends/kubernetes/backend.py @@ -59,21 +59,33 @@ def __init__(self, cfg: KubernetesBackendConfig): def list_runtimes(self) -> list[types.Runtime]: result = [] try: - thread = self.custom_api.list_cluster_custom_object( + cluster_thread = self.custom_api.list_cluster_custom_object( constants.GROUP, constants.VERSION, constants.CLUSTER_TRAINING_RUNTIME_PLURAL, async_req=True, ) - runtime_list = models.TrainerV1alpha1ClusterTrainingRuntimeList.from_dict( - thread.get(common_constants.DEFAULT_TIMEOUT) + namespace_thread = self.custom_api.list_namespaced_custom_object( + constants.GROUP, + constants.VERSION, + self.namespace, + constants.TRAINING_RUNTIME_PLURAL, + async_req=True, + ) + + cluster_runtime_list = models.TrainerV1alpha1ClusterTrainingRuntimeList.from_dict( + cluster_thread.get(constants.DEFAULT_TIMEOUT) + ) + + namespace_runtime_list = models.TrainerV1alpha1TrainingRuntimeList.from_dict( + namespace_thread.get(constants.DEFAULT_TIMEOUT) ) - if not runtime_list: + if not (cluster_runtime_list or namespace_runtime_list): return result - for runtime in runtime_list.items: + for runtime in namespace_runtime_list.items + cluster_runtime_list.items: if not ( runtime.metadata and runtime.metadata.labels @@ -88,22 +100,24 @@ def list_runtimes(self) -> list[types.Runtime]: except multiprocessing.TimeoutError as e: raise TimeoutError( - f"Timeout to list {constants.CLUSTER_TRAINING_RUNTIME_KIND}s " + "Timeout to list " + f"{constants.CLUSTER_TRAINING_RUNTIME_KIND}s/{constants.TRAINING_RUNTIME_KIND}s " f"in namespace: {self.namespace}" ) from e except Exception as e: raise RuntimeError( - f"Failed to list {constants.CLUSTER_TRAINING_RUNTIME_KIND}s " + "Failed to list " + f"{constants.CLUSTER_TRAINING_RUNTIME_KIND}s/{constants.TRAINING_RUNTIME_KIND}s " f"in namespace: {self.namespace}" ) from e return result def get_runtime(self, name: str) -> types.Runtime: - """Get the the Runtime object""" + """Get the the Runtime object prefer namespaced, fall-back to cluster-scoped""" try: - thread = self.custom_api.get_cluster_custom_object( + cluster_thread = self.custom_api.get_cluster_custom_object( constants.GROUP, constants.VERSION, constants.CLUSTER_TRAINING_RUNTIME_PLURAL, @@ -111,10 +125,30 @@ def get_runtime(self, name: str) -> types.Runtime: async_req=True, ) - runtime = models.TrainerV1alpha1ClusterTrainingRuntime.from_dict( - thread.get(common_constants.DEFAULT_TIMEOUT) # type: ignore + namespace_thread = self.custom_api.get_namespaced_custom_object( + constants.GROUP, + constants.VERSION, + self.namespace, + constants.TRAINING_RUNTIME_PLURAL, + name, + async_req=True, ) + # Try namespaced runtime first, fall back to cluster-scoped one + try: + runtime = models.TrainerV1alpha1TrainingRuntime.from_dict( + namespace_thread.get(constants.DEFAULT_TIMEOUT) # type: ignore + ) + except Exception as e: + logger.warning( + f"Namespaced TrainingRuntime '{self.namespace}/{name}' not found " + f"({type(e).__name__}: {e}); falling back to cluster-scoped runtime." + ) + + runtime = models.TrainerV1alpha1ClusterTrainingRuntime.from_dict( + cluster_thread.get(constants.DEFAULT_TIMEOUT) # type: ignore + ) + except multiprocessing.TimeoutError as e: raise TimeoutError( f"Timeout to get {constants.CLUSTER_TRAINING_RUNTIME_PLURAL}: " @@ -396,8 +430,13 @@ def delete_job(self, name: str): def __get_runtime_from_cr( self, - runtime_cr: models.TrainerV1alpha1ClusterTrainingRuntime, + runtime_cr: Union[ + models.TrainerV1alpha1ClusterTrainingRuntime, models.TrainerV1alpha1TrainingRuntime + ], ) -> types.Runtime: + crd_kind = getattr(runtime_cr, "kind", "UnknownKind") + crd_name = getattr(runtime_cr.metadata, "name", "UnknownName") + if not ( runtime_cr.metadata and runtime_cr.metadata.name @@ -406,7 +445,11 @@ def __get_runtime_from_cr( and runtime_cr.spec.template.spec and runtime_cr.spec.template.spec.replicated_jobs ): - raise Exception(f"ClusterTrainingRuntime CR is invalid: {runtime_cr}") + raise Exception( + f"{crd_kind} '{crd_name}' is invalid — missing one or more required fields: " + f"metadata.name, spec.mlPolicy, spec.template.spec.replicatedJobs.\n" + f"Full object: {runtime_cr}" + ) if not ( runtime_cr.metadata.labels diff --git a/kubeflow/trainer/backends/kubernetes/backend_test.py b/kubeflow/trainer/backends/kubernetes/backend_test.py index 2fa5276a..10a8ba69 100644 --- a/kubeflow/trainer/backends/kubernetes/backend_test.py +++ b/kubeflow/trainer/backends/kubernetes/backend_test.py @@ -357,6 +357,16 @@ def list_namespaced_custom_object_response(*args, **kwargs): models.TrainerV1alpha1TrainJobList(items=items), models.TrainerV1alpha1TrainJobList, ) + elif args[3] == constants.TRAINING_RUNTIME_PLURAL: + # TODO: add test case for namespace scoped runtimes + # items = [ + # create_training_runtime(name="runtime-1"), + # create_training_runtime(name="runtime-2"), + # ] + mock_thread.get.return_value = normalize_model( + models.TrainerV1alpha1TrainingRuntimeList(items=[]), + models.TrainerV1alpha1TrainingRuntimeList, + ) return mock_thread @@ -490,6 +500,37 @@ def create_cluster_training_runtime( ) +def create_training_runtime( + name: str, + namespace: str = "default", +) -> models.TrainerV1alpha1TrainingRuntime: + """Create a mock namespaced TrainingRuntime object (not cluster-scoped).""" + return models.TrainerV1alpha1TrainingRuntime( + apiVersion=constants.API_VERSION, + kind="TrainingRuntime", + metadata=models.IoK8sApimachineryPkgApisMetaV1ObjectMeta( + name=name, + namespace=namespace, + labels={constants.RUNTIME_FRAMEWORK_LABEL: name}, + ), + spec=models.TrainerV1alpha1TrainingRuntimeSpec( + mlPolicy=models.TrainerV1alpha1MLPolicy( + torch=models.TrainerV1alpha1TorchMLPolicySource( + numProcPerNode=models.IoK8sApimachineryPkgUtilIntstrIntOrString(2) + ), + numNodes=2, + ), + template=models.TrainerV1alpha1JobSetTemplateSpec( + metadata=models.IoK8sApimachineryPkgApisMetaV1ObjectMeta( + name=name, + namespace=namespace, + ), + spec=models.JobsetV1alpha2JobSetSpec(replicatedJobs=[get_replicated_job()]), + ), + ), + ) + + def get_replicated_job() -> models.JobsetV1alpha2ReplicatedJob: return models.JobsetV1alpha2ReplicatedJob( name="node", diff --git a/kubeflow/trainer/constants/constants.py b/kubeflow/trainer/constants/constants.py index a15d402a..626bfd7a 100644 --- a/kubeflow/trainer/constants/constants.py +++ b/kubeflow/trainer/constants/constants.py @@ -26,6 +26,12 @@ # The plural for the ClusterTrainingRuntime. CLUSTER_TRAINING_RUNTIME_PLURAL = "clustertrainingruntimes" +# The Kind name for the TrainingRuntime. +TRAINING_RUNTIME_KIND = "TrainingRuntime" + +# The plural for the ClusterTrainingRuntime. +TRAINING_RUNTIME_PLURAL = "trainingruntimes" + # The Kind name for the TrainJob. TRAINJOB_KIND = "TrainJob"