Skip to content

Commit 692a41c

Browse files
authored
fix(csrc): Remove strong dependency on specific Torch version. (#166)
If the built wheel is created with Torch version A, it cannot be used in a local environment that has Torch version B installed, where A and B are different. This PR addresses and fixes this issue.
1 parent 02c4f47 commit 692a41c

9 files changed

+41
-108
lines changed

csrc/CMakeLists.txt

+4
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,10 @@ set_target_properties(
2323
# https://github.com/pytorch/pytorch/issues/13541
2424
target_compile_definitions(${TARGET} PUBLIC _GLIBCXX_USE_CXX11_ABI=0)
2525

26+
# The find_package(Torch) command does not expose PyTorch's Python bindings.
27+
# However, when using Pybind11, we need to link against these bindings.
28+
list(APPEND TORCH_LIBRARIES "${TORCH_INSTALL_PREFIX}/lib/libtorch_python.so")
29+
2630
target_compile_options(
2731
${TARGET}
2832
PUBLIC $<$<COMPILE_LANGUAGE:CUDA>:

csrc/common.h

+2
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
#include <ATen/cuda/CUDAContext.h>
66
#include <torch/extension.h>
77

8+
namespace vptq {
89
class OptionalCUDAGuard {
910
int set_device_ = -1;
1011
int current_device_ = -1;
@@ -40,3 +41,4 @@ inline void gpuAssert(cudaError_t code, const char* file, int line) {
4041
TORCH_CHECK(false, cudaGetErrorString(code));
4142
}
4243
}
44+
} // namespace vptq

csrc/dequant_impl_packed.cu

+2
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
#include "common.h"
88
#include "utils.cuh"
99

10+
namespace vptq {
1011
template <typename T>
1112
struct C10ToNvType {
1213
typedef __bfloat16 type;
@@ -734,3 +735,4 @@ torch::Tensor launch_gemv_outliers_cuda_packkernel(
734735
}
735736
return output;
736737
}
738+
} // namespace vptq

csrc/ops.cc

+8-36
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,8 @@
66
#include <c10/cuda/CUDAGuard.h>
77

88
#include <torch/extension.h>
9-
#include <torch/library.h>
109

10+
namespace vptq {
1111
#define CHECK_CUDA(x) \
1212
TORCH_CHECK(x.device().is_cuda(), #x " must be a CUDA tensor")
1313
#define CHECK_CONTIGUOUS(x) \
@@ -157,41 +157,13 @@ torch::Tensor wqA16Gemm(const torch::Tensor& input,
157157

158158
return output;
159159
}
160+
} // namespace vptq
160161

161-
TORCH_LIBRARY_IMPL(vptq, CUDA, m) {
162-
m.impl("dequant", dequant);
163-
m.impl("gemm", wqA16Gemm);
164-
}
162+
// NOTE: DO NOT change the module name "libvptq" here. It must match how
163+
// the module is loaded in the Python codes.
164+
PYBIND11_MODULE(libvptq, m) {
165+
m.doc() = "VPTQ customized kernels.";
165166

166-
TORCH_LIBRARY(vptq, m) {
167-
m.def(
168-
R"DOC(dequant(Tensor q_indice,
169-
Tensor centroids,
170-
Tensor? q_indice_residual,
171-
Tensor? residual_centroids,
172-
Tensor? q_indice_outliers,
173-
Tensor? outliers_centroids,
174-
Tensor? invperm,
175-
Tensor weight_scale,
176-
Tensor weight_bias,
177-
int groupsize,
178-
int in_features,
179-
int out_features) -> Tensor
180-
)DOC");
181-
m.def(
182-
R"DOC(gemm(Tensor input,
183-
Tensor q_indice,
184-
Tensor centroids,
185-
Tensor? q_indice_residual,
186-
Tensor? residual_centroids,
187-
Tensor? q_indice_outliers,
188-
Tensor? outliers_centroids,
189-
Tensor? invperm,
190-
Tensor weight_scale,
191-
Tensor weight_bias,
192-
Tensor? bias,
193-
int groupsize,
194-
int in_features,
195-
int out_features) -> Tensor
196-
)DOC");
167+
m.def("dequant", &vptq::dequant, "vptq customized dequantization kernel.");
168+
m.def("gemm", &vptq::wqA16Gemm, "vptq customized dequantized gemv kernel.");
197169
}

csrc/utils.cuh

+2-46
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ typedef __nv_bfloat162 __bfloat162;
2222
typedef __nv_bfloat16 __bfloat16;
2323
#endif
2424

25+
namespace vptq {
2526
namespace cuda {
2627

2728
constexpr int kBlockSize = 256;
@@ -93,20 +94,8 @@ __device__ __forceinline__ void ldg_vec_x(
9394
const int2* src = (const int2*)src_u32;
9495
if constexpr (GROUPSIZE == 2) {
9596
*dst_u32 = VPTQ_LDG(src_u32);
96-
// uint32_t* dec = (uint32_t*)dst;
97-
// asm volatile (
98-
// "ld.cg.global.v2.u32 {%0, %1}, [%2];"
99-
// : "=r"(dec[0]), "=r"(dec[1])
100-
// : "l"((const void*)src)
101-
// );
10297
} else if constexpr (GROUPSIZE == 4) {
10398
*dst = VPTQ_LDG(src);
104-
// uint32_t* dec = (uint32_t*)dst;
105-
// asm volatile (
106-
// "ld.cg.global.v2.u32 {%0, %1}, [%2];"
107-
// : "=r"(dec[0]), "=r"(dec[1])
108-
// : "l"((const void*)src)
109-
// );
11099
} else if constexpr (GROUPSIZE == 6) {
111100
dst_u32[0] = VPTQ_LDG(src_u32);
112101
dst_u32[1] = VPTQ_LDG(src_u32 + 1);
@@ -116,12 +105,6 @@ __device__ __forceinline__ void ldg_vec_x(
116105
} else if constexpr (GROUPSIZE == 16) {
117106
*(int4*)dst = VPTQ_LDG((const int4*)src);
118107
*(int4*)(dst + 2) = VPTQ_LDG((const int4*)(src + 2));
119-
// asm volatile("ld.cg.global.v4.u32 {%0, %1, %2, %3}, [%4];"
120-
// : "=r"(dst_u32[0]), "=r"(dst_u32[1]), "=r"(dst_u32[2]),
121-
// "=r"(dst_u32[3]) : "l"((const void*)src_u32));
122-
// asm volatile("ld.cg.global.v4.u32 {%0, %1, %2, %3}, [%4];"
123-
// : "=r"(dst_u32[4]), "=r"(dst_u32[5]), "=r"(dst_u32[6]),
124-
// "=r"(dst_u32[7]) : "l"((const void*)(src_u32 + 4)));
125108
} else if constexpr (GROUPSIZE == 12) {
126109
if (uint64_t(src) % 16) {
127110
dst[0] = VPTQ_LDG(src);
@@ -132,38 +115,11 @@ __device__ __forceinline__ void ldg_vec_x(
132115
*(int4*)dst = VPTQ_LDG((int4*)(src));
133116
dst[2] = VPTQ_LDG((src + 2));
134117
}
135-
// dst[0] = VPTQ_LDG(src);
136-
// dst[1] = VPTQ_LDG((src+1));
137-
// dst[2] = VPTQ_LDG((src+2));
138-
139-
// uint32_t* dec = (uint32_t*)dst;
140-
// asm volatile (
141-
// "ld.cg.global.v4.u32 {%0, %1, %2, %3}, [%4];"
142-
// : "=r"(dec[0]), "=r"(dec[1]), "=r"(dec[2]), "=r"(dec[3])
143-
// : "l"((const void*)src)
144-
// );
145-
// asm volatile (
146-
// "ld.cg.global.v2.u32 {%0, %1}, [%2];"
147-
// : "=r"(dec[4]), "=r"(dec[5])
148-
// : "l"((const void*)src)
149-
// );
150118
} else if constexpr (GROUPSIZE == 24) {
151119
*((int4*)(dst)) = VPTQ_LDG((const int4*)(src));
152120
*(((int4*)(dst)) + 1) = VPTQ_LDG(((const int4*)(src)) + 1);
153121
*(((int4*)(dst)) + 2) = VPTQ_LDG(((const int4*)(src)) + 2);
154122
} else if constexpr (GROUPSIZE == 32) {
155-
// asm volatile("ld.cg.global.v4.u32 {%0, %1, %2, %3}, [%4];"
156-
// : "=r"(dst_u32[0]), "=r"(dst_u32[1]), "=r"(dst_u32[2]),
157-
// "=r"(dst_u32[3]) : "l"((const void*)src_u32));
158-
// asm volatile("ld.cg.global.v4.u32 {%0, %1, %2, %3}, [%4];"
159-
// : "=r"(dst_u32[4]), "=r"(dst_u32[5]), "=r"(dst_u32[6]),
160-
// "=r"(dst_u32[7]) : "l"((const void*)(src_u32 + 4)));
161-
// asm volatile("ld.cg.global.v4.u32 {%0, %1, %2, %3}, [%4];"
162-
// : "=r"(dst_u32[8]), "=r"(dst_u32[9]), "=r"(dst_u32[10]),
163-
// "=r"(dst_u32[11]) : "l"((const void*)(src_u32 + 8)));
164-
// asm volatile("ld.cg.global.v4.u32 {%0, %1, %2, %3}, [%4];"
165-
// : "=r"(dst_u32[12]), "=r"(dst_u32[13]), "=r"(dst_u32[14]),
166-
// "=r"(dst_u32[15]) : "l"((const void*)(src_u32 + 12)));
167123
*((int4*)(dst)) = VPTQ_LDG((const int4*)(src));
168124
*(((int4*)(dst)) + 1) = VPTQ_LDG(((const int4*)(src)) + 1);
169125
*(((int4*)(dst)) + 2) = VPTQ_LDG(((const int4*)(src)) + 2);
@@ -203,7 +159,6 @@ template <typename T>
203159
__forceinline__ T ceil_div(T a, T b) {
204160
return (a + b - 1) / b;
205161
}
206-
207162
} // namespace cuda
208163

209164
template <typename T>
@@ -288,3 +243,4 @@ __device__ __half operator*(const __half& a, const __half& b) {
288243
return __hmul(a, b);
289244
}
290245
#endif
246+
} // namespace vptq

pyproject.toml

+3-1
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ classifiers = [
2020
# `pyproject.toml`'s `dependencies` field.
2121
# Make sure to keep this field in sync with what is in `requirements.txt`.
2222
dependencies = [
23-
"torch",
23+
"torch>=2.3.0",
2424
"datasets",
2525
"transformers>=4.45",
2626
"safetensors",
@@ -29,6 +29,8 @@ dependencies = [
2929
"gradio",
3030
"plotly==5.9.0",
3131
"pynvml",
32+
"tqdm",
33+
"sentence_transformers",
3234
]
3335

3436
[project.urls]

requirements.txt

+2-1
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
cmake
22
packaging
33
setuptools>=64.0.0
4-
torch
4+
torch>=2.3.0
55
wheel
66
datasets
77
transformers>=4.45
@@ -12,3 +12,4 @@ gradio
1212
plotly==5.9.0
1313
pynvml
1414
tqdm
15+
sentence_transformers

vptq/ops/quant_gemm.py

+18-23
Original file line numberDiff line numberDiff line change
@@ -9,35 +9,29 @@
99
]
1010

1111
import math
12-
import os
1312

1413
import torch
1514
from torch.nn import functional as F
1615

1716
from vptq.utils.pack import unpack_index_tensor
1817

18+
__cuda_ops_installed = False
1919

20-
def _load_library(filename: str) -> bool:
21-
"""Load a shared library from the given filename."""
22-
try:
23-
libdir = os.path.dirname(os.path.dirname(__file__))
24-
torch.ops.load_library(os.path.join(libdir, filename))
25-
print(f"Successfully loaded: '{filename}'")
26-
return True
27-
except Exception as error:
28-
print((
29-
f"{error}\n"
30-
"!!! Warning !!!: CUDA kernels are not found, "
31-
"please check CUDA and VPTQ installation."
32-
))
33-
print((
34-
"!!! Warning !!!: Running on Torch implementations, "
35-
"which is extremely slow."
36-
))
37-
return False
20+
try:
21+
import vptq.libvptq as vptq_ops
3822

39-
40-
__cuda_ops_installed: bool = _load_library("libvptq.so")
23+
print("Successfully loaded VPTQ CUDA kernels.")
24+
__cuda_ops_installed = True
25+
except Exception as error:
26+
print((
27+
f"{error}\n"
28+
"!!! Warning !!!: CUDA kernels are not found, "
29+
"please check CUDA and VPTQ installation."
30+
))
31+
print((
32+
"!!! Warning !!!: Running on Torch implementations, "
33+
"which is extremely slow."
34+
))
4135

4236

4337
def dequant(
@@ -212,7 +206,7 @@ def quant_gemm(
212206
invert_perm = invert_perm.to(torch.uint16).view(torch.int16)
213207

214208
if (x.numel() // x.shape[-1] < 3) and __cuda_ops_installed:
215-
out = torch.ops.vptq.gemm(
209+
out = vptq_ops.gemm(
216210
x,
217211
indices,
218212
centroids_,
@@ -231,7 +225,8 @@ def quant_gemm(
231225
return out
232226
else:
233227
if __cuda_ops_installed:
234-
weight = torch.ops.vptq.dequant(
228+
229+
weight = vptq_ops.dequant(
235230
indices,
236231
centroids_,
237232
residual_indices,

vptq/utils/pack.py

-1
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@
1212
import tqdm
1313
from sentence_transformers.SentenceTransformer import SentenceTransformer
1414

15-
# import time
1615
import vptq
1716

1817
logging.basicConfig(

0 commit comments

Comments
 (0)