Skip to content

Commit b776cdb

Browse files
committed
Support new flattened kernel builds
Support kernels that have the main module at `build/<variant>`. See: huggingface/kernel-builder#293
1 parent 39d2ade commit b776cdb

File tree

2 files changed

+89
-66
lines changed

2 files changed

+89
-66
lines changed

src/kernels/utils.py

Lines changed: 50 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,11 @@ def universal_build_variant() -> str:
8888
return "torch-universal"
8989

9090

91-
def import_from_path(module_name: str, file_path: Path) -> ModuleType:
91+
def _import_from_path(module_name: str, variant_path: Path) -> ModuleType:
92+
file_path = variant_path / "__init__.py"
93+
if not file_path.exists():
94+
file_path = variant_path / module_name / "__init__.py"
95+
9296
# We cannot use the module name as-is, after adding it to `sys.modules`,
9397
# it would also be used for other imports. So, we make a module name that
9498
# depends on the path for it to be unique using the hex-encoded hash of
@@ -149,42 +153,48 @@ def install_kernel(
149153
)
150154

151155
try:
152-
return _load_kernel_from_path(repo_path, package_name, variant_locks)
153-
except FileNotFoundError:
154-
# Redo with more specific error message.
156+
return _find_kernel_in_repo_path(repo_path, package_name, variant_locks)
157+
except:
155158
raise FileNotFoundError(
156-
f"Kernel `{repo_id}` at revision {revision} does not have build: {variant}"
159+
f"Cannot install kernel from repo {repo_id} (revision: {revision})"
157160
)
158161

159162

