@@ -109,6 +109,8 @@ def _cfg(url='', **kwargs):
109
109
'vit_giant_patch14_224' : _cfg (url = '' ),
110
110
'vit_gigantic_patch14_224' : _cfg (url = '' ),
111
111
112
+ 'vit_base2_patch32_256' : _cfg (url = '' , input_size = (3 , 256 , 256 ), crop_pct = 0.95 ),
113
+
112
114
# patch models, imagenet21k (weights from official Google JAX impl)
113
115
'vit_tiny_patch16_224_in21k' : _cfg (
114
116
url = 'https://storage.googleapis.com/vit_models/augreg/Ti_16-i21k-300ep-lr_0.001-aug_none-wd_0.03-do_0.0-sd_0.0.npz' ,
@@ -202,6 +204,7 @@ def _cfg(url='', **kwargs):
202
204
class Attention (nn .Module ):
203
205
def __init__ (self , dim , num_heads = 8 , qkv_bias = False , attn_drop = 0. , proj_drop = 0. ):
204
206
super ().__init__ ()
207
+ assert dim % num_heads == 0 , 'dim should be divisible by num_heads'
205
208
self .num_heads = num_heads
206
209
head_dim = dim // num_heads
207
210
self .scale = head_dim ** - 0.5
@@ -634,6 +637,16 @@ def vit_base_patch32_224(pretrained=False, **kwargs):
634
637
return model
635
638
636
639
640
+ @register_model
641
+ def vit_base2_patch32_256 (pretrained = False , ** kwargs ):
642
+ """ ViT-Base (ViT-B/32)
643
+ # FIXME experiment
644
+ """
645
+ model_kwargs = dict (patch_size = 32 , embed_dim = 896 , depth = 12 , num_heads = 14 , ** kwargs )
646
+ model = _create_vision_transformer ('vit_base2_patch32_256' , pretrained = pretrained , ** model_kwargs )
647
+ return model
648
+
649
+
637
650
@register_model
638
651
def vit_base_patch32_384 (pretrained = False , ** kwargs ):
639
652
""" ViT-Base model (ViT-B/32) from original paper (https://arxiv.org/abs/2010.11929).
0 commit comments