diff --git a/challenges/medium/103_repetition_penalty/challenge.html b/challenges/medium/103_repetition_penalty/challenge.html new file mode 100644 index 00000000..3d402359 --- /dev/null +++ b/challenges/medium/103_repetition_penalty/challenge.html @@ -0,0 +1,77 @@ +

+Implement the repetition penalty logit processor used by LLM samplers in production +inference engines such as Hugging Face Transformers, vLLM, and TGI. Given a batch of raw +logits from a language model and the input_ids already seen by +each sequence in the batch, downweight the logits of tokens that appeared in the prompt +or prior generation so the model is less likely to repeat them. The processor modifies +logits in place. +

+ +

+For each batch b and for every token id v that occurs in +input_ids[b], update logits[b, v] as follows: +

+ +

+\[ +\text{logits}[b, v] \leftarrow +\begin{cases} +\text{logits}[b, v] / \text{penalty} & \text{if } \text{logits}[b, v] \ge 0 \\ +\text{logits}[b, v] \times \text{penalty} & \text{if } \text{logits}[b, v] < 0 +\end{cases} +\] +

+ +

+The penalty is applied once per unique token id: if the same token id +appears multiple times in input_ids[b], the corresponding logit is still +scaled by penalty only once. Logits at indices that never appear in +input_ids[b] remain unchanged. +

+ +

Implementation Requirements

+ + +

Example

+

+ With B = 1, V = 5, T = 3, penalty = 2.0: +

+

+ Input:
+ \(\text{logits}\) (1×5): + \[ + \begin{bmatrix} 2.0 & -1.0 & 0.5 & -3.0 & 1.0 \end{bmatrix} + \] + \(\text{input\_ids}\) (1×3): + \[ + \begin{bmatrix} 0 & 3 & 3 \end{bmatrix} + \] +

+

+ Token id 0 appears in input_ids: \(\text{logits}[0, 0] = 2.0 \ge 0\), so it becomes \(2.0 / 2.0 = 1.0\).
+ Token id 3 appears in input_ids (twice, but the penalty is applied only once): \(\text{logits}[0, 3] = -3.0 < 0\), so it becomes \(-3.0 \times 2.0 = -6.0\).
+ Token ids 1, 2, 4 do not appear in input_ids, so their logits are unchanged. +

+

+ Output (in logits, 1×5): + \[ + \begin{bmatrix} 1.0 & -1.0 & 0.5 & -6.0 & 1.0 \end{bmatrix} + \] +

+ +

Constraints

