Skip to content

Commit 9045b9f

Browse files
mmckyHumphreyYangjstacclaude
authored
ENH: Enable RunsOn GPU support for lecture builds (#437)
* Enable RunsOn GPU support for lecture builds - Add scripts/test-jax-install.py to verify JAX/GPU installation - Add .github/runs-on.yml with QuantEcon Ubuntu 24.04 AMI configuration - Update cache.yml to use RunsOn g4dn.2xlarge GPU runner - Update ci.yml to use RunsOn g4dn.2xlarge GPU runner - Update publish.yml to use RunsOn g4dn.2xlarge GPU runner - Install JAX with CUDA 13 support and Numpyro on all workflows - Add nvidia-smi check to verify GPU availability This mirrors the setup used in lecture-python.myst repository. * DOC: Update JAX lectures with GPU admonition and narrative - Add standard GPU admonition to jax_intro.md and numpy_vs_numba_vs_jax.md - Update introduction in jax_intro.md to reflect GPU access - Update conditional GPU language to reflect lectures now run on GPU - Following QuantEcon style guide for JAX lectures * DEBUG: Add hardware benchmark script to diagnose performance - Add benchmark-hardware.py with CPU, NumPy, Numba, and JAX benchmarks - Works on both GPU (RunsOn) and CPU-only (GitHub Actions) runners - Include warm-up vs compiled timing to isolate JIT overhead - Add system info collection (CPU model, frequency, GPU detection) * Add multi-pathway benchmark tests (bare metal, Jupyter, jupyter-book) * Fix: Add content to benchmark-jupyter.ipynb (was empty) * Fix: Add benchmark content to benchmark-jupyter.ipynb * Add JSON output to benchmarks and upload as artifacts - Update benchmark-hardware.py to save results to JSON - Update benchmark-jupyter.ipynb to save results to JSON - Update benchmark-jupyterbook.md to save results to JSON - Add CI step to collect and display benchmark results - Add CI step to upload benchmark results as artifact * Fix syntax errors in benchmark-hardware.py - Remove extra triple quote at start of file - Remove stray parentheses in result assignments * Sync benchmark scripts with CPU branch for comparable results - Copy benchmark-hardware.py from debug/benchmark-github-actions - Copy benchmark-jupyter.ipynb from debug/benchmark-github-actions - Copy benchmark-jupyterbook.md from debug/benchmark-github-actions - Update ci.yml to use matching file names The test scripts are now identical between both branches, only the CI workflow differs (runner type and JAX installation). * ENH: Force lax.scan sequential operation to run on CPU Add device=cpu to the qm_jax function decorator to avoid the known XLA limitation where lax.scan with millions of lightweight iterations performs poorly on GPU due to CPU-GPU synchronization overhead. Added explanatory note about this pattern. Co-authored-by: HumphreyYang <[email protected]> * update note * Add lax.scan profiler to CI for GPU debugging - Add scripts/profile_lax_scan.py: Profiles lax.scan performance on GPU vs CPU to investigate the synchronization overhead issue (JAX Issue #2491) - Add CI step to run profiler with 100K iterations on RunsOn GPU environment - Script supports multiple profiling modes: basic timing, Nsight, JAX profiler, XLA dumps * Add diagnostic mode to lax.scan profiler - Add --diagnose flag that tests time scaling across iteration counts - If time scales linearly with iterations (not compute), it proves constant per-iteration overhead (CPU-GPU synchronization) - Also add --verbose flag for CUDA/XLA logging - Update CI to run with --diagnose flag * Add Nsight Systems profiling to CI - Run nsys profile with 1000 iterations if nsys is available - Captures CUDA, NVTX, and OS runtime traces - Uploads .nsys-rep file as artifact for visual analysis - continue-on-error: true so CI doesn't fail if nsys unavailable * address @jstac comment * Improve JAX lecture content and pedagogy - Reorganize jax_intro.md to introduce JAX features upfront with clearer structure - Expand JAX introduction with bulleted list of key capabilities (parallelization, JIT, autodiff) - Add explicit GPU performance notes in vmap sections - Enhance vmap explanation with detailed function composition breakdown - Clarify memory efficiency tradeoffs between different vmap approaches 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude <[email protected]> * Remove benchmark scripts (moved to QuantEcon/benchmarks) - Remove profile_lax_scan.py, benchmark-hardware.py, benchmark-jupyter.ipynb, benchmark-jupyterbook.md - Remove profiling/benchmarking steps from ci.yml - Keep test-jax-install.py for JAX installation verification Benchmark scripts are now maintained in: https://github.com/QuantEcon/benchmarks * Update lectures/numpy_vs_numba_vs_jax.md * Add GPU and JAX hardware details to status page - Add nvidia-smi output to show GPU availability - Add JAX backend check to confirm GPU usage - Matches format used in lecture-python.myst --------- Co-authored-by: HumphreyYang <[email protected]> Co-authored-by: Humphrey Yang <[email protected]> Co-authored-by: John Stachurski <[email protected]> Co-authored-by: Claude <[email protected]>
1 parent a4b89d4 commit 9045b9f

File tree

7 files changed

+141
-42
lines changed

7 files changed

+141
-42
lines changed

.github/workflows/cache.yml

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ on:
66
workflow_dispatch:
77
jobs:
88
cache:
9-
runs-on: ubuntu-latest
9+
runs-on: "runs-on=${{ github.run_id }}/family=g4dn.2xlarge/image=quantecon_ubuntu2404/disk=large"
1010
steps:
1111
- uses: actions/checkout@v6
1212
- name: Setup Anaconda
@@ -18,6 +18,16 @@ jobs:
1818
python-version: "3.13"
1919
environment-file: environment.yml
2020
activate-environment: quantecon
21+
- name: Install JAX and Numpyro
22+
shell: bash -l {0}
23+
run: |
24+
pip install -U "jax[cuda13]"
25+
pip install numpyro
26+
python scripts/test-jax-install.py
27+
- name: Check nvidia drivers
28+
shell: bash -l {0}
29+
run: |
30+
nvidia-smi
2131
- name: Build HTML
2232
shell: bash -l {0}
2333
run: |

.github/workflows/ci.yml

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ name: Build Project [using jupyter-book]
22
on: [pull_request]
33
jobs:
44
preview:
5-
runs-on: ubuntu-latest
5+
runs-on: "runs-on=${{ github.run_id }}/family=g4dn.2xlarge/image=quantecon_ubuntu2404/disk=large"
66
steps:
77
- uses: actions/checkout@v6
88
with:
@@ -16,6 +16,15 @@ jobs:
1616
python-version: "3.13"
1717
environment-file: environment.yml
1818
activate-environment: quantecon
19+
- name: Check nvidia Drivers
20+
shell: bash -l {0}
21+
run: nvidia-smi
22+
- name: Install JAX and Numpyro
23+
shell: bash -l {0}
24+
run: |
25+
pip install -U "jax[cuda13]"
26+
pip install numpyro
27+
python scripts/test-jax-install.py
1928
- name: Install latex dependencies
2029
run: |
2130
sudo apt-get -qq update

.github/workflows/publish.yml

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ on:
66
jobs:
77
publish:
88
if: github.event_name == 'push' && startsWith(github.event.ref, 'refs/tags')
9-
runs-on: ubuntu-latest
9+
runs-on: "runs-on=${{ github.run_id }}/family=g4dn.2xlarge/image=quantecon_ubuntu2404/disk=large"
1010
steps:
1111
- name: Checkout
1212
uses: actions/checkout@v6
@@ -21,6 +21,16 @@ jobs:
2121
python-version: "3.13"
2222
environment-file: environment.yml
2323
activate-environment: quantecon
24+
- name: Install JAX and Numpyro
25+
shell: bash -l {0}
26+
run: |
27+
pip install -U "jax[cuda13]"
28+
pip install numpyro
29+
python scripts/test-jax-install.py
30+
- name: Check nvidia drivers
31+
shell: bash -l {0}
32+
run: |
33+
nvidia-smi
2434
- name: Install latex dependencies
2535
run: |
2636
sudo apt-get -qq update

lectures/jax_intro.md

Lines changed: 26 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,18 @@ kernelspec:
1313

1414
# JAX
1515

16+
This lecture provides a short introduction to [Google JAX](https://github.com/jax-ml/jax).
17+
18+
JAX is a high-performance scientific computing library that provides
19+
20+
* a NumPy-like interface that can automatically parallize across CPUs and GPUs,
21+
* a just-in-time compiler for accelerating a large range of numerical
22+
operations, and
23+
* automatic differentiation.
24+
25+
Increasingly, JAX also maintains and provides more specialized scientific
26+
computing routines, such as those originally found in SciPy.
27+
1628
In addition to what's in Anaconda, this lecture will need the following libraries:
1729

1830
```{code-cell} ipython3
@@ -21,28 +33,24 @@ In addition to what's in Anaconda, this lecture will need the following librarie
2133
!pip install jax quantecon
2234
```
2335

24-
This lecture provides a short introduction to [Google JAX](https://github.com/jax-ml/jax).
25-
26-
Here we are focused on using JAX on the CPU, rather than on accelerators such as
27-
GPUs or TPUs.
28-
29-
This means we will only see a small amount of the possible benefits from using JAX.
30-
31-
However, JAX seamlessly handles transitions across different hardware platforms.
36+
```{admonition} GPU
37+
:class: warning
3238
33-
As a result, if you run this code on a machine with a GPU and a GPU-aware
34-
version of JAX installed, your code will be automatically accelerated and you
35-
will receive the full benefits.
39+
This lecture is accelerated via [hardware](status:machine-details) that has access to a GPU and target JAX for GPU programming.
3640
37-
For a discussion of JAX on GPUs, see [our JAX lecture series](https://jax.quantecon.org/intro.html).
41+
Free GPUs are available on Google Colab.
42+
To use this option, please click on the play icon top right, select Colab, and set the runtime environment to include a GPU.
3843
44+
Alternatively, if you have your own GPU, you can follow the [instructions](https://github.com/google/jax) for installing JAX with GPU support.
45+
If you would like to install JAX running on the `cpu` only you can use `pip install jax[cpu]`
46+
```
3947

4048
## JAX as a NumPy Replacement
4149

42-
One of the attractive features of JAX is that, whenever possible, it conforms to
43-
the NumPy API for array operations.
50+
One of the attractive features of JAX is that, whenever possible, its array
51+
processing operations conform to the NumPy API.
4452

45-
This means that, to a large extent, we can use JAX is as a drop-in NumPy replacement.
53+
This means that, in many cases, we can use JAX is as a drop-in NumPy replacement.
4654

4755
Let's look at the similarities and differences between JAX and NumPy.
4856

@@ -523,16 +531,9 @@ with qe.Timer():
523531
jax.block_until_ready(y);
524532
```
525533

526-
If you are running this on a GPU the code will run much faster than its NumPy
527-
equivalent, which ran on the CPU.
528-
529-
Even if you are running on a machine with many CPUs, the second JAX run should
530-
be substantially faster with JAX.
531-
532-
Also, typically, the second run is faster than the first.
534+
On a GPU, this code runs much faster than its NumPy equivalent.
533535

534-
(This might not be noticable on the CPU but it should definitely be noticable on
535-
the GPU.)
536+
Also, typically, the second run is faster than the first due to JIT compilation.
536537

537538
This is because even built in functions like `jnp.cos` are JIT-compiled --- and the
538539
first run includes compile time.
@@ -634,8 +635,7 @@ with qe.Timer():
634635
jax.block_until_ready(y);
635636
```
636637

637-
The outcome is similar to the `cos` example --- JAX is faster, especially if you
638-
use a GPU and especially on the second run.
638+
The outcome is similar to the `cos` example --- JAX is faster, especially on the second run after JIT compilation.
639639

640640
Moreover, with JAX, we have another trick up our sleeve:
641641

lectures/numpy_vs_numba_vs_jax.md

Lines changed: 48 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,18 @@ tags: [hide-output]
4848
!pip install quantecon jax
4949
```
5050

51+
```{admonition} GPU
52+
:class: warning
53+
54+
This lecture is accelerated via [hardware](status:machine-details) that has access to a GPU and target JAX for GPU programming.
55+
56+
Free GPUs are available on Google Colab.
57+
To use this option, please click on the play icon top right, select Colab, and set the runtime environment to include a GPU.
58+
59+
Alternatively, if you have your own GPU, you can follow the [instructions](https://github.com/google/jax) for installing JAX with GPU support.
60+
If you would like to install JAX running on the `cpu` only you can use `pip install jax[cpu]`
61+
```
62+
5163
We will use the following imports.
5264

5365
```{code-cell} ipython3
@@ -317,7 +329,7 @@ with qe.Timer(precision=8):
317329
z_max = jnp.max(f(x_mesh, y_mesh)).block_until_ready()
318330
```
319331

320-
Once compiled, JAX will be significantly faster than NumPy, especially if you are using a GPU.
332+
Once compiled, JAX is significantly faster than NumPy due to GPU acceleration.
321333

322334
The compilation overhead is a one-time cost that pays off when the function is called repeatedly.
323335

@@ -370,23 +382,29 @@ with qe.Timer(precision=8):
370382
z_max.block_until_ready()
371383
```
372384

373-
The execution time is similar to the mesh operation but, by avoiding the large input arrays `x_mesh` and `y_mesh`,
374-
we are using far less memory.
385+
By avoiding the large input arrays `x_mesh` and `y_mesh`, this `vmap` version uses far less memory.
386+
387+
When run on a CPU, its runtime is similar to that of the meshgrid version.
375388

376-
In addition, `vmap` allows us to break vectorization up into stages, which is
377-
often easier to comprehend than the traditional approach.
389+
When run on a GPU, it is usually significantly faster.
378390

379-
This will become more obvious when we tackle larger problems.
391+
In fact, using `vmap` has another advantage: It allows us to break vectorization up into stages.
392+
393+
This leads to code that is often easier to comprehend than traditional vectorized code.
394+
395+
We will investigate these ideas more when we tackle larger problems.
380396

381397

382398
### vmap version 2
383399

384400
We can be still more memory efficient using vmap.
385401

386-
While we avoided large input arrays in the preceding version,
402+
While we avoid large input arrays in the preceding version,
387403
we still create the large output array `f(x,y)` before we compute the max.
388404

389-
Let's use a slightly different approach that takes the max to the inside.
405+
Let's try a slightly different approach that takes the max to the inside.
406+
407+
Because of this change, we never compute the two-dimensional array `f(x,y)`.
390408

391409
```{code-cell} ipython3
392410
@jax.jit
@@ -399,23 +417,28 @@ def compute_max_vmap_v2(grid):
399417
return jnp.max(f_vec_max(grid))
400418
```
401419

402-
Let's try it
420+
Here
421+
422+
* `f_vec_x_max` computes the max along any given row
423+
* `f_vec_max` is a vectorized version that can compute the max of all rows in parallel.
424+
425+
We apply this function to all rows and then take the max of the row maxes.
426+
427+
Let's try it.
403428

404429
```{code-cell} ipython3
405430
with qe.Timer(precision=8):
406431
z_max = compute_max_vmap_v2(grid).block_until_ready()
407432
```
408433

409-
410434
Let's run it again to eliminate compilation time:
411435

412436
```{code-cell} ipython3
413437
with qe.Timer(precision=8):
414438
z_max = compute_max_vmap_v2(grid).block_until_ready()
415439
```
416440

417-
We don't get much speed gain but we do save some memory.
418-
441+
If you are running this on a GPU, as we are, you should see another nontrivial speed gain.
419442

420443

421444
### Summary
@@ -497,7 +520,9 @@ Now let's create a JAX version using `lax.scan`:
497520
from jax import lax
498521
from functools import partial
499522
500-
@partial(jax.jit, static_argnums=(1,))
523+
cpu = jax.devices("cpu")[0]
524+
525+
@partial(jax.jit, static_argnums=(1,), device=cpu)
501526
def qm_jax(x0, n, α=4.0):
502527
def update(x, t):
503528
x_new = α * x * (1 - x)
@@ -509,6 +534,16 @@ def qm_jax(x0, n, α=4.0):
509534

510535
This code is not easy to read but, in essence, `lax.scan` repeatedly calls `update` and accumulates the returns `x_new` into an array.
511536

537+
```{note}
538+
Sharp readers will notice that we specify `device=cpu` in the `jax.jit` decorator.
539+
540+
The computation consists of many very small `lax.scan` iterations that must run sequentially, leaving little opportunity for the GPU to exploit parallelism.
541+
542+
As a result, kernel-launch overhead tends to dominate on the GPU, making the CPU a better fit for this workload.
543+
544+
Curious readers can try removing this option to see how performance changes.
545+
```
546+
512547
Let's time it with the same parameters:
513548

514549
```{code-cell} ipython3

lectures/status.md

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,4 +31,18 @@ and the following package versions
3131
```{code-cell} ipython
3232
:tags: [hide-output]
3333
!conda list
34+
```
35+
36+
This lecture series has access to the following GPU
37+
38+
```{code-cell} ipython
39+
!nvidia-smi
40+
```
41+
42+
You can check the backend used by JAX using:
43+
44+
```{code-cell} ipython3
45+
import jax
46+
# Check if JAX is using GPU
47+
print(f"JAX backend: {jax.devices()[0].platform}")
3448
```

scripts/test-jax-install.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
import jax
2+
import jax.numpy as jnp
3+
4+
devices = jax.devices()
5+
print(f"The available devices are: {devices}")
6+
7+
@jax.jit
8+
def matrix_multiply(a, b):
9+
return jnp.dot(a, b)
10+
11+
# Example usage:
12+
key = jax.random.PRNGKey(0)
13+
x = jax.random.normal(key, (1000, 1000))
14+
y = jax.random.normal(key, (1000, 1000))
15+
z = matrix_multiply(x, y)
16+
17+
# Now the function is JIT compiled and will likely run on GPU (if available)
18+
print(z)
19+
20+
devices = jax.devices()
21+
print(f"The available devices are: {devices}")

0 commit comments

Comments
 (0)