Skip to content

Commit 5b28ef4

Browse files
authored
Merge pull request #475 from rwightman/pre-release/0.4.5
Prep for PyPi release. Tweak NFNet, ResNetV2, RexNet feature extraction, use pre-act features. Also, add act args to RexNet #202
2 parents 5fcdd4b + f57db99 commit 5b28ef4

File tree

9 files changed

+39
-52
lines changed

9 files changed

+39
-52
lines changed

.github/workflows/tests.yml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,8 @@ jobs:
1717
matrix:
1818
os: [ubuntu-latest, macOS-latest]
1919
python: ['3.8']
20-
torch: ['1.7.0']
21-
torchvision: ['0.8.1']
20+
torch: ['1.8.0']
21+
torchvision: ['0.9.0']
2222
runs-on: ${{ matrix.os }}
2323

2424
steps:

README.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,10 @@
22

33
## What's New
44

5+
### March 7, 2021
6+
* First 0.4.x PyPi release w/ NFNets (& related), ByoB (GPU-Efficient, RepVGG, etc).
7+
* Change feature extraction for pre-activation nets (NFNets, ResNetV2) to return features before activation.
8+
59
### Feb 18, 2021
610
* Add pretrained weights and model variants for NFNet-F* models from [DeepMind Haiku impl](https://github.com/deepmind/deepmind-research/tree/master/nfnets).
711
* Models are prefixed with `dm_`. They require SAME padding conv, skipinit enabled, and activation gains applied in act fn.

tests/test_models.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
# GitHub Linux runner is slower and hits memory limits sooner than MacOS, exclude bigger models
2222
EXCLUDE_FILTERS = [
2323
'*efficientnet_l2*', '*resnext101_32x48d', '*in21k', '*152x4_bitm',
24-
'*nfnet_f4*', '*nfnet_f5*', '*nfnet_f6*', '*nfnet_f7*'] + NON_STD_FILTERS
24+
'*nfnet_f3*', '*nfnet_f4*', '*nfnet_f5*', '*nfnet_f6*', '*nfnet_f7*'] + NON_STD_FILTERS
2525
else:
2626
EXCLUDE_FILTERS = NON_STD_FILTERS
2727

timm/models/cspnet.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -264,9 +264,11 @@ def forward(self, x):
264264
if self.conv_down is not None:
265265
x = self.conv_down(x)
266266
x = self.conv_exp(x)
267-
xs, xb = x.chunk(2, dim=1)
267+
split = x.shape[1] // 2
268+
xs, xb = x[:, :split], x[:, split:]
268269
xb = self.blocks(xb)
269-
out = self.conv_transition(torch.cat([xs, self.conv_transition_b(xb)], dim=1))
270+
xb = self.conv_transition_b(xb).contiguous()
271+
out = self.conv_transition(torch.cat([xs, xb], dim=1))
270272
return out
271273

272274

timm/models/layers/inplace_abn.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
def inplace_abn(x, weight, bias, running_mean, running_var,
1111
training=True, momentum=0.1, eps=1e-05, activation="leaky_relu", activation_param=0.01):
1212
raise ImportError(
13-
"Please install InplaceABN:'pip install git+https://github.com/mapillary/[email protected].11'")
13+
"Please install InplaceABN:'pip install git+https://github.com/mapillary/[email protected].12'")
1414

1515
def inplace_abn_sync(**kwargs):
1616
inplace_abn(**kwargs)

timm/models/nfnet.py

Lines changed: 11 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -101,11 +101,12 @@ def _dcfg(url='', **kwargs):
101101
url='', pool_size=(15, 15), input_size=(3, 480, 480), test_input_size=(3, 608, 608)),
102102

