-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathtest.py
109 lines (81 loc) · 3.97 KB
/
test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
import argparse
import os
import numpy as np
import cv2
import datetime
import random
from ipdb import set_trace
import torch
from torch import autograd
from torch import nn, optim
from tensorboardX import SummaryWriter
from gmair.models.model import gmair
from gmair.config import config as cfg
from gmair.dataset.fruit2d import FruitDataset
from gmair.dataset.multi_mnist import SimpleScatteredMNISTDataset
from gmair.utils import debug_tools
from gmair.test import metric, cluster_metric
parser = argparse.ArgumentParser()
parser.add_argument('--test_num_iters', type=int, default=1000, help='test number of iterations')
args = parser.parse_args()
random.seed(3)
torch.manual_seed(3)
log_dir = os.path.join(
cfg.log_dir,
'test',
datetime.datetime.now().strftime("%Y_%m_%d__%H_%M_%S"))
log_dir = os.path.abspath(log_dir)
writer = SummaryWriter(log_dir)
print('log path : {}'.format(log_dir))
device = 'cuda' if torch.cuda.is_available() else 'cpu'
def test():
image_shape = cfg.input_image_shape
if cfg.dataset == 'multi_mnist':
data = SimpleScatteredMNISTDataset(os.path.join(cfg.dataset_dir, 'scattered_mnist_128x128_obj14x14.hdf5'), mode = 'test')
elif cfg.dataset == 'fruit2d':
data = FruitDataset(os.path.join(cfg.dataset_dir, 'train_images'), os.path.join(cfg.dataset_dir, 'train_labels'))
else:
print('No implementation for {}'.format(cfg.dataset))
exit(0)
gmair_net = gmair(image_shape, writer, device).to(device)
gmair_net.load_state_dict(torch.load(cfg.test_model_path)) # os.path.join(cfg.test_model_path, 'checkpoints', 'step_{}.pkl'.format(cfg.test_iter))))
torch.manual_seed(10)
dataloader = torch.utils.data.DataLoader(data,
batch_size = cfg.test_batch_size,
pin_memory = True,
num_workers = cfg.num_workers,
drop_last = False,
shuffle = True
)
tot_iter = min(len(dataloader), args.test_num_iters)
# print(len(dataloader))
all_z_where = None
all_z_cls = None
all_obj_prob = None
all_y_bbox = None
all_y_obj_count = None
for batch_idx, batch in enumerate(dataloader):
if batch_idx >= tot_iter:
break
if (batch_idx + 1) % 100 == 0:
print('test {} / {}'.format(batch_idx + 1, tot_iter))
x_image, y_bbox, y_obj_count = batch
x_image = x_image.to(device)
y_bbox = y_bbox.to(device)
y_obj_count = y_obj_count.to(device)
with torch.no_grad():
gmair_net.eval()
out_img, z_cls, z_what, z_where, obj_prob = gmair_net(x_image, mode = 'infer')
all_z_where = z_where if all_z_where is None else torch.cat([all_z_where, z_where], dim = 0)
all_z_cls = z_cls if all_z_cls is None else torch.cat([all_z_cls, z_cls], dim = 0)
all_obj_prob = obj_prob if all_obj_prob is None else torch.cat([all_obj_prob, obj_prob], dim = 0)
all_y_bbox = y_bbox if all_y_bbox is None else torch.cat([all_y_bbox, y_bbox], dim = 0)
all_y_obj_count = y_obj_count if all_y_obj_count is None else torch.cat([all_y_obj_count, y_obj_count], dim = 0)
# debug_tools.plot_infer_render_components(x_image, y_bbox, obj_prob, z_cls, z_where, out_img, writer, batch_idx)
meanAP = metric.mAP(all_z_where, all_obj_prob, all_y_bbox[:, :, :4], all_y_obj_count)
print('Bbox Average Precision : ', meanAP)
acc, nmi = cluster_metric.test_cluster(all_z_where, all_z_cls, all_obj_prob, all_y_bbox)
print('Cluster Accuracy : ', acc)
print('Cluster NMI : ', nmi)
if __name__ == '__main__':
test()