Skip to content

Commit 709d7c0

Browse files
committed
Merge branch 'rwightman:master' into master
2 parents 240e667 + cf4ce2f commit 709d7c0

File tree

5 files changed

+271
-39
lines changed

5 files changed

+271
-39
lines changed

inference.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -114,13 +114,13 @@ def main():
114114
_logger.info('Predict: [{0}/{1}] Time {batch_time.val:.3f} ({batch_time.avg:.3f})'.format(
115115
batch_idx, len(loader), batch_time=batch_time))
116116

117-
topk_ids = np.concatenate(topk_ids, axis=0).squeeze()
117+
topk_ids = np.concatenate(topk_ids, axis=0)
118118

119119
with open(os.path.join(args.output_dir, './topk_ids.csv'), 'w') as out_file:
120120
filenames = loader.dataset.filenames(basename=True)
121121
for filename, label in zip(filenames, topk_ids):
122-
out_file.write('{0},{1},{2},{3},{4},{5}\n'.format(
123-
filename, label[0], label[1], label[2], label[3], label[4]))
122+
out_file.write('{0},{1}\n'.format(
123+
filename, ','.join([ str(v) for v in label])))
124124

125125

126126
if __name__ == '__main__':

tests/test_models.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
torch._C._jit_set_profiling_mode(False)
1616

1717
# transformer models don't support many of the spatial / feature based model functionalities
18-
NON_STD_FILTERS = ['vit_*', 'tnt_*', 'pit_*', 'swin_*', 'coat_*', 'cait_*', 'mixer_*']
18+
NON_STD_FILTERS = ['vit_*', 'tnt_*', 'pit_*', 'swin_*', 'coat_*', 'cait_*', '*mixer_*', 'gmlp_*', 'resmlp_*']
1919
NUM_NON_STD = len(NON_STD_FILTERS)
2020

2121
# exclude models that cause specific test failures

timm/models/layers/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
from .inplace_abn import InplaceAbn
2121
from .linear import Linear
2222
from .mixed_conv2d import MixedConv2d
23-
from .mlp import Mlp, GluMlp
23+
from .mlp import Mlp, GluMlp, GatedMlp
2424
from .norm import GroupNorm
2525
from .norm_act import BatchNormAct2d, GroupNormAct
2626
from .padding import get_padding, get_same_padding, pad_same

timm/models/layers/mlp.py

Lines changed: 32 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,9 +34,10 @@ def __init__(self, in_features, hidden_features=None, out_features=None, act_lay
3434
super().__init__()
3535
out_features = out_features or in_features
3636
hidden_features = hidden_features or in_features
37-
self.fc1 = nn.Linear(in_features, hidden_features * 2)
37+
assert hidden_features % 2 == 0
38+
self.fc1 = nn.Linear(in_features, hidden_features)
3839
self.act = act_layer()
39-
self.fc2 = nn.Linear(hidden_features, out_features)
40+
self.fc2 = nn.Linear(hidden_features // 2, out_features)
4041
self.drop = nn.Dropout(drop)
4142

4243
def forward(self, x):
@@ -47,3 +48,32 @@ def forward(self, x):
4748
x = self.fc2(x)
4849
x = self.drop(x)
4950
return x
51+
52+
53+
class GatedMlp(nn.Module):
54+
""" MLP as used in gMLP
55+
"""
56+
def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU,
57+
gate_layer=None, drop=0.):
58+
super().__init__()
59+
out_features = out_features or in_features
60+
hidden_features = hidden_features or in_features
61+
self.fc1 = nn.Linear(in_features, hidden_features)
62+
self.act = act_layer()
63+
if gate_layer is not None:
64+
assert hidden_features % 2 == 0
65+
self.gate = gate_layer(hidden_features)
66+
hidden_features = hidden_features // 2 # FIXME base reduction on gate property?
67+
else:
68+
self.gate = nn.Identity()
69+
self.fc2 = nn.Linear(hidden_features, out_features)
70+
self.drop = nn.Dropout(drop)
71+
72+
def forward(self, x):
73+
x = self.fc1(x)
74+
x = self.act(x)
75+
x = self.drop(x)
76+
x = self.gate(x)
77+
x = self.fc2(x)
78+
x = self.drop(x)
79+
return x

0 commit comments

Comments
 (0)