diff --git a/ML/Pytorch/Basics/pytorch_lr_ratescheduler.py b/ML/Pytorch/Basics/pytorch_lr_ratescheduler.py index 8a342070..b5c8b468 100644 --- a/ML/Pytorch/Basics/pytorch_lr_ratescheduler.py +++ b/ML/Pytorch/Basics/pytorch_lr_ratescheduler.py @@ -91,7 +91,7 @@ def check_accuracy(loader, model): for x, y in loader: x = x.to(device=device) y = y.to(device=device) - + x=x.view(x.size(0),-1) scores = model(x) _, predictions = scores.max(1) num_correct += (predictions == y).sum()