-
Notifications
You must be signed in to change notification settings - Fork 0
/
predict1.py
72 lines (62 loc) · 2 KB
/
predict1.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
import torch
import torch.nn as nn
import torch.optim as optim
import sparseconvnet as scn
import uproot
import matplotlib.pyplot as plt
import numpy as np
from model import DeepVtx
from timeit import default_timer as timer
import csv
import util
# Use the GPU if there is one and sparseconvnet can use it, otherwise CPU
# use_cuda = torch.cuda.is_available() and scn.SCN.is_cuda_build()
use_cuda = False
torch.set_num_threads(1)
device = 'cuda:0' if use_cuda else 'cpu'
if use_cuda:
print("Using CUDA.")
else:
print("Using CPU.")
nIn = 1
model = DeepVtx(dimension=3, nIn=nIn, device=device)
model.train()
model_path = 't48k/m16-l5-lr5d-res0.5/CP24.pth'
model.load_state_dict(torch.load(model_path))
start_sample = 0
max_sample = 1000 + start_sample
resolution = 0.5
loose_cut = 1.0
# val_list = 'list/numucc-24k-val.csv'
val_list = 'list/nuecc-21k-val.csv'
results = []
start = timer()
with open(val_list) as f:
reader = csv.reader(f, delimiter=' ')
isample = 0
stat = {}
for row in reader:
isample = isample + 1
if isample < start_sample :
continue
if isample > max_sample :
break
print('isample: {} : {}'.format(isample,row[0]))
coords_np, ft_np = util.load(row, vis=False, resolution=resolution)
coords = torch.LongTensor(coords_np)
truth = torch.FloatTensor(ft_np[:,-1]).to(device)
ft = torch.FloatTensor(ft_np[:,0:-1]).to(device)
prediction = model([coords,ft[:,0:1]])
pred_np = prediction.cpu().detach().numpy()
pred_np = pred_np[:,1] - pred_np[:,0]
truth_np = truth.cpu().detach().numpy()
# prediction and vis
result = util.vis_prediction_regseg(
np.column_stack((coords_np, pred_np)),
np.column_stack((coords_np, truth_np)),
cand=np.column_stack((coords_np, ft_np[:,1])),
vis=True
)
results.append(result)
end = timer()
print('time: {0:.1f} ms'.format((end-start)/1*1000))