Skip to content

refactor diloco test #232

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jul 15, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
184 changes: 5 additions & 179 deletions torchft/local_sgd_integ_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,38 +33,11 @@
ProcessGroupBabyNCCL,
ProcessGroupGloo,
)
from torchft.test.diloco_trainer import DiLoCoTrainer, MultiMyModel

logger: logging.Logger = logging.getLogger(__name__)


class MultiMyModel(torch.nn.Module):
def __init__(self, in_dim: int = 3, out_dim: int = 4, n_layers: int = 1) -> None:
super().__init__()
self.in_dim = in_dim

self.layers = torch.nn.ModuleList()
for i in range(n_layers):
self.layers.append(MyModel(in_dim, out_dim))
in_dim, out_dim = out_dim, in_dim

self.out_dim = in_dim

def forward(self, x: torch.Tensor) -> torch.Tensor:
for layer in self.layers:
x = layer(x)
return x

def get_rand_inputs(
self, batch_size: int, device: torch.device = torch.device("cpu")
) -> torch.Tensor:
return torch.rand(batch_size, self.in_dim, device=device)

def get_rand_labels(
self, batch_size: int, device: torch.device = torch.device("cpu")
) -> torch.Tensor:
return torch.randint(self.out_dim, (batch_size,), device=device)


def local_sgd_train_loop(
rank: int,
store_port: int,
Expand Down Expand Up @@ -148,158 +121,11 @@ def diloco_train_loop(
diloco_args = train_loop_args.get("diloco_args", {})

with ExitStack() as stack:
# Declare the model and optimizers
m = MultiMyModel(2, 3, n_fragments)
m.load_state_dict(model_state_dict)
m.to(device)

# Setup optimizers
inner_optimizer: optim.Optimizer = torch.optim.AdamW(
m.parameters(), lr=4e-4, weight_decay=0.1, betas=(0.9, 0.95)
trainer = DiLoCoTrainer(
rank, store_port, device, runner, model_state_dict, n_fragments, diloco_args
)

# Create one outer optimizer per fragment
outer_optimizers = []
for _, layer in enumerate(m.layers):
outer_optimizers.append(
torch.optim.SGD(layer.parameters(), lr=0.7, momentum=0.9, nesterov=True)
)

# pyre-ignore[53]
def load_state_dict(state_dict: Dict[str, Dict[str, object]]) -> None:
m.load_state_dict(state_dict["model"])
m.to(device)

# Load original parameters for each fragment
for i, fragment in enumerate(diloco._fragments):
fragment.original_parameters = cast(
Dict[str, torch.Tensor], state_dict["original_params"][f"{i}"]
)

for fragment in diloco._fragments:
for name in fragment.original_parameters.keys():
fragment.original_parameters[name] = fragment.original_parameters[
name
].to(device)

inner_optimizer.load_state_dict(state_dict["inner_optim"])
for i, optimizer in enumerate(outer_optimizers):
optimizer.load_state_dict(state_dict[f"outer_optim"][f"{i}"])

def state_dict() -> Dict[str, Dict[str, object]]: # pyre-ignore[53]
return {
"model": m.state_dict(),
"original_params": {
f"{i}": fragment.original_parameters
for i, fragment in enumerate(diloco._fragments)
},
"inner_optim": inner_optimizer.state_dict(),
"outer_optim": {
f"{i}": optimizer.state_dict()
for i, optimizer in enumerate(outer_optimizers)
},
}

print(f"worker {runner.replica_id=} {rank=} {runner.world_size=} starting")

if device.type == "cuda":
pg = FakeProcessGroupWrapper(ProcessGroupBabyNCCL())
else:
pg = FakeProcessGroupWrapper(
ProcessGroupGloo(timeout=timedelta(seconds=10))
)
manager = Manager(
pg=pg,
min_replica_size=2,
use_async_quorum=False,
load_state_dict=load_state_dict,
state_dict=state_dict,
replica_id=str(runner.replica_id),
store_addr="localhost",
store_port=store_port,
rank=rank,
world_size=runner.world_size,
lighthouse_addr=runner.lighthouse_address,
port=19530 + runner.replica_id,
connect_timeout=timedelta(seconds=10),
quorum_timeout=timedelta(seconds=10),
timeout=timedelta(seconds=10),
# pyre-fixme[6]: Incompatible parameter type
**runner.manager_args,
)
runner.event_injector.set_pg(pg)
stack.callback(manager.shutdown)
# initialize default group for device mesh to work
if not torch.distributed.is_initialized():
# TODO: remove this try-except once pytorch is updated to 2.8.0 and can use localhost:0
try:
torch.distributed.init_process_group(
init_method="tcp://localhost:0",
rank=rank,
world_size=runner.world_size,
)
except ValueError:
os.environ["MASTER_ADDR"] = "localhost"
os.environ["MASTER_PORT"] = "0"
os.environ["WORLD_SIZE"] = str(runner.world_size)
os.environ["RANK"] = str(rank)

device_type = device.type
ft_device_mesh = ft_init_device_mesh(
device_type=device_type,
mesh_shape=(runner.world_size, 1),
mesh_dim_names=("replicate", "none"),
replicate_dim=0,
manager=manager,
)
for layer in m.layers:
if isinstance(layer, nn.Linear):
for param in layer.parameters():
param = DTensor.from_local(
param,
device_mesh=ft_device_mesh,
)

criterion = nn.CrossEntropyLoss()
all_state_dicts = {}

if "sync_every" not in diloco_args:
diloco_args["sync_every"] = 2

with DiLoCo(
manager,
[layer for layer in m.layers],
inner_optimizer,
outer_optimizers,
backup_device=device,
**diloco_args,
) as diloco:
while True:
runner.event_injector.check(rank, manager.current_step())

manager_curr_step = manager.current_step()
if manager_curr_step not in all_state_dicts:
all_state_dicts[manager_curr_step] = copy.deepcopy(
manager._manager_state_dict()
)

batch_size = 1
inputs = m.get_rand_inputs(batch_size, device=device)
labels = m.get_rand_labels(batch_size, device=device)

out = m(inputs)
loss = criterion(out, labels)

inner_optimizer.zero_grad()
loss.backward()
inner_optimizer.step()

# after 4 model updates then break
if manager.current_step() >= 4:
break

# return state_dict so we can check consistency
return all_state_dicts
stack.callback(trainer.manager.shutdown)
return trainer.train_loop()
return {}


Expand Down
Loading