Skip to content

Commit 50b846b

Browse files
committed
Import error in manager.py + switch to sync mode
1 parent 7cc11f2 commit 50b846b

File tree

2 files changed

+2
-2
lines changed

2 files changed

+2
-2
lines changed

torchft/manager.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@
3636
from typing import TYPE_CHECKING, Callable, Dict, List, Optional, TypeVar, cast
3737

3838
import torch
39+
import torch.distributed as dist
3940
from torch.distributed import ReduceOp, TCPStore
4041

4142
from torchft.checkpointing import CheckpointServer

train_fsdp.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -126,6 +126,7 @@ def state_dict():
126126
load_state_dict=load_state_dict,
127127
state_dict=state_dict,
128128
replica_id=f"train_fsdp_{REPLICA_GROUP_ID}",
129+
use_async_quorum=False,
129130
)
130131

131132
mesh = hsdp_device_mesh(NUM_REPLICA_GROUPS, NUM_REPLICAS, "cuda" if torch.cuda.is_available() else "cpu", manager=manager)
@@ -136,8 +137,6 @@ def state_dict():
136137

137138
optimizer = Optimizer(manager, torch.optim.Adam(model.parameters(), lr=1e-5))
138139

139-
optimizer.zero_grad()
140-
141140
while manager.current_step() < 500:
142141
model.train()
143142
for batch in tqdm(train_dataloader):

0 commit comments

Comments
 (0)