diff --git a/.github/workflows/smoke-test.yml b/.github/workflows/smoke-test.yml index b81d1b2..2adf005 100644 --- a/.github/workflows/smoke-test.yml +++ b/.github/workflows/smoke-test.yml @@ -33,6 +33,7 @@ jobs: - name: Run smoke test run: uv run python -m BackendBench.scripts.main --suite smoke --backend aten + - name: Run FACTO test run: uv run python -m BackendBench.scripts.main --suite facto --backend aten --ops "add.Tensor" diff --git a/BackendBench/backends/directory.py b/BackendBench/backends/directory.py index 7b43972..4cb9652 100644 --- a/BackendBench/backends/directory.py +++ b/BackendBench/backends/directory.py @@ -9,6 +9,8 @@ import os from typing import Callable, Dict +from torch.utils.cpp_extension import load_inline + from ..utils import folder_name_to_op_name, get_pytorch_op from .base import Backend @@ -47,7 +49,8 @@ def _load_kernels(self): impl_files = [ f for f in os.listdir(op_dir) - if f.endswith(".py") and f.startswith(f"{folder_name}_implementation") + if (f.endswith(".py") or f.endswith(".cu") or f.endswith(".cpp")) + and f.startswith(f"{folder_name}_implementation") ] if not impl_files: logger.debug(f"No implementation files found in {op_dir}") @@ -71,17 +74,13 @@ def _load_kernels(self): logger.info(f"DirectoryBackend loaded {loaded_count} kernels from {self.ops_dir}/") - def _load_kernel_from_file(self, file_path: str, folder_name: str) -> Callable: + def _load_python_kernel(self, file_path: str, folder_name: str) -> Callable: """ - Dynamically load a kernel implementation function from a Python file. - - Each operator directory should contain implementation files that export a function - named {op_name}_kernel_impl. This function becomes the kernel implementation - that gets registered for all variants of the operator. + Load a kernel implementation from a Python file. Args: file_path: Path to the Python implementation file - op_name: Base name of the operator (e.g., "add", "mul", "conv2d") + folder_name: Base name of the operator (e.g., "add__Tensor") Returns: Callable kernel implementation function @@ -99,6 +98,85 @@ def _load_kernel_from_file(self, file_path: str, folder_name: str) -> Callable: else: raise ValueError(f"No function named {kernel_func_name} found in {file_path}") + def _load_cuda_kernel(self, file_path: str, folder_name: str) -> Callable: + """ + Load and compile a kernel implementation from CUDA files using load_inline. + + Args: + file_path: Path to the CUDA implementation file (.cu or .cpp) + folder_name: Base name of the operator (e.g., "add__Tensor") + + Returns: + Callable kernel implementation function + + Raises: + ValueError: If the expected kernel function is not found in the compiled module + """ + file_dir = os.path.dirname(file_path) + file_name = os.path.basename(file_path) + base_name = file_name.rsplit(".", 1)[0] + + cu_file = os.path.join(file_dir, f"{base_name}.cu") + cpp_file = os.path.join(file_dir, f"{base_name}.cpp") + + cpp_source = "" + cuda_source = "" + + # Read both files if they exist + if os.path.exists(cu_file): + with open(cu_file, "r") as f: + cuda_source = f.read() + + if os.path.exists(cpp_file): + with open(cpp_file, "r") as f: + cpp_source = f.read() + + # Use load_inline for all cases + module_name = f"{folder_name}_cuda_inline" + cuda_module = load_inline( + name=module_name, + cpp_sources=cpp_source, + cuda_sources=cuda_source, + functions=[folder_name], + no_implicit_headers=True, + ) + + if hasattr(cuda_module, folder_name): + return getattr(cuda_module, folder_name) + else: + raise ValueError( + f"No function named {folder_name} found in compiled CUDA module from {file_path}" + ) + + def _load_kernel_from_file(self, file_path: str, folder_name: str) -> Callable: + """ + Dynamically load a kernel implementation function from a Python or CUDA file. + + Dispatches to the appropriate loader based on file extension: + - .py files -> _load_python_kernel + - .cu or .cpp files -> _load_cuda_kernel + + Args: + file_path: Path to the implementation file (Python or CUDA) + op_name: Base name of the operator (e.g., "add", "mul", "conv2d") + + Returns: + Callable kernel implementation function + + Raises: + ValueError: If the file extension is unsupported or the kernel function is not found + """ + file_ext = os.path.splitext(file_path)[1] + + if file_ext == ".py": + return self._load_python_kernel(file_path, folder_name) + elif file_ext in [".cu", ".cpp"]: + return self._load_cuda_kernel(file_path, folder_name) + else: + raise ValueError( + f"Unsupported file extension {file_ext} for {file_path}. Expected .py, .cu, or .cpp" + ) + def __getitem__(self, key): if key in self.compiled_kernels: return self.compiled_kernels[key] diff --git a/BackendBench/scripts/create_simple_test_ops_cuda.py b/BackendBench/scripts/create_simple_test_ops_cuda.py new file mode 100644 index 0000000..b585a0d --- /dev/null +++ b/BackendBench/scripts/create_simple_test_ops_cuda.py @@ -0,0 +1,71 @@ +#!/usr/bin/env python3 + +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD 3-Clause license found in the +# LICENSE file in the root directory of this source tree. + +""" +Create simple kernel implementations for 5 common operations. +Each just calls the original PyTorch function. +""" + +import argparse +import logging +import os + +logger = logging.getLogger(__name__) + + +def create_add(base_dir): + os.makedirs(f"{base_dir}/add__Tensor", exist_ok=True) + with open(f"{base_dir}/add__Tensor/add__Tensor_implementation_v1.cu", "w") as f: + f.write("""#include +#include + +__global__ void add__Tensor_kernel( + const float* __restrict__ x, + const float* __restrict__ y, + float* __restrict__ output, + const int size) { + const auto index = blockIdx.x * blockDim.x + threadIdx.x; + if (index < size) { + output[index] = x[index] + y[index]; + } +} + +at::Tensor add__Tensor(const at::Tensor& a, const at::Tensor& b) { + auto out = at::empty_like(a); + int64_t numel = a.numel(); + const int threads = 256; + const int blocks = (numel + threads - 1) / threads; + add__Tensor_kernel<<>>( + a.data_ptr(), b.data_ptr(), out.data_ptr(), numel + ); + return out; +} +""") + with open(f"{base_dir}/add__Tensor/add__Tensor_implementation_v1.cpp", "w") as f: + f.write("""#include + +at::Tensor add__Tensor(const at::Tensor& a, const at::Tensor& b);""") + logger.info("Created add implementation") + + +def main(): + """Create 1 simple test operations.""" + parser = argparse.ArgumentParser(description="Creating cuda kernel implementations for testing") + parser.add_argument( + "--base-dir", + default="generated_kernels", + help="Base directory containing operator subdirectories", + ) + + args = parser.parse_args() + + create_add(args.base_dir) + + +if __name__ == "__main__": + main() diff --git a/pyproject.toml b/pyproject.toml index ae6bafc..a335d74 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -28,6 +28,7 @@ dependencies = [ "pandas", "datasets", "tenacity", + "ninja", ] [project.optional-dependencies] diff --git a/test/test_directory_backend.py b/test/test_directory_backend.py index a045d0b..1baaef2 100644 --- a/test/test_directory_backend.py +++ b/test/test_directory_backend.py @@ -19,86 +19,90 @@ from BackendBench.backends import DirectoryBackend from BackendBench.utils import op_name_to_folder_name +try: + from torch.utils.cpp_extension import CUDA_HOME +except ImportError: + CUDA_HOME = None -@pytest.fixture(scope="module") -def backend(): + +@pytest.fixture(scope="class") +def backend(request): # Always create correct test implementations, overriding any watermarked ones import subprocess subprocess.run( [sys.executable, "-m", "BackendBench.scripts.create_simple_test_ops"], check=True ) + yield DirectoryBackend(ops_dir="generated_kernels") - return DirectoryBackend(ops_dir="generated_kernels") - - -def test_relu_operation(backend): - relu_op = torch.ops.aten.relu.default - assert relu_op in backend - - our_impl = backend[relu_op] - x = torch.tensor([-2.0, -1.0, 0.0, 1.0, 2.0]) - result = our_impl(x) - expected = relu_op(x) + import shutil - assert torch.allclose(result, expected) + shutil.rmtree("generated_kernels", ignore_errors=True) -def test_add_operation(backend): - add_op = torch.ops.aten.add.Tensor - assert add_op in backend +class TestDirectoryBackend: + def test_relu_operation(self, backend): + relu_op = torch.ops.aten.relu.default + assert relu_op in backend - our_impl = backend[add_op] - a = torch.tensor([1.0, 2.0, 3.0]) - b = torch.tensor([4.0, 5.0, 6.0]) - result = our_impl(a, b) - expected = add_op(a, b) + our_impl = backend[relu_op] + x = torch.tensor([-2.0, -1.0, 0.0, 1.0, 2.0]) + result = our_impl(x) + expected = relu_op(x) - assert torch.allclose(result, expected) + assert torch.allclose(result, expected) + def test_add_operation(self, backend): + add_op = torch.ops.aten.add.Tensor + assert add_op in backend -def test_mul_operation(backend): - mul_op = torch.ops.aten.mul.Tensor - assert mul_op in backend + our_impl = backend[add_op] + a = torch.tensor([1.0, 2.0, 3.0]) + b = torch.tensor([4.0, 5.0, 6.0]) + result = our_impl(a, b) + expected = add_op(a, b) - our_impl = backend[mul_op] - a = torch.tensor([1.0, 2.0, 3.0]) - b = torch.tensor([4.0, 5.0, 6.0]) - result = our_impl(a, b) - expected = mul_op(a, b) + assert torch.allclose(result, expected) - assert torch.allclose(result, expected) + def test_mul_operation(self, backend): + mul_op = torch.ops.aten.mul.Tensor + assert mul_op in backend + our_impl = backend[mul_op] + a = torch.tensor([1.0, 2.0, 3.0]) + b = torch.tensor([4.0, 5.0, 6.0]) + result = our_impl(a, b) + expected = mul_op(a, b) -def test_abs_operation(backend): - abs_op = torch.ops.aten.abs.default - assert abs_op in backend + assert torch.allclose(result, expected) - our_impl = backend[abs_op] - x = torch.tensor([-2.0, -1.0, 0.0, 1.0, 2.0]) - result = our_impl(x) - expected = abs_op(x) + def test_abs_operation(self, backend): + abs_op = torch.ops.aten.abs.default + assert abs_op in backend - assert torch.allclose(result, expected) + our_impl = backend[abs_op] + x = torch.tensor([-2.0, -1.0, 0.0, 1.0, 2.0]) + result = our_impl(x) + expected = abs_op(x) + assert torch.allclose(result, expected) -def test_sum_operation(backend): - sum_op = torch.ops.aten.sum.default - assert sum_op in backend + def test_sum_operation(self, backend): + sum_op = torch.ops.aten.sum.default + assert sum_op in backend - our_impl = backend[sum_op] - x = torch.tensor([[1.0, 2.0], [3.0, 4.0]]) - result = our_impl(x) - expected = sum_op(x) + our_impl = backend[sum_op] + x = torch.tensor([[1.0, 2.0], [3.0, 4.0]]) + result = our_impl(x) + expected = sum_op(x) - assert torch.allclose(result, expected) + assert torch.allclose(result, expected) + def test_backend_loading(self, backend): + loaded_ops = set(backend.compiled_kernels.keys()) + assert len(loaded_ops) > 0 -def test_backend_loading(backend): - loaded_ops = set(backend.compiled_kernels.keys()) - assert len(loaded_ops) > 0 - - if os.path.exists("generated_kernels"): + assert os.path.exists("generated_kernels") dirs = [ d for d in os.listdir("generated_kernels") @@ -106,15 +110,80 @@ def test_backend_loading(backend): ] assert len(dirs) > 0 + def test_kernel_directories_exist(self, backend): + assert os.path.exists("generated_kernels") + + expected_ops = ["relu.default", "add.Tensor", "mul.Tensor", "abs.default", "sum.default"] + for expected_op in expected_ops: + expected_dir = op_name_to_folder_name(expected_op) + dir_path = os.path.join("generated_kernels", expected_dir) + assert os.path.isdir(dir_path) + + py_files = [f for f in os.listdir(dir_path) if f.endswith(".py")] + assert len(py_files) > 0 + + +@pytest.fixture(scope="class") +def backend_cuda(request): + import subprocess + + # Access class attribute via request.cls + base_dir = getattr(request.cls, "base_dir", "generated_kernels_cuda") + subprocess.run( + [ + sys.executable, + "-m", + "BackendBench.scripts.create_simple_test_ops_cuda", + "--base-dir", + base_dir, + ], + check=True, + ) + backend_instance = DirectoryBackend(ops_dir=base_dir) + + yield backend_instance + + import shutil + + shutil.rmtree(base_dir, ignore_errors=True) + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA is not available") +@pytest.mark.skipif(CUDA_HOME is None, reason="CUDA_HOME is not available") +class TestDirectoryBackendCUDA: + base_dir = "generated_kernels_cuda" + + def test_add_operation(self, backend_cuda): + add_op = torch.ops.aten.add.Tensor + assert add_op in backend_cuda + + our_impl = backend_cuda[add_op] + a = torch.tensor([1.0, 2.0, 3.0]).cuda() + b = torch.tensor([4.0, 5.0, 6.0]).cuda() + result = our_impl(a, b) + expected = add_op(a, b) + + assert torch.allclose(result, expected) + + def test_backend_loading(self, backend_cuda): + loaded_ops = set(backend_cuda.compiled_kernels.keys()) + assert len(loaded_ops) > 0 + os.path.exists(self.base_dir) + + dirs = [ + d for d in os.listdir(self.base_dir) if os.path.isdir(os.path.join(self.base_dir, d)) + ] + assert len(dirs) > 0 -def test_kernel_directories_exist(backend): - assert os.path.exists("generated_kernels") + def test_kernel_directories_exist(self, backend_cuda): + assert os.path.exists(self.base_dir) - expected_ops = ["relu.default", "add.Tensor", "mul.Tensor", "abs.default", "sum.default"] - for expected_op in expected_ops: - expected_dir = op_name_to_folder_name(expected_op) - dir_path = os.path.join("generated_kernels", expected_dir) - assert os.path.isdir(dir_path) + expected_dirs = ["add__Tensor"] + for expected_dir in expected_dirs: + dir_path = os.path.join(self.base_dir, expected_dir) + assert os.path.isdir(dir_path) - py_files = [f for f in os.listdir(dir_path) if f.endswith(".py")] - assert len(py_files) > 0 + cuda_files = [ + f for f in os.listdir(dir_path) if f.endswith(".cu") or f.endswith(".cpp") + ] + assert len(cuda_files) > 0