diff --git a/torchrec/distributed/tests/test_init_parameters.py b/torchrec/distributed/tests/test_init_parameters.py index 8c80db7a5..f9427f74e 100644 --- a/torchrec/distributed/tests/test_init_parameters.py +++ b/torchrec/distributed/tests/test_init_parameters.py @@ -60,6 +60,25 @@ def initialize_and_test_parameters( local_size: Optional[int] = None, ) -> None: with MultiProcessContext(rank, world_size, backend, local_size) as ctx: + # Set seed again in each process to ensure consistency + torch.manual_seed(42) + if torch.cuda.is_available(): + torch.cuda.manual_seed(42) + + key = ( + f"embeddings.{table_name}.weight" + if isinstance(embedding_tables, EmbeddingCollection) + else f"embedding_bags.{table_name}.weight" + ) + + # Create the same fixed tensor in each process + fixed_tensor = torch.randn(10, 64, generator=torch.Generator().manual_seed(42)) + + # Load the fixed tensor into the embedding_tables to ensure consistency + embedding_tables.load_state_dict({key: fixed_tensor}) + + # Store the original tensor on CPU for comparison BEFORE creating the model + original_tensor = embedding_tables.state_dict()[key].clone().cpu() module_sharding_plan = construct_module_sharding_plan( embedding_tables, @@ -79,12 +98,8 @@ def initialize_and_test_parameters( env=ShardingEnv.from_process_group(ctx.pg), sharders=sharders, device=ctx.device, - ) - - key = ( - f"embeddings.{table_name}.weight" - if isinstance(embedding_tables, EmbeddingCollection) - else f"embedding_bags.{table_name}.weight" + init_data_parallel=False, + init_parameters=False, ) if isinstance(model.state_dict()[key], DTensor): @@ -96,14 +111,11 @@ def initialize_and_test_parameters( gathered_tensor = model.state_dict()[key].full_tensor() if ctx.rank == 0: torch.testing.assert_close( - gathered_tensor, - embedding_tables.state_dict()[key], + gathered_tensor.cpu(), original_tensor, rtol=1e-5, atol=1e-6 ) elif isinstance(model.state_dict()[key], ShardedTensor): if ctx.rank == 0: - gathered_tensor = torch.empty_like( - embedding_tables.state_dict()[key], device=ctx.device - ) + gathered_tensor = torch.empty_like(original_tensor, device=ctx.device) else: gathered_tensor = None @@ -111,13 +123,14 @@ def initialize_and_test_parameters( if ctx.rank == 0: torch.testing.assert_close( - none_throws(gathered_tensor).to("cpu"), - embedding_tables.state_dict()[key], + none_throws(gathered_tensor).cpu(), + original_tensor, + rtol=1e-5, + atol=1e-6, ) elif isinstance(model.state_dict()[key], torch.Tensor): torch.testing.assert_close( - embedding_tables.state_dict()[key].cpu(), - model.state_dict()[key].cpu(), + model.state_dict()[key].cpu(), original_tensor, rtol=1e-5, atol=1e-6 ) else: raise AssertionError( @@ -161,6 +174,9 @@ def test_initialize_parameters_ec(self, sharding_type: str) -> None: backend = "nccl" table_name = "free_parameters" + # Set seed for deterministic tensor generation + torch.manual_seed(42) + # Initialize embedding table on non-meta device, in this case cuda:0 embedding_tables = EmbeddingCollection( tables=[ @@ -173,8 +189,10 @@ def test_initialize_parameters_ec(self, sharding_type: str) -> None: ], ) + # Use a fixed tensor with explicit seeding for consistent testing + fixed_tensor = torch.randn(10, 64, generator=torch.Generator().manual_seed(42)) embedding_tables.load_state_dict( - {f"embeddings.{table_name}.weight": torch.randn(10, 64)} + {f"embeddings.{table_name}.weight": fixed_tensor} ) self._run_multi_process_test( @@ -210,6 +228,9 @@ def test_initialize_parameters_ebc(self, sharding_type: str) -> None: backend = "nccl" table_name = "free_parameters" + # Set seed for deterministic tensor generation + torch.manual_seed(42) + # Initialize embedding bag on non-meta device, in this case cuda:0 embedding_tables = EmbeddingBagCollection( tables=[ @@ -222,8 +243,10 @@ def test_initialize_parameters_ebc(self, sharding_type: str) -> None: ], ) + # Use a fixed tensor with explicit seeding for consistent testing + fixed_tensor = torch.randn(10, 64, generator=torch.Generator().manual_seed(42)) embedding_tables.load_state_dict( - {f"embedding_bags.{table_name}.weight": torch.randn(10, 64)} + {f"embedding_bags.{table_name}.weight": fixed_tensor} ) self._run_multi_process_test(