-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathclassifier.py
More file actions
41 lines (32 loc) · 1.22 KB
/
classifier.py
File metadata and controls
41 lines (32 loc) · 1.22 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
from torch import nn
import torch.nn.functional as F
class Classifier(nn.Module):
def __init__(self, d_in, d_out):
super().__init__()
self.din = d_in
self.dout = d_out
# self.fc1 = nn.Linear(self.din, 512)
# self.bn1 = nn.BatchNorm1d(512)
# self.drop1 = nn.Dropout(0.4)
# self.fc2 = nn.Linear(512, 256)
# self.bn2 = nn.BatchNorm1d(256)
# self.drop2 = nn.Dropout(0.5)
# self.fc3 = nn.Linear(256, self.dout)
self.linear1 = nn.Linear(self.din, 512, bias=False)
self.bn6 = nn.BatchNorm1d(512)
self.dp1 = nn.Dropout(p=0.5)
self.linear2 = nn.Linear(512, 256)
self.bn7 = nn.BatchNorm1d(256)
self.dp2 = nn.Dropout(p=0.5)
self.linear3 = nn.Linear(256, self.dout)
def forward(self, x):
# x = self.drop1(F.relu(self.bn1(self.fc1(x))))
# x = self.drop2(F.relu(self.bn2(self.fc2(x))))
# x = self.fc3(x)
# return F.log_softmax(x, -1)
x = F.leaky_relu(self.bn6(self.linear1(x)), negative_slope=0.2)
x = self.dp1(x)
x = F.leaky_relu(self.bn7(self.linear2(x)), negative_slope=0.2)
x = self.dp2(x)
x = self.linear3(x)
return x