forked from wwwanghao/caffe2pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
test_caffe_loader2.py
43 lines (37 loc) · 1.26 KB
/
test_caffe_loader2.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
# 2017.12.16 by xiaohang
import sys
from caffenet import *
import numpy as np
import argparse
import torch.nn as nn
from torch.autograd import Variable
from torch.nn.parameter import Parameter
import time
def create_network(protofile, weightfile):
net = CaffeNet(protofile)
if args.cuda:
net.cuda()
print(net)
net.load_weights(weightfile)
net.train()
return net
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='convert caffe to pytorch')
parser.add_argument('--protofile', default='', type=str)
parser.add_argument('--weightfile', default='', type=str)
parser.add_argument('--cuda', action='store_true', help='enables cuda')
args = parser.parse_args()
print(args)
protofile = args.protofile
weightfile = args.weightfile
net = create_network(protofile, weightfile)
net.set_verbose(False)
for i in range(10):
blobs = net()
blob_names = blobs.keys()
for blob_name in blob_names:
if args.cuda:
blob_data = blobs[blob_name].data.cpu().numpy()
else:
blob_data = blobs[blob_name].data.numpy()
print('[%d] %-30s pytorch_shape: %-20s mean: %f' % (i, blob_name, blob_data.shape, blob_data.mean()))