Skip to content

Commit 2e83bba

Browse files
committed
Revert head norm changes to ConvNeXt as it broke some downstream use, alternate workaround for fcmae weights
1 parent 2c24cb9 commit 2e83bba

File tree

1 file changed

+7
-10
lines changed

1 file changed

+7
-10
lines changed

timm/models/convnext.py

Lines changed: 7 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -301,11 +301,10 @@ def __init__(
301301

302302
# if head_norm_first == true, norm -> global pool -> fc ordering, like most other nets
303303
# otherwise pool -> norm -> fc, the default ConvNeXt ordering (pretrained FB weights)
304-
self.head_norm_first = head_norm_first
305304
self.norm_pre = norm_layer(self.num_features) if head_norm_first else nn.Identity()
306305
self.head = nn.Sequential(OrderedDict([
307306
('global_pool', SelectAdaptivePool2d(pool_type=global_pool)),
308-
('norm', nn.Identity() if head_norm_first or num_classes == 0 else norm_layer(self.num_features)),
307+
('norm', nn.Identity() if head_norm_first else norm_layer(self.num_features)),
309308
('flatten', nn.Flatten(1) if global_pool else nn.Identity()),
310309
('drop', nn.Dropout(self.drop_rate)),
311310
('fc', nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity())]))
@@ -336,14 +335,7 @@ def reset_classifier(self, num_classes=0, global_pool=None):
336335
if global_pool is not None:
337336
self.head.global_pool = SelectAdaptivePool2d(pool_type=global_pool)
338337
self.head.flatten = nn.Flatten(1) if global_pool else nn.Identity()
339-
if num_classes == 0:
340-
self.head.norm = nn.Identity()
341-
self.head.fc = nn.Identity()
342-
else:
343-
if not self.head_norm_first:
344-
norm_layer = type(self.stem[-1]) # obtain type from stem norm
345-
self.head.norm = norm_layer(self.num_features)
346-
self.head.fc = nn.Linear(self.num_features, num_classes)
338+
self.head.fc = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
347339

348340
def forward_features(self, x):
349341
x = self.stem(x)
@@ -407,6 +399,11 @@ def checkpoint_filter_fn(state_dict, model):
407399

408400

409401
def _create_convnext(variant, pretrained=False, **kwargs):
402+
if kwargs.get('pretrained_cfg', '') == 'fcmae':
403+
# NOTE fcmae pretrained weights have no classifier or final norm-layer (`head.norm`)
404+
# This is workaround loading with num_classes=0 w/o removing norm-layer.
405+
kwargs.setdefault('pretrained_strict', False)
406+
410407
model = build_model_with_cfg(
411408
ConvNeXt, variant, pretrained,
412409
pretrained_filter_fn=checkpoint_filter_fn,

0 commit comments

Comments
 (0)