diff --git a/src/02hardware/src/device_manager.cpp b/src/02hardware/src/device_manager.cpp index 15ae3b901..8a7f8edd2 100644 --- a/src/02hardware/src/device_manager.cpp +++ b/src/02hardware/src/device_manager.cpp @@ -2,6 +2,7 @@ #include "hardware/devices/cpu.h" #include "hardware/devices/mlu.h" #include "hardware/devices/nvidia.h" +#include "hardware/devices/mlu.h" namespace refactor::hardware::device { diff --git a/src/04kernel/src/collectors/batch_normalization.cc b/src/04kernel/src/collectors/batch_normalization.cc index 93bcb240e..e944e37d7 100644 --- a/src/04kernel/src/collectors/batch_normalization.cc +++ b/src/04kernel/src/collectors/batch_normalization.cc @@ -1,6 +1,7 @@ #include "kernel/collectors/batch_normalization.h" #include "../kernels/batch_normalization/cpu_kernel.hh" #include "../kernels/batch_normalization/cudnn_kernel.hh" +#include "../kernels/batch_normalization/cnnl_kernel.hh" namespace refactor::kernel { @@ -20,6 +21,9 @@ namespace refactor::kernel { case decltype(_target)::Nvidia: REGISTER(BatchNormalizationCudnn) break; + case decltype(_target)::Mlu: + REGISTER(BatchNormalizationCnnl) + break; default: UNREACHABLEX(void, "Unknown target"); } diff --git a/src/04kernel/src/collectors/cast.cc b/src/04kernel/src/collectors/cast.cc index bca9d2629..45bf3372b 100644 --- a/src/04kernel/src/collectors/cast.cc +++ b/src/04kernel/src/collectors/cast.cc @@ -1,6 +1,7 @@ #include "kernel/collectors/cast.h" #include "../kernels/cast/cpu_kernel.hh" #include "../kernels/cast/cuda_kernel.hh" +#include "../kernels/cast/cnnl_kernel.hh" namespace refactor::kernel { @@ -24,6 +25,11 @@ namespace refactor::kernel { ans.emplace_back(std::move(ptr)); } break; + case decltype(_target)::Mlu: + if (auto ptr = CastCnnl::build(from, to); ptr) { + ans.emplace_back(std::move(ptr)); + } + break; default: UNREACHABLEX(void, "Unknown target"); } diff --git a/src/04kernel/src/collectors/clip.cc b/src/04kernel/src/collectors/clip.cc index 06ccd020b..364fbe588 100644 --- a/src/04kernel/src/collectors/clip.cc +++ b/src/04kernel/src/collectors/clip.cc @@ -1,6 +1,7 @@ #include "kernel/collectors/clip.h" #include "../kernels/clip/cpu_kernel.hh" #include "../kernels/clip/cuda_kernel.hh" +#include "../kernels/clip/cnnl_kernel.hh" namespace refactor::kernel { @@ -24,6 +25,11 @@ namespace refactor::kernel { ans.emplace_back(std::move(ptr)); } break; + case decltype(_target)::Mlu: + if (auto ptr = ClipCnnl::build(data, hasMax); ptr) { + ans.emplace_back(std::move(ptr)); + } + break; default: UNREACHABLEX(void, "Unknown target"); } diff --git a/src/04kernel/src/collectors/concat.cc b/src/04kernel/src/collectors/concat.cc index 8e6386907..0802de8e8 100644 --- a/src/04kernel/src/collectors/concat.cc +++ b/src/04kernel/src/collectors/concat.cc @@ -1,6 +1,7 @@ #include "kernel/collectors/concat.h" #include "../kernels/concat/cpu_kernel.hh" #include "../kernels/concat/cuda_kernel.hh" +#include "../kernels/concat/cnnl_kernel.hh" namespace refactor::kernel { @@ -8,6 +9,8 @@ namespace refactor::kernel { ConcatCollector::filter(TensorRefs inputs, TensorRefs outputs) const { SplitInfo info(axis, inputs); + auto const &b = outputs[0]; + std::vector<KernelBox> ans; switch (_target) { case decltype(_target)::Cpu: @@ -20,6 +23,11 @@ namespace refactor::kernel { ans.emplace_back(std::move(ptr)); } break; + case decltype(_target)::Mlu: + if (auto ptr = ConcatCnnl::build(axis, inputs, b); ptr) { + ans.emplace_back(std::move(ptr)); + } + break; default: UNREACHABLEX(void, "Unknown target"); } diff --git a/src/04kernel/src/collectors/conv.cc b/src/04kernel/src/collectors/conv.cc index 14b61835f..d9cc0ea27 100644 --- a/src/04kernel/src/collectors/conv.cc +++ b/src/04kernel/src/collectors/conv.cc @@ -1,4 +1,5 @@ #include "kernel/collectors/conv.h" +#include "../kernels/conv/cnnl_kernel.hh" #include "../kernels/conv/cudnn_kernel.hh" namespace refactor::kernel { @@ -23,6 +24,11 @@ namespace refactor::kernel { ans.emplace_back(std::move(ptr)); } break; + case decltype(_target)::Mlu: + if (auto ptr = ConvCnnl::build(poolAttrs, x, w, b, y); ptr) { + ans.emplace_back(std::move(ptr)); + } + break; default: UNREACHABLEX(void, "Unknown target"); } diff --git a/src/04kernel/src/collectors/gather.cc b/src/04kernel/src/collectors/gather.cc index 9c30f1c62..0a8b8b0e8 100644 --- a/src/04kernel/src/collectors/gather.cc +++ b/src/04kernel/src/collectors/gather.cc @@ -1,4 +1,5 @@ #include "kernel/collectors/gather.h" +#include "../kernels/gather/cnnl_kernel.hh" #include "../kernels/gather/cpu_kernel.hh" #include "../kernels/gather/cuda_kernel.hh" @@ -8,7 +9,12 @@ namespace refactor::kernel { GatherCollector::filter(TensorRefs inputs, TensorRefs outputs) const { GatherInfo info(axis, inputs[0], inputs[1]); - std::vector<KernelBox> ans; + auto const &a = inputs[0]; + auto const &b = inputs[1]; + auto const &c = outputs[0]; + + std::vector<KernelBox> + ans; switch (_target) { case decltype(_target)::Cpu: if (auto ptr = GatherCpu::build(info); ptr != nullptr) { @@ -20,6 +26,11 @@ namespace refactor::kernel { ans.emplace_back(std::move(ptr)); } break; + case decltype(_target)::Mlu: + if (auto ptr = GatherCnnl::build(axis, a, b, c); ptr != nullptr) { + ans.emplace_back(std::move(ptr)); + } + break; default: UNREACHABLEX(void, "Unknown target"); } diff --git a/src/04kernel/src/collectors/global_pool.cc b/src/04kernel/src/collectors/global_pool.cc index 1ae1d7fc2..e6a278c1f 100644 --- a/src/04kernel/src/collectors/global_pool.cc +++ b/src/04kernel/src/collectors/global_pool.cc @@ -1,5 +1,6 @@ #include "kernel/collectors/global_pool.h" #include "../kernels/pool/cudnn_kernel.hh" +#include "../kernels/pool/cnnl_kernel.hh" namespace refactor::kernel { @@ -28,6 +29,11 @@ namespace refactor::kernel { ans.emplace_back(std::move(ptr)); } break; + case decltype(_target)::Mlu: + if (auto ptr = PoolCnnl::build(type, false, kernelShape, attributes, x, y); ptr) { + ans.emplace_back(std::move(ptr)); + } + break; default: UNREACHABLEX(void, "Unknown target"); } diff --git a/src/04kernel/src/collectors/hard_sigmoid.cc b/src/04kernel/src/collectors/hard_sigmoid.cc index 69d2f9d1e..c44151c20 100644 --- a/src/04kernel/src/collectors/hard_sigmoid.cc +++ b/src/04kernel/src/collectors/hard_sigmoid.cc @@ -1,4 +1,5 @@ #include "kernel/collectors/hard_sigmoid.h" +#include "../kernels/hard_sigmoid/cnnl_kernel.hh" #include "../kernels/hard_sigmoid/cpu_kernel.hh" #include "../kernels/hard_sigmoid/cuda_kernel.hh" @@ -20,6 +21,11 @@ namespace refactor::kernel { ans.emplace_back(std::move(ptr)); } break; + case decltype(_target)::Mlu: + if (auto ptr = HardSigmoidCnnl::build(alpha, beta, a); ptr) { + ans.emplace_back(std::move(ptr)); + } + break; default: UNREACHABLEX(void, "Unknown target"); } diff --git a/src/04kernel/src/collectors/mat_mul.cc b/src/04kernel/src/collectors/mat_mul.cc index 7581200cd..d6b323260 100644 --- a/src/04kernel/src/collectors/mat_mul.cc +++ b/src/04kernel/src/collectors/mat_mul.cc @@ -1,4 +1,5 @@ #include "kernel/collectors/mat_mul.h" +#include "../kernels/mat_mul/cnnl_kernel.hh" #include "../kernels/mat_mul/cpu_kernel.hh" #include "../kernels/mat_mul/cublas_kernel.hh" #include "kernel/attributes/mat_mul_info.h" @@ -26,6 +27,11 @@ namespace refactor::kernel { case decltype(_target)::Nvidia: REGISTER(MatMulCublas) break; + case decltype(_target)::Mlu: + if (auto ptr = MatMulCnnl::build(inputs, outputs, transA, transB, alpha, beta); ptr) { + ans.emplace_back(std::move(ptr)); + } + break; default: UNREACHABLEX(void, "Unknown target"); } diff --git a/src/04kernel/src/collectors/pad.cc b/src/04kernel/src/collectors/pad.cc index f4c995e0b..1c7be68c4 100644 --- a/src/04kernel/src/collectors/pad.cc +++ b/src/04kernel/src/collectors/pad.cc @@ -1,4 +1,5 @@ #include "kernel/collectors/pad.h" +#include "../kernels/pad/cnnl_kernel.hh" #include "../kernels/pad/cpu_kernel.hh" #include "../kernels/pad/cuda_kernel.hh" @@ -22,6 +23,11 @@ namespace refactor::kernel { ans.emplace_back(std::move(ptr)); } break; + case decltype(_target)::Mlu: + if (auto ptr = PadCnnl::build(dims, input.get().dataType, mode, const_value); ptr) { + ans.emplace_back(std::move(ptr)); + } + break; default: UNREACHABLEX(void, "Unknown target"); } @@ -29,4 +35,3 @@ namespace refactor::kernel { } }// namespace refactor::kernel - diff --git a/src/04kernel/src/collectors/pool.cc b/src/04kernel/src/collectors/pool.cc index 458d3a375..d034e199f 100644 --- a/src/04kernel/src/collectors/pool.cc +++ b/src/04kernel/src/collectors/pool.cc @@ -1,5 +1,6 @@ #include "kernel/collectors/pool.h" #include "../kernels/pool/cudnn_kernel.hh" +#include "../kernels/pool/cnnl_kernel.hh" namespace refactor::kernel { @@ -29,6 +30,11 @@ namespace refactor::kernel { ans.emplace_back(std::move(ptr)); } break; + case decltype(_target)::Mlu: + if (auto ptr = PoolCnnl::build(type, ceil, kernelShape, attributes, x, y); ptr) { + ans.emplace_back(std::move(ptr)); + } + break; default: UNREACHABLEX(void, "Unknown target"); } diff --git a/src/04kernel/src/collectors/reduce.cc b/src/04kernel/src/collectors/reduce.cc index bec37731d..71fa194ba 100644 --- a/src/04kernel/src/collectors/reduce.cc +++ b/src/04kernel/src/collectors/reduce.cc @@ -1,6 +1,7 @@ #include "kernel/collectors/reduce.h" #include "../kernels/reduce/cpu_kernel.hh" #include "../kernels/reduce/cudnn_kernel.hh" +#include "../kernels/reduce/cnnl_kernel.hh" namespace refactor::kernel { @@ -27,6 +28,9 @@ namespace refactor::kernel { case decltype(_target)::Nvidia: REGISTER(ReduceCudnn) break; + case decltype(_target)::Mlu: + REGISTER(ReduceCnnl) + break; default: UNREACHABLEX(void, "Unknown target"); } diff --git a/src/04kernel/src/collectors/scatter_nd.cc b/src/04kernel/src/collectors/scatter_nd.cc index 62d63c802..3b43a25fb 100644 --- a/src/04kernel/src/collectors/scatter_nd.cc +++ b/src/04kernel/src/collectors/scatter_nd.cc @@ -1,6 +1,7 @@ #include "kernel/collectors/scatter_nd.h" #include "../kernels/scatter_nd/cpu_kernel.hh" #include "../kernels/scatter_nd/cuda_kernel.hh" +#include "../kernels/scatter_nd/cnnl_kernel.hh" namespace refactor::kernel { @@ -23,6 +24,11 @@ namespace refactor::kernel { ans.emplace_back(std::move(ptr)); } break; + case decltype(_target)::Mlu: + if (auto ptr = ScatterNDCnnl::build(inputs, outputs); ptr) { + ans.emplace_back(std::move(ptr)); + } + break; default: UNREACHABLEX(void, "Unknown target"); } diff --git a/src/04kernel/src/collectors/select.cc b/src/04kernel/src/collectors/select.cc index e4eff8f4b..3cec750e8 100644 --- a/src/04kernel/src/collectors/select.cc +++ b/src/04kernel/src/collectors/select.cc @@ -1,6 +1,7 @@ #include "kernel/collectors/select.h" #include "../kernels/select/cpu_kernel.hh" #include "../kernels/select/cuda_kernel.hh" +#include "../kernels/select/cnnl_kernel.hh" namespace refactor::kernel { @@ -35,6 +36,9 @@ namespace refactor::kernel { case decltype(_target)::Nvidia: REGISTER(SelectCuda) break; + case decltype(_target)::Mlu: + REGISTER(SelectCnnl) + break; default: UNREACHABLEX(void, "Unknown target"); } diff --git a/src/04kernel/src/collectors/simple_binary.cc b/src/04kernel/src/collectors/simple_binary.cc index 53ae6723c..e61d013f9 100644 --- a/src/04kernel/src/collectors/simple_binary.cc +++ b/src/04kernel/src/collectors/simple_binary.cc @@ -2,6 +2,7 @@ #include "../kernels/simple_binary/binary_cudnn.hh" #include "../kernels/simple_binary/cpu_kernel.hh" #include "../kernels/simple_binary/cuda_kernel.hh" +#include "../kernels/simple_binary/binary_cnnl.hh" namespace refactor::kernel { @@ -50,6 +51,9 @@ namespace refactor::kernel { REGISTER_BROCAST(BinaryCudnn) REGISTER(BinaryCuda) break; + case decltype(_target)::Mlu: + REGISTER_BROCAST(BinaryCnnl) + break; default: UNREACHABLEX(void, "Unknown target"); } diff --git a/src/04kernel/src/collectors/simple_unary.cc b/src/04kernel/src/collectors/simple_unary.cc index 51a334c91..c489acecf 100644 --- a/src/04kernel/src/collectors/simple_unary.cc +++ b/src/04kernel/src/collectors/simple_unary.cc @@ -2,6 +2,8 @@ #include "../kernels/simple_unary/cpu_kernel.hh" #include "../kernels/simple_unary/cuda_kernel.hh" #include "../kernels/simple_unary/cudnn_activation_kernel.hh" +#include "../kernels/simple_unary/cnnl_activation_kernel.hh" +#include "../kernels/simple_unary/cnnl_simple_unary_kernel.hh" #include "common.h" namespace refactor::kernel { @@ -55,6 +57,10 @@ namespace refactor::kernel { REGISTER(ActivationCudnn) REGISTER(SimpleUnaryCuda) break; + case decltype(_target)::Mlu: + REGISTER(ActivationCnnl) + REGISTER(SimpleUnaryCnnl) + break; default: UNREACHABLEX(void, "Unknown target"); } diff --git a/src/04kernel/src/collectors/slice.cc b/src/04kernel/src/collectors/slice.cc index 0b063dd17..60c93cb28 100644 --- a/src/04kernel/src/collectors/slice.cc +++ b/src/04kernel/src/collectors/slice.cc @@ -1,6 +1,7 @@ #include "kernel/collectors/slice.h" #include "../kernels/slice/cpu_kernel.hh" #include "../kernels/slice/cuda_kernel.hh" +#include "../kernels/slice/cnnl_kernel.hh" namespace refactor::kernel { @@ -26,6 +27,11 @@ namespace refactor::kernel { ans.emplace_back(std::move(ptr)); } break; + case decltype(_target)::Mlu: + if (auto ptr = SliceCnnl::build(inputs[0].get().dataType, dimentions, inputs[0].get().shape, outputs[0].get().shape); ptr) { + ans.emplace_back(std::move(ptr)); + } + break; default: UNREACHABLEX(void, "Unknown target"); } diff --git a/src/04kernel/src/collectors/softmax.cc b/src/04kernel/src/collectors/softmax.cc index 2ce442696..020bc6ded 100644 --- a/src/04kernel/src/collectors/softmax.cc +++ b/src/04kernel/src/collectors/softmax.cc @@ -1,4 +1,5 @@ #include "kernel/collectors/softmax.h" +#include "../kernels/softmax/cnnl_kernel.hh" #include "../kernels/softmax/cpu_kernel.hh" #include "../kernels/softmax/cuda_kernel.hh" #include "../kernels/softmax/cudnn_kernel.hh" @@ -28,6 +29,12 @@ namespace refactor::kernel { } break; } + case decltype(_target)::Mlu: { + if (auto ptr = SoftmaxCnnl::build(cnnl::SoftmaxAlgo::ACCURATE, info); ptr) { + ans.emplace_back(std::move(ptr)); + } + break; + } default: UNREACHABLEX(void, "Unknown target"); } diff --git a/src/04kernel/src/collectors/split.cc b/src/04kernel/src/collectors/split.cc index 6fe22548d..b862d8959 100644 --- a/src/04kernel/src/collectors/split.cc +++ b/src/04kernel/src/collectors/split.cc @@ -1,4 +1,5 @@ #include "kernel/collectors/split.h" +#include "../kernels/split/cnnl_kernel.hh" #include "../kernels/split/cpu_kernel.hh" #include "../kernels/split/cuda_kernel.hh" @@ -8,6 +9,8 @@ namespace refactor::kernel { SplitCollector::filter(TensorRefs inputs, TensorRefs outputs) const { SplitInfo info(axis, outputs); + auto const &a = inputs[0]; + std::vector<KernelBox> ans; switch (_target) { case decltype(_target)::Cpu: @@ -20,6 +23,11 @@ namespace refactor::kernel { ans.emplace_back(std::move(ptr)); } break; + case decltype(_target)::Mlu: + if (auto ptr = SplitCnnl::build(axis, a, outputs); ptr) { + ans.emplace_back(std::move(ptr)); + } + break; default: UNREACHABLEX(void, "Unknown target"); } diff --git a/src/04kernel/src/collectors/transpose.cc b/src/04kernel/src/collectors/transpose.cc index c8d651974..c91493f98 100644 --- a/src/04kernel/src/collectors/transpose.cc +++ b/src/04kernel/src/collectors/transpose.cc @@ -1,6 +1,7 @@ #include "kernel/collectors/transpose.h" #include "../kernels/transpose/cpu_kernel.hh" #include "../kernels/transpose/cuda_kernel.hh" +#include "../kernels/transpose/cnnl_kernel.hh" namespace refactor::kernel { @@ -25,6 +26,11 @@ namespace refactor::kernel { ans.emplace_back(std::move(ptr)); } break; + case decltype(_target)::Mlu: + if (auto ptr = TransposeCnnl::build(data.dataType, data.shape, perm); ptr) { + ans.emplace_back(std::move(ptr)); + } + break; default: UNREACHABLEX(void, "Unknown target"); } diff --git a/src/04kernel/src/collectors/where.cc b/src/04kernel/src/collectors/where.cc index 1ce2b60c6..14aa21952 100644 --- a/src/04kernel/src/collectors/where.cc +++ b/src/04kernel/src/collectors/where.cc @@ -1,11 +1,12 @@ #include "kernel/collectors/where.h" +#include "../kernels/where/cnnl_kernel.hh" #include "../kernels/where/cpu_kernel.hh" #include "../kernels/where/where_cuda.hh" namespace refactor::kernel { std::vector<KernelBox> - WhereCollector::filter(TensorRefs inputs, TensorRefs) const { + WhereCollector::filter(TensorRefs inputs, TensorRefs outputs) const { std::vector<KernelBox> ans; switch (_target) { case decltype(_target)::Cpu: @@ -18,6 +19,11 @@ namespace refactor::kernel { ans.emplace_back(std::move(ptr)); } break; + case decltype(_target)::Mlu: + if (auto ptr = WhereCnnl::build(inputs, outputs); ptr) { + ans.emplace_back(std::move(ptr)); + } + break; default: UNREACHABLEX(void, "Unknown target"); } diff --git a/src/04kernel/src/kernels/batch_normalization/cnnl_kernel.cc b/src/04kernel/src/kernels/batch_normalization/cnnl_kernel.cc new file mode 100644 index 000000000..1330cfa7a --- /dev/null +++ b/src/04kernel/src/kernels/batch_normalization/cnnl_kernel.cc @@ -0,0 +1,157 @@ +#include "cnnl_kernel.hh" + +#ifdef USE_BANG +#include "../../utilities/bang/cnnl_context.hh" +#include "../../utilities/bang/cnnl_functions.h" +#include <cnnl.h> +#endif + +namespace refactor::kernel { + using K = BatchNormalizationCnnl; + using DT = DataType; + + K::BatchNormalizationCnnl(decltype(info) info_) noexcept + : info(info_) {} + + auto K::build(float epsilon, TensorRefs inputs) noexcept -> KernelBox { +#ifndef USE_BANG + return nullptr; +#endif + + auto const &x = inputs[0].get(); + auto const &scale = inputs[1].get(); + auto const &mean = inputs[3].get(); + + if (x.rank() != 4) { + return nullptr; + } + + // see "Supported Configurations for `cnnlBatchNormalizationForwardInference`" + if (scale.dataType != mean.dataType) { + return nullptr; + } + if (x.dataType == DT::F64) { + if (scale.dataType != DT::F64) { + return nullptr; + } + } else { + if (scale.dataType != DT::F32) { + return nullptr; + } + } + return std::make_unique<K>(decltype(info){ + epsilon, + x.dataType, + scale.dataType, + x.layout, + { + static_cast<int>(x.shape[0]), + static_cast<int>(x.shape[1]), + static_cast<int>(x.shape[2]), + static_cast<int>(x.shape[3]), + }}); + } + auto K::typeId() noexcept -> size_t { + static uint8_t ID = 1; + return reinterpret_cast<size_t>(&ID); + } + + auto K::kernelTypeId() const noexcept -> size_t { return typeId(); } + auto K::description() const noexcept -> std::string_view { + return "Performing batch normalization for non-training-mode using CNNL"; + } + +#ifdef USE_BANG + + auto K::lower(Resources &res) const -> RoutineWorkspace { + using namespace cnnl; + using namespace runtime; + using DT = DataType; + + // RAII for closure + struct Descriptors { + cnnlTensorDescriptor_t inDesc, inDescTrans, p; + cnnlTransposeDescriptor_t NCHW2NHWC, NHWC2NCHW; + bool f32; + + explicit Descriptors(decltype(f32) f32_) + : inDesc(nullptr), inDescTrans(nullptr), p(nullptr), + NCHW2NHWC(nullptr), NHWC2NCHW(nullptr), f32(f32_) { + CNNL_ASSERT(cnnlCreateTensorDescriptor(&inDesc)); + CNNL_ASSERT(cnnlCreateTensorDescriptor(&inDescTrans)); + CNNL_ASSERT(cnnlCreateTensorDescriptor(&p)); + CNNL_ASSERT(cnnlCreateTransposeDescriptor(&NCHW2NHWC)); + CNNL_ASSERT(cnnlCreateTransposeDescriptor(&NHWC2NCHW)); + } + ~Descriptors() noexcept(false) { + CNNL_ASSERT(cnnlDestroyTensorDescriptor(inDesc)); + CNNL_ASSERT(cnnlDestroyTensorDescriptor(inDescTrans)); + CNNL_ASSERT(cnnlDestroyTensorDescriptor(p)); + CNNL_ASSERT(cnnlDestroyTransposeDescriptor(NCHW2NHWC)); + CNNL_ASSERT(cnnlDestroyTransposeDescriptor(NHWC2NCHW)); + } + + Descriptors(const Descriptors &) = delete; + Descriptors(Descriptors &&) = delete; + }; + auto d = std::make_shared<Descriptors>(info.dtX != DT::F64); + int dimNCHW[4] = {info.dimAx[0], info.dimAx[1], info.dimAx[2], info.dimAx[3]}; + int dimNHWC[4] = {info.dimAx[0], info.dimAx[2], info.dimAx[3], info.dimAx[1]}; + int dimParam[]{info.dimAx[1]}; + setCnnlTensor(d->inDesc, info.dtX, slice(dimNCHW, 4)); + CNNL_ASSERT(cnnlSetTensorDescriptor(d->inDescTrans, CNNL_LAYOUT_NHWC, cnnlDataTypeConvert(info.dtX), 4, dimNHWC)); + CNNL_ASSERT(cnnlSetTensorDescriptor(d->p, CNNL_LAYOUT_ARRAY, cnnlDataTypeConvert(info.dtP), 1, dimParam)); + int permute[4] = {0, 2, 3, 1}; + int permuteOut[4] = {0, 3, 1, 2}; + CNNL_ASSERT(cnnlSetTransposeDescriptor(d->NCHW2NHWC, 4, permute)); + CNNL_ASSERT(cnnlSetTransposeDescriptor(d->NHWC2NCHW, 4, permuteOut)); + + auto handle = res.fetchOrStore<CnnlContext>()->handle; + auto xTransSize = cnnlGetTensorElementNum(d->inDescTrans) * info.dtX.size(); + size_t workspaceSize; + CNNL_ASSERT(cnnlGetTransposeWorkspaceSize(handle, d->inDesc, d->NCHW2NHWC, &workspaceSize)); + size_t totalWorkspaceSize = xTransSize * 2 + workspaceSize; + + res.fetchOrStore<CnnlContext>(); + auto routine = [d = std::move(d), + epsilon = info.epsilon, + xTransSize, workspaceSize](Resources &res, void *workspace, void const *const *inputs, void *const *outputs) { + // fetch cnnl handle from resources + auto handle = res.fetchOrStore<CnnlContext>()->handle; + + // name inputs and outputs + auto x = inputs[0], + scale = inputs[1], + bias = inputs[2], + mean = inputs[3], + var = inputs[4]; + auto y = outputs[0]; + + void *xTrans = workspace; + void *yTrans = reinterpret_cast<uint8_t *>(xTrans) + xTransSize; + void *cursor = reinterpret_cast<uint8_t *>(yTrans) + xTransSize; + + // transpose NCHW input to NHWC + CNNL_ASSERT(cnnlTranspose_v2(handle, d->NCHW2NHWC, d->inDesc, x, + d->inDescTrans, xTrans, cursor, workspaceSize)); + + // build alpha/beta for double + auto a = d->f32 ? factor<fp32_t>(1) : factor<fp64_t>(1), + b = d->f32 ? factor<fp32_t>(0) : factor<fp64_t>(0); + CNNL_ASSERT(cnnlBatchNormForwardInference( + handle, &a, &b, + d->inDescTrans, xTrans, d->p, scale, bias, mean, var, + epsilon, d->inDescTrans, yTrans)); + + // transpose NHWC intermediates to NCHW + CNNL_ASSERT(cnnlTranspose_v2(handle, d->NHWC2NCHW, d->inDescTrans, yTrans, + d->inDesc, y, cursor, workspaceSize)); + + }; + + return {std::move(routine), totalWorkspaceSize}; + } + +#endif + +}// namespace refactor::kernel diff --git a/src/04kernel/src/kernels/batch_normalization/cnnl_kernel.hh b/src/04kernel/src/kernels/batch_normalization/cnnl_kernel.hh new file mode 100644 index 000000000..978b0dedc --- /dev/null +++ b/src/04kernel/src/kernels/batch_normalization/cnnl_kernel.hh @@ -0,0 +1,32 @@ +#ifndef KERNEL_BATCH_NORMALIZATION_CNNL_KERNEL_HH +#define KERNEL_BATCH_NORMALIZATION_CNNL_KERNEL_HH + +#include "kernel/kernel.h" +#include "kernel/tensor.h" + +namespace refactor::kernel { + /// @brief Use `cnnlBatchNormalizationForwardInference`. + /// It only supports 4D and 5D tensors. + struct BatchNormalizationCnnl final : public Kernel { + struct { + float epsilon; + DataType dtX, dtP; + LayoutType layout; + int dimAx[4];// dimA for x + } info; + + explicit BatchNormalizationCnnl(decltype(info)) noexcept; + + static KernelBox build(float, TensorRefs) noexcept; + static size_t typeId() noexcept; + + size_t kernelTypeId() const noexcept final; + std::string_view description() const noexcept final; +#ifdef USE_BANG + RoutineWorkspace lower(Resources &) const final; +#endif + }; + +}// namespace refactor::kernel + +#endif// KERNEL_BATCH_NORMALIZATION_CNNL_KERNEL_HH diff --git a/src/04kernel/src/kernels/cast/cnnl_kernel.cc b/src/04kernel/src/kernels/cast/cnnl_kernel.cc new file mode 100644 index 000000000..735692b90 --- /dev/null +++ b/src/04kernel/src/kernels/cast/cnnl_kernel.cc @@ -0,0 +1,234 @@ +#include "cnnl_kernel.hh" + +#ifdef USE_BANG +#include "../../utilities/bang/cnnl_context.hh" +#include "../../utilities/bang/cnnl_functions.h" +#endif + + +namespace refactor::kernel { + using K = CastCnnl; + using DT = DataType; + + K::CastCnnl(decltype(from) from_, + decltype(to) to_, + decltype(shape) shape_) noexcept + : from(from_), to(to_), shape(shape_) {} + + auto K::build(Tensor const &from, Tensor const &to) noexcept -> KernelBox { +#ifndef USE_BANG + return nullptr; +#endif + + return std::make_unique<K>(from.dataType, to.dataType, + std::vector<int>(from.shape.begin(), from.shape.end())); + } + auto K::typeId() noexcept -> size_t { + static uint8_t ID = 1; + return reinterpret_cast<size_t>(&ID); + } + + auto K::kernelTypeId() const noexcept -> size_t { + return typeId(); + } + auto K::description() const noexcept -> std::string_view { + return "Performing cast operation using CNNL"; + } + +#ifdef USE_BANG + + static cnnlCastDataType_t castType(DataType from, DataType to); + + auto K::lower(Resources &res) const -> RoutineWorkspace { + using namespace cnnl; + using namespace runtime; + + struct Descriptors { + cnnlTensorDescriptor_t inDesc, outDesc; + cnnlCastDataType_t cast; + bool needCast; + + Descriptors(bool need) : inDesc(nullptr), outDesc(nullptr), + needCast(need) { + CNNL_ASSERT(cnnlCreateTensorDescriptor(&inDesc)); + CNNL_ASSERT(cnnlCreateTensorDescriptor(&outDesc)); + } + ~Descriptors() noexcept(false) { + CNNL_ASSERT(cnnlDestroyTensorDescriptor(inDesc)); + CNNL_ASSERT(cnnlDestroyTensorDescriptor(outDesc)); + } + }; + auto d = std::make_shared<Descriptors>(from != to); + if (d->needCast) { + d->cast = castType(from, to); + } + setCnnlTensor(d->inDesc, from, slice(shape.data(), shape.size())); + setCnnlTensor(d->outDesc, to, slice(shape.data(), shape.size())); + + res.fetchOrStore<CnnlContext>(); + return [d = std::move(d)](Resources &res, void *workspace, void const *const *inputs, void *const *outputs) { + if (d->needCast) { + CNNL_ASSERT(cnnlCastDataType(res.fetchOrStore<CnnlContext>()->handle, + d->inDesc, inputs[0], d->cast, d->outDesc, outputs[0])); + } else { + CNNL_ASSERT(cnnlCopy(res.fetchOrStore<CnnlContext>()->handle, + d->inDesc, inputs[0], d->outDesc, outputs[0])); + } + }; + } + + static cnnlCastDataType_t castType(DataType from, DataType to) { + switch (from) { + case DT::F32: + switch (to) { + case DT::F64: + return CNNL_CAST_FLOAT_TO_DOUBLE; + case DT::FP16: + return CNNL_CAST_FLOAT_TO_HALF; + case DT::I64: + return CNNL_CAST_FLOAT_TO_INT64; + case DT::I32: + return CNNL_CAST_FLOAT_TO_INT32; + case DT::I16: + return CNNL_CAST_FLOAT_TO_INT16; + case DT::I8: + return CNNL_CAST_FLOAT_TO_INT8; + case DT::U8: + return CNNL_CAST_FLOAT_TO_UINT8; + // case DT::BF16: + // return CNNL_CAST_FLOAT_TO_BFLOAT16; + case DT::Bool: + return CNNL_CAST_FLOAT_TO_BOOL; + default: + UNREACHABLE(); + } + case DT::FP16: + switch (to) { + case DT::F32: + return CNNL_CAST_HALF_TO_FLOAT; + case DT::I64: + return CNNL_CAST_HALF_TO_INT64; + case DT::I32: + return CNNL_CAST_HALF_TO_INT32; + case DT::I16: + return CNNL_CAST_HALF_TO_INT16; + case DT::I8: + return CNNL_CAST_HALF_TO_INT8; + case DT::U8: + return CNNL_CAST_HALF_TO_UINT8; + case DT::Bool: + return CNNL_CAST_HALF_TO_BOOL; + default: + UNREACHABLE(); + } + case DT::I32: + switch (to) { + case DT::F32: + return CNNL_CAST_INT32_TO_FLOAT; + case DT::FP16: + return CNNL_CAST_INT32_TO_HALF; + case DT::I64: + return CNNL_CAST_INT32_TO_INT64; + case DT::I16: + return CNNL_CAST_INT32_TO_INT16; + case DT::I8: + return CNNL_CAST_INT32_TO_INT8; + case DT::Bool: + return CNNL_CAST_INT32_TO_BOOL; + default: + UNREACHABLE(); + } + case DT::I16: + switch (to) { + case DT::F32: + return CNNL_CAST_INT16_TO_FLOAT; + case DT::FP16: + return CNNL_CAST_INT16_TO_HALF; + case DT::I32: + return CNNL_CAST_INT16_TO_INT32; + // case DT::I8: + // return CNNL_CAST_INT16_TO_INT8; + default: + UNREACHABLE(); + } + case DT::I8: + switch (to) { + case DT::F32: + return CNNL_CAST_INT8_TO_FLOAT; + case DT::FP16: + return CNNL_CAST_INT8_TO_HALF; + case DT::I32: + return CNNL_CAST_INT8_TO_INT32; + case DT::I16: + return CNNL_CAST_INT8_TO_INT16; + default: + UNREACHABLE(); + } + case DT::U8: + switch (to) { + case DT::F32: + return CNNL_CAST_UINT8_TO_FLOAT; + case DT::FP16: + return CNNL_CAST_UINT8_TO_HALF; + case DT::I64: + return CNNL_CAST_UINT8_TO_INT64; + case DT::I32: + return CNNL_CAST_UINT8_TO_INT32; + default: + UNREACHABLE(); + } + case DT::Bool: + switch (to) { + case DT::F32: + return CNNL_CAST_BOOL_TO_FLOAT; + case DT::FP16: + return CNNL_CAST_BOOL_TO_HALF; + case DT::I32: + return CNNL_CAST_BOOL_TO_INT32; + default: + UNREACHABLE(); + } + case DT::I64: + switch (to) { + case DT::F32: + return CNNL_CAST_INT64_TO_FLOAT; + case DT::FP16: + return CNNL_CAST_INT64_TO_HALF; + case DT::I32: + return CNNL_CAST_INT64_TO_INT32; + case DT::U32: + return CNNL_CAST_INT64_TO_UINT32; + default: + UNREACHABLE(); + } + case DT::U32: + switch (to) { + case DT::I64: + return CNNL_CAST_UINT32_TO_INT64; + case DT::U64: + return CNNL_CAST_UINT32_TO_UINT64; + default: + UNREACHABLE(); + } + case DT::F64: + switch (to) { + case DT::F32: + return CNNL_CAST_DOUBLE_TO_FLOAT; + default: + UNREACHABLE(); + } + case DT::BF16: + switch (to) { + // case DT::F32: + // return CNNL_CAST_BF16_TO_FLOAT; + default: + UNREACHABLE(); + } + default: + UNREACHABLE(); + } + } + +#endif + +}// namespace refactor::kernel diff --git a/src/04kernel/src/kernels/cast/cnnl_kernel.hh b/src/04kernel/src/kernels/cast/cnnl_kernel.hh new file mode 100644 index 000000000..b1e638080 --- /dev/null +++ b/src/04kernel/src/kernels/cast/cnnl_kernel.hh @@ -0,0 +1,27 @@ +#ifndef KERNEL_CAST_CNNL_KERNEL_HH +#define KERNEL_CAST_CNNL_KERNEL_HH + +#include "kernel/kernel.h" +#include "kernel/tensor.h" + +namespace refactor::kernel { + + struct CastCnnl final : public Kernel { + DataType from, to; + std::vector<int> shape; + + CastCnnl(decltype(from), decltype(to), decltype(shape)) noexcept; + + static KernelBox build(Tensor const &, Tensor const &) noexcept; + static size_t typeId() noexcept; + + size_t kernelTypeId() const noexcept final; + std::string_view description() const noexcept final; +#ifdef USE_BANG + RoutineWorkspace lower(Resources &) const final; +#endif + }; + +}// namespace refactor::kernel + +#endif// KERNEL_CAST_CNNL_KERNEL_HH diff --git a/src/04kernel/src/kernels/clip/cnnl_kernel.cc b/src/04kernel/src/kernels/clip/cnnl_kernel.cc new file mode 100644 index 000000000..fe65e99b3 --- /dev/null +++ b/src/04kernel/src/kernels/clip/cnnl_kernel.cc @@ -0,0 +1,65 @@ +#include "cnnl_kernel.hh" + +#ifdef USE_BANG +#include "../../utilities/bang/cnnl_context.hh" +#include "../../utilities/bang/cnnl_functions.h" +#endif + +namespace refactor::kernel { + using K = ClipCnnl; + + K::ClipCnnl(decltype(dataType) dt, + decltype(shape) shape_, + decltype(hasMax) hasMax_) noexcept + : dataType(dt), shape(shape_), hasMax(hasMax_) { + } + + auto K::build(Tensor const &data, bool hasMax) noexcept -> KernelBox { + return data.dataType.isCpuNumberic() + ? std::make_unique<K>(data.dataType, + std::vector<int>(data.shape.begin(), data.shape.end()), + hasMax) + : nullptr; + } + auto K::typeId() noexcept -> size_t { + static uint8_t ID = 1; + return reinterpret_cast<size_t>(&ID); + } + + auto K::kernelTypeId() const noexcept -> size_t { + return typeId(); + } + auto K::description() const noexcept -> std::string_view { + return "Performing clip operation using CNNL"; + } + +#ifdef USE_BANG + auto K::lower(Resources &res) const noexcept -> RoutineWorkspace { + using namespace cnnl; + using namespace runtime; + + struct Descriptors { + cnnlTensorDescriptor_t t; + + Descriptors() : t(nullptr) { + CNNL_ASSERT(cnnlCreateTensorDescriptor(&t)); + } + ~Descriptors() noexcept(false) { + CNNL_ASSERT(cnnlDestroyTensorDescriptor(t)); + } + }; + auto d = std::make_shared<Descriptors>(); + setCnnlTensor(d->t, dataType, slice(shape.data(), shape.size())); + + res.fetchOrStore<CnnlContext>(); + return [d = std::move(d), hasMax = this->hasMax](Resources &res, void *workspace, void const *const *inputs, void *const *outputs) { + CNNL_ASSERT(cnnlClip_v2(res.fetchOrStore<CnnlContext>()->handle, + CNNL_POINTER_MODE_DEVICE, d->t, + inputs[0], inputs[1], hasMax ? inputs[2] : nullptr, + d->t, outputs[0])); + }; + } + +#endif + +}// namespace refactor::kernel diff --git a/src/04kernel/src/kernels/clip/cnnl_kernel.hh b/src/04kernel/src/kernels/clip/cnnl_kernel.hh new file mode 100644 index 000000000..37d168062 --- /dev/null +++ b/src/04kernel/src/kernels/clip/cnnl_kernel.hh @@ -0,0 +1,28 @@ +#ifndef KERNEL_CLIP_CNNL_KERNEL_HH +#define KERNEL_CLIP_CNNL_KERNEL_HH + +#include "kernel/kernel.h" +#include "kernel/tensor.h" + +namespace refactor::kernel { + + struct ClipCnnl final : public Kernel { + DataType dataType; + std::vector<int> shape; + bool hasMax; + + ClipCnnl(decltype(dataType), decltype(shape), decltype(hasMax)) noexcept; + + static KernelBox build(Tensor const &, bool hasMax) noexcept; + static size_t typeId() noexcept; + + size_t kernelTypeId() const noexcept final; + std::string_view description() const noexcept final; +#ifdef USE_BANG + RoutineWorkspace lower(Resources &) const noexcept final; +#endif + }; + +}// namespace refactor::kernel + +#endif// KERNEL_CLIP_CNNL_KERNEL_HH diff --git a/src/04kernel/src/kernels/concat/cnnl_kernel.cc b/src/04kernel/src/kernels/concat/cnnl_kernel.cc new file mode 100644 index 000000000..c35b1c33f --- /dev/null +++ b/src/04kernel/src/kernels/concat/cnnl_kernel.cc @@ -0,0 +1,93 @@ +#include "cnnl_kernel.hh" + +#ifdef USE_BANG +#include "../../utilities/bang/cnnl_context.hh" +#include "../../utilities/bang/cnnl_functions.h" +#include <cnnl.h> +#endif + +namespace refactor::kernel { + using K = ConcatCnnl; + + K::ConcatCnnl(SplitInfoCnnl info_) noexcept + : Kernel(), info(std::move(info_)) {} + + auto K::build(int axis, TensorRefs inputs, Tensor const &output) noexcept -> KernelBox { +#ifndef USE_BANG + return nullptr; +#endif + return std::make_unique<K>(SplitInfoCnnl(axis, output, inputs)); + } + auto K::typeId() noexcept -> size_t { + static uint8_t ID = 1; + return reinterpret_cast<size_t>(&ID); + } + + auto K::kernelTypeId() const noexcept -> size_t { + return typeId(); + } + auto K::description() const noexcept -> std::string_view { + return "Performing split operation using CNNL"; + } + +#ifdef USE_BANG + auto ConcatCnnl::lower(Resources &res) const -> RoutineWorkspace { + using namespace cnnl; + using namespace runtime; + using DT = DataType; + + struct Descriptors { + cnnlTensorDescriptor_t in; + std::vector<cnnlTensorDescriptor_t> out; + bool f32; + + explicit Descriptors(int n, decltype(f32) f32_) + : in(nullptr), + out(std::vector<cnnlTensorDescriptor_t>(n, nullptr)), + f32(f32_) { + CNNL_ASSERT(cnnlCreateTensorDescriptor(&in)); + for (auto i = 0; i < n; i++) { + CNNL_ASSERT(cnnlCreateTensorDescriptor(&out[i])); + } + } + ~Descriptors() noexcept(false) { + CNNL_ASSERT(cnnlDestroyTensorDescriptor(in)); + for (size_t i = 0; i < out.size(); i++) { + CNNL_ASSERT(cnnlDestroyTensorDescriptor(out[i])); + } + } + + Descriptors(const Descriptors &) = delete; + Descriptors(Descriptors &&) = delete; + }; + auto d = std::make_shared<Descriptors>(info.num, info.dataType != DT::F64); + setCnnlTensor(d->in, info.dataType, slice(info.inDim.data(), info.inDim.size())); + for (size_t i = 0; i < info.outDims.size(); i++) { + setCnnlTensor(d->out[i], info.dataType, slice(info.outDims[i].data(), info.outDims[i].size())); + } + + auto handle = res.fetchOrStore<CnnlContext>()->handle; + size_t workspaceSize; + CNNL_ASSERT(cnnlGetSplitWorkspaceSize(handle, info.num, &workspaceSize)); + + res.fetchOrStore<CnnlContext>(); + auto routine = [d = std::move(d), n = info.num, axis = info.axis, workspaceSize](Resources &res, void *workspace, void const *const *inputs, void *const *outputs) { + // fetch cnnl handle from resources + auto handle = res.fetchOrStore<CnnlContext>()->handle; + + const void *argv[n]; + for (auto i = 0; i < n; i++) { + argv[i] = inputs[i]; + } + + CNNL_ASSERT(cnnlConcat( + handle, n, axis, d->out.data(), argv, + workspace, workspaceSize, d->in, outputs[0])); + }; + + return {std::move(routine), workspaceSize}; + } + +#endif + +}// namespace refactor::kernel diff --git a/src/04kernel/src/kernels/concat/cnnl_kernel.hh b/src/04kernel/src/kernels/concat/cnnl_kernel.hh new file mode 100644 index 000000000..0d4f6f853 --- /dev/null +++ b/src/04kernel/src/kernels/concat/cnnl_kernel.hh @@ -0,0 +1,26 @@ +#ifndef KERNEL_CONCAT_CNNL_KERNEL_HH +#define KERNEL_CONCAT_CNNL_KERNEL_HH + +#include "../../kernels/split/cnnl_kernel.hh" +#include "kernel/kernel.h" + +namespace refactor::kernel { + + struct ConcatCnnl final : public Kernel { + SplitInfoCnnl info; + + explicit ConcatCnnl(SplitInfoCnnl) noexcept; + + static KernelBox build(int, TensorRefs, Tensor const &) noexcept; + static size_t typeId() noexcept; + + size_t kernelTypeId() const noexcept final; + std::string_view description() const noexcept final; +#ifdef USE_BANG + RoutineWorkspace lower(Resources &) const final; +#endif + }; + +}// namespace refactor::kernel + +#endif// KERNEL_CONCAT_CNNL_KERNEL_HH diff --git a/src/04kernel/src/kernels/conv/cnnl_kernel.cc b/src/04kernel/src/kernels/conv/cnnl_kernel.cc new file mode 100644 index 000000000..0974a7600 --- /dev/null +++ b/src/04kernel/src/kernels/conv/cnnl_kernel.cc @@ -0,0 +1,224 @@ +#include "cnnl_kernel.hh" + +#ifdef USE_BANG +#include "../../utilities/bang/cnnl_context.hh" +#include "../../utilities/bang/cnnl_functions.h" +#include "../expand/cnnl_kernel.hh" +#include "hardware/functions.h" +#endif + +namespace refactor::kernel { + using K = ConvCnnl; + + K::ConvCnnl(decltype(info) info_) noexcept + : Kernel(), info(std::move(info_)) {} + + auto K::build(PoolAttributes const &poolAttributes, + Tensor const &x, + Tensor const &w, + std::optional<std::reference_wrapper<Tensor const>> b, + Tensor const &y) -> KernelBox { + static const std::unordered_set<decltype(DataType::internal)> + SET{DataType::FP16, DataType::BF16, DataType::F32, DataType::F64, DataType::I8}; +#ifndef USE_BANG + return nullptr; +#endif + + auto dt = x.dataType; + if (!SET.contains(dt) || w.dataType != dt || y.dataType != dt) { + return nullptr; + } + + int biasSize_ = 0; + if (b) { + ASSERT(b->get().shape[0] == y.shape[1], ""); + biasSize_ = b->get().shape[0]; + } + + // group is not supported + if (w.rank() != 4 || poolAttributes.rank() != 2) { + return nullptr; + } + auto d = poolAttributes.dilations(), + p = poolAttributes.pads(), + s = poolAttributes.strides(); + return std::make_unique<K>(decltype(info){ + dt, + { + static_cast<int>(x.shape[0]), + static_cast<int>(x.shape[1]), + static_cast<int>(x.shape[2]), + static_cast<int>(x.shape[3]), + }, + { + static_cast<int>(w.shape[0]), + static_cast<int>(w.shape[1]), + static_cast<int>(w.shape[2]), + static_cast<int>(w.shape[3]), + }, + { + static_cast<int>(y.shape[0]), + static_cast<int>(y.shape[1]), + static_cast<int>(y.shape[2]), + static_cast<int>(y.shape[3]), + }, + {d[0], d[1]}, + {p[0], p[1], p[2], p[3]}, + {s[0], s[1]}, + biasSize_, + }); + } + + auto K::typeId() noexcept -> size_t { + static uint8_t ID = 1; + return reinterpret_cast<size_t>(&ID); + } + + auto K::kernelTypeId() const noexcept -> size_t { return typeId(); } + auto K::description() const noexcept -> std::string_view { + return "Performing conv using CNNL"; + } + +#ifdef USE_BANG + + auto ConvCnnl::lower(Resources &res) const -> RoutineWorkspace { + using namespace cnnl; + using namespace runtime; + + // RAII for closure + struct Descriptors { + cnnlTensorDescriptor_t x, y, w, b; + cnnlTensorDescriptor_t xTrans, yTrans, wTrans; + cnnlTransposeDescriptor_t NCHW2NHWC, NHWC2NCHW; + cnnlConvolutionDescriptor_t conv; + cnnlConvolutionForwardAlgo_t algo; + bool bias; + + Descriptors(decltype(bias) bias_) : bias(bias_) { + CNNL_ASSERT(cnnlCreateTensorDescriptor(&x)); + CNNL_ASSERT(cnnlCreateTensorDescriptor(&y)); + CNNL_ASSERT(cnnlCreateTensorDescriptor(&w)); + CNNL_ASSERT(cnnlCreateTensorDescriptor(&b)); + CNNL_ASSERT(cnnlCreateTensorDescriptor(&xTrans)); + CNNL_ASSERT(cnnlCreateTensorDescriptor(&yTrans)); + CNNL_ASSERT(cnnlCreateTensorDescriptor(&wTrans)); + CNNL_ASSERT(cnnlCreateTransposeDescriptor(&NCHW2NHWC)); + CNNL_ASSERT(cnnlCreateTransposeDescriptor(&NHWC2NCHW)); + CNNL_ASSERT(cnnlCreateConvolutionDescriptor(&conv)); + } + ~Descriptors() noexcept(false) { + CNNL_ASSERT(cnnlDestroyTensorDescriptor(x)); + CNNL_ASSERT(cnnlDestroyTensorDescriptor(y)); + CNNL_ASSERT(cnnlDestroyTensorDescriptor(w)); + CNNL_ASSERT(cnnlDestroyTensorDescriptor(b)); + CNNL_ASSERT(cnnlDestroyTensorDescriptor(xTrans)); + CNNL_ASSERT(cnnlDestroyTensorDescriptor(yTrans)); + CNNL_ASSERT(cnnlDestroyTensorDescriptor(wTrans)); + CNNL_ASSERT(cnnlDestroyTransposeDescriptor(NCHW2NHWC)); + CNNL_ASSERT(cnnlDestroyTransposeDescriptor(NHWC2NCHW)); + CNNL_ASSERT(cnnlDestroyConvolutionDescriptor(conv)); + } + + Descriptors(const Descriptors &) = delete; + Descriptors(Descriptors &&) = delete; + }; + auto d = std::make_shared<Descriptors>(info.biasSize > 0); + + int xs[]{ + info.xShape[0], + info.xShape[1], + info.xShape[2] + std::abs(info.pad[0] - info.pad[2]), + info.xShape[3] + std::abs(info.pad[1] - info.pad[3]), + }; + + auto NHWC = [](const int shape[]) -> std::vector<int> { + return { + shape[0], shape[2], shape[3], shape[1]}; + }; + + std::vector<int> xsNHWC = NHWC(xs); + std::vector<int> wsNHWC = NHWC(info.wShape); + std::vector<int> ysNHWC = NHWC(info.yShape); + + setCnnlTensor(d->x, info.dt, slice(xs, 4)); + setCnnlTensor(d->y, info.dt, slice(info.yShape, 4)); + setCnnlTensor(d->w, info.dt, slice(info.wShape, 4)); + + CNNL_ASSERT(cnnlSetTensorDescriptor(d->xTrans, CNNL_LAYOUT_NHWC, cnnlDataTypeConvert(info.dt), 4, xsNHWC.data())); + CNNL_ASSERT(cnnlSetTensorDescriptor(d->yTrans, CNNL_LAYOUT_NHWC, cnnlDataTypeConvert(info.dt), 4, ysNHWC.data())); + CNNL_ASSERT(cnnlSetTensorDescriptor(d->wTrans, CNNL_LAYOUT_NHWC, cnnlDataTypeConvert(info.dt), 4, wsNHWC.data())); + if (d->bias) { + int biasDim[] = {1, 1, 1, info.biasSize}; + CNNL_ASSERT(cnnlSetTensorDescriptor(d->b, CNNL_LAYOUT_NHWC, cnnlDataTypeConvert(info.dt), 4, biasDim)); + } + + auto xTransSize = cnnlGetTensorElementNum(d->xTrans) * info.dt.size(); + auto yTransSize = cnnlGetTensorElementNum(d->yTrans) * info.dt.size(); + auto wTransSize = cnnlGetTensorElementNum(d->wTrans) * info.dt.size(); + + int permuteIn[4] = {0, 2, 3, 1}; + int permuteOut[4] = {0, 3, 1, 2}; + CNNL_ASSERT(cnnlSetTransposeDescriptor(d->NCHW2NHWC, 4, permuteIn)); + CNNL_ASSERT(cnnlSetTransposeDescriptor(d->NHWC2NCHW, 4, permuteOut)); + + size_t xWorkspaceSize, yWorkspaceSize, wWorkspaceSize, convWorkspaceSize; + auto handle = res.fetchOrStore<CnnlContext>()->handle; + CNNL_ASSERT(cnnlGetTransposeWorkspaceSize(handle, d->x, d->NCHW2NHWC, &xWorkspaceSize)); + CNNL_ASSERT(cnnlGetTransposeWorkspaceSize(handle, d->w, d->NCHW2NHWC, &wWorkspaceSize)); + CNNL_ASSERT(cnnlGetTransposeWorkspaceSize(handle, d->yTrans, d->NHWC2NCHW, &yWorkspaceSize)); + + // clang-format off + auto computation = info.dt == DataType::F64 ? DataType::F64 + : info.dt == DataType::I8 ? DataType::I32 + : DataType::F32; + // clang-format on + auto group = xs[1] / info.wShape[1]; + CNNL_ASSERT(cnnlSetConvolutionDescriptor(d->conv, 4, info.pad, info.stride, info.dilation, group, cnnlDataTypeConvert(computation))); + CNNL_ASSERT(cnnlGetConvolutionForwardAlgorithm( + handle, d->conv, d->xTrans, d->wTrans, d->yTrans, + CNNL_CONVOLUTION_FWD_FASTEST, &d->algo)); + + CNNL_ASSERT(cnnlGetConvolutionForwardWorkspaceSize( + handle, d->xTrans, d->wTrans, d->yTrans, NULL, + d->conv, d->algo, &convWorkspaceSize)); + + size_t workspaceSize = xTransSize + yTransSize + wTransSize + std::max({xWorkspaceSize, wWorkspaceSize, yWorkspaceSize, convWorkspaceSize}); + + res.fetchOrStore<CnnlContext>(); + auto routine = [d, xTransSize, yTransSize, wTransSize, + xWorkspaceSize, wWorkspaceSize, + yWorkspaceSize, convWorkspaceSize](Resources &res, void *workspace, void const *const *inputs, void *const *outputs) { + auto handle = res.fetchOrStore<CnnlContext>()->handle; + void const *x = inputs[0], *w = inputs[1]; + void *y = outputs[0]; + + void *xTrans = workspace; + void *wTrans = reinterpret_cast<uint8_t *>(xTrans) + xTransSize; + void *yTrans = reinterpret_cast<uint8_t *>(wTrans) + wTransSize; + void *opWorkspace = reinterpret_cast<uint8_t *>(yTrans) + yTransSize; + + // transpose NCHW input to NHWC + CNNL_ASSERT(cnnlTranspose_v2(handle, d->NCHW2NHWC, d->x, x, + d->xTrans, xTrans, opWorkspace, xWorkspaceSize)); + CNNL_ASSERT(cnnlTranspose_v2(handle, d->NCHW2NHWC, d->w, w, + d->wTrans, wTrans, opWorkspace, wWorkspaceSize)); + + auto bDesc = (d->bias) ? d->b : NULL; + auto bData = (d->bias) ? inputs[2] : NULL; + CNNL_ASSERT(cnnlConvolutionForward( + handle, + d->conv, d->algo, NULL, + d->xTrans, xTrans, d->wTrans, wTrans, + bDesc, bData, opWorkspace, convWorkspaceSize, + NULL, d->yTrans, yTrans)); + + // transpose NHWC intermediates to NCHW + CNNL_ASSERT(cnnlTranspose_v2(handle, d->NHWC2NCHW, d->yTrans, yTrans, + d->y, y, opWorkspace, yWorkspaceSize)); + }; + return {std::move(routine), workspaceSize}; + } + +#endif + +}// namespace refactor::kernel diff --git a/src/04kernel/src/kernels/conv/cnnl_kernel.hh b/src/04kernel/src/kernels/conv/cnnl_kernel.hh new file mode 100644 index 000000000..c486cdd17 --- /dev/null +++ b/src/04kernel/src/kernels/conv/cnnl_kernel.hh @@ -0,0 +1,43 @@ +#ifndef KERNEL_CONV_CNNL_KERNEL_HH +#define KERNEL_CONV_CNNL_KERNEL_HH + +#include "../../kernels/expand/cnnl_kernel.hh" +#include "kernel/attributes/pool_attributes.h" +#include "kernel/kernel.h" +#include <optional> + +namespace refactor::kernel { + + /// @brief Use `cnnlConvolutionForward`. + /// It only supports 4D tensors. + struct ConvCnnl final : public Kernel { + struct { + DataType dt; + int xShape[4], + wShape[4], + yShape[4], + dilation[2], + pad[4], + stride[2]; + int biasSize; + } info; + + explicit ConvCnnl(decltype(info)) noexcept; + + static KernelBox build(PoolAttributes const &, + Tensor const &, + Tensor const &, + std::optional<std::reference_wrapper<Tensor const>>, + Tensor const &); + static size_t typeId() noexcept; + + size_t kernelTypeId() const noexcept final; + std::string_view description() const noexcept final; +#ifdef USE_BANG + RoutineWorkspace lower(Resources &) const final; +#endif + }; + +}// namespace refactor::kernel + +#endif// KERNEL_CONV_CNNL_KERNEL_HH diff --git a/src/04kernel/src/kernels/expand/cnnl_kernel.cc b/src/04kernel/src/kernels/expand/cnnl_kernel.cc new file mode 100644 index 000000000..426eac49d --- /dev/null +++ b/src/04kernel/src/kernels/expand/cnnl_kernel.cc @@ -0,0 +1,67 @@ +#include "cnnl_kernel.hh" + +#ifdef USE_BANG +#include "../../utilities/bang/cnnl_context.hh" +#include "../../utilities/bang/cnnl_functions.h" +#endif + +namespace refactor::kernel { + using K = ExpandCnnl; + + K::ExpandCnnl(ExpandInfoCnnl info_) noexcept + : Kernel(), info(info_) {} + + auto K::build(Tensor const &input, Tensor const &output) noexcept -> KernelBox { +#ifndef USE_BANG + return nullptr; +#endif + return std::make_unique<K>(ExpandInfoCnnl( + input.dataType, + slice(input.shape.data(), input.rank()), + slice(output.shape.data(), output.rank()) + )); + } + auto K::typeId() noexcept -> size_t { + static uint8_t ID = 1; + return reinterpret_cast<size_t>(&ID); + } + + auto K::kernelTypeId() const noexcept -> size_t { + return typeId(); + } + auto K::description() const noexcept -> std::string_view { + return "Performing expand operation using CNNL"; + } + +#ifdef USE_BANG + auto ExpandCnnl::lower(Resources &res) const noexcept -> RoutineWorkspace { + using namespace cnnl; + using namespace runtime; + + struct Descriptors { + cnnlTensorDescriptor_t inDesc, outDesc; + + Descriptors() : inDesc(nullptr), outDesc(nullptr) { + CNNL_ASSERT(cnnlCreateTensorDescriptor(&inDesc)); + CNNL_ASSERT(cnnlCreateTensorDescriptor(&outDesc)); + } + ~Descriptors() noexcept(false) { + CNNL_ASSERT(cnnlDestroyTensorDescriptor(inDesc)); + CNNL_ASSERT(cnnlDestroyTensorDescriptor(outDesc)); + } + }; + auto d = std::make_shared<Descriptors>(); + std::vector<int> in(info.inDims.begin(), info.inDims.end()), + out(info.outDims.begin(), info.outDims.end()); + setCnnlTensor(d->inDesc, info.dataType, slice(in.data(), in.size())); + setCnnlTensor(d->outDesc, info.dataType, slice(out.data(), out.size())); + + res.fetchOrStore<CnnlContext>(); + return [d = std::move(d)](Resources &res, void *workspace, void const *const *inputs, void *const *outputs) { + CNNL_ASSERT(cnnlExpand(res.fetchOrStore<CnnlContext>()->handle, + d->inDesc, inputs[0], d->outDesc, outputs[0])); + }; + } +#endif + +}// namespace refactor::kernel diff --git a/src/04kernel/src/kernels/expand/cnnl_kernel.hh b/src/04kernel/src/kernels/expand/cnnl_kernel.hh new file mode 100644 index 000000000..a6271907e --- /dev/null +++ b/src/04kernel/src/kernels/expand/cnnl_kernel.hh @@ -0,0 +1,31 @@ +#ifndef KERNEL_EXPAND_CNNL_KERNEL_HH +#define KERNEL_EXPAND_CNNL_KERNEL_HH + +#include "kernel/kernel.h" +#include "kernel/tensor.h" + +namespace refactor::kernel { + + struct ExpandInfoCnnl { + DataType dataType; + slice_t<dim_t> inDims, outDims; + }; + + struct ExpandCnnl final : public Kernel { + ExpandInfoCnnl info; + + explicit ExpandCnnl(ExpandInfoCnnl) noexcept; + + static KernelBox build(Tensor const &input, Tensor const &output) noexcept; + static size_t typeId() noexcept; + + size_t kernelTypeId() const noexcept final; + std::string_view description() const noexcept final; +#ifdef USE_BANG + RoutineWorkspace lower(Resources &) const noexcept final; +#endif + }; + +}// namespace refactor::kernel + +#endif// KERNEL_EXPAND_CNNL_KERNEL_HH diff --git a/src/04kernel/src/kernels/gather/cnnl_kernel.cc b/src/04kernel/src/kernels/gather/cnnl_kernel.cc new file mode 100644 index 000000000..58b86660b --- /dev/null +++ b/src/04kernel/src/kernels/gather/cnnl_kernel.cc @@ -0,0 +1,92 @@ +#include "cnnl_kernel.hh" + +#ifdef USE_BANG +#include "../../utilities/bang/cnnl_context.hh" +#include "../../utilities/bang/cnnl_functions.h" +#endif +#include <iostream> + +namespace refactor::kernel { + using K = GatherCnnl; + + K::GatherCnnl(decltype(info) info_) noexcept + : Kernel(), info(std::move(info_)) {} + + auto K::build(int axis, Tensor const &input, Tensor const &index, Tensor const &output) noexcept -> KernelBox { +#ifndef USE_BANG + return nullptr; +#endif + auto indicesDim = std::vector<int>(index.shape.begin(), index.shape.end()); + if (indicesDim.size() == 0) { + indicesDim.push_back(1); + } + return std::make_unique<K>(decltype(info){ + input.dataType, + index.dataType, + axis, + std::vector<int>(input.shape.begin(), input.shape.end()), + std::move(indicesDim), + std::vector<int>(output.shape.begin(), output.shape.end()), + }); + } + + auto K::typeId() noexcept -> size_t { + static uint8_t ID = 1; + return reinterpret_cast<size_t>(&ID); + } + + auto K::kernelTypeId() const noexcept -> size_t { return typeId(); } + auto K::description() const noexcept -> std::string_view { + return "Performing gather using CNNL"; + } + +#ifdef USE_BANG + auto K::lower(Resources &res) const noexcept -> RoutineWorkspace { + using namespace cnnl; + using namespace runtime; + + struct Descriptors { + cnnlTensorDescriptor_t inDesc, indexDesc, outDesc; + + Descriptors() : inDesc(nullptr), indexDesc(nullptr), outDesc(nullptr) { + CNNL_ASSERT(cnnlCreateTensorDescriptor(&inDesc)); + CNNL_ASSERT(cnnlCreateTensorDescriptor(&indexDesc)); + CNNL_ASSERT(cnnlCreateTensorDescriptor(&outDesc)); + } + ~Descriptors() noexcept(false) { + CNNL_ASSERT(cnnlDestroyTensorDescriptor(inDesc)); + CNNL_ASSERT(cnnlDestroyTensorDescriptor(indexDesc)); + CNNL_ASSERT(cnnlDestroyTensorDescriptor(outDesc)); + } + }; + auto d = std::make_shared<Descriptors>(); + CNNL_ASSERT(cnnlSetTensorDescriptor( + d->inDesc, CNNL_LAYOUT_ARRAY, cnnlDataTypeConvert(info.dataType), + info.inDim.size(), info.inDim.data())); + // cnnlGatherV2 does not support int64 indices + CNNL_ASSERT(cnnlSetTensorDescriptor( + d->indexDesc, CNNL_LAYOUT_ARRAY, CNNL_DTYPE_INT32, + info.indexDim.size(), info.indexDim.data())); + CNNL_ASSERT(cnnlSetTensorDescriptor( + d->outDesc, CNNL_LAYOUT_ARRAY, cnnlDataTypeConvert(info.dataType), + info.outDim.size(), info.outDim.data())); + + size_t workspaceSize = info.inDim.size() * sizeof(int); + + res.fetchOrStore<CnnlContext>(); + auto routine = [d = std::move(d), + shape = std::vector<int>(info.inDim.begin(), info.inDim.end()), + workspaceSize, + dim = info.axis](Resources &res, void *workspace, void const *const *inputs, void *const *outputs) { + res.fetchOrStore<CnnlContext>()->copyFromCPU(workspace, shape.data(), workspaceSize); + CNNL_ASSERT(cnnlGatherV2(res.fetchOrStore<CnnlContext>()->handle, dim, + d->inDesc, inputs[0], reinterpret_cast<const int *>(workspace), + d->indexDesc, reinterpret_cast<const int *>(inputs[1]), + d->outDesc, outputs[0])); + }; + + return {std::move(routine), workspaceSize}; + } +#endif + +}// namespace refactor::kernel diff --git a/src/04kernel/src/kernels/gather/cnnl_kernel.hh b/src/04kernel/src/kernels/gather/cnnl_kernel.hh new file mode 100644 index 000000000..154345929 --- /dev/null +++ b/src/04kernel/src/kernels/gather/cnnl_kernel.hh @@ -0,0 +1,30 @@ +#ifndef KERNEL_GATHER_CNNL_KERNEL_HH +#define KERNEL_GATHER_CNNL_KERNEL_HH + +#include "kernel/attributes/gather_info.h" +#include "kernel/kernel.h" + +namespace refactor::kernel { + + struct GatherCnnl final : public Kernel { + struct { + DataType dataType, indexDataType; + int axis; + std::vector<int> inDim, indexDim, outDim; + } info; + + explicit GatherCnnl(decltype(info)) noexcept; + + static KernelBox build(int, Tensor const &, Tensor const &, Tensor const &) noexcept; + static size_t typeId() noexcept; + + size_t kernelTypeId() const noexcept final; + std::string_view description() const noexcept final; +#ifdef USE_BANG + RoutineWorkspace lower(Resources &) const noexcept final; +#endif + }; + +}// namespace refactor::kernel + +#endif// KERNEL_TRANSPOSE_CNNL_KERNEL_HH diff --git a/src/04kernel/src/kernels/hard_sigmoid/cnnl_kernel.cc b/src/04kernel/src/kernels/hard_sigmoid/cnnl_kernel.cc new file mode 100644 index 000000000..25ec6d0cb --- /dev/null +++ b/src/04kernel/src/kernels/hard_sigmoid/cnnl_kernel.cc @@ -0,0 +1,81 @@ +#include "cnnl_kernel.hh" +#include "kernel/collectors/hard_sigmoid.h" +#include <unordered_set> + +#ifdef USE_BANG +#include "../../utilities/bang/cnnl_context.hh" +#include "../../utilities/bang/cnnl_functions.h" +#include <cnnl.h> +#endif + +namespace refactor::kernel { + using K = HardSigmoidCnnl; + using DT = DataType; + + K::HardSigmoidCnnl(float alpha_, float beta_, DT dataType_, int size_) noexcept + : Kernel(), alpha(alpha_), beta(beta_), dataType(dataType_), size(size_) {} + + auto K::build(float alpha_, float beta_, Tensor const &a) noexcept -> KernelBox { + +#ifndef USE_BANG + return nullptr; +#endif + + return std::make_unique<K>(alpha_, beta_, a.dataType, a.elementsSize()); + } + auto K::typeId() noexcept -> size_t { + static uint8_t ID = 1; + return reinterpret_cast<size_t>(&ID); + } + + auto K::kernelTypeId() const noexcept -> size_t { return typeId(); } + auto K::description() const noexcept -> std::string_view { + return "Performing hardsigmoid using CNNL"; + } + +#ifdef USE_BANG + + auto HardSigmoidCnnl::lower(Resources &res) const -> RoutineWorkspace { + using namespace cnnl; + using namespace runtime; + + // RAII for closure + struct Descriptors { + cnnlActivationDescriptor_t activation; + cnnlTensorDescriptor_t tensor; + + Descriptors() : activation(nullptr), tensor(nullptr) { + CNNL_ASSERT(cnnlCreateActivationDescriptor(&activation)); + CNNL_ASSERT(cnnlCreateTensorDescriptor(&tensor)); + } + ~Descriptors() noexcept(false) { + CNNL_ASSERT(cnnlDestroyActivationDescriptor(activation)); + CNNL_ASSERT(cnnlDestroyTensorDescriptor(tensor)); + } + + Descriptors(const Descriptors &) = delete; + Descriptors(Descriptors &&) = delete; + }; + auto d = std::make_shared<Descriptors>(); + + setCnnlTensor(d->tensor, dataType, slice(&size, 1)); + CNNL_ASSERT(cnnlSetActivationDescriptor_v5(d->activation, CNNL_ACTIVATION_HARDSIGMOID, + CNNL_ACTIVATION_HIGH_PRECISION, + CNNL_NOT_PROPAGATE_NAN, 0.0, + 0.0, alpha, beta, true)); + + res.fetchOrStore<CnnlContext>(); + return [d = std::move(d)]// + (Resources & res, void *, void const *const *inputs, void *const *outputs) { + float alpha = 1.f, beta = 0.f; + CNNL_ASSERT(cnnlActivationForward( + res.fetchOrStore<CnnlContext>()->handle, + d->activation, + &alpha, d->tensor, inputs[0], + &beta, d->tensor, outputs[0])); + }; + } + +#endif + +}// namespace refactor::kernel diff --git a/src/04kernel/src/kernels/hard_sigmoid/cnnl_kernel.hh b/src/04kernel/src/kernels/hard_sigmoid/cnnl_kernel.hh new file mode 100644 index 000000000..c343d1fed --- /dev/null +++ b/src/04kernel/src/kernels/hard_sigmoid/cnnl_kernel.hh @@ -0,0 +1,27 @@ +#ifndef KERNEL_HARD_SIGMOID_CNNL_KERNEL_HH +#define KERNEL_HARD_SIGMOID_CNNL_KERNEL_HH + +#include "kernel/collectors/simple_unary.h" + +namespace refactor::kernel { + + struct HardSigmoidCnnl final : public Kernel { + float alpha, beta; + DataType dataType; + int size; + + HardSigmoidCnnl(float, float, DataType, int) noexcept; + + static KernelBox build(float, float, Tensor const &) noexcept; + static size_t typeId() noexcept; + + size_t kernelTypeId() const noexcept final; + std::string_view description() const noexcept final; +#ifdef USE_BANG + RoutineWorkspace lower(Resources &) const final; +#endif + }; + +}// namespace refactor::kernel + +#endif// KERNEL_HARD_SIGMOID_CNNL_KERNEL_HH diff --git a/src/04kernel/src/kernels/mat_mul/cnnl_kernel.cc b/src/04kernel/src/kernels/mat_mul/cnnl_kernel.cc new file mode 100644 index 000000000..3eac35723 --- /dev/null +++ b/src/04kernel/src/kernels/mat_mul/cnnl_kernel.cc @@ -0,0 +1,152 @@ +#include "cnnl_kernel.hh" +#include <numeric> + +#ifdef USE_BANG +#include "../../utilities/bang/cnnl_context.hh" +#include "../../utilities/bang/cnnl_functions.h" +#include <cnnl.h> +#endif + +namespace refactor::kernel { + using K = MatMulCnnl; + using DT = DataType; + + K::MatMulCnnl(decltype(info) info_) noexcept + : Kernel(), info(std::move(info_)) {} + + auto K::build(TensorRefs inputs_, TensorRefs outputs_, bool transA_, bool transB_, float alpha_, float beta_) noexcept -> KernelBox { +#ifndef USE_BANG + return nullptr; +#endif + auto dt = inputs_[0].get().dataType; + return dt.isIeee754() || dt == DT::I8 + ? std::make_unique<K>(decltype(info){ + dt, + transA_, + transB_, + alpha_, + beta_, + std::vector<int>(inputs_[0].get().shape.begin(), inputs_[0].get().shape.end()), + std::vector<int>(inputs_[1].get().shape.begin(), inputs_[1].get().shape.end()), + std::vector<int>(outputs_[0].get().shape.begin(), outputs_[0].get().shape.end()), + inputs_.size() == 3 + ? inputs_[2].get().shape.size() == 0 ? std::make_optional(std::vector<int>(1, 1)) + : std::make_optional(std::vector<int>( + inputs_[2].get().shape.begin(), + inputs_[2].get().shape.end())) + : std::nullopt, + }) + : nullptr; + } + + auto K::typeId() noexcept -> size_t { + static uint8_t ID = 1; + return reinterpret_cast<size_t>(&ID); + } + + auto K::kernelTypeId() const noexcept -> size_t { return typeId(); } + auto K::description() const noexcept -> std::string_view { + return "Performing MatMul using CNNL"; + } + + +#ifdef USE_BANG + auto K::lower(Resources &res) const noexcept -> RoutineWorkspace { + using namespace cnnl; + using namespace runtime; + using DT = DataType; + + // RAII for closure + struct Descriptors { + cnnlTensorDescriptor_t a, b, c; + cnnlMatMulDescriptor_t bmm; + cnnlMatMulAlgo_t algo; + cnnlMatMulHeuristicResult_t heuristic; + cnnlTensorDescriptor_t bias; + bool addBias, f32; + + explicit Descriptors(bool addBias_, bool f32_) + : a(nullptr), b(nullptr), c(nullptr), + bmm(nullptr), algo(nullptr), heuristic(nullptr), + bias(nullptr), addBias(addBias_), f32(f32_) { + CNNL_ASSERT(cnnlCreateTensorDescriptor(&a)); + CNNL_ASSERT(cnnlCreateTensorDescriptor(&b)); + CNNL_ASSERT(cnnlCreateTensorDescriptor(&c)); + if (addBias) { + CNNL_ASSERT(cnnlCreateTensorDescriptor(&bias)); + } + CNNL_ASSERT(cnnlMatMulDescCreate(&bmm)); + CNNL_ASSERT(cnnlMatMulAlgoCreate(&algo)); + CNNL_ASSERT(cnnlCreateMatMulHeuristicResult(&heuristic)); + } + ~Descriptors() noexcept(false) { + CNNL_ASSERT(cnnlDestroyTensorDescriptor(a)); + CNNL_ASSERT(cnnlDestroyTensorDescriptor(b)); + CNNL_ASSERT(cnnlDestroyTensorDescriptor(c)); + if (addBias) { + CNNL_ASSERT(cnnlDestroyTensorDescriptor(bias)); + } + CNNL_ASSERT(cnnlMatMulDescDestroy(bmm)); + CNNL_ASSERT(cnnlMatMulAlgoDestroy(algo)); + CNNL_ASSERT(cnnlDestroyMatMulHeuristicResult(heuristic)); + } + + Descriptors(const Descriptors &) = delete; + Descriptors(Descriptors &&) = delete; + }; + auto d = std::make_shared<Descriptors>(info.biasDim.has_value(), info.dataType != DT::F64); + setCnnlTensor(d->a, info.dataType, slice(info.aDim.data(), info.aDim.size())); + setCnnlTensor(d->b, info.dataType, slice(info.bDim.data(), info.bDim.size())); + setCnnlTensor(d->c, info.dataType, slice(info.cDim.data(), info.cDim.size())); + if (d->addBias) { + CNNL_ASSERT(cnnlSetTensorDescriptor( + d->bias, CNNL_LAYOUT_NCHW, cnnlDataTypeConvert(info.dataType), + info.biasDim.value().size(), info.biasDim.value().data())); + } + int32_t tA = info.transA, tB = info.transB; + CNNL_ASSERT(cnnlSetMatMulDescAttr(d->bmm, CNNL_MATMUL_DESC_TRANSA, + &tA, sizeof(int32_t))); + CNNL_ASSERT(cnnlSetMatMulDescAttr(d->bmm, CNNL_MATMUL_DESC_TRANSB, + &tB, sizeof(int32_t))); + auto handle = res.fetchOrStore<CnnlContext>()->handle; + int returnedAlgoCount = 0; + CNNL_ASSERT(cnnlGetBatchMatMulAlgoHeuristic( + handle, d->bmm, d->a, d->b, d->c, + NULL, 1, &(d->heuristic), &returnedAlgoCount)); + + size_t algoWorkspaceSize; + CNNL_ASSERT(cnnlGetBatchMatMulHeuristicResult(d->heuristic, d->algo, &algoWorkspaceSize)); + + res.fetchOrStore<CnnlContext>(); + auto routine = [d = std::move(d), algoWorkspaceSize, + aa = info.alpha, bb = info.beta](Resources &res, void *workspace, void const *const *inputs, void *const *outputs) { + // fetch cnnl handle from resources + auto handle = res.fetchOrStore<CnnlContext>()->handle; + + // build alpha/beta for double + auto alpha = d->f32 ? factor<fp32_t>(aa) : factor<fp64_t>(aa), + beta = d->f32 ? factor<fp32_t>(bb) : factor<fp64_t>(bb), + // one = d->f32 ? factor<fp32_t>(1) : factor<fp64_t>(1), + zero = d->f32 ? factor<fp32_t>(0) : factor<fp64_t>(0); + + if (d->addBias) { + CNNL_ASSERT(cnnlExpand(handle, d->bias, inputs[2], d->c, outputs[0])); + } + + if (alpha != 0) { + CNNL_ASSERT(cnnlBatchMatMulBCast_v2( + handle, d->bmm, d->algo, &alpha, + d->a, inputs[0], d->b, inputs[1], + d->addBias ? &beta : &zero, d->c, outputs[0], + workspace, algoWorkspaceSize)); + } + + }; + + return {std::move(routine), algoWorkspaceSize}; + } + + +#endif + +}// namespace refactor::kernel diff --git a/src/04kernel/src/kernels/mat_mul/cnnl_kernel.hh b/src/04kernel/src/kernels/mat_mul/cnnl_kernel.hh new file mode 100644 index 000000000..9b44b192c --- /dev/null +++ b/src/04kernel/src/kernels/mat_mul/cnnl_kernel.hh @@ -0,0 +1,32 @@ +#ifndef KERNEL_MATMUL_CNNL_KERNEL_HH +#define KERNEL_MATMUL_CNNL_KERNEL_HH + +#include "kernel/kernel.h" +#include "kernel/tensor.h" + +namespace refactor::kernel { + + struct MatMulCnnl final : public Kernel { + struct { + DataType dataType; + bool transA, transB; + float alpha, beta; + std::vector<int> aDim, bDim, cDim; + std::optional<std::vector<int>> biasDim; + } info; + + explicit MatMulCnnl(decltype(info)) noexcept; + + static KernelBox build(TensorRefs, TensorRefs, bool, bool, float, float) noexcept; + static size_t typeId() noexcept; + + size_t kernelTypeId() const noexcept final; + std::string_view description() const noexcept final; +#ifdef USE_BANG + RoutineWorkspace lower(Resources &) const noexcept final; +#endif + }; + +}// namespace refactor::kernel + +#endif// KERNEL_MATMUL_CNNL_KERNEL_HH diff --git a/src/04kernel/src/kernels/pad/cnnl_kernel.cc b/src/04kernel/src/kernels/pad/cnnl_kernel.cc new file mode 100644 index 000000000..39648f851 --- /dev/null +++ b/src/04kernel/src/kernels/pad/cnnl_kernel.cc @@ -0,0 +1,95 @@ +#include "cnnl_kernel.hh" + +#ifdef USE_BANG +#include "../../utilities/bang/cnnl_context.hh" +#include "../../utilities/bang/cnnl_functions.h" +#endif + +namespace refactor::kernel { + using K = PadCnnl; + + K::PadCnnl(DataType dataType_, PadType mode_, std::vector<int> inDim_, + std::vector<int> outDim_, std::vector<int> padDim_, size_t len_) noexcept + : Kernel(), dataType(dataType_), mode(mode_), inDim(std::move(inDim_)), + outDim(std::move(outDim_)), padDim(std::move(padDim_)), valueLength(len_) {} + + auto K::build(PadDimension dims_, DataType dataType_, PadType mode_, std::optional<std::reference_wrapper<Tensor const>> value_) noexcept -> KernelBox { +#ifndef USE_BANG + return nullptr; +#endif + if (mode_ != PadType::Constant || (value_ && value_->get().dataType != dataType_)) { + return nullptr; + } + size_t valueLength_ = value_ ? value_->get().dataType.size() : 0; + std::vector<int> inDim_, outDim_, padDim_; + for (auto dim : dims_) { + inDim_.push_back(dim.dimI); + outDim_.push_back(dim.dimO); + padDim_.push_back(dim.pads); + } + + return std::make_unique<K>(dataType_, mode_, inDim_, outDim_, padDim_, valueLength_); + } + + auto K::typeId() noexcept -> size_t { + static uint8_t ID = 1; + return reinterpret_cast<size_t>(&ID); + } + + auto K::kernelTypeId() const noexcept -> size_t { return typeId(); } + auto K::description() const noexcept -> std::string_view { + return "Performing Pad using CNNL"; + } + +#ifdef USE_BANG + auto K::lower(Resources &res) const noexcept -> RoutineWorkspace { + using namespace cnnl; + using namespace runtime; + + struct Descriptors { + cnnlTensorDescriptor_t inDesc, outDesc; + + Descriptors() : inDesc(nullptr), outDesc(nullptr) { + CNNL_ASSERT(cnnlCreateTensorDescriptor(&inDesc)); + CNNL_ASSERT(cnnlCreateTensorDescriptor(&outDesc)); + } + ~Descriptors() noexcept(false) { + CNNL_ASSERT(cnnlDestroyTensorDescriptor(inDesc)); + CNNL_ASSERT(cnnlDestroyTensorDescriptor(outDesc)); + } + }; + auto d = std::make_shared<Descriptors>(); + setCnnlTensor(d->inDesc, dataType, slice(inDim.data(), inDim.size())); + setCnnlTensor(d->outDesc, dataType, slice(outDim.data(), outDim.size())); + + std::vector<int> pads; + for (auto d : padDim) { + pads.push_back(d); + pads.push_back(d); + } + + res.fetchOrStore<CnnlContext>(); + return [d = std::move(d), val = valueLength, + p = std::vector<int>(pads.begin(), pads.end())](Resources &res, void *workspace, void const *const *inputs, void *const *outputs) { + void *paddingValue; + if (val != 0) { + paddingValue = malloc(val); + BANG_ASSERT(cnrtMemcpy(paddingValue, const_cast<void *>(inputs[2]), + val, CNRT_MEM_TRANS_DIR_DEV2HOST)); + } else { + float zero = 0.0; + paddingValue = &zero; + } + + CNNL_ASSERT(cnnlPad(res.fetchOrStore<CnnlContext>()->handle, + d->inDesc, inputs[0], p.data(), paddingValue, + d->outDesc, outputs[0])); + + if (val != 0) { + free(paddingValue); + } + }; + } +#endif + +}// namespace refactor::kernel diff --git a/src/04kernel/src/kernels/pad/cnnl_kernel.hh b/src/04kernel/src/kernels/pad/cnnl_kernel.hh new file mode 100644 index 000000000..202518fd4 --- /dev/null +++ b/src/04kernel/src/kernels/pad/cnnl_kernel.hh @@ -0,0 +1,28 @@ +#ifndef KERNEL_PAD_CNNL_HH +#define KERNEL_PAD_CNNL_HH + +#include "kernel/attributes/pad_info.h" +#include "kernel/collectors/pad.h" + +namespace refactor::kernel { + + struct PadCnnl final : public Kernel { + DataType dataType; + PadType mode; + std::vector<int> inDim, outDim, padDim; + size_t valueLength; + + PadCnnl(DataType, PadType, std::vector<int>, std::vector<int>, std::vector<int>, size_t) noexcept; + static KernelBox build(PadDimension, DataType, PadType, std::optional<std::reference_wrapper<Tensor const>>) noexcept; + static size_t typeId() noexcept; + + size_t kernelTypeId() const noexcept final; + std::string_view description() const noexcept final; +#ifdef USE_BANG + RoutineWorkspace lower(Resources &) const noexcept final; +#endif + }; + +}// namespace refactor::kernel + +#endif//KERNEL_PAD_CNNL_HH diff --git a/src/04kernel/src/kernels/pool/cnnl_kernel.cc b/src/04kernel/src/kernels/pool/cnnl_kernel.cc new file mode 100644 index 000000000..929ea5789 --- /dev/null +++ b/src/04kernel/src/kernels/pool/cnnl_kernel.cc @@ -0,0 +1,156 @@ +#include "cnnl_kernel.hh" + +#ifdef USE_BANG +#include "../../utilities/bang/cnnl_context.hh" +#include "../../utilities/bang/cnnl_functions.h" +#endif + +namespace refactor::kernel { + using K = PoolCnnl; + + K::PoolCnnl(decltype(info) info_) noexcept + : Kernel(), info(std::move(info_)) {} + + auto K::build(PoolType poolType, + bool ceil, + KernelShape const &kernelShape, + PoolAttributes const &poolAttributes, + Tensor const &x, + Tensor const &y) noexcept -> KernelBox { +#ifndef USE_BANG + return nullptr; +#endif + + // TODO check data type + auto p = poolAttributes.pads(), + d = poolAttributes.dilations(), + s = poolAttributes.strides(); + if (x.rank() != 4 || + poolType == PoolType::Lp || + d[0] != 1 || d[1] != 1) { + return nullptr; + } + return std::make_unique<K>(decltype(info){ + poolType, + x.dataType, + { + static_cast<int>(x.shape[0]), + static_cast<int>(x.shape[1]), + static_cast<int>(x.shape[2]), + static_cast<int>(x.shape[3]), + }, + { + static_cast<int>(y.shape[0]), + static_cast<int>(y.shape[1]), + static_cast<int>(y.shape[2]), + static_cast<int>(y.shape[3]), + }, + { + static_cast<int>(kernelShape[0]), + static_cast<int>(kernelShape[1]), + }, + {p[0], p[1], p[2], p[3]}, + {s[0], s[1]}, + {d[0], d[1]}, + ceil + }); + } + + auto K::typeId() noexcept -> size_t { + static uint8_t ID = 1; + return reinterpret_cast<size_t>(&ID); + } + + auto K::kernelTypeId() const noexcept -> size_t { return typeId(); } + auto K::description() const noexcept -> std::string_view { + return "Performing pool using CNNL"; + } + +#ifdef USE_BANG + + auto PoolCnnl::lower(Resources &res) const -> RoutineWorkspace { + using namespace cnnl; + using namespace runtime; + using Ty = PoolType; + + // RAII for closure + struct Descriptors { + cnnlTensorDescriptor_t x, y; + cnnlPoolingDescriptor_t pooling; + bool f32; + + Descriptors(decltype(f32) f32_) : f32(f32_) { + CNNL_ASSERT(cnnlCreateTensorDescriptor(&x)); + CNNL_ASSERT(cnnlCreateTensorDescriptor(&y)); + CNNL_ASSERT(cnnlCreatePoolingDescriptor(&pooling)); + } + ~Descriptors() noexcept(false) { + CNNL_ASSERT(cnnlDestroyTensorDescriptor(x)); + CNNL_ASSERT(cnnlDestroyTensorDescriptor(y)); + CNNL_ASSERT(cnnlDestroyPoolingDescriptor(pooling)); + } + + Descriptors(const Descriptors &) = delete; + Descriptors(Descriptors &&) = delete; + }; + auto d = std::make_shared<Descriptors>(info.dt != DataType::F64); + int const + xs[]{ + info.xShape[0], + info.xShape[1], + info.xShape[2] + std::abs(info.pads[0] - info.pads[2]), + info.xShape[3] + std::abs(info.pads[1] - info.pads[3]), + }, + *ys = info.yShape; + setCnnlTensor(d->x, info.dt, slice(xs, 4)); + setCnnlTensor(d->y, info.dt, slice(ys, 4)); + + // clang-format off + auto mode = info.poolType == Ty::Average ? CNNL_POOLING_AVERAGE_COUNT_EXCLUDE_PADDING + : info.poolType == Ty::Max ? CNNL_POOLING_MAX + : UNREACHABLEX(cnnlPoolingMode_t, ""); + // clang-format on + auto pp = info.pads; + auto ss = info.strides; + auto kk = info.kernelShape; + auto dd = info.dilations; + CNNL_ASSERT(cnnlSetPooling2dDescriptor_v2( + d->pooling, mode, CNNL_NOT_PROPAGATE_NAN, + kk[0], kk[1], pp[0], pp[2], pp[1], pp[3], + ss[0], ss[1], dd[0], dd[1], ceil)); + + auto handle = res.fetchOrStore<CnnlContext>()->handle; + size_t extraInputSize, workspaceSize; + CNNL_ASSERT(cnnlGetPoolingWorkspaceSize(handle, mode, ys[3], ys[2], &workspaceSize)); + CNNL_ASSERT(cnnlGetPoolingExtraInputSize(handle, mode, ys[3], ys[2], &extraInputSize)); + + res.fetchOrStore<CnnlContext>(); + auto routine = [d, workspaceSize, + extraInputSize](Resources &res, void *workspace, void const *const *inputs, void *const *outputs) { + auto handle = res.fetchOrStore<CnnlContext>()->handle; + + void *extraInputDev = workspace; + void *poolWorkSpace = reinterpret_cast<uint8_t *>(workspace) + extraInputSize; + + void *extraInputHost = malloc(extraInputSize); + CNNL_ASSERT(cnnlInitPoolingExtraInput(handle, d->pooling, d->x, d->y, extraInputHost)); + BANG_ASSERT(cnrtMemcpy(extraInputDev, extraInputHost, extraInputSize, CNRT_MEM_TRANS_DIR_HOST2DEV)); + + // build alpha/beta for double + auto a = d->f32 ? factor<fp32_t>(1) : factor<fp64_t>(1), + b = d->f32 ? factor<fp32_t>(0) : factor<fp64_t>(0); + CNNL_ASSERT(cnnlPoolingForward_v2( + handle, d->pooling, + &a, d->x, inputs[0], + &b, extraInputDev, d->y, outputs[0], + poolWorkSpace, workspaceSize)); + + res.fetchOrStore<CnnlContext>()->queueSync(); + + free(extraInputHost); + }; + return {std::move(routine), workspaceSize + extraInputSize}; + } +#endif + +}// namespace refactor::kernel diff --git a/src/04kernel/src/kernels/pool/cnnl_kernel.hh b/src/04kernel/src/kernels/pool/cnnl_kernel.hh new file mode 100644 index 000000000..0a0298ede --- /dev/null +++ b/src/04kernel/src/kernels/pool/cnnl_kernel.hh @@ -0,0 +1,45 @@ +#ifndef KERNEL_POOL_CNNL_KERNEL_HH +#define KERNEL_POOL_CNNL_KERNEL_HH + +#include "kernel/attributes/pool_attributes.h" +#include "kernel/kernel.h" +#include "kernel/tensor.h" + +namespace refactor::kernel { + + /// @brief Use `cnnlPoolingForward`. + /// It only supports 4D tensors. + struct PoolCnnl final : public Kernel { + struct + { + PoolType poolType; + DataType dt; + int xShape[4], + yShape[4], + kernelShape[2], + pads[4], + strides[2], + dilations[2]; + bool ceil; + } info; + + explicit PoolCnnl(decltype(info)) noexcept; + + static KernelBox build(PoolType, + bool, + KernelShape const &, + PoolAttributes const &, + Tensor const &, + Tensor const &) noexcept; + static size_t typeId() noexcept; + + size_t kernelTypeId() const noexcept final; + std::string_view description() const noexcept final; +#ifdef USE_BANG + RoutineWorkspace lower(Resources &) const final; +#endif + }; + +}// namespace refactor::kernel + +#endif// KERNEL_POOL_CNNL_KERNEL_HH diff --git a/src/04kernel/src/kernels/reduce/cnnl_kernel.cc b/src/04kernel/src/kernels/reduce/cnnl_kernel.cc new file mode 100644 index 000000000..4ea6fd827 --- /dev/null +++ b/src/04kernel/src/kernels/reduce/cnnl_kernel.cc @@ -0,0 +1,131 @@ +#include "cnnl_kernel.hh" + +#ifdef USE_BANG +#include "../../utilities/bang/cnnl_context.hh" +#include "../../utilities/bang/cnnl_functions.h" +#include "hardware/functions.h" +#endif + +namespace refactor::kernel { + using K = ReduceCnnl; + + K::ReduceCnnl( + decltype(dataType) dataType_, + decltype(reduceType) reduceType_, + decltype(axes) axes_, + decltype(shape) shape_) noexcept + : Kernel(), + dataType(dataType_), + reduceType(reduceType_), + axes(std::move(axes_)), + shape(std::move(shape_)) {} + + auto K::build(decltype(axes) axes_, ReduceType reduceType_, TensorRefs inputs_) noexcept -> KernelBox { +#ifndef USE_BANG + return nullptr; +#endif + + auto const &x = inputs_[0].get(); + return x.dataType.isFloat() + ? std::make_unique<K>(x.dataType, reduceType_, std::move(axes_), x.shape) + : nullptr; + } + + auto K::typeId() noexcept -> size_t { + static uint8_t ID = 1; + return reinterpret_cast<size_t>(&ID); + } + + auto K::kernelTypeId() const noexcept -> size_t { return typeId(); } + auto K::description() const noexcept -> std::string_view { + return "Performing reduce operation using CNNL"; + } + +#ifdef USE_BANG + + auto ReduceCnnl::lower(Resources &res) const -> RoutineWorkspace { + using namespace cnnl; + using namespace runtime; + + // RAII for closure + struct Descriptors { + cnnlTensorDescriptor_t x, y; + cnnlReduceDescriptor_t reduce; + bool f32; + + explicit Descriptors(decltype(f32) f32_) : f32(f32_) { + CNNL_ASSERT(cnnlCreateTensorDescriptor(&x)); + CNNL_ASSERT(cnnlCreateTensorDescriptor(&y)); + CNNL_ASSERT(cnnlCreateReduceDescriptor(&reduce)); + } + ~Descriptors() noexcept(false) { + CNNL_ASSERT(cnnlDestroyTensorDescriptor(x)); + CNNL_ASSERT(cnnlDestroyTensorDescriptor(y)); + CNNL_ASSERT(cnnlDestroyReduceDescriptor(reduce)); + } + + Descriptors(const Descriptors &) = delete; + Descriptors(Descriptors &&) = delete; + }; + auto d = std::make_shared<Descriptors>(dataType != DataType::F64); + + std::vector<int> + dimsI(shape.begin(), shape.end()), + dimsO(shape.begin(), shape.end()), + indices(axes.begin(), axes.end()); + for (auto axis : axes) { + dimsO[axis] = 1; + } + // setCnnlTensor(d->x, dataType, slice(dimsI.data(), dimsI.size())); + // setCnnlTensor(d->y, dataType, slice(dimsO.data(), dimsO.size())); + CNNL_ASSERT(cnnlSetTensorDescriptor(d->x, CNNL_LAYOUT_ARRAY, cnnlDataTypeConvert(dataType), dimsI.size(), dimsI.data())); + CNNL_ASSERT(cnnlSetTensorDescriptor(d->y, CNNL_LAYOUT_ARRAY, cnnlDataTypeConvert(dataType), dimsO.size(), dimsO.data())); + + // clang-format off + auto reduceOp = reduceType == ReduceType::Mean ? CNNL_REDUCE_AVG + : reduceType == ReduceType::Sum ? CNNL_REDUCE_ADD + : reduceType == ReduceType::Min ? CNNL_REDUCE_MIN + : reduceType == ReduceType::Max ? CNNL_REDUCE_MAX + : reduceType == ReduceType::L1 ? CNNL_REDUCE_NORM1 + : reduceType == ReduceType::L2 ? CNNL_REDUCE_NORM2 + : reduceType == ReduceType::Prod ? CNNL_REDUCE_MUL + : UNREACHABLEX(cnnlReduceOp_t, ""); + // clang-format on + CNNL_ASSERT(cnnlSetReduceDescriptor_v2( + d->reduce, indices.data(), indices.size(), reduceOp, + cnnlDataTypeConvert(d->f32 ? DataType::F32 : DataType::F64), + CNNL_NOT_PROPAGATE_NAN, CNNL_REDUCE_NO_INDICES, CNNL_32BIT_INDICES, 0.0)); + + auto handler = res.fetchOrStore<CnnlContext>()->handle; + size_t idxWorkspaceSize = indices.size() * sizeof(int); + // idxWorkspaceSize = hardware::alignBytes(idxWorkspaceSize, 256); + size_t workspaceSize; + // get workspace + CNNL_ASSERT(cnnlGetReduceOpWorkspaceSize(handler, d->x, d->y, d->reduce, &workspaceSize)); + + res.fetchOrStore<CnnlContext>(); + auto routine = [d = std::move(d), + idxWorkspaceSize, + workspaceSize](Resources &res, + void *workspace, + void const *const *inputs, + void *const *outputs) { + void *idxWorkspace = workspace, + *dataWorkspace = reinterpret_cast<uint8_t *>(workspace) + idxWorkspaceSize; + // build alpha/beta for double + auto a = d->f32 ? factor<fp32_t>(1) : factor<fp64_t>(1), + b = d->f32 ? factor<fp32_t>(0) : factor<fp64_t>(0); + CNNL_ASSERT(cnnlReduce( + res.fetchOrStore<CnnlContext>()->handle, + d->reduce, + dataWorkspace, workspaceSize, + &a, d->x, inputs[0], + idxWorkspaceSize, idxWorkspace, + &b, d->y, outputs[0])); + }; + return RoutineWorkspace(std::move(routine), idxWorkspaceSize + workspaceSize); + } + +#endif + +}// namespace refactor::kernel diff --git a/src/04kernel/src/kernels/reduce/cnnl_kernel.hh b/src/04kernel/src/kernels/reduce/cnnl_kernel.hh new file mode 100644 index 000000000..6ffaf7387 --- /dev/null +++ b/src/04kernel/src/kernels/reduce/cnnl_kernel.hh @@ -0,0 +1,32 @@ +#ifndef KERNEL_REDUCE_MEAN_CNNL_KERNEL_HH +#define KERNEL_REDUCE_MEAN_CNNL_KERNEL_HH + +#include "kernel/collectors/reduce.h" +#include "kernel/kernel.h" +#include "kernel/tensor.h" + +namespace refactor::kernel { + + struct ReduceCnnl final : public Kernel { + DataType dataType; + ReduceType reduceType; + Axes axes; + Shape shape; + + ReduceCnnl(decltype(dataType), + decltype(reduceType), + decltype(axes), + decltype(shape)) noexcept; + + static KernelBox build(decltype(axes), ReduceType, TensorRefs) noexcept; + static size_t typeId() noexcept; + + size_t kernelTypeId() const noexcept final; + std::string_view description() const noexcept final; +#ifdef USE_BANG + RoutineWorkspace lower(Resources &) const final; +#endif + }; +}// namespace refactor::kernel + +#endif// KERNEL_REDUCE_MEAN_CNNL_KERNEL_HH diff --git a/src/04kernel/src/kernels/scatter_nd/cnnl_kernel.cc b/src/04kernel/src/kernels/scatter_nd/cnnl_kernel.cc new file mode 100644 index 000000000..2152acf5b --- /dev/null +++ b/src/04kernel/src/kernels/scatter_nd/cnnl_kernel.cc @@ -0,0 +1,85 @@ +#include "cnnl_kernel.hh" + +#ifdef USE_BANG +#include "../../utilities/bang/cnnl_context.hh" +#include "../../utilities/bang/cnnl_functions.h" +#endif + +namespace refactor::kernel { + using K = ScatterNDCnnl; + + K::ScatterNDCnnl(decltype(info) info_) + : Kernel(), info(std::move(info_)) {} + + auto K::build(TensorRefs inputs, TensorRefs outputs) noexcept -> KernelBox { +#ifndef USE_BANG + return nullptr; +#endif + return std::make_unique<ScatterNDCnnl>(decltype(info){ + inputs[0].get().dataType, + inputs[1].get().dataType, + inputs[2].get().dataType, + std::vector<int>(inputs[0].get().shape.begin(), inputs[0].get().shape.end()), + std::vector<int>(inputs[1].get().shape.begin(), inputs[1].get().shape.end()), + std::vector<int>(inputs[2].get().shape.begin(), inputs[2].get().shape.end()), + std::vector<int>(outputs[0].get().shape.begin(), outputs[0].get().shape.end()), + }); + } + auto K::typeId() noexcept -> size_t { + static uint8_t ID = 1; + return reinterpret_cast<size_t>(&ID); + } + + auto K::kernelTypeId() const noexcept -> size_t { + return typeId(); + } + auto K::description() const noexcept -> std::string_view { + return "Performing scatterNd operation using CNNL"; + } + +#ifdef USE_BANG + auto K::lower(Resources &res) const noexcept -> RoutineWorkspace { + using namespace cnnl; + using namespace runtime; + + struct Descriptors { + cnnlTensorDescriptor_t inDesc, indexDesc, updateDesc, outDesc; + + Descriptors() : inDesc(nullptr), indexDesc(nullptr), + updateDesc(nullptr), outDesc(nullptr) { + CNNL_ASSERT(cnnlCreateTensorDescriptor(&inDesc)); + CNNL_ASSERT(cnnlCreateTensorDescriptor(&indexDesc)); + CNNL_ASSERT(cnnlCreateTensorDescriptor(&updateDesc)); + CNNL_ASSERT(cnnlCreateTensorDescriptor(&outDesc)); + } + ~Descriptors() noexcept(false) { + CNNL_ASSERT(cnnlDestroyTensorDescriptor(inDesc)); + CNNL_ASSERT(cnnlDestroyTensorDescriptor(indexDesc)); + CNNL_ASSERT(cnnlDestroyTensorDescriptor(updateDesc)); + CNNL_ASSERT(cnnlDestroyTensorDescriptor(outDesc)); + } + }; + auto d = std::make_shared<Descriptors>(); + CNNL_ASSERT(cnnlSetTensorDescriptor( + d->inDesc, CNNL_LAYOUT_ARRAY, cnnlDataTypeConvert(info.dataType), + info.inDim.size(), info.inDim.data())); + CNNL_ASSERT(cnnlSetTensorDescriptor( + d->indexDesc, CNNL_LAYOUT_ARRAY, cnnlDataTypeConvert(info.indexDataType), + info.indexDim.size(), info.indexDim.data())); + CNNL_ASSERT(cnnlSetTensorDescriptor( + d->updateDesc, CNNL_LAYOUT_ARRAY, cnnlDataTypeConvert(info.updateDataType), + info.updateDim.size(), info.updateDim.data())); + CNNL_ASSERT(cnnlSetTensorDescriptor( + d->outDesc, CNNL_LAYOUT_ARRAY, cnnlDataTypeConvert(info.updateDataType), + info.outDim.size(), info.outDim.data())); + + res.fetchOrStore<CnnlContext>(); + return [d = std::move(d)](Resources &res, void *workspace, void const *const *inputs, void *const *outputs) { + CNNL_ASSERT(cnnlScatterNd_v2(res.fetchOrStore<CnnlContext>()->handle, CNNL_SCATTERND_UPDATE, + d->indexDesc, inputs[1], d->updateDesc, inputs[2], + d->inDesc, inputs[0], d->outDesc, outputs[0])); + }; + } +#endif + +}// namespace refactor::kernel diff --git a/src/04kernel/src/kernels/scatter_nd/cnnl_kernel.hh b/src/04kernel/src/kernels/scatter_nd/cnnl_kernel.hh new file mode 100644 index 000000000..81fca0357 --- /dev/null +++ b/src/04kernel/src/kernels/scatter_nd/cnnl_kernel.hh @@ -0,0 +1,30 @@ +#ifndef KERNEL_SCATTER_ND_CNNL_KERNEL_HH +#define KERNEL_SCATTER_ND_CNNL_KERNEL_HH + +#include "kernel/attributes/scatter_nd_info.h" +#include "kernel/kernel.h" + +namespace refactor::kernel { + + struct ScatterNDCnnl final : public Kernel { + struct { + DataType dataType, indexDataType, updateDataType; + std::vector<int> inDim, indexDim, updateDim, outDim; + } info; + + explicit ScatterNDCnnl(decltype(info)); + + static KernelBox build(TensorRefs, TensorRefs) noexcept; + + static size_t typeId() noexcept; + + size_t kernelTypeId() const noexcept final; + std::string_view description() const noexcept final; +#ifdef USE_BANG + RoutineWorkspace lower(Resources &) const noexcept final; +#endif + }; + +}// namespace refactor::kernel + +#endif// KERNEL_SCATTER_ND_CNNL_KERNEL_HH diff --git a/src/04kernel/src/kernels/select/cnnl_kernel.cc b/src/04kernel/src/kernels/select/cnnl_kernel.cc new file mode 100644 index 000000000..be54a8904 --- /dev/null +++ b/src/04kernel/src/kernels/select/cnnl_kernel.cc @@ -0,0 +1,151 @@ +#include "cnnl_kernel.hh" + +#ifdef USE_BANG +#include "../../utilities/bang/cnnl_context.hh" +#include "../../utilities/bang/cnnl_functions.h" +#endif + +namespace refactor::kernel { + using K = SelectCnnl; + + K::SelectCnnl(decltype(dataType) dataType_, + decltype(selectType) selectType_, + decltype(inputDims) inputDims_, + decltype(outputDims) outputDims_, + decltype(inputsNum) inputsNum_) noexcept + : dataType(dataType_), + selectType(selectType_), + inputDims(std::move(inputDims_)), + outputDims(std::move(outputDims_)), + inputsNum(inputsNum_) {} + + auto K::build(SelectType selectType_, TensorRefs inputs_) noexcept -> KernelBox { +#ifndef USE_BANG + return nullptr; +#endif + auto dt = inputs_[0].get().dataType; + std::vector<std::vector<int>> inputDims_, outputDims_; + for (size_t i = 0; i < inputs_.size(); i++) { + auto shape = std::vector<int>(inputs_[i].get().shape.begin(), + inputs_[i].get().shape.end()); + if (shape.size() == 0) { + shape.push_back(1); + } + inputDims_.push_back(shape); + } + + auto broadcastShape = [](const std::vector<int> &shape1, const std::vector<int> &shape2) -> std::vector<int> { + int max_dim = std::max(shape1.size(), shape2.size()); + + std::vector<int> resultShape(max_dim, 1); + int dim_diff1 = max_dim - shape1.size(); + int dim_diff2 = max_dim - shape2.size(); + + for (int i = 0; i < max_dim; ++i) { + int dim_size1 = (i >= dim_diff1) ? shape1[i - dim_diff1] : 1; + int dim_size2 = (i >= dim_diff2) ? shape2[i - dim_diff2] : 1; + resultShape[i] = std::max(dim_size1, dim_size2); + } + + return resultShape; + }; + + for (size_t i = 1; i < inputs_.size(); i++) { + outputDims_.push_back(broadcastShape(inputDims_[i - 1], inputDims_[i])); + } + + return std::make_unique<K>(dt, selectType_, inputDims_, outputDims_, inputs_.size()); + } + + auto K::typeId() noexcept -> size_t { + static uint8_t ID = 1; + return reinterpret_cast<size_t>(&ID); + } + + auto K::kernelTypeId() const noexcept -> size_t { + return typeId(); + } + auto K::description() const noexcept -> std::string_view { + return "Performing select operation using CNNL"; + } + +#ifdef USE_BANG + auto K::lower(Resources &res) const -> RoutineWorkspace { + using namespace cnnl; + using namespace runtime; + + struct Descriptors { + std::vector<cnnlTensorDescriptor_t> in, out; + + explicit Descriptors(int n) + : in(std::vector<cnnlTensorDescriptor_t>(n, nullptr)), + out(std::vector<cnnlTensorDescriptor_t>(n - 1, nullptr)) { + for (auto i = 0; i < n; i++) { + CNNL_ASSERT(cnnlCreateTensorDescriptor(&in[i])); + if (i != n - 1) { + CNNL_ASSERT(cnnlCreateTensorDescriptor(&out[i])); + } + } + } + ~Descriptors() noexcept(false) { + for (size_t i = 0; i < in.size(); i++) { + CNNL_ASSERT(cnnlDestroyTensorDescriptor(in[i])); + if (i != in.size() - 1) { + CNNL_ASSERT(cnnlDestroyTensorDescriptor(out[i])); + } + } + } + + Descriptors(const Descriptors &) = delete; + Descriptors(Descriptors &&) = delete; + }; + auto d = std::make_shared<Descriptors>(inputsNum); + for (size_t i = 0; i < inputsNum; i++) { + setCnnlTensor(d->in[i], dataType, slice(inputDims[i].data(), inputDims[i].size())); + if (i != inputsNum - 1) { + setCnnlTensor(d->out[i], dataType, slice(outputDims[i].data(), outputDims[i].size())); + } + } + + auto handle = res.fetchOrStore<CnnlContext>()->handle; + size_t workspaceSize; + switch (selectType) { + case SelectType::Max: + CNNL_ASSERT(cnnlGetMaximumWorkspaceSize(handle, d->out.back(), &workspaceSize)); + break; + case SelectType::Min: + CNNL_ASSERT(cnnlGetMinimumWorkspaceSize(handle, d->out.back(), &workspaceSize)); + break; + default: + UNREACHABLE(); + } + + res.fetchOrStore<CnnlContext>(); + auto routine = [d = std::move(d), type = selectType, workspaceSize](Resources &res, void *workspace, void const *const *inputs, void *const *outputs) { + // fetch cnnl handle from resources + auto handle = res.fetchOrStore<CnnlContext>()->handle; + + auto select = + (type == SelectType::Max) ? cnnlMaximum + : (type == SelectType::Min) ? cnnlMinimum + : nullptr; + + for (size_t i = 1; i < d->in.size(); i++) { + if (i == 1) { + CNNL_ASSERT(select( + handle, d->in[0], inputs[0], d->in[1], inputs[1], + d->out[0], outputs[0], workspace, workspaceSize)); + } else { + CNNL_ASSERT(select( + handle, d->out[i - 2], outputs[0], d->in[i], inputs[i], + d->out[i - 1], outputs[0], workspace, workspaceSize)); + } + } + }; + + return {std::move(routine), workspaceSize}; + } + +#endif + +}// namespace refactor::kernel diff --git a/src/04kernel/src/kernels/select/cnnl_kernel.hh b/src/04kernel/src/kernels/select/cnnl_kernel.hh new file mode 100644 index 000000000..680911f36 --- /dev/null +++ b/src/04kernel/src/kernels/select/cnnl_kernel.hh @@ -0,0 +1,33 @@ +#ifndef KERNEL_SELECT_CNNL_KERNEL_HH +#define KERNEL_SELECT_CNNL_KERNEL_HH + +#include "kernel/attributes/broadcaster.h" +#include "kernel/collectors/select.h" +#include "kernel/kernel.h" +#include "kernel/tensor.h" + +namespace refactor::kernel { + + struct SelectCnnl final : public Kernel { + DataType dataType; + SelectType selectType; + std::vector<std::vector<int>> inputDims; + std::vector<std::vector<int>> outputDims; + size_t inputsNum; + + SelectCnnl(decltype(dataType), decltype(selectType), decltype(inputDims), + decltype(outputDims), decltype(inputsNum)) noexcept; + + static KernelBox build(SelectType, TensorRefs) noexcept; + static size_t typeId() noexcept; + + size_t kernelTypeId() const noexcept final; + std::string_view description() const noexcept final; +#ifdef USE_BANG + RoutineWorkspace lower(Resources &) const final; +#endif + }; + +}// namespace refactor::kernel + +#endif// KERNEL_SELECT_CNNL_KERNEL_HH diff --git a/src/04kernel/src/kernels/simple_binary/binary_cnnl.cc b/src/04kernel/src/kernels/simple_binary/binary_cnnl.cc new file mode 100644 index 000000000..d62ba0ab8 --- /dev/null +++ b/src/04kernel/src/kernels/simple_binary/binary_cnnl.cc @@ -0,0 +1,186 @@ +#include "binary_cnnl.hh" +#include <unordered_set> + +#ifdef USE_BANG +#include "../../utilities/bang/cnnl_context.hh" +#include "../../utilities/bang/cnnl_functions.h" +#endif + +namespace refactor::kernel { + using K = BinaryCnnl; + using Op = SimpleBinaryType; + using DT = DataType; + + K::BinaryCnnl(Op opType_, DT dataType_, std::vector<int> aDims_, std::vector<int> bDims_, std::vector<int> cDims_) noexcept + : Kernel(), dataType(dataType_), opType(opType_), aDims(aDims_), bDims(bDims_), cDims(cDims_) {} + + auto K::build(Op op, Tensor const &a, Tensor const &b, Tensor const &c) noexcept -> KernelBox { + static const std::unordered_set<Op> + ARTHIMETIC{Op::Add, Op::Sub, Op::Mul, Op::Div, Op::And, Op::Or, Op::Xor, Op::Pow, Op::Mod, Op::Fmod}; + +#ifndef USE_BANG + return nullptr; +#endif + + if (a.dataType != b.dataType || + // !a.dataType.isFloat() || + !ARTHIMETIC.contains(op) || + // At least one of a,b should have the same shape as c + (a.shape != c.shape && b.shape != c.shape)) { + return nullptr; + } + + auto shape2IntVec = [](Shape shape) -> std::vector<int> { + std::vector<int> intVector; + intVector.reserve(shape.size()); + for (const uint32_t &element : shape) { + intVector.push_back(static_cast<int>(element)); + } + return intVector; + }; + + return std::make_unique<K>(op, a.dataType, shape2IntVec(a.shape), shape2IntVec(b.shape), shape2IntVec(c.shape)); + } + + auto K::typeId() noexcept -> size_t { + static uint8_t ID = 1; + return reinterpret_cast<size_t>(&ID); + } + + auto K::kernelTypeId() const noexcept -> size_t { + return typeId(); + } + auto K::description() const noexcept -> std::string_view { + return "Performing element-wise op of 2 tensors with CNNL"; + } + +#ifdef USE_BANG + + auto BinaryCnnl::lower(Resources &res) const -> RoutineWorkspace { + using namespace cnnl; + using namespace runtime; + + struct Descriptors { + cnnlOpTensorDescriptor_t opDesc; + cnnlTensorDescriptor_t aDesc, bDesc, cDesc; + bool f32, sub; + + Descriptors(decltype(f32) f32_) : f32(f32_), sub(false) { + CNNL_ASSERT(cnnlCreateTensorDescriptor(&aDesc)); + CNNL_ASSERT(cnnlCreateTensorDescriptor(&bDesc)); + CNNL_ASSERT(cnnlCreateTensorDescriptor(&cDesc)); + CNNL_ASSERT(cnnlCreateOpTensorDescriptor(&opDesc)); + } + ~Descriptors() noexcept(false) { + CNNL_ASSERT(cnnlDestroyTensorDescriptor(aDesc)); + CNNL_ASSERT(cnnlDestroyTensorDescriptor(bDesc)); + CNNL_ASSERT(cnnlDestroyTensorDescriptor(cDesc)); + CNNL_ASSERT(cnnlDestroyOpTensorDescriptor(opDesc)); + } + }; + auto d = std::make_shared<Descriptors>(dataType != DT::F64); + cnnlOpTensorDesc_t cnnlOP; + cnnlLogicOp_t cnnlLogicOP; + if (opType == SimpleBinaryType::Add) { + cnnlOP = CNNL_OP_TENSOR_ADD; + } else if (opType == SimpleBinaryType::Sub) { + cnnlOP = CNNL_OP_TENSOR_ADD; + d->sub = true; + } else if (opType == SimpleBinaryType::Mul) { + cnnlOP = CNNL_OP_TENSOR_MUL; + } else if (opType == SimpleBinaryType::And) { + cnnlLogicOP = CNNL_LOGIC_OP_AND; + } else if (opType == SimpleBinaryType::Or) { + cnnlLogicOP = CNNL_LOGIC_OP_OR; + } else if (opType == SimpleBinaryType::Xor) { + cnnlLogicOP = CNNL_LOGIC_OP_XOR; + } + + setCnnlTensor(d->aDesc, dataType, slice(aDims.data(), aDims.size())); + setCnnlTensor(d->bDesc, dataType, slice(bDims.data(), bDims.size())); + setCnnlTensor(d->cDesc, dataType, slice(cDims.data(), cDims.size())); + if (cnnlOP) { + CNNL_ASSERT(cnnlSetOpTensorDescriptor( + d->opDesc, cnnlOP, + cnnlDataTypeConvert(dataType), + CNNL_NOT_PROPAGATE_NAN)); + } + + auto cnnlGetBinaryWorkspaceSize = + (opType == SimpleBinaryType::Add || opType == SimpleBinaryType::Sub || opType == SimpleBinaryType::Mul) ? cnnlGetOpTensorWorkspaceSize + : (opType == SimpleBinaryType::Div) ? cnnlGetDivWorkspaceSize + : (opType == SimpleBinaryType::And || opType == SimpleBinaryType::Or || opType == SimpleBinaryType::Xor) ? cnnlGetLogicOpWorkspaceSize + : (opType == SimpleBinaryType::Pow) ? cnnlGetPowWorkspaceSize + : (opType == SimpleBinaryType::Mod || opType == SimpleBinaryType::Fmod) ? cnnlGetFloorModWorkspaceSize + : nullptr; + + if (cnnlGetBinaryWorkspaceSize == nullptr) { + UNREACHABLE(); + } + + auto handle = res.fetchOrStore<CnnlContext>()->handle; + size_t workspaceSize; + CNNL_ASSERT(cnnlGetBinaryWorkspaceSize(handle, d->aDesc, + d->bDesc, d->cDesc, + &workspaceSize)); + + + res.fetchOrStore<CnnlContext>(); + auto routine = [d = std::move(d), + workspaceSize, cnnlLogicOP, + op = this->opType](Resources &res, void *workspace, void const *const *inputs, void *const *outputs) { + auto handle = res.fetchOrStore<CnnlContext>()->handle; + // name inputs and outputs + auto a = inputs[0], + b = inputs[1]; + auto c = outputs[0]; + if (op == SimpleBinaryType::Add || op == SimpleBinaryType::Sub || op == SimpleBinaryType::Mul) { + auto alphaA = d->f32 + ? factor<fp32_t>(1) + : factor<fp64_t>(1), + alphaB = d->f32 + ? factor<fp32_t>(d->sub ? -1 : 1) + : factor<fp64_t>(d->sub ? -1 : 1), + beta = d->f32 + ? factor<fp32_t>(0) + : factor<fp64_t>(0); + CNNL_ASSERT(cnnlOpTensor(handle, d->opDesc, + &alphaA, d->aDesc, a, + &alphaB, d->bDesc, b, + workspace, workspaceSize, + &beta, d->cDesc, c)); + } else if (op == SimpleBinaryType::Div) { + CNNL_ASSERT(cnnlDiv_v2(handle, + CNNL_COMPUTATION_HIGH_PRECISION, + d->aDesc, a, + d->bDesc, b, + workspace, workspaceSize, + d->cDesc, c)); + } else if (op == SimpleBinaryType::And || op == SimpleBinaryType::Or || op == SimpleBinaryType::Xor) { + CNNL_ASSERT(cnnlLogicOp(handle, cnnlLogicOP, + d->aDesc, a, + d->bDesc, b, + workspace, workspaceSize, + d->cDesc, c)); + } else if (op == SimpleBinaryType::Pow) { + CNNL_ASSERT(cnnlPow(handle, + CNNL_COMPUTATION_HIGH_PRECISION, + d->aDesc, a, + d->bDesc, b, + workspace, workspaceSize, + d->cDesc, c)); + } else if (op == SimpleBinaryType::Mod || op == SimpleBinaryType::Fmod) { + CNNL_ASSERT(cnnlFloorMod(handle, + d->aDesc, a, + d->bDesc, b, + d->cDesc, c, + workspace, workspaceSize)); + } + }; + + return {std::move(routine), workspaceSize}; + } + +#endif + +}// namespace refactor::kernel diff --git a/src/04kernel/src/kernels/simple_binary/binary_cnnl.hh b/src/04kernel/src/kernels/simple_binary/binary_cnnl.hh new file mode 100644 index 000000000..2d5c7cfaa --- /dev/null +++ b/src/04kernel/src/kernels/simple_binary/binary_cnnl.hh @@ -0,0 +1,28 @@ +#ifndef KERNEL_BINARY_CNNL_HH +#define KERNEL_BINARY_CNNL_HH + +#include "kernel/collectors/simple_binary.h" +#include "kernel/tensor.h" + +namespace refactor::kernel { + + struct BinaryCnnl final : public Kernel { + DataType dataType; + SimpleBinaryType opType; + std::vector<int> aDims, bDims, cDims; + + BinaryCnnl(SimpleBinaryType, DataType, std::vector<int> aDims_, std::vector<int> bDims_, std::vector<int> cDims_) noexcept; + + static KernelBox build(SimpleBinaryType, Tensor const &, Tensor const &, Tensor const &) noexcept; + static size_t typeId() noexcept; + + size_t kernelTypeId() const noexcept final; + std::string_view description() const noexcept final; +#ifdef USE_BANG + RoutineWorkspace lower(Resources &) const final; +#endif + }; + +}// namespace refactor::kernel + +#endif// KERNEL_BINARY_CNNL_HH diff --git a/src/04kernel/src/kernels/simple_unary/cnnl_activation_kernel.cc b/src/04kernel/src/kernels/simple_unary/cnnl_activation_kernel.cc new file mode 100644 index 000000000..68670662c --- /dev/null +++ b/src/04kernel/src/kernels/simple_unary/cnnl_activation_kernel.cc @@ -0,0 +1,98 @@ +#include "cnnl_activation_kernel.hh" +#include "kernel/collectors/simple_unary.h" +#include <unordered_set> + +#ifdef USE_BANG +#include "../../utilities/bang/cnnl_context.hh" +#include "../../utilities/bang/cnnl_functions.h" +#include <cnnl.h> +#endif + +namespace refactor::kernel { + using K = ActivationCnnl; + using DT = DataType; + using Op = SimpleUnaryType; + + K::ActivationCnnl(Op type_, DT dataType_, int size_) noexcept + : Kernel(), type(type_), dataType(dataType_), size(size_) {} + + auto K::build(Op op, Tensor const &a) noexcept -> KernelBox { + static const std::unordered_set<Op> ARTHIMETIC{Op::Sigmoid, Op::Relu, Op::Tanh, Op::HardSwish}; + +#ifndef USE_BANG + return nullptr; +#endif + + return ARTHIMETIC.contains(op) + ? std::make_unique<K>(op, a.dataType, static_cast<int>(a.elementsSize())) + : nullptr; + } + auto K::typeId() noexcept -> size_t { + static uint8_t ID = 1; + return reinterpret_cast<size_t>(&ID); + } + + auto K::kernelTypeId() const noexcept -> size_t { return typeId(); } + auto K::description() const noexcept -> std::string_view { + return "Performing activation using CNNL"; + } + +#ifdef USE_BANG + + auto ActivationCnnl::lower(Resources &res) const -> RoutineWorkspace { + using namespace cnnl; + using namespace runtime; + using Ty = SimpleUnaryType; + + // RAII for closure + struct Descriptors { + cnnlActivationDescriptor_t activation; + cnnlTensorDescriptor_t tensor; + + Descriptors() : activation(nullptr), tensor(nullptr) { + CNNL_ASSERT(cnnlCreateActivationDescriptor(&activation)); + CNNL_ASSERT(cnnlCreateTensorDescriptor(&tensor)); + } + ~Descriptors() noexcept(false) { + CNNL_ASSERT(cnnlDestroyActivationDescriptor(activation)); + CNNL_ASSERT(cnnlDestroyTensorDescriptor(tensor)); + } + + Descriptors(const Descriptors &) = delete; + Descriptors(Descriptors &&) = delete; + }; + auto d = std::make_shared<Descriptors>(); + + // clang-format off + auto mode = type == Ty::Relu ? CNNL_ACTIVATION_RELU + : type == Ty::Sigmoid ? CNNL_ACTIVATION_SIGMOID + : type == Ty::Tanh ? CNNL_ACTIVATION_TANH + : type == Ty::HardSwish ? CNNL_ACTIVATION_HARDSWISH + : UNREACHABLEX(cnnlActivationMode_t, ""); + float coef = 0.0; + float slicedDim = 0.0; + float gamma = 0.0; + float scale = 0.0; + // clang-format on + + setCnnlTensor(d->tensor, dataType, slice(&size, 1)); + CNNL_ASSERT(cnnlSetActivationDescriptor_v5(d->activation, mode, + CNNL_ACTIVATION_HIGH_PRECISION, + CNNL_NOT_PROPAGATE_NAN, coef, + slicedDim, gamma, scale, true)); + + res.fetchOrStore<CnnlContext>(); + return [d = std::move(d)]// + (Resources & res, void *, void const *const *inputs, void *const *outputs) { + float alpha = 1.f, beta = 0.f; + CNNL_ASSERT(cnnlActivationForward( + res.fetchOrStore<CnnlContext>()->handle, + d->activation, + &alpha, d->tensor, inputs[0], + &beta, d->tensor, outputs[0])); + }; + } + +#endif + +}// namespace refactor::kernel diff --git a/src/04kernel/src/kernels/simple_unary/cnnl_activation_kernel.hh b/src/04kernel/src/kernels/simple_unary/cnnl_activation_kernel.hh new file mode 100644 index 000000000..a5d7ad65c --- /dev/null +++ b/src/04kernel/src/kernels/simple_unary/cnnl_activation_kernel.hh @@ -0,0 +1,27 @@ +#ifndef KERNEL_ACTIVATION_CNNL_KERNEL_HH +#define KERNEL_ACTIVATION_CNNL_KERNEL_HH + +#include "kernel/collectors/simple_unary.h" + +namespace refactor::kernel { + + struct ActivationCnnl final : public Kernel { + SimpleUnaryType type; + DataType dataType; + int size; + + ActivationCnnl(SimpleUnaryType, DataType, int) noexcept; + + static KernelBox build(SimpleUnaryType, Tensor const &) noexcept; + static size_t typeId() noexcept; + + size_t kernelTypeId() const noexcept final; + std::string_view description() const noexcept final; +#ifdef USE_BANG + RoutineWorkspace lower(Resources &) const final; +#endif + }; + +}// namespace refactor::kernel + +#endif// KERNEL_ACTIVATION_CNNL_KERNEL_HH diff --git a/src/04kernel/src/kernels/simple_unary/cnnl_simple_unary_kernel.cc b/src/04kernel/src/kernels/simple_unary/cnnl_simple_unary_kernel.cc new file mode 100644 index 000000000..f8c0d7d01 --- /dev/null +++ b/src/04kernel/src/kernels/simple_unary/cnnl_simple_unary_kernel.cc @@ -0,0 +1,97 @@ +#include "cnnl_simple_unary_kernel.hh" +#include "kernel/collectors/simple_unary.h" +#include <unordered_set> + +#ifdef USE_BANG +#include "../../utilities/bang/cnnl_context.hh" +#include "../../utilities/bang/cnnl_functions.h" +#include <cnnl.h> +#endif + +namespace refactor::kernel { + using K = SimpleUnaryCnnl; + using DT = DataType; + using Op = SimpleUnaryType; + + K::SimpleUnaryCnnl(Op type_, DT dataType_, int size_) noexcept + : Kernel(), type(type_), dataType(dataType_), size(size_) {} + + auto K::build(Op op, Tensor const &a) noexcept -> KernelBox { + static const std::unordered_set<Op> supportedOp{Op::Abs, Op::Sqrt, Op::Neg, Op::Erf}; + +#ifndef USE_BANG + return nullptr; +#endif + + return supportedOp.contains(op) + ? std::make_unique<K>(op, a.dataType, static_cast<int>(a.elementsSize())) + : nullptr; + } + auto K::typeId() noexcept -> size_t { + static uint8_t ID = 1; + return reinterpret_cast<size_t>(&ID); + } + + auto K::kernelTypeId() const noexcept -> size_t { return typeId(); } + auto K::description() const noexcept -> std::string_view { + return "Performing simple unary using CNNL"; + } + +#ifdef USE_BANG + + auto SimpleUnaryCnnl::lower(Resources &res) const -> RoutineWorkspace { + using namespace cnnl; + using namespace runtime; + using Ty = SimpleUnaryType; + + // RAII for closure + struct Descriptors { + cnnlTensorDescriptor_t tensor; + + Descriptors() : tensor(nullptr) { + CNNL_ASSERT(cnnlCreateTensorDescriptor(&tensor)); + } + ~Descriptors() noexcept(false) { + CNNL_ASSERT(cnnlDestroyTensorDescriptor(tensor)); + } + + Descriptors(const Descriptors &) = delete; + Descriptors(Descriptors &&) = delete; + }; + auto d = std::make_shared<Descriptors>(); + + setCnnlTensor(d->tensor, dataType, slice(&size, 1)); + + auto cnnlUnaryForward = [t = this->type](cnnlHandle_t handle, + const cnnlTensorDescriptor_t x_desc, + const void *x, + const cnnlTensorDescriptor_t y_desc, + void *y) -> cnnlStatus_t { + switch (t) { + case Ty::Abs: + return cnnlAbs(handle, x_desc, x, y_desc, y); + case Ty::Neg: + return cnnlNegTensor(handle, x_desc, x, y_desc, y); + case Ty::Sqrt: + return cnnlSqrt_v2(handle, CNNL_COMPUTATION_HIGH_PRECISION, x_desc, x, y_desc, y); + case Ty::Erf: + return cnnlErf_v2(handle, CNNL_COMPUTATION_HIGH_PRECISION, x_desc, x, y_desc, y); + default: + // fmt::println("{}", unaryName(t)); + UNREACHABLE(); + } + }; + + res.fetchOrStore<CnnlContext>(); + return [d = std::move(d), cnnlUnaryForward]// + (Resources & res, void *, void const *const *inputs, void *const *outputs) { + CNNL_ASSERT(cnnlUnaryForward( + res.fetchOrStore<CnnlContext>()->handle, + d->tensor, inputs[0], + d->tensor, outputs[0])); + }; + } + +#endif + +}// namespace refactor::kernel diff --git a/src/04kernel/src/kernels/simple_unary/cnnl_simple_unary_kernel.hh b/src/04kernel/src/kernels/simple_unary/cnnl_simple_unary_kernel.hh new file mode 100644 index 000000000..b69902f7b --- /dev/null +++ b/src/04kernel/src/kernels/simple_unary/cnnl_simple_unary_kernel.hh @@ -0,0 +1,27 @@ +#ifndef KERNEL_SIMPLE_UNARY_CNNL_KERNEL_HH +#define KERNEL_SIMPLE_UNARY_CNNL_KERNEL_HH + +#include "kernel/collectors/simple_unary.h" + +namespace refactor::kernel { + + struct SimpleUnaryCnnl final : public Kernel { + SimpleUnaryType type; + DataType dataType; + int size; + + SimpleUnaryCnnl(SimpleUnaryType, DataType, int) noexcept; + + static KernelBox build(SimpleUnaryType, Tensor const &) noexcept; + static size_t typeId() noexcept; + + size_t kernelTypeId() const noexcept final; + std::string_view description() const noexcept final; +#ifdef USE_BANG + RoutineWorkspace lower(Resources &) const final; +#endif + }; + +}// namespace refactor::kernel + +#endif// KERNEL_SIMPLE_UNARY_CNNL_KERNEL_HH diff --git a/src/04kernel/src/kernels/slice/cnnl_kernel.cc b/src/04kernel/src/kernels/slice/cnnl_kernel.cc new file mode 100644 index 000000000..85bc90938 --- /dev/null +++ b/src/04kernel/src/kernels/slice/cnnl_kernel.cc @@ -0,0 +1,88 @@ +#include "cnnl_kernel.hh" + +#ifdef USE_BANG +#include "../../utilities/bang/cnnl_context.hh" +#include "../../utilities/bang/cnnl_functions.h" +#include <cnnl.h> +#endif + +namespace refactor::kernel { + using K = SliceCnnl; + + K::SliceCnnl(decltype(info) info_) noexcept + : Kernel(), info(std::move(info_)) {} + + auto K::build(DataType dt_, Dimensions dims_, Shape in_, Shape out_) noexcept -> KernelBox { +#ifndef USE_BANG + return nullptr; +#endif + return std::make_unique<K>(decltype(info){ + dt_, + dims_, + std::vector<int>(in_.begin(), in_.end()), + std::vector<int>(out_.begin(), out_.end()), + }); + } + auto K::typeId() noexcept -> size_t { + static uint8_t ID = 1; + return reinterpret_cast<size_t>(&ID); + } + + auto K::kernelTypeId() const noexcept -> size_t { + return typeId(); + } + auto K::description() const noexcept -> std::string_view { + return "Performing slice operation using CNNL"; + } + +#ifdef USE_BANG + auto SliceCnnl::lower(Resources &res) const -> RoutineWorkspace { + using namespace cnnl; + using namespace runtime; + using DT = DataType; + + struct Descriptors { + cnnlTensorDescriptor_t in, out; + bool f32; + + explicit Descriptors(decltype(f32) f32_) + : in(nullptr), out(nullptr), f32(f32_) { + CNNL_ASSERT(cnnlCreateTensorDescriptor(&in)); + CNNL_ASSERT(cnnlCreateTensorDescriptor(&out)); + } + ~Descriptors() noexcept(false) { + CNNL_ASSERT(cnnlDestroyTensorDescriptor(in)); + CNNL_ASSERT(cnnlDestroyTensorDescriptor(out)); + } + + Descriptors(const Descriptors &) = delete; + Descriptors(Descriptors &&) = delete; + }; + auto d = std::make_shared<Descriptors>(info.dataType != DT::F64); + // setCnnlTensor(d->in, info.dataType, slice(info.inDim.data(), info.inDim.size())); + // setCnnlTensor(d->out, info.dataType, slice(info.outDim.data(), info.outDim.size())); + CNNL_ASSERT(cnnlSetTensorDescriptor(d->in, CNNL_LAYOUT_NCHW, cnnlDataTypeConvert(info.dataType), info.inDim.size(), info.inDim.data())); + CNNL_ASSERT(cnnlSetTensorDescriptor(d->out, CNNL_LAYOUT_NCHW, cnnlDataTypeConvert(info.dataType), info.outDim.size(), info.outDim.data())); + std::vector<int> begin, end, stride; + for (size_t i = 0; i < info.dims.size(); i++) { + // [begin, end), end is not inclued + begin.push_back(info.dims[i].start); + auto sign = info.dims[i].step > 0 ? 1 : -1; + end.push_back(info.dims[i].start + info.dims[i].step * (info.dims[i].length - 1) + sign); + stride.push_back(info.dims[i].step); + } + + res.fetchOrStore<CnnlContext>(); + return [d = std::move(d), begin, end, stride](Resources &res, void *workspace, void const *const *inputs, void *const *outputs) { + // fetch cnnl handle from resources + auto handle = res.fetchOrStore<CnnlContext>()->handle; + + CNNL_ASSERT(cnnlStridedSlice( + handle, d->in, inputs[0], + begin.data(), end.data(), stride.data(), + d->out, outputs[0])); + }; + } +#endif + +}// namespace refactor::kernel diff --git a/src/04kernel/src/kernels/slice/cnnl_kernel.hh b/src/04kernel/src/kernels/slice/cnnl_kernel.hh new file mode 100644 index 000000000..ace79da1d --- /dev/null +++ b/src/04kernel/src/kernels/slice/cnnl_kernel.hh @@ -0,0 +1,32 @@ +#ifndef KERNEL_SLICE_CNNL_KERNEL_HH +#define KERNEL_SLICE_CNNL_KERNEL_HH + +#include "kernel/attributes/slice_info.h" +#include "kernel/collectors/slice.h" +#include "kernel/kernel.h" +#include "kernel/tensor.h" + +namespace refactor::kernel { + + struct SliceCnnl final : public Kernel { + struct { + DataType dataType; + Dimensions dims; + std::vector<int> inDim, outDim; + } info; + + explicit SliceCnnl(decltype(info)) noexcept; + + static KernelBox build(DataType, Dimensions, Shape, Shape) noexcept; + static size_t typeId() noexcept; + + size_t kernelTypeId() const noexcept final; + std::string_view description() const noexcept final; +#ifdef USE_BANG + RoutineWorkspace lower(Resources &) const final; +#endif + }; + +}// namespace refactor::kernel + +#endif// KERNEL_SLICE_CNNL_KERNEL_HH diff --git a/src/04kernel/src/kernels/softmax/cnnl_kernel.cc b/src/04kernel/src/kernels/softmax/cnnl_kernel.cc new file mode 100644 index 000000000..babaf33cc --- /dev/null +++ b/src/04kernel/src/kernels/softmax/cnnl_kernel.cc @@ -0,0 +1,88 @@ +#include "cnnl_kernel.hh" + +#ifdef USE_BANG +#include "../../utilities/bang/cnnl_context.hh" +#include "../../utilities/bang/cnnl_functions.h" +#endif + +namespace refactor::kernel { + using K = SoftmaxCnnl; + + K::SoftmaxCnnl(cnnl::SoftmaxAlgo algo_, DataType type_, + int pre_, int mid_, int post_) noexcept + : Kernel(), algo(algo_), dataType(type_), + pre(pre_), mid(mid_), post(post_) {} + + auto K::build(cnnl::SoftmaxAlgo algo, SoftmaxInfo info) noexcept -> KernelBox { +#ifndef USE_BANG + return nullptr; +#endif + + return std::make_unique<K>(algo, info.type, info.pre, info.mid, info.post); + } + auto K::typeId() noexcept -> size_t { + static uint8_t ID = 1; + return reinterpret_cast<size_t>(&ID); + } + + auto K::kernelTypeId() const noexcept -> size_t { + return typeId(); + } + auto K::description() const noexcept -> std::string_view { + return "Performing softmax forward with CNNL"; + } + +#ifdef USE_BANG + + auto SoftmaxCnnl::lower(Resources &res) const -> RoutineWorkspace { + using namespace cnnl; + using namespace runtime; + + // RAII for closure + struct Descriptors { + cnnlTensorDescriptor_t t; + cnnlSoftmaxAlgorithm_t algo; + bool f32; + + Descriptors(decltype(algo) algo_, decltype(f32) f32_) + : algo(algo_), f32(f32_) { + CNNL_ASSERT(cnnlCreateTensorDescriptor(&t)); + } + ~Descriptors() noexcept(false) { + CNNL_ASSERT(cnnlDestroyTensorDescriptor(t)); + } + Descriptors(const Descriptors &) = delete; + Descriptors(Descriptors &&) = delete; + }; + + auto d = std::make_shared<Descriptors>( + static_cast<cnnlSoftmaxAlgorithm_t>(algo), + dataType != DataType::F64); + int dims[]{pre, mid, post}; + // cnnlSoftmaxMode_t mode = (pre == 1) ? CNNL_SOFTMAX_MODE_HIGH_DIMENSION + // : (post == 1) ? CNNL_SOFTMAX_MODE_LOW_DIMENSION + // : CNNL_SOFTMAX_MODE_MEDIUM_DIMENSION; + // FIXME(bolun): CNNL Softmax mode + cnnlSoftmaxMode_t mode = CNNL_SOFTMAX_MODE_MEDIUM_DIMENSION; + + // cnnlSoftmaxForward_v2 is applied to a 3D input tensor only + CNNL_ASSERT(cnnlSetTensorDescriptor(d->t, CNNL_LAYOUT_ARRAY, cnnlDataTypeConvert(dataType), 3, dims)); + + res.fetchOrStore<CnnlContext>(); + return [d = std::move(d), mode](Resources &res, void *workspace, void const *const *inputs, void *const *outputs) { + // build alpha/beta for double + auto a = d->f32 ? factor<fp32_t>(1) : factor<fp64_t>(1), + b = d->f32 ? factor<fp32_t>(0) : factor<fp64_t>(0); + CNNL_ASSERT(cnnlSoftmaxForward_v2( + res.fetchOrStore<CnnlContext>()->handle, + d->algo, + mode, + CNNL_COMPUTATION_ULTRAHIGH_PRECISION, + &a, d->t, inputs[0], + &b, d->t, outputs[0])); + }; + } + +#endif + +}// namespace refactor::kernel diff --git a/src/04kernel/src/kernels/softmax/cnnl_kernel.hh b/src/04kernel/src/kernels/softmax/cnnl_kernel.hh new file mode 100644 index 000000000..b9bedb5a4 --- /dev/null +++ b/src/04kernel/src/kernels/softmax/cnnl_kernel.hh @@ -0,0 +1,36 @@ +#ifndef KERNEL_SOFTMAX_CNNL_HH +#define KERNEL_SOFTMAX_CNNL_HH + +#include "kernel/attributes/softmax_info.h" +#include "kernel/collectors/softmax.h" + +namespace refactor::kernel { + + namespace cnnl { + enum class SoftmaxAlgo { + FAST = 0, + ACCURATE = 1, + LOG = 2, + }; + }// namespace cnnl + + struct SoftmaxCnnl final : public Kernel { + cnnl::SoftmaxAlgo algo; + DataType dataType; + int pre, mid, post; + + SoftmaxCnnl(cnnl::SoftmaxAlgo, DataType, int, int, int) noexcept; + + static KernelBox build(cnnl::SoftmaxAlgo, SoftmaxInfo) noexcept; + static size_t typeId() noexcept; + + size_t kernelTypeId() const noexcept final; + std::string_view description() const noexcept final; +#ifdef USE_BANG + RoutineWorkspace lower(Resources &) const final; +#endif + }; + +}// namespace refactor::kernel + +#endif// KERNEL_SOFTMAX_CNNL_HH diff --git a/src/04kernel/src/kernels/split/cnnl_kernel.cc b/src/04kernel/src/kernels/split/cnnl_kernel.cc new file mode 100644 index 000000000..8f686d597 --- /dev/null +++ b/src/04kernel/src/kernels/split/cnnl_kernel.cc @@ -0,0 +1,114 @@ +#include "cnnl_kernel.hh" + +#ifdef USE_BANG +#include "../../utilities/bang/cnnl_context.hh" +#include "../../utilities/bang/cnnl_functions.h" +#include <cnnl.h> +#endif + +namespace refactor::kernel { + using K = SplitCnnl; + using Info = SplitInfoCnnl; + + Info::SplitInfoCnnl(DataType dt_, int axis_, int num_, std::vector<int> in_, std::vector<std::vector<int>> out_) + : dataType(dt_), axis(axis_), num(num_), inDim(std::move(in_)), outDims(std::move(out_)) {} + + + Info::SplitInfoCnnl(int axis, Tensor const &input, TensorRefs outputs) + : SplitInfoCnnl(input.dataType, axis, outputs.size(), + std::move(std::vector<int>(input.shape.begin(), input.shape.end())), + std::move([](TensorRefs tensors) -> std::vector<std::vector<int>> { + std::vector<std::vector<int>> res; + for (uint32_t i = 0; i < tensors.size(); i++) { + res.push_back(std::vector<int>(tensors[i].get().shape.begin(), + tensors[i].get().shape.end())); + } + return res; + }(outputs))) {} + + K::SplitCnnl(SplitInfoCnnl info_) noexcept + : Kernel(), info(std::move(info_)) {} + + auto K::build(int axis, Tensor const &input, TensorRefs outputs) noexcept -> KernelBox { +#ifndef USE_BANG + return nullptr; +#endif + return std::make_unique<K>(SplitInfoCnnl(axis, input, outputs)); + } + auto K::typeId() noexcept -> size_t { + static uint8_t ID = 1; + return reinterpret_cast<size_t>(&ID); + } + + auto K::kernelTypeId() const noexcept -> size_t { + return typeId(); + } + auto K::description() const noexcept -> std::string_view { + return "Performing concat operation using CNNL"; + } + +#ifdef USE_BANG + auto SplitCnnl::lower(Resources &res) const noexcept -> RoutineWorkspace { + using namespace cnnl; + using namespace runtime; + using DT = DataType; + + struct Descriptors { + cnnlTensorDescriptor_t in; + std::vector<cnnlTensorDescriptor_t> out; + bool f32; + + explicit Descriptors(int n, decltype(f32) f32_) + : in(nullptr), + out(std::vector<cnnlTensorDescriptor_t>(n, nullptr)), + f32(f32_) { + CNNL_ASSERT(cnnlCreateTensorDescriptor(&in)); + for (auto i = 0; i < n; i++) { + CNNL_ASSERT(cnnlCreateTensorDescriptor(&out[i])); + } + } + ~Descriptors() noexcept(false) { + CNNL_ASSERT(cnnlDestroyTensorDescriptor(in)); + for (size_t i = 0; i < out.size(); i++) { + CNNL_ASSERT(cnnlDestroyTensorDescriptor(out[i])); + } + } + + Descriptors(const Descriptors &) = delete; + Descriptors(Descriptors &&) = delete; + }; + auto d = std::make_shared<Descriptors>(info.num, info.dataType != DT::F64); + // setCnnlTensor(d->in, info.dataType, slice(info.inDim.data(), info.inDim.size())); + CNNL_ASSERT(cnnlSetTensorDescriptor(d->in, CNNL_LAYOUT_NCHW, cnnlDataTypeConvert(info.dataType), info.inDim.size(), info.inDim.data())); + + for (size_t i = 0; i < info.outDims.size(); i++) { + // setCnnlTensor(d->out[i], info.dataType, slice(info.outDims[i].data(), info.outDims[i].size())); + CNNL_ASSERT(cnnlSetTensorDescriptor(d->out[i], CNNL_LAYOUT_NCHW, cnnlDataTypeConvert(info.dataType), info.outDims[i].size(), info.outDims[i].data())); + } + + auto handle = res.fetchOrStore<CnnlContext>()->handle; + size_t workspaceSize; + CNNL_ASSERT(cnnlGetSplitWorkspaceSize(handle, info.num, &workspaceSize)); + + res.fetchOrStore<CnnlContext>(); + auto routine = [d = std::move(d), n = info.num, axis = info.axis, workspaceSize](Resources &res, void *workspace, void const *const *inputs, void *const *outputs) { + // fetch cnnl handle from resources + auto handle = res.fetchOrStore<CnnlContext>()->handle; + + void *argv[n]; + for (auto i = 0; i < n; i++) { + argv[i] = outputs[i]; + } + + CNNL_ASSERT(cnnlSplit( + handle, n, axis, d->in, inputs[0], + workspace, workspaceSize, d->out.data(), argv)); + }; + + return {std::move(routine), workspaceSize}; + } + +#endif + + +}// namespace refactor::kernel diff --git a/src/04kernel/src/kernels/split/cnnl_kernel.hh b/src/04kernel/src/kernels/split/cnnl_kernel.hh new file mode 100644 index 000000000..7fb4147d4 --- /dev/null +++ b/src/04kernel/src/kernels/split/cnnl_kernel.hh @@ -0,0 +1,37 @@ +#ifndef KERNEL_SPLIT_CNNL_KERNEL_HH +#define KERNEL_SPLIT_CNNL_KERNEL_HH + +#include "kernel/collectors/split.h" +#include "kernel/kernel.h" +#include "kernel/tensor.h" + +namespace refactor::kernel { + struct SplitInfoCnnl { + DataType dataType; + int axis; + int num; + std::vector<int> inDim; + std::vector<std::vector<int>> outDims; + + SplitInfoCnnl(DataType, int, int, std::vector<int>, std::vector<std::vector<int>>); + SplitInfoCnnl(int, Tensor const &, TensorRefs); + }; + + struct SplitCnnl final : public Kernel { + SplitInfoCnnl info; + + explicit SplitCnnl(SplitInfoCnnl) noexcept; + + static KernelBox build(int, Tensor const &, TensorRefs) noexcept; + static size_t typeId() noexcept; + + size_t kernelTypeId() const noexcept final; + std::string_view description() const noexcept final; +#ifdef USE_BANG + RoutineWorkspace lower(Resources &) const noexcept final; +#endif + }; + +}// namespace refactor::kernel + +#endif// KERNEL_SPLIT_CNNL_KERNEL_HH diff --git a/src/04kernel/src/kernels/transpose/cnnl_kernel.cc b/src/04kernel/src/kernels/transpose/cnnl_kernel.cc new file mode 100644 index 000000000..58f2d4fd4 --- /dev/null +++ b/src/04kernel/src/kernels/transpose/cnnl_kernel.cc @@ -0,0 +1,104 @@ +#include "cnnl_kernel.hh" + +#ifdef USE_BANG +#include "../../utilities/bang/cnnl_context.hh" +#include "../../utilities/bang/cnnl_functions.h" +#include <cnnl.h> +#endif + +namespace refactor::kernel { + using K = TransposeCnnl; + using Info = TransposeInfoCnnl; + + Info::TransposeInfoCnnl(DataType dataType_, std::vector<int> input_, std::vector<int> perm_) + : dataType(dataType_), inDim(input_), perm(perm_) { + ASSERT(input_.size() == perm_.size(), "Unreachable"); + for (uint32_t i = 0; i < input_.size(); i++) { + outDim.push_back(input_[perm_[i]]); + } + } + + Info::TransposeInfoCnnl(DataType dataType, Shape shape, Permutation perm) + : TransposeInfoCnnl(dataType, + std::move(std::vector<int>(shape.begin(), shape.end())), + std::move(std::vector<int>(perm.begin(), perm.end()))) { } + + K::TransposeCnnl(Info info_) noexcept + : Kernel(), info(std::move(info_)) { } + + auto K::build(DataType dataType, Shape shape, Permutation perm) noexcept -> KernelBox { +#ifndef USE_BANG + return nullptr; +#endif + return std::make_unique<K>(TransposeInfoCnnl(dataType, shape, perm)); + } + auto K::typeId() noexcept -> size_t { + static uint8_t ID = 1; + return reinterpret_cast<size_t>(&ID); + } + + auto K::kernelTypeId() const noexcept -> size_t { + return typeId(); + } + auto K::description() const noexcept -> std::string_view { + return "Performing transpose operation using CNNL"; + } + +#ifdef USE_BANG + auto TransposeCnnl::lower(Resources &res) const noexcept -> RoutineWorkspace { + using namespace cnnl; + using namespace runtime; + using DT = DataType; + + struct Descriptors { + cnnlTensorDescriptor_t x, y; + cnnlTransposeDescriptor_t trans; + bool f32; + + explicit Descriptors(decltype(f32) f32_) + : x(nullptr), y(nullptr), trans(nullptr), f32(f32_) { + CNNL_ASSERT(cnnlCreateTensorDescriptor(&x)); + CNNL_ASSERT(cnnlCreateTensorDescriptor(&y)); + CNNL_ASSERT(cnnlCreateTransposeDescriptor(&trans)); + } + ~Descriptors() noexcept(false) { + CNNL_ASSERT(cnnlDestroyTensorDescriptor(x)); + CNNL_ASSERT(cnnlDestroyTensorDescriptor(y)); + CNNL_ASSERT(cnnlDestroyTransposeDescriptor(trans)); + } + + Descriptors(const Descriptors &) = delete; + Descriptors(Descriptors &&) = delete; + }; + + auto d = std::make_shared<Descriptors>(info.dataType != DT::F64); + CNNL_ASSERT(cnnlSetTensorDescriptor( + d->x, CNNL_LAYOUT_ARRAY, cnnlDataTypeConvert(info.dataType), + info.inDim.size(), info.inDim.data())); + CNNL_ASSERT(cnnlSetTensorDescriptor( + d->y, CNNL_LAYOUT_ARRAY, cnnlDataTypeConvert(info.dataType), + info.outDim.size(), info.outDim.data())); + CNNL_ASSERT(cnnlSetTransposeDescriptor(d->trans, info.perm.size(), info.perm.data())); + + auto handle = res.fetchOrStore<CnnlContext>()->handle; + size_t workspaceSize; + CNNL_ASSERT(cnnlGetTransposeWorkspaceSize(handle, d->x, d->trans, &workspaceSize)); + + res.fetchOrStore<CnnlContext>(); + auto routine = [d = std::move(d), workspaceSize](Resources &res, void *workspace, void const *const *inputs, void *const *outputs) { + // fetch cnnl handle from resources + auto handle = res.fetchOrStore<CnnlContext>()->handle; + + // name inputs and outputs + auto x = inputs[0]; + auto y = outputs[0]; + + CNNL_ASSERT(cnnlTranspose_v2(handle, d->trans, d->x, x, + d->y, y, workspace, workspaceSize)); + }; + + return {std::move(routine), workspaceSize}; + } +#endif + +}// namespace refactor::kernel diff --git a/src/04kernel/src/kernels/transpose/cnnl_kernel.hh b/src/04kernel/src/kernels/transpose/cnnl_kernel.hh new file mode 100644 index 000000000..62254dc4e --- /dev/null +++ b/src/04kernel/src/kernels/transpose/cnnl_kernel.hh @@ -0,0 +1,37 @@ +#ifndef KERNEL_TRANSPOSE_CNNL_KERNEL_HH +#define KERNEL_TRANSPOSE_CNNL_KERNEL_HH + +#include "kernel/collectors/transpose.h" +#include "kernel/tensor.h" + +namespace refactor::kernel { + + using Shape = absl::InlinedVector<dim_t, 4>; + using Permutation = Shape; + + struct TransposeInfoCnnl { + DataType dataType; + std::vector<int> inDim, outDim, perm; + + TransposeInfoCnnl(DataType, std::vector<int>, std::vector<int>); + TransposeInfoCnnl(DataType, Shape, Permutation); + }; + + struct TransposeCnnl final : public Kernel { + TransposeInfoCnnl info; + + TransposeCnnl(TransposeInfoCnnl) noexcept; + + static KernelBox build(DataType, Shape, Permutation) noexcept; + static size_t typeId() noexcept; + + size_t kernelTypeId() const noexcept final; + std::string_view description() const noexcept final; +#ifdef USE_BANG + RoutineWorkspace lower(Resources &) const noexcept final; +#endif + }; + +}// namespace refactor::kernel + +#endif// KERNEL_TRANSPOSE_CNNL_KERNEL_HH diff --git a/src/04kernel/src/kernels/where/cnnl_kernel.cc b/src/04kernel/src/kernels/where/cnnl_kernel.cc new file mode 100644 index 000000000..14f8d6676 --- /dev/null +++ b/src/04kernel/src/kernels/where/cnnl_kernel.cc @@ -0,0 +1,111 @@ +#include "cnnl_kernel.hh" + +#ifdef USE_BANG +#include "../../utilities/bang/cnnl_context.hh" +#include "../../utilities/bang/cnnl_functions.h" +#include <cnnl.h> +#endif + +namespace refactor::kernel { + using K = WhereCnnl; + + K::WhereCnnl(decltype(info) info_) noexcept + : Kernel(), info(info_) {} + + auto K::build(TensorRefs const &inputs, TensorRefs const &outputs) noexcept -> KernelBox { +#ifndef USE_BANG + return nullptr; +#endif + std::vector<int> cDim(inputs[0].get().shape.begin(), inputs[0].get().shape.end()), + xDim(inputs[1].get().shape.begin(), inputs[1].get().shape.end()), + yDim(inputs[2].get().shape.begin(), inputs[2].get().shape.end()), + ansDim(outputs[0].get().shape.begin(), outputs[0].get().shape.end()); + if (ansDim.size() == 0) { + ansDim.push_back(1); + } + if (xDim.size() == 0) { + xDim.push_back(1); + } + if (yDim.size() == 0) { + yDim.push_back(1); + } + if (cDim.size() == 0) { + cDim.push_back(1); + } + return std::make_unique<K>(decltype(info){ + inputs[1].get().dataType, cDim, xDim, yDim, ansDim}); + } + auto K::typeId() noexcept -> size_t { + static uint8_t ID = 1; + return reinterpret_cast<size_t>(&ID); + } + + auto K::kernelTypeId() const noexcept -> size_t { + return typeId(); + } + auto K::description() const noexcept -> std::string_view { + return "Performing where operation using CNNL"; + } + +#ifdef USE_BANG + auto WhereCnnl::lower(Resources &res) const noexcept -> RoutineWorkspace { + using namespace cnnl; + using namespace runtime; + using DT = DataType; + + struct Descriptors { + cnnlTensorDescriptor_t cond, x, y, ans; + + explicit Descriptors() + : cond(nullptr), x(nullptr), y(nullptr), + ans(nullptr) { + CNNL_ASSERT(cnnlCreateTensorDescriptor(&cond)); + CNNL_ASSERT(cnnlCreateTensorDescriptor(&x)); + CNNL_ASSERT(cnnlCreateTensorDescriptor(&y)); + CNNL_ASSERT(cnnlCreateTensorDescriptor(&ans)); + } + ~Descriptors() noexcept(false) { + CNNL_ASSERT(cnnlDestroyTensorDescriptor(cond)); + CNNL_ASSERT(cnnlDestroyTensorDescriptor(x)); + CNNL_ASSERT(cnnlDestroyTensorDescriptor(y)); + CNNL_ASSERT(cnnlDestroyTensorDescriptor(ans)); + } + + Descriptors(const Descriptors &) = delete; + Descriptors(Descriptors &&) = delete; + }; + auto d = std::make_shared<Descriptors>(); + + CNNL_ASSERT(cnnlSetTensorDescriptor( + d->cond, CNNL_LAYOUT_ARRAY, cnnlDataTypeConvert(DT::Bool), + info.condDim.size(), info.condDim.data())); + CNNL_ASSERT(cnnlSetTensorDescriptor( + d->x, CNNL_LAYOUT_ARRAY, cnnlDataTypeConvert(info.dataType), + info.thenDim.size(), info.thenDim.data())); + CNNL_ASSERT(cnnlSetTensorDescriptor( + d->y, CNNL_LAYOUT_ARRAY, cnnlDataTypeConvert(info.dataType), + info.elseDim.size(), info.elseDim.data())); + CNNL_ASSERT(cnnlSetTensorDescriptor( + d->ans, CNNL_LAYOUT_ARRAY, cnnlDataTypeConvert(info.dataType), + info.outputDim.size(), info.outputDim.data())); + + auto handle = res.fetchOrStore<CnnlContext>()->handle; + size_t workspaceSize; + CNNL_ASSERT(cnnlGetSelectV2WorkspaceSize(handle, d->cond, d->x, d->y, &workspaceSize)); + + res.fetchOrStore<CnnlContext>(); + auto routine = [d = std::move(d), workspaceSize](Resources &res, void *workspace, void const *const *inputs, void *const *outputs) { + + CNNL_ASSERT(cnnlSelectV2( + res.fetchOrStore<CnnlContext>()->handle, + d->cond, inputs[0], d->x, inputs[1], + d->y, inputs[2], workspace, workspaceSize, + d->ans, outputs[0])); + + }; + + return {std::move(routine), workspaceSize}; + } +#endif + +}// namespace refactor::kernel diff --git a/src/04kernel/src/kernels/where/cnnl_kernel.hh b/src/04kernel/src/kernels/where/cnnl_kernel.hh new file mode 100644 index 000000000..ffe39a875 --- /dev/null +++ b/src/04kernel/src/kernels/where/cnnl_kernel.hh @@ -0,0 +1,30 @@ +#ifndef KERNEL_WHERE_CNNL_HH +#define KERNEL_WHERE_CNNL_HH + +#include "kernel/collectors/where.h" +#include "kernel/kernel.h" +#include "kernel/tensor.h" + +namespace refactor::kernel { + + struct WhereCnnl final : public Kernel { + struct { + DataType dataType; + std::vector<int> condDim, thenDim, elseDim, outputDim; + } info; + + WhereCnnl(decltype(info)) noexcept; + + static KernelBox build(TensorRefs const &, TensorRefs const &) noexcept; + static size_t typeId() noexcept; + + size_t kernelTypeId() const noexcept final; + std::string_view description() const noexcept final; +#ifdef USE_BANG + RoutineWorkspace lower(Resources &) const noexcept final; +#endif + }; + +}// namespace refactor::kernel + +#endif// KERNEL_WHERE_CNNL_HH diff --git a/src/04kernel/src/utilities/bang/cnnl_context.cc b/src/04kernel/src/utilities/bang/cnnl_context.cc new file mode 100644 index 000000000..f2ad33ab5 --- /dev/null +++ b/src/04kernel/src/utilities/bang/cnnl_context.cc @@ -0,0 +1,44 @@ +#ifdef USE_BANG + +#include "cnnl_context.hh" +#include "cnnl_functions.h" + +namespace refactor::kernel::cnnl { + + CnnlContext::CnnlContext() : runtime::Resource() { + BANG_ASSERT(cnrtQueueCreate(&queue)); + CNNL_ASSERT(cnnlCreate(&handle)); + CNNL_ASSERT(cnnlSetQueue(handle, queue)); + } + CnnlContext::~CnnlContext() { + BANG_ASSERT(cnrtQueueDestroy(queue)); + CNNL_ASSERT(cnnlDestroy(handle)); + } + + auto CnnlContext::typeId() noexcept -> size_t { + static uint8_t ID = 1; + return reinterpret_cast<size_t>(&ID); + } + auto CnnlContext::build() -> runtime::ResourceBox { + return std::make_unique<CnnlContext>(); + } + + auto CnnlContext::resourceTypeId() const noexcept -> size_t { + return typeId(); + } + auto CnnlContext::description() const noexcept -> std::string_view { + return "CnnlContext"; + } + + void CnnlContext::copyFromCPU(void *dst, const void *src, size_t size) { + BANG_ASSERT(cnrtMemcpy(dst, const_cast<void *>(src), size, + CNRT_MEM_TRANS_DIR_HOST2DEV)); + } + + void CnnlContext::queueSync() { + BANG_ASSERT(cnrtQueueSync(queue)); + } + +}// namespace refactor::kernel::cnnl + +#endif diff --git a/src/04kernel/src/utilities/bang/cnnl_context.hh b/src/04kernel/src/utilities/bang/cnnl_context.hh new file mode 100644 index 000000000..4743a0e4e --- /dev/null +++ b/src/04kernel/src/utilities/bang/cnnl_context.hh @@ -0,0 +1,31 @@ +#ifndef KERNEL_CNNL_CONTEXT_HH +#define KERNEL_CNNL_CONTEXT_HH + +#include "runtime/resource.h" +#include <cnnl.h> +#include <cnrt.h> + +namespace refactor::kernel::cnnl { + + struct CnnlContext final : public runtime::Resource { + cnnlHandle_t handle; + cnrtQueue_t queue; + + CnnlContext(); + ~CnnlContext(); + CnnlContext(CnnlContext const &) noexcept = delete; + CnnlContext(CnnlContext &&) noexcept = delete; + + static size_t typeId() noexcept; + static runtime::ResourceBox build(); + + size_t resourceTypeId() const noexcept final; + std::string_view description() const noexcept final; + + void copyFromCPU(void *dst, const void *src, size_t size); + void queueSync(); + }; + +}// namespace refactor::kernel::cnnl + +#endif// KERNEL_CNNL_CONTEXT_HH diff --git a/src/04kernel/src/utilities/bang/cnnl_functions.cpp b/src/04kernel/src/utilities/bang/cnnl_functions.cpp new file mode 100644 index 000000000..8dfeb6457 --- /dev/null +++ b/src/04kernel/src/utilities/bang/cnnl_functions.cpp @@ -0,0 +1,38 @@ +#ifdef USE_BANG + +#include "cnnl_functions.h" + +namespace refactor::kernel::cnnl { + + cnnlDataType_t cnnlDataTypeConvert(DataType dataType) { + // clang-format off + switch (dataType) { + case DataType::F32 : return CNNL_DTYPE_FLOAT; break; + case DataType::F64 : return CNNL_DTYPE_DOUBLE; break; + case DataType::FP16: return CNNL_DTYPE_HALF; break; + case DataType::I8 : return CNNL_DTYPE_INT8; break; + case DataType::I32 : return CNNL_DTYPE_INT32; break; + case DataType::U8 : return CNNL_DTYPE_UINT8; break; + case DataType::BF16: return CNNL_DTYPE_BFLOAT16; break; + case DataType::I64 : return CNNL_DTYPE_INT64; break; + case DataType::Bool: return CNNL_DTYPE_BOOL; break; + default: UNREACHABLE(); + } + // clang-format on + } + + void setCnnlTensor(cnnlTensorDescriptor_t t, DataType dt, slice_t<int> d) { + auto dt_ = cnnlDataTypeConvert(dt); + if (auto n = d.size(); n == 4) { + CNNL_ASSERT(cnnlSetTensorDescriptor(t, CNNL_LAYOUT_NCHW, dt_, d.size(), d.begin())); + } else if (n < 4) { + int d_[]{1, 1, 1, 1}; + std::copy_n(d.begin(), n, d_ + 4 - n); + CNNL_ASSERT(cnnlSetTensorDescriptor(t, CNNL_LAYOUT_NCHW, dt_, 4, std::move(d_))); + } else { + CNNL_ASSERT(cnnlSetTensorDescriptor(t, CNNL_LAYOUT_NCHW, dt_, d.size(), d.begin())); + } + } +}// namespace refactor::kernel::cnnl + +#endif diff --git a/src/04kernel/src/utilities/bang/cnnl_functions.h b/src/04kernel/src/utilities/bang/cnnl_functions.h new file mode 100644 index 000000000..4ba2f89d7 --- /dev/null +++ b/src/04kernel/src/utilities/bang/cnnl_functions.h @@ -0,0 +1,40 @@ +#ifndef KERNEL_CNNL_FUNCTIONS_H +#define KERNEL_CNNL_FUNCTIONS_H + +#include "common.h" +#include <cnnl.h> + +#define BANG_ASSERT(STATUS) \ + if (auto status = (STATUS); status != CNRT_RET_SUCCESS) { \ + RUNTIME_ERROR(fmt::format("bang failed on \"" #STATUS "\" with \"{}\" ({})", \ + cnrtGetErrorStr(status), (int) status)); \ + } + +#define CNNL_ASSERT(STATUS) \ + if (auto status = (STATUS); status != CNNL_STATUS_SUCCESS) { \ + fmt::println("cnnl failed on \"" #STATUS "\" with {}", \ + cnnlGetErrorString(status)); \ + abort(); \ + } + +namespace refactor::kernel::cnnl { + + cnnlDataType_t cnnlDataTypeConvert(DataType); + + // A helper function that set Cnnl tensor descriptor given tensor shape and type + void setCnnlTensor(cnnlTensorDescriptor_t, DataType, slice_t<int>); + + template<class T> + constexpr uint64_t factor(T x) noexcept { + static_assert(std::is_floating_point_v<T>); + static_assert(sizeof(T) <= sizeof(uint64_t)); + union { + T f; + uint64_t i; + } u{x}; + return u.i; + } + +}// namespace refactor::kernel::cnnl + +#endif// KERNEL_CNNL_FUNCTIONS_H diff --git a/src/04kernel/src/utilities/bang/cnrt_functions.cc b/src/04kernel/src/utilities/bang/cnrt_functions.cc new file mode 100644 index 000000000..26c1b975d --- /dev/null +++ b/src/04kernel/src/utilities/bang/cnrt_functions.cc @@ -0,0 +1,27 @@ +#ifdef USE_BANG +#include "cnrt_functions.h" +#include "cnnl_functions.h" +#include <cnrt.h> +#include <cstdio> + +namespace refactor::kernel::bang { + + int currentDevice() { + int device; + BANG_ASSERT(cnrtGetDevice(&device)); + return device; + } + + void sync() { + BANG_ASSERT(cnrtSyncDevice()); + } + + void copyOut(void *dst, const void *src, size_t size) { + sync(); + BANG_ASSERT(cnrtMemcpy(dst, const_cast<void *>(src), size, + CNRT_MEM_TRANS_DIR_DEV2HOST)); + } + +}// namespace refactor::kernel::bang + +#endif diff --git a/src/04kernel/src/utilities/bang/cnrt_functions.h b/src/04kernel/src/utilities/bang/cnrt_functions.h new file mode 100644 index 000000000..3a05195ce --- /dev/null +++ b/src/04kernel/src/utilities/bang/cnrt_functions.h @@ -0,0 +1,16 @@ +#ifndef KERNEL_CNRT_FUNCTIONS_H +#define KERNEL_CNRT_FUNCTIONS_H + +#include "common.h" + +namespace refactor::kernel::bang { + + int currentDevice(); + + void sync(); + + void copyOut(void *dst, const void *src, size_t size); + +}// namespace refactor::kernel::bang + +#endif// KERNEL_CNRT_FUNCTIONS_H diff --git a/src/04kernel/test/kernels/batch_normalization/test_cnnl.cpp b/src/04kernel/test/kernels/batch_normalization/test_cnnl.cpp new file mode 100644 index 000000000..d1ad9bd78 --- /dev/null +++ b/src/04kernel/test/kernels/batch_normalization/test_cnnl.cpp @@ -0,0 +1,72 @@ +#ifdef USE_BANG + +#include "../../../src/kernels/batch_normalization/cnnl_kernel.hh" +#include "../../../src/kernels/batch_normalization/cpu_kernel.hh" +#include "../src/utilities/bang/cnrt_functions.h" +#include "hardware/device_manager.h" +#include <gtest/gtest.h> + +using namespace refactor; +using namespace kernel; +using namespace hardware; + +TEST(kernel, BatchNormalizationCnnl) { + // build routine + auto xTensor = Tensor::share(DataType::F32, Shape{1, 2, 3, 2}); + auto outTensor = Tensor::share(DataType::F32, Shape{1, 2, 3, 2}); + auto scaleTensor = Tensor::share(DataType::F32, Shape{2}); + auto biasTensor = Tensor::share(DataType::F32, Shape{2}); + auto meanTensor = Tensor::share(DataType::F32, Shape{2}); + auto varTensor = Tensor::share(DataType::F32, Shape{2}); + float epsilon = 0.00001; + TensorRefs inputs = TensorRefs{*xTensor, *scaleTensor, *biasTensor, *meanTensor, *varTensor}; + auto kCpu = BatchNormalization::build(epsilon, inputs); + auto kCnnl = BatchNormalizationCnnl::build(epsilon, inputs); + ASSERT_TRUE(kCpu && kCnnl); + auto res = runtime::Resources(); + auto rCpu = kCpu->lower(res).routine; + auto [rMlu, workspaceSize] = kCnnl->lower(res); + // malloc + auto &dev = *device::init(Device::Type::Mlu, 0, ""); + auto workspace = dev.malloc(workspaceSize), + mluIn = dev.malloc(xTensor->bytesSize()), + mluScale = dev.malloc(scaleTensor->bytesSize()), + mluBias = dev.malloc(biasTensor->bytesSize()), + mluMean = dev.malloc(meanTensor->bytesSize()), + mluVar = dev.malloc(varTensor->bytesSize()), + mluOut = dev.malloc(outTensor->bytesSize()); + // put input data + std::vector<float> + data(xTensor->elementsSize(), 1.0f), + scale(scaleTensor->elementsSize(), 0.5f), + bias(biasTensor->elementsSize(), 1.0f), + mean(meanTensor->elementsSize(), 0.5f), + var(varTensor->elementsSize(), 1.0f), + cpuOut(outTensor->elementsSize()); + mluIn->copyFromHost(data.data(), xTensor->bytesSize()); + mluScale->copyFromHost(scale.data(), scaleTensor->bytesSize()); + mluBias->copyFromHost(bias.data(), biasTensor->bytesSize()); + mluMean->copyFromHost(mean.data(), meanTensor->bytesSize()); + mluVar->copyFromHost(var.data(), varTensor->bytesSize()); + // inference + { + void const *inputs[]{data.data(), scale.data(), bias.data(), mean.data(), var.data()}; + void *outputs[]{cpuOut.data()}; + rCpu(res, nullptr, inputs, outputs); + } + { + void const *inputs[]{*mluIn, *mluScale, *mluBias, *mluMean, *mluVar}; + void *outputs[]{*mluOut}; + rMlu(res, *workspace, inputs, outputs); + kernel::bang::sync(); + } + // take output data + std::vector<float> result(outTensor->elementsSize()); + mluOut->copyToHost(result.data(), outTensor->bytesSize()); + // check + for (auto i : range0_(result.size())) { + EXPECT_FLOAT_EQ(cpuOut[i], result[i]); + } +} + +#endif diff --git a/src/04kernel/test/kernels/cast/test_cnnl.cpp b/src/04kernel/test/kernels/cast/test_cnnl.cpp new file mode 100644 index 000000000..94297357d --- /dev/null +++ b/src/04kernel/test/kernels/cast/test_cnnl.cpp @@ -0,0 +1,51 @@ +#ifdef USE_BANG + +#include "../../../src/kernels/cast/cnnl_kernel.hh" +#include "../../../src/kernels/cast/cpu_kernel.hh" +#include "../src/utilities/bang/cnrt_functions.h" +#include "hardware/device_manager.h" +#include <gtest/gtest.h> +#include <numeric> + +using namespace refactor; +using namespace kernel; +using namespace hardware; + +TEST(kernel, CastCnnl) { + // build routine + auto x = Tensor::share(DataType::F32, Shape{2, 3, 4, 5}); + auto y = Tensor::share(DataType::I8, Shape{2, 3, 4, 5}); + auto kernel = CastCnnl::build(*x, *y), + kCpu = CastCpu::build(*x, *y); + ASSERT_TRUE(kernel && kCpu); + auto res = runtime::Resources(); + auto routine = kernel->lower(res).routine, + rCpu = kCpu->lower(res).routine; + // malloc + auto &dev = *device::init(Device::Type::Mlu, 0, ""); + auto xMlu = dev.malloc(x->bytesSize()), + yMlu = dev.malloc(y->bytesSize()); + // put input data + std::vector<float> x_(x->elementsSize()); + std::vector<int8_t> y_(y->elementsSize()); + std::iota(x_.begin(), x_.end(), 0); + xMlu->copyFromHost(x_.data(), x->bytesSize()); + // inference + { + void const *inputs[]{*xMlu}; + void *outputs[]{*yMlu}; + routine(res, nullptr, inputs, outputs); + kernel::bang::sync(); + } + { + void const *inputs[]{x_.data()}; + void *outputs[]{y_.data()}; + rCpu(res, nullptr, inputs, outputs); + } + // check + std::vector<int8_t> result(y->elementsSize()); + yMlu->copyToHost(result.data(), y->bytesSize()); + EXPECT_EQ(result, y_); +} + +#endif diff --git a/src/04kernel/test/kernels/clip/test_cnnl.cpp b/src/04kernel/test/kernels/clip/test_cnnl.cpp new file mode 100644 index 000000000..ff2e77290 --- /dev/null +++ b/src/04kernel/test/kernels/clip/test_cnnl.cpp @@ -0,0 +1,53 @@ +#ifdef USE_BANG + +#include "../../../src/kernels/clip/cnnl_kernel.hh" +#include "../../../src/kernels/clip/cpu_kernel.hh" +#include "../src/utilities/bang/cnrt_functions.h" +#include "hardware/device_manager.h" +#include <gtest/gtest.h> +#include <numeric> + +using namespace refactor; +using namespace kernel; +using namespace hardware; + +TEST(kernel, ClipCnnl) { + // build routine + auto data = Tensor::share(DataType::F32, Shape{2, 3, 4, 5}); + auto kernel = ClipCnnl::build(*data, true), + kCpu = ClipCpu::build(*data, true); + ASSERT_TRUE(kernel && kCpu); + auto res = runtime::Resources(); + auto routine = kernel->lower(res).routine, + rCpu = kCpu->lower(res).routine; + // malloc + auto &dev = *device::init(Device::Type::Mlu, 0, ""); + auto mluMem = dev.malloc(data->bytesSize()), + mluMin = dev.malloc(sizeof(float)), + mluMax = dev.malloc(sizeof(float)); + // put input data + std::vector<float> value(data->elementsSize()); + float min = 30, max = 80; + std::iota(value.begin(), value.end(), 0); + mluMem->copyFromHost(value.data(), data->bytesSize()); + mluMin->copyFromHost(&min, sizeof(float)); + mluMax->copyFromHost(&max, sizeof(float)); + // inference + { + void const *inputs[]{*mluMem, *mluMin, *mluMax}; + void *outputs[]{*mluMem}; + routine(res, nullptr, inputs, outputs); + kernel::bang::sync(); + } + { + void const *inputs[]{value.data(), &min, &max}; + void *outputs[]{value.data()}; + rCpu(res, nullptr, inputs, outputs); + } + // check + std::vector<float> result(data->elementsSize()); + mluMem->copyToHost(result.data(), data->bytesSize()); + EXPECT_EQ(result, value); +} + +#endif diff --git a/src/04kernel/test/kernels/concat/test_cnnl.cpp b/src/04kernel/test/kernels/concat/test_cnnl.cpp new file mode 100644 index 000000000..ecc817aca --- /dev/null +++ b/src/04kernel/test/kernels/concat/test_cnnl.cpp @@ -0,0 +1,81 @@ +#ifdef USE_BANG + +#include "../../../src/kernels/concat/cnnl_kernel.hh" +#include "../../../src/kernels/concat/cpu_kernel.hh" +#include "../src/utilities/bang/cnrt_functions.h" +#include "hardware/device_manager.h" +#include <gtest/gtest.h> +#include <numeric> + +using namespace refactor; +using namespace kernel; +using namespace hardware; + +TEST(kernel, ConcatCnnl) { + // build routine + std::vector<Arc<Tensor>> inputTensors{ + Tensor::share(DataType::F32, Shape{2, 3, 1, 1, 7, 7}),// 勿 + Tensor::share(DataType::F32, Shape{2, 3, 1, 9, 7, 7}),// 忘 + Tensor::share(DataType::F32, Shape{2, 3, 1, 3, 7, 7}),// 国 + Tensor::share(DataType::F32, Shape{2, 3, 1, 7, 7, 7}),// 耻 + }; + auto result = Tensor::share(DataType::F32, Shape{2, 3, 1, 20, 7, 7}); + TensorRefs inputs_; + inputs_.reserve(inputTensors.size()); + std::transform(inputTensors.begin(), inputTensors.end(), + std::back_inserter(inputs_), + [](auto const &it) { return std::cref(*it); }); + SplitInfo info(3, inputs_); + auto kCpu = ConcatCpu::build(info); + auto kernel = ConcatCnnl::build(3, inputs_, *result); + ASSERT_TRUE(kCpu && kernel); + auto res = runtime::Resources(); + auto rCpu = kCpu->lower(res).routine; + auto [routine, workspaceSize] = kernel->lower(res); + // malloc + auto &dev = *device::init(Device::Type::Mlu, 0, ""); + Arc<Device::Blob> + workspace = dev.malloc(workspaceSize), + mluIns[]{ + dev.malloc(inputTensors[0]->bytesSize()), + dev.malloc(inputTensors[1]->bytesSize()), + dev.malloc(inputTensors[2]->bytesSize()), + dev.malloc(inputTensors[3]->bytesSize()), + }, + mluOut = dev.malloc(result->bytesSize()); + // put input data + std::vector<float> + cpuIns[]{ + std::vector<float>(inputTensors[0]->elementsSize()), + std::vector<float>(inputTensors[1]->elementsSize()), + std::vector<float>(inputTensors[2]->elementsSize()), + std::vector<float>(inputTensors[3]->elementsSize()), + }, + cpuOut(result->elementsSize()), + out(result->elementsSize()); + std::iota(cpuIns[0].begin(), cpuIns[0].end(), 0); + std::iota(cpuIns[1].begin(), cpuIns[1].end(), 0); + std::iota(cpuIns[2].begin(), cpuIns[2].end(), 0); + std::iota(cpuIns[3].begin(), cpuIns[3].end(), 0); + mluIns[0]->copyFromHost(cpuIns[0].data(), inputTensors[0]->bytesSize()); + mluIns[1]->copyFromHost(cpuIns[1].data(), inputTensors[1]->bytesSize()); + mluIns[2]->copyFromHost(cpuIns[2].data(), inputTensors[2]->bytesSize()); + mluIns[3]->copyFromHost(cpuIns[3].data(), inputTensors[3]->bytesSize()); + // inference + { + void const *inputs[]{*mluIns[0], *mluIns[1], *mluIns[2], *mluIns[3]}; + void *outputs[]{*mluOut}; + routine(res, *workspace, inputs, outputs); + kernel::bang::sync(); + } + { + void const *inputs[]{cpuIns[0].data(), cpuIns[1].data(), cpuIns[2].data(), cpuIns[3].data()}; + void *outputs[]{cpuOut.data()}; + rCpu(res, nullptr, inputs, outputs); + } + // check + mluOut->copyToHost(out.data(), result->bytesSize()); + EXPECT_EQ(out, cpuOut); +} + +#endif diff --git a/src/04kernel/test/kernels/conv/test_cnnl.cpp b/src/04kernel/test/kernels/conv/test_cnnl.cpp new file mode 100644 index 000000000..74e799f5a --- /dev/null +++ b/src/04kernel/test/kernels/conv/test_cnnl.cpp @@ -0,0 +1,69 @@ +#ifdef USE_BANG + +#include "../../../src/kernels/conv/cnnl_kernel.hh" +#include "../src/utilities/bang/cnrt_functions.h" +#include "hardware/device_manager.h" +#include <gtest/gtest.h> +#include <numeric> + +using namespace refactor; +using namespace kernel; +using namespace hardware; + +void testConvCnnl(int rank, const int64_t *pads, const int64_t *strides, const int64_t *dilations, + Shape xShape, Shape wShape, Shape yShape, + const std::vector<float> &ExpectData) { + auto xTensor = Tensor::share(DataType::F32, xShape); + auto wTensor = Tensor::share(DataType::F32, wShape); + auto yTensor = Tensor::share(DataType::F32, yShape); + PoolAttributes poolAttributes(rank, dilations, pads, strides); + auto kernel = ConvCnnl::build(poolAttributes, *xTensor, *wTensor, std::nullopt, *yTensor); + ASSERT_TRUE(kernel); + auto res = runtime::Resources(); + auto [routine, workspaceSize] = kernel->lower(res); + // bang malloc + auto &dev = *device::init(Device::Type::Mlu, 0, ""); + auto workspace = dev.malloc(workspaceSize), + xMlu = dev.malloc(xTensor->bytesSize()), + wMlu = dev.malloc(wTensor->bytesSize()), + yMlu = dev.malloc(yTensor->bytesSize()); + // put input data + std::vector<int> xIncremental(xTensor->elementsSize()), + wIncremental(wTensor->elementsSize()); + std::iota(xIncremental.begin(), xIncremental.end(), 0); + std::iota(wIncremental.begin(), wIncremental.end(), 0); + std::vector<float> xData(xIncremental.begin(), xIncremental.end()), + wData(wIncremental.begin(), wIncremental.end()); + xMlu->copyFromHost(xData.data(), xTensor->bytesSize()); + wMlu->copyFromHost(wData.data(), wTensor->bytesSize()); + // inference + void const *inputs[]{*xMlu, *wMlu}; + void *outputs[]{*yMlu}; + routine(res, *workspace, inputs, outputs); + kernel::bang::sync(); + + // take output data + std::vector<float> result(yTensor->elementsSize()); + yMlu->copyToHost(result.data(), yTensor->bytesSize()); + // check + for (auto i : range0_(ExpectData.size())) { + EXPECT_FLOAT_EQ(ExpectData[i], result[i]); + } +} + +TEST(kernel, ConvCnnl) { + int rank = 2; + int64_t const + pads[]{1, 1, 1, 1}, + strides[]{1, 1}, + dilations[]{1, 1}; + Shape + xShape{1, 3, 3, 2}, + wShape{1, 3, 3, 2}, + yShape{1, 1, 3, 3}; + const std::vector<float> ExpectData = {570, 1158, 582, 888, 1785, 888, 582, 1158, 570}; + testConvCnnl(rank, pads, strides, dilations, xShape, wShape, yShape, ExpectData); +} + + +#endif diff --git a/src/04kernel/test/kernels/expand/test_cnnl.cpp b/src/04kernel/test/kernels/expand/test_cnnl.cpp new file mode 100644 index 000000000..43fb07e8d --- /dev/null +++ b/src/04kernel/test/kernels/expand/test_cnnl.cpp @@ -0,0 +1,52 @@ +#ifdef USE_BANG + +#include "../../../src/kernels/expand/cnnl_kernel.hh" +#include "../../../src/kernels/expand/cpu_kernel.hh" +#include "../src/utilities/bang/cnrt_functions.h" +#include "hardware/device_manager.h" +#include <gtest/gtest.h> +#include <numeric> + +using namespace refactor; +using namespace kernel; +using namespace hardware; + +TEST(kernel, ExpandCnnl) { + // build routine + auto input = Tensor::share(DataType::F32, {3, 4, 1, 6}), + output = Tensor::share(DataType::F32, {2, 3, 4, 5, 6}); + auto kernel = ExpandCnnl::build(*input, *output); + auto kCpu = ExpandCpu::build(ExpandInfo(*input, *output)); + ASSERT_TRUE(kernel && kCpu); + auto res = runtime::Resources(); + auto routine = kernel->lower(res).routine; + auto rCpu = kCpu->lower(res).routine; + // malloc + auto &dev = *device::init(Device::Type::Mlu, 0, ""); + auto mluIn = dev.malloc(input->bytesSize()), + mluOut = dev.malloc(output->bytesSize()); + // put input data + std::vector<float> + data(input->elementsSize()), + ans(output->elementsSize()), + result(ans.size()); + std::iota(data.begin(), data.end(), 0); + mluIn->copyFromHost(data.data(), input->bytesSize()); + // inference + { + void const *inputs[]{*mluIn}; + void *outputs[]{*mluOut}; + routine(res, nullptr, inputs, outputs); + kernel::bang::sync(); + } + { + void const *inputs[]{data.data()}; + void *outputs[]{ans.data()}; + rCpu(res, nullptr, inputs, outputs); + } + // check + mluOut->copyToHost(result.data(), output->bytesSize()); + EXPECT_EQ(result, ans); +} + +#endif diff --git a/src/04kernel/test/kernels/gather/test_gather_cnnl.cpp b/src/04kernel/test/kernels/gather/test_gather_cnnl.cpp new file mode 100644 index 000000000..b63682d26 --- /dev/null +++ b/src/04kernel/test/kernels/gather/test_gather_cnnl.cpp @@ -0,0 +1,148 @@ +#ifdef USE_BANG + +#include "../src/kernels/gather/cnnl_kernel.hh" +#include "../src/kernels/gather/cpu_kernel.hh" +#include "../src/utilities/bang/cnrt_functions.h" +#include "hardware/device_manager.h" +#include <gtest/gtest.h> + +using namespace refactor; +using namespace kernel; +using namespace hardware; + +TEST(kernel, GatherCnnl) { + auto &dev = *device::init(Device::Type::Mlu, 0, ""); + // Case axis = 0, indexType= int32 + { + // Create Tensor and build kernels + auto data = Tensor::share(DataType::F32, Shape{3, 2}, LayoutType::NCHW); + auto indices = Tensor::share(DataType::I32, Shape{2, 2}, LayoutType::NCHW); + auto output = Tensor::share(DataType::F32, Shape{2, 2, 2}, LayoutType::NCHW); + GatherInfo info(0, *data, *indices); + auto cnnlKernel = GatherCnnl::build(0, *data, *indices, *output); + auto cpuKernel = GatherCpu::build(info); + ASSERT_TRUE(cnnlKernel && cpuKernel); + auto res = runtime::Resources(); + auto [cnnlRoutine, workspaceSize] = cnnlKernel->lower(res); + auto cpuRoutine = cpuKernel->lower(res).routine; + // Init inputs and outputs + std::vector<float> a{1.0, 1.2, 2.3, 3.4, 4.5, 5.7}; + std::vector<int> b{0, 1, 1, 2}; + std::vector<float> c(output->elementsSize()); + auto workspace = dev.malloc(workspaceSize), + aMLU = dev.malloc(data->bytesSize()), + bMLU = dev.malloc(indices->bytesSize()), + cMLU = dev.malloc(output->bytesSize()); + aMLU->copyFromHost(a.data(), data->bytesSize()); + bMLU->copyFromHost(b.data(), indices->bytesSize()); + // Compute + { + void const *inputs[]{*aMLU, *bMLU}; + void *outputs[]{*cMLU}; + cnnlRoutine(res, *workspace, inputs, outputs); + kernel::bang::sync(); + } + { + void const *inputs[]{a.data(), b.data()}; + void *outputs[]{c.data()}; + cpuRoutine(res, nullptr, inputs, outputs); + } + // Compare + std::vector<float> result(output->elementsSize()); + cMLU->copyToHost(result.data(), output->bytesSize()); + for (auto i : range0_(c.size())) { + EXPECT_FLOAT_EQ(c[i], result[i]); + } + } + + // Case axis = 1, indexType= int32 + { + // Create Tensor and build kernels + auto data = Tensor::share(DataType::F32, Shape{3, 3}, LayoutType::NCHW); + auto indices = Tensor::share(DataType::I32, Shape{1, 2}, LayoutType::NCHW); + auto output = Tensor::share(DataType::F32, Shape{3, 1, 2}, LayoutType::NCHW); + GatherInfo info(1, *data, *indices); + auto cnnlKernel = GatherCnnl::build(1, *data, *indices, *output); + auto cpuKernel = GatherCpu::build(info); + ASSERT_TRUE(cnnlKernel && cpuKernel); + auto res = runtime::Resources(); + auto [cnnlRoutine, workspaceSize] = cnnlKernel->lower(res); + auto cpuRoutine = cpuKernel->lower(res).routine; + // Init inputs and outputs + std::vector<float> a{1.0, 1.2, 1.9, 2.3, 3.4, 3.9, 4.5, 5.7, 5.9}; + std::vector<int> b{0, 2}; + std::vector<float> c(output->elementsSize()); + auto workspace = dev.malloc(workspaceSize), + aMLU = dev.malloc(data->bytesSize()), + bMLU = dev.malloc(indices->bytesSize()), + cMLU = dev.malloc(output->bytesSize()); + aMLU->copyFromHost(a.data(), data->bytesSize()); + bMLU->copyFromHost(b.data(), indices->bytesSize()); + // Compute + { + void const *inputs[]{*aMLU, *bMLU}; + void *outputs[]{*cMLU}; + cnnlRoutine(res, *workspace, inputs, outputs); + kernel::bang::sync(); + } + { + void const *inputs[]{a.data(), b.data()}; + void *outputs[]{c.data()}; + cpuRoutine(res, nullptr, inputs, outputs); + } + // Compare + std::vector<float> result(output->elementsSize()); + cMLU->copyToHost(result.data(), output->bytesSize()); + for (auto i : range0_(c.size())) { + EXPECT_FLOAT_EQ(c[i], result[i]); + } + } + + // Case axis = 1, indexType= int32 + { + // Create Tensor and build kernels + auto data = Tensor::share(DataType::F32, Shape{32, 16}, LayoutType::NCHW); + auto indices = Tensor::share(DataType::I64, Shape{1, 4}, LayoutType::NCHW); + auto output = Tensor::share(DataType::F32, Shape{1, 4, 16}, LayoutType::NCHW); + GatherInfo info(0, *data, *indices); + auto cnnlKernel = GatherCnnl::build(0, *data, *indices, *output); + auto cpuKernel = GatherCpu::build(info); + ASSERT_TRUE(cnnlKernel && cpuKernel); + auto res = runtime::Resources(); + auto [cnnlRoutine, workspaceSize] = cnnlKernel->lower(res); + auto cpuRoutine = cpuKernel->lower(res).routine; + // Init inputs and outputs + std::vector<float> a; + for (size_t i = 0; i < data->elementsSize(); i++) { + a.push_back(i + 0.1f); + } + std::vector<int64_t> b(indices->elementsSize(), 0); + std::vector<float> c(output->elementsSize()); + auto workspace = dev.malloc(workspaceSize), + aMLU = dev.malloc(data->bytesSize()), + bMLU = dev.malloc(indices->bytesSize()), + cMLU = dev.malloc(output->bytesSize()); + aMLU->copyFromHost(a.data(), data->bytesSize()); + bMLU->copyFromHost(b.data(), indices->bytesSize()); + // Compute + { + void const *inputs[]{*aMLU, *bMLU}; + void *outputs[]{*cMLU}; + cnnlRoutine(res, *workspace, inputs, outputs); + kernel::bang::sync(); + } + { + void const *inputs[]{a.data(), b.data()}; + void *outputs[]{c.data()}; + cpuRoutine(res, nullptr, inputs, outputs); + } + // Compare + std::vector<float> result(output->elementsSize()); + cMLU->copyToHost(result.data(), output->bytesSize()); + for (auto i : range0_(c.size())) { + EXPECT_FLOAT_EQ(c[i], result[i]); + } + } +} + +#endif diff --git a/src/04kernel/test/kernels/hard_sigmoid/test_cnnl.cpp b/src/04kernel/test/kernels/hard_sigmoid/test_cnnl.cpp new file mode 100644 index 000000000..ad26438bf --- /dev/null +++ b/src/04kernel/test/kernels/hard_sigmoid/test_cnnl.cpp @@ -0,0 +1,51 @@ +#ifdef USE_BANG + +#include "../../../src/kernels/hard_sigmoid/cnnl_kernel.hh" +#include "../../../src/kernels/hard_sigmoid/cpu_kernel.hh" +#include "../src/utilities/bang/cnrt_functions.h" +#include "hardware/device_manager.h" +#include <gtest/gtest.h> + +using namespace refactor; +using namespace kernel; +using namespace hardware; + +TEST(kernel, HardSigmoidCnnl) { + // build routine + auto dataTensor = Tensor::share(DataType::F32, Shape{2, 3, 5}); + float alpha = 0.2f, beta = 0.5f; + auto kernel = HardSigmoidCnnl::build(alpha, beta, *dataTensor); + auto kCpu = HardSigmoidCpu::build(alpha, beta, *dataTensor); + ASSERT_TRUE(kernel && kCpu); + auto res = runtime::Resources(); + auto routine = kernel->lower(res).routine, + rCpu = kCpu->lower(res).routine; + // malloc + auto &dev = *device::init(Device::Type::Mlu, 0, ""); + auto mluMem = dev.malloc(dataTensor->bytesSize()); + // put input data + std::vector<float> data(dataTensor->elementsSize()); + for (auto i : range0_(data.size())) { data[i] = i; } + mluMem->copyFromHost(data.data(), dataTensor->bytesSize()); + // inference + { + void const *inputs[]{*mluMem}; + void *outputs[]{*mluMem}; + routine(res, nullptr, inputs, outputs); + kernel::bang::sync(); + } + { + void const *inputs[]{data.data()}; + void *outputs[]{data.data()}; + rCpu(res, nullptr, inputs, outputs); + } + // take output data + std::vector<float> result(dataTensor->elementsSize()); + mluMem->copyToHost(result.data(), dataTensor->bytesSize()); + // check + for (auto i : range0_(data.size())) { + EXPECT_FLOAT_EQ(data[i], result[i]); + } +} + +#endif diff --git a/src/04kernel/test/kernels/mat_mul/test_cnnl.cpp b/src/04kernel/test/kernels/mat_mul/test_cnnl.cpp new file mode 100644 index 000000000..f079b4445 --- /dev/null +++ b/src/04kernel/test/kernels/mat_mul/test_cnnl.cpp @@ -0,0 +1,210 @@ +#ifdef USE_BANG + +#include "../src/kernels/mat_mul/cnnl_kernel.hh" +#include "../src/kernels/mat_mul/cpu_kernel.hh" +#include "../src/utilities/bang/cnrt_functions.h" +#include "hardware/device_manager.h" +#include <gtest/gtest.h> + +using namespace refactor; +using namespace kernel; +using namespace hardware; + +TensorRefs getRefs(std::vector<Arc<Tensor>> tensors) { + TensorRefs refs; + std::transform(tensors.begin(), tensors.end(), std::back_inserter(refs), + [](auto const &it) { return std::cref(*it); }); + return refs; +} + +TEST(kernel, MatMulCnnl_OnlyBias) { + // build routine + auto A = Tensor::share(DataType::F32, Shape{2, 2, 2}); + auto B = Tensor::share(DataType::F32, Shape{2, 2}); + auto C = Tensor::share(DataType::F32, Shape{}); + auto Y = Tensor::share(DataType::F32, Shape{2, 2, 2}); + bool tA = false, tB = false; + float alpha = 0.0, beta = 1.0; + MatMulInfo info(*A, *B, *C, tA, tB, alpha, beta); + auto kernel = MatMulCnnl::build(getRefs({A, B, C}), getRefs({Y}), tA, tB, 0, 0); + ASSERT_TRUE(kernel); + auto res = runtime::Resources(); + auto [routine, workspaceSize] = kernel->lower(res); + // malloc + auto &dev = *device::init(Device::Type::Mlu, 0, ""); + auto workspace = dev.malloc(workspaceSize), + ma = dev.malloc(A->bytesSize()), + mb = dev.malloc(B->bytesSize()), + mc = dev.malloc(C->bytesSize()), + my = dev.malloc(Y->bytesSize()); + // put input data + std::vector<float> dataA{0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0}; + std::vector<float> dataB{0.0, 0.0, 0.0, 0.0}; + std::vector<float> dataC{2.5}; + std::vector<float> ans{2.5, 2.5, 2.5, 2.5, 2.5, 2.5, 2.5, 2.5}; + ma->copyFromHost(dataA.data(), A->bytesSize()); + mb->copyFromHost(dataB.data(), B->bytesSize()); + mc->copyFromHost(dataC.data(), C->bytesSize()); + // inference + void const *inputs[]{*ma, *mb, *mc}; + void *outputs[]{*my}; + routine(res, *workspace, inputs, outputs); + kernel::bang::sync(); + // take output data + std::vector<float> result(Y->elementsSize()); + my->copyToHost(result.data(), Y->bytesSize()); + // check + for (auto i : range0_(result.size())) { + EXPECT_FLOAT_EQ(result[i], ans[i]); + } +} + +TEST(kernel, MatMulCnnl_Broadcast) { + // build routine + auto A = Tensor::share(DataType::F32, Shape{2, 1, 2, 2}); + auto B = Tensor::share(DataType::F32, Shape{1, 2, 2, 2}); + auto C = Tensor::share(DataType::F32, Shape{2, 1}); + auto Y = Tensor::share(DataType::F32, Shape{2, 2, 2, 2}); + MatMulInfo info(*A, *B, *C, false, false, 1, 1); + auto cpuKernel = MatMulCPU::build(info); + auto mluKernel = MatMulCnnl::build(getRefs({A, B, C}), getRefs({Y}), false, false, 1.0, 1.0); + ASSERT_TRUE(cpuKernel && mluKernel); + auto res = runtime::Resources(); + auto cpuRoutine = cpuKernel->lower(res).routine; + auto [mluRoutine, workspaceSize] = mluKernel->lower(res); + // put input data + std::vector<float> dataA{1.0, 2.0, 0.0, 0.5, + 1.0, 0.0, 0.0, 1.0}; + std::vector<float> dataB{1.0, 2.0, 0.0, 0.5, + 1.0, 0.0, 0.0, 1.0}; + std::vector<float> dataC{1.0, 0.0}; + std::vector<float> cpuOut(Y->elementsSize()); + auto &dev = *device::init(Device::Type::Mlu, 0, ""); + auto workspace = dev.malloc(workspaceSize), + ma = dev.malloc(A->bytesSize()), + mb = dev.malloc(B->bytesSize()), + mc = dev.malloc(C->bytesSize()), + my = dev.malloc(Y->bytesSize()); + ma->copyFromHost(dataA.data(), A->bytesSize()); + mb->copyFromHost(dataB.data(), B->bytesSize()); + mc->copyFromHost(dataC.data(), C->bytesSize()); + // inference + { + void const *inputs[]{*ma, *mb, *mc}; + void *outputs[]{*my}; + mluRoutine(res, *workspace, inputs, outputs); + kernel::bang::sync(); + } + { + void const *inputs[]{dataA.data(), dataB.data(), dataC.data()}; + void *outputs[]{cpuOut.data()}; + cpuRoutine(res, nullptr, inputs, outputs); + } + // take output data + std::vector<float> result(Y->elementsSize()); + my->copyToHost(result.data(), Y->bytesSize()); + // check + EXPECT_EQ(result, cpuOut); +} + +TEST(kernel, MatMulCnnl_TransABNoBias) { + // build routine + auto A = Tensor::share(DataType::F32, Shape{1, 3, 2, 2}); + auto B = Tensor::share(DataType::F32, Shape{2, 1, 2, 2}); + auto Y = Tensor::share(DataType::F32, Shape{2, 3, 2, 2}); + MatMulInfo info(*A, *B, {}, true, true, 2.0, 1); + auto cpuKernel = MatMulCPU::build(info); + auto mluKernel = MatMulCnnl::build(getRefs({A, B}), getRefs({Y}), true, true, 2.0, 1.0); + ASSERT_TRUE(cpuKernel && mluKernel); + auto res = runtime::Resources(); + auto cpuRoutine = cpuKernel->lower(res).routine; + auto [mluRoutine, workspaceSize] = mluKernel->lower(res); + // put input data + std::vector<float> dataA{1.0, 2.0, 0.0, 0.5, + 1.0, 0.0, 0.0, 1.0, + 1.0, 2.0, 3.0, 4.0}; + std::vector<float> dataB{1.0, 2.0, 0.0, 0.5, + 1.0, 0.0, 0.0, 1.0}; + std::vector<float> cpuOut(Y->elementsSize()); + auto &dev = *device::init(Device::Type::Mlu, 0, ""); + auto workspace = dev.malloc(workspaceSize), + ma = dev.malloc(A->bytesSize()), + mb = dev.malloc(B->bytesSize()), + my = dev.malloc(Y->bytesSize()); + ma->copyFromHost(dataA.data(), A->bytesSize()); + mb->copyFromHost(dataB.data(), B->bytesSize()); + // inference + { + void const *inputs[]{*ma, *mb}; + void *outputs[]{*my}; + mluRoutine(res, *workspace, inputs, outputs); + kernel::bang::sync(); + } + { + void const *inputs[]{dataA.data(), dataB.data()}; + void *outputs[]{cpuOut.data()}; + cpuRoutine(res, nullptr, inputs, outputs); + } + // take output data + std::vector<float> result(Y->elementsSize()); + my->copyToHost(result.data(), Y->bytesSize()); + // check + EXPECT_EQ(result, cpuOut); +} + +TEST(kernel, MatMulCnnl_Large) { + // build routine + auto A = Tensor::share(DataType::F32, Shape{1, 512}); + auto B = Tensor::share(DataType::F32, Shape{1000, 512}); + auto C = Tensor::share(DataType::F32, Shape{1000}); + auto Y = Tensor::share(DataType::F32, Shape{1, 1000}); + MatMulInfo info(*A, *B, *C, false, true, 1, 1); + auto cpuKernel = MatMulCPU::build(info); + auto mluKernel = MatMulCnnl::build(getRefs({A, B, C}), getRefs({Y}), false, true, 1.0, 1.0); + ASSERT_TRUE(cpuKernel && mluKernel); + auto res = runtime::Resources(); + auto cpuRoutine = cpuKernel->lower(res).routine; + auto [mluRoutine, workspaceSize] = mluKernel->lower(res); + // put input data + std::vector<float> dataA(A->elementsSize()); + for (auto i : range0_(dataA.size())) { + dataA[i] = 1.0 * (i % 4) - 2.0; + } + std::vector<float> dataB(B->elementsSize()); + for (auto i : range0_(dataB.size())) { + dataB[i] = 1.0 * (i % 4) - 2.0; + } + std::vector<float> dataC(C->elementsSize()); + for (auto i : range0_(dataC.size())) { + dataC[i] = 1.0 * (i % 4) - 2.0; + } + std::vector<float> cpuOut(Y->elementsSize()); + auto &dev = *device::init(Device::Type::Mlu, 0, ""); + auto workspace = dev.malloc(workspaceSize), + ma = dev.malloc(A->bytesSize()), + mb = dev.malloc(B->bytesSize()), + mc = dev.malloc(C->bytesSize()), + my = dev.malloc(Y->bytesSize()); + ma->copyFromHost(dataA.data(), A->bytesSize()); + mb->copyFromHost(dataB.data(), B->bytesSize()); + mc->copyFromHost(dataC.data(), C->bytesSize()); + // inference + { + void const *inputs[]{*ma, *mb, *mc}; + void *outputs[]{*my}; + mluRoutine(res, *workspace, inputs, outputs); + kernel::bang::sync(); + } + { + void const *inputs[]{dataA.data(), dataB.data(), dataC.data()}; + void *outputs[]{cpuOut.data()}; + cpuRoutine(res, nullptr, inputs, outputs); + } + // take output data + std::vector<float> result(Y->elementsSize()); + my->copyToHost(result.data(), Y->bytesSize()); + // check + EXPECT_EQ(result, cpuOut); +} + +#endif diff --git a/src/04kernel/test/kernels/pad/test_cnnl.cpp b/src/04kernel/test/kernels/pad/test_cnnl.cpp new file mode 100644 index 000000000..9243be4c7 --- /dev/null +++ b/src/04kernel/test/kernels/pad/test_cnnl.cpp @@ -0,0 +1,131 @@ +#ifdef USE_BANG + +#include "../../../src/kernels/pad/cpu_kernel.hh" +#include "../../../src/kernels/pad/cnnl_kernel.hh" +#include "../src/utilities/bang/cnrt_functions.h" +#include "hardware/device_manager.h" +#include <gtest/gtest.h> + +using namespace refactor; +using namespace kernel; +using namespace hardware; + +TEST(kernel, PadCnnl) { + { + PadDimension dims{ + {2, 4, 1}, + {3, 5, 1}, + {1, 1, 0}, + {4, 8, 2}, + }; + // build routine + auto t1Tensor = Tensor::share(DataType::F32, Shape{2, 3, 1, 4}); + auto t2Tensor = Tensor::share(DataType::I64, Shape{8}); + auto t3Tensor = Tensor::share(DataType::F32, Shape{}); + auto yTensor = Tensor::share(DataType::F32, Shape{4, 5, 1, 8}); + PadType type = PadType::Constant; + auto kCpu = PadCpu::build(PadInfo(dims, *t1Tensor), type, std::make_optional(std::reference_wrapper(*t3Tensor))); + auto kernel = PadCnnl::build(dims, DataType::F32, type, std::make_optional(std::reference_wrapper(*t3Tensor))); + ASSERT_TRUE(kernel && kCpu); + auto res = runtime::Resources(); + auto routine = kernel->lower(res).routine, + rCpu = kCpu->lower(res).routine; + // malloc + auto &dev = *device::init(Device::Type::Mlu, 0, ""); + auto mluIn = dev.malloc(t1Tensor->bytesSize()), + mluIn2 = dev.malloc(t2Tensor->bytesSize()), + mluIn3 = dev.malloc(t3Tensor->bytesSize()), + mluOut = dev.malloc(yTensor->bytesSize()); + // put input data + std::vector<float> data(t1Tensor->elementsSize()), + constvalue(1, 1.2f), + cpuOut(yTensor->elementsSize()); + std::vector<int64_t> pads{1, 1, 0, 2, 1, 1, 0, 2}; + + + for (auto i : range0_(data.size())) { data[i] = i; } + mluIn->copyFromHost(data.data(), t1Tensor->bytesSize()); + mluIn2->copyFromHost(pads.data(), t2Tensor->bytesSize()); + mluIn3->copyFromHost(constvalue.data(), t3Tensor->bytesSize()); + + // inference + { + void const *inputs[]{*mluIn, *mluIn2, *mluIn3}; + void *outputs[]{*mluOut}; + routine(res, nullptr, inputs, outputs); + kernel::bang::sync(); + } + { + void const *inputs[]{data.data(), pads.data(), constvalue.data()}; + void *outputs[]{cpuOut.data()}; + rCpu(res, nullptr, inputs, outputs); + } + // take output data + std::vector<float> result(yTensor->elementsSize()); + mluOut->copyToHost(result.data(), yTensor->bytesSize()); + // check + for (auto i : range0_(cpuOut.size())) { + EXPECT_FLOAT_EQ(cpuOut[i], result[i]); + } + } + + { + PadDimension dims{ + {2, 2, 0}, + {3, 3, 0}, + {1, 1, 0}, + {4, 4, 0}, + }; + // build routine + auto t1Tensor = Tensor::share(DataType::F32, Shape{2, 3, 1, 4}); + auto t2Tensor = Tensor::share(DataType::I64, Shape{8}); + auto t3Tensor = Tensor::share(DataType::F32, Shape{}); + auto yTensor = Tensor::share(DataType::F32, Shape{2, 3, 1, 4}); + PadType type = PadType::Constant; + auto kCpu = PadCpu::build(PadInfo(dims, *t1Tensor), type, std::make_optional(std::reference_wrapper(*t3Tensor))); + auto kernel = PadCnnl::build(dims, DataType::F32, type, std::make_optional(std::reference_wrapper(*t3Tensor))); + ASSERT_TRUE(kernel && kCpu); + auto res = runtime::Resources(); + auto routine = kernel->lower(res).routine, + rCpu = kCpu->lower(res).routine; + // malloc + auto &dev = *device::init(Device::Type::Mlu, 0, ""); + auto mluIn = dev.malloc(t1Tensor->bytesSize()), + mluIn2 = dev.malloc(t2Tensor->bytesSize()), + mluIn3 = dev.malloc(t3Tensor->bytesSize()), + mluOut = dev.malloc(yTensor->bytesSize()); + // put input data + std::vector<float> data(t1Tensor->elementsSize()), + constvalue(1, 1.2f), + cpuOut(yTensor->elementsSize()); + std::vector<int64_t> pads{0, 0, 0, 0, 0, 0, 0, 0}; + + + for (auto i : range0_(data.size())) { data[i] = i; } + mluIn->copyFromHost(data.data(), t1Tensor->bytesSize()); + mluIn2->copyFromHost(pads.data(), t2Tensor->bytesSize()); + mluIn3->copyFromHost(constvalue.data(), t3Tensor->bytesSize()); + + // inference + { + void const *inputs[]{*mluIn, *mluIn2, *mluIn3}; + void *outputs[]{*mluOut}; + routine(res, nullptr, inputs, outputs); + kernel::bang::sync(); + } + { + void const *inputs[]{data.data(), pads.data(), constvalue.data()}; + void *outputs[]{cpuOut.data()}; + rCpu(res, nullptr, inputs, outputs); + } + // take output data + std::vector<float> result(yTensor->elementsSize()); + mluOut->copyToHost(result.data(), yTensor->bytesSize()); + // check + for (auto i : range0_(cpuOut.size())) { + EXPECT_FLOAT_EQ(cpuOut[i], result[i]); + } + } +} + +#endif diff --git a/src/04kernel/test/kernels/pool/test_cnnl.cpp b/src/04kernel/test/kernels/pool/test_cnnl.cpp new file mode 100644 index 000000000..1adf45139 --- /dev/null +++ b/src/04kernel/test/kernels/pool/test_cnnl.cpp @@ -0,0 +1,72 @@ +#ifdef USE_BANG + +#include "../../../src/kernels/pool/cnnl_kernel.hh" +#include "../src/utilities/bang/cnrt_functions.h" +#include "hardware/device_manager.h" +#include <gtest/gtest.h> + +using namespace refactor; +using namespace kernel; +using namespace hardware; + +void testPoolCnnl(PoolType poolType, int rank, const int64_t *pads, const int64_t *strides, KernelShape kernelShape, Shape xShape, Shape yShape, const std::vector<float> &ExpectData) { + auto dataTensor = Tensor::share(DataType::F32, xShape); + auto yTensor = Tensor::share(DataType::F32, yShape); + //bool ceil = false; + bool ceil = true; + int64_t const dilations[] = {1, 1}; + PoolAttributes poolAttributes(rank, dilations, pads, strides); + + auto kernel = PoolCnnl::build(poolType, ceil, kernelShape, poolAttributes, *dataTensor, *yTensor); + ASSERT_TRUE(kernel); + auto res = runtime::Resources(); + auto [routine, workspaceSize] = kernel->lower(res); + // bang malloc + auto &dev = *device::init(Device::Type::Mlu, 0, ""); + auto workspace = dev.malloc(workspaceSize), + mluMem = dev.malloc(dataTensor->bytesSize()); + // put input data + std::vector<float> data(dataTensor->elementsSize()); + for (auto i : range0_(data.size())) { data[i] = i * 0.1f; } + mluMem->copyFromHost(data.data(), dataTensor->bytesSize()); + // inference + void const *inputs[]{*mluMem}; + void *outputs[]{*mluMem}; + routine(res, *workspace, inputs, outputs); + kernel::bang::sync(); + // take output data + std::vector<float> result(yTensor->elementsSize()); + mluMem->copyToHost(result.data(), yTensor->bytesSize()); + // check + for (auto i : range0_(ExpectData.size())) { + EXPECT_FLOAT_EQ(ExpectData[i], result[i]); + } +} + +TEST(kernel, PoolCnnlMax) { + int rank = 2; + int64_t const + pads[]{0, 0, 0, 0}, + strides[]{2, 2}; + KernelShape kernelShape{2, 2}; + Shape + xShape{1, 1, 4, 4}, + yShape{1, 1, 2, 2}; + const std::vector<float> ExpectData = {0.5, 0.7, 1.3, 1.5}; + testPoolCnnl(PoolType::Max, rank, pads, strides, kernelShape, xShape, yShape, ExpectData); +} + +TEST(kernel, PoolCnnlAvg) { + int rank = 2; + int64_t const + pads[]{0, 0, 0, 0}, + strides[]{2, 2}; + KernelShape kernelShape{2, 2}; + Shape + xShape{1, 1, 4, 4}, + yShape{1, 1, 2, 2}; + const std::vector<float> ExpectData = {0.25, 0.45, 1.05, 1.25}; + testPoolCnnl(PoolType::Average, rank, pads, strides, kernelShape, xShape, yShape, ExpectData); +} + +#endif diff --git a/src/04kernel/test/kernels/reduce/test_cnnl.cpp b/src/04kernel/test/kernels/reduce/test_cnnl.cpp new file mode 100644 index 000000000..113fe7664 --- /dev/null +++ b/src/04kernel/test/kernels/reduce/test_cnnl.cpp @@ -0,0 +1,66 @@ +#ifdef USE_BANG + +#include "../../../src/kernels/reduce/cnnl_kernel.hh" +#include "../src/utilities/bang/cnrt_functions.h" +#include "hardware/device_manager.h" +#include <gtest/gtest.h> + +using namespace refactor; +using namespace kernel; +using namespace hardware; + +static void testReducemean(const Shape &shape, const std::vector<float> &data, + Axes axes, const std::vector<float> ExpectData) { + // build routine + auto dataTensor = Tensor::share(DataType::F32, shape); + auto kernel = ReduceCnnl::build(axes, ReduceType::Mean, {*dataTensor}); + ASSERT_TRUE(kernel); + auto res = runtime::Resources(); + auto [routine, workspaceSize] = kernel->lower(res); + // bang malloc + auto &dev = *device::init(Device::Type::Mlu, 0, ""); + auto workspace = dev.malloc(workspaceSize), + mluMemIn = dev.malloc(dataTensor->bytesSize()), + mluMemOut = dev.malloc(dataTensor->bytesSize()); + // put input output data + mluMemIn->copyFromHost(data.data(), dataTensor->bytesSize()); + // inference + { + void const *inputs[]{*mluMemIn}; + void *outputs[]{*mluMemOut}; + routine(res, *workspace, inputs, outputs); + kernel::bang::sync(); + } + // take output data + Shape outDimArray; + std::unordered_set axesSet(axes.begin(), axes.end()); + for (size_t i = 0; i < shape.size(); ++i) { + if (axesSet.contains(i)) { + outDimArray.push_back(shape[i]); + } + } + auto outputTensor = Tensor::share(DataType::F32, outDimArray); + std::vector<float> result(outDimArray.size()); + mluMemOut->copyToHost(result.data(), outputTensor->bytesSize()); + // check + for (auto i : range0_(ExpectData.size())) { + EXPECT_FLOAT_EQ(ExpectData[i], result[i]); + } +} + +TEST(kernel, ReduceMeanCnnl) { + testReducemean({2, 3, 2, 2}, + {0, 1, 2, 3, 4, 5, 6, 7, + 8, 9, 10, 11, 12, 13, 14, 15, + 16, 17, 18, 19, 20, 21, 22, 23}, + {1, 2}, + {5, 6, 17, 18}); + testReducemean({2, 3, 2, 2, 1}, + {0, 1, 2, 3, 4, 5, 6, 7, + 8, 9, 10, 11, 12, 13, 14, 15, + 16, 17, 18, 19, 20, 21, 22, 23}, + {1, 2}, + {5, 6, 17, 18}); +} + +#endif diff --git a/src/04kernel/test/kernels/scatter_nd/test_cnnl.cpp b/src/04kernel/test/kernels/scatter_nd/test_cnnl.cpp new file mode 100644 index 000000000..cf5b9c367 --- /dev/null +++ b/src/04kernel/test/kernels/scatter_nd/test_cnnl.cpp @@ -0,0 +1,65 @@ +#ifdef USE_BANG + +#include "../../../src/kernels/scatter_nd/cnnl_kernel.hh" +#include "../../../src/kernels/scatter_nd/cpu_kernel.hh" +#include "../src/utilities/bang/cnrt_functions.h" +#include "hardware/device_manager.h" +#include <gtest/gtest.h> +#include <numeric> + +using namespace refactor; +using namespace kernel; +using namespace hardware; + +TEST(kernel, ScatterNDCnnl) { + // build routine + auto data = Tensor::share(DataType::F32, Shape{8}); + auto indices = Tensor::share(DataType::I64, Shape{4, 1}); + auto updates = Tensor::share(DataType::F32, Shape{4}); + auto output = Tensor::share(DataType::F32, Shape{8}); + ScatterNDInfo info(*data, *indices); + auto getRefs = [](std::vector<Arc<Tensor>> tensors) -> TensorRefs { + TensorRefs refs; + std::transform(tensors.begin(), tensors.end(), std::back_inserter(refs), + [](auto const &it) { return std::cref(*it); }); + return refs; + }; + auto kernel = ScatterNDCnnl::build(getRefs({data, indices, updates}), getRefs({output})), + kCpu = ScatterNDCpu::build(info); + ASSERT_TRUE(kernel && kCpu); + auto res = runtime::Resources(); + auto routine = kernel->lower(res).routine, + rCpu = kCpu->lower(res).routine; + // malloc + auto &dev = *device::init(Device::Type::Mlu, 0, ""); + auto mluData = dev.malloc(data->bytesSize()), + mluIndices = dev.malloc(indices->bytesSize()), + mluUpdates = dev.malloc(updates->bytesSize()), + mluOut = dev.malloc(output->bytesSize()); + // put input data + std::vector<float> data_(data->elementsSize()); + std::iota(data_.begin(), data_.end(), 1); + std::vector<int64_t> indices_{4, 3, 1, 7}; + std::vector<float> updates_{9, 10, 11, 12}; + mluData->copyFromHost(data_.data(), data->bytesSize()); + mluIndices->copyFromHost(indices_.data(), indices->bytesSize()); + mluUpdates->copyFromHost(updates_.data(), updates->bytesSize()); + // inference + { + void const *inputs[]{*mluData, *mluIndices, *mluUpdates}; + void *outputs[]{*mluOut}; + routine(res, nullptr, inputs, outputs); + kernel::bang::sync(); + } + { + void const *inputs[]{data_.data(), indices_.data(), updates_.data()}; + void *outputs[]{data_.data()}; + rCpu(res, nullptr, inputs, outputs); + } + // check + std::vector<float> result(output->elementsSize()); + mluOut->copyToHost(result.data(), output->bytesSize()); + EXPECT_EQ(result, data_); +} + +#endif diff --git a/src/04kernel/test/kernels/select/test_cnnl.cpp b/src/04kernel/test/kernels/select/test_cnnl.cpp new file mode 100644 index 000000000..b691fc3d1 --- /dev/null +++ b/src/04kernel/test/kernels/select/test_cnnl.cpp @@ -0,0 +1,99 @@ +#ifdef USE_BANG + +#include "../../../src/kernels/select/cnnl_kernel.hh" +#include "../src/utilities/bang/cnrt_functions.h" +#include "hardware/device_manager.h" +#include <functional> +#include <gtest/gtest.h> +#include <numeric> + +using namespace refactor; +using namespace kernel; +using namespace hardware; + +static void testSelect(const SelectType selectType, const std::vector<Shape> &shapes, const Shape &outShape, const std::vector<std::vector<float>> &data, + const std::vector<float> expectData) { + // build routine + TensorRefs dataTensors; + std::vector<Tensor> tensorsVec; + for (size_t i = 0; i < shapes.size(); ++i) { + tensorsVec.push_back(Tensor(DataType::F32, shapes[i], LayoutType::Others, nullptr)); + } + for (size_t i = 0; i < shapes.size(); ++i) { + dataTensors.push_back(std::cref(tensorsVec[i])); + } + auto result = Tensor::share(DataType::F32, outShape); + auto kernel = SelectCnnl::build(selectType, dataTensors); + ASSERT_TRUE(kernel); + auto res = runtime::Resources(); + auto [routine, workspaceSize] = kernel->lower(res); + // cnnl malloc + auto &dev = *device::init(Device::Type::Mlu, 0, ""); + Arc<Device::Blob> + workspace = dev.malloc(workspaceSize), + mluIns[]{ + dev.malloc(dataTensors[0].get().bytesSize()), + dev.malloc(dataTensors[1].get().bytesSize()), + dev.malloc(dataTensors[2].get().bytesSize()), + }, + mluOut = dev.malloc(result->bytesSize()); + // put input data + mluIns[0]->copyFromHost(data[0].data(), dataTensors[0].get().bytesSize()); + mluIns[1]->copyFromHost(data[1].data(), dataTensors[1].get().bytesSize()); + mluIns[2]->copyFromHost(data[2].data(), dataTensors[2].get().bytesSize()); + // inference + { + void const *inputs[]{*mluIns[0], *mluIns[1], *mluIns[2]}; + void *outputs[]{*mluOut}; + routine(res, *workspace, inputs, outputs); + kernel::bang::sync(); + } + // check + std::vector<float> out(result->elementsSize()); + mluOut->copyToHost(out.data(), result->bytesSize()); + for (auto i : range0_(expectData.size())) { + EXPECT_FLOAT_EQ(expectData[i], out[i]); + } +} + +TEST(kernel, SelectCnnl) { + // no need broadcast + testSelect(SelectType::Max, + {{1, 3}, {1, 3}, {1, 3}}, + {1, 3}, + {{3, 2, 1}, {1, 4, 4}, {2, 5, 3}}, + {3, 5, 4}); + + testSelect(SelectType::Min, + {{1, 3}, {1, 3}, {1, 3}}, + {1, 3}, + {{3, 2, 1}, {1, 4, 4}, {2, 5, 3}}, + {1, 2, 1}); + + // need broadcast + testSelect(SelectType::Max, + {{3}, {1, 3}, {1, 3}}, + {1, 3}, + {{3, 3, 3}, {1, 4, 4}, {2, 5, 3}}, + {3, 5, 4}); + + testSelect(SelectType::Min, + {{3}, {1, 3}, {1, 3}}, + {1, 3}, + {{3, 3, 3}, {1, 4, 4}, {2, 5, 3}}, + {1, 3, 3}); + + testSelect(SelectType::Max, + {{1}, {1, 3}, {1, 3}}, + {1, 3}, + {{3}, {1, 4, 4}, {2, 5, 3}}, + {3, 5, 4}); + + testSelect(SelectType::Min, + {{1}, {1, 3}, {1, 3}}, + {1, 3}, + {{3}, {1, 4, 4}, {2, 5, 3}}, + {1, 3, 3}); +} + +#endif diff --git a/src/04kernel/test/kernels/simple_binary/test_binary_cnnl.cpp b/src/04kernel/test/kernels/simple_binary/test_binary_cnnl.cpp new file mode 100644 index 000000000..dfd3c9b80 --- /dev/null +++ b/src/04kernel/test/kernels/simple_binary/test_binary_cnnl.cpp @@ -0,0 +1,110 @@ +#ifdef USE_BANG + +#include "../src/kernels/simple_binary/binary_cnnl.hh" +#include "../src/kernels/simple_binary/cpu_kernel.hh" +#include "../src/utilities/bang/cnrt_functions.h" +#include "hardware/device_manager.h" +#include <gtest/gtest.h> + +using namespace refactor; +using namespace kernel; +using namespace hardware; + +template<decltype(DataType::internal) T> +void testBinaryCnnl(SimpleBinaryType binaryOPT, Shape dimA, Shape dimB, Shape dimC) { + // Create Tensor and build kernels + using T_ = primitive<T>::type; + auto aTensor = Tensor::share(T, dimA, LayoutType::NCHW); + auto bTensor = Tensor::share(T, dimB, LayoutType::NCHW); + auto cTensor = Tensor::share(T, dimC, LayoutType::NCHW); + auto kernel = BinaryCnnl::build(binaryOPT, *aTensor, *bTensor, *cTensor); + auto kCpu = BinaryCpu::build(binaryOPT, *aTensor, *bTensor); + ASSERT_TRUE(kCpu && kernel); + auto res = runtime::Resources(); + auto [routine, workspaceSize] = kernel->lower(res); + auto rCpu = kCpu->lower(res).routine; + // Init inputs and outputs + std::vector<T_> a(aTensor->elementsSize(), 3); + std::vector<T_> b(bTensor->elementsSize(), 2); + std::vector<T_> c(cTensor->elementsSize()); + auto &dev = *device::init(Device::Type::Mlu, 0, ""); + auto workspace = dev.malloc(workspaceSize), + aMLU = dev.malloc(aTensor->bytesSize()), + bMLU = dev.malloc(bTensor->bytesSize()), + cMLU = dev.malloc(cTensor->bytesSize()); + aMLU->copyFromHost(a.data(), aTensor->bytesSize()); + bMLU->copyFromHost(b.data(), bTensor->bytesSize()); + // Compute + { + void const *inputs[]{*aMLU, *bMLU}; + void *outputs[]{*cMLU}; + routine(res, *workspace, inputs, outputs); + kernel::bang::sync(); + } + { + void const *inputs[]{a.data(), b.data()}; + void *outputs[]{c.data()}; + rCpu(res, nullptr, inputs, outputs); + } + // Compare + std::vector<T_> result(cTensor->elementsSize()); + cMLU->copyToHost(result.data(), cTensor->bytesSize()); + for (auto i : range0_(result.size())) { + EXPECT_EQ(c[i], result[i]); + } +} +TEST(kernel, BinaryCnnlAdd) { + testBinaryCnnl<DataType::F32>(SimpleBinaryType::Add, + Shape{2, 5, 10, 20, 3, 4}, + Shape{2, 5, 10, 20, 3, 4}, + Shape{2, 5, 10, 20, 3, 4}); +} + +TEST(kernel, BinaryCnnlMul) { + testBinaryCnnl<DataType::F32>(SimpleBinaryType::Mul, + Shape{2, 5, 10, 20, 3, 4}, + Shape{2, 5, 10, 20, 3, 4}, + Shape{2, 5, 10, 20, 3, 4}); +} + +TEST(kernel, BinaryCnnlSub) { + testBinaryCnnl<DataType::F32>(SimpleBinaryType::Sub, + Shape{2, 5, 10, 20, 3, 4}, + Shape{2, 5, 10, 20, 3, 4}, + Shape{2, 5, 10, 20, 3, 4}); +} + +TEST(kernel, BinaryCnnlDiv) { + testBinaryCnnl<DataType::F32>(SimpleBinaryType::Div, + Shape{2, 5, 10, 20, 3, 4}, + Shape{2, 5, 10, 20, 3, 4}, + Shape{2, 5, 10, 20, 3, 4}); +} + +TEST(kernel, BinaryCnnlPow) { + testBinaryCnnl<DataType::F32>(SimpleBinaryType::Pow, + Shape{2, 5, 10, 20, 3, 4}, + Shape{2, 5, 10, 20, 3, 4}, + Shape{2, 5, 10, 20, 3, 4}); +} + +TEST(kernel, BinaryCnnlMod) { + testBinaryCnnl<DataType::I32>(SimpleBinaryType::Mod, + Shape{2, 5, 10, 20, 3, 4}, + Shape{2, 5, 10, 20, 3, 4}, + Shape{2, 5, 10, 20, 3, 4}); +} + +TEST(kernel, BinaryCnnlFMod) { + testBinaryCnnl<DataType::F32>(SimpleBinaryType::Fmod, + Shape{2, 5, 10, 20, 3, 4}, + Shape{2, 5, 10, 20, 3, 4}, + Shape{2, 5, 10, 20, 3, 4}); +} + +TEST(kernel, BinaryCnnlBroadcast) { + testBinaryCnnl<DataType::F32>(SimpleBinaryType::Add, Shape{1, 2, 3, 4, 5, 6}, Shape{}, Shape{1, 2, 3, 4, 5, 6}); +} + + +#endif diff --git a/src/04kernel/test/kernels/simple_unary/test_cnnl.cpp b/src/04kernel/test/kernels/simple_unary/test_cnnl.cpp new file mode 100644 index 000000000..3964989c9 --- /dev/null +++ b/src/04kernel/test/kernels/simple_unary/test_cnnl.cpp @@ -0,0 +1,67 @@ +#ifdef USE_BANG + +#include "../../../src/kernels/simple_unary/cnnl_activation_kernel.hh" +#include "../../../src/kernels/simple_unary/cnnl_simple_unary_kernel.hh" +#include "../../../src/kernels/simple_unary/cpu_kernel.hh" +#include "../src/utilities/bang/cnrt_functions.h" +#include "hardware/device_manager.h" +#include <gtest/gtest.h> + +using namespace refactor; +using namespace kernel; +using namespace hardware; + +static void testOp(SimpleUnaryType opType, bool activation = true) { + // build routine + auto dataTensor = Tensor::share(DataType::F32, Shape{20, 30, 50}); + auto kernel = activation ? ActivationCnnl::build(opType, *dataTensor) + : SimpleUnaryCnnl::build(opType, *dataTensor); + auto kCpu = SimpleUnaryCpu::build(opType, *dataTensor); + ASSERT_TRUE(kernel && kCpu); + auto res = runtime::Resources(); + auto routine = kernel->lower(res).routine, + rCpu = kCpu->lower(res).routine; + // malloc + auto &dev = *device::init(Device::Type::Mlu, 0, ""); + auto mluMem = dev.malloc(dataTensor->bytesSize()); + // put input data + std::vector<float> data(dataTensor->elementsSize()); + for (auto i : range0_(data.size())) { data[i] = i * 1e-4f; } + mluMem->copyFromHost(data.data(), dataTensor->bytesSize()); + // inference + { + void const *inputs[]{*mluMem}; + void *outputs[]{*mluMem}; + routine(res, nullptr, inputs, outputs); + kernel::bang::sync(); + } + { + void const *inputs[]{data.data()}; + void *outputs[]{data.data()}; + rCpu(res, nullptr, inputs, outputs); + } + // take output data + std::vector<float> result(dataTensor->elementsSize()); + mluMem->copyToHost(result.data(), dataTensor->bytesSize()); + // check + for (auto i : range0_(data.size())) { + EXPECT_NEAR(data[i], result[i], 1e-4); + } +} + +TEST(kernel, SimpleUnaryCnnl) { + testOp(SimpleUnaryType::Abs, false); + testOp(SimpleUnaryType::Neg, false); + testOp(SimpleUnaryType::Sqrt, false); + testOp(SimpleUnaryType::Erf, false); +} + +TEST(kernel, ActivationCnnl) { + testOp(SimpleUnaryType::Relu); + testOp(SimpleUnaryType::Sigmoid); + testOp(SimpleUnaryType::Tanh); + testOp(SimpleUnaryType::HardSwish); +} + + +#endif// USE_BANG diff --git a/src/04kernel/test/kernels/slice/test_cnnl.cpp b/src/04kernel/test/kernels/slice/test_cnnl.cpp new file mode 100644 index 000000000..1685d7aaa --- /dev/null +++ b/src/04kernel/test/kernels/slice/test_cnnl.cpp @@ -0,0 +1,61 @@ +#ifdef USE_BANG + +#include "../../../src/kernels/slice/cnnl_kernel.hh" +#include "../../../src/kernels/slice/cpu_kernel.hh" +#include "../src/utilities/bang/cnrt_functions.h" +#include "hardware/device_manager.h" +#include <gtest/gtest.h> +#include <numeric> + +using namespace refactor; +using namespace kernel; +using namespace hardware; + +TEST(kernel, SliceCnnl) { + // build routine + Dimensions dims{ + {5, -2, 3},// 7 -> {5, 3, 1} -> {108, 900, -360} + {2, 3, 2}, // 6 -> {2, 5} -> { 36, 60, 90} + {1, 1, 3}, // 5 -> {1, 2, 3} -> { 18, 6, 30} + {0, 1, 1}, // 1 -> {0} + {0, 1, 2}, // 2 -> {0, 1} + {0, 1, 3}, // 3 -> {0, 1, 2} + }; + auto input = Tensor::share(DataType::F32, Shape{7, 6, 5, 1, 2, 3}), + output = Tensor::share(DataType::F32, Shape{3, 2, 3, 1, 2, 3}); + SliceInfo info(dims, *input); + auto kernel = SliceCnnl::build(DataType::F32, dims, input->shape, output->shape); + auto kCpu = SliceCpu::build(info); + ASSERT_TRUE(kernel && kCpu); + auto res = runtime::Resources(); + auto routine = kernel->lower(res).routine; + auto rCpu = kCpu->lower(res).routine; + // malloc + auto &dev = *device::init(Device::Type::Mlu, 0, ""); + auto mluIn = dev.malloc(input->bytesSize()), + mluOut = dev.malloc(output->bytesSize()); + // put input data + std::vector<float> + data(input->elementsSize()), + ans(output->elementsSize()), + result(ans.size()); + std::iota(data.begin(), data.end(), 0); + mluIn->copyFromHost(data.data(), input->bytesSize()); + // inference + { + void const *inputs[]{*mluIn}; + void *outputs[]{*mluOut}; + routine(res, nullptr, inputs, outputs); + kernel::bang::sync(); + } + { + void const *inputs[]{data.data()}; + void *outputs[]{ans.data()}; + rCpu(res, nullptr, inputs, outputs); + } + // check + mluOut->copyToHost(result.data(), output->bytesSize()); + EXPECT_EQ(result, ans); +} + +#endif diff --git a/src/04kernel/test/kernels/softmax/test_cnnl.cpp b/src/04kernel/test/kernels/softmax/test_cnnl.cpp new file mode 100644 index 000000000..09874d207 --- /dev/null +++ b/src/04kernel/test/kernels/softmax/test_cnnl.cpp @@ -0,0 +1,54 @@ +#ifdef USE_BANG + +#include "../../../src/kernels/softmax/cnnl_kernel.hh" +#include "../../../src/kernels/softmax/cpu_kernel.hh" +#include "../src/utilities/bang/cnrt_functions.h" +#include "hardware/device_manager.h" +#include <gtest/gtest.h> + +using namespace refactor; +using namespace kernel; +using namespace hardware; + +TEST(kernel, SoftmaxCnnl) { + // build routine + auto xTensor = Tensor::share(DataType::F32, Shape{2, 3, 2, 5, 4}); + auto outTensor = Tensor::share(DataType::F32, Shape{2, 3, 2, 5, 4}); + dim_t axis = 2; + auto kCpu = SoftmaxCpu::build(SoftmaxInfo(*xTensor, axis)); + auto kCnnl = SoftmaxCnnl::build(cnnl::SoftmaxAlgo::FAST, SoftmaxInfo(*xTensor, axis)); + ASSERT_TRUE(kCpu && kCnnl); + auto res = runtime::Resources(); + auto rCpu = kCpu->lower(res).routine; + auto rCnnl = kCnnl->lower(res).routine; + // malloc + auto &dev = *device::init(Device::Type::Mlu, 0, ""); + auto mluIn = dev.malloc(xTensor->bytesSize()), + mluOut = dev.malloc(outTensor->bytesSize()); + // put input data + std::vector<float> + data(xTensor->elementsSize(), 0), + cpuOut(outTensor->elementsSize()); + mluIn->copyFromHost(data.data(), xTensor->bytesSize()); + // inference + { + void const *inputs[]{data.data()}; + void *outputs[]{cpuOut.data()}; + rCpu(res, nullptr, inputs, outputs); + } + { + void const *inputs[]{*mluIn}; + void *outputs[]{*mluOut}; + rCnnl(res, nullptr, inputs, outputs); + kernel::bang::sync(); + } + // take output data + std::vector<float> result(outTensor->elementsSize()); + mluOut->copyToHost(result.data(), outTensor->bytesSize()); + // check + for (auto i : range0_(result.size())) { + EXPECT_FLOAT_EQ(cpuOut[i], result[i]); + } +} + +#endif diff --git a/src/04kernel/test/kernels/split/test_cnnl.cpp b/src/04kernel/test/kernels/split/test_cnnl.cpp new file mode 100644 index 000000000..71e69b219 --- /dev/null +++ b/src/04kernel/test/kernels/split/test_cnnl.cpp @@ -0,0 +1,82 @@ +#ifdef USE_BANG + +#include "../../../src/kernels/split/cnnl_kernel.hh" +#include "../../../src/kernels/split/cpu_kernel.hh" +#include "../src/utilities/bang/cnrt_functions.h" +#include "hardware/device_manager.h" +#include <gtest/gtest.h> +#include <numeric> + +using namespace refactor; +using namespace kernel; +using namespace hardware; + +TEST(kernel, SplitCnnl) { + // build routine + auto dataTensor = Tensor::share(DataType::F32, Shape{2, 3, 1, 20, 7, 7}); + std::vector<Arc<Tensor>> outputTensors{ + Tensor::share(DataType::F32, Shape{2, 3, 1, 1, 7, 7}),// 勿 + Tensor::share(DataType::F32, Shape{2, 3, 1, 9, 7, 7}),// 忘 + Tensor::share(DataType::F32, Shape{2, 3, 1, 3, 7, 7}),// 国 + Tensor::share(DataType::F32, Shape{2, 3, 1, 7, 7, 7}),// 耻 + }; + TensorRefs outputs_; + outputs_.reserve(outputTensors.size()); + std::transform(outputTensors.begin(), outputTensors.end(), + std::back_inserter(outputs_), + [](auto const &it) { return std::cref(*it); }); + auto info = SplitInfo(3, outputs_); + auto kCpu = SplitCpu::build(info); + auto kernel = SplitCnnl::build(3, *dataTensor, outputs_); + ASSERT_TRUE(kCpu && kernel); + auto res = runtime::Resources(); + auto rCpu = kCpu->lower(res).routine; + auto [routine, workspaceSize]= kernel->lower(res); + // malloc + auto &dev = *device::init(Device::Type::Mlu, 0, ""); + Arc<Device::Blob> + workspace = dev.malloc(workspaceSize), + mluIn = dev.malloc(dataTensor->bytesSize()), + mluOuts[]{ + dev.malloc(outputTensors[0]->bytesSize()), + dev.malloc(outputTensors[1]->bytesSize()), + dev.malloc(outputTensors[2]->bytesSize()), + dev.malloc(outputTensors[3]->bytesSize()), + }; + // put input data + std::vector<float> + data(dataTensor->elementsSize()), + outsCpu[]{ + std::vector<float>(outputTensors[0]->elementsSize()), + std::vector<float>(outputTensors[1]->elementsSize()), + std::vector<float>(outputTensors[2]->elementsSize()), + std::vector<float>(outputTensors[3]->elementsSize()), + }, + outs[]{ + std::vector<float>(outputTensors[0]->elementsSize()), + std::vector<float>(outputTensors[1]->elementsSize()), + std::vector<float>(outputTensors[2]->elementsSize()), + std::vector<float>(outputTensors[3]->elementsSize()), + }; + std::iota(data.begin(), data.end(), 0); + mluIn->copyFromHost(data.data(), dataTensor->bytesSize()); + // inference + { + void const *inputs[]{*mluIn}; + void *outputs[]{*mluOuts[0], *mluOuts[1], *mluOuts[2], *mluOuts[3]}; + routine(res, *workspace, inputs, outputs); + kernel::bang::sync(); + } + { + void const *inputs[]{data.data()}; + void *outputs[]{outsCpu[0].data(), outsCpu[1].data(), outsCpu[2].data(), outsCpu[3].data()}; + rCpu(res, nullptr, inputs, outputs); + } + // check + for (auto i : range0_(outputTensors.size())) { + mluOuts[i]->copyToHost(outs[i].data(), outputTensors[i]->bytesSize()); + EXPECT_EQ(outs[i], outsCpu[i]); + } +} + +#endif diff --git a/src/04kernel/test/kernels/transpose/test_cnnl.cpp b/src/04kernel/test/kernels/transpose/test_cnnl.cpp new file mode 100644 index 000000000..9acdd95ba --- /dev/null +++ b/src/04kernel/test/kernels/transpose/test_cnnl.cpp @@ -0,0 +1,57 @@ +#ifdef USE_BANG + +#include "../../../src/kernels/transpose/cnnl_kernel.hh" +#include "../../../src/kernels/transpose/cpu_kernel.hh" +#include "../src/utilities/bang/cnrt_functions.h" +#include "hardware/device_manager.h" +#include <gtest/gtest.h> +#include <numeric> + +using namespace refactor; +using namespace kernel; +using namespace hardware; + +TEST(kernel, TransposeCnnl) { + // build routine + auto dataTensor = Tensor::share(DataType::F32, Shape{1, 3, 2, 5}); + auto info = TransposeInfo(dataTensor->dataType, dataTensor->shape, Permutation{2, 3, 0, 1}); + auto kCpu = TransposeCpu::build(info); + auto kernel = TransposeCnnl::build(dataTensor->dataType, dataTensor->shape, Permutation{2, 3, 0, 1}); + ASSERT_TRUE(kCpu && kernel); + auto res = runtime::Resources(); + auto rCpu = kCpu->lower(res).routine; + auto [routine, workspaceSize] = kernel->lower(res); + // malloc + auto &dev = *device::init(Device::Type::Mlu, 0, ""); + auto bytes = dataTensor->bytesSize(); + auto workspace = dev.malloc(workspaceSize), + mluIn = dev.malloc(bytes), + mluOut = dev.malloc(bytes); + // put input data + std::vector<float> + cpuIn(dataTensor->elementsSize()), + cpuOut(cpuIn.size()); + std::iota(cpuIn.begin(), cpuIn.end(), 0); + mluIn->copyFromHost(cpuIn.data(), bytes); + // inference + { + void const *inputs[]{cpuIn.data()}; + void *outputs[]{cpuOut.data()}; + rCpu(res, nullptr, inputs, outputs); + } + { + void const *inputs[]{*mluIn}; + void *outputs[]{*mluOut}; + routine(res, *workspace, inputs, outputs); + kernel::bang::sync(); + } + // take output data + std::vector<float> result(dataTensor->elementsSize()); + mluOut->copyToHost(result.data(), bytes); + // check + for (auto i : range0_(result.size())) { + EXPECT_FLOAT_EQ(cpuOut[i], result[i]); + } +} + +#endif diff --git a/src/04kernel/test/kernels/where/test_cnnl.cpp b/src/04kernel/test/kernels/where/test_cnnl.cpp new file mode 100644 index 000000000..6e26ddbd2 --- /dev/null +++ b/src/04kernel/test/kernels/where/test_cnnl.cpp @@ -0,0 +1,70 @@ +#ifdef USE_BANG + +#include "../../../src/kernels/where/cnnl_kernel.hh" +#include "../../../src/kernels/where/cpu_kernel.hh" +#include "../src/utilities/bang/cnrt_functions.h" +#include "hardware/device_manager.h" +#include <gtest/gtest.h> + +using namespace refactor; +using namespace kernel; +using namespace hardware; + +void testWhereCnnl(Shape cDim, Shape xDim, Shape yDim, Shape outDim) { + // build routine + auto cTensor = Tensor::share(DataType::Bool, cDim); + auto xTensor = Tensor::share(DataType::F32, xDim); + auto yTensor = Tensor::share(DataType::F32, yDim); + auto outTensor = Tensor::share(DataType::F32, outDim); + auto kCpu = WhereCpu::build({*cTensor, *xTensor, *yTensor}); + auto kCnnl = WhereCnnl::build({*cTensor, *xTensor, *yTensor}, {*outTensor}); + ASSERT_TRUE(kCpu && kCnnl); + auto res = runtime::Resources(); + auto rCpu = kCpu->lower(res).routine; + auto [rCnnl, workspaceSize] = kCnnl->lower(res); + // malloc + auto &dev = *device::init(Device::Type::Mlu, 0, ""); + auto workspace = dev.malloc(workspaceSize), + mluC = dev.malloc(cTensor->bytesSize()), + mluX = dev.malloc(xTensor->bytesSize()), + mluY = dev.malloc(yTensor->bytesSize()), + mluOut = dev.malloc(outTensor->bytesSize()); + // put input data + int dataC[cTensor->elementsSize()]; + memset(dataC, 1, cTensor->elementsSize() * sizeof(bool)); + mluC->copyFromHost(dataC, cTensor->bytesSize()); + std::vector<float> dataX(xTensor->elementsSize()); + for (auto i : range0_(dataX.size())) { dataX[i] = 7; } + mluX->copyFromHost(dataX.data(), xTensor->bytesSize()); + std::vector<float> dataY(yTensor->elementsSize()); + for (auto i : range0_(dataY.size())) { dataY[i] = 3; } + mluY->copyFromHost(dataY.data(), yTensor->bytesSize()); + std::vector<float> cpuOut(outTensor->elementsSize()); + // inference + { + void const *inputs[]{dataC, dataX.data(), dataY.data()}; + void *outputs[]{cpuOut.data()}; + rCpu(res, nullptr, inputs, outputs); + } + { + void const *inputs[]{*mluC, *mluX, *mluY}; + void *outputs[]{*mluOut}; + rCnnl(res, *workspace, inputs, outputs); + kernel::bang::sync(); + } + // take output data + std::vector<float> result(outTensor->elementsSize()); + mluOut->copyToHost(result.data(), outTensor->bytesSize()); + // check + for (auto i : range0_(result.size())) { + EXPECT_FLOAT_EQ(cpuOut[i], result[i]); + } +} + +TEST(kernel, WhereCnnl) { + testWhereCnnl(Shape{2, 5}, Shape{2, 3, 1, 5}, Shape{3, 2, 5}, Shape{2, 3, 2, 5}); + testWhereCnnl(Shape{1}, Shape{4}, Shape{1}, Shape{4}); + testWhereCnnl(Shape{3}, Shape{2, 3}, Shape{2, 3}, Shape{2, 3}); +} + +#endif diff --git a/src/07onnx/src/operators/simple_binary.cc b/src/07onnx/src/operators/simple_binary.cc index 2db99bdd3..8ff7660b4 100644 --- a/src/07onnx/src/operators/simple_binary.cc +++ b/src/07onnx/src/operators/simple_binary.cc @@ -10,7 +10,7 @@ namespace refactor::onnx { : Operator(), type(type_) {} auto Op::build(ModelContext const &, std::string_view opType, Attributes attributes) -> OpBox { - auto fmod = attributes.getOrInsert( "fmod", {0}).int_(); + auto fmod = attributes.getOrInsert("fmod", {0}).int_(); // clang-format off auto type = opType == "onnx::Add" ? Ty::Add : diff --git a/src/09python_ffi/CMakeLists.txt b/src/09python_ffi/CMakeLists.txt index ccce34d37..09567c9da 100644 --- a/src/09python_ffi/CMakeLists.txt +++ b/src/09python_ffi/CMakeLists.txt @@ -10,6 +10,10 @@ pybind11_add_module(python_ffi SHARED ${PYFFI_SRC}) target_link_libraries(python_ffi PRIVATE onnx llm communication) target_include_directories(python_ffi PRIVATE include) +if(USE_BANG) + target_include_directories(python_ffi PUBLIC ${CMAKE_CURRENT_SOURCE_DIR}/../04kernel/src/utilities/bang) +endif() + # EXAMPLE_VERSION_INFO is defined by setup.py and passed into the C++ code as a # define (VERSION_INFO) here. # target_compile_definitions(python_ffi diff --git a/src/09python_ffi/src/compiler.cc b/src/09python_ffi/src/compiler.cc index bf04053e9..45450582e 100644 --- a/src/09python_ffi/src/compiler.cc +++ b/src/09python_ffi/src/compiler.cc @@ -95,6 +95,7 @@ namespace refactor::python_ffi { // clang-format off auto target_ = target == "cpu" ? Target::Cpu : target == "cuda" ? Target::Nvidia + : target == "mlu" ? Target::Mlu : UNREACHABLEX(Target, "Unknown target: {}", target); // clang-format on return compileOn(hardware::device::fetch(target_), diff --git a/src/09python_ffi/src/executor.cc b/src/09python_ffi/src/executor.cc index c6a20cb95..1d0e543f1 100644 --- a/src/09python_ffi/src/executor.cc +++ b/src/09python_ffi/src/executor.cc @@ -7,6 +7,10 @@ #include "kernel/cuda/functions.cuh" #endif// USE_CUDA +#ifdef USE_BANG +#include "cnrt_functions.h" +#endif// USE_BANG + namespace refactor::python_ffi { Executor::Executor(computation::Graph graph, runtime::Stream stream) @@ -70,9 +74,13 @@ namespace refactor::python_ffi { void Executor::bench(bool sync) { #ifdef USE_CUDA auto ans = _stream.bench(sync ? kernel::cuda::sync : nullptr); +#else + #ifdef USE_BANG + auto ans = _stream.bench(sync ? kernel::bang::sync : nullptr); #else auto ans = _stream.bench(nullptr); -#endif// USE_CUDA + #endif +#endif auto const &nodes = _graph.internal().contiguous().nodes; for (auto i : range0_(nodes.size())) { fmt::println("{} {} {}", @@ -213,6 +221,9 @@ namespace refactor::python_ffi { #ifdef USE_CUDA kernel::cuda::copyOut(buffer.data(), addresses[idx], size); #endif +#ifdef USE_BANG + kernel::bang::copyOut(buffer.data(), addresses[idx], size); +#endif auto file = path / fmt::format("data{:06}.{}", dataIdx++, format); fs::remove(file); diff --git a/src/09python_ffi/src/import.cpp b/src/09python_ffi/src/import.cpp index dda0e660c..74cf7cf01 100644 --- a/src/09python_ffi/src/import.cpp +++ b/src/09python_ffi/src/import.cpp @@ -16,6 +16,7 @@ namespace refactor::python_ffi { // clang-format off auto type_ = type == "cpu" ? Device::Type::Cpu : type == "nvidia" ? Device::Type::Nvidia + : type == "mlu" ? Device::Type::Mlu : UNREACHABLEX(Device::Type, "Unknown device type: \"{}\"", type); // clang-format on return device::init(type_, card, "");