Skip to content

Commit 78a62de

Browse files
authored
feat: update grad_sync_done (kleveross#72)
1 parent 39d3133 commit 78a62de

File tree

1 file changed

+19
-9
lines changed

1 file changed

+19
-9
lines changed

ftlib/commlib/pytorch/impl.py

Lines changed: 19 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -24,18 +24,28 @@ def __init__(
2424

2525
@BasicCommLib.register_api
2626
def grad_sync_done(self, *args, **kwargs):
27-
model = None
28-
if "model" in kwargs.keys():
29-
model = kwargs["model"]
30-
elif len(args) > 0:
31-
model = args[0]
32-
if model is None:
27+
# it is required to pass args with keys
28+
if "model" not in kwargs.keys() and "params" not in kwargs.keys():
3329
return CommLibStatus.FAIL
30+
31+
params = (
32+
kwargs["model"].parameters()
33+
if "model" in kwargs.keys()
34+
else kwargs["params"]
35+
)
36+
37+
get_data = (
38+
(lambda x: x.grad.data)
39+
if "model" in kwargs.keys()
40+
else (lambda x: torch.from_numpy(x))
41+
)
42+
3443
try:
3544
size = float(dist.get_world_size())
36-
for param in model.parameters():
37-
dist.all_reduce(param.grad.data, op=dist.reduce_op.SUM)
38-
param.grad.data /= size
45+
for param in params:
46+
data = get_data(param)
47+
dist.all_reduce(data, op=dist.reduce_op.SUM)
48+
data /= size
3949
except Exception as e:
4050
logging.error(str(e))
4151
return CommLibStatus.FAIL

0 commit comments

Comments
 (0)