diff --git a/.clang-format b/.clang-format index a77ae97c3..6bc4b3682 100644 --- a/.clang-format +++ b/.clang-format @@ -1,10 +1,12 @@ --- BasedOnStyle: LLVM IndentWidth: 4 # 缩进宽度,LLVM 默认值为 2,改为 4 +TabWidth: 4 # 制表符宽度,与 IndentWidth 一致 +UseTab: Never # 只用空格缩进,不用 Tab AccessModifierOffset: -4 # public/protected/private 访问控制符相对成员的偏移,与 IndentWidth 配合,LLVM 默认值为 -2 AlignOperands: AlignAfterOperator # 双目运算符的行间对齐,LLVM 默认值为 Align,改为带符号一起换行 BreakBeforeBinaryOperators: All # 在双目运算符之前换行,LLVM 默认值为 None,改为换行时总是把双目运算符放在行首,包括赋值(=) -ColumnLimit: 0 # 列宽限制,LLVM 默认值为 80,改为不限制 +ColumnLimit: 80 # 列宽限制,LLVM 默认值为 80,改为不限制 AllowShortBlocksOnASingleLine: Always # 是否允许短块(单个语句的块)不换行,LLVM 默认值为 Never,改为允许 AllowShortLoopsOnASingleLine: true # 是否允许短循环不换行,LLVM 默认值为 false,改为允许 InsertBraces: true # 是否在 if/for/while/switch 等语句后插入大括号,LLVM 默认值为 false,改为允许 diff --git a/.github/workflows/build.yaml b/.github/workflows/build.yaml index 3d31c23bb..709115791 100644 --- a/.github/workflows/build.yaml +++ b/.github/workflows/build.yaml @@ -1,5 +1,6 @@ name: Build and test on: + workflow_dispatch: # 手动触发 pull_request: push: paths-ignore: diff --git a/.gitignore b/.gitignore index e38cf5747..eb124141f 100644 --- a/.gitignore +++ b/.gitignore @@ -1,6 +1,15 @@ +# 辅助脚本与 IDE/构建生成 +.clangd +compile_commands.json +clean.sh +scripts/inspect_safetensors.py + +# ----------------------------------------------------------------------------- +# 构建与二进制 +# ----------------------------------------------------------------------------- # Xmake cache .xmake/ -build/ +build*/ # Binaries bin/ @@ -15,6 +24,9 @@ lib/ # Vscode .vscode/ +# But keep configuration files +!.vscode/c_cpp_properties.json +!.vscode/settings.json # Python __pycache__/ @@ -77,14 +89,18 @@ htmlcov/ # IDE and editor settings .vscode/ +# But keep configuration files +!.vscode/c_cpp_properties.json +!.vscode/settings.json .idea/ *.swp *~ - # macOS .DS_Store # Windows Thumbs.db ehthumbs.db -desktop.ini \ No newline at end of file +desktop.ini + +METAX_BACKEND_REPORT.md \ No newline at end of file diff --git a/METAX_BACKEND_REPORT.md b/METAX_BACKEND_REPORT.md new file mode 100644 index 000000000..c45312882 --- /dev/null +++ b/METAX_BACKEND_REPORT.md @@ -0,0 +1,272 @@ +# 项目二与项目三完成报告 + +## 一、完成概要 + +本次完成了 README 中的项目二和项目三,主要成果如下: + +1. 在 LLAISYS 中完成了双 GPU 后端接入,支持 `NVIDIA` 与 `MetaX` 两个平台。 +2. 完成了 `Qwen2` 模型在 LLAISYS 后端的推理实现,支持权重加载、KV-Cache 和逐 token 解码。 +3. 完成了项目二要求的核心算子实现与接入,包括 `add`、`argmax`、`embedding`、`linear`、`rms_norm`、`rope`、`self_attention`、`swiglu`。 +4. 完成了项目三要求的随机采样功能,支持 `temperature`、`top-k`、`top-p`。 +5. 完成了聊天服务与交互界面,提供 `FastAPI` 服务端、命令行客户端和 Web 界面,并支持流式输出。 +6. 编写了统一的推理 benchmark 脚本,用于比较 `Torch` 与 `LLAISYS` 的输出对齐情况和吞吐表现。 + +当前工程已经能够在本地 `NVIDIA` 平台和远程 `MetaX` 平台完成端到端模型推理,并具备聊天服务的基本交付能力。 + +## 二、开发环境 + +### 1. 本地开发与验证环境 + +- 操作系统:Linux +- GPU:NVIDIA RTX 4060 +- CUDA:本地安装 CUDA 工具链 +- 构建工具:`xmake` +- Python:Python 3.x +- 主要依赖:`transformers`、`huggingface_hub`、`fastapi`、`uvicorn` + +### 2. 远程 MetaX 验证环境 + +- 操作系统:Linux +- GPU:MetaX GPU +- 开发环境:`MACA / mcPyTorch` +- 头文件与库路径:远程环境已安装对应 MetaX SDK + +### 3. 模型与测试对象 + +- 模型:`deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B` +- 权重格式:`safetensors` +- 主要数据类型:`bf16` + +## 三、项目二具体实现 + +### 1. 双平台 Runtime 与构建链路 + +在 LLAISYS 原有 CPU 框架基础上,补充了 `NVIDIA` 与 `MetaX` 两套设备后端: + +- 实现了 `nvidia` Runtime API 与 `metax` Runtime API。 +- 在构建系统中增加了平台开关,支持通过 `xmake` 分别编译 `NVIDIA` 与 `MetaX` 后端。 +- 在 Python 侧补充设备映射,使测试脚本和推理脚本能够通过 `--device nvidia` 与 `--device metax` 调用对应后端。 + +### 2. 核心算子实现 + +项目二要求的核心算子已经在 GPU 后端完成实现,并接入统一算子分发路径。主要包括: + +- `add` +- `argmax` +- `embedding` +- `linear` +- `rms_norm` +- `rope` +- `self_attention` +- `swiglu` + +其中: + +- `NVIDIA` 路径主要采用 CUDA 风格实现,并在 `linear` 等算子中使用官方库加速。 +- `MetaX` 路径尽量对齐 CUDA 实现风格,优先使用 MetaX 官方 API 与 `mcBLAS`。 +- 针对 MetaX 平台 `warp=64` 的特性,对部分 kernel 的 block 配置和规约方式做了适配。 + +### 3. 模型推理实现 + +围绕 `Qwen2` 模型,完成了 LLAISYS 后端推理链路: + +- 在 C/C++ 后端实现模型结构、张量组织和推理逻辑。 +- 实现 `safetensors` 权重加载接口。 +- 实现 KV-Cache,支持逐 token 解码。 +- 在 Python 包装层中完成 `Qwen2` 模型封装,支持 `generate` 与 `generate_stream`。 + +### 4. 功能验证情况 + +项目二完成后,已完成以下验证: + +- Runtime 测试:验证设备运行时接口可用。 +- 算子测试:各核心算子均有对应测试脚本,可在指定设备上运行。 +- 推理测试:`test/test_infer.py` 可用于验证 LLAISYS 输出是否与 Torch 对齐。 +- Benchmark 测试:`test/benchmark_infer.py` 用于比较 Torch 与 LLAISYS 的推理性能与吞吐,输出对齐由 `test/test_infer.py` 单独负责验证。 + +本地 `NVIDIA` 平台最新 benchmark 结果如下: + +| Case | Torch mean(ms) | Torch tok/s | LLAISYS mean(ms) | LLAISYS tok/s | speedup | +|---|---:|---:|---:|---:|---:| +| short/32 | 810.54 | 39.48 | 495.97 | 64.52 | 1.63x | +| short/64 | 1563.33 | 40.94 | 1007.77 | 63.51 | 1.55x | +| short/128 | 2079.48 | 38.95 | 1280.56 | 63.25 | 1.62x | +| medium/32 | 786.33 | 40.70 | 506.45 | 63.19 | 1.55x | +| medium/64 | 1802.99 | 35.50 | 1029.44 | 62.17 | 1.75x | +| medium/128 | 3219.73 | 39.75 | 2114.44 | 60.54 | 1.52x | +| long/32 | 1032.12 | 31.00 | 522.34 | 61.26 | 1.98x | +| long/64 | 1616.44 | 39.59 | 1040.72 | 61.50 | 1.55x | +| long/128 | 3160.70 | 40.50 | 2155.55 | 59.38 | 1.47x | + +吞吐汇总如下: + +- Torch total throughput:`38.89 tok/s` +- LLAISYS total throughput:`61.56 tok/s` +- Overall speedup:`1.58x` + +从这组结果可以看到,LLAISYS 在本地 `NVIDIA` 平台上已经取得了稳定的端到端推理性能优势。 + +远程 `MetaX` 平台最新 benchmark 结果如下: + +| Case | Torch mean(ms) | Torch tok/s | LLAISYS mean(ms) | LLAISYS tok/s | speedup | +|---|---:|---:|---:|---:|---:| +| short/32 | 864.34 | 37.02 | 356.17 | 89.85 | 2.43x | +| short/64 | 1749.20 | 36.59 | 818.50 | 78.19 | 2.14x | +| short/128 | 2173.61 | 37.27 | 1105.36 | 73.28 | 1.97x | +| medium/32 | 865.01 | 36.99 | 437.44 | 73.15 | 1.98x | +| medium/64 | 1721.78 | 37.17 | 977.52 | 65.47 | 1.76x | +| medium/128 | 3439.50 | 37.21 | 2386.28 | 53.64 | 1.44x | +| long/32 | 863.88 | 37.04 | 516.00 | 62.02 | 1.67x | +| long/64 | 1724.36 | 37.12 | 1129.42 | 56.67 | 1.53x | +| long/128 | 3424.45 | 37.38 | 2703.57 | 47.34 | 1.27x | + +吞吐汇总如下: + +- Torch total throughput:`37.14 tok/s` +- LLAISYS total throughput:`59.92 tok/s` +- Overall speedup:`1.61x` + +从这组结果可以看到,LLAISYS 在远程 `MetaX` 平台上同样取得了稳定的端到端推理性能优势。结合 `test/test_infer.py` 的对齐测试,可以说明项目二的双平台推理链路已经打通并完成验证。 + +## 四、项目三具体实现 + +### 1. 随机采样 + +在模型推理接口中补充了随机采样逻辑,支持以下参数: + +- `temperature` +- `top-k` +- `top-p` + +当参数配置为 `top_k=1, top_p=1.0, temperature=1.0` 时,系统工作在确定性贪心解码模式,可用于和 Torch 做严格 token 对齐测试;其他配置可用于更自然的聊天生成。 + +### 2. 聊天服务端 + +实现了基于 `FastAPI` 的聊天服务端,主要能力包括: + +- 提供 `/v1/chat/completions` 接口 +- 接口风格对齐 OpenAI Chat Completion +- 支持普通返回模式 +- 支持基于 `text/event-stream` 的流式输出 +- 支持通过请求参数控制 `top-k`、`top-p`、`temperature`、`max_tokens` + +服务端入口文件为: + +- `test/chat_server.py` + +### 3. 命令行交互 + +实现了命令行聊天客户端,支持: + +- 向服务端发送多轮消息 +- 保持对话历史 +- 支持普通模式和流式模式 +- 支持 `/reset` 清空历史、`/exit` 退出 + +对应文件为: + +- `test/chat_cli.py` + +### 4. Web 交互界面 + +实现了简单的 Web 聊天页面,支持: + +- 输入对话消息 +- 设置 `top-k`、`top-p`、`temperature` +- 切换是否流式输出 +- 与 `FastAPI` 服务端联动完成对话 + +对应文件为: + +- `test/chat_web.html` + +### 5. 项目三完成情况 + +目前,项目三已经完成“可采样、可服务、可交互”的基础目标: + +- 模型可以通过 LLAISYS 后端执行聊天生成。 +- 服务端可以接收 HTTP 请求并返回响应。 +- 命令行和 Web 端都可以与服务端交互。 +- 系统支持单用户场景下的连续对话与流式输出。 + +## 五、复现流程 + +### 1. NVIDIA 平台构建与测试 + +```bash +cd ~/llaisys +xmake f -c -m release --nv-gpu=y --mx-gpu=n +xmake -r && xmake install +``` + +运行推理对齐测试: + +```bash +python test/test_infer.py \ + --device nvidia \ + --model ~/model_pkg/DeepSeek-R1-Distill-Qwen-1.5B/ \ + --test +``` + +运行 Torch 与 LLAISYS 的推理 benchmark: + +```bash +python test/benchmark_infer.py \ + --device nvidia \ + --model ~/model_pkg/DeepSeek-R1-Distill-Qwen-1.5B/ +``` + +### 2. MetaX 平台构建与测试 + +在远程 MetaX 服务器上执行: + +```bash +cd ~/llaisys +xmake f -c -m release --mx-gpu=y --nv-gpu=n +xmake -r && xmake install +``` + +运行推理对齐测试: + +```bash +python test/test_infer.py \ + --device metax \ + --model ~/model_pkg/DeepSeek-R1-Distill-Qwen-1.5B/ \ + --test +``` + +运行 benchmark: + +```bash +python test/benchmark_infer.py \ + --device metax \ + --model ~/model_pkg/DeepSeek-R1-Distill-Qwen-1.5B/ +``` + +### 3. 聊天服务复现 + +启动服务端: + +```bash +python test/chat_server.py \ + --device nvidia \ + --model ~/model_pkg/DeepSeek-R1-Distill-Qwen-1.5B/ +``` + +命令行客户端连接服务端: + +```bash +python test/chat_cli.py --stream +``` + +Web 端使用方法: + +1. 启动 `chat_server.py` +2. 浏览器访问 `http://127.0.0.1:8000/` +3. 在页面中输入消息并发起对话 + +## 结论 + +项目二已经完成 LLAISYS 在 `NVIDIA` 与 `MetaX` 双 GPU 平台上的推理后端集成,完成了核心算子、运行时接口和模型推理链路的实现与验证。项目三在此基础上完成了随机采样、聊天服务端、CLI 与 Web UI 的实现,使系统具备了单用户对话式推理的基础能力。 + +当前代码已经具备提交条件,并能够作为后续性能优化和工程化完善的基础版本。 diff --git a/include/llaisys.h b/include/llaisys.h index 73ca7eead..ca9f03184 100644 --- a/include/llaisys.h +++ b/include/llaisys.h @@ -24,6 +24,7 @@ typedef enum { LLAISYS_DEVICE_CPU = 0, //// TODO: Add more device types here. Numbers need to be consecutive. LLAISYS_DEVICE_NVIDIA = 1, + LLAISYS_DEVICE_METAX = 2, LLAISYS_DEVICE_TYPE_COUNT } llaisysDeviceType_t; diff --git a/include/llaisys/models/qwen2.h b/include/llaisys/models/qwen2.h index 7054626d4..529725be1 100644 --- a/include/llaisys/models/qwen2.h +++ b/include/llaisys/models/qwen2.h @@ -30,13 +30,23 @@ __C { }; struct LlaisysQwen2Model; - + // __export用于导出函数,使得它们在DLL中可见 __export struct LlaisysQwen2Model *llaisysQwen2ModelCreate(const LlaisysQwen2Meta *meta, llaisysDeviceType_t device, int *device_ids, int ndevice); __export void llaisysQwen2ModelDestroy(struct LlaisysQwen2Model * model); __export struct LlaisysQwen2Weights *llaisysQwen2ModelWeights(struct LlaisysQwen2Model * model); + __export void llaisysQwen2ModelResetCache(struct LlaisysQwen2Model * model); + __export int64_t llaisysQwen2ModelInfer(struct LlaisysQwen2Model * model, int64_t * token_ids, size_t ntoken); + + __export int64_t llaisysQwen2ModelInferSample( + struct LlaisysQwen2Model * model, + int64_t * token_ids, + size_t ntoken, + int top_k, + float top_p, + float temperature); } #endif // LLAISYS_MODELS_QWEN2_H diff --git a/python/llaisys/libllaisys/__init__.py b/python/llaisys/libllaisys/__init__.py index f536fb527..c40ffc8e3 100644 --- a/python/llaisys/libllaisys/__init__.py +++ b/python/llaisys/libllaisys/__init__.py @@ -12,6 +12,7 @@ from .tensor import llaisysTensor_t from .tensor import load_tensor from .ops import load_ops +from .models import load_models def load_shared_library(): @@ -38,6 +39,7 @@ def load_shared_library(): load_runtime(LIB_LLAISYS) load_tensor(LIB_LLAISYS) load_ops(LIB_LLAISYS) +load_models(LIB_LLAISYS) __all__ = [ diff --git a/python/llaisys/libllaisys/llaisys_types.py b/python/llaisys/libllaisys/llaisys_types.py index c5a0b4679..cbe92132e 100644 --- a/python/llaisys/libllaisys/llaisys_types.py +++ b/python/llaisys/libllaisys/llaisys_types.py @@ -6,7 +6,8 @@ class DeviceType(IntEnum): CPU = 0 NVIDIA = 1 - COUNT = 2 + METAX = 2 + COUNT = 3 llaisysDeviceType_t = ctypes.c_int diff --git a/python/llaisys/libllaisys/models.py b/python/llaisys/libllaisys/models.py new file mode 100644 index 000000000..fc47ae577 --- /dev/null +++ b/python/llaisys/libllaisys/models.py @@ -0,0 +1,103 @@ +from .tensor import llaisysTensor_t +from .llaisys_types import ( + llaisysDataType_t, + llaisysDeviceType_t, + DataType, + DeviceType, +) +from ctypes import ( + c_float, + c_int64, + c_size_t, + POINTER, + Structure, + c_int, + c_void_p, +) + + +class LlaisysQwen2Meta(Structure): + _fields_ = [ + ("dtype", c_int), # llaisysDataType_t + ("nlayer", c_size_t), + ("hs", c_size_t), + ("nh", c_size_t), + ("nkvh", c_size_t), + ("dh", c_size_t), + ("di", c_size_t), + ("maxseq", c_size_t), + ("voc", c_size_t), + ("epsilon", c_float), + ("theta", c_float), + ("end_token", c_int64), + ] + + +class LlaisysQwen2Weights(Structure): + _fields_ = [ + ("in_embed", llaisysTensor_t), + ("out_embed", llaisysTensor_t), + ("out_norm_w", llaisysTensor_t), + ("attn_norm_w", POINTER(llaisysTensor_t)), + ("attn_q_w", POINTER(llaisysTensor_t)), + ("attn_q_b", POINTER(llaisysTensor_t)), + ("attn_k_w", POINTER(llaisysTensor_t)), + ("attn_k_b", POINTER(llaisysTensor_t)), + ("attn_v_w", POINTER(llaisysTensor_t)), + ("attn_v_b", POINTER(llaisysTensor_t)), + ("attn_o_w", POINTER(llaisysTensor_t)), + ("mlp_norm_w", POINTER(llaisysTensor_t)), + ("mlp_gate_w", POINTER(llaisysTensor_t)), + ("mlp_up_w", POINTER(llaisysTensor_t)), + ("mlp_down_w", POINTER(llaisysTensor_t)), + ] + + +llaisysQwen2Model_t = c_void_p + + +def load_models(lib): + # Meta structure + lib.LlaisysQwen2Meta = LlaisysQwen2Meta + lib.LlaisysQwen2Weights = LlaisysQwen2Weights + + # llaisysQwen2ModelCreate + # argtypes用于指定函数参数类型,restype用于指定返回类型 + lib.llaisysQwen2ModelCreate.argtypes = [ + POINTER(LlaisysQwen2Meta), + c_int, # llaisysDeviceType_t + POINTER(c_int), # int *device_ids + c_int, # int ndevice + ] + lib.llaisysQwen2ModelCreate.restype = llaisysQwen2Model_t + + # llaisysQwen2ModelDestroy + lib.llaisysQwen2ModelDestroy.argtypes = [llaisysQwen2Model_t] + lib.llaisysQwen2ModelDestroy.restype = None + + # llaisysQwen2ModelWeights + lib.llaisysQwen2ModelWeights.argtypes = [llaisysQwen2Model_t] + lib.llaisysQwen2ModelWeights.restype = POINTER(LlaisysQwen2Weights) + + # llaisysQwen2ModelResetCache + lib.llaisysQwen2ModelResetCache.argtypes = [llaisysQwen2Model_t] + lib.llaisysQwen2ModelResetCache.restype = None + + # llaisysQwen2ModelInfer + lib.llaisysQwen2ModelInfer.argtypes = [ + llaisysQwen2Model_t, + POINTER(c_int64), # int64_t *token_ids + c_size_t, # size_t ntoken + ] + lib.llaisysQwen2ModelInfer.restype = c_int64 + + # llaisysQwen2ModelInferSample + lib.llaisysQwen2ModelInferSample.argtypes = [ + llaisysQwen2Model_t, + POINTER(c_int64), # int64_t *token_ids + c_size_t, # size_t ntoken + c_int, # int top_k + c_float, # float top_p + c_float, # float temperature + ] + lib.llaisysQwen2ModelInferSample.restype = c_int64 diff --git a/python/llaisys/models/qwen2.py b/python/llaisys/models/qwen2.py index 0d07b0b21..135a7ed38 100644 --- a/python/llaisys/models/qwen2.py +++ b/python/llaisys/models/qwen2.py @@ -1,24 +1,226 @@ from typing import Sequence from ..libllaisys import LIB_LLAISYS -from ..libllaisys import DeviceType +from ..libllaisys import DeviceType, DataType +from ..libllaisys.models import ( + LlaisysQwen2Meta, + LlaisysQwen2Weights, + llaisysQwen2Model_t, +) +from ..tensor import Tensor +from ctypes import c_int64, c_size_t, POINTER, byref, cast, c_int, c_void_p, c_float +import json from pathlib import Path -import safetensors +from safetensors.torch import load_file as safetensors_load_file +import torch class Qwen2: - def __init__(self, model_path, device: DeviceType = DeviceType.CPU): - # TODO: Implement model constructor - model_path = Path(model_path) + self._device = device + + # 加载模型配置 + config_path = model_path / "config.json" # '/'拼接路径 + if not config_path.exists(): + raise FileNotFoundError(f"config.json not found in {model_path}") + + with open(config_path, 'r') as f: + config = json.load(f) + + # 提取模型元数据 + self.meta = LlaisysQwen2Meta() + self.meta.dtype = DataType.BF16 # 根据模型配置确定 + self.meta.nlayer = config.get("num_hidden_layers", config.get("num_layers", 0)) + self.meta.hs = config.get("hidden_size", 0) + self.meta.nh = config.get("num_attention_heads", 0) + self.meta.nkvh = config.get("num_key_value_heads", self.meta.nh) # GQA + self.meta.dh = config.get("head_dim", self.meta.hs // self.meta.nh) # 一般有hs = nh * dh + if self.meta.dh == 0: + self.meta.dh = self.meta.hs // self.meta.nh + # intermediate_size是MLP(前馈层)的中间层维度,一般是hs的几倍;起到先升维再降维的作用,提高非线性表达能力 + self.meta.di = config.get("intermediate_size", 0) + self.meta.maxseq = config.get("max_position_embeddings", 32768) + self.meta.voc = config.get("vocab_size", 0) + self.meta.epsilon = config.get("rms_norm_eps", 1e-6) + self.meta.theta = config.get("rope_theta", 1000000.0) # RoPE的基数,控制位置编码的频率分布 + self.meta.end_token = config.get("eos_token_id", 151643) + + # 确定设备 + device_id = 0 + device_ids = (c_int * 1)(device_id) + + # 创建模型 + self.model = LIB_LLAISYS.llaisysQwen2ModelCreate( + byref(self.meta), # byref用于将Python对象转换为C语言的结构体指针 + device.value, + device_ids, + 1 + ) + + if not self.model: + raise RuntimeError("Failed to create model") + + # 获取权重结构 + self.weights_ptr = LIB_LLAISYS.llaisysQwen2ModelWeights(self.model) + if not self.weights_ptr: + raise RuntimeError("Failed to get model weights") + + self.weights = self.weights_ptr.contents + # 持有所有权重 Tensor,延长权重的生命周期,避免 Python GC 导致底层 tensorDestroy 释放权重后悬空 + self._weight_tensors = [] + + # 加载权重 + self._load_weights(model_path) + + # 模型safetensors->LLAISYS:Tensor->C:LlaisysQwen2Weights + def _load_weights(self, model_path): + """从 safetensors 文件加载权重(流式加载 + BF16 直拷贝 + 进度输出)""" + safetensors_files = sorted(model_path.glob("*.safetensors")) + if not safetensors_files: + raise FileNotFoundError(f"No safetensors files found in {model_path}") + + print(f"[llaisys] Loading Qwen2 weights from: {model_path}") + print(f"[llaisys] Found {len(safetensors_files)} safetensors") + + # qwen2模型权重为bf16 + def to_bf16_cpu_contig(t: torch.Tensor) -> torch.Tensor: + t = t.detach().cpu() + if t.dtype != torch.bfloat16: + t = t.to(torch.bfloat16) + return t.contiguous() + + def load_llaisys_tensor_from_torch(t: torch.Tensor) -> Tensor: + t_cpu = to_bf16_cpu_contig(t) + lt = Tensor(shape=list(t_cpu.shape), dtype=DataType.BF16, device=self._device) + lt.load(c_void_p(t_cpu.data_ptr())) + self._weight_tensors.append(lt) + return lt + + def set_field(name: str, t: torch.Tensor): + lt = load_llaisys_tensor_from_torch(t) + setattr(self.weights, name, lt.lib_tensor()) # 为对象动态添加属性,等价于self.weights.name = lt.lib_tensor() + + loaded = 0 # 成功加载,没写进权重结构的tensor数量 + skipped = 0 # 遍历到但没用上的tensor数量 + + # 遍历所有safetensors文件 + for file_idx, file in enumerate(safetensors_files): + print(f"[llaisys] [{file_idx + 1}/{len(safetensors_files)}] reading {file.name}") + weights_dict = safetensors_load_file(str(file)) + print(f"[llaisys] tensors in shard: {len(weights_dict)}") + + for key, t in weights_dict.items(): + # Global weights + if key == "model.embed_tokens.weight": # 输入 embedding:[voc, hs] + set_field("in_embed", t) + loaded += 1 + continue + if key == "lm_head.weight": + set_field("out_embed", t) + loaded += 1 + continue + if key == "model.norm.weight": + set_field("out_norm_w", t) + loaded += 1 + continue + + # Per-layer weights + if not key.startswith("model.layers."): + skipped += 1 + continue + + parts = key.split(".") + if len(parts) < 4: + skipped += 1 + continue + + try: + layer_idx = int(parts[2]) + except ValueError: + skipped += 1 + continue + + if layer_idx < 0 or layer_idx >= int(self.meta.nlayer): + skipped += 1 + continue + + suffix = ".".join(parts[3:]) # 用'.'拼接层号后的元素 + + if suffix == "input_layernorm.weight": + lt = load_llaisys_tensor_from_torch(t) + self.weights.attn_norm_w[layer_idx] = lt.lib_tensor() + loaded += 1 + continue + + if suffix == "self_attn.q_proj.weight": + lt = load_llaisys_tensor_from_torch(t) + self.weights.attn_q_w[layer_idx] = lt.lib_tensor() + loaded += 1 + continue + if suffix == "self_attn.q_proj.bias": + lt = load_llaisys_tensor_from_torch(t) + self.weights.attn_q_b[layer_idx] = lt.lib_tensor() + loaded += 1 + continue + + if suffix == "self_attn.k_proj.weight": + lt = load_llaisys_tensor_from_torch(t) + self.weights.attn_k_w[layer_idx] = lt.lib_tensor() + loaded += 1 + continue + if suffix == "self_attn.k_proj.bias": + lt = load_llaisys_tensor_from_torch(t) + self.weights.attn_k_b[layer_idx] = lt.lib_tensor() + loaded += 1 + continue + + if suffix == "self_attn.v_proj.weight": + lt = load_llaisys_tensor_from_torch(t) + self.weights.attn_v_w[layer_idx] = lt.lib_tensor() + loaded += 1 + continue + if suffix == "self_attn.v_proj.bias": + lt = load_llaisys_tensor_from_torch(t) + self.weights.attn_v_b[layer_idx] = lt.lib_tensor() + loaded += 1 + continue + + if suffix == "self_attn.o_proj.weight": + lt = load_llaisys_tensor_from_torch(t) + self.weights.attn_o_w[layer_idx] = lt.lib_tensor() + loaded += 1 + continue - for file in sorted(model_path.glob("*.safetensors")): - data_ = safetensors.safe_open(file, framework="numpy", device="cpu") - for name_ in data_.keys(): - ## TODO: load the model weights - pass + if suffix == "post_attention_layernorm.weight": + lt = load_llaisys_tensor_from_torch(t) + self.weights.mlp_norm_w[layer_idx] = lt.lib_tensor() + loaded += 1 + continue + if suffix == "mlp.gate_proj.weight": + lt = load_llaisys_tensor_from_torch(t) + self.weights.mlp_gate_w[layer_idx] = lt.lib_tensor() + loaded += 1 + continue + if suffix == "mlp.up_proj.weight": + lt = load_llaisys_tensor_from_torch(t) + self.weights.mlp_up_w[layer_idx] = lt.lib_tensor() + loaded += 1 + continue + if suffix == "mlp.down_proj.weight": + lt = load_llaisys_tensor_from_torch(t) + self.weights.mlp_down_w[layer_idx] = lt.lib_tensor() + loaded += 1 + continue + + skipped += 1 + + # 释放 shard dict 的引用(尽快回收内存) + del weights_dict + + print(f"[llaisys] Done. loaded={loaded}, skipped={skipped}") + def generate( self, inputs: Sequence[int], @@ -27,7 +229,87 @@ def generate( top_p: float = 0.8, temperature: float = 0.8, ): + # 重置 KV Cache(开始新的生成序列) + LIB_LLAISYS.llaisysQwen2ModelResetCache(self.model) + + output_tokens = list(inputs) + if len(inputs) == 0: + return output_tokens + + if max_new_tokens is None: + max_new_tokens = 128 + max_new_tokens = max(int(max_new_tokens), 1) + + # Prefill 阶段 + next_token = self._infer_next(inputs, top_k, top_p, temperature) + output_tokens.append(next_token) + + # Decode 阶段 + for _ in range(max_new_tokens - 1): + if next_token == self.meta.end_token: + break + next_token = self._infer_next([next_token], top_k, top_p, temperature) + output_tokens.append(next_token) + + return output_tokens + + def generate_stream( + self, + inputs: Sequence[int], + max_new_tokens: int = None, + top_k: int = 1, + top_p: float = 1.0, + temperature: float = 1.0, + ): + LIB_LLAISYS.llaisysQwen2ModelResetCache(self.model) + if len(inputs) == 0: + return + + if max_new_tokens is None: + max_new_tokens = 128 + max_new_tokens = max(int(max_new_tokens), 1) + + next_token = self._infer_next(inputs, top_k, top_p, temperature) + yield next_token + for _ in range(max_new_tokens - 1): + if next_token == self.meta.end_token: + break + next_token = self._infer_next([next_token], top_k, top_p, temperature) + yield next_token + + def _infer_next( + self, + tokens: Sequence[int], + top_k: int, + top_p: float, + temperature: float, + ) -> int: + token_array = (c_int64 * len(tokens))(*tokens) + top_k_i = int(top_k) + top_p_f = float(top_p) + temp_f = float(temperature) - # TODO: Implement generate function + if top_k_i == 1 and top_p_f >= 1.0 and abs(temp_f - 1.0) < 1e-8: + return int( + LIB_LLAISYS.llaisysQwen2ModelInfer( + self.model, + token_array, + len(tokens), + ) + ) - return [] + return int( + LIB_LLAISYS.llaisysQwen2ModelInferSample( + self.model, + token_array, + len(tokens), + c_int(top_k_i), + c_float(top_p_f), + c_float(temp_f), + ) + ) + + def __del__(self): + if hasattr(self, 'model') and self.model: + LIB_LLAISYS.llaisysQwen2ModelDestroy(self.model) + self.model = None diff --git a/src/device/metax/metax_resource.hpp b/src/device/metax/metax_resource.hpp new file mode 100644 index 000000000..fd2679e0c --- /dev/null +++ b/src/device/metax/metax_resource.hpp @@ -0,0 +1,11 @@ +#pragma once + +#include "../device_resource.hpp" + +namespace llaisys::device::metax { +class Resource : public llaisys::device::DeviceResource { +public: + explicit Resource(int device_id); + ~Resource() = default; +}; +} // namespace llaisys::device::metax diff --git a/src/device/metax/metax_resource.maca b/src/device/metax/metax_resource.maca new file mode 100644 index 000000000..1fda42c09 --- /dev/null +++ b/src/device/metax/metax_resource.maca @@ -0,0 +1,7 @@ +#include "metax_resource.hpp" + +namespace llaisys::device::metax { + +Resource::Resource(int device_id) : llaisys::device::DeviceResource(LLAISYS_DEVICE_METAX, device_id) {} + +} // namespace llaisys::device::metax diff --git a/src/device/metax/metax_runtime_api.maca b/src/device/metax/metax_runtime_api.maca new file mode 100644 index 000000000..46dc324f9 --- /dev/null +++ b/src/device/metax/metax_runtime_api.maca @@ -0,0 +1,117 @@ +#include "../runtime_api.hpp" +#include "llaisys.h" + +#include + +#include +#include + +namespace llaisys::device::metax { + +namespace runtime_api { + +static mcMemcpyKind toMcMemcpyKind(llaisysMemcpyKind_t kind) { + switch (kind) { + case LLAISYS_MEMCPY_H2H: + return mcMemcpyHostToHost; + case LLAISYS_MEMCPY_H2D: + return mcMemcpyHostToDevice; + case LLAISYS_MEMCPY_D2H: + return mcMemcpyDeviceToHost; + case LLAISYS_MEMCPY_D2D: + return mcMemcpyDeviceToDevice; + default: + return mcMemcpyDefault; + } +} + +int getDeviceCount() { + int n = 0; + mcError_t e = mcGetDeviceCount(&n); + if (e != mcSuccess) { + return 0; + } + return n; +} + +void setDevice(int device_id) { + mcSetDevice(device_id); +} + +void deviceSynchronize() { + mcDeviceSynchronize(); +} + +llaisysStream_t createStream() { + mcStream_t s = nullptr; + mcError_t e = mcStreamCreate(&s); + if (e != mcSuccess) { + return nullptr; + } + return reinterpret_cast(s); +} + +void destroyStream(llaisysStream_t stream) { + if (stream) { + mcStreamDestroy(reinterpret_cast(stream)); + } +} + +void streamSynchronize(llaisysStream_t stream) { + if (stream) { + mcStreamSynchronize(reinterpret_cast(stream)); + } +} + +void *mallocDevice(size_t size) { + void *p = nullptr; + mcMalloc(&p, size); + return p; +} + +void freeDevice(void *ptr) { + if (ptr) { + mcFree(ptr); + } +} + +void *mallocHost(size_t size) { + // Keep host allocation policy aligned with CPU/NVIDIA backends. + return std::malloc(size); +} + +void freeHost(void *ptr) { + std::free(ptr); +} + +void memcpySync(void *dst, const void *src, size_t size, llaisysMemcpyKind_t kind) { + mcMemcpy(dst, src, size, toMcMemcpyKind(kind)); +} + +void memcpyAsync(void *dst, const void *src, size_t size, llaisysMemcpyKind_t kind, llaisysStream_t stream) { + mcStream_t s = stream ? reinterpret_cast(stream) : (mcStream_t)0; + mcMemcpyAsync(dst, src, size, toMcMemcpyKind(kind), s); +} + +static const LlaisysRuntimeAPI RUNTIME_API = { + &getDeviceCount, + &setDevice, + &deviceSynchronize, + &createStream, + &destroyStream, + &streamSynchronize, + &mallocDevice, + &freeDevice, + &mallocHost, + &freeHost, + &memcpySync, + &memcpyAsync, +}; + +} // namespace runtime_api + +const LlaisysRuntimeAPI *getRuntimeAPI() { + return &runtime_api::RUNTIME_API; +} + +} // namespace llaisys::device::metax diff --git a/src/device/nvidia/nvidia_runtime_api.cu b/src/device/nvidia/nvidia_runtime_api.cu index cab928261..65da83990 100644 --- a/src/device/nvidia/nvidia_runtime_api.cu +++ b/src/device/nvidia/nvidia_runtime_api.cu @@ -1,56 +1,98 @@ #include "../runtime_api.hpp" +#include "llaisys.h" +#include #include #include namespace llaisys::device::nvidia { namespace runtime_api { + int getDeviceCount() { - TO_BE_IMPLEMENTED(); + int n = 0; + cudaError_t e = cudaGetDeviceCount(&n); + if (e == cudaErrorNoDevice || e == cudaErrorInsufficientDriver) { + return 0; + } + if (e != cudaSuccess) { + return 0; + } + return n; } -void setDevice(int) { - TO_BE_IMPLEMENTED(); +void setDevice(int device_id) { + cudaSetDevice(device_id); } void deviceSynchronize() { - TO_BE_IMPLEMENTED(); + cudaDeviceSynchronize(); } llaisysStream_t createStream() { - TO_BE_IMPLEMENTED(); + cudaStream_t s = nullptr; + cudaStreamCreate(&s); + return (llaisysStream_t)s; } void destroyStream(llaisysStream_t stream) { - TO_BE_IMPLEMENTED(); + if (stream) { + cudaStreamDestroy((cudaStream_t)stream); + } } + void streamSynchronize(llaisysStream_t stream) { - TO_BE_IMPLEMENTED(); + if (stream) { + cudaStreamSynchronize((cudaStream_t)stream); + } } void *mallocDevice(size_t size) { - TO_BE_IMPLEMENTED(); + void *p = nullptr; + cudaMalloc(&p, size); + return p; } void freeDevice(void *ptr) { - TO_BE_IMPLEMENTED(); + if (ptr) { + cudaFree(ptr); + } } void *mallocHost(size_t size) { - TO_BE_IMPLEMENTED(); + void *p = nullptr; + cudaMallocHost(&p, size); + return p; } void freeHost(void *ptr) { - TO_BE_IMPLEMENTED(); + if (ptr) { + cudaFreeHost(ptr); + } +} + +static cudaMemcpyKind toCudaMemcpyKind(llaisysMemcpyKind_t kind) { + switch (kind) { + case LLAISYS_MEMCPY_H2H: + return cudaMemcpyHostToHost; + case LLAISYS_MEMCPY_H2D: + return cudaMemcpyHostToDevice; + case LLAISYS_MEMCPY_D2H: + return cudaMemcpyDeviceToHost; + case LLAISYS_MEMCPY_D2D: + return cudaMemcpyDeviceToDevice; + default: + return cudaMemcpyDefault; + } } void memcpySync(void *dst, const void *src, size_t size, llaisysMemcpyKind_t kind) { - TO_BE_IMPLEMENTED(); + cudaMemcpy(dst, src, size, toCudaMemcpyKind(kind)); } -void memcpyAsync(void *dst, const void *src, size_t size, llaisysMemcpyKind_t kind) { - TO_BE_IMPLEMENTED(); +void memcpyAsync(void *dst, const void *src, size_t size, llaisysMemcpyKind_t kind, llaisysStream_t stream) { + cudaStream_t s = stream ? (cudaStream_t)stream : (cudaStream_t)0; + cudaMemcpyAsync(dst, src, size, toCudaMemcpyKind(kind), s); } static const LlaisysRuntimeAPI RUNTIME_API = { @@ -65,11 +107,13 @@ static const LlaisysRuntimeAPI RUNTIME_API = { &mallocHost, &freeHost, &memcpySync, - &memcpyAsync}; + &memcpyAsync, +}; } // namespace runtime_api const LlaisysRuntimeAPI *getRuntimeAPI() { return &runtime_api::RUNTIME_API; } + } // namespace llaisys::device::nvidia diff --git a/src/device/runtime_api.cpp b/src/device/runtime_api.cpp index 2de3eca02..233afa896 100644 --- a/src/device/runtime_api.cpp +++ b/src/device/runtime_api.cpp @@ -80,6 +80,12 @@ const LlaisysRuntimeAPI *getRuntimeAPI(llaisysDeviceType_t device_type) { return llaisys::device::nvidia::getRuntimeAPI(); #else return getUnsupportedRuntimeAPI(); +#endif + case LLAISYS_DEVICE_METAX: +#ifdef ENABLE_METAX_API + return llaisys::device::metax::getRuntimeAPI(); +#else + return getUnsupportedRuntimeAPI(); #endif default: EXCEPTION_UNSUPPORTED_DEVICE; diff --git a/src/device/runtime_api.hpp b/src/device/runtime_api.hpp index e6b9f80d6..0e94644f5 100644 --- a/src/device/runtime_api.hpp +++ b/src/device/runtime_api.hpp @@ -17,4 +17,10 @@ namespace nvidia { const LlaisysRuntimeAPI *getRuntimeAPI(); } #endif + +#ifdef ENABLE_METAX_API +namespace metax { +const LlaisysRuntimeAPI *getRuntimeAPI(); +} +#endif } // namespace llaisys::device diff --git a/src/llaisys/models.cc b/src/llaisys/models.cc new file mode 100644 index 000000000..13c66c21a --- /dev/null +++ b/src/llaisys/models.cc @@ -0,0 +1,178 @@ +#include "llaisys/models/qwen2.h" + +#include "llaisys_tensor.hpp" +#include "../models/qwen2/model.hpp" + +#include +#include + +// C++ Model 的包装结构 +struct LlaisysQwen2Model { + std::unique_ptr model; + std::unique_ptr c_weights; // C 结构的权重,由 Python 设置 +}; + +// 同步权重从 C 结构到 C++ 模型 +static void sync_weights(struct LlaisysQwen2Model *model) { + if (!model->c_weights) return; + + auto& weights = model->model->weights(); + size_t nlayer = model->model->meta().nlayer; + + if (model->c_weights->in_embed) { + weights.in_embed = model->c_weights->in_embed->tensor; + } + if (model->c_weights->out_embed) { + weights.out_embed = model->c_weights->out_embed->tensor; + } + if (model->c_weights->out_norm_w) { + weights.out_norm_w = model->c_weights->out_norm_w->tensor; + } + for (size_t i = 0; i < nlayer; ++i) { + if (model->c_weights->attn_norm_w[i]) { + weights.attn_norm_w[i] = model->c_weights->attn_norm_w[i]->tensor; + } + if (model->c_weights->attn_q_w[i]) { + weights.attn_q_w[i] = model->c_weights->attn_q_w[i]->tensor; + } + if (model->c_weights->attn_q_b[i]) { + weights.attn_q_b[i] = model->c_weights->attn_q_b[i]->tensor; + } + if (model->c_weights->attn_k_w[i]) { + weights.attn_k_w[i] = model->c_weights->attn_k_w[i]->tensor; + } + if (model->c_weights->attn_k_b[i]) { + weights.attn_k_b[i] = model->c_weights->attn_k_b[i]->tensor; + } + if (model->c_weights->attn_v_w[i]) { + weights.attn_v_w[i] = model->c_weights->attn_v_w[i]->tensor; + } + if (model->c_weights->attn_v_b[i]) { + weights.attn_v_b[i] = model->c_weights->attn_v_b[i]->tensor; + } + if (model->c_weights->attn_o_w[i]) { + weights.attn_o_w[i] = model->c_weights->attn_o_w[i]->tensor; + } + if (model->c_weights->mlp_norm_w[i]) { + weights.mlp_norm_w[i] = model->c_weights->mlp_norm_w[i]->tensor; + } + if (model->c_weights->mlp_gate_w[i]) { + weights.mlp_gate_w[i] = model->c_weights->mlp_gate_w[i]->tensor; + } + if (model->c_weights->mlp_up_w[i]) { + weights.mlp_up_w[i] = model->c_weights->mlp_up_w[i]->tensor; + } + if (model->c_weights->mlp_down_w[i]) { + weights.mlp_down_w[i] = model->c_weights->mlp_down_w[i]->tensor; + } + } +} + +__C { + struct LlaisysQwen2Model *llaisysQwen2ModelCreate( + const LlaisysQwen2Meta *meta, + llaisysDeviceType_t device, + int *device_ids, + int ndevice) { + + llaisys::models::qwen2::ModelMeta cpp_meta; + cpp_meta.dtype = meta->dtype; + cpp_meta.nlayer = meta->nlayer; + cpp_meta.hs = meta->hs; + cpp_meta.nh = meta->nh; + cpp_meta.nkvh = meta->nkvh; + cpp_meta.dh = meta->dh; + cpp_meta.di = meta->di; + cpp_meta.maxseq = meta->maxseq; + cpp_meta.voc = meta->voc; + cpp_meta.epsilon = meta->epsilon; + cpp_meta.theta = meta->theta; + cpp_meta.end_token = meta->end_token; + + int device_id = (ndevice > 0 && device_ids) ? device_ids[0] : 0; + + auto model = std::make_unique(cpp_meta, device, device_id); + + return new LlaisysQwen2Model{std::move(model)}; + } + + void llaisysQwen2ModelDestroy(struct LlaisysQwen2Model *model) { + delete model; + } + + struct LlaisysQwen2Weights *llaisysQwen2ModelWeights(struct LlaisysQwen2Model *model) { + // 返回模型权重的引用,Python 侧可以设置这些指针 + // 如果还没有创建,则创建并初始化 + if (!model->c_weights) { + size_t nlayer = model->model->meta().nlayer; + model->c_weights = std::make_unique(); + + // 初始化指针为 nullptr,由 Python 侧设置 + model->c_weights->in_embed = nullptr; + model->c_weights->out_embed = nullptr; + model->c_weights->out_norm_w = nullptr; + + // 为每层权重分配数组 + model->c_weights->attn_norm_w = new LlaisysTensor*[nlayer]; + model->c_weights->attn_q_w = new LlaisysTensor*[nlayer]; + model->c_weights->attn_q_b = new LlaisysTensor*[nlayer]; + model->c_weights->attn_k_w = new LlaisysTensor*[nlayer]; + model->c_weights->attn_k_b = new LlaisysTensor*[nlayer]; + model->c_weights->attn_v_w = new LlaisysTensor*[nlayer]; + model->c_weights->attn_v_b = new LlaisysTensor*[nlayer]; + model->c_weights->attn_o_w = new LlaisysTensor*[nlayer]; + model->c_weights->mlp_norm_w = new LlaisysTensor*[nlayer]; + model->c_weights->mlp_gate_w = new LlaisysTensor*[nlayer]; + model->c_weights->mlp_up_w = new LlaisysTensor*[nlayer]; + model->c_weights->mlp_down_w = new LlaisysTensor*[nlayer]; + + // 初始化为 nullptr + for (size_t i = 0; i < nlayer; ++i) { + model->c_weights->attn_norm_w[i] = nullptr; + model->c_weights->attn_q_w[i] = nullptr; + model->c_weights->attn_q_b[i] = nullptr; + model->c_weights->attn_k_w[i] = nullptr; + model->c_weights->attn_k_b[i] = nullptr; + model->c_weights->attn_v_w[i] = nullptr; + model->c_weights->attn_v_b[i] = nullptr; + model->c_weights->attn_o_w[i] = nullptr; + model->c_weights->mlp_norm_w[i] = nullptr; + model->c_weights->mlp_gate_w[i] = nullptr; + model->c_weights->mlp_up_w[i] = nullptr; + model->c_weights->mlp_down_w[i] = nullptr; + } + } + + // 每次调用时同步权重(确保权重是最新的) + sync_weights(model); + + return model->c_weights.get(); + } + + void llaisysQwen2ModelResetCache(struct LlaisysQwen2Model *model) { + model->model->reset_cache(); + } + + int64_t llaisysQwen2ModelInfer( + struct LlaisysQwen2Model *model, + int64_t *token_ids, + size_t ntoken) { + + // 允许 Python 在任意时刻更新 c_weights 指针: + // 推理前再同步一次,避免"先拿到 weights 指针 -> Python 填充 -> 没再调用 Weights()"导致的未同步问题。 + sync_weights(model); + return model->model->infer(token_ids, ntoken); + } + + int64_t llaisysQwen2ModelInferSample( + struct LlaisysQwen2Model *model, + int64_t *token_ids, + size_t ntoken, + int top_k, + float top_p, + float temperature) { + + sync_weights(model); + return model->model->infer(token_ids, ntoken, top_k, top_p, temperature); + } +} diff --git a/src/models/qwen2/model.cpp b/src/models/qwen2/model.cpp new file mode 100644 index 000000000..2d74e409a --- /dev/null +++ b/src/models/qwen2/model.cpp @@ -0,0 +1,375 @@ +#include "model.hpp" +#include "../../core/llaisys_core.hpp" +#include "../../device/runtime_api.hpp" +#include "../../ops/add/op.hpp" +#include "../../utils.hpp" +#include +#include +#include +#include +#include +#include +#include + +namespace llaisys::models::qwen2 { +namespace { +int64_t argmax_host(const std::vector &vals) { + ASSERT(!vals.empty(), "argmax_host: input must not be empty"); + size_t best = 0; + for (size_t i = 1; i < vals.size(); ++i) { + if (vals[i] > vals[best]) { + best = i; + } + } + return static_cast(best); +} + +std::vector logits_to_host_f32(tensor_t logits, + const LlaisysRuntimeAPI *api) { + const size_t n = logits->numel(); + std::vector out(n); + switch (logits->dtype()) { + case LLAISYS_DTYPE_F32: { + api->memcpy_sync(out.data(), logits->data(), n * sizeof(float), + LLAISYS_MEMCPY_D2H); + break; + } + case LLAISYS_DTYPE_F16: { + std::vector tmp(n); + api->memcpy_sync(tmp.data(), logits->data(), + n * sizeof(llaisys::fp16_t), LLAISYS_MEMCPY_D2H); + for (size_t i = 0; i < n; ++i) { + out[i] = llaisys::utils::cast(tmp[i]); + } + break; + } + case LLAISYS_DTYPE_BF16: { + std::vector tmp(n); + api->memcpy_sync(tmp.data(), logits->data(), + n * sizeof(llaisys::bf16_t), LLAISYS_MEMCPY_D2H); + for (size_t i = 0; i < n; ++i) { + out[i] = llaisys::utils::cast(tmp[i]); + } + break; + } + default: + EXCEPTION_UNSUPPORTED_DATATYPE(logits->dtype()); + } + return out; +} + +int64_t sample_from_logits(const std::vector &logits, int top_k, + float top_p, float temperature) { + ASSERT(!logits.empty(), "sample_from_logits: logits must not be empty"); + + if (temperature <= 0.0f) { + return argmax_host(logits); + } + + const size_t vocab = logits.size(); + if (top_k <= 0 || top_k > static_cast(vocab)) { + top_k = static_cast(vocab); + } + if (top_p <= 0.0f || top_p > 1.0f) { + top_p = 1.0f; + } + + if (top_k == 1 && top_p >= 1.0f) { + return argmax_host(logits); + } + + std::vector idx(vocab); + std::iota(idx.begin(), idx.end(), 0); + auto by_logit_desc + = [&logits](int a, int b) { return logits[a] > logits[b]; }; + if (top_k < static_cast(vocab)) { + std::partial_sort(idx.begin(), idx.begin() + top_k, idx.end(), + by_logit_desc); + idx.resize(top_k); + } + std::sort(idx.begin(), idx.end(), by_logit_desc); + + const float inv_temp = 1.0f / temperature; + float max_scaled = -std::numeric_limits::infinity(); + for (int i : idx) { + max_scaled = std::max(max_scaled, logits[i] * inv_temp); + } + + std::vector probs(idx.size(), 0.0); + double total = 0.0; + for (size_t i = 0; i < idx.size(); ++i) { + double p = std::exp( + static_cast(logits[idx[i]] * inv_temp - max_scaled)); + if (!std::isfinite(p) || p < 0.0) { + p = 0.0; + } + probs[i] = p; + total += p; + } + if (total <= 0.0) { + return static_cast(idx.front()); + } + + if (top_p < 1.0f) { + double cum = 0.0; + size_t keep = 0; + for (size_t i = 0; i < probs.size(); ++i) { + cum += probs[i] / total; + keep = i + 1; + if (cum >= static_cast(top_p)) { + break; + } + } + keep = std::max(keep, 1); + idx.resize(keep); + probs.resize(keep); + } + + thread_local std::mt19937 rng(std::random_device{}()); + std::discrete_distribution dist(probs.begin(), probs.end()); + int chosen = dist(rng); + return static_cast(idx[static_cast(chosen)]); +} +} // namespace + +Model::Model(const ModelMeta &meta, llaisysDeviceType_t device_type, + int device_id) + : meta_(meta), device_type_(device_type), device_id_(device_id), + cache_len_(0) { + k_cache_.resize(meta_.nlayer); + v_cache_.resize(meta_.nlayer); + for (size_t i = 0; i < meta_.nlayer; ++i) { + k_cache_[i] = Tensor::create({meta_.maxseq, meta_.nkvh, meta_.dh}, + meta_.dtype, device_type_, device_id_); + v_cache_[i] = Tensor::create({meta_.maxseq, meta_.nkvh, meta_.dh}, + meta_.dtype, device_type_, device_id_); + } + + weights_.attn_norm_w.resize(meta_.nlayer); + weights_.attn_q_w.resize(meta_.nlayer); + weights_.attn_q_b.resize(meta_.nlayer); + weights_.attn_k_w.resize(meta_.nlayer); + weights_.attn_k_b.resize(meta_.nlayer); + weights_.attn_v_w.resize(meta_.nlayer); + weights_.attn_v_b.resize(meta_.nlayer); + weights_.attn_o_w.resize(meta_.nlayer); + weights_.mlp_norm_w.resize(meta_.nlayer); + weights_.mlp_gate_w.resize(meta_.nlayer); + weights_.mlp_up_w.resize(meta_.nlayer); + weights_.mlp_down_w.resize(meta_.nlayer); + + // Zero-initialized fallback bias for layers without bias terms. + dummy_bias_hs_ + = Tensor::create({meta_.hs}, meta_.dtype, device_type_, device_id_); + dummy_bias_di_ + = Tensor::create({meta_.di}, meta_.dtype, device_type_, device_id_); + dummy_bias_q_ = Tensor::create({meta_.nh * meta_.dh}, meta_.dtype, + device_type_, device_id_); + dummy_bias_kv_ = Tensor::create({meta_.nkvh * meta_.dh}, meta_.dtype, + device_type_, device_id_); + dummy_bias_voc_ + = Tensor::create({meta_.voc}, meta_.dtype, device_type_, device_id_); + + auto zero_tensor = [](const tensor_t &t) { + std::vector zeros(t->numel() * t->elementSize(), + std::byte{0}); + t->load(zeros.data()); + }; + zero_tensor(dummy_bias_hs_); + zero_tensor(dummy_bias_di_); + zero_tensor(dummy_bias_q_); + zero_tensor(dummy_bias_kv_); + zero_tensor(dummy_bias_voc_); +} + +Model::~Model() {} + +void Model::reset_cache() { cache_len_ = 0; } + +void Model::ensure_tensor(tensor_t &tensor, const std::vector &shape, + llaisysDataType_t dtype) { + const bool need_new = (!tensor) || tensor->dtype() != dtype + || tensor->deviceType() != device_type_ + || tensor->deviceId() != device_id_ + || tensor->shape() != shape; + if (need_new) { + tensor = Tensor::create(shape, dtype, device_type_, device_id_); + } +} + +void Model::update_kv_cache(size_t layer_idx, tensor_t k_new, tensor_t v_new, + size_t seqlen, size_t old_len) { + // Append the current step K/V to the cache. + ASSERT(old_len == cache_len_, + "update_kv_cache: old_len must equal cache_len_"); + size_t new_len = old_len + seqlen; + CHECK_ARGUMENT(new_len <= meta_.maxseq, "update_kv_cache: cache overflow"); + + llaisys::core::context().setDevice(device_type_, device_id_); + const LlaisysRuntimeAPI *api = llaisys::core::context().runtime().api(); + + size_t k_size = k_new->numel() * k_new->elementSize(); + size_t v_size = v_new->numel() * v_new->elementSize(); + + ASSERT(k_new->isContiguous() && v_new->isContiguous(), + "update_kv_cache: k_new and v_new must be contiguous"); + ASSERT(k_cache_[layer_idx]->isContiguous() + && v_cache_[layer_idx]->isContiguous(), + "update_kv_cache: cache tensors must be contiguous"); + + const size_t cache_row_bytes = meta_.nkvh * meta_.dh * k_new->elementSize(); + const size_t dst_offset_bytes = old_len * cache_row_bytes; + api->memcpy_sync(k_cache_[layer_idx]->data() + dst_offset_bytes, + k_new->data(), k_size, LLAISYS_MEMCPY_D2D); + api->memcpy_sync(v_cache_[layer_idx]->data() + dst_offset_bytes, + v_new->data(), v_size, LLAISYS_MEMCPY_D2D); +} + +void Model::forward_layer(size_t layer_idx, tensor_t &x, size_t seqlen, + size_t total_len, tensor_t pos_ids_q) { + llaisys::core::context().setDevice(device_type_, device_id_); + + ensure_tensor(x_norm_, {seqlen, meta_.hs}, meta_.dtype); + ops::rms_norm(x_norm_, x, weights_.attn_norm_w[layer_idx], meta_.epsilon); + + ensure_tensor(q_flat_, {seqlen, meta_.nh * meta_.dh}, meta_.dtype); + ensure_tensor(k_flat_, {seqlen, meta_.nkvh * meta_.dh}, meta_.dtype); + ensure_tensor(v_flat_, {seqlen, meta_.nkvh * meta_.dh}, meta_.dtype); + + tensor_t q_bias = (weights_.attn_q_b[layer_idx] + && weights_.attn_q_b[layer_idx]->numel() > 0) + ? weights_.attn_q_b[layer_idx] + : dummy_bias_q_; + tensor_t k_bias = (weights_.attn_k_b[layer_idx] + && weights_.attn_k_b[layer_idx]->numel() > 0) + ? weights_.attn_k_b[layer_idx] + : dummy_bias_kv_; + tensor_t v_bias = (weights_.attn_v_b[layer_idx] + && weights_.attn_v_b[layer_idx]->numel() > 0) + ? weights_.attn_v_b[layer_idx] + : dummy_bias_kv_; + + ops::linear(q_flat_, x_norm_, weights_.attn_q_w[layer_idx], q_bias); + ops::linear(k_flat_, x_norm_, weights_.attn_k_w[layer_idx], k_bias); + ops::linear(v_flat_, x_norm_, weights_.attn_v_w[layer_idx], v_bias); + + q_ = q_flat_->view({seqlen, meta_.nh, meta_.dh}); + k_ = k_flat_->view({seqlen, meta_.nkvh, meta_.dh}); + v_ = v_flat_->view({seqlen, meta_.nkvh, meta_.dh}); + + // RoPE is applied to newly generated tokens only. + ensure_tensor(q_rope_, {seqlen, meta_.nh, meta_.dh}, meta_.dtype); + ensure_tensor(k_rope_new_, {seqlen, meta_.nkvh, meta_.dh}, meta_.dtype); + ops::rope(k_rope_new_, k_, pos_ids_q, meta_.theta); + ops::rope(q_rope_, q_, pos_ids_q, meta_.theta); + + size_t old_len = total_len - seqlen; + update_kv_cache(layer_idx, k_rope_new_, v_, seqlen, old_len); + + k_full_ = k_cache_[layer_idx]->slice(0, 0, total_len); + v_full_ = v_cache_[layer_idx]->slice(0, 0, total_len); + + ensure_tensor(attn_out_, {seqlen, meta_.nh, meta_.dh}, meta_.dtype); + float scale = 1.0f / std::sqrt(static_cast(meta_.dh)); + ops::self_attention(attn_out_, q_rope_, k_full_, v_full_, scale); + + tensor_t attn_out_flat = attn_out_->view({seqlen, meta_.nh * meta_.dh}); + ensure_tensor(attn_proj_out_, {seqlen, meta_.hs}, meta_.dtype); + ops::linear(attn_proj_out_, attn_out_flat, weights_.attn_o_w[layer_idx], + nullptr); + + ensure_tensor(x_attn_, {seqlen, meta_.hs}, meta_.dtype); + ops::add(x_attn_, x, attn_proj_out_); + x = x_attn_; + + ensure_tensor(x_norm_, {seqlen, meta_.hs}, meta_.dtype); + ops::rms_norm(x_norm_, x, weights_.mlp_norm_w[layer_idx], meta_.epsilon); + + ensure_tensor(gate_, {seqlen, meta_.di}, meta_.dtype); + ensure_tensor(up_, {seqlen, meta_.di}, meta_.dtype); + + ops::linear(gate_, x_norm_, weights_.mlp_gate_w[layer_idx], nullptr); + ops::linear(up_, x_norm_, weights_.mlp_up_w[layer_idx], nullptr); + + ensure_tensor(swiglu_out_, {seqlen, meta_.di}, meta_.dtype); + ops::swiglu(swiglu_out_, gate_, up_); + + ensure_tensor(mlp_out_, {seqlen, meta_.hs}, meta_.dtype); + ops::linear(mlp_out_, swiglu_out_, weights_.mlp_down_w[layer_idx], nullptr); + + ensure_tensor(x_mlp_, {seqlen, meta_.hs}, meta_.dtype); + ops::add(x_mlp_, x, mlp_out_); + x = x_mlp_; +} + +tensor_t Model::forward(tensor_t input_ids, size_t seqlen, size_t total_len) { + llaisys::core::context().setDevice(device_type_, device_id_); + + ensure_tensor(x_, {seqlen, meta_.hs}, meta_.dtype); + ops::embedding(x_, input_ids, weights_.in_embed); + + // Reuse the same pos_ids across all layers in this forward pass. + size_t start_pos = total_len - seqlen; + ensure_tensor(pos_ids_q_, {seqlen}, LLAISYS_DTYPE_I64); + if (seqlen == 1) { + int64_t pos = static_cast(start_pos); + pos_ids_q_->load(&pos); + } else { + std::vector pos_ids_q_host(seqlen); + for (size_t i = 0; i < seqlen; ++i) { + pos_ids_q_host[i] = static_cast(start_pos + i); + } + pos_ids_q_->load(pos_ids_q_host.data()); + } + + for (size_t i = 0; i < meta_.nlayer; ++i) { + forward_layer(i, x_, seqlen, total_len, pos_ids_q_); + } + + ensure_tensor(x_norm_, {seqlen, meta_.hs}, meta_.dtype); + ops::rms_norm(x_norm_, x_, weights_.out_norm_w, meta_.epsilon); + + ensure_tensor(logits_, {seqlen, meta_.voc}, meta_.dtype); + ops::linear(logits_, x_norm_, weights_.out_embed, nullptr); + + return logits_; +} + +int64_t Model::infer(int64_t *token_ids, size_t ntoken, int top_k, float top_p, + float temperature) { + llaisys::core::context().setDevice(device_type_, device_id_); + + ensure_tensor(input_ids_buf_, {ntoken}, LLAISYS_DTYPE_I64); + input_ids_buf_->load(token_ids); + + size_t seqlen = ntoken; + size_t total_len = cache_len_ + seqlen; + + tensor_t logits = forward(input_ids_buf_, seqlen, total_len); + + cache_len_ = total_len; + + tensor_t last_logits = logits->slice(0, seqlen - 1, seqlen); + last_logits = last_logits->view({meta_.voc}); + + const bool greedy = (top_k == 1) && (top_p >= 1.0f) + && (std::abs(temperature - 1.0f) < 1e-6f); + if (greedy) { + // Fast path: keep current argmax operator pipeline. + ensure_tensor(max_idx_, {1}, LLAISYS_DTYPE_I64); + ensure_tensor(max_val_, {1}, meta_.dtype); + ops::argmax(max_idx_, max_val_, last_logits); + + int64_t host_result = 0; + llaisys::core::context().runtime().api()->memcpy_sync( + &host_result, max_idx_->data(), sizeof(int64_t), + LLAISYS_MEMCPY_D2H); + return host_result; + } + + const LlaisysRuntimeAPI *api = llaisys::core::context().runtime().api(); + std::vector host_logits = logits_to_host_f32(last_logits, api); + return sample_from_logits(host_logits, top_k, top_p, temperature); +} + +} // namespace llaisys::models::qwen2 diff --git a/src/models/qwen2/model.hpp b/src/models/qwen2/model.hpp new file mode 100644 index 000000000..2b9ce5621 --- /dev/null +++ b/src/models/qwen2/model.hpp @@ -0,0 +1,129 @@ +#pragma once + +#include "../../tensor/tensor.hpp" +#include "../../ops/embedding/op.hpp" +#include "../../ops/linear/op.hpp" +#include "../../ops/rms_norm/op.hpp" +#include "../../ops/rope/op.hpp" +#include "../../ops/self_attention/op.hpp" +#include "../../ops/swiglu/op.hpp" +#include "../../ops/argmax/op.hpp" + +#include +#include + +namespace llaisys::models::qwen2 { +// 模型元数据 + struct ModelMeta { + llaisysDataType_t dtype; + size_t nlayer; // 层数 + size_t hs; // hidden size + size_t nh; // num heads + size_t nkvh; // num kv heads + size_t dh; // head dimension + size_t di; // intermediate size + size_t maxseq; // max sequence length + size_t voc; // vocabulary size + float epsilon; // RMS norm epsilon + float theta; // RoPE theta + int64_t end_token; // end token id +}; + +// 模型权重 + struct ModelWeights { + tensor_t in_embed; // [voc, hs] + tensor_t out_embed; // [voc, hs] + tensor_t out_norm_w; // [hs] + + // 每层的权重 + std::vector attn_norm_w; // [nlayer] x [hs] + std::vector attn_q_w; // [nlayer] x [nh * dh, hs] + std::vector attn_q_b; // [nlayer] x [nh * dh] (可能为空) + std::vector attn_k_w; // [nlayer] x [nkvh * dh, hs] + std::vector attn_k_b; // [nlayer] x [nkvh * dh] (可能为空) + std::vector attn_v_w; // [nlayer] x [nkvh * dh, hs] + std::vector attn_v_b; // [nlayer] x [nkvh * dh] (可能为空) + std::vector attn_o_w; // [nlayer] x [hs, nh * dh] + + std::vector mlp_norm_w; // [nlayer] x [hs] + std::vector mlp_gate_w; // [nlayer] x [di, hs] + std::vector mlp_up_w; // [nlayer] x [di, hs] + std::vector mlp_down_w; // [nlayer] x [hs, di] +}; + +// 模型类 +class Model { +private: + ModelMeta meta_; + ModelWeights weights_; + llaisysDeviceType_t device_type_; + int device_id_; + + // KV Cache: 每层的 K 和 V + std::vector k_cache_; // [nlayer] x [maxseq, nkvh, dh] + std::vector v_cache_; // [nlayer] x [maxseq, nkvh, dh] + size_t cache_len_; // 当前 cache 长度 + + // Dummy bias tensors(用于没有 bias 的层,必须全零) + tensor_t dummy_bias_hs_; // [hs] - 用于 o_proj, mlp_down, out_embed + tensor_t dummy_bias_di_; // [di] - 用于 mlp_gate, mlp_up + tensor_t dummy_bias_q_; // [nh * dh] - 用于 q_proj + tensor_t dummy_bias_kv_; // [nkvh * dh] - 用于 k_proj, v_proj + tensor_t dummy_bias_voc_; // [voc] - 用于 out_embed + + // 临时张量(避免重复分配) + tensor_t x_; // 当前隐藏状态 [seqlen, hs] + tensor_t x_norm_; // 归一化后的隐藏状态 + tensor_t q_flat_; // [seqlen, nh * dh] + tensor_t k_flat_; // [seqlen, nkvh * dh] + tensor_t v_flat_; // [seqlen, nkvh * dh] + tensor_t q_; // Query [seqlen, nh, dh] + tensor_t k_; // Key [seqlen, nkvh, dh] + tensor_t v_; // Value [seqlen, nkvh, dh] + tensor_t q_rope_; // [seqlen, nh, dh] + tensor_t k_rope_new_; // [seqlen, nkvh, dh] + tensor_t k_full_; // 完整的 K(包含 cache)[total_len, nkvh, dh] + tensor_t v_full_; // 完整的 V(包含 cache)[total_len, nkvh, dh] + tensor_t attn_out_; // Attention 输出 [seqlen, nh, dh] + tensor_t attn_proj_out_; // Attention 投影输出 [seqlen, hs] + tensor_t x_attn_; // Attention 残差输出 [seqlen, hs] + tensor_t gate_; // MLP gate [seqlen, di] + tensor_t up_; // MLP up [seqlen, di] + tensor_t swiglu_out_; // SwiGLU 输出 [seqlen, di] + tensor_t mlp_out_; // MLP 输出 [seqlen, hs] + tensor_t x_mlp_; // MLP 残差输出 [seqlen, hs] + tensor_t logits_; // 输出 logits [seqlen, voc] + tensor_t pos_ids_q_; // 位置 id [seqlen] + tensor_t input_ids_buf_; // infer 输入缓存 [ntoken] + tensor_t max_idx_; // argmax 索引缓存 [1] + tensor_t max_val_; // argmax 值缓存 [1] + + // 前向传播辅助函数 + void ensure_tensor(tensor_t &tensor, const std::vector &shape, llaisysDataType_t dtype); + void forward_layer(size_t layer_idx, tensor_t& x, size_t seqlen, size_t total_len, tensor_t pos_ids_q); + void update_kv_cache(size_t layer_idx, tensor_t k_new, tensor_t v_new, size_t seqlen, size_t old_len); + +public: + Model(const ModelMeta& meta, llaisysDeviceType_t device_type, int device_id); + ~Model(); + + ModelWeights& weights() { return weights_; } + const ModelWeights& weights() const { return weights_; } + const ModelMeta& meta() const { return meta_; } + + // 前向传播:返回 logits + tensor_t forward(tensor_t input_ids, size_t seqlen, size_t total_len); + + // 推理:生成下一个 token + int64_t infer( + int64_t* token_ids, + size_t ntoken, + int top_k = 1, + float top_p = 1.0f, + float temperature = 1.0f); + + // 重置 KV Cache + void reset_cache(); +}; + +} // namespace llaisys::models::qwen2 diff --git a/src/ops/add/metax/add_metax.hpp b/src/ops/add/metax/add_metax.hpp new file mode 100644 index 000000000..ea9dca0d3 --- /dev/null +++ b/src/ops/add/metax/add_metax.hpp @@ -0,0 +1,9 @@ +#pragma once + +#include "llaisys.h" + +namespace llaisys::ops::metax { + +void add(void *c, const void *a, const void *b, llaisysDataType_t type, size_t numel); + +} // namespace llaisys::ops::metax diff --git a/src/ops/add/metax/add_metax.maca b/src/ops/add/metax/add_metax.maca new file mode 100644 index 000000000..26c8f62fa --- /dev/null +++ b/src/ops/add/metax/add_metax.maca @@ -0,0 +1,77 @@ +#include "add_metax.hpp" + +#include "../../../utils.hpp" + +#include +#include +#include + +namespace { + +__global__ void add_f32_kernel(float *c, const float *a, const float *b, size_t n) { + const size_t idx = static_cast(blockIdx.x) * static_cast(blockDim.x) + + static_cast(threadIdx.x); + if (idx < n) { + c[idx] = a[idx] + b[idx]; + } +} + +__global__ void add_f16_kernel(__half *c, const __half *a, const __half *b, size_t n) { + const size_t idx = static_cast(blockIdx.x) * static_cast(blockDim.x) + + static_cast(threadIdx.x); + if (idx < n) { + c[idx] = __hadd(a[idx], b[idx]); + } +} + +__global__ void add_bf16_kernel(__maca_bfloat16 *c, + const __maca_bfloat16 *a, + const __maca_bfloat16 *b, + size_t n) { + const size_t idx = static_cast(blockIdx.x) * static_cast(blockDim.x) + + static_cast(threadIdx.x); + if (idx < n) { + c[idx] = __hadd(a[idx], b[idx]); + } +} + +} // namespace + +namespace llaisys::ops::metax { + +void add(void *c, const void *a, const void *b, llaisysDataType_t type, size_t numel) { + if (numel == 0) { + return; + } + + const dim3 block(256); + const dim3 grid((numel + block.x - 1) / block.x); + + switch (type) { + case LLAISYS_DTYPE_F32: + add_f32_kernel<<>>( + reinterpret_cast(c), + reinterpret_cast(a), + reinterpret_cast(b), + numel); + break; + case LLAISYS_DTYPE_F16: + add_f16_kernel<<>>( + reinterpret_cast<__half *>(c), + reinterpret_cast(a), + reinterpret_cast(b), + numel); + break; + case LLAISYS_DTYPE_BF16: + add_bf16_kernel<<>>( + reinterpret_cast<__maca_bfloat16 *>(c), + reinterpret_cast(a), + reinterpret_cast(b), + numel); + break; + default: + EXCEPTION_UNSUPPORTED_DATATYPE(type); + } +} + +} // namespace llaisys::ops::metax diff --git a/src/ops/add/nvidia/add_nvidia.cu b/src/ops/add/nvidia/add_nvidia.cu new file mode 100644 index 000000000..65d22b425 --- /dev/null +++ b/src/ops/add/nvidia/add_nvidia.cu @@ -0,0 +1,96 @@ +#include "add_nvidia.cuh" + +#include "../../../utils.hpp" + +#include "../../../utils/gpu_utils.hpp" + +__global__ void add_f32_kernel(float *c, const float *a, const float *b, size_t n) { + int idx = 4 * (blockIdx.x * blockDim.x + threadIdx.x); + if (idx < n) { + float4 reg_a = LOAD_FLOAT4(a[idx]); + float4 reg_b = LOAD_FLOAT4(b[idx]); + float4 reg_c; + reg_c.x = reg_a.x + reg_b.x; + reg_c.y = reg_a.y + reg_b.y; + reg_c.z = reg_a.z + reg_b.z; + reg_c.w = reg_a.w + reg_b.w; + STORE_FLOAT4(c[idx]) = reg_c; + } +} + +__global__ void add_f16_kernel(half *c, const half *a, const half *b, size_t n) { + int idx = 2 * (blockIdx.x * blockDim.x + threadIdx.x); + if (idx < n) { + half2 reg_a = LOAD_HALF2(a[idx]); + half2 reg_b = LOAD_HALF2(b[idx]); + half2 reg_c; + reg_c.x = __hadd(reg_a.x, reg_b.x); + reg_c.y = __hadd(reg_a.y, reg_b.y); + STORE_HALF2(c[idx]) = reg_c; + } +} + +__global__ void add_bf16_kernel(__nv_bfloat16 *c, const __nv_bfloat16 *a, const __nv_bfloat16 *b, size_t n) { + int idx = 2 * (blockIdx.x * blockDim.x + threadIdx.x); + if (idx < n) { + __nv_bfloat162 reg_a = LOAD_BFLOAT2(a[idx]); + __nv_bfloat162 reg_b = LOAD_BFLOAT2(b[idx]); + __nv_bfloat162 reg_c; + reg_c.x = __hadd(reg_a.x, reg_b.x); + reg_c.y = __hadd(reg_a.y, reg_b.y); + STORE_BFLOAT2(c[idx]) = reg_c; + } +} + +void config_launch(dim3 &block, dim3 &grid, llaisysDataType_t type, size_t numel) { + switch (type) { + case LLAISYS_DTYPE_F32: + block = dim3(256); + grid = dim3(CEIL(CEIL(numel,4), 256)); + break; + case LLAISYS_DTYPE_F16: + block = dim3(256); + grid = dim3(CEIL(CEIL(numel,2), 256)); + break; + case LLAISYS_DTYPE_BF16: + block = dim3(256); + grid = dim3(CEIL(CEIL(numel,2), 256)); + break; + default: + EXCEPTION_UNSUPPORTED_DATATYPE(type); + } +} + +namespace llaisys::ops::nvidia { +void add(std::byte *c, const std::byte *a, const std::byte *b, llaisysDataType_t type, size_t numel) { + if (numel == 0) { + return; + } + + dim3 block{0}; + dim3 grid{0}; + config_launch(block, grid, type, numel); + + switch (type) { + case LLAISYS_DTYPE_F32: + add_f32_kernel<<>>(reinterpret_cast(c), reinterpret_cast(a), + reinterpret_cast(b), numel); + break; + case LLAISYS_DTYPE_F16: + add_f16_kernel<<>>(reinterpret_cast(c), + reinterpret_cast(a), + reinterpret_cast(b), numel); + break; + case LLAISYS_DTYPE_BF16: + add_bf16_kernel<<>>(reinterpret_cast<__nv_bfloat16 *>(c), + reinterpret_cast(a), + reinterpret_cast(b), numel); + break; + default: + EXCEPTION_UNSUPPORTED_DATATYPE(type); + } + + CUDA_CHECK(cudaGetLastError()); +} + +} // namespace llaisys::ops::nvidia diff --git a/src/ops/add/nvidia/add_nvidia.cuh b/src/ops/add/nvidia/add_nvidia.cuh new file mode 100644 index 000000000..ba5b04bee --- /dev/null +++ b/src/ops/add/nvidia/add_nvidia.cuh @@ -0,0 +1,14 @@ +#pragma once + +#include "llaisys.h" + +#include + +namespace llaisys::ops::nvidia { + +// Elementwise add: c = a + b +// Pointers are device pointers. +void add(std::byte *c, const std::byte *a, const std::byte *b, llaisysDataType_t type, size_t numel); + +} // namespace llaisys::ops::nvidia + diff --git a/src/ops/add/op.cpp b/src/ops/add/op.cpp index a057330d7..39d6344ce 100644 --- a/src/ops/add/op.cpp +++ b/src/ops/add/op.cpp @@ -4,6 +4,12 @@ #include "../../utils.hpp" #include "cpu/add_cpu.hpp" +#ifdef ENABLE_NVIDIA_API +#include "nvidia/add_nvidia.cuh" +#endif +#ifdef ENABLE_METAX_API +#include "metax/add_metax.hpp" +#endif namespace llaisys::ops { void add(tensor_t c, tensor_t a, tensor_t b) { @@ -25,8 +31,11 @@ void add(tensor_t c, tensor_t a, tensor_t b) { return cpu::add(c->data(), a->data(), b->data(), c->dtype(), c->numel()); #ifdef ENABLE_NVIDIA_API case LLAISYS_DEVICE_NVIDIA: - TO_BE_IMPLEMENTED(); - return; + return nvidia::add(c->data(), a->data(), b->data(), c->dtype(), c->numel()); +#endif +#ifdef ENABLE_METAX_API + case LLAISYS_DEVICE_METAX: + return metax::add(c->data(), a->data(), b->data(), c->dtype(), c->numel()); #endif default: EXCEPTION_UNSUPPORTED_DEVICE; diff --git a/src/ops/argmax/cpu/argmax_cpu.cpp b/src/ops/argmax/cpu/argmax_cpu.cpp new file mode 100644 index 000000000..cdf97ebd7 --- /dev/null +++ b/src/ops/argmax/cpu/argmax_cpu.cpp @@ -0,0 +1,70 @@ +#include "argmax_cpu.hpp" + +#include "../../../utils.hpp" +#include "llaisys.h" + +#include +#include +#include +#include +#include + +// cpu侧实现 +template +void argmax_(int64_t *max_idx, T *max_val, const T *vals, size_t numel) { + if (numel == 0) { + *max_idx = 0; + // 对于fp16和bf16这种非内置类型,需要用cast转换;其他类型使用默认构造赋0值 + if (std::is_same_v || std::is_same_v) { + *max_val = llaisys::utils::cast(0.0f); + } else { + *max_val = T{}; + } + return; + } + + T tmp_max_val = vals[0]; + int64_t tmp_max_idx = 0; + + // 对于fp16和bf16,先转为float32进行比较,避免精度丢失 + if constexpr (std::is_same_v || std::is_same_v) { + float max_val_float = llaisys::utils::cast(vals[0]); + for (size_t i = 1; i < numel; ++i) { + float cur_val_float = llaisys::utils::cast(vals[i]); + if (cur_val_float > max_val_float) { + max_val_float = cur_val_float; + tmp_max_val = vals[i]; + tmp_max_idx = i; + } + } + } else { + for (size_t i = 1; i < numel; i++) { + if (vals[i] > tmp_max_val) { + tmp_max_val = vals[i]; + tmp_max_idx = i; + } + } + } + + *max_idx = tmp_max_idx; + *max_val = tmp_max_val; +} + +namespace llaisys::ops::cpu { +void argmax(int64_t *max_idx, std::byte *max_val, const std::byte *vals, llaisysDataType_t type, size_t numel) { + // 传入的是std::byte类型的指针,需要转成对应的类型 + switch (type) { + case LLAISYS_DTYPE_F32: + return argmax_(max_idx, reinterpret_cast(max_val), + reinterpret_cast(vals), numel); + case LLAISYS_DTYPE_BF16: + return argmax_(max_idx, reinterpret_cast(max_val), + reinterpret_cast(vals), numel); + case LLAISYS_DTYPE_F16: + return argmax_(max_idx, reinterpret_cast(max_val), + reinterpret_cast(vals), numel); + default: + EXCEPTION_UNSUPPORTED_DATATYPE(type); + } + } +} // namespace llaisys::ops::cpu diff --git a/src/ops/argmax/cpu/argmax_cpu.hpp b/src/ops/argmax/cpu/argmax_cpu.hpp new file mode 100644 index 000000000..02f4ea703 --- /dev/null +++ b/src/ops/argmax/cpu/argmax_cpu.hpp @@ -0,0 +1,10 @@ +#pragma once +#include "llaisys.h" + +#include +#include + +// max_val应为std::byte*,用于支持多种数据类型的通用内存写入,不能简单换成float*等具体类型,否则类型不兼容。 +namespace llaisys::ops::cpu { +void argmax(int64_t *max_idx, std::byte *max_val, const std::byte *vals, llaisysDataType_t type, size_t numel); +} \ No newline at end of file diff --git a/src/ops/argmax/metax/argmax_metax.hpp b/src/ops/argmax/metax/argmax_metax.hpp new file mode 100644 index 000000000..73d557d2f --- /dev/null +++ b/src/ops/argmax/metax/argmax_metax.hpp @@ -0,0 +1,17 @@ +#pragma once + +#include "llaisys.h" + +#include +#include + +namespace llaisys::ops::metax { + +void argmax(int64_t *max_idx, + std::byte *max_val, + const std::byte *vals, + llaisysDataType_t type, + size_t numel); + +} // namespace llaisys::ops::metax + diff --git a/src/ops/argmax/metax/argmax_metax.maca b/src/ops/argmax/metax/argmax_metax.maca new file mode 100644 index 000000000..d36d3d469 --- /dev/null +++ b/src/ops/argmax/metax/argmax_metax.maca @@ -0,0 +1,140 @@ +#include "argmax_metax.hpp" + +#include "../../../utils.hpp" + +#include +#include +#include + +#include +#include +#include + +namespace { + +template +__device__ __forceinline__ float to_float(T v); + +template <> +__device__ __forceinline__ float to_float(float v) { + return v; +} + +template <> +__device__ __forceinline__ float to_float<__half>(__half v) { + return __half2float(v); +} + +template <> +__device__ __forceinline__ float to_float<__maca_bfloat16>(__maca_bfloat16 v) { + return __bfloat162float(v); +} + +__device__ __forceinline__ void warp_argmax(float &max_val, int64_t &max_idx) { + constexpr maca_uint64_t full_mask = static_cast(~0ULL); + for (int stride = warpSize / 2; stride > 0; stride >>= 1) { + const float other_max = __shfl_down_sync(full_mask, max_val, stride, warpSize); + const int64_t other_idx = __shfl_down_sync(full_mask, max_idx, stride, warpSize); + if (other_idx >= 0 && (other_max > max_val || (other_max == max_val && (max_idx < 0 || other_idx < max_idx)))) { + max_val = other_max; + max_idx = other_idx; + } + } +} + +template +__global__ void argmax_kernel(int64_t *max_idx, T *max_val, const T *vals, size_t numel) { + __shared__ float smax[BLOCK_SIZE]; + __shared__ int64_t sidx[BLOCK_SIZE]; + + const int tid = threadIdx.x; + const int lane_id = tid % warpSize; + const int warp_id = tid / warpSize; + const int warp_count = (BLOCK_SIZE + warpSize - 1) / warpSize; + + float local_max = -INFINITY; + int64_t local_idx = -1; + + for (size_t i = static_cast(tid); i < numel; i += static_cast(BLOCK_SIZE)) { + const float cur = to_float(vals[i]); + if (cur > local_max || (cur == local_max && (local_idx < 0 || static_cast(i) < local_idx))) { + local_max = cur; + local_idx = static_cast(i); + } + } + + // Warp-level reduction first to cut shared-memory traffic and barriers. + warp_argmax(local_max, local_idx); + + if (lane_id == 0) { + smax[warp_id] = local_max; + sidx[warp_id] = local_idx; + } + __syncthreads(); + + // Final reduction over warp leaders by warp 0. + if (warp_id == 0) { + float block_max = (lane_id < warp_count) ? smax[lane_id] : -INFINITY; + int64_t block_idx = (lane_id < warp_count) ? sidx[lane_id] : -1; + warp_argmax(block_max, block_idx); + if (lane_id == 0) { + *max_idx = block_idx; + *max_val = vals[block_idx]; + } + } +} + +} // namespace + +namespace llaisys::ops::metax { + +void argmax(int64_t *max_idx, std::byte *max_val, const std::byte *vals, llaisysDataType_t type, size_t numel) { + if (numel == 0) { + *max_idx = 0; + switch (type) { + case LLAISYS_DTYPE_F32: + *reinterpret_cast(max_val) = 0.0f; + break; + case LLAISYS_DTYPE_F16: + *reinterpret_cast<__half *>(max_val) = __float2half(0.0f); + break; + case LLAISYS_DTYPE_BF16: + *reinterpret_cast<__maca_bfloat16 *>(max_val) = __float2bfloat16(0.0f); + break; + default: + EXCEPTION_UNSUPPORTED_DATATYPE(type); + } + return; + } + + constexpr int block_size = 512; + const int grid_size = 1; + + switch (type) { + case LLAISYS_DTYPE_F32: + argmax_kernel<<>>( + max_idx, + reinterpret_cast(max_val), + reinterpret_cast(vals), + numel); + break; + case LLAISYS_DTYPE_F16: + argmax_kernel<__half, block_size><<>>( + max_idx, + reinterpret_cast<__half *>(max_val), + reinterpret_cast(vals), + numel); + break; + case LLAISYS_DTYPE_BF16: + argmax_kernel<__maca_bfloat16, block_size><<>>( + max_idx, + reinterpret_cast<__maca_bfloat16 *>(max_val), + reinterpret_cast(vals), + numel); + break; + default: + EXCEPTION_UNSUPPORTED_DATATYPE(type); + } +} + +} // namespace llaisys::ops::metax diff --git a/src/ops/argmax/nvidia/argmax_nvidia.cu b/src/ops/argmax/nvidia/argmax_nvidia.cu new file mode 100644 index 000000000..8ed7c73f8 --- /dev/null +++ b/src/ops/argmax/nvidia/argmax_nvidia.cu @@ -0,0 +1,128 @@ +#include "../../../utils.hpp" +#include "../../../utils/gpu_utils.hpp" +#include "argmax_nvidia.cuh" +#include + +namespace { + +template +__device__ __forceinline__ void warp_argmax(T local_val, int64_t local_idx, T &max_val, int64_t &max_idx) { +#pragma unroll + for (int stride = 16; stride > 0; stride >>= 1) { + T other_val = __shfl_down_sync(0xffffffff, local_val, stride); + int64_t other_idx = __shfl_down_sync(0xffffffff, local_idx, stride); + + if (other_val > local_val || (other_val == local_val && other_idx < local_idx)) { + local_val = other_val; + local_idx = other_idx; + } + } + + if (threadIdx.x % 32 == 0) { + max_val = local_val; + max_idx = local_idx; + } +} + +template +__global__ void argmax_kernel(int64_t *max_idx, T *max_val, const T *vals, size_t numel) { + constexpr int warp_per_block = BLOCK_SIZE / 32; + + int tid = threadIdx.x; + const int warp_id = threadIdx.x / 32; + const int lane_id = threadIdx.x % 32; + + __shared__ T vals_shared[warp_per_block]; + __shared__ int64_t idxs_shared[warp_per_block]; + + // 0. 线程级别求局部最大值 + T thread_max_val = static_cast(-INFINITY); + int64_t thread_max_idx = -1; + for (int i = tid; i < numel; i += blockDim.x) { + T local_val = vals[i]; + if (local_val > thread_max_val || (local_val == thread_max_val && i < thread_max_idx)) { + thread_max_val = local_val; + thread_max_idx = i; + } + } + + // 1.warp内规约 + T warp_max_val = thread_max_val; + int64_t warp_max_idx = thread_max_idx; + warp_argmax(thread_max_val, thread_max_idx, warp_max_val, warp_max_idx); + + if (lane_id == 0) { + vals_shared[warp_id] = warp_max_val; + idxs_shared[warp_id] = warp_max_idx; + } + __syncthreads(); + + // 2. 用 warp 0 对共享内存里的各 warp 结果做规约,得到 block 的全局最大,再由 lane 0 写回 + if (warp_id == 0) { + // 每个 lane 持有一个候选 + T lane_val = lane_id < warp_per_block ? vals_shared[lane_id] : static_cast(-INFINITY); + int64_t lane_idx = lane_id < warp_per_block ? idxs_shared[lane_id] : -1; + T final_val; + int64_t final_idx; + warp_argmax(lane_val, lane_idx, final_val, final_idx); + if (lane_id == 0) { + *max_val = final_val; + *max_idx = final_idx; + } + } +} + +} // namespace + +namespace llaisys::ops::nvidia { + +void argmax(int64_t *max_idx, std::byte *max_val, const std::byte *vals, llaisysDataType_t type, size_t numel) { + // 特殊处理空张量的情况:max_val 是 std::byte*,需按类型写入 + if (numel == 0) { + *max_idx = 0; + switch (type) { + case LLAISYS_DTYPE_F32: + *reinterpret_cast(max_val) = 0.0f; + break; + case LLAISYS_DTYPE_F16: + *reinterpret_cast(max_val) = __float2half(0.0f); + break; + case LLAISYS_DTYPE_BF16: + *reinterpret_cast<__nv_bfloat16 *>(max_val) = __float2bfloat16(0.0f); + break; + default: + EXCEPTION_UNSUPPORTED_DATATYPE(type); + } + return; + } + + const int block_size = 256; + const int grid_size = 1; + + switch (type) { + case LLAISYS_DTYPE_F32: + argmax_kernel<<>>(max_idx, + reinterpret_cast(max_val), + reinterpret_cast(vals), + numel); + break; + case LLAISYS_DTYPE_F16: + argmax_kernel<<>>(max_idx, + reinterpret_cast(max_val), + reinterpret_cast(vals), + numel); + break; + case LLAISYS_DTYPE_BF16: + argmax_kernel<__nv_bfloat16, block_size><<>>(max_idx, + reinterpret_cast<__nv_bfloat16 *>(max_val), + reinterpret_cast(vals), + numel); + break; + default: + EXCEPTION_UNSUPPORTED_DATATYPE(type); + } + + CUDA_CHECK(cudaGetLastError()); +} + +} // namespace llaisys::ops::nvidia diff --git a/src/ops/argmax/nvidia/argmax_nvidia.cuh b/src/ops/argmax/nvidia/argmax_nvidia.cuh new file mode 100644 index 000000000..51bc0f060 --- /dev/null +++ b/src/ops/argmax/nvidia/argmax_nvidia.cuh @@ -0,0 +1,16 @@ +#pragma once + +#include "llaisys.h" + +#include +#include + +namespace llaisys::ops::nvidia { + +void argmax(int64_t *max_idx, + std::byte *max_val, + const std::byte *vals, + llaisysDataType_t type, + size_t numel); + +} // namespace llaisys::ops::nvidia diff --git a/src/ops/argmax/op.cpp b/src/ops/argmax/op.cpp index 6dc37d426..c442f7630 100644 --- a/src/ops/argmax/op.cpp +++ b/src/ops/argmax/op.cpp @@ -1,7 +1,54 @@ #include "op.hpp" +#include "../../core/llaisys_core.hpp" +#include "../../utils.hpp" + +#include "cpu/argmax_cpu.hpp" +#include "nvidia/argmax_nvidia.cuh" +#ifdef ENABLE_METAX_API +#include "metax/argmax_metax.hpp" +#endif +#include "llaisys.h" + namespace llaisys::ops { void argmax(tensor_t max_idx, tensor_t max_val, tensor_t vals) { - TO_BE_IMPLEMENTED(); + CHECK_SAME_DEVICE(max_idx, max_val, vals); + + CHECK_ARGUMENT(vals->ndim() == 1, "vals only support 1D tensor for now"); + CHECK_ARGUMENT(max_idx->ndim() == 1 && max_idx->numel() == 1, "max_idx should be a single element"); + CHECK_ARGUMENT(max_val->ndim() == 1 && max_val->numel() == 1, "max_val should be a single element"); + + CHECK_SAME_DTYPE(max_idx->dtype(), LLAISYS_DTYPE_I64); + CHECK_SAME_DTYPE(max_val->dtype(), vals->dtype()); + + ASSERT(max_idx->isContiguous() && max_val->isContiguous() && vals->isContiguous(), + "max_idx, max_val and vals must be contiguous"); + + llaisys::core::context().setDevice(vals->deviceType(), vals->deviceId()); + + switch (vals->deviceType()) { + case LLAISYS_DEVICE_CPU: + return cpu::argmax( + reinterpret_cast(max_idx->data()), max_val->data(), vals->data(), vals->dtype(), vals->numel()); +#ifdef ENABLE_NVIDIA_API + case LLAISYS_DEVICE_NVIDIA: + return nvidia::argmax( + reinterpret_cast(max_idx->data()), + reinterpret_cast(max_val->data()), + reinterpret_cast(vals->data()), + vals->dtype(), + vals->numel()); +#endif +#ifdef ENABLE_METAX_API + case LLAISYS_DEVICE_METAX: + return metax::argmax(reinterpret_cast(max_idx->data()), + reinterpret_cast(max_val->data()), + reinterpret_cast(vals->data()), + vals->dtype(), + vals->numel()); +#endif + default: + EXCEPTION_UNSUPPORTED_DEVICE; + } } } // namespace llaisys::ops diff --git a/src/ops/argmax/op.hpp b/src/ops/argmax/op.hpp index 433fdacdb..4441ac595 100644 --- a/src/ops/argmax/op.hpp +++ b/src/ops/argmax/op.hpp @@ -2,6 +2,8 @@ #include "../../tensor/tensor.hpp" +// C++对外(python)暴露的接口声明 +// 功能:获取张量vals的最大值及其索引,并分别存储在max_val和max_idx中 namespace llaisys::ops { void argmax(tensor_t max_idx, tensor_t max_val, tensor_t vals); -} +} \ No newline at end of file diff --git a/src/ops/embedding/cpu/embedding_cpu.cpp b/src/ops/embedding/cpu/embedding_cpu.cpp new file mode 100644 index 000000000..f41b2ba1e --- /dev/null +++ b/src/ops/embedding/cpu/embedding_cpu.cpp @@ -0,0 +1,57 @@ +#include "embedding_cpu.hpp" + +#include "../../../utils.hpp" + +#include +#include + +// CPU 侧实现:逐行从 weight 中按 index 拷贝到 out +// out[i, :] = weight[index[i], :] +template +void embedding_(T *out, + const int64_t *index, + const T *weight, + size_t index_numel, + size_t embedding_dim) { + for (size_t i = 0; i < index_numel; i++) { + int64_t cur_idx = index[i]; + for (size_t j = 0; j < embedding_dim; j++) { + out[i * embedding_dim + j] = weight[cur_idx * embedding_dim + j]; + } + } +} + +namespace llaisys::ops::cpu { +void embedding(std::byte *out, + const std::byte *index, + const std::byte *weight, + llaisysDataType_t type, + size_t index_numel, + size_t embedding_dim) { + // index 在 op 层已经保证是 I64,这里直接按 int64_t 解释 + const auto *index_i64 = reinterpret_cast(index); + + switch (type) { + case LLAISYS_DTYPE_F32: + return embedding_(reinterpret_cast(out), + index_i64, + reinterpret_cast(weight), + index_numel, + embedding_dim); + case LLAISYS_DTYPE_BF16: + return embedding_(reinterpret_cast(out), + index_i64, + reinterpret_cast(weight), + index_numel, + embedding_dim); + case LLAISYS_DTYPE_F16: + return embedding_(reinterpret_cast(out), + index_i64, + reinterpret_cast(weight), + index_numel, + embedding_dim); + default: + EXCEPTION_UNSUPPORTED_DATATYPE(type); + } +} +} // namespace llaisys::ops::cpu diff --git a/src/ops/embedding/cpu/embedding_cpu.hpp b/src/ops/embedding/cpu/embedding_cpu.hpp new file mode 100644 index 000000000..260d5cc9b --- /dev/null +++ b/src/ops/embedding/cpu/embedding_cpu.hpp @@ -0,0 +1,20 @@ +#pragma once +#include "llaisys.h" + +#include + +// CPU 侧 embedding 接口: +// out : [seqlen, embedding_dim] +// index : [seqlen],int64 索引( +// weight: [num_embeddings, embedding_dim] +// type : out/weight 的数据类型(F32/F16/BF16) +// index_numel : seqlen +// embedding_dim : 每个 embedding 向量的维度 +namespace llaisys::ops::cpu { +void embedding(std::byte *out, + const std::byte *index, + const std::byte *weight, + llaisysDataType_t type, + size_t index_numel, + size_t embedding_dim); +} \ No newline at end of file diff --git a/src/ops/embedding/metax/embedding_metax.hpp b/src/ops/embedding/metax/embedding_metax.hpp new file mode 100644 index 000000000..1f931d364 --- /dev/null +++ b/src/ops/embedding/metax/embedding_metax.hpp @@ -0,0 +1,16 @@ +#pragma once + +#include "llaisys.h" + +#include + +namespace llaisys::ops::metax { + +void embedding(std::byte *out, + const std::byte *index, + const std::byte *weight, + llaisysDataType_t type, + size_t index_numel, + size_t embedding_dim); + +} // namespace llaisys::ops::metax diff --git a/src/ops/embedding/metax/embedding_metax.maca b/src/ops/embedding/metax/embedding_metax.maca new file mode 100644 index 000000000..6602ac29c --- /dev/null +++ b/src/ops/embedding/metax/embedding_metax.maca @@ -0,0 +1,71 @@ +#include "embedding_metax.hpp" + +#include "../../../utils.hpp" + +#include +#include +#include + +#include +#include + +namespace { + +template +__global__ void embedding_kernel(T *out, const int64_t *index, const T *weight, + size_t index_numel, size_t embedding_dim) { + const size_t row = static_cast(blockIdx.x); + if (row >= index_numel) { + return; + } + + const int64_t idx = index[row]; + const size_t in_start = static_cast(idx) * embedding_dim; + const size_t out_start = row * embedding_dim; + + for (size_t col = static_cast(threadIdx.x); col < embedding_dim; + col += static_cast(blockDim.x)) { + out[out_start + col] = weight[in_start + col]; + } +} + +} // namespace + +namespace llaisys::ops::metax { + +void embedding(std::byte *out, const std::byte *index, const std::byte *weight, + llaisysDataType_t type, size_t index_numel, + size_t embedding_dim) { + if (index_numel == 0 || embedding_dim == 0) { + return; + } + + const int block_size = 512; + const int grid_size = static_cast(index_numel); + + switch (type) { + case LLAISYS_DTYPE_F32: + embedding_kernel<<>>( + reinterpret_cast(out), + reinterpret_cast(index), + reinterpret_cast(weight), index_numel, embedding_dim); + break; + case LLAISYS_DTYPE_F16: + embedding_kernel<__half><<>>( + reinterpret_cast<__half *>(out), + reinterpret_cast(index), + reinterpret_cast(weight), index_numel, embedding_dim); + break; + case LLAISYS_DTYPE_BF16: + embedding_kernel<__maca_bfloat16><<>>( + reinterpret_cast<__maca_bfloat16 *>(out), + reinterpret_cast(index), + reinterpret_cast(weight), index_numel, + embedding_dim); + break; + default: + EXCEPTION_UNSUPPORTED_DATATYPE(type); + } +} + +} // namespace llaisys::ops::metax diff --git a/src/ops/embedding/nvidia/embedding_nvidia.cu b/src/ops/embedding/nvidia/embedding_nvidia.cu new file mode 100644 index 000000000..8fb53eea3 --- /dev/null +++ b/src/ops/embedding/nvidia/embedding_nvidia.cu @@ -0,0 +1,62 @@ +#include "embedding_nvidia.cuh" + +#include "../../../utils.hpp" +#include "../../../utils/gpu_utils.hpp" +#include + +namespace { + +template +__global__ void embedding_kernel(T *out, const int64_t *index, const T *weight, + size_t index_numel, size_t embedding_dim) { + const size_t row = blockIdx.x; + if (row >= index_numel) { + return; + } + + const int64_t idx = index[row]; + const size_t in_start = static_cast(idx) * embedding_dim; + const size_t out_start = row * embedding_dim; + + for (size_t col = threadIdx.x; col < embedding_dim; col += blockDim.x) { + out[out_start + col] = weight[in_start + col]; + } +} + +} // namespace + +namespace llaisys::ops::nvidia { + +void embedding(std::byte *out, const std::byte *index, const std::byte *weight, + llaisysDataType_t type, size_t index_numel, + size_t embedding_dim) { + + const int block_size = 256; + const int grid_size = index_numel; + + switch (type) { + case LLAISYS_DTYPE_F32: + embedding_kernel<<>>( + reinterpret_cast(out), + reinterpret_cast(index), + reinterpret_cast(weight), index_numel, embedding_dim); + break; + case LLAISYS_DTYPE_F16: + embedding_kernel<<>>( + reinterpret_cast(out), reinterpret_cast(index), + reinterpret_cast(weight), index_numel, embedding_dim); + break; + case LLAISYS_DTYPE_BF16: + embedding_kernel<__nv_bfloat16><<>>( + reinterpret_cast<__nv_bfloat16 *>(out), + reinterpret_cast(index), + reinterpret_cast(weight), index_numel, + embedding_dim); + break; + default: + EXCEPTION_UNSUPPORTED_DATATYPE(type); + } + + CUDA_CHECK(cudaGetLastError()); +} +} // namespace llaisys::ops::nvidia diff --git a/src/ops/embedding/nvidia/embedding_nvidia.cuh b/src/ops/embedding/nvidia/embedding_nvidia.cuh new file mode 100644 index 000000000..14168ce59 --- /dev/null +++ b/src/ops/embedding/nvidia/embedding_nvidia.cuh @@ -0,0 +1,10 @@ +#include "llaisys.h" +#include + +namespace llaisys::ops::nvidia { + +void embedding(std::byte *out, const std::byte *index, const std::byte *weight, + llaisysDataType_t type, size_t index_numel, + size_t embedding_dim); + +} // namespace llaisys::ops::nvidia \ No newline at end of file diff --git a/src/ops/embedding/op.cpp b/src/ops/embedding/op.cpp index 84b9a5d06..3fb2a2549 100644 --- a/src/ops/embedding/op.cpp +++ b/src/ops/embedding/op.cpp @@ -1,7 +1,67 @@ #include "op.hpp" +#include "../../core/llaisys_core.hpp" +#include "../../utils.hpp" + +#include "./cpu/embedding_cpu.hpp" +#ifdef ENABLE_NVIDIA_API +#include "./nvidia/embedding_nvidia.cuh" +#endif +#ifdef ENABLE_METAX_API +#include "./metax/embedding_metax.hpp" +#endif + namespace llaisys::ops { void embedding(tensor_t out, tensor_t index, tensor_t weight) { - TO_BE_IMPLEMENTED(); + // 1. 检查张量所在设备 + CHECK_SAME_DEVICE(out, index, weight); + + // 2. 检查张量形状 + CHECK_ARGUMENT(index->ndim() == 1, "index must be a 1D tensor"); + CHECK_ARGUMENT(weight->ndim() == 2, "weight must be a 2D tensor"); + CHECK_ARGUMENT(out->ndim() == 2, "out must be a 2D tensor"); + // 索引的数量就是输出的行数 + CHECK_ARGUMENT(index->numel() == out->shape()[0], + "index must have the same number of elements as the first " + "dimension of out"); + // 权重和输出的维度相同 + CHECK_ARGUMENT(weight->shape()[1] == out->shape()[1], + "weight must have the same number of rows as the second " + "dimension of out"); + // 索引的类型设为int64,与pytorch对齐 + CHECK_ARGUMENT(index->dtype() == LLAISYS_DTYPE_I64, + "index must be a 64-bit integer tensor"); + // 权重和输出的数据类型相同 + CHECK_ARGUMENT(weight->dtype() == out->dtype(), + "weight and out must have the same data type"); + // 索引、权重和输出必须连续 + ASSERT(index->isContiguous() && weight->isContiguous() && out->isContiguous(), + "index, weight and out must be contiguous"); + + // 3. 设置设备上下文 + llaisys::core::context().setDevice(out->deviceType(), out->deviceId()); + + // 4. 设备分发 + size_t index_numel = index->numel(); + size_t embedding_dim = weight->shape()[1]; + + switch (out->deviceType()) { + case LLAISYS_DEVICE_CPU: + // 需要传入index_numel和embedding_dim,因为传入类型为std::byte*,丢失shape信息 + return cpu::embedding(out->data(), index->data(), weight->data(), + out->dtype(), index_numel, embedding_dim); +#ifdef ENABLE_NVIDIA_API + case LLAISYS_DEVICE_NVIDIA: + return nvidia::embedding(out->data(), index->data(), weight->data(), + out->dtype(), index_numel, embedding_dim); +#endif +#ifdef ENABLE_METAX_API + case LLAISYS_DEVICE_METAX: + return metax::embedding(out->data(), index->data(), weight->data(), + out->dtype(), index_numel, embedding_dim); +#endif + default: + EXCEPTION_UNSUPPORTED_DEVICE; + } } } // namespace llaisys::ops diff --git a/src/ops/embedding/op.hpp b/src/ops/embedding/op.hpp index 37216c0cf..f7592a9d7 100644 --- a/src/ops/embedding/op.hpp +++ b/src/ops/embedding/op.hpp @@ -2,6 +2,10 @@ #include "../../tensor/tensor.hpp" +// 功能:按照索引(1-D)从权重矩阵(2-D)中抽取指定行,生成输出张量(2-D),即将索引映射为稠密向量 +// weight: 2-D tensor, shape: [num_embeddings, embedding_dim] +// index: 1-D tensor, shape: [batch_size] +// out: 2-D tensor, shape: [batch_size, embedding_dim] namespace llaisys::ops { void embedding(tensor_t out, tensor_t index, tensor_t weight); } diff --git a/src/ops/linear/cpu/linear_cpu.cpp b/src/ops/linear/cpu/linear_cpu.cpp new file mode 100644 index 000000000..b1c3cff6f --- /dev/null +++ b/src/ops/linear/cpu/linear_cpu.cpp @@ -0,0 +1,241 @@ +#include "linear_cpu.hpp" + +#include "../../../utils.hpp" +#include "llaisys.h" + +#include +#include +#include +#include +#include + +#ifdef LLAISYS_USE_OPENBLAS +#if __has_include() +#include +#define LLAISYS_HAS_CBLAS 1 +#elif __has_include() +#include +#define LLAISYS_HAS_CBLAS 1 +#endif +#endif + +// 分块矩阵乘 (F32),提升 cache 命中,无 OpenBLAS 时使用 +static constexpr size_t kBlock = 64u; + +static void linear_f32_blocked(float *out, + const float *in, + const float *weight, + const float *bias, + size_t M, + size_t N, + size_t K) { + if (bias != nullptr) { +#ifdef _OPENMP +#pragma omp parallel for schedule(static) if (M >= 64) +#endif + for (int i = 0; i < static_cast(M); i++) { + for (size_t j = 0; j < N; j++) { + out[i * N + j] = bias[j]; + } + } + } else { +#ifdef _OPENMP +#pragma omp parallel for schedule(static) if (M >= 64) +#endif + for (int i = 0; i < static_cast(M); i++) { + for (size_t j = 0; j < N; j++) { + out[i * N + j] = 0.0f; + } + } + } +#ifdef _OPENMP +#pragma omp parallel for schedule(static) if (M >= 2 * kBlock) +#endif + for (int ib = 0; ib < static_cast(M); ib += static_cast(kBlock)) { + size_t ie = (std::min)(ib + kBlock, M); + for (size_t kb = 0; kb < K; kb += kBlock) { + size_t ke = (std::min)(kb + kBlock, K); + for (size_t jb = 0; jb < N; jb += kBlock) { + size_t je = (std::min)(jb + kBlock, N); + for (size_t i = ib; i < ie; i++) { + for (size_t j = jb; j < je; j++) { + float sum = out[i * N + j]; + for (size_t k = kb; k < ke; k++) { + sum += in[i * K + k] * weight[j * K + k]; + } + out[i * N + j] = sum; + } + } + } + } + } +} + +// 通用内核:按外积方式实现 Y = X W^T + b(BF16/F16 或无 OpenBLAS 时使用) +template +static void linear_naive(T *out, + const T *in, + const T *weight, + const T *bias, + size_t M, + size_t N, + size_t K) { +#ifdef _OPENMP +#pragma omp parallel for schedule(static) if (M >= 64) +#endif + for (int i = 0; i < static_cast(M); i++) { + for (size_t j = 0; j < N; j++) { + float sum = 0.0f; + if (bias != nullptr) { + sum += llaisys::utils::cast(bias[j]); + } + if constexpr (std::is_same_v || std::is_same_v) { + for (size_t k = 0; k < K; k++) { + float data_x = llaisys::utils::cast(in[i * K + k]); + float data_w = llaisys::utils::cast(weight[j * K + k]); + sum += data_x * data_w; + } + out[i * N + j] = llaisys::utils::cast(sum); + } else { + for (size_t k = 0; k < K; k++) { + sum += in[i * K + k] * weight[j * K + k]; + } + out[i * N + j] = sum; + } + } + } +} + +#if defined(LLAISYS_USE_OPENBLAS) && defined(LLAISYS_HAS_CBLAS) +// F32: 直接调用 SGEMM,再加 bias +static void linear_f32_openblas(float *out, + const float *in, + const float *weight, + const float *bias, + size_t M, + size_t N, + size_t K) { + // C = alpha * A * B^T + beta * C => out = 1 * in * weight^T + 0 * out + // RowMajor: A[M,K] lda=K, B[N,K] transB => B^T[K,N] ldb=K, C[M,N] ldc=N + cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasTrans, + (int)M, (int)N, (int)K, + 1.0f, in, (int)K, weight, (int)K, 0.0f, out, (int)N); + if (bias != nullptr) { +#ifdef _OPENMP +#pragma omp parallel for schedule(static) if (M >= 64) +#endif + for (int i = 0; i < static_cast(M); i++) { + for (size_t j = 0; j < N; j++) { + out[i * N + j] += bias[j]; + } + } + } +} + +// BF16/F16: 分块转 float -> SGEMM -> 转回,避免整块临时矩阵过大 +static constexpr size_t kLinearBlockRows = 256; + +template +static void linear_bf16_f16_openblas(T *out, + const T *in, + const T *weight, + const T *bias, + size_t M, + size_t N, + size_t K) { + std::vector w_float(static_cast(N) * K); + for (size_t j = 0; j < N; j++) { + for (size_t k = 0; k < K; k++) { + w_float[j * K + k] = llaisys::utils::cast(weight[j * K + k]); + } + } + std::vector in_block(kLinearBlockRows * K); + std::vector out_block(kLinearBlockRows * N); + + for (size_t i0 = 0; i0 < M; i0 += kLinearBlockRows) { + size_t rows = (std::min)(i0 + kLinearBlockRows, M) - i0; + for (size_t i = 0; i < rows; i++) { + for (size_t k = 0; k < K; k++) { + in_block[i * K + k] = llaisys::utils::cast(in[(i0 + i) * K + k]); + } + } + cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasTrans, + (int)rows, (int)N, (int)K, + 1.0f, in_block.data(), (int)K, w_float.data(), (int)K, + 0.0f, out_block.data(), (int)N); + for (size_t i = 0; i < rows; i++) { + for (size_t j = 0; j < N; j++) { + float v = out_block[i * N + j]; + if (bias != nullptr) { + v += llaisys::utils::cast(bias[j]); + } + out[(i0 + i) * N + j] = llaisys::utils::cast(v); + } + } + } +} +#endif // LLAISYS_USE_OPENBLAS && LLAISYS_HAS_CBLAS + +namespace llaisys::ops::cpu { +void linear(std::byte *out, + const std::byte *in, + const std::byte *weight, + const std::byte *bias, + llaisysDataType_t type, + size_t M, + size_t N, + size_t K) { +#if defined(LLAISYS_USE_OPENBLAS) && defined(LLAISYS_HAS_CBLAS) + if (type == LLAISYS_DTYPE_F32) { + return linear_f32_openblas( + reinterpret_cast(out), + reinterpret_cast(in), + reinterpret_cast(weight), + reinterpret_cast(bias), + M, N, K); + } + if (type == LLAISYS_DTYPE_BF16) { + return linear_bf16_f16_openblas( + reinterpret_cast(out), + reinterpret_cast(in), + reinterpret_cast(weight), + reinterpret_cast(bias), + M, N, K); + } + if (type == LLAISYS_DTYPE_F16) { + return linear_bf16_f16_openblas( + reinterpret_cast(out), + reinterpret_cast(in), + reinterpret_cast(weight), + reinterpret_cast(bias), + M, N, K); + } +#else + (void)M; + (void)N; + (void)K; +#endif + switch (type) { + case LLAISYS_DTYPE_F16: + return linear_naive(reinterpret_cast(out), + reinterpret_cast(in), + reinterpret_cast(weight), + reinterpret_cast(bias), + M, N, K); + case LLAISYS_DTYPE_BF16: + return linear_naive(reinterpret_cast(out), + reinterpret_cast(in), + reinterpret_cast(weight), + reinterpret_cast(bias), + M, N, K); + case LLAISYS_DTYPE_F32: + return linear_f32_blocked(reinterpret_cast(out), + reinterpret_cast(in), + reinterpret_cast(weight), + reinterpret_cast(bias), + M, N, K); + default: + EXCEPTION_UNSUPPORTED_DATATYPE(type); + } +} +} // namespace llaisys::ops::cpu diff --git a/src/ops/linear/cpu/linear_cpu.hpp b/src/ops/linear/cpu/linear_cpu.hpp new file mode 100644 index 000000000..3c10e2ebe --- /dev/null +++ b/src/ops/linear/cpu/linear_cpu.hpp @@ -0,0 +1,7 @@ +#pragma once + +#include "llaisys.h" + +namespace llaisys::ops::cpu { +void linear(std::byte *out, const std::byte *in, const std::byte *weight, const std::byte *bias, llaisysDataType_t type, size_t M, size_t N, size_t K); +} \ No newline at end of file diff --git a/src/ops/linear/metax/linear_metax.hpp b/src/ops/linear/metax/linear_metax.hpp new file mode 100644 index 000000000..1eae5a9c2 --- /dev/null +++ b/src/ops/linear/metax/linear_metax.hpp @@ -0,0 +1,18 @@ +#pragma once + +#include "llaisys.h" + +#include + +namespace llaisys::ops::metax { + +void linear(std::byte *out, + const std::byte *in, + const std::byte *weight, + const std::byte *bias, + llaisysDataType_t type, + size_t M, + size_t N, + size_t K); + +} // namespace llaisys::ops::metax diff --git a/src/ops/linear/metax/linear_metax.maca b/src/ops/linear/metax/linear_metax.maca new file mode 100644 index 000000000..803f8ac6f --- /dev/null +++ b/src/ops/linear/metax/linear_metax.maca @@ -0,0 +1,301 @@ +#include "linear_metax.hpp" + +#include "../../../utils.hpp" + +#include +#include +#include +#include + +#include + +namespace { + +__host__ __device__ __forceinline__ int ceil_div_int(int x, int y) { + return (x + y - 1) / y; +} + +constexpr int METAX_WARP_SIZE = 64; + +template __device__ __forceinline__ float to_float(T v) { + return static_cast(v); +} + +template <> __device__ __forceinline__ float to_float<__half>(__half v) { + return __half2float(v); +} + +template <> +__device__ __forceinline__ float to_float<__maca_bfloat16>(__maca_bfloat16 v) { + return __bfloat162float(v); +} + +template __device__ __forceinline__ T from_float(float v) { + return static_cast(v); +} + +template <> __device__ __forceinline__ __half from_float<__half>(float v) { + return __float2half(v); +} + +template <> +__device__ __forceinline__ __maca_bfloat16 +from_float<__maca_bfloat16>(float v) { + return __float2bfloat16(v); +} + +template +__global__ void add_bias_rowwise_kernel(T *out, const T *bias, size_t M, + size_t N) { + const size_t idx + = static_cast(blockIdx.x) * blockDim.x + threadIdx.x; + const size_t total = M * N; + for (size_t i = idx; i < total; + i += static_cast(blockDim.x) * gridDim.x) { + const size_t col = i % N; + out[i] = from_float(to_float(out[i]) + to_float(bias[col])); + } +} + +template +inline void launch_add_bias(T *out, const T *bias, size_t M, size_t N) { + if (bias == nullptr || M == 0 || N == 0) { + return; + } + constexpr int block_size = METAX_WARP_SIZE * 8; + const int grid_size = ceil_div_int(static_cast(M * N), block_size); + add_bias_rowwise_kernel<<>>(out, bias, M, N); +} + +inline bool mcblas_ok(mcblasStatus_t status) { + return static_cast(status) == 0; +} + +inline mcblasHandle_t get_mcblas_handle() { + static thread_local mcblasHandle_t handle = []() { + mcblasHandle_t h = nullptr; + if (!mcblas_ok(mcblasCreate(&h))) { + return static_cast(nullptr); + } + return h; + }(); + return handle; +} + +inline bool linear_mcblas_f32(float *out, const float *in, const float *weight, + const float *bias, size_t M, size_t N, size_t K) { + mcblasHandle_t handle = get_mcblas_handle(); + if (handle == nullptr) { + return false; + } + + if (!mcblas_ok(mcblasSetPointerMode(handle, MCBLAS_POINTER_MODE_HOST))) { + return false; + } + if (!mcblas_ok(mcblasSetAtomicsMode(handle, MCBLAS_ATOMICS_NOT_ALLOWED))) { + return false; + } + + mcblasMath_t math_mode = MCBLAS_PEDANTIC_MATH; +#ifdef MCBLAS_MATH_DISALLOW_REDUCED_PRECISION_REDUCTION + math_mode = static_cast( + static_cast(math_mode) + | static_cast(MCBLAS_MATH_DISALLOW_REDUCED_PRECISION_REDUCTION)); +#endif + if (!mcblas_ok(mcblasSetMathMode(handle, math_mode))) { + return false; + } + + const int m = static_cast(N); + const int n = static_cast(M); + const int k = static_cast(K); + const int lda = static_cast(K); + const int ldb = static_cast(K); + const int ldc = static_cast(N); + const float alpha = 1.0f; + const float beta = 0.0f; + + const mcblasStatus_t status + = mcblasSgemm(handle, MCBLAS_OP_T, MCBLAS_OP_N, m, n, k, &alpha, weight, + lda, in, ldb, &beta, out, ldc); + if (!mcblas_ok(status)) { + return false; + } + + launch_add_bias(out, bias, M, N); + return true; +} + +inline bool linear_mcblas_f16(__half *out, const __half *in, + const __half *weight, const __half *bias, + size_t M, size_t N, size_t K) { + mcblasHandle_t handle = get_mcblas_handle(); + if (handle == nullptr) { + return false; + } + + if (!mcblas_ok(mcblasSetPointerMode(handle, MCBLAS_POINTER_MODE_HOST))) { + return false; + } + if (!mcblas_ok(mcblasSetAtomicsMode(handle, MCBLAS_ATOMICS_ALLOWED))) { + return false; + } +#ifdef MCBLAS_TENSOR_OP_MATH + if (!mcblas_ok(mcblasSetMathMode(handle, MCBLAS_TENSOR_OP_MATH))) { + return false; + } +#endif + + const int m = static_cast(N); + const int n = static_cast(M); + const int k = static_cast(K); + const int lda = static_cast(K); + const int ldb = static_cast(K); + const int ldc = static_cast(N); + const float alpha = 1.0f; + const float beta = 0.0f; + + mcblasComputeType_t compute_type = MCBLAS_COMPUTE_32F; + bool used_fast_compute = false; +#ifdef MCBLAS_COMPUTE_32F_FAST_16F + compute_type = MCBLAS_COMPUTE_32F_FAST_16F; + used_fast_compute = true; +#endif + mcblasGemmAlgo_t algo = MCBLAS_GEMM_DEFAULT; +#ifdef MCBLAS_GEMM_DEFAULT_TENSOR_OP + algo = MCBLAS_GEMM_DEFAULT_TENSOR_OP; +#endif + + mcblasStatus_t status + = mcblasGemmEx(handle, MCBLAS_OP_T, MCBLAS_OP_N, m, n, k, &alpha, + weight, MACA_R_16F, lda, in, MACA_R_16F, ldb, &beta, out, + MACA_R_16F, ldc, compute_type, algo); + + if (!mcblas_ok(status) && algo != MCBLAS_GEMM_DEFAULT) { + status = mcblasGemmEx(handle, MCBLAS_OP_T, MCBLAS_OP_N, m, n, k, &alpha, + weight, MACA_R_16F, lda, in, MACA_R_16F, ldb, + &beta, out, MACA_R_16F, ldc, compute_type, + MCBLAS_GEMM_DEFAULT); + } + if (!mcblas_ok(status) && used_fast_compute) { + status = mcblasGemmEx(handle, MCBLAS_OP_T, MCBLAS_OP_N, m, n, k, &alpha, + weight, MACA_R_16F, lda, in, MACA_R_16F, ldb, + &beta, out, MACA_R_16F, ldc, MCBLAS_COMPUTE_32F, + MCBLAS_GEMM_DEFAULT); + } + if (!mcblas_ok(status)) { + return false; + } + + launch_add_bias(out, bias, M, N); + return true; +} + +inline bool linear_mcblas_bf16(__maca_bfloat16 *out, const __maca_bfloat16 *in, + const __maca_bfloat16 *weight, + const __maca_bfloat16 *bias, size_t M, size_t N, + size_t K) { + mcblasHandle_t handle = get_mcblas_handle(); + if (handle == nullptr) { + return false; + } + + if (!mcblas_ok(mcblasSetPointerMode(handle, MCBLAS_POINTER_MODE_HOST))) { + return false; + } + if (!mcblas_ok(mcblasSetAtomicsMode(handle, MCBLAS_ATOMICS_ALLOWED))) { + return false; + } +#ifdef MCBLAS_TENSOR_OP_MATH + if (!mcblas_ok(mcblasSetMathMode(handle, MCBLAS_TENSOR_OP_MATH))) { + return false; + } +#endif + + const int m = static_cast(N); + const int n = static_cast(M); + const int k = static_cast(K); + const int lda = static_cast(K); + const int ldb = static_cast(K); + const int ldc = static_cast(N); + const float alpha = 1.0f; + const float beta = 0.0f; + + mcblasComputeType_t compute_type = MCBLAS_COMPUTE_32F; + bool used_fast_compute = false; +#ifdef MCBLAS_COMPUTE_32F_FAST_16BF + compute_type = MCBLAS_COMPUTE_32F_FAST_16BF; + used_fast_compute = true; +#endif + mcblasGemmAlgo_t algo = MCBLAS_GEMM_DEFAULT; +#ifdef MCBLAS_GEMM_DEFAULT_TENSOR_OP + algo = MCBLAS_GEMM_DEFAULT_TENSOR_OP; +#endif + + mcblasStatus_t status + = mcblasGemmEx(handle, MCBLAS_OP_T, MCBLAS_OP_N, m, n, k, &alpha, + weight, MACA_R_16BF, lda, in, MACA_R_16BF, ldb, &beta, + out, MACA_R_16BF, ldc, compute_type, algo); + + if (!mcblas_ok(status) && algo != MCBLAS_GEMM_DEFAULT) { + status = mcblasGemmEx(handle, MCBLAS_OP_T, MCBLAS_OP_N, m, n, k, &alpha, + weight, MACA_R_16BF, lda, in, MACA_R_16BF, ldb, + &beta, out, MACA_R_16BF, ldc, compute_type, + MCBLAS_GEMM_DEFAULT); + } + if (!mcblas_ok(status) && used_fast_compute) { + status = mcblasGemmEx(handle, MCBLAS_OP_T, MCBLAS_OP_N, m, n, k, &alpha, + weight, MACA_R_16BF, lda, in, MACA_R_16BF, ldb, + &beta, out, MACA_R_16BF, ldc, MCBLAS_COMPUTE_32F, + MCBLAS_GEMM_DEFAULT); + } + if (!mcblas_ok(status)) { + return false; + } + + launch_add_bias(out, bias, M, N); + return true; +} + + +} // namespace + +namespace llaisys::ops::metax { + +void linear(std::byte *out, const std::byte *in, const std::byte *weight, + const std::byte *bias, llaisysDataType_t type, size_t M, size_t N, + size_t K) { + if (M == 0 || N == 0 || K == 0) { + return; + } + + switch (type) { + case LLAISYS_DTYPE_F32: { + const bool ok = linear_mcblas_f32( + reinterpret_cast(out), reinterpret_cast(in), + reinterpret_cast(weight), + reinterpret_cast(bias), M, N, K); + break; + } + case LLAISYS_DTYPE_F16: { + const bool ok = linear_mcblas_f16( + reinterpret_cast<__half *>(out), + reinterpret_cast(in), + reinterpret_cast(weight), + reinterpret_cast(bias), M, N, K); + break; + } + case LLAISYS_DTYPE_BF16: { + const bool ok = linear_mcblas_bf16( + reinterpret_cast<__maca_bfloat16 *>(out), + reinterpret_cast(in), + reinterpret_cast(weight), + reinterpret_cast(bias), M, N, K); + break; + } + default: + EXCEPTION_UNSUPPORTED_DATATYPE(type); + } +} + +} // namespace llaisys::ops::metax diff --git a/src/ops/linear/nvidia/linear_nvidia.cu b/src/ops/linear/nvidia/linear_nvidia.cu new file mode 100644 index 000000000..d00317334 --- /dev/null +++ b/src/ops/linear/nvidia/linear_nvidia.cu @@ -0,0 +1,460 @@ +#include "linear_nvidia.cuh" + +#include "../../../utils.hpp" +#include "../../../utils/gpu_utils.hpp" + +#include + +#include +#include + +namespace { + +inline void cublas_check(cublasStatus_t status, const char *msg) { + if (status != CUBLAS_STATUS_SUCCESS) { + throw std::runtime_error(msg); + } +} + +inline cublasHandle_t get_cublas_handle() { + static thread_local cublasHandle_t handle = []() { + cublasHandle_t h = nullptr; + cublas_check(cublasCreate(&h), "cublasCreate failed"); + return h; + }(); + return handle; +} + +template +__global__ void add_bias_rowwise_kernel(T *out, const T *bias, size_t M, size_t N) { + const size_t idx = static_cast(blockIdx.x) * blockDim.x + threadIdx.x; + const size_t total = M * N; + for (size_t i = idx; i < total; i += static_cast(blockDim.x) * gridDim.x) { + const size_t col = i % N; + out[i] = from_float(to_float(out[i]) + to_float(bias[col])); + } +} + +template +inline void launch_add_bias(T *out, const T *bias, size_t M, size_t N) { + if (bias == nullptr || M == 0 || N == 0) { + return; + } + constexpr int block_size = 256; + const int grid_size = static_cast(CEIL(M * N, block_size)); + add_bias_rowwise_kernel<<>>(out, bias, M, N); +} + +inline void linear_cublas_f32(float *out, + const float *in, + const float *weight, + const float *bias, + size_t M, + size_t N, + size_t K) { + cublasHandle_t handle = get_cublas_handle(); + const float alpha = 1.0f; + const float beta = 0.0f; + cublas_check(cublasSgemm(handle, + CUBLAS_OP_T, + CUBLAS_OP_N, + static_cast(N), + static_cast(M), + static_cast(K), + &alpha, + weight, + static_cast(K), + in, + static_cast(K), + &beta, + out, + static_cast(N)), + "cublasSgemm failed"); + launch_add_bias(out, bias, M, N); +} + +inline void linear_cublas_f16(half *out, + const half *in, + const half *weight, + const half *bias, + size_t M, + size_t N, + size_t K) { + cublasHandle_t handle = get_cublas_handle(); + const float alpha = 1.0f; + const float beta = 0.0f; + cublasStatus_t status = cublasGemmEx(handle, + CUBLAS_OP_T, + CUBLAS_OP_N, + static_cast(N), + static_cast(M), + static_cast(K), + &alpha, + weight, + CUDA_R_16F, + static_cast(K), + in, + CUDA_R_16F, + static_cast(K), + &beta, + out, + CUDA_R_16F, + static_cast(N), + CUBLAS_COMPUTE_32F, + CUBLAS_GEMM_DEFAULT_TENSOR_OP); + if (status == CUBLAS_STATUS_NOT_SUPPORTED) { + status = cublasGemmEx(handle, + CUBLAS_OP_T, + CUBLAS_OP_N, + static_cast(N), + static_cast(M), + static_cast(K), + &alpha, + weight, + CUDA_R_16F, + static_cast(K), + in, + CUDA_R_16F, + static_cast(K), + &beta, + out, + CUDA_R_16F, + static_cast(N), + CUBLAS_COMPUTE_32F, + CUBLAS_GEMM_DEFAULT); + } + cublas_check(status, "cublasGemmEx f16 failed"); + launch_add_bias(out, bias, M, N); +} + +inline void linear_cublas_bf16(__nv_bfloat16 *out, + const __nv_bfloat16 *in, + const __nv_bfloat16 *weight, + const __nv_bfloat16 *bias, + size_t M, + size_t N, + size_t K) { + cublasHandle_t handle = get_cublas_handle(); + const float alpha = 1.0f; + const float beta = 0.0f; + cublasStatus_t status = cublasGemmEx(handle, + CUBLAS_OP_T, + CUBLAS_OP_N, + static_cast(N), + static_cast(M), + static_cast(K), + &alpha, + weight, + CUDA_R_16BF, + static_cast(K), + in, + CUDA_R_16BF, + static_cast(K), + &beta, + out, + CUDA_R_16BF, + static_cast(N), + CUBLAS_COMPUTE_32F, + CUBLAS_GEMM_DEFAULT_TENSOR_OP); + if (status == CUBLAS_STATUS_NOT_SUPPORTED) { + status = cublasGemmEx(handle, + CUBLAS_OP_T, + CUBLAS_OP_N, + static_cast(N), + static_cast(M), + static_cast(K), + &alpha, + weight, + CUDA_R_16BF, + static_cast(K), + in, + CUDA_R_16BF, + static_cast(K), + &beta, + out, + CUDA_R_16BF, + static_cast(N), + CUBLAS_COMPUTE_32F, + CUBLAS_GEMM_DEFAULT); + } + cublas_check(status, "cublasGemmEx bf16 failed"); + launch_add_bias(out, bias, M, N); +} + +// Reference-only hand-written kernel retained for review. It is not dispatched. +template +__global__ void sgemm_v7_float32(float *__restrict__ out, + const float *__restrict__ in, + const float *__restrict__ weight, + const float *__restrict__ bias, + size_t M, + size_t N, + size_t K) { + static_assert(BLOCK_SIZE_M == 128 && BLOCK_SIZE_N == 128 && BLOCK_SIZE_K == 8 && THREAD_SIZE_X == 8 && THREAD_SIZE_Y == 8, + "v7 is tuned for 128x128x8 tile and 8x8 thread tile."); + + const int bx = blockIdx.x; + const int by = blockIdx.y; + + const int tx = threadIdx.x; + const int ty = threadIdx.y; + + const int thread_x_per_block = BLOCK_SIZE_N / THREAD_SIZE_X; + const int thread_y_per_block = BLOCK_SIZE_M / THREAD_SIZE_Y; + const int thread_num_per_block = thread_x_per_block * thread_y_per_block; + + const int tid = ty * thread_x_per_block + tx; + + __shared__ float As[2][BLOCK_SIZE_K][BLOCK_SIZE_M]; + __shared__ float Bs[2][BLOCK_SIZE_K][BLOCK_SIZE_N]; + + float accum[THREAD_SIZE_Y][THREAD_SIZE_X] = {0.0f}; + + float frag_a[2][THREAD_SIZE_Y]; + float frag_b[2][THREAD_SIZE_X]; + + const int ldg_num_a = BLOCK_SIZE_M * BLOCK_SIZE_K / (thread_num_per_block * 4); + const int ldg_num_b = BLOCK_SIZE_K * BLOCK_SIZE_N / (thread_num_per_block * 4); + float ldg_a_reg[4 * ldg_num_a]; + float ldg_b_reg[4 * ldg_num_b]; + + const int a_load_thread_per_row = BLOCK_SIZE_K / 4; + const int b_load_thread_per_row = BLOCK_SIZE_K / 4; + + const int a_load_row_start = tid / a_load_thread_per_row; + const int b_load_row_start = tid / b_load_thread_per_row; + const int a_load_col = (tid % a_load_thread_per_row) * 4; + const int b_load_col = (tid % b_load_thread_per_row) * 4; + + const int a_load_row_stride = thread_num_per_block / a_load_thread_per_row; + const int b_load_row_stride = thread_num_per_block / b_load_thread_per_row; + + const float *A = &in[(BLOCK_SIZE_M * by) * K]; + const float *B = &weight[(BLOCK_SIZE_N * bx) * K]; + +#pragma unroll + for (int i = 0; i < BLOCK_SIZE_M; i += a_load_row_stride) { + const int ldg_index = i / a_load_row_stride * 4; + const int offset = (a_load_row_start + i) * K + a_load_col; + STORE_FLOAT4(ldg_a_reg[ldg_index]) = LOAD_FLOAT4(A[offset]); + As[0][a_load_col][a_load_row_start + i] = ldg_a_reg[ldg_index]; + As[0][a_load_col + 1][a_load_row_start + i] = ldg_a_reg[ldg_index + 1]; + As[0][a_load_col + 2][a_load_row_start + i] = ldg_a_reg[ldg_index + 2]; + As[0][a_load_col + 3][a_load_row_start + i] = ldg_a_reg[ldg_index + 3]; + } + +#pragma unroll + for (int i = 0; i < BLOCK_SIZE_N; i += b_load_row_stride) { + const int ldg_index = i / b_load_row_stride * 4; + const int offset = (b_load_row_start + i) * K + b_load_col; + STORE_FLOAT4(ldg_b_reg[ldg_index]) = LOAD_FLOAT4(B[offset]); + Bs[0][b_load_col][b_load_row_start + i] = ldg_b_reg[ldg_index]; + Bs[0][b_load_col + 1][b_load_row_start + i] = ldg_b_reg[ldg_index + 1]; + Bs[0][b_load_col + 2][b_load_row_start + i] = ldg_b_reg[ldg_index + 2]; + Bs[0][b_load_col + 3][b_load_row_start + i] = ldg_b_reg[ldg_index + 3]; + } + __syncthreads(); + + const int warp_id = tid / 32; + const int lane_id = tid % 32; + const int a_tile_index = warp_id / 2 * 16 + lane_id / 8 * 4; + const int b_tile_index = warp_id % 2 * 32 + lane_id % 8 * 4; + + STORE_FLOAT4(frag_a[0][0]) = LOAD_FLOAT4(As[0][0][a_tile_index]); + STORE_FLOAT4(frag_a[0][4]) = LOAD_FLOAT4(As[0][0][a_tile_index + BLOCK_SIZE_M / 2]); + STORE_FLOAT4(frag_b[0][0]) = LOAD_FLOAT4(Bs[0][0][b_tile_index]); + STORE_FLOAT4(frag_b[0][4]) = LOAD_FLOAT4(Bs[0][0][b_tile_index + BLOCK_SIZE_N / 2]); + + int write_stage_idx = 1; + int tile_idx = 0; + do { + tile_idx += BLOCK_SIZE_K; + if (tile_idx < K) { +#pragma unroll + for (int i = 0; i < BLOCK_SIZE_M; i += a_load_row_stride) { + const int ldg_index = i / a_load_row_stride * 4; + const int offset = (a_load_row_start + i) * K + (a_load_col + tile_idx); + STORE_FLOAT4(ldg_a_reg[ldg_index]) = LOAD_FLOAT4(A[offset]); + } +#pragma unroll + for (int i = 0; i < BLOCK_SIZE_N; i += b_load_row_stride) { + const int ldg_index = i / b_load_row_stride * 4; + const int offset = (b_load_row_start + i) * K + (b_load_col + tile_idx); + STORE_FLOAT4(ldg_b_reg[ldg_index]) = LOAD_FLOAT4(B[offset]); + } + } + + const int load_stage_idx = write_stage_idx ^ 1; + +#pragma unroll + for (int j = 0; j < BLOCK_SIZE_K - 1; ++j) { + STORE_FLOAT4(frag_a[(j + 1) % 2][0]) = LOAD_FLOAT4(As[load_stage_idx][j + 1][a_tile_index]); + STORE_FLOAT4(frag_a[(j + 1) % 2][4]) = LOAD_FLOAT4(As[load_stage_idx][j + 1][a_tile_index + BLOCK_SIZE_M / 2]); + STORE_FLOAT4(frag_b[(j + 1) % 2][0]) = LOAD_FLOAT4(Bs[load_stage_idx][j + 1][b_tile_index]); + STORE_FLOAT4(frag_b[(j + 1) % 2][4]) = LOAD_FLOAT4(Bs[load_stage_idx][j + 1][b_tile_index + BLOCK_SIZE_N / 2]); + +#pragma unroll + for (int thread_y = 0; thread_y < THREAD_SIZE_Y; ++thread_y) { +#pragma unroll + for (int thread_x = 0; thread_x < THREAD_SIZE_X; ++thread_x) { + accum[thread_y][thread_x] += frag_a[j % 2][thread_y] * frag_b[j % 2][thread_x]; + } + } + } + + if (tile_idx < K) { +#pragma unroll + for (int i = 0; i < BLOCK_SIZE_M; i += a_load_row_stride) { + const int ldg_index = i / a_load_row_stride * 4; + As[write_stage_idx][a_load_col][a_load_row_start + i] = ldg_a_reg[ldg_index]; + As[write_stage_idx][a_load_col + 1][a_load_row_start + i] = ldg_a_reg[ldg_index + 1]; + As[write_stage_idx][a_load_col + 2][a_load_row_start + i] = ldg_a_reg[ldg_index + 2]; + As[write_stage_idx][a_load_col + 3][a_load_row_start + i] = ldg_a_reg[ldg_index + 3]; + } +#pragma unroll + for (int i = 0; i < BLOCK_SIZE_N; i += b_load_row_stride) { + const int ldg_index = i / b_load_row_stride * 4; + Bs[write_stage_idx][b_load_col][b_load_row_start + i] = ldg_b_reg[ldg_index]; + Bs[write_stage_idx][b_load_col + 1][b_load_row_start + i] = ldg_b_reg[ldg_index + 1]; + Bs[write_stage_idx][b_load_col + 2][b_load_row_start + i] = ldg_b_reg[ldg_index + 2]; + Bs[write_stage_idx][b_load_col + 3][b_load_row_start + i] = ldg_b_reg[ldg_index + 3]; + } + __syncthreads(); + write_stage_idx ^= 1; + } + + STORE_FLOAT4(frag_a[0][0]) = LOAD_FLOAT4(As[load_stage_idx ^ 1][0][a_tile_index]); + STORE_FLOAT4(frag_a[0][4]) = LOAD_FLOAT4(As[load_stage_idx ^ 1][0][a_tile_index + BLOCK_SIZE_M / 2]); + STORE_FLOAT4(frag_b[0][0]) = LOAD_FLOAT4(Bs[load_stage_idx ^ 1][0][b_tile_index]); + STORE_FLOAT4(frag_b[0][4]) = LOAD_FLOAT4(Bs[load_stage_idx ^ 1][0][b_tile_index + BLOCK_SIZE_N / 2]); + +#pragma unroll + for (int thread_y = 0; thread_y < THREAD_SIZE_Y; ++thread_y) { +#pragma unroll + for (int thread_x = 0; thread_x < THREAD_SIZE_X; ++thread_x) { + accum[thread_y][thread_x] += frag_a[1][thread_y] * frag_b[1][thread_x]; + } + } + } while (tile_idx < K); + + const int c_block_row = a_tile_index; + const int c_block_col = b_tile_index; + + for (int i = 0; i < 4; i++) { + const int row = BLOCK_SIZE_M * by + c_block_row + i; + const int col = BLOCK_SIZE_N * bx + c_block_col; + float4 c_val; + c_val.x = accum[i][0]; + c_val.y = accum[i][1]; + c_val.z = accum[i][2]; + c_val.w = accum[i][3]; + if (bias != nullptr) { + c_val.x += bias[col]; + c_val.y += bias[col + 1]; + c_val.z += bias[col + 2]; + c_val.w += bias[col + 3]; + } + STORE_FLOAT4(out[row * N + col]) = c_val; + } + for (int i = 0; i < 4; i++) { + const int row = BLOCK_SIZE_M * by + c_block_row + i; + const int col = BLOCK_SIZE_N * bx + c_block_col + BLOCK_SIZE_N / 2; + float4 c_val; + c_val.x = accum[i][4]; + c_val.y = accum[i][5]; + c_val.z = accum[i][6]; + c_val.w = accum[i][7]; + if (bias != nullptr) { + c_val.x += bias[col]; + c_val.y += bias[col + 1]; + c_val.z += bias[col + 2]; + c_val.w += bias[col + 3]; + } + STORE_FLOAT4(out[row * N + col]) = c_val; + } + for (int i = 0; i < 4; i++) { + const int row = BLOCK_SIZE_M * by + c_block_row + BLOCK_SIZE_M / 2 + i; + const int col = BLOCK_SIZE_N * bx + c_block_col; + float4 c_val; + c_val.x = accum[i + 4][0]; + c_val.y = accum[i + 4][1]; + c_val.z = accum[i + 4][2]; + c_val.w = accum[i + 4][3]; + if (bias != nullptr) { + c_val.x += bias[col]; + c_val.y += bias[col + 1]; + c_val.z += bias[col + 2]; + c_val.w += bias[col + 3]; + } + STORE_FLOAT4(out[row * N + col]) = c_val; + } + for (int i = 0; i < 4; i++) { + const int row = BLOCK_SIZE_M * by + c_block_row + BLOCK_SIZE_M / 2 + i; + const int col = BLOCK_SIZE_N * bx + c_block_col + BLOCK_SIZE_N / 2; + float4 c_val; + c_val.x = accum[i + 4][4]; + c_val.y = accum[i + 4][5]; + c_val.z = accum[i + 4][6]; + c_val.w = accum[i + 4][7]; + if (bias != nullptr) { + c_val.x += bias[col]; + c_val.y += bias[col + 1]; + c_val.z += bias[col + 2]; + c_val.w += bias[col + 3]; + } + STORE_FLOAT4(out[row * N + col]) = c_val; + } +} + +} // namespace + +namespace llaisys::ops::nvidia { + +void linear(std::byte *out, + const std::byte *in, + const std::byte *weight, + const std::byte *bias, + llaisysDataType_t type, + size_t M, + size_t N, + size_t K) { + switch (type) { + case LLAISYS_DTYPE_F32: + linear_cublas_f32(reinterpret_cast(out), + reinterpret_cast(in), + reinterpret_cast(weight), + reinterpret_cast(bias), + M, + N, + K); + break; + case LLAISYS_DTYPE_F16: + linear_cublas_f16(reinterpret_cast(out), + reinterpret_cast(in), + reinterpret_cast(weight), + reinterpret_cast(bias), + M, + N, + K); + break; + case LLAISYS_DTYPE_BF16: + linear_cublas_bf16(reinterpret_cast<__nv_bfloat16 *>(out), + reinterpret_cast(in), + reinterpret_cast(weight), + reinterpret_cast(bias), + M, + N, + K); + break; + default: + EXCEPTION_UNSUPPORTED_DATATYPE(type); + } + + CUDA_CHECK(cudaGetLastError()); +} + +} // namespace llaisys::ops::nvidia diff --git a/src/ops/linear/nvidia/linear_nvidia.cuh b/src/ops/linear/nvidia/linear_nvidia.cuh new file mode 100644 index 000000000..c7fa94011 --- /dev/null +++ b/src/ops/linear/nvidia/linear_nvidia.cuh @@ -0,0 +1,9 @@ +#pragma once + +#include "llaisys.h" + +namespace llaisys::ops::nvidia { +void linear(std::byte *out, const std::byte *in, const std::byte *weight, + const std::byte *bias, llaisysDataType_t type, size_t M, size_t N, + size_t K); +} \ No newline at end of file diff --git a/src/ops/linear/op.cpp b/src/ops/linear/op.cpp index 97d1f8655..3854467d4 100644 --- a/src/ops/linear/op.cpp +++ b/src/ops/linear/op.cpp @@ -1,7 +1,74 @@ #include "op.hpp" +#include "../../core/llaisys_core.hpp" +#include "../../utils.hpp" + +#include "./cpu/linear_cpu.hpp" +#ifdef ENABLE_NVIDIA_API +#include "./nvidia/linear_nvidia.cuh" +#endif +#ifdef ENABLE_METAX_API +#include "./metax/linear_metax.hpp" +#endif +#include "llaisys.h" + namespace llaisys::ops { void linear(tensor_t out, tensor_t in, tensor_t weight, tensor_t bias) { - TO_BE_IMPLEMENTED(); + // 1. 参数校验 + CHECK_SAME_DEVICE(out, in, weight); + if (bias != nullptr) { + CHECK_SAME_DEVICE(out, bias); + CHECK_ARGUMENT(bias->ndim() == 1, "bias must be a 1D tensor"); + CHECK_ARGUMENT(bias->shape()[0] == out->shape()[1], + "N dim of bias and out must be the same"); + CHECK_ARGUMENT(out->dtype() == bias->dtype(), + "bias must have the same data type as out"); + } + CHECK_ARGUMENT(out->ndim() == 2, "out must be a 2D tensor"); + CHECK_ARGUMENT(in->ndim() == 2, "in must be a 2D tensor"); + CHECK_ARGUMENT(weight->ndim() == 2, "weight must be a 2D tensor"); + // X: [M, K], W: [N, K], b: [N], Y: [M, N] + CHECK_ARGUMENT(out->shape()[0] == in->shape()[0], + "M dim of out and in must be the same"); + CHECK_ARGUMENT(out->shape()[1] == weight->shape()[0], + "N dim of out and weight must be the same"); + CHECK_ARGUMENT(in->shape()[1] == weight->shape()[1], + "K dim of inin and weight must be the same"); + CHECK_ARGUMENT(out->dtype() == in->dtype() && out->dtype() == weight->dtype(), + "out, in and weight must have the same data type"); + if (bias != nullptr) { + ASSERT(out->isContiguous() && in->isContiguous() && + weight->isContiguous() && bias->isContiguous(), + "out, in, weight and bias must be contiguous"); + } else { + ASSERT(out->isContiguous() && in->isContiguous() && weight->isContiguous(), + "out, in and weight must be contiguous"); + } + + // 2. 设置上下文 + llaisys::core::context().setDevice(out->deviceType(), out->deviceId()); + + switch (out->deviceType()) { + case LLAISYS_DEVICE_CPU: + return cpu::linear(out->data(), in->data(), weight->data(), + (bias != nullptr) ? bias->data() : nullptr, out->dtype(), + out->shape()[0], out->shape()[1], in->shape()[1]); +#ifdef ENABLE_NVIDIA_API + case LLAISYS_DEVICE_NVIDIA: + return nvidia::linear(out->data(), in->data(), weight->data(), + (bias != nullptr) ? bias->data() : nullptr, + out->dtype(), out->shape()[0], out->shape()[1], + in->shape()[1]); +#endif +#ifdef ENABLE_METAX_API + case LLAISYS_DEVICE_METAX: + return metax::linear(out->data(), in->data(), weight->data(), + (bias != nullptr) ? bias->data() : nullptr, + out->dtype(), out->shape()[0], out->shape()[1], + in->shape()[1]); +#endif + default: + EXCEPTION_UNSUPPORTED_DEVICE; + } } } // namespace llaisys::ops diff --git a/src/ops/linear/op.hpp b/src/ops/linear/op.hpp index 7bf06f017..6ed922633 100644 --- a/src/ops/linear/op.hpp +++ b/src/ops/linear/op.hpp @@ -2,6 +2,11 @@ #include "../../tensor/tensor.hpp" +// 功能:计算线性变换,即matmul +// in/X: 形状[M, K] +// weight/W: 形状[N, K],存的是未转置的W +// bias/b: 形状[N](可选;为 nullptr 时等价于不加 bias) +// out/Y: 形状[M, N] namespace llaisys::ops { -void linear(tensor_t out, tensor_t in, tensor_t weight, tensor_t bias); +void linear(tensor_t out, tensor_t in, tensor_t weight, tensor_t bias = nullptr); } diff --git a/src/ops/rms_norm/cpu/rms_norm_cpu.cpp b/src/ops/rms_norm/cpu/rms_norm_cpu.cpp new file mode 100644 index 000000000..c0893ca17 --- /dev/null +++ b/src/ops/rms_norm/cpu/rms_norm_cpu.cpp @@ -0,0 +1,57 @@ +#include "rms_norm_cpu.hpp" + +#include "../../../utils.hpp" + +#include "llaisys.h" +#include +#include + +template +void rms_norm_(T *out, const T *in, const T *weight, size_t M, size_t N, float eps) { + for (size_t m = 0; m < M; m++) { + // 1. 计算当前行的均方 + float sum = 0.0f; + for (size_t n = 0; n < N; n++) { + float value = llaisys::utils::cast(in[m * N + n]); + sum += value * value; + } + float mean = sum / static_cast(N); + float scale_rms = 1.0f / std::sqrt(mean + eps); + + // 2. 乘以权重并归一化 + for (size_t n = 0; n < N; n++) { + float value = llaisys::utils::cast(in[m * N + n]); + float wei = llaisys::utils::cast(weight[n]); + float res = value * wei * scale_rms; + out[m * N + n] = llaisys::utils::cast(res); + } + } +} + +namespace llaisys::ops::cpu { +void rms_norm(std::byte *out, + const std::byte *in, + const std::byte *weight, + llaisysDataType_t dataType, + size_t M, size_t N, float eps){ + switch (dataType) { + case LLAISYS_DTYPE_F16: + return rms_norm_(reinterpret_cast(out), + reinterpret_cast(in), + reinterpret_cast(weight), + M, N, eps); + case LLAISYS_DTYPE_BF16: + return rms_norm_(reinterpret_cast(out), + reinterpret_cast(in), + reinterpret_cast(weight), + M, N, eps); + case LLAISYS_DTYPE_F32: + return rms_norm_(reinterpret_cast(out), + reinterpret_cast(in), + reinterpret_cast(weight), + M, N, eps); + default: + EXCEPTION_UNSUPPORTED_DATATYPE(dataType); + } +} +} // namespace llaisys::ops::cpu \ No newline at end of file diff --git a/src/ops/rms_norm/cpu/rms_norm_cpu.hpp b/src/ops/rms_norm/cpu/rms_norm_cpu.hpp new file mode 100644 index 000000000..2745484a8 --- /dev/null +++ b/src/ops/rms_norm/cpu/rms_norm_cpu.hpp @@ -0,0 +1,14 @@ +#pragma once + +#include "llaisys.h" +#include + +namespace llaisys::ops::cpu { +void rms_norm(std::byte *out, + const std::byte *in, + const std::byte *weight, + llaisysDataType_t type, + size_t M, + size_t N, + float eps); +} \ No newline at end of file diff --git a/src/ops/rms_norm/metax/rms_norm_metax.hpp b/src/ops/rms_norm/metax/rms_norm_metax.hpp new file mode 100644 index 000000000..7e57c47aa --- /dev/null +++ b/src/ops/rms_norm/metax/rms_norm_metax.hpp @@ -0,0 +1,18 @@ +#pragma once + +#include "llaisys.h" + +#include + +namespace llaisys::ops::metax { + +void rms_norm(std::byte *out, + const std::byte *in, + const std::byte *weight, + llaisysDataType_t type, + size_t M, + size_t N, + float eps); + +} // namespace llaisys::ops::metax + diff --git a/src/ops/rms_norm/metax/rms_norm_metax.maca b/src/ops/rms_norm/metax/rms_norm_metax.maca new file mode 100644 index 000000000..b2022805e --- /dev/null +++ b/src/ops/rms_norm/metax/rms_norm_metax.maca @@ -0,0 +1,139 @@ +#include "rms_norm_metax.hpp" + +#include "../../../utils.hpp" + +#include +#include +#include + +#include +#include + +namespace { + +constexpr int METAX_WARP_SIZE = 64; + +template +__device__ __forceinline__ T warp_reduce_sum(T local_val) { + constexpr maca_uint64_t full_mask = static_cast(~0ULL); +#pragma unroll + for (int stride = METAX_WARP_SIZE / 2; stride > 0; stride >>= 1) { + local_val + += __shfl_xor_sync(full_mask, local_val, stride, METAX_WARP_SIZE); + } + return local_val; +} + +template +__device__ __forceinline__ T block_reduce_sum(T local_val) { + constexpr int warp_per_block + = (BLOCK_SIZE + METAX_WARP_SIZE - 1) / METAX_WARP_SIZE; + const int warp_id = threadIdx.x / METAX_WARP_SIZE; + const int lane_id = threadIdx.x % METAX_WARP_SIZE; + __shared__ T shared_val[warp_per_block]; + + local_val = warp_reduce_sum(local_val); + if (lane_id == 0) { + shared_val[warp_id] = local_val; + } + __syncthreads(); + + const T lane_val + = (lane_id < warp_per_block) ? shared_val[lane_id] : static_cast(0); + return warp_reduce_sum(lane_val); +} + +template __device__ __forceinline__ float to_float_t(T v) { + return static_cast(v); +} + +template <> __device__ __forceinline__ float to_float_t<__half>(__half v) { + return __half2float(v); +} + +template <> +__device__ __forceinline__ float +to_float_t<__maca_bfloat16>(__maca_bfloat16 v) { + return __bfloat162float(v); +} + +template __device__ __forceinline__ T from_float_t(float v) { + return static_cast(v); +} + +template <> __device__ __forceinline__ __half from_float_t<__half>(float v) { + return __float2half(v); +} + +template <> +__device__ __forceinline__ __maca_bfloat16 +from_float_t<__maca_bfloat16>(float v) { + return __float2bfloat16(v); +} + +template +__global__ void rms_norm_kernel(T *out, const T *in, const T *weight, size_t M, + size_t N, float eps) { + const size_t row_id = static_cast(blockIdx.x); + if (row_id >= M) { + return; + } + + const int tid = threadIdx.x; + + float sum_thread = 0.0f; + for (size_t i = static_cast(tid); i < N; + i += static_cast(blockDim.x)) { + const float v = to_float_t(in[row_id * N + i]); + sum_thread += v * v; + } + + const float sum_block = block_reduce_sum(sum_thread); + const float mean_sq = sum_block / static_cast(N); + const float scale_rms = 1.0f / sqrtf(mean_sq + eps); + + for (size_t i = static_cast(tid); i < N; + i += static_cast(blockDim.x)) { + const float x = to_float_t(in[row_id * N + i]); + const float w = to_float_t(weight[i]); + out[row_id * N + i] = from_float_t(x * w * scale_rms); + } +} + +} // namespace + +namespace llaisys::ops::metax { + +void rms_norm(std::byte *out, const std::byte *in, const std::byte *weight, + llaisysDataType_t type, size_t M, size_t N, float eps) { + if (M == 0 || N == 0) { + return; + } + + constexpr int block_size = 512; + const int grid_size = static_cast(M); + + switch (type) { + case LLAISYS_DTYPE_F32: + rms_norm_kernel<<>>( + reinterpret_cast(out), reinterpret_cast(in), + reinterpret_cast(weight), M, N, eps); + break; + case LLAISYS_DTYPE_F16: + rms_norm_kernel<__half, block_size><<>>( + reinterpret_cast<__half *>(out), + reinterpret_cast(in), + reinterpret_cast(weight), M, N, eps); + break; + case LLAISYS_DTYPE_BF16: + rms_norm_kernel<__maca_bfloat16, block_size><<>>( + reinterpret_cast<__maca_bfloat16 *>(out), + reinterpret_cast(in), + reinterpret_cast(weight), M, N, eps); + break; + default: + EXCEPTION_UNSUPPORTED_DATATYPE(type); + } +} + +} // namespace llaisys::ops::metax diff --git a/src/ops/rms_norm/nvidia/rms_norm_nvidia.cu b/src/ops/rms_norm/nvidia/rms_norm_nvidia.cu new file mode 100644 index 000000000..c302bba0f --- /dev/null +++ b/src/ops/rms_norm/nvidia/rms_norm_nvidia.cu @@ -0,0 +1,120 @@ +#include "llaisys.h" +#include "rms_norm_nvidia.cuh" + +#include "../../../utils.hpp" +#include "../../../utils/gpu_utils.hpp" +#include + +namespace { + +template +__device__ __forceinline__ T warp_reduce_sum(T local_val) { +#pragma unroll + for (int stride = 16; stride > 0; stride >>= 1) { + local_val += __shfl_xor_sync(0xffffffff, local_val, stride); + } + return local_val; +} + +template +__device__ __forceinline__ T block_reduce_sum(T local_val) { + constexpr int warp_per_block = CEIL(BLOCK_SIZE, 32); + const int warp_id = threadIdx.x / 32; + const int lane_id = threadIdx.x % 32; + __shared__ T shared_val[warp_per_block]; + + local_val = warp_reduce_sum(local_val); + if (lane_id == 0) { + shared_val[warp_id] = local_val; + } + __syncthreads(); + + T block_sum{0}; + T lane_val = lane_id < warp_per_block ? shared_val[lane_id] : 0; + block_sum = warp_reduce_sum(lane_val); + return block_sum; +} + +template __device__ __forceinline__ float to_float_t(T v) { + return static_cast(v); +} +template <> __device__ __forceinline__ float to_float_t(half v) { + return __half2float(v); +} +template <> __device__ __forceinline__ float to_float_t(__nv_bfloat16 v) { + return __bfloat162float(v); +} + +template __device__ __forceinline__ T from_float_t(float v) { + return static_cast(v); +} +template <> __device__ __forceinline__ half from_float_t(float v) { + return __float2half(v); +} +template <> +__device__ __forceinline__ __nv_bfloat16 from_float_t<__nv_bfloat16>(float v) { + return __float2bfloat16(v); +} + +template +__global__ void rms_norm_kernel(T *out, const T *in, const T *weight, size_t M, + size_t N, float eps) { + const size_t row_id = blockIdx.x; + if (row_id >= M) + return; + + const int tid = threadIdx.x; + + // 1. 每个线程求局部平方和(用 float 累加) + float sum_thread = 0.0f; + for (int i = tid; i < N; i += blockDim.x) { + float v = to_float_t(in[row_id * N + i]); + sum_thread += v * v; + } + + // 2. block 内归约得到整行平方和,所有线程得到同一 sum_sq + float sum_block = block_reduce_sum(sum_thread); + float mean_sq = sum_block / static_cast(N); + float scale_rms = 1.0f / sqrtf(mean_sq + eps); + + // 3. 归一化并写回:out[i] = in[i] * weight[i] * scale_rms + for (int i = tid; i < N; i += blockDim.x) { + float x = to_float_t(in[row_id * N + i]); + float w = to_float_t(weight[i]); + float y = x * w * scale_rms; + out[row_id * N + i] = from_float_t(y); + } +} + +} // namespace + +namespace llaisys::ops::nvidia { +void rms_norm(std::byte *out, const std::byte *in, const std::byte *weight, + llaisysDataType_t type, size_t M, size_t N, float eps) { + if (M == 0 || N == 0) + return; + constexpr int block_size = 256; + const int grid_size = static_cast(M); + switch (type) { + case LLAISYS_DTYPE_F32: + rms_norm_kernel<<>>( + reinterpret_cast(out), reinterpret_cast(in), + reinterpret_cast(weight), M, N, eps); + break; + case LLAISYS_DTYPE_F16: + rms_norm_kernel<<>>( + reinterpret_cast(out), reinterpret_cast(in), + reinterpret_cast(weight), M, N, eps); + break; + case LLAISYS_DTYPE_BF16: + rms_norm_kernel<__nv_bfloat16><<>>( + reinterpret_cast<__nv_bfloat16 *>(out), + reinterpret_cast(in), + reinterpret_cast(weight), M, N, eps); + break; + default: + EXCEPTION_UNSUPPORTED_DATATYPE(type); + } + CUDA_CHECK(cudaGetLastError()); +} +} // namespace llaisys::ops::nvidia diff --git a/src/ops/rms_norm/nvidia/rms_norm_nvidia.cuh b/src/ops/rms_norm/nvidia/rms_norm_nvidia.cuh new file mode 100644 index 000000000..d16ca5d95 --- /dev/null +++ b/src/ops/rms_norm/nvidia/rms_norm_nvidia.cuh @@ -0,0 +1,7 @@ +#include "llaisys.h" +#include + +namespace llaisys::ops::nvidia { +void rms_norm(std::byte *out, const std::byte *in, const std::byte *weight, + llaisysDataType_t type, size_t M, size_t N, float eps); +} \ No newline at end of file diff --git a/src/ops/rms_norm/op.cpp b/src/ops/rms_norm/op.cpp index 529553d9d..a1d639a64 100644 --- a/src/ops/rms_norm/op.cpp +++ b/src/ops/rms_norm/op.cpp @@ -1,7 +1,57 @@ #include "op.hpp" + +#include "./cpu/rms_norm_cpu.hpp" +#ifdef ENABLE_NVIDIA_API +#include "./nvidia/rms_norm_nvidia.cuh" +#endif +#ifdef ENABLE_METAX_API +#include "./metax/rms_norm_metax.hpp" +#endif +#include "llaisys.h" + namespace llaisys::ops { void rms_norm(tensor_t out, tensor_t in, tensor_t weight, float eps) { - TO_BE_IMPLEMENTED(); + // 1. 参数校验 + CHECK_SAME_DEVICE(out, in, weight); + CHECK_ARGUMENT(out->ndim() == 2, "out must be 2d"); + CHECK_ARGUMENT(in->ndim() == 2, "in must be 2d"); + CHECK_ARGUMENT(weight->ndim() == 1, "weight must be 1d"); + CHECK_ARGUMENT(out->shape()[0] == in->shape()[0] && out->shape()[1] == in->shape()[1], + "out's shape must be same as in's shape"); + CHECK_ARGUMENT(weight->shape()[0] == out->shape()[1], + "weight and out must have equal N"); + CHECK_ARGUMENT(out->dtype() == in->dtype() && out->dtype() == weight->dtype(), + "tensors must have the same dtype"); + ASSERT(out->isContiguous() && in->isContiguous() && weight->isContiguous(), + "tensors must be contiguous"); + + // 2. 设置设备上下文 + llaisys::core::context().setDevice(out->deviceType(), out->deviceId()); + + // 3. 张量分发到指定设备 + size_t M = out->shape()[0]; + size_t N = out->shape()[1]; + + switch (out->deviceType()) { + case LLAISYS_DEVICE_CPU: + return cpu::rms_norm(out->data(), + in->data(), + weight->data(), + out->dtype(), + M, N, eps); +#ifdef ENABLE_NVIDIA_API + case LLAISYS_DEVICE_NVIDIA: + return nvidia::rms_norm(out->data(), in->data(), weight->data(), + out->dtype(), M, N, eps); +#endif +#ifdef ENABLE_METAX_API + case LLAISYS_DEVICE_METAX: + return metax::rms_norm(out->data(), in->data(), weight->data(), + out->dtype(), M, N, eps); +#endif + default: + EXCEPTION_UNSUPPORTED_DEVICE; + } } } // namespace llaisys::ops diff --git a/src/ops/rope/cpu/rope_cpu.cpp b/src/ops/rope/cpu/rope_cpu.cpp new file mode 100644 index 000000000..829ac9912 --- /dev/null +++ b/src/ops/rope/cpu/rope_cpu.cpp @@ -0,0 +1,63 @@ +#include "rope_cpu.hpp" + +#include "../../../utils.hpp" + +#include +#include + +template +static void rope_(T *out, const T *in, const int64_t *pos_ids, size_t seqlen, + size_t nhead, size_t head_dim, float theta) { + const size_t half = head_dim / 2; + + // denom[j] = theta^(2j/d) + std::vector denom(half); + for (size_t j = 0; j < half; ++j) { + const float exponent = + (2.0f * static_cast(j)) / static_cast(head_dim); + denom[j] = ::powf(theta, exponent); + } + + for (size_t s = 0; s < seqlen; ++s) { + // pos对应seqlen位置的position id + const float p = static_cast(pos_ids[s]); + for (size_t h = 0; h < nhead; ++h) { + const size_t offset = (s * nhead + h) * head_dim; + // 将相邻的两个特征维度合并为一组,然后一起旋转 + for (size_t j = 0; j < half; ++j) { + const float phi = p / denom[j]; + const float sinv = ::sinf(phi); + const float cosv = ::cosf(phi); + + const float a = llaisys::utils::cast(in[offset + j]); + const float b = llaisys::utils::cast(in[offset + j + half]); + + out[offset + j] = llaisys::utils::cast(a * cosv - b * sinv); + out[offset + j + half] = llaisys::utils::cast(a * sinv + b * cosv); + } + } + } +} + +namespace llaisys::ops::cpu { +void rope(std::byte *out, const std::byte *in, const int64_t *pos_ids, + llaisysDataType_t type, size_t seqlen, size_t nhead, size_t head_dim, + float theta) { + switch (type) { + case LLAISYS_DTYPE_F32: + return rope_(reinterpret_cast(out), + reinterpret_cast(in), pos_ids, seqlen, nhead, + head_dim, theta); + case LLAISYS_DTYPE_F16: + return rope_(reinterpret_cast(out), + reinterpret_cast(in), pos_ids, seqlen, + nhead, head_dim, theta); + case LLAISYS_DTYPE_BF16: + return rope_(reinterpret_cast(out), + reinterpret_cast(in), pos_ids, seqlen, + nhead, head_dim, theta); + default: + EXCEPTION_UNSUPPORTED_DATATYPE(type); + } +} +} // namespace llaisys::ops::cpu diff --git a/src/ops/rope/cpu/rope_cpu.hpp b/src/ops/rope/cpu/rope_cpu.hpp new file mode 100644 index 000000000..9c1c6352a --- /dev/null +++ b/src/ops/rope/cpu/rope_cpu.hpp @@ -0,0 +1,16 @@ +#pragma once + +#include "llaisys.h" +#include + +namespace llaisys::ops::cpu { +void rope(std::byte *out, + const std::byte *in, + const int64_t *pos_ids, + llaisysDataType_t type, + size_t seqlen, + size_t nhead, + size_t head_dim, + float theta); +} // namespace llaisys::ops::cpu + diff --git a/src/ops/rope/metax/rope_metax.hpp b/src/ops/rope/metax/rope_metax.hpp new file mode 100644 index 000000000..4d93b18a7 --- /dev/null +++ b/src/ops/rope/metax/rope_metax.hpp @@ -0,0 +1,19 @@ +#pragma once + +#include "llaisys.h" + +#include +#include + +namespace llaisys::ops::metax { + +void rope(std::byte *out, + const std::byte *in, + const int64_t *pos_ids, + llaisysDataType_t type, + size_t seqlen, + size_t nhead, + size_t head_dim, + float theta); + +} // namespace llaisys::ops::metax diff --git a/src/ops/rope/metax/rope_metax.maca b/src/ops/rope/metax/rope_metax.maca new file mode 100644 index 000000000..1278cbd0b --- /dev/null +++ b/src/ops/rope/metax/rope_metax.maca @@ -0,0 +1,112 @@ +#include "rope_metax.hpp" + +#include "../../../utils.hpp" + +#include +#include +#include + +#include + +namespace { + +template __device__ __forceinline__ float to_float_t(T v) { + return static_cast(v); +} + +template <> __device__ __forceinline__ float to_float_t<__half>(__half v) { + return __half2float(v); +} + +template <> +__device__ __forceinline__ float +to_float_t<__maca_bfloat16>(__maca_bfloat16 v) { + return __bfloat162float(v); +} + +template __device__ __forceinline__ T from_float_t(float v) { + return static_cast(v); +} + +template <> __device__ __forceinline__ __half from_float_t<__half>(float v) { + return __float2half(v); +} + +template <> +__device__ __forceinline__ __maca_bfloat16 +from_float_t<__maca_bfloat16>(float v) { + return __float2bfloat16(v); +} + +// in/out: [seqlen, nhead, head_dim] +// pos_ids: [seqlen] +template +__global__ void rope_kernel(T *out, const T *in, const int64_t *pos_ids, + size_t seqlen, size_t nhead, size_t head_dim, + float theta) { + const size_t bid = static_cast(blockIdx.x); + if (bid >= seqlen * nhead) { + return; + } + + const size_t seqlen_idx = bid / nhead; + const size_t head_id = bid % nhead; + const size_t half = head_dim / 2; + const size_t offset = (seqlen_idx * nhead + head_id) * head_dim; + const float pos_val = static_cast(pos_ids[seqlen_idx]); + + for (size_t j = static_cast(threadIdx.x); j < half; + j += static_cast(blockDim.x)) { + const float exponent + = (2.0f * static_cast(j)) / static_cast(head_dim); + const float phi = pos_val / powf(theta, exponent); + const float sinv = sinf(phi); + const float cosv = cosf(phi); + + const float a = to_float_t(in[offset + j]); + const float b = to_float_t(in[offset + j + half]); + + out[offset + j] = from_float_t(a * cosv - b * sinv); + out[offset + j + half] = from_float_t(b * cosv + a * sinv); + } +} + +} // namespace + +namespace llaisys::ops::metax { + +void rope(std::byte *out, const std::byte *in, const int64_t *pos_ids, + llaisysDataType_t type, size_t seqlen, size_t nhead, size_t head_dim, + float theta) { + if (seqlen == 0 || nhead == 0 || head_dim == 0) { + return; + } + + const size_t total_heads = seqlen * nhead; + constexpr int block_size = 512; + const int grid_size = static_cast(total_heads); + + switch (type) { + case LLAISYS_DTYPE_F32: + rope_kernel<<>>( + reinterpret_cast(out), reinterpret_cast(in), + pos_ids, seqlen, nhead, head_dim, theta); + break; + case LLAISYS_DTYPE_F16: + rope_kernel<__half><<>>( + reinterpret_cast<__half *>(out), + reinterpret_cast(in), pos_ids, seqlen, nhead, + head_dim, theta); + break; + case LLAISYS_DTYPE_BF16: + rope_kernel<__maca_bfloat16><<>>( + reinterpret_cast<__maca_bfloat16 *>(out), + reinterpret_cast(in), pos_ids, seqlen, + nhead, head_dim, theta); + break; + default: + EXCEPTION_UNSUPPORTED_DATATYPE(type); + } +} + +} // namespace llaisys::ops::metax diff --git a/src/ops/rope/nvidia/rope_nvidia.cu b/src/ops/rope/nvidia/rope_nvidia.cu new file mode 100644 index 000000000..2ef08b726 --- /dev/null +++ b/src/ops/rope/nvidia/rope_nvidia.cu @@ -0,0 +1,108 @@ +#include "rope_nvidia.cuh" + +#include "../../../utils.hpp" +#include "../../../utils/gpu_utils.hpp" + +#include + +namespace { + +// 将不同 T 转为 float 做计算 +template __device__ __forceinline__ float to_float_t(T v) { + return static_cast(v); +} +template <> __device__ __forceinline__ float to_float_t(half v) { + return __half2float(v); +} +template <> __device__ __forceinline__ float to_float_t(__nv_bfloat16 v) { + return __bfloat162float(v); +} + +// 将 float 转回不同 T +template __device__ __forceinline__ T from_float_t(float v) { + return static_cast(v); +} +template <> __device__ __forceinline__ half from_float_t(float v) { + return __float2half(v); +} +template <> +__device__ __forceinline__ __nv_bfloat16 from_float_t<__nv_bfloat16>(float v) { + return __float2bfloat16(v); +} + +// in/out: [seqlen, nhead, head_dim] +// pos_ids: [seqlen] +template +__global__ void rope_kernel(T *out, const T *in, const int64_t *pos_ids, + size_t seqlen, size_t nhead, size_t head_dim, + float theta) { + const size_t bid = blockIdx.x; + if (bid >= seqlen * nhead) { + return; + } + + const size_t seqlen_idx = bid / nhead; + const size_t head_id = bid % nhead; + + const size_t half = head_dim / 2; + const size_t offset = (seqlen_idx * nhead + head_id) * head_dim; + const float pos_val = to_float_t(pos_ids[seqlen_idx]); + + for (int j = threadIdx.x; j < half; j += blockDim.x) { + const float exponent = (2.0f * static_cast(j)) / static_cast(head_dim); + const float denom = powf(theta, exponent); + const float phi = pos_val / denom; + const float sinv = sinf(phi); + const float cosv = cosf(phi); + + const float a = to_float_t(in[offset + j]); + const float b = to_float_t(in[offset + j + half]); + + const float outa = a * cosv - b * sinv; + const float outb = b * cosv + a * sinv; + + out[offset + j] = from_float_t(outa); + out[offset + j + half] = from_float_t(outb); + } +} + +} // namespace + +namespace llaisys::ops::nvidia { + +void rope(std::byte *out, const std::byte *in, const int64_t *pos_ids, + llaisysDataType_t type, size_t seqlen, size_t nhead, size_t head_dim, + float theta) { + if (seqlen == 0 || nhead == 0 || head_dim == 0) { + return; + } + + const size_t total_heads = seqlen * nhead; + constexpr int block_size = 256; + const int grid_size = static_cast(total_heads); + + switch (type) { + case LLAISYS_DTYPE_F32: + rope_kernel<<>>( + reinterpret_cast(out), reinterpret_cast(in), + pos_ids, seqlen, nhead, head_dim, theta); + break; + case LLAISYS_DTYPE_F16: + rope_kernel<<>>( + reinterpret_cast(out), reinterpret_cast(in), + pos_ids, seqlen, nhead, head_dim, theta); + break; + case LLAISYS_DTYPE_BF16: + rope_kernel<__nv_bfloat16> + <<>>(reinterpret_cast<__nv_bfloat16 *>(out), + reinterpret_cast(in), + pos_ids, seqlen, nhead, head_dim, theta); + break; + default: + EXCEPTION_UNSUPPORTED_DATATYPE(type); + } + + CUDA_CHECK(cudaGetLastError()); +} + +} // namespace llaisys::ops::nvidia diff --git a/src/ops/rope/nvidia/rope_nvidia.cuh b/src/ops/rope/nvidia/rope_nvidia.cuh new file mode 100644 index 000000000..1b1b1b9bc --- /dev/null +++ b/src/ops/rope/nvidia/rope_nvidia.cuh @@ -0,0 +1,12 @@ +#pragma once + +#include "llaisys.h" +#include + +namespace llaisys::ops::nvidia { + +void rope(std::byte *out, const std::byte *in, const int64_t *pos_ids, + llaisysDataType_t type, size_t seqlen, size_t nhead, size_t head_dim, + float theta); + +} // namespace llaisys::ops::nvidia diff --git a/src/ops/rope/op.cpp b/src/ops/rope/op.cpp index d60dbe64e..eac2fc606 100644 --- a/src/ops/rope/op.cpp +++ b/src/ops/rope/op.cpp @@ -1,7 +1,73 @@ #include "op.hpp" +#include "../../core/llaisys_core.hpp" +#include "../../utils.hpp" + +#include "cpu/rope_cpu.hpp" +#ifdef ENABLE_NVIDIA_API +#include "nvidia/rope_nvidia.cuh" +#endif +#ifdef ENABLE_METAX_API +#include "metax/rope_metax.hpp" +#endif + namespace llaisys::ops { void rope(tensor_t out, tensor_t in, tensor_t pos_ids, float theta) { - TO_BE_IMPLEMENTED(); + // 1. 参数校验 + CHECK_SAME_DEVICE(out, in, pos_ids); + CHECK_SAME_SHAPE(out->shape(), in->shape()); + CHECK_SAME_DTYPE(out->dtype(), in->dtype()); + ASSERT(out->isContiguous() && in->isContiguous() && pos_ids->isContiguous(), + "RoPE: all tensors must be contiguous."); + + CHECK_ARGUMENT(out->ndim() == 3, "RoPE: out must be 3D [seqlen, nhead, d]."); + CHECK_ARGUMENT(pos_ids->ndim() == 1, "RoPE: pos_ids must be 1D [seqlen]."); + CHECK_ARGUMENT(pos_ids->dtype() == LLAISYS_DTYPE_I64, "RoPE: pos_ids must be int64."); + CHECK_ARGUMENT(theta > 0.0f, "RoPE: theta must be positive."); + + const size_t seqlen = out->shape()[0]; + const size_t nhead = out->shape()[1]; + const size_t d = out->shape()[2]; + CHECK_ARGUMENT((d % 2) == 0, "RoPE: head_dim must be even."); + CHECK_ARGUMENT(pos_ids->shape()[0] == seqlen, "RoPE: pos_ids shape must match seqlen."); + + // 2. 设置设备上下文 + llaisys::core::context().setDevice(out->deviceType(), out->deviceId()); + + switch (out->deviceType()) { + case LLAISYS_DEVICE_CPU: + return cpu::rope(out->data(), + in->data(), + reinterpret_cast(pos_ids->data()), + out->dtype(), + seqlen, + nhead, + d, + theta); +#ifdef ENABLE_NVIDIA_API + case LLAISYS_DEVICE_NVIDIA: + return nvidia::rope(out->data(), + in->data(), + reinterpret_cast(pos_ids->data()), + out->dtype(), + seqlen, + nhead, + d, + theta); +#endif +#ifdef ENABLE_METAX_API + case LLAISYS_DEVICE_METAX: + return metax::rope(out->data(), + in->data(), + reinterpret_cast(pos_ids->data()), + out->dtype(), + seqlen, + nhead, + d, + theta); +#endif + default: + EXCEPTION_UNSUPPORTED_DEVICE; + } } } // namespace llaisys::ops diff --git a/src/ops/self_attention/cpu/self_attention_cpu.cpp b/src/ops/self_attention/cpu/self_attention_cpu.cpp new file mode 100644 index 000000000..07e1e633e --- /dev/null +++ b/src/ops/self_attention/cpu/self_attention_cpu.cpp @@ -0,0 +1,117 @@ +#include "self_attention_cpu.hpp" + +#include "../../../utils.hpp" + +#include +#include + +namespace { +constexpr float NEG_INF = -1e9f; +} + +template +static void self_attention_(std::byte *attn_val, + const std::byte *q, + const std::byte *k, + const std::byte *v, + size_t seqlen, + size_t nhead, + size_t nkvhead, + size_t d, + size_t dv, + size_t total_len, + float scale) { + const T *qT = reinterpret_cast(q); + const T *kT = reinterpret_cast(k); + const T *vT = reinterpret_cast(v); + T *outT = reinterpret_cast(attn_val); + + std::vector scores(seqlen * total_len); + + // 遍历层级:head(头)--->seqlen(序列长度) + for (size_t h = 0; h < nhead; ++h) { + const size_t kv_head = h * nkvhead / nhead; + + // 1. Scores: (seqlen, total_len), A[i,j] = scale * q[i,h,:] @ k[j,kv_head,:] + for (size_t i = 0; i < seqlen; ++i) { // 遍历每个query位置 + for (size_t j = 0; j < total_len; ++j) { // 遍历每个key位置 + float acc = 0.f; + for (size_t kd = 0; kd < d; ++kd) { + float qv = llaisys::utils::cast(qT[(i * nhead + h) * d + kd]); + float kv = llaisys::utils::cast(kT[(j * nkvhead + kv_head) * d + kd]); + acc += qv * kv; + } + scores[i * total_len + j] = scale * acc; + } + } + + // 2. Causal: mask (i,j) when j > i + (total_len - seqlen) + // 这是为了确保在推理时,模型只能看到当前位置之前的上下文,而不能看到未来的信息 + // total_len:kvcache的总长度 seqlen:当前序列的长度 + // diag = total_len - seqlen : 历史token的数量(也就是当前序列的起始位置) + // 置为-INF而不是0,因为exp(0) = 1,会导致softmax结果不正确 + const ptrdiff_t diag = static_cast(total_len) - static_cast(seqlen); + for (size_t i = 0; i < seqlen; ++i) { + for (size_t j = 0; j < total_len; ++j) { + // i:当前query在序列中的相对位置,j:当前key在KV Cache中的绝对位置 + if (static_cast(j) > static_cast(i) + diag) + scores[i * total_len + j] = NEG_INF; // mask掉未来的位置 + } + } + + // 3. 对每个query位置,计算softmax:softmax(scores[i,:])->attn[i,:] + for (size_t i = 0; i < seqlen; ++i) { + float *row = &scores[i * total_len]; + float row_max = row[0]; + for (size_t j = 1; j < total_len; ++j) { + if (row[j] > row_max) + row_max = row[j]; + } + float sum = 0.f; + for (size_t j = 0; j < total_len; ++j) { + row[j] = std::exp(row[j] - row_max); + sum += row[j]; + } + for (size_t j = 0; j < total_len; ++j) + row[j] /= sum; + } + + // 4. 用注意力分数对V进行加权求和:attn_val[i,h,:](1 * dv) = attn[i,:] (1 * total_len) @ v[:,kv_head,:] (total_len * dv) + // scores[seqlen, total_len], v[total_len, nkvhead, dv], out[seqlen, nhead, dv] + for (size_t i = 0; i < seqlen; ++i) { + for (size_t m = 0; m < dv; ++m) { + float acc = 0.f; + for (size_t j = 0; j < total_len; ++j) { + acc += scores[i * total_len + j] * llaisys::utils::cast(vT[(j * nkvhead + kv_head) * dv + m]); + } + outT[(i * nhead + h) * dv + m] = llaisys::utils::cast(acc); + } + } + } +} + +namespace llaisys::ops::cpu { +void self_attention(std::byte *attn_val, + const std::byte *q, + const std::byte *k, + const std::byte *v, + llaisysDataType_t dtype, + size_t seqlen, + size_t nhead, + size_t nkvhead, + size_t d, + size_t dv, + size_t total_len, + float scale) { + switch (dtype) { + case LLAISYS_DTYPE_F32: + return self_attention_(attn_val, q, k, v, seqlen, nhead, nkvhead, d, dv, total_len, scale); + case LLAISYS_DTYPE_F16: + return self_attention_(attn_val, q, k, v, seqlen, nhead, nkvhead, d, dv, total_len, scale); + case LLAISYS_DTYPE_BF16: + return self_attention_(attn_val, q, k, v, seqlen, nhead, nkvhead, d, dv, total_len, scale); + default: + EXCEPTION_UNSUPPORTED_DATATYPE(dtype); + } +} +} // namespace llaisys::ops::cpu diff --git a/src/ops/self_attention/cpu/self_attention_cpu.hpp b/src/ops/self_attention/cpu/self_attention_cpu.hpp new file mode 100644 index 000000000..b2a54b152 --- /dev/null +++ b/src/ops/self_attention/cpu/self_attention_cpu.hpp @@ -0,0 +1,19 @@ +#pragma once + +#include "llaisys.h" +#include + +namespace llaisys::ops::cpu { +void self_attention(std::byte *attn_val, + const std::byte *q, + const std::byte *k, + const std::byte *v, + llaisysDataType_t dtype, + size_t seqlen, + size_t nhead, + size_t nkvhead, + size_t d, + size_t dv, + size_t total_len, + float scale); +} // namespace llaisys::ops::cpu diff --git a/src/ops/self_attention/metax/self_attention_metax.hpp b/src/ops/self_attention/metax/self_attention_metax.hpp new file mode 100644 index 000000000..a73ffae4b --- /dev/null +++ b/src/ops/self_attention/metax/self_attention_metax.hpp @@ -0,0 +1,23 @@ +#pragma once + +#include "llaisys.h" + +#include + +namespace llaisys::ops::metax { + +void self_attention(std::byte *attn_val, + const std::byte *q, + const std::byte *k, + const std::byte *v, + llaisysDataType_t type, + size_t seqlen, + size_t nhead, + size_t nkvhead, + size_t d, + size_t dv, + size_t total_len, + float scale); + +} // namespace llaisys::ops::metax + diff --git a/src/ops/self_attention/metax/self_attention_metax.maca b/src/ops/self_attention/metax/self_attention_metax.maca new file mode 100644 index 000000000..62a103d08 --- /dev/null +++ b/src/ops/self_attention/metax/self_attention_metax.maca @@ -0,0 +1,228 @@ +#include "self_attention_metax.hpp" + +#include "../../../utils.hpp" + +#include +#include +#include + +#include +#include + +namespace { + +constexpr int METAX_WARP_SIZE = 64; + +template __device__ __forceinline__ float to_float_t(T v) { + return static_cast(v); +} + +template <> __device__ __forceinline__ float to_float_t<__half>(__half v) { + return __half2float(v); +} + +template <> +__device__ __forceinline__ float +to_float_t<__maca_bfloat16>(__maca_bfloat16 v) { + return __bfloat162float(v); +} + +template __device__ __forceinline__ T from_float_t(float v) { + return static_cast(v); +} + +template <> __device__ __forceinline__ __half from_float_t<__half>(float v) { + return __float2half(v); +} + +template <> +__device__ __forceinline__ __maca_bfloat16 +from_float_t<__maca_bfloat16>(float v) { + return __float2bfloat16(v); +} + +__device__ __forceinline__ float warp_sum(float val) { + constexpr maca_uint64_t full_mask = static_cast(~0ULL); +#pragma unroll + for (int offset = METAX_WARP_SIZE / 2; offset > 0; offset >>= 1) { + val += __shfl_down_sync(full_mask, val, offset, METAX_WARP_SIZE); + } + return val; +} + +template +__global__ void self_attention_online_kernel( + T *__restrict__ out, const T *__restrict__ q, const T *__restrict__ k, + const T *__restrict__ v, size_t seqlen, size_t nhead, size_t nkvhead, + size_t d, size_t dv, size_t total_len, float scale) { + const size_t block_id = static_cast(blockIdx.x); + if (block_id >= seqlen * nhead) { + return; + } + + const size_t qi = block_id / nhead; + const size_t qh = block_id % nhead; + const size_t kv_head = qh * nkvhead / nhead; + + const T *q_row = q + (qi * nhead + qh) * d; + T *out_row = out + (qi * nhead + qh) * dv; + + const ptrdiff_t diag + = static_cast(total_len) - static_cast(seqlen); + const ptrdiff_t max_visible_key = static_cast(qi) + diag; + if (max_visible_key < 0) { + for (size_t m = static_cast(threadIdx.x); m < dv; + m += BLOCK_SIZE) { + out_row[m] = from_float_t(0.0f); + } + return; + } + const size_t visible_len + = (static_cast(max_visible_key) + 1 < total_len) + ? static_cast(max_visible_key) + 1 + : total_len; + + // Dynamic shared memory layout: [q_cache(d), score(1)]. + extern __shared__ float smem[]; + float *q_cache = smem; + float *score_ptr = q_cache + d; + + for (size_t kd = static_cast(threadIdx.x); kd < d; + kd += BLOCK_SIZE) { + q_cache[kd] = to_float_t(q_row[kd]); + } + __syncthreads(); + + int local_idx[MAX_LOCAL_OUT]; + double local_acc[MAX_LOCAL_OUT]; + int local_n = 0; + for (size_t m = static_cast(threadIdx.x); + m < dv && local_n < MAX_LOCAL_OUT; m += BLOCK_SIZE) { + local_idx[local_n] = static_cast(m); + local_acc[local_n] = 0.0; + ++local_n; + } + + double row_m = -INFINITY; + double row_l = 0.0; + + for (size_t j = 0; j < visible_len; ++j) { + if (threadIdx.x < METAX_WARP_SIZE) { + const T *k_row = k + (j * nkvhead + kv_head) * d; + float dot = 0.0f; + for (size_t kd = static_cast(threadIdx.x); kd < d; + kd += METAX_WARP_SIZE) { + dot += q_cache[kd] * to_float_t(k_row[kd]); + } + dot = warp_sum(dot); + if (threadIdx.x == 0) { + *score_ptr = dot * scale; + } + } + __syncthreads(); + + const double score = static_cast(*score_ptr); + const double m_new = fmax(row_m, score); + const double alpha = (row_l == 0.0) ? 0.0 : exp(row_m - m_new); + const double beta = exp(score - m_new); + const double l_new = row_l * alpha + beta; + + const T *v_row = v + (j * nkvhead + kv_head) * dv; +#pragma unroll + for (int t = 0; t < MAX_LOCAL_OUT; ++t) { + if (t < local_n) { + local_acc[t] = local_acc[t] * alpha + + beta * static_cast(to_float_t(v_row[local_idx[t]])); + } + } + row_m = m_new; + row_l = l_new; + __syncthreads(); + } + + const double inv_l = (row_l > 0.0) ? (1.0 / row_l) : 0.0; +#pragma unroll + for (int t = 0; t < MAX_LOCAL_OUT; ++t) { + if (t < local_n) { + out_row[local_idx[t]] + = from_float_t(static_cast(local_acc[t] * inv_l)); + } + } + + // Rare fallback for very large dv. + for (size_t m = static_cast(threadIdx.x) + + static_cast(BLOCK_SIZE * MAX_LOCAL_OUT); + m < dv; m += BLOCK_SIZE) { + double acc = 0.0; + for (size_t j = 0; j < visible_len; ++j) { + const T *k_row = k + (j * nkvhead + kv_head) * d; + float dot = 0.0f; + for (size_t kd = 0; kd < d; ++kd) { + dot += q_cache[kd] * to_float_t(k_row[kd]); + } + const double prob + = (row_l > 0.0) + ? exp(static_cast(dot) * static_cast(scale) + - row_m) + * inv_l + : 0.0; + acc += prob + * static_cast( + to_float_t(v[(j * nkvhead + kv_head) * dv + m])); + } + out_row[m] = from_float_t(static_cast(acc)); + } +} + +} // namespace + +namespace llaisys::ops::metax { + +void self_attention(std::byte *attn_val, const std::byte *q, const std::byte *k, + const std::byte *v, llaisysDataType_t type, size_t seqlen, + size_t nhead, size_t nkvhead, size_t d, size_t dv, + size_t total_len, float scale) { + if (seqlen == 0 || nhead == 0 || nkvhead == 0 || d == 0 || dv == 0 + || total_len == 0) { + return; + } + + const int grid_size = static_cast(seqlen * nhead); + constexpr int block_size = 128; + constexpr int max_local_out = 8; + const size_t smem_bytes = sizeof(float) * (d + 1); + + switch (type) { + case LLAISYS_DTYPE_F32: + self_attention_online_kernel + <<>>( + reinterpret_cast(attn_val), + reinterpret_cast(q), + reinterpret_cast(k), + reinterpret_cast(v), seqlen, nhead, nkvhead, d, + dv, total_len, scale); + break; + case LLAISYS_DTYPE_F16: + self_attention_online_kernel<__half, block_size, max_local_out> + <<>>( + reinterpret_cast<__half *>(attn_val), + reinterpret_cast(q), + reinterpret_cast(k), + reinterpret_cast(v), seqlen, nhead, nkvhead, d, + dv, total_len, scale); + break; + case LLAISYS_DTYPE_BF16: + self_attention_online_kernel<__maca_bfloat16, block_size, max_local_out> + <<>>( + reinterpret_cast<__maca_bfloat16 *>(attn_val), + reinterpret_cast(q), + reinterpret_cast(k), + reinterpret_cast(v), seqlen, nhead, + nkvhead, d, dv, total_len, scale); + break; + default: + EXCEPTION_UNSUPPORTED_DATATYPE(type); + } +} + +} // namespace llaisys::ops::metax diff --git a/src/ops/self_attention/nvidia/self_attention_nvidia.cu b/src/ops/self_attention/nvidia/self_attention_nvidia.cu new file mode 100644 index 000000000..34db608c0 --- /dev/null +++ b/src/ops/self_attention/nvidia/self_attention_nvidia.cu @@ -0,0 +1,190 @@ +#include "self_attention_nvidia.cuh" + +#include "../../../utils.hpp" +#include "../../../utils/gpu_utils.hpp" + +#include + +namespace { + +__device__ __forceinline__ float warp_sum(float val) { + #pragma unroll + for (int offset = 16; offset > 0; offset >>= 1) { + val += __shfl_down_sync(0xffffffff, val, offset); + } + return val; +} + +template +__global__ void self_attention_online_kernel(T *__restrict__ out, + const T *__restrict__ q, + const T *__restrict__ k, + const T *__restrict__ v, + size_t seqlen, + size_t nhead, + size_t nkvhead, + size_t d, + size_t dv, + size_t total_len, + float scale) { + const size_t block_id = static_cast(blockIdx.x); + if (block_id >= seqlen * nhead) { + return; + } + + const size_t qi = block_id / nhead; + const size_t qh = block_id % nhead; + const size_t kv_head = qh * nkvhead / nhead; + + const T *q_row = q + (qi * nhead + qh) * d; + T *out_row = out + (qi * nhead + qh) * dv; + + const ptrdiff_t diag = static_cast(total_len) - static_cast(seqlen); + const ptrdiff_t max_visible_key = static_cast(qi) + diag; + if (max_visible_key < 0) { + for (size_t m = static_cast(threadIdx.x); m < dv; m += BLOCK_SIZE) { + out_row[m] = from_float(0.0f); + } + return; + } + const size_t visible_len = (static_cast(max_visible_key) + 1 < total_len) + ? static_cast(max_visible_key) + 1 + : total_len; + + // Dynamic shared memory layout: [q_cache(d), score(1)] + extern __shared__ float smem[]; + float *q_cache = smem; + float *score_ptr = q_cache + d; + + for (size_t kd = static_cast(threadIdx.x); kd < d; kd += BLOCK_SIZE) { + q_cache[kd] = to_float(q_row[kd]); + } + __syncthreads(); + + int local_idx[MAX_LOCAL_OUT]; + float local_acc[MAX_LOCAL_OUT]; + int local_n = 0; + for (size_t m = static_cast(threadIdx.x); m < dv && local_n < MAX_LOCAL_OUT; m += BLOCK_SIZE) { + local_idx[local_n] = static_cast(m); + local_acc[local_n] = 0.0f; + ++local_n; + } + + float row_m = -INFINITY; + float row_l = 0.0f; + + for (size_t j = 0; j < visible_len; ++j) { + if (threadIdx.x < 32) { + const T *k_row = k + (j * nkvhead + kv_head) * d; + float dot = 0.0f; + for (size_t kd = static_cast(threadIdx.x); kd < d; kd += 32) { + dot += q_cache[kd] * to_float(k_row[kd]); + } + dot = warp_sum(dot); + if (threadIdx.x == 0) { + *score_ptr = dot * scale; + } + } + __syncthreads(); + + const float score = *score_ptr; + const float m_new = fmaxf(row_m, score); + const float alpha = (row_l == 0.0f) ? 0.0f : expf(row_m - m_new); + const float beta = expf(score - m_new); + const float l_new = row_l * alpha + beta; + + const T *v_row = v + (j * nkvhead + kv_head) * dv; + #pragma unroll + for (int t = 0; t < MAX_LOCAL_OUT; ++t) { + if (t < local_n) { + local_acc[t] = local_acc[t] * alpha + beta * to_float(v_row[local_idx[t]]); + } + } + row_m = m_new; + row_l = l_new; + __syncthreads(); + } + + const float inv_l = (row_l > 0.0f) ? (1.0f / row_l) : 0.0f; + #pragma unroll + for (int t = 0; t < MAX_LOCAL_OUT; ++t) { + if (t < local_n) { + out_row[local_idx[t]] = from_float(local_acc[t] * inv_l); + } + } + + // Rare fallback for very large dv. + for (size_t m = static_cast(threadIdx.x) + static_cast(BLOCK_SIZE * MAX_LOCAL_OUT); m < dv; + m += BLOCK_SIZE) { + float acc = 0.0f; + for (size_t j = 0; j < visible_len; ++j) { + const T *k_row = k + (j * nkvhead + kv_head) * d; + float dot = 0.0f; + for (size_t kd = 0; kd < d; ++kd) { + dot += q_cache[kd] * to_float(k_row[kd]); + } + const float prob = (row_l > 0.0f) ? expf(dot * scale - row_m) * inv_l : 0.0f; + acc += prob * to_float(v[(j * nkvhead + kv_head) * dv + m]); + } + out_row[m] = from_float(acc); + } +} + +} // namespace + +namespace llaisys::ops::nvidia { + +void self_attention(std::byte *attn_val, + const std::byte *q, + const std::byte *k, + const std::byte *v, + llaisysDataType_t type, + size_t seqlen, + size_t nhead, + size_t nkvhead, + size_t d, + size_t dv, + size_t total_len, + float scale) { + if (seqlen == 0 || nhead == 0 || nkvhead == 0 || d == 0 || dv == 0 || total_len == 0) { + return; + } + + const int grid_size = static_cast(seqlen * nhead); + constexpr int block_size = 128; + constexpr int max_local_out = 8; + const size_t smem_bytes = sizeof(float) * (d + 1); + + switch (type) { + case LLAISYS_DTYPE_F32: + self_attention_online_kernel<<>>( + reinterpret_cast(attn_val), + reinterpret_cast(q), + reinterpret_cast(k), + reinterpret_cast(v), + seqlen, nhead, nkvhead, d, dv, total_len, scale); + break; + case LLAISYS_DTYPE_F16: + self_attention_online_kernel<<>>( + reinterpret_cast(attn_val), + reinterpret_cast(q), + reinterpret_cast(k), + reinterpret_cast(v), + seqlen, nhead, nkvhead, d, dv, total_len, scale); + break; + case LLAISYS_DTYPE_BF16: + self_attention_online_kernel<__nv_bfloat16, block_size, max_local_out><<>>( + reinterpret_cast<__nv_bfloat16 *>(attn_val), + reinterpret_cast(q), + reinterpret_cast(k), + reinterpret_cast(v), + seqlen, nhead, nkvhead, d, dv, total_len, scale); + break; + default: + EXCEPTION_UNSUPPORTED_DATATYPE(type); + } + + CUDA_CHECK(cudaGetLastError()); +} + +} // namespace llaisys::ops::nvidia diff --git a/src/ops/self_attention/nvidia/self_attention_nvidia.cuh b/src/ops/self_attention/nvidia/self_attention_nvidia.cuh new file mode 100644 index 000000000..d088f5877 --- /dev/null +++ b/src/ops/self_attention/nvidia/self_attention_nvidia.cuh @@ -0,0 +1,22 @@ +#pragma once + +#include "llaisys.h" +#include + +namespace llaisys::ops::nvidia { + +void self_attention(std::byte *attn_val, + const std::byte *q, + const std::byte *k, + const std::byte *v, + llaisysDataType_t type, + size_t seqlen, + size_t nhead, + size_t nkvhead, + size_t d, + size_t dv, + size_t total_len, + float scale); + +} // namespace llaisys::ops::nvidia + diff --git a/src/ops/self_attention/op.cpp b/src/ops/self_attention/op.cpp index 43d620142..aba6204a0 100644 --- a/src/ops/self_attention/op.cpp +++ b/src/ops/self_attention/op.cpp @@ -1,7 +1,81 @@ #include "op.hpp" +#include "../../core/llaisys_core.hpp" +#include "../../utils.hpp" + +#include "cpu/self_attention_cpu.hpp" +#ifdef ENABLE_NVIDIA_API +#include "nvidia/self_attention_nvidia.cuh" +#endif +#ifdef ENABLE_METAX_API +#include "metax/self_attention_metax.hpp" +#endif + +// Q: [seqlen, nhead, d], K: [total_len, nkvhead, d], V: [total_len, nkvhead, dv], attn_val: [seqlen, nhead, dv] namespace llaisys::ops { void self_attention(tensor_t attn_val, tensor_t q, tensor_t k, tensor_t v, float scale) { - TO_BE_IMPLEMENTED(); + // 1. 参数校验 + CHECK_SAME_DEVICE(attn_val, q, k, v); + CHECK_SAME_DTYPE(attn_val->dtype(), q->dtype(), k->dtype(), v->dtype()); + CHECK_ARGUMENT(attn_val->ndim() == 3 && q->ndim() == 3 && k->ndim() == 3 && v->ndim() == 3, + "self_attention: all tensors must be 3D"); + CHECK_ARGUMENT(attn_val->shape()[0] == q->shape()[0], "self_attention: seqlen of attn_val and q must match"); + CHECK_ARGUMENT(attn_val->shape()[1] == q->shape()[1], "self_attention: nhead of attn_val and q must match"); + CHECK_ARGUMENT(q->shape()[2] == k->shape()[2], "self_attention: d of q and k must match"); + CHECK_ARGUMENT(attn_val->shape()[2] == v->shape()[2], "self_attention: dv of attn_val and v must match"); + CHECK_ARGUMENT(k->shape()[0] == v->shape()[0] && k->shape()[1] == v->shape()[1], + "self_attention: total_len and nkvhead of k and v must match"); + CHECK_ARGUMENT((q->shape()[1] % k->shape()[1]) == 0, "self_attention: nhead must be divisible by nkvhead (GQA)"); + ASSERT(attn_val->isContiguous() && q->isContiguous() && k->isContiguous() && v->isContiguous(), + "self_attention: all tensors must be contiguous"); + + const size_t seqlen = q->shape()[0]; + const size_t nhead = q->shape()[1]; + const size_t d = q->shape()[2]; + const size_t total_len = k->shape()[0]; + const size_t nkvhead = k->shape()[1]; + const size_t dv = v->shape()[2]; + + // 2. 设置设备上下文 + llaisys::core::context().setDevice(attn_val->deviceType(), attn_val->deviceId()); + + // 3. 设备分发 + switch (attn_val->deviceType()) { + case LLAISYS_DEVICE_CPU: + return cpu::self_attention(attn_val->data(), q->data(), k->data(), v->data(), + attn_val->dtype(), seqlen, nhead, nkvhead, d, dv, total_len, scale); +#ifdef ENABLE_NVIDIA_API + case LLAISYS_DEVICE_NVIDIA: + return nvidia::self_attention(attn_val->data(), + q->data(), + k->data(), + v->data(), + attn_val->dtype(), + seqlen, + nhead, + nkvhead, + d, + dv, + total_len, + scale); +#endif +#ifdef ENABLE_METAX_API + case LLAISYS_DEVICE_METAX: + return metax::self_attention(attn_val->data(), + q->data(), + k->data(), + v->data(), + attn_val->dtype(), + seqlen, + nhead, + nkvhead, + d, + dv, + total_len, + scale); +#endif + default: + EXCEPTION_UNSUPPORTED_DEVICE; + } } } // namespace llaisys::ops diff --git a/src/ops/swiglu/cpu/swiglu_cpu.cpp b/src/ops/swiglu/cpu/swiglu_cpu.cpp new file mode 100644 index 000000000..762564c95 --- /dev/null +++ b/src/ops/swiglu/cpu/swiglu_cpu.cpp @@ -0,0 +1,30 @@ +#include "swiglu_cpu.hpp" +#include "../../../utils.hpp" + +#include + +template +void swiglu_(T *out, const T *gate, const T *up, size_t numel) { + for (size_t i = 0; i < numel; i++) { + float gate_val = llaisys::utils::cast(gate[i]); + float up_val = llaisys::utils::cast(up[i]); + float res = up_val * gate_val / (1 + std::exp(-gate_val)); + out[i] = llaisys::utils::cast(res); + } +} + +namespace llaisys::ops::cpu { +void swiglu(std::byte *out, const std::byte *gate, const std::byte *up, + llaisysDataType_t type, size_t numel) { + switch (type) { + case LLAISYS_DTYPE_F16: + return swiglu_(reinterpret_cast(out), reinterpret_cast(gate), reinterpret_cast(up), numel); + case LLAISYS_DTYPE_BF16: + return swiglu_(reinterpret_cast(out), reinterpret_cast(gate), reinterpret_cast(up), numel); + case LLAISYS_DTYPE_F32: + return swiglu_(reinterpret_cast(out), reinterpret_cast(gate), reinterpret_cast(up), numel); + default: + EXCEPTION_UNSUPPORTED_DATATYPE(type); + } +} +} // namespace llaisys::ops::cpu \ No newline at end of file diff --git a/src/ops/swiglu/cpu/swiglu_cpu.hpp b/src/ops/swiglu/cpu/swiglu_cpu.hpp new file mode 100644 index 000000000..c2945473a --- /dev/null +++ b/src/ops/swiglu/cpu/swiglu_cpu.hpp @@ -0,0 +1,8 @@ +#pragma once +#include "llaisys.h" +#include + +namespace llaisys::ops::cpu { +void swiglu(std::byte *out, const std::byte *gate, const std::byte *up, + llaisysDataType_t type, size_t numel); +} \ No newline at end of file diff --git a/src/ops/swiglu/metax/swiglu_metax.hpp b/src/ops/swiglu/metax/swiglu_metax.hpp new file mode 100644 index 000000000..b4da8d950 --- /dev/null +++ b/src/ops/swiglu/metax/swiglu_metax.hpp @@ -0,0 +1,16 @@ +#pragma once + +#include "llaisys.h" + +#include + +namespace llaisys::ops::metax { + +void swiglu(std::byte *out, + const std::byte *gate, + const std::byte *up, + llaisysDataType_t type, + size_t numel); + +} // namespace llaisys::ops::metax + diff --git a/src/ops/swiglu/metax/swiglu_metax.maca b/src/ops/swiglu/metax/swiglu_metax.maca new file mode 100644 index 000000000..e5e3baf89 --- /dev/null +++ b/src/ops/swiglu/metax/swiglu_metax.maca @@ -0,0 +1,97 @@ +#include "swiglu_metax.hpp" + +#include "../../../utils.hpp" + +#include +#include +#include + +#include + +namespace { + +template +__device__ __forceinline__ float to_float_t(T v) { + return static_cast(v); +} + +template <> +__device__ __forceinline__ float to_float_t<__half>(__half v) { + return __half2float(v); +} + +template <> +__device__ __forceinline__ float to_float_t<__maca_bfloat16>(__maca_bfloat16 v) { + return __bfloat162float(v); +} + +template +__device__ __forceinline__ T from_float_t(float v) { + return static_cast(v); +} + +template <> +__device__ __forceinline__ __half from_float_t<__half>(float v) { + return __float2half(v); +} + +template <> +__device__ __forceinline__ __maca_bfloat16 from_float_t<__maca_bfloat16>(float v) { + return __float2bfloat16(v); +} + +template +__global__ void swiglu_kernel(T *out, const T *gate, const T *up, size_t numel) { + const size_t idx = static_cast(blockIdx.x) * blockDim.x + threadIdx.x; + if (idx >= numel) { + return; + } + + const float gate_val = to_float_t(gate[idx]); + const float up_val = to_float_t(up[idx]); + const float exp_gate = ::expf(-gate_val); + const float out_val = up_val * gate_val / (1.0f + exp_gate); + out[idx] = from_float_t(out_val); +} + +} // namespace + +namespace llaisys::ops::metax { + +void swiglu(std::byte *out, + const std::byte *gate, + const std::byte *up, + llaisysDataType_t type, + size_t numel) { + constexpr int block_size = 512; + const int grid_size = static_cast((numel + static_cast(block_size) - 1) / + static_cast(block_size)); + + switch (type) { + case LLAISYS_DTYPE_F32: + swiglu_kernel<<>>( + reinterpret_cast(out), + reinterpret_cast(gate), + reinterpret_cast(up), + numel); + break; + case LLAISYS_DTYPE_F16: + swiglu_kernel<__half><<>>( + reinterpret_cast<__half *>(out), + reinterpret_cast(gate), + reinterpret_cast(up), + numel); + break; + case LLAISYS_DTYPE_BF16: + swiglu_kernel<__maca_bfloat16><<>>( + reinterpret_cast<__maca_bfloat16 *>(out), + reinterpret_cast(gate), + reinterpret_cast(up), + numel); + break; + default: + EXCEPTION_UNSUPPORTED_DATATYPE(type); + } +} + +} // namespace llaisys::ops::metax diff --git a/src/ops/swiglu/nvidia/swiglu_nvidia.cu b/src/ops/swiglu/nvidia/swiglu_nvidia.cu new file mode 100644 index 000000000..31f20f467 --- /dev/null +++ b/src/ops/swiglu/nvidia/swiglu_nvidia.cu @@ -0,0 +1,46 @@ +#include "swiglu_nvidia.cuh" + +#include "../../../utils.hpp" +#include "../../../utils/gpu_utils.hpp" + +namespace { + +template +__global__ void swiglu_kernel(T *out, const T *gate, const T *up, size_t numel) { + const size_t idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx >= numel) { + return; + } + + float gate_val = to_float(gate[idx]); + float up_val = to_float(up[idx]); + float exp_gate = ::expf(-gate_val); + float out_val = up_val * gate_val / (1 + exp_gate); + out[idx] = from_float(out_val); +} + +} // namespace + +namespace llaisys::ops::nvidia { +void swiglu(std::byte *out, const std::byte *gate, const std::byte *up, + llaisysDataType_t type, size_t numel) { + constexpr int block_size = 256; + const int grid_size = CEIL(numel, block_size); + + switch (type) { + case LLAISYS_DTYPE_F32: + swiglu_kernel<<>>(reinterpret_cast(out), reinterpret_cast(gate), reinterpret_cast(up), numel); + break; + case LLAISYS_DTYPE_F16: + swiglu_kernel<<>>(reinterpret_cast(out), reinterpret_cast(gate), reinterpret_cast(up), numel); + break; + case LLAISYS_DTYPE_BF16: + swiglu_kernel<__nv_bfloat16><<>>(reinterpret_cast<__nv_bfloat16 *>(out), reinterpret_cast(gate), reinterpret_cast(up), numel); + break; + default: + EXCEPTION_UNSUPPORTED_DATATYPE(type); + } + + CUDA_CHECK(cudaGetLastError()); +} +} // namespace llaisys::ops::nvidia diff --git a/src/ops/swiglu/nvidia/swiglu_nvidia.cuh b/src/ops/swiglu/nvidia/swiglu_nvidia.cuh new file mode 100644 index 000000000..1224b3b4a --- /dev/null +++ b/src/ops/swiglu/nvidia/swiglu_nvidia.cuh @@ -0,0 +1,8 @@ +#pragma once +#include "llaisys.h" +#include + +namespace llaisys::ops::nvidia { +void swiglu(std::byte *out, const std::byte *gate, const std::byte *up, + llaisysDataType_t type, size_t numel); +} \ No newline at end of file diff --git a/src/ops/swiglu/op.cpp b/src/ops/swiglu/op.cpp index 47edbcc97..43c04838f 100644 --- a/src/ops/swiglu/op.cpp +++ b/src/ops/swiglu/op.cpp @@ -1,7 +1,43 @@ #include "op.hpp" +#include "../../core/llaisys_core.hpp" +#include "../../utils.hpp" + +#include "cpu/swiglu_cpu.hpp" +#ifdef ENABLE_NVIDIA_API +#include "nvidia/swiglu_nvidia.cuh" +#endif +#ifdef ENABLE_METAX_API +#include "metax/swiglu_metax.hpp" +#endif + namespace llaisys::ops { void swiglu(tensor_t out, tensor_t gate, tensor_t up) { - TO_BE_IMPLEMENTED(); + // 1. 参数校验 + CHECK_SAME_DEVICE(out, gate, up); + CHECK_SAME_SHAPE(out->shape(), gate->shape(), up->shape()); + CHECK_SAME_DTYPE(out->dtype(), gate->dtype(), up->dtype()); + CHECK_ARGUMENT(out->ndim() == 2, "out must be a 2D tensor"); + ASSERT(out->isContiguous() && gate->isContiguous() && up->isContiguous(), "out, gate and up must be contiguous"); + + // 2. 设置设备上下文 + llaisys::core::context().setDevice(out->deviceType(), out->deviceId()); + const size_t numel = out->numel(); + + // 3. 设备分发 + switch(out->deviceType()) { + case LLAISYS_DEVICE_CPU: + return cpu::swiglu(out->data(), gate->data(), up->data(), out->dtype(), numel); +#ifdef ENABLE_NVIDIA_API + case LLAISYS_DEVICE_NVIDIA: + return nvidia::swiglu(out->data(), gate->data(), up->data(), out->dtype(), numel); +#endif +#ifdef ENABLE_METAX_API + case LLAISYS_DEVICE_METAX: + return metax::swiglu(out->data(), gate->data(), up->data(), out->dtype(), numel); +#endif + default: + EXCEPTION_UNSUPPORTED_DEVICE; + } } } // namespace llaisys::ops diff --git a/src/tensor/tensor.cpp b/src/tensor/tensor.cpp index 2f594bb65..068e7eab4 100644 --- a/src/tensor/tensor.cpp +++ b/src/tensor/tensor.cpp @@ -2,9 +2,14 @@ #include "../utils.hpp" +#include #include +#include #include #include +#include + +#include namespace llaisys { @@ -26,6 +31,7 @@ tensor_t Tensor::create(const std::vector &shape, size_t total_elems = stride; size_t dtype_size = utils::dsize(dtype); + // Fast path for host tensors when the active runtime is non-CPU. if (device_type == LLAISYS_DEVICE_CPU && core::context().runtime().deviceType() != LLAISYS_DEVICE_CPU) { auto storage = core::context().runtime().allocateHostStorage(total_elems * dtype_size); return std::shared_ptr(new Tensor(meta, storage)); @@ -164,27 +170,95 @@ void Tensor::debug() const { } bool Tensor::isContiguous() const { - TO_BE_IMPLEMENTED(); + const auto &tensor_shape = shape(); + const auto &tensor_strides = strides(); + const size_t &tensor_ndim = ndim(); + + if (tensor_ndim == 0 || tensor_ndim == 1) { + return true; + } + + if (tensor_ndim == 1) { + return tensor_strides[0] == 1; + } + ptrdiff_t expected_stride = 1; + + for (ptrdiff_t i = static_cast(tensor_ndim) - 1; i >= 0; i--) { + if (tensor_strides[i] != expected_stride) { + return false; + } + expected_stride *= tensor_shape[i]; + } return true; } tensor_t Tensor::permute(const std::vector &order) const { - TO_BE_IMPLEMENTED(); - return std::shared_ptr(new Tensor(_meta, _storage)); + CHECK_ARGUMENT(order.size() == ndim(), "order size != tensor ndim"); + + std::vector used(ndim(), false); + for (auto index : order) { + CHECK_ARGUMENT(index < ndim(), "order index out of dim range"); + CHECK_ARGUMENT(!used[index], "index repition"); + used[index] = true; + } + + llaisys::TensorMeta new_meta = _meta; + for (size_t i = 0; i < order.size(); ++i) { + new_meta.shape[i] = _meta.shape[order[i]]; + new_meta.strides[i] = _meta.strides[order[i]]; + } + + return std::shared_ptr(new Tensor(new_meta, _storage, _offset)); } +// View reshapes metadata only and requires a contiguous tensor. tensor_t Tensor::view(const std::vector &shape) const { - TO_BE_IMPLEMENTED(); - return std::shared_ptr(new Tensor(_meta, _storage)); + size_t new_numel = 1; + for (auto num : shape) { + new_numel *= num; + } + CHECK_ARGUMENT(new_numel == numel(), "view size match"); + + if (isContiguous()) { + TensorMeta new_meta = _meta; + new_meta.shape = shape; + + new_meta.strides.resize(shape.size()); + ptrdiff_t stride = 1; + for (int i = static_cast(shape.size()) - 1; i >= 0; i--) { + new_meta.strides[i] = stride; + stride *= static_cast(shape[i]); + } + + return std::shared_ptr(new Tensor(new_meta, _storage, _offset)); + } + + return nullptr; } +// Slice shares storage and only adjusts shape and offset. tensor_t Tensor::slice(size_t dim, size_t start, size_t end) const { - TO_BE_IMPLEMENTED(); - return std::shared_ptr(new Tensor(_meta, _storage)); + CHECK_ARGUMENT(dim < ndim(), "dim out of range"); + CHECK_ARGUMENT(start < end, "start must less than end"); + CHECK_ARGUMENT(end <= shape()[dim], "end out of range"); + + llaisys::TensorMeta new_meta = _meta; + new_meta.shape[dim] = end - start; + + size_t new_offset = _offset + start * strides()[dim] * elementSize(); + + return std::shared_ptr(new Tensor(new_meta, _storage, new_offset)); } void Tensor::load(const void *src_) { - TO_BE_IMPLEMENTED(); + core::context().setDevice(this->deviceType(), this->deviceId()); + + const LlaisysRuntimeAPI *api = core::context().runtime().api(); + + size_t size_bytes = this->numel() * this->elementSize(); + + // Copy host data into the tensor storage. + api->memcpy_sync(this->data(), src_, size_bytes, LLAISYS_MEMCPY_H2D); } tensor_t Tensor::contiguous() const { diff --git a/src/tensor/tensor.hpp b/src/tensor/tensor.hpp index 35e340922..7e147a944 100644 --- a/src/tensor/tensor.hpp +++ b/src/tensor/tensor.hpp @@ -9,14 +9,16 @@ using tensor_t = std::shared_ptr; struct TensorMeta { llaisysDataType_t dtype; std::vector shape; - std::vector strides; + std::vector strides; // 以元素为单位,计算每个维度上元素的偏移量 }; +// 逻辑上组织张量:shape、strides、offset +// 物理上组织张量:storage class Tensor { private: TensorMeta _meta; core::storage_t _storage; - size_t _offset; + size_t _offset; //以字节为单位,记录该张量在storage中的起始位置(一个storage存储不同的张量) Tensor(TensorMeta meta, core::storage_t storage, size_t offset = 0); public: diff --git a/src/utils.hpp b/src/utils.hpp index f038edfb6..d0bbcf603 100644 --- a/src/utils.hpp +++ b/src/utils.hpp @@ -1,3 +1,3 @@ #pragma once #include "utils/check.hpp" -#include "utils/types.hpp" +#include "utils/types.hpp" \ No newline at end of file diff --git a/src/utils/gpu_utils.hpp b/src/utils/gpu_utils.hpp new file mode 100644 index 000000000..952bb164e --- /dev/null +++ b/src/utils/gpu_utils.hpp @@ -0,0 +1,51 @@ +#if defined(ENABLE_NVIDIA_API) + +#include + +#include +#include + +#define LOAD_FLOAT4(value) *(reinterpret_cast(&value)) +#define STORE_FLOAT4(value) *(reinterpret_cast(&value)) +#define LOAD_HALF2(value) *(reinterpret_cast(&value)) +#define STORE_HALF2(value) *(reinterpret_cast(&(value))) +#define LOAD_BFLOAT2(value) *(reinterpret_cast(&value)) +#define STORE_BFLOAT2(value) *(reinterpret_cast<__nv_bfloat162*>(&value)) + +#define CEIL(x, y) ((x + y - 1) / y) + +#define CUDA_CHECK(err) _cudaCheck(err, __FILE__, __LINE__) +inline void _cudaCheck(cudaError_t err, const char* file, int line) { + if (err != cudaSuccess) { + std::cerr << "[CUDA Error] " << cudaGetErrorString(err) << " at " << file << ":" << line << std::endl; + throw std::runtime_error(cudaGetErrorString(err)); + } +} + +template +__device__ __forceinline__ float to_float(T v) { + return static_cast(v); +} +template <> +__device__ __forceinline__ float to_float(half v) { + return __half2float(v); +} +template <> +__device__ __forceinline__ float to_float(__nv_bfloat16 v) { + return __bfloat162float(v); +} + +template +__device__ __forceinline__ T from_float(float v) { + return static_cast(v); +} +template <> +__device__ __forceinline__ half from_float(float v) { + return __float2half(v); +} +template <> +__device__ __forceinline__ __nv_bfloat16 from_float<__nv_bfloat16>(float v) { + return __float2bfloat16(v); +} + +#endif // ENABLE_NVIDIA_API \ No newline at end of file diff --git a/test/benchmark_infer.py b/test/benchmark_infer.py new file mode 100644 index 000000000..f4d9fc179 --- /dev/null +++ b/test/benchmark_infer.py @@ -0,0 +1,238 @@ +import argparse +import gc +import io +import logging +import os +import statistics +import sys +import time + +import llaisys +import torch +from huggingface_hub import snapshot_download +from transformers import AutoModelForCausalLM, AutoTokenizer + +from test_utils import llaisys_device, torch_device + +sys.stdout = io.TextIOWrapper(sys.stdout.buffer, encoding="utf-8") + +PROMPTS = { + "short": "Who are you?", + "medium": ( + "Explain the role of KV cache in transformer decoding, and give a short " + "step-by-step example with one prompt token and two generated tokens." + ), + "long": ( + "I am building a tiny LLM inference system from scratch. Please provide a " + "concise engineering checklist that covers model loading, tensor layout, " + "runtime abstraction, memory reuse, operator profiling, and end-to-end " + "benchmarking. Keep the answer practical and implementation-oriented." + ), +} + +logging.getLogger("transformers.dynamic_module_utils").setLevel(logging.ERROR) + + +def is_gpu_device(device_name): + return device_name in {"nvidia", "metax"} + + +def parse_csv(text, caster=str): + return [caster(x.strip()) for x in text.split(",") if x.strip()] + + +def load_hf_model(model_path, device_name): + model_id = "deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B" + if model_path and os.path.isdir(model_path): + model_path = os.path.expanduser(model_path) + print(f"Loading model from local path: {model_path}") + else: + print(f"Loading model from Hugging Face: {model_id}") + model_path = snapshot_download(model_id) + + tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True) + kwargs = {"device_map": torch_device(device_name), "trust_remote_code": True} + try: + model = AutoModelForCausalLM.from_pretrained(model_path, dtype=torch.bfloat16, **kwargs) + except TypeError: + model = AutoModelForCausalLM.from_pretrained(model_path, torch_dtype=torch.bfloat16, **kwargs) + return tokenizer, model, model_path + + +def load_llaisys_model(model_path, device_name): + return llaisys.models.Qwen2(model_path, llaisys_device(device_name)) + + +def sync_torch(device_name): + if is_gpu_device(device_name): + torch.cuda.synchronize() + + +def sync_llaisys(device_name): + llaisys.RuntimeAPI(llaisys_device(device_name)).device_synchronize() + + +def build_input_ids(tokenizer, prompt): + text = tokenizer.apply_chat_template( + conversation=[{"role": "user", "content": prompt}], + add_generation_prompt=True, + tokenize=False, + ) + return tokenizer.encode(text) + + +def run_torch_case(tokenizer, model, input_ids, max_new_tokens, top_k, top_p, temperature, device_name): + inputs = torch.tensor(input_ids, dtype=torch.int64, device=model.device).unsqueeze(0) + attention_mask = torch.ones_like(inputs) + + sync_torch(device_name) + start = time.perf_counter() + with torch.no_grad(): + outputs = model.generate( + inputs, + attention_mask=attention_mask, + max_new_tokens=max_new_tokens, + top_k=top_k, + top_p=top_p, + temperature=temperature, + pad_token_id=tokenizer.eos_token_id, + ) + sync_torch(device_name) + out_tokens = outputs[0].tolist() + return time.perf_counter() - start, len(out_tokens) - len(input_ids), out_tokens + + +def run_llaisys_case(model, input_ids, max_new_tokens, top_k, top_p, temperature, device_name): + sync_llaisys(device_name) + start = time.perf_counter() + out_tokens = model.generate( + input_ids, + max_new_tokens=max_new_tokens, + top_k=top_k, + top_p=top_p, + temperature=temperature, + ) + sync_llaisys(device_name) + return time.perf_counter() - start, len(out_tokens) - len(input_ids), out_tokens + + +def benchmark_backend(backend, tokenizer, model, cases, warmup, repeat, top_k, top_p, temperature, device_name): + rows = {} + for case in cases: + for _ in range(warmup): + if backend == "torch": + run_torch_case(tokenizer, model, case["input_ids"], case["max_new_tokens"], top_k, top_p, temperature, device_name) + else: + run_llaisys_case(model, case["input_ids"], case["max_new_tokens"], top_k, top_p, temperature, device_name) + + latencies = [] + generated = [] + for _ in range(repeat): + if backend == "torch": + elapsed, new_tokens, _ = run_torch_case( + tokenizer, model, case["input_ids"], case["max_new_tokens"], top_k, top_p, temperature, device_name + ) + else: + elapsed, new_tokens, _ = run_llaisys_case( + model, case["input_ids"], case["max_new_tokens"], top_k, top_p, temperature, device_name + ) + latencies.append(elapsed) + generated.append(new_tokens) + + mean_s = statistics.mean(latencies) + rows[(case["prompt_name"], case["max_new_tokens"])] = { + "mean_ms": mean_s * 1000.0, + "mean_new_tokens": statistics.mean(generated), + "tokens_per_sec": statistics.mean(generated) / mean_s if mean_s > 0 else 0.0, + } + return rows + + +def print_report(cases, torch_rows, llaisys_rows): + print("\n=== Torch vs LLAISYS Inference Benchmark ===") + print("| Case | Torch mean(ms) | Torch tok/s | LLAISYS mean(ms) | LLAISYS tok/s | speedup |") + print("|---|---:|---:|---:|---:|---:|") + + torch_total_tokens = 0.0 + llaisys_total_tokens = 0.0 + torch_total_seconds = 0.0 + llaisys_total_seconds = 0.0 + + for case in cases: + key = (case["prompt_name"], case["max_new_tokens"]) + torch_row = torch_rows[key] + llaisys_row = llaisys_rows[key] + speedup = torch_row["mean_ms"] / llaisys_row["mean_ms"] if llaisys_row["mean_ms"] > 0 else 0.0 + + print( + f"| {case['prompt_name']}/{case['max_new_tokens']} | {torch_row['mean_ms']:.2f} | {torch_row['tokens_per_sec']:.2f} | " + f"{llaisys_row['mean_ms']:.2f} | {llaisys_row['tokens_per_sec']:.2f} | {speedup:.2f}x |" + ) + + torch_total_tokens += torch_row["mean_new_tokens"] + llaisys_total_tokens += llaisys_row["mean_new_tokens"] + torch_total_seconds += torch_row["mean_ms"] / 1000.0 + llaisys_total_seconds += llaisys_row["mean_ms"] / 1000.0 + + torch_total_tok_s = torch_total_tokens / torch_total_seconds if torch_total_seconds > 0 else 0.0 + llaisys_total_tok_s = llaisys_total_tokens / llaisys_total_seconds if llaisys_total_seconds > 0 else 0.0 + overall_speedup = llaisys_total_tok_s / torch_total_tok_s if torch_total_tok_s > 0 else 0.0 + + print("\n=== Throughput Summary ===") + print(f"Torch total throughput : {torch_total_tok_s:.2f} tok/s") + print(f"LLAISYS total throughput : {llaisys_total_tok_s:.2f} tok/s") + print(f"Overall speedup : {overall_speedup:.2f}x") + + +def main(): + parser = argparse.ArgumentParser(description="Benchmark Torch vs LLAISYS inference throughput.") + parser.add_argument("--model", required=True, type=str) + parser.add_argument("--device", default="nvidia", choices=["cpu", "nvidia", "metax"], type=str) + parser.add_argument("--prompts", default="short,medium,long", type=str) + parser.add_argument("--max-new-tokens", default="32,64,128", type=str) + parser.add_argument("--warmup", default=2, type=int) + parser.add_argument("--repeat", default=3, type=int) + parser.add_argument("--top-k", default=1, type=int) + parser.add_argument("--top-p", default=1.0, type=float) + parser.add_argument("--temperature", default=1.0, type=float) + args = parser.parse_args() + + top_k, top_p, temperature = args.top_k, args.top_p, args.temperature + + prompt_names = parse_csv(args.prompts) + max_new_tokens_list = parse_csv(args.max_new_tokens, int) + for name in prompt_names: + if name not in PROMPTS: + raise ValueError(f"Unknown prompt preset: {name}. Valid keys: {list(PROMPTS.keys())}") + + tokenizer, torch_model, model_path = load_hf_model(args.model, args.device) + cases = [ + { + "prompt_name": prompt_name, + "max_new_tokens": max_new_tokens, + "input_ids": build_input_ids(tokenizer, PROMPTS[prompt_name]), + } + for prompt_name in prompt_names + for max_new_tokens in max_new_tokens_list + ] + + torch_rows = benchmark_backend( + "torch", tokenizer, torch_model, cases, args.warmup, args.repeat, top_k, top_p, temperature, args.device + ) + + del torch_model + gc.collect() + if is_gpu_device(args.device): + torch.cuda.empty_cache() + torch.cuda.synchronize() + + llaisys_model = load_llaisys_model(model_path, args.device) + llaisys_rows = benchmark_backend( + "llaisys", tokenizer, llaisys_model, cases, args.warmup, args.repeat, top_k, top_p, temperature, args.device + ) + + print_report(cases, torch_rows, llaisys_rows) + + +if __name__ == "__main__": + main() diff --git a/test/chat_cli.py b/test/chat_cli.py new file mode 100644 index 000000000..8d1db85b3 --- /dev/null +++ b/test/chat_cli.py @@ -0,0 +1,158 @@ +import argparse +import json +import sys +import urllib.error +import urllib.request +from typing import Any, Dict, List + + +def post_json(url: str, payload: Dict[str, Any]) -> Dict[str, Any]: + data = json.dumps(payload).encode("utf-8") + req = urllib.request.Request( + url, + data=data, + headers={"Content-Type": "application/json"}, + method="POST", + ) + with urllib.request.urlopen(req) as resp: + raw = resp.read().decode("utf-8") + return json.loads(raw) + + +def stream_sse(url: str, payload: Dict[str, Any]): + data = json.dumps(payload).encode("utf-8") + req = urllib.request.Request( + url, + data=data, + headers={ + "Content-Type": "application/json", + "Accept": "text/event-stream", + }, + method="POST", + ) + with urllib.request.urlopen(req) as resp: + for raw_line in resp: + line = raw_line.decode("utf-8").strip() + if not line.startswith("data: "): + continue + data_part = line[6:] + if data_part == "[DONE]": + break + yield json.loads(data_part) + + +def request_assistant_reply( + url: str, + model_name: str, + messages: List[Dict[str, str]], + max_tokens: int, + top_k: int, + top_p: float, + temperature: float, + stream: bool, +) -> str: + payload = { + "model": model_name, + "messages": messages, + "max_tokens": max_tokens, + "top_k": top_k, + "top_p": top_p, + "temperature": temperature, + "stream": stream, + } + + if not stream: + obj = post_json(url, payload) + return obj["choices"][0]["message"]["content"] + + pieces: List[str] = [] + for chunk in stream_sse(url, payload): + choices = chunk.get("choices", []) + if not choices: + continue + delta = choices[0].get("delta", {}) + text = delta.get("content", "") + if text: + pieces.append(text) + sys.stdout.write(text) + sys.stdout.flush() + sys.stdout.write("\n") + return "".join(pieces) + + +def parse_args(): + parser = argparse.ArgumentParser(description="Interactive CLI for LLAISYS chat server") + parser.add_argument("--url", default="http://127.0.0.1:8000/v1/chat/completions", type=str) + parser.add_argument("--model", default="llaisys-qwen2", type=str) + parser.add_argument("--system", default="", type=str) + parser.add_argument("--max-tokens", default=256, type=int) + parser.add_argument("--top-k", default=1, type=int) + parser.add_argument("--top-p", default=1.0, type=float) + parser.add_argument("--temperature", default=1.0, type=float) + parser.add_argument("--stream", action="store_true") + return parser.parse_args() + + +def main(): + args = parse_args() + history: List[Dict[str, str]] = [] + if args.system: + history.append({"role": "system", "content": args.system}) + + print("Interactive chat started.") + print("Commands: /reset clears history, /exit quits.") + + while True: + try: + user_text = input("\nYou: ").strip() + except (EOFError, KeyboardInterrupt): + print("\nBye.") + return + + if not user_text: + continue + if user_text in {"/exit", "/quit"}: + print("Bye.") + return + if user_text == "/reset": + history = [] + if args.system: + history.append({"role": "system", "content": args.system}) + print("History cleared.") + continue + + history.append({"role": "user", "content": user_text}) + try: + if not args.stream: + print("Assistant: ", end="") + reply = request_assistant_reply( + url=args.url, + model_name=args.model, + messages=history, + max_tokens=args.max_tokens, + top_k=args.top_k, + top_p=args.top_p, + temperature=args.temperature, + stream=args.stream, + ) + if not args.stream: + print(reply) + except urllib.error.HTTPError as exc: + body = exc.read().decode("utf-8", errors="replace") + print(f"HTTP error {exc.code}: {body}") + history.pop() + continue + except urllib.error.URLError as exc: + print(f"Connection error: {exc}") + history.pop() + continue + except Exception as exc: # noqa: BLE001 + print(f"Request failed: {exc}") + history.pop() + continue + + history.append({"role": "assistant", "content": reply}) + + +if __name__ == "__main__": + main() diff --git a/test/chat_server.py b/test/chat_server.py new file mode 100644 index 000000000..4d1e0bfa3 --- /dev/null +++ b/test/chat_server.py @@ -0,0 +1,340 @@ +import argparse +import json +import threading +import time +import uuid +import sys +from pathlib import Path +from typing import Any, Dict, Iterable, List + +try: + from fastapi import FastAPI, HTTPException + from fastapi.responses import FileResponse, JSONResponse, StreamingResponse +except ModuleNotFoundError as exc: + raise SystemExit( + "Missing dependencies for chat server. Install with:\n" + " pip install fastapi uvicorn" + ) from exc + +from transformers import AutoTokenizer + +# Prefer local python package source under repo root. +REPO_ROOT = Path(__file__).resolve().parents[1] +PYTHON_SRC = REPO_ROOT / "python" +if str(PYTHON_SRC) not in sys.path: + sys.path.insert(0, str(PYTHON_SRC)) + +import llaisys +from test_utils import llaisys_device + +UI_HTML_PATH = Path(__file__).with_name("chat_web.html") + + +def parse_message_content(content: Any) -> str: + if isinstance(content, str): + return content + if isinstance(content, list): + parts: List[str] = [] + for item in content: + if isinstance(item, str): + parts.append(item) + continue + if isinstance(item, dict) and item.get("type") == "text": + text = item.get("text", "") + if isinstance(text, str): + parts.append(text) + return "".join(parts) + if content is None: + return "" + return str(content) + + +def normalize_messages(raw_messages: Any) -> List[Dict[str, str]]: + if not isinstance(raw_messages, list) or len(raw_messages) == 0: + raise ValueError("`messages` must be a non-empty list") + + out: List[Dict[str, str]] = [] + for item in raw_messages: + if not isinstance(item, dict): + raise ValueError("each message must be an object") + role = item.get("role") + if role not in {"system", "user", "assistant"}: + raise ValueError(f"unsupported role: {role}") + content = parse_message_content(item.get("content")) + out.append({"role": role, "content": content}) + return out + + +class ChatEngine: + def __init__(self, model_path: str, device: str): + self.model_path = model_path + self.device_name = device + self.tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True) + self.model = llaisys.models.Qwen2(model_path, llaisys_device(device)) + self._infer_lock = threading.Lock() + + def _build_inputs(self, messages: List[Dict[str, str]]) -> List[int]: + prompt = self.tokenizer.apply_chat_template( + conversation=messages, + add_generation_prompt=True, + tokenize=False, + ) + return self.tokenizer.encode(prompt) + + def generate( + self, + messages: List[Dict[str, str]], + max_new_tokens: int, + top_k: int, + top_p: float, + temperature: float, + ) -> Dict[str, Any]: + with self._infer_lock: + input_ids = self._build_inputs(messages) + out_ids = self.model.generate( + input_ids, + max_new_tokens=max_new_tokens, + top_k=top_k, + top_p=top_p, + temperature=temperature, + ) + completion_ids = out_ids[len(input_ids):] + completion_text = self.tokenizer.decode(completion_ids, skip_special_tokens=True) + return { + "text": completion_text, + "prompt_tokens": len(input_ids), + "completion_tokens": len(completion_ids), + } + + def stream_generate( + self, + messages: List[Dict[str, str]], + max_new_tokens: int, + top_k: int, + top_p: float, + temperature: float, + ) -> Iterable[Dict[str, Any]]: + with self._infer_lock: + input_ids = self._build_inputs(messages) + if not hasattr(self.model, "generate_stream"): + out_ids = self.model.generate( + input_ids, + max_new_tokens=max_new_tokens, + top_k=top_k, + top_p=top_p, + temperature=temperature, + ) + completion_ids = out_ids[len(input_ids):] + completion_text = self.tokenizer.decode(completion_ids, skip_special_tokens=True) + if completion_text: + yield { + "delta": completion_text, + "prompt_tokens": len(input_ids), + "completion_tokens": len(completion_ids), + } + yield { + "delta": "", + "prompt_tokens": len(input_ids), + "completion_tokens": len(completion_ids), + "final_text": completion_text, + } + return + + generated_ids: List[int] = [] + previous_text = "" + + for token_id in self.model.generate_stream( + input_ids, + max_new_tokens=max_new_tokens, + top_k=top_k, + top_p=top_p, + temperature=temperature, + ): + generated_ids.append(int(token_id)) + current_text = self.tokenizer.decode(generated_ids, skip_special_tokens=True) + if current_text.startswith(previous_text): + delta = current_text[len(previous_text):] + else: + # Fallback for rare decode normalization mismatch. + delta = self.tokenizer.decode([int(token_id)], skip_special_tokens=True) + previous_text = current_text + if delta: + yield { + "delta": delta, + "prompt_tokens": len(input_ids), + "completion_tokens": len(generated_ids), + } + + yield { + "delta": "", + "prompt_tokens": len(input_ids), + "completion_tokens": len(generated_ids), + "final_text": previous_text, + } + + +def create_app(engine: ChatEngine, served_model_name: str) -> FastAPI: + app = FastAPI(title="LLAISYS Chat Server", version="0.1.0") + + @app.get("/") + def chat_web() -> Any: + if not UI_HTML_PATH.exists(): + raise HTTPException(status_code=404, detail="chat_web.html not found") + return FileResponse( + UI_HTML_PATH, + headers={ + "Cache-Control": "no-store, no-cache, must-revalidate, max-age=0", + "Pragma": "no-cache", + "Expires": "0", + }, + ) + + @app.get("/health") + def health() -> Dict[str, str]: + return {"status": "ok"} + + @app.post("/v1/chat/completions") + def chat_completions(payload: Dict[str, Any]) -> Any: + try: + messages = normalize_messages(payload.get("messages")) + except ValueError as exc: + raise HTTPException(status_code=400, detail=str(exc)) from exc + + stream = bool(payload.get("stream", False)) + top_k = int(payload.get("top_k", 1)) + top_p = float(payload.get("top_p", 1.0)) + temperature = float(payload.get("temperature", 1.0)) + max_new_tokens = int(payload.get("max_tokens", payload.get("max_new_tokens", 128))) + max_new_tokens = max(1, max_new_tokens) + + request_model_name = payload.get("model") + model_name = request_model_name if isinstance(request_model_name, str) else served_model_name + + completion_id = f"chatcmpl-{uuid.uuid4().hex[:24]}" + created = int(time.time()) + + if not stream: + result = engine.generate( + messages=messages, + max_new_tokens=max_new_tokens, + top_k=top_k, + top_p=top_p, + temperature=temperature, + ) + response_obj = { + "id": completion_id, + "object": "chat.completion", + "created": created, + "model": model_name, + "choices": [ + { + "index": 0, + "message": {"role": "assistant", "content": result["text"]}, + "finish_reason": "stop", + } + ], + "usage": { + "prompt_tokens": result["prompt_tokens"], + "completion_tokens": result["completion_tokens"], + "total_tokens": result["prompt_tokens"] + result["completion_tokens"], + }, + } + return JSONResponse(response_obj) + + def stream_iter(): + first_chunk = { + "id": completion_id, + "object": "chat.completion.chunk", + "created": created, + "model": model_name, + "choices": [{"index": 0, "delta": {"role": "assistant"}, "finish_reason": None}], + } + yield f"data: {json.dumps(first_chunk, ensure_ascii=False)}\n\n" + + final_usage = None + for item in engine.stream_generate( + messages=messages, + max_new_tokens=max_new_tokens, + top_k=top_k, + top_p=top_p, + temperature=temperature, + ): + if "final_text" in item: + final_usage = item + break + delta_chunk = { + "id": completion_id, + "object": "chat.completion.chunk", + "created": created, + "model": model_name, + "choices": [ + { + "index": 0, + "delta": {"content": item["delta"]}, + "finish_reason": None, + } + ], + } + yield f"data: {json.dumps(delta_chunk, ensure_ascii=False)}\n\n" + + finish_chunk = { + "id": completion_id, + "object": "chat.completion.chunk", + "created": created, + "model": model_name, + "choices": [{"index": 0, "delta": {}, "finish_reason": "stop"}], + } + yield f"data: {json.dumps(finish_chunk, ensure_ascii=False)}\n\n" + + if final_usage is not None: + usage_chunk = { + "id": completion_id, + "object": "chat.completion.chunk", + "created": created, + "model": model_name, + "usage": { + "prompt_tokens": final_usage["prompt_tokens"], + "completion_tokens": final_usage["completion_tokens"], + "total_tokens": ( + final_usage["prompt_tokens"] + final_usage["completion_tokens"] + ), + }, + "choices": [], + } + yield f"data: {json.dumps(usage_chunk, ensure_ascii=False)}\n\n" + + yield "data: [DONE]\n\n" + + return StreamingResponse(stream_iter(), media_type="text/event-stream") + + return app + + +def parse_args(): + parser = argparse.ArgumentParser(description="LLAISYS OpenAI-style Chat Server") + parser.add_argument("--model", required=True, type=str, help="Path to model directory") + parser.add_argument("--device", default="cpu", choices=["cpu", "nvidia", "metax"], type=str) + parser.add_argument("--host", default="127.0.0.1", type=str) + parser.add_argument("--port", default=8000, type=int) + parser.add_argument("--served-model-name", default="llaisys-qwen2", type=str) + return parser.parse_args() + + +def main(): + args = parse_args() + engine = ChatEngine(model_path=args.model, device=args.device) + app = create_app(engine, served_model_name=args.served_model_name) + + try: + import uvicorn + except ModuleNotFoundError as exc: + raise SystemExit( + "Missing uvicorn. Install with:\n" + " pip install uvicorn" + ) from exc + + uvicorn.run(app, host=args.host, port=args.port) + + +if __name__ == "__main__": + main() diff --git a/test/chat_web.html b/test/chat_web.html new file mode 100644 index 000000000..1c4892630 --- /dev/null +++ b/test/chat_web.html @@ -0,0 +1,632 @@ + + + + + + LLAISYS Chat + + + +
+ + +
+
+
+ Conversation + Streaming responses from the local LLAISYS server +
+ Idle +
+
+
+ Messages will appear here. Keep the left panel for generation settings and use the bottom composer for chat. +
+
+
+
+ +
+ + +
+
+
+
+
+ + + + diff --git a/test/ops/add.py b/test/ops/add.py index bb8bf8ca8..2abd75b31 100644 --- a/test/ops/add.py +++ b/test/ops/add.py @@ -42,16 +42,23 @@ def test_op_add( import argparse parser = argparse.ArgumentParser() - parser.add_argument("--device", default="cpu", choices=["cpu", "nvidia"], type=str) + parser.add_argument("--device", default="cpu", choices=["cpu", "nvidia", "metax"], type=str) parser.add_argument("--profile", action="store_true") args = parser.parse_args() testShapes = [(2, 3), (512, 4096)] - testDtypePrec = [ - # type, atol, rtol - ("f32", 1e-5, 1e-5), - ("f16", 1e-3, 1e-3), - ("bf16", 1e-3, 1e-3), - ] + if args.device == "metax": + testDtypePrec = [ + ("f32", 1e-5, 1e-5), + ("f16", 1e-3, 1e-3), + ("bf16", 1e-3, 1e-3), + ] + else: + testDtypePrec = [ + # type, atol, rtol + ("f32", 1e-5, 1e-5), + ("f16", 1e-3, 1e-3), + ("bf16", 1e-3, 1e-3), + ] print(f"Testing Ops.add on {args.device}") for shape in testShapes: for dtype_name, atol, rtol in testDtypePrec: diff --git a/test/ops/argmax.py b/test/ops/argmax.py index d0f7ee298..87a5d970d 100644 --- a/test/ops/argmax.py +++ b/test/ops/argmax.py @@ -43,7 +43,7 @@ def test_op_argmax( import argparse parser = argparse.ArgumentParser() - parser.add_argument("--device", default="cpu", choices=["cpu", "nvidia"], type=str) + parser.add_argument("--device", default="cpu", choices=["cpu", "nvidia", "metax"], type=str) parser.add_argument("--profile", action="store_true") args = parser.parse_args() testShapes = [(4,), (4096,)] diff --git a/test/ops/embedding.py b/test/ops/embedding.py index 99cadc1b8..17286babf 100644 --- a/test/ops/embedding.py +++ b/test/ops/embedding.py @@ -39,7 +39,7 @@ def test_op_embedding( import argparse parser = argparse.ArgumentParser() - parser.add_argument("--device", default="cpu", choices=["cpu", "nvidia"], type=str) + parser.add_argument("--device", default="cpu", choices=["cpu", "nvidia", "metax"], type=str) parser.add_argument("--profile", action="store_true") args = parser.parse_args() testShapes = [ diff --git a/test/ops/linear.py b/test/ops/linear.py index 38897331f..4ffff3943 100644 --- a/test/ops/linear.py +++ b/test/ops/linear.py @@ -49,19 +49,38 @@ def test_op_linear( import argparse parser = argparse.ArgumentParser() - parser.add_argument("--device", default="cpu", choices=["cpu", "nvidia"], type=str) + parser.add_argument("--device", default="cpu", choices=["cpu", "nvidia", "metax"], type=str) + parser.add_argument( + "--dtype", + default="auto", + choices=["auto", "all", "f32", "f16", "bf16"], + type=str, + help="dtype set to test. auto: metax->bf16 only, others->all", + ) parser.add_argument("--profile", action="store_true") args = parser.parse_args() testShapes = [ ((2, 3), (2, 4), (3, 4), True), ((512, 4096), (512, 4096), (4096, 4096), True), + # M=1 decode-like cases + ((1, 4096), (1, 4096), (4096, 4096), True), + ((1, 11008), (1, 4096), (11008, 4096), True), + ((1, 4096), (1, 11008), (4096, 11008), True), ] - testDtypePrec = [ + allDtypePrec = [ # type, atol, rtol ("f32", 1e-5, 1e-5), ("f16", 1e-3, 1e-3), ("bf16", 1e-2, 1e-2), ] + + if args.dtype == "auto": + testDtypePrec = [("bf16", 1e-2, 1e-2)] if args.device == "metax" else allDtypePrec + elif args.dtype == "all": + testDtypePrec = allDtypePrec + else: + testDtypePrec = [x for x in allDtypePrec if x[0] == args.dtype] + print(f"Testing Ops.linear on {args.device}") for shapes in testShapes: for dtype_name, atol, rtol in testDtypePrec: diff --git a/test/ops/rms_norm.py b/test/ops/rms_norm.py index 67b789e3f..b4b62d27b 100644 --- a/test/ops/rms_norm.py +++ b/test/ops/rms_norm.py @@ -48,7 +48,7 @@ def test_op_rms_norm( import argparse parser = argparse.ArgumentParser() - parser.add_argument("--device", default="cpu", choices=["cpu", "nvidia"], type=str) + parser.add_argument("--device", default="cpu", choices=["cpu", "nvidia", "metax"], type=str) parser.add_argument("--profile", action="store_true") args = parser.parse_args() testShapes = [(1, 4), (512, 4096)] diff --git a/test/ops/rope.py b/test/ops/rope.py index fe59dd11c..bfb620b24 100644 --- a/test/ops/rope.py +++ b/test/ops/rope.py @@ -63,7 +63,7 @@ def test_op_rope( import argparse parser = argparse.ArgumentParser() - parser.add_argument("--device", default="cpu", choices=["cpu", "nvidia"], type=str) + parser.add_argument("--device", default="cpu", choices=["cpu", "nvidia", "metax"], type=str) parser.add_argument("--profile", action="store_true") args = parser.parse_args() testShapes = [ diff --git a/test/ops/self_attention.py b/test/ops/self_attention.py index a042b51be..f6058d0cd 100644 --- a/test/ops/self_attention.py +++ b/test/ops/self_attention.py @@ -15,7 +15,7 @@ def torch_self_attention(attn_val, query, key, value, scale): L, S = query.size(-2), key.size(-2) attn_bias = torch.zeros(L, S, dtype=query.dtype, device=query.device) - temp_mask = torch.ones(L, S, dtype=torch.bool).tril(diagonal=S-L) + temp_mask = torch.ones(L, S, dtype=torch.bool, device=query.device).tril(diagonal=S-L) attn_bias.masked_fill_(temp_mask.logical_not(), float("-inf")) attn_bias.to(query.dtype) @@ -65,7 +65,14 @@ def test_op_self_attention( import argparse parser = argparse.ArgumentParser() - parser.add_argument("--device", default="cpu", choices=["cpu", "nvidia"], type=str) + parser.add_argument("--device", default="cpu", choices=["cpu", "nvidia", "metax"], type=str) + parser.add_argument( + "--dtype", + default="auto", + choices=["auto", "all", "f32", "f16", "bf16"], + type=str, + help="dtype set to test. auto: metax->bf16 only, others->all", + ) parser.add_argument("--profile", action="store_true") args = parser.parse_args() testShapes = [ @@ -73,12 +80,20 @@ def test_op_self_attention( (2, 2, 1, 1, 4), (5, 11, 4, 2, 8), ] - testDtypePrec = [ + allDtypePrec = [ # type, atol, rtol ("f32", 1e-5, 1e-5), ("f16", 1e-3, 1e-3), ("bf16", 1e-2, 1e-2), ] + + if args.dtype == "auto": + testDtypePrec = [("bf16", 1e-2, 1e-2)] if args.device == "metax" else allDtypePrec + elif args.dtype == "all": + testDtypePrec = allDtypePrec + else: + testDtypePrec = [x for x in allDtypePrec if x[0] == args.dtype] + print(f"Testing Ops.self_attention on {args.device}") for shape in testShapes: for dtype_name, atol, rtol in testDtypePrec: diff --git a/test/ops/swiglu.py b/test/ops/swiglu.py index 1fa08f739..1a1880565 100644 --- a/test/ops/swiglu.py +++ b/test/ops/swiglu.py @@ -42,7 +42,7 @@ def test_op_swiglu( import argparse parser = argparse.ArgumentParser() - parser.add_argument("--device", default="cpu", choices=["cpu", "nvidia"], type=str) + parser.add_argument("--device", default="cpu", choices=["cpu", "nvidia", "metax"], type=str) parser.add_argument("--profile", action="store_true") args = parser.parse_args() testShapes = [(2, 3), (512, 4096)] diff --git a/test/test_infer.py b/test/test_infer.py index 59d06b874..44b4c797d 100644 --- a/test/test_infer.py +++ b/test/test_infer.py @@ -81,7 +81,7 @@ def llaisys_infer( if __name__ == "__main__": parser = argparse.ArgumentParser() - parser.add_argument("--device", default="cpu", choices=["cpu", "nvidia"], type=str) + parser.add_argument("--device", default="cpu", choices=["cpu", "nvidia", "metax"], type=str) parser.add_argument("--model", default=None, type=str) parser.add_argument("--prompt", default="Who are you?", type=str) parser.add_argument("--max_steps", default=128, type=int) @@ -113,6 +113,10 @@ def llaisys_infer( del model gc.collect() + if args.device == "nvidia": + # Release PyTorch caching allocator blocks before running LLAISYS in the same process. + torch.cuda.empty_cache() + torch.cuda.synchronize() print("\n=== Answer ===\n") print("Tokens:") diff --git a/test/test_runtime.py b/test/test_runtime.py index e2ac218a1..c509af3a8 100644 --- a/test/test_runtime.py +++ b/test/test_runtime.py @@ -15,7 +15,7 @@ def test_basic_runtime_api(device_name: str = "cpu"): return for i in range(ndev): - print("Testing device {i}...") + print(f"Testing device {i}...") api.set_device(i) test_memcpy(api, 1024 * 1024) @@ -55,7 +55,7 @@ def test_memcpy(api, size_bytes: int): if __name__ == "__main__": parser = argparse.ArgumentParser() - parser.add_argument("--device", default="cpu", choices=["cpu", "nvidia"], type=str) + parser.add_argument("--device", default="cpu", choices=["cpu", "nvidia", "metax"], type=str) args = parser.parse_args() test_basic_runtime_api(args.device) diff --git a/test/test_tensor.py b/test/test_tensor.py index 9d2e9a075..5bf7fc56b 100644 --- a/test/test_tensor.py +++ b/test/test_tensor.py @@ -1,19 +1,20 @@ -import llaisys +import argparse +import llaisys import torch from test_utils import * -import argparse -def test_tensor(): - torch_tensor = torch.arange(60, dtype=torch_dtype("i64")).reshape(3, 4, 5) +def test_tensor(device_name: str = "cpu"): + torch_tensor_host = torch.arange(60, dtype=torch_dtype("i64")).reshape(3, 4, 5) + torch_tensor = torch_tensor_host.to(torch_baseline_device(device_name)) llaisys_tensor = llaisys.Tensor( - (3, 4, 5), dtype=llaisys_dtype("i64"), device=llaisys_device("cpu") + (3, 4, 5), dtype=llaisys_dtype("i64"), device=llaisys_device(device_name) ) # Test load print("===Test load===") - llaisys_tensor.load(torch_tensor.data_ptr()) + llaisys_tensor.load(torch_tensor_host.data_ptr()) llaisys_tensor.debug() assert llaisys_tensor.is_contiguous() == torch_tensor.is_contiguous() assert check_equal(llaisys_tensor, torch_tensor) @@ -50,6 +51,10 @@ def test_tensor(): if __name__ == "__main__": - test_tensor() + parser = argparse.ArgumentParser() + parser.add_argument("--device", default="cpu", choices=["cpu", "nvidia", "metax"], type=str) + args = parser.parse_args() + + test_tensor(args.device) print("\n\033[92mTest passed!\033[0m\n") diff --git a/test/test_utils.py b/test/test_utils.py index 0f38f0c8e..4966c2271 100644 --- a/test/test_utils.py +++ b/test/test_utils.py @@ -2,13 +2,55 @@ import torch +def torch_baseline_device(device_name: str, device_id=0): + if device_name in {"nvidia", "metax"}: + return torch_device(device_name, device_id) + return torch.device("cpu") + + +def torch_to_llaisys_memcpy_kind(torch_tensor: torch.Tensor, dst_device_name: str): + src_is_cpu = torch_tensor.device.type == "cpu" + dst_is_cpu = dst_device_name == "cpu" + if src_is_cpu and dst_is_cpu: + return llaisys.MemcpyKind.D2D + if src_is_cpu and not dst_is_cpu: + return llaisys.MemcpyKind.H2D + if (not src_is_cpu) and dst_is_cpu: + return llaisys.MemcpyKind.D2H + return llaisys.MemcpyKind.D2D + + +def llaisys_to_torch_memcpy_kind(src_device_type: llaisys.DeviceType, torch_tensor: torch.Tensor): + src_is_cpu = src_device_type == llaisys.DeviceType.CPU + dst_is_cpu = torch_tensor.device.type == "cpu" + if src_is_cpu and dst_is_cpu: + return llaisys.MemcpyKind.D2D + if src_is_cpu and not dst_is_cpu: + return llaisys.MemcpyKind.H2D + if (not src_is_cpu) and dst_is_cpu: + return llaisys.MemcpyKind.D2H + return llaisys.MemcpyKind.D2D + + +def host_to_llaisys_memcpy_kind(device_name: str): + if device_name == "cpu": + return llaisys.MemcpyKind.D2D + return llaisys.MemcpyKind.H2D + + +def llaisys_to_host_memcpy_kind(device_type: llaisys.DeviceType): + if device_type == llaisys.DeviceType.CPU: + return llaisys.MemcpyKind.D2D + return llaisys.MemcpyKind.D2H + + def random_tensor( shape, dtype_name, device_name, device_id=0, scale=None, bias=None ) -> tuple[torch.Tensor, llaisys.Tensor]: torch_tensor = torch.rand( shape, dtype=torch_dtype(dtype_name), - device=torch_device(device_name, device_id), + device=torch_baseline_device(device_name, device_id), ) if scale is not None: torch_tensor *= scale @@ -28,7 +70,7 @@ def random_tensor( llaisys_tensor.data_ptr(), torch_tensor.data_ptr(), bytes_, - llaisys.MemcpyKind.D2D, + torch_to_llaisys_memcpy_kind(torch_tensor, device_name), ) return torch_tensor, llaisys_tensor @@ -40,7 +82,7 @@ def random_int_tensor(shape, device_name, dtype_name="i64", device_id=0, low=0, high, shape, dtype=torch_dtype(dtype_name), - device=torch_device(device_name, device_id), + device=torch_baseline_device(device_name, device_id), ) llaisys_tensor = llaisys.Tensor( @@ -56,7 +98,7 @@ def random_int_tensor(shape, device_name, dtype_name="i64", device_id=0, low=0, llaisys_tensor.data_ptr(), torch_tensor.data_ptr(), bytes_, - llaisys.MemcpyKind.D2D, + torch_to_llaisys_memcpy_kind(torch_tensor, device_name), ) return torch_tensor, llaisys_tensor @@ -68,7 +110,7 @@ def zero_tensor( torch_tensor = torch.zeros( shape, dtype=torch_dtype(dtype_name), - device=torch_device(device_name, device_id), + device=torch_baseline_device(device_name, device_id), ) llaisys_tensor = llaisys.Tensor( @@ -84,7 +126,7 @@ def zero_tensor( llaisys_tensor.data_ptr(), torch_tensor.data_ptr(), bytes_, - llaisys.MemcpyKind.D2D, + torch_to_llaisys_memcpy_kind(torch_tensor, device_name), ) return torch_tensor, llaisys_tensor @@ -93,7 +135,7 @@ def zero_tensor( def arrange_tensor( start, end, device_name, device_id=0 ) -> tuple[torch.Tensor, llaisys.Tensor]: - torch_tensor = torch.arange(start, end, device=torch_device(device_name, device_id)) + torch_tensor = torch.arange(start, end, device=torch_baseline_device(device_name, device_id)) llaisys_tensor = llaisys.Tensor( (end - start,), dtype=llaisys_dtype("i64"), @@ -107,7 +149,7 @@ def arrange_tensor( llaisys_tensor.data_ptr(), torch_tensor.data_ptr(), bytes_, - llaisys.MemcpyKind.D2D, + torch_to_llaisys_memcpy_kind(torch_tensor, device_name), ) return torch_tensor, llaisys_tensor @@ -135,9 +177,7 @@ def check_equal( tmp = torch.zeros( (right + 1,), dtype=torch_answer.dtype, - device=torch_device( - device_name(llaisys_result.device_type()), llaisys_result.device_id() - ), + device=torch_baseline_device(device_name(llaisys_result.device_type()), llaisys_result.device_id()), ) result = torch.as_strided(tmp, shape, strides) api = llaisys.RuntimeAPI(llaisys_result.device_type()) @@ -145,7 +185,7 @@ def check_equal( result.data_ptr(), llaisys_result.data_ptr(), (right + 1) * tmp.element_size(), - llaisys.MemcpyKind.D2D, + llaisys_to_torch_memcpy_kind(llaisys_result.device_type(), result), ) if strict: @@ -188,6 +228,9 @@ def torch_device(device_name: str, device_id=0): return torch.device("cpu") elif device_name == "nvidia": return torch.device(f"cuda:{device_id}") + elif device_name == "metax": + # mcPyTorch uses CUDA-compatible API; tensors are typically exposed as cuda devices. + return torch.device(f"cuda:{device_id}") else: raise ValueError(f"Unsupported device name: {device_name}") @@ -197,6 +240,8 @@ def llaisys_device(device_name: str): return llaisys.DeviceType.CPU elif device_name == "nvidia": return llaisys.DeviceType.NVIDIA + elif device_name == "metax": + return llaisys.DeviceType.METAX else: raise ValueError(f"Unsupported device name: {device_name}") @@ -206,6 +251,8 @@ def device_name(llaisys_device: llaisys.DeviceType): return "cpu" elif llaisys_device == llaisys.DeviceType.NVIDIA: return "nvidia" + elif llaisys_device == llaisys.DeviceType.METAX: + return "metax" else: raise ValueError(f"Unsupported llaisys device: {llaisys_device}") diff --git a/xmake.lua b/xmake.lua index 1f65f7a95..ced85c14c 100644 --- a/xmake.lua +++ b/xmake.lua @@ -7,17 +7,34 @@ add_includedirs("include") includes("xmake/cpu.lua") -- NVIDIA -- +option("openblas") + set_default(false) + set_showmenu(true) + set_description("Use OpenBLAS for linear (matmul) on CPU; install libopenblas-dev and run xmake f --openblas=y") +option_end() + option("nv-gpu") set_default(false) set_showmenu(true) set_description("Whether to compile implementations for Nvidia GPU") option_end() +option("mx-gpu") + set_default(false) + set_showmenu(true) + set_description("Whether to compile implementations for MetaX GPU") +option_end() + if has_config("nv-gpu") then add_defines("ENABLE_NVIDIA_API") includes("xmake/nvidia.lua") end +if has_config("mx-gpu") then + add_defines("ENABLE_METAX_API") + includes("xmake/metax.lua") +end + target("llaisys-utils") set_kind("static") @@ -37,6 +54,12 @@ target("llaisys-device") set_kind("static") add_deps("llaisys-utils") add_deps("llaisys-device-cpu") + if has_config("nv-gpu") then + add_deps("llaisys-device-nvidia") + end + if has_config("mx-gpu") then + add_deps("llaisys-device-metax") + end set_languages("cxx17") set_warnings("all", "error") @@ -83,6 +106,14 @@ target_end() target("llaisys-ops") set_kind("static") add_deps("llaisys-ops-cpu") + if has_config("nv-gpu") then + add_deps("llaisys-ops-nvidia") + end + if has_config("mx-gpu") then + add_deps("llaisys-ops-metax") + -- Propagate metax operator archive to final link step in dependency order. + add_links("llaisys-ops-metax") + end set_languages("cxx17") set_warnings("all", "error") @@ -95,6 +126,34 @@ target("llaisys-ops") on_install(function (target) end) target_end() +if has_config("nv-gpu") then + target("llaisys-ops-nvidia") + set_kind("static") + add_deps("llaisys-tensor") + + set_languages("cxx17") + set_warnings("all", "error") + add_files("src/ops/*/nvidia/*.cu") + add_includedirs("include", "src") + + -- CUDA arch targets (keep simple; adjust later for perf/compat) + add_cugencodes("native") + add_cugencodes("compute_75") + + -- Ensure static lib does CUDA devlink once (because final .so has no .cu) + add_values("cuda.build.devlink", true) + + if not is_plat("windows") then + add_cxflags("-fPIC", "-Wno-unknown-pragmas") + -- nvcc compile + devlink must be PIC + add_cuflags("-Xcompiler -fPIC", "-Xcompiler -Wno-unknown-pragmas") + add_culdflags("-Xcompiler -fPIC", "-Xcompiler -Wno-unknown-pragmas") + end + + on_install(function (target) end) + target_end() +end + target("llaisys") set_kind("shared") add_deps("llaisys-utils") @@ -105,7 +164,16 @@ target("llaisys") set_languages("cxx17") set_warnings("all", "error") + if not is_plat("windows") then + add_ldflags("-fopenmp") + add_syslinks("gomp") + end + if has_config("nv-gpu") then + add_syslinks("cudart") + add_syslinks("cublas") + end add_files("src/llaisys/*.cc") + add_files("src/models/qwen2/*.cpp") set_installdir(".") @@ -119,4 +187,4 @@ target("llaisys") os.cp("lib/*.so", "python/llaisys/libllaisys/") end end) -target_end() \ No newline at end of file +target_end() diff --git a/xmake/cpu.lua b/xmake/cpu.lua index 101d894e6..ccd8eb52a 100644 --- a/xmake/cpu.lua +++ b/xmake/cpu.lua @@ -18,6 +18,16 @@ target("llaisys-ops-cpu") set_warnings("all", "error") if not is_plat("windows") then add_cxflags("-fPIC", "-Wno-unknown-pragmas") + add_cxflags("-fopenmp") + elseif is_plat("windows") then + add_cxflags("/openmp") + end + if has_config("openblas") then + add_defines("LLAISYS_USE_OPENBLAS") + add_links("openblas") + add_syslinks("openblas") + -- 常见 cblas 头路径(按需取消注释或添加本机路径) + add_includedirs("/usr/include/x86_64-linux-gnu", "/usr/include", {public = false}) end add_files("../src/ops/*/cpu/*.cpp") diff --git a/xmake/metax.lua b/xmake/metax.lua new file mode 100644 index 000000000..22aeecdaf --- /dev/null +++ b/xmake/metax.lua @@ -0,0 +1,200 @@ +-- MetaX GPU backend integration. +-- Usage: xmake f --mx-gpu=y + +local function _append_unique(list, value) + if not value or value == "" then + return + end + for _, item in ipairs(list) do + if item == value then + return + end + end + table.insert(list, value) +end + +local function _metax_roots() + local roots = {} + _append_unique(roots, os.getenv("MACA_HOME")) + _append_unique(roots, "/opt/maca") + _append_unique(roots, "/usr/local/maca") + _append_unique(roots, "/opt/maca-3.3.0") + _append_unique(roots, "/opt/maca-3.2.0") + _append_unique(roots, "/opt/maca-3.1.0") + return roots +end + +local function _metax_include_dirs() + local dirs = {} + for _, root in ipairs(_metax_roots()) do + local d1 = path.join(root, "include") + local d2 = path.join(root, "include", "mcr") + local d3 = path.join(root, "mxgpu_llvm", "include") + if os.isdir(d1) then _append_unique(dirs, d1) end + if os.isdir(d2) then _append_unique(dirs, d2) end + if os.isdir(d3) then _append_unique(dirs, d3) end + end + return dirs +end + +local function _metax_link_dirs() + local dirs = {} + for _, root in ipairs(_metax_roots()) do + local d1 = path.join(root, "lib") + local d2 = path.join(root, "lib64") + local d3 = path.join(root, "mxgpu_llvm", "lib") + local d4 = path.join(root, "mxgpu_llvm", "lib64") + if os.isdir(d1) then _append_unique(dirs, d1) end + if os.isdir(d2) then _append_unique(dirs, d2) end + if os.isdir(d3) then _append_unique(dirs, d3) end + if os.isdir(d4) then _append_unique(dirs, d4) end + end + return dirs +end + +local function _apply_metax_search_paths(target) + for _, includedir in ipairs(_metax_include_dirs()) do + target:add("includedirs", includedir, {public = true}) + end + for _, linkdir in ipairs(_metax_link_dirs()) do + target:add("linkdirs", linkdir, {public = true}) + end +end + +local function _resolve_mxcc() + local mxcc = os.getenv("MXCC") + if mxcc and mxcc ~= "" then + return mxcc + end + local maca_home = os.getenv("MACA_HOME") + if maca_home and maca_home ~= "" then + local candidate = path.join(maca_home, "mxgpu_llvm", "bin", "mxcc") + if os.isfile(candidate) then + return candidate + end + end + return "mxcc" +end + +target("llaisys-device-metax") + set_kind("static") + add_deps("llaisys-utils") + set_languages("cxx17") + set_warnings("all", "error") + + -- Keep .maca as canonical source files, but compile wrappers for xmake 2.8.x compatibility. + on_load(function (target) + local projectdir = os.projectdir() + local gen_dir = path.join(projectdir, "build", "_gen", "metax") + os.mkdir(gen_dir) + + local maca_sources = { + path.join(projectdir, "src", "device", "metax", "metax_resource.maca"), + path.join(projectdir, "src", "device", "metax", "metax_runtime_api.maca") + } + + for _, source in ipairs(maca_sources) do + local base = path.basename(source) + local wrap = path.join(gen_dir, base .. "_wrapper.cpp") + io.writefile(wrap, "#include \"" .. path.translate(source) .. "\"\n") + target:add("files", wrap) + end + + _apply_metax_search_paths(target) + end) + + add_includedirs("../include", "../src") + + if not is_plat("windows") then + add_cxflags("-fPIC", "-Wno-unknown-pragmas") + end + + -- Link common runtime library names shipped by MACA. + add_syslinks("mcruntime", "mxc-runtime64", "runtime_cu", "mcblas", {public = true}) + + on_install(function (target) end) +target_end() + +target("llaisys-ops-metax") + set_kind("static") + add_deps("llaisys-tensor") + set_languages("cxx17") + set_warnings("all", "error") + + on_load(function (target) + local projectdir = os.projectdir() + local obj_dir = path.join(projectdir, "build", "_gen", "metax_ops_obj") + os.mkdir(obj_dir) + + local maca_sources = os.files(path.join(projectdir, "src", "ops", "*", "metax", "*.maca")) + local objectfiles = {} + for _, source in ipairs(maca_sources) do + local op_name = path.basename(path.directory(path.directory(source))) + local base = path.basename(source) + local objectfile = path.join(obj_dir, op_name .. "_" .. base .. ".o") + table.insert(objectfiles, objectfile) + end + + target:data_set("metax_maca_sources", maca_sources) + target:data_set("metax_maca_objectfiles", objectfiles) + _apply_metax_search_paths(target) + end) + + -- Build .maca sources via mxcc manually to avoid xmake 2.8.x toolscript limitations. + on_build(function (target) + local projectdir = os.projectdir() + local mxcc = _resolve_mxcc() + local include_dirs = { + path.join(projectdir, "include"), + path.join(projectdir, "src") + } + for _, includedir in ipairs(_metax_include_dirs()) do + table.insert(include_dirs, includedir) + end + + local sources = target:data("metax_maca_sources") or {} + local objectfiles = target:data("metax_maca_objectfiles") or {} + for i, source in ipairs(sources) do + local objectfile = objectfiles[i] + os.mkdir(path.directory(objectfile)) + + local args = { + "-std=c++17", + "-O3", + "-fPIC", + "-Wno-unknown-pragmas", + "-DENABLE_METAX_API" + } + for _, includedir in ipairs(include_dirs) do + table.insert(args, "-I" .. includedir) + end + table.insert(args, "-c") + table.insert(args, source) + table.insert(args, "-o") + table.insert(args, objectfile) + + os.vrunv(mxcc, args) + end + + local ar = target:tool("ar") or "ar" + local targetfile = target:targetfile() + os.mkdir(path.directory(targetfile)) + + local ar_args = {"-cr", targetfile} + for _, objectfile in ipairs(objectfiles) do + table.insert(ar_args, objectfile) + end + os.vrunv(ar, ar_args) + end) + + add_includedirs("../include", "../src") + + if not is_plat("windows") then + add_cxflags("-fPIC", "-Wno-unknown-pragmas") + end + + -- Link common runtime library names shipped by MACA. + add_syslinks("mcruntime", "mxc-runtime64", "runtime_cu", "mcblas", {public = true}) + + on_install(function (target) end) +target_end() diff --git a/xmake/nvidia.lua b/xmake/nvidia.lua new file mode 100644 index 000000000..3b89f8807 --- /dev/null +++ b/xmake/nvidia.lua @@ -0,0 +1,19 @@ +-- NVIDIA GPU 设备:CUDA Runtime API + 资源 +-- 使用方式: xmake f --nv-gpu=y [--cuda=/path/to/cuda] +target("llaisys-device-nvidia") + set_kind("static") + add_deps("llaisys-utils") + set_languages("cxx17") + add_files("../src/device/nvidia/*.cu") + add_cugencodes("native") + add_cugencodes("compute_75") + add_values("cuda.build.devlink", true) + add_includedirs("../include", "../src") + if not is_plat("windows") then + add_cxflags("-fPIC", "-Wno-unknown-pragmas") + -- nvcc: pass -fPIC to host compiler and to devlink step (for _gpucode.cu.o) + add_cuflags("-Xcompiler -fPIC", "-Xcompiler -Wno-unknown-pragmas") + add_culdflags("-Xcompiler -fPIC", "-Xcompiler -Wno-unknown-pragmas") + end + on_install(function (target) end) +target_end()