Skip to content
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.

Commit a673b35

Browse files
author
Anatoly Baksheev
committedJul 6, 2018
formatting + use nn.MaxPool2d if no SpatialMaxUnpooling
1 parent e77709f commit a673b35

File tree

1 file changed

+59
-34
lines changed

1 file changed

+59
-34
lines changed
 

‎convert_torch.py

+59-34
Original file line numberDiff line numberDiff line change
@@ -36,11 +36,29 @@ def add_submodule(seq, *args):
3636

3737
class Convertor(object):
3838

39-
def __init__(self):
39+
def __init__(self, model):
4040
self.prefix_code = []
4141
self.t2pt_names = dict()
4242
self.t2pt_layers = dict()
4343

44+
def search_max_unpool(model):
45+
modules = []
46+
modules.extend(model.modules)
47+
containers = ['Sequential', 'Concat']
48+
49+
while modules:
50+
m = modules.pop()
51+
name = type(m).__name__
52+
if name in containers:
53+
modules.extend(m.modules)
54+
55+
if name == 'SpatialMaxUnpooling':
56+
return True
57+
58+
return False
59+
60+
self.have_max_unpool = search_max_unpool(model)
61+
4462
def lua_recursive_model(self, module, seq):
4563
for m in module.modules:
4664
name = type(m).__name__
@@ -69,9 +87,11 @@ def lua_recursive_model(self, module, seq):
6987
n = nn.Sigmoid()
7088
add_submodule(seq, n)
7189
elif name == 'SpatialMaxPooling':
72-
# n = nn.MaxPool2d((m.kW, m.kH), (m.dW, m.dH), (m.padW, m.padH), ceil_mode=m.ceil_mode)
73-
n = StatefulMaxPool2d((m.kH, m.kW), (m.dH, m.dW), (m.padH, m.padW), ceil_mode=m.ceil_mode)
74-
self.t2pt_layers[m] = n
90+
if not self.have_max_unpool:
91+
n = nn.MaxPool2d((m.kW, m.kH), (m.dW, m.dH), (m.padW, m.padH), ceil_mode=m.ceil_mode)
92+
else:
93+
n = StatefulMaxPool2d((m.kH, m.kW), (m.dH, m.dW), (m.padH, m.padW), ceil_mode=m.ceil_mode)
94+
self.t2pt_layers[m] = n
7595
add_submodule(seq, n)
7696
elif name == 'SpatialMaxUnpooling':
7797
if m.pooling in self.t2pt_layers:
@@ -164,30 +184,33 @@ def lua_recursive_source(self, module):
164184

