@@ -34,9 +34,10 @@ def __init__(self, in_features, hidden_features=None, out_features=None, act_lay
34
34
super ().__init__ ()
35
35
out_features = out_features or in_features
36
36
hidden_features = hidden_features or in_features
37
- self .fc1 = nn .Linear (in_features , hidden_features * 2 )
37
+ assert hidden_features % 2 == 0
38
+ self .fc1 = nn .Linear (in_features , hidden_features )
38
39
self .act = act_layer ()
39
- self .fc2 = nn .Linear (hidden_features , out_features )
40
+ self .fc2 = nn .Linear (hidden_features // 2 , out_features )
40
41
self .drop = nn .Dropout (drop )
41
42
42
43
def forward (self , x ):
@@ -47,3 +48,32 @@ def forward(self, x):
47
48
x = self .fc2 (x )
48
49
x = self .drop (x )
49
50
return x
51
+
52
+
53
+ class GatedMlp (nn .Module ):
54
+ """ MLP as used in gMLP
55
+ """
56
+ def __init__ (self , in_features , hidden_features = None , out_features = None , act_layer = nn .GELU ,
57
+ gate_layer = None , drop = 0. ):
58
+ super ().__init__ ()
59
+ out_features = out_features or in_features
60
+ hidden_features = hidden_features or in_features
61
+ self .fc1 = nn .Linear (in_features , hidden_features )
62
+ self .act = act_layer ()
63
+ if gate_layer is not None :
64
+ assert hidden_features % 2 == 0
65
+ self .gate = gate_layer (hidden_features )
66
+ hidden_features = hidden_features // 2 # FIXME base reduction on gate property?
67
+ else :
68
+ self .gate = nn .Identity ()
69
+ self .fc2 = nn .Linear (hidden_features , out_features )
70
+ self .drop = nn .Dropout (drop )
71
+
72
+ def forward (self , x ):
73
+ x = self .fc1 (x )
74
+ x = self .act (x )
75
+ x = self .drop (x )
76
+ x = self .gate (x )
77
+ x = self .fc2 (x )
78
+ x = self .drop (x )
79
+ return x
0 commit comments