Skip to content

Commit 2f103af

Browse files
authored
FIX: jax_intro timeout: use lax.fori_loop instead of Python for loop (#442)
* Fix jax_intro timeout: use lax.fori_loop instead of Python for loop The compute_call_price_jax function was timing out during cache.yml builds because JAX unrolls Python for loops during JIT compilation. With large arrays (M=10M), this causes excessive compilation time. Solution: Replace Python for loop with jax.lax.fori_loop, which compiles the loop efficiently without unrolling. Fixes cell execution timeout in jax_intro.md * style: use jstac's fori_loop naming conventions - loop_body -> update - state -> loop_state - Added explicit new_loop_state and final_loop_state variables - More verbose but clearer for first-time fori_loop readers * style: loop_state -> initial_loop_state
1 parent 39297ba commit 2f103af

File tree

1 file changed

+16
-1
lines changed

1 file changed

+16
-1
lines changed

lectures/jax_intro.md

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -832,16 +832,31 @@ def compute_call_price_jax(β=β,
832832
833833
s = jnp.full(M, np.log(S0))
834834
h = jnp.full(M, h0)
835-
for t in range(n):
835+
836+
def update(i, loop_state):
837+
s, h, key = loop_state
836838
key, subkey = jax.random.split(key)
837839
Z = jax.random.normal(subkey, (2, M))
838840
s = s + μ + jnp.exp(h) * Z[0, :]
839841
h = ρ * h + ν * Z[1, :]
842+
new_loop_state = s, h, key
843+
return new_loop_state
844+
845+
initial_loop_state = s, h, key
846+
final_loop_state = jax.lax.fori_loop(0, n, update, initial_loop_state)
847+
s, h, key = final_loop_state
848+
840849
expectation = jnp.mean(jnp.maximum(jnp.exp(s) - K, 0))
841850
842851
return β**n * expectation
843852
```
844853

854+
```{note}
855+
We use `jax.lax.fori_loop` instead of a Python `for` loop.
856+
This allows JAX to compile the loop efficiently without unrolling it,
857+
which significantly reduces compilation time for large arrays.
858+
```
859+
845860
Let's run it once to compile it:
846861

847862
```{code-cell} ipython3

0 commit comments

Comments
 (0)