4
4
import torch .nn .functional as F
5
5
from torchvision import datasets , transforms
6
6
7
- def train (rank , args , model ):
7
+
8
+ def train (rank , args , model , device , dataloader_kwargs ):
8
9
torch .manual_seed (args .seed + rank )
9
10
10
11
train_loader = torch .utils .data .DataLoader (
@@ -13,32 +14,35 @@ def train(rank, args, model):
13
14
transforms .ToTensor (),
14
15
transforms .Normalize ((0.1307 ,), (0.3081 ,))
15
16
])),
16
- batch_size = args .batch_size , shuffle = True , num_workers = 1 )
17
+ batch_size = args .batch_size , shuffle = True , num_workers = 1 ,
18
+ ** dataloader_kwargs )
17
19
18
20
optimizer = optim .SGD (model .parameters (), lr = args .lr , momentum = args .momentum )
19
21
for epoch in range (1 , args .epochs + 1 ):
20
- train_epoch (epoch , args , model , train_loader , optimizer )
22
+ train_epoch (epoch , args , model , device , train_loader , optimizer )
23
+
21
24
22
- def test (args , model ):
25
+ def test (args , model , device , dataloader_kwargs ):
23
26
torch .manual_seed (args .seed )
24
27
25
28
test_loader = torch .utils .data .DataLoader (
26
29
datasets .MNIST ('../data' , train = False , transform = transforms .Compose ([
27
30
transforms .ToTensor (),
28
31
transforms .Normalize ((0.1307 ,), (0.3081 ,))
29
32
])),
30
- batch_size = args .batch_size , shuffle = True , num_workers = 1 )
33
+ batch_size = args .batch_size , shuffle = True , num_workers = 1 ,
34
+ ** dataloader_kwargs )
31
35
32
- test_epoch (model , test_loader )
36
+ test_epoch (model , device , test_loader , device )
33
37
34
38
35
- def train_epoch (epoch , args , model , data_loader , optimizer ):
39
+ def train_epoch (epoch , args , model , device , data_loader , optimizer ):
36
40
model .train ()
37
41
pid = os .getpid ()
38
42
for batch_idx , (data , target ) in enumerate (data_loader ):
39
43
optimizer .zero_grad ()
40
- output = model (data )
41
- loss = F .nll_loss (output , target )
44
+ output = model (data . to ( device ) )
45
+ loss = F .nll_loss (output , target . to ( device ) )
42
46
loss .backward ()
43
47
optimizer .step ()
44
48
if batch_idx % args .log_interval == 0 :
@@ -47,16 +51,16 @@ def train_epoch(epoch, args, model, data_loader, optimizer):
47
51
100. * batch_idx / len (data_loader ), loss .item ()))
48
52
49
53
50
- def test_epoch (model , data_loader ):
54
+ def test_epoch (model , device , data_loader ):
51
55
model .eval ()
52
56
test_loss = 0
53
57
correct = 0
54
58
with torch .no_grad ():
55
59
for data , target in data_loader :
56
- output = model (data )
57
- test_loss += F .nll_loss (output , target , reduction = 'sum' ).item () # sum up batch loss
60
+ output = model (data . to ( device ) )
61
+ test_loss += F .nll_loss (output , target . to ( device ) , reduction = 'sum' ).item () # sum up batch loss
58
62
pred = output .max (1 )[1 ] # get the index of the max log-probability
59
- correct += pred .eq (target ).sum ().item ()
63
+ correct += pred .eq (target . to ( device ) ).sum ().item ()
60
64
61
65
test_loss /= len (data_loader .dataset )
62
66
print ('\n Test set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n ' .format (
0 commit comments