Skip to content

Commit e4397a7

Browse files
Jasper Shanfacebook-github-bot
Jasper Shan
authored andcommitted
Fix RW Support and checkpointing
Summary: Fixes a bug in RW / TWRW checkpointing where the buffers being saved wouldn't represent all of the shards that existed during training. Reviewed By: Nayef211 Differential Revision: D73144116
1 parent 6aaf1fa commit e4397a7

File tree

2 files changed

+8
-1
lines changed

2 files changed

+8
-1
lines changed

torchrec/distributed/itep_embeddingbag.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -123,6 +123,7 @@ def __init__(
123123
lookups=grouped_lookups[ShardingTypeGroup.CW_GROUP],
124124
pruning_interval=module._itep_module.pruning_interval,
125125
enable_pruning=module._itep_module.enable_pruning,
126+
pg=env.process_group,
126127
)
127128
self._rowwise_itep_module: RowwiseShardedITEPModule = RowwiseShardedITEPModule(
128129
table_name_to_sharding_type=self.table_name_to_sharding_type,
@@ -132,6 +133,7 @@ def __init__(
132133
lookups=grouped_lookups[ShardingTypeGroup.RW_GROUP],
133134
pruning_interval=module._itep_module.pruning_interval,
134135
enable_pruning=module._itep_module.enable_pruning,
136+
pg=env.process_group,
135137
)
136138

137139
def prefetch(
@@ -389,6 +391,7 @@ def __init__(
389391
lookups=grouped_lookups[ShardingTypeGroup.CW_GROUP],
390392
pruning_interval=module._itep_module.pruning_interval,
391393
enable_pruning=module._itep_module.enable_pruning,
394+
pg=env.process_group,
392395
)
393396
self._rowwise_itep_module: RowwiseShardedITEPModule = RowwiseShardedITEPModule(
394397
table_name_to_unpruned_hash_sizes=grouped_table_unpruned_size_map[
@@ -398,6 +401,7 @@ def __init__(
398401
pruning_interval=module._itep_module.pruning_interval,
399402
table_name_to_sharding_type=self.table_name_to_sharding_type,
400403
enable_pruning=module._itep_module.enable_pruning,
404+
pg=env.process_group,
401405
)
402406

403407
# pyre-ignore

torchrec/modules/itep_modules.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -548,7 +548,10 @@ def _get_local_unpruned_hash_sizes_and_offsets(
548548
shard_offsets: List[int] = [0]
549549

550550
for rank in range(num_devices):
551-
if sharding_type == ShardingType.ROW_WISE.value:
551+
if (
552+
sharding_type == ShardingType.ROW_WISE.value
553+
or sharding_type == ShardingType.TABLE_ROW_WISE.value
554+
):
552555
if rank < last_rank:
553556
local_row: int = block_size
554557
elif rank == last_rank:

0 commit comments

Comments
 (0)