Skip to content

Commit 6c36789

Browse files
committed
feat: address edit comments and improve examples
1 parent cc9e805 commit 6c36789

File tree

2 files changed

+96
-44
lines changed

2 files changed

+96
-44
lines changed

_blog.yml

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5951,3 +5951,15 @@
59515951
- gradio
59525952
- tool
59535953
- llm
5954+
5955+
- local: hello-hf-kernels
5956+
title: "Learn the Hugging Face Kernel Hub in 5 Minutes"
5957+
author: drbh
5958+
thumbnail: /blog/assets/hello-hf-kernels/kernel-hub-five-mins-short-21.png
5959+
date: May 8, 2025
5960+
tags:
5961+
- guide
5962+
- hub
5963+
- optimization
5964+
- open-source
5965+
- performance

hello-hf-kernels.md

Lines changed: 84 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -3,14 +3,17 @@ title: "Learn the Hugging Face Kernel Hub in 5 Minutes"
33
thumbnail: /blog/assets/hello-hf-kernels/kernel-hub-five-mins-short.png
44
authors:
55
- user: drbh
6+
- user: danieldk
7+
- user: pcuenca
8+
- user: pagezyhf
69
date: 2025-03-28
710
---
811

912
# 🏎️ Learn the Hugging Face Kernel Hub in 5 Minutes
1013

11-
**Unlock performance boosts for your models with pre-optimized compute kernels, easily loaded from the Hub.**
14+
**Boost your model performance with pre-optimized kernels, easily loaded from the Hub.**
1215

13-
Today, we'll explore an exciting development from Hugging Face: the **Kernel Hub**! As ML practitioners, we know that maximizing performance often involves diving deep into optimized code, custom CUDA kernels, or complex build systems. The Kernel Hub aims to simplify this dramatically.
16+
Today, we'll explore an exciting development from Hugging Face: the **Kernel Hub**! As ML practitioners, we know that maximizing performance often involves diving deep into optimized code, custom CUDA kernels, or complex build systems. The Kernel Hub simplifies this process dramatically!
1417

1518
We'll cover the following topics:
1619

@@ -19,30 +22,36 @@ We'll cover the following topics:
1922
3. **Adding a Kernel to a Simple Model** - A practical integration using RMSNorm.
2023
4. **Reviewing Performance Impact** - Benchmarking the RMSNorm difference.
2124

22-
And we'll introduce these concepts quickly – the core idea can be grasped in about 5 minutes (though experimenting and benchmarking might take a bit longer!).
25+
We'll introduce these concepts quickly – the core idea can be grasped in about 5 minutes (though experimenting and benchmarking might take a bit longer!).
2326

2427
## 1. What is the Kernel Hub?
2528

