Skip to content

Commit b2e8426

Browse files
committed
Make k=stride=2 ('avg2') pooling default for coatnet/maxvit. Add weight links. Rename 'combined' partition to 'parallel'.
1 parent 837c682 commit b2e8426

File tree

1 file changed

+27
-27
lines changed

1 file changed

+27
-27
lines changed

timm/models/maxxvit.py

Lines changed: 27 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -74,26 +74,26 @@ def _cfg(url='', **kwargs):
7474
# Fiddling with configs / defaults / still pretraining
7575
'coatnet_pico_rw_224': _cfg(url=''),
7676
'coatnet_nano_rw_224': _cfg(
77-
url='',
77+
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-maxx/coatnet_nano_rw_224_sw-f53093b4.pth',
7878
crop_pct=0.9),
7979
'coatnet_0_rw_224': _cfg(
80-
url=''),
80+
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-maxx/coatnet_0_rw_224_sw-a6439706.pth'),
8181
'coatnet_1_rw_224': _cfg(
82-
url=''
82+
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-maxx/coatnet_1_rw_224_sw-5cae1ea8.pth'
8383
),
8484
'coatnet_2_rw_224': _cfg(url=''),
8585

8686
# Highly experimental configs
8787
'coatnet_bn_0_rw_224': _cfg(
88-
url='',
88+
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-maxx/coatnet_bn_0_rw_224_sw-c228e218.pth',
8989
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD,
9090
crop_pct=0.95),
9191
'coatnet_rmlp_nano_rw_224': _cfg(
92-
url='',
92+
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-maxx/coatnet_rmlp_nano_rw_224_sw-bd1d51b3.pth',
9393
crop_pct=0.9),
9494
'coatnet_rmlp_0_rw_224': _cfg(url=''),
9595
'coatnet_rmlp_1_rw_224': _cfg(
96-
url=''),
96+
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-maxx/coatnet_rmlp_1_rw_224_sw-9051e6c3.pth'),
9797
'coatnet_nano_cc_224': _cfg(url=''),
9898
'coatnext_nano_rw_224': _cfg(url=''),
9999

@@ -107,10 +107,12 @@ def _cfg(url='', **kwargs):
107107

108108
# Experimental configs
109109
'maxvit_pico_rw_256': _cfg(url='', input_size=(3, 256, 256), pool_size=(8, 8)),
110-
'maxvit_nano_rw_256': _cfg(url='', input_size=(3, 256, 256), pool_size=(8, 8)),
110+
'maxvit_nano_rw_256': _cfg(
111+
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-maxx/maxvit_nano_rw_256_sw-3e790ce3.pth',
112+
input_size=(3, 256, 256), pool_size=(8, 8)),
111113
'maxvit_tiny_rw_224': _cfg(url=''),
112114
'maxvit_tiny_rw_256': _cfg(url='', input_size=(3, 256, 256), pool_size=(8, 8)),
113-
'maxvit_tiny_cm_256': _cfg(url='', input_size=(3, 256, 256), pool_size=(8, 8)),
115+
'maxvit_tiny_pm_256': _cfg(url='', input_size=(3, 256, 256), pool_size=(8, 8)),
114116
'maxxvit_nano_rw_256': _cfg(url='', input_size=(3, 256, 256), pool_size=(8, 8)),
115117

