Commit 2f103af
authored
FIX: jax_intro timeout: use lax.fori_loop instead of Python for loop (#442)
* Fix jax_intro timeout: use lax.fori_loop instead of Python for loop
The compute_call_price_jax function was timing out during cache.yml builds
because JAX unrolls Python for loops during JIT compilation. With large
arrays (M=10M), this causes excessive compilation time.
Solution: Replace Python for loop with jax.lax.fori_loop, which compiles
the loop efficiently without unrolling.
Fixes cell execution timeout in jax_intro.md
* style: use jstac's fori_loop naming conventions
- loop_body -> update
- state -> loop_state
- Added explicit new_loop_state and final_loop_state variables
- More verbose but clearer for first-time fori_loop readers
* style: loop_state -> initial_loop_state1 parent 39297ba commit 2f103af
1 file changed
+16
-1
lines changed| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
832 | 832 | | |
833 | 833 | | |
834 | 834 | | |
835 | | - | |
| 835 | + | |
| 836 | + | |
| 837 | + | |
836 | 838 | | |
837 | 839 | | |
838 | 840 | | |
839 | 841 | | |
| 842 | + | |
| 843 | + | |
| 844 | + | |
| 845 | + | |
| 846 | + | |
| 847 | + | |
| 848 | + | |
840 | 849 | | |
841 | 850 | | |
842 | 851 | | |
843 | 852 | | |
844 | 853 | | |
| 854 | + | |
| 855 | + | |
| 856 | + | |
| 857 | + | |
| 858 | + | |
| 859 | + | |
845 | 860 | | |
846 | 861 | | |
847 | 862 | | |
| |||
0 commit comments