@@ -156,21 +156,22 @@ def _test_distrib_all_reduce_group(device):
156156
157157def _test_distrib_all_gather (device ):
158158 rank = idist .get_rank ()
159+ ws = idist .get_world_size ()
159160
160161 res = torch .tensor (idist .all_gather (10 ), device = device )
161- true_res = torch .tensor ([10 ] * idist . get_world_size () , device = device )
162+ true_res = torch .tensor ([10 ] * ws , device = device )
162163 assert (res == true_res ).all ()
163164
164165 t = torch .tensor (rank , device = device )
165166 res = idist .all_gather (t )
166- true_res = torch .tensor ([i for i in range (idist . get_world_size () )], device = device )
167+ true_res = torch .tensor ([i for i in range (ws )], device = device )
167168 assert (res == true_res ).all ()
168169
169170 x = "test-test"
170171 if rank == 0 :
171172 x = "abc"
172173 res = idist .all_gather (x )
173- true_res = ["abc" ] + ["test-test" ] * (idist . get_world_size () - 1 )
174+ true_res = ["abc" ] + ["test-test" ] * (ws - 1 )
174175 assert res == true_res
175176
176177 base_x = "tests/ignite/distributed/utils/test_native.py" * 2000
@@ -179,27 +180,46 @@ def _test_distrib_all_gather(device):
179180 x = "abc"
180181
181182 res = idist .all_gather (x )
182- true_res = ["abc" ] + [base_x ] * (idist . get_world_size () - 1 )
183+ true_res = ["abc" ] + [base_x ] * (ws - 1 )
183184 assert res == true_res
184185
185186 t = torch .arange (100 , device = device ).reshape (4 , 25 ) * (rank + 1 )
186187 in_dtype = t .dtype
187188 res = idist .all_gather (t )
188- assert res .shape == (idist . get_world_size () * 4 , 25 )
189+ assert res .shape == (ws * 4 , 25 )
189190 assert res .dtype == in_dtype
190- true_res = torch .zeros (idist . get_world_size () * 4 , 25 , device = device )
191- for i in range (idist . get_world_size () ):
191+ true_res = torch .zeros (ws * 4 , 25 , device = device )
192+ for i in range (ws ):
192193 true_res [i * 4 : (i + 1 ) * 4 , ...] = torch .arange (100 , device = device ).reshape (4 , 25 ) * (i + 1 )
193194 assert (res == true_res ).all ()
194195
195- if idist .get_world_size () > 1 :
196- with pytest .raises (TypeError , match = r"Unhandled input type" ):
197- idist .all_reduce ([0 , 1 , 2 ])
196+ if ws > 1 and idist .backend () != "xla-tpu" :
197+ t = {
198+ "a" : [rank + 1 , rank + 2 , torch .tensor (rank + 3 , device = device )],
199+ "b" : torch .tensor ([[rank + 1 , rank + 2 , rank + 3 ]], device = device ),
200+ "c" : {"abcd" : rank , "cdfg" : torch .tensor (rank , dtype = torch .uint8 , device = device )},
201+ }
202+ res = idist .all_gather (t )
203+ assert isinstance (res , list ) and len (res ) == ws
204+ for i , obj in enumerate (res ):
205+ assert isinstance (obj , dict )
206+ assert list (obj .keys ()) == ["a" , "b" , "c" ], obj
207+ expected_device = (
208+ device if torch .device (device ).type == "cpu" else torch .device (f"{ torch .device (device ).type } :{ i } " )
209+ )
210+ expected = {
211+ "a" : [i + 1 , i + 2 , torch .tensor (i + 3 , device = expected_device )],
212+ "b" : torch .tensor ([[i + 1 , i + 2 , i + 3 ]], device = expected_device ),
213+ "c" : {"abcd" : i , "cdfg" : torch .tensor (i , dtype = torch .uint8 , device = expected_device )},
214+ }
215+ assert obj ["a" ] == expected ["a" ]
216+ assert (obj ["b" ] == expected ["b" ]).all ()
217+ assert obj ["c" ] == expected ["c" ]
198218
199219
200220def _test_distrib_all_gather_group (device ):
201221 if idist .get_world_size () > 1 :
202- ranks = [ 0 , 1 ]
222+ ranks = list ( range ( idist . get_world_size () - 1 , 0 , - 1 )) # [0, 1, 2, 3] -> [3, 2 , 1]
203223 rank = idist .get_rank ()
204224 bnd = idist .backend ()
205225
@@ -226,6 +246,40 @@ def _test_distrib_all_gather_group(device):
226246 else :
227247 assert res == t
228248
249+ t = {
250+ "a" : [rank + 1 , rank + 2 , torch .tensor (rank + 3 , device = device )],
251+ "b" : torch .tensor ([[rank + 1 , rank + 2 , rank + 3 ]], device = device ),
252+ "c" : {"abcd" : rank , "cdfg" : torch .tensor (rank , dtype = torch .uint8 , device = device )},
253+ }
254+ if bnd in ("xla-tpu" ):
255+ with pytest .raises (NotImplementedError , match = r"all_gather on object is not implemented for xla" ):
256+ res = idist .all_gather (t , group = ranks )
257+ elif bnd in ("horovod" ):
258+ with pytest .raises (NotImplementedError , match = r"all_gather with group for horovod is not implemented" ):
259+ res = idist .all_gather (t , group = ranks )
260+ else :
261+ res = idist .all_gather (t , group = ranks )
262+ if rank in ranks :
263+ assert isinstance (res , list ) and len (res ) == len (ranks )
264+ for i , obj in zip (ranks , res ):
265+ assert isinstance (obj , dict )
266+ assert list (obj .keys ()) == ["a" , "b" , "c" ], obj
267+ expected_device = (
268+ device
269+ if torch .device (device ).type == "cpu"
270+ else torch .device (f"{ torch .device (device ).type } :{ i } " )
271+ )
272+ expected = {
273+ "a" : [i + 1 , i + 2 , torch .tensor (i + 3 , device = expected_device )],
274+ "b" : torch .tensor ([[i + 1 , i + 2 , i + 3 ]], device = expected_device ),
275+ "c" : {"abcd" : i , "cdfg" : torch .tensor (i , dtype = torch .uint8 , device = expected_device )},
276+ }
277+ assert obj ["a" ] == expected ["a" ], (obj , expected )
278+ assert (obj ["b" ] == expected ["b" ]).all (), (obj , expected )
279+ assert obj ["c" ] == expected ["c" ], (obj , expected )
280+ else :
281+ assert res == t
282+
229283 if bnd in ("nccl" , "gloo" , "mpi" ):
230284 with pytest .raises (ValueError , match = r"Argument group should be list of int or ProcessGroup" ):
231285 res = idist .all_gather (t , group = "abc" )
0 commit comments