From 7bee7179190b66ec5e1f18fa8d033a57e9bbc31e Mon Sep 17 00:00:00 2001 From: jiafei96 Date: Tue, 22 Jul 2025 06:13:08 +0000 Subject: [PATCH] fix rocm build error --- csrc/custom_marlin/gptq_marlin/gptq_marlin.cu | 2 +- csrc/custom_marlin/gptq_marlin/gptq_marlin.cuh | 2 +- csrc/custom_marlin/gptq_marlin/gptq_marlin_dtypes.cuh | 5 +++++ csrc/custom_marlin/gptq_marlin/gptq_marlin_repack.cu | 2 +- ktransformers/operators/linear.py | 2 +- 5 files changed, 9 insertions(+), 4 deletions(-) diff --git a/csrc/custom_marlin/gptq_marlin/gptq_marlin.cu b/csrc/custom_marlin/gptq_marlin/gptq_marlin.cu index 73ba3ddd..e7a762a2 100644 --- a/csrc/custom_marlin/gptq_marlin/gptq_marlin.cu +++ b/csrc/custom_marlin/gptq_marlin/gptq_marlin.cu @@ -34,7 +34,7 @@ template inline std::string str(T x) { return std::to_string(x); } namespace gptq_marlin { -#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 +#if (defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800) || defined(__HIP_PLATFORM_AMD__) __global__ void permute_cols_kernel(int4 const* __restrict__ a_int4_ptr, int const* __restrict__ perm_int_ptr, diff --git a/csrc/custom_marlin/gptq_marlin/gptq_marlin.cuh b/csrc/custom_marlin/gptq_marlin/gptq_marlin.cuh index 5b4b0599..e2220bac 100644 --- a/csrc/custom_marlin/gptq_marlin/gptq_marlin.cuh +++ b/csrc/custom_marlin/gptq_marlin/gptq_marlin.cuh @@ -38,7 +38,7 @@ using I4 = Vec; constexpr int div_ceil(int a, int b) { return (a + b - 1) / b; } -#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 +#if (defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800) || defined(__HIP_PLATFORM_AMD__) // No support for async #else diff --git a/csrc/custom_marlin/gptq_marlin/gptq_marlin_dtypes.cuh b/csrc/custom_marlin/gptq_marlin/gptq_marlin_dtypes.cuh index 3e8c3ca9..ccffb2e5 100644 --- a/csrc/custom_marlin/gptq_marlin/gptq_marlin_dtypes.cuh +++ b/csrc/custom_marlin/gptq_marlin/gptq_marlin_dtypes.cuh @@ -8,6 +8,11 @@ #include #include +#ifdef __HIP_PLATFORM_AMD__ +typedef __hip_bfloat16 nv_bfloat16; +typedef __hip_bfloat162 nv_bfloat162; +#endif + namespace gptq_marlin { template class ScalarType {}; diff --git a/csrc/custom_marlin/gptq_marlin/gptq_marlin_repack.cu b/csrc/custom_marlin/gptq_marlin/gptq_marlin_repack.cu index 4adcbd5a..a6f82daa 100644 --- a/csrc/custom_marlin/gptq_marlin/gptq_marlin_repack.cu +++ b/csrc/custom_marlin/gptq_marlin/gptq_marlin_repack.cu @@ -9,7 +9,7 @@ static constexpr int repack_threads = 256; static constexpr int tile_k_size = tile_size; static constexpr int tile_n_size = tile_k_size * 4; -#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 +#if (defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800) || defined(__HIP_PLATFORM_AMD__) template __global__ void marlin_repack_kernel( diff --git a/ktransformers/operators/linear.py b/ktransformers/operators/linear.py index 654c9f98..a60ff594 100644 --- a/ktransformers/operators/linear.py +++ b/ktransformers/operators/linear.py @@ -16,7 +16,6 @@ from torch import Tensor, nn if not torch.xpu.is_available(): import KTransformersOps - import vLLMMarlin from ktransformers.util.custom_loader import GGUFLoader, SafeTensorLoader from ktransformers.util.utils import InferenceState if not torch.xpu.is_available(): @@ -520,6 +519,7 @@ def forward(self, x: torch.Tensor, bsz_tensor: torch.Tensor = None) -> torch.Ten # padding x.shape[0] to avoid CUDA illegal memory access error x, orig_size_m = self._pad_input(x) + import vLLMMarlin x = vLLMMarlin.gptq_marlin_gemm( x, self.marlin_q_w,