Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .github/workflows/smoke-test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ jobs:
- name: Run smoke test
run: uv run python -m BackendBench.scripts.main --suite smoke --backend aten


- name: Run FACTO test
run: uv run python -m BackendBench.scripts.main --suite facto --backend aten --ops "add.Tensor"

Expand Down
94 changes: 86 additions & 8 deletions BackendBench/backends/directory.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
import os
from typing import Callable, Dict

from torch.utils.cpp_extension import load_inline

from ..utils import folder_name_to_op_name, get_pytorch_op
from .base import Backend

Expand Down Expand Up @@ -47,7 +49,8 @@ def _load_kernels(self):
impl_files = [
f
for f in os.listdir(op_dir)
if f.endswith(".py") and f.startswith(f"{folder_name}_implementation")
if (f.endswith(".py") or f.endswith(".cu") or f.endswith(".cpp"))
and f.startswith(f"{folder_name}_implementation")
]
if not impl_files:
logger.debug(f"No implementation files found in {op_dir}")
Expand All @@ -71,17 +74,13 @@ def _load_kernels(self):

logger.info(f"DirectoryBackend loaded {loaded_count} kernels from {self.ops_dir}/")

def _load_kernel_from_file(self, file_path: str, folder_name: str) -> Callable:
def _load_python_kernel(self, file_path: str, folder_name: str) -> Callable:
"""
Dynamically load a kernel implementation function from a Python file.

Each operator directory should contain implementation files that export a function
named {op_name}_kernel_impl. This function becomes the kernel implementation
that gets registered for all variants of the operator.
Load a kernel implementation from a Python file.

Args:
file_path: Path to the Python implementation file
op_name: Base name of the operator (e.g., "add", "mul", "conv2d")
folder_name: Base name of the operator (e.g., "add__Tensor")

Returns:
Callable kernel implementation function
Expand All @@ -99,6 +98,85 @@ def _load_kernel_from_file(self, file_path: str, folder_name: str) -> Callable:
else:
raise ValueError(f"No function named {kernel_func_name} found in {file_path}")

def _load_cuda_kernel(self, file_path: str, folder_name: str) -> Callable:
"""
Load and compile a kernel implementation from CUDA files using load_inline.

Args:
file_path: Path to the CUDA implementation file (.cu or .cpp)
folder_name: Base name of the operator (e.g., "add__Tensor")

Returns:
Callable kernel implementation function

Raises:
ValueError: If the expected kernel function is not found in the compiled module
"""
file_dir = os.path.dirname(file_path)
file_name = os.path.basename(file_path)
base_name = file_name.rsplit(".", 1)[0]

cu_file = os.path.join(file_dir, f"{base_name}.cu")
cpp_file = os.path.join(file_dir, f"{base_name}.cpp")

cpp_source = ""
cuda_source = ""

# Read both files if they exist
if os.path.exists(cu_file):
with open(cu_file, "r") as f:
cuda_source = f.read()

if os.path.exists(cpp_file):
with open(cpp_file, "r") as f:
cpp_source = f.read()

# Use load_inline for all cases
module_name = f"{folder_name}_cuda_inline"
cuda_module = load_inline(
name=module_name,
cpp_sources=cpp_source,
cuda_sources=cuda_source,
functions=[folder_name],
no_implicit_headers=True,
)

if hasattr(cuda_module, folder_name):
return getattr(cuda_module, folder_name)
else:
raise ValueError(
f"No function named {folder_name} found in compiled CUDA module from {file_path}"
)

def _load_kernel_from_file(self, file_path: str, folder_name: str) -> Callable:
"""
Dynamically load a kernel implementation function from a Python or CUDA file.

Dispatches to the appropriate loader based on file extension:
- .py files -> _load_python_kernel
- .cu or .cpp files -> _load_cuda_kernel

Args:
file_path: Path to the implementation file (Python or CUDA)
op_name: Base name of the operator (e.g., "add", "mul", "conv2d")

Returns:
Callable kernel implementation function

Raises:
ValueError: If the file extension is unsupported or the kernel function is not found
"""
file_ext = os.path.splitext(file_path)[1]

if file_ext == ".py":
return self._load_python_kernel(file_path, folder_name)
elif file_ext in [".cu", ".cpp"]:
return self._load_cuda_kernel(file_path, folder_name)
else:
raise ValueError(
f"Unsupported file extension {file_ext} for {file_path}. Expected .py, .cu, or .cpp"
)

def __getitem__(self, key):
if key in self.compiled_kernels:
return self.compiled_kernels[key]
Expand Down
71 changes: 71 additions & 0 deletions BackendBench/scripts/create_simple_test_ops_cuda.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
#!/usr/bin/env python3

# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD 3-Clause license found in the
# LICENSE file in the root directory of this source tree.

"""
Create simple kernel implementations for 5 common operations.
Each just calls the original PyTorch function.
"""

import argparse
import logging
import os

logger = logging.getLogger(__name__)


def create_add(base_dir):
os.makedirs(f"{base_dir}/add__Tensor", exist_ok=True)
with open(f"{base_dir}/add__Tensor/add__Tensor_implementation_v1.cu", "w") as f:
f.write("""#include <torch/extension.h>
#include <ATen/cuda/CUDAContext.h>

__global__ void add__Tensor_kernel(
const float* __restrict__ x,
const float* __restrict__ y,
float* __restrict__ output,
const int size) {
const auto index = blockIdx.x * blockDim.x + threadIdx.x;
if (index < size) {
output[index] = x[index] + y[index];
}
}

at::Tensor add__Tensor(const at::Tensor& a, const at::Tensor& b) {
auto out = at::empty_like(a);
int64_t numel = a.numel();
const int threads = 256;
const int blocks = (numel + threads - 1) / threads;
add__Tensor_kernel<<<blocks, threads, 0, c10::cuda::getCurrentCUDAStream()>>>(
a.data_ptr<float>(), b.data_ptr<float>(), out.data_ptr<float>(), numel
);
return out;
}
""")
with open(f"{base_dir}/add__Tensor/add__Tensor_implementation_v1.cpp", "w") as f:
f.write("""#include <torch/extension.h>

at::Tensor add__Tensor(const at::Tensor& a, const at::Tensor& b);""")
logger.info("Created add implementation")


def main():
"""Create 1 simple test operations."""
parser = argparse.ArgumentParser(description="Creating cuda kernel implementations for testing")
parser.add_argument(
"--base-dir",
default="generated_kernels",
help="Base directory containing operator subdirectories",
)

args = parser.parse_args()

create_add(args.base_dir)


if __name__ == "__main__":
main()
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ dependencies = [
"pandas",
"datasets",
"tenacity",
"ninja",
]

[project.optional-dependencies]
Expand Down
Loading