From 57a557c8a001975d7a9323c1d8ebf330b3bb5712 Mon Sep 17 00:00:00 2001 From: Zain Huda Date: Wed, 1 Oct 2025 11:41:47 -0700 Subject: [PATCH] Add row based sharding support for FeaturedProcessedEBC Summary: X-link: https://github.com/pytorch/torchrec/pull/3281 In this diff we introduce row based sharding (TWRW, RW, GRID) type support for feature processors. Previously, feature processors did not support row based sharding since feature processors are data parallel. This means by splitting up the input for row based shards the accessed feature processor weights were in correct. In column/data sharding based approaches, the data is duplicated ensuring the correct weight is accessed across ranks. The indices/buckets are calculated post input split/distribution, to make it compatible with row based sharding we calculate this pre input split/distribution. This couples the train pipeline and feature processors. For each feature, we preprocess the input and place the calculated indices in KJT.weights, this propagates the indices correctly and indexs into the right weight to use for the final step in the feature processing. This applies in both pipelined and non pipelined situations - the input modification is done either at the pipelined forward call or in the input dist of the FPEBC. This is determined by the pipelining flag set through rewrite_model in train pipeline. Differential Revision: D82248545 --- torchrec/distributed/fp_embeddingbag.py | 59 ++++++- .../distributed/tests/test_fp_embeddingbag.py | 1 - .../tests/test_fp_embeddingbag_utils.py | 7 +- .../tests/test_train_pipelines.py | 165 +++++++++++++++++- .../tests/test_train_pipelines_base.py | 2 +- torchrec/distributed/train_pipeline/utils.py | 3 + torchrec/distributed/types.py | 20 +++ torchrec/distributed/utils.py | 51 +++++- torchrec/modules/feature_processor_.py | 49 ++++-- 9 files changed, 331 insertions(+), 26 deletions(-) diff --git a/torchrec/distributed/fp_embeddingbag.py b/torchrec/distributed/fp_embeddingbag.py index 4b069437f..3d7fd4140 100644 --- a/torchrec/distributed/fp_embeddingbag.py +++ b/torchrec/distributed/fp_embeddingbag.py @@ -8,7 +8,18 @@ # pyre-strict from functools import partial -from typing import Any, Dict, Iterator, List, Optional, Type, Union +from typing import ( + Any, + Dict, + Iterator, + List, + Mapping, + Optional, + Tuple, + Type, + TypeVar, + Union, +) import torch from torch import nn @@ -31,7 +42,11 @@ ShardingEnv, ShardingType, ) -from torchrec.distributed.utils import append_prefix, init_parameters +from torchrec.distributed.utils import ( + append_prefix, + init_parameters, + modify_input_for_feature_processor, +) from torchrec.modules.feature_processor_ import FeatureProcessorsCollection from torchrec.modules.fp_embedding_modules import ( apply_feature_processors_to_kjt, @@ -39,6 +54,8 @@ ) from torchrec.sparse.jagged_tensor import KeyedJaggedTensor, KeyedTensor +_T = TypeVar("_T") + def param_dp_sync(kt: KeyedTensor, no_op_tensor: torch.Tensor) -> KeyedTensor: kt._values.add_(no_op_tensor) @@ -74,6 +91,16 @@ def __init__( ) ) + self._row_wise_sharded: bool = False + for param_sharding in table_name_to_parameter_sharding.values(): + if param_sharding.sharding_type in [ + ShardingType.ROW_WISE.value, + ShardingType.TABLE_ROW_WISE.value, + ShardingType.GRID_SHARD.value, + ]: + self._row_wise_sharded = True + break + self._lookups: List[nn.Module] = self._embedding_bag_collection._lookups self._is_collection: bool = False @@ -96,6 +123,11 @@ def __init__( def input_dist( self, ctx: EmbeddingBagCollectionContext, features: KeyedJaggedTensor ) -> Awaitable[Awaitable[KJTList]]: + if not self.is_pipelined and self._row_wise_sharded: + # transform input to support row based sharding when not pipelined + modify_input_for_feature_processor( + features, self._feature_processors, self._is_collection + ) return self._embedding_bag_collection.input_dist(ctx, features) def apply_feature_processors_to_kjt_list(self, dist_input: KJTList) -> KJTList: @@ -105,10 +137,7 @@ def apply_feature_processors_to_kjt_list(self, dist_input: KJTList) -> KJTList: kjt_list.append(self._feature_processors(features)) else: kjt_list.append( - apply_feature_processors_to_kjt( - features, - self._feature_processors, - ) + apply_feature_processors_to_kjt(features, self._feature_processors) ) return KJTList(kjt_list) @@ -117,7 +146,6 @@ def compute( ctx: EmbeddingBagCollectionContext, dist_input: KJTList, ) -> List[torch.Tensor]: - fp_features = self.apply_feature_processors_to_kjt_list(dist_input) return self._embedding_bag_collection.compute(ctx, fp_features) @@ -166,6 +194,18 @@ def sharded_parameter_names(self, prefix: str = "") -> Iterator[str]: def _initialize_torch_state(self, skip_registering: bool = False) -> None: # noqa self._embedding_bag_collection._initialize_torch_state(skip_registering) + def preprocess_input( + self, args: List[_T], kwargs: Mapping[str, _T] + ) -> Tuple[List[_T], Mapping[str, _T]]: + for x in args + list(kwargs.values()): + if isinstance(x, KeyedJaggedTensor): + modify_input_for_feature_processor( + features=x, + feature_processors=self._feature_processors, + is_collection=self._is_collection, + ) + return args, kwargs + class FeatureProcessedEmbeddingBagCollectionSharder( BaseEmbeddingSharder[FeatureProcessedEmbeddingBagCollection] @@ -191,7 +231,6 @@ def shard( device: Optional[torch.device] = None, module_fqn: Optional[str] = None, ) -> ShardedFeatureProcessedEmbeddingBagCollection: - if device is None: device = torch.device("cuda") @@ -228,12 +267,14 @@ def sharding_types(self, compute_device_type: str) -> List[str]: if compute_device_type in {"mtia"}: return [ShardingType.TABLE_WISE.value, ShardingType.COLUMN_WISE.value] - # No row wise because position weighted FP and RW don't play well together. types = [ ShardingType.DATA_PARALLEL.value, ShardingType.TABLE_WISE.value, ShardingType.COLUMN_WISE.value, ShardingType.TABLE_COLUMN_WISE.value, + ShardingType.TABLE_ROW_WISE.value, + ShardingType.ROW_WISE.value, + ShardingType.GRID_SHARD.value, ] return types diff --git a/torchrec/distributed/tests/test_fp_embeddingbag.py b/torchrec/distributed/tests/test_fp_embeddingbag.py index 130776919..08f5dfdbb 100644 --- a/torchrec/distributed/tests/test_fp_embeddingbag.py +++ b/torchrec/distributed/tests/test_fp_embeddingbag.py @@ -231,7 +231,6 @@ class ShardedEmbeddingBagCollectionParallelTest(MultiProcessTestBase): def test_sharding_ebc( self, set_gradient_division: bool, use_dmp: bool, use_fp_collection: bool ) -> None: - import hypothesis # don't need to test entire matrix diff --git a/torchrec/distributed/tests/test_fp_embeddingbag_utils.py b/torchrec/distributed/tests/test_fp_embeddingbag_utils.py index 8efacdbb8..f7027b198 100644 --- a/torchrec/distributed/tests/test_fp_embeddingbag_utils.py +++ b/torchrec/distributed/tests/test_fp_embeddingbag_utils.py @@ -86,7 +86,12 @@ def forward(self, kjt: KeyedJaggedTensor) -> Tuple[torch.Tensor, torch.Tensor]: pred = torch.cat( [ fp_ebc_out[key] - for key in ["feature_0", "feature_1", "feature_2", "feature_3"] + for key in [ + "feature_0", + "feature_1", + "feature_2", + "feature_3", + ] ], dim=1, ) diff --git a/torchrec/distributed/train_pipeline/tests/test_train_pipelines.py b/torchrec/distributed/train_pipeline/tests/test_train_pipelines.py index a0ea00132..4728e831e 100644 --- a/torchrec/distributed/train_pipeline/tests/test_train_pipelines.py +++ b/torchrec/distributed/train_pipeline/tests/test_train_pipelines.py @@ -22,7 +22,10 @@ from torch._dynamo.testing import reduce_to_scalar_loss from torch._dynamo.utils import counters from torchrec.distributed import DistributedModelParallel -from torchrec.distributed.embedding_types import EmbeddingComputeKernel +from torchrec.distributed.embedding_types import ( + EmbeddingComputeKernel, + EmbeddingTableConfig, +) from torchrec.distributed.embeddingbag import EmbeddingBagCollectionSharder from torchrec.distributed.fp_embeddingbag import ( FeatureProcessedEmbeddingBagCollectionSharder, @@ -31,8 +34,13 @@ from torchrec.distributed.model_parallel import DMPCollection from torchrec.distributed.sharding_plan import ( construct_module_sharding_plan, + row_wise, table_wise, ) +from torchrec.distributed.test_utils.multi_process import ( + MultiProcessContext, + MultiProcessTestBase, +) from torchrec.distributed.test_utils.test_model import ( ModelInput, TestEBCSharder, @@ -331,6 +339,161 @@ def test_equal_to_non_pipelined_with_input_transformer(self) -> None: torch.testing.assert_close(pred_gpu.cpu(), pred) +def fp_ebc( + rank: int, + world_size: int, + tables: List[EmbeddingTableConfig], + weighted_tables: List[EmbeddingTableConfig], + data: List[Tuple[ModelInput, List[ModelInput]]], + backend: str = "nccl", + local_size: Optional[int] = None, +) -> None: + with MultiProcessContext(rank, world_size, backend, local_size) as ctx: + assert ctx.pg is not None + sharder = cast( + ModuleSharder[nn.Module], + FeatureProcessedEmbeddingBagCollectionSharder(), + ) + + class DummyWrapper(nn.Module): + def __init__(self, sparse_arch): + super().__init__() + self.m = sparse_arch + + def forward(self, model_input) -> Tuple[torch.Tensor, torch.Tensor]: + return self.m(model_input.idlist_features) + + max_feature_lengths = [10, 10, 12, 12] + sparse_arch = DummyWrapper( + create_module_and_freeze( + tables=tables, # pyre-ignore[6] + device=ctx.device, + use_fp_collection=False, + max_feature_lengths=max_feature_lengths, + ) + ) + + # compute_kernel = EmbeddingComputeKernel.FUSED.value + module_sharding_plan = construct_module_sharding_plan( + sparse_arch.m._fp_ebc, + per_param_sharding={ + "table_0": row_wise(), + "table_1": row_wise(), + "table_2": row_wise(), + "table_3": row_wise(), + }, + world_size=2, + device_type=ctx.device.type, + sharder=sharder, + ) + sharded_sparse_arch_pipeline = DistributedModelParallel( + module=copy.deepcopy(sparse_arch), + plan=ShardingPlan({"m._fp_ebc": module_sharding_plan}), + env=ShardingEnv.from_process_group(ctx.pg), # pyre-ignore[6] + sharders=[sharder], + device=ctx.device, + ) + sharded_sparse_arch_no_pipeline = DistributedModelParallel( + module=copy.deepcopy(sparse_arch), + plan=ShardingPlan({"m._fp_ebc": module_sharding_plan}), + env=ShardingEnv.from_process_group(ctx.pg), # pyre-ignore[6] + sharders=[sharder], + device=ctx.device, + ) + + batches = [] + for d in data: + batches.append(d[1][ctx.rank].to(ctx.device)) + dataloader = iter(batches) + + optimizer_no_pipeline = optim.SGD( + sharded_sparse_arch_no_pipeline.parameters(), lr=0.1 + ) + optimizer_pipeline = optim.SGD( + sharded_sparse_arch_pipeline.parameters(), lr=0.1 + ) + + pipeline = TrainPipelineSparseDist( + sharded_sparse_arch_pipeline, + optimizer_pipeline, + ctx.device, + ) + + for batch in batches[:-2]: + batch = batch.to(ctx.device) + optimizer_no_pipeline.zero_grad() + loss, pred = sharded_sparse_arch_no_pipeline(batch) + loss.backward() + optimizer_no_pipeline.step() + + pred_pipeline = pipeline.progress(dataloader) + torch.testing.assert_close(pred_pipeline.cpu(), pred.cpu()) + + +class TrainPipelineGPUTest(MultiProcessTestBase): + def setUp(self, backend: str = "nccl") -> None: + super().setUp() + + self.pipeline_class = TrainPipelineSparseDist + num_features = 4 + num_weighted_features = 4 + self.tables = [ + EmbeddingBagConfig( + num_embeddings=(i + 1) * 100, + embedding_dim=(i + 1) * 4, + name="table_" + str(i), + feature_names=["feature_" + str(i)], + ) + for i in range(num_features) + ] + self.weighted_tables = [ + EmbeddingBagConfig( + num_embeddings=(i + 1) * 100, + embedding_dim=(i + 1) * 4, + name="weighted_table_" + str(i), + feature_names=["weighted_feature_" + str(i)], + ) + for i in range(num_weighted_features) + ] + + self.backend = backend + if torch.cuda.is_available(): + self.device = torch.device("cuda") + else: + self.device = torch.device("cpu") + + if self.backend == "nccl" and self.device == torch.device("cpu"): + self.skipTest("NCCL not supported on CPUs.") + + def _generate_data( + self, + num_batches: int = 5, + batch_size: int = 1, + max_feature_lengths: Optional[List[int]] = None, + ) -> List[Tuple[ModelInput, List[ModelInput]]]: + return [ + ModelInput.generate( + tables=self.tables, + weighted_tables=self.weighted_tables, + batch_size=batch_size, + world_size=2, + num_float_features=10, + max_feature_lengths=max_feature_lengths, + ) + for i in range(num_batches) + ] + + def test_fp_ebc_rw(self) -> None: + data = self._generate_data(max_feature_lengths=[10, 10, 12, 12]) + self._run_multi_process_test( + callable=fp_ebc, + world_size=2, + tables=self.tables, + weighted_tables=self.weighted_tables, + data=data, + ) + + class TrainPipelineSparseDistTest(TrainPipelineSparseDistTestBase): # pyre-fixme[56]: Pyre was not able to infer the type of argument @unittest.skipIf( diff --git a/torchrec/distributed/train_pipeline/tests/test_train_pipelines_base.py b/torchrec/distributed/train_pipeline/tests/test_train_pipelines_base.py index 56e6ac636..85148a480 100644 --- a/torchrec/distributed/train_pipeline/tests/test_train_pipelines_base.py +++ b/torchrec/distributed/train_pipeline/tests/test_train_pipelines_base.py @@ -40,7 +40,7 @@ def setUp(self) -> None: self.pg = init_distributed_single_host(backend=backend, rank=0, world_size=1) num_features = 4 - num_weighted_features = 2 + num_weighted_features = 4 self.tables = [ EmbeddingBagConfig( num_embeddings=(i + 1) * 100, diff --git a/torchrec/distributed/train_pipeline/utils.py b/torchrec/distributed/train_pipeline/utils.py index 08e5c2aab..a40356e30 100644 --- a/torchrec/distributed/train_pipeline/utils.py +++ b/torchrec/distributed/train_pipeline/utils.py @@ -147,6 +147,7 @@ def _start_data_dist( # and this info was done in the _rewrite_model by tracing the # entire model to get the arg_info_list args, kwargs = forward.args.build_args_kwargs(batch) + args, kwargs = module.preprocess_input(args, kwargs) # Start input distribution. module_ctx = module.create_context() @@ -382,6 +383,8 @@ def _rewrite_model( # noqa C901 logger.info(f"Module '{node.target}' will be pipelined") child = sharded_modules[node.target] original_forwards.append(child.forward) + # Set pipelining flag on the child module + child.is_pipelined = True # pyre-ignore[8] Incompatible attribute type child.forward = pipelined_forward( node.target, diff --git a/torchrec/distributed/types.py b/torchrec/distributed/types.py index 46521ca6c..82b528130 100644 --- a/torchrec/distributed/types.py +++ b/torchrec/distributed/types.py @@ -19,7 +19,10 @@ Generic, Iterator, List, + Mapping, Optional, + ParamSpec, + Sequence, Tuple, Type, TypeVar, @@ -78,6 +81,8 @@ class GenericMeta(type): ) from torchrec.streamable import Multistreamable +_T = TypeVar("_T") + def _tabulate( table: List[List[Union[str, int]]], headers: Optional[List[str]] = None @@ -1015,6 +1020,8 @@ def __init__( if qcomm_codecs_registry is None: qcomm_codecs_registry = {} self._qcomm_codecs_registry = qcomm_codecs_registry + # In pipelining, this flag is flipped in rewrite_model when the forward is replaced with the pipelined forward + self.is_pipelined = False @abc.abstractmethod def create_context(self) -> ShrdCtx: @@ -1117,6 +1124,19 @@ def sharded_parameter_names(self, prefix: str = "") -> Iterator[str]: for key, _ in self.named_parameters(prefix): yield key + def preprocess_input( + self, + args: List[_T], + kwargs: Mapping[str, _T], + ) -> Tuple[List[_T], Mapping[str, _T]]: + """ + This function can be used to preprocess the input arguments prior to module forward call. + + For example, it is used in ShardedFeatureProcessorEmbeddingBagCollection to transform the input data + prior to the forward call. + """ + return args, kwargs + @property @abc.abstractmethod def unsharded_module_type(self) -> Type[nn.Module]: diff --git a/torchrec/distributed/utils.py b/torchrec/distributed/utils.py index b12660e97..e69a88371 100644 --- a/torchrec/distributed/utils.py +++ b/torchrec/distributed/utils.py @@ -26,8 +26,10 @@ from torch import nn from torch.autograd.profiler import record_function from torchrec import optim as trec_optim -from torchrec.distributed.embedding_types import EmbeddingComputeKernel - +from torchrec.distributed.embedding_types import ( + EmbeddingComputeKernel, + KeyedJaggedTensor, +) from torchrec.distributed.types import ( DataType, EmbeddingEvent, @@ -38,6 +40,7 @@ ShardMetadata, ) from torchrec.modules.embedding_configs import data_type_to_sparse_type +from torchrec.modules.feature_processor_ import FeatureProcessorsCollection from torchrec.types import CopyMixIn logger: logging.Logger = logging.getLogger(__name__) @@ -758,3 +761,47 @@ def _recalculate_torch_state_helper( _recalculate_torch_state_helper(child) _recalculate_torch_state_helper(module) + emb_kernel.weights_precision = converted_sparse_dtype # pyre-ignore [16] + + +def modify_input_for_feature_processor( + features: KeyedJaggedTensor, + feature_processors: Union[nn.ModuleDict, FeatureProcessorsCollection], + is_collection: bool, +) -> None: + """ + This function applies the feature processor pre input dist. This way we + can support row wise based sharding mechanisms. + + This is an inplace modifcation of the input KJT. + """ + with torch.no_grad(): + if features.weights_or_none() is None: + # force creation of weights, this way the feature jagged tensor weights are tied to the original KJT + features._weights = torch.zeros_like(features.values(), dtype=torch.float32) + + if is_collection: + if hasattr(feature_processors, "pre_process_pipeline_input"): + feature_processors.pre_process_pipeline_input(features) # pyre-ignore[29] + else: + logging.info( + f"[Feature Processor Pipeline] Skipping pre_process_pipeline_input for feature processor {feature_processors=}" + ) + else: + # per feature process + for feature in features.keys(): + if feature in feature_processors: # pyre-ignore[58] + feature_processor = feature_processors[feature] # pyre-ignore[29] + if hasattr(feature_processor, "pre_process_pipeline_input"): + feature_processor.pre_process_pipeline_input(features[feature]) + else: + logging.info( + f"[Feature Processor Pipeline] Skipping pre_process_pipeline_input for feature processor {feature_processor=}" + ) + else: + features[feature].weights().copy_( + torch.ones( + features[feature].values().shape[0], + device=features[feature].values().device, + ) + ) diff --git a/torchrec/modules/feature_processor_.py b/torchrec/modules/feature_processor_.py index 707f5bd2b..f064ad5e3 100644 --- a/torchrec/modules/feature_processor_.py +++ b/torchrec/modules/feature_processor_.py @@ -14,7 +14,7 @@ import torch -from torch import nn +from torch import distributed as dist, nn from torch.nn.modules.module import _IncompatibleKeys from torchrec.pt2.checks import is_non_strict_exporting @@ -72,6 +72,7 @@ def __init__( torch.empty([max_feature_length], device=device), requires_grad=True, ) + self.pipelined = False self.reset_parameters() @@ -85,15 +86,18 @@ def forward( ) -> JaggedTensor: """ Args: - features (JaggedTensor]): feature representation + features (JaggedTensor): feature representation Returns: JaggedTensor: same as input features with `weights` field being populated. """ - - seq = torch.ops.fbgemm.offsets_range( - features.offsets().long(), torch.numel(features.values()) - ) + if self.pipelined: + # position is embedded as weights + seq = features.weights().clone().to(torch.int64) + else: + seq = torch.ops.fbgemm.offsets_range( + features.offsets().long(), torch.numel(features.values()) + ) weighted_features = JaggedTensor( values=features.values(), lengths=features.lengths(), @@ -102,6 +106,20 @@ def forward( ) return weighted_features + def pre_process_pipeline_input(self, features: JaggedTensor) -> None: + """ + Args: + features (JaggedTensor]): feature representation + + Returns: + torch.Tensor: position weights + """ + self.pipelined = True + cat_seq = torch.ops.fbgemm.offsets_range( + features.offsets().long(), torch.numel(features.values()) + ) + features.weights().copy_(cat_seq.to(torch.float32)) + class FeatureProcessorsCollection(nn.Module): """ @@ -169,7 +187,7 @@ def __init__( for length in self.max_feature_lengths.values(): if length <= 0: raise - + self.pipelined = False # if pipelined, input dist has performed part of input feature processing self.position_weights: nn.ParameterDict = nn.ParameterDict() # needed since nn.ParameterDict isn't torchscriptable (get_items) self.position_weights_dict: Dict[str, nn.Parameter] = {} @@ -191,7 +209,6 @@ def reset_parameters(self) -> None: self.position_weights_dict[key] = self.position_weights[key] def forward(self, features: KeyedJaggedTensor) -> KeyedJaggedTensor: - # TODO unflattener doesnt work well with aten.to at submodule boundaries if is_non_strict_exporting(): offsets = features.offsets() if offsets.dtype == torch.int64: @@ -203,9 +220,12 @@ def forward(self, features: KeyedJaggedTensor) -> KeyedJaggedTensor: features.offsets().long(), torch.numel(features.values()) ) else: - cat_seq = torch.ops.fbgemm.offsets_range( - features.offsets().long(), torch.numel(features.values()) - ) + if self.pipelined: + cat_seq = features.weights().clone().to(torch.int64) + else: + cat_seq = torch.ops.fbgemm.offsets_range( + features.offsets().long(), torch.numel(features.values()) + ) return KeyedJaggedTensor( keys=features.keys(), @@ -245,3 +265,10 @@ def load_state_dict( for k, param in self.position_weights.items(): self.position_weights_dict[k] = param return result + + def pre_process_pipeline_input(self, features: KeyedJaggedTensor) -> None: + self.pipelined = True + cat_seq = torch.ops.fbgemm.offsets_range( + features.offsets().long(), torch.numel(features.values()) + ) + features.weights().copy_(cat_seq.to(torch.float32))