165185
if name == 'SpatialConvolution' or name == 'nn.SpatialConvolutionMM':
166186
if not hasattr(m, 'groups') or m.groups is None: m.groups = 1
167-
s += ['nn.Conv2d({},{},{},{},{},{},{},bias={}),#Conv2d'.format(m.nInputPlane,
168-
m.nOutputPlane, (m.kW, m.kH), (m.dW, m.dH), (m.padW, m.padH), 1, m.groups, m.bias is not None)]
187+
s += ['nn.Conv2d({}, {}, {}, {}, {}, {}, {},bias={}), #Conv2d'.format(m.nInputPlane,
188+
m.nOutputPlane, (m.kW, m.kH), (m.dW, m.dH), (m.padW, m.padH), 1, m.groups,
189+
m.bias is not None)]
169190
elif name == 'SpatialBatchNormalization':
170-
s += ['nn.BatchNorm2d({},{},{},{}),#BatchNorm2d'.format(m.running_mean.size(0), m.eps, m.momentum, m.affine)]
191+
s += ['nn.BatchNorm2d({}, {}, {}, {}), #BatchNorm2d'.format(m.running_mean.size(0), m.eps, m.momentum, m.affine)]
171192
elif name == 'VolumetricBatchNormalization':
172193
s += ['nn.BatchNorm3d({},{},{},{}),#BatchNorm3d'.format(m.running_mean.size(0), m.eps, m.momentum, m.affine)]
173194
elif name == 'ReLU':
174195
s += ['nn.ReLU()']
175196
elif name == 'Sigmoid':
176197
s += ['nn.Sigmoid()']
177198
elif name == 'SpatialMaxPooling':
178-
# s += ['nn.MaxPool2d({},{},{},ceil_mode={}),#MaxPool2d'.format((m.kW, m.kH), (m.dW, m.dH), (m.padW, m.padH), m.ceil_mode)]
179-
suffixes = sorted(int(re.match('pooling_(\d*)', v).group(1)) for v in self.t2pt_names.values())
180-
name = 'pooling_{}'.format(suffixes[-1] + 1 if suffixes else 1)
181-
s += [name]
182-
self.t2pt_names[m] = name
183-
self.prefix_code += ['{} = StatefulMaxPool2d({}, {}, {}, ceil_mode={})'.format(name, (m.kH, m.kW), (m.dH, m.dW), (m.padH, m.padW), m.ceil_mode)]
199+
if not self.have_max_unpool:
200+
s += ['nn.MaxPool2d({}, {}, {}, ceil_mode={}), #MaxPool2d'.format((m.kW, m.kH), (m.dW, m.dH), (m.padW, m.padH), m.ceil_mode)]
201+
else:
202+
suffixes = sorted(int(re.match('pooling_(\d*)', v).group(1)) for v in self.t2pt_names.values())
203+
name = 'pooling_{}'.format(suffixes[-1] + 1 if suffixes else 1)
204+
s += [name]
205+
self.t2pt_names[m] = name
206+
self.prefix_code += ['{} = StatefulMaxPool2d({}, {}, {}, ceil_mode={})'.format(name, (m.kH, m.kW), (m.dH, m.dW), (m.padH, m.padW), m.ceil_mode)]
184207
elif name == 'SpatialMaxUnpooling':
185208
if m.pooling in self.t2pt_names:
186209
s += ['StatefulMaxUnpool2d({}), #SpatialMaxUnpooling'.format(self.t2pt_names[m.pooling])]
187210
else:
188211
s += ['# ' + name + ' Not Implement (can\'t find corresponding SpatialMaxUnpooling,\n']
189212
elif name == 'SpatialAveragePooling':
190-
s += ['nn.AvgPool2d({},{},{},ceil_mode={}),#AvgPool2d'.format((m.kW, m.kH), (m.dW, m.dH), (m.padW, m.padH), m.ceil_mode)]
213+
s += ['nn.AvgPool2d({}, {}, {}, ceil_mode={}), #AvgPool2d'.format((m.kW, m.kH), (m.dW, m.dH), (m.padW, m.padH), m.ceil_mode)]
191214
elif name == 'SpatialUpSamplingNearest':
192215
s += ['nn.UpsamplingNearest2d(scale_factor={})'.format(m.scale_factor)]
193216
elif name == 'View':
@@ -197,7 +220,7 @@ def lua_recursive_source(self, module):
197220
elif name == 'Linear':
198221
s1 = 'Lambda(lambda x: x.view(1,-1) if 1==len(x.size()) else x )'
199222
s2 = 'nn.Linear({},{},bias={})'.format(m.weight.size(1), m.weight.size(0), (m.bias is not None))
200-
s += ['nn.Sequential({},{}),#Linear'.format(s1, s2)]
223+
s += ['nn.Sequential({}, {}), #Linear'.format(s1, s2)]
201224
elif name == 'Dropout':
202225
s += ['nn.Dropout({})'.format(m.p)]
203226
elif name == 'SoftMax':
@@ -245,20 +268,20 @@ def lua_recursive_source(self, module):
245268

