Skip to content

Commit a0acd51

Browse files
authored
test multiple outer optimizers (#231)
Summary: - change to pass multiple outer optimizers in the tests
1 parent 347fd32 commit a0acd51

File tree

1 file changed

+40
-25
lines changed

1 file changed

+40
-25
lines changed

torchft/local_sgd_integ_test.py

Lines changed: 40 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
from contextlib import ExitStack
1010
from dataclasses import field
1111
from datetime import timedelta
12-
from typing import Any, Dict
12+
from typing import Any, Dict, cast
1313
from unittest import TestCase, skipIf
1414

1515
import torch
@@ -157,17 +157,24 @@ def diloco_train_loop(
157157
inner_optimizer: optim.Optimizer = torch.optim.AdamW(
158158
m.parameters(), lr=4e-4, weight_decay=0.1, betas=(0.9, 0.95)
159159
)
160-
outer_optimizer: optim.Optimizer = torch.optim.SGD(
161-
m.parameters(), lr=0.7, momentum=0.9, nesterov=True
162-
)
160+
161+
# Create one outer optimizer per fragment
162+
outer_optimizers = []
163+
for _, layer in enumerate(m.layers):
164+
outer_optimizers.append(
165+
torch.optim.SGD(layer.parameters(), lr=0.7, momentum=0.9, nesterov=True)
166+
)
163167

164168
# pyre-ignore[53]
165169
def load_state_dict(state_dict: Dict[str, Dict[str, object]]) -> None:
166170
m.load_state_dict(state_dict["model"])
167171
m.to(device)
168172

173+
# Load original parameters for each fragment
169174
for i, fragment in enumerate(diloco._fragments):
170-
fragment.original_parameters = state_dict["original_params"][f"{i}"]
175+
fragment.original_parameters = cast(
176+
Dict[str, torch.Tensor], state_dict["original_params"][f"{i}"]
177+
)
171178

172179
for fragment in diloco._fragments:
173180
for name in fragment.original_parameters.keys():
@@ -176,7 +183,8 @@ def load_state_dict(state_dict: Dict[str, Dict[str, object]]) -> None:
176183
].to(device)
177184

178185
inner_optimizer.load_state_dict(state_dict["inner_optim"])
179-
outer_optimizer.load_state_dict(state_dict["outer_optim"])
186+
for i, optimizer in enumerate(outer_optimizers):
187+
optimizer.load_state_dict(state_dict[f"outer_optim"][f"{i}"])
180188

181189
def state_dict() -> Dict[str, Dict[str, object]]: # pyre-ignore[53]
182190
return {
@@ -186,7 +194,10 @@ def state_dict() -> Dict[str, Dict[str, object]]: # pyre-ignore[53]
186194
for i, fragment in enumerate(diloco._fragments)
187195
},
188196
"inner_optim": inner_optimizer.state_dict(),
189-
"outer_optim": outer_optimizer.state_dict(),
197+
"outer_optim": {
198+
f"{i}": optimizer.state_dict()
199+
for i, optimizer in enumerate(outer_optimizers)
200+
},
190201
}
191202

192203
print(f"worker {runner.replica_id=} {rank=} {runner.world_size=} starting")
@@ -259,7 +270,7 @@ def state_dict() -> Dict[str, Dict[str, object]]: # pyre-ignore[53]
259270
manager,
260271
[layer for layer in m.layers],
261272
inner_optimizer,
262-
outer_optimizer,
273+
outer_optimizers,
263274
backup_device=device,
264275
**diloco_args,
265276
) as diloco:
@@ -305,11 +316,26 @@ def assert_equal_global_state(
305316
rep0[step]["user"]["default"]["original_params"],
306317
check_device=False,
307318
)
308-
torch.testing.assert_close(
309-
rep1[step]["user"]["default"]["outer_optim"],
310-
rep0[step]["user"]["default"]["outer_optim"],
311-
check_device=False,
312-
)
319+
# Check all outer optimizers
320+
for i in range(
321+
len(
322+
cast(
323+
dict[str, dict[str, torch.Tensor]],
324+
rep0[step]["user"]["default"]["outer_optim"],
325+
).keys()
326+
)
327+
):
328+
torch.testing.assert_close(
329+
cast(
330+
dict[str, dict[str, torch.Tensor]],
331+
rep1[step]["user"]["default"]["outer_optim"],
332+
)[f"{i}"],
333+
cast(
334+
dict[str, dict[str, torch.Tensor]],
335+
rep0[step]["user"]["default"]["outer_optim"],
336+
)[f"{i}"],
337+
check_device=False,
338+
)
313339

314340

315341
class LocalSGDIntegTest(TestCase):
@@ -420,18 +446,7 @@ def test_diloco_healthy(self, use_cuda: bool) -> None:
420446
lighthouse.shutdown()
421447

422448
rep0, rep1 = state_dicts
423-
for step, state_dict in rep1.items():
424-
# inner optimizer will be different, outer optimizer and model should be the same
425-
torch.testing.assert_close(
426-
state_dict["user"]["default"]["model"],
427-
rep0[step]["user"]["default"]["model"],
428-
check_device=False,
429-
)
430-
torch.testing.assert_close(
431-
state_dict["user"]["default"]["outer_optim"],
432-
rep0[step]["user"]["default"]["outer_optim"],
433-
check_device=False,
434-
)
449+
assert_equal_global_state(rep1, rep0)
435450

436451
# pyre-fixme[56]: Pyre was not able to infer the type of argument
437452
@skipIf(sys.platform == "darwin", "not reliable on mac")

0 commit comments

Comments
 (0)