@@ -36,11 +36,29 @@ def add_submodule(seq, *args):
36
36
37
37
class Convertor (object ):
38
38
39
- def __init__ (self ):
39
+ def __init__ (self , model ):
40
40
self .prefix_code = []
41
41
self .t2pt_names = dict ()
42
42
self .t2pt_layers = dict ()
43
43
44
+ def search_max_unpool (model ):
45
+ modules = []
46
+ modules .extend (model .modules )
47
+ containers = ['Sequential' , 'Concat' ]
48
+
49
+ while modules :
50
+ m = modules .pop ()
51
+ name = type (m ).__name__
52
+ if name in containers :
53
+ modules .extend (m .modules )
54
+
55
+ if name == 'SpatialMaxUnpooling' :
56
+ return True
57
+
58
+ return False
59
+
60
+ self .have_max_unpool = search_max_unpool (model )
61
+
44
62
def lua_recursive_model (self , module , seq ):
45
63
for m in module .modules :
46
64
name = type (m ).__name__
@@ -69,9 +87,11 @@ def lua_recursive_model(self, module, seq):
69
87
n = nn .Sigmoid ()
70
88
add_submodule (seq , n )
71
89
elif name == 'SpatialMaxPooling' :
72
- # n = nn.MaxPool2d((m.kW, m.kH), (m.dW, m.dH), (m.padW, m.padH), ceil_mode=m.ceil_mode)
73
- n = StatefulMaxPool2d ((m .kH , m .kW ), (m .dH , m .dW ), (m .padH , m .padW ), ceil_mode = m .ceil_mode )
74
- self .t2pt_layers [m ] = n
90
+ if not self .have_max_unpool :
91
+ n = nn .MaxPool2d ((m .kW , m .kH ), (m .dW , m .dH ), (m .padW , m .padH ), ceil_mode = m .ceil_mode )
92
+ else :
93
+ n = StatefulMaxPool2d ((m .kH , m .kW ), (m .dH , m .dW ), (m .padH , m .padW ), ceil_mode = m .ceil_mode )
94
+ self .t2pt_layers [m ] = n
75
95
add_submodule (seq , n )
76
96
elif name == 'SpatialMaxUnpooling' :
77
97
if m .pooling in self .t2pt_layers :
@@ -164,30 +184,33 @@ def lua_recursive_source(self, module):
164
184
165
185
if name == 'SpatialConvolution' or name == 'nn.SpatialConvolutionMM' :
166
186
if not hasattr (m , 'groups' ) or m .groups is None : m .groups = 1
167
- s += ['nn.Conv2d({},{},{},{},{},{},{},bias={}),#Conv2d' .format (m .nInputPlane ,
168
- m .nOutputPlane , (m .kW , m .kH ), (m .dW , m .dH ), (m .padW , m .padH ), 1 , m .groups , m .bias is not None )]
187
+ s += ['nn.Conv2d({}, {}, {}, {}, {}, {}, {},bias={}), #Conv2d' .format (m .nInputPlane ,
188
+ m .nOutputPlane , (m .kW , m .kH ), (m .dW , m .dH ), (m .padW , m .padH ), 1 , m .groups ,
189
+ m .bias is not None )]
169
190
elif name == 'SpatialBatchNormalization' :
170
- s += ['nn.BatchNorm2d({},{},{},{}),#BatchNorm2d' .format (m .running_mean .size (0 ), m .eps , m .momentum , m .affine )]
191
+ s += ['nn.BatchNorm2d({}, {}, {}, {}), #BatchNorm2d' .format (m .running_mean .size (0 ), m .eps , m .momentum , m .affine )]
171
192
elif name == 'VolumetricBatchNormalization' :
172
193
s += ['nn.BatchNorm3d({},{},{},{}),#BatchNorm3d' .format (m .running_mean .size (0 ), m .eps , m .momentum , m .affine )]
173
194
elif name == 'ReLU' :
174
195
s += ['nn.ReLU()' ]
175
196
elif name == 'Sigmoid' :
176
197
s += ['nn.Sigmoid()' ]
177
198
elif name == 'SpatialMaxPooling' :
178
- # s += ['nn.MaxPool2d({},{},{},ceil_mode={}),#MaxPool2d'.format((m.kW, m.kH), (m.dW, m.dH), (m.padW, m.padH), m.ceil_mode)]
179
- suffixes = sorted (int (re .match ('pooling_(\d*)' , v ).group (1 )) for v in self .t2pt_names .values ())
180
- name = 'pooling_{}' .format (suffixes [- 1 ] + 1 if suffixes else 1 )
181
- s += [name ]
182
- self .t2pt_names [m ] = name
183
- self .prefix_code += ['{} = StatefulMaxPool2d({}, {}, {}, ceil_mode={})' .format (name , (m .kH , m .kW ), (m .dH , m .dW ), (m .padH , m .padW ), m .ceil_mode )]
199
+ if not self .have_max_unpool :
200
+ s += ['nn.MaxPool2d({}, {}, {}, ceil_mode={}), #MaxPool2d' .format ((m .kW , m .kH ), (m .dW , m .dH ), (m .padW , m .padH ), m .ceil_mode )]
201
+ else :
202
+ suffixes = sorted (int (re .match ('pooling_(\d*)' , v ).group (1 )) for v in self .t2pt_names .values ())
203
+ name = 'pooling_{}' .format (suffixes [- 1 ] + 1 if suffixes else 1 )
204
+ s += [name ]
205
+ self .t2pt_names [m ] = name
206
+ self .prefix_code += ['{} = StatefulMaxPool2d({}, {}, {}, ceil_mode={})' .format (name , (m .kH , m .kW ), (m .dH , m .dW ), (m .padH , m .padW ), m .ceil_mode )]
184
207
elif name == 'SpatialMaxUnpooling' :
185
208
if m .pooling in self .t2pt_names :
186
209
s += ['StatefulMaxUnpool2d({}), #SpatialMaxUnpooling' .format (self .t2pt_names [m .pooling ])]
187
210
else :
188
211
s += ['# ' + name + ' Not Implement (can\' t find corresponding SpatialMaxUnpooling,\n ' ]
189
212
elif name == 'SpatialAveragePooling' :
190
- s += ['nn.AvgPool2d({},{},{},ceil_mode={}),#AvgPool2d' .format ((m .kW , m .kH ), (m .dW , m .dH ), (m .padW , m .padH ), m .ceil_mode )]
213
+ s += ['nn.AvgPool2d({}, {}, {}, ceil_mode={}), #AvgPool2d' .format ((m .kW , m .kH ), (m .dW , m .dH ), (m .padW , m .padH ), m .ceil_mode )]
191
214
elif name == 'SpatialUpSamplingNearest' :
192
215
s += ['nn.UpsamplingNearest2d(scale_factor={})' .format (m .scale_factor )]
193
216
elif name == 'View' :
@@ -197,7 +220,7 @@ def lua_recursive_source(self, module):
197
220
elif name == 'Linear' :
198
221
s1 = 'Lambda(lambda x: x.view(1,-1) if 1==len(x.size()) else x )'
199
222
s2 = 'nn.Linear({},{},bias={})' .format (m .weight .size (1 ), m .weight .size (0 ), (m .bias is not None ))
200
- s += ['nn.Sequential({},{}),#Linear' .format (s1 , s2 )]
223
+ s += ['nn.Sequential({}, {}), #Linear' .format (s1 , s2 )]
201
224
elif name == 'Dropout' :
202
225
s += ['nn.Dropout({})' .format (m .p )]
203
226
elif name == 'SoftMax' :
@@ -245,20 +268,20 @@ def lua_recursive_source(self, module):
245
268
246
269
@staticmethod
247
270
def simplify_source (s ):
248
- s = map (lambda x : x .replace (',(1, 1),(0, 0),1,1, bias=True),#Conv2d' , ')' ), s )
249
- s = map (lambda x : x .replace (',(0, 0),1,1, bias=True),#Conv2d' , ')' ), s )
250
- s = map (lambda x : x .replace (',1,1, bias=True),#Conv2d' , ')' ), s )
251
- s = map (lambda x : x .replace (',bias=True),#Conv2d' , ')' ), s )
252
- s = map (lambda x : x .replace ('),#Conv2d' , ')' ), s )
253
- s = map (lambda x : x .replace (',1e-05,0.1,True),#BatchNorm2d' , ')' ), s )
254
- s = map (lambda x : x .replace ('),#BatchNorm2d' , ')' ), s )
255
- s = map (lambda x : x .replace (',(0, 0),ceil_mode=False),#MaxPool2d' , ')' ), s )
256
- s = map (lambda x : x .replace (',ceil_mode=False),#MaxPool2d' , ')' ), s )
257
- s = map (lambda x : x .replace ('),#MaxPool2d' , ')' ), s )
258
- s = map (lambda x : x .replace (',(0, 0),ceil_mode=False),#AvgPool2d' , ')' ), s )
259
- s = map (lambda x : x .replace (',ceil_mode=False),#AvgPool2d' , ')' ), s )
260
- s = map (lambda x : x .replace (',bias=True)),#Linear' , ')), # Linear' ), s )
261
- s = map (lambda x : x .replace (')),#Linear' , ')), # Linear' ), s )
271
+ s = map (lambda x : x .replace (', (1, 1), (0, 0), 1, 1, bias=True), #Conv2d' , ')' ), s )
272
+ s = map (lambda x : x .replace (', (0, 0), 1, 1, bias=True), #Conv2d' , ')' ), s )
273
+ s = map (lambda x : x .replace (', 1, 1, bias=True), #Conv2d' , ')' ), s )
274
+ s = map (lambda x : x .replace (', bias=True), #Conv2d' , ')' ), s )
275
+ s = map (lambda x : x .replace ('), #Conv2d' , ')' ), s )
276
+ s = map (lambda x : x .replace (', 1e-05, 0.1, True), #BatchNorm2d' , ')' ), s )
277
+ s = map (lambda x : x .replace ('), #BatchNorm2d' , ')' ), s )
278
+ s = map (lambda x : x .replace (', (0, 0), ceil_mode=False), #MaxPool2d' , ')' ), s )
279
+ s = map (lambda x : x .replace (', ceil_mode=False), #MaxPool2d' , ')' ), s )
280
+ s = map (lambda x : x .replace ('), #MaxPool2d' , ')' ), s )
281
+ s = map (lambda x : x .replace (', (0, 0), ceil_mode=False), #AvgPool2d' , ')' ), s )
282
+ s = map (lambda x : x .replace (', ceil_mode=False), #AvgPool2d' , ')' ), s )
283
+ s = map (lambda x : x .replace (', bias=True)), #Linear' , ')), # Linear' ), s )
284
+ s = map (lambda x : x .replace (')), #Linear' , ')), # Linear' ), s )
262
285
263
286
s = map (lambda x : '{},\n ' .format (x ), s )
264
287
s = map (lambda x : x [1 :], s )
@@ -272,17 +295,19 @@ def torch_to_pytorch(t7_filename, outputname=None):
272
295
model = model .model
273
296
model .gradInput = None
274
297
275
- cvt = Convertor ()
276
- slist = cvt .lua_recursive_source (lnn .Sequential ().add (model ))
277
- s = cvt .simplify_source (slist )
298
+ cvt = Convertor (model )
299
+ s = cvt .lua_recursive_source (lnn .Sequential ().add (model ))
300
+ s = cvt .simplify_source (s )
278
301
279
302
varname = os .path .basename (t7_filename ).replace ('.t7' , '' ).replace ('.' , '_' ).replace ('-' , '_' )
280
303
281
304
with open ("header.py" ) as f :
282
305
header = f .read ()
283
306
s = '{}\n {}\n \n {} = {}' .format (header , '\n ' .join (cvt .prefix_code ), varname , s [:- 2 ])
284
307
285
- if outputname is None : outputname = varname
308
+ if outputname is None :
309
+ outputname = varname
310
+
286
311
with open (outputname + '.py' , "w" ) as pyfile :
287
312
pyfile .write (s )
288
313
@@ -294,7 +319,7 @@ def torch_to_pytorch(t7_filename, outputname=None):
294
319
if __name__ == '__main__' :
295
320
parser = argparse .ArgumentParser (description = 'Convert torch t7 model to pytorch' )
296
321
parser .add_argument ('--model' , '-m' , type = str , required = True , help = 'torch model file in t7 format' )
297
- parser .add_argument ('--output' , '-o' , type = str , default = None , help = 'output file name prefix, xxx.py xxx.pth' )
322
+ parser .add_argument ('--output' , '-o' , type = str , default = '/tmp/model' , help = 'output file name prefix, xxx.py xxx.pth' )
298
323
args = parser .parse_args ()
299
324
300
325
torch_to_pytorch (args .model , args .output )
0 commit comments