Skip to content

Commit 8011f0e

Browse files
committed
[kernel] Add DispatchGmmCombineDecode aclnn operator
Signed-off-by: wangqiankun13 <[email protected]>
1 parent 048d350 commit 8011f0e

30 files changed

+7917
-2
lines changed

.gitsubmodules

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
[submodule "csrc/third_party/catlass"]
2+
path = csrc/third_party/catlass
3+
url = https://gitee.com/ascend/catlass.git
4+
branch = catlass-v1-stable

csrc/build_aclnn.sh

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,14 @@ elif [[ "$SOC_VERSION" =~ ^ascend910b ]]; then
1515
SOC_ARG="ascend910b"
1616
elif [[ "$SOC_VERSION" =~ ^ascend910_93 ]]; then
1717
# ASCEND910C (A3) series
18-
CUSTOM_OPS="grouped_matmul_swiglu_quant_weight_nz_tensor_list"
18+
# depdendency: catlass
19+
CATLASS_PATH=${ROOT_DIR}/csrc/third_party/catlass
20+
if [[ ! -d "${CATLASS_PATH}" ]]; then
21+
echo "depdendency catlass does not exist, please run 'git submodule update --init --recursive'"
22+
exit 1
23+
fi
24+
export CPATH=${CATLASS_PATH}/include:${CPATH}
25+
CUSTOM_OPS="grouped_matmul_swiglu_quant_weight_nz_tensor_list;dispatch_gmm_combine_decode"
1926
SOC_ARG="ascend910_93"
2027
else
2128
# others

