diff --git a/distributed/tensor_parallelism/tensor_parallel_example.py b/distributed/tensor_parallelism/tensor_parallel_example.py index 2731e8046b..0b9c884507 100755 --- a/distributed/tensor_parallelism/tensor_parallel_example.py +++ b/distributed/tensor_parallelism/tensor_parallel_example.py @@ -91,9 +91,6 @@ def forward(self, x): # create model and move it to GPU - init"cuda"_mesh has already mapped GPU ids. tp_model = ToyModel().to("cuda") -# Create an optimizer for the parallelized module. -lr = 0.25 -optimizer = torch.optim.AdamW(tp_model.parameters(), lr=lr, foreach=True) # Custom parallelization plan for the model tp_model = parallelize_module( @@ -104,6 +101,12 @@ def forward(self, x): "out_proj": RowwiseParallel(), }, ) + +# Create an optimizer for the parallelized module. +lr = 0.25 +optimizer = torch.optim.AdamW(tp_model.parameters(), lr=lr, foreach=True) + + # Perform a num of iterations of forward/backward # and optimizations for the sharded module. num_iters = 10