diff --git a/challenges/core/challenge_base.py b/challenges/core/challenge_base.py index 97faebb0..201e087e 100644 --- a/challenges/core/challenge_base.py +++ b/challenges/core/challenge_base.py @@ -2,6 +2,54 @@ from typing import Any, Dict, List +class RandTensor: + """Uniform random input in [low, high).""" + + def __init__(self, shape, low=0.0, high=1.0, dtype="float32"): + self.shape = tuple(shape) + self.low = low + self.high = high + self.dtype = dtype + + +class RandnTensor: + """Normal (Gaussian) random input.""" + + def __init__(self, shape, mean=0.0, std=1.0, dtype="float32"): + self.shape = tuple(shape) + self.mean = mean + self.std = std + self.dtype = dtype + + +class RandIntTensor: + """Uniform integer random input in [low, high).""" + + def __init__(self, shape, low, high, dtype="int32"): + self.shape = tuple(shape) + self.low = low + self.high = high + self.dtype = dtype + + +class FullTensor: + """Constant-filled input (covers zeros / ones / full).""" + + def __init__(self, shape, value=0.0, dtype="float32"): + self.shape = tuple(shape) + self.value = value + self.dtype = dtype + + +class OutTensor: + """Output buffer (by shape): materialized empty where outputs are written in + place (torch), omitted where they're returned functionally (jax).""" + + def __init__(self, shape, dtype="float32"): + self.shape = tuple(shape) + self.dtype = dtype + + class ChallengeBase(ABC): name: str atol: float diff --git a/challenges/easy/2_matrix_multiplication/challenge.py b/challenges/easy/2_matrix_multiplication/challenge.py index 03fee9e2..6013af07 100644 --- a/challenges/easy/2_matrix_multiplication/challenge.py +++ b/challenges/easy/2_matrix_multiplication/challenge.py @@ -2,7 +2,7 @@ from typing import Any, Dict, List import torch -from core.challenge_base import ChallengeBase +from core.challenge_base import ChallengeBase, OutTensor, RandTensor class Challenge(ChallengeBase): @@ -23,6 +23,11 @@ def reference_impl( torch.matmul(A, B, out=C) + def reference_impl_jax(self, A, B, M, N, K): + import jax.numpy as jnp + + return jnp.matmul(A, B) + def get_solve_signature(self) -> Dict[str, tuple]: return { "A": (ctypes.POINTER(ctypes.c_float), "in"), @@ -126,12 +131,11 @@ def generate_functional_test(self) -> List[Dict[str, Any]]: return test_cases def generate_performance_test(self) -> Dict[str, Any]: - dtype = torch.float32 M, N, K = 8192, 6144, 4096 return { - "A": torch.empty(M, N, device=self.device, dtype=dtype).uniform_(-10.0, 10.0), - "B": torch.empty(N, K, device=self.device, dtype=dtype).uniform_(-10.0, 10.0), - "C": torch.empty(M, K, device=self.device, dtype=dtype), + "A": RandTensor((M, N), -10.0, 10.0), + "B": RandTensor((N, K), -10.0, 10.0), + "C": OutTensor((M, K)), "M": M, "N": N, "K": K, diff --git a/challenges/easy/9_1d_convolution/challenge.py b/challenges/easy/9_1d_convolution/challenge.py index 2929b9cf..f3f059b2 100644 --- a/challenges/easy/9_1d_convolution/challenge.py +++ b/challenges/easy/9_1d_convolution/challenge.py @@ -2,7 +2,7 @@ from typing import Any, Dict, List import torch -from core.challenge_base import ChallengeBase +from core.challenge_base import ChallengeBase, OutTensor, RandTensor class Challenge(ChallengeBase): @@ -33,6 +33,22 @@ def reference_impl( # 'ij,j->i' means: for each window i, multiply with kernel j and sum over j output.copy_(torch.einsum("ij,j->i", windows, kernel)) + def reference_impl_jax(self, input, kernel, input_size, kernel_size): + import jax + + # Cross-correlation, valid padding, stride 1. + # Shapes: input (N=1, C=1, W), kernel (O=1, I=1, W). + lhs = input.reshape(1, 1, input_size) + rhs = kernel.reshape(1, 1, kernel_size) + result = jax.lax.conv_general_dilated( + lhs, + rhs, + window_strides=(1,), + padding="VALID", + precision=jax.lax.Precision.HIGHEST, + ) + return result.reshape(-1) + def get_solve_signature(self) -> Dict[str, tuple]: return { "input": (ctypes.POINTER(ctypes.c_float), "in"), @@ -128,13 +144,12 @@ def generate_functional_test(self) -> List[Dict[str, Any]]: return test_cases def generate_performance_test(self) -> Dict[str, Any]: - dtype = torch.float32 input_size, kernel_size = 1500000, 2047 # Large convolution for performance testing output_size = input_size - kernel_size + 1 return { - "input": torch.empty(input_size, device=self.device, dtype=dtype).uniform_(-1.0, 1.0), - "kernel": torch.empty(kernel_size, device=self.device, dtype=dtype).uniform_(-1.0, 1.0), - "output": torch.empty(output_size, device=self.device, dtype=dtype), + "input": RandTensor((input_size,), -1.0, 1.0), + "kernel": RandTensor((kernel_size,), -1.0, 1.0), + "output": OutTensor((output_size,)), "input_size": input_size, "kernel_size": kernel_size, } diff --git a/challenges/hard/12_multi_head_attention/challenge.py b/challenges/hard/12_multi_head_attention/challenge.py index 103806d4..0cfd2ed6 100644 --- a/challenges/hard/12_multi_head_attention/challenge.py +++ b/challenges/hard/12_multi_head_attention/challenge.py @@ -2,7 +2,7 @@ from typing import Any, Dict, List import torch -from core.challenge_base import ChallengeBase +from core.challenge_base import ChallengeBase, OutTensor, RandTensor class Challenge(ChallengeBase): @@ -40,6 +40,22 @@ def reference_impl( result[:, head * d_k : (head + 1) * d_k] = head_output output.copy_(result) + def reference_impl_jax(self, Q, K, V, N, d_model, h): + import jax + import jax.numpy as jnp + + d_k = d_model // h + # Reshape (N, d_model) -> (N, h, d_k) -> (h, N, d_k) + Q_h = jnp.transpose(jnp.reshape(Q, (N, h, d_k)), (1, 0, 2)) + K_h = jnp.transpose(jnp.reshape(K, (N, h, d_k)), (1, 0, 2)) + V_h = jnp.transpose(jnp.reshape(V, (N, h, d_k)), (1, 0, 2)) + scores = jnp.matmul(Q_h, jnp.transpose(K_h, (0, 2, 1))) / (d_k**0.5) + softmax = jax.nn.softmax(scores, axis=-1) + head_output = jnp.matmul(softmax, V_h) # (h, N, d_k) + # (h, N, d_k) -> (N, h, d_k) -> (N, d_model) + result = jnp.reshape(jnp.transpose(head_output, (1, 0, 2)), (N, d_model)) + return result + def get_solve_signature(self) -> Dict[str, tuple]: return { "Q": (ctypes.POINTER(ctypes.c_float), "in"), @@ -111,16 +127,11 @@ def generate_functional_test(self) -> List[Dict[str, Any]]: return test_cases def generate_performance_test(self) -> Dict[str, Any]: - dtype = torch.float32 - Q = torch.empty(1024, 1024, device=self.device, dtype=dtype).uniform_(-10.0, 10.0) - K = torch.empty(1024, 1024, device=self.device, dtype=dtype).uniform_(-10.0, 10.0) - V = torch.empty(1024, 1024, device=self.device, dtype=dtype).uniform_(-10.0, 10.0) - output = torch.zeros(1024, 1024, device=self.device, dtype=dtype) return { - "Q": Q, - "K": K, - "V": V, - "output": output, + "Q": RandTensor((1024, 1024), -10.0, 10.0), + "K": RandTensor((1024, 1024), -10.0, 10.0), + "V": RandTensor((1024, 1024), -10.0, 10.0), + "output": OutTensor((1024, 1024)), "N": 1024, "d_model": 1024, "h": 16, diff --git a/challenges/hard/14_multi_agent_sim/challenge.py b/challenges/hard/14_multi_agent_sim/challenge.py index dd4e538d..a730931c 100644 --- a/challenges/hard/14_multi_agent_sim/challenge.py +++ b/challenges/hard/14_multi_agent_sim/challenge.py @@ -2,7 +2,7 @@ from typing import Any, Dict, List import torch -from core.challenge_base import ChallengeBase +from core.challenge_base import ChallengeBase, OutTensor, RandTensor class Challenge(ChallengeBase): @@ -39,6 +39,37 @@ def reference_impl(self, agents: torch.Tensor, agents_next: torch.Tensor, N: int agents_next_reshaped[:] = torch.cat([new_positions, new_velocities], dim=1) agents_next.copy_(agents_next_reshaped.view(-1)) + def reference_impl_jax(self, agents, N): + import jax + import jax.numpy as jnp + + r = 5.0 + r2 = r * r + alpha = 0.05 + agents_reshaped = agents.reshape(N, 4) + positions = agents_reshaped[:, :2] + velocities = agents_reshaped[:, 2:] + diff = positions[:, None, :] - positions[None, :, :] + dist_sq = (diff**2).sum(axis=2) + dist_sq = dist_sq + jnp.eye(N) * (r2 + 1) + neighbor_mask = dist_sq < r2 + sum_velocities = jnp.matmul( + neighbor_mask.astype(velocities.dtype), + velocities, + precision=jax.lax.Precision.HIGHEST, + ) + neighbor_counts = neighbor_mask.sum(axis=1, keepdims=True) + nonzero_mask = neighbor_counts[:, 0] > 0 + avg_velocities = jnp.where( + nonzero_mask[:, None], + sum_velocities / jnp.where(neighbor_counts == 0, 1, neighbor_counts), + velocities, + ) + new_velocities = velocities + alpha * (avg_velocities - velocities) + new_positions = positions + new_velocities + agents_next = jnp.concatenate([new_positions, new_velocities], axis=1) + return agents_next.reshape(-1) + def get_solve_signature(self) -> Dict[str, tuple]: return { "agents": (ctypes.POINTER(ctypes.c_float), "in"), @@ -99,11 +130,8 @@ def generate_functional_test(self) -> List[Dict[str, Any]]: return test_cases def generate_performance_test(self) -> Dict[str, Any]: - dtype = torch.float32 - agents = torch.empty(40000, device=self.device, dtype=dtype).uniform_(-1000.0, 1000.0) - agents_next = torch.empty(40000, device=self.device, dtype=dtype) return { - "agents": agents, - "agents_next": agents_next, + "agents": RandTensor((40000,), -1000.0, 1000.0), + "agents_next": OutTensor((40000,)), "N": 10000, } diff --git a/challenges/hard/20_kmeans_clustering/challenge.py b/challenges/hard/20_kmeans_clustering/challenge.py index f0689d5e..00411ce6 100644 --- a/challenges/hard/20_kmeans_clustering/challenge.py +++ b/challenges/hard/20_kmeans_clustering/challenge.py @@ -45,6 +45,31 @@ def reference_impl( final_centroid_x[i] = data_x[mask].mean() final_centroid_y[i] = data_y[mask].mean() + def reference_impl_jax( + self, data_x, data_y, initial_centroid_x, initial_centroid_y, sample_size, k, max_iterations + ): + import jax + import jax.numpy as jnp + + final_centroid_x = initial_centroid_x + final_centroid_y = initial_centroid_y + labels = jnp.zeros((sample_size,), dtype=jnp.int32) + for _ in range(max_iterations): + expanded_x = data_x.reshape(-1, 1) - final_centroid_x.reshape(1, -1) + expanded_y = data_y.reshape(-1, 1) - final_centroid_y.reshape(1, -1) + distances = expanded_x**2 + expanded_y**2 + labels = jnp.argmin(distances, axis=1).astype(jnp.int32) + onehot = jax.nn.one_hot(labels, k, dtype=data_x.dtype) + counts = onehot.sum(axis=0) + sum_x = jnp.matmul(data_x, onehot, precision=jax.lax.Precision.HIGHEST) + sum_y = jnp.matmul(data_y, onehot, precision=jax.lax.Precision.HIGHEST) + safe_counts = jnp.where(counts == 0, 1, counts) + mean_x = sum_x / safe_counts + mean_y = sum_y / safe_counts + final_centroid_x = jnp.where(counts > 0, mean_x, final_centroid_x) + final_centroid_y = jnp.where(counts > 0, mean_y, final_centroid_y) + return labels, final_centroid_x, final_centroid_y + def get_solve_signature(self) -> Dict[str, tuple]: return { "data_x": (ctypes.POINTER(ctypes.c_float), "in"), diff --git a/challenges/hard/53_casual_attention/challenge.py b/challenges/hard/53_casual_attention/challenge.py index d6bfd241..af087d0f 100644 --- a/challenges/hard/53_casual_attention/challenge.py +++ b/challenges/hard/53_casual_attention/challenge.py @@ -2,7 +2,7 @@ from typing import Any, Dict, List import torch -from core.challenge_base import ChallengeBase +from core.challenge_base import ChallengeBase, OutTensor, RandTensor class Challenge(ChallengeBase): @@ -30,6 +30,18 @@ def reference_impl( attn = torch.softmax(attn, dim=1) torch.matmul(attn, V, out=output) + def reference_impl_jax(self, Q, K, V, M, d): + import jax + import jax.numpy as jnp + + scale = d**0.5 + attn = jnp.matmul(Q, K.T) / scale + + mask = jnp.triu(jnp.ones((M, M), dtype=bool), k=1) + attn = jnp.where(mask, -jnp.inf, attn) + attn = jax.nn.softmax(attn, axis=1) + return jnp.matmul(attn, V) + def get_solve_signature(self) -> Dict[str, tuple]: return { "Q": (ctypes.POINTER(ctypes.c_float), "in"), @@ -127,10 +139,12 @@ def generate_functional_test(self) -> List[Dict[str, Any]]: return tests def generate_performance_test(self) -> Dict[str, Any]: - dtype = torch.float32 M, d = 5000, 128 - Q = torch.empty((M, d), device=self.device, dtype=dtype).uniform_(-100, 100) - K = torch.empty((M, d), device=self.device, dtype=dtype).uniform_(-100, 100) - V = torch.empty((M, d), device=self.device, dtype=dtype).uniform_(-100, 100) - output = torch.empty(M, d, device=self.device, dtype=dtype) - return {"Q": Q, "K": K, "V": V, "output": output, "M": M, "d": d} + return { + "Q": RandTensor((M, d), -100.0, 100.0), + "K": RandTensor((M, d), -100.0, 100.0), + "V": RandTensor((M, d), -100.0, 100.0), + "output": OutTensor((M, d)), + "M": M, + "d": d, + } diff --git a/challenges/hard/56_linear_attention/challenge.py b/challenges/hard/56_linear_attention/challenge.py index fdab57a0..73b6b9cf 100644 --- a/challenges/hard/56_linear_attention/challenge.py +++ b/challenges/hard/56_linear_attention/challenge.py @@ -2,7 +2,7 @@ from typing import Any, Dict, List import torch -from core.challenge_base import ChallengeBase +from core.challenge_base import ChallengeBase, OutTensor, RandTensor class Challenge(ChallengeBase): @@ -38,6 +38,23 @@ def reference_impl( output.copy_(numerator / denominator.unsqueeze(-1)) # (M, d) + def reference_impl_jax(self, Q, K, V, M, d): + import jax.numpy as jnp + + # φ(x) = ELU(x) + 1 + phi_Q = jnp.where(Q > 0, Q + 1, jnp.exp(Q)) + phi_K = jnp.where(K > 0, K + 1, jnp.exp(K)) + + # S = φ(K)^T V → (d, d) + S = phi_K.T @ V + # z = sum_j φ(K_j) → (d,) + z = phi_K.sum(axis=0) + + numerator = phi_Q @ S # (M, d) + denominator = phi_Q @ z # (M,) + + return numerator / denominator[:, None] # (M, d) + def get_solve_signature(self) -> Dict[str, tuple]: return { "Q": (ctypes.POINTER(ctypes.c_float), "in"), @@ -147,10 +164,12 @@ def generate_functional_test(self) -> List[Dict[str, Any]]: return tests def generate_performance_test(self) -> Dict[str, Any]: - dtype = torch.float32 M, d = 10000, 128 - Q = torch.empty((M, d), device=self.device, dtype=dtype).uniform_(-100, 100) - K = torch.empty((M, d), device=self.device, dtype=dtype).uniform_(-100, 100) - V = torch.empty((M, d), device=self.device, dtype=dtype).uniform_(-100, 100) - output = torch.empty(M, d, device=self.device, dtype=dtype) - return {"Q": Q, "K": K, "V": V, "output": output, "M": M, "d": d} + return { + "Q": RandTensor((M, d), -100, 100), + "K": RandTensor((M, d), -100, 100), + "V": RandTensor((M, d), -100, 100), + "output": OutTensor((M, d)), + "M": M, + "d": d, + } diff --git a/challenges/hard/59_sliding_window_attn/challenge.py b/challenges/hard/59_sliding_window_attn/challenge.py index f2eee8bd..5be3aa4f 100644 --- a/challenges/hard/59_sliding_window_attn/challenge.py +++ b/challenges/hard/59_sliding_window_attn/challenge.py @@ -2,7 +2,7 @@ from typing import Any, Dict, List import torch -from core.challenge_base import ChallengeBase +from core.challenge_base import ChallengeBase, OutTensor, RandTensor class Challenge(ChallengeBase): @@ -34,6 +34,19 @@ def reference_impl( torch.matmul(attn, V, out=output) + def reference_impl_jax(self, Q, K, V, M, d, window_size): + import jax + import jax.numpy as jnp + + scores = (Q @ K.T) / (d**0.5) + + idxs = jnp.arange(M) + mask = jnp.abs(idxs[None, :] - idxs[:, None]) > window_size + scores = jnp.where(mask, -jnp.inf, scores) + attn = jax.nn.softmax(scores, axis=1) + + return jnp.matmul(attn, V) + def get_solve_signature(self) -> Dict[str, Any]: return { "Q": (ctypes.POINTER(ctypes.c_float), "in"), @@ -155,17 +168,12 @@ def generate_functional_test(self) -> List[Dict[str, Any]]: return tests def generate_performance_test(self) -> Dict[str, Any]: - dtype = torch.float32 M, d, window_size = 5000, 64, 16 - Q = torch.empty((M, d), device=self.device, dtype=dtype).uniform_(-100, 100) - K = torch.empty((M, d), device=self.device, dtype=dtype).uniform_(-100, 100) - V = torch.empty((M, d), device=self.device, dtype=dtype).uniform_(-100, 100) - output = torch.empty(M, d, device=self.device, dtype=dtype) return { - "Q": Q, - "K": K, - "V": V, - "output": output, + "Q": RandTensor((M, d), -100, 100), + "K": RandTensor((M, d), -100, 100), + "V": RandTensor((M, d), -100, 100), + "output": OutTensor((M, d)), "M": M, "d": d, "window_size": window_size, diff --git a/challenges/hard/73_all_pairs_shortest_paths/challenge.py b/challenges/hard/73_all_pairs_shortest_paths/challenge.py index 846f78da..a3467f44 100644 --- a/challenges/hard/73_all_pairs_shortest_paths/challenge.py +++ b/challenges/hard/73_all_pairs_shortest_paths/challenge.py @@ -37,6 +37,22 @@ def reference_impl(self, dist: torch.Tensor, output: torch.Tensor, N: int): d = torch.minimum(d, d[:, k : k + 1] + d[k : k + 1, :]) output.copy_(d.view(-1)) + def reference_impl_jax(self, dist, N): + import jax + import jax.numpy as jnp + + d = jnp.asarray(dist, dtype=jnp.float32).reshape(N, N) + + # Floyd-Warshall: for each k, d = min(d, d[:, k] + d[k, :]). + # inf for unreachable preserved (inf + x = inf, min stays inf). + def body(k, d): + col = jax.lax.dynamic_slice_in_dim(d, k, 1, axis=1) # (N, 1) + row = jax.lax.dynamic_slice_in_dim(d, k, 1, axis=0) # (1, N) + return jnp.minimum(d, col + row) + + d = jax.lax.fori_loop(0, N, body, d) + return d.reshape(-1) + def get_solve_signature(self) -> Dict[str, tuple]: return { "dist": (ctypes.POINTER(ctypes.c_float), "in"), diff --git a/challenges/hard/74_gpt2_block/challenge.py b/challenges/hard/74_gpt2_block/challenge.py index ffce5be0..701a6954 100644 --- a/challenges/hard/74_gpt2_block/challenge.py +++ b/challenges/hard/74_gpt2_block/challenge.py @@ -96,6 +96,64 @@ def reference_impl( # residual connection 2 output.copy_(hidden + proj) + def reference_impl_jax(self, x, weights, seq_len): + import jax + import jax.numpy as jnp + + def layer_norm(z, w, b, eps=1e-5): + mean = jnp.mean(z, axis=-1, keepdims=True) + var = jnp.mean((z - mean) ** 2, axis=-1, keepdims=True) + return (z - mean) * jax.lax.rsqrt(var + eps) * w + b + + # unpack weights + ln1_w = weights[O_LN1_W:O_LN1_B] + ln1_b = weights[O_LN1_B:O_WQKV] + W_qkv = weights[O_WQKV:O_BQKV].reshape(D, 3 * D) + b_qkv = weights[O_BQKV:O_WAPROJ] + W_attn = weights[O_WAPROJ:O_BAPROJ].reshape(D, D) + b_attn = weights[O_BAPROJ:O_LN2_W] + ln2_w = weights[O_LN2_W:O_LN2_B] + ln2_b = weights[O_LN2_B:O_WFC] + W_fc = weights[O_WFC:O_BFC].reshape(D, FFN) + b_fc = weights[O_BFC:O_WPROJ] + W_proj = weights[O_WPROJ:O_BPROJ].reshape(FFN, D) + b_proj = weights[O_BPROJ : O_BPROJ + D] + + # layer norm 1 + x_norm = layer_norm(x, ln1_w, ln1_b, eps=1e-5) + + # qkv projection + qkv = x_norm @ W_qkv + b_qkv + q, k, v = jnp.split(qkv, 3, axis=-1) + + # reshape for multi-head attention: (H, seq_len, DH) + q = q.reshape(seq_len, H, DH).transpose(1, 0, 2) + k = k.reshape(seq_len, H, DH).transpose(1, 0, 2) + v = v.reshape(seq_len, H, DH).transpose(1, 0, 2) + + # scaled dot-product attention + scores = jnp.matmul(q, jnp.swapaxes(k, -2, -1)) / math.sqrt(DH) + attn_weights = jax.nn.softmax(scores, axis=-1) + attn_out = jnp.matmul(attn_weights, v) + + # concat heads and project + attn_out = attn_out.transpose(1, 0, 2).reshape(seq_len, D) + attn_proj = attn_out @ W_attn + b_attn + + # residual connection 1 + hidden = x + attn_proj + + # layer norm 2 + h_norm = layer_norm(hidden, ln2_w, ln2_b, eps=1e-5) + + # ffn: linear -> gelu (tanh approx) -> linear + fc = h_norm @ W_fc + b_fc + fc = jax.nn.gelu(fc, approximate=True) + proj = fc @ W_proj + b_proj + + # residual connection 2 + return hidden + proj + def get_solve_signature(self) -> Dict[str, tuple]: return { "x": (ctypes.POINTER(ctypes.c_float), "in"), diff --git a/challenges/hard/93_llama_transformer_block/challenge.py b/challenges/hard/93_llama_transformer_block/challenge.py index 0e89ca83..a6a1a408 100644 --- a/challenges/hard/93_llama_transformer_block/challenge.py +++ b/challenges/hard/93_llama_transformer_block/challenge.py @@ -122,6 +122,80 @@ def apply_rope(qk, c, s): # Residual 2 output.copy_(hidden + ffn_out) + def reference_impl_jax(self, x, weights, cos, sin, seq_len): + import jax + import jax.numpy as jnp + + def rms_norm(z, w): + return z * jax.lax.rsqrt(jnp.mean(z**2, axis=-1, keepdims=True) + 1e-5) * w + + def apply_rope(qk, c, s): + # qk: (seq_len, num_heads, head_dim) + q1, q2 = qk[..., : HEAD_DIM // 2], qk[..., HEAD_DIM // 2 :] + c = c[:, None, :] # (seq_len, 1, head_dim//2) + s = s[:, None, :] + return jnp.concatenate([q1 * c - q2 * s, q1 * s + q2 * c], axis=-1) + + # unpack weights + rms1_w = weights[O_RMS1_W:O_WQ] + W_Q = weights[O_WQ:O_WK].reshape(Q_DIM, D) + W_K = weights[O_WK:O_WV].reshape(KV_DIM, D) + W_V = weights[O_WV:O_WO].reshape(KV_DIM, D) + W_O = weights[O_WO:O_RMS2_W].reshape(D, D) + rms2_w = weights[O_RMS2_W:O_WGATE] + W_gate = weights[O_WGATE:O_WUP].reshape(FFN_HIDDEN, D) + W_up = weights[O_WUP:O_WDOWN].reshape(FFN_HIDDEN, D) + W_down = weights[O_WDOWN:TOTAL_WEIGHTS].reshape(D, FFN_HIDDEN) + + # --- Attention sub-block --- + x_norm = rms_norm(x, rms1_w) + + # QKV projections + q = (x_norm @ W_Q.T).reshape(seq_len, NUM_Q_HEADS, HEAD_DIM) + k = (x_norm @ W_K.T).reshape(seq_len, NUM_KV_HEADS, HEAD_DIM) + v = (x_norm @ W_V.T).reshape(seq_len, NUM_KV_HEADS, HEAD_DIM) + + # Apply RoPE to Q and K + q = apply_rope(q, cos, sin) + k = apply_rope(k, cos, sin) + + # Reshape for batched matmul: (num_heads, seq_len, head_dim) + q = q.transpose(1, 0, 2) # (NUM_Q_HEADS, seq_len, HEAD_DIM) + k = k.transpose(1, 0, 2) # (NUM_KV_HEADS, seq_len, HEAD_DIM) + v = v.transpose(1, 0, 2) # (NUM_KV_HEADS, seq_len, HEAD_DIM) + + # GQA: broadcast K and V to match Q heads + k = jnp.repeat(k, GQA_GROUPS, axis=0) # (NUM_Q_HEADS, seq_len, HEAD_DIM) + v = jnp.repeat(v, GQA_GROUPS, axis=0) + + # Causal scaled dot-product attention + scores = jnp.matmul(q, jnp.swapaxes(k, -2, -1)) / math.sqrt(HEAD_DIM) + causal_mask = jnp.triu( + jnp.full((seq_len, seq_len), -jnp.inf, dtype=x.dtype), + k=1, + ) + scores = scores + causal_mask + attn_weights = jax.nn.softmax(scores, axis=-1) + attn_out = jnp.matmul(attn_weights, v) # (NUM_Q_HEADS, seq_len, HEAD_DIM) + + # Merge heads and project + attn_out = attn_out.transpose(1, 0, 2).reshape(seq_len, D) + attn_proj = attn_out @ W_O.T + + # Residual 1 + hidden = x + attn_proj + + # --- FFN sub-block --- + h_norm = rms_norm(hidden, rms2_w) + + # SwiGLU: gate * up, then project down + gate = jax.nn.silu(h_norm @ W_gate.T) + up = h_norm @ W_up.T + ffn_out = (gate * up) @ W_down.T + + # Residual 2 + return hidden + ffn_out + def get_solve_signature(self) -> Dict[str, tuple]: return { "x": (ctypes.POINTER(ctypes.c_float), "in"), diff --git a/challenges/medium/10_2d_convolution/challenge.py b/challenges/medium/10_2d_convolution/challenge.py index db462bdd..b90f7c09 100644 --- a/challenges/medium/10_2d_convolution/challenge.py +++ b/challenges/medium/10_2d_convolution/challenge.py @@ -2,7 +2,7 @@ from typing import Any, Dict, List import torch -from core.challenge_base import ChallengeBase +from core.challenge_base import ChallengeBase, OutTensor, RandTensor class Challenge(ChallengeBase): @@ -34,6 +34,22 @@ def reference_impl( # Copy result to output tensor (removing the extra dimensions and flattening) output.copy_(result.view(-1)) + def reference_impl_jax(self, input, kernel, input_rows, input_cols, kernel_rows, kernel_cols): + import jax + + # Cross-correlation (matches F.conv2d), valid padding, stride 1. + # Shapes: input (N=1, C=1, H, W), kernel (O=1, I=1, H, W). + lhs = input.reshape(1, 1, input_rows, input_cols) + rhs = kernel.reshape(1, 1, kernel_rows, kernel_cols) + result = jax.lax.conv_general_dilated( + lhs, + rhs, + window_strides=(1, 1), + padding="VALID", + precision=jax.lax.Precision.HIGHEST, + ) + return result.reshape(-1) + def get_solve_signature(self) -> Dict[str, tuple]: return { "input": (ctypes.POINTER(ctypes.c_float), "in"), @@ -164,20 +180,15 @@ def generate_functional_test(self) -> List[Dict[str, Any]]: return tests def generate_performance_test(self) -> Dict[str, Any]: - dtype = torch.float32 input_rows = 3072 input_cols = 3072 kernel_rows = 15 kernel_cols = 15 - input = torch.empty(input_rows * input_cols, device=self.device, dtype=dtype).uniform_( - -1.0, 1.0 - ) - kernel = torch.empty(kernel_rows * kernel_cols, device=self.device, dtype=dtype).uniform_( - -1.0, 1.0 - ) + input = RandTensor((input_rows * input_cols,), -1.0, 1.0) + kernel = RandTensor((kernel_rows * kernel_cols,), -1.0, 1.0) output_rows = input_rows - kernel_rows + 1 output_cols = input_cols - kernel_cols + 1 - output = torch.empty(output_rows * output_cols, device=self.device, dtype=dtype) + output = OutTensor((output_rows * output_cols,)) return { "input": input, "kernel": kernel, diff --git a/challenges/medium/11_3d_convolution/challenge.py b/challenges/medium/11_3d_convolution/challenge.py index 1e4a4618..b9702aa6 100644 --- a/challenges/medium/11_3d_convolution/challenge.py +++ b/challenges/medium/11_3d_convolution/challenge.py @@ -2,7 +2,7 @@ from typing import Any, Dict, List import torch -from core.challenge_base import ChallengeBase +from core.challenge_base import ChallengeBase, OutTensor, RandTensor class Challenge(ChallengeBase): @@ -43,6 +43,32 @@ def reference_impl( output.copy_(result.squeeze(0).squeeze(0)) + def reference_impl_jax( + self, + input, + kernel, + input_depth, + input_rows, + input_cols, + kernel_depth, + kernel_rows, + kernel_cols, + ): + import jax + + # Cross-correlation (matches F.conv3d), valid padding, stride 1. + # Shapes: input (N=1, C=1, D, H, W), kernel (O=1, I=1, D, H, W). + lhs = input.reshape(1, 1, input_depth, input_rows, input_cols) + rhs = kernel.reshape(1, 1, kernel_depth, kernel_rows, kernel_cols) + result = jax.lax.conv_general_dilated( + lhs, + rhs, + window_strides=(1, 1, 1), + padding="VALID", + precision=jax.lax.Precision.HIGHEST, + ) + return result.reshape(result.shape[2:]) + def get_solve_signature(self) -> Dict[str, tuple]: return { "input": (ctypes.POINTER(ctypes.c_float), "in"), @@ -242,26 +268,18 @@ def generate_functional_test(self) -> List[Dict[str, Any]]: return tests def generate_performance_test(self) -> Dict[str, Any]: - dtype = torch.float32 input_depth, input_rows, input_cols = 256, 128, 128 kernel_depth, kernel_rows, kernel_cols = 5, 5, 5 - input_tensor = torch.empty( - input_depth, input_rows, input_cols, device=self.device, dtype=dtype - ).uniform_(-1.0, 1.0) - kernel_tensor = torch.empty( - kernel_depth, kernel_rows, kernel_cols, device=self.device, dtype=dtype - ).uniform_(-1.0, 1.0) - output_tensor = torch.zeros( - input_depth - kernel_depth + 1, - input_rows - kernel_rows + 1, - input_cols - kernel_cols + 1, - device=self.device, - dtype=dtype, - ) return { - "input": input_tensor, - "kernel": kernel_tensor, - "output": output_tensor, + "input": RandTensor((input_depth, input_rows, input_cols), -1.0, 1.0), + "kernel": RandTensor((kernel_depth, kernel_rows, kernel_cols), -1.0, 1.0), + "output": OutTensor( + ( + input_depth - kernel_depth + 1, + input_rows - kernel_rows + 1, + input_cols - kernel_cols + 1, + ) + ), "input_depth": input_depth, "input_rows": input_rows, "input_cols": input_cols, diff --git a/challenges/medium/22_gemm/challenge.py b/challenges/medium/22_gemm/challenge.py index 945433f6..44d5ea99 100644 --- a/challenges/medium/22_gemm/challenge.py +++ b/challenges/medium/22_gemm/challenge.py @@ -2,7 +2,7 @@ from typing import Any, Dict, List import torch -from core.challenge_base import ChallengeBase +from core.challenge_base import ChallengeBase, RandTensor class Challenge(ChallengeBase): @@ -33,6 +33,16 @@ def reference_impl( final_result = alpha * matmul_result + beta * C_f32 C.copy_(final_result.to(torch.float16)) + def reference_impl_jax(self, A, B, C, M, N, K, alpha, beta): + import jax.numpy as jnp + + A_f32 = A.reshape(M, K).astype(jnp.float32) + B_f32 = B.reshape(K, N).astype(jnp.float32) + C_f32 = C.reshape(M, N).astype(jnp.float32) + matmul_result = jnp.matmul(A_f32, B_f32) + final_result = alpha * matmul_result + beta * C_f32 + return final_result.astype(jnp.float16) + def get_solve_signature(self) -> Dict[str, tuple]: return { "A": (ctypes.POINTER(ctypes.c_uint16), "in"), @@ -138,13 +148,12 @@ def generate_functional_test(self) -> List[Dict[str, Any]]: return tests def generate_performance_test(self) -> Dict[str, Any]: - dtype = torch.float16 M = 1024 N = 1024 K = 1024 - A = torch.empty((M, K), device=self.device, dtype=dtype).uniform_(-1.0, 1.0) - B = torch.empty((K, N), device=self.device, dtype=dtype).uniform_(-1.0, 1.0) - C = torch.empty((M, N), device=self.device, dtype=dtype).uniform_(-1.0, 1.0) + A = RandTensor((M, K), -1.0, 1.0, dtype="float16") + B = RandTensor((K, N), -1.0, 1.0, dtype="float16") + C = RandTensor((M, N), -1.0, 1.0, dtype="float16") return { "A": A, "B": B, diff --git a/challenges/medium/25_categorical_cross_entropy_loss/challenge.py b/challenges/medium/25_categorical_cross_entropy_loss/challenge.py index d74cd5cc..46fc6de6 100644 --- a/challenges/medium/25_categorical_cross_entropy_loss/challenge.py +++ b/challenges/medium/25_categorical_cross_entropy_loss/challenge.py @@ -2,7 +2,7 @@ from typing import Any, Dict, List import torch -from core.challenge_base import ChallengeBase +from core.challenge_base import ChallengeBase, OutTensor, RandIntTensor, RandTensor class Challenge(ChallengeBase): @@ -33,6 +33,20 @@ def reference_impl( total_loss += loss_i.item() loss[0] = total_loss / N + def reference_impl_jax(self, logits, true_labels, N, C): + import jax.numpy as jnp + + logits = jnp.asarray(logits, dtype=jnp.float32) # (N, C) + true_labels = jnp.asarray(true_labels, dtype=jnp.int32) # (N,) + + # Per-row: log_sum_exp - logit[true_label], stable via max subtraction. + max_logit = jnp.max(logits, axis=1) # (N,) + log_sum_exp = max_logit + jnp.log(jnp.sum(jnp.exp(logits - max_logit[:, None]), axis=1)) + true_logit = jnp.take_along_axis(logits, true_labels[:, None], axis=1)[:, 0] # (N,) + loss_per = log_sum_exp - true_logit # (N,) + mean_loss = jnp.sum(loss_per) / N # mean over N + return mean_loss.reshape(1) + def get_solve_signature(self) -> Dict[str, tuple]: return { "logits": (ctypes.POINTER(ctypes.c_float), "in"), @@ -153,13 +167,9 @@ def generate_functional_test(self) -> List[Dict[str, Any]]: return tests def generate_performance_test(self) -> Dict[str, Any]: - dtype_logits = torch.float32 - dtype_labels = torch.int32 - logits = torch.empty(10000, 1000, device=self.device, dtype=dtype_logits).uniform_( - -10.0, 10.0 - ) - true_labels = torch.randint(0, 1000, (10000,), device=self.device, dtype=dtype_labels) - loss = torch.zeros(1, device=self.device, dtype=dtype_logits) + logits = RandTensor((10000, 1000), -10.0, 10.0) + true_labels = RandIntTensor((10000,), 0, 1000, dtype="int32") + loss = OutTensor((1,)) return { "logits": logits, "true_labels": true_labels, diff --git a/challenges/medium/32_int8_quantized_matmul/challenge.py b/challenges/medium/32_int8_quantized_matmul/challenge.py index 43a602dc..b23b9ca8 100644 --- a/challenges/medium/32_int8_quantized_matmul/challenge.py +++ b/challenges/medium/32_int8_quantized_matmul/challenge.py @@ -37,6 +37,21 @@ def reference_impl( C_q = torch.clamp(C_q, -128, 127).to(torch.int8) C.view(M, N).copy_(C_q) + def reference_impl_jax( + self, A, B, M, N, K, scale_A, scale_B, scale_C, zero_point_A, zero_point_B, zero_point_C + ): + import jax.numpy as jnp + + A = A.reshape(M, K).astype(jnp.int32) + B = B.reshape(K, N).astype(jnp.int32) + A_f = (A - zero_point_A).astype(jnp.float32) + B_f = (B - zero_point_B).astype(jnp.float32) + C_f = jnp.round(jnp.matmul(A_f, B_f)).astype(jnp.int32) + C_f = C_f * scale_A * scale_B / scale_C + C_q = jnp.round(C_f).astype(jnp.int32) + zero_point_C + C_q = jnp.clip(C_q, -128, 127).astype(jnp.int8) + return C_q + def get_solve_signature(self) -> Dict[str, tuple]: return { "A": (ctypes.POINTER(ctypes.c_int8), "in"), diff --git a/challenges/medium/38_nearest_neighbor/challenge.py b/challenges/medium/38_nearest_neighbor/challenge.py index 832f10c7..04c9fe1f 100644 --- a/challenges/medium/38_nearest_neighbor/challenge.py +++ b/challenges/medium/38_nearest_neighbor/challenge.py @@ -2,7 +2,7 @@ from typing import Any, Dict, List import torch -from core.challenge_base import ChallengeBase +from core.challenge_base import ChallengeBase, OutTensor, RandTensor class Challenge(ChallengeBase): @@ -43,6 +43,16 @@ def reference_impl(self, points: torch.Tensor, indices: torch.Tensor, N: int): # Find nearest neighbor indices indices.copy_(torch.argmin(dist_sq, dim=1).int()) + def reference_impl_jax(self, points, N): + import jax.numpy as jnp + + pts = points.reshape(N, 3) + diff = pts[:, None, :] - pts[None, :, :] + dist_sq = jnp.sum(diff * diff, axis=2) + mask = jnp.eye(N, dtype=bool) + dist_sq = jnp.where(mask, jnp.inf, dist_sq) + return jnp.argmin(dist_sq, axis=1).astype(jnp.int32) + def get_solve_signature(self) -> Dict[str, tuple]: return { "points": (ctypes.POINTER(ctypes.c_float), "in"), @@ -208,14 +218,9 @@ def generate_functional_test(self) -> List[Dict[str, Any]]: return test_cases def generate_performance_test(self) -> Dict[str, Any]: - dtype_float = torch.float32 - dtype_int = torch.int32 N = 10000 - return { - "points": torch.empty((N, 3), device=self.device, dtype=dtype_float) - .uniform_(-1000.0, 1000.0) - .flatten(), - "indices": torch.full((N,), -1, device=self.device, dtype=dtype_int), + "points": RandTensor((N * 3,), -1000.0, 1000.0), + "indices": OutTensor((N,), dtype="int32"), "N": N, } diff --git a/challenges/medium/55_attn_w_linear_bias/challenge.py b/challenges/medium/55_attn_w_linear_bias/challenge.py index 96780a99..b2605da2 100644 --- a/challenges/medium/55_attn_w_linear_bias/challenge.py +++ b/challenges/medium/55_attn_w_linear_bias/challenge.py @@ -2,7 +2,7 @@ from typing import Any, Dict, List import torch -from core.challenge_base import ChallengeBase +from core.challenge_base import ChallengeBase, OutTensor, RandTensor class Challenge(ChallengeBase): @@ -40,6 +40,21 @@ def reference_impl( attn = torch.softmax(attn, dim=1) # M , N torch.matmul(attn, V, out=output) + def reference_impl_jax(self, Q, K, V, M, N, d, alpha): + import jax + import jax.numpy as jnp + + scale = d**0.5 + attn = jnp.matmul(Q, K.T, precision=jax.lax.Precision.HIGHEST) / scale + + pos_bias = alpha * (jnp.arange(M).reshape(M, 1) - jnp.arange(N).reshape(1, N)).astype( + attn.dtype + ) + attn = attn + pos_bias + + attn = jax.nn.softmax(attn, axis=1) + return jnp.matmul(attn, V, precision=jax.lax.Precision.HIGHEST) + def get_solve_signature(self) -> Dict[str, tuple]: return { "Q": (ctypes.POINTER(ctypes.c_float), "in"), @@ -183,10 +198,14 @@ def generate_functional_test(self) -> List[Dict[str, Any]]: return tests def generate_performance_test(self) -> Dict[str, Any]: - dtype = torch.float32 M, N, d = 2048, 2048, 1024 - Q = torch.empty((M, d), device=self.device, dtype=dtype).uniform_(-0.1, 0.1) - K = torch.empty((N, d), device=self.device, dtype=dtype).uniform_(-0.1, 0.1) - V = torch.empty((N, d), device=self.device, dtype=dtype).uniform_(-0.1, 0.1) - output = torch.empty(M, d, device=self.device, dtype=dtype) - return {"Q": Q, "K": K, "V": V, "output": output, "M": M, "N": N, "d": d, "alpha": 0.5} + return { + "Q": RandTensor((M, d), -0.1, 0.1), + "K": RandTensor((N, d), -0.1, 0.1), + "V": RandTensor((N, d), -0.1, 0.1), + "output": OutTensor((M, d)), + "M": M, + "N": N, + "d": d, + "alpha": 0.5, + } diff --git a/challenges/medium/75_sparse_matrix_dense_matrix_multiplication/challenge.py b/challenges/medium/75_sparse_matrix_dense_matrix_multiplication/challenge.py index 9daf45f7..c7e9a73b 100644 --- a/challenges/medium/75_sparse_matrix_dense_matrix_multiplication/challenge.py +++ b/challenges/medium/75_sparse_matrix_dense_matrix_multiplication/challenge.py @@ -46,6 +46,13 @@ def reference_impl( result = torch.matmul(A_matrix, B_matrix) C.copy_(result.view(C.shape)) + def reference_impl_jax(self, A, B, M, N, K, nnz): + import jax.numpy as jnp + + A_matrix = A.reshape(M, N) + B_matrix = B.reshape(N, K) + return jnp.matmul(A_matrix, B_matrix) + def get_solve_signature(self) -> Dict[str, tuple]: return { "A": (ctypes.POINTER(ctypes.c_float), "in"), diff --git a/challenges/medium/76_adder_transformer/challenge.py b/challenges/medium/76_adder_transformer/challenge.py index 0aafad71..5a4ed587 100644 --- a/challenges/medium/76_adder_transformer/challenge.py +++ b/challenges/medium/76_adder_transformer/challenge.py @@ -204,6 +204,125 @@ def reference_impl( next_token = last_logits.argmax(dim=-1).to(torch.int32) seq = torch.cat([seq, next_token.unsqueeze(1)], dim=1) + def reference_impl_jax(self, prompts, weights, batch_size): + import jax + import jax.numpy as jnp + + prompts = jnp.asarray(prompts, dtype=jnp.int32) + weights = jnp.asarray(weights, dtype=jnp.float32) + + max_len = PROMPT_LEN + OUTPUT_DIGITS + + embed_w = weights[O_EMBED : O_EMBED + 2] + q_w = weights[O_QPROJ : O_QPROJ + 2] + v_w = weights[O_VPROJ] + gate_w = weights[O_GATE : O_GATE + 2] + carry_w = weights[O_CARRY] + norm_w = weights[O_NORM : O_NORM + 2] + + digits = jnp.arange(VOCAB_SIZE, dtype=jnp.float32) + embed_table = jnp.stack( + [embed_w[0] - embed_w[1] * digits * digits, -digits], axis=-1 + ) # [10, 2] + + positions = jnp.arange(max_len, dtype=jnp.float32) + angles = positions * OMEGA + cos_a = jnp.cos(angles) + sin_a = jnp.sin(angles) + + def unit_rms_norm(x): + return x * jax.lax.rsqrt(jnp.mean(x * x, axis=-1, keepdims=True) + RMS_EPS) + + def forward_last(seq, cur_len): + # seq: [batch, max_len] int32 (padded); cur_len: scalar valid length. + h = embed_table[seq.astype(jnp.int32)] # [batch, max_len, 2] + + h_norm = unit_rms_norm(h) + + q = jnp.stack([h_norm[..., 0] * q_w[0], h_norm[..., 0] * q_w[1]], axis=-1) + k = jnp.stack([h_norm[..., 0], jnp.zeros_like(h_norm[..., 0])], axis=-1) + v = jnp.stack([h_norm[..., 1] * v_w, jnp.zeros_like(h_norm[..., 1])], axis=-1) + + q = unit_rms_norm(q) + k = unit_rms_norm(k) + + q_rot = jnp.stack( + [ + q[..., 0] * cos_a - q[..., 1] * sin_a, + q[..., 0] * sin_a + q[..., 1] * cos_a, + ], + axis=-1, + ) + k_rot = jnp.stack( + [ + k[..., 0] * cos_a - k[..., 1] * sin_a, + k[..., 0] * sin_a + k[..., 1] * cos_a, + ], + axis=-1, + ) + + q_rot = q_rot[:, None, :, :] + k_rot = k_rot[:, None, :, :] + v = v[:, None, :, :] + + attn_scores = ( + jnp.matmul(q_rot, jnp.swapaxes(k_rot, -2, -1), precision="highest") * ATTN_SCALE + ) + row = jnp.arange(max_len)[:, None] + col = jnp.arange(max_len)[None, :] + # Causal mask plus padding: queries/keys at index >= cur_len excluded. + mask = (col > row) | (col >= cur_len) + attn_scores = jnp.where(mask[None, None, :, :], -jnp.inf, attn_scores) + attn_probs = jax.nn.softmax(attn_scores, axis=-1) + attn_out = jnp.matmul(attn_probs, v, precision="highest")[ + :, 0, :, : + ] # [batch, max_len, 2] + + o = jnp.stack([jnp.zeros_like(attn_out[..., 0]), attn_out[..., 0]], axis=-1) + h = h + o + + h_norm2 = unit_rms_norm(h) + + a_gate = gate_w[0] + c_gate = gate_w[1] + g0 = h_norm2[..., 0] * a_gate + h_norm2[..., 1] * c_gate + g1 = h_norm2[..., 0] * (a_gate - c_gate / EMBED_CONST) + h_norm2[..., 1] * c_gate + gate = jnp.stack([g0, g1], axis=-1) + + base = h_norm2[..., 0] + up = jnp.broadcast_to(base[..., None], gate.shape) + mix = jax.nn.silu(gate) * up + mlp_out = jnp.stack( + [jnp.zeros_like(base), carry_w * (mix[..., 1] - mix[..., 0])], axis=-1 + ) + h = h + mlp_out + + rms = jnp.sqrt(jnp.mean(h * h, axis=-1, keepdims=True) + RMS_EPS) + h = (h / rms) * norm_w + + logits = jnp.matmul(h, embed_table.T, precision="highest") # [batch, max_len, 10] + # Logits at the last valid position (cur_len - 1). + last = jnp.take_along_axis( + logits, jnp.full((logits.shape[0], 1, VOCAB_SIZE), cur_len - 1), axis=1 + )[:, 0, :] + return last + + bsz = prompts.shape[0] + seq0 = jnp.zeros((bsz, max_len), dtype=jnp.int32) + seq0 = seq0.at[:, :PROMPT_LEN].set(prompts) + + def step(carry, step_idx): + seq, cur_len = carry + last_logits = forward_last(seq, cur_len) + next_token = jnp.argmax(last_logits, axis=-1).astype(jnp.int32) + seq = seq.at[:, cur_len].set(next_token) + return (seq, cur_len + 1), last_logits + + (_, _), outputs = jax.lax.scan(step, (seq0, PROMPT_LEN), jnp.arange(OUTPUT_DIGITS)) + # outputs: [OUTPUT_DIGITS, batch, VOCAB_SIZE] -> [batch, OUTPUT_DIGITS, VOCAB_SIZE] + output = jnp.transpose(outputs, (1, 0, 2)).astype(jnp.float32) + return output + def get_solve_signature(self) -> Dict[str, tuple]: return { "prompts": (ctypes.POINTER(ctypes.c_int), "in"), diff --git a/challenges/medium/80_grouped_query_attention/challenge.py b/challenges/medium/80_grouped_query_attention/challenge.py index edf133bf..20df4930 100644 --- a/challenges/medium/80_grouped_query_attention/challenge.py +++ b/challenges/medium/80_grouped_query_attention/challenge.py @@ -3,7 +3,7 @@ from typing import Any, Dict, List import torch -from core.challenge_base import ChallengeBase +from core.challenge_base import ChallengeBase, OutTensor, RandnTensor class Challenge(ChallengeBase): @@ -48,6 +48,23 @@ def reference_impl( # Weighted sum of values: (num_q_heads, seq_len, head_dim) output.copy_(torch.bmm(attn_weights, V_expanded)) + def reference_impl_jax(self, Q, K, V, num_q_heads, num_kv_heads, seq_len, head_dim): + import jax + import jax.numpy as jnp + + num_groups = num_q_heads // num_kv_heads + scale = 1.0 / math.sqrt(head_dim) + + K_expanded = jnp.repeat(K, num_groups, axis=0) + V_expanded = jnp.repeat(V, num_groups, axis=0) + + scores = ( + jnp.matmul(Q, jnp.transpose(K_expanded, (0, 2, 1)), precision=jax.lax.Precision.HIGHEST) + * scale + ) + attn_weights = jax.nn.softmax(scores, axis=-1) + return jnp.matmul(attn_weights, V_expanded, precision=jax.lax.Precision.HIGHEST) + def get_solve_signature(self) -> Dict[str, tuple]: return { "Q": (ctypes.POINTER(ctypes.c_float), "in"), @@ -167,6 +184,15 @@ def generate_functional_test(self) -> List[Dict[str, Any]]: return tests def generate_performance_test(self) -> Dict[str, Any]: - torch.manual_seed(0) # LLaMA-3 8B style: 32 Q heads, 8 KV heads, head_dim=128 - return self._make_test_case(32, 8, 1024, 128) + num_q_heads, num_kv_heads, seq_len, head_dim = 32, 8, 1024, 128 + return { + "Q": RandnTensor((num_q_heads, seq_len, head_dim)), + "K": RandnTensor((num_kv_heads, seq_len, head_dim)), + "V": RandnTensor((num_kv_heads, seq_len, head_dim)), + "output": OutTensor((num_q_heads, seq_len, head_dim)), + "num_q_heads": num_q_heads, + "num_kv_heads": num_kv_heads, + "seq_len": seq_len, + "head_dim": head_dim, + } diff --git a/challenges/medium/81_int4_matmul/challenge.py b/challenges/medium/81_int4_matmul/challenge.py index 2b2f36cd..64e10567 100644 --- a/challenges/medium/81_int4_matmul/challenge.py +++ b/challenges/medium/81_int4_matmul/challenge.py @@ -52,6 +52,25 @@ def reference_impl( # MatMul: x [M, K] @ w_dequant.T [K, N] = y [M, N] y.copy_((x.float() @ w_dequant.T).half()) + def reference_impl_jax(self, x, w_q, scales, M, N, K, group_size): + import jax.numpy as jnp + + # Unpack INT4 weights from packed uint8 bytes. + w_high = ((w_q >> 4) & 0xF).astype(jnp.int32) - 8 # [N, K//2] + w_low = (w_q & 0xF).astype(jnp.int32) - 8 # [N, K//2] + + # Interleave high and low nibbles to reconstruct [N, K] + w_int = jnp.stack([w_high, w_low], axis=-1).reshape(N, K) # [N, K] + + # Apply group-wise scales: dequantize each group + n_groups = K // group_size + w_groups = w_int.reshape(N, n_groups, group_size).astype(jnp.float32) + scales_f = scales.astype(jnp.float32)[..., None] # [N, n_groups, 1] + w_dequant = (w_groups * scales_f).reshape(N, K) # [N, K] + + # MatMul: x [M, K] @ w_dequant.T [K, N] = y [M, N] + return (x.astype(jnp.float32) @ w_dequant.T).astype(jnp.float16) + def get_solve_signature(self) -> Dict[str, tuple]: return { "x": (ctypes.POINTER(ctypes.c_uint16), "in"), diff --git a/challenges/medium/82_linear_recurrence/challenge.py b/challenges/medium/82_linear_recurrence/challenge.py index c09a5d0c..fe2c3bd1 100644 --- a/challenges/medium/82_linear_recurrence/challenge.py +++ b/challenges/medium/82_linear_recurrence/challenge.py @@ -2,7 +2,7 @@ from typing import Any, Dict, List import torch -from core.challenge_base import ChallengeBase +from core.challenge_base import ChallengeBase, OutTensor, RandnTensor, RandTensor class Challenge(ChallengeBase): @@ -31,6 +31,27 @@ def reference_impl( out[:, t] = a[:, t] * out[:, t - 1] + x[:, t] h.copy_(out) + def reference_impl_jax(self, a, x, B, L): + import jax + import jax.numpy as jnp + + a = jnp.asarray(a, dtype=jnp.float32) + x = jnp.asarray(x, dtype=jnp.float32) + + # out[:, 0] = x[:, 0]; out[:, t] = a[:, t] * out[:, t-1] + x[:, t] + # Scan over time dimension (axis 1). a_t, x_t are (B,). + def step(carry, inp): + a_t, x_t = inp + new = a_t * carry + x_t + return new, new + + a_T = jnp.transpose(a, (1, 0)) # (L, B) + x_T = jnp.transpose(x, (1, 0)) # (L, B) + init = x_T[0] # out[:, 0] + _, ys = jax.lax.scan(step, init, (a_T[1:], x_T[1:])) # ys: (L-1, B) + out = jnp.concatenate([init[None, :], ys], axis=0) # (L, B) + return jnp.transpose(out, (1, 0)) # (B, L) + def get_solve_signature(self) -> Dict[str, tuple]: return { "a": (ctypes.POINTER(ctypes.c_float), "in"), @@ -111,6 +132,12 @@ def generate_functional_test(self) -> List[Dict[str, Any]]: return tests def generate_performance_test(self) -> Dict[str, Any]: - torch.manual_seed(0) # B=64 sequences, L=16384 tokens — typical long-context SSM workload - return self._make_test_case(64, 16384) + B, L = 64, 16384 + return { + "a": RandTensor((B, L), 0.0, 1.0), + "x": RandnTensor((B, L)), + "h": OutTensor((B, L)), + "B": B, + "L": L, + } diff --git a/challenges/medium/84_swiglu_mlp_block/challenge.py b/challenges/medium/84_swiglu_mlp_block/challenge.py index 9d3a704a..bab54e84 100644 --- a/challenges/medium/84_swiglu_mlp_block/challenge.py +++ b/challenges/medium/84_swiglu_mlp_block/challenge.py @@ -3,7 +3,7 @@ import torch import torch.nn.functional as F -from core.challenge_base import ChallengeBase +from core.challenge_base import ChallengeBase, OutTensor, RandnTensor class Challenge(ChallengeBase): @@ -38,6 +38,14 @@ def reference_impl( hidden = F.silu(gate) * up # [M, d_ffn] output.copy_(hidden @ W_down) # [M, d_model] + def reference_impl_jax(self, x, W_gate, W_up, W_down, M, d_model, d_ffn): + import jax + + gate = x @ W_gate # [M, d_ffn] + up = x @ W_up # [M, d_ffn] + hidden = jax.nn.silu(gate) * up # [M, d_ffn] + return hidden @ W_down # [M, d_model] + def get_solve_signature(self) -> Dict[str, tuple]: return { "x": (ctypes.POINTER(ctypes.c_float), "in"), @@ -149,6 +157,15 @@ def generate_functional_test(self) -> List[Dict[str, Any]]: return tests def generate_performance_test(self) -> Dict[str, Any]: - torch.manual_seed(0) # LLaMA-3 8B style: d_model=4096, d_ffn=14336, M=512 (batch=4 x seq=128) - return self._make_test_case(512, 4096, 14336) + M, d_model, d_ffn = 512, 4096, 14336 + return { + "x": RandnTensor((M, d_model), std=0.1), + "W_gate": RandnTensor((d_model, d_ffn), std=0.02), + "W_up": RandnTensor((d_model, d_ffn), std=0.02), + "W_down": RandnTensor((d_ffn, d_model), std=0.02), + "output": OutTensor((M, d_model)), + "M": M, + "d_model": d_model, + "d_ffn": d_ffn, + } diff --git a/challenges/medium/85_lora_linear/challenge.py b/challenges/medium/85_lora_linear/challenge.py index 8b9d6337..60b4d707 100644 --- a/challenges/medium/85_lora_linear/challenge.py +++ b/challenges/medium/85_lora_linear/challenge.py @@ -2,7 +2,7 @@ from typing import Any, Dict, List import torch -from core.challenge_base import ChallengeBase +from core.challenge_base import ChallengeBase, FullTensor, OutTensor, RandnTensor class Challenge(ChallengeBase): @@ -41,6 +41,17 @@ def reference_impl( output.copy_(base + lora_scale * delta) + def reference_impl_jax(self, x, W, A, B, batch, d_in, d_out, rank, lora_scale): + + # Base linear: output = x @ W^T + base = x @ W.T + + # LoRA path: delta = lora_scale * (x @ A^T) @ B^T + lora_hidden = x @ A.T # (batch, rank) + delta = lora_hidden @ B.T # (batch, d_out) + + return base + lora_scale * delta + def get_solve_signature(self) -> Dict[str, tuple]: return { "x": (ctypes.POINTER(ctypes.c_float), "in"), @@ -158,6 +169,17 @@ def generate_functional_test(self) -> List[Dict[str, Any]]: return tests def generate_performance_test(self) -> Dict[str, Any]: - torch.manual_seed(0) # LLaMA-style: d_in=d_out=4096, rank=64, batch=256 - return self._make_test_case(256, 4096, 4096, 64, lora_scale=0.015625) + batch, d_in, d_out, rank, lora_scale = 256, 4096, 4096, 64, 0.015625 + return { + "x": RandnTensor((batch, d_in)), + "W": RandnTensor((d_out, d_in), std=0.02), + "A": RandnTensor((rank, d_in), std=0.02), + "B": FullTensor((d_out, rank), 0.0), + "output": OutTensor((batch, d_out)), + "batch": batch, + "d_in": d_in, + "d_out": d_out, + "rank": rank, + "lora_scale": lora_scale, + } diff --git a/challenges/medium/87_speculative_decoding_verification/challenge.py b/challenges/medium/87_speculative_decoding_verification/challenge.py index 5857e906..32962c10 100644 --- a/challenges/medium/87_speculative_decoding_verification/challenge.py +++ b/challenges/medium/87_speculative_decoding_verification/challenge.py @@ -65,6 +65,62 @@ def reference_impl( bonus_tok = int(torch.searchsorted(cdf.contiguous(), r).item()) output_tokens[b, T] = min(bonus_tok, V - 1) + def reference_impl_jax(self, draft_tokens, draft_probs, target_probs, uniform_samples, B, T, V): + import jax.numpy as jnp + + draft_tokens = jnp.asarray(draft_tokens, dtype=jnp.int32) + draft_probs = jnp.asarray(draft_probs, dtype=jnp.float32) + target_probs = jnp.asarray(target_probs, dtype=jnp.float32) + uniform_samples = jnp.asarray(uniform_samples, dtype=jnp.float32) + + # Gather p = draft_probs[b, i, tok] and q = target_probs[b, i, tok]. + tok = draft_tokens.astype(jnp.int32) # [B, T] + p = jnp.take_along_axis(draft_probs, tok[..., None], axis=2)[..., 0] # [B, T] + q = jnp.take_along_axis(target_probs, tok[..., None], axis=2)[..., 0] # [B, T] + alpha = jnp.minimum(1.0, q / p) # [B, T] + + accept = uniform_samples[:, :T] < alpha # [B, T] True where accepted + reject = ~accept + + # First rejection position per batch (T if none rejected -> all accepted). + any_reject = jnp.any(reject, axis=1) # [B] + first_reject = jnp.argmax(reject.astype(jnp.int32), axis=1) # [B], 0 if none + first_reject = jnp.where(any_reject, first_reject, T) # [B] + + # Resampled token at every (b, i): clamp(target - draft, 0), normalize, cdf, + # searchsorted(cdf, uniform_samples[b, T]). Uniform fallback when total == 0. + adjusted = jnp.clip(target_probs - draft_probs, a_min=0.0) # [B, T, V] + total = jnp.sum(adjusted, axis=2, keepdims=True) # [B, T, 1] + uniform_dist = jnp.ones_like(adjusted) / V + normalized = jnp.where(total > 0.0, adjusted / total, uniform_dist) # [B, T, V] + cdf = jnp.cumsum(normalized, axis=2) # [B, T, V] + r = uniform_samples[:, T] # [B] + # searchsorted (side='left') over the last axis for scalar r per batch. + resample_tok = jnp.sum((cdf < r[:, None, None]).astype(jnp.int32), axis=2) # [B, T] + resample_tok = jnp.minimum(resample_tok, V - 1) # [B, T] + + # Bonus token (all accepted): cumsum(target_probs[b, T-1]) searchsorted r. + cdf_bonus = jnp.cumsum(target_probs[:, T - 1, :], axis=1) # [B, V] + bonus_tok = jnp.sum((cdf_bonus < r[:, None]).astype(jnp.int32), axis=1) # [B] + bonus_tok = jnp.minimum(bonus_tok, V - 1) # [B] + + # Assemble output_tokens [B, T+1], all zeros by default. + positions = jnp.arange(T)[None, :] # [1, T] + fr = first_reject[:, None] # [B, 1] + + out_first_T = jnp.zeros((B, T), dtype=jnp.int32) + # Accepted positions i < first_reject -> draft token. + out_first_T = jnp.where(positions < fr, tok, out_first_T) + # Exactly the first-reject position -> resampled token (only when one exists). + is_reject_pos = (positions == fr) & any_reject[:, None] + out_first_T = jnp.where(is_reject_pos, resample_tok, out_first_T) + + # Position T column: bonus token only when all accepted, else 0. + out_last = jnp.where(any_reject, 0, bonus_tok).astype(jnp.int32)[:, None] # [B, 1] + + output_tokens = jnp.concatenate([out_first_T, out_last], axis=1) # [B, T+1] + return output_tokens.astype(jnp.int32) + def get_solve_signature(self) -> Dict[str, tuple]: return { "draft_tokens": (ctypes.POINTER(ctypes.c_int), "in"), diff --git a/challenges/medium/90_causal_depthwise_conv1d/challenge.py b/challenges/medium/90_causal_depthwise_conv1d/challenge.py index f7039127..bcfc91c1 100644 --- a/challenges/medium/90_causal_depthwise_conv1d/challenge.py +++ b/challenges/medium/90_causal_depthwise_conv1d/challenge.py @@ -3,7 +3,7 @@ import torch import torch.nn.functional as F -from core.challenge_base import ChallengeBase +from core.challenge_base import ChallengeBase, OutTensor, RandnTensor class Challenge(ChallengeBase): @@ -46,6 +46,31 @@ def reference_impl( output.copy_(result.permute(0, 2, 1)) # (B, L, D) + def reference_impl_jax(self, x, weight, bias, B, L, D, K): + import jax + import jax.numpy as jnp + + # x (B, L, D) -> (B, D, L) for conv + x_t = jnp.transpose(x, (0, 2, 1)) # (B, D, L) + + # Causal padding: K-1 zeros on the left. + x_padded = jnp.pad(x_t, ((0, 0), (0, 0), (K - 1, 0))) # (B, D, L + K - 1) + + # Depthwise (groups=D) cross-correlation with flipped kernel, + # mirroring the torch reference. rhs shape (out=D, in/groups=1, K). + w = jnp.flip(weight, axis=1).reshape(D, 1, K) + result = jax.lax.conv_general_dilated( + x_padded, + w, + window_strides=(1,), + padding="VALID", + feature_group_count=D, + precision=jax.lax.Precision.HIGHEST, + ) # (B, D, L) + result = result + bias.reshape(1, D, 1) + + return jnp.transpose(result, (0, 2, 1)) # (B, L, D) + def get_solve_signature(self) -> Dict[str, tuple]: return { "x": (ctypes.POINTER(ctypes.c_float), "in"), @@ -163,16 +188,11 @@ def make_case(B, L, D, K, x_vals=None, w_vals=None, b_vals=None): def generate_performance_test(self) -> Dict[str, Any]: B, L, D, K = 8, 2048, 4096, 4 - dtype = torch.float32 - x = torch.randn(B, L, D, device=self.device, dtype=dtype) - weight = torch.randn(D, K, device=self.device, dtype=dtype) - bias = torch.randn(D, device=self.device, dtype=dtype) - output = torch.empty(B, L, D, device=self.device, dtype=dtype) return { - "x": x, - "weight": weight, - "bias": bias, - "output": output, + "x": RandnTensor((B, L, D)), + "weight": RandnTensor((D, K)), + "bias": RandnTensor((D,)), + "output": OutTensor((B, L, D)), "B": B, "L": L, "D": D, diff --git a/challenges/medium/92_decaying_causal_attention/challenge.py b/challenges/medium/92_decaying_causal_attention/challenge.py index 1d9380dd..8db0bf0e 100644 --- a/challenges/medium/92_decaying_causal_attention/challenge.py +++ b/challenges/medium/92_decaying_causal_attention/challenge.py @@ -3,7 +3,7 @@ from typing import Any, Dict, List import torch -from core.challenge_base import ChallengeBase +from core.challenge_base import ChallengeBase, OutTensor, RandnTensor class Challenge(ChallengeBase): @@ -39,6 +39,18 @@ def reference_impl( attn = torch.matmul(Q, K.T) / scale output.copy_(torch.matmul(attn * decay_mask, V)) + def reference_impl_jax(self, Q, K, V, seq_len, d_model, gamma): + import jax + import jax.numpy as jnp + + scale = math.sqrt(d_model) + positions = jnp.arange(seq_len, dtype=Q.dtype) + distances = positions.reshape(seq_len, 1) - positions.reshape(1, seq_len) + causal = (distances >= 0).astype(Q.dtype) + decay_mask = jnp.power(gamma, jnp.maximum(distances, 0)) * causal + attn = jnp.matmul(Q, K.T, precision=jax.lax.Precision.HIGHEST) / scale + return jnp.matmul(attn * decay_mask, V, precision=jax.lax.Precision.HIGHEST) + def get_solve_signature(self) -> Dict[str, tuple]: return { "Q": (ctypes.POINTER(ctypes.c_float), "in"), @@ -134,6 +146,14 @@ def generate_functional_test(self) -> List[Dict[str, Any]]: return tests def generate_performance_test(self) -> Dict[str, Any]: - torch.manual_seed(0) # Typical LLM head: seq_len=4096, head_dim=64 - return self._make_test_case(4096, 64, gamma=0.9) + seq_len, d_model = 4096, 64 + return { + "Q": RandnTensor((seq_len, d_model)), + "K": RandnTensor((seq_len, d_model)), + "V": RandnTensor((seq_len, d_model)), + "output": OutTensor((seq_len, d_model)), + "seq_len": seq_len, + "d_model": d_model, + "gamma": 0.9, + } diff --git a/challenges/medium/94_ssm_selective_scan/challenge.py b/challenges/medium/94_ssm_selective_scan/challenge.py index e96681ba..e8140537 100644 --- a/challenges/medium/94_ssm_selective_scan/challenge.py +++ b/challenges/medium/94_ssm_selective_scan/challenge.py @@ -62,6 +62,43 @@ def reference_impl( y_t = torch.einsum("bn,bdn->bd", C_t, h) + skip * u_t # (batch, d_model) y[:, t, :] = y_t + def reference_impl_jax(self, u, delta, A, B, C, skip, batch, seq_len, d_model, d_state): + import jax + import jax.numpy as jnp + + u = jnp.asarray(u, dtype=jnp.float32) # (batch, seq_len, d_model) + delta = jnp.asarray(delta, dtype=jnp.float32) # (batch, seq_len, d_model) + A = jnp.asarray(A, dtype=jnp.float32) # (d_model, d_state) + B = jnp.asarray(B, dtype=jnp.float32) # (batch, seq_len, d_state) + C = jnp.asarray(C, dtype=jnp.float32) # (batch, seq_len, d_state) + skip = jnp.asarray(skip, dtype=jnp.float32) # (d_model,) + + batch = u.shape[0] + d_model = u.shape[2] + d_state = A.shape[1] + + # Move sequence axis to front for scanning. + u_t = jnp.transpose(u, (1, 0, 2)) # (seq_len, batch, d_model) + delta_t = jnp.transpose(delta, (1, 0, 2)) # (seq_len, batch, d_model) + B_t = jnp.transpose(B, (1, 0, 2)) # (seq_len, batch, d_state) + C_t = jnp.transpose(C, (1, 0, 2)) # (seq_len, batch, d_state) + + A_b = A[None, :, :] # (1, d_model, d_state) + + def step(h, inp): + dt, ut, bt, ct = inp # dt,ut:(batch,d_model) bt,ct:(batch,d_state) + A_bar = jnp.exp(dt[:, :, None] * A_b) # (batch, d_model, d_state) + B_bar = dt[:, :, None] * bt[:, None, :] # (batch, d_model, d_state) + h = A_bar * h + B_bar * ut[:, :, None] # (batch, d_model, d_state) + y_t = ( + jnp.einsum("bn,bdn->bd", ct, h, precision=jax.lax.Precision.HIGHEST) + skip * ut + ) # (batch, d_model) + return h, y_t + + h0 = jnp.zeros((batch, d_model, d_state), dtype=jnp.float32) + _, ys = jax.lax.scan(step, h0, (delta_t, u_t, B_t, C_t)) # (seq_len, batch, d_model) + return jnp.transpose(ys, (1, 0, 2)) # (batch, seq_len, d_model) + def get_solve_signature(self) -> Dict[str, tuple]: return { "u": (ctypes.POINTER(ctypes.c_float), "in"), diff --git a/challenges/medium/96_int8_kv_cache_attention/challenge.py b/challenges/medium/96_int8_kv_cache_attention/challenge.py index 8c8207e5..88a91be8 100644 --- a/challenges/medium/96_int8_kv_cache_attention/challenge.py +++ b/challenges/medium/96_int8_kv_cache_attention/challenge.py @@ -52,6 +52,26 @@ def reference_impl( out = torch.bmm(weights, V_float) # [num_heads, 1, head_dim] output.copy_(out.squeeze(1)) + def reference_impl_jax(self, Q, K_int8, V_int8, k_scale, v_scale, num_heads, seq_len, head_dim): + import jax + import jax.numpy as jnp + + K_float = K_int8.astype(jnp.float32) * jnp.expand_dims(k_scale, axis=-1) + V_float = V_int8.astype(jnp.float32) * jnp.expand_dims(v_scale, axis=-1) + + scale = 1.0 / math.sqrt(head_dim) + scores = ( + jnp.matmul( + jnp.expand_dims(Q, axis=1), + jnp.transpose(K_float, (0, 2, 1)), + precision=jax.lax.Precision.HIGHEST, + ) + * scale + ) + weights = jax.nn.softmax(scores, axis=-1) + out = jnp.matmul(weights, V_float, precision=jax.lax.Precision.HIGHEST) + return jnp.squeeze(out, axis=1) + def get_solve_signature(self) -> Dict[str, tuple]: return { "Q": (ctypes.POINTER(ctypes.c_float), "in"),