Skip to content

Commit 10f1c7d

Browse files
TroyGardenfacebook-github-bot
authored andcommitted
PMT (#3023)
Summary: Pull Request resolved: #3023 # context * `_test_sharding` is frequently used test function covering many TorchRec sharding test cases * the multiprocess env often introduces additional difficulties when debugging, espeically for kernel-size issues (the multiprocess env is not actually needed) * this change make it run on the main process when the `world_size==1` so that a simple `breakpoint()` can just work. Reviewed By: iamzainhuda Differential Revision: D74131796 fbshipit-source-id: ccc34ab589c0153cc0ce1187bba3df7dd63cbfc6
1 parent 151aa02 commit 10f1c7d

File tree

1 file changed

+57
-27
lines changed

1 file changed

+57
-27
lines changed

torchrec/distributed/test_utils/test_model_parallel.py

Lines changed: 57 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -163,33 +163,63 @@ def _test_sharding(
163163
lengths_dtype: torch.dtype = torch.int64,
164164
) -> None:
165165
self._build_tables_and_groups(data_type=data_type)
166-
self._run_multi_process_test(
167-
callable=sharding_single_rank_test,
168-
world_size=world_size,
169-
local_size=local_size,
170-
world_size_2D=world_size_2D,
171-
node_group_size=node_group_size,
172-
model_class=model_class,
173-
tables=self.tables if pooling == PoolingType.SUM else self.mean_tables,
174-
weighted_tables=self.weighted_tables if has_weighted_tables else None,
175-
embedding_groups=self.embedding_groups,
176-
sharders=sharders,
177-
backend=backend,
178-
optim=EmbOptimType.EXACT_SGD,
179-
constraints=constraints,
180-
qcomms_config=qcomms_config,
181-
variable_batch_size=variable_batch_size,
182-
apply_optimizer_in_backward_config=apply_optimizer_in_backward_config,
183-
variable_batch_per_feature=variable_batch_per_feature,
184-
global_constant_batch=global_constant_batch,
185-
use_inter_host_allreduce=use_inter_host_allreduce,
186-
allow_zero_batch_size=allow_zero_batch_size,
187-
custom_all_reduce=custom_all_reduce,
188-
use_offsets=use_offsets,
189-
indices_dtype=indices_dtype,
190-
offsets_dtype=offsets_dtype,
191-
lengths_dtype=lengths_dtype,
192-
)
166+
# directly run the test with single process
167+
if world_size == 1:
168+
sharding_single_rank_test(
169+
rank=0,
170+
world_size=world_size,
171+
local_size=local_size,
172+
world_size_2D=world_size_2D,
173+
node_group_size=node_group_size,
174+
model_class=model_class, # pyre-ignore[6]
175+
tables=self.tables if pooling == PoolingType.SUM else self.mean_tables,
176+
weighted_tables=self.weighted_tables if has_weighted_tables else None,
177+
embedding_groups=self.embedding_groups,
178+
sharders=sharders,
179+
backend=backend,
180+
optim=EmbOptimType.EXACT_SGD,
181+
constraints=constraints,
182+
qcomms_config=qcomms_config,
183+
variable_batch_size=variable_batch_size,
184+
apply_optimizer_in_backward_config=apply_optimizer_in_backward_config,
185+
variable_batch_per_feature=variable_batch_per_feature,
186+
global_constant_batch=global_constant_batch,
187+
use_inter_host_allreduce=use_inter_host_allreduce,
188+
allow_zero_batch_size=allow_zero_batch_size,
189+
custom_all_reduce=custom_all_reduce,
190+
use_offsets=use_offsets,
191+
indices_dtype=indices_dtype,
192+
offsets_dtype=offsets_dtype,
193+
lengths_dtype=lengths_dtype,
194+
)
195+
else:
196+
self._run_multi_process_test(
197+
callable=sharding_single_rank_test,
198+
world_size=world_size,
199+
local_size=local_size,
200+
world_size_2D=world_size_2D,
201+
node_group_size=node_group_size,
202+
model_class=model_class,
203+
tables=self.tables if pooling == PoolingType.SUM else self.mean_tables,
204+
weighted_tables=self.weighted_tables if has_weighted_tables else None,
205+
embedding_groups=self.embedding_groups,
206+
sharders=sharders,
207+
backend=backend,
208+
optim=EmbOptimType.EXACT_SGD,
209+
constraints=constraints,
210+
qcomms_config=qcomms_config,
211+
variable_batch_size=variable_batch_size,
212+
apply_optimizer_in_backward_config=apply_optimizer_in_backward_config,
213+
variable_batch_per_feature=variable_batch_per_feature,
214+
global_constant_batch=global_constant_batch,
215+
use_inter_host_allreduce=use_inter_host_allreduce,
216+
allow_zero_batch_size=allow_zero_batch_size,
217+
custom_all_reduce=custom_all_reduce,
218+
use_offsets=use_offsets,
219+
indices_dtype=indices_dtype,
220+
offsets_dtype=offsets_dtype,
221+
lengths_dtype=lengths_dtype,
222+
)
193223

194224
def _test_dynamic_sharding(
195225
self,

0 commit comments

Comments
 (0)