File tree Expand file tree Collapse file tree 1 file changed +19
-9
lines changed
Expand file tree Collapse file tree 1 file changed +19
-9
lines changed Original file line number Diff line number Diff line change @@ -24,18 +24,28 @@ def __init__(
2424
2525 @BasicCommLib .register_api
2626 def grad_sync_done (self , * args , ** kwargs ):
27- model = None
28- if "model" in kwargs .keys ():
29- model = kwargs ["model" ]
30- elif len (args ) > 0 :
31- model = args [0 ]
32- if model is None :
27+ # it is required to pass args with keys
28+ if "model" not in kwargs .keys () and "params" not in kwargs .keys ():
3329 return CommLibStatus .FAIL
30+
31+ params = (
32+ kwargs ["model" ].parameters ()
33+ if "model" in kwargs .keys ()
34+ else kwargs ["params" ]
35+ )
36+
37+ get_data = (
38+ (lambda x : x .grad .data )
39+ if "model" in kwargs .keys ()
40+ else (lambda x : torch .from_numpy (x ))
41+ )
42+
3443 try :
3544 size = float (dist .get_world_size ())
36- for param in model .parameters ():
37- dist .all_reduce (param .grad .data , op = dist .reduce_op .SUM )
38- param .grad .data /= size
45+ for param in params :
46+ data = get_data (param )
47+ dist .all_reduce (data , op = dist .reduce_op .SUM )
48+ data /= size
3949 except Exception as e :
4050 logging .error (str (e ))
4151 return CommLibStatus .FAIL
You can’t perform that action at this time.
0 commit comments