Skip to content

Commit dd2c141

Browse files
committed
Fix tracing of attention module with attn_mask support
1 parent 162f492 commit dd2c141

File tree

2 files changed

+8
-7
lines changed

2 files changed

+8
-7
lines changed

timm/layers/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from .activations import *
22
from .adaptive_avgmax_pool import \
33
adaptive_avgmax_pool2d, select_adaptive_pool2d, AdaptiveAvgMaxPool2d, SelectAdaptivePool2d
4-
from .attention import Attention, AttentionRope
4+
from .attention import Attention, AttentionRope, maybe_add_mask
55
from .attention2d import MultiQueryAttention2d, Attention2d, MultiQueryAttentionV2
66
from .attention_pool import AttentionPoolLatent
77
from .attention_pool2d import AttentionPool2d, RotAttentionPool2d, RotaryEmbedding

timm/layers/attention.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,11 @@
88
from .pos_embed_sincos import apply_rot_embed_cat
99

1010

11+
@torch.fx.wrap
12+
def maybe_add_mask(scores: torch.Tensor, attn_mask: Optional[torch.Tensor] = None):
13+
return scores if attn_mask is None else scores + attn_mask
14+
15+
1116
class Attention(nn.Module):
1217
"""Standard Multi-head Self Attention module with QKV projection.
1318
@@ -74,8 +79,7 @@ def forward(
7479
else:
7580
q = q * self.scale
7681
attn = q @ k.transpose(-2, -1)
77-
if attn_mask is not None:
78-
attn = attn + attn_mask
82+
attn = maybe_add_mask(attn, attn_mask)
7983
attn = attn.softmax(dim=-1)
8084
attn = self.attn_drop(attn)
8185
x = attn @ v
@@ -196,10 +200,7 @@ def forward(
196200
else:
197201
q = q * self.scale
198202
attn = (q @ k.transpose(-2, -1))
199-
200-
if attn_mask is not None:
201-
attn_mask = attn_mask.to(torch.bool)
202-
attn = attn.masked_fill(~attn_mask[:, None, None, :], float("-inf"))
203+
attn = maybe_add_mask(attn, attn_mask)
203204
attn = attn.softmax(dim=-1)
204205

205206
attn = self.attn_drop(attn)

0 commit comments

Comments
 (0)