Skip to content
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.

Commit b03225b

Browse files
committedFeb 6, 2025·
Add utils and re-organize code structures.
1 parent 9199281 commit b03225b

13 files changed

+117
-48
lines changed
 

‎csrc/dequant.cu

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
// Copyright (c) Microsoft Corporation. All rights reserved.
22
// Licensed under the MIT License.
33

4-
#include "common.h"
54
#include "dequant.cuh"
5+
#include "util/common.h"
66

77
namespace vptq {
88

‎csrc/dequant.cuh

+2-1
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
11
// Copyright (c) Microsoft Corporation. All rights reserved.
22
// Licensed under the MIT License.
3+
34
#pragma once
45

5-
#include "cuda_utils.cuh"
6+
#include "util/cuda_utils.cuh"
67

78
namespace vptq {
89

‎csrc/ops.cc

-1
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22
// Licensed under the MIT License.
33

44
/// register bindings for VPTQ APIs in this file. ///
5-
65
#include <torch/extension.h>
76

87
namespace vptq {

‎csrc/quant_gemv.cu

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
// Copyright (c) Microsoft Corporation. All rights reserved.
22
// Licensed under the MIT License.
33

4-
#include "common.h"
54
#include "quant_gemv.cuh"
5+
#include "util/common.h"
66

77
namespace vptq {
88

‎csrc/quant_gemv.cuh

+2-1
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,8 @@
22
// Licensed under the MIT License.
33
#pragma once
44

5-
#include "cuda_utils.cuh"
5+
#include "util/cuda_utils.cuh"
6+
#include "util/debug.cuh"
67

78
namespace vptq {
89

‎csrc/quant_gemv_v2.cu

+24-3
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,27 @@
1-
#include "common.h"
2-
#include "dispatch_macros.h"
1+
// Copyright (c) Microsoft Corporation. All rights reserved.
2+
// Licensed under the MIT License.
3+
34
#include "quant_gemv_v2.cuh"
5+
#include "util/common.h"
6+
#include "util/cuda_utils.cuh"
7+
#include "util/dispatch_macros.h"
48

59
namespace vptq {
610

11+
/**
12+
* @brief Quantized GEMV kernel.
13+
* @param act The input activations.
14+
* @param bias The bias.
15+
* @param indices The indices.
16+
* @param centroids The codebook for the main vector quantized weights.
17+
* Stored in row-major order. Element type: fp16, bf16.
18+
* Shape: (num_codebooks, num_centroids, vec_len).
19+
* @param residual_centroids The residual centroids.
20+
* @param scale_weights The scale weights.
21+
* @param scale_bias The scale bias.
22+
* @param in_features The number of input features.
23+
* @param out_features The number of output features.
24+
*/
725
torch::Tensor quant_gemv_v2(
826
const torch::Tensor& act, const c10::optional<torch::Tensor>& bias,
927
const torch::Tensor& indices, const torch::Tensor& centroids,
@@ -44,6 +62,9 @@ torch::Tensor quant_gemv_v2(
4462
const int64_t num_centroids = centroids.size(1);
4563
const int64_t vec_len = centroids.size(2);
4664

65+
TORCH_CHECK_LT(batch, 16)
66+
<< "In GEMV, the batch size is suggested to be less than 16.";
67+
4768
TORCH_CHECK_EQ(num_codebooks, 1) << "Only support one codebook.";
4869

4970
TORCH_CHECK(
@@ -60,7 +81,7 @@ torch::Tensor quant_gemv_v2(
6081
dim3 blocks(batch, num_codebooks, block_z);
6182
// FIXME(ying): refine the choice of threads in a thread block.
6283
// For test at the moment.
63-
dim3 threads(256, 1, 1);
84+
dim3 threads(256, 1, 1); // four warps in a thread block.
6485

6586
std::cout << "num_codebooks: " << num_codebooks << std::endl
6687
<< "num_centroids: " << num_centroids << std::endl

‎csrc/quant_gemv_v2.cuh

+2-6
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,8 @@
22
// Licensed under the MIT License.
33
#pragma once
44

5-
#include "cuda_utils.cuh"
5+
#include "util/cuda_utils.cuh"
6+
#include "util/debug.cuh"
67

78
namespace vptq {
89

@@ -16,11 +17,6 @@ __global__ void quant_gemv_v2_kernel(
1617
const DType* const __restrict__ scale_weights,
1718
const DType* const __restrict__ scale_bias, int64_t in_features,
1819
int64_t out_features, int64_t vec_len) {
19-
if (threadIdx.x == 0 && blockIdx.x == 0 && blockIdx.y == 0 &&
20-
blockIdx.z == 0) {
21-
printf("quant_gemv_v2_kernel\n");
22-
}
23-
2420
return;
2521
}
2622

‎csrc/common.h ‎csrc/util/common.h

File renamed without changes.

‎csrc/util/config.cuh

+42
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
// Copyright (c) Microsoft Corporation. All rights reserved.
2+
// Licensed under the MIT License.
3+
4+
#pragma once
5+
6+
#if defined(__CUDA_ARCH__)
7+
#define HOST_DEVICE __forceinline__ __host__ __device__
8+
#define DEVICE __forceinline__ __device__
9+
#define HOST __forceinline__ __host__
10+
#else
11+
#define HOST_DEVICE inline
12+
#define DEVICE inline
13+
#define HOST inline
14+
#endif
15+
16+
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800))
17+
#define CP_ASYNC_SM80_ENABLED
18+
#endif
19+
20+
#if defined(USE_ROCM)
21+
#include <hip/hip_bf16.h>
22+
#include <hip/hip_fp16.h>
23+
24+
#define VPTQ_LDG(arg) __ldg(arg)
25+
#define SHFL_DOWN(val, offset) __shfl_down(val, offset)
26+
#define WARP_SIZE warpSize
27+
28+
typedef __hip_bfloat162 __bfloat162;
29+
typedef __hip_bfloat16 __bfloat16;
30+
31+
#else
32+
#include <cuda_bf16.h>
33+
#include <cuda_fp16.h>
34+
35+
#define WARP_SIZE 32
36+
#define VPTQ_LDG(arg) *(arg)
37+
#define SHFL_DOWN(val, offset) __shfl_down_sync(0xffffffff, val, offset)
38+
39+
typedef __nv_bfloat162 __bfloat162;
40+
typedef __nv_bfloat16 __bfloat16;
41+
42+
#endif

‎csrc/util/copy.cuh

+6
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
// Copyright (c) Microsoft Corporation. All rights reserved.
2+
// Licensed under the MIT License.
3+
4+
namespace vptq {
5+
namespace cutlass_wrapper {} // namespace cutlass_wrapper
6+
} // namespace vptq

‎csrc/cuda_utils.cuh ‎csrc/util/cuda_utils.cuh

+2-34
Original file line numberDiff line numberDiff line change
@@ -2,41 +2,9 @@
22
// Licensed under the MIT License.
33
#pragma once
44

5-
#include <ATen/cuda/CUDAContext.h>
6-
7-
#if defined(USE_ROCM)
8-
#include <hip/hip_bf16.h>
9-
#include <hip/hip_fp16.h>
10-
11-
#define VPTQ_LDG(arg) __ldg(arg)
12-
#define SHFL_DOWN(val, offset) __shfl_down(val, offset)
13-
#define WARP_SIZE warpSize
14-
15-
typedef __hip_bfloat162 __bfloat162;
16-
typedef __hip_bfloat16 __bfloat16;
17-
18-
#else
19-
#include <cuda_bf16.h>
20-
#include <cuda_fp16.h>
5+
#include "config.cuh"
216

22-
#define WARP_SIZE 32
23-
#define VPTQ_LDG(arg) *(arg)
24-
#define SHFL_DOWN(val, offset) __shfl_down_sync(0xffffffff, val, offset)
25-
26-
typedef __nv_bfloat162 __bfloat162;
27-
typedef __nv_bfloat16 __bfloat16;
28-
29-
#endif
30-
31-
#if defined(__CUDA_ARCH__)
32-
#define HOST_DEVICE __forceinline__ __host__ __device__
33-
#define DEVICE __forceinline__ __device__
34-
#define HOST __forceinline__ __host__
35-
#else
36-
#define HOST_DEVICE inline
37-
#define DEVICE inline
38-
#define HOST inline
39-
#endif
7+
#include <ATen/cuda/CUDAContext.h>
408

419
namespace vptq {
4210

‎csrc/util/debug.cuh

+35
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
// Copyright (c) Microsoft Corporation. All rights reserved.
2+
// Licensed under the MIT License.
3+
4+
#pragma once
5+
6+
#include "config.cuh"
7+
8+
#include <cuda_runtime_api.h>
9+
10+
namespace vptq {
11+
12+
DEVICE bool block(int bid) {
13+
int id =
14+
blockIdx.x + blockIdx.y * gridDim.x + blockIdx.z * gridDim.x * gridDim.y;
15+
return id == bid;
16+
}
17+
18+
DEVICE bool thread(int tid, int bid) {
19+
int id = threadIdx.x + threadIdx.y * blockDim.x +
20+
threadIdx.z * blockDim.x * blockDim.y;
21+
return id == tid && block(bid);
22+
}
23+
24+
// usage, e.g.
25+
// if (thread(0, 0)) { ... }
26+
// if (thread(37)) { ... }
27+
// if (block(0)) { ... }
28+
29+
DEVICE bool thread(int tid) { return thread(tid, 0); }
30+
31+
DEVICE bool thread0() { return thread(0, 0); }
32+
33+
DEVICE bool block0() { return block(0); }
34+
35+
} // namespace vptq
File renamed without changes.

0 commit comments

Comments
 (0)
Please sign in to comment.