Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 14 additions & 2 deletions examples/run_lyrics_transcription.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,22 @@ def parse_args():

if __name__ == "__main__":
args = parse_args()

if torch.backends.mps.is_available():
device = torch.device("mps")
# MPS commonly lacks bf16 support; fp16 is the safest default.
dtype = torch.float16
elif torch.cuda.is_available():
device = torch.device("cuda")
dtype = torch.bfloat16
else:
device = torch.device("cpu")
dtype = torch.bfloat16

pipe = HeartTranscriptorPipeline.from_pretrained(
args.model_path,
device=torch.device("cuda"),
dtype=torch.float16,
device=device,
dtype=dtype,
)
with torch.no_grad():
result = pipe(
Expand Down
20 changes: 17 additions & 3 deletions examples/run_music_generation.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
from heartlib import HeartMuLaGenPipeline
import argparse

import torch

from heartlib import HeartMuLaGenPipeline


def parse_args():
parser = argparse.ArgumentParser()
Expand All @@ -20,10 +22,22 @@ def parse_args():

if __name__ == "__main__":
args = parse_args()

if torch.backends.mps.is_available():
device = torch.device("mps")
# MPS commonly lacks bf16 support; fp16 is the safest default.
dtype = torch.float16
elif torch.cuda.is_available():
device = torch.device("cuda")
dtype = torch.bfloat16
else:
device = torch.device("cpu")
dtype = torch.bfloat16

pipe = HeartMuLaGenPipeline.from_pretrained(
args.model_path,
device=torch.device("cuda"),
dtype=torch.bfloat16,
device=device,
dtype=dtype,
version=args.version,
)
with torch.no_grad():
Expand Down
24 changes: 24 additions & 0 deletions src/heartlib/accelerators/metal/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
"""Optional Metal (MPS) fused kernels for Apple Silicon.

This is intentionally self-contained and opt-in:
- No import-time dependency on Xcode toolchains.
- The extension is built on-demand via `torch.utils.cpp_extension` when enabled.
"""

from __future__ import annotations

from .runtime import metal_supported, metal_build_tools_available
from .jit import load_heartlib_metal_ops
from .rmsnorm import metal_rmsnorm_available, rmsnorm_fp16
from .rope import metal_rope_available, rope_fp16

__all__ = [
"metal_supported",
"metal_build_tools_available",
"load_heartlib_metal_ops",
"metal_rmsnorm_available",
"rmsnorm_fp16",
"metal_rope_available",
"rope_fp16",
]

143 changes: 143 additions & 0 deletions src/heartlib/accelerators/metal/jit.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,143 @@
"""JIT build + load the Metal extension.

Built only when explicitly enabled. Requires Xcode command line tools.
"""

from __future__ import annotations

from pathlib import Path
import subprocess
from typing import Any

from .runtime import metal_build_tools_available, metal_supported


def _this_dir() -> Path:
return Path(__file__).resolve().parent


_CACHED_MOD: Any | None = None
_CACHED_ERR: Exception | None = None


def _xcrun_find(tool: str) -> str:
out = subprocess.check_output(
["xcrun", "-sdk", "macosx", "--find", str(tool)], stderr=subprocess.STDOUT
)
p = out.decode("utf-8", errors="replace").strip()
if not p:
raise RuntimeError(f"xcrun returned empty path for tool {tool!r}")
return p


def _compile_metallib(*, out_dir: Path, verbose: bool) -> Path:
"""Compile minimal Metal shaders -> `heartlib_ops.metallib` in `out_dir`."""
sources = [
_this_dir() / "rmsnorm.metal",
_this_dir() / "rope.metal",
]
airs = [out_dir / f"{src.stem}.air" for src in sources]
metallib = out_dir / "heartlib_ops.metallib"

metal = _xcrun_find("metal")
metallib_tool = _xcrun_find("metallib")

if metallib.exists():
mt = metallib.stat().st_mtime
if all(mt >= src.stat().st_mtime for src in sources):
return metallib

out_dir.mkdir(parents=True, exist_ok=True)

for src, air in zip(sources, airs, strict=True):
cmd = [metal, "-c", str(src), "-o", str(air)]
if verbose:
print("[heartlib] compiling Metal shader:", " ".join(cmd))
proc = subprocess.run(cmd, capture_output=True, text=True)
if proc.returncode != 0:
raise RuntimeError(
"Failed to compile Metal shaders.\n\n"
f"Command:\n {' '.join(cmd)}\n\n"
f"stdout:\n{proc.stdout}\n\n"
f"stderr:\n{proc.stderr}\n"
)

cmd2 = [metallib_tool, *[str(air) for air in airs], "-o", str(metallib)]
if verbose:
print("[heartlib] linking Metal metallib:", " ".join(cmd2))
proc2 = subprocess.run(cmd2, capture_output=True, text=True)
if proc2.returncode != 0:
raise RuntimeError(
"Failed to link Metal metallib (`metallib`).\n\n"
f"Command:\n {' '.join(cmd2)}\n\n"
f"stdout:\n{proc2.stdout}\n\n"
f"stderr:\n{proc2.stderr}\n"
)
return metallib


def load_heartlib_metal_ops(*, verbose: bool = False) -> Any:
"""Build (if needed) and import the `heartlib_metal_ops` extension."""
global _CACHED_MOD, _CACHED_ERR
if _CACHED_MOD is not None:
return _CACHED_MOD
if _CACHED_ERR is not None:
raise _CACHED_ERR

if not metal_supported():
err = RuntimeError("Metal/MPS is not supported on this runtime")
_CACHED_ERR = err
raise err
if not metal_build_tools_available():
err = RuntimeError(
"Metal build tools unavailable.\n\n"
"heartlib's fused Metal kernels require Xcode's Metal toolchain (`metal`, `metallib`).\n"
"Install/select it:\n"
" - `xcode-select --install`\n"
" - or install Xcode.app then:\n"
" `sudo xcode-select -s /Applications/Xcode.app/Contents/Developer`\n"
" `sudo xcodebuild -license accept`\n\n"
"Verify:\n"
" `xcrun -sdk macosx --find metal`\n"
" `xcrun -sdk macosx --find metallib`\n"
)
_CACHED_ERR = err
raise err

import torch.utils.cpp_extension as ce

try:
name = "heartlib_metal_ops"
build_dir = Path(ce._get_build_directory(name, verbose=verbose))

_compile_metallib(out_dir=build_dir, verbose=verbose)

src_ops = str(_this_dir() / "ops.mm")
extra_cflags = [
"-O3",
"-std=c++17",
"-fobjc-arc",
]
extra_ldflags = [
"-framework",
"Metal",
"-framework",
"Foundation",
]
mod = ce.load(
name=name,
sources=[src_ops],
extra_cflags=extra_cflags,
extra_ldflags=extra_ldflags,
with_cuda=False,
is_python_module=True,
build_directory=str(build_dir),
verbose=verbose,
)
except Exception as e:
_CACHED_ERR = e
raise

_CACHED_MOD = mod
return mod

Loading