diff --git a/deeplog/deeplog.py b/deeplog/deeplog.py index f543002..dba0ee0 100644 --- a/deeplog/deeplog.py +++ b/deeplog/deeplog.py @@ -23,19 +23,19 @@ class Model(nn.Module): - def __init__(self, input_size, hidden_size, num_layers, num_classes): - super(Model, self).__init__() - self.hidden_size = hidden_size - self.num_layers = num_layers - self.lstm = nn.LSTM(input_size, hidden_size, num_layers, batch_first=True) + def __init__(self, num_classes, embed_dim, hidden_size, num_layers): + super().__init__() + self.embedding = nn.Embedding(num_classes, embed_dim, padding_idx=0) + self.lstm = nn.LSTM(embed_dim, hidden_size, num_layers, batch_first=True) self.fc = nn.Linear(hidden_size, num_classes) - def forward(self, input): - h0 = torch.zeros(self.num_layers, input.size(0), self.hidden_size).to(input.device) - c0 = torch.zeros(self.num_layers, input.size(0), self.hidden_size).to(input.device) - out, _ = self.lstm(input, (h0, c0)) - out = self.fc(out[:, -1, :]) - return out + def forward(self, x): + emb = self.embedding(x) + h0 = torch.zeros(self.lstm.num_layers, x.size(0), self.lstm.hidden_size, + device=x.device) + c0 = torch.zeros_like(h0) + out, _ = self.lstm(emb, (h0, c0)) + return self.fc(out[:, -1, :]) class Generate():