16
16
17
17
import torch
18
18
import torch .nn as nn
19
- import torch .nn .functional as F
20
19
21
20
from timm .data import IMAGENET_DEFAULT_MEAN , IMAGENET_DEFAULT_STD
22
21
from .helpers import named_apply , build_model_with_cfg , checkpoint_seq
23
22
from .layers import trunc_normal_ , SelectAdaptivePool2d , DropPath , ConvMlp , Mlp , LayerNorm2d ,\
24
- create_conv2d , make_divisible
23
+ create_conv2d , get_act_layer , make_divisible , to_ntuple
25
24
from .registry import register_model
26
25
27
26
@@ -40,14 +39,13 @@ def _cfg(url='', **kwargs):
40
39
41
40
42
41
default_cfgs = dict (
43
- convnext_tiny = _cfg (url = "https://dl.fbaipublicfiles.com/convnext/convnext_tiny_1k_224_ema.pth" ),
44
- convnext_small = _cfg (url = "https://dl.fbaipublicfiles.com/convnext/convnext_small_1k_224_ema.pth" ),
45
- convnext_base = _cfg (url = "https://dl.fbaipublicfiles.com/convnext/convnext_base_1k_224_ema.pth" ),
46
- convnext_large = _cfg (url = "https://dl.fbaipublicfiles.com/convnext/convnext_large_1k_224_ema.pth" ),
47
-
48
42
# timm specific variants
49
- convnext_atto = _cfg (url = '' ),
50
- convnext_atto_ols = _cfg (url = '' ),
43
+ convnext_atto = _cfg (
44
+ url = 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rsb-weights/convnext_atto_d2-01bb0f51.pth' ,
45
+ test_input_size = (3 , 288 , 288 ), test_crop_pct = 0.95 ),
46
+ convnext_atto_ols = _cfg (
47
+ url = 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rsb-weights/convnext_atto_ols_a2-78d1c8f3.pth' ,
48
+ test_input_size = (3 , 288 , 288 ), test_crop_pct = 0.95 ),
51
49
convnext_femto = _cfg (
52
50
url = 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rsb-weights/convnext_femto_d1-d71d5b4c.pth' ,
53
51
test_input_size = (3 , 288 , 288 ), test_crop_pct = 0.95 ),
@@ -70,16 +68,34 @@ def _cfg(url='', **kwargs):
70
68
url = 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rsb-weights/convnext_tiny_hnf_a2h-ab7e9df2.pth' ,
71
69
crop_pct = 0.95 , test_input_size = (3 , 288 , 288 ), test_crop_pct = 1.0 ),
72
70
71
+ convnext_tiny = _cfg (
72
+ url = "https://dl.fbaipublicfiles.com/convnext/convnext_tiny_1k_224_ema.pth" ,
73
+ test_input_size = (3 , 288 , 288 ), test_crop_pct = 1.0 ),
74
+ convnext_small = _cfg (
75
+ url = "https://dl.fbaipublicfiles.com/convnext/convnext_small_1k_224_ema.pth" ,
76
+ test_input_size = (3 , 288 , 288 ), test_crop_pct = 1.0 ),
77
+ convnext_base = _cfg (
78
+ url = "https://dl.fbaipublicfiles.com/convnext/convnext_base_1k_224_ema.pth" ,
79
+ test_input_size = (3 , 288 , 288 ), test_crop_pct = 1.0 ),
80
+ convnext_large = _cfg (
81
+ url = "https://dl.fbaipublicfiles.com/convnext/convnext_large_1k_224_ema.pth" ,
82
+ test_input_size = (3 , 288 , 288 ), test_crop_pct = 1.0 ),
83
+
73
84
convnext_tiny_in22ft1k = _cfg (
74
- url = 'https://dl.fbaipublicfiles.com/convnext/convnext_tiny_22k_1k_224.pth' ),
85
+ url = 'https://dl.fbaipublicfiles.com/convnext/convnext_tiny_22k_1k_224.pth' ,
86
+ test_input_size = (3 , 288 , 288 ), test_crop_pct = 1.0 ),
75
87
convnext_small_in22ft1k = _cfg (
76
- url = 'https://dl.fbaipublicfiles.com/convnext/convnext_small_22k_1k_224.pth' ),
88
+ url = 'https://dl.fbaipublicfiles.com/convnext/convnext_small_22k_1k_224.pth' ,
89
+ test_input_size = (3 , 288 , 288 ), test_crop_pct = 1.0 ),
77
90
convnext_base_in22ft1k = _cfg (
78
- url = 'https://dl.fbaipublicfiles.com/convnext/convnext_base_22k_1k_224.pth' ),
91
+ url = 'https://dl.fbaipublicfiles.com/convnext/convnext_base_22k_1k_224.pth' ,
92
+ test_input_size = (3 , 288 , 288 ), test_crop_pct = 1.0 ),
79
93
convnext_large_in22ft1k = _cfg (
80
- url = 'https://dl.fbaipublicfiles.com/convnext/convnext_large_22k_1k_224.pth' ),
94
+ url = 'https://dl.fbaipublicfiles.com/convnext/convnext_large_22k_1k_224.pth' ,
95
+ test_input_size = (3 , 288 , 288 ), test_crop_pct = 1.0 ),
81
96
convnext_xlarge_in22ft1k = _cfg (
82
- url = 'https://dl.fbaipublicfiles.com/convnext/convnext_xlarge_22k_1k_224_ema.pth' ),
97
+ url = 'https://dl.fbaipublicfiles.com/convnext/convnext_xlarge_22k_1k_224_ema.pth' ,
98
+ test_input_size = (3 , 288 , 288 ), test_crop_pct = 1.0 ),
83
99
84
100
convnext_tiny_384_in22ft1k = _cfg (
85
101
url = 'https://dl.fbaipublicfiles.com/convnext/convnext_tiny_22k_1k_384.pth' ,
@@ -121,37 +137,39 @@ class ConvNeXtBlock(nn.Module):
121
137
is a better choice. This was observed with PyTorch 1.10 on 3090 GPU, it could change over time & w/ different HW.
122
138
123
139
Args:
124
- dim (int): Number of input channels.
140
+ in_chs (int): Number of input channels.
125
141
drop_path (float): Stochastic depth rate. Default: 0.0
126
142
ls_init_value (float): Init value for Layer Scale. Default: 1e-6.
127
143
"""
128
144
129
145
def __init__ (
130
146
self ,
131
- dim ,
132
- dim_out = None ,
147
+ in_chs ,
148
+ out_chs = None ,
149
+ kernel_size = 7 ,
133
150
stride = 1 ,
134
151
dilation = 1 ,
135
152
mlp_ratio = 4 ,
136
153
conv_mlp = False ,
137
154
conv_bias = True ,
138
155
ls_init_value = 1e-6 ,
156
+ act_layer = 'gelu' ,
139
157
norm_layer = None ,
140
- act_layer = nn .GELU ,
141
158
drop_path = 0. ,
142
159
):
143
160
super ().__init__ ()
144
- dim_out = dim_out or dim
161
+ out_chs = out_chs or in_chs
162
+ act_layer = get_act_layer (act_layer )
145
163
if not norm_layer :
146
164
norm_layer = partial (LayerNorm2d , eps = 1e-6 ) if conv_mlp else partial (nn .LayerNorm , eps = 1e-6 )
147
165
mlp_layer = ConvMlp if conv_mlp else Mlp
148
166
self .use_conv_mlp = conv_mlp
149
167
150
168
self .conv_dw = create_conv2d (
151
- dim , dim_out , kernel_size = 7 , stride = stride , dilation = dilation , depthwise = True , bias = conv_bias )
152
- self .norm = norm_layer (dim_out )
153
- self .mlp = mlp_layer (dim_out , int (mlp_ratio * dim_out ), act_layer = act_layer )
154
- self .gamma = nn .Parameter (ls_init_value * torch .ones (dim_out )) if ls_init_value > 0 else None
169
+ in_chs , out_chs , kernel_size = kernel_size , stride = stride , dilation = dilation , depthwise = True , bias = conv_bias )
170
+ self .norm = norm_layer (out_chs )
171
+ self .mlp = mlp_layer (out_chs , int (mlp_ratio * out_chs ), act_layer = act_layer )
172
+ self .gamma = nn .Parameter (ls_init_value * torch .ones (out_chs )) if ls_init_value > 0 else None
155
173
self .drop_path = DropPath (drop_path ) if drop_path > 0. else nn .Identity ()
156
174
157
175
def forward (self , x ):
@@ -178,13 +196,15 @@ def __init__(
178
196
self ,
179
197
in_chs ,
180
198
out_chs ,
199
+ kernel_size = 7 ,
181
200
stride = 2 ,
182
201
depth = 2 ,
183
202
dilation = (1 , 1 ),
184
203
drop_path_rates = None ,
185
204
ls_init_value = 1.0 ,
186
205
conv_mlp = False ,
187
206
conv_bias = True ,
207
+ act_layer = 'gelu' ,
188
208
norm_layer = None ,
189
209
norm_layer_cl = None
190
210
):
@@ -208,13 +228,15 @@ def __init__(
208
228
stage_blocks = []
209
229
for i in range (depth ):
210
230
stage_blocks .append (ConvNeXtBlock (
211
- dim = in_chs ,
212
- dim_out = out_chs ,
231
+ in_chs = in_chs ,
232
+ out_chs = out_chs ,
233
+ kernel_size = kernel_size ,
213
234
dilation = dilation [1 ],
214
235
drop_path = drop_path_rates [i ],
215
236
ls_init_value = ls_init_value ,
216
237
conv_mlp = conv_mlp ,
217
238
conv_bias = conv_bias ,
239
+ act_layer = act_layer ,
218
240
norm_layer = norm_layer if conv_mlp else norm_layer_cl
219
241
))
220
242
in_chs = out_chs
@@ -252,19 +274,22 @@ def __init__(
252
274
output_stride = 32 ,
253
275
depths = (3 , 3 , 9 , 3 ),
254
276
dims = (96 , 192 , 384 , 768 ),
277
+ kernel_sizes = 7 ,
255
278
ls_init_value = 1e-6 ,
256
279
stem_type = 'patch' ,
257
280
patch_size = 4 ,
258
281
head_init_scale = 1. ,
259
282
head_norm_first = False ,
260
283
conv_mlp = False ,
261
284
conv_bias = True ,
285
+ act_layer = 'gelu' ,
262
286
norm_layer = None ,
263
287
drop_rate = 0. ,
264
288
drop_path_rate = 0. ,
265
289
):
266
290
super ().__init__ ()
267
291
assert output_stride in (8 , 16 , 32 )
292
+ kernel_sizes = to_ntuple (4 )(kernel_sizes )
268
293
if norm_layer is None :
269
294
norm_layer = partial (LayerNorm2d , eps = 1e-6 )
270
295
norm_layer_cl = norm_layer if conv_mlp else partial (nn .LayerNorm , eps = 1e-6 )
@@ -312,13 +337,15 @@ def __init__(
312
337
stages .append (ConvNeXtStage (
313
338
prev_chs ,
314
339
out_chs ,
340
+ kernel_size = kernel_sizes [i ],
315
341
stride = stride ,
316
342
dilation = (first_dilation , dilation ),
317
343
depth = depths [i ],
318
344
drop_path_rates = dp_rates [i ],
319
345
ls_init_value = ls_init_value ,
320
346
conv_mlp = conv_mlp ,
321
347
conv_bias = conv_bias ,
348
+ act_layer = act_layer ,
322
349
norm_layer = norm_layer ,
323
350
norm_layer_cl = norm_layer_cl
324
351
))
0 commit comments