160-
def _load_kernel_from_path(
163+
def _find_kernel_in_repo_path(
161164
repo_path: Path,
162165
package_name: str,
163166
variant_locks: Optional[Dict[str, VariantLock]] = None,
164167
) -> Tuple[str, Path]:
165-
variant = build_variant()
168+
specific_variant = build_variant()
166169
universal_variant = universal_build_variant()
167170

168-
variant_path = repo_path / "build" / variant
171+
specific_variant_path = repo_path / "build" / specific_variant
169172
universal_variant_path = repo_path / "build" / universal_variant
170173

171-
if not variant_path.exists() and universal_variant_path.exists():
172-
# Fall back to universal variant.
174+
if specific_variant_path.exists():
175+
variant = specific_variant
176+
variant_path = specific_variant_path
177+
elif universal_variant_path.exists():
173178
variant = universal_variant
174179
variant_path = universal_variant_path
180+
else:
181+
raise FileNotFoundError(
182+
f"Kernel at path `{repo_path}` does not have one of build variants: {specific_variant}, {universal_variant}"
183+
)
175184

176185
if variant_locks is not None:
177186
variant_lock = variant_locks.get(variant)
178187
if variant_lock is None:
179188
raise ValueError(f"No lock found for build variant: {variant}")
180189
validate_kernel(repo_path=repo_path, variant=variant, hash=variant_lock.hash)
181190

182-
module_init_path = variant_path / package_name / "__init__.py"
191+
module_init_path = variant_path / "__init__.py"
192+
if not os.path.exists(module_init_path):
193+
# Compatibility with older kernels.
194+
module_init_path = variant_path / package_name / "__init__.py"
183195

184196
if not os.path.exists(module_init_path):
185-
raise FileNotFoundError(
186-
f"Kernel at path `{repo_path}` does not have build: {variant}"
187-
)
197+
raise FileNotFoundError(f"No kernel module found at: `{variant_path}`")
188198

189199
return package_name, variant_path
190200

@@ -258,10 +268,10 @@ def get_kernel(
258268
```
259269
"""
260270
revision = select_revision_or_version(repo_id, revision, version)
261-
package_name, package_path = install_kernel(
271+
package_name, variant_path = install_kernel(
262272
repo_id, revision=revision, user_agent=user_agent
263273
)
264-
return import_from_path(package_name, package_path / package_name / "__init__.py")
274+
return _import_from_path(package_name, variant_path)
265275

266276

267277
def get_local_kernel(repo_path: Path, package_name: str) -> ModuleType:
@@ -284,15 +294,15 @@ def get_local_kernel(repo_path: Path, package_name: str) -> ModuleType:
284294
for base_path in [repo_path, repo_path / "build"]:
285295
# Prefer the universal variant if it exists.
286296
for v in [universal_variant, variant]:
287-
package_path = base_path / v / package_name / "__init__.py"
288-
if package_path.exists():
289-
return import_from_path(package_name, package_path)
297+
variant_path = base_path / v
298+
if variant_path.exists():
299+
return _import_from_path(package_name, variant_path)
290300

291301
# If we didn't find the package in the repo we may have a explicit
292302
# package path.
293-
package_path = repo_path / package_name / "__init__.py"
294-
if package_path.exists():
295-
return import_from_path(package_name, package_path)
303+
variant_path = repo_path
304+
if variant_path.exists():
305+
return _import_from_path(package_name, variant_path)
296306

297307
raise FileNotFoundError(f"Could not find package '{package_name}' in {repo_path}")
298308

@@ -321,18 +331,16 @@ def has_kernel(
321331
variant = build_variant()
322332
universal_variant = universal_build_variant()
323333

324-
if file_exists(
325-
repo_id,
326-
revision=revision,
327-
filename=f"build/{universal_variant}/{package_name}/__init__.py",
328-
):
329-
return True
330-
331-
return file_exists(
332-
repo_id,
333-
revision=revision,
334-
filename=f"build/{variant}/{package_name}/__init__.py",
335-
)
334+
for variant in [universal_variant, variant]:
335+
for init_file in ["__init__.py", f"{package_name}/__init__.py"]:
336+
if file_exists(
337+
repo_id,
338+
revision=revision,
339+
filename=f"build/{variant}/{init_file}",
340+
):
341+
return True
342+
343+
return False
336344

337345

338346
def load_kernel(repo_id: str, *, lockfile: Optional[Path] = None) -> ModuleType:
@@ -376,21 +384,16 @@ def load_kernel(repo_id: str, *, lockfile: Optional[Path] = None) -> ModuleType:
376384
)
377385
)
378386

379-
variant_path = repo_path / "build" / variant
380-
universal_variant_path = repo_path / "build" / universal_variant
381-
if not variant_path.exists() and universal_variant_path.exists():
382-
# Fall back to universal variant.
383-
variant = universal_variant
384-
variant_path = universal_variant_path
385-
386-
module_init_path = variant_path / package_name / "__init__.py"
387-
if not os.path.exists(module_init_path):
387+
try:
388+
package_name, variant_path = _find_kernel_in_repo_path(
389+
repo_path, package_name, variant_locks=None
390+
)
391+
_import_from_path(package_name, variant_path)
392+
except FileNotFoundError:
388393
raise FileNotFoundError(
389394
f"Locked kernel `{repo_id}` does not have build `{variant}` or was not downloaded with `kernels download <project>`"
390395
)
391396

392-
return import_from_path(package_name, variant_path / package_name / "__init__.py")
393-
394397

395398
def get_locked_kernel(repo_id: str, local_files_only: bool = False) -> ModuleType:
396399
"""
@@ -410,11 +413,11 @@ def get_locked_kernel(repo_id: str, local_files_only: bool = False) -> ModuleTyp
410413
if locked_sha is None:
411414
raise ValueError(f"Kernel `{repo_id}` is not locked")
412415

413-
package_name, package_path = install_kernel(
416+
package_name, variant_path = install_kernel(
414417
repo_id, locked_sha, local_files_only=local_files_only
415418
)
416419

417-
return import_from_path(package_name, package_path / package_name / "__init__.py")
420+
return _import_from_path(package_name, variant_path)
418421

419422

420423
def _get_caller_locked_kernel(repo_id: str) -> Optional[str]:

tests/test_basic.py

Lines changed: 39 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import pytest
22
import torch
3+
import torch.nn.functional as F
34

45
from kernels import get_kernel, get_local_kernel, has_kernel, install_kernel
56

@@ -72,37 +73,33 @@ def test_local_kernel(local_kernel, device):
7273
assert torch.allclose(y, expected)
7374

7475

75-
@pytest.mark.cuda_only
76-
def test_local_kernel_path_types(local_kernel_path, device):
77-
package_name, path = local_kernel_path
76+
@pytest.mark.parametrize(
77+
"repo_revision",
78+
[
79+
("kernels-test/flattened-build", "pre-flattening"),
80+
("kernels-test/flattened-build", "main"),
81+
("kernels-test/flattened-build", "without-compat-module"),
82+
],
83+
)
84+
def test_local_kernel_path_types(repo_revision, device):
85+
repo_id, revision = repo_revision
86+
package_name, path = install_kernel(repo_id, revision)
7887

7988
# Top-level repo path
8089
# ie: /home/ubuntu/.cache/huggingface/hub/models--kernels-community--activation/snapshots/2fafa6a3a38ccb57a1a98419047cf7816ecbc071
8190
kernel = get_local_kernel(path.parent.parent, package_name)
82-
x = torch.arange(1, 10, dtype=torch.float16, device=device).view(3, 3)
83-
y = torch.empty_like(x)
84-
85-
kernel.gelu_fast(y, x)
86-
expected = torch.tensor(
87-
[[0.8408, 1.9551, 2.9961], [4.0000, 5.0000, 6.0000], [7.0000, 8.0000, 9.0000]],
88-
device=device,
89-
dtype=torch.float16,
90-
)
91-
assert torch.allclose(y, expected)
91+
x = torch.arange(0, 32, dtype=torch.float16, device=device).view(2, 16)
92+
torch.testing.assert_close(kernel.silu_and_mul(x), silu_and_mul_torch(x))
9293

9394
# Build directory path
9495
# ie: /home/ubuntu/.cache/huggingface/hub/models--kernels-community--activation/snapshots/2fafa6a3a38ccb57a1a98419047cf7816ecbc071/build
9596
kernel = get_local_kernel(path.parent.parent / "build", package_name)
96-
y = torch.empty_like(x)
97-
kernel.gelu_fast(y, x)
98-
assert torch.allclose(y, expected)
97+
torch.testing.assert_close(kernel.silu_and_mul(x), silu_and_mul_torch(x))
9998

10099
# Explicit package path
101100
# ie: /home/ubuntu/.cache/huggingface/hub/models--kernels-community--activation/snapshots/2fafa6a3a38ccb57a1a98419047cf7816ecbc071/build/torch28-cxx11-cu128-x86_64-linux
102101
kernel = get_local_kernel(path, package_name)
103-
y = torch.empty_like(x)
104-
kernel.gelu_fast(y, x)
105-
assert torch.allclose(y, expected)
102+
torch.testing.assert_close(kernel.silu_and_mul(x), silu_and_mul_torch(x))
106103

107104

108105
@pytest.mark.darwin_only
@@ -123,6 +120,8 @@ def test_relu_metal(metal_kernel, dtype):
123120
# support/test against this version).
124121
("kernels-test/only-torch-2.4", "main", False),
125122
("google-bert/bert-base-uncased", "87565a309", False),
123+
("kernels-test/flattened-build", "main", True),
124+
("kernels-test/flattened-build", "without-compat-module", True),
126125
],
127126
)
128127
def test_has_kernel(kernel_exists):
@@ -162,3 +161,24 @@ def test_universal_kernel(universal_kernel):
162161
out_check = out_check.to(torch.float16)
163162

164163
torch.testing.assert_close(out, out_check, rtol=1e-1, atol=1e-1)
164+
165+
166+
@pytest.mark.parametrize(
167+
"repo_revision",
168+
[
169+
("kernels-test/flattened-build", "pre-flattening"),
170+
("kernels-test/flattened-build", "main"),
171+
("kernels-test/flattened-build", "without-compat-module"),
172+
],
173+
)
174+
def test_flattened_build(repo_revision, device):
175+
repo_id, revision = repo_revision
176+
kernel = get_kernel(repo_id, revision=revision)
177+
178+
x = torch.arange(0, 32, dtype=torch.float16, device=device).view(2, 16)
179+
torch.testing.assert_close(kernel.silu_and_mul(x), silu_and_mul_torch(x))
180+
181+
182+
def silu_and_mul_torch(x: torch.Tensor):
183+
d = x.shape[-1] // 2
184+
return F.silu(x[..., :d]) * x[..., d:]

0 commit comments

Comments
 (0)