Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 48 additions & 0 deletions challenges/core/challenge_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,54 @@
from typing import Any, Dict, List


class RandTensor:
"""Uniform random input in [low, high)."""

def __init__(self, shape, low=0.0, high=1.0, dtype="float32"):
self.shape = tuple(shape)
self.low = low
self.high = high
self.dtype = dtype


class RandnTensor:
"""Normal (Gaussian) random input."""

def __init__(self, shape, mean=0.0, std=1.0, dtype="float32"):
self.shape = tuple(shape)
self.mean = mean
self.std = std
self.dtype = dtype


class RandIntTensor:
"""Uniform integer random input in [low, high)."""

def __init__(self, shape, low, high, dtype="int32"):
self.shape = tuple(shape)
self.low = low
self.high = high
self.dtype = dtype


class FullTensor:
"""Constant-filled input (covers zeros / ones / full)."""

def __init__(self, shape, value=0.0, dtype="float32"):
self.shape = tuple(shape)
self.value = value
self.dtype = dtype


class OutTensor:
"""Output buffer (by shape): materialized empty where outputs are written in
place (torch), omitted where they're returned functionally (jax)."""

def __init__(self, shape, dtype="float32"):
self.shape = tuple(shape)
self.dtype = dtype


class ChallengeBase(ABC):
name: str
atol: float
Expand Down
14 changes: 9 additions & 5 deletions challenges/easy/2_matrix_multiplication/challenge.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from typing import Any, Dict, List

import torch
from core.challenge_base import ChallengeBase
from core.challenge_base import ChallengeBase, OutTensor, RandTensor


