Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions CLAUDE.md
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,7 @@ Each starter file must have exactly one comment describing the parameters, place
| CUDA | `// <params> are device pointers` |
| Mojo | `# <params> are device pointers` |
| PyTorch, Triton, CuTe | `# <params> are tensors on the GPU` |
| JAX | `# <params> are tensors on GPU` (+ `# return output tensor directly` inside body) |
| JAX | `# <params> are tensors on device` (+ `# return output tensor directly` inside body) |

**Rules:**
- Easy challenges: include the parenthetical `(i.e. pointers to memory on the GPU)` for CUDA/Mojo (matches vector_add reference)
Expand Down Expand Up @@ -200,7 +200,7 @@ Verify every item before submitting. This is the single source of truth — work
- [ ] All 6 files present: `.cu`, `.pytorch.py`, `.triton.py`, `.jax.py`, `.cute.py`, `.mojo`
- [ ] Exactly 1 parameter description comment per file, no other comments
- [ ] CUDA/Mojo use "device pointers"; easy challenges include `(i.e. pointers to memory on the GPU)`, medium/hard omit it
- [ ] Python frameworks use "tensors on the GPU"; JAX also has `# return output tensor directly`
- [ ] PyTorch/Triton/CuTe use "tensors on the GPU"; JAX uses "tensors on device" and also has `# return output tensor directly`
- [ ] Starters compile/run but do NOT produce correct output

### General
Expand Down
2 changes: 1 addition & 1 deletion challenges/easy/19_reverse_array/starter/starter.jax.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import jax.numpy as jnp


# input is a tensor on the GPU
# input is a tensor on device
@jax.jit
def solve(input: jax.Array, N: int) -> jax.Array:
# return output tensor directly
Expand Down
2 changes: 1 addition & 1 deletion challenges/easy/1_vector_add/starter/starter.jax.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import jax.numpy as jnp


# A, B are tensors on GPU
# A, B are tensors on device
@jax.jit
def solve(A: jax.Array, B: jax.Array, N: int) -> jax.Array:
# return output tensor directly
Expand Down
2 changes: 1 addition & 1 deletion challenges/easy/21_relu/starter/starter.jax.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import jax.numpy as jnp


# input is a tensor on the GPU
# input is a tensor on device
@jax.jit
def solve(input: jax.Array, N: int) -> jax.Array:
# return output tensor directly
Expand Down
2 changes: 1 addition & 1 deletion challenges/easy/23_leaky_relu/starter/starter.jax.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import jax.numpy as jnp


# input is a tensor on the GPU
# input is a tensor on device
@jax.jit
def solve(input: jax.Array, N: int) -> jax.Array:
# return output tensor directly
Expand Down
2 changes: 1 addition & 1 deletion challenges/easy/24_rainbow_table/starter/starter.jax.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ def fnv1a_hash(x: jax.Array) -> jax.Array:
return hash_val


# input is a tensor on the GPU
# input is a tensor on device
def solve(input: jax.Array, N: int, R: int) -> jax.Array:
# return output tensor directly
pass
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import jax.numpy as jnp


# A, B are tensors on GPU
# A, B are tensors on device
@jax.jit
def solve(A: jax.Array, B: jax.Array, M: int, N: int, K: int) -> jax.Array:
# return output tensor directly
Expand Down
2 changes: 1 addition & 1 deletion challenges/easy/31_matrix_copy/starter/starter.jax.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import jax.numpy as jnp


# A is a tensor on the GPU
# A is a tensor on device
@jax.jit
def solve(A: jax.Array, N: int) -> jax.Array:
# return output tensor directly
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import jax.numpy as jnp


# input is a tensor on GPU
# input is a tensor on device
@jax.jit
def solve(input: jax.Array, rows: int, cols: int) -> jax.Array:
# return output tensor directly
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import jax.numpy as jnp


# input and model are on the GPU
# input and model are on device
def solve(input: jax.Array, model) -> jax.Array:
# return output tensor directly
pass
2 changes: 1 addition & 1 deletion challenges/easy/52_silu/starter/starter.jax.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import jax.numpy as jnp


# input is a tensor on the GPU
# input is a tensor on device
@jax.jit
def solve(input: jax.Array, N: int) -> jax.Array:
# return output tensor directly
Expand Down
2 changes: 1 addition & 1 deletion challenges/easy/54_swiglu/starter/starter.jax.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import jax.numpy as jnp


