Skip to content

Commit 111d7fe

Browse files
committed
merge pytorch0.3.1
2 parents 6baafec + 38e1792 commit 111d7fe

File tree

9 files changed

+373
-23
lines changed

9 files changed

+373
-23
lines changed

demo.sh

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,8 @@
1111
#
1212
# 4. uncomment to train
1313
#
14-
# python multi_label_classifier.py --dir "./test/celeba/" --mode "Train" --name "test" --batch_size 64 --gpu_ids 0 --input_channel 3 --load_size 144 --input_size 128 --mean [0,0,0] --std [1,1,1] --ratio "[0.94, 0.03, 0.03]" --shuffle --load_thread 8 --sum_epoch 20 --lr_decay_in_epoch 4 --display_port 8900 --validate_ratio 0.1 --top_k "(1,)" --score_thres 0.1 --display_train_freq 20 --display_validate_freq 20 --save_epoch_freq 1 --display_image_ratio 0.2
14+
# python multi_label_classifier.py --dir "./test/celeba/" --mode "Train" --model VGG16 --name "test" --batch_size 64 --gpu_ids 0 --input_channel 3 --load_size 144 --input_size 128 --mean [0,0,0] --std [1,1,1] --ratio "[0.94, 0.03, 0.03]" --shuffle --load_thread 8 --sum_epoch 20 --lr_decay_in_epoch 4 --display_port 8900 --validate_ratio 0.1 --top_k "(1,)" --score_thres 0.1 --display_train_freq 20 --display_validate_freq 20 --save_epoch_freq 1 --display_image_ratio 0.2
15+
#
1516
# 5. open localhost:8900 on your browser and you will see loss and accuracy curves and training images samples later on.
1617
#
1718
#--------------test--------------

models/alexnet.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import torch.nn as nn
22
import torch.utils.model_zoo as model_zoo
3-
3+
from build_model import LoadPretrainedModel
44

55
__all__ = ['AlexNet', 'alexnet']
66

@@ -89,3 +89,15 @@ def alexnet(pretrained=False, **kwargs):
8989
if pretrained:
9090
model.load_state_dict(model_zoo.load_url(model_urls['alexnet']))
9191
return model
92+
93+
def AlexnetTemplet(input_channel, pretrained=False, **kwargs):
94+
r"""AlexNet model architecture from the
95+
`"One weird trick..." <https://arxiv.org/abs/1404.5997>`_ paper.
96+
Args:
97+
pretrained (bool): If True, returns a model pre-trained on ImageNet
98+
"""
99+
model = AlexNetTemplet(input_channel)
100+
if pretrained:
101+
model_dict = LoadPretrainedModel(model, model_zoo.load_url(model_urls['alexnet']))
102+
model.load_state_dict(model_dict)
103+
return model

models/build_model.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,11 @@ def forward(self, x):
1818
outs.append(out)
1919
return outs
2020

21+
def LoadPretrainedModel(model, pretrained_state_dict):
22+
model_dict = model.state_dict()
23+
union_dict = {k : v for k,v in pretrained_state_dict.iteritems() if k in model_dict}
24+
model_dict.update(union_dict)
25+
return model_dict
2126

2227
def BuildMultiLabelModel(basemodel, basemodel_output, num_classes):
2328
return MultiLabelModel(basemodel, basemodel_output, num_classes)

models/resnet.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import torch.nn as nn
22
import math
33
import torch.utils.model_zoo as model_zoo
4+
from build_model import LoadPretrainedModel
45

56

67
__all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101',
@@ -229,7 +230,8 @@ def Resnet18Templet(input_channel, pretrained=False, **kwargs):
229230
"""
230231
model = ResNetTemplet(BasicBlock, [2, 2, 2, 2], input_channel, **kwargs)
231232
if pretrained:
232-
model.load_state_dict(model_zoo.load_url(model_urls['resnet18']))
233+
model_dict = LoadPretrainedModel(model, model_zoo.load_url(model_urls['resnet18']))
234+
model.load_state_dict(model_dict)
233235
return model
234236

235237

@@ -253,7 +255,8 @@ def Resnet34Templet(input_channel, pretrained=False, **kwargs):
253255
"""
254256
model = ResNetTemplet(BasicBlock, [3, 4, 6, 3], input_channel, **kwargs)
255257
if pretrained:
256-
model.load_state_dict(model_zoo.load_url(model_urls['resnet34']))
258+
model_dict = LoadPretrainedModel(model, model_zoo.load_url(model_urls['resnet34']))
259+
model.load_state_dict(model_dict)
257260
return model
258261

259262

@@ -277,7 +280,8 @@ def Resnet50Templet(input_channel, pretrained=False, **kwargs):
277280
"""
278281
model = ResNetTemplet(Bottleneck, [3, 4, 6, 3], input_channel, **kwargs)
279282
if pretrained:
280-
model.load_state_dict(model_zoo.load_url(model_urls['resnet50']))
283+
model_dict = LoadPretrainedModel(model, model_zoo.load_url(model_urls['resnet50']))
284+
model.load_state_dict(model_dict)
281285
return model
282286

283287

@@ -300,7 +304,8 @@ def Resnet101Templet(input_channel, pretrained=False, **kwargs):
300304
"""
301305
model = ResNetTemplet(Bottleneck, [3, 4, 23, 3], input_channel, **kwargs)
302306
if pretrained:
303-
model.load_state_dict(model_zoo.load_url(model_urls['resnet101']))
307+
model_dict = LoadPretrainedModel(model, model_zoo.load_url(model_urls['resnet101']))
308+
model.load_state_dict(model_dict)
304309
return model
305310

306311
def resnet152(pretrained=False, **kwargs):
@@ -322,6 +327,7 @@ def Resnet152Templet(input_channel, pretrained=False, **kwargs):
322327
"""
323328
model = ResNetTemplet(Bottleneck, [3, 8, 36, 3], input_channel, **kwargs)
324329
if pretrained:
325-
model.load_state_dict(model_zoo.load_url(model_urls['resnet152']))
330+
model_dict = LoadPretrainedModel(model, model_zoo.load_url(model_urls['resnet152']))
331+
model.load_state_dict(model_dict)
326332
return model
327333

0 commit comments

Comments
 (0)