@@ -163,33 +163,63 @@ def _test_sharding(
163
163
lengths_dtype : torch .dtype = torch .int64 ,
164
164
) -> None :
165
165
self ._build_tables_and_groups (data_type = data_type )
166
- self ._run_multi_process_test (
167
- callable = sharding_single_rank_test ,
168
- world_size = world_size ,
169
- local_size = local_size ,
170
- world_size_2D = world_size_2D ,
171
- node_group_size = node_group_size ,
172
- model_class = model_class ,
173
- tables = self .tables if pooling == PoolingType .SUM else self .mean_tables ,
174
- weighted_tables = self .weighted_tables if has_weighted_tables else None ,
175
- embedding_groups = self .embedding_groups ,
176
- sharders = sharders ,
177
- backend = backend ,
178
- optim = EmbOptimType .EXACT_SGD ,
179
- constraints = constraints ,
180
- qcomms_config = qcomms_config ,
181
- variable_batch_size = variable_batch_size ,
182
- apply_optimizer_in_backward_config = apply_optimizer_in_backward_config ,
183
- variable_batch_per_feature = variable_batch_per_feature ,
184
- global_constant_batch = global_constant_batch ,
185
- use_inter_host_allreduce = use_inter_host_allreduce ,
186
- allow_zero_batch_size = allow_zero_batch_size ,
187
- custom_all_reduce = custom_all_reduce ,
188
- use_offsets = use_offsets ,
189
- indices_dtype = indices_dtype ,
190
- offsets_dtype = offsets_dtype ,
191
- lengths_dtype = lengths_dtype ,
192
- )
166
+ # directly run the test with single process
167
+ if world_size == 1 :
168
+ sharding_single_rank_test (
169
+ rank = 0 ,
170
+ world_size = world_size ,
171
+ local_size = local_size ,
172
+ world_size_2D = world_size_2D ,
173
+ node_group_size = node_group_size ,
174
+ model_class = model_class , # pyre-ignore[6]
175
+ tables = self .tables if pooling == PoolingType .SUM else self .mean_tables ,
176
+ weighted_tables = self .weighted_tables if has_weighted_tables else None ,
177
+ embedding_groups = self .embedding_groups ,
178
+ sharders = sharders ,
179
+ backend = backend ,
180
+ optim = EmbOptimType .EXACT_SGD ,
181
+ constraints = constraints ,
182
+ qcomms_config = qcomms_config ,
183
+ variable_batch_size = variable_batch_size ,
184
+ apply_optimizer_in_backward_config = apply_optimizer_in_backward_config ,
185
+ variable_batch_per_feature = variable_batch_per_feature ,
186
+ global_constant_batch = global_constant_batch ,
187
+ use_inter_host_allreduce = use_inter_host_allreduce ,
188
+ allow_zero_batch_size = allow_zero_batch_size ,
189
+ custom_all_reduce = custom_all_reduce ,
190
+ use_offsets = use_offsets ,
191
+ indices_dtype = indices_dtype ,
192
+ offsets_dtype = offsets_dtype ,
193
+ lengths_dtype = lengths_dtype ,
194
+ )
195
+ else :
196
+ self ._run_multi_process_test (
197
+ callable = sharding_single_rank_test ,
198
+ world_size = world_size ,
199
+ local_size = local_size ,
200
+ world_size_2D = world_size_2D ,
201
+ node_group_size = node_group_size ,
202
+ model_class = model_class ,
203
+ tables = self .tables if pooling == PoolingType .SUM else self .mean_tables ,
204
+ weighted_tables = self .weighted_tables if has_weighted_tables else None ,
205
+ embedding_groups = self .embedding_groups ,
206
+ sharders = sharders ,
207
+ backend = backend ,
208
+ optim = EmbOptimType .EXACT_SGD ,
209
+ constraints = constraints ,
210
+ qcomms_config = qcomms_config ,
211
+ variable_batch_size = variable_batch_size ,
212
+ apply_optimizer_in_backward_config = apply_optimizer_in_backward_config ,
213
+ variable_batch_per_feature = variable_batch_per_feature ,
214
+ global_constant_batch = global_constant_batch ,
215
+ use_inter_host_allreduce = use_inter_host_allreduce ,
216
+ allow_zero_batch_size = allow_zero_batch_size ,
217
+ custom_all_reduce = custom_all_reduce ,
218
+ use_offsets = use_offsets ,
219
+ indices_dtype = indices_dtype ,
220
+ offsets_dtype = offsets_dtype ,
221
+ lengths_dtype = lengths_dtype ,
222
+ )
193
223
194
224
def _test_dynamic_sharding (
195
225
self ,
0 commit comments