# input is a tensor on the GPU
# input is a tensor on device
@jax.jit
def solve(input: jax.Array, N: int) -> jax.Array:
# return output tensor directly
Expand Down
2 changes: 1 addition & 1 deletion challenges/easy/62_value_clipping/starter/starter.jax.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import jax.numpy as jnp


# input is a tensor on the GPU
# input is a tensor on device
@jax.jit
def solve(input: jax.Array, lo: float, hi: float, N: int) -> jax.Array:
# return output tensor directly
Expand Down
2 changes: 1 addition & 1 deletion challenges/easy/63_interleave/starter/starter.jax.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import jax.numpy as jnp


# A, B are tensors on the GPU
# A, B are tensors on device
@jax.jit
def solve(A: jax.Array, B: jax.Array, N: int) -> jax.Array:
# return output tensor directly
Expand Down
2 changes: 1 addition & 1 deletion challenges/easy/65_geglu/starter/starter.jax.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import jax.numpy as jnp


# input is a tensor on the GPU
# input is a tensor on device
@jax.jit
def solve(input: jax.Array, N: int) -> jax.Array:
# return output tensor directly
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import jax.numpy as jnp


# input is a tensor on GPU
# input is a tensor on device
@jax.jit
def solve(input: jax.Array, width: int, height: int) -> jax.Array:
# return output tensor directly
Expand Down
2 changes: 1 addition & 1 deletion challenges/easy/68_sigmoid/starter/starter.jax.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import jax.numpy as jnp


# X is a tensor on GPU
# X is a tensor on device
@jax.jit
def solve(X: jax.Array, N: int) -> jax.Array:
# return output tensor directly
Expand Down
2 changes: 1 addition & 1 deletion challenges/easy/7_color_inversion/starter/starter.jax.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import jax.numpy as jnp


# image is a tensor on the GPU
# image is a tensor on device
@jax.jit
def solve(image: jax.Array, width: int, height: int) -> jax.Array:
# return output tensor directly
Expand Down
2 changes: 1 addition & 1 deletion challenges/easy/8_matrix_addition/starter/starter.jax.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import jax.numpy as jnp


# A, B are tensors on the GPU
# A, B are tensors on device
@jax.jit
def solve(A: jax.Array, B: jax.Array, N: int) -> jax.Array:
# return output tensor directly
Expand Down
2 changes: 1 addition & 1 deletion challenges/easy/9_1d_convolution/starter/starter.jax.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import jax.numpy as jnp


# input, kernel are tensors on the GPU
# input, kernel are tensors on device
@jax.jit
def solve(input: jax.Array, kernel: jax.Array, input_size: int, kernel_size: int) -> jax.Array:
# return output tensor directly
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import jax.numpy as jnp


# Q, K, V are tensors on the GPU
# Q, K, V are tensors on device
@jax.jit
def solve(Q: jax.Array, K: jax.Array, V: jax.Array, N: int, d_model: int, h: int) -> jax.Array:
# return output tensor directly
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import jax.numpy as jnp


# agents is a tensor on the GPU
# agents is a tensor on device
@jax.jit
def solve(agents: jax.Array, N: int) -> jax.Array:
# return output tensor directly
Expand Down
2 changes: 1 addition & 1 deletion challenges/hard/15_sorting/starter/starter.jax.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import jax.numpy as jnp


# data is a tensor on the GPU
# data is a tensor on device
@jax.jit
def solve(data: jax.Array, N: int) -> jax.Array:
# return output tensor directly
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import jax.numpy as jnp


