Skip to content

Commit 9f34dda

Browse files
feat: 接入CNNL,并添加unary/binary/softmax/batchnorm/reduce/transpose/pooling算子
1 parent 7f82d74 commit 9f34dda

40 files changed

+1918
-9
lines changed

src/02hardware/CMakeLists.txt

+2-6
Original file line numberDiff line numberDiff line change
@@ -3,14 +3,10 @@ project(hardware VERSION 0.0.0 LANGUAGES CXX)
33
message(STATUS "Project " ${PROJECT_NAME} " version " ${PROJECT_VERSION})
44

55
# Source files
6-
file(GLOB HARDWARE_SRC src/*.cc src/*.cpp src/devices/cpu/*.cc)
6+
file(GLOB_RECURSE HARDWARE_SRC src/*.cc src/*.cpp)
77

88
if(USE_CUDA)
9-
file(GLOB_RECURSE HARDWARE_CUDA_SRC src/devices/nvidia/*.cu src/devices/nvidia/*.cc)
10-
endif()
11-
12-
if(USE_BANG)
13-
file(GLOB_RECURSE HARDWARE_BANG_SRC src/devices/mlu/*.cc)
9+
file(GLOB_RECURSE HARDWARE_CUDA_SRC src/devices/nvidia/*.cu)
1410
endif()
1511

1612
add_library(hardware STATIC ${HARDWARE_SRC} ${HARDWARE_CUDA_SRC} ${HARDWARE_BANG_SRC})

src/02hardware/src/device_manager.cpp

+2
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
#include "hardware/device_manager.h"
22
#include "hardware/devices/cpu.h"
33
#include "hardware/devices/nvidia.h"
4+
#include "hardware/devices/mlu.h"
45

56
namespace refactor::hardware::device {
67

@@ -37,6 +38,7 @@ namespace refactor::hardware::device {
3738
using T = Device::Type;
3839
// clang-format off
3940
auto device = type == T::Nvidia ? std::make_shared<Nvidia>(card)
41+
: type == T::Mlu ? std::make_shared<Mlu>(card)
4042
: UNREACHABLEX(Arc<Device>, "");
4143
// clang-format on
4244
auto [kind, ok] = DEVICES.try_emplace(static_cast<int32_t>(type));

src/02hardware/src/devices/mlu/device.cc

+6-2
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,25 @@
1-
#include "functions.cc"
1+
#include "functions.hh"
22
#include "hardware/devices/mlu.h"
33
#include "hardware/mem_pool.h"
44
#include "memory.hh"
55

66
namespace refactor::hardware {
77

88
static Arc<Memory> bangMemory(int32_t card) {
9+
#ifdef USE_BANG
910
ASSERT(0 <= card && card < getDeviceCount(), "Invalid card id: {}", card);
1011
setDevice(card);
1112
auto [free, total] = getMemInfo();
1213
auto size = std::min(free, std::max(5ul << 30, total * 4 / 5));
13-
fmt::println("initializing Nvidia GPU {}, memory {} / {}, alloc {}",
14+
fmt::println("initializing Cambricon MLU {}, memory {} / {}, alloc {}",
1415
card, free, total, size);
1516
return std::make_shared<MemPool>(
1617
std::make_shared<MluMemory>(),
1718
size,
1819
256ul);
20+
#else
21+
return nullptr;
22+
#endif
1923
}
2024

2125
Mlu::Mlu(int32_t card) : Device(card, bangMemory(card)) {}

src/02hardware/src/devices/mlu/functions.cc

+2
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
namespace refactor::hardware {
44

5+
#ifdef USE_BANG
56
int getDeviceCount() {
67
unsigned deviceCount;
78
BANG_ASSERT(cnrtGetDeviceCount(&deviceCount));
@@ -15,5 +16,6 @@ namespace refactor::hardware {
1516
BANG_ASSERT(cnrtMemGetInfo(&memInfo.free, &memInfo.total));
1617
return memInfo;
1718
}
19+
#endif
1820

1921
}// namespace refactor::hardware

src/02hardware/src/devices/mlu/functions.hh

+4-1
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,17 @@
11
#ifndef HARDWARE_DEVICES_MLU_FUNCTIONS_CUH
22
#define HARDWARE_DEVICES_MLU_FUNCTIONS_CUH
33

4-
#include "cnrt.h"
54
#include "common.h"
65

6+
#ifdef USE_BANG
7+
#include "cnrt.h"
8+
79
#define BANG_ASSERT(STATUS) \
810
if (auto status = (STATUS); status != CNRT_RET_SUCCESS) { \
911
RUNTIME_ERROR(fmt::format("bang failed on \"" #STATUS "\" with \"{}\" ({})", \
1012
cnrtGetErrorStr(status), (int) status)); \
1113
}
14+
#endif
1215

1316
namespace refactor::hardware {
1417

src/02hardware/src/devices/mlu/memory.cc

+2
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
#include "functions.hh"
33

44
namespace refactor::hardware {
5+
#ifdef USE_BANG
56
using M = MluMemory;
67

78
void *M::malloc(size_t size) {
@@ -27,5 +28,6 @@ namespace refactor::hardware {
2728
CNRT_MEM_TRANS_DIR_PEER2PEER));
2829
return dst;
2930
}
31+
#endif
3032

3133
}// namespace refactor::hardware

src/02hardware/src/devices/nvidia/device.cc

+4
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
namespace refactor::hardware {
77

88
static Arc<Memory> cudaMemory(int32_t card) {
9+
#ifdef USE_CUDA
910
ASSERT(0 <= card && card < getDeviceCount(), "Invalid card id: {}", card);
1011
setDevice(card);
1112
auto [free, total] = getMemInfo();
@@ -16,6 +17,9 @@ namespace refactor::hardware {
1617
std::make_shared<NvidiaMemory>(),
1718
size,
1819
256ul);
20+
#else
21+
return nullptr;
22+
#endif
1923
}
2024

2125
Nvidia::Nvidia(int32_t card) : Device(card, cudaMemory(card)) {}

src/04kernel/src/collectors/batch_normalization.cc

+4
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
#include "kernel/collectors/batch_normalization.h"
22
#include "../kernels/batch_normalization/cpu_kernel.hh"
33
#include "../kernels/batch_normalization/cudnn_kernel.hh"
4+
#include "../kernels/batch_normalization/cnnl_kernel.hh"
45

56
namespace refactor::kernel {
67

@@ -20,6 +21,9 @@ namespace refactor::kernel {
2021
case decltype(_target)::Nvidia:
2122
REGISTER(BatchNormalizationCudnn)
2223
break;
24+
case decltype(_target)::Mlu:
25+
REGISTER(BatchNormalizationCnnl)
26+
break;
2327
default:
2428
UNREACHABLEX(void, "Unknown target");
2529
}

src/04kernel/src/collectors/reduce.cc

+4
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
#include "kernel/collectors/reduce.h"
22
#include "../kernels/reduce/cpu_kernel.hh"
33
#include "../kernels/reduce/cudnn_kernel.hh"
4+
#include "../kernels/reduce/cnnl_kernel.hh"
45

56
namespace refactor::kernel {
67

@@ -27,6 +28,9 @@ namespace refactor::kernel {
2728
case decltype(_target)::Nvidia:
2829
REGISTER(ReduceCudnn)
2930
break;
31+
case decltype(_target)::Mlu:
32+
REGISTER(ReduceCnnl)
33+
break;
3034
default:
3135
UNREACHABLEX(void, "Unknown target");
3236
}

src/04kernel/src/collectors/simple_binary.cc

+4
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
#include "../kernels/simple_binary/binary_cudnn.hh"
33
#include "../kernels/simple_binary/cpu_kernel.hh"
44
#include "../kernels/simple_binary/cuda_kernel.hh"
5+
#include "../kernels/simple_binary/binary_cnnl.hh"
56

67
namespace refactor::kernel {
78

@@ -48,6 +49,9 @@ namespace refactor::kernel {
4849
REGISTER_BROCAST(BinaryCudnn)
4950
REGISTER(BinaryCuda)
5051
break;
52+
case decltype(_target)::Mlu:
53+
REGISTER_BROCAST(BinaryCnnl)
54+
break;
5155
default:
5256
UNREACHABLEX(void, "Unknown target");
5357
}

src/04kernel/src/collectors/simple_unary.cc

+6
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22
#include "../kernels/simple_unary/cpu_kernel.hh"
33
#include "../kernels/simple_unary/cuda_kernel.hh"
44
#include "../kernels/simple_unary/cudnn_activation_kernel.hh"
5+
#include "../kernels/simple_unary/cnnl_activation_kernel.hh"
6+
#include "../kernels/simple_unary/cnnl_simple_unary_kernel.hh"
57
#include "common.h"
68

79
namespace refactor::kernel {
@@ -54,6 +56,10 @@ namespace refactor::kernel {
5456
REGISTER(ActivationCudnn)
5557
REGISTER(SimpleUnaryCuda)
5658
break;
59+
case decltype(_target)::Mlu:
60+
REGISTER(ActivationCnnl)
61+
REGISTER(SimpleUnaryCnnl)
62+
break;
5763
default:
5864
UNREACHABLEX(void, "Unknown target");
5965
}

src/04kernel/src/collectors/softmax.cc

+7
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
#include "kernel/collectors/softmax.h"
2+
#include "../kernels/softmax/cnnl_kernel.hh"
23
#include "../kernels/softmax/cpu_kernel.hh"
34
#include "../kernels/softmax/cuda_kernel.hh"
45
#include "../kernels/softmax/cudnn_kernel.hh"
@@ -28,6 +29,12 @@ namespace refactor::kernel {
2829
}
2930
break;
3031
}
32+
case decltype(_target)::Mlu: {
33+
if (auto ptr = SoftmaxCnnl::build(cnnl::SoftmaxAlgo::ACCURATE, info); ptr) {
34+
ans.emplace_back(std::move(ptr));
35+
}
36+
break;
37+
}
3138
default:
3239
UNREACHABLEX(void, "Unknown target");
3340
}

src/04kernel/src/collectors/transpose.cc

+6
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
#include "kernel/collectors/transpose.h"
22
#include "../kernels/transpose/cpu_kernel.hh"
33
#include "../kernels/transpose/cuda_kernel.hh"
4+
#include "../kernels/transpose/cnnl_kernel.hh"
45

56
namespace refactor::kernel {
67

@@ -25,6 +26,11 @@ namespace refactor::kernel {
2526
ans.emplace_back(std::move(ptr));
2627
}
2728
break;
29+
case decltype(_target)::Mlu:
30+
if (auto ptr = TransposeCnnl::build(data.dataType, data.shape, perm); ptr) {
31+
ans.emplace_back(std::move(ptr));
32+
}
33+
break;
2834
default:
2935
UNREACHABLEX(void, "Unknown target");
3036
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,158 @@
1+
#include "cnnl_kernel.hh"
2+
3+
#ifdef USE_BANG
4+
#include "../../utilities/bang/cnnl_context.hh"
5+
#include "../../utilities/bang/cnnl_functions.h"
6+
#include <cnnl.h>
7+
#endif
8+
9+
namespace refactor::kernel {
10+
using K = BatchNormalizationCnnl;
11+
using DT = DataType;
12+
13+
K::BatchNormalizationCnnl(decltype(info) info_) noexcept
14+
: info(info_) {}
15+
16+
auto K::build(float epsilon, TensorRefs inputs) noexcept -> KernelBox {
17+
#ifndef USE_BANG
18+
return nullptr;
19+
#endif
20+
21+
auto const &x = inputs[0].get();
22+
auto const &scale = inputs[1].get();
23+
auto const &mean = inputs[3].get();
24+
25+
if (x.rank() != 4) {
26+
return nullptr;
27+
}
28+
29+
// see "Supported Configurations for `cnnlBatchNormalizationForwardInference`"
30+
if (scale.dataType != mean.dataType) {
31+
return nullptr;
32+
}
33+
if (x.dataType == DT::F64) {
34+
if (scale.dataType != DT::F64) {
35+
return nullptr;
36+
}
37+
} else {
38+
if (scale.dataType != DT::F32) {
39+
return nullptr;
40+
}
41+
}
42+
return std::make_unique<K>(decltype(info){
43+
epsilon,
44+
x.dataType,
45+
scale.dataType,
46+
x.layout,
47+
{
48+
static_cast<int>(x.shape[0]),
49+
static_cast<int>(x.shape[1]),
50+
static_cast<int>(x.shape[2]),
51+
static_cast<int>(x.shape[3]),
52+
}});
53+
}
54+
auto K::typeId() noexcept -> size_t {
55+
static uint8_t ID = 1;
56+
return reinterpret_cast<size_t>(&ID);
57+
}
58+
59+
auto K::kernelTypeId() const noexcept -> size_t { return typeId(); }
60+
auto K::description() const noexcept -> std::string_view {
61+
return "Performing batch normalization for non-training-mode using CNNL";
62+
}
63+
64+
#ifdef USE_BANG
65+
66+
auto K::lower(Resources &res) const -> RoutineWorkspace {
67+
using namespace cnnl;
68+
using namespace runtime;
69+
using DT = DataType;
70+
71+
// RAII for closure
72+
struct Descriptors {
73+
cnnlTensorDescriptor_t inDesc, inDescTrans, p;
74+
cnnlTransposeDescriptor_t NCHW2NHWC, NHWC2NCHW;
75+
bool f32;
76+
77+
explicit Descriptors(decltype(f32) f32_)
78+
: inDesc(nullptr), inDescTrans(nullptr), p(nullptr),
79+
NCHW2NHWC(nullptr), NHWC2NCHW(nullptr), f32(f32_) {
80+
CNNL_ASSERT(cnnlCreateTensorDescriptor(&inDesc));
81+
CNNL_ASSERT(cnnlCreateTensorDescriptor(&inDescTrans));
82+
CNNL_ASSERT(cnnlCreateTensorDescriptor(&p));
83+
CNNL_ASSERT(cnnlCreateTransposeDescriptor(&NCHW2NHWC));
84+
CNNL_ASSERT(cnnlCreateTransposeDescriptor(&NHWC2NCHW));
85+
}
86+
~Descriptors() noexcept(false) {
87+
CNNL_ASSERT(cnnlDestroyTensorDescriptor(inDesc));
88+
CNNL_ASSERT(cnnlDestroyTensorDescriptor(inDescTrans));
89+
CNNL_ASSERT(cnnlDestroyTensorDescriptor(p));
90+
CNNL_ASSERT(cnnlDestroyTransposeDescriptor(NCHW2NHWC));
91+
CNNL_ASSERT(cnnlDestroyTransposeDescriptor(NHWC2NCHW));
92+
}
93+
94+
Descriptors(const Descriptors &) = delete;
95+
Descriptors(Descriptors &&) = delete;
96+
};
97+
auto d = std::make_shared<Descriptors>(info.dtX != DT::F64);
98+
int dimNCHW[4] = {info.dimAx[0], info.dimAx[1], info.dimAx[2], info.dimAx[3]};
99+
int dimNHWC[4] = {info.dimAx[0], info.dimAx[2], info.dimAx[3], info.dimAx[1]};
100+
int dimParam[]{info.dimAx[1]};
101+
setCnnlTensor(d->inDesc, info.dtX, slice(dimNCHW, 4));
102+
CNNL_ASSERT(cnnlSetTensorDescriptor(d->inDescTrans, CNNL_LAYOUT_NHWC, cnnlDataTypeConvert(info.dtX), 4, dimNHWC));
103+
CNNL_ASSERT(cnnlSetTensorDescriptor(d->p, CNNL_LAYOUT_ARRAY, cnnlDataTypeConvert(info.dtP), 1, dimParam));
104+
int permute[4] = {0, 2, 3, 1};
105+
int permuteOut[4] = {0, 3, 1, 2};
106+
CNNL_ASSERT(cnnlSetTransposeDescriptor(d->NCHW2NHWC, 4, permute));
107+
CNNL_ASSERT(cnnlSetTransposeDescriptor(d->NHWC2NCHW, 4, permuteOut));
108+
109+
auto handle = res.fetchOrStore<CnnlContext>()->handle;
110+
auto xTransSize = cnnlGetTensorElementNum(d->inDescTrans) * sizeof(info.dtX);
111+
size_t workspaceSize;
112+
CNNL_ASSERT(cnnlGetTransposeWorkspaceSize(handle, d->inDesc, d->NCHW2NHWC, &workspaceSize));
113+
size_t totalWorkspaceSize = xTransSize + workspaceSize;
114+
115+
res.fetchOrStore<CnnlContext>();
116+
auto routine = [d = std::move(d),
117+
epsilon = info.epsilon,
118+
xTransSize, workspaceSize](Resources &res, void *workspace, void const *const *inputs, void *const *outputs) {
119+
// fetch cnnl handle from resources
120+
auto handle = res.fetchOrStore<CnnlContext>()->handle;
121+
122+
// name inputs and outputs
123+
auto x = inputs[0],
124+
scale = inputs[1],
125+
bias = inputs[2],
126+
mean = inputs[3],
127+
var = inputs[4];
128+
auto y = outputs[0];
129+
130+
void *xTrans = workspace;
131+
void *yTrans = xTrans + xTransSize;
132+
void *cursor = yTrans + workspaceSize;
133+
134+
// transpose NCHW input to NHWC
135+
CNNL_ASSERT(cnnlTranspose_v2(handle, d->NCHW2NHWC, d->inDesc, x,
136+
d->inDescTrans, xTrans, cursor, workspaceSize));
137+
138+
// build alpha/beta for double
139+
auto a = d->f32 ? factor<fp32_t>(1) : factor<fp64_t>(1),
140+
b = d->f32 ? factor<fp32_t>(0) : factor<fp64_t>(0);
141+
CNNL_ASSERT(cnnlBatchNormForwardInference(
142+
handle, &a, &b,
143+
d->inDescTrans, xTrans, d->p, scale, bias, mean, var,
144+
epsilon, d->inDescTrans, yTrans));
145+
146+
// transpose NHWC intermediates to NCHW
147+
CNNL_ASSERT(cnnlTranspose_v2(handle, d->NHWC2NCHW, d->inDescTrans, yTrans,
148+
d->inDesc, y, cursor, workspaceSize));
149+
150+
BANG_ASSERT(cnrtQueueSync(res.fetchOrStore<CnnlContext>()->queue));
151+
};
152+
153+
return {std::move(routine), totalWorkspaceSize};
154+
}
155+
156+
#endif
157+
158+
}// namespace refactor::kernel

0 commit comments

Comments
 (0)