Skip to content

Commit fbb2de6

Browse files
author
pytorchbot
committed
2025-05-29 nightly release (3df4260)
1 parent 7b09c7a commit fbb2de6

File tree

16 files changed

+1566
-569
lines changed

16 files changed

+1566
-569
lines changed

torchrec/distributed/embeddingbag.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1531,15 +1531,9 @@ def update_shards(
15311531
current_state = self.state_dict()
15321532
# TODO: Save Optimizers
15331533

1534-
saved_weights = {}
15351534
# TODO: Saving lookups tensors to CPU to eventually avoid recreating them completely again
1536-
for i, lookup in enumerate(self._lookups):
1537-
for attribute, tbe_module in lookup.named_modules():
1538-
if type(tbe_module) is DenseTableBatchedEmbeddingBagsCodegen:
1539-
saved_weights[str(i) + "." + attribute] = tbe_module.weights.cpu()
1540-
# Note: lookup.purge should delete tbe_module and weights
1541-
# del tbe_module.weights
1542-
# del tbe_module
1535+
# TODO: Ensure lookup tensors are actually being deleted
1536+
for _, lookup in enumerate(self._lookups):
15431537
# pyre-ignore
15441538
lookup.purge()
15451539

@@ -1603,6 +1597,12 @@ def update_shards(
16031597
for embedding_configs in self.sharding_type_to_sharding_infos.values()
16041598
]
16051599

1600+
# Reset input dists
1601+
self._has_uninitialized_input_dist = True
1602+
self._input_dists: List[nn.Module] = []
1603+
self._features_order: List[int] = []
1604+
self._feature_splits: List[int] = []
1605+
16061606
self._create_lookups()
16071607
self._update_output_dist()
16081608

torchrec/distributed/mc_modules.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -671,7 +671,9 @@ def _kjt_list_to_tensor_list(
671671
vals.append(feature_split.values() + offset)
672672
remapped_ids_ret.append(torch.cat(vals).view(-1, 1))
673673
else:
674-
remapped_ids_ret.append(kjt.values() + self._table_to_offset[tables[0]])
674+
remapped_ids_ret.append(
675+
(kjt.values() + self._table_to_offset[tables[0]]).unsqueeze(-1)
676+
)
675677
return remapped_ids_ret
676678

677679
def global_to_local_index(

torchrec/distributed/model_parallel.py

Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535
from torchrec.distributed.types import (
3636
EnumerableShardingSpec,
3737
ModuleSharder,
38+
ParameterSharding,
3839
ShardedModule,
3940
ShardingEnv,
4041
ShardingEnv2D,
@@ -612,6 +613,84 @@ def _reset_parameters(module: nn.Module) -> None:
612613
if hasattr(m, "reset_parameters"):
613614
m.reset_parameters()
614615

616+
def reshard(
617+
self,
618+
sharded_module_fqn: str,
619+
changed_shard_to_params: Dict[str, ParameterSharding],
620+
) -> None:
621+
"""
622+
Reshards an already-sharded module in the DMP given a set of ParameterShardings to change placements.
623+
624+
This method allows you to dynamically change the sharding strategy for a specific module
625+
without recreating the entire DMP. It's particularly useful for:
626+
1. Adapting to changing requirements during training
627+
2. Implementing progressive sharding strategies
628+
3. Rebalancing load across devices
629+
4. A/B Testing different sharding plans
630+
631+
Args:
632+
path_to_sharded_module (str): The path to the sharded module in the DMP.
633+
For example, "sparse.ebc".
634+
changed_shard_to_params (Dict[str, ParameterSharding]): A dictionary mapping
635+
parameter names to their new ParameterSharding configurations. Includes
636+
only the shards that needs to be moved.
637+
638+
Example:
639+
```
640+
# Original sharding plan might have table sharded across 2 GPUs
641+
original_plan = {
642+
"table_0': ParameterSharding(
643+
sharding_type="table_wise",
644+
ranks=[0, 1, 2, 3],
645+
sharding_spec=EnumerableShardingSpec(...)
646+
)
647+
}
648+
649+
# New sharding plan to shard across 4 GPUs
650+
new_plan = {
651+
"weight": ParameterSharding(
652+
sharding_type="table_wise",
653+
ranks=[0, 1, 2, 3],
654+
sharding_spec=EnumerableShardingSpec(...)
655+
)
656+
}
657+
658+
# Helper function for only selecting the delta between original and new plan
659+
changed_sharding_params = output_sharding_plan_delta(new_plan)
660+
661+
# Reshard the module and redistribute the tensors
662+
model.reshard("embedding_module", changed_sharding_params)
663+
```
664+
665+
Notes:
666+
- The sharder for the module must implement a `reshard` method
667+
- Resharding involves redistributing tensor data across devices, which can be expensive
668+
- After resharding, the optimizer state is maintained for the module
669+
- The sharding plan is updated to reflect the new configuration
670+
"""
671+
steps = sharded_module_fqn.split(".")
672+
sharded_module = self.module
673+
for s in steps:
674+
sharded_module = getattr(sharded_module, s)
675+
676+
assert isinstance(sharded_module, ShardedModule)
677+
assert changed_shard_to_params is not None
678+
sharder_key = sharded_module.unsharded_module_type
679+
sharder = self._sharder_map[sharder_key]
680+
assert hasattr(
681+
sharder, "reshard"
682+
), "reshard is not implemented for this sharder"
683+
sharded_module = sharder.reshard( # pyre-ignore
684+
sharded_module,
685+
changed_shard_to_params,
686+
self._env,
687+
self.device,
688+
)
689+
690+
self._optim: CombinedOptimizer = self._init_optim(self._dmp_wrapped_module)
691+
self._plan.plan[sharded_module_fqn] = sharded_module.module_sharding_plan
692+
return sharded_module
693+
615694

616695
class DMPCollection(DistributedModelParallel):
617696
"""

torchrec/distributed/planner/tests/test_types.py

Lines changed: 93 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,12 +14,17 @@
1414
import torch
1515
from torchrec.distributed.embedding_types import EmbeddingComputeKernel
1616

17-
from torchrec.distributed.planner.types import Shard, ShardingOption
17+
from torchrec.distributed.planner.types import (
18+
ParameterConstraints,
19+
Shard,
20+
ShardingOption,
21+
)
1822
from torchrec.distributed.types import (
1923
BoundsCheckMode,
2024
CacheAlgorithm,
2125
CacheParams,
2226
DataType,
27+
KeyValueParams,
2328
ShardingType,
2429
)
2530
from torchrec.modules.embedding_configs import EmbeddingBagConfig, EmbeddingConfig
@@ -207,3 +212,90 @@ def test_module_pooled_mch_ec(self) -> None:
207212
shards=[Shard(size=shard_size, offset=offset) for offset in shard_offsets],
208213
)
209214
self.assertEqual(sharding_option.is_pooled, False)
215+
216+
217+
class TestParameterConstraintsHash(unittest.TestCase):
218+
219+
def test_hash_equality(self) -> None:
220+
# Create two identical instances
221+
pc1 = ParameterConstraints(
222+
sharding_types=["type1", "type2"],
223+
compute_kernels=["kernel1"],
224+
min_partition=4,
225+
pooling_factors=[1.0, 2.0],
226+
num_poolings=[1.0],
227+
batch_sizes=[32],
228+
is_weighted=True,
229+
cache_params=CacheParams(),
230+
enforce_hbm=True,
231+
stochastic_rounding=False,
232+
bounds_check_mode=BoundsCheckMode(1),
233+
feature_names=["feature1", "feature2"],
234+
output_dtype=DataType.FP32,
235+
device_group="cuda",
236+
key_value_params=KeyValueParams(),
237+
)
238+
239+
pc2 = ParameterConstraints(
240+
sharding_types=["type1", "type2"],
241+
compute_kernels=["kernel1"],
242+
min_partition=4,
243+
pooling_factors=[1.0, 2.0],
244+
num_poolings=[1.0],
245+
batch_sizes=[32],
246+
is_weighted=True,
247+
cache_params=CacheParams(),
248+
enforce_hbm=True,
249+
stochastic_rounding=False,
250+
bounds_check_mode=BoundsCheckMode(1),
251+
feature_names=["feature1", "feature2"],
252+
output_dtype=DataType.FP32,
253+
device_group="cuda",
254+
key_value_params=KeyValueParams(),
255+
)
256+
257+
self.assertEqual(
258+
hash(pc1), hash(pc2), "Hashes should be equal for identical instances"
259+
)
260+
261+
def test_hash_inequality(self) -> None:
262+
# Create two different instances
263+
pc1 = ParameterConstraints(
264+
sharding_types=["type1"],
265+
compute_kernels=["kernel1"],
266+
min_partition=4,
267+
pooling_factors=[1.0],
268+
num_poolings=[1.0],
269+
batch_sizes=[32],
270+
is_weighted=True,
271+
cache_params=CacheParams(),
272+
enforce_hbm=True,
273+
stochastic_rounding=False,
274+
bounds_check_mode=BoundsCheckMode(1),
275+
feature_names=["feature1"],
276+
output_dtype=DataType.FP32,
277+
device_group="cuda",
278+
key_value_params=KeyValueParams(),
279+
)
280+
281+
pc2 = ParameterConstraints(
282+
sharding_types=["type2"],
283+
compute_kernels=["kernel2"],
284+
min_partition=8,
285+
pooling_factors=[2.0],
286+
num_poolings=[2.0],
287+
batch_sizes=[64],
288+
is_weighted=False,
289+
cache_params=CacheParams(),
290+
enforce_hbm=False,
291+
stochastic_rounding=True,
292+
bounds_check_mode=BoundsCheckMode(1),
293+
feature_names=["feature2"],
294+
output_dtype=DataType.FP16,
295+
device_group="cpu",
296+
key_value_params=KeyValueParams(),
297+
)
298+
299+
self.assertNotEqual(
300+
hash(pc1), hash(pc2), "Hashes should be different for different instances"
301+
)

torchrec/distributed/planner/types.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -703,6 +703,27 @@ class ParameterConstraints:
703703
device_group: Optional[str] = None
704704
key_value_params: Optional[KeyValueParams] = None
705705

706+
def __hash__(self) -> int:
707+
return hash(
708+
(
709+
tuple(self.sharding_types) if self.sharding_types else None,
710+
tuple(self.compute_kernels) if self.compute_kernels else None,
711+
self.min_partition,
712+
tuple(self.pooling_factors),
713+
tuple(self.num_poolings) if self.num_poolings else None,
714+
tuple(self.batch_sizes) if self.batch_sizes else None,
715+
self.is_weighted,
716+
self.cache_params,
717+
self.enforce_hbm,
718+
self.stochastic_rounding,
719+
self.bounds_check_mode,
720+
tuple(self.feature_names) if self.feature_names else None,
721+
self.output_dtype,
722+
self.device_group,
723+
self.key_value_params,
724+
)
725+
)
726+
706727

707728
class PlannerErrorType(Enum):
708729
"""

torchrec/distributed/sharding/dynamic_sharding.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,13 +7,15 @@
77

88
# pyre-strict
99

10+
import copy
1011
from typing import Any, Callable, Dict, List, Tuple
1112

1213
import torch
1314
import torch.distributed as dist
1415
import torch.nn.functional as F
1516
from torch.distributed._shard.sharded_tensor import Shard
1617
from torchrec.distributed.types import (
18+
EmbeddingModuleShardingPlan,
1719
ParameterSharding,
1820
ShardedModule,
1921
ShardedTensor,
@@ -364,3 +366,25 @@ def pad_tensor_to_max_dims(
364366
mode="constant",
365367
value=0,
366368
)
369+
370+
371+
# Utils
372+
def output_sharding_plan_delta(
373+
old_plan: EmbeddingModuleShardingPlan, new_plan: EmbeddingModuleShardingPlan
374+
) -> EmbeddingModuleShardingPlan:
375+
"""
376+
Compute and return a new sharding plan that is the delta
377+
between new and old embedding module plans. Assumes that the old and new plan
378+
have the same number of parameters/tables.
379+
380+
This is useful for Dynamic Sharding since Resharding API takes in only the
381+
ParameterSharding or shards that needs to be moved.
382+
"""
383+
assert len(old_plan) == len(new_plan)
384+
return EmbeddingModuleShardingPlan(
385+
{
386+
k: copy.deepcopy(v)
387+
for k, v in new_plan.items()
388+
if v.ranks != old_plan[k].ranks
389+
}
390+
)

torchrec/distributed/sharding_plan.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -410,6 +410,20 @@ def _get_parameter_sharding(
410410
]
411411

412412

413+
def get_sharding_constructor_from_type(
414+
sharding_type: ShardingType,
415+
) -> Callable[..., ParameterShardingGenerator]:
416+
sharding_type_to_constructor = {
417+
ShardingType.TABLE_WISE: table_wise,
418+
ShardingType.ROW_WISE: row_wise,
419+
ShardingType.COLUMN_WISE: column_wise,
420+
ShardingType.TABLE_ROW_WISE: table_row_wise,
421+
ShardingType.GRID_SHARD: grid_shard,
422+
ShardingType.DATA_PARALLEL: data_parallel,
423+
}
424+
return sharding_type_to_constructor[sharding_type]
425+
426+
413427
def data_parallel() -> ParameterShardingGenerator:
414428
"""
415429
Returns a generator of ParameterShardingPlan for `ShardingType::DATA_PARALLEL` for construct_module_sharding_plan.

0 commit comments

Comments
 (0)