@@ -61,6 +61,7 @@ def __init__(
61
61
mounts : Optional [List [str ]] = None ,
62
62
rdzv_port : int = 29500 ,
63
63
scheduler_args : Optional [Dict [str , str ]] = None ,
64
+ image : Optional [str ] = None ,
64
65
):
65
66
if bool (script ) == bool (m ): # logical XOR
66
67
raise ValueError (
@@ -82,6 +83,7 @@ def __init__(
82
83
self .scheduler_args : Dict [str , str ] = (
83
84
scheduler_args if scheduler_args is not None else dict ()
84
85
)
86
+ self .image = image
85
87
86
88
def _dry_run (self , cluster : "Cluster" ):
87
89
j = f"{ cluster .config .max_worker } x{ max (cluster .config .gpu , 1 )} " # # of proc. = # of gpus
@@ -108,15 +110,58 @@ def _dry_run(self, cluster: "Cluster"):
108
110
workspace = f"file://{ Path .cwd ()} " ,
109
111
)
110
112
111
- def submit (self , cluster : "Cluster" ) -> "Job" :
113
+ def _missing_spec (self , spec : str ):
114
+ raise ValueError (f"Job definition missing arg: { spec } " )
115
+
116
+ def _dry_run_no_cluster (self ):
117
+ return torchx_runner .dryrun (
118
+ app = ddp (
119
+ * self .script_args ,
120
+ script = self .script ,
121
+ m = self .m ,
122
+ name = self .name if self .name is not None else self ._missing_spec ("name" ),
123
+ h = self .h ,
124
+ cpu = self .cpu
125
+ if self .cpu is not None
126
+ else self ._missing_spec ("cpu (# cpus per worker)" ),
127
+ gpu = self .gpu
128
+ if self .gpu is not None
129
+ else self ._missing_spec ("gpu (# gpus per worker)" ),
130
+ memMB = self .memMB
131
+ if self .memMB is not None
132
+ else self ._missing_spec ("memMB (memory in MB)" ),
133
+ j = self .j
134
+ if self .j is not None
135
+ else self ._missing_spec (
136
+ "j (`workers`x`procs`)"
137
+ ), # # of proc. = # of gpus,
138
+ env = self .env , # should this still exist?
139
+ max_retries = self .max_retries ,
140
+ rdzv_port = self .rdzv_port , # should this still exist?
141
+ mounts = self .mounts ,
142
+ image = self .image
143
+ if self .image is not None
144
+ else self ._missing_spec ("image" ),
145
+ ),
146
+ scheduler = "kubernetes_mcad" ,
147
+ cfg = self .scheduler_args if self .scheduler_args is not None else None ,
148
+ workspace = "" ,
149
+ )
150
+
151
+ def submit (self , cluster : "Cluster" = None ) -> "Job" :
112
152
return DDPJob (self , cluster )
113
153
114
154
115
155
class DDPJob (Job ):
116
156
def __init__ (self , job_definition : "DDPJobDefinition" , cluster : "Cluster" ):
117
157
self .job_definition = job_definition
118
158
self .cluster = cluster
119
- self ._app_handle = torchx_runner .schedule (job_definition ._dry_run (cluster ))
159
+ if self .cluster :
160
+ self ._app_handle = torchx_runner .schedule (job_definition ._dry_run (cluster ))
161
+ else :
162
+ self ._app_handle = torchx_runner .schedule (
163
+ job_definition ._dry_run_no_cluster ()
164
+ )
120
165
all_jobs .append (self )
121
166
122
167
def status (self ) -> str :
0 commit comments