diff --git a/challenges/medium/104_min_p_sampling/challenge.html b/challenges/medium/104_min_p_sampling/challenge.html new file mode 100644 index 00000000..5513d3a1 --- /dev/null +++ b/challenges/medium/104_min_p_sampling/challenge.html @@ -0,0 +1,58 @@ +

+Implement batched min-p sampling, a logit-filtering primitive used in modern LLM serving +stacks such as vLLM, Hugging Face TGI, and llama.cpp. Given a batch of logits +logits of shape \(B \times V\) and a scalar min_p in \([0, 1]\), produce +a batch of filtered probabilities probs of the same shape. For each row \(b\): +

+
    +
  1. Convert logits to probabilities with softmax: \(p_{b,i} = \exp(z_{b,i}) / \sum_j \exp(z_{b,j})\).
  2. +
  3. Find the maximum probability in the row: \(p^{\max}_b = \max_i p_{b,i}\).
  4. +
  5. Mask every token whose probability is below \(\text{min\_p} \cdot p^{\max}_b\) (set it to 0).
  6. +
  7. Renormalize the surviving probabilities so they sum to 1.
  8. +
+

+Unlike top-k or top-p sampling, min-p does not require any sort: the cutoff is defined relative +to the row's most likely token, which makes it well-suited to GPU parallelization while still +adapting to how peaked or flat each distribution is. +

+ +

Implementation Requirements

+ + +

Example

+

+ With B = 2, V = 4, min_p = 0.1: +

+

+ Input logits (2×4): + \[ + \begin{bmatrix} 1.0 & 2.0 & 3.0 & 4.0 \\ -1.0 & 0.0 & 1.0 & -2.0 \end{bmatrix} + \] +

+

+ For row 0, \(\max z = 4\) so \(e = [e^{-3}, e^{-2}, e^{-1}, e^{0}] \approx [0.0498, 0.1353, 0.3679, 1.0000]\). + The first entry falls below the threshold \(\text{min\_p} = 0.1\) (note that \(e^{-3} / \sum e < 0.1 \cdot e^{0} / \sum e\)), + so it is masked. Renormalizing the rest gives row 0 of probs. + For row 1, \(\max z = 1\) and the last entry is masked by the same reasoning. +

+

+ Output probs (2×4): + \[ + \begin{bmatrix} 0.0000 & 0.0900 & 0.2447 & 0.6652 \\ 0.0900 & 0.2447 & 0.6652 & 0.0000 \end{bmatrix} + \] +

+ +

Constraints

