From ad89b5029b6ddf5f98cef06f69522c40f07cf6a5 Mon Sep 17 00:00:00 2001 From: Mughees Ahmad Date: Thu, 24 Oct 2024 14:52:53 -0400 Subject: [PATCH] fixed evaluate to match train function --- imagenet/main.py | 16 +++++----------- 1 file changed, 5 insertions(+), 11 deletions(-) diff --git a/imagenet/main.py b/imagenet/main.py index cc32d50733..1644ef0577 100644 --- a/imagenet/main.py +++ b/imagenet/main.py @@ -270,7 +270,7 @@ def main_worker(gpu, ngpus_per_node, args): num_workers=args.workers, pin_memory=True, sampler=val_sampler) if args.evaluate: - validate(val_loader, model, criterion, args) + validate(val_loader, model, criterion, device, args) return for epoch in range(args.start_epoch, args.epochs): @@ -281,7 +281,7 @@ def main_worker(gpu, ngpus_per_node, args): train(train_loader, model, criterion, optimizer, epoch, device, args) # evaluate on validation set - acc1 = validate(val_loader, model, criterion, args) + acc1 = validate(val_loader, model, criterion, device, args) scheduler.step() @@ -347,21 +347,15 @@ def train(train_loader, model, criterion, optimizer, epoch, device, args): progress.display(i + 1) -def validate(val_loader, model, criterion, args): +def validate(val_loader, model, criterion, device, args): def run_validate(loader, base_progress=0): with torch.no_grad(): end = time.time() for i, (images, target) in enumerate(loader): i = base_progress + i - if args.gpu is not None and torch.cuda.is_available(): - images = images.cuda(args.gpu, non_blocking=True) - if torch.backends.mps.is_available(): - images = images.to('mps') - target = target.to('mps') - if torch.cuda.is_available(): - target = target.cuda(args.gpu, non_blocking=True) - + images = images.to(device, non_blocking=True) + target = target.to(device, non_blocking=True) # compute output output = model(images) loss = criterion(output, target)