Skip to content

Commit

Permalink
u
Browse files Browse the repository at this point in the history
  • Loading branch information
HDembinski committed Oct 12, 2024
1 parent 1fab8ff commit c1ec2f7
Showing 1 changed file with 209 additions and 0 deletions.
209 changes: 209 additions & 0 deletions unpublished/sequence_prediction.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,209 @@
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Prediction of time series with recurrent neural networks\n",
"\n",
"https://github.com/google/flax/blob/main/examples/seq2seq/models.py\n",
"\n",
"* We don't need vocabulary size\n",
"* We don't need one-hot encoding\n",
"* We don't need to sample from the output of the decoder"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [
{
"ename": "TypeError",
"evalue": "LSTMCell.__init__() missing 1 required positional argument: 'features'",
"output_type": "error",
"traceback": [
"\u001b[1;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[1;31mTypeError\u001b[0m Traceback (most recent call last)",
"Cell \u001b[1;32mIn[2], line 115\u001b[0m\n\u001b[0;32m 113\u001b[0m input_seq_length \u001b[38;5;241m=\u001b[39m \u001b[38;5;241m20\u001b[39m \u001b[38;5;66;03m# Length of input sequence (K)\u001b[39;00m\n\u001b[0;32m 114\u001b[0m output_seq_length \u001b[38;5;241m=\u001b[39m \u001b[38;5;241m10\u001b[39m \u001b[38;5;66;03m# Length of output sequence (N)\u001b[39;00m\n\u001b[1;32m--> 115\u001b[0m trained_state \u001b[38;5;241m=\u001b[39m \u001b[43mtrain_lstm\u001b[49m\u001b[43m(\u001b[49m\u001b[43minput_dim\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43minput_seq_length\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43moutput_seq_length\u001b[49m\u001b[43m)\u001b[49m\n\u001b[0;32m 117\u001b[0m \u001b[38;5;66;03m# Evaluate the model\u001b[39;00m\n\u001b[0;32m 118\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mevaluate_model\u001b[39m(state, X, y):\n",
"Cell \u001b[1;32mIn[2], line 97\u001b[0m, in \u001b[0;36mtrain_lstm\u001b[1;34m(input_dim, input_seq_length, output_seq_length)\u001b[0m\n\u001b[0;32m 94\u001b[0m X, y \u001b[38;5;241m=\u001b[39m generate_multidim_data(\u001b[38;5;241m2000\u001b[39m, input_dim, input_seq_length, output_seq_length)\n\u001b[0;32m 96\u001b[0m \u001b[38;5;66;03m# Create and initialize model\u001b[39;00m\n\u001b[1;32m---> 97\u001b[0m state \u001b[38;5;241m=\u001b[39m \u001b[43mcreate_train_state\u001b[49m\u001b[43m(\u001b[49m\u001b[43mkey\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mhidden_size\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43minput_dim\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43moutput_seq_length\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mlearning_rate\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mX\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mshape\u001b[49m\u001b[43m[\u001b[49m\u001b[38;5;241;43m1\u001b[39;49m\u001b[43m:\u001b[49m\u001b[43m]\u001b[49m\u001b[43m)\u001b[49m\n\u001b[0;32m 99\u001b[0m \u001b[38;5;66;03m# Training loop\u001b[39;00m\n\u001b[0;32m 100\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m epoch \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mrange\u001b[39m(num_epochs):\n",
"Cell \u001b[1;32mIn[2], line 69\u001b[0m, in \u001b[0;36mcreate_train_state\u001b[1;34m(key, hidden_size, output_dim, output_seq_length, learning_rate, input_shape)\u001b[0m\n\u001b[0;32m 67\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mcreate_train_state\u001b[39m(key, hidden_size, output_dim, output_seq_length, learning_rate, input_shape):\n\u001b[0;32m 68\u001b[0m model \u001b[38;5;241m=\u001b[39m LSTMEncoderDecoder(hidden_size\u001b[38;5;241m=\u001b[39mhidden_size, output_dim\u001b[38;5;241m=\u001b[39moutput_dim, output_seq_length\u001b[38;5;241m=\u001b[39moutput_seq_length)\n\u001b[1;32m---> 69\u001b[0m params \u001b[38;5;241m=\u001b[39m \u001b[43mmodel\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43minit\u001b[49m\u001b[43m(\u001b[49m\u001b[43mkey\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mjnp\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mones\u001b[49m\u001b[43m(\u001b[49m\u001b[43minput_shape\u001b[49m\u001b[43m)\u001b[49m\u001b[43m)\u001b[49m[\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mparams\u001b[39m\u001b[38;5;124m'\u001b[39m]\n\u001b[0;32m 70\u001b[0m tx \u001b[38;5;241m=\u001b[39m optax\u001b[38;5;241m.\u001b[39madam(learning_rate)\n\u001b[0;32m 71\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m train_state\u001b[38;5;241m.\u001b[39mTrainState\u001b[38;5;241m.\u001b[39mcreate(apply_fn\u001b[38;5;241m=\u001b[39mmodel\u001b[38;5;241m.\u001b[39mapply, params\u001b[38;5;241m=\u001b[39mparams, tx\u001b[38;5;241m=\u001b[39mtx)\n",
" \u001b[1;31m[... skipping hidden 9 frame]\u001b[0m\n",
"Cell \u001b[1;32mIn[2], line 41\u001b[0m, in \u001b[0;36mLSTMEncoderDecoder.__call__\u001b[1;34m(self, x)\u001b[0m\n\u001b[0;32m 38\u001b[0m \u001b[38;5;129m@nn\u001b[39m\u001b[38;5;241m.\u001b[39mcompact\n\u001b[0;32m 39\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21m__call__\u001b[39m(\u001b[38;5;28mself\u001b[39m, x):\n\u001b[0;32m 40\u001b[0m \u001b[38;5;66;03m# Encoder\u001b[39;00m\n\u001b[1;32m---> 41\u001b[0m encoder_lstm \u001b[38;5;241m=\u001b[39m \u001b[43mnn\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mLSTMCell\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[0;32m 42\u001b[0m encoder_dense \u001b[38;5;241m=\u001b[39m nn\u001b[38;5;241m.\u001b[39mDense(features\u001b[38;5;241m=\u001b[39m\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mhidden_size)\n\u001b[0;32m 44\u001b[0m \u001b[38;5;66;03m# Decoder\u001b[39;00m\n",
"File \u001b[1;32mc:\\Users\\HansDembinski\\AppData\\Local\\micromamba\\envs\\blog\\Lib\\site-packages\\flax\\linen\\kw_only_dataclasses.py:235\u001b[0m, in \u001b[0;36m_process_class.<locals>.init_wrapper\u001b[1;34m(self, *args, **kwargs)\u001b[0m\n\u001b[0;32m 227\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m num_args \u001b[38;5;241m>\u001b[39m expected_num_args:\n\u001b[0;32m 228\u001b[0m \u001b[38;5;66;03m# we add + 1 to each to account for `self`, matching python's\u001b[39;00m\n\u001b[0;32m 229\u001b[0m \u001b[38;5;66;03m# default error message\u001b[39;00m\n\u001b[0;32m 230\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mTypeError\u001b[39;00m(\n\u001b[0;32m 231\u001b[0m \u001b[38;5;124mf\u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124m__init__() takes \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mexpected_num_args\u001b[38;5;250m \u001b[39m\u001b[38;5;241m+\u001b[39m\u001b[38;5;250m \u001b[39m\u001b[38;5;241m1\u001b[39m\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m positional \u001b[39m\u001b[38;5;124m'\u001b[39m\n\u001b[0;32m 232\u001b[0m \u001b[38;5;124mf\u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124marguments but \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mnum_args\u001b[38;5;250m \u001b[39m\u001b[38;5;241m+\u001b[39m\u001b[38;5;250m \u001b[39m\u001b[38;5;241m1\u001b[39m\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m were given\u001b[39m\u001b[38;5;124m'\u001b[39m\n\u001b[0;32m 233\u001b[0m )\n\u001b[1;32m--> 235\u001b[0m \u001b[43mdataclass_init\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n",
"\u001b[1;31mTypeError\u001b[0m: LSTMCell.__init__() missing 1 required positional argument: 'features'"
]
}
],
"source": [
"import jax\n",
"import jax.numpy as jnp\n",
"import flax.linen as nn\n",
"from flax.training import train_state\n",
"import optax\n",
"import numpy as np\n",
"\n",
"Array = jax.Array\n",
"PRNGKey = jax.Array\n",
"CellCarry = tuple[Array, Array]\n",
"\n",
"\n",
"# Set random seed for reproducibility\n",
"root_key = jax.random.PRNGKey(0)\n",
"\n",
"\n",
"# Generate multidimensional sequential data\n",
"def generate_multidim_data(input_seq_length, output_seq_length):\n",
" num_points = input_seq_length + output_seq_length\n",
" t = np.linspace(0, 8 * jnp.pi, num_points)\n",
" X = np.empty((num_points, input_dim))\n",
"\n",
" # Generate different patterns for each dimension\n",
" X[:, 0] = np.sin(t)\n",
" X[:, 1] = np.cos(t)\n",
" X[:, 2] = np.sin(2 * t)\n",
"\n",
" # let's have one point overlap between input and output\n",
" Y = X[input_seq_length-1:, 0] + X[input_seq_length-1:, 1] + X[input_seq_length:-1, 2]\n",
"\n",
" return jnp.array(X[:input_seq_length]), jnp.array(Y)\n",
"\n",
"\n",
"class DecoderCell(nn.RNNCellBase):\n",
" features: int\n",
"\n",
" @nn.compact\n",
" def __call__(self, carry, x):\n",
" state, last_prediction = carry\n",
" \n",
"\n",
"\n",
"class Seq2seq(nn.Module):\n",
" \"\"\"Sequence-to-sequence class using encoder/decoder architecture.\"\"\"\n",
"\n",
" encoder_size: int\n",
" decoder_size: int\n",
"\n",
" @nn.compact\n",
" def __call__(self, X: Array, y: Array) -> tuple[Array, Array]:\n",
" \"\"\"Applies the seq2seq model.\"\"\"\n",
" # X shape (batch size, input length, input dims)\n",
" # y shape (batch size, output length)\n",
" encoder = nn.RNN(\n",
" nn.GRUCell(self.encoder_size), return_carry=True, name=\"encoder\"\n",
" )\n",
" decoder = nn.RNN(nn.GRUCell(self.decoder_size), name=\"decoder\")\n",
"\n",
" state, _ = encoder(X)\n",
" y = decoder(state, y)\n",
" return y\n",
"\n",
"\n",
"# Create and initialize the model\n",
"def create_train_state(key, encoder_size, decoder_size, learning_rate, input_shape, output_size):\n",
" model = Seq2seq(encoder_size, decoder_size)\n",
" params = model.init(key, jnp.ones(input_shape), output_size)[\"params\"]\n",
" tx = optax.nadamw(learning_rate)\n",
" return train_state.TrainState.create(apply_fn=model.apply, params=params, tx=tx)\n",
"\n",
"\n",
"# Define loss function\n",
"def mse_loss(params, model, X, y):\n",
" pred = model.apply(params, X, y[:, 0], y.shape[-1])\n",
" return optax.squared_error(pred, y[:, 1:]).mean()\n",
"\n",
"\n",
"# Training step\n",
"@jax.jit\n",
"def train_step(state, X, y):\n",
" loss, grads = jax.value_and_grad(mse_loss)(state.params, state.apply_fn, X, y)\n",
" state = state.apply_gradients(grads=grads)\n",
" return state, loss\n",
"\n",
"\n",
"# Main training loop\n",
"def train_lstm(input_dim, input_seq_length, output_seq_length):\n",
" # Hyperparameters\n",
" hidden_size = 64\n",
" learning_rate = 0.01\n",
" num_epochs = 1000\n",
" batch_size = 32\n",
"\n",
" # Generate data\n",
" X, y = generate_multidim_data(2000, input_dim, input_seq_length, output_seq_length)\n",
"\n",
" # Create and initialize model\n",
" state = create_train_state(\n",
" key, hidden_size, input_dim, output_seq_length, learning_rate, X.shape[1:]\n",
" )\n",
"\n",
" # Training loop\n",
" for epoch in range(num_epochs):\n",
" for i in range(0, len(X), batch_size):\n",
" batch_X = X[i : i + batch_size]\n",
" batch_y = y[i : i + batch_size]\n",
" state, loss = train_step(state, batch_X, batch_y)\n",
"\n",
" if epoch % 100 == 0:\n",
" print(f\"Epoch {epoch}, Loss: {loss}\")\n",
"\n",
" return state\n",
"\n",
"\n",
"# Run training\n",
"input_dim = 5 # Number of input dimensions\n",
"input_seq_length = 20 # Length of input sequence (K)\n",
"output_seq_length = 10 # Length of output sequence (N)\n",
"trained_state = train_lstm(input_dim, input_seq_length, output_seq_length)\n",
"\n",
"\n",
"# Evaluate the model\n",
"def evaluate_model(state, X, y):\n",
" predictions = state.apply_fn({\"params\": state.params}, X)\n",
" mse = jnp.mean((predictions - y) ** 2)\n",
" print(f\"Mean Squared Error: {mse}\")\n",
"\n",
" # Plot results for the first dimension\n",
" import matplotlib.pyplot as plt\n",
"\n",
" plt.figure(figsize=(12, 6))\n",
" plt.plot(y[0, :, 0], label=\"True\")\n",
" plt.plot(predictions[0, :, 0], label=\"Predicted\")\n",
" plt.legend()\n",
" plt.title(\"LSTM Multidimensional Sequence-to-Sequence Prediction\")\n",
" plt.xlabel(\"Time Steps\")\n",
" plt.ylabel(\"Value\")\n",
" plt.show()\n",
"\n",
"\n",
"# Generate test data and evaluate\n",
"X_test, y_test = generate_multidim_data(\n",
" 200, input_dim, input_seq_length, output_seq_length\n",
")\n",
"\n",
"evaluate_model(trained_state, X_test, y_test)"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.12.6"
}
},
"nbformat": 4,
"nbformat_minor": 2
}

0 comments on commit c1ec2f7

Please sign in to comment.