Skip to content

Commit d73bbb0

Browse files
authored
Add files via upload
0 parents  commit d73bbb0

File tree

2 files changed

+445
-0
lines changed

2 files changed

+445
-0
lines changed

lstm.py

+266
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,266 @@
1+
import torch
2+
import torch.nn as nn
3+
import math
4+
from typing import Tuple, List, Optional
5+
6+
class LSTMCell(nn.Module):
7+
"""
8+
LSTM Cell implementation with layer normalization.
9+
10+
Mathematical formulation of LSTM:
11+
f_t = σ(W_f · [h_{t-1}, x_t] + b_f) # Forget gate
12+
i_t = σ(W_i · [h_{t-1}, x_t] + b_i) # Input gate
13+
g_t = tanh(W_g · [h_{t-1}, x_t] + b_g) # Candidate cell state
14+
o_t = σ(W_o · [h_{t-1}, x_t] + b_o) # Output gate
15+
16+
c_t = f_t ⊙ c_{t-1} + i_t ⊙ g_t # New cell state
17+
h_t = o_t ⊙ tanh(c_t) # New hidden state
18+
19+
where:
20+
- σ is the sigmoid function
21+
- ⊙ is element-wise multiplication
22+
- [h_{t-1}, x_t] represents concatenation
23+
"""
24+
def __init__(self, input_size: int, hidden_size: int, dropout: float = 0.0):
25+
super().__init__()
26+
self.input_size = input_size
27+
self.hidden_size = hidden_size
28+
self.dropout = nn.Dropout(dropout) if dropout > 0 else None
29+
30+
# Combined weight matrices for efficiency
31+
# W_ih combines weights for [i_t, f_t, g_t, o_t] for input x_t
32+
# W_hh combines weights for [i_t, f_t, g_t, o_t] for hidden state h_{t-1}
33+
self.weight_ih = nn.Linear(input_size, 4 * hidden_size)
34+
self.weight_hh = nn.Linear(hidden_size, 4 * hidden_size)
35+
36+
# Layer Normalization for better training stability
37+
self.layer_norm_x = nn.LayerNorm(4 * hidden_size) # Normalize gate pre-activations
38+
self.layer_norm_h = nn.LayerNorm(hidden_size) # Normalize hidden state
39+
self.layer_norm_c = nn.LayerNorm(hidden_size) # Normalize cell state
40+
41+
self.init_parameters()
42+
43+
def init_parameters(self) -> None:
44+
"""
45+
Initialize parameters using best practices:
46+
1. Orthogonal initialization for better gradient flow
47+
2. Initialize forget gate bias to 1.0 to prevent forgetting at start of training
48+
"""
49+
for weight in [self.weight_ih.weight, self.weight_hh.weight]:
50+
nn.init.orthogonal_(weight)
51+
52+
# Set forget gate bias to 1.0 (helps with learning long sequences)
53+
nn.init.constant_(self.weight_ih.bias[self.hidden_size:2*self.hidden_size], 1.0)
54+
nn.init.constant_(self.weight_hh.bias[self.hidden_size:2*self.hidden_size], 1.0)
55+
56+
def forward(self, x: torch.Tensor,
57+
hidden_state: Tuple[torch.Tensor, torch.Tensor]) -> Tuple[torch.Tensor, torch.Tensor]:
58+
"""
59+
Forward pass of LSTM cell.
60+
61+
Args:
62+
x: Input tensor of shape (batch_size, input_size)
63+
hidden_state: Tuple of (h_{t-1}, c_{t-1}) each of shape (batch_size, hidden_size)
64+
65+
Returns:
66+
Tuple of (h_t, c_t) representing new hidden and cell states
67+
"""
68+
h_prev, c_prev = hidden_state
69+
70+
# Combined matrix multiplication for all gates
71+
# Shape: (batch_size, 4 * hidden_size)
72+
gates_x = self.weight_ih(x) # Transform input
73+
gates_h = self.weight_hh(h_prev) # Transform previous hidden state
74+
75+
# Apply layer normalization
76+
gates_x = self.layer_norm_x(gates_x)
77+
gates = gates_x + gates_h # Combined gate pre-activations
78+
79+
# Split into individual gates
80+
# Each gate shape: (batch_size, hidden_size)
81+
i_gate, f_gate, g_gate, o_gate = gates.chunk(4, dim=1)
82+
83+
# Apply gate non-linearities
84+
i_t = torch.sigmoid(i_gate) # Input gate
85+
f_t = torch.sigmoid(f_gate) # Forget gate
86+
g_t = torch.tanh(g_gate) # Cell state candidate
87+
o_t = torch.sigmoid(o_gate) # Output gate
88+
89+
# Update cell state: c_t = f_t ⊙ c_{t-1} + i_t ⊙ g_t
90+
c_t = f_t * c_prev + i_t * g_t
91+
c_t = self.layer_norm_c(c_t)
92+
93+
# Update hidden state: h_t = o_t ⊙ tanh(c_t)
94+
h_t = o_t * torch.tanh(c_t)
95+
h_t = self.layer_norm_h(h_t)
96+
97+
if self.dropout is not None:
98+
h_t = self.dropout(h_t)
99+
100+
return h_t, c_t
101+
102+
def init_hidden(self, batch_size: int, device: torch.device) -> Tuple[torch.Tensor, torch.Tensor]:
103+
"""Initialize hidden state and cell state with zeros."""
104+
return (torch.zeros(batch_size, self.hidden_size, device=device),
105+
torch.zeros(batch_size, self.hidden_size, device=device))
106+
107+
108+
class StackedLSTM(nn.Module):
109+
"""
110+
Stacked LSTM implementation supporting multiple layers.
111+
Each layer processes the output of the previous layer.
112+
"""
113+
def __init__(self, input_size: int, hidden_size: int, num_layers: int, dropout: float = 0.0):
114+
super().__init__()
115+
self.num_layers = num_layers
116+
self.hidden_size = hidden_size
117+
118+
# Create list of LSTM cells, one for each layer
119+
self.layers = nn.ModuleList([
120+
LSTMCell(input_size if i == 0 else hidden_size, hidden_size,
121+
dropout if i < num_layers - 1 else 0.0) # No dropout on last layer
122+
for i in range(num_layers)
123+
])
124+
125+
def forward(self, x: torch.Tensor,
126+
hidden_states: Optional[List[Tuple[torch.Tensor, torch.Tensor]]] = None) -> Tuple[torch.Tensor, List[Tuple[torch.Tensor, torch.Tensor]]]:
127+
"""
128+
Process input sequence through stacked LSTM layers.
129+
130+
Args:
131+
x: Input tensor of shape (batch_size, seq_length, input_size)
132+
hidden_states: Optional initial hidden states for each layer
133+
134+
Returns:
135+
Tuple of (output, hidden_states) where output has shape (batch_size, seq_length, hidden_size)
136+
"""
137+
batch_size, seq_length, _ = x.size()
138+
device = x.device
139+
140+
if hidden_states is None:
141+
hidden_states = [layer.init_hidden(batch_size, device) for layer in self.layers]
142+
143+
layer_outputs = []
144+
for t in range(seq_length):
145+
input_t = x[:, t, :]
146+
for i, lstm_cell in enumerate(self.layers):
147+
input_t, cell_state = lstm_cell(input_t, hidden_states[i])
148+
hidden_states[i] = (input_t, cell_state)
149+
layer_outputs.append(input_t)
150+
151+
# Stack outputs along sequence dimension
152+
output = torch.stack(layer_outputs, dim=1)
153+
return output, hidden_states
154+
155+
156+
class LSTMNetwork(nn.Module):
157+
"""
158+
Complete LSTM network with bidirectional support.
159+
160+
In bidirectional mode:
161+
- Forward LSTM processes sequence from left to right
162+
- Backward LSTM processes sequence from right to left
163+
- Outputs are concatenated for final prediction
164+
"""
165+
def __init__(self,
166+
input_size: int,
167+
hidden_size: int,
168+
num_layers: int,
169+
output_size: int,
170+
dropout: float = 0.0,
171+
bidirectional: bool = False):
172+
super().__init__()
173+
self.bidirectional = bidirectional
174+
175+
# Forward direction LSTM
176+
self.stacked_lstm = StackedLSTM(input_size, hidden_size, num_layers, dropout)
177+
178+
# Optional backward direction LSTM for bidirectional processing
179+
if bidirectional:
180+
self.reverse_lstm = StackedLSTM(input_size, hidden_size, num_layers, dropout)
181+
hidden_size *= 2 # Double hidden size due to concatenation
182+
183+
self.fc = nn.Linear(hidden_size, output_size)
184+
self.dropout = nn.Dropout(dropout)
185+
186+
def forward(self, x: torch.Tensor,
187+
hidden_states: Optional[List[Tuple[torch.Tensor, torch.Tensor]]] = None) -> torch.Tensor:
188+
"""
189+
Forward pass of the network.
190+
191+
For bidirectional processing:
192+
1. Process sequence normally with forward LSTM
193+
2. Process reversed sequence with backward LSTM
194+
3. Concatenate both outputs
195+
4. Apply final linear transformation
196+
197+
Args:
198+
x: Input tensor of shape (batch_size, seq_length, input_size)
199+
hidden_states: Optional initial hidden states
200+
201+
Returns:
202+
Output tensor of shape (batch_size, output_size)
203+
"""
204+
# Forward direction
205+
output, hidden_states = self.stacked_lstm(x, hidden_states)
206+
207+
if self.bidirectional:
208+
# Process sequence in reverse direction
209+
reverse_output, _ = self.reverse_lstm(torch.flip(x, [1]))
210+
# Flip back to align with forward sequence
211+
reverse_output = torch.flip(reverse_output, [1])
212+
# Concatenate forward and backward outputs along feature dimension
213+
output = torch.cat([output, reverse_output], dim=-1)
214+
215+
# Apply dropout before final layer
216+
output = self.dropout(output)
217+
# Use final timestep output for prediction
218+
final_output = self.fc(output[:, -1, :])
219+
return final_output
220+
221+
222+
def create_lstm_model(config: dict) -> LSTMNetwork:
223+
"""
224+
Factory function to create an LSTM model with specified configuration.
225+
226+
Args:
227+
config: Dictionary containing model parameters:
228+
- input_size: Size of input features
229+
- hidden_size: Size of LSTM hidden state
230+
- num_layers: Number of stacked LSTM layers
231+
- output_size: Size of final output
232+
- dropout: Dropout probability (optional)
233+
- bidirectional: Whether to use bidirectional LSTM (optional)
234+
"""
235+
return LSTMNetwork(
236+
input_size=config['input_size'],
237+
hidden_size=config['hidden_size'],
238+
num_layers=config['num_layers'],
239+
output_size=config['output_size'],
240+
dropout=config.get('dropout', 0.0),
241+
bidirectional=config.get('bidirectional', False)
242+
)
243+
244+
# Example usage
245+
if __name__ == "__main__":
246+
# Configuration for a bidirectional LSTM
247+
config = {
248+
'input_size': 3,
249+
'hidden_size': 64,
250+
'num_layers': 2,
251+
'output_size': 1,
252+
'dropout': 0.3,
253+
'bidirectional': True # Enable bidirectional processing
254+
}
255+
256+
# Create model
257+
model = create_lstm_model(config)
258+
259+
# Generate dummy input
260+
batch_size, seq_length = 32, 10
261+
x = torch.randn(batch_size, seq_length, config['input_size'])
262+
263+
# Forward pass
264+
output = model(x)
265+
print(f"Input shape: {x.shape}")
266+
print(f"Output shape: {output.shape}")

0 commit comments

Comments
 (0)