Skip to content

Commit

Permalink
Merge pull request #87 from gdahia/num-classes-not-required
Browse files Browse the repository at this point in the history
Allow None "num_classes" if not "classify"
  • Loading branch information
timesler authored Jun 14, 2020
2 parents e5a30d7 + 5fd1d05 commit 5dcce36
Showing 1 changed file with 4 additions and 6 deletions.
10 changes: 4 additions & 6 deletions models/inception_resnet_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,10 +208,8 @@ def __init__(self, pretrained=None, classify=False, num_classes=None, dropout_pr
tmp_classes = 8631
elif pretrained == 'casia-webface':
tmp_classes = 10575
elif pretrained is None and self.num_classes is None:
raise Exception('At least one of "pretrained" or "num_classes" must be specified')
else:
tmp_classes = self.num_classes
elif pretrained is None and self.classify and self.num_classes is None:
raise Exception('If "pretrained" is not specified and "classify" is True, "num_classes" must be specified')


# Define layers
Expand Down Expand Up @@ -255,12 +253,12 @@ def __init__(self, pretrained=None, classify=False, num_classes=None, dropout_pr
self.dropout = nn.Dropout(dropout_prob)
self.last_linear = nn.Linear(1792, 512, bias=False)
self.last_bn = nn.BatchNorm1d(512, eps=0.001, momentum=0.1, affine=True)
self.logits = nn.Linear(512, tmp_classes)

if pretrained is not None:
self.logits = nn.Linear(512, tmp_classes)
load_weights(self, pretrained)

if self.num_classes is not None:
if self.classify and self.num_classes is not None:
self.logits = nn.Linear(512, self.num_classes)

self.device = torch.device('cpu')
Expand Down

0 comments on commit 5dcce36

Please sign in to comment.