forked from AmrElsersy/Emotions-Recognition
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtest.py
More file actions
84 lines (70 loc) · 3.04 KB
/
test.py
File metadata and controls
84 lines (70 loc) · 3.04 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
"""
Author: Amr Elsersy
email: amrelsersay@gmail.com
-----------------------------------------------------------------------------------
Description: Testing
"""
import numpy as np
import argparse
import logging
import time
import os
from tqdm import tqdm
import torch
import torch.nn as nn
import torch.optim
import torch.utils.tensorboard as tensorboard
import torch.backends.cudnn as cudnn
cudnn.benchmark = True
cudnn.enabled = True
import cv2
import torchvision.transforms.transforms as transforms
from model.model import Mini_Xception
from dataset import FER2013
from utils import get_label_emotion
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument('--epochs', type=int, default=300, help='num of training epochs')
parser.add_argument('--datapath', type=str, default='data', help='root path of augumented WFLW dataset')
parser.add_argument('--resume', action='store_true', help='resume from pretrained path specified in prev arg')
parser.add_argument('--mode', type=str, choices=['train', 'test', 'val'], default='test', help='dataset mode')
parser.add_argument('--pretrained', type=str,default='checkpoint/model_weights/weights_epoch_75.pth.tar')
args = parser.parse_args()
return args
# ======================================================================
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
args = parse_args()
def main():
mini_xception = Mini_Xception()
mini_xception.to(device)
mini_xception.eval()
checkpoint = torch.load(args.pretrained, map_location=device)
mini_xception.load_state_dict(checkpoint['mini_xception'], strict=False)
print(f'\tLoaded checkpoint from {args.pretrained}\n')
dataset = FER2013(args.datapath, args.mode, transform=transforms.ToTensor())
print(f'dataset size = {len(dataset)}')
with torch.no_grad():
for i in range(len(dataset)):
face, label = dataset[i]
temp_face = face.squeeze().numpy()
face = face.to(device)
face = torch.unsqueeze(face, 0)
emotion = mini_xception(face)
# torch.set_printoptions(precision=6)
# softmax = nn.Softmax()
# emotions_soft = softmax(emotion.squeeze()).reshape(-1,1).cpu().detach().numpy()
# emotions_soft = np.round(emotions_soft, 3)
# for i, em in enumerate(emotions_soft):
# em = round(em.item(),3)
# print(f'{get_label_emotion(i)} : {em}')
# # print(f'softmax {emotions_soft}')
_, emotion = torch.max(emotion, axis=1)
temp_face = cv2.resize(temp_face, (200,200))
cv2.putText(temp_face, get_label_emotion(emotion.squeeze().cpu().item()), (0,20), cv2.FONT_HERSHEY_COMPLEX, 1, (255,255,255))
cv2.putText(temp_face, get_label_emotion(label.item()), (0,190), cv2.FONT_HERSHEY_COMPLEX, 1, (255,255,0))
cv2.imshow('face', temp_face)
if cv2.waitKey(0) == 27:
cv2.destroyAllWindows()
break
if __name__ == "__main__":
main()