diff --git a/.gitmodules b/.gitmodules new file mode 100644 index 00000000000..51cbe64bd4a --- /dev/null +++ b/.gitmodules @@ -0,0 +1,4 @@ +[submodule "csrc/third_party/catlass"] + path = csrc/third_party/catlass + url = https://gitcode.com/cann/catlass.git + branch = catlass-v1-stable diff --git a/csrc/build_aclnn.sh b/csrc/build_aclnn.sh index 9dba287e3ae..3685bdd107a 100644 --- a/csrc/build_aclnn.sh +++ b/csrc/build_aclnn.sh @@ -15,7 +15,15 @@ elif [[ "$SOC_VERSION" =~ ^ascend910b ]]; then SOC_ARG="ascend910b" elif [[ "$SOC_VERSION" =~ ^ascend910_93 ]]; then # ASCEND910C (A3) series - CUSTOM_OPS="grouped_matmul_swiglu_quant_weight_nz_tensor_list" + # depdendency: catlass + CATLASS_PATH=${ROOT_DIR}/csrc/third_party/catlass + if [[ ! -d "${CATLASS_PATH}" ]]; then + echo "depdendency catlass does not exist, please run 'git submodule update --init --recursive'" + exit 1 + fi + ABSOLUTE_CATLASS_PATH=$(cd "${CATLASS_PATH}" && pwd) + export CPATH=${ABSOLUTE_CATLASS_PATH}/include:${CPATH} + CUSTOM_OPS="grouped_matmul_swiglu_quant_weight_nz_tensor_list;dispatch_gmm_combine_decode" SOC_ARG="ascend910_93" else # others diff --git a/csrc/cmake/func.cmake b/csrc/cmake/func.cmake index f2bebf75639..e8ce57564fc 100644 --- a/csrc/cmake/func.cmake +++ b/csrc/cmake/func.cmake @@ -282,7 +282,7 @@ function(add_ops_src_copy) set(_BUILD_FLAG ${SRC_COPY_DST}/${SRC_COPY_TARGET_NAME}.done) add_custom_command(OUTPUT ${_BUILD_FLAG} COMMAND mkdir -p ${SRC_COPY_DST} - COMMAND cp -rf ${SRC_COPY_SRC}/op_kernel/*.* ${SRC_COPY_DST} + COMMAND cp -rf ${SRC_COPY_SRC}/op_kernel/* ${SRC_COPY_DST} COMMAND touch ${_BUILD_FLAG} ) diff --git a/csrc/dispatch_gmm_combine_decode/op_host/CMakeLists.txt b/csrc/dispatch_gmm_combine_decode/op_host/CMakeLists.txt new file mode 100644 index 00000000000..04d3c702009 --- /dev/null +++ b/csrc/dispatch_gmm_combine_decode/op_host/CMakeLists.txt @@ -0,0 +1,51 @@ +# Copyright (c) 2025 Huawei Technologies Co., Ltd. +# This file is a part of the CANN Open Software. +# Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. +# ====================================================================================================================== + +add_ops_compile_options( + OP_NAME DispatchGmmCombineDecode + OPTIONS --cce-auto-sync=off + -Wno-deprecated-declarations + -Werror +) + +target_sources(op_host_aclnnInner PRIVATE + dispatch_gmm_combine_decode_def.cpp +) + +target_sources(opapi PRIVATE + aclnn_dispatch_gmm_combine_decode.cpp +) + +if (NOT BUILD_OPEN_PROJECT) + target_sources(aclnn_ops_train PRIVATE + aclnn_dispatch_gmm_combine_decode.cpp + ) + + target_sources(aclnn_ops_infer PRIVATE + aclnn_dispatch_gmm_combine_decode.cpp + ) +endif () + +target_sources(optiling PRIVATE + dispatch_gmm_combine_decode_tiling.cpp +) + +target_include_directories(optiling PRIVATE + ${CMAKE_CURRENT_SOURCE_DIR} +) + +target_sources(opsproto PRIVATE + dispatch_gmm_combine_decode_proto.cpp +) + +file(GLOB _GMM_Aclnn_header "${CMAKE_CURRENT_SOURCE_DIR}/aclnn_dispatch_gmm_combine_decode.h") + +install(FILES ${_GMM_Aclnn_header} + DESTINATION ${ACLNN_INC_INSTALL_DIR} OPTIONAL +) diff --git a/csrc/dispatch_gmm_combine_decode/op_host/aclnn_dispatch_gmm_combine_decode.cpp b/csrc/dispatch_gmm_combine_decode/op_host/aclnn_dispatch_gmm_combine_decode.cpp new file mode 100644 index 00000000000..30c890c0121 --- /dev/null +++ b/csrc/dispatch_gmm_combine_decode/op_host/aclnn_dispatch_gmm_combine_decode.cpp @@ -0,0 +1,101 @@ +/* + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ +#include +#include "graph/types.h" +#include "aclnn/opdev/platform.h" +#include "aclnn_dispatch_gmm_combine_decode.h" + +enum NnopbaseHcclServerType { + NNOPBASE_HCCL_SERVER_TYPE_AICPU = 0, + NNOPBASE_HCCL_SERVER_TYPE_MTE, + NNOPBASE_HCCL_SERVER_TYPE_END +}; +extern "C" void __attribute__((weak)) NnopbaseSetHcclServerType(void *executor, NnopbaseHcclServerType sType); + +#ifdef __cplusplus +extern "C" { +#endif + +extern aclnnStatus aclnnInnerDispatchGmmCombineDecodeGetWorkspaceSize( + const aclTensor *x, + const aclTensor *expertIds, + const aclTensor *gmm1PermutedWeight, + const aclTensor *gmm1PermutedWeightScale, + const aclTensor *gmm2Weight, + const aclTensor *gmm2WeightScale, + const aclTensor *expertSmoothScalesOptional, + const aclTensor *expertScalesOptional, + char *groupEp, + int64_t epRankSize, + int64_t epRankId, + int64_t moeExpertNum, + int64_t shareExpertNum, + int64_t shareExpertRankNum, + int64_t quantMode, + int64_t globalBs, + const aclTensor *output, + const aclTensor *epRecvCount, + uint64_t *workspaceSize, + aclOpExecutor **executor); +extern aclnnStatus aclnnInnerDispatchGmmCombineDecode( + void *workspace, + uint64_t workspaceSize, + aclOpExecutor *executor, + aclrtStream stream); + +aclnnStatus aclnnDispatchGmmCombineDecodeGetWorkspaceSize( + const aclTensor *x, + const aclTensor *expertIds, + const aclTensor *gmm1PermutedWeight, + const aclTensor *gmm1PermutedWeightScale, + const aclTensor *gmm2Weight, + const aclTensor *gmm2WeightScale, + const aclTensor *expertSmoothScalesOptional, + const aclTensor *expertScalesOptional, + char *groupEp, + int64_t epRankSize, + int64_t epRankId, + int64_t moeExpertNum, + int64_t shareExpertNum, + int64_t shareExpertRankNum, + int64_t quantMode, + int64_t globalBs, + const aclTensor *output, + const aclTensor *epRecvCount, + uint64_t *workspaceSize, + aclOpExecutor **executor) +{ + return aclnnInnerDispatchGmmCombineDecodeGetWorkspaceSize(x, expertIds, gmm1PermutedWeight, gmm1PermutedWeightScale, + gmm2Weight, gmm2WeightScale, expertSmoothScalesOptional, expertScalesOptional, groupEp, epRankSize, + epRankId, moeExpertNum, shareExpertNum, shareExpertRankNum, quantMode, globalBs, + output, epRecvCount, workspaceSize, executor); +} + +aclnnStatus aclnnDispatchGmmCombineDecode( + void *workspace, + uint64_t workspaceSize, + aclOpExecutor *executor, + aclrtStream stream) +{ + if (NnopbaseSetHcclServerType) { + if (op::GetCurrentPlatformInfo().GetSocVersion() == op::SocVersion::ASCEND910B) { + NnopbaseSetHcclServerType(executor, NNOPBASE_HCCL_SERVER_TYPE_AICPU); + } else { + NnopbaseSetHcclServerType(executor, NNOPBASE_HCCL_SERVER_TYPE_MTE); + } + } + return aclnnInnerDispatchGmmCombineDecode(workspace, workspaceSize, executor, stream); +} + +#ifdef __cplusplus +} +#endif + + diff --git a/csrc/dispatch_gmm_combine_decode/op_host/aclnn_dispatch_gmm_combine_decode.h b/csrc/dispatch_gmm_combine_decode/op_host/aclnn_dispatch_gmm_combine_decode.h new file mode 100644 index 00000000000..bf7fc18bccf --- /dev/null +++ b/csrc/dispatch_gmm_combine_decode/op_host/aclnn_dispatch_gmm_combine_decode.h @@ -0,0 +1,51 @@ +/* + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ +#ifndef DISPATCH_GMM_COMBINE_DECODE +#define DISPATCH_GMM_COMBINE_DECODE + +#include "aclnn/acl_meta.h" + +#ifdef __cplusplus +extern "C" { +#endif + +__attribute__((visibility("default"))) aclnnStatus aclnnDispatchGmmCombineDecodeGetWorkspaceSize( + const aclTensor *x, + const aclTensor *expertIds, + const aclTensor *gmm1PermutedWeight, + const aclTensor *gmm1PermutedWeightScale, + const aclTensor *gmm2Weight, + const aclTensor *gmm2WeightScale, + const aclTensor *expertSmoothScalesOptional, + const aclTensor *expertScalesOptional, + char *groupEp, + int64_t epRankSize, + int64_t epRankId, + int64_t moeExpertNum, + int64_t shareExpertNum, + int64_t shareExpertRankNum, + int64_t quantMode, + int64_t globalBs, + const aclTensor *output, + const aclTensor *epRecvCount, + uint64_t *workspaceSize, + aclOpExecutor **executor); + +__attribute__((visibility("default"))) aclnnStatus aclnnDispatchGmmCombineDecode( + void *workspace, + uint64_t workspaceSize, + aclOpExecutor *executor, + aclrtStream stream); + +#ifdef __cplusplus +} +#endif + +#endif \ No newline at end of file diff --git a/csrc/dispatch_gmm_combine_decode/op_host/dispatch_gmm_combine_decode_def.cpp b/csrc/dispatch_gmm_combine_decode/op_host/dispatch_gmm_combine_decode_def.cpp new file mode 100644 index 00000000000..0a0737b27f1 --- /dev/null +++ b/csrc/dispatch_gmm_combine_decode/op_host/dispatch_gmm_combine_decode_def.cpp @@ -0,0 +1,83 @@ +/* + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ +#include "register/op_def_registry.h" + +namespace ops { +class DispatchGmmCombineDecode : public OpDef +{ +public: + explicit DispatchGmmCombineDecode(const char *name) : OpDef(name) + { + this->Input("x") + .ParamType(REQUIRED) + .DataType({ge::DT_BF16, ge::DT_FLOAT16}) + .Format({ge::FORMAT_ND, ge::FORMAT_ND}) + .UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND}); + this->Input("expert_ids") + .ParamType(REQUIRED) + .DataType({ge::DT_INT32, ge::DT_INT32}) + .Format({ge::FORMAT_ND, ge::FORMAT_ND}) + .UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND}); + this->Input("gmm1_permuted_weight") + .ParamType(REQUIRED) + .DataType({ge::DT_INT8, ge::DT_INT8}) + .Format({ge::FORMAT_FRACTAL_NZ, ge::FORMAT_FRACTAL_NZ}) + .UnknownShapeFormat({ge::FORMAT_FRACTAL_NZ, ge::FORMAT_FRACTAL_NZ}); + this->Input("gmm1_permuted_weight_scale") + .ParamType(REQUIRED) + .DataType({ge::DT_FLOAT, ge::DT_FLOAT}) + .Format({ge::FORMAT_ND, ge::FORMAT_ND}) + .UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND}); + this->Input("gmm2_weight") + .ParamType(REQUIRED) + .DataType({ge::DT_INT8, ge::DT_INT8}) + .Format({ge::FORMAT_FRACTAL_NZ, ge::FORMAT_FRACTAL_NZ}) + .UnknownShapeFormat({ge::FORMAT_FRACTAL_NZ, ge::FORMAT_FRACTAL_NZ}); + this->Input("gmm2_weight_scale") + .ParamType(REQUIRED) + .DataType({ge::DT_FLOAT, ge::DT_FLOAT}) + .Format({ge::FORMAT_ND, ge::FORMAT_ND}) + .UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND}); + this->Input("expert_smooth_scales") + .ParamType(OPTIONAL) + .DataType({ge::DT_FLOAT, ge::DT_FLOAT}) + .Format({ge::FORMAT_ND, ge::FORMAT_ND}) + .UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND}); + this->Input("expert_scales") + .ParamType(OPTIONAL) + .DataType({ge::DT_FLOAT, ge::DT_FLOAT}) + .Format({ge::FORMAT_ND, ge::FORMAT_ND}) + .UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND}); + this->Output("output") + .ParamType(REQUIRED) + .DataType({ge::DT_BF16, ge::DT_FLOAT16}) + .Format({ge::FORMAT_ND, ge::FORMAT_ND}) + .UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND}); + this->Output("ep_recv_count") + .ParamType(REQUIRED) + .DataType({ge::DT_INT32, ge::DT_INT32}) + .Format({ge::FORMAT_ND, ge::FORMAT_ND}) + .UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND}); + this->Attr("group_ep").String(); + this->Attr("ep_rank_size").Int(); + this->Attr("ep_rank_id").Int(); + this->Attr("moe_expert_num").Int(); + this->Attr("share_expert_num").Int(); + this->Attr("share_expert_rank_num").Int(); + this->Attr("quant_mode").Int(); + this->Attr("global_bs").Int(); + + this->MC2().HcclGroup({"group_ep"}); + this->AICore().AddConfig("ascend910_93"); + } +}; + +OP_ADD(DispatchGmmCombineDecode); +} // namespace ops diff --git a/csrc/dispatch_gmm_combine_decode/op_host/dispatch_gmm_combine_decode_proto.cpp b/csrc/dispatch_gmm_combine_decode/op_host/dispatch_gmm_combine_decode_proto.cpp new file mode 100644 index 00000000000..008d9316eaf --- /dev/null +++ b/csrc/dispatch_gmm_combine_decode/op_host/dispatch_gmm_combine_decode_proto.cpp @@ -0,0 +1,94 @@ +/* + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +#include +#include "error_log.h" +#include "graph/utils/type_utils.h" +#include "register/op_def_registry.h" + +namespace ge { +constexpr uint32_t EXPAND_X_INDEX = 0; +constexpr uint32_t EXPERT_IDS_INDEX = 1; +constexpr uint32_t OUTPUT_X_INDEX = 0; +constexpr uint32_t OUTPUT_REC_COUNT_INDEX = 1; + +constexpr uint32_t ATTR_GROUP_EP_INDEX = 0; +constexpr uint32_t ATTR_EP_RANK_SIZE_INDEX = 1; +constexpr uint32_t ATTR_EP_RANK_ID_INDEX = 2; +constexpr uint32_t ATTR_MOE_EXPERT_NUM_INDEX = 3; +constexpr uint32_t ATTR_SHARE_EXPERT_NUM_INDEX = 4; +constexpr uint32_t ATTR_SHARE_EXPERT_RANK_NUM_INDEX = 5; +constexpr uint32_t ATTR_QUANT_MODE_INDEX = 6; +constexpr uint32_t ATTR_GLOBAL_BS_INDEX = 7; + +static ge::graphStatus InferShape(gert::InferShapeContext *context) +{ + const char *nodeName = context->GetNodeName(); + // infer output shape + const gert::Shape *expandXShape = context->GetInputShape(EXPAND_X_INDEX); + const gert::Shape *expertIdsShape = context->GetInputShape(EXPERT_IDS_INDEX); + gert::Shape *expandXOutShape = context->GetOutputShape(OUTPUT_X_INDEX); + gert::Shape *recvCountOutShape = context->GetOutputShape(OUTPUT_REC_COUNT_INDEX); + if (expandXShape == nullptr || expertIdsShape == nullptr || expandXOutShape == nullptr || + recvCountOutShape == nullptr) { + return GRAPH_FAILED; + } + if (expandXShape->GetDimNum() < 2 || expertIdsShape->GetDimNum() < 1) { + return GRAPH_FAILED; + } + + int bs = expertIdsShape->GetDim(0); + int h = expandXShape->GetDim(1); + + expandXOutShape->SetDimNum(expandXShape->GetDimNum()); + expandXOutShape->SetDim(0, bs); + expandXOutShape->SetDim(1, h); + + // infer recvCount shape + auto attrs = context->GetAttrs(); + OP_TILING_CHECK(attrs == nullptr, OP_LOGE(nodeName, "attrs is nullptr."), return ge::GRAPH_FAILED); + + auto epRankSizePtr = attrs->GetAttrPointer(ATTR_EP_RANK_SIZE_INDEX); + auto epRankIdPtr = attrs->GetAttrPointer(ATTR_EP_RANK_ID_INDEX); + auto moeExpertNumPtr = attrs->GetAttrPointer(ATTR_MOE_EXPERT_NUM_INDEX); + auto sharedExpertRankNumPtr = attrs->GetAttrPointer(ATTR_SHARE_EXPERT_RANK_NUM_INDEX); + + OP_TILING_CHECK(epRankIdPtr == nullptr, OP_LOGE(nodeName, "epRankIdPtr is nullptr."), return ge::GRAPH_FAILED); + OP_TILING_CHECK(moeExpertNumPtr == nullptr, OP_LOGE(nodeName, "moeExpertNumPtr is nullptr."), + return ge::GRAPH_FAILED); + OP_TILING_CHECK(epRankSizePtr == nullptr, OP_LOGE(nodeName, "epRankSizePtr is nullptr."), return ge::GRAPH_FAILED); + OP_TILING_CHECK(sharedExpertRankNumPtr == nullptr, OP_LOGE(nodeName, "sharedExpertRankNumPtr is nullptr."), + return ge::GRAPH_FAILED); + uint32_t epRankSize = static_cast(*epRankSizePtr); + uint32_t moeExpertNum = static_cast(*moeExpertNumPtr); + uint32_t epRankId = static_cast(*epRankIdPtr); + uint32_t sharedExpertRankNum = static_cast(*sharedExpertRankNumPtr); + + recvCountOutShape->SetDimNum(1); + bool isShareExpert = (epRankId < sharedExpertRankNum); + if (isShareExpert) { + recvCountOutShape->SetDim(0, epRankSize); + } else { + recvCountOutShape->SetDim(0, epRankSize * (moeExpertNum / (epRankSize - sharedExpertRankNum))); + } + + return GRAPH_SUCCESS; +} + +static ge::graphStatus InferDataType(gert::InferDataTypeContext *context) +{ + const auto expandXDataType = context->GetInputDataType(EXPAND_X_INDEX); + context->SetOutputDataType(OUTPUT_X_INDEX, expandXDataType); + context->SetOutputDataType(OUTPUT_REC_COUNT_INDEX, ge::DT_INT32); + return ge::GRAPH_SUCCESS; +} + +IMPL_OP(DispatchGmmCombineDecode).InferShape(InferShape).InferDataType(InferDataType); +} // namespace ge diff --git a/csrc/dispatch_gmm_combine_decode/op_host/dispatch_gmm_combine_decode_tiling.cpp b/csrc/dispatch_gmm_combine_decode/op_host/dispatch_gmm_combine_decode_tiling.cpp new file mode 100644 index 00000000000..487d36e3875 --- /dev/null +++ b/csrc/dispatch_gmm_combine_decode/op_host/dispatch_gmm_combine_decode_tiling.cpp @@ -0,0 +1,327 @@ +/* + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ +#include +#include +#include + +#include "log/ops_log.h" +#include "error/ops_error.h" +#include "graph/utils/type_utils.h" +#include "register/op_def_registry.h" +#include "../op_kernel/dispatch_gmm_combine_decode_tiling.h" +#include "tiling/platform/platform_ascendc.h" +#include "tiling/hccl/hccl_tiling.h" + +using namespace ge; +namespace { +constexpr uint32_t OP_TYPE_ALL_TO_ALL = 8; +constexpr uint32_t SYSTEM_NEED_WORKSPACE = 16 * 1024 * 1024; +constexpr uint32_t GM_ALIGN_SIZE = 512; +constexpr uint32_t TOKEN_DTYPE_BYTE_SIZE = 2; +constexpr uint32_t USE_CORE_NUM = 24; +constexpr uint32_t L1_TILE_BYTE_SIZE = 32 * 1024; +constexpr uint32_t CUBE_WORKSPACE_STAGE = 4; +constexpr uint32_t RESERVED_WORKSPACE_SIZE = 256 * 1024; + +constexpr uint32_t INPUT_X_INDEX = 0; +constexpr uint32_t INPUT_EXPERT_IDS_INDEX = 1; +constexpr uint32_t INPUT_GMM1_WEIGHT_INDEX = 2; +constexpr uint32_t INPUT_GMM1_WEIGHT_SCALE_INDEX = 3; +constexpr uint32_t INPUT_GMM2_WEIGHT_INDEX = 4; +constexpr uint32_t INPUT_GMM2_WEIGHT_SCALE_INDEX = 5; +constexpr uint32_t INPUT_SMOOTH_SCALE_INDEX = 6; +constexpr uint32_t INPUT_EXPERT_SCALE_INDEX = 7; + +constexpr uint32_t ATTR_GROUP_EP_INDEX = 0; +constexpr uint32_t ATTR_EP_RANK_SIZE_INDEX = 1; +constexpr uint32_t ATTR_EP_RANK_ID_INDEX = 2; +constexpr uint32_t ATTR_MOE_EXPERT_NUM_INDEX = 3; +constexpr uint32_t ATTR_SHARE_EXPERT_NUM_INDEX = 4; +constexpr uint32_t ATTR_SHARE_EXPERT_RANK_NUM_INDEX = 5; +constexpr uint32_t ATTR_QUANT_MODE_INDEX = 6; +constexpr uint32_t ATTR_GLOBAL_BS_INDEX = 7; + +constexpr uint32_t MIN_BATCH_SIZE = 1; +constexpr uint32_t MAX_BATCH_SIZE = 256; +constexpr uint32_t MAX_MOE_EXERT_NUM = 512; +constexpr uint32_t RECV_AIV_NUM = 24; +constexpr uint32_t SUPPORT_TOP_K = 12; +constexpr uint32_t TWO_DIMS = 2; +constexpr uint32_t MIN_TOKEN_LENGTH = 512; +constexpr uint32_t MAX_TOKEN_LENGTH = 7168; +constexpr uint32_t MIN_GMM1_HIDDEN = 1024; +constexpr uint32_t MAX_GMM1_HIDDEN = 6144; +} // namespace + +namespace optiling { +static size_t CeilUp(size_t x, size_t y) +{ + return (x + y - 1) / y * y; +} + +static ge::graphStatus CheckTensorShape(gert::TilingContext *context, const char *nodeName, + DispatchGmmCombineDecodeTilingData &tilingData) +{ + uint32_t epRankId = tilingData.disGmmDeqSwigluQuantGmmDeqComInfo.epRankId; + uint32_t moeExpertNum = tilingData.disGmmDeqSwigluQuantGmmDeqComInfo.moeExpertNum; + uint32_t sharedExpertRankNum = tilingData.disGmmDeqSwigluQuantGmmDeqComInfo.sharedExpertRankNum; + uint32_t moeExpertNumPerRank = tilingData.disGmmDeqSwigluQuantGmmDeqComInfo.moeExpertNumPerRank; + uint32_t h = tilingData.disGmmDeqSwigluQuantGmmDeqComInfo.h; + uint64_t gmm1WeightDim2 = tilingData.disGmmDeqSwigluQuantGmmDeqComInfo.gmm1HLen; + + uint32_t localExpertNum = epRankId < sharedExpertRankNum ? 1 : moeExpertNumPerRank; + const gert::StorageShape *gmm1WeightStorageShape = context->GetInputShape(INPUT_GMM1_WEIGHT_INDEX); + OPS_ERR_IF(gmm1WeightStorageShape == nullptr, OPS_LOG_E(nodeName, "gmm1 weight shape is null."), + return ge::GRAPH_FAILED); + const int64_t gmm1WeightDim0 = gmm1WeightStorageShape->GetStorageShape().GetDim(0); + OPS_ERR_IF(gmm1WeightDim0 != localExpertNum, + OPS_LOG_E(nodeName, "gmm1Weight Dim0 must be expert number in current rank."), + return ge::GRAPH_FAILED); + + const gert::StorageShape *gmm1WeightScaleStorageShape = context->GetInputShape(INPUT_GMM1_WEIGHT_SCALE_INDEX); + OPS_ERR_IF(gmm1WeightScaleStorageShape == nullptr, OPS_LOG_E(nodeName, "gmm1 weight scale shape is null."), + return ge::GRAPH_FAILED); + OPS_ERR_IF(gmm1WeightScaleStorageShape->GetStorageShape().GetDimNum() != TWO_DIMS, + OPS_LOG_E(nodeName, "gmm1 weight scale shape dims must be 2, but current dim num is %lu.", + gmm1WeightScaleStorageShape->GetStorageShape().GetDimNum()), + return ge::GRAPH_FAILED); + const int64_t gmm1WeightScaleDim0 = gmm1WeightScaleStorageShape->GetStorageShape().GetDim(0); + OPS_ERR_IF(gmm1WeightScaleDim0 != localExpertNum, + OPS_LOG_E(nodeName, "gmm1WeightScale Dim0 must be expert number in current rank."), + return ge::GRAPH_FAILED); + const int64_t gmm1WeightScaleDim1 = gmm1WeightScaleStorageShape->GetStorageShape().GetDim(1); + OPS_ERR_IF(gmm1WeightScaleDim1 != gmm1WeightDim2, + OPS_LOG_E(nodeName, "gmm1WeightScale Dim1 must be %lu(gmm1WeightDim2).", gmm1WeightDim2), + return ge::GRAPH_FAILED); + + const gert::StorageShape *gmm2WeightStorageShape = context->GetInputShape(INPUT_GMM2_WEIGHT_INDEX); + OPS_ERR_IF(gmm2WeightStorageShape == nullptr, OPS_LOG_E(nodeName, "gmm2 weight shape is null."), + return ge::GRAPH_FAILED); + const int64_t gmm2WeightDim0 = gmm2WeightStorageShape->GetStorageShape().GetDim(0); + OPS_ERR_IF(gmm2WeightDim0 != localExpertNum, + OPS_LOG_E(nodeName, "gmm2Weight Dim0 must be expert number in current rank."), + return ge::GRAPH_FAILED); + + const gert::StorageShape *gmm2WeightScaleStorageShape = context->GetInputShape(INPUT_GMM2_WEIGHT_SCALE_INDEX); + OPS_ERR_IF(gmm2WeightScaleStorageShape == nullptr, OPS_LOG_E(nodeName, "gmm2 weight scale shape is null."), + return ge::GRAPH_FAILED); + OPS_ERR_IF(gmm2WeightScaleStorageShape->GetStorageShape().GetDimNum() != TWO_DIMS, + OPS_LOG_E(nodeName, "gmm2 weight scale shape dims must be 2, but current dim num is %lu.", + gmm2WeightScaleStorageShape->GetStorageShape().GetDimNum()), + return ge::GRAPH_FAILED); + const int64_t gmm2WeightScaleDim0 = gmm2WeightScaleStorageShape->GetStorageShape().GetDim(0); + OPS_ERR_IF(gmm2WeightScaleDim0 != localExpertNum, + OPS_LOG_E(nodeName, "gmm2WeightScale Dim0 must be expert number in current rank."), + return ge::GRAPH_FAILED); + const int64_t gmm2WeightScaleDim1 = gmm2WeightScaleStorageShape->GetStorageShape().GetDim(1); + OPS_ERR_IF(gmm2WeightScaleDim1 != h, OPS_LOG_E(nodeName, "gmm2WeightScale Dim1 must be %u.", h), + return ge::GRAPH_FAILED); + + return ge::GRAPH_SUCCESS; +} + +static ge::graphStatus CheckData(const char *nodeName, DispatchGmmCombineDecodeTilingData &tilingData) +{ + uint32_t batchSize = tilingData.disGmmDeqSwigluQuantGmmDeqComInfo.bs; + OPS_ERR_IF(batchSize < MIN_BATCH_SIZE, OPS_LOG_E(nodeName, "batchSize(bs) must >= %d.", MIN_BATCH_SIZE), + return ge::GRAPH_FAILED); + OPS_ERR_IF(batchSize > MAX_BATCH_SIZE, OPS_LOG_E(nodeName, "batchSize(bs) must <= %d.", MAX_BATCH_SIZE), + return ge::GRAPH_FAILED); + uint32_t tokenLength = tilingData.disGmmDeqSwigluQuantGmmDeqComInfo.h; + OPS_ERR_IF( + tokenLength < MIN_TOKEN_LENGTH || tokenLength > MAX_TOKEN_LENGTH, + OPS_LOG_E(nodeName, "tokenLength(h) is invalid. Only support [%u, %u].", MIN_TOKEN_LENGTH, MAX_TOKEN_LENGTH), + return ge::GRAPH_FAILED); + uint32_t gmm1HLen = tilingData.disGmmDeqSwigluQuantGmmDeqComInfo.gmm1HLen; + OPS_ERR_IF( + gmm1HLen < MIN_GMM1_HIDDEN || gmm1HLen > MAX_GMM1_HIDDEN, + OPS_LOG_E(nodeName, "gmm1 hidden size is invalid. Only support [%u, %u].", MIN_GMM1_HIDDEN, MAX_GMM1_HIDDEN), + return ge::GRAPH_FAILED); + uint32_t topK = tilingData.disGmmDeqSwigluQuantGmmDeqComInfo.k; + OPS_ERR_IF(topK > SUPPORT_TOP_K, OPS_LOG_E(nodeName, "topK(k) must <= %d.", SUPPORT_TOP_K), + return ge::GRAPH_FAILED); + uint32_t globalBatchSize = tilingData.disGmmDeqSwigluQuantGmmDeqComInfo.globalBs; + uint32_t epRankSize = tilingData.disGmmDeqSwigluQuantGmmDeqComInfo.epRankSize; + if (globalBatchSize == 0) { + globalBatchSize = epRankSize * batchSize; + tilingData.disGmmDeqSwigluQuantGmmDeqComInfo.globalBs = globalBatchSize; + } else { + OPS_ERR_IF(globalBatchSize < 0, OPS_LOG_E(nodeName, "globalBatchSize must >= 0."), return ge::GRAPH_FAILED); + OPS_ERR_IF(globalBatchSize % epRankSize > 0, + OPS_LOG_E(nodeName, "globalBatchSize must be divisible by epRankSize."), return ge::GRAPH_FAILED); + } + + return ge::GRAPH_SUCCESS; +} + +static ge::graphStatus GetAttrAndSetTilingData(gert::TilingContext *context, const char *nodeName, + DispatchGmmCombineDecodeTilingData &tilingData, std::string &groupEp) +{ + auto attrs = context->GetAttrs(); + OPS_ERR_IF(attrs == nullptr, OPS_LOG_E(nodeName, "attrs is nullptr."), return ge::GRAPH_FAILED); + + auto groupEpPtr = attrs->GetAttrPointer(static_cast(ATTR_GROUP_EP_INDEX)); + auto epRankSizePtr = attrs->GetAttrPointer(ATTR_EP_RANK_SIZE_INDEX); + auto epRankIdPtr = attrs->GetAttrPointer(ATTR_EP_RANK_ID_INDEX); + auto moeExpertNumPtr = attrs->GetAttrPointer(ATTR_MOE_EXPERT_NUM_INDEX); + auto sharedExpertNumPtr = attrs->GetAttrPointer(ATTR_SHARE_EXPERT_NUM_INDEX); + auto sharedExpertRankNumPtr = attrs->GetAttrPointer(ATTR_SHARE_EXPERT_RANK_NUM_INDEX); + auto quantModePtr = attrs->GetAttrPointer(ATTR_QUANT_MODE_INDEX); + auto globalBsPtr = attrs->GetAttrPointer(ATTR_GLOBAL_BS_INDEX); + + uint32_t epRankSize = static_cast(*epRankSizePtr); + uint32_t epRankId = static_cast(*epRankIdPtr); + uint32_t moeExpertNum = static_cast(*moeExpertNumPtr); + uint32_t sharedExpertNum = static_cast(*sharedExpertNumPtr); + uint32_t sharedExpertRankNum = static_cast(*sharedExpertRankNumPtr); + uint32_t moeExpertNumPerRank = moeExpertNum / (epRankSize - sharedExpertRankNum); + + OPS_ERR_IF(epRankId < 0, OPS_LOG_E(nodeName, "epRankId must >= 0."), return ge::GRAPH_FAILED); + OPS_ERR_IF(epRankId >= epRankSize, OPS_LOG_E(nodeName, "epRankId must < epRankSize."), return ge::GRAPH_FAILED); + OPS_ERR_IF(moeExpertNum > MAX_MOE_EXERT_NUM, OPS_LOG_E(nodeName, "moeExpertNum must <= %d.", MAX_MOE_EXERT_NUM), + return ge::GRAPH_FAILED); + OPS_ERR_IF(moeExpertNum <= 0, OPS_LOG_E(nodeName, "moeExpertNum must > 0."), return ge::GRAPH_FAILED); + OPS_ERR_IF(sharedExpertNum != 1, OPS_LOG_E(nodeName, "sharedExpertNum must be 1."), return ge::GRAPH_FAILED); + OPS_ERR_IF(moeExpertNum % (epRankSize - sharedExpertRankNum) != 0, + OPS_LOG_E(nodeName, "moeExpertNum must be divisible by (epRankSize - sharedExpertRankNum)."), + return ge::GRAPH_FAILED); + OPS_ERR_IF(moeExpertNumPerRank > RECV_AIV_NUM, + OPS_LOG_E(nodeName, "moeExpertNumPerRank must <= %d.", RECV_AIV_NUM), return ge::GRAPH_FAILED); + + groupEp = std::string(groupEpPtr); + tilingData.disGmmDeqSwigluQuantGmmDeqComInfo.epRankSize = epRankSize; + tilingData.disGmmDeqSwigluQuantGmmDeqComInfo.epRankId = epRankId; + tilingData.disGmmDeqSwigluQuantGmmDeqComInfo.moeExpertNum = moeExpertNum; + tilingData.disGmmDeqSwigluQuantGmmDeqComInfo.sharedExpertNum = sharedExpertNum; + tilingData.disGmmDeqSwigluQuantGmmDeqComInfo.sharedExpertRankNum = sharedExpertRankNum; + tilingData.disGmmDeqSwigluQuantGmmDeqComInfo.quantMode = static_cast(*quantModePtr); + tilingData.disGmmDeqSwigluQuantGmmDeqComInfo.globalBs = static_cast(*globalBsPtr); + tilingData.disGmmDeqSwigluQuantGmmDeqComInfo.moeExpertNumPerRank = moeExpertNumPerRank; + return ge::GRAPH_SUCCESS; +} + +static void SetHcommCfg(const gert::TilingContext *context, DispatchGmmCombineDecodeTilingData *tiling, const std::string groupEp) +{ + const char *nodeName = context->GetNodeName(); + OPS_LOG_D(nodeName, "DispatchGmmCombineDecode groupEp = %s", groupEp.c_str()); + uint32_t opType = OP_TYPE_ALL_TO_ALL; + std::string algConfigAllToAllStr = "AlltoAll=level0:fullmesh;level1:pairwise"; + std::string algConfigAllGatherStr = "AllGather=level0:ring"; + + AscendC::Mc2CcTilingConfig mc2CcTilingConfig(groupEp, opType, algConfigAllToAllStr); + mc2CcTilingConfig.GetTiling(tiling->mc2InitTiling); + mc2CcTilingConfig.GetTiling(tiling->mc2CcTiling); +} + +static ge::graphStatus SetWorkSpace(gert::TilingContext *context, const char *nodeName, + DispatchGmmCombineDecodeTilingData &tilingData) +{ + size_t *workSpaces = context->GetWorkspaceSizes(1); + OPS_ERR_IF(workSpaces == nullptr, OPS_LOG_E(nodeName, "workSpaces is nullptr."), return ge::GRAPH_FAILED); + size_t maxTokenNum; + uint32_t epRankSize = tilingData.disGmmDeqSwigluQuantGmmDeqComInfo.epRankSize; + uint32_t epRankId = tilingData.disGmmDeqSwigluQuantGmmDeqComInfo.epRankId; + uint32_t sharedExpertRankNum = tilingData.disGmmDeqSwigluQuantGmmDeqComInfo.sharedExpertRankNum; + uint32_t batchSize = tilingData.disGmmDeqSwigluQuantGmmDeqComInfo.bs; + uint32_t globalBs = tilingData.disGmmDeqSwigluQuantGmmDeqComInfo.globalBs; + uint32_t maxBatchSize = globalBs / epRankSize; + uint32_t topK = tilingData.disGmmDeqSwigluQuantGmmDeqComInfo.k; + uint32_t moeExpertNumPerRank = tilingData.disGmmDeqSwigluQuantGmmDeqComInfo.moeExpertNumPerRank; + uint32_t h = tilingData.disGmmDeqSwigluQuantGmmDeqComInfo.h; + uint64_t gmm2HLen = tilingData.disGmmDeqSwigluQuantGmmDeqComInfo.gmm1HLen / 2; + if (epRankId < sharedExpertRankNum) { + maxTokenNum = maxBatchSize * epRankSize / sharedExpertRankNum; + } else { + maxTokenNum = maxBatchSize * epRankSize * std::min(topK, moeExpertNumPerRank); + } + + size_t x2TokenSize = CeilUp(maxTokenNum * gmm2HLen * sizeof(int8_t), GM_ALIGN_SIZE); + size_t x2ScaleSize = CeilUp(maxTokenNum * sizeof(float), GM_ALIGN_SIZE); + size_t CVSwapBufferSize = + CeilUp(USE_CORE_NUM * L1_TILE_BYTE_SIZE * CUBE_WORKSPACE_STAGE * sizeof(int32_t), GM_ALIGN_SIZE); + size_t swigluOutSize = CeilUp(maxTokenNum * gmm2HLen * sizeof(float), GM_ALIGN_SIZE); + size_t groupListSize = CeilUp(moeExpertNumPerRank * sizeof(int64_t), GM_ALIGN_SIZE); + size_t expandIdxSize = CeilUp(batchSize * topK * sizeof(int32_t), GM_ALIGN_SIZE); + size_t epSendCountSize = CeilUp(epRankSize * moeExpertNumPerRank * sizeof(int32_t), GM_ALIGN_SIZE); + size_t x1TokenSize = CeilUp(maxTokenNum * h * sizeof(int8_t), GM_ALIGN_SIZE); + size_t x1ScaleSize = CeilUp(maxTokenNum * sizeof(float), GM_ALIGN_SIZE); + size_t gmm2DepOutSize = CeilUp(maxTokenNum * h * TOKEN_DTYPE_BYTE_SIZE, GM_ALIGN_SIZE); + size_t resveredSize = CeilUp(RESERVED_WORKSPACE_SIZE, GM_ALIGN_SIZE); + size_t usrSize = x2TokenSize + x2ScaleSize + CVSwapBufferSize + swigluOutSize + groupListSize + expandIdxSize + + epSendCountSize + x1TokenSize + x1ScaleSize + gmm2DepOutSize + resveredSize; + + workSpaces[0] = SYSTEM_NEED_WORKSPACE + usrSize; + return ge::GRAPH_SUCCESS; +} + +static ge::graphStatus DispatchGmmCombineDecodeTilingFuncImpl(gert::TilingContext *context) +{ + const char *nodeName = context->GetNodeName(); + DispatchGmmCombineDecodeTilingData *tilingData = context->GetTilingData(); + OPS_ERR_IF(tilingData == nullptr, OPS_LOG_E(nodeName, "tilingData is nullptr."), return ge::GRAPH_FAILED); + std::string groupEp = ""; + + const gert::StorageShape *xStorageShape = context->GetInputShape(INPUT_X_INDEX); + OPS_ERR_IF(xStorageShape == nullptr, OPS_LOG_E(nodeName, "x shape is null."), return ge::GRAPH_FAILED); + OPS_ERR_IF(xStorageShape->GetStorageShape().GetDimNum() != TWO_DIMS, + OPS_LOG_E(nodeName, "x shape dims must be 2, but current dim num is %lu.", + xStorageShape->GetStorageShape().GetDimNum()), + return ge::GRAPH_FAILED); + const int64_t batchSize = xStorageShape->GetStorageShape().GetDim(0); + tilingData->disGmmDeqSwigluQuantGmmDeqComInfo.bs = batchSize; + const int64_t hiddenSize = xStorageShape->GetStorageShape().GetDim(1); + tilingData->disGmmDeqSwigluQuantGmmDeqComInfo.h = hiddenSize; + + const gert::StorageShape *expertIdsStorageShape = context->GetInputShape(INPUT_EXPERT_IDS_INDEX); + OPS_ERR_IF(expertIdsStorageShape == nullptr, OPS_LOG_E(nodeName, "expertIds shape is null."), + return ge::GRAPH_FAILED); + OPS_ERR_IF(expertIdsStorageShape->GetStorageShape().GetDimNum() != TWO_DIMS, + OPS_LOG_E(nodeName, "expertIds shape dims must be 2, but current dim num is %lu.", + expertIdsStorageShape->GetStorageShape().GetDimNum()), + return ge::GRAPH_FAILED); + const int64_t topK = expertIdsStorageShape->GetStorageShape().GetDim(1); + tilingData->disGmmDeqSwigluQuantGmmDeqComInfo.k = topK; + OPS_ERR_IF(GetAttrAndSetTilingData(context, nodeName, *tilingData, groupEp) != ge::GRAPH_SUCCESS, + OPS_LOG_E(nodeName, "Get attr and set tiling data failed."), return ge::GRAPH_FAILED); + const gert::StorageShape *gmm1WeightStorageShape = context->GetInputShape(INPUT_GMM1_WEIGHT_INDEX); + OPS_ERR_IF(gmm1WeightStorageShape == nullptr, OPS_LOG_E(nodeName, "gmm1Weight shape is null."), + return ge::GRAPH_FAILED); + tilingData->disGmmDeqSwigluQuantGmmDeqComInfo.gmm1HLen = gmm1WeightStorageShape->GetOriginShape().GetDim(TWO_DIMS); + OPS_ERR_IF(CheckData(nodeName, *tilingData) != ge::GRAPH_SUCCESS, OPS_LOG_E(nodeName, "CheckData failed."), + return ge::GRAPH_FAILED); + OPS_ERR_IF(SetWorkSpace(context, nodeName, *tilingData) != ge::GRAPH_SUCCESS, + OPS_LOG_E(nodeName, "Tiling set workspace failed."), return ge::GRAPH_FAILED); + SetHcommCfg(context, tilingData, groupEp); + if (tilingData->disGmmDeqSwigluQuantGmmDeqComInfo.moeExpertNumPerRank == 1) { + context->SetTilingKey(0); + } else { + context->SetTilingKey(EXEC_FLAG_DEEP_FUSE); + } + context->SetBlockDim(USE_CORE_NUM); + return ge::GRAPH_SUCCESS; +} + +static ge::graphStatus DispatchGmmCombineDecodeTilingFunc(gert::TilingContext *context) +{ + ge::graphStatus ret = DispatchGmmCombineDecodeTilingFuncImpl(context); + return ret; +} + +struct DispatchGmmCombineDecodeCompileInfo {}; +ge::graphStatus TilingParseForDispatchGmmCombineDecode(gert::TilingParseContext *context) +{ + (void)context; + return ge::GRAPH_SUCCESS; +} + +IMPL_OP_OPTILING(DispatchGmmCombineDecode) + .Tiling(DispatchGmmCombineDecodeTilingFunc) + .TilingParse(TilingParseForDispatchGmmCombineDecode); +} // namespace optiling diff --git a/csrc/dispatch_gmm_combine_decode/op_kernel/dispatch_gmm_combine_decode.cpp b/csrc/dispatch_gmm_combine_decode/op_kernel/dispatch_gmm_combine_decode.cpp new file mode 100644 index 00000000000..a061c95a622 --- /dev/null +++ b/csrc/dispatch_gmm_combine_decode/op_kernel/dispatch_gmm_combine_decode.cpp @@ -0,0 +1,34 @@ +/* + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ +#include "dispatch_gmm_combine_decode.h" +#include +#include "lib/matmul_intf.h" + +extern "C" __global__ __aicore__ void dispatch_gmm_combine_decode( + // input + GM_ADDR x, GM_ADDR expert_ids, GM_ADDR gmm1_permuted_weight, GM_ADDR gmm1_permuted_weight_scale, + GM_ADDR gmm2_weight, GM_ADDR gmm2_weight_scale, GM_ADDR expert_smooth_scales, GM_ADDR expert_scales, + // output + GM_ADDR output, GM_ADDR outputRecvCount, + // system + GM_ADDR workspace, GM_ADDR tiling) +{ + icache_preload(8); + // New output recvCount + REGISTER_TILING_DEFAULT(DispatchGmmCombineDecodeTilingData); + KERNEL_TASK_TYPE_DEFAULT(KERNEL_TYPE_MIX_AIC_1_2); // 1C2V + GET_TILING_DATA(tiling_data, tiling); + if constexpr (TILING_KEY_IS(0) || TILING_KEY_IS(1)) { + DispatchGmmCombineDecode op; + op.Init(x, expert_ids, gmm1_permuted_weight, gmm1_permuted_weight_scale, gmm2_weight, gmm2_weight_scale, + expert_smooth_scales, expert_scales, output, outputRecvCount, workspace, nullptr, &tiling_data); + op.Process(); + } +} diff --git a/csrc/dispatch_gmm_combine_decode/op_kernel/dispatch_gmm_combine_decode.h b/csrc/dispatch_gmm_combine_decode/op_kernel/dispatch_gmm_combine_decode.h new file mode 100644 index 00000000000..179c9c93495 --- /dev/null +++ b/csrc/dispatch_gmm_combine_decode/op_kernel/dispatch_gmm_combine_decode.h @@ -0,0 +1,442 @@ +/* + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ +#ifndef DISPATCH_GMM_COMBINE_DECODE_H +#define DISPATCH_GMM_COMBINE_DECODE_H + +#include "lib/matmul_intf.h" +#include + +#include "catlass/catlass.hpp" +#include "catlass/arch/arch.hpp" +#include "catlass/layout/layout.hpp" +#include "catlass/epilogue/tile/tile_broadcast_mul.hpp" +#include "catlass/epilogue/tile/tile_broadcast_one_blk.hpp" +#include "catlass/epilogue/tile/tile_swizzle.hpp" +#include "catlass/gemm/block/block_swizzle.hpp" +#include "dispatch_gmm_combine_decode/gemm/kernel/grouped_matmul_slice_m_per_token_dequant_multistage_workspace.h" +#include "catlass/gemm/gemm_type.hpp" +#include "dispatch_gmm_combine_decode/epilogue/dispatch_policy.h" +#include "dispatch_gmm_combine_decode/gemm/dispatch_policy.h" +#include "dispatch_gmm_combine_decode/epilogue/block/block_epilogue.h" +#include "dispatch_gmm_combine_decode/gemm/block/block_mmad.h" +#include "dispatch_gmm_combine_decode/gemm/kernel/grouped_matmul_slice_m_per_token_dequant_swiglu_quant_multistage_workspace.h" + +#include "dispatch_gmm_combine_decode/raw_distributed/cam_moe_distribute_dispatch.h" + +#include "dispatch_gmm_combine_decode_tiling.h" +#include "dispatch_gmm_combine_decode_base.h" + +using namespace Catlass; + +using MmadAtlasA2Custom = + Gemm::MmadAtlasA2PreloadAsyncWithCallback; + +using Gmm1L1TileShape = GemmShape; +using Gmm1L0TileShape = GemmShape; +using Gmm1EpilogueTileShape = MatrixShape; +using Gmm1BlockScheduler = typename Gemm::Block::GemmIdentityBlockSwizzle; + +using Gmm2L1TileShape = GemmShape; +using Gmm2L0TileShape = GemmShape; +using Gmm2EpilogueTileShape = MatrixShape; +using Gmm2BlockScheduler = typename Gemm::Block::GemmIdentityBlockSwizzle; +using Gmm2DispatchPolicy = + Gemm::MmadAtlasA2PreloadAsyncWithCallbackResidentA; + +template +CATLASS_DEVICE void GmmDeqSwigluQuant(GemmCoord problemShape, uint32_t groupCount, GM_ADDR gmGroupList, GM_ADDR gmA, + layout::RowMajor layoutA, GM_ADDR gmB, layout::zN layoutB, GM_ADDR gmScale, + layout::VectorLayout layoutScale, GM_ADDR gmPerTokenScale, + layout::VectorLayout layoutPerTokenScale, GM_ADDR gmD, layout::RowMajor layoutD, + GM_ADDR gmDequantScale, layout::VectorLayout layoutDequantScale, GM_ADDR gmWorkspace, + GM_ADDR gmX, GM_ADDR debugGm, GM_ADDR gmexpertIds, GM_ADDR gmExpandIdx, + GM_ADDR gmEpSendCount, GM_ADDR gmResvered, GM_ADDR gmOutputRecvCount, + uint32_t epRankSize, uint32_t epRankId, uint32_t moeExpertNum, + uint32_t moeExpertNumPerRank, uint32_t sharedExpertNum, uint32_t sharedExpertRankNum, + uint32_t quantMode, uint32_t globalBs, uint32_t bs, uint32_t topK, uint32_t tokenLen) +{ + using ArchTag = Arch::AtlasA2; + using DispatchPolicy = DispatchPolicy_; + using L1TileShape = L1TileShape_; + using L0TileShape = L0TileShape_; + + using XType = XType_; + using AType = Gemm::GemmType; + using BType = Gemm::GemmType; + using CType = Gemm::GemmType; + + using BlockMmad = Gemm::Block::BlockMmad; + + constexpr uint32_t ubStages = 1; + using EpilogueDispatchPolicy = Epilogue::EpilogueAtlasA2PerTokenDequantSwiglu; + using ScaleType = Gemm::GemmType; + using PerTokenScaleType = Gemm::GemmType; + using DType = Gemm::GemmType; + + using RowBroadcastMulType = Gemm::GemmType; + using BroadcastOneBlkType = Gemm::GemmType; + using OneBlkColumnBroadcastMulType = Gemm::GemmType; + + using EpilogueTileShape = EpilogueTileShape_; + using TileRowBroadcastMul = Epilogue::Tile::TileRowBroadcastMul; + using TileBroadcastOneBlk = + Epilogue::Tile::TileBroadcastOneBlk; + using TileOneBlkColumnBroadcastMul = + Epilogue::Tile::TileOneBlkColumnBroadcastMul; + using TileCopy = Epilogue::Tile::TileCopy; + using TileScheduler = Epilogue::Tile::EpilogueHorizontalTileSwizzle; + + using BlockEpilogue = Epilogue::Block::BlockEpilogue; + + using BlockScheduler = BlockScheduler_; + + // kernel level + using ElementGroupList = int64_t; + + using GemmKernel = typename std::conditional< + (EXEC_FLAG & EXEC_FLAG_DEEP_FUSE), + Gemm::Kernel::GroupedMatmulSliceMPerTokenDequantSwigluQuantMultiStageWorkspace< + XType, BlockMmad, BlockEpilogue, BlockScheduler, WORKSPACE_STAGES, ElementGroupList>, + Gemm::Kernel::GroupedMatmulSliceMPerTokenDequantSwigluQuantMultiStageWorkspaceWithShallowDispatch< + BlockMmad, BlockEpilogue, BlockScheduler, WORKSPACE_STAGES, ElementGroupList>>::type; + + if constexpr (EXEC_FLAG & EXEC_FLAG_DEEP_FUSE) { + typename GemmKernel::Params params{problemShape, + groupCount, + gmGroupList, + gmA, + layoutA, + gmB, + layoutB, + gmScale, + layoutScale, + gmPerTokenScale, + layoutPerTokenScale, + gmD, + layoutD, + gmDequantScale, + layoutDequantScale, + gmWorkspace, + gmX, + debugGm, + gmexpertIds, + gmExpandIdx, + gmEpSendCount, + gmResvered, + gmOutputRecvCount, + epRankSize, + epRankId, + moeExpertNum, + moeExpertNumPerRank, + sharedExpertNum, + sharedExpertRankNum, + quantMode, + globalBs, + bs, + topK, + tokenLen}; + // call a kernel + GemmKernel gemm; + gemm(params); + } else { + typename GemmKernel::Params params{problemShape, + groupCount, + gmGroupList, + gmA, + layoutA, + gmB, + layoutB, + gmScale, + layoutScale, + gmPerTokenScale, + layoutPerTokenScale, + gmD, + layoutD, + gmDequantScale, + layoutDequantScale, + gmWorkspace}; + // call a kernel + GemmKernel gemm; + gemm(params); + } +} + +template +CATLASS_DEVICE void GmmDeq(GemmCoord problemShape, uint32_t groupCount, GM_ADDR gmGroupList, GM_ADDR gmA, + layout::RowMajor layoutA, GM_ADDR gmB, layout::nZ layoutB, GM_ADDR gmScale, + layout::VectorLayout layoutScale, GM_ADDR gmPerTokenScale, + layout::VectorLayout layoutPerTokenScale, GM_ADDR gmD, layout::RowMajor layoutD, + GM_ADDR gmWorkspace, void *combiner) +{ + using ArchTag = Arch::AtlasA2; + using DispatchPolicy = DispatchPolicy_; + using L1TileShape = L1TileShape_; + using L0TileShape = L0TileShape_; + + using AType = Gemm::GemmType; + using BType = Gemm::GemmType; + using CType = Gemm::GemmType; + + using BlockMmad = Gemm::Block::BlockMmad; + + constexpr uint32_t ubStages = 1; + using EpilogueDispatchPolicy = Epilogue::EpilogueAtlasA2PerTokenDequantCombine; + using ScaleType = Gemm::GemmType; + using PerTokenScaleType = Gemm::GemmType; + using DType = Gemm::GemmType; + + using RowBroadcastMulType = Gemm::GemmType; + using BroadcastOneBlkType = Gemm::GemmType; + using OneBlkColumnBroadcastMulType = Gemm::GemmType; + + using EpilogueTileShape = EpilogueTileShape_; + using TileRowBroadcastMul = Epilogue::Tile::TileRowBroadcastMul; + using TileBroadcastOneBlk = + Epilogue::Tile::TileBroadcastOneBlk; + using TileOneBlkColumnBroadcastMul = + Epilogue::Tile::TileOneBlkColumnBroadcastMul; + using TileCopy = Epilogue::Tile::TileCopy; + using TileScheduler = Epilogue::Tile::EpilogueHorizontalTileSwizzle; + + using BlockEpilogue = Epilogue::Block::BlockEpilogue; + + using BlockScheduler = BlockScheduler_; + + // kernel level + using ElementGroupList = int64_t; + using GemmKernel = Gemm::Kernel::GroupedMatmulSliceMPerTokenDequantMultiStageWorkspace< + TemplateMC2TypeFunc, BlockMmad, BlockEpilogue, BlockScheduler, WORKSPACE_STAGES, ElementGroupList>; + + typename GemmKernel::Params params{ + problemShape, groupCount, gmGroupList, gmA, layoutA, gmB, layoutB, gmScale, + layoutScale, gmPerTokenScale, layoutPerTokenScale, gmD, layoutD, gmWorkspace, combiner}; + + // call a kernel + GemmKernel gemm; + gemm(params); +} + +template +class DispatchGmmCombineDecode +{ +public: + __aicore__ inline DispatchGmmCombineDecode(){}; + __aicore__ inline void Init( + // input + GM_ADDR x, GM_ADDR expert_ids, GM_ADDR gmm1_permuted_weight, GM_ADDR gmm1_permuted_weight_scale, + GM_ADDR gmm2_weight, GM_ADDR gmm2_weight_scale, GM_ADDR expert_smooth_scales, GM_ADDR expert_scales, + // output + GM_ADDR output, GM_ADDR outputRecvCount, + // system + GM_ADDR workspaceGM, AscendC::TPipe *pipe, const DispatchGmmCombineDecodeTilingData *tilingData); + __aicore__ inline void Process(); + +private: + GM_ADDR gmX_; + GM_ADDR gmexpertIds_; + GM_ADDR gmPermuteWeight1_; + GM_ADDR gmPermuteScale1_; + GM_ADDR gmWeight2_; + GM_ADDR gmScale2_; + GM_ADDR gmOutput_; + GM_ADDR gmOutputRecvCount_; + GM_ADDR workspaceGM_; + GM_ADDR gmSmoothScales_; + GM_ADDR gmexpertScales_; + + uint32_t m_{0}; + uint32_t n_{0}; + uint32_t k_{0}; + uint32_t groupCount_{0}; + uint32_t n2_{0}; + uint32_t k2_{0}; + uint32_t globalRankId_{0}; + uint32_t winSizePerRank_{0}; + uint32_t blockDim_{0}; + uint32_t epRankSize_{0}; + uint32_t epRankId_{0}; + uint32_t moeExpertNum_{0}; + uint32_t moeExpertNumPerRank_{0}; + uint32_t sharedExpertNum_{0}; + uint32_t sharedExpertRankNum_{0}; + uint32_t quantMode_{0}; + uint32_t globalBs_{0}; + uint32_t bs_{0}; + uint32_t maxBs_{0}; + uint32_t topK_{0}; + + AscendC::TPipe *tpipe_{nullptr}; + __gm__ HcclOpResParam *winContext_{nullptr}; + const DispatchGmmCombineDecodeTilingData *tilingData_; +}; + +template +__aicore__ inline void DispatchGmmCombineDecode::Init( + // input + GM_ADDR x, GM_ADDR expert_ids, GM_ADDR gmm1_permuted_weight, GM_ADDR gmm1_permuted_weight_scale, + GM_ADDR gmm2_weight, GM_ADDR gmm2_weight_scale, GM_ADDR expert_smooth_scales, GM_ADDR expert_scales, + // output + GM_ADDR output, GM_ADDR outputRecvCount, + // system + GM_ADDR workspaceGM, AscendC::TPipe *pipe, const DispatchGmmCombineDecodeTilingData *tilingData) +{ + tpipe_ = pipe; + blockDim_ = AscendC::GetBlockNum(); + winContext_ = (__gm__ HcclOpResParam *)AscendC::GetHcclContext(); + + gmSmoothScales_ = expert_smooth_scales; // 量化平滑系数,当前无作用 + gmX_ = x; // dispatch的输入 + gmexpertIds_ = expert_ids; + gmPermuteWeight1_ = gmm1_permuted_weight; + gmPermuteScale1_ = gmm1_permuted_weight_scale; + gmWeight2_ = gmm2_weight; + gmScale2_ = gmm2_weight_scale; + gmOutput_ = output; + gmOutputRecvCount_ = outputRecvCount; + workspaceGM_ = workspaceGM; + gmexpertScales_ = expert_scales; + tilingData_ = tilingData; + epRankSize_ = tilingData->disGmmDeqSwigluQuantGmmDeqComInfo.epRankSize; + epRankId_ = tilingData->disGmmDeqSwigluQuantGmmDeqComInfo.epRankId; + moeExpertNum_ = tilingData->disGmmDeqSwigluQuantGmmDeqComInfo.moeExpertNum; + moeExpertNumPerRank_ = tilingData->disGmmDeqSwigluQuantGmmDeqComInfo.moeExpertNumPerRank; + sharedExpertNum_ = tilingData->disGmmDeqSwigluQuantGmmDeqComInfo.sharedExpertNum; + sharedExpertRankNum_ = tilingData->disGmmDeqSwigluQuantGmmDeqComInfo.sharedExpertRankNum; + quantMode_ = tilingData->disGmmDeqSwigluQuantGmmDeqComInfo.quantMode; + globalBs_ = tilingData->disGmmDeqSwigluQuantGmmDeqComInfo.globalBs; + bs_ = tilingData->disGmmDeqSwigluQuantGmmDeqComInfo.bs; + topK_ = tilingData->disGmmDeqSwigluQuantGmmDeqComInfo.k; + maxBs_ = globalBs_ / epRankSize_; + + bool isShareExpert = (epRankId_ < sharedExpertRankNum_); + if (isShareExpert) { + m_ = maxBs_ * epRankSize_ / sharedExpertRankNum_; + } else { + m_ = maxBs_ * epRankSize_ * (topK_ < moeExpertNumPerRank_ ? topK_ : moeExpertNumPerRank_); + } + + n_ = tilingData->disGmmDeqSwigluQuantGmmDeqComInfo.gmm1HLen; + k_ = tilingData->disGmmDeqSwigluQuantGmmDeqComInfo.h; + groupCount_ = isShareExpert ? 1 : tilingData->disGmmDeqSwigluQuantGmmDeqComInfo.moeExpertNumPerRank; + n2_ = k_; + k2_ = n_ / 2; +} + +template +__aicore__ inline void DispatchGmmCombineDecode::Process() +{ + if (g_coreType == AscendC::AIV) { + ((DispatchGmmCombineDecodeTilingData *)tilingData_)->disGmmDeqSwigluQuantGmmDeqComInfo.aicNum = get_block_num(); + if constexpr (EXEC_FLAG & EXEC_FLAG_DEEP_FUSE) { + ((DispatchGmmCombineDecodeTilingData *)tilingData_)->disGmmDeqSwigluQuantGmmDeqComInfo.aivNum = get_block_num(); + } else { + ((DispatchGmmCombineDecodeTilingData *)tilingData_)->disGmmDeqSwigluQuantGmmDeqComInfo.aivNum = + get_block_num() * get_subblockdim(); + } + } + GemmCoord gmm1ProblemShape{m_, n_, k_}; + GemmCoord gmm2ProblemShape{m_, n2_, k2_}; + + layout::RowMajor layoutX1{m_, k_}; + layout::zN layoutWeight1 = layout::zN::template MakeLayout(k_, n_); + layout::VectorLayout layoutScale1{n_}; + layout::VectorLayout layoutPerTokenScale1{m_}; + layout::RowMajor layoutX2{m_, k2_}; + layout::nZ layoutWeight2 = layout::nZ::template MakeLayout(k2_, n2_); + layout::VectorLayout layoutScale2{n2_}; + layout::VectorLayout layoutPerTokenScale2{m_}; + layout::RowMajor layoutOutput{m_, n2_}; + + size_t workspaceOffset = 0; + constexpr int32_t resveredWorkSpaceSize = 256 * 1024; + GM_ADDR gmX2 = workspaceGM_; + workspaceOffset += RoundUp(static_cast(m_) * k2_ * sizeof(int8_t)); + GM_ADDR gmPerTokenScale2 = workspaceGM_ + workspaceOffset; + workspaceOffset += RoundUp(static_cast(m_) * sizeof(float)); + GM_ADDR gmWorkspace = workspaceGM_ + workspaceOffset; + + GM_ADDR gmCVSwap = workspaceGM_ + workspaceOffset; + workspaceOffset += RoundUp(static_cast(blockDim_) * (GMM1_L1M * GMM1_L1N) * + WORKSPACE_STAGES * sizeof(int32_t)); + GM_ADDR gmSwigluOut = workspaceGM_ + workspaceOffset; + workspaceOffset += RoundUp(static_cast(m_) * k2_ * sizeof(float)); + GM_ADDR gmGroupList = workspaceGM_ + workspaceOffset; + workspaceOffset += RoundUp(static_cast(groupCount_) * sizeof(int64_t)); + GM_ADDR gmExpandIdx = workspaceGM_ + workspaceOffset; + workspaceOffset += RoundUp(static_cast(bs_) * topK_ * sizeof(int32_t)); + GM_ADDR gmEpSendCount = workspaceGM_ + workspaceOffset; + workspaceOffset += RoundUp(static_cast(epRankSize_) * groupCount_ * sizeof(int32_t)); + GM_ADDR gmX1Token = workspaceGM_ + workspaceOffset; + workspaceOffset += RoundUp(static_cast(m_) * k_ * sizeof(int8_t)); + GM_ADDR gmX1Scale = workspaceGM_ + workspaceOffset; + workspaceOffset += RoundUp(static_cast(m_) * sizeof(float)); + GM_ADDR gmGmm2DepOut = workspaceGM_ + workspaceOffset; + workspaceOffset += RoundUp(static_cast(m_) * k_ * sizeof(ExpandXType)); + GM_ADDR gmResvered = workspaceGM_ + workspaceOffset; + workspaceOffset += RoundUp(resveredWorkSpaceSize); + + if constexpr (EXEC_FLAG == 0) { + if constexpr (g_coreType == AscendC::AIV) { + AscendC::TPipe tpipe; + MoeDistributeDispatchImpl::CamMoeDistributeDispatch + dispatcher; + dispatcher.Init(gmX_, gmexpertIds_, gmSmoothScales_, gmX1Token, gmX1Scale, gmExpandIdx, gmGroupList, + gmEpSendCount, gmOutputRecvCount_, nullptr, gmWorkspace, &tpipe, tilingData_); + dispatcher.Process(); + tpipe.Destroy(); + icache_preload(8); + } + + AscendC::PipeBarrier(); + Arch::CrossCoreFlag gmm1AivFinished{0}; + if constexpr (g_coreType == AscendC::AIV) { + Arch::CrossCoreBarrier<0x0, PIPE_MTE3>(); + Arch::CrossCoreSetFlag<0x2, PIPE_MTE3>(gmm1AivFinished); + } else { + Arch::CrossCoreWaitFlag(gmm1AivFinished); + } + } + GmmDeqSwigluQuant( + gmm1ProblemShape, groupCount_, gmGroupList, gmX1Token, layoutX1, gmPermuteWeight1_, layoutWeight1, + gmPermuteScale1_, layoutScale1, gmX1Scale, layoutPerTokenScale1, gmX2, layoutX2, gmPerTokenScale2, + layoutPerTokenScale2, gmWorkspace, gmX_, gmSmoothScales_, gmexpertIds_, gmExpandIdx, gmEpSendCount, gmResvered, + gmOutputRecvCount_, epRankSize_, epRankId_, moeExpertNum_, moeExpertNumPerRank_, sharedExpertNum_, + sharedExpertRankNum_, quantMode_, globalBs_, bs_, topK_, k_); + AscendC::PipeBarrier(); + Arch::CrossCoreFlag gmm1AivFinished{0}; + if constexpr (g_coreType == AscendC::AIV) { + Arch::CrossCoreBarrier<0x0, PIPE_MTE3>(); + Arch::CrossCoreSetFlag<0x2, PIPE_MTE3>(gmm1AivFinished); + } else { + Arch::CrossCoreWaitFlag(gmm1AivFinished); + } + + MoeDistributeCombineImpl::CamMoeDistributeCombine combiner; + if (g_coreType == AscendC::AIV) { + combiner.Init(gmGmm2DepOut, gmexpertIds_, gmExpandIdx, gmEpSendCount, nullptr, gmexpertScales_, gmOutput_, + workspaceGM_, nullptr, tilingData_); + } + GmmDeq(gmm2ProblemShape, groupCount_, gmGroupList, gmX2, layoutX2, gmWeight2_, layoutWeight2, + gmScale2_, layoutScale2, gmPerTokenScale2, layoutPerTokenScale2, gmGmm2DepOut, + layoutOutput, gmWorkspace, &combiner); +} +#endif // DISPATCH_GMM_COMBINE_DECODE_H diff --git a/csrc/dispatch_gmm_combine_decode/op_kernel/dispatch_gmm_combine_decode/epilogue/block/block_epilogue.h b/csrc/dispatch_gmm_combine_decode/op_kernel/dispatch_gmm_combine_decode/epilogue/block/block_epilogue.h new file mode 100644 index 00000000000..4bbbc792762 --- /dev/null +++ b/csrc/dispatch_gmm_combine_decode/op_kernel/dispatch_gmm_combine_decode/epilogue/block/block_epilogue.h @@ -0,0 +1,14 @@ +/* + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ +#pragma once +#include "catlass/epilogue/block/block_epilogue.hpp" + +#include "block_epilogue_per_token_dequant_swiglu.h" +#include "block_epilogue_per_token_dequant.hpp" diff --git a/csrc/dispatch_gmm_combine_decode/op_kernel/dispatch_gmm_combine_decode/epilogue/block/block_epilogue_per_token_dequant.hpp b/csrc/dispatch_gmm_combine_decode/op_kernel/dispatch_gmm_combine_decode/epilogue/block/block_epilogue_per_token_dequant.hpp new file mode 100644 index 00000000000..cf7d16e36f6 --- /dev/null +++ b/csrc/dispatch_gmm_combine_decode/op_kernel/dispatch_gmm_combine_decode/epilogue/block/block_epilogue_per_token_dequant.hpp @@ -0,0 +1,760 @@ +/* + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ +#ifndef ACT_EPILOGUE_BLOCK_EPILOGUE_PER_TOKEN_DEQUANT_HPP +#define ACT_EPILOGUE_BLOCK_EPILOGUE_PER_TOKEN_DEQUANT_HPP + +#include "../../raw_distributed/cam_moe_distribute_combine.h" +#include "catlass/catlass.hpp" +#include "catlass/arch/resource.hpp" +#include "catlass/detail/callback.hpp" +#include "catlass/epilogue/dispatch_policy.hpp" +#include "catlass/gemm_coord.hpp" +#include "catlass/layout/layout.hpp" +#include "catlass/matrix_coord.hpp" + +#define ENABLE_EP_SEND_COUNT_HASH 0 + +namespace Catlass::Epilogue::Block { + +template +class BlockEpilogue, CType_, ScaleType_, PerTokenScaleType_, + DType_, TileRowBroadcastMul_, TileBroadcastOneBlk_, TileOneBlkColumnBroadcastMul_, TileCopy_, + EpilogueTileSwizzle_> +{ +public: + using DispatchPolicy = EpilogueAtlasA2PerTokenDequantCombine; + using ArchTag = typename DispatchPolicy::ArchTag; + static constexpr uint32_t UB_STAGES = UB_STAGES_; + + // Data infos + using ElementC = typename CType_::Element; + using LayoutC = typename CType_::Layout; + using ElementScale = typename ScaleType_::Element; + using LayoutScale = typename ScaleType_::Layout; + using ElementPerTokenScale = typename PerTokenScaleType_::Element; + using LayoutPerTokenScale = typename PerTokenScaleType_::Layout; + using ElementD = typename DType_::Element; + using LayoutD = typename DType_::Layout; + + // Check data infos + static_assert(std::is_same_v && + (std::is_same_v || std::is_same_v) && + std::is_same_v && std::is_same_v, + "The element type template parameters of BlockEpilogue are wrong"); + static_assert(std::is_same_v && std::is_same_v && + std::is_same_v && + std::is_same_v, + "The layout template parameters of BlockEpilogue are wrong"); + + // Tile compute ops + using TileRowBroadcastMul = TileRowBroadcastMul_; + using TileBroadcastOneBlk = TileBroadcastOneBlk_; + using TileOneBlkColumnBroadcastMul = TileOneBlkColumnBroadcastMul_; + + // Tile copy + using CopyGmToUbC = typename TileCopy_::CopyGmToUbC; + using CopyGmToUbScale = typename TileCopy_::CopyGmToUbX; + using CopyGmToUbPerTokenScale = typename TileCopy_::CopyGmToUbY; + using CopyUbToGmD = typename TileCopy_::CopyUbToGmD; + + using EpilogueTileSwizzle = EpilogueTileSwizzle_; + + using TileShape = typename TileRowBroadcastMul::TileShape; + + static_assert(TileShape::ROW == TileBroadcastOneBlk::COMPUTE_LENGTH && + std::is_same_v, + "TileShape must be consistent for all tile compute ops"); + + static_assert((UB_STAGES * (TileShape::COUNT * sizeof(ElementC) + TileShape::COLUMN * sizeof(ElementScale) + + TileShape::ROW * sizeof(ElementPerTokenScale) + TileShape::COUNT * sizeof(ElementD)) + + (TileShape::COUNT + TileShape::COLUMN + TileShape::COUNT + TileShape::ROW) * sizeof(float) + + TileShape::ROW * BYTE_PER_BLK) <= ArchTag::UB_SIZE, + "TileShape is too large to fit in UB"); + + struct Params { + __gm__ ElementScale *ptrScale{nullptr}; + LayoutScale layoutScale{}; + __gm__ ElementPerTokenScale *ptrPerTokenScale{nullptr}; + LayoutPerTokenScale layoutPerTokenScale{}; + __gm__ ElementD *ptrD{nullptr}; + LayoutD layoutD{}; + + CATLASS_DEVICE + Params() {}; + + CATLASS_DEVICE + Params(__gm__ ElementScale *ptrScale_, LayoutScale const &layoutScale_, + __gm__ ElementPerTokenScale *ptrPerTokenScale_, LayoutPerTokenScale const &layoutPerTokenScale_, + __gm__ ElementD *ptrD_, LayoutD const &layoutD_) + : ptrScale(ptrScale_), + layoutScale(layoutScale_), + ptrPerTokenScale(ptrPerTokenScale_), + layoutPerTokenScale(layoutPerTokenScale_), + ptrD(ptrD_), + layoutD(layoutD_) + {} + }; + + CATLASS_DEVICE + BlockEpilogue(Arch::Resource const &resource, Params const ¶ms = Params{}) : params(params) + { + size_t ubOffset = 0; + int32_t eventVMTE2 = 0; + int32_t eventMTE2V = 0; + int32_t eventMTE3V = 0; + int32_t eventVMTE3 = 0; + for (uint32_t i = 0; i < UB_STAGES; ++i) { + ubCList[i] = resource.ubBuf.template GetBufferByByte(ubOffset); + ubOffset += TileShape::COUNT * sizeof(ElementC); + ubScaleList[i] = resource.ubBuf.template GetBufferByByte(ubOffset); + ubOffset += TileShape::COLUMN * sizeof(ElementScale); + ubPerTokenScaleList[i] = resource.ubBuf.template GetBufferByByte(ubOffset); + ubOffset += TileShape::ROW * sizeof(ElementPerTokenScale); + ubDList[i] = resource.ubBuf.template GetBufferByByte(ubOffset); + ubOffset += TileShape::COUNT * sizeof(ElementD); + + eventUbCVMTE2List[i] = eventVMTE2++; + eventUbCMTE2VList[i] = eventMTE2V++; + eventUbScaleVMTE2List[i] = eventVMTE2++; + eventUbScaleMTE2VList[i] = eventMTE2V++; + eventUbPerTokenScaleVMTE2List[i] = eventVMTE2++; + eventUbPerTokenScaleMTE2VList[i] = eventMTE2V++; + eventUbDMTE3VList[i] = eventMTE3V++; + eventUbDVMTE3List[i] = eventVMTE3++; + + AscendC::SetFlag(eventUbCVMTE2List[i]); + AscendC::SetFlag(eventUbScaleVMTE2List[i]); + AscendC::SetFlag(eventUbPerTokenScaleVMTE2List[i]); + AscendC::SetFlag(eventUbDMTE3VList[i]); + } + ubCFp32 = resource.ubBuf.template GetBufferByByte(ubOffset); + ubOffset += TileShape::COUNT * sizeof(float); + ubScaleFp32 = resource.ubBuf.template GetBufferByByte(ubOffset); + ubOffset += TileShape::COLUMN * sizeof(float); + ubMul = resource.ubBuf.template GetBufferByByte(ubOffset); + ubOffset += TileShape::COUNT * sizeof(float); + ubPerTokenScaleFp32 = resource.ubBuf.template GetBufferByByte(ubOffset); + ubOffset += TileShape::ROW * sizeof(float); + ubPerTokenScaleFp32Brcb = resource.ubBuf.template GetBufferByByte(ubOffset); + ubOffset += TileShape::ROW * BYTE_PER_BLK; + ubPerTokenMul = ubMul; + } + + CATLASS_DEVICE + ~BlockEpilogue() + { + for (uint32_t i = 0; i < UB_STAGES; ++i) { + AscendC::WaitFlag(eventUbCVMTE2List[i]); + AscendC::WaitFlag(eventUbScaleVMTE2List[i]); + AscendC::WaitFlag(eventUbPerTokenScaleVMTE2List[i]); + AscendC::WaitFlag(eventUbDMTE3VList[i]); + } + } + + CATLASS_DEVICE + void UpdateParams(Params const ¶ms_) + { + params = params_; + } + + CATLASS_DEVICE + void operator()(GemmCoord const &blockShapeMNK, GemmCoord const &blockCoordMNK, + GemmCoord const &actualBlockShapeMNK, AscendC::GlobalTensor const &gmBlockC, + LayoutC const &layoutBlockC, Callback &&callback = Callback{}) + { + if (actualBlockShapeMNK.k() == 0) { + return; + } + callback(); + + // Calculate the offset of the current block + MatrixCoord blockShape = blockShapeMNK.GetCoordMN(); + MatrixCoord blockCoord = blockCoordMNK.GetCoordMN(); + MatrixCoord actualBlockShape = actualBlockShapeMNK.GetCoordMN(); + MatrixCoord blockOffset = blockCoord * blockShape; + + AscendC::GlobalTensor gmScale; + gmScale.SetGlobalBuffer(params.ptrScale); + AscendC::GlobalTensor gmPerTokenScale; + gmPerTokenScale.SetGlobalBuffer(params.ptrPerTokenScale); + AscendC::GlobalTensor gmD; + gmD.SetGlobalBuffer(params.ptrD); + + auto ubTileStride = MakeCoord(static_cast(TileShape::COLUMN), 1L); + auto tileShape = TileShape::ToCoord(); + EpilogueTileSwizzle epilogueTileSwizzle(actualBlockShape, tileShape); + uint32_t tileLoops = epilogueTileSwizzle.GetLoops(); + uint32_t subblockIdx = AscendC::GetSubBlockIdx(); + uint32_t subblockNum = AscendC::GetSubBlockNum(); + for (uint32_t loopIdx = subblockIdx; loopIdx < tileLoops; loopIdx += subblockNum) { + auto tileCoord = epilogueTileSwizzle.GetTileCoord(loopIdx); + auto actualTileShape = epilogueTileSwizzle.GetActualTileShape(tileCoord); + auto tileOffsetInBlock = tileCoord * tileShape; + auto tileOffset = blockOffset + tileOffsetInBlock; + + auto gmTileC = gmBlockC[layoutBlockC.GetOffset(tileOffsetInBlock)]; + auto layoutGmTileC = layoutBlockC.GetTileLayout(actualTileShape); + + auto &ubC = ubCList[ubListId]; + LayoutC layoutUbC{actualTileShape, ubTileStride}; + + AscendC::WaitFlag(eventUbCVMTE2List[ubListId]); + copyGmToUbC(ubC, gmTileC, layoutUbC, layoutGmTileC); + AscendC::SetFlag(eventUbCMTE2VList[ubListId]); + + auto scaleTileOffset = tileOffset.template GetCoordByAxis<1>(); + auto scaleTileShape = actualTileShape.template GetCoordByAxis<1>(); + + auto gmTileScale = gmScale[params.layoutScale.GetOffset(scaleTileOffset)]; + auto layoutGmTileScale = params.layoutScale.GetTileLayout(scaleTileShape); + + auto &ubScale = ubScaleList[ubListId]; + auto layoutUbScale = LayoutScale::template MakeLayoutInUb(scaleTileShape); + + AscendC::WaitFlag(eventUbScaleVMTE2List[ubListId]); + copyGmToUbScale(ubScale, gmTileScale, layoutUbScale, layoutGmTileScale); + AscendC::SetFlag(eventUbScaleMTE2VList[ubListId]); + + auto perTokenScaleTileOffset = tileOffset.template GetCoordByAxis<0>(); + auto perTokenScaleTileShape = actualTileShape.template GetCoordByAxis<0>(); + + auto gmTilePerTokenScale = gmPerTokenScale[params.layoutPerTokenScale.GetOffset(perTokenScaleTileOffset)]; + auto layoutGmTilePerTokenScale = params.layoutPerTokenScale.GetTileLayout(perTokenScaleTileShape); + + auto &ubPerTokenScale = ubPerTokenScaleList[ubListId]; + auto layoutUbPerTokenScale = + LayoutScale::template MakeLayoutInUb(perTokenScaleTileShape); + + AscendC::WaitFlag(eventUbPerTokenScaleVMTE2List[ubListId]); + copyGmToUbPerTokenScale(ubPerTokenScale, gmTilePerTokenScale, layoutUbPerTokenScale, + layoutGmTilePerTokenScale); + AscendC::SetFlag(eventUbPerTokenScaleMTE2VList[ubListId]); + + AscendC::WaitFlag(eventUbCMTE2VList[ubListId]); + AscendC::Cast(ubCFp32, ubC, AscendC::RoundMode::CAST_RINT, TileShape::COUNT); + AscendC::SetFlag(eventUbCVMTE2List[ubListId]); + + AscendC::WaitFlag(eventUbScaleMTE2VList[ubListId]); + AscendC::Cast(ubScaleFp32, ubScale, AscendC::RoundMode::CAST_NONE, TileShape::COLUMN); + AscendC::SetFlag(eventUbScaleVMTE2List[ubListId]); + + AscendC::WaitFlag(eventUbPerTokenScaleMTE2VList[ubListId]); + AscendC::Cast(ubPerTokenScaleFp32, ubPerTokenScale, AscendC::RoundMode::CAST_NONE, TileShape::ROW); + AscendC::SetFlag(eventUbPerTokenScaleVMTE2List[ubListId]); + + tileRowBroadcastMul(ubMul, ubCFp32, ubScaleFp32); + tileBroadcastOneBlk(ubPerTokenScaleFp32Brcb, ubPerTokenScaleFp32); + AscendC::PipeBarrier(); + tileOneBlkColumnBroadcastMul(ubPerTokenMul, ubMul, ubPerTokenScaleFp32Brcb); + AscendC::PipeBarrier(); + + auto &ubD = ubDList[ubListId]; + LayoutD layoutUbD{actualTileShape, ubTileStride}; + + AscendC::WaitFlag(eventUbDMTE3VList[ubListId]); + AscendC::Cast(ubD, ubPerTokenMul, AscendC::RoundMode::CAST_RINT, TileShape::COUNT); + AscendC::SetFlag(eventUbDVMTE3List[ubListId]); + + auto gmTileD = gmD[params.layoutD.GetOffset(tileOffset)]; + auto layoutGmTileD = params.layoutD.GetTileLayout(actualTileShape); + + AscendC::WaitFlag(eventUbDVMTE3List[ubListId]); + copyUbToGmD(gmTileD, ubD, layoutGmTileD, layoutUbD); + AscendC::SetFlag(eventUbDMTE3VList[ubListId]); + + ubListId = (ubListId + 1 < UB_STAGES) ? (ubListId + 1) : 0; + } + } + +private: + Params params; + + AscendC::LocalTensor ubCList[UB_STAGES]; + AscendC::LocalTensor ubScaleList[UB_STAGES]; + AscendC::LocalTensor ubPerTokenScaleList[UB_STAGES]; + AscendC::LocalTensor ubDList[UB_STAGES]; + + int32_t eventUbCVMTE2List[UB_STAGES]; + int32_t eventUbCMTE2VList[UB_STAGES]; + int32_t eventUbScaleVMTE2List[UB_STAGES]; + int32_t eventUbScaleMTE2VList[UB_STAGES]; + int32_t eventUbPerTokenScaleVMTE2List[UB_STAGES]; + int32_t eventUbPerTokenScaleMTE2VList[UB_STAGES]; + int32_t eventUbDMTE3VList[UB_STAGES]; + int32_t eventUbDVMTE3List[UB_STAGES]; + + uint32_t ubListId{0}; + + AscendC::LocalTensor ubCFp32; + AscendC::LocalTensor ubScaleFp32; + AscendC::LocalTensor ubMul; + AscendC::LocalTensor ubPerTokenScaleFp32; + AscendC::LocalTensor ubPerTokenScaleFp32Brcb; + AscendC::LocalTensor ubPerTokenMul; + + TileRowBroadcastMul tileRowBroadcastMul; + TileBroadcastOneBlk tileBroadcastOneBlk; + TileOneBlkColumnBroadcastMul tileOneBlkColumnBroadcastMul; + + CopyGmToUbC copyGmToUbC; + CopyGmToUbScale copyGmToUbScale; + CopyGmToUbPerTokenScale copyGmToUbPerTokenScale; + CopyUbToGmD copyUbToGmD; +}; + +template +class BlockEpilogue, CType_, Gemm::GemmType, + Gemm::GemmType, DType_, TileRowBroadcastMul_, TileBroadcastOneBlk_, + TileOneBlkColumnBroadcastMul_, TileCopy_, EpilogueTileSwizzle_> +{ +public: + using DispatchPolicy = EpilogueAtlasA2PerTokenDequantCombine; + using ArchTag = typename DispatchPolicy::ArchTag; + static constexpr uint32_t UB_STAGES = UB_STAGES_; + static constexpr uint32_t EXEC_FLAG = EXEC_FLAG_; + + // Data infos + using ElementC = typename CType_::Element; + using LayoutC = typename CType_::Layout; + using ElementScale = float; + using LayoutScale = LayoutScale_; + using ElementPerTokenScale = float; + using LayoutPerTokenScale = LayoutPerTokenScale_; + using ElementD = typename DType_::Element; + using LayoutD = typename DType_::Layout; + + // Check data infos + static_assert(std::is_same_v && + (std::is_same_v || std::is_same_v), + "The element type template parameters of BlockEpilogue are wrong"); + static_assert(std::is_same_v && std::is_same_v && + std::is_same_v && + std::is_same_v, + "The layout template parameters of BlockEpilogue are wrong"); + + // Tile compute ops + using TileRowBroadcastMul = TileRowBroadcastMul_; + using TileBroadcastOneBlk = TileBroadcastOneBlk_; + using TileOneBlkColumnBroadcastMul = TileOneBlkColumnBroadcastMul_; + + // Tile copy + using CopyGmToUbC = typename TileCopy_::CopyGmToUbC; + using CopyGmToUbScale = typename TileCopy_::CopyGmToUbX; + using CopyGmToUbPerTokenScale = typename TileCopy_::CopyGmToUbY; + using CopyUbToGmD = typename TileCopy_::CopyUbToGmD; + + using EpilogueTileSwizzle = EpilogueTileSwizzle_; + + using TileShape = typename TileRowBroadcastMul::TileShape; + + static_assert(TileShape::ROW == TileBroadcastOneBlk::COMPUTE_LENGTH && + std::is_same_v, + "TileShape must be consistent for all tile compute ops"); + + static_assert((UB_STAGES * (TileShape::COUNT * sizeof(ElementC) + TileShape::COLUMN * sizeof(ElementScale) + + TileShape::ROW * sizeof(ElementPerTokenScale) + TileShape::COUNT * sizeof(ElementD)) + + (TileShape::COUNT + TileShape::COUNT) * sizeof(float) + TileShape::ROW * BYTE_PER_BLK) <= + ArchTag::UB_SIZE, + "TileShape is too large to fit in UB"); + + struct Params { + __gm__ ElementScale *ptrScale{nullptr}; + LayoutScale layoutScale{}; + __gm__ ElementPerTokenScale *ptrPerTokenScale{nullptr}; + LayoutPerTokenScale layoutPerTokenScale{}; + __gm__ ElementD *ptrD{nullptr}; + LayoutD layoutD{}; + + CATLASS_DEVICE + Params() {}; + + CATLASS_DEVICE + Params(__gm__ ElementScale *ptrScale_, LayoutScale const &layoutScale_, + __gm__ ElementPerTokenScale *ptrPerTokenScale_, LayoutPerTokenScale const &layoutPerTokenScale_, + __gm__ ElementD *ptrD_, LayoutD const &layoutD_) + : ptrScale(ptrScale_), + layoutScale(layoutScale_), + ptrPerTokenScale(ptrPerTokenScale_), + layoutPerTokenScale(layoutPerTokenScale_), + ptrD(ptrD_), + layoutD(layoutD_) + {} + }; + + CATLASS_DEVICE void AlignUbOffset() + { + size_t ubMask = ubOffset & (MoeDistributeCombineImpl::UB_ALIGN - 1); + if (ubMask != 0) { + ubOffset += MoeDistributeCombineImpl::UB_ALIGN - ubMask; + } + } + + CATLASS_DEVICE + BlockEpilogue(Arch::Resource &resource, MoeDistributeCombineImpl::CombineCalcInfo &calcInfo, + Params const ¶ms = Params{}) + : resource(resource), calcInfo(calcInfo), params(params) + { + for (uint32_t i = 0; i < UB_STAGES; ++i) { + ubCList[i] = resource.ubBuf.template GetBufferByByte(ubOffset); + ubOffset += TileShape::COUNT * sizeof(ElementC); + ubScaleList[i] = resource.ubBuf.template GetBufferByByte(ubOffset); + ubOffset += TileShape::COLUMN * sizeof(ElementScale); + ubPerTokenScaleList[i] = resource.ubBuf.template GetBufferByByte(ubOffset); + ubOffset += TileShape::ROW * sizeof(ElementPerTokenScale); + ubDList[i] = resource.ubBuf.template GetBufferByByte(ubOffset); + ubOffset += TileShape::COUNT * sizeof(ElementD); + + eventUbCVMTE2List[i] = eventVMTE2++; + eventUbCMTE2VList[i] = eventMTE2V++; + eventUbScaleVMTE2List[i] = eventVMTE2++; + eventUbScaleMTE2VList[i] = eventMTE2V++; + eventUbPerTokenScaleVMTE2List[i] = eventVMTE2++; + eventUbPerTokenScaleMTE2VList[i] = eventMTE2V++; + eventUbDMTE3VList[i] = eventMTE3V++; + eventUbDVMTE3List[i] = eventVMTE3++; + + AscendC::SetFlag(eventUbCVMTE2List[i]); + AscendC::SetFlag(eventUbScaleVMTE2List[i]); + AscendC::SetFlag(eventUbPerTokenScaleVMTE2List[i]); + AscendC::SetFlag(eventUbDMTE3VList[i]); + } + ubCFp32 = resource.ubBuf.template GetBufferByByte(ubOffset); + ubOffset += TileShape::COUNT * sizeof(float); + ubMul = resource.ubBuf.template GetBufferByByte(ubOffset); + ubOffset += TileShape::COUNT * sizeof(float); + ubPerTokenScaleBrcb = resource.ubBuf.template GetBufferByByte(ubOffset); + ubOffset += TileShape::ROW * BYTE_PER_BLK; + ubPerTokenMul = ubCFp32; + + if constexpr (EXEC_FLAG & EXEC_FLAG_DEEP_FUSE) { + AlignUbOffset(); + epSendCountLocal_ = resource.ubBuf.template GetBufferByByte(ubOffset); + ubOffset += calcInfo.moeSendNum_ * sizeof(int32_t); + AlignUbOffset(); + AscendC::GlobalTensor epSendCountGM; + epSendCountGM.SetGlobalBuffer((__gm__ int32_t *)calcInfo.epSendCount_); + uint32_t epSendCountSize = calcInfo.isShardExpert_ ? calcInfo.epWorldSize_ : calcInfo.moeSendNum_; + AscendC::DataCopyExtParams epSendCntParams = {1U, static_cast(epSendCountSize * sizeof(uint32_t)), + 0U, 0U, 0U}; + AscendC::DataCopyPadExtParams copyPadParams{false, 0U, 0U, 0U}; + AscendC::DataCopyPad(epSendCountLocal_, epSendCountGM, epSendCntParams, copyPadParams); + AscendC::SetFlag(eventMTE2S); + AscendC::WaitFlag(eventMTE2S); +#if ENABLE_EP_SEND_COUNT_HASH + tokenToEpRankHashLocal_ = resource.ubBuf.template GetBufferByByte(ubOffset); + uint32_t maxGroupSendCount = 0; + uint32_t groupSendCount = 0; + for (uint32_t expertIdx = 0; expertIdx < calcInfo.moeExpertPerRankNum_; ++expertIdx) { + uint32_t prevGroupSendCount = groupSendCount; + groupSendCount = epSendCountLocal_.GetValue((expertIdx + 1) * calcInfo.epWorldSize_ - 1); + if (maxGroupSendCount < groupSendCount - prevGroupSendCount) { + maxGroupSendCount = groupSendCount - prevGroupSendCount; + } + } + ubOffset += maxGroupSendCount * sizeof(int32_t); + AlignUbOffset(); + // assert: ubOffset <= AscendC::TOTAL_UB_SIZE or + // AscendC::TOTAL_VEC_LOCAL_SIZE +#endif + } + } + + CATLASS_DEVICE + ~BlockEpilogue() + { + for (uint32_t i = 0; i < UB_STAGES; ++i) { + AscendC::WaitFlag(eventUbCVMTE2List[i]); + AscendC::WaitFlag(eventUbScaleVMTE2List[i]); + AscendC::WaitFlag(eventUbPerTokenScaleVMTE2List[i]); + AscendC::WaitFlag(eventUbDMTE3VList[i]); + } + } + + CATLASS_DEVICE + void UpdateParams(Params const ¶ms_) + { + params = params_; + } + + CATLASS_DEVICE GM_ADDR GetWinAddrByRankId(const int32_t rankId, const uint8_t expertLocalId = 0U) + { + return (GM_ADDR)((calcInfo.epRankId_ == rankId) + ? calcInfo.epWinContext_->localWindowsIn + : ((HcclRankRelationResV2 *)(calcInfo.epWinContext_->remoteRes[rankId].nextDevicePtr)) + ->windowsIn) + + calcInfo.winDataSizeOffset_ + expertLocalId * calcInfo.expertPerSizeOnWin_ + rankId * OPT_RANK_OFFSET; + } +#if ENABLE_EP_SEND_COUNT_HASH + CATLASS_DEVICE void InitTokenToEpRankHashLocalForEpRank(uint32_t &hashOffset, uint32_t epRank, uint32_t copyLen) + { + constexpr uint32_t DUPLICATE_MASK_COUNT = 8; + uint32_t hashOffsetMask = (((uint32_t)hashOffset) & (DUPLICATE_MASK_COUNT - 1)); + if (hashOffsetMask != 0) { + uint32_t remainMaskCount = DUPLICATE_MASK_COUNT - hashOffsetMask; + if (copyLen < remainMaskCount) { + remainMaskCount = copyLen; + } + uint64_t copyMask = ((1UL << remainMaskCount) - 1) << hashOffsetMask; + AscendC::Duplicate(tokenToEpRankHashLocal_[hashOffset - hashOffsetMask], epRank, ©Mask, 1, 1, + DUPLICATE_MASK_COUNT); + hashOffset += remainMaskCount; + copyLen -= remainMaskCount; + } + if (copyLen > 0) { + AscendC::Duplicate(tokenToEpRankHashLocal_[hashOffset], epRank, copyLen); + hashOffset += copyLen; + } + } +#endif + + CATLASS_DEVICE void SetCombineSendEpRank(uint32_t epRank, uint32_t &remoteEpRank, uint32_t &localEpRank) + { + if ((calcInfo.isShardExpert_) && (epRank < calcInfo.sharedExpertRankNum_)) { + remoteEpRank = calcInfo.epRankId_; + localEpRank = epRank; + } else { + remoteEpRank = epRank; + localEpRank = calcInfo.epRankId_; + } + } + + CATLASS_DEVICE void DoCombineSend(AscendC::LocalTensor &ubD, layout::RowMajor &layoutGmTileD, + LayoutD &layoutUbD, int64_t groupOffsetD, uint32_t expertIdx, uint32_t tileOffsetD) + { + const uint32_t copyTokenLen = layoutGmTileD.shape(1) * sizeof(ElementD); + const uint32_t copyTokenSrcStride = + (layoutUbD.stride(0) - layoutUbD.shape(1)) / (BYTE_PER_C0 / sizeof(ElementD)); + const uint32_t copyTokenDstStride = (layoutGmTileD.stride(0) - layoutGmTileD.shape(1)) * sizeof(ElementD); + + int64_t offsetD = groupOffsetD + tileOffsetD; + uint32_t startToken = offsetD / calcInfo.axisH_; + uint32_t tokenOffset = offsetD - startToken * calcInfo.axisH_; + uint32_t itToken = startToken; + uint32_t endToken = startToken + layoutGmTileD.shape(0); +#if ENABLE_EP_SEND_COUNT_HASH + uint32_t epRankStart = tokenToEpRankHashLocal_(itToken - startToken); +#else + constexpr uint32_t epRankStart = 0; +#endif + uint32_t sendCount = + expertIdx == 0 && epRankStart == 0 ? 0 : epSendCountLocal_.GetValue(expertOffset + epRankStart - 1); + for (uint32_t epRank = epRankStart; epRank < calcInfo.epWorldSize_ && itToken < endToken; ++epRank) { + uint32_t prevSendCount = sendCount; + sendCount = epSendCountLocal_.GetValue(expertOffset + epRank); + if (prevSendCount <= itToken && itToken < sendCount) { + uint32_t copyTokenCount = (sendCount < endToken ? sendCount : endToken) - itToken; + AscendC::DataCopyExtParams dataCopyParams(copyTokenCount, copyTokenLen, copyTokenSrcStride, + copyTokenDstStride, 0); + uint32_t remoteEpRank; + uint32_t localEpRank; + SetCombineSendEpRank(epRank, remoteEpRank, localEpRank); + GM_ADDR rankGM = GetWinAddrByRankId(remoteEpRank, expertIdx) + + localEpRank * calcInfo.moeExpertPerRankNum_ * calcInfo.expertPerSizeOnWin_; + AscendC::GlobalTensor rankWindow; + rankWindow.SetGlobalBuffer((__gm__ ElementD *)rankGM); + AscendC::DataCopyPad(rankWindow[(itToken - prevSendCount) * calcInfo.axisH_ + tokenOffset], + ubD[(itToken - startToken) * layoutUbD.stride(0)], dataCopyParams); + itToken += copyTokenCount; + } + } + } + + CATLASS_DEVICE + void operator()(int64_t groupOffsetD, uint32_t expertIdx, GemmCoord const &blockShapeMNK, + GemmCoord const &blockCoordMNK, GemmCoord const &actualBlockShapeMNK, + AscendC::GlobalTensor const &gmBlockC, LayoutC const &layoutBlockC, + Callback &&callback = Callback{}) + { + if (actualBlockShapeMNK.k() == 0) { + return; + } + + if constexpr (EXEC_FLAG & EXEC_FLAG_DEEP_FUSE) { + expertOffset = expertIdx * calcInfo.epWorldSize_; +#if ENABLE_EP_SEND_COUNT_HASH + if (currentExpertIdx_ != expertIdx) { + uint32_t hashOffset = 0; + uint32_t sendCount = expertIdx == 0 ? 0 : epSendCountLocal_.GetValue(expertOffset - 1); + for (uint32_t epRank = 0; epRank < calcInfo.epWorldSize_; ++epRank) { + uint32_t prevSendCount = sendCount; + sendCount = epSendCountLocal_.GetValue(expertOffset + epRank); + InitTokenToEpRankHashLocalForEpRank(hashOffset, epRank, sendCount - prevSendCount); + } + AscendC::SetFlag(eventVS); + AscendC::WaitFlag(eventVS); + currentExpertIdx_ = expertIdx; + } +#endif + } + + callback(); + // Calculate the offset of the current block + MatrixCoord blockShape = blockShapeMNK.GetCoordMN(); + MatrixCoord blockCoord = blockCoordMNK.GetCoordMN(); + MatrixCoord actualBlockShape = actualBlockShapeMNK.GetCoordMN(); + MatrixCoord blockOffset = blockCoord * blockShape; + + AscendC::GlobalTensor gmScale; + gmScale.SetGlobalBuffer(params.ptrScale); + AscendC::GlobalTensor gmPerTokenScale; + gmPerTokenScale.SetGlobalBuffer(params.ptrPerTokenScale); + AscendC::GlobalTensor gmD; + gmD.SetGlobalBuffer(params.ptrD); + + auto ubTileStride = MakeCoord(static_cast(TileShape::COLUMN), 1L); + auto tileShape = TileShape::ToCoord(); + EpilogueTileSwizzle epilogueTileSwizzle(actualBlockShape, tileShape); + uint32_t tileLoops = epilogueTileSwizzle.GetLoops(); + uint32_t subblockIdx = AscendC::GetSubBlockIdx(); + uint32_t subblockNum = AscendC::GetSubBlockNum(); + for (uint32_t loopIdx = subblockIdx; loopIdx < tileLoops; loopIdx += subblockNum) { + auto tileCoord = epilogueTileSwizzle.GetTileCoord(loopIdx); + auto actualTileShape = epilogueTileSwizzle.GetActualTileShape(tileCoord); + auto tileOffsetInBlock = tileCoord * tileShape; + auto tileOffset = blockOffset + tileOffsetInBlock; + + auto gmTileC = gmBlockC[layoutBlockC.GetOffset(tileOffsetInBlock)]; + auto layoutGmTileC = layoutBlockC.GetTileLayout(actualTileShape); + + auto &ubC = ubCList[ubListId]; + LayoutC layoutUbC{actualTileShape, ubTileStride}; + + AscendC::WaitFlag(eventUbCVMTE2List[ubListId]); + copyGmToUbC(ubC, gmTileC, layoutUbC, layoutGmTileC); + AscendC::SetFlag(eventUbCMTE2VList[ubListId]); + + auto scaleTileOffset = tileOffset.template GetCoordByAxis<1>(); + auto scaleTileShape = actualTileShape.template GetCoordByAxis<1>(); + + auto gmTileScale = gmScale[params.layoutScale.GetOffset(scaleTileOffset)]; + auto layoutGmTileScale = params.layoutScale.GetTileLayout(scaleTileShape); + + auto &ubScale = ubScaleList[ubListId]; + auto layoutUbScale = LayoutScale::template MakeLayoutInUb(scaleTileShape); + + AscendC::WaitFlag(eventUbScaleVMTE2List[ubListId]); + copyGmToUbScale(ubScale, gmTileScale, layoutUbScale, layoutGmTileScale); + AscendC::SetFlag(eventUbScaleMTE2VList[ubListId]); + + auto perTokenScaleTileOffset = tileOffset.template GetCoordByAxis<0>(); + auto perTokenScaleTileShape = actualTileShape.template GetCoordByAxis<0>(); + + auto gmTilePerTokenScale = gmPerTokenScale[params.layoutPerTokenScale.GetOffset(perTokenScaleTileOffset)]; + auto layoutGmTilePerTokenScale = params.layoutPerTokenScale.GetTileLayout(perTokenScaleTileShape); + + auto &ubPerTokenScale = ubPerTokenScaleList[ubListId]; + auto layoutUbPerTokenScale = + LayoutScale::template MakeLayoutInUb(perTokenScaleTileShape); + + AscendC::WaitFlag(eventUbPerTokenScaleVMTE2List[ubListId]); + copyGmToUbPerTokenScale(ubPerTokenScale, gmTilePerTokenScale, layoutUbPerTokenScale, + layoutGmTilePerTokenScale); + AscendC::SetFlag(eventUbPerTokenScaleMTE2VList[ubListId]); + + AscendC::WaitFlag(eventUbCMTE2VList[ubListId]); + AscendC::Cast(ubCFp32, ubC, AscendC::RoundMode::CAST_RINT, TileShape::COUNT); + AscendC::SetFlag(eventUbCVMTE2List[ubListId]); + + AscendC::WaitFlag(eventUbScaleMTE2VList[ubListId]); + tileRowBroadcastMul(ubMul, ubCFp32, ubScale); + AscendC::SetFlag(eventUbScaleVMTE2List[ubListId]); + + AscendC::WaitFlag(eventUbPerTokenScaleMTE2VList[ubListId]); + tileBroadcastOneBlk(ubPerTokenScaleBrcb, ubPerTokenScale); + AscendC::SetFlag(eventUbPerTokenScaleVMTE2List[ubListId]); + + AscendC::PipeBarrier(); + tileOneBlkColumnBroadcastMul(ubPerTokenMul, ubMul, ubPerTokenScaleBrcb); + AscendC::PipeBarrier(); + + auto &ubD = ubDList[ubListId]; + LayoutD layoutUbD{actualTileShape, ubTileStride}; + + AscendC::WaitFlag(eventUbDMTE3VList[ubListId]); + AscendC::Cast(ubD, ubPerTokenMul, AscendC::RoundMode::CAST_RINT, TileShape::COUNT); + AscendC::SetFlag(eventUbDVMTE3List[ubListId]); + + auto tileOffsetD = params.layoutD.GetOffset(tileOffset); + auto layoutGmTileD = params.layoutD.GetTileLayout(actualTileShape); + + AscendC::WaitFlag(eventUbDVMTE3List[ubListId]); + + if constexpr (EXEC_FLAG & EXEC_FLAG_DEEP_FUSE) { + DoCombineSend(ubD, layoutGmTileD, layoutUbD, groupOffsetD, expertIdx, tileOffsetD); + } else { + auto gmTileD = gmD[tileOffsetD]; + copyUbToGmD(gmTileD, ubD, layoutGmTileD, layoutUbD); + } + + AscendC::SetFlag(eventUbDMTE3VList[ubListId]); + + ubListId = (ubListId + 1 < UB_STAGES) ? (ubListId + 1) : 0; + } + } + +private: + Params params; + Arch::Resource &resource; + MoeDistributeCombineImpl::CombineCalcInfo calcInfo; + + AscendC::LocalTensor ubCList[UB_STAGES]; + AscendC::LocalTensor ubScaleList[UB_STAGES]; + AscendC::LocalTensor ubPerTokenScaleList[UB_STAGES]; + AscendC::LocalTensor ubDList[UB_STAGES]; + + int32_t eventUbCVMTE2List[UB_STAGES]; + int32_t eventUbCMTE2VList[UB_STAGES]; + int32_t eventUbScaleVMTE2List[UB_STAGES]; + int32_t eventUbScaleMTE2VList[UB_STAGES]; + int32_t eventUbPerTokenScaleVMTE2List[UB_STAGES]; + int32_t eventUbPerTokenScaleMTE2VList[UB_STAGES]; + int32_t eventUbDMTE3VList[UB_STAGES]; + int32_t eventUbDVMTE3List[UB_STAGES]; + + AscendC::LocalTensor epSendCountLocal_; +#if ENABLE_EP_SEND_COUNT_HASH + AscendC::LocalTensor tokenToEpRankHashLocal_; + uint32_t currentExpertIdx_{static_cast(-1)}; +#endif + + size_t ubOffset{0}; + int32_t eventVMTE2{0}; + int32_t eventMTE2V{0}; + int32_t eventMTE3V{0}; + int32_t eventVMTE3{0}; + int32_t eventVS{0}; + int32_t eventMTE2S{0}; + + uint32_t expertOffset; + + uint32_t ubListId{0}; + + AscendC::LocalTensor ubCFp32; + AscendC::LocalTensor ubMul; + AscendC::LocalTensor ubPerTokenScaleBrcb; + AscendC::LocalTensor ubPerTokenMul; + + TileRowBroadcastMul tileRowBroadcastMul; + TileBroadcastOneBlk tileBroadcastOneBlk; + TileOneBlkColumnBroadcastMul tileOneBlkColumnBroadcastMul; + + CopyGmToUbC copyGmToUbC; + CopyGmToUbScale copyGmToUbScale; + CopyGmToUbPerTokenScale copyGmToUbPerTokenScale; + CopyUbToGmD copyUbToGmD; +}; + +} // namespace Catlass::Epilogue::Block + +#endif // ACT_EPILOGUE_BLOCK_EPILOGUE_PER_TOKEN_DEQUANT_HPP diff --git a/csrc/dispatch_gmm_combine_decode/op_kernel/dispatch_gmm_combine_decode/epilogue/block/block_epilogue_per_token_dequant_swiglu.h b/csrc/dispatch_gmm_combine_decode/op_kernel/dispatch_gmm_combine_decode/epilogue/block/block_epilogue_per_token_dequant_swiglu.h new file mode 100644 index 00000000000..c89b62d224e --- /dev/null +++ b/csrc/dispatch_gmm_combine_decode/op_kernel/dispatch_gmm_combine_decode/epilogue/block/block_epilogue_per_token_dequant_swiglu.h @@ -0,0 +1,326 @@ +/* + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ +#pragma once +#include "catlass/catlass.hpp" +#include "catlass/arch/resource.hpp" +#include "catlass/epilogue/dispatch_policy.hpp" +#include "catlass/gemm_coord.hpp" +#include "catlass/matrix_coord.hpp" +#include "catlass/layout/layout.hpp" +#include "catlass/detail/callback.hpp" + +#include "../tile/tile_stride_muls.h" +#include "../tile/tile_stride_binary.h" + +namespace Catlass::Epilogue::Block { + +template +class BlockEpilogue, CType_, + Gemm::GemmType, Gemm::GemmType, DType_, + TileRowBroadcastMul_, TileBroadcastOneBlk_, TileOneBlkColumnBroadcastMul_, TileCopy_, + EpilogueTileSwizzle_> +{ +public: + using DispatchPolicy = EpilogueAtlasA2PerTokenDequantSwiglu; + using ArchTag = typename DispatchPolicy::ArchTag; + static constexpr uint32_t UB_STAGES = UB_STAGES_; + + // Data infos + using ElementC = typename CType_::Element; + using LayoutC = typename CType_::Layout; + using ElementScale = float; + using LayoutScale = LayoutScale_; + using ElementPerTokenScale = float; + using LayoutPerTokenScale = LayoutPerTokenScale_; + using ElementD = typename DType_::Element; + using LayoutD = typename DType_::Layout; + + // Check data infos + static_assert(std::is_same_v && std::is_same_v, + "The element type template parameters of BlockEpilogue are wrong"); + static_assert(std::is_same_v && std::is_same_v && + std::is_same_v && + std::is_same_v, + "The layout template parameters of BlockEpilogue are wrong"); + + // Tile compute ops + using TileRowBroadcastMul = TileRowBroadcastMul_; + using TileBroadcastOneBlk = TileBroadcastOneBlk_; + using TileOneBlkColumnBroadcastMul = TileOneBlkColumnBroadcastMul_; + + // Tile copy + using CopyGmToUbC = typename TileCopy_::CopyGmToUbC; + using CopyGmToUbScale = typename TileCopy_::CopyGmToUbX; + using CopyGmToUbPerTokenScale = typename TileCopy_::CopyGmToUbY; + using CopyUbToGmD = typename TileCopy_::CopyUbToGmD; + + using EpilogueTileSwizzle = EpilogueTileSwizzle_; + + using TileShape = typename TileRowBroadcastMul::TileShape; + static_assert(TileShape::ROW * sizeof(float) % BYTE_PER_BLK == 0, + "The per token scale granularity for word calculation must be 32 bytes aligned."); + static_assert(TileShape::COLUMN % 2 == 0, "The n-axis needs to be divided into two parts."); + + static_assert(TileShape::ROW == TileBroadcastOneBlk::COMPUTE_LENGTH && + std::is_same_v, + "TileShape must be consistent for all tile compute ops"); + + static constexpr uint32_t CHUNK_TILE_COLUMN = TileShape::COLUMN / 2; + using ChunkTileShape = MatrixShape; + + using TileStrideMuls = Tile::TileStrideMuls; + using TileStrideDiv = Tile::TileStrideDiv; + using TileStrideMul = Tile::TileStrideMul; + + static_assert(UB_STAGES <= 2, "UB stages too large, event id is not enough."); + + static_assert((UB_STAGES * (TileShape::COUNT * sizeof(ElementC) + TileShape::COLUMN * sizeof(ElementScale) + + TileShape::ROW * sizeof(ElementPerTokenScale) + TileShape::COUNT * sizeof(ElementD)) + + (TileShape::COUNT + ChunkTileShape::COUNT) * sizeof(float) + TileShape::ROW * BYTE_PER_BLK) <= + ArchTag::UB_SIZE, + "TileShape is too large to fit in UB"); + + struct Params { + __gm__ ElementScale *ptrScale{nullptr}; + LayoutScale layoutScale{}; + __gm__ ElementPerTokenScale *ptrPerTokenScale{nullptr}; + LayoutPerTokenScale layoutPerTokenScale{}; + __gm__ ElementD *ptrD{nullptr}; + LayoutD layoutD{}; + + CATLASS_DEVICE + Params() {}; + + CATLASS_DEVICE + Params(__gm__ ElementScale *ptrScale_, LayoutScale const &layoutScale_, + __gm__ ElementPerTokenScale *ptrPerTokenScale_, LayoutPerTokenScale const &layoutPerTokenScale_, + __gm__ ElementD *ptrD_, LayoutD const &layoutD_) + : ptrScale(ptrScale_), + layoutScale(layoutScale_), + ptrPerTokenScale(ptrPerTokenScale_), + layoutPerTokenScale(layoutPerTokenScale_), + ptrD(ptrD_), + layoutD(layoutD_) + {} + }; + + CATLASS_DEVICE + BlockEpilogue(Arch::Resource const &resource, Params const ¶ms = Params{}) : params(params) + { + size_t ubOffset = 0; + int32_t eventVMTE2 = 0; + int32_t eventMTE2V = 0; + int32_t eventMTE3V = 0; + int32_t eventVMTE3 = 0; + for (uint32_t i = 0; i < UB_STAGES; ++i) { + ubCList[i] = resource.ubBuf.template GetBufferByByte(ubOffset); + ubOffset += TileShape::COUNT * sizeof(ElementC); + ubScaleList[i] = resource.ubBuf.template GetBufferByByte(ubOffset); + ubOffset += TileShape::COLUMN * sizeof(ElementScale); + ubPerTokenScaleList[i] = resource.ubBuf.template GetBufferByByte(ubOffset); + ubOffset += TileShape::ROW * sizeof(ElementPerTokenScale); + ubDList[i] = resource.ubBuf.template GetBufferByByte(ubOffset); + ubOffset += TileShape::COUNT * sizeof(ElementD); + + eventUbCVMTE2List[i] = eventVMTE2++; + eventUbCMTE2VList[i] = eventMTE2V++; + eventUbScaleVMTE2List[i] = eventVMTE2++; + eventUbScaleMTE2VList[i] = eventMTE2V++; + eventUbPerTokenScaleVMTE2List[i] = eventVMTE2++; + eventUbPerTokenScaleMTE2VList[i] = eventMTE2V++; + eventUbDMTE3VList[i] = eventMTE3V++; + eventUbDVMTE3List[i] = eventVMTE3++; + + AscendC::SetFlag(eventUbCVMTE2List[i]); + AscendC::SetFlag(eventUbScaleVMTE2List[i]); + AscendC::SetFlag(eventUbPerTokenScaleVMTE2List[i]); + AscendC::SetFlag(eventUbDMTE3VList[i]); + } + ubTmpMxN = resource.ubBuf.template GetBufferByByte(ubOffset); + ubOffset += TileShape::COUNT * sizeof(float); + ubTmpMx32B = resource.ubBuf.template GetBufferByByte(ubOffset); + ubOffset += TileShape::ROW * BYTE_PER_BLK; + ubTmpMxChunkN = resource.ubBuf.template GetBufferByByte(ubOffset); + } + + CATLASS_DEVICE + ~BlockEpilogue() + { + for (uint32_t i = 0; i < UB_STAGES; ++i) { + AscendC::WaitFlag(eventUbCVMTE2List[i]); + AscendC::WaitFlag(eventUbScaleVMTE2List[i]); + AscendC::WaitFlag(eventUbPerTokenScaleVMTE2List[i]); + AscendC::WaitFlag(eventUbDMTE3VList[i]); + } + } + + CATLASS_DEVICE + void UpdateParams(Params const ¶ms_) + { + params = params_; + } + + CATLASS_DEVICE + void operator()(GemmCoord const &blockShapeMNK, GemmCoord const &blockCoordMNK, + GemmCoord const &actualBlockShapeMNK, AscendC::GlobalTensor const &gmBlockC, + LayoutC const &layoutBlockC, Callback &&callback = Callback{}) + { + if (0 == actualBlockShapeMNK.k()) { + return; + } + callback(); + // Calculate the offset of the current block + MatrixCoord blockShape = blockShapeMNK.GetCoordMN(); + MatrixCoord blockCoord = blockCoordMNK.GetCoordMN(); + MatrixCoord actualBlockShape = actualBlockShapeMNK.GetCoordMN(); + MatrixCoord blockOffset = blockCoord * blockShape; + + AscendC::GlobalTensor gmScale; + gmScale.SetGlobalBuffer(params.ptrScale); + AscendC::GlobalTensor gmPerTokenScale; + gmPerTokenScale.SetGlobalBuffer(params.ptrPerTokenScale); + AscendC::GlobalTensor gmD; + gmD.SetGlobalBuffer(params.ptrD); + + auto ubTileStride = MakeCoord(static_cast(TileShape::COLUMN), 1L); + auto ubChunkTileStride = MakeCoord(static_cast(ChunkTileShape::COLUMN), 1L); + auto tileShape = TileShape::ToCoord(); + EpilogueTileSwizzle epilogueTileSwizzle(actualBlockShape, tileShape); + uint32_t tileLoops = epilogueTileSwizzle.GetLoops(); + uint32_t subblockIdx = 0; // 原本是AscendC::GetSubBlockIdx(); + uint32_t subblockNum = 1; // 原本是AscendC::GetSubBlockNum(); + for (uint32_t loopIdx = subblockIdx; loopIdx < tileLoops; loopIdx += subblockNum) { + auto tileCoord = epilogueTileSwizzle.GetTileCoord(loopIdx); + auto actualTileShape = epilogueTileSwizzle.GetActualTileShape(tileCoord); + auto tileOffsetInBlock = tileCoord * tileShape; + auto tileOffset = blockOffset + tileOffsetInBlock; + + auto actualChunkTileShape = MakeCoord(actualTileShape.row(), actualTileShape.column() >> 1); + auto chunkTileOffset = MakeCoord(tileOffset.row(), tileOffset.column() >> 1); + + auto gmTileC = gmBlockC[layoutBlockC.GetOffset(tileOffsetInBlock)]; + auto layoutGmTileC = layoutBlockC.GetTileLayout(actualTileShape); + + auto &ubC = ubCList[ubListId]; + LayoutC layoutUbC{actualTileShape, ubTileStride}; + + AscendC::WaitFlag(eventUbCVMTE2List[ubListId]); + copyGmToUbC(ubC, gmTileC, layoutUbC, layoutGmTileC); + AscendC::SetFlag(eventUbCMTE2VList[ubListId]); + + auto scaleTileOffset = tileOffset.template GetCoordByAxis<1>(); + auto scaleTileShape = actualTileShape.template GetCoordByAxis<1>(); + + auto gmTileScale = gmScale[params.layoutScale.GetOffset(scaleTileOffset)]; + auto layoutGmTileScale = params.layoutScale.GetTileLayout(scaleTileShape); + + auto &ubScale = ubScaleList[ubListId]; + auto layoutUbScale = LayoutScale::template MakeLayoutInUb(scaleTileShape); + + AscendC::WaitFlag(eventUbScaleVMTE2List[ubListId]); + copyGmToUbScale(ubScale, gmTileScale, layoutUbScale, layoutGmTileScale); + AscendC::SetFlag(eventUbScaleMTE2VList[ubListId]); + + auto perTokenScaleTileOffset = tileOffset.template GetCoordByAxis<0>(); + auto perTokenScaleTileShape = actualTileShape.template GetCoordByAxis<0>(); + + auto gmTilePerTokenScale = gmPerTokenScale[params.layoutPerTokenScale.GetOffset(perTokenScaleTileOffset)]; + auto layoutGmTilePerTokenScale = params.layoutPerTokenScale.GetTileLayout(perTokenScaleTileShape); + + auto &ubPerTokenScale = ubPerTokenScaleList[ubListId]; + auto layoutUbPerTokenScale = + LayoutScale::template MakeLayoutInUb(perTokenScaleTileShape); + + AscendC::WaitFlag(eventUbPerTokenScaleVMTE2List[ubListId]); + copyGmToUbPerTokenScale(ubPerTokenScale, gmTilePerTokenScale, layoutUbPerTokenScale, + layoutGmTilePerTokenScale); + AscendC::SetFlag(eventUbPerTokenScaleMTE2VList[ubListId]); + + AscendC::WaitFlag(eventUbCMTE2VList[ubListId]); + AscendC::Cast(ubTmpMxN, ubC, AscendC::RoundMode::CAST_RINT, TileShape::COUNT); + AscendC::SetFlag(eventUbCVMTE2List[ubListId]); + AscendC::WaitFlag(eventUbScaleMTE2VList[ubListId]); + tileRowBroadcastMul(ubTmpMxN, ubTmpMxN, ubScale); + AscendC::SetFlag(eventUbScaleVMTE2List[ubListId]); + AscendC::WaitFlag(eventUbPerTokenScaleMTE2VList[ubListId]); + tileBroadcastOneBlk(ubTmpMx32B, ubPerTokenScale); + AscendC::SetFlag(eventUbPerTokenScaleVMTE2List[ubListId]); + + AscendC::PipeBarrier(); + tileOneBlkColumnBroadcastMul(ubTmpMxN, ubTmpMxN, ubTmpMx32B); + AscendC::PipeBarrier(); + tileStrideMuls(ubTmpMxChunkN, ubTmpMxN, -1.0f); + AscendC::PipeBarrier(); + AscendC::Exp(ubTmpMxChunkN, ubTmpMxChunkN, ChunkTileShape::COUNT); + AscendC::PipeBarrier(); + AscendC::Adds(ubTmpMxChunkN, ubTmpMxChunkN, 1.0f, ChunkTileShape::COUNT); + AscendC::PipeBarrier(); + tileStrideDiv(ubTmpMxChunkN, ubTmpMxN, ubTmpMxChunkN); + AscendC::PipeBarrier(); + auto &ubD = ubDList[ubListId]; + LayoutD layoutUbD{actualChunkTileShape, ubChunkTileStride}; + + auto ubTmpMxNR = ubTmpMxN[ChunkTileShape::COLUMN]; + AscendC::WaitFlag(eventUbDMTE3VList[ubListId]); + tileStrideMul(ubD, ubTmpMxNR, ubTmpMxChunkN); + AscendC::SetFlag(eventUbDVMTE3List[ubListId]); + + auto gmTileD = gmD[params.layoutD.GetOffset(chunkTileOffset)]; + auto layoutGmTileD = params.layoutD.GetTileLayout(actualChunkTileShape); + + AscendC::WaitFlag(eventUbDVMTE3List[ubListId]); + copyUbToGmD(gmTileD, ubD, layoutGmTileD, layoutUbD); + AscendC::SetFlag(eventUbDMTE3VList[ubListId]); + ubListId = (ubListId + 1 < UB_STAGES) ? (ubListId + 1) : 0; + } + } + +private: + Params params; + + AscendC::LocalTensor ubCList[UB_STAGES]; + AscendC::LocalTensor ubScaleList[UB_STAGES]; + AscendC::LocalTensor ubPerTokenScaleList[UB_STAGES]; + AscendC::LocalTensor ubDList[UB_STAGES]; + + int32_t eventUbCVMTE2List[UB_STAGES]; + int32_t eventUbCMTE2VList[UB_STAGES]; + int32_t eventUbScaleVMTE2List[UB_STAGES]; + int32_t eventUbScaleMTE2VList[UB_STAGES]; + int32_t eventUbPerTokenScaleVMTE2List[UB_STAGES]; + int32_t eventUbPerTokenScaleMTE2VList[UB_STAGES]; + int32_t eventUbDMTE3VList[UB_STAGES]; + int32_t eventUbDVMTE3List[UB_STAGES]; + + uint32_t ubListId{0}; + + AscendC::LocalTensor ubTmpMxN; + AscendC::LocalTensor ubTmpMx32B; + AscendC::LocalTensor ubTmpMxChunkN; + + TileRowBroadcastMul tileRowBroadcastMul; + TileBroadcastOneBlk tileBroadcastOneBlk; + TileOneBlkColumnBroadcastMul tileOneBlkColumnBroadcastMul; + + TileStrideMuls tileStrideMuls; + TileStrideDiv tileStrideDiv; + TileStrideMul tileStrideMul; + + CopyGmToUbC copyGmToUbC; + CopyGmToUbScale copyGmToUbScale; + CopyGmToUbPerTokenScale copyGmToUbPerTokenScale; + CopyUbToGmD copyUbToGmD; +}; + +} // namespace Catlass::Epilogue::Block diff --git a/csrc/dispatch_gmm_combine_decode/op_kernel/dispatch_gmm_combine_decode/epilogue/dispatch_policy.h b/csrc/dispatch_gmm_combine_decode/op_kernel/dispatch_gmm_combine_decode/epilogue/dispatch_policy.h new file mode 100644 index 00000000000..df70c101a13 --- /dev/null +++ b/csrc/dispatch_gmm_combine_decode/op_kernel/dispatch_gmm_combine_decode/epilogue/dispatch_policy.h @@ -0,0 +1,29 @@ +/* + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ +#pragma once +#include "catlass/epilogue/dispatch_policy.hpp" + +namespace Catlass::Epilogue { + +template +struct EpilogueAtlasA2PerTokenDequantSwiglu { + using ArchTag = Arch::AtlasA2; + static constexpr uint32_t UB_STAGES = UB_STAGES_; + static constexpr uint32_t EXEC_FLAG = EXEC_FLAG_; +}; + +template +struct EpilogueAtlasA2PerTokenDequantCombine { + using ArchTag = Arch::AtlasA2; + static constexpr uint32_t UB_STAGES = UB_STAGES_; + static constexpr uint32_t EXEC_FLAG = EXEC_FLAG_; +}; + +} // namespace Catlass::Epilogue diff --git a/csrc/dispatch_gmm_combine_decode/op_kernel/dispatch_gmm_combine_decode/epilogue/tile/tile_stride_binary.h b/csrc/dispatch_gmm_combine_decode/op_kernel/dispatch_gmm_combine_decode/epilogue/tile/tile_stride_binary.h new file mode 100644 index 00000000000..d7ff4ee91fe --- /dev/null +++ b/csrc/dispatch_gmm_combine_decode/op_kernel/dispatch_gmm_combine_decode/epilogue/tile/tile_stride_binary.h @@ -0,0 +1,107 @@ +/* + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ +#pragma once +#include "catlass/catlass.hpp" + +namespace Catlass::Epilogue::Tile { + +template +struct TileStrideBinary { + using ArchTag = ArchTag_; + using ElementCompute = ElementCompute_; + using TileShape = TileShape_; + static constexpr int64_t DST_STRIDE = DST_STRIDE_; + static constexpr int64_t SRC0_STRIDE = SRC0_STRIDE_; + static constexpr int64_t SRC1_STRIDE = SRC1_STRIDE_; + + static constexpr uint32_t MAX_REPEAT_TIMES = 255; + static constexpr uint32_t ELE_NUM_PER_BLK = BYTE_PER_BLK / sizeof(ElementCompute); + + static constexpr uint32_t DST_BLK_NUM_PER_COLUMN = DST_STRIDE / ELE_NUM_PER_BLK; + static constexpr uint32_t SRC0_BLK_NUM_PER_COLUMN = SRC0_STRIDE / ELE_NUM_PER_BLK; + static constexpr uint32_t SRC1_BLK_NUM_PER_COLUMN = SRC1_STRIDE / ELE_NUM_PER_BLK; + + static constexpr uint32_t ROW_NUM_PER_COMPUTE = MAX_REPEAT_TIMES; + static constexpr uint32_t COL_NUM_PER_COMPUTE = BYTE_PER_VECTOR_FRACTAL / sizeof(ElementCompute); + + CATLASS_DEVICE + TileStrideBinary() + { + repeatParams.dstBlkStride = 1; + repeatParams.src0BlkStride = 1; + repeatParams.src1BlkStride = 1; + repeatParams.dstRepStride = DST_BLK_NUM_PER_COLUMN; + repeatParams.src0RepStride = SRC0_BLK_NUM_PER_COLUMN; + repeatParams.src1RepStride = SRC1_BLK_NUM_PER_COLUMN; + } + + AscendC::BinaryRepeatParams repeatParams; +}; + +template +struct TileStrideMul + : TileStrideBinary { + using Base = TileStrideBinary; + + CATLASS_DEVICE + TileStrideMul() : Base() {} + + CATLASS_DEVICE + void operator()(AscendC::LocalTensor const &ubDst, + AscendC::LocalTensor const &ubSrc0, + AscendC::LocalTensor const &ubSrc1) + { + for (uint32_t rowOffset = 0; rowOffset < Base::TileShape::ROW; rowOffset += Base::ROW_NUM_PER_COMPUTE) { + uint32_t residueM = Base::TileShape::ROW - rowOffset; + uint8_t repeatTimes = + static_cast((residueM > Base::ROW_NUM_PER_COMPUTE) ? Base::ROW_NUM_PER_COMPUTE : residueM); + for (uint32_t colOffset = 0; colOffset < Base::TileShape::COLUMN; colOffset += Base::COL_NUM_PER_COMPUTE) { + uint32_t residueN = Base::TileShape::COLUMN - colOffset; + uint64_t mask = (residueN > Base::COL_NUM_PER_COMPUTE) ? Base::COL_NUM_PER_COMPUTE : residueN; + AscendC::Mul(ubDst[rowOffset * Base::DST_STRIDE + colOffset], + ubSrc0[rowOffset * Base::SRC0_STRIDE + colOffset], + ubSrc1[rowOffset * Base::SRC1_STRIDE + colOffset], mask, repeatTimes, this->repeatParams); + } + } + } +}; + +template +struct TileStrideDiv + : TileStrideBinary { + using Base = TileStrideBinary; + + CATLASS_DEVICE + TileStrideDiv() : Base() {} + + CATLASS_DEVICE + void operator()(AscendC::LocalTensor const &ubDst, + AscendC::LocalTensor const &ubSrc0, + AscendC::LocalTensor const &ubSrc1) + { + for (uint32_t rowOffset = 0; rowOffset < Base::TileShape::ROW; rowOffset += Base::ROW_NUM_PER_COMPUTE) { + uint32_t residueM = Base::TileShape::ROW - rowOffset; + uint8_t repeatTimes = + static_cast((residueM > Base::ROW_NUM_PER_COMPUTE) ? Base::ROW_NUM_PER_COMPUTE : residueM); + for (uint32_t colOffset = 0; colOffset < Base::TileShape::COLUMN; colOffset += Base::COL_NUM_PER_COMPUTE) { + uint32_t residueN = Base::TileShape::COLUMN - colOffset; + uint64_t mask = (residueN > Base::COL_NUM_PER_COMPUTE) ? Base::COL_NUM_PER_COMPUTE : residueN; + AscendC::Div(ubDst[rowOffset * Base::DST_STRIDE + colOffset], + ubSrc0[rowOffset * Base::SRC0_STRIDE + colOffset], + ubSrc1[rowOffset * Base::SRC1_STRIDE + colOffset], mask, repeatTimes, this->repeatParams); + } + } + } +}; + +} // namespace Catlass::Epilogue::Tile diff --git a/csrc/dispatch_gmm_combine_decode/op_kernel/dispatch_gmm_combine_decode/epilogue/tile/tile_stride_muls.h b/csrc/dispatch_gmm_combine_decode/op_kernel/dispatch_gmm_combine_decode/epilogue/tile/tile_stride_muls.h new file mode 100644 index 00000000000..7093ef12c45 --- /dev/null +++ b/csrc/dispatch_gmm_combine_decode/op_kernel/dispatch_gmm_combine_decode/epilogue/tile/tile_stride_muls.h @@ -0,0 +1,59 @@ +/* + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ +#pragma once +#include "catlass/catlass.hpp" + +namespace Catlass::Epilogue::Tile { + +template +struct TileStrideMuls { + using ArchTag = ArchTag_; + using ElementCompute = ElementCompute_; + using TileShape = TileShape_; + using DstTileShape = DstTileShape_; + using SrcTileShape = SrcTileShape_; + + static_assert(DstTileShape::ROW == SrcTileShape::ROW && DstTileShape::ROW == TileShape::ROW, "Error"); + + CATLASS_DEVICE + TileStrideMuls() {} + + CATLASS_DEVICE + void operator()(AscendC::LocalTensor const &ubDst, + AscendC::LocalTensor const &ubSrc, ElementCompute scalar) + { + constexpr uint32_t maxRepeatTimes = 255; + constexpr uint32_t eleNumPerBlk = BYTE_PER_BLK / sizeof(ElementCompute); + + constexpr uint32_t dstBlkNumPerColumn = DstTileShape::COLUMN / eleNumPerBlk; + constexpr uint32_t srcBlkNumPerColumn = SrcTileShape::COLUMN / eleNumPerBlk; + AscendC::UnaryRepeatParams repeatParams; + repeatParams.dstBlkStride = 1; + repeatParams.srcBlkStride = 1; + repeatParams.dstRepStride = dstBlkNumPerColumn; + repeatParams.srcRepStride = srcBlkNumPerColumn; + + constexpr uint32_t rowNumPerCompute = maxRepeatTimes; + constexpr uint32_t colNumPerCompute = BYTE_PER_VECTOR_FRACTAL / sizeof(ElementCompute); + for (uint32_t rowOffset = 0; rowOffset < TileShape::ROW; rowOffset += rowNumPerCompute) { + uint32_t residueM = TileShape::ROW - rowOffset; + uint8_t repeatTimes = static_cast((residueM > rowNumPerCompute) ? rowNumPerCompute : residueM); + for (uint32_t colOffset = 0; colOffset < TileShape::COLUMN; colOffset += colNumPerCompute) { + uint32_t residueN = TileShape::COLUMN - colOffset; + uint64_t mask = (residueN > colNumPerCompute) ? colNumPerCompute : residueN; + AscendC::Muls(ubDst[rowOffset * DstTileShape::COLUMN + colOffset], + ubSrc[rowOffset * SrcTileShape::COLUMN + colOffset], scalar, mask, repeatTimes, + repeatParams); + } + } + } +}; + +} // namespace Catlass::Epilogue::Tile diff --git a/csrc/dispatch_gmm_combine_decode/op_kernel/dispatch_gmm_combine_decode/gemm/block/block_mmad.h b/csrc/dispatch_gmm_combine_decode/op_kernel/dispatch_gmm_combine_decode/gemm/block/block_mmad.h new file mode 100644 index 00000000000..bd2ce09d3dd --- /dev/null +++ b/csrc/dispatch_gmm_combine_decode/op_kernel/dispatch_gmm_combine_decode/gemm/block/block_mmad.h @@ -0,0 +1,13 @@ +/* + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ +#pragma once +#include "catlass/gemm/block/block_mmad.hpp" + +#include "block_mmad_preload_async_with_callback_resident_a.h" diff --git a/csrc/dispatch_gmm_combine_decode/op_kernel/dispatch_gmm_combine_decode/gemm/block/block_mmad_preload_async_with_callback_resident_a.h b/csrc/dispatch_gmm_combine_decode/op_kernel/dispatch_gmm_combine_decode/gemm/block/block_mmad_preload_async_with_callback_resident_a.h new file mode 100644 index 00000000000..87612763eba --- /dev/null +++ b/csrc/dispatch_gmm_combine_decode/op_kernel/dispatch_gmm_combine_decode/gemm/block/block_mmad_preload_async_with_callback_resident_a.h @@ -0,0 +1,420 @@ +/* + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ +#pragma once +#include "catlass/catlass.hpp" +#include "catlass/arch/resource.hpp" +#include "catlass/coord.hpp" +#include "catlass/detail/callback.hpp" +#include "catlass/gemm_coord.hpp" +#include "catlass/gemm/dispatch_policy.hpp" +#include "catlass/gemm/helper.hpp" + +namespace Catlass::Gemm::Block { + +template +struct BlockMmad< + MmadAtlasA2PreloadAsyncWithCallbackResidentA, + L1TileShape_, L0TileShape_, AType_, BType_, CType_, BiasType_, TileCopy_, TileMmad_> { +public: + // Type Aliases + using DispatchPolicy = + MmadAtlasA2PreloadAsyncWithCallbackResidentA; + using ArchTag = typename DispatchPolicy::ArchTag; + using L1TileShape = L1TileShape_; + using L0TileShape = L0TileShape_; + using ElementA = typename AType_::Element; + using LayoutA = typename AType_::Layout; + using ElementB = typename BType_::Element; + using LayoutB = typename BType_::Layout; + using ElementC = typename CType_::Element; + using LayoutC = typename CType_::Layout; + using TileMmad = TileMmad_; + using CopyGmToL1A = typename TileCopy_::CopyGmToL1A; + using CopyGmToL1B = typename TileCopy_::CopyGmToL1B; + using CopyL1ToL0A = typename TileCopy_::CopyL1ToL0A; + using CopyL1ToL0B = typename TileCopy_::CopyL1ToL0B; + using CopyL0CToGm = typename TileCopy_::CopyL0CToGm; + using ElementAccumulator = + typename Gemm::helper::ElementAccumulatorSelector::ElementAccumulator; + using LayoutAInL1 = typename CopyL1ToL0A::LayoutSrc; + using LayoutBInL1 = typename CopyL1ToL0B::LayoutSrc; + using LayoutAInL0 = typename CopyL1ToL0A::LayoutDst; + using LayoutBInL0 = typename CopyL1ToL0B::LayoutDst; + using LayoutCInL0 = layout::zN; + + using L1AAlignHelper = Gemm::helper::L1AlignHelper; + using L1BAlignHelper = Gemm::helper::L1AlignHelper; + + static constexpr uint32_t PRELOAD_STAGES = DispatchPolicy::PRELOAD_STAGES; + static constexpr uint32_t L1A_STAGES = DispatchPolicy::L1A_STAGES; + static constexpr uint32_t L1B_STAGES = DispatchPolicy::L1B_STAGES; + static constexpr uint32_t L0A_STAGES = DispatchPolicy::L0A_STAGES; + static constexpr uint32_t L0B_STAGES = DispatchPolicy::L0B_STAGES; + static constexpr uint32_t L0C_STAGES = DispatchPolicy::L0C_STAGES; + + static constexpr bool ENABLE_UNIT_FLAG = DispatchPolicy::ENABLE_UNIT_FLAG; + static constexpr bool ENABLE_SHUFFLE_K = DispatchPolicy::ENABLE_SHUFFLE_K; + + // L1 tile size + static constexpr uint32_t L1A_TILE_SIZE = L1TileShape::M * L1TileShape::K * sizeof(ElementA); + static constexpr uint32_t L1B_TILE_SIZE = L1TileShape::N * L1TileShape::K * sizeof(ElementB); + // L0 tile size + static constexpr uint32_t L0A_TILE_SIZE = L0TileShape::M * L0TileShape::K * sizeof(ElementA); + static constexpr uint32_t L0B_TILE_SIZE = L0TileShape::K * L0TileShape::N * sizeof(ElementB); + static constexpr uint32_t L0C_TILE_SIZE = L1TileShape::M * L1TileShape::N * sizeof(ElementAccumulator); + + // Check LayoutC + static_assert(std::is_same_v, "LayoutC only support RowMajor yet!"); + + // Check L1TileShape + static_assert(L1A_TILE_SIZE * L1A_STAGES + L1B_TILE_SIZE * L1B_STAGES <= ArchTag::L1_SIZE, + "L1TileShape exceeding the L1 space!"); + + // Check L0TileShape + static_assert(L0A_TILE_SIZE * L0A_STAGES <= ArchTag::L0A_SIZE, "L0TileShape exceeding the L0A space!"); + static_assert(L0B_TILE_SIZE * L0B_STAGES <= ArchTag::L0B_SIZE, "L0TileShape exceeding the L0B space!"); + static_assert(L0C_TILE_SIZE * L0C_STAGES <= ArchTag::L0C_SIZE, "L0TileShape exceeding the L0C space!"); + + static_assert(L1TileShape::M == L0TileShape::M && L1TileShape::N == L0TileShape::N, + "The situation where the basic blocks of L1 and L0 differ on the m and n axes is not supported yet"); + + static constexpr auto L1A_LAYOUT = LayoutAInL1::template MakeLayout(L1TileShape::M, L1TileShape::K); + static constexpr auto L1B_LAYOUT = LayoutBInL1::template MakeLayout(L1TileShape::K, L1TileShape::N); + + CATLASS_DEVICE + BlockMmad(Arch::Resource &resource, uint32_t l1BufAddrStart = 0) + { + InitL1(resource, l1BufAddrStart); + InitL0A(resource); + InitL0B(resource); + InitL0C(resource); + } + + CATLASS_DEVICE + ~BlockMmad() + { + SynchronizeBlock(); + for (uint32_t i = 0; i < L1A_STAGES; ++i) { + AscendC::WaitFlag(l1AEventList[i]); + } + for (uint32_t i = 0; i < L1B_STAGES; ++i) { + AscendC::WaitFlag(l1BEventList[i]); + } + for (uint32_t i = 0; i < L0A_STAGES; ++i) { + AscendC::WaitFlag(l0AEventList[i]); + } + for (uint32_t i = 0; i < L0B_STAGES; ++i) { + AscendC::WaitFlag(l0BEventList[i]); + } + for (uint32_t i = 0; i < L0C_STAGES; ++i) { + AscendC::WaitFlag(l0CEventList[i]); + } + } + + CATLASS_DEVICE + void operator()(AscendC::GlobalTensor const &gmBlockA, LayoutA const &layoutA, + AscendC::GlobalTensor const &gmBlockB, LayoutB const &layoutB, + AscendC::GlobalTensor const &gmBlockC, LayoutC const &layoutC, + GemmCoord const &actualShape, Callback const &callbackBeforeFixpipe, + Callback const &callbackAfterFixpipe) + { + uint32_t kTileCount = CeilDiv(actualShape.k()); + bool useResidentA = + (kTileCount == L1A_STAGES) && (!isFirstLoad) && (gmBlockA.GetPhyAddr() == lastGmBlockA.GetPhyAddr()); + isFirstLoad = false; + lastGmBlockA = gmBlockA; + + uint32_t mRound = RoundUp(actualShape.m()); + uint32_t nRound = RoundUp(actualShape.n()); + + uint32_t startTileIdx = 0; + if constexpr (ENABLE_SHUFFLE_K) { + startTileIdx = AscendC::GetBlockIdx() % kTileCount; + } + + for (uint32_t kLoopIdx = 0; kLoopIdx < kTileCount; ++kLoopIdx) { + uint32_t kTileIdx = (startTileIdx + kLoopIdx < kTileCount) ? (startTileIdx + kLoopIdx) + : (startTileIdx + kLoopIdx - kTileCount); + + uint32_t kActual = + (kTileIdx < kTileCount - 1) ? L1TileShape::K : (actualShape.k() - kTileIdx * L1TileShape::K); + + // Emission load instruction from GM to L1 + MatrixCoord gmTileAOffset{0, kTileIdx * L1TileShape::K}; + MatrixCoord gmTileBOffset{kTileIdx * L1TileShape::K, 0}; + auto gmTileA = gmBlockA[layoutA.GetOffset(gmTileAOffset)]; + auto gmTileB = gmBlockB[layoutB.GetOffset(gmTileBOffset)]; + // Load first matrix A tile from GM to L1 + AscendC::WaitFlag(l1AEventList[l1AListId]); + if (!useResidentA) { + auto layoutTileA = layoutA.GetTileLayout(MakeCoord(actualShape.m(), kActual)); + copyGmToL1A(l1ATensorList[l1AListId], gmTileA, L1A_LAYOUT, layoutTileA); + } + AscendC::SetFlag(l1AEventList[l1AListId]); + // Load first matrix B tile from GM to L1 + AscendC::WaitFlag(l1BEventList[l1BListId]); + auto layoutTileB = layoutB.GetTileLayout(MakeCoord(kActual, actualShape.n())); + copyGmToL1B(l1BTensorList[l1BListId], gmTileB, L1B_LAYOUT, layoutTileB); + AscendC::SetFlag(l1BEventList[l1BListId]); + + // If the number of preload instructions reaches the upper limit, perform an mmad calculation on L1 tile + if (preloadCount == PRELOAD_STAGES) { + L1TileMmad(l1TileMmadParamsList[l1TileMmadParamsId]); + } + + // Store the current load status + uint32_t preloadL1TileMmadParamsId = (l1TileMmadParamsId + preloadCount < PRELOAD_STAGES) + ? (l1TileMmadParamsId + preloadCount) + : (l1TileMmadParamsId + preloadCount - PRELOAD_STAGES); + auto &l1TileMmadParams = l1TileMmadParamsList[preloadL1TileMmadParamsId]; + l1TileMmadParams.l1AListId = l1AListId; + l1TileMmadParams.l1BListId = l1BListId; + l1TileMmadParams.mRound = mRound; + l1TileMmadParams.nRound = nRound; + l1TileMmadParams.kActual = kActual; + l1TileMmadParams.isKLoopFirst = (kLoopIdx == 0); + l1TileMmadParams.isKLoopLast = (kLoopIdx == kTileCount - 1); + if (kLoopIdx == kTileCount - 1) { + l1TileMmadParams.gmBlockC = gmBlockC; + l1TileMmadParams.layoutCInGm = layoutC.GetTileLayout(actualShape.GetCoordMN()); + l1TileMmadParams.callbackBeforeFixpipe = callbackBeforeFixpipe; + l1TileMmadParams.callbackAfterFixpipe = callbackAfterFixpipe; + } + + if (preloadCount < PRELOAD_STAGES) { + ++preloadCount; + } else { + l1TileMmadParamsId = (l1TileMmadParamsId + 1 < PRELOAD_STAGES) ? (l1TileMmadParamsId + 1) : 0; + } + l1AListId = (l1AListId + 1 < L1A_STAGES) ? (l1AListId + 1) : 0; + l1BListId = (l1BListId + 1 < L1B_STAGES) ? (l1BListId + 1) : 0; + } + } + + CATLASS_DEVICE + void SynchronizeBlock() + { + while (preloadCount > 0) { + L1TileMmad(l1TileMmadParamsList[l1TileMmadParamsId]); + l1TileMmadParamsId = (l1TileMmadParamsId + 1 < PRELOAD_STAGES) ? (l1TileMmadParamsId + 1) : 0; + --preloadCount; + } + } + +private: + struct L1TileMmadParams { + uint32_t l1AListId; + uint32_t l1BListId; + uint32_t mRound; + uint32_t nRound; + uint32_t kActual; + bool isKLoopFirst; + bool isKLoopLast; + AscendC::GlobalTensor gmBlockC; + LayoutC layoutCInGm; + Callback callbackBeforeFixpipe; + Callback callbackAfterFixpipe; + + CATLASS_DEVICE + L1TileMmadParams() = default; + }; + + CATLASS_DEVICE + void InitL1(Arch::Resource &resource, uint32_t l1BufAddrStart) + { + uint32_t l1AOffset = l1BufAddrStart; + for (uint32_t i = 0; i < L1A_STAGES; ++i) { + l1ATensorList[i] = resource.l1Buf.template GetBufferByByte(l1AOffset + L1A_TILE_SIZE * i); + l1AEventList[i] = i; + AscendC::SetFlag(l1AEventList[i]); + } + uint32_t l1BOffset = l1BufAddrStart + L1A_TILE_SIZE * L1A_STAGES; + for (uint32_t i = 0; i < L1B_STAGES; ++i) { + l1BTensorList[i] = resource.l1Buf.template GetBufferByByte(l1BOffset + L1B_TILE_SIZE * i); + l1BEventList[i] = i + L1A_STAGES; + AscendC::SetFlag(l1BEventList[i]); + } + } + + CATLASS_DEVICE + void InitL0A(Arch::Resource &resource) + { + for (uint32_t i = 0; i < L0A_STAGES; ++i) { + l0ATensorList[i] = resource.l0ABuf.template GetBufferByByte(L0A_TILE_SIZE * i); + l0AEventList[i] = i; + AscendC::SetFlag(l0AEventList[i]); + } + } + + CATLASS_DEVICE + void InitL0B(Arch::Resource &resource) + { + for (uint32_t i = 0; i < L0B_STAGES; ++i) { + l0BTensorList[i] = resource.l0BBuf.template GetBufferByByte(L0B_TILE_SIZE * i); + l0BEventList[i] = i + L0A_STAGES; + AscendC::SetFlag(l0BEventList[i]); + } + } + + CATLASS_DEVICE + void InitL0C(Arch::Resource &resource) + { + for (uint32_t i = 0; i < L0C_STAGES; ++i) { + l0CTensorList[i] = resource.l0CBuf.template GetBufferByByte(L0C_TILE_SIZE * i); + l0CEventList[i] = i; + AscendC::SetFlag(l0CEventList[i]); + } + } + + CATLASS_DEVICE + void L1TileMmad(L1TileMmadParams const ¶ms) + { + uint32_t mPartLoop = CeilDiv(params.mRound); + uint32_t nPartLoop = CeilDiv(params.nRound); + uint32_t kPartLoop = CeilDiv(params.kActual); + auto &l1ATensor = l1ATensorList[params.l1AListId]; + auto &l1BTensor = l1BTensorList[params.l1BListId]; + + auto &l0CTensor = l0CTensorList[l0CListId]; + LayoutCInL0 layoutCInL0 = LayoutCInL0::MakeLayoutInL0C(MakeCoord(params.mRound, params.nRound)); + + if constexpr (!ENABLE_UNIT_FLAG) { + if (params.isKLoopFirst) { + AscendC::WaitFlag(l0CEventList[l0CListId]); + } + } + + for (uint32_t mPartIdx = 0; mPartIdx < mPartLoop; ++mPartIdx) { + uint32_t mPartActual = + (mPartIdx < mPartLoop - 1) ? L0TileShape::M : (params.mRound - mPartIdx * L0TileShape::M); + + for (uint32_t kPartIdx = 0; kPartIdx < kPartLoop; ++kPartIdx) { + uint32_t kPartActual = + (kPartIdx < kPartLoop - 1) ? L0TileShape::K : (params.kActual - kPartIdx * L0TileShape::K); + + auto &l0ATile = l0ATensorList[l0AListId]; + auto layoutAInL0 = LayoutAInL0::template MakeLayout(mPartActual, kPartActual); + auto l1AOffset = MakeCoord(mPartIdx, kPartIdx) * L0TileShape::ToCoordMK(); + auto l1ATile = l1ATensor[L1A_LAYOUT.GetOffset(l1AOffset)]; + + AscendC::WaitFlag(l0AEventList[l0AListId]); + if ((mPartIdx == 0) && (kPartIdx == 0)) { + AscendC::WaitFlag(l1AEventList[params.l1AListId]); + } + copyL1ToL0A(l0ATile, l1ATile, layoutAInL0, L1A_LAYOUT); + if ((mPartIdx == mPartLoop - 1) && (kPartIdx == kPartLoop - 1)) { + AscendC::SetFlag(l1AEventList[params.l1AListId]); + } + + for (uint32_t nPartIdx = 0; nPartIdx < nPartLoop; ++nPartIdx) { + uint32_t nPartActual = + (nPartIdx < nPartLoop - 1) ? L0TileShape::N : (params.nRound - nPartIdx * L0TileShape::N); + + auto &l0BTile = l0BTensorList[l0BListId]; + auto layoutBInL0 = LayoutBInL0::template MakeLayout(kPartActual, nPartActual); + auto l1BOffset = MakeCoord(kPartIdx, nPartIdx) * L0TileShape::ToCoordKN(); + auto l1BTile = l1BTensor[L1B_LAYOUT.GetOffset(l1BOffset)]; + + AscendC::WaitFlag(l0BEventList[l0BListId]); + if ((kPartIdx == 0) && (nPartIdx == 0)) { + AscendC::WaitFlag(l1BEventList[params.l1BListId]); + } + copyL1ToL0B(l0BTile, l1BTile, layoutBInL0, L1B_LAYOUT); + if ((kPartIdx == kPartLoop - 1) && (nPartIdx == nPartLoop - 1)) { + AscendC::SetFlag(l1BEventList[params.l1BListId]); + } + + AscendC::SetFlag(EVENT_ID0); + + auto l0COffset = MakeCoord(mPartIdx, nPartIdx) * L0TileShape::ToCoordMN(); + auto l0CTile = l0CTensor[layoutCInL0.GetOffset(l0COffset)]; + + AscendC::WaitFlag(EVENT_ID0); + // If the current tile is the first tile on the k axis, the accumulator needs to be reset to 0 + bool initC = (params.isKLoopFirst && (kPartIdx == 0)); + // If the unit flag is enabled, the unit flag is set according to the calculation progress + uint8_t unitFlag = 0b00; + if constexpr (ENABLE_UNIT_FLAG) { + if (params.isKLoopLast && (mPartIdx == mPartLoop - 1) && (kPartIdx == kPartLoop - 1) && + (nPartIdx == nPartLoop - 1)) { + unitFlag = 0b11; + } else { + unitFlag = 0b10; + } + } + tileMmad(l0CTile, l0ATile, l0BTile, mPartActual, nPartActual, kPartActual, initC, unitFlag); + + AscendC::SetFlag(l0BEventList[l0BListId]); + l0BListId = (l0BListId + 1 < L0B_STAGES) ? (l0BListId + 1) : 0; + } + AscendC::SetFlag(l0AEventList[l0AListId]); + l0AListId = (l0AListId + 1 < L0A_STAGES) ? (l0AListId + 1) : 0; + } + } + + if (params.isKLoopLast) { + auto layoutCInGm = params.layoutCInGm; + + params.callbackBeforeFixpipe(); + + if constexpr (!ENABLE_UNIT_FLAG) { + AscendC::SetFlag(l0CEventList[l0CListId]); + AscendC::WaitFlag(l0CEventList[l0CListId]); + copyL0CToGm(params.gmBlockC, l0CTensor, layoutCInGm, layoutCInL0); + AscendC::SetFlag(l0CEventList[l0CListId]); + } else { + copyL0CToGm(params.gmBlockC, l0CTensor, layoutCInGm, layoutCInL0, 0b11); + } + l0CListId = (l0CListId + 1 < L0C_STAGES) ? (l0CListId + 1) : 0; + + params.callbackAfterFixpipe(); + } + } + + AscendC::LocalTensor l1ATensorList[L1A_STAGES]; + AscendC::LocalTensor l1BTensorList[L1B_STAGES]; + int32_t l1AEventList[L1A_STAGES]; + int32_t l1BEventList[L1B_STAGES]; + uint32_t l1AListId{0}; + uint32_t l1BListId{0}; + + AscendC::LocalTensor l0ATensorList[L0A_STAGES]; + int32_t l0AEventList[L0A_STAGES]; + uint32_t l0AListId{0}; + + AscendC::LocalTensor l0BTensorList[L0B_STAGES]; + int32_t l0BEventList[L0B_STAGES]; + uint32_t l0BListId{0}; + + AscendC::LocalTensor l0CTensorList[L0C_STAGES_]; + int32_t l0CEventList[L0C_STAGES_]; + uint32_t l0CListId{0}; + + L1TileMmadParams l1TileMmadParamsList[PRELOAD_STAGES]; + uint32_t l1TileMmadParamsId{0}; + uint32_t preloadCount{0}; + + TileMmad tileMmad; + CopyGmToL1A copyGmToL1A; + CopyGmToL1B copyGmToL1B; + CopyL1ToL0A copyL1ToL0A; + CopyL1ToL0B copyL1ToL0B; + CopyL0CToGm copyL0CToGm; + + bool isFirstLoad{true}; + AscendC::GlobalTensor lastGmBlockA; +}; + +} // namespace Catlass::Gemm::Block diff --git a/csrc/dispatch_gmm_combine_decode/op_kernel/dispatch_gmm_combine_decode/gemm/dispatch_policy.h b/csrc/dispatch_gmm_combine_decode/op_kernel/dispatch_gmm_combine_decode/gemm/dispatch_policy.h new file mode 100644 index 00000000000..40522a0b78a --- /dev/null +++ b/csrc/dispatch_gmm_combine_decode/op_kernel/dispatch_gmm_combine_decode/gemm/dispatch_policy.h @@ -0,0 +1,28 @@ +/* + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ +#pragma once +#include "catlass/gemm/dispatch_policy.hpp" + +namespace Catlass::Gemm { + +template +struct MmadAtlasA2PreloadAsyncWithCallbackResidentA : public MmadAtlasA2Async { + static constexpr uint32_t PRELOAD_STAGES = PRELOAD_STAGES_; // Stages of emitting load instruction in advance + static constexpr uint32_t L1A_STAGES = L1A_STAGES_; + static constexpr uint32_t L1B_STAGES = L1B_STAGES_; + static constexpr uint32_t L0A_STAGES = L0A_STAGES_; + static constexpr uint32_t L0B_STAGES = L0B_STAGES_; + static constexpr uint32_t L0C_STAGES = L0C_STAGES_; + static constexpr bool ENABLE_UNIT_FLAG = ENABLE_UNIT_FLAG_; + static constexpr bool ENABLE_SHUFFLE_K = ENABLE_SHUFFLE_K_; +}; + +} // namespace Catlass::Gemm diff --git a/csrc/dispatch_gmm_combine_decode/op_kernel/dispatch_gmm_combine_decode/gemm/kernel/grouped_matmul_slice_m_per_token_dequant_multistage_workspace.h b/csrc/dispatch_gmm_combine_decode/op_kernel/dispatch_gmm_combine_decode/gemm/kernel/grouped_matmul_slice_m_per_token_dequant_multistage_workspace.h new file mode 100644 index 00000000000..22cfe2b1cc6 --- /dev/null +++ b/csrc/dispatch_gmm_combine_decode/op_kernel/dispatch_gmm_combine_decode/gemm/kernel/grouped_matmul_slice_m_per_token_dequant_multistage_workspace.h @@ -0,0 +1,355 @@ +/* + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ +#ifndef ACT_GEMM_KERNEL_GROUPED_MATMUL_M_PER_TOKEN_DEQUANT_MULTISTAGE_WORKSPACE_HPP +#define ACT_GEMM_KERNEL_GROUPED_MATMUL_M_PER_TOKEN_DEQUANT_MULTISTAGE_WORKSPACE_HPP + +#include "../../raw_distributed/cam_moe_distribute_combine.h" +#include "catlass/catlass.hpp" +#include "catlass/arch/cross_core_sync.hpp" +#include "catlass/arch/resource.hpp" +#include "catlass/coord.hpp" +#include "catlass/detail/callback.hpp" +#include "catlass/gemm_coord.hpp" +#include "catlass/matrix_coord.hpp" + +namespace Catlass::Gemm::Kernel { + +template +class GroupedMatmulSliceMPerTokenDequantMultiStageWorkspace +{ +public: + using BlockMmad = BlockMmad_; + using ArchTag = typename BlockMmad::ArchTag; + using L1TileShape = typename BlockMmad::L1TileShape; + using ElementA = typename BlockMmad::ElementA; + using LayoutA = typename BlockMmad::LayoutA; + using ElementB = typename BlockMmad::ElementB; + using LayoutB = typename BlockMmad::LayoutB; + using ElementC = typename BlockMmad::ElementC; + using LayoutC = typename BlockMmad::LayoutC; + using ElementAccumulator = typename BlockMmad::ElementAccumulator; + + using BlockEpilogue = BlockEpilogue_; + using ElementScale = typename BlockEpilogue::ElementScale; + using LayoutScale = typename BlockEpilogue::LayoutScale; + using ElementPerTokenScale = typename BlockEpilogue::ElementPerTokenScale; + using LayoutPerTokenScale = typename BlockEpilogue::LayoutPerTokenScale; + using ElementD = typename BlockEpilogue::ElementD; + using LayoutD = typename BlockEpilogue::LayoutD; + using EpilogueParams = typename BlockEpilogue::Params; + + using BlockScheduler = BlockScheduler_; + static constexpr uint32_t WORKSPACE_STAGES = WORKSPACE_STAGES_; + using ElementGroupList = ElementGroupList_; + + /// Parameters structure + struct Params { + // Data members + GemmCoord problemShape; + uint32_t problemCount; + __gm__ ElementGroupList_ *ptrGroupList; + __gm__ ElementA *ptrA; + LayoutA layoutA; + __gm__ ElementB *ptrB; + LayoutB layoutB; + __gm__ ElementScale *ptrScale; + LayoutScale layoutScale; + __gm__ ElementPerTokenScale *ptrPerTokenScale; + LayoutPerTokenScale layoutPerTokenScale; + __gm__ ElementD *ptrD; + LayoutD layoutD; + GM_ADDR ptrWorkspace; + void *combiner; + + // Methods + CATLASS_DEVICE + Params() {} + + CATLASS_DEVICE + Params(GemmCoord problemShape_, uint32_t problemCount_, GM_ADDR ptrGroupList_, GM_ADDR ptrA_, LayoutA layoutA_, + GM_ADDR ptrB_, LayoutB layoutB_, GM_ADDR ptrScale_, LayoutScale layoutScale_, GM_ADDR ptrPerTokenScale_, + LayoutPerTokenScale layoutPerTokenScale_, GM_ADDR ptrD_, LayoutD layoutD_, GM_ADDR ptrWorkspace_, + void *combiner_) + : problemShape(problemShape_), + problemCount(problemCount_), + ptrGroupList(reinterpret_cast<__gm__ ElementGroupList *>(ptrGroupList_)), + ptrA(reinterpret_cast<__gm__ ElementA *>(ptrA_)), + layoutA(layoutA_), + ptrB(reinterpret_cast<__gm__ ElementB *>(ptrB_)), + layoutB(layoutB_), + ptrScale(reinterpret_cast<__gm__ ElementScale *>(ptrScale_)), + layoutScale(layoutScale_), + ptrPerTokenScale(reinterpret_cast<__gm__ ElementPerTokenScale *>(ptrPerTokenScale_)), + layoutPerTokenScale(layoutPerTokenScale_), + ptrD(reinterpret_cast<__gm__ ElementD *>(ptrD_)), + layoutD(layoutD_), + ptrWorkspace(ptrWorkspace_), + combiner(combiner_) + {} + }; + + // Methods + CATLASS_DEVICE + GroupedMatmulSliceMPerTokenDequantMultiStageWorkspace() + { + Arch::FlagID flagId = 0; + for (uint32_t stageId = 0; stageId < WORKSPACE_STAGES; ++stageId) { + flagAicFinishStoreList[stageId] = Arch::CrossCoreFlag(flagId++); + flagAivFinishComputeList[stageId] = Arch::CrossCoreFlag(flagId++); + aicWaitFuncList[stageId] = {this, stageId}; + aicSetFuncList[stageId] = {this, stageId}; + } + } + + template + CATLASS_DEVICE void operator()(Params const ¶ms); + + template <> + CATLASS_DEVICE void operator()(Params const ¶ms) + { + BlockScheduler blockScheduler; + BlockMmad blockMmad(resource); + + // Represent the full gm + AscendC::GlobalTensor gmA; + gmA.SetGlobalBuffer(params.ptrA); + AscendC::GlobalTensor gmB; + gmB.SetGlobalBuffer(params.ptrB); + AscendC::GlobalTensor groupList; + groupList.SetGlobalBuffer(params.ptrGroupList); + + uint32_t coreIdx = AscendC::GetBlockIdx(); + uint32_t coreNum = AscendC::GetBlockNum(); + int64_t gmGroupOffsetA = 0; + int64_t gmGroupOffsetB = 0; + + AscendC::GlobalTensor gmC; + gmC.SetGlobalBuffer(reinterpret_cast<__gm__ ElementC *>(params.ptrWorkspace)); + auto layoutC = layout::RowMajor{L1TileShape::M * coreNum * WORKSPACE_STAGES, L1TileShape::N}; + + uint32_t stageId = 0; + uint32_t stageUsed = 0; + uint32_t startCoreIdx = 0; + for (uint32_t groupIdx = 0; groupIdx < params.problemCount; ++groupIdx) { + uint32_t currentM = (groupIdx == 0) ? groupList.GetValue(groupIdx) + : (groupList.GetValue(groupIdx) - groupList.GetValue(groupIdx - 1)); + GemmCoord inGroupProblemShape{currentM, params.problemShape.n(), params.problemShape.k()}; + + LayoutA layoutA = params.layoutA.GetTileLayout(inGroupProblemShape.GetCoordMK()); + LayoutB layoutB = params.layoutB; + + blockScheduler.Update(inGroupProblemShape, MakeCoord(L1TileShape::M, L1TileShape::N)); + uint32_t coreLoops = blockScheduler.GetCoreLoops(); + + // Determine the starting loopIdx of the current core under the current + // groupIdx + uint32_t startLoopIdx = ((coreIdx < startCoreIdx) ? (coreIdx + coreNum) : coreIdx) - startCoreIdx; + // Loop through the matmul of each groupIdx + for (uint32_t loopIdx = startLoopIdx; loopIdx < coreLoops; loopIdx += coreNum) { + // Compute block location + GemmCoord blockCoord = blockScheduler.GetBlockCoord(loopIdx); + GemmCoord actualBlockShape = blockScheduler.GetActualBlockShape(blockCoord); + + Callback callbackBeforeFixpipe{}; + if (stageUsed == WORKSPACE_STAGES) { + callbackBeforeFixpipe = MakeCallback(&aicWaitFuncList[stageId]); + } else { + ++stageUsed; + } + Callback callbackAfterFixpipe = MakeCallback(&aicSetFuncList[stageId]); + + // Compute initial location in logical coordinates + MatrixCoord offsetA{blockCoord.m() * L1TileShape::M, blockCoord.k() * L1TileShape::K}; + MatrixCoord offsetB{blockCoord.k() * L1TileShape::K, blockCoord.n() * L1TileShape::N}; + MatrixCoord offsetC{(stageId * coreNum + coreIdx) * L1TileShape::M, 0}; + int64_t gmOffsetA = layoutA.GetOffset(offsetA); + int64_t gmOffsetB = layoutB.GetOffset(offsetB); + int64_t gmOffsetC = layoutC.GetOffset(offsetC); + + // Compute block-scoped matrix multiply-add + if constexpr (BlockMmad::DispatchPolicy::ASYNC) { + blockMmad(gmA[gmGroupOffsetA + gmOffsetA], layoutA, gmB[gmGroupOffsetB + gmOffsetB], layoutB, + gmC[gmOffsetC], layoutC, actualBlockShape, callbackBeforeFixpipe, callbackAfterFixpipe); + } else { + callbackBeforeFixpipe(); + blockMmad(gmA[gmGroupOffsetA + gmOffsetA], layoutA, gmB[gmGroupOffsetB + gmOffsetB], layoutB, + gmC[gmOffsetC], layoutC, actualBlockShape); + callbackAfterFixpipe(); + } + + stageId = (stageId + 1 < WORKSPACE_STAGES) ? (stageId + 1) : 0; + } + + gmGroupOffsetA += inGroupProblemShape.m() * inGroupProblemShape.k(); + gmGroupOffsetB += inGroupProblemShape.k() * inGroupProblemShape.n(); + startCoreIdx = (startCoreIdx + coreLoops) % coreNum; + } + + if constexpr (BlockMmad::DispatchPolicy::ASYNC) { + blockMmad.SynchronizeBlock(); + } + + while (stageUsed > 0) { + uint32_t aivComputeStageId = + (stageId >= stageUsed) ? (stageId - stageUsed) : (stageId + WORKSPACE_STAGES - stageUsed); + Arch::CrossCoreWaitFlag(flagAivFinishComputeList[aivComputeStageId]); + --stageUsed; + } + } + + template <> + CATLASS_DEVICE void operator()(Params const ¶ms) + { + auto *combiner = (MoeDistributeCombineImpl::CamMoeDistributeCombine *)params.combiner; + { + if constexpr (EXEC_FLAG & EXEC_FLAG_DEEP_FUSE) { + if (get_subblockid() == 0) { + AscendC::CrossCoreSetFlag<0x0, PIPE_MTE3>(MoeDistributeCombineImpl::RECV_SYNC_EVENT_ID); + } + } + BlockScheduler blockScheduler; + BlockEpilogue blockEpilogue(resource, combiner->GetCalcInfo()); + + uint32_t coreIdx = AscendC::GetBlockIdx() / AscendC::GetSubBlockNum(); + uint32_t coreNum = AscendC::GetBlockNum(); + int64_t gmGroupOffsetScale = 0; + int64_t gmGroupOffsetPerTokenScale = 0; + int64_t gmGroupOffsetD = 0; + AscendC::GlobalTensor groupList; + groupList.SetGlobalBuffer(params.ptrGroupList); + + AscendC::GlobalTensor gmC; + gmC.SetGlobalBuffer(reinterpret_cast<__gm__ ElementC *>(params.ptrWorkspace)); + auto layoutC = layout::RowMajor{L1TileShape::M * coreNum * WORKSPACE_STAGES, L1TileShape::N}; + + uint32_t stageId = 0; + uint32_t startCoreIdx = 0; + for (uint32_t groupIdx = 0; groupIdx < params.problemCount; ++groupIdx) { + uint32_t currentM = (groupIdx == 0) ? groupList.GetValue(groupIdx) + : (groupList.GetValue(groupIdx) - groupList.GetValue(groupIdx - 1)); + GemmCoord inGroupProblemShape{currentM, params.problemShape.n(), params.problemShape.k()}; + + LayoutScale layoutScale = params.layoutScale; + LayoutPerTokenScale layoutPerTokenScale = + params.layoutPerTokenScale.GetTileLayout(inGroupProblemShape.template GetCoordByAxis<0>()); + LayoutD layoutD = params.layoutD.GetTileLayout(inGroupProblemShape.GetCoordMN()); + EpilogueParams epilogueParams{params.ptrScale + gmGroupOffsetScale, + layoutScale, + params.ptrPerTokenScale + gmGroupOffsetPerTokenScale, + layoutPerTokenScale, + params.ptrD + gmGroupOffsetD, + layoutD}; + blockScheduler.Update(inGroupProblemShape, L1TileShape::ToCoordMN()); + blockEpilogue.UpdateParams(epilogueParams); + uint32_t coreLoops = blockScheduler.GetCoreLoops(); + + GemmCoord blockShapeMNK = L1TileShape::ToCoord(); + uint32_t startLoopIdx = ((coreIdx < startCoreIdx) ? (coreIdx + coreNum) : coreIdx) - startCoreIdx; + for (uint32_t loopIdx = startLoopIdx; loopIdx < coreLoops; loopIdx += coreNum) { + GemmCoord blockCoordMNK = blockScheduler.GetBlockCoord(loopIdx); + GemmCoord actualBlockShapeMNK = blockScheduler.GetActualBlockShape(blockCoordMNK); + + MatrixCoord offsetC{(stageId * coreNum + coreIdx) * L1TileShape::M, 0}; + int64_t gmOffsetC = layoutC.GetOffset(offsetC); + auto gmBlockC = gmC[gmOffsetC]; + auto layoutBlockC = layoutC.GetTileLayout(actualBlockShapeMNK.GetCoordMN()); + + Arch::CrossCoreWaitFlag(flagAicFinishStoreList[stageId]); + blockEpilogue(gmGroupOffsetD, groupIdx, blockShapeMNK, blockCoordMNK, actualBlockShapeMNK, gmBlockC, + layoutBlockC); + Arch::CrossCoreSetFlag<0x2, PIPE_MTE3>(flagAivFinishComputeList[stageId]); + + stageId = (stageId + 1 < WORKSPACE_STAGES) ? (stageId + 1) : 0; + } + + gmGroupOffsetScale += inGroupProblemShape.n(); + gmGroupOffsetPerTokenScale += inGroupProblemShape.m(); + gmGroupOffsetD += inGroupProblemShape.m() * inGroupProblemShape.n(); + + startCoreIdx = (startCoreIdx + coreLoops) % coreNum; + } + } + + icache_preload(4); + if constexpr (EXEC_FLAG & EXEC_FLAG_DEEP_FUSE) { + if (get_subblockid() == 0) { + resource.pipe.Init(); + combiner->TPipeSet(&resource.pipe); + combiner->AllToAllSend(); + combiner->TPipeSet(nullptr); + resource.pipe.Destroy(); + } else { + resource.pipe.Init(); + combiner->TPipeSet(&resource.pipe); + combiner->ReducePermute(); + combiner->TPipeSet(nullptr); + resource.pipe.Destroy(); + } + } else { + resource.pipe.Init(); + combiner->TPipeSet(&resource.pipe); + combiner->Process(); + combiner->TPipeSet(nullptr); + resource.pipe.Destroy(); + } + } + +private: + friend struct AicWaitFunc; + friend struct AicSetFunc; + + struct AicWaitFunc { + using MatmulKernel = + GroupedMatmulSliceMPerTokenDequantMultiStageWorkspace; + + CATLASS_DEVICE + AicWaitFunc() = default; + + CATLASS_DEVICE + void operator()() const + { + Arch::CrossCoreWaitFlag(ptr->flagAivFinishComputeList[stageId]); + } + + MatmulKernel *ptr{nullptr}; + uint32_t stageId; + }; + + struct AicSetFunc { + using MatmulKernel = + GroupedMatmulSliceMPerTokenDequantMultiStageWorkspace; + + CATLASS_DEVICE + AicSetFunc() = default; + + CATLASS_DEVICE + void operator()() const + { + Arch::CrossCoreSetFlag<0x2, PIPE_FIX>(ptr->flagAicFinishStoreList[stageId]); + } + + MatmulKernel *ptr{nullptr}; + uint32_t stageId; + }; + + Arch::CrossCoreFlag flagAicFinishStoreList[WORKSPACE_STAGES]; + Arch::CrossCoreFlag flagAivFinishComputeList[WORKSPACE_STAGES]; + + AicWaitFunc aicWaitFuncList[WORKSPACE_STAGES]; + AicSetFunc aicSetFuncList[WORKSPACE_STAGES]; + Arch::Resource resource; +}; + +} // namespace Catlass::Gemm::Kernel + +#endif // ACT_GEMM_KERNEL_GROUPED_MATMUL_M_PER_TOKEN_DEQUANT_MULTISTAGE_WORKSPACE_HPP diff --git a/csrc/dispatch_gmm_combine_decode/op_kernel/dispatch_gmm_combine_decode/gemm/kernel/grouped_matmul_slice_m_per_token_dequant_swiglu_quant_multistage_workspace.h b/csrc/dispatch_gmm_combine_decode/op_kernel/dispatch_gmm_combine_decode/gemm/kernel/grouped_matmul_slice_m_per_token_dequant_swiglu_quant_multistage_workspace.h new file mode 100644 index 00000000000..2ace6ed826b --- /dev/null +++ b/csrc/dispatch_gmm_combine_decode/op_kernel/dispatch_gmm_combine_decode/gemm/kernel/grouped_matmul_slice_m_per_token_dequant_swiglu_quant_multistage_workspace.h @@ -0,0 +1,1998 @@ +/* + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ +#pragma once + +#include "catlass/catlass.hpp" +#include "catlass/arch/cross_core_sync.hpp" +#include "catlass/arch/resource.hpp" +#include "catlass/coord.hpp" +#include "catlass/detail/callback.hpp" +#include "catlass/gemm_coord.hpp" +#include "catlass/matrix_coord.hpp" +#include "catlass/epilogue/tile/tile_swizzle.hpp" +#include "catlass/epilogue/tile/tile_copy.hpp" + +#include "../../../dispatch_gmm_combine_decode_base.h" + +constexpr uint32_t STATE_OFFSET = 512; +constexpr uint64_t WIN_STATE_OFFSET = 512 * 1024; +constexpr uint64_t STATE_WIN_OFFSET = 900 * 1024; +constexpr uint64_t GROUP_TOKEN_NUM_OFFSET = 932 * 1024; +constexpr uint64_t SOFT_SYNC_OFFSET = 964 * 1024; +constexpr uint32_t SELF_STATE_OFFSET = 256 * 1024; +constexpr uint32_t SUM_TMP_TENSOR_SIZE = 1024; +constexpr uint32_t UB_ALIGN = 32; +constexpr uint32_t TOKEN_EXTRA_SPACE = 512; +constexpr uint32_t INT32_COUNT_PER_BLOCK = 8; +constexpr uint32_t SOFT_SYNC_SPACE_SIZE = 512; +constexpr uint32_t COMP_AIV_CORE_NUM = 24; // 24 AIV 做deq-swiglu计算,当前不支持自己调整 +constexpr uint32_t SEND_AIV_CORE_NUM = 48; // 单卡单专家时全部核发送/接收,多专家时砍半 +constexpr uint32_t RECV_AIV_CORE_NUM = 48; // 单卡单专家时全部核发送/接收,多专家时砍半 +constexpr int64_t LOOP_TMP_SIZE = 4096; // 计算地址偏移优化使用空间 +constexpr int32_t SUB_AIV_NUM = 2; // 1C配2V,即1个cube搭配两个vector +constexpr int32_t ODD_EVEN_BASE = 2; // 判断奇偶的基数 +constexpr int32_t BUFFER_NUM = 2; +constexpr int32_t GATHER_SECOND_NUM = 2; +constexpr uint32_t MAX_QUANT_ROW_ONCE = 8; +constexpr uint32_t QUANT_SPACE_FACTOR = 176 * 1024 / 11; // 量化使用UB不超过176KB +#define OPT_RANK_OFFSET 512 + +#define CEIL_UP(x) ((x + UB_ALIGN - 1) / UB_ALIGN * UB_ALIGN) +#define CEIL(x, y) (((x) + (y - 1)) / (y)) +#define UB_BLOCK_SIZE (32) +#define GET_WIND_STATE_ADDR_BY_RANK_ID(rankId) \ + (((epRankId == rankId) \ + ? ((GM_ADDR)(winContext_->localWindowsExp)) \ + : ((GM_ADDR)(((HcclRankRelationResV2 *)(winContext_->remoteRes[rankId].nextDevicePtr))->windowsExp))) + \ + dataState * WIN_STATE_OFFSET) +#define GET_WIND_ADDR_BY_RANK_ID(rankId) \ + (((epRankId == rankId) \ + ? ((GM_ADDR)(winContext_->localWindowsIn)) \ + : ((GM_ADDR)(((HcclRankRelationResV2 *)(winContext_->remoteRes[rankId].nextDevicePtr))->windowsIn))) + \ + winDataSizeOffset + rankId * OPT_RANK_OFFSET) +#define TOKEN_FLAG_1 (0x55555555) +#define TOKEN_FLAG_2 (0x33333333) +#define V_TO_C_FLAG_1 (0x03030303) +#define V_TO_C_FLAG_2 (0x05050505) +#define AIC_STATE_SPACE_IDNEX (48) +#define AIV_STATE_SPACE_IDNEX (72) +#define CV_FLAG_INDEX 0 +#define GROUP_ID_INDEX 1 +#define PRE_COUNT_INDEX 2 +#define SELF_COUNT_INDEX 3 +#define TOTAL_COUNT_INDEX 4 +#define GROUP_TOKEN_COUNT 3 // 等于SELF_COUNT_INDEX +#define GROUP_INFO_SIZE 32 + +namespace Catlass::Gemm::Kernel { + +template +class BlockQuant +{ +public: + using ElementInput = float; + using LayoutInput = layout::RowMajor; + using ElementDequantScale = float; + using LayoutDequantScale = layout::VectorLayout; + using ElementOutput = int8_t; + using LayoutOutput = layout::RowMajor; + + using InputType = GemmType; + using DequantScaleType = GemmType; + using OutputType = GemmType; + + using EpilogueTileSwizzle = Epilogue::Tile::EpilogueHorizontalTileSwizzle; + + struct Params { + __gm__ ElementInput *ptrInput{nullptr}; + LayoutInput layoutInput; + __gm__ ElementDequantScale *ptrDequantScale{nullptr}; + LayoutDequantScale layoutDequantScale; + __gm__ ElementOutput *ptrOutput{nullptr}; + LayoutOutput layoutOutput; + uint32_t tileRow; + uint32_t tileColumn; + + CATLASS_DEVICE + Params() {}; + + CATLASS_DEVICE + Params(__gm__ ElementInput *ptrInput_, LayoutInput const &layoutInput_, + __gm__ ElementDequantScale *ptrQuantScale_, LayoutDequantScale const &layoutQuantScale_, + __gm__ ElementOutput *ptrOutput_, LayoutOutput const layoutOutput_, const uint32_t tileRow_, + const uint32_t tileColumn_) + : ptrInput(ptrInput_), + layoutInput(layoutInput_), + ptrDequantScale(ptrQuantScale_), + layoutDequantScale(layoutQuantScale_), + ptrOutput(ptrOutput_), + layoutOutput(layoutOutput_), + tileRow(tileRow_), + tileColumn(tileColumn_) + {} + }; + + CATLASS_DEVICE + BlockQuant(Arch::Resource const &resource, Params const ¶ms_) : params(params_) + { + int64_t ubOffset = 0; + tileRow = params_.tileRow; + tileColumn = params_.tileColumn; + tileCount = tileRow * tileColumn; + halfTileColumn = tileColumn / 2; + halfTileCount = tileRow * halfTileColumn; + + ubInput = resource.ubBuf.template GetBufferByByte(ubOffset); + ubOffset += tileCount * sizeof(ElementInput); + ubDequantScale = resource.ubBuf.template GetBufferByByte(ubOffset); + ubOffset += CEIL_UP(tileRow * sizeof(ElementDequantScale)); + ubOutput = resource.ubBuf.template GetBufferByByte(ubOffset); + ubOffset += tileCount * sizeof(ElementOutput); + + ubAbs = resource.ubBuf.template GetBufferByByte(ubOffset); + ubOffset += tileCount * sizeof(float); + ubMax = resource.ubBuf.template GetBufferByByte(ubOffset); + ubOffset += halfTileCount * sizeof(float); + ubReduceMax = resource.ubBuf.template GetBufferByByte(ubOffset); + ubOffset += CEIL_UP(tileRow * sizeof(float)); + ubQuantScale = resource.ubBuf.template GetBufferByByte(ubOffset); + ubOffset += CEIL_UP(tileRow * sizeof(float)); + ubInputTmp = ubAbs; + ubQuantF32 = ubAbs; + ubQuantS32 = ubAbs.ReinterpretCast(); + ubQuantF16 = ubAbs.ReinterpretCast(); + + AscendC::SetFlag(0); + AscendC::SetFlag(0); + AscendC::SetFlag(1); + } + + CATLASS_DEVICE + ~BlockQuant() + { + AscendC::WaitFlag(0); + AscendC::WaitFlag(0); + AscendC::WaitFlag(1); + } + + CATLASS_DEVICE + void operator()(MatrixCoord const &blockShape, MatrixCoord const &blockCoord, MatrixCoord const &actualBlockShape) + { + MatrixCoord blockOffset = blockCoord * blockShape; + + AscendC::GlobalTensor gmInput; + gmInput.SetGlobalBuffer(params.ptrInput); + AscendC::GlobalTensor gmDequantScale; + gmDequantScale.SetGlobalBuffer(params.ptrDequantScale); + AscendC::GlobalTensor gmOutput; + gmOutput.SetGlobalBuffer(params.ptrOutput); + + auto ubTileStride = MakeCoord(static_cast(tileColumn), 1L); + auto ubHalfTileStride = MakeCoord(static_cast(halfTileColumn), 1L); + auto tileShape = MakeCoord(tileRow, tileColumn); + EpilogueTileSwizzle epilogueTileSwizzle(actualBlockShape, tileShape); + uint32_t tileLoops = epilogueTileSwizzle.GetLoops(); + uint32_t subblockIdx = AscendC::GetSubBlockIdx(); + uint32_t subblockNum = AscendC::GetSubBlockNum(); + for (uint32_t loopIdx = subblockIdx; loopIdx < tileLoops; loopIdx += subblockNum) { + auto tileCoord = epilogueTileSwizzle.GetTileCoord(loopIdx); + auto actualTileShape = epilogueTileSwizzle.GetActualTileShape(tileCoord); + auto tileOffsetInBlock = tileCoord * tileShape; + auto tileOffset = blockOffset + tileOffsetInBlock; + + auto gmTileInput = gmInput[params.layoutInput.GetOffset(tileOffset)]; + auto layoutGmTileInput = params.layoutInput.GetTileLayout(actualTileShape); + + layout::RowMajor layoutUbInput{actualTileShape, ubTileStride}; + + AscendC::WaitFlag(0); + copyGmToUbInput(ubInput, gmTileInput, layoutUbInput, layoutGmTileInput); + AscendC::SetFlag(0); + + AscendC::WaitFlag(0); + AscendC::Abs(ubAbs, ubInput, tileCount); + AscendC::PipeBarrier(); + + for (uint32_t rowIdx = 0; rowIdx < tileRow; ++rowIdx) { + AscendC::Max(ubMax[rowIdx * halfTileColumn], ubAbs[rowIdx * tileColumn], + ubAbs[rowIdx * tileColumn + halfTileColumn], halfTileColumn); + } + + AscendC::PipeBarrier(); + AscendC::Muls(ubInputTmp, ubInput, 127.f, tileCount); + + constexpr uint32_t elementPerBlk = BYTE_PER_BLK / sizeof(float); + constexpr int32_t mask = 64; + + AscendC::BinaryRepeatParams maxParams; + maxParams.dstBlkStride = halfTileColumn / elementPerBlk; + maxParams.src0BlkStride = halfTileColumn / elementPerBlk; + maxParams.src1BlkStride = halfTileColumn / elementPerBlk; + maxParams.dstRepStride = 1; + maxParams.src0RepStride = 1; + maxParams.src1RepStride = 1; + constexpr uint32_t colNumPerCompute = BYTE_PER_VECTOR_FRACTAL / sizeof(float); + uint32_t reduceWidth = halfTileColumn; + while (reduceWidth > (BLK_NUM_PER_VECTOR_FRACTAL * BYTE_PER_BLK / sizeof(float))) { + reduceWidth >>= 1; + AscendC::Max(ubMax, ubMax, ubMax[reduceWidth], mask, reduceWidth / elementPerBlk, maxParams); + AscendC::PipeBarrier(); + } + + AscendC::WholeReduceMax(ubReduceMax, ubMax, mask, tileRow, 1, 1, halfTileColumn / elementPerBlk, + AscendC::ReduceOrder::ORDER_ONLY_VALUE); + AscendC::SetFlag(0); + AscendC::SetFlag(0); + AscendC::PipeBarrier(); + + AscendC::WaitFlag(0); + AscendC::Muls(ubDequantScale, ubReduceMax, 1.0f / 127.0f, tileRow); + AscendC::SetFlag(0); + + auto dequantScaleTileOffset = tileOffset.template GetCoordByAxis<0>(); + auto dequantScaleTileShape = actualTileShape.template GetCoordByAxis<0>(); + + auto gmTileDequantScale = gmDequantScale[params.layoutDequantScale.GetOffset(dequantScaleTileOffset)]; + auto layoutGmTileDequantScale = params.layoutDequantScale.GetTileLayout(dequantScaleTileShape); + + auto layoutUbDequantScale = + LayoutDequantScale::template MakeLayoutInUb(dequantScaleTileShape); + + AscendC::WaitFlag(0); + copyUbToGmDequantScale(gmTileDequantScale, ubDequantScale, layoutGmTileDequantScale, layoutUbDequantScale); + AscendC::SetFlag(0); + + AscendC::WaitFlag(0); + for (uint32_t rowIdx = 0; rowIdx < tileRow; ++rowIdx) { + AscendC::Muls(ubQuantF32[rowIdx * tileColumn], ubInputTmp[rowIdx * tileColumn], + 1.f / ubReduceMax.GetValue(rowIdx), tileColumn); + } + + AscendC::PipeBarrier(); + AscendC::Cast(ubQuantS32, ubQuantF32, AscendC::RoundMode::CAST_RINT, tileCount); + AscendC::PipeBarrier(); + AscendC::SetDeqScale(static_cast(1.0)); + AscendC::Cast(ubQuantF16, ubQuantS32, AscendC::RoundMode::CAST_RINT, tileCount); + AscendC::PipeBarrier(); + + AscendC::WaitFlag(1); + AscendC::Cast(ubOutput, ubQuantF16, AscendC::RoundMode::CAST_RINT, tileCount); + AscendC::SetFlag(1); + + auto gmTileOutput = gmOutput[params.layoutOutput.GetOffset(tileOffset)]; + auto layoutGmTileOutput = params.layoutOutput.GetTileLayout(actualTileShape); + + LayoutOutput layoutUbOutput{actualTileShape, ubTileStride}; + + AscendC::WaitFlag(1); + copyUbToGmOutput(gmTileOutput, ubOutput, layoutGmTileOutput, layoutUbOutput); + AscendC::SetFlag(1); + } + } + +private: + Params params; + uint32_t tileRow; + uint32_t tileColumn; + uint32_t tileCount; + uint32_t halfTileColumn; + uint32_t halfTileCount; + + AscendC::LocalTensor ubInput; + AscendC::LocalTensor ubDequantScale; + AscendC::LocalTensor ubOutput; + + AscendC::LocalTensor ubAbs; + AscendC::LocalTensor ubMax; + AscendC::LocalTensor ubReduceMax; + AscendC::LocalTensor ubQuantScale; + AscendC::LocalTensor ubQuantScaleBrcb; + AscendC::LocalTensor ubInputTmp; + AscendC::LocalTensor ubQuantF32; + AscendC::LocalTensor ubQuantS32; + AscendC::LocalTensor ubQuantF16; + + Epilogue::Tile::CopyGm2Ub copyGmToUbInput; + Epilogue::Tile::CopyUb2Gm copyUbToGmDequantScale; + Epilogue::Tile::CopyUb2Gm copyUbToGmOutput; +}; + +__aicore__ inline static void EncreaseSyncFlag(__gm__ uint8_t *flagAddr, uint8_t idx) +{ + // flag++,类似set flag + AscendC::PipeBarrier(); + AscendC::GlobalTensor global; + global.SetGlobalBuffer(flagAddr + idx * SOFT_SYNC_SPACE_SIZE); + __asm__ __volatile__(""); + AscendC::DataCacheCleanAndInvalid( + global); + __asm__ __volatile__(""); + uint8_t value = global.GetValue(0); + global.SetValue(0, value + 1); + __asm__ __volatile__(""); + AscendC::DataCacheCleanAndInvalid( + global); + __asm__ __volatile__(""); + AscendC::PipeBarrier(); +} + +__aicore__ inline static void CheckSyncFlag(__gm__ uint8_t *flagAddr, uint8_t idx, uint32_t target) +{ + // 查看flag,类似wait flag + AscendC::PipeBarrier(); + AscendC::GlobalTensor global; + global.SetGlobalBuffer(flagAddr + idx * SOFT_SYNC_SPACE_SIZE); + while (true) { + __asm__ __volatile__(""); + AscendC::DataCacheCleanAndInvalid(global); + __asm__ __volatile__(""); + uint8_t value = global.GetValue(0); + if (value >= target) { + __asm__ __volatile__(""); + AscendC::DataCacheCleanAndInvalid(global); + __asm__ __volatile__(""); + break; + } + } + AscendC::PipeBarrier(); +} + +__aicore__ inline static void CalQuantRow(const uint32_t column, uint32_t &row) +{ + row = QUANT_SPACE_FACTOR / column; + row = row < MAX_QUANT_ROW_ONCE ? row : MAX_QUANT_ROW_ONCE; +} + +template +class GroupedMatmulSliceMPerTokenDequantSwigluQuantMultiStageWorkspace +{ +public: + using BlockMmad = BlockMmad_; + using ArchTag = typename BlockMmad::ArchTag; + using L1TileShape = typename BlockMmad::L1TileShape; + using ElementA = typename BlockMmad::ElementA; + using LayoutA = typename BlockMmad::LayoutA; + using ElementB = typename BlockMmad::ElementB; + using LayoutB = typename BlockMmad::LayoutB; + using ElementC = typename BlockMmad::ElementC; + using LayoutC = typename BlockMmad::LayoutC; + using ElementAccumulator = typename BlockMmad::ElementAccumulator; + + using BlockEpilogue = BlockEpilogue_; + using ElementScale = typename BlockEpilogue::ElementScale; + using LayoutScale = typename BlockEpilogue::LayoutScale; + using ElementPerTokenScale = typename BlockEpilogue::ElementPerTokenScale; + using LayoutPerTokenScale = typename BlockEpilogue::LayoutPerTokenScale; + using ElementD = typename BlockEpilogue::ElementD; + using LayoutD = typename BlockEpilogue::LayoutD; + using EpilogueParams = typename BlockEpilogue::Params; + + using ElementDequantScale = typename BlockQuant::ElementDequantScale; + using LayoutDequantScale = typename BlockQuant::LayoutDequantScale; + using ElementOutput = typename BlockQuant::ElementOutput; + using LayoutOutput = typename BlockQuant::LayoutOutput; + + using BlockScheduler = BlockScheduler_; + static constexpr uint32_t WORKSPACE_STAGES = WORKSPACE_STAGES_; + using ElementGroupList = ElementGroupList_; + + using XType = XType_; + + // Parameters structure + struct Params { + // Data members + GemmCoord problemShape; + uint32_t problemCount; + __gm__ ElementGroupList_ *ptrGroupList; + __gm__ ElementA *ptrA; + LayoutA layoutA; + __gm__ ElementB *ptrB; + LayoutB layoutB; + __gm__ ElementScale *ptrScale; + LayoutScale layoutScale; + __gm__ ElementPerTokenScale *ptrPerTokenScale; + LayoutPerTokenScale layoutPerTokenScale; + __gm__ ElementOutput *ptrOutput; + LayoutOutput layoutOutput; + __gm__ ElementDequantScale *ptrDequantScale; + LayoutDequantScale layoutDequantScale; + GM_ADDR ptrWorkspace; + GM_ADDR gmX; + GM_ADDR debugGm; + GM_ADDR gmexpertIds; + + GM_ADDR gmExpandIdx; + GM_ADDR gmEpSendCount; + GM_ADDR gmResvered; + GM_ADDR gmOutputRecvCount; + + uint32_t epRankSize; + uint32_t epRankId; + uint32_t moeExpertNum; + uint32_t moeExpertNumPerRank; + uint32_t sharedExpertNum; + uint32_t sharedExpertRankNum; + uint32_t quantMode; + uint32_t globalBs; + uint32_t bs; + uint32_t topK; + uint32_t tokenLen; + // Methods + CATLASS_DEVICE + Params() {} + + CATLASS_DEVICE + Params(GemmCoord problemShape_, uint32_t problemCount_, GM_ADDR ptrGroupList_, GM_ADDR ptrA_, + LayoutA const &layoutA_, GM_ADDR ptrB_, LayoutB const &layoutB_, GM_ADDR ptrScale_, + LayoutScale const &layoutScale_, GM_ADDR ptrPerTokenScale_, + LayoutPerTokenScale const &layoutPerTokenScale_, GM_ADDR ptrOutput_, LayoutOutput const &layoutOutput_, + GM_ADDR ptrDequantScale_, LayoutDequantScale const &layoutDequantScale_, GM_ADDR ptrWorkspace_, + GM_ADDR gmX_, GM_ADDR debugGm_, GM_ADDR gmexpertIds_, GM_ADDR gmExpandIdx_, GM_ADDR gmEpSendCount_, + GM_ADDR gmResvered_, GM_ADDR gmOutputRecvCount_, uint32_t epRankSize_, uint32_t epRankId_, + uint32_t moeExpertNum_, uint32_t moeExpertNumPerRank_, uint32_t sharedExpertNum_, + uint32_t sharedExpertRankNum_, uint32_t quantMode_, uint32_t globalBs_, uint32_t bs_, uint32_t topK_, + uint32_t h) + : problemShape(problemShape_), + problemCount(problemCount_), + ptrGroupList(reinterpret_cast<__gm__ ElementGroupList *>(ptrGroupList_)), + ptrA(reinterpret_cast<__gm__ ElementA *>(ptrA_)), + layoutA(layoutA_), + ptrB(reinterpret_cast<__gm__ ElementB *>(ptrB_)), + layoutB(layoutB_), + ptrScale(reinterpret_cast<__gm__ ElementScale *>(ptrScale_)), + layoutScale(layoutScale_), + ptrPerTokenScale(reinterpret_cast<__gm__ ElementPerTokenScale *>(ptrPerTokenScale_)), + layoutPerTokenScale(layoutPerTokenScale_), + ptrOutput(reinterpret_cast<__gm__ ElementOutput *>(ptrOutput_)), + layoutOutput(layoutOutput_), + ptrDequantScale(reinterpret_cast<__gm__ ElementDequantScale *>(ptrDequantScale_)), + layoutDequantScale(layoutDequantScale_), + ptrWorkspace(ptrWorkspace_), + gmX(gmX_), + debugGm(debugGm_), + gmexpertIds(gmexpertIds_), + gmExpandIdx(gmExpandIdx_), + gmEpSendCount(gmEpSendCount_), + gmOutputRecvCount(gmOutputRecvCount_), + gmResvered(gmResvered_), + epRankSize(epRankSize_), + epRankId(epRankId_), + moeExpertNum(moeExpertNum_), + moeExpertNumPerRank(moeExpertNumPerRank_), + sharedExpertNum(sharedExpertNum_), + sharedExpertRankNum(sharedExpertRankNum_), + quantMode(quantMode_), + globalBs(globalBs_), + bs(bs_), + topK(topK_), + tokenLen(h) + {} + }; + + // Methods + CATLASS_DEVICE + GroupedMatmulSliceMPerTokenDequantSwigluQuantMultiStageWorkspace() {} + + template + CATLASS_DEVICE void operator()(Params const ¶ms); + + template <> + CATLASS_DEVICE void operator()(Params const ¶ms) + { + aicIdx = AscendC::GetBlockIdx(); + subBlockNum = AscendC::GetSubBlockNum(); + aiCoreGroupNum = AscendC::GetBlockNum(); + aicNum = aiCoreGroupNum; + aicStateGlobalCoreIdx = AIC_STATE_SPACE_IDNEX + aicIdx; + moeExpertNumPerRank = params.moeExpertNumPerRank; + isShareExpert = (params.epRankId < params.sharedExpertRankNum); + localExpertNum = isShareExpert ? 1 : moeExpertNumPerRank; + // 单卡单专家48发48收 + recvCoreNum = RECV_AIV_CORE_NUM; + // 单卡多专家24收24发 + if (localExpertNum > 1) { + recvCoreNum = RECV_AIV_CORE_NUM / SUB_AIV_NUM; + } + uint32_t coreNumPerGroup = recvCoreNum / localExpertNum; // 与V接收分配逻辑保持一致 + winContext_ = (__gm__ HcclOpResParam *)AscendC::GetHcclContext(); + + // 更新状态,影响CV交互使用的信号值 + statusDataSpaceGm = (GM_ADDR)(winContext_->localWindowsExp); + AscendC::GlobalTensor selfDataStatusTensor; + selfDataStatusTensor.SetGlobalBuffer((__gm__ int32_t *)(statusDataSpaceGm + STATE_WIN_OFFSET)); + __asm__ __volatile__(""); + AscendC::DataCacheCleanAndInvalid( + selfDataStatusTensor[aicStateGlobalCoreIdx * UB_ALIGN]); + __asm__ __volatile__(""); + cvDataState = selfDataStatusTensor(aicStateGlobalCoreIdx * UB_ALIGN); + if (cvDataState == 0) { + selfDataStatusTensor(aicStateGlobalCoreIdx * UB_ALIGN) = 1; + vToCFlag = V_TO_C_FLAG_1; + } else { + selfDataStatusTensor(aicStateGlobalCoreIdx * UB_ALIGN) = 0; + vToCFlag = V_TO_C_FLAG_2; + } + + BlockScheduler blockScheduler; + BlockMmad blockMmad(resource); + + // Represent the full gm + AscendC::GlobalTensor gmA; + gmA.SetGlobalBuffer(params.ptrA); + AscendC::GlobalTensor gmB; + gmB.SetGlobalBuffer(params.ptrB); + + AscendC::GlobalTensor groupList; + groupList.SetGlobalBuffer(params.ptrGroupList); + + int64_t gmGroupOffsetA = 0; + int64_t gmGroupOffsetB = 0; + + AscendC::GlobalTensor gmC; + gmC.SetGlobalBuffer(reinterpret_cast<__gm__ ElementC *>(params.ptrWorkspace)); + auto layoutC = layout::RowMajor{L1TileShape::M * aicNum * WORKSPACE_STAGES, L1TileShape::N}; + + uint32_t stageId = 0; + uint32_t stageUsed = 0; + uint32_t startCoreIdx = 0; + AscendC::GlobalTensor groupTokenNumStateTensor; + aicSetFunc1 = {statusDataSpaceGm + SOFT_SYNC_OFFSET, + static_cast(aicNum + AscendC::GetBlockIdx())}; // AIV等待的信息在24~48 + uint32_t target = 1; + for (uint32_t groupIdx = 0; groupIdx < localExpertNum; ++groupIdx) { + groupTokenNumStateTensor.SetGlobalBuffer((__gm__ int32_t *)(statusDataSpaceGm + GROUP_TOKEN_NUM_OFFSET) + + groupIdx * GROUP_INFO_SIZE); + // 等待AIV的token收齐信号后,再往下走 + while (true) { + __asm__ __volatile__(""); + AscendC::DataCacheCleanAndInvalid(groupTokenNumStateTensor); + __asm__ __volatile__(""); + if (groupTokenNumStateTensor.GetValue(0) == coreNumPerGroup * vToCFlag) { + break; + } + } + + uint32_t currentM = groupTokenNumStateTensor.GetValue(GROUP_TOKEN_COUNT); + GemmCoord inGroupProblemShape{currentM, params.problemShape.n(), params.problemShape.k()}; + + LayoutA layoutA = params.layoutA.GetTileLayout(inGroupProblemShape.GetCoordMK()); + LayoutB layoutB = params.layoutB; + + blockScheduler.Update(inGroupProblemShape, MakeCoord(L1TileShape::M, L1TileShape::N)); + uint32_t coreLoops = blockScheduler.GetCoreLoops(); + + // Determine the starting loopIdx of the current core under the current groupIdx + uint32_t startLoopIdx = ((aicIdx < startCoreIdx) ? (aicIdx + aicNum) : aicIdx) - startCoreIdx; + // Loop through the matmul of each groupIdx + for (uint32_t loopIdx = startLoopIdx; loopIdx < coreLoops; loopIdx += aicNum) { + // Compute block location + GemmCoord blockCoord = blockScheduler.GetBlockCoord(loopIdx); + GemmCoord actualBlockShape = blockScheduler.GetActualBlockShape(blockCoord); + + Callback callbackBeforeFixpipe{}; + if (stageUsed == WORKSPACE_STAGES) { + aicWaitFunc1 = {statusDataSpaceGm + SOFT_SYNC_OFFSET, static_cast(AscendC::GetBlockIdx()), + target}; // AIC等待的信号在前24个 + target += 1; + callbackBeforeFixpipe = MakeCallback(&aicWaitFunc1); + } else { + ++stageUsed; + } + Callback callbackAfterFixpipe = MakeCallback(&aicSetFunc1); + + // Compute initial location in logical coordinates + MatrixCoord offsetA{blockCoord.m() * L1TileShape::M, blockCoord.k() * L1TileShape::K}; + MatrixCoord offsetB{blockCoord.k() * L1TileShape::K, blockCoord.n() * L1TileShape::N}; + MatrixCoord offsetC{(stageId * aicNum + aicIdx) * L1TileShape::M, 0}; + int64_t gmOffsetA = layoutA.GetOffset(offsetA); + int64_t gmOffsetB = layoutB.GetOffset(offsetB); + int64_t gmOffsetC = layoutC.GetOffset(offsetC); + + // Compute block-scoped matrix multiply-add + if constexpr (BlockMmad::DispatchPolicy::ASYNC) { + blockMmad(gmA[gmGroupOffsetA + gmOffsetA], layoutA, gmB[gmGroupOffsetB + gmOffsetB], layoutB, + gmC[gmOffsetC], layoutC, actualBlockShape, callbackBeforeFixpipe, callbackAfterFixpipe); + } else { + callbackBeforeFixpipe(); + blockMmad(gmA[gmGroupOffsetA + gmOffsetA], layoutA, gmB[gmGroupOffsetB + gmOffsetB], layoutB, + gmC[gmOffsetC], layoutC, actualBlockShape); + callbackAfterFixpipe(); + } + + stageId = (stageId + 1 < WORKSPACE_STAGES) ? (stageId + 1) : 0; + } + + gmGroupOffsetA += inGroupProblemShape.m() * inGroupProblemShape.k(); + gmGroupOffsetB += inGroupProblemShape.k() * inGroupProblemShape.n(); + + startCoreIdx = (startCoreIdx + coreLoops) % aicNum; + } + + if constexpr (BlockMmad::DispatchPolicy::ASYNC) { + blockMmad.SynchronizeBlock(); + } + + while (stageUsed > 0) { + uint32_t aivComputeStageId = + (stageId >= stageUsed) ? (stageId - stageUsed) : (stageId + WORKSPACE_STAGES - stageUsed); + target += 1; // 追平AIV多余的软同步 + --stageUsed; + } + AscendC::SyncAll(); + } + + CATLASS_DEVICE + void CalExpandxIdx(int32_t dstExpertId, uint32_t tokenIndex, int32_t &curExpertCnt, int64_t ubOffset) + { + // 使用AIV计算发送到对端的偏移量 + int64_t subUbOffset = ubOffset; + AscendC::LocalTensor dstExpIdTensor_ = (resource.ubBuf.template GetBufferByByte(ubOffset)); + subUbOffset += LOOP_TMP_SIZE; + AscendC::LocalTensor subExpIdTensor_ = (resource.ubBuf.template GetBufferByByte(ubOffset)); + subUbOffset += LOOP_TMP_SIZE; + AscendC::LocalTensor workLocalTensor_ = (resource.ubBuf.template GetBufferByByte(ubOffset)); + subUbOffset += LOOP_TMP_SIZE; + AscendC::Duplicate(dstExpIdTensor_, dstExpertId, tokenIndex); + AscendC::PipeBarrier(); + AscendC::Sub(subExpIdTensor_, expertIdsTensor_, dstExpIdTensor_, tokenIndex); + AscendC::PipeBarrier(); + AscendC::LocalTensor tmpFp32 = subExpIdTensor_.ReinterpretCast(); + AscendC::LocalTensor tmpoutFp32 = dstExpIdTensor_.ReinterpretCast(); + AscendC::Abs(tmpoutFp32, tmpFp32, tokenIndex); + AscendC::PipeBarrier(); + AscendC::Mins(subExpIdTensor_, dstExpIdTensor_, 1, tokenIndex); + AscendC::PipeBarrier(); + AscendC::ReduceSum(tmpoutFp32, tmpFp32, workLocalTensor_, tokenIndex); + AscendC::SetFlag(0); + AscendC::WaitFlag(0); + int32_t curOtherExpertCnt = dstExpIdTensor_(0); + if (tokenIndex > curOtherExpertCnt) { + curExpertCnt = tokenIndex - curOtherExpertCnt; + } + } + + CATLASS_DEVICE + void CalAndSendTokenCount() + { + // 计算发送token的数量,并且发送出去 + uint32_t totalExpertNum = sharedExpertRankNum + moeExpertNum; + uint32_t sendCountExpertNum = totalExpertNum / sendCoreNum; // 每个aiv需要处理的专家数 + uint32_t remainderRankNum = totalExpertNum % sendCoreNum; + uint32_t startExpertId = sendCountExpertNum * sendCoreIdx; // sharedExpertRankNum, 每个aiv发送的起始rankid + if (sendCoreIdx < remainderRankNum) { // 前remainderRankNum个aiv需要多发1个卡的数据 + sendCountExpertNum += 1; + startExpertId += sendCoreIdx; + } else { + startExpertId += remainderRankNum; + } + uint32_t endExpertId = startExpertId + sendCountExpertNum; + if (startExpertId >= totalExpertNum) { + return; + } + // 计算count及偏移 + AscendC::LocalTensor statusTensor_ = resource.ubBuf.template GetBufferByByte(ubOffset); + ubOffset += CEIL_UP(CEIL(expertCntUp, INT32_COUNT_PER_BLOCK) * INT32_COUNT_PER_BLOCK * UB_BLOCK_SIZE); + AscendC::Duplicate(statusTensor_, (int32_t)0, + expertCntUp * INT32_COUNT_PER_BLOCK); // 先清零再赋值 + if (state == 0) { + // 一次性操作256字节,也是64个int32_t,每8个数将首个设置为0x3F800000,即浮点数的1.0 + uint64_t mask[2] = {0x101010101010101, 0}; + AscendC::PipeBarrier(); + AscendC::Duplicate(statusTensor_, 0x3F800000, mask, CEIL(expertCntUp, 8), 1, 8); + } + + AscendC::SetFlag(0); + AscendC::WaitFlag(0); + + if (!isShareExpert) { + for (uint32_t curSatatusExpId = 0; curSatatusExpId < sharedExpertRankNum; ++curSatatusExpId) { + int32_t curExpertCnt = (curSatatusExpId + 1 + epRankId) * axisBS / sharedExpertRankNum - + (curSatatusExpId + epRankId) * axisBS / sharedExpertRankNum; + statusTensor_((curSatatusExpId)*INT32_COUNT_PER_BLOCK + 1) = curExpertCnt; + } + } + + for (uint32_t curExpertId = startExpertId; curExpertId < endExpertId; ++curExpertId) { + if (curExpertId < sharedExpertRankNum) { + continue; + } + int32_t curExpertCnt = 0; + int32_t dstExpertId = curExpertId - sharedExpertRankNum; + CalExpandxIdx(dstExpertId, expertIdsCnt, curExpertCnt, ubOffset); + int32_t cntPosIndex = curExpertId * INT32_COUNT_PER_BLOCK + 1; + statusTensor_(cntPosIndex) = curExpertCnt; + } + + AscendC::SetFlag(0); + AscendC::WaitFlag(0); + + AscendC::GlobalTensor rankGMTensor; + uint32_t offset = stateOffset * epRankId; + for (uint32_t rankIndex = startExpertId; rankIndex < endExpertId; ++rankIndex) { + uint32_t dstRankId = rankIndex; + if (moeExpertNumPerRank > 1 && (rankIndex >= sharedExpertRankNum)) { + dstRankId = ((rankIndex - sharedExpertRankNum) / moeExpertNumPerRank + sharedExpertRankNum); + offset = + (epRankId + (rankIndex - sharedExpertRankNum) % moeExpertNumPerRank * epRankSize) * stateOffset; + } + GM_ADDR rankGM = (__gm__ uint8_t *)(GET_WIND_STATE_ADDR_BY_RANK_ID(dstRankId) + offset); // 计算地址偏移 + rankGMTensor.SetGlobalBuffer((__gm__ int32_t *)rankGM); + AscendC::DataCopy(rankGMTensor, statusTensor_[rankIndex * 8], 8UL); + } + } + + CATLASS_DEVICE + void QuantToken(AscendC::LocalTensor &xInTensor, AscendC::LocalTensor &yInt8Tensor, int64_t ubOffset) + { + // 量化token + int64_t subUbOffset = ubOffset; + AscendC::LocalTensor xFp32TmpTensor = resource.ubBuf.template GetBufferByByte(subUbOffset); + subUbOffset += CEIL_UP(tokenLength * sizeof(float)); + AscendC::LocalTensor xFp32AbsTensor = resource.ubBuf.template GetBufferByByte(subUbOffset); + subUbOffset += CEIL_UP(tokenLength * sizeof(float)); + AscendC::LocalTensor xRowMaxTensor = resource.ubBuf.template GetBufferByByte(subUbOffset); + subUbOffset += CEIL_UP(UB_BLOCK_SIZE); + AscendC::LocalTensor ytmpInt32Tensor = xFp32TmpTensor.template ReinterpretCast(); + AscendC::LocalTensor yHalfTensor = xFp32TmpTensor.template ReinterpretCast(); + AscendC::LocalTensor yFp32Tensor = yInt8Tensor.template ReinterpretCast(); + AscendC::LocalTensor yInt32Tensor = yInt8Tensor.template ReinterpretCast(); + + AscendC::Cast(xFp32TmpTensor, xInTensor, AscendC::RoundMode::CAST_NONE, tokenLength); + AscendC::PipeBarrier(); + AscendC::Abs(xFp32AbsTensor, xFp32TmpTensor, tokenLength); + AscendC::PipeBarrier(); + AscendC::ReduceMax(xRowMaxTensor, xFp32AbsTensor, xFp32AbsTensor, tokenLength, false); + AscendC::PipeBarrier(); + + AscendC::SetFlag(0); + AscendC::WaitFlag(0); + float dynamicQuantScale = float(127.0) / xRowMaxTensor.GetValue(0); + yFp32Tensor.SetValue(tokenLength / sizeof(float), float(1.0) / dynamicQuantScale); + yInt32Tensor.SetValue(tokenLength / sizeof(int32_t) + 1, tokenFlag); + AscendC::SetFlag(0); + AscendC::SetFlag(0); + AscendC::WaitFlag(0); + + AscendC::Muls(xFp32TmpTensor, xFp32TmpTensor, dynamicQuantScale, tokenLength); + AscendC::PipeBarrier(); + AscendC::Cast(ytmpInt32Tensor, xFp32TmpTensor, AscendC::RoundMode::CAST_RINT, tokenLength); + AscendC::PipeBarrier(); + AscendC::Cast(yHalfTensor, ytmpInt32Tensor, AscendC::RoundMode::CAST_ROUND, tokenLength); + AscendC::PipeBarrier(); + AscendC::Cast(yInt8Tensor, yHalfTensor, AscendC::RoundMode::CAST_TRUNC, tokenLength); + } + + CATLASS_DEVICE + void SendToShareExprt(GM_ADDR gmX, GM_ADDR gmX1, GM_ADDR gmX1Scale) + { + // 给共享专家发送token + uint32_t newAivId = sendCoreIdx - sendToMoeAivNum; + uint32_t sendTokenNum = axisBS / sendToShareAivNum; // 每个aiv需要发送的token数 + uint32_t remainderTokenNum = axisBS % sendToShareAivNum; + uint32_t startTokenId = sendTokenNum * newAivId; // 每个aiv发送时的起始rankid + if (newAivId < remainderTokenNum) { + sendTokenNum += 1; + startTokenId += newAivId; + } else { + startTokenId += remainderTokenNum; + } + uint32_t endTokenId = startTokenId + sendTokenNum; + if (startTokenId >= axisBS) { + return; + } + + AscendC::LocalTensor xInTensor[BUFFER_NUM]; + AscendC::LocalTensor yInt8Tensor[BUFFER_NUM]; + AscendC::LocalTensor yFp32Tensor[BUFFER_NUM]; + + AscendC::GlobalTensor srcWinGMTensor; // token输入 + srcWinGMTensor.SetGlobalBuffer((__gm__ XType *)gmX); + + xInTensor[0] = resource.ubBuf.template GetBufferByByte(ubOffset); + ubOffset += CEIL_UP(tokenLength * sizeof(XType)); + xInTensor[1] = resource.ubBuf.template GetBufferByByte(ubOffset); + ubOffset += CEIL_UP(tokenLength * sizeof(XType)); + yInt8Tensor[0] = resource.ubBuf.template GetBufferByByte(ubOffset); + yFp32Tensor[0] = yInt8Tensor[0].template ReinterpretCast(); + ubOffset += CEIL_UP(axisHCommu * sizeof(int8_t)); + yInt8Tensor[1] = resource.ubBuf.template GetBufferByByte(ubOffset); + yFp32Tensor[1] = yInt8Tensor[1].template ReinterpretCast(); + ubOffset += CEIL_UP(axisHCommu * sizeof(int8_t)); + AscendC::GlobalTensor dstWinGMTensor; // token输出 + AscendC::GlobalTensor expandXOutGlobal; + expandXOutGlobal.SetGlobalBuffer((__gm__ int8_t *)(gmX1)); + AscendC::GlobalTensor dynamicScalesOutGMTensor_; + dynamicScalesOutGMTensor_.SetGlobalBuffer((__gm__ float *)(gmX1Scale)); + + // 输入输出开double buffer + AscendC::SetFlag(0); // MTE2等MTE3 + AscendC::SetFlag(1); // MTE2等MTE3 + AscendC::SetFlag(0); + AscendC::SetFlag(1); + + for (uint32_t tokenIndex = startTokenId; tokenIndex < endTokenId; ++tokenIndex) { + uint32_t index = (tokenIndex & 1) ? 0 : 1; + int32_t eventId = (tokenIndex & 1) ? 0 : 1; + // 下面的计算有点绕,目的是计算目的专家卡和偏移 + uint32_t temp = (epRankId * axisBS) / sharedExpertRankNum; + // 当前token发给哪个共享专家 + uint32_t moeOnShareRank = CEIL((tokenIndex + 1 + temp) * sharedExpertRankNum, axisBS) - 1 - epRankId; + // 发给该共享专家已经有多少token数据 + uint32_t preCnt = (moeOnShareRank + epRankId) * axisBS / sharedExpertRankNum - + epRankId * axisBS / sharedExpertRankNum; + dstWinGMTensor.SetGlobalBuffer( + (__gm__ int8_t *)(GET_WIND_ADDR_BY_RANK_ID(moeOnShareRank) + expertPerSizeOnWin * epRankId)); + + AscendC::WaitFlag(eventId); + AscendC::WaitFlag(eventId); + AscendC::DataCopy(xInTensor[index], srcWinGMTensor[tokenIndex * tokenLength], tokenLength); + AscendC::SetFlag(eventId); + AscendC::WaitFlag(eventId); + QuantToken(xInTensor[index], yInt8Tensor[index], ubOffset); + AscendC::SetFlag(eventId); + AscendC::WaitFlag(0); + + AscendC::WaitFlag(eventId); + if (isShareExpert) { + AscendC::DataCopyExtParams dataCopyParamsFloat = {1U, sizeof(float), 0U, 0U, 0U}; + AscendC::DataCopy(expandXOutGlobal[tokenIndex * tokenLength], yInt8Tensor[index], tokenLength); + AscendC::PipeBarrier(); + AscendC::DataCopyPad(dynamicScalesOutGMTensor_[tokenIndex], + yFp32Tensor[index][tokenLength / sizeof(float)], dataCopyParamsFloat); + } else { + // 可能有时序问题,所以分开发送 + AscendC::DataCopy(dstWinGMTensor[(tokenIndex - preCnt) * axisHCommu], yInt8Tensor[index], + tokenLength); + AscendC::PipeBarrier(); + AscendC::DataCopy(dstWinGMTensor[(tokenIndex - preCnt) * axisHCommu + tokenLength], + yInt8Tensor[index][tokenLength], scaleParamPad); + } + AscendC::SetFlag(eventId); + AscendC::SetFlag(eventId); + } + AscendC::WaitFlag(0); + AscendC::WaitFlag(1); + AscendC::WaitFlag(0); + AscendC::WaitFlag(1); + } + + CATLASS_DEVICE + void SendToMoeExprt(GM_ADDR gmX, GM_ADDR gmExpandIdx) + { + // 给路由专家发送token + uint32_t sendTokenNum = expertIdsCnt / sendToMoeAivNum; + uint32_t remainderTokenNum = expertIdsCnt % sendToMoeAivNum; + uint32_t startTokenId = sendTokenNum * sendCoreIdx; + if (sendCoreIdx < remainderTokenNum) { + sendTokenNum += 1; + startTokenId += sendCoreIdx; + } else { + startTokenId += remainderTokenNum; + } + uint32_t endTokenId = startTokenId + sendTokenNum; + if (startTokenId >= expertIdsCnt) { + return; + } + AscendC::LocalTensor expertCountTensor = (resource.ubBuf.template GetBufferByByte(ubOffset)); + ubOffset += CEIL_UP(expertIdsCnt * sizeof(int32_t)); + AscendC::Duplicate(expertCountTensor, (int32_t)0, expertIdsCnt); // 清零 + AscendC::SetFlag(1); + AscendC::WaitFlag(1); + + AscendC::LocalTensor xInTensor[BUFFER_NUM]; + AscendC::LocalTensor yInt8Tensor[BUFFER_NUM]; + AscendC::LocalTensor yFp32Tensor[BUFFER_NUM]; + + AscendC::GlobalTensor srcWinGMTensor; // token输入 + srcWinGMTensor.SetGlobalBuffer((__gm__ XType *)gmX); + + xInTensor[0] = resource.ubBuf.template GetBufferByByte(ubOffset); + ubOffset += CEIL_UP(tokenLength * sizeof(XType)); + xInTensor[1] = resource.ubBuf.template GetBufferByByte(ubOffset); + ubOffset += CEIL_UP(tokenLength * sizeof(XType)); + yInt8Tensor[0] = resource.ubBuf.template GetBufferByByte(ubOffset); + ubOffset += CEIL_UP(axisHCommu * sizeof(int8_t)); + yInt8Tensor[1] = resource.ubBuf.template GetBufferByByte(ubOffset); + ubOffset += CEIL_UP(axisHCommu * sizeof(int8_t)); + AscendC::GlobalTensor dstWinGMTensor; // token输出 + // 输入输出开double buffer + AscendC::SetFlag(0); // MTE2等MTE3 + AscendC::SetFlag(1); // MTE2等MTE3 + AscendC::SetFlag(0); + AscendC::SetFlag(1); + uint32_t sendValidTokenIndex = 0; + for (uint32_t sendGroupIndex = 0; sendGroupIndex < moeExpertNumPerRank; ++sendGroupIndex) { + for (uint32_t tokenIndex = startTokenId; tokenIndex < endTokenId; ++tokenIndex) { + int32_t dstExpertId = expertIdsTensor_(tokenIndex); + if (dstExpertId < 0) { + continue; + } + if ((dstExpertId % moeExpertNumPerRank) != sendGroupIndex) { // 优先发送指定专家的token + continue; + } + uint32_t index = (sendValidTokenIndex & 1) ? 0 : 1; + int32_t eventId = (sendValidTokenIndex & 1) ? 0 : 1; + sendValidTokenIndex += 1; + int32_t curExpertCnt = 0; + CalExpandxIdx(dstExpertId, tokenIndex, curExpertCnt, ubOffset); + expertCountTensor(tokenIndex - startTokenId) = curExpertCnt; + uint32_t tempRankId = dstExpertId / moeExpertNumPerRank + sharedExpertRankNum; + GM_ADDR rankGM = (__gm__ uint8_t *)(GET_WIND_ADDR_BY_RANK_ID(tempRankId) + + (expertPerSizeOnWin * (epRankId * moeExpertNumPerRank + + dstExpertId % moeExpertNumPerRank)) + + hCommuSize * curExpertCnt); + dstWinGMTensor.SetGlobalBuffer((__gm__ int8_t *)rankGM); + + AscendC::WaitFlag(eventId); + AscendC::WaitFlag(eventId); + AscendC::DataCopy(xInTensor[index], srcWinGMTensor[tokenIndex / axisK * tokenLength], tokenLength); + AscendC::SetFlag(eventId); + AscendC::WaitFlag(eventId); + QuantToken(xInTensor[index], yInt8Tensor[index], ubOffset); + AscendC::SetFlag(eventId); + + AscendC::WaitFlag(0); + AscendC::WaitFlag(eventId); + + // 担心有时序问题,所以分开发送 + AscendC::DataCopy(dstWinGMTensor, yInt8Tensor[index], tokenLength); + AscendC::PipeBarrier(); + AscendC::DataCopy(dstWinGMTensor[tokenLength], yInt8Tensor[index][tokenLength], scaleParamPad); + AscendC::SetFlag(eventId); + AscendC::SetFlag(eventId); + } + } + AscendC::WaitFlag(0); // MTE2等MTE3 + AscendC::WaitFlag(1); // MTE2等MTE3 + AscendC::WaitFlag(0); + AscendC::WaitFlag(1); + + AscendC::GlobalTensor expandIdxGMTensor; + expandIdxGMTensor.SetGlobalBuffer((__gm__ int32_t *)gmExpandIdx + startTokenId); + AscendC::DataCopyExtParams expertIdsCntParams = {1U, static_cast(sendTokenNum * sizeof(uint32_t)), 0U, + 0U, 0U}; + AscendC::SetFlag(0); + AscendC::WaitFlag(0); + AscendC::DataCopyPad(expandIdxGMTensor, expertCountTensor, expertIdsCntParams); + } + + CATLASS_DEVICE void + SendCoreFunc(GM_ADDR gmX, GM_ADDR gmExpertIds, GM_ADDR gmX1, GM_ADDR gmX1Scale, GM_ADDR gmExpandIdx) + { + ubOffset = 0; + expertIdsCnt = axisBS * axisK; + + AscendC::GlobalTensor expertIdsGMTensor_; + expertIdsGMTensor_.SetGlobalBuffer((__gm__ int32_t *)gmExpertIds); + expertIdsTensor_ = (resource.ubBuf.template GetBufferByByte(ubOffset)); + ubOffset += CEIL_UP(expertIdsCnt * sizeof(int32_t)); + + AscendC::DataCopyExtParams expertIdsCntParams = {1U, static_cast(expertIdsCnt * sizeof(uint32_t)), 0U, 0U, + 0U}; + AscendC::DataCopyPadExtParams copyPadParams{false, 0U, 0U, 0U}; + AscendC::DataCopyPad(expertIdsTensor_, expertIdsGMTensor_, expertIdsCntParams, copyPadParams); + AscendC::SetFlag(0); + AscendC::WaitFlag(0); + + CalAndSendTokenCount(); + AscendC::PipeBarrier(); + if (hasShareExpert) { + sendToShareAivNum = sendCoreNum / (axisK + 1); // 均等分,取整 + if (sendToShareAivNum == 0) { + sendToShareAivNum = 1; + } + } + sendToMoeAivNum = sendCoreNum - sendToShareAivNum; + + AscendC::SetDeqScale((half)1.000000e+00f); + if (hasShareExpert && sendCoreIdx >= sendToMoeAivNum) { + SendToShareExprt(gmX, gmX1, gmX1Scale); + } else { + SendToMoeExprt(gmX, gmExpandIdx); + } + AscendC::PipeBarrier(); + } + + CATLASS_DEVICE + void RecvCount(int64_t ubOffset) + { + // 接收count数据 + uint32_t recStatusNumPerCore = isShareExpert ? epRankSize : expertCntUp; + uint32_t startStatusIndex = 0; // 目前每个核都要收集所有的count + + int64_t subUbOffset = ubOffset; + AscendC::LocalTensor statusTensor_ = resource.ubBuf.template GetBufferByByte(subUbOffset); + subUbOffset += CEIL_UP(expertCntUp * UB_BLOCK_SIZE); + AscendC::LocalTensor gatherTmpTensor = (resource.ubBuf.template GetBufferByByte(subUbOffset)); + subUbOffset += CEIL_UP(UB_BLOCK_SIZE); + AscendC::LocalTensor gatherMaskOutTensor = resource.ubBuf.template GetBufferByByte(subUbOffset); + subUbOffset += CEIL_UP(expertCntUp * sizeof(float)); + AscendC::LocalTensor statusFp32Tensor_ = statusTensor_.ReinterpretCast(); + + AscendC::LocalTensor statusSumOutTensor = resource.ubBuf.template GetBufferByByte(subUbOffset); + subUbOffset += CEIL_UP(UB_BLOCK_SIZE); + AscendC::LocalTensor sumTmpTensor = resource.ubBuf.template GetBufferByByte(subUbOffset); + subUbOffset += CEIL_UP(SUM_TMP_TENSOR_SIZE); + gatherTmpTensor.SetValue(0, 1); + + uint32_t mask = 1; // gatherMask + sum 相关参数 + uint64_t rsvdCnt = 0; + AscendC::SumParams sumParams{1, recStatusNumPerCore, recStatusNumPerCore}; + float sumOfFlag = static_cast(-1.0); + float minTarget = (sumTarget * recStatusNumPerCore) - (float)0.5; + float maxTarget = (sumTarget * recStatusNumPerCore) + (float)0.5; + AscendC::DataCopyParams intriParams{static_cast(recStatusNumPerCore), 1, static_cast(15), + 0}; // srcStride为15个block + AscendC::GlobalTensor windowInstatusFp32Tensor_; + windowInstatusFp32Tensor_.SetGlobalBuffer((__gm__ float *)GET_WIND_STATE_ADDR_BY_RANK_ID(epRankId)); + AscendC::SetFlag(0); + AscendC::WaitFlag(0); + + uint32_t preRecvTokenCount = 0; + while ((sumOfFlag < minTarget) || (sumOfFlag > maxTarget)) { + AscendC::DataCopy(statusFp32Tensor_, windowInstatusFp32Tensor_[startStatusIndex * stateOffset / sizeof(float)], + intriParams); + AscendC::SetFlag(0); + AscendC::WaitFlag(0); + AscendC::GatherMask(gatherMaskOutTensor, statusFp32Tensor_, gatherTmpTensor, true, mask, + {1, (uint16_t)recStatusNumPerCore, 1, 0}, rsvdCnt); + AscendC::PipeBarrier(); + AscendC::Sum(statusSumOutTensor, gatherMaskOutTensor, sumTmpTensor, sumParams); + AscendC::SetFlag(0); + AscendC::WaitFlag(0); + sumOfFlag = statusSumOutTensor.GetValue(0); + } + } + + CATLASS_DEVICE + void GetCumSum(int32_t startRankId, int32_t recvExpertNum, int64_t ubOffset, GM_ADDR gmOutputRecvCount) + { + // 计算前缀和,目的是知道自己收到的token在output中的偏移 + int64_t subUbOffset = ubOffset; + uint32_t recStatusNumPerCore = isShareExpert ? epRankSize : expertCntUp; + AscendC::LocalTensor statusTensor_ = resource.ubBuf.template GetBufferByByte(subUbOffset); + subUbOffset += CEIL_UP(expertCntUp * UB_BLOCK_SIZE); + AscendC::LocalTensor gatherTmpTensor = (resource.ubBuf.template GetBufferByByte(subUbOffset)); + subUbOffset += CEIL_UP(UB_BLOCK_SIZE); + AscendC::LocalTensor gatherMaskOutTensor = resource.ubBuf.template GetBufferByByte(subUbOffset); + subUbOffset += CEIL_UP(expertCntUp * sizeof(float)); + AscendC::LocalTensor statusFp32Tensor_ = statusTensor_.ReinterpretCast(); + if (isShareExpert) { + for (uint32_t curSatatusExpId = 0; curSatatusExpId < sharedExpertRankNum; ++curSatatusExpId) { + int32_t curExpertCnt = (curSatatusExpId + 1 + epRankId) * axisBS / sharedExpertRankNum - + (curSatatusExpId + epRankId) * axisBS / sharedExpertRankNum; + statusTensor_((curSatatusExpId)*INT32_COUNT_PER_BLOCK + 1) = curExpertCnt; + } + } + + uint64_t rsvdCnt = 0; + gatherTmpTensor.SetValue(0, GATHER_SECOND_NUM); + AscendC::SetFlag(0); + AscendC::WaitFlag(0); + AscendC::GatherMask(gatherMaskOutTensor, statusFp32Tensor_, gatherTmpTensor, true, GATHER_SECOND_NUM, + {1, (uint16_t)recStatusNumPerCore, 1, 0}, rsvdCnt); + if (isRecvCore && recvCoreIdx == 0) { + AscendC::GlobalTensor recvCountTensor; + recvCountTensor.SetGlobalBuffer((__gm__ int32_t *)gmOutputRecvCount); + AscendC::DataCopyExtParams dataCopyParams = { + 1U, static_cast(localExpertNum * epRankSize * sizeof(int32_t)), 0U, 0U, 0U}; + AscendC::SetFlag(0); + AscendC::WaitFlag(0); + AscendC::DataCopyPad(recvCountTensor, gatherMaskOutTensor.ReinterpretCast(), dataCopyParams); + } + // 这里是为ReduceSum准备所需空间,本应该计算好需要多大空间,但当前是给偏移,且用完就释放,所以就不计算了 + AscendC::LocalTensor workLocalTensor = resource.ubBuf.template GetBufferByByte(subUbOffset); + AscendC::PipeBarrier(); + AscendC::ReduceSum(gatherMaskOutTensor, gatherMaskOutTensor, workLocalTensor, + (startRankId + 1) <= recvExpertNum ? (startRankId + 1) : recvExpertNum); + AscendC::SetFlag(0); + AscendC::WaitFlag(0); + } + + CATLASS_DEVICE + void RecvToken(GM_ADDR gmX1, GM_ADDR gmX1Scale, GM_ADDR gmEpSendCount, uint32_t &coreTokenCount, uint32_t startRankId, + uint32_t endRankId, uint32_t recvRankNumPerCore, int64_t ubOffset) + { + // 接收token + int64_t subUbOffset = ubOffset; + AscendC::LocalTensor statusTensor_ = resource.ubBuf.template GetBufferByByte(subUbOffset); + subUbOffset += CEIL_UP(expertCntUp * UB_BLOCK_SIZE); + AscendC::LocalTensor gatherTmpTensor = (resource.ubBuf.template GetBufferByByte(subUbOffset)); + subUbOffset += CEIL_UP(UB_BLOCK_SIZE); + AscendC::LocalTensor gatherMaskOutTensor = resource.ubBuf.template GetBufferByByte(subUbOffset); + subUbOffset += CEIL_UP(expertCntUp * sizeof(float)); + AscendC::LocalTensor statusFp32Tensor_ = statusTensor_.ReinterpretCast(); + + AscendC::DataCopyExtParams dataCopyParamsFloat = {1U, sizeof(float), 0U, 0U, 0U}; + AscendC::LocalTensor xTmpTensor_ = resource.ubBuf.template GetBufferByByte(subUbOffset); + subUbOffset += CEIL_UP(axisHCommu * sizeof(int8_t)); + AscendC::LocalTensor xOutFp32Tensor_ = xTmpTensor_.template ReinterpretCast(); + AscendC::LocalTensor tmpLocalTensor = resource.ubBuf.template GetBufferByByte(subUbOffset); + subUbOffset += CEIL_UP(UB_BLOCK_SIZE); + AscendC::LocalTensor gatherMaskOutCountTensor = (gatherMaskOutTensor.template ReinterpretCast()); + AscendC::GlobalTensor tokGlobal; + AscendC::GlobalTensor tokGlobalInt32; + AscendC::GlobalTensor expandXOutGlobal; + AscendC::GlobalTensor dynamicScalesOutGMTensor_; + dynamicScalesOutGMTensor_.SetGlobalBuffer((__gm__ float *)(gmX1Scale)); + uint32_t beginIdx = 0; + for (uint32_t index = startRankId; index < endRankId; index++) { + uint32_t i = index - startRankId; + if (i > 0) { + gatherMaskOutCountTensor.SetValue( + i, gatherMaskOutCountTensor.GetValue(i - 1) + gatherMaskOutCountTensor.GetValue(index)); + } + uint32_t count = statusTensor_.GetValue(index * INT32_COUNT_PER_BLOCK + 1); + coreTokenCount += count; + beginIdx = gatherMaskOutCountTensor.GetValue(i) - count; + if (isShareExpert && index < sharedExpertRankNum) { + beginIdx += count; + continue; + } + uint32_t winOffset = index; + if (!isShareExpert && moeExpertNumPerRank > 1) { + // count的空间排布,与token数据的空间排布不同,需要转换成数据区的排布偏移 + // srcRank: index % epRankSize + // localExpertId: index / epRankSize + // Addr: (srcRank * moeExpertNumPerRank + localExpertId) * expertPerSizeOnWin + winOffset = (index % epRankSize) * moeExpertNumPerRank + index / epRankSize; + } + GM_ADDR wAddr = (__gm__ uint8_t *)(GET_WIND_ADDR_BY_RANK_ID(epRankId)) + winOffset * expertPerSizeOnWin; + AscendC::SetFlag(0); + for (uint32_t j = 0; j < count; j++) { + tokGlobal.SetGlobalBuffer((__gm__ int8_t *)(wAddr + j * hCommuSize)); + tokGlobalInt32.SetGlobalBuffer((__gm__ int32_t *)(wAddr + j * hCommuSize + hOutSize)); + expandXOutGlobal.SetGlobalBuffer((__gm__ int8_t *)(gmX1) + (beginIdx + j) * tokenLength, tokenLength); + + while (true) { + AscendC::DataCopy(tmpLocalTensor, tokGlobalInt32, INT32_COUNT_PER_BLOCK); + AscendC::SetFlag(0); + AscendC::WaitFlag(0); + if (tmpLocalTensor.GetValue(1) == tokenFlag) { + tokGlobalInt32.SetValue(1, 0); + __asm__ __volatile__(""); + AscendC::DataCacheCleanAndInvalid(tokGlobalInt32[1]); + __asm__ __volatile__(""); + break; + } + } + AscendC::PipeBarrier(); + + AscendC::WaitFlag(0); + AscendC::DataCopy(xTmpTensor_, tokGlobal, axisHCommu); + AscendC::SetFlag(0); + AscendC::WaitFlag(0); + AscendC::DataCopyPad(dynamicScalesOutGMTensor_[beginIdx + j], xOutFp32Tensor_[tokenLength / sizeof(float)], + dataCopyParamsFloat); + AscendC::DataCopy(expandXOutGlobal, xTmpTensor_, tokenLength); + AscendC::SetFlag(0); + } + AscendC::WaitFlag(0); + beginIdx += count; + } + AscendC::PipeBarrier(); + + AscendC::SetFlag(0); + AscendC::WaitFlag(0); + AscendC::DataCopyExtParams dataCopyOutParams = {1U, static_cast(recvRankNumPerCore * sizeof(int32_t)), 0U, + 0U, 0U}; + AscendC::GlobalTensor sendCountsGlobal; + sendCountsGlobal.SetGlobalBuffer(reinterpret_cast<__gm__ int32_t *>(gmEpSendCount)); + AscendC::DataCopyPad(sendCountsGlobal[startRankId], gatherMaskOutCountTensor, dataCopyOutParams); + } + + CATLASS_DEVICE + void RecvCoreFunc(GM_ADDR gmX1, GM_ADDR gmX1Scale, GM_ADDR gmEpSendCount, GM_ADDR gmOutputRecvCount) + { + ubOffset = 0; + RecvCount(ubOffset); + + // 先按本地专家分核,再在专家内进一步分核 + uint32_t recvExpertNum = isShareExpert ? epRankSize : expertCntUp; + uint32_t recvCoreNumPerGroup = recvCoreNum / localExpertNum; // 每个group由若干核处理,下取整,可能有空闲核 + uint32_t recvRankNumPerCore = epRankSize / recvCoreNumPerGroup; // 每个核处理的rank数量 + uint32_t remainderRankNum = epRankSize % recvCoreNumPerGroup; + + uint32_t groupId = recvCoreIdx / recvCoreNumPerGroup; // 当前核处理的是哪个group + uint32_t recvCoreIdxInGroup = recvCoreIdx % recvCoreNumPerGroup; // 当前核处理的是group中第几个 + uint32_t startRankIdInGroup = recvRankNumPerCore * recvCoreIdxInGroup; // 当前核处理的起始rank + if (recvCoreIdxInGroup < remainderRankNum) { + recvRankNumPerCore += 1; + startRankIdInGroup += recvCoreIdxInGroup; + } else { + startRankIdInGroup += remainderRankNum; + } + uint32_t endRankIdInGroup = startRankIdInGroup + recvRankNumPerCore; + uint32_t startRankId = epRankSize * groupId + startRankIdInGroup; + uint32_t endRankId = epRankSize * groupId + endRankIdInGroup; + + uint32_t coreTokenCount = 0; + + if (startRankId < recvExpertNum) { + // 计算前缀和,以及接收token。这里有隐含约束,下面两个函数与RecvCount的ubOffset入参应保持一致,这样才能拿到有效数据 + GetCumSum(startRankId, recvExpertNum, ubOffset, gmOutputRecvCount); + RecvToken(gmX1, gmX1Scale, gmEpSendCount, coreTokenCount, startRankId, endRankId, recvRankNumPerCore, ubOffset); + } + + // 接收完成,通过写GM告知C核和计算V核 + AscendC::PipeBarrier(); + AscendC::LocalTensor tmpLocalTensor = resource.ubBuf.template GetBufferByByte(0); + ubOffset += CEIL_UP(UB_BLOCK_SIZE); + tmpLocalTensor.SetValue(CV_FLAG_INDEX, vToCFlag); + tmpLocalTensor.SetValue(GROUP_ID_INDEX, groupId); + tmpLocalTensor.SetValue(SELF_COUNT_INDEX, coreTokenCount); + AscendC::SetFlag(0); + + AscendC::GlobalTensor groupTokenNumStateTensor; + groupTokenNumStateTensor.SetGlobalBuffer((__gm__ int32_t *)(statusDataSpaceGm + GROUP_TOKEN_NUM_OFFSET)); + AscendC::WaitFlag(0); + AscendC::SetAtomicAdd(); + // 用原子加,各个核收到的token数量加一起,就是专家收到的token数量 + AscendC::DataCopy(groupTokenNumStateTensor[groupId * GROUP_INFO_SIZE], tmpLocalTensor, INT32_COUNT_PER_BLOCK); + AscendC::SetAtomicNone(); + AscendC::PipeBarrier(); + } + + CATLASS_DEVICE + void CompCoreFunc(GM_ADDR gmCVSwapBuff, __gm__ ElementScale *gmScale, __gm__ ElementPerTokenScale *gmTokenScale, + __gm__ float *gmSwigluOutput, uint32_t n, uint32_t k, LayoutScale layoutScale, + LayoutPerTokenScale wholeLayoutPerTokenScale, LayoutOutput layoutOutput) + { + uint32_t nOut = n / 2; + uint32_t coreNumPerGroup = recvCoreNum / localExpertNum; // 与V接收分配逻辑保持一致 + int64_t gmGroupOffsetScale = 0; + int64_t gmGroupOffsetPerTokenScale = 0; + int64_t gmGroupOffsetD = 0; + + AscendC::GlobalTensor gmC; + gmC.SetGlobalBuffer(reinterpret_cast<__gm__ ElementC *>(gmCVSwapBuff)); + auto layoutC = layout::RowMajor{L1TileShape::M * aiCoreGroupNum * WORKSPACE_STAGES, L1TileShape::N}; + { + BlockScheduler blockScheduler; + BlockEpilogue blockEpilogue(resource); + + uint32_t stageId = 0; + uint32_t target = 1; + uint32_t startCoreIdx = 0; + + AscendC::GlobalTensor groupTokenNumStateTensor; + for (uint32_t groupIdx = 0; groupIdx < localExpertNum; ++groupIdx) { + // 流程与C核类似,等专家token数据,以及计算、软同步 + groupTokenNumStateTensor.SetGlobalBuffer((__gm__ int32_t *)(statusDataSpaceGm + GROUP_TOKEN_NUM_OFFSET) + + groupIdx * GROUP_INFO_SIZE); + while (true) { + __asm__ __volatile__(""); + AscendC::DataCacheCleanAndInvalid(groupTokenNumStateTensor); + __asm__ __volatile__(""); + if (groupTokenNumStateTensor.GetValue(0) == coreNumPerGroup * vToCFlag) { + break; + } + } + uint32_t currentM = groupTokenNumStateTensor.GetValue(GROUP_TOKEN_COUNT); + GemmCoord inGroupProblemShape{currentM, n, k}; + LayoutPerTokenScale layoutPerTokenScale = + wholeLayoutPerTokenScale.GetTileLayout(inGroupProblemShape.template GetCoordByAxis<0>()); + LayoutD layoutD = layoutOutput.GetTileLayout(MakeCoord(currentM, nOut)); + EpilogueParams epilogueParams{gmScale + gmGroupOffsetScale, + layoutScale, + gmTokenScale + gmGroupOffsetPerTokenScale, + layoutPerTokenScale, + gmSwigluOutput + gmGroupOffsetD, + layoutD}; + blockScheduler.Update(inGroupProblemShape, L1TileShape::ToCoordMN()); + blockEpilogue.UpdateParams(epilogueParams); + uint32_t coreLoops = blockScheduler.GetCoreLoops(); + + GemmCoord blockShapeMNK = L1TileShape::ToCoord(); + uint32_t startLoopIdx = + ((compCoreIdx < startCoreIdx) ? (compCoreIdx + aiCoreGroupNum) : compCoreIdx) - startCoreIdx; + for (uint32_t loopIdx = startLoopIdx; loopIdx < coreLoops; loopIdx += aiCoreGroupNum) { + GemmCoord blockCoordMNK = blockScheduler.GetBlockCoord(loopIdx); + GemmCoord actualBlockShapeMNK = blockScheduler.GetActualBlockShape(blockCoordMNK); + + MatrixCoord offsetC{(stageId * aiCoreGroupNum + aiCoreGroupIdx) * L1TileShape::M, 0}; + int64_t gmOffsetC = layoutC.GetOffset(offsetC); + auto gmBlockC = gmC[gmOffsetC]; + auto layoutBlockC = layoutC.GetTileLayout(actualBlockShapeMNK.GetCoordMN()); + CheckSyncFlag(statusDataSpaceGm + SOFT_SYNC_OFFSET, + static_cast(COMP_AIV_CORE_NUM + compCoreIdx), target); // AIV等待的信号在24~48 + target += 1; + blockEpilogue(blockShapeMNK, blockCoordMNK, actualBlockShapeMNK, gmBlockC, layoutBlockC); + EncreaseSyncFlag(statusDataSpaceGm + SOFT_SYNC_OFFSET, static_cast(compCoreIdx)); + stageId = (stageId + 1 < WORKSPACE_STAGES) ? (stageId + 1) : 0; + } + + gmGroupOffsetScale += inGroupProblemShape.n(); + gmGroupOffsetPerTokenScale += inGroupProblemShape.m(); + gmGroupOffsetD += currentM * nOut; + + startCoreIdx = (startCoreIdx + coreLoops) % aiCoreGroupNum; + } + } + // 清理软同步残留信息,避免影响别处或者下次运行 + AscendC::PipeBarrier(); + AscendC::GlobalTensor softSyncTensor; + softSyncTensor.SetGlobalBuffer((__gm__ int32_t *)(statusDataSpaceGm + SOFT_SYNC_OFFSET)); + AscendC::LocalTensor tmpZeroLocalTensor = resource.ubBuf.template GetBufferByByte(0); + AscendC::Duplicate(tmpZeroLocalTensor, (int32_t)0, INT32_COUNT_PER_BLOCK); + AscendC::SetFlag(0); + AscendC::WaitFlag(0); + AscendC::DataCopy(softSyncTensor[compCoreIdx * SOFT_SYNC_SPACE_SIZE / sizeof(int32_t)], tmpZeroLocalTensor, + INT32_COUNT_PER_BLOCK); + AscendC::DataCopy(softSyncTensor[(compCoreIdx + COMP_AIV_CORE_NUM) * SOFT_SYNC_SPACE_SIZE / sizeof(int32_t)], + tmpZeroLocalTensor, INT32_COUNT_PER_BLOCK); + } + + CATLASS_DEVICE + void AivInitParams(Params const ¶ms) + { + aiCoreGroupNum = AscendC::GetBlockNum(); + subBlockNum = AscendC::GetSubBlockNum(); + aivIdx = AscendC::GetBlockIdx(); + aiCoreGroupIdx = aivIdx / subBlockNum; + aivStateGlobalCoreIdx = AIV_STATE_SPACE_IDNEX + aivIdx; + + isCompCore = (aivIdx % SUB_AIV_NUM) == 0; // 偶数核做计算 + compCoreNum = COMP_AIV_CORE_NUM; + compCoreIdx = aiCoreGroupIdx; + // 单卡单专家48发48收 + isRecvCore = true; + isSendCore = true; + recvCoreIdx = aivIdx; + sendCoreIdx = aivIdx; + sendCoreNum = SEND_AIV_CORE_NUM; + recvCoreNum = RECV_AIV_CORE_NUM; + + moeExpertNumPerRank = params.moeExpertNumPerRank; + + epRankSize = params.epRankSize; + epRankId = params.epRankId; + expertCntUp = epRankSize * moeExpertNumPerRank; + sharedExpertRankNum = params.sharedExpertRankNum; + hasShareExpert = (sharedExpertRankNum > 0); + isShareExpert = (epRankId < sharedExpertRankNum); + localExpertNum = isShareExpert ? 1 : moeExpertNumPerRank; + moeExpertNum = params.moeExpertNum; + tokenLength = params.tokenLen; + + // 单卡多专家改为24收24发 + if (localExpertNum > 1) { + isRecvCore = ((aivIdx % ODD_EVEN_BASE) == 0); // 偶数核接收 + isSendCore = ((aivIdx % ODD_EVEN_BASE) == 1); // 奇数核发送 + recvCoreIdx = aivIdx / SUB_AIV_NUM; + sendCoreIdx = aivIdx / SUB_AIV_NUM; + sendCoreNum = SEND_AIV_CORE_NUM / SUB_AIV_NUM; + recvCoreNum = RECV_AIV_CORE_NUM / SUB_AIV_NUM; + } + + hOutSize = tokenLength * sizeof(int8_t); + scaleParamPad = TOKEN_EXTRA_SPACE; // 预留512B给量化参数,实际只使用了4B(fp32) + hCommuSize = hOutSize + scaleParamPad; + axisHCommu = hCommuSize / sizeof(int8_t); + axisBS = params.bs; + axisK = params.topK; + uint32_t maxAxisBs = params.globalBs / epRankSize; + + stateOffset = STATE_OFFSET; + expertPerSizeOnWin = maxAxisBs * tokenLength * sizeof(XType); + winContext_ = (__gm__ HcclOpResParam *)AscendC::GetHcclContext(); + statusDataSpaceGm = (GM_ADDR)(winContext_->localWindowsExp); + } + + CATLASS_DEVICE + void AivInitState() + { + // 核状态更新,决定使用哪一半空间,以及各种信号的切换 + AscendC::GlobalTensor selfDataStatusTensor; + selfDataStatusTensor.SetGlobalBuffer((__gm__ int32_t *)(statusDataSpaceGm + STATE_WIN_OFFSET)); + __asm__ __volatile__(""); + AscendC::DataCacheCleanAndInvalid( + selfDataStatusTensor[aivIdx * UB_ALIGN]); + __asm__ __volatile__(""); + dataState = selfDataStatusTensor(aivIdx * UB_ALIGN); + if (dataState == 0) { + selfDataStatusTensor(aivIdx * UB_ALIGN) = 1; + } else { + selfDataStatusTensor(aivIdx * UB_ALIGN) = 0; + } + __asm__ __volatile__(""); + AscendC::DataCacheCleanAndInvalid( + selfDataStatusTensor[aivIdx * UB_ALIGN]); + __asm__ __volatile__(""); + + // 专家token数据信号 + __asm__ __volatile__(""); + AscendC::DataCacheCleanAndInvalid( + selfDataStatusTensor[aivStateGlobalCoreIdx * UB_ALIGN]); + __asm__ __volatile__(""); + cvDataState = selfDataStatusTensor(aivStateGlobalCoreIdx * UB_ALIGN); + if (cvDataState == 0) { + selfDataStatusTensor(aivStateGlobalCoreIdx * UB_ALIGN) = 1; + vToCFlag = V_TO_C_FLAG_1; + } else { + selfDataStatusTensor(aivStateGlobalCoreIdx * UB_ALIGN) = 0; + vToCFlag = V_TO_C_FLAG_2; + } + __asm__ __volatile__(""); + AscendC::DataCacheCleanAndInvalid( + selfDataStatusTensor[aivStateGlobalCoreIdx * UB_ALIGN]); + __asm__ __volatile__(""); + + AscendC::PipeBarrier(); + winDataSizeOffset = dataState * epRankSize * expertPerSizeOnWin * moeExpertNumPerRank; + GM_ADDR statusSpaceGm_ = GET_WIND_STATE_ADDR_BY_RANK_ID(epRankId); + AscendC::GlobalTensor selfStatusTensor; + selfStatusTensor.SetGlobalBuffer((__gm__ int32_t *)(statusSpaceGm_ + SELF_STATE_OFFSET)); + __asm__ __volatile__(""); + AscendC::DataCacheCleanAndInvalid( + selfStatusTensor[aivIdx * UB_ALIGN]); + __asm__ __volatile__(""); + state = selfStatusTensor(aivIdx * UB_ALIGN); + if (state == 0) { + sumTarget = (float)1.0; + tokenFlag = TOKEN_FLAG_1; + selfStatusTensor(aivIdx * UB_ALIGN) = 0x3F800000; // 浮点数的1.0 + } else { + sumTarget = 0.0; + tokenFlag = TOKEN_FLAG_2; + selfStatusTensor(aivIdx * UB_ALIGN) = 0; + } + __asm__ __volatile__(""); + AscendC::DataCacheCleanAndInvalid( + selfStatusTensor[aivIdx * UB_ALIGN]); + __asm__ __volatile__(""); + } + + CATLASS_DEVICE + void UpdateAndCleanInfo(__gm__ ElementGroupList_ *ptrGroupList, GM_ADDR gmEpSendCount) + { + if (aivIdx == aiCoreGroupNum * subBlockNum - 1) { + // 清理专家token数量信息 + AscendC::GlobalTensor groupTokenNumStateTensor; + groupTokenNumStateTensor.SetGlobalBuffer((__gm__ int32_t *)(statusDataSpaceGm + GROUP_TOKEN_NUM_OFFSET)); + AscendC::LocalTensor tmpZeroLocalTensor = resource.ubBuf.template GetBufferByByte(0); + AscendC::Duplicate(tmpZeroLocalTensor, (int32_t)0, GROUP_INFO_SIZE * localExpertNum); + AscendC::SetFlag(0); + AscendC::WaitFlag(0); + AscendC::DataCopy(groupTokenNumStateTensor, tmpZeroLocalTensor, GROUP_INFO_SIZE * localExpertNum); + } + + if (isRecvCore && recvCoreIdx == (recvCoreNum - 1)) { + // 更新group_list信息 + AscendC::GlobalTensor expertTokenNumsOutGMTensor_; + expertTokenNumsOutGMTensor_.SetGlobalBuffer((__gm__ int64_t *)(ptrGroupList)); + AscendC::GlobalTensor sendCountsGlobal; + sendCountsGlobal.SetGlobalBuffer(reinterpret_cast<__gm__ int32_t *>(gmEpSendCount)); + for (uint32_t localMoeIndex = 0; localMoeIndex < localExpertNum; ++localMoeIndex) { + __asm__ __volatile__(""); + AscendC::DataCacheCleanAndInvalid( + sendCountsGlobal[localMoeIndex * epRankSize + epRankSize - 1]); + __asm__ __volatile__(""); + uint32_t tokenNum = sendCountsGlobal.GetValue(localMoeIndex * epRankSize + epRankSize - 1); + expertTokenNumsOutGMTensor_.SetValue(localMoeIndex, tokenNum); + __asm__ __volatile__(""); + AscendC::DataCacheCleanAndInvalid( + expertTokenNumsOutGMTensor_[localMoeIndex]); + __asm__ __volatile__(""); + } + } + } + + template <> + CATLASS_DEVICE void operator()(Params const ¶ms) + { + AivInitParams(params); + AivInitState(); + if (isSendCore) { + SendCoreFunc((GM_ADDR)params.gmX, (GM_ADDR)params.gmexpertIds, (GM_ADDR)params.ptrA, + (GM_ADDR)params.ptrPerTokenScale, (GM_ADDR)params.gmExpandIdx); + } + if (isRecvCore) { + RecvCoreFunc((GM_ADDR)params.ptrA, (GM_ADDR)params.ptrPerTokenScale, (GM_ADDR)params.gmEpSendCount, + (GM_ADDR)params.gmOutputRecvCount); + } + + auto gmSwigluOutput = reinterpret_cast<__gm__ float *>( + params.ptrWorkspace + sizeof(int32_t) * (L1TileShape::M * aiCoreGroupNum * WORKSPACE_STAGES * L1TileShape::N)); + if (isCompCore) { + CompCoreFunc(params.ptrWorkspace, params.ptrScale, params.ptrPerTokenScale, gmSwigluOutput, + params.problemShape.n(), params.problemShape.k(), params.layoutScale, params.layoutPerTokenScale, + params.layoutOutput); + } + + icache_preload(8); + AscendC::SyncAll(); + AscendC::PipeBarrier(); + + UpdateAndCleanInfo(params.ptrGroupList, params.gmEpSendCount); + { + // 量化计算 + AscendC::GlobalTensor sendCountsGlobal; + sendCountsGlobal.SetGlobalBuffer(reinterpret_cast<__gm__ int32_t *>(params.gmEpSendCount)); + __asm__ __volatile__(""); + AscendC::DataCacheCleanAndInvalid(sendCountsGlobal); + __asm__ __volatile__(""); + totalTokenCount = sendCountsGlobal.GetValue(localExpertNum * epRankSize - 1); + AscendC::PipeBarrier(); + uint32_t nOut = params.problemShape.n() / 2; + uint32_t quantRowOnce = 0; + CalQuantRow(nOut, quantRowOnce); + typename BlockQuant::Params quantParams{ + gmSwigluOutput, params.layoutOutput, params.ptrDequantScale, params.layoutDequantScale, + params.ptrOutput, params.layoutOutput, quantRowOnce, nOut}; + + BlockQuant blockQuant(resource, quantParams); + MatrixCoord quantShape(totalTokenCount, nOut); + MatrixCoord quantBlockShape((uint16_t)(subBlockNum * quantRowOnce), nOut); + Epilogue::Tile::EpilogueHorizontalTileSwizzle quantSwizzle(quantShape, quantBlockShape); + for (uint32_t loopIdx = aiCoreGroupIdx; loopIdx < quantSwizzle.GetLoops(); loopIdx += aiCoreGroupNum) { + auto blockCoord = quantSwizzle.GetTileCoord(loopIdx); + auto actualBlockShape = quantSwizzle.GetActualTileShape(blockCoord); + blockQuant(quantBlockShape, blockCoord, actualBlockShape); + } + } + } + +private: + friend struct AicWaitFunc1; + friend struct AicSetFunc1; + + struct AicWaitFunc1 { + CATLASS_DEVICE + AicWaitFunc1() = default; + + CATLASS_DEVICE + void operator()() const + { + CheckSyncFlag(flagAddr, idx, target); + } + + __gm__ uint8_t *flagAddr; + uint8_t idx; + uint32_t target; + }; + + struct AicSetFunc1 { + CATLASS_DEVICE + AicSetFunc1() = default; + + CATLASS_DEVICE + void operator()() const + { + EncreaseSyncFlag(flagAddr, idx); + } + + __gm__ uint8_t *flagAddr; + uint8_t idx; + }; + + AicWaitFunc1 aicWaitFunc1; + AicSetFunc1 aicSetFunc1; + Arch::Resource resource; + + AscendC::LocalTensor expertIdsTensor_; + + // 卡与专家相关 + uint32_t epRankSize{0}; + uint32_t epRankId{0}; + bool hasShareExpert{false}; + bool isShareExpert{false}; + uint32_t expertCntUp{0}; + uint32_t localExpertNum{0}; + uint32_t sharedExpertRankNum{0}; + uint32_t moeExpertNumPerRank{0}; + uint32_t moeExpertNum{0}; + + // token相关 + uint32_t hOutSize{0}; + uint32_t scaleParamPad{0}; + uint32_t hCommuSize{0}; + uint32_t axisHCommu{0}; + uint32_t axisBS{0}; + uint32_t axisK{0}; + uint32_t totalTokenCount{0}; + uint32_t expertIdsCnt{0}; + uint32_t tokenLength{0}; + + // 状态相关 + int32_t tokenFlag{0}; // token到达的flag + int32_t vToCFlag{0}; // V通知C的flag + int32_t dataState{0}; // 当前核的状态,与combine配合 + int32_t cvDataState{0}; // 当前核的状态,CV配合 + int32_t state{0}; // count的flag选择依据 + float sumTarget{0.0}; // count达到的数量 + + // 共享内存相关 + __gm__ HcclOpResParam *winContext_; + GM_ADDR statusDataSpaceGm; + uint32_t stateOffset{0}; + uint64_t expertPerSizeOnWin{0}; + uint64_t winDataSizeOffset{0}; + + // 核上资源相关 + int64_t ubOffset; + + // 分核相关 + bool isSendCore{false}; + bool isRecvCore{false}; + bool isCompCore{false}; // 参与计算deq_swiglu + uint32_t aiCoreGroupNum{0}; + uint32_t aiCoreGroupIdx{0}; + uint32_t subBlockNum{0}; + uint32_t aicNum{0}; + uint32_t sendCoreNum{0}; + uint32_t recvCoreNum{0}; + uint32_t compCoreNum{0}; + uint32_t aivIdx{0}; + uint32_t aicIdx{0}; + uint32_t sendCoreIdx{0}; + uint32_t recvCoreIdx{0}; + uint32_t compCoreIdx{0}; + uint32_t aivStateGlobalCoreIdx{0}; + uint32_t aicStateGlobalCoreIdx{0}; + uint32_t sendToMoeAivNum{0}; + uint32_t sendToShareAivNum{0}; +}; + +} // namespace Catlass::Gemm::Kernel + +namespace Catlass::Gemm::Kernel { + +template +class GroupedMatmulSliceMPerTokenDequantSwigluQuantMultiStageWorkspaceWithShallowDispatch +{ +public: + using BlockMmad = BlockMmad_; + using ArchTag = typename BlockMmad::ArchTag; + using L1TileShape = typename BlockMmad::L1TileShape; + using ElementA = typename BlockMmad::ElementA; + using LayoutA = typename BlockMmad::LayoutA; + using ElementB = typename BlockMmad::ElementB; + using LayoutB = typename BlockMmad::LayoutB; + using ElementC = typename BlockMmad::ElementC; + using LayoutC = typename BlockMmad::LayoutC; + using ElementAccumulator = typename BlockMmad::ElementAccumulator; + + using BlockEpilogue = BlockEpilogue_; + using ElementScale = typename BlockEpilogue::ElementScale; + using LayoutScale = typename BlockEpilogue::LayoutScale; + using ElementPerTokenScale = typename BlockEpilogue::ElementPerTokenScale; + using LayoutPerTokenScale = typename BlockEpilogue::LayoutPerTokenScale; + using ElementD = typename BlockEpilogue::ElementD; + using LayoutD = typename BlockEpilogue::LayoutD; + using EpilogueParams = typename BlockEpilogue::Params; + + using ElementDequantScale = typename BlockQuant::ElementDequantScale; + using LayoutDequantScale = typename BlockQuant::LayoutDequantScale; + using ElementOutput = typename BlockQuant::ElementOutput; + using LayoutOutput = typename BlockQuant::LayoutOutput; + + using BlockScheduler = BlockScheduler_; + static constexpr uint32_t WORKSPACE_STAGES = WORKSPACE_STAGES_; + using ElementGroupList = ElementGroupList_; + + /// Parameters structure + struct Params { + // Data members + GemmCoord problemShape; + uint32_t problemCount; + __gm__ ElementGroupList_ *ptrGroupList; + __gm__ ElementA *ptrA; + LayoutA layoutA; + __gm__ ElementB *ptrB; + LayoutB layoutB; + __gm__ ElementScale *ptrScale; + LayoutScale layoutScale; + __gm__ ElementPerTokenScale *ptrPerTokenScale; + LayoutPerTokenScale layoutPerTokenScale; + __gm__ ElementOutput *ptrOutput; + LayoutOutput layoutOutput; + __gm__ ElementDequantScale *ptrDequantScale; + LayoutDequantScale layoutDequantScale; + GM_ADDR ptrWorkspace; + + // Methods + CATLASS_DEVICE + Params() {} + + CATLASS_DEVICE + Params(GemmCoord problemShape_, uint32_t problemCount_, GM_ADDR ptrGroupList_, GM_ADDR ptrA_, + LayoutA const &layoutA_, GM_ADDR ptrB_, LayoutB const &layoutB_, GM_ADDR ptrScale_, + LayoutScale const &layoutScale_, GM_ADDR ptrPerTokenScale_, + LayoutPerTokenScale const &layoutPerTokenScale_, GM_ADDR ptrOutput_, LayoutOutput const &layoutOutput_, + GM_ADDR ptrDequantScale_, LayoutDequantScale const &layoutDequantScale_, GM_ADDR ptrWorkspace_) + : problemShape(problemShape_), + problemCount(problemCount_), + ptrGroupList(reinterpret_cast<__gm__ ElementGroupList *>(ptrGroupList_)), + ptrA(reinterpret_cast<__gm__ ElementA *>(ptrA_)), + layoutA(layoutA_), + ptrB(reinterpret_cast<__gm__ ElementB *>(ptrB_)), + layoutB(layoutB_), + ptrScale(reinterpret_cast<__gm__ ElementScale *>(ptrScale_)), + layoutScale(layoutScale_), + ptrPerTokenScale(reinterpret_cast<__gm__ ElementPerTokenScale *>(ptrPerTokenScale_)), + layoutPerTokenScale(layoutPerTokenScale_), + ptrOutput(reinterpret_cast<__gm__ ElementOutput *>(ptrOutput_)), + layoutOutput(layoutOutput_), + ptrDequantScale(reinterpret_cast<__gm__ ElementDequantScale *>(ptrDequantScale_)), + layoutDequantScale(layoutDequantScale_), + ptrWorkspace(ptrWorkspace_) + {} + }; + + // Methods + CATLASS_DEVICE + GroupedMatmulSliceMPerTokenDequantSwigluQuantMultiStageWorkspaceWithShallowDispatch() + { + Arch::FlagID flagId = 0; + for (uint32_t stageId = 0; stageId < WORKSPACE_STAGES; ++stageId) { + flagAicFinishStoreList[stageId] = Arch::CrossCoreFlag(flagId++); + flagAivFinishComputeList[stageId] = Arch::CrossCoreFlag(flagId++); + aicWaitFuncList[stageId] = {this, stageId}; + aicSetFuncList[stageId] = {this, stageId}; + } + } + + template + CATLASS_DEVICE void operator()(Params const ¶ms); + + template <> + CATLASS_DEVICE void operator()(Params const ¶ms) + { + BlockScheduler blockScheduler; + BlockMmad blockMmad(resource); + + // Represent the full gm + AscendC::GlobalTensor gmA; + gmA.SetGlobalBuffer(params.ptrA); + AscendC::GlobalTensor gmB; + gmB.SetGlobalBuffer(params.ptrB); + AscendC::GlobalTensor groupList; + groupList.SetGlobalBuffer(params.ptrGroupList); + + uint32_t coreIdx = AscendC::GetBlockIdx(); + uint32_t coreNum = AscendC::GetBlockNum(); + int64_t gmGroupOffsetA = 0; + int64_t gmGroupOffsetB = 0; + + AscendC::GlobalTensor gmC; + gmC.SetGlobalBuffer(reinterpret_cast<__gm__ ElementC *>(params.ptrWorkspace)); + auto layoutC = layout::RowMajor{L1TileShape::M * coreNum * WORKSPACE_STAGES, L1TileShape::N}; + + uint32_t stageId = 0; + uint32_t stageUsed = 0; + uint32_t startCoreIdx = 0; + for (uint32_t groupIdx = 0; groupIdx < params.problemCount; ++groupIdx) { + uint32_t currentM = (groupIdx == 0) ? groupList.GetValue(groupIdx) + : (groupList.GetValue(groupIdx) - groupList.GetValue(groupIdx - 1)); + GemmCoord inGroupProblemShape{currentM, params.problemShape.n(), params.problemShape.k()}; + + LayoutA layoutA = params.layoutA.GetTileLayout(inGroupProblemShape.GetCoordMK()); + LayoutB layoutB = params.layoutB; + + blockScheduler.Update(inGroupProblemShape, MakeCoord(L1TileShape::M, L1TileShape::N)); + uint32_t coreLoops = blockScheduler.GetCoreLoops(); + + // Determine the starting loopIdx of the current core under the current groupIdx + uint32_t startLoopIdx = ((coreIdx < startCoreIdx) ? (coreIdx + coreNum) : coreIdx) - startCoreIdx; + // Loop through the matmul of each groupIdx + for (uint32_t loopIdx = startLoopIdx; loopIdx < coreLoops; loopIdx += coreNum) { + // Compute block location + GemmCoord blockCoord = blockScheduler.GetBlockCoord(loopIdx); + GemmCoord actualBlockShape = blockScheduler.GetActualBlockShape(blockCoord); + + Callback callbackBeforeFixpipe{}; + if (stageUsed == WORKSPACE_STAGES) { + callbackBeforeFixpipe = MakeCallback(&aicWaitFuncList[stageId]); + } else { + ++stageUsed; + } + Callback callbackAfterFixpipe = MakeCallback(&aicSetFuncList[stageId]); + + // Compute initial location in logical coordinates + MatrixCoord offsetA{blockCoord.m() * L1TileShape::M, blockCoord.k() * L1TileShape::K}; + MatrixCoord offsetB{blockCoord.k() * L1TileShape::K, blockCoord.n() * L1TileShape::N}; + MatrixCoord offsetC{(stageId * coreNum + coreIdx) * L1TileShape::M, 0}; + int64_t gmOffsetA = layoutA.GetOffset(offsetA); + int64_t gmOffsetB = layoutB.GetOffset(offsetB); + int64_t gmOffsetC = layoutC.GetOffset(offsetC); + + // Compute block-scoped matrix multiply-add + if constexpr (BlockMmad::DispatchPolicy::ASYNC) { + blockMmad(gmA[gmGroupOffsetA + gmOffsetA], layoutA, gmB[gmGroupOffsetB + gmOffsetB], layoutB, + gmC[gmOffsetC], layoutC, actualBlockShape, callbackBeforeFixpipe, callbackAfterFixpipe); + } else { + callbackBeforeFixpipe(); + blockMmad(gmA[gmGroupOffsetA + gmOffsetA], layoutA, gmB[gmGroupOffsetB + gmOffsetB], layoutB, + gmC[gmOffsetC], layoutC, actualBlockShape); + callbackAfterFixpipe(); + } + + stageId = (stageId + 1 < WORKSPACE_STAGES) ? (stageId + 1) : 0; + } + + gmGroupOffsetA += inGroupProblemShape.m() * inGroupProblemShape.k(); + gmGroupOffsetB += inGroupProblemShape.k() * inGroupProblemShape.n(); + + startCoreIdx = (startCoreIdx + coreLoops) % coreNum; + } + + if constexpr (BlockMmad::DispatchPolicy::ASYNC) { + blockMmad.SynchronizeBlock(); + } + + while (stageUsed > 0) { + uint32_t aivComputeStageId = + (stageId >= stageUsed) ? (stageId - stageUsed) : (stageId + WORKSPACE_STAGES - stageUsed); + Arch::CrossCoreWaitFlag(flagAivFinishComputeList[aivComputeStageId]); + --stageUsed; + } + } + + template <> + CATLASS_DEVICE void operator()(Params const ¶ms) + { + uint32_t coreIdx = AscendC::GetBlockIdx() / AscendC::GetSubBlockNum(); + uint32_t coreNum = AscendC::GetBlockNum(); + int64_t gmGroupOffsetScale = 0; + int64_t gmGroupOffsetPerTokenScale = 0; + int64_t gmGroupOffsetD = 0; + + AscendC::GlobalTensor groupList; + groupList.SetGlobalBuffer(params.ptrGroupList); + + AscendC::GlobalTensor gmC; + gmC.SetGlobalBuffer(reinterpret_cast<__gm__ ElementC *>(params.ptrWorkspace)); + auto layoutC = layout::RowMajor{L1TileShape::M * coreNum * WORKSPACE_STAGES, L1TileShape::N}; + + auto ptrD = reinterpret_cast<__gm__ float *>( + params.ptrWorkspace + sizeof(int32_t) * (L1TileShape::M * coreNum * WORKSPACE_STAGES * L1TileShape::N)); + + uint32_t mActual = groupList.GetValue(params.problemCount - 1); + uint32_t nOut = params.problemShape.n() / 2; + + { + BlockScheduler blockScheduler; + BlockEpilogue blockEpilogue(resource); + + uint32_t stageId = 0; + uint32_t startCoreIdx = 0; + for (uint32_t groupIdx = 0; groupIdx < params.problemCount; ++groupIdx) { + uint32_t currentM = (groupIdx == 0) ? groupList.GetValue(groupIdx) + : (groupList.GetValue(groupIdx) - groupList.GetValue(groupIdx - 1)); + GemmCoord inGroupProblemShape{currentM, params.problemShape.n(), params.problemShape.k()}; + + LayoutScale layoutScale = params.layoutScale; + LayoutPerTokenScale layoutPerTokenScale = + params.layoutPerTokenScale.GetTileLayout(inGroupProblemShape.template GetCoordByAxis<0>()); + LayoutD layoutD = params.layoutOutput.GetTileLayout(MakeCoord(currentM, nOut)); + + EpilogueParams epilogueParams{params.ptrScale + gmGroupOffsetScale, + layoutScale, + params.ptrPerTokenScale + gmGroupOffsetPerTokenScale, + layoutPerTokenScale, + ptrD + gmGroupOffsetD, + layoutD}; + + blockScheduler.Update(inGroupProblemShape, L1TileShape::ToCoordMN()); + blockEpilogue.UpdateParams(epilogueParams); + uint32_t coreLoops = blockScheduler.GetCoreLoops(); + + GemmCoord blockShapeMNK = L1TileShape::ToCoord(); + uint32_t startLoopIdx = ((coreIdx < startCoreIdx) ? (coreIdx + coreNum) : coreIdx) - startCoreIdx; + for (uint32_t loopIdx = startLoopIdx; loopIdx < coreLoops; loopIdx += coreNum) { + GemmCoord blockCoordMNK = blockScheduler.GetBlockCoord(loopIdx); + GemmCoord actualBlockShapeMNK = blockScheduler.GetActualBlockShape(blockCoordMNK); + + MatrixCoord offsetC{(stageId * coreNum + coreIdx) * L1TileShape::M, 0}; + int64_t gmOffsetC = layoutC.GetOffset(offsetC); + auto gmBlockC = gmC[gmOffsetC]; + auto layoutBlockC = layoutC.GetTileLayout(actualBlockShapeMNK.GetCoordMN()); + + Arch::CrossCoreWaitFlag(flagAicFinishStoreList[stageId]); + blockEpilogue(blockShapeMNK, blockCoordMNK, actualBlockShapeMNK, gmBlockC, layoutBlockC); + Arch::CrossCoreSetFlag<0x2, PIPE_MTE3>(flagAivFinishComputeList[stageId]); + + stageId = (stageId + 1 < WORKSPACE_STAGES) ? (stageId + 1) : 0; + } + + gmGroupOffsetScale += inGroupProblemShape.n(); + gmGroupOffsetPerTokenScale += inGroupProblemShape.m(); + gmGroupOffsetD += currentM * nOut; + + startCoreIdx = (startCoreIdx + coreLoops) % coreNum; + } + } + + Arch::CrossCoreBarrier<0x0, PIPE_MTE3>(); + + { + uint32_t quantRowOnce = 0; + CalQuantRow(nOut, quantRowOnce); + typename BlockQuant::Params quantParams{ptrD, + params.layoutOutput, + params.ptrDequantScale, + params.layoutDequantScale, + params.ptrOutput, + params.layoutOutput, + quantRowOnce, + nOut}; + + BlockQuant blockQuant(resource, quantParams); + MatrixCoord quantShape(mActual, nOut); + MatrixCoord quantBlockShape((uint16_t)(AscendC::GetSubBlockNum() * quantRowOnce), nOut); + Epilogue::Tile::EpilogueHorizontalTileSwizzle quantSwizzle(quantShape, quantBlockShape); + for (uint32_t loopIdx = coreIdx; loopIdx < quantSwizzle.GetLoops(); loopIdx += coreNum) { + auto blockCoord = quantSwizzle.GetTileCoord(loopIdx); + auto actualBlockShape = quantSwizzle.GetActualTileShape(blockCoord); + + blockQuant(quantBlockShape, blockCoord, actualBlockShape); + } + } + } + +private: + friend struct AicWaitFunc; + friend struct AicSetFunc; + + struct AicWaitFunc { + using MatmulKernel = GroupedMatmulSliceMPerTokenDequantSwigluQuantMultiStageWorkspaceWithShallowDispatch< + BlockMmad, BlockEpilogue, BlockScheduler, WORKSPACE_STAGES, ElementGroupList>; + + CATLASS_DEVICE + AicWaitFunc() = default; + + CATLASS_DEVICE + void operator()() const + { + Arch::CrossCoreWaitFlag(ptr->flagAivFinishComputeList[stageId]); + } + + MatmulKernel *ptr{nullptr}; + uint32_t stageId; + }; + + struct AicSetFunc { + using MatmulKernel = GroupedMatmulSliceMPerTokenDequantSwigluQuantMultiStageWorkspaceWithShallowDispatch< + BlockMmad, BlockEpilogue, BlockScheduler, WORKSPACE_STAGES, ElementGroupList>; + + CATLASS_DEVICE + AicSetFunc() = default; + + CATLASS_DEVICE + void operator()() const + { + Arch::CrossCoreSetFlag<0x2, PIPE_FIX>(ptr->flagAicFinishStoreList[stageId]); + } + + MatmulKernel *ptr{nullptr}; + uint32_t stageId; + }; + + Arch::CrossCoreFlag flagAicFinishStoreList[WORKSPACE_STAGES]; + Arch::CrossCoreFlag flagAivFinishComputeList[WORKSPACE_STAGES]; + + AicWaitFunc aicWaitFuncList[WORKSPACE_STAGES]; + AicSetFunc aicSetFuncList[WORKSPACE_STAGES]; + Arch::Resource resource; +}; + +} // namespace Catlass::Gemm::Kernel diff --git a/csrc/dispatch_gmm_combine_decode/op_kernel/dispatch_gmm_combine_decode/raw_distributed/cam_moe_distribute_combine.h b/csrc/dispatch_gmm_combine_decode/op_kernel/dispatch_gmm_combine_decode/raw_distributed/cam_moe_distribute_combine.h new file mode 100644 index 00000000000..0b8ec9a69e6 --- /dev/null +++ b/csrc/dispatch_gmm_combine_decode/op_kernel/dispatch_gmm_combine_decode/raw_distributed/cam_moe_distribute_combine.h @@ -0,0 +1,813 @@ +/* + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ +#ifndef CAM_MOE_DISTRIBUTE_COMBINE_H +#define CAM_MOE_DISTRIBUTE_COMBINE_H +#define OPT_RANK_OFFSET 512 + +#include "kernel_operator.h" +#include "kernel_tiling/kernel_tiling.h" +#include "../../dispatch_gmm_combine_decode_base.h" +#include "../../dispatch_gmm_combine_decode_tiling.h" + +namespace MoeDistributeCombineImpl { +constexpr uint8_t BUFFER_NUM = 2; // multi-buf +constexpr uint32_t STATE_OFFSET = 512; +constexpr uint32_t STATE_SIZE = 1024 * 1024; // 1M +constexpr uint32_t RANK_SIZE_ON_WIN_512 = 512 * 1024; +constexpr uint32_t RANK_SIZE_ON_WIN_256 = 256 * 1024; +constexpr uint32_t TP_RANK_SIZE_ON_WIN = 0; +constexpr uint32_t UB_ALIGN = 32; +constexpr uint32_t SELF_STATE_OFFSET = 256 * 1024; +constexpr uint8_t EP_DOMAIN = 0; +constexpr uint8_t TP_DOMAIN = 1; +constexpr uint64_t WIN_STATE_OFFSET = 512 * 1024; +constexpr uint64_t STATE_WIN_OFFSET = 900 * 1024; +constexpr uint16_t SEND_SYNC_EVENT_ID = 9; +constexpr uint16_t RECV_SYNC_EVENT_ID = 10; + +template +__aicore__ inline void SyncFunc() +{ + int32_t eventID = static_cast(GetTPipePtr()->FetchEventID(event)); + AscendC::SetFlag(eventID); + AscendC::WaitFlag(eventID); +} + +using namespace AscendC; + +struct CombineCalcInfo { + uint64_t expertPerSizeOnWin_; + uint32_t epRankId_; + uint32_t epWorldSize_; + uint32_t moeExpertPerRankNum_; + uint32_t sharedExpertRankNum_; + uint32_t axisH_; + uint32_t moeSendNum_; + bool isShardExpert_; + GM_ADDR epSendCount_; + __gm__ HcclOpResParam *epWinContext_; + uint64_t winDataSizeOffset_; +}; + +template +class CamMoeDistributeCombine +{ +public: + __aicore__ inline CamMoeDistributeCombine(){}; + __aicore__ inline void Init(GM_ADDR expandX, GM_ADDR expertIds, GM_ADDR expandIdx, GM_ADDR epSendCount, + GM_ADDR tpSendCount, GM_ADDR scales, GM_ADDR XOut, GM_ADDR workspaceGM, TPipe *pipe, + const DispatchGmmCombineDecodeTilingData *tilingData); + __aicore__ inline void Process(); + __aicore__ inline void AllToAllSend(); + __aicore__ inline void ReducePermute(); + + __aicore__ inline CombineCalcInfo &GetCalcInfo() + { + return calcInfo_; + } + + __aicore__ inline void TPipeSet(AscendC::TPipe *pipe) + { + tpipe_ = pipe; + } + +private: + __aicore__ inline void InitStatusTargetSum(); + __aicore__ inline void AlltoAllBuffInit(); + __aicore__ inline void ReduceScatterTrans(); + __aicore__ inline void SetWaitTpStatusAndDisPatch(); + __aicore__ inline void CustomAdd(LocalTensor &dst, LocalTensor &src0, + LocalTensor &src1, uint32_t dataCnt); + __aicore__ inline void ExpertAlltoAllDispatchInnerCopyAdd(uint32_t tokenNumLoop, uint32_t srcStartTokenIdx, + uint32_t ep, uint32_t expertIdx); + __aicore__ inline void ExpertAlltoAllDispatchCopyAdd(); + __aicore__ inline void LocalWindowCopy(); + __aicore__ inline void BuffInit(); + __aicore__ inline void SplitCoreCal(); + __aicore__ inline void SetStatus(); + __aicore__ inline void WaitDispatch(); + __aicore__ GM_ADDR GetWinAddrByRankId(const int32_t rankId, const uint8_t domain, const uint8_t expertLocalId = 0U) + { + if (domain == EP_DOMAIN) { + return (GM_ADDR)((epRankId_ == rankId) + ? epWinContext_->localWindowsIn + : ((HcclRankRelationResV2 *)(epWinContext_->remoteRes[rankId].nextDevicePtr)) + ->windowsIn) + + winDataSizeOffset_ + expertLocalId * expertPerSizeOnWin_ + rankId * OPT_RANK_OFFSET; + } else { + return (GM_ADDR)((tpRankId_ == rankId) + ? tpWinContext_->localWindowsIn + : ((HcclRankRelationResV2 *)(tpWinContext_->remoteRes[rankId].nextDevicePtr)) + ->windowsIn) + + winDataSizeOffset_ + rankId * OPT_RANK_OFFSET; + } + } + + __aicore__ GM_ADDR GetWinStateAddrByRankId(const int32_t rankId, const uint8_t domain) + { + if (domain == EP_DOMAIN) { + return (GM_ADDR)((epRankId_ == rankId) + ? epWinContext_->localWindowsExp + : ((HcclRankRelationResV2 *)(epWinContext_->remoteRes[rankId].nextDevicePtr)) + ->windowsExp) + + dataState_ * WIN_STATE_OFFSET; + } else { + return (GM_ADDR)((tpRankId_ == rankId) + ? tpWinContext_->localWindowsExp + : ((HcclRankRelationResV2 *)(tpWinContext_->remoteRes[rankId].nextDevicePtr)) + ->windowsExp) + + dataState_ * WIN_STATE_OFFSET; + } + } + + __aicore__ inline uint32_t MIN(uint32_t x, uint32_t y) + { + return (x < y) ? x : y; + } + + __aicore__ static void DoCombineRecv(void *ptr) + { + auto *combiner = (CamMoeDistributeCombine *)ptr; + combiner->ReducePermute(); + } + + TPipe *tpipe_{nullptr}; + GlobalTensor expandXGM_; + GlobalTensor expertIdsGM_; + GlobalTensor expandIdxGM_; + GlobalTensor epSendCountGM_; + GlobalTensor tpSendCountGM_; + GlobalTensor expandScalesGM_; + GlobalTensor expandOutGlobal_; + GlobalTensor rankWindow_; + GlobalTensor rankStates_; + GlobalTensor epStatusSpaceGlobalTensor_; + GlobalTensor tpStatusSpaceGlobalTensor_; + GlobalTensor tpRankWindow_; + GlobalTensor rowTmpGlobal_; + GM_ADDR workspaceGM_; + GM_ADDR epWindowGM_; + GM_ADDR epStatusSpaceGm_; + GM_ADDR tpWindowGM_; + GM_ADDR tpStatusSpaceGm_; + GM_ADDR stateGM_; + + LocalTensor winTpSendCountTensor_; + LocalTensor gmTpSendCountTensor_; + LocalTensor outTensor_; + LocalTensor winTpSendCountFloatTensor_; + LocalTensor gmTpSendCountFloatTensor_; + LocalTensor epSendCountLocal_; + + CombineCalcInfo calcInfo_; + uint32_t axisBS_{0}; + uint32_t axisMaxBs_{0}; + uint32_t axisH_{0}; + uint32_t axisK_{0}; + uint32_t aivNum_{0}; + uint32_t epWorldSize_{0}; + uint32_t tpWorldSize_{0}; + uint32_t epRankId_{0}; + uint32_t tpRankId_{0}; + uint32_t coreIdx_{0}; // aiv id + uint32_t sharedExpertRankNum_{0}; + uint32_t moeExpertNum_{0}; + uint32_t moeExpertPerRankNum_{0}; + uint32_t moeSendNum_{0}; // moeExpertPerRankNum_ * epWorldSize_ + uint32_t tpScatterNum_{0}; + uint32_t firstTpTokenEndIdx_{0}; + uint32_t firstTpTokenEndOffset_{0}; + uint32_t endTok_{0}; + __gm__ HcclOpResParam *epWinContext_{nullptr}; + __gm__ HcclOpResParam *tpWinContext_{nullptr}; + uint32_t epDataOffsetOnWin_{0}; + uint32_t tpDataOffsetOnWin_{0}; + uint32_t epStateOffsetOnWin_{0}; + uint32_t tpStateOffsetOnWin_{0}; + uint32_t axisHFloatSize_{0}; + uint32_t axisHExpandXTypeSize_{0}; + uint32_t bsKNum_{0}; + uint32_t startRankId_{0}; + uint32_t endRankId_{0}; + uint32_t sendRankNum_{0}; + uint32_t ubSize_{0}; + uint32_t dataState_{0}; + uint32_t stateOffset_{0}; + uint64_t winDataSizeOffset_{0}; + uint64_t expertPerSizeOnWin_{0}; + uint64_t totalWinSize_{0}; + TQueBind moeQueue_; + TQue moeSumQueue_; + TQueBind gmTpSendCountQueue_; + TQue gmTpSendCountInQueue_; + TQue winTpSendCountInQueue_; + TQue xOutQueue_; + TBuf<> readStateBuf_; + TBuf<> expertIdsBuf_; + TBuf<> expandScalesBuf_; + TBuf<> rowTmpFloatBuf_; + TBuf<> sumFloatBuf_; + TBuf<> mulBuf_; + TBuf<> sendCountBuf_; + TBuf<> indexCountsBuf_; + TBuf<> winTpSendCountFloatBuf_; + TBuf<> gmTpSendCountFloatBuf_; + TBuf<> tokenBuf_; + TBuf<> statusBuf_; + TBuf<> gatherMaskOutBuf_; // gather mask output buf + TBuf<> gatherTmpBuf_; + TBuf<> statusSumOutBuf_; + float sumTarget_{0.0}; + int32_t epStateValue_; + bool isShardExpert_{false}; +}; + +template +__aicore__ inline void CamMoeDistributeCombine::Init( + GM_ADDR expandX, GM_ADDR expertIds, GM_ADDR expandIdx, GM_ADDR epSendCount, GM_ADDR tpSendCount, GM_ADDR scales, + GM_ADDR XOut, GM_ADDR workspaceGM, TPipe *pipe, const DispatchGmmCombineDecodeTilingData *tilingData) +{ + tpipe_ = pipe; + coreIdx_ = GetBlockIdx(); + epRankId_ = tilingData->disGmmDeqSwigluQuantGmmDeqComInfo.epRankId; + auto contextGM0 = AscendC::GetHcclContext(); + epWinContext_ = (__gm__ HcclOpResParam *)contextGM0; + GlobalTensor selfDataStatusTensor; + GM_ADDR statusDataSpaceGm = (GM_ADDR)epWinContext_->localWindowsExp; + selfDataStatusTensor.SetGlobalBuffer((__gm__ int32_t *)(statusDataSpaceGm + STATE_WIN_OFFSET)); + __asm__ __volatile__(""); + DataCacheCleanAndInvalid( + selfDataStatusTensor[coreIdx_ * UB_ALIGN]); + __asm__ __volatile__(""); + dataState_ = selfDataStatusTensor(coreIdx_ * UB_ALIGN); + if (dataState_ == 0) { + selfDataStatusTensor(coreIdx_ * UB_ALIGN) = 1; + } else { + selfDataStatusTensor(coreIdx_ * UB_ALIGN) = 0; + } + __asm__ __volatile__(""); + DataCacheCleanAndInvalid( + selfDataStatusTensor[coreIdx_ * UB_ALIGN]); + __asm__ __volatile__(""); + pipe_barrier(PIPE_ALL); + + workspaceGM_ = workspaceGM; + expandXGM_.SetGlobalBuffer((__gm__ ExpandXType *)expandX); + expertIdsGM_.SetGlobalBuffer((__gm__ ExpandIdxType *)expertIds); + expandIdxGM_.SetGlobalBuffer((__gm__ ExpandIdxType *)expandIdx); + epSendCountGM_.SetGlobalBuffer((__gm__ int32_t *)epSendCount); + expandScalesGM_.SetGlobalBuffer((__gm__ float *)scales); + expandOutGlobal_.SetGlobalBuffer((__gm__ ExpandXType *)XOut); + axisBS_ = tilingData->disGmmDeqSwigluQuantGmmDeqComInfo.bs; + axisH_ = tilingData->disGmmDeqSwigluQuantGmmDeqComInfo.h; + axisK_ = tilingData->disGmmDeqSwigluQuantGmmDeqComInfo.k; + aivNum_ = tilingData->disGmmDeqSwigluQuantGmmDeqComInfo.aivNum; + ubSize_ = tilingData->disGmmDeqSwigluQuantGmmDeqComInfo.totalUbSize; + sharedExpertRankNum_ = tilingData->disGmmDeqSwigluQuantGmmDeqComInfo.sharedExpertRankNum; + moeExpertNum_ = tilingData->disGmmDeqSwigluQuantGmmDeqComInfo.moeExpertNum; + moeExpertPerRankNum_ = tilingData->disGmmDeqSwigluQuantGmmDeqComInfo.moeExpertNumPerRank; + epWorldSize_ = tilingData->disGmmDeqSwigluQuantGmmDeqComInfo.epRankSize; + axisMaxBs_ = tilingData->disGmmDeqSwigluQuantGmmDeqComInfo.globalBs / epWorldSize_; + moeSendNum_ = epWorldSize_ * moeExpertPerRankNum_; + tpWorldSize_ = 1; + tpRankId_ = 0; + totalWinSize_ = tilingData->disGmmDeqSwigluQuantGmmDeqComInfo.totalWinSize; + stateOffset_ = (moeSendNum_ > 512) ? (STATE_OFFSET / 2) : STATE_OFFSET; + expertPerSizeOnWin_ = + static_cast(axisMaxBs_) * static_cast(axisH_) * static_cast(sizeof(ExpandXType)); + winDataSizeOffset_ = static_cast(dataState_) * static_cast(moeSendNum_) * expertPerSizeOnWin_; + epWindowGM_ = GetWinAddrByRankId(epRankId_, EP_DOMAIN); + epStatusSpaceGm_ = GetWinStateAddrByRankId(epRankId_, EP_DOMAIN); + epStatusSpaceGlobalTensor_.SetGlobalBuffer((__gm__ float *)epStatusSpaceGm_); + epDataOffsetOnWin_ = epRankId_ * moeExpertPerRankNum_ * static_cast(expertPerSizeOnWin_); + epStateOffsetOnWin_ = epRankId_ * stateOffset_; + isShardExpert_ = (epRankId_ < sharedExpertRankNum_); + axisHFloatSize_ = axisH_ * sizeof(float); + axisHExpandXTypeSize_ = axisH_ * sizeof(ExpandXType); + bsKNum_ = axisBS_ * axisK_; + + if constexpr (IsNeedReduceScatter) { + tpSendCountGM_.SetGlobalBuffer((__gm__ int32_t *)tpSendCount); + tpWindowGM_ = GetWinAddrByRankId(tpRankId_, TP_DOMAIN); + tpStatusSpaceGm_ = GetWinStateAddrByRankId(tpRankId_, TP_DOMAIN); + tpStatusSpaceGlobalTensor_.SetGlobalBuffer((__gm__ float *)tpStatusSpaceGm_); + tpDataOffsetOnWin_ = tpRankId_ * TP_RANK_SIZE_ON_WIN; + tpStateOffsetOnWin_ = tpRankId_ * stateOffset_; + uint32_t tpScatterRankWinOffset = (tpRankId_ == 0) ? TP_RANK_SIZE_ON_WIN : 0; + GM_ADDR rankGM = tpWindowGM_ + tpScatterRankWinOffset; + tpRankWindow_.SetGlobalBuffer((__gm__ ExpandXType *)rankGM); + } + + InitStatusTargetSum(); + if constexpr (EXEC_FLAG & EXEC_FLAG_DEEP_FUSE) { + coreIdx_ = get_block_idx(); + } + SplitCoreCal(); + + calcInfo_.epRankId_ = epRankId_; + calcInfo_.epWorldSize_ = epWorldSize_; + calcInfo_.expertPerSizeOnWin_ = expertPerSizeOnWin_; + calcInfo_.moeExpertPerRankNum_ = moeExpertPerRankNum_; + calcInfo_.sharedExpertRankNum_ = sharedExpertRankNum_; + calcInfo_.axisH_ = axisH_; + calcInfo_.moeSendNum_ = moeSendNum_; + calcInfo_.isShardExpert_ = isShardExpert_; + calcInfo_.epSendCount_ = epSendCount; + calcInfo_.epWinContext_ = epWinContext_; + calcInfo_.winDataSizeOffset_ = winDataSizeOffset_; +} + +template +__aicore__ inline void CamMoeDistributeCombine::InitStatusTargetSum() +{ + // ep state + GlobalTensor selfStatusTensor; + selfStatusTensor.SetGlobalBuffer((__gm__ int32_t *)(epStatusSpaceGm_ + SELF_STATE_OFFSET)); + __asm__ __volatile__(""); + DataCacheCleanAndInvalid( + selfStatusTensor[coreIdx_ * UB_ALIGN]); + __asm__ __volatile__(""); + int32_t state = selfStatusTensor(coreIdx_ * UB_ALIGN); + if (state == 0) { + sumTarget_ = static_cast(1.0); + selfStatusTensor(coreIdx_ * UB_ALIGN) = 0x3F800000; // 1.0f + epStateValue_ = 0x3F800000; // 1.0f + } else { + sumTarget_ = static_cast(0.0); + selfStatusTensor(coreIdx_ * UB_ALIGN) = 0; + epStateValue_ = 0; + } + __asm__ __volatile__(""); + DataCacheCleanAndInvalid( + selfStatusTensor[coreIdx_ * UB_ALIGN]); + __asm__ __volatile__(""); +} + +template +__aicore__ inline void CamMoeDistributeCombine::BuffInit() +{ + tpipe_->Reset(); + tpipe_->InitBuffer(readStateBuf_, UB_ALIGN); // 32 + uint32_t sendNumAlign = Ceil(moeSendNum_ * sizeof(int32_t), UB_ALIGN) * UB_ALIGN; + tpipe_->InitBuffer(sendCountBuf_, sendNumAlign); // epWorldSize_ * moeExpertPerRankNum_ * 4 + if constexpr (IsNeedReduceScatter) { + tpipe_->InitBuffer(winTpSendCountInQueue_, BUFFER_NUM, axisHExpandXTypeSize_); // 28K + tpipe_->InitBuffer(gmTpSendCountInQueue_, BUFFER_NUM, axisHExpandXTypeSize_); // 28K + tpipe_->InitBuffer(xOutQueue_, BUFFER_NUM, axisHExpandXTypeSize_); // 28K + if constexpr (AscendC::IsSameType::value) { + tpipe_->InitBuffer(winTpSendCountFloatBuf_, axisHFloatSize_); + tpipe_->InitBuffer(gmTpSendCountFloatBuf_, axisHFloatSize_); + winTpSendCountFloatTensor_ = winTpSendCountFloatBuf_.Get(); + gmTpSendCountFloatTensor_ = gmTpSendCountFloatBuf_.Get(); + } + } else { + tpipe_->InitBuffer(gmTpSendCountQueue_, BUFFER_NUM, axisHExpandXTypeSize_); // 28K + } + epSendCountLocal_ = sendCountBuf_.Get(); +} + +template +__aicore__ inline void CamMoeDistributeCombine::AlltoAllBuffInit() +{ + tpipe_->Reset(); + uint32_t bsMulTopkSizeAligned = Ceil(axisBS_ * axisK_ * sizeof(int32_t), UB_ALIGN) * UB_ALIGN; // 防止UB不对齐 + tpipe_->InitBuffer(readStateBuf_, UB_ALIGN); + tpipe_->InitBuffer(statusBuf_, sendRankNum_ * UB_ALIGN); + tpipe_->InitBuffer(expertIdsBuf_, bsMulTopkSizeAligned); + tpipe_->InitBuffer(expandScalesBuf_, bsMulTopkSizeAligned); + tpipe_->InitBuffer(tokenBuf_, axisH_ * sizeof(ExpandXType)); + tpipe_->InitBuffer(rowTmpFloatBuf_, axisHFloatSize_); // 7168 * 4 = 28672 + tpipe_->InitBuffer(mulBuf_, axisHFloatSize_); // 7168 * 4 = 28672 + tpipe_->InitBuffer(sumFloatBuf_, axisHFloatSize_); // 7168 * 4 = 28672 + tpipe_->InitBuffer(indexCountsBuf_, bsMulTopkSizeAligned); + tpipe_->InitBuffer(moeSumQueue_, BUFFER_NUM, axisHExpandXTypeSize_); + tpipe_->InitBuffer(gatherMaskOutBuf_, epWorldSize_ * sizeof(float)); + tpipe_->InitBuffer(gatherTmpBuf_, sizeof(uint32_t)); // 4 + tpipe_->InitBuffer(statusSumOutBuf_, sizeof(float)); // 4 +} + +template +__aicore__ inline void CamMoeDistributeCombine::SplitCoreCal() +{ + sendRankNum_ = epWorldSize_ / aivNum_; + uint32_t remainderRankNum = epWorldSize_ % aivNum_; + startRankId_ = sendRankNum_ * coreIdx_; + if (coreIdx_ < remainderRankNum) { + sendRankNum_++; + startRankId_ += coreIdx_; + } else { + startRankId_ += remainderRankNum; + } + endRankId_ = startRankId_ + sendRankNum_; +} + +template +__aicore__ inline void CamMoeDistributeCombine::ReduceScatterTrans() +{ + __asm__ __volatile__(""); + DataCacheCleanAndInvalid(tpSendCountGM_[tpRankId_]); + __asm__ __volatile__(""); + uint32_t offset = tpSendCountGM_.GetValue(tpRankId_) * axisH_; + GlobalTensor dataCopyInGM = expandXGM_[offset]; + GM_ADDR rankGM = GetWinAddrByRankId(1 - tpRankId_, TP_DOMAIN) + tpDataOffsetOnWin_; + rankWindow_.SetGlobalBuffer((__gm__ ExpandXType *)rankGM); + uint32_t copyStartIdx = 0; + if (startRankId_ > 0) { + __asm__ __volatile__(""); + DataCacheCleanAndInvalid( + epSendCountGM_[epWorldSize_ + startRankId_ - 1]); + __asm__ __volatile__(""); + copyStartIdx = epSendCountGM_.GetValue(epWorldSize_ + startRankId_ - 1); + } + __asm__ __volatile__(""); + DataCacheCleanAndInvalid( + epSendCountGM_[epWorldSize_ + endRankId_ - 1]); + __asm__ __volatile__(""); + uint32_t copyEndIdx = epSendCountGM_.GetValue(epWorldSize_ + endRankId_ - 1); + LocalTensor tmpUb; + for (uint32_t tokenNumIdx = copyStartIdx; tokenNumIdx < copyEndIdx; tokenNumIdx++) { + tmpUb = moeQueue_.AllocTensor(); + DataCopy(tmpUb, dataCopyInGM[tokenNumIdx * axisH_], axisH_); + moeQueue_.EnQue(tmpUb); + tmpUb = moeQueue_.DeQue(); + DataCopy(rankWindow_[tokenNumIdx * axisH_], tmpUb, axisH_); + moeQueue_.FreeTensor(tmpUb); + } +} + +// 46 -> gm -> ub syncall win->gm add -> alltoall +// 2 -> win wait syncall gm -> ub win ->gm add -> alltoall +template +__aicore__ inline void CamMoeDistributeCombine::SetWaitTpStatusAndDisPatch() +{ + pipe_barrier(PIPE_ALL); + if (startRankId_ >= epWorldSize_) { + return; + } + if constexpr (IsNeedReduceScatter) { + uint32_t tpToRankId = 1 - tpRankId_; + pipe_barrier(PIPE_ALL); + LocalTensor statusFlagUb = readStateBuf_.Get(); + statusFlagUb(0) = sumTarget_; + SyncFunc(); + GlobalTensor tpWindowInstatusFp32Tensor_; + stateGM_ = GetWinStateAddrByRankId(tpToRankId, TP_DOMAIN) + coreIdx_ * stateOffset_; + tpWindowInstatusFp32Tensor_.SetGlobalBuffer((__gm__ float *)stateGM_); + DataCopy(tpWindowInstatusFp32Tensor_, statusFlagUb, 8UL); + SyncFunc(); + LocalTensor statusFp32Tensor_ = readStateBuf_.Get(); + float sumOfFlag = static_cast(-1.0); + uint32_t statusRankOffset = coreIdx_ * stateOffset_ / sizeof(float); // tp = 2 + while (sumOfFlag != sumTarget_) { + DataCopy(statusFp32Tensor_, tpStatusSpaceGlobalTensor_[statusRankOffset], 8); + SyncFunc(); + sumOfFlag = statusFp32Tensor_.GetValue(0); + SyncFunc(); + } + } + // Copy win gm->ub add ->alltoall send + ExpertAlltoAllDispatchCopyAdd(); + SyncFunc(); +} + +template +__aicore__ inline void CamMoeDistributeCombine::ExpertAlltoAllDispatchCopyAdd() +{ + if (startRankId_ >= epWorldSize_) { + return; + } + uint32_t curRankExpertNum = 0; + DataCopyExtParams epSendCntParams; + if (isShardExpert_) { + curRankExpertNum = 1; + epSendCntParams = {1U, static_cast(epWorldSize_ * sizeof(uint32_t)), 0U, 0U, 0U}; + } else { + curRankExpertNum = moeExpertPerRankNum_; + epSendCntParams = {1U, static_cast(moeSendNum_ * sizeof(uint32_t)), 0U, 0U, 0U}; + } + DataCopyPadExtParams copyPadParams{false, 0U, 0U, 0U}; + DataCopyPad(epSendCountLocal_, epSendCountGM_, epSendCntParams, copyPadParams); + SyncFunc(); + uint32_t preCount = 0; + uint32_t startTokenIdx = 0; + uint32_t curTokenNum = 0; + + for (uint32_t expertIdx = 0U; expertIdx < curRankExpertNum; expertIdx++) { + uint32_t sendEpCount = endRankId_ - startRankId_; + for (uint32_t i = 0; i < sendEpCount; ++i) { + uint32_t ep = startRankId_ + (i + epRankId_) % sendEpCount; + if ((ep > 0) || (expertIdx > 0U)) { + preCount = epSendCountLocal_.GetValue(expertIdx * epWorldSize_ + ep - 1); + } else { + preCount = 0; + } + curTokenNum = epSendCountLocal_.GetValue(expertIdx * epWorldSize_ + ep) - preCount; + if (curTokenNum == 0) { + continue; + } + startTokenIdx = preCount * axisH_; + ExpertAlltoAllDispatchInnerCopyAdd(curTokenNum, startTokenIdx, ep, expertIdx); + } + } +} + +template +__aicore__ inline void CamMoeDistributeCombine::ExpertAlltoAllDispatchInnerCopyAdd( + uint32_t tokenNumLoop, uint32_t srcStartTokenIdx, uint32_t ep, uint32_t expertIdx) +{ + GM_ADDR rankGM = GetWinAddrByRankId(ep, EP_DOMAIN, expertIdx) + epDataOffsetOnWin_; + if ((isShardExpert_) && (ep < sharedExpertRankNum_)) { + rankGM = GetWinAddrByRankId(epRankId_, EP_DOMAIN, expertIdx) + ep * moeExpertPerRankNum_ * expertPerSizeOnWin_; + } + rankWindow_.SetGlobalBuffer((__gm__ ExpandXType *)rankGM); + uint32_t dataCnt = axisH_; + for (uint32_t loopIdx = 0; loopIdx < tokenNumLoop; loopIdx++) { + if constexpr (IsNeedReduceScatter) { + gmTpSendCountTensor_ = gmTpSendCountInQueue_.AllocTensor(); + DataCopy(gmTpSendCountTensor_, expandXGM_[srcStartTokenIdx], dataCnt); + gmTpSendCountInQueue_.EnQue(gmTpSendCountTensor_); + + winTpSendCountTensor_ = winTpSendCountInQueue_.AllocTensor(); + DataCopy(winTpSendCountTensor_, tpRankWindow_[srcStartTokenIdx], dataCnt); + winTpSendCountInQueue_.EnQue(winTpSendCountTensor_); + + gmTpSendCountTensor_ = gmTpSendCountInQueue_.DeQue(); + winTpSendCountTensor_ = winTpSendCountInQueue_.DeQue(); + outTensor_ = xOutQueue_.AllocTensor(); + + CustomAdd(outTensor_, winTpSendCountTensor_, gmTpSendCountTensor_, dataCnt); + gmTpSendCountInQueue_.FreeTensor(gmTpSendCountTensor_); + winTpSendCountInQueue_.FreeTensor(winTpSendCountTensor_); + xOutQueue_.EnQue(outTensor_); + + outTensor_ = xOutQueue_.DeQue(); + DataCopy(rankWindow_[loopIdx * dataCnt], outTensor_, dataCnt); + xOutQueue_.FreeTensor(outTensor_); + } else { + gmTpSendCountTensor_ = gmTpSendCountQueue_.AllocTensor(); + DataCopy(gmTpSendCountTensor_, expandXGM_[srcStartTokenIdx], dataCnt); + ExpandXType val = expandXGM_[srcStartTokenIdx].GetValue(0); + gmTpSendCountQueue_.EnQue(gmTpSendCountTensor_); + gmTpSendCountTensor_ = gmTpSendCountQueue_.DeQue(); + DataCopy(rankWindow_[loopIdx * dataCnt], gmTpSendCountTensor_, dataCnt); + gmTpSendCountQueue_.FreeTensor(gmTpSendCountTensor_); + } + srcStartTokenIdx += dataCnt; + } +} + +template +__aicore__ inline void CamMoeDistributeCombine::CustomAdd(LocalTensor &dst, + LocalTensor &src0, + LocalTensor &src1, + uint32_t dataCnt) +{ + if constexpr (AscendC::IsSameType::value) { + Cast(winTpSendCountFloatTensor_, src0, RoundMode::CAST_NONE, dataCnt); + Cast(gmTpSendCountFloatTensor_, src1, RoundMode::CAST_NONE, dataCnt); + pipe_barrier(PIPE_V); + Add(winTpSendCountFloatTensor_, winTpSendCountFloatTensor_, gmTpSendCountFloatTensor_, dataCnt); + pipe_barrier(PIPE_V); + Cast(dst, winTpSendCountFloatTensor_, RoundMode::CAST_ROUND, dataCnt); + } else { + Add(dst, src0, src1, dataCnt); + } +} + +template +__aicore__ inline void CamMoeDistributeCombine::SetStatus() +{ + pipe_barrier(PIPE_ALL); + if (startRankId_ >= epWorldSize_) { + return; + } + + LocalTensor statusFlagUb = readStateBuf_.Get(); + statusFlagUb.SetValue(0, epStateValue_); + SyncFunc(); + + for (uint32_t epIdx = startRankId_; epIdx < endRankId_; epIdx++) { + stateGM_ = GetWinStateAddrByRankId(epIdx, EP_DOMAIN) + epStateOffsetOnWin_; + rankStates_.SetGlobalBuffer((__gm__ int32_t *)stateGM_); + DataCopy(rankStates_, statusFlagUb, 8); + } +} + +template +__aicore__ inline void CamMoeDistributeCombine::WaitDispatch() +{ + if (startRankId_ < epWorldSize_) { + LocalTensor statusTensor = statusBuf_.Get(); + LocalTensor gatherMaskOutTensor = gatherMaskOutBuf_.Get(); + LocalTensor gatherTmpTensor = gatherTmpBuf_.Get(); + LocalTensor statusSumOutTensor = statusSumOutBuf_.Get(); + PipeBarrier(); + + gatherTmpTensor.SetValue(0, 1); + uint32_t mask = 1; // gatherMask + sum + uint64_t rsvdCnt = 0; + DataCopyParams intriParams{static_cast(sendRankNum_), 1, + static_cast((moeSendNum_ > 512) ? 7 : 15), 0}; // srcStride is 15 blocks + float sumOfFlag = static_cast(-1.0); + float minTarget = (sumTarget_ * sendRankNum_) - (float)0.5; + float maxTarget = (sumTarget_ * sendRankNum_) + (float)0.5; + SumParams sumParams{1, sendRankNum_, sendRankNum_}; + SyncFunc(); + while ((sumOfFlag < minTarget) || (sumOfFlag > maxTarget)) { + DataCopy(statusTensor, epStatusSpaceGlobalTensor_[startRankId_ * stateOffset_ / sizeof(float)], + intriParams); + SyncFunc(); + GatherMask(gatherMaskOutTensor, statusTensor, gatherTmpTensor, true, mask, + {1, (uint16_t)sendRankNum_, 1, 0}, rsvdCnt); + PipeBarrier(); + Sum(statusSumOutTensor, gatherMaskOutTensor, sumParams); + SyncFunc(); + sumOfFlag = statusSumOutTensor.GetValue(0); + } + } + + if constexpr (EXEC_FLAG & EXEC_FLAG_DEEP_FUSE) { + AscendC::CrossCoreSetFlag<0x0, PIPE_MTE3>(RECV_SYNC_EVENT_ID); + AscendC::CrossCoreWaitFlag(RECV_SYNC_EVENT_ID); + } else { + SyncAll(); + } +} + +template +__aicore__ inline void CamMoeDistributeCombine::LocalWindowCopy() +{ + uint32_t beginIndex = 0; + uint32_t endIndex = 0; + uint32_t processLen = 0; + uint32_t tokenOffset = 0; + if (axisBS_ < aivNum_) { + uint32_t aivNumPerToken = aivNum_ / axisBS_; // axisBS_ < aivNum_ + if (coreIdx_ >= (axisBS_ * aivNumPerToken)) { + return; + } + uint32_t tokenIndex = coreIdx_ / aivNumPerToken; + processLen = ((axisH_ / UB_ALIGN) / aivNumPerToken) * UB_ALIGN; + tokenOffset = processLen * (coreIdx_ % aivNumPerToken); + if ((coreIdx_ % aivNumPerToken) == (aivNumPerToken - 1)) { + processLen = axisH_ - ((aivNumPerToken - 1) * processLen); + } + beginIndex = tokenIndex; + endIndex = beginIndex + 1U; + } else { + uint32_t tokenPerAivNum = axisBS_ / aivNum_; + uint32_t remainderToken = axisBS_ % aivNum_; + beginIndex = tokenPerAivNum * coreIdx_; + if (coreIdx_ < remainderToken) { + tokenPerAivNum++; + beginIndex = tokenPerAivNum * coreIdx_; + } else { + beginIndex += remainderToken; + } + endIndex = beginIndex + tokenPerAivNum; + processLen = axisH_; + } + LocalTensor expertIdsLocal = expertIdsBuf_.Get(); + LocalTensor expandScalesLocal = expandScalesBuf_.Get(); + + LocalTensor rowTmpFloatLocal = rowTmpFloatBuf_.Get(); + LocalTensor mulBufLocal = mulBuf_.Get(); + LocalTensor sumFloatBufLocal = sumFloatBuf_.Get(); + + LocalTensor indexCountsLocal = indexCountsBuf_.Get(); + const DataCopyExtParams bskParams = {1U, static_cast(bsKNum_ * sizeof(uint32_t)), 0U, 0U, 0U}; + const DataCopyPadExtParams copyPadParams{false, 0U, 0U, 0U}; + const DataCopyPadExtParams copyPadFloatParams{false, 0U, 0U, 0U}; + + DataCopyPad(indexCountsLocal, expandIdxGM_, bskParams, copyPadParams); + DataCopyPad(expertIdsLocal, expertIdsGM_, bskParams, copyPadParams); + DataCopyPad(expandScalesLocal, expandScalesGM_, bskParams, copyPadFloatParams); + SyncFunc(); + + for (uint32_t tokenIndex = beginIndex; tokenIndex < endIndex; tokenIndex++) { + uint32_t index = tokenIndex * axisK_; + SyncFunc(); + Duplicate(sumFloatBufLocal, (float)0, axisH_); + for (uint32_t i = 0; i < axisK_; i++) { + int32_t moeExpert = expertIdsLocal.GetValue(index); + if (moeExpert < 0) { + index++; + continue; + } + float scaleVal = expandScalesLocal.GetValue(index); + GM_ADDR wAddr = (__gm__ uint8_t *)(epWindowGM_) + + expertPerSizeOnWin_ * moeExpertPerRankNum_ * sharedExpertRankNum_ + + expertPerSizeOnWin_ * moeExpert + indexCountsLocal.GetValue(index) * axisHExpandXTypeSize_ + + tokenOffset * sizeof(ExpandXType); + rowTmpGlobal_.SetGlobalBuffer((__gm__ ExpandXType *)wAddr); + ExpandXType val = rowTmpGlobal_.GetValue(0); + LocalTensor tmpUb = moeSumQueue_.AllocTensor(); + DataCopy(tmpUb, rowTmpGlobal_, processLen); + moeSumQueue_.EnQue(tmpUb); + tmpUb = moeSumQueue_.DeQue(); + Cast(rowTmpFloatLocal, tmpUb, AscendC::RoundMode::CAST_NONE, processLen); + AscendC::PipeBarrier(); + AscendC::Muls(mulBufLocal, rowTmpFloatLocal, scaleVal, processLen); + AscendC::PipeBarrier(); + AscendC::Add(sumFloatBufLocal, sumFloatBufLocal, mulBufLocal, processLen); + index++; + moeSumQueue_.FreeTensor(tmpUb); + } + LocalTensor rowTmpLocal = tokenBuf_.Get(); + if (sharedExpertRankNum_ > 0U) { + uint32_t temp = (epRankId_ * axisBS_) / sharedExpertRankNum_; + uint32_t moeOnShareRank = Ceil((tokenIndex + 1 + temp) * sharedExpertRankNum_, axisBS_) - 1 - epRankId_; + uint32_t preCnt = (moeOnShareRank + epRankId_) * axisBS_ / sharedExpertRankNum_ - + epRankId_ * axisBS_ / sharedExpertRankNum_; + __gm__ ExpandXType *shareAddr = + (__gm__ ExpandXType *)(epWindowGM_ + moeOnShareRank * expertPerSizeOnWin_ * moeExpertPerRankNum_) + + (tokenIndex - preCnt) * axisH_ + tokenOffset; + GlobalTensor shareTokGlobal; + shareTokGlobal.SetGlobalBuffer((__gm__ ExpandXType *)(shareAddr)); + SyncFunc(); + DataCopy(rowTmpLocal, shareTokGlobal, processLen); + SyncFunc(); + Cast(rowTmpFloatLocal, rowTmpLocal, AscendC::RoundMode::CAST_NONE, processLen); + AscendC::PipeBarrier(); + AscendC::Add(sumFloatBufLocal, sumFloatBufLocal, rowTmpFloatLocal, processLen); + } + // 结果搬出 + AscendC::PipeBarrier(); + LocalTensor sumBufLocal = tokenBuf_.Get(); + Cast(sumBufLocal, sumFloatBufLocal, AscendC::RoundMode::CAST_RINT, processLen); + SyncFunc(); + DataCopy(expandOutGlobal_[tokenIndex * axisH_ + tokenOffset], sumBufLocal, processLen); + } +} + +template +__aicore__ inline void CamMoeDistributeCombine::Process() +{ + SyncAll(); + if constexpr (IsNeedReduceScatter) { + tpipe_->InitBuffer(moeQueue_, BUFFER_NUM, axisHExpandXTypeSize_); // 7168 * 2 * 2 = 28672 + ReduceScatterTrans(); + } + if constexpr ((EXEC_FLAG & EXEC_FLAG_DEEP_FUSE) == 0) { + BuffInit(); + SetWaitTpStatusAndDisPatch(); + } + AlltoAllBuffInit(); + SetStatus(); + WaitDispatch(); + LocalWindowCopy(); +} + +template +__aicore__ inline void CamMoeDistributeCombine::AllToAllSend() +{ + if constexpr (IsNeedReduceScatter) { + tpipe_->InitBuffer(moeQueue_, BUFFER_NUM, axisHExpandXTypeSize_); // 7168 * 2 * 2 = 28672 + ReduceScatterTrans(); + } + BuffInit(); + if constexpr ((EXEC_FLAG & EXEC_FLAG_DEEP_FUSE) == 0) { + SetWaitTpStatusAndDisPatch(); + AlltoAllBuffInit(); + } + if constexpr (EXEC_FLAG & EXEC_FLAG_DEEP_FUSE) { + AscendC::CrossCoreSetFlag<0x0, PIPE_MTE3>(SEND_SYNC_EVENT_ID); + AscendC::CrossCoreWaitFlag(SEND_SYNC_EVENT_ID); + } else { + SyncAll(); + } + SetStatus(); + if constexpr (EXEC_FLAG & EXEC_FLAG_DEEP_FUSE) { + AscendC::CrossCoreWaitFlag(RECV_SYNC_EVENT_ID); + } else { + SyncAll(); + } +} + +template +__aicore__ inline void CamMoeDistributeCombine::ReducePermute() +{ + AlltoAllBuffInit(); + if constexpr (EXEC_FLAG & EXEC_FLAG_DEEP_FUSE) { + AscendC::CrossCoreSetFlag<0x0, PIPE_MTE3>(SEND_SYNC_EVENT_ID); + } else { + SyncAll(); + } + + WaitDispatch(); + LocalWindowCopy(); + + if constexpr (EXEC_FLAG & EXEC_FLAG_DEEP_FUSE) { + AscendC::CrossCoreWaitFlag(SEND_SYNC_EVENT_ID); + } +} +} // namespace MoeDistributeCombineImpl + +#endif // CAM_MOE_DISTRIBUTE_COMBINE_IMPL_H diff --git a/csrc/dispatch_gmm_combine_decode/op_kernel/dispatch_gmm_combine_decode/raw_distributed/cam_moe_distribute_dispatch.h b/csrc/dispatch_gmm_combine_decode/op_kernel/dispatch_gmm_combine_decode/raw_distributed/cam_moe_distribute_dispatch.h new file mode 100644 index 00000000000..bcafe846109 --- /dev/null +++ b/csrc/dispatch_gmm_combine_decode/op_kernel/dispatch_gmm_combine_decode/raw_distributed/cam_moe_distribute_dispatch.h @@ -0,0 +1,1091 @@ +/* + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +#ifndef CAM_MOE_DISTRIBUTE_DISPATCH_H +#define CAM_MOE_DISTRIBUTE_DISPATCH_H +#define OPT_RANK_OFFSET 512 + +#include "kernel_operator.h" +#include "kernel_tiling/kernel_tiling.h" +#include "../../dispatch_gmm_combine_decode_base.h" +#include "../../dispatch_gmm_combine_decode_tiling.h" + +namespace MoeDistributeDispatchImpl { +constexpr uint8_t BUFFER_NUM = 2; // 多buf +constexpr uint32_t STATE_OFFSET = 512; // 状态空间偏移地址 +constexpr uint32_t STATE_SIZE = 1024 * 1024; // 1M +constexpr uint32_t UB_ALIGN = 32; // UB按32字节对齐 +constexpr uint32_t SELF_STATE_OFFSET = 256 * 1024; +constexpr uint8_t COMM_NUM = 2; // 通信域大小 +constexpr uint8_t COMM_EP_IDX = 0; +constexpr uint8_t COMM_TP_IDX = 1; +constexpr uint32_t GATHER_NUM_PER_TIME = 6; +// 先写死这个偏移,如果TP固定为2,可直接往起始数据偏移开始读写 +constexpr uint64_t WIN_STATE_OFFSET = 512 * 1024; +constexpr uint64_t STATE_WIN_OFFSET = 900 * 1024; +constexpr uint32_t TP_STATE_SIZE = 100 * 1024; +constexpr int CAM_MAX_RANK_SIZE = 384; // Cam通信库最大支持的npu卡数 +constexpr int64_t IPC_DATA_OFFSET = 2 * 1024 * 1024; // 前2MB作为flag标志位,之后100MB作为数据存储 + +// 循环优化相关变量 +using countType = uint8_t; // 循环优化使用的数据类型 +constexpr uint32_t LOOP_OPT_MAX_BS = 64; +constexpr uint32_t LOOP_OPT_MAX_MOE_RANK = 256; +constexpr uint32_t TOPK_ELEM_COUNT_PER_BLOCK = UB_ALIGN / sizeof(int32_t); +constexpr uint32_t TABLE_ELEM_COUNT_PER_BLOCK = UB_ALIGN / sizeof(countType); +constexpr uint32_t INT32_NUM_PER_BLOCK = UB_ALIGN / sizeof(int32_t); + +template +__aicore__ inline void SyncFunc() +{ + int32_t eventID = static_cast(GetTPipePtr()->FetchEventID(event)); + AscendC::SetFlag(eventID); + AscendC::WaitFlag(eventID); +} + +#define TemplateDispatchTypeClass \ + typename XType, typename ExpandXOutType, bool StaticQuant, bool DynamicQuant, bool IsSmoothScaleExist, \ + bool IsNeedAllgater +#define TemplateDispatchTypeFunc XType, ExpandXOutType, StaticQuant, DynamicQuant, IsSmoothScaleExist, IsNeedAllgater + +using namespace AscendC; +template +class CamMoeDistributeDispatch +{ +public: + __aicore__ inline CamMoeDistributeDispatch(){}; + __aicore__ inline void Init(GM_ADDR x, GM_ADDR expertIds, GM_ADDR scales, GM_ADDR expandXOut, + GM_ADDR dynamicScalesOut, GM_ADDR expandIdxOut, GM_ADDR expertTokenNumsOut, + GM_ADDR sendCountsOut, GM_ADDR outputRecvCount, GM_ADDR tpSendCountsOut, + GM_ADDR workspaceGM, TPipe *pipe, const DispatchGmmCombineDecodeTilingData *tilingData); + __aicore__ inline void Process(); + +private: + __aicore__ inline void SendToSharedExpert(); + __aicore__ inline void SendToMoeExpert(); + __aicore__ inline void AlltoAllDispatch(); + __aicore__ inline void LocalWindowCopy(); + __aicore__ inline void QuantProcess(uint32_t expertIndex); + __aicore__ inline void LocalSharedExpertCopyWindow(uint32_t rankIndex, uint32_t tokenOffset, + uint32_t currendTokenIndex, uint32_t &dynamicScalesLocalIdx); + __aicore__ inline void SetStatus(); + __aicore__ inline void WaitDispatch(); + __aicore__ inline void GetCumSum(LocalTensor &inLocal, LocalTensor &outLocal, int32_t totalCount, + GM_ADDR gmOutputRecvCount); + __aicore__ inline void CreateZeroTensor(LocalTensor &outTensor); + __aicore__ inline void AllGatherSetStatusAndWait(); + __aicore__ inline void ResetStatus(); + __aicore__ inline void QuantInit(GM_ADDR scales); + __aicore__ inline void AllgatherProcessOut(); + __aicore__ inline void UpdataMultiMoeTokenNumsOut(); + __aicore__ inline void UpdataTokenNumsOut(); + __aicore__ inline GM_ADDR GetWindAddrByRankId(uint8_t ctxIdx, const int32_t rankId) + { + uint32_t curRankId = ctxIdx == COMM_EP_IDX ? epRankId_ : tpRankId_; + if (curRankId == rankId) { + return (GM_ADDR)(winContext_[ctxIdx]->localWindowsIn) + winDataSizeOffset_ + rankId * OPT_RANK_OFFSET; + } + return (GM_ADDR)(((HcclRankRelationResV2 *)(winContext_[ctxIdx]->remoteRes[rankId].nextDevicePtr))->windowsIn) + + winDataSizeOffset_ + rankId * OPT_RANK_OFFSET; + } + + __aicore__ inline GM_ADDR GetWindStateAddrByRankId(uint8_t ctxIdx, const int32_t rankId) + { + uint32_t curRankId = ctxIdx == COMM_EP_IDX ? epRankId_ : tpRankId_; + if (curRankId == rankId) { + return (GM_ADDR)(winContext_[ctxIdx]->localWindowsExp) + dataState_ * WIN_STATE_OFFSET; + } + return (GM_ADDR)(((HcclRankRelationResV2 *)(winContext_[ctxIdx]->remoteRes[rankId].nextDevicePtr)) + ->windowsExp) + + dataState_ * WIN_STATE_OFFSET; + } + + __aicore__ inline uint32_t MIN(uint32_t x, uint32_t y) + { + return (x < y) ? x : y; + } + TPipe *tpipe_{nullptr}; + GlobalTensor xGMTensor_; + GlobalTensor expertIdsGMTensor_; + GlobalTensor scalesGMTensor_; + GlobalTensor expandXOutGMTensor_; + GlobalTensor dynamicScalesOutGMTensor_; + GlobalTensor expertTokenNumsOutGMTensor_; + GlobalTensor windowInQuantTensor_; + GlobalTensor windowInstatusTensor_; + GlobalTensor windowInstatusFp32Tensor_; + GlobalTensor winTpGatherOutGMTensor_; + GlobalTensor fpWinTpGatherOutGMTensor_; + GlobalTensor winTpEpCntGMTensor_; + LocalTensor xTmpTensor_; + LocalTensor tpTmpTensor_; + LocalTensor xInTensor_; + LocalTensor xOutTensor_; + LocalTensor xOutFp32Tensor_; + LocalTensor expertCountTensor_; + LocalTensor expertIdsTensor_; + LocalTensor receivestatusTensor_; + LocalTensor rowMaxTensor_; + LocalTensor statusTensor_; + LocalTensor statusFp32Tensor_; + LocalTensor smoothScalesTensor_; + LocalTensor dynamicScalesTensor_; + TBuf<> dynamicScalesBuf_; + TBuf<> expertCountBuf_; + TBuf<> expertIdsBuf_; + TBuf<> statusBuf_; + TBuf<> gatherMaskOutBuf_; // gather mask输出buf + TBuf<> getTotalBuf_; // 计算totalCnt + TBuf<> scalarBuf_; // 辅助gather tensor定义 + TBuf<> rowMaxBuf_; + TBuf<> receiveDataCastFloatBuf_; + TBuf<> smoothScalesBuf_; + TQueBind xQueue_; // 非量化使用,量化场景接收也可使用 + TQue xInQueue_; // 量化使用,量化前的输入 + TQue xOutQueue_; // 量化使用,量化后的输出 + GM_ADDR expandXOutGM_; + GM_ADDR expandIdxOutGM_; + GM_ADDR expertTokenNumsOutGM_; // 这个输出没有使用 + GM_ADDR sendCountsOutGM_; + GM_ADDR outputRecvCountGM_; + GM_ADDR sendTpCountOutGM_; + GM_ADDR statusSpaceGm_; + GM_ADDR windowGM_; + GM_ADDR tpWindowGM_; + GM_ADDR tpStatusWindowGM_; + GM_ADDR tpLocalWindowGM_; + GM_ADDR tpLocalStatusWindowGM_; + GlobalTensor peerMemsAddrGm_; + // tiling侧已确保数据上限,相乘不会越界,因此统一采用uint32_t进行处理 + uint32_t axisBS_{0}; + uint32_t axisMaxBS_{0}; + uint32_t axisH_{0}; + uint32_t axisK_{0}; + uint32_t aivNum_{0}; + uint32_t sharedUsedAivNum_{0}; + uint32_t moeUsedAivNum_{0}; + uint32_t epWorldSize_{0}; + uint32_t tpWorldSize_{0}; + uint32_t epRankId_{0}; + uint32_t tpGatherRankId_{0}; // gather 对端ID + uint32_t tpRankId_{0}; // 本卡 ID + uint32_t aivId_{0}; // aiv id + uint32_t sharedExpertRankNum_{0}; // 共享专家卡数 + uint32_t moeExpertRankNum_{0}; // moe专家卡数,等于worldSize_ - 共享专家卡数 + uint32_t moeExpertNumPerRank_{0}; + uint32_t moeExpertNum_{0}; + uint32_t totalExpertNum_{0}; + uint32_t bufferSizePerRank_{0}; + uint32_t recvWinBlockNum_{0}; + uint32_t hSize_{0}; + uint32_t hOutSize_{0}; + uint32_t hCommuSize_{0}; + uint32_t scaleParamPad_{0}; + uint32_t axisHCommu_{0}; + uint32_t startExpertId_; + uint32_t endExpertId_; + uint32_t sendExpertNum_; + uint32_t localCopyCoreNum_; + uint32_t totalCnt_; + uint32_t lastCore_{0}; + uint32_t dataState_{0}; + uint32_t stateOffset_{0}; + uint64_t winDataSizeOffset_{0}; + uint64_t expertPerSizeOnWin_{0}; + uint64_t windyquantOffset_; + bool isShareExpertRank_ = false; + bool isQuant_ = false; + float sumTarget_; + uint64_t totalWinSize_{0}; + uint32_t gatherCount_{0}; + uint32_t expertTokenNumsType_{1}; + uint32_t preCnt_{0}; + __gm__ HcclOpResParam *winContext_[COMM_NUM]{nullptr, nullptr}; + // 循环优化使用的变量 + TBuf<> sendTableIdsBuf_; + LocalTensor tableLocalTensor_; + LocalTensor sendCountLocalTensor_; + uint32_t moeExpertRankNumAligned_; + uint32_t moeExpertRankNumInt16Aligned_; + uint32_t tableElemCount_; + bool enableAivOpt_{false}; +}; + +template +__aicore__ inline void CamMoeDistributeDispatch::Init( + GM_ADDR x, GM_ADDR expertIds, GM_ADDR scales, GM_ADDR expandXOut, GM_ADDR dynamicScalesOut, GM_ADDR expandIdxOut, + GM_ADDR expertTokenNumsOut, GM_ADDR sendCountsOut, GM_ADDR outputRecvCount, GM_ADDR tpSendCountsOut, + GM_ADDR workspaceGM, TPipe *pipe, const DispatchGmmCombineDecodeTilingData *tilingData) +{ + tpipe_ = pipe; + aivId_ = GetBlockIdx(); + epRankId_ = tilingData->disGmmDeqSwigluQuantGmmDeqComInfo.epRankId; + GlobalTensor selfDataStatusTensor; + GM_ADDR statusDataSpaceGm; + + winContext_[COMM_EP_IDX] = (__gm__ HcclOpResParam *)AscendC::GetHcclContext(); + winContext_[COMM_TP_IDX] = (__gm__ HcclOpResParam *)AscendC::GetHcclContext<1>(); // 没有相关公共宏 + + statusDataSpaceGm = (GM_ADDR)(winContext_[COMM_EP_IDX]->localWindowsExp); + selfDataStatusTensor.SetGlobalBuffer((__gm__ int32_t *)(statusDataSpaceGm + STATE_WIN_OFFSET)); + + __asm__ __volatile__(""); + DataCacheCleanAndInvalid( + selfDataStatusTensor[aivId_ * UB_ALIGN]); + __asm__ __volatile__(""); + dataState_ = selfDataStatusTensor(aivId_ * UB_ALIGN); + if (dataState_ == 0) { + selfDataStatusTensor(aivId_ * UB_ALIGN) = 1; + } else { + selfDataStatusTensor(aivId_ * UB_ALIGN) = 0; + } + __asm__ __volatile__(""); + DataCacheCleanAndInvalid( + selfDataStatusTensor[aivId_ * UB_ALIGN]); + __asm__ __volatile__(""); + pipe_barrier(PIPE_ALL); + axisBS_ = tilingData->disGmmDeqSwigluQuantGmmDeqComInfo.bs; + axisH_ = tilingData->disGmmDeqSwigluQuantGmmDeqComInfo.h; + epWorldSize_ = tilingData->disGmmDeqSwigluQuantGmmDeqComInfo.epRankSize; + // axisMaxBS_ = axisBS_; + axisMaxBS_ = tilingData->disGmmDeqSwigluQuantGmmDeqComInfo.globalBs / epWorldSize_; + moeExpertNum_ = tilingData->disGmmDeqSwigluQuantGmmDeqComInfo.moeExpertNum; + sharedExpertRankNum_ = tilingData->disGmmDeqSwigluQuantGmmDeqComInfo.sharedExpertRankNum; + expertTokenNumsType_ = 0; + totalWinSize_ = tilingData->disGmmDeqSwigluQuantGmmDeqComInfo.totalWinSize; + moeExpertRankNum_ = epWorldSize_ - sharedExpertRankNum_; + moeExpertNumPerRank_ = moeExpertNum_ / moeExpertRankNum_; + expertPerSizeOnWin_ = axisMaxBS_ * axisH_ * sizeof(XType); + winDataSizeOffset_ = dataState_ * epWorldSize_ * expertPerSizeOnWin_ * moeExpertNumPerRank_; + tpRankId_ = 0; + windowGM_ = GetWindAddrByRankId(COMM_EP_IDX, epRankId_); + statusSpaceGm_ = GetWindStateAddrByRankId(COMM_EP_IDX, epRankId_); + tpGatherRankId_ = tpRankId_ == 0 ? 1 : 0; + axisK_ = tilingData->disGmmDeqSwigluQuantGmmDeqComInfo.k; + aivNum_ = 48; + tpWorldSize_ = 1; + xGMTensor_.SetGlobalBuffer((__gm__ XType *)x); + expertIdsGMTensor_.SetGlobalBuffer((__gm__ int32_t *)expertIds); + expandXOutGMTensor_.SetGlobalBuffer((__gm__ ExpandXOutType *)expandXOut); + dynamicScalesOutGMTensor_.SetGlobalBuffer((__gm__ float *)dynamicScalesOut); + expertTokenNumsOutGMTensor_.SetGlobalBuffer((__gm__ int64_t *)expertTokenNumsOut); + windowInQuantTensor_.SetGlobalBuffer((__gm__ ExpandXOutType *)windowGM_); + windowInstatusTensor_.SetGlobalBuffer((__gm__ int32_t *)(statusSpaceGm_)); + windowInstatusFp32Tensor_.SetGlobalBuffer((__gm__ float *)(statusSpaceGm_)); + if constexpr (IsNeedAllgater) { + tpLocalWindowGM_ = GetWindAddrByRankId(COMM_TP_IDX, tpRankId_); + tpLocalStatusWindowGM_ = GetWindStateAddrByRankId(COMM_TP_IDX, tpRankId_); + tpWindowGM_ = GetWindAddrByRankId(COMM_TP_IDX, tpGatherRankId_); + tpStatusWindowGM_ = GetWindStateAddrByRankId(COMM_TP_IDX, tpGatherRankId_); + winTpGatherOutGMTensor_.SetGlobalBuffer((__gm__ ExpandXOutType *)tpWindowGM_); + fpWinTpGatherOutGMTensor_.SetGlobalBuffer((__gm__ float *)tpWindowGM_); + winTpEpCntGMTensor_.SetGlobalBuffer((__gm__ int32_t *)(tpStatusWindowGM_ + TP_STATE_SIZE)); + } + expandXOutGM_ = expandXOut; + expandIdxOutGM_ = expandIdxOut; // 无GlobalTensor + sendCountsOutGM_ = sendCountsOut; // 无GlobalTensor + outputRecvCountGM_ = outputRecvCount; + sendTpCountOutGM_ = tpSendCountsOut; + isQuant_ = StaticQuant | DynamicQuant; + hSize_ = axisH_ * sizeof(XType); + hOutSize_ = axisH_ * sizeof(ExpandXOutType); // 如有量化,需要量化后通信 + scaleParamPad_ = (isQuant_ ? 128 : 0); // 预留128B给量化参数,实际只使用了4B(fp32) + hCommuSize_ = hOutSize_ + scaleParamPad_; + axisHCommu_ = hCommuSize_ / sizeof(ExpandXOutType); + if (sharedExpertRankNum_ != 0) { // 后面的卡才需要发给共享专家发数据 + sharedUsedAivNum_ = aivNum_ / (axisK_ + 1); // 均等分,取整 + if (sharedUsedAivNum_ == 0) { + sharedUsedAivNum_ = 1; + } + } + moeUsedAivNum_ = aivNum_ - sharedUsedAivNum_; + bufferSizePerRank_ = 32 * hSize_; + recvWinBlockNum_ = epWorldSize_ * moeExpertNumPerRank_; + isShareExpertRank_ = (epRankId_ < sharedExpertRankNum_) ? true : false; + windyquantOffset_ = epWorldSize_ * axisMaxBS_ * hOutSize_; + GlobalTensor selfStatusTensor; + selfStatusTensor.SetGlobalBuffer((__gm__ int32_t *)(statusSpaceGm_ + SELF_STATE_OFFSET)); + DataCacheCleanAndInvalid( + selfStatusTensor[aivId_ * UB_ALIGN]); + int32_t state = selfStatusTensor(aivId_ * UB_ALIGN); + stateOffset_ = (recvWinBlockNum_ > 512) ? (STATE_OFFSET / 2) : STATE_OFFSET; + tpipe_->InitBuffer(statusBuf_, recvWinBlockNum_ * UB_ALIGN); // expertNum * 32B + statusTensor_ = statusBuf_.Get(); // 保存发送数据量及flag,同时用于计算windows中的偏移 + Duplicate(statusTensor_, 0, recvWinBlockNum_ * 8); // 8 = UB_ALIGN / sizeof(int32_t) + if (state == 0) { + sumTarget_ = (float)1.0; + selfStatusTensor(aivId_ * UB_ALIGN) = 0x3F800000; + uint64_t mask[2] = {0x101010101010101, 0}; // 一次性操作256字节,也是64个int32_t,每8个数将首个设置为0x3F800000 + Duplicate(statusTensor_, 0x3F800000, mask, recvWinBlockNum_ / 8, 1, 8); // 0x3F800000是float的1 + } else { + sumTarget_ = 0.0; + selfStatusTensor(aivId_ * UB_ALIGN) = 0; + } + DataCacheCleanAndInvalid( + selfStatusTensor[aivId_ * UB_ALIGN]); + tpipe_->InitBuffer(xQueue_, BUFFER_NUM, hCommuSize_); // 14k *2 + if (isQuant_) { + QuantInit(scales); + } + uint32_t expertIdsSize = Ceil(axisBS_ * axisK_ * sizeof(int32_t), UB_ALIGN) * UB_ALIGN; // 约束32对齐 + tpipe_->InitBuffer(expertIdsBuf_, expertIdsSize); // BS * K * 4 + expertIdsTensor_ = expertIdsBuf_.Get(); + tpipe_->InitBuffer(expertCountBuf_, expertIdsSize); // BS * K * 4 + expertCountTensor_ = expertCountBuf_.Get(); + + tpipe_->InitBuffer(gatherMaskOutBuf_, recvWinBlockNum_ * sizeof(float)); // worldsize * 4B + tpipe_->InitBuffer(getTotalBuf_, + epWorldSize_ * moeExpertNumPerRank_ * sizeof(int32_t)); // worldsize * 单卡专家数 * 4B + tpipe_->InitBuffer(scalarBuf_, UB_ALIGN * 2); // 72B + + moeExpertRankNumAligned_ = Ceil(moeExpertNum_, TABLE_ELEM_COUNT_PER_BLOCK) * TABLE_ELEM_COUNT_PER_BLOCK; + if (axisBS_ <= LOOP_OPT_MAX_BS && moeExpertRankNumAligned_ <= LOOP_OPT_MAX_MOE_RANK && + axisK_ % TOPK_ELEM_COUNT_PER_BLOCK == 0) { + // UB空间限制BS不大于64、路由专家数量不大于256;对齐要求限制axisK_是8的倍数 + enableAivOpt_ = true; + moeExpertRankNumInt16Aligned_ = moeExpertRankNumAligned_ / 2; // 每个int16_t装2个uint8_t + tableElemCount_ = (axisBS_ + 1) * moeExpertRankNumAligned_; // 额外加一行(首行全0) + + tpipe_->InitBuffer(sendTableIdsBuf_, tableElemCount_ * sizeof(countType)); + tableLocalTensor_ = sendTableIdsBuf_.Get(); + sendCountLocalTensor_ = tableLocalTensor_[axisBS_ * moeExpertRankNumAligned_]; // 计算完成后,最后一行为count + } +} + +template +__aicore__ inline void CamMoeDistributeDispatch::QuantInit(GM_ADDR scales) +{ + tpipe_->InitBuffer(xInQueue_, BUFFER_NUM, hSize_); // 14K *2 + tpipe_->InitBuffer(xOutQueue_, BUFFER_NUM, hCommuSize_); // 7K *2 + scalesGMTensor_.SetGlobalBuffer((__gm__ float *)scales); + uint32_t hFp32Size = axisH_ * sizeof(float); + if constexpr (DynamicQuant) { + tpipe_->InitBuffer(rowMaxBuf_, UB_ALIGN); // 32B + } + tpipe_->InitBuffer(receiveDataCastFloatBuf_, 1 * hFp32Size); // 28KB + tpipe_->InitBuffer(smoothScalesBuf_, axisH_ * sizeof(float)); // 28KB + smoothScalesTensor_ = smoothScalesBuf_.Get(); + tpipe_->InitBuffer(dynamicScalesBuf_, axisBS_ * sizeof(float)); // 32 * 4 + dynamicScalesTensor_ = dynamicScalesBuf_.Get(); +} + +template +__aicore__ inline void CamMoeDistributeDispatch::SendToSharedExpert() +{ + uint32_t sendTokenNum = axisBS_ / sharedUsedAivNum_; // 每个aiv需要发送的token数 + uint32_t remainderTokenNum = axisBS_ % sharedUsedAivNum_; // 余数 + uint32_t newAivId = aivId_ - moeUsedAivNum_; // 由于是后面的核作为发送的共享专家,因此需要换算 + uint32_t startTokenId = sendTokenNum * newAivId; // 每个aiv发送时的起始rankid + if (newAivId < remainderTokenNum) { // 前remainderRankNum个aiv需要多发1个卡的数据 + sendTokenNum += 1; + startTokenId += newAivId; + } else { + startTokenId += remainderTokenNum; + } + if (startTokenId >= axisBS_) { + return; + } + uint32_t endTokenId = startTokenId + sendTokenNum; + for (uint32_t tokenShuffleIndex = 0; tokenShuffleIndex < sendTokenNum; ++tokenShuffleIndex) { + uint32_t tokenIndex = startTokenId + ((tokenShuffleIndex + epRankId_) % sendTokenNum); + uint32_t temp = (epRankId_ * axisBS_) / sharedExpertRankNum_; + uint32_t moeOnShareRank = Ceil((tokenIndex + 1 + temp) * sharedExpertRankNum_, axisBS_) - 1 - epRankId_; // dst + uint32_t preCnt = (moeOnShareRank + epRankId_) * axisBS_ / sharedExpertRankNum_ - + epRankId_ * axisBS_ / sharedExpertRankNum_; // 发给该共享专家已经有多少token数据 + GlobalTensor dstWinGMTensor; + dstWinGMTensor.SetGlobalBuffer((__gm__ ExpandXOutType *)(GetWindAddrByRankId(COMM_EP_IDX, moeOnShareRank) + + expertPerSizeOnWin_ * epRankId_)); + if constexpr (DynamicQuant || StaticQuant) { + xInTensor_ = xInQueue_.AllocTensor(); + DataCopy(xInTensor_, xGMTensor_[tokenIndex * axisH_], axisH_); // 约束对齐 + xInQueue_.EnQue(xInTensor_); + xInTensor_ = xInQueue_.DeQue(); + xOutTensor_ = xOutQueue_.AllocTensor(); + QuantProcess(0); + xOutQueue_.EnQue(xOutTensor_); + + xOutTensor_ = xOutQueue_.DeQue(); + if (isShareExpertRank_) { + xOutFp32Tensor_ = xOutTensor_.template ReinterpretCast(); + DataCopyExtParams dataCopyParamsFloat = {1U, sizeof(float), 0U, 0U, 0U}; + DataCopyPad(dynamicScalesOutGMTensor_[tokenIndex], xOutFp32Tensor_[axisH_ / sizeof(float)], + dataCopyParamsFloat); + if constexpr (IsNeedAllgater) { + DataCopy(winTpGatherOutGMTensor_[tokenIndex * axisHCommu_], xOutTensor_, axisHCommu_); // 约束对齐 + } + DataCopy(expandXOutGMTensor_[tokenIndex * axisH_], xOutTensor_, axisH_); // 约束对齐 + } else { + DataCopy(dstWinGMTensor[(tokenIndex - preCnt) * axisHCommu_], xOutTensor_, axisHCommu_); // 约束对齐 + } + xOutQueue_.FreeTensor(xOutTensor_); + } else { + xTmpTensor_ = xQueue_.AllocTensor(); + DataCopy(xTmpTensor_, xGMTensor_[tokenIndex * axisH_], axisH_); // 约束对齐 + xQueue_.EnQue(xTmpTensor_); + xTmpTensor_ = xQueue_.DeQue(); + if (isShareExpertRank_) { + if constexpr (IsNeedAllgater) { + DataCopy(winTpGatherOutGMTensor_[tokenIndex * axisHCommu_], xTmpTensor_, axisHCommu_); + } + DataCopy(expandXOutGMTensor_[tokenIndex * axisHCommu_], xTmpTensor_, axisHCommu_); + } else { + DataCopy(dstWinGMTensor[(tokenIndex - preCnt) * axisHCommu_], xTmpTensor_, axisHCommu_); // 约束对齐 + } + xQueue_.FreeTensor(xTmpTensor_); + } + } +} + +template +__aicore__ inline void CamMoeDistributeDispatch::SendToMoeExpert() +{ + uint32_t expertIdsCnt = axisBS_ * axisK_; + uint32_t sendTokenNum = expertIdsCnt / moeUsedAivNum_; // 每个aiv需要发送的token数 + uint32_t remainderTokenNum = expertIdsCnt % moeUsedAivNum_; // 余数 + uint32_t startTokenId = sendTokenNum * aivId_; // 每个aiv发送时的起始rankid + if (aivId_ < remainderTokenNum) { // 前remainderRankNum个aiv需要多发1个卡的数据 + sendTokenNum += 1; + startTokenId += aivId_; + } else { + startTokenId += remainderTokenNum; + } + uint32_t endTokenId = startTokenId + sendTokenNum; + GlobalTensor dstWinGMTensor; + for (uint32_t tokenIndex = startTokenId; tokenIndex < endTokenId; ++tokenIndex) { + int32_t dstExpertId = expertIdsTensor_(tokenIndex); + if (dstExpertId < 0) { + continue; + } + uint32_t tempRankId = dstExpertId / moeExpertNumPerRank_ + sharedExpertRankNum_; + GM_ADDR rankGM = (__gm__ uint8_t *)(GetWindAddrByRankId(COMM_EP_IDX, tempRankId) + + (expertPerSizeOnWin_ * + (epRankId_ * moeExpertNumPerRank_ + dstExpertId % moeExpertNumPerRank_)) + + hCommuSize_ * expertCountTensor_(tokenIndex)); // 计算地址偏移 + dstWinGMTensor.SetGlobalBuffer((__gm__ ExpandXOutType *)rankGM); + if constexpr (DynamicQuant || StaticQuant) { + xInTensor_ = xInQueue_.AllocTensor(); + DataCopy(xInTensor_, xGMTensor_[tokenIndex / axisK_ * axisH_], axisH_); // 约束对齐 + xInQueue_.EnQue(xInTensor_); + xInTensor_ = xInQueue_.DeQue(); + xOutTensor_ = xOutQueue_.AllocTensor(); + uint32_t expertIndex = sharedExpertRankNum_ != 0 ? (dstExpertId + 1) : dstExpertId; + QuantProcess(expertIndex); + xOutQueue_.EnQue(xOutTensor_); + + xOutTensor_ = xOutQueue_.DeQue(); + DataCopy(dstWinGMTensor, xOutTensor_, axisHCommu_); // 约束对齐 + xOutQueue_.FreeTensor(xOutTensor_); + } else { + xTmpTensor_ = xQueue_.AllocTensor(); + DataCopy(xTmpTensor_, xGMTensor_[tokenIndex / axisK_ * axisH_], axisH_); // 约束对齐 + xQueue_.EnQue(xTmpTensor_); + xTmpTensor_ = xQueue_.DeQue(); + DataCopy(dstWinGMTensor, xTmpTensor_, axisHCommu_); // 约束对齐 + xQueue_.FreeTensor(xTmpTensor_); + } + } + if (aivId_ == (moeUsedAivNum_ - 1) && (!enableAivOpt_)) { + // 不启用循环优化时,这里才需要写出结果 + GlobalTensor expandIdxGMTensor; + expandIdxGMTensor.SetGlobalBuffer((__gm__ int32_t *)expandIdxOutGM_); + DataCopyExtParams expertIdsCntParams = {1U, static_cast(expertIdsCnt * sizeof(uint32_t)), 0U, 0U, 0U}; + DataCopyPad(expandIdxGMTensor, expertCountTensor_, expertIdsCntParams); + } +} + +template +__aicore__ inline void CamMoeDistributeDispatch::AlltoAllDispatch() +{ + uint32_t expertIdsCnt = axisBS_ * axisK_; + DataCopyExtParams expertIdsCntParams = {1U, static_cast(expertIdsCnt * sizeof(uint32_t)), 0U, 0U, 0U}; + DataCopyPadExtParams copyPadParams{false, 0U, 0U, 0U}; + DataCopyPad(expertIdsTensor_, expertIdsGMTensor_, expertIdsCntParams, copyPadParams); + AscendC::TQueSync expertCntLocalSync; + expertCntLocalSync.SetFlag(0); + expertCntLocalSync.WaitFlag(0); + if (enableAivOpt_) { + LocalTensor tableInt16LocalTensor_ = tableLocalTensor_.template ReinterpretCast(); + Duplicate(tableInt16LocalTensor_, (int16_t)0, tableElemCount_ / 2); // 清零 + SyncFunc(); + for (int tokenIndex = 0; tokenIndex < expertIdsCnt; ++tokenIndex) { // 填表。默认为0,发送置1 + int expertId = expertIdsTensor_(tokenIndex); + if (expertId < 0) { + continue; + } + tableLocalTensor_((tokenIndex / axisK_ + 1) * moeExpertRankNumAligned_ + expertId) = 1; + } + pipe_barrier(PIPE_ALL); + + // 分核,确定每个核要处理的token + uint32_t sendTokenNum = expertIdsCnt / moeUsedAivNum_; + uint32_t remainderTokenNum = expertIdsCnt % moeUsedAivNum_; + uint32_t startTokenId = sendTokenNum * aivId_; + if (aivId_ < remainderTokenNum) { + sendTokenNum += 1; + startTokenId += aivId_; + } else { + startTokenId += remainderTokenNum; + } + uint32_t endTokenId = startTokenId + sendTokenNum; + uint32_t startTokenRow = startTokenId / axisK_; + uint32_t endTokenRow = (endTokenId + axisK_ - 1) / axisK_; + + for (int row = 1; row <= axisBS_; ++row) { + Add(tableInt16LocalTensor_[row * moeExpertRankNumInt16Aligned_], + tableInt16LocalTensor_[row * moeExpertRankNumInt16Aligned_], + tableInt16LocalTensor_[(row - 1) * moeExpertRankNumInt16Aligned_], moeExpertRankNumInt16Aligned_); + pipe_barrier(PIPE_V); + } + + // 计算完成后,下标为的i的行为下标为i+1的token在远端的偏移,最后一行为总count + GlobalTensor expandIdxGMTensor; + if (aivId_ < moeUsedAivNum_) { + SyncFunc(); + for (int row = startTokenRow; row < endTokenRow; ++row) { + for (int expertIndex = 0; expertIndex < axisK_; ++expertIndex) { + int32_t expertId = expertIdsTensor_(row * axisK_ + expertIndex); + if (expertId < 0) { + continue; + } + expertCountTensor_(row * axisK_ + expertIndex) = + (int32_t)tableLocalTensor_(row * moeExpertRankNumAligned_ + expertId); + } + SyncFunc(); + expandIdxGMTensor.SetGlobalBuffer( + (__gm__ int32_t *)(expandIdxOutGM_ + row * axisK_ * sizeof(uint32_t))); + DataCopy(expandIdxGMTensor, expertCountTensor_[row * axisK_], axisK_); + } + } + + // 分核,确定每个核要set status的rank + uint32_t preTotalExpertNum = sharedExpertRankNum_ + moeExpertNum_; + uint32_t preSendExpertNum = preTotalExpertNum / aivNum_; + uint32_t preRemainderRankNum = preTotalExpertNum % aivNum_; + uint32_t preStartExpertId = preSendExpertNum * aivId_; + if (aivId_ < preRemainderRankNum) { + preSendExpertNum += 1; + preStartExpertId += aivId_; + } else { + preStartExpertId += preRemainderRankNum; + } + uint32_t preEndExpertId = preStartExpertId + preSendExpertNum; + preStartExpertId = preStartExpertId >= sharedExpertRankNum_ ? preStartExpertId : sharedExpertRankNum_; + + SyncFunc(); + for (int32_t tmpExpertId = preStartExpertId; tmpExpertId < preEndExpertId; ++tmpExpertId) { + statusTensor_(tmpExpertId * INT32_NUM_PER_BLOCK + 1) = + (int32_t)sendCountLocalTensor_(tmpExpertId - sharedExpertRankNum_); + } + } else { + for (uint32_t tokenIndex = 0; tokenIndex < expertIdsCnt; ++tokenIndex) { + // 防止越界,越界判断(expertId >= epWorldSize_) || (expertId < sharedExpertRankNum_) + int32_t expertId = expertIdsTensor_(tokenIndex) + sharedExpertRankNum_; + if (expertId < 0) { + continue; + } + expertCountTensor_(tokenIndex) = statusTensor_(expertId * INT32_NUM_PER_BLOCK + 1); + statusTensor_(expertId * INT32_NUM_PER_BLOCK + 1)++; + } + } + if (!isShareExpertRank_) { + for (uint32_t curSatatusExpId = 0; curSatatusExpId < sharedExpertRankNum_; ++curSatatusExpId) { + int32_t curExpertCnt = (curSatatusExpId + 1 + epRankId_) * axisBS_ / sharedExpertRankNum_ - + (curSatatusExpId + epRankId_) * axisBS_ / sharedExpertRankNum_; + statusTensor_((curSatatusExpId)*INT32_NUM_PER_BLOCK + 1) = curExpertCnt; + } + } + if ((sharedExpertRankNum_ != 0) && (aivId_ >= moeUsedAivNum_)) { // 后面的核进行发给共享专家 + SendToSharedExpert(); + return; + } + SendToMoeExpert(); +} + +template +__aicore__ inline void CamMoeDistributeDispatch::SetStatus() +{ + pipe_barrier(PIPE_ALL); + SyncAll(); + totalExpertNum_ = sharedExpertRankNum_ + moeExpertNum_; + sendExpertNum_ = totalExpertNum_ / aivNum_; // 每个aiv需要处理的专家数 + uint32_t remainderRankNum = totalExpertNum_ % aivNum_; + startExpertId_ = sendExpertNum_ * aivId_; // + sharedExpertRankNum_, 每个aiv发送的起始rankid + if (aivId_ < remainderRankNum) { // 前remainderRankNum个aiv需要多发1个卡的数据 + sendExpertNum_ += 1; + startExpertId_ += aivId_; + } else { + startExpertId_ += remainderRankNum; + } + endExpertId_ = startExpertId_ + sendExpertNum_; + if (startExpertId_ >= totalExpertNum_) { // 多余的核return + return; + } + GlobalTensor rankGMTensor; + uint32_t offset = stateOffset_ * epRankId_; + for (uint32_t rankIndex = startExpertId_; rankIndex < endExpertId_; ++rankIndex) { + uint32_t dstRankId = rankIndex; + if (moeExpertNumPerRank_ > 1 && (rankIndex >= sharedExpertRankNum_)) { + dstRankId = ((rankIndex - sharedExpertRankNum_) / moeExpertNumPerRank_ + sharedExpertRankNum_); + offset = + (epRankId_ + (rankIndex - sharedExpertRankNum_) % moeExpertNumPerRank_ * epWorldSize_) * stateOffset_; + } + GM_ADDR rankGM = (__gm__ uint8_t *)(GetWindStateAddrByRankId(COMM_EP_IDX, dstRankId) + offset); // 计算地址偏移 + rankGMTensor.SetGlobalBuffer((__gm__ int32_t *)rankGM); + DataCopy(rankGMTensor, statusTensor_[rankIndex * 8], 8UL); // 8时数据大小,按32对齐拷贝 + } + SyncFunc(); +} + +template +__aicore__ inline void CamMoeDistributeDispatch::QuantProcess(uint32_t expertIndex) +{ + float dynamicScale = 0.0; + LocalTensor floatLocalTemp; + floatLocalTemp = receiveDataCastFloatBuf_.Get(); + Cast(floatLocalTemp, xInTensor_, RoundMode::CAST_NONE, axisH_); + xInQueue_.FreeTensor(xInTensor_); + pipe_barrier(PIPE_V); + if constexpr (IsSmoothScaleExist) { + if constexpr (DynamicQuant) { + SyncFunc(); // ub复用,循环同步 + } + DataCopy(smoothScalesTensor_, scalesGMTensor_[expertIndex * axisH_], axisH_); + SyncFunc(); + Mul(floatLocalTemp, floatLocalTemp, smoothScalesTensor_, axisH_); + pipe_barrier(PIPE_V); + } + if constexpr (DynamicQuant) { + LocalTensor floatLocalAbsTemp = smoothScalesBuf_.Get(); + rowMaxTensor_ = rowMaxBuf_.Get(); + Abs(floatLocalAbsTemp, floatLocalTemp, axisH_); + pipe_barrier(PIPE_V); + ReduceMax(rowMaxTensor_, floatLocalAbsTemp, floatLocalAbsTemp, axisH_, false); + SyncFunc(); + dynamicScale = float(127.0) / rowMaxTensor_.GetValue(0); + SyncFunc(); + Muls(floatLocalTemp, floatLocalTemp, dynamicScale, axisH_); + pipe_barrier(PIPE_V); + } + LocalTensor halfLocalTemp = floatLocalTemp.ReinterpretCast(); + LocalTensor int32LocalTemp = floatLocalTemp.ReinterpretCast(); + Cast(int32LocalTemp, floatLocalTemp, RoundMode::CAST_RINT, axisH_); + pipe_barrier(PIPE_V); + SetDeqScale((half)1.000000e+00f); + PipeBarrier(); + Cast(halfLocalTemp, int32LocalTemp, RoundMode::CAST_ROUND, axisH_); + pipe_barrier(PIPE_V); + Cast(xOutTensor_, halfLocalTemp, RoundMode::CAST_TRUNC, axisH_); + floatLocalTemp = xOutTensor_.template ReinterpretCast(); + floatLocalTemp.SetValue(axisH_ / sizeof(float), float(1.0) / dynamicScale); // int8->float32 +} + +template +__aicore__ inline void CamMoeDistributeDispatch::LocalSharedExpertCopyWindow( + uint32_t rankIndex, uint32_t tokenOffset, uint32_t currendTokenIndex, uint32_t &dynamicScalesLocalIdx) +{ + xTmpTensor_ = xQueue_.AllocTensor(); + DataCopy(xTmpTensor_, + windowInQuantTensor_[rankIndex * (expertPerSizeOnWin_ / sizeof(ExpandXOutType)) + + currendTokenIndex * axisHCommu_], + axisHCommu_); + xQueue_.EnQue(xTmpTensor_); + xTmpTensor_ = xQueue_.DeQue(); + if constexpr (DynamicQuant || StaticQuant) { + pipe_barrier(PIPE_ALL); + xOutFp32Tensor_ = xTmpTensor_.template ReinterpretCast(); + dynamicScalesTensor_.SetValue(dynamicScalesLocalIdx++, xOutFp32Tensor_.GetValue(axisH_ / sizeof(float))); + pipe_barrier(PIPE_ALL); + } + if constexpr (IsNeedAllgater) { + DataCopy(winTpGatherOutGMTensor_[tokenOffset * axisH_], xTmpTensor_, axisH_); + } + DataCopy(expandXOutGMTensor_[tokenOffset * axisH_], xTmpTensor_, axisH_); + xQueue_.FreeTensor(xTmpTensor_); +} + +template +__aicore__ inline void CamMoeDistributeDispatch::WaitDispatch() +{ + uint32_t rscvStatusNum = isShareExpertRank_ ? epWorldSize_ : recvWinBlockNum_; + uint32_t recStatusNumPerCore = rscvStatusNum / aivNum_; // 每个aiv需要处理的专家数 + uint32_t remainderRankNum = rscvStatusNum % aivNum_; + uint32_t startStatusIndex = recStatusNumPerCore * aivId_; // + sharedExpertRankNum_, 每个aiv发送的起始rankid + if (aivId_ < remainderRankNum) { // 前remainderRankNum个aiv需要多发1个卡的数据 + recStatusNumPerCore += 1; + startStatusIndex += aivId_; + } else { + startStatusIndex += remainderRankNum; + } + if (startStatusIndex >= rscvStatusNum) { + SyncAll(); + return; + } + LocalTensor gatherMaskOutTensor = gatherMaskOutBuf_.Get(); + LocalTensor gatherTmpTensor = scalarBuf_.GetWithOffset(UB_ALIGN / sizeof(uint32_t), 0); + gatherTmpTensor.SetValue(0, 1); + LocalTensor statusSumOutTensor = scalarBuf_.GetWithOffset(UB_ALIGN / sizeof(float), UB_ALIGN); + statusFp32Tensor_ = statusTensor_.ReinterpretCast(); + uint32_t mask = 1; // gatherMask + sum 相关参数 + uint64_t rsvdCnt = 0; + SumParams sumParams{1, recStatusNumPerCore, recStatusNumPerCore}; + float sumOfFlag = static_cast(-1.0); + float minTarget = (sumTarget_ * recStatusNumPerCore) - (float)0.5; + float maxTarget = (sumTarget_ * recStatusNumPerCore) + (float)0.5; + DataCopyParams intriParams{static_cast(recStatusNumPerCore), 1, + static_cast((recvWinBlockNum_ > 512) ? 7 : 15), 0}; // srcStride为15个block + SyncFunc(); + while ((sumOfFlag < minTarget) || (sumOfFlag > maxTarget)) { + DataCopy(statusFp32Tensor_, windowInstatusFp32Tensor_[startStatusIndex * stateOffset_ / sizeof(float)], + intriParams); + SyncFunc(); + GatherMask(gatherMaskOutTensor, statusFp32Tensor_, gatherTmpTensor, true, mask, + {1, (uint16_t)recStatusNumPerCore, 1, 0}, rsvdCnt); + pipe_barrier(PIPE_V); + Sum(statusSumOutTensor, gatherMaskOutTensor, sumParams); + SyncFunc(); + sumOfFlag = statusSumOutTensor.GetValue(0); + } + SyncAll(); +} + +template +__aicore__ inline void CamMoeDistributeDispatch::GetCumSum(LocalTensor &inLocal, + LocalTensor &outLocal, + int32_t totalCount, + GM_ADDR gmOutputRecvCount) +{ + statusFp32Tensor_ = statusTensor_.ReinterpretCast(); + DataCopyParams intriParams{static_cast(recvWinBlockNum_), 1, + static_cast((recvWinBlockNum_ > 512) ? 7 : 15), 0}; // srcStride为15个block + DataCopy(statusTensor_, windowInstatusTensor_, intriParams); + SyncFunc(); + if (isShareExpertRank_) { + for (uint32_t curSatatusExpId = 0; curSatatusExpId < sharedExpertRankNum_; ++curSatatusExpId) { + int32_t curExpertCnt = (curSatatusExpId + 1 + epRankId_) * axisBS_ / sharedExpertRankNum_ - + (curSatatusExpId + epRankId_) * axisBS_ / sharedExpertRankNum_; + statusTensor_((curSatatusExpId)*INT32_NUM_PER_BLOCK + 1) = curExpertCnt; + } + } + outLocal = gatherMaskOutBuf_.Get(); // 内存复用 + LocalTensor getTotalLocal = getTotalBuf_.Get(); + // gather mask在一起 + TBuf<> gatherTmpBuf; + TBuf<> workLocalBuf; + tpipe_->InitBuffer(gatherTmpBuf, sizeof(uint32_t) * recvWinBlockNum_ / 4); + LocalTensor gatherTmpTensor = gatherTmpBuf.Get(); + Duplicate(gatherTmpTensor, (uint32_t)33686018, recvWinBlockNum_ / 4); // 0000 0010 0000 0010 0000 0010 0000 0010 + PipeBarrier(); + uint32_t mask = recvWinBlockNum_ * 8; // 512 / 32 + uint64_t rsvdCnt = 0; + GatherMask(outLocal, inLocal, gatherTmpTensor, true, mask, {1, 1, 0, 0}, rsvdCnt); + AscendC::GlobalTensor recvCountTensor; + recvCountTensor.SetGlobalBuffer((__gm__ int32_t *)gmOutputRecvCount); + uint32_t localExpertNum = isShareExpertRank_ ? 1 : moeExpertNumPerRank_; + AscendC::DataCopyExtParams dataCopyParams = { + 1U, static_cast(localExpertNum * epWorldSize_ * sizeof(int32_t)), 0U, 0U, 0U}; + SyncFunc(); + AscendC::DataCopyPad(recvCountTensor, outLocal.ReinterpretCast(), dataCopyParams); + SyncFunc(); + // 再用cumsum累加,按照列相加 + int typeSize = sizeof(int32_t); + int32_t elementsPerBlock = 32 / typeSize; + int32_t elementsPerRepeat = 256 / typeSize; + int32_t firstMaxRepeat = epWorldSize_; + int32_t iter1OutputCount = firstMaxRepeat; + int32_t iter1AlignEnd = ((iter1OutputCount + elementsPerBlock - 1) / elementsPerBlock) * elementsPerBlock; + int32_t finalWorkLocalNeedSize = iter1AlignEnd; + tpipe_->InitBuffer(workLocalBuf, finalWorkLocalNeedSize * sizeof(int32_t)); + LocalTensor workLocalTensor = workLocalBuf.Get(); + LocalTensor tmpFp32 = outLocal.ReinterpretCast(); + PipeBarrier(); + ReduceSum(getTotalLocal, tmpFp32, workLocalTensor, epWorldSize_); + totalCnt_ = getTotalLocal.ReinterpretCast().GetValue(0); + PipeBarrier(); + ReduceSum(tmpFp32, tmpFp32, workLocalTensor, totalCount); + PipeBarrier(); +} + +template +__aicore__ inline void +CamMoeDistributeDispatch::CreateZeroTensor(LocalTensor &outLocal) +{ + TBuf<> outBuf; + tpipe_->InitBuffer(outBuf, UB_ALIGN); + outLocal = outBuf.Get(); + for (uint32_t i = 0; i < 2; i++) { + outLocal.SetValue(i, 0); + } +} + +template +__aicore__ inline void CamMoeDistributeDispatch::LocalWindowCopy() +{ + uint32_t totalMoeExpert = 0; + LocalTensor outCountLocal; + if (isShareExpertRank_) { + totalMoeExpert = epWorldSize_; + } else { + totalMoeExpert = epWorldSize_ * moeExpertNumPerRank_; + } + sendExpertNum_ = totalMoeExpert / aivNum_; // 每个aiv需要处理的专家数 + uint32_t remainderRankNum = totalMoeExpert % aivNum_; + startExpertId_ = sendExpertNum_ * aivId_; // + sharedExpertRankNum_, 每个aiv发送的起始rankid + if (aivId_ < remainderRankNum) { // 前remainderRankNum个aiv需要多发1个卡的数据 + sendExpertNum_ += 1; + startExpertId_ += aivId_; + } else { + startExpertId_ += remainderRankNum; + } + endExpertId_ = startExpertId_ + sendExpertNum_; + if (startExpertId_ >= totalMoeExpert) { // 多余的核return + return; + } + GetCumSum(statusTensor_, outCountLocal, startExpertId_ + 1, outputRecvCountGM_); + uint32_t index = 0; + uint32_t beginIdx = 0; + DataCopyExtParams dataCopyParamsFloat = {1U, sizeof(float), 0U, 0U, 0U}; + for (uint32_t index = startExpertId_; index < endExpertId_; index++) { + uint32_t i = index - startExpertId_; + if (i > 0) { + outCountLocal.SetValue(i, outCountLocal.GetValue(i - 1) + outCountLocal.GetValue(index)); + } + uint32_t count = statusTensor_.GetValue(index * INT32_NUM_PER_BLOCK + 1); + beginIdx = outCountLocal.GetValue(i) - count; + if constexpr (IsNeedAllgater) { + gatherCount_ += count; + } + if (i == 0) { + preCnt_ = beginIdx; + } + if (isShareExpertRank_) { + if (index < sharedExpertRankNum_) { // 共享专家前面排布的是本卡数据,只需要统计epRecvCnt,不需要去搬出 + beginIdx += count; + continue; + } + } + uint32_t winOffset = index; + if (!isShareExpertRank_) { + if (moeExpertNumPerRank_ > 1) { + winOffset = + index % epWorldSize_ * moeExpertNumPerRank_ + index / epWorldSize_; // 转换成数据区的排布偏移 + } + } + GM_ADDR wAddr = (__gm__ uint8_t *)(windowGM_) + winOffset * expertPerSizeOnWin_; + GlobalTensor tokGlobal; + GlobalTensor expandXOutGlobal; + for (uint32_t j = 0; j < count; j++) { + tokGlobal.SetGlobalBuffer((__gm__ ExpandXOutType *)(wAddr + j * hCommuSize_)); + xTmpTensor_ = xQueue_.AllocTensor(); + DataCopy(xTmpTensor_, tokGlobal, axisHCommu_); + xQueue_.EnQue(xTmpTensor_); + xTmpTensor_ = xQueue_.DeQue(); + if constexpr (DynamicQuant || StaticQuant) { + pipe_barrier(PIPE_ALL); + xOutFp32Tensor_ = xTmpTensor_.template ReinterpretCast(); + DataCopyPad(dynamicScalesOutGMTensor_[beginIdx + j], xOutFp32Tensor_[axisH_ / sizeof(float)], + dataCopyParamsFloat); + pipe_barrier(PIPE_ALL); + } + if constexpr (IsNeedAllgater) { + DataCopy(winTpGatherOutGMTensor_[(beginIdx + j) * axisHCommu_], xTmpTensor_, axisHCommu_); + } + expandXOutGlobal.SetGlobalBuffer((__gm__ ExpandXOutType *)(expandXOutGM_) + (beginIdx + j) * axisH_, + axisH_); + DataCopy(expandXOutGlobal, xTmpTensor_, axisH_); + xQueue_.FreeTensor(xTmpTensor_); + } + beginIdx += count; + } + if constexpr (!IsNeedAllgater) { + totalCnt_ = beginIdx; + } + lastCore_ = MIN(totalMoeExpert, aivNum_) - 1; + if constexpr (IsNeedAllgater) { + DataCopyExtParams dataCopyOutParams = {1U, static_cast(sendExpertNum_ * sizeof(int32_t)), 0U, 0U, 0U}; + DataCopyPad(winTpEpCntGMTensor_[startExpertId_], outCountLocal, dataCopyOutParams); + } + DataCopyExtParams dataCopyOutParams = {1U, static_cast(sendExpertNum_ * sizeof(int32_t)), 0U, 0U, 0U}; + GlobalTensor sendCountsGlobal; + sendCountsGlobal.SetGlobalBuffer(reinterpret_cast<__gm__ int32_t *>(sendCountsOutGM_)); + DataCopyPad(sendCountsGlobal[startExpertId_], outCountLocal, dataCopyOutParams); + PipeBarrier(); +} + +template +__aicore__ inline void CamMoeDistributeDispatch::AllGatherSetStatusAndWait() +{ + pipe_barrier(PIPE_ALL); + if (startExpertId_ >= totalExpertNum_) { + return; + } + GM_ADDR rankGM = (__gm__ uint8_t *)(GetWindStateAddrByRankId(COMM_TP_IDX, tpGatherRankId_) + stateOffset_ * aivId_); + GlobalTensor tpwindowInstatusFp32Tensor_; + tpwindowInstatusFp32Tensor_.SetGlobalBuffer((__gm__ float *)(rankGM)); + statusTensor_(aivId_ * INT32_NUM_PER_BLOCK + 1) = gatherCount_; + statusTensor_(aivId_ * INT32_NUM_PER_BLOCK + 2) = preCnt_; + LocalTensor statusFp32Tensor_ = statusTensor_.ReinterpretCast(); + statusFp32Tensor_(aivId_ * 8) = sumTarget_; + SyncFunc(); + DataCopy(tpwindowInstatusFp32Tensor_, statusFp32Tensor_[aivId_ * 8], + UB_ALIGN); // 12是数据大小,按32对齐拷贝 + SyncFunc(); + float sumOfFlag = static_cast(-1.0); + rankGM = + (__gm__ uint8_t *)(GetWindStateAddrByRankId(COMM_TP_IDX, tpRankId_) + stateOffset_ * aivId_); // 计算地址偏移 + tpwindowInstatusFp32Tensor_.SetGlobalBuffer((__gm__ float *)(rankGM)); + while (sumOfFlag != sumTarget_) { + DataCopy(statusFp32Tensor_, tpwindowInstatusFp32Tensor_, UB_ALIGN); + SyncFunc(); + sumOfFlag = statusFp32Tensor_.GetValue(0); + SyncFunc(); + } +} + +template +__aicore__ inline void CamMoeDistributeDispatch::AllgatherProcessOut() +{ + if (startExpertId_ >= totalExpertNum_) { + return; + } + // 获取需要allgather的tokens数量 + GlobalTensor tpwindowInstatusFp32Tensor_; + GM_ADDR rankGM = (__gm__ uint8_t *)(GetWindStateAddrByRankId(COMM_TP_IDX, tpRankId_) + stateOffset_ * aivId_); + tpwindowInstatusFp32Tensor_.SetGlobalBuffer((__gm__ float *)rankGM); + LocalTensor statusFp32Tensor_ = statusTensor_.ReinterpretCast(); + DataCopy(statusFp32Tensor_, tpwindowInstatusFp32Tensor_, UB_ALIGN); + SyncFunc(); + uint32_t coreGatherCount = statusFp32Tensor_.ReinterpretCast().GetValue(1); + uint32_t preCount = statusFp32Tensor_.ReinterpretCast().GetValue(2); + gatherCount_ = coreGatherCount; + preCnt_ = preCount; + GlobalTensor sendCountsGlobal; + GlobalTensor tpGlobal; + // 搬运另一个tp域卡传来的epRcvCnt + sendCountsGlobal.SetGlobalBuffer(reinterpret_cast<__gm__ int32_t *>(sendCountsOutGM_)); + tpGlobal.SetGlobalBuffer((__gm__ int32_t *)(tpLocalStatusWindowGM_ + TP_STATE_SIZE)); + DataCopyExtParams dataCopyParams = {1U, static_cast(sendExpertNum_ * sizeof(int32_t)), 0U, 0U, 0U}; + DataCopyPadExtParams copyPadParams{false, 0U, 0U, 0U}; + tpTmpTensor_ = xQueue_.AllocTensor(); + DataCopyPad(tpTmpTensor_, tpGlobal[startExpertId_], dataCopyParams, copyPadParams); + xQueue_.EnQue(tpTmpTensor_); + tpTmpTensor_ = xQueue_.DeQue(); + DataCopyPad(sendCountsGlobal[epWorldSize_ + startExpertId_], tpTmpTensor_, dataCopyParams); + xQueue_.FreeTensor(tpTmpTensor_); + if (coreGatherCount == 0) { + return; + } + // 输出起始偏移本卡数据 + GlobalTensor tokGlobal; + GlobalTensor expandXOutGlobal; + DataCopyExtParams dataCopyParamsFloat = {1U, sizeof(float), 0U, 0U, 0U}; + for (uint32_t i = 0; i < coreGatherCount; i++) { + tokGlobal.SetGlobalBuffer((__gm__ ExpandXOutType *)(tpLocalWindowGM_ + (preCount + i) * hCommuSize_)); + xTmpTensor_ = xQueue_.AllocTensor(); + DataCopy(xTmpTensor_, tokGlobal, axisHCommu_); + xQueue_.EnQue(xTmpTensor_); + xTmpTensor_ = xQueue_.DeQue(); + expandXOutGlobal.SetGlobalBuffer( + (__gm__ ExpandXOutType *)(expandXOutGM_ + (preCount + totalCnt_ + i) * hOutSize_)); + DataCopy(expandXOutGlobal, xTmpTensor_, axisH_); + if constexpr (StaticQuant || DynamicQuant) { + xOutFp32Tensor_ = xTmpTensor_.template ReinterpretCast(); + DataCopyPad(dynamicScalesOutGMTensor_[preCount + totalCnt_ + i], xOutFp32Tensor_[axisH_ / sizeof(float)], + dataCopyParamsFloat); + } + xQueue_.FreeTensor(xTmpTensor_); + } +} + +// 更新多专家卡上的tokenNumsOut tensor +template +__aicore__ inline void CamMoeDistributeDispatch::UpdataMultiMoeTokenNumsOut() +{ + uint32_t tokenSums = 0; + GlobalTensor sendCountsGlobal; + sendCountsGlobal.SetGlobalBuffer(reinterpret_cast<__gm__ int32_t *>(sendCountsOutGM_)); + for (uint32_t localMoeIndex = 0; localMoeIndex < moeExpertNumPerRank_; ++localMoeIndex) { + if (localMoeIndex == 0) { + DataCacheCleanAndInvalid( + sendCountsGlobal[epWorldSize_ - 1]); + uint32_t firstMoeCnt = sendCountsGlobal.GetValue(epWorldSize_ - 1); + tokenSums = firstMoeCnt + gatherCount_; + expertTokenNumsOutGMTensor_.SetValue(localMoeIndex, tokenSums); + DataCacheCleanAndInvalid( + expertTokenNumsOutGMTensor_[localMoeIndex]); + } else { + uint32_t preIndex = epWorldSize_ * (localMoeIndex - 1) + epWorldSize_ - 1; + uint32_t curIndex = epWorldSize_ * localMoeIndex + epWorldSize_ - 1; + DataCacheCleanAndInvalid( + sendCountsGlobal[preIndex]); + DataCacheCleanAndInvalid( + sendCountsGlobal[curIndex]); + uint32_t preMoeIndexCnt = sendCountsGlobal.GetValue(preIndex); + uint32_t curMoeIndexCnt = sendCountsGlobal.GetValue(curIndex); + tokenSums = + ((expertTokenNumsType_ == 0) ? tokenSums : 0) + (curMoeIndexCnt - preMoeIndexCnt) + gatherCount_; + expertTokenNumsOutGMTensor_.SetValue(localMoeIndex, tokenSums); + DataCacheCleanAndInvalid( + expertTokenNumsOutGMTensor_[localMoeIndex]); + } + } +} + +// 更新tokenNumsOut tensor +template +__aicore__ inline void CamMoeDistributeDispatch::UpdataTokenNumsOut() +{ + // 最后一个核做更新,Moe专家只有最后一个核有计算出所有 sendCountsGlobal + if (!isShareExpertRank_ && moeExpertNumPerRank_ > 1) { + SyncAll(); + if (aivId_ != lastCore_) return; + SyncFunc(); + UpdataMultiMoeTokenNumsOut(); + } else { + if (aivId_ != lastCore_) return; + uint32_t tokenNum = 0; + // Moe专家token总数在Cumsum内计算得出 + tokenNum = totalCnt_; + if constexpr (IsNeedAllgater) { + tokenNum += preCnt_; + tokenNum += gatherCount_; + } + expertTokenNumsOutGMTensor_.SetValue(0, tokenNum); + DataCacheCleanAndInvalid( + expertTokenNumsOutGMTensor_); + } + // token总数 = 其他专家搬进来的token数 + allgather拿到的另一张卡token数 + if constexpr (IsNeedAllgater) { + GlobalTensor sendTpCountsGlobal; + sendTpCountsGlobal.SetGlobalBuffer(reinterpret_cast<__gm__ int32_t *>(sendTpCountOutGM_)); + sendTpCountsGlobal.SetValue(tpRankId_, totalCnt_); + sendTpCountsGlobal.SetValue(tpGatherRankId_, gatherCount_ + preCnt_); + DataCacheCleanAndInvalid( + sendTpCountsGlobal); // 当前tpId只会为0或1,只需要刷一次Cache + } +} + +template +__aicore__ inline void CamMoeDistributeDispatch::Process() +{ + if ASCEND_IS_AIV { // 全aiv处理 + AlltoAllDispatch(); + SetStatus(); + WaitDispatch(); + LocalWindowCopy(); + if constexpr (IsNeedAllgater) { + AllGatherSetStatusAndWait(); + AllgatherProcessOut(); + } + UpdataTokenNumsOut(); + } +} + +} // namespace MoeDistributeDispatchImpl +#endif // CAM_MOE_DISTRIBUTE_DISPATCH_H diff --git a/csrc/dispatch_gmm_combine_decode/op_kernel/dispatch_gmm_combine_decode_base.h b/csrc/dispatch_gmm_combine_decode/op_kernel/dispatch_gmm_combine_decode_base.h new file mode 100644 index 00000000000..97a329da75d --- /dev/null +++ b/csrc/dispatch_gmm_combine_decode/op_kernel/dispatch_gmm_combine_decode_base.h @@ -0,0 +1,18 @@ +/* + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ +#ifndef DISPATCH_GMM_COMBINE_DECODE_BASE_H +#define DISPATCH_GMM_COMBINE_DECODE_BASE_H + +#include "moe_distribute_base.h" + +#define TemplateMC2TypeClass typename ExpandXType, typename ExpandIdxType, bool IsNeedReduceScatter, uint32_t EXEC_FLAG +#define TemplateMC2TypeFunc ExpandXType, ExpandIdxType, IsNeedReduceScatter, EXEC_FLAG + +#endif // DISPATCH_GMM_COMBINE_DECODE_BASE_H diff --git a/csrc/dispatch_gmm_combine_decode/op_kernel/dispatch_gmm_combine_decode_tiling.h b/csrc/dispatch_gmm_combine_decode/op_kernel/dispatch_gmm_combine_decode_tiling.h new file mode 100644 index 00000000000..b8a831ac39c --- /dev/null +++ b/csrc/dispatch_gmm_combine_decode/op_kernel/dispatch_gmm_combine_decode_tiling.h @@ -0,0 +1,74 @@ +/* + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +#ifndef DISPATCH_GMM_COMBINE_DECODE_TILING_H +#define DISPATCH_GMM_COMBINE_DECODE_TILING_H + +#include "kernel_tiling/kernel_tiling.h" + +struct DispatchGmmCombineDecodeInfo { + uint32_t epRankSize; // epRankSize + uint32_t epRankId; // epRankId + uint32_t moeExpertNum; // moe expert number + uint32_t moeExpertNumPerRank; // moe expert number per rank + uint32_t sharedExpertNum; // shared expert number + uint32_t sharedExpertRankNum; // shared expert rank number + uint32_t quantMode; // quant mode + uint32_t globalBs; // globalBs = BS * worldSize + uint32_t bs; // bs + uint32_t k; // k + uint32_t h; // h + uint32_t aicNum; // aicNum + uint32_t aivNum; // aivNum + uint64_t totalUbSize; + uint64_t totalWinSize; + uint64_t gmm1HLen; +}; + +struct DispatchGmmCombineDecodeTilingData { + Mc2InitTiling mc2InitTiling; + Mc2CcTiling mc2CcTiling; + DispatchGmmCombineDecodeInfo disGmmDeqSwigluQuantGmmDeqComInfo; +}; + +constexpr uint32_t GM_ALIGN_BYTE = 512; +constexpr uint32_t CUSTOM_PRELOAD_STAGES = 1; +constexpr uint32_t CUSTOM_L1_STAGES = 2; +constexpr uint32_t CUSTOM_L0A_STAGES = 2; +constexpr uint32_t CUSTOM_L0B_STAGES = 2; +constexpr uint32_t CUSTOM_L0C_STAGES = 1; +constexpr bool CUSTOM_ENABLE_UNIT_FLAG = true; +constexpr bool CUSTOM_ENABLE_SHUFFLE_K = true; + +constexpr uint32_t GMM1_L1M = 256; +constexpr uint32_t GMM1_L1N = 128; +constexpr uint32_t GMM1_L1K = 512; +constexpr uint32_t GMM1_L0K = 128; +constexpr uint32_t GMM1_EPIM = 64; +constexpr uint32_t GMM1_SWIZZLE_OFFSET = 3; +constexpr uint32_t GMM1_SWIZZLE_DIRECTION = 0; + +constexpr uint32_t GMM2_L1A_STAGES = 4; +constexpr uint32_t GMM2_L1B_STAGES = 2; +constexpr uint32_t GMM2_L0A_STAGES = 4; +constexpr uint32_t GMM2_L0B_STAGES = 2; +constexpr uint32_t GMM2_L1M = 128; +constexpr uint32_t GMM2_L1N = 256; +constexpr uint32_t GMM2_L1K = 512; +constexpr uint32_t GMM2_L0K = 128; +constexpr uint32_t GMM2_EPIM = 32; +constexpr uint32_t GMM2_SWIZZLE_OFFSET = 3; +constexpr uint32_t GMM2_SWIZZLE_DIRECTION = 0; + +constexpr uint32_t WORKSPACE_STAGES = 4; + +constexpr uint32_t EXEC_FLAG_DEEP_FUSE = (1U << 0); + +#endif // DISPATCH_GMM_COMBINE_DECODE_TILING_H diff --git a/csrc/dispatch_gmm_combine_decode/op_kernel/moe_distribute_base.h b/csrc/dispatch_gmm_combine_decode/op_kernel/moe_distribute_base.h new file mode 100644 index 00000000000..1bddf66261c --- /dev/null +++ b/csrc/dispatch_gmm_combine_decode/op_kernel/moe_distribute_base.h @@ -0,0 +1,175 @@ +/* + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ +#ifndef MOE_DISTRIBUTE_BASE_H +#define MOE_DISTRIBUTE_BASE_H + +constexpr uint32_t LOCAL_NOTIFY_MAX_NUM = 64; +constexpr uint32_t LOCAL_STREAM_MAX_NUM = 19; +constexpr uint32_t AICPU_OP_NOTIFY_MAX_NUM = 2; +constexpr uint32_t AICPU_MAX_RANK_NUM = 128 * 1024; + +struct HcclSignalInfo { + uint64_t resId; + uint64_t addr; + uint32_t devId; + uint32_t tsId; + uint32_t rankId; + uint32_t flag; +}; + +struct ListCommon { + uint64_t nextHost; + uint64_t preHost; + uint64_t nextDevice; + uint64_t preDevice; +}; + +struct HcclStreamInfo { + int32_t streamIds; + uint32_t sqIds; + uint32_t cqIds; + uint32_t logicCqids; +}; + +struct LocalResInfoV2 { + uint32_t streamNum; + uint32_t signalNum; + HcclSignalInfo localSignals[LOCAL_NOTIFY_MAX_NUM]; + HcclStreamInfo streamInfo[LOCAL_STREAM_MAX_NUM]; + HcclStreamInfo mainStreamInfo; + HcclSignalInfo aicpuOpNotify[AICPU_OP_NOTIFY_MAX_NUM]; + ListCommon nextTagRes; // HccltagLocalResV2 +}; + +enum class rtFloatOverflowMode_t { + RT_OVERFLOW_MODE_SATURATION = 0, + RT_OVERFLOW_MODE_INFNAN, + RT_OVERFLOW_MODE_UNDEF, +}; + +struct AlgoTopoInfo { + uint32_t userRank; // RankID + uint32_t userRankSize; // Rank Number + int32_t deviceLogicId; + bool isSingleMeshAggregation; + uint32_t deviceNumPerAggregation; + uint32_t superPodNum; + uint32_t devicePhyId; + uint32_t topoType; // TopoType + uint32_t deviceType; + uint32_t serverNum; + uint32_t meshAggregationRankSize; + uint32_t multiModuleDiffDeviceNumMode; + uint32_t multiSuperPodDiffServerNumMode; + uint32_t realUserRank; + bool isDiffDeviceModule; + bool isDiffDeviceType; + uint32_t gcdDeviceNumPerAggregation; + uint32_t moduleNum; + uint32_t isUsedRdmaRankPairNum; + uint64_t isUsedRdmaRankPair; + uint32_t pairLinkCounterNum; + uint64_t pairLinkCounter; + uint32_t nicNum; + uint64_t nicList; + uint64_t complanRankLength; + uint64_t complanRank; + uint64_t bridgeRankNum; + uint64_t bridgeRank; + uint64_t serverAndsuperPodRankLength; + uint64_t serverAndsuperPodRank; +}; + +struct HcclOpConfig { + uint8_t deterministic; + uint8_t retryEnable; + uint8_t highPerfEnable; + uint8_t padding[5]; + uint8_t linkTimeOut[8]; + uint64_t notifyWaitTime; + uint32_t retryHoldTime; + uint32_t retryIntervalTime; + bool interHccsDisable = false; + rtFloatOverflowMode_t floatOverflowMode = rtFloatOverflowMode_t::RT_OVERFLOW_MODE_UNDEF; + uint32_t multiQpThreshold = 512; +}; + +struct HcclMC2WorkSpace { + uint64_t workSpace; + uint64_t workSpaceSize; +}; + +struct RemoteResPtr { + uint64_t nextHostPtr; + uint64_t nextDevicePtr; +}; + +struct HDCommunicateParams { + uint64_t hostAddr{0}; + uint64_t deviceAddr{0}; + uint64_t readCacheAddr{0}; + uint32_t devMemSize{0}; + uint32_t buffLen{0}; + uint32_t flag{0}; +}; + +struct HcclRankRelationResV2 { + uint32_t remoteUsrRankId; + uint32_t remoteWorldRank; + uint64_t windowsIn; + uint64_t windowsOut; + uint64_t windowsExp; + ListCommon nextTagRes; +}; + +struct HcclOpResParam { + // local resource + HcclMC2WorkSpace mc2WorkSpace; + uint32_t localUsrRankId; // usrrankid + uint32_t rankSize; + uint64_t winSize; + uint64_t localWindowsIn; + uint64_t localWindowsOut; + char hcomId[128]; + // aicore detect remote window + uint64_t winExpSize; + uint64_t localWindowsExp; + uint32_t rWinStart; + uint32_t rWinOffset; + uint64_t version; + LocalResInfoV2 localRes; + AlgoTopoInfo topoInfo; + + // config parameters + HcclOpConfig config; + uint64_t hostStateInfo; + uint64_t aicpuStateInfo; + uint64_t lockAddr; + uint32_t rsv[16]; + uint32_t notifysize; + uint32_t remoteResNum; + RemoteResPtr remoteRes[AICPU_MAX_RANK_NUM]; + + // communicate retry + HDCommunicateParams kfcControlTransferH2DParams; + HDCommunicateParams kfcStatusTransferD2HParams; + uint64_t tinyMem; // for all2all + uint64_t tinyMemSize; + // zero-copy + uint64_t zeroCopyHeadPtr; + uint64_t zeroCopyTailPtr; + uint64_t zeroCopyRingBuffer; + uint64_t zeroCopyIpcPtrs[16]; + uint32_t zeroCopyDevicePhyId[16]; + + bool utraceStatusFlag; +}; + +#endif // MOE_DISTRIBUTE_BASE_H diff --git a/csrc/third_party/catlass b/csrc/third_party/catlass new file mode 160000 index 00000000000..716fd7baa7f --- /dev/null +++ b/csrc/third_party/catlass @@ -0,0 +1 @@ +Subproject commit 716fd7baa7fb7f6cac0488bb628fd1dd0e875641 diff --git a/csrc/torch_binding.cpp b/csrc/torch_binding.cpp index 06338e4f475..5b7515c307f 100644 --- a/csrc/torch_binding.cpp +++ b/csrc/torch_binding.cpp @@ -587,6 +587,63 @@ std::tuple grouped_matmul_swiglu_quant_weigh return std::tuple(output, output_scale, output_offset); } +std::tuple dispatch_gmm_combine_decode( + const at::Tensor &x, + const at::Tensor &expert_ids, + const at::Tensor &gmm1_permuted_weight, + const at::Tensor &gmm1_permuted_weight_scale, + const at::Tensor &gmm2_weight, + const at::Tensor &gmm2_weight_scale, + const c10::optional &expert_smooth_scales, + const c10::optional &expert_scales, + c10::string_view group_ep, + int64_t ep_rank_size, + int64_t ep_rank_id, + int64_t moe_expert_num, + int64_t shared_expert_num, + int64_t shared_expert_rank_num, + int64_t quant_mode, + int64_t global_bs) +{ + auto x_shape = x.sizes(); + int bs = x_shape[0]; + int h = x_shape[1]; + + at::Tensor output = at::empty({bs, h}, x.options()); + + bool is_shared_expert = (ep_rank_id < shared_expert_rank_num); + int64_t num_local_experts = is_shared_expert ? 1 : moe_expert_num / (ep_rank_size - shared_expert_rank_num); + at::Tensor ep_recv_count = at::empty({num_local_experts * ep_rank_size}, expert_ids.options()); + + vector group_ep_chrs(group_ep.begin(), group_ep.end()); + group_ep_chrs.push_back('\0'); + char *group_ep_ptr = &group_ep_chrs[0]; + EXEC_NPU_CMD( + // op api + aclnnDispatchGmmCombineDecode, + // input tensors + x, + expert_ids, + gmm1_permuted_weight, + gmm1_permuted_weight_scale, + gmm2_weight, + gmm2_weight_scale, + expert_smooth_scales, + expert_scales, + //input attrs + group_ep_ptr, + ep_rank_size, + ep_rank_id, + moe_expert_num, + shared_expert_num, + shared_expert_rank_num, + quant_mode, + global_bs, + // output tensors + output, + ep_recv_count); + return {output, ep_recv_count}; +} } // namespace vllm_ascend TORCH_LIBRARY_EXPAND(CONCAT(_C, _ascend), ops) @@ -657,4 +714,17 @@ TORCH_LIBRARY_EXPAND(CONCAT(_C, _ascend), ops) " (Tensor output, Tensor output_scale, Tensor output_offset)" ); ops.impl("grouped_matmul_swiglu_quant_weight_nz_tensor_list", torch::kPrivateUse1, &vllm_ascend::grouped_matmul_swiglu_quant_weight_nz_tensor_list); + + ops.def( + "dispatch_gmm_combine_decode(Tensor x, Tensor expert_ids, Tensor gmm1_permuted_weight," + " Tensor gmm1_permuted_weight_scale," + " Tensor gmm2_weight, Tensor gmm2_weight_scale," + " Tensor? expert_smooth_scales=None, Tensor? expert_scales=None," + " str group_ep=''," + " int ep_rank_size=0, int ep_rank_id=0, int moe_expert_num=0," + " int shared_expert_num=1, int shared_expert_rank_num=0," + " int quant_mode=0," + " int global_bs=0) -> (Tensor output, Tensor ep_recv_count)" + ); + ops.impl("dispatch_gmm_combine_decode", torch::kPrivateUse1, &vllm_ascend::dispatch_gmm_combine_decode); } diff --git a/csrc/torch_binding_meta.cpp b/csrc/torch_binding_meta.cpp index 26b3d66de03..a382c9e6abc 100644 --- a/csrc/torch_binding_meta.cpp +++ b/csrc/torch_binding_meta.cpp @@ -151,6 +151,37 @@ std::tuple grouped_matmul_swiglu_quant_weigh return std::tuple(output, output_scale, output_offset); } +std::tuple dispatch_gmm_combine_decode_meta( + const at::Tensor &x, + const at::Tensor &expert_ids, + const at::Tensor &gmm1_permuted_weight, + const at::Tensor &gmm1_permuted_weight_scale, + const at::Tensor &gmm2_weight, + const at::Tensor &gmm2_weight_scale, + const c10::optional &expert_smooth_scales, + const c10::optional &expert_scales, + c10::string_view group_ep, + int64_t ep_rank_size, + int64_t ep_rank_id, + int64_t moe_expert_num, + int64_t shared_expert_num, + int64_t shared_expert_rank_num, + int64_t quant_mode, + int64_t global_bs) +{ + auto x_shape = x.sizes(); + int bs = x_shape[0]; + int h = x_shape[1]; + + at::Tensor output = at::empty({bs, h}, x.options().device(at::kMeta)); + + bool is_shared_expert = (ep_rank_id < shared_expert_rank_num); + int64_t num_local_experts = is_shared_expert ? 1 : moe_expert_num / (ep_rank_size - shared_expert_rank_num); + at::Tensor ep_recv_count = at::empty({num_local_experts * ep_rank_size}, expert_ids.options().device(at::kMeta)); + + return {output, ep_recv_count}; +} + } // namespace meta } // namespace vllm_ascend @@ -172,5 +203,7 @@ TORCH_LIBRARY_IMPL_EXPAND(CONCAT(_C, _ascend), Meta, ops) { ops.impl("grouped_matmul_swiglu_quant", &vllm_ascend::meta::grouped_matmul_swiglu_quant); // Grouped matmul swiglu quant weight nz tensor list ops.impl("grouped_matmul_swiglu_quant_weight_nz_tensor_list", &vllm_ascend::meta::grouped_matmul_swiglu_quant_weight_nz_tensor_list_meta); + // dispatch_gmm_combine_decode meta implementation + ops.impl("dispatch_gmm_combine_decode", &vllm_ascend::meta::dispatch_gmm_combine_decode_meta); } } diff --git a/tests/e2e/nightly/ops/test_dispatch_gmm_combine_decode.py b/tests/e2e/nightly/ops/test_dispatch_gmm_combine_decode.py new file mode 100644 index 00000000000..f47533d422b --- /dev/null +++ b/tests/e2e/nightly/ops/test_dispatch_gmm_combine_decode.py @@ -0,0 +1,319 @@ +import os +import sys +import gc +import numpy as np +import torch +import torch_npu +import torch.distributed as dist +import torch.multiprocessing as mp +import torchair +from pathlib import Path +from torchair.configs.compiler_config import CompilerConfig +from vllm_ascend.utils import enable_custom_op + +torch_npu.npu.set_compile_mode(jit_compile=True) +config = CompilerConfig() +npu_backend = torchair.get_npu_backend(compiler_config=config) +torch_npu.npu.config.allow_internal_format = True +enable_custom_op() +LOG_NAME = "dispatch_gmm_combine_decode_test_logs" + +def redirect_output(log_file_path): + log_path = Path(LOG_NAME) / log_file_path + log_path.parent.mkdir(parents=True, exist_ok=True) + f = open(LOG_NAME + "/" + log_file_path, "w") + os.dup2(f.fileno(), sys.stdout.fileno()) + os.dup2(f.fileno(), sys.stderr.fileno()) + return f + +def permute_weight(w: torch.Tensor, tile_n): + *dims, n = w.shape + order = list(range(len(dims))) + [-2, -3, -1] + return w.reshape(*dims, 2, n // tile_n, + tile_n // 2).permute(order).reshape(*dims, + n).contiguous() + +def output_to_file(rank_id): + return False + +class DecodeMoeOps(torch.nn.Module): + def __init__(self, gmm1_weight, gmm1_weight_scale, gmm2_weight, gmm2_weight_scale, ep_hcomm_info, + batch_size, token_hidden_size, moe_intermediate_size, + ep_world_size, moe_expert_num, global_rank_id, shared_expert_rank_num=0): + super().__init__() + self.gmm1_weight = None + self.gmm1_weight_scale = None + self.gmm2_weight = None + self.gmm2_weight_scale = None + self.ep_hcomm_info = ep_hcomm_info + self.batch_size = batch_size + self.token_hidden_size = token_hidden_size + self.moe_intermediate_size = moe_intermediate_size + self.ep_world_size = ep_world_size + self.moe_expert_num = moe_expert_num + self.global_rank_id = global_rank_id + self.shared_expert_rank_num = shared_expert_rank_num + self._process_weights_after_loading(gmm1_weight, gmm1_weight_scale, gmm2_weight, gmm2_weight_scale) + + def _process_weights_after_loading(self, gmm1_weight, gmm1_weight_scale, gmm2_weight, gmm2_weight_scale): + raise NotImplementedError("To be implemented in subclass") + + def _apply_ops(self, x, expert_ids, smooth_scales, expert_scales): + raise NotImplementedError("To be implemented in subclass") + + def forward(self, x, expert_ids, smooth_scales, expert_scales): + return self._apply_ops(x, expert_ids, smooth_scales, expert_scales) + + +class SmallOps(DecodeMoeOps): + + def __init__(self, gmm1_weight, gmm1_weight_scale, gmm2_weight, gmm2_weight_scale, ep_hcomm_info, + batch_size, token_hidden_size, moe_intermediate_size, + ep_world_size, moe_expert_num, global_rank_id, shared_expert_rank_num=0): + super().__init__(gmm1_weight, gmm1_weight_scale, gmm2_weight, gmm2_weight_scale, ep_hcomm_info, + batch_size, token_hidden_size, moe_intermediate_size, + ep_world_size, moe_expert_num, global_rank_id, shared_expert_rank_num) + self.tp_hcomm_info = "" + + def _process_weights_after_loading(self, gmm1_weight, gmm1_weight_scale, gmm2_weight, gmm2_weight_scale): + gmm1_weight = torch_npu.npu_format_cast(gmm1_weight, torch_npu.Format.FRACTAL_NZ) + gmm2_weight = torch_npu.npu_format_cast(gmm2_weight, torch_npu.Format.FRACTAL_NZ) + self.gmm1_weight = torch.nn.Parameter(gmm1_weight, requires_grad=False) + self.gmm1_weight_scale = torch.nn.Parameter(gmm1_weight_scale, requires_grad=False) + self.gmm2_weight = torch.nn.Parameter(gmm2_weight, requires_grad=False) + self.gmm2_weight_scale = torch.nn.Parameter(gmm2_weight_scale, requires_grad=False) + + def _apply_ops(self, x, expert_ids, smooth_scales, expert_scales): + outputs = torch_npu.npu_moe_distribute_dispatch_v2( + x=x, + expert_ids=expert_ids, + expert_scales=expert_scales, + group_ep=self.ep_hcomm_info, + ep_world_size=self.ep_world_size, + ep_rank_id=self.global_rank_id, + moe_expert_num=self.moe_expert_num, + group_tp=self.tp_hcomm_info, + tp_world_size=1, + tp_rank_id=0, + expert_shard_type=0, + shared_expert_num=1, + shared_expert_rank_num=self.shared_expert_rank_num, + quant_mode=2, + global_bs=self.batch_size * self.ep_world_size, + expert_token_nums_type=1, # 0代表前缀和,1代表各自数量 + ) + expand_x, dynamic_scales, assist_info_for_combine, expert_token_nums, ep_send_counts, tp_send_counts, expand_scales = outputs + output_dtype = x.dtype + + y1_int32 = torch_npu.npu_grouped_matmul( + x=[expand_x], + weight=[self.gmm1_weight], + split_item=3, + group_list_type=1, # 默认为0,代表前缀和形式 + group_type=0, # 0代表m轴分组 + group_list=expert_token_nums, + output_dtype=torch.int32)[0] + y1, y1_scale = torch_npu.npu_dequant_swiglu_quant( + x=y1_int32, + weight_scale=self.gmm1_weight_scale.to(torch.float32), + activation_scale=dynamic_scales, + bias=None, + quant_scale=None, + quant_offset=None, + group_index=expert_token_nums, + activate_left=True, + quant_mode=1, + ) + y2 = torch_npu.npu_grouped_matmul(x=[y1], + weight=[self.gmm2_weight], + scale=[self.gmm2_weight_scale], + per_token_scale=[y1_scale], + split_item=2, + group_list_type=1, + group_type=0, + group_list=expert_token_nums, + output_dtype=output_dtype)[0] + combine_output = torch_npu.npu_moe_distribute_combine_v2( + expand_x=y2, + expert_ids=expert_ids, + assist_info_for_combine=assist_info_for_combine, + ep_send_counts=ep_send_counts, + expert_scales=expert_scales, + group_ep=self.ep_hcomm_info, + ep_world_size=self.ep_world_size, + ep_rank_id=self.global_rank_id, + moe_expert_num=self.moe_expert_num, + tp_send_counts=tp_send_counts, + expand_scales=expand_scales, + group_tp=self.tp_hcomm_info, + tp_world_size=1, + tp_rank_id=0, + expert_shard_type=0, + shared_expert_num=1, + shared_expert_rank_num=self.shared_expert_rank_num, + global_bs=self.batch_size * self.ep_world_size) + return (combine_output, ep_send_counts) + + +class FusionOp(DecodeMoeOps): + + def __init__(self, gmm1_weight, gmm1_weight_scale, gmm2_weight, gmm2_weight_scale, ep_hcomm_info, + batch_size, token_hidden_size, moe_intermediate_size, + ep_world_size, moe_expert_num, global_rank_id, shared_expert_rank_num=0): + super().__init__(gmm1_weight, gmm1_weight_scale, gmm2_weight, gmm2_weight_scale, ep_hcomm_info, + batch_size, token_hidden_size, moe_intermediate_size, + ep_world_size, moe_expert_num, global_rank_id, shared_expert_rank_num) + + def _process_weights_after_loading(self, gmm1_weight, gmm1_weight_scale, gmm2_weight, gmm2_weight_scale): + gmm1_weight = gmm1_weight.transpose(1,2).contiguous()\ + .view(-1, 2, 32, 64, 7168).transpose(1,2).contiguous()\ + .view(-1, 4096, 7168).transpose(1,2).contiguous() + gmm1_weight = torch_npu.npu_format_cast(gmm1_weight, torch_npu.Format.ND) + gmm1_weight.add_(0) + gmm1_weight = torch_npu.npu_format_cast(gmm1_weight, torch_npu.Format.FRACTAL_NZ) + gmm1_weight_scale = permute_weight(gmm1_weight_scale, 128) + gmm2_weight = torch_npu.npu_format_cast( + gmm2_weight.transpose(1, 2).contiguous(), torch_npu.Format.FRACTAL_NZ) + + gmm1_weight_scale = gmm1_weight_scale.float() + gmm2_weight_scale = gmm2_weight_scale.float() + + self.gmm1_weight = torch.nn.Parameter(gmm1_weight, requires_grad=False) + self.gmm1_weight_scale = torch.nn.Parameter(gmm1_weight_scale, requires_grad=False) + self.gmm2_weight = torch.nn.Parameter(gmm2_weight, requires_grad=False) + self.gmm2_weight_scale = torch.nn.Parameter(gmm2_weight_scale, requires_grad=False) + + def _apply_ops(self, x, expert_ids, smooth_scales, expert_scales): + output = torch.ops._C_ascend.dispatch_gmm_combine_decode( + x=x, + expert_ids=expert_ids, + gmm1_permuted_weight=self.gmm1_weight, + gmm1_permuted_weight_scale=self.gmm1_weight_scale, + gmm2_weight=self.gmm2_weight, + gmm2_weight_scale=self.gmm2_weight_scale, + expert_smooth_scales=smooth_scales, + expert_scales=expert_scales, + group_ep=self.ep_hcomm_info, + ep_rank_size=self.ep_world_size, + ep_rank_id=self.global_rank_id, + moe_expert_num=self.moe_expert_num, + shared_expert_num=1, + shared_expert_rank_num=self.shared_expert_rank_num, + quant_mode=0, + global_bs=self.batch_size * self.ep_world_size) + return output + + +def generate_datas( + batch_size, token_hidden_size, moe_intermediate_size, + ep_world_size, moe_expert_num, global_rank_id, shared_expert_rank_num=0, + top_k=8, test_bfloat16=True, enable_dynamic_bs=False): + is_shared_expert = global_rank_id < shared_expert_rank_num + moe_expert_num_per_rank = moe_expert_num // (ep_world_size - shared_expert_rank_num) + actual_bs = int(torch.randint(1, batch_size, [1]).item() if enable_dynamic_bs else batch_size) + local_expert_num = 1 if is_shared_expert else moe_expert_num_per_rank + gmm1_input_dim = token_hidden_size + gmm1_output_dim = moe_intermediate_size * 2 + gmm2_input_dim = moe_intermediate_size + gmm2_output_dim = token_hidden_size + x = torch.rand([actual_bs, token_hidden_size]) * 10 - 5 + expert_ids = torch.arange( + global_rank_id * batch_size * top_k, + global_rank_id * batch_size * top_k + actual_bs * top_k).to(torch.int32).view(actual_bs, top_k) + if is_shared_expert: + gmm1_weight = torch.ones([local_expert_num, gmm1_input_dim, gmm1_output_dim + ]).to(torch.int8) * 4 + gmm2_weight = torch.ones([local_expert_num, gmm2_input_dim, gmm2_output_dim + ]).to(torch.int8) * 4 + gmm1_weight[:, :, ::2] = gmm1_weight[:, :, ::2] * -1 + gmm2_weight[:, :, ::2] = gmm2_weight[:, :, ::2] * -1 + gmm1_weight_scale = torch.ones([local_expert_num, gmm1_output_dim + ]) * 0.0015 + gmm2_weight_scale = torch.ones([local_expert_num, gmm2_output_dim + ]) * 0.0015 + else: + gmm1_weight = torch.randint( + -16, 16, + [local_expert_num, gmm1_input_dim, gmm1_output_dim]).to(torch.int8) + gmm2_weight = torch.randint( + -16, 16, + [local_expert_num, gmm2_input_dim, gmm2_output_dim]).to(torch.int8) + gmm1_weight_scale = torch.rand([local_expert_num, gmm1_output_dim + ]) * 0.003 + 0.0015 + gmm2_weight_scale = torch.rand([local_expert_num, gmm2_output_dim + ]) * 0.003 + 0.0015 + expert_scales = torch.rand(actual_bs, top_k) + if test_bfloat16: + x = x.bfloat16() + gmm1_weight_scale = gmm1_weight_scale.bfloat16() + gmm2_weight_scale = gmm2_weight_scale.bfloat16() + else: + x = x.half() + smooth_sales = None + return (x, expert_ids, smooth_sales, expert_scales), \ + (gmm1_weight, gmm1_weight_scale, gmm2_weight, gmm2_weight_scale), \ + actual_bs + +def run_once(local_rank_id, + batch_size, token_hidden_size, moe_intermediate_size, + ep_world_size, moe_expert_num, shared_expert_rank_num=0, + top_k=8, test_bfloat16=True, enable_dynamic_bs=False, test_graph=False): + log_file = redirect_output(f"local_rank_{local_rank_id}.log") if output_to_file(local_rank_id) else None + global_rank_id = local_rank_id # 单机 + device_id = local_rank_id % 16 + torch_npu.npu.set_device(device_id) + + # 初始化分布式环境 + os.environ["MASTER_ADDR"] = "127.0.0.1" + os.environ["MASTER_PORT"] = "29500" # 端口号随意 + dist.init_process_group(backend="hccl", + rank=local_rank_id, + world_size=ep_world_size) + ep_ranks_list = list(np.arange(0, ep_world_size)) + ep_group = dist.new_group(backend="hccl", ranks=ep_ranks_list) + ep_group_small = dist.new_group(backend="hccl", ranks=ep_ranks_list) + + ep_hcomm_info_fused = ep_group._get_backend( + torch.device("npu")).get_hccl_comm_name(local_rank_id) + ep_hcomm_info_small = ep_group_small._get_backend( + torch.device("npu")).get_hccl_comm_name(local_rank_id) + torch_npu.npu.synchronize(device_id) + + parameter=(batch_size, token_hidden_size, moe_intermediate_size, + ep_world_size, moe_expert_num, global_rank_id, shared_expert_rank_num) + input_datas, weight_datas, actual_bs = generate_datas( + *parameter, top_k, test_bfloat16, enable_dynamic_bs) + input_datas = [data.npu() if data is not None else None for data in input_datas] + weight_datas = [data.npu() if data is not None else None for data in weight_datas] + small_ops = SmallOps(*weight_datas, ep_hcomm_info_small, *parameter).npu() + fused_ops = FusionOp(*weight_datas, ep_hcomm_info_fused, *parameter).npu() + if test_graph: + fused_ops = torch.compile(fused_ops, mode="default") + small_op_token_output, small_op_count_output = small_ops(*input_datas) + fused_op_token_output, fused_op_count_output = fused_ops(*input_datas) + torch_npu.npu.synchronize(device_id) + dist.destroy_process_group() + if log_file is not None: + log_file.close() + torch.testing.assert_close(small_op_token_output.cpu(), fused_op_token_output.cpu()) + torch.testing.assert_close(small_op_count_output.cpu(), fused_op_count_output.cpu()) + gc.collect() + torch.npu.empty_cache() + torch.npu.reset_peak_memory_stats() + +@torch.inference_mode() +def test(): + batch_size=64 + token_hidden_size=7168 + moe_intermediate_size=2048 + ep_world_size=16 + moe_expert_num=64 + shared_expert_rank_num=0 + top_k=8 + test_bfloat16=True + enable_dynamic_bs=False + test_graph=False + args=(batch_size, token_hidden_size, moe_intermediate_size, + ep_world_size, moe_expert_num, shared_expert_rank_num, + top_k, test_bfloat16, enable_dynamic_bs, test_graph) + mp.spawn(run_once, args=args, nprocs=ep_world_size, join=True) diff --git a/tests/e2e/nightly/ops/test_fused_deep_moe_accuracy.py b/tests/e2e/nightly/ops/test_fused_deep_moe_accuracy.py new file mode 100644 index 00000000000..f347d75a24e --- /dev/null +++ b/tests/e2e/nightly/ops/test_fused_deep_moe_accuracy.py @@ -0,0 +1,340 @@ +import os +import sys +from pathlib import Path + +import numpy as np +import torch +import torch.distributed as dist +import torch.multiprocessing as mp +import torch_npu + +from vllm_ascend.utils import enable_custom_op + +torch_npu.npu.config.allow_internal_format = True +use_graph = False +test_bfloat16 = True +enable_dynamic_bs = False +if use_graph: + import torchair + from torchair.configs.compiler_config import CompilerConfig + torch_npu.npu.set_compile_mode(jit_compile=True) + config = CompilerConfig() + npu_backend = torchair.get_npu_backend(compiler_config=config) + +enable_custom_op() + +TP = 1 +print( + f"{len(sys.argv)= }, {sys.argv= }\n{use_graph= }, {test_bfloat16= }, {enable_dynamic_bs= }" +) +assert len( + sys.argv +) == 7, "入参列表:[0]rank_size, [1]share_expert_rank_num, [2]moe_expert_num, [3]bs, [4]name, [5]loop_cnt" +ep_world_size = int(sys.argv[1]) +SHARE_RANK_NUM = int(sys.argv[2]) +MOE_RANK_NUM = ep_world_size - SHARE_RANK_NUM +MOE_EXPERT_NUM = int(sys.argv[3]) +MOE_EXPERT_NUM_PER_RANK = MOE_EXPERT_NUM // MOE_RANK_NUM +RANK_BS = int(sys.argv[4]) +LOG_NAME = str(sys.argv[5]) +loop_times = int(str(sys.argv[6])) +node_num = 1 + +SHARE_EXPERT_NUM = SHARE_RANK_NUM +DISPATCH_QUANT = True +H = 7168 +K = 8 +GMM1_INPUT = H +GMM1_HIDDEN = 4096 +GMM2_INPUT = GMM1_HIDDEN // 2 +GMM2_HIDDEN = H + +global_rank_id = 0 +ep_hcomm_info = None +ep_hcomm_info_small = None +commArgs = None +tp_hcomm_info = None +device_id = None + + +def redirect_output(log_file_path): + log_path = Path(LOG_NAME) / log_file_path + log_path.parent.mkdir(parents=True, exist_ok=True) + f = open(LOG_NAME + "/" + log_file_path, "w") + os.dup2(f.fileno(), sys.stdout.fileno()) + os.dup2(f.fileno(), sys.stderr.fileno()) + return f + + +def permute_weight(w: torch.Tensor, tile_n): + *dims, n = w.shape + order = list(range(len(dims))) + [-2, -3, -1] + return w.reshape(*dims, 2, n // tile_n, + tile_n // 2).permute(order).reshape(*dims, + n).contiguous() + + +def output_to_file(rank_id): + # return True + return rank_id not in [0, SHARE_RANK_NUM] + + +class SmallOps(torch.nn.Module): + + def __init__(self): + super().__init__() + + def forward(self, x, expert_ids, gmm1_weight, gmm1_weight_scale, + gmm2_weight, gmm2_weight_scale, smooth_scales, expert_scales): + outputs = torch_npu.npu_moe_distribute_dispatch_v2( + x=x, + expert_ids=expert_ids, + expert_scales=expert_scales, + group_ep=ep_hcomm_info_small, + ep_world_size=ep_world_size, + ep_rank_id=global_rank_id, + moe_expert_num=MOE_EXPERT_NUM, + group_tp=tp_hcomm_info, + tp_world_size=1, + tp_rank_id=0, + expert_shard_type=0, + shared_expert_num=1, + shared_expert_rank_num=SHARE_RANK_NUM, + quant_mode=2 if DISPATCH_QUANT else 0, + global_bs=RANK_BS * ep_world_size, + expert_token_nums_type=1, # 0代表前缀和,1代表各自数量 + ) + expand_x, dynamic_scales, assist_info_for_combine, expert_token_nums, ep_send_counts, tp_send_counts, expand_scales = outputs + output_dtype = torch.bfloat16 if test_bfloat16 else torch.half + + y1_int32 = torch_npu.npu_grouped_matmul( + x=[expand_x], + weight=[gmm1_weight], + split_item=3, + group_list_type=1, # 默认为0,代表前缀和形式 + group_type=0, # 0代表m轴分组 + group_list=expert_token_nums, + output_dtype=torch.int32)[0] + y1, y1_scale = torch_npu.npu_dequant_swiglu_quant( + x=y1_int32, + weight_scale=gmm1_weight_scale.to(torch.float32), + activation_scale=dynamic_scales, + bias=None, + quant_scale=None, + quant_offset=None, + group_index=expert_token_nums, + activate_left=True, + quant_mode=1, + ) + y2 = torch_npu.npu_grouped_matmul(x=[y1], + weight=[gmm2_weight], + scale=[gmm2_weight_scale], + per_token_scale=[y1_scale], + split_item=2, + group_list_type=1, + group_type=0, + group_list=expert_token_nums, + output_dtype=output_dtype)[0] + combine_output = torch_npu.npu_moe_distribute_combine_v2( + expand_x=y2, + expert_ids=expert_ids, + assist_info_for_combine=assist_info_for_combine, + ep_send_counts=ep_send_counts, + expert_scales=expert_scales, + group_ep=ep_hcomm_info_small, + ep_world_size=ep_world_size, + ep_rank_id=global_rank_id, + moe_expert_num=MOE_EXPERT_NUM, + tp_send_counts=tp_send_counts, + expand_scales=expand_scales, + group_tp=tp_hcomm_info, + tp_world_size=1, + tp_rank_id=0, + expert_shard_type=0, + shared_expert_num=1, + shared_expert_rank_num=SHARE_RANK_NUM, + global_bs=RANK_BS * ep_world_size) + return combine_output + + +class FusionOp(torch.nn.Module): + + def __init__(self): + super().__init__() + + def forward(self, x, expert_ids, gmm1_weight, gmm1_weight_scale, + gmm2_weight, gmm2_weight_scale, smooth_scales, expert_scales): + output = torch.ops._C_ascend.dispatch_gmm_combine_decode( + x, expert_ids, gmm1_weight, gmm1_weight_scale, gmm2_weight, + gmm2_weight_scale, smooth_scales, expert_scales, ep_hcomm_info, + ep_world_size, global_rank_id, MOE_EXPERT_NUM, 1, SHARE_RANK_NUM, + 0, RANK_BS * ep_world_size) + return output + + +def generate_datas(): + if enable_dynamic_bs: + actual_bs = torch.randint(1, RANK_BS, [1]).item() + print(f"rank-{global_rank_id}: {actual_bs=}") + else: + actual_bs = RANK_BS + local_expert_num = 1 if global_rank_id < SHARE_RANK_NUM else MOE_EXPERT_NUM_PER_RANK + x = torch.rand([actual_bs, H]).half() + x = x * 10 - 5 + expert_ids = [ + i % MOE_EXPERT_NUM + for i in range(global_rank_id * RANK_BS * + K, global_rank_id * RANK_BS * K + actual_bs * K) + ] + expert_ids = torch.Tensor(expert_ids).to(torch.int32).view(actual_bs, K) + if global_rank_id < SHARE_RANK_NUM: + gmm1_weight = torch.ones([local_expert_num, GMM1_INPUT, GMM1_HIDDEN + ]).to(torch.int8) * 4 + gmm2_weight = torch.ones([local_expert_num, GMM2_INPUT, GMM2_HIDDEN + ]).to(torch.int8) * 4 + gmm1_weight[:, :, ::2] = gmm1_weight[:, :, ::2] * -1 + gmm2_weight[:, :, ::2] = gmm2_weight[:, :, ::2] * -1 + gmm1_weight_scale = torch.ones([local_expert_num, GMM1_HIDDEN + ]) * 0.0015 + gmm2_weight_scale = torch.ones([local_expert_num, GMM2_HIDDEN + ]) * 0.0015 + else: + gmm1_weight = torch.randint( + -16, 16, + [local_expert_num, GMM1_INPUT, GMM1_HIDDEN]).to(torch.int8) + gmm2_weight = torch.randint( + -16, 16, + [local_expert_num, GMM2_INPUT, GMM2_HIDDEN]).to(torch.int8) + gmm1_weight_scale = torch.rand([local_expert_num, GMM1_HIDDEN + ]) * 0.003 + 0.0015 + gmm2_weight_scale = torch.rand([local_expert_num, GMM2_HIDDEN + ]) * 0.003 + 0.0015 + expert_scales = torch.rand(actual_bs, K) + if test_bfloat16: + x = x.bfloat16() + gmm1_weight_scale = gmm1_weight_scale.bfloat16() + gmm2_weight_scale = gmm2_weight_scale.bfloat16() + else: + x = x.half() + return x, expert_ids, gmm1_weight, gmm1_weight_scale, gmm2_weight, gmm2_weight_scale, None, expert_scales + + +def test_small_op(x, expert_ids, gmm1_weight, gmm1_weight_scale, gmm2_weight, + gmm2_weight_scale, smooth_sales, expert_scales): + small_op = SmallOps().npu() + gmm1_weight = torch_npu.npu_format_cast(gmm1_weight, 29) + gmm2_weight = torch_npu.npu_format_cast(gmm2_weight, 29) + for _ in range(1, loop_times + 1): + output = small_op(x, expert_ids, gmm1_weight, gmm1_weight_scale, + gmm2_weight, gmm2_weight_scale, smooth_sales, + expert_scales) + return output + + +def test_fused_op(x, expert_ids, gmm1_weight, gmm1_weight_scale, gmm2_weight, + gmm2_weight_scale, smooth_sales, expert_scales): + fused_op = FusionOp().npu() + if use_graph: + fused_op = torch.compile(fused_op, backend=npu_backend) + gmm1_weight = gmm1_weight.transpose(1,2).contiguous()\ + .view(-1, 2, 32, 64, 7168).transpose(1,2).contiguous()\ + .view(-1, 4096, 7168).transpose(1,2).contiguous() + gmm1_weight = torch_npu.npu_format_cast(gmm1_weight, 2) + gmm1_weight.add_(0) + gmm1_weight = torch_npu.npu_format_cast(gmm1_weight, 29) + + gmm1_weight_scale = permute_weight(gmm1_weight_scale, 128) + gmm2_weight = torch_npu.npu_format_cast( + gmm2_weight.transpose(1, 2).contiguous(), 29) + + if test_bfloat16: + gmm1_weight_scale = gmm1_weight_scale.float() + gmm2_weight_scale = gmm2_weight_scale.float() + + smooth_sales = torch.ones( + [RANK_BS]).float().npu() if smooth_sales is None else smooth_sales + for _ in range(1, loop_times + 1): + # print(f"iter: {_} / {loop_times}") + output = fused_op(x, expert_ids, gmm1_weight, gmm1_weight_scale, + gmm2_weight, gmm2_weight_scale, smooth_sales, + expert_scales) + torch_npu.npu.synchronize(device_id) + print("fused op run end") + # 只返回一个出参 + return output[0] + + +def test(): + tensor_datas = [ + data.npu() if data is not None else None for data in generate_datas() + ] + + small_op_datas = [ + data.clone().detach() if data is not None else None + for data in tensor_datas + ] + small_op_output = test_small_op(*small_op_datas) + print( + f"{small_op_output= }\n {small_op_output.abs().mean()=}, {small_op_output.abs().max()=}" + ) + + fused_op_datas = [ + data.clone().detach() if data is not None else None + for data in tensor_datas + ] + fused_op_output = test_fused_op(*fused_op_datas) + print( + f"{fused_op_output= }\n {fused_op_output.abs().mean()=}, {fused_op_output.abs().max()=}" + ) + + diff = (small_op_output - fused_op_output).abs() + print( + f"[info-{global_rank_id}] dispatch gmm combine decode: {diff.max()= }, {diff.mean()= }" + ) + + +def worker(rank, ep_world_size): + if output_to_file(rank): + log_file = redirect_output(f"log_test_accuracy_rank_{rank}.txt") + global global_rank_id, ep_hcomm_info, ep_hcomm_info_small, tp_hcomm_info, device_id + global_rank_id = rank + device_id = rank % 16 + torch_npu.npu.set_device(device_id) + + # 初始化分布式环境 + os.environ["MASTER_ADDR"] = "127.0.0.1" + os.environ["MASTER_PORT"] = "29500" # 端口号随意 + dist.init_process_group(backend="hccl", + rank=rank, + world_size=ep_world_size) + + print(f"[info-{rank}] start ep comm init...") + ep_ranks_list = list(np.arange(0, ep_world_size)) + print(f"[info-{rank}] ep rank list:", ep_ranks_list) + ep_group = dist.new_group(backend="hccl", ranks=ep_ranks_list) + ep_group_small = dist.new_group(backend="hccl", ranks=ep_ranks_list) + tp_group = dist.new_group(backend="hccl", ranks=[rank]) + + ep_hcomm_info = ep_group._get_backend( + torch.device("npu")).get_hccl_comm_name(rank) + ep_hcomm_info_small = ep_group_small._get_backend( + torch.device("npu")).get_hccl_comm_name(rank) + tp_hcomm_info = tp_group._get_backend( + torch.device("npu")).get_hccl_comm_name(rank) + + torch_npu.npu.synchronize(device_id) + print( + f"[info-{rank}] ep group: {ep_group}, ep_hcomm_info:{type(ep_hcomm_info)}" + ) + + # 测试 + test() + # 关闭进程组 + torch_npu.npu.synchronize(device_id) + dist.destroy_process_group() + if output_to_file(rank): + log_file.close() + + +if __name__ == "__main__": + mp.spawn(worker, args=(ep_world_size, ), nprocs=ep_world_size, join=True)