Skip to content

Commit 8266307

Browse files
authored
Adding TorchX-MCAD Scheduler Support to Jobs (#78)
* First pass torchx-mcad * Make workspace blank * Updated requirements * Feedback applied
1 parent 7945f5a commit 8266307

File tree

3 files changed

+52
-4
lines changed

3 files changed

+52
-4
lines changed

pyproject.toml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,9 @@ homepage = "https://github.com/project-codeflare/codeflare-sdk"
2020
keywords = ['codeflare', 'python', 'sdk', 'client', 'batch', 'scale']
2121

2222
[tool.poetry.dependencies]
23-
python = "^3.6.3"
23+
python = "^3.7"
2424
openshift-client = "1.0.18"
2525
rich = "^12.5"
2626
ray = {version = "2.1.0", extras = ["default"]}
27+
kubernetes = "26.1.0"
28+
torchx = {git = "https://github.com/project-codeflare/torchx", rev = "OCP"}

requirements.txt

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
openshift-client==1.0.18
22
rich==12.5.1
33
ray[default]==2.1.0
4-
git+https://github.com/project-codeflare/torchx@6517d5b060e4fe32b9ad41019c3bef647095c35f#egg=torchx
4+
kubernetes==26.1.0
5+
git+https://github.com/project-codeflare/torchx@OCP

src/codeflare_sdk/job/jobs.py

Lines changed: 47 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,7 @@ def __init__(
6161
mounts: Optional[List[str]] = None,
6262
rdzv_port: int = 29500,
6363
scheduler_args: Optional[Dict[str, str]] = None,
64+
image: Optional[str] = None,
6465
):
6566
if bool(script) == bool(m): # logical XOR
6667
raise ValueError(
@@ -82,6 +83,7 @@ def __init__(
8283
self.scheduler_args: Dict[str, str] = (
8384
scheduler_args if scheduler_args is not None else dict()
8485
)
86+
self.image = image
8587

8688
def _dry_run(self, cluster: "Cluster"):
8789
j = f"{cluster.config.max_worker}x{max(cluster.config.gpu, 1)}" # # of proc. = # of gpus
@@ -108,15 +110,58 @@ def _dry_run(self, cluster: "Cluster"):
108110
workspace=f"file://{Path.cwd()}",
109111
)
110112

111-
def submit(self, cluster: "Cluster") -> "Job":
113+
def _missing_spec(self, spec: str):
114+
raise ValueError(f"Job definition missing arg: {spec}")
115+
116+
def _dry_run_no_cluster(self):
117+
return torchx_runner.dryrun(
118+
app=ddp(
119+
*self.script_args,
120+
script=self.script,
121+
m=self.m,
122+
name=self.name if self.name is not None else self._missing_spec("name"),
123+
h=self.h,
124+
cpu=self.cpu
125+
if self.cpu is not None
126+
else self._missing_spec("cpu (# cpus per worker)"),
127+
gpu=self.gpu
128+
if self.gpu is not None
129+
else self._missing_spec("gpu (# gpus per worker)"),
130+
memMB=self.memMB
131+
if self.memMB is not None
132+
else self._missing_spec("memMB (memory in MB)"),
133+
j=self.j
134+
if self.j is not None
135+
else self._missing_spec(
136+
"j (`workers`x`procs`)"
137+
), # # of proc. = # of gpus,
138+
env=self.env, # should this still exist?
139+
max_retries=self.max_retries,
140+
rdzv_port=self.rdzv_port, # should this still exist?
141+
mounts=self.mounts,
142+
image=self.image
143+
if self.image is not None
144+
else self._missing_spec("image"),
145+
),
146+
scheduler="kubernetes_mcad",
147+
cfg=self.scheduler_args if self.scheduler_args is not None else None,
148+
workspace="",
149+
)
150+
151+
def submit(self, cluster: "Cluster" = None) -> "Job":
112152
return DDPJob(self, cluster)
113153

114154

115155
class DDPJob(Job):
116156
def __init__(self, job_definition: "DDPJobDefinition", cluster: "Cluster"):
117157
self.job_definition = job_definition
118158
self.cluster = cluster
119-
self._app_handle = torchx_runner.schedule(job_definition._dry_run(cluster))
159+
if self.cluster:
160+
self._app_handle = torchx_runner.schedule(job_definition._dry_run(cluster))
161+
else:
162+
self._app_handle = torchx_runner.schedule(
163+
job_definition._dry_run_no_cluster()
164+
)
120165
all_jobs.append(self)
121166

122167
def status(self) -> str:

0 commit comments

Comments
 (0)