@@ -301,11 +301,10 @@ def __init__(
301
301
302
302
# if head_norm_first == true, norm -> global pool -> fc ordering, like most other nets
303
303
# otherwise pool -> norm -> fc, the default ConvNeXt ordering (pretrained FB weights)
304
- self .head_norm_first = head_norm_first
305
304
self .norm_pre = norm_layer (self .num_features ) if head_norm_first else nn .Identity ()
306
305
self .head = nn .Sequential (OrderedDict ([
307
306
('global_pool' , SelectAdaptivePool2d (pool_type = global_pool )),
308
- ('norm' , nn .Identity () if head_norm_first or num_classes == 0 else norm_layer (self .num_features )),
307
+ ('norm' , nn .Identity () if head_norm_first else norm_layer (self .num_features )),
309
308
('flatten' , nn .Flatten (1 ) if global_pool else nn .Identity ()),
310
309
('drop' , nn .Dropout (self .drop_rate )),
311
310
('fc' , nn .Linear (self .num_features , num_classes ) if num_classes > 0 else nn .Identity ())]))
@@ -336,14 +335,7 @@ def reset_classifier(self, num_classes=0, global_pool=None):
336
335
if global_pool is not None :
337
336
self .head .global_pool = SelectAdaptivePool2d (pool_type = global_pool )
338
337
self .head .flatten = nn .Flatten (1 ) if global_pool else nn .Identity ()
339
- if num_classes == 0 :
340
- self .head .norm = nn .Identity ()
341
- self .head .fc = nn .Identity ()
342
- else :
343
- if not self .head_norm_first :
344
- norm_layer = type (self .stem [- 1 ]) # obtain type from stem norm
345
- self .head .norm = norm_layer (self .num_features )
346
- self .head .fc = nn .Linear (self .num_features , num_classes )
338
+ self .head .fc = nn .Linear (self .num_features , num_classes ) if num_classes > 0 else nn .Identity ()
347
339
348
340
def forward_features (self , x ):
349
341
x = self .stem (x )
@@ -407,6 +399,11 @@ def checkpoint_filter_fn(state_dict, model):
407
399
408
400
409
401
def _create_convnext (variant , pretrained = False , ** kwargs ):
402
+ if kwargs .get ('pretrained_cfg' , '' ) == 'fcmae' :
403
+ # NOTE fcmae pretrained weights have no classifier or final norm-layer (`head.norm`)
404
+ # This is workaround loading with num_classes=0 w/o removing norm-layer.
405
+ kwargs .setdefault ('pretrained_strict' , False )
406
+
410
407
model = build_model_with_cfg (
411
408
ConvNeXt , variant , pretrained ,
412
409
pretrained_filter_fn = checkpoint_filter_fn ,
0 commit comments