2121from timm .layers import LayerNorm2d , NormMlpClassifierHead , DropPath ,\
2222 trunc_normal_ , resize_rel_pos_bias_table_levit , use_fused_attn
2323from ._builder import build_model_with_cfg
24+ from ._features_fx import register_notrace_module
2425from ._manipulate import checkpoint_seq
2526from ._registry import register_model , generate_default_cfgs
2627
@@ -178,18 +179,15 @@ def __init__(
178179 self .num_heads = num_heads
179180 self .scale = key_dim ** - 0.5
180181 self .key_dim = key_dim
181- self .nh_kd = nh_kd = key_dim * num_heads
182- self .d = int (attn_ratio * key_dim )
183- self .dh = int (attn_ratio * key_dim ) * num_heads
182+ self .val_dim = int (attn_ratio * key_dim )
183+ self .out_dim = self .val_dim * num_heads
184184 self .attn_ratio = attn_ratio
185185 self .resolution = resolution
186186 self .fused_attn = use_fused_attn ()
187187
188- h = self .dh + nh_kd * 2
189-
190188 self .norm = nn .LayerNorm (dim )
191- self .qkv = nn .Linear (dim , h )
192- self .proj = nn .Linear (self .dh , dim )
189+ self .qkv = nn .Linear (dim , num_heads * ( self . val_dim + 2 * key_dim ) )
190+ self .proj = nn .Linear (self .out_dim , dim )
193191
194192 points = list (itertools .product (range (resolution [0 ]), range (resolution [1 ])))
195193 N = len (points )
@@ -227,7 +225,7 @@ def forward(self, x):
227225 x = self .norm (x )
228226 qkv = self .qkv (x )
229227 # (B, N, num_heads, d)
230- q , k , v = qkv .view (B , N , self .num_heads , - 1 ).split ([self .key_dim , self .key_dim , self .d ], dim = 3 )
228+ q , k , v = qkv .view (B , N , self .num_heads , - 1 ).split ([self .key_dim , self .key_dim , self .val_dim ], dim = 3 )
231229 # (B, num_heads, N, d)
232230 q = q .permute (0 , 2 , 1 , 3 )
233231 k = k .permute (0 , 2 , 1 , 3 )
@@ -241,7 +239,7 @@ def forward(self, x):
241239 attn = attn + attn_bias
242240 attn = attn .softmax (dim = - 1 )
243241 x = attn @ v
244- x = x .transpose (1 , 2 ).reshape (B , N , self .dh )
242+ x = x .transpose (1 , 2 ).reshape (B , N , self .out_dim )
245243 x = self .proj (x )
246244 return x
247245
@@ -311,7 +309,6 @@ def forward(self, x):
311309 pad_b = (self .window_size - H % self .window_size ) % self .window_size
312310 pad_r = (self .window_size - W % self .window_size ) % self .window_size
313311 padding = pad_b > 0 or pad_r > 0
314-
315312 if padding :
316313 x = F .pad (x , (0 , 0 , 0 , pad_r , 0 , pad_b ))
317314
@@ -344,6 +341,9 @@ def extra_repr(self) -> str:
344341 f"window_size={ self .window_size } , mlp_ratio={ self .mlp_ratio } "
345342
346343
344+ register_notrace_module (TinyVitBlock )
345+
346+
347347class TinyVitStage (nn .Module ):
348348 """ A basic TinyViT layer for one stage.
349349
0 commit comments