diff --git a/torchrec/distributed/sharding/tw_sharding.py b/torchrec/distributed/sharding/tw_sharding.py index d4c1ca48a..3e197b0c7 100644 --- a/torchrec/distributed/sharding/tw_sharding.py +++ b/torchrec/distributed/sharding/tw_sharding.py @@ -135,12 +135,8 @@ def _shard( dtensor_metadata = None if self._env.output_dtensor: dtensor_metadata = DTensorMetadata( - mesh=( - self._env.device_mesh["replicate"] # pyre-ignore[16] - if self._is_2D_parallel - else self._env.device_mesh - ), - placements=(Replicate(),), + mesh=self._env.device_mesh, + placements=(Replicate(),) * (self._env.device_mesh.ndim), # pyre-ignore[16] size=( info.embedding_config.num_embeddings, info.embedding_config.embedding_dim,