Skip to content

Commit cf4ce2f

Browse files
authored
Merge pull request #647 from rwightman/more_mlp
Add preliminary gMLP and ResMLP impl to Mlp-Mixer
2 parents b3b503c + 6d81374 commit cf4ce2f

File tree

4 files changed

+268
-36
lines changed

4 files changed

+268
-36
lines changed

tests/test_models.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
torch._C._jit_set_profiling_mode(False)
1616

1717
# transformer models don't support many of the spatial / feature based model functionalities
18-
NON_STD_FILTERS = ['vit_*', 'tnt_*', 'pit_*', 'swin_*', 'coat_*', 'cait_*', 'mixer_*']
18+
NON_STD_FILTERS = ['vit_*', 'tnt_*', 'pit_*', 'swin_*', 'coat_*', 'cait_*', '*mixer_*', 'gmlp_*', 'resmlp_*']
1919
NUM_NON_STD = len(NON_STD_FILTERS)
2020

2121
# exclude models that cause specific test failures

timm/models/layers/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
from .inplace_abn import InplaceAbn
2121
from .linear import Linear
2222
from .mixed_conv2d import MixedConv2d
23-
from .mlp import Mlp, GluMlp
23+
from .mlp import Mlp, GluMlp, GatedMlp
2424
from .norm import GroupNorm
2525
from .norm_act import BatchNormAct2d, GroupNormAct
2626
from .padding import get_padding, get_same_padding, pad_same

timm/models/layers/mlp.py

Lines changed: 32 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,9 +34,10 @@ def __init__(self, in_features, hidden_features=None, out_features=None, act_lay
3434
super().__init__()
3535
out_features = out_features or in_features
3636
hidden_features = hidden_features or in_features
37-
self.fc1 = nn.Linear(in_features, hidden_features * 2)
37+
assert hidden_features % 2 == 0
38+
self.fc1 = nn.Linear(in_features, hidden_features)
3839
self.act = act_layer()
39-
self.fc2 = nn.Linear(hidden_features, out_features)
40+
self.fc2 = nn.Linear(hidden_features // 2, out_features)
4041
self.drop = nn.Dropout(drop)
4142

4243
def forward(self, x):
@@ -47,3 +48,32 @@ def forward(self, x):
4748
x = self.fc2(x)
4849
x = self.drop(x)
4950
return x
51+
52+
53+
class GatedMlp(nn.Module):
54+
""" MLP as used in gMLP
55+
"""
56+
def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU,
57+
gate_layer=None, drop=0.):
58+
super().__init__()
59+
out_features = out_features or in_features
60+
hidden_features = hidden_features or in_features
61+
self.fc1 = nn.Linear(in_features, hidden_features)
62+
self.act = act_layer()
63+
if gate_layer is not None:
64+
assert hidden_features % 2 == 0
65+
self.gate = gate_layer(hidden_features)
66+
hidden_features = hidden_features // 2 # FIXME base reduction on gate property?
67+
else:
68+
self.gate = nn.Identity()
69+
self.fc2 = nn.Linear(hidden_features, out_features)
70+
self.drop = nn.Dropout(drop)
71+
72+
def forward(self, x):
73+
x = self.fc1(x)
74+
x = self.act(x)
75+
x = self.drop(x)
76+
x = self.gate(x)
77+
x = self.fc2(x)
78+
x = self.drop(x)
79+
return x

0 commit comments

Comments
 (0)