Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 58 additions & 0 deletions challenges/medium/104_min_p_sampling/challenge.html
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
<p>
Implement batched <em>min-p sampling</em>, a logit-filtering primitive used in modern LLM serving
stacks such as vLLM, Hugging Face TGI, and llama.cpp. Given a batch of logits
<code>logits</code> of shape \(B \times V\) and a scalar <code>min_p</code> in \([0, 1]\), produce
a batch of filtered probabilities <code>probs</code> of the same shape. For each row \(b\):
</p>
<ol>
<li>Convert logits to probabilities with softmax: \(p_{b,i} = \exp(z_{b,i}) / \sum_j \exp(z_{b,j})\).</li>
<li>Find the maximum probability in the row: \(p^{\max}_b = \max_i p_{b,i}\).</li>
<li>Mask every token whose probability is below \(\text{min\_p} \cdot p^{\max}_b\) (set it to 0).</li>
<li>Renormalize the surviving probabilities so they sum to 1.</li>
</ol>
<p>
Unlike top-k or top-p sampling, min-p does not require any sort: the cutoff is defined relative
to the row's most likely token, which makes it well-suited to GPU parallelization while still
adapting to how peaked or flat each distribution is.
</p>

<h2>Implementation Requirements</h2>
<ul>
<li>Implement <code>solve(logits, probs, min_p, B, V)</code>; do not change the signature or use external libraries beyond the standard GPU frameworks.</li>
<li>Write the result into the provided <code>probs</code> buffer. Every row must sum to exactly 1; masked positions must be exactly 0.</li>
<li>Apply min-p independently per row of the batch.</li>
<li>Use the numerically stable softmax (subtract the per-row max before exponentiating).</li>
</ul>

<h2>Example</h2>
<p>
With <code>B</code> = 2, <code>V</code> = 4, <code>min_p</code> = 0.1:
</p>
<p>
<strong>Input</strong> <code>logits</code> (2&times;4):
\[
\begin{bmatrix} 1.0 & 2.0 & 3.0 & 4.0 \\ -1.0 & 0.0 & 1.0 & -2.0 \end{bmatrix}
\]
</p>
<p>
For row 0, \(\max z = 4\) so \(e = [e^{-3}, e^{-2}, e^{-1}, e^{0}] \approx [0.0498, 0.1353, 0.3679, 1.0000]\).
The first entry falls below the threshold \(\text{min\_p} = 0.1\) (note that \(e^{-3} / \sum e &lt; 0.1 \cdot e^{0} / \sum e\)),
so it is masked. Renormalizing the rest gives row 0 of <code>probs</code>.
For row 1, \(\max z = 1\) and the last entry is masked by the same reasoning.
</p>
<p>
<strong>Output</strong> <code>probs</code> (2&times;4):
\[
\begin{bmatrix} 0.0000 & 0.0900 & 0.2447 & 0.6652 \\ 0.0900 & 0.2447 & 0.6652 & 0.0000 \end{bmatrix}
\]
</p>

<h2>Constraints</h2>
<ul>
<li>1 &le; <code>B</code> &le; 256</li>
<li>1 &le; <code>V</code> &le; 200,000</li>
<li>0.0 &le; <code>min_p</code> &le; 1.0</li>
<li>-50.0 &le; <code>logits[b, i]</code> &le; 50.0</li>
<li>Inputs and outputs are <code>float32</code></li>
<li>Performance is measured with <code>B</code> = 64, <code>V</code> = 128,000, <code>min_p</code> = 0.05</li>
</ul>
221 changes: 221 additions & 0 deletions challenges/medium/104_min_p_sampling/challenge.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,221 @@
import ctypes
from typing import Any, Dict, List

import torch
from core.challenge_base import ChallengeBase


class Challenge(ChallengeBase):
name = "Min-P Sampling"
atol = 1e-05
rtol = 1e-05
num_gpus = 1
access_tier = "free"

def reference_impl(
self,
logits: torch.Tensor,
probs: torch.Tensor,
min_p: float,
B: int,
V: int,
):
assert logits.shape == (B, V)
assert probs.shape == (B, V)
assert logits.dtype == torch.float32
assert probs.dtype == torch.float32

max_logit, _ = torch.max(logits, dim=-1, keepdim=True)
exp_shifted = torch.exp(logits - max_logit)
keep = exp_shifted >= min_p
masked = torch.where(keep, exp_shifted, torch.zeros_like(exp_shifted))
denom = torch.sum(masked, dim=-1, keepdim=True)
probs.copy_(masked / denom)

def get_solve_signature(self) -> Dict[str, tuple]:
return {
"logits": (ctypes.POINTER(ctypes.c_float), "in"),
"probs": (ctypes.POINTER(ctypes.c_float), "out"),
"min_p": (ctypes.c_float, "in"),
"B": (ctypes.c_int, "in"),
"V": (ctypes.c_int, "in"),
}

def generate_example_test(self) -> Dict[str, Any]:
dtype = torch.float32
B, V = 2, 4
logits = torch.tensor(
[[1.0, 2.0, 3.0, 4.0], [-1.0, 0.0, 1.0, -2.0]],
device=self.device,
dtype=dtype,
)
probs = torch.empty(B, V, device=self.device, dtype=dtype)
return {
"logits": logits,
"probs": probs,
"min_p": 0.1,
"B": B,
"V": V,
}

def generate_functional_test(self) -> List[Dict[str, Any]]:
dtype = torch.float32
tests = []

# single row, tiny vocab
B, V = 1, 3
tests.append(
{
"logits": torch.tensor([[1.0, 2.0, 3.0]], device=self.device, dtype=dtype),
"probs": torch.empty(B, V, device=self.device, dtype=dtype),
"min_p": 0.05,
"B": B,
"V": V,
}
)

