|
| 1 | +{ |
| 2 | + "cells": [ |
| 3 | + { |
| 4 | + "cell_type": "markdown", |
| 5 | + "metadata": {}, |
| 6 | + "source": [ |
| 7 | + "# Prediction of time series with recurrent neural networks\n", |
| 8 | + "\n", |
| 9 | + "https://github.com/google/flax/blob/main/examples/seq2seq/models.py\n", |
| 10 | + "\n", |
| 11 | + "* We don't need vocabulary size\n", |
| 12 | + "* We don't need one-hot encoding\n", |
| 13 | + "* We don't need to sample from the output of the decoder" |
| 14 | + ] |
| 15 | + }, |
| 16 | + { |
| 17 | + "cell_type": "code", |
| 18 | + "execution_count": 2, |
| 19 | + "metadata": {}, |
| 20 | + "outputs": [ |
| 21 | + { |
| 22 | + "ename": "TypeError", |
| 23 | + "evalue": "LSTMCell.__init__() missing 1 required positional argument: 'features'", |
| 24 | + "output_type": "error", |
| 25 | + "traceback": [ |
| 26 | + "\u001b[1;31m---------------------------------------------------------------------------\u001b[0m", |
| 27 | + "\u001b[1;31mTypeError\u001b[0m Traceback (most recent call last)", |
| 28 | + "Cell \u001b[1;32mIn[2], line 115\u001b[0m\n\u001b[0;32m 113\u001b[0m input_seq_length \u001b[38;5;241m=\u001b[39m \u001b[38;5;241m20\u001b[39m \u001b[38;5;66;03m# Length of input sequence (K)\u001b[39;00m\n\u001b[0;32m 114\u001b[0m output_seq_length \u001b[38;5;241m=\u001b[39m \u001b[38;5;241m10\u001b[39m \u001b[38;5;66;03m# Length of output sequence (N)\u001b[39;00m\n\u001b[1;32m--> 115\u001b[0m trained_state \u001b[38;5;241m=\u001b[39m \u001b[43mtrain_lstm\u001b[49m\u001b[43m(\u001b[49m\u001b[43minput_dim\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43minput_seq_length\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43moutput_seq_length\u001b[49m\u001b[43m)\u001b[49m\n\u001b[0;32m 117\u001b[0m \u001b[38;5;66;03m# Evaluate the model\u001b[39;00m\n\u001b[0;32m 118\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mevaluate_model\u001b[39m(state, X, y):\n", |
| 29 | + "Cell \u001b[1;32mIn[2], line 97\u001b[0m, in \u001b[0;36mtrain_lstm\u001b[1;34m(input_dim, input_seq_length, output_seq_length)\u001b[0m\n\u001b[0;32m 94\u001b[0m X, y \u001b[38;5;241m=\u001b[39m generate_multidim_data(\u001b[38;5;241m2000\u001b[39m, input_dim, input_seq_length, output_seq_length)\n\u001b[0;32m 96\u001b[0m \u001b[38;5;66;03m# Create and initialize model\u001b[39;00m\n\u001b[1;32m---> 97\u001b[0m state \u001b[38;5;241m=\u001b[39m \u001b[43mcreate_train_state\u001b[49m\u001b[43m(\u001b[49m\u001b[43mkey\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mhidden_size\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43minput_dim\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43moutput_seq_length\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mlearning_rate\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mX\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mshape\u001b[49m\u001b[43m[\u001b[49m\u001b[38;5;241;43m1\u001b[39;49m\u001b[43m:\u001b[49m\u001b[43m]\u001b[49m\u001b[43m)\u001b[49m\n\u001b[0;32m 99\u001b[0m \u001b[38;5;66;03m# Training loop\u001b[39;00m\n\u001b[0;32m 100\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m epoch \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mrange\u001b[39m(num_epochs):\n", |
| 30 | + "Cell \u001b[1;32mIn[2], line 69\u001b[0m, in \u001b[0;36mcreate_train_state\u001b[1;34m(key, hidden_size, output_dim, output_seq_length, learning_rate, input_shape)\u001b[0m\n\u001b[0;32m 67\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mcreate_train_state\u001b[39m(key, hidden_size, output_dim, output_seq_length, learning_rate, input_shape):\n\u001b[0;32m 68\u001b[0m model \u001b[38;5;241m=\u001b[39m LSTMEncoderDecoder(hidden_size\u001b[38;5;241m=\u001b[39mhidden_size, output_dim\u001b[38;5;241m=\u001b[39moutput_dim, output_seq_length\u001b[38;5;241m=\u001b[39moutput_seq_length)\n\u001b[1;32m---> 69\u001b[0m params \u001b[38;5;241m=\u001b[39m \u001b[43mmodel\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43minit\u001b[49m\u001b[43m(\u001b[49m\u001b[43mkey\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mjnp\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mones\u001b[49m\u001b[43m(\u001b[49m\u001b[43minput_shape\u001b[49m\u001b[43m)\u001b[49m\u001b[43m)\u001b[49m[\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mparams\u001b[39m\u001b[38;5;124m'\u001b[39m]\n\u001b[0;32m 70\u001b[0m tx \u001b[38;5;241m=\u001b[39m optax\u001b[38;5;241m.\u001b[39madam(learning_rate)\n\u001b[0;32m 71\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m train_state\u001b[38;5;241m.\u001b[39mTrainState\u001b[38;5;241m.\u001b[39mcreate(apply_fn\u001b[38;5;241m=\u001b[39mmodel\u001b[38;5;241m.\u001b[39mapply, params\u001b[38;5;241m=\u001b[39mparams, tx\u001b[38;5;241m=\u001b[39mtx)\n", |
| 31 | + " \u001b[1;31m[... skipping hidden 9 frame]\u001b[0m\n", |
| 32 | + "Cell \u001b[1;32mIn[2], line 41\u001b[0m, in \u001b[0;36mLSTMEncoderDecoder.__call__\u001b[1;34m(self, x)\u001b[0m\n\u001b[0;32m 38\u001b[0m \u001b[38;5;129m@nn\u001b[39m\u001b[38;5;241m.\u001b[39mcompact\n\u001b[0;32m 39\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21m__call__\u001b[39m(\u001b[38;5;28mself\u001b[39m, x):\n\u001b[0;32m 40\u001b[0m \u001b[38;5;66;03m# Encoder\u001b[39;00m\n\u001b[1;32m---> 41\u001b[0m encoder_lstm \u001b[38;5;241m=\u001b[39m \u001b[43mnn\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mLSTMCell\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[0;32m 42\u001b[0m encoder_dense \u001b[38;5;241m=\u001b[39m nn\u001b[38;5;241m.\u001b[39mDense(features\u001b[38;5;241m=\u001b[39m\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mhidden_size)\n\u001b[0;32m 44\u001b[0m \u001b[38;5;66;03m# Decoder\u001b[39;00m\n", |
| 33 | + "File \u001b[1;32mc:\\Users\\HansDembinski\\AppData\\Local\\micromamba\\envs\\blog\\Lib\\site-packages\\flax\\linen\\kw_only_dataclasses.py:235\u001b[0m, in \u001b[0;36m_process_class.<locals>.init_wrapper\u001b[1;34m(self, *args, **kwargs)\u001b[0m\n\u001b[0;32m 227\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m num_args \u001b[38;5;241m>\u001b[39m expected_num_args:\n\u001b[0;32m 228\u001b[0m \u001b[38;5;66;03m# we add + 1 to each to account for `self`, matching python's\u001b[39;00m\n\u001b[0;32m 229\u001b[0m \u001b[38;5;66;03m# default error message\u001b[39;00m\n\u001b[0;32m 230\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mTypeError\u001b[39;00m(\n\u001b[0;32m 231\u001b[0m \u001b[38;5;124mf\u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124m__init__() takes \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mexpected_num_args\u001b[38;5;250m \u001b[39m\u001b[38;5;241m+\u001b[39m\u001b[38;5;250m \u001b[39m\u001b[38;5;241m1\u001b[39m\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m positional \u001b[39m\u001b[38;5;124m'\u001b[39m\n\u001b[0;32m 232\u001b[0m \u001b[38;5;124mf\u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124marguments but \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mnum_args\u001b[38;5;250m \u001b[39m\u001b[38;5;241m+\u001b[39m\u001b[38;5;250m \u001b[39m\u001b[38;5;241m1\u001b[39m\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m were given\u001b[39m\u001b[38;5;124m'\u001b[39m\n\u001b[0;32m 233\u001b[0m )\n\u001b[1;32m--> 235\u001b[0m \u001b[43mdataclass_init\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n", |
| 34 | + "\u001b[1;31mTypeError\u001b[0m: LSTMCell.__init__() missing 1 required positional argument: 'features'" |
| 35 | + ] |
| 36 | + } |
| 37 | + ], |
| 38 | + "source": [ |
| 39 | + "import jax\n", |
| 40 | + "import jax.numpy as jnp\n", |
| 41 | + "import flax.linen as nn\n", |
| 42 | + "from flax.training import train_state\n", |
| 43 | + "import optax\n", |
| 44 | + "import numpy as np\n", |
| 45 | + "\n", |
| 46 | + "Array = jax.Array\n", |
| 47 | + "PRNGKey = jax.Array\n", |
| 48 | + "CellCarry = tuple[Array, Array]\n", |
| 49 | + "\n", |
| 50 | + "\n", |
| 51 | + "# Set random seed for reproducibility\n", |
| 52 | + "root_key = jax.random.PRNGKey(0)\n", |
| 53 | + "\n", |
| 54 | + "\n", |
| 55 | + "# Generate multidimensional sequential data\n", |
| 56 | + "def generate_multidim_data(input_seq_length, output_seq_length):\n", |
| 57 | + " num_points = input_seq_length + output_seq_length\n", |
| 58 | + " t = np.linspace(0, 8 * jnp.pi, num_points)\n", |
| 59 | + " X = np.empty((num_points, input_dim))\n", |
| 60 | + "\n", |
| 61 | + " # Generate different patterns for each dimension\n", |
| 62 | + " X[:, 0] = np.sin(t)\n", |
| 63 | + " X[:, 1] = np.cos(t)\n", |
| 64 | + " X[:, 2] = np.sin(2 * t)\n", |
| 65 | + "\n", |
| 66 | + " # let's have one point overlap between input and output\n", |
| 67 | + " Y = X[input_seq_length-1:, 0] + X[input_seq_length-1:, 1] + X[input_seq_length:-1, 2]\n", |
| 68 | + "\n", |
| 69 | + " return jnp.array(X[:input_seq_length]), jnp.array(Y)\n", |
| 70 | + "\n", |
| 71 | + "\n", |
| 72 | + "class DecoderCell(nn.RNNCellBase):\n", |
| 73 | + " features: int\n", |
| 74 | + "\n", |
| 75 | + " @nn.compact\n", |
| 76 | + " def __call__(self, carry, x):\n", |
| 77 | + " state, last_prediction = carry\n", |
| 78 | + " \n", |
| 79 | + "\n", |
| 80 | + "\n", |
| 81 | + "class Seq2seq(nn.Module):\n", |
| 82 | + " \"\"\"Sequence-to-sequence class using encoder/decoder architecture.\"\"\"\n", |
| 83 | + "\n", |
| 84 | + " encoder_size: int\n", |
| 85 | + " decoder_size: int\n", |
| 86 | + "\n", |
| 87 | + " @nn.compact\n", |
| 88 | + " def __call__(self, X: Array, y: Array) -> tuple[Array, Array]:\n", |
| 89 | + " \"\"\"Applies the seq2seq model.\"\"\"\n", |
| 90 | + " # X shape (batch size, input length, input dims)\n", |
| 91 | + " # y shape (batch size, output length)\n", |
| 92 | + " encoder = nn.RNN(\n", |
| 93 | + " nn.GRUCell(self.encoder_size), return_carry=True, name=\"encoder\"\n", |
| 94 | + " )\n", |
| 95 | + " decoder = nn.RNN(nn.GRUCell(self.decoder_size), name=\"decoder\")\n", |
| 96 | + "\n", |
| 97 | + " state, _ = encoder(X)\n", |
| 98 | + " y = decoder(state, y)\n", |
| 99 | + " return y\n", |
| 100 | + "\n", |
| 101 | + "\n", |
| 102 | + "# Create and initialize the model\n", |
| 103 | + "def create_train_state(key, encoder_size, decoder_size, learning_rate, input_shape, output_size):\n", |
| 104 | + " model = Seq2seq(encoder_size, decoder_size)\n", |
| 105 | + " params = model.init(key, jnp.ones(input_shape), output_size)[\"params\"]\n", |
| 106 | + " tx = optax.nadamw(learning_rate)\n", |
| 107 | + " return train_state.TrainState.create(apply_fn=model.apply, params=params, tx=tx)\n", |
| 108 | + "\n", |
| 109 | + "\n", |
| 110 | + "# Define loss function\n", |
| 111 | + "def mse_loss(params, model, X, y):\n", |
| 112 | + " pred = model.apply(params, X, y[:, 0], y.shape[-1])\n", |
| 113 | + " return optax.squared_error(pred, y[:, 1:]).mean()\n", |
| 114 | + "\n", |
| 115 | + "\n", |
| 116 | + "# Training step\n", |
| 117 | + "@jax.jit\n", |
| 118 | + "def train_step(state, X, y):\n", |
| 119 | + " loss, grads = jax.value_and_grad(mse_loss)(state.params, state.apply_fn, X, y)\n", |
| 120 | + " state = state.apply_gradients(grads=grads)\n", |
| 121 | + " return state, loss\n", |
| 122 | + "\n", |
| 123 | + "\n", |
| 124 | + "# Main training loop\n", |
| 125 | + "def train_lstm(input_dim, input_seq_length, output_seq_length):\n", |
| 126 | + " # Hyperparameters\n", |
| 127 | + " hidden_size = 64\n", |
| 128 | + " learning_rate = 0.01\n", |
| 129 | + " num_epochs = 1000\n", |
| 130 | + " batch_size = 32\n", |
| 131 | + "\n", |
| 132 | + " # Generate data\n", |
| 133 | + " X, y = generate_multidim_data(2000, input_dim, input_seq_length, output_seq_length)\n", |
| 134 | + "\n", |
| 135 | + " # Create and initialize model\n", |
| 136 | + " state = create_train_state(\n", |
| 137 | + " key, hidden_size, input_dim, output_seq_length, learning_rate, X.shape[1:]\n", |
| 138 | + " )\n", |
| 139 | + "\n", |
| 140 | + " # Training loop\n", |
| 141 | + " for epoch in range(num_epochs):\n", |
| 142 | + " for i in range(0, len(X), batch_size):\n", |
| 143 | + " batch_X = X[i : i + batch_size]\n", |
| 144 | + " batch_y = y[i : i + batch_size]\n", |
| 145 | + " state, loss = train_step(state, batch_X, batch_y)\n", |
| 146 | + "\n", |
| 147 | + " if epoch % 100 == 0:\n", |
| 148 | + " print(f\"Epoch {epoch}, Loss: {loss}\")\n", |
| 149 | + "\n", |
| 150 | + " return state\n", |
| 151 | + "\n", |
| 152 | + "\n", |
| 153 | + "# Run training\n", |
| 154 | + "input_dim = 5 # Number of input dimensions\n", |
| 155 | + "input_seq_length = 20 # Length of input sequence (K)\n", |
| 156 | + "output_seq_length = 10 # Length of output sequence (N)\n", |
| 157 | + "trained_state = train_lstm(input_dim, input_seq_length, output_seq_length)\n", |
| 158 | + "\n", |
| 159 | + "\n", |
| 160 | + "# Evaluate the model\n", |
| 161 | + "def evaluate_model(state, X, y):\n", |
| 162 | + " predictions = state.apply_fn({\"params\": state.params}, X)\n", |
| 163 | + " mse = jnp.mean((predictions - y) ** 2)\n", |
| 164 | + " print(f\"Mean Squared Error: {mse}\")\n", |
| 165 | + "\n", |
| 166 | + " # Plot results for the first dimension\n", |
| 167 | + " import matplotlib.pyplot as plt\n", |
| 168 | + "\n", |
| 169 | + " plt.figure(figsize=(12, 6))\n", |
| 170 | + " plt.plot(y[0, :, 0], label=\"True\")\n", |
| 171 | + " plt.plot(predictions[0, :, 0], label=\"Predicted\")\n", |
| 172 | + " plt.legend()\n", |
| 173 | + " plt.title(\"LSTM Multidimensional Sequence-to-Sequence Prediction\")\n", |
| 174 | + " plt.xlabel(\"Time Steps\")\n", |
| 175 | + " plt.ylabel(\"Value\")\n", |
| 176 | + " plt.show()\n", |
| 177 | + "\n", |
| 178 | + "\n", |
| 179 | + "# Generate test data and evaluate\n", |
| 180 | + "X_test, y_test = generate_multidim_data(\n", |
| 181 | + " 200, input_dim, input_seq_length, output_seq_length\n", |
| 182 | + ")\n", |
| 183 | + "\n", |
| 184 | + "evaluate_model(trained_state, X_test, y_test)" |
| 185 | + ] |
| 186 | + } |
| 187 | + ], |
| 188 | + "metadata": { |
| 189 | + "kernelspec": { |
| 190 | + "display_name": "Python 3", |
| 191 | + "language": "python", |
| 192 | + "name": "python3" |
| 193 | + }, |
| 194 | + "language_info": { |
| 195 | + "codemirror_mode": { |
| 196 | + "name": "ipython", |
| 197 | + "version": 3 |
| 198 | + }, |
| 199 | + "file_extension": ".py", |
| 200 | + "mimetype": "text/x-python", |
| 201 | + "name": "python", |
| 202 | + "nbconvert_exporter": "python", |
| 203 | + "pygments_lexer": "ipython3", |
| 204 | + "version": "3.12.6" |
| 205 | + } |
| 206 | + }, |
| 207 | + "nbformat": 4, |
| 208 | + "nbformat_minor": 2 |
| 209 | +} |
0 commit comments