Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions .gitsubmodules
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
[submodule "csrc/third_party/catlass"]
path = csrc/third_party/catlass
url = https://gitee.com/ascend/catlass.git
branch = catlass-v1-stable
10 changes: 9 additions & 1 deletion csrc/build_aclnn.sh
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,15 @@ elif [[ "$SOC_VERSION" =~ ^ascend910b ]]; then
SOC_ARG="ascend910b"
elif [[ "$SOC_VERSION" =~ ^ascend910_93 ]]; then
# ASCEND910C (A3) series
CUSTOM_OPS="grouped_matmul_swiglu_quant_weight_nz_tensor_list"
# depdendency: catlass
CATLASS_PATH=${ROOT_DIR}/csrc/third_party/catlass
if [[ ! -d "${CATLASS_PATH}" ]]; then
echo "depdendency catlass does not exist, please run 'git submodule update --init --recursive'"
exit 1
fi
ABSOLUTE_CATLASS_PATH=$(cd "${CATLASS_PATH}" && pwd)
export CPATH=${ABSOLUTE_CATLASS_PATH}/include:${CPATH}
CUSTOM_OPS="grouped_matmul_swiglu_quant_weight_nz_tensor_list;dispatch_gmm_combine_decode"
SOC_ARG="ascend910_93"
else
# others
Expand Down
2 changes: 1 addition & 1 deletion csrc/cmake/func.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -282,7 +282,7 @@ function(add_ops_src_copy)
set(_BUILD_FLAG ${SRC_COPY_DST}/${SRC_COPY_TARGET_NAME}.done)
add_custom_command(OUTPUT ${_BUILD_FLAG}
COMMAND mkdir -p ${SRC_COPY_DST}
COMMAND cp -rf ${SRC_COPY_SRC}/op_kernel/*.* ${SRC_COPY_DST}
COMMAND cp -rf ${SRC_COPY_SRC}/op_kernel/* ${SRC_COPY_DST}
COMMAND touch ${_BUILD_FLAG}
)

Expand Down
51 changes: 51 additions & 0 deletions csrc/dispatch_gmm_combine_decode/op_host/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
# Copyright (c) 2025 Huawei Technologies Co., Ltd.
# This file is a part of the CANN Open Software.
# Licensed under CANN Open Software License Agreement Version 1.0 (the "License").
# Please refer to the License for details. You may not use this file except in compliance with the License.
# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
# See LICENSE in the root of the software repository for the full text of the License.
# ======================================================================================================================

add_ops_compile_options(
OP_NAME DispatchGmmCombineDecode
OPTIONS --cce-auto-sync=off
-Wno-deprecated-declarations
-Werror
)

target_sources(op_host_aclnnInner PRIVATE
dispatch_gmm_combine_decode_def.cpp
)

target_sources(opapi PRIVATE
aclnn_dispatch_gmm_combine_decode.cpp
)

if (NOT BUILD_OPEN_PROJECT)
target_sources(aclnn_ops_train PRIVATE
aclnn_dispatch_gmm_combine_decode.cpp
)

target_sources(aclnn_ops_infer PRIVATE
aclnn_dispatch_gmm_combine_decode.cpp
)
endif ()

target_sources(optiling PRIVATE
dispatch_gmm_combine_decode_tiling.cpp
)

target_include_directories(optiling PRIVATE
${CMAKE_CURRENT_SOURCE_DIR}
)

target_sources(opsproto PRIVATE
dispatch_gmm_combine_decode_proto.cpp
)

file(GLOB _GMM_Aclnn_header "${CMAKE_CURRENT_SOURCE_DIR}/aclnn_dispatch_gmm_combine_decode.h")

install(FILES ${_GMM_Aclnn_header}
DESTINATION ${ACLNN_INC_INSTALL_DIR} OPTIONAL
)
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
/*
* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This file is a part of the CANN Open Software.
* Licensed under CANN Open Software License Agreement Version 1.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
#include <string.h>
#include "graph/types.h"
#include "aclnn/opdev/platform.h"
#include "aclnn_dispatch_gmm_combine_decode.h"

enum NnopbaseHcclServerType {
NNOPBASE_HCCL_SERVER_TYPE_AICPU = 0,
NNOPBASE_HCCL_SERVER_TYPE_MTE,
NNOPBASE_HCCL_SERVER_TYPE_END
};
extern "C" void __attribute__((weak)) NnopbaseSetHcclServerType(void *executor, NnopbaseHcclServerType sType);

#ifdef __cplusplus
extern "C" {
#endif

extern aclnnStatus aclnnInnerDispatchGmmCombineDecodeGetWorkspaceSize(
const aclTensor *x,
const aclTensor *expertIds,
const aclTensor *gmm1PermutedWeight,
const aclTensor *gmm1PermutedWeightScale,
const aclTensor *gmm2Weight,
const aclTensor *gmm2WeightScale,
const aclTensor *expertSmoothScalesOptional,
const aclTensor *expertScalesOptional,
char *groupEp,
int64_t epRankSize,
int64_t epRankId,
int64_t moeExpertNum,
int64_t shareExpertNum,
int64_t shareExpertRankNum,
int64_t quantMode,
int64_t globalBs,
const aclTensor *output,
const aclTensor *epRecvCount,
uint64_t *workspaceSize,
aclOpExecutor **executor);
extern aclnnStatus aclnnInnerDispatchGmmCombineDecode(
void *workspace,
uint64_t workspaceSize,
aclOpExecutor *executor,
aclrtStream stream);

aclnnStatus aclnnDispatchGmmCombineDecodeGetWorkspaceSize(
const aclTensor *x,
const aclTensor *expertIds,
const aclTensor *gmm1PermutedWeight,
const aclTensor *gmm1PermutedWeightScale,
const aclTensor *gmm2Weight,
const aclTensor *gmm2WeightScale,
const aclTensor *expertSmoothScalesOptional,
const aclTensor *expertScalesOptional,
char *groupEp,
int64_t epRankSize,
int64_t epRankId,
int64_t moeExpertNum,
int64_t shareExpertNum,
int64_t shareExpertRankNum,
int64_t quantMode,
int64_t globalBs,
const aclTensor *output,
const aclTensor *epRecvCount,
uint64_t *workspaceSize,
aclOpExecutor **executor)
{
return aclnnInnerDispatchGmmCombineDecodeGetWorkspaceSize(x, expertIds, gmm1PermutedWeight, gmm1PermutedWeightScale,
gmm2Weight, gmm2WeightScale, expertSmoothScalesOptional, expertScalesOptional, groupEp, epRankSize,
epRankId, moeExpertNum, shareExpertNum, shareExpertRankNum, quantMode, globalBs,
output, epRecvCount, workspaceSize, executor);
}

aclnnStatus aclnnDispatchGmmCombineDecode(
void *workspace,
uint64_t workspaceSize,
aclOpExecutor *executor,
aclrtStream stream)
{
if (NnopbaseSetHcclServerType) {
if (op::GetCurrentPlatformInfo().GetSocVersion() == op::SocVersion::ASCEND910B) {
NnopbaseSetHcclServerType(executor, NNOPBASE_HCCL_SERVER_TYPE_AICPU);
} else {
NnopbaseSetHcclServerType(executor, NNOPBASE_HCCL_SERVER_TYPE_MTE);
}
}
return aclnnInnerDispatchGmmCombineDecode(workspace, workspaceSize, executor, stream);
}

#ifdef __cplusplus
}
#endif


Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
/*
* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This file is a part of the CANN Open Software.
* Licensed under CANN Open Software License Agreement Version 1.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
#ifndef DISPATCH_GMM_COMBINE_DECODE
#define DISPATCH_GMM_COMBINE_DECODE

#include "aclnn/acl_meta.h"

#ifdef __cplusplus
extern "C" {
#endif

__attribute__((visibility("default"))) aclnnStatus aclnnDispatchGmmCombineDecodeGetWorkspaceSize(
const aclTensor *x,
const aclTensor *expertIds,
const aclTensor *gmm1PermutedWeight,
const aclTensor *gmm1PermutedWeightScale,
const aclTensor *gmm2Weight,
const aclTensor *gmm2WeightScale,
const aclTensor *expertSmoothScalesOptional,
const aclTensor *expertScalesOptional,
char *groupEp,
int64_t epRankSize,
int64_t epRankId,
int64_t moeExpertNum,
int64_t shareExpertNum,
int64_t shareExpertRankNum,
int64_t quantMode,
int64_t globalBs,
const aclTensor *output,
const aclTensor *epRecvCount,
uint64_t *workspaceSize,
aclOpExecutor **executor);

__attribute__((visibility("default"))) aclnnStatus aclnnDispatchGmmCombineDecode(
void *workspace,
uint64_t workspaceSize,
aclOpExecutor *executor,
aclrtStream stream);

#ifdef __cplusplus
}
#endif

#endif
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
/*
* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This file is a part of the CANN Open Software.
* Licensed under CANN Open Software License Agreement Version 1.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
#include "register/op_def_registry.h"

namespace ops {
class DispatchGmmCombineDecode : public OpDef
{
public:
explicit DispatchGmmCombineDecode(const char *name) : OpDef(name)
{
this->Input("x")
.ParamType(REQUIRED)
.DataType({ge::DT_BF16, ge::DT_FLOAT16})
.Format({ge::FORMAT_ND, ge::FORMAT_ND})
.UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND});
this->Input("expert_ids")
.ParamType(REQUIRED)
.DataType({ge::DT_INT32, ge::DT_INT32})
.Format({ge::FORMAT_ND, ge::FORMAT_ND})
.UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND});
this->Input("gmm1_permuted_weight")
.ParamType(REQUIRED)
.DataType({ge::DT_INT8, ge::DT_INT8})
.Format({ge::FORMAT_FRACTAL_NZ, ge::FORMAT_FRACTAL_NZ})
.UnknownShapeFormat({ge::FORMAT_FRACTAL_NZ, ge::FORMAT_FRACTAL_NZ});
this->Input("gmm1_permuted_weight_scale")
.ParamType(REQUIRED)
.DataType({ge::DT_FLOAT, ge::DT_FLOAT})
.Format({ge::FORMAT_ND, ge::FORMAT_ND})
.UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND});
this->Input("gmm2_weight")
.ParamType(REQUIRED)
.DataType({ge::DT_INT8, ge::DT_INT8})
.Format({ge::FORMAT_FRACTAL_NZ, ge::FORMAT_FRACTAL_NZ})
.UnknownShapeFormat({ge::FORMAT_FRACTAL_NZ, ge::FORMAT_FRACTAL_NZ});
this->Input("gmm2_weight_scale")
.ParamType(REQUIRED)
.DataType({ge::DT_FLOAT, ge::DT_FLOAT})
.Format({ge::FORMAT_ND, ge::FORMAT_ND})
.UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND});
this->Input("expert_smooth_scales")
.ParamType(OPTIONAL)
.DataType({ge::DT_FLOAT, ge::DT_FLOAT})
.Format({ge::FORMAT_ND, ge::FORMAT_ND})
.UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND});
this->Input("expert_scales")
.ParamType(OPTIONAL)
.DataType({ge::DT_FLOAT, ge::DT_FLOAT})
.Format({ge::FORMAT_ND, ge::FORMAT_ND})
.UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND});
this->Output("output")
.ParamType(REQUIRED)
.DataType({ge::DT_BF16, ge::DT_FLOAT16})
.Format({ge::FORMAT_ND, ge::FORMAT_ND})
.UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND});
this->Output("ep_recv_count")
.ParamType(REQUIRED)
.DataType({ge::DT_INT32, ge::DT_INT32})
.Format({ge::FORMAT_ND, ge::FORMAT_ND})
.UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND});
this->Attr("group_ep").String();
this->Attr("ep_rank_size").Int();
this->Attr("ep_rank_id").Int();
this->Attr("moe_expert_num").Int();
this->Attr("share_expert_num").Int();
this->Attr("share_expert_rank_num").Int();
this->Attr("quant_mode").Int();
this->Attr("global_bs").Int();

this->MC2().HcclGroup({"group_ep"});
this->AICore().AddConfig("ascend910_93");
}
};

OP_ADD(DispatchGmmCombineDecode);
} // namespace ops
Loading
Loading