You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
The incorrect parallel implementation typically returns `-inf` (the initial value of `m`) instead of the correct maximum value of approximately `0.9999979986680024`.
248
-
249
-
The reason is that the variable $m$ is shared across threads and not properly controlled.
240
+
The reason is that the variable `m` is shared across threads and not properly controlled.
250
241
251
-
When multiple threads try to read and write `m` simultaneously, they interfere with each other, causing a race condition.
242
+
When multiple threads try to read and write `m` simultaneously, they interfere with each other.
252
243
253
-
This results in lost updates—threads read stale values of `m` or overwrite each other's updates—and the variable often never gets updated from its initial value of `-inf`.
244
+
Threads read stale values of `m` or overwrite each other's updates --— or `m`never gets updated from its initial value.
@@ -445,19 +445,19 @@ If you are running this on a GPU, as we are, you should see another nontrivial s
445
445
446
446
In our view, JAX is the winner for vectorized operations.
447
447
448
-
It dominates NumPy both in terms of speed (via JIT-compilation and parallelization) and memory efficiency (via vmap).
448
+
It dominates NumPy both in terms of speed (via JIT-compilation and
449
+
parallelization) and memory efficiency (via vmap).
449
450
450
451
Moreover, the `vmap` approach can sometimes lead to significantly clearer code.
451
452
452
453
While Numba is impressive, the beauty of JAX is that, with fully vectorized
453
-
operations, we can run exactly the
454
-
same code on machines with hardware accelerators and reap all the benefits
455
-
without extra effort.
454
+
operations, we can run exactly the same code on machines with hardware
455
+
accelerators and reap all the benefits without extra effort.
456
456
457
457
Moreover, JAX already knows how to effectively parallelize many common array
458
458
operations, which is key to fast execution.
459
459
460
-
For almost all cases encountered in economics, econometrics, and finance, it is
460
+
For most cases encountered in economics, econometrics, and finance, it is
461
461
far better to hand over to the JAX compiler for efficient parallelization than to
462
462
try to hand code these routines ourselves.
463
463
@@ -537,9 +537,11 @@ This code is not easy to read but, in essence, `lax.scan` repeatedly calls `upda
537
537
```{note}
538
538
Sharp readers will notice that we specify `device=cpu` in the `jax.jit` decorator.
539
539
540
-
The computation consists of many very small `lax.scan` iterations that must run sequentially, leaving little opportunity for the GPU to exploit parallelism.
540
+
The computation consists of many small sequential operations, leaving little
541
+
opportunity for the GPU to exploit parallelism.
541
542
542
-
As a result, kernel-launch overhead tends to dominate on the GPU, making the CPU a better fit for this workload.
543
+
As a result, kernel-launch overhead tends to dominate on the GPU, making the CPU
544
+
a better fit for this workload.
543
545
544
546
Curious readers can try removing this option to see how performance changes.
545
547
```
@@ -558,16 +560,17 @@ with qe.Timer(precision=8):
558
560
x_jax = qm_jax(0.1, n).block_until_ready()
559
561
```
560
562
561
-
JAX is also efficient for this sequential operation.
563
+
JAX is also quite efficient for this sequential operation.
562
564
563
565
Both JAX and Numba deliver strong performance after compilation, with Numba
564
566
typically (but not always) offering slightly better speeds on purely sequential
565
567
operations.
566
568
569
+
567
570
### Summary
568
571
569
572
While both Numba and JAX deliver strong performance for sequential operations,
570
-
there are significant differences in code readability and ease of use.
573
+
*there are significant differences in code readability and ease of use*.
571
574
572
575
The Numba version is straightforward and natural to read: we simply allocate an
573
576
array and fill it element by element using a standard Python loop.
@@ -580,3 +583,4 @@ Additionally, JAX's immutable arrays mean we cannot simply update array elements
580
583
581
584
For this type of sequential operation, Numba is the clear winner in terms of
582
585
code clarity and ease of implementation, as well as high performance.
0 commit comments