-
Notifications
You must be signed in to change notification settings - Fork 5
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
1fab8ff
commit c1ec2f7
Showing
1 changed file
with
209 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,209 @@ | ||
{ | ||
"cells": [ | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"# Prediction of time series with recurrent neural networks\n", | ||
"\n", | ||
"https://github.com/google/flax/blob/main/examples/seq2seq/models.py\n", | ||
"\n", | ||
"* We don't need vocabulary size\n", | ||
"* We don't need one-hot encoding\n", | ||
"* We don't need to sample from the output of the decoder" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 2, | ||
"metadata": {}, | ||
"outputs": [ | ||
{ | ||
"ename": "TypeError", | ||
"evalue": "LSTMCell.__init__() missing 1 required positional argument: 'features'", | ||
"output_type": "error", | ||
"traceback": [ | ||
"\u001b[1;31m---------------------------------------------------------------------------\u001b[0m", | ||
"\u001b[1;31mTypeError\u001b[0m Traceback (most recent call last)", | ||
"Cell \u001b[1;32mIn[2], line 115\u001b[0m\n\u001b[0;32m 113\u001b[0m input_seq_length \u001b[38;5;241m=\u001b[39m \u001b[38;5;241m20\u001b[39m \u001b[38;5;66;03m# Length of input sequence (K)\u001b[39;00m\n\u001b[0;32m 114\u001b[0m output_seq_length \u001b[38;5;241m=\u001b[39m \u001b[38;5;241m10\u001b[39m \u001b[38;5;66;03m# Length of output sequence (N)\u001b[39;00m\n\u001b[1;32m--> 115\u001b[0m trained_state \u001b[38;5;241m=\u001b[39m \u001b[43mtrain_lstm\u001b[49m\u001b[43m(\u001b[49m\u001b[43minput_dim\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43minput_seq_length\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43moutput_seq_length\u001b[49m\u001b[43m)\u001b[49m\n\u001b[0;32m 117\u001b[0m \u001b[38;5;66;03m# Evaluate the model\u001b[39;00m\n\u001b[0;32m 118\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mevaluate_model\u001b[39m(state, X, y):\n", | ||
"Cell \u001b[1;32mIn[2], line 97\u001b[0m, in \u001b[0;36mtrain_lstm\u001b[1;34m(input_dim, input_seq_length, output_seq_length)\u001b[0m\n\u001b[0;32m 94\u001b[0m X, y \u001b[38;5;241m=\u001b[39m generate_multidim_data(\u001b[38;5;241m2000\u001b[39m, input_dim, input_seq_length, output_seq_length)\n\u001b[0;32m 96\u001b[0m \u001b[38;5;66;03m# Create and initialize model\u001b[39;00m\n\u001b[1;32m---> 97\u001b[0m state \u001b[38;5;241m=\u001b[39m \u001b[43mcreate_train_state\u001b[49m\u001b[43m(\u001b[49m\u001b[43mkey\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mhidden_size\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43minput_dim\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43moutput_seq_length\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mlearning_rate\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mX\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mshape\u001b[49m\u001b[43m[\u001b[49m\u001b[38;5;241;43m1\u001b[39;49m\u001b[43m:\u001b[49m\u001b[43m]\u001b[49m\u001b[43m)\u001b[49m\n\u001b[0;32m 99\u001b[0m \u001b[38;5;66;03m# Training loop\u001b[39;00m\n\u001b[0;32m 100\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m epoch \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mrange\u001b[39m(num_epochs):\n", | ||
"Cell \u001b[1;32mIn[2], line 69\u001b[0m, in \u001b[0;36mcreate_train_state\u001b[1;34m(key, hidden_size, output_dim, output_seq_length, learning_rate, input_shape)\u001b[0m\n\u001b[0;32m 67\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mcreate_train_state\u001b[39m(key, hidden_size, output_dim, output_seq_length, learning_rate, input_shape):\n\u001b[0;32m 68\u001b[0m model \u001b[38;5;241m=\u001b[39m LSTMEncoderDecoder(hidden_size\u001b[38;5;241m=\u001b[39mhidden_size, output_dim\u001b[38;5;241m=\u001b[39moutput_dim, output_seq_length\u001b[38;5;241m=\u001b[39moutput_seq_length)\n\u001b[1;32m---> 69\u001b[0m params \u001b[38;5;241m=\u001b[39m \u001b[43mmodel\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43minit\u001b[49m\u001b[43m(\u001b[49m\u001b[43mkey\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mjnp\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mones\u001b[49m\u001b[43m(\u001b[49m\u001b[43minput_shape\u001b[49m\u001b[43m)\u001b[49m\u001b[43m)\u001b[49m[\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mparams\u001b[39m\u001b[38;5;124m'\u001b[39m]\n\u001b[0;32m 70\u001b[0m tx \u001b[38;5;241m=\u001b[39m optax\u001b[38;5;241m.\u001b[39madam(learning_rate)\n\u001b[0;32m 71\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m train_state\u001b[38;5;241m.\u001b[39mTrainState\u001b[38;5;241m.\u001b[39mcreate(apply_fn\u001b[38;5;241m=\u001b[39mmodel\u001b[38;5;241m.\u001b[39mapply, params\u001b[38;5;241m=\u001b[39mparams, tx\u001b[38;5;241m=\u001b[39mtx)\n", | ||
" \u001b[1;31m[... skipping hidden 9 frame]\u001b[0m\n", | ||
"Cell \u001b[1;32mIn[2], line 41\u001b[0m, in \u001b[0;36mLSTMEncoderDecoder.__call__\u001b[1;34m(self, x)\u001b[0m\n\u001b[0;32m 38\u001b[0m \u001b[38;5;129m@nn\u001b[39m\u001b[38;5;241m.\u001b[39mcompact\n\u001b[0;32m 39\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21m__call__\u001b[39m(\u001b[38;5;28mself\u001b[39m, x):\n\u001b[0;32m 40\u001b[0m \u001b[38;5;66;03m# Encoder\u001b[39;00m\n\u001b[1;32m---> 41\u001b[0m encoder_lstm \u001b[38;5;241m=\u001b[39m \u001b[43mnn\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mLSTMCell\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[0;32m 42\u001b[0m encoder_dense \u001b[38;5;241m=\u001b[39m nn\u001b[38;5;241m.\u001b[39mDense(features\u001b[38;5;241m=\u001b[39m\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mhidden_size)\n\u001b[0;32m 44\u001b[0m \u001b[38;5;66;03m# Decoder\u001b[39;00m\n", | ||
"File \u001b[1;32mc:\\Users\\HansDembinski\\AppData\\Local\\micromamba\\envs\\blog\\Lib\\site-packages\\flax\\linen\\kw_only_dataclasses.py:235\u001b[0m, in \u001b[0;36m_process_class.<locals>.init_wrapper\u001b[1;34m(self, *args, **kwargs)\u001b[0m\n\u001b[0;32m 227\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m num_args \u001b[38;5;241m>\u001b[39m expected_num_args:\n\u001b[0;32m 228\u001b[0m \u001b[38;5;66;03m# we add + 1 to each to account for `self`, matching python's\u001b[39;00m\n\u001b[0;32m 229\u001b[0m \u001b[38;5;66;03m# default error message\u001b[39;00m\n\u001b[0;32m 230\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mTypeError\u001b[39;00m(\n\u001b[0;32m 231\u001b[0m \u001b[38;5;124mf\u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124m__init__() takes \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mexpected_num_args\u001b[38;5;250m \u001b[39m\u001b[38;5;241m+\u001b[39m\u001b[38;5;250m \u001b[39m\u001b[38;5;241m1\u001b[39m\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m positional \u001b[39m\u001b[38;5;124m'\u001b[39m\n\u001b[0;32m 232\u001b[0m \u001b[38;5;124mf\u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124marguments but \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mnum_args\u001b[38;5;250m \u001b[39m\u001b[38;5;241m+\u001b[39m\u001b[38;5;250m \u001b[39m\u001b[38;5;241m1\u001b[39m\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m were given\u001b[39m\u001b[38;5;124m'\u001b[39m\n\u001b[0;32m 233\u001b[0m )\n\u001b[1;32m--> 235\u001b[0m \u001b[43mdataclass_init\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n", | ||
"\u001b[1;31mTypeError\u001b[0m: LSTMCell.__init__() missing 1 required positional argument: 'features'" | ||
] | ||
} | ||
], | ||
"source": [ | ||
"import jax\n", | ||
"import jax.numpy as jnp\n", | ||
"import flax.linen as nn\n", | ||
"from flax.training import train_state\n", | ||
"import optax\n", | ||
"import numpy as np\n", | ||
"\n", | ||
"Array = jax.Array\n", | ||
"PRNGKey = jax.Array\n", | ||
"CellCarry = tuple[Array, Array]\n", | ||
"\n", | ||
"\n", | ||
"# Set random seed for reproducibility\n", | ||
"root_key = jax.random.PRNGKey(0)\n", | ||
"\n", | ||
"\n", | ||
"# Generate multidimensional sequential data\n", | ||
"def generate_multidim_data(input_seq_length, output_seq_length):\n", | ||
" num_points = input_seq_length + output_seq_length\n", | ||
" t = np.linspace(0, 8 * jnp.pi, num_points)\n", | ||
" X = np.empty((num_points, input_dim))\n", | ||
"\n", | ||
" # Generate different patterns for each dimension\n", | ||
" X[:, 0] = np.sin(t)\n", | ||
" X[:, 1] = np.cos(t)\n", | ||
" X[:, 2] = np.sin(2 * t)\n", | ||
"\n", | ||
" # let's have one point overlap between input and output\n", | ||
" Y = X[input_seq_length-1:, 0] + X[input_seq_length-1:, 1] + X[input_seq_length:-1, 2]\n", | ||
"\n", | ||
" return jnp.array(X[:input_seq_length]), jnp.array(Y)\n", | ||
"\n", | ||
"\n", | ||
"class DecoderCell(nn.RNNCellBase):\n", | ||
" features: int\n", | ||
"\n", | ||
" @nn.compact\n", | ||
" def __call__(self, carry, x):\n", | ||
" state, last_prediction = carry\n", | ||
" \n", | ||
"\n", | ||
"\n", | ||
"class Seq2seq(nn.Module):\n", | ||
" \"\"\"Sequence-to-sequence class using encoder/decoder architecture.\"\"\"\n", | ||
"\n", | ||
" encoder_size: int\n", | ||
" decoder_size: int\n", | ||
"\n", | ||
" @nn.compact\n", | ||
" def __call__(self, X: Array, y: Array) -> tuple[Array, Array]:\n", | ||
" \"\"\"Applies the seq2seq model.\"\"\"\n", | ||
" # X shape (batch size, input length, input dims)\n", | ||
" # y shape (batch size, output length)\n", | ||
" encoder = nn.RNN(\n", | ||
" nn.GRUCell(self.encoder_size), return_carry=True, name=\"encoder\"\n", | ||
" )\n", | ||
" decoder = nn.RNN(nn.GRUCell(self.decoder_size), name=\"decoder\")\n", | ||
"\n", | ||
" state, _ = encoder(X)\n", | ||
" y = decoder(state, y)\n", | ||
" return y\n", | ||
"\n", | ||
"\n", | ||
"# Create and initialize the model\n", | ||
"def create_train_state(key, encoder_size, decoder_size, learning_rate, input_shape, output_size):\n", | ||
" model = Seq2seq(encoder_size, decoder_size)\n", | ||
" params = model.init(key, jnp.ones(input_shape), output_size)[\"params\"]\n", | ||
" tx = optax.nadamw(learning_rate)\n", | ||
" return train_state.TrainState.create(apply_fn=model.apply, params=params, tx=tx)\n", | ||
"\n", | ||
"\n", | ||
"# Define loss function\n", | ||
"def mse_loss(params, model, X, y):\n", | ||
" pred = model.apply(params, X, y[:, 0], y.shape[-1])\n", | ||
" return optax.squared_error(pred, y[:, 1:]).mean()\n", | ||
"\n", | ||
"\n", | ||
"# Training step\n", | ||
"@jax.jit\n", | ||
"def train_step(state, X, y):\n", | ||
" loss, grads = jax.value_and_grad(mse_loss)(state.params, state.apply_fn, X, y)\n", | ||
" state = state.apply_gradients(grads=grads)\n", | ||
" return state, loss\n", | ||
"\n", | ||
"\n", | ||
"# Main training loop\n", | ||
"def train_lstm(input_dim, input_seq_length, output_seq_length):\n", | ||
" # Hyperparameters\n", | ||
" hidden_size = 64\n", | ||
" learning_rate = 0.01\n", | ||
" num_epochs = 1000\n", | ||
" batch_size = 32\n", | ||
"\n", | ||
" # Generate data\n", | ||
" X, y = generate_multidim_data(2000, input_dim, input_seq_length, output_seq_length)\n", | ||
"\n", | ||
" # Create and initialize model\n", | ||
" state = create_train_state(\n", | ||
" key, hidden_size, input_dim, output_seq_length, learning_rate, X.shape[1:]\n", | ||
" )\n", | ||
"\n", | ||
" # Training loop\n", | ||
" for epoch in range(num_epochs):\n", | ||
" for i in range(0, len(X), batch_size):\n", | ||
" batch_X = X[i : i + batch_size]\n", | ||
" batch_y = y[i : i + batch_size]\n", | ||
" state, loss = train_step(state, batch_X, batch_y)\n", | ||
"\n", | ||
" if epoch % 100 == 0:\n", | ||
" print(f\"Epoch {epoch}, Loss: {loss}\")\n", | ||
"\n", | ||
" return state\n", | ||
"\n", | ||
"\n", | ||
"# Run training\n", | ||
"input_dim = 5 # Number of input dimensions\n", | ||
"input_seq_length = 20 # Length of input sequence (K)\n", | ||
"output_seq_length = 10 # Length of output sequence (N)\n", | ||
"trained_state = train_lstm(input_dim, input_seq_length, output_seq_length)\n", | ||
"\n", | ||
"\n", | ||
"# Evaluate the model\n", | ||
"def evaluate_model(state, X, y):\n", | ||
" predictions = state.apply_fn({\"params\": state.params}, X)\n", | ||
" mse = jnp.mean((predictions - y) ** 2)\n", | ||
" print(f\"Mean Squared Error: {mse}\")\n", | ||
"\n", | ||
" # Plot results for the first dimension\n", | ||
" import matplotlib.pyplot as plt\n", | ||
"\n", | ||
" plt.figure(figsize=(12, 6))\n", | ||
" plt.plot(y[0, :, 0], label=\"True\")\n", | ||
" plt.plot(predictions[0, :, 0], label=\"Predicted\")\n", | ||
" plt.legend()\n", | ||
" plt.title(\"LSTM Multidimensional Sequence-to-Sequence Prediction\")\n", | ||
" plt.xlabel(\"Time Steps\")\n", | ||
" plt.ylabel(\"Value\")\n", | ||
" plt.show()\n", | ||
"\n", | ||
"\n", | ||
"# Generate test data and evaluate\n", | ||
"X_test, y_test = generate_multidim_data(\n", | ||
" 200, input_dim, input_seq_length, output_seq_length\n", | ||
")\n", | ||
"\n", | ||
"evaluate_model(trained_state, X_test, y_test)" | ||
] | ||
} | ||
], | ||
"metadata": { | ||
"kernelspec": { | ||
"display_name": "Python 3", | ||
"language": "python", | ||
"name": "python3" | ||
}, | ||
"language_info": { | ||
"codemirror_mode": { | ||
"name": "ipython", | ||
"version": 3 | ||
}, | ||
"file_extension": ".py", | ||
"mimetype": "text/x-python", | ||
"name": "python", | ||
"nbconvert_exporter": "python", | ||
"pygments_lexer": "ipython3", | ||
"version": "3.12.6" | ||
} | ||
}, | ||
"nbformat": 4, | ||
"nbformat_minor": 2 | ||
} |