Skip to content

Commit 676964f

Browse files
authored
refactor (csrc): Restructure C++ code organization to facilitate adding new kernels (#169)
This PR does not change any logic or functionality in the main branch. It restructures the existing C++ code organization to facilitate the addition of new kernels in subsequent PRs. In general, each customized CUDA kernel involves modifying the following three files at the backend: 1. Registering Python bindings and adding declarations in `ops.cc`. 1. Placing the host's kernel launch function in `x.cc`, where `x` is the operator's name. 1. Implementing CUDA kernels in `x.cuh`.
1 parent 86685d5 commit 676964f

12 files changed

+976
-909
lines changed

.clang-format

+26-11
Original file line numberDiff line numberDiff line change
@@ -3,24 +3,39 @@ UseTab: Never
33
IndentWidth: 2
44
ColumnLimit: 80
55

6+
AccessModifierOffset: -2
7+
68
# Force pointers to the type for C++.
79
DerivePointerAlignment: false
810
PointerAlignment: Left
911

10-
# Reordering #include statements can (and currently will) introduce errors
11-
SortIncludes: false
12-
1312
# Style choices
1413
AlignConsecutiveAssignments: false
1514
AlignConsecutiveDeclarations: false
1615
IndentPPDirectives: BeforeHash
1716

17+
SortIncludes: true
18+
IncludeBlocks: Regroup
1819
IncludeCategories:
19-
- Regex: '^<'
20-
Priority: 4
21-
- Regex: '^"(llvm|llvm-c|clang|clang-c|mlir|mlir-c)/'
22-
Priority: 3
23-
- Regex: '^"(qoda|\.\.)/'
24-
Priority: 2
25-
- Regex: '.*'
26-
Priority: 1
20+
- Regex: '<([A-Za-z0-9\Q/-_\E])+>'
21+
Priority: 4
22+
- Regex: '<(catch2|boost)\/'
23+
Priority: 3
24+
- Regex: '<([A-Za-z0-9.\Q/-_\E])+>'
25+
Priority: 2
26+
- Regex: '"([A-Za-z0-9.\Q/-_\E])+"'
27+
Priority: 1
28+
29+
# If true, empty lines at the start of blocks are kept.
30+
KeepEmptyLinesAtTheStartOfBlocks: false
31+
AllowShortLoopsOnASingleLine: true
32+
AllowShortIfStatementsOnASingleLine: true
33+
Cpp11BracedListStyle: true
34+
# If true, always break after the template<...> of a template declaration.
35+
AlwaysBreakTemplateDeclarations: true
36+
# If false, a function declaration's or function definition's parameters will
37+
# either all be on the same line or will have one line each.
38+
BinPackArguments: true
39+
BreakConstructorInitializersBeforeComma: true
40+
# The maximum number of consecutive empty lines to keep.
41+
MaxEmptyLinesToKeep: 1

.vscode/settings.json

+6-1
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,8 @@
11
{
2-
"yapf.args":["--style={based_on_s'tyle: google, column_limit: 80, indent_width: 4}"]
2+
"yapf.args": [
3+
"--style={based_on_s'tyle: google, column_limit: 80, indent_width: 4}"
4+
],
5+
"files.associations": {
6+
"optional": "cpp"
7+
}
38
}

csrc/common.h

+14-3
Original file line numberDiff line numberDiff line change
@@ -3,14 +3,26 @@
33
#pragma once
44

55
#include <ATen/cuda/CUDAContext.h>
6+
#include <c10/cuda/CUDAGuard.h>
67
#include <torch/extension.h>
78

89
namespace vptq {
10+
11+
#define CHECK_CUDA(x) \
12+
TORCH_CHECK(x.device().is_cuda(), #x " must be a CUDA tensor")
13+
#define CHECK_CONTIGUOUS(x) \
14+
TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
15+
#define CHECK_INPUT(x) \
16+
CHECK_CUDA(x); \
17+
CHECK_CONTIGUOUS(x)
18+
19+
#define gpuErrchk(ret) gpuAssert((ret), __FILE__, __LINE__);
20+
921
class OptionalCUDAGuard {
1022
int set_device_ = -1;
1123
int current_device_ = -1;
1224

13-
public:
25+
public:
1426
OptionalCUDAGuard(int device) : set_device_(device) {
1527
cudaError_t err = cudaGetDevice(&current_device_);
1628
std::stringstream ss;
@@ -32,13 +44,12 @@ class OptionalCUDAGuard {
3244
}
3345
};
3446

35-
#define gpuErrchk(ret) gpuAssert((ret), __FILE__, __LINE__);
36-
3747
inline void gpuAssert(cudaError_t code, const char* file, int line) {
3848
if (code != cudaSuccess) {
3949
fprintf(stderr, "GPUassert: %s %s %d\n", cudaGetErrorString(code), file,
4050
line);
4151
TORCH_CHECK(false, cudaGetErrorString(code));
4252
}
4353
}
54+
4455
} // namespace vptq

csrc/utils.cuh csrc/cuda_utils.cuh

+23-2
Original file line numberDiff line numberDiff line change
@@ -2,27 +2,47 @@
22
// Licensed under the MIT License.
33
#pragma once
44

5+
#include <ATen/cuda/CUDAContext.h>
6+
57
#if defined(USE_ROCM)
6-
#include <hip/hip_fp16.h>
78
#include <hip/hip_bf16.h>
9+
#include <hip/hip_fp16.h>
810

911
#define VPTQ_LDG(arg) __ldg(arg)
1012
#define SHFL_DOWN(val, offset) __shfl_down(val, offset)
1113
#define WARP_SIZE warpSize
14+
1215
typedef __hip_bfloat162 __bfloat162;
1316
typedef __hip_bfloat16 __bfloat16;
1417
#else
15-
#include <cuda_fp16.h>
1618
#include <cuda_bf16.h>
19+
#include <cuda_fp16.h>
1720

1821
#define WARP_SIZE 32
1922
#define VPTQ_LDG(arg) *(arg)
2023
#define SHFL_DOWN(val, offset) __shfl_down_sync(0xffffffff, val, offset)
24+
2125
typedef __nv_bfloat162 __bfloat162;
2226
typedef __nv_bfloat16 __bfloat16;
2327
#endif
2428

2529
namespace vptq {
30+
31+
template <typename T>
32+
struct C10ToNvType {
33+
typedef __bfloat16 type;
34+
};
35+
36+
template <>
37+
struct C10ToNvType<c10::Half> {
38+
typedef __half type;
39+
};
40+
41+
template <>
42+
struct C10ToNvType<float> {
43+
typedef float type;
44+
};
45+
2646
namespace cuda {
2747

2848
constexpr int kBlockSize = 256;
@@ -243,4 +263,5 @@ __device__ __half operator*(const __half& a, const __half& b) {
243263
return __hmul(a, b);
244264
}
245265
#endif
266+
246267
} // namespace vptq

0 commit comments

Comments
 (0)