diff --git a/torchft/local_sgd_integ_test.py b/torchft/local_sgd_integ_test.py index 355e4877..2b0e065e 100644 --- a/torchft/local_sgd_integ_test.py +++ b/torchft/local_sgd_integ_test.py @@ -9,7 +9,7 @@ from contextlib import ExitStack from dataclasses import field from datetime import timedelta -from typing import Any, Dict +from typing import Any, Dict, cast from unittest import TestCase, skipIf import torch @@ -157,17 +157,24 @@ def diloco_train_loop( inner_optimizer: optim.Optimizer = torch.optim.AdamW( m.parameters(), lr=4e-4, weight_decay=0.1, betas=(0.9, 0.95) ) - outer_optimizer: optim.Optimizer = torch.optim.SGD( - m.parameters(), lr=0.7, momentum=0.9, nesterov=True - ) + + # Create one outer optimizer per fragment + outer_optimizers = [] + for _, layer in enumerate(m.layers): + outer_optimizers.append( + torch.optim.SGD(layer.parameters(), lr=0.7, momentum=0.9, nesterov=True) + ) # pyre-ignore[53] def load_state_dict(state_dict: Dict[str, Dict[str, object]]) -> None: m.load_state_dict(state_dict["model"]) m.to(device) + # Load original parameters for each fragment for i, fragment in enumerate(diloco._fragments): - fragment.original_parameters = state_dict["original_params"][f"{i}"] + fragment.original_parameters = cast( + Dict[str, torch.Tensor], state_dict["original_params"][f"{i}"] + ) for fragment in diloco._fragments: for name in fragment.original_parameters.keys(): @@ -176,7 +183,8 @@ def load_state_dict(state_dict: Dict[str, Dict[str, object]]) -> None: ].to(device) inner_optimizer.load_state_dict(state_dict["inner_optim"]) - outer_optimizer.load_state_dict(state_dict["outer_optim"]) + for i, optimizer in enumerate(outer_optimizers): + optimizer.load_state_dict(state_dict[f"outer_optim"][f"{i}"]) def state_dict() -> Dict[str, Dict[str, object]]: # pyre-ignore[53] return { @@ -186,7 +194,10 @@ def state_dict() -> Dict[str, Dict[str, object]]: # pyre-ignore[53] for i, fragment in enumerate(diloco._fragments) }, "inner_optim": inner_optimizer.state_dict(), - "outer_optim": outer_optimizer.state_dict(), + "outer_optim": { + f"{i}": optimizer.state_dict() + for i, optimizer in enumerate(outer_optimizers) + }, } print(f"worker {runner.replica_id=} {rank=} {runner.world_size=} starting") @@ -259,7 +270,7 @@ def state_dict() -> Dict[str, Dict[str, object]]: # pyre-ignore[53] manager, [layer for layer in m.layers], inner_optimizer, - outer_optimizer, + outer_optimizers, backup_device=device, **diloco_args, ) as diloco: @@ -305,11 +316,26 @@ def assert_equal_global_state( rep0[step]["user"]["default"]["original_params"], check_device=False, ) - torch.testing.assert_close( - rep1[step]["user"]["default"]["outer_optim"], - rep0[step]["user"]["default"]["outer_optim"], - check_device=False, - ) + # Check all outer optimizers + for i in range( + len( + cast( + dict[str, dict[str, torch.Tensor]], + rep0[step]["user"]["default"]["outer_optim"], + ).keys() + ) + ): + torch.testing.assert_close( + cast( + dict[str, dict[str, torch.Tensor]], + rep1[step]["user"]["default"]["outer_optim"], + )[f"{i}"], + cast( + dict[str, dict[str, torch.Tensor]], + rep0[step]["user"]["default"]["outer_optim"], + )[f"{i}"], + check_device=False, + ) class LocalSGDIntegTest(TestCase): @@ -420,18 +446,7 @@ def test_diloco_healthy(self, use_cuda: bool) -> None: lighthouse.shutdown() rep0, rep1 = state_dicts - for step, state_dict in rep1.items(): - # inner optimizer will be different, outer optimizer and model should be the same - torch.testing.assert_close( - state_dict["user"]["default"]["model"], - rep0[step]["user"]["default"]["model"], - check_device=False, - ) - torch.testing.assert_close( - state_dict["user"]["default"]["outer_optim"], - rep0[step]["user"]["default"]["outer_optim"], - check_device=False, - ) + assert_equal_global_state(rep1, rep0) # pyre-fixme[56]: Pyre was not able to infer the type of argument @skipIf(sys.platform == "darwin", "not reliable on mac")