Skip to content

Commit c1ec2f7

Browse files
committed
u
1 parent 1fab8ff commit c1ec2f7

File tree

1 file changed

+209
-0
lines changed

1 file changed

+209
-0
lines changed

unpublished/sequence_prediction.ipynb

Lines changed: 209 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,209 @@
1+
{
2+
"cells": [
3+
{
4+
"cell_type": "markdown",
5+
"metadata": {},
6+
"source": [
7+
"# Prediction of time series with recurrent neural networks\n",
8+
"\n",
9+
"https://github.com/google/flax/blob/main/examples/seq2seq/models.py\n",
10+
"\n",
11+
"* We don't need vocabulary size\n",
12+
"* We don't need one-hot encoding\n",
13+
"* We don't need to sample from the output of the decoder"
14+
]
15+
},
16+
{
17+
"cell_type": "code",
18+
"execution_count": 2,
19+
"metadata": {},
20+
"outputs": [
21+
{
22+
"ename": "TypeError",
23+
"evalue": "LSTMCell.__init__() missing 1 required positional argument: 'features'",
24+
"output_type": "error",
25+
"traceback": [
26+
"\u001b[1;31m---------------------------------------------------------------------------\u001b[0m",
27+
"\u001b[1;31mTypeError\u001b[0m Traceback (most recent call last)",
28+
"Cell \u001b[1;32mIn[2], line 115\u001b[0m\n\u001b[0;32m 113\u001b[0m input_seq_length \u001b[38;5;241m=\u001b[39m \u001b[38;5;241m20\u001b[39m \u001b[38;5;66;03m# Length of input sequence (K)\u001b[39;00m\n\u001b[0;32m 114\u001b[0m output_seq_length \u001b[38;5;241m=\u001b[39m \u001b[38;5;241m10\u001b[39m \u001b[38;5;66;03m# Length of output sequence (N)\u001b[39;00m\n\u001b[1;32m--> 115\u001b[0m trained_state \u001b[38;5;241m=\u001b[39m \u001b[43mtrain_lstm\u001b[49m\u001b[43m(\u001b[49m\u001b[43minput_dim\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43minput_seq_length\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43moutput_seq_length\u001b[49m\u001b[43m)\u001b[49m\n\u001b[0;32m 117\u001b[0m \u001b[38;5;66;03m# Evaluate the model\u001b[39;00m\n\u001b[0;32m 118\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mevaluate_model\u001b[39m(state, X, y):\n",
29+
"Cell \u001b[1;32mIn[2], line 97\u001b[0m, in \u001b[0;36mtrain_lstm\u001b[1;34m(input_dim, input_seq_length, output_seq_length)\u001b[0m\n\u001b[0;32m 94\u001b[0m X, y \u001b[38;5;241m=\u001b[39m generate_multidim_data(\u001b[38;5;241m2000\u001b[39m, input_dim, input_seq_length, output_seq_length)\n\u001b[0;32m 96\u001b[0m \u001b[38;5;66;03m# Create and initialize model\u001b[39;00m\n\u001b[1;32m---> 97\u001b[0m state \u001b[38;5;241m=\u001b[39m \u001b[43mcreate_train_state\u001b[49m\u001b[43m(\u001b[49m\u001b[43mkey\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mhidden_size\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43minput_dim\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43moutput_seq_length\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mlearning_rate\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mX\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mshape\u001b[49m\u001b[43m[\u001b[49m\u001b[38;5;241;43m1\u001b[39;49m\u001b[43m:\u001b[49m\u001b[43m]\u001b[49m\u001b[43m)\u001b[49m\n\u001b[0;32m 99\u001b[0m \u001b[38;5;66;03m# Training loop\u001b[39;00m\n\u001b[0;32m 100\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m epoch \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mrange\u001b[39m(num_epochs):\n",
30+
"Cell \u001b[1;32mIn[2], line 69\u001b[0m, in \u001b[0;36mcreate_train_state\u001b[1;34m(key, hidden_size, output_dim, output_seq_length, learning_rate, input_shape)\u001b[0m\n\u001b[0;32m 67\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mcreate_train_state\u001b[39m(key, hidden_size, output_dim, output_seq_length, learning_rate, input_shape):\n\u001b[0;32m 68\u001b[0m model \u001b[38;5;241m=\u001b[39m LSTMEncoderDecoder(hidden_size\u001b[38;5;241m=\u001b[39mhidden_size, output_dim\u001b[38;5;241m=\u001b[39moutput_dim, output_seq_length\u001b[38;5;241m=\u001b[39moutput_seq_length)\n\u001b[1;32m---> 69\u001b[0m params \u001b[38;5;241m=\u001b[39m \u001b[43mmodel\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43minit\u001b[49m\u001b[43m(\u001b[49m\u001b[43mkey\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mjnp\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mones\u001b[49m\u001b[43m(\u001b[49m\u001b[43minput_shape\u001b[49m\u001b[43m)\u001b[49m\u001b[43m)\u001b[49m[\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mparams\u001b[39m\u001b[38;5;124m'\u001b[39m]\n\u001b[0;32m 70\u001b[0m tx \u001b[38;5;241m=\u001b[39m optax\u001b[38;5;241m.\u001b[39madam(learning_rate)\n\u001b[0;32m 71\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m train_state\u001b[38;5;241m.\u001b[39mTrainState\u001b[38;5;241m.\u001b[39mcreate(apply_fn\u001b[38;5;241m=\u001b[39mmodel\u001b[38;5;241m.\u001b[39mapply, params\u001b[38;5;241m=\u001b[39mparams, tx\u001b[38;5;241m=\u001b[39mtx)\n",
31+
" \u001b[1;31m[... skipping hidden 9 frame]\u001b[0m\n",
32+
"Cell \u001b[1;32mIn[2], line 41\u001b[0m, in \u001b[0;36mLSTMEncoderDecoder.__call__\u001b[1;34m(self, x)\u001b[0m\n\u001b[0;32m 38\u001b[0m \u001b[38;5;129m@nn\u001b[39m\u001b[38;5;241m.\u001b[39mcompact\n\u001b[0;32m 39\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21m__call__\u001b[39m(\u001b[38;5;28mself\u001b[39m, x):\n\u001b[0;32m 40\u001b[0m \u001b[38;5;66;03m# Encoder\u001b[39;00m\n\u001b[1;32m---> 41\u001b[0m encoder_lstm \u001b[38;5;241m=\u001b[39m \u001b[43mnn\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mLSTMCell\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[0;32m 42\u001b[0m encoder_dense \u001b[38;5;241m=\u001b[39m nn\u001b[38;5;241m.\u001b[39mDense(features\u001b[38;5;241m=\u001b[39m\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mhidden_size)\n\u001b[0;32m 44\u001b[0m \u001b[38;5;66;03m# Decoder\u001b[39;00m\n",
33+
"File \u001b[1;32mc:\\Users\\HansDembinski\\AppData\\Local\\micromamba\\envs\\blog\\Lib\\site-packages\\flax\\linen\\kw_only_dataclasses.py:235\u001b[0m, in \u001b[0;36m_process_class.<locals>.init_wrapper\u001b[1;34m(self, *args, **kwargs)\u001b[0m\n\u001b[0;32m 227\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m num_args \u001b[38;5;241m>\u001b[39m expected_num_args:\n\u001b[0;32m 228\u001b[0m \u001b[38;5;66;03m# we add + 1 to each to account for `self`, matching python's\u001b[39;00m\n\u001b[0;32m 229\u001b[0m \u001b[38;5;66;03m# default error message\u001b[39;00m\n\u001b[0;32m 230\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mTypeError\u001b[39;00m(\n\u001b[0;32m 231\u001b[0m \u001b[38;5;124mf\u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124m__init__() takes \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mexpected_num_args\u001b[38;5;250m \u001b[39m\u001b[38;5;241m+\u001b[39m\u001b[38;5;250m \u001b[39m\u001b[38;5;241m1\u001b[39m\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m positional \u001b[39m\u001b[38;5;124m'\u001b[39m\n\u001b[0;32m 232\u001b[0m \u001b[38;5;124mf\u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124marguments but \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mnum_args\u001b[38;5;250m \u001b[39m\u001b[38;5;241m+\u001b[39m\u001b[38;5;250m \u001b[39m\u001b[38;5;241m1\u001b[39m\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m were given\u001b[39m\u001b[38;5;124m'\u001b[39m\n\u001b[0;32m 233\u001b[0m )\n\u001b[1;32m--> 235\u001b[0m \u001b[43mdataclass_init\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n",
34+
"\u001b[1;31mTypeError\u001b[0m: LSTMCell.__init__() missing 1 required positional argument: 'features'"
35+
]
36+
}
37+
],
38+
"source": [
39+
"import jax\n",
40+
"import jax.numpy as jnp\n",
41+
"import flax.linen as nn\n",
42+
"from flax.training import train_state\n",
43+
"import optax\n",
44+
"import numpy as np\n",
45+
"\n",
46+
"Array = jax.Array\n",
47+
"PRNGKey = jax.Array\n",
48+
"CellCarry = tuple[Array, Array]\n",
49+
"\n",
50+
"\n",
51+
"# Set random seed for reproducibility\n",
52+
"root_key = jax.random.PRNGKey(0)\n",
53+
"\n",
54+
"\n",
55+
"# Generate multidimensional sequential data\n",
56+
"def generate_multidim_data(input_seq_length, output_seq_length):\n",
57+
" num_points = input_seq_length + output_seq_length\n",
58+
" t = np.linspace(0, 8 * jnp.pi, num_points)\n",
59+
" X = np.empty((num_points, input_dim))\n",
60+
"\n",
61+
" # Generate different patterns for each dimension\n",
62+
" X[:, 0] = np.sin(t)\n",
63+
" X[:, 1] = np.cos(t)\n",
64+
" X[:, 2] = np.sin(2 * t)\n",
65+
"\n",
66+
" # let's have one point overlap between input and output\n",
67+
" Y = X[input_seq_length-1:, 0] + X[input_seq_length-1:, 1] + X[input_seq_length:-1, 2]\n",
68+
"\n",
69+
" return jnp.array(X[:input_seq_length]), jnp.array(Y)\n",
70+
"\n",
71+
"\n",
72+
"class DecoderCell(nn.RNNCellBase):\n",
73+
" features: int\n",
74+
"\n",
75+
" @nn.compact\n",
76+
" def __call__(self, carry, x):\n",
77+
" state, last_prediction = carry\n",
78+
" \n",
79+
"\n",
80+
"\n",
81+
"class Seq2seq(nn.Module):\n",
82+
" \"\"\"Sequence-to-sequence class using encoder/decoder architecture.\"\"\"\n",
83+
"\n",
84+
" encoder_size: int\n",
85+
" decoder_size: int\n",
86+
"\n",
87+
" @nn.compact\n",
88+
" def __call__(self, X: Array, y: Array) -> tuple[Array, Array]:\n",
89+
" \"\"\"Applies the seq2seq model.\"\"\"\n",
90+
" # X shape (batch size, input length, input dims)\n",
91+
" # y shape (batch size, output length)\n",
92+
" encoder = nn.RNN(\n",
93+
" nn.GRUCell(self.encoder_size), return_carry=True, name=\"encoder\"\n",
94+
" )\n",
95+
" decoder = nn.RNN(nn.GRUCell(self.decoder_size), name=\"decoder\")\n",
96+
"\n",
97+
" state, _ = encoder(X)\n",
98+
" y = decoder(state, y)\n",
99+
" return y\n",
100+
"\n",
101+
"\n",
102+
"# Create and initialize the model\n",
103+
"def create_train_state(key, encoder_size, decoder_size, learning_rate, input_shape, output_size):\n",
104+
" model = Seq2seq(encoder_size, decoder_size)\n",
105+
" params = model.init(key, jnp.ones(input_shape), output_size)[\"params\"]\n",
106+
" tx = optax.nadamw(learning_rate)\n",
107+
" return train_state.TrainState.create(apply_fn=model.apply, params=params, tx=tx)\n",
108+
"\n",
109+
"\n",
110+
"# Define loss function\n",
111+
"def mse_loss(params, model, X, y):\n",
112+
" pred = model.apply(params, X, y[:, 0], y.shape[-1])\n",
113+
" return optax.squared_error(pred, y[:, 1:]).mean()\n",
114+
"\n",
115+
"\n",
116+
"# Training step\n",
117+
"@jax.jit\n",
118+
"def train_step(state, X, y):\n",
119+
" loss, grads = jax.value_and_grad(mse_loss)(state.params, state.apply_fn, X, y)\n",
120+
" state = state.apply_gradients(grads=grads)\n",
121+
" return state, loss\n",
122+
"\n",
123+
"\n",
124+
"# Main training loop\n",
125+
"def train_lstm(input_dim, input_seq_length, output_seq_length):\n",
126+
" # Hyperparameters\n",
127+
" hidden_size = 64\n",
128+
" learning_rate = 0.01\n",
129+
" num_epochs = 1000\n",
130+
" batch_size = 32\n",
131+
"\n",
132+
" # Generate data\n",
133+
" X, y = generate_multidim_data(2000, input_dim, input_seq_length, output_seq_length)\n",
134+
"\n",
135+
" # Create and initialize model\n",
136+
" state = create_train_state(\n",
137+
" key, hidden_size, input_dim, output_seq_length, learning_rate, X.shape[1:]\n",
138+
" )\n",
139+
"\n",
140+
" # Training loop\n",
141+
" for epoch in range(num_epochs):\n",
142+
" for i in range(0, len(X), batch_size):\n",
143+
" batch_X = X[i : i + batch_size]\n",
144+
" batch_y = y[i : i + batch_size]\n",
145+
" state, loss = train_step(state, batch_X, batch_y)\n",
146+
"\n",
147+
" if epoch % 100 == 0:\n",
148+
" print(f\"Epoch {epoch}, Loss: {loss}\")\n",
149+
"\n",
150+
" return state\n",
151+
"\n",
152+
"\n",
153+
"# Run training\n",
154+
"input_dim = 5 # Number of input dimensions\n",
155+
"input_seq_length = 20 # Length of input sequence (K)\n",
156+
"output_seq_length = 10 # Length of output sequence (N)\n",
157+
"trained_state = train_lstm(input_dim, input_seq_length, output_seq_length)\n",
158+
"\n",
159+
"\n",
160+
"# Evaluate the model\n",
161+
"def evaluate_model(state, X, y):\n",
162+
" predictions = state.apply_fn({\"params\": state.params}, X)\n",
163+
" mse = jnp.mean((predictions - y) ** 2)\n",
164+
" print(f\"Mean Squared Error: {mse}\")\n",
165+
"\n",
166+
" # Plot results for the first dimension\n",
167+
" import matplotlib.pyplot as plt\n",
168+
"\n",
169+
" plt.figure(figsize=(12, 6))\n",
170+
" plt.plot(y[0, :, 0], label=\"True\")\n",
171+
" plt.plot(predictions[0, :, 0], label=\"Predicted\")\n",
172+
" plt.legend()\n",
173+
" plt.title(\"LSTM Multidimensional Sequence-to-Sequence Prediction\")\n",
174+
" plt.xlabel(\"Time Steps\")\n",
175+
" plt.ylabel(\"Value\")\n",
176+
" plt.show()\n",
177+
"\n",
178+
"\n",
179+
"# Generate test data and evaluate\n",
180+
"X_test, y_test = generate_multidim_data(\n",
181+
" 200, input_dim, input_seq_length, output_seq_length\n",
182+
")\n",
183+
"\n",
184+
"evaluate_model(trained_state, X_test, y_test)"
185+
]
186+
}
187+
],
188+
"metadata": {
189+
"kernelspec": {
190+
"display_name": "Python 3",
191+
"language": "python",
192+
"name": "python3"
193+
},
194+
"language_info": {
195+
"codemirror_mode": {
196+
"name": "ipython",
197+
"version": 3
198+
},
199+
"file_extension": ".py",
200+
"mimetype": "text/x-python",
201+
"name": "python",
202+
"nbconvert_exporter": "python",
203+
"pygments_lexer": "ipython3",
204+
"version": "3.12.6"
205+
}
206+
},
207+
"nbformat": 4,
208+
"nbformat_minor": 2
209+
}

0 commit comments

Comments
 (0)