-
Notifications
You must be signed in to change notification settings - Fork 40
/
Copy pathtrain_torchrun.py
89 lines (74 loc) · 3.15 KB
/
train_torchrun.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
import os
import time
import torch
from model import MLP
from torchvision.datasets import mnist
from torch.utils.data import DataLoader
from torchvision.transforms import ToTensor
# XLA imports
import torch_xla.core.xla_model as xm
# XLA imports for parallel loader and multi-processing
import torch_xla.distributed.parallel_loader as pl
from torch.utils.data.distributed import DistributedSampler
# Initialize XLA process group for torchrun
import torch_xla.distributed.xla_backend
torch.distributed.init_process_group('xla')
# Global constants
EPOCHS = 4
WARMUP_STEPS = 2
BATCH_SIZE = 32
# Load MNIST train dataset
if not xm.is_master_ordinal(): xm.rendezvous('dataset_download')
train_dataset = mnist.MNIST(root='/tmp/MNIST_DATA_train',
train=True, download=True, transform=ToTensor())
if xm.is_master_ordinal(): xm.rendezvous('dataset_download')
def main():
# XLA MP: get world size
world_size = xm.xrt_world_size()
# multi-processing: ensure each worker has same initial weights
torch.manual_seed(0)
# Move model to device and declare optimizer and loss function
device = 'xla'
model = MLP().to(device)
# For multiprocessing, scale up learning rate
optimizer = torch.optim.SGD(model.parameters(), lr=0.01 * world_size)
loss_fn = torch.nn.NLLLoss()
# Prepare data loader
train_sampler = None
if world_size > 1:
train_sampler = DistributedSampler(train_dataset,
num_replicas=world_size,
rank=xm.get_ordinal(),
shuffle=True)
train_loader = DataLoader(train_dataset,
batch_size=BATCH_SIZE,
sampler=train_sampler,
shuffle=False if train_sampler else True)
# XLA MP: use MpDeviceLoader from torch_xla.distributed
train_device_loader = pl.MpDeviceLoader(train_loader, device)
# Run the training loop
print('----------Training ---------------')
model.train()
for epoch in range(EPOCHS):
start = time.time()
for idx, (train_x, train_label) in enumerate(train_device_loader):
optimizer.zero_grad()
train_x = train_x.view(train_x.size(0), -1)
output = model(train_x)
loss = loss_fn(output, train_label)
loss.backward()
xm.optimizer_step(optimizer) # XLA MP: performs grad allreduce and optimizer step
if idx < WARMUP_STEPS: # skip warmup iterations
start = time.time()
# Compute statistics for the last epoch
interval = idx - WARMUP_STEPS # skip warmup iterations
throughput = interval / (time.time() - start)
print("Train throughput (iter/sec): {}".format(throughput))
print("Final loss is {:0.4f}".format(loss.detach().to('cpu')))
# Save checkpoint for evaluation (xm.save ensures only one process save)
os.makedirs("checkpoints", exist_ok=True)
checkpoint = {'state_dict': model.state_dict()}
xm.save(checkpoint,'checkpoints/checkpoint.pt')
print('----------End Training ---------------')
if __name__ == '__main__':
main()