diff --git a/challenges/medium/104_min_p_sampling/challenge.html b/challenges/medium/104_min_p_sampling/challenge.html
new file mode 100644
index 00000000..5513d3a1
--- /dev/null
+++ b/challenges/medium/104_min_p_sampling/challenge.html
@@ -0,0 +1,58 @@
+
+Implement batched min-p sampling, a logit-filtering primitive used in modern LLM serving
+stacks such as vLLM, Hugging Face TGI, and llama.cpp. Given a batch of logits
+logits of shape \(B \times V\) and a scalar min_p in \([0, 1]\), produce
+a batch of filtered probabilities probs of the same shape. For each row \(b\):
+
+
+ - Convert logits to probabilities with softmax: \(p_{b,i} = \exp(z_{b,i}) / \sum_j \exp(z_{b,j})\).
+ - Find the maximum probability in the row: \(p^{\max}_b = \max_i p_{b,i}\).
+ - Mask every token whose probability is below \(\text{min\_p} \cdot p^{\max}_b\) (set it to 0).
+ - Renormalize the surviving probabilities so they sum to 1.
+
+
+Unlike top-k or top-p sampling, min-p does not require any sort: the cutoff is defined relative
+to the row's most likely token, which makes it well-suited to GPU parallelization while still
+adapting to how peaked or flat each distribution is.
+
+
+Implementation Requirements
+
+ - Implement
solve(logits, probs, min_p, B, V); do not change the signature or use external libraries beyond the standard GPU frameworks.
+ - Write the result into the provided
probs buffer. Every row must sum to exactly 1; masked positions must be exactly 0.
+ - Apply min-p independently per row of the batch.
+ - Use the numerically stable softmax (subtract the per-row max before exponentiating).
+
+
+Example
+
+ With B = 2, V = 4, min_p = 0.1:
+
+
+ Input logits (2×4):
+ \[
+ \begin{bmatrix} 1.0 & 2.0 & 3.0 & 4.0 \\ -1.0 & 0.0 & 1.0 & -2.0 \end{bmatrix}
+ \]
+
+
+ For row 0, \(\max z = 4\) so \(e = [e^{-3}, e^{-2}, e^{-1}, e^{0}] \approx [0.0498, 0.1353, 0.3679, 1.0000]\).
+ The first entry falls below the threshold \(\text{min\_p} = 0.1\) (note that \(e^{-3} / \sum e < 0.1 \cdot e^{0} / \sum e\)),
+ so it is masked. Renormalizing the rest gives row 0 of probs.
+ For row 1, \(\max z = 1\) and the last entry is masked by the same reasoning.
+
+
+ Output probs (2×4):
+ \[
+ \begin{bmatrix} 0.0000 & 0.0900 & 0.2447 & 0.6652 \\ 0.0900 & 0.2447 & 0.6652 & 0.0000 \end{bmatrix}
+ \]
+
+
+Constraints
+
+ - 1 ≤
B ≤ 256
+ - 1 ≤
V ≤ 200,000
+ - 0.0 ≤
min_p ≤ 1.0
+ - -50.0 ≤
logits[b, i] ≤ 50.0
+ - Inputs and outputs are
float32
+ - Performance is measured with
B = 64, V = 128,000, min_p = 0.05
+
diff --git a/challenges/medium/104_min_p_sampling/challenge.py b/challenges/medium/104_min_p_sampling/challenge.py
new file mode 100644
index 00000000..067a0222
--- /dev/null
+++ b/challenges/medium/104_min_p_sampling/challenge.py
@@ -0,0 +1,221 @@
+import ctypes
+from typing import Any, Dict, List
+
+import torch
+from core.challenge_base import ChallengeBase
+
+
+class Challenge(ChallengeBase):
+ name = "Min-P Sampling"
+ atol = 1e-05
+ rtol = 1e-05
+ num_gpus = 1
+ access_tier = "free"
+
+ def reference_impl(
+ self,
+ logits: torch.Tensor,
+ probs: torch.Tensor,
+ min_p: float,
+ B: int,
+ V: int,
+ ):
+ assert logits.shape == (B, V)
+ assert probs.shape == (B, V)
+ assert logits.dtype == torch.float32
+ assert probs.dtype == torch.float32
+
+ max_logit, _ = torch.max(logits, dim=-1, keepdim=True)
+ exp_shifted = torch.exp(logits - max_logit)
+ keep = exp_shifted >= min_p
+ masked = torch.where(keep, exp_shifted, torch.zeros_like(exp_shifted))
+ denom = torch.sum(masked, dim=-1, keepdim=True)
+ probs.copy_(masked / denom)
+
+ def get_solve_signature(self) -> Dict[str, tuple]:
+ return {
+ "logits": (ctypes.POINTER(ctypes.c_float), "in"),
+ "probs": (ctypes.POINTER(ctypes.c_float), "out"),
+ "min_p": (ctypes.c_float, "in"),
+ "B": (ctypes.c_int, "in"),
+ "V": (ctypes.c_int, "in"),
+ }
+
+ def generate_example_test(self) -> Dict[str, Any]:
+ dtype = torch.float32
+ B, V = 2, 4
+ logits = torch.tensor(
+ [[1.0, 2.0, 3.0, 4.0], [-1.0, 0.0, 1.0, -2.0]],
+ device=self.device,
+ dtype=dtype,
+ )
+ probs = torch.empty(B, V, device=self.device, dtype=dtype)
+ return {
+ "logits": logits,
+ "probs": probs,
+ "min_p": 0.1,
+ "B": B,
+ "V": V,
+ }
+
+ def generate_functional_test(self) -> List[Dict[str, Any]]:
+ dtype = torch.float32
+ tests = []
+
+ # single row, tiny vocab
+ B, V = 1, 3
+ tests.append(
+ {
+ "logits": torch.tensor([[1.0, 2.0, 3.0]], device=self.device, dtype=dtype),
+ "probs": torch.empty(B, V, device=self.device, dtype=dtype),
+ "min_p": 0.05,
+ "B": B,
+ "V": V,
+ }
+ )
+
+ # tied maxima (multiple winners survive)
+ B, V = 1, 4
+ tests.append(
+ {
+ "logits": torch.tensor([[3.0, 3.0, 1.0, -2.0]], device=self.device, dtype=dtype),
+ "probs": torch.empty(B, V, device=self.device, dtype=dtype),
+ "min_p": 0.5,
+ "B": B,
+ "V": V,
+ }
+ )
+
+ # min_p = 0 (no filtering => plain softmax)
+ B, V = 2, 5
+ tests.append(
+ {
+ "logits": torch.tensor(
+ [[-1.0, 0.0, 1.0, 2.0, 3.0], [0.5, 0.5, 0.5, 0.5, 0.5]],
+ device=self.device,
+ dtype=dtype,
+ ),
+ "probs": torch.empty(B, V, device=self.device, dtype=dtype),
+ "min_p": 0.0,
+ "B": B,
+ "V": V,
+ }
+ )
+
+ # min_p near 1 (only the maximum survives)
+ B, V = 3, 6
+ tests.append(
+ {
+ "logits": torch.tensor(
+ [
+ [0.0, 1.0, 5.0, 2.0, 3.0, -1.0],
+ [10.0, -10.0, 0.0, 0.0, 0.0, 0.0],
+ [1.0, 2.0, 3.0, 4.0, 5.0, 6.0],
+ ],
+ device=self.device,
+ dtype=dtype,
+ ),
+ "probs": torch.empty(B, V, device=self.device, dtype=dtype),
+ "min_p": 0.99,
+ "B": B,
+ "V": V,
+ }
+ )
+
+ # all-zero logits (uniform distribution survives)
+ B, V = 2, 8
+ tests.append(
+ {
+ "logits": torch.zeros(B, V, device=self.device, dtype=dtype),
+ "probs": torch.empty(B, V, device=self.device, dtype=dtype),
+ "min_p": 0.1,
+ "B": B,
+ "V": V,
+ }
+ )
+
+ # power-of-two vocab with mixed positives/negatives
+ B, V = 4, 128
+ torch.manual_seed(0)
+ tests.append(
+ {
+ "logits": torch.randn(B, V, device=self.device, dtype=dtype) * 2.0,
+ "probs": torch.empty(B, V, device=self.device, dtype=dtype),
+ "min_p": 0.1,
+ "B": B,
+ "V": V,
+ }
+ )
+
+ # non-power-of-two vocab
+ B, V = 8, 255
+ torch.manual_seed(1)
+ tests.append(
+ {
+ "logits": torch.randn(B, V, device=self.device, dtype=dtype) * 3.0,
+ "probs": torch.empty(B, V, device=self.device, dtype=dtype),
+ "min_p": 0.05,
+ "B": B,
+ "V": V,
+ }
+ )
+
+ # peaked distribution (a few dominating tokens)
+ B, V = 4, 1024
+ torch.manual_seed(2)
+ logits = torch.full((B, V), -5.0, device=self.device, dtype=dtype)
+ logits[0, 17] = 10.0
+ logits[0, 99] = 9.5
+ logits[1, 3] = 8.0
+ logits[2, 500] = 7.0
+ logits[3, 1000] = 12.0
+ logits[3, 0] = 11.0
+ tests.append(
+ {
+ "logits": logits,
+ "probs": torch.empty(B, V, device=self.device, dtype=dtype),
+ "min_p": 0.05,
+ "B": B,
+ "V": V,
+ }
+ )
+
+ # realistic batch x vocab
+ B, V = 16, 32000
+ torch.manual_seed(3)
+ tests.append(
+ {
+ "logits": torch.randn(B, V, device=self.device, dtype=dtype) * 2.5,
+ "probs": torch.empty(B, V, device=self.device, dtype=dtype),
+ "min_p": 0.1,
+ "B": B,
+ "V": V,
+ }
+ )
+
+ # larger batch, smaller vocab
+ B, V = 64, 4096
+ torch.manual_seed(4)
+ tests.append(
+ {
+ "logits": torch.randn(B, V, device=self.device, dtype=dtype) * 1.5,
+ "probs": torch.empty(B, V, device=self.device, dtype=dtype),
+ "min_p": 0.02,
+ "B": B,
+ "V": V,
+ }
+ )
+
+ return tests
+
+ def generate_performance_test(self) -> Dict[str, Any]:
+ dtype = torch.float32
+ B, V = 64, 128000
+ torch.manual_seed(42)
+ return {
+ "logits": torch.randn(B, V, device=self.device, dtype=dtype) * 2.0,
+ "probs": torch.empty(B, V, device=self.device, dtype=dtype),
+ "min_p": 0.05,
+ "B": B,
+ "V": V,
+ }
diff --git a/challenges/medium/104_min_p_sampling/starter/starter.cu b/challenges/medium/104_min_p_sampling/starter/starter.cu
new file mode 100644
index 00000000..f3718ea6
--- /dev/null
+++ b/challenges/medium/104_min_p_sampling/starter/starter.cu
@@ -0,0 +1,4 @@
+#include
+
+// logits, probs are device pointers
+extern "C" void solve(const float* logits, float* probs, float min_p, int B, int V) {}
diff --git a/challenges/medium/104_min_p_sampling/starter/starter.cute.py b/challenges/medium/104_min_p_sampling/starter/starter.cute.py
new file mode 100644
index 00000000..7fe9b95f
--- /dev/null
+++ b/challenges/medium/104_min_p_sampling/starter/starter.cute.py
@@ -0,0 +1,14 @@
+import cutlass
+import cutlass.cute as cute
+
+
+# logits, probs are tensors on the GPU
+@cute.jit
+def solve(
+ logits: cute.Tensor,
+ probs: cute.Tensor,
+ min_p: cute.Float32,
+ B: cute.Int32,
+ V: cute.Int32,
+):
+ pass
diff --git a/challenges/medium/104_min_p_sampling/starter/starter.jax.py b/challenges/medium/104_min_p_sampling/starter/starter.jax.py
new file mode 100644
index 00000000..7f4e9b70
--- /dev/null
+++ b/challenges/medium/104_min_p_sampling/starter/starter.jax.py
@@ -0,0 +1,9 @@
+import jax
+import jax.numpy as jnp
+
+
+# logits are tensors on GPU
+@jax.jit
+def solve(logits: jax.Array, min_p: float, B: int, V: int) -> jax.Array:
+ # return output tensor directly
+ pass
diff --git a/challenges/medium/104_min_p_sampling/starter/starter.mojo b/challenges/medium/104_min_p_sampling/starter/starter.mojo
new file mode 100644
index 00000000..0fab8d25
--- /dev/null
+++ b/challenges/medium/104_min_p_sampling/starter/starter.mojo
@@ -0,0 +1,14 @@
+from std.gpu.host import DeviceContext
+from std.memory import UnsafePointer
+
+
+# logits, probs are device pointers
+@export
+def solve(
+ logits: UnsafePointer[Float32, MutExternalOrigin],
+ probs: UnsafePointer[Float32, MutExternalOrigin],
+ min_p: Float32,
+ B: Int32,
+ V: Int32,
+) raises:
+ pass
diff --git a/challenges/medium/104_min_p_sampling/starter/starter.pytorch.py b/challenges/medium/104_min_p_sampling/starter/starter.pytorch.py
new file mode 100644
index 00000000..60544083
--- /dev/null
+++ b/challenges/medium/104_min_p_sampling/starter/starter.pytorch.py
@@ -0,0 +1,6 @@
+import torch
+
+
+# logits, probs are tensors on the GPU
+def solve(logits: torch.Tensor, probs: torch.Tensor, min_p: float, B: int, V: int):
+ pass
diff --git a/challenges/medium/104_min_p_sampling/starter/starter.triton.py b/challenges/medium/104_min_p_sampling/starter/starter.triton.py
new file mode 100644
index 00000000..4e5c04ea
--- /dev/null
+++ b/challenges/medium/104_min_p_sampling/starter/starter.triton.py
@@ -0,0 +1,8 @@
+import torch
+import triton
+import triton.language as tl
+
+
+# logits, probs are tensors on the GPU
+def solve(logits: torch.Tensor, probs: torch.Tensor, min_p: float, B: int, V: int):
+ pass