Skip to content

Commit 507cb08

Browse files
committed
TinyVitBlock needs adding as leaf for FX now, tweak a few dim names
1 parent 9caf32b commit 507cb08

File tree

1 file changed

+10
-10
lines changed

1 file changed

+10
-10
lines changed

timm/models/tiny_vit.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
from timm.layers import LayerNorm2d, NormMlpClassifierHead, DropPath,\
2222
trunc_normal_, resize_rel_pos_bias_table_levit, use_fused_attn
2323
from ._builder import build_model_with_cfg
24+
from ._features_fx import register_notrace_module
2425
from ._manipulate import checkpoint_seq
2526
from ._registry import register_model, generate_default_cfgs
2627

@@ -178,18 +179,15 @@ def __init__(
178179
self.num_heads = num_heads
179180
self.scale = key_dim ** -0.5
180181
self.key_dim = key_dim
181-
self.nh_kd = nh_kd = key_dim * num_heads
182-
self.d = int(attn_ratio * key_dim)
183-
self.dh = int(attn_ratio * key_dim) * num_heads
182+
self.val_dim = int(attn_ratio * key_dim)
183+
self.out_dim = self.val_dim * num_heads
184184
self.attn_ratio = attn_ratio
185185
self.resolution = resolution
186186
self.fused_attn = use_fused_attn()
187187

188-
h = self.dh + nh_kd * 2
189-
190188
self.norm = nn.LayerNorm(dim)
191-
self.qkv = nn.Linear(dim, h)
192-
self.proj = nn.Linear(self.dh, dim)
189+
self.qkv = nn.Linear(dim, num_heads * (self.val_dim + 2 * key_dim))
190+
self.proj = nn.Linear(self.out_dim, dim)
193191

194192
points = list(itertools.product(range(resolution[0]), range(resolution[1])))
195193
N = len(points)
@@ -227,7 +225,7 @@ def forward(self, x):
227225
x = self.norm(x)
228226
qkv = self.qkv(x)
229227
# (B, N, num_heads, d)
230-
q, k, v = qkv.view(B, N, self.num_heads, -1).split([self.key_dim, self.key_dim, self.d], dim=3)
228+
q, k, v = qkv.view(B, N, self.num_heads, -1).split([self.key_dim, self.key_dim, self.val_dim], dim=3)
231229
# (B, num_heads, N, d)
232230
q = q.permute(0, 2, 1, 3)
233231
k = k.permute(0, 2, 1, 3)
@@ -241,7 +239,7 @@ def forward(self, x):
241239
attn = attn + attn_bias
242240
attn = attn.softmax(dim=-1)
243241
x = attn @ v
244-
x = x.transpose(1, 2).reshape(B, N, self.dh)
242+
x = x.transpose(1, 2).reshape(B, N, self.out_dim)
245243
x = self.proj(x)
246244
return x
247245

@@ -311,7 +309,6 @@ def forward(self, x):
311309
pad_b = (self.window_size - H % self.window_size) % self.window_size
312310
pad_r = (self.window_size - W % self.window_size) % self.window_size
313311
padding = pad_b > 0 or pad_r > 0
314-
315312
if padding:
316313
x = F.pad(x, (0, 0, 0, pad_r, 0, pad_b))
317314

@@ -344,6 +341,9 @@ def extra_repr(self) -> str:
344341
f"window_size={self.window_size}, mlp_ratio={self.mlp_ratio}"
345342

346343

344+
register_notrace_module(TinyVitBlock)
345+
346+
347347
class TinyVitStage(nn.Module):
348348
""" A basic TinyViT layer for one stage.
349349

0 commit comments

Comments
 (0)