Skip to content

Commit 842a786

Browse files
committed
A few more maybe_add_mask situations
1 parent dd2c141 commit 842a786

File tree

3 files changed

+5
-6
lines changed

3 files changed

+5
-6
lines changed

timm/layers/attention_pool.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import torch.nn as nn
55
import torch.nn.functional as F
66

7+
from .attention import maybe_add_mask
78
from .config import use_fused_attn
89
from .mlp import Mlp
910
from .weight_init import trunc_normal_tf_
@@ -95,8 +96,7 @@ def forward(self, x, attn_mask: Optional[torch.Tensor] = None):
9596
else:
9697
q = q * self.scale
9798
attn = q @ k.transpose(-2, -1)
98-
if attn_mask is not None:
99-
attn = attn + attn_mask
99+
attn = maybe_add_mask(attn, attn_mask)
100100
attn = attn.softmax(dim=-1)
101101
x = attn @ v
102102
x = x.transpose(1, 2).reshape(B, self.latent_len, C)

timm/models/vision_transformer.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@
4343
OPENAI_CLIP_MEAN, OPENAI_CLIP_STD
4444
from timm.layers import Attention, PatchEmbed, Mlp, DropPath, AttentionPoolLatent, RmsNorm, PatchDropout, \
4545
SwiGLUPacked, SwiGLU, trunc_normal_, lecun_normal_, resample_patch_embed, resample_abs_pos_embed, use_fused_attn, \
46-
get_act_layer, get_norm_layer, LayerType
46+
get_act_layer, get_norm_layer, LayerType, maybe_add_mask
4747
from ._builder import build_model_with_cfg
4848
from ._features import feature_take_indices
4949
from ._manipulate import named_apply, checkpoint_seq, adapt_input_conv
@@ -256,8 +256,7 @@ def forward(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None) ->
256256
else:
257257
q = q * self.scale
258258
attn = q @ k.transpose(-2, -1)
259-
if attn_mask is not None:
260-
attn = attn + attn_mask
259+
attn = maybe_add_mask(attn, attn_mask)
261260
attn = attn.softmax(dim=-1)
262261
attn = self.attn_drop(attn)
263262
x_attn = attn @ v

timm/models/vision_transformer_flex.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -823,7 +823,7 @@ def forward_features(
823823
attn_mask: Optional[torch.Tensor] = None,
824824
) -> torch.Tensor:
825825

826-
if attn_mask is None and patch_valid is not None:
826+
if attn_mask is None:
827827
attn_mask = create_attention_mask(
828828
patch_valid,
829829
num_prefix_tokens=self.num_prefix_tokens,

0 commit comments

Comments
 (0)