-
Notifications
You must be signed in to change notification settings - Fork 3
Expand file tree
/
Copy pathchamfer_iou_clevr.py
More file actions
70 lines (56 loc) · 1.85 KB
/
chamfer_iou_clevr.py
File metadata and controls
70 lines (56 loc) · 1.85 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
from run_reconstruct_clevr import SSLR
import os
import data
import torch
from tqdm import tqdm
import matplotlib.pyplot as plt
import numpy as np
import pickle
from utils import chamfer_score, cv_bbox
import torch.multiprocessing as mp
import gc
import cv2
import argparse
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument('--model_type', help='model type: srn | mlp', default="srn")
parser.add_argument('--batch_size', type=int, help='batch size', default=32)
parser.add_argument('--resume', help='path to resume a saved checkpoint', default=None)
args = parser.parse_args()
return args
if __name__ == '__main__':
args = get_args()
print(args)
use_srn = args.model_type == "srn"
dataset_test = data.CLEVR(
"clevr_no_mask", "val", box=True, full=True, chamfer=True
)
batch_size = args.batch_size
test_loader = data.get_loader(
dataset_test, batch_size=batch_size, shuffle=False
)
net = SSLR(use_srn=use_srn).float().cuda()
net.eval()
net.load_state_dict(torch.load(args.resume))
test_loader = tqdm(
test_loader,
ncols=0,
desc="test"
)
full_score = 0
for idx, sample in enumerate(test_loader):
def tfunc():
gc.collect()
image, masks = [x.cuda() for x in sample]
p_, inner_losses, gs = net(image)
thresh_mask = gs < 1e-2
gs[thresh_mask] = 0
gs[~thresh_mask] = 1
gs = gs.sum(2).clamp(0,1)
gs = gs.to(dtype=torch.uint8)
img = cv_bbox(gs.detach().cpu().numpy().reshape(-1,128,128))
score = chamfer_score(img.cuda().to(dtype=torch.uint8), masks.to(dtype=torch.uint8))
return score
full_score += tfunc()
full_score /= len(test_loader)
print(full_score)