From 103f39689a2015f6b6afb1a3ad51f0b4ae33cdfc Mon Sep 17 00:00:00 2001 From: Felicity Liao <11263993+aporialiao@users.noreply.github.com> Date: Wed, 23 Apr 2025 11:55:00 -0700 Subject: [PATCH] Add unsharded module property to sharded modules (#2901) Summary: Pull Request resolved: https://github.com/pytorch/torchrec/pull/2901 Adding a simple unsharded module reference to sharded modules. This will be used in Dynamic Sharding by `DistributedModelParallel` to reshard an already-sharded_module. As DMP is created with only one-way relationship in mind, accessing the unsharded module type will help determine which sharder to use in 'resharding'. Most of the changes here are simply to add in the property in where ShardedModule or it's wrapper ShardedEmbeddingModule is used. Differential Revision: D73407830 --- torchrec/distributed/embedding.py | 4 ++++ .../distributed/embedding_tower_sharding.py | 8 +++++++ torchrec/distributed/embedding_types.py | 23 ++++++++++++++++++- torchrec/distributed/embeddingbag.py | 8 +++++++ torchrec/distributed/fp_embeddingbag.py | 4 ++++ torchrec/distributed/fused_embeddingbag.py | 4 ++++ torchrec/distributed/itep_embeddingbag.py | 8 +++++++ torchrec/distributed/mc_embedding.py | 4 ++++ torchrec/distributed/mc_embeddingbag.py | 4 ++++ torchrec/distributed/quant_embedding.py | 4 ++++ torchrec/distributed/quant_embeddingbag.py | 4 ++++ torchrec/distributed/quant_state.py | 14 ++++++++++- .../distributed/tests/test_embedding_types.py | 8 ++++++- torchrec/distributed/types.py | 13 +++++++++++ 14 files changed, 107 insertions(+), 3 deletions(-) diff --git a/torchrec/distributed/embedding.py b/torchrec/distributed/embedding.py index 2f006758e..a9c4e5c9d 100644 --- a/torchrec/distributed/embedding.py +++ b/torchrec/distributed/embedding.py @@ -1409,6 +1409,10 @@ def _embedding_dim_for_sharding_type(self, sharding_type: str) -> int: def fused_optimizer(self) -> KeyedOptimizer: return self._optim + @property + def unsharded_module_type(self) -> Type[EmbeddingCollection]: + return EmbeddingCollection + def create_context(self) -> EmbeddingCollectionContext: return EmbeddingCollectionContext(sharding_contexts=[]) diff --git a/torchrec/distributed/embedding_tower_sharding.py b/torchrec/distributed/embedding_tower_sharding.py index 2fdca6b5c..44a58b916 100644 --- a/torchrec/distributed/embedding_tower_sharding.py +++ b/torchrec/distributed/embedding_tower_sharding.py @@ -438,6 +438,10 @@ def named_modules( def create_context(self) -> NullShardedModuleContext: return NullShardedModuleContext() + @property + def unsharded_module_type(self) -> Type[EmbeddingTower]: + return EmbeddingTower + class ShardedEmbeddingTowerCollection( ShardedEmbeddingModule[ @@ -941,6 +945,10 @@ def embedding_feature_names( kjt_features.extend(config.feature_names) return kjt_features, wkjt_features + @property + def unsharded_module_type(self) -> Type[EmbeddingTowerCollection]: + return EmbeddingTowerCollection + class EmbeddingTowerCollectionSharder(BaseEmbeddingSharder[EmbeddingTowerCollection]): def __init__( diff --git a/torchrec/distributed/embedding_types.py b/torchrec/distributed/embedding_types.py index 3b2b127a0..b66f92f16 100644 --- a/torchrec/distributed/embedding_types.py +++ b/torchrec/distributed/embedding_types.py @@ -11,7 +11,18 @@ import copy from dataclasses import dataclass from enum import Enum, unique -from typing import Any, Dict, Generic, Iterator, List, Optional, Tuple, TypeVar, Union +from typing import ( + Any, + Dict, + Generic, + Iterator, + List, + Optional, + Tuple, + Type, + TypeVar, + Union, +) import torch from fbgemm_gpu.split_table_batched_embeddings_ops_training import EmbeddingLocation @@ -399,6 +410,16 @@ def train(self, mode: bool = True): # pyre-ignore[3] return self + @property + def unsharded_module_type(self) -> Type[nn.Module]: + """ + As this is the generic ShardedEmbeddingModule class, simply + return the generic nn.Module type. In the inherited classes of + ShardedEmbeddingModule, the specific unsharded module type will + be returned. + """ + return nn.Module + M = TypeVar("M", bound=nn.Module) diff --git a/torchrec/distributed/embeddingbag.py b/torchrec/distributed/embeddingbag.py index 4cb1d62c2..1a4263ea9 100644 --- a/torchrec/distributed/embeddingbag.py +++ b/torchrec/distributed/embeddingbag.py @@ -1598,6 +1598,10 @@ def create_context(self) -> EmbeddingBagCollectionContext: def extend_shard_name(shard_name: str) -> str: return f"embedding_bags.{shard_name}.weight" + @property + def unsharded_module_type(self) -> Type[EmbeddingBagCollection]: + return EmbeddingBagCollection + class EmbeddingBagCollectionSharder(BaseEmbeddingSharder[EmbeddingBagCollection]): """ @@ -1887,6 +1891,10 @@ def fused_optimizer(self) -> KeyedOptimizer: def create_context(self) -> NullShardedModuleContext: return NullShardedModuleContext() + @property + def unsharded_module_type(self) -> Type[nn.EmbeddingBag]: + return nn.EmbeddingBag + class EmbeddingBagSharder(BaseEmbeddingSharder[nn.EmbeddingBag]): """ diff --git a/torchrec/distributed/fp_embeddingbag.py b/torchrec/distributed/fp_embeddingbag.py index 872fa6aa6..9e209cbc5 100644 --- a/torchrec/distributed/fp_embeddingbag.py +++ b/torchrec/distributed/fp_embeddingbag.py @@ -161,6 +161,10 @@ def sharded_parameter_names(self, prefix: str = "") -> Iterator[str]: if "_embedding_bag_collection" in fqn: yield append_prefix(prefix, fqn) + @property + def unsharded_module_type(self) -> Type[FeatureProcessedEmbeddingBagCollection]: + return FeatureProcessedEmbeddingBagCollection + class FeatureProcessedEmbeddingBagCollectionSharder( BaseEmbeddingSharder[FeatureProcessedEmbeddingBagCollection] diff --git a/torchrec/distributed/fused_embeddingbag.py b/torchrec/distributed/fused_embeddingbag.py index 43eeda323..067670668 100644 --- a/torchrec/distributed/fused_embeddingbag.py +++ b/torchrec/distributed/fused_embeddingbag.py @@ -85,6 +85,10 @@ def __init__( # We need to ensure that a checkpoint from DDP and a checkpoint from a # model parallel version are compatible. + @property + def unsharded_module_type(self) -> Type[FusedEmbeddingBagCollection]: + return FusedEmbeddingBagCollection + class FusedEmbeddingBagCollectionSharder( BaseEmbeddingSharder[FusedEmbeddingBagCollection] diff --git a/torchrec/distributed/itep_embeddingbag.py b/torchrec/distributed/itep_embeddingbag.py index 7250077a4..4fb1350da 100644 --- a/torchrec/distributed/itep_embeddingbag.py +++ b/torchrec/distributed/itep_embeddingbag.py @@ -274,6 +274,10 @@ def _group_lookups_and_table_unpruned_size_map( return grouped_lookups, grouped_table_unpruned_size_map + @property + def unsharded_module_type(self) -> Type[ITEPEmbeddingBagCollection]: + return ITEPEmbeddingBagCollection + class ITEPEmbeddingBagCollectionSharder( BaseEmbeddingSharder[ITEPEmbeddingBagCollection] @@ -523,6 +527,10 @@ def _group_lookups_and_table_unpruned_size_map( return grouped_lookups, grouped_table_unpruned_size_map + @property + def unsharded_module_type(self) -> Type[ITEPEmbeddingCollection]: + return ITEPEmbeddingCollection + class ITEPEmbeddingCollectionSharder(BaseEmbeddingSharder[ITEPEmbeddingCollection]): def __init__( diff --git a/torchrec/distributed/mc_embedding.py b/torchrec/distributed/mc_embedding.py index 0d939632e..0b1c0eee9 100644 --- a/torchrec/distributed/mc_embedding.py +++ b/torchrec/distributed/mc_embedding.py @@ -97,6 +97,10 @@ def create_context( ) -> ManagedCollisionEmbeddingCollectionContext: return ManagedCollisionEmbeddingCollectionContext(sharding_contexts=[]) + @property + def unsharded_module_type(self) -> Type[ManagedCollisionEmbeddingCollection]: + return ManagedCollisionEmbeddingCollection + class ManagedCollisionEmbeddingCollectionSharder( BaseManagedCollisionEmbeddingCollectionSharder[ManagedCollisionEmbeddingCollection] diff --git a/torchrec/distributed/mc_embeddingbag.py b/torchrec/distributed/mc_embeddingbag.py index e94d42d59..347582814 100644 --- a/torchrec/distributed/mc_embeddingbag.py +++ b/torchrec/distributed/mc_embeddingbag.py @@ -82,6 +82,10 @@ def create_context( ) -> ManagedCollisionEmbeddingBagCollectionContext: return ManagedCollisionEmbeddingBagCollectionContext(sharding_contexts=[]) + @property + def unsharded_module_type(self) -> Type[ManagedCollisionEmbeddingBagCollection]: + return ManagedCollisionEmbeddingBagCollection + class ManagedCollisionEmbeddingBagCollectionSharder( BaseManagedCollisionEmbeddingCollectionSharder[ diff --git a/torchrec/distributed/quant_embedding.py b/torchrec/distributed/quant_embedding.py index 792fdeb0a..88e0f3d18 100644 --- a/torchrec/distributed/quant_embedding.py +++ b/torchrec/distributed/quant_embedding.py @@ -1320,6 +1320,10 @@ def sharded_parameter_names(self, prefix: str = "") -> Iterator[str]: for fqn, _ in self.named_buffers(): yield append_prefix(prefix, fqn) + @property + def unsharded_module_type(self) -> Type[QuantManagedCollisionEmbeddingCollection]: + return QuantManagedCollisionEmbeddingCollection + class QuantManagedCollisionEmbeddingCollectionSharder( BaseQuantEmbeddingSharder[QuantManagedCollisionEmbeddingCollection] diff --git a/torchrec/distributed/quant_embeddingbag.py b/torchrec/distributed/quant_embeddingbag.py index e666841b9..09dd3f04f 100644 --- a/torchrec/distributed/quant_embeddingbag.py +++ b/torchrec/distributed/quant_embeddingbag.py @@ -383,6 +383,10 @@ def create_context(self) -> NullShardedModuleContext: return NullShardedModuleContext() + @property + def unsharded_module_type(self) -> Type[QuantEmbeddingBagCollection]: + return QuantEmbeddingBagCollection + class QuantEmbeddingBagCollectionSharder( BaseQuantEmbeddingSharder[QuantEmbeddingBagCollection] diff --git a/torchrec/distributed/quant_state.py b/torchrec/distributed/quant_state.py index 1de388e1b..03cd30d12 100644 --- a/torchrec/distributed/quant_state.py +++ b/torchrec/distributed/quant_state.py @@ -10,12 +10,13 @@ import copy from dataclasses import dataclass from functools import partial -from typing import Any, Dict, List, Mapping, Optional, Tuple, TypeVar, Union +from typing import Any, Dict, List, Mapping, Optional, Tuple, Type, TypeVar, Union import torch from fbgemm_gpu.split_table_batched_embeddings_ops_inference import ( IntNBitTableBatchedEmbeddingBagsCodegen, ) +from torch import nn from torch.distributed import _remote_device from torch.distributed._shard.sharded_tensor import ( Shard, @@ -367,6 +368,17 @@ def _load_from_state_dict( missing_keys.extend(_missing_keys) unexpected_keys.extend(_unexpected_keys) + @property + def unsharded_module_type(self) -> Type[nn.Module]: + """ + Since ShardedQuantEmbeddingModuleState is not exactly a sharded module + but rather a class to utilize generic helper functions. Returns generic + nn.Module type. + """ + + # TODO: Add test in TorchRec for using ShardedQuantEmbeddingModuleState + return nn.Module + @dataclass class WeightSpec: diff --git a/torchrec/distributed/tests/test_embedding_types.py b/torchrec/distributed/tests/test_embedding_types.py index db9f660b7..cd7c3cca1 100644 --- a/torchrec/distributed/tests/test_embedding_types.py +++ b/torchrec/distributed/tests/test_embedding_types.py @@ -8,9 +8,10 @@ # pyre-strict import unittest -from typing import Dict, List +from typing import Dict, List, Type import torch +from torch import nn from torchrec.distributed.embedding_types import KJTList, ShardedEmbeddingModule from torchrec.distributed.embeddingbag import EmbeddingBagCollectionContext from torchrec.distributed.types import Awaitable, LazyAwaitable @@ -55,6 +56,11 @@ def compute(self, ctx: ShrdCtx, dist_input: CompIn) -> DistOut: def output_dist(self, ctx: ShrdCtx, output: DistOut) -> LazyAwaitable[Out]: pass + @property + def unsharded_module_type(self) -> Type[nn.Module]: + # Since this is a fake sharded embedding module, just returning default module + return nn.Module + class TestShardedEmbeddingModule(unittest.TestCase): def test_train_mode(self) -> None: diff --git a/torchrec/distributed/types.py b/torchrec/distributed/types.py index 45300f733..59fe483dc 100644 --- a/torchrec/distributed/types.py +++ b/torchrec/distributed/types.py @@ -1034,6 +1034,19 @@ def sharded_parameter_names(self, prefix: str = "") -> Iterator[str]: for key, _ in self.named_parameters(prefix): yield key + @property + @abc.abstractmethod + def unsharded_module_type(self) -> Type[nn.Module]: + """ + This property is added as part of dynamic sharding implementation. + + When resharding an already-sharded module wrapped in DMP, the unsharded + module type is needed to identify the proper sharder to reshard. This is + due to DistributedModelParellel (DMP) references module Sharders based + on the unsharded module type. + """ + ... + def get_tensor_size_bytes(t: torch.Tensor) -> int: b: int = t.numel() * t.element_size()