Skip to content

Commit e0c4eec

Browse files
committed
Default conv_mlp to False across the board for ConvNeXt, causing issues on more setups than it's improving right now...
1 parent b669f4a commit e0c4eec

File tree

1 file changed

+10
-10
lines changed

1 file changed

+10
-10
lines changed

timm/models/convnext.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -116,7 +116,7 @@ class ConvNeXtBlock(nn.Module):
116116
ls_init_value (float): Init value for Layer Scale. Default: 1e-6.
117117
"""
118118

119-
def __init__(self, dim, drop_path=0., ls_init_value=1e-6, conv_mlp=True, mlp_ratio=4, norm_layer=None):
119+
def __init__(self, dim, drop_path=0., ls_init_value=1e-6, conv_mlp=False, mlp_ratio=4, norm_layer=None):
120120
super().__init__()
121121
if not norm_layer:
122122
norm_layer = partial(LayerNorm2d, eps=1e-6) if conv_mlp else partial(nn.LayerNorm, eps=1e-6)
@@ -148,7 +148,7 @@ def forward(self, x):
148148
class ConvNeXtStage(nn.Module):
149149

150150
def __init__(
151-
self, in_chs, out_chs, stride=2, depth=2, dp_rates=None, ls_init_value=1.0, conv_mlp=True,
151+
self, in_chs, out_chs, stride=2, depth=2, dp_rates=None, ls_init_value=1.0, conv_mlp=False,
152152
norm_layer=None, cl_norm_layer=None, cross_stage=False):
153153
super().__init__()
154154

@@ -190,7 +190,7 @@ class ConvNeXt(nn.Module):
190190

191191
def __init__(
192192
self, in_chans=3, num_classes=1000, global_pool='avg', output_stride=32, patch_size=4,
193-
depths=(3, 3, 9, 3), dims=(96, 192, 384, 768), ls_init_value=1e-6, conv_mlp=True,
193+
depths=(3, 3, 9, 3), dims=(96, 192, 384, 768), ls_init_value=1e-6, conv_mlp=False,
194194
head_init_scale=1., head_norm_first=False, norm_layer=None, drop_rate=0., drop_path_rate=0.,
195195
):
196196
super().__init__()
@@ -356,7 +356,7 @@ def convnext_base(pretrained=False, **kwargs):
356356

357357
@register_model
358358
def convnext_large(pretrained=False, **kwargs):
359-
model_args = dict(depths=[3, 3, 27, 3], dims=[192, 384, 768, 1536], conv_mlp=False, **kwargs)
359+
model_args = dict(depths=[3, 3, 27, 3], dims=[192, 384, 768, 1536], **kwargs)
360360
model = _create_convnext('convnext_large', pretrained=pretrained, **model_args)
361361
return model
362362

@@ -370,14 +370,14 @@ def convnext_base_in22ft1k(pretrained=False, **kwargs):
370370

371371
@register_model
372372
def convnext_large_in22ft1k(pretrained=False, **kwargs):
373-
model_args = dict(depths=[3, 3, 27, 3], dims=[192, 384, 768, 1536], conv_mlp=False, **kwargs)
373+
model_args = dict(depths=[3, 3, 27, 3], dims=[192, 384, 768, 1536], **kwargs)
374374
model = _create_convnext('convnext_large_in22ft1k', pretrained=pretrained, **model_args)
375375
return model
376376

377377

378378
@register_model
379379
def convnext_xlarge_in22ft1k(pretrained=False, **kwargs):
380-
model_args = dict(depths=[3, 3, 27, 3], dims=[256, 512, 1024, 2048], conv_mlp=False, **kwargs)
380+
model_args = dict(depths=[3, 3, 27, 3], dims=[256, 512, 1024, 2048], **kwargs)
381381
model = _create_convnext('convnext_xlarge_in22ft1k', pretrained=pretrained, **model_args)
382382
return model
383383

@@ -391,14 +391,14 @@ def convnext_base_384_in22ft1k(pretrained=False, **kwargs):
391391

392392
@register_model
393393
def convnext_large_384_in22ft1k(pretrained=False, **kwargs):
394-
model_args = dict(depths=[3, 3, 27, 3], dims=[192, 384, 768, 1536], conv_mlp=False, **kwargs)
394+
model_args = dict(depths=[3, 3, 27, 3], dims=[192, 384, 768, 1536], **kwargs)
395395
model = _create_convnext('convnext_large_384_in22ft1k', pretrained=pretrained, **model_args)
396396
return model
397397

398398

399399
@register_model
400400
def convnext_xlarge_384_in22ft1k(pretrained=False, **kwargs):
401-
model_args = dict(depths=[3, 3, 27, 3], dims=[256, 512, 1024, 2048], conv_mlp=False, **kwargs)
401+
model_args = dict(depths=[3, 3, 27, 3], dims=[256, 512, 1024, 2048], **kwargs)
402402
model = _create_convnext('convnext_xlarge_384_in22ft1k', pretrained=pretrained, **model_args)
403403
return model
404404

@@ -412,14 +412,14 @@ def convnext_base_in22k(pretrained=False, **kwargs):
412412

413413
@register_model
414414
def convnext_large_in22k(pretrained=False, **kwargs):
415-
model_args = dict(depths=[3, 3, 27, 3], dims=[192, 384, 768, 1536], conv_mlp=False, **kwargs)
415+
model_args = dict(depths=[3, 3, 27, 3], dims=[192, 384, 768, 1536], **kwargs)
416416
model = _create_convnext('convnext_large_in22k', pretrained=pretrained, **model_args)
417417
return model
418418

419419

420420
@register_model
421421
def convnext_xlarge_in22k(pretrained=False, **kwargs):
422-
model_args = dict(depths=[3, 3, 27, 3], dims=[256, 512, 1024, 2048], conv_mlp=False, **kwargs)
422+
model_args = dict(depths=[3, 3, 27, 3], dims=[256, 512, 1024, 2048], **kwargs)
423423
model = _create_convnext('convnext_xlarge_in22k', pretrained=pretrained, **model_args)
424424
return model
425425

0 commit comments

Comments
 (0)