Skip to content

Commit 9171836

Browse files
jstacclaude
andauthored
Improve NumPy vs Numba vs JAX lecture clarity and output formatting (#446)
- Add formatted output with .6f precision for all results - Simplify race condition explanation, remove verbose details - Improve code flow with better print statement placement - Enhance readability of parallel Numba examples - Clarify JAX vmap benefits and use cases - Remove redundant multithreading note from NumPy section - Strengthen conclusion about tool trade-offs 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-authored-by: Claude <[email protected]>
1 parent 2f103af commit 9171836

File tree

1 file changed

+44
-40
lines changed

1 file changed

+44
-40
lines changed

lectures/numpy_vs_numba_vs_jax.md

Lines changed: 44 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -155,22 +155,13 @@ x, y = np.meshgrid(grid, grid)
155155
with qe.Timer(precision=8):
156156
z_max_numpy = np.max(f(x, y))
157157
158-
print(f"NumPy result: {z_max_numpy}")
158+
print(f"NumPy result: {z_max_numpy:.6f}")
159159
```
160160

161161
In the vectorized version, all the looping takes place in compiled code.
162162

163163
Moreover, NumPy uses implicit multithreading, so that at least some parallelization occurs.
164164

165-
```{note}
166-
If you have a system monitor such as htop (Linux/Mac) or perfmon
167-
(Windows), then try running this and then observing the load on your CPUs.
168-
169-
(You will probably need to bump up the grid size to see large effects.)
170-
171-
The output typically shows that the operation is successfully distributed across multiple threads.
172-
```
173-
174165
(The parallelization cannot be highly efficient because the binary is compiled
175166
before it sees the size of the arrays `x` and `y`.)
176167

@@ -195,15 +186,18 @@ def compute_max_numba(grid):
195186
grid = np.linspace(-3, 3, 3_000)
196187
197188
with qe.Timer(precision=8):
198-
compute_max_numba(grid)
189+
z_max_numpy = compute_max_numba(grid)
190+
191+
print(f"Numba result: {z_max_numpy:.6f}")
199192
```
200193

194+
Let's run again to eliminate compile time.
195+
201196
```{code-cell} ipython3
202197
with qe.Timer(precision=8):
203198
compute_max_numba(grid)
204199
```
205200

206-
207201
Depending on your machine, the Numba version can be a bit slower or a bit faster
208202
than NumPy.
209203

@@ -240,17 +234,14 @@ Usually this returns an incorrect result:
240234

241235
```{code-cell} ipython3
242236
z_max_parallel_incorrect = compute_max_numba_parallel(grid)
243-
print(f"Incorrect parallel Numba result: {z_max_parallel_incorrect}")
244-
print(f"NumPy result: {z_max_numpy}")
237+
print(f"Numba result: {z_max_parallel_incorrect} 😱")
245238
```
246239

247-
The incorrect parallel implementation typically returns `-inf` (the initial value of `m`) instead of the correct maximum value of approximately `0.9999979986680024`.
248-
249-
The reason is that the variable $m$ is shared across threads and not properly controlled.
240+
The reason is that the variable `m` is shared across threads and not properly controlled.
250241

251-
When multiple threads try to read and write `m` simultaneously, they interfere with each other, causing a race condition.
242+
When multiple threads try to read and write `m` simultaneously, they interfere with each other.
252243

253-
This results in lost updates—threads read stale values of `m` or overwrite each other's updates—and the variable often never gets updated from its initial value of `-inf`.
244+
Threads read stale values of `m` or overwrite each other's updates --— or `m` never gets updated from its initial value.
254245

255246
Here's a more carefully written version.
256247

@@ -274,30 +265,31 @@ def compute_max_numba_parallel(grid):
274265
Now the code block that `for i in numba.prange(n)` acts over is independent
275266
across `i`.
276267

277-
Each thread writes to a separate element of the array `row_maxes`.
278-
279-
Hence the parallelization is safe.
280-
281-
Here's the timings.
268+
Each thread writes to a separate element of the array `row_maxes` and
269+
the parallelization is safe.
282270

283271
```{code-cell} ipython3
284-
with qe.Timer(precision=8):
285-
compute_max_numba_parallel(grid)
272+
z_max_parallel = compute_max_numba_parallel(grid)
273+
print(f"Numba result: {z_max_parallel:.6f}")
286274
```
287275

276+
Here's the timing.
277+
288278
```{code-cell} ipython3
289279
with qe.Timer(precision=8):
290280
compute_max_numba_parallel(grid)
291281
```
292282

293-
If you have multiple cores, you should see at least some benefits from parallelization here.
283+
If you have multiple cores, you should see at least some benefits from
284+
parallelization here.
294285

295-
For more powerful machines and larger grid sizes, parallelization can generate major speed gains, even on the CPU.
286+
For more powerful machines and larger grid sizes, parallelization can generate
287+
major speed gains, even on the CPU.
296288

297289

298290
### Vectorized code with JAX
299291

300-
In most ways, vectorization is the same in JAX as it is in NumPy.
292+
On the surface, vectorized code in JAX is similar to NumPy code.
301293

302294
But there are also some differences, which we highlight here.
303295

@@ -319,14 +311,18 @@ grid = jnp.linspace(-3, 3, 3_000)
319311
x_mesh, y_mesh = np.meshgrid(grid, grid)
320312
321313
with qe.Timer(precision=8):
322-
z_max = jnp.max(f(x_mesh, y_mesh)).block_until_ready()
314+
z_max = jnp.max(f(x_mesh, y_mesh))
315+
z_max.block_until_ready()
316+
317+
print(f"Plain vanilla JAX result: {z_max:.6f}")
323318
```
324319

325320
Let's run again to eliminate compile time.
326321

327322
```{code-cell} ipython3
328323
with qe.Timer(precision=8):
329-
z_max = jnp.max(f(x_mesh, y_mesh)).block_until_ready()
324+
z_max = jnp.max(f(x_mesh, y_mesh))
325+
z_max.block_until_ready()
330326
```
331327

332328
Once compiled, JAX is significantly faster than NumPy due to GPU acceleration.
@@ -374,6 +370,8 @@ Let's see the timing:
374370
with qe.Timer(precision=8):
375371
z_max = jnp.max(f_vec(grid))
376372
z_max.block_until_ready()
373+
374+
print(f"JAX vmap v1 result: {z_max:.6f}")
377375
```
378376

379377
```{code-cell} ipython3
@@ -429,6 +427,8 @@ Let's try it.
429427
```{code-cell} ipython3
430428
with qe.Timer(precision=8):
431429
z_max = compute_max_vmap_v2(grid).block_until_ready()
430+
431+
print(f"JAX vmap v1 result: {z_max:.6f}")
432432
```
433433

434434
Let's run it again to eliminate compilation time:
@@ -445,19 +445,19 @@ If you are running this on a GPU, as we are, you should see another nontrivial s
445445

446446
In our view, JAX is the winner for vectorized operations.
447447

448-
It dominates NumPy both in terms of speed (via JIT-compilation and parallelization) and memory efficiency (via vmap).
448+
It dominates NumPy both in terms of speed (via JIT-compilation and
449+
parallelization) and memory efficiency (via vmap).
449450

450451
Moreover, the `vmap` approach can sometimes lead to significantly clearer code.
451452

452453
While Numba is impressive, the beauty of JAX is that, with fully vectorized
453-
operations, we can run exactly the
454-
same code on machines with hardware accelerators and reap all the benefits
455-
without extra effort.
454+
operations, we can run exactly the same code on machines with hardware
455+
accelerators and reap all the benefits without extra effort.
456456

457457
Moreover, JAX already knows how to effectively parallelize many common array
458458
operations, which is key to fast execution.
459459

460-
For almost all cases encountered in economics, econometrics, and finance, it is
460+
For most cases encountered in economics, econometrics, and finance, it is
461461
far better to hand over to the JAX compiler for efficient parallelization than to
462462
try to hand code these routines ourselves.
463463

@@ -537,9 +537,11 @@ This code is not easy to read but, in essence, `lax.scan` repeatedly calls `upda
537537
```{note}
538538
Sharp readers will notice that we specify `device=cpu` in the `jax.jit` decorator.
539539
540-
The computation consists of many very small `lax.scan` iterations that must run sequentially, leaving little opportunity for the GPU to exploit parallelism.
540+
The computation consists of many small sequential operations, leaving little
541+
opportunity for the GPU to exploit parallelism.
541542
542-
As a result, kernel-launch overhead tends to dominate on the GPU, making the CPU a better fit for this workload.
543+
As a result, kernel-launch overhead tends to dominate on the GPU, making the CPU
544+
a better fit for this workload.
543545
544546
Curious readers can try removing this option to see how performance changes.
545547
```
@@ -558,16 +560,17 @@ with qe.Timer(precision=8):
558560
x_jax = qm_jax(0.1, n).block_until_ready()
559561
```
560562

561-
JAX is also efficient for this sequential operation.
563+
JAX is also quite efficient for this sequential operation.
562564

563565
Both JAX and Numba deliver strong performance after compilation, with Numba
564566
typically (but not always) offering slightly better speeds on purely sequential
565567
operations.
566568

569+
567570
### Summary
568571

569572
While both Numba and JAX deliver strong performance for sequential operations,
570-
there are significant differences in code readability and ease of use.
573+
*there are significant differences in code readability and ease of use*.
571574

572575
The Numba version is straightforward and natural to read: we simply allocate an
573576
array and fill it element by element using a standard Python loop.
@@ -580,3 +583,4 @@ Additionally, JAX's immutable arrays mean we cannot simply update array elements
580583

581584
For this type of sequential operation, Numba is the clear winner in terms of
582585
code clarity and ease of implementation, as well as high performance.
586+

0 commit comments

Comments
 (0)