Skip to content

Commit f7c4ee9

Browse files
committed
delete unnecessary codes.
1 parent 1937767 commit f7c4ee9

File tree

7 files changed

+31
-239
lines changed

7 files changed

+31
-239
lines changed

csrc/dequant.cu

+2
Original file line numberDiff line numberDiff line change
@@ -208,7 +208,9 @@ torch::Tensor launch_deqantize_outliers_cuda_packkernel(
208208
TORCH_CHECK(false, "un-supported base_groupsize:" +
209209
std::to_string(base_groupsize));
210210
}
211+
211212
#undef CASE_DispatchDequantWithOutliers
213+
212214
if (out_ouf_inf) {
213215
return output;
214216
} else {

csrc/dequant.h

-20
This file was deleted.

csrc/ops.cc

+29-8
Original file line numberDiff line numberDiff line change
@@ -1,22 +1,43 @@
11
// Copyright (c) Microsoft Corporation. All rights reserved.
22
// Licensed under the MIT License.
33

4-
/// register VPTQ APIs bindings in this file. ///
4+
/// register bindings for VPTQ APIs in this file. ///
55

6-
#include "dequant.h"
7-
#include "quant_gemv.h"
6+
#include <torch/extension.h>
7+
8+
namespace vptq {
9+
10+
torch::Tensor dequant(const torch::Tensor& q_indice,
11+
const torch::Tensor& centroids,
12+
const c10::optional<torch::Tensor>& q_indice_residual,
13+
const c10::optional<torch::Tensor>& residual_centroids,
14+
const c10::optional<torch::Tensor>& q_indice_outliers,
15+
const c10::optional<torch::Tensor>& outliers_centroids,
16+
const c10::optional<torch::Tensor>& invperm,
17+
const torch::Tensor& weight_scale,
18+
const torch::Tensor& weight_bias, int64_t groupsize,
19+
int64_t in_features, int64_t out_features);
20+
21+
torch::Tensor wquant_act16_gemv(
22+
const torch::Tensor& input, const torch::Tensor& q_indice,
23+
const torch::Tensor& centroids,
24+
const c10::optional<torch::Tensor>& q_indice_residual,
25+
const c10::optional<torch::Tensor>& residual_centroids,
26+
const c10::optional<torch::Tensor>& q_indice_outliers,
27+
const c10::optional<torch::Tensor>& outliers_centroids,
28+
const c10::optional<torch::Tensor>& invperm,
29+
const torch::Tensor& weight_scale, const torch::Tensor& weight_bias,
30+
const c10::optional<torch::Tensor>& bias, int64_t in_features,
31+
int64_t out_features);
32+
33+
} // namespace vptq
834

935
// NOTE: DO NOT change the module name "libvptq" here. It must match how
1036
// the module is loaded in the Python codes.
1137
PYBIND11_MODULE(libvptq, m) {
1238
m.doc() = "VPTQ customized kernels.";
1339

14-
// v1 kernels.
1540
m.def("dequant", &vptq::dequant, "vptq customized dequantization kernel.");
1641
m.def("quant_gemv", &vptq::wquant_act16_gemv,
1742
"vptq customized dequantized gemv kernel.");
18-
19-
// v2 kernels.
20-
m.def("quant_gemv_v2", &vptq::quant_gemv_v2,
21-
"vptq customized quantized gemm kernel.");
2243
}

csrc/quant_gemv.cu

-35
Original file line numberDiff line numberDiff line change
@@ -286,39 +286,4 @@ torch::Tensor wquant_act16_gemv(
286286
return output;
287287
}
288288

289-
torch::Tensor quant_gemv_v2(
290-
const torch::Tensor& activations, const c10::optional<torch::Tensor>& bias,
291-
const torch::Tensor& indices, const torch::Tensor& centroids,
292-
const c10::optional<torch::Tensor>& residual_centroids,
293-
const torch::Tensor& scale_weights, const torch::Tensor& scale_bias,
294-
int64_t in_features, int64_t out_features) {
295-
CHECK_INPUT(indices);
296-
CHECK_INPUT(centroids);
297-
CHECK_INPUT(scale_weights);
298-
CHECK_INPUT(scale_bias);
299-
300-
int64_t ndim = activations.ndimension();
301-
TORCH_CHECK(ndim == 3, "activations must be a 3D Tensor, but got: ",
302-
activations.sizes());
303-
304-
const int64_t batch = activations.size(0);
305-
306-
std::cout << "batch: " << batch << std::endl;
307-
308-
const int64_t num_codebooks = centroids.size(0);
309-
const int64_t num_centroids = centroids.size(1);
310-
const int64_t vec_len = centroids.size(2);
311-
312-
std::cout << "num_codebooks: " << num_codebooks << std::endl
313-
<< "num_centroids: " << num_centroids << std::endl
314-
<< "vec_len: " << vec_len << std::endl;
315-
316-
torch::Tensor output;
317-
output = at::empty({in_features, out_features}, centroids.options());
318-
319-
// auto stream = at::cuda::getCurrentCUDAStream().stream();
320-
321-
return output;
322-
}
323-
324289
} // namespace vptq

