Skip to content

Commit 892d66d

Browse files
aporialiaofacebook-github-bot
authored andcommitted
Add unsharded module reference to sharded modules
Summary: Adding a simple unsharded module reference to sharded modules. This will be used in Dynamic Sharding by `DistributedModelParallel` to reshard an already-sharded_module. As DMP is created with only one-way relationship in mind, accessing the unsharded module type will help determine which sharder to use in 'resharding'. Differential Revision: D73407830
1 parent 9eaec09 commit 892d66d

9 files changed

+36
-0
lines changed

torchrec/distributed/embedding.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1409,6 +1409,10 @@ def _embedding_dim_for_sharding_type(self, sharding_type: str) -> int:
14091409
def fused_optimizer(self) -> KeyedOptimizer:
14101410
return self._optim
14111411

1412+
@property
1413+
def unsharded_module_type(self) -> Type[EmbeddingCollection]:
1414+
return EmbeddingCollection
1415+
14121416
def create_context(self) -> EmbeddingCollectionContext:
14131417
return EmbeddingCollectionContext(sharding_contexts=[])
14141418

torchrec/distributed/embeddingbag.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1598,6 +1598,10 @@ def create_context(self) -> EmbeddingBagCollectionContext:
15981598
def extend_shard_name(shard_name: str) -> str:
15991599
return f"embedding_bags.{shard_name}.weight"
16001600

1601+
@property
1602+
def unsharded_module_type(self) -> Type[EmbeddingBagCollection]:
1603+
return EmbeddingBagCollection
1604+
16011605

16021606
class EmbeddingBagCollectionSharder(BaseEmbeddingSharder[EmbeddingBagCollection]):
16031607
"""

torchrec/distributed/fp_embeddingbag.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -161,6 +161,10 @@ def sharded_parameter_names(self, prefix: str = "") -> Iterator[str]:
161161
if "_embedding_bag_collection" in fqn:
162162
yield append_prefix(prefix, fqn)
163163

164+
@property
165+
def unsharded_module_type(self) -> Type[FeatureProcessedEmbeddingBagCollection]:
166+
return FeatureProcessedEmbeddingBagCollection
167+
164168

165169
class FeatureProcessedEmbeddingBagCollectionSharder(
166170
BaseEmbeddingSharder[FeatureProcessedEmbeddingBagCollection]

torchrec/distributed/fused_embeddingbag.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,10 @@ def __init__(
8585
# We need to ensure that a checkpoint from DDP and a checkpoint from a
8686
# model parallel version are compatible.
8787

88+
@property
89+
def unsharded_module_type(self) -> Type[FusedEmbeddingBagCollection]:
90+
return FusedEmbeddingBagCollection
91+
8892

8993
class FusedEmbeddingBagCollectionSharder(
9094
BaseEmbeddingSharder[FusedEmbeddingBagCollection]

torchrec/distributed/mc_embedding.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,10 @@ def create_context(
9797
) -> ManagedCollisionEmbeddingCollectionContext:
9898
return ManagedCollisionEmbeddingCollectionContext(sharding_contexts=[])
9999

100+
@property
101+
def unsharded_module_type(self) -> Type[ManagedCollisionEmbeddingCollection]:
102+
return ManagedCollisionEmbeddingCollection
103+
100104

101105
class ManagedCollisionEmbeddingCollectionSharder(
102106
BaseManagedCollisionEmbeddingCollectionSharder[ManagedCollisionEmbeddingCollection]

torchrec/distributed/mc_embeddingbag.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,10 @@ def create_context(
8282
) -> ManagedCollisionEmbeddingBagCollectionContext:
8383
return ManagedCollisionEmbeddingBagCollectionContext(sharding_contexts=[])
8484

85+
@property
86+
def unsharded_module_type(self) -> Type[ManagedCollisionEmbeddingBagCollection]:
87+
return ManagedCollisionEmbeddingBagCollection
88+
8589

8690
class ManagedCollisionEmbeddingBagCollectionSharder(
8791
BaseManagedCollisionEmbeddingCollectionSharder[

torchrec/distributed/quant_embedding.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1320,6 +1320,10 @@ def sharded_parameter_names(self, prefix: str = "") -> Iterator[str]:
13201320
for fqn, _ in self.named_buffers():
13211321
yield append_prefix(prefix, fqn)
13221322

1323+
@property
1324+
def unsharded_module_type(self) -> Type[QuantManagedCollisionEmbeddingCollection]:
1325+
return QuantManagedCollisionEmbeddingCollection
1326+
13231327

13241328
class QuantManagedCollisionEmbeddingCollectionSharder(
13251329
BaseQuantEmbeddingSharder[QuantManagedCollisionEmbeddingCollection]

torchrec/distributed/quant_embeddingbag.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -383,6 +383,10 @@ def create_context(self) -> NullShardedModuleContext:
383383

384384
return NullShardedModuleContext()
385385

386+
@property
387+
def unsharded_module_type(self) -> Type[QuantEmbeddingBagCollection]:
388+
return QuantEmbeddingBagCollection
389+
386390

387391
class QuantEmbeddingBagCollectionSharder(
388392
BaseQuantEmbeddingSharder[QuantEmbeddingBagCollection]

torchrec/distributed/types.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1034,6 +1034,10 @@ def sharded_parameter_names(self, prefix: str = "") -> Iterator[str]:
10341034
for key, _ in self.named_parameters(prefix):
10351035
yield key
10361036

1037+
@property
1038+
@abc.abstractmethod
1039+
def unsharded_module_type(self) -> Type[nn.Module]: ...
1040+
10371041

10381042
def get_tensor_size_bytes(t: torch.Tensor) -> int:
10391043
b: int = t.numel() * t.element_size()

0 commit comments

Comments
 (0)