103103
nfnet_l0a=_dcfg(
104-
url='', pool_size=(6, 6), input_size=(3, 192, 192), test_input_size=(3, 256, 256)),
104+
url='', pool_size=(7, 7), input_size=(3, 224, 224), test_input_size=(3, 288, 288)),
105105
nfnet_l0b=_dcfg(
106-
url='', pool_size=(6, 6), input_size=(3, 192, 192), test_input_size=(3, 256, 256)),
106+
url='', pool_size=(7, 7), input_size=(3, 224, 224), test_input_size=(3, 288, 288)),
107107
nfnet_l0c=_dcfg(
108-
url='', pool_size=(6, 6), input_size=(3, 192, 192), test_input_size=(3, 256, 256)),
108+
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/nfnet_l0c-ad1045c2.pth',
109+
pool_size=(7, 7), input_size=(3, 224, 224), test_input_size=(3, 288, 288), crop_pct=1.0),
109110

110111
nf_regnet_b0=_dcfg(
111112
url='', pool_size=(6, 6), input_size=(3, 192, 192), test_input_size=(3, 256, 256), first_conv='stem.conv'),
@@ -376,9 +377,9 @@ def forward(self, x):
376377
return out
377378

378379

379-
def create_stem(in_chs, out_chs, stem_type='', conv_layer=None, act_layer=None):
380+
def create_stem(in_chs, out_chs, stem_type='', conv_layer=None, act_layer=None, preact_feature=True):
380381
stem_stride = 2
381-
stem_feature = dict(num_chs=out_chs, reduction=2, module='')
382+
stem_feature = dict(num_chs=out_chs, reduction=2, module='stem.conv')
382383
stem = OrderedDict()
383384
assert stem_type in ('', 'deep', 'deep_tiered', 'deep_quad', '3x3', '7x7', 'deep_pool', '3x3_pool', '7x7_pool')
384385
if 'deep' in stem_type:
@@ -388,14 +389,14 @@ def create_stem(in_chs, out_chs, stem_type='', conv_layer=None, act_layer=None):
388389
stem_chs = (out_chs // 8, out_chs // 4, out_chs // 2, out_chs)
389390
strides = (2, 1, 1, 2)
390391
stem_stride = 4
391-
stem_feature = dict(num_chs=out_chs // 2, reduction=2, module='stem.act4')
392+
stem_feature = dict(num_chs=out_chs // 2, reduction=2, module='stem.conv3')
392393
else:
393394
if 'tiered' in stem_type:
394395
stem_chs = (3 * out_chs // 8, out_chs // 2, out_chs) # 'T' resnets in resnet.py
395396
else:
396397
stem_chs = (out_chs // 2, out_chs // 2, out_chs) # 'D' ResNets
397398
strides = (2, 1, 1)
398-
stem_feature = dict(num_chs=out_chs // 2, reduction=2, module='stem.act3')
399+
stem_feature = dict(num_chs=out_chs // 2, reduction=2, module='stem.conv2')
399400
last_idx = len(stem_chs) - 1
400401
for i, (c, s) in enumerate(zip(stem_chs, strides)):
401402
stem[f'conv{i + 1}'] = conv_layer(in_chs, c, kernel_size=3, stride=s)
@@ -477,7 +478,7 @@ def __init__(self, cfg: NfCfg, num_classes=1000, in_chans=3, global_pool='avg',
477478
self.stem, stem_stride, stem_feat = create_stem(
478479
in_chans, stem_chs, cfg.stem_type, conv_layer=conv_layer, act_layer=act_layer)
479480

480-
self.feature_info = [stem_feat] if stem_stride == 4 else []
481+
self.feature_info = [stem_feat]
481482
drop_path_rates = [x.tolist() for x in torch.linspace(0, drop_path_rate, sum(cfg.depths)).split(cfg.depths)]
482483
prev_chs = stem_chs
483484
net_stride = stem_stride
@@ -486,8 +487,6 @@ def __init__(self, cfg: NfCfg, num_classes=1000, in_chans=3, global_pool='avg',
486487
stages = []
487488
for stage_idx, stage_depth in enumerate(cfg.depths):
488489
stride = 1 if stage_idx == 0 and stem_stride > 2 else 2
489-
if stride == 2:
490-
self.feature_info += [dict(num_chs=prev_chs, reduction=net_stride, module=f'stages.{stage_idx}.0.act1')]
491490
if net_stride >= output_stride and stride > 1:
492491
dilation *= stride
493492
stride = 1
@@ -522,18 +521,19 @@ def __init__(self, cfg: NfCfg, num_classes=1000, in_chans=3, global_pool='avg',
522521
expected_var += cfg.alpha ** 2 # Even if reset occurs, increment expected variance
523522
first_dilation = dilation
524523
prev_chs = out_chs
524+
self.feature_info += [dict(num_chs=prev_chs, reduction=net_stride, module=f'stages.{stage_idx}')]
525525
stages += [nn.Sequential(*blocks)]
526526
self.stages = nn.Sequential(*stages)
527527

528528
if cfg.num_features:
529529
# The paper NFRegNet models have an EfficientNet-like final head convolution.
530530
self.num_features = make_divisible(cfg.width_factor * cfg.num_features, cfg.ch_div)
531531
self.final_conv = conv_layer(prev_chs, self.num_features, 1)
532+
self.feature_info[-1] = dict(num_chs=self.num_features, reduction=net_stride, module=f'final_conv')
532533
else:
533534
self.num_features = prev_chs
534535
self.final_conv = nn.Identity()
535536
self.final_act = act_layer(inplace=cfg.num_features > 0)
536-
self.feature_info += [dict(num_chs=self.num_features, reduction=net_stride, module='final_act')]
537537

538538
self.head = ClassifierHead(self.num_features, num_classes, pool_type=global_pool, drop_rate=self.drop_rate)
539539

@@ -572,10 +572,6 @@ def forward(self, x):
572572
def _create_normfreenet(variant, pretrained=False, **kwargs):
573573
model_cfg = model_cfgs[variant]
574574
feature_cfg = dict(flatten_sequential=True)
575-
feature_cfg['feature_cls'] = 'hook' # pre-act models need hooks to grab feat from act1 in bottleneck blocks
576-
if 'pool' in model_cfg.stem_type and 'deep' not in model_cfg.stem_type:
577-
feature_cfg['out_indices'] = (1, 2, 3, 4) # no stride 2 feat for stride 4, 1 layer maxpool stems
578-
579575
return build_model_with_cfg(
580576
NormFreeNet, variant, pretrained,
581577
default_cfg=default_cfgs[variant],

timm/models/resnetv2.py

Lines changed: 3 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -323,8 +323,8 @@ def __init__(self, layers, channels=(256, 512, 1024, 2048),
323323
self.feature_info = []
324324
stem_chs = make_div(stem_chs * wf)
325325
self.stem = create_stem(in_chans, stem_chs, stem_type, preact, conv_layer=conv_layer, norm_layer=norm_layer)
326-
# NOTE no, reduction 2 feature if preact
327-
self.feature_info.append(dict(num_chs=stem_chs, reduction=2, module='' if preact else 'stem.norm'))
326+
stem_feat = ('stem.conv3' if 'deep' in stem_type else 'stem.conv') if preact else 'stem.norm'
327+
self.feature_info.append(dict(num_chs=stem_chs, reduction=2, module=stem_feat))
328328

329329
prev_chs = stem_chs
330330
curr_stride = 4
@@ -343,10 +343,7 @@ def __init__(self, layers, channels=(256, 512, 1024, 2048),
343343
act_layer=act_layer, conv_layer=conv_layer, norm_layer=norm_layer, block_dpr=bdpr, block_fn=block_fn)
344344
prev_chs = out_chs
345345
curr_stride *= stride
346-
feat_name = f'stages.{stage_idx}'
347-
if preact:
348-
feat_name = f'stages.{stage_idx + 1}.blocks.0.norm1' if (stage_idx + 1) != len(channels) else 'norm'
349-
self.feature_info += [dict(num_chs=prev_chs, reduction=curr_stride, module=feat_name)]
346+
self.feature_info += [dict(num_chs=prev_chs, reduction=curr_stride, module=f'stages.{stage_idx}')]
350347
self.stages.add_module(str(stage_idx), stage)
351348

352349
self.num_features = prev_chs
@@ -414,13 +411,7 @@ def load_pretrained(self, checkpoint_path, prefix='resnet/'):
414411

415412

416413
def _create_resnetv2(variant, pretrained=False, **kwargs):
417-
# FIXME feature map extraction is not setup properly for pre-activation mode right now
418-
preact = kwargs.get('preact', True)
419414
feature_cfg = dict(flatten_sequential=True)
420-
if preact:
421-
feature_cfg['feature_cls'] = 'hook'
422-
feature_cfg['out_indices'] = (1, 2, 3, 4) # no stride 2, 0 level feat for preact
423-
424415
return build_model_with_cfg(
425416
ResNetV2, variant, pretrained, default_cfg=default_cfgs[variant], pretrained_custom_load=True,
426417
feature_cfg=feature_cfg, **kwargs)

timm/models/rexnet.py

Lines changed: 12 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -71,22 +71,23 @@ def forward(self, x):
7171

7272

7373
class LinearBottleneck(nn.Module):
74-
def __init__(self, in_chs, out_chs, stride, exp_ratio=1.0, se_ratio=0., ch_div=1, drop_path=None):
74+
def __init__(self, in_chs, out_chs, stride, exp_ratio=1.0, se_ratio=0., ch_div=1,
75+
act_layer='swish', dw_act_layer='relu6', drop_path=None):
7576
super(LinearBottleneck, self).__init__()
7677
self.use_shortcut = stride == 1 and in_chs <= out_chs
7778
self.in_channels = in_chs
7879
self.out_channels = out_chs
7980

8081
if exp_ratio != 1.:
8182
dw_chs = make_divisible(round(in_chs * exp_ratio), divisor=ch_div)
82-
self.conv_exp = ConvBnAct(in_chs, dw_chs, act_layer="swish")
83+
self.conv_exp = ConvBnAct(in_chs, dw_chs, act_layer=act_layer)
8384
else:
8485
dw_chs = in_chs
8586
self.conv_exp = None
8687

8788
self.conv_dw = ConvBnAct(dw_chs, dw_chs, 3, stride=stride, groups=dw_chs, apply_act=False)
8889
self.se = SEWithNorm(dw_chs, se_ratio=se_ratio, divisor=ch_div) if se_ratio > 0. else None
89-
self.act_dw = nn.ReLU6()
90+
self.act_dw = create_act_layer(dw_act_layer)
9091

9192
self.conv_pwl = ConvBnAct(dw_chs, out_chs, 1, apply_act=False)
9293
self.drop_path = drop_path
@@ -131,8 +132,7 @@ def _block_cfg(width_mult=1.0, depth_mult=1.0, initial_chs=16, final_chs=180, se
131132

132133

133134
def _build_blocks(
134-
block_cfg, prev_chs, width_mult, ch_div=1, drop_path_rate=0., feature_location='bottleneck'):
135-
feat_exp = feature_location == 'expansion'
135+
block_cfg, prev_chs, width_mult, ch_div=1, act_layer='swish', dw_act_layer='relu6', drop_path_rate=0.):
136136
feat_chs = [prev_chs]
137137
feature_info = []
138138
curr_stride = 2
@@ -141,41 +141,37 @@ def _build_blocks(
141141
for block_idx, (chs, exp_ratio, stride, se_ratio) in enumerate(block_cfg):
142142
if stride > 1:
143143
fname = 'stem' if block_idx == 0 else f'features.{block_idx - 1}'
144-
if block_idx > 0 and feat_exp:
145-
fname += '.act_dw'
146144
feature_info += [dict(num_chs=feat_chs[-1], reduction=curr_stride, module=fname)]
147145
curr_stride *= stride
148146
block_dpr = drop_path_rate * block_idx / (num_blocks - 1) # stochastic depth linear decay rule
149147
drop_path = DropPath(block_dpr) if block_dpr > 0. else None
150148
features.append(LinearBottleneck(
151149
in_chs=prev_chs, out_chs=chs, exp_ratio=exp_ratio, stride=stride, se_ratio=se_ratio,
152-
ch_div=ch_div, drop_path=drop_path))
150+
ch_div=ch_div, act_layer=act_layer, dw_act_layer=dw_act_layer, drop_path=drop_path))
153151
prev_chs = chs
154-
feat_chs += [features[-1].feat_channels(feat_exp)]
152+
feat_chs += [features[-1].feat_channels()]
155153
pen_chs = make_divisible(1280 * width_mult, divisor=ch_div)
156-
feature_info += [dict(
157-
num_chs=pen_chs if feat_exp else feat_chs[-1], reduction=curr_stride,
158-
module=f'features.{len(features) - int(not feat_exp)}')]
159-
features.append(ConvBnAct(prev_chs, pen_chs, act_layer="swish"))
154+
feature_info += [dict(num_chs=feat_chs[-1], reduction=curr_stride, module=f'features.{len(features) - 1}')]
155+
features.append(ConvBnAct(prev_chs, pen_chs, act_layer=act_layer))
160156
return features, feature_info
161157

162158

163159
class ReXNetV1(nn.Module):
164160
def __init__(self, in_chans=3, num_classes=1000, global_pool='avg', output_stride=32,
165161
initial_chs=16, final_chs=180, width_mult=1.0, depth_mult=1.0, se_ratio=1/12.,
166-
ch_div=1, drop_rate=0.2, drop_path_rate=0., feature_location='bottleneck'):
162+
ch_div=1, act_layer='swish', dw_act_layer='relu6', drop_rate=0.2, drop_path_rate=0.):
167163
super(ReXNetV1, self).__init__()
168164
self.drop_rate = drop_rate
169165
self.num_classes = num_classes
170166

171167
assert output_stride == 32 # FIXME support dilation
172168
stem_base_chs = 32 / width_mult if width_mult < 1.0 else 32
173169
stem_chs = make_divisible(round(stem_base_chs * width_mult), divisor=ch_div)
174-
self.stem = ConvBnAct(in_chans, stem_chs, 3, stride=2, act_layer='swish')
170+
self.stem = ConvBnAct(in_chans, stem_chs, 3, stride=2, act_layer=act_layer)
175171

176172
block_cfg = _block_cfg(width_mult, depth_mult, initial_chs, final_chs, se_ratio, ch_div)
177173
features, self.feature_info = _build_blocks(
178-
block_cfg, stem_chs, width_mult, ch_div, drop_path_rate, feature_location)
174+
block_cfg, stem_chs, width_mult, ch_div, act_layer, dw_act_layer, drop_path_rate)
179175
self.num_features = features[-1].out_channels
180176
self.features = nn.Sequential(*features)
181177

@@ -202,8 +198,6 @@ def forward(self, x):
202198

203199
def _create_rexnet(variant, pretrained, **kwargs):
204200
feature_cfg = dict(flatten_sequential=True)
205-
if kwargs.get('feature_location', '') == 'expansion':
206-
feature_cfg['feature_cls'] = 'hook'
207201
return build_model_with_cfg(
208202
ReXNetV1, variant, pretrained, default_cfg=default_cfgs[variant], feature_cfg=feature_cfg, **kwargs)
209203

timm/version.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
__version__ = '0.4.4'
1+
__version__ = '0.4.5'

0 commit comments

Comments
 (0)