Skip to content

Commit

Permalink
fix(python): bug fix for loading built library. (#29)
Browse files Browse the repository at this point in the history
bug fix.
  • Loading branch information
lcy-seso authored Dec 31, 2024
1 parent dc7314d commit 36cf17a
Show file tree
Hide file tree
Showing 9 changed files with 47 additions and 20 deletions.
6 changes: 3 additions & 3 deletions include/kernels/flash_attn.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,8 @@ template <typename InType, typename AccType, typename OutType,
void run_flash_attention(const InType* dQ, const InType* dK, const InType* dV,
OutType* dO);

void custom_flash_attention_op(const torch::Tensor& Q, const torch::Tensor& K,
const torch::Tensor& V, torch::Tensor& O,
int64_t m, int64_t n, int64_t k, int64_t p);
void flash_attention_op(const torch::Tensor& Q, const torch::Tensor& K,
const torch::Tensor& V, torch::Tensor& O, int64_t m,
int64_t n, int64_t k, int64_t p);

} // namespace tilefusion::kernels
4 changes: 2 additions & 2 deletions include/kernels/scatter_nd.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ template <typename T>
void scatter_nd(torch::Tensor& data, const torch::Tensor& updates,
const torch::Tensor& indices);

void custom_scatter_op(torch::Tensor& data, const torch::Tensor& updates,
const torch::Tensor& indices);
void scatter_op(torch::Tensor& data, const torch::Tensor& updates,
const torch::Tensor& indices);

} // namespace tilefusion::kernels
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ Issues = "https://github.com/microsoft/TileFusion/issues"
requires = [
"cmake",
"packaging",
"setuptools>=49.4.0",
"setuptools>=64.0.0",
"wheel",
]
build-backend = "setuptools.build_meta"
Expand Down
15 changes: 15 additions & 0 deletions pytilefusion/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,24 @@
# Licensed under the MIT License.
# --------------------------------------------------------------------------

import os

import torch


def _load_library(filename: str) -> bool:
"""Load a shared library from the given filename."""
try:
libdir = os.path.dirname(os.path.dirname(__file__))
torch.ops.load_library(os.path.join(libdir, "pytilefusion", filename))
print(f"Successfully loaded: '{filename}'")
except Exception as error:
print(f"Fail to load library: '{filename}', {error}\n")


_load_library("libtilefusion.so")


def scatter_nd(scatter_data, scatter_indices, scatter_updates):
torch.ops.tilefusion.scatter_nd(
scatter_data, scatter_updates, scatter_indices
Expand Down
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
cmake
packaging
setuptools>=49.4.0
setuptools>=64.0.0
torch
wheel
11 changes: 8 additions & 3 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,18 +24,23 @@ def get_requirements():
class CMakeExtension(Extension):
""" specify the root folder of the CMake projects"""

def __init__(self, name, cmake_lists_dir=".", **kwargs):
def __init__(self, name="tilefusion", cmake_lists_dir=".", **kwargs):
Extension.__init__(self, name, sources=[], **kwargs)
self.cmake_lists_dir = os.path.abspath(cmake_lists_dir)


class CMakeBuildExt(build_ext):
"""launches the CMake build."""

def get_ext_filename(self, name):
return f"lib{name}.so"

def copy_extensions_to_source(self) -> None:
build_py = self.get_finalized_command("build_py")
for ext in self.extensions:
source_path = os.path.join(self.build_lib, "lib" + ext.name + ".so")
source_path = os.path.join(
self.build_lib, self.get_ext_filename(ext.name)
)
inplace_file, _ = self._get_inplace_equivalent(build_py, ext)

target_path = os.path.join(
Expand Down Expand Up @@ -164,7 +169,7 @@ def run(self):
python_requires=">=3.10",
packages=find_packages(exclude=[""]),
install_requires=get_requirements(),
ext_modules=[CMakeExtension("tilefusion")],
ext_modules=[CMakeExtension()],
cmdclass={
"build_ext": CMakeBuildExt,
"clean": Clean,
Expand Down
6 changes: 3 additions & 3 deletions src/kernels/flash_attn.cu
Original file line number Diff line number Diff line change
Expand Up @@ -424,9 +424,9 @@ void run_flash_attention(const InType* dQ, const InType* dK, const InType* dV,
cudaDeviceSynchronize();
}

void custom_flash_attention_op(const torch::Tensor& Q, const torch::Tensor& K,
const torch::Tensor& V, torch::Tensor& O,
int64_t m, int64_t n, int64_t k, int64_t p) {
void flash_attention_op(const torch::Tensor& Q, const torch::Tensor& K,
const torch::Tensor& V, torch::Tensor& O, int64_t m,
int64_t n, int64_t k, int64_t p) {
using InType = __half;
using AccType = float;
using OutType = __half;
Expand Down
4 changes: 2 additions & 2 deletions src/kernels/scatter_nd.cu
Original file line number Diff line number Diff line change
Expand Up @@ -114,8 +114,8 @@ void scatter_nd(torch::Tensor& data, const torch::Tensor& updates,
slice_size);
}

void custom_scatter_op(torch::Tensor& data, const torch::Tensor& updates,
const torch::Tensor& indices) {
void scatter_op(torch::Tensor& data, const torch::Tensor& updates,
const torch::Tensor& indices) {
auto dtype = data.dtype();
if (dtype == torch::kFloat32) {
scatter_nd<float>(data, updates, indices);
Expand Down
17 changes: 12 additions & 5 deletions src/torch_bind.cc
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,21 @@

#include "kernels/mod.hpp"

#include <torch/script.h>

namespace tilefusion {
using namespace tilefusion::kernels;

TORCH_LIBRARY(tilefusion, t) {
t.def("scatter_nd", &custom_scatter_op);
t.def("flash_attention_fwd", &custom_flash_attention_op);
TORCH_LIBRARY_IMPL(tilefusion, CUDA, m) {
m.impl("scatter_nd", scatter_op);
m.impl("flash_attention_fwd", flash_attention_op);
};

TORCH_LIBRARY(tilefusion, m) {
m.def("scatter_nd(Tensor(a!) data, Tensor updates, Tensor indices) -> ()");
m.def(
R"DOC(flash_attention_fwd(
Tensor(a!) Q,
Tensor K, Tensor V, Tensor O,
int m, int n, int k, int p) -> ()
)DOC");
}
} // namespace tilefusion

0 comments on commit 36cf17a

Please sign in to comment.