Skip to content

Commit d90dbce

Browse files
authored
feat(trainer): Add get_runtime_packages() API (#57)
* feat(trainer): Add get_runtime_packages() API Signed-off-by: Andrey Velichkevich <[email protected]> * Fix mpirun command Signed-off-by: Andrey Velichkevich <[email protected]> * Print Python version Signed-off-by: Andrey Velichkevich <[email protected]> * Add unit tests Signed-off-by: Andrey Velichkevich <[email protected]> * Fix verify Signed-off-by: Andrey Velichkevich <[email protected]> --------- Signed-off-by: Andrey Velichkevich <[email protected]>
1 parent 7432d9f commit d90dbce

File tree

5 files changed

+169
-49
lines changed

5 files changed

+169
-49
lines changed

python/kubeflow/trainer/api/trainer_client.py

Lines changed: 96 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
import queue
1818
import random
1919
import string
20+
import time
2021
import uuid
2122
from typing import Dict, List, Optional, Union, Set
2223

@@ -159,6 +160,71 @@ def get_runtime(self, name: str) -> types.Runtime:
159160

160161
return self.__get_runtime_from_crd(runtime) # type: ignore
161162

163+
def get_runtime_packages(self, runtime: types.Runtime):
164+
"""
165+
Print the installed Python packages for the given Runtime. If Runtime has GPUs it also
166+
prints available GPUs on the single training node.
167+
168+
Args:
169+
runtime: Reference to one of existing Runtimes.
170+
171+
Raises:
172+
ValueError: Input arguments are invalid.
173+
RuntimeError: Failed to get Runtime.
174+
175+
"""
176+
177+
if runtime.trainer.trainer_type == types.TrainerType.BUILTIN_TRAINER:
178+
raise ValueError("Cannot get Runtime packages for BuiltinTrainer")
179+
180+
# Run mpirun only within the single process.
181+
if runtime.trainer.command[0] == "mpirun":
182+
mpi_command = list(constants.MPI_COMMAND)
183+
mpi_command[1:3] = ["-np", "1"]
184+
runtime.trainer.set_command(tuple(mpi_command))
185+
186+
def print_packages():
187+
import subprocess
188+
import shutil
189+
import sys
190+
191+
# Print Python version.
192+
print(f"Python: {sys.version}")
193+
194+
# Print Python packages.
195+
if shutil.which("pip"):
196+
pip_list = subprocess.run(
197+
["pip", "list"], capture_output=True, text=True
198+
)
199+
print(pip_list.stdout)
200+
else:
201+
print("Unable to get installed packages: pip command not found")
202+
203+
# Print nvidia-smi if GPUs are available.
204+
if shutil.which("nvidia-smi"):
205+
print("Available GPUs on the single training node")
206+
nvidia_smi = subprocess.run(
207+
["nvidia-smi"], capture_output=True, text=True
208+
)
209+
print(nvidia_smi.stdout)
210+
211+
# Create the TrainJob and wait until it completes.
212+
# If Runtime trainer has GPU resources use them, otherwise run TrainJob with 1 CPU.
213+
job_name = self.train(
214+
runtime=runtime,
215+
trainer=types.CustomTrainer(
216+
func=print_packages,
217+
num_nodes=1,
218+
resources_per_node=(
219+
{"cpu": 1} if runtime.trainer.device != "gpu" else None
220+
),
221+
),
222+
)
223+
224+
self.wait_for_job_status(job_name)
225+
print(self.get_job_logs(job_name)["node-0"])
226+
self.delete_job(job_name)
227+
162228
def train(
163229
self,
164230
runtime: Optional[types.Runtime] = None,
@@ -174,11 +240,11 @@ def train(
174240
the post-training logic, requiring only parameter adjustments, e.g. `BuiltinTrainer`.
175241
176242
Args:
177-
runtime (`types.Runtime`): Reference to one of existing Runtimes. By default the
243+
runtime: Reference to one of existing Runtimes. By default the
178244
torch-distributed Runtime is used.
179-
initializer (`Optional[types.Initializer]`):
245+
initializer:
180246
Configuration for the dataset and model initializers.
181-
trainer (`Optional[types.CustomTrainer, types.BuiltinTrainer]`):
247+
trainer:
182248
Configuration for Custom Training Task or Config-driven Task with Builtin Trainer.
183249
184250
Returns:
@@ -460,6 +526,7 @@ def wait_for_job_status(
460526
name: str,
461527
status: Set[str] = {constants.TRAINJOB_COMPLETE},
462528
timeout: int = 600,
529+
polling_interval: int = 2,
463530
) -> types.TrainJob:
464531
"""Wait for TrainJob to reach the desired status
465532
@@ -468,6 +535,7 @@ def wait_for_job_status(
468535
status: Set of expected statuses. It must be subset of Created, Running, Complete, and
469536
Failed statuses.
470537
timeout: How many seconds to wait until TrainJob reaches one of the expected conditions.
538+
polling_interval: The polling interval in seconds to check TrainJob status.
471539
472540
Returns:
473541
TrainJob: The training job that reaches the desired status.
@@ -489,36 +557,28 @@ def wait_for_job_status(
489557
f"Expected status {status} must be a subset of {job_statuses}"
490558
)
491559

492-
# Use Kubernetes watch API to monitor the TrainJob's Pods.
493-
w = watch.Watch()
494-
try:
495-
for event in w.stream(
496-
self.core_api.list_namespaced_pod,
497-
self.namespace,
498-
label_selector=constants.POD_LABEL_SELECTOR.format(trainjob_name=name),
499-
timeout_seconds=timeout,
500-
):
501-
# Check the status after event is generated for the TrainJob's Pods.
502-
trainjob = self.get_job(name)
503-
logger.debug(f"TrainJob {name}, status {trainjob.status}")
560+
if polling_interval > timeout:
561+
raise ValueError(
562+
f"Polling interval {polling_interval} must be less than timeout: {timeout}"
563+
)
504564

505-
# Raise an error if TrainJob is Failed and it is not the expected status.
506-
if (
507-
constants.TRAINJOB_FAILED not in status
508-
and trainjob.status == constants.TRAINJOB_FAILED
509-
):
510-
raise RuntimeError(f"TrainJob {name} is Failed")
565+
for _ in range(round(timeout / polling_interval)):
566+
# Check the status after event is generated for the TrainJob's Pods.
567+
trainjob = self.get_job(name)
568+
logger.debug(f"TrainJob {name}, status {trainjob.status}")
511569

512-
# Return the TrainJob if it reaches the expected status.
513-
if trainjob.status in status:
514-
return trainjob
570+
# Raise an error if TrainJob is Failed and it is not the expected status.
571+
if (
572+
constants.TRAINJOB_FAILED not in status
573+
and trainjob.status == constants.TRAINJOB_FAILED
574+
):
575+
raise RuntimeError(f"TrainJob {name} is Failed")
515576

516-
except TimeoutError:
517-
raise TimeoutError(f"Timeout to get the TrainJob {name}")
518-
except Exception:
519-
raise RuntimeError(f"Failed to watch Pods for TrainJob {name}")
520-
finally:
521-
w.stop()
577+
# Return the TrainJob if it reaches the expected status.
578+
if trainjob.status in status:
579+
return trainjob
580+
581+
time.sleep(polling_interval)
522582

523583
raise TimeoutError(
524584
f"Timeout waiting for TrainJob {name} to reach status: {status} status"
@@ -691,12 +751,16 @@ def __get_trainjob_from_crd(
691751
elif c.type == constants.TRAINJOB_FAILED and c.status == "True":
692752
trainjob.status = c.type
693753
else:
694-
# The TrainJob running status is defined when all training node (e.g. Pods) are running.
754+
# The TrainJob running status is defined when all training node (e.g. Pods) are
755+
# running or succeeded.
695756
num_running_nodes = sum(
696757
1
697758
for step in trainjob.steps
698759
if step.name.startswith(constants.NODE)
699-
and step.status == constants.TRAINJOB_RUNNING
760+
and (
761+
step.status == constants.TRAINJOB_RUNNING
762+
or step.status == constants.POD_SUCCEEDED
763+
)
700764
)
701765

702766
if trainjob.num_nodes == num_running_nodes:

python/kubeflow/trainer/api/trainer_client_test.py

Lines changed: 59 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,10 @@ class TestCase:
5656
# In all tests runtime name is equal to the framework name.
5757
TORCH_RUNTIME = "torch"
5858
TORCH_TUNE_RUNTIME = "torchtune"
59+
60+
# 2 nodes * 2 nproc
61+
RUNTIME_DEVICES = "4"
62+
5963
FAIL_LOGS = "fail_logs"
6064
LIST_RUNTIMES = "list_runtimes"
6165
BASIC_TRAIN_JOB_NAME = "basic-job"
@@ -95,11 +99,6 @@ def trainer_client(request):
9599
list_namespaced_pod=Mock(side_effect=list_namespaced_pod_response),
96100
read_namespaced_pod_log=Mock(side_effect=mock_read_namespaced_pod_log),
97101
),
98-
), patch(
99-
"kubernetes.watch.Watch",
100-
return_value=Mock(
101-
stream=Mock(side_effect=mock_watch),
102-
),
103102
):
104103
yield TrainerClient()
105104

@@ -509,7 +508,8 @@ def create_runtime_type(
509508
trainer_type=types.TrainerType.CUSTOM_TRAINER,
510509
framework=name,
511510
num_nodes=2,
512-
accelerator_count=4,
511+
device="gpu",
512+
device_count=RUNTIME_DEVICES,
513513
)
514514
trainer.set_command(constants.TORCH_COMMAND)
515515
return types.Runtime(
@@ -528,7 +528,8 @@ def get_train_job_data_type(
528528
trainer = types.RuntimeTrainer(
529529
trainer_type=types.TrainerType.CUSTOM_TRAINER,
530530
framework=runtime_name,
531-
accelerator_count=4,
531+
device="gpu",
532+
device_count=RUNTIME_DEVICES,
532533
num_nodes=2,
533534
)
534535
trainer.set_command(constants.TORCH_COMMAND)
@@ -644,6 +645,45 @@ def test_list_runtimes(trainer_client, test_case):
644645
print("test execution complete")
645646

646647

648+
@pytest.mark.parametrize(
649+
"test_case",
650+
[
651+
TestCase(
652+
name="valid flow with custom trainer runtime",
653+
expected_status=SUCCESS,
654+
config={"runtime": create_runtime_type(name=TORCH_RUNTIME)},
655+
),
656+
TestCase(
657+
name="value error with builtin trainer runtime",
658+
expected_status=FAILED,
659+
config={
660+
"runtime": types.Runtime(
661+
name="torchtune-runtime",
662+
trainer=types.RuntimeTrainer(
663+
trainer_type=types.TrainerType.BUILTIN_TRAINER,
664+
framework="torchtune",
665+
num_nodes=1,
666+
device="cpu",
667+
device_count="1",
668+
),
669+
)
670+
},
671+
expected_error=ValueError,
672+
),
673+
],
674+
)
675+
def test_get_runtime_packages(trainer_client, test_case):
676+
"""Test TrainerClient.get_runtime_packages with basic success path."""
677+
print("Executing test:", test_case.name)
678+
679+
try:
680+
trainer_client.get_runtime_packages(**test_case.config)
681+
except Exception as e:
682+
assert type(e) is test_case.expected_error
683+
684+
print("test execution complete")
685+
686+
647687
@pytest.mark.parametrize(
648688
"test_case",
649689
[
@@ -944,6 +984,16 @@ def test_get_job_logs(trainer_client, test_case):
944984
},
945985
expected_error=ValueError,
946986
),
987+
TestCase(
988+
name="polling interval is more than timeout error",
989+
expected_status=FAILED,
990+
config={
991+
"name": BASIC_TRAIN_JOB_NAME,
992+
"timeout": 1,
993+
"polling_interval": 2,
994+
},
995+
expected_error=ValueError,
996+
),
947997
TestCase(
948998
name="job failed when not expected",
949999
expected_status=FAILED,
@@ -959,7 +1009,8 @@ def test_get_job_logs(trainer_client, test_case):
9591009
config={
9601010
"name": BASIC_TRAIN_JOB_NAME,
9611011
"status": {constants.TRAINJOB_FAILED},
962-
"timeout": 1,
1012+
"polling_interval": 1,
1013+
"timeout": 2,
9631014
},
9641015
expected_error=TimeoutError,
9651016
),

python/kubeflow/trainer/constants/constants.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,8 @@
4141
# The default status for the TrainJob once users create it.
4242
TRAINJOB_CREATED = "Created"
4343

44-
# The running status of the TrainJob, defined when all training node (e.g. Pods) are running.
44+
# The running status of the TrainJob, defined when all training node (e.g. Pods) are
45+
# running or succeeded.
4546
TRAINJOB_RUNNING = "Running"
4647

4748
# The complete status of the TrainJob, defined when TrainJob CR has complete condition.
@@ -50,6 +51,9 @@
5051
# The failed status of the TrainJob, defined when TrainJob CR has failed condition.
5152
TRAINJOB_FAILED = "Failed"
5253

54+
# The succeeded phase of the Pod.
55+
POD_SUCCEEDED = "Succeeded"
56+
5357
# The label key to identify the relationship between TrainJob and Pod template in the runtime.
5458
# For example, what PodTemplate must be overridden by TrainJob's .spec.trainer APIs.
5559
TRAINJOB_ANCESTOR_LABEL = "trainer.kubeflow.org/trainjob-ancestor-step"

python/kubeflow/trainer/types/types.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
from dataclasses import dataclass, field
1717
from datetime import datetime
1818
from enum import Enum
19-
from typing import Callable, Dict, Optional, Union
19+
from typing import Callable, Dict, Optional
2020

2121
from kubeflow.trainer.constants import constants
2222

@@ -168,7 +168,8 @@ class RuntimeTrainer:
168168
trainer_type: TrainerType
169169
framework: str
170170
num_nodes: int = 1 # The default value is set in the APIs.
171-
accelerator_count: Union[str, float, int] = constants.UNKNOWN
171+
device: str = constants.UNKNOWN
172+
device_count: str = constants.UNKNOWN
172173
__command: tuple[str, ...] = field(init=False, repr=False)
173174

174175
@property
@@ -194,7 +195,7 @@ class Step:
194195
status: Optional[str]
195196
pod_name: str
196197
device: str = constants.UNKNOWN
197-
device_count: Union[str, int] = constants.UNKNOWN
198+
device_count: str = constants.UNKNOWN
198199

199200

200201
# Representation for the TrainJob.

python/kubeflow/trainer/utils/utils.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -131,19 +131,19 @@ def get_runtime_trainer(
131131

132132
# Get the container devices.
133133
if devices := get_container_devices(trainer_container.resources):
134-
_, trainer.accelerator_count = devices
134+
trainer.device, trainer.device_count = devices
135135

136136
# Torch and MPI plugins override accelerator count.
137137
if ml_policy.torch and ml_policy.torch.num_proc_per_node:
138138
num_proc = ml_policy.torch.num_proc_per_node.actual_instance
139139
if isinstance(num_proc, int):
140-
trainer.accelerator_count = num_proc
140+
trainer.device_count = str(num_proc)
141141
elif ml_policy.mpi and ml_policy.mpi.num_proc_per_node:
142-
trainer.accelerator_count = ml_policy.mpi.num_proc_per_node
142+
trainer.device_count = str(ml_policy.mpi.num_proc_per_node)
143143

144144
# Multiply accelerator_count by the number of nodes.
145-
if isinstance(trainer.accelerator_count, (int, float)) and ml_policy.num_nodes:
146-
trainer.accelerator_count *= ml_policy.num_nodes
145+
if trainer.device_count.isdigit() and ml_policy.num_nodes:
146+
trainer.device_count = str(int(trainer.device_count) * ml_policy.num_nodes)
147147

148148
# Add number of training nodes.
149149
if ml_policy.num_nodes:

0 commit comments

Comments
 (0)