116118
# Trying to be like the MaxViT paper configs
@@ -131,7 +133,7 @@ class MaxxVitTransformerCfg:
131133
attn_bias: bool = True
132134
attn_drop: float = 0.
133135
proj_drop: float = 0.
134-
pool_type: str = 'avg'
136+
pool_type: str = 'avg2'
135137
rel_pos_type: str = 'bias'
136138
rel_pos_dim: int = 512 # for relative position types w/ MLP
137139
window_size: Tuple[int, int] = (7, 7)
@@ -153,7 +155,7 @@ class MaxxVitConvCfg:
153155
pre_norm_act: bool = False # activation after pre-norm
154156
output_bias: bool = True # bias for shortcut + final 1x1 projection conv
155157
stride_mode: str = 'dw' # stride done via one of 'pool', '1x1', 'dw'
156-
pool_type: str = 'avg'
158+
pool_type: str = 'avg2'
157159
downsample_pool_type: str = 'avg2'
158160
attn_early: bool = False # apply attn between conv2 and norm2, instead of after norm2
159161
attn_layer: str = 'se'
@@ -241,7 +243,7 @@ def _rw_coat_cfg(
241243

242244
def _rw_max_cfg(
243245
stride_mode='dw',
244-
pool_type='avg',
246+
pool_type='avg2',
245247
conv_output_bias=False,
246248
conv_attn_ratio=1 / 16,
247249
conv_norm_layer='',
@@ -325,7 +327,6 @@ def _next_cfg(
325327
depths=(2, 3, 5, 2),
326328
stem_width=(32, 64),
327329
**_rw_max_cfg( # using newer max defaults here
328-
pool_type='avg2',
329330
conv_output_bias=True,
330331
conv_attn_ratio=0.25,
331332
),
@@ -336,7 +337,6 @@ def _next_cfg(
336337
stem_width=(32, 64),
337338
**_rw_max_cfg( # using newer max defaults here
338339
stride_mode='pool',
339-
pool_type='avg2',
340340
conv_output_bias=True,
341341
conv_attn_ratio=0.25,
342342
),
@@ -384,7 +384,6 @@ def _next_cfg(
384384
depths=(3, 4, 6, 3),
385385
stem_width=(32, 64),
386386
**_rw_max_cfg(
387-
pool_type='avg2',
388387
conv_output_bias=True,
389388
conv_attn_ratio=0.25,
390389
rel_pos_type='mlp',
@@ -487,10 +486,10 @@ def _next_cfg(
487486
stem_width=(32, 64),
488487
**_rw_max_cfg(window_size=8),
489488
),
490-
maxvit_tiny_cm_256=MaxxVitCfg(
489+
maxvit_tiny_pm_256=MaxxVitCfg(
491490
embed_dim=(64, 128, 256, 512),
492491
depths=(2, 2, 5, 2),
493-
block_type=('CM',) * 4,
492+
block_type=('PM',) * 4,
494493
stem_width=(32, 64),
495494
**_rw_max_cfg(window_size=8),
496495
),
@@ -663,13 +662,15 @@ def __init__(
663662
bias: bool = True,
664663
):
665664
super().__init__()
666-
assert pool_type in ('max', 'avg', 'avg2')
665+
assert pool_type in ('max', 'max2', 'avg', 'avg2')
667666
if pool_type == 'max':
668667
self.pool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
668+
elif pool_type == 'max2':
669+
self.pool = nn.MaxPool2d(2) # kernel_size == stride == 2
669670
elif pool_type == 'avg':
670671
self.pool = nn.AvgPool2d(kernel_size=3, stride=2, padding=1, count_include_pad=False)
671672
else:
672-
self.pool = nn.AvgPool2d(2)
673+
self.pool = nn.AvgPool2d(2) # kernel_size == stride == 2
673674

674675
if dim != dim_out:
675676
self.expand = nn.Conv2d(dim, dim_out, 1, bias=bias)
@@ -1073,7 +1074,7 @@ def forward(self, x):
10731074
return x
10741075

10751076

1076-
class CombinedPartitionAttention(nn.Module):
1077+
class ParallelPartitionAttention(nn.Module):
10771078
""" Experimental. Grid and Block partition + single FFN
10781079
NxC tensor layout.
10791080
"""
@@ -1286,7 +1287,7 @@ def forward(self, x):
12861287
return x
12871288

12881289

1289-
class CombinedMaxxVitBlock(nn.Module):
1290+
class ParallelMaxxVitBlock(nn.Module):
12901291
"""
12911292
"""
12921293

@@ -1309,7 +1310,7 @@ def __init__(
13091310
self.conv = nn.Sequential(*convs)
13101311
else:
13111312
self.conv = conv_cls(dim, dim_out, stride=stride, cfg=conv_cfg, drop_path=drop_path)
1312-
self.attn = CombinedPartitionAttention(dim=dim_out, cfg=transformer_cfg, drop_path=drop_path)
1313+
self.attn = ParallelPartitionAttention(dim=dim_out, cfg=transformer_cfg, drop_path=drop_path)
13131314

13141315
def init_weights(self, scheme=''):
13151316
named_apply(partial(_init_transformer, scheme=scheme), self.attn)
@@ -1343,7 +1344,7 @@ def __init__(
13431344
blocks = []
13441345
for i, t in enumerate(block_types):
13451346
block_stride = stride if i == 0 else 1
1346-
assert t in ('C', 'T', 'M', 'CM')
1347+
assert t in ('C', 'T', 'M', 'PM')
13471348
if t == 'C':
13481349
conv_cls = ConvNeXtBlock if conv_cfg.block_type == 'convnext' else MbConvBlock
13491350
blocks += [conv_cls(
@@ -1372,8 +1373,8 @@ def __init__(
13721373
transformer_cfg=transformer_cfg,
13731374
drop_path=drop_path[i],
13741375
)]
1375-
elif t == 'CM':
1376-
blocks += [CombinedMaxxVitBlock(
1376+
elif t == 'PM':
1377+
blocks += [ParallelMaxxVitBlock(
13771378
in_chs,
13781379
out_chs,
13791380
stride=block_stride,
@@ -1415,7 +1416,6 @@ def __init__(
14151416
self.norm1 = norm_act_layer(out_chs[0])
14161417
self.conv2 = create_conv2d(out_chs[0], out_chs[1], kernel_size, stride=1)
14171418

1418-
@torch.jit.ignore
14191419
def init_weights(self, scheme=''):
14201420
named_apply(partial(_init_conv, scheme=scheme), self)
14211421

@@ -1659,8 +1659,8 @@ def maxvit_tiny_rw_256(pretrained=False, **kwargs):
16591659

16601660

16611661
@register_model
1662-
def maxvit_tiny_cm_256(pretrained=False, **kwargs):
1663-
return _create_maxxvit('maxvit_tiny_cm_256', pretrained=pretrained, **kwargs)
1662+
def maxvit_tiny_pm_256(pretrained=False, **kwargs):
1663+
return _create_maxxvit('maxvit_tiny_pm_256', pretrained=pretrained, **kwargs)
16641664

16651665

16661666
@register_model

0 commit comments

Comments
 (0)