From 41cdb733b3c8b3a0da4fd43ac4b9a4b3f4fbe62e Mon Sep 17 00:00:00 2001 From: nrothGIT Date: Sun, 6 Apr 2025 17:59:37 -0400 Subject: [PATCH] Update tensor_parallel_example.py --- .../tensor_parallelism/tensor_parallel_example.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/distributed/tensor_parallelism/tensor_parallel_example.py b/distributed/tensor_parallelism/tensor_parallel_example.py index 2731e8046b..0b9c884507 100755 --- a/distributed/tensor_parallelism/tensor_parallel_example.py +++ b/distributed/tensor_parallelism/tensor_parallel_example.py @@ -91,9 +91,6 @@ def forward(self, x): # create model and move it to GPU - init"cuda"_mesh has already mapped GPU ids. tp_model = ToyModel().to("cuda") -# Create an optimizer for the parallelized module. -lr = 0.25 -optimizer = torch.optim.AdamW(tp_model.parameters(), lr=lr, foreach=True) # Custom parallelization plan for the model tp_model = parallelize_module( @@ -104,6 +101,12 @@ def forward(self, x): "out_proj": RowwiseParallel(), }, ) + +# Create an optimizer for the parallelized module. +lr = 0.25 +optimizer = torch.optim.AdamW(tp_model.parameters(), lr=lr, foreach=True) + + # Perform a num of iterations of forward/backward # and optimizations for the sharded module. num_iters = 10