26-
The [Kernel Hub](https://huggingface.co/kernels) (👈 Check it out!) allows Python libraries and applications to **load optimized compute kernels directly from the Hugging Face Hub**. Think of it like the Model Hub, but for low-level, high-performance code snippets (kernels) that accelerate specific operations, often on GPUs. Examples include optimized attention mechanisms (like FlashAttention), activation functions, and normalization layers (like LayerNorm or RMSNorm).
29+
30+
The [Kernel Hub](https://huggingface.co/kernels-community) (👈 Check it out!) allows Python libraries and applications to **load optimized compute kernels directly from the Hugging Face Hub**. Think of it like the Model Hub, but for low-level, high-performance code snippets (kernels) that accelerate specific operations, often on GPUs.
31+
32+
Examples include advanced attention mechanisms (like [FlashAttention](https://huggingface.co/kernels-community/flash-attn) for dramatic speedups and memory savings). Custom [quantization kernels](https://huggingface.co/kernels-community/quantization) (enabling efficient computation with lower-precision data types like INT8 or INT4). Specialized kernels required for complex architectures like [Mixture of Experts (MoE) layers](https://huggingface.co/kernels-community/moe), which involve intricate routing and computation patterns. As well as [activation functions](https://huggingface.co/kernels-community/activation), and [normalization layers (like LayerNorm or RMSNorm)](https://huggingface.co/kernels-community/triton-layer-norm).
2733

2834
Instead of manually managing complex dependencies, wrestling with compilation flags, or building libraries like Triton or CUTLASS from source, you can use the `kernels` library to instantly fetch and run pre-compiled, optimized kernels.
2935

3036
### Benefits of the Kernel Hub:
3137

32-
* **Instant Access to Optimized Kernels**: Load and run kernels optimized for various hardware (like NVIDIA GPUs) without local compilation hassles.
38+
* **Instant Access to Optimized Kernels**: Load and run kernels optimized for various hardware starting with NVIDIA and AMD GPUs, without local compilation hassles.
3339
* **Share and Reuse**: Discover, share, and reuse kernels across different projects and the community.
3440
* **Easy Updates**: Stay up-to-date with the latest kernel improvements simply by pulling the latest version from the Hub.
3541
* **Accelerate Development**: Focus on your model architecture and logic, not on the intricacies of kernel compilation and deployment.
3642
* **Improve Performance**: Leverage kernels optimized by experts to potentially speed up training and inference.
3743
* **Simplify Deployment**: Reduce the complexity of your deployment environment by fetching kernels on demand.
44+
* **Develop and Share Your Own Kernels**: If you create optimized kernels, you can easily share them on the Hub for others to use. This encourages collaboration and knowledge sharing within the community.
3845

3946
> As many machine learning developers know, managing dependencies and building low-level code from source can be a time-consuming and error-prone process. The Kernel Hub aims to simplify this by providing a centralized repository of optimized compute kernels that can be easily loaded and run.
4047
4148
Spend more time building great models and less time fighting build systems!
4249

4350
## 2. How to Use the Kernel Hub (Basic Example)
4451

45-
Using the Kernel Hub is designed to be straightforward. The `kernels` library provides the main interface. Here's a quick example loading an optimized GELU activation function kernel (we'll use a different kernel for the main example later).
52+
Using the Kernel Hub is designed to be straightforward. The `kernels` library provides the main interface. Here's a quick example that loads an optimized GELU activation function kernel. (Later on, we'll see another example about how to integrate a kernel in our model).
53+
54+
File: `activation_validation_example.py`
4655

4756
~~~python
4857
# /// script
@@ -54,19 +63,15 @@ Using the Kernel Hub is designed to be straightforward. The `kernels` library pr
5463
# ///
5564

5665
import torch
66+
import torch.nn.functional as F
5767
from kernels import get_kernel
5868

59-
# Ensure you have a CUDA-enabled device
60-
if not torch.cuda.is_available():
61-
raise RuntimeError("This example requires a CUDA-enabled GPU")
62-
6369
DEVICE = "cuda"
6470

6571
# Make reproducible
6672
torch.manual_seed(42)
6773

6874
# Download optimized activation kernels from the Hub
69-
# This fetches the kernel code if not already cached
7075
activation_kernels = get_kernel("kernels-community/activation")
7176

7277
# Create a random tensor on the GPU
@@ -75,26 +80,26 @@ x = torch.randn((4, 4), dtype=torch.float16, device=DEVICE)
7580
# Prepare an output tensor
7681
y = torch.empty_like(x)
7782

78-
# Run the specific kernel function (e.g., fast GELU)
79-
# The `activation_kernels` object holds multiple functions
83+
# Run the fast GELU kernel
8084
activation_kernels.gelu_fast(y, x)
8185

82-
# Check the output against expected values
83-
expected = torch.tensor(
84-
[
85-
[0.1100, 2.1309, -0.0700, 0.6802],
86-
[-0.0500, 0.4800, -0.1700, -0.1700],
87-
[0.3701, -0.1300, -0.0800, -0.1200],
88-
[-0.0400, 0.1200, -0.1500, 1.7998],
89-
],
90-
dtype=torch.float16,
91-
device=DEVICE,
92-
)
86+
# Get expected output using PyTorch's built-in GELU
87+
expected = F.gelu(x)
88+
89+
# Compare the kernel output with PyTorch's result
9390
torch.testing.assert_close(y, expected, rtol=1e-2, atol=1e-2)
9491

95-
print("Kernel executed successfully and output matches expected values!")
92+
print("✅ Kernel output matches PyTorch GELU!")
93+
94+
# Optional: print both tensors for inspection
95+
print("\nInput tensor:")
96+
print(x)
97+
print("\nFast GELU kernel output:")
98+
print(y)
99+
print("\nPyTorch GELU output:")
100+
print(expected)
96101

97-
# You can list available functions in the loaded kernel module
102+
# List available functions in the loaded kernel module
98103
print("\nAvailable functions in 'kernels-community/activation':")
99104
print(dir(activation_kernels))
100105
~~~
@@ -117,6 +122,9 @@ Let's integrate an optimized **RMS Normalization** kernel into a basic model. We
117122

118123
First, define a simple RMSNorm module in PyTorch and a baseline model using it:
119124

125+
126+
File: `rmsnorm_baseline.py`
127+
120128
~~~python
121129
# /// script
122130
# dependencies = [
@@ -129,17 +137,16 @@ import torch
129137
import torch.nn as nn
130138

131139
DEVICE = "cuda"
132-
if not torch.cuda.is_available():
133-
raise RuntimeError("This example requires a CUDA-enabled GPU")
140+
134141
DTYPE = torch.float16 # Use float16 for better kernel performance potential
135142

136143

137144
# Simple PyTorch implementation of RMSNorm for baseline comparison
138145
class RMSNorm(nn.Module):
139-
def __init__(self, hidden_size, eps=1e-5):
146+
def __init__(self, hidden_size, variance_epsilon=1e-5):
140147
super().__init__()
141148
self.weight = nn.Parameter(torch.ones(hidden_size))
142-
self.eps = eps
149+
self.eps = variance_epsilon
143150
self.hidden_size = hidden_size
144151

145152
def forward(self, x):
@@ -157,7 +164,7 @@ class BaselineModel(nn.Module):
157164
def __init__(self, input_size, hidden_size, output_size, eps=1e-5):
158165
super().__init__()
159166
self.linear1 = nn.Linear(input_size, hidden_size)
160-
self.norm = RMSNorm(hidden_size, eps=eps)
167+
self.norm = RMSNorm(hidden_size, variance_epsilon=eps)
161168
self.activation = nn.GELU()
162169
self.linear2 = nn.Linear(hidden_size, output_size)
163170

@@ -195,6 +202,8 @@ print("Baseline RMSNorm model output shape:", output.shape)
195202

196203
Now, let's create a version using the `LlamaRMSNorm` kernel loaded via `kernels`.
197204

205+
File: `rmsnorm_kernel.py`
206+
198207
~~~python
199208
# /// script
200209
# dependencies = [
@@ -205,26 +214,53 @@ Now, let's create a version using the `LlamaRMSNorm` kernel loaded via `kernels`
205214
# ///
206215
import torch
207216
import torch.nn as nn
208-
from kernels import get_kernel
217+
from kernels import get_kernel, use_kernel_forward_from_hub
209218

210219
# reuse the model from the previous snippet or copy the class
211220
# definition here to run this script independently
212-
from snippet2 import BaselineModel
221+
from rmsnorm_baseline import BaselineModel
213222

214223
DEVICE = "cuda"
215-
if not torch.cuda.is_available():
216-
raise RuntimeError("This example requires a CUDA-enabled GPU")
217224
DTYPE = torch.float16 # Use float16 for better kernel performance potential
218225

219226

220227
layer_norm_kernel_module = get_kernel("kernels-community/triton-layer-norm")
221228

222-
223-
class KernelRMSNorm(layer_norm_kernel_module.layers.LlamaRMSNorm):
229+
# Simply add the decorator to the LlamaRMSNorm class to automatically replace the forward function
230+
# with the optimized kernel version
231+
#
232+
# Note: note all kernels ship with layers already mapped, and would require calling the function directly
233+
# Howeber in this case, the LlamaRMSNorm class is already mapped to the kernel function. Otherwise we'd need to
234+
# call the function directly like this:
235+
# ```python
236+
# layer_norm_kernel_module.rms_norm_fn(
237+
# hidden_states,
238+
# self.weight,
239+
# bias=None,
240+
# residual=None,
241+
# eps=self.variance_epsilon,
242+
# dropout_p=0.0,
243+
# prenorm=False,
244+
# residual_in_fp32=False,
245+
# )
246+
# ```
247+
@use_kernel_forward_from_hub("LlamaRMSNorm")
248+
class OriginalRMSNorm(nn.Module):
224249
def __init__(self, hidden_size, variance_epsilon=1e-5):
225250
super().__init__()
226251
self.weight = nn.Parameter(torch.ones(hidden_size))
227-
self.variance_epsilon = variance_epsilon
252+
self.eps = variance_epsilon
253+
self.hidden_size = hidden_size
254+
255+
def forward(self, x):
256+
# Assumes x is (batch_size, ..., hidden_size)
257+
input_dtype = x.dtype
258+
# Calculate variance in float32 for stability
259+
variance = x.to(torch.float32).pow(2).mean(-1, keepdim=True)
260+
x = x * torch.rsqrt(variance + self.eps)
261+
262+
# Apply weight and convert back to original dtype
263+
return (self.weight * x).to(input_dtype)
228264

229265

230266
class KernelModel(nn.Module):
@@ -239,7 +275,9 @@ class KernelModel(nn.Module):
239275
):
240276
super().__init__()
241277
self.linear1 = nn.Linear(input_size, hidden_size)
242-
self.norm = KernelRMSNorm(hidden_size, variance_epsilon=eps)
278+
# OriginalRMSNorm will be replaced with the optimized kernel layer
279+
# when the model is loaded
280+
self.norm = OriginalRMSNorm(hidden_size, variance_epsilon=eps)
243281
self.activation = nn.GELU()
244282
self.linear2 = nn.Linear(hidden_size, output_size)
245283

@@ -299,6 +337,7 @@ except NameError:
299337
~~~
300338

301339
**Important Notes on the `KernelModel`:**
340+
302341
* **Kernel Inheritance:** The `KernelRMSNorm` class inherits from `layer_norm_kernel_module.layers.LlamaRMSNorm`, which is the RMSNorm implementation in the kernel. This allows us to use the optimized kernel directly.
303342
* **Accessing the Function:** The exact way to access the RMSNorm function (`layer_norm_kernel_module.layers.LlamaRMSNorm.forward`, `layer_norm_kernel_module.rms_norm_forward`, or something else) **depends entirely on how the kernel creator structured the repository on the Hub.** You may need to inspect the loaded `layer_norm_kernel_module` object (e.g., using `dir()`) or check the kernel's documentation on the Hub to find the correct function/method and its signature. I've used `rms_norm_forward` as a plausible placeholder and added error handling.
304343
* **Parameters:** We now only define `rms_norm_weight` (no bias), consistent with RMSNorm.
@@ -307,6 +346,9 @@ except NameError:
307346

308347
Does using the optimized Triton RMSNorm kernel provide a speedup compared to the basic PyTorch version? Let's benchmark the forward pass again.
309348

349+
350+
File: `rmsnorm_benchmark.py`
351+
310352
~~~python
311353
# /// script
312354
# dependencies = [
@@ -319,12 +361,10 @@ import torch
319361

320362
# reuse the models from the previous snippets or copy the class
321363
# definitions here to run this script independently
322-
from snippet2 import BaselineModel
323-
from snippet3 import KernelModel
364+
from rmsnorm_baseline import BaselineModel
365+
from rmsnorm_kernel import KernelModel
324366

325367
DEVICE = "cuda"
326-
if not torch.cuda.is_available():
327-
raise RuntimeError("This example requires a CUDA-enabled GPU")
328368
DTYPE = torch.float16 # Use float16 for better kernel performance potential
329369

330370

@@ -462,7 +502,7 @@ You've seen how easy it is to fetch and use optimized kernels with the Hugging F
462502
~~~bash
463503
pip install kernels torch numpy
464504
~~~
465-
Ensure you have a compatible PyTorch version and CUDA installed if using GPU kernels.
505+
Ensure you have a compatible PyTorch version and gpu driver installed.
466506

467507
2. **Browse the Hub:** Explore available kernels on the Hugging Face Hub under the [`kernels` tag](https://huggingface.co/kernels) or within organizations like [`kernels-community`](https://huggingface.co/kernels-community). Look for kernels relevant to your operations (activations, attention, normalization like LayerNorm/RMSNorm, etc.).
468508

0 commit comments

Comments
 (0)