diff --git a/unpublished/sequence_prediction.ipynb b/unpublished/sequence_prediction.ipynb new file mode 100644 index 0000000..636c977 --- /dev/null +++ b/unpublished/sequence_prediction.ipynb @@ -0,0 +1,209 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Prediction of time series with recurrent neural networks\n", + "\n", + "https://github.com/google/flax/blob/main/examples/seq2seq/models.py\n", + "\n", + "* We don't need vocabulary size\n", + "* We don't need one-hot encoding\n", + "* We don't need to sample from the output of the decoder" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [ + { + "ename": "TypeError", + "evalue": "LSTMCell.__init__() missing 1 required positional argument: 'features'", + "output_type": "error", + "traceback": [ + "\u001b[1;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[1;31mTypeError\u001b[0m Traceback (most recent call last)", + "Cell \u001b[1;32mIn[2], line 115\u001b[0m\n\u001b[0;32m 113\u001b[0m input_seq_length \u001b[38;5;241m=\u001b[39m \u001b[38;5;241m20\u001b[39m \u001b[38;5;66;03m# Length of input sequence (K)\u001b[39;00m\n\u001b[0;32m 114\u001b[0m output_seq_length \u001b[38;5;241m=\u001b[39m \u001b[38;5;241m10\u001b[39m \u001b[38;5;66;03m# Length of output sequence (N)\u001b[39;00m\n\u001b[1;32m--> 115\u001b[0m trained_state \u001b[38;5;241m=\u001b[39m \u001b[43mtrain_lstm\u001b[49m\u001b[43m(\u001b[49m\u001b[43minput_dim\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43minput_seq_length\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43moutput_seq_length\u001b[49m\u001b[43m)\u001b[49m\n\u001b[0;32m 117\u001b[0m \u001b[38;5;66;03m# Evaluate the model\u001b[39;00m\n\u001b[0;32m 118\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mevaluate_model\u001b[39m(state, X, y):\n", + "Cell \u001b[1;32mIn[2], line 97\u001b[0m, in \u001b[0;36mtrain_lstm\u001b[1;34m(input_dim, input_seq_length, output_seq_length)\u001b[0m\n\u001b[0;32m 94\u001b[0m X, y \u001b[38;5;241m=\u001b[39m generate_multidim_data(\u001b[38;5;241m2000\u001b[39m, input_dim, input_seq_length, output_seq_length)\n\u001b[0;32m 96\u001b[0m \u001b[38;5;66;03m# Create and initialize model\u001b[39;00m\n\u001b[1;32m---> 97\u001b[0m state \u001b[38;5;241m=\u001b[39m \u001b[43mcreate_train_state\u001b[49m\u001b[43m(\u001b[49m\u001b[43mkey\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mhidden_size\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43minput_dim\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43moutput_seq_length\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mlearning_rate\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mX\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mshape\u001b[49m\u001b[43m[\u001b[49m\u001b[38;5;241;43m1\u001b[39;49m\u001b[43m:\u001b[49m\u001b[43m]\u001b[49m\u001b[43m)\u001b[49m\n\u001b[0;32m 99\u001b[0m \u001b[38;5;66;03m# Training loop\u001b[39;00m\n\u001b[0;32m 100\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m epoch \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mrange\u001b[39m(num_epochs):\n", + "Cell \u001b[1;32mIn[2], line 69\u001b[0m, in \u001b[0;36mcreate_train_state\u001b[1;34m(key, hidden_size, output_dim, output_seq_length, learning_rate, input_shape)\u001b[0m\n\u001b[0;32m 67\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mcreate_train_state\u001b[39m(key, hidden_size, output_dim, output_seq_length, learning_rate, input_shape):\n\u001b[0;32m 68\u001b[0m model \u001b[38;5;241m=\u001b[39m LSTMEncoderDecoder(hidden_size\u001b[38;5;241m=\u001b[39mhidden_size, output_dim\u001b[38;5;241m=\u001b[39moutput_dim, output_seq_length\u001b[38;5;241m=\u001b[39moutput_seq_length)\n\u001b[1;32m---> 69\u001b[0m params \u001b[38;5;241m=\u001b[39m \u001b[43mmodel\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43minit\u001b[49m\u001b[43m(\u001b[49m\u001b[43mkey\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mjnp\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mones\u001b[49m\u001b[43m(\u001b[49m\u001b[43minput_shape\u001b[49m\u001b[43m)\u001b[49m\u001b[43m)\u001b[49m[\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mparams\u001b[39m\u001b[38;5;124m'\u001b[39m]\n\u001b[0;32m 70\u001b[0m tx \u001b[38;5;241m=\u001b[39m optax\u001b[38;5;241m.\u001b[39madam(learning_rate)\n\u001b[0;32m 71\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m train_state\u001b[38;5;241m.\u001b[39mTrainState\u001b[38;5;241m.\u001b[39mcreate(apply_fn\u001b[38;5;241m=\u001b[39mmodel\u001b[38;5;241m.\u001b[39mapply, params\u001b[38;5;241m=\u001b[39mparams, tx\u001b[38;5;241m=\u001b[39mtx)\n", + " \u001b[1;31m[... skipping hidden 9 frame]\u001b[0m\n", + "Cell \u001b[1;32mIn[2], line 41\u001b[0m, in \u001b[0;36mLSTMEncoderDecoder.__call__\u001b[1;34m(self, x)\u001b[0m\n\u001b[0;32m 38\u001b[0m \u001b[38;5;129m@nn\u001b[39m\u001b[38;5;241m.\u001b[39mcompact\n\u001b[0;32m 39\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21m__call__\u001b[39m(\u001b[38;5;28mself\u001b[39m, x):\n\u001b[0;32m 40\u001b[0m \u001b[38;5;66;03m# Encoder\u001b[39;00m\n\u001b[1;32m---> 41\u001b[0m encoder_lstm \u001b[38;5;241m=\u001b[39m \u001b[43mnn\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mLSTMCell\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[0;32m 42\u001b[0m encoder_dense \u001b[38;5;241m=\u001b[39m nn\u001b[38;5;241m.\u001b[39mDense(features\u001b[38;5;241m=\u001b[39m\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mhidden_size)\n\u001b[0;32m 44\u001b[0m \u001b[38;5;66;03m# Decoder\u001b[39;00m\n", + "File \u001b[1;32mc:\\Users\\HansDembinski\\AppData\\Local\\micromamba\\envs\\blog\\Lib\\site-packages\\flax\\linen\\kw_only_dataclasses.py:235\u001b[0m, in \u001b[0;36m_process_class..init_wrapper\u001b[1;34m(self, *args, **kwargs)\u001b[0m\n\u001b[0;32m 227\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m num_args \u001b[38;5;241m>\u001b[39m expected_num_args:\n\u001b[0;32m 228\u001b[0m \u001b[38;5;66;03m# we add + 1 to each to account for `self`, matching python's\u001b[39;00m\n\u001b[0;32m 229\u001b[0m \u001b[38;5;66;03m# default error message\u001b[39;00m\n\u001b[0;32m 230\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mTypeError\u001b[39;00m(\n\u001b[0;32m 231\u001b[0m \u001b[38;5;124mf\u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124m__init__() takes \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mexpected_num_args\u001b[38;5;250m \u001b[39m\u001b[38;5;241m+\u001b[39m\u001b[38;5;250m \u001b[39m\u001b[38;5;241m1\u001b[39m\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m positional \u001b[39m\u001b[38;5;124m'\u001b[39m\n\u001b[0;32m 232\u001b[0m \u001b[38;5;124mf\u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124marguments but \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mnum_args\u001b[38;5;250m \u001b[39m\u001b[38;5;241m+\u001b[39m\u001b[38;5;250m \u001b[39m\u001b[38;5;241m1\u001b[39m\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m were given\u001b[39m\u001b[38;5;124m'\u001b[39m\n\u001b[0;32m 233\u001b[0m )\n\u001b[1;32m--> 235\u001b[0m \u001b[43mdataclass_init\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n", + "\u001b[1;31mTypeError\u001b[0m: LSTMCell.__init__() missing 1 required positional argument: 'features'" + ] + } + ], + "source": [ + "import jax\n", + "import jax.numpy as jnp\n", + "import flax.linen as nn\n", + "from flax.training import train_state\n", + "import optax\n", + "import numpy as np\n", + "\n", + "Array = jax.Array\n", + "PRNGKey = jax.Array\n", + "CellCarry = tuple[Array, Array]\n", + "\n", + "\n", + "# Set random seed for reproducibility\n", + "root_key = jax.random.PRNGKey(0)\n", + "\n", + "\n", + "# Generate multidimensional sequential data\n", + "def generate_multidim_data(input_seq_length, output_seq_length):\n", + " num_points = input_seq_length + output_seq_length\n", + " t = np.linspace(0, 8 * jnp.pi, num_points)\n", + " X = np.empty((num_points, input_dim))\n", + "\n", + " # Generate different patterns for each dimension\n", + " X[:, 0] = np.sin(t)\n", + " X[:, 1] = np.cos(t)\n", + " X[:, 2] = np.sin(2 * t)\n", + "\n", + " # let's have one point overlap between input and output\n", + " Y = X[input_seq_length-1:, 0] + X[input_seq_length-1:, 1] + X[input_seq_length:-1, 2]\n", + "\n", + " return jnp.array(X[:input_seq_length]), jnp.array(Y)\n", + "\n", + "\n", + "class DecoderCell(nn.RNNCellBase):\n", + " features: int\n", + "\n", + " @nn.compact\n", + " def __call__(self, carry, x):\n", + " state, last_prediction = carry\n", + " \n", + "\n", + "\n", + "class Seq2seq(nn.Module):\n", + " \"\"\"Sequence-to-sequence class using encoder/decoder architecture.\"\"\"\n", + "\n", + " encoder_size: int\n", + " decoder_size: int\n", + "\n", + " @nn.compact\n", + " def __call__(self, X: Array, y: Array) -> tuple[Array, Array]:\n", + " \"\"\"Applies the seq2seq model.\"\"\"\n", + " # X shape (batch size, input length, input dims)\n", + " # y shape (batch size, output length)\n", + " encoder = nn.RNN(\n", + " nn.GRUCell(self.encoder_size), return_carry=True, name=\"encoder\"\n", + " )\n", + " decoder = nn.RNN(nn.GRUCell(self.decoder_size), name=\"decoder\")\n", + "\n", + " state, _ = encoder(X)\n", + " y = decoder(state, y)\n", + " return y\n", + "\n", + "\n", + "# Create and initialize the model\n", + "def create_train_state(key, encoder_size, decoder_size, learning_rate, input_shape, output_size):\n", + " model = Seq2seq(encoder_size, decoder_size)\n", + " params = model.init(key, jnp.ones(input_shape), output_size)[\"params\"]\n", + " tx = optax.nadamw(learning_rate)\n", + " return train_state.TrainState.create(apply_fn=model.apply, params=params, tx=tx)\n", + "\n", + "\n", + "# Define loss function\n", + "def mse_loss(params, model, X, y):\n", + " pred = model.apply(params, X, y[:, 0], y.shape[-1])\n", + " return optax.squared_error(pred, y[:, 1:]).mean()\n", + "\n", + "\n", + "# Training step\n", + "@jax.jit\n", + "def train_step(state, X, y):\n", + " loss, grads = jax.value_and_grad(mse_loss)(state.params, state.apply_fn, X, y)\n", + " state = state.apply_gradients(grads=grads)\n", + " return state, loss\n", + "\n", + "\n", + "# Main training loop\n", + "def train_lstm(input_dim, input_seq_length, output_seq_length):\n", + " # Hyperparameters\n", + " hidden_size = 64\n", + " learning_rate = 0.01\n", + " num_epochs = 1000\n", + " batch_size = 32\n", + "\n", + " # Generate data\n", + " X, y = generate_multidim_data(2000, input_dim, input_seq_length, output_seq_length)\n", + "\n", + " # Create and initialize model\n", + " state = create_train_state(\n", + " key, hidden_size, input_dim, output_seq_length, learning_rate, X.shape[1:]\n", + " )\n", + "\n", + " # Training loop\n", + " for epoch in range(num_epochs):\n", + " for i in range(0, len(X), batch_size):\n", + " batch_X = X[i : i + batch_size]\n", + " batch_y = y[i : i + batch_size]\n", + " state, loss = train_step(state, batch_X, batch_y)\n", + "\n", + " if epoch % 100 == 0:\n", + " print(f\"Epoch {epoch}, Loss: {loss}\")\n", + "\n", + " return state\n", + "\n", + "\n", + "# Run training\n", + "input_dim = 5 # Number of input dimensions\n", + "input_seq_length = 20 # Length of input sequence (K)\n", + "output_seq_length = 10 # Length of output sequence (N)\n", + "trained_state = train_lstm(input_dim, input_seq_length, output_seq_length)\n", + "\n", + "\n", + "# Evaluate the model\n", + "def evaluate_model(state, X, y):\n", + " predictions = state.apply_fn({\"params\": state.params}, X)\n", + " mse = jnp.mean((predictions - y) ** 2)\n", + " print(f\"Mean Squared Error: {mse}\")\n", + "\n", + " # Plot results for the first dimension\n", + " import matplotlib.pyplot as plt\n", + "\n", + " plt.figure(figsize=(12, 6))\n", + " plt.plot(y[0, :, 0], label=\"True\")\n", + " plt.plot(predictions[0, :, 0], label=\"Predicted\")\n", + " plt.legend()\n", + " plt.title(\"LSTM Multidimensional Sequence-to-Sequence Prediction\")\n", + " plt.xlabel(\"Time Steps\")\n", + " plt.ylabel(\"Value\")\n", + " plt.show()\n", + "\n", + "\n", + "# Generate test data and evaluate\n", + "X_test, y_test = generate_multidim_data(\n", + " 200, input_dim, input_seq_length, output_seq_length\n", + ")\n", + "\n", + "evaluate_model(trained_state, X_test, y_test)" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.12.6" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +}