Skip to content

Commit 4e24f75

Browse files
authored
Merge pull request #1593 from rwightman/multi-weight_effnet_convnext
Update efficientnet.py and convnext.py to multi-weight, add new 12k pretrained weights
2 parents 18ec173 + 8ece53e commit 4e24f75

26 files changed

+2090
-1579
lines changed

.gitignore

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,16 @@ output/
106106
*.tar
107107
*.pth
108108
*.pt
109+
*.torch
109110
*.gz
110111
Untitled.ipynb
111112
Testing notebook.ipynb
113+
114+
# Root dir exclusions
115+
/*.csv
116+
/*.yaml
117+
/*.json
118+
/*.jpg
119+
/*.png
120+
/*.zip
121+
/*.tar.*

tests/test_models.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727
'vit_*', 'tnt_*', 'pit_*', 'swin_*', 'coat_*', 'cait_*', '*mixer_*', 'gmlp_*', 'resmlp_*', 'twins_*',
2828
'convit_*', 'levit*', 'visformer*', 'deit*', 'jx_nest_*', 'nest_*', 'xcit_*', 'crossvit_*', 'beit*',
2929
'poolformer_*', 'volo_*', 'sequencer2d_*', 'swinv2_*', 'pvt_v2*', 'mvitv2*', 'gcvit*', 'efficientformer*',
30-
'coatnet*', 'coatnext*', 'maxvit*', 'maxxvit*', 'eva_*'
30+
'coatnet*', 'coatnext*', 'maxvit*', 'maxxvit*', 'eva_*', 'flexivit*'
3131
]
3232
NUM_NON_STD = len(NON_STD_FILTERS)
3333

timm/data/dataset_factory.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -151,7 +151,7 @@ def create_dataset(
151151
elif name.startswith('hfds/'):
152152
# NOTE right now, HF datasets default arrow format is a random-access Dataset,
153153
# There will be a IterableDataset variant too, TBD
154-
ds = ImageDataset(root, reader=name, split=split, **kwargs)
154+
ds = ImageDataset(root, reader=name, split=split, class_map=class_map, **kwargs)
155155
elif name.startswith('tfds/'):
156156
ds = IterableImageDataset(
157157
root,

timm/data/readers/reader_factory.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66

77
def create_reader(name, root, split='train', **kwargs):
88
name = name.lower()
9-
name = name.split('/', 2)
9+
name = name.split('/', 1)
1010
prefix = ''
1111
if len(name) > 1:
1212
prefix = name[0]

timm/data/readers/reader_hfds.py

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -13,13 +13,14 @@
1313
except ImportError as e:
1414
print("Please install Hugging Face datasets package `pip install datasets`.")
1515
exit(1)
16+
from .class_map import load_class_map
1617
from .reader import Reader
1718

1819

19-
def get_class_labels(info):
20+
def get_class_labels(info, label_key='label'):
2021
if 'label' not in info.features:
2122
return {}
22-
class_label = info.features['label']
23+
class_label = info.features[label_key]
2324
class_to_idx = {n: class_label.str2int(n) for n in class_label.names}
2425
return class_to_idx
2526

@@ -32,6 +33,7 @@ def __init__(
3233
name,
3334
split='train',
3435
class_map=None,
36+
label_key='label',
3537
download=False,
3638
):
3739
"""
@@ -43,12 +45,17 @@ def __init__(
4345
name, # 'name' maps to path arg in hf datasets
4446
split=split,
4547
cache_dir=self.root, # timm doesn't expect hidden cache dir for datasets, specify a path
46-
#use_auth_token=True,
4748
)
4849
# leave decode for caller, plus we want easy access to original path names...
4950
self.dataset = self.dataset.cast_column('image', datasets.Image(decode=False))
5051

51-
self.class_to_idx = get_class_labels(self.dataset.info)
52+
self.label_key = label_key
53+
self.remap_class = False
54+
if class_map:
55+
self.class_to_idx = load_class_map(class_map)
56+
self.remap_class = True
57+
else:
58+
self.class_to_idx = get_class_labels(self.dataset.info, self.label_key)
5259
self.split_info = self.dataset.info.splits[split]
5360
self.num_samples = self.split_info.num_examples
5461

@@ -60,7 +67,10 @@ def __getitem__(self, index):
6067
else:
6168
assert 'path' in image and image['path']
6269
image = open(image['path'], 'rb')
63-
return image, item['label']
70+
label = item[self.label_key]
71+
if self.remap_class:
72+
label = self.class_to_idx[label]
73+
return image, label
6474

6575
def __len__(self):
6676
return len(self.dataset)

timm/layers/__init__.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from .activations import *
22
from .adaptive_avgmax_pool import \
33
adaptive_avgmax_pool2d, select_adaptive_pool2d, AdaptiveAvgMaxPool2d, SelectAdaptivePool2d
4+
from .attention_pool2d import AttentionPool2d, RotAttentionPool2d, RotaryEmbedding
45
from .blur_pool import BlurPool2d
56
from .classifier import ClassifierHead, create_classifier
67
from .cond_conv2d import CondConv2d, get_condconv_initializer
@@ -30,8 +31,12 @@
3031
from .norm import GroupNorm, GroupNorm1, LayerNorm, LayerNorm2d
3132
from .norm_act import BatchNormAct2d, GroupNormAct, convert_sync_batchnorm
3233
from .padding import get_padding, get_same_padding, pad_same
33-
from .patch_embed import PatchEmbed
34+
from .patch_embed import PatchEmbed, resample_patch_embed
3435
from .pool2d_same import AvgPool2dSame, create_pool2d
36+
from .pos_embed import resample_abs_pos_embed
37+
from .pos_embed_rel import RelPosMlp, RelPosBias, RelPosBiasTf, gen_relative_position_index, gen_relative_log_coords
38+
from .pos_embed_sincos import build_sincos2d_pos_embed, build_fourier_pos_embed, build_rotary_pos_embed, \
39+
FourierEmbed, RotaryEmbedding
3540
from .squeeze_excite import SEModule, SqueezeExcite, EffectiveSEModule, EffectiveSqueezeExcite
3641
from .selective_kernel import SelectiveKernel
3742
from .separable_conv import SeparableConv2d, SeparableConvNormAct

timm/layers/attention_pool2d.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
import torch.nn as nn
1414

1515
from .helpers import to_2tuple
16-
from .pos_embed import apply_rot_embed, RotaryEmbedding
16+
from .pos_embed_sincos import apply_rot_embed, RotaryEmbedding
1717
from .weight_init import trunc_normal_
1818

1919

timm/layers/helpers.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
def _ntuple(n):
1111
def parse(x):
1212
if isinstance(x, collections.abc.Iterable) and not isinstance(x, str):
13-
return x
13+
return tuple(x)
1414
return tuple(repeat(x, n))
1515
return parse
1616

timm/layers/patch_embed.py

Lines changed: 129 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,15 +2,24 @@
22
33
A convolution based approach to patchifying a 2D image w/ embedding projection.
44
5-
Based on the impl in https://github.com/google-research/vision_transformer
5+
Based on code in:
6+
* https://github.com/google-research/vision_transformer
7+
* https://github.com/google-research/big_vision/tree/main/big_vision
68
79
Hacked together by / Copyright 2020 Ross Wightman
810
"""
11+
import logging
12+
from typing import List
13+
14+
import torch
915
from torch import nn as nn
16+
import torch.nn.functional as F
1017

1118
from .helpers import to_2tuple
1219
from .trace_utils import _assert
1320

21+
_logger = logging.getLogger(__name__)
22+
1423

1524
class PatchEmbed(nn.Module):
1625
""" 2D Image to Patch Embedding
@@ -46,3 +55,122 @@ def forward(self, x):
4655
x = x.flatten(2).transpose(1, 2) # BCHW -> BNC
4756
x = self.norm(x)
4857
return x
58+
59+
60+
def resample_patch_embed(
61+
patch_embed,
62+
new_size: List[int],
63+
interpolation: str = 'bicubic',
64+
antialias: bool = True,
65+
verbose: bool = False,
66+
):
67+
"""Resample the weights of the patch embedding kernel to target resolution.
68+
We resample the patch embedding kernel by approximately inverting the effect
69+
of patch resizing.
70+
71+
Code based on:
72+
https://github.com/google-research/big_vision/blob/b00544b81f8694488d5f36295aeb7972f3755ffe/big_vision/models/proj/flexi/vit.py
73+
74+
With this resizing, we can for example load a B/8 filter into a B/16 model
75+
and, on 2x larger input image, the result will match.
76+
77+
Args:
78+
patch_embed: original parameter to be resized.
79+
new_size (tuple(int, int): target shape (height, width)-only.
80+
interpolation (str): interpolation for resize
81+
antialias (bool): use anti-aliasing filter in resize
82+
verbose (bool): log operation
83+
Returns:
84+
Resized patch embedding kernel.
85+
"""
86+
import numpy as np
87+
88+
assert len(patch_embed.shape) == 4, "Four dimensions expected"
89+
assert len(new_size) == 2, "New shape should only be hw"
90+
old_size = patch_embed.shape[-2:]
91+
if tuple(old_size) == tuple(new_size):
92+
return patch_embed
93+
94+
if verbose:
95+
_logger.info(f"Resize patch embedding {patch_embed.shape} to {new_size}, w/ {interpolation} interpolation.")
96+
97+
def resize(x_np, _new_size):
98+
x_tf = torch.Tensor(x_np)[None, None, ...]
99+
x_upsampled = F.interpolate(
100+
x_tf, size=_new_size, mode=interpolation, antialias=antialias)[0, 0, ...].numpy()
101+
return x_upsampled
102+
103+
def get_resize_mat(_old_size, _new_size):
104+
mat = []
105+
for i in range(np.prod(_old_size)):
106+
basis_vec = np.zeros(_old_size)
107+
basis_vec[np.unravel_index(i, _old_size)] = 1.
108+
mat.append(resize(basis_vec, _new_size).reshape(-1))
109+
return np.stack(mat).T
110+
111+
resize_mat = get_resize_mat(old_size, new_size)
112+
resize_mat_pinv = torch.Tensor(np.linalg.pinv(resize_mat.T))
113+
114+
def resample_kernel(kernel):
115+
resampled_kernel = resize_mat_pinv @ kernel.reshape(-1)
116+
return resampled_kernel.reshape(new_size)
117+
118+
v_resample_kernel = torch.vmap(torch.vmap(resample_kernel, 0, 0), 1, 1)
119+
return v_resample_kernel(patch_embed)
120+
121+
122+
# def divs(n, m=None):
123+
# m = m or n // 2
124+
# if m == 1:
125+
# return [1]
126+
# if n % m == 0:
127+
# return [m] + divs(n, m - 1)
128+
# return divs(n, m - 1)
129+
#
130+
#
131+
# class FlexiPatchEmbed(nn.Module):
132+
# """ 2D Image to Patch Embedding w/ Flexible Patch sizes (FlexiViT)
133+
# FIXME WIP
134+
# """
135+
# def __init__(
136+
# self,
137+
# img_size=240,
138+
# patch_size=16,
139+
# in_chans=3,
140+
# embed_dim=768,
141+
# base_img_size=240,
142+
# base_patch_size=32,
143+
# norm_layer=None,
144+
# flatten=True,
145+
# bias=True,
146+
# ):
147+
# super().__init__()
148+
# self.img_size = to_2tuple(img_size)
149+
# self.patch_size = to_2tuple(patch_size)
150+
# self.num_patches = 0
151+
#
152+
# # full range for 240 = (5, 6, 8, 10, 12, 14, 15, 16, 20, 24, 30, 40, 48)
153+
# self.seqhw = (6, 8, 10, 12, 14, 15, 16, 20, 24, 30)
154+
#
155+
# self.base_img_size = to_2tuple(base_img_size)
156+
# self.base_patch_size = to_2tuple(base_patch_size)
157+
# self.base_grid_size = tuple([i // p for i, p in zip(self.base_img_size, self.base_patch_size)])
158+
# self.base_num_patches = self.base_grid_size[0] * self.base_grid_size[1]
159+
#
160+
# self.flatten = flatten
161+
# self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=self.patch_size, stride=self.patch_size, bias=bias)
162+
# self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
163+
#
164+
# def forward(self, x):
165+
# B, C, H, W = x.shape
166+
#
167+
# if self.patch_size == self.base_patch_size:
168+
# weight = self.proj.weight
169+
# else:
170+
# weight = resample_patch_embed(self.proj.weight, self.patch_size)
171+
# patch_size = self.patch_size
172+
# x = F.conv2d(x, weight, bias=self.proj.bias, stride=patch_size)
173+
# if self.flatten:
174+
# x = x.flatten(2).transpose(1, 2) # BCHW -> BNC
175+
# x = self.norm(x)
176+
# return x

0 commit comments

Comments
 (0)