csrc/cmake/func.cmake

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -282,7 +282,7 @@ function(add_ops_src_copy)
282282
set(_BUILD_FLAG ${SRC_COPY_DST}/${SRC_COPY_TARGET_NAME}.done)
283283
add_custom_command(OUTPUT ${_BUILD_FLAG}
284284
COMMAND mkdir -p ${SRC_COPY_DST}
285-
COMMAND cp -rf ${SRC_COPY_SRC}/op_kernel/*.* ${SRC_COPY_DST}
285+
COMMAND cp -rf ${SRC_COPY_SRC}/op_kernel/* ${SRC_COPY_DST}
286286
COMMAND touch ${_BUILD_FLAG}
287287
)
288288

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
# Copyright (c) 2025 Huawei Technologies Co., Ltd.
2+
# This file is a part of the CANN Open Software.
3+
# Licensed under CANN Open Software License Agreement Version 1.0 (the "License").
4+
# Please refer to the License for details. You may not use this file except in compliance with the License.
5+
# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
6+
# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
7+
# See LICENSE in the root of the software repository for the full text of the License.
8+
# ======================================================================================================================
9+
10+
add_ops_compile_options(
11+
OP_NAME GroupedMatmulSwigluQuantWeightNzTensorList
12+
OPTIONS --cce-auto-sync=off
13+
-Wno-deprecated-declarations
14+
-Werror
15+
)
16+
17+
target_sources(op_host_aclnnInner PRIVATE
18+
dispatch_gmm_combine_decode_def.cpp
19+
)
20+
21+
target_sources(opapi PRIVATE
22+
aclnn_dispatch_gmm_combine_decode.cpp
23+
)
24+
25+
if (NOT BUILD_OPEN_PROJECT)
26+
target_sources(aclnn_ops_train PRIVATE
27+
aclnn_dispatch_gmm_combine_decode.cpp
28+
)
29+
30+
target_sources(aclnn_ops_infer PRIVATE
31+
aclnn_dispatch_gmm_combine_decode.cpp
32+
)
33+
endif ()
34+
35+
target_sources(optiling PRIVATE
36+
dispatch_gmm_combine_decode_tiling.cpp
37+
)
38+
39+
target_include_directories(optiling PRIVATE
40+
${CMAKE_CURRENT_SOURCE_DIR}
41+
)
42+
43+
target_sources(opsproto PRIVATE
44+
dispatch_gmm_combine_decode_proto.cpp
45+
)
46+
47+
file(GLOB _GMM_Aclnn_header "${CMAKE_CURRENT_SOURCE_DIR}/aclnn_dispatch_gmm_combine_decode.h")
48+
49+
install(FILES ${_GMM_Aclnn_header}
50+
DESTINATION ${ACLNN_INC_INSTALL_DIR} OPTIONAL
51+
)
Lines changed: 100 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,100 @@
1+
/*
2+
* Copyright (c) 2025 Huawei Technologies Co., Ltd.
3+
* This file is a part of the CANN Open Software.
4+
* Licensed under CANN Open Software License Agreement Version 1.0 (the "License").
5+
* Please refer to the License for details. You may not use this file except in compliance with the License.
6+
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
7+
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
8+
* See LICENSE in the root of the software repository for the full text of the License.
9+
*/
10+
#include <string.h>
11+
#include "graph/types.h"
12+
#include "aclnn/opdev/platform.h"
13+
#include "aclnn_dispatch_gmm_combine_decode.h"
14+
15+
enum NnopbaseHcclServerType {
16+
NNOPBASE_HCCL_SERVER_TYPE_AICPU = 0,
17+
NNOPBASE_HCCL_SERVER_TYPE_MTE,
18+
NNOPBASE_HCCL_SERVER_TYPE_END
19+
};
20+
extern "C" void __attribute__((weak)) NnopbaseSetHcclServerType(void *executor, NnopbaseHcclServerType sType);
21+
extern aclnnStatus aclnnInnerDispatchGmmCombineDecodeGetWorkspaceSize(
22+
const aclTensor *x,
23+
const aclTensor *expertIds,
24+
const aclTensor *gmm1PermutedWeight,
25+
const aclTensor *gmm1PermutedWeightScale,
26+
const aclTensor *gmm2Weight,
27+
const aclTensor *gmm2WeightScale,
28+
const aclTensor *expertSmoothScalesOptional,
29+
const aclTensor *expertScalesOptional,
30+
char *groupEp,
31+
int64_t epRankSize,
32+
int64_t epRankId,
33+
int64_t moeExpertNum,
34+
int64_t shareExpertNum,
35+
int64_t shareExpertRankNum,
36+
int64_t quantMode,
37+
int64_t globalBs,
38+
const aclTensor *output,
39+
const aclTensor *epRecvCount,
40+
uint64_t *workspaceSize,
41+
aclOpExecutor **executor);
42+
extern aclnnStatus aclnnInnerDispatchGmmCombineDecode(
43+
void *workspace,
44+
uint64_t workspaceSize,
45+
aclOpExecutor *executor,
46+
aclrtStream stream);
47+
48+
#ifdef __cplusplus
49+
extern "C" {
50+
#endif
51+
52+
aclnnStatus aclnnDispatchGmmCombineDecodeGetWorkspaceSize(
53+
const aclTensor *x,
54+
const aclTensor *expertIds,
55+
const aclTensor *gmm1PermutedWeight,
56+
const aclTensor *gmm1PermutedWeightScale,
57+
const aclTensor *gmm2Weight,
58+
const aclTensor *gmm2WeightScale,
59+
const aclTensor *expertSmoothScalesOptional,
60+
const aclTensor *expertScalesOptional,
61+
char *groupEp,
62+
int64_t epRankSize,
63+
int64_t epRankId,
64+
int64_t moeExpertNum,
65+
int64_t shareExpertNum,
66+
int64_t shareExpertRankNum,
67+
int64_t quantMode,
68+
int64_t globalBs,
69+
const aclTensor *output,
70+
const aclTensor *epRecvCount,
71+
uint64_t *workspaceSize,
72+
aclOpExecutor **executor)
73+
{
74+
return aclnnInnerDispatchGmmCombineDecodeGetWorkspaceSize(x, expertIds, gmm1PermutedWeight, gmm1PermutedWeightScale,
75+
gmm2Weight, gmm2WeightScale, expertSmoothScalesOptional, expertScalesOptional, groupEp, epRankSize,
76+
epRankId, moeExpertNum, shareExpertNum, shareExpertRankNum, quantMode, globalBs,
77+
output, epRecvCount, workspaceSize, executor);
78+
}
79+
80+
aclnnStatus aclnnDispatchGmmCombineDecode(
81+
void *workspace,
82+
uint64_t workspaceSize,
83+
aclOpExecutor *executor,
84+
aclrtStream stream)
85+
{
86+
if (NnopbaseSetHcclServerType) {
87+
if (op::GetCurrentPlatformInfo().GetSocVersion() == op::SocVersion::ASCEND910B) {
88+
NnopbaseSetHcclServerType(executor, NNOPBASE_HCCL_SERVER_TYPE_AICPU);
89+
} else {
90+
NnopbaseSetHcclServerType(executor, NNOPBASE_HCCL_SERVER_TYPE_MTE);
91+
}
92+
}
93+
return aclnnInnerDispatchGmmCombineDecode(workspace, workspaceSize, executor, stream);
94+
}
95+
96+
#ifdef __cplusplus
97+
}
98+
#endif
99+
100+
Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
/*
2+
* Copyright (c) 2025 Huawei Technologies Co., Ltd.
3+
* This file is a part of the CANN Open Software.
4+
* Licensed under CANN Open Software License Agreement Version 1.0 (the "License").
5+
* Please refer to the License for details. You may not use this file except in compliance with the License.
6+
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
7+
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
8+
* See LICENSE in the root of the software repository for the full text of the License.
9+
*/
10+
#ifndef DISPATCH_GMM_COMBINE_DECODE
11+
#define DISPATCH_GMM_COMBINE_DECODE
12+
13+
#include "aclnn/acl_meta.h"
14+
15+
#ifdef __cplusplus
16+
extern "C" {
17+
#endif
18+
19+
__attribute__((visibility("default"))) aclnnStatus aclnnDispatchGmmCombineDecodeGetWorkspaceSize(
20+
const aclTensor *x,
21+
const aclTensor *expertIds,
22+
const aclTensor *gmm1PermutedWeight,
23+
const aclTensor *gmm1PermutedWeightScale,
24+
const aclTensor *gmm2Weight,
25+
const aclTensor *gmm2WeightScale,
26+
const aclTensor *expertSmoothScalesOptional,
27+
const aclTensor *expertScalesOptional,
28+
char *groupEp,
29+
int64_t epRankSize,
30+
int64_t epRankId,
31+
int64_t moeExpertNum,
32+
int64_t shareExpertNum,
33+
int64_t shareExpertRankNum,
34+
int64_t quantMode,
35+
int64_t globalBs,
36+
const aclTensor *output,
37+
const aclTensor *epRecvCount,
38+
uint64_t *workspaceSize,
39+
aclOpExecutor **executor);
40+
41+
__attribute__((visibility("default"))) aclnnStatus aclnnDispatchGmmCombineDecode(
42+
void *workspace,
43+
uint64_t workspaceSize,
44+
aclOpExecutor *executor,
45+
aclrtStream stream);
46+
47+
#ifdef __cplusplus
48+
}
49+
#endif
50+
51+
#endif
Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,83 @@
1+
/*
2+
* Copyright (c) 2025 Huawei Technologies Co., Ltd.
3+
* This file is a part of the CANN Open Software.
4+
* Licensed under CANN Open Software License Agreement Version 1.0 (the "License").
5+
* Please refer to the License for details. You may not use this file except in compliance with the License.
6+
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
7+
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
8+
* See LICENSE in the root of the software repository for the full text of the License.
9+
*/
10+
#include "register/op_def_registry.h"
11+
12+
namespace ops {
13+
class DispatchGmmCombineDecode : public OpDef
14+
{
15+
public:
16+
explicit DispatchGmmCombineDecode(const char *name) : OpDef(name)
17+
{
18+
this->Input("x")
19+
.ParamType(REQUIRED)
20+
.DataType({ge::DT_BF16, ge::DT_FLOAT16})
21+
.Format({ge::FORMAT_ND, ge::FORMAT_ND})
22+
.UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND});
23+
this->Input("expert_ids")
24+
.ParamType(REQUIRED)
25+
.DataType({ge::DT_INT32, ge::DT_INT32})
26+
.Format({ge::FORMAT_ND, ge::FORMAT_ND})
27+
.UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND});
28+
this->Input("gmm1_permuted_weight")
29+
.ParamType(REQUIRED)
30+
.DataType({ge::DT_INT8, ge::DT_INT8})
31+
.Format({ge::FORMAT_FRACTAL_NZ, ge::FORMAT_FRACTAL_NZ})
32+
.UnknownShapeFormat({ge::FORMAT_FRACTAL_NZ, ge::FORMAT_FRACTAL_NZ});
33+
this->Input("gmm1_permuted_weight_scale")
34+
.ParamType(REQUIRED)
35+
.DataType({ge::DT_FLOAT, ge::DT_FLOAT})
36+
.Format({ge::FORMAT_ND, ge::FORMAT_ND})
37+
.UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND});
38+
this->Input("gmm2_weight")
39+
.ParamType(REQUIRED)
40+
.DataType({ge::DT_INT8, ge::DT_INT8})
41+
.Format({ge::FORMAT_FRACTAL_NZ, ge::FORMAT_FRACTAL_NZ})
42+
.UnknownShapeFormat({ge::FORMAT_FRACTAL_NZ, ge::FORMAT_FRACTAL_NZ});
43+
this->Input("gmm2_weight_scale")
44+
.ParamType(REQUIRED)
45+
.DataType({ge::DT_FLOAT, ge::DT_FLOAT})
46+
.Format({ge::FORMAT_ND, ge::FORMAT_ND})
47+
.UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND});
48+
this->Input("expert_smooth_scales")
49+
.ParamType(OPTIONAL)
50+
.DataType({ge::DT_FLOAT, ge::DT_FLOAT})
51+
.Format({ge::FORMAT_ND, ge::FORMAT_ND})
52+
.UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND});
53+
this->Input("expert_scales")
54+
.ParamType(OPTIONAL)
55+
.DataType({ge::DT_FLOAT, ge::DT_FLOAT})
56+
.Format({ge::FORMAT_ND, ge::FORMAT_ND})
57+
.UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND});
58+
this->Output("output")
59+
.ParamType(REQUIRED)
60+
.DataType({ge::DT_BF16, ge::DT_FLOAT16})
61+
.Format({ge::FORMAT_ND, ge::FORMAT_ND})
62+
.UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND});
63+
this->Output("ep_recv_count")
64+
.ParamType(REQUIRED)
65+
.DataType({ge::DT_INT32, ge::DT_INT32})
66+
.Format({ge::FORMAT_ND, ge::FORMAT_ND})
67+
.UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND});
68+
this->Attr("group_ep").String();
69+
this->Attr("ep_rank_size").Int();
70+
this->Attr("ep_rank_id").Int();
71+
this->Attr("moe_expert_num").Int();
72+
this->Attr("share_expert_num").Int();
73+
this->Attr("share_expert_rank_num").Int();
74+
this->Attr("quant_mode").Int();
75+
this->Attr("global_bs").Int();
76+
77+
this->MC2().HcclGroup({"group_ep"});
78+
this->AICore().AddConfig("ascend910_93");
79+
}
80+
};
81+
82+
OP_ADD(DispatchGmmCombineDecode);
83+
} // namespace ops

0 commit comments

Comments
 (0)