Skip to content

Add JAX reference implementations for CPU-slow challenges#278

Merged
kunal-mansukhani merged 2 commits into
mainfrom
add-jax-reference-impls
Jun 7, 2026
Merged

Add JAX reference implementations for CPU-slow challenges#278
kunal-mansukhani merged 2 commits into
mainfrom
add-jax-reference-impls

Conversation

@kunal-mansukhani

@kunal-mansukhani kunal-mansukhani commented Jun 7, 2026

Copy link
Copy Markdown
Contributor

No description provided.

The TPU runner uses CPU torch as ground truth, but reference_impl is
intractably slow on CPU for large matmuls, attention, transformer blocks,
convolutions, O(n^2) ops, and challenges with Python-level sequential
loops. The runner prefers a challenge-supplied reference_impl_jax when
present.

Add reference_impl_jax to 30 such challenges. Each mirrors reference_impl
exactly under the runner contract: takes the in/inout params in signature
order as jax arrays, returns the out/inout results, fp32 matmuls at
highest precision. Python sequential loops are replaced with jax.lax.scan
(linear_recurrence, ssm_selective_scan, adder_transformer decode) and
jax.lax.fori_loop (all_pairs_shortest_paths), and nested .item() loops are
vectorized (categorical_cross_entropy, speculative_decoding_verification).

jax is imported lazily inside reference_impl_jax (not at module scope) so
challenge.py still imports and runs the torch reference_impl on images
where jax is not installed (e.g. the CUDA runner).

Verified each against the torch reference on its functional tests at the
challenge's own atol/rtol.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
@kunal-mansukhani kunal-mansukhani force-pushed the add-jax-reference-impls branch from 3566a2a to f3190f5 Compare June 7, 2026 04:16
Migrate generate_performance_test on the jax-ref'd challenges from imperative
torch construction to declarative specs (RandTensor/RandnTensor/RandIntTensor/
FullTensor/OutTensor); the runner materializes inputs on the accelerator (jax on
TPU) instead of generating on CPU and copying over. 18 migrated; structured-input
challenges left imperative.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
@kunal-mansukhani kunal-mansukhani force-pushed the add-jax-reference-impls branch from ca8d713 to 9b536e7 Compare June 7, 2026 06:37
@kunal-mansukhani kunal-mansukhani merged commit 47240a7 into main Jun 7, 2026
4 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant