9
9
from contextlib import ExitStack
10
10
from dataclasses import field
11
11
from datetime import timedelta
12
- from typing import Any , Dict
12
+ from typing import Any , Dict , cast
13
13
from unittest import TestCase , skipIf
14
14
15
15
import torch
@@ -157,17 +157,24 @@ def diloco_train_loop(
157
157
inner_optimizer : optim .Optimizer = torch .optim .AdamW (
158
158
m .parameters (), lr = 4e-4 , weight_decay = 0.1 , betas = (0.9 , 0.95 )
159
159
)
160
- outer_optimizer : optim .Optimizer = torch .optim .SGD (
161
- m .parameters (), lr = 0.7 , momentum = 0.9 , nesterov = True
162
- )
160
+
161
+ # Create one outer optimizer per fragment
162
+ outer_optimizers = []
163
+ for _ , layer in enumerate (m .layers ):
164
+ outer_optimizers .append (
165
+ torch .optim .SGD (layer .parameters (), lr = 0.7 , momentum = 0.9 , nesterov = True )
166
+ )
163
167
164
168
# pyre-ignore[53]
165
169
def load_state_dict (state_dict : Dict [str , Dict [str , object ]]) -> None :
166
170
m .load_state_dict (state_dict ["model" ])
167
171
m .to (device )
168
172
173
+ # Load original parameters for each fragment
169
174
for i , fragment in enumerate (diloco ._fragments ):
170
- fragment .original_parameters = state_dict ["original_params" ][f"{ i } " ]
175
+ fragment .original_parameters = cast (
176
+ Dict [str , torch .Tensor ], state_dict ["original_params" ][f"{ i } " ]
177
+ )
171
178
172
179
for fragment in diloco ._fragments :
173
180
for name in fragment .original_parameters .keys ():
@@ -176,7 +183,8 @@ def load_state_dict(state_dict: Dict[str, Dict[str, object]]) -> None:
176
183
].to (device )
177
184
178
185
inner_optimizer .load_state_dict (state_dict ["inner_optim" ])
179
- outer_optimizer .load_state_dict (state_dict ["outer_optim" ])
186
+ for i , optimizer in enumerate (outer_optimizers ):
187
+ optimizer .load_state_dict (state_dict [f"outer_optim" ][f"{ i } " ])
180
188
181
189
def state_dict () -> Dict [str , Dict [str , object ]]: # pyre-ignore[53]
182
190
return {
@@ -186,7 +194,10 @@ def state_dict() -> Dict[str, Dict[str, object]]: # pyre-ignore[53]
186
194
for i , fragment in enumerate (diloco ._fragments )
187
195
},
188
196
"inner_optim" : inner_optimizer .state_dict (),
189
- "outer_optim" : outer_optimizer .state_dict (),
197
+ "outer_optim" : {
198
+ f"{ i } " : optimizer .state_dict ()
199
+ for i , optimizer in enumerate (outer_optimizers )
200
+ },
190
201
}
191
202
192
203
print (f"worker { runner .replica_id = } { rank = } { runner .world_size = } starting" )
@@ -259,7 +270,7 @@ def state_dict() -> Dict[str, Dict[str, object]]: # pyre-ignore[53]
259
270
manager ,
260
271
[layer for layer in m .layers ],
261
272
inner_optimizer ,
262
- outer_optimizer ,
273
+ outer_optimizers ,
263
274
backup_device = device ,
264
275
** diloco_args ,
265
276
) as diloco :
@@ -305,11 +316,26 @@ def assert_equal_global_state(
305
316
rep0 [step ]["user" ]["default" ]["original_params" ],
306
317
check_device = False ,
307
318
)
308
- torch .testing .assert_close (
309
- rep1 [step ]["user" ]["default" ]["outer_optim" ],
310
- rep0 [step ]["user" ]["default" ]["outer_optim" ],
311
- check_device = False ,
312
- )
319
+ # Check all outer optimizers
320
+ for i in range (
321
+ len (
322
+ cast (
323
+ dict [str , dict [str , torch .Tensor ]],
324
+ rep0 [step ]["user" ]["default" ]["outer_optim" ],
325
+ ).keys ()
326
+ )
327
+ ):
328
+ torch .testing .assert_close (
329
+ cast (
330
+ dict [str , dict [str , torch .Tensor ]],
331
+ rep1 [step ]["user" ]["default" ]["outer_optim" ],
332
+ )[f"{ i } " ],
333
+ cast (
334
+ dict [str , dict [str , torch .Tensor ]],
335
+ rep0 [step ]["user" ]["default" ]["outer_optim" ],
336
+ )[f"{ i } " ],
337
+ check_device = False ,
338
+ )
313
339
314
340
315
341
class LocalSGDIntegTest (TestCase ):
@@ -420,18 +446,7 @@ def test_diloco_healthy(self, use_cuda: bool) -> None:
420
446
lighthouse .shutdown ()
421
447
422
448
rep0 , rep1 = state_dicts
423
- for step , state_dict in rep1 .items ():
424
- # inner optimizer will be different, outer optimizer and model should be the same
425
- torch .testing .assert_close (
426
- state_dict ["user" ]["default" ]["model" ],
427
- rep0 [step ]["user" ]["default" ]["model" ],
428
- check_device = False ,
429
- )
430
- torch .testing .assert_close (
431
- state_dict ["user" ]["default" ]["outer_optim" ],
432
- rep0 [step ]["user" ]["default" ]["outer_optim" ],
433
- check_device = False ,
434
- )
449
+ assert_equal_global_state (rep1 , rep0 )
435
450
436
451
# pyre-fixme[56]: Pyre was not able to infer the type of argument
437
452
@skipIf (sys .platform == "darwin" , "not reliable on mac" )
0 commit comments