-
Notifications
You must be signed in to change notification settings - Fork 0
/
model.py
24 lines (20 loc) · 852 Bytes
/
model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
import torch.nn as nn
from torchvision.models import alexnet, resnet50, densenet121, densenet169, densenet201, densenet161
class Classifier(nn.Module):
def __init__(self, model):
super(Classifier, self).__init__()
self.model = model
def build_model(self):
if self.model == 'alexnet':
model = alexnet(pretrained=True)
elif self.model == 'resnet50':
model = resnet50(pretrained=True)
elif self.model == 'densenet121':
model = densenet121(pretrained=True)
elif self.model == 'densenet169':
model = densenet169(pretrained=True)
elif self.model == 'densenet201':
model = densenet201(pretrained=True)
elif self.model == 'densenet161':
model = densenet161()
return model