@@ -36,11 +36,27 @@ def add_submodule(seq, *args):
36
36
37
37
class Convertor (object ):
38
38
39
- def __init__ (self ):
39
+ def __init__ (self , model ):
40
40
self .prefix_code = []
41
41
self .t2pt_names = dict ()
42
42
self .t2pt_layers = dict ()
43
43
44
+ self .have_max_unpool = False
45
+
46
+ modules = []
47
+ modules .extend (model .modules )
48
+ containers = ['Sequential' , 'Concat' ]
49
+
50
+ while modules :
51
+ m = modules .pop ()
52
+ name = type (m ).__name__
53
+ if name in containers :
54
+ modules .extend (m .modules )
55
+
56
+ self .have_max_unpool = name == 'SpatialMaxUnpooling'
57
+ if self .have_max_unpool :
58
+ break
59
+
44
60
def lua_recursive_model (self , module , seq ):
45
61
for m in module .modules :
46
62
name = type (m ).__name__
@@ -69,9 +85,11 @@ def lua_recursive_model(self, module, seq):
69
85
n = nn .Sigmoid ()
70
86
add_submodule (seq , n )
71
87
elif name == 'SpatialMaxPooling' :
72
- # n = nn.MaxPool2d((m.kW, m.kH), (m.dW, m.dH), (m.padW, m.padH), ceil_mode=m.ceil_mode)
73
- n = StatefulMaxPool2d ((m .kH , m .kW ), (m .dH , m .dW ), (m .padH , m .padW ), ceil_mode = m .ceil_mode )
74
- self .t2pt_layers [m ] = n
88
+ if not self .have_max_unpool :
89
+ n = nn .MaxPool2d ((m .kW , m .kH ), (m .dW , m .dH ), (m .padW , m .padH ), ceil_mode = m .ceil_mode )
90
+ else :
91
+ n = StatefulMaxPool2d ((m .kH , m .kW ), (m .dH , m .dW ), (m .padH , m .padW ), ceil_mode = m .ceil_mode )
92
+ self .t2pt_layers [m ] = n
75
93
add_submodule (seq , n )
76
94
elif name == 'SpatialMaxUnpooling' :
77
95
if m .pooling in self .t2pt_layers :
@@ -164,30 +182,33 @@ def lua_recursive_source(self, module):
164
182
165
183
if name == 'SpatialConvolution' or name == 'nn.SpatialConvolutionMM' :
166
184
if not hasattr (m , 'groups' ) or m .groups is None : m .groups = 1
167
- s += ['nn.Conv2d({},{},{},{},{},{},{},bias={}),#Conv2d' .format (m .nInputPlane ,
168
- m .nOutputPlane , (m .kW , m .kH ), (m .dW , m .dH ), (m .padW , m .padH ), 1 , m .groups , m .bias is not None )]
185
+ s += ['nn.Conv2d({}, {}, {}, {}, {}, {}, {},bias={}), #Conv2d' .format (m .nInputPlane ,
186
+ m .nOutputPlane , (m .kW , m .kH ), (m .dW , m .dH ), (m .padW , m .padH ), 1 , m .groups ,
187
+ m .bias is not None )]
169
188
elif name == 'SpatialBatchNormalization' :
170
- s += ['nn.BatchNorm2d({},{},{},{}),#BatchNorm2d' .format (m .running_mean .size (0 ), m .eps , m .momentum , m .affine )]
189
+ s += ['nn.BatchNorm2d({}, {}, {}, {}), #BatchNorm2d' .format (m .running_mean .size (0 ), m .eps , m .momentum , m .affine )]
171
190
elif name == 'VolumetricBatchNormalization' :
172
191
s += ['nn.BatchNorm3d({},{},{},{}),#BatchNorm3d' .format (m .running_mean .size (0 ), m .eps , m .momentum , m .affine )]
173
192
elif name == 'ReLU' :
174
193
s += ['nn.ReLU()' ]
175
194
elif name == 'Sigmoid' :
176
195
s += ['nn.Sigmoid()' ]
177
196
elif name == 'SpatialMaxPooling' :
178
- # s += ['nn.MaxPool2d({},{},{},ceil_mode={}),#MaxPool2d'.format((m.kW, m.kH), (m.dW, m.dH), (m.padW, m.padH), m.ceil_mode)]
179
- suffixes = sorted (int (re .match ('pooling_(\d*)' , v ).group (1 )) for v in self .t2pt_names .values ())
180
- name = 'pooling_{}' .format (suffixes [- 1 ] + 1 if suffixes else 1 )
181
- s += [name ]
182
- self .t2pt_names [m ] = name
183
- self .prefix_code += ['{} = StatefulMaxPool2d({}, {}, {}, ceil_mode={})' .format (name , (m .kH , m .kW ), (m .dH , m .dW ), (m .padH , m .padW ), m .ceil_mode )]
197
+ if not self .have_max_unpool :
198
+ s += ['nn.MaxPool2d({}, {}, {}, ceil_mode={}), #MaxPool2d' .format ((m .kW , m .kH ), (m .dW , m .dH ), (m .padW , m .padH ), m .ceil_mode )]
199
+ else :
200
+ suffixes = sorted (int (re .match ('pooling_(\d*)' , v ).group (1 )) for v in self .t2pt_names .values ())
201
+ name = 'pooling_{}' .format (suffixes [- 1 ] + 1 if suffixes else 1 )
202
+ s += [name ]
203
+ self .t2pt_names [m ] = name
204
+ self .prefix_code += ['{} = StatefulMaxPool2d({}, {}, {}, ceil_mode={})' .format (name , (m .kH , m .kW ), (m .dH , m .dW ), (m .padH , m .padW ), m .ceil_mode )]
184
205
elif name == 'SpatialMaxUnpooling' :
185
206
if m .pooling in self .t2pt_names :
186
207
s += ['StatefulMaxUnpool2d({}), #SpatialMaxUnpooling' .format (self .t2pt_names [m .pooling ])]
187
208
else :
188
209
s += ['# ' + name + ' Not Implement (can\' t find corresponding SpatialMaxUnpooling,\n ' ]
189
210
elif name == 'SpatialAveragePooling' :
190
- s += ['nn.AvgPool2d({},{},{},ceil_mode={}),#AvgPool2d' .format ((m .kW , m .kH ), (m .dW , m .dH ), (m .padW , m .padH ), m .ceil_mode )]
211
+ s += ['nn.AvgPool2d({}, {}, {}, ceil_mode={}), #AvgPool2d' .format ((m .kW , m .kH ), (m .dW , m .dH ), (m .padW , m .padH ), m .ceil_mode )]
191
212
elif name == 'SpatialUpSamplingNearest' :
192
213
s += ['nn.UpsamplingNearest2d(scale_factor={})' .format (m .scale_factor )]
193
214
elif name == 'View' :
@@ -197,7 +218,7 @@ def lua_recursive_source(self, module):
197
218
elif name == 'Linear' :
198
219
s1 = 'Lambda(lambda x: x.view(1,-1) if 1==len(x.size()) else x )'
199
220
s2 = 'nn.Linear({},{},bias={})' .format (m .weight .size (1 ), m .weight .size (0 ), (m .bias is not None ))
200
- s += ['nn.Sequential({},{}),#Linear' .format (s1 , s2 )]
221
+ s += ['nn.Sequential({}, {}), #Linear' .format (s1 , s2 )]
201
222
elif name == 'Dropout' :
202
223
s += ['nn.Dropout({})' .format (m .p )]
203
224
elif name == 'SoftMax' :
@@ -245,20 +266,20 @@ def lua_recursive_source(self, module):
245
266
246
267
@staticmethod
247
268
def simplify_source (s ):
248
- s = map (lambda x : x .replace (',(1, 1),(0, 0),1,1, bias=True),#Conv2d' , ')' ), s )
249
- s = map (lambda x : x .replace (',(0, 0),1,1, bias=True),#Conv2d' , ')' ), s )
250
- s = map (lambda x : x .replace (',1,1, bias=True),#Conv2d' , ')' ), s )
251
- s = map (lambda x : x .replace (',bias=True),#Conv2d' , ')' ), s )
252
- s = map (lambda x : x .replace ('),#Conv2d' , ')' ), s )
253
- s = map (lambda x : x .replace (',1e-05,0.1,True),#BatchNorm2d' , ')' ), s )
254
- s = map (lambda x : x .replace ('),#BatchNorm2d' , ')' ), s )
255
- s = map (lambda x : x .replace (',(0, 0),ceil_mode=False),#MaxPool2d' , ')' ), s )
256
- s = map (lambda x : x .replace (',ceil_mode=False),#MaxPool2d' , ')' ), s )
257
- s = map (lambda x : x .replace ('),#MaxPool2d' , ')' ), s )
258
- s = map (lambda x : x .replace (',(0, 0),ceil_mode=False),#AvgPool2d' , ')' ), s )
259
- s = map (lambda x : x .replace (',ceil_mode=False),#AvgPool2d' , ')' ), s )
260
- s = map (lambda x : x .replace (',bias=True)),#Linear' , ')), # Linear' ), s )
261
- s = map (lambda x : x .replace (')),#Linear' , ')), # Linear' ), s )
269
+ s = map (lambda x : x .replace (', (1, 1), (0, 0), 1, 1, bias=True), #Conv2d' , ')' ), s )
270
+ s = map (lambda x : x .replace (', (0, 0), 1, 1, bias=True), #Conv2d' , ')' ), s )
271
+ s = map (lambda x : x .replace (', 1, 1, bias=True), #Conv2d' , ')' ), s )
272
+ s = map (lambda x : x .replace (', bias=True), #Conv2d' , ')' ), s )
273
+ s = map (lambda x : x .replace ('), #Conv2d' , ')' ), s )
274
+ s = map (lambda x : x .replace (', 1e-05, 0.1, True), #BatchNorm2d' , ')' ), s )
275
+ s = map (lambda x : x .replace ('), #BatchNorm2d' , ')' ), s )
276
+ s = map (lambda x : x .replace (', (0, 0), ceil_mode=False), #MaxPool2d' , ')' ), s )
277
+ s = map (lambda x : x .replace (', ceil_mode=False), #MaxPool2d' , ')' ), s )
278
+ s = map (lambda x : x .replace ('), #MaxPool2d' , ')' ), s )
279
+ s = map (lambda x : x .replace (', (0, 0), ceil_mode=False), #AvgPool2d' , ')' ), s )
280
+ s = map (lambda x : x .replace (', ceil_mode=False), #AvgPool2d' , ')' ), s )
281
+ s = map (lambda x : x .replace (', bias=True)), #Linear' , ')), # Linear' ), s )
282
+ s = map (lambda x : x .replace (')), #Linear' , ')), # Linear' ), s )
262
283
263
284
s = map (lambda x : '{},\n ' .format (x ), s )
264
285
s = map (lambda x : x [1 :], s )
@@ -272,17 +293,19 @@ def torch_to_pytorch(t7_filename, outputname=None):
272
293
model = model .model
273
294
model .gradInput = None
274
295
275
- cvt = Convertor ()
276
- slist = cvt .lua_recursive_source (lnn .Sequential ().add (model ))
277
- s = cvt .simplify_source (slist )
296
+ cvt = Convertor (model )
297
+ s = cvt .lua_recursive_source (lnn .Sequential ().add (model ))
298
+ s = cvt .simplify_source (s )
278
299
279
300
varname = os .path .basename (t7_filename ).replace ('.t7' , '' ).replace ('.' , '_' ).replace ('-' , '_' )
280
301
281
302
with open ("header.py" ) as f :
282
303
header = f .read ()
283
304
s = '{}\n {}\n \n {} = {}' .format (header , '\n ' .join (cvt .prefix_code ), varname , s [:- 2 ])
284
305
285
- if outputname is None : outputname = varname
306
+ if outputname is None :
307
+ outputname = varname
308
+
286
309
with open (outputname + '.py' , "w" ) as pyfile :
287
310
pyfile .write (s )
288
311
@@ -294,7 +317,7 @@ def torch_to_pytorch(t7_filename, outputname=None):
294
317
if __name__ == '__main__' :
295
318
parser = argparse .ArgumentParser (description = 'Convert torch t7 model to pytorch' )
296
319
parser .add_argument ('--model' , '-m' , type = str , required = True , help = 'torch model file in t7 format' )
297
- parser .add_argument ('--output' , '-o' , type = str , default = None , help = 'output file name prefix, xxx.py xxx.pth' )
320
+ parser .add_argument ('--output' , '-o' , type = str , default = '/tmp/model' , help = 'output file name prefix, xxx.py xxx.pth' )
298
321
args = parser .parse_args ()
299
322
300
323
torch_to_pytorch (args .model , args .output )
0 commit comments