Skip to content

Commit ea825a5

Browse files
VitalyFedyuninsoumith
authored andcommittedFeb 15, 2019
Add Cuda support to mnist_hogwild (pytorch#508)
1 parent fe8abc3 commit ea825a5

File tree

2 files changed

+27
-18
lines changed

2 files changed

+27
-18
lines changed
 

‎mnist_hogwild/main.py

+10-5
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,8 @@
2525
help='how many batches to wait before logging training status')
2626
parser.add_argument('--num-processes', type=int, default=2, metavar='N',
2727
help='how many training processes to use (default: 2)')
28+
parser.add_argument('--cuda', action='store_true', default=False,
29+
help='enables CUDA training')
2830

2931
class Net(nn.Module):
3032
def __init__(self):
@@ -47,21 +49,24 @@ def forward(self, x):
4749
if __name__ == '__main__':
4850
args = parser.parse_args()
4951

52+
use_cuda = args.cuda and torch.cuda.is_available()
53+
device = torch.device("cuda" if use_cuda else "cpu")
54+
dataloader_kwargs = {'pin_memory': True} if use_cuda else {}
55+
5056
torch.manual_seed(args.seed)
57+
mp.set_start_method('spawn')
5158

52-
model = Net()
59+
model = Net().to(device)
5360
model.share_memory() # gradients are allocated lazily, so they are not shared here
5461

5562
processes = []
5663
for rank in range(args.num_processes):
57-
p = mp.Process(target=train, args=(rank, args, model))
64+
p = mp.Process(target=train, args=(rank, args, model, device, dataloader_kwargs))
5865
# We first train the model across `num_processes` processes
5966
p.start()
6067
processes.append(p)
6168
for p in processes:
6269
p.join()
6370

6471
# Once training is complete, we can test the model
65-
test(args, model)
66-
67-
72+
test(args, model, device, dataloader_kwargs)

‎mnist_hogwild/train.py

+17-13
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,8 @@
44
import torch.nn.functional as F
55
from torchvision import datasets, transforms
66

7-
def train(rank, args, model):
7+
8+
def train(rank, args, model, device, dataloader_kwargs):
89
torch.manual_seed(args.seed + rank)
910

1011
train_loader = torch.utils.data.DataLoader(
@@ -13,32 +14,35 @@ def train(rank, args, model):
1314
transforms.ToTensor(),
1415
transforms.Normalize((0.1307,), (0.3081,))
1516
])),
16-
batch_size=args.batch_size, shuffle=True, num_workers=1)
17+
batch_size=args.batch_size, shuffle=True, num_workers=1,
18+
**dataloader_kwargs)
1719

1820
optimizer = optim.SGD(model.parameters(), lr=args.lr, momentum=args.momentum)
1921
for epoch in range(1, args.epochs + 1):
20-
train_epoch(epoch, args, model, train_loader, optimizer)
22+
train_epoch(epoch, args, model, device, train_loader, optimizer)
23+
2124

22-
def test(args, model):
25+
def test(args, model, device, dataloader_kwargs):
2326
torch.manual_seed(args.seed)
2427

2528
test_loader = torch.utils.data.DataLoader(
2629
datasets.MNIST('../data', train=False, transform=transforms.Compose([
2730
transforms.ToTensor(),
2831
transforms.Normalize((0.1307,), (0.3081,))
2932
])),
30-
batch_size=args.batch_size, shuffle=True, num_workers=1)
33+
batch_size=args.batch_size, shuffle=True, num_workers=1,
34+
**dataloader_kwargs)
3135

32-
test_epoch(model, test_loader)
36+
test_epoch(model, device, test_loader, device)
3337

3438

35-
def train_epoch(epoch, args, model, data_loader, optimizer):
39+
def train_epoch(epoch, args, model, device, data_loader, optimizer):
3640
model.train()
3741
pid = os.getpid()
3842
for batch_idx, (data, target) in enumerate(data_loader):
3943
optimizer.zero_grad()
40-
output = model(data)
41-
loss = F.nll_loss(output, target)
44+
output = model(data.to(device))
45+
loss = F.nll_loss(output, target.to(device))
4246
loss.backward()
4347
optimizer.step()
4448
if batch_idx % args.log_interval == 0:
@@ -47,16 +51,16 @@ def train_epoch(epoch, args, model, data_loader, optimizer):
4751
100. * batch_idx / len(data_loader), loss.item()))
4852

4953

50-
def test_epoch(model, data_loader):
54+
def test_epoch(model, device, data_loader):
5155
model.eval()
5256
test_loss = 0
5357
correct = 0
5458
with torch.no_grad():
5559
for data, target in data_loader:
56-
output = model(data)
57-
test_loss += F.nll_loss(output, target, reduction='sum').item() # sum up batch loss
60+
output = model(data.to(device))
61+
test_loss += F.nll_loss(output, target.to(device), reduction='sum').item() # sum up batch loss
5862
pred = output.max(1)[1] # get the index of the max log-probability
59-
correct += pred.eq(target).sum().item()
63+
correct += pred.eq(target.to(device)).sum().item()
6064

6165
test_loss /= len(data_loader.dataset)
6266
print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(

0 commit comments

Comments
 (0)
Please sign in to comment.