1
1
import torch .nn as nn
2
2
import math
3
3
import torch .utils .model_zoo as model_zoo
4
+ from build_model import LoadPretrainedModel
4
5
5
6
6
7
__all__ = ['ResNet' , 'resnet18' , 'resnet34' , 'resnet50' , 'resnet101' ,
@@ -229,7 +230,8 @@ def Resnet18Templet(input_channel, pretrained=False, **kwargs):
229
230
"""
230
231
model = ResNetTemplet (BasicBlock , [2 , 2 , 2 , 2 ], input_channel , ** kwargs )
231
232
if pretrained :
232
- model .load_state_dict (model_zoo .load_url (model_urls ['resnet18' ]))
233
+ model_dict = LoadPretrainedModel (model , model_zoo .load_url (model_urls ['resnet18' ]))
234
+ model .load_state_dict (model_dict )
233
235
return model
234
236
235
237
@@ -253,7 +255,8 @@ def Resnet34Templet(input_channel, pretrained=False, **kwargs):
253
255
"""
254
256
model = ResNetTemplet (BasicBlock , [3 , 4 , 6 , 3 ], input_channel , ** kwargs )
255
257
if pretrained :
256
- model .load_state_dict (model_zoo .load_url (model_urls ['resnet34' ]))
258
+ model_dict = LoadPretrainedModel (model , model_zoo .load_url (model_urls ['resnet34' ]))
259
+ model .load_state_dict (model_dict )
257
260
return model
258
261
259
262
@@ -277,7 +280,8 @@ def Resnet50Templet(input_channel, pretrained=False, **kwargs):
277
280
"""
278
281
model = ResNetTemplet (Bottleneck , [3 , 4 , 6 , 3 ], input_channel , ** kwargs )
279
282
if pretrained :
280
- model .load_state_dict (model_zoo .load_url (model_urls ['resnet50' ]))
283
+ model_dict = LoadPretrainedModel (model , model_zoo .load_url (model_urls ['resnet50' ]))
284
+ model .load_state_dict (model_dict )
281
285
return model
282
286
283
287
@@ -300,7 +304,8 @@ def Resnet101Templet(input_channel, pretrained=False, **kwargs):
300
304
"""
301
305
model = ResNetTemplet (Bottleneck , [3 , 4 , 23 , 3 ], input_channel , ** kwargs )
302
306
if pretrained :
303
- model .load_state_dict (model_zoo .load_url (model_urls ['resnet101' ]))
307
+ model_dict = LoadPretrainedModel (model , model_zoo .load_url (model_urls ['resnet101' ]))
308
+ model .load_state_dict (model_dict )
304
309
return model
305
310
306
311
def resnet152 (pretrained = False , ** kwargs ):
@@ -322,6 +327,7 @@ def Resnet152Templet(input_channel, pretrained=False, **kwargs):
322
327
"""
323
328
model = ResNetTemplet (Bottleneck , [3 , 8 , 36 , 3 ], input_channel , ** kwargs )
324
329
if pretrained :
325
- model .load_state_dict (model_zoo .load_url (model_urls ['resnet152' ]))
330
+ model_dict = LoadPretrainedModel (model , model_zoo .load_url (model_urls ['resnet152' ]))
331
+ model .load_state_dict (model_dict )
326
332
return model
327
333
0 commit comments