Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
77 changes: 77 additions & 0 deletions challenges/medium/103_repetition_penalty/challenge.html
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
<p>
Implement the repetition penalty logit processor used by LLM samplers in production
inference engines such as Hugging Face Transformers, vLLM, and TGI. Given a batch of raw
<code>logits</code> from a language model and the <code>input_ids</code> already seen by
each sequence in the batch, downweight the logits of tokens that appeared in the prompt
or prior generation so the model is less likely to repeat them. The processor modifies
<code>logits</code> in place.
</p>

<p>
For each batch <code>b</code> and for every token id <code>v</code> that occurs in
<code>input_ids[b]</code>, update <code>logits[b, v]</code> as follows:
</p>

<p>
\[
\text{logits}[b, v] \leftarrow
\begin{cases}
\text{logits}[b, v] / \text{penalty} & \text{if } \text{logits}[b, v] \ge 0 \\
\text{logits}[b, v] \times \text{penalty} & \text{if } \text{logits}[b, v] < 0
\end{cases}
\]
</p>

<p>
The penalty is applied <strong>once per unique token id</strong>: if the same token id
appears multiple times in <code>input_ids[b]</code>, the corresponding logit is still
scaled by <code>penalty</code> only once. Logits at indices that never appear in
<code>input_ids[b]</code> remain unchanged.
</p>

<h2>Implementation Requirements</h2>
<ul>
<li>Implement the function <code>solve(logits, input_ids, penalty, B, V, T)</code>.</li>
<li>Do not change the function signature or use external libraries beyond the standard GPU frameworks.</li>
<li>Modify the <code>logits</code> buffer in place.</li>
<li><code>logits</code> is <code>float32</code> with shape <code>(B, V)</code>; <code>input_ids</code> is <code>int32</code> with shape <code>(B, T)</code>.</li>
<li>Each batch is independent: token ids in <code>input_ids[b]</code> only affect <code>logits[b, :]</code>.</li>
</ul>

<h2>Example</h2>
<p>
With <code>B</code> = 1, <code>V</code> = 5, <code>T</code> = 3, <code>penalty</code> = 2.0:
</p>
<p>
<strong>Input:</strong><br>
\(\text{logits}\) (1&times;5):
\[
\begin{bmatrix} 2.0 & -1.0 & 0.5 & -3.0 & 1.0 \end{bmatrix}
\]
\(\text{input\_ids}\) (1&times;3):
\[
\begin{bmatrix} 0 & 3 & 3 \end{bmatrix}
\]
</p>
<p>
Token id 0 appears in <code>input_ids</code>: \(\text{logits}[0, 0] = 2.0 \ge 0\), so it becomes \(2.0 / 2.0 = 1.0\).<br>
Token id 3 appears in <code>input_ids</code> (twice, but the penalty is applied only once): \(\text{logits}[0, 3] = -3.0 < 0\), so it becomes \(-3.0 \times 2.0 = -6.0\).<br>
Token ids 1, 2, 4 do not appear in <code>input_ids</code>, so their logits are unchanged.
</p>
<p>
<strong>Output</strong> (in <code>logits</code>, 1&times;5):
\[
\begin{bmatrix} 1.0 & -1.0 & 0.5 & -6.0 & 1.0 \end{bmatrix}
\]
</p>

<h2>Constraints</h2>
<ul>
<li>1 &le; <code>B</code> &le; 128</li>
<li>1 &le; <code>V</code> &le; 200,000</li>
<li>1 &le; <code>T</code> &le; 8,192</li>
<li>1.0 &le; <code>penalty</code> &le; 4.0</li>
<li>Each <code>input_ids[b, t]</code> is in <code>[0, V)</code></li>
<li>-100.0 &le; <code>logits[b, v]</code> &le; 100.0</li>
<li>Performance is measured with <code>B</code> = 64, <code>V</code> = 131,072, <code>T</code> = 4,096</li>
</ul>
144 changes: 144 additions & 0 deletions challenges/medium/103_repetition_penalty/challenge.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,144 @@
import ctypes
from typing import Any, Dict, List

import torch
from core.challenge_base import ChallengeBase


class Challenge(ChallengeBase):
name = "Repetition Penalty Logit Processor"
atol = 1e-05
rtol = 1e-05
num_gpus = 1
access_tier = "free"

def reference_impl(
self,
logits: torch.Tensor,
input_ids: torch.Tensor,
penalty: float,
B: int,
V: int,
T: int,
):
assert logits.shape == (B, V)
assert input_ids.shape == (B, T)
assert logits.dtype == torch.float32
assert input_ids.dtype == torch.int32

score = torch.gather(logits, 1, input_ids.long())
score = torch.where(score < 0, score * penalty, score / penalty)
logits.scatter_(1, input_ids.long(), score)

def get_solve_signature(self) -> Dict[str, tuple]:
return {
"logits": (ctypes.POINTER(ctypes.c_float), "inout"),
"input_ids": (ctypes.POINTER(ctypes.c_int), "in"),
"penalty": (ctypes.c_float, "in"),
"B": (ctypes.c_int, "in"),
"V": (ctypes.c_int, "in"),
"T": (ctypes.c_int, "in"),
}

