Skip to content

Commit 0c1fcc8

Browse files
vipinpillaisoumith
authored andcommitted
Update all prec occurrences to acc (pytorch#420)
1 parent 95d5fdd commit 0c1fcc8

File tree

1 file changed

+19
-19
lines changed

1 file changed

+19
-19
lines changed

imagenet/main.py

Lines changed: 19 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -62,11 +62,11 @@
6262
parser.add_argument('--gpu', default=None, type=int,
6363
help='GPU id to use.')
6464

65-
best_prec1 = 0
65+
best_acc1 = 0
6666

6767

6868
def main():
69-
global args, best_prec1
69+
global args, best_acc1
7070
args = parser.parse_args()
7171

7272
if args.seed is not None:
@@ -122,7 +122,7 @@ def main():
122122
print("=> loading checkpoint '{}'".format(args.resume))
123123
checkpoint = torch.load(args.resume)
124124
args.start_epoch = checkpoint['epoch']
125-
best_prec1 = checkpoint['best_prec1']
125+
best_acc1 = checkpoint['best_acc1']
126126
model.load_state_dict(checkpoint['state_dict'])
127127
optimizer.load_state_dict(checkpoint['optimizer'])
128128
print("=> loaded checkpoint '{}' (epoch {})"
@@ -179,16 +179,16 @@ def main():
179179
train(train_loader, model, criterion, optimizer, epoch)
180180

181181
# evaluate on validation set
182-
prec1 = validate(val_loader, model, criterion)
182+
acc1 = validate(val_loader, model, criterion)
183183

184-
# remember best prec@1 and save checkpoint
185-
is_best = prec1 > best_prec1
186-
best_prec1 = max(prec1, best_prec1)
184+
# remember best acc@1 and save checkpoint
185+
is_best = acc1 > best_acc1
186+
best_acc1 = max(acc1, best_acc1)
187187
save_checkpoint({
188188
'epoch': epoch + 1,
189189
'arch': args.arch,
190190
'state_dict': model.state_dict(),
191-
'best_prec1': best_prec1,
191+
'best_acc1': best_acc1,
192192
'optimizer' : optimizer.state_dict(),
193193
}, is_best)
194194

@@ -217,10 +217,10 @@ def train(train_loader, model, criterion, optimizer, epoch):
217217
loss = criterion(output, target)
218218

219219
# measure accuracy and record loss
220-
prec1, prec5 = accuracy(output, target, topk=(1, 5))
220+
acc1, acc5 = accuracy(output, target, topk=(1, 5))
221221
losses.update(loss.item(), input.size(0))
222-
top1.update(prec1[0], input.size(0))
223-
top5.update(prec5[0], input.size(0))
222+
top1.update(acc1[0], input.size(0))
223+
top5.update(acc5[0], input.size(0))
224224

225225
# compute gradient and do SGD step
226226
optimizer.zero_grad()
@@ -236,8 +236,8 @@ def train(train_loader, model, criterion, optimizer, epoch):
236236
'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
237237
'Data {data_time.val:.3f} ({data_time.avg:.3f})\t'
238238
'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
239-
'Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t'
240-
'Prec@5 {top5.val:.3f} ({top5.avg:.3f})'.format(
239+
'Acc@1 {top1.val:.3f} ({top1.avg:.3f})\t'
240+
'Acc@5 {top5.val:.3f} ({top5.avg:.3f})'.format(
241241
epoch, i, len(train_loader), batch_time=batch_time,
242242
data_time=data_time, loss=losses, top1=top1, top5=top5))
243243

@@ -263,10 +263,10 @@ def validate(val_loader, model, criterion):
263263
loss = criterion(output, target)
264264

265265
# measure accuracy and record loss
266-
prec1, prec5 = accuracy(output, target, topk=(1, 5))
266+
acc1, acc5 = accuracy(output, target, topk=(1, 5))
267267
losses.update(loss.item(), input.size(0))
268-
top1.update(prec1[0], input.size(0))
269-
top5.update(prec5[0], input.size(0))
268+
top1.update(acc1[0], input.size(0))
269+
top5.update(acc5[0], input.size(0))
270270

271271
# measure elapsed time
272272
batch_time.update(time.time() - end)
@@ -276,12 +276,12 @@ def validate(val_loader, model, criterion):
276276
print('Test: [{0}/{1}]\t'
277277
'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
278278
'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
279-
'Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t'
280-
'Prec@5 {top5.val:.3f} ({top5.avg:.3f})'.format(
279+
'Acc@1 {top1.val:.3f} ({top1.avg:.3f})\t'
280+
'Acc@5 {top5.val:.3f} ({top5.avg:.3f})'.format(
281281
i, len(val_loader), batch_time=batch_time, loss=losses,
282282
top1=top1, top5=top5))
283283

284-
print(' * Prec@1 {top1.avg:.3f} Prec@5 {top5.avg:.3f}'
284+
print(' * Acc@1 {top1.avg:.3f} Acc@5 {top5.avg:.3f}'
285285
.format(top1=top1, top5=top5))
286286

287287
return top1.avg

0 commit comments

Comments
 (0)