From ec21970896304c0cda463d962419ea29ebff110b Mon Sep 17 00:00:00 2001 From: Patrick Kidger <33688385+patrick-kidger@users.noreply.github.com> Date: Tue, 29 Aug 2023 19:58:53 +0100 Subject: [PATCH] Typo fix --- examples/stateful.ipynb | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/examples/stateful.ipynb b/examples/stateful.ipynb index 9e93ec4e..38044b8c 100644 --- a/examples/stateful.ipynb +++ b/examples/stateful.ipynb @@ -79,8 +79,7 @@ "def compute_loss(model, state, xs, ys):\n", " # The `axis_name` argument is needed specifically for `BatchNorm`: so it knows\n", " # what axis to compute batch statistics over.\n", - " # The `in_axes` and `out_axes` are needed with all stateful operations, so that\n", - " # `ctx` isn't batched.\n", + " # The `in_axes` and `out_axes` are needed so that `state` isn't batched.\n", " batch_model = jax.vmap(\n", " model, axis_name=\"batch\", in_axes=(0, None), out_axes=(0, None)\n", " )\n",