Skip to content
10 changes: 5 additions & 5 deletions tests/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,13 +75,13 @@
'*efficientnet_l2*', '*resnext101_32x48d', '*in21k', '*152x4_bitm', '*101x3_bitm', '*50x3_bitm',
'*nfnet_f3*', '*nfnet_f4*', '*nfnet_f5*', '*nfnet_f6*', '*nfnet_f7*', '*efficientnetv2_xl*',
'*resnetrs350*', '*resnetrs420*', 'xcit_large_24_p8*', '*huge*', '*giant*', '*gigantic*',
'*enormous*', 'maxvit_xlarge*', 'regnet*1280', 'regnet*2560', '*_1b_*', '*_3b_*']
NON_STD_EXCLUDE_FILTERS = ['*huge*', '*giant*', '*gigantic*', '*enormous*', '*_1b_*', '*_3b_*']
'*enormous*', 'maxvit_xlarge*', 'regnet*1280', 'regnet*2560', '*_1b_*', '*_3b_*', '*_7b_*']
NON_STD_EXCLUDE_FILTERS = ['*huge*', '*giant*', '*gigantic*', '*enormous*', '*_1b_*', '*_3b_*', '*_7b_*']
else:
EXCLUDE_FILTERS = ['*enormous*']
NON_STD_EXCLUDE_FILTERS = ['*gigantic*', '*enormous*', '*_3b_*']
EXCLUDE_FILTERS = ['*enormous*', '*_7b_*']
NON_STD_EXCLUDE_FILTERS = ['*gigantic*', '*enormous*', '*_3b_*', '*_7b_*']

EXCLUDE_JIT_FILTERS = ['hiera_*', '*naflex*']
EXCLUDE_JIT_FILTERS = ['hiera_*', '*naflex*', '*_7b_*']

TARGET_FWD_SIZE = MAX_FWD_SIZE = 384
TARGET_BWD_SIZE = 128
Expand Down
2 changes: 2 additions & 0 deletions timm/layers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,9 @@
RotaryEmbedding,
RotaryEmbeddingCat,
RotaryEmbeddingMixed,
RotaryEmbeddingDinoV3,
get_mixed_freqs,
create_rope_embed,
)
from .squeeze_excite import SEModule, SqueezeExcite, EffectiveSEModule, EffectiveSqueezeExcite
from .selective_kernel import SelectiveKernel
Expand Down
8 changes: 6 additions & 2 deletions timm/layers/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,7 @@ def __init__(
qk_norm: bool = False,
scale_norm: bool = False,
proj_bias: bool = True,
rotate_half: bool = False,
):
"""Initialize the Attention module.

Expand All @@ -136,6 +137,7 @@ def __init__(
norm_layer: Normalization layer constructor to use for QK and scale normalization
qk_norm: Enable normalization of query (Q) and key (K) vectors with norm_layer
scale_norm: Enable normalization (scaling) of attention output with norm_layer
rotate_half: Use 'half' ROPE layout instead of default 'interleaved'
"""
super().__init__()
if scale_norm or qk_norm:
Expand All @@ -148,6 +150,7 @@ def __init__(
self.scale = head_dim ** -0.5
self.num_prefix_tokens = num_prefix_tokens
self.fused_attn = use_fused_attn()
self.rotate_half = rotate_half

if qkv_fused:
self.qkv = nn.Linear(dim, attn_dim * 3, bias=qkv_bias)
Expand Down Expand Up @@ -196,8 +199,9 @@ def forward(

if rope is not None:
npt = self.num_prefix_tokens
q = torch.cat([q[:, :, :npt, :], apply_rot_embed_cat(q[:, :, npt:, :], rope)], dim=2).type_as(v)
k = torch.cat([k[:, :, :npt, :], apply_rot_embed_cat(k[:, :, npt:, :], rope)], dim=2).type_as(v)
half = getattr(self, 'rotate_half', False)
q = torch.cat([q[:, :, :npt, :], apply_rot_embed_cat(q[:, :, npt:, :], rope, half=half)], dim=2).type_as(v)
k = torch.cat([k[:, :, :npt, :], apply_rot_embed_cat(k[:, :, npt:, :], rope, half=half)], dim=2).type_as(v)

if self.fused_attn:
x = F.scaled_dot_product_attention(
Expand Down
4 changes: 4 additions & 0 deletions timm/layers/mlp.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,13 +119,17 @@ def __init__(
norm_layer=None,
bias=True,
drop=0.,
align_to=0,
):
super().__init__()
out_features = out_features or in_features
hidden_features = hidden_features or in_features
bias = to_2tuple(bias)
drop_probs = to_2tuple(drop)

if align_to:
hidden_features = hidden_features + (-hidden_features % align_to)

self.fc1_g = nn.Linear(in_features, hidden_features, bias=bias[0])
self.fc1_x = nn.Linear(in_features, hidden_features, bias=bias[0])
self.act = act_layer()
Expand Down
Loading