diff --git a/test.py b/test.py old mode 100755 new mode 100644 index 2729657..8a5adf6 --- a/test.py +++ b/test.py @@ -19,13 +19,16 @@ def detect(net,img): img = img.transpose(2, 0, 1) img = img.reshape((1,)+img.shape) - img = Variable(torch.from_numpy(img).float(),volatile=True).cuda() + img = Variable(torch.from_numpy(img).float()) + if use_cuda: + img = img.cuda() BB,CC,HH,WW = img.size() - olist = net(img) + with torch.no_grad(): + olist = net(img) bboxlist = [] - for i in range(len(olist)/2): olist[i*2] = F.softmax(olist[i*2]) - for i in range(len(olist)/2): + for i in range(int(len(olist)/2)): olist[i*2] = F.softmax(olist[i*2]) + for i in range(int(len(olist)/2)): ocls,oreg = olist[i*2].data.cpu(),olist[i*2+1].data.cpu() FB,FC,FH,FW = ocls.size() # feature map size stride = 2**(i+2) # 4,8,16,32,64,128 @@ -58,7 +61,8 @@ def detect(net,img): net = getattr(net_s3fd,args.net)() if args.model!='' :net.load_state_dict(torch.load(args.model)) else: print('Please set --model parameter!') -net.cuda() +if use_cuda: + net.cuda() net.eval() @@ -82,4 +86,4 @@ def detect(net,img): if cv2.waitKey(1) & 0xFF == ord('q'): break else: cv2.imwrite(args.path[:-4]+'_output.png',imgshow) - if cv2.waitKey(0) or True: break \ No newline at end of file + if cv2.waitKey(0) or True: break