246269
@staticmethod
247270
def simplify_source(s):
248-
s = map(lambda x: x.replace(',(1, 1),(0, 0),1,1,bias=True),#Conv2d', ')'), s)
249-
s = map(lambda x: x.replace(',(0, 0),1,1,bias=True),#Conv2d', ')'), s)
250-
s = map(lambda x: x.replace(',1,1,bias=True),#Conv2d', ')'), s)
251-
s = map(lambda x: x.replace(',bias=True),#Conv2d', ')'), s)
252-
s = map(lambda x: x.replace('),#Conv2d', ')'), s)
253-
s = map(lambda x: x.replace(',1e-05,0.1,True),#BatchNorm2d', ')'), s)
254-
s = map(lambda x: x.replace('),#BatchNorm2d', ')'), s)
255-
s = map(lambda x: x.replace(',(0, 0),ceil_mode=False),#MaxPool2d', ')'), s)
256-
s = map(lambda x: x.replace(',ceil_mode=False),#MaxPool2d', ')'), s)
257-
s = map(lambda x: x.replace('),#MaxPool2d', ')'), s)
258-
s = map(lambda x: x.replace(',(0, 0),ceil_mode=False),#AvgPool2d', ')'), s)
259-
s = map(lambda x: x.replace(',ceil_mode=False),#AvgPool2d', ')'), s)
260-
s = map(lambda x: x.replace(',bias=True)),#Linear', ')), # Linear'), s)
261-
s = map(lambda x: x.replace(')),#Linear', ')), # Linear'), s)
271+
s = map(lambda x: x.replace(', (1, 1), (0, 0), 1, 1, bias=True), #Conv2d', ')'), s)
272+
s = map(lambda x: x.replace(', (0, 0), 1, 1, bias=True), #Conv2d', ')'), s)
273+
s = map(lambda x: x.replace(', 1, 1, bias=True), #Conv2d', ')'), s)
274+
s = map(lambda x: x.replace(', bias=True), #Conv2d', ')'), s)
275+
s = map(lambda x: x.replace('), #Conv2d', ')'), s)
276+
s = map(lambda x: x.replace(', 1e-05, 0.1, True), #BatchNorm2d', ')'), s)
277+
s = map(lambda x: x.replace('), #BatchNorm2d', ')'), s)
278+
s = map(lambda x: x.replace(', (0, 0), ceil_mode=False), #MaxPool2d', ')'), s)
279+
s = map(lambda x: x.replace(', ceil_mode=False), #MaxPool2d', ')'), s)
280+
s = map(lambda x: x.replace('), #MaxPool2d', ')'), s)
281+
s = map(lambda x: x.replace(', (0, 0), ceil_mode=False), #AvgPool2d', ')'), s)
282+
s = map(lambda x: x.replace(', ceil_mode=False), #AvgPool2d', ')'), s)
283+
s = map(lambda x: x.replace(', bias=True)), #Linear', ')), # Linear'), s)
284+
s = map(lambda x: x.replace(')), #Linear', ')), # Linear'), s)
262285

263286
s = map(lambda x: '{},\n'.format(x), s)
264287
s = map(lambda x: x[1:], s)
@@ -272,17 +295,19 @@ def torch_to_pytorch(t7_filename, outputname=None):
272295
model = model.model
273296
model.gradInput = None
274297

275-
cvt = Convertor()
276-
slist = cvt.lua_recursive_source(lnn.Sequential().add(model))
277-
s = cvt.simplify_source(slist)
298+
cvt = Convertor(model)
299+
s = cvt.lua_recursive_source(lnn.Sequential().add(model))
300+
s = cvt.simplify_source(s)
278301

279302
varname = os.path.basename(t7_filename).replace('.t7', '').replace('.', '_').replace('-', '_')
280303

281304
with open("header.py") as f:
282305
header = f.read()
283306
s = '{}\n{}\n\n{} = {}'.format(header, '\n'.join(cvt.prefix_code), varname, s[:-2])
284307

285-
if outputname is None: outputname = varname
308+
if outputname is None:
309+
outputname = varname
310+
286311
with open(outputname + '.py', "w") as pyfile:
287312
pyfile.write(s)
288313

@@ -294,7 +319,7 @@ def torch_to_pytorch(t7_filename, outputname=None):
294319
if __name__ == '__main__':
295320
parser = argparse.ArgumentParser(description='Convert torch t7 model to pytorch')
296321
parser.add_argument('--model', '-m', type=str, required=True, help='torch model file in t7 format')
297-
parser.add_argument('--output', '-o', type=str, default=None, help='output file name prefix, xxx.py xxx.pth')
322+
parser.add_argument('--output', '-o', type=str, default='/tmp/model', help='output file name prefix, xxx.py xxx.pth')
298323
args = parser.parse_args()
299324

300325
torch_to_pytorch(args.model, args.output)

0 commit comments

Comments
 (0)
Please sign in to comment.