Skip to content

Commit b7ced7c

Browse files
committed
torch.fx.wrap not working with older pytorch, trying register_notrace instead
1 parent 842a786 commit b7ced7c

File tree

2 files changed

+2
-2
lines changed

2 files changed

+2
-2
lines changed

timm/layers/attention.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@
88
from .pos_embed_sincos import apply_rot_embed_cat
99

1010

11-
@torch.fx.wrap
1211
def maybe_add_mask(scores: torch.Tensor, attn_mask: Optional[torch.Tensor] = None):
1312
return scores if attn_mask is None else scores + attn_mask
1413

timm/models/_features_fx.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818

1919
# Layers we went to treat as leaf modules
2020
from timm.layers import Conv2dSame, ScaledStdConv2dSame, CondConv2d, StdConv2dSame, Format
21-
from timm.layers import resample_abs_pos_embed, resample_abs_pos_embed_nhwc
21+
from timm.layers import resample_abs_pos_embed, resample_abs_pos_embed_nhwc, maybe_add_mask
2222
from timm.layers.non_local_attn import BilinearAttnTransform
2323
from timm.layers.pool2d_same import MaxPool2dSame, AvgPool2dSame
2424
from timm.layers.norm_act import (
@@ -79,6 +79,7 @@ def get_notrace_modules():
7979
_autowrap_functions = {
8080
resample_abs_pos_embed,
8181
resample_abs_pos_embed_nhwc,
82+
maybe_add_mask,
8283
}
8384

8485

0 commit comments

Comments
 (0)