+ diff --git a/challenges/medium/104_min_p_sampling/challenge.py b/challenges/medium/104_min_p_sampling/challenge.py new file mode 100644 index 00000000..067a0222 --- /dev/null +++ b/challenges/medium/104_min_p_sampling/challenge.py @@ -0,0 +1,221 @@ +import ctypes +from typing import Any, Dict, List + +import torch +from core.challenge_base import ChallengeBase + + +class Challenge(ChallengeBase): + name = "Min-P Sampling" + atol = 1e-05 + rtol = 1e-05 + num_gpus = 1 + access_tier = "free" + + def reference_impl( + self, + logits: torch.Tensor, + probs: torch.Tensor, + min_p: float, + B: int, + V: int, + ): + assert logits.shape == (B, V) + assert probs.shape == (B, V) + assert logits.dtype == torch.float32 + assert probs.dtype == torch.float32 + + max_logit, _ = torch.max(logits, dim=-1, keepdim=True) + exp_shifted = torch.exp(logits - max_logit) + keep = exp_shifted >= min_p + masked = torch.where(keep, exp_shifted, torch.zeros_like(exp_shifted)) + denom = torch.sum(masked, dim=-1, keepdim=True) + probs.copy_(masked / denom) + + def get_solve_signature(self) -> Dict[str, tuple]: + return { + "logits": (ctypes.POINTER(ctypes.c_float), "in"), + "probs": (ctypes.POINTER(ctypes.c_float), "out"), + "min_p": (ctypes.c_float, "in"), + "B": (ctypes.c_int, "in"), + "V": (ctypes.c_int, "in"), + } + + def generate_example_test(self) -> Dict[str, Any]: + dtype = torch.float32 + B, V = 2, 4 + logits = torch.tensor( + [[1.0, 2.0, 3.0, 4.0], [-1.0, 0.0, 1.0, -2.0]], + device=self.device, + dtype=dtype, + ) + probs = torch.empty(B, V, device=self.device, dtype=dtype) + return { + "logits": logits, + "probs": probs, + "min_p": 0.1, + "B": B, + "V": V, + } + + def generate_functional_test(self) -> List[Dict[str, Any]]: + dtype = torch.float32 + tests = [] + + # single row, tiny vocab + B, V = 1, 3 + tests.append( + { + "logits": torch.tensor([[1.0, 2.0, 3.0]], device=self.device, dtype=dtype), + "probs": torch.empty(B, V, device=self.device, dtype=dtype), + "min_p": 0.05, + "B": B, + "V": V, + } + ) + + # tied maxima (multiple winners survive) + B, V = 1, 4 + tests.append( + { + "logits": torch.tensor([[3.0, 3.0, 1.0, -2.0]], device=self.device, dtype=dtype), + "probs": torch.empty(B, V, device=self.device, dtype=dtype), + "min_p": 0.5, + "B": B, + "V": V, + } + ) + + # min_p = 0 (no filtering => plain softmax) + B, V = 2, 5 + tests.append( + { + "logits": torch.tensor( + [[-1.0, 0.0, 1.0, 2.0, 3.0], [0.5, 0.5, 0.5, 0.5, 0.5]], + device=self.device, + dtype=dtype, + ), + "probs": torch.empty(B, V, device=self.device, dtype=dtype), + "min_p": 0.0, + "B": B, + "V": V, + } + ) + + # min_p near 1 (only the maximum survives) + B, V = 3, 6 + tests.append( + { + "logits": torch.tensor( + [ + [0.0, 1.0, 5.0, 2.0, 3.0, -1.0], + [10.0, -10.0, 0.0, 0.0, 0.0, 0.0], + [1.0, 2.0, 3.0, 4.0, 5.0, 6.0], + ], + device=self.device, + dtype=dtype, + ), + "probs": torch.empty(B, V, device=self.device, dtype=dtype), + "min_p": 0.99, + "B": B, + "V": V, + } + ) + + # all-zero logits (uniform distribution survives) + B, V = 2, 8 + tests.append( + { + "logits": torch.zeros(B, V, device=self.device, dtype=dtype), + "probs": torch.empty(B, V, device=self.device, dtype=dtype), + "min_p": 0.1, + "B": B, + "V": V, + } + ) + + # power-of-two vocab with mixed positives/negatives + B, V = 4, 128 + torch.manual_seed(0) + tests.append( + { + "logits": torch.randn(B, V, device=self.device, dtype=dtype) * 2.0, + "probs": torch.empty(B, V, device=self.device, dtype=dtype), + "min_p": 0.1, + "B": B, + "V": V, + } + ) + + # non-power-of-two vocab + B, V = 8, 255 + torch.manual_seed(1) + tests.append( + { + "logits": torch.randn(B, V, device=self.device, dtype=dtype) * 3.0, + "probs": torch.empty(B, V, device=self.device, dtype=dtype), + "min_p": 0.05, + "B": B, + "V": V, + } + ) + + # peaked distribution (a few dominating tokens) + B, V = 4, 1024 + torch.manual_seed(2) + logits = torch.full((B, V), -5.0, device=self.device, dtype=dtype) + logits[0, 17] = 10.0 + logits[0, 99] = 9.5 + logits[1, 3] = 8.0 + logits[2, 500] = 7.0 + logits[3, 1000] = 12.0 + logits[3, 0] = 11.0 + tests.append( + { + "logits": logits, + "probs": torch.empty(B, V, device=self.device, dtype=dtype), + "min_p": 0.05, + "B": B, + "V": V, + } + ) + + # realistic batch x vocab + B, V = 16, 32000 + torch.manual_seed(3) + tests.append( + { + "logits": torch.randn(B, V, device=self.device, dtype=dtype) * 2.5, + "probs": torch.empty(B, V, device=self.device, dtype=dtype), + "min_p": 0.1, + "B": B, + "V": V, + } + ) + + # larger batch, smaller vocab + B, V = 64, 4096 + torch.manual_seed(4) + tests.append( + { + "logits": torch.randn(B, V, device=self.device, dtype=dtype) * 1.5, + "probs": torch.empty(B, V, device=self.device, dtype=dtype), + "min_p": 0.02, + "B": B, + "V": V, + } + ) + + return tests + + def generate_performance_test(self) -> Dict[str, Any]: + dtype = torch.float32 + B, V = 64, 128000 + torch.manual_seed(42) + return { + "logits": torch.randn(B, V, device=self.device, dtype=dtype) * 2.0, + "probs": torch.empty(B, V, device=self.device, dtype=dtype), + "min_p": 0.05, + "B": B, + "V": V, + } diff --git a/challenges/medium/104_min_p_sampling/starter/starter.cu b/challenges/medium/104_min_p_sampling/starter/starter.cu new file mode 100644 index 00000000..f3718ea6 --- /dev/null +++ b/challenges/medium/104_min_p_sampling/starter/starter.cu @@ -0,0 +1,4 @@ +#include + +// logits, probs are device pointers +extern "C" void solve(const float* logits, float* probs, float min_p, int B, int V) {} diff --git a/challenges/medium/104_min_p_sampling/starter/starter.cute.py b/challenges/medium/104_min_p_sampling/starter/starter.cute.py new file mode 100644 index 00000000..7fe9b95f --- /dev/null +++ b/challenges/medium/104_min_p_sampling/starter/starter.cute.py @@ -0,0 +1,14 @@ +import cutlass +import cutlass.cute as cute + + +# logits, probs are tensors on the GPU +@cute.jit +def solve( + logits: cute.Tensor, + probs: cute.Tensor, + min_p: cute.Float32, + B: cute.Int32, + V: cute.Int32, +): + pass diff --git a/challenges/medium/104_min_p_sampling/starter/starter.jax.py b/challenges/medium/104_min_p_sampling/starter/starter.jax.py new file mode 100644 index 00000000..7f4e9b70 --- /dev/null +++ b/challenges/medium/104_min_p_sampling/starter/starter.jax.py @@ -0,0 +1,9 @@ +import jax +import jax.numpy as jnp + + +# logits are tensors on GPU +@jax.jit +def solve(logits: jax.Array, min_p: float, B: int, V: int) -> jax.Array: + # return output tensor directly + pass diff --git a/challenges/medium/104_min_p_sampling/starter/starter.mojo b/challenges/medium/104_min_p_sampling/starter/starter.mojo new file mode 100644 index 00000000..0fab8d25 --- /dev/null +++ b/challenges/medium/104_min_p_sampling/starter/starter.mojo @@ -0,0 +1,14 @@ +from std.gpu.host import DeviceContext +from std.memory import UnsafePointer + + +# logits, probs are device pointers +@export +def solve( + logits: UnsafePointer[Float32, MutExternalOrigin], + probs: UnsafePointer[Float32, MutExternalOrigin], + min_p: Float32, + B: Int32, + V: Int32, +) raises: + pass diff --git a/challenges/medium/104_min_p_sampling/starter/starter.pytorch.py b/challenges/medium/104_min_p_sampling/starter/starter.pytorch.py new file mode 100644 index 00000000..60544083 --- /dev/null +++ b/challenges/medium/104_min_p_sampling/starter/starter.pytorch.py @@ -0,0 +1,6 @@ +import torch + + +# logits, probs are tensors on the GPU +def solve(logits: torch.Tensor, probs: torch.Tensor, min_p: float, B: int, V: int): + pass diff --git a/challenges/medium/104_min_p_sampling/starter/starter.triton.py b/challenges/medium/104_min_p_sampling/starter/starter.triton.py new file mode 100644 index 00000000..4e5c04ea --- /dev/null +++ b/challenges/medium/104_min_p_sampling/starter/starter.triton.py @@ -0,0 +1,8 @@ +import torch +import triton +import triton.language as tl + + +# logits, probs are tensors on the GPU +def solve(logits: torch.Tensor, probs: torch.Tensor, min_p: float, B: int, V: int): + pass