From a95b334249c063530887c1e4153092c61643de6c Mon Sep 17 00:00:00 2001
From: Deepak CH <deepak.23BCS10092@ms.sst.scaler.com>
Date: Fri, 2 Aug 2024 07:54:38 +0000
Subject: [PATCH 1/2] Fix Minibatch alignment in Bayesian Neural Network
 example

---
 .../bayesian_neural_network_advi.ipynb          | 17 +++++++++++------
 1 file changed, 11 insertions(+), 6 deletions(-)

diff --git a/examples/variational_inference/bayesian_neural_network_advi.ipynb b/examples/variational_inference/bayesian_neural_network_advi.ipynb
index 731de639a..bc46b320a 100644
--- a/examples/variational_inference/bayesian_neural_network_advi.ipynb
+++ b/examples/variational_inference/bayesian_neural_network_advi.ipynb
@@ -190,7 +190,7 @@
    },
    "outputs": [],
    "source": [
-    "def construct_nn(ann_input, ann_output):\n",
+    "def construct_nn():\n",
     "    n_hidden = 5\n",
     "\n",
     "    # Initialize random weights between each layer\n",
@@ -204,9 +204,14 @@
     "        \"train_cols\": np.arange(X_train.shape[1]),\n",
     "        \"obs_id\": np.arange(X_train.shape[0]),\n",
     "    }\n",
+    "    \n",
     "    with pm.Model(coords=coords) as neural_network:\n",
-    "        ann_input = pm.Data(\"ann_input\", X_train, dims=(\"obs_id\", \"train_cols\"))\n",
-    "        ann_output = pm.Data(\"ann_output\", Y_train, dims=\"obs_id\")\n",
+    "        # Define minibatch variables\n",
+    "        minibatch_x, minibatch_y = pm.Minibatch(X_train, Y_train, batch_size=50)\n",
+    "        \n",
+    "        # Define data variables using minibatches\n",
+    "        ann_input = pm.Data(\"ann_input\", minibatch_x, mutable=True, dims=(\"obs_id\", \"train_cols\"))\n",
+    "        ann_output = pm.Data(\"ann_output\", minibatch_y, mutable=True, dims=\"obs_id\")\n",
     "\n",
     "        # Weights from input to hidden layer\n",
     "        weights_in_1 = pm.Normal(\n",
@@ -231,13 +236,13 @@
     "            \"out\",\n",
     "            act_out,\n",
     "            observed=ann_output,\n",
-    "            total_size=Y_train.shape[0],  # IMPORTANT for minibatches\n",
+    "            total_size=X_train.shape[0],  # IMPORTANT for minibatches\n",
     "            dims=\"obs_id\",\n",
     "        )\n",
     "    return neural_network\n",
     "\n",
-    "\n",
-    "neural_network = construct_nn(X_train, Y_train)"
+    "# Create the neural network model\n",
+    "neural_network = construct_nn()\n"
    ]
   },
   {

From 6c2effefc2fc1d21852593ffeeebe68e487edfeb Mon Sep 17 00:00:00 2001
From: Win Wang <wang.win0+git@gmail.com>
Date: Tue, 5 Nov 2024 11:18:46 -0500
Subject: [PATCH 2/2] Run: pre-commit run all-files

---
 .../bayesian_neural_network_advi.ipynb           |  7 ++++---
 .../bayesian_neural_network_advi.myst.md         | 16 +++++++++++-----
 2 files changed, 15 insertions(+), 8 deletions(-)

diff --git a/examples/variational_inference/bayesian_neural_network_advi.ipynb b/examples/variational_inference/bayesian_neural_network_advi.ipynb
index bc46b320a..a21624b2f 100644
--- a/examples/variational_inference/bayesian_neural_network_advi.ipynb
+++ b/examples/variational_inference/bayesian_neural_network_advi.ipynb
@@ -204,11 +204,11 @@
     "        \"train_cols\": np.arange(X_train.shape[1]),\n",
     "        \"obs_id\": np.arange(X_train.shape[0]),\n",
     "    }\n",
-    "    \n",
+    "\n",
     "    with pm.Model(coords=coords) as neural_network:\n",
     "        # Define minibatch variables\n",
     "        minibatch_x, minibatch_y = pm.Minibatch(X_train, Y_train, batch_size=50)\n",
-    "        \n",
+    "\n",
     "        # Define data variables using minibatches\n",
     "        ann_input = pm.Data(\"ann_input\", minibatch_x, mutable=True, dims=(\"obs_id\", \"train_cols\"))\n",
     "        ann_output = pm.Data(\"ann_output\", minibatch_y, mutable=True, dims=\"obs_id\")\n",
@@ -241,8 +241,9 @@
     "        )\n",
     "    return neural_network\n",
     "\n",
+    "\n",
     "# Create the neural network model\n",
-    "neural_network = construct_nn()\n"
+    "neural_network = construct_nn()"
    ]
   },
   {
diff --git a/examples/variational_inference/bayesian_neural_network_advi.myst.md b/examples/variational_inference/bayesian_neural_network_advi.myst.md
index 04dc6a4f1..0c3649a09 100644
--- a/examples/variational_inference/bayesian_neural_network_advi.myst.md
+++ b/examples/variational_inference/bayesian_neural_network_advi.myst.md
@@ -114,7 +114,7 @@ A neural network is quite simple. The basic unit is a [perceptron](https://en.wi
 jupyter:
   outputs_hidden: true
 ---
-def construct_nn(ann_input, ann_output):
+def construct_nn():
     n_hidden = 5
 
     # Initialize random weights between each layer
@@ -128,9 +128,14 @@ def construct_nn(ann_input, ann_output):
         "train_cols": np.arange(X_train.shape[1]),
         "obs_id": np.arange(X_train.shape[0]),
     }
+
     with pm.Model(coords=coords) as neural_network:
-        ann_input = pm.Data("ann_input", X_train, dims=("obs_id", "train_cols"))
-        ann_output = pm.Data("ann_output", Y_train, dims="obs_id")
+        # Define minibatch variables
+        minibatch_x, minibatch_y = pm.Minibatch(X_train, Y_train, batch_size=50)
+
+        # Define data variables using minibatches
+        ann_input = pm.Data("ann_input", minibatch_x, mutable=True, dims=("obs_id", "train_cols"))
+        ann_output = pm.Data("ann_output", minibatch_y, mutable=True, dims="obs_id")
 
         # Weights from input to hidden layer
         weights_in_1 = pm.Normal(
@@ -155,13 +160,14 @@ def construct_nn(ann_input, ann_output):
             "out",
             act_out,
             observed=ann_output,
-            total_size=Y_train.shape[0],  # IMPORTANT for minibatches
+            total_size=X_train.shape[0],  # IMPORTANT for minibatches
             dims="obs_id",
         )
     return neural_network
 
 
-neural_network = construct_nn(X_train, Y_train)
+# Create the neural network model
+neural_network = construct_nn()
 ```
 
 That's not so bad. The `Normal` priors help regularize the weights. Usually we would add a constant `b` to the inputs but I omitted it here to keep the code cleaner.