You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
ENH: Enable RunsOn GPU support for lecture builds (#437)
* Enable RunsOn GPU support for lecture builds
- Add scripts/test-jax-install.py to verify JAX/GPU installation
- Add .github/runs-on.yml with QuantEcon Ubuntu 24.04 AMI configuration
- Update cache.yml to use RunsOn g4dn.2xlarge GPU runner
- Update ci.yml to use RunsOn g4dn.2xlarge GPU runner
- Update publish.yml to use RunsOn g4dn.2xlarge GPU runner
- Install JAX with CUDA 13 support and Numpyro on all workflows
- Add nvidia-smi check to verify GPU availability
This mirrors the setup used in lecture-python.myst repository.
* DOC: Update JAX lectures with GPU admonition and narrative
- Add standard GPU admonition to jax_intro.md and numpy_vs_numba_vs_jax.md
- Update introduction in jax_intro.md to reflect GPU access
- Update conditional GPU language to reflect lectures now run on GPU
- Following QuantEcon style guide for JAX lectures
* DEBUG: Add hardware benchmark script to diagnose performance
- Add benchmark-hardware.py with CPU, NumPy, Numba, and JAX benchmarks
- Works on both GPU (RunsOn) and CPU-only (GitHub Actions) runners
- Include warm-up vs compiled timing to isolate JIT overhead
- Add system info collection (CPU model, frequency, GPU detection)
* Add multi-pathway benchmark tests (bare metal, Jupyter, jupyter-book)
* Fix: Add content to benchmark-jupyter.ipynb (was empty)
* Fix: Add benchmark content to benchmark-jupyter.ipynb
* Add JSON output to benchmarks and upload as artifacts
- Update benchmark-hardware.py to save results to JSON
- Update benchmark-jupyter.ipynb to save results to JSON
- Update benchmark-jupyterbook.md to save results to JSON
- Add CI step to collect and display benchmark results
- Add CI step to upload benchmark results as artifact
* Fix syntax errors in benchmark-hardware.py
- Remove extra triple quote at start of file
- Remove stray parentheses in result assignments
* Sync benchmark scripts with CPU branch for comparable results
- Copy benchmark-hardware.py from debug/benchmark-github-actions
- Copy benchmark-jupyter.ipynb from debug/benchmark-github-actions
- Copy benchmark-jupyterbook.md from debug/benchmark-github-actions
- Update ci.yml to use matching file names
The test scripts are now identical between both branches,
only the CI workflow differs (runner type and JAX installation).
* ENH: Force lax.scan sequential operation to run on CPU
Add device=cpu to the qm_jax function decorator to avoid the known
XLA limitation where lax.scan with millions of lightweight iterations
performs poorly on GPU due to CPU-GPU synchronization overhead.
Added explanatory note about this pattern.
Co-authored-by: HumphreyYang <[email protected]>
* update note
* Add lax.scan profiler to CI for GPU debugging
- Add scripts/profile_lax_scan.py: Profiles lax.scan performance on GPU vs CPU
to investigate the synchronization overhead issue (JAX Issue #2491)
- Add CI step to run profiler with 100K iterations on RunsOn GPU environment
- Script supports multiple profiling modes: basic timing, Nsight, JAX profiler, XLA dumps
* Add diagnostic mode to lax.scan profiler
- Add --diagnose flag that tests time scaling across iteration counts
- If time scales linearly with iterations (not compute), it proves
constant per-iteration overhead (CPU-GPU synchronization)
- Also add --verbose flag for CUDA/XLA logging
- Update CI to run with --diagnose flag
* Add Nsight Systems profiling to CI
- Run nsys profile with 1000 iterations if nsys is available
- Captures CUDA, NVTX, and OS runtime traces
- Uploads .nsys-rep file as artifact for visual analysis
- continue-on-error: true so CI doesn't fail if nsys unavailable
* address @jstac comment
* Improve JAX lecture content and pedagogy
- Reorganize jax_intro.md to introduce JAX features upfront with clearer structure
- Expand JAX introduction with bulleted list of key capabilities (parallelization, JIT, autodiff)
- Add explicit GPU performance notes in vmap sections
- Enhance vmap explanation with detailed function composition breakdown
- Clarify memory efficiency tradeoffs between different vmap approaches
🤖 Generated with [Claude Code](https://claude.com/claude-code)
Co-Authored-By: Claude <[email protected]>
* Remove benchmark scripts (moved to QuantEcon/benchmarks)
- Remove profile_lax_scan.py, benchmark-hardware.py, benchmark-jupyter.ipynb, benchmark-jupyterbook.md
- Remove profiling/benchmarking steps from ci.yml
- Keep test-jax-install.py for JAX installation verification
Benchmark scripts are now maintained in: https://github.com/QuantEcon/benchmarks
* Update lectures/numpy_vs_numba_vs_jax.md
* Add GPU and JAX hardware details to status page
- Add nvidia-smi output to show GPU availability
- Add JAX backend check to confirm GPU usage
- Matches format used in lecture-python.myst
---------
Co-authored-by: HumphreyYang <[email protected]>
Co-authored-by: Humphrey Yang <[email protected]>
Co-authored-by: John Stachurski <[email protected]>
Co-authored-by: Claude <[email protected]>
This code is not easy to read but, in essence, `lax.scan` repeatedly calls `update` and accumulates the returns `x_new` into an array.
511
536
537
+
```{note}
538
+
Sharp readers will notice that we specify `device=cpu` in the `jax.jit` decorator.
539
+
540
+
The computation consists of many very small `lax.scan` iterations that must run sequentially, leaving little opportunity for the GPU to exploit parallelism.
541
+
542
+
As a result, kernel-launch overhead tends to dominate on the GPU, making the CPU a better fit for this workload.
543
+
544
+
Curious readers can try removing this option to see how performance changes.
0 commit comments