# data_x, data_y, initial_centroid_x, initial_centroid_y are tensors on the GPU
# data_x, data_y, initial_centroid_x, initial_centroid_y are tensors on device
@jax.jit
def solve(
data_x: jax.Array,
Expand Down
2 changes: 1 addition & 1 deletion challenges/hard/36_radix_sort/starter/starter.jax.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import jax.numpy as jnp


# input is a tensor on the GPU
# input is a tensor on device
@jax.jit
def solve(input: jax.Array, N: int) -> jax.Array:
# return output tensor directly
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import jax.numpy as jnp


# signal is a tensor on GPU
# signal is a tensor on device
@jax.jit
def solve(signal: jax.Array, N: int) -> jax.Array:
# return output tensor directly
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import jax.numpy as jnp


# grid is a tensor on the GPU
# grid is a tensor on device
@jax.jit
def solve(
grid: jax.Array,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import jax.numpy as jnp


# Q, K, V are tensors on the GPU
# Q, K, V are tensors on device
@jax.jit
def solve(Q: jax.Array, K: jax.Array, V: jax.Array, M: int, d: int) -> jax.Array:
# return output tensor directly
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import jax.numpy as jnp


# Q, K, V are tensors on the GPU
# Q, K, V are tensors on device
@jax.jit
def solve(Q: jax.Array, K: jax.Array, V: jax.Array, M: int, d: int) -> jax.Array:
# return output tensor directly
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import jax.numpy as jnp


# Q, K, V are tensors on the GPU
# Q, K, V are tensors on device
@jax.jit
def solve(Q: jax.Array, K: jax.Array, V: jax.Array, M: int, d: int, window_size: int) -> jax.Array:
# return output tensor directly
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import jax.numpy as jnp


# dist is a tensor on the GPU
# dist is a tensor on device
@jax.jit
def solve(dist: jax.Array, N: int) -> jax.Array:
# return output tensor directly
Expand Down
2 changes: 1 addition & 1 deletion challenges/hard/74_gpt2_block/starter/starter.jax.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import jax.numpy as jnp


# x, weights are tensors on GPU
# x, weights are tensors on device
@jax.jit
def solve(x: jax.Array, weights: jax.Array, seq_len: int) -> jax.Array:
# return output tensor directly
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import jax.numpy as jnp


# x, weights, cos, sin are tensors on GPU
# x, weights, cos, sin are tensors on device
@jax.jit
def solve(
x: jax.Array,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import jax.numpy as jnp


# input, kernel are tensors on the GPU
# input, kernel are tensors on device
@jax.jit
def solve(
input: jax.Array,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import jax.numpy as jnp


# input, kernel are tensors on the GPU
# input, kernel are tensors on device
@jax.jit
def solve(
input: jax.Array,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import jax.numpy as jnp


# input is a tensor on the GPU
# input is a tensor on device
@jax.jit
def solve(input: jax.Array, N: int, num_bins: int) -> jax.Array:
# return output tensor directly
Expand Down
2 changes: 1 addition & 1 deletion challenges/medium/16_prefix_sum/starter/starter.jax.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import jax.numpy as jnp


# input is a tensor on the GPU
# input is a tensor on device
@jax.jit
def solve(input: jax.Array, N: int) -> jax.Array:
# return output tensor directly
Expand Down
2 changes: 1 addition & 1 deletion challenges/medium/17_dot_product/starter/starter.jax.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import jax.numpy as jnp


# A, B are tensors on the GPU
# A, B are tensors on device
@jax.jit
def solve(A: jax.Array, B: jax.Array, N: int) -> jax.Array:
# return output tensor directly
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import jax.numpy as jnp


# A, x are tensors on the GPU
# A, x are tensors on device
@jax.jit
def solve(A: jax.Array, x: jax.Array, M: int, N: int, nnz: int) -> jax.Array:
# return output tensor directly
Expand Down
2 changes: 1 addition & 1 deletion challenges/medium/22_gemm/starter/starter.jax.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import jax.numpy as jnp


# A, B are tensors on the GPU
# A, B are tensors on device
@jax.jit
def solve(
A: jax.Array, B: jax.Array, M: int, N: int, K: int, alpha: float, beta: float
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import jax.numpy as jnp


# logits, true_labels are tensors on the GPU
# logits, true_labels are tensors on device
@jax.jit
def solve(logits: jax.Array, true_labels: jax.Array, N: int, C: int) -> jax.Array:
# return output tensor directly
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import jax.numpy as jnp


# predictions, targets are tensors on the GPU
# predictions, targets are tensors on device
@jax.jit
def solve(predictions: jax.Array, targets: jax.Array, N: int) -> jax.Array:
# return output tensor directly
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import jax.numpy as jnp


# input, kernel are tensors on the GPU
# input, kernel are tensors on device
@jax.jit
def solve(
input: jax.Array,
Expand Down
Loading
Loading