Skip to content

test multiple outer optimizers #231

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jul 15, 2025
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
65 changes: 40 additions & 25 deletions torchft/local_sgd_integ_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from contextlib import ExitStack
from dataclasses import field
from datetime import timedelta
from typing import Any, Dict
from typing import Any, Dict, cast
from unittest import TestCase, skipIf

import torch
Expand Down Expand Up @@ -157,17 +157,24 @@ def diloco_train_loop(
inner_optimizer: optim.Optimizer = torch.optim.AdamW(
m.parameters(), lr=4e-4, weight_decay=0.1, betas=(0.9, 0.95)
)
outer_optimizer: optim.Optimizer = torch.optim.SGD(
m.parameters(), lr=0.7, momentum=0.9, nesterov=True
)

# Create one outer optimizer per fragment
outer_optimizers = []
for _, layer in enumerate(m.layers):
outer_optimizers.append(
torch.optim.SGD(layer.parameters(), lr=0.7, momentum=0.9, nesterov=True)
)

# pyre-ignore[53]
def load_state_dict(state_dict: Dict[str, Dict[str, object]]) -> None:
m.load_state_dict(state_dict["model"])
m.to(device)

# Load original parameters for each fragment
for i, fragment in enumerate(diloco._fragments):
fragment.original_parameters = state_dict["original_params"][f"{i}"]
fragment.original_parameters = cast(
Dict[str, torch.Tensor], state_dict["original_params"][f"{i}"]
)

for fragment in diloco._fragments:
for name in fragment.original_parameters.keys():
Expand All @@ -176,7 +183,8 @@ def load_state_dict(state_dict: Dict[str, Dict[str, object]]) -> None:
].to(device)

inner_optimizer.load_state_dict(state_dict["inner_optim"])
outer_optimizer.load_state_dict(state_dict["outer_optim"])
for i, optimizer in enumerate(outer_optimizers):
optimizer.load_state_dict(state_dict[f"outer_optim"][f"{i}"])

def state_dict() -> Dict[str, Dict[str, object]]: # pyre-ignore[53]
return {
Expand All @@ -186,7 +194,10 @@ def state_dict() -> Dict[str, Dict[str, object]]: # pyre-ignore[53]
for i, fragment in enumerate(diloco._fragments)
},
"inner_optim": inner_optimizer.state_dict(),
"outer_optim": outer_optimizer.state_dict(),
"outer_optim": {
f"{i}": optimizer.state_dict()
for i, optimizer in enumerate(outer_optimizers)
},
}

print(f"worker {runner.replica_id=} {rank=} {runner.world_size=} starting")
Expand Down Expand Up @@ -259,7 +270,7 @@ def state_dict() -> Dict[str, Dict[str, object]]: # pyre-ignore[53]
manager,
[layer for layer in m.layers],
inner_optimizer,
outer_optimizer,
outer_optimizers,
backup_device=device,
**diloco_args,
) as diloco:
Expand Down Expand Up @@ -305,11 +316,26 @@ def assert_equal_global_state(
rep0[step]["user"]["default"]["original_params"],
check_device=False,
)
torch.testing.assert_close(
rep1[step]["user"]["default"]["outer_optim"],
rep0[step]["user"]["default"]["outer_optim"],
check_device=False,
)
# Check all outer optimizers
for i in range(
len(
cast(
dict[str, dict[str, torch.Tensor]],
rep0[step]["user"]["default"]["outer_optim"],
).keys()
)
):
torch.testing.assert_close(
cast(
dict[str, dict[str, torch.Tensor]],
rep1[step]["user"]["default"]["outer_optim"],
)[f"{i}"],
cast(
dict[str, dict[str, torch.Tensor]],
rep0[step]["user"]["default"]["outer_optim"],
)[f"{i}"],
check_device=False,
)


class LocalSGDIntegTest(TestCase):
Expand Down Expand Up @@ -420,18 +446,7 @@ def test_diloco_healthy(self, use_cuda: bool) -> None:
lighthouse.shutdown()

rep0, rep1 = state_dicts
for step, state_dict in rep1.items():
# inner optimizer will be different, outer optimizer and model should be the same
torch.testing.assert_close(
state_dict["user"]["default"]["model"],
rep0[step]["user"]["default"]["model"],
check_device=False,
)
torch.testing.assert_close(
state_dict["user"]["default"]["outer_optim"],
rep0[step]["user"]["default"]["outer_optim"],
check_device=False,
)
assert_equal_global_state(rep1, rep0)

# pyre-fixme[56]: Pyre was not able to infer the type of argument
@skipIf(sys.platform == "darwin", "not reliable on mac")
Expand Down