+ diff --git a/challenges/medium/103_repetition_penalty/challenge.py b/challenges/medium/103_repetition_penalty/challenge.py new file mode 100644 index 00000000..24c1b5fe --- /dev/null +++ b/challenges/medium/103_repetition_penalty/challenge.py @@ -0,0 +1,144 @@ +import ctypes +from typing import Any, Dict, List + +import torch +from core.challenge_base import ChallengeBase + + +class Challenge(ChallengeBase): + name = "Repetition Penalty Logit Processor" + atol = 1e-05 + rtol = 1e-05 + num_gpus = 1 + access_tier = "free" + + def reference_impl( + self, + logits: torch.Tensor, + input_ids: torch.Tensor, + penalty: float, + B: int, + V: int, + T: int, + ): + assert logits.shape == (B, V) + assert input_ids.shape == (B, T) + assert logits.dtype == torch.float32 + assert input_ids.dtype == torch.int32 + + score = torch.gather(logits, 1, input_ids.long()) + score = torch.where(score < 0, score * penalty, score / penalty) + logits.scatter_(1, input_ids.long(), score) + + def get_solve_signature(self) -> Dict[str, tuple]: + return { + "logits": (ctypes.POINTER(ctypes.c_float), "inout"), + "input_ids": (ctypes.POINTER(ctypes.c_int), "in"), + "penalty": (ctypes.c_float, "in"), + "B": (ctypes.c_int, "in"), + "V": (ctypes.c_int, "in"), + "T": (ctypes.c_int, "in"), + } + + def _make_test_case( + self, + B: int, + V: int, + T: int, + penalty: float = 1.2, + seed: int = 0, + zero_logits: bool = False, + ) -> Dict[str, Any]: + device = self.device + torch.manual_seed(seed) + if zero_logits: + logits = torch.zeros(B, V, device=device, dtype=torch.float32) + else: + logits = torch.empty(B, V, device=device, dtype=torch.float32).uniform_(-5.0, 5.0) + input_ids = torch.randint(0, V, (B, T), dtype=torch.int32, device=device) + return { + "logits": logits, + "input_ids": input_ids, + "penalty": penalty, + "B": B, + "V": V, + "T": T, + } + + def generate_example_test(self) -> Dict[str, Any]: + device = self.device + logits = torch.tensor([[2.0, -1.0, 0.5, -3.0, 1.0]], device=device, dtype=torch.float32) + input_ids = torch.tensor([[0, 3, 3]], device=device, dtype=torch.int32) + return { + "logits": logits, + "input_ids": input_ids, + "penalty": 2.0, + "B": 1, + "V": 5, + "T": 3, + } + + def generate_functional_test(self) -> List[Dict[str, Any]]: + device = self.device + tests = [] + + tests.append( + { + "logits": torch.tensor( + [[2.0, -1.0, 0.5, -3.0, 1.0]], device=device, dtype=torch.float32 + ), + "input_ids": torch.tensor([[0, 3, 3]], device=device, dtype=torch.int32), + "penalty": 2.0, + "B": 1, + "V": 5, + "T": 3, + } + ) + + tests.append( + { + "logits": torch.tensor( + [[-4.0, -2.0, 0.0, 2.0]], device=device, dtype=torch.float32 + ), + "input_ids": torch.tensor([[1, 2]], device=device, dtype=torch.int32), + "penalty": 1.5, + "B": 1, + "V": 4, + "T": 2, + } + ) + + tests.append( + { + "logits": torch.tensor( + [ + [3.0, -3.0, 1.5, -1.5, 0.0, 4.0, -2.0, 2.0], + [-1.0, 2.0, 0.5, -0.5, 1.0, -2.0, 3.0, -3.0], + ], + device=device, + dtype=torch.float32, + ), + "input_ids": torch.tensor( + [[0, 5, 7, 7], [1, 6, 6, 4]], device=device, dtype=torch.int32 + ), + "penalty": 1.3, + "B": 2, + "V": 8, + "T": 4, + } + ) + + tests.append(self._make_test_case(B=1, V=32, T=1, penalty=1.2, seed=10)) + tests.append(self._make_test_case(B=4, V=64, T=16, penalty=1.1, seed=11)) + tests.append(self._make_test_case(B=2, V=255, T=30, penalty=1.5, seed=12)) + tests.append( + self._make_test_case(B=2, V=128, T=100, penalty=1.25, seed=13, zero_logits=True) + ) + tests.append(self._make_test_case(B=8, V=1024, T=256, penalty=1.1, seed=14)) + tests.append(self._make_test_case(B=4, V=8192, T=512, penalty=1.3, seed=15)) + tests.append(self._make_test_case(B=2, V=32000, T=1024, penalty=1.15, seed=16)) + + return tests + + def generate_performance_test(self) -> Dict[str, Any]: + return self._make_test_case(B=64, V=131072, T=4096, penalty=1.2, seed=42) diff --git a/challenges/medium/103_repetition_penalty/starter/starter.cu b/challenges/medium/103_repetition_penalty/starter/starter.cu new file mode 100644 index 00000000..28b7b3f5 --- /dev/null +++ b/challenges/medium/103_repetition_penalty/starter/starter.cu @@ -0,0 +1,4 @@ +#include + +// logits, input_ids are device pointers +extern "C" void solve(float* logits, const int* input_ids, float penalty, int B, int V, int T) {} diff --git a/challenges/medium/103_repetition_penalty/starter/starter.cute.py b/challenges/medium/103_repetition_penalty/starter/starter.cute.py new file mode 100644 index 00000000..86f2c742 --- /dev/null +++ b/challenges/medium/103_repetition_penalty/starter/starter.cute.py @@ -0,0 +1,15 @@ +import cutlass +import cutlass.cute as cute + + +# logits, input_ids are tensors on the GPU +@cute.jit +def solve( + logits: cute.Tensor, + input_ids: cute.Tensor, + penalty: cute.Float32, + B: cute.Int32, + V: cute.Int32, + T: cute.Int32, +): + pass diff --git a/challenges/medium/103_repetition_penalty/starter/starter.jax.py b/challenges/medium/103_repetition_penalty/starter/starter.jax.py new file mode 100644 index 00000000..eff33fc9 --- /dev/null +++ b/challenges/medium/103_repetition_penalty/starter/starter.jax.py @@ -0,0 +1,16 @@ +import jax +import jax.numpy as jnp + + +# logits, input_ids are tensors on GPU +@jax.jit +def solve( + logits: jax.Array, + input_ids: jax.Array, + penalty: float, + B: int, + V: int, + T: int, +) -> jax.Array: + # return output tensor directly + pass diff --git a/challenges/medium/103_repetition_penalty/starter/starter.mojo b/challenges/medium/103_repetition_penalty/starter/starter.mojo new file mode 100644 index 00000000..e2c15e36 --- /dev/null +++ b/challenges/medium/103_repetition_penalty/starter/starter.mojo @@ -0,0 +1,15 @@ +from std.gpu.host import DeviceContext +from std.memory import UnsafePointer + + +# logits, input_ids are device pointers +@export +def solve( + logits: UnsafePointer[Float32, MutExternalOrigin], + input_ids: UnsafePointer[Int32, MutExternalOrigin], + penalty: Float32, + B: Int32, + V: Int32, + T: Int32, +) raises: + pass diff --git a/challenges/medium/103_repetition_penalty/starter/starter.pytorch.py b/challenges/medium/103_repetition_penalty/starter/starter.pytorch.py new file mode 100644 index 00000000..25663235 --- /dev/null +++ b/challenges/medium/103_repetition_penalty/starter/starter.pytorch.py @@ -0,0 +1,13 @@ +import torch + + +# logits, input_ids are tensors on the GPU +def solve( + logits: torch.Tensor, + input_ids: torch.Tensor, + penalty: float, + B: int, + V: int, + T: int, +): + pass diff --git a/challenges/medium/103_repetition_penalty/starter/starter.triton.py b/challenges/medium/103_repetition_penalty/starter/starter.triton.py new file mode 100644 index 00000000..6147612f --- /dev/null +++ b/challenges/medium/103_repetition_penalty/starter/starter.triton.py @@ -0,0 +1,15 @@ +import torch +import triton +import triton.language as tl + + +# logits, input_ids are tensors on the GPU +def solve( + logits: torch.Tensor, + input_ids: torch.Tensor, + penalty: float, + B: int, + V: int, + T: int, +): + pass