class Challenge(ChallengeBase):
Expand All @@ -23,6 +23,11 @@ def reference_impl(

torch.matmul(A, B, out=C)

def reference_impl_jax(self, A, B, M, N, K):
import jax.numpy as jnp

return jnp.matmul(A, B)

def get_solve_signature(self) -> Dict[str, tuple]:
return {
"A": (ctypes.POINTER(ctypes.c_float), "in"),
Expand Down Expand Up @@ -126,12 +131,11 @@ def generate_functional_test(self) -> List[Dict[str, Any]]:
return test_cases

def generate_performance_test(self) -> Dict[str, Any]:
dtype = torch.float32
M, N, K = 8192, 6144, 4096
return {
"A": torch.empty(M, N, device=self.device, dtype=dtype).uniform_(-10.0, 10.0),
"B": torch.empty(N, K, device=self.device, dtype=dtype).uniform_(-10.0, 10.0),
"C": torch.empty(M, K, device=self.device, dtype=dtype),
"A": RandTensor((M, N), -10.0, 10.0),
"B": RandTensor((N, K), -10.0, 10.0),
"C": OutTensor((M, K)),
"M": M,
"N": N,
"K": K,
Expand Down
25 changes: 20 additions & 5 deletions challenges/easy/9_1d_convolution/challenge.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from typing import Any, Dict, List

import torch
from core.challenge_base import ChallengeBase
from core.challenge_base import ChallengeBase, OutTensor, RandTensor


class Challenge(ChallengeBase):
Expand Down Expand Up @@ -33,6 +33,22 @@ def reference_impl(
# 'ij,j->i' means: for each window i, multiply with kernel j and sum over j
output.copy_(torch.einsum("ij,j->i", windows, kernel))

def reference_impl_jax(self, input, kernel, input_size, kernel_size):
import jax

# Cross-correlation, valid padding, stride 1.
# Shapes: input (N=1, C=1, W), kernel (O=1, I=1, W).
lhs = input.reshape(1, 1, input_size)
rhs = kernel.reshape(1, 1, kernel_size)
result = jax.lax.conv_general_dilated(
lhs,
rhs,
window_strides=(1,),
padding="VALID",
precision=jax.lax.Precision.HIGHEST,
)
return result.reshape(-1)

def get_solve_signature(self) -> Dict[str, tuple]:
return {
"input": (ctypes.POINTER(ctypes.c_float), "in"),
Expand Down Expand Up @@ -128,13 +144,12 @@ def generate_functional_test(self) -> List[Dict[str, Any]]:
return test_cases

def generate_performance_test(self) -> Dict[str, Any]:
dtype = torch.float32
input_size, kernel_size = 1500000, 2047 # Large convolution for performance testing
output_size = input_size - kernel_size + 1
return {
"input": torch.empty(input_size, device=self.device, dtype=dtype).uniform_(-1.0, 1.0),
"kernel": torch.empty(kernel_size, device=self.device, dtype=dtype).uniform_(-1.0, 1.0),
"output": torch.empty(output_size, device=self.device, dtype=dtype),
"input": RandTensor((input_size,), -1.0, 1.0),
"kernel": RandTensor((kernel_size,), -1.0, 1.0),
"output": OutTensor((output_size,)),
"input_size": input_size,
"kernel_size": kernel_size,
}
31 changes: 21 additions & 10 deletions challenges/hard/12_multi_head_attention/challenge.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from typing import Any, Dict, List

import torch
from core.challenge_base import ChallengeBase
from core.challenge_base import ChallengeBase, OutTensor, RandTensor


class Challenge(ChallengeBase):
Expand Down Expand Up @@ -40,6 +40,22 @@ def reference_impl(
result[:, head * d_k : (head + 1) * d_k] = head_output
output.copy_(result)

def reference_impl_jax(self, Q, K, V, N, d_model, h):
import jax
import jax.numpy as jnp

d_k = d_model // h
# Reshape (N, d_model) -> (N, h, d_k) -> (h, N, d_k)
Q_h = jnp.transpose(jnp.reshape(Q, (N, h, d_k)), (1, 0, 2))
K_h = jnp.transpose(jnp.reshape(K, (N, h, d_k)), (1, 0, 2))
V_h = jnp.transpose(jnp.reshape(V, (N, h, d_k)), (1, 0, 2))
scores = jnp.matmul(Q_h, jnp.transpose(K_h, (0, 2, 1))) / (d_k**0.5)
softmax = jax.nn.softmax(scores, axis=-1)
head_output = jnp.matmul(softmax, V_h) # (h, N, d_k)
# (h, N, d_k) -> (N, h, d_k) -> (N, d_model)
result = jnp.reshape(jnp.transpose(head_output, (1, 0, 2)), (N, d_model))
return result

def get_solve_signature(self) -> Dict[str, tuple]:
return {
"Q": (ctypes.POINTER(ctypes.c_float), "in"),
Expand Down Expand Up @@ -111,16 +127,11 @@ def generate_functional_test(self) -> List[Dict[str, Any]]:
return test_cases

def generate_performance_test(self) -> Dict[str, Any]:
dtype = torch.float32
Q = torch.empty(1024, 1024, device=self.device, dtype=dtype).uniform_(-10.0, 10.0)
K = torch.empty(1024, 1024, device=self.device, dtype=dtype).uniform_(-10.0, 10.0)
V = torch.empty(1024, 1024, device=self.device, dtype=dtype).uniform_(-10.0, 10.0)
output = torch.zeros(1024, 1024, device=self.device, dtype=dtype)
return {
"Q": Q,
"K": K,
"V": V,
"output": output,
"Q": RandTensor((1024, 1024), -10.0, 10.0),
"K": RandTensor((1024, 1024), -10.0, 10.0),
"V": RandTensor((1024, 1024), -10.0, 10.0),
"output": OutTensor((1024, 1024)),
"N": 1024,
"d_model": 1024,
"h": 16,
Expand Down
40 changes: 34 additions & 6 deletions challenges/hard/14_multi_agent_sim/challenge.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from typing import Any, Dict, List

import torch
from core.challenge_base import ChallengeBase
from core.challenge_base import ChallengeBase, OutTensor, RandTensor


class Challenge(ChallengeBase):
Expand Down Expand Up @@ -39,6 +39,37 @@ def reference_impl(self, agents: torch.Tensor, agents_next: torch.Tensor, N: int
agents_next_reshaped[:] = torch.cat([new_positions, new_velocities], dim=1)
agents_next.copy_(agents_next_reshaped.view(-1))

def reference_impl_jax(self, agents, N):
import jax
import jax.numpy as jnp

r = 5.0
r2 = r * r
alpha = 0.05
agents_reshaped = agents.reshape(N, 4)
positions = agents_reshaped[:, :2]
velocities = agents_reshaped[:, 2:]
diff = positions[:, None, :] - positions[None, :, :]
dist_sq = (diff**2).sum(axis=2)
dist_sq = dist_sq + jnp.eye(N) * (r2 + 1)
neighbor_mask = dist_sq < r2
sum_velocities = jnp.matmul(
neighbor_mask.astype(velocities.dtype),
velocities,
precision=jax.lax.Precision.HIGHEST,
)
neighbor_counts = neighbor_mask.sum(axis=1, keepdims=True)
nonzero_mask = neighbor_counts[:, 0] > 0
avg_velocities = jnp.where(
nonzero_mask[:, None],
sum_velocities / jnp.where(neighbor_counts == 0, 1, neighbor_counts),
velocities,
)
new_velocities = velocities + alpha * (avg_velocities - velocities)
new_positions = positions + new_velocities
agents_next = jnp.concatenate([new_positions, new_velocities], axis=1)
return agents_next.reshape(-1)

def get_solve_signature(self) -> Dict[str, tuple]:
return {
"agents": (ctypes.POINTER(ctypes.c_float), "in"),
Expand Down Expand Up @@ -99,11 +130,8 @@ def generate_functional_test(self) -> List[Dict[str, Any]]:
return test_cases

def generate_performance_test(self) -> Dict[str, Any]:
dtype = torch.float32
agents = torch.empty(40000, device=self.device, dtype=dtype).uniform_(-1000.0, 1000.0)
agents_next = torch.empty(40000, device=self.device, dtype=dtype)
return {
"agents": agents,
"agents_next": agents_next,
"agents": RandTensor((40000,), -1000.0, 1000.0),
"agents_next": OutTensor((40000,)),
"N": 10000,
}
25 changes: 25 additions & 0 deletions challenges/hard/20_kmeans_clustering/challenge.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,31 @@ def reference_impl(
final_centroid_x[i] = data_x[mask].mean()
final_centroid_y[i] = data_y[mask].mean()

def reference_impl_jax(
self, data_x, data_y, initial_centroid_x, initial_centroid_y, sample_size, k, max_iterations
):
import jax
import jax.numpy as jnp

final_centroid_x = initial_centroid_x
final_centroid_y = initial_centroid_y
labels = jnp.zeros((sample_size,), dtype=jnp.int32)
for _ in range(max_iterations):
expanded_x = data_x.reshape(-1, 1) - final_centroid_x.reshape(1, -1)
expanded_y = data_y.reshape(-1, 1) - final_centroid_y.reshape(1, -1)
distances = expanded_x**2 + expanded_y**2
labels = jnp.argmin(distances, axis=1).astype(jnp.int32)
onehot = jax.nn.one_hot(labels, k, dtype=data_x.dtype)
counts = onehot.sum(axis=0)
sum_x = jnp.matmul(data_x, onehot, precision=jax.lax.Precision.HIGHEST)
sum_y = jnp.matmul(data_y, onehot, precision=jax.lax.Precision.HIGHEST)
safe_counts = jnp.where(counts == 0, 1, counts)
mean_x = sum_x / safe_counts
mean_y = sum_y / safe_counts
final_centroid_x = jnp.where(counts > 0, mean_x, final_centroid_x)
final_centroid_y = jnp.where(counts > 0, mean_y, final_centroid_y)
return labels, final_centroid_x, final_centroid_y

def get_solve_signature(self) -> Dict[str, tuple]:
return {
"data_x": (ctypes.POINTER(ctypes.c_float), "in"),
Expand Down
28 changes: 21 additions & 7 deletions challenges/hard/53_casual_attention/challenge.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from typing import Any, Dict, List

import torch
from core.challenge_base import ChallengeBase
from core.challenge_base import ChallengeBase, OutTensor, RandTensor


class Challenge(ChallengeBase):
Expand Down Expand Up @@ -30,6 +30,18 @@ def reference_impl(
attn = torch.softmax(attn, dim=1)
torch.matmul(attn, V, out=output)

def reference_impl_jax(self, Q, K, V, M, d):
import jax
import jax.numpy as jnp

scale = d**0.5
attn = jnp.matmul(Q, K.T) / scale

mask = jnp.triu(jnp.ones((M, M), dtype=bool), k=1)
attn = jnp.where(mask, -jnp.inf, attn)
attn = jax.nn.softmax(attn, axis=1)
return jnp.matmul(attn, V)

def get_solve_signature(self) -> Dict[str, tuple]:
return {
"Q": (ctypes.POINTER(ctypes.c_float), "in"),
Expand Down Expand Up @@ -127,10 +139,12 @@ def generate_functional_test(self) -> List[Dict[str, Any]]:
return tests

def generate_performance_test(self) -> Dict[str, Any]:
dtype = torch.float32
M, d = 5000, 128
Q = torch.empty((M, d), device=self.device, dtype=dtype).uniform_(-100, 100)
K = torch.empty((M, d), device=self.device, dtype=dtype).uniform_(-100, 100)
V = torch.empty((M, d), device=self.device, dtype=dtype).uniform_(-100, 100)
output = torch.empty(M, d, device=self.device, dtype=dtype)
return {"Q": Q, "K": K, "V": V, "output": output, "M": M, "d": d}
return {
"Q": RandTensor((M, d), -100.0, 100.0),
"K": RandTensor((M, d), -100.0, 100.0),
"V": RandTensor((M, d), -100.0, 100.0),
"output": OutTensor((M, d)),
"M": M,
"d": d,
}
33 changes: 26 additions & 7 deletions challenges/hard/56_linear_attention/challenge.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from typing import Any, Dict, List

import torch
from core.challenge_base import ChallengeBase
from core.challenge_base import ChallengeBase, OutTensor, RandTensor


class Challenge(ChallengeBase):
Expand Down Expand Up @@ -38,6 +38,23 @@ def reference_impl(

output.copy_(numerator / denominator.unsqueeze(-1)) # (M, d)

def reference_impl_jax(self, Q, K, V, M, d):
import jax.numpy as jnp

# φ(x) = ELU(x) + 1
phi_Q = jnp.where(Q > 0, Q + 1, jnp.exp(Q))
phi_K = jnp.where(K > 0, K + 1, jnp.exp(K))

# S = φ(K)^T V → (d, d)
S = phi_K.T @ V
# z = sum_j φ(K_j) → (d,)
z = phi_K.sum(axis=0)

numerator = phi_Q @ S # (M, d)
denominator = phi_Q @ z # (M,)

return numerator / denominator[:, None] # (M, d)

def get_solve_signature(self) -> Dict[str, tuple]:
return {
"Q": (ctypes.POINTER(ctypes.c_float), "in"),
Expand Down Expand Up @@ -147,10 +164,12 @@ def generate_functional_test(self) -> List[Dict[str, Any]]:
return tests

def generate_performance_test(self) -> Dict[str, Any]:
dtype = torch.float32
M, d = 10000, 128
Q = torch.empty((M, d), device=self.device, dtype=dtype).uniform_(-100, 100)
K = torch.empty((M, d), device=self.device, dtype=dtype).uniform_(-100, 100)
V = torch.empty((M, d), device=self.device, dtype=dtype).uniform_(-100, 100)
output = torch.empty(M, d, device=self.device, dtype=dtype)
return {"Q": Q, "K": K, "V": V, "output": output, "M": M, "d": d}
return {
"Q": RandTensor((M, d), -100, 100),
"K": RandTensor((M, d), -100, 100),
"V": RandTensor((M, d), -100, 100),
"output": OutTensor((M, d)),
"M": M,
"d": d,
}
Loading
Loading