def _make_test_case(
self,
B: int,
V: int,
T: int,
penalty: float = 1.2,
seed: int = 0,
zero_logits: bool = False,
) -> Dict[str, Any]:
device = self.device
torch.manual_seed(seed)
if zero_logits:
logits = torch.zeros(B, V, device=device, dtype=torch.float32)
else:
logits = torch.empty(B, V, device=device, dtype=torch.float32).uniform_(-5.0, 5.0)
input_ids = torch.randint(0, V, (B, T), dtype=torch.int32, device=device)
return {
"logits": logits,
"input_ids": input_ids,
"penalty": penalty,
"B": B,
"V": V,
"T": T,
}

def generate_example_test(self) -> Dict[str, Any]:
device = self.device
logits = torch.tensor([[2.0, -1.0, 0.5, -3.0, 1.0]], device=device, dtype=torch.float32)
input_ids = torch.tensor([[0, 3, 3]], device=device, dtype=torch.int32)
return {
"logits": logits,
"input_ids": input_ids,
"penalty": 2.0,
"B": 1,
"V": 5,
"T": 3,
}

def generate_functional_test(self) -> List[Dict[str, Any]]:
device = self.device
tests = []

tests.append(
{
"logits": torch.tensor(
[[2.0, -1.0, 0.5, -3.0, 1.0]], device=device, dtype=torch.float32
),
"input_ids": torch.tensor([[0, 3, 3]], device=device, dtype=torch.int32),
"penalty": 2.0,
"B": 1,
"V": 5,
"T": 3,
}
)

tests.append(
{
"logits": torch.tensor(
[[-4.0, -2.0, 0.0, 2.0]], device=device, dtype=torch.float32
),
"input_ids": torch.tensor([[1, 2]], device=device, dtype=torch.int32),
"penalty": 1.5,
"B": 1,
"V": 4,
"T": 2,
}
)

tests.append(
{
"logits": torch.tensor(
[
[3.0, -3.0, 1.5, -1.5, 0.0, 4.0, -2.0, 2.0],
[-1.0, 2.0, 0.5, -0.5, 1.0, -2.0, 3.0, -3.0],
],
device=device,
dtype=torch.float32,
),
"input_ids": torch.tensor(
[[0, 5, 7, 7], [1, 6, 6, 4]], device=device, dtype=torch.int32
),
"penalty": 1.3,
"B": 2,
"V": 8,
"T": 4,
}
)

tests.append(self._make_test_case(B=1, V=32, T=1, penalty=1.2, seed=10))
tests.append(self._make_test_case(B=4, V=64, T=16, penalty=1.1, seed=11))
tests.append(self._make_test_case(B=2, V=255, T=30, penalty=1.5, seed=12))
tests.append(
self._make_test_case(B=2, V=128, T=100, penalty=1.25, seed=13, zero_logits=True)
)
tests.append(self._make_test_case(B=8, V=1024, T=256, penalty=1.1, seed=14))
tests.append(self._make_test_case(B=4, V=8192, T=512, penalty=1.3, seed=15))
tests.append(self._make_test_case(B=2, V=32000, T=1024, penalty=1.15, seed=16))

return tests

def generate_performance_test(self) -> Dict[str, Any]:
return self._make_test_case(B=64, V=131072, T=4096, penalty=1.2, seed=42)
4 changes: 4 additions & 0 deletions challenges/medium/103_repetition_penalty/starter/starter.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
#include <cuda_runtime.h>

// logits, input_ids are device pointers
extern "C" void solve(float* logits, const int* input_ids, float penalty, int B, int V, int T) {}
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
import cutlass
import cutlass.cute as cute


# logits, input_ids are tensors on the GPU
@cute.jit
def solve(
logits: cute.Tensor,
input_ids: cute.Tensor,
penalty: cute.Float32,
B: cute.Int32,
V: cute.Int32,
T: cute.Int32,
):
pass
16 changes: 16 additions & 0 deletions challenges/medium/103_repetition_penalty/starter/starter.jax.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
import jax
import jax.numpy as jnp


# logits, input_ids are tensors on GPU
@jax.jit
def solve(
logits: jax.Array,
input_ids: jax.Array,
penalty: float,
B: int,
V: int,
T: int,
) -> jax.Array:
# return output tensor directly
pass
15 changes: 15 additions & 0 deletions challenges/medium/103_repetition_penalty/starter/starter.mojo
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
from std.gpu.host import DeviceContext
from std.memory import UnsafePointer


# logits, input_ids are device pointers
@export
def solve(
logits: UnsafePointer[Float32, MutExternalOrigin],
input_ids: UnsafePointer[Int32, MutExternalOrigin],
penalty: Float32,
B: Int32,
V: Int32,
T: Int32,
) raises:
pass
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
import torch


# logits, input_ids are tensors on the GPU
def solve(
logits: torch.Tensor,
input_ids: torch.Tensor,
penalty: float,
B: int,
V: int,
T: int,
):
pass
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
import torch
import triton
import triton.language as tl


# logits, input_ids are tensors on the GPU
def solve(
logits: torch.Tensor,
input_ids: torch.Tensor,
penalty: float,
B: int,
V: int,
T: int,
):
pass
Loading