csrc/quant_gemv.h

-29
This file was deleted.

vptq/ops/quant_gemm.py

-69
Original file line numberDiff line numberDiff line change
@@ -270,72 +270,3 @@ def quant_gemm(
270270
)
271271
out = F.linear(x, weight, bias)
272272
return out
273-
274-
275-
def quant_gemv_v2(
276-
x: torch.Tensor,
277-
bias: Optional[torch.Tensor],
278-
indices: torch.Tensor,
279-
centroids: torch.Tensor,
280-
residual_centroids: Optional[torch.Tensor],
281-
scale_weights: Optional[torch.Tensor],
282-
scale_bias: Optional[torch.Tensor],
283-
vector_len: int,
284-
num_codebooks: int,
285-
num_centroids: int,
286-
num_residual_centroids: int,
287-
in_features: int,
288-
out_features: int,
289-
) -> torch.Tensor:
290-
""" Dequantize the input tensor and perform GEMV operation.
291-
292-
Args:
293-
x: Tensor[fp16|bf16], has a shape of (batch_size, sequence_length,
294-
in_features). NOTE that `batch_size` here represents the number of
295-
sequences, not tokens.
296-
bias: (optional) Tensor[fp16|bf16], has a shape of (1, out_features)
297-
indices: Tensor[int16], the original input tensor is flattened into a
298-
vector with a shape of (1, numel). Then, internally, it will be
299-
reshaped into a 3D tensor with a shape of
300-
(num_codebooks, num_indices, packed_groupsize).
301-
NOTE: If the residual quantization component is enabled,
302-
indices for the main quantization component and the residual
303-
quantization component are packed together into this single
304-
input tensor.
305-
centroids: Tensor[fp16|bf16], the original input tensor is flatten into
306-
a vector that has a shape of (1, numel), and then be
307-
reshaped internally into a 3-D tensor with a shape of
308-
(num_codebooks, num_centroids, vector_len).
309-
residual_centroids: (optional) Tensor[fp16|bf16], has a shape of
310-
(num_codebooks, num_residual_centroids, vector_len).
311-
scale_weights: (optional) Tensor[fp16|bf16], has a shape of
312-
(in_feature, 1), the scale factor for the quantized
313-
weight.
314-
scale_bias: (optional) Tensor[fp16|bf16], has a shape of
315-
(in_feature, 1), the bias factor for the quantized weight.
316-
vector_len: int, the length of the vector in vector quantization.
317-
num_codebooks: int, the number of codebooks.
318-
num_centroids: int, the number of centroids.
319-
num_residual_centroids: int, the number of residual centroids.
320-
in_features: int, the number of input features.
321-
out_features: int, the number of output features.
322-
"""
323-
centroids_ = centroids.view(num_codebooks, num_centroids, vector_len)
324-
325-
residual_centroids_ = None
326-
if residual_centroids is not None:
327-
shape = (num_codebooks, num_residual_centroids, vector_len)
328-
residual_centroids_ = residual_centroids.view(shape)
329-
330-
out = vptq_ops.quant_gemv_v2(
331-
x,
332-
bias,
333-
indices,
334-
centroids_,
335-
residual_centroids_,
336-
scale_weights,
337-
scale_bias,
338-
in_features,
339-
out_features,
340-
)
341-
return out

vptq/tests/ops/test_quant_gemm.py

-78
This file was deleted.

0 commit comments

Comments
 (0)