1
+ import torch
2
+ import torch .nn as nn
3
+ import math
4
+ from typing import Tuple , List , Optional
5
+
6
+ class LSTMCell (nn .Module ):
7
+ """
8
+ LSTM Cell implementation with layer normalization.
9
+
10
+ Mathematical formulation of LSTM:
11
+ f_t = σ(W_f · [h_{t-1}, x_t] + b_f) # Forget gate
12
+ i_t = σ(W_i · [h_{t-1}, x_t] + b_i) # Input gate
13
+ g_t = tanh(W_g · [h_{t-1}, x_t] + b_g) # Candidate cell state
14
+ o_t = σ(W_o · [h_{t-1}, x_t] + b_o) # Output gate
15
+
16
+ c_t = f_t ⊙ c_{t-1} + i_t ⊙ g_t # New cell state
17
+ h_t = o_t ⊙ tanh(c_t) # New hidden state
18
+
19
+ where:
20
+ - σ is the sigmoid function
21
+ - ⊙ is element-wise multiplication
22
+ - [h_{t-1}, x_t] represents concatenation
23
+ """
24
+ def __init__ (self , input_size : int , hidden_size : int , dropout : float = 0.0 ):
25
+ super ().__init__ ()
26
+ self .input_size = input_size
27
+ self .hidden_size = hidden_size
28
+ self .dropout = nn .Dropout (dropout ) if dropout > 0 else None
29
+
30
+ # Combined weight matrices for efficiency
31
+ # W_ih combines weights for [i_t, f_t, g_t, o_t] for input x_t
32
+ # W_hh combines weights for [i_t, f_t, g_t, o_t] for hidden state h_{t-1}
33
+ self .weight_ih = nn .Linear (input_size , 4 * hidden_size )
34
+ self .weight_hh = nn .Linear (hidden_size , 4 * hidden_size )
35
+
36
+ # Layer Normalization for better training stability
37
+ self .layer_norm_x = nn .LayerNorm (4 * hidden_size ) # Normalize gate pre-activations
38
+ self .layer_norm_h = nn .LayerNorm (hidden_size ) # Normalize hidden state
39
+ self .layer_norm_c = nn .LayerNorm (hidden_size ) # Normalize cell state
40
+
41
+ self .init_parameters ()
42
+
43
+ def init_parameters (self ) -> None :
44
+ """
45
+ Initialize parameters using best practices:
46
+ 1. Orthogonal initialization for better gradient flow
47
+ 2. Initialize forget gate bias to 1.0 to prevent forgetting at start of training
48
+ """
49
+ for weight in [self .weight_ih .weight , self .weight_hh .weight ]:
50
+ nn .init .orthogonal_ (weight )
51
+
52
+ # Set forget gate bias to 1.0 (helps with learning long sequences)
53
+ nn .init .constant_ (self .weight_ih .bias [self .hidden_size :2 * self .hidden_size ], 1.0 )
54
+ nn .init .constant_ (self .weight_hh .bias [self .hidden_size :2 * self .hidden_size ], 1.0 )
55
+
56
+ def forward (self , x : torch .Tensor ,
57
+ hidden_state : Tuple [torch .Tensor , torch .Tensor ]) -> Tuple [torch .Tensor , torch .Tensor ]:
58
+ """
59
+ Forward pass of LSTM cell.
60
+
61
+ Args:
62
+ x: Input tensor of shape (batch_size, input_size)
63
+ hidden_state: Tuple of (h_{t-1}, c_{t-1}) each of shape (batch_size, hidden_size)
64
+
65
+ Returns:
66
+ Tuple of (h_t, c_t) representing new hidden and cell states
67
+ """
68
+ h_prev , c_prev = hidden_state
69
+
70
+ # Combined matrix multiplication for all gates
71
+ # Shape: (batch_size, 4 * hidden_size)
72
+ gates_x = self .weight_ih (x ) # Transform input
73
+ gates_h = self .weight_hh (h_prev ) # Transform previous hidden state
74
+
75
+ # Apply layer normalization
76
+ gates_x = self .layer_norm_x (gates_x )
77
+ gates = gates_x + gates_h # Combined gate pre-activations
78
+
79
+ # Split into individual gates
80
+ # Each gate shape: (batch_size, hidden_size)
81
+ i_gate , f_gate , g_gate , o_gate = gates .chunk (4 , dim = 1 )
82
+
83
+ # Apply gate non-linearities
84
+ i_t = torch .sigmoid (i_gate ) # Input gate
85
+ f_t = torch .sigmoid (f_gate ) # Forget gate
86
+ g_t = torch .tanh (g_gate ) # Cell state candidate
87
+ o_t = torch .sigmoid (o_gate ) # Output gate
88
+
89
+ # Update cell state: c_t = f_t ⊙ c_{t-1} + i_t ⊙ g_t
90
+ c_t = f_t * c_prev + i_t * g_t
91
+ c_t = self .layer_norm_c (c_t )
92
+
93
+ # Update hidden state: h_t = o_t ⊙ tanh(c_t)
94
+ h_t = o_t * torch .tanh (c_t )
95
+ h_t = self .layer_norm_h (h_t )
96
+
97
+ if self .dropout is not None :
98
+ h_t = self .dropout (h_t )
99
+
100
+ return h_t , c_t
101
+
102
+ def init_hidden (self , batch_size : int , device : torch .device ) -> Tuple [torch .Tensor , torch .Tensor ]:
103
+ """Initialize hidden state and cell state with zeros."""
104
+ return (torch .zeros (batch_size , self .hidden_size , device = device ),
105
+ torch .zeros (batch_size , self .hidden_size , device = device ))
106
+
107
+
108
+ class StackedLSTM (nn .Module ):
109
+ """
110
+ Stacked LSTM implementation supporting multiple layers.
111
+ Each layer processes the output of the previous layer.
112
+ """
113
+ def __init__ (self , input_size : int , hidden_size : int , num_layers : int , dropout : float = 0.0 ):
114
+ super ().__init__ ()
115
+ self .num_layers = num_layers
116
+ self .hidden_size = hidden_size
117
+
118
+ # Create list of LSTM cells, one for each layer
119
+ self .layers = nn .ModuleList ([
120
+ LSTMCell (input_size if i == 0 else hidden_size , hidden_size ,
121
+ dropout if i < num_layers - 1 else 0.0 ) # No dropout on last layer
122
+ for i in range (num_layers )
123
+ ])
124
+
125
+ def forward (self , x : torch .Tensor ,
126
+ hidden_states : Optional [List [Tuple [torch .Tensor , torch .Tensor ]]] = None ) -> Tuple [torch .Tensor , List [Tuple [torch .Tensor , torch .Tensor ]]]:
127
+ """
128
+ Process input sequence through stacked LSTM layers.
129
+
130
+ Args:
131
+ x: Input tensor of shape (batch_size, seq_length, input_size)
132
+ hidden_states: Optional initial hidden states for each layer
133
+
134
+ Returns:
135
+ Tuple of (output, hidden_states) where output has shape (batch_size, seq_length, hidden_size)
136
+ """
137
+ batch_size , seq_length , _ = x .size ()
138
+ device = x .device
139
+
140
+ if hidden_states is None :
141
+ hidden_states = [layer .init_hidden (batch_size , device ) for layer in self .layers ]
142
+
143
+ layer_outputs = []
144
+ for t in range (seq_length ):
145
+ input_t = x [:, t , :]
146
+ for i , lstm_cell in enumerate (self .layers ):
147
+ input_t , cell_state = lstm_cell (input_t , hidden_states [i ])
148
+ hidden_states [i ] = (input_t , cell_state )
149
+ layer_outputs .append (input_t )
150
+
151
+ # Stack outputs along sequence dimension
152
+ output = torch .stack (layer_outputs , dim = 1 )
153
+ return output , hidden_states
154
+
155
+
156
+ class LSTMNetwork (nn .Module ):
157
+ """
158
+ Complete LSTM network with bidirectional support.
159
+
160
+ In bidirectional mode:
161
+ - Forward LSTM processes sequence from left to right
162
+ - Backward LSTM processes sequence from right to left
163
+ - Outputs are concatenated for final prediction
164
+ """
165
+ def __init__ (self ,
166
+ input_size : int ,
167
+ hidden_size : int ,
168
+ num_layers : int ,
169
+ output_size : int ,
170
+ dropout : float = 0.0 ,
171
+ bidirectional : bool = False ):
172
+ super ().__init__ ()
173
+ self .bidirectional = bidirectional
174
+
175
+ # Forward direction LSTM
176
+ self .stacked_lstm = StackedLSTM (input_size , hidden_size , num_layers , dropout )
177
+
178
+ # Optional backward direction LSTM for bidirectional processing
179
+ if bidirectional :
180
+ self .reverse_lstm = StackedLSTM (input_size , hidden_size , num_layers , dropout )
181
+ hidden_size *= 2 # Double hidden size due to concatenation
182
+
183
+ self .fc = nn .Linear (hidden_size , output_size )
184
+ self .dropout = nn .Dropout (dropout )
185
+
186
+ def forward (self , x : torch .Tensor ,
187
+ hidden_states : Optional [List [Tuple [torch .Tensor , torch .Tensor ]]] = None ) -> torch .Tensor :
188
+ """
189
+ Forward pass of the network.
190
+
191
+ For bidirectional processing:
192
+ 1. Process sequence normally with forward LSTM
193
+ 2. Process reversed sequence with backward LSTM
194
+ 3. Concatenate both outputs
195
+ 4. Apply final linear transformation
196
+
197
+ Args:
198
+ x: Input tensor of shape (batch_size, seq_length, input_size)
199
+ hidden_states: Optional initial hidden states
200
+
201
+ Returns:
202
+ Output tensor of shape (batch_size, output_size)
203
+ """
204
+ # Forward direction
205
+ output , hidden_states = self .stacked_lstm (x , hidden_states )
206
+
207
+ if self .bidirectional :
208
+ # Process sequence in reverse direction
209
+ reverse_output , _ = self .reverse_lstm (torch .flip (x , [1 ]))
210
+ # Flip back to align with forward sequence
211
+ reverse_output = torch .flip (reverse_output , [1 ])
212
+ # Concatenate forward and backward outputs along feature dimension
213
+ output = torch .cat ([output , reverse_output ], dim = - 1 )
214
+
215
+ # Apply dropout before final layer
216
+ output = self .dropout (output )
217
+ # Use final timestep output for prediction
218
+ final_output = self .fc (output [:, - 1 , :])
219
+ return final_output
220
+
221
+
222
+ def create_lstm_model (config : dict ) -> LSTMNetwork :
223
+ """
224
+ Factory function to create an LSTM model with specified configuration.
225
+
226
+ Args:
227
+ config: Dictionary containing model parameters:
228
+ - input_size: Size of input features
229
+ - hidden_size: Size of LSTM hidden state
230
+ - num_layers: Number of stacked LSTM layers
231
+ - output_size: Size of final output
232
+ - dropout: Dropout probability (optional)
233
+ - bidirectional: Whether to use bidirectional LSTM (optional)
234
+ """
235
+ return LSTMNetwork (
236
+ input_size = config ['input_size' ],
237
+ hidden_size = config ['hidden_size' ],
238
+ num_layers = config ['num_layers' ],
239
+ output_size = config ['output_size' ],
240
+ dropout = config .get ('dropout' , 0.0 ),
241
+ bidirectional = config .get ('bidirectional' , False )
242
+ )
243
+
244
+ # Example usage
245
+ if __name__ == "__main__" :
246
+ # Configuration for a bidirectional LSTM
247
+ config = {
248
+ 'input_size' : 3 ,
249
+ 'hidden_size' : 64 ,
250
+ 'num_layers' : 2 ,
251
+ 'output_size' : 1 ,
252
+ 'dropout' : 0.3 ,
253
+ 'bidirectional' : True # Enable bidirectional processing
254
+ }
255
+
256
+ # Create model
257
+ model = create_lstm_model (config )
258
+
259
+ # Generate dummy input
260
+ batch_size , seq_length = 32 , 10
261
+ x = torch .randn (batch_size , seq_length , config ['input_size' ])
262
+
263
+ # Forward pass
264
+ output = model (x )
265
+ print (f"Input shape: { x .shape } " )
266
+ print (f"Output shape: { output .shape } " )
0 commit comments