Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 40 additions & 17 deletions torchrec/distributed/tests/test_init_parameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,25 @@ def initialize_and_test_parameters(
local_size: Optional[int] = None,
) -> None:
with MultiProcessContext(rank, world_size, backend, local_size) as ctx:
# Set seed again in each process to ensure consistency
torch.manual_seed(42)
if torch.cuda.is_available():
torch.cuda.manual_seed(42)

key = (
f"embeddings.{table_name}.weight"
if isinstance(embedding_tables, EmbeddingCollection)
else f"embedding_bags.{table_name}.weight"
)

# Create the same fixed tensor in each process
fixed_tensor = torch.randn(10, 64, generator=torch.Generator().manual_seed(42))

# Load the fixed tensor into the embedding_tables to ensure consistency
embedding_tables.load_state_dict({key: fixed_tensor})

# Store the original tensor on CPU for comparison BEFORE creating the model
original_tensor = embedding_tables.state_dict()[key].clone().cpu()

module_sharding_plan = construct_module_sharding_plan(
embedding_tables,
Expand All @@ -79,12 +98,8 @@ def initialize_and_test_parameters(
env=ShardingEnv.from_process_group(ctx.pg),
sharders=sharders,
device=ctx.device,
)

key = (
f"embeddings.{table_name}.weight"
if isinstance(embedding_tables, EmbeddingCollection)
else f"embedding_bags.{table_name}.weight"
init_data_parallel=False,
init_parameters=False,
)

if isinstance(model.state_dict()[key], DTensor):
Expand All @@ -96,28 +111,26 @@ def initialize_and_test_parameters(
gathered_tensor = model.state_dict()[key].full_tensor()
if ctx.rank == 0:
torch.testing.assert_close(
gathered_tensor,
embedding_tables.state_dict()[key],
gathered_tensor.cpu(), original_tensor, rtol=1e-5, atol=1e-6
)
elif isinstance(model.state_dict()[key], ShardedTensor):
if ctx.rank == 0:
gathered_tensor = torch.empty_like(
embedding_tables.state_dict()[key], device=ctx.device
)
gathered_tensor = torch.empty_like(original_tensor, device=ctx.device)
else:
gathered_tensor = None

model.state_dict()[key].gather(dst=0, out=gathered_tensor)

if ctx.rank == 0:
torch.testing.assert_close(
none_throws(gathered_tensor).to("cpu"),
embedding_tables.state_dict()[key],
none_throws(gathered_tensor).cpu(),
original_tensor,
rtol=1e-5,
atol=1e-6,
)
elif isinstance(model.state_dict()[key], torch.Tensor):
torch.testing.assert_close(
embedding_tables.state_dict()[key].cpu(),
model.state_dict()[key].cpu(),
model.state_dict()[key].cpu(), original_tensor, rtol=1e-5, atol=1e-6
)
else:
raise AssertionError(
Expand Down Expand Up @@ -161,6 +174,9 @@ def test_initialize_parameters_ec(self, sharding_type: str) -> None:
backend = "nccl"
table_name = "free_parameters"

# Set seed for deterministic tensor generation
torch.manual_seed(42)

# Initialize embedding table on non-meta device, in this case cuda:0
embedding_tables = EmbeddingCollection(
tables=[
Expand All @@ -173,8 +189,10 @@ def test_initialize_parameters_ec(self, sharding_type: str) -> None:
],
)

# Use a fixed tensor with explicit seeding for consistent testing
fixed_tensor = torch.randn(10, 64, generator=torch.Generator().manual_seed(42))
embedding_tables.load_state_dict(
{f"embeddings.{table_name}.weight": torch.randn(10, 64)}
{f"embeddings.{table_name}.weight": fixed_tensor}
)

self._run_multi_process_test(
Expand Down Expand Up @@ -210,6 +228,9 @@ def test_initialize_parameters_ebc(self, sharding_type: str) -> None:
backend = "nccl"
table_name = "free_parameters"

# Set seed for deterministic tensor generation
torch.manual_seed(42)

# Initialize embedding bag on non-meta device, in this case cuda:0
embedding_tables = EmbeddingBagCollection(
tables=[
Expand All @@ -222,8 +243,10 @@ def test_initialize_parameters_ebc(self, sharding_type: str) -> None:
],
)

# Use a fixed tensor with explicit seeding for consistent testing
fixed_tensor = torch.randn(10, 64, generator=torch.Generator().manual_seed(42))
embedding_tables.load_state_dict(
{f"embedding_bags.{table_name}.weight": torch.randn(10, 64)}
{f"embedding_bags.{table_name}.weight": fixed_tensor}
)

self._run_multi_process_test(
Expand Down
Loading