# tied maxima (multiple winners survive)
B, V = 1, 4
tests.append(
{
"logits": torch.tensor([[3.0, 3.0, 1.0, -2.0]], device=self.device, dtype=dtype),
"probs": torch.empty(B, V, device=self.device, dtype=dtype),
"min_p": 0.5,
"B": B,
"V": V,
}
)

# min_p = 0 (no filtering => plain softmax)
B, V = 2, 5
tests.append(
{
"logits": torch.tensor(
[[-1.0, 0.0, 1.0, 2.0, 3.0], [0.5, 0.5, 0.5, 0.5, 0.5]],
device=self.device,
dtype=dtype,
),
"probs": torch.empty(B, V, device=self.device, dtype=dtype),
"min_p": 0.0,
"B": B,
"V": V,
}
)

# min_p near 1 (only the maximum survives)
B, V = 3, 6
tests.append(
{
"logits": torch.tensor(
[
[0.0, 1.0, 5.0, 2.0, 3.0, -1.0],
[10.0, -10.0, 0.0, 0.0, 0.0, 0.0],
[1.0, 2.0, 3.0, 4.0, 5.0, 6.0],
],
device=self.device,
dtype=dtype,
),
"probs": torch.empty(B, V, device=self.device, dtype=dtype),
"min_p": 0.99,
"B": B,
"V": V,
}
)

# all-zero logits (uniform distribution survives)
B, V = 2, 8
tests.append(
{
"logits": torch.zeros(B, V, device=self.device, dtype=dtype),
"probs": torch.empty(B, V, device=self.device, dtype=dtype),
"min_p": 0.1,
"B": B,
"V": V,
}
)

# power-of-two vocab with mixed positives/negatives
B, V = 4, 128
torch.manual_seed(0)
tests.append(
{
"logits": torch.randn(B, V, device=self.device, dtype=dtype) * 2.0,
"probs": torch.empty(B, V, device=self.device, dtype=dtype),
"min_p": 0.1,
"B": B,
"V": V,
}
)

# non-power-of-two vocab
B, V = 8, 255
torch.manual_seed(1)
tests.append(
{
"logits": torch.randn(B, V, device=self.device, dtype=dtype) * 3.0,
"probs": torch.empty(B, V, device=self.device, dtype=dtype),
"min_p": 0.05,
"B": B,
"V": V,
}
)

# peaked distribution (a few dominating tokens)
B, V = 4, 1024
torch.manual_seed(2)
logits = torch.full((B, V), -5.0, device=self.device, dtype=dtype)
logits[0, 17] = 10.0
logits[0, 99] = 9.5
logits[1, 3] = 8.0
logits[2, 500] = 7.0
logits[3, 1000] = 12.0
logits[3, 0] = 11.0
tests.append(
{
"logits": logits,
"probs": torch.empty(B, V, device=self.device, dtype=dtype),
"min_p": 0.05,
"B": B,
"V": V,
}
)

# realistic batch x vocab
B, V = 16, 32000
torch.manual_seed(3)
tests.append(
{
"logits": torch.randn(B, V, device=self.device, dtype=dtype) * 2.5,
"probs": torch.empty(B, V, device=self.device, dtype=dtype),
"min_p": 0.1,
"B": B,
"V": V,
}
)

# larger batch, smaller vocab
B, V = 64, 4096
torch.manual_seed(4)
tests.append(
{
"logits": torch.randn(B, V, device=self.device, dtype=dtype) * 1.5,
"probs": torch.empty(B, V, device=self.device, dtype=dtype),
"min_p": 0.02,
"B": B,
"V": V,
}
)

return tests

def generate_performance_test(self) -> Dict[str, Any]:
dtype = torch.float32
B, V = 64, 128000
torch.manual_seed(42)
return {
"logits": torch.randn(B, V, device=self.device, dtype=dtype) * 2.0,
"probs": torch.empty(B, V, device=self.device, dtype=dtype),
"min_p": 0.05,
"B": B,
"V": V,
}
4 changes: 4 additions & 0 deletions challenges/medium/104_min_p_sampling/starter/starter.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
#include <cuda_runtime.h>

// logits, probs are device pointers
extern "C" void solve(const float* logits, float* probs, float min_p, int B, int V) {}
14 changes: 14 additions & 0 deletions challenges/medium/104_min_p_sampling/starter/starter.cute.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
import cutlass
import cutlass.cute as cute


# logits, probs are tensors on the GPU
@cute.jit
def solve(
logits: cute.Tensor,
probs: cute.Tensor,
min_p: cute.Float32,
B: cute.Int32,
V: cute.Int32,
):
pass
9 changes: 9 additions & 0 deletions challenges/medium/104_min_p_sampling/starter/starter.jax.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
import jax
import jax.numpy as jnp


# logits are tensors on GPU
@jax.jit
def solve(logits: jax.Array, min_p: float, B: int, V: int) -> jax.Array:
# return output tensor directly
pass
14 changes: 14 additions & 0 deletions challenges/medium/104_min_p_sampling/starter/starter.mojo
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
from std.gpu.host import DeviceContext
from std.memory import UnsafePointer


# logits, probs are device pointers
@export
def solve(
logits: UnsafePointer[Float32, MutExternalOrigin],
probs: UnsafePointer[Float32, MutExternalOrigin],
min_p: Float32,
B: Int32,
V: Int32,
) raises:
pass
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
import torch


# logits, probs are tensors on the GPU
def solve(logits: torch.Tensor, probs: torch.Tensor, min_p: float, B: int, V: int):
pass
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
import torch
import triton
import triton.language as tl


# logits, probs are tensors on the GPU
def solve(logits: torch.Tensor, probs: torch.Tensor, min_p: float, B: int, V: int):
pass
Loading