diff --git a/torchrec/distributed/itep_embeddingbag.py b/torchrec/distributed/itep_embeddingbag.py index 018538020..7250077a4 100644 --- a/torchrec/distributed/itep_embeddingbag.py +++ b/torchrec/distributed/itep_embeddingbag.py @@ -123,6 +123,7 @@ def __init__( lookups=grouped_lookups[ShardingTypeGroup.CW_GROUP], pruning_interval=module._itep_module.pruning_interval, enable_pruning=module._itep_module.enable_pruning, + pg=env.process_group, ) self._rowwise_itep_module: RowwiseShardedITEPModule = RowwiseShardedITEPModule( table_name_to_sharding_type=self.table_name_to_sharding_type, @@ -132,6 +133,7 @@ def __init__( lookups=grouped_lookups[ShardingTypeGroup.RW_GROUP], pruning_interval=module._itep_module.pruning_interval, enable_pruning=module._itep_module.enable_pruning, + pg=env.process_group, ) def prefetch( @@ -389,6 +391,7 @@ def __init__( lookups=grouped_lookups[ShardingTypeGroup.CW_GROUP], pruning_interval=module._itep_module.pruning_interval, enable_pruning=module._itep_module.enable_pruning, + pg=env.process_group, ) self._rowwise_itep_module: RowwiseShardedITEPModule = RowwiseShardedITEPModule( table_name_to_unpruned_hash_sizes=grouped_table_unpruned_size_map[ @@ -398,6 +401,7 @@ def __init__( pruning_interval=module._itep_module.pruning_interval, table_name_to_sharding_type=self.table_name_to_sharding_type, enable_pruning=module._itep_module.enable_pruning, + pg=env.process_group, ) # pyre-ignore diff --git a/torchrec/modules/itep_modules.py b/torchrec/modules/itep_modules.py index 95ae6c849..8ffd4dbc0 100644 --- a/torchrec/modules/itep_modules.py +++ b/torchrec/modules/itep_modules.py @@ -548,7 +548,10 @@ def _get_local_unpruned_hash_sizes_and_offsets( shard_offsets: List[int] = [0] for rank in range(num_devices): - if sharding_type == ShardingType.ROW_WISE.value: + if ( + sharding_type == ShardingType.ROW_WISE.value + or sharding_type == ShardingType.TABLE_ROW_WISE.value + ): if rank < last_rank: local_row: